diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..84132977 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +src/duckdb \ No newline at end of file diff --git a/binding.gyp b/binding.gyp new file mode 100644 index 00000000..2ceccc0d --- /dev/null +++ b/binding.gyp @@ -0,0 +1,312 @@ +{ + "targets": [ + { + "target_name": "<(module_name)", + "sources": [ + "src/duckdb_node.cpp", + "src/database.cpp", + "src/data_chunk.cpp", + "src/connection.cpp", + "src/statement.cpp", + "src/utils.cpp", + "src/duckdb/ub_src_catalog.cpp", + "src/duckdb/ub_src_catalog_catalog_entry.cpp", + "src/duckdb/ub_src_catalog_default.cpp", + "src/duckdb/ub_src_common_adbc.cpp", + "src/duckdb/ub_src_common_adbc_nanoarrow.cpp", + "src/duckdb/ub_src_common.cpp", + "src/duckdb/ub_src_common_arrow_appender.cpp", + "src/duckdb/ub_src_common_arrow.cpp", + "src/duckdb/ub_src_common_crypto.cpp", + "src/duckdb/ub_src_common_enums.cpp", + "src/duckdb/ub_src_common_operator.cpp", + "src/duckdb/ub_src_common_progress_bar.cpp", + "src/duckdb/ub_src_common_row_operations.cpp", + "src/duckdb/ub_src_common_serializer.cpp", + "src/duckdb/ub_src_common_sort.cpp", + "src/duckdb/ub_src_common_types.cpp", + "src/duckdb/ub_src_common_types_column.cpp", + "src/duckdb/ub_src_common_types_row.cpp", + "src/duckdb/ub_src_common_value_operations.cpp", + "src/duckdb/src/common/vector_operations/boolean_operators.cpp", + "src/duckdb/src/common/vector_operations/comparison_operators.cpp", + "src/duckdb/src/common/vector_operations/generators.cpp", + "src/duckdb/src/common/vector_operations/is_distinct_from.cpp", + "src/duckdb/src/common/vector_operations/null_operations.cpp", + "src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp", + "src/duckdb/src/common/vector_operations/vector_cast.cpp", + "src/duckdb/src/common/vector_operations/vector_copy.cpp", + "src/duckdb/src/common/vector_operations/vector_hash.cpp", + "src/duckdb/src/common/vector_operations/vector_storage.cpp", + "src/duckdb/ub_src_core_functions_aggregate_algebraic.cpp", + "src/duckdb/ub_src_core_functions_aggregate_distributive.cpp", + "src/duckdb/ub_src_core_functions_aggregate_holistic.cpp", + "src/duckdb/ub_src_core_functions_aggregate_nested.cpp", + "src/duckdb/ub_src_core_functions_aggregate_regression.cpp", + "src/duckdb/ub_src_core_functions.cpp", + "src/duckdb/ub_src_core_functions_scalar_bit.cpp", + "src/duckdb/ub_src_core_functions_scalar_blob.cpp", + "src/duckdb/ub_src_core_functions_scalar_date.cpp", + "src/duckdb/ub_src_core_functions_scalar_debug.cpp", + "src/duckdb/ub_src_core_functions_scalar_enum.cpp", + "src/duckdb/ub_src_core_functions_scalar_generic.cpp", + "src/duckdb/ub_src_core_functions_scalar_list.cpp", + "src/duckdb/ub_src_core_functions_scalar_map.cpp", + "src/duckdb/ub_src_core_functions_scalar_math.cpp", + "src/duckdb/ub_src_core_functions_scalar_operators.cpp", + "src/duckdb/ub_src_core_functions_scalar_random.cpp", + "src/duckdb/ub_src_core_functions_scalar_string.cpp", + "src/duckdb/ub_src_core_functions_scalar_struct.cpp", + "src/duckdb/ub_src_core_functions_scalar_union.cpp", + "src/duckdb/ub_src_execution.cpp", + "src/duckdb/ub_src_execution_expression_executor.cpp", + "src/duckdb/ub_src_execution_index_art.cpp", + "src/duckdb/ub_src_execution_index.cpp", + "src/duckdb/ub_src_execution_nested_loop_join.cpp", + "src/duckdb/ub_src_execution_operator_aggregate.cpp", + "src/duckdb/ub_src_execution_operator_csv_scanner.cpp", + "src/duckdb/ub_src_execution_operator_csv_scanner_sniffer.cpp", + "src/duckdb/ub_src_execution_operator_filter.cpp", + "src/duckdb/ub_src_execution_operator_helper.cpp", + "src/duckdb/ub_src_execution_operator_join.cpp", + "src/duckdb/ub_src_execution_operator_order.cpp", + "src/duckdb/ub_src_execution_operator_persistent.cpp", + "src/duckdb/ub_src_execution_operator_projection.cpp", + "src/duckdb/ub_src_execution_operator_scan.cpp", + "src/duckdb/ub_src_execution_operator_schema.cpp", + "src/duckdb/ub_src_execution_operator_set.cpp", + "src/duckdb/ub_src_execution_physical_plan.cpp", + "src/duckdb/ub_src_function_aggregate_distributive.cpp", + "src/duckdb/ub_src_function_aggregate.cpp", + "src/duckdb/ub_src_function.cpp", + "src/duckdb/ub_src_function_cast.cpp", + "src/duckdb/ub_src_function_cast_union.cpp", + "src/duckdb/ub_src_function_pragma.cpp", + "src/duckdb/ub_src_function_scalar_compressed_materialization.cpp", + "src/duckdb/ub_src_function_scalar.cpp", + "src/duckdb/ub_src_function_scalar_generic.cpp", + "src/duckdb/ub_src_function_scalar_list.cpp", + "src/duckdb/ub_src_function_scalar_operators.cpp", + "src/duckdb/ub_src_function_scalar_sequence.cpp", + "src/duckdb/ub_src_function_scalar_string.cpp", + "src/duckdb/ub_src_function_scalar_string_regexp.cpp", + "src/duckdb/ub_src_function_scalar_struct.cpp", + "src/duckdb/ub_src_function_scalar_system.cpp", + "src/duckdb/ub_src_function_table_arrow.cpp", + "src/duckdb/ub_src_function_table.cpp", + "src/duckdb/ub_src_function_table_system.cpp", + "src/duckdb/ub_src_function_table_version.cpp", + "src/duckdb/ub_src_main.cpp", + "src/duckdb/ub_src_main_capi.cpp", + "src/duckdb/ub_src_main_capi_cast.cpp", + "src/duckdb/ub_src_main_chunk_scan_state.cpp", + "src/duckdb/ub_src_main_extension.cpp", + "src/duckdb/ub_src_main_relation.cpp", + "src/duckdb/ub_src_main_settings.cpp", + "src/duckdb/ub_src_optimizer.cpp", + "src/duckdb/ub_src_optimizer_compressed_materialization.cpp", + "src/duckdb/ub_src_optimizer_join_order.cpp", + "src/duckdb/ub_src_optimizer_matcher.cpp", + "src/duckdb/ub_src_optimizer_pullup.cpp", + "src/duckdb/ub_src_optimizer_pushdown.cpp", + "src/duckdb/ub_src_optimizer_rule.cpp", + "src/duckdb/ub_src_optimizer_statistics_expression.cpp", + "src/duckdb/ub_src_optimizer_statistics_operator.cpp", + "src/duckdb/ub_src_parallel.cpp", + "src/duckdb/ub_src_parser.cpp", + "src/duckdb/ub_src_parser_constraints.cpp", + "src/duckdb/ub_src_parser_expression.cpp", + "src/duckdb/ub_src_parser_parsed_data.cpp", + "src/duckdb/ub_src_parser_query_node.cpp", + "src/duckdb/ub_src_parser_statement.cpp", + "src/duckdb/ub_src_parser_tableref.cpp", + "src/duckdb/ub_src_parser_transform_constraint.cpp", + "src/duckdb/ub_src_parser_transform_expression.cpp", + "src/duckdb/ub_src_parser_transform_helpers.cpp", + "src/duckdb/ub_src_parser_transform_statement.cpp", + "src/duckdb/ub_src_parser_transform_tableref.cpp", + "src/duckdb/ub_src_planner.cpp", + "src/duckdb/ub_src_planner_binder_expression.cpp", + "src/duckdb/ub_src_planner_binder_query_node.cpp", + "src/duckdb/ub_src_planner_binder_statement.cpp", + "src/duckdb/ub_src_planner_binder_tableref.cpp", + "src/duckdb/ub_src_planner_expression.cpp", + "src/duckdb/ub_src_planner_expression_binder.cpp", + "src/duckdb/ub_src_planner_filter.cpp", + "src/duckdb/ub_src_planner_operator.cpp", + "src/duckdb/ub_src_planner_subquery.cpp", + "src/duckdb/ub_src_storage.cpp", + "src/duckdb/ub_src_storage_buffer.cpp", + "src/duckdb/ub_src_storage_checkpoint.cpp", + "src/duckdb/ub_src_storage_compression.cpp", + "src/duckdb/ub_src_storage_compression_chimp.cpp", + "src/duckdb/ub_src_storage_metadata.cpp", + "src/duckdb/ub_src_storage_serialization.cpp", + "src/duckdb/ub_src_storage_statistics.cpp", + "src/duckdb/ub_src_storage_table.cpp", + "src/duckdb/ub_src_transaction.cpp", + "src/duckdb/src/verification/copied_statement_verifier.cpp", + "src/duckdb/src/verification/deserialized_statement_verifier.cpp", + "src/duckdb/src/verification/external_statement_verifier.cpp", + "src/duckdb/src/verification/no_operator_caching_verifier.cpp", + "src/duckdb/src/verification/parsed_statement_verifier.cpp", + "src/duckdb/src/verification/prepared_statement_verifier.cpp", + "src/duckdb/src/verification/statement_verifier.cpp", + "src/duckdb/src/verification/unoptimized_statement_verifier.cpp", + "src/duckdb/third_party/fmt/format.cc", + "src/duckdb/third_party/fsst/fsst_avx512.cpp", + "src/duckdb/third_party/fsst/libfsst.cpp", + "src/duckdb/third_party/miniz/miniz.cpp", + "src/duckdb/third_party/re2/re2/bitstate.cc", + "src/duckdb/third_party/re2/re2/compile.cc", + "src/duckdb/third_party/re2/re2/dfa.cc", + "src/duckdb/third_party/re2/re2/filtered_re2.cc", + "src/duckdb/third_party/re2/re2/mimics_pcre.cc", + "src/duckdb/third_party/re2/re2/nfa.cc", + "src/duckdb/third_party/re2/re2/onepass.cc", + "src/duckdb/third_party/re2/re2/parse.cc", + "src/duckdb/third_party/re2/re2/perl_groups.cc", + "src/duckdb/third_party/re2/re2/prefilter.cc", + "src/duckdb/third_party/re2/re2/prefilter_tree.cc", + "src/duckdb/third_party/re2/re2/prog.cc", + "src/duckdb/third_party/re2/re2/re2.cc", + "src/duckdb/third_party/re2/re2/regexp.cc", + "src/duckdb/third_party/re2/re2/set.cc", + "src/duckdb/third_party/re2/re2/simplify.cc", + "src/duckdb/third_party/re2/re2/stringpiece.cc", + "src/duckdb/third_party/re2/re2/tostring.cc", + "src/duckdb/third_party/re2/re2/unicode_casefold.cc", + "src/duckdb/third_party/re2/re2/unicode_groups.cc", + "src/duckdb/third_party/re2/util/rune.cc", + "src/duckdb/third_party/re2/util/strutil.cc", + "src/duckdb/third_party/hyperloglog/hyperloglog.cpp", + "src/duckdb/third_party/hyperloglog/sds.cpp", + "src/duckdb/third_party/fastpforlib/bitpacking.cpp", + "src/duckdb/third_party/utf8proc/utf8proc.cpp", + "src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp", + "src/duckdb/third_party/libpg_query/pg_functions.cpp", + "src/duckdb/third_party/libpg_query/postgres_parser.cpp", + "src/duckdb/third_party/libpg_query/src_backend_nodes_list.cpp", + "src/duckdb/third_party/libpg_query/src_backend_nodes_makefuncs.cpp", + "src/duckdb/third_party/libpg_query/src_backend_nodes_value.cpp", + "src/duckdb/third_party/libpg_query/src_backend_parser_gram.cpp", + "src/duckdb/third_party/libpg_query/src_backend_parser_parser.cpp", + "src/duckdb/third_party/libpg_query/src_backend_parser_scan.cpp", + "src/duckdb/third_party/libpg_query/src_backend_parser_scansup.cpp", + "src/duckdb/third_party/libpg_query/src_common_keywords.cpp", + "src/duckdb/third_party/mbedtls/library/asn1parse.cpp", + "src/duckdb/third_party/mbedtls/library/base64.cpp", + "src/duckdb/third_party/mbedtls/library/bignum.cpp", + "src/duckdb/third_party/mbedtls/library/constant_time.cpp", + "src/duckdb/third_party/mbedtls/library/md.cpp", + "src/duckdb/third_party/mbedtls/library/oid.cpp", + "src/duckdb/third_party/mbedtls/library/pem.cpp", + "src/duckdb/third_party/mbedtls/library/pk.cpp", + "src/duckdb/third_party/mbedtls/library/pk_wrap.cpp", + "src/duckdb/third_party/mbedtls/library/pkparse.cpp", + "src/duckdb/third_party/mbedtls/library/platform_util.cpp", + "src/duckdb/third_party/mbedtls/library/rsa.cpp", + "src/duckdb/third_party/mbedtls/library/rsa_alt_helpers.cpp", + "src/duckdb/third_party/mbedtls/library/sha1.cpp", + "src/duckdb/third_party/mbedtls/library/sha256.cpp", + "src/duckdb/third_party/mbedtls/library/sha512.cpp", + "src/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp" + ], + "include_dirs": [ + " + +namespace duckdb { + +Catalog::Catalog(AttachedDatabase &db) : db(db) { +} + +Catalog::~Catalog() { +} + +DatabaseInstance &Catalog::GetDatabase() { + return db.GetDatabase(); +} + +AttachedDatabase &Catalog::GetAttached() { + return db; +} + +const string &Catalog::GetName() { + return GetAttached().GetName(); +} + +idx_t Catalog::GetOid() { + return GetAttached().oid; +} + +Catalog &Catalog::GetSystemCatalog(ClientContext &context) { + return Catalog::GetSystemCatalog(*context.db); +} + +optional_ptr Catalog::GetCatalogEntry(ClientContext &context, const string &catalog_name) { + auto &db_manager = DatabaseManager::Get(context); + if (catalog_name == TEMP_CATALOG) { + return &ClientData::Get(context).temporary_objects->GetCatalog(); + } + if (catalog_name == SYSTEM_CATALOG) { + return &GetSystemCatalog(context); + } + auto entry = db_manager.GetDatabase( + context, IsInvalidCatalog(catalog_name) ? DatabaseManager::GetDefaultDatabase(context) : catalog_name); + if (!entry) { + return nullptr; + } + return &entry->GetCatalog(); +} + +Catalog &Catalog::GetCatalog(ClientContext &context, const string &catalog_name) { + auto catalog = Catalog::GetCatalogEntry(context, catalog_name); + if (!catalog) { + throw BinderException("Catalog \"%s\" does not exist!", catalog_name); + } + return *catalog; +} + +//===--------------------------------------------------------------------===// +// Schema +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateSchema(ClientContext &context, CreateSchemaInfo &info) { + return CreateSchema(GetCatalogTransaction(context), info); +} + +CatalogTransaction Catalog::GetCatalogTransaction(ClientContext &context) { + return CatalogTransaction(*this, context); +} + +//===--------------------------------------------------------------------===// +// Table +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateTable(ClientContext &context, BoundCreateTableInfo &info) { + return CreateTable(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreateTable(ClientContext &context, unique_ptr info) { + auto binder = Binder::CreateBinder(context); + auto bound_info = binder->BindCreateTableInfo(std::move(info)); + return CreateTable(context, *bound_info); +} + +optional_ptr Catalog::CreateTable(CatalogTransaction transaction, SchemaCatalogEntry &schema, + BoundCreateTableInfo &info) { + return schema.CreateTable(transaction, info); +} + +optional_ptr Catalog::CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) { + auto &schema = GetSchema(transaction, info.base->schema); + return CreateTable(transaction, schema, info); +} + +//===--------------------------------------------------------------------===// +// View +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateView(CatalogTransaction transaction, CreateViewInfo &info) { + auto &schema = GetSchema(transaction, info.schema); + return CreateView(transaction, schema, info); +} + +optional_ptr Catalog::CreateView(ClientContext &context, CreateViewInfo &info) { + return CreateView(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreateView(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateViewInfo &info) { + return schema.CreateView(transaction, info); +} + +//===--------------------------------------------------------------------===// +// Sequence +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) { + auto &schema = GetSchema(transaction, info.schema); + return CreateSequence(transaction, schema, info); +} + +optional_ptr Catalog::CreateSequence(ClientContext &context, CreateSequenceInfo &info) { + return CreateSequence(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreateSequence(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateSequenceInfo &info) { + return schema.CreateSequence(transaction, info); +} + +//===--------------------------------------------------------------------===// +// Type +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateType(CatalogTransaction transaction, CreateTypeInfo &info) { + auto &schema = GetSchema(transaction, info.schema); + return CreateType(transaction, schema, info); +} + +optional_ptr Catalog::CreateType(ClientContext &context, CreateTypeInfo &info) { + return CreateType(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreateType(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateTypeInfo &info) { + return schema.CreateType(transaction, info); +} + +//===--------------------------------------------------------------------===// +// Table Function +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateTableFunction(CatalogTransaction transaction, CreateTableFunctionInfo &info) { + auto &schema = GetSchema(transaction, info.schema); + return CreateTableFunction(transaction, schema, info); +} + +optional_ptr Catalog::CreateTableFunction(ClientContext &context, CreateTableFunctionInfo &info) { + return CreateTableFunction(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreateTableFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateTableFunctionInfo &info) { + return schema.CreateTableFunction(transaction, info); +} + +optional_ptr Catalog::CreateTableFunction(ClientContext &context, + optional_ptr info) { + return CreateTableFunction(context, *info); +} + +//===--------------------------------------------------------------------===// +// Copy Function +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateCopyFunction(CatalogTransaction transaction, CreateCopyFunctionInfo &info) { + auto &schema = GetSchema(transaction, info.schema); + return CreateCopyFunction(transaction, schema, info); +} + +optional_ptr Catalog::CreateCopyFunction(ClientContext &context, CreateCopyFunctionInfo &info) { + return CreateCopyFunction(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreateCopyFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateCopyFunctionInfo &info) { + return schema.CreateCopyFunction(transaction, info); +} + +//===--------------------------------------------------------------------===// +// Pragma Function +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreatePragmaFunction(CatalogTransaction transaction, + CreatePragmaFunctionInfo &info) { + auto &schema = GetSchema(transaction, info.schema); + return CreatePragmaFunction(transaction, schema, info); +} + +optional_ptr Catalog::CreatePragmaFunction(ClientContext &context, CreatePragmaFunctionInfo &info) { + return CreatePragmaFunction(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreatePragmaFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreatePragmaFunctionInfo &info) { + return schema.CreatePragmaFunction(transaction, info); +} + +//===--------------------------------------------------------------------===// +// Function +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) { + auto &schema = GetSchema(transaction, info.schema); + return CreateFunction(transaction, schema, info); +} + +optional_ptr Catalog::CreateFunction(ClientContext &context, CreateFunctionInfo &info) { + return CreateFunction(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreateFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateFunctionInfo &info) { + return schema.CreateFunction(transaction, info); +} + +optional_ptr Catalog::AddFunction(ClientContext &context, CreateFunctionInfo &info) { + info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; + return CreateFunction(context, info); +} + +//===--------------------------------------------------------------------===// +// Collation +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) { + auto &schema = GetSchema(transaction, info.schema); + return CreateCollation(transaction, schema, info); +} + +optional_ptr Catalog::CreateCollation(ClientContext &context, CreateCollationInfo &info) { + return CreateCollation(GetCatalogTransaction(context), info); +} + +optional_ptr Catalog::CreateCollation(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateCollationInfo &info) { + return schema.CreateCollation(transaction, info); +} + +//===--------------------------------------------------------------------===// +// Index +//===--------------------------------------------------------------------===// +optional_ptr Catalog::CreateIndex(CatalogTransaction transaction, CreateIndexInfo &info) { + auto &context = transaction.GetContext(); + return CreateIndex(context, info); +} + +optional_ptr Catalog::CreateIndex(ClientContext &context, CreateIndexInfo &info) { + auto &schema = GetSchema(context, info.schema); + auto &table = GetEntry(context, schema.name, info.table); + return schema.CreateIndex(context, info, table); +} + +//===--------------------------------------------------------------------===// +// Lookup Structures +//===--------------------------------------------------------------------===// +struct CatalogLookup { + CatalogLookup(Catalog &catalog, string schema_p) : catalog(catalog), schema(std::move(schema_p)) { + } + + Catalog &catalog; + string schema; +}; + +//! Return value of Catalog::LookupEntry +struct CatalogEntryLookup { + optional_ptr schema; + optional_ptr entry; + PreservedError error; + + DUCKDB_API bool Found() const { + return entry; + } +}; + +//===--------------------------------------------------------------------===// +// Generic +//===--------------------------------------------------------------------===// +void Catalog::DropEntry(ClientContext &context, DropInfo &info) { + ModifyCatalog(); + if (info.type == CatalogType::SCHEMA_ENTRY) { + // DROP SCHEMA + DropSchema(context, info); + return; + } + + auto lookup = LookupEntry(context, info.type, info.schema, info.name, info.if_not_found); + + if (!lookup.Found()) { + return; + } + + lookup.schema->DropEntry(context, info); +} + +SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &name, QueryErrorContext error_context) { + return *Catalog::GetSchema(context, name, OnEntryNotFound::THROW_EXCEPTION, error_context); +} + +optional_ptr Catalog::GetSchema(ClientContext &context, const string &schema_name, + OnEntryNotFound if_not_found, QueryErrorContext error_context) { + return GetSchema(GetCatalogTransaction(context), schema_name, if_not_found, error_context); +} + +SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &catalog_name, const string &schema_name, + QueryErrorContext error_context) { + return *Catalog::GetSchema(context, catalog_name, schema_name, OnEntryNotFound::THROW_EXCEPTION, error_context); +} + +SchemaCatalogEntry &Catalog::GetSchema(CatalogTransaction transaction, const string &name, + QueryErrorContext error_context) { + return *GetSchema(transaction, name, OnEntryNotFound::THROW_EXCEPTION, error_context); +} + +//===--------------------------------------------------------------------===// +// Lookup +//===--------------------------------------------------------------------===// +SimilarCatalogEntry Catalog::SimilarEntryInSchemas(ClientContext &context, const string &entry_name, CatalogType type, + const reference_set_t &schemas) { + SimilarCatalogEntry result; + for (auto schema_ref : schemas) { + auto &schema = schema_ref.get(); + auto transaction = schema.catalog.GetCatalogTransaction(context); + auto entry = schema.GetSimilarEntry(transaction, type, entry_name); + if (!entry.Found()) { + // no similar entry found + continue; + } + if (!result.Found() || result.distance > entry.distance) { + result = entry; + result.schema = &schema; + } + } + return result; +} + +vector GetCatalogEntries(ClientContext &context, const string &catalog, const string &schema) { + vector entries; + auto &search_path = *context.client_data->catalog_search_path; + if (IsInvalidCatalog(catalog) && IsInvalidSchema(schema)) { + // no catalog or schema provided - scan the entire search path + entries = search_path.Get(); + } else if (IsInvalidCatalog(catalog)) { + auto catalogs = search_path.GetCatalogsForSchema(schema); + for (auto &catalog_name : catalogs) { + entries.emplace_back(catalog_name, schema); + } + if (entries.empty()) { + entries.emplace_back(DatabaseManager::GetDefaultDatabase(context), schema); + } + } else if (IsInvalidSchema(schema)) { + auto schemas = search_path.GetSchemasForCatalog(catalog); + for (auto &schema_name : schemas) { + entries.emplace_back(catalog, schema_name); + } + if (entries.empty()) { + entries.emplace_back(catalog, DEFAULT_SCHEMA); + } + } else { + // specific catalog and schema provided + entries.emplace_back(catalog, schema); + } + return entries; +} + +void FindMinimalQualification(ClientContext &context, const string &catalog_name, const string &schema_name, + bool &qualify_database, bool &qualify_schema) { + // check if we can we qualify ONLY the schema + bool found = false; + auto entries = GetCatalogEntries(context, INVALID_CATALOG, schema_name); + for (auto &entry : entries) { + if (entry.catalog == catalog_name && entry.schema == schema_name) { + found = true; + break; + } + } + if (found) { + qualify_database = false; + qualify_schema = true; + return; + } + // check if we can qualify ONLY the catalog + found = false; + entries = GetCatalogEntries(context, catalog_name, INVALID_SCHEMA); + for (auto &entry : entries) { + if (entry.catalog == catalog_name && entry.schema == schema_name) { + found = true; + break; + } + } + if (found) { + qualify_database = true; + qualify_schema = false; + return; + } + // need to qualify both catalog and schema + qualify_database = true; + qualify_schema = true; +} + +bool Catalog::TryAutoLoad(ClientContext &context, const string &original_name) noexcept { + string extension_name = ExtensionHelper::ApplyExtensionAlias(original_name); + if (context.db->ExtensionIsLoaded(extension_name)) { + return true; + } +#ifndef DUCKDB_DISABLE_EXTENSION_LOAD + auto &dbconfig = DBConfig::GetConfig(context); + if (!dbconfig.options.autoload_known_extensions) { + return false; + } + try { + if (ExtensionHelper::CanAutoloadExtension(extension_name)) { + return ExtensionHelper::TryAutoLoadExtension(context, extension_name); + } + } catch (...) { + return false; + } +#endif + return false; +} + +void Catalog::AutoloadExtensionByConfigName(ClientContext &context, const string &configuration_name) { +#ifndef DUCKDB_DISABLE_EXTENSION_LOAD + auto &dbconfig = DBConfig::GetConfig(context); + if (dbconfig.options.autoload_known_extensions) { + auto extension_name = ExtensionHelper::FindExtensionInEntries(configuration_name, EXTENSION_SETTINGS); + if (ExtensionHelper::CanAutoloadExtension(extension_name)) { + ExtensionHelper::AutoLoadExtension(context, extension_name); + return; + } + } +#endif + + throw Catalog::UnrecognizedConfigurationError(context, configuration_name); +} + +bool Catalog::AutoLoadExtensionByCatalogEntry(ClientContext &context, CatalogType type, const string &entry_name) { +#ifndef DUCKDB_DISABLE_EXTENSION_LOAD + auto &dbconfig = DBConfig::GetConfig(context); + if (dbconfig.options.autoload_known_extensions) { + string extension_name; + if (type == CatalogType::TABLE_FUNCTION_ENTRY || type == CatalogType::SCALAR_FUNCTION_ENTRY || + type == CatalogType::AGGREGATE_FUNCTION_ENTRY || type == CatalogType::PRAGMA_FUNCTION_ENTRY) { + extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_FUNCTIONS); + } else if (type == CatalogType::COPY_FUNCTION_ENTRY) { + extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COPY_FUNCTIONS); + } else if (type == CatalogType::TYPE_ENTRY) { + extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_TYPES); + } else if (type == CatalogType::COLLATION_ENTRY) { + extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COLLATIONS); + } + + if (!extension_name.empty() && ExtensionHelper::CanAutoloadExtension(extension_name)) { + ExtensionHelper::AutoLoadExtension(context, extension_name); + return true; + } + } +#endif + + return false; +} + +CatalogException Catalog::UnrecognizedConfigurationError(ClientContext &context, const string &name) { + // check if the setting exists in any extensions + auto extension_name = ExtensionHelper::FindExtensionInEntries(name, EXTENSION_SETTINGS); + if (!extension_name.empty()) { + auto error_message = "Setting with name \"" + name + "\" is not in the catalog, but it exists in the " + + extension_name + " extension."; + error_message = ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, extension_name); + return CatalogException(error_message); + } + // the setting is not in an extension + // get a list of all options + vector potential_names = DBConfig::GetOptionNames(); + for (auto &entry : DBConfig::GetConfig(context).extension_parameters) { + potential_names.push_back(entry.first); + } + + throw CatalogException("unrecognized configuration parameter \"%s\"\n%s", name, + StringUtil::CandidatesErrorMessage(potential_names, name, "Did you mean")); +} + +CatalogException Catalog::CreateMissingEntryException(ClientContext &context, const string &entry_name, + CatalogType type, + const reference_set_t &schemas, + QueryErrorContext error_context) { + auto entry = SimilarEntryInSchemas(context, entry_name, type, schemas); + + reference_set_t unseen_schemas; + auto &db_manager = DatabaseManager::Get(context); + auto databases = db_manager.GetDatabases(context); + for (auto database : databases) { + auto &catalog = database.get().GetCatalog(); + auto current_schemas = catalog.GetAllSchemas(context); + for (auto ¤t_schema : current_schemas) { + unseen_schemas.insert(current_schema.get()); + } + } + // check if the entry exists in any extension + string extension_name; + if (type == CatalogType::TABLE_FUNCTION_ENTRY || type == CatalogType::SCALAR_FUNCTION_ENTRY || + type == CatalogType::AGGREGATE_FUNCTION_ENTRY || type == CatalogType::PRAGMA_FUNCTION_ENTRY) { + extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_FUNCTIONS); + } else if (type == CatalogType::TYPE_ENTRY) { + extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_TYPES); + } else if (type == CatalogType::COPY_FUNCTION_ENTRY) { + extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COPY_FUNCTIONS); + } else if (type == CatalogType::COLLATION_ENTRY) { + extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COLLATIONS); + } + + // if we found an extension that can handle this catalog entry, create an error hinting the user + if (!extension_name.empty()) { + auto error_message = CatalogTypeToString(type) + " with name \"" + entry_name + + "\" is not in the catalog, but it exists in the " + extension_name + " extension."; + error_message = ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, extension_name); + return CatalogException(error_message); + } + + auto unseen_entry = SimilarEntryInSchemas(context, entry_name, type, unseen_schemas); + string did_you_mean; + if (unseen_entry.Found() && unseen_entry.distance < entry.distance) { + // the closest matching entry requires qualification as it is not in the default search path + // check how to minimally qualify this entry + auto catalog_name = unseen_entry.schema->catalog.GetName(); + auto schema_name = unseen_entry.schema->name; + bool qualify_database; + bool qualify_schema; + FindMinimalQualification(context, catalog_name, schema_name, qualify_database, qualify_schema); + did_you_mean = "\nDid you mean \"" + unseen_entry.GetQualifiedName(qualify_database, qualify_schema) + "\"?"; + } else if (entry.Found()) { + did_you_mean = "\nDid you mean \"" + entry.name + "\"?"; + } + + return CatalogException(error_context.FormatError("%s with name %s does not exist!%s", CatalogTypeToString(type), + entry_name, did_you_mean)); +} + +CatalogEntryLookup Catalog::TryLookupEntryInternal(CatalogTransaction transaction, CatalogType type, + const string &schema, const string &name) { + auto schema_entry = GetSchema(transaction, schema, OnEntryNotFound::RETURN_NULL); + if (!schema_entry) { + return {nullptr, nullptr, PreservedError()}; + } + auto entry = schema_entry->GetEntry(transaction, type, name); + if (!entry) { + return {schema_entry, nullptr, PreservedError()}; + } + return {schema_entry, entry, PreservedError()}; +} + +CatalogEntryLookup Catalog::TryLookupEntry(ClientContext &context, CatalogType type, const string &schema, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context) { + reference_set_t schemas; + if (IsInvalidSchema(schema)) { + // try all schemas for this catalog + auto entries = GetCatalogEntries(context, GetName(), INVALID_SCHEMA); + for (auto &entry : entries) { + auto &candidate_schema = entry.schema; + auto transaction = GetCatalogTransaction(context); + auto result = TryLookupEntryInternal(transaction, type, candidate_schema, name); + if (result.Found()) { + return result; + } + if (result.schema) { + schemas.insert(*result.schema); + } + } + } else { + auto transaction = GetCatalogTransaction(context); + auto result = TryLookupEntryInternal(transaction, type, schema, name); + if (result.Found()) { + return result; + } + if (result.schema) { + schemas.insert(*result.schema); + } + } + + if (if_not_found == OnEntryNotFound::RETURN_NULL) { + return {nullptr, nullptr, PreservedError()}; + } else { + auto except = CreateMissingEntryException(context, name, type, schemas, error_context); + return {nullptr, nullptr, PreservedError(except)}; + } +} + +CatalogEntryLookup Catalog::LookupEntry(ClientContext &context, CatalogType type, const string &schema, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context) { + auto res = TryLookupEntry(context, type, schema, name, if_not_found, error_context); + + if (res.error) { + res.error.Throw(); + } + + return res; +} + +CatalogEntryLookup Catalog::TryLookupEntry(ClientContext &context, vector &lookups, CatalogType type, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context) { + reference_set_t schemas; + for (auto &lookup : lookups) { + auto transaction = lookup.catalog.GetCatalogTransaction(context); + auto result = lookup.catalog.TryLookupEntryInternal(transaction, type, lookup.schema, name); + if (result.Found()) { + return result; + } + if (result.schema) { + schemas.insert(*result.schema); + } + } + + if (if_not_found == OnEntryNotFound::RETURN_NULL) { + return {nullptr, nullptr, PreservedError()}; + } else { + auto except = CreateMissingEntryException(context, name, type, schemas, error_context); + return {nullptr, nullptr, PreservedError(except)}; + } +} + +CatalogEntryLookup Catalog::TryLookupEntry(ClientContext &context, CatalogType type, const string &catalog, + const string &schema, const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context) { + auto entries = GetCatalogEntries(context, catalog, schema); + vector lookups; + lookups.reserve(entries.size()); + for (auto &entry : entries) { + if (if_not_found == OnEntryNotFound::RETURN_NULL) { + auto catalog_entry = Catalog::GetCatalogEntry(context, entry.catalog); + if (!catalog_entry) { + return {nullptr, nullptr, PreservedError()}; + } + lookups.emplace_back(*catalog_entry, entry.schema); + } else { + lookups.emplace_back(Catalog::GetCatalog(context, entry.catalog), entry.schema); + } + } + return Catalog::TryLookupEntry(context, lookups, type, name, if_not_found, error_context); +} + +CatalogEntry &Catalog::GetEntry(ClientContext &context, const string &schema, const string &name) { + vector entry_types {CatalogType::TABLE_ENTRY, CatalogType::SEQUENCE_ENTRY}; + + for (auto entry_type : entry_types) { + auto result = GetEntry(context, entry_type, schema, name, OnEntryNotFound::RETURN_NULL); + if (result) { + return *result; + } + } + + throw CatalogException("CatalogElement \"%s.%s\" does not exist!", schema, name); +} + +optional_ptr Catalog::GetEntry(ClientContext &context, CatalogType type, const string &schema_name, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context) { + auto lookup_entry = TryLookupEntry(context, type, schema_name, name, if_not_found, error_context); + + // Try autoloading extension to resolve lookup + if (!lookup_entry.Found()) { + if (AutoLoadExtensionByCatalogEntry(context, type, name)) { + lookup_entry = TryLookupEntry(context, type, schema_name, name, if_not_found, error_context); + } + } + + if (lookup_entry.error) { + lookup_entry.error.Throw(); + } + + return lookup_entry.entry.get(); +} + +CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType type, const string &schema, const string &name, + QueryErrorContext error_context) { + return *Catalog::GetEntry(context, type, schema, name, OnEntryNotFound::THROW_EXCEPTION, error_context); +} + +optional_ptr Catalog::GetEntry(ClientContext &context, CatalogType type, const string &catalog, + const string &schema, const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context) { + auto result = TryLookupEntry(context, type, catalog, schema, name, if_not_found, error_context); + + // Try autoloading extension to resolve lookup + if (!result.Found()) { + if (AutoLoadExtensionByCatalogEntry(context, type, name)) { + result = TryLookupEntry(context, type, catalog, schema, name, if_not_found, error_context); + } + } + + if (result.error) { + result.error.Throw(); + } + + if (!result.Found()) { + D_ASSERT(if_not_found == OnEntryNotFound::RETURN_NULL); + return nullptr; + } + return result.entry.get(); +} + +CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType type, const string &catalog, const string &schema, + const string &name, QueryErrorContext error_context) { + return *Catalog::GetEntry(context, type, catalog, schema, name, OnEntryNotFound::THROW_EXCEPTION, error_context); +} + +optional_ptr Catalog::GetSchema(ClientContext &context, const string &catalog_name, + const string &schema_name, OnEntryNotFound if_not_found, + QueryErrorContext error_context) { + auto entries = GetCatalogEntries(context, catalog_name, schema_name); + for (idx_t i = 0; i < entries.size(); i++) { + auto on_not_found = i + 1 == entries.size() ? if_not_found : OnEntryNotFound::RETURN_NULL; + auto &catalog = Catalog::GetCatalog(context, entries[i].catalog); + auto result = catalog.GetSchema(context, schema_name, on_not_found, error_context); + if (result) { + return result; + } + } + return nullptr; +} + +LogicalType Catalog::GetType(ClientContext &context, const string &schema, const string &name, + OnEntryNotFound if_not_found) { + auto type_entry = GetEntry(context, schema, name, if_not_found); + if (!type_entry) { + return LogicalType::INVALID; + } + return type_entry->user_type; +} + +LogicalType Catalog::GetType(ClientContext &context, const string &catalog_name, const string &schema, + const string &name) { + auto &type_entry = Catalog::GetEntry(context, catalog_name, schema, name); + return type_entry.user_type; +} + +vector> Catalog::GetSchemas(ClientContext &context) { + vector> schemas; + ScanSchemas(context, [&](SchemaCatalogEntry &entry) { schemas.push_back(entry); }); + return schemas; +} + +vector> Catalog::GetSchemas(ClientContext &context, const string &catalog_name) { + vector> catalogs; + if (IsInvalidCatalog(catalog_name)) { + reference_set_t inserted_catalogs; + + auto &search_path = *context.client_data->catalog_search_path; + for (auto &entry : search_path.Get()) { + auto &catalog = Catalog::GetCatalog(context, entry.catalog); + if (inserted_catalogs.find(catalog) != inserted_catalogs.end()) { + continue; + } + inserted_catalogs.insert(catalog); + catalogs.push_back(catalog); + } + } else { + catalogs.push_back(Catalog::GetCatalog(context, catalog_name)); + } + vector> result; + for (auto catalog : catalogs) { + auto schemas = catalog.get().GetSchemas(context); + result.insert(result.end(), schemas.begin(), schemas.end()); + } + return result; +} + +vector> Catalog::GetAllSchemas(ClientContext &context) { + vector> result; + + auto &db_manager = DatabaseManager::Get(context); + auto databases = db_manager.GetDatabases(context); + for (auto database : databases) { + auto &catalog = database.get().GetCatalog(); + auto new_schemas = catalog.GetSchemas(context); + result.insert(result.end(), new_schemas.begin(), new_schemas.end()); + } + sort(result.begin(), result.end(), + [&](reference left_p, reference right_p) { + auto &left = left_p.get(); + auto &right = right_p.get(); + if (left.catalog.GetName() < right.catalog.GetName()) { + return true; + } + if (left.catalog.GetName() == right.catalog.GetName()) { + return left.name < right.name; + } + return false; + }); + + return result; +} + +void Catalog::Alter(ClientContext &context, AlterInfo &info) { + ModifyCatalog(); + auto lookup = LookupEntry(context, info.GetCatalogType(), info.schema, info.name, info.if_not_found); + + if (!lookup.Found()) { + return; + } + return lookup.schema->Alter(context, info); +} + +vector Catalog::GetMetadataInfo(ClientContext &context) { + return vector(); +} + +void Catalog::Verify() { +} + +//===--------------------------------------------------------------------===// +// Catalog Version +//===--------------------------------------------------------------------===// +idx_t Catalog::GetCatalogVersion() { + return GetDatabase().GetDatabaseManager().catalog_version; +} + +idx_t Catalog::ModifyCatalog() { + return GetDatabase().GetDatabaseManager().ModifyCatalog(); +} + +bool Catalog::IsSystemCatalog() const { + return db.IsSystem(); +} + +bool Catalog::IsTemporaryCatalog() const { + return db.IsTemporary(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry.cpp new file mode 100644 index 00000000..844eaeaa --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry.cpp @@ -0,0 +1,75 @@ +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" + +namespace duckdb { + +CatalogEntry::CatalogEntry(CatalogType type, string name_p, idx_t oid) + : oid(oid), type(type), set(nullptr), name(std::move(name_p)), deleted(false), temporary(false), internal(false), + parent(nullptr) { +} + +CatalogEntry::CatalogEntry(CatalogType type, Catalog &catalog, string name_p) + : CatalogEntry(type, std::move(name_p), catalog.ModifyCatalog()) { +} + +CatalogEntry::~CatalogEntry() { +} + +void CatalogEntry::SetAsRoot() { +} + +// LCOV_EXCL_START +unique_ptr CatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { + throw InternalException("Unsupported alter type for catalog entry!"); +} + +void CatalogEntry::UndoAlter(ClientContext &context, AlterInfo &info) { +} + +unique_ptr CatalogEntry::Copy(ClientContext &context) const { + throw InternalException("Unsupported copy type for catalog entry!"); +} + +unique_ptr CatalogEntry::GetInfo() const { + throw InternalException("Unsupported type for CatalogEntry::GetInfo!"); +} + +string CatalogEntry::ToSQL() const { + throw InternalException("Unsupported catalog type for ToSQL()"); +} + +Catalog &CatalogEntry::ParentCatalog() { + throw InternalException("CatalogEntry::ParentCatalog called on catalog entry without catalog"); +} + +SchemaCatalogEntry &CatalogEntry::ParentSchema() { + throw InternalException("CatalogEntry::ParentSchema called on catalog entry without schema"); +} +// LCOV_EXCL_STOP + +void CatalogEntry::Serialize(Serializer &serializer) const { + const auto info = GetInfo(); + info->Serialize(serializer); +} + +unique_ptr CatalogEntry::Deserialize(Deserializer &deserializer) { + return CreateInfo::Deserialize(deserializer); +} + +void CatalogEntry::Verify(Catalog &catalog_p) { +} + +InCatalogEntry::InCatalogEntry(CatalogType type, Catalog &catalog, string name) + : CatalogEntry(type, catalog, std::move(name)), catalog(catalog) { +} + +InCatalogEntry::~InCatalogEntry() { +} + +void InCatalogEntry::Verify(Catalog &catalog_p) { + D_ASSERT(&catalog_p == &catalog); +} +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp b/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp new file mode 100644 index 00000000..e09f73df --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp @@ -0,0 +1,270 @@ +#include "duckdb/catalog/catalog_entry/column_dependency_manager.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/common/queue.hpp" + +namespace duckdb { + +ColumnDependencyManager::ColumnDependencyManager() { +} + +ColumnDependencyManager::~ColumnDependencyManager() { +} + +void ColumnDependencyManager::AddGeneratedColumn(const ColumnDefinition &column, const ColumnList &list) { + D_ASSERT(column.Generated()); + vector referenced_columns; + column.GetListOfDependencies(referenced_columns); + vector indices; + for (auto &col : referenced_columns) { + if (!list.ColumnExists(col)) { + throw BinderException("Column \"%s\" referenced by generated column does not exist", col); + } + auto &entry = list.GetColumn(col); + indices.push_back(entry.Logical()); + } + return AddGeneratedColumn(column.Logical(), indices); +} + +void ColumnDependencyManager::AddGeneratedColumn(LogicalIndex index, const vector &indices, bool root) { + if (indices.empty()) { + return; + } + auto &list = dependents_map[index]; + // Create a link between the dependencies + for (auto &dep : indices) { + // Add this column as a dependency of the new column + list.insert(dep); + // Add the new column as a dependent of the column + dependencies_map[dep].insert(index); + // Inherit the dependencies + if (HasDependencies(dep)) { + auto &inherited_deps = dependents_map[dep]; + D_ASSERT(!inherited_deps.empty()); + for (auto &inherited_dep : inherited_deps) { + list.insert(inherited_dep); + dependencies_map[inherited_dep].insert(index); + } + } + if (!root) { + continue; + } + direct_dependencies[index].insert(dep); + } + if (!HasDependents(index)) { + return; + } + auto &dependents = dependencies_map[index]; + if (dependents.count(index)) { + throw InvalidInputException("Circular dependency encountered when resolving generated column expressions"); + } + // Also let the dependents of this generated column inherit the dependencies + for (auto &dependent : dependents) { + AddGeneratedColumn(dependent, indices, false); + } +} + +vector ColumnDependencyManager::RemoveColumn(LogicalIndex index, idx_t column_amount) { + // Always add the initial column + deleted_columns.insert(index); + + RemoveGeneratedColumn(index); + RemoveStandardColumn(index); + + // Clean up the internal list + vector new_indices = CleanupInternals(column_amount); + D_ASSERT(deleted_columns.empty()); + return new_indices; +} + +bool ColumnDependencyManager::IsDependencyOf(LogicalIndex gcol, LogicalIndex col) const { + auto entry = dependents_map.find(gcol); + if (entry == dependents_map.end()) { + return false; + } + auto &list = entry->second; + return list.count(col); +} + +bool ColumnDependencyManager::HasDependencies(LogicalIndex index) const { + auto entry = dependents_map.find(index); + if (entry == dependents_map.end()) { + return false; + } + return true; +} + +const logical_index_set_t &ColumnDependencyManager::GetDependencies(LogicalIndex index) const { + auto entry = dependents_map.find(index); + D_ASSERT(entry != dependents_map.end()); + return entry->second; +} + +bool ColumnDependencyManager::HasDependents(LogicalIndex index) const { + auto entry = dependencies_map.find(index); + if (entry == dependencies_map.end()) { + return false; + } + return true; +} + +const logical_index_set_t &ColumnDependencyManager::GetDependents(LogicalIndex index) const { + auto entry = dependencies_map.find(index); + D_ASSERT(entry != dependencies_map.end()); + return entry->second; +} + +void ColumnDependencyManager::RemoveStandardColumn(LogicalIndex index) { + if (!HasDependents(index)) { + return; + } + auto dependents = dependencies_map[index]; + for (auto &gcol : dependents) { + // If index is a direct dependency of gcol, remove it from the list + if (direct_dependencies.find(gcol) != direct_dependencies.end()) { + direct_dependencies[gcol].erase(index); + } + RemoveGeneratedColumn(gcol); + } + // Remove this column from the dependencies map + dependencies_map.erase(index); +} + +void ColumnDependencyManager::RemoveGeneratedColumn(LogicalIndex index) { + deleted_columns.insert(index); + if (!HasDependencies(index)) { + return; + } + auto &dependencies = dependents_map[index]; + for (auto &col : dependencies) { + // Remove this generated column from the list of this column + auto &col_dependents = dependencies_map[col]; + D_ASSERT(col_dependents.count(index)); + col_dependents.erase(index); + // If the resulting list is empty, remove the column from the dependencies map altogether + if (col_dependents.empty()) { + dependencies_map.erase(col); + } + } + // Remove this column from the dependents_map map + dependents_map.erase(index); +} + +void ColumnDependencyManager::AdjustSingle(LogicalIndex idx, idx_t offset) { + D_ASSERT(idx.index >= offset); + LogicalIndex new_idx = LogicalIndex(idx.index - offset); + // Adjust this index in the dependents of this column + bool has_dependents = HasDependents(idx); + bool has_dependencies = HasDependencies(idx); + + if (has_dependents) { + auto &dependents = GetDependents(idx); + for (auto &dep : dependents) { + auto &dep_dependencies = dependents_map[dep]; + dep_dependencies.erase(idx); + D_ASSERT(!dep_dependencies.count(new_idx)); + dep_dependencies.insert(new_idx); + } + } + if (has_dependencies) { + auto &dependencies = GetDependencies(idx); + for (auto &dep : dependencies) { + auto &dep_dependents = dependencies_map[dep]; + dep_dependents.erase(idx); + D_ASSERT(!dep_dependents.count(new_idx)); + dep_dependents.insert(new_idx); + } + } + if (has_dependents) { + D_ASSERT(!dependencies_map.count(new_idx)); + dependencies_map[new_idx] = std::move(dependencies_map[idx]); + dependencies_map.erase(idx); + } + if (has_dependencies) { + D_ASSERT(!dependents_map.count(new_idx)); + dependents_map[new_idx] = std::move(dependents_map[idx]); + dependents_map.erase(idx); + } +} + +vector ColumnDependencyManager::CleanupInternals(idx_t column_amount) { + vector to_adjust; + D_ASSERT(!deleted_columns.empty()); + // Get the lowest index that was deleted + vector new_indices(column_amount, LogicalIndex(DConstants::INVALID_INDEX)); + idx_t threshold = deleted_columns.begin()->index; + + idx_t offset = 0; + for (idx_t i = 0; i < column_amount; i++) { + auto current_index = LogicalIndex(i); + auto new_index = LogicalIndex(i - offset); + new_indices[i] = new_index; + if (deleted_columns.count(current_index)) { + offset++; + continue; + } + if (i > threshold && (HasDependencies(current_index) || HasDependents(current_index))) { + to_adjust.push_back(current_index); + } + } + + // Adjust all indices inside the dependency managers internal mappings + for (auto &col : to_adjust) { + auto offset = col.index - new_indices[col.index].index; + AdjustSingle(col, offset); + } + deleted_columns.clear(); + return new_indices; +} + +stack ColumnDependencyManager::GetBindOrder(const ColumnList &columns) { + stack bind_order; + queue to_visit; + logical_index_set_t visited; + + for (auto &entry : direct_dependencies) { + auto dependent = entry.first; + //! Skip the dependents that are also dependencies + if (dependencies_map.find(dependent) != dependencies_map.end()) { + continue; + } + bind_order.push(dependent); + visited.insert(dependent); + for (auto &dependency : direct_dependencies[dependent]) { + to_visit.push(dependency); + } + } + + while (!to_visit.empty()) { + auto column = to_visit.front(); + to_visit.pop(); + + //! If this column does not have dependencies, the queue stops getting filled + if (direct_dependencies.find(column) == direct_dependencies.end()) { + continue; + } + bind_order.push(column); + visited.insert(column); + + for (auto &dependency : direct_dependencies[column]) { + to_visit.push(dependency); + } + } + + // Add generated columns that have no dependencies, but still might need to have their type resolved + for (auto &col : columns.Logical()) { + // Not a generated column + if (!col.Generated()) { + continue; + } + // Already added to the bind_order stack + if (visited.count(col.Logical())) { + continue; + } + bind_order.push(col.Logical()); + } + + return bind_order; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp new file mode 100644 index 00000000..25544a34 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp @@ -0,0 +1,11 @@ +#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" +#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" + +namespace duckdb { + +CopyFunctionCatalogEntry::CopyFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, + CreateCopyFunctionInfo &info) + : StandardEntry(CatalogType::COPY_FUNCTION_ENTRY, schema, catalog, info.name), function(info.function) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp new file mode 100644 index 00000000..0848bc16 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp @@ -0,0 +1,32 @@ +#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/execution/index/art/art.hpp" + +namespace duckdb { + +DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info) + : IndexCatalogEntry(catalog, schema, info) { +} + +DuckIndexEntry::~DuckIndexEntry() { + // remove the associated index from the info + if (!info || !index) { + return; + } + info->indexes.RemoveIndex(*index); +} + +string DuckIndexEntry::GetSchemaName() const { + return info->schema; +} + +string DuckIndexEntry::GetTableName() const { + return info->table; +} + +void DuckIndexEntry::CommitDrop() { + D_ASSERT(info && index); + index->CommitDrop(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp new file mode 100644 index 00000000..f501a946 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp @@ -0,0 +1,342 @@ +#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" +#include "duckdb/catalog/default/default_functions.hpp" +#include "duckdb/catalog/default/default_types.hpp" +#include "duckdb/catalog/default/default_views.hpp" +#include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" +#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/dependency_list.hpp" +#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" +#include "duckdb/parser/constraints/foreign_key_constraint.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/parser/parsed_data/create_collation_info.hpp" +#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" +#include "duckdb/parser/parsed_data/create_table_function_info.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" + +namespace duckdb { + +void FindForeignKeyInformation(CatalogEntry &entry, AlterForeignKeyType alter_fk_type, + vector> &fk_arrays) { + if (entry.type != CatalogType::TABLE_ENTRY) { + return; + } + auto &table_entry = entry.Cast(); + auto &constraints = table_entry.GetConstraints(); + for (idx_t i = 0; i < constraints.size(); i++) { + auto &cond = constraints[i]; + if (cond->type != ConstraintType::FOREIGN_KEY) { + continue; + } + auto &fk = cond->Cast(); + if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { + AlterEntryData alter_data(entry.ParentCatalog().GetName(), fk.info.schema, fk.info.table, + OnEntryNotFound::THROW_EXCEPTION); + fk_arrays.push_back(make_uniq(std::move(alter_data), entry.name, fk.pk_columns, + fk.fk_columns, fk.info.pk_keys, fk.info.fk_keys, + alter_fk_type)); + } else if (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && + alter_fk_type == AlterForeignKeyType::AFT_DELETE) { + throw CatalogException("Could not drop the table because this table is main key table of the table \"%s\"", + fk.info.table); + } + } +} + +DuckSchemaEntry::DuckSchemaEntry(Catalog &catalog, string name_p, bool is_internal) + : SchemaCatalogEntry(catalog, std::move(name_p), is_internal), + tables(catalog, make_uniq(catalog, *this)), indexes(catalog), table_functions(catalog), + copy_functions(catalog), pragma_functions(catalog), + functions(catalog, make_uniq(catalog, *this)), sequences(catalog), collations(catalog), + types(catalog, make_uniq(catalog, *this)) { +} + +optional_ptr DuckSchemaEntry::AddEntryInternal(CatalogTransaction transaction, + unique_ptr entry, + OnCreateConflict on_conflict, + DependencyList dependencies) { + auto entry_name = entry->name; + auto entry_type = entry->type; + auto result = entry.get(); + + // first find the set for this entry + auto &set = GetCatalogSet(entry_type); + dependencies.AddDependency(*this); + if (on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT) { + // CREATE OR REPLACE: first try to drop the entry + auto old_entry = set.GetEntry(transaction, entry_name); + if (old_entry) { + if (old_entry->type != entry_type) { + throw CatalogException("Existing object %s is of type %s, trying to replace with type %s", entry_name, + CatalogTypeToString(old_entry->type), CatalogTypeToString(entry_type)); + } + (void)set.DropEntry(transaction, entry_name, false, entry->internal); + } + } + // now try to add the entry + if (!set.CreateEntry(transaction, entry_name, std::move(entry), dependencies)) { + // entry already exists! + if (on_conflict == OnCreateConflict::ERROR_ON_CONFLICT) { + throw CatalogException("%s with name \"%s\" already exists!", CatalogTypeToString(entry_type), entry_name); + } else { + return nullptr; + } + } + return result; +} + +optional_ptr DuckSchemaEntry::CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) { + auto table = make_uniq(catalog, *this, info); + auto &storage = table->GetStorage(); + storage.info->cardinality = storage.GetTotalRows(); + + auto entry = AddEntryInternal(transaction, std::move(table), info.Base().on_conflict, info.dependencies); + if (!entry) { + return nullptr; + } + + // add a foreign key constraint in main key table if there is a foreign key constraint + vector> fk_arrays; + FindForeignKeyInformation(*entry, AlterForeignKeyType::AFT_ADD, fk_arrays); + for (idx_t i = 0; i < fk_arrays.size(); i++) { + // alter primary key table + auto &fk_info = *fk_arrays[i]; + catalog.Alter(transaction.GetContext(), fk_info); + + // make a dependency between this table and referenced table + auto &set = GetCatalogSet(CatalogType::TABLE_ENTRY); + info.dependencies.AddDependency(*set.GetEntry(transaction, fk_info.name)); + } + return entry; +} + +optional_ptr DuckSchemaEntry::CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) { + if (info.on_conflict == OnCreateConflict::ALTER_ON_CONFLICT) { + // check if the original entry exists + auto &catalog_set = GetCatalogSet(info.type); + auto current_entry = catalog_set.GetEntry(transaction, info.name); + if (current_entry) { + // the current entry exists - alter it instead + auto alter_info = info.GetAlterInfo(); + Alter(transaction.GetContext(), *alter_info); + return nullptr; + } + } + unique_ptr function; + switch (info.type) { + case CatalogType::SCALAR_FUNCTION_ENTRY: + function = make_uniq_base(catalog, *this, + info.Cast()); + break; + case CatalogType::TABLE_FUNCTION_ENTRY: + function = make_uniq_base(catalog, *this, + info.Cast()); + break; + case CatalogType::MACRO_ENTRY: + // create a macro function + function = make_uniq_base(catalog, *this, info.Cast()); + break; + + case CatalogType::TABLE_MACRO_ENTRY: + // create a macro table function + function = make_uniq_base(catalog, *this, info.Cast()); + break; + case CatalogType::AGGREGATE_FUNCTION_ENTRY: + D_ASSERT(info.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); + // create an aggregate function + function = make_uniq_base( + catalog, *this, info.Cast()); + break; + default: + throw InternalException("Unknown function type \"%s\"", CatalogTypeToString(info.type)); + } + function->internal = info.internal; + return AddEntry(transaction, std::move(function), info.on_conflict); +} + +optional_ptr DuckSchemaEntry::AddEntry(CatalogTransaction transaction, unique_ptr entry, + OnCreateConflict on_conflict) { + DependencyList dependencies; + return AddEntryInternal(transaction, std::move(entry), on_conflict, dependencies); +} + +optional_ptr DuckSchemaEntry::CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) { + auto sequence = make_uniq(catalog, *this, info); + return AddEntry(transaction, std::move(sequence), info.on_conflict); +} + +optional_ptr DuckSchemaEntry::CreateType(CatalogTransaction transaction, CreateTypeInfo &info) { + auto type_entry = make_uniq(catalog, *this, info); + return AddEntry(transaction, std::move(type_entry), info.on_conflict); +} + +optional_ptr DuckSchemaEntry::CreateView(CatalogTransaction transaction, CreateViewInfo &info) { + auto view = make_uniq(catalog, *this, info); + return AddEntry(transaction, std::move(view), info.on_conflict); +} + +optional_ptr DuckSchemaEntry::CreateIndex(ClientContext &context, CreateIndexInfo &info, + TableCatalogEntry &table) { + DependencyList dependencies; + dependencies.AddDependency(table); + auto index = make_uniq(catalog, *this, info); + return AddEntryInternal(GetCatalogTransaction(context), std::move(index), info.on_conflict, dependencies); +} + +optional_ptr DuckSchemaEntry::CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) { + auto collation = make_uniq(catalog, *this, info); + collation->internal = info.internal; + return AddEntry(transaction, std::move(collation), info.on_conflict); +} + +optional_ptr DuckSchemaEntry::CreateTableFunction(CatalogTransaction transaction, + CreateTableFunctionInfo &info) { + auto table_function = make_uniq(catalog, *this, info); + table_function->internal = info.internal; + return AddEntry(transaction, std::move(table_function), info.on_conflict); +} + +optional_ptr DuckSchemaEntry::CreateCopyFunction(CatalogTransaction transaction, + CreateCopyFunctionInfo &info) { + auto copy_function = make_uniq(catalog, *this, info); + copy_function->internal = info.internal; + return AddEntry(transaction, std::move(copy_function), info.on_conflict); +} + +optional_ptr DuckSchemaEntry::CreatePragmaFunction(CatalogTransaction transaction, + CreatePragmaFunctionInfo &info) { + auto pragma_function = make_uniq(catalog, *this, info); + pragma_function->internal = info.internal; + return AddEntry(transaction, std::move(pragma_function), info.on_conflict); +} + +void DuckSchemaEntry::Alter(ClientContext &context, AlterInfo &info) { + CatalogType type = info.GetCatalogType(); + auto &set = GetCatalogSet(type); + auto transaction = GetCatalogTransaction(context); + if (info.type == AlterType::CHANGE_OWNERSHIP) { + if (!set.AlterOwnership(transaction, info.Cast())) { + throw CatalogException("Couldn't change ownership!"); + } + } else { + string name = info.name; + if (!set.AlterEntry(transaction, name, info)) { + throw CatalogException("Entry with name \"%s\" does not exist!", name); + } + } +} + +void DuckSchemaEntry::Scan(ClientContext &context, CatalogType type, + const std::function &callback) { + auto &set = GetCatalogSet(type); + set.Scan(GetCatalogTransaction(context), callback); +} + +void DuckSchemaEntry::Scan(CatalogType type, const std::function &callback) { + auto &set = GetCatalogSet(type); + set.Scan(callback); +} + +void DuckSchemaEntry::DropEntry(ClientContext &context, DropInfo &info) { + auto &set = GetCatalogSet(info.type); + + // first find the entry + auto transaction = GetCatalogTransaction(context); + auto existing_entry = set.GetEntry(transaction, info.name); + if (!existing_entry) { + throw InternalException("Failed to drop entry \"%s\" - entry could not be found", info.name); + } + if (existing_entry->type != info.type) { + throw CatalogException("Existing object %s is of type %s, trying to replace with type %s", info.name, + CatalogTypeToString(existing_entry->type), CatalogTypeToString(info.type)); + } + + // if there is a foreign key constraint, get that information + vector> fk_arrays; + FindForeignKeyInformation(*existing_entry, AlterForeignKeyType::AFT_DELETE, fk_arrays); + + if (!set.DropEntry(transaction, info.name, info.cascade, info.allow_drop_internal)) { + throw InternalException("Could not drop element because of an internal error"); + } + + // remove the foreign key constraint in main key table if main key table's name is valid + for (idx_t i = 0; i < fk_arrays.size(); i++) { + // alter primary key table + catalog.Alter(context, *fk_arrays[i]); + } +} + +optional_ptr DuckSchemaEntry::GetEntry(CatalogTransaction transaction, CatalogType type, + const string &name) { + return GetCatalogSet(type).GetEntry(transaction, name); +} + +SimilarCatalogEntry DuckSchemaEntry::GetSimilarEntry(CatalogTransaction transaction, CatalogType type, + const string &name) { + return GetCatalogSet(type).SimilarEntry(transaction, name); +} + +CatalogSet &DuckSchemaEntry::GetCatalogSet(CatalogType type) { + switch (type) { + case CatalogType::VIEW_ENTRY: + case CatalogType::TABLE_ENTRY: + return tables; + case CatalogType::INDEX_ENTRY: + return indexes; + case CatalogType::TABLE_FUNCTION_ENTRY: + case CatalogType::TABLE_MACRO_ENTRY: + return table_functions; + case CatalogType::COPY_FUNCTION_ENTRY: + return copy_functions; + case CatalogType::PRAGMA_FUNCTION_ENTRY: + return pragma_functions; + case CatalogType::AGGREGATE_FUNCTION_ENTRY: + case CatalogType::SCALAR_FUNCTION_ENTRY: + case CatalogType::MACRO_ENTRY: + return functions; + case CatalogType::SEQUENCE_ENTRY: + return sequences; + case CatalogType::COLLATION_ENTRY: + return collations; + case CatalogType::TYPE_ENTRY: + return types; + default: + throw InternalException("Unsupported catalog type in schema"); + } +} + +void DuckSchemaEntry::Verify(Catalog &catalog) { + InCatalogEntry::Verify(catalog); + + tables.Verify(catalog); + indexes.Verify(catalog); + table_functions.Verify(catalog); + copy_functions.Verify(catalog); + pragma_functions.Verify(catalog); + functions.Verify(catalog); + sequences.Verify(catalog); + collations.Verify(catalog); + types.Verify(catalog); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp new file mode 100644 index 00000000..0a473398 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp @@ -0,0 +1,736 @@ +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/constraints/bound_check_constraint.hpp" +#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" +#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" +#include "duckdb/planner/constraints/bound_unique_constraint.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression_binder/alter_binder.hpp" +#include "duckdb/planner/filter/null_filter.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/common/index_map.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/constraints/list.hpp" +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_update.hpp" + +namespace duckdb { + +void AddDataTableIndex(DataTable &storage, const ColumnList &columns, const vector &keys, + IndexConstraintType constraint_type, BlockPointer index_block = BlockPointer()) { + // fetch types and create expressions for the index from the columns + vector column_ids; + vector> unbound_expressions; + vector> bound_expressions; + idx_t key_nr = 0; + column_ids.reserve(keys.size()); + for (auto &physical_key : keys) { + auto &column = columns.GetColumn(physical_key); + D_ASSERT(!column.Generated()); + unbound_expressions.push_back( + make_uniq(column.Name(), column.Type(), ColumnBinding(0, column_ids.size()))); + + bound_expressions.push_back(make_uniq(column.Type(), key_nr++)); + column_ids.push_back(column.StorageOid()); + } + unique_ptr art; + // create an adaptive radix tree around the expressions + if (index_block.IsValid()) { + art = make_uniq(column_ids, TableIOManager::Get(storage), std::move(unbound_expressions), constraint_type, + storage.db, nullptr, index_block); + } else { + art = make_uniq(column_ids, TableIOManager::Get(storage), std::move(unbound_expressions), constraint_type, + storage.db); + if (!storage.IsRoot()) { + throw TransactionException("Transaction conflict: cannot add an index to a table that has been altered!"); + } + } + storage.info->indexes.AddIndex(std::move(art)); +} + +void AddDataTableIndex(DataTable &storage, const ColumnList &columns, vector &keys, + IndexConstraintType constraint_type, BlockPointer index_block = BlockPointer()) { + vector new_keys; + new_keys.reserve(keys.size()); + for (auto &logical_key : keys) { + new_keys.push_back(columns.LogicalToPhysical(logical_key)); + } + AddDataTableIndex(storage, columns, new_keys, constraint_type, index_block); +} + +DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, + std::shared_ptr inherited_storage) + : TableCatalogEntry(catalog, schema, info.Base()), storage(std::move(inherited_storage)), + bound_constraints(std::move(info.bound_constraints)), + column_dependency_manager(std::move(info.column_dependency_manager)) { + if (!storage) { + // create the physical storage + vector storage_columns; + for (auto &col_def : columns.Physical()) { + storage_columns.push_back(col_def.Copy()); + } + storage = make_shared(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), + schema.name, name, std::move(storage_columns), std::move(info.data)); + + // create the unique indexes for the UNIQUE and PRIMARY KEY and FOREIGN KEY constraints + idx_t indexes_idx = 0; + for (idx_t i = 0; i < bound_constraints.size(); i++) { + auto &constraint = bound_constraints[i]; + if (constraint->type == ConstraintType::UNIQUE) { + // unique constraint: create a unique index + auto &unique = constraint->Cast(); + IndexConstraintType constraint_type = IndexConstraintType::UNIQUE; + if (unique.is_primary_key) { + constraint_type = IndexConstraintType::PRIMARY; + } + if (info.indexes.empty()) { + AddDataTableIndex(*storage, columns, unique.keys, constraint_type); + } else { + AddDataTableIndex(*storage, columns, unique.keys, constraint_type, info.indexes[indexes_idx++]); + } + } else if (constraint->type == ConstraintType::FOREIGN_KEY) { + // foreign key constraint: create a foreign key index + auto &bfk = constraint->Cast(); + if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || + bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + if (info.indexes.empty()) { + AddDataTableIndex(*storage, columns, bfk.info.fk_keys, IndexConstraintType::FOREIGN); + } else { + AddDataTableIndex(*storage, columns, bfk.info.fk_keys, IndexConstraintType::FOREIGN, + info.indexes[indexes_idx++]); + } + } + } + } + } +} + +unique_ptr DuckTableEntry::GetStatistics(ClientContext &context, column_t column_id) { + if (column_id == COLUMN_IDENTIFIER_ROW_ID) { + return nullptr; + } + auto &column = columns.GetColumn(LogicalIndex(column_id)); + if (column.Generated()) { + return nullptr; + } + return storage->GetStatistics(context, column.StorageOid()); +} + +unique_ptr DuckTableEntry::AlterEntry(ClientContext &context, AlterInfo &info) { + D_ASSERT(!internal); + if (info.type != AlterType::ALTER_TABLE) { + throw CatalogException("Can only modify table with ALTER TABLE statement"); + } + auto &table_info = info.Cast(); + switch (table_info.alter_table_type) { + case AlterTableType::RENAME_COLUMN: { + auto &rename_info = table_info.Cast(); + return RenameColumn(context, rename_info); + } + case AlterTableType::RENAME_TABLE: { + auto &rename_info = table_info.Cast(); + auto copied_table = Copy(context); + copied_table->name = rename_info.new_table_name; + storage->info->table = rename_info.new_table_name; + return copied_table; + } + case AlterTableType::ADD_COLUMN: { + auto &add_info = table_info.Cast(); + return AddColumn(context, add_info); + } + case AlterTableType::REMOVE_COLUMN: { + auto &remove_info = table_info.Cast(); + return RemoveColumn(context, remove_info); + } + case AlterTableType::SET_DEFAULT: { + auto &set_default_info = table_info.Cast(); + return SetDefault(context, set_default_info); + } + case AlterTableType::ALTER_COLUMN_TYPE: { + auto &change_type_info = table_info.Cast(); + return ChangeColumnType(context, change_type_info); + } + case AlterTableType::FOREIGN_KEY_CONSTRAINT: { + auto &foreign_key_constraint_info = table_info.Cast(); + if (foreign_key_constraint_info.type == AlterForeignKeyType::AFT_ADD) { + return AddForeignKeyConstraint(context, foreign_key_constraint_info); + } else { + return DropForeignKeyConstraint(context, foreign_key_constraint_info); + } + } + case AlterTableType::SET_NOT_NULL: { + auto &set_not_null_info = table_info.Cast(); + return SetNotNull(context, set_not_null_info); + } + case AlterTableType::DROP_NOT_NULL: { + auto &drop_not_null_info = table_info.Cast(); + return DropNotNull(context, drop_not_null_info); + } + default: + throw InternalException("Unrecognized alter table type!"); + } +} + +void DuckTableEntry::UndoAlter(ClientContext &context, AlterInfo &info) { + D_ASSERT(!internal); + D_ASSERT(info.type == AlterType::ALTER_TABLE); + auto &table_info = info.Cast(); + switch (table_info.alter_table_type) { + case AlterTableType::RENAME_TABLE: { + storage->info->table = this->name; + break; + default: + break; + } + } +} + +static void RenameExpression(ParsedExpression &expr, RenameColumnInfo &info) { + if (expr.type == ExpressionType::COLUMN_REF) { + auto &colref = expr.Cast(); + if (colref.column_names.back() == info.old_name) { + colref.column_names.back() = info.new_name; + } + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](const ParsedExpression &child) { RenameExpression((ParsedExpression &)child, info); }); +} + +unique_ptr DuckTableEntry::RenameColumn(ClientContext &context, RenameColumnInfo &info) { + auto rename_idx = GetColumnIndex(info.old_name); + if (rename_idx.index == COLUMN_IDENTIFIER_ROW_ID) { + throw CatalogException("Cannot rename rowid column"); + } + auto create_info = make_uniq(schema, name); + create_info->temporary = temporary; + for (auto &col : columns.Logical()) { + auto copy = col.Copy(); + if (rename_idx == col.Logical()) { + copy.SetName(info.new_name); + } + if (col.Generated() && column_dependency_manager.IsDependencyOf(col.Logical(), rename_idx)) { + RenameExpression(copy.GeneratedExpressionMutable(), info); + } + create_info->columns.AddColumn(std::move(copy)); + } + for (idx_t c_idx = 0; c_idx < constraints.size(); c_idx++) { + auto copy = constraints[c_idx]->Copy(); + switch (copy->type) { + case ConstraintType::NOT_NULL: + // NOT NULL constraint: no adjustments necessary + break; + case ConstraintType::CHECK: { + // CHECK constraint: need to rename column references that refer to the renamed column + auto &check = copy->Cast(); + RenameExpression(*check.expression, info); + break; + } + case ConstraintType::UNIQUE: { + // UNIQUE constraint: possibly need to rename columns + auto &unique = copy->Cast(); + for (idx_t i = 0; i < unique.columns.size(); i++) { + if (unique.columns[i] == info.old_name) { + unique.columns[i] = info.new_name; + } + } + break; + } + case ConstraintType::FOREIGN_KEY: { + // FOREIGN KEY constraint: possibly need to rename columns + auto &fk = copy->Cast(); + vector columns = fk.pk_columns; + if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { + columns = fk.fk_columns; + } else if (fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + for (idx_t i = 0; i < fk.fk_columns.size(); i++) { + columns.push_back(fk.fk_columns[i]); + } + } + for (idx_t i = 0; i < columns.size(); i++) { + if (columns[i] == info.old_name) { + throw CatalogException( + "Cannot rename column \"%s\" because this is involved in the foreign key constraint", + info.old_name); + } + } + break; + } + default: + throw InternalException("Unsupported constraint for entry!"); + } + create_info->constraints.push_back(std::move(copy)); + } + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + return make_uniq(catalog, schema, *bound_create_info, storage); +} + +unique_ptr DuckTableEntry::AddColumn(ClientContext &context, AddColumnInfo &info) { + auto col_name = info.new_column.GetName(); + + // We're checking for the opposite condition (ADD COLUMN IF _NOT_ EXISTS ...). + if (info.if_column_not_exists && ColumnExists(col_name)) { + return nullptr; + } + + auto create_info = make_uniq(schema, name); + create_info->temporary = temporary; + + for (auto &col : columns.Logical()) { + create_info->columns.AddColumn(col.Copy()); + } + for (auto &constraint : constraints) { + create_info->constraints.push_back(constraint->Copy()); + } + Binder::BindLogicalType(context, info.new_column.TypeMutable(), &catalog, schema.name); + info.new_column.SetOid(columns.LogicalColumnCount()); + info.new_column.SetStorageOid(columns.PhysicalColumnCount()); + auto col = info.new_column.Copy(); + + create_info->columns.AddColumn(std::move(col)); + + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + auto new_storage = + make_shared(context, *storage, info.new_column, *bound_create_info->bound_defaults.back()); + return make_uniq(catalog, schema, *bound_create_info, new_storage); +} + +void DuckTableEntry::UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_index, + const vector &adjusted_indices, + const RemoveColumnInfo &info, CreateTableInfo &create_info, + bool is_generated) { + // handle constraints for the new table + D_ASSERT(constraints.size() == bound_constraints.size()); + + for (idx_t constr_idx = 0; constr_idx < constraints.size(); constr_idx++) { + auto &constraint = constraints[constr_idx]; + auto &bound_constraint = bound_constraints[constr_idx]; + switch (constraint->type) { + case ConstraintType::NOT_NULL: { + auto ¬_null_constraint = bound_constraint->Cast(); + auto not_null_index = columns.PhysicalToLogical(not_null_constraint.index); + if (not_null_index != removed_index) { + // the constraint is not about this column: we need to copy it + // we might need to shift the index back by one though, to account for the removed column + auto new_index = adjusted_indices[not_null_index.index]; + create_info.constraints.push_back(make_uniq(new_index)); + } + break; + } + case ConstraintType::CHECK: { + // Generated columns can not be part of an index + // CHECK constraint + auto &bound_check = bound_constraint->Cast(); + // check if the removed column is part of the check constraint + if (is_generated) { + // generated columns can not be referenced by constraints, we can just add the constraint back + create_info.constraints.push_back(constraint->Copy()); + break; + } + auto physical_index = columns.LogicalToPhysical(removed_index); + if (bound_check.bound_columns.find(physical_index) != bound_check.bound_columns.end()) { + if (bound_check.bound_columns.size() > 1) { + // CHECK constraint that concerns mult + throw CatalogException( + "Cannot drop column \"%s\" because there is a CHECK constraint that depends on it", + info.removed_column); + } else { + // CHECK constraint that ONLY concerns this column, strip the constraint + } + } else { + // check constraint does not concern the removed column: simply re-add it + create_info.constraints.push_back(constraint->Copy()); + } + break; + } + case ConstraintType::UNIQUE: { + auto copy = constraint->Copy(); + auto &unique = copy->Cast(); + if (unique.index.index != DConstants::INVALID_INDEX) { + if (unique.index == removed_index) { + throw CatalogException( + "Cannot drop column \"%s\" because there is a UNIQUE constraint that depends on it", + info.removed_column); + } + unique.index = adjusted_indices[unique.index.index]; + } + create_info.constraints.push_back(std::move(copy)); + break; + } + case ConstraintType::FOREIGN_KEY: { + auto copy = constraint->Copy(); + auto &fk = copy->Cast(); + vector columns = fk.pk_columns; + if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { + columns = fk.fk_columns; + } else if (fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + for (idx_t i = 0; i < fk.fk_columns.size(); i++) { + columns.push_back(fk.fk_columns[i]); + } + } + for (idx_t i = 0; i < columns.size(); i++) { + if (columns[i] == info.removed_column) { + throw CatalogException( + "Cannot drop column \"%s\" because there is a FOREIGN KEY constraint that depends on it", + info.removed_column); + } + } + create_info.constraints.push_back(std::move(copy)); + break; + } + default: + throw InternalException("Unsupported constraint for entry!"); + } + } +} + +unique_ptr DuckTableEntry::RemoveColumn(ClientContext &context, RemoveColumnInfo &info) { + auto removed_index = GetColumnIndex(info.removed_column, info.if_column_exists); + if (!removed_index.IsValid()) { + if (!info.if_column_exists) { + throw CatalogException("Cannot drop column: rowid column cannot be dropped"); + } + return nullptr; + } + + auto create_info = make_uniq(schema, name); + create_info->temporary = temporary; + + logical_index_set_t removed_columns; + if (column_dependency_manager.HasDependents(removed_index)) { + removed_columns = column_dependency_manager.GetDependents(removed_index); + } + if (!removed_columns.empty() && !info.cascade) { + throw CatalogException("Cannot drop column: column is a dependency of 1 or more generated column(s)"); + } + bool dropped_column_is_generated = false; + for (auto &col : columns.Logical()) { + if (col.Logical() == removed_index || removed_columns.count(col.Logical())) { + if (col.Generated()) { + dropped_column_is_generated = true; + } + continue; + } + create_info->columns.AddColumn(col.Copy()); + } + if (create_info->columns.empty()) { + throw CatalogException("Cannot drop column: table only has one column remaining!"); + } + auto adjusted_indices = column_dependency_manager.RemoveColumn(removed_index, columns.LogicalColumnCount()); + + UpdateConstraintsOnColumnDrop(removed_index, adjusted_indices, info, *create_info, dropped_column_is_generated); + + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + if (columns.GetColumn(LogicalIndex(removed_index)).Generated()) { + return make_uniq(catalog, schema, *bound_create_info, storage); + } + auto new_storage = + make_shared(context, *storage, columns.LogicalToPhysical(LogicalIndex(removed_index)).index); + return make_uniq(catalog, schema, *bound_create_info, new_storage); +} + +unique_ptr DuckTableEntry::SetDefault(ClientContext &context, SetDefaultInfo &info) { + auto create_info = make_uniq(schema, name); + auto default_idx = GetColumnIndex(info.column_name); + if (default_idx.index == COLUMN_IDENTIFIER_ROW_ID) { + throw CatalogException("Cannot SET DEFAULT for rowid column"); + } + + // Copy all the columns, changing the value of the one that was specified by 'column_name' + for (auto &col : columns.Logical()) { + auto copy = col.Copy(); + if (default_idx == col.Logical()) { + // set the default value of this column + if (copy.Generated()) { + throw BinderException("Cannot SET DEFAULT for generated column \"%s\"", col.Name()); + } + copy.SetDefaultValue(info.expression ? info.expression->Copy() : nullptr); + } + create_info->columns.AddColumn(std::move(copy)); + } + // Copy all the constraints + for (idx_t i = 0; i < constraints.size(); i++) { + auto constraint = constraints[i]->Copy(); + create_info->constraints.push_back(std::move(constraint)); + } + + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + return make_uniq(catalog, schema, *bound_create_info, storage); +} + +unique_ptr DuckTableEntry::SetNotNull(ClientContext &context, SetNotNullInfo &info) { + + auto create_info = make_uniq(schema, name); + create_info->columns = columns.Copy(); + + auto not_null_idx = GetColumnIndex(info.column_name); + if (columns.GetColumn(LogicalIndex(not_null_idx)).Generated()) { + throw BinderException("Unsupported constraint for generated column!"); + } + bool has_not_null = false; + for (idx_t i = 0; i < constraints.size(); i++) { + auto constraint = constraints[i]->Copy(); + if (constraint->type == ConstraintType::NOT_NULL) { + auto ¬_null = constraint->Cast(); + if (not_null.index == not_null_idx) { + has_not_null = true; + } + } + create_info->constraints.push_back(std::move(constraint)); + } + if (!has_not_null) { + create_info->constraints.push_back(make_uniq(not_null_idx)); + } + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + + // Early return + if (has_not_null) { + return make_uniq(catalog, schema, *bound_create_info, storage); + } + + // Return with new storage info. Note that we need the bound column index here. + auto new_storage = make_shared( + context, *storage, make_uniq(columns.LogicalToPhysical(LogicalIndex(not_null_idx)))); + return make_uniq(catalog, schema, *bound_create_info, new_storage); +} + +unique_ptr DuckTableEntry::DropNotNull(ClientContext &context, DropNotNullInfo &info) { + auto create_info = make_uniq(schema, name); + create_info->columns = columns.Copy(); + + auto not_null_idx = GetColumnIndex(info.column_name); + for (idx_t i = 0; i < constraints.size(); i++) { + auto constraint = constraints[i]->Copy(); + // Skip/drop not_null + if (constraint->type == ConstraintType::NOT_NULL) { + auto ¬_null = constraint->Cast(); + if (not_null.index == not_null_idx) { + continue; + } + } + create_info->constraints.push_back(std::move(constraint)); + } + + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + return make_uniq(catalog, schema, *bound_create_info, storage); +} + +unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context, ChangeColumnTypeInfo &info) { + Binder::BindLogicalType(context, info.target_type, &catalog, schema.name); + auto change_idx = GetColumnIndex(info.column_name); + auto create_info = make_uniq(schema, name); + create_info->temporary = temporary; + + for (auto &col : columns.Logical()) { + auto copy = col.Copy(); + if (change_idx == col.Logical()) { + // set the type of this column + if (copy.Generated()) { + throw NotImplementedException("Changing types of generated columns is not supported yet"); + } + copy.SetType(info.target_type); + } + // TODO: check if the generated_expression breaks, only delete it if it does + if (copy.Generated() && column_dependency_manager.IsDependencyOf(col.Logical(), change_idx)) { + throw BinderException( + "This column is referenced by the generated column \"%s\", so its type can not be changed", + copy.Name()); + } + create_info->columns.AddColumn(std::move(copy)); + } + + for (idx_t i = 0; i < constraints.size(); i++) { + auto constraint = constraints[i]->Copy(); + switch (constraint->type) { + case ConstraintType::CHECK: { + auto &bound_check = bound_constraints[i]->Cast(); + auto physical_index = columns.LogicalToPhysical(change_idx); + if (bound_check.bound_columns.find(physical_index) != bound_check.bound_columns.end()) { + throw BinderException("Cannot change the type of a column that has a CHECK constraint specified"); + } + break; + } + case ConstraintType::NOT_NULL: + break; + case ConstraintType::UNIQUE: { + auto &bound_unique = bound_constraints[i]->Cast(); + if (bound_unique.key_set.find(change_idx) != bound_unique.key_set.end()) { + throw BinderException( + "Cannot change the type of a column that has a UNIQUE or PRIMARY KEY constraint specified"); + } + break; + } + case ConstraintType::FOREIGN_KEY: { + auto &bfk = bound_constraints[i]->Cast(); + auto key_set = bfk.pk_key_set; + if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { + key_set = bfk.fk_key_set; + } else if (bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + for (idx_t i = 0; i < bfk.info.fk_keys.size(); i++) { + key_set.insert(bfk.info.fk_keys[i]); + } + } + if (key_set.find(columns.LogicalToPhysical(change_idx)) != key_set.end()) { + throw BinderException("Cannot change the type of a column that has a FOREIGN KEY constraint specified"); + } + break; + } + default: + throw InternalException("Unsupported constraint for entry!"); + } + create_info->constraints.push_back(std::move(constraint)); + } + + auto binder = Binder::CreateBinder(context); + // bind the specified expression + vector bound_columns; + AlterBinder expr_binder(*binder, context, *this, bound_columns, info.target_type); + auto expression = info.expression->Copy(); + auto bound_expression = expr_binder.Bind(expression); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + vector storage_oids; + for (idx_t i = 0; i < bound_columns.size(); i++) { + storage_oids.push_back(columns.LogicalToPhysical(bound_columns[i]).index); + } + if (storage_oids.empty()) { + storage_oids.push_back(COLUMN_IDENTIFIER_ROW_ID); + } + + auto new_storage = + make_shared(context, *storage, columns.LogicalToPhysical(LogicalIndex(change_idx)).index, + info.target_type, std::move(storage_oids), *bound_expression); + auto result = make_uniq(catalog, schema, *bound_create_info, new_storage); + return std::move(result); +} + +unique_ptr DuckTableEntry::AddForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info) { + D_ASSERT(info.type == AlterForeignKeyType::AFT_ADD); + auto create_info = make_uniq(schema, name); + create_info->temporary = temporary; + + create_info->columns = columns.Copy(); + for (idx_t i = 0; i < constraints.size(); i++) { + create_info->constraints.push_back(constraints[i]->Copy()); + } + ForeignKeyInfo fk_info; + fk_info.type = ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE; + fk_info.schema = info.schema; + fk_info.table = info.fk_table; + fk_info.pk_keys = info.pk_keys; + fk_info.fk_keys = info.fk_keys; + create_info->constraints.push_back( + make_uniq(info.pk_columns, info.fk_columns, std::move(fk_info))); + + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + + return make_uniq(catalog, schema, *bound_create_info, storage); +} + +unique_ptr DuckTableEntry::DropForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info) { + D_ASSERT(info.type == AlterForeignKeyType::AFT_DELETE); + auto create_info = make_uniq(schema, name); + create_info->temporary = temporary; + + create_info->columns = columns.Copy(); + for (idx_t i = 0; i < constraints.size(); i++) { + auto constraint = constraints[i]->Copy(); + if (constraint->type == ConstraintType::FOREIGN_KEY) { + ForeignKeyConstraint &fk = constraint->Cast(); + if (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && fk.info.table == info.fk_table) { + continue; + } + } + create_info->constraints.push_back(std::move(constraint)); + } + + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + + return make_uniq(catalog, schema, *bound_create_info, storage); +} + +unique_ptr DuckTableEntry::Copy(ClientContext &context) const { + auto create_info = make_uniq(schema, name); + create_info->columns = columns.Copy(); + + for (idx_t i = 0; i < constraints.size(); i++) { + auto constraint = constraints[i]->Copy(); + create_info->constraints.push_back(std::move(constraint)); + } + + auto binder = Binder::CreateBinder(context); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info)); + return make_uniq(catalog, schema, *bound_create_info, storage); +} + +void DuckTableEntry::SetAsRoot() { + storage->SetAsRoot(); + storage->info->table = name; +} + +void DuckTableEntry::CommitAlter(string &column_name) { + D_ASSERT(!column_name.empty()); + idx_t removed_index = DConstants::INVALID_INDEX; + for (auto &col : columns.Logical()) { + if (col.Name() == column_name) { + // No need to alter storage, removed column is generated column + if (col.Generated()) { + return; + } + removed_index = col.Oid(); + break; + } + } + D_ASSERT(removed_index != DConstants::INVALID_INDEX); + storage->CommitDropColumn(columns.LogicalToPhysical(LogicalIndex(removed_index)).index); +} + +void DuckTableEntry::CommitDrop() { + storage->CommitDropTable(); +} + +DataTable &DuckTableEntry::GetStorage() { + return *storage; +} + +const vector> &DuckTableEntry::GetBoundConstraints() { + return bound_constraints; +} + +TableFunction DuckTableEntry::GetScanFunction(ClientContext &context, unique_ptr &bind_data) { + bind_data = make_uniq(*this); + return TableScanFunction::GetFunction(); +} + +vector DuckTableEntry::GetColumnSegmentInfo() { + return storage->GetColumnSegmentInfo(); +} + +TableStorageInfo DuckTableEntry::GetStorageInfo(ClientContext &context) { + TableStorageInfo result; + result.cardinality = storage->info->cardinality.load(); + storage->info->indexes.Scan([&](Index &index) { + IndexInfo info; + info.is_primary = index.IsPrimary(); + info.is_unique = index.IsUnique() || info.is_primary; + info.is_foreign = index.IsForeign(); + info.column_set = index.column_id_set; + result.index_info.push_back(std::move(info)); + return false; + }); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp new file mode 100644 index 00000000..84798bb6 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp @@ -0,0 +1,40 @@ +#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" +#include "duckdb/storage/index.hpp" + +namespace duckdb { + +IndexCatalogEntry::IndexCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info) + : StandardEntry(CatalogType::INDEX_ENTRY, schema, catalog, info.index_name), index(nullptr), sql(info.sql) { + this->temporary = info.temporary; +} + +string IndexCatalogEntry::ToSQL() const { + if (sql.empty()) { + return sql; + } + if (sql[sql.size() - 1] != ';') { + return sql + ";"; + } + return sql; +} + +unique_ptr IndexCatalogEntry::GetInfo() const { + auto result = make_uniq(); + result->schema = GetSchemaName(); + result->table = GetTableName(); + result->index_name = name; + result->sql = sql; + result->index_type = index->type; + result->constraint_type = index->constraint_type; + for (auto &expr : expressions) { + result->expressions.push_back(expr->Copy()); + } + for (auto &expr : parsed_expressions) { + result->parsed_expressions.push_back(expr->Copy()); + } + result->column_ids = index->column_ids; + result->temporary = temporary; + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp new file mode 100644 index 00000000..63ea28e3 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp @@ -0,0 +1,33 @@ +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" +#include "duckdb/function/scalar_macro_function.hpp" + +namespace duckdb { + +MacroCatalogEntry::MacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) + : FunctionEntry( + (info.function->type == MacroType::SCALAR_MACRO ? CatalogType::MACRO_ENTRY : CatalogType::TABLE_MACRO_ENTRY), + catalog, schema, info), + function(std::move(info.function)) { + this->temporary = info.temporary; + this->internal = info.internal; +} + +ScalarMacroCatalogEntry::ScalarMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) + : MacroCatalogEntry(catalog, schema, info) { +} + +TableMacroCatalogEntry::TableMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) + : MacroCatalogEntry(catalog, schema, info) { +} + +unique_ptr MacroCatalogEntry::GetInfo() const { + auto info = make_uniq(type); + info->catalog = catalog.GetName(); + info->schema = schema.name; + info->name = name; + info->function = function->Copy(); + return std::move(info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp new file mode 100644 index 00000000..ff247dcb --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp @@ -0,0 +1,11 @@ +#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" + +namespace duckdb { + +PragmaFunctionCatalogEntry::PragmaFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, + CreatePragmaFunctionInfo &info) + : FunctionEntry(CatalogType::PRAGMA_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp new file mode 100644 index 00000000..865ac473 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp @@ -0,0 +1,30 @@ +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" + +namespace duckdb { + +ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, + CreateScalarFunctionInfo &info) + : FunctionEntry(CatalogType::SCALAR_FUNCTION_ENTRY, catalog, schema, info), functions(info.functions) { +} + +unique_ptr ScalarFunctionCatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { + if (info.type != AlterType::ALTER_SCALAR_FUNCTION) { + throw InternalException("Attempting to alter ScalarFunctionCatalogEntry with unsupported alter type"); + } + auto &function_info = info.Cast(); + if (function_info.alter_scalar_function_type != AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS) { + throw InternalException( + "Attempting to alter ScalarFunctionCatalogEntry with unsupported alter scalar function type"); + } + auto &add_overloads = function_info.Cast(); + + ScalarFunctionSet new_set = functions; + if (!new_set.MergeFunctionSet(add_overloads.new_overloads)) { + throw BinderException("Failed to add new function overloads to function \"%s\": function already exists", name); + } + CreateScalarFunctionInfo new_info(std::move(new_set)); + return make_uniq(catalog, schema, new_info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp new file mode 100644 index 00000000..7b106252 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp @@ -0,0 +1,47 @@ +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/catalog/dependency_list.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" + +#include + +namespace duckdb { + +SchemaCatalogEntry::SchemaCatalogEntry(Catalog &catalog, string name_p, bool internal) + : InCatalogEntry(CatalogType::SCHEMA_ENTRY, catalog, std::move(name_p)) { + this->internal = internal; +} + +CatalogTransaction SchemaCatalogEntry::GetCatalogTransaction(ClientContext &context) { + return CatalogTransaction(catalog, context); +} + +SimilarCatalogEntry SchemaCatalogEntry::GetSimilarEntry(CatalogTransaction transaction, CatalogType type, + const string &name) { + SimilarCatalogEntry result; + Scan(transaction.GetContext(), type, [&](CatalogEntry &entry) { + auto ldist = StringUtil::SimilarityScore(entry.name, name); + if (ldist < result.distance) { + result.distance = ldist; + result.name = entry.name; + } + }); + return result; +} + +unique_ptr SchemaCatalogEntry::GetInfo() const { + auto result = make_uniq(); + result->schema = name; + return std::move(result); +} + +string SchemaCatalogEntry::ToSQL() const { + std::stringstream ss; + ss << "CREATE SCHEMA " << name << ";"; + return ss.str(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp new file mode 100644 index 00000000..928335bc --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp @@ -0,0 +1,44 @@ +#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" + +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" +#include "duckdb/catalog/dependency_manager.hpp" + +#include +#include + +namespace duckdb { + +SequenceCatalogEntry::SequenceCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateSequenceInfo &info) + : StandardEntry(CatalogType::SEQUENCE_ENTRY, schema, catalog, info.name), usage_count(info.usage_count), + counter(info.start_value), increment(info.increment), start_value(info.start_value), min_value(info.min_value), + max_value(info.max_value), cycle(info.cycle) { + this->temporary = info.temporary; +} + +unique_ptr SequenceCatalogEntry::GetInfo() const { + auto result = make_uniq(); + result->schema = schema.name; + result->name = name; + result->usage_count = usage_count; + result->increment = increment; + result->min_value = min_value; + result->max_value = max_value; + result->start_value = counter; + result->cycle = cycle; + return std::move(result); +} + +string SequenceCatalogEntry::ToSQL() const { + std::stringstream ss; + ss << "CREATE SEQUENCE "; + ss << name; + ss << " INCREMENT BY " << increment; + ss << " MINVALUE " << min_value; + ss << " MAXVALUE " << max_value; + ss << " START " << counter; + ss << " " << (cycle ? "CYCLE" : "NO CYCLE") << ";"; + return ss.str(); +} +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp new file mode 100644 index 00000000..9fd270ac --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -0,0 +1,307 @@ +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/constraints/list.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/planner/operator/logical_update.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/constraints/bound_check_constraint.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" + +#include + +namespace duckdb { + +TableCatalogEntry::TableCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info) + : StandardEntry(CatalogType::TABLE_ENTRY, schema, catalog, info.table), columns(std::move(info.columns)), + constraints(std::move(info.constraints)) { + this->temporary = info.temporary; +} + +bool TableCatalogEntry::HasGeneratedColumns() const { + return columns.LogicalColumnCount() != columns.PhysicalColumnCount(); +} + +LogicalIndex TableCatalogEntry::GetColumnIndex(string &column_name, bool if_exists) { + auto entry = columns.GetColumnIndex(column_name); + if (!entry.IsValid()) { + if (if_exists) { + return entry; + } + throw BinderException("Table \"%s\" does not have a column with name \"%s\"", name, column_name); + } + return entry; +} + +bool TableCatalogEntry::ColumnExists(const string &name) { + return columns.ColumnExists(name); +} + +const ColumnDefinition &TableCatalogEntry::GetColumn(const string &name) { + return columns.GetColumn(name); +} + +vector TableCatalogEntry::GetTypes() { + vector types; + for (auto &col : columns.Physical()) { + types.push_back(col.Type()); + } + return types; +} + +unique_ptr TableCatalogEntry::GetInfo() const { + auto result = make_uniq(); + result->catalog = catalog.GetName(); + result->schema = schema.name; + result->table = name; + result->columns = columns.Copy(); + result->constraints.reserve(constraints.size()); + std::for_each(constraints.begin(), constraints.end(), + [&result](const unique_ptr &c) { result->constraints.emplace_back(c->Copy()); }); + return std::move(result); +} + +string TableCatalogEntry::ColumnsToSQL(const ColumnList &columns, const vector> &constraints) { + std::stringstream ss; + + ss << "("; + + // find all columns that have NOT NULL specified, but are NOT primary key columns + logical_index_set_t not_null_columns; + logical_index_set_t unique_columns; + logical_index_set_t pk_columns; + unordered_set multi_key_pks; + vector extra_constraints; + for (auto &constraint : constraints) { + if (constraint->type == ConstraintType::NOT_NULL) { + auto ¬_null = constraint->Cast(); + not_null_columns.insert(not_null.index); + } else if (constraint->type == ConstraintType::UNIQUE) { + auto &pk = constraint->Cast(); + vector constraint_columns = pk.columns; + if (pk.index.index != DConstants::INVALID_INDEX) { + // no columns specified: single column constraint + if (pk.is_primary_key) { + pk_columns.insert(pk.index); + } else { + unique_columns.insert(pk.index); + } + } else { + // multi-column constraint, this constraint needs to go at the end after all columns + if (pk.is_primary_key) { + // multi key pk column: insert set of columns into multi_key_pks + for (auto &col : pk.columns) { + multi_key_pks.insert(col); + } + } + extra_constraints.push_back(constraint->ToString()); + } + } else if (constraint->type == ConstraintType::FOREIGN_KEY) { + auto &fk = constraint->Cast(); + if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || + fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + extra_constraints.push_back(constraint->ToString()); + } + } else { + extra_constraints.push_back(constraint->ToString()); + } + } + + for (auto &column : columns.Logical()) { + if (column.Oid() > 0) { + ss << ", "; + } + ss << KeywordHelper::WriteOptionallyQuoted(column.Name()) << " "; + ss << column.Type().ToString(); + bool not_null = not_null_columns.find(column.Logical()) != not_null_columns.end(); + bool is_single_key_pk = pk_columns.find(column.Logical()) != pk_columns.end(); + bool is_multi_key_pk = multi_key_pks.find(column.Name()) != multi_key_pks.end(); + bool is_unique = unique_columns.find(column.Logical()) != unique_columns.end(); + if (not_null && !is_single_key_pk && !is_multi_key_pk) { + // NOT NULL but not a primary key column + ss << " NOT NULL"; + } + if (is_single_key_pk) { + // single column pk: insert constraint here + ss << " PRIMARY KEY"; + } + if (is_unique) { + // single column unique: insert constraint here + ss << " UNIQUE"; + } + if (column.Generated()) { + ss << " GENERATED ALWAYS AS(" << column.GeneratedExpression().ToString() << ")"; + } else if (column.DefaultValue()) { + ss << " DEFAULT(" << column.DefaultValue()->ToString() << ")"; + } + } + // print any extra constraints that still need to be printed + for (auto &extra_constraint : extra_constraints) { + ss << ", "; + ss << extra_constraint; + } + + ss << ")"; + return ss.str(); +} + +string TableCatalogEntry::ToSQL() const { + std::stringstream ss; + + ss << "CREATE TABLE "; + + if (schema.name != DEFAULT_SCHEMA) { + ss << KeywordHelper::WriteOptionallyQuoted(schema.name) << "."; + } + + ss << KeywordHelper::WriteOptionallyQuoted(name); + ss << ColumnsToSQL(columns, constraints); + ss << ";"; + + return ss.str(); +} + +const ColumnList &TableCatalogEntry::GetColumns() const { + return columns; +} + +const ColumnDefinition &TableCatalogEntry::GetColumn(LogicalIndex idx) { + return columns.GetColumn(idx); +} + +const vector> &TableCatalogEntry::GetConstraints() { + return constraints; +} + +// LCOV_EXCL_START +DataTable &TableCatalogEntry::GetStorage() { + throw InternalException("Calling GetStorage on a TableCatalogEntry that is not a DuckTableEntry"); +} + +const vector> &TableCatalogEntry::GetBoundConstraints() { + throw InternalException("Calling GetBoundConstraints on a TableCatalogEntry that is not a DuckTableEntry"); +} + +// LCOV_EXCL_STOP + +static void BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, + physical_index_set_t &bound_columns) { + if (bound_columns.size() <= 1) { + return; + } + idx_t found_column_count = 0; + physical_index_set_t found_columns; + for (idx_t i = 0; i < update.columns.size(); i++) { + if (bound_columns.find(update.columns[i]) != bound_columns.end()) { + // this column is referenced in the CHECK constraint + found_column_count++; + found_columns.insert(update.columns[i]); + } + } + if (found_column_count > 0 && found_column_count != bound_columns.size()) { + // columns in this CHECK constraint were referenced, but not all were part of the UPDATE + // add them to the scan and update set + for (auto &check_column_id : bound_columns) { + if (found_columns.find(check_column_id) != found_columns.end()) { + // column is already projected + continue; + } + // column is not projected yet: project it by adding the clause "i=i" to the set of updated columns + auto &column = table.GetColumns().GetColumn(check_column_id); + update.expressions.push_back(make_uniq( + column.Type(), ColumnBinding(proj.table_index, proj.expressions.size()))); + proj.expressions.push_back(make_uniq( + column.Type(), ColumnBinding(get.table_index, get.column_ids.size()))); + get.column_ids.push_back(check_column_id.index); + update.columns.push_back(check_column_id); + } + } +} + +static bool TypeSupportsRegularUpdate(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::LIST: + case LogicalTypeId::MAP: + case LogicalTypeId::UNION: + // lists and maps and unions don't support updates directly + return false; + case LogicalTypeId::STRUCT: { + auto &child_types = StructType::GetChildTypes(type); + for (auto &entry : child_types) { + if (!TypeSupportsRegularUpdate(entry.second)) { + return false; + } + } + return true; + } + default: + return true; + } +} + +vector TableCatalogEntry::GetColumnSegmentInfo() { + return {}; +} + +void TableCatalogEntry::BindUpdateConstraints(LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, + ClientContext &context) { + // check the constraints and indexes of the table to see if we need to project any additional columns + // we do this for indexes with multiple columns and CHECK constraints in the UPDATE clause + // suppose we have a constraint CHECK(i + j < 10); now we need both i and j to check the constraint + // if we are only updating one of the two columns we add the other one to the UPDATE set + // with a "useless" update (i.e. i=i) so we can verify that the CHECK constraint is not violated + for (auto &constraint : GetBoundConstraints()) { + if (constraint->type == ConstraintType::CHECK) { + auto &check = constraint->Cast(); + // check constraint! check if we need to add any extra columns to the UPDATE clause + BindExtraColumns(*this, get, proj, update, check.bound_columns); + } + } + if (update.return_chunk) { + physical_index_set_t all_columns; + for (auto &column : GetColumns().Physical()) { + all_columns.insert(column.Physical()); + } + BindExtraColumns(*this, get, proj, update, all_columns); + } + // for index updates we always turn any update into an insert and a delete + // we thus need all the columns to be available, hence we check if the update touches any index columns + // If the returning keyword is used, we need access to the whole row in case the user requests it. + // Therefore switch the update to a delete and insert. + update.update_is_del_and_insert = false; + TableStorageInfo table_storage_info = GetStorageInfo(context); + for (auto index : table_storage_info.index_info) { + for (auto &column : update.columns) { + if (index.column_set.find(column.index) != index.column_set.end()) { + update.update_is_del_and_insert = true; + break; + } + } + }; + + // we also convert any updates on LIST columns into delete + insert + for (auto &col_index : update.columns) { + auto &column = GetColumns().GetColumn(col_index); + if (!TypeSupportsRegularUpdate(column.Type())) { + update.update_is_del_and_insert = true; + break; + } + } + + if (update.update_is_del_and_insert) { + // the update updates a column required by an index or requires returning the updated rows, + // push projections for all columns + physical_index_set_t all_columns; + for (auto &column : GetColumns().Physical()) { + all_columns.insert(column.Physical()); + } + BindExtraColumns(*this, get, proj, update, all_columns); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp new file mode 100644 index 00000000..1b3b566b --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp @@ -0,0 +1,31 @@ +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/parser/parsed_data/alter_table_function_info.hpp" + +namespace duckdb { + +TableFunctionCatalogEntry::TableFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, + CreateTableFunctionInfo &info) + : FunctionEntry(CatalogType::TABLE_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { + D_ASSERT(this->functions.Size() > 0); +} + +unique_ptr TableFunctionCatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { + if (info.type != AlterType::ALTER_TABLE_FUNCTION) { + throw InternalException("Attempting to alter TableFunctionCatalogEntry with unsupported alter type"); + } + auto &function_info = info.Cast(); + if (function_info.alter_table_function_type != AlterTableFunctionType::ADD_FUNCTION_OVERLOADS) { + throw InternalException( + "Attempting to alter TableFunctionCatalogEntry with unsupported alter table function type"); + } + auto &add_overloads = function_info.Cast(); + + TableFunctionSet new_set = functions; + if (!new_set.MergeFunctionSet(add_overloads.new_overloads)) { + throw BinderException("Failed to add new function overloads to function \"%s\": function already exists", name); + } + CreateTableFunctionInfo new_info(std::move(new_set)); + return make_uniq(catalog, schema, new_info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp new file mode 100644 index 00000000..055bdfc7 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp @@ -0,0 +1,53 @@ +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include +#include + +namespace duckdb { + +TypeCatalogEntry::TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTypeInfo &info) + : StandardEntry(CatalogType::TYPE_ENTRY, schema, catalog, info.name), user_type(info.type) { + this->temporary = info.temporary; + this->internal = info.internal; +} + +unique_ptr TypeCatalogEntry::GetInfo() const { + auto result = make_uniq(); + result->catalog = catalog.GetName(); + result->schema = schema.name; + result->name = name; + result->type = user_type; + return std::move(result); +} + +string TypeCatalogEntry::ToSQL() const { + std::stringstream ss; + switch (user_type.id()) { + case (LogicalTypeId::ENUM): { + auto &values_insert_order = EnumType::GetValuesInsertOrder(user_type); + idx_t size = EnumType::GetSize(user_type); + ss << "CREATE TYPE "; + ss << KeywordHelper::WriteOptionallyQuoted(name); + ss << " AS ENUM ( "; + + for (idx_t i = 0; i < size; i++) { + ss << "'" << values_insert_order.GetValue(i).ToString() << "'"; + if (i != size - 1) { + ss << ", "; + } + } + ss << ");"; + break; + } + default: + throw InternalException("Logical Type can't be used as a User Defined Type"); + } + + return ss.str(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp new file mode 100644 index 00000000..1f41f740 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp @@ -0,0 +1,80 @@ +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" + +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/common/limits.hpp" + +#include + +namespace duckdb { + +void ViewCatalogEntry::Initialize(CreateViewInfo &info) { + query = std::move(info.query); + this->aliases = info.aliases; + this->types = info.types; + this->temporary = info.temporary; + this->sql = info.sql; + this->internal = info.internal; +} + +ViewCatalogEntry::ViewCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateViewInfo &info) + : StandardEntry(CatalogType::VIEW_ENTRY, schema, catalog, info.view_name) { + Initialize(info); +} + +unique_ptr ViewCatalogEntry::GetInfo() const { + auto result = make_uniq(); + result->schema = schema.name; + result->view_name = name; + result->sql = sql; + result->query = unique_ptr_cast(query->Copy()); + result->aliases = aliases; + result->types = types; + return std::move(result); +} + +unique_ptr ViewCatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { + D_ASSERT(!internal); + if (info.type != AlterType::ALTER_VIEW) { + throw CatalogException("Can only modify view with ALTER VIEW statement"); + } + auto &view_info = info.Cast(); + switch (view_info.alter_view_type) { + case AlterViewType::RENAME_VIEW: { + auto &rename_info = view_info.Cast(); + auto copied_view = Copy(context); + copied_view->name = rename_info.new_view_name; + return copied_view; + } + default: + throw InternalException("Unrecognized alter view type!"); + } +} + +string ViewCatalogEntry::ToSQL() const { + if (sql.empty()) { + //! Return empty sql with view name so pragma view_tables don't complain + return sql; + } + return sql + "\n;"; +} + +unique_ptr ViewCatalogEntry::Copy(ClientContext &context) const { + D_ASSERT(!internal); + CreateViewInfo create_info(schema, name); + create_info.query = unique_ptr_cast(query->Copy()); + for (idx_t i = 0; i < aliases.size(); i++) { + create_info.aliases.push_back(aliases[i]); + } + for (idx_t i = 0; i < types.size(); i++) { + create_info.types.push_back(types[i]); + } + create_info.temporary = temporary; + create_info.sql = sql; + + return make_uniq(catalog, schema, create_info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_search_path.cpp b/src/duckdb/src/catalog/catalog_search_path.cpp new file mode 100644 index 00000000..6be5f491 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_search_path.cpp @@ -0,0 +1,266 @@ +#include "duckdb/catalog/catalog_search_path.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database_manager.hpp" + +namespace duckdb { + +CatalogSearchEntry::CatalogSearchEntry(string catalog_p, string schema_p) + : catalog(std::move(catalog_p)), schema(std::move(schema_p)) { +} + +string CatalogSearchEntry::ToString() const { + if (catalog.empty()) { + return WriteOptionallyQuoted(schema); + } else { + return WriteOptionallyQuoted(catalog) + "." + WriteOptionallyQuoted(schema); + } +} + +string CatalogSearchEntry::WriteOptionallyQuoted(const string &input) { + for (idx_t i = 0; i < input.size(); i++) { + if (input[i] == '.' || input[i] == ',') { + return "\"" + input + "\""; + } + } + return input; +} + +string CatalogSearchEntry::ListToString(const vector &input) { + string result; + for (auto &entry : input) { + if (!result.empty()) { + result += ","; + } + result += entry.ToString(); + } + return result; +} + +CatalogSearchEntry CatalogSearchEntry::ParseInternal(const string &input, idx_t &idx) { + string catalog; + string schema; + string entry; + bool finished = false; +normal: + for (; idx < input.size(); idx++) { + if (input[idx] == '"') { + idx++; + goto quoted; + } else if (input[idx] == '.') { + goto separator; + } else if (input[idx] == ',') { + finished = true; + goto separator; + } + entry += input[idx]; + } + finished = true; + goto separator; +quoted: + //! look for another quote + for (; idx < input.size(); idx++) { + if (input[idx] == '"') { + //! unquote + idx++; + goto normal; + } + entry += input[idx]; + } + throw ParserException("Unterminated quote in qualified name!"); +separator: + if (entry.empty()) { + throw ParserException("Unexpected dot - empty CatalogSearchEntry"); + } + if (schema.empty()) { + // if we parse one entry it is the schema + schema = std::move(entry); + } else if (catalog.empty()) { + // if we parse two entries it is [catalog.schema] + catalog = std::move(schema); + schema = std::move(entry); + } else { + throw ParserException("Too many dots - expected [schema] or [catalog.schema] for CatalogSearchEntry"); + } + entry = ""; + idx++; + if (finished) { + goto final; + } + goto normal; +final: + if (schema.empty()) { + throw ParserException("Unexpected end of entry - empty CatalogSearchEntry"); + } + return CatalogSearchEntry(std::move(catalog), std::move(schema)); +} + +CatalogSearchEntry CatalogSearchEntry::Parse(const string &input) { + idx_t pos = 0; + auto result = ParseInternal(input, pos); + if (pos < input.size()) { + throw ParserException("Failed to convert entry \"%s\" to CatalogSearchEntry - expected a single entry", input); + } + return result; +} + +vector CatalogSearchEntry::ParseList(const string &input) { + idx_t pos = 0; + vector result; + while (pos < input.size()) { + auto entry = ParseInternal(input, pos); + result.push_back(entry); + } + return result; +} + +CatalogSearchPath::CatalogSearchPath(ClientContext &context_p) : context(context_p) { + Reset(); +} + +void CatalogSearchPath::Reset() { + vector empty; + SetPaths(empty); +} + +string CatalogSearchPath::GetSetName(CatalogSetPathType set_type) { + switch (set_type) { + case CatalogSetPathType::SET_SCHEMA: + return "SET schema"; + case CatalogSetPathType::SET_SCHEMAS: + return "SET search_path"; + default: + throw InternalException("Unrecognized CatalogSetPathType"); + } +} + +void CatalogSearchPath::Set(vector new_paths, CatalogSetPathType set_type) { + if (set_type != CatalogSetPathType::SET_SCHEMAS && new_paths.size() != 1) { + throw CatalogException("%s can set only 1 schema. This has %d", GetSetName(set_type), new_paths.size()); + } + for (auto &path : new_paths) { + auto schema_entry = Catalog::GetSchema(context, path.catalog, path.schema, OnEntryNotFound::RETURN_NULL); + if (schema_entry) { + // we are setting a schema - update the catalog and schema + if (path.catalog.empty()) { + path.catalog = GetDefault().catalog; + } + continue; + } + // only schema supplied - check if this is a catalog instead + if (path.catalog.empty()) { + auto catalog = Catalog::GetCatalogEntry(context, path.schema); + if (catalog) { + auto schema = catalog->GetSchema(context, DEFAULT_SCHEMA, OnEntryNotFound::RETURN_NULL); + if (schema) { + path.catalog = std::move(path.schema); + path.schema = schema->name; + continue; + } + } + } + throw CatalogException("%s: No catalog + schema named \"%s\" found.", GetSetName(set_type), path.ToString()); + } + if (set_type == CatalogSetPathType::SET_SCHEMA) { + if (new_paths[0].catalog == TEMP_CATALOG || new_paths[0].catalog == SYSTEM_CATALOG) { + throw CatalogException("%s cannot be set to internal schema \"%s\"", GetSetName(set_type), + new_paths[0].catalog); + } + } + this->set_paths = std::move(new_paths); + SetPaths(set_paths); +} + +void CatalogSearchPath::Set(CatalogSearchEntry new_value, CatalogSetPathType set_type) { + vector new_paths {std::move(new_value)}; + Set(std::move(new_paths), set_type); +} + +const vector &CatalogSearchPath::Get() { + return paths; +} + +string CatalogSearchPath::GetDefaultSchema(const string &catalog) { + for (auto &path : paths) { + if (path.catalog == TEMP_CATALOG) { + continue; + } + if (StringUtil::CIEquals(path.catalog, catalog)) { + return path.schema; + } + } + return DEFAULT_SCHEMA; +} + +string CatalogSearchPath::GetDefaultCatalog(const string &schema) { + for (auto &path : paths) { + if (path.catalog == TEMP_CATALOG) { + continue; + } + if (StringUtil::CIEquals(path.schema, schema)) { + return path.catalog; + } + } + return INVALID_CATALOG; +} + +vector CatalogSearchPath::GetCatalogsForSchema(const string &schema) { + vector schemas; + for (auto &path : paths) { + if (StringUtil::CIEquals(path.schema, schema)) { + schemas.push_back(path.catalog); + } + } + return schemas; +} + +vector CatalogSearchPath::GetSchemasForCatalog(const string &catalog) { + vector schemas; + for (auto &path : paths) { + if (StringUtil::CIEquals(path.catalog, catalog)) { + schemas.push_back(path.schema); + } + } + return schemas; +} + +const CatalogSearchEntry &CatalogSearchPath::GetDefault() { + const auto &paths = Get(); + D_ASSERT(paths.size() >= 2); + return paths[1]; +} + +void CatalogSearchPath::SetPaths(vector new_paths) { + paths.clear(); + paths.reserve(new_paths.size() + 3); + paths.emplace_back(TEMP_CATALOG, DEFAULT_SCHEMA); + for (auto &path : new_paths) { + paths.push_back(std::move(path)); + } + paths.emplace_back(INVALID_CATALOG, DEFAULT_SCHEMA); + paths.emplace_back(SYSTEM_CATALOG, DEFAULT_SCHEMA); + paths.emplace_back(SYSTEM_CATALOG, "pg_catalog"); +} + +bool CatalogSearchPath::SchemaInSearchPath(ClientContext &context, const string &catalog_name, + const string &schema_name) { + for (auto &path : paths) { + if (!StringUtil::CIEquals(path.schema, schema_name)) { + continue; + } + if (StringUtil::CIEquals(path.catalog, catalog_name)) { + return true; + } + if (IsInvalidCatalog(path.catalog) && + StringUtil::CIEquals(catalog_name, DatabaseManager::GetDefaultDatabase(context))) { + return true; + } + } + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_set.cpp b/src/duckdb/src/catalog/catalog_set.cpp new file mode 100644 index 00000000..8440786d --- /dev/null +++ b/src/duckdb/src/catalog/catalog_set.cpp @@ -0,0 +1,663 @@ +#include "duckdb/catalog/catalog_set.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/catalog/mapping_value.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/transaction/transaction_manager.hpp" + +namespace duckdb { + +//! Class responsible to keep track of state when removing entries from the catalog. +//! When deleting, many types of errors can be thrown, since we want to avoid try/catch blocks +//! this class makes sure that whatever elements were modified are returned to a correct state +//! when exceptions are thrown. +//! The idea here is to use RAII (Resource acquisition is initialization) to mimic a try/catch/finally block. +//! If any exception is raised when this object exists, then its destructor will be called +//! and the entry will return to its previous state during deconstruction. +class EntryDropper { +public: + //! Both constructor and destructor are privates because they should only be called by DropEntryDependencies + explicit EntryDropper(EntryIndex &entry_index_p) : entry_index(entry_index_p) { + old_deleted = entry_index.GetEntry()->deleted; + } + + ~EntryDropper() { + entry_index.GetEntry()->deleted = old_deleted; + } + +private: + //! Keeps track of the state of the entry before starting the delete + bool old_deleted; + //! Index of entry to be deleted + EntryIndex &entry_index; +}; + +CatalogSet::CatalogSet(Catalog &catalog_p, unique_ptr defaults) + : catalog(catalog_p.Cast()), defaults(std::move(defaults)) { + D_ASSERT(catalog_p.IsDuckCatalog()); +} +CatalogSet::~CatalogSet() { +} + +EntryIndex CatalogSet::PutEntry(idx_t entry_index, unique_ptr entry) { + if (entries.find(entry_index) != entries.end()) { + throw InternalException("Entry with entry index \"%llu\" already exists", entry_index); + } + entries.insert(make_pair(entry_index, EntryValue(std::move(entry)))); + return EntryIndex(*this, entry_index); +} + +void CatalogSet::PutEntry(EntryIndex index, unique_ptr catalog_entry) { + auto entry = entries.find(index.GetIndex()); + if (entry == entries.end()) { + throw InternalException("Entry with entry index \"%llu\" does not exist", index.GetIndex()); + } + catalog_entry->child = std::move(entry->second.entry); + catalog_entry->child->parent = catalog_entry.get(); + entry->second.entry = std::move(catalog_entry); +} + +bool CatalogSet::CreateEntry(CatalogTransaction transaction, const string &name, unique_ptr value, + DependencyList &dependencies) { + if (value->internal && !catalog.IsSystemCatalog() && name != DEFAULT_SCHEMA) { + throw InternalException("Attempting to create internal entry \"%s\" in non-system catalog - internal entries " + "can only be created in the system catalog", + name); + } + if (!value->internal) { + if (!value->temporary && catalog.IsSystemCatalog()) { + throw InternalException( + "Attempting to create non-internal entry \"%s\" in system catalog - the system catalog " + "can only contain internal entries", + name); + } + if (value->temporary && !catalog.IsTemporaryCatalog()) { + throw InternalException("Attempting to create temporary entry \"%s\" in non-temporary catalog", name); + } + if (!value->temporary && catalog.IsTemporaryCatalog() && name != DEFAULT_SCHEMA) { + throw InvalidInputException("Cannot create non-temporary entry \"%s\" in temporary catalog", name); + } + } + // lock the catalog for writing + lock_guard write_lock(catalog.GetWriteLock()); + // lock this catalog set to disallow reading + unique_lock read_lock(catalog_lock); + + // first check if the entry exists in the unordered set + idx_t index; + auto mapping_value = GetMapping(transaction, name); + if (mapping_value == nullptr || mapping_value->deleted) { + // if it does not: entry has never been created + + // check if there is a default entry + auto entry = CreateDefaultEntry(transaction, name, read_lock); + if (entry) { + return false; + } + + // first create a dummy deleted entry for this entry + // so transactions started before the commit of this transaction don't + // see it yet + auto dummy_node = make_uniq(CatalogType::INVALID, value->ParentCatalog(), name); + dummy_node->timestamp = 0; + dummy_node->deleted = true; + dummy_node->set = this; + + auto entry_index = PutEntry(current_entry++, std::move(dummy_node)); + index = entry_index.GetIndex(); + PutMapping(transaction, name, std::move(entry_index)); + } else { + index = mapping_value->index.GetIndex(); + auto ¤t = *mapping_value->index.GetEntry(); + // if it does, we have to check version numbers + if (HasConflict(transaction, current.timestamp)) { + // current version has been written to by a currently active + // transaction + throw TransactionException("Catalog write-write conflict on create with \"%s\"", current.name); + } + // there is a current version that has been committed + // if it has not been deleted there is a conflict + if (!current.deleted) { + return false; + } + } + // create a new entry and replace the currently stored one + // set the timestamp to the timestamp of the current transaction + // and point it at the dummy node + value->timestamp = transaction.transaction_id; + value->set = this; + + // now add the dependency set of this object to the dependency manager + catalog.GetDependencyManager().AddObject(transaction, *value, dependencies); + + auto value_ptr = value.get(); + EntryIndex entry_index(*this, index); + PutEntry(std::move(entry_index), std::move(value)); + // push the old entry in the undo buffer for this transaction + if (transaction.transaction) { + auto &dtransaction = transaction.transaction->Cast(); + dtransaction.PushCatalogEntry(*value_ptr->child); + } + return true; +} + +bool CatalogSet::CreateEntry(ClientContext &context, const string &name, unique_ptr value, + DependencyList &dependencies) { + return CreateEntry(catalog.GetCatalogTransaction(context), name, std::move(value), dependencies); +} + +optional_ptr CatalogSet::GetEntryInternal(CatalogTransaction transaction, EntryIndex &entry_index) { + auto &catalog_entry = *entry_index.GetEntry(); + // if it does: we have to retrieve the entry and to check version numbers + if (HasConflict(transaction, catalog_entry.timestamp)) { + // current version has been written to by a currently active + // transaction + throw TransactionException("Catalog write-write conflict on alter with \"%s\"", catalog_entry.name); + } + // there is a current version that has been committed by this transaction + if (catalog_entry.deleted) { + // if the entry was already deleted, it now does not exist anymore + // so we return that we could not find it + return nullptr; + } + return &catalog_entry; +} + +optional_ptr CatalogSet::GetEntryInternal(CatalogTransaction transaction, const string &name, + EntryIndex *entry_index) { + auto mapping_value = GetMapping(transaction, name); + if (mapping_value == nullptr || mapping_value->deleted) { + // the entry does not exist, check if we can create a default entry + return nullptr; + } + if (entry_index) { + *entry_index = mapping_value->index.Copy(); + } + return GetEntryInternal(transaction, mapping_value->index); +} + +bool CatalogSet::AlterOwnership(CatalogTransaction transaction, ChangeOwnershipInfo &info) { + auto entry = GetEntryInternal(transaction, info.name, nullptr); + if (!entry) { + return false; + } + + auto &owner_entry = catalog.GetEntry(transaction.GetContext(), info.owner_schema, info.owner_name); + catalog.GetDependencyManager().AddOwnership(transaction, owner_entry, *entry); + return true; +} + +bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, AlterInfo &alter_info) { + // lock the catalog for writing + lock_guard write_lock(catalog.GetWriteLock()); + + // first check if the entry exists in the unordered set + EntryIndex entry_index; + auto entry = GetEntryInternal(transaction, name, &entry_index); + if (!entry) { + return false; + } + if (!alter_info.allow_internal && entry->internal) { + throw CatalogException("Cannot alter entry \"%s\" because it is an internal system entry", entry->name); + } + + // lock this catalog set to disallow reading + lock_guard read_lock(catalog_lock); + + // create a new entry and replace the currently stored one + // set the timestamp to the timestamp of the current transaction + // and point it to the updated table node + string original_name = entry->name; + if (!transaction.context) { + throw InternalException("Cannot AlterEntry without client context"); + } + auto &context = *transaction.context; + auto value = entry->AlterEntry(context, alter_info); + if (!value) { + // alter failed, but did not result in an error + return true; + } + + if (value->name != original_name) { + auto mapping_value = GetMapping(transaction, value->name); + if (mapping_value && !mapping_value->deleted) { + auto &original_entry = GetEntryForTransaction(transaction, *mapping_value->index.GetEntry()); + if (!original_entry.deleted) { + entry->UndoAlter(context, alter_info); + string rename_err_msg = + "Could not rename \"%s\" to \"%s\": another entry with this name already exists!"; + throw CatalogException(rename_err_msg, original_name, value->name); + } + } + } + + if (value->name != original_name) { + // Do PutMapping and DeleteMapping after dependency check + PutMapping(transaction, value->name, entry_index.Copy()); + DeleteMapping(transaction, original_name); + } + + value->timestamp = transaction.transaction_id; + value->set = this; + auto new_entry = value.get(); + PutEntry(std::move(entry_index), std::move(value)); + + // serialize the AlterInfo into a temporary buffer + MemoryStream stream; + BinarySerializer serializer(stream); + serializer.Begin(); + serializer.WriteProperty(100, "column_name", alter_info.GetColumnName()); + serializer.WriteProperty(101, "alter_info", &alter_info); + serializer.End(); + + // push the old entry in the undo buffer for this transaction + if (transaction.transaction) { + auto &dtransaction = transaction.transaction->Cast(); + dtransaction.PushCatalogEntry(*new_entry->child, stream.GetData(), stream.GetPosition()); + } + + // Check the dependency manager to verify that there are no conflicting dependencies with this alter + // Note that we do this AFTER the new entry has been entirely set up in the catalog set + // that is because in case the alter fails because of a dependency conflict, we need to be able to cleanly roll back + // to the old entry. + catalog.GetDependencyManager().AlterObject(transaction, *entry, *new_entry); + + return true; +} + +void CatalogSet::DropEntryDependencies(CatalogTransaction transaction, EntryIndex &entry_index, CatalogEntry &entry, + bool cascade) { + // Stores the deleted value of the entry before starting the process + EntryDropper dropper(entry_index); + + // To correctly delete the object and its dependencies, it temporarily is set to deleted. + entry_index.GetEntry()->deleted = true; + + // check any dependencies of this object + D_ASSERT(entry.ParentCatalog().IsDuckCatalog()); + auto &duck_catalog = entry.ParentCatalog().Cast(); + duck_catalog.GetDependencyManager().DropObject(transaction, entry, cascade); + + // dropper destructor is called here + // the destructor makes sure to return the value to the previous state + // dropper.~EntryDropper() +} + +void CatalogSet::DropEntryInternal(CatalogTransaction transaction, EntryIndex entry_index, CatalogEntry &entry, + bool cascade) { + DropEntryDependencies(transaction, entry_index, entry, cascade); + + // create a new entry and replace the currently stored one + // set the timestamp to the timestamp of the current transaction + // and point it at the dummy node + auto value = make_uniq(CatalogType::DELETED_ENTRY, entry.ParentCatalog(), entry.name); + value->timestamp = transaction.transaction_id; + value->set = this; + value->deleted = true; + auto value_ptr = value.get(); + PutEntry(std::move(entry_index), std::move(value)); + + // push the old entry in the undo buffer for this transaction + if (transaction.transaction) { + auto &dtransaction = transaction.transaction->Cast(); + dtransaction.PushCatalogEntry(*value_ptr->child); + } +} + +bool CatalogSet::DropEntry(CatalogTransaction transaction, const string &name, bool cascade, bool allow_drop_internal) { + // lock the catalog for writing + lock_guard write_lock(catalog.GetWriteLock()); + // we can only delete an entry that exists + EntryIndex entry_index; + auto entry = GetEntryInternal(transaction, name, &entry_index); + if (!entry) { + return false; + } + if (entry->internal && !allow_drop_internal) { + throw CatalogException("Cannot drop entry \"%s\" because it is an internal system entry", entry->name); + } + + lock_guard read_lock(catalog_lock); + DropEntryInternal(transaction, std::move(entry_index), *entry, cascade); + return true; +} + +bool CatalogSet::DropEntry(ClientContext &context, const string &name, bool cascade, bool allow_drop_internal) { + return DropEntry(catalog.GetCatalogTransaction(context), name, cascade, allow_drop_internal); +} + +DuckCatalog &CatalogSet::GetCatalog() { + return catalog; +} + +void CatalogSet::CleanupEntry(CatalogEntry &catalog_entry) { + // destroy the backed up entry: it is no longer required + D_ASSERT(catalog_entry.parent); + if (catalog_entry.parent->type != CatalogType::UPDATED_ENTRY) { + lock_guard write_lock(catalog.GetWriteLock()); + lock_guard lock(catalog_lock); + if (!catalog_entry.deleted) { + // delete the entry from the dependency manager, if it is not deleted yet + D_ASSERT(catalog_entry.ParentCatalog().IsDuckCatalog()); + catalog_entry.ParentCatalog().Cast().GetDependencyManager().EraseObject(catalog_entry); + } + auto parent = catalog_entry.parent; + parent->child = std::move(catalog_entry.child); + if (parent->deleted && !parent->child && !parent->parent) { + auto mapping_entry = mapping.find(parent->name); + D_ASSERT(mapping_entry != mapping.end()); + auto &entry = mapping_entry->second->index.GetEntry(); + D_ASSERT(entry); + if (entry.get() == parent.get()) { + mapping.erase(mapping_entry); + } + } + } +} + +bool CatalogSet::HasConflict(CatalogTransaction transaction, transaction_t timestamp) { + return (timestamp >= TRANSACTION_ID_START && timestamp != transaction.transaction_id) || + (timestamp < TRANSACTION_ID_START && timestamp > transaction.start_time); +} + +optional_ptr CatalogSet::GetMapping(CatalogTransaction transaction, const string &name, bool get_latest) { + optional_ptr mapping_value; + auto entry = mapping.find(name); + if (entry != mapping.end()) { + mapping_value = entry->second.get(); + } else { + + return nullptr; + } + if (get_latest) { + return mapping_value; + } + while (mapping_value->child) { + if (UseTimestamp(transaction, mapping_value->timestamp)) { + break; + } + mapping_value = mapping_value->child.get(); + D_ASSERT(mapping_value); + } + return mapping_value; +} + +void CatalogSet::PutMapping(CatalogTransaction transaction, const string &name, EntryIndex entry_index) { + auto entry = mapping.find(name); + auto new_value = make_uniq(std::move(entry_index)); + new_value->timestamp = transaction.transaction_id; + if (entry != mapping.end()) { + if (HasConflict(transaction, entry->second->timestamp)) { + throw TransactionException("Catalog write-write conflict on name \"%s\"", name); + } + new_value->child = std::move(entry->second); + new_value->child->parent = new_value.get(); + } + mapping[name] = std::move(new_value); +} + +void CatalogSet::DeleteMapping(CatalogTransaction transaction, const string &name) { + auto entry = mapping.find(name); + D_ASSERT(entry != mapping.end()); + auto delete_marker = make_uniq(entry->second->index.Copy()); + delete_marker->deleted = true; + delete_marker->timestamp = transaction.transaction_id; + delete_marker->child = std::move(entry->second); + delete_marker->child->parent = delete_marker.get(); + mapping[name] = std::move(delete_marker); +} + +bool CatalogSet::UseTimestamp(CatalogTransaction transaction, transaction_t timestamp) { + if (timestamp == transaction.transaction_id) { + // we created this version + return true; + } + if (timestamp < transaction.start_time) { + // this version was commited before we started the transaction + return true; + } + return false; +} + +CatalogEntry &CatalogSet::GetEntryForTransaction(CatalogTransaction transaction, CatalogEntry ¤t) { + reference entry(current); + while (entry.get().child) { + if (UseTimestamp(transaction, entry.get().timestamp)) { + break; + } + entry = *entry.get().child; + } + return entry.get(); +} + +CatalogEntry &CatalogSet::GetCommittedEntry(CatalogEntry ¤t) { + reference entry(current); + while (entry.get().child) { + if (entry.get().timestamp < TRANSACTION_ID_START) { + // this entry is committed: use it + break; + } + entry = *entry.get().child; + } + return entry.get(); +} + +SimilarCatalogEntry CatalogSet::SimilarEntry(CatalogTransaction transaction, const string &name) { + unique_lock lock(catalog_lock); + CreateDefaultEntries(transaction, lock); + + SimilarCatalogEntry result; + for (auto &kv : mapping) { + auto mapping_value = GetMapping(transaction, kv.first); + if (mapping_value && !mapping_value->deleted) { + auto ldist = StringUtil::SimilarityScore(kv.first, name); + if (ldist < result.distance) { + result.distance = ldist; + result.name = kv.first; + } + } + } + return result; +} + +optional_ptr CatalogSet::CreateEntryInternal(CatalogTransaction transaction, + unique_ptr entry) { + if (mapping.find(entry->name) != mapping.end()) { + return nullptr; + } + auto &name = entry->name; + auto catalog_entry = entry.get(); + + entry->set = this; + entry->timestamp = 0; + + auto entry_index = PutEntry(current_entry++, std::move(entry)); + PutMapping(transaction, name, std::move(entry_index)); + mapping[name]->timestamp = 0; + return catalog_entry; +} + +optional_ptr CatalogSet::CreateDefaultEntry(CatalogTransaction transaction, const string &name, + unique_lock &lock) { + // no entry found with this name, check for defaults + if (!defaults || defaults->created_all_entries) { + // no defaults either: return null + return nullptr; + } + // this catalog set has a default map defined + // check if there is a default entry that we can create with this name + if (!transaction.context) { + // no context - cannot create default entry + return nullptr; + } + lock.unlock(); + auto entry = defaults->CreateDefaultEntry(*transaction.context, name); + + lock.lock(); + if (!entry) { + // no default entry + return nullptr; + } + // there is a default entry! create it + auto result = CreateEntryInternal(transaction, std::move(entry)); + if (result) { + return result; + } + // we found a default entry, but failed + // this means somebody else created the entry first + // just retry? + lock.unlock(); + return GetEntry(transaction, name); +} + +optional_ptr CatalogSet::GetEntry(CatalogTransaction transaction, const string &name) { + unique_lock lock(catalog_lock); + auto mapping_value = GetMapping(transaction, name); + if (mapping_value != nullptr && !mapping_value->deleted) { + // we found an entry for this name + // check the version numbers + + auto &catalog_entry = *mapping_value->index.GetEntry(); + auto ¤t = GetEntryForTransaction(transaction, catalog_entry); + if (current.deleted || (current.name != name && !UseTimestamp(transaction, mapping_value->timestamp))) { + return nullptr; + } + return ¤t; + } + return CreateDefaultEntry(transaction, name, lock); +} + +optional_ptr CatalogSet::GetEntry(ClientContext &context, const string &name) { + return GetEntry(catalog.GetCatalogTransaction(context), name); +} + +void CatalogSet::UpdateTimestamp(CatalogEntry &entry, transaction_t timestamp) { + entry.timestamp = timestamp; + mapping[entry.name]->timestamp = timestamp; +} + +void CatalogSet::Undo(CatalogEntry &entry) { + lock_guard write_lock(catalog.GetWriteLock()); + lock_guard lock(catalog_lock); + + // entry has to be restored + // and entry->parent has to be removed ("rolled back") + + // i.e. we have to place (entry) as (entry->parent) again + auto &to_be_removed_node = *entry.parent; + + if (!to_be_removed_node.deleted) { + // delete the entry from the dependency manager as well + auto &dependency_manager = catalog.GetDependencyManager(); + dependency_manager.EraseObject(to_be_removed_node); + } + if (!StringUtil::CIEquals(entry.name, to_be_removed_node.name)) { + // rename: clean up the new name when the rename is rolled back + auto removed_entry = mapping.find(to_be_removed_node.name); + if (removed_entry->second->child) { + removed_entry->second->child->parent = nullptr; + mapping[to_be_removed_node.name] = std::move(removed_entry->second->child); + } else { + mapping.erase(removed_entry); + } + } + if (to_be_removed_node.parent) { + // if the to be removed node has a parent, set the child pointer to the + // to be restored node + to_be_removed_node.parent->child = std::move(to_be_removed_node.child); + entry.parent = to_be_removed_node.parent; + } else { + // otherwise we need to update the base entry tables + auto &name = entry.name; + to_be_removed_node.child->SetAsRoot(); + mapping[name]->index.GetEntry() = std::move(to_be_removed_node.child); + entry.parent = nullptr; + } + + // restore the name if it was deleted + auto restored_entry = mapping.find(entry.name); + if (restored_entry->second->deleted || entry.type == CatalogType::INVALID) { + if (restored_entry->second->child) { + restored_entry->second->child->parent = nullptr; + mapping[entry.name] = std::move(restored_entry->second->child); + } else { + mapping.erase(restored_entry); + } + } + // we mark the catalog as being modified, since this action can lead to e.g. tables being dropped + catalog.ModifyCatalog(); +} + +void CatalogSet::CreateDefaultEntries(CatalogTransaction transaction, unique_lock &lock) { + if (!defaults || defaults->created_all_entries || !transaction.context) { + return; + } + // this catalog set has a default set defined: + auto default_entries = defaults->GetDefaultEntries(); + for (auto &default_entry : default_entries) { + auto map_entry = mapping.find(default_entry); + if (map_entry == mapping.end()) { + // we unlock during the CreateEntry, since it might reference other catalog sets... + // specifically for views this can happen since the view will be bound + lock.unlock(); + auto entry = defaults->CreateDefaultEntry(*transaction.context, default_entry); + if (!entry) { + throw InternalException("Failed to create default entry for %s", default_entry); + } + + lock.lock(); + CreateEntryInternal(transaction, std::move(entry)); + } + } + defaults->created_all_entries = true; +} + +void CatalogSet::Scan(CatalogTransaction transaction, const std::function &callback) { + // lock the catalog set + unique_lock lock(catalog_lock); + CreateDefaultEntries(transaction, lock); + + for (auto &kv : entries) { + auto &entry = *kv.second.entry.get(); + auto &entry_for_transaction = GetEntryForTransaction(transaction, entry); + if (!entry_for_transaction.deleted) { + callback(entry_for_transaction); + } + } +} + +void CatalogSet::Scan(ClientContext &context, const std::function &callback) { + Scan(catalog.GetCatalogTransaction(context), callback); +} + +void CatalogSet::Scan(const std::function &callback) { + // lock the catalog set + lock_guard lock(catalog_lock); + for (auto &kv : entries) { + auto entry = kv.second.entry.get(); + auto &commited_entry = GetCommittedEntry(*entry); + if (!commited_entry.deleted) { + callback(commited_entry); + } + } +} + +void CatalogSet::Verify(Catalog &catalog_p) { + D_ASSERT(&catalog_p == &catalog); + vector> entries; + Scan([&](CatalogEntry &entry) { entries.push_back(entry); }); + for (auto &entry : entries) { + entry.get().Verify(catalog_p); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_transaction.cpp b/src/duckdb/src/catalog/catalog_transaction.cpp new file mode 100644 index 00000000..a99a8028 --- /dev/null +++ b/src/duckdb/src/catalog/catalog_transaction.cpp @@ -0,0 +1,38 @@ +#include "duckdb/catalog/catalog_transaction.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/main/database.hpp" + +namespace duckdb { + +CatalogTransaction::CatalogTransaction(Catalog &catalog, ClientContext &context) { + auto &transaction = Transaction::Get(context, catalog); + this->db = &DatabaseInstance::GetDatabase(context); + if (!transaction.IsDuckTransaction()) { + this->transaction_id = transaction_t(-1); + this->start_time = transaction_t(-1); + } else { + auto &dtransaction = transaction.Cast(); + this->transaction_id = dtransaction.transaction_id; + this->start_time = dtransaction.start_time; + } + this->transaction = &transaction; + this->context = &context; +} + +CatalogTransaction::CatalogTransaction(DatabaseInstance &db, transaction_t transaction_id_p, transaction_t start_time_p) + : db(&db), context(nullptr), transaction(nullptr), transaction_id(transaction_id_p), start_time(start_time_p) { +} + +ClientContext &CatalogTransaction::GetContext() { + if (!context) { + throw InternalException("Attempting to get a context in a CatalogTransaction without a context"); + } + return *context; +} + +CatalogTransaction CatalogTransaction::GetSystemTransaction(DatabaseInstance &db) { + return CatalogTransaction(db, 1, 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_functions.cpp b/src/duckdb/src/catalog/default/default_functions.cpp new file mode 100644 index 00000000..1c8a528f --- /dev/null +++ b/src/duckdb/src/catalog/default/default_functions.cpp @@ -0,0 +1,236 @@ +#include "duckdb/catalog/default/default_functions.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/function/table_macro_function.hpp" + +#include "duckdb/function/scalar_macro_function.hpp" + +namespace duckdb { + +static DefaultMacro internal_macros[] = { + {DEFAULT_SCHEMA, "current_role", {nullptr}, "'duckdb'"}, // user name of current execution context + {DEFAULT_SCHEMA, "current_user", {nullptr}, "'duckdb'"}, // user name of current execution context + {DEFAULT_SCHEMA, "current_catalog", {nullptr}, "current_database()"}, // name of current database (called "catalog" in the SQL standard) + {DEFAULT_SCHEMA, "user", {nullptr}, "current_user"}, // equivalent to current_user + {DEFAULT_SCHEMA, "session_user", {nullptr}, "'duckdb'"}, // session user name + {"pg_catalog", "inet_client_addr", {nullptr}, "NULL"}, // address of the remote connection + {"pg_catalog", "inet_client_port", {nullptr}, "NULL"}, // port of the remote connection + {"pg_catalog", "inet_server_addr", {nullptr}, "NULL"}, // address of the local connection + {"pg_catalog", "inet_server_port", {nullptr}, "NULL"}, // port of the local connection + {"pg_catalog", "pg_my_temp_schema", {nullptr}, "0"}, // OID of session's temporary schema, or 0 if none + {"pg_catalog", "pg_is_other_temp_schema", {"schema_id", nullptr}, "false"}, // is schema another session's temporary schema? + + {"pg_catalog", "pg_conf_load_time", {nullptr}, "current_timestamp"}, // configuration load time + {"pg_catalog", "pg_postmaster_start_time", {nullptr}, "current_timestamp"}, // server start time + + {"pg_catalog", "pg_typeof", {"expression", nullptr}, "lower(typeof(expression))"}, // get the data type of any value + + // privilege functions + // {"has_any_column_privilege", {"user", "table", "privilege", nullptr}, "true"}, //boolean //does user have privilege for any column of table + {"pg_catalog", "has_any_column_privilege", {"table", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for any column of table + // {"has_column_privilege", {"user", "table", "column", "privilege", nullptr}, "true"}, //boolean //does user have privilege for column + {"pg_catalog", "has_column_privilege", {"table", "column", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for column + // {"has_database_privilege", {"user", "database", "privilege", nullptr}, "true"}, //boolean //does user have privilege for database + {"pg_catalog", "has_database_privilege", {"database", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for database + // {"has_foreign_data_wrapper_privilege", {"user", "fdw", "privilege", nullptr}, "true"}, //boolean //does user have privilege for foreign-data wrapper + {"pg_catalog", "has_foreign_data_wrapper_privilege", {"fdw", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for foreign-data wrapper + // {"has_function_privilege", {"user", "function", "privilege", nullptr}, "true"}, //boolean //does user have privilege for function + {"pg_catalog", "has_function_privilege", {"function", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for function + // {"has_language_privilege", {"user", "language", "privilege", nullptr}, "true"}, //boolean //does user have privilege for language + {"pg_catalog", "has_language_privilege", {"language", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for language + // {"has_schema_privilege", {"user", "schema, privilege", nullptr}, "true"}, //boolean //does user have privilege for schema + {"pg_catalog", "has_schema_privilege", {"schema", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for schema + // {"has_sequence_privilege", {"user", "sequence", "privilege", nullptr}, "true"}, //boolean //does user have privilege for sequence + {"pg_catalog", "has_sequence_privilege", {"sequence", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for sequence + // {"has_server_privilege", {"user", "server", "privilege", nullptr}, "true"}, //boolean //does user have privilege for foreign server + {"pg_catalog", "has_server_privilege", {"server", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for foreign server + // {"has_table_privilege", {"user", "table", "privilege", nullptr}, "true"}, //boolean //does user have privilege for table + {"pg_catalog", "has_table_privilege", {"table", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for table + // {"has_tablespace_privilege", {"user", "tablespace", "privilege", nullptr}, "true"}, //boolean //does user have privilege for tablespace + {"pg_catalog", "has_tablespace_privilege", {"tablespace", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for tablespace + + // various postgres system functions + {"pg_catalog", "pg_get_viewdef", {"oid", nullptr}, "(select sql from duckdb_views() v where v.view_oid=oid)"}, + {"pg_catalog", "pg_get_constraintdef", {"constraint_oid", "pretty_bool", nullptr}, "(select constraint_text from duckdb_constraints() d_constraint where d_constraint.table_oid=constraint_oid//1000000 and d_constraint.constraint_index=constraint_oid%1000000)"}, + {"pg_catalog", "pg_get_expr", {"pg_node_tree", "relation_oid", nullptr}, "pg_node_tree"}, + {"pg_catalog", "format_pg_type", {"type_name", nullptr}, "case when logical_type='FLOAT' then 'real' when logical_type='DOUBLE' then 'double precision' when logical_type='DECIMAL' then 'numeric' when logical_type='ENUM' then lower(type_name) when logical_type='VARCHAR' then 'character varying' when logical_type='BLOB' then 'bytea' when logical_type='TIMESTAMP' then 'timestamp without time zone' when logical_type='TIME' then 'time without time zone' else lower(logical_type) end"}, + {"pg_catalog", "format_type", {"type_oid", "typemod", nullptr}, "(select format_pg_type(type_name) from duckdb_types() t where t.type_oid=type_oid) || case when typemod>0 then concat('(', typemod//1000, ',', typemod%1000, ')') else '' end"}, + + {"pg_catalog", "pg_has_role", {"user", "role", "privilege", nullptr}, "true"}, //boolean //does user have privilege for role + {"pg_catalog", "pg_has_role", {"role", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for role + + {"pg_catalog", "col_description", {"table_oid", "column_number", nullptr}, "NULL"}, // get comment for a table column + {"pg_catalog", "obj_description", {"object_oid", "catalog_name", nullptr}, "NULL"}, // get comment for a database object + {"pg_catalog", "shobj_description", {"object_oid", "catalog_name", nullptr}, "NULL"}, // get comment for a shared database object + + // visibility functions + {"pg_catalog", "pg_collation_is_visible", {"collation_oid", nullptr}, "true"}, + {"pg_catalog", "pg_conversion_is_visible", {"conversion_oid", nullptr}, "true"}, + {"pg_catalog", "pg_function_is_visible", {"function_oid", nullptr}, "true"}, + {"pg_catalog", "pg_opclass_is_visible", {"opclass_oid", nullptr}, "true"}, + {"pg_catalog", "pg_operator_is_visible", {"operator_oid", nullptr}, "true"}, + {"pg_catalog", "pg_opfamily_is_visible", {"opclass_oid", nullptr}, "true"}, + {"pg_catalog", "pg_table_is_visible", {"table_oid", nullptr}, "true"}, + {"pg_catalog", "pg_ts_config_is_visible", {"config_oid", nullptr}, "true"}, + {"pg_catalog", "pg_ts_dict_is_visible", {"dict_oid", nullptr}, "true"}, + {"pg_catalog", "pg_ts_parser_is_visible", {"parser_oid", nullptr}, "true"}, + {"pg_catalog", "pg_ts_template_is_visible", {"template_oid", nullptr}, "true"}, + {"pg_catalog", "pg_type_is_visible", {"type_oid", nullptr}, "true"}, + + {"pg_catalog", "pg_size_pretty", {"bytes", nullptr}, "format_bytes(bytes)"}, + + {DEFAULT_SCHEMA, "round_even", {"x", "n", nullptr}, "CASE ((abs(x) * power(10, n+1)) % 10) WHEN 5 THEN round(x/2, n) * 2 ELSE round(x, n) END"}, + {DEFAULT_SCHEMA, "roundbankers", {"x", "n", nullptr}, "round_even(x, n)"}, + {DEFAULT_SCHEMA, "nullif", {"a", "b", nullptr}, "CASE WHEN a=b THEN NULL ELSE a END"}, + {DEFAULT_SCHEMA, "list_append", {"l", "e", nullptr}, "list_concat(l, list_value(e))"}, + {DEFAULT_SCHEMA, "array_append", {"arr", "el", nullptr}, "list_append(arr, el)"}, + {DEFAULT_SCHEMA, "list_prepend", {"e", "l", nullptr}, "list_concat(list_value(e), l)"}, + {DEFAULT_SCHEMA, "array_prepend", {"el", "arr", nullptr}, "list_prepend(el, arr)"}, + {DEFAULT_SCHEMA, "array_pop_back", {"arr", nullptr}, "arr[:LEN(arr)-1]"}, + {DEFAULT_SCHEMA, "array_pop_front", {"arr", nullptr}, "arr[2:]"}, + {DEFAULT_SCHEMA, "array_push_back", {"arr", "e", nullptr}, "list_concat(arr, list_value(e))"}, + {DEFAULT_SCHEMA, "array_push_front", {"arr", "e", nullptr}, "list_concat(list_value(e), arr)"}, + {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, "list_aggr(arr, 'string_agg', sep)"}, + {DEFAULT_SCHEMA, "generate_subscripts", {"arr", "dim", nullptr}, "unnest(generate_series(1, array_length(arr, dim)))"}, + {DEFAULT_SCHEMA, "fdiv", {"x", "y", nullptr}, "floor(x/y)"}, + {DEFAULT_SCHEMA, "fmod", {"x", "y", nullptr}, "(x-y*floor(x/y))"}, + {DEFAULT_SCHEMA, "count_if", {"l", nullptr}, "sum(if(l, 1, 0))"}, + {DEFAULT_SCHEMA, "split_part", {"string", "delimiter", "position", nullptr}, "coalesce(string_split(string, delimiter)[position],'')"}, + {DEFAULT_SCHEMA, "geomean", {"x", nullptr}, "exp(avg(ln(x)))"}, + {DEFAULT_SCHEMA, "geometric_mean", {"x", nullptr}, "geomean(x)"}, + + {DEFAULT_SCHEMA, "list_reverse", {"l", nullptr}, "l[:-:-1]"}, + {DEFAULT_SCHEMA, "array_reverse", {"l", nullptr}, "list_reverse(l)"}, + + // FIXME implement as actual function if we encounter a lot of performance issues. Complexity now: n * m, with hashing possibly n + m + {DEFAULT_SCHEMA, "list_intersect", {"l1", "l2", nullptr}, "list_filter(l1, (x) -> list_contains(l2, x))"}, + {DEFAULT_SCHEMA, "array_intersect", {"l1", "l2", nullptr}, "list_intersect(l1, l2)"}, + + {DEFAULT_SCHEMA, "list_has_any", {"l1", "l2", nullptr}, "CASE WHEN l1 IS NULL THEN NULL WHEN l2 IS NULL THEN NULL WHEN len(list_intersect(l1, l2)) > 0 THEN true ELSE false END"}, + {DEFAULT_SCHEMA, "array_has_any", {"l1", "l2", nullptr}, "list_has_any(l1, l2)" }, + {DEFAULT_SCHEMA, "&&", {"l1", "l2", nullptr}, "list_has_any(l1, l2)" }, // "&&" is the operator for "list_has_any + + {DEFAULT_SCHEMA, "list_has_all", {"l1", "l2", nullptr}, "CASE WHEN l1 IS NULL THEN NULL WHEN l2 IS NULL THEN NULL WHEN len(list_intersect(l2, l1)) = len(list_filter(l2, x -> x IS NOT NULL)) THEN true ELSE false END"}, + {DEFAULT_SCHEMA, "array_has_all", {"l1", "l2", nullptr}, "list_has_all(l1, l2)" }, + {DEFAULT_SCHEMA, "@>", {"l1", "l2", nullptr}, "list_has_all(l1, l2)" }, // "@>" is the operator for "list_has_all + {DEFAULT_SCHEMA, "<@", {"l1", "l2", nullptr}, "list_has_all(l2, l1)" }, // "<@" is the operator for "list_has_all + + // algebraic list aggregates + {DEFAULT_SCHEMA, "list_avg", {"l", nullptr}, "list_aggr(l, 'avg')"}, + {DEFAULT_SCHEMA, "list_var_samp", {"l", nullptr}, "list_aggr(l, 'var_samp')"}, + {DEFAULT_SCHEMA, "list_var_pop", {"l", nullptr}, "list_aggr(l, 'var_pop')"}, + {DEFAULT_SCHEMA, "list_stddev_pop", {"l", nullptr}, "list_aggr(l, 'stddev_pop')"}, + {DEFAULT_SCHEMA, "list_stddev_samp", {"l", nullptr}, "list_aggr(l, 'stddev_samp')"}, + {DEFAULT_SCHEMA, "list_sem", {"l", nullptr}, "list_aggr(l, 'sem')"}, + + // distributive list aggregates + {DEFAULT_SCHEMA, "list_approx_count_distinct", {"l", nullptr}, "list_aggr(l, 'approx_count_distinct')"}, + {DEFAULT_SCHEMA, "list_bit_xor", {"l", nullptr}, "list_aggr(l, 'bit_xor')"}, + {DEFAULT_SCHEMA, "list_bit_or", {"l", nullptr}, "list_aggr(l, 'bit_or')"}, + {DEFAULT_SCHEMA, "list_bit_and", {"l", nullptr}, "list_aggr(l, 'bit_and')"}, + {DEFAULT_SCHEMA, "list_bool_and", {"l", nullptr}, "list_aggr(l, 'bool_and')"}, + {DEFAULT_SCHEMA, "list_bool_or", {"l", nullptr}, "list_aggr(l, 'bool_or')"}, + {DEFAULT_SCHEMA, "list_count", {"l", nullptr}, "list_aggr(l, 'count')"}, + {DEFAULT_SCHEMA, "list_entropy", {"l", nullptr}, "list_aggr(l, 'entropy')"}, + {DEFAULT_SCHEMA, "list_last", {"l", nullptr}, "list_aggr(l, 'last')"}, + {DEFAULT_SCHEMA, "list_first", {"l", nullptr}, "list_aggr(l, 'first')"}, + {DEFAULT_SCHEMA, "list_any_value", {"l", nullptr}, "list_aggr(l, 'any_value')"}, + {DEFAULT_SCHEMA, "list_kurtosis", {"l", nullptr}, "list_aggr(l, 'kurtosis')"}, + {DEFAULT_SCHEMA, "list_min", {"l", nullptr}, "list_aggr(l, 'min')"}, + {DEFAULT_SCHEMA, "list_max", {"l", nullptr}, "list_aggr(l, 'max')"}, + {DEFAULT_SCHEMA, "list_product", {"l", nullptr}, "list_aggr(l, 'product')"}, + {DEFAULT_SCHEMA, "list_skewness", {"l", nullptr}, "list_aggr(l, 'skewness')"}, + {DEFAULT_SCHEMA, "list_sum", {"l", nullptr}, "list_aggr(l, 'sum')"}, + {DEFAULT_SCHEMA, "list_string_agg", {"l", nullptr}, "list_aggr(l, 'string_agg')"}, + + // holistic list aggregates + {DEFAULT_SCHEMA, "list_mode", {"l", nullptr}, "list_aggr(l, 'mode')"}, + {DEFAULT_SCHEMA, "list_median", {"l", nullptr}, "list_aggr(l, 'median')"}, + {DEFAULT_SCHEMA, "list_mad", {"l", nullptr}, "list_aggr(l, 'mad')"}, + + // nested list aggregates + {DEFAULT_SCHEMA, "list_histogram", {"l", nullptr}, "list_aggr(l, 'histogram')"}, + + // date functions + {DEFAULT_SCHEMA, "date_add", {"date", "interval", nullptr}, "date + interval"}, + + {nullptr, nullptr, {nullptr}, nullptr} + }; + +unique_ptr DefaultFunctionGenerator::CreateInternalTableMacroInfo(DefaultMacro &default_macro, unique_ptr function) { + for (idx_t param_idx = 0; default_macro.parameters[param_idx] != nullptr; param_idx++) { + function->parameters.push_back( + make_uniq(default_macro.parameters[param_idx])); + } + + auto type = function->type == MacroType::TABLE_MACRO ? CatalogType::TABLE_MACRO_ENTRY : CatalogType::MACRO_ENTRY; + auto bind_info = make_uniq(type); + bind_info->schema = default_macro.schema; + bind_info->name = default_macro.name; + bind_info->temporary = true; + bind_info->internal = true; + bind_info->function = std::move(function); + return bind_info; + +} + +unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(DefaultMacro &default_macro) { + // parse the expression + auto expressions = Parser::ParseExpressionList(default_macro.macro); + D_ASSERT(expressions.size() == 1); + + auto result = make_uniq(std::move(expressions[0])); + return CreateInternalTableMacroInfo(default_macro, std::move(result)); +} + +unique_ptr DefaultFunctionGenerator::CreateInternalTableMacroInfo(DefaultMacro &default_macro) { + Parser parser; + parser.ParseQuery(default_macro.macro); + D_ASSERT(parser.statements.size() == 1); + D_ASSERT(parser.statements[0]->type == StatementType::SELECT_STATEMENT); + + auto &select = parser.statements[0]->Cast(); + auto result = make_uniq(std::move(select.node)); + return CreateInternalTableMacroInfo(default_macro, std::move(result)); +} + +static unique_ptr GetDefaultFunction(const string &input_schema, const string &input_name) { + auto schema = StringUtil::Lower(input_schema); + auto name = StringUtil::Lower(input_name); + for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { + if (internal_macros[index].schema == schema && internal_macros[index].name == name) { + return DefaultFunctionGenerator::CreateInternalMacroInfo(internal_macros[index]); + } + } + return nullptr; +} + +DefaultFunctionGenerator::DefaultFunctionGenerator(Catalog &catalog, SchemaCatalogEntry &schema) + : DefaultGenerator(catalog), schema(schema) { +} + +unique_ptr DefaultFunctionGenerator::CreateDefaultEntry(ClientContext &context, + const string &entry_name) { + auto info = GetDefaultFunction(schema.name, entry_name); + if (info) { + return make_uniq_base(catalog, schema, info->Cast()); + } + return nullptr; +} + +vector DefaultFunctionGenerator::GetDefaultEntries() { + vector result; + for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { + if (StringUtil::Lower(internal_macros[index].name) != internal_macros[index].name) { + throw InternalException("Default macro name %s should be lowercase", internal_macros[index].name); + } + if (internal_macros[index].schema == schema.name) { + result.emplace_back(internal_macros[index].name); + } + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_schemas.cpp b/src/duckdb/src/catalog/default/default_schemas.cpp new file mode 100644 index 00000000..394e1ba2 --- /dev/null +++ b/src/duckdb/src/catalog/default/default_schemas.cpp @@ -0,0 +1,41 @@ +#include "duckdb/catalog/default/default_schemas.hpp" +#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +struct DefaultSchema { + const char *name; +}; + +static DefaultSchema internal_schemas[] = {{"information_schema"}, {"pg_catalog"}, {nullptr}}; + +static bool GetDefaultSchema(const string &input_schema) { + auto schema = StringUtil::Lower(input_schema); + for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { + if (internal_schemas[index].name == schema) { + return true; + } + } + return false; +} + +DefaultSchemaGenerator::DefaultSchemaGenerator(Catalog &catalog) : DefaultGenerator(catalog) { +} + +unique_ptr DefaultSchemaGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { + if (GetDefaultSchema(entry_name)) { + return make_uniq_base(catalog, StringUtil::Lower(entry_name), true); + } + return nullptr; +} + +vector DefaultSchemaGenerator::GetDefaultEntries() { + vector result; + for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { + result.emplace_back(internal_schemas[index].name); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_types.cpp b/src/duckdb/src/catalog/default/default_types.cpp new file mode 100644 index 00000000..23edac04 --- /dev/null +++ b/src/duckdb/src/catalog/default/default_types.cpp @@ -0,0 +1,53 @@ +#include "duckdb/catalog/default/default_types.hpp" + +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/catalog/default/builtin_types/types.hpp" + +namespace duckdb { + +LogicalTypeId DefaultTypeGenerator::GetDefaultType(const string &name) { + auto &internal_types = BUILTIN_TYPES; + for (auto &type : internal_types) { + if (StringUtil::CIEquals(name, type.name)) { + return type.type; + } + } + return LogicalType::INVALID; +} + +DefaultTypeGenerator::DefaultTypeGenerator(Catalog &catalog, SchemaCatalogEntry &schema) + : DefaultGenerator(catalog), schema(schema) { +} + +unique_ptr DefaultTypeGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { + if (schema.name != DEFAULT_SCHEMA) { + return nullptr; + } + auto type_id = GetDefaultType(entry_name); + if (type_id == LogicalTypeId::INVALID) { + return nullptr; + } + CreateTypeInfo info; + info.name = entry_name; + info.type = LogicalType(type_id); + info.internal = true; + info.temporary = true; + return make_uniq_base(catalog, schema, info); +} + +vector DefaultTypeGenerator::GetDefaultEntries() { + vector result; + if (schema.name != DEFAULT_SCHEMA) { + return result; + } + auto &internal_types = BUILTIN_TYPES; + for (auto &type : internal_types) { + result.emplace_back(StringUtil::Lower(type.name)); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_views.cpp b/src/duckdb/src/catalog/default/default_views.cpp new file mode 100644 index 00000000..72004502 --- /dev/null +++ b/src/duckdb/src/catalog/default/default_views.cpp @@ -0,0 +1,94 @@ +#include "duckdb/catalog/default/default_views.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +struct DefaultView { + const char *schema; + const char *name; + const char *sql; +}; + +static DefaultView internal_views[] = { + {DEFAULT_SCHEMA, "pragma_database_list", "SELECT database_oid AS seq, database_name AS name, path AS file FROM duckdb_databases() WHERE NOT internal ORDER BY 1"}, + {DEFAULT_SCHEMA, "sqlite_master", "select 'table' \"type\", table_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_tables union all select 'view' \"type\", view_name \"name\", view_name \"tbl_name\", 0 rootpage, sql from duckdb_views union all select 'index' \"type\", index_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_indexes;"}, + {DEFAULT_SCHEMA, "sqlite_schema", "SELECT * FROM sqlite_master"}, + {DEFAULT_SCHEMA, "sqlite_temp_master", "SELECT * FROM sqlite_master"}, + {DEFAULT_SCHEMA, "sqlite_temp_schema", "SELECT * FROM sqlite_master"}, + {DEFAULT_SCHEMA, "duckdb_constraints", "SELECT * FROM duckdb_constraints()"}, + {DEFAULT_SCHEMA, "duckdb_columns", "SELECT * FROM duckdb_columns() WHERE NOT internal"}, + {DEFAULT_SCHEMA, "duckdb_databases", "SELECT * FROM duckdb_databases() WHERE NOT internal"}, + {DEFAULT_SCHEMA, "duckdb_indexes", "SELECT * FROM duckdb_indexes()"}, + {DEFAULT_SCHEMA, "duckdb_schemas", "SELECT * FROM duckdb_schemas() WHERE NOT internal"}, + {DEFAULT_SCHEMA, "duckdb_tables", "SELECT * FROM duckdb_tables() WHERE NOT internal"}, + {DEFAULT_SCHEMA, "duckdb_types", "SELECT * FROM duckdb_types()"}, + {DEFAULT_SCHEMA, "duckdb_views", "SELECT * FROM duckdb_views() WHERE NOT internal"}, + {"pg_catalog", "pg_am", "SELECT 0 oid, 'art' amname, NULL amhandler, 'i' amtype"}, + {"pg_catalog", "pg_attribute", "SELECT table_oid attrelid, column_name attname, data_type_id atttypid, 0 attstattarget, NULL attlen, column_index attnum, 0 attndims, -1 attcacheoff, case when data_type ilike '%decimal%' then numeric_precision*1000+numeric_scale else -1 end atttypmod, false attbyval, NULL attstorage, NULL attalign, NOT is_nullable attnotnull, column_default IS NOT NULL atthasdef, false atthasmissing, '' attidentity, '' attgenerated, false attisdropped, true attislocal, 0 attinhcount, 0 attcollation, NULL attcompression, NULL attacl, NULL attoptions, NULL attfdwoptions, NULL attmissingval FROM duckdb_columns()"}, + {"pg_catalog", "pg_attrdef", "SELECT column_index oid, table_oid adrelid, column_index adnum, column_default adbin from duckdb_columns() where column_default is not null;"}, + {"pg_catalog", "pg_class", "SELECT table_oid oid, table_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, estimated_size::real reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, index_count > 0 relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'r' relkind, column_count relnatts, check_constraint_count relchecks, false relhasoids, has_primary_key relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_tables() UNION ALL SELECT view_oid oid, view_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'v' relkind, column_count relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_views() UNION ALL SELECT sequence_oid oid, sequence_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'S' relkind, 0 relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_sequences() UNION ALL SELECT index_oid oid, index_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, 't' relpersistence, 'i' relkind, NULL relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_indexes()"}, + {"pg_catalog", "pg_constraint", "SELECT table_oid*1000000+constraint_index oid, constraint_text conname, schema_oid connamespace, CASE constraint_type WHEN 'CHECK' then 'c' WHEN 'UNIQUE' then 'u' WHEN 'PRIMARY KEY' THEN 'p' WHEN 'FOREIGN KEY' THEN 'f' ELSE 'x' END contype, false condeferrable, false condeferred, true convalidated, table_oid conrelid, 0 contypid, 0 conindid, 0 conparentid, 0 confrelid, NULL confupdtype, NULL confdeltype, NULL confmatchtype, true conislocal, 0 coninhcount, false connoinherit, constraint_column_indexes conkey, NULL confkey, NULL conpfeqop, NULL conppeqop, NULL conffeqop, NULL conexclop, expression conbin FROM duckdb_constraints()"}, + {"pg_catalog", "pg_database", "SELECT database_oid oid, database_name datname FROM duckdb_databases()"}, + {"pg_catalog", "pg_depend", "SELECT * FROM duckdb_dependencies()"}, + {"pg_catalog", "pg_description", "SELECT NULL objoid, NULL classoid, NULL objsubid, NULL description WHERE 1=0"}, + {"pg_catalog", "pg_enum", "SELECT NULL oid, a.type_oid enumtypid, list_position(b.labels, a.elabel) enumsortorder, a.elabel enumlabel FROM (SELECT UNNEST(labels) elabel, type_oid FROM duckdb_types() WHERE logical_type='ENUM') a JOIN duckdb_types() b ON a.type_oid=b.type_oid;"}, + {"pg_catalog", "pg_index", "SELECT index_oid indexrelid, table_oid indrelid, 0 indnatts, 0 indnkeyatts, is_unique indisunique, is_primary indisprimary, false indisexclusion, true indimmediate, false indisclustered, true indisvalid, false indcheckxmin, true indisready, true indislive, false indisreplident, NULL::INT[] indkey, NULL::OID[] indcollation, NULL::OID[] indclass, NULL::INT[] indoption, expressions indexprs, NULL indpred FROM duckdb_indexes()"}, + {"pg_catalog", "pg_indexes", "SELECT schema_name schemaname, table_name tablename, index_name indexname, NULL \"tablespace\", sql indexdef FROM duckdb_indexes()"}, + {"pg_catalog", "pg_namespace", "SELECT oid, schema_name nspname, 0 nspowner, NULL nspacl FROM duckdb_schemas()"}, + {"pg_catalog", "pg_proc", "SELECT f.function_oid oid, function_name proname, s.oid pronamespace, varargs provariadic, function_type = 'aggregate' proisagg, function_type = 'table' proretset, return_type prorettype, parameter_types proargtypes, parameters proargnames FROM duckdb_functions() f LEFT JOIN duckdb_schemas() s USING (database_name, schema_name)"}, + {"pg_catalog", "pg_sequence", "SELECT sequence_oid seqrelid, 0 seqtypid, start_value seqstart, increment_by seqincrement, max_value seqmax, min_value seqmin, 0 seqcache, cycle seqcycle FROM duckdb_sequences()"}, + {"pg_catalog", "pg_sequences", "SELECT schema_name schemaname, sequence_name sequencename, 'duckdb' sequenceowner, 0 data_type, start_value, min_value, max_value, increment_by, cycle, 0 cache_size, last_value FROM duckdb_sequences()"}, + {"pg_catalog", "pg_settings", "SELECT name, value setting, description short_desc, CASE WHEN input_type = 'VARCHAR' THEN 'string' WHEN input_type = 'BOOLEAN' THEN 'bool' WHEN input_type IN ('BIGINT', 'UBIGINT') THEN 'integer' ELSE input_type END vartype FROM duckdb_settings()"}, + {"pg_catalog", "pg_tables", "SELECT schema_name schemaname, table_name tablename, 'duckdb' tableowner, NULL \"tablespace\", index_count > 0 hasindexes, false hasrules, false hastriggers FROM duckdb_tables()"}, + {"pg_catalog", "pg_tablespace", "SELECT 0 oid, 'pg_default' spcname, 0 spcowner, NULL spcacl, NULL spcoptions"}, + {"pg_catalog", "pg_type", "SELECT type_oid oid, format_pg_type(type_name) typname, schema_oid typnamespace, 0 typowner, type_size typlen, false typbyval, CASE WHEN logical_type='ENUM' THEN 'e' else 'b' end typtype, CASE WHEN type_category='NUMERIC' THEN 'N' WHEN type_category='STRING' THEN 'S' WHEN type_category='DATETIME' THEN 'D' WHEN type_category='BOOLEAN' THEN 'B' WHEN type_category='COMPOSITE' THEN 'C' WHEN type_category='USER' THEN 'U' ELSE 'X' END typcategory, false typispreferred, true typisdefined, NULL typdelim, NULL typrelid, NULL typsubscript, NULL typelem, NULL typarray, NULL typinput, NULL typoutput, NULL typreceive, NULL typsend, NULL typmodin, NULL typmodout, NULL typanalyze, 'd' typalign, 'p' typstorage, NULL typnotnull, NULL typbasetype, NULL typtypmod, NULL typndims, NULL typcollation, NULL typdefaultbin, NULL typdefault, NULL typacl FROM duckdb_types() WHERE type_size IS NOT NULL;"}, + {"pg_catalog", "pg_views", "SELECT schema_name schemaname, view_name viewname, 'duckdb' viewowner, sql definition FROM duckdb_views()"}, + {"information_schema", "columns", "SELECT database_name table_catalog, schema_name table_schema, table_name, column_name, column_index ordinal_position, column_default, CASE WHEN is_nullable THEN 'YES' ELSE 'NO' END is_nullable, data_type, character_maximum_length, NULL character_octet_length, numeric_precision, numeric_precision_radix, numeric_scale, NULL datetime_precision, NULL interval_type, NULL interval_precision, NULL character_set_catalog, NULL character_set_schema, NULL character_set_name, NULL collation_catalog, NULL collation_schema, NULL collation_name, NULL domain_catalog, NULL domain_schema, NULL domain_name, NULL udt_catalog, NULL udt_schema, NULL udt_name, NULL scope_catalog, NULL scope_schema, NULL scope_name, NULL maximum_cardinality, NULL dtd_identifier, NULL is_self_referencing, NULL is_identity, NULL identity_generation, NULL identity_start, NULL identity_increment, NULL identity_maximum, NULL identity_minimum, NULL identity_cycle, NULL is_generated, NULL generation_expression, NULL is_updatable FROM duckdb_columns;"}, + {"information_schema", "schemata", "SELECT database_name catalog_name, schema_name, 'duckdb' schema_owner, NULL default_character_set_catalog, NULL default_character_set_schema, NULL default_character_set_name, sql sql_path FROM duckdb_schemas()"}, + {"information_schema", "tables", "SELECT database_name table_catalog, schema_name table_schema, table_name, CASE WHEN temporary THEN 'LOCAL TEMPORARY' ELSE 'BASE TABLE' END table_type, NULL self_referencing_column_name, NULL reference_generation, NULL user_defined_type_catalog, NULL user_defined_type_schema, NULL user_defined_type_name, 'YES' is_insertable_into, 'NO' is_typed, CASE WHEN temporary THEN 'PRESERVE' ELSE NULL END commit_action FROM duckdb_tables() UNION ALL SELECT database_name table_catalog, schema_name table_schema, view_name table_name, 'VIEW' table_type, NULL self_referencing_column_name, NULL reference_generation, NULL user_defined_type_catalog, NULL user_defined_type_schema, NULL user_defined_type_name, 'NO' is_insertable_into, 'NO' is_typed, NULL commit_action FROM duckdb_views;"}, + {nullptr, nullptr, nullptr}}; + +static unique_ptr GetDefaultView(ClientContext &context, const string &input_schema, const string &input_name) { + auto schema = StringUtil::Lower(input_schema); + auto name = StringUtil::Lower(input_name); + for (idx_t index = 0; internal_views[index].name != nullptr; index++) { + if (internal_views[index].schema == schema && internal_views[index].name == name) { + auto result = make_uniq(); + result->schema = schema; + result->view_name = name; + result->sql = internal_views[index].sql; + result->temporary = true; + result->internal = true; + + return CreateViewInfo::FromSelect(context, std::move(result)); + } + } + return nullptr; +} + +DefaultViewGenerator::DefaultViewGenerator(Catalog &catalog, SchemaCatalogEntry &schema) + : DefaultGenerator(catalog), schema(schema) { +} + +unique_ptr DefaultViewGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { + auto info = GetDefaultView(context, schema.name, entry_name); + if (info) { + return make_uniq_base(catalog, schema, *info); + } + return nullptr; +} + +vector DefaultViewGenerator::GetDefaultEntries() { + vector result; + for (idx_t index = 0; internal_views[index].name != nullptr; index++) { + if (internal_views[index].schema == schema.name) { + result.emplace_back(internal_views[index].name); + } + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/dependency_list.cpp b/src/duckdb/src/catalog/dependency_list.cpp new file mode 100644 index 00000000..76ae806b --- /dev/null +++ b/src/duckdb/src/catalog/dependency_list.cpp @@ -0,0 +1,26 @@ +#include "duckdb/catalog/dependency_list.hpp" +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +void DependencyList::AddDependency(CatalogEntry &entry) { + if (entry.internal) { + return; + } + set.insert(entry); +} + +void DependencyList::VerifyDependencies(Catalog &catalog, const string &name) { + for (auto &dep_entry : set) { + auto &dep = dep_entry.get(); + if (&dep.ParentCatalog() != &catalog) { + throw DependencyException( + "Error adding dependency for object \"%s\" - dependency \"%s\" is in catalog " + "\"%s\", which does not match the catalog \"%s\".\nCross catalog dependencies are not supported.", + name, dep.name, dep.ParentCatalog().GetName(), catalog.GetName()); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/dependency_manager.cpp b/src/duckdb/src/catalog/dependency_manager.cpp new file mode 100644 index 00000000..901604a5 --- /dev/null +++ b/src/duckdb/src/catalog/dependency_manager.cpp @@ -0,0 +1,192 @@ +#include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/catalog/mapping_value.hpp" +#include "duckdb/catalog/dependency_list.hpp" + +namespace duckdb { + +DependencyManager::DependencyManager(DuckCatalog &catalog) : catalog(catalog) { +} + +void DependencyManager::AddObject(CatalogTransaction transaction, CatalogEntry &object, DependencyList &dependencies) { + // check for each object in the sources if they were not deleted yet + for (auto &dep : dependencies.set) { + auto &dependency = dep.get(); + if (&dependency.ParentCatalog() != &object.ParentCatalog()) { + throw DependencyException( + "Error adding dependency for object \"%s\" - dependency \"%s\" is in catalog " + "\"%s\", which does not match the catalog \"%s\".\nCross catalog dependencies are not supported.", + object.name, dependency.name, dependency.ParentCatalog().GetName(), object.ParentCatalog().GetName()); + } + if (!dependency.set) { + throw InternalException("Dependency has no set"); + } + auto catalog_entry = dependency.set->GetEntryInternal(transaction, dependency.name, nullptr); + if (!catalog_entry) { + throw InternalException("Dependency has already been deleted?"); + } + } + // indexes do not require CASCADE to be dropped, they are simply always dropped along with the table + auto dependency_type = object.type == CatalogType::INDEX_ENTRY ? DependencyType::DEPENDENCY_AUTOMATIC + : DependencyType::DEPENDENCY_REGULAR; + // add the object to the dependents_map of each object that it depends on + for (auto &dependency : dependencies.set) { + auto &set = dependents_map[dependency]; + set.insert(Dependency(object, dependency_type)); + } + // create the dependents map for this object: it starts out empty + dependents_map[object] = dependency_set_t(); + dependencies_map[object] = dependencies.set; +} + +void DependencyManager::DropObject(CatalogTransaction transaction, CatalogEntry &object, bool cascade) { + D_ASSERT(dependents_map.find(object) != dependents_map.end()); + + // first check the objects that depend on this object + auto &dependent_objects = dependents_map[object]; + for (auto &dep : dependent_objects) { + // look up the entry in the catalog set + auto &entry = dep.entry.get(); + auto &catalog_set = *entry.set; + auto mapping_value = catalog_set.GetMapping(transaction, entry.name, true /* get_latest */); + if (mapping_value == nullptr) { + continue; + } + auto dependency_entry = catalog_set.GetEntryInternal(transaction, mapping_value->index); + if (!dependency_entry) { + // the dependent object was already deleted, no conflict + continue; + } + // conflict: attempting to delete this object but the dependent object still exists + if (cascade || dep.dependency_type == DependencyType::DEPENDENCY_AUTOMATIC || + dep.dependency_type == DependencyType::DEPENDENCY_OWNS) { + // cascade: drop the dependent object + catalog_set.DropEntryInternal(transaction, mapping_value->index.Copy(), *dependency_entry, cascade); + } else { + // no cascade and there are objects that depend on this object: throw error + throw DependencyException("Cannot drop entry \"%s\" because there are entries that " + "depend on it. Use DROP...CASCADE to drop all dependents.", + object.name); + } + } +} + +void DependencyManager::AlterObject(CatalogTransaction transaction, CatalogEntry &old_obj, CatalogEntry &new_obj) { + D_ASSERT(dependents_map.find(old_obj) != dependents_map.end()); + D_ASSERT(dependencies_map.find(old_obj) != dependencies_map.end()); + + // first check the objects that depend on this object + catalog_entry_vector_t owned_objects_to_add; + auto &dependent_objects = dependents_map[old_obj]; + for (auto &dep : dependent_objects) { + // look up the entry in the catalog set + auto &entry = dep.entry.get(); + auto &catalog_set = *entry.set; + auto dependency_entry = catalog_set.GetEntryInternal(transaction, entry.name, nullptr); + if (!dependency_entry) { + // the dependent object was already deleted, no conflict + continue; + } + if (dep.dependency_type == DependencyType::DEPENDENCY_OWNS) { + // the dependent object is owned by the current object + owned_objects_to_add.push_back(dep.entry); + continue; + } + // conflict: attempting to alter this object but the dependent object still exists + // no cascade and there are objects that depend on this object: throw error + throw DependencyException("Cannot alter entry \"%s\" because there are entries that " + "depend on it.", + old_obj.name); + } + // add the new object to the dependents_map of each object that it depends on + auto &old_dependencies = dependencies_map[old_obj]; + for (auto &dep : old_dependencies) { + auto &dependency = dep.get(); + dependents_map[dependency].insert(new_obj); + } + + // We might have to add a type dependency + // add the new object to the dependency manager + dependents_map[new_obj] = dependency_set_t(); + dependencies_map[new_obj] = old_dependencies; + + for (auto &dependency : owned_objects_to_add) { + dependents_map[new_obj].insert(Dependency(dependency, DependencyType::DEPENDENCY_OWNS)); + dependents_map[dependency].insert(Dependency(new_obj, DependencyType::DEPENDENCY_OWNED_BY)); + dependencies_map[new_obj].insert(dependency); + } +} + +void DependencyManager::EraseObject(CatalogEntry &object) { + // obtain the writing lock + EraseObjectInternal(object); +} + +void DependencyManager::EraseObjectInternal(CatalogEntry &object) { + if (dependents_map.find(object) == dependents_map.end()) { + // dependencies already removed + return; + } + D_ASSERT(dependents_map.find(object) != dependents_map.end()); + D_ASSERT(dependencies_map.find(object) != dependencies_map.end()); + // now for each of the dependencies, erase the entries from the dependents_map + for (auto &dependency : dependencies_map[object]) { + auto entry = dependents_map.find(dependency); + if (entry != dependents_map.end()) { + D_ASSERT(entry->second.find(object) != entry->second.end()); + entry->second.erase(object); + } + } + // erase the dependents and dependencies for this object + dependents_map.erase(object); + dependencies_map.erase(object); +} + +void DependencyManager::Scan(const std::function &callback) { + lock_guard write_lock(catalog.GetWriteLock()); + for (auto &entry : dependents_map) { + for (auto &dependent : entry.second) { + callback(entry.first, dependent.entry, dependent.dependency_type); + } + } +} + +void DependencyManager::AddOwnership(CatalogTransaction transaction, CatalogEntry &owner, CatalogEntry &entry) { + // lock the catalog for writing + lock_guard write_lock(catalog.GetWriteLock()); + + // If the owner is already owned by something else, throw an error + for (auto &dep : dependents_map[owner]) { + if (dep.dependency_type == DependencyType::DEPENDENCY_OWNED_BY) { + throw DependencyException(owner.name + " already owned by " + dep.entry.get().name); + } + } + + // If the entry is already owned, throw an error + for (auto &dep : dependents_map[entry]) { + // if the entry is already owned, throw error + if (&dep.entry.get() != &owner) { + throw DependencyException(entry.name + " already depends on " + dep.entry.get().name); + } + // if the entry owns the owner, throw error + if (&dep.entry.get() == &owner && dep.dependency_type == DependencyType::DEPENDENCY_OWNS) { + throw DependencyException(entry.name + " already owns " + owner.name + + ". Cannot have circular dependencies"); + } + } + + // Emplace guarantees that the same object cannot be inserted twice in the unordered_set + // In the case AddOwnership is called twice, because of emplace, the object will not be repeated in the set. + // We use an automatic dependency because if the Owner gets deleted, then the owned objects are also deleted + dependents_map[owner].emplace(entry, DependencyType::DEPENDENCY_OWNS); + dependents_map[entry].emplace(owner, DependencyType::DEPENDENCY_OWNED_BY); + dependencies_map[owner].emplace(entry); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/duck_catalog.cpp b/src/duckdb/src/catalog/duck_catalog.cpp new file mode 100644 index 00000000..879761d2 --- /dev/null +++ b/src/duckdb/src/catalog/duck_catalog.cpp @@ -0,0 +1,154 @@ +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" +#include "duckdb/catalog/default/default_schemas.hpp" +#include "duckdb/function/built_in_functions.hpp" +#include "duckdb/main/attached_database.hpp" +#ifndef DISABLE_CORE_FUNCTIONS_EXTENSION +#include "duckdb/core_functions/core_functions.hpp" +#endif + +namespace duckdb { + +DuckCatalog::DuckCatalog(AttachedDatabase &db) + : Catalog(db), dependency_manager(make_uniq(*this)), + schemas(make_uniq(*this, make_uniq(*this))) { +} + +DuckCatalog::~DuckCatalog() { +} + +void DuckCatalog::Initialize(bool load_builtin) { + // first initialize the base system catalogs + // these are never written to the WAL + // we start these at 1 because deleted entries default to 0 + auto data = CatalogTransaction::GetSystemTransaction(GetDatabase()); + + // create the default schema + CreateSchemaInfo info; + info.schema = DEFAULT_SCHEMA; + info.internal = true; + CreateSchema(data, info); + + if (load_builtin) { + // initialize default functions + BuiltinFunctions builtin(data, *this); + builtin.Initialize(); + +#ifndef DISABLE_CORE_FUNCTIONS_EXTENSION + CoreFunctions::RegisterFunctions(*this, data); +#endif + } + + Verify(); +} + +bool DuckCatalog::IsDuckCatalog() { + return true; +} + +//===--------------------------------------------------------------------===// +// Schema +//===--------------------------------------------------------------------===// +optional_ptr DuckCatalog::CreateSchemaInternal(CatalogTransaction transaction, CreateSchemaInfo &info) { + DependencyList dependencies; + auto entry = make_uniq(*this, info.schema, info.internal); + auto result = entry.get(); + if (!schemas->CreateEntry(transaction, info.schema, std::move(entry), dependencies)) { + return nullptr; + } + return (CatalogEntry *)result; +} + +optional_ptr DuckCatalog::CreateSchema(CatalogTransaction transaction, CreateSchemaInfo &info) { + D_ASSERT(!info.schema.empty()); + auto result = CreateSchemaInternal(transaction, info); + if (!result) { + switch (info.on_conflict) { + case OnCreateConflict::ERROR_ON_CONFLICT: + throw CatalogException("Schema with name %s already exists!", info.schema); + case OnCreateConflict::REPLACE_ON_CONFLICT: { + DropInfo drop_info; + drop_info.type = CatalogType::SCHEMA_ENTRY; + drop_info.catalog = info.catalog; + drop_info.name = info.schema; + DropSchema(transaction, drop_info); + result = CreateSchemaInternal(transaction, info); + if (!result) { + throw InternalException("Failed to create schema entry in CREATE_OR_REPLACE"); + } + break; + } + case OnCreateConflict::IGNORE_ON_CONFLICT: + break; + default: + throw InternalException("Unsupported OnCreateConflict for CreateSchema"); + } + return nullptr; + } + return result; +} + +void DuckCatalog::DropSchema(CatalogTransaction transaction, DropInfo &info) { + D_ASSERT(!info.name.empty()); + ModifyCatalog(); + if (!schemas->DropEntry(transaction, info.name, info.cascade)) { + if (info.if_not_found == OnEntryNotFound::THROW_EXCEPTION) { + throw CatalogException("Schema with name \"%s\" does not exist!", info.name); + } + } +} + +void DuckCatalog::DropSchema(ClientContext &context, DropInfo &info) { + DropSchema(GetCatalogTransaction(context), info); +} + +void DuckCatalog::ScanSchemas(ClientContext &context, std::function callback) { + schemas->Scan(GetCatalogTransaction(context), + [&](CatalogEntry &entry) { callback(entry.Cast()); }); +} + +void DuckCatalog::ScanSchemas(std::function callback) { + schemas->Scan([&](CatalogEntry &entry) { callback(entry.Cast()); }); +} + +optional_ptr DuckCatalog::GetSchema(CatalogTransaction transaction, const string &schema_name, + OnEntryNotFound if_not_found, QueryErrorContext error_context) { + D_ASSERT(!schema_name.empty()); + auto entry = schemas->GetEntry(transaction, schema_name); + if (!entry) { + if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { + throw CatalogException(error_context.FormatError("Schema with name %s does not exist!", schema_name)); + } + return nullptr; + } + return &entry->Cast(); +} + +DatabaseSize DuckCatalog::GetDatabaseSize(ClientContext &context) { + return db.GetStorageManager().GetDatabaseSize(); +} + +vector DuckCatalog::GetMetadataInfo(ClientContext &context) { + return db.GetStorageManager().GetMetadataInfo(); +} + +bool DuckCatalog::InMemory() { + return db.GetStorageManager().InMemory(); +} + +string DuckCatalog::GetDBPath() { + return db.GetStorageManager().GetDBPath(); +} + +void DuckCatalog::Verify() { +#ifdef DEBUG + Catalog::Verify(); + schemas->Verify(*this); +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/similar_catalog_entry.cpp b/src/duckdb/src/catalog/similar_catalog_entry.cpp new file mode 100644 index 00000000..d3e3487b --- /dev/null +++ b/src/duckdb/src/catalog/similar_catalog_entry.cpp @@ -0,0 +1,26 @@ +#include "duckdb/catalog/similar_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +string SimilarCatalogEntry::GetQualifiedName(bool qualify_catalog, bool qualify_schema) const { + D_ASSERT(Found()); + string result; + if (qualify_catalog) { + result += schema->catalog.GetName(); + } + if (qualify_schema) { + if (!result.empty()) { + result += "."; + } + result += schema->name; + } + if (!result.empty()) { + result += "."; + } + result += name; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/adbc/adbc.cpp b/src/duckdb/src/common/adbc/adbc.cpp new file mode 100644 index 00000000..14016adc --- /dev/null +++ b/src/duckdb/src/common/adbc/adbc.cpp @@ -0,0 +1,1009 @@ +#include "duckdb/common/adbc/adbc.hpp" +#include "duckdb/common/adbc/adbc-init.hpp" + +#include "duckdb/common/string.hpp" +#include "duckdb/common/string_util.hpp" + +#include "duckdb.h" +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" + +#ifndef DUCKDB_AMALGAMATION +#include "duckdb/main/connection.hpp" +#endif + +#include "duckdb/common/adbc/single_batch_array_stream.hpp" + +#include +#include + +// We must leak the symbols of the init function +duckdb_adbc::AdbcStatusCode duckdb_adbc_init(size_t count, struct duckdb_adbc::AdbcDriver *driver, + struct duckdb_adbc::AdbcError *error) { + if (!driver) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + + driver->DatabaseNew = duckdb_adbc::DatabaseNew; + driver->DatabaseSetOption = duckdb_adbc::DatabaseSetOption; + driver->DatabaseInit = duckdb_adbc::DatabaseInit; + driver->DatabaseRelease = duckdb_adbc::DatabaseRelease; + driver->ConnectionNew = duckdb_adbc::ConnectionNew; + driver->ConnectionSetOption = duckdb_adbc::ConnectionSetOption; + driver->ConnectionInit = duckdb_adbc::ConnectionInit; + driver->ConnectionRelease = duckdb_adbc::ConnectionRelease; + driver->ConnectionGetTableTypes = duckdb_adbc::ConnectionGetTableTypes; + driver->StatementNew = duckdb_adbc::StatementNew; + driver->StatementRelease = duckdb_adbc::StatementRelease; + driver->StatementBind = duckdb_adbc::StatementBind; + driver->StatementBindStream = duckdb_adbc::StatementBindStream; + driver->StatementExecuteQuery = duckdb_adbc::StatementExecuteQuery; + driver->StatementPrepare = duckdb_adbc::StatementPrepare; + driver->StatementSetOption = duckdb_adbc::StatementSetOption; + driver->StatementSetSqlQuery = duckdb_adbc::StatementSetSqlQuery; + driver->ConnectionGetObjects = duckdb_adbc::ConnectionGetObjects; + driver->ConnectionCommit = duckdb_adbc::ConnectionCommit; + driver->ConnectionRollback = duckdb_adbc::ConnectionRollback; + driver->ConnectionReadPartition = duckdb_adbc::ConnectionReadPartition; + driver->StatementExecutePartitions = duckdb_adbc::StatementExecutePartitions; + driver->ConnectionGetInfo = duckdb_adbc::ConnectionGetInfo; + driver->StatementGetParameterSchema = duckdb_adbc::StatementGetParameterSchema; + driver->ConnectionGetTableSchema = duckdb_adbc::ConnectionGetTableSchema; + driver->StatementSetSubstraitPlan = duckdb_adbc::StatementSetSubstraitPlan; + + driver->ConnectionGetInfo = duckdb_adbc::ConnectionGetInfo; + driver->StatementGetParameterSchema = duckdb_adbc::StatementGetParameterSchema; + return ADBC_STATUS_OK; +} + +namespace duckdb_adbc { + +enum class IngestionMode { CREATE = 0, APPEND = 1 }; +struct DuckDBAdbcStatementWrapper { + ::duckdb_connection connection; + ::duckdb_arrow result; + ::duckdb_prepared_statement statement; + char *ingestion_table_name; + ArrowArrayStream ingestion_stream; + IngestionMode ingestion_mode = IngestionMode::CREATE; +}; + +static AdbcStatusCode QueryInternal(struct AdbcConnection *connection, struct ArrowArrayStream *out, const char *query, + struct AdbcError *error) { + AdbcStatement statement; + + auto status = StatementNew(connection, &statement, error); + if (status != ADBC_STATUS_OK) { + SetError(error, "unable to initialize statement"); + return status; + } + status = StatementSetSqlQuery(&statement, query, error); + if (status != ADBC_STATUS_OK) { + SetError(error, "unable to initialize statement"); + return status; + } + status = StatementExecuteQuery(&statement, out, nullptr, error); + if (status != ADBC_STATUS_OK) { + SetError(error, "unable to initialize statement"); + return status; + } + + return ADBC_STATUS_OK; +} + +struct DuckDBAdbcDatabaseWrapper { + //! The DuckDB Database Configuration + ::duckdb_config config; + //! The DuckDB Database + ::duckdb_database database; + //! Path of Disk-Based Database or :memory: database + std::string path; +}; + +static void EmptyErrorRelease(AdbcError *error) { + // The object is valid but doesn't contain any data that needs to be cleaned up + // Just set the release to nullptr to indicate that it's no longer valid. + error->release = nullptr; + return; +} + +void InitializeADBCError(AdbcError *error) { + if (!error) { + return; + } + error->message = nullptr; + // Don't set to nullptr, as that indicates that it's invalid + error->release = EmptyErrorRelease; + std::memset(error->sqlstate, '\0', sizeof(error->sqlstate)); + error->vendor_code = -1; +} + +AdbcStatusCode CheckResult(duckdb_state &res, AdbcError *error, const char *error_msg) { + if (!error) { + // Error should be a non-null pointer + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (res != DuckDBSuccess) { + duckdb_adbc::SetError(error, error_msg); + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; +} + +AdbcStatusCode DatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { + if (!database) { + SetError(error, "Missing database object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + database->private_data = nullptr; + // you can't malloc a struct with a non-trivial C++ constructor + // and std::string has a non-trivial constructor. so we need + // to use new and delete rather than malloc and free. + auto wrapper = new (std::nothrow) DuckDBAdbcDatabaseWrapper; + if (!wrapper) { + SetError(error, "Allocation error"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + database->private_data = wrapper; + auto res = duckdb_create_config(&wrapper->config); + return CheckResult(res, error, "Failed to allocate"); +} + +AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, + struct AdbcError *error) { + if (!statement) { + SetError(error, "Statement is not set"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!plan) { + SetError(error, "Substrait Plan is not set"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (length == 0) { + SetError(error, "Can't execute plan with size = 0"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto wrapper = reinterpret_cast(statement->private_data); + auto plan_str = std::string(reinterpret_cast(plan), length); + auto query = "CALL from_substrait('" + plan_str + "'::BLOB)"; + auto res = duckdb_prepare(wrapper->connection, query.c_str(), &wrapper->statement); + auto error_msg = duckdb_prepare_error(wrapper->statement); + return CheckResult(res, error, error_msg); +} + +AdbcStatusCode DatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, + struct AdbcError *error) { + if (!database) { + SetError(error, "Missing database object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!key) { + SetError(error, "Missing key"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; + if (strcmp(key, "path") == 0) { + wrapper->path = value; + return ADBC_STATUS_OK; + } + auto res = duckdb_set_config(wrapper->config, key, value); + + return CheckResult(res, error, "Failed to set configuration option"); +} + +AdbcStatusCode DatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { + if (!error) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!database) { + duckdb_adbc::SetError(error, "ADBC Database has an invalid pointer"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + char *errormsg; + // TODO can we set the database path via option, too? Does not look like it... + auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; + auto res = duckdb_open_ext(wrapper->path.c_str(), &wrapper->database, wrapper->config, &errormsg); + return CheckResult(res, error, errormsg); +} + +AdbcStatusCode DatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { + + if (database && database->private_data) { + auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; + + duckdb_close(&wrapper->database); + duckdb_destroy_config(&wrapper->config); + delete wrapper; + database->private_data = nullptr; + } + return ADBC_STATUS_OK; +} + +AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, const char *db_schema, + const char *table_name, struct ArrowSchema *schema, struct AdbcError *error) { + if (!connection) { + SetError(error, "Connection is not set"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (db_schema == nullptr) { + // if schema is not set, we use the default schema + db_schema = "main"; + } + if (catalog != nullptr && strlen(catalog) > 0) { + // In DuckDB this is the name of the database, not sure what's the expected functionality here, so for now, + // scream. + SetError(error, "Catalog Name is not used in DuckDB. It must be set to nullptr or an empty string"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } else if (table_name == nullptr) { + SetError(error, "AdbcConnectionGetTableSchema: must provide table_name"); + return ADBC_STATUS_INVALID_ARGUMENT; + } else if (strlen(table_name) == 0) { + SetError(error, "AdbcConnectionGetTableSchema: must provide table_name"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + ArrowArrayStream arrow_stream; + + std::string query = "SELECT * FROM "; + if (strlen(db_schema) > 0) { + query += std::string(db_schema) + "."; + } + query += std::string(table_name) + " LIMIT 0;"; + + auto success = QueryInternal(connection, &arrow_stream, query.c_str(), error); + if (success != ADBC_STATUS_OK) { + return success; + } + arrow_stream.get_schema(&arrow_stream, schema); + arrow_stream.release(&arrow_stream); + return ADBC_STATUS_OK; +} + +AdbcStatusCode ConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { + if (!connection) { + SetError(error, "Missing connection object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + connection->private_data = nullptr; + return ADBC_STATUS_OK; +} + +AdbcStatusCode ExecuteQuery(duckdb::Connection *conn, const char *query, struct AdbcError *error) { + auto res = conn->Query(query); + if (res->HasError()) { + auto error_message = "Failed to execute query \"" + std::string(query) + "\": " + res->GetError(); + SetError(error, error_message); + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; +} + +AdbcStatusCode ConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, + struct AdbcError *error) { + if (!connection) { + SetError(error, "Connection is not set"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto conn = (duckdb::Connection *)connection->private_data; + if (strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { + if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { + if (conn->HasActiveTransaction()) { + AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error); + if (status != ADBC_STATUS_OK) { + return status; + } + } else { + // no-op + } + } else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { + if (conn->HasActiveTransaction()) { + // no-op + } else { + // begin + AdbcStatusCode status = ExecuteQuery(conn, "START TRANSACTION", error); + if (status != ADBC_STATUS_OK) { + return status; + } + } + } else { + auto error_message = "Invalid connection option value " + std::string(key) + "=" + std::string(value); + SetError(error, error_message); + return ADBC_STATUS_INVALID_ARGUMENT; + } + return ADBC_STATUS_OK; + } + auto error_message = + "Unknown connection option " + std::string(key) + "=" + (value ? std::string(value) : "(NULL)"); + SetError(error, error_message); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, + size_t serialized_length, struct ArrowArrayStream *out, + struct AdbcError *error) { + SetError(error, "Read Partitions are not supported in DuckDB"); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementExecutePartitions(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcPartitions *partitions, int64_t *rows_affected, + struct AdbcError *error) { + SetError(error, "Execute Partitions are not supported in DuckDB"); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { + if (!connection) { + SetError(error, "Connection is not set"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto conn = (duckdb::Connection *)connection->private_data; + if (!conn->HasActiveTransaction()) { + SetError(error, "No active transaction, cannot commit"); + return ADBC_STATUS_INVALID_STATE; + } + + AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error); + if (status != ADBC_STATUS_OK) { + return status; + } + return ExecuteQuery(conn, "START TRANSACTION", error); +} + +AdbcStatusCode ConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { + if (!connection) { + SetError(error, "Connection is not set"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto conn = (duckdb::Connection *)connection->private_data; + if (!conn->HasActiveTransaction()) { + SetError(error, "No active transaction, cannot rollback"); + return ADBC_STATUS_INVALID_STATE; + } + + AdbcStatusCode status = ExecuteQuery(conn, "ROLLBACK", error); + if (status != ADBC_STATUS_OK) { + return status; + } + return ExecuteQuery(conn, "START TRANSACTION", error); +} + +enum class AdbcInfoCode : uint32_t { + VENDOR_NAME, + VENDOR_VERSION, + DRIVER_NAME, + DRIVER_VERSION, + DRIVER_ARROW_VERSION, + UNRECOGNIZED // always the last entry of the enum +}; + +static AdbcInfoCode ConvertToInfoCode(uint32_t info_code) { + switch (info_code) { + case 0: + return AdbcInfoCode::VENDOR_NAME; + case 1: + return AdbcInfoCode::VENDOR_VERSION; + case 2: + return AdbcInfoCode::DRIVER_NAME; + case 3: + return AdbcInfoCode::DRIVER_VERSION; + case 4: + return AdbcInfoCode::DRIVER_ARROW_VERSION; + default: + return AdbcInfoCode::UNRECOGNIZED; + } +} + +AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, + struct ArrowArrayStream *out, struct AdbcError *error) { + if (!connection) { + SetError(error, "Missing connection object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_data) { + SetError(error, "Connection is invalid"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!out) { + SetError(error, "Output parameter was not provided"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + // If 'info_codes' is NULL, we should output all the info codes we recognize + size_t length = info_codes ? info_codes_length : (size_t)AdbcInfoCode::UNRECOGNIZED; + + duckdb::string q = R"EOF( + select + name::UINTEGER as info_name, + info::UNION( + string_value VARCHAR, + bool_value BOOL, + int64_value BIGINT, + int32_bitmask INTEGER, + string_list VARCHAR[], + int32_to_int32_list_map MAP(INTEGER, INTEGER[]) + ) as info_value from values + )EOF"; + + duckdb::string results = ""; + + for (size_t i = 0; i < length; i++) { + uint32_t code = info_codes ? info_codes[i] : i; + auto info_code = ConvertToInfoCode(code); + switch (info_code) { + case AdbcInfoCode::VENDOR_NAME: { + results += "(0, 'duckdb'),"; + break; + } + case AdbcInfoCode::VENDOR_VERSION: { + results += duckdb::StringUtil::Format("(1, '%s'),", duckdb_library_version()); + break; + } + case AdbcInfoCode::DRIVER_NAME: { + results += "(2, 'ADBC DuckDB Driver'),"; + break; + } + case AdbcInfoCode::DRIVER_VERSION: { + // TODO: fill in driver version + results += "(3, '(unknown)'),"; + break; + } + case AdbcInfoCode::DRIVER_ARROW_VERSION: { + // TODO: fill in arrow version + results += "(4, '(unknown)'),"; + break; + } + case AdbcInfoCode::UNRECOGNIZED: { + // Unrecognized codes are not an error, just ignored + continue; + } + default: { + // Codes that we have implemented but not handled here are a developer error + SetError(error, "Info code recognized but not handled"); + return ADBC_STATUS_INTERNAL; + } + } + } + if (results.empty()) { + // Add a group of values so the query parses + q += "(NULL, NULL)"; + } else { + q += results; + } + q += " tbl(name, info)"; + if (results.empty()) { + // Add an impossible where clause to return an empty result set + q += " where true = false"; + } + return QueryInternal(connection, out, q.c_str(), error); +} + +AdbcStatusCode ConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, + struct AdbcError *error) { + if (!database) { + SetError(error, "Missing database object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!database->private_data) { + SetError(error, "Invalid database"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection) { + SetError(error, "Missing connection object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto database_wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; + + connection->private_data = nullptr; + auto res = duckdb_connect(database_wrapper->database, (duckdb_connection *)&connection->private_data); + return CheckResult(res, error, "Failed to connect to Database"); +} + +AdbcStatusCode ConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { + if (connection && connection->private_data) { + duckdb_disconnect((duckdb_connection *)&connection->private_data); + connection->private_data = nullptr; + } + return ADBC_STATUS_OK; +} + +// some stream callbacks + +static int get_schema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { + if (!stream || !stream->private_data || !out) { + return DuckDBError; + } + return duckdb_query_arrow_schema((duckdb_arrow)stream->private_data, (duckdb_arrow_schema *)&out); +} + +static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { + if (!stream || !stream->private_data || !out) { + return DuckDBError; + } + out->release = nullptr; + + return duckdb_query_arrow_array((duckdb_arrow)stream->private_data, (duckdb_arrow_array *)&out); +} + +void release(struct ArrowArrayStream *stream) { + if (!stream || !stream->release) { + return; + } + if (stream->private_data) { + duckdb_destroy_arrow((duckdb_arrow *)&stream->private_data); + stream->private_data = nullptr; + } + stream->release = nullptr; +} + +const char *get_last_error(struct ArrowArrayStream *stream) { + if (!stream) { + return nullptr; + } + return nullptr; + // return duckdb_query_arrow_error(stream); +} + +// this is an evil hack, normally we would need a stream factory here, but its probably much easier if the adbc clients +// just hand over a stream + +duckdb::unique_ptr +stream_produce(uintptr_t factory_ptr, + std::pair, std::vector> &project_columns, + duckdb::TableFilterSet *filters) { + + // TODO this will ignore any projections or filters but since we don't expose the scan it should be sort of fine + auto res = duckdb::make_uniq(); + res->arrow_array_stream = *(ArrowArrayStream *)factory_ptr; + return res; +} + +void stream_schema(uintptr_t factory_ptr, duckdb::ArrowSchemaWrapper &schema) { + auto stream = (ArrowArrayStream *)factory_ptr; + get_schema(stream, &schema.arrow_schema); +} + +AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, struct ArrowArrayStream *input, + struct AdbcError *error, IngestionMode ingestion_mode) { + + if (!connection) { + SetError(error, "Missing connection object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!input) { + SetError(error, "Missing input arrow stream pointer"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!table_name) { + SetError(error, "Missing database object name"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + auto cconn = (duckdb::Connection *)connection; + + auto arrow_scan = cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), + duckdb::Value::POINTER((uintptr_t)stream_produce), + duckdb::Value::POINTER((uintptr_t)input->get_schema)}); + try { + if (ingestion_mode == IngestionMode::CREATE) { + // We create the table based on an Arrow Scanner + arrow_scan->Create(table_name); + } else { + arrow_scan->CreateView("temp_adbc_view", true, true); + auto query = duckdb::StringUtil::Format("insert into \"%s\" select * from temp_adbc_view", table_name); + auto result = cconn->Query(query); + } + // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid + // double-releasing it + input->release = nullptr; + } catch (std::exception &ex) { + if (error) { + error->message = strdup(ex.what()); + } + return ADBC_STATUS_INTERNAL; + } catch (...) { + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; +} + +AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, + struct AdbcError *error) { + if (!connection) { + SetError(error, "Missing connection object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_data) { + SetError(error, "Invalid connection object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + statement->private_data = nullptr; + + auto statement_wrapper = (DuckDBAdbcStatementWrapper *)malloc(sizeof(DuckDBAdbcStatementWrapper)); + if (!statement_wrapper) { + SetError(error, "Allocation error"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + statement->private_data = statement_wrapper; + statement_wrapper->connection = (duckdb_connection)connection->private_data; + statement_wrapper->statement = nullptr; + statement_wrapper->result = nullptr; + statement_wrapper->ingestion_stream.release = nullptr; + statement_wrapper->ingestion_table_name = nullptr; + statement_wrapper->ingestion_mode = IngestionMode::CREATE; + return ADBC_STATUS_OK; +} + +AdbcStatusCode StatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { + if (!statement || !statement->private_data) { + return ADBC_STATUS_OK; + } + auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + if (wrapper->statement) { + duckdb_destroy_prepare(&wrapper->statement); + wrapper->statement = nullptr; + } + if (wrapper->result) { + duckdb_destroy_arrow(&wrapper->result); + wrapper->result = nullptr; + } + if (wrapper->ingestion_stream.release) { + wrapper->ingestion_stream.release(&wrapper->ingestion_stream); + wrapper->ingestion_stream.release = nullptr; + } + if (wrapper->ingestion_table_name) { + free(wrapper->ingestion_table_name); + wrapper->ingestion_table_name = nullptr; + } + free(statement->private_data); + statement->private_data = nullptr; + return ADBC_STATUS_OK; +} + +AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!schema) { + SetError(error, "Missing schema object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + // TODO: we might want to cache this, but then we need to return a deep copy anyways.., so I'm not sure if that + // would be worth the extra management + auto res = duckdb_prepared_arrow_schema(wrapper->statement, (duckdb_arrow_schema *)&schema); + if (res != DuckDBSuccess) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + return ADBC_STATUS_OK; +} + +AdbcStatusCode GetPreparedParameters(duckdb_connection connection, duckdb::unique_ptr &result, + ArrowArrayStream *input, AdbcError *error) { + + auto cconn = (duckdb::Connection *)connection; + + try { + auto arrow_scan = cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), + duckdb::Value::POINTER((uintptr_t)stream_produce), + duckdb::Value::POINTER((uintptr_t)input->get_schema)}); + result = arrow_scan->Execute(); + // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid + // double-releasing it + input->release = nullptr; + } catch (std::exception &ex) { + if (error) { + error->message = strdup(ex.what()); + } + return ADBC_STATUS_INTERNAL; + } catch (...) { + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; +} + +static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *statement, AdbcError *error) { + // See ADBC_INGEST_OPTION_TARGET_TABLE + D_ASSERT(statement->ingestion_stream.release); + D_ASSERT(statement->ingestion_table_name); + + // Take the input stream from the statement + auto stream = statement->ingestion_stream; + statement->ingestion_stream.release = nullptr; + + // Ingest into a table from the bound stream + return Ingest(statement->connection, statement->ingestion_table_name, &stream, error, statement->ingestion_mode); +} + +AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, + int64_t *rows_affected, struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + + // TODO: Set affected rows, careful with early return + if (rows_affected) { + *rows_affected = 0; + } + + const auto has_stream = wrapper->ingestion_stream.release != nullptr; + const auto to_table = wrapper->ingestion_table_name != nullptr; + + if (has_stream && to_table) { + return IngestToTableFromBoundStream(wrapper, error); + } + + if (has_stream) { + // A stream was bound to the statement, use that to bind parameters + duckdb::unique_ptr result; + ArrowArrayStream stream = wrapper->ingestion_stream; + wrapper->ingestion_stream.release = nullptr; + auto adbc_res = GetPreparedParameters(wrapper->connection, result, &stream, error); + if (adbc_res != ADBC_STATUS_OK) { + return adbc_res; + } + if (!result) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + duckdb::unique_ptr chunk; + while ((chunk = result->Fetch()) != nullptr) { + if (chunk->size() == 0) { + SetError(error, "Please provide a non-empty chunk to be bound"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (chunk->size() != 1) { + // TODO: add support for binding multiple rows + SetError(error, "Binding multiple rows at once is not supported yet"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + duckdb_clear_bindings(wrapper->statement); + for (idx_t col_idx = 0; col_idx < chunk->ColumnCount(); col_idx++) { + auto val = chunk->GetValue(col_idx, 0); + auto duck_val = (duckdb_value)&val; + auto res = duckdb_bind_value(wrapper->statement, 1 + col_idx, duck_val); + if (res != DuckDBSuccess) { + SetError(error, duckdb_prepare_error(wrapper->statement)); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + + auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); + if (res != DuckDBSuccess) { + SetError(error, duckdb_query_arrow_error(wrapper->result)); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + } else { + auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); + if (res != DuckDBSuccess) { + SetError(error, duckdb_query_arrow_error(wrapper->result)); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + + if (out) { + out->private_data = wrapper->result; + out->get_schema = get_schema; + out->get_next = get_next; + out->release = release; + out->get_last_error = get_last_error; + + // because we handed out the stream pointer its no longer our responsibility to destroy it in + // AdbcStatementRelease, this is now done in release() + wrapper->result = nullptr; + } + + return ADBC_STATUS_OK; +} + +// this is a nop for us +AdbcStatusCode StatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + return ADBC_STATUS_OK; +} + +AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!query) { + SetError(error, "Missing query"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + auto res = duckdb_prepare(wrapper->connection, query, &wrapper->statement); + auto error_msg = duckdb_prepare_error(wrapper->statement); + return CheckResult(res, error, error_msg); +} + +AdbcStatusCode StatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schemas, + struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!values) { + SetError(error, "Missing values object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!schemas) { + SetError(error, "Invalid schemas object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + if (wrapper->ingestion_stream.release) { + // Free the stream that was previously bound + wrapper->ingestion_stream.release(&wrapper->ingestion_stream); + } + auto status = BatchToArrayStream(values, schemas, &wrapper->ingestion_stream, error); + return status; +} + +AdbcStatusCode StatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *values, + struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!values) { + SetError(error, "Missing values object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + if (wrapper->ingestion_stream.release) { + // Release any resources currently held by the ingestion stream before we overwrite it + wrapper->ingestion_stream.release(&wrapper->ingestion_stream); + } + wrapper->ingestion_stream = *values; + values->release = nullptr; + return ADBC_STATUS_OK; +} + +AdbcStatusCode StatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, + struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!key) { + SetError(error, "Missing key object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + + if (strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { + wrapper->ingestion_table_name = strdup(value); + return ADBC_STATUS_OK; + } + if (strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { + if (strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) { + wrapper->ingestion_mode = IngestionMode::CREATE; + return ADBC_STATUS_OK; + } else if (strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { + wrapper->ingestion_mode = IngestionMode::APPEND; + return ADBC_STATUS_OK; + } else { + SetError(error, "Invalid ingestion mode"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + return ADBC_STATUS_INVALID_ARGUMENT; +} + +AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, + const char *db_schema, const char *table_name, const char **table_type, + const char *column_name, struct ArrowArrayStream *out, struct AdbcError *error) { + if (catalog != nullptr) { + if (strcmp(catalog, "duckdb") == 0) { + SetError(error, "catalog must be NULL or 'duckdb'"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + + if (table_type != nullptr) { + SetError(error, "Table types parameter not yet supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + std::string query; + switch (depth) { + case ADBC_OBJECT_DEPTH_CATALOGS: + SetError(error, "ADBC_OBJECT_DEPTH_CATALOGS not yet supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + case ADBC_OBJECT_DEPTH_DB_SCHEMAS: + // Return metadata on catalogs and schemas. + query = duckdb::StringUtil::Format(R"( + SELECT table_schema db_schema_name + FROM information_schema.columns + WHERE table_schema LIKE '%s' AND table_name LIKE '%s' AND column_name LIKE '%s' ; + )", + db_schema ? db_schema : "%", table_name ? table_name : "%", + column_name ? column_name : "%"); + break; + case ADBC_OBJECT_DEPTH_TABLES: + // Return metadata on catalogs, schemas, and tables. + query = duckdb::StringUtil::Format(R"( + SELECT table_schema db_schema_name, LIST(table_schema_list) db_schema_tables + FROM ( + SELECT table_schema, { table_name : table_name} table_schema_list + FROM information_schema.columns + WHERE table_schema LIKE '%s' AND table_name LIKE '%s' AND column_name LIKE '%s' GROUP BY table_schema, table_name + ) GROUP BY table_schema; + )", + db_schema ? db_schema : "%", table_name ? table_name : "%", + column_name ? column_name : "%"); + break; + case ADBC_OBJECT_DEPTH_COLUMNS: + // Return metadata on catalogs, schemas, tables, and columns. + query = duckdb::StringUtil::Format(R"( + SELECT table_schema db_schema_name, LIST(table_schema_list) db_schema_tables + FROM ( + SELECT table_schema, { table_name : table_name, table_columns : LIST({column_name : column_name, ordinal_position : ordinal_position + 1, remarks : ''})} table_schema_list + FROM information_schema.columns + WHERE table_schema LIKE '%s' AND table_name LIKE '%s' AND column_name LIKE '%s' GROUP BY table_schema, table_name + ) GROUP BY table_schema; + )", + db_schema ? db_schema : "%", table_name ? table_name : "%", + column_name ? column_name : "%"); + break; + default: + SetError(error, "Invalid value of Depth"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + return QueryInternal(connection, out, query.c_str(), error); +} + +AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *out, + struct AdbcError *error) { + const char *q = "SELECT DISTINCT table_type FROM information_schema.tables ORDER BY table_type"; + return QueryInternal(connection, out, q, error); +} + +} // namespace duckdb_adbc diff --git a/src/duckdb/src/common/adbc/driver_manager.cpp b/src/duckdb/src/common/adbc/driver_manager.cpp new file mode 100644 index 00000000..23ae9826 --- /dev/null +++ b/src/duckdb/src/common/adbc/driver_manager.cpp @@ -0,0 +1,790 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "duckdb/common/adbc/driver_manager.h" +#include "duckdb/common/adbc/adbc.h" +#include "duckdb/common/adbc/adbc.hpp" + +#include +#include +#include +#include +#include + +#if defined(_WIN32) +#include // Must come first + +#include +#include +#else +#include +#endif // defined(_WIN32) + +namespace duckdb_adbc { + +// Platform-specific helpers + +#if defined(_WIN32) +/// Append a description of the Windows error to the buffer. +void GetWinError(std::string *buffer) { + DWORD rc = GetLastError(); + LPVOID message; + + FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + reinterpret_cast(&message), /*nSize=*/0, /*Arguments=*/nullptr); + + (*buffer) += '('; + (*buffer) += std::to_string(rc); + (*buffer) += ") "; + (*buffer) += reinterpret_cast(message); + LocalFree(message); +} + +#endif // defined(_WIN32) + +// Temporary state while the database is being configured. +struct TempDatabase { + std::unordered_map options; + std::string driver; + // Default name (see adbc.h) + std::string entrypoint = "AdbcDriverInit"; + AdbcDriverInitFunc init_func = nullptr; +}; + +// Error handling + +void ReleaseError(struct AdbcError *error) { + if (error) { + if (error->message) { + delete[] error->message; + } + error->message = nullptr; + error->release = nullptr; + } +} + +void SetError(struct AdbcError *error, const std::string &message) { + if (!error) { + return; + } + if (error->message) { + // Append + std::string buffer = error->message; + buffer.reserve(buffer.size() + message.size() + 1); + buffer += '\n'; + buffer += message; + error->release(error); + + error->message = new char[buffer.size() + 1]; + buffer.copy(error->message, buffer.size()); + error->message[buffer.size()] = '\0'; + } else { + error->message = new char[message.size() + 1]; + message.copy(error->message, message.size()); + error->message[message.size()] = '\0'; + } + error->release = ReleaseError; +} + +void SetError(struct AdbcError *error, const char *message_p) { + if (!message_p) { + message_p = ""; + } + std::string message(message_p); + SetError(error, message); +} + +// Driver state + +/// Hold the driver DLL and the driver release callback in the driver struct. +struct ManagerDriverState { + // The original release callback + AdbcStatusCode (*driver_release)(struct AdbcDriver *driver, struct AdbcError *error); + +#if defined(_WIN32) + // The loaded DLL + HMODULE handle; +#endif // defined(_WIN32) +}; + +/// Unload the driver DLL. +static AdbcStatusCode ReleaseDriver(struct AdbcDriver *driver, struct AdbcError *error) { + AdbcStatusCode status = ADBC_STATUS_OK; + + if (!driver->private_manager) { + return status; + } + ManagerDriverState *state = reinterpret_cast(driver->private_manager); + + if (state->driver_release) { + status = state->driver_release(driver, error); + } + +#if defined(_WIN32) + // TODO(apache/arrow-adbc#204): causes tests to segfault + // if (!FreeLibrary(state->handle)) { + // std::string message = "FreeLibrary() failed: "; + // GetWinError(&message); + // SetError(error, message); + // } +#endif // defined(_WIN32) + + driver->private_manager = nullptr; + delete state; + return status; +} + +/// Temporary state while the database is being configured. +struct TempConnection { + std::unordered_map options; +}; + +// Direct implementations of API methods + +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { + // Allocate a temporary structure to store options pre-Init + database->private_data = new TempDatabase(); + database->private_driver = nullptr; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, + struct AdbcError *error) { + if (!database) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (database->private_driver) { + return database->private_driver->DatabaseSetOption(database, key, value, error); + } + + TempDatabase *args = reinterpret_cast(database->private_data); + if (std::strcmp(key, "driver") == 0) { + args->driver = value; + } else if (std::strcmp(key, "entrypoint") == 0) { + args->entrypoint = value; + } else { + args->options[key] = value; + } + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, + struct AdbcError *error) { + if (!database) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (database->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + + TempDatabase *args = reinterpret_cast(database->private_data); + args->init_func = init_func; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { + if (!database->private_data) { + SetError(error, "Must call AdbcDatabaseNew first"); + return ADBC_STATUS_INVALID_STATE; + } + TempDatabase *args = reinterpret_cast(database->private_data); + if (args->init_func) { + // Do nothing + } else if (args->driver.empty()) { + SetError(error, "Must provide 'driver' parameter"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + database->private_driver = new AdbcDriver; + std::memset(database->private_driver, 0, sizeof(AdbcDriver)); + AdbcStatusCode status; + // So we don't confuse a driver into thinking it's initialized already + database->private_data = nullptr; + if (args->init_func) { + status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, database->private_driver, error); + } else { + status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), ADBC_VERSION_1_0_0, + database->private_driver, error); + } + if (status != ADBC_STATUS_OK) { + // Restore private_data so it will be released by AdbcDatabaseRelease + database->private_data = args; + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); + } + delete database->private_driver; + database->private_driver = nullptr; + return status; + } + status = database->private_driver->DatabaseNew(database, error); + if (status != ADBC_STATUS_OK) { + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); + } + delete database->private_driver; + database->private_driver = nullptr; + return status; + } + for (const auto &option : args->options) { + status = + database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); + if (status != ADBC_STATUS_OK) { + delete args; + // Release the database + std::ignore = database->private_driver->DatabaseRelease(database, error); + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); + } + delete database->private_driver; + database->private_driver = nullptr; + // Should be redundant, but ensure that AdbcDatabaseRelease + // below doesn't think that it contains a TempDatabase + database->private_data = nullptr; + return status; + } + } + delete args; + return database->private_driver->DatabaseInit(database, error); +} + +AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { + if (!database->private_driver) { + if (database->private_data) { + TempDatabase *args = reinterpret_cast(database->private_data); + delete args; + database->private_data = nullptr; + return ADBC_STATUS_OK; + } + return ADBC_STATUS_INVALID_STATE; + } + auto status = database->private_driver->DatabaseRelease(database, error); + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); + } + delete database->private_driver; + database->private_data = nullptr; + database->private_driver = nullptr; + return status; +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return connection->private_driver->ConnectionCommit(connection, error); +} + +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, + struct ArrowArrayStream *out, struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return connection->private_driver->ConnectionGetInfo(connection, info_codes, info_codes_length, out, error); +} + +AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, + const char *db_schema, const char *table_name, const char **table_types, + const char *column_name, struct ArrowArrayStream *stream, + struct AdbcError *error) { + if (!connection) { + SetError(error, "connection can't be null"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_data) { + SetError(error, "connection must be initialized"); + return ADBC_STATUS_INVALID_STATE; + } + return connection->private_driver->ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, + table_types, column_name, stream, error); +} + +AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, + const char *db_schema, const char *table_name, struct ArrowSchema *schema, + struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return connection->private_driver->ConnectionGetTableSchema(connection, catalog, db_schema, table_name, schema, + error); +} + +AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *stream, + struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); +} + +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, + struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_data) { + SetError(error, "Must call AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } else if (!database->private_driver) { + SetError(error, "Database is not initialized"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + TempConnection *args = reinterpret_cast(connection->private_data); + connection->private_data = nullptr; + std::unordered_map options = std::move(args->options); + delete args; + + auto status = database->private_driver->ConnectionNew(connection, error); + if (status != ADBC_STATUS_OK) { + return status; + } + connection->private_driver = database->private_driver; + + for (const auto &option : options) { + status = database->private_driver->ConnectionSetOption(connection, option.first.c_str(), option.second.c_str(), + error); + if (status != ADBC_STATUS_OK) { + return status; + } + } + return connection->private_driver->ConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { + // Allocate a temporary structure to store options pre-Init, because + // we don't get access to the database (and hence the driver + // function table) until then + connection->private_data = new TempConnection; + connection->private_driver = nullptr; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, + size_t serialized_length, struct ArrowArrayStream *out, + struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return connection->private_driver->ConnectionReadPartition(connection, serialized_partition, serialized_length, out, + error); +} + +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_driver) { + if (connection->private_data) { + TempConnection *args = reinterpret_cast(connection->private_data); + delete args; + connection->private_data = nullptr; + return ADBC_STATUS_OK; + } + return ADBC_STATUS_INVALID_STATE; + } + auto status = connection->private_driver->ConnectionRelease(connection, error); + connection->private_driver = nullptr; + return status; +} + +AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return connection->private_driver->ConnectionRollback(connection, error); +} + +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, + struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection *args = reinterpret_cast(connection->private_data); + args->options[key] = value; + return ADBC_STATUS_OK; + } + return connection->private_driver->ConnectionSetOption(connection, key, value, error); +} + +AdbcStatusCode AdbcStatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema, + struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementBind(statement, values, schema, error); +} + +AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, + struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementBindStream(statement, stream, error); +} + +// XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement *statement, ArrowSchema *schema, + struct AdbcPartitions *partitions, int64_t *rows_affected, + struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementExecutePartitions(statement, schema, partitions, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, + int64_t *rows_affected, struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementGetParameterSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, + struct AdbcError *error) { + if (!connection) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + auto status = connection->private_driver->StatementNew(connection, statement, error); + statement->private_driver = connection->private_driver; + return status; +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + auto status = statement->private_driver->StatementRelease(statement, error); + statement->private_driver = nullptr; + return status; +} + +AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, + struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementSetOption(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, + struct AdbcError *error) { + if (!statement) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); +} + +const char *AdbcStatusCodeMessage(AdbcStatusCode code) { +#define STRINGIFY(s) #s +#define STRINGIFY_VALUE(s) STRINGIFY(s) +#define CASE(CONSTANT) \ + case CONSTANT: \ + return #CONSTANT " (" STRINGIFY_VALUE(CONSTANT) ")"; + + switch (code) { + CASE(ADBC_STATUS_OK); + CASE(ADBC_STATUS_UNKNOWN); + CASE(ADBC_STATUS_NOT_IMPLEMENTED); + CASE(ADBC_STATUS_NOT_FOUND); + CASE(ADBC_STATUS_ALREADY_EXISTS); + CASE(ADBC_STATUS_INVALID_ARGUMENT); + CASE(ADBC_STATUS_INVALID_STATE); + CASE(ADBC_STATUS_INVALID_DATA); + CASE(ADBC_STATUS_INTEGRITY); + CASE(ADBC_STATUS_INTERNAL); + CASE(ADBC_STATUS_IO); + CASE(ADBC_STATUS_CANCELLED); + CASE(ADBC_STATUS_TIMEOUT); + CASE(ADBC_STATUS_UNAUTHENTICATED); + CASE(ADBC_STATUS_UNAUTHORIZED); + default: + return "(invalid code)"; + } +#undef CASE +#undef STRINGIFY_VALUE +#undef STRINGIFY +} + +AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *raw_driver, + struct AdbcError *error) { + AdbcDriverInitFunc init_func; + std::string error_message; + + if (version != ADBC_VERSION_1_0_0) { + SetError(error, "Only ADBC 1.0.0 is supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + auto *driver = reinterpret_cast(raw_driver); + + if (!entrypoint) { + // Default entrypoint (see adbc.h) + entrypoint = "AdbcDriverInit"; + } + +#if defined(_WIN32) + + HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); + if (!handle) { + error_message += driver_name; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + + std::string full_driver_name = driver_name; + full_driver_name += ".lib"; + handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); + if (!handle) { + error_message += '\n'; + error_message += full_driver_name; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + } + } + if (!handle) { + SetError(error, error_message); + return ADBC_STATUS_INTERNAL; + } + + void *load_handle = reinterpret_cast(GetProcAddress(handle, entrypoint)); + init_func = reinterpret_cast(load_handle); + if (!init_func) { + std::string message = "GetProcAddress("; + message += entrypoint; + message += ") failed: "; + GetWinError(&message); + if (!FreeLibrary(handle)) { + message += "\nFreeLibrary() failed: "; + GetWinError(&message); + } + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } + +#else + +#if defined(__APPLE__) + const std::string kPlatformLibraryPrefix = "lib"; + const std::string kPlatformLibrarySuffix = ".dylib"; +#else + const std::string kPlatformLibraryPrefix = "lib"; + const std::string kPlatformLibrarySuffix = ".so"; +#endif // defined(__APPLE__) + + void *handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message = "dlopen() failed: "; + error_message += dlerror(); + + // If applicable, append the shared library prefix/extension and + // try again (this way you don't have to hardcode driver names by + // platform in the application) + const std::string driver_str = driver_name; + + std::string full_driver_name; + if (driver_str.size() < kPlatformLibraryPrefix.size() || + driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != 0) { + full_driver_name += kPlatformLibraryPrefix; + } + full_driver_name += driver_name; + if (driver_str.size() < kPlatformLibrarySuffix.size() || + driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix) != 0) { + full_driver_name += kPlatformLibrarySuffix; + } + handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message += "\ndlopen() failed: "; + error_message += dlerror(); + } + } + if (!handle) { + SetError(error, error_message); + // AdbcDatabaseInit tries to call this if set + driver->release = nullptr; + return ADBC_STATUS_INTERNAL; + } + + void *load_handle = dlsym(handle, entrypoint); + if (!load_handle) { + std::string message = "dlsym("; + message += entrypoint; + message += ") failed: "; + message += dlerror(); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } + init_func = reinterpret_cast(load_handle); + +#endif // defined(_WIN32) + + AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); + if (status == ADBC_STATUS_OK) { + ManagerDriverState *state = new ManagerDriverState; + state->driver_release = driver->release; +#if defined(_WIN32) + state->handle = handle; +#endif // defined(_WIN32) + driver->release = &ReleaseDriver; + driver->private_manager = state; + } else { +#if defined(_WIN32) + if (!FreeLibrary(handle)) { + std::string message = "FreeLibrary() failed: "; + GetWinError(&message); + SetError(error, message); + } +#endif // defined(_WIN32) + } + return status; +} + +AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *raw_driver, + struct AdbcError *error) { +#define FILL_DEFAULT(DRIVER, STUB) \ + if (!DRIVER->STUB) { \ + DRIVER->STUB = &STUB; \ + } +#define CHECK_REQUIRED(DRIVER, STUB) \ + if (!DRIVER->STUB) { \ + SetError(error, "Driver does not implement required function Adbc" #STUB); \ + return ADBC_STATUS_INTERNAL; \ + } + + auto result = init_func(version, raw_driver, error); + if (result != ADBC_STATUS_OK) { + return result; + } + + if (version == ADBC_VERSION_1_0_0) { + auto *driver = reinterpret_cast(raw_driver); + CHECK_REQUIRED(driver, DatabaseNew); + CHECK_REQUIRED(driver, DatabaseInit); + CHECK_REQUIRED(driver, DatabaseRelease); + FILL_DEFAULT(driver, DatabaseSetOption); + + CHECK_REQUIRED(driver, ConnectionNew); + CHECK_REQUIRED(driver, ConnectionInit); + CHECK_REQUIRED(driver, ConnectionRelease); + FILL_DEFAULT(driver, ConnectionCommit); + FILL_DEFAULT(driver, ConnectionGetInfo); + FILL_DEFAULT(driver, ConnectionGetObjects); + FILL_DEFAULT(driver, ConnectionGetTableSchema); + FILL_DEFAULT(driver, ConnectionGetTableTypes); + FILL_DEFAULT(driver, ConnectionReadPartition); + FILL_DEFAULT(driver, ConnectionRollback); + FILL_DEFAULT(driver, ConnectionSetOption); + + FILL_DEFAULT(driver, StatementExecutePartitions); + CHECK_REQUIRED(driver, StatementExecuteQuery); + CHECK_REQUIRED(driver, StatementNew); + CHECK_REQUIRED(driver, StatementRelease); + FILL_DEFAULT(driver, StatementBind); + FILL_DEFAULT(driver, StatementGetParameterSchema); + FILL_DEFAULT(driver, StatementPrepare); + FILL_DEFAULT(driver, StatementSetOption); + FILL_DEFAULT(driver, StatementSetSqlQuery); + FILL_DEFAULT(driver, StatementSetSubstraitPlan); + } + + return ADBC_STATUS_OK; + +#undef FILL_DEFAULT +#undef CHECK_REQUIRED +} +} // namespace duckdb_adbc diff --git a/src/duckdb/src/common/adbc/nanoarrow/allocator.cpp b/src/duckdb/src/common/adbc/nanoarrow/allocator.cpp new file mode 100644 index 00000000..692cb58a --- /dev/null +++ b/src/duckdb/src/common/adbc/nanoarrow/allocator.cpp @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" + +namespace duckdb_nanoarrow { + +void *ArrowMalloc(int64_t size) { + return malloc(size); +} + +void *ArrowRealloc(void *ptr, int64_t size) { + return realloc(ptr, size); +} + +void ArrowFree(void *ptr) { + free(ptr); +} + +static uint8_t *ArrowBufferAllocatorMallocAllocate(struct ArrowBufferAllocator *allocator, int64_t size) { + return (uint8_t *)ArrowMalloc(size); +} + +static uint8_t *ArrowBufferAllocatorMallocReallocate(struct ArrowBufferAllocator *allocator, uint8_t *ptr, + int64_t old_size, int64_t new_size) { + return (uint8_t *)ArrowRealloc(ptr, new_size); +} + +static void ArrowBufferAllocatorMallocFree(struct ArrowBufferAllocator *allocator, uint8_t *ptr, int64_t size) { + ArrowFree(ptr); +} + +static struct ArrowBufferAllocator ArrowBufferAllocatorMalloc = { + &ArrowBufferAllocatorMallocAllocate, &ArrowBufferAllocatorMallocReallocate, &ArrowBufferAllocatorMallocFree, NULL}; + +struct ArrowBufferAllocator *ArrowBufferAllocatorDefault() { + return &ArrowBufferAllocatorMalloc; +} + +} // namespace duckdb_nanoarrow diff --git a/src/duckdb/src/common/adbc/nanoarrow/metadata.cpp b/src/duckdb/src/common/adbc/nanoarrow/metadata.cpp new file mode 100644 index 00000000..742bbe41 --- /dev/null +++ b/src/duckdb/src/common/adbc/nanoarrow/metadata.cpp @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" + +namespace duckdb_nanoarrow { + +ArrowErrorCode ArrowMetadataReaderInit(struct ArrowMetadataReader *reader, const char *metadata) { + reader->metadata = metadata; + + if (reader->metadata == NULL) { + reader->offset = 0; + reader->remaining_keys = 0; + } else { + memcpy(&reader->remaining_keys, reader->metadata, sizeof(int32_t)); + reader->offset = sizeof(int32_t); + } + + return NANOARROW_OK; +} + +ArrowErrorCode ArrowMetadataReaderRead(struct ArrowMetadataReader *reader, struct ArrowStringView *key_out, + struct ArrowStringView *value_out) { + if (reader->remaining_keys <= 0) { + return EINVAL; + } + + int64_t pos = 0; + + int32_t key_size; + memcpy(&key_size, reader->metadata + reader->offset + pos, sizeof(int32_t)); + pos += sizeof(int32_t); + + key_out->data = reader->metadata + reader->offset + pos; + key_out->n_bytes = key_size; + pos += key_size; + + int32_t value_size; + memcpy(&value_size, reader->metadata + reader->offset + pos, sizeof(int32_t)); + pos += sizeof(int32_t); + + value_out->data = reader->metadata + reader->offset + pos; + value_out->n_bytes = value_size; + pos += value_size; + + reader->offset += pos; + reader->remaining_keys--; + return NANOARROW_OK; +} + +int64_t ArrowMetadataSizeOf(const char *metadata) { + if (metadata == NULL) { + return 0; + } + + struct ArrowMetadataReader reader; + struct ArrowStringView key; + struct ArrowStringView value; + ArrowMetadataReaderInit(&reader, metadata); + + int64_t size = sizeof(int32_t); + while (ArrowMetadataReaderRead(&reader, &key, &value) == NANOARROW_OK) { + size += sizeof(int32_t) + key.n_bytes + sizeof(int32_t) + value.n_bytes; + } + + return size; +} + +ArrowErrorCode ArrowMetadataGetValue(const char *metadata, const char *key, const char *default_value, + struct ArrowStringView *value_out) { + struct ArrowStringView target_key_view = {key, static_cast(strlen(key))}; + value_out->data = default_value; + if (default_value != NULL) { + value_out->n_bytes = strlen(default_value); + } else { + value_out->n_bytes = 0; + } + + struct ArrowMetadataReader reader; + struct ArrowStringView key_view; + struct ArrowStringView value; + ArrowMetadataReaderInit(&reader, metadata); + + while (ArrowMetadataReaderRead(&reader, &key_view, &value) == NANOARROW_OK) { + int key_equal = target_key_view.n_bytes == key_view.n_bytes && + strncmp(target_key_view.data, key_view.data, key_view.n_bytes) == 0; + if (key_equal) { + value_out->data = value.data; + value_out->n_bytes = value.n_bytes; + break; + } + } + + return NANOARROW_OK; +} + +char ArrowMetadataHasKey(const char *metadata, const char *key) { + struct ArrowStringView value; + ArrowMetadataGetValue(metadata, key, NULL, &value); + return value.data != NULL; +} + +} // namespace duckdb_nanoarrow diff --git a/src/duckdb/src/common/adbc/nanoarrow/schema.cpp b/src/duckdb/src/common/adbc/nanoarrow/schema.cpp new file mode 100644 index 00000000..1ed36f1f --- /dev/null +++ b/src/duckdb/src/common/adbc/nanoarrow/schema.cpp @@ -0,0 +1,474 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" + +namespace duckdb_nanoarrow { + +void ArrowSchemaRelease(struct ArrowSchema *schema) { + if (schema->format != NULL) + ArrowFree((void *)schema->format); + if (schema->name != NULL) + ArrowFree((void *)schema->name); + if (schema->metadata != NULL) + ArrowFree((void *)schema->metadata); + + // This object owns the memory for all the children, but those + // children may have been generated elsewhere and might have + // their own release() callback. + if (schema->children != NULL) { + for (int64_t i = 0; i < schema->n_children; i++) { + if (schema->children[i] != NULL) { + if (schema->children[i]->release != NULL) { + schema->children[i]->release(schema->children[i]); + } + + ArrowFree(schema->children[i]); + } + } + + ArrowFree(schema->children); + } + + // This object owns the memory for the dictionary but it + // may have been generated somewhere else and have its own + // release() callback. + if (schema->dictionary != NULL) { + if (schema->dictionary->release != NULL) { + schema->dictionary->release(schema->dictionary); + } + + ArrowFree(schema->dictionary); + } + + // private data not currently used + if (schema->private_data != NULL) { + ArrowFree(schema->private_data); + } + + schema->release = NULL; +} + +const char *ArrowSchemaFormatTemplate(enum ArrowType data_type) { + switch (data_type) { + case NANOARROW_TYPE_UNINITIALIZED: + return NULL; + case NANOARROW_TYPE_NA: + return "n"; + case NANOARROW_TYPE_BOOL: + return "b"; + + case NANOARROW_TYPE_UINT8: + return "C"; + case NANOARROW_TYPE_INT8: + return "c"; + case NANOARROW_TYPE_UINT16: + return "S"; + case NANOARROW_TYPE_INT16: + return "s"; + case NANOARROW_TYPE_UINT32: + return "I"; + case NANOARROW_TYPE_INT32: + return "i"; + case NANOARROW_TYPE_UINT64: + return "L"; + case NANOARROW_TYPE_INT64: + return "l"; + + case NANOARROW_TYPE_HALF_FLOAT: + return "e"; + case NANOARROW_TYPE_FLOAT: + return "f"; + case NANOARROW_TYPE_DOUBLE: + return "g"; + + case NANOARROW_TYPE_STRING: + return "u"; + case NANOARROW_TYPE_LARGE_STRING: + return "U"; + case NANOARROW_TYPE_BINARY: + return "z"; + case NANOARROW_TYPE_LARGE_BINARY: + return "Z"; + + case NANOARROW_TYPE_DATE32: + return "tdD"; + case NANOARROW_TYPE_DATE64: + return "tdm"; + case NANOARROW_TYPE_INTERVAL_MONTHS: + return "tiM"; + case NANOARROW_TYPE_INTERVAL_DAY_TIME: + return "tiD"; + case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: + return "tin"; + + case NANOARROW_TYPE_LIST: + return "+l"; + case NANOARROW_TYPE_LARGE_LIST: + return "+L"; + case NANOARROW_TYPE_STRUCT: + return "+s"; + case NANOARROW_TYPE_MAP: + return "+m"; + + default: + return NULL; + } +} + +ArrowErrorCode ArrowSchemaInit(struct ArrowSchema *schema, enum ArrowType data_type) { + schema->format = NULL; + schema->name = NULL; + schema->metadata = NULL; + schema->flags = ARROW_FLAG_NULLABLE; + schema->n_children = 0; + schema->children = NULL; + schema->dictionary = NULL; + schema->private_data = NULL; + schema->release = &ArrowSchemaRelease; + + // We don't allocate the dictionary because it has to be nullptr + // for non-dictionary-encoded arrays. + + // Set the format to a valid format string for data_type + const char *template_format = ArrowSchemaFormatTemplate(data_type); + + // If data_type isn't recognized and not explicitly unset + if (template_format == NULL && data_type != NANOARROW_TYPE_UNINITIALIZED) { + schema->release(schema); + return EINVAL; + } + + int result = ArrowSchemaSetFormat(schema, template_format); + if (result != NANOARROW_OK) { + schema->release(schema); + return result; + } + + return NANOARROW_OK; +} + +ArrowErrorCode ArrowSchemaInitFixedSize(struct ArrowSchema *schema, enum ArrowType data_type, int32_t fixed_size) { + int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); + if (result != NANOARROW_OK) { + return result; + } + + if (fixed_size <= 0) { + schema->release(schema); + return EINVAL; + } + + char buffer[64]; + int n_chars; + switch (data_type) { + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + n_chars = snprintf(buffer, sizeof(buffer), "w:%d", (int)fixed_size); + break; + case NANOARROW_TYPE_FIXED_SIZE_LIST: + n_chars = snprintf(buffer, sizeof(buffer), "+w:%d", (int)fixed_size); + break; + default: + schema->release(schema); + return EINVAL; + } + + buffer[n_chars] = '\0'; + result = ArrowSchemaSetFormat(schema, buffer); + if (result != NANOARROW_OK) { + schema->release(schema); + } + + return result; +} + +ArrowErrorCode ArrowSchemaInitDecimal(struct ArrowSchema *schema, enum ArrowType data_type, int32_t decimal_precision, + int32_t decimal_scale) { + int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); + if (result != NANOARROW_OK) { + return result; + } + + if (decimal_precision <= 0) { + schema->release(schema); + return EINVAL; + } + + char buffer[64]; + int n_chars; + switch (data_type) { + case NANOARROW_TYPE_DECIMAL128: + n_chars = snprintf(buffer, sizeof(buffer), "d:%d,%d", decimal_precision, decimal_scale); + break; + case NANOARROW_TYPE_DECIMAL256: + n_chars = snprintf(buffer, sizeof(buffer), "d:%d,%d,256", decimal_precision, decimal_scale); + break; + default: + schema->release(schema); + return EINVAL; + } + + buffer[n_chars] = '\0'; + + result = ArrowSchemaSetFormat(schema, buffer); + if (result != NANOARROW_OK) { + schema->release(schema); + return result; + } + + return NANOARROW_OK; +} + +static const char *ArrowTimeUnitString(enum ArrowTimeUnit time_unit) { + switch (time_unit) { + case NANOARROW_TIME_UNIT_SECOND: + return "s"; + case NANOARROW_TIME_UNIT_MILLI: + return "m"; + case NANOARROW_TIME_UNIT_MICRO: + return "u"; + case NANOARROW_TIME_UNIT_NANO: + return "n"; + default: + return NULL; + } +} + +ArrowErrorCode ArrowSchemaInitDateTime(struct ArrowSchema *schema, enum ArrowType data_type, + enum ArrowTimeUnit time_unit, const char *timezone) { + int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); + if (result != NANOARROW_OK) { + return result; + } + + const char *time_unit_str = ArrowTimeUnitString(time_unit); + if (time_unit_str == NULL) { + schema->release(schema); + return EINVAL; + } + + char buffer[128]; + int n_chars; + switch (data_type) { + case NANOARROW_TYPE_TIME32: + case NANOARROW_TYPE_TIME64: + if (timezone != NULL) { + schema->release(schema); + return EINVAL; + } + n_chars = snprintf(buffer, sizeof(buffer), "tt%s", time_unit_str); + break; + case NANOARROW_TYPE_TIMESTAMP: + if (timezone == NULL) { + timezone = ""; + } + n_chars = snprintf(buffer, sizeof(buffer), "ts%s:%s", time_unit_str, timezone); + break; + case NANOARROW_TYPE_DURATION: + if (timezone != NULL) { + schema->release(schema); + return EINVAL; + } + n_chars = snprintf(buffer, sizeof(buffer), "tD%s", time_unit_str); + break; + default: + schema->release(schema); + return EINVAL; + } + + if (static_cast(n_chars) >= sizeof(buffer)) { + schema->release(schema); + return ERANGE; + } + + buffer[n_chars] = '\0'; + + result = ArrowSchemaSetFormat(schema, buffer); + if (result != NANOARROW_OK) { + schema->release(schema); + return result; + } + + return NANOARROW_OK; +} + +ArrowErrorCode ArrowSchemaSetFormat(struct ArrowSchema *schema, const char *format) { + if (schema->format != NULL) { + ArrowFree((void *)schema->format); + } + + if (format != NULL) { + size_t format_size = strlen(format) + 1; + schema->format = (const char *)ArrowMalloc(format_size); + if (schema->format == NULL) { + return ENOMEM; + } + + memcpy((void *)schema->format, format, format_size); + } else { + schema->format = NULL; + } + + return NANOARROW_OK; +} + +ArrowErrorCode ArrowSchemaSetName(struct ArrowSchema *schema, const char *name) { + if (schema->name != NULL) { + ArrowFree((void *)schema->name); + } + + if (name != NULL) { + size_t name_size = strlen(name) + 1; + schema->name = (const char *)ArrowMalloc(name_size); + if (schema->name == NULL) { + return ENOMEM; + } + + memcpy((void *)schema->name, name, name_size); + } else { + schema->name = NULL; + } + + return NANOARROW_OK; +} + +ArrowErrorCode ArrowSchemaSetMetadata(struct ArrowSchema *schema, const char *metadata) { + if (schema->metadata != NULL) { + ArrowFree((void *)schema->metadata); + } + + if (metadata != NULL) { + size_t metadata_size = ArrowMetadataSizeOf(metadata); + schema->metadata = (const char *)ArrowMalloc(metadata_size); + if (schema->metadata == NULL) { + return ENOMEM; + } + + memcpy((void *)schema->metadata, metadata, metadata_size); + } else { + schema->metadata = NULL; + } + + return NANOARROW_OK; +} + +ArrowErrorCode ArrowSchemaAllocateChildren(struct ArrowSchema *schema, int64_t n_children) { + if (schema->children != NULL) { + return EEXIST; + } + + if (n_children > 0) { + schema->children = (struct ArrowSchema **)ArrowMalloc(n_children * sizeof(struct ArrowSchema *)); + + if (schema->children == NULL) { + return ENOMEM; + } + + schema->n_children = n_children; + + memset(schema->children, 0, n_children * sizeof(struct ArrowSchema *)); + + for (int64_t i = 0; i < n_children; i++) { + schema->children[i] = (struct ArrowSchema *)ArrowMalloc(sizeof(struct ArrowSchema)); + + if (schema->children[i] == NULL) { + return ENOMEM; + } + + schema->children[i]->release = NULL; + } + } + + return NANOARROW_OK; +} + +ArrowErrorCode ArrowSchemaAllocateDictionary(struct ArrowSchema *schema) { + if (schema->dictionary != NULL) { + return EEXIST; + } + + schema->dictionary = (struct ArrowSchema *)ArrowMalloc(sizeof(struct ArrowSchema)); + if (schema->dictionary == NULL) { + return ENOMEM; + } + + schema->dictionary->release = NULL; + return NANOARROW_OK; +} + +int ArrowSchemaDeepCopy(struct ArrowSchema *schema, struct ArrowSchema *schema_out) { + int result; + result = ArrowSchemaInit(schema_out, NANOARROW_TYPE_NA); + if (result != NANOARROW_OK) { + return result; + } + + result = ArrowSchemaSetFormat(schema_out, schema->format); + if (result != NANOARROW_OK) { + schema_out->release(schema_out); + return result; + } + + result = ArrowSchemaSetName(schema_out, schema->name); + if (result != NANOARROW_OK) { + schema_out->release(schema_out); + return result; + } + + result = ArrowSchemaSetMetadata(schema_out, schema->metadata); + if (result != NANOARROW_OK) { + schema_out->release(schema_out); + return result; + } + + result = ArrowSchemaAllocateChildren(schema_out, schema->n_children); + if (result != NANOARROW_OK) { + schema_out->release(schema_out); + return result; + } + + for (int64_t i = 0; i < schema->n_children; i++) { + result = ArrowSchemaDeepCopy(schema->children[i], schema_out->children[i]); + if (result != NANOARROW_OK) { + schema_out->release(schema_out); + return result; + } + } + + if (schema->dictionary != NULL) { + result = ArrowSchemaAllocateDictionary(schema_out); + if (result != NANOARROW_OK) { + schema_out->release(schema_out); + return result; + } + + result = ArrowSchemaDeepCopy(schema->dictionary, schema_out->dictionary); + if (result != NANOARROW_OK) { + schema_out->release(schema_out); + return result; + } + } + + return NANOARROW_OK; +} + +} // namespace duckdb_nanoarrow diff --git a/src/duckdb/src/common/adbc/nanoarrow/single_batch_array_stream.cpp b/src/duckdb/src/common/adbc/nanoarrow/single_batch_array_stream.cpp new file mode 100644 index 00000000..bddcd4e0 --- /dev/null +++ b/src/duckdb/src/common/adbc/nanoarrow/single_batch_array_stream.cpp @@ -0,0 +1,84 @@ +#include "duckdb/common/adbc/single_batch_array_stream.hpp" +#include "duckdb/common/arrow/nanoarrow/nanoarrow.h" +#include "duckdb/common/adbc/adbc.hpp" + +#include "duckdb.h" +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" + +#include +#include +#include +#include +#include + +namespace duckdb_adbc { + +using duckdb_nanoarrow::ArrowSchemaDeepCopy; + +static const char *SingleBatchArrayStreamGetLastError(struct ArrowArrayStream *stream) { + return NULL; +} + +static int SingleBatchArrayStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *batch) { + if (!stream || !stream->private_data) { + return EINVAL; + } + struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; + + memcpy(batch, &impl->batch, sizeof(*batch)); + memset(&impl->batch, 0, sizeof(*batch)); + return 0; +} + +static int SingleBatchArrayStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *schema) { + if (!stream || !stream->private_data) { + return EINVAL; + } + struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; + + return ArrowSchemaDeepCopy(&impl->schema, schema); +} + +static void SingleBatchArrayStreamRelease(struct ArrowArrayStream *stream) { + if (!stream || !stream->private_data) { + return; + } + struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; + impl->schema.release(&impl->schema); + if (impl->batch.release) { + impl->batch.release(&impl->batch); + } + free(impl); + + memset(stream, 0, sizeof(*stream)); +} + +AdbcStatusCode BatchToArrayStream(struct ArrowArray *values, struct ArrowSchema *schema, + struct ArrowArrayStream *stream, struct AdbcError *error) { + if (!values->release) { + SetError(error, "ArrowArray is not initialized"); + return ADBC_STATUS_INTERNAL; + } else if (!schema->release) { + SetError(error, "ArrowSchema is not initialized"); + return ADBC_STATUS_INTERNAL; + } else if (stream->release) { + SetError(error, "ArrowArrayStream is already initialized"); + return ADBC_STATUS_INTERNAL; + } + + struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)malloc(sizeof(*impl)); + memcpy(&impl->schema, schema, sizeof(*schema)); + memcpy(&impl->batch, values, sizeof(*values)); + memset(schema, 0, sizeof(*schema)); + memset(values, 0, sizeof(*values)); + stream->private_data = impl; + stream->get_last_error = SingleBatchArrayStreamGetLastError; + stream->get_next = SingleBatchArrayStreamGetNext; + stream->get_schema = SingleBatchArrayStreamGetSchema; + stream->release = SingleBatchArrayStreamRelease; + + return ADBC_STATUS_OK; +} + +} // namespace duckdb_adbc diff --git a/src/duckdb/src/common/allocator.cpp b/src/duckdb/src/common/allocator.cpp new file mode 100644 index 00000000..939a790d --- /dev/null +++ b/src/duckdb/src/common/allocator.cpp @@ -0,0 +1,245 @@ +#include "duckdb/common/allocator.hpp" + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/exception.hpp" + +#include + +#ifdef DUCKDB_DEBUG_ALLOCATION +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/unordered_map.hpp" + +#include +#endif + +#ifndef USE_JEMALLOC +#if defined(DUCKDB_EXTENSION_JEMALLOC_LINKED) && DUCKDB_EXTENSION_JEMALLOC_LINKED && !defined(WIN32) +#define USE_JEMALLOC +#endif +#endif + +#ifdef USE_JEMALLOC +#include "jemalloc_extension.hpp" +#endif + +namespace duckdb { + +AllocatedData::AllocatedData() : allocator(nullptr), pointer(nullptr), allocated_size(0) { +} + +AllocatedData::AllocatedData(Allocator &allocator, data_ptr_t pointer, idx_t allocated_size) + : allocator(&allocator), pointer(pointer), allocated_size(allocated_size) { + if (!pointer) { + throw InternalException("AllocatedData object constructed with nullptr"); + } +} +AllocatedData::~AllocatedData() { + Reset(); +} + +AllocatedData::AllocatedData(AllocatedData &&other) noexcept + : allocator(other.allocator), pointer(nullptr), allocated_size(0) { + std::swap(pointer, other.pointer); + std::swap(allocated_size, other.allocated_size); +} + +AllocatedData &AllocatedData::operator=(AllocatedData &&other) noexcept { + std::swap(allocator, other.allocator); + std::swap(pointer, other.pointer); + std::swap(allocated_size, other.allocated_size); + return *this; +} + +void AllocatedData::Reset() { + if (!pointer) { + return; + } + D_ASSERT(allocator); + allocator->FreeData(pointer, allocated_size); + allocated_size = 0; + pointer = nullptr; +} + +//===--------------------------------------------------------------------===// +// Debug Info +//===--------------------------------------------------------------------===// +struct AllocatorDebugInfo { +#ifdef DEBUG + AllocatorDebugInfo(); + ~AllocatorDebugInfo(); + + void AllocateData(data_ptr_t pointer, idx_t size); + void FreeData(data_ptr_t pointer, idx_t size); + void ReallocateData(data_ptr_t pointer, data_ptr_t new_pointer, idx_t old_size, idx_t new_size); + +private: + //! The number of bytes that are outstanding (i.e. that have been allocated - but not freed) + //! Used for debug purposes + atomic allocation_count; +#ifdef DUCKDB_DEBUG_ALLOCATION + mutex pointer_lock; + //! Set of active outstanding pointers together with stack traces + unordered_map> pointers; +#endif +#endif +}; + +PrivateAllocatorData::PrivateAllocatorData() { +} + +PrivateAllocatorData::~PrivateAllocatorData() { +} + +//===--------------------------------------------------------------------===// +// Allocator +//===--------------------------------------------------------------------===// +#ifdef USE_JEMALLOC +Allocator::Allocator() + : Allocator(JemallocExtension::Allocate, JemallocExtension::Free, JemallocExtension::Reallocate, nullptr) { +} +#else +Allocator::Allocator() + : Allocator(Allocator::DefaultAllocate, Allocator::DefaultFree, Allocator::DefaultReallocate, nullptr) { +} +#endif + +Allocator::Allocator(allocate_function_ptr_t allocate_function_p, free_function_ptr_t free_function_p, + reallocate_function_ptr_t reallocate_function_p, unique_ptr private_data_p) + : allocate_function(allocate_function_p), free_function(free_function_p), + reallocate_function(reallocate_function_p), private_data(std::move(private_data_p)) { + D_ASSERT(allocate_function); + D_ASSERT(free_function); + D_ASSERT(reallocate_function); +#ifdef DEBUG + if (!private_data) { + private_data = make_uniq(); + } + private_data->debug_info = make_uniq(); +#endif +} + +Allocator::~Allocator() { +} + +data_ptr_t Allocator::AllocateData(idx_t size) { + D_ASSERT(size > 0); + if (size >= MAXIMUM_ALLOC_SIZE) { + D_ASSERT(false); + throw InternalException("Requested allocation size of %llu is out of range - maximum allocation size is %llu", + size, MAXIMUM_ALLOC_SIZE); + } + auto result = allocate_function(private_data.get(), size); +#ifdef DEBUG + D_ASSERT(private_data); + private_data->debug_info->AllocateData(result, size); +#endif + if (!result) { + throw OutOfMemoryException("Failed to allocate block of %llu bytes", size); + } + return result; +} + +void Allocator::FreeData(data_ptr_t pointer, idx_t size) { + if (!pointer) { + return; + } + D_ASSERT(size > 0); +#ifdef DEBUG + D_ASSERT(private_data); + private_data->debug_info->FreeData(pointer, size); +#endif + free_function(private_data.get(), pointer, size); +} + +data_ptr_t Allocator::ReallocateData(data_ptr_t pointer, idx_t old_size, idx_t size) { + if (!pointer) { + return nullptr; + } + if (size >= MAXIMUM_ALLOC_SIZE) { + D_ASSERT(false); + throw InternalException( + "Requested re-allocation size of %llu is out of range - maximum allocation size is %llu", size, + MAXIMUM_ALLOC_SIZE); + } + auto new_pointer = reallocate_function(private_data.get(), pointer, old_size, size); +#ifdef DEBUG + D_ASSERT(private_data); + private_data->debug_info->ReallocateData(pointer, new_pointer, old_size, size); +#endif + if (!new_pointer) { + throw OutOfMemoryException("Failed to re-allocate block of %llu bytes", size); + } + return new_pointer; +} + +shared_ptr &Allocator::DefaultAllocatorReference() { + static shared_ptr DEFAULT_ALLOCATOR = make_shared(); + return DEFAULT_ALLOCATOR; +} + +Allocator &Allocator::DefaultAllocator() { + return *DefaultAllocatorReference(); +} + +void Allocator::ThreadFlush(idx_t threshold) { +#ifdef USE_JEMALLOC + JemallocExtension::ThreadFlush(threshold); +#endif +} + +//===--------------------------------------------------------------------===// +// Debug Info (extended) +//===--------------------------------------------------------------------===// +#ifdef DEBUG +AllocatorDebugInfo::AllocatorDebugInfo() { + allocation_count = 0; +} +AllocatorDebugInfo::~AllocatorDebugInfo() { +#ifdef DUCKDB_DEBUG_ALLOCATION + if (allocation_count != 0) { + printf("Outstanding allocations found for Allocator\n"); + for (auto &entry : pointers) { + printf("Allocation of size %llu at address %p\n", entry.second.first, (void *)entry.first); + printf("Stack trace:\n%s\n", entry.second.second.c_str()); + printf("\n"); + } + } +#endif + //! Verify that there is no outstanding memory still associated with the batched allocator + //! Only works for access to the batched allocator through the batched allocator interface + //! If this assertion triggers, enable DUCKDB_DEBUG_ALLOCATION for more information about the allocations + D_ASSERT(allocation_count == 0); +} + +void AllocatorDebugInfo::AllocateData(data_ptr_t pointer, idx_t size) { + allocation_count += size; +#ifdef DUCKDB_DEBUG_ALLOCATION + lock_guard l(pointer_lock); + pointers[pointer] = make_pair(size, Exception::GetStackTrace()); +#endif +} + +void AllocatorDebugInfo::FreeData(data_ptr_t pointer, idx_t size) { + D_ASSERT(allocation_count >= size); + allocation_count -= size; +#ifdef DUCKDB_DEBUG_ALLOCATION + lock_guard l(pointer_lock); + // verify that the pointer exists + D_ASSERT(pointers.find(pointer) != pointers.end()); + // verify that the stored size matches the passed in size + D_ASSERT(pointers[pointer].first == size); + // erase the pointer + pointers.erase(pointer); +#endif +} + +void AllocatorDebugInfo::ReallocateData(data_ptr_t pointer, data_ptr_t new_pointer, idx_t old_size, idx_t new_size) { + FreeData(pointer, old_size); + AllocateData(new_pointer, new_size); +} + +#endif + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/bool_data.cpp b/src/duckdb/src/common/arrow/appender/bool_data.cpp new file mode 100644 index 00000000..6a5f6728 --- /dev/null +++ b/src/duckdb/src/common/arrow/appender/bool_data.cpp @@ -0,0 +1,44 @@ +#include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/common/arrow/appender/bool_data.hpp" + +namespace duckdb { + +void ArrowBoolData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { + auto byte_count = (capacity + 7) / 8; + result.main_buffer.reserve(byte_count); +} + +void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + idx_t size = to - from; + UnifiedVectorFormat format; + input.ToUnifiedFormat(input_size, format); + + // we initialize both the validity and the bit set to 1's + ResizeValidity(append_data.validity, append_data.row_count + size); + ResizeValidity(append_data.main_buffer, append_data.row_count + size); + auto data = UnifiedVectorFormat::GetData(format); + + auto result_data = append_data.main_buffer.GetData(); + auto validity_data = append_data.validity.GetData(); + uint8_t current_bit; + idx_t current_byte; + GetBitPosition(append_data.row_count, current_byte, current_bit); + for (idx_t i = from; i < to; i++) { + auto source_idx = format.sel->get_index(i); + // append the validity mask + if (!format.validity.RowIsValid(source_idx)) { + SetNull(append_data, validity_data, current_byte, current_bit); + } else if (!data[source_idx]) { + UnsetBit(result_data, current_byte, current_bit); + } + NextBit(current_byte, current_bit); + } + append_data.row_count += size; +} + +void ArrowBoolData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { + result->n_buffers = 2; + result->buffers[1] = append_data.main_buffer.data(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/list_data.cpp b/src/duckdb/src/common/arrow/appender/list_data.cpp new file mode 100644 index 00000000..57400fc7 --- /dev/null +++ b/src/duckdb/src/common/arrow/appender/list_data.cpp @@ -0,0 +1,78 @@ +#include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/common/arrow/appender/list_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Lists +//===--------------------------------------------------------------------===// +void ArrowListData::AppendOffsets(ArrowAppendData &append_data, UnifiedVectorFormat &format, idx_t from, idx_t to, + vector &child_sel) { + // resize the offset buffer - the offset buffer holds the offsets into the child array + idx_t size = to - from; + append_data.main_buffer.resize(append_data.main_buffer.size() + sizeof(uint32_t) * (size + 1)); + auto data = UnifiedVectorFormat::GetData(format); + auto offset_data = append_data.main_buffer.GetData(); + if (append_data.row_count == 0) { + // first entry + offset_data[0] = 0; + } + // set up the offsets using the list entries + auto last_offset = offset_data[append_data.row_count]; + for (idx_t i = from; i < to; i++) { + auto source_idx = format.sel->get_index(i); + auto offset_idx = append_data.row_count + i + 1 - from; + + if (!format.validity.RowIsValid(source_idx)) { + offset_data[offset_idx] = last_offset; + continue; + } + + // append the offset data + auto list_length = data[source_idx].length; + last_offset += list_length; + offset_data[offset_idx] = last_offset; + + for (idx_t k = 0; k < list_length; k++) { + child_sel.push_back(data[source_idx].offset + k); + } + } +} + +void ArrowListData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { + auto &child_type = ListType::GetChildType(type); + result.main_buffer.reserve((capacity + 1) * sizeof(uint32_t)); + auto child_buffer = ArrowAppender::InitializeChild(child_type, capacity, result.options); + result.child_data.push_back(std::move(child_buffer)); +} + +void ArrowListData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + UnifiedVectorFormat format; + input.ToUnifiedFormat(input_size, format); + idx_t size = to - from; + vector child_indices; + AppendValidity(append_data, format, from, to); + ArrowListData::AppendOffsets(append_data, format, from, to, child_indices); + + // append the child vector of the list + SelectionVector child_sel(child_indices.data()); + auto &child = ListVector::GetEntry(input); + auto child_size = child_indices.size(); + Vector child_copy(child.GetType()); + child_copy.Slice(child, child_sel, child_size); + append_data.child_data[0]->append_vector(*append_data.child_data[0], child_copy, 0, child_size, child_size); + append_data.row_count += size; +} + +void ArrowListData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { + result->n_buffers = 2; + result->buffers[1] = append_data.main_buffer.data(); + + auto &child_type = ListType::GetChildType(type); + append_data.child_pointers.resize(1); + result->children = append_data.child_pointers.data(); + result->n_children = 1; + append_data.child_pointers[0] = ArrowAppender::FinalizeChild(child_type, *append_data.child_data[0]); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/map_data.cpp b/src/duckdb/src/common/arrow/appender/map_data.cpp new file mode 100644 index 00000000..90e99a9a --- /dev/null +++ b/src/duckdb/src/common/arrow/appender/map_data.cpp @@ -0,0 +1,86 @@ +#include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/common/arrow/appender/map_data.hpp" +#include "duckdb/common/arrow/appender/list_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Maps +//===--------------------------------------------------------------------===// +void ArrowMapData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { + // map types are stored in a (too) clever way + // the main buffer holds the null values and the offsets + // then we have a single child, which is a struct of the map_type, and the key_type + result.main_buffer.reserve((capacity + 1) * sizeof(uint32_t)); + + auto &key_type = MapType::KeyType(type); + auto &value_type = MapType::ValueType(type); + auto internal_struct = make_uniq(result.options); + internal_struct->child_data.push_back(ArrowAppender::InitializeChild(key_type, capacity, result.options)); + internal_struct->child_data.push_back(ArrowAppender::InitializeChild(value_type, capacity, result.options)); + + result.child_data.push_back(std::move(internal_struct)); +} + +void ArrowMapData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + UnifiedVectorFormat format; + input.ToUnifiedFormat(input_size, format); + idx_t size = to - from; + AppendValidity(append_data, format, from, to); + vector child_indices; + ArrowListData::AppendOffsets(append_data, format, from, to, child_indices); + + SelectionVector child_sel(child_indices.data()); + auto &key_vector = MapVector::GetKeys(input); + auto &value_vector = MapVector::GetValues(input); + auto list_size = child_indices.size(); + + auto &struct_data = *append_data.child_data[0]; + auto &key_data = *struct_data.child_data[0]; + auto &value_data = *struct_data.child_data[1]; + + Vector key_vector_copy(key_vector.GetType()); + key_vector_copy.Slice(key_vector, child_sel, list_size); + Vector value_vector_copy(value_vector.GetType()); + value_vector_copy.Slice(value_vector, child_sel, list_size); + key_data.append_vector(key_data, key_vector_copy, 0, list_size, list_size); + value_data.append_vector(value_data, value_vector_copy, 0, list_size, list_size); + + append_data.row_count += size; + struct_data.row_count += size; +} + +void ArrowMapData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { + // set up the main map buffer + result->n_buffers = 2; + result->buffers[1] = append_data.main_buffer.data(); + + // the main map buffer has a single child: a struct + append_data.child_pointers.resize(1); + result->children = append_data.child_pointers.data(); + result->n_children = 1; + append_data.child_pointers[0] = ArrowAppender::FinalizeChild(type, *append_data.child_data[0]); + + // now that struct has two children: the key and the value type + auto &struct_data = *append_data.child_data[0]; + auto &struct_result = append_data.child_pointers[0]; + struct_data.child_pointers.resize(2); + struct_result->n_buffers = 1; + struct_result->n_children = 2; + struct_result->length = struct_data.child_data[0]->row_count; + struct_result->children = struct_data.child_pointers.data(); + + D_ASSERT(struct_data.child_data[0]->row_count == struct_data.child_data[1]->row_count); + + auto &key_type = MapType::KeyType(type); + auto &value_type = MapType::ValueType(type); + struct_data.child_pointers[0] = ArrowAppender::FinalizeChild(key_type, *struct_data.child_data[0]); + struct_data.child_pointers[1] = ArrowAppender::FinalizeChild(value_type, *struct_data.child_data[1]); + + // keys cannot have null values + if (struct_data.child_pointers[0]->null_count > 0) { + throw std::runtime_error("Arrow doesn't accept NULL keys on Maps"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/struct_data.cpp b/src/duckdb/src/common/arrow/appender/struct_data.cpp new file mode 100644 index 00000000..b6c0972e --- /dev/null +++ b/src/duckdb/src/common/arrow/appender/struct_data.cpp @@ -0,0 +1,45 @@ +#include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/common/arrow/appender/struct_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Structs +//===--------------------------------------------------------------------===// +void ArrowStructData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { + auto &children = StructType::GetChildTypes(type); + for (auto &child : children) { + auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); + result.child_data.push_back(std::move(child_buffer)); + } +} + +void ArrowStructData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + UnifiedVectorFormat format; + input.ToUnifiedFormat(input_size, format); + idx_t size = to - from; + AppendValidity(append_data, format, from, to); + // append the children of the struct + auto &children = StructVector::GetEntries(input); + for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { + auto &child = children[child_idx]; + auto &child_data = *append_data.child_data[child_idx]; + child_data.append_vector(child_data, *child, from, to, size); + } + append_data.row_count += size; +} + +void ArrowStructData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { + result->n_buffers = 1; + + auto &child_types = StructType::GetChildTypes(type); + append_data.child_pointers.resize(child_types.size()); + result->children = append_data.child_pointers.data(); + result->n_children = child_types.size(); + for (idx_t i = 0; i < child_types.size(); i++) { + auto &child_type = child_types[i].second; + append_data.child_pointers[i] = ArrowAppender::FinalizeChild(child_type, *append_data.child_data[i]); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/union_data.cpp b/src/duckdb/src/common/arrow/appender/union_data.cpp new file mode 100644 index 00000000..0c52f80e --- /dev/null +++ b/src/duckdb/src/common/arrow/appender/union_data.cpp @@ -0,0 +1,70 @@ +#include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/common/arrow/appender/union_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Unions +//===--------------------------------------------------------------------===// +void ArrowUnionData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { + result.main_buffer.reserve(capacity * sizeof(int8_t)); + + for (auto &child : UnionType::CopyMemberTypes(type)) { + auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); + result.child_data.push_back(std::move(child_buffer)); + } +} + +void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + UnifiedVectorFormat format; + input.ToUnifiedFormat(input_size, format); + idx_t size = to - from; + + auto &types_buffer = append_data.main_buffer; + + duckdb::vector child_vectors; + for (const auto &child : UnionType::CopyMemberTypes(input.GetType())) { + child_vectors.emplace_back(child.second); + } + + for (idx_t input_idx = from; input_idx < to; input_idx++) { + const auto &val = input.GetValue(input_idx); + + idx_t tag = 0; + Value resolved_value(nullptr); + if (!val.IsNull()) { + tag = UnionValue::GetTag(val); + + resolved_value = UnionValue::GetValue(val); + } + + for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { + child_vectors[child_idx].SetValue(input_idx, child_idx == tag ? resolved_value : Value(nullptr)); + } + + types_buffer.data()[input_idx] = tag; + } + + for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { + auto &child_buffer = append_data.child_data[child_idx]; + auto &child = child_vectors[child_idx]; + child_buffer->append_vector(*child_buffer, child, from, to, size); + } + append_data.row_count += size; +} + +void ArrowUnionData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { + result->n_buffers = 2; + result->buffers[1] = append_data.main_buffer.data(); + + auto &child_types = UnionType::CopyMemberTypes(type); + append_data.child_pointers.resize(child_types.size()); + result->children = append_data.child_pointers.data(); + result->n_children = child_types.size(); + for (idx_t i = 0; i < child_types.size(); i++) { + auto &child_type = child_types[i].second; + append_data.child_pointers[i] = ArrowAppender::FinalizeChild(child_type, *append_data.child_data[i]); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_appender.cpp b/src/duckdb/src/common/arrow/arrow_appender.cpp new file mode 100644 index 00000000..18414f5b --- /dev/null +++ b/src/duckdb/src/common/arrow/arrow_appender.cpp @@ -0,0 +1,241 @@ +#include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/common/arrow/arrow_buffer.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/array.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/uuid.hpp" +#include "duckdb/function/table/arrow.hpp" +#include "duckdb/common/arrow/appender/append_data.hpp" +#include "duckdb/common/arrow/appender/list.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// ArrowAppender +//===--------------------------------------------------------------------===// + +ArrowAppender::ArrowAppender(vector types_p, idx_t initial_capacity, ClientProperties options) + : types(std::move(types_p)) { + for (auto &type : types) { + auto entry = ArrowAppender::InitializeChild(type, initial_capacity, options); + root_data.push_back(std::move(entry)); + } +} + +ArrowAppender::~ArrowAppender() { +} + +//! Append a data chunk to the underlying arrow array +void ArrowAppender::Append(DataChunk &input, idx_t from, idx_t to, idx_t input_size) { + D_ASSERT(types == input.GetTypes()); + D_ASSERT(to >= from); + for (idx_t i = 0; i < input.ColumnCount(); i++) { + root_data[i]->append_vector(*root_data[i], input.data[i], from, to, input_size); + } + row_count += to - from; +} + +void ArrowAppender::ReleaseArray(ArrowArray *array) { + if (!array || !array->release) { + return; + } + array->release = nullptr; + auto holder = static_cast(array->private_data); + delete holder; +} + +//===--------------------------------------------------------------------===// +// Finalize Arrow Child +//===--------------------------------------------------------------------===// +ArrowArray *ArrowAppender::FinalizeChild(const LogicalType &type, ArrowAppendData &append_data) { + auto result = make_uniq(); + + result->private_data = nullptr; + result->release = ArrowAppender::ReleaseArray; + result->n_children = 0; + result->null_count = 0; + result->offset = 0; + result->dictionary = nullptr; + result->buffers = append_data.buffers.data(); + result->null_count = append_data.null_count; + result->length = append_data.row_count; + result->buffers[0] = append_data.validity.data(); + + if (append_data.finalize) { + append_data.finalize(append_data, type, result.get()); + } + + append_data.array = std::move(result); + return append_data.array.get(); +} + +//! Returns the underlying arrow array +ArrowArray ArrowAppender::Finalize() { + D_ASSERT(root_data.size() == types.size()); + auto root_holder = make_uniq(options); + + ArrowArray result; + root_holder->child_pointers.resize(types.size()); + result.children = root_holder->child_pointers.data(); + result.n_children = types.size(); + + // Configure root array + result.length = row_count; + result.n_buffers = 1; + result.buffers = root_holder->buffers.data(); // there is no actual buffer there since we don't have NULLs + result.offset = 0; + result.null_count = 0; // needs to be 0 + result.dictionary = nullptr; + root_holder->child_data = std::move(root_data); + + // FIXME: this violates a property of the arrow format, if root owns all the child memory then consumers can't move + // child arrays https://arrow.apache.org/docs/format/CDataInterface.html#moving-child-arrays + for (idx_t i = 0; i < root_holder->child_data.size(); i++) { + root_holder->child_pointers[i] = ArrowAppender::FinalizeChild(types[i], *root_holder->child_data[i]); + } + + // Release ownership to caller + result.private_data = root_holder.release(); + result.release = ArrowAppender::ReleaseArray; + return result; +} + +//===--------------------------------------------------------------------===// +// Initialize Arrow Child +//===--------------------------------------------------------------------===// + +template +static void InitializeAppenderForType(ArrowAppendData &append_data) { + append_data.initialize = OP::Initialize; + append_data.append_vector = OP::Append; + append_data.finalize = OP::Finalize; +} + +static void InitializeFunctionPointers(ArrowAppendData &append_data, const LogicalType &type) { + // handle special logical types + switch (type.id()) { + case LogicalTypeId::BOOLEAN: + InitializeAppenderForType(append_data); + break; + case LogicalTypeId::TINYINT: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::SMALLINT: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::DATE: + case LogicalTypeId::INTEGER: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::BIGINT: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::HUGEINT: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::UTINYINT: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::USMALLINT: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::UINTEGER: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::UBIGINT: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::FLOAT: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::DOUBLE: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + InitializeAppenderForType>(append_data); + break; + case PhysicalType::INT32: + InitializeAppenderForType>(append_data); + break; + case PhysicalType::INT64: + InitializeAppenderForType>(append_data); + break; + case PhysicalType::INT128: + InitializeAppenderForType>(append_data); + break; + default: + throw InternalException("Unsupported internal decimal type"); + } + break; + case LogicalTypeId::VARCHAR: + case LogicalTypeId::BLOB: + case LogicalTypeId::BIT: + if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { + InitializeAppenderForType>(append_data); + } else { + InitializeAppenderForType>(append_data); + } + break; + case LogicalTypeId::UUID: + if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { + InitializeAppenderForType>(append_data); + } else { + InitializeAppenderForType>(append_data); + } + break; + case LogicalTypeId::ENUM: + switch (type.InternalType()) { + case PhysicalType::UINT8: + InitializeAppenderForType>(append_data); + break; + case PhysicalType::UINT16: + InitializeAppenderForType>(append_data); + break; + case PhysicalType::UINT32: + InitializeAppenderForType>(append_data); + break; + default: + throw InternalException("Unsupported internal enum type"); + } + break; + case LogicalTypeId::INTERVAL: + InitializeAppenderForType>(append_data); + break; + case LogicalTypeId::UNION: + InitializeAppenderForType(append_data); + break; + case LogicalTypeId::STRUCT: + InitializeAppenderForType(append_data); + break; + case LogicalTypeId::LIST: + InitializeAppenderForType(append_data); + break; + case LogicalTypeId::MAP: + InitializeAppenderForType(append_data); + break; + default: + throw NotImplementedException("Unsupported type in DuckDB -> Arrow Conversion: %s\n", type.ToString()); + } +} + +unique_ptr ArrowAppender::InitializeChild(const LogicalType &type, idx_t capacity, + ClientProperties &options) { + auto result = make_uniq(options); + InitializeFunctionPointers(*result, type); + + auto byte_count = (capacity + 7) / 8; + result->validity.reserve(byte_count); + result->initialize(*result, type, capacity); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_converter.cpp b/src/duckdb/src/common/arrow/arrow_converter.cpp new file mode 100644 index 00000000..0ecc46e0 --- /dev/null +++ b/src/duckdb/src/common/arrow/arrow_converter.cpp @@ -0,0 +1,327 @@ +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/common/arrow/arrow_converter.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/sel_cache.hpp" +#include "duckdb/common/types/vector_cache.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/vector.hpp" +#include +#include "duckdb/common/arrow/arrow_appender.hpp" + +namespace duckdb { + +void ArrowConverter::ToArrowArray(DataChunk &input, ArrowArray *out_array, ClientProperties options) { + ArrowAppender appender(input.GetTypes(), input.size(), std::move(options)); + appender.Append(input, 0, input.size(), input.size()); + *out_array = appender.Finalize(); +} + +unsafe_unique_array AddName(const string &name) { + auto name_ptr = make_unsafe_uniq_array(name.size() + 1); + for (size_t i = 0; i < name.size(); i++) { + name_ptr[i] = name[i]; + } + name_ptr[name.size()] = '\0'; + return name_ptr; +} + +//===--------------------------------------------------------------------===// +// Arrow Schema +//===--------------------------------------------------------------------===// +struct DuckDBArrowSchemaHolder { + // unused in children + vector children; + // unused in children + vector children_ptrs; + //! used for nested structures + std::list> nested_children; + std::list> nested_children_ptr; + //! This holds strings created to represent decimal types + vector> owned_type_names; + vector> owned_column_names; +}; + +static void ReleaseDuckDBArrowSchema(ArrowSchema *schema) { + if (!schema || !schema->release) { + return; + } + schema->release = nullptr; + auto holder = static_cast(schema->private_data); + delete holder; +} + +void InitializeChild(ArrowSchema &child, DuckDBArrowSchemaHolder &root_holder, const string &name = "") { + //! Child is cleaned up by parent + child.private_data = nullptr; + child.release = ReleaseDuckDBArrowSchema; + + // Store the child schema + child.flags = ARROW_FLAG_NULLABLE; + root_holder.owned_type_names.push_back(AddName(name)); + + child.name = root_holder.owned_type_names.back().get(); + child.n_children = 0; + child.children = nullptr; + child.metadata = nullptr; + child.dictionary = nullptr; +} + +void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, + const ClientProperties &options); + +void SetArrowMapFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, + const ClientProperties &options) { + child.format = "+m"; + //! Map has one child which is a struct + child.n_children = 1; + root_holder.nested_children.emplace_back(); + root_holder.nested_children.back().resize(1); + root_holder.nested_children_ptr.emplace_back(); + root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); + InitializeChild(root_holder.nested_children.back()[0], root_holder); + child.children = &root_holder.nested_children_ptr.back()[0]; + child.children[0]->name = "entries"; + SetArrowFormat(root_holder, **child.children, ListType::GetChildType(type), options); +} + +void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, + const ClientProperties &options) { + switch (type.id()) { + case LogicalTypeId::BOOLEAN: + child.format = "b"; + break; + case LogicalTypeId::TINYINT: + child.format = "c"; + break; + case LogicalTypeId::SMALLINT: + child.format = "s"; + break; + case LogicalTypeId::INTEGER: + child.format = "i"; + break; + case LogicalTypeId::BIGINT: + child.format = "l"; + break; + case LogicalTypeId::UTINYINT: + child.format = "C"; + break; + case LogicalTypeId::USMALLINT: + child.format = "S"; + break; + case LogicalTypeId::UINTEGER: + child.format = "I"; + break; + case LogicalTypeId::UBIGINT: + child.format = "L"; + break; + case LogicalTypeId::FLOAT: + child.format = "f"; + break; + case LogicalTypeId::HUGEINT: + child.format = "d:38,0"; + break; + case LogicalTypeId::DOUBLE: + child.format = "g"; + break; + case LogicalTypeId::UUID: + case LogicalTypeId::VARCHAR: + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + child.format = "U"; + } else { + child.format = "u"; + } + break; + case LogicalTypeId::DATE: + child.format = "tdD"; + break; +#ifdef DUCKDB_WASM + case LogicalTypeId::TIME_TZ: +#endif + case LogicalTypeId::TIME: + child.format = "ttu"; + break; + case LogicalTypeId::TIMESTAMP: + child.format = "tsu:"; + break; + case LogicalTypeId::TIMESTAMP_TZ: { + string format = "tsu:" + options.time_zone; + root_holder.owned_type_names.push_back(AddName(format)); + child.format = root_holder.owned_type_names.back().get(); + break; + } + case LogicalTypeId::TIMESTAMP_SEC: + child.format = "tss:"; + break; + case LogicalTypeId::TIMESTAMP_NS: + child.format = "tsn:"; + break; + case LogicalTypeId::TIMESTAMP_MS: + child.format = "tsm:"; + break; + case LogicalTypeId::INTERVAL: + child.format = "tin"; + break; + case LogicalTypeId::DECIMAL: { + uint8_t width, scale; + type.GetDecimalProperties(width, scale); + string format = "d:" + to_string(width) + "," + to_string(scale); + root_holder.owned_type_names.push_back(AddName(format)); + child.format = root_holder.owned_type_names.back().get(); + break; + } + case LogicalTypeId::SQLNULL: { + child.format = "n"; + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::BIT: { + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + child.format = "Z"; + } else { + child.format = "z"; + } + break; + } + case LogicalTypeId::LIST: { + child.format = "+l"; + child.n_children = 1; + root_holder.nested_children.emplace_back(); + root_holder.nested_children.back().resize(1); + root_holder.nested_children_ptr.emplace_back(); + root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); + InitializeChild(root_holder.nested_children.back()[0], root_holder); + child.children = &root_holder.nested_children_ptr.back()[0]; + child.children[0]->name = "l"; + SetArrowFormat(root_holder, **child.children, ListType::GetChildType(type), options); + break; + } + case LogicalTypeId::STRUCT: { + child.format = "+s"; + auto &child_types = StructType::GetChildTypes(type); + child.n_children = child_types.size(); + root_holder.nested_children.emplace_back(); + root_holder.nested_children.back().resize(child_types.size()); + root_holder.nested_children_ptr.emplace_back(); + root_holder.nested_children_ptr.back().resize(child_types.size()); + for (idx_t type_idx = 0; type_idx < child_types.size(); type_idx++) { + root_holder.nested_children_ptr.back()[type_idx] = &root_holder.nested_children.back()[type_idx]; + } + child.children = &root_holder.nested_children_ptr.back()[0]; + for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { + + InitializeChild(*child.children[type_idx], root_holder); + + root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); + + child.children[type_idx]->name = root_holder.owned_type_names.back().get(); + SetArrowFormat(root_holder, *child.children[type_idx], child_types[type_idx].second, options); + } + break; + } + case LogicalTypeId::MAP: { + SetArrowMapFormat(root_holder, child, type, options); + break; + } + case LogicalTypeId::UNION: { + std::string format = "+us:"; + + auto &child_types = UnionType::CopyMemberTypes(type); + child.n_children = child_types.size(); + root_holder.nested_children.emplace_back(); + root_holder.nested_children.back().resize(child_types.size()); + root_holder.nested_children_ptr.emplace_back(); + root_holder.nested_children_ptr.back().resize(child_types.size()); + for (idx_t type_idx = 0; type_idx < child_types.size(); type_idx++) { + root_holder.nested_children_ptr.back()[type_idx] = &root_holder.nested_children.back()[type_idx]; + } + child.children = &root_holder.nested_children_ptr.back()[0]; + for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { + + InitializeChild(*child.children[type_idx], root_holder); + + root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); + + child.children[type_idx]->name = root_holder.owned_type_names.back().get(); + SetArrowFormat(root_holder, *child.children[type_idx], child_types[type_idx].second, options); + + format += to_string(type_idx) + ","; + } + + format.pop_back(); + + root_holder.owned_type_names.push_back(AddName(format)); + child.format = root_holder.owned_type_names.back().get(); + + break; + } + case LogicalTypeId::ENUM: { + // TODO what do we do with pointer enums here? + switch (EnumType::GetPhysicalType(type)) { + case PhysicalType::UINT8: + child.format = "C"; + break; + case PhysicalType::UINT16: + child.format = "S"; + break; + case PhysicalType::UINT32: + child.format = "I"; + break; + default: + throw InternalException("Unsupported Enum Internal Type"); + } + root_holder.nested_children.emplace_back(); + root_holder.nested_children.back().resize(1); + root_holder.nested_children_ptr.emplace_back(); + root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); + InitializeChild(root_holder.nested_children.back()[0], root_holder); + child.dictionary = root_holder.nested_children_ptr.back()[0]; + child.dictionary->format = "u"; + break; + } + default: + throw NotImplementedException("Unsupported Arrow type " + type.ToString()); + } +} + +void ArrowConverter::ToArrowSchema(ArrowSchema *out_schema, const vector &types, + const vector &names, const ClientProperties &options) { + D_ASSERT(out_schema); + D_ASSERT(types.size() == names.size()); + idx_t column_count = types.size(); + // Allocate as unique_ptr first to cleanup properly on error + auto root_holder = make_uniq(); + + // Allocate the children + root_holder->children.resize(column_count); + root_holder->children_ptrs.resize(column_count, nullptr); + for (size_t i = 0; i < column_count; ++i) { + root_holder->children_ptrs[i] = &root_holder->children[i]; + } + out_schema->children = root_holder->children_ptrs.data(); + out_schema->n_children = column_count; + + // Store the schema + out_schema->format = "+s"; // struct apparently + out_schema->flags = 0; + out_schema->metadata = nullptr; + out_schema->name = "duckdb_query_result"; + out_schema->dictionary = nullptr; + + // Configure all child schemas + for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { + root_holder->owned_column_names.push_back(AddName(names[col_idx])); + auto &child = root_holder->children[col_idx]; + InitializeChild(child, *root_holder, names[col_idx]); + SetArrowFormat(*root_holder, child, types[col_idx], options); + } + + // Release ownership to caller + out_schema->private_data = root_holder.release(); + out_schema->release = ReleaseDuckDBArrowSchema; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_wrapper.cpp b/src/duckdb/src/common/arrow/arrow_wrapper.cpp new file mode 100644 index 00000000..68170c96 --- /dev/null +++ b/src/duckdb/src/common/arrow/arrow_wrapper.cpp @@ -0,0 +1,221 @@ +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/common/arrow/arrow_converter.hpp" + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" + +#include "duckdb/main/stream_query_result.hpp" + +#include "duckdb/common/arrow/result_arrow_wrapper.hpp" +#include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/main/query_result.hpp" +#include "duckdb/main/chunk_scan_state/query_result.hpp" + +namespace duckdb { + +ArrowSchemaWrapper::~ArrowSchemaWrapper() { + if (arrow_schema.release) { + arrow_schema.release(&arrow_schema); + arrow_schema.release = nullptr; + } +} + +ArrowArrayWrapper::~ArrowArrayWrapper() { + if (arrow_array.release) { + arrow_array.release(&arrow_array); + arrow_array.release = nullptr; + } +} + +ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() { + if (arrow_array_stream.release) { + arrow_array_stream.release(&arrow_array_stream); + arrow_array_stream.release = nullptr; + } +} + +void ArrowArrayStreamWrapper::GetSchema(ArrowSchemaWrapper &schema) { + D_ASSERT(arrow_array_stream.get_schema); + // LCOV_EXCL_START + if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema)) { + throw InvalidInputException("arrow_scan: get_schema failed(): %s", string(GetError())); + } + if (!schema.arrow_schema.release) { + throw InvalidInputException("arrow_scan: released schema passed"); + } + if (schema.arrow_schema.n_children < 1) { + throw InvalidInputException("arrow_scan: empty schema passed"); + } + // LCOV_EXCL_STOP +} + +shared_ptr ArrowArrayStreamWrapper::GetNextChunk() { + auto current_chunk = make_shared(); + if (arrow_array_stream.get_next(&arrow_array_stream, ¤t_chunk->arrow_array)) { // LCOV_EXCL_START + throw InvalidInputException("arrow_scan: get_next failed(): %s", string(GetError())); + } // LCOV_EXCL_STOP + + return current_chunk; +} + +const char *ArrowArrayStreamWrapper::GetError() { // LCOV_EXCL_START + return arrow_array_stream.get_last_error(&arrow_array_stream); +} // LCOV_EXCL_STOP + +int ResultArrowArrayStreamWrapper::MyStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { + if (!stream->release) { + return -1; + } + auto my_stream = reinterpret_cast(stream->private_data); + if (!my_stream->column_types.empty()) { + ArrowConverter::ToArrowSchema(out, my_stream->column_types, my_stream->column_names, + my_stream->result->client_properties); + return 0; + } + + auto &result = *my_stream->result; + if (result.HasError()) { + my_stream->last_error = result.GetErrorObject(); + return -1; + } + if (result.type == QueryResultType::STREAM_RESULT) { + auto &stream_result = result.Cast(); + if (!stream_result.IsOpen()) { + my_stream->last_error = PreservedError("Query Stream is closed"); + return -1; + } + } + if (my_stream->column_types.empty()) { + my_stream->column_types = result.types; + my_stream->column_names = result.names; + } + ArrowConverter::ToArrowSchema(out, my_stream->column_types, my_stream->column_names, + my_stream->result->client_properties); + return 0; +} + +int ResultArrowArrayStreamWrapper::MyStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) { + if (!stream->release) { + return -1; + } + auto my_stream = reinterpret_cast(stream->private_data); + auto &result = *my_stream->result; + auto &scan_state = *my_stream->scan_state; + if (result.HasError()) { + my_stream->last_error = result.GetErrorObject(); + return -1; + } + if (result.type == QueryResultType::STREAM_RESULT) { + auto &stream_result = result.Cast(); + if (!stream_result.IsOpen()) { + // Nothing to output + out->release = nullptr; + return 0; + } + } + if (my_stream->column_types.empty()) { + my_stream->column_types = result.types; + my_stream->column_names = result.names; + } + idx_t result_count; + PreservedError error; + if (!ArrowUtil::TryFetchChunk(scan_state, result.client_properties, my_stream->batch_size, out, result_count, + error)) { + D_ASSERT(error); + my_stream->last_error = error; + return -1; + } + if (result_count == 0) { + // Nothing to output + out->release = nullptr; + } + return 0; +} + +void ResultArrowArrayStreamWrapper::MyStreamRelease(struct ArrowArrayStream *stream) { + if (!stream || !stream->release) { + return; + } + stream->release = nullptr; + delete reinterpret_cast(stream->private_data); +} + +const char *ResultArrowArrayStreamWrapper::MyStreamGetLastError(struct ArrowArrayStream *stream) { + if (!stream->release) { + return "stream was released"; + } + D_ASSERT(stream->private_data); + auto my_stream = reinterpret_cast(stream->private_data); + return my_stream->last_error.Message().c_str(); +} + +ResultArrowArrayStreamWrapper::ResultArrowArrayStreamWrapper(unique_ptr result_p, idx_t batch_size_p) + : result(std::move(result_p)), scan_state(make_uniq(*result)) { + //! We first initialize the private data of the stream + stream.private_data = this; + //! Ceil Approx_Batch_Size/STANDARD_VECTOR_SIZE + if (batch_size_p == 0) { + throw std::runtime_error("Approximate Batch Size of Record Batch MUST be higher than 0"); + } + batch_size = batch_size_p; + //! We initialize the stream functions + stream.get_schema = ResultArrowArrayStreamWrapper::MyStreamGetSchema; + stream.get_next = ResultArrowArrayStreamWrapper::MyStreamGetNext; + stream.release = ResultArrowArrayStreamWrapper::MyStreamRelease; + stream.get_last_error = ResultArrowArrayStreamWrapper::MyStreamGetLastError; +} + +bool ArrowUtil::TryFetchChunk(ChunkScanState &scan_state, ClientProperties options, idx_t batch_size, ArrowArray *out, + idx_t &count, PreservedError &error) { + count = 0; + ArrowAppender appender(scan_state.Types(), batch_size, std::move(options)); + auto remaining_tuples_in_chunk = scan_state.RemainingInChunk(); + if (remaining_tuples_in_chunk) { + // We start by scanning the non-finished current chunk + idx_t cur_consumption = MinValue(remaining_tuples_in_chunk, batch_size); + count += cur_consumption; + auto ¤t_chunk = scan_state.CurrentChunk(); + appender.Append(current_chunk, scan_state.CurrentOffset(), scan_state.CurrentOffset() + cur_consumption, + current_chunk.size()); + scan_state.IncreaseOffset(cur_consumption); + } + while (count < batch_size) { + if (!scan_state.LoadNextChunk(error)) { + if (scan_state.HasError()) { + error = scan_state.GetError(); + } + return false; + } + if (scan_state.ChunkIsEmpty()) { + // The scan was successful, but an empty chunk was returned + break; + } + auto ¤t_chunk = scan_state.CurrentChunk(); + if (scan_state.Finished() || current_chunk.size() == 0) { + break; + } + // The amount we still need to append into this chunk + auto remaining = batch_size - count; + + // The amount remaining, capped by the amount left in the current chunk + auto to_append_to_batch = MinValue(remaining, scan_state.RemainingInChunk()); + appender.Append(current_chunk, 0, to_append_to_batch, current_chunk.size()); + count += to_append_to_batch; + scan_state.IncreaseOffset(to_append_to_batch); + } + if (count > 0) { + *out = appender.Finalize(); + } + return true; +} + +idx_t ArrowUtil::FetchChunk(ChunkScanState &scan_state, ClientProperties options, idx_t chunk_size, ArrowArray *out) { + PreservedError error; + idx_t result_count; + if (!TryFetchChunk(scan_state, std::move(options), chunk_size, out, result_count, error)) { + error.Throw(); + } + return result_count; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/assert.cpp b/src/duckdb/src/common/assert.cpp new file mode 100644 index 00000000..3397fce2 --- /dev/null +++ b/src/duckdb/src/common/assert.cpp @@ -0,0 +1,17 @@ +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +void DuckDBAssertInternal(bool condition, const char *condition_name, const char *file, int linenr) { +#ifdef DISABLE_ASSERTIONS + return; +#endif + if (condition) { + return; + } + throw InternalException("Assertion triggered in file \"%s\" on line %d: %s%s", file, linenr, condition_name, + Exception::GetStackTrace()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/bind_helpers.cpp b/src/duckdb/src/common/bind_helpers.cpp new file mode 100644 index 00000000..9074b3e7 --- /dev/null +++ b/src/duckdb/src/common/bind_helpers.cpp @@ -0,0 +1,122 @@ +#include "duckdb/common/bind_helpers.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include + +namespace duckdb { + +Value ConvertVectorToValue(vector set) { + if (set.empty()) { + return Value::EMPTYLIST(LogicalType::BOOLEAN); + } + return Value::LIST(std::move(set)); +} + +vector ParseColumnList(const vector &set, vector &names, const string &loption) { + vector result; + + if (set.empty()) { + throw BinderException("\"%s\" expects a column list or * as parameter", loption); + } + // list of options: parse the list + case_insensitive_map_t option_map; + for (idx_t i = 0; i < set.size(); i++) { + option_map[set[i].ToString()] = false; + } + result.resize(names.size(), false); + for (idx_t i = 0; i < names.size(); i++) { + auto entry = option_map.find(names[i]); + if (entry != option_map.end()) { + result[i] = true; + entry->second = true; + } + } + for (auto &entry : option_map) { + if (!entry.second) { + throw BinderException("\"%s\" expected to find %s, but it was not found in the table", loption, + entry.first.c_str()); + } + } + return result; +} + +vector ParseColumnList(const Value &value, vector &names, const string &loption) { + vector result; + + // Only accept a list of arguments + if (value.type().id() != LogicalTypeId::LIST) { + // Support a single argument if it's '*' + if (value.type().id() == LogicalTypeId::VARCHAR && value.GetValue() == "*") { + result.resize(names.size(), true); + return result; + } + throw BinderException("\"%s\" expects a column list or * as parameter", loption); + } + auto &children = ListValue::GetChildren(value); + // accept '*' as single argument + if (children.size() == 1 && children[0].type().id() == LogicalTypeId::VARCHAR && + children[0].GetValue() == "*") { + result.resize(names.size(), true); + return result; + } + return ParseColumnList(children, names, loption); +} + +vector ParseColumnsOrdered(const vector &set, vector &names, const string &loption) { + vector result; + + if (set.empty()) { + throw BinderException("\"%s\" expects a column list or * as parameter", loption); + } + + // Maps option to bool indicating if its found and the index in the original set + case_insensitive_map_t> option_map; + for (idx_t i = 0; i < set.size(); i++) { + option_map[set[i].ToString()] = {false, i}; + } + result.resize(option_map.size()); + + for (idx_t i = 0; i < names.size(); i++) { + auto entry = option_map.find(names[i]); + if (entry != option_map.end()) { + result[entry->second.second] = i; + entry->second.first = true; + } + } + for (auto &entry : option_map) { + if (!entry.second.first) { + throw BinderException("\"%s\" expected to find %s, but it was not found in the table", loption, + entry.first.c_str()); + } + } + return result; +} + +vector ParseColumnsOrdered(const Value &value, vector &names, const string &loption) { + vector result; + + // Only accept a list of arguments + if (value.type().id() != LogicalTypeId::LIST) { + // Support a single argument if it's '*' + if (value.type().id() == LogicalTypeId::VARCHAR && value.GetValue() == "*") { + result.resize(names.size(), 0); + std::iota(std::begin(result), std::end(result), 0); + return result; + } + throw BinderException("\"%s\" expects a column list or * as parameter", loption); + } + auto &children = ListValue::GetChildren(value); + // accept '*' as single argument + if (children.size() == 1 && children[0].type().id() == LogicalTypeId::VARCHAR && + children[0].GetValue() == "*") { + result.resize(names.size(), 0); + std::iota(std::begin(result), std::end(result), 0); + return result; + } + return ParseColumnsOrdered(children, names, loption); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/box_renderer.cpp b/src/duckdb/src/common/box_renderer.cpp new file mode 100644 index 00000000..bdb4afa7 --- /dev/null +++ b/src/duckdb/src/common/box_renderer.cpp @@ -0,0 +1,748 @@ +#include "duckdb/common/box_renderer.hpp" + +#include "duckdb/common/printer.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "utf8proc_wrapper.hpp" + +#include + +namespace duckdb { + +const idx_t BoxRenderer::SPLIT_COLUMN = idx_t(-1); + +BoxRenderer::BoxRenderer(BoxRendererConfig config_p) : config(std::move(config_p)) { +} + +string BoxRenderer::ToString(ClientContext &context, const vector &names, const ColumnDataCollection &result) { + std::stringstream ss; + Render(context, names, result, ss); + return ss.str(); +} + +void BoxRenderer::Print(ClientContext &context, const vector &names, const ColumnDataCollection &result) { + Printer::Print(ToString(context, names, result)); +} + +void BoxRenderer::RenderValue(std::ostream &ss, const string &value, idx_t column_width, + ValueRenderAlignment alignment) { + auto render_width = Utf8Proc::RenderWidth(value); + + const string *render_value = &value; + string small_value; + if (render_width > column_width) { + // the string is too large to fit in this column! + // the size of this column must have been reduced + // figure out how much of this value we can render + idx_t pos = 0; + idx_t current_render_width = config.DOTDOTDOT_LENGTH; + while (pos < value.size()) { + // check if this character fits... + auto char_size = Utf8Proc::RenderWidth(value.c_str(), value.size(), pos); + if (current_render_width + char_size >= column_width) { + // it doesn't! stop + break; + } + // it does! move to the next character + current_render_width += char_size; + pos = Utf8Proc::NextGraphemeCluster(value.c_str(), value.size(), pos); + } + small_value = value.substr(0, pos) + config.DOTDOTDOT; + render_value = &small_value; + render_width = current_render_width; + } + auto padding_count = (column_width - render_width) + 2; + idx_t lpadding; + idx_t rpadding; + switch (alignment) { + case ValueRenderAlignment::LEFT: + lpadding = 1; + rpadding = padding_count - 1; + break; + case ValueRenderAlignment::MIDDLE: + lpadding = padding_count / 2; + rpadding = padding_count - lpadding; + break; + case ValueRenderAlignment::RIGHT: + lpadding = padding_count - 1; + rpadding = 1; + break; + default: + throw InternalException("Unrecognized value renderer alignment"); + } + ss << config.VERTICAL; + ss << string(lpadding, ' '); + ss << *render_value; + ss << string(rpadding, ' '); +} + +string BoxRenderer::RenderType(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return "int8"; + case LogicalTypeId::SMALLINT: + return "int16"; + case LogicalTypeId::INTEGER: + return "int32"; + case LogicalTypeId::BIGINT: + return "int64"; + case LogicalTypeId::HUGEINT: + return "int128"; + case LogicalTypeId::UTINYINT: + return "uint8"; + case LogicalTypeId::USMALLINT: + return "uint16"; + case LogicalTypeId::UINTEGER: + return "uint32"; + case LogicalTypeId::UBIGINT: + return "uint64"; + case LogicalTypeId::LIST: { + auto child = RenderType(ListType::GetChildType(type)); + return child + "[]"; + } + default: + return StringUtil::Lower(type.ToString()); + } +} + +ValueRenderAlignment BoxRenderer::TypeAlignment(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + return ValueRenderAlignment::RIGHT; + default: + return ValueRenderAlignment::LEFT; + } +} + +list BoxRenderer::FetchRenderCollections(ClientContext &context, + const ColumnDataCollection &result, idx_t top_rows, + idx_t bottom_rows) { + auto column_count = result.ColumnCount(); + vector varchar_types; + for (idx_t c = 0; c < column_count; c++) { + varchar_types.emplace_back(LogicalType::VARCHAR); + } + std::list collections; + collections.emplace_back(context, varchar_types); + collections.emplace_back(context, varchar_types); + + auto &top_collection = collections.front(); + auto &bottom_collection = collections.back(); + + DataChunk fetch_result; + fetch_result.Initialize(context, result.Types()); + + DataChunk insert_result; + insert_result.Initialize(context, varchar_types); + + // fetch the top rows from the ColumnDataCollection + idx_t chunk_idx = 0; + idx_t row_idx = 0; + while (row_idx < top_rows) { + fetch_result.Reset(); + insert_result.Reset(); + // fetch the next chunk + result.FetchChunk(chunk_idx, fetch_result); + idx_t insert_count = MinValue(fetch_result.size(), top_rows - row_idx); + + // cast all columns to varchar + for (idx_t c = 0; c < column_count; c++) { + VectorOperations::Cast(context, fetch_result.data[c], insert_result.data[c], insert_count); + } + insert_result.SetCardinality(insert_count); + + // construct the render collection + top_collection.Append(insert_result); + + chunk_idx++; + row_idx += fetch_result.size(); + } + + // fetch the bottom rows from the ColumnDataCollection + row_idx = 0; + chunk_idx = result.ChunkCount() - 1; + while (row_idx < bottom_rows) { + fetch_result.Reset(); + insert_result.Reset(); + // fetch the next chunk + result.FetchChunk(chunk_idx, fetch_result); + idx_t insert_count = MinValue(fetch_result.size(), bottom_rows - row_idx); + + // invert the rows + SelectionVector inverted_sel(insert_count); + for (idx_t r = 0; r < insert_count; r++) { + inverted_sel.set_index(r, fetch_result.size() - r - 1); + } + + for (idx_t c = 0; c < column_count; c++) { + Vector slice(fetch_result.data[c], inverted_sel, insert_count); + VectorOperations::Cast(context, slice, insert_result.data[c], insert_count); + } + insert_result.SetCardinality(insert_count); + // construct the render collection + bottom_collection.Append(insert_result); + + chunk_idx--; + row_idx += fetch_result.size(); + } + return collections; +} + +list BoxRenderer::PivotCollections(ClientContext &context, list input, + vector &column_names, + vector &result_types, idx_t row_count) { + auto &top = input.front(); + auto &bottom = input.back(); + + vector varchar_types; + vector new_names; + new_names.emplace_back("Column"); + new_names.emplace_back("Type"); + varchar_types.emplace_back(LogicalType::VARCHAR); + varchar_types.emplace_back(LogicalType::VARCHAR); + for (idx_t r = 0; r < top.Count(); r++) { + new_names.emplace_back("Row " + to_string(r + 1)); + varchar_types.emplace_back(LogicalType::VARCHAR); + } + for (idx_t r = 0; r < bottom.Count(); r++) { + auto row_index = row_count - bottom.Count() + r + 1; + new_names.emplace_back("Row " + to_string(row_index)); + varchar_types.emplace_back(LogicalType::VARCHAR); + } + // + DataChunk row_chunk; + row_chunk.Initialize(Allocator::DefaultAllocator(), varchar_types); + std::list result; + result.emplace_back(context, varchar_types); + result.emplace_back(context, varchar_types); + auto &res_coll = result.front(); + ColumnDataAppendState append_state; + res_coll.InitializeAppend(append_state); + for (idx_t c = 0; c < top.ColumnCount(); c++) { + vector column_ids {c}; + auto row_index = row_chunk.size(); + idx_t current_index = 0; + row_chunk.SetValue(current_index++, row_index, column_names[c]); + row_chunk.SetValue(current_index++, row_index, RenderType(result_types[c])); + for (auto &collection : input) { + for (auto &chunk : collection.Chunks(column_ids)) { + for (idx_t r = 0; r < chunk.size(); r++) { + row_chunk.SetValue(current_index++, row_index, chunk.GetValue(0, r)); + } + } + } + row_chunk.SetCardinality(row_chunk.size() + 1); + if (row_chunk.size() == STANDARD_VECTOR_SIZE || c + 1 == top.ColumnCount()) { + res_coll.Append(append_state, row_chunk); + row_chunk.Reset(); + } + } + column_names = std::move(new_names); + result_types = std::move(varchar_types); + return result; +} + +string ConvertRenderValue(const string &input) { + return StringUtil::Replace(StringUtil::Replace(input, "\n", "\\n"), string("\0", 1), "\\0"); +} + +string BoxRenderer::GetRenderValue(ColumnDataRowCollection &rows, idx_t c, idx_t r) { + try { + auto row = rows.GetValue(c, r); + if (row.IsNull()) { + return config.null_value; + } + return ConvertRenderValue(StringValue::Get(row)); + } catch (std::exception &ex) { + return "????INVALID VALUE - " + string(ex.what()) + "?????"; + } +} + +vector BoxRenderer::ComputeRenderWidths(const vector &names, const vector &result_types, + list &collections, idx_t min_width, + idx_t max_width, vector &column_map, idx_t &total_length) { + auto column_count = result_types.size(); + + vector widths; + widths.reserve(column_count); + for (idx_t c = 0; c < column_count; c++) { + auto name_width = Utf8Proc::RenderWidth(ConvertRenderValue(names[c])); + auto type_width = Utf8Proc::RenderWidth(RenderType(result_types[c])); + widths.push_back(MaxValue(name_width, type_width)); + } + + // now iterate over the data in the render collection and find out the true max width + for (auto &collection : collections) { + for (auto &chunk : collection.Chunks()) { + for (idx_t c = 0; c < column_count; c++) { + auto string_data = FlatVector::GetData(chunk.data[c]); + for (idx_t r = 0; r < chunk.size(); r++) { + string render_value; + if (FlatVector::IsNull(chunk.data[c], r)) { + render_value = config.null_value; + } else { + render_value = ConvertRenderValue(string_data[r].GetString()); + } + auto render_width = Utf8Proc::RenderWidth(render_value); + widths[c] = MaxValue(render_width, widths[c]); + } + } + } + } + + // figure out the total length + // we start off with a pipe (|) + total_length = 1; + for (idx_t c = 0; c < widths.size(); c++) { + // each column has a space at the beginning, and a space plus a pipe (|) at the end + // hence + 3 + total_length += widths[c] + 3; + } + if (total_length < min_width) { + // if there are hidden rows we should always display that + // stretch up the first column until we have space to show the row count + widths[0] += min_width - total_length; + total_length = min_width; + } + // now we need to constrain the length + unordered_set pruned_columns; + if (total_length > max_width) { + // before we remove columns, check if we can just reduce the size of columns + for (auto &w : widths) { + if (w > config.max_col_width) { + auto max_diff = w - config.max_col_width; + if (total_length - max_diff <= max_width) { + // if we reduce the size of this column we fit within the limits! + // reduce the width exactly enough so that the box fits + w -= total_length - max_width; + total_length = max_width; + break; + } else { + // reducing the width of this column does not make the result fit + // reduce the column width by the maximum amount anyway + w = config.max_col_width; + total_length -= max_diff; + } + } + } + + if (total_length > max_width) { + // the total length is still too large + // we need to remove columns! + // first, we add 6 characters to the total length + // this is what we need to add the "..." in the middle + total_length += 3 + config.DOTDOTDOT_LENGTH; + // now select columns to prune + // we select columns in zig-zag order starting from the middle + // e.g. if we have 10 columns, we remove #5, then #4, then #6, then #3, then #7, etc + int64_t offset = 0; + while (total_length > max_width) { + idx_t c = column_count / 2 + offset; + total_length -= widths[c] + 3; + pruned_columns.insert(c); + if (offset >= 0) { + offset = -offset - 1; + } else { + offset = -offset; + } + } + } + } + + bool added_split_column = false; + vector new_widths; + for (idx_t c = 0; c < column_count; c++) { + if (pruned_columns.find(c) == pruned_columns.end()) { + column_map.push_back(c); + new_widths.push_back(widths[c]); + } else { + if (!added_split_column) { + // "..." + column_map.push_back(SPLIT_COLUMN); + new_widths.push_back(config.DOTDOTDOT_LENGTH); + added_split_column = true; + } + } + } + return new_widths; +} + +void BoxRenderer::RenderHeader(const vector &names, const vector &result_types, + const vector &column_map, const vector &widths, + const vector &boundaries, idx_t total_length, bool has_results, + std::ostream &ss) { + auto column_count = column_map.size(); + // render the top line + ss << config.LTCORNER; + idx_t column_index = 0; + for (idx_t k = 0; k < total_length - 2; k++) { + if (column_index + 1 < column_count && k == boundaries[column_index]) { + ss << config.TMIDDLE; + column_index++; + } else { + ss << config.HORIZONTAL; + } + } + ss << config.RTCORNER; + ss << std::endl; + + // render the header names + for (idx_t c = 0; c < column_count; c++) { + auto column_idx = column_map[c]; + string name; + if (column_idx == SPLIT_COLUMN) { + name = config.DOTDOTDOT; + } else { + name = ConvertRenderValue(names[column_idx]); + } + RenderValue(ss, name, widths[c]); + } + ss << config.VERTICAL; + ss << std::endl; + + // render the types + if (config.render_mode == RenderMode::ROWS) { + for (idx_t c = 0; c < column_count; c++) { + auto column_idx = column_map[c]; + auto type = column_idx == SPLIT_COLUMN ? "" : RenderType(result_types[column_idx]); + RenderValue(ss, type, widths[c]); + } + ss << config.VERTICAL; + ss << std::endl; + } + + // render the line under the header + ss << config.LMIDDLE; + column_index = 0; + for (idx_t k = 0; k < total_length - 2; k++) { + if (has_results && column_index + 1 < column_count && k == boundaries[column_index]) { + ss << config.MIDDLE; + column_index++; + } else { + ss << config.HORIZONTAL; + } + } + ss << config.RMIDDLE; + ss << std::endl; +} + +void BoxRenderer::RenderValues(const list &collections, const vector &column_map, + const vector &widths, const vector &result_types, std::ostream &ss) { + auto &top_collection = collections.front(); + auto &bottom_collection = collections.back(); + // render the top rows + auto top_rows = top_collection.Count(); + auto bottom_rows = bottom_collection.Count(); + auto column_count = column_map.size(); + + vector alignments; + if (config.render_mode == RenderMode::ROWS) { + for (idx_t c = 0; c < column_count; c++) { + auto column_idx = column_map[c]; + if (column_idx == SPLIT_COLUMN) { + alignments.push_back(ValueRenderAlignment::MIDDLE); + } else { + alignments.push_back(TypeAlignment(result_types[column_idx])); + } + } + } + + auto rows = top_collection.GetRows(); + for (idx_t r = 0; r < top_rows; r++) { + for (idx_t c = 0; c < column_count; c++) { + auto column_idx = column_map[c]; + string str; + if (column_idx == SPLIT_COLUMN) { + str = config.DOTDOTDOT; + } else { + str = GetRenderValue(rows, column_idx, r); + } + ValueRenderAlignment alignment; + if (config.render_mode == RenderMode::ROWS) { + alignment = alignments[c]; + } else { + if (c < 2) { + alignment = ValueRenderAlignment::LEFT; + } else if (c == SPLIT_COLUMN) { + alignment = ValueRenderAlignment::MIDDLE; + } else { + alignment = ValueRenderAlignment::RIGHT; + } + } + RenderValue(ss, str, widths[c], alignment); + } + ss << config.VERTICAL; + ss << std::endl; + } + + if (bottom_rows > 0) { + if (config.render_mode == RenderMode::COLUMNS) { + throw InternalException("Columns render mode does not support bottom rows"); + } + // render the bottom rows + // first render the divider + auto brows = bottom_collection.GetRows(); + for (idx_t k = 0; k < 3; k++) { + for (idx_t c = 0; c < column_count; c++) { + auto column_idx = column_map[c]; + string str; + auto alignment = alignments[c]; + if (alignment == ValueRenderAlignment::MIDDLE || column_idx == SPLIT_COLUMN) { + str = config.DOT; + } else { + // align the dots in the center of the column + auto top_value = GetRenderValue(rows, column_idx, top_rows - 1); + auto bottom_value = GetRenderValue(brows, column_idx, bottom_rows - 1); + auto top_length = MinValue(widths[c], Utf8Proc::RenderWidth(top_value)); + auto bottom_length = MinValue(widths[c], Utf8Proc::RenderWidth(bottom_value)); + auto dot_length = MinValue(top_length, bottom_length); + if (top_length == 0) { + dot_length = bottom_length; + } else if (bottom_length == 0) { + dot_length = top_length; + } + if (dot_length > 1) { + auto padding = dot_length - 1; + idx_t left_padding, right_padding; + switch (alignment) { + case ValueRenderAlignment::LEFT: + left_padding = padding / 2; + right_padding = padding - left_padding; + break; + case ValueRenderAlignment::RIGHT: + right_padding = padding / 2; + left_padding = padding - right_padding; + break; + default: + throw InternalException("Unrecognized value renderer alignment"); + } + str = string(left_padding, ' ') + config.DOT + string(right_padding, ' '); + } else { + if (dot_length == 0) { + // everything is empty + alignment = ValueRenderAlignment::MIDDLE; + } + str = config.DOT; + } + } + RenderValue(ss, str, widths[c], alignment); + } + ss << config.VERTICAL; + ss << std::endl; + } + // note that the bottom rows are in reverse order + for (idx_t r = 0; r < bottom_rows; r++) { + for (idx_t c = 0; c < column_count; c++) { + auto column_idx = column_map[c]; + string str; + if (column_idx == SPLIT_COLUMN) { + str = config.DOTDOTDOT; + } else { + str = GetRenderValue(brows, column_idx, bottom_rows - r - 1); + } + RenderValue(ss, str, widths[c], alignments[c]); + } + ss << config.VERTICAL; + ss << std::endl; + } + } +} + +void BoxRenderer::RenderRowCount(string row_count_str, string shown_str, const string &column_count_str, + const vector &boundaries, bool has_hidden_rows, bool has_hidden_columns, + idx_t total_length, idx_t row_count, idx_t column_count, idx_t minimum_row_length, + std::ostream &ss) { + // check if we can merge the row_count_str and the shown_str + bool display_shown_separately = has_hidden_rows; + if (has_hidden_rows && total_length >= row_count_str.size() + shown_str.size() + 5) { + // we can! + row_count_str += " " + shown_str; + shown_str = string(); + display_shown_separately = false; + minimum_row_length = row_count_str.size() + 4; + } + auto minimum_length = row_count_str.size() + column_count_str.size() + 6; + bool render_rows_and_columns = total_length >= minimum_length && + ((has_hidden_columns && row_count > 0) || (row_count >= 10 && column_count > 1)); + bool render_rows = total_length >= minimum_row_length && (row_count == 0 || row_count >= 10); + bool render_anything = true; + if (!render_rows && !render_rows_and_columns) { + render_anything = false; + } + // render the bottom of the result values, if there are any + if (row_count > 0) { + ss << (render_anything ? config.LMIDDLE : config.LDCORNER); + idx_t column_index = 0; + for (idx_t k = 0; k < total_length - 2; k++) { + if (column_index + 1 < boundaries.size() && k == boundaries[column_index]) { + ss << config.DMIDDLE; + column_index++; + } else { + ss << config.HORIZONTAL; + } + } + ss << (render_anything ? config.RMIDDLE : config.RDCORNER); + ss << std::endl; + } + if (!render_anything) { + return; + } + + if (render_rows_and_columns) { + ss << config.VERTICAL; + ss << " "; + ss << row_count_str; + ss << string(total_length - row_count_str.size() - column_count_str.size() - 4, ' '); + ss << column_count_str; + ss << " "; + ss << config.VERTICAL; + ss << std::endl; + } else if (render_rows) { + RenderValue(ss, row_count_str, total_length - 4); + ss << config.VERTICAL; + ss << std::endl; + + if (display_shown_separately) { + RenderValue(ss, shown_str, total_length - 4); + ss << config.VERTICAL; + ss << std::endl; + } + } + // render the bottom line + ss << config.LDCORNER; + for (idx_t k = 0; k < total_length - 2; k++) { + ss << config.HORIZONTAL; + } + ss << config.RDCORNER; + ss << std::endl; +} + +void BoxRenderer::Render(ClientContext &context, const vector &names, const ColumnDataCollection &result, + std::ostream &ss) { + if (result.ColumnCount() != names.size()) { + throw InternalException("Error in BoxRenderer::Render - unaligned columns and names"); + } + auto max_width = config.max_width; + if (max_width == 0) { + if (Printer::IsTerminal(OutputStream::STREAM_STDOUT)) { + max_width = Printer::TerminalWidth(); + } else { + max_width = 120; + } + } + // we do not support max widths under 80 + max_width = MaxValue(80, max_width); + + // figure out how many/which rows to render + idx_t row_count = result.Count(); + idx_t rows_to_render = MinValue(row_count, config.max_rows); + if (row_count <= config.max_rows + 3) { + // hiding rows adds 3 extra rows + // so hiding rows makes no sense if we are only slightly over the limit + // if we are 1 row over the limit hiding rows will actually increase the number of lines we display! + // in this case render all the rows + rows_to_render = row_count; + } + idx_t top_rows; + idx_t bottom_rows; + if (rows_to_render == row_count) { + top_rows = row_count; + bottom_rows = 0; + } else { + top_rows = rows_to_render / 2 + (rows_to_render % 2 != 0 ? 1 : 0); + bottom_rows = rows_to_render - top_rows; + } + auto row_count_str = to_string(row_count) + " rows"; + bool has_limited_rows = config.limit > 0 && row_count == config.limit; + if (has_limited_rows) { + row_count_str = "? rows"; + } + string shown_str; + bool has_hidden_rows = top_rows < row_count; + if (has_hidden_rows) { + shown_str = "("; + if (has_limited_rows) { + shown_str += ">" + to_string(config.limit - 1) + " rows, "; + } + shown_str += to_string(top_rows + bottom_rows) + " shown)"; + } + auto minimum_row_length = MaxValue(row_count_str.size(), shown_str.size()) + 4; + + // fetch the top and bottom render collections from the result + auto collections = FetchRenderCollections(context, result, top_rows, bottom_rows); + auto column_names = names; + auto result_types = result.Types(); + if (config.render_mode == RenderMode::COLUMNS) { + collections = PivotCollections(context, std::move(collections), column_names, result_types, row_count); + } + + // for each column, figure out the width + // start off by figuring out the name of the header by looking at the column name and column type + idx_t min_width = has_hidden_rows || row_count == 0 ? minimum_row_length : 0; + vector column_map; + idx_t total_length; + auto widths = + ComputeRenderWidths(column_names, result_types, collections, min_width, max_width, column_map, total_length); + + // render boundaries for the individual columns + vector boundaries; + for (idx_t c = 0; c < widths.size(); c++) { + idx_t render_boundary; + if (c == 0) { + render_boundary = widths[c] + 2; + } else { + render_boundary = boundaries[c - 1] + widths[c] + 3; + } + boundaries.push_back(render_boundary); + } + + // now begin rendering + // first render the header + RenderHeader(column_names, result_types, column_map, widths, boundaries, total_length, row_count > 0, ss); + + // render the values, if there are any + RenderValues(collections, column_map, widths, result_types, ss); + + // render the row count and column count + auto column_count_str = to_string(result.ColumnCount()) + " column"; + if (result.ColumnCount() > 1) { + column_count_str += "s"; + } + bool has_hidden_columns = false; + for (auto entry : column_map) { + if (entry == SPLIT_COLUMN) { + has_hidden_columns = true; + break; + } + } + idx_t column_count = column_map.size(); + if (config.render_mode == RenderMode::COLUMNS) { + if (has_hidden_columns) { + has_hidden_rows = true; + shown_str = " (" + to_string(column_count - 3) + " shown)"; + } else { + shown_str = string(); + } + } else { + if (has_hidden_columns) { + column_count--; + column_count_str += " (" + to_string(column_count) + " shown)"; + } + } + + RenderRowCount(std::move(row_count_str), std::move(shown_str), column_count_str, boundaries, has_hidden_rows, + has_hidden_columns, total_length, row_count, column_count, minimum_row_length, ss); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/checksum.cpp b/src/duckdb/src/common/checksum.cpp new file mode 100644 index 00000000..2fbca299 --- /dev/null +++ b/src/duckdb/src/common/checksum.cpp @@ -0,0 +1,25 @@ +#include "duckdb/common/checksum.hpp" +#include "duckdb/common/types/hash.hpp" + +namespace duckdb { + +hash_t Checksum(uint64_t x) { + return x * UINT64_C(0xbf58476d1ce4e5b9); +} + +uint64_t Checksum(uint8_t *buffer, size_t size) { + uint64_t result = 5381; + uint64_t *ptr = reinterpret_cast(buffer); + size_t i; + // for efficiency, we first checksum uint64_t values + for (i = 0; i < size / 8; i++) { + result ^= Checksum(ptr[i]); + } + if (size - i * 8 > 0) { + // the remaining 0-7 bytes we hash using a string hash + result ^= Hash(buffer + i * 8, size - i * 8); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/compressed_file_system.cpp b/src/duckdb/src/common/compressed_file_system.cpp new file mode 100644 index 00000000..a9d28933 --- /dev/null +++ b/src/duckdb/src/common/compressed_file_system.cpp @@ -0,0 +1,147 @@ +#include "duckdb/common/compressed_file_system.hpp" + +namespace duckdb { + +StreamWrapper::~StreamWrapper() { +} + +CompressedFile::CompressedFile(CompressedFileSystem &fs, unique_ptr child_handle_p, const string &path) + : FileHandle(fs, path), compressed_fs(fs), child_handle(std::move(child_handle_p)) { +} + +CompressedFile::~CompressedFile() { + CompressedFile::Close(); +} + +void CompressedFile::Initialize(bool write) { + Close(); + + this->write = write; + stream_data.in_buf_size = compressed_fs.InBufferSize(); + stream_data.out_buf_size = compressed_fs.OutBufferSize(); + stream_data.in_buff = make_unsafe_uniq_array(stream_data.in_buf_size); + stream_data.in_buff_start = stream_data.in_buff.get(); + stream_data.in_buff_end = stream_data.in_buff.get(); + stream_data.out_buff = make_unsafe_uniq_array(stream_data.out_buf_size); + stream_data.out_buff_start = stream_data.out_buff.get(); + stream_data.out_buff_end = stream_data.out_buff.get(); + + stream_wrapper = compressed_fs.CreateStream(); + stream_wrapper->Initialize(*this, write); +} + +int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { + idx_t total_read = 0; + while (true) { + // first check if there are input bytes available in the output buffers + if (stream_data.out_buff_start != stream_data.out_buff_end) { + // there is! copy it into the output buffer + idx_t available = MinValue(remaining, stream_data.out_buff_end - stream_data.out_buff_start); + memcpy(data_ptr_t(buffer) + total_read, stream_data.out_buff_start, available); + + // increment the total read variables as required + stream_data.out_buff_start += available; + total_read += available; + remaining -= available; + if (remaining == 0) { + // done! read enough + return total_read; + } + } + if (!stream_wrapper) { + return total_read; + } + + // ran out of buffer: read more data from the child stream + stream_data.out_buff_start = stream_data.out_buff.get(); + stream_data.out_buff_end = stream_data.out_buff.get(); + D_ASSERT(stream_data.in_buff_start <= stream_data.in_buff_end); + D_ASSERT(stream_data.in_buff_end <= stream_data.in_buff_start + stream_data.in_buf_size); + + // read more input when requested and still data in the input stream + if (stream_data.refresh && (stream_data.in_buff_end == stream_data.in_buff.get() + stream_data.in_buf_size)) { + auto bufrem = stream_data.in_buff_end - stream_data.in_buff_start; + // buffer not empty, move remaining bytes to the beginning + memmove(stream_data.in_buff.get(), stream_data.in_buff_start, bufrem); + stream_data.in_buff_start = stream_data.in_buff.get(); + // refill the rest of input buffer + auto sz = child_handle->Read(stream_data.in_buff_start + bufrem, stream_data.in_buf_size - bufrem); + stream_data.in_buff_end = stream_data.in_buff_start + bufrem + sz; + if (sz <= 0) { + stream_wrapper.reset(); + break; + } + } + + // read more input if none available + if (stream_data.in_buff_start == stream_data.in_buff_end) { + // empty input buffer: refill from the start + stream_data.in_buff_start = stream_data.in_buff.get(); + stream_data.in_buff_end = stream_data.in_buff_start; + auto sz = child_handle->Read(stream_data.in_buff.get(), stream_data.in_buf_size); + if (sz <= 0) { + stream_wrapper.reset(); + break; + } + stream_data.in_buff_end = stream_data.in_buff_start + sz; + } + + auto finished = stream_wrapper->Read(stream_data); + if (finished) { + stream_wrapper.reset(); + } + } + return total_read; +} + +int64_t CompressedFile::WriteData(data_ptr_t buffer, int64_t nr_bytes) { + stream_wrapper->Write(*this, stream_data, buffer, nr_bytes); + return nr_bytes; +} + +void CompressedFile::Close() { + if (stream_wrapper) { + stream_wrapper->Close(); + stream_wrapper.reset(); + } + stream_data.in_buff.reset(); + stream_data.out_buff.reset(); + stream_data.out_buff_start = nullptr; + stream_data.out_buff_end = nullptr; + stream_data.in_buff_start = nullptr; + stream_data.in_buff_end = nullptr; + stream_data.in_buf_size = 0; + stream_data.out_buf_size = 0; +} + +int64_t CompressedFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + auto &compressed_file = handle.Cast(); + return compressed_file.ReadData(buffer, nr_bytes); +} + +int64_t CompressedFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { + auto &compressed_file = handle.Cast(); + return compressed_file.WriteData(data_ptr_cast(buffer), nr_bytes); +} + +void CompressedFileSystem::Reset(FileHandle &handle) { + auto &compressed_file = handle.Cast(); + compressed_file.child_handle->Reset(); + compressed_file.Initialize(compressed_file.write); +} + +int64_t CompressedFileSystem::GetFileSize(FileHandle &handle) { + auto &compressed_file = handle.Cast(); + return compressed_file.child_handle->GetFileSize(); +} + +bool CompressedFileSystem::OnDiskFile(FileHandle &handle) { + auto &compressed_file = handle.Cast(); + return compressed_file.child_handle->OnDiskFile(); +} + +bool CompressedFileSystem::CanSeek() { + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/constants.cpp b/src/duckdb/src/common/constants.cpp new file mode 100644 index 00000000..0e2461c2 --- /dev/null +++ b/src/duckdb/src/common/constants.cpp @@ -0,0 +1,52 @@ +#include "duckdb/common/constants.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/vector_size.hpp" + +namespace duckdb { + +constexpr const idx_t DConstants::INVALID_INDEX; +const row_t MAX_ROW_ID = 36028797018960000ULL; // 2^55 +const row_t MAX_ROW_ID_LOCAL = 72057594037920000ULL; // 2^56 +const column_t COLUMN_IDENTIFIER_ROW_ID = (column_t)-1; +const sel_t ZERO_VECTOR[STANDARD_VECTOR_SIZE] = {0}; +const double PI = 3.141592653589793; + +const transaction_t TRANSACTION_ID_START = 4611686018427388000ULL; // 2^62 +const transaction_t MAX_TRANSACTION_ID = NumericLimits::Maximum(); // 2^63 +const transaction_t NOT_DELETED_ID = NumericLimits::Maximum() - 1; // 2^64 - 1 +const transaction_t MAXIMUM_QUERY_ID = NumericLimits::Maximum(); // 2^64 + +bool IsPowerOfTwo(uint64_t v) { + return (v & (v - 1)) == 0; +} + +uint64_t NextPowerOfTwo(uint64_t v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v++; + return v; +} + +uint64_t PreviousPowerOfTwo(uint64_t v) { + return NextPowerOfTwo((v / 2) + 1); +} + +bool IsInvalidSchema(const string &str) { + return str.empty(); +} + +bool IsInvalidCatalog(const string &str) { + return str.empty(); +} + +bool IsRowIdColumnId(column_t column_id) { + return column_id == COLUMN_IDENTIFIER_ROW_ID; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/crypto/md5.cpp b/src/duckdb/src/common/crypto/md5.cpp new file mode 100644 index 00000000..228fa8e2 --- /dev/null +++ b/src/duckdb/src/common/crypto/md5.cpp @@ -0,0 +1,256 @@ +/* +** This code taken from the SQLite test library. Originally found on +** the internet. The original header comment follows this comment. +** The code is largerly unchanged, but there have been some modifications. +*/ +/* + * This code implements the MD5 message-digest algorithm. + * The algorithm is due to Ron Rivest. This code was + * written by Colin Plumb in 1993, no copyright is claimed. + * This code is in the public domain; do with it what you wish. + * + * Equivalent code is available from RSA Data Security, Inc. + * This code has been tested against that, and is equivalent, + * except that you don't need to include two pages of legalese + * with every copy. + * + * To compute the message digest of a chunk of bytes, declare an + * MD5Context structure, pass it to MD5Init, call MD5Update as + * needed on buffers full of bytes, and then call MD5Final, which + * will fill a supplied 16-byte array with the digest. + */ +#include "duckdb/common/crypto/md5.hpp" +#include "mbedtls_wrapper.hpp" + +namespace duckdb { + +/* + * Note: this code is harmless on little-endian machines. + */ +static void ByteReverse(unsigned char *buf, unsigned longs) { + uint32_t t; + do { + t = (uint32_t)((unsigned)buf[3] << 8 | buf[2]) << 16 | ((unsigned)buf[1] << 8 | buf[0]); + *reinterpret_cast(buf) = t; + buf += 4; + } while (--longs); +} +/* The four core functions - F1 is optimized somewhat */ + +/* #define F1(x, y, z) (x & y | ~x & z) */ +#define F1(x, y, z) ((z) ^ ((x) & ((y) ^ (z)))) +#define F2(x, y, z) F1(z, x, y) +#define F3(x, y, z) ((x) ^ (y) ^ (z)) +#define F4(x, y, z) ((y) ^ ((x) | ~(z))) + +/* This is the central step in the MD5 algorithm. */ +#define MD5STEP(f, w, x, y, z, data, s) ((w) += f(x, y, z) + (data), (w) = (w) << (s) | (w) >> (32 - (s)), (w) += (x)) + +/* + * The core of the MD5 algorithm, this alters an existing MD5 hash to + * reflect the addition of 16 longwords of new data. MD5Update blocks + * the data and converts bytes into longwords for this routine. + */ +static void MD5Transform(uint32_t buf[4], const uint32_t in[16]) { + uint32_t a, b, c, d; + + a = buf[0]; + b = buf[1]; + c = buf[2]; + d = buf[3]; + + MD5STEP(F1, a, b, c, d, in[0] + 0xd76aa478, 7); + MD5STEP(F1, d, a, b, c, in[1] + 0xe8c7b756, 12); + MD5STEP(F1, c, d, a, b, in[2] + 0x242070db, 17); + MD5STEP(F1, b, c, d, a, in[3] + 0xc1bdceee, 22); + MD5STEP(F1, a, b, c, d, in[4] + 0xf57c0faf, 7); + MD5STEP(F1, d, a, b, c, in[5] + 0x4787c62a, 12); + MD5STEP(F1, c, d, a, b, in[6] + 0xa8304613, 17); + MD5STEP(F1, b, c, d, a, in[7] + 0xfd469501, 22); + MD5STEP(F1, a, b, c, d, in[8] + 0x698098d8, 7); + MD5STEP(F1, d, a, b, c, in[9] + 0x8b44f7af, 12); + MD5STEP(F1, c, d, a, b, in[10] + 0xffff5bb1, 17); + MD5STEP(F1, b, c, d, a, in[11] + 0x895cd7be, 22); + MD5STEP(F1, a, b, c, d, in[12] + 0x6b901122, 7); + MD5STEP(F1, d, a, b, c, in[13] + 0xfd987193, 12); + MD5STEP(F1, c, d, a, b, in[14] + 0xa679438e, 17); + MD5STEP(F1, b, c, d, a, in[15] + 0x49b40821, 22); + + MD5STEP(F2, a, b, c, d, in[1] + 0xf61e2562, 5); + MD5STEP(F2, d, a, b, c, in[6] + 0xc040b340, 9); + MD5STEP(F2, c, d, a, b, in[11] + 0x265e5a51, 14); + MD5STEP(F2, b, c, d, a, in[0] + 0xe9b6c7aa, 20); + MD5STEP(F2, a, b, c, d, in[5] + 0xd62f105d, 5); + MD5STEP(F2, d, a, b, c, in[10] + 0x02441453, 9); + MD5STEP(F2, c, d, a, b, in[15] + 0xd8a1e681, 14); + MD5STEP(F2, b, c, d, a, in[4] + 0xe7d3fbc8, 20); + MD5STEP(F2, a, b, c, d, in[9] + 0x21e1cde6, 5); + MD5STEP(F2, d, a, b, c, in[14] + 0xc33707d6, 9); + MD5STEP(F2, c, d, a, b, in[3] + 0xf4d50d87, 14); + MD5STEP(F2, b, c, d, a, in[8] + 0x455a14ed, 20); + MD5STEP(F2, a, b, c, d, in[13] + 0xa9e3e905, 5); + MD5STEP(F2, d, a, b, c, in[2] + 0xfcefa3f8, 9); + MD5STEP(F2, c, d, a, b, in[7] + 0x676f02d9, 14); + MD5STEP(F2, b, c, d, a, in[12] + 0x8d2a4c8a, 20); + + MD5STEP(F3, a, b, c, d, in[5] + 0xfffa3942, 4); + MD5STEP(F3, d, a, b, c, in[8] + 0x8771f681, 11); + MD5STEP(F3, c, d, a, b, in[11] + 0x6d9d6122, 16); + MD5STEP(F3, b, c, d, a, in[14] + 0xfde5380c, 23); + MD5STEP(F3, a, b, c, d, in[1] + 0xa4beea44, 4); + MD5STEP(F3, d, a, b, c, in[4] + 0x4bdecfa9, 11); + MD5STEP(F3, c, d, a, b, in[7] + 0xf6bb4b60, 16); + MD5STEP(F3, b, c, d, a, in[10] + 0xbebfbc70, 23); + MD5STEP(F3, a, b, c, d, in[13] + 0x289b7ec6, 4); + MD5STEP(F3, d, a, b, c, in[0] + 0xeaa127fa, 11); + MD5STEP(F3, c, d, a, b, in[3] + 0xd4ef3085, 16); + MD5STEP(F3, b, c, d, a, in[6] + 0x04881d05, 23); + MD5STEP(F3, a, b, c, d, in[9] + 0xd9d4d039, 4); + MD5STEP(F3, d, a, b, c, in[12] + 0xe6db99e5, 11); + MD5STEP(F3, c, d, a, b, in[15] + 0x1fa27cf8, 16); + MD5STEP(F3, b, c, d, a, in[2] + 0xc4ac5665, 23); + + MD5STEP(F4, a, b, c, d, in[0] + 0xf4292244, 6); + MD5STEP(F4, d, a, b, c, in[7] + 0x432aff97, 10); + MD5STEP(F4, c, d, a, b, in[14] + 0xab9423a7, 15); + MD5STEP(F4, b, c, d, a, in[5] + 0xfc93a039, 21); + MD5STEP(F4, a, b, c, d, in[12] + 0x655b59c3, 6); + MD5STEP(F4, d, a, b, c, in[3] + 0x8f0ccc92, 10); + MD5STEP(F4, c, d, a, b, in[10] + 0xffeff47d, 15); + MD5STEP(F4, b, c, d, a, in[1] + 0x85845dd1, 21); + MD5STEP(F4, a, b, c, d, in[8] + 0x6fa87e4f, 6); + MD5STEP(F4, d, a, b, c, in[15] + 0xfe2ce6e0, 10); + MD5STEP(F4, c, d, a, b, in[6] + 0xa3014314, 15); + MD5STEP(F4, b, c, d, a, in[13] + 0x4e0811a1, 21); + MD5STEP(F4, a, b, c, d, in[4] + 0xf7537e82, 6); + MD5STEP(F4, d, a, b, c, in[11] + 0xbd3af235, 10); + MD5STEP(F4, c, d, a, b, in[2] + 0x2ad7d2bb, 15); + MD5STEP(F4, b, c, d, a, in[9] + 0xeb86d391, 21); + + buf[0] += a; + buf[1] += b; + buf[2] += c; + buf[3] += d; +} + +/* + * Start MD5 accumulation. Set bit count to 0 and buffer to mysterious + * initialization constants. + */ +MD5Context::MD5Context() { + buf[0] = 0x67452301; + buf[1] = 0xefcdab89; + buf[2] = 0x98badcfe; + buf[3] = 0x10325476; + bits[0] = 0; + bits[1] = 0; +} + +/* + * Update context to reflect the concatenation of another buffer full + * of bytes. + */ +void MD5Context::MD5Update(const_data_ptr_t input, idx_t len) { + uint32_t t; + + /* Update bitcount */ + + t = bits[0]; + if ((bits[0] = t + ((uint32_t)len << 3)) < t) { + bits[1]++; /* Carry from low to high */ + } + bits[1] += len >> 29; + + t = (t >> 3) & 0x3f; /* Bytes already in shsInfo->data */ + + /* Handle any leading odd-sized chunks */ + + if (t) { + unsigned char *p = (unsigned char *)in + t; + + t = 64 - t; + if (len < t) { + memcpy(p, input, len); + return; + } + memcpy(p, input, t); + ByteReverse(in, 16); + MD5Transform(buf, reinterpret_cast(in)); + input += t; + len -= t; + } + + /* Process data in 64-byte chunks */ + + while (len >= 64) { + memcpy(in, input, 64); + ByteReverse(in, 16); + MD5Transform(buf, reinterpret_cast(in)); + input += 64; + len -= 64; + } + + /* Handle any remaining bytes of data. */ + memcpy(in, input, len); +} + +/* + * Final wrapup - pad to 64-byte boundary with the bit pattern + * 1 0* (64-bit count of bits processed, MSB-first) + */ +void MD5Context::Finish(data_ptr_t out_digest) { + unsigned count; + unsigned char *p; + + /* Compute number of bytes mod 64 */ + count = (bits[0] >> 3) & 0x3F; + + /* Set the first char of padding to 0x80. This is safe since there is + always at least one byte free */ + p = in + count; + *p++ = 0x80; + + /* Bytes of padding needed to make 64 bytes */ + count = 64 - 1 - count; + + /* Pad out to 56 mod 64 */ + if (count < 8) { + /* Two lots of padding: Pad the first block to 64 bytes */ + memset(p, 0, count); + ByteReverse(in, 16); + MD5Transform(buf, reinterpret_cast(in)); + + /* Now fill the next block with 56 bytes */ + memset(in, 0, 56); + } else { + /* Pad block to 56 bytes */ + memset(p, 0, count - 8); + } + ByteReverse(in, 14); + + /* Append length in bits and transform */ + (reinterpret_cast(in))[14] = bits[0]; + (reinterpret_cast(in))[15] = bits[1]; + + MD5Transform(buf, reinterpret_cast(in)); + ByteReverse(reinterpret_cast(buf), 4); + memcpy(out_digest, buf, 16); +} + +void MD5Context::FinishHex(char *out_digest) { + data_t digest[MD5_HASH_LENGTH_BINARY]; + Finish(digest); + duckdb_mbedtls::MbedTlsWrapper::ToBase16(reinterpret_cast(digest), out_digest, MD5_HASH_LENGTH_BINARY); +} + +string MD5Context::FinishHex() { + char digest[MD5_HASH_LENGTH_TEXT]; + FinishHex(digest); + return string(digest, MD5_HASH_LENGTH_TEXT); +} + +void MD5Context::Add(const char *data) { + MD5Update(const_data_ptr_cast(data), strlen(data)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/cycle_counter.cpp b/src/duckdb/src/common/cycle_counter.cpp new file mode 100644 index 00000000..89d4bfc1 --- /dev/null +++ b/src/duckdb/src/common/cycle_counter.cpp @@ -0,0 +1,76 @@ +// This file is licensed under Apache License 2.0 +// Source code taken from https://github.com/google/benchmark +// It is highly modified + +#include "duckdb/common/cycle_counter.hpp" +#include "duckdb/common/chrono.hpp" + +namespace duckdb { + +inline uint64_t ChronoNow() { + return std::chrono::duration_cast( + std::chrono::time_point_cast(std::chrono::high_resolution_clock::now()) + .time_since_epoch()) + .count(); +} + +inline uint64_t Now() { +#if defined(RDTSC) +#if defined(__i386__) + uint64_t ret; + __asm__ volatile("rdtsc" : "=A"(ret)); + return ret; +#elif defined(__x86_64__) || defined(__amd64__) + uint64_t low, high; + __asm__ volatile("rdtsc" : "=a"(low), "=d"(high)); + return (high << 32) | low; +#elif defined(__powerpc__) || defined(__ppc__) + uint64_t tbl, tbu0, tbu1; + asm("mftbu %0" : "=r"(tbu0)); + asm("mftb %0" : "=r"(tbl)); + asm("mftbu %0" : "=r"(tbu1)); + tbl &= -static_cast(tbu0 == tbu1); + return (tbu1 << 32) | tbl; +#elif defined(__sparc__) + uint64_t tick; + asm(".byte 0x83, 0x41, 0x00, 0x00"); + asm("mov %%g1, %0" : "=r"(tick)); + return tick; +#elif defined(__ia64__) + uint64_t itc; + asm("mov %0 = ar.itc" : "=r"(itc)); + return itc; +#elif defined(COMPILER_MSVC) && defined(_M_IX86) + _asm rdtsc +#elif defined(COMPILER_MSVC) + return __rdtsc(); +#elif defined(__aarch64__) + uint64_t virtual_timer_value; + asm volatile("mrs %0, cntvct_el0" : "=r"(virtual_timer_value)); + return virtual_timer_value; +#elif defined(__ARM_ARCH) +#if (__ARM_ARCH >= 6) + uint32_t pmccntr; + uint32_t pmuseren; + uint32_t pmcntenset; + asm volatile("mrc p15, 0, %0, c9, c14, 0" : "=r"(pmuseren)); + if (pmuseren & 1) { // Allows reading perfmon counters for user mode code. + asm volatile("mrc p15, 0, %0, c9, c12, 1" : "=r"(pmcntenset)); + if (pmcntenset & 0x80000000ul) { // Is it counting? + asm volatile("mrc p15, 0, %0, c9, c13, 0" : "=r"(pmccntr)); + return static_cast(pmccntr) * 64; // Should optimize to << 6 + } + } +#endif + return ChronoNow(); +#else + return ChronoNow(); +#endif +#else + return ChronoNow(); +#endif // defined(RDTSC) +} +uint64_t CycleCounter::Tick() const { + return Now(); +} +} // namespace duckdb diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp new file mode 100644 index 00000000..305d09d5 --- /dev/null +++ b/src/duckdb/src/common/enum_util.cpp @@ -0,0 +1,6463 @@ +//------------------------------------------------------------------------- +// This file is automatically generated by scripts/generate_enum_util.py +// Do not edit this file manually, your changes will be overwritten +// If you want to exclude an enum from serialization, add it to the blacklist in the script +// +// Note: The generated code will only work properly if the enum is a top level item in the duckdb namespace +// If the enum is nested in a class, or in another namespace, the generated code will not compile. +// You should move the enum to the duckdb namespace, manually write a specialization or add it to the blacklist +//------------------------------------------------------------------------- + + +#include "duckdb/common/enum_util.hpp" +#include "duckdb/catalog/catalog_entry/table_column_type.hpp" +#include "duckdb/common/box_renderer.hpp" +#include "duckdb/common/enums/access_mode.hpp" +#include "duckdb/common/enums/aggregate_handling.hpp" +#include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/common/enums/compression_type.hpp" +#include "duckdb/common/enums/cte_materialize.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/enums/debug_initialize.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/enums/file_compression_type.hpp" +#include "duckdb/common/enums/file_glob_options.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/enums/index_type.hpp" +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/enums/joinref_type.hpp" +#include "duckdb/common/enums/logical_operator_type.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" +#include "duckdb/common/enums/operator_result_type.hpp" +#include "duckdb/common/enums/optimizer_type.hpp" +#include "duckdb/common/enums/order_preservation_type.hpp" +#include "duckdb/common/enums/order_type.hpp" +#include "duckdb/common/enums/output_type.hpp" +#include "duckdb/common/enums/pending_execution_result.hpp" +#include "duckdb/common/enums/physical_operator_type.hpp" +#include "duckdb/common/enums/profiler_format.hpp" +#include "duckdb/common/enums/relation_type.hpp" +#include "duckdb/common/enums/scan_options.hpp" +#include "duckdb/common/enums/set_operation_type.hpp" +#include "duckdb/common/enums/set_scope.hpp" +#include "duckdb/common/enums/set_type.hpp" +#include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/common/enums/subquery_type.hpp" +#include "duckdb/common/enums/tableref_type.hpp" +#include "duckdb/common/enums/undo_flags.hpp" +#include "duckdb/common/enums/vector_type.hpp" +#include "duckdb/common/enums/wal_type.hpp" +#include "duckdb/common/enums/window_aggregation_mode.hpp" +#include "duckdb/common/exception_format_value.hpp" +#include "duckdb/common/extra_type_info.hpp" +#include "duckdb/common/file_buffer.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/sort/partition_state.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/column/column_data_scan_states.hpp" +#include "duckdb/common/types/column/partitioned_column_data.hpp" +#include "duckdb/common/types/conflict_manager.hpp" +#include "duckdb/common/types/hyperloglog.hpp" +#include "duckdb/common/types/row/partitioned_tuple_data.hpp" +#include "duckdb/common/types/row/tuple_data_states.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/vector_buffer.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" +#include "duckdb/execution/operator/scan/csv/base_csv_reader.hpp" +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/execution/operator/scan/csv/csv_state_machine.hpp" +#include "duckdb/execution/operator/scan/csv/quote_rules.hpp" +#include "duckdb/function/aggregate_state.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/function/macro_function.hpp" +#include "duckdb/function/scalar/compressed_materialization_functions.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "duckdb/main/appender.hpp" +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/error_manager.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/main/query_result.hpp" +#include "duckdb/parallel/interrupt.hpp" +#include "duckdb/parallel/task.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/parser/parsed_data/alter_info.hpp" +#include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" +#include "duckdb/parser/parsed_data/alter_table_function_info.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" +#include "duckdb/parser/parsed_data/load_info.hpp" +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/parser/parsed_data/pragma_info.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" +#include "duckdb/parser/parsed_data/transaction_info.hpp" +#include "duckdb/parser/parser_extension.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/parser/simplified_token.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/compression/bitpacking.hpp" +#include "duckdb/storage/magic_bytes.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/table/chunk_info.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +template<> +const char* EnumUtil::ToChars(AccessMode value) { + switch(value) { + case AccessMode::UNDEFINED: + return "UNDEFINED"; + case AccessMode::AUTOMATIC: + return "AUTOMATIC"; + case AccessMode::READ_ONLY: + return "READ_ONLY"; + case AccessMode::READ_WRITE: + return "READ_WRITE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AccessMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "UNDEFINED")) { + return AccessMode::UNDEFINED; + } + if (StringUtil::Equals(value, "AUTOMATIC")) { + return AccessMode::AUTOMATIC; + } + if (StringUtil::Equals(value, "READ_ONLY")) { + return AccessMode::READ_ONLY; + } + if (StringUtil::Equals(value, "READ_WRITE")) { + return AccessMode::READ_WRITE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AggregateHandling value) { + switch(value) { + case AggregateHandling::STANDARD_HANDLING: + return "STANDARD_HANDLING"; + case AggregateHandling::NO_AGGREGATES_ALLOWED: + return "NO_AGGREGATES_ALLOWED"; + case AggregateHandling::FORCE_AGGREGATES: + return "FORCE_AGGREGATES"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AggregateHandling EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "STANDARD_HANDLING")) { + return AggregateHandling::STANDARD_HANDLING; + } + if (StringUtil::Equals(value, "NO_AGGREGATES_ALLOWED")) { + return AggregateHandling::NO_AGGREGATES_ALLOWED; + } + if (StringUtil::Equals(value, "FORCE_AGGREGATES")) { + return AggregateHandling::FORCE_AGGREGATES; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AggregateOrderDependent value) { + switch(value) { + case AggregateOrderDependent::ORDER_DEPENDENT: + return "ORDER_DEPENDENT"; + case AggregateOrderDependent::NOT_ORDER_DEPENDENT: + return "NOT_ORDER_DEPENDENT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AggregateOrderDependent EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "ORDER_DEPENDENT")) { + return AggregateOrderDependent::ORDER_DEPENDENT; + } + if (StringUtil::Equals(value, "NOT_ORDER_DEPENDENT")) { + return AggregateOrderDependent::NOT_ORDER_DEPENDENT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AggregateType value) { + switch(value) { + case AggregateType::NON_DISTINCT: + return "NON_DISTINCT"; + case AggregateType::DISTINCT: + return "DISTINCT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AggregateType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NON_DISTINCT")) { + return AggregateType::NON_DISTINCT; + } + if (StringUtil::Equals(value, "DISTINCT")) { + return AggregateType::DISTINCT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AlterForeignKeyType value) { + switch(value) { + case AlterForeignKeyType::AFT_ADD: + return "AFT_ADD"; + case AlterForeignKeyType::AFT_DELETE: + return "AFT_DELETE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AlterForeignKeyType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "AFT_ADD")) { + return AlterForeignKeyType::AFT_ADD; + } + if (StringUtil::Equals(value, "AFT_DELETE")) { + return AlterForeignKeyType::AFT_DELETE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AlterScalarFunctionType value) { + switch(value) { + case AlterScalarFunctionType::INVALID: + return "INVALID"; + case AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS: + return "ADD_FUNCTION_OVERLOADS"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AlterScalarFunctionType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return AlterScalarFunctionType::INVALID; + } + if (StringUtil::Equals(value, "ADD_FUNCTION_OVERLOADS")) { + return AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AlterTableFunctionType value) { + switch(value) { + case AlterTableFunctionType::INVALID: + return "INVALID"; + case AlterTableFunctionType::ADD_FUNCTION_OVERLOADS: + return "ADD_FUNCTION_OVERLOADS"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AlterTableFunctionType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return AlterTableFunctionType::INVALID; + } + if (StringUtil::Equals(value, "ADD_FUNCTION_OVERLOADS")) { + return AlterTableFunctionType::ADD_FUNCTION_OVERLOADS; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AlterTableType value) { + switch(value) { + case AlterTableType::INVALID: + return "INVALID"; + case AlterTableType::RENAME_COLUMN: + return "RENAME_COLUMN"; + case AlterTableType::RENAME_TABLE: + return "RENAME_TABLE"; + case AlterTableType::ADD_COLUMN: + return "ADD_COLUMN"; + case AlterTableType::REMOVE_COLUMN: + return "REMOVE_COLUMN"; + case AlterTableType::ALTER_COLUMN_TYPE: + return "ALTER_COLUMN_TYPE"; + case AlterTableType::SET_DEFAULT: + return "SET_DEFAULT"; + case AlterTableType::FOREIGN_KEY_CONSTRAINT: + return "FOREIGN_KEY_CONSTRAINT"; + case AlterTableType::SET_NOT_NULL: + return "SET_NOT_NULL"; + case AlterTableType::DROP_NOT_NULL: + return "DROP_NOT_NULL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AlterTableType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return AlterTableType::INVALID; + } + if (StringUtil::Equals(value, "RENAME_COLUMN")) { + return AlterTableType::RENAME_COLUMN; + } + if (StringUtil::Equals(value, "RENAME_TABLE")) { + return AlterTableType::RENAME_TABLE; + } + if (StringUtil::Equals(value, "ADD_COLUMN")) { + return AlterTableType::ADD_COLUMN; + } + if (StringUtil::Equals(value, "REMOVE_COLUMN")) { + return AlterTableType::REMOVE_COLUMN; + } + if (StringUtil::Equals(value, "ALTER_COLUMN_TYPE")) { + return AlterTableType::ALTER_COLUMN_TYPE; + } + if (StringUtil::Equals(value, "SET_DEFAULT")) { + return AlterTableType::SET_DEFAULT; + } + if (StringUtil::Equals(value, "FOREIGN_KEY_CONSTRAINT")) { + return AlterTableType::FOREIGN_KEY_CONSTRAINT; + } + if (StringUtil::Equals(value, "SET_NOT_NULL")) { + return AlterTableType::SET_NOT_NULL; + } + if (StringUtil::Equals(value, "DROP_NOT_NULL")) { + return AlterTableType::DROP_NOT_NULL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AlterType value) { + switch(value) { + case AlterType::INVALID: + return "INVALID"; + case AlterType::ALTER_TABLE: + return "ALTER_TABLE"; + case AlterType::ALTER_VIEW: + return "ALTER_VIEW"; + case AlterType::ALTER_SEQUENCE: + return "ALTER_SEQUENCE"; + case AlterType::CHANGE_OWNERSHIP: + return "CHANGE_OWNERSHIP"; + case AlterType::ALTER_SCALAR_FUNCTION: + return "ALTER_SCALAR_FUNCTION"; + case AlterType::ALTER_TABLE_FUNCTION: + return "ALTER_TABLE_FUNCTION"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AlterType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return AlterType::INVALID; + } + if (StringUtil::Equals(value, "ALTER_TABLE")) { + return AlterType::ALTER_TABLE; + } + if (StringUtil::Equals(value, "ALTER_VIEW")) { + return AlterType::ALTER_VIEW; + } + if (StringUtil::Equals(value, "ALTER_SEQUENCE")) { + return AlterType::ALTER_SEQUENCE; + } + if (StringUtil::Equals(value, "CHANGE_OWNERSHIP")) { + return AlterType::CHANGE_OWNERSHIP; + } + if (StringUtil::Equals(value, "ALTER_SCALAR_FUNCTION")) { + return AlterType::ALTER_SCALAR_FUNCTION; + } + if (StringUtil::Equals(value, "ALTER_TABLE_FUNCTION")) { + return AlterType::ALTER_TABLE_FUNCTION; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AlterViewType value) { + switch(value) { + case AlterViewType::INVALID: + return "INVALID"; + case AlterViewType::RENAME_VIEW: + return "RENAME_VIEW"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AlterViewType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return AlterViewType::INVALID; + } + if (StringUtil::Equals(value, "RENAME_VIEW")) { + return AlterViewType::RENAME_VIEW; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(AppenderType value) { + switch(value) { + case AppenderType::LOGICAL: + return "LOGICAL"; + case AppenderType::PHYSICAL: + return "PHYSICAL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AppenderType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "LOGICAL")) { + return AppenderType::LOGICAL; + } + if (StringUtil::Equals(value, "PHYSICAL")) { + return AppenderType::PHYSICAL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ArrowDateTimeType value) { + switch(value) { + case ArrowDateTimeType::MILLISECONDS: + return "MILLISECONDS"; + case ArrowDateTimeType::MICROSECONDS: + return "MICROSECONDS"; + case ArrowDateTimeType::NANOSECONDS: + return "NANOSECONDS"; + case ArrowDateTimeType::SECONDS: + return "SECONDS"; + case ArrowDateTimeType::DAYS: + return "DAYS"; + case ArrowDateTimeType::MONTHS: + return "MONTHS"; + case ArrowDateTimeType::MONTH_DAY_NANO: + return "MONTH_DAY_NANO"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ArrowDateTimeType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "MILLISECONDS")) { + return ArrowDateTimeType::MILLISECONDS; + } + if (StringUtil::Equals(value, "MICROSECONDS")) { + return ArrowDateTimeType::MICROSECONDS; + } + if (StringUtil::Equals(value, "NANOSECONDS")) { + return ArrowDateTimeType::NANOSECONDS; + } + if (StringUtil::Equals(value, "SECONDS")) { + return ArrowDateTimeType::SECONDS; + } + if (StringUtil::Equals(value, "DAYS")) { + return ArrowDateTimeType::DAYS; + } + if (StringUtil::Equals(value, "MONTHS")) { + return ArrowDateTimeType::MONTHS; + } + if (StringUtil::Equals(value, "MONTH_DAY_NANO")) { + return ArrowDateTimeType::MONTH_DAY_NANO; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ArrowVariableSizeType value) { + switch(value) { + case ArrowVariableSizeType::FIXED_SIZE: + return "FIXED_SIZE"; + case ArrowVariableSizeType::NORMAL: + return "NORMAL"; + case ArrowVariableSizeType::SUPER_SIZE: + return "SUPER_SIZE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ArrowVariableSizeType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "FIXED_SIZE")) { + return ArrowVariableSizeType::FIXED_SIZE; + } + if (StringUtil::Equals(value, "NORMAL")) { + return ArrowVariableSizeType::NORMAL; + } + if (StringUtil::Equals(value, "SUPER_SIZE")) { + return ArrowVariableSizeType::SUPER_SIZE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(BindingMode value) { + switch(value) { + case BindingMode::STANDARD_BINDING: + return "STANDARD_BINDING"; + case BindingMode::EXTRACT_NAMES: + return "EXTRACT_NAMES"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +BindingMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "STANDARD_BINDING")) { + return BindingMode::STANDARD_BINDING; + } + if (StringUtil::Equals(value, "EXTRACT_NAMES")) { + return BindingMode::EXTRACT_NAMES; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(BitpackingMode value) { + switch(value) { + case BitpackingMode::INVALID: + return "INVALID"; + case BitpackingMode::AUTO: + return "AUTO"; + case BitpackingMode::CONSTANT: + return "CONSTANT"; + case BitpackingMode::CONSTANT_DELTA: + return "CONSTANT_DELTA"; + case BitpackingMode::DELTA_FOR: + return "DELTA_FOR"; + case BitpackingMode::FOR: + return "FOR"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +BitpackingMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return BitpackingMode::INVALID; + } + if (StringUtil::Equals(value, "AUTO")) { + return BitpackingMode::AUTO; + } + if (StringUtil::Equals(value, "CONSTANT")) { + return BitpackingMode::CONSTANT; + } + if (StringUtil::Equals(value, "CONSTANT_DELTA")) { + return BitpackingMode::CONSTANT_DELTA; + } + if (StringUtil::Equals(value, "DELTA_FOR")) { + return BitpackingMode::DELTA_FOR; + } + if (StringUtil::Equals(value, "FOR")) { + return BitpackingMode::FOR; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(BlockState value) { + switch(value) { + case BlockState::BLOCK_UNLOADED: + return "BLOCK_UNLOADED"; + case BlockState::BLOCK_LOADED: + return "BLOCK_LOADED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +BlockState EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "BLOCK_UNLOADED")) { + return BlockState::BLOCK_UNLOADED; + } + if (StringUtil::Equals(value, "BLOCK_LOADED")) { + return BlockState::BLOCK_LOADED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(CAPIResultSetType value) { + switch(value) { + case CAPIResultSetType::CAPI_RESULT_TYPE_NONE: + return "CAPI_RESULT_TYPE_NONE"; + case CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED: + return "CAPI_RESULT_TYPE_MATERIALIZED"; + case CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING: + return "CAPI_RESULT_TYPE_STREAMING"; + case CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED: + return "CAPI_RESULT_TYPE_DEPRECATED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CAPIResultSetType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "CAPI_RESULT_TYPE_NONE")) { + return CAPIResultSetType::CAPI_RESULT_TYPE_NONE; + } + if (StringUtil::Equals(value, "CAPI_RESULT_TYPE_MATERIALIZED")) { + return CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED; + } + if (StringUtil::Equals(value, "CAPI_RESULT_TYPE_STREAMING")) { + return CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING; + } + if (StringUtil::Equals(value, "CAPI_RESULT_TYPE_DEPRECATED")) { + return CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(CSVState value) { + switch(value) { + case CSVState::STANDARD: + return "STANDARD"; + case CSVState::DELIMITER: + return "DELIMITER"; + case CSVState::RECORD_SEPARATOR: + return "RECORD_SEPARATOR"; + case CSVState::CARRIAGE_RETURN: + return "CARRIAGE_RETURN"; + case CSVState::QUOTED: + return "QUOTED"; + case CSVState::UNQUOTED: + return "UNQUOTED"; + case CSVState::ESCAPE: + return "ESCAPE"; + case CSVState::EMPTY_LINE: + return "EMPTY_LINE"; + case CSVState::INVALID: + return "INVALID"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CSVState EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "STANDARD")) { + return CSVState::STANDARD; + } + if (StringUtil::Equals(value, "DELIMITER")) { + return CSVState::DELIMITER; + } + if (StringUtil::Equals(value, "RECORD_SEPARATOR")) { + return CSVState::RECORD_SEPARATOR; + } + if (StringUtil::Equals(value, "CARRIAGE_RETURN")) { + return CSVState::CARRIAGE_RETURN; + } + if (StringUtil::Equals(value, "QUOTED")) { + return CSVState::QUOTED; + } + if (StringUtil::Equals(value, "UNQUOTED")) { + return CSVState::UNQUOTED; + } + if (StringUtil::Equals(value, "ESCAPE")) { + return CSVState::ESCAPE; + } + if (StringUtil::Equals(value, "EMPTY_LINE")) { + return CSVState::EMPTY_LINE; + } + if (StringUtil::Equals(value, "INVALID")) { + return CSVState::INVALID; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(CTEMaterialize value) { + switch(value) { + case CTEMaterialize::CTE_MATERIALIZE_DEFAULT: + return "CTE_MATERIALIZE_DEFAULT"; + case CTEMaterialize::CTE_MATERIALIZE_ALWAYS: + return "CTE_MATERIALIZE_ALWAYS"; + case CTEMaterialize::CTE_MATERIALIZE_NEVER: + return "CTE_MATERIALIZE_NEVER"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CTEMaterialize EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "CTE_MATERIALIZE_DEFAULT")) { + return CTEMaterialize::CTE_MATERIALIZE_DEFAULT; + } + if (StringUtil::Equals(value, "CTE_MATERIALIZE_ALWAYS")) { + return CTEMaterialize::CTE_MATERIALIZE_ALWAYS; + } + if (StringUtil::Equals(value, "CTE_MATERIALIZE_NEVER")) { + return CTEMaterialize::CTE_MATERIALIZE_NEVER; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(CatalogType value) { + switch(value) { + case CatalogType::INVALID: + return "INVALID"; + case CatalogType::TABLE_ENTRY: + return "TABLE_ENTRY"; + case CatalogType::SCHEMA_ENTRY: + return "SCHEMA_ENTRY"; + case CatalogType::VIEW_ENTRY: + return "VIEW_ENTRY"; + case CatalogType::INDEX_ENTRY: + return "INDEX_ENTRY"; + case CatalogType::PREPARED_STATEMENT: + return "PREPARED_STATEMENT"; + case CatalogType::SEQUENCE_ENTRY: + return "SEQUENCE_ENTRY"; + case CatalogType::COLLATION_ENTRY: + return "COLLATION_ENTRY"; + case CatalogType::TYPE_ENTRY: + return "TYPE_ENTRY"; + case CatalogType::DATABASE_ENTRY: + return "DATABASE_ENTRY"; + case CatalogType::TABLE_FUNCTION_ENTRY: + return "TABLE_FUNCTION_ENTRY"; + case CatalogType::SCALAR_FUNCTION_ENTRY: + return "SCALAR_FUNCTION_ENTRY"; + case CatalogType::AGGREGATE_FUNCTION_ENTRY: + return "AGGREGATE_FUNCTION_ENTRY"; + case CatalogType::PRAGMA_FUNCTION_ENTRY: + return "PRAGMA_FUNCTION_ENTRY"; + case CatalogType::COPY_FUNCTION_ENTRY: + return "COPY_FUNCTION_ENTRY"; + case CatalogType::MACRO_ENTRY: + return "MACRO_ENTRY"; + case CatalogType::TABLE_MACRO_ENTRY: + return "TABLE_MACRO_ENTRY"; + case CatalogType::UPDATED_ENTRY: + return "UPDATED_ENTRY"; + case CatalogType::DELETED_ENTRY: + return "DELETED_ENTRY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CatalogType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return CatalogType::INVALID; + } + if (StringUtil::Equals(value, "TABLE_ENTRY")) { + return CatalogType::TABLE_ENTRY; + } + if (StringUtil::Equals(value, "SCHEMA_ENTRY")) { + return CatalogType::SCHEMA_ENTRY; + } + if (StringUtil::Equals(value, "VIEW_ENTRY")) { + return CatalogType::VIEW_ENTRY; + } + if (StringUtil::Equals(value, "INDEX_ENTRY")) { + return CatalogType::INDEX_ENTRY; + } + if (StringUtil::Equals(value, "PREPARED_STATEMENT")) { + return CatalogType::PREPARED_STATEMENT; + } + if (StringUtil::Equals(value, "SEQUENCE_ENTRY")) { + return CatalogType::SEQUENCE_ENTRY; + } + if (StringUtil::Equals(value, "COLLATION_ENTRY")) { + return CatalogType::COLLATION_ENTRY; + } + if (StringUtil::Equals(value, "TYPE_ENTRY")) { + return CatalogType::TYPE_ENTRY; + } + if (StringUtil::Equals(value, "DATABASE_ENTRY")) { + return CatalogType::DATABASE_ENTRY; + } + if (StringUtil::Equals(value, "TABLE_FUNCTION_ENTRY")) { + return CatalogType::TABLE_FUNCTION_ENTRY; + } + if (StringUtil::Equals(value, "SCALAR_FUNCTION_ENTRY")) { + return CatalogType::SCALAR_FUNCTION_ENTRY; + } + if (StringUtil::Equals(value, "AGGREGATE_FUNCTION_ENTRY")) { + return CatalogType::AGGREGATE_FUNCTION_ENTRY; + } + if (StringUtil::Equals(value, "PRAGMA_FUNCTION_ENTRY")) { + return CatalogType::PRAGMA_FUNCTION_ENTRY; + } + if (StringUtil::Equals(value, "COPY_FUNCTION_ENTRY")) { + return CatalogType::COPY_FUNCTION_ENTRY; + } + if (StringUtil::Equals(value, "MACRO_ENTRY")) { + return CatalogType::MACRO_ENTRY; + } + if (StringUtil::Equals(value, "TABLE_MACRO_ENTRY")) { + return CatalogType::TABLE_MACRO_ENTRY; + } + if (StringUtil::Equals(value, "UPDATED_ENTRY")) { + return CatalogType::UPDATED_ENTRY; + } + if (StringUtil::Equals(value, "DELETED_ENTRY")) { + return CatalogType::DELETED_ENTRY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(CheckpointAbort value) { + switch(value) { + case CheckpointAbort::NO_ABORT: + return "NO_ABORT"; + case CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE: + return "DEBUG_ABORT_BEFORE_TRUNCATE"; + case CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER: + return "DEBUG_ABORT_BEFORE_HEADER"; + case CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE: + return "DEBUG_ABORT_AFTER_FREE_LIST_WRITE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CheckpointAbort EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NO_ABORT")) { + return CheckpointAbort::NO_ABORT; + } + if (StringUtil::Equals(value, "DEBUG_ABORT_BEFORE_TRUNCATE")) { + return CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE; + } + if (StringUtil::Equals(value, "DEBUG_ABORT_BEFORE_HEADER")) { + return CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER; + } + if (StringUtil::Equals(value, "DEBUG_ABORT_AFTER_FREE_LIST_WRITE")) { + return CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ChunkInfoType value) { + switch(value) { + case ChunkInfoType::CONSTANT_INFO: + return "CONSTANT_INFO"; + case ChunkInfoType::VECTOR_INFO: + return "VECTOR_INFO"; + case ChunkInfoType::EMPTY_INFO: + return "EMPTY_INFO"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ChunkInfoType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "CONSTANT_INFO")) { + return ChunkInfoType::CONSTANT_INFO; + } + if (StringUtil::Equals(value, "VECTOR_INFO")) { + return ChunkInfoType::VECTOR_INFO; + } + if (StringUtil::Equals(value, "EMPTY_INFO")) { + return ChunkInfoType::EMPTY_INFO; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ColumnDataAllocatorType value) { + switch(value) { + case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: + return "BUFFER_MANAGER_ALLOCATOR"; + case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: + return "IN_MEMORY_ALLOCATOR"; + case ColumnDataAllocatorType::HYBRID: + return "HYBRID"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ColumnDataAllocatorType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "BUFFER_MANAGER_ALLOCATOR")) { + return ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR; + } + if (StringUtil::Equals(value, "IN_MEMORY_ALLOCATOR")) { + return ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR; + } + if (StringUtil::Equals(value, "HYBRID")) { + return ColumnDataAllocatorType::HYBRID; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ColumnDataScanProperties value) { + switch(value) { + case ColumnDataScanProperties::INVALID: + return "INVALID"; + case ColumnDataScanProperties::ALLOW_ZERO_COPY: + return "ALLOW_ZERO_COPY"; + case ColumnDataScanProperties::DISALLOW_ZERO_COPY: + return "DISALLOW_ZERO_COPY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ColumnDataScanProperties EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return ColumnDataScanProperties::INVALID; + } + if (StringUtil::Equals(value, "ALLOW_ZERO_COPY")) { + return ColumnDataScanProperties::ALLOW_ZERO_COPY; + } + if (StringUtil::Equals(value, "DISALLOW_ZERO_COPY")) { + return ColumnDataScanProperties::DISALLOW_ZERO_COPY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ColumnSegmentType value) { + switch(value) { + case ColumnSegmentType::TRANSIENT: + return "TRANSIENT"; + case ColumnSegmentType::PERSISTENT: + return "PERSISTENT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ColumnSegmentType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "TRANSIENT")) { + return ColumnSegmentType::TRANSIENT; + } + if (StringUtil::Equals(value, "PERSISTENT")) { + return ColumnSegmentType::PERSISTENT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(CompressedMaterializationDirection value) { + switch(value) { + case CompressedMaterializationDirection::INVALID: + return "INVALID"; + case CompressedMaterializationDirection::COMPRESS: + return "COMPRESS"; + case CompressedMaterializationDirection::DECOMPRESS: + return "DECOMPRESS"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CompressedMaterializationDirection EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return CompressedMaterializationDirection::INVALID; + } + if (StringUtil::Equals(value, "COMPRESS")) { + return CompressedMaterializationDirection::COMPRESS; + } + if (StringUtil::Equals(value, "DECOMPRESS")) { + return CompressedMaterializationDirection::DECOMPRESS; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(CompressionType value) { + switch(value) { + case CompressionType::COMPRESSION_AUTO: + return "COMPRESSION_AUTO"; + case CompressionType::COMPRESSION_UNCOMPRESSED: + return "COMPRESSION_UNCOMPRESSED"; + case CompressionType::COMPRESSION_CONSTANT: + return "COMPRESSION_CONSTANT"; + case CompressionType::COMPRESSION_RLE: + return "COMPRESSION_RLE"; + case CompressionType::COMPRESSION_DICTIONARY: + return "COMPRESSION_DICTIONARY"; + case CompressionType::COMPRESSION_PFOR_DELTA: + return "COMPRESSION_PFOR_DELTA"; + case CompressionType::COMPRESSION_BITPACKING: + return "COMPRESSION_BITPACKING"; + case CompressionType::COMPRESSION_FSST: + return "COMPRESSION_FSST"; + case CompressionType::COMPRESSION_CHIMP: + return "COMPRESSION_CHIMP"; + case CompressionType::COMPRESSION_PATAS: + return "COMPRESSION_PATAS"; + case CompressionType::COMPRESSION_COUNT: + return "COMPRESSION_COUNT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CompressionType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "COMPRESSION_AUTO")) { + return CompressionType::COMPRESSION_AUTO; + } + if (StringUtil::Equals(value, "COMPRESSION_UNCOMPRESSED")) { + return CompressionType::COMPRESSION_UNCOMPRESSED; + } + if (StringUtil::Equals(value, "COMPRESSION_CONSTANT")) { + return CompressionType::COMPRESSION_CONSTANT; + } + if (StringUtil::Equals(value, "COMPRESSION_RLE")) { + return CompressionType::COMPRESSION_RLE; + } + if (StringUtil::Equals(value, "COMPRESSION_DICTIONARY")) { + return CompressionType::COMPRESSION_DICTIONARY; + } + if (StringUtil::Equals(value, "COMPRESSION_PFOR_DELTA")) { + return CompressionType::COMPRESSION_PFOR_DELTA; + } + if (StringUtil::Equals(value, "COMPRESSION_BITPACKING")) { + return CompressionType::COMPRESSION_BITPACKING; + } + if (StringUtil::Equals(value, "COMPRESSION_FSST")) { + return CompressionType::COMPRESSION_FSST; + } + if (StringUtil::Equals(value, "COMPRESSION_CHIMP")) { + return CompressionType::COMPRESSION_CHIMP; + } + if (StringUtil::Equals(value, "COMPRESSION_PATAS")) { + return CompressionType::COMPRESSION_PATAS; + } + if (StringUtil::Equals(value, "COMPRESSION_COUNT")) { + return CompressionType::COMPRESSION_COUNT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ConflictManagerMode value) { + switch(value) { + case ConflictManagerMode::SCAN: + return "SCAN"; + case ConflictManagerMode::THROW: + return "THROW"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ConflictManagerMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "SCAN")) { + return ConflictManagerMode::SCAN; + } + if (StringUtil::Equals(value, "THROW")) { + return ConflictManagerMode::THROW; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ConstraintType value) { + switch(value) { + case ConstraintType::INVALID: + return "INVALID"; + case ConstraintType::NOT_NULL: + return "NOT_NULL"; + case ConstraintType::CHECK: + return "CHECK"; + case ConstraintType::UNIQUE: + return "UNIQUE"; + case ConstraintType::FOREIGN_KEY: + return "FOREIGN_KEY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ConstraintType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return ConstraintType::INVALID; + } + if (StringUtil::Equals(value, "NOT_NULL")) { + return ConstraintType::NOT_NULL; + } + if (StringUtil::Equals(value, "CHECK")) { + return ConstraintType::CHECK; + } + if (StringUtil::Equals(value, "UNIQUE")) { + return ConstraintType::UNIQUE; + } + if (StringUtil::Equals(value, "FOREIGN_KEY")) { + return ConstraintType::FOREIGN_KEY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(DataFileType value) { + switch(value) { + case DataFileType::FILE_DOES_NOT_EXIST: + return "FILE_DOES_NOT_EXIST"; + case DataFileType::DUCKDB_FILE: + return "DUCKDB_FILE"; + case DataFileType::SQLITE_FILE: + return "SQLITE_FILE"; + case DataFileType::PARQUET_FILE: + return "PARQUET_FILE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +DataFileType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "FILE_DOES_NOT_EXIST")) { + return DataFileType::FILE_DOES_NOT_EXIST; + } + if (StringUtil::Equals(value, "DUCKDB_FILE")) { + return DataFileType::DUCKDB_FILE; + } + if (StringUtil::Equals(value, "SQLITE_FILE")) { + return DataFileType::SQLITE_FILE; + } + if (StringUtil::Equals(value, "PARQUET_FILE")) { + return DataFileType::PARQUET_FILE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(DatePartSpecifier value) { + switch(value) { + case DatePartSpecifier::YEAR: + return "YEAR"; + case DatePartSpecifier::MONTH: + return "MONTH"; + case DatePartSpecifier::DAY: + return "DAY"; + case DatePartSpecifier::DECADE: + return "DECADE"; + case DatePartSpecifier::CENTURY: + return "CENTURY"; + case DatePartSpecifier::MILLENNIUM: + return "MILLENNIUM"; + case DatePartSpecifier::MICROSECONDS: + return "MICROSECONDS"; + case DatePartSpecifier::MILLISECONDS: + return "MILLISECONDS"; + case DatePartSpecifier::SECOND: + return "SECOND"; + case DatePartSpecifier::MINUTE: + return "MINUTE"; + case DatePartSpecifier::HOUR: + return "HOUR"; + case DatePartSpecifier::DOW: + return "DOW"; + case DatePartSpecifier::ISODOW: + return "ISODOW"; + case DatePartSpecifier::WEEK: + return "WEEK"; + case DatePartSpecifier::ISOYEAR: + return "ISOYEAR"; + case DatePartSpecifier::QUARTER: + return "QUARTER"; + case DatePartSpecifier::DOY: + return "DOY"; + case DatePartSpecifier::YEARWEEK: + return "YEARWEEK"; + case DatePartSpecifier::ERA: + return "ERA"; + case DatePartSpecifier::TIMEZONE: + return "TIMEZONE"; + case DatePartSpecifier::TIMEZONE_HOUR: + return "TIMEZONE_HOUR"; + case DatePartSpecifier::TIMEZONE_MINUTE: + return "TIMEZONE_MINUTE"; + case DatePartSpecifier::EPOCH: + return "EPOCH"; + case DatePartSpecifier::JULIAN_DAY: + return "JULIAN_DAY"; + case DatePartSpecifier::INVALID: + return "INVALID"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +DatePartSpecifier EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "YEAR")) { + return DatePartSpecifier::YEAR; + } + if (StringUtil::Equals(value, "MONTH")) { + return DatePartSpecifier::MONTH; + } + if (StringUtil::Equals(value, "DAY")) { + return DatePartSpecifier::DAY; + } + if (StringUtil::Equals(value, "DECADE")) { + return DatePartSpecifier::DECADE; + } + if (StringUtil::Equals(value, "CENTURY")) { + return DatePartSpecifier::CENTURY; + } + if (StringUtil::Equals(value, "MILLENNIUM")) { + return DatePartSpecifier::MILLENNIUM; + } + if (StringUtil::Equals(value, "MICROSECONDS")) { + return DatePartSpecifier::MICROSECONDS; + } + if (StringUtil::Equals(value, "MILLISECONDS")) { + return DatePartSpecifier::MILLISECONDS; + } + if (StringUtil::Equals(value, "SECOND")) { + return DatePartSpecifier::SECOND; + } + if (StringUtil::Equals(value, "MINUTE")) { + return DatePartSpecifier::MINUTE; + } + if (StringUtil::Equals(value, "HOUR")) { + return DatePartSpecifier::HOUR; + } + if (StringUtil::Equals(value, "DOW")) { + return DatePartSpecifier::DOW; + } + if (StringUtil::Equals(value, "ISODOW")) { + return DatePartSpecifier::ISODOW; + } + if (StringUtil::Equals(value, "WEEK")) { + return DatePartSpecifier::WEEK; + } + if (StringUtil::Equals(value, "ISOYEAR")) { + return DatePartSpecifier::ISOYEAR; + } + if (StringUtil::Equals(value, "QUARTER")) { + return DatePartSpecifier::QUARTER; + } + if (StringUtil::Equals(value, "DOY")) { + return DatePartSpecifier::DOY; + } + if (StringUtil::Equals(value, "YEARWEEK")) { + return DatePartSpecifier::YEARWEEK; + } + if (StringUtil::Equals(value, "ERA")) { + return DatePartSpecifier::ERA; + } + if (StringUtil::Equals(value, "TIMEZONE")) { + return DatePartSpecifier::TIMEZONE; + } + if (StringUtil::Equals(value, "TIMEZONE_HOUR")) { + return DatePartSpecifier::TIMEZONE_HOUR; + } + if (StringUtil::Equals(value, "TIMEZONE_MINUTE")) { + return DatePartSpecifier::TIMEZONE_MINUTE; + } + if (StringUtil::Equals(value, "EPOCH")) { + return DatePartSpecifier::EPOCH; + } + if (StringUtil::Equals(value, "JULIAN_DAY")) { + return DatePartSpecifier::JULIAN_DAY; + } + if (StringUtil::Equals(value, "INVALID")) { + return DatePartSpecifier::INVALID; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(DebugInitialize value) { + switch(value) { + case DebugInitialize::NO_INITIALIZE: + return "NO_INITIALIZE"; + case DebugInitialize::DEBUG_ZERO_INITIALIZE: + return "DEBUG_ZERO_INITIALIZE"; + case DebugInitialize::DEBUG_ONE_INITIALIZE: + return "DEBUG_ONE_INITIALIZE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +DebugInitialize EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NO_INITIALIZE")) { + return DebugInitialize::NO_INITIALIZE; + } + if (StringUtil::Equals(value, "DEBUG_ZERO_INITIALIZE")) { + return DebugInitialize::DEBUG_ZERO_INITIALIZE; + } + if (StringUtil::Equals(value, "DEBUG_ONE_INITIALIZE")) { + return DebugInitialize::DEBUG_ONE_INITIALIZE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(DefaultOrderByNullType value) { + switch(value) { + case DefaultOrderByNullType::INVALID: + return "INVALID"; + case DefaultOrderByNullType::NULLS_FIRST: + return "NULLS_FIRST"; + case DefaultOrderByNullType::NULLS_LAST: + return "NULLS_LAST"; + case DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC: + return "NULLS_FIRST_ON_ASC_LAST_ON_DESC"; + case DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC: + return "NULLS_LAST_ON_ASC_FIRST_ON_DESC"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +DefaultOrderByNullType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return DefaultOrderByNullType::INVALID; + } + if (StringUtil::Equals(value, "NULLS_FIRST")) { + return DefaultOrderByNullType::NULLS_FIRST; + } + if (StringUtil::Equals(value, "NULLS_LAST")) { + return DefaultOrderByNullType::NULLS_LAST; + } + if (StringUtil::Equals(value, "NULLS_FIRST_ON_ASC_LAST_ON_DESC")) { + return DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC; + } + if (StringUtil::Equals(value, "NULLS_LAST_ON_ASC_FIRST_ON_DESC")) { + return DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(DistinctType value) { + switch(value) { + case DistinctType::DISTINCT: + return "DISTINCT"; + case DistinctType::DISTINCT_ON: + return "DISTINCT_ON"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +DistinctType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "DISTINCT")) { + return DistinctType::DISTINCT; + } + if (StringUtil::Equals(value, "DISTINCT_ON")) { + return DistinctType::DISTINCT_ON; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ErrorType value) { + switch(value) { + case ErrorType::UNSIGNED_EXTENSION: + return "UNSIGNED_EXTENSION"; + case ErrorType::INVALIDATED_TRANSACTION: + return "INVALIDATED_TRANSACTION"; + case ErrorType::INVALIDATED_DATABASE: + return "INVALIDATED_DATABASE"; + case ErrorType::ERROR_COUNT: + return "ERROR_COUNT"; + case ErrorType::INVALID: + return "INVALID"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ErrorType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "UNSIGNED_EXTENSION")) { + return ErrorType::UNSIGNED_EXTENSION; + } + if (StringUtil::Equals(value, "INVALIDATED_TRANSACTION")) { + return ErrorType::INVALIDATED_TRANSACTION; + } + if (StringUtil::Equals(value, "INVALIDATED_DATABASE")) { + return ErrorType::INVALIDATED_DATABASE; + } + if (StringUtil::Equals(value, "ERROR_COUNT")) { + return ErrorType::ERROR_COUNT; + } + if (StringUtil::Equals(value, "INVALID")) { + return ErrorType::INVALID; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ExceptionFormatValueType value) { + switch(value) { + case ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE: + return "FORMAT_VALUE_TYPE_DOUBLE"; + case ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER: + return "FORMAT_VALUE_TYPE_INTEGER"; + case ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING: + return "FORMAT_VALUE_TYPE_STRING"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExceptionFormatValueType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "FORMAT_VALUE_TYPE_DOUBLE")) { + return ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE; + } + if (StringUtil::Equals(value, "FORMAT_VALUE_TYPE_INTEGER")) { + return ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER; + } + if (StringUtil::Equals(value, "FORMAT_VALUE_TYPE_STRING")) { + return ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ExplainOutputType value) { + switch(value) { + case ExplainOutputType::ALL: + return "ALL"; + case ExplainOutputType::OPTIMIZED_ONLY: + return "OPTIMIZED_ONLY"; + case ExplainOutputType::PHYSICAL_ONLY: + return "PHYSICAL_ONLY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExplainOutputType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "ALL")) { + return ExplainOutputType::ALL; + } + if (StringUtil::Equals(value, "OPTIMIZED_ONLY")) { + return ExplainOutputType::OPTIMIZED_ONLY; + } + if (StringUtil::Equals(value, "PHYSICAL_ONLY")) { + return ExplainOutputType::PHYSICAL_ONLY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ExplainType value) { + switch(value) { + case ExplainType::EXPLAIN_STANDARD: + return "EXPLAIN_STANDARD"; + case ExplainType::EXPLAIN_ANALYZE: + return "EXPLAIN_ANALYZE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExplainType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "EXPLAIN_STANDARD")) { + return ExplainType::EXPLAIN_STANDARD; + } + if (StringUtil::Equals(value, "EXPLAIN_ANALYZE")) { + return ExplainType::EXPLAIN_ANALYZE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ExpressionClass value) { + switch(value) { + case ExpressionClass::INVALID: + return "INVALID"; + case ExpressionClass::AGGREGATE: + return "AGGREGATE"; + case ExpressionClass::CASE: + return "CASE"; + case ExpressionClass::CAST: + return "CAST"; + case ExpressionClass::COLUMN_REF: + return "COLUMN_REF"; + case ExpressionClass::COMPARISON: + return "COMPARISON"; + case ExpressionClass::CONJUNCTION: + return "CONJUNCTION"; + case ExpressionClass::CONSTANT: + return "CONSTANT"; + case ExpressionClass::DEFAULT: + return "DEFAULT"; + case ExpressionClass::FUNCTION: + return "FUNCTION"; + case ExpressionClass::OPERATOR: + return "OPERATOR"; + case ExpressionClass::STAR: + return "STAR"; + case ExpressionClass::SUBQUERY: + return "SUBQUERY"; + case ExpressionClass::WINDOW: + return "WINDOW"; + case ExpressionClass::PARAMETER: + return "PARAMETER"; + case ExpressionClass::COLLATE: + return "COLLATE"; + case ExpressionClass::LAMBDA: + return "LAMBDA"; + case ExpressionClass::POSITIONAL_REFERENCE: + return "POSITIONAL_REFERENCE"; + case ExpressionClass::BETWEEN: + return "BETWEEN"; + case ExpressionClass::BOUND_AGGREGATE: + return "BOUND_AGGREGATE"; + case ExpressionClass::BOUND_CASE: + return "BOUND_CASE"; + case ExpressionClass::BOUND_CAST: + return "BOUND_CAST"; + case ExpressionClass::BOUND_COLUMN_REF: + return "BOUND_COLUMN_REF"; + case ExpressionClass::BOUND_COMPARISON: + return "BOUND_COMPARISON"; + case ExpressionClass::BOUND_CONJUNCTION: + return "BOUND_CONJUNCTION"; + case ExpressionClass::BOUND_CONSTANT: + return "BOUND_CONSTANT"; + case ExpressionClass::BOUND_DEFAULT: + return "BOUND_DEFAULT"; + case ExpressionClass::BOUND_FUNCTION: + return "BOUND_FUNCTION"; + case ExpressionClass::BOUND_OPERATOR: + return "BOUND_OPERATOR"; + case ExpressionClass::BOUND_PARAMETER: + return "BOUND_PARAMETER"; + case ExpressionClass::BOUND_REF: + return "BOUND_REF"; + case ExpressionClass::BOUND_SUBQUERY: + return "BOUND_SUBQUERY"; + case ExpressionClass::BOUND_WINDOW: + return "BOUND_WINDOW"; + case ExpressionClass::BOUND_BETWEEN: + return "BOUND_BETWEEN"; + case ExpressionClass::BOUND_UNNEST: + return "BOUND_UNNEST"; + case ExpressionClass::BOUND_LAMBDA: + return "BOUND_LAMBDA"; + case ExpressionClass::BOUND_LAMBDA_REF: + return "BOUND_LAMBDA_REF"; + case ExpressionClass::BOUND_EXPRESSION: + return "BOUND_EXPRESSION"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExpressionClass EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return ExpressionClass::INVALID; + } + if (StringUtil::Equals(value, "AGGREGATE")) { + return ExpressionClass::AGGREGATE; + } + if (StringUtil::Equals(value, "CASE")) { + return ExpressionClass::CASE; + } + if (StringUtil::Equals(value, "CAST")) { + return ExpressionClass::CAST; + } + if (StringUtil::Equals(value, "COLUMN_REF")) { + return ExpressionClass::COLUMN_REF; + } + if (StringUtil::Equals(value, "COMPARISON")) { + return ExpressionClass::COMPARISON; + } + if (StringUtil::Equals(value, "CONJUNCTION")) { + return ExpressionClass::CONJUNCTION; + } + if (StringUtil::Equals(value, "CONSTANT")) { + return ExpressionClass::CONSTANT; + } + if (StringUtil::Equals(value, "DEFAULT")) { + return ExpressionClass::DEFAULT; + } + if (StringUtil::Equals(value, "FUNCTION")) { + return ExpressionClass::FUNCTION; + } + if (StringUtil::Equals(value, "OPERATOR")) { + return ExpressionClass::OPERATOR; + } + if (StringUtil::Equals(value, "STAR")) { + return ExpressionClass::STAR; + } + if (StringUtil::Equals(value, "SUBQUERY")) { + return ExpressionClass::SUBQUERY; + } + if (StringUtil::Equals(value, "WINDOW")) { + return ExpressionClass::WINDOW; + } + if (StringUtil::Equals(value, "PARAMETER")) { + return ExpressionClass::PARAMETER; + } + if (StringUtil::Equals(value, "COLLATE")) { + return ExpressionClass::COLLATE; + } + if (StringUtil::Equals(value, "LAMBDA")) { + return ExpressionClass::LAMBDA; + } + if (StringUtil::Equals(value, "POSITIONAL_REFERENCE")) { + return ExpressionClass::POSITIONAL_REFERENCE; + } + if (StringUtil::Equals(value, "BETWEEN")) { + return ExpressionClass::BETWEEN; + } + if (StringUtil::Equals(value, "BOUND_AGGREGATE")) { + return ExpressionClass::BOUND_AGGREGATE; + } + if (StringUtil::Equals(value, "BOUND_CASE")) { + return ExpressionClass::BOUND_CASE; + } + if (StringUtil::Equals(value, "BOUND_CAST")) { + return ExpressionClass::BOUND_CAST; + } + if (StringUtil::Equals(value, "BOUND_COLUMN_REF")) { + return ExpressionClass::BOUND_COLUMN_REF; + } + if (StringUtil::Equals(value, "BOUND_COMPARISON")) { + return ExpressionClass::BOUND_COMPARISON; + } + if (StringUtil::Equals(value, "BOUND_CONJUNCTION")) { + return ExpressionClass::BOUND_CONJUNCTION; + } + if (StringUtil::Equals(value, "BOUND_CONSTANT")) { + return ExpressionClass::BOUND_CONSTANT; + } + if (StringUtil::Equals(value, "BOUND_DEFAULT")) { + return ExpressionClass::BOUND_DEFAULT; + } + if (StringUtil::Equals(value, "BOUND_FUNCTION")) { + return ExpressionClass::BOUND_FUNCTION; + } + if (StringUtil::Equals(value, "BOUND_OPERATOR")) { + return ExpressionClass::BOUND_OPERATOR; + } + if (StringUtil::Equals(value, "BOUND_PARAMETER")) { + return ExpressionClass::BOUND_PARAMETER; + } + if (StringUtil::Equals(value, "BOUND_REF")) { + return ExpressionClass::BOUND_REF; + } + if (StringUtil::Equals(value, "BOUND_SUBQUERY")) { + return ExpressionClass::BOUND_SUBQUERY; + } + if (StringUtil::Equals(value, "BOUND_WINDOW")) { + return ExpressionClass::BOUND_WINDOW; + } + if (StringUtil::Equals(value, "BOUND_BETWEEN")) { + return ExpressionClass::BOUND_BETWEEN; + } + if (StringUtil::Equals(value, "BOUND_UNNEST")) { + return ExpressionClass::BOUND_UNNEST; + } + if (StringUtil::Equals(value, "BOUND_LAMBDA")) { + return ExpressionClass::BOUND_LAMBDA; + } + if (StringUtil::Equals(value, "BOUND_LAMBDA_REF")) { + return ExpressionClass::BOUND_LAMBDA_REF; + } + if (StringUtil::Equals(value, "BOUND_EXPRESSION")) { + return ExpressionClass::BOUND_EXPRESSION; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ExpressionType value) { + switch(value) { + case ExpressionType::INVALID: + return "INVALID"; + case ExpressionType::OPERATOR_CAST: + return "OPERATOR_CAST"; + case ExpressionType::OPERATOR_NOT: + return "OPERATOR_NOT"; + case ExpressionType::OPERATOR_IS_NULL: + return "OPERATOR_IS_NULL"; + case ExpressionType::OPERATOR_IS_NOT_NULL: + return "OPERATOR_IS_NOT_NULL"; + case ExpressionType::COMPARE_EQUAL: + return "COMPARE_EQUAL"; + case ExpressionType::COMPARE_NOTEQUAL: + return "COMPARE_NOTEQUAL"; + case ExpressionType::COMPARE_LESSTHAN: + return "COMPARE_LESSTHAN"; + case ExpressionType::COMPARE_GREATERTHAN: + return "COMPARE_GREATERTHAN"; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return "COMPARE_LESSTHANOREQUALTO"; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return "COMPARE_GREATERTHANOREQUALTO"; + case ExpressionType::COMPARE_IN: + return "COMPARE_IN"; + case ExpressionType::COMPARE_NOT_IN: + return "COMPARE_NOT_IN"; + case ExpressionType::COMPARE_DISTINCT_FROM: + return "COMPARE_DISTINCT_FROM"; + case ExpressionType::COMPARE_BETWEEN: + return "COMPARE_BETWEEN"; + case ExpressionType::COMPARE_NOT_BETWEEN: + return "COMPARE_NOT_BETWEEN"; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return "COMPARE_NOT_DISTINCT_FROM"; + case ExpressionType::CONJUNCTION_AND: + return "CONJUNCTION_AND"; + case ExpressionType::CONJUNCTION_OR: + return "CONJUNCTION_OR"; + case ExpressionType::VALUE_CONSTANT: + return "VALUE_CONSTANT"; + case ExpressionType::VALUE_PARAMETER: + return "VALUE_PARAMETER"; + case ExpressionType::VALUE_TUPLE: + return "VALUE_TUPLE"; + case ExpressionType::VALUE_TUPLE_ADDRESS: + return "VALUE_TUPLE_ADDRESS"; + case ExpressionType::VALUE_NULL: + return "VALUE_NULL"; + case ExpressionType::VALUE_VECTOR: + return "VALUE_VECTOR"; + case ExpressionType::VALUE_SCALAR: + return "VALUE_SCALAR"; + case ExpressionType::VALUE_DEFAULT: + return "VALUE_DEFAULT"; + case ExpressionType::AGGREGATE: + return "AGGREGATE"; + case ExpressionType::BOUND_AGGREGATE: + return "BOUND_AGGREGATE"; + case ExpressionType::GROUPING_FUNCTION: + return "GROUPING_FUNCTION"; + case ExpressionType::WINDOW_AGGREGATE: + return "WINDOW_AGGREGATE"; + case ExpressionType::WINDOW_RANK: + return "WINDOW_RANK"; + case ExpressionType::WINDOW_RANK_DENSE: + return "WINDOW_RANK_DENSE"; + case ExpressionType::WINDOW_NTILE: + return "WINDOW_NTILE"; + case ExpressionType::WINDOW_PERCENT_RANK: + return "WINDOW_PERCENT_RANK"; + case ExpressionType::WINDOW_CUME_DIST: + return "WINDOW_CUME_DIST"; + case ExpressionType::WINDOW_ROW_NUMBER: + return "WINDOW_ROW_NUMBER"; + case ExpressionType::WINDOW_FIRST_VALUE: + return "WINDOW_FIRST_VALUE"; + case ExpressionType::WINDOW_LAST_VALUE: + return "WINDOW_LAST_VALUE"; + case ExpressionType::WINDOW_LEAD: + return "WINDOW_LEAD"; + case ExpressionType::WINDOW_LAG: + return "WINDOW_LAG"; + case ExpressionType::WINDOW_NTH_VALUE: + return "WINDOW_NTH_VALUE"; + case ExpressionType::FUNCTION: + return "FUNCTION"; + case ExpressionType::BOUND_FUNCTION: + return "BOUND_FUNCTION"; + case ExpressionType::CASE_EXPR: + return "CASE_EXPR"; + case ExpressionType::OPERATOR_NULLIF: + return "OPERATOR_NULLIF"; + case ExpressionType::OPERATOR_COALESCE: + return "OPERATOR_COALESCE"; + case ExpressionType::ARRAY_EXTRACT: + return "ARRAY_EXTRACT"; + case ExpressionType::ARRAY_SLICE: + return "ARRAY_SLICE"; + case ExpressionType::STRUCT_EXTRACT: + return "STRUCT_EXTRACT"; + case ExpressionType::ARRAY_CONSTRUCTOR: + return "ARRAY_CONSTRUCTOR"; + case ExpressionType::ARROW: + return "ARROW"; + case ExpressionType::SUBQUERY: + return "SUBQUERY"; + case ExpressionType::STAR: + return "STAR"; + case ExpressionType::TABLE_STAR: + return "TABLE_STAR"; + case ExpressionType::PLACEHOLDER: + return "PLACEHOLDER"; + case ExpressionType::COLUMN_REF: + return "COLUMN_REF"; + case ExpressionType::FUNCTION_REF: + return "FUNCTION_REF"; + case ExpressionType::TABLE_REF: + return "TABLE_REF"; + case ExpressionType::CAST: + return "CAST"; + case ExpressionType::BOUND_REF: + return "BOUND_REF"; + case ExpressionType::BOUND_COLUMN_REF: + return "BOUND_COLUMN_REF"; + case ExpressionType::BOUND_UNNEST: + return "BOUND_UNNEST"; + case ExpressionType::COLLATE: + return "COLLATE"; + case ExpressionType::LAMBDA: + return "LAMBDA"; + case ExpressionType::POSITIONAL_REFERENCE: + return "POSITIONAL_REFERENCE"; + case ExpressionType::BOUND_LAMBDA_REF: + return "BOUND_LAMBDA_REF"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExpressionType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return ExpressionType::INVALID; + } + if (StringUtil::Equals(value, "OPERATOR_CAST")) { + return ExpressionType::OPERATOR_CAST; + } + if (StringUtil::Equals(value, "OPERATOR_NOT")) { + return ExpressionType::OPERATOR_NOT; + } + if (StringUtil::Equals(value, "OPERATOR_IS_NULL")) { + return ExpressionType::OPERATOR_IS_NULL; + } + if (StringUtil::Equals(value, "OPERATOR_IS_NOT_NULL")) { + return ExpressionType::OPERATOR_IS_NOT_NULL; + } + if (StringUtil::Equals(value, "COMPARE_EQUAL")) { + return ExpressionType::COMPARE_EQUAL; + } + if (StringUtil::Equals(value, "COMPARE_NOTEQUAL")) { + return ExpressionType::COMPARE_NOTEQUAL; + } + if (StringUtil::Equals(value, "COMPARE_LESSTHAN")) { + return ExpressionType::COMPARE_LESSTHAN; + } + if (StringUtil::Equals(value, "COMPARE_GREATERTHAN")) { + return ExpressionType::COMPARE_GREATERTHAN; + } + if (StringUtil::Equals(value, "COMPARE_LESSTHANOREQUALTO")) { + return ExpressionType::COMPARE_LESSTHANOREQUALTO; + } + if (StringUtil::Equals(value, "COMPARE_GREATERTHANOREQUALTO")) { + return ExpressionType::COMPARE_GREATERTHANOREQUALTO; + } + if (StringUtil::Equals(value, "COMPARE_IN")) { + return ExpressionType::COMPARE_IN; + } + if (StringUtil::Equals(value, "COMPARE_NOT_IN")) { + return ExpressionType::COMPARE_NOT_IN; + } + if (StringUtil::Equals(value, "COMPARE_DISTINCT_FROM")) { + return ExpressionType::COMPARE_DISTINCT_FROM; + } + if (StringUtil::Equals(value, "COMPARE_BETWEEN")) { + return ExpressionType::COMPARE_BETWEEN; + } + if (StringUtil::Equals(value, "COMPARE_NOT_BETWEEN")) { + return ExpressionType::COMPARE_NOT_BETWEEN; + } + if (StringUtil::Equals(value, "COMPARE_NOT_DISTINCT_FROM")) { + return ExpressionType::COMPARE_NOT_DISTINCT_FROM; + } + if (StringUtil::Equals(value, "CONJUNCTION_AND")) { + return ExpressionType::CONJUNCTION_AND; + } + if (StringUtil::Equals(value, "CONJUNCTION_OR")) { + return ExpressionType::CONJUNCTION_OR; + } + if (StringUtil::Equals(value, "VALUE_CONSTANT")) { + return ExpressionType::VALUE_CONSTANT; + } + if (StringUtil::Equals(value, "VALUE_PARAMETER")) { + return ExpressionType::VALUE_PARAMETER; + } + if (StringUtil::Equals(value, "VALUE_TUPLE")) { + return ExpressionType::VALUE_TUPLE; + } + if (StringUtil::Equals(value, "VALUE_TUPLE_ADDRESS")) { + return ExpressionType::VALUE_TUPLE_ADDRESS; + } + if (StringUtil::Equals(value, "VALUE_NULL")) { + return ExpressionType::VALUE_NULL; + } + if (StringUtil::Equals(value, "VALUE_VECTOR")) { + return ExpressionType::VALUE_VECTOR; + } + if (StringUtil::Equals(value, "VALUE_SCALAR")) { + return ExpressionType::VALUE_SCALAR; + } + if (StringUtil::Equals(value, "VALUE_DEFAULT")) { + return ExpressionType::VALUE_DEFAULT; + } + if (StringUtil::Equals(value, "AGGREGATE")) { + return ExpressionType::AGGREGATE; + } + if (StringUtil::Equals(value, "BOUND_AGGREGATE")) { + return ExpressionType::BOUND_AGGREGATE; + } + if (StringUtil::Equals(value, "GROUPING_FUNCTION")) { + return ExpressionType::GROUPING_FUNCTION; + } + if (StringUtil::Equals(value, "WINDOW_AGGREGATE")) { + return ExpressionType::WINDOW_AGGREGATE; + } + if (StringUtil::Equals(value, "WINDOW_RANK")) { + return ExpressionType::WINDOW_RANK; + } + if (StringUtil::Equals(value, "WINDOW_RANK_DENSE")) { + return ExpressionType::WINDOW_RANK_DENSE; + } + if (StringUtil::Equals(value, "WINDOW_NTILE")) { + return ExpressionType::WINDOW_NTILE; + } + if (StringUtil::Equals(value, "WINDOW_PERCENT_RANK")) { + return ExpressionType::WINDOW_PERCENT_RANK; + } + if (StringUtil::Equals(value, "WINDOW_CUME_DIST")) { + return ExpressionType::WINDOW_CUME_DIST; + } + if (StringUtil::Equals(value, "WINDOW_ROW_NUMBER")) { + return ExpressionType::WINDOW_ROW_NUMBER; + } + if (StringUtil::Equals(value, "WINDOW_FIRST_VALUE")) { + return ExpressionType::WINDOW_FIRST_VALUE; + } + if (StringUtil::Equals(value, "WINDOW_LAST_VALUE")) { + return ExpressionType::WINDOW_LAST_VALUE; + } + if (StringUtil::Equals(value, "WINDOW_LEAD")) { + return ExpressionType::WINDOW_LEAD; + } + if (StringUtil::Equals(value, "WINDOW_LAG")) { + return ExpressionType::WINDOW_LAG; + } + if (StringUtil::Equals(value, "WINDOW_NTH_VALUE")) { + return ExpressionType::WINDOW_NTH_VALUE; + } + if (StringUtil::Equals(value, "FUNCTION")) { + return ExpressionType::FUNCTION; + } + if (StringUtil::Equals(value, "BOUND_FUNCTION")) { + return ExpressionType::BOUND_FUNCTION; + } + if (StringUtil::Equals(value, "CASE_EXPR")) { + return ExpressionType::CASE_EXPR; + } + if (StringUtil::Equals(value, "OPERATOR_NULLIF")) { + return ExpressionType::OPERATOR_NULLIF; + } + if (StringUtil::Equals(value, "OPERATOR_COALESCE")) { + return ExpressionType::OPERATOR_COALESCE; + } + if (StringUtil::Equals(value, "ARRAY_EXTRACT")) { + return ExpressionType::ARRAY_EXTRACT; + } + if (StringUtil::Equals(value, "ARRAY_SLICE")) { + return ExpressionType::ARRAY_SLICE; + } + if (StringUtil::Equals(value, "STRUCT_EXTRACT")) { + return ExpressionType::STRUCT_EXTRACT; + } + if (StringUtil::Equals(value, "ARRAY_CONSTRUCTOR")) { + return ExpressionType::ARRAY_CONSTRUCTOR; + } + if (StringUtil::Equals(value, "ARROW")) { + return ExpressionType::ARROW; + } + if (StringUtil::Equals(value, "SUBQUERY")) { + return ExpressionType::SUBQUERY; + } + if (StringUtil::Equals(value, "STAR")) { + return ExpressionType::STAR; + } + if (StringUtil::Equals(value, "TABLE_STAR")) { + return ExpressionType::TABLE_STAR; + } + if (StringUtil::Equals(value, "PLACEHOLDER")) { + return ExpressionType::PLACEHOLDER; + } + if (StringUtil::Equals(value, "COLUMN_REF")) { + return ExpressionType::COLUMN_REF; + } + if (StringUtil::Equals(value, "FUNCTION_REF")) { + return ExpressionType::FUNCTION_REF; + } + if (StringUtil::Equals(value, "TABLE_REF")) { + return ExpressionType::TABLE_REF; + } + if (StringUtil::Equals(value, "CAST")) { + return ExpressionType::CAST; + } + if (StringUtil::Equals(value, "BOUND_REF")) { + return ExpressionType::BOUND_REF; + } + if (StringUtil::Equals(value, "BOUND_COLUMN_REF")) { + return ExpressionType::BOUND_COLUMN_REF; + } + if (StringUtil::Equals(value, "BOUND_UNNEST")) { + return ExpressionType::BOUND_UNNEST; + } + if (StringUtil::Equals(value, "COLLATE")) { + return ExpressionType::COLLATE; + } + if (StringUtil::Equals(value, "LAMBDA")) { + return ExpressionType::LAMBDA; + } + if (StringUtil::Equals(value, "POSITIONAL_REFERENCE")) { + return ExpressionType::POSITIONAL_REFERENCE; + } + if (StringUtil::Equals(value, "BOUND_LAMBDA_REF")) { + return ExpressionType::BOUND_LAMBDA_REF; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ExtensionLoadResult value) { + switch(value) { + case ExtensionLoadResult::LOADED_EXTENSION: + return "LOADED_EXTENSION"; + case ExtensionLoadResult::EXTENSION_UNKNOWN: + return "EXTENSION_UNKNOWN"; + case ExtensionLoadResult::NOT_LOADED: + return "NOT_LOADED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExtensionLoadResult EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "LOADED_EXTENSION")) { + return ExtensionLoadResult::LOADED_EXTENSION; + } + if (StringUtil::Equals(value, "EXTENSION_UNKNOWN")) { + return ExtensionLoadResult::EXTENSION_UNKNOWN; + } + if (StringUtil::Equals(value, "NOT_LOADED")) { + return ExtensionLoadResult::NOT_LOADED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ExtraTypeInfoType value) { + switch(value) { + case ExtraTypeInfoType::INVALID_TYPE_INFO: + return "INVALID_TYPE_INFO"; + case ExtraTypeInfoType::GENERIC_TYPE_INFO: + return "GENERIC_TYPE_INFO"; + case ExtraTypeInfoType::DECIMAL_TYPE_INFO: + return "DECIMAL_TYPE_INFO"; + case ExtraTypeInfoType::STRING_TYPE_INFO: + return "STRING_TYPE_INFO"; + case ExtraTypeInfoType::LIST_TYPE_INFO: + return "LIST_TYPE_INFO"; + case ExtraTypeInfoType::STRUCT_TYPE_INFO: + return "STRUCT_TYPE_INFO"; + case ExtraTypeInfoType::ENUM_TYPE_INFO: + return "ENUM_TYPE_INFO"; + case ExtraTypeInfoType::USER_TYPE_INFO: + return "USER_TYPE_INFO"; + case ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO: + return "AGGREGATE_STATE_TYPE_INFO"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExtraTypeInfoType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID_TYPE_INFO")) { + return ExtraTypeInfoType::INVALID_TYPE_INFO; + } + if (StringUtil::Equals(value, "GENERIC_TYPE_INFO")) { + return ExtraTypeInfoType::GENERIC_TYPE_INFO; + } + if (StringUtil::Equals(value, "DECIMAL_TYPE_INFO")) { + return ExtraTypeInfoType::DECIMAL_TYPE_INFO; + } + if (StringUtil::Equals(value, "STRING_TYPE_INFO")) { + return ExtraTypeInfoType::STRING_TYPE_INFO; + } + if (StringUtil::Equals(value, "LIST_TYPE_INFO")) { + return ExtraTypeInfoType::LIST_TYPE_INFO; + } + if (StringUtil::Equals(value, "STRUCT_TYPE_INFO")) { + return ExtraTypeInfoType::STRUCT_TYPE_INFO; + } + if (StringUtil::Equals(value, "ENUM_TYPE_INFO")) { + return ExtraTypeInfoType::ENUM_TYPE_INFO; + } + if (StringUtil::Equals(value, "USER_TYPE_INFO")) { + return ExtraTypeInfoType::USER_TYPE_INFO; + } + if (StringUtil::Equals(value, "AGGREGATE_STATE_TYPE_INFO")) { + return ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(FileBufferType value) { + switch(value) { + case FileBufferType::BLOCK: + return "BLOCK"; + case FileBufferType::MANAGED_BUFFER: + return "MANAGED_BUFFER"; + case FileBufferType::TINY_BUFFER: + return "TINY_BUFFER"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +FileBufferType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "BLOCK")) { + return FileBufferType::BLOCK; + } + if (StringUtil::Equals(value, "MANAGED_BUFFER")) { + return FileBufferType::MANAGED_BUFFER; + } + if (StringUtil::Equals(value, "TINY_BUFFER")) { + return FileBufferType::TINY_BUFFER; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(FileCompressionType value) { + switch(value) { + case FileCompressionType::AUTO_DETECT: + return "AUTO_DETECT"; + case FileCompressionType::UNCOMPRESSED: + return "UNCOMPRESSED"; + case FileCompressionType::GZIP: + return "GZIP"; + case FileCompressionType::ZSTD: + return "ZSTD"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +FileCompressionType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "AUTO_DETECT")) { + return FileCompressionType::AUTO_DETECT; + } + if (StringUtil::Equals(value, "UNCOMPRESSED")) { + return FileCompressionType::UNCOMPRESSED; + } + if (StringUtil::Equals(value, "GZIP")) { + return FileCompressionType::GZIP; + } + if (StringUtil::Equals(value, "ZSTD")) { + return FileCompressionType::ZSTD; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(FileGlobOptions value) { + switch(value) { + case FileGlobOptions::DISALLOW_EMPTY: + return "DISALLOW_EMPTY"; + case FileGlobOptions::ALLOW_EMPTY: + return "ALLOW_EMPTY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +FileGlobOptions EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "DISALLOW_EMPTY")) { + return FileGlobOptions::DISALLOW_EMPTY; + } + if (StringUtil::Equals(value, "ALLOW_EMPTY")) { + return FileGlobOptions::ALLOW_EMPTY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(FileLockType value) { + switch(value) { + case FileLockType::NO_LOCK: + return "NO_LOCK"; + case FileLockType::READ_LOCK: + return "READ_LOCK"; + case FileLockType::WRITE_LOCK: + return "WRITE_LOCK"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +FileLockType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NO_LOCK")) { + return FileLockType::NO_LOCK; + } + if (StringUtil::Equals(value, "READ_LOCK")) { + return FileLockType::READ_LOCK; + } + if (StringUtil::Equals(value, "WRITE_LOCK")) { + return FileLockType::WRITE_LOCK; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(FilterPropagateResult value) { + switch(value) { + case FilterPropagateResult::NO_PRUNING_POSSIBLE: + return "NO_PRUNING_POSSIBLE"; + case FilterPropagateResult::FILTER_ALWAYS_TRUE: + return "FILTER_ALWAYS_TRUE"; + case FilterPropagateResult::FILTER_ALWAYS_FALSE: + return "FILTER_ALWAYS_FALSE"; + case FilterPropagateResult::FILTER_TRUE_OR_NULL: + return "FILTER_TRUE_OR_NULL"; + case FilterPropagateResult::FILTER_FALSE_OR_NULL: + return "FILTER_FALSE_OR_NULL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +FilterPropagateResult EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NO_PRUNING_POSSIBLE")) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + if (StringUtil::Equals(value, "FILTER_ALWAYS_TRUE")) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + if (StringUtil::Equals(value, "FILTER_ALWAYS_FALSE")) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + if (StringUtil::Equals(value, "FILTER_TRUE_OR_NULL")) { + return FilterPropagateResult::FILTER_TRUE_OR_NULL; + } + if (StringUtil::Equals(value, "FILTER_FALSE_OR_NULL")) { + return FilterPropagateResult::FILTER_FALSE_OR_NULL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ForeignKeyType value) { + switch(value) { + case ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE: + return "FK_TYPE_PRIMARY_KEY_TABLE"; + case ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE: + return "FK_TYPE_FOREIGN_KEY_TABLE"; + case ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE: + return "FK_TYPE_SELF_REFERENCE_TABLE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ForeignKeyType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "FK_TYPE_PRIMARY_KEY_TABLE")) { + return ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE; + } + if (StringUtil::Equals(value, "FK_TYPE_FOREIGN_KEY_TABLE")) { + return ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; + } + if (StringUtil::Equals(value, "FK_TYPE_SELF_REFERENCE_TABLE")) { + return ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(FunctionNullHandling value) { + switch(value) { + case FunctionNullHandling::DEFAULT_NULL_HANDLING: + return "DEFAULT_NULL_HANDLING"; + case FunctionNullHandling::SPECIAL_HANDLING: + return "SPECIAL_HANDLING"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +FunctionNullHandling EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "DEFAULT_NULL_HANDLING")) { + return FunctionNullHandling::DEFAULT_NULL_HANDLING; + } + if (StringUtil::Equals(value, "SPECIAL_HANDLING")) { + return FunctionNullHandling::SPECIAL_HANDLING; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(FunctionSideEffects value) { + switch(value) { + case FunctionSideEffects::NO_SIDE_EFFECTS: + return "NO_SIDE_EFFECTS"; + case FunctionSideEffects::HAS_SIDE_EFFECTS: + return "HAS_SIDE_EFFECTS"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +FunctionSideEffects EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NO_SIDE_EFFECTS")) { + return FunctionSideEffects::NO_SIDE_EFFECTS; + } + if (StringUtil::Equals(value, "HAS_SIDE_EFFECTS")) { + return FunctionSideEffects::HAS_SIDE_EFFECTS; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(HLLStorageType value) { + switch(value) { + case HLLStorageType::UNCOMPRESSED: + return "UNCOMPRESSED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +HLLStorageType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "UNCOMPRESSED")) { + return HLLStorageType::UNCOMPRESSED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(IndexConstraintType value) { + switch(value) { + case IndexConstraintType::NONE: + return "NONE"; + case IndexConstraintType::UNIQUE: + return "UNIQUE"; + case IndexConstraintType::PRIMARY: + return "PRIMARY"; + case IndexConstraintType::FOREIGN: + return "FOREIGN"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +IndexConstraintType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NONE")) { + return IndexConstraintType::NONE; + } + if (StringUtil::Equals(value, "UNIQUE")) { + return IndexConstraintType::UNIQUE; + } + if (StringUtil::Equals(value, "PRIMARY")) { + return IndexConstraintType::PRIMARY; + } + if (StringUtil::Equals(value, "FOREIGN")) { + return IndexConstraintType::FOREIGN; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(IndexType value) { + switch(value) { + case IndexType::INVALID: + return "INVALID"; + case IndexType::ART: + return "ART"; + case IndexType::EXTENSION: + return "EXTENSION"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +IndexType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return IndexType::INVALID; + } + if (StringUtil::Equals(value, "ART")) { + return IndexType::ART; + } + if (StringUtil::Equals(value, "EXTENSION")) { + return IndexType::EXTENSION; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(InsertColumnOrder value) { + switch(value) { + case InsertColumnOrder::INSERT_BY_POSITION: + return "INSERT_BY_POSITION"; + case InsertColumnOrder::INSERT_BY_NAME: + return "INSERT_BY_NAME"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +InsertColumnOrder EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INSERT_BY_POSITION")) { + return InsertColumnOrder::INSERT_BY_POSITION; + } + if (StringUtil::Equals(value, "INSERT_BY_NAME")) { + return InsertColumnOrder::INSERT_BY_NAME; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(InterruptMode value) { + switch(value) { + case InterruptMode::NO_INTERRUPTS: + return "NO_INTERRUPTS"; + case InterruptMode::TASK: + return "TASK"; + case InterruptMode::BLOCKING: + return "BLOCKING"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +InterruptMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NO_INTERRUPTS")) { + return InterruptMode::NO_INTERRUPTS; + } + if (StringUtil::Equals(value, "TASK")) { + return InterruptMode::TASK; + } + if (StringUtil::Equals(value, "BLOCKING")) { + return InterruptMode::BLOCKING; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(JoinRefType value) { + switch(value) { + case JoinRefType::REGULAR: + return "REGULAR"; + case JoinRefType::NATURAL: + return "NATURAL"; + case JoinRefType::CROSS: + return "CROSS"; + case JoinRefType::POSITIONAL: + return "POSITIONAL"; + case JoinRefType::ASOF: + return "ASOF"; + case JoinRefType::DEPENDENT: + return "DEPENDENT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +JoinRefType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "REGULAR")) { + return JoinRefType::REGULAR; + } + if (StringUtil::Equals(value, "NATURAL")) { + return JoinRefType::NATURAL; + } + if (StringUtil::Equals(value, "CROSS")) { + return JoinRefType::CROSS; + } + if (StringUtil::Equals(value, "POSITIONAL")) { + return JoinRefType::POSITIONAL; + } + if (StringUtil::Equals(value, "ASOF")) { + return JoinRefType::ASOF; + } + if (StringUtil::Equals(value, "DEPENDENT")) { + return JoinRefType::DEPENDENT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(JoinType value) { + switch(value) { + case JoinType::INVALID: + return "INVALID"; + case JoinType::LEFT: + return "LEFT"; + case JoinType::RIGHT: + return "RIGHT"; + case JoinType::INNER: + return "INNER"; + case JoinType::OUTER: + return "FULL"; + case JoinType::SEMI: + return "SEMI"; + case JoinType::ANTI: + return "ANTI"; + case JoinType::MARK: + return "MARK"; + case JoinType::SINGLE: + return "SINGLE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +JoinType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return JoinType::INVALID; + } + if (StringUtil::Equals(value, "LEFT")) { + return JoinType::LEFT; + } + if (StringUtil::Equals(value, "RIGHT")) { + return JoinType::RIGHT; + } + if (StringUtil::Equals(value, "INNER")) { + return JoinType::INNER; + } + if (StringUtil::Equals(value, "FULL")) { + return JoinType::OUTER; + } + if (StringUtil::Equals(value, "SEMI")) { + return JoinType::SEMI; + } + if (StringUtil::Equals(value, "ANTI")) { + return JoinType::ANTI; + } + if (StringUtil::Equals(value, "MARK")) { + return JoinType::MARK; + } + if (StringUtil::Equals(value, "SINGLE")) { + return JoinType::SINGLE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(KeywordCategory value) { + switch(value) { + case KeywordCategory::KEYWORD_RESERVED: + return "KEYWORD_RESERVED"; + case KeywordCategory::KEYWORD_UNRESERVED: + return "KEYWORD_UNRESERVED"; + case KeywordCategory::KEYWORD_TYPE_FUNC: + return "KEYWORD_TYPE_FUNC"; + case KeywordCategory::KEYWORD_COL_NAME: + return "KEYWORD_COL_NAME"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +KeywordCategory EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "KEYWORD_RESERVED")) { + return KeywordCategory::KEYWORD_RESERVED; + } + if (StringUtil::Equals(value, "KEYWORD_UNRESERVED")) { + return KeywordCategory::KEYWORD_UNRESERVED; + } + if (StringUtil::Equals(value, "KEYWORD_TYPE_FUNC")) { + return KeywordCategory::KEYWORD_TYPE_FUNC; + } + if (StringUtil::Equals(value, "KEYWORD_COL_NAME")) { + return KeywordCategory::KEYWORD_COL_NAME; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(LoadType value) { + switch(value) { + case LoadType::LOAD: + return "LOAD"; + case LoadType::INSTALL: + return "INSTALL"; + case LoadType::FORCE_INSTALL: + return "FORCE_INSTALL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +LoadType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "LOAD")) { + return LoadType::LOAD; + } + if (StringUtil::Equals(value, "INSTALL")) { + return LoadType::INSTALL; + } + if (StringUtil::Equals(value, "FORCE_INSTALL")) { + return LoadType::FORCE_INSTALL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(LogicalOperatorType value) { + switch(value) { + case LogicalOperatorType::LOGICAL_INVALID: + return "LOGICAL_INVALID"; + case LogicalOperatorType::LOGICAL_PROJECTION: + return "LOGICAL_PROJECTION"; + case LogicalOperatorType::LOGICAL_FILTER: + return "LOGICAL_FILTER"; + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + return "LOGICAL_AGGREGATE_AND_GROUP_BY"; + case LogicalOperatorType::LOGICAL_WINDOW: + return "LOGICAL_WINDOW"; + case LogicalOperatorType::LOGICAL_UNNEST: + return "LOGICAL_UNNEST"; + case LogicalOperatorType::LOGICAL_LIMIT: + return "LOGICAL_LIMIT"; + case LogicalOperatorType::LOGICAL_ORDER_BY: + return "LOGICAL_ORDER_BY"; + case LogicalOperatorType::LOGICAL_TOP_N: + return "LOGICAL_TOP_N"; + case LogicalOperatorType::LOGICAL_COPY_TO_FILE: + return "LOGICAL_COPY_TO_FILE"; + case LogicalOperatorType::LOGICAL_DISTINCT: + return "LOGICAL_DISTINCT"; + case LogicalOperatorType::LOGICAL_SAMPLE: + return "LOGICAL_SAMPLE"; + case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: + return "LOGICAL_LIMIT_PERCENT"; + case LogicalOperatorType::LOGICAL_PIVOT: + return "LOGICAL_PIVOT"; + case LogicalOperatorType::LOGICAL_GET: + return "LOGICAL_GET"; + case LogicalOperatorType::LOGICAL_CHUNK_GET: + return "LOGICAL_CHUNK_GET"; + case LogicalOperatorType::LOGICAL_DELIM_GET: + return "LOGICAL_DELIM_GET"; + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + return "LOGICAL_EXPRESSION_GET"; + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + return "LOGICAL_DUMMY_SCAN"; + case LogicalOperatorType::LOGICAL_EMPTY_RESULT: + return "LOGICAL_EMPTY_RESULT"; + case LogicalOperatorType::LOGICAL_CTE_REF: + return "LOGICAL_CTE_REF"; + case LogicalOperatorType::LOGICAL_JOIN: + return "LOGICAL_JOIN"; + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + return "LOGICAL_DELIM_JOIN"; + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + return "LOGICAL_COMPARISON_JOIN"; + case LogicalOperatorType::LOGICAL_ANY_JOIN: + return "LOGICAL_ANY_JOIN"; + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + return "LOGICAL_CROSS_PRODUCT"; + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: + return "LOGICAL_POSITIONAL_JOIN"; + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + return "LOGICAL_ASOF_JOIN"; + case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: + return "LOGICAL_DEPENDENT_JOIN"; + case LogicalOperatorType::LOGICAL_UNION: + return "LOGICAL_UNION"; + case LogicalOperatorType::LOGICAL_EXCEPT: + return "LOGICAL_EXCEPT"; + case LogicalOperatorType::LOGICAL_INTERSECT: + return "LOGICAL_INTERSECT"; + case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: + return "LOGICAL_RECURSIVE_CTE"; + case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: + return "LOGICAL_MATERIALIZED_CTE"; + case LogicalOperatorType::LOGICAL_INSERT: + return "LOGICAL_INSERT"; + case LogicalOperatorType::LOGICAL_DELETE: + return "LOGICAL_DELETE"; + case LogicalOperatorType::LOGICAL_UPDATE: + return "LOGICAL_UPDATE"; + case LogicalOperatorType::LOGICAL_ALTER: + return "LOGICAL_ALTER"; + case LogicalOperatorType::LOGICAL_CREATE_TABLE: + return "LOGICAL_CREATE_TABLE"; + case LogicalOperatorType::LOGICAL_CREATE_INDEX: + return "LOGICAL_CREATE_INDEX"; + case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: + return "LOGICAL_CREATE_SEQUENCE"; + case LogicalOperatorType::LOGICAL_CREATE_VIEW: + return "LOGICAL_CREATE_VIEW"; + case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: + return "LOGICAL_CREATE_SCHEMA"; + case LogicalOperatorType::LOGICAL_CREATE_MACRO: + return "LOGICAL_CREATE_MACRO"; + case LogicalOperatorType::LOGICAL_DROP: + return "LOGICAL_DROP"; + case LogicalOperatorType::LOGICAL_PRAGMA: + return "LOGICAL_PRAGMA"; + case LogicalOperatorType::LOGICAL_TRANSACTION: + return "LOGICAL_TRANSACTION"; + case LogicalOperatorType::LOGICAL_CREATE_TYPE: + return "LOGICAL_CREATE_TYPE"; + case LogicalOperatorType::LOGICAL_ATTACH: + return "LOGICAL_ATTACH"; + case LogicalOperatorType::LOGICAL_DETACH: + return "LOGICAL_DETACH"; + case LogicalOperatorType::LOGICAL_EXPLAIN: + return "LOGICAL_EXPLAIN"; + case LogicalOperatorType::LOGICAL_SHOW: + return "LOGICAL_SHOW"; + case LogicalOperatorType::LOGICAL_PREPARE: + return "LOGICAL_PREPARE"; + case LogicalOperatorType::LOGICAL_EXECUTE: + return "LOGICAL_EXECUTE"; + case LogicalOperatorType::LOGICAL_EXPORT: + return "LOGICAL_EXPORT"; + case LogicalOperatorType::LOGICAL_VACUUM: + return "LOGICAL_VACUUM"; + case LogicalOperatorType::LOGICAL_SET: + return "LOGICAL_SET"; + case LogicalOperatorType::LOGICAL_LOAD: + return "LOGICAL_LOAD"; + case LogicalOperatorType::LOGICAL_RESET: + return "LOGICAL_RESET"; + case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: + return "LOGICAL_EXTENSION_OPERATOR"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +LogicalOperatorType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "LOGICAL_INVALID")) { + return LogicalOperatorType::LOGICAL_INVALID; + } + if (StringUtil::Equals(value, "LOGICAL_PROJECTION")) { + return LogicalOperatorType::LOGICAL_PROJECTION; + } + if (StringUtil::Equals(value, "LOGICAL_FILTER")) { + return LogicalOperatorType::LOGICAL_FILTER; + } + if (StringUtil::Equals(value, "LOGICAL_AGGREGATE_AND_GROUP_BY")) { + return LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY; + } + if (StringUtil::Equals(value, "LOGICAL_WINDOW")) { + return LogicalOperatorType::LOGICAL_WINDOW; + } + if (StringUtil::Equals(value, "LOGICAL_UNNEST")) { + return LogicalOperatorType::LOGICAL_UNNEST; + } + if (StringUtil::Equals(value, "LOGICAL_LIMIT")) { + return LogicalOperatorType::LOGICAL_LIMIT; + } + if (StringUtil::Equals(value, "LOGICAL_ORDER_BY")) { + return LogicalOperatorType::LOGICAL_ORDER_BY; + } + if (StringUtil::Equals(value, "LOGICAL_TOP_N")) { + return LogicalOperatorType::LOGICAL_TOP_N; + } + if (StringUtil::Equals(value, "LOGICAL_COPY_TO_FILE")) { + return LogicalOperatorType::LOGICAL_COPY_TO_FILE; + } + if (StringUtil::Equals(value, "LOGICAL_DISTINCT")) { + return LogicalOperatorType::LOGICAL_DISTINCT; + } + if (StringUtil::Equals(value, "LOGICAL_SAMPLE")) { + return LogicalOperatorType::LOGICAL_SAMPLE; + } + if (StringUtil::Equals(value, "LOGICAL_LIMIT_PERCENT")) { + return LogicalOperatorType::LOGICAL_LIMIT_PERCENT; + } + if (StringUtil::Equals(value, "LOGICAL_PIVOT")) { + return LogicalOperatorType::LOGICAL_PIVOT; + } + if (StringUtil::Equals(value, "LOGICAL_GET")) { + return LogicalOperatorType::LOGICAL_GET; + } + if (StringUtil::Equals(value, "LOGICAL_CHUNK_GET")) { + return LogicalOperatorType::LOGICAL_CHUNK_GET; + } + if (StringUtil::Equals(value, "LOGICAL_DELIM_GET")) { + return LogicalOperatorType::LOGICAL_DELIM_GET; + } + if (StringUtil::Equals(value, "LOGICAL_EXPRESSION_GET")) { + return LogicalOperatorType::LOGICAL_EXPRESSION_GET; + } + if (StringUtil::Equals(value, "LOGICAL_DUMMY_SCAN")) { + return LogicalOperatorType::LOGICAL_DUMMY_SCAN; + } + if (StringUtil::Equals(value, "LOGICAL_EMPTY_RESULT")) { + return LogicalOperatorType::LOGICAL_EMPTY_RESULT; + } + if (StringUtil::Equals(value, "LOGICAL_CTE_REF")) { + return LogicalOperatorType::LOGICAL_CTE_REF; + } + if (StringUtil::Equals(value, "LOGICAL_JOIN")) { + return LogicalOperatorType::LOGICAL_JOIN; + } + if (StringUtil::Equals(value, "LOGICAL_DELIM_JOIN")) { + return LogicalOperatorType::LOGICAL_DELIM_JOIN; + } + if (StringUtil::Equals(value, "LOGICAL_COMPARISON_JOIN")) { + return LogicalOperatorType::LOGICAL_COMPARISON_JOIN; + } + if (StringUtil::Equals(value, "LOGICAL_ANY_JOIN")) { + return LogicalOperatorType::LOGICAL_ANY_JOIN; + } + if (StringUtil::Equals(value, "LOGICAL_CROSS_PRODUCT")) { + return LogicalOperatorType::LOGICAL_CROSS_PRODUCT; + } + if (StringUtil::Equals(value, "LOGICAL_POSITIONAL_JOIN")) { + return LogicalOperatorType::LOGICAL_POSITIONAL_JOIN; + } + if (StringUtil::Equals(value, "LOGICAL_ASOF_JOIN")) { + return LogicalOperatorType::LOGICAL_ASOF_JOIN; + } + if (StringUtil::Equals(value, "LOGICAL_DEPENDENT_JOIN")) { + return LogicalOperatorType::LOGICAL_DEPENDENT_JOIN; + } + if (StringUtil::Equals(value, "LOGICAL_UNION")) { + return LogicalOperatorType::LOGICAL_UNION; + } + if (StringUtil::Equals(value, "LOGICAL_EXCEPT")) { + return LogicalOperatorType::LOGICAL_EXCEPT; + } + if (StringUtil::Equals(value, "LOGICAL_INTERSECT")) { + return LogicalOperatorType::LOGICAL_INTERSECT; + } + if (StringUtil::Equals(value, "LOGICAL_RECURSIVE_CTE")) { + return LogicalOperatorType::LOGICAL_RECURSIVE_CTE; + } + if (StringUtil::Equals(value, "LOGICAL_MATERIALIZED_CTE")) { + return LogicalOperatorType::LOGICAL_MATERIALIZED_CTE; + } + if (StringUtil::Equals(value, "LOGICAL_INSERT")) { + return LogicalOperatorType::LOGICAL_INSERT; + } + if (StringUtil::Equals(value, "LOGICAL_DELETE")) { + return LogicalOperatorType::LOGICAL_DELETE; + } + if (StringUtil::Equals(value, "LOGICAL_UPDATE")) { + return LogicalOperatorType::LOGICAL_UPDATE; + } + if (StringUtil::Equals(value, "LOGICAL_ALTER")) { + return LogicalOperatorType::LOGICAL_ALTER; + } + if (StringUtil::Equals(value, "LOGICAL_CREATE_TABLE")) { + return LogicalOperatorType::LOGICAL_CREATE_TABLE; + } + if (StringUtil::Equals(value, "LOGICAL_CREATE_INDEX")) { + return LogicalOperatorType::LOGICAL_CREATE_INDEX; + } + if (StringUtil::Equals(value, "LOGICAL_CREATE_SEQUENCE")) { + return LogicalOperatorType::LOGICAL_CREATE_SEQUENCE; + } + if (StringUtil::Equals(value, "LOGICAL_CREATE_VIEW")) { + return LogicalOperatorType::LOGICAL_CREATE_VIEW; + } + if (StringUtil::Equals(value, "LOGICAL_CREATE_SCHEMA")) { + return LogicalOperatorType::LOGICAL_CREATE_SCHEMA; + } + if (StringUtil::Equals(value, "LOGICAL_CREATE_MACRO")) { + return LogicalOperatorType::LOGICAL_CREATE_MACRO; + } + if (StringUtil::Equals(value, "LOGICAL_DROP")) { + return LogicalOperatorType::LOGICAL_DROP; + } + if (StringUtil::Equals(value, "LOGICAL_PRAGMA")) { + return LogicalOperatorType::LOGICAL_PRAGMA; + } + if (StringUtil::Equals(value, "LOGICAL_TRANSACTION")) { + return LogicalOperatorType::LOGICAL_TRANSACTION; + } + if (StringUtil::Equals(value, "LOGICAL_CREATE_TYPE")) { + return LogicalOperatorType::LOGICAL_CREATE_TYPE; + } + if (StringUtil::Equals(value, "LOGICAL_ATTACH")) { + return LogicalOperatorType::LOGICAL_ATTACH; + } + if (StringUtil::Equals(value, "LOGICAL_DETACH")) { + return LogicalOperatorType::LOGICAL_DETACH; + } + if (StringUtil::Equals(value, "LOGICAL_EXPLAIN")) { + return LogicalOperatorType::LOGICAL_EXPLAIN; + } + if (StringUtil::Equals(value, "LOGICAL_SHOW")) { + return LogicalOperatorType::LOGICAL_SHOW; + } + if (StringUtil::Equals(value, "LOGICAL_PREPARE")) { + return LogicalOperatorType::LOGICAL_PREPARE; + } + if (StringUtil::Equals(value, "LOGICAL_EXECUTE")) { + return LogicalOperatorType::LOGICAL_EXECUTE; + } + if (StringUtil::Equals(value, "LOGICAL_EXPORT")) { + return LogicalOperatorType::LOGICAL_EXPORT; + } + if (StringUtil::Equals(value, "LOGICAL_VACUUM")) { + return LogicalOperatorType::LOGICAL_VACUUM; + } + if (StringUtil::Equals(value, "LOGICAL_SET")) { + return LogicalOperatorType::LOGICAL_SET; + } + if (StringUtil::Equals(value, "LOGICAL_LOAD")) { + return LogicalOperatorType::LOGICAL_LOAD; + } + if (StringUtil::Equals(value, "LOGICAL_RESET")) { + return LogicalOperatorType::LOGICAL_RESET; + } + if (StringUtil::Equals(value, "LOGICAL_EXTENSION_OPERATOR")) { + return LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(LogicalTypeId value) { + switch(value) { + case LogicalTypeId::INVALID: + return "INVALID"; + case LogicalTypeId::SQLNULL: + return "NULL"; + case LogicalTypeId::UNKNOWN: + return "UNKNOWN"; + case LogicalTypeId::ANY: + return "ANY"; + case LogicalTypeId::USER: + return "USER"; + case LogicalTypeId::BOOLEAN: + return "BOOLEAN"; + case LogicalTypeId::TINYINT: + return "TINYINT"; + case LogicalTypeId::SMALLINT: + return "SMALLINT"; + case LogicalTypeId::INTEGER: + return "INTEGER"; + case LogicalTypeId::BIGINT: + return "BIGINT"; + case LogicalTypeId::DATE: + return "DATE"; + case LogicalTypeId::TIME: + return "TIME"; + case LogicalTypeId::TIMESTAMP_SEC: + return "TIMESTAMP_S"; + case LogicalTypeId::TIMESTAMP_MS: + return "TIMESTAMP_MS"; + case LogicalTypeId::TIMESTAMP: + return "TIMESTAMP"; + case LogicalTypeId::TIMESTAMP_NS: + return "TIMESTAMP_NS"; + case LogicalTypeId::DECIMAL: + return "DECIMAL"; + case LogicalTypeId::FLOAT: + return "FLOAT"; + case LogicalTypeId::DOUBLE: + return "DOUBLE"; + case LogicalTypeId::CHAR: + return "CHAR"; + case LogicalTypeId::VARCHAR: + return "VARCHAR"; + case LogicalTypeId::BLOB: + return "BLOB"; + case LogicalTypeId::INTERVAL: + return "INTERVAL"; + case LogicalTypeId::UTINYINT: + return "UTINYINT"; + case LogicalTypeId::USMALLINT: + return "USMALLINT"; + case LogicalTypeId::UINTEGER: + return "UINTEGER"; + case LogicalTypeId::UBIGINT: + return "UBIGINT"; + case LogicalTypeId::TIMESTAMP_TZ: + return "TIMESTAMP WITH TIME ZONE"; + case LogicalTypeId::TIME_TZ: + return "TIME WITH TIME ZONE"; + case LogicalTypeId::BIT: + return "BIT"; + case LogicalTypeId::HUGEINT: + return "HUGEINT"; + case LogicalTypeId::POINTER: + return "POINTER"; + case LogicalTypeId::VALIDITY: + return "VALIDITY"; + case LogicalTypeId::UUID: + return "UUID"; + case LogicalTypeId::STRUCT: + return "STRUCT"; + case LogicalTypeId::LIST: + return "LIST"; + case LogicalTypeId::MAP: + return "MAP"; + case LogicalTypeId::TABLE: + return "TABLE"; + case LogicalTypeId::ENUM: + return "ENUM"; + case LogicalTypeId::AGGREGATE_STATE: + return "AGGREGATE_STATE"; + case LogicalTypeId::LAMBDA: + return "LAMBDA"; + case LogicalTypeId::UNION: + return "UNION"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +LogicalTypeId EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return LogicalTypeId::INVALID; + } + if (StringUtil::Equals(value, "NULL")) { + return LogicalTypeId::SQLNULL; + } + if (StringUtil::Equals(value, "UNKNOWN")) { + return LogicalTypeId::UNKNOWN; + } + if (StringUtil::Equals(value, "ANY")) { + return LogicalTypeId::ANY; + } + if (StringUtil::Equals(value, "USER")) { + return LogicalTypeId::USER; + } + if (StringUtil::Equals(value, "BOOLEAN")) { + return LogicalTypeId::BOOLEAN; + } + if (StringUtil::Equals(value, "TINYINT")) { + return LogicalTypeId::TINYINT; + } + if (StringUtil::Equals(value, "SMALLINT")) { + return LogicalTypeId::SMALLINT; + } + if (StringUtil::Equals(value, "INTEGER")) { + return LogicalTypeId::INTEGER; + } + if (StringUtil::Equals(value, "BIGINT")) { + return LogicalTypeId::BIGINT; + } + if (StringUtil::Equals(value, "DATE")) { + return LogicalTypeId::DATE; + } + if (StringUtil::Equals(value, "TIME")) { + return LogicalTypeId::TIME; + } + if (StringUtil::Equals(value, "TIMESTAMP_S")) { + return LogicalTypeId::TIMESTAMP_SEC; + } + if (StringUtil::Equals(value, "TIMESTAMP_MS")) { + return LogicalTypeId::TIMESTAMP_MS; + } + if (StringUtil::Equals(value, "TIMESTAMP")) { + return LogicalTypeId::TIMESTAMP; + } + if (StringUtil::Equals(value, "TIMESTAMP_NS")) { + return LogicalTypeId::TIMESTAMP_NS; + } + if (StringUtil::Equals(value, "DECIMAL")) { + return LogicalTypeId::DECIMAL; + } + if (StringUtil::Equals(value, "FLOAT")) { + return LogicalTypeId::FLOAT; + } + if (StringUtil::Equals(value, "DOUBLE")) { + return LogicalTypeId::DOUBLE; + } + if (StringUtil::Equals(value, "CHAR")) { + return LogicalTypeId::CHAR; + } + if (StringUtil::Equals(value, "VARCHAR")) { + return LogicalTypeId::VARCHAR; + } + if (StringUtil::Equals(value, "BLOB")) { + return LogicalTypeId::BLOB; + } + if (StringUtil::Equals(value, "INTERVAL")) { + return LogicalTypeId::INTERVAL; + } + if (StringUtil::Equals(value, "UTINYINT")) { + return LogicalTypeId::UTINYINT; + } + if (StringUtil::Equals(value, "USMALLINT")) { + return LogicalTypeId::USMALLINT; + } + if (StringUtil::Equals(value, "UINTEGER")) { + return LogicalTypeId::UINTEGER; + } + if (StringUtil::Equals(value, "UBIGINT")) { + return LogicalTypeId::UBIGINT; + } + if (StringUtil::Equals(value, "TIMESTAMP WITH TIME ZONE")) { + return LogicalTypeId::TIMESTAMP_TZ; + } + if (StringUtil::Equals(value, "TIME WITH TIME ZONE")) { + return LogicalTypeId::TIME_TZ; + } + if (StringUtil::Equals(value, "BIT")) { + return LogicalTypeId::BIT; + } + if (StringUtil::Equals(value, "HUGEINT")) { + return LogicalTypeId::HUGEINT; + } + if (StringUtil::Equals(value, "POINTER")) { + return LogicalTypeId::POINTER; + } + if (StringUtil::Equals(value, "VALIDITY")) { + return LogicalTypeId::VALIDITY; + } + if (StringUtil::Equals(value, "UUID")) { + return LogicalTypeId::UUID; + } + if (StringUtil::Equals(value, "STRUCT")) { + return LogicalTypeId::STRUCT; + } + if (StringUtil::Equals(value, "LIST")) { + return LogicalTypeId::LIST; + } + if (StringUtil::Equals(value, "MAP")) { + return LogicalTypeId::MAP; + } + if (StringUtil::Equals(value, "TABLE")) { + return LogicalTypeId::TABLE; + } + if (StringUtil::Equals(value, "ENUM")) { + return LogicalTypeId::ENUM; + } + if (StringUtil::Equals(value, "AGGREGATE_STATE")) { + return LogicalTypeId::AGGREGATE_STATE; + } + if (StringUtil::Equals(value, "LAMBDA")) { + return LogicalTypeId::LAMBDA; + } + if (StringUtil::Equals(value, "UNION")) { + return LogicalTypeId::UNION; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(LookupResultType value) { + switch(value) { + case LookupResultType::LOOKUP_MISS: + return "LOOKUP_MISS"; + case LookupResultType::LOOKUP_HIT: + return "LOOKUP_HIT"; + case LookupResultType::LOOKUP_NULL: + return "LOOKUP_NULL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +LookupResultType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "LOOKUP_MISS")) { + return LookupResultType::LOOKUP_MISS; + } + if (StringUtil::Equals(value, "LOOKUP_HIT")) { + return LookupResultType::LOOKUP_HIT; + } + if (StringUtil::Equals(value, "LOOKUP_NULL")) { + return LookupResultType::LOOKUP_NULL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(MacroType value) { + switch(value) { + case MacroType::VOID_MACRO: + return "VOID_MACRO"; + case MacroType::TABLE_MACRO: + return "TABLE_MACRO"; + case MacroType::SCALAR_MACRO: + return "SCALAR_MACRO"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +MacroType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "VOID_MACRO")) { + return MacroType::VOID_MACRO; + } + if (StringUtil::Equals(value, "TABLE_MACRO")) { + return MacroType::TABLE_MACRO; + } + if (StringUtil::Equals(value, "SCALAR_MACRO")) { + return MacroType::SCALAR_MACRO; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(MapInvalidReason value) { + switch(value) { + case MapInvalidReason::VALID: + return "VALID"; + case MapInvalidReason::NULL_KEY_LIST: + return "NULL_KEY_LIST"; + case MapInvalidReason::NULL_KEY: + return "NULL_KEY"; + case MapInvalidReason::DUPLICATE_KEY: + return "DUPLICATE_KEY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +MapInvalidReason EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "VALID")) { + return MapInvalidReason::VALID; + } + if (StringUtil::Equals(value, "NULL_KEY_LIST")) { + return MapInvalidReason::NULL_KEY_LIST; + } + if (StringUtil::Equals(value, "NULL_KEY")) { + return MapInvalidReason::NULL_KEY; + } + if (StringUtil::Equals(value, "DUPLICATE_KEY")) { + return MapInvalidReason::DUPLICATE_KEY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(NType value) { + switch(value) { + case NType::PREFIX: + return "PREFIX"; + case NType::LEAF: + return "LEAF"; + case NType::NODE_4: + return "NODE_4"; + case NType::NODE_16: + return "NODE_16"; + case NType::NODE_48: + return "NODE_48"; + case NType::NODE_256: + return "NODE_256"; + case NType::LEAF_INLINED: + return "LEAF_INLINED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +NType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "PREFIX")) { + return NType::PREFIX; + } + if (StringUtil::Equals(value, "LEAF")) { + return NType::LEAF; + } + if (StringUtil::Equals(value, "NODE_4")) { + return NType::NODE_4; + } + if (StringUtil::Equals(value, "NODE_16")) { + return NType::NODE_16; + } + if (StringUtil::Equals(value, "NODE_48")) { + return NType::NODE_48; + } + if (StringUtil::Equals(value, "NODE_256")) { + return NType::NODE_256; + } + if (StringUtil::Equals(value, "LEAF_INLINED")) { + return NType::LEAF_INLINED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(NewLineIdentifier value) { + switch(value) { + case NewLineIdentifier::SINGLE: + return "SINGLE"; + case NewLineIdentifier::CARRY_ON: + return "CARRY_ON"; + case NewLineIdentifier::MIX: + return "MIX"; + case NewLineIdentifier::NOT_SET: + return "NOT_SET"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +NewLineIdentifier EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "SINGLE")) { + return NewLineIdentifier::SINGLE; + } + if (StringUtil::Equals(value, "CARRY_ON")) { + return NewLineIdentifier::CARRY_ON; + } + if (StringUtil::Equals(value, "MIX")) { + return NewLineIdentifier::MIX; + } + if (StringUtil::Equals(value, "NOT_SET")) { + return NewLineIdentifier::NOT_SET; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OnConflictAction value) { + switch(value) { + case OnConflictAction::THROW: + return "THROW"; + case OnConflictAction::NOTHING: + return "NOTHING"; + case OnConflictAction::UPDATE: + return "UPDATE"; + case OnConflictAction::REPLACE: + return "REPLACE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OnConflictAction EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "THROW")) { + return OnConflictAction::THROW; + } + if (StringUtil::Equals(value, "NOTHING")) { + return OnConflictAction::NOTHING; + } + if (StringUtil::Equals(value, "UPDATE")) { + return OnConflictAction::UPDATE; + } + if (StringUtil::Equals(value, "REPLACE")) { + return OnConflictAction::REPLACE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OnCreateConflict value) { + switch(value) { + case OnCreateConflict::ERROR_ON_CONFLICT: + return "ERROR_ON_CONFLICT"; + case OnCreateConflict::IGNORE_ON_CONFLICT: + return "IGNORE_ON_CONFLICT"; + case OnCreateConflict::REPLACE_ON_CONFLICT: + return "REPLACE_ON_CONFLICT"; + case OnCreateConflict::ALTER_ON_CONFLICT: + return "ALTER_ON_CONFLICT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OnCreateConflict EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "ERROR_ON_CONFLICT")) { + return OnCreateConflict::ERROR_ON_CONFLICT; + } + if (StringUtil::Equals(value, "IGNORE_ON_CONFLICT")) { + return OnCreateConflict::IGNORE_ON_CONFLICT; + } + if (StringUtil::Equals(value, "REPLACE_ON_CONFLICT")) { + return OnCreateConflict::REPLACE_ON_CONFLICT; + } + if (StringUtil::Equals(value, "ALTER_ON_CONFLICT")) { + return OnCreateConflict::ALTER_ON_CONFLICT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OnEntryNotFound value) { + switch(value) { + case OnEntryNotFound::THROW_EXCEPTION: + return "THROW_EXCEPTION"; + case OnEntryNotFound::RETURN_NULL: + return "RETURN_NULL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OnEntryNotFound EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "THROW_EXCEPTION")) { + return OnEntryNotFound::THROW_EXCEPTION; + } + if (StringUtil::Equals(value, "RETURN_NULL")) { + return OnEntryNotFound::RETURN_NULL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OperatorFinalizeResultType value) { + switch(value) { + case OperatorFinalizeResultType::HAVE_MORE_OUTPUT: + return "HAVE_MORE_OUTPUT"; + case OperatorFinalizeResultType::FINISHED: + return "FINISHED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OperatorFinalizeResultType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "HAVE_MORE_OUTPUT")) { + return OperatorFinalizeResultType::HAVE_MORE_OUTPUT; + } + if (StringUtil::Equals(value, "FINISHED")) { + return OperatorFinalizeResultType::FINISHED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OperatorResultType value) { + switch(value) { + case OperatorResultType::NEED_MORE_INPUT: + return "NEED_MORE_INPUT"; + case OperatorResultType::HAVE_MORE_OUTPUT: + return "HAVE_MORE_OUTPUT"; + case OperatorResultType::FINISHED: + return "FINISHED"; + case OperatorResultType::BLOCKED: + return "BLOCKED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OperatorResultType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NEED_MORE_INPUT")) { + return OperatorResultType::NEED_MORE_INPUT; + } + if (StringUtil::Equals(value, "HAVE_MORE_OUTPUT")) { + return OperatorResultType::HAVE_MORE_OUTPUT; + } + if (StringUtil::Equals(value, "FINISHED")) { + return OperatorResultType::FINISHED; + } + if (StringUtil::Equals(value, "BLOCKED")) { + return OperatorResultType::BLOCKED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OptimizerType value) { + switch(value) { + case OptimizerType::INVALID: + return "INVALID"; + case OptimizerType::EXPRESSION_REWRITER: + return "EXPRESSION_REWRITER"; + case OptimizerType::FILTER_PULLUP: + return "FILTER_PULLUP"; + case OptimizerType::FILTER_PUSHDOWN: + return "FILTER_PUSHDOWN"; + case OptimizerType::REGEX_RANGE: + return "REGEX_RANGE"; + case OptimizerType::IN_CLAUSE: + return "IN_CLAUSE"; + case OptimizerType::JOIN_ORDER: + return "JOIN_ORDER"; + case OptimizerType::DELIMINATOR: + return "DELIMINATOR"; + case OptimizerType::UNNEST_REWRITER: + return "UNNEST_REWRITER"; + case OptimizerType::UNUSED_COLUMNS: + return "UNUSED_COLUMNS"; + case OptimizerType::STATISTICS_PROPAGATION: + return "STATISTICS_PROPAGATION"; + case OptimizerType::COMMON_SUBEXPRESSIONS: + return "COMMON_SUBEXPRESSIONS"; + case OptimizerType::COMMON_AGGREGATE: + return "COMMON_AGGREGATE"; + case OptimizerType::COLUMN_LIFETIME: + return "COLUMN_LIFETIME"; + case OptimizerType::TOP_N: + return "TOP_N"; + case OptimizerType::COMPRESSED_MATERIALIZATION: + return "COMPRESSED_MATERIALIZATION"; + case OptimizerType::DUPLICATE_GROUPS: + return "DUPLICATE_GROUPS"; + case OptimizerType::REORDER_FILTER: + return "REORDER_FILTER"; + case OptimizerType::EXTENSION: + return "EXTENSION"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OptimizerType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return OptimizerType::INVALID; + } + if (StringUtil::Equals(value, "EXPRESSION_REWRITER")) { + return OptimizerType::EXPRESSION_REWRITER; + } + if (StringUtil::Equals(value, "FILTER_PULLUP")) { + return OptimizerType::FILTER_PULLUP; + } + if (StringUtil::Equals(value, "FILTER_PUSHDOWN")) { + return OptimizerType::FILTER_PUSHDOWN; + } + if (StringUtil::Equals(value, "REGEX_RANGE")) { + return OptimizerType::REGEX_RANGE; + } + if (StringUtil::Equals(value, "IN_CLAUSE")) { + return OptimizerType::IN_CLAUSE; + } + if (StringUtil::Equals(value, "JOIN_ORDER")) { + return OptimizerType::JOIN_ORDER; + } + if (StringUtil::Equals(value, "DELIMINATOR")) { + return OptimizerType::DELIMINATOR; + } + if (StringUtil::Equals(value, "UNNEST_REWRITER")) { + return OptimizerType::UNNEST_REWRITER; + } + if (StringUtil::Equals(value, "UNUSED_COLUMNS")) { + return OptimizerType::UNUSED_COLUMNS; + } + if (StringUtil::Equals(value, "STATISTICS_PROPAGATION")) { + return OptimizerType::STATISTICS_PROPAGATION; + } + if (StringUtil::Equals(value, "COMMON_SUBEXPRESSIONS")) { + return OptimizerType::COMMON_SUBEXPRESSIONS; + } + if (StringUtil::Equals(value, "COMMON_AGGREGATE")) { + return OptimizerType::COMMON_AGGREGATE; + } + if (StringUtil::Equals(value, "COLUMN_LIFETIME")) { + return OptimizerType::COLUMN_LIFETIME; + } + if (StringUtil::Equals(value, "TOP_N")) { + return OptimizerType::TOP_N; + } + if (StringUtil::Equals(value, "COMPRESSED_MATERIALIZATION")) { + return OptimizerType::COMPRESSED_MATERIALIZATION; + } + if (StringUtil::Equals(value, "DUPLICATE_GROUPS")) { + return OptimizerType::DUPLICATE_GROUPS; + } + if (StringUtil::Equals(value, "REORDER_FILTER")) { + return OptimizerType::REORDER_FILTER; + } + if (StringUtil::Equals(value, "EXTENSION")) { + return OptimizerType::EXTENSION; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OrderByNullType value) { + switch(value) { + case OrderByNullType::INVALID: + return "INVALID"; + case OrderByNullType::ORDER_DEFAULT: + return "ORDER_DEFAULT"; + case OrderByNullType::NULLS_FIRST: + return "NULLS_FIRST"; + case OrderByNullType::NULLS_LAST: + return "NULLS_LAST"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OrderByNullType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return OrderByNullType::INVALID; + } + if (StringUtil::Equals(value, "ORDER_DEFAULT") || StringUtil::Equals(value, "DEFAULT")) { + return OrderByNullType::ORDER_DEFAULT; + } + if (StringUtil::Equals(value, "NULLS_FIRST") || StringUtil::Equals(value, "NULLS FIRST")) { + return OrderByNullType::NULLS_FIRST; + } + if (StringUtil::Equals(value, "NULLS_LAST") || StringUtil::Equals(value, "NULLS LAST")) { + return OrderByNullType::NULLS_LAST; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OrderPreservationType value) { + switch(value) { + case OrderPreservationType::NO_ORDER: + return "NO_ORDER"; + case OrderPreservationType::INSERTION_ORDER: + return "INSERTION_ORDER"; + case OrderPreservationType::FIXED_ORDER: + return "FIXED_ORDER"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OrderPreservationType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NO_ORDER")) { + return OrderPreservationType::NO_ORDER; + } + if (StringUtil::Equals(value, "INSERTION_ORDER")) { + return OrderPreservationType::INSERTION_ORDER; + } + if (StringUtil::Equals(value, "FIXED_ORDER")) { + return OrderPreservationType::FIXED_ORDER; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OrderType value) { + switch(value) { + case OrderType::INVALID: + return "INVALID"; + case OrderType::ORDER_DEFAULT: + return "ORDER_DEFAULT"; + case OrderType::ASCENDING: + return "ASCENDING"; + case OrderType::DESCENDING: + return "DESCENDING"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OrderType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return OrderType::INVALID; + } + if (StringUtil::Equals(value, "ORDER_DEFAULT") || StringUtil::Equals(value, "DEFAULT")) { + return OrderType::ORDER_DEFAULT; + } + if (StringUtil::Equals(value, "ASCENDING") || StringUtil::Equals(value, "ASC")) { + return OrderType::ASCENDING; + } + if (StringUtil::Equals(value, "DESCENDING") || StringUtil::Equals(value, "DESC")) { + return OrderType::DESCENDING; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(OutputStream value) { + switch(value) { + case OutputStream::STREAM_STDOUT: + return "STREAM_STDOUT"; + case OutputStream::STREAM_STDERR: + return "STREAM_STDERR"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +OutputStream EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "STREAM_STDOUT")) { + return OutputStream::STREAM_STDOUT; + } + if (StringUtil::Equals(value, "STREAM_STDERR")) { + return OutputStream::STREAM_STDERR; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ParseInfoType value) { + switch(value) { + case ParseInfoType::ALTER_INFO: + return "ALTER_INFO"; + case ParseInfoType::ATTACH_INFO: + return "ATTACH_INFO"; + case ParseInfoType::COPY_INFO: + return "COPY_INFO"; + case ParseInfoType::CREATE_INFO: + return "CREATE_INFO"; + case ParseInfoType::DETACH_INFO: + return "DETACH_INFO"; + case ParseInfoType::DROP_INFO: + return "DROP_INFO"; + case ParseInfoType::BOUND_EXPORT_DATA: + return "BOUND_EXPORT_DATA"; + case ParseInfoType::LOAD_INFO: + return "LOAD_INFO"; + case ParseInfoType::PRAGMA_INFO: + return "PRAGMA_INFO"; + case ParseInfoType::SHOW_SELECT_INFO: + return "SHOW_SELECT_INFO"; + case ParseInfoType::TRANSACTION_INFO: + return "TRANSACTION_INFO"; + case ParseInfoType::VACUUM_INFO: + return "VACUUM_INFO"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ParseInfoType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "ALTER_INFO")) { + return ParseInfoType::ALTER_INFO; + } + if (StringUtil::Equals(value, "ATTACH_INFO")) { + return ParseInfoType::ATTACH_INFO; + } + if (StringUtil::Equals(value, "COPY_INFO")) { + return ParseInfoType::COPY_INFO; + } + if (StringUtil::Equals(value, "CREATE_INFO")) { + return ParseInfoType::CREATE_INFO; + } + if (StringUtil::Equals(value, "DETACH_INFO")) { + return ParseInfoType::DETACH_INFO; + } + if (StringUtil::Equals(value, "DROP_INFO")) { + return ParseInfoType::DROP_INFO; + } + if (StringUtil::Equals(value, "BOUND_EXPORT_DATA")) { + return ParseInfoType::BOUND_EXPORT_DATA; + } + if (StringUtil::Equals(value, "LOAD_INFO")) { + return ParseInfoType::LOAD_INFO; + } + if (StringUtil::Equals(value, "PRAGMA_INFO")) { + return ParseInfoType::PRAGMA_INFO; + } + if (StringUtil::Equals(value, "SHOW_SELECT_INFO")) { + return ParseInfoType::SHOW_SELECT_INFO; + } + if (StringUtil::Equals(value, "TRANSACTION_INFO")) { + return ParseInfoType::TRANSACTION_INFO; + } + if (StringUtil::Equals(value, "VACUUM_INFO")) { + return ParseInfoType::VACUUM_INFO; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ParserExtensionResultType value) { + switch(value) { + case ParserExtensionResultType::PARSE_SUCCESSFUL: + return "PARSE_SUCCESSFUL"; + case ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR: + return "DISPLAY_ORIGINAL_ERROR"; + case ParserExtensionResultType::DISPLAY_EXTENSION_ERROR: + return "DISPLAY_EXTENSION_ERROR"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ParserExtensionResultType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "PARSE_SUCCESSFUL")) { + return ParserExtensionResultType::PARSE_SUCCESSFUL; + } + if (StringUtil::Equals(value, "DISPLAY_ORIGINAL_ERROR")) { + return ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR; + } + if (StringUtil::Equals(value, "DISPLAY_EXTENSION_ERROR")) { + return ParserExtensionResultType::DISPLAY_EXTENSION_ERROR; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ParserMode value) { + switch(value) { + case ParserMode::PARSING: + return "PARSING"; + case ParserMode::SNIFFING_DATATYPES: + return "SNIFFING_DATATYPES"; + case ParserMode::PARSING_HEADER: + return "PARSING_HEADER"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ParserMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "PARSING")) { + return ParserMode::PARSING; + } + if (StringUtil::Equals(value, "SNIFFING_DATATYPES")) { + return ParserMode::SNIFFING_DATATYPES; + } + if (StringUtil::Equals(value, "PARSING_HEADER")) { + return ParserMode::PARSING_HEADER; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(PartitionSortStage value) { + switch(value) { + case PartitionSortStage::INIT: + return "INIT"; + case PartitionSortStage::SCAN: + return "SCAN"; + case PartitionSortStage::PREPARE: + return "PREPARE"; + case PartitionSortStage::MERGE: + return "MERGE"; + case PartitionSortStage::SORTED: + return "SORTED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +PartitionSortStage EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INIT")) { + return PartitionSortStage::INIT; + } + if (StringUtil::Equals(value, "SCAN")) { + return PartitionSortStage::SCAN; + } + if (StringUtil::Equals(value, "PREPARE")) { + return PartitionSortStage::PREPARE; + } + if (StringUtil::Equals(value, "MERGE")) { + return PartitionSortStage::MERGE; + } + if (StringUtil::Equals(value, "SORTED")) { + return PartitionSortStage::SORTED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(PartitionedColumnDataType value) { + switch(value) { + case PartitionedColumnDataType::INVALID: + return "INVALID"; + case PartitionedColumnDataType::RADIX: + return "RADIX"; + case PartitionedColumnDataType::HIVE: + return "HIVE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +PartitionedColumnDataType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return PartitionedColumnDataType::INVALID; + } + if (StringUtil::Equals(value, "RADIX")) { + return PartitionedColumnDataType::RADIX; + } + if (StringUtil::Equals(value, "HIVE")) { + return PartitionedColumnDataType::HIVE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(PartitionedTupleDataType value) { + switch(value) { + case PartitionedTupleDataType::INVALID: + return "INVALID"; + case PartitionedTupleDataType::RADIX: + return "RADIX"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +PartitionedTupleDataType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return PartitionedTupleDataType::INVALID; + } + if (StringUtil::Equals(value, "RADIX")) { + return PartitionedTupleDataType::RADIX; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(PendingExecutionResult value) { + switch(value) { + case PendingExecutionResult::RESULT_READY: + return "RESULT_READY"; + case PendingExecutionResult::RESULT_NOT_READY: + return "RESULT_NOT_READY"; + case PendingExecutionResult::EXECUTION_ERROR: + return "EXECUTION_ERROR"; + case PendingExecutionResult::NO_TASKS_AVAILABLE: + return "NO_TASKS_AVAILABLE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +PendingExecutionResult EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "RESULT_READY")) { + return PendingExecutionResult::RESULT_READY; + } + if (StringUtil::Equals(value, "RESULT_NOT_READY")) { + return PendingExecutionResult::RESULT_NOT_READY; + } + if (StringUtil::Equals(value, "EXECUTION_ERROR")) { + return PendingExecutionResult::EXECUTION_ERROR; + } + if (StringUtil::Equals(value, "NO_TASKS_AVAILABLE")) { + return PendingExecutionResult::NO_TASKS_AVAILABLE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(PhysicalOperatorType value) { + switch(value) { + case PhysicalOperatorType::INVALID: + return "INVALID"; + case PhysicalOperatorType::ORDER_BY: + return "ORDER_BY"; + case PhysicalOperatorType::LIMIT: + return "LIMIT"; + case PhysicalOperatorType::STREAMING_LIMIT: + return "STREAMING_LIMIT"; + case PhysicalOperatorType::LIMIT_PERCENT: + return "LIMIT_PERCENT"; + case PhysicalOperatorType::TOP_N: + return "TOP_N"; + case PhysicalOperatorType::WINDOW: + return "WINDOW"; + case PhysicalOperatorType::UNNEST: + return "UNNEST"; + case PhysicalOperatorType::UNGROUPED_AGGREGATE: + return "UNGROUPED_AGGREGATE"; + case PhysicalOperatorType::HASH_GROUP_BY: + return "HASH_GROUP_BY"; + case PhysicalOperatorType::PERFECT_HASH_GROUP_BY: + return "PERFECT_HASH_GROUP_BY"; + case PhysicalOperatorType::FILTER: + return "FILTER"; + case PhysicalOperatorType::PROJECTION: + return "PROJECTION"; + case PhysicalOperatorType::COPY_TO_FILE: + return "COPY_TO_FILE"; + case PhysicalOperatorType::BATCH_COPY_TO_FILE: + return "BATCH_COPY_TO_FILE"; + case PhysicalOperatorType::FIXED_BATCH_COPY_TO_FILE: + return "FIXED_BATCH_COPY_TO_FILE"; + case PhysicalOperatorType::RESERVOIR_SAMPLE: + return "RESERVOIR_SAMPLE"; + case PhysicalOperatorType::STREAMING_SAMPLE: + return "STREAMING_SAMPLE"; + case PhysicalOperatorType::STREAMING_WINDOW: + return "STREAMING_WINDOW"; + case PhysicalOperatorType::PIVOT: + return "PIVOT"; + case PhysicalOperatorType::TABLE_SCAN: + return "TABLE_SCAN"; + case PhysicalOperatorType::DUMMY_SCAN: + return "DUMMY_SCAN"; + case PhysicalOperatorType::COLUMN_DATA_SCAN: + return "COLUMN_DATA_SCAN"; + case PhysicalOperatorType::CHUNK_SCAN: + return "CHUNK_SCAN"; + case PhysicalOperatorType::RECURSIVE_CTE_SCAN: + return "RECURSIVE_CTE_SCAN"; + case PhysicalOperatorType::CTE_SCAN: + return "CTE_SCAN"; + case PhysicalOperatorType::DELIM_SCAN: + return "DELIM_SCAN"; + case PhysicalOperatorType::EXPRESSION_SCAN: + return "EXPRESSION_SCAN"; + case PhysicalOperatorType::POSITIONAL_SCAN: + return "POSITIONAL_SCAN"; + case PhysicalOperatorType::BLOCKWISE_NL_JOIN: + return "BLOCKWISE_NL_JOIN"; + case PhysicalOperatorType::NESTED_LOOP_JOIN: + return "NESTED_LOOP_JOIN"; + case PhysicalOperatorType::HASH_JOIN: + return "HASH_JOIN"; + case PhysicalOperatorType::CROSS_PRODUCT: + return "CROSS_PRODUCT"; + case PhysicalOperatorType::PIECEWISE_MERGE_JOIN: + return "PIECEWISE_MERGE_JOIN"; + case PhysicalOperatorType::IE_JOIN: + return "IE_JOIN"; + case PhysicalOperatorType::DELIM_JOIN: + return "DELIM_JOIN"; + case PhysicalOperatorType::INDEX_JOIN: + return "INDEX_JOIN"; + case PhysicalOperatorType::POSITIONAL_JOIN: + return "POSITIONAL_JOIN"; + case PhysicalOperatorType::ASOF_JOIN: + return "ASOF_JOIN"; + case PhysicalOperatorType::UNION: + return "UNION"; + case PhysicalOperatorType::RECURSIVE_CTE: + return "RECURSIVE_CTE"; + case PhysicalOperatorType::CTE: + return "CTE"; + case PhysicalOperatorType::INSERT: + return "INSERT"; + case PhysicalOperatorType::BATCH_INSERT: + return "BATCH_INSERT"; + case PhysicalOperatorType::DELETE_OPERATOR: + return "DELETE_OPERATOR"; + case PhysicalOperatorType::UPDATE: + return "UPDATE"; + case PhysicalOperatorType::CREATE_TABLE: + return "CREATE_TABLE"; + case PhysicalOperatorType::CREATE_TABLE_AS: + return "CREATE_TABLE_AS"; + case PhysicalOperatorType::BATCH_CREATE_TABLE_AS: + return "BATCH_CREATE_TABLE_AS"; + case PhysicalOperatorType::CREATE_INDEX: + return "CREATE_INDEX"; + case PhysicalOperatorType::ALTER: + return "ALTER"; + case PhysicalOperatorType::CREATE_SEQUENCE: + return "CREATE_SEQUENCE"; + case PhysicalOperatorType::CREATE_VIEW: + return "CREATE_VIEW"; + case PhysicalOperatorType::CREATE_SCHEMA: + return "CREATE_SCHEMA"; + case PhysicalOperatorType::CREATE_MACRO: + return "CREATE_MACRO"; + case PhysicalOperatorType::DROP: + return "DROP"; + case PhysicalOperatorType::PRAGMA: + return "PRAGMA"; + case PhysicalOperatorType::TRANSACTION: + return "TRANSACTION"; + case PhysicalOperatorType::CREATE_TYPE: + return "CREATE_TYPE"; + case PhysicalOperatorType::ATTACH: + return "ATTACH"; + case PhysicalOperatorType::DETACH: + return "DETACH"; + case PhysicalOperatorType::EXPLAIN: + return "EXPLAIN"; + case PhysicalOperatorType::EXPLAIN_ANALYZE: + return "EXPLAIN_ANALYZE"; + case PhysicalOperatorType::EMPTY_RESULT: + return "EMPTY_RESULT"; + case PhysicalOperatorType::EXECUTE: + return "EXECUTE"; + case PhysicalOperatorType::PREPARE: + return "PREPARE"; + case PhysicalOperatorType::VACUUM: + return "VACUUM"; + case PhysicalOperatorType::EXPORT: + return "EXPORT"; + case PhysicalOperatorType::SET: + return "SET"; + case PhysicalOperatorType::LOAD: + return "LOAD"; + case PhysicalOperatorType::INOUT_FUNCTION: + return "INOUT_FUNCTION"; + case PhysicalOperatorType::RESULT_COLLECTOR: + return "RESULT_COLLECTOR"; + case PhysicalOperatorType::RESET: + return "RESET"; + case PhysicalOperatorType::EXTENSION: + return "EXTENSION"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +PhysicalOperatorType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return PhysicalOperatorType::INVALID; + } + if (StringUtil::Equals(value, "ORDER_BY")) { + return PhysicalOperatorType::ORDER_BY; + } + if (StringUtil::Equals(value, "LIMIT")) { + return PhysicalOperatorType::LIMIT; + } + if (StringUtil::Equals(value, "STREAMING_LIMIT")) { + return PhysicalOperatorType::STREAMING_LIMIT; + } + if (StringUtil::Equals(value, "LIMIT_PERCENT")) { + return PhysicalOperatorType::LIMIT_PERCENT; + } + if (StringUtil::Equals(value, "TOP_N")) { + return PhysicalOperatorType::TOP_N; + } + if (StringUtil::Equals(value, "WINDOW")) { + return PhysicalOperatorType::WINDOW; + } + if (StringUtil::Equals(value, "UNNEST")) { + return PhysicalOperatorType::UNNEST; + } + if (StringUtil::Equals(value, "UNGROUPED_AGGREGATE")) { + return PhysicalOperatorType::UNGROUPED_AGGREGATE; + } + if (StringUtil::Equals(value, "HASH_GROUP_BY")) { + return PhysicalOperatorType::HASH_GROUP_BY; + } + if (StringUtil::Equals(value, "PERFECT_HASH_GROUP_BY")) { + return PhysicalOperatorType::PERFECT_HASH_GROUP_BY; + } + if (StringUtil::Equals(value, "FILTER")) { + return PhysicalOperatorType::FILTER; + } + if (StringUtil::Equals(value, "PROJECTION")) { + return PhysicalOperatorType::PROJECTION; + } + if (StringUtil::Equals(value, "COPY_TO_FILE")) { + return PhysicalOperatorType::COPY_TO_FILE; + } + if (StringUtil::Equals(value, "BATCH_COPY_TO_FILE")) { + return PhysicalOperatorType::BATCH_COPY_TO_FILE; + } + if (StringUtil::Equals(value, "FIXED_BATCH_COPY_TO_FILE")) { + return PhysicalOperatorType::FIXED_BATCH_COPY_TO_FILE; + } + if (StringUtil::Equals(value, "RESERVOIR_SAMPLE")) { + return PhysicalOperatorType::RESERVOIR_SAMPLE; + } + if (StringUtil::Equals(value, "STREAMING_SAMPLE")) { + return PhysicalOperatorType::STREAMING_SAMPLE; + } + if (StringUtil::Equals(value, "STREAMING_WINDOW")) { + return PhysicalOperatorType::STREAMING_WINDOW; + } + if (StringUtil::Equals(value, "PIVOT")) { + return PhysicalOperatorType::PIVOT; + } + if (StringUtil::Equals(value, "TABLE_SCAN")) { + return PhysicalOperatorType::TABLE_SCAN; + } + if (StringUtil::Equals(value, "DUMMY_SCAN")) { + return PhysicalOperatorType::DUMMY_SCAN; + } + if (StringUtil::Equals(value, "COLUMN_DATA_SCAN")) { + return PhysicalOperatorType::COLUMN_DATA_SCAN; + } + if (StringUtil::Equals(value, "CHUNK_SCAN")) { + return PhysicalOperatorType::CHUNK_SCAN; + } + if (StringUtil::Equals(value, "RECURSIVE_CTE_SCAN")) { + return PhysicalOperatorType::RECURSIVE_CTE_SCAN; + } + if (StringUtil::Equals(value, "CTE_SCAN")) { + return PhysicalOperatorType::CTE_SCAN; + } + if (StringUtil::Equals(value, "DELIM_SCAN")) { + return PhysicalOperatorType::DELIM_SCAN; + } + if (StringUtil::Equals(value, "EXPRESSION_SCAN")) { + return PhysicalOperatorType::EXPRESSION_SCAN; + } + if (StringUtil::Equals(value, "POSITIONAL_SCAN")) { + return PhysicalOperatorType::POSITIONAL_SCAN; + } + if (StringUtil::Equals(value, "BLOCKWISE_NL_JOIN")) { + return PhysicalOperatorType::BLOCKWISE_NL_JOIN; + } + if (StringUtil::Equals(value, "NESTED_LOOP_JOIN")) { + return PhysicalOperatorType::NESTED_LOOP_JOIN; + } + if (StringUtil::Equals(value, "HASH_JOIN")) { + return PhysicalOperatorType::HASH_JOIN; + } + if (StringUtil::Equals(value, "CROSS_PRODUCT")) { + return PhysicalOperatorType::CROSS_PRODUCT; + } + if (StringUtil::Equals(value, "PIECEWISE_MERGE_JOIN")) { + return PhysicalOperatorType::PIECEWISE_MERGE_JOIN; + } + if (StringUtil::Equals(value, "IE_JOIN")) { + return PhysicalOperatorType::IE_JOIN; + } + if (StringUtil::Equals(value, "DELIM_JOIN")) { + return PhysicalOperatorType::DELIM_JOIN; + } + if (StringUtil::Equals(value, "INDEX_JOIN")) { + return PhysicalOperatorType::INDEX_JOIN; + } + if (StringUtil::Equals(value, "POSITIONAL_JOIN")) { + return PhysicalOperatorType::POSITIONAL_JOIN; + } + if (StringUtil::Equals(value, "ASOF_JOIN")) { + return PhysicalOperatorType::ASOF_JOIN; + } + if (StringUtil::Equals(value, "UNION")) { + return PhysicalOperatorType::UNION; + } + if (StringUtil::Equals(value, "RECURSIVE_CTE")) { + return PhysicalOperatorType::RECURSIVE_CTE; + } + if (StringUtil::Equals(value, "CTE")) { + return PhysicalOperatorType::CTE; + } + if (StringUtil::Equals(value, "INSERT")) { + return PhysicalOperatorType::INSERT; + } + if (StringUtil::Equals(value, "BATCH_INSERT")) { + return PhysicalOperatorType::BATCH_INSERT; + } + if (StringUtil::Equals(value, "DELETE_OPERATOR")) { + return PhysicalOperatorType::DELETE_OPERATOR; + } + if (StringUtil::Equals(value, "UPDATE")) { + return PhysicalOperatorType::UPDATE; + } + if (StringUtil::Equals(value, "CREATE_TABLE")) { + return PhysicalOperatorType::CREATE_TABLE; + } + if (StringUtil::Equals(value, "CREATE_TABLE_AS")) { + return PhysicalOperatorType::CREATE_TABLE_AS; + } + if (StringUtil::Equals(value, "BATCH_CREATE_TABLE_AS")) { + return PhysicalOperatorType::BATCH_CREATE_TABLE_AS; + } + if (StringUtil::Equals(value, "CREATE_INDEX")) { + return PhysicalOperatorType::CREATE_INDEX; + } + if (StringUtil::Equals(value, "ALTER")) { + return PhysicalOperatorType::ALTER; + } + if (StringUtil::Equals(value, "CREATE_SEQUENCE")) { + return PhysicalOperatorType::CREATE_SEQUENCE; + } + if (StringUtil::Equals(value, "CREATE_VIEW")) { + return PhysicalOperatorType::CREATE_VIEW; + } + if (StringUtil::Equals(value, "CREATE_SCHEMA")) { + return PhysicalOperatorType::CREATE_SCHEMA; + } + if (StringUtil::Equals(value, "CREATE_MACRO")) { + return PhysicalOperatorType::CREATE_MACRO; + } + if (StringUtil::Equals(value, "DROP")) { + return PhysicalOperatorType::DROP; + } + if (StringUtil::Equals(value, "PRAGMA")) { + return PhysicalOperatorType::PRAGMA; + } + if (StringUtil::Equals(value, "TRANSACTION")) { + return PhysicalOperatorType::TRANSACTION; + } + if (StringUtil::Equals(value, "CREATE_TYPE")) { + return PhysicalOperatorType::CREATE_TYPE; + } + if (StringUtil::Equals(value, "ATTACH")) { + return PhysicalOperatorType::ATTACH; + } + if (StringUtil::Equals(value, "DETACH")) { + return PhysicalOperatorType::DETACH; + } + if (StringUtil::Equals(value, "EXPLAIN")) { + return PhysicalOperatorType::EXPLAIN; + } + if (StringUtil::Equals(value, "EXPLAIN_ANALYZE")) { + return PhysicalOperatorType::EXPLAIN_ANALYZE; + } + if (StringUtil::Equals(value, "EMPTY_RESULT")) { + return PhysicalOperatorType::EMPTY_RESULT; + } + if (StringUtil::Equals(value, "EXECUTE")) { + return PhysicalOperatorType::EXECUTE; + } + if (StringUtil::Equals(value, "PREPARE")) { + return PhysicalOperatorType::PREPARE; + } + if (StringUtil::Equals(value, "VACUUM")) { + return PhysicalOperatorType::VACUUM; + } + if (StringUtil::Equals(value, "EXPORT")) { + return PhysicalOperatorType::EXPORT; + } + if (StringUtil::Equals(value, "SET")) { + return PhysicalOperatorType::SET; + } + if (StringUtil::Equals(value, "LOAD")) { + return PhysicalOperatorType::LOAD; + } + if (StringUtil::Equals(value, "INOUT_FUNCTION")) { + return PhysicalOperatorType::INOUT_FUNCTION; + } + if (StringUtil::Equals(value, "RESULT_COLLECTOR")) { + return PhysicalOperatorType::RESULT_COLLECTOR; + } + if (StringUtil::Equals(value, "RESET")) { + return PhysicalOperatorType::RESET; + } + if (StringUtil::Equals(value, "EXTENSION")) { + return PhysicalOperatorType::EXTENSION; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(PhysicalType value) { + switch(value) { + case PhysicalType::BOOL: + return "BOOL"; + case PhysicalType::UINT8: + return "UINT8"; + case PhysicalType::INT8: + return "INT8"; + case PhysicalType::UINT16: + return "UINT16"; + case PhysicalType::INT16: + return "INT16"; + case PhysicalType::UINT32: + return "UINT32"; + case PhysicalType::INT32: + return "INT32"; + case PhysicalType::UINT64: + return "UINT64"; + case PhysicalType::INT64: + return "INT64"; + case PhysicalType::FLOAT: + return "FLOAT"; + case PhysicalType::DOUBLE: + return "DOUBLE"; + case PhysicalType::INTERVAL: + return "INTERVAL"; + case PhysicalType::LIST: + return "LIST"; + case PhysicalType::STRUCT: + return "STRUCT"; + case PhysicalType::VARCHAR: + return "VARCHAR"; + case PhysicalType::INT128: + return "INT128"; + case PhysicalType::UNKNOWN: + return "UNKNOWN"; + case PhysicalType::BIT: + return "BIT"; + case PhysicalType::INVALID: + return "INVALID"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +PhysicalType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "BOOL")) { + return PhysicalType::BOOL; + } + if (StringUtil::Equals(value, "UINT8")) { + return PhysicalType::UINT8; + } + if (StringUtil::Equals(value, "INT8")) { + return PhysicalType::INT8; + } + if (StringUtil::Equals(value, "UINT16")) { + return PhysicalType::UINT16; + } + if (StringUtil::Equals(value, "INT16")) { + return PhysicalType::INT16; + } + if (StringUtil::Equals(value, "UINT32")) { + return PhysicalType::UINT32; + } + if (StringUtil::Equals(value, "INT32")) { + return PhysicalType::INT32; + } + if (StringUtil::Equals(value, "UINT64")) { + return PhysicalType::UINT64; + } + if (StringUtil::Equals(value, "INT64")) { + return PhysicalType::INT64; + } + if (StringUtil::Equals(value, "FLOAT")) { + return PhysicalType::FLOAT; + } + if (StringUtil::Equals(value, "DOUBLE")) { + return PhysicalType::DOUBLE; + } + if (StringUtil::Equals(value, "INTERVAL")) { + return PhysicalType::INTERVAL; + } + if (StringUtil::Equals(value, "LIST")) { + return PhysicalType::LIST; + } + if (StringUtil::Equals(value, "STRUCT")) { + return PhysicalType::STRUCT; + } + if (StringUtil::Equals(value, "VARCHAR")) { + return PhysicalType::VARCHAR; + } + if (StringUtil::Equals(value, "INT128")) { + return PhysicalType::INT128; + } + if (StringUtil::Equals(value, "UNKNOWN")) { + return PhysicalType::UNKNOWN; + } + if (StringUtil::Equals(value, "BIT")) { + return PhysicalType::BIT; + } + if (StringUtil::Equals(value, "INVALID")) { + return PhysicalType::INVALID; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(PragmaType value) { + switch(value) { + case PragmaType::PRAGMA_STATEMENT: + return "PRAGMA_STATEMENT"; + case PragmaType::PRAGMA_CALL: + return "PRAGMA_CALL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +PragmaType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "PRAGMA_STATEMENT")) { + return PragmaType::PRAGMA_STATEMENT; + } + if (StringUtil::Equals(value, "PRAGMA_CALL")) { + return PragmaType::PRAGMA_CALL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(PreparedParamType value) { + switch(value) { + case PreparedParamType::AUTO_INCREMENT: + return "AUTO_INCREMENT"; + case PreparedParamType::POSITIONAL: + return "POSITIONAL"; + case PreparedParamType::NAMED: + return "NAMED"; + case PreparedParamType::INVALID: + return "INVALID"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +PreparedParamType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "AUTO_INCREMENT")) { + return PreparedParamType::AUTO_INCREMENT; + } + if (StringUtil::Equals(value, "POSITIONAL")) { + return PreparedParamType::POSITIONAL; + } + if (StringUtil::Equals(value, "NAMED")) { + return PreparedParamType::NAMED; + } + if (StringUtil::Equals(value, "INVALID")) { + return PreparedParamType::INVALID; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ProfilerPrintFormat value) { + switch(value) { + case ProfilerPrintFormat::QUERY_TREE: + return "QUERY_TREE"; + case ProfilerPrintFormat::JSON: + return "JSON"; + case ProfilerPrintFormat::QUERY_TREE_OPTIMIZER: + return "QUERY_TREE_OPTIMIZER"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ProfilerPrintFormat EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "QUERY_TREE")) { + return ProfilerPrintFormat::QUERY_TREE; + } + if (StringUtil::Equals(value, "JSON")) { + return ProfilerPrintFormat::JSON; + } + if (StringUtil::Equals(value, "QUERY_TREE_OPTIMIZER")) { + return ProfilerPrintFormat::QUERY_TREE_OPTIMIZER; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(QueryNodeType value) { + switch(value) { + case QueryNodeType::SELECT_NODE: + return "SELECT_NODE"; + case QueryNodeType::SET_OPERATION_NODE: + return "SET_OPERATION_NODE"; + case QueryNodeType::BOUND_SUBQUERY_NODE: + return "BOUND_SUBQUERY_NODE"; + case QueryNodeType::RECURSIVE_CTE_NODE: + return "RECURSIVE_CTE_NODE"; + case QueryNodeType::CTE_NODE: + return "CTE_NODE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +QueryNodeType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "SELECT_NODE")) { + return QueryNodeType::SELECT_NODE; + } + if (StringUtil::Equals(value, "SET_OPERATION_NODE")) { + return QueryNodeType::SET_OPERATION_NODE; + } + if (StringUtil::Equals(value, "BOUND_SUBQUERY_NODE")) { + return QueryNodeType::BOUND_SUBQUERY_NODE; + } + if (StringUtil::Equals(value, "RECURSIVE_CTE_NODE")) { + return QueryNodeType::RECURSIVE_CTE_NODE; + } + if (StringUtil::Equals(value, "CTE_NODE")) { + return QueryNodeType::CTE_NODE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(QueryResultType value) { + switch(value) { + case QueryResultType::MATERIALIZED_RESULT: + return "MATERIALIZED_RESULT"; + case QueryResultType::STREAM_RESULT: + return "STREAM_RESULT"; + case QueryResultType::PENDING_RESULT: + return "PENDING_RESULT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +QueryResultType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "MATERIALIZED_RESULT")) { + return QueryResultType::MATERIALIZED_RESULT; + } + if (StringUtil::Equals(value, "STREAM_RESULT")) { + return QueryResultType::STREAM_RESULT; + } + if (StringUtil::Equals(value, "PENDING_RESULT")) { + return QueryResultType::PENDING_RESULT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(QuoteRule value) { + switch(value) { + case QuoteRule::QUOTES_RFC: + return "QUOTES_RFC"; + case QuoteRule::QUOTES_OTHER: + return "QUOTES_OTHER"; + case QuoteRule::NO_QUOTES: + return "NO_QUOTES"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +QuoteRule EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "QUOTES_RFC")) { + return QuoteRule::QUOTES_RFC; + } + if (StringUtil::Equals(value, "QUOTES_OTHER")) { + return QuoteRule::QUOTES_OTHER; + } + if (StringUtil::Equals(value, "NO_QUOTES")) { + return QuoteRule::NO_QUOTES; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(RelationType value) { + switch(value) { + case RelationType::INVALID_RELATION: + return "INVALID_RELATION"; + case RelationType::TABLE_RELATION: + return "TABLE_RELATION"; + case RelationType::PROJECTION_RELATION: + return "PROJECTION_RELATION"; + case RelationType::FILTER_RELATION: + return "FILTER_RELATION"; + case RelationType::EXPLAIN_RELATION: + return "EXPLAIN_RELATION"; + case RelationType::CROSS_PRODUCT_RELATION: + return "CROSS_PRODUCT_RELATION"; + case RelationType::JOIN_RELATION: + return "JOIN_RELATION"; + case RelationType::AGGREGATE_RELATION: + return "AGGREGATE_RELATION"; + case RelationType::SET_OPERATION_RELATION: + return "SET_OPERATION_RELATION"; + case RelationType::DISTINCT_RELATION: + return "DISTINCT_RELATION"; + case RelationType::LIMIT_RELATION: + return "LIMIT_RELATION"; + case RelationType::ORDER_RELATION: + return "ORDER_RELATION"; + case RelationType::CREATE_VIEW_RELATION: + return "CREATE_VIEW_RELATION"; + case RelationType::CREATE_TABLE_RELATION: + return "CREATE_TABLE_RELATION"; + case RelationType::INSERT_RELATION: + return "INSERT_RELATION"; + case RelationType::VALUE_LIST_RELATION: + return "VALUE_LIST_RELATION"; + case RelationType::DELETE_RELATION: + return "DELETE_RELATION"; + case RelationType::UPDATE_RELATION: + return "UPDATE_RELATION"; + case RelationType::WRITE_CSV_RELATION: + return "WRITE_CSV_RELATION"; + case RelationType::WRITE_PARQUET_RELATION: + return "WRITE_PARQUET_RELATION"; + case RelationType::READ_CSV_RELATION: + return "READ_CSV_RELATION"; + case RelationType::SUBQUERY_RELATION: + return "SUBQUERY_RELATION"; + case RelationType::TABLE_FUNCTION_RELATION: + return "TABLE_FUNCTION_RELATION"; + case RelationType::VIEW_RELATION: + return "VIEW_RELATION"; + case RelationType::QUERY_RELATION: + return "QUERY_RELATION"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +RelationType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID_RELATION")) { + return RelationType::INVALID_RELATION; + } + if (StringUtil::Equals(value, "TABLE_RELATION")) { + return RelationType::TABLE_RELATION; + } + if (StringUtil::Equals(value, "PROJECTION_RELATION")) { + return RelationType::PROJECTION_RELATION; + } + if (StringUtil::Equals(value, "FILTER_RELATION")) { + return RelationType::FILTER_RELATION; + } + if (StringUtil::Equals(value, "EXPLAIN_RELATION")) { + return RelationType::EXPLAIN_RELATION; + } + if (StringUtil::Equals(value, "CROSS_PRODUCT_RELATION")) { + return RelationType::CROSS_PRODUCT_RELATION; + } + if (StringUtil::Equals(value, "JOIN_RELATION")) { + return RelationType::JOIN_RELATION; + } + if (StringUtil::Equals(value, "AGGREGATE_RELATION")) { + return RelationType::AGGREGATE_RELATION; + } + if (StringUtil::Equals(value, "SET_OPERATION_RELATION")) { + return RelationType::SET_OPERATION_RELATION; + } + if (StringUtil::Equals(value, "DISTINCT_RELATION")) { + return RelationType::DISTINCT_RELATION; + } + if (StringUtil::Equals(value, "LIMIT_RELATION")) { + return RelationType::LIMIT_RELATION; + } + if (StringUtil::Equals(value, "ORDER_RELATION")) { + return RelationType::ORDER_RELATION; + } + if (StringUtil::Equals(value, "CREATE_VIEW_RELATION")) { + return RelationType::CREATE_VIEW_RELATION; + } + if (StringUtil::Equals(value, "CREATE_TABLE_RELATION")) { + return RelationType::CREATE_TABLE_RELATION; + } + if (StringUtil::Equals(value, "INSERT_RELATION")) { + return RelationType::INSERT_RELATION; + } + if (StringUtil::Equals(value, "VALUE_LIST_RELATION")) { + return RelationType::VALUE_LIST_RELATION; + } + if (StringUtil::Equals(value, "DELETE_RELATION")) { + return RelationType::DELETE_RELATION; + } + if (StringUtil::Equals(value, "UPDATE_RELATION")) { + return RelationType::UPDATE_RELATION; + } + if (StringUtil::Equals(value, "WRITE_CSV_RELATION")) { + return RelationType::WRITE_CSV_RELATION; + } + if (StringUtil::Equals(value, "WRITE_PARQUET_RELATION")) { + return RelationType::WRITE_PARQUET_RELATION; + } + if (StringUtil::Equals(value, "READ_CSV_RELATION")) { + return RelationType::READ_CSV_RELATION; + } + if (StringUtil::Equals(value, "SUBQUERY_RELATION")) { + return RelationType::SUBQUERY_RELATION; + } + if (StringUtil::Equals(value, "TABLE_FUNCTION_RELATION")) { + return RelationType::TABLE_FUNCTION_RELATION; + } + if (StringUtil::Equals(value, "VIEW_RELATION")) { + return RelationType::VIEW_RELATION; + } + if (StringUtil::Equals(value, "QUERY_RELATION")) { + return RelationType::QUERY_RELATION; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(RenderMode value) { + switch(value) { + case RenderMode::ROWS: + return "ROWS"; + case RenderMode::COLUMNS: + return "COLUMNS"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +RenderMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "ROWS")) { + return RenderMode::ROWS; + } + if (StringUtil::Equals(value, "COLUMNS")) { + return RenderMode::COLUMNS; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ResultModifierType value) { + switch(value) { + case ResultModifierType::LIMIT_MODIFIER: + return "LIMIT_MODIFIER"; + case ResultModifierType::ORDER_MODIFIER: + return "ORDER_MODIFIER"; + case ResultModifierType::DISTINCT_MODIFIER: + return "DISTINCT_MODIFIER"; + case ResultModifierType::LIMIT_PERCENT_MODIFIER: + return "LIMIT_PERCENT_MODIFIER"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ResultModifierType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "LIMIT_MODIFIER")) { + return ResultModifierType::LIMIT_MODIFIER; + } + if (StringUtil::Equals(value, "ORDER_MODIFIER")) { + return ResultModifierType::ORDER_MODIFIER; + } + if (StringUtil::Equals(value, "DISTINCT_MODIFIER")) { + return ResultModifierType::DISTINCT_MODIFIER; + } + if (StringUtil::Equals(value, "LIMIT_PERCENT_MODIFIER")) { + return ResultModifierType::LIMIT_PERCENT_MODIFIER; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SampleMethod value) { + switch(value) { + case SampleMethod::SYSTEM_SAMPLE: + return "System"; + case SampleMethod::BERNOULLI_SAMPLE: + return "Bernoulli"; + case SampleMethod::RESERVOIR_SAMPLE: + return "Reservoir"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SampleMethod EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "System")) { + return SampleMethod::SYSTEM_SAMPLE; + } + if (StringUtil::Equals(value, "Bernoulli")) { + return SampleMethod::BERNOULLI_SAMPLE; + } + if (StringUtil::Equals(value, "Reservoir")) { + return SampleMethod::RESERVOIR_SAMPLE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SequenceInfo value) { + switch(value) { + case SequenceInfo::SEQ_START: + return "SEQ_START"; + case SequenceInfo::SEQ_INC: + return "SEQ_INC"; + case SequenceInfo::SEQ_MIN: + return "SEQ_MIN"; + case SequenceInfo::SEQ_MAX: + return "SEQ_MAX"; + case SequenceInfo::SEQ_CYCLE: + return "SEQ_CYCLE"; + case SequenceInfo::SEQ_OWN: + return "SEQ_OWN"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SequenceInfo EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "SEQ_START")) { + return SequenceInfo::SEQ_START; + } + if (StringUtil::Equals(value, "SEQ_INC")) { + return SequenceInfo::SEQ_INC; + } + if (StringUtil::Equals(value, "SEQ_MIN")) { + return SequenceInfo::SEQ_MIN; + } + if (StringUtil::Equals(value, "SEQ_MAX")) { + return SequenceInfo::SEQ_MAX; + } + if (StringUtil::Equals(value, "SEQ_CYCLE")) { + return SequenceInfo::SEQ_CYCLE; + } + if (StringUtil::Equals(value, "SEQ_OWN")) { + return SequenceInfo::SEQ_OWN; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SetOperationType value) { + switch(value) { + case SetOperationType::NONE: + return "NONE"; + case SetOperationType::UNION: + return "UNION"; + case SetOperationType::EXCEPT: + return "EXCEPT"; + case SetOperationType::INTERSECT: + return "INTERSECT"; + case SetOperationType::UNION_BY_NAME: + return "UNION_BY_NAME"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SetOperationType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NONE")) { + return SetOperationType::NONE; + } + if (StringUtil::Equals(value, "UNION")) { + return SetOperationType::UNION; + } + if (StringUtil::Equals(value, "EXCEPT")) { + return SetOperationType::EXCEPT; + } + if (StringUtil::Equals(value, "INTERSECT")) { + return SetOperationType::INTERSECT; + } + if (StringUtil::Equals(value, "UNION_BY_NAME")) { + return SetOperationType::UNION_BY_NAME; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SetScope value) { + switch(value) { + case SetScope::AUTOMATIC: + return "AUTOMATIC"; + case SetScope::LOCAL: + return "LOCAL"; + case SetScope::SESSION: + return "SESSION"; + case SetScope::GLOBAL: + return "GLOBAL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SetScope EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "AUTOMATIC")) { + return SetScope::AUTOMATIC; + } + if (StringUtil::Equals(value, "LOCAL")) { + return SetScope::LOCAL; + } + if (StringUtil::Equals(value, "SESSION")) { + return SetScope::SESSION; + } + if (StringUtil::Equals(value, "GLOBAL")) { + return SetScope::GLOBAL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SetType value) { + switch(value) { + case SetType::SET: + return "SET"; + case SetType::RESET: + return "RESET"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SetType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "SET")) { + return SetType::SET; + } + if (StringUtil::Equals(value, "RESET")) { + return SetType::RESET; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SimplifiedTokenType value) { + switch(value) { + case SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER: + return "SIMPLIFIED_TOKEN_IDENTIFIER"; + case SimplifiedTokenType::SIMPLIFIED_TOKEN_NUMERIC_CONSTANT: + return "SIMPLIFIED_TOKEN_NUMERIC_CONSTANT"; + case SimplifiedTokenType::SIMPLIFIED_TOKEN_STRING_CONSTANT: + return "SIMPLIFIED_TOKEN_STRING_CONSTANT"; + case SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR: + return "SIMPLIFIED_TOKEN_OPERATOR"; + case SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD: + return "SIMPLIFIED_TOKEN_KEYWORD"; + case SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT: + return "SIMPLIFIED_TOKEN_COMMENT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SimplifiedTokenType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_IDENTIFIER")) { + return SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER; + } + if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_NUMERIC_CONSTANT")) { + return SimplifiedTokenType::SIMPLIFIED_TOKEN_NUMERIC_CONSTANT; + } + if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_STRING_CONSTANT")) { + return SimplifiedTokenType::SIMPLIFIED_TOKEN_STRING_CONSTANT; + } + if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_OPERATOR")) { + return SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR; + } + if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_KEYWORD")) { + return SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD; + } + if (StringUtil::Equals(value, "SIMPLIFIED_TOKEN_COMMENT")) { + return SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SinkCombineResultType value) { + switch(value) { + case SinkCombineResultType::FINISHED: + return "FINISHED"; + case SinkCombineResultType::BLOCKED: + return "BLOCKED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SinkCombineResultType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "FINISHED")) { + return SinkCombineResultType::FINISHED; + } + if (StringUtil::Equals(value, "BLOCKED")) { + return SinkCombineResultType::BLOCKED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SinkFinalizeType value) { + switch(value) { + case SinkFinalizeType::READY: + return "READY"; + case SinkFinalizeType::NO_OUTPUT_POSSIBLE: + return "NO_OUTPUT_POSSIBLE"; + case SinkFinalizeType::BLOCKED: + return "BLOCKED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SinkFinalizeType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "READY")) { + return SinkFinalizeType::READY; + } + if (StringUtil::Equals(value, "NO_OUTPUT_POSSIBLE")) { + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + if (StringUtil::Equals(value, "BLOCKED")) { + return SinkFinalizeType::BLOCKED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SinkResultType value) { + switch(value) { + case SinkResultType::NEED_MORE_INPUT: + return "NEED_MORE_INPUT"; + case SinkResultType::FINISHED: + return "FINISHED"; + case SinkResultType::BLOCKED: + return "BLOCKED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SinkResultType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NEED_MORE_INPUT")) { + return SinkResultType::NEED_MORE_INPUT; + } + if (StringUtil::Equals(value, "FINISHED")) { + return SinkResultType::FINISHED; + } + if (StringUtil::Equals(value, "BLOCKED")) { + return SinkResultType::BLOCKED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SourceResultType value) { + switch(value) { + case SourceResultType::HAVE_MORE_OUTPUT: + return "HAVE_MORE_OUTPUT"; + case SourceResultType::FINISHED: + return "FINISHED"; + case SourceResultType::BLOCKED: + return "BLOCKED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SourceResultType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "HAVE_MORE_OUTPUT")) { + return SourceResultType::HAVE_MORE_OUTPUT; + } + if (StringUtil::Equals(value, "FINISHED")) { + return SourceResultType::FINISHED; + } + if (StringUtil::Equals(value, "BLOCKED")) { + return SourceResultType::BLOCKED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(StatementReturnType value) { + switch(value) { + case StatementReturnType::QUERY_RESULT: + return "QUERY_RESULT"; + case StatementReturnType::CHANGED_ROWS: + return "CHANGED_ROWS"; + case StatementReturnType::NOTHING: + return "NOTHING"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +StatementReturnType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "QUERY_RESULT")) { + return StatementReturnType::QUERY_RESULT; + } + if (StringUtil::Equals(value, "CHANGED_ROWS")) { + return StatementReturnType::CHANGED_ROWS; + } + if (StringUtil::Equals(value, "NOTHING")) { + return StatementReturnType::NOTHING; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(StatementType value) { + switch(value) { + case StatementType::INVALID_STATEMENT: + return "INVALID_STATEMENT"; + case StatementType::SELECT_STATEMENT: + return "SELECT_STATEMENT"; + case StatementType::INSERT_STATEMENT: + return "INSERT_STATEMENT"; + case StatementType::UPDATE_STATEMENT: + return "UPDATE_STATEMENT"; + case StatementType::CREATE_STATEMENT: + return "CREATE_STATEMENT"; + case StatementType::DELETE_STATEMENT: + return "DELETE_STATEMENT"; + case StatementType::PREPARE_STATEMENT: + return "PREPARE_STATEMENT"; + case StatementType::EXECUTE_STATEMENT: + return "EXECUTE_STATEMENT"; + case StatementType::ALTER_STATEMENT: + return "ALTER_STATEMENT"; + case StatementType::TRANSACTION_STATEMENT: + return "TRANSACTION_STATEMENT"; + case StatementType::COPY_STATEMENT: + return "COPY_STATEMENT"; + case StatementType::ANALYZE_STATEMENT: + return "ANALYZE_STATEMENT"; + case StatementType::VARIABLE_SET_STATEMENT: + return "VARIABLE_SET_STATEMENT"; + case StatementType::CREATE_FUNC_STATEMENT: + return "CREATE_FUNC_STATEMENT"; + case StatementType::EXPLAIN_STATEMENT: + return "EXPLAIN_STATEMENT"; + case StatementType::DROP_STATEMENT: + return "DROP_STATEMENT"; + case StatementType::EXPORT_STATEMENT: + return "EXPORT_STATEMENT"; + case StatementType::PRAGMA_STATEMENT: + return "PRAGMA_STATEMENT"; + case StatementType::SHOW_STATEMENT: + return "SHOW_STATEMENT"; + case StatementType::VACUUM_STATEMENT: + return "VACUUM_STATEMENT"; + case StatementType::CALL_STATEMENT: + return "CALL_STATEMENT"; + case StatementType::SET_STATEMENT: + return "SET_STATEMENT"; + case StatementType::LOAD_STATEMENT: + return "LOAD_STATEMENT"; + case StatementType::RELATION_STATEMENT: + return "RELATION_STATEMENT"; + case StatementType::EXTENSION_STATEMENT: + return "EXTENSION_STATEMENT"; + case StatementType::LOGICAL_PLAN_STATEMENT: + return "LOGICAL_PLAN_STATEMENT"; + case StatementType::ATTACH_STATEMENT: + return "ATTACH_STATEMENT"; + case StatementType::DETACH_STATEMENT: + return "DETACH_STATEMENT"; + case StatementType::MULTI_STATEMENT: + return "MULTI_STATEMENT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +StatementType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID_STATEMENT")) { + return StatementType::INVALID_STATEMENT; + } + if (StringUtil::Equals(value, "SELECT_STATEMENT")) { + return StatementType::SELECT_STATEMENT; + } + if (StringUtil::Equals(value, "INSERT_STATEMENT")) { + return StatementType::INSERT_STATEMENT; + } + if (StringUtil::Equals(value, "UPDATE_STATEMENT")) { + return StatementType::UPDATE_STATEMENT; + } + if (StringUtil::Equals(value, "CREATE_STATEMENT")) { + return StatementType::CREATE_STATEMENT; + } + if (StringUtil::Equals(value, "DELETE_STATEMENT")) { + return StatementType::DELETE_STATEMENT; + } + if (StringUtil::Equals(value, "PREPARE_STATEMENT")) { + return StatementType::PREPARE_STATEMENT; + } + if (StringUtil::Equals(value, "EXECUTE_STATEMENT")) { + return StatementType::EXECUTE_STATEMENT; + } + if (StringUtil::Equals(value, "ALTER_STATEMENT")) { + return StatementType::ALTER_STATEMENT; + } + if (StringUtil::Equals(value, "TRANSACTION_STATEMENT")) { + return StatementType::TRANSACTION_STATEMENT; + } + if (StringUtil::Equals(value, "COPY_STATEMENT")) { + return StatementType::COPY_STATEMENT; + } + if (StringUtil::Equals(value, "ANALYZE_STATEMENT")) { + return StatementType::ANALYZE_STATEMENT; + } + if (StringUtil::Equals(value, "VARIABLE_SET_STATEMENT")) { + return StatementType::VARIABLE_SET_STATEMENT; + } + if (StringUtil::Equals(value, "CREATE_FUNC_STATEMENT")) { + return StatementType::CREATE_FUNC_STATEMENT; + } + if (StringUtil::Equals(value, "EXPLAIN_STATEMENT")) { + return StatementType::EXPLAIN_STATEMENT; + } + if (StringUtil::Equals(value, "DROP_STATEMENT")) { + return StatementType::DROP_STATEMENT; + } + if (StringUtil::Equals(value, "EXPORT_STATEMENT")) { + return StatementType::EXPORT_STATEMENT; + } + if (StringUtil::Equals(value, "PRAGMA_STATEMENT")) { + return StatementType::PRAGMA_STATEMENT; + } + if (StringUtil::Equals(value, "SHOW_STATEMENT")) { + return StatementType::SHOW_STATEMENT; + } + if (StringUtil::Equals(value, "VACUUM_STATEMENT")) { + return StatementType::VACUUM_STATEMENT; + } + if (StringUtil::Equals(value, "CALL_STATEMENT")) { + return StatementType::CALL_STATEMENT; + } + if (StringUtil::Equals(value, "SET_STATEMENT")) { + return StatementType::SET_STATEMENT; + } + if (StringUtil::Equals(value, "LOAD_STATEMENT")) { + return StatementType::LOAD_STATEMENT; + } + if (StringUtil::Equals(value, "RELATION_STATEMENT")) { + return StatementType::RELATION_STATEMENT; + } + if (StringUtil::Equals(value, "EXTENSION_STATEMENT")) { + return StatementType::EXTENSION_STATEMENT; + } + if (StringUtil::Equals(value, "LOGICAL_PLAN_STATEMENT")) { + return StatementType::LOGICAL_PLAN_STATEMENT; + } + if (StringUtil::Equals(value, "ATTACH_STATEMENT")) { + return StatementType::ATTACH_STATEMENT; + } + if (StringUtil::Equals(value, "DETACH_STATEMENT")) { + return StatementType::DETACH_STATEMENT; + } + if (StringUtil::Equals(value, "MULTI_STATEMENT")) { + return StatementType::MULTI_STATEMENT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(StatisticsType value) { + switch(value) { + case StatisticsType::NUMERIC_STATS: + return "NUMERIC_STATS"; + case StatisticsType::STRING_STATS: + return "STRING_STATS"; + case StatisticsType::LIST_STATS: + return "LIST_STATS"; + case StatisticsType::STRUCT_STATS: + return "STRUCT_STATS"; + case StatisticsType::BASE_STATS: + return "BASE_STATS"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +StatisticsType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NUMERIC_STATS")) { + return StatisticsType::NUMERIC_STATS; + } + if (StringUtil::Equals(value, "STRING_STATS")) { + return StatisticsType::STRING_STATS; + } + if (StringUtil::Equals(value, "LIST_STATS")) { + return StatisticsType::LIST_STATS; + } + if (StringUtil::Equals(value, "STRUCT_STATS")) { + return StatisticsType::STRUCT_STATS; + } + if (StringUtil::Equals(value, "BASE_STATS")) { + return StatisticsType::BASE_STATS; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(StatsInfo value) { + switch(value) { + case StatsInfo::CAN_HAVE_NULL_VALUES: + return "CAN_HAVE_NULL_VALUES"; + case StatsInfo::CANNOT_HAVE_NULL_VALUES: + return "CANNOT_HAVE_NULL_VALUES"; + case StatsInfo::CAN_HAVE_VALID_VALUES: + return "CAN_HAVE_VALID_VALUES"; + case StatsInfo::CANNOT_HAVE_VALID_VALUES: + return "CANNOT_HAVE_VALID_VALUES"; + case StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES: + return "CAN_HAVE_NULL_AND_VALID_VALUES"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +StatsInfo EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "CAN_HAVE_NULL_VALUES")) { + return StatsInfo::CAN_HAVE_NULL_VALUES; + } + if (StringUtil::Equals(value, "CANNOT_HAVE_NULL_VALUES")) { + return StatsInfo::CANNOT_HAVE_NULL_VALUES; + } + if (StringUtil::Equals(value, "CAN_HAVE_VALID_VALUES")) { + return StatsInfo::CAN_HAVE_VALID_VALUES; + } + if (StringUtil::Equals(value, "CANNOT_HAVE_VALID_VALUES")) { + return StatsInfo::CANNOT_HAVE_VALID_VALUES; + } + if (StringUtil::Equals(value, "CAN_HAVE_NULL_AND_VALID_VALUES")) { + return StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(StrTimeSpecifier value) { + switch(value) { + case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: + return "ABBREVIATED_WEEKDAY_NAME"; + case StrTimeSpecifier::FULL_WEEKDAY_NAME: + return "FULL_WEEKDAY_NAME"; + case StrTimeSpecifier::WEEKDAY_DECIMAL: + return "WEEKDAY_DECIMAL"; + case StrTimeSpecifier::DAY_OF_MONTH_PADDED: + return "DAY_OF_MONTH_PADDED"; + case StrTimeSpecifier::DAY_OF_MONTH: + return "DAY_OF_MONTH"; + case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: + return "ABBREVIATED_MONTH_NAME"; + case StrTimeSpecifier::FULL_MONTH_NAME: + return "FULL_MONTH_NAME"; + case StrTimeSpecifier::MONTH_DECIMAL_PADDED: + return "MONTH_DECIMAL_PADDED"; + case StrTimeSpecifier::MONTH_DECIMAL: + return "MONTH_DECIMAL"; + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: + return "YEAR_WITHOUT_CENTURY_PADDED"; + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: + return "YEAR_WITHOUT_CENTURY"; + case StrTimeSpecifier::YEAR_DECIMAL: + return "YEAR_DECIMAL"; + case StrTimeSpecifier::HOUR_24_PADDED: + return "HOUR_24_PADDED"; + case StrTimeSpecifier::HOUR_24_DECIMAL: + return "HOUR_24_DECIMAL"; + case StrTimeSpecifier::HOUR_12_PADDED: + return "HOUR_12_PADDED"; + case StrTimeSpecifier::HOUR_12_DECIMAL: + return "HOUR_12_DECIMAL"; + case StrTimeSpecifier::AM_PM: + return "AM_PM"; + case StrTimeSpecifier::MINUTE_PADDED: + return "MINUTE_PADDED"; + case StrTimeSpecifier::MINUTE_DECIMAL: + return "MINUTE_DECIMAL"; + case StrTimeSpecifier::SECOND_PADDED: + return "SECOND_PADDED"; + case StrTimeSpecifier::SECOND_DECIMAL: + return "SECOND_DECIMAL"; + case StrTimeSpecifier::MICROSECOND_PADDED: + return "MICROSECOND_PADDED"; + case StrTimeSpecifier::MILLISECOND_PADDED: + return "MILLISECOND_PADDED"; + case StrTimeSpecifier::UTC_OFFSET: + return "UTC_OFFSET"; + case StrTimeSpecifier::TZ_NAME: + return "TZ_NAME"; + case StrTimeSpecifier::DAY_OF_YEAR_PADDED: + return "DAY_OF_YEAR_PADDED"; + case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: + return "DAY_OF_YEAR_DECIMAL"; + case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: + return "WEEK_NUMBER_PADDED_SUN_FIRST"; + case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: + return "WEEK_NUMBER_PADDED_MON_FIRST"; + case StrTimeSpecifier::LOCALE_APPROPRIATE_DATE_AND_TIME: + return "LOCALE_APPROPRIATE_DATE_AND_TIME"; + case StrTimeSpecifier::LOCALE_APPROPRIATE_DATE: + return "LOCALE_APPROPRIATE_DATE"; + case StrTimeSpecifier::LOCALE_APPROPRIATE_TIME: + return "LOCALE_APPROPRIATE_TIME"; + case StrTimeSpecifier::NANOSECOND_PADDED: + return "NANOSECOND_PADDED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +StrTimeSpecifier EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "ABBREVIATED_WEEKDAY_NAME")) { + return StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME; + } + if (StringUtil::Equals(value, "FULL_WEEKDAY_NAME")) { + return StrTimeSpecifier::FULL_WEEKDAY_NAME; + } + if (StringUtil::Equals(value, "WEEKDAY_DECIMAL")) { + return StrTimeSpecifier::WEEKDAY_DECIMAL; + } + if (StringUtil::Equals(value, "DAY_OF_MONTH_PADDED")) { + return StrTimeSpecifier::DAY_OF_MONTH_PADDED; + } + if (StringUtil::Equals(value, "DAY_OF_MONTH")) { + return StrTimeSpecifier::DAY_OF_MONTH; + } + if (StringUtil::Equals(value, "ABBREVIATED_MONTH_NAME")) { + return StrTimeSpecifier::ABBREVIATED_MONTH_NAME; + } + if (StringUtil::Equals(value, "FULL_MONTH_NAME")) { + return StrTimeSpecifier::FULL_MONTH_NAME; + } + if (StringUtil::Equals(value, "MONTH_DECIMAL_PADDED")) { + return StrTimeSpecifier::MONTH_DECIMAL_PADDED; + } + if (StringUtil::Equals(value, "MONTH_DECIMAL")) { + return StrTimeSpecifier::MONTH_DECIMAL; + } + if (StringUtil::Equals(value, "YEAR_WITHOUT_CENTURY_PADDED")) { + return StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED; + } + if (StringUtil::Equals(value, "YEAR_WITHOUT_CENTURY")) { + return StrTimeSpecifier::YEAR_WITHOUT_CENTURY; + } + if (StringUtil::Equals(value, "YEAR_DECIMAL")) { + return StrTimeSpecifier::YEAR_DECIMAL; + } + if (StringUtil::Equals(value, "HOUR_24_PADDED")) { + return StrTimeSpecifier::HOUR_24_PADDED; + } + if (StringUtil::Equals(value, "HOUR_24_DECIMAL")) { + return StrTimeSpecifier::HOUR_24_DECIMAL; + } + if (StringUtil::Equals(value, "HOUR_12_PADDED")) { + return StrTimeSpecifier::HOUR_12_PADDED; + } + if (StringUtil::Equals(value, "HOUR_12_DECIMAL")) { + return StrTimeSpecifier::HOUR_12_DECIMAL; + } + if (StringUtil::Equals(value, "AM_PM")) { + return StrTimeSpecifier::AM_PM; + } + if (StringUtil::Equals(value, "MINUTE_PADDED")) { + return StrTimeSpecifier::MINUTE_PADDED; + } + if (StringUtil::Equals(value, "MINUTE_DECIMAL")) { + return StrTimeSpecifier::MINUTE_DECIMAL; + } + if (StringUtil::Equals(value, "SECOND_PADDED")) { + return StrTimeSpecifier::SECOND_PADDED; + } + if (StringUtil::Equals(value, "SECOND_DECIMAL")) { + return StrTimeSpecifier::SECOND_DECIMAL; + } + if (StringUtil::Equals(value, "MICROSECOND_PADDED")) { + return StrTimeSpecifier::MICROSECOND_PADDED; + } + if (StringUtil::Equals(value, "MILLISECOND_PADDED")) { + return StrTimeSpecifier::MILLISECOND_PADDED; + } + if (StringUtil::Equals(value, "UTC_OFFSET")) { + return StrTimeSpecifier::UTC_OFFSET; + } + if (StringUtil::Equals(value, "TZ_NAME")) { + return StrTimeSpecifier::TZ_NAME; + } + if (StringUtil::Equals(value, "DAY_OF_YEAR_PADDED")) { + return StrTimeSpecifier::DAY_OF_YEAR_PADDED; + } + if (StringUtil::Equals(value, "DAY_OF_YEAR_DECIMAL")) { + return StrTimeSpecifier::DAY_OF_YEAR_DECIMAL; + } + if (StringUtil::Equals(value, "WEEK_NUMBER_PADDED_SUN_FIRST")) { + return StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST; + } + if (StringUtil::Equals(value, "WEEK_NUMBER_PADDED_MON_FIRST")) { + return StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST; + } + if (StringUtil::Equals(value, "LOCALE_APPROPRIATE_DATE_AND_TIME")) { + return StrTimeSpecifier::LOCALE_APPROPRIATE_DATE_AND_TIME; + } + if (StringUtil::Equals(value, "LOCALE_APPROPRIATE_DATE")) { + return StrTimeSpecifier::LOCALE_APPROPRIATE_DATE; + } + if (StringUtil::Equals(value, "LOCALE_APPROPRIATE_TIME")) { + return StrTimeSpecifier::LOCALE_APPROPRIATE_TIME; + } + if (StringUtil::Equals(value, "NANOSECOND_PADDED")) { + return StrTimeSpecifier::NANOSECOND_PADDED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(SubqueryType value) { + switch(value) { + case SubqueryType::INVALID: + return "INVALID"; + case SubqueryType::SCALAR: + return "SCALAR"; + case SubqueryType::EXISTS: + return "EXISTS"; + case SubqueryType::NOT_EXISTS: + return "NOT_EXISTS"; + case SubqueryType::ANY: + return "ANY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +SubqueryType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return SubqueryType::INVALID; + } + if (StringUtil::Equals(value, "SCALAR")) { + return SubqueryType::SCALAR; + } + if (StringUtil::Equals(value, "EXISTS")) { + return SubqueryType::EXISTS; + } + if (StringUtil::Equals(value, "NOT_EXISTS")) { + return SubqueryType::NOT_EXISTS; + } + if (StringUtil::Equals(value, "ANY")) { + return SubqueryType::ANY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TableColumnType value) { + switch(value) { + case TableColumnType::STANDARD: + return "STANDARD"; + case TableColumnType::GENERATED: + return "GENERATED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TableColumnType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "STANDARD")) { + return TableColumnType::STANDARD; + } + if (StringUtil::Equals(value, "GENERATED")) { + return TableColumnType::GENERATED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TableFilterType value) { + switch(value) { + case TableFilterType::CONSTANT_COMPARISON: + return "CONSTANT_COMPARISON"; + case TableFilterType::IS_NULL: + return "IS_NULL"; + case TableFilterType::IS_NOT_NULL: + return "IS_NOT_NULL"; + case TableFilterType::CONJUNCTION_OR: + return "CONJUNCTION_OR"; + case TableFilterType::CONJUNCTION_AND: + return "CONJUNCTION_AND"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TableFilterType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "CONSTANT_COMPARISON")) { + return TableFilterType::CONSTANT_COMPARISON; + } + if (StringUtil::Equals(value, "IS_NULL")) { + return TableFilterType::IS_NULL; + } + if (StringUtil::Equals(value, "IS_NOT_NULL")) { + return TableFilterType::IS_NOT_NULL; + } + if (StringUtil::Equals(value, "CONJUNCTION_OR")) { + return TableFilterType::CONJUNCTION_OR; + } + if (StringUtil::Equals(value, "CONJUNCTION_AND")) { + return TableFilterType::CONJUNCTION_AND; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TableReferenceType value) { + switch(value) { + case TableReferenceType::INVALID: + return "INVALID"; + case TableReferenceType::BASE_TABLE: + return "BASE_TABLE"; + case TableReferenceType::SUBQUERY: + return "SUBQUERY"; + case TableReferenceType::JOIN: + return "JOIN"; + case TableReferenceType::TABLE_FUNCTION: + return "TABLE_FUNCTION"; + case TableReferenceType::EXPRESSION_LIST: + return "EXPRESSION_LIST"; + case TableReferenceType::CTE: + return "CTE"; + case TableReferenceType::EMPTY: + return "EMPTY"; + case TableReferenceType::PIVOT: + return "PIVOT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TableReferenceType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return TableReferenceType::INVALID; + } + if (StringUtil::Equals(value, "BASE_TABLE")) { + return TableReferenceType::BASE_TABLE; + } + if (StringUtil::Equals(value, "SUBQUERY")) { + return TableReferenceType::SUBQUERY; + } + if (StringUtil::Equals(value, "JOIN")) { + return TableReferenceType::JOIN; + } + if (StringUtil::Equals(value, "TABLE_FUNCTION")) { + return TableReferenceType::TABLE_FUNCTION; + } + if (StringUtil::Equals(value, "EXPRESSION_LIST")) { + return TableReferenceType::EXPRESSION_LIST; + } + if (StringUtil::Equals(value, "CTE")) { + return TableReferenceType::CTE; + } + if (StringUtil::Equals(value, "EMPTY")) { + return TableReferenceType::EMPTY; + } + if (StringUtil::Equals(value, "PIVOT")) { + return TableReferenceType::PIVOT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TableScanType value) { + switch(value) { + case TableScanType::TABLE_SCAN_REGULAR: + return "TABLE_SCAN_REGULAR"; + case TableScanType::TABLE_SCAN_COMMITTED_ROWS: + return "TABLE_SCAN_COMMITTED_ROWS"; + case TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES: + return "TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES"; + case TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED: + return "TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TableScanType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "TABLE_SCAN_REGULAR")) { + return TableScanType::TABLE_SCAN_REGULAR; + } + if (StringUtil::Equals(value, "TABLE_SCAN_COMMITTED_ROWS")) { + return TableScanType::TABLE_SCAN_COMMITTED_ROWS; + } + if (StringUtil::Equals(value, "TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES")) { + return TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES; + } + if (StringUtil::Equals(value, "TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED")) { + return TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TaskExecutionMode value) { + switch(value) { + case TaskExecutionMode::PROCESS_ALL: + return "PROCESS_ALL"; + case TaskExecutionMode::PROCESS_PARTIAL: + return "PROCESS_PARTIAL"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TaskExecutionMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "PROCESS_ALL")) { + return TaskExecutionMode::PROCESS_ALL; + } + if (StringUtil::Equals(value, "PROCESS_PARTIAL")) { + return TaskExecutionMode::PROCESS_PARTIAL; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TaskExecutionResult value) { + switch(value) { + case TaskExecutionResult::TASK_FINISHED: + return "TASK_FINISHED"; + case TaskExecutionResult::TASK_NOT_FINISHED: + return "TASK_NOT_FINISHED"; + case TaskExecutionResult::TASK_ERROR: + return "TASK_ERROR"; + case TaskExecutionResult::TASK_BLOCKED: + return "TASK_BLOCKED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TaskExecutionResult EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "TASK_FINISHED")) { + return TaskExecutionResult::TASK_FINISHED; + } + if (StringUtil::Equals(value, "TASK_NOT_FINISHED")) { + return TaskExecutionResult::TASK_NOT_FINISHED; + } + if (StringUtil::Equals(value, "TASK_ERROR")) { + return TaskExecutionResult::TASK_ERROR; + } + if (StringUtil::Equals(value, "TASK_BLOCKED")) { + return TaskExecutionResult::TASK_BLOCKED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TimestampCastResult value) { + switch(value) { + case TimestampCastResult::SUCCESS: + return "SUCCESS"; + case TimestampCastResult::ERROR_INCORRECT_FORMAT: + return "ERROR_INCORRECT_FORMAT"; + case TimestampCastResult::ERROR_NON_UTC_TIMEZONE: + return "ERROR_NON_UTC_TIMEZONE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TimestampCastResult EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "SUCCESS")) { + return TimestampCastResult::SUCCESS; + } + if (StringUtil::Equals(value, "ERROR_INCORRECT_FORMAT")) { + return TimestampCastResult::ERROR_INCORRECT_FORMAT; + } + if (StringUtil::Equals(value, "ERROR_NON_UTC_TIMEZONE")) { + return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TransactionType value) { + switch(value) { + case TransactionType::INVALID: + return "INVALID"; + case TransactionType::BEGIN_TRANSACTION: + return "BEGIN_TRANSACTION"; + case TransactionType::COMMIT: + return "COMMIT"; + case TransactionType::ROLLBACK: + return "ROLLBACK"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TransactionType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return TransactionType::INVALID; + } + if (StringUtil::Equals(value, "BEGIN_TRANSACTION")) { + return TransactionType::BEGIN_TRANSACTION; + } + if (StringUtil::Equals(value, "COMMIT")) { + return TransactionType::COMMIT; + } + if (StringUtil::Equals(value, "ROLLBACK")) { + return TransactionType::ROLLBACK; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(TupleDataPinProperties value) { + switch(value) { + case TupleDataPinProperties::INVALID: + return "INVALID"; + case TupleDataPinProperties::KEEP_EVERYTHING_PINNED: + return "KEEP_EVERYTHING_PINNED"; + case TupleDataPinProperties::UNPIN_AFTER_DONE: + return "UNPIN_AFTER_DONE"; + case TupleDataPinProperties::DESTROY_AFTER_DONE: + return "DESTROY_AFTER_DONE"; + case TupleDataPinProperties::ALREADY_PINNED: + return "ALREADY_PINNED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TupleDataPinProperties EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return TupleDataPinProperties::INVALID; + } + if (StringUtil::Equals(value, "KEEP_EVERYTHING_PINNED")) { + return TupleDataPinProperties::KEEP_EVERYTHING_PINNED; + } + if (StringUtil::Equals(value, "UNPIN_AFTER_DONE")) { + return TupleDataPinProperties::UNPIN_AFTER_DONE; + } + if (StringUtil::Equals(value, "DESTROY_AFTER_DONE")) { + return TupleDataPinProperties::DESTROY_AFTER_DONE; + } + if (StringUtil::Equals(value, "ALREADY_PINNED")) { + return TupleDataPinProperties::ALREADY_PINNED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(UndoFlags value) { + switch(value) { + case UndoFlags::EMPTY_ENTRY: + return "EMPTY_ENTRY"; + case UndoFlags::CATALOG_ENTRY: + return "CATALOG_ENTRY"; + case UndoFlags::INSERT_TUPLE: + return "INSERT_TUPLE"; + case UndoFlags::DELETE_TUPLE: + return "DELETE_TUPLE"; + case UndoFlags::UPDATE_TUPLE: + return "UPDATE_TUPLE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +UndoFlags EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "EMPTY_ENTRY")) { + return UndoFlags::EMPTY_ENTRY; + } + if (StringUtil::Equals(value, "CATALOG_ENTRY")) { + return UndoFlags::CATALOG_ENTRY; + } + if (StringUtil::Equals(value, "INSERT_TUPLE")) { + return UndoFlags::INSERT_TUPLE; + } + if (StringUtil::Equals(value, "DELETE_TUPLE")) { + return UndoFlags::DELETE_TUPLE; + } + if (StringUtil::Equals(value, "UPDATE_TUPLE")) { + return UndoFlags::UPDATE_TUPLE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(UnionInvalidReason value) { + switch(value) { + case UnionInvalidReason::VALID: + return "VALID"; + case UnionInvalidReason::TAG_OUT_OF_RANGE: + return "TAG_OUT_OF_RANGE"; + case UnionInvalidReason::NO_MEMBERS: + return "NO_MEMBERS"; + case UnionInvalidReason::VALIDITY_OVERLAP: + return "VALIDITY_OVERLAP"; + case UnionInvalidReason::TAG_MISMATCH: + return "TAG_MISMATCH"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +UnionInvalidReason EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "VALID")) { + return UnionInvalidReason::VALID; + } + if (StringUtil::Equals(value, "TAG_OUT_OF_RANGE")) { + return UnionInvalidReason::TAG_OUT_OF_RANGE; + } + if (StringUtil::Equals(value, "NO_MEMBERS")) { + return UnionInvalidReason::NO_MEMBERS; + } + if (StringUtil::Equals(value, "VALIDITY_OVERLAP")) { + return UnionInvalidReason::VALIDITY_OVERLAP; + } + if (StringUtil::Equals(value, "TAG_MISMATCH")) { + return UnionInvalidReason::TAG_MISMATCH; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(VectorAuxiliaryDataType value) { + switch(value) { + case VectorAuxiliaryDataType::ARROW_AUXILIARY: + return "ARROW_AUXILIARY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +VectorAuxiliaryDataType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "ARROW_AUXILIARY")) { + return VectorAuxiliaryDataType::ARROW_AUXILIARY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(VectorBufferType value) { + switch(value) { + case VectorBufferType::STANDARD_BUFFER: + return "STANDARD_BUFFER"; + case VectorBufferType::DICTIONARY_BUFFER: + return "DICTIONARY_BUFFER"; + case VectorBufferType::VECTOR_CHILD_BUFFER: + return "VECTOR_CHILD_BUFFER"; + case VectorBufferType::STRING_BUFFER: + return "STRING_BUFFER"; + case VectorBufferType::FSST_BUFFER: + return "FSST_BUFFER"; + case VectorBufferType::STRUCT_BUFFER: + return "STRUCT_BUFFER"; + case VectorBufferType::LIST_BUFFER: + return "LIST_BUFFER"; + case VectorBufferType::MANAGED_BUFFER: + return "MANAGED_BUFFER"; + case VectorBufferType::OPAQUE_BUFFER: + return "OPAQUE_BUFFER"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +VectorBufferType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "STANDARD_BUFFER")) { + return VectorBufferType::STANDARD_BUFFER; + } + if (StringUtil::Equals(value, "DICTIONARY_BUFFER")) { + return VectorBufferType::DICTIONARY_BUFFER; + } + if (StringUtil::Equals(value, "VECTOR_CHILD_BUFFER")) { + return VectorBufferType::VECTOR_CHILD_BUFFER; + } + if (StringUtil::Equals(value, "STRING_BUFFER")) { + return VectorBufferType::STRING_BUFFER; + } + if (StringUtil::Equals(value, "FSST_BUFFER")) { + return VectorBufferType::FSST_BUFFER; + } + if (StringUtil::Equals(value, "STRUCT_BUFFER")) { + return VectorBufferType::STRUCT_BUFFER; + } + if (StringUtil::Equals(value, "LIST_BUFFER")) { + return VectorBufferType::LIST_BUFFER; + } + if (StringUtil::Equals(value, "MANAGED_BUFFER")) { + return VectorBufferType::MANAGED_BUFFER; + } + if (StringUtil::Equals(value, "OPAQUE_BUFFER")) { + return VectorBufferType::OPAQUE_BUFFER; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(VectorType value) { + switch(value) { + case VectorType::FLAT_VECTOR: + return "FLAT_VECTOR"; + case VectorType::FSST_VECTOR: + return "FSST_VECTOR"; + case VectorType::CONSTANT_VECTOR: + return "CONSTANT_VECTOR"; + case VectorType::DICTIONARY_VECTOR: + return "DICTIONARY_VECTOR"; + case VectorType::SEQUENCE_VECTOR: + return "SEQUENCE_VECTOR"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +VectorType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "FLAT_VECTOR")) { + return VectorType::FLAT_VECTOR; + } + if (StringUtil::Equals(value, "FSST_VECTOR")) { + return VectorType::FSST_VECTOR; + } + if (StringUtil::Equals(value, "CONSTANT_VECTOR")) { + return VectorType::CONSTANT_VECTOR; + } + if (StringUtil::Equals(value, "DICTIONARY_VECTOR")) { + return VectorType::DICTIONARY_VECTOR; + } + if (StringUtil::Equals(value, "SEQUENCE_VECTOR")) { + return VectorType::SEQUENCE_VECTOR; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(VerificationType value) { + switch(value) { + case VerificationType::ORIGINAL: + return "ORIGINAL"; + case VerificationType::COPIED: + return "COPIED"; + case VerificationType::DESERIALIZED: + return "DESERIALIZED"; + case VerificationType::PARSED: + return "PARSED"; + case VerificationType::UNOPTIMIZED: + return "UNOPTIMIZED"; + case VerificationType::NO_OPERATOR_CACHING: + return "NO_OPERATOR_CACHING"; + case VerificationType::PREPARED: + return "PREPARED"; + case VerificationType::EXTERNAL: + return "EXTERNAL"; + case VerificationType::INVALID: + return "INVALID"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +VerificationType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "ORIGINAL")) { + return VerificationType::ORIGINAL; + } + if (StringUtil::Equals(value, "COPIED")) { + return VerificationType::COPIED; + } + if (StringUtil::Equals(value, "DESERIALIZED")) { + return VerificationType::DESERIALIZED; + } + if (StringUtil::Equals(value, "PARSED")) { + return VerificationType::PARSED; + } + if (StringUtil::Equals(value, "UNOPTIMIZED")) { + return VerificationType::UNOPTIMIZED; + } + if (StringUtil::Equals(value, "NO_OPERATOR_CACHING")) { + return VerificationType::NO_OPERATOR_CACHING; + } + if (StringUtil::Equals(value, "PREPARED")) { + return VerificationType::PREPARED; + } + if (StringUtil::Equals(value, "EXTERNAL")) { + return VerificationType::EXTERNAL; + } + if (StringUtil::Equals(value, "INVALID")) { + return VerificationType::INVALID; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(VerifyExistenceType value) { + switch(value) { + case VerifyExistenceType::APPEND: + return "APPEND"; + case VerifyExistenceType::APPEND_FK: + return "APPEND_FK"; + case VerifyExistenceType::DELETE_FK: + return "DELETE_FK"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +VerifyExistenceType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "APPEND")) { + return VerifyExistenceType::APPEND; + } + if (StringUtil::Equals(value, "APPEND_FK")) { + return VerifyExistenceType::APPEND_FK; + } + if (StringUtil::Equals(value, "DELETE_FK")) { + return VerifyExistenceType::DELETE_FK; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(WALType value) { + switch(value) { + case WALType::INVALID: + return "INVALID"; + case WALType::CREATE_TABLE: + return "CREATE_TABLE"; + case WALType::DROP_TABLE: + return "DROP_TABLE"; + case WALType::CREATE_SCHEMA: + return "CREATE_SCHEMA"; + case WALType::DROP_SCHEMA: + return "DROP_SCHEMA"; + case WALType::CREATE_VIEW: + return "CREATE_VIEW"; + case WALType::DROP_VIEW: + return "DROP_VIEW"; + case WALType::CREATE_SEQUENCE: + return "CREATE_SEQUENCE"; + case WALType::DROP_SEQUENCE: + return "DROP_SEQUENCE"; + case WALType::SEQUENCE_VALUE: + return "SEQUENCE_VALUE"; + case WALType::CREATE_MACRO: + return "CREATE_MACRO"; + case WALType::DROP_MACRO: + return "DROP_MACRO"; + case WALType::CREATE_TYPE: + return "CREATE_TYPE"; + case WALType::DROP_TYPE: + return "DROP_TYPE"; + case WALType::ALTER_INFO: + return "ALTER_INFO"; + case WALType::CREATE_TABLE_MACRO: + return "CREATE_TABLE_MACRO"; + case WALType::DROP_TABLE_MACRO: + return "DROP_TABLE_MACRO"; + case WALType::CREATE_INDEX: + return "CREATE_INDEX"; + case WALType::DROP_INDEX: + return "DROP_INDEX"; + case WALType::USE_TABLE: + return "USE_TABLE"; + case WALType::INSERT_TUPLE: + return "INSERT_TUPLE"; + case WALType::DELETE_TUPLE: + return "DELETE_TUPLE"; + case WALType::UPDATE_TUPLE: + return "UPDATE_TUPLE"; + case WALType::CHECKPOINT: + return "CHECKPOINT"; + case WALType::WAL_FLUSH: + return "WAL_FLUSH"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +WALType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return WALType::INVALID; + } + if (StringUtil::Equals(value, "CREATE_TABLE")) { + return WALType::CREATE_TABLE; + } + if (StringUtil::Equals(value, "DROP_TABLE")) { + return WALType::DROP_TABLE; + } + if (StringUtil::Equals(value, "CREATE_SCHEMA")) { + return WALType::CREATE_SCHEMA; + } + if (StringUtil::Equals(value, "DROP_SCHEMA")) { + return WALType::DROP_SCHEMA; + } + if (StringUtil::Equals(value, "CREATE_VIEW")) { + return WALType::CREATE_VIEW; + } + if (StringUtil::Equals(value, "DROP_VIEW")) { + return WALType::DROP_VIEW; + } + if (StringUtil::Equals(value, "CREATE_SEQUENCE")) { + return WALType::CREATE_SEQUENCE; + } + if (StringUtil::Equals(value, "DROP_SEQUENCE")) { + return WALType::DROP_SEQUENCE; + } + if (StringUtil::Equals(value, "SEQUENCE_VALUE")) { + return WALType::SEQUENCE_VALUE; + } + if (StringUtil::Equals(value, "CREATE_MACRO")) { + return WALType::CREATE_MACRO; + } + if (StringUtil::Equals(value, "DROP_MACRO")) { + return WALType::DROP_MACRO; + } + if (StringUtil::Equals(value, "CREATE_TYPE")) { + return WALType::CREATE_TYPE; + } + if (StringUtil::Equals(value, "DROP_TYPE")) { + return WALType::DROP_TYPE; + } + if (StringUtil::Equals(value, "ALTER_INFO")) { + return WALType::ALTER_INFO; + } + if (StringUtil::Equals(value, "CREATE_TABLE_MACRO")) { + return WALType::CREATE_TABLE_MACRO; + } + if (StringUtil::Equals(value, "DROP_TABLE_MACRO")) { + return WALType::DROP_TABLE_MACRO; + } + if (StringUtil::Equals(value, "CREATE_INDEX")) { + return WALType::CREATE_INDEX; + } + if (StringUtil::Equals(value, "DROP_INDEX")) { + return WALType::DROP_INDEX; + } + if (StringUtil::Equals(value, "USE_TABLE")) { + return WALType::USE_TABLE; + } + if (StringUtil::Equals(value, "INSERT_TUPLE")) { + return WALType::INSERT_TUPLE; + } + if (StringUtil::Equals(value, "DELETE_TUPLE")) { + return WALType::DELETE_TUPLE; + } + if (StringUtil::Equals(value, "UPDATE_TUPLE")) { + return WALType::UPDATE_TUPLE; + } + if (StringUtil::Equals(value, "CHECKPOINT")) { + return WALType::CHECKPOINT; + } + if (StringUtil::Equals(value, "WAL_FLUSH")) { + return WALType::WAL_FLUSH; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(WindowAggregationMode value) { + switch(value) { + case WindowAggregationMode::WINDOW: + return "WINDOW"; + case WindowAggregationMode::COMBINE: + return "COMBINE"; + case WindowAggregationMode::SEPARATE: + return "SEPARATE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +WindowAggregationMode EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "WINDOW")) { + return WindowAggregationMode::WINDOW; + } + if (StringUtil::Equals(value, "COMBINE")) { + return WindowAggregationMode::COMBINE; + } + if (StringUtil::Equals(value, "SEPARATE")) { + return WindowAggregationMode::SEPARATE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(WindowBoundary value) { + switch(value) { + case WindowBoundary::INVALID: + return "INVALID"; + case WindowBoundary::UNBOUNDED_PRECEDING: + return "UNBOUNDED_PRECEDING"; + case WindowBoundary::UNBOUNDED_FOLLOWING: + return "UNBOUNDED_FOLLOWING"; + case WindowBoundary::CURRENT_ROW_RANGE: + return "CURRENT_ROW_RANGE"; + case WindowBoundary::CURRENT_ROW_ROWS: + return "CURRENT_ROW_ROWS"; + case WindowBoundary::EXPR_PRECEDING_ROWS: + return "EXPR_PRECEDING_ROWS"; + case WindowBoundary::EXPR_FOLLOWING_ROWS: + return "EXPR_FOLLOWING_ROWS"; + case WindowBoundary::EXPR_PRECEDING_RANGE: + return "EXPR_PRECEDING_RANGE"; + case WindowBoundary::EXPR_FOLLOWING_RANGE: + return "EXPR_FOLLOWING_RANGE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +WindowBoundary EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return WindowBoundary::INVALID; + } + if (StringUtil::Equals(value, "UNBOUNDED_PRECEDING")) { + return WindowBoundary::UNBOUNDED_PRECEDING; + } + if (StringUtil::Equals(value, "UNBOUNDED_FOLLOWING")) { + return WindowBoundary::UNBOUNDED_FOLLOWING; + } + if (StringUtil::Equals(value, "CURRENT_ROW_RANGE")) { + return WindowBoundary::CURRENT_ROW_RANGE; + } + if (StringUtil::Equals(value, "CURRENT_ROW_ROWS")) { + return WindowBoundary::CURRENT_ROW_ROWS; + } + if (StringUtil::Equals(value, "EXPR_PRECEDING_ROWS")) { + return WindowBoundary::EXPR_PRECEDING_ROWS; + } + if (StringUtil::Equals(value, "EXPR_FOLLOWING_ROWS")) { + return WindowBoundary::EXPR_FOLLOWING_ROWS; + } + if (StringUtil::Equals(value, "EXPR_PRECEDING_RANGE")) { + return WindowBoundary::EXPR_PRECEDING_RANGE; + } + if (StringUtil::Equals(value, "EXPR_FOLLOWING_RANGE")) { + return WindowBoundary::EXPR_FOLLOWING_RANGE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +} + diff --git a/src/duckdb/src/common/enums/catalog_type.cpp b/src/duckdb/src/common/enums/catalog_type.cpp new file mode 100644 index 00000000..0f7ea053 --- /dev/null +++ b/src/duckdb/src/common/enums/catalog_type.cpp @@ -0,0 +1,51 @@ +#include "duckdb/common/enums/catalog_type.hpp" + +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +// LCOV_EXCL_START +string CatalogTypeToString(CatalogType type) { + switch (type) { + case CatalogType::COLLATION_ENTRY: + return "Collation"; + case CatalogType::TYPE_ENTRY: + return "Type"; + case CatalogType::TABLE_ENTRY: + return "Table"; + case CatalogType::SCHEMA_ENTRY: + return "Schema"; + case CatalogType::DATABASE_ENTRY: + return "Database"; + case CatalogType::TABLE_FUNCTION_ENTRY: + return "Table Function"; + case CatalogType::SCALAR_FUNCTION_ENTRY: + return "Scalar Function"; + case CatalogType::AGGREGATE_FUNCTION_ENTRY: + return "Aggregate Function"; + case CatalogType::COPY_FUNCTION_ENTRY: + return "Copy Function"; + case CatalogType::PRAGMA_FUNCTION_ENTRY: + return "Pragma Function"; + case CatalogType::MACRO_ENTRY: + return "Macro Function"; + case CatalogType::TABLE_MACRO_ENTRY: + return "Table Macro Function"; + case CatalogType::VIEW_ENTRY: + return "View"; + case CatalogType::INDEX_ENTRY: + return "Index"; + case CatalogType::PREPARED_STATEMENT: + return "Prepared Statement"; + case CatalogType::SEQUENCE_ENTRY: + return "Sequence"; + case CatalogType::INVALID: + case CatalogType::DELETED_ENTRY: + case CatalogType::UPDATED_ENTRY: + break; + } + return "INVALID"; +} +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/compression_type.cpp b/src/duckdb/src/common/enums/compression_type.cpp new file mode 100644 index 00000000..267c2253 --- /dev/null +++ b/src/duckdb/src/common/enums/compression_type.cpp @@ -0,0 +1,70 @@ +#include "duckdb/common/enums/compression_type.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +// LCOV_EXCL_START + +vector ListCompressionTypes(void) { + vector compression_types; + uint8_t amount_of_compression_options = (uint8_t)CompressionType::COMPRESSION_COUNT; + compression_types.reserve(amount_of_compression_options); + for (uint8_t i = 0; i < amount_of_compression_options; i++) { + compression_types.push_back(CompressionTypeToString((CompressionType)i)); + } + return compression_types; +} + +CompressionType CompressionTypeFromString(const string &str) { + auto compression = StringUtil::Lower(str); + if (compression == "uncompressed") { + return CompressionType::COMPRESSION_UNCOMPRESSED; + } else if (compression == "rle") { + return CompressionType::COMPRESSION_RLE; + } else if (compression == "dictionary") { + return CompressionType::COMPRESSION_DICTIONARY; + } else if (compression == "pfor") { + return CompressionType::COMPRESSION_PFOR_DELTA; + } else if (compression == "bitpacking") { + return CompressionType::COMPRESSION_BITPACKING; + } else if (compression == "fsst") { + return CompressionType::COMPRESSION_FSST; + } else if (compression == "chimp") { + return CompressionType::COMPRESSION_CHIMP; + } else if (compression == "patas") { + return CompressionType::COMPRESSION_PATAS; + } else { + return CompressionType::COMPRESSION_AUTO; + } +} + +string CompressionTypeToString(CompressionType type) { + switch (type) { + case CompressionType::COMPRESSION_AUTO: + return "Auto"; + case CompressionType::COMPRESSION_UNCOMPRESSED: + return "Uncompressed"; + case CompressionType::COMPRESSION_CONSTANT: + return "Constant"; + case CompressionType::COMPRESSION_RLE: + return "RLE"; + case CompressionType::COMPRESSION_DICTIONARY: + return "Dictionary"; + case CompressionType::COMPRESSION_PFOR_DELTA: + return "PFOR"; + case CompressionType::COMPRESSION_BITPACKING: + return "BitPacking"; + case CompressionType::COMPRESSION_FSST: + return "FSST"; + case CompressionType::COMPRESSION_CHIMP: + return "Chimp"; + case CompressionType::COMPRESSION_PATAS: + return "Patas"; + default: + throw InternalException("Unrecognized compression type!"); + } +} +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/date_part_specifier.cpp b/src/duckdb/src/common/enums/date_part_specifier.cpp new file mode 100644 index 00000000..4a633a51 --- /dev/null +++ b/src/duckdb/src/common/enums/date_part_specifier.cpp @@ -0,0 +1,84 @@ +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +bool TryGetDatePartSpecifier(const string &specifier_p, DatePartSpecifier &result) { + auto specifier = StringUtil::Lower(specifier_p); + if (specifier == "year" || specifier == "yr" || specifier == "y" || specifier == "years" || specifier == "yrs") { + result = DatePartSpecifier::YEAR; + } else if (specifier == "month" || specifier == "mon" || specifier == "months" || specifier == "mons") { + result = DatePartSpecifier::MONTH; + } else if (specifier == "day" || specifier == "days" || specifier == "d" || specifier == "dayofmonth") { + result = DatePartSpecifier::DAY; + } else if (specifier == "decade" || specifier == "dec" || specifier == "decades" || specifier == "decs") { + result = DatePartSpecifier::DECADE; + } else if (specifier == "century" || specifier == "cent" || specifier == "centuries" || specifier == "c") { + result = DatePartSpecifier::CENTURY; + } else if (specifier == "millennium" || specifier == "mil" || specifier == "millenniums" || + specifier == "millennia" || specifier == "mils" || specifier == "millenium") { + result = DatePartSpecifier::MILLENNIUM; + } else if (specifier == "microseconds" || specifier == "microsecond" || specifier == "us" || specifier == "usec" || + specifier == "usecs" || specifier == "usecond" || specifier == "useconds") { + result = DatePartSpecifier::MICROSECONDS; + } else if (specifier == "milliseconds" || specifier == "millisecond" || specifier == "ms" || specifier == "msec" || + specifier == "msecs" || specifier == "msecond" || specifier == "mseconds") { + result = DatePartSpecifier::MILLISECONDS; + } else if (specifier == "second" || specifier == "sec" || specifier == "seconds" || specifier == "secs" || + specifier == "s") { + result = DatePartSpecifier::SECOND; + } else if (specifier == "minute" || specifier == "min" || specifier == "minutes" || specifier == "mins" || + specifier == "m") { + result = DatePartSpecifier::MINUTE; + } else if (specifier == "hour" || specifier == "hr" || specifier == "hours" || specifier == "hrs" || + specifier == "h") { + result = DatePartSpecifier::HOUR; + } else if (specifier == "epoch") { + // seconds since 1970-01-01 + result = DatePartSpecifier::EPOCH; + } else if (specifier == "dow" || specifier == "dayofweek" || specifier == "weekday") { + // day of the week (Sunday = 0, Saturday = 6) + result = DatePartSpecifier::DOW; + } else if (specifier == "isodow") { + // isodow (Monday = 1, Sunday = 7) + result = DatePartSpecifier::ISODOW; + } else if (specifier == "week" || specifier == "weeks" || specifier == "w" || specifier == "weekofyear") { + // ISO week number + result = DatePartSpecifier::WEEK; + } else if (specifier == "doy" || specifier == "dayofyear") { + // day of the year (1-365/366) + result = DatePartSpecifier::DOY; + } else if (specifier == "quarter" || specifier == "quarters") { + // quarter of the year (1-4) + result = DatePartSpecifier::QUARTER; + } else if (specifier == "yearweek") { + // Combined isoyear and isoweek YYYYWW + result = DatePartSpecifier::YEARWEEK; + } else if (specifier == "isoyear") { + // ISO year (first week of the year may be in previous year) + result = DatePartSpecifier::ISOYEAR; + } else if (specifier == "era") { + result = DatePartSpecifier::ERA; + } else if (specifier == "timezone") { + result = DatePartSpecifier::TIMEZONE; + } else if (specifier == "timezone_hour") { + result = DatePartSpecifier::TIMEZONE_HOUR; + } else if (specifier == "timezone_minute") { + result = DatePartSpecifier::TIMEZONE_MINUTE; + } else if (specifier == "julian" || specifier == "jd") { + result = DatePartSpecifier::JULIAN_DAY; + } else { + return false; + } + return true; +} + +DatePartSpecifier GetDatePartSpecifier(const string &specifier) { + DatePartSpecifier result; + if (!TryGetDatePartSpecifier(specifier, result)) { + throw ConversionException("extract specifier \"%s\" not recognized", specifier); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/expression_type.cpp b/src/duckdb/src/common/enums/expression_type.cpp new file mode 100644 index 00000000..f6755c3c --- /dev/null +++ b/src/duckdb/src/common/enums/expression_type.cpp @@ -0,0 +1,326 @@ +#include "duckdb/common/enums/expression_type.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +string ExpressionTypeToString(ExpressionType type) { + switch (type) { + case ExpressionType::OPERATOR_CAST: + return "CAST"; + case ExpressionType::OPERATOR_NOT: + return "NOT"; + case ExpressionType::OPERATOR_IS_NULL: + return "IS_NULL"; + case ExpressionType::OPERATOR_IS_NOT_NULL: + return "IS_NOT_NULL"; + case ExpressionType::COMPARE_EQUAL: + return "EQUAL"; + case ExpressionType::COMPARE_NOTEQUAL: + return "NOTEQUAL"; + case ExpressionType::COMPARE_LESSTHAN: + return "LESSTHAN"; + case ExpressionType::COMPARE_GREATERTHAN: + return "GREATERTHAN"; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return "LESSTHANOREQUALTO"; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return "GREATERTHANOREQUALTO"; + case ExpressionType::COMPARE_IN: + return "IN"; + case ExpressionType::COMPARE_DISTINCT_FROM: + return "DISTINCT_FROM"; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return "NOT_DISTINCT_FROM"; + case ExpressionType::CONJUNCTION_AND: + return "AND"; + case ExpressionType::CONJUNCTION_OR: + return "OR"; + case ExpressionType::VALUE_CONSTANT: + return "CONSTANT"; + case ExpressionType::VALUE_PARAMETER: + return "PARAMETER"; + case ExpressionType::VALUE_TUPLE: + return "TUPLE"; + case ExpressionType::VALUE_TUPLE_ADDRESS: + return "TUPLE_ADDRESS"; + case ExpressionType::VALUE_NULL: + return "NULL"; + case ExpressionType::VALUE_VECTOR: + return "VECTOR"; + case ExpressionType::VALUE_SCALAR: + return "SCALAR"; + case ExpressionType::AGGREGATE: + return "AGGREGATE"; + case ExpressionType::WINDOW_AGGREGATE: + return "WINDOW_AGGREGATE"; + case ExpressionType::WINDOW_RANK: + return "RANK"; + case ExpressionType::WINDOW_RANK_DENSE: + return "RANK_DENSE"; + case ExpressionType::WINDOW_PERCENT_RANK: + return "PERCENT_RANK"; + case ExpressionType::WINDOW_ROW_NUMBER: + return "ROW_NUMBER"; + case ExpressionType::WINDOW_FIRST_VALUE: + return "FIRST_VALUE"; + case ExpressionType::WINDOW_LAST_VALUE: + return "LAST_VALUE"; + case ExpressionType::WINDOW_NTH_VALUE: + return "NTH_VALUE"; + case ExpressionType::WINDOW_CUME_DIST: + return "CUME_DIST"; + case ExpressionType::WINDOW_LEAD: + return "LEAD"; + case ExpressionType::WINDOW_LAG: + return "LAG"; + case ExpressionType::WINDOW_NTILE: + return "NTILE"; + case ExpressionType::FUNCTION: + return "FUNCTION"; + case ExpressionType::CASE_EXPR: + return "CASE"; + case ExpressionType::OPERATOR_NULLIF: + return "NULLIF"; + case ExpressionType::OPERATOR_COALESCE: + return "COALESCE"; + case ExpressionType::ARRAY_EXTRACT: + return "ARRAY_EXTRACT"; + case ExpressionType::ARRAY_SLICE: + return "ARRAY_SLICE"; + case ExpressionType::STRUCT_EXTRACT: + return "STRUCT_EXTRACT"; + case ExpressionType::SUBQUERY: + return "SUBQUERY"; + case ExpressionType::STAR: + return "STAR"; + case ExpressionType::PLACEHOLDER: + return "PLACEHOLDER"; + case ExpressionType::COLUMN_REF: + return "COLUMN_REF"; + case ExpressionType::FUNCTION_REF: + return "FUNCTION_REF"; + case ExpressionType::TABLE_REF: + return "TABLE_REF"; + case ExpressionType::CAST: + return "CAST"; + case ExpressionType::COMPARE_NOT_IN: + return "COMPARE_NOT_IN"; + case ExpressionType::COMPARE_BETWEEN: + return "COMPARE_BETWEEN"; + case ExpressionType::COMPARE_NOT_BETWEEN: + return "COMPARE_NOT_BETWEEN"; + case ExpressionType::VALUE_DEFAULT: + return "VALUE_DEFAULT"; + case ExpressionType::BOUND_REF: + return "BOUND_REF"; + case ExpressionType::BOUND_COLUMN_REF: + return "BOUND_COLUMN_REF"; + case ExpressionType::BOUND_FUNCTION: + return "BOUND_FUNCTION"; + case ExpressionType::BOUND_AGGREGATE: + return "BOUND_AGGREGATE"; + case ExpressionType::GROUPING_FUNCTION: + return "GROUPING"; + case ExpressionType::ARRAY_CONSTRUCTOR: + return "ARRAY_CONSTRUCTOR"; + case ExpressionType::TABLE_STAR: + return "TABLE_STAR"; + case ExpressionType::BOUND_UNNEST: + return "BOUND_UNNEST"; + case ExpressionType::COLLATE: + return "COLLATE"; + case ExpressionType::POSITIONAL_REFERENCE: + return "POSITIONAL_REFERENCE"; + case ExpressionType::BOUND_LAMBDA_REF: + return "BOUND_LAMBDA_REF"; + case ExpressionType::LAMBDA: + return "LAMBDA"; + case ExpressionType::ARROW: + return "ARROW"; + case ExpressionType::INVALID: + break; + } + return "INVALID"; +} +string ExpressionClassToString(ExpressionClass type) { + switch (type) { + case ExpressionClass::INVALID: + return "INVALID"; + case ExpressionClass::AGGREGATE: + return "AGGREGATE"; + case ExpressionClass::CASE: + return "CASE"; + case ExpressionClass::CAST: + return "CAST"; + case ExpressionClass::COLUMN_REF: + return "COLUMN_REF"; + case ExpressionClass::COMPARISON: + return "COMPARISON"; + case ExpressionClass::CONJUNCTION: + return "CONJUNCTION"; + case ExpressionClass::CONSTANT: + return "CONSTANT"; + case ExpressionClass::DEFAULT: + return "DEFAULT"; + case ExpressionClass::FUNCTION: + return "FUNCTION"; + case ExpressionClass::OPERATOR: + return "OPERATOR"; + case ExpressionClass::STAR: + return "STAR"; + case ExpressionClass::SUBQUERY: + return "SUBQUERY"; + case ExpressionClass::WINDOW: + return "WINDOW"; + case ExpressionClass::PARAMETER: + return "PARAMETER"; + case ExpressionClass::COLLATE: + return "COLLATE"; + case ExpressionClass::LAMBDA: + return "LAMBDA"; + case ExpressionClass::POSITIONAL_REFERENCE: + return "POSITIONAL_REFERENCE"; + case ExpressionClass::BETWEEN: + return "BETWEEN"; + case ExpressionClass::BOUND_AGGREGATE: + return "BOUND_AGGREGATE"; + case ExpressionClass::BOUND_CASE: + return "BOUND_CASE"; + case ExpressionClass::BOUND_CAST: + return "BOUND_CAST"; + case ExpressionClass::BOUND_COLUMN_REF: + return "BOUND_COLUMN_REF"; + case ExpressionClass::BOUND_COMPARISON: + return "BOUND_COMPARISON"; + case ExpressionClass::BOUND_CONJUNCTION: + return "BOUND_CONJUNCTION"; + case ExpressionClass::BOUND_CONSTANT: + return "BOUND_CONSTANT"; + case ExpressionClass::BOUND_DEFAULT: + return "BOUND_DEFAULT"; + case ExpressionClass::BOUND_FUNCTION: + return "BOUND_FUNCTION"; + case ExpressionClass::BOUND_OPERATOR: + return "BOUND_OPERATOR"; + case ExpressionClass::BOUND_PARAMETER: + return "BOUND_PARAMETER"; + case ExpressionClass::BOUND_REF: + return "BOUND_REF"; + case ExpressionClass::BOUND_SUBQUERY: + return "BOUND_SUBQUERY"; + case ExpressionClass::BOUND_WINDOW: + return "BOUND_WINDOW"; + case ExpressionClass::BOUND_BETWEEN: + return "BOUND_BETWEEN"; + case ExpressionClass::BOUND_UNNEST: + return "BOUND_UNNEST"; + case ExpressionClass::BOUND_LAMBDA: + return "BOUND_LAMBDA"; + case ExpressionClass::BOUND_EXPRESSION: + return "BOUND_EXPRESSION"; + default: + return "ExpressionClass::!!UNIMPLEMENTED_CASE!!"; + } +} + +string ExpressionTypeToOperator(ExpressionType type) { + switch (type) { + case ExpressionType::COMPARE_EQUAL: + return "="; + case ExpressionType::COMPARE_NOTEQUAL: + return "!="; + case ExpressionType::COMPARE_LESSTHAN: + return "<"; + case ExpressionType::COMPARE_GREATERTHAN: + return ">"; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return "<="; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return ">="; + case ExpressionType::COMPARE_DISTINCT_FROM: + return "IS DISTINCT FROM"; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return "IS NOT DISTINCT FROM"; + case ExpressionType::CONJUNCTION_AND: + return "AND"; + case ExpressionType::CONJUNCTION_OR: + return "OR"; + default: + return ""; + } +} + +ExpressionType NegateComparisonExpression(ExpressionType type) { + ExpressionType negated_type = ExpressionType::INVALID; + switch (type) { + case ExpressionType::COMPARE_EQUAL: + negated_type = ExpressionType::COMPARE_NOTEQUAL; + break; + case ExpressionType::COMPARE_NOTEQUAL: + negated_type = ExpressionType::COMPARE_EQUAL; + break; + case ExpressionType::COMPARE_LESSTHAN: + negated_type = ExpressionType::COMPARE_GREATERTHANOREQUALTO; + break; + case ExpressionType::COMPARE_GREATERTHAN: + negated_type = ExpressionType::COMPARE_LESSTHANOREQUALTO; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + negated_type = ExpressionType::COMPARE_GREATERTHAN; + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + negated_type = ExpressionType::COMPARE_LESSTHAN; + break; + default: + throw InternalException("Unsupported comparison type in negation"); + } + return negated_type; +} + +ExpressionType FlipComparisonExpression(ExpressionType type) { + ExpressionType flipped_type = ExpressionType::INVALID; + switch (type) { + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + case ExpressionType::COMPARE_DISTINCT_FROM: + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_EQUAL: + flipped_type = type; + break; + case ExpressionType::COMPARE_LESSTHAN: + flipped_type = ExpressionType::COMPARE_GREATERTHAN; + break; + case ExpressionType::COMPARE_GREATERTHAN: + flipped_type = ExpressionType::COMPARE_LESSTHAN; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + flipped_type = ExpressionType::COMPARE_GREATERTHANOREQUALTO; + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + flipped_type = ExpressionType::COMPARE_LESSTHANOREQUALTO; + break; + default: + throw InternalException("Unsupported comparison type in flip"); + } + return flipped_type; +} + +ExpressionType OperatorToExpressionType(const string &op) { + if (op == "=" || op == "==") { + return ExpressionType::COMPARE_EQUAL; + } else if (op == "!=" || op == "<>") { + return ExpressionType::COMPARE_NOTEQUAL; + } else if (op == "<") { + return ExpressionType::COMPARE_LESSTHAN; + } else if (op == ">") { + return ExpressionType::COMPARE_GREATERTHAN; + } else if (op == "<=") { + return ExpressionType::COMPARE_LESSTHANOREQUALTO; + } else if (op == ">=") { + return ExpressionType::COMPARE_GREATERTHANOREQUALTO; + } + return ExpressionType::INVALID; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/file_compression_type.cpp b/src/duckdb/src/common/enums/file_compression_type.cpp new file mode 100644 index 00000000..c3c790d2 --- /dev/null +++ b/src/duckdb/src/common/enums/file_compression_type.cpp @@ -0,0 +1,21 @@ +#include "duckdb/common/enums/file_compression_type.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +FileCompressionType FileCompressionTypeFromString(const string &input) { + auto parameter = StringUtil::Lower(input); + if (parameter == "infer" || parameter == "auto") { + return FileCompressionType::AUTO_DETECT; + } else if (parameter == "gzip") { + return FileCompressionType::GZIP; + } else if (parameter == "zstd") { + return FileCompressionType::ZSTD; + } else if (parameter == "uncompressed" || parameter == "none" || parameter.empty()) { + return FileCompressionType::UNCOMPRESSED; + } else { + throw ParserException("Unrecognized file compression type \"%s\"", input); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/join_type.cpp b/src/duckdb/src/common/enums/join_type.cpp new file mode 100644 index 00000000..f9112794 --- /dev/null +++ b/src/duckdb/src/common/enums/join_type.cpp @@ -0,0 +1,19 @@ +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +bool IsLeftOuterJoin(JoinType type) { + return type == JoinType::LEFT || type == JoinType::OUTER; +} + +bool IsRightOuterJoin(JoinType type) { + return type == JoinType::OUTER || type == JoinType::RIGHT; +} + +// **DEPRECATED**: Use EnumUtil directly instead. +string JoinTypeToString(JoinType type) { + return EnumUtil::ToString(type); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/logical_operator_type.cpp b/src/duckdb/src/common/enums/logical_operator_type.cpp new file mode 100644 index 00000000..a941343a --- /dev/null +++ b/src/duckdb/src/common/enums/logical_operator_type.cpp @@ -0,0 +1,136 @@ +#include "duckdb/common/enums/logical_operator_type.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Value <--> String Utilities +//===--------------------------------------------------------------------===// +// LCOV_EXCL_START +string LogicalOperatorToString(LogicalOperatorType type) { + switch (type) { + case LogicalOperatorType::LOGICAL_GET: + return "GET"; + case LogicalOperatorType::LOGICAL_CHUNK_GET: + return "CHUNK_GET"; + case LogicalOperatorType::LOGICAL_DELIM_GET: + return "DELIM_GET"; + case LogicalOperatorType::LOGICAL_EMPTY_RESULT: + return "EMPTY_RESULT"; + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + return "EXPRESSION_GET"; + case LogicalOperatorType::LOGICAL_ANY_JOIN: + return "ANY_JOIN"; + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + return "ASOF_JOIN"; + case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: + return "DEPENDENT_JOIN"; + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + return "COMPARISON_JOIN"; + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + return "DELIM_JOIN"; + case LogicalOperatorType::LOGICAL_PROJECTION: + return "PROJECTION"; + case LogicalOperatorType::LOGICAL_FILTER: + return "FILTER"; + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + return "AGGREGATE"; + case LogicalOperatorType::LOGICAL_WINDOW: + return "WINDOW"; + case LogicalOperatorType::LOGICAL_UNNEST: + return "UNNEST"; + case LogicalOperatorType::LOGICAL_LIMIT: + return "LIMIT"; + case LogicalOperatorType::LOGICAL_ORDER_BY: + return "ORDER_BY"; + case LogicalOperatorType::LOGICAL_TOP_N: + return "TOP_N"; + case LogicalOperatorType::LOGICAL_SAMPLE: + return "SAMPLE"; + case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: + return "LIMIT_PERCENT"; + case LogicalOperatorType::LOGICAL_COPY_TO_FILE: + return "COPY_TO_FILE"; + case LogicalOperatorType::LOGICAL_JOIN: + return "JOIN"; + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + return "CROSS_PRODUCT"; + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: + return "POSITIONAL_JOIN"; + case LogicalOperatorType::LOGICAL_UNION: + return "UNION"; + case LogicalOperatorType::LOGICAL_EXCEPT: + return "EXCEPT"; + case LogicalOperatorType::LOGICAL_INTERSECT: + return "INTERSECT"; + case LogicalOperatorType::LOGICAL_INSERT: + return "INSERT"; + case LogicalOperatorType::LOGICAL_DISTINCT: + return "DISTINCT"; + case LogicalOperatorType::LOGICAL_DELETE: + return "DELETE"; + case LogicalOperatorType::LOGICAL_UPDATE: + return "UPDATE"; + case LogicalOperatorType::LOGICAL_PREPARE: + return "PREPARE"; + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + return "DUMMY_SCAN"; + case LogicalOperatorType::LOGICAL_CREATE_INDEX: + return "CREATE_INDEX"; + case LogicalOperatorType::LOGICAL_CREATE_TABLE: + return "CREATE_TABLE"; + case LogicalOperatorType::LOGICAL_CREATE_MACRO: + return "CREATE_MACRO"; + case LogicalOperatorType::LOGICAL_EXPLAIN: + return "EXPLAIN"; + case LogicalOperatorType::LOGICAL_EXECUTE: + return "EXECUTE"; + case LogicalOperatorType::LOGICAL_VACUUM: + return "VACUUM"; + case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: + return "REC_CTE"; + case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: + return "CTE"; + case LogicalOperatorType::LOGICAL_CTE_REF: + return "CTE_SCAN"; + case LogicalOperatorType::LOGICAL_SHOW: + return "SHOW"; + case LogicalOperatorType::LOGICAL_ALTER: + return "ALTER"; + case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: + return "CREATE_SEQUENCE"; + case LogicalOperatorType::LOGICAL_CREATE_TYPE: + return "CREATE_TYPE"; + case LogicalOperatorType::LOGICAL_CREATE_VIEW: + return "CREATE_VIEW"; + case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: + return "CREATE_SCHEMA"; + case LogicalOperatorType::LOGICAL_ATTACH: + return "ATTACH"; + case LogicalOperatorType::LOGICAL_DETACH: + return "ATTACH"; + case LogicalOperatorType::LOGICAL_DROP: + return "DROP"; + case LogicalOperatorType::LOGICAL_PRAGMA: + return "PRAGMA"; + case LogicalOperatorType::LOGICAL_TRANSACTION: + return "TRANSACTION"; + case LogicalOperatorType::LOGICAL_EXPORT: + return "EXPORT"; + case LogicalOperatorType::LOGICAL_SET: + return "SET"; + case LogicalOperatorType::LOGICAL_RESET: + return "RESET"; + case LogicalOperatorType::LOGICAL_LOAD: + return "LOAD"; + case LogicalOperatorType::LOGICAL_INVALID: + break; + case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: + return "CUSTOM_OP"; + case LogicalOperatorType::LOGICAL_PIVOT: + return "PIVOT"; + } + return "INVALID"; +} +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp new file mode 100644 index 00000000..480e9eb5 --- /dev/null +++ b/src/duckdb/src/common/enums/optimizer_type.cpp @@ -0,0 +1,58 @@ +#include "duckdb/common/enums/optimizer_type.hpp" +#include "duckdb/common/string_util.hpp" + +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +struct DefaultOptimizerType { + const char *name; + OptimizerType type; +}; + +static DefaultOptimizerType internal_optimizer_types[] = { + {"expression_rewriter", OptimizerType::EXPRESSION_REWRITER}, + {"filter_pullup", OptimizerType::FILTER_PULLUP}, + {"filter_pushdown", OptimizerType::FILTER_PUSHDOWN}, + {"regex_range", OptimizerType::REGEX_RANGE}, + {"in_clause", OptimizerType::IN_CLAUSE}, + {"join_order", OptimizerType::JOIN_ORDER}, + {"deliminator", OptimizerType::DELIMINATOR}, + {"unnest_rewriter", OptimizerType::UNNEST_REWRITER}, + {"unused_columns", OptimizerType::UNUSED_COLUMNS}, + {"statistics_propagation", OptimizerType::STATISTICS_PROPAGATION}, + {"common_subexpressions", OptimizerType::COMMON_SUBEXPRESSIONS}, + {"common_aggregate", OptimizerType::COMMON_AGGREGATE}, + {"column_lifetime", OptimizerType::COLUMN_LIFETIME}, + {"top_n", OptimizerType::TOP_N}, + {"compressed_materialization", OptimizerType::COMPRESSED_MATERIALIZATION}, + {"duplicate_groups", OptimizerType::DUPLICATE_GROUPS}, + {"reorder_filter", OptimizerType::REORDER_FILTER}, + {"extension", OptimizerType::EXTENSION}, + {nullptr, OptimizerType::INVALID}}; + +string OptimizerTypeToString(OptimizerType type) { + for (idx_t i = 0; internal_optimizer_types[i].name; i++) { + if (internal_optimizer_types[i].type == type) { + return internal_optimizer_types[i].name; + } + } + throw InternalException("Invalid optimizer type"); +} + +OptimizerType OptimizerTypeFromString(const string &str) { + for (idx_t i = 0; internal_optimizer_types[i].name; i++) { + if (internal_optimizer_types[i].name == str) { + return internal_optimizer_types[i].type; + } + } + // optimizer not found, construct candidate list + vector optimizer_names; + for (idx_t i = 0; internal_optimizer_types[i].name; i++) { + optimizer_names.emplace_back(internal_optimizer_types[i].name); + } + throw ParserException("Optimizer type \"%s\" not recognized\n%s", str, + StringUtil::CandidatesErrorMessage(optimizer_names, str, "Candidate optimizers")); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/physical_operator_type.cpp b/src/duckdb/src/common/enums/physical_operator_type.cpp new file mode 100644 index 00000000..e07942d1 --- /dev/null +++ b/src/duckdb/src/common/enums/physical_operator_type.cpp @@ -0,0 +1,161 @@ +#include "duckdb/common/enums/physical_operator_type.hpp" + +namespace duckdb { + +// LCOV_EXCL_START +string PhysicalOperatorToString(PhysicalOperatorType type) { + switch (type) { + case PhysicalOperatorType::TABLE_SCAN: + return "TABLE_SCAN"; + case PhysicalOperatorType::DUMMY_SCAN: + return "DUMMY_SCAN"; + case PhysicalOperatorType::CHUNK_SCAN: + return "CHUNK_SCAN"; + case PhysicalOperatorType::COLUMN_DATA_SCAN: + return "COLUMN_DATA_SCAN"; + case PhysicalOperatorType::DELIM_SCAN: + return "DELIM_SCAN"; + case PhysicalOperatorType::ORDER_BY: + return "ORDER_BY"; + case PhysicalOperatorType::LIMIT: + return "LIMIT"; + case PhysicalOperatorType::LIMIT_PERCENT: + return "LIMIT_PERCENT"; + case PhysicalOperatorType::STREAMING_LIMIT: + return "STREAMING_LIMIT"; + case PhysicalOperatorType::RESERVOIR_SAMPLE: + return "RESERVOIR_SAMPLE"; + case PhysicalOperatorType::STREAMING_SAMPLE: + return "STREAMING_SAMPLE"; + case PhysicalOperatorType::TOP_N: + return "TOP_N"; + case PhysicalOperatorType::WINDOW: + return "WINDOW"; + case PhysicalOperatorType::STREAMING_WINDOW: + return "STREAMING_WINDOW"; + case PhysicalOperatorType::UNNEST: + return "UNNEST"; + case PhysicalOperatorType::UNGROUPED_AGGREGATE: + return "UNGROUPED_AGGREGATE"; + case PhysicalOperatorType::HASH_GROUP_BY: + return "HASH_GROUP_BY"; + case PhysicalOperatorType::PERFECT_HASH_GROUP_BY: + return "PERFECT_HASH_GROUP_BY"; + case PhysicalOperatorType::FILTER: + return "FILTER"; + case PhysicalOperatorType::PROJECTION: + return "PROJECTION"; + case PhysicalOperatorType::COPY_TO_FILE: + return "COPY_TO_FILE"; + case PhysicalOperatorType::BATCH_COPY_TO_FILE: + return "BATCH_COPY_TO_FILE"; + case PhysicalOperatorType::FIXED_BATCH_COPY_TO_FILE: + return "FIXED_BATCH_COPY_TO_FILE"; + case PhysicalOperatorType::DELIM_JOIN: + return "DELIM_JOIN"; + case PhysicalOperatorType::BLOCKWISE_NL_JOIN: + return "BLOCKWISE_NL_JOIN"; + case PhysicalOperatorType::NESTED_LOOP_JOIN: + return "NESTED_LOOP_JOIN"; + case PhysicalOperatorType::HASH_JOIN: + return "HASH_JOIN"; + case PhysicalOperatorType::INDEX_JOIN: + return "INDEX_JOIN"; + case PhysicalOperatorType::PIECEWISE_MERGE_JOIN: + return "PIECEWISE_MERGE_JOIN"; + case PhysicalOperatorType::IE_JOIN: + return "IE_JOIN"; + case PhysicalOperatorType::ASOF_JOIN: + return "ASOF_JOIN"; + case PhysicalOperatorType::CROSS_PRODUCT: + return "CROSS_PRODUCT"; + case PhysicalOperatorType::POSITIONAL_JOIN: + return "POSITIONAL_JOIN"; + case PhysicalOperatorType::POSITIONAL_SCAN: + return "POSITIONAL_SCAN"; + case PhysicalOperatorType::UNION: + return "UNION"; + case PhysicalOperatorType::INSERT: + return "INSERT"; + case PhysicalOperatorType::BATCH_INSERT: + return "BATCH_INSERT"; + case PhysicalOperatorType::DELETE_OPERATOR: + return "DELETE"; + case PhysicalOperatorType::UPDATE: + return "UPDATE"; + case PhysicalOperatorType::EMPTY_RESULT: + return "EMPTY_RESULT"; + case PhysicalOperatorType::CREATE_TABLE: + return "CREATE_TABLE"; + case PhysicalOperatorType::CREATE_TABLE_AS: + return "CREATE_TABLE_AS"; + case PhysicalOperatorType::BATCH_CREATE_TABLE_AS: + return "BATCH_CREATE_TABLE_AS"; + case PhysicalOperatorType::CREATE_INDEX: + return "CREATE_INDEX"; + case PhysicalOperatorType::EXPLAIN: + return "EXPLAIN"; + case PhysicalOperatorType::EXPLAIN_ANALYZE: + return "EXPLAIN_ANALYZE"; + case PhysicalOperatorType::EXECUTE: + return "EXECUTE"; + case PhysicalOperatorType::VACUUM: + return "VACUUM"; + case PhysicalOperatorType::RECURSIVE_CTE: + return "REC_CTE"; + case PhysicalOperatorType::CTE: + return "CTE"; + case PhysicalOperatorType::RECURSIVE_CTE_SCAN: + return "REC_CTE_SCAN"; + case PhysicalOperatorType::CTE_SCAN: + return "CTE_SCAN"; + case PhysicalOperatorType::EXPRESSION_SCAN: + return "EXPRESSION_SCAN"; + case PhysicalOperatorType::ALTER: + return "ALTER"; + case PhysicalOperatorType::CREATE_SEQUENCE: + return "CREATE_SEQUENCE"; + case PhysicalOperatorType::CREATE_VIEW: + return "CREATE_VIEW"; + case PhysicalOperatorType::CREATE_SCHEMA: + return "CREATE_SCHEMA"; + case PhysicalOperatorType::CREATE_MACRO: + return "CREATE_MACRO"; + case PhysicalOperatorType::DROP: + return "DROP"; + case PhysicalOperatorType::PRAGMA: + return "PRAGMA"; + case PhysicalOperatorType::TRANSACTION: + return "TRANSACTION"; + case PhysicalOperatorType::PREPARE: + return "PREPARE"; + case PhysicalOperatorType::EXPORT: + return "EXPORT"; + case PhysicalOperatorType::SET: + return "SET"; + case PhysicalOperatorType::RESET: + return "RESET"; + case PhysicalOperatorType::LOAD: + return "LOAD"; + case PhysicalOperatorType::INOUT_FUNCTION: + return "INOUT_FUNCTION"; + case PhysicalOperatorType::CREATE_TYPE: + return "CREATE_TYPE"; + case PhysicalOperatorType::ATTACH: + return "ATTACH"; + case PhysicalOperatorType::DETACH: + return "DETACH"; + case PhysicalOperatorType::RESULT_COLLECTOR: + return "RESULT_COLLECTOR"; + case PhysicalOperatorType::EXTENSION: + return "EXTENSION"; + case PhysicalOperatorType::PIVOT: + return "PIVOT"; + case PhysicalOperatorType::INVALID: + break; + } + return "INVALID"; +} +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/relation_type.cpp b/src/duckdb/src/common/enums/relation_type.cpp new file mode 100644 index 00000000..3c55f95a --- /dev/null +++ b/src/duckdb/src/common/enums/relation_type.cpp @@ -0,0 +1,65 @@ +#include "duckdb/common/enums/relation_type.hpp" + +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +// LCOV_EXCL_START +string RelationTypeToString(RelationType type) { + switch (type) { + case RelationType::TABLE_RELATION: + return "TABLE_RELATION"; + case RelationType::PROJECTION_RELATION: + return "PROJECTION_RELATION"; + case RelationType::FILTER_RELATION: + return "FILTER_RELATION"; + case RelationType::EXPLAIN_RELATION: + return "EXPLAIN_RELATION"; + case RelationType::CROSS_PRODUCT_RELATION: + return "CROSS_PRODUCT_RELATION"; + case RelationType::JOIN_RELATION: + return "JOIN_RELATION"; + case RelationType::AGGREGATE_RELATION: + return "AGGREGATE_RELATION"; + case RelationType::SET_OPERATION_RELATION: + return "SET_OPERATION_RELATION"; + case RelationType::DISTINCT_RELATION: + return "DISTINCT_RELATION"; + case RelationType::LIMIT_RELATION: + return "LIMIT_RELATION"; + case RelationType::ORDER_RELATION: + return "ORDER_RELATION"; + case RelationType::CREATE_VIEW_RELATION: + return "CREATE_VIEW_RELATION"; + case RelationType::CREATE_TABLE_RELATION: + return "CREATE_TABLE_RELATION"; + case RelationType::INSERT_RELATION: + return "INSERT_RELATION"; + case RelationType::VALUE_LIST_RELATION: + return "VALUE_LIST_RELATION"; + case RelationType::DELETE_RELATION: + return "DELETE_RELATION"; + case RelationType::UPDATE_RELATION: + return "UPDATE_RELATION"; + case RelationType::WRITE_CSV_RELATION: + return "WRITE_CSV_RELATION"; + case RelationType::WRITE_PARQUET_RELATION: + return "WRITE_PARQUET_RELATION"; + case RelationType::READ_CSV_RELATION: + return "READ_CSV_RELATION"; + case RelationType::SUBQUERY_RELATION: + return "SUBQUERY_RELATION"; + case RelationType::TABLE_FUNCTION_RELATION: + return "TABLE_FUNCTION_RELATION"; + case RelationType::VIEW_RELATION: + return "VIEW_RELATION"; + case RelationType::QUERY_RELATION: + return "QUERY_RELATION"; + case RelationType::INVALID_RELATION: + break; + } + return "INVALID_RELATION"; +} +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/statement_type.cpp b/src/duckdb/src/common/enums/statement_type.cpp new file mode 100644 index 00000000..c3a48658 --- /dev/null +++ b/src/duckdb/src/common/enums/statement_type.cpp @@ -0,0 +1,83 @@ +#include "duckdb/common/enums/statement_type.hpp" + +namespace duckdb { + +// LCOV_EXCL_START +string StatementTypeToString(StatementType type) { + switch (type) { + case StatementType::SELECT_STATEMENT: + return "SELECT"; + case StatementType::INSERT_STATEMENT: + return "INSERT"; + case StatementType::UPDATE_STATEMENT: + return "UPDATE"; + case StatementType::DELETE_STATEMENT: + return "DELETE"; + case StatementType::PREPARE_STATEMENT: + return "PREPARE"; + case StatementType::EXECUTE_STATEMENT: + return "EXECUTE"; + case StatementType::ALTER_STATEMENT: + return "ALTER"; + case StatementType::TRANSACTION_STATEMENT: + return "TRANSACTION"; + case StatementType::COPY_STATEMENT: + return "COPY"; + case StatementType::ANALYZE_STATEMENT: + return "ANALYZE"; + case StatementType::VARIABLE_SET_STATEMENT: + return "VARIABLE_SET"; + case StatementType::CREATE_FUNC_STATEMENT: + return "CREATE_FUNC"; + case StatementType::EXPLAIN_STATEMENT: + return "EXPLAIN"; + case StatementType::CREATE_STATEMENT: + return "CREATE"; + case StatementType::DROP_STATEMENT: + return "DROP"; + case StatementType::PRAGMA_STATEMENT: + return "PRAGMA"; + case StatementType::SHOW_STATEMENT: + return "SHOW"; + case StatementType::VACUUM_STATEMENT: + return "VACUUM"; + case StatementType::RELATION_STATEMENT: + return "RELATION"; + case StatementType::EXPORT_STATEMENT: + return "EXPORT"; + case StatementType::CALL_STATEMENT: + return "CALL"; + case StatementType::SET_STATEMENT: + return "SET"; + case StatementType::LOAD_STATEMENT: + return "LOAD"; + case StatementType::EXTENSION_STATEMENT: + return "EXTENSION"; + case StatementType::LOGICAL_PLAN_STATEMENT: + return "LOGICAL_PLAN"; + case StatementType::ATTACH_STATEMENT: + return "ATTACH"; + case StatementType::DETACH_STATEMENT: + return "DETACH"; + case StatementType::MULTI_STATEMENT: + return "MULTI"; + case StatementType::INVALID_STATEMENT: + break; + } + return "INVALID"; +} + +string StatementReturnTypeToString(StatementReturnType type) { + switch (type) { + case StatementReturnType::QUERY_RESULT: + return "QUERY_RESULT"; + case StatementReturnType::CHANGED_ROWS: + return "CHANGED_ROWS"; + case StatementReturnType::NOTHING: + return "NOTHING"; + } + return "INVALID"; +} +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/common/exception.cpp b/src/duckdb/src/common/exception.cpp new file mode 100644 index 00000000..68e761ea --- /dev/null +++ b/src/duckdb/src/common/exception.cpp @@ -0,0 +1,396 @@ +#include "duckdb/common/exception.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types.hpp" + +#ifdef DUCKDB_CRASH_ON_ASSERT +#include "duckdb/common/printer.hpp" +#include +#include +#endif +#ifdef DUCKDB_DEBUG_STACKTRACE +#include +#endif + +namespace duckdb { + +Exception::Exception(const string &msg) : std::exception(), type(ExceptionType::INVALID), raw_message_(msg) { + exception_message_ = msg; +} + +Exception::Exception(ExceptionType exception_type, const string &message) + : std::exception(), type(exception_type), raw_message_(message) { + exception_message_ = ExceptionTypeToString(exception_type) + " Error: " + message; +} + +const char *Exception::what() const noexcept { + return exception_message_.c_str(); +} + +const string &Exception::RawMessage() const { + return raw_message_; +} + +bool Exception::UncaughtException() { +#if __cplusplus >= 201703L + return std::uncaught_exceptions() > 0; +#else + return std::uncaught_exception(); +#endif +} + +string Exception::GetStackTrace(int max_depth) { +#ifdef DUCKDB_DEBUG_STACKTRACE + string result; + auto callstack = unique_ptr(new void *[max_depth]); + int frames = backtrace(callstack.get(), max_depth); + char **strs = backtrace_symbols(callstack.get(), frames); + for (int i = 0; i < frames; i++) { + result += strs[i]; + result += "\n"; + } + free(strs); + return "\n" + result; +#else + // Stack trace not available. Toggle DUCKDB_DEBUG_STACKTRACE in exception.cpp to enable stack traces. + return ""; +#endif +} + +string Exception::ConstructMessageRecursive(const string &msg, std::vector &values) { +#ifdef DEBUG + // Verify that we have the required amount of values for the message + idx_t parameter_count = 0; + for (idx_t i = 0; i + 1 < msg.size(); i++) { + if (msg[i] != '%') { + continue; + } + if (msg[i + 1] == '%') { + i++; + continue; + } + parameter_count++; + } + if (parameter_count != values.size()) { + throw InternalException("Primary exception: %s\nSecondary exception in ConstructMessageRecursive: Expected %d " + "parameters, received %d", + msg.c_str(), parameter_count, values.size()); + } + +#endif + return ExceptionFormatValue::Format(msg, values); +} + +string Exception::ExceptionTypeToString(ExceptionType type) { + switch (type) { + case ExceptionType::INVALID: + return "Invalid"; + case ExceptionType::OUT_OF_RANGE: + return "Out of Range"; + case ExceptionType::CONVERSION: + return "Conversion"; + case ExceptionType::UNKNOWN_TYPE: + return "Unknown Type"; + case ExceptionType::DECIMAL: + return "Decimal"; + case ExceptionType::MISMATCH_TYPE: + return "Mismatch Type"; + case ExceptionType::DIVIDE_BY_ZERO: + return "Divide by Zero"; + case ExceptionType::OBJECT_SIZE: + return "Object Size"; + case ExceptionType::INVALID_TYPE: + return "Invalid type"; + case ExceptionType::SERIALIZATION: + return "Serialization"; + case ExceptionType::TRANSACTION: + return "TransactionContext"; + case ExceptionType::NOT_IMPLEMENTED: + return "Not implemented"; + case ExceptionType::EXPRESSION: + return "Expression"; + case ExceptionType::CATALOG: + return "Catalog"; + case ExceptionType::PARSER: + return "Parser"; + case ExceptionType::BINDER: + return "Binder"; + case ExceptionType::PLANNER: + return "Planner"; + case ExceptionType::SCHEDULER: + return "Scheduler"; + case ExceptionType::EXECUTOR: + return "Executor"; + case ExceptionType::CONSTRAINT: + return "Constraint"; + case ExceptionType::INDEX: + return "Index"; + case ExceptionType::STAT: + return "Stat"; + case ExceptionType::CONNECTION: + return "Connection"; + case ExceptionType::SYNTAX: + return "Syntax"; + case ExceptionType::SETTINGS: + return "Settings"; + case ExceptionType::OPTIMIZER: + return "Optimizer"; + case ExceptionType::NULL_POINTER: + return "NullPointer"; + case ExceptionType::IO: + return "IO"; + case ExceptionType::INTERRUPT: + return "INTERRUPT"; + case ExceptionType::FATAL: + return "FATAL"; + case ExceptionType::INTERNAL: + return "INTERNAL"; + case ExceptionType::INVALID_INPUT: + return "Invalid Input"; + case ExceptionType::OUT_OF_MEMORY: + return "Out of Memory"; + case ExceptionType::PERMISSION: + return "Permission"; + case ExceptionType::PARAMETER_NOT_RESOLVED: + return "Parameter Not Resolved"; + case ExceptionType::PARAMETER_NOT_ALLOWED: + return "Parameter Not Allowed"; + case ExceptionType::DEPENDENCY: + return "Dependency"; + case ExceptionType::MISSING_EXTENSION: + return "Missing Extension"; + case ExceptionType::HTTP: + return "HTTP"; + case ExceptionType::AUTOLOAD: + return "Extension Autoloading"; + default: + return "Unknown"; + } +} + +const HTTPException &Exception::AsHTTPException() const { + D_ASSERT(type == ExceptionType::HTTP); + const auto &e = static_cast(this); + D_ASSERT(e->GetStatusCode() != 0); + D_ASSERT(e->GetHeaders().size() > 0); + return *e; +} + +void Exception::ThrowAsTypeWithMessage(ExceptionType type, const string &message, + const std::shared_ptr &original) { + switch (type) { + case ExceptionType::OUT_OF_RANGE: + throw OutOfRangeException(message); + case ExceptionType::CONVERSION: + throw ConversionException(message); // FIXME: make a separation between Conversion/Cast exception? + case ExceptionType::INVALID_TYPE: + throw InvalidTypeException(message); + case ExceptionType::MISMATCH_TYPE: + throw TypeMismatchException(message); + case ExceptionType::TRANSACTION: + throw TransactionException(message); + case ExceptionType::NOT_IMPLEMENTED: + throw NotImplementedException(message); + case ExceptionType::CATALOG: + throw CatalogException(message); + case ExceptionType::CONNECTION: + throw ConnectionException(message); + case ExceptionType::PARSER: + throw ParserException(message); + case ExceptionType::PERMISSION: + throw PermissionException(message); + case ExceptionType::SYNTAX: + throw SyntaxException(message); + case ExceptionType::CONSTRAINT: + throw ConstraintException(message); + case ExceptionType::BINDER: + throw BinderException(message); + case ExceptionType::IO: + throw IOException(message); + case ExceptionType::SERIALIZATION: + throw SerializationException(message); + case ExceptionType::INTERRUPT: + throw InterruptException(); + case ExceptionType::INTERNAL: + throw InternalException(message); + case ExceptionType::INVALID_INPUT: + throw InvalidInputException(message); + case ExceptionType::OUT_OF_MEMORY: + throw OutOfMemoryException(message); + case ExceptionType::PARAMETER_NOT_ALLOWED: + throw ParameterNotAllowedException(message); + case ExceptionType::PARAMETER_NOT_RESOLVED: + throw ParameterNotResolvedException(); + case ExceptionType::FATAL: + throw FatalException(message); + case ExceptionType::DEPENDENCY: + throw DependencyException(message); + case ExceptionType::HTTP: { + original->AsHTTPException().Throw(); + } + case ExceptionType::MISSING_EXTENSION: + throw MissingExtensionException(message); + default: + throw Exception(type, message); + } +} + +StandardException::StandardException(ExceptionType exception_type, const string &message) + : Exception(exception_type, message) { +} + +CastException::CastException(const PhysicalType orig_type, const PhysicalType new_type) + : Exception(ExceptionType::CONVERSION, + "Type " + TypeIdToString(orig_type) + " can't be cast as " + TypeIdToString(new_type)) { +} + +CastException::CastException(const LogicalType &orig_type, const LogicalType &new_type) + : Exception(ExceptionType::CONVERSION, + "Type " + orig_type.ToString() + " can't be cast as " + new_type.ToString()) { +} + +CastException::CastException(const string &msg) : Exception(ExceptionType::CONVERSION, msg) { +} + +ValueOutOfRangeException::ValueOutOfRangeException(const int64_t value, const PhysicalType orig_type, + const PhysicalType new_type) + : Exception(ExceptionType::CONVERSION, "Type " + TypeIdToString(orig_type) + " with value " + + to_string((intmax_t)value) + + " can't be cast because the value is out of range " + "for the destination type " + + TypeIdToString(new_type)) { +} + +ValueOutOfRangeException::ValueOutOfRangeException(const double value, const PhysicalType orig_type, + const PhysicalType new_type) + : Exception(ExceptionType::CONVERSION, "Type " + TypeIdToString(orig_type) + " with value " + to_string(value) + + " can't be cast because the value is out of range " + "for the destination type " + + TypeIdToString(new_type)) { +} + +ValueOutOfRangeException::ValueOutOfRangeException(const hugeint_t value, const PhysicalType orig_type, + const PhysicalType new_type) + : Exception(ExceptionType::CONVERSION, "Type " + TypeIdToString(orig_type) + " with value " + value.ToString() + + " can't be cast because the value is out of range " + "for the destination type " + + TypeIdToString(new_type)) { +} + +ValueOutOfRangeException::ValueOutOfRangeException(const PhysicalType var_type, const idx_t length) + : Exception(ExceptionType::OUT_OF_RANGE, + "The value is too long to fit into type " + TypeIdToString(var_type) + "(" + to_string(length) + ")") { +} + +ValueOutOfRangeException::ValueOutOfRangeException(const string &msg) : Exception(ExceptionType::OUT_OF_RANGE, msg) { +} + +ConversionException::ConversionException(const string &msg) : Exception(ExceptionType::CONVERSION, msg) { +} + +InvalidTypeException::InvalidTypeException(PhysicalType type, const string &msg) + : Exception(ExceptionType::INVALID_TYPE, "Invalid Type [" + TypeIdToString(type) + "]: " + msg) { +} + +InvalidTypeException::InvalidTypeException(const LogicalType &type, const string &msg) + : Exception(ExceptionType::INVALID_TYPE, "Invalid Type [" + type.ToString() + "]: " + msg) { +} + +InvalidTypeException::InvalidTypeException(const string &msg) : Exception(ExceptionType::INVALID_TYPE, msg) { +} + +TypeMismatchException::TypeMismatchException(const PhysicalType type_1, const PhysicalType type_2, const string &msg) + : Exception(ExceptionType::MISMATCH_TYPE, + "Type " + TypeIdToString(type_1) + " does not match with " + TypeIdToString(type_2) + ". " + msg) { +} + +TypeMismatchException::TypeMismatchException(const LogicalType &type_1, const LogicalType &type_2, const string &msg) + : Exception(ExceptionType::MISMATCH_TYPE, + "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg) { +} + +TypeMismatchException::TypeMismatchException(const string &msg) : Exception(ExceptionType::MISMATCH_TYPE, msg) { +} + +TransactionException::TransactionException(const string &msg) : Exception(ExceptionType::TRANSACTION, msg) { +} + +NotImplementedException::NotImplementedException(const string &msg) : Exception(ExceptionType::NOT_IMPLEMENTED, msg) { +} + +OutOfRangeException::OutOfRangeException(const string &msg) : Exception(ExceptionType::OUT_OF_RANGE, msg) { +} + +CatalogException::CatalogException(const string &msg) : StandardException(ExceptionType::CATALOG, msg) { +} + +ConnectionException::ConnectionException(const string &msg) : StandardException(ExceptionType::CONNECTION, msg) { +} + +ParserException::ParserException(const string &msg) : StandardException(ExceptionType::PARSER, msg) { +} + +PermissionException::PermissionException(const string &msg) : StandardException(ExceptionType::PERMISSION, msg) { +} + +SyntaxException::SyntaxException(const string &msg) : Exception(ExceptionType::SYNTAX, msg) { +} + +ConstraintException::ConstraintException(const string &msg) : Exception(ExceptionType::CONSTRAINT, msg) { +} + +DependencyException::DependencyException(const string &msg) : Exception(ExceptionType::DEPENDENCY, msg) { +} + +BinderException::BinderException(const string &msg) : StandardException(ExceptionType::BINDER, msg) { +} + +IOException::IOException(const string &msg) : Exception(ExceptionType::IO, msg) { +} + +MissingExtensionException::MissingExtensionException(const string &msg) + : Exception(ExceptionType::MISSING_EXTENSION, msg) { +} + +AutoloadException::AutoloadException(const string &extension_name, Exception &e) + : Exception(ExceptionType::AUTOLOAD, + "An error occurred while trying to automatically install the required extension '" + extension_name + + "':\n" + e.RawMessage()), + wrapped_exception(e) { +} + +SerializationException::SerializationException(const string &msg) : Exception(ExceptionType::SERIALIZATION, msg) { +} + +SequenceException::SequenceException(const string &msg) : Exception(ExceptionType::SERIALIZATION, msg) { +} + +InterruptException::InterruptException() : Exception(ExceptionType::INTERRUPT, "Interrupted!") { +} + +FatalException::FatalException(ExceptionType type, const string &msg) : Exception(type, msg) { +} + +InternalException::InternalException(const string &msg) : FatalException(ExceptionType::INTERNAL, msg) { +#ifdef DUCKDB_CRASH_ON_ASSERT + Printer::Print("ABORT THROWN BY INTERNAL EXCEPTION: " + msg); + abort(); +#endif +} + +InvalidInputException::InvalidInputException(const string &msg) : Exception(ExceptionType::INVALID_INPUT, msg) { +} + +OutOfMemoryException::OutOfMemoryException(const string &msg) : Exception(ExceptionType::OUT_OF_MEMORY, msg) { +} + +ParameterNotAllowedException::ParameterNotAllowedException(const string &msg) + : StandardException(ExceptionType::PARAMETER_NOT_ALLOWED, msg) { +} + +ParameterNotResolvedException::ParameterNotResolvedException() + : Exception(ExceptionType::PARAMETER_NOT_RESOLVED, "Parameter types could not be resolved") { +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/exception_format_value.cpp b/src/duckdb/src/common/exception_format_value.cpp new file mode 100644 index 00000000..5a433030 --- /dev/null +++ b/src/duckdb/src/common/exception_format_value.cpp @@ -0,0 +1,98 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types.hpp" +#include "fmt/format.h" +#include "fmt/printf.h" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/parser/keyword_helper.hpp" + +namespace duckdb { + +ExceptionFormatValue::ExceptionFormatValue(double dbl_val) + : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE), dbl_val(dbl_val) { +} +ExceptionFormatValue::ExceptionFormatValue(int64_t int_val) + : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER), int_val(int_val) { +} +ExceptionFormatValue::ExceptionFormatValue(hugeint_t huge_val) + : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(Hugeint::ToString(huge_val)) { +} +ExceptionFormatValue::ExceptionFormatValue(string str_val) + : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(std::move(str_val)) { +} + +template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value) { + return ExceptionFormatValue(TypeIdToString(value)); +} +template <> +ExceptionFormatValue +ExceptionFormatValue::CreateFormatValue(LogicalType value) { // NOLINT: templating requires us to copy value here + return ExceptionFormatValue(value.ToString()); +} +template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value) { + return ExceptionFormatValue(double(value)); +} +template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value) { + return ExceptionFormatValue(double(value)); +} +template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value) { + return ExceptionFormatValue(std::move(value)); +} + +template <> +ExceptionFormatValue +ExceptionFormatValue::CreateFormatValue(SQLString value) { // NOLINT: templating requires us to copy value here + return KeywordHelper::WriteQuoted(value.raw_string, '\''); +} + +template <> +ExceptionFormatValue +ExceptionFormatValue::CreateFormatValue(SQLIdentifier value) { // NOLINT: templating requires us to copy value here + return KeywordHelper::WriteOptionallyQuoted(value.raw_string, '"'); +} + +template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value) { + return ExceptionFormatValue(string(value)); +} +template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value) { + return ExceptionFormatValue(string(value)); +} +template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value) { + return ExceptionFormatValue(value); +} + +string ExceptionFormatValue::Format(const string &msg, std::vector &values) { + try { + std::vector> format_args; + for (auto &val : values) { + switch (val.type) { + case ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE: + format_args.push_back(duckdb_fmt::internal::make_arg(val.dbl_val)); + break; + case ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER: + format_args.push_back(duckdb_fmt::internal::make_arg(val.int_val)); + break; + case ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING: + format_args.push_back(duckdb_fmt::internal::make_arg(val.str_val)); + break; + } + } + return duckdb_fmt::vsprintf(msg, duckdb_fmt::basic_format_args( + format_args.data(), static_cast(format_args.size()))); + } catch (std::exception &ex) { // LCOV_EXCL_START + // work-around for oss-fuzz limiting memory which causes issues here + if (StringUtil::Contains(ex.what(), "fuzz mode")) { + throw Exception(msg); + } + throw InternalException(std::string("Primary exception: ") + msg + + "\nSecondary exception in ExceptionFormatValue: " + ex.what()); + } // LCOV_EXCL_STOP +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/extra_type_info.cpp b/src/duckdb/src/common/extra_type_info.cpp new file mode 100644 index 00000000..235b5b7c --- /dev/null +++ b/src/duckdb/src/common/extra_type_info.cpp @@ -0,0 +1,315 @@ +#include "duckdb/common/extra_type_info.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/common/string_map_set.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Extra Type Info +//===--------------------------------------------------------------------===// +ExtraTypeInfo::ExtraTypeInfo(ExtraTypeInfoType type) : type(type) { +} +ExtraTypeInfo::ExtraTypeInfo(ExtraTypeInfoType type, string alias) : type(type), alias(std::move(alias)) { +} +ExtraTypeInfo::~ExtraTypeInfo() { +} + +bool ExtraTypeInfo::Equals(ExtraTypeInfo *other_p) const { + if (type == ExtraTypeInfoType::INVALID_TYPE_INFO || type == ExtraTypeInfoType::STRING_TYPE_INFO || + type == ExtraTypeInfoType::GENERIC_TYPE_INFO) { + if (!other_p) { + if (!alias.empty()) { + return false; + } + //! We only need to compare aliases when both types have them in this case + return true; + } + if (alias != other_p->alias) { + return false; + } + return true; + } + if (!other_p) { + return false; + } + if (type != other_p->type) { + return false; + } + return alias == other_p->alias && EqualsInternal(other_p); +} + +bool ExtraTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + // Do nothing + return true; +} + +//===--------------------------------------------------------------------===// +// Decimal Type Info +//===--------------------------------------------------------------------===// +DecimalTypeInfo::DecimalTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::DECIMAL_TYPE_INFO) { +} + +DecimalTypeInfo::DecimalTypeInfo(uint8_t width_p, uint8_t scale_p) + : ExtraTypeInfo(ExtraTypeInfoType::DECIMAL_TYPE_INFO), width(width_p), scale(scale_p) { + D_ASSERT(width_p >= scale_p); +} + +bool DecimalTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + auto &other = other_p->Cast(); + return width == other.width && scale == other.scale; +} + +//===--------------------------------------------------------------------===// +// String Type Info +//===--------------------------------------------------------------------===// +StringTypeInfo::StringTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::STRING_TYPE_INFO) { +} + +StringTypeInfo::StringTypeInfo(string collation_p) + : ExtraTypeInfo(ExtraTypeInfoType::STRING_TYPE_INFO), collation(std::move(collation_p)) { +} + +bool StringTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + // collation info has no impact on equality + return true; +} + +//===--------------------------------------------------------------------===// +// List Type Info +//===--------------------------------------------------------------------===// +ListTypeInfo::ListTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::LIST_TYPE_INFO) { +} + +ListTypeInfo::ListTypeInfo(LogicalType child_type_p) + : ExtraTypeInfo(ExtraTypeInfoType::LIST_TYPE_INFO), child_type(std::move(child_type_p)) { +} + +bool ListTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + auto &other = other_p->Cast(); + return child_type == other.child_type; +} + +//===--------------------------------------------------------------------===// +// Struct Type Info +//===--------------------------------------------------------------------===// +StructTypeInfo::StructTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::STRUCT_TYPE_INFO) { +} + +StructTypeInfo::StructTypeInfo(child_list_t child_types_p) + : ExtraTypeInfo(ExtraTypeInfoType::STRUCT_TYPE_INFO), child_types(std::move(child_types_p)) { +} + +bool StructTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + auto &other = other_p->Cast(); + return child_types == other.child_types; +} + +//===--------------------------------------------------------------------===// +// Aggregate State Type Info +//===--------------------------------------------------------------------===// +AggregateStateTypeInfo::AggregateStateTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO) { +} + +AggregateStateTypeInfo::AggregateStateTypeInfo(aggregate_state_t state_type_p) + : ExtraTypeInfo(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO), state_type(std::move(state_type_p)) { +} + +bool AggregateStateTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + auto &other = other_p->Cast(); + return state_type.function_name == other.state_type.function_name && + state_type.return_type == other.state_type.return_type && + state_type.bound_argument_types == other.state_type.bound_argument_types; +} + +//===--------------------------------------------------------------------===// +// User Type Info +//===--------------------------------------------------------------------===// +UserTypeInfo::UserTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::USER_TYPE_INFO) { +} + +UserTypeInfo::UserTypeInfo(string name_p) + : ExtraTypeInfo(ExtraTypeInfoType::USER_TYPE_INFO), user_type_name(std::move(name_p)) { +} + +bool UserTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + auto &other = other_p->Cast(); + return other.user_type_name == user_type_name; +} + +//===--------------------------------------------------------------------===// +// Enum Type Info +//===--------------------------------------------------------------------===// +PhysicalType EnumTypeInfo::DictType(idx_t size) { + if (size <= NumericLimits::Maximum()) { + return PhysicalType::UINT8; + } else if (size <= NumericLimits::Maximum()) { + return PhysicalType::UINT16; + } else if (size <= NumericLimits::Maximum()) { + return PhysicalType::UINT32; + } else { + throw InternalException("Enum size must be lower than " + std::to_string(NumericLimits::Maximum())); + } +} + +template +struct EnumTypeInfoTemplated : public EnumTypeInfo { + explicit EnumTypeInfoTemplated(Vector &values_insert_order_p, idx_t size_p) + : EnumTypeInfo(values_insert_order_p, size_p) { + D_ASSERT(values_insert_order_p.GetType().InternalType() == PhysicalType::VARCHAR); + + UnifiedVectorFormat vdata; + values_insert_order.ToUnifiedFormat(size_p, vdata); + + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < size_p; i++) { + auto idx = vdata.sel->get_index(i); + if (!vdata.validity.RowIsValid(idx)) { + throw InternalException("Attempted to create ENUM type with NULL value"); + } + if (values.count(data[idx]) > 0) { + throw InvalidInputException("Attempted to create ENUM type with duplicate value %s", + data[idx].GetString()); + } + values[data[idx]] = i; + } + } + + static shared_ptr Deserialize(Deserializer &deserializer, uint32_t size) { + Vector values_insert_order(LogicalType::VARCHAR, size); + auto strings = FlatVector::GetData(values_insert_order); + + deserializer.ReadList(201, "values", [&](Deserializer::List &list, idx_t i) { + strings[i] = StringVector::AddStringOrBlob(values_insert_order, list.ReadElement()); + }); + return make_shared(values_insert_order, size); + } + + const string_map_t &GetValues() const { + return values; + } + + EnumTypeInfoTemplated(const EnumTypeInfoTemplated &) = delete; + EnumTypeInfoTemplated &operator=(const EnumTypeInfoTemplated &) = delete; + +private: + string_map_t values; +}; + +EnumTypeInfo::EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p) + : ExtraTypeInfo(ExtraTypeInfoType::ENUM_TYPE_INFO), values_insert_order(values_insert_order_p), + dict_type(EnumDictType::VECTOR_DICT), dict_size(dict_size_p) { +} + +const EnumDictType &EnumTypeInfo::GetEnumDictType() const { + return dict_type; +} + +const Vector &EnumTypeInfo::GetValuesInsertOrder() const { + return values_insert_order; +} + +const idx_t &EnumTypeInfo::GetDictSize() const { + return dict_size; +} + +LogicalType EnumTypeInfo::CreateType(Vector &ordered_data, idx_t size) { + // Generate EnumTypeInfo + shared_ptr info; + auto enum_internal_type = EnumTypeInfo::DictType(size); + switch (enum_internal_type) { + case PhysicalType::UINT8: + info = make_shared>(ordered_data, size); + break; + case PhysicalType::UINT16: + info = make_shared>(ordered_data, size); + break; + case PhysicalType::UINT32: + info = make_shared>(ordered_data, size); + break; + default: + throw InternalException("Invalid Physical Type for ENUMs"); + } + // Generate Actual Enum Type + return LogicalType(LogicalTypeId::ENUM, info); +} + +template +int64_t TemplatedGetPos(const string_map_t &map, const string_t &key) { + auto it = map.find(key); + if (it == map.end()) { + return -1; + } + return it->second; +} + +int64_t EnumType::GetPos(const LogicalType &type, const string_t &key) { + auto info = type.AuxInfo(); + switch (type.InternalType()) { + case PhysicalType::UINT8: + return TemplatedGetPos(info->Cast>().GetValues(), key); + case PhysicalType::UINT16: + return TemplatedGetPos(info->Cast>().GetValues(), key); + case PhysicalType::UINT32: + return TemplatedGetPos(info->Cast>().GetValues(), key); + default: + throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); + } +} + +string_t EnumType::GetString(const LogicalType &type, idx_t pos) { + D_ASSERT(pos < EnumType::GetSize(type)); + return FlatVector::GetData(EnumType::GetValuesInsertOrder(type))[pos]; +} + +shared_ptr EnumTypeInfo::Deserialize(Deserializer &deserializer) { + auto values_count = deserializer.ReadProperty(200, "values_count"); + auto enum_internal_type = EnumTypeInfo::DictType(values_count); + switch (enum_internal_type) { + case PhysicalType::UINT8: + return EnumTypeInfoTemplated::Deserialize(deserializer, values_count); + case PhysicalType::UINT16: + return EnumTypeInfoTemplated::Deserialize(deserializer, values_count); + case PhysicalType::UINT32: + return EnumTypeInfoTemplated::Deserialize(deserializer, values_count); + default: + throw InternalException("Invalid Physical Type for ENUMs"); + } +} + +// Equalities are only used in enums with different catalog entries +bool EnumTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + auto &other = other_p->Cast(); + if (dict_type != other.dict_type) { + return false; + } + D_ASSERT(dict_type == EnumDictType::VECTOR_DICT); + // We must check if both enums have the same size + if (other.dict_size != dict_size) { + return false; + } + auto other_vector_ptr = FlatVector::GetData(other.values_insert_order); + auto this_vector_ptr = FlatVector::GetData(values_insert_order); + + // Now we must check if all strings are the same + for (idx_t i = 0; i < dict_size; i++) { + if (!Equals::Operation(other_vector_ptr[i], this_vector_ptr[i])) { + return false; + } + } + return true; +} + +void EnumTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); + + // Enums are special in that we serialize their values as a list instead of dumping the whole vector + auto strings = FlatVector::GetData(values_insert_order); + serializer.WriteProperty(200, "values_count", dict_size); + serializer.WriteList(201, "values", dict_size, + [&](Serializer::List &list, idx_t i) { list.WriteElement(strings[i]); }); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/file_buffer.cpp b/src/duckdb/src/common/file_buffer.cpp new file mode 100644 index 00000000..01fa5d77 --- /dev/null +++ b/src/duckdb/src/common/file_buffer.cpp @@ -0,0 +1,108 @@ +#include "duckdb/common/file_buffer.hpp" + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/checksum.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/storage/storage_info.hpp" +#include + +namespace duckdb { + +FileBuffer::FileBuffer(Allocator &allocator, FileBufferType type, uint64_t user_size) + : allocator(allocator), type(type) { + Init(); + if (user_size) { + Resize(user_size); + } +} + +void FileBuffer::Init() { + buffer = nullptr; + size = 0; + internal_buffer = nullptr; + internal_size = 0; +} + +FileBuffer::FileBuffer(FileBuffer &source, FileBufferType type_p) : allocator(source.allocator), type(type_p) { + // take over the structures of the source buffer + buffer = source.buffer; + size = source.size; + internal_buffer = source.internal_buffer; + internal_size = source.internal_size; + + source.Init(); +} + +FileBuffer::~FileBuffer() { + if (!internal_buffer) { + return; + } + allocator.FreeData(internal_buffer, internal_size); +} + +void FileBuffer::ReallocBuffer(size_t new_size) { + data_ptr_t new_buffer; + if (internal_buffer) { + new_buffer = allocator.ReallocateData(internal_buffer, internal_size, new_size); + } else { + new_buffer = allocator.AllocateData(new_size); + } + if (!new_buffer) { + throw std::bad_alloc(); + } + internal_buffer = new_buffer; + internal_size = new_size; + // Caller must update these. + buffer = nullptr; + size = 0; +} + +FileBuffer::MemoryRequirement FileBuffer::CalculateMemory(uint64_t user_size) { + FileBuffer::MemoryRequirement result; + + if (type == FileBufferType::TINY_BUFFER) { + // We never do IO on tiny buffers, so there's no need to add a header or sector-align. + result.header_size = 0; + result.alloc_size = user_size; + } else { + result.header_size = Storage::BLOCK_HEADER_SIZE; + result.alloc_size = AlignValue(result.header_size + user_size); + } + return result; +} + +void FileBuffer::Resize(uint64_t new_size) { + auto req = CalculateMemory(new_size); + ReallocBuffer(req.alloc_size); + + if (new_size > 0) { + buffer = internal_buffer + req.header_size; + size = internal_size - req.header_size; + } +} + +void FileBuffer::Read(FileHandle &handle, uint64_t location) { + D_ASSERT(type != FileBufferType::TINY_BUFFER); + handle.Read(internal_buffer, internal_size, location); +} + +void FileBuffer::Write(FileHandle &handle, uint64_t location) { + D_ASSERT(type != FileBufferType::TINY_BUFFER); + handle.Write(internal_buffer, internal_size, location); +} + +void FileBuffer::Clear() { + memset(internal_buffer, 0, internal_size); +} + +void FileBuffer::Initialize(DebugInitialize initialize) { + if (initialize == DebugInitialize::NO_INITIALIZE) { + return; + } + uint8_t value = initialize == DebugInitialize::DEBUG_ZERO_INITIALIZE ? 0 : 0xFF; + memset(internal_buffer, value, internal_size); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp new file mode 100644 index 00000000..be51cda9 --- /dev/null +++ b/src/duckdb/src/common/file_system.cpp @@ -0,0 +1,547 @@ +#include "duckdb/common/file_system.hpp" + +#include "duckdb/common/checksum.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_opener.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/windows.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/common/windows_util.hpp" + +#include +#include + +#ifndef _WIN32 +#include +#include +#include +#include +#include +#include + +#ifdef __MVS__ +#define _XOPEN_SOURCE_EXTENDED 1 +#include +// enjoy - https://reviews.llvm.org/D92110 +#define PATH_MAX _XOPEN_PATH_MAX +#endif + +#else +#include +#include + +#ifdef __MINGW32__ +// need to manually define this for mingw +extern "C" WINBASEAPI BOOL WINAPI GetPhysicallyInstalledSystemMemory(PULONGLONG); +#endif + +#undef FILE_CREATE // woo mingw +#endif + +namespace duckdb { + +FileSystem::~FileSystem() { +} + +FileSystem &FileSystem::GetFileSystem(ClientContext &context) { + auto &client_data = ClientData::Get(context); + return *client_data.client_file_system; +} + +bool PathMatched(const string &path, const string &sub_path) { + if (path.rfind(sub_path, 0) == 0) { + return true; + } + return false; +} + +#ifndef _WIN32 + +string FileSystem::GetEnvVariable(const string &name) { + const char *env = getenv(name.c_str()); + if (!env) { + return string(); + } + return env; +} + +bool FileSystem::IsPathAbsolute(const string &path) { + auto path_separator = PathSeparator(path); + return PathMatched(path, path_separator); +} + +string FileSystem::PathSeparator(const string &path) { + return "/"; +} + +void FileSystem::SetWorkingDirectory(const string &path) { + if (chdir(path.c_str()) != 0) { + throw IOException("Could not change working directory!"); + } +} + +idx_t FileSystem::GetAvailableMemory() { + errno = 0; + +#ifdef __MVS__ + struct rlimit limit; + int rlim_rc = getrlimit(RLIMIT_AS, &limit); + idx_t max_memory = MinValue(limit.rlim_max, UINTPTR_MAX); +#else + idx_t max_memory = MinValue((idx_t)sysconf(_SC_PHYS_PAGES) * (idx_t)sysconf(_SC_PAGESIZE), UINTPTR_MAX); +#endif + if (errno != 0) { + return DConstants::INVALID_INDEX; + } + return max_memory; +} + +string FileSystem::GetWorkingDirectory() { + auto buffer = make_unsafe_uniq_array(PATH_MAX); + char *ret = getcwd(buffer.get(), PATH_MAX); + if (!ret) { + throw IOException("Could not get working directory!"); + } + return string(buffer.get()); +} + +string FileSystem::NormalizeAbsolutePath(const string &path) { + D_ASSERT(IsPathAbsolute(path)); + return path; +} + +#else + +string FileSystem::GetEnvVariable(const string &env) { + // first convert the environment variable name to the correct encoding + auto env_w = WindowsUtil::UTF8ToUnicode(env.c_str()); + // use _wgetenv to get the value + auto res_w = _wgetenv(env_w.c_str()); + if (!res_w) { + // no environment variable of this name found + return string(); + } + return WindowsUtil::UnicodeToUTF8(res_w); +} + +static bool StartsWithSingleBackslash(const string &path) { + if (path.size() < 2) { + return false; + } + if (path[0] != '/' && path[0] != '\\') { + return false; + } + if (path[1] == '/' || path[1] == '\\') { + return false; + } + return true; +} + +bool FileSystem::IsPathAbsolute(const string &path) { + // 1) A single backslash or forward-slash + if (StartsWithSingleBackslash(path)) { + return true; + } + // 2) A disk designator with a backslash (e.g., C:\ or C:/) + auto path_aux = path; + path_aux.erase(0, 1); + if (PathMatched(path_aux, ":\\") || PathMatched(path_aux, ":/")) { + return true; + } + return false; +} + +string FileSystem::NormalizeAbsolutePath(const string &path) { + D_ASSERT(IsPathAbsolute(path)); + auto result = StringUtil::Lower(FileSystem::ConvertSeparators(path)); + if (StartsWithSingleBackslash(result)) { + // Path starts with a single backslash or forward slash + // prepend drive letter + return GetWorkingDirectory().substr(0, 2) + result; + } + return result; +} + +string FileSystem::PathSeparator(const string &path) { + return "\\"; +} + +void FileSystem::SetWorkingDirectory(const string &path) { + auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); + if (!SetCurrentDirectoryW(unicode_path.c_str())) { + throw IOException("Could not change working directory to \"%s\"", path); + } +} + +idx_t FileSystem::GetAvailableMemory() { + ULONGLONG available_memory_kb; + if (GetPhysicallyInstalledSystemMemory(&available_memory_kb)) { + return MinValue(available_memory_kb * 1000, UINTPTR_MAX); + } + // fallback: try GlobalMemoryStatusEx + MEMORYSTATUSEX mem_state; + mem_state.dwLength = sizeof(MEMORYSTATUSEX); + + if (GlobalMemoryStatusEx(&mem_state)) { + return MinValue(mem_state.ullTotalPhys, UINTPTR_MAX); + } + return DConstants::INVALID_INDEX; +} + +string FileSystem::GetWorkingDirectory() { + idx_t count = GetCurrentDirectoryW(0, nullptr); + if (count == 0) { + throw IOException("Could not get working directory!"); + } + auto buffer = make_unsafe_uniq_array(count); + idx_t ret = GetCurrentDirectoryW(count, buffer.get()); + if (count != ret + 1) { + throw IOException("Could not get working directory!"); + } + return WindowsUtil::UnicodeToUTF8(buffer.get()); +} + +#endif + +string FileSystem::JoinPath(const string &a, const string &b) { + // FIXME: sanitize paths + return a + PathSeparator(a) + b; +} + +string FileSystem::ConvertSeparators(const string &path) { + auto separator_str = PathSeparator(path); + char separator = separator_str[0]; + if (separator == '/') { + // on unix-based systems we only accept / as a separator + return path; + } + // on windows-based systems we accept both + return StringUtil::Replace(path, "/", separator_str); +} + +string FileSystem::ExtractName(const string &path) { + if (path.empty()) { + return string(); + } + auto normalized_path = ConvertSeparators(path); + auto sep = PathSeparator(path); + auto splits = StringUtil::Split(normalized_path, sep); + D_ASSERT(!splits.empty()); + return splits.back(); +} + +string FileSystem::ExtractBaseName(const string &path) { + if (path.empty()) { + return string(); + } + auto vec = StringUtil::Split(ExtractName(path), "."); + D_ASSERT(!vec.empty()); + return vec[0]; +} + +string FileSystem::GetHomeDirectory(optional_ptr opener) { + // read the home_directory setting first, if it is set + if (opener) { + Value result; + if (opener->TryGetCurrentSetting("home_directory", result)) { + if (!result.IsNull() && !result.ToString().empty()) { + return result.ToString(); + } + } + } + // fallback to the default home directories for the specified system +#ifdef DUCKDB_WINDOWS + return FileSystem::GetEnvVariable("USERPROFILE"); +#else + return FileSystem::GetEnvVariable("HOME"); +#endif +} + +string FileSystem::GetHomeDirectory() { + return GetHomeDirectory(nullptr); +} + +string FileSystem::ExpandPath(const string &path, optional_ptr opener) { + if (path.empty()) { + return path; + } + if (path[0] == '~') { + return GetHomeDirectory(opener) + path.substr(1); + } + return path; +} + +string FileSystem::ExpandPath(const string &path) { + return FileSystem::ExpandPath(path, nullptr); +} + +// LCOV_EXCL_START +unique_ptr FileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock, + FileCompressionType compression, FileOpener *opener) { + throw NotImplementedException("%s: OpenFile is not implemented!", GetName()); +} + +void FileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + throw NotImplementedException("%s: Read (with location) is not implemented!", GetName()); +} + +void FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + throw NotImplementedException("%s: Write (with location) is not implemented!", GetName()); +} + +int64_t FileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + throw NotImplementedException("%s: Read is not implemented!", GetName()); +} + +int64_t FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { + throw NotImplementedException("%s: Write is not implemented!", GetName()); +} + +int64_t FileSystem::GetFileSize(FileHandle &handle) { + throw NotImplementedException("%s: GetFileSize is not implemented!", GetName()); +} + +time_t FileSystem::GetLastModifiedTime(FileHandle &handle) { + throw NotImplementedException("%s: GetLastModifiedTime is not implemented!", GetName()); +} + +FileType FileSystem::GetFileType(FileHandle &handle) { + return FileType::FILE_TYPE_INVALID; +} + +void FileSystem::Truncate(FileHandle &handle, int64_t new_size) { + throw NotImplementedException("%s: Truncate is not implemented!", GetName()); +} + +bool FileSystem::DirectoryExists(const string &directory) { + throw NotImplementedException("%s: DirectoryExists is not implemented!", GetName()); +} + +void FileSystem::CreateDirectory(const string &directory) { + throw NotImplementedException("%s: CreateDirectory is not implemented!", GetName()); +} + +void FileSystem::RemoveDirectory(const string &directory) { + throw NotImplementedException("%s: RemoveDirectory is not implemented!", GetName()); +} + +bool FileSystem::ListFiles(const string &directory, const std::function &callback, + FileOpener *opener) { + throw NotImplementedException("%s: ListFiles is not implemented!", GetName()); +} + +void FileSystem::MoveFile(const string &source, const string &target) { + throw NotImplementedException("%s: MoveFile is not implemented!", GetName()); +} + +bool FileSystem::FileExists(const string &filename) { + throw NotImplementedException("%s: FileExists is not implemented!", GetName()); +} + +bool FileSystem::IsPipe(const string &filename) { + throw NotImplementedException("%s: IsPipe is not implemented!", GetName()); +} + +void FileSystem::RemoveFile(const string &filename) { + throw NotImplementedException("%s: RemoveFile is not implemented!", GetName()); +} + +void FileSystem::FileSync(FileHandle &handle) { + throw NotImplementedException("%s: FileSync is not implemented!", GetName()); +} + +bool FileSystem::HasGlob(const string &str) { + for (idx_t i = 0; i < str.size(); i++) { + switch (str[i]) { + case '*': + case '?': + case '[': + return true; + default: + break; + } + } + return false; +} + +vector FileSystem::Glob(const string &path, FileOpener *opener) { + throw NotImplementedException("%s: Glob is not implemented!", GetName()); +} + +void FileSystem::RegisterSubSystem(unique_ptr sub_fs) { + throw NotImplementedException("%s: Can't register a sub system on a non-virtual file system", GetName()); +} + +void FileSystem::RegisterSubSystem(FileCompressionType compression_type, unique_ptr sub_fs) { + throw NotImplementedException("%s: Can't register a sub system on a non-virtual file system", GetName()); +} + +void FileSystem::UnregisterSubSystem(const string &name) { + throw NotImplementedException("%s: Can't unregister a sub system on a non-virtual file system", GetName()); +} + +void FileSystem::SetDisabledFileSystems(const vector &names) { + throw NotImplementedException("%s: Can't disable file systems on a non-virtual file system", GetName()); +} + +vector FileSystem::ListSubSystems() { + throw NotImplementedException("%s: Can't list sub systems on a non-virtual file system", GetName()); +} + +bool FileSystem::CanHandleFile(const string &fpath) { + throw NotImplementedException("%s: CanHandleFile is not implemented!", GetName()); +} + +static string LookupExtensionForPattern(const string &pattern) { + for (const auto &entry : EXTENSION_FILE_PREFIXES) { + if (StringUtil::StartsWith(pattern, entry.name)) { + return entry.extension; + } + } + return ""; +} + +vector FileSystem::GlobFiles(const string &pattern, ClientContext &context, FileGlobOptions options) { + auto result = Glob(pattern); + if (result.empty()) { + string required_extension = LookupExtensionForPattern(pattern); + if (!required_extension.empty() && !context.db->ExtensionIsLoaded(required_extension)) { + auto &dbconfig = DBConfig::GetConfig(context); + if (!ExtensionHelper::CanAutoloadExtension(required_extension) || + !dbconfig.options.autoload_known_extensions) { + auto error_message = + "File " + pattern + " requires the extension " + required_extension + " to be loaded"; + error_message = + ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, required_extension); + throw MissingExtensionException(error_message); + } + // an extension is required to read this file, but it is not loaded - try to load it + ExtensionHelper::AutoLoadExtension(context, required_extension); + // success! glob again + // check the extension is loaded just in case to prevent an infinite loop here + if (!context.db->ExtensionIsLoaded(required_extension)) { + throw InternalException("Extension load \"%s\" did not throw but somehow the extension was not loaded", + required_extension); + } + return GlobFiles(pattern, context, options); + } + if (options == FileGlobOptions::DISALLOW_EMPTY) { + throw IOException("No files found that match the pattern \"%s\"", pattern); + } + } + return result; +} + +void FileSystem::Seek(FileHandle &handle, idx_t location) { + throw NotImplementedException("%s: Seek is not implemented!", GetName()); +} + +void FileSystem::Reset(FileHandle &handle) { + handle.Seek(0); +} + +idx_t FileSystem::SeekPosition(FileHandle &handle) { + throw NotImplementedException("%s: SeekPosition is not implemented!", GetName()); +} + +bool FileSystem::CanSeek() { + throw NotImplementedException("%s: CanSeek is not implemented!", GetName()); +} + +unique_ptr FileSystem::OpenCompressedFile(unique_ptr handle, bool write) { + throw NotImplementedException("%s: OpenCompressedFile is not implemented!", GetName()); +} + +bool FileSystem::OnDiskFile(FileHandle &handle) { + throw NotImplementedException("%s: OnDiskFile is not implemented!", GetName()); +} +// LCOV_EXCL_STOP + +FileHandle::FileHandle(FileSystem &file_system, string path_p) : file_system(file_system), path(std::move(path_p)) { +} + +FileHandle::~FileHandle() { +} + +int64_t FileHandle::Read(void *buffer, idx_t nr_bytes) { + return file_system.Read(*this, buffer, nr_bytes); +} + +int64_t FileHandle::Write(void *buffer, idx_t nr_bytes) { + return file_system.Write(*this, buffer, nr_bytes); +} + +void FileHandle::Read(void *buffer, idx_t nr_bytes, idx_t location) { + file_system.Read(*this, buffer, nr_bytes, location); +} + +void FileHandle::Write(void *buffer, idx_t nr_bytes, idx_t location) { + file_system.Write(*this, buffer, nr_bytes, location); +} + +void FileHandle::Seek(idx_t location) { + file_system.Seek(*this, location); +} + +void FileHandle::Reset() { + file_system.Reset(*this); +} + +idx_t FileHandle::SeekPosition() { + return file_system.SeekPosition(*this); +} + +bool FileHandle::CanSeek() { + return file_system.CanSeek(); +} + +string FileHandle::ReadLine() { + string result; + char buffer[1]; + while (true) { + idx_t tuples_read = Read(buffer, 1); + if (tuples_read == 0 || buffer[0] == '\n') { + return result; + } + if (buffer[0] != '\r') { + result += buffer[0]; + } + } +} + +bool FileHandle::OnDiskFile() { + return file_system.OnDiskFile(*this); +} + +idx_t FileHandle::GetFileSize() { + return file_system.GetFileSize(*this); +} + +void FileHandle::Sync() { + file_system.FileSync(*this); +} + +void FileHandle::Truncate(int64_t new_size) { + file_system.Truncate(*this, new_size); +} + +FileType FileHandle::GetType() { + return file_system.GetFileType(*this); +} + +bool FileSystem::IsRemoteFile(const string &path) { + const string prefixes[] = {"http://", "https://", "s3://"}; + for (auto &prefix : prefixes) { + if (StringUtil::StartsWith(path, prefix)) { + return true; + } + } + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/filename_pattern.cpp b/src/duckdb/src/common/filename_pattern.cpp new file mode 100644 index 00000000..e52b9a61 --- /dev/null +++ b/src/duckdb/src/common/filename_pattern.cpp @@ -0,0 +1,41 @@ +#include "duckdb/common/filename_pattern.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +void FilenamePattern::SetFilenamePattern(const string &pattern) { + const string id_format {"{i}"}; + const string uuid_format {"{uuid}"}; + + _base = pattern; + + _pos = _base.find(id_format); + if (_pos != string::npos) { + _base = StringUtil::Replace(_base, id_format, ""); + _uuid = false; + } + + _pos = _base.find(uuid_format); + if (_pos != string::npos) { + _base = StringUtil::Replace(_base, uuid_format, ""); + _uuid = true; + } + + _pos = std::min(_pos, (idx_t)_base.length()); +} + +string FilenamePattern::CreateFilename(FileSystem &fs, const string &path, const string &extension, + idx_t offset) const { + string result(_base); + string replacement; + + if (_uuid) { + replacement = UUID::ToString(UUID::GenerateRandomUUID()); + } else { + replacement = std::to_string(offset); + } + result.insert(_pos, replacement); + return fs.JoinPath(path, result + "." + extension); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/fsst.cpp b/src/duckdb/src/common/fsst.cpp new file mode 100644 index 00000000..6c8c3de4 --- /dev/null +++ b/src/duckdb/src/common/fsst.cpp @@ -0,0 +1,35 @@ +#include "duckdb/storage/string_uncompressed.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/fsst.hpp" +#include "fsst.h" + +namespace duckdb { + +string_t FSSTPrimitives::DecompressValue(void *duckdb_fsst_decoder, Vector &result, const char *compressed_string, + idx_t compressed_string_len) { + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + unsigned char decompress_buffer[StringUncompressed::STRING_BLOCK_LIMIT + 1]; + auto fsst_decoder = reinterpret_cast(duckdb_fsst_decoder); + auto compressed_string_ptr = (unsigned char *)compressed_string; // NOLINT + auto decompressed_string_size = + duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, + StringUncompressed::STRING_BLOCK_LIMIT + 1, &decompress_buffer[0]); + D_ASSERT(decompressed_string_size <= StringUncompressed::STRING_BLOCK_LIMIT); + + return StringVector::AddStringOrBlob(result, const_char_ptr_cast(decompress_buffer), decompressed_string_size); +} + +Value FSSTPrimitives::DecompressValue(void *duckdb_fsst_decoder, const char *compressed_string, + idx_t compressed_string_len) { + unsigned char decompress_buffer[StringUncompressed::STRING_BLOCK_LIMIT + 1]; + auto compressed_string_ptr = (unsigned char *)compressed_string; // NOLINT + auto fsst_decoder = reinterpret_cast(duckdb_fsst_decoder); + auto decompressed_string_size = + duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, + StringUncompressed::STRING_BLOCK_LIMIT + 1, &decompress_buffer[0]); + D_ASSERT(decompressed_string_size <= StringUncompressed::STRING_BLOCK_LIMIT); + + return Value(string(char_ptr_cast(decompress_buffer), decompressed_string_size)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/gzip_file_system.cpp b/src/duckdb/src/common/gzip_file_system.cpp new file mode 100644 index 00000000..51774bd8 --- /dev/null +++ b/src/duckdb/src/common/gzip_file_system.cpp @@ -0,0 +1,397 @@ +#include "duckdb/common/gzip_file_system.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_system.hpp" + +#include "miniz.hpp" +#include "miniz_wrapper.hpp" + +#include "duckdb/common/limits.hpp" + +namespace duckdb { + +/* + + 0 2 bytes magic header 0x1f, 0x8b (\037 \213) + 2 1 byte compression method + 0: store (copied) + 1: compress + 2: pack + 3: lzh + 4..7: reserved + 8: deflate + 3 1 byte flags + bit 0 set: file probably ascii text + bit 1 set: continuation of multi-part gzip file, part number present + bit 2 set: extra field present + bit 3 set: original file name present + bit 4 set: file comment present + bit 5 set: file is encrypted, encryption header present + bit 6,7: reserved + 4 4 bytes file modification time in Unix format + 8 1 byte extra flags (depend on compression method) + 9 1 byte OS type +[ + 2 bytes optional part number (second part=1) +]? +[ + 2 bytes optional extra field length (e) + (e)bytes optional extra field +]? +[ + bytes optional original file name, zero terminated +]? +[ + bytes optional file comment, zero terminated +]? +[ + 12 bytes optional encryption header +]? + bytes compressed data + 4 bytes crc32 + 4 bytes uncompressed input size modulo 2^32 + + */ + +static idx_t GZipConsumeString(FileHandle &input) { + idx_t size = 1; // terminator + char buffer[1]; + while (input.Read(buffer, 1) == 1) { + if (buffer[0] == '\0') { + break; + } + size++; + } + return size; +} + +struct MiniZStreamWrapper : public StreamWrapper { + ~MiniZStreamWrapper() override; + + CompressedFile *file = nullptr; + duckdb_miniz::mz_stream *mz_stream_ptr = nullptr; + bool writing = false; + duckdb_miniz::mz_ulong crc; + idx_t total_size; + +public: + void Initialize(CompressedFile &file, bool write) override; + + bool Read(StreamData &stream_data) override; + void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, int64_t nr_bytes) override; + + void Close() override; + + void FlushStream(); +}; + +MiniZStreamWrapper::~MiniZStreamWrapper() { + // avoid closing if destroyed during stack unwinding + if (Exception::UncaughtException()) { + return; + } + try { + MiniZStreamWrapper::Close(); + } catch (...) { + } +} + +void MiniZStreamWrapper::Initialize(CompressedFile &file, bool write) { + Close(); + this->file = &file; + mz_stream_ptr = new duckdb_miniz::mz_stream(); + memset(mz_stream_ptr, 0, sizeof(duckdb_miniz::mz_stream)); + this->writing = write; + + // TODO use custom alloc/free methods in miniz to throw exceptions on OOM + uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; + if (write) { + crc = MZ_CRC32_INIT; + total_size = 0; + + MiniZStream::InitializeGZIPHeader(gzip_hdr); + file.child_handle->Write(gzip_hdr, GZIP_HEADER_MINSIZE); + + auto ret = mz_deflateInit2((duckdb_miniz::mz_streamp)mz_stream_ptr, duckdb_miniz::MZ_DEFAULT_LEVEL, MZ_DEFLATED, + -MZ_DEFAULT_WINDOW_BITS, 1, 0); + if (ret != duckdb_miniz::MZ_OK) { + throw InternalException("Failed to initialize miniz"); + } + } else { + idx_t data_start = GZIP_HEADER_MINSIZE; + auto read_count = file.child_handle->Read(gzip_hdr, GZIP_HEADER_MINSIZE); + GZipFileSystem::VerifyGZIPHeader(gzip_hdr, read_count); + // Skip over the extra field if necessary + if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { + uint8_t gzip_xlen[2]; + file.child_handle->Seek(data_start); + file.child_handle->Read(gzip_xlen, 2); + idx_t xlen = (uint8_t)gzip_xlen[0] | (uint8_t)gzip_xlen[1] << 8; + data_start += xlen + 2; + } + // Skip over the file name if necessary + if (gzip_hdr[3] & GZIP_FLAG_NAME) { + file.child_handle->Seek(data_start); + data_start += GZipConsumeString(*file.child_handle); + } + file.child_handle->Seek(data_start); + // stream is now set to beginning of payload data + auto ret = duckdb_miniz::mz_inflateInit2((duckdb_miniz::mz_streamp)mz_stream_ptr, -MZ_DEFAULT_WINDOW_BITS); + if (ret != duckdb_miniz::MZ_OK) { + throw InternalException("Failed to initialize miniz"); + } + } +} + +bool MiniZStreamWrapper::Read(StreamData &sd) { + // Handling for the concatenated files + if (sd.refresh) { + auto available = (uint32_t)(sd.in_buff_end - sd.in_buff_start); + if (available <= GZIP_FOOTER_SIZE) { + // Only footer is available so we just close and return finished + Close(); + return true; + } + + sd.refresh = false; + auto body_ptr = sd.in_buff_start + GZIP_FOOTER_SIZE; + uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; + memcpy(gzip_hdr, body_ptr, GZIP_HEADER_MINSIZE); + GZipFileSystem::VerifyGZIPHeader(gzip_hdr, GZIP_HEADER_MINSIZE); + body_ptr += GZIP_HEADER_MINSIZE; + if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { + idx_t xlen = (uint8_t)*body_ptr | (uint8_t) * (body_ptr + 1) << 8; + body_ptr += xlen + 2; + if (GZIP_FOOTER_SIZE + GZIP_HEADER_MINSIZE + 2 + xlen >= GZIP_HEADER_MAXSIZE) { + throw InternalException("Extra field resulting in GZIP header larger than defined maximum (%d)", + GZIP_HEADER_MAXSIZE); + } + } + if (gzip_hdr[3] & GZIP_FLAG_NAME) { + char c; + do { + c = *body_ptr; + body_ptr++; + } while (c != '\0' && body_ptr < sd.in_buff_end); + if ((idx_t)(body_ptr - sd.in_buff_start) >= GZIP_HEADER_MAXSIZE) { + throw InternalException("Filename resulting in GZIP header larger than defined maximum (%d)", + GZIP_HEADER_MAXSIZE); + } + } + sd.in_buff_start = body_ptr; + if (sd.in_buff_end - sd.in_buff_start < 1) { + Close(); + return true; + } + duckdb_miniz::mz_inflateEnd(mz_stream_ptr); + auto sta = duckdb_miniz::mz_inflateInit2((duckdb_miniz::mz_streamp)mz_stream_ptr, -MZ_DEFAULT_WINDOW_BITS); + if (sta != duckdb_miniz::MZ_OK) { + throw InternalException("Failed to initialize miniz"); + } + } + + // actually decompress + mz_stream_ptr->next_in = sd.in_buff_start; + D_ASSERT(sd.in_buff_end - sd.in_buff_start < NumericLimits::Maximum()); + mz_stream_ptr->avail_in = (uint32_t)(sd.in_buff_end - sd.in_buff_start); + mz_stream_ptr->next_out = data_ptr_cast(sd.out_buff_end); + mz_stream_ptr->avail_out = (uint32_t)((sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_end); + auto ret = duckdb_miniz::mz_inflate(mz_stream_ptr, duckdb_miniz::MZ_NO_FLUSH); + if (ret != duckdb_miniz::MZ_OK && ret != duckdb_miniz::MZ_STREAM_END) { + throw IOException("Failed to decode gzip stream: %s", duckdb_miniz::mz_error(ret)); + } + // update pointers following inflate() + sd.in_buff_start = (data_ptr_t)mz_stream_ptr->next_in; // NOLINT + sd.in_buff_end = sd.in_buff_start + mz_stream_ptr->avail_in; + sd.out_buff_end = data_ptr_cast(mz_stream_ptr->next_out); + D_ASSERT(sd.out_buff_end + mz_stream_ptr->avail_out == sd.out_buff.get() + sd.out_buf_size); + + // if stream ended, deallocate inflator + if (ret == duckdb_miniz::MZ_STREAM_END) { + // Concatenated GZIP potentially coming up - refresh input buffer + sd.refresh = true; + } + return false; +} + +void MiniZStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t uncompressed_data, + int64_t uncompressed_size) { + // update the src and the total size + crc = duckdb_miniz::mz_crc32(crc, reinterpret_cast(uncompressed_data), uncompressed_size); + total_size += uncompressed_size; + + auto remaining = uncompressed_size; + while (remaining > 0) { + idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; + + mz_stream_ptr->next_in = reinterpret_cast(uncompressed_data); + mz_stream_ptr->avail_in = remaining; + mz_stream_ptr->next_out = sd.out_buff_start; + mz_stream_ptr->avail_out = output_remaining; + + auto res = mz_deflate(mz_stream_ptr, duckdb_miniz::MZ_NO_FLUSH); + if (res != duckdb_miniz::MZ_OK) { + D_ASSERT(res != duckdb_miniz::MZ_STREAM_END); + throw InternalException("Failed to compress GZIP block"); + } + sd.out_buff_start += output_remaining - mz_stream_ptr->avail_out; + if (mz_stream_ptr->avail_out == 0) { + // no more output buffer available: flush + file.child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); + sd.out_buff_start = sd.out_buff.get(); + } + idx_t written = remaining - mz_stream_ptr->avail_in; + uncompressed_data += written; + remaining = mz_stream_ptr->avail_in; + } +} + +void MiniZStreamWrapper::FlushStream() { + auto &sd = file->stream_data; + mz_stream_ptr->next_in = nullptr; + mz_stream_ptr->avail_in = 0; + while (true) { + auto output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; + mz_stream_ptr->next_out = sd.out_buff_start; + mz_stream_ptr->avail_out = output_remaining; + + auto res = mz_deflate(mz_stream_ptr, duckdb_miniz::MZ_FINISH); + sd.out_buff_start += (output_remaining - mz_stream_ptr->avail_out); + if (sd.out_buff_start > sd.out_buff.get()) { + file->child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); + sd.out_buff_start = sd.out_buff.get(); + } + if (res == duckdb_miniz::MZ_STREAM_END) { + break; + } + if (res != duckdb_miniz::MZ_OK) { + throw InternalException("Failed to compress GZIP block"); + } + } +} + +void MiniZStreamWrapper::Close() { + if (!mz_stream_ptr) { + return; + } + if (writing) { + // flush anything remaining in the stream + FlushStream(); + + // write the footer + unsigned char gzip_footer[MiniZStream::GZIP_FOOTER_SIZE]; + MiniZStream::InitializeGZIPFooter(gzip_footer, crc, total_size); + file->child_handle->Write(gzip_footer, MiniZStream::GZIP_FOOTER_SIZE); + + duckdb_miniz::mz_deflateEnd(mz_stream_ptr); + } else { + duckdb_miniz::mz_inflateEnd(mz_stream_ptr); + } + delete mz_stream_ptr; + mz_stream_ptr = nullptr; + file = nullptr; +} + +class GZipFile : public CompressedFile { +public: + GZipFile(unique_ptr child_handle_p, const string &path, bool write) + : CompressedFile(gzip_fs, std::move(child_handle_p), path) { + Initialize(write); + } + + GZipFileSystem gzip_fs; +}; + +void GZipFileSystem::VerifyGZIPHeader(uint8_t gzip_hdr[], idx_t read_count) { + // check for incorrectly formatted files + if (read_count != GZIP_HEADER_MINSIZE) { + throw IOException("Input is not a GZIP stream"); + } + if (gzip_hdr[0] != 0x1F || gzip_hdr[1] != 0x8B) { // magic header + throw IOException("Input is not a GZIP stream"); + } + if (gzip_hdr[2] != GZIP_COMPRESSION_DEFLATE) { // compression method + throw IOException("Unsupported GZIP compression method"); + } + if (gzip_hdr[3] & GZIP_FLAG_UNSUPPORTED) { + throw IOException("Unsupported GZIP archive"); + } +} + +string GZipFileSystem::UncompressGZIPString(const string &in) { + // decompress file + auto body_ptr = in.data(); + + auto mz_stream_ptr = new duckdb_miniz::mz_stream(); + memset(mz_stream_ptr, 0, sizeof(duckdb_miniz::mz_stream)); + + uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; + + // check for incorrectly formatted files + + // TODO this is mostly the same as gzip_file_system.cpp + if (in.size() < GZIP_HEADER_MINSIZE) { + throw IOException("Input is not a GZIP stream"); + } + memcpy(gzip_hdr, body_ptr, GZIP_HEADER_MINSIZE); + body_ptr += GZIP_HEADER_MINSIZE; + GZipFileSystem::VerifyGZIPHeader(gzip_hdr, GZIP_HEADER_MINSIZE); + + if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { + throw IOException("Extra field in a GZIP stream unsupported"); + } + + if (gzip_hdr[3] & GZIP_FLAG_NAME) { + char c; + do { + c = *body_ptr; + body_ptr++; + } while (c != '\0' && (idx_t)(body_ptr - in.data()) < in.size()); + } + + // stream is now set to beginning of payload data + auto status = duckdb_miniz::mz_inflateInit2(mz_stream_ptr, -MZ_DEFAULT_WINDOW_BITS); + if (status != duckdb_miniz::MZ_OK) { + throw InternalException("Failed to initialize miniz"); + } + + auto bytes_remaining = in.size() - (body_ptr - in.data()); + mz_stream_ptr->next_in = const_uchar_ptr_cast(body_ptr); + mz_stream_ptr->avail_in = bytes_remaining; + + unsigned char decompress_buffer[BUFSIZ]; + string decompressed; + + while (status == duckdb_miniz::MZ_OK) { + mz_stream_ptr->next_out = decompress_buffer; + mz_stream_ptr->avail_out = sizeof(decompress_buffer); + status = mz_inflate(mz_stream_ptr, duckdb_miniz::MZ_NO_FLUSH); + if (status != duckdb_miniz::MZ_STREAM_END && status != duckdb_miniz::MZ_OK) { + throw IOException("Failed to uncompress"); + } + decompressed.append(char_ptr_cast(decompress_buffer), mz_stream_ptr->total_out - decompressed.size()); + } + duckdb_miniz::mz_inflateEnd(mz_stream_ptr); + if (decompressed.empty()) { + throw IOException("Failed to uncompress"); + } + return decompressed; +} + +unique_ptr GZipFileSystem::OpenCompressedFile(unique_ptr handle, bool write) { + auto path = handle->path; + return make_uniq(std::move(handle), path, write); +} + +unique_ptr GZipFileSystem::CreateStream() { + return make_uniq(); +} + +idx_t GZipFileSystem::InBufferSize() { + return BUFFER_SIZE; +} + +idx_t GZipFileSystem::OutBufferSize() { + return BUFFER_SIZE; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/hive_partitioning.cpp b/src/duckdb/src/common/hive_partitioning.cpp new file mode 100644 index 00000000..d3b24d76 --- /dev/null +++ b/src/duckdb/src/common/hive_partitioning.cpp @@ -0,0 +1,394 @@ +#include "duckdb/common/hive_partitioning.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/filter_combiner.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "re2/re2.h" + +namespace duckdb { + +static unordered_map GetKnownColumnValues(string &filename, + unordered_map &column_map, + duckdb_re2::RE2 &compiled_regex, bool filename_col, + bool hive_partition_cols) { + unordered_map result; + + if (filename_col) { + auto lookup_column_id = column_map.find("filename"); + if (lookup_column_id != column_map.end()) { + result[lookup_column_id->second] = filename; + } + } + + if (hive_partition_cols) { + auto partitions = HivePartitioning::Parse(filename, compiled_regex); + for (auto &partition : partitions) { + auto lookup_column_id = column_map.find(partition.first); + if (lookup_column_id != column_map.end()) { + result[lookup_column_id->second] = partition.second; + } + } + } + + return result; +} + +// Takes an expression and converts a list of known column_refs to constants +static void ConvertKnownColRefToConstants(unique_ptr &expr, + unordered_map &known_column_values, idx_t table_index) { + if (expr->type == ExpressionType::BOUND_COLUMN_REF) { + auto &bound_colref = expr->Cast(); + + // This bound column ref is for another table + if (table_index != bound_colref.binding.table_index) { + return; + } + + auto lookup = known_column_values.find(bound_colref.binding.column_index); + if (lookup != known_column_values.end()) { + expr = make_uniq(Value(lookup->second).DefaultCastAs(bound_colref.return_type)); + } + } else { + ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { + ConvertKnownColRefToConstants(child, known_column_values, table_index); + }); + } +} + +// matches hive partitions in file name. For example: +// - s3://bucket/var1=value1/bla/bla/var2=value2 +// - http(s)://domain(:port)/lala/kasdl/var1=value1/?not-a-var=not-a-value +// - folder/folder/folder/../var1=value1/etc/.//var2=value2 +const string HivePartitioning::REGEX_STRING = "[\\/\\\\]([^\\/\\?\\\\]+)=([^\\/\\n\\?\\\\]+)"; + +std::map HivePartitioning::Parse(const string &filename, duckdb_re2::RE2 ®ex) { + std::map result; + duckdb_re2::StringPiece input(filename); // Wrap a StringPiece around it + + string var; + string value; + while (RE2::FindAndConsume(&input, regex, &var, &value)) { + result.insert(std::pair(var, value)); + } + return result; +} + +std::map HivePartitioning::Parse(const string &filename) { + duckdb_re2::RE2 regex(REGEX_STRING); + return Parse(filename, regex); +} + +// TODO: this can still be improved by removing the parts of filter expressions that are true for all remaining files. +// currently, only expressions that cannot be evaluated during pushdown are removed. +void HivePartitioning::ApplyFiltersToFileList(ClientContext &context, vector &files, + vector> &filters, + unordered_map &column_map, LogicalGet &get, + bool hive_enabled, bool filename_enabled) { + + vector pruned_files; + vector have_preserved_filter(filters.size(), false); + vector> pruned_filters; + unordered_set filters_applied_to_files; + duckdb_re2::RE2 regex(REGEX_STRING); + auto table_index = get.table_index; + + if ((!filename_enabled && !hive_enabled) || filters.empty()) { + return; + } + + for (idx_t i = 0; i < files.size(); i++) { + auto &file = files[i]; + bool should_prune_file = false; + auto known_values = GetKnownColumnValues(file, column_map, regex, filename_enabled, hive_enabled); + + FilterCombiner combiner(context); + + for (idx_t j = 0; j < filters.size(); j++) { + auto &filter = filters[j]; + unique_ptr filter_copy = filter->Copy(); + ConvertKnownColRefToConstants(filter_copy, known_values, table_index); + // Evaluate the filter, if it can be evaluated here, we can not prune this filter + Value result_value; + + if (!filter_copy->IsScalar() || !filter_copy->IsFoldable() || + !ExpressionExecutor::TryEvaluateScalar(context, *filter_copy, result_value)) { + // can not be evaluated only with the filename/hive columns added, we can not prune this filter + if (!have_preserved_filter[j]) { + pruned_filters.emplace_back(filter->Copy()); + have_preserved_filter[j] = true; + } + } else if (!result_value.GetValue()) { + // filter evaluates to false + should_prune_file = true; + // convert the filter to a table filter. + if (filters_applied_to_files.find(j) == filters_applied_to_files.end()) { + get.extra_info.file_filters += filter->ToString(); + filters_applied_to_files.insert(j); + } + } + } + + if (!should_prune_file) { + pruned_files.push_back(file); + } + } + + D_ASSERT(filters.size() >= pruned_filters.size()); + + filters = std::move(pruned_filters); + files = std::move(pruned_files); +} + +HivePartitionedColumnData::HivePartitionedColumnData(const HivePartitionedColumnData &other) + : PartitionedColumnData(other), hashes_v(LogicalType::HASH) { + // Synchronize to ensure consistency of shared partition map + if (other.global_state) { + global_state = other.global_state; + unique_lock lck(global_state->lock); + SynchronizeLocalMap(); + } + InitializeKeys(); +} + +void HivePartitionedColumnData::InitializeKeys() { + keys.resize(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + keys[i].values.resize(group_by_columns.size()); + } +} + +template +static inline Value GetHiveKeyValue(const T &val) { + return Value::CreateValue(val); +} + +template +static inline Value GetHiveKeyValue(const T &val, const LogicalType &type) { + auto result = GetHiveKeyValue(val); + result.Reinterpret(type); + return result; +} + +static inline Value GetHiveKeyNullValue(const LogicalType &type) { + Value result; + result.Reinterpret(type); + return result; +} + +template +static void TemplatedGetHivePartitionValues(Vector &input, vector &keys, const idx_t col_idx, + const idx_t count) { + UnifiedVectorFormat format; + input.ToUnifiedFormat(count, format); + + const auto &sel = *format.sel; + const auto data = UnifiedVectorFormat::GetData(format); + const auto &validity = format.validity; + + const auto &type = input.GetType(); + + const auto reinterpret = Value::CreateValue(data[0]).GetTypeMutable() != type; + if (reinterpret) { + for (idx_t i = 0; i < count; i++) { + auto &key = keys[i]; + const auto idx = sel.get_index(i); + if (validity.RowIsValid(idx)) { + key.values[col_idx] = GetHiveKeyValue(data[idx], type); + } else { + key.values[col_idx] = GetHiveKeyNullValue(type); + } + } + } else { + for (idx_t i = 0; i < count; i++) { + auto &key = keys[i]; + const auto idx = sel.get_index(i); + if (validity.RowIsValid(idx)) { + key.values[col_idx] = GetHiveKeyValue(data[idx]); + } else { + key.values[col_idx] = GetHiveKeyNullValue(type); + } + } + } +} + +static void GetNestedHivePartitionValues(Vector &input, vector &keys, const idx_t col_idx, + const idx_t count) { + for (idx_t i = 0; i < count; i++) { + auto &key = keys[i]; + key.values[col_idx] = input.GetValue(i); + } +} + +static void GetHivePartitionValuesTypeSwitch(Vector &input, vector &keys, const idx_t col_idx, + const idx_t count) { + const auto &type = input.GetType(); + switch (type.InternalType()) { + case PhysicalType::BOOL: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::INT8: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::INT16: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::INT32: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::INT64: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::INT128: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::UINT8: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::UINT16: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::UINT32: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::UINT64: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::FLOAT: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::DOUBLE: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::INTERVAL: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::VARCHAR: + TemplatedGetHivePartitionValues(input, keys, col_idx, count); + break; + case PhysicalType::STRUCT: + case PhysicalType::LIST: + GetNestedHivePartitionValues(input, keys, col_idx, count); + break; + default: + throw InternalException("Unsupported type for HivePartitionedColumnData::ComputePartitionIndices"); + } +} + +void HivePartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) { + const auto count = input.size(); + + input.Hash(group_by_columns, hashes_v); + hashes_v.Flatten(count); + + for (idx_t col_idx = 0; col_idx < group_by_columns.size(); col_idx++) { + auto &group_by_col = input.data[group_by_columns[col_idx]]; + GetHivePartitionValuesTypeSwitch(group_by_col, keys, col_idx, count); + } + + const auto hashes = FlatVector::GetData(hashes_v); + const auto partition_indices = FlatVector::GetData(state.partition_indices); + for (idx_t i = 0; i < count; i++) { + auto &key = keys[i]; + key.hash = hashes[i]; + auto lookup = local_partition_map.find(key); + if (lookup == local_partition_map.end()) { + idx_t new_partition_id = RegisterNewPartition(key, state); + partition_indices[i] = new_partition_id; + } else { + partition_indices[i] = lookup->second; + } + } +} + +std::map HivePartitionedColumnData::GetReverseMap() { + std::map ret; + for (const auto &pair : local_partition_map) { + ret[pair.second] = &(pair.first); + } + return ret; +} + +void HivePartitionedColumnData::GrowAllocators() { + unique_lock lck_gstate(allocators->lock); + + idx_t current_allocator_size = allocators->allocators.size(); + idx_t required_allocators = local_partition_map.size(); + + allocators->allocators.reserve(current_allocator_size); + for (idx_t i = current_allocator_size; i < required_allocators; i++) { + CreateAllocator(); + } + + D_ASSERT(allocators->allocators.size() == local_partition_map.size()); +} + +void HivePartitionedColumnData::GrowAppendState(PartitionedColumnDataAppendState &state) { + idx_t current_append_state_size = state.partition_append_states.size(); + idx_t required_append_state_size = local_partition_map.size(); + + for (idx_t i = current_append_state_size; i < required_append_state_size; i++) { + state.partition_append_states.emplace_back(make_uniq()); + state.partition_buffers.emplace_back(CreatePartitionBuffer()); + } +} + +void HivePartitionedColumnData::GrowPartitions(PartitionedColumnDataAppendState &state) { + idx_t current_partitions = partitions.size(); + idx_t required_partitions = local_partition_map.size(); + + D_ASSERT(allocators->allocators.size() == required_partitions); + + for (idx_t i = current_partitions; i < required_partitions; i++) { + partitions.emplace_back(CreatePartitionCollection(i)); + partitions[i]->InitializeAppend(*state.partition_append_states[i]); + } + D_ASSERT(partitions.size() == local_partition_map.size()); +} + +void HivePartitionedColumnData::SynchronizeLocalMap() { + // Synchronise global map into local, may contain changes from other threads too + for (auto it = global_state->partitions.begin() + local_partition_map.size(); it < global_state->partitions.end(); + it++) { + local_partition_map[(*it)->first] = (*it)->second; + } +} + +idx_t HivePartitionedColumnData::RegisterNewPartition(HivePartitionKey key, PartitionedColumnDataAppendState &state) { + if (global_state) { + idx_t partition_id; + + // Synchronize Global state with our local state with the newly discoveren partition + { + unique_lock lck_gstate(global_state->lock); + + // Insert into global map, or return partition if already present + auto res = + global_state->partition_map.emplace(std::make_pair(std::move(key), global_state->partition_map.size())); + auto it = res.first; + partition_id = it->second; + + // Add iterator to vector to allow incrementally updating local states from global state + global_state->partitions.emplace_back(it); + SynchronizeLocalMap(); + } + + // After synchronizing with the global state, we need to grow the shared allocators to support + // the number of partitions, which guarantees that there's always enough allocators available to each thread + GrowAllocators(); + + // Grow local partition data + GrowAppendState(state); + GrowPartitions(state); + + return partition_id; + } else { + return local_partition_map.emplace(std::make_pair(std::move(key), local_partition_map.size())).first->second; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/http_state.cpp b/src/duckdb/src/common/http_state.cpp new file mode 100644 index 00000000..4a50ecd7 --- /dev/null +++ b/src/duckdb/src/common/http_state.cpp @@ -0,0 +1,78 @@ +#include "duckdb/common/http_state.hpp" + +namespace duckdb { + +CachedFileHandle::CachedFileHandle(shared_ptr &file_p) { + // If the file was not yet initialized, we need to grab a lock. + if (!file_p->initialized) { + lock = make_uniq>(file_p->lock); + } + file = file_p; +} + +void CachedFileHandle::SetInitialized() { + if (file->initialized) { + throw InternalException("Cannot set initialized on cached file that was already initialized"); + } + if (!lock) { + throw InternalException("Cannot set initialized on cached file without lock"); + } + file->initialized = true; + lock = nullptr; +} + +void CachedFileHandle::AllocateBuffer(idx_t size) { + if (file->initialized) { + throw InternalException("Cannot allocate a buffer for a cached file that was already initialized"); + } + file->data = std::shared_ptr(new char[size], std::default_delete()); + file->capacity = size; +} + +void CachedFileHandle::GrowBuffer(idx_t new_capacity, idx_t bytes_to_copy) { + // copy shared ptr to old data + auto old_data = file->data; + // allocate new buffer that can hold the new capacity + AllocateBuffer(new_capacity); + // copy the old data + Write(old_data.get(), bytes_to_copy); +} + +void CachedFileHandle::Write(const char *buffer, idx_t length, idx_t offset) { + //! Only write to non-initialized files with a lock; + D_ASSERT(!file->initialized && lock); + memcpy(file->data.get() + offset, buffer, length); +} + +void HTTPState::Reset() { + // Reset Counters + head_count = 0; + get_count = 0; + put_count = 0; + post_count = 0; + total_bytes_received = 0; + total_bytes_sent = 0; + + // Reset cached files + cached_files.clear(); +} + +shared_ptr HTTPState::TryGetState(FileOpener *opener) { + auto client_context = FileOpener::TryGetClientContext(opener); + if (client_context) { + return client_context->client_data->http_state; + } + return nullptr; +} + +//! Get cache entry, create if not exists +shared_ptr &HTTPState::GetCachedFile(const string &path) { + lock_guard lock(cached_files_mutex); + auto &cache_entry_ref = cached_files[path]; + if (!cache_entry_ref) { + cache_entry_ref = make_shared(); + } + return cache_entry_ref; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp new file mode 100644 index 00000000..ccf3e1cb --- /dev/null +++ b/src/duckdb/src/common/local_file_system.cpp @@ -0,0 +1,1043 @@ +#include "duckdb/common/local_file_system.hpp" + +#include "duckdb/common/checksum.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_opener.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/windows.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" + +#include +#include +#include + +#ifndef _WIN32 +#include +#include +#include +#include +#include +#else +#include "duckdb/common/windows_util.hpp" + +#include +#include + +#ifdef __MINGW32__ +// need to manually define this for mingw +extern "C" WINBASEAPI BOOL WINAPI GetPhysicallyInstalledSystemMemory(PULONGLONG); +#endif + +#undef FILE_CREATE // woo mingw +#endif + +namespace duckdb { + +static void AssertValidFileFlags(uint8_t flags) { +#ifdef DEBUG + bool is_read = flags & FileFlags::FILE_FLAGS_READ; + bool is_write = flags & FileFlags::FILE_FLAGS_WRITE; + // require either READ or WRITE (or both) + D_ASSERT(is_read || is_write); + // CREATE/Append flags require writing + D_ASSERT(is_write || !(flags & FileFlags::FILE_FLAGS_APPEND)); + D_ASSERT(is_write || !(flags & FileFlags::FILE_FLAGS_FILE_CREATE)); + D_ASSERT(is_write || !(flags & FileFlags::FILE_FLAGS_FILE_CREATE_NEW)); + // cannot combine CREATE and CREATE_NEW flags + D_ASSERT(!(flags & FileFlags::FILE_FLAGS_FILE_CREATE && flags & FileFlags::FILE_FLAGS_FILE_CREATE_NEW)); +#endif +} + +#ifndef _WIN32 +bool LocalFileSystem::FileExists(const string &filename) { + if (!filename.empty()) { + if (access(filename.c_str(), 0) == 0) { + struct stat status; + stat(filename.c_str(), &status); + if (S_ISREG(status.st_mode)) { + return true; + } + } + } + // if any condition fails + return false; +} + +bool LocalFileSystem::IsPipe(const string &filename) { + if (!filename.empty()) { + if (access(filename.c_str(), 0) == 0) { + struct stat status; + stat(filename.c_str(), &status); + if (S_ISFIFO(status.st_mode)) { + return true; + } + } + } + // if any condition fails + return false; +} + +#else +bool LocalFileSystem::FileExists(const string &filename) { + auto unicode_path = WindowsUtil::UTF8ToUnicode(filename.c_str()); + const wchar_t *wpath = unicode_path.c_str(); + if (_waccess(wpath, 0) == 0) { + struct _stati64 status; + _wstati64(wpath, &status); + if (status.st_mode & S_IFREG) { + return true; + } + } + return false; +} +bool LocalFileSystem::IsPipe(const string &filename) { + auto unicode_path = WindowsUtil::UTF8ToUnicode(filename.c_str()); + const wchar_t *wpath = unicode_path.c_str(); + if (_waccess(wpath, 0) == 0) { + struct _stati64 status; + _wstati64(wpath, &status); + if (status.st_mode & _S_IFCHR) { + return true; + } + } + return false; +} +#endif + +#ifndef _WIN32 +// somehow sometimes this is missing +#ifndef O_CLOEXEC +#define O_CLOEXEC 0 +#endif + +// Solaris +#ifndef O_DIRECT +#define O_DIRECT 0 +#endif + +struct UnixFileHandle : public FileHandle { +public: + UnixFileHandle(FileSystem &file_system, string path, int fd) : FileHandle(file_system, std::move(path)), fd(fd) { + } + ~UnixFileHandle() override { + UnixFileHandle::Close(); + } + + int fd; + +public: + void Close() override { + if (fd != -1) { + close(fd); + fd = -1; + } + }; +}; + +static FileType GetFileTypeInternal(int fd) { // LCOV_EXCL_START + struct stat s; + if (fstat(fd, &s) == -1) { + return FileType::FILE_TYPE_INVALID; + } + switch (s.st_mode & S_IFMT) { + case S_IFBLK: + return FileType::FILE_TYPE_BLOCKDEV; + case S_IFCHR: + return FileType::FILE_TYPE_CHARDEV; + case S_IFIFO: + return FileType::FILE_TYPE_FIFO; + case S_IFDIR: + return FileType::FILE_TYPE_DIR; + case S_IFLNK: + return FileType::FILE_TYPE_LINK; + case S_IFREG: + return FileType::FILE_TYPE_REGULAR; + case S_IFSOCK: + return FileType::FILE_TYPE_SOCKET; + default: + return FileType::FILE_TYPE_INVALID; + } +} // LCOV_EXCL_STOP + +unique_ptr LocalFileSystem::OpenFile(const string &path_p, uint8_t flags, FileLockType lock_type, + FileCompressionType compression, FileOpener *opener) { + auto path = FileSystem::ExpandPath(path_p, opener); + if (compression != FileCompressionType::UNCOMPRESSED) { + throw NotImplementedException("Unsupported compression type for default file system"); + } + + AssertValidFileFlags(flags); + + int open_flags = 0; + int rc; + bool open_read = flags & FileFlags::FILE_FLAGS_READ; + bool open_write = flags & FileFlags::FILE_FLAGS_WRITE; + if (open_read && open_write) { + open_flags = O_RDWR; + } else if (open_read) { + open_flags = O_RDONLY; + } else if (open_write) { + open_flags = O_WRONLY; + } else { + throw InternalException("READ, WRITE or both should be specified when opening a file"); + } + if (open_write) { + // need Read or Write + D_ASSERT(flags & FileFlags::FILE_FLAGS_WRITE); + open_flags |= O_CLOEXEC; + if (flags & FileFlags::FILE_FLAGS_FILE_CREATE) { + open_flags |= O_CREAT; + } else if (flags & FileFlags::FILE_FLAGS_FILE_CREATE_NEW) { + open_flags |= O_CREAT | O_TRUNC; + } + if (flags & FileFlags::FILE_FLAGS_APPEND) { + open_flags |= O_APPEND; + } + } + if (flags & FileFlags::FILE_FLAGS_DIRECT_IO) { +#if defined(__sun) && defined(__SVR4) + throw Exception("DIRECT_IO not supported on Solaris"); +#endif +#if defined(__DARWIN__) || defined(__APPLE__) || defined(__OpenBSD__) + // OSX does not have O_DIRECT, instead we need to use fcntl afterwards to support direct IO + open_flags |= O_SYNC; +#else + open_flags |= O_DIRECT | O_SYNC; +#endif + } + int fd = open(path.c_str(), open_flags, 0666); + if (fd == -1) { + throw IOException("Cannot open file \"%s\": %s", path, strerror(errno)); + } + // #if defined(__DARWIN__) || defined(__APPLE__) + // if (flags & FileFlags::FILE_FLAGS_DIRECT_IO) { + // // OSX requires fcntl for Direct IO + // rc = fcntl(fd, F_NOCACHE, 1); + // if (fd == -1) { + // throw IOException("Could not enable direct IO for file \"%s\": %s", path, strerror(errno)); + // } + // } + // #endif + if (lock_type != FileLockType::NO_LOCK) { + // set lock on file + // but only if it is not an input/output stream + auto file_type = GetFileTypeInternal(fd); + if (file_type != FileType::FILE_TYPE_FIFO && file_type != FileType::FILE_TYPE_SOCKET) { + struct flock fl; + memset(&fl, 0, sizeof fl); + fl.l_type = lock_type == FileLockType::READ_LOCK ? F_RDLCK : F_WRLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 0; + rc = fcntl(fd, F_SETLK, &fl); + if (rc == -1) { + throw IOException("Could not set lock on file \"%s\": %s", path, strerror(errno)); + } + } + } + return make_uniq(*this, path, fd); +} + +void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { + int fd = handle.Cast().fd; + off_t offset = lseek(fd, location, SEEK_SET); + if (offset == (off_t)-1) { + throw IOException("Could not seek to location %lld for file \"%s\": %s", location, handle.path, + strerror(errno)); + } +} + +idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { + int fd = handle.Cast().fd; + off_t position = lseek(fd, 0, SEEK_CUR); + if (position == (off_t)-1) { + throw IOException("Could not get file position file \"%s\": %s", handle.path, strerror(errno)); + } + return position; +} + +void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + int fd = handle.Cast().fd; + auto read_buffer = char_ptr_cast(buffer); + while (nr_bytes > 0) { + int64_t bytes_read = pread(fd, read_buffer, nr_bytes, location); + if (bytes_read == -1) { + throw IOException("Could not read from file \"%s\": %s", handle.path, strerror(errno)); + } + if (bytes_read == 0) { + throw IOException( + "Could not read enough bytes from file \"%s\": attempted to read %llu bytes from location %llu", + handle.path, nr_bytes, location); + } + read_buffer += bytes_read; + nr_bytes -= bytes_read; + } +} + +int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + int fd = handle.Cast().fd; + int64_t bytes_read = read(fd, buffer, nr_bytes); + if (bytes_read == -1) { + throw IOException("Could not read from file \"%s\": %s", handle.path, strerror(errno)); + } + return bytes_read; +} + +void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + int fd = handle.Cast().fd; + auto write_buffer = char_ptr_cast(buffer); + while (nr_bytes > 0) { + int64_t bytes_written = pwrite(fd, write_buffer, nr_bytes, location); + if (bytes_written < 0) { + throw IOException("Could not write file \"%s\": %s", handle.path, strerror(errno)); + } + D_ASSERT(bytes_written >= 0 && bytes_written); + write_buffer += bytes_written; + nr_bytes -= bytes_written; + } +} + +int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { + int fd = handle.Cast().fd; + int64_t bytes_written = write(fd, buffer, nr_bytes); + if (bytes_written == -1) { + throw IOException("Could not write file \"%s\": %s", handle.path, strerror(errno)); + } + return bytes_written; +} + +int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { + int fd = handle.Cast().fd; + struct stat s; + if (fstat(fd, &s) == -1) { + return -1; + } + return s.st_size; +} + +time_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { + int fd = handle.Cast().fd; + struct stat s; + if (fstat(fd, &s) == -1) { + return -1; + } + return s.st_mtime; +} + +FileType LocalFileSystem::GetFileType(FileHandle &handle) { + int fd = handle.Cast().fd; + return GetFileTypeInternal(fd); +} + +void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { + int fd = handle.Cast().fd; + if (ftruncate(fd, new_size) != 0) { + throw IOException("Could not truncate file \"%s\": %s", handle.path, strerror(errno)); + } +} + +bool LocalFileSystem::DirectoryExists(const string &directory) { + if (!directory.empty()) { + if (access(directory.c_str(), 0) == 0) { + struct stat status; + stat(directory.c_str(), &status); + if (status.st_mode & S_IFDIR) { + return true; + } + } + } + // if any condition fails + return false; +} + +void LocalFileSystem::CreateDirectory(const string &directory) { + struct stat st; + + if (stat(directory.c_str(), &st) != 0) { + /* Directory does not exist. EEXIST for race condition */ + if (mkdir(directory.c_str(), 0755) != 0 && errno != EEXIST) { + throw IOException("Failed to create directory \"%s\"!", directory); + } + } else if (!S_ISDIR(st.st_mode)) { + throw IOException("Failed to create directory \"%s\": path exists but is not a directory!", directory); + } +} + +int RemoveDirectoryRecursive(const char *path) { + DIR *d = opendir(path); + idx_t path_len = (idx_t)strlen(path); + int r = -1; + + if (d) { + struct dirent *p; + r = 0; + while (!r && (p = readdir(d))) { + int r2 = -1; + char *buf; + idx_t len; + /* Skip the names "." and ".." as we don't want to recurse on them. */ + if (!strcmp(p->d_name, ".") || !strcmp(p->d_name, "..")) { + continue; + } + len = path_len + (idx_t)strlen(p->d_name) + 2; + buf = new (std::nothrow) char[len]; + if (buf) { + struct stat statbuf; + snprintf(buf, len, "%s/%s", path, p->d_name); + if (!stat(buf, &statbuf)) { + if (S_ISDIR(statbuf.st_mode)) { + r2 = RemoveDirectoryRecursive(buf); + } else { + r2 = unlink(buf); + } + } + delete[] buf; + } + r = r2; + } + closedir(d); + } + if (!r) { + r = rmdir(path); + } + return r; +} + +void LocalFileSystem::RemoveDirectory(const string &directory) { + RemoveDirectoryRecursive(directory.c_str()); +} + +void LocalFileSystem::RemoveFile(const string &filename) { + if (std::remove(filename.c_str()) != 0) { + throw IOException("Could not remove file \"%s\": %s", filename, strerror(errno)); + } +} + +bool LocalFileSystem::ListFiles(const string &directory, const std::function &callback, + FileOpener *opener) { + if (!DirectoryExists(directory)) { + return false; + } + DIR *dir = opendir(directory.c_str()); + if (!dir) { + return false; + } + struct dirent *ent; + // loop over all files in the directory + while ((ent = readdir(dir)) != nullptr) { + string name = string(ent->d_name); + // skip . .. and empty files + if (name.empty() || name == "." || name == "..") { + continue; + } + // now stat the file to figure out if it is a regular file or directory + string full_path = JoinPath(directory, name); + if (access(full_path.c_str(), 0) != 0) { + continue; + } + struct stat status; + stat(full_path.c_str(), &status); + if (!(status.st_mode & S_IFREG) && !(status.st_mode & S_IFDIR)) { + // not a file or directory: skip + continue; + } + // invoke callback + callback(name, status.st_mode & S_IFDIR); + } + closedir(dir); + return true; +} + +void LocalFileSystem::FileSync(FileHandle &handle) { + int fd = handle.Cast().fd; + if (fsync(fd) != 0) { + throw FatalException("fsync failed!"); + } +} + +void LocalFileSystem::MoveFile(const string &source, const string &target) { + //! FIXME: rename does not guarantee atomicity or overwriting target file if it exists + if (rename(source.c_str(), target.c_str()) != 0) { + throw IOException("Could not rename file!"); + } +} + +std::string LocalFileSystem::GetLastErrorAsString() { + return string(); +} + +#else + +constexpr char PIPE_PREFIX[] = "\\\\.\\pipe\\"; + +// Returns the last Win32 error, in string format. Returns an empty string if there is no error. +std::string LocalFileSystem::GetLastErrorAsString() { + // Get the error message, if any. + DWORD errorMessageID = GetLastError(); + if (errorMessageID == 0) + return std::string(); // No error message has been recorded + + LPSTR messageBuffer = nullptr; + idx_t size = + FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, errorMessageID, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); + + std::string message(messageBuffer, size); + + // Free the buffer. + LocalFree(messageBuffer); + + return message; +} + +struct WindowsFileHandle : public FileHandle { +public: + WindowsFileHandle(FileSystem &file_system, string path, HANDLE fd) + : FileHandle(file_system, path), position(0), fd(fd) { + } + ~WindowsFileHandle() override { + Close(); + } + + idx_t position; + HANDLE fd; + +public: + void Close() override { + if (!fd) { + return; + } + CloseHandle(fd); + fd = nullptr; + }; +}; + +unique_ptr LocalFileSystem::OpenFile(const string &path_p, uint8_t flags, FileLockType lock_type, + FileCompressionType compression, FileOpener *opener) { + auto path = FileSystem::ExpandPath(path_p, opener); + if (compression != FileCompressionType::UNCOMPRESSED) { + throw NotImplementedException("Unsupported compression type for default file system"); + } + AssertValidFileFlags(flags); + + DWORD desired_access; + DWORD share_mode; + DWORD creation_disposition = OPEN_EXISTING; + DWORD flags_and_attributes = FILE_ATTRIBUTE_NORMAL; + bool open_read = flags & FileFlags::FILE_FLAGS_READ; + bool open_write = flags & FileFlags::FILE_FLAGS_WRITE; + if (open_read && open_write) { + desired_access = GENERIC_READ | GENERIC_WRITE; + share_mode = 0; + } else if (open_read) { + desired_access = GENERIC_READ; + share_mode = FILE_SHARE_READ; + } else if (open_write) { + desired_access = GENERIC_WRITE; + share_mode = 0; + } else { + throw InternalException("READ, WRITE or both should be specified when opening a file"); + } + if (open_write) { + if (flags & FileFlags::FILE_FLAGS_FILE_CREATE) { + creation_disposition = OPEN_ALWAYS; + } else if (flags & FileFlags::FILE_FLAGS_FILE_CREATE_NEW) { + creation_disposition = CREATE_ALWAYS; + } + } + if (flags & FileFlags::FILE_FLAGS_DIRECT_IO) { + flags_and_attributes |= FILE_FLAG_NO_BUFFERING; + } + auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); + HANDLE hFile = CreateFileW(unicode_path.c_str(), desired_access, share_mode, NULL, creation_disposition, + flags_and_attributes, NULL); + if (hFile == INVALID_HANDLE_VALUE) { + auto error = LocalFileSystem::GetLastErrorAsString(); + throw IOException("Cannot open file \"%s\": %s", path.c_str(), error); + } + auto handle = make_uniq(*this, path.c_str(), hFile); + if (flags & FileFlags::FILE_FLAGS_APPEND) { + auto file_size = GetFileSize(*handle); + SetFilePointer(*handle, file_size); + } + return std::move(handle); +} + +void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { + auto &whandle = handle.Cast(); + whandle.position = location; + LARGE_INTEGER wlocation; + wlocation.QuadPart = location; + SetFilePointerEx(whandle.fd, wlocation, NULL, FILE_BEGIN); +} + +idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { + return handle.Cast().position; +} + +static DWORD FSInternalRead(FileHandle &handle, HANDLE hFile, void *buffer, int64_t nr_bytes, idx_t location) { + DWORD bytes_read = 0; + OVERLAPPED ov = {}; + ov.Internal = 0; + ov.InternalHigh = 0; + ov.Offset = location & 0xFFFFFFFF; + ov.OffsetHigh = location >> 32; + ov.hEvent = 0; + auto rc = ReadFile(hFile, buffer, (DWORD)nr_bytes, &bytes_read, &ov); + if (!rc) { + auto error = LocalFileSystem::GetLastErrorAsString(); + throw IOException("Could not read file \"%s\" (error in ReadFile(location: %llu, nr_bytes: %lld)): %s", + handle.path, location, nr_bytes, error); + } + return bytes_read; +} + +void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + HANDLE hFile = ((WindowsFileHandle &)handle).fd; + auto bytes_read = FSInternalRead(handle, hFile, buffer, nr_bytes, location); + if (bytes_read != nr_bytes) { + throw IOException("Could not read all bytes from file \"%s\": wanted=%lld read=%lld", handle.path, nr_bytes, + bytes_read); + } +} + +int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + HANDLE hFile = handle.Cast().fd; + auto &pos = handle.Cast().position; + auto n = std::min(std::max(GetFileSize(handle), pos) - pos, nr_bytes); + auto bytes_read = FSInternalRead(handle, hFile, buffer, n, pos); + pos += bytes_read; + return bytes_read; +} + +static DWORD FSInternalWrite(FileHandle &handle, HANDLE hFile, void *buffer, int64_t nr_bytes, idx_t location) { + DWORD bytes_written = 0; + OVERLAPPED ov = {}; + ov.Internal = 0; + ov.InternalHigh = 0; + ov.Offset = location & 0xFFFFFFFF; + ov.OffsetHigh = location >> 32; + ov.hEvent = 0; + auto rc = WriteFile(hFile, buffer, (DWORD)nr_bytes, &bytes_written, &ov); + if (!rc) { + auto error = LocalFileSystem::GetLastErrorAsString(); + throw IOException("Could not write file \"%s\" (error in WriteFile): %s", handle.path, error); + } + return bytes_written; +} + +void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + HANDLE hFile = handle.Cast().fd; + auto bytes_written = FSInternalWrite(handle, hFile, buffer, nr_bytes, location); + if (bytes_written != nr_bytes) { + throw IOException("Could not write all bytes from file \"%s\": wanted=%lld wrote=%lld", handle.path, nr_bytes, + bytes_written); + } +} + +int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { + HANDLE hFile = handle.Cast().fd; + auto &pos = handle.Cast().position; + auto bytes_written = FSInternalWrite(handle, hFile, buffer, nr_bytes, pos); + pos += bytes_written; + return bytes_written; +} + +int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { + HANDLE hFile = handle.Cast().fd; + LARGE_INTEGER result; + if (!GetFileSizeEx(hFile, &result)) { + return -1; + } + return result.QuadPart; +} + +time_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { + HANDLE hFile = handle.Cast().fd; + + // https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfiletime + FILETIME last_write; + if (GetFileTime(hFile, nullptr, nullptr, &last_write) == 0) { + return -1; + } + + // https://stackoverflow.com/questions/29266743/what-is-dwlowdatetime-and-dwhighdatetime + ULARGE_INTEGER ul; + ul.LowPart = last_write.dwLowDateTime; + ul.HighPart = last_write.dwHighDateTime; + int64_t fileTime64 = ul.QuadPart; + + // fileTime64 contains a 64-bit value representing the number of + // 100-nanosecond intervals since January 1, 1601 (UTC). + // https://docs.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-filetime + + // Adapted from: https://stackoverflow.com/questions/6161776/convert-windows-filetime-to-second-in-unix-linux + const auto WINDOWS_TICK = 10000000; + const auto SEC_TO_UNIX_EPOCH = 11644473600LL; + time_t result = (fileTime64 / WINDOWS_TICK - SEC_TO_UNIX_EPOCH); + return result; +} + +void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { + HANDLE hFile = handle.Cast().fd; + // seek to the location + SetFilePointer(handle, new_size); + // now set the end of file position + if (!SetEndOfFile(hFile)) { + auto error = LocalFileSystem::GetLastErrorAsString(); + throw IOException("Failure in SetEndOfFile call on file \"%s\": %s", handle.path, error); + } +} + +static DWORD WindowsGetFileAttributes(const string &filename) { + auto unicode_path = WindowsUtil::UTF8ToUnicode(filename.c_str()); + return GetFileAttributesW(unicode_path.c_str()); +} + +bool LocalFileSystem::DirectoryExists(const string &directory) { + DWORD attrs = WindowsGetFileAttributes(directory); + return (attrs != INVALID_FILE_ATTRIBUTES && (attrs & FILE_ATTRIBUTE_DIRECTORY)); +} + +void LocalFileSystem::CreateDirectory(const string &directory) { + if (DirectoryExists(directory)) { + return; + } + auto unicode_path = WindowsUtil::UTF8ToUnicode(directory.c_str()); + if (directory.empty() || !CreateDirectoryW(unicode_path.c_str(), NULL) || !DirectoryExists(directory)) { + throw IOException("Could not create directory: \'%s\'", directory.c_str()); + } +} + +static void DeleteDirectoryRecursive(FileSystem &fs, string directory) { + fs.ListFiles(directory, [&](const string &fname, bool is_directory) { + if (is_directory) { + DeleteDirectoryRecursive(fs, fs.JoinPath(directory, fname)); + } else { + fs.RemoveFile(fs.JoinPath(directory, fname)); + } + }); + auto unicode_path = WindowsUtil::UTF8ToUnicode(directory.c_str()); + if (!RemoveDirectoryW(unicode_path.c_str())) { + auto error = LocalFileSystem::GetLastErrorAsString(); + throw IOException("Failed to delete directory \"%s\": %s", directory, error); + } +} + +void LocalFileSystem::RemoveDirectory(const string &directory) { + if (FileExists(directory)) { + throw IOException("Attempting to delete directory \"%s\", but it is a file and not a directory!", directory); + } + if (!DirectoryExists(directory)) { + return; + } + DeleteDirectoryRecursive(*this, directory.c_str()); +} + +void LocalFileSystem::RemoveFile(const string &filename) { + auto unicode_path = WindowsUtil::UTF8ToUnicode(filename.c_str()); + if (!DeleteFileW(unicode_path.c_str())) { + auto error = LocalFileSystem::GetLastErrorAsString(); + throw IOException("Failed to delete file \"%s\": %s", filename, error); + } +} + +bool LocalFileSystem::ListFiles(const string &directory, const std::function &callback, + FileOpener *opener) { + string search_dir = JoinPath(directory, "*"); + + auto unicode_path = WindowsUtil::UTF8ToUnicode(search_dir.c_str()); + + WIN32_FIND_DATAW ffd; + HANDLE hFind = FindFirstFileW(unicode_path.c_str(), &ffd); + if (hFind == INVALID_HANDLE_VALUE) { + return false; + } + do { + string cFileName = WindowsUtil::UnicodeToUTF8(ffd.cFileName); + if (cFileName == "." || cFileName == "..") { + continue; + } + callback(cFileName, ffd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY); + } while (FindNextFileW(hFind, &ffd) != 0); + + DWORD dwError = GetLastError(); + if (dwError != ERROR_NO_MORE_FILES) { + FindClose(hFind); + return false; + } + + FindClose(hFind); + return true; +} + +void LocalFileSystem::FileSync(FileHandle &handle) { + HANDLE hFile = handle.Cast().fd; + if (FlushFileBuffers(hFile) == 0) { + throw IOException("Could not flush file handle to disk!"); + } +} + +void LocalFileSystem::MoveFile(const string &source, const string &target) { + auto source_unicode = WindowsUtil::UTF8ToUnicode(source.c_str()); + auto target_unicode = WindowsUtil::UTF8ToUnicode(target.c_str()); + if (!MoveFileW(source_unicode.c_str(), target_unicode.c_str())) { + throw IOException("Could not move file: %s", GetLastErrorAsString()); + } +} + +FileType LocalFileSystem::GetFileType(FileHandle &handle) { + auto path = handle.Cast().path; + // pipes in windows are just files in '\\.\pipe\' folder + if (strncmp(path.c_str(), PIPE_PREFIX, strlen(PIPE_PREFIX)) == 0) { + return FileType::FILE_TYPE_FIFO; + } + DWORD attrs = WindowsGetFileAttributes(path.c_str()); + if (attrs != INVALID_FILE_ATTRIBUTES) { + if (attrs & FILE_ATTRIBUTE_DIRECTORY) { + return FileType::FILE_TYPE_DIR; + } else { + return FileType::FILE_TYPE_REGULAR; + } + } + return FileType::FILE_TYPE_INVALID; +} +#endif + +bool LocalFileSystem::CanSeek() { + return true; +} + +bool LocalFileSystem::OnDiskFile(FileHandle &handle) { + return true; +} + +void LocalFileSystem::Seek(FileHandle &handle, idx_t location) { + if (!CanSeek()) { + throw IOException("Cannot seek in files of this type"); + } + SetFilePointer(handle, location); +} + +idx_t LocalFileSystem::SeekPosition(FileHandle &handle) { + if (!CanSeek()) { + throw IOException("Cannot seek in files of this type"); + } + return GetFilePointer(handle); +} + +static bool IsCrawl(const string &glob) { + // glob must match exactly + return glob == "**"; +} +static bool HasMultipleCrawl(const vector &splits) { + return std::count(splits.begin(), splits.end(), "**") > 1; +} +static bool IsSymbolicLink(const string &path) { +#ifndef _WIN32 + struct stat status; + return (lstat(path.c_str(), &status) != -1 && S_ISLNK(status.st_mode)); +#else + auto attributes = WindowsGetFileAttributes(path); + if (attributes == INVALID_FILE_ATTRIBUTES) + return false; + return attributes & FILE_ATTRIBUTE_REPARSE_POINT; +#endif +} + +static void RecursiveGlobDirectories(FileSystem &fs, const string &path, vector &result, bool match_directory, + bool join_path) { + + fs.ListFiles(path, [&](const string &fname, bool is_directory) { + string concat; + if (join_path) { + concat = fs.JoinPath(path, fname); + } else { + concat = fname; + } + if (IsSymbolicLink(concat)) { + return; + } + if (is_directory == match_directory) { + result.push_back(concat); + } + if (is_directory) { + RecursiveGlobDirectories(fs, concat, result, match_directory, true); + } + }); +} + +static void GlobFilesInternal(FileSystem &fs, const string &path, const string &glob, bool match_directory, + vector &result, bool join_path) { + fs.ListFiles(path, [&](const string &fname, bool is_directory) { + if (is_directory != match_directory) { + return; + } + if (LikeFun::Glob(fname.c_str(), fname.size(), glob.c_str(), glob.size())) { + if (join_path) { + result.push_back(fs.JoinPath(path, fname)); + } else { + result.push_back(fname); + } + } + }); +} + +vector LocalFileSystem::FetchFileWithoutGlob(const string &path, FileOpener *opener, bool absolute_path) { + vector result; + if (FileExists(path) || IsPipe(path)) { + result.push_back(path); + } else if (!absolute_path) { + Value value; + if (opener && opener->TryGetCurrentSetting("file_search_path", value)) { + auto search_paths_str = value.ToString(); + vector search_paths = StringUtil::Split(search_paths_str, ','); + for (const auto &search_path : search_paths) { + auto joined_path = JoinPath(search_path, path); + if (FileExists(joined_path) || IsPipe(joined_path)) { + result.push_back(joined_path); + } + } + } + } + return result; +} + +vector LocalFileSystem::Glob(const string &path, FileOpener *opener) { + if (path.empty()) { + return vector(); + } + // split up the path into separate chunks + vector splits; + idx_t last_pos = 0; + for (idx_t i = 0; i < path.size(); i++) { + if (path[i] == '\\' || path[i] == '/') { + if (i == last_pos) { + // empty: skip this position + last_pos = i + 1; + continue; + } + if (splits.empty()) { + splits.push_back(path.substr(0, i)); + } else { + splits.push_back(path.substr(last_pos, i - last_pos)); + } + last_pos = i + 1; + } + } + splits.push_back(path.substr(last_pos, path.size() - last_pos)); + // handle absolute paths + bool absolute_path = false; + if (path[0] == '/') { + // first character is a slash - unix absolute path + absolute_path = true; + } else if (StringUtil::Contains(splits[0], ":")) { + // first split has a colon - windows absolute path + absolute_path = true; + } else if (splits[0] == "~") { + // starts with home directory + auto home_directory = GetHomeDirectory(opener); + if (!home_directory.empty()) { + absolute_path = true; + splits[0] = home_directory; + D_ASSERT(path[0] == '~'); + if (!HasGlob(path)) { + return Glob(home_directory + path.substr(1)); + } + } + } + // Check if the path has a glob at all + if (!HasGlob(path)) { + // no glob: return only the file (if it exists or is a pipe) + return FetchFileWithoutGlob(path, opener, absolute_path); + } + vector previous_directories; + if (absolute_path) { + // for absolute paths, we don't start by scanning the current directory + previous_directories.push_back(splits[0]); + } else { + // If file_search_path is set, use those paths as the first glob elements + Value value; + if (opener && opener->TryGetCurrentSetting("file_search_path", value)) { + auto search_paths_str = value.ToString(); + vector search_paths = StringUtil::Split(search_paths_str, ','); + for (const auto &search_path : search_paths) { + previous_directories.push_back(search_path); + } + } + } + + if (HasMultipleCrawl(splits)) { + throw IOException("Cannot use multiple \'**\' in one path"); + } + + for (idx_t i = absolute_path ? 1 : 0; i < splits.size(); i++) { + bool is_last_chunk = i + 1 == splits.size(); + bool has_glob = HasGlob(splits[i]); + // if it's the last chunk we need to find files, otherwise we find directories + // not the last chunk: gather a list of all directories that match the glob pattern + vector result; + if (!has_glob) { + // no glob, just append as-is + if (previous_directories.empty()) { + result.push_back(splits[i]); + } else { + if (is_last_chunk) { + for (auto &prev_directory : previous_directories) { + const string filename = JoinPath(prev_directory, splits[i]); + if (FileExists(filename) || DirectoryExists(filename)) { + result.push_back(filename); + } + } + } else { + for (auto &prev_directory : previous_directories) { + result.push_back(JoinPath(prev_directory, splits[i])); + } + } + } + } else { + if (IsCrawl(splits[i])) { + if (!is_last_chunk) { + result = previous_directories; + } + if (previous_directories.empty()) { + RecursiveGlobDirectories(*this, ".", result, !is_last_chunk, false); + } else { + for (auto &prev_dir : previous_directories) { + RecursiveGlobDirectories(*this, prev_dir, result, !is_last_chunk, true); + } + } + } else { + if (previous_directories.empty()) { + // no previous directories: list in the current path + GlobFilesInternal(*this, ".", splits[i], !is_last_chunk, result, false); + } else { + // previous directories + // we iterate over each of the previous directories, and apply the glob of the current directory + for (auto &prev_directory : previous_directories) { + GlobFilesInternal(*this, prev_directory, splits[i], !is_last_chunk, result, true); + } + } + } + } + if (result.empty()) { + // no result found that matches the glob + // last ditch effort: search the path as a string literal + return FetchFileWithoutGlob(path, opener, absolute_path); + } + if (is_last_chunk) { + return result; + } + previous_directories = std::move(result); + } + return vector(); +} + +unique_ptr FileSystem::CreateLocal() { + return make_uniq(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/multi_file_reader.cpp b/src/duckdb/src/common/multi_file_reader.cpp new file mode 100644 index 00000000..c68d950c --- /dev/null +++ b/src/duckdb/src/common/multi_file_reader.cpp @@ -0,0 +1,498 @@ +#include "duckdb/common/multi_file_reader.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/common/hive_partitioning.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +void MultiFileReader::AddParameters(TableFunction &table_function) { + table_function.named_parameters["filename"] = LogicalType::BOOLEAN; + table_function.named_parameters["hive_partitioning"] = LogicalType::BOOLEAN; + table_function.named_parameters["union_by_name"] = LogicalType::BOOLEAN; + table_function.named_parameters["hive_types"] = LogicalType::ANY; + table_function.named_parameters["hive_types_autocast"] = LogicalType::BOOLEAN; +} + +vector MultiFileReader::GetFileList(ClientContext &context, const Value &input, const string &name, + FileGlobOptions options) { + auto &config = DBConfig::GetConfig(context); + if (!config.options.enable_external_access) { + throw PermissionException("Scanning %s files is disabled through configuration", name); + } + if (input.IsNull()) { + throw ParserException("%s reader cannot take NULL list as parameter", name); + } + FileSystem &fs = FileSystem::GetFileSystem(context); + vector files; + if (input.type().id() == LogicalTypeId::VARCHAR) { + auto file_name = StringValue::Get(input); + files = fs.GlobFiles(file_name, context, options); + } else if (input.type().id() == LogicalTypeId::LIST) { + for (auto &val : ListValue::GetChildren(input)) { + if (val.IsNull()) { + throw ParserException("%s reader cannot take NULL input as parameter", name); + } + if (val.type().id() != LogicalTypeId::VARCHAR) { + throw ParserException("%s reader can only take a list of strings as a parameter", name); + } + auto glob_files = fs.GlobFiles(StringValue::Get(val), context, options); + files.insert(files.end(), glob_files.begin(), glob_files.end()); + } + } else { + throw InternalException("Unsupported type for MultiFileReader::GetFileList"); + } + if (files.empty() && options == FileGlobOptions::DISALLOW_EMPTY) { + throw IOException("%s reader needs at least one file to read", name); + } + return files; +} + +bool MultiFileReader::ParseOption(const string &key, const Value &val, MultiFileReaderOptions &options, + ClientContext &context) { + auto loption = StringUtil::Lower(key); + if (loption == "filename") { + options.filename = BooleanValue::Get(val); + } else if (loption == "hive_partitioning") { + options.hive_partitioning = BooleanValue::Get(val); + options.auto_detect_hive_partitioning = false; + } else if (loption == "union_by_name") { + options.union_by_name = BooleanValue::Get(val); + } else if (loption == "hive_types_autocast" || loption == "hive_type_autocast") { + options.hive_types_autocast = BooleanValue::Get(val); + } else if (loption == "hive_types" || loption == "hive_type") { + if (val.type().id() != LogicalTypeId::STRUCT) { + throw InvalidInputException( + "'hive_types' only accepts a STRUCT('name':VARCHAR, ...), but '%s' was provided", + val.type().ToString()); + } + // verify that that all the children of the struct value are VARCHAR + auto &children = StructValue::GetChildren(val); + for (idx_t i = 0; i < children.size(); i++) { + const Value &child = children[i]; + if (child.type().id() != LogicalType::VARCHAR) { + throw InvalidInputException("hive_types: '%s' must be a VARCHAR, instead: '%s' was provided", + StructType::GetChildName(val.type(), i), child.type().ToString()); + } + // for every child of the struct, get the logical type + LogicalType transformed_type = TransformStringToLogicalType(child.ToString(), context); + const string &name = StructType::GetChildName(val.type(), i); + options.hive_types_schema[name] = transformed_type; + } + D_ASSERT(!options.hive_types_schema.empty()); + } else { + return false; + } + return true; +} + +bool MultiFileReader::ComplexFilterPushdown(ClientContext &context, vector &files, + const MultiFileReaderOptions &options, LogicalGet &get, + vector> &filters) { + if (files.empty()) { + return false; + } + if (!options.hive_partitioning && !options.filename) { + return false; + } + + unordered_map column_map; + for (idx_t i = 0; i < get.column_ids.size(); i++) { + column_map.insert({get.names[get.column_ids[i]], i}); + } + + auto start_files = files.size(); + HivePartitioning::ApplyFiltersToFileList(context, files, filters, column_map, get, options.hive_partitioning, + options.filename); + + if (files.size() != start_files) { + // we have pruned files + return true; + } + return false; +} + +MultiFileReaderBindData MultiFileReader::BindOptions(MultiFileReaderOptions &options, const vector &files, + vector &return_types, vector &names) { + MultiFileReaderBindData bind_data; + // Add generated constant column for filename + if (options.filename) { + if (std::find(names.begin(), names.end(), "filename") != names.end()) { + throw BinderException("Using filename option on file with column named filename is not supported"); + } + bind_data.filename_idx = names.size(); + return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("filename"); + } + + // Add generated constant columns from hive partitioning scheme + if (options.hive_partitioning) { + D_ASSERT(!files.empty()); + auto partitions = HivePartitioning::Parse(files[0]); + // verify that all files have the same hive partitioning scheme + for (auto &f : files) { + auto file_partitions = HivePartitioning::Parse(f); + for (auto &part_info : partitions) { + if (file_partitions.find(part_info.first) == file_partitions.end()) { + string error = "Hive partition mismatch between file \"%s\" and \"%s\": key \"%s\" not found"; + if (options.auto_detect_hive_partitioning == true) { + throw InternalException(error + "(hive partitioning was autodetected)", files[0], f, + part_info.first); + } + throw BinderException(error.c_str(), files[0], f, part_info.first); + } + } + if (partitions.size() != file_partitions.size()) { + string error_msg = "Hive partition mismatch between file \"%s\" and \"%s\""; + if (options.auto_detect_hive_partitioning == true) { + throw InternalException(error_msg + "(hive partitioning was autodetected)", files[0], f); + } + throw BinderException(error_msg.c_str(), files[0], f); + } + } + + if (!options.hive_types_schema.empty()) { + // verify that all hive_types are existing partitions + options.VerifyHiveTypesArePartitions(partitions); + } + + for (auto &part : partitions) { + idx_t hive_partitioning_index = DConstants::INVALID_INDEX; + auto lookup = std::find(names.begin(), names.end(), part.first); + if (lookup != names.end()) { + // hive partitioning column also exists in file - override + auto idx = lookup - names.begin(); + hive_partitioning_index = idx; + return_types[idx] = options.GetHiveLogicalType(part.first); + } else { + // hive partitioning column does not exist in file - add a new column containing the key + hive_partitioning_index = names.size(); + return_types.emplace_back(options.GetHiveLogicalType(part.first)); + names.emplace_back(part.first); + } + bind_data.hive_partitioning_indexes.emplace_back(part.first, hive_partitioning_index); + } + } + return bind_data; +} + +void MultiFileReader::FinalizeBind(const MultiFileReaderOptions &file_options, const MultiFileReaderBindData &options, + const string &filename, const vector &local_names, + const vector &global_types, const vector &global_names, + const vector &global_column_ids, MultiFileReaderData &reader_data, + ClientContext &context) { + + // create a map of name -> column index + case_insensitive_map_t name_map; + if (file_options.union_by_name) { + for (idx_t col_idx = 0; col_idx < local_names.size(); col_idx++) { + name_map[local_names[col_idx]] = col_idx; + } + } + for (idx_t i = 0; i < global_column_ids.size(); i++) { + auto column_id = global_column_ids[i]; + if (IsRowIdColumnId(column_id)) { + // row-id + reader_data.constant_map.emplace_back(i, Value::BIGINT(42)); + continue; + } + if (column_id == options.filename_idx) { + // filename + reader_data.constant_map.emplace_back(i, Value(filename)); + continue; + } + if (!options.hive_partitioning_indexes.empty()) { + // hive partition constants + auto partitions = HivePartitioning::Parse(filename); + D_ASSERT(partitions.size() == options.hive_partitioning_indexes.size()); + bool found_partition = false; + for (auto &entry : options.hive_partitioning_indexes) { + if (column_id == entry.index) { + Value value = file_options.GetHivePartitionValue(partitions[entry.value], entry.value, context); + reader_data.constant_map.emplace_back(i, value); + found_partition = true; + break; + } + } + if (found_partition) { + continue; + } + } + if (file_options.union_by_name) { + auto &global_name = global_names[column_id]; + auto entry = name_map.find(global_name); + bool not_present_in_file = entry == name_map.end(); + if (not_present_in_file) { + // we need to project a column with name \"global_name\" - but it does not exist in the current file + // push a NULL value of the specified type + reader_data.constant_map.emplace_back(i, Value(global_types[column_id])); + continue; + } + } + } +} + +void MultiFileReader::CreateNameMapping(const string &file_name, const vector &local_types, + const vector &local_names, const vector &global_types, + const vector &global_names, const vector &global_column_ids, + MultiFileReaderData &reader_data, const string &initial_file) { + D_ASSERT(global_types.size() == global_names.size()); + D_ASSERT(local_types.size() == local_names.size()); + // we have expected types: create a map of name -> column index + case_insensitive_map_t name_map; + for (idx_t col_idx = 0; col_idx < local_names.size(); col_idx++) { + name_map[local_names[col_idx]] = col_idx; + } + for (idx_t i = 0; i < global_column_ids.size(); i++) { + // check if this is a constant column + bool constant = false; + for (auto &entry : reader_data.constant_map) { + if (entry.column_id == i) { + constant = true; + break; + } + } + if (constant) { + // this column is constant for this file + continue; + } + // not constant - look up the column in the name map + auto global_id = global_column_ids[i]; + if (global_id >= global_types.size()) { + throw InternalException( + "MultiFileReader::CreatePositionalMapping - global_id is out of range in global_types for this file"); + } + auto &global_name = global_names[global_id]; + auto entry = name_map.find(global_name); + if (entry == name_map.end()) { + string candidate_names; + for (auto &local_name : local_names) { + if (!candidate_names.empty()) { + candidate_names += ", "; + } + candidate_names += local_name; + } + throw IOException( + StringUtil::Format("Failed to read file \"%s\": schema mismatch in glob: column \"%s\" was read from " + "the original file \"%s\", but could not be found in file \"%s\".\nCandidate names: " + "%s\nIf you are trying to " + "read files with different schemas, try setting union_by_name=True", + file_name, global_name, initial_file, file_name, candidate_names)); + } + // we found the column in the local file - check if the types are the same + auto local_id = entry->second; + D_ASSERT(global_id < global_types.size()); + D_ASSERT(local_id < local_types.size()); + auto &global_type = global_types[global_id]; + auto &local_type = local_types[local_id]; + if (global_type != local_type) { + reader_data.cast_map[local_id] = global_type; + } + // the types are the same - create the mapping + reader_data.column_mapping.push_back(i); + reader_data.column_ids.push_back(local_id); + } + reader_data.empty_columns = reader_data.column_ids.empty(); +} + +void MultiFileReader::CreateMapping(const string &file_name, const vector &local_types, + const vector &local_names, const vector &global_types, + const vector &global_names, const vector &global_column_ids, + optional_ptr filters, MultiFileReaderData &reader_data, + const string &initial_file) { + CreateNameMapping(file_name, local_types, local_names, global_types, global_names, global_column_ids, reader_data, + initial_file); + if (filters) { + reader_data.filter_map.resize(global_types.size()); + for (idx_t c = 0; c < reader_data.column_mapping.size(); c++) { + auto map_index = reader_data.column_mapping[c]; + reader_data.filter_map[map_index].index = c; + reader_data.filter_map[map_index].is_constant = false; + } + for (idx_t c = 0; c < reader_data.constant_map.size(); c++) { + auto constant_index = reader_data.constant_map[c].column_id; + reader_data.filter_map[constant_index].index = c; + reader_data.filter_map[constant_index].is_constant = true; + } + } +} + +void MultiFileReader::FinalizeChunk(const MultiFileReaderBindData &bind_data, const MultiFileReaderData &reader_data, + DataChunk &chunk) { + // reference all the constants set up in MultiFileReader::FinalizeBind + for (auto &entry : reader_data.constant_map) { + chunk.data[entry.column_id].Reference(entry.value); + } + chunk.Verify(); +} + +TableFunctionSet MultiFileReader::CreateFunctionSet(TableFunction table_function) { + TableFunctionSet function_set(table_function.name); + function_set.AddFunction(table_function); + D_ASSERT(table_function.arguments.size() == 1 && table_function.arguments[0] == LogicalType::VARCHAR); + table_function.arguments[0] = LogicalType::LIST(LogicalType::VARCHAR); + function_set.AddFunction(std::move(table_function)); + return function_set; +} + +HivePartitioningIndex::HivePartitioningIndex(string value_p, idx_t index) : value(std::move(value_p)), index(index) { +} + +void MultiFileReaderOptions::AddBatchInfo(BindInfo &bind_info) const { + bind_info.InsertOption("filename", Value::BOOLEAN(filename)); + bind_info.InsertOption("hive_partitioning", Value::BOOLEAN(hive_partitioning)); + bind_info.InsertOption("auto_detect_hive_partitioning", Value::BOOLEAN(auto_detect_hive_partitioning)); + bind_info.InsertOption("union_by_name", Value::BOOLEAN(union_by_name)); + bind_info.InsertOption("hive_types_autocast", Value::BOOLEAN(hive_types_autocast)); +} + +void UnionByName::CombineUnionTypes(const vector &col_names, const vector &sql_types, + vector &union_col_types, vector &union_col_names, + case_insensitive_map_t &union_names_map) { + D_ASSERT(col_names.size() == sql_types.size()); + + for (idx_t col = 0; col < col_names.size(); ++col) { + auto union_find = union_names_map.find(col_names[col]); + + if (union_find != union_names_map.end()) { + // given same name , union_col's type must compatible with col's type + auto ¤t_type = union_col_types[union_find->second]; + LogicalType compatible_type; + compatible_type = LogicalType::MaxLogicalType(current_type, sql_types[col]); + union_col_types[union_find->second] = compatible_type; + } else { + union_names_map[col_names[col]] = union_col_names.size(); + union_col_names.emplace_back(col_names[col]); + union_col_types.emplace_back(sql_types[col]); + } + } +} + +bool MultiFileReaderOptions::AutoDetectHivePartitioningInternal(const vector &files, ClientContext &context) { + std::unordered_set partitions; + auto &fs = FileSystem::GetFileSystem(context); + + auto splits_first_file = StringUtil::Split(files.front(), fs.PathSeparator(files.front())); + if (splits_first_file.size() < 2) { + return false; + } + for (auto it = splits_first_file.begin(); it != splits_first_file.end(); it++) { + auto partition = StringUtil::Split(*it, "="); + if (partition.size() == 2) { + partitions.insert(partition.front()); + } + } + if (partitions.empty()) { + return false; + } + for (auto &file : files) { + auto splits = StringUtil::Split(file, fs.PathSeparator(file)); + if (splits.size() != splits_first_file.size()) { + return false; + } + for (auto it = splits.begin(); it != std::prev(splits.end()); it++) { + auto part = StringUtil::Split(*it, "="); + if (part.size() != 2) { + continue; + } + if (partitions.find(part.front()) == partitions.end()) { + return false; + } + } + } + return true; +} +void MultiFileReaderOptions::AutoDetectHiveTypesInternal(const string &file, ClientContext &context) { + auto &fs = FileSystem::GetFileSystem(context); + + std::map partitions; + auto splits = StringUtil::Split(file, fs.PathSeparator(file)); + if (splits.size() < 2) { + return; + } + for (auto it = splits.begin(); it != std::prev(splits.end()); it++) { + auto part = StringUtil::Split(*it, "="); + if (part.size() == 2) { + partitions[part.front()] = part.back(); + } + } + if (partitions.empty()) { + return; + } + + const LogicalType candidates[] = {LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::BIGINT}; + for (auto &part : partitions) { + const string &name = part.first; + if (hive_types_schema.find(name) != hive_types_schema.end()) { + continue; + } + Value value(part.second); + for (auto &candidate : candidates) { + const bool success = value.TryCastAs(context, candidate); + if (success) { + hive_types_schema[name] = candidate; + break; + } + } + } +} +void MultiFileReaderOptions::AutoDetectHivePartitioning(const vector &files, ClientContext &context) { + D_ASSERT(!files.empty()); + const bool hp_explicitly_disabled = !auto_detect_hive_partitioning && !hive_partitioning; + const bool ht_enabled = !hive_types_schema.empty(); + if (hp_explicitly_disabled && ht_enabled) { + throw InvalidInputException("cannot disable hive_partitioning when hive_types is enabled"); + } + if (ht_enabled && auto_detect_hive_partitioning && !hive_partitioning) { + // hive_types flag implies hive_partitioning + hive_partitioning = true; + auto_detect_hive_partitioning = false; + } + if (auto_detect_hive_partitioning) { + hive_partitioning = AutoDetectHivePartitioningInternal(files, context); + } + if (hive_partitioning && hive_types_autocast) { + AutoDetectHiveTypesInternal(files.front(), context); + } +} +void MultiFileReaderOptions::VerifyHiveTypesArePartitions(const std::map &partitions) const { + for (auto &hive_type : hive_types_schema) { + if (partitions.find(hive_type.first) == partitions.end()) { + throw InvalidInputException("Unknown hive_type: \"%s\" does not appear to be a partition", hive_type.first); + } + } +} +LogicalType MultiFileReaderOptions::GetHiveLogicalType(const string &hive_partition_column) const { + if (!hive_types_schema.empty()) { + auto it = hive_types_schema.find(hive_partition_column); + if (it != hive_types_schema.end()) { + return it->second; + } + } + return LogicalType::VARCHAR; +} +Value MultiFileReaderOptions::GetHivePartitionValue(const string &base, const string &entry, + ClientContext &context) const { + Value value(base); + auto it = hive_types_schema.find(entry); + if (it == hive_types_schema.end()) { + return value; + } + + // Handle nulls + if (base.empty() || StringUtil::CIEquals(base, "NULL")) { + return Value(it->second); + } + + if (!value.TryCastAs(context, it->second)) { + throw InvalidInputException("Unable to cast '%s' (from hive partition column '%s') to: '%s'", value.ToString(), + StringUtil::Upper(it->first), it->second.ToString()); + } + return value; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/operator/cast_operators.cpp b/src/duckdb/src/common/operator/cast_operators.cpp new file mode 100644 index 00000000..8de8cd5b --- /dev/null +++ b/src/duckdb/src/common/operator/cast_operators.cpp @@ -0,0 +1,2777 @@ +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/operator/string_cast.hpp" +#include "duckdb/common/operator/numeric_cast.hpp" +#include "duckdb/common/operator/decimal_cast_operators.hpp" +#include "duckdb/common/operator/multiply.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/blob.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/uuid.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types.hpp" +#include "fast_float/fast_float.h" +#include "fmt/format.h" +#include "duckdb/common/types/bit.hpp" + +#include +#include +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Cast bool -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(bool input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(bool input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast int8_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(int8_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int8_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast int16_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(int16_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int16_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast int32_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(int32_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int32_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast int64_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(int64_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(int64_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast hugeint_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(hugeint_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(hugeint_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast uint8_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(uint8_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint8_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast uint16_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(uint16_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint16_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast uint32_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(uint32_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint32_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast uint64_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(uint64_t input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(uint64_t input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast float -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(float input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(float input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast double -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(double input, bool &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, int8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, int16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, int32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, int64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, hugeint_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, uint8_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, uint16_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, uint32_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, uint64_t &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, float &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +template <> +bool TryCast::Operation(double input, double &result, bool strict) { + return NumericTryCast::Operation(input, result, strict); +} + +//===--------------------------------------------------------------------===// +// Cast String -> Numeric +//===--------------------------------------------------------------------===// +template +struct IntegerCastData { + using Result = T; + Result result; + bool seen_decimal; +}; + +struct IntegerCastOperation { + template + static bool HandleDigit(T &state, uint8_t digit) { + using result_t = typename T::Result; + if (NEGATIVE) { + if (state.result < (NumericLimits::Minimum() + digit) / 10) { + return false; + } + state.result = state.result * 10 - digit; + } else { + if (state.result > (NumericLimits::Maximum() - digit) / 10) { + return false; + } + state.result = state.result * 10 + digit; + } + return true; + } + + template + static bool HandleHexDigit(T &state, uint8_t digit) { + using result_t = typename T::Result; + if (state.result > (NumericLimits::Maximum() - digit) / 16) { + return false; + } + state.result = state.result * 16 + digit; + return true; + } + + template + static bool HandleBinaryDigit(T &state, uint8_t digit) { + using result_t = typename T::Result; + if (state.result > (NumericLimits::Maximum() - digit) / 2) { + return false; + } + state.result = state.result * 2 + digit; + return true; + } + + template + static bool HandleExponent(T &state, int32_t exponent) { + using result_t = typename T::Result; + double dbl_res = state.result * std::pow(10.0L, exponent); + if (dbl_res < (double)NumericLimits::Minimum() || + dbl_res > (double)NumericLimits::Maximum()) { + return false; + } + state.result = (result_t)std::nearbyint(dbl_res); + return true; + } + + template + static bool HandleDecimal(T &state, uint8_t digit) { + if (state.seen_decimal) { + return true; + } + state.seen_decimal = true; + // round the integer based on what is after the decimal point + // if digit >= 5, then we round up (or down in case of negative numbers) + auto increment = digit >= 5; + if (!increment) { + return true; + } + if (NEGATIVE) { + if (state.result == NumericLimits::Minimum()) { + return false; + } + state.result--; + } else { + if (state.result == NumericLimits::Maximum()) { + return false; + } + state.result++; + } + return true; + } + + template + static bool Finalize(T &state) { + return true; + } +}; + +template +static bool IntegerCastLoop(const char *buf, idx_t len, T &result, bool strict) { + idx_t start_pos; + if (NEGATIVE) { + start_pos = 1; + } else { + if (*buf == '+') { + if (strict) { + // leading plus is not allowed in strict mode + return false; + } + start_pos = 1; + } else { + start_pos = 0; + } + } + idx_t pos = start_pos; + while (pos < len) { + if (!StringUtil::CharacterIsDigit(buf[pos])) { + // not a digit! + if (buf[pos] == decimal_separator) { + if (strict) { + return false; + } + bool number_before_period = pos > start_pos; + // decimal point: we accept decimal values for integers as well + // we just truncate them + // make sure everything after the period is a number + pos++; + idx_t start_digit = pos; + while (pos < len) { + if (!StringUtil::CharacterIsDigit(buf[pos])) { + break; + } + if (!OP::template HandleDecimal(result, buf[pos] - '0')) { + return false; + } + pos++; + } + // make sure there is either (1) one number after the period, or (2) one number before the period + // i.e. we accept "1." and ".1" as valid numbers, but not "." + if (!(number_before_period || pos > start_digit)) { + return false; + } + if (pos >= len) { + break; + } + } + if (StringUtil::CharacterIsSpace(buf[pos])) { + // skip any trailing spaces + while (++pos < len) { + if (!StringUtil::CharacterIsSpace(buf[pos])) { + return false; + } + } + break; + } + if (ALLOW_EXPONENT) { + if (buf[pos] == 'e' || buf[pos] == 'E') { + if (pos == start_pos) { + return false; + } + pos++; + if (pos >= len) { + return false; + } + using ExponentData = IntegerCastData; + ExponentData exponent {0, false}; + int negative = buf[pos] == '-'; + if (negative) { + if (!IntegerCastLoop( + buf + pos, len - pos, exponent, strict)) { + return false; + } + } else { + if (!IntegerCastLoop( + buf + pos, len - pos, exponent, strict)) { + return false; + } + } + return OP::template HandleExponent(result, exponent.result); + } + } + return false; + } + uint8_t digit = buf[pos++] - '0'; + if (!OP::template HandleDigit(result, digit)) { + return false; + } + } + if (!OP::template Finalize(result)) { + return false; + } + return pos > start_pos; +} + +template +static bool IntegerHexCastLoop(const char *buf, idx_t len, T &result, bool strict) { + if (ALLOW_EXPONENT || NEGATIVE) { + return false; + } + idx_t start_pos = 1; + idx_t pos = start_pos; + char current_char; + while (pos < len) { + current_char = StringUtil::CharacterToLower(buf[pos]); + if (!StringUtil::CharacterIsHex(current_char)) { + return false; + } + uint8_t digit; + if (current_char >= 'a') { + digit = current_char - 'a' + 10; + } else { + digit = current_char - '0'; + } + pos++; + if (!OP::template HandleHexDigit(result, digit)) { + return false; + } + } + if (!OP::template Finalize(result)) { + return false; + } + return pos > start_pos; +} + +template +static bool IntegerBinaryCastLoop(const char *buf, idx_t len, T &result, bool strict) { + if (ALLOW_EXPONENT || NEGATIVE) { + return false; + } + idx_t start_pos = 1; + idx_t pos = start_pos; + uint8_t digit; + char current_char; + while (pos < len) { + current_char = buf[pos]; + if (current_char == '_' && pos > start_pos) { + // skip underscore, if it is not the first character + pos++; + if (pos == len) { + // we cant end on an underscore either + return false; + } + continue; + } else if (current_char == '0') { + digit = 0; + } else if (current_char == '1') { + digit = 1; + } else { + return false; + } + pos++; + if (!OP::template HandleBinaryDigit(result, digit)) { + return false; + } + } + if (!OP::template Finalize(result)) { + return false; + } + return pos > start_pos; +} + +template +static bool TryIntegerCast(const char *buf, idx_t len, T &result, bool strict) { + // skip any spaces at the start + while (len > 0 && StringUtil::CharacterIsSpace(*buf)) { + buf++; + len--; + } + if (len == 0) { + return false; + } + if (ZERO_INITIALIZE) { + memset(&result, 0, sizeof(T)); + } + // if the number is negative, we set the negative flag and skip the negative sign + if (*buf == '-') { + if (!IS_SIGNED) { + // Need to check if its not -0 + idx_t pos = 1; + while (pos < len) { + if (buf[pos++] != '0') { + return false; + } + } + } + return IntegerCastLoop(buf, len, result, strict); + } + if (len > 1 && *buf == '0') { + if (buf[1] == 'x' || buf[1] == 'X') { + // If it starts with 0x or 0X, we parse it as a hex value + buf++; + len--; + return IntegerHexCastLoop(buf, len, result, strict); + } else if (buf[1] == 'b' || buf[1] == 'B') { + // If it starts with 0b or 0B, we parse it as a binary value + buf++; + len--; + return IntegerBinaryCastLoop(buf, len, result, strict); + } else if (strict && StringUtil::CharacterIsDigit(buf[1])) { + // leading zeros are not allowed in strict mode + return false; + } + } + return IntegerCastLoop(buf, len, result, strict); +} + +template +static inline bool TrySimpleIntegerCast(const char *buf, idx_t len, T &result, bool strict) { + IntegerCastData data; + if (TryIntegerCast, IS_SIGNED>(buf, len, data, strict)) { + result = data.result; + return true; + } + return false; +} + +template <> +bool TryCast::Operation(string_t input, bool &result, bool strict) { + auto input_data = input.GetData(); + auto input_size = input.GetSize(); + + switch (input_size) { + case 1: { + char c = std::tolower(*input_data); + if (c == 't' || (!strict && c == '1')) { + result = true; + return true; + } else if (c == 'f' || (!strict && c == '0')) { + result = false; + return true; + } + return false; + } + case 4: { + char t = std::tolower(input_data[0]); + char r = std::tolower(input_data[1]); + char u = std::tolower(input_data[2]); + char e = std::tolower(input_data[3]); + if (t == 't' && r == 'r' && u == 'u' && e == 'e') { + result = true; + return true; + } + return false; + } + case 5: { + char f = std::tolower(input_data[0]); + char a = std::tolower(input_data[1]); + char l = std::tolower(input_data[2]); + char s = std::tolower(input_data[3]); + char e = std::tolower(input_data[4]); + if (f == 'f' && a == 'a' && l == 'l' && s == 's' && e == 'e') { + result = false; + return true; + } + return false; + } + default: + return false; + } +} +template <> +bool TryCast::Operation(string_t input, int8_t &result, bool strict) { + return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); +} +template <> +bool TryCast::Operation(string_t input, int16_t &result, bool strict) { + return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); +} +template <> +bool TryCast::Operation(string_t input, int32_t &result, bool strict) { + return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); +} +template <> +bool TryCast::Operation(string_t input, int64_t &result, bool strict) { + return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); +} + +template <> +bool TryCast::Operation(string_t input, uint8_t &result, bool strict) { + return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); +} +template <> +bool TryCast::Operation(string_t input, uint16_t &result, bool strict) { + return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); +} +template <> +bool TryCast::Operation(string_t input, uint32_t &result, bool strict) { + return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); +} +template <> +bool TryCast::Operation(string_t input, uint64_t &result, bool strict) { + return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); +} + +template +static bool TryDoubleCast(const char *buf, idx_t len, T &result, bool strict) { + // skip any spaces at the start + while (len > 0 && StringUtil::CharacterIsSpace(*buf)) { + buf++; + len--; + } + if (len == 0) { + return false; + } + if (*buf == '+') { + if (strict) { + // plus is not allowed in strict mode + return false; + } + buf++; + len--; + } + if (strict && len >= 2) { + if (buf[0] == '0' && StringUtil::CharacterIsDigit(buf[1])) { + // leading zeros are not allowed in strict mode + return false; + } + } + auto endptr = buf + len; + auto parse_result = duckdb_fast_float::from_chars(buf, buf + len, result, decimal_separator); + if (parse_result.ec != std::errc()) { + return false; + } + auto current_end = parse_result.ptr; + if (!strict) { + while (current_end < endptr && StringUtil::CharacterIsSpace(*current_end)) { + current_end++; + } + } + return current_end == endptr; +} + +template <> +bool TryCast::Operation(string_t input, float &result, bool strict) { + return TryDoubleCast(input.GetData(), input.GetSize(), result, strict); +} + +template <> +bool TryCast::Operation(string_t input, double &result, bool strict) { + return TryDoubleCast(input.GetData(), input.GetSize(), result, strict); +} + +template <> +bool TryCastErrorMessageCommaSeparated::Operation(string_t input, float &result, string *error_message, bool strict) { + if (!TryDoubleCast(input.GetData(), input.GetSize(), result, strict)) { + HandleCastError::AssignError(StringUtil::Format("Could not cast string to float: \"%s\"", input.GetString()), + error_message); + return false; + } + return true; +} + +template <> +bool TryCastErrorMessageCommaSeparated::Operation(string_t input, double &result, string *error_message, bool strict) { + if (!TryDoubleCast(input.GetData(), input.GetSize(), result, strict)) { + HandleCastError::AssignError(StringUtil::Format("Could not cast string to double: \"%s\"", input.GetString()), + error_message); + return false; + } + return true; +} + +//===--------------------------------------------------------------------===// +// Cast From Date +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(date_t input, date_t &result, bool strict) { + result = input; + return true; +} + +template <> +bool TryCast::Operation(date_t input, timestamp_t &result, bool strict) { + if (input == date_t::infinity()) { + result = timestamp_t::infinity(); + return true; + } else if (input == date_t::ninfinity()) { + result = timestamp_t::ninfinity(); + return true; + } + return Timestamp::TryFromDatetime(input, Time::FromTime(0, 0, 0), result); +} + +//===--------------------------------------------------------------------===// +// Cast From Time +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(dtime_t input, dtime_t &result, bool strict) { + result = input; + return true; +} + +template <> +bool TryCast::Operation(dtime_t input, dtime_tz_t &result, bool strict) { + result = dtime_tz_t(input, 0); + return true; +} + +//===--------------------------------------------------------------------===// +// Cast From Time With Time Zone (Offset) +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(dtime_tz_t input, dtime_tz_t &result, bool strict) { + result = input; + return true; +} + +template <> +bool TryCast::Operation(dtime_tz_t input, dtime_t &result, bool strict) { + result = input.time(); + return true; +} + +//===--------------------------------------------------------------------===// +// Cast From Timestamps +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(timestamp_t input, date_t &result, bool strict) { + result = Timestamp::GetDate(input); + return true; +} + +template <> +bool TryCast::Operation(timestamp_t input, dtime_t &result, bool strict) { + if (!Timestamp::IsFinite(input)) { + return false; + } + result = Timestamp::GetTime(input); + return true; +} + +template <> +bool TryCast::Operation(timestamp_t input, timestamp_t &result, bool strict) { + result = input; + return true; +} + +template <> +bool TryCast::Operation(timestamp_t input, dtime_tz_t &result, bool strict) { + if (!Timestamp::IsFinite(input)) { + return false; + } + result = dtime_tz_t(Timestamp::GetTime(input), 0); + return true; +} + +//===--------------------------------------------------------------------===// +// Cast from Interval +//===--------------------------------------------------------------------===// +template <> +bool TryCast::Operation(interval_t input, interval_t &result, bool strict) { + result = input; + return true; +} + +//===--------------------------------------------------------------------===// +// Non-Standard Timestamps +//===--------------------------------------------------------------------===// +template <> +duckdb::string_t CastFromTimestampNS::Operation(duckdb::timestamp_t input, Vector &result) { + return StringCast::Operation(Timestamp::FromEpochNanoSeconds(input.value), result); +} +template <> +duckdb::string_t CastFromTimestampMS::Operation(duckdb::timestamp_t input, Vector &result) { + return StringCast::Operation(Timestamp::FromEpochMs(input.value), result); +} +template <> +duckdb::string_t CastFromTimestampSec::Operation(duckdb::timestamp_t input, Vector &result) { + return StringCast::Operation(Timestamp::FromEpochSeconds(input.value), result); +} + +template <> +timestamp_t CastTimestampUsToMs::Operation(timestamp_t input) { + timestamp_t cast_timestamp(Timestamp::GetEpochMs(input)); + return cast_timestamp; +} + +template <> +timestamp_t CastTimestampUsToNs::Operation(timestamp_t input) { + timestamp_t cast_timestamp(Timestamp::GetEpochNanoSeconds(input)); + return cast_timestamp; +} + +template <> +timestamp_t CastTimestampUsToSec::Operation(timestamp_t input) { + timestamp_t cast_timestamp(Timestamp::GetEpochSeconds(input)); + return cast_timestamp; +} +template <> +timestamp_t CastTimestampMsToUs::Operation(timestamp_t input) { + return Timestamp::FromEpochMs(input.value); +} + +template <> +timestamp_t CastTimestampMsToNs::Operation(timestamp_t input) { + auto us = CastTimestampMsToUs::Operation(input); + return CastTimestampUsToNs::Operation(us); +} + +template <> +timestamp_t CastTimestampNsToUs::Operation(timestamp_t input) { + return Timestamp::FromEpochNanoSeconds(input.value); +} + +template <> +timestamp_t CastTimestampSecToUs::Operation(timestamp_t input) { + return Timestamp::FromEpochSeconds(input.value); +} + +template <> +timestamp_t CastTimestampSecToMs::Operation(timestamp_t input) { + auto us = CastTimestampSecToUs::Operation(input); + return CastTimestampUsToMs::Operation(us); +} + +template <> +timestamp_t CastTimestampSecToNs::Operation(timestamp_t input) { + auto us = CastTimestampSecToUs::Operation(input); + return CastTimestampUsToNs::Operation(us); +} + +//===--------------------------------------------------------------------===// +// Cast To Timestamp +//===--------------------------------------------------------------------===// +template <> +bool TryCastToTimestampNS::Operation(string_t input, timestamp_t &result, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + return false; + } + result = Timestamp::GetEpochNanoSeconds(result); + return true; +} + +template <> +bool TryCastToTimestampMS::Operation(string_t input, timestamp_t &result, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + return false; + } + result = Timestamp::GetEpochMs(result); + return true; +} + +template <> +bool TryCastToTimestampSec::Operation(string_t input, timestamp_t &result, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + return false; + } + result = Timestamp::GetEpochSeconds(result); + return true; +} + +template <> +bool TryCastToTimestampNS::Operation(date_t input, timestamp_t &result, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + return false; + } + if (!TryMultiplyOperator::Operation(result.value, Interval::NANOS_PER_MICRO, result.value)) { + return false; + } + return true; +} + +template <> +bool TryCastToTimestampMS::Operation(date_t input, timestamp_t &result, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + return false; + } + result.value /= Interval::MICROS_PER_MSEC; + return true; +} + +template <> +bool TryCastToTimestampSec::Operation(date_t input, timestamp_t &result, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + return false; + } + result.value /= Interval::MICROS_PER_MSEC * Interval::MSECS_PER_SEC; + return true; +} + +//===--------------------------------------------------------------------===// +// Cast From Blob +//===--------------------------------------------------------------------===// +template <> +string_t CastFromBlob::Operation(string_t input, Vector &vector) { + idx_t result_size = Blob::GetStringSize(input); + + string_t result = StringVector::EmptyString(vector, result_size); + Blob::ToString(input, result.GetDataWriteable()); + result.Finalize(); + + return result; +} + +template <> +string_t CastFromBlobToBit::Operation(string_t input, Vector &vector) { + idx_t result_size = input.GetSize() + 1; + if (result_size <= 1) { + throw ConversionException("Cannot cast empty BLOB to BIT"); + } + return StringVector::AddStringOrBlob(vector, Bit::BlobToBit(input)); +} + +//===--------------------------------------------------------------------===// +// Cast From Bit +//===--------------------------------------------------------------------===// +template <> +string_t CastFromBitToString::Operation(string_t input, Vector &vector) { + + idx_t result_size = Bit::BitLength(input); + string_t result = StringVector::EmptyString(vector, result_size); + Bit::ToString(input, result.GetDataWriteable()); + result.Finalize(); + + return result; +} + +//===--------------------------------------------------------------------===// +// Cast From Pointer +//===--------------------------------------------------------------------===// +template <> +string_t CastFromPointer::Operation(uintptr_t input, Vector &vector) { + std::string s = duckdb_fmt::format("0x{:x}", input); + return StringVector::AddString(vector, s); +} + +//===--------------------------------------------------------------------===// +// Cast To Blob +//===--------------------------------------------------------------------===// +template <> +bool TryCastToBlob::Operation(string_t input, string_t &result, Vector &result_vector, string *error_message, + bool strict) { + idx_t result_size; + if (!Blob::TryGetBlobSize(input, result_size, error_message)) { + return false; + } + + result = StringVector::EmptyString(result_vector, result_size); + Blob::ToBlob(input, data_ptr_cast(result.GetDataWriteable())); + result.Finalize(); + return true; +} + +//===--------------------------------------------------------------------===// +// Cast To Bit +//===--------------------------------------------------------------------===// +template <> +bool TryCastToBit::Operation(string_t input, string_t &result, Vector &result_vector, string *error_message, + bool strict) { + idx_t result_size; + if (!Bit::TryGetBitStringSize(input, result_size, error_message)) { + return false; + } + + result = StringVector::EmptyString(result_vector, result_size); + Bit::ToBit(input, result); + result.Finalize(); + return true; +} + +template <> +bool CastFromBitToNumeric::Operation(string_t input, bool &result, bool strict) { + D_ASSERT(input.GetSize() > 1); + + uint8_t value; + bool success = CastFromBitToNumeric::Operation(input, value, strict); + result = (value > 0); + return (success); +} + +template <> +bool CastFromBitToNumeric::Operation(string_t input, hugeint_t &result, bool strict) { + D_ASSERT(input.GetSize() > 1); + + if (input.GetSize() - 1 > sizeof(hugeint_t)) { + throw ConversionException("Bitstring doesn't fit inside of %s", GetTypeId()); + } + Bit::BitToNumeric(input, result); + if (result < NumericLimits::Minimum()) { + throw ConversionException("Minimum limit for HUGEINT is %s", NumericLimits::Minimum().ToString()); + } + return (true); +} + +//===--------------------------------------------------------------------===// +// Cast From UUID +//===--------------------------------------------------------------------===// +template <> +string_t CastFromUUID::Operation(hugeint_t input, Vector &vector) { + string_t result = StringVector::EmptyString(vector, 36); + UUID::ToString(input, result.GetDataWriteable()); + result.Finalize(); + return result; +} + +//===--------------------------------------------------------------------===// +// Cast To UUID +//===--------------------------------------------------------------------===// +template <> +bool TryCastToUUID::Operation(string_t input, hugeint_t &result, Vector &result_vector, string *error_message, + bool strict) { + return UUID::FromString(input.GetString(), result); +} + +//===--------------------------------------------------------------------===// +// Cast To Date +//===--------------------------------------------------------------------===// +template <> +bool TryCastErrorMessage::Operation(string_t input, date_t &result, string *error_message, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + HandleCastError::AssignError(Date::ConversionError(input), error_message); + return false; + } + return true; +} + +template <> +bool TryCast::Operation(string_t input, date_t &result, bool strict) { + idx_t pos; + bool special = false; + return Date::TryConvertDate(input.GetData(), input.GetSize(), pos, result, special, strict); +} + +template <> +date_t Cast::Operation(string_t input) { + return Date::FromCString(input.GetData(), input.GetSize()); +} + +//===--------------------------------------------------------------------===// +// Cast To Time +//===--------------------------------------------------------------------===// +template <> +bool TryCastErrorMessage::Operation(string_t input, dtime_t &result, string *error_message, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + HandleCastError::AssignError(Time::ConversionError(input), error_message); + return false; + } + return true; +} + +template <> +bool TryCast::Operation(string_t input, dtime_t &result, bool strict) { + idx_t pos; + return Time::TryConvertTime(input.GetData(), input.GetSize(), pos, result, strict); +} + +template <> +dtime_t Cast::Operation(string_t input) { + return Time::FromCString(input.GetData(), input.GetSize()); +} + +//===--------------------------------------------------------------------===// +// Cast To TimeTZ +//===--------------------------------------------------------------------===// +template <> +bool TryCastErrorMessage::Operation(string_t input, dtime_tz_t &result, string *error_message, bool strict) { + if (!TryCast::Operation(input, result, strict)) { + HandleCastError::AssignError(Time::ConversionError(input), error_message); + return false; + } + return true; +} + +template <> +bool TryCast::Operation(string_t input, dtime_tz_t &result, bool strict) { + idx_t pos; + return Time::TryConvertTimeTZ(input.GetData(), input.GetSize(), pos, result, strict); +} + +template <> +dtime_tz_t Cast::Operation(string_t input) { + dtime_tz_t result; + if (!TryCast::Operation(input, result, false)) { + throw ConversionException(Time::ConversionError(input)); + } + return result; +} + +//===--------------------------------------------------------------------===// +// Cast To Timestamp +//===--------------------------------------------------------------------===// +template <> +bool TryCastErrorMessage::Operation(string_t input, timestamp_t &result, string *error_message, bool strict) { + auto cast_result = Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result); + if (cast_result == TimestampCastResult::SUCCESS) { + return true; + } + if (cast_result == TimestampCastResult::ERROR_INCORRECT_FORMAT) { + HandleCastError::AssignError(Timestamp::ConversionError(input), error_message); + } else { + HandleCastError::AssignError(Timestamp::UnsupportedTimezoneError(input), error_message); + } + return false; +} + +template <> +bool TryCast::Operation(string_t input, timestamp_t &result, bool strict) { + return Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result) == TimestampCastResult::SUCCESS; +} + +template <> +timestamp_t Cast::Operation(string_t input) { + return Timestamp::FromCString(input.GetData(), input.GetSize()); +} + +//===--------------------------------------------------------------------===// +// Cast From Interval +//===--------------------------------------------------------------------===// +template <> +bool TryCastErrorMessage::Operation(string_t input, interval_t &result, string *error_message, bool strict) { + return Interval::FromCString(input.GetData(), input.GetSize(), result, error_message, strict); +} + +//===--------------------------------------------------------------------===// +// Cast From Hugeint +//===--------------------------------------------------------------------===// +// parsing hugeint from string is done a bit differently for performance reasons +// for other integer types we keep track of a single value +// and multiply that value by 10 for every digit we read +// however, for hugeints, multiplication is very expensive (>20X as expensive as for int64) +// for that reason, we parse numbers first into an int64 value +// when that value is full, we perform a HUGEINT multiplication to flush it into the hugeint +// this takes the number of HUGEINT multiplications down from [0-38] to [0-2] +struct HugeIntCastData { + hugeint_t hugeint; + int64_t intermediate; + uint8_t digits; + bool decimal; + + bool Flush() { + if (digits == 0 && intermediate == 0) { + return true; + } + if (hugeint.lower != 0 || hugeint.upper != 0) { + if (digits > 38) { + return false; + } + if (!Hugeint::TryMultiply(hugeint, Hugeint::POWERS_OF_TEN[digits], hugeint)) { + return false; + } + } + if (!Hugeint::AddInPlace(hugeint, hugeint_t(intermediate))) { + return false; + } + digits = 0; + intermediate = 0; + return true; + } +}; + +struct HugeIntegerCastOperation { + template + static bool HandleDigit(T &result, uint8_t digit) { + if (NEGATIVE) { + if (result.intermediate < (NumericLimits::Minimum() + digit) / 10) { + // intermediate is full: need to flush it + if (!result.Flush()) { + return false; + } + } + result.intermediate = result.intermediate * 10 - digit; + } else { + if (result.intermediate > (NumericLimits::Maximum() - digit) / 10) { + if (!result.Flush()) { + return false; + } + } + result.intermediate = result.intermediate * 10 + digit; + } + result.digits++; + return true; + } + + template + static bool HandleHexDigit(T &result, uint8_t digit) { + return false; + } + + template + static bool HandleBinaryDigit(T &result, uint8_t digit) { + if (result.intermediate > (NumericLimits::Maximum() - digit) / 2) { + // intermediate is full: need to flush it + if (!result.Flush()) { + return false; + } + } + result.intermediate = result.intermediate * 2 + digit; + result.digits++; + return true; + } + + template + static bool HandleExponent(T &result, int32_t exponent) { + if (!result.Flush()) { + return false; + } + if (exponent < -38 || exponent > 38) { + // out of range for exact exponent: use double and convert + double dbl_res = Hugeint::Cast(result.hugeint) * std::pow(10.0L, exponent); + if (dbl_res < Hugeint::Cast(NumericLimits::Minimum()) || + dbl_res > Hugeint::Cast(NumericLimits::Maximum())) { + return false; + } + result.hugeint = Hugeint::Convert(dbl_res); + return true; + } + if (exponent < 0) { + // negative exponent: divide by power of 10 + result.hugeint = Hugeint::Divide(result.hugeint, Hugeint::POWERS_OF_TEN[-exponent]); + return true; + } else { + // positive exponent: multiply by power of 10 + return Hugeint::TryMultiply(result.hugeint, Hugeint::POWERS_OF_TEN[exponent], result.hugeint); + } + } + + template + static bool HandleDecimal(T &result, uint8_t digit) { + // Integer casts round + if (!result.decimal) { + if (!result.Flush()) { + return false; + } + if (NEGATIVE) { + result.intermediate = -(digit >= 5); + } else { + result.intermediate = (digit >= 5); + } + } + result.decimal = true; + + return true; + } + + template + static bool Finalize(T &result) { + return result.Flush(); + } +}; + +template <> +bool TryCast::Operation(string_t input, hugeint_t &result, bool strict) { + HugeIntCastData data; + if (!TryIntegerCast(input.GetData(), input.GetSize(), data, + strict)) { + return false; + } + result = data.hugeint; + return true; +} + +//===--------------------------------------------------------------------===// +// Decimal String Cast +//===--------------------------------------------------------------------===// + +template +struct DecimalCastData { + typedef TYPE type_t; + TYPE result; + uint8_t width; + uint8_t scale; + uint8_t digit_count; + uint8_t decimal_count; + //! Whether we have determined if the result should be rounded + bool round_set; + //! If the result should be rounded + bool should_round; + //! Only set when ALLOW_EXPONENT is enabled + enum class ExponentType : uint8_t { NONE, POSITIVE, NEGATIVE }; + uint8_t excessive_decimals; + ExponentType exponent_type; +}; + +struct DecimalCastOperation { + template + static bool HandleDigit(T &state, uint8_t digit) { + if (state.result == 0 && digit == 0) { + // leading zero's don't count towards the digit count + return true; + } + if (state.digit_count == state.width - state.scale) { + // width of decimal type is exceeded! + return false; + } + state.digit_count++; + if (NEGATIVE) { + if (state.result < (NumericLimits::Minimum() / 10)) { + return false; + } + state.result = state.result * 10 - digit; + } else { + if (state.result > (NumericLimits::Maximum() / 10)) { + return false; + } + state.result = state.result * 10 + digit; + } + return true; + } + + template + static bool HandleHexDigit(T &state, uint8_t digit) { + return false; + } + + template + static bool HandleBinaryDigit(T &state, uint8_t digit) { + return false; + } + + template + static void RoundUpResult(T &state) { + if (NEGATIVE) { + state.result -= 1; + } else { + state.result += 1; + } + } + + template + static bool HandleExponent(T &state, int32_t exponent) { + auto decimal_excess = (state.decimal_count > state.scale) ? state.decimal_count - state.scale : 0; + if (exponent > 0) { + state.exponent_type = T::ExponentType::POSITIVE; + // Positive exponents need up to 'exponent' amount of digits + // Everything beyond that amount needs to be truncated + if (decimal_excess > exponent) { + // We've allowed too many decimals + state.excessive_decimals = decimal_excess - exponent; + exponent = 0; + } else { + exponent -= decimal_excess; + } + D_ASSERT(exponent >= 0); + } else if (exponent < 0) { + state.exponent_type = T::ExponentType::NEGATIVE; + } + if (!Finalize(state)) { + return false; + } + if (exponent < 0) { + bool round_up = false; + for (idx_t i = 0; i < idx_t(-int64_t(exponent)); i++) { + auto mod = state.result % 10; + round_up = NEGATIVE ? mod <= -5 : mod >= 5; + state.result /= 10; + if (state.result == 0) { + break; + } + } + if (round_up) { + RoundUpResult(state); + } + return true; + } else { + // positive exponent: append 0's + for (idx_t i = 0; i < idx_t(exponent); i++) { + if (!HandleDigit(state, 0)) { + return false; + } + } + return true; + } + } + + template + static bool HandleDecimal(T &state, uint8_t digit) { + if (state.decimal_count == state.scale && !state.round_set) { + // Determine whether the last registered decimal should be rounded or not + state.round_set = true; + state.should_round = digit >= 5; + } + if (!ALLOW_EXPONENT && state.decimal_count == state.scale) { + // we exceeded the amount of supported decimals + // however, we don't throw an error here + // we just truncate the decimal + return true; + } + //! If we expect an exponent, we need to preserve the decimals + //! But we don't want to overflow, so we prevent overflowing the result with this check + if (state.digit_count + state.decimal_count >= DecimalWidth::max) { + return true; + } + state.decimal_count++; + if (NEGATIVE) { + state.result = state.result * 10 - digit; + } else { + state.result = state.result * 10 + digit; + } + return true; + } + + template + static bool TruncateExcessiveDecimals(T &state) { + D_ASSERT(state.excessive_decimals); + bool round_up = false; + for (idx_t i = 0; i < state.excessive_decimals; i++) { + auto mod = state.result % 10; + round_up = NEGATIVE ? mod <= -5 : mod >= 5; + state.result /= 10.0; + } + //! Only round up when exponents are involved + if (state.exponent_type == T::ExponentType::POSITIVE && round_up) { + RoundUpResult(state); + } + D_ASSERT(state.decimal_count > state.scale); + state.decimal_count = state.scale; + return true; + } + + template + static bool Finalize(T &state) { + if (state.exponent_type != T::ExponentType::POSITIVE && state.decimal_count > state.scale) { + //! Did not encounter an exponent, but ALLOW_EXPONENT was on + state.excessive_decimals = state.decimal_count - state.scale; + } + if (state.excessive_decimals && !TruncateExcessiveDecimals(state)) { + return false; + } + if (state.exponent_type == T::ExponentType::NONE && state.round_set && state.should_round) { + RoundUpResult(state); + } + // if we have not gotten exactly "scale" decimals, we need to multiply the result + // e.g. if we have a string "1.0" that is cast to a DECIMAL(9,3), the value needs to be 1000 + // but we have only gotten the value "10" so far, so we multiply by 1000 + for (uint8_t i = state.decimal_count; i < state.scale; i++) { + state.result *= 10; + } + return true; + } +}; + +template +bool TryDecimalStringCast(string_t input, T &result, string *error_message, uint8_t width, uint8_t scale) { + DecimalCastData state; + state.result = 0; + state.width = width; + state.scale = scale; + state.digit_count = 0; + state.decimal_count = 0; + state.excessive_decimals = 0; + state.exponent_type = DecimalCastData::ExponentType::NONE; + state.round_set = false; + state.should_round = false; + if (!TryIntegerCast, true, true, DecimalCastOperation, false, decimal_separator>( + input.GetData(), input.GetSize(), state, false)) { + string error = StringUtil::Format("Could not convert string \"%s\" to DECIMAL(%d,%d)", input.GetString(), + (int)width, (int)scale); + HandleCastError::AssignError(error, error_message); + return false; + } + result = state.result; + return true; +} + +template <> +bool TryCastToDecimal::Operation(string_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryDecimalStringCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(string_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryDecimalStringCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(string_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryDecimalStringCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(string_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryDecimalStringCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimalCommaSeparated::Operation(string_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryDecimalStringCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimalCommaSeparated::Operation(string_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryDecimalStringCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimalCommaSeparated::Operation(string_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryDecimalStringCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimalCommaSeparated::Operation(string_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryDecimalStringCast(input, result, error_message, width, scale); +} + +template <> +string_t StringCastFromDecimal::Operation(int16_t input, uint8_t width, uint8_t scale, Vector &result) { + return DecimalToString::Format(input, width, scale, result); +} + +template <> +string_t StringCastFromDecimal::Operation(int32_t input, uint8_t width, uint8_t scale, Vector &result) { + return DecimalToString::Format(input, width, scale, result); +} + +template <> +string_t StringCastFromDecimal::Operation(int64_t input, uint8_t width, uint8_t scale, Vector &result) { + return DecimalToString::Format(input, width, scale, result); +} + +template <> +string_t StringCastFromDecimal::Operation(hugeint_t input, uint8_t width, uint8_t scale, Vector &result) { + return HugeintToStringCast::FormatDecimal(input, width, scale, result); +} + +//===--------------------------------------------------------------------===// +// Decimal Casts +//===--------------------------------------------------------------------===// +// Decimal <-> Bool +//===--------------------------------------------------------------------===// +template +bool TryCastBoolToDecimal(bool input, T &result, string *error_message, uint8_t width, uint8_t scale) { + if (width > scale) { + result = input ? OP::POWERS_OF_TEN[scale] : 0; + return true; + } else { + return TryCast::Operation(input, result); + } +} + +template <> +bool TryCastToDecimal::Operation(bool input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastBoolToDecimal(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(bool input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastBoolToDecimal(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(bool input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastBoolToDecimal(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(bool input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastBoolToDecimal(input, result, error_message, width, scale); +} + +template <> +bool TryCastFromDecimal::Operation(int16_t input, bool &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCast::Operation(input, result); +} + +template <> +bool TryCastFromDecimal::Operation(int32_t input, bool &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCast::Operation(input, result); +} + +template <> +bool TryCastFromDecimal::Operation(int64_t input, bool &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCast::Operation(input, result); +} + +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, bool &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCast::Operation(input, result); +} + +//===--------------------------------------------------------------------===// +// Numeric -> Decimal Cast +//===--------------------------------------------------------------------===// +struct SignedToDecimalOperator { + template + static bool Operation(SRC input, DST max_width) { + return int64_t(input) >= int64_t(max_width) || int64_t(input) <= int64_t(-max_width); + } +}; + +struct UnsignedToDecimalOperator { + template + static bool Operation(SRC input, DST max_width) { + return uint64_t(input) >= uint64_t(max_width); + } +}; + +template +bool StandardNumericToDecimalCast(SRC input, DST &result, string *error_message, uint8_t width, uint8_t scale) { + // check for overflow + DST max_width = NumericHelper::POWERS_OF_TEN[width - scale]; + if (OP::template Operation(input, max_width)) { + string error = StringUtil::Format("Could not cast value %d to DECIMAL(%d,%d)", input, width, scale); + HandleCastError::AssignError(error, error_message); + return false; + } + result = DST(input) * NumericHelper::POWERS_OF_TEN[scale]; + return true; +} + +template +bool NumericToHugeDecimalCast(SRC input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { + // check for overflow + hugeint_t max_width = Hugeint::POWERS_OF_TEN[width - scale]; + hugeint_t hinput = Hugeint::Convert(input); + if (hinput >= max_width || hinput <= -max_width) { + string error = StringUtil::Format("Could not cast value %s to DECIMAL(%d,%d)", hinput.ToString(), width, scale); + HandleCastError::AssignError(error, error_message); + return false; + } + result = hinput * Hugeint::POWERS_OF_TEN[scale]; + return true; +} + +//===--------------------------------------------------------------------===// +// Cast int8_t -> Decimal +//===--------------------------------------------------------------------===// +template <> +bool TryCastToDecimal::Operation(int8_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int8_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int8_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int8_t input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { + return NumericToHugeDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Cast int16_t -> Decimal +//===--------------------------------------------------------------------===// +template <> +bool TryCastToDecimal::Operation(int16_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int16_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int16_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int16_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return NumericToHugeDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Cast int32_t -> Decimal +//===--------------------------------------------------------------------===// +template <> +bool TryCastToDecimal::Operation(int32_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int32_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int32_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int32_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return NumericToHugeDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Cast int64_t -> Decimal +//===--------------------------------------------------------------------===// +template <> +bool TryCastToDecimal::Operation(int64_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int64_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int64_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, width, scale); +} +template <> +bool TryCastToDecimal::Operation(int64_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return NumericToHugeDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Cast uint8_t -> Decimal +//===--------------------------------------------------------------------===// +template <> +bool TryCastToDecimal::Operation(uint8_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint8_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint8_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint8_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return NumericToHugeDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Cast uint16_t -> Decimal +//===--------------------------------------------------------------------===// +template <> +bool TryCastToDecimal::Operation(uint16_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint16_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint16_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint16_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return NumericToHugeDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Cast uint32_t -> Decimal +//===--------------------------------------------------------------------===// +template <> +bool TryCastToDecimal::Operation(uint32_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint32_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint32_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint32_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return NumericToHugeDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Cast uint64_t -> Decimal +//===--------------------------------------------------------------------===// +template <> +bool TryCastToDecimal::Operation(uint64_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint64_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint64_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return StandardNumericToDecimalCast(input, result, error_message, + width, scale); +} +template <> +bool TryCastToDecimal::Operation(uint64_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return NumericToHugeDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Hugeint -> Decimal Cast +//===--------------------------------------------------------------------===// +template +bool HugeintToDecimalCast(hugeint_t input, DST &result, string *error_message, uint8_t width, uint8_t scale) { + // check for overflow + hugeint_t max_width = Hugeint::POWERS_OF_TEN[width - scale]; + if (input >= max_width || input <= -max_width) { + string error = StringUtil::Format("Could not cast value %s to DECIMAL(%d,%d)", input.ToString(), width, scale); + HandleCastError::AssignError(error, error_message); + return false; + } + result = Hugeint::Cast(input * Hugeint::POWERS_OF_TEN[scale]); + return true; +} + +template <> +bool TryCastToDecimal::Operation(hugeint_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return HugeintToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(hugeint_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return HugeintToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(hugeint_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return HugeintToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(hugeint_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return HugeintToDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Float/Double -> Decimal Cast +//===--------------------------------------------------------------------===// +template +bool DoubleToDecimalCast(SRC input, DST &result, string *error_message, uint8_t width, uint8_t scale) { + double value = input * NumericHelper::DOUBLE_POWERS_OF_TEN[scale]; + // Add the sign (-1, 0, 1) times a tiny value to fix floating point issues (issue 3091) + double sign = (double(0) < value) - (value < double(0)); + value += 1e-9 * sign; + if (value <= -NumericHelper::DOUBLE_POWERS_OF_TEN[width] || value >= NumericHelper::DOUBLE_POWERS_OF_TEN[width]) { + string error = StringUtil::Format("Could not cast value %f to DECIMAL(%d,%d)", value, width, scale); + HandleCastError::AssignError(error, error_message); + return false; + } + result = Cast::Operation(value); + return true; +} + +template <> +bool TryCastToDecimal::Operation(float input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return DoubleToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(float input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return DoubleToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(float input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return DoubleToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(float input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { + return DoubleToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(double input, int16_t &result, string *error_message, uint8_t width, uint8_t scale) { + return DoubleToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(double input, int32_t &result, string *error_message, uint8_t width, uint8_t scale) { + return DoubleToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(double input, int64_t &result, string *error_message, uint8_t width, uint8_t scale) { + return DoubleToDecimalCast(input, result, error_message, width, scale); +} + +template <> +bool TryCastToDecimal::Operation(double input, hugeint_t &result, string *error_message, uint8_t width, uint8_t scale) { + return DoubleToDecimalCast(input, result, error_message, width, scale); +} + +//===--------------------------------------------------------------------===// +// Decimal -> Numeric Cast +//===--------------------------------------------------------------------===// +template +bool TryCastDecimalToNumeric(SRC input, DST &result, string *error_message, uint8_t scale) { + // Round away from 0. + const auto power = NumericHelper::POWERS_OF_TEN[scale]; + // https://graphics.stanford.edu/~seander/bithacks.html#ConditionalNegate + const auto fNegate = int64_t(input < 0); + const auto rounding = ((power ^ -fNegate) + fNegate) / 2; + const auto scaled_value = (input + rounding) / power; + if (!TryCast::Operation(scaled_value, result)) { + string error = StringUtil::Format("Failed to cast decimal value %d to type %s", scaled_value, GetTypeId()); + HandleCastError::AssignError(error, error_message); + return false; + } + return true; +} + +template +bool TryCastHugeDecimalToNumeric(hugeint_t input, DST &result, string *error_message, uint8_t scale) { + const auto power = Hugeint::POWERS_OF_TEN[scale]; + const auto rounding = ((input < 0) ? -power : power) / 2; + auto scaled_value = (input + rounding) / power; + if (!TryCast::Operation(scaled_value, result)) { + string error = StringUtil::Format("Failed to cast decimal value %s to type %s", + ConvertToString::Operation(scaled_value), GetTypeId()); + HandleCastError::AssignError(error, error_message); + return false; + } + return true; +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> int8_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, int8_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> int16_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> int32_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> int64_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> uint8_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, uint8_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, uint8_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, uint8_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, uint8_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> uint16_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, uint16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, uint16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, uint16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, uint16_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> uint32_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, uint32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, uint32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, uint32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, uint32_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> uint64_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, uint64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, uint64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, uint64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, uint64_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Cast Decimal -> hugeint_t +//===--------------------------------------------------------------------===// +template <> +bool TryCastFromDecimal::Operation(int16_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int32_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(int64_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToNumeric(input, result, error_message, scale); +} +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastHugeDecimalToNumeric(input, result, error_message, scale); +} + +//===--------------------------------------------------------------------===// +// Decimal -> Float/Double Cast +//===--------------------------------------------------------------------===// +template +bool TryCastDecimalToFloatingPoint(SRC input, DST &result, uint8_t scale) { + result = Cast::Operation(input) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); + return true; +} + +// DECIMAL -> FLOAT +template <> +bool TryCastFromDecimal::Operation(int16_t input, float &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToFloatingPoint(input, result, scale); +} + +template <> +bool TryCastFromDecimal::Operation(int32_t input, float &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToFloatingPoint(input, result, scale); +} + +template <> +bool TryCastFromDecimal::Operation(int64_t input, float &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToFloatingPoint(input, result, scale); +} + +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, float &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToFloatingPoint(input, result, scale); +} + +// DECIMAL -> DOUBLE +template <> +bool TryCastFromDecimal::Operation(int16_t input, double &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToFloatingPoint(input, result, scale); +} + +template <> +bool TryCastFromDecimal::Operation(int32_t input, double &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToFloatingPoint(input, result, scale); +} + +template <> +bool TryCastFromDecimal::Operation(int64_t input, double &result, string *error_message, uint8_t width, uint8_t scale) { + return TryCastDecimalToFloatingPoint(input, result, scale); +} + +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, double &result, string *error_message, uint8_t width, + uint8_t scale) { + return TryCastDecimalToFloatingPoint(input, result, scale); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/operator/convert_to_string.cpp b/src/duckdb/src/common/operator/convert_to_string.cpp new file mode 100644 index 00000000..d2a20106 --- /dev/null +++ b/src/duckdb/src/common/operator/convert_to_string.cpp @@ -0,0 +1,82 @@ +#include "duckdb/common/operator/convert_to_string.hpp" +#include "duckdb/common/operator/string_cast.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +template +string StandardStringCast(T input) { + Vector v(LogicalType::VARCHAR); + return StringCast::Operation(input, v).GetString(); +} + +template <> +string ConvertToString::Operation(bool input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(int8_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(int16_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(int32_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(int64_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(uint8_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(uint16_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(uint32_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(uint64_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(hugeint_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(float input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(double input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(interval_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(date_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(dtime_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(timestamp_t input) { + return StandardStringCast(input); +} +template <> +string ConvertToString::Operation(string_t input) { + return input.GetString(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/operator/string_cast.cpp b/src/duckdb/src/common/operator/string_cast.cpp new file mode 100644 index 00000000..633c1408 --- /dev/null +++ b/src/duckdb/src/common/operator/string_cast.cpp @@ -0,0 +1,264 @@ +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/operator/string_cast.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Cast Numeric -> String +//===--------------------------------------------------------------------===// +template <> +string_t StringCast::Operation(bool input, Vector &vector) { + if (input) { + return StringVector::AddString(vector, "true", 4); + } else { + return StringVector::AddString(vector, "false", 5); + } +} + +template <> +string_t StringCast::Operation(int8_t input, Vector &vector) { + return NumericHelper::FormatSigned(input, vector); +} + +template <> +string_t StringCast::Operation(int16_t input, Vector &vector) { + return NumericHelper::FormatSigned(input, vector); +} +template <> +string_t StringCast::Operation(int32_t input, Vector &vector) { + return NumericHelper::FormatSigned(input, vector); +} + +template <> +string_t StringCast::Operation(int64_t input, Vector &vector) { + return NumericHelper::FormatSigned(input, vector); +} +template <> +duckdb::string_t StringCast::Operation(uint8_t input, Vector &vector) { + return NumericHelper::FormatSigned(input, vector); +} +template <> +duckdb::string_t StringCast::Operation(uint16_t input, Vector &vector) { + return NumericHelper::FormatSigned(input, vector); +} +template <> +duckdb::string_t StringCast::Operation(uint32_t input, Vector &vector) { + return NumericHelper::FormatSigned(input, vector); +} +template <> +duckdb::string_t StringCast::Operation(uint64_t input, Vector &vector) { + return NumericHelper::FormatSigned(input, vector); +} + +template <> +string_t StringCast::Operation(float input, Vector &vector) { + std::string s = duckdb_fmt::format("{}", input); + return StringVector::AddString(vector, s); +} + +template <> +string_t StringCast::Operation(double input, Vector &vector) { + std::string s = duckdb_fmt::format("{}", input); + return StringVector::AddString(vector, s); +} + +template <> +string_t StringCast::Operation(interval_t input, Vector &vector) { + char buffer[70]; + idx_t length = IntervalToStringCast::Format(input, buffer); + return StringVector::AddString(vector, buffer, length); +} + +template <> +duckdb::string_t StringCast::Operation(hugeint_t input, Vector &vector) { + return HugeintToStringCast::FormatSigned(input, vector); +} + +template <> +duckdb::string_t StringCast::Operation(date_t input, Vector &vector) { + if (input == date_t::infinity()) { + return StringVector::AddString(vector, Date::PINF); + } else if (input == date_t::ninfinity()) { + return StringVector::AddString(vector, Date::NINF); + } + int32_t date[3]; + Date::Convert(input, date[0], date[1], date[2]); + + idx_t year_length; + bool add_bc; + idx_t length = DateToStringCast::Length(date, year_length, add_bc); + + string_t result = StringVector::EmptyString(vector, length); + auto data = result.GetDataWriteable(); + + DateToStringCast::Format(data, date, year_length, add_bc); + + result.Finalize(); + return result; +} + +template <> +duckdb::string_t StringCast::Operation(dtime_t input, Vector &vector) { + int32_t time[4]; + Time::Convert(input, time[0], time[1], time[2], time[3]); + + char micro_buffer[10]; + idx_t length = TimeToStringCast::Length(time, micro_buffer); + + string_t result = StringVector::EmptyString(vector, length); + auto data = result.GetDataWriteable(); + + TimeToStringCast::Format(data, length, time, micro_buffer); + + result.Finalize(); + return result; +} + +template <> +duckdb::string_t StringCast::Operation(timestamp_t input, Vector &vector) { + if (input == timestamp_t::infinity()) { + return StringVector::AddString(vector, Date::PINF); + } else if (input == timestamp_t::ninfinity()) { + return StringVector::AddString(vector, Date::NINF); + } + date_t date_entry; + dtime_t time_entry; + Timestamp::Convert(input, date_entry, time_entry); + + int32_t date[3], time[4]; + Date::Convert(date_entry, date[0], date[1], date[2]); + Time::Convert(time_entry, time[0], time[1], time[2], time[3]); + + // format for timestamp is DATE TIME (separated by space) + idx_t year_length; + bool add_bc; + char micro_buffer[6]; + idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); + idx_t time_length = TimeToStringCast::Length(time, micro_buffer); + idx_t length = date_length + time_length + 1; + + string_t result = StringVector::EmptyString(vector, length); + auto data = result.GetDataWriteable(); + + DateToStringCast::Format(data, date, year_length, add_bc); + data[date_length] = ' '; + TimeToStringCast::Format(data + date_length + 1, time_length, time, micro_buffer); + + result.Finalize(); + return result; +} + +template <> +duckdb::string_t StringCast::Operation(duckdb::string_t input, Vector &result) { + return StringVector::AddStringOrBlob(result, input); +} + +template <> +string_t StringCastTZ::Operation(dtime_tz_t input, Vector &vector) { + int32_t time[4]; + Time::Convert(input.time(), time[0], time[1], time[2], time[3]); + + char micro_buffer[10]; + const auto time_length = TimeToStringCast::Length(time, micro_buffer); + idx_t length = time_length; + + const auto offset = input.offset(); + const bool negative = (offset < 0); + ++length; + + auto ss = std::abs(offset); + const auto hh = ss / Interval::SECS_PER_HOUR; + + const auto hh_length = (hh < 100) ? 2 : NumericHelper::UnsignedLength(uint32_t(hh)); + length += hh_length; + + ss %= Interval::SECS_PER_HOUR; + const auto mm = ss / Interval::SECS_PER_MINUTE; + if (mm) { + length += 3; + } + + ss %= Interval::SECS_PER_MINUTE; + if (ss) { + length += 3; + } + + string_t result = StringVector::EmptyString(vector, length); + auto data = result.GetDataWriteable(); + + idx_t pos = 0; + TimeToStringCast::Format(data + pos, time_length, time, micro_buffer); + pos += time_length; + + data[pos++] = negative ? '-' : '+'; + if (hh < 100) { + TimeToStringCast::FormatTwoDigits(data + pos, hh); + } else { + NumericHelper::FormatUnsigned(hh, data + pos + hh_length); + } + pos += hh_length; + + if (mm) { + data[pos++] = ':'; + TimeToStringCast::FormatTwoDigits(data + pos, mm); + pos += 2; + } + + if (ss) { + data[pos++] = ':'; + TimeToStringCast::FormatTwoDigits(data + pos, ss); + pos += 2; + } + + result.Finalize(); + return result; +} + +template <> +string_t StringCastTZ::Operation(timestamp_t input, Vector &vector) { + if (input == timestamp_t::infinity()) { + return StringVector::AddString(vector, Date::PINF); + } else if (input == timestamp_t::ninfinity()) { + return StringVector::AddString(vector, Date::NINF); + } + date_t date_entry; + dtime_t time_entry; + Timestamp::Convert(input, date_entry, time_entry); + + int32_t date[3], time[4]; + Date::Convert(date_entry, date[0], date[1], date[2]); + Time::Convert(time_entry, time[0], time[1], time[2], time[3]); + + // format for timestamptz is DATE TIME+00 (separated by space) + idx_t year_length; + bool add_bc; + char micro_buffer[6]; + const idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); + const idx_t time_length = TimeToStringCast::Length(time, micro_buffer); + const idx_t length = date_length + 1 + time_length + 3; + + string_t result = StringVector::EmptyString(vector, length); + auto data = result.GetDataWriteable(); + + idx_t pos = 0; + DateToStringCast::Format(data + pos, date, year_length, add_bc); + pos += date_length; + data[pos++] = ' '; + TimeToStringCast::Format(data + pos, time_length, time, micro_buffer); + pos += time_length; + data[pos++] = '+'; + data[pos++] = '0'; + data[pos++] = '0'; + + result.Finalize(); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/pipe_file_system.cpp b/src/duckdb/src/common/pipe_file_system.cpp new file mode 100644 index 00000000..39a1877d --- /dev/null +++ b/src/duckdb/src/common/pipe_file_system.cpp @@ -0,0 +1,57 @@ +#include "duckdb/common/pipe_file_system.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { +class PipeFile : public FileHandle { +public: + PipeFile(unique_ptr child_handle_p, const string &path) + : FileHandle(pipe_fs, path), child_handle(std::move(child_handle_p)) { + } + + PipeFileSystem pipe_fs; + unique_ptr child_handle; + +public: + int64_t ReadChunk(void *buffer, int64_t nr_bytes); + int64_t WriteChunk(void *buffer, int64_t nr_bytes); + + void Close() override { + } +}; + +int64_t PipeFile::ReadChunk(void *buffer, int64_t nr_bytes) { + return child_handle->Read(buffer, nr_bytes); +} +int64_t PipeFile::WriteChunk(void *buffer, int64_t nr_bytes) { + return child_handle->Write(buffer, nr_bytes); +} + +void PipeFileSystem::Reset(FileHandle &handle) { + throw InternalException("Cannot reset pipe file system"); +} + +int64_t PipeFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + auto &pipe = handle.Cast(); + return pipe.ReadChunk(buffer, nr_bytes); +} + +int64_t PipeFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { + auto &pipe = handle.Cast(); + return pipe.WriteChunk(buffer, nr_bytes); +} + +int64_t PipeFileSystem::GetFileSize(FileHandle &handle) { + return 0; +} + +void PipeFileSystem::FileSync(FileHandle &handle) { +} + +unique_ptr PipeFileSystem::OpenPipe(unique_ptr handle) { + auto path = handle->path; + return make_uniq(std::move(handle), path); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/preserved_error.cpp b/src/duckdb/src/common/preserved_error.cpp new file mode 100644 index 00000000..8fbe0f0b --- /dev/null +++ b/src/duckdb/src/common/preserved_error.cpp @@ -0,0 +1,67 @@ +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/common/exception.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +PreservedError::PreservedError() : initialized(false), exception_instance(nullptr) { +} + +PreservedError::PreservedError(const Exception &exception) + : initialized(true), type(exception.type), raw_message(SanitizeErrorMessage(exception.RawMessage())), + exception_instance(exception.Copy()) { +} + +PreservedError::PreservedError(const string &message) + : initialized(true), type(ExceptionType::INVALID), raw_message(SanitizeErrorMessage(message)), + exception_instance(nullptr) { +} + +const string &PreservedError::Message() { + if (final_message.empty()) { + final_message = Exception::ExceptionTypeToString(type) + " Error: " + raw_message; + } + return final_message; +} + +string PreservedError::SanitizeErrorMessage(string error) { + return StringUtil::Replace(std::move(error), string("\0", 1), "\\0"); +} + +void PreservedError::Throw(const string &prepended_message) const { + D_ASSERT(initialized); + if (!prepended_message.empty()) { + string new_message = prepended_message + raw_message; + Exception::ThrowAsTypeWithMessage(type, new_message, exception_instance); + } + Exception::ThrowAsTypeWithMessage(type, raw_message, exception_instance); +} + +const ExceptionType &PreservedError::Type() const { + D_ASSERT(initialized); + return this->type; +} + +PreservedError &PreservedError::AddToMessage(const string &prepended_message) { + raw_message = prepended_message + raw_message; + return *this; +} + +PreservedError::operator bool() const { + return initialized; +} + +bool PreservedError::operator==(const PreservedError &other) const { + if (initialized != other.initialized) { + return false; + } + if (type != other.type) { + return false; + } + return raw_message == other.raw_message; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/printer.cpp b/src/duckdb/src/common/printer.cpp new file mode 100644 index 00000000..2c24069c --- /dev/null +++ b/src/duckdb/src/common/printer.cpp @@ -0,0 +1,81 @@ +#include "duckdb/common/printer.hpp" +#include "duckdb/common/progress_bar/progress_bar.hpp" +#include "duckdb/common/windows_util.hpp" +#include "duckdb/common/windows.hpp" +#include + +#ifndef DUCKDB_DISABLE_PRINT +#ifdef DUCKDB_WINDOWS +#include +#else +#include +#include +#include +#endif +#endif + +namespace duckdb { + +void Printer::RawPrint(OutputStream stream, const string &str) { +#ifndef DUCKDB_DISABLE_PRINT +#ifdef DUCKDB_WINDOWS + if (IsTerminal(stream)) { + // print utf8 to terminal + auto unicode = WindowsUtil::UTF8ToMBCS(str.c_str()); + fprintf(stream == OutputStream::STREAM_STDERR ? stderr : stdout, "%s", unicode.c_str()); + return; + } +#endif + fprintf(stream == OutputStream::STREAM_STDERR ? stderr : stdout, "%s", str.c_str()); +#endif +} + +// LCOV_EXCL_START +void Printer::Print(OutputStream stream, const string &str) { + Printer::RawPrint(stream, str); + Printer::RawPrint(stream, "\n"); +} +void Printer::Flush(OutputStream stream) { +#ifndef DUCKDB_DISABLE_PRINT + fflush(stream == OutputStream::STREAM_STDERR ? stderr : stdout); +#endif +} + +void Printer::Print(const string &str) { + Printer::Print(OutputStream::STREAM_STDERR, str); +} + +bool Printer::IsTerminal(OutputStream stream) { +#ifndef DUCKDB_DISABLE_PRINT +#ifdef DUCKDB_WINDOWS + auto stream_handle = stream == OutputStream::STREAM_STDERR ? STD_ERROR_HANDLE : STD_OUTPUT_HANDLE; + return GetFileType(GetStdHandle(stream_handle)) == FILE_TYPE_CHAR; +#else + return isatty(stream == OutputStream::STREAM_STDERR ? 2 : 1); +#endif +#else + throw InternalException("IsTerminal called while printing is disabled"); +#endif +} + +idx_t Printer::TerminalWidth() { +#ifndef DUCKDB_DISABLE_PRINT +#ifdef DUCKDB_WINDOWS + CONSOLE_SCREEN_BUFFER_INFO csbi; + int columns, rows; + + GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); + rows = csbi.srWindow.Right - csbi.srWindow.Left + 1; + return rows; +#else + struct winsize w; + ioctl(0, TIOCGWINSZ, &w); + return w.ws_col; +#endif +#else + throw InternalException("TerminalWidth called while printing is disabled"); +#endif +} +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/common/progress_bar/progress_bar.cpp b/src/duckdb/src/common/progress_bar/progress_bar.cpp new file mode 100644 index 00000000..26e56ed2 --- /dev/null +++ b/src/duckdb/src/common/progress_bar/progress_bar.cpp @@ -0,0 +1,98 @@ +#include "duckdb/common/progress_bar/progress_bar.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp" + +namespace duckdb { + +void ProgressBar::SystemOverrideCheck(ClientConfig &config) { + if (config.system_progress_bar_disable_reason != nullptr) { + throw InvalidInputException("Could not change the progress bar setting because: '%s'", + config.system_progress_bar_disable_reason); + } +} + +unique_ptr ProgressBar::DefaultProgressBarDisplay() { + return make_uniq(); +} + +ProgressBar::ProgressBar(Executor &executor, idx_t show_progress_after, + progress_bar_display_create_func_t create_display_func) + : executor(executor), show_progress_after(show_progress_after), current_percentage(-1) { + if (create_display_func) { + display = create_display_func(); + } +} + +double ProgressBar::GetCurrentPercentage() { + return current_percentage; +} + +void ProgressBar::Start() { + profiler.Start(); + current_percentage = 0; + supported = true; +} + +bool ProgressBar::PrintEnabled() const { + return display != nullptr; +} + +bool ProgressBar::ShouldPrint(bool final) const { + if (!PrintEnabled()) { + // Don't print progress at all + return false; + } + // FIXME - do we need to check supported before running `profiler.Elapsed()` ? + auto sufficient_time_elapsed = profiler.Elapsed() > show_progress_after / 1000.0; + if (!sufficient_time_elapsed) { + // Don't print yet + return false; + } + if (final) { + // Print the last completed bar + return true; + } + if (!supported) { + return false; + } + return current_percentage > -1; +} + +void ProgressBar::Update(bool final) { + if (!final && !supported) { + return; + } + double new_percentage; + supported = executor.GetPipelinesProgress(new_percentage); + if (!final && !supported) { + return; + } + if (new_percentage > current_percentage) { + current_percentage = new_percentage; + } + if (ShouldPrint(final)) { +#ifndef DUCKDB_DISABLE_PRINT + if (final) { + FinishProgressBarPrint(); + } else { + PrintProgress(current_percentage); + } +#endif + } +} + +void ProgressBar::PrintProgress(int current_percentage) { + D_ASSERT(display); + display->Update(current_percentage); +} + +void ProgressBar::FinishProgressBarPrint() { + if (finished) { + return; + } + D_ASSERT(display); + display->Finish(); + finished = true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp new file mode 100644 index 00000000..a4b10a4d --- /dev/null +++ b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp @@ -0,0 +1,66 @@ +#include "duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/to_string.hpp" + +namespace duckdb { + +void TerminalProgressBarDisplay::PrintProgressInternal(int percentage) { + if (percentage > 100) { + percentage = 100; + } + if (percentage < 0) { + percentage = 0; + } + string result; + // we divide the number of blocks by the percentage + // 0% = 0 + // 100% = PROGRESS_BAR_WIDTH + // the percentage determines how many blocks we need to draw + double blocks_to_draw = PROGRESS_BAR_WIDTH * (percentage / 100.0); + // because of the power of unicode, we can also draw partial blocks + + // render the percentage with some padding to ensure everything stays nicely aligned + result = "\r"; + if (percentage < 100) { + result += " "; + } + if (percentage < 10) { + result += " "; + } + result += to_string(percentage) + "%"; + result += " "; + result += PROGRESS_START; + idx_t i; + for (i = 0; i < idx_t(blocks_to_draw); i++) { + result += PROGRESS_BLOCK; + } + if (i < PROGRESS_BAR_WIDTH) { + // print a partial block based on the percentage of the progress bar remaining + idx_t index = idx_t((blocks_to_draw - idx_t(blocks_to_draw)) * PARTIAL_BLOCK_COUNT); + if (index >= PARTIAL_BLOCK_COUNT) { + index = PARTIAL_BLOCK_COUNT - 1; + } + result += PROGRESS_PARTIAL[index]; + i++; + } + for (; i < PROGRESS_BAR_WIDTH; i++) { + result += PROGRESS_EMPTY; + } + result += PROGRESS_END; + result += " "; + + Printer::RawPrint(OutputStream::STREAM_STDOUT, result); +} + +void TerminalProgressBarDisplay::Update(double percentage) { + PrintProgressInternal(percentage); + Printer::Flush(OutputStream::STREAM_STDOUT); +} + +void TerminalProgressBarDisplay::Finish() { + PrintProgressInternal(100); + Printer::RawPrint(OutputStream::STREAM_STDOUT, "\n"); + Printer::Flush(OutputStream::STREAM_STDOUT); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/radix_partitioning.cpp b/src/duckdb/src/common/radix_partitioning.cpp new file mode 100644 index 00000000..b4dc39bd --- /dev/null +++ b/src/duckdb/src/common/radix_partitioning.cpp @@ -0,0 +1,237 @@ +#include "duckdb/common/radix_partitioning.hpp" + +#include "duckdb/common/types/column/partitioned_column_data.hpp" +#include "duckdb/common/types/row/row_data_collection.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" + +namespace duckdb { + +//! Templated radix partitioning constants, can be templated to the number of radix bits +template +struct RadixPartitioningConstants { +public: + //! Bitmask of the upper bits starting at the 5th byte + static constexpr const idx_t NUM_PARTITIONS = RadixPartitioning::NumberOfPartitions(radix_bits); + static constexpr const idx_t SHIFT = RadixPartitioning::Shift(radix_bits); + static constexpr const hash_t MASK = RadixPartitioning::Mask(radix_bits); + +public: + //! Apply bitmask and right shift to get a number between 0 and NUM_PARTITIONS + static inline hash_t ApplyMask(hash_t hash) { + D_ASSERT((hash & MASK) >> SHIFT < NUM_PARTITIONS); + return (hash & MASK) >> SHIFT; + } +}; + +template +RETURN_TYPE RadixBitsSwitch(idx_t radix_bits, ARGS &&... args) { + D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); + switch (radix_bits) { + case 0: + return OP::template Operation<0>(std::forward(args)...); + case 1: + return OP::template Operation<1>(std::forward(args)...); + case 2: + return OP::template Operation<2>(std::forward(args)...); + case 3: + return OP::template Operation<3>(std::forward(args)...); + case 4: + return OP::template Operation<4>(std::forward(args)...); + case 5: // LCOV_EXCL_START + return OP::template Operation<5>(std::forward(args)...); + case 6: + return OP::template Operation<6>(std::forward(args)...); + case 7: + return OP::template Operation<7>(std::forward(args)...); + case 8: + return OP::template Operation<8>(std::forward(args)...); + case 9: + return OP::template Operation<9>(std::forward(args)...); + case 10: + return OP::template Operation<10>(std::forward(args)...); + case 11: + return OP::template Operation<10>(std::forward(args)...); + case 12: + return OP::template Operation<10>(std::forward(args)...); + default: + throw InternalException( + "radix_bits higher than RadixPartitioning::MAX_RADIX_BITS encountered in RadixBitsSwitch"); + } // LCOV_EXCL_STOP +} + +template +struct RadixLessThan { + static inline bool Operation(hash_t hash, hash_t cutoff) { + using CONSTANTS = RadixPartitioningConstants; + return CONSTANTS::ApplyMask(hash) < cutoff; + } +}; + +struct SelectFunctor { + template + static idx_t Operation(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t cutoff, + SelectionVector *true_sel, SelectionVector *false_sel) { + Vector cutoff_vector(Value::HASH(cutoff)); + return BinaryExecutor::Select>(hashes, cutoff_vector, sel, count, + true_sel, false_sel); + } +}; + +idx_t RadixPartitioning::Select(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t radix_bits, idx_t cutoff, + SelectionVector *true_sel, SelectionVector *false_sel) { + return RadixBitsSwitch(radix_bits, hashes, sel, count, cutoff, true_sel, false_sel); +} + +struct ComputePartitionIndicesFunctor { + template + static void Operation(Vector &hashes, Vector &partition_indices, idx_t count) { + UnaryExecutor::Execute(hashes, partition_indices, count, [&](hash_t hash) { + using CONSTANTS = RadixPartitioningConstants; + return CONSTANTS::ApplyMask(hash); + }); + } +}; + +//===--------------------------------------------------------------------===// +// Column Data Partitioning +//===--------------------------------------------------------------------===// +RadixPartitionedColumnData::RadixPartitionedColumnData(ClientContext &context_p, vector types_p, + idx_t radix_bits_p, idx_t hash_col_idx_p) + : PartitionedColumnData(PartitionedColumnDataType::RADIX, context_p, std::move(types_p)), radix_bits(radix_bits_p), + hash_col_idx(hash_col_idx_p) { + D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); + D_ASSERT(hash_col_idx < types.size()); + const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); + allocators->allocators.reserve(num_partitions); + for (idx_t i = 0; i < num_partitions; i++) { + CreateAllocator(); + } + D_ASSERT(allocators->allocators.size() == num_partitions); +} + +RadixPartitionedColumnData::RadixPartitionedColumnData(const RadixPartitionedColumnData &other) + : PartitionedColumnData(other), radix_bits(other.radix_bits), hash_col_idx(other.hash_col_idx) { + for (idx_t i = 0; i < RadixPartitioning::NumberOfPartitions(radix_bits); i++) { + partitions.emplace_back(CreatePartitionCollection(i)); + } +} + +RadixPartitionedColumnData::~RadixPartitionedColumnData() { +} + +void RadixPartitionedColumnData::InitializeAppendStateInternal(PartitionedColumnDataAppendState &state) const { + const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); + state.partition_append_states.reserve(num_partitions); + state.partition_buffers.reserve(num_partitions); + for (idx_t i = 0; i < num_partitions; i++) { + state.partition_append_states.emplace_back(make_uniq()); + partitions[i]->InitializeAppend(*state.partition_append_states[i]); + state.partition_buffers.emplace_back(CreatePartitionBuffer()); + } +} + +void RadixPartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) { + D_ASSERT(partitions.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); + D_ASSERT(state.partition_buffers.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); + RadixBitsSwitch(radix_bits, input.data[hash_col_idx], state.partition_indices, + input.size()); +} + +//===--------------------------------------------------------------------===// +// Tuple Data Partitioning +//===--------------------------------------------------------------------===// +RadixPartitionedTupleData::RadixPartitionedTupleData(BufferManager &buffer_manager, const TupleDataLayout &layout_p, + idx_t radix_bits_p, idx_t hash_col_idx_p) + : PartitionedTupleData(PartitionedTupleDataType::RADIX, buffer_manager, layout_p.Copy()), radix_bits(radix_bits_p), + hash_col_idx(hash_col_idx_p) { + D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); + D_ASSERT(hash_col_idx < layout.GetTypes().size()); + const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); + allocators->allocators.reserve(num_partitions); + for (idx_t i = 0; i < num_partitions; i++) { + CreateAllocator(); + } + D_ASSERT(allocators->allocators.size() == num_partitions); + Initialize(); +} + +RadixPartitionedTupleData::RadixPartitionedTupleData(const RadixPartitionedTupleData &other) + : PartitionedTupleData(other), radix_bits(other.radix_bits), hash_col_idx(other.hash_col_idx) { + Initialize(); +} + +RadixPartitionedTupleData::~RadixPartitionedTupleData() { +} + +void RadixPartitionedTupleData::Initialize() { + for (idx_t i = 0; i < RadixPartitioning::NumberOfPartitions(radix_bits); i++) { + partitions.emplace_back(CreatePartitionCollection(i)); + } +} + +void RadixPartitionedTupleData::InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, + TupleDataPinProperties properties) const { + // Init pin state per partition + const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); + state.partition_pin_states.reserve(num_partitions); + for (idx_t i = 0; i < num_partitions; i++) { + state.partition_pin_states.emplace_back(make_uniq()); + partitions[i]->InitializeAppend(*state.partition_pin_states[i], properties); + } + + // Init single chunk state + auto column_count = layout.ColumnCount(); + vector column_ids; + column_ids.reserve(column_count); + for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { + column_ids.emplace_back(col_idx); + } + partitions[0]->InitializeChunkState(state.chunk_state, std::move(column_ids)); + + // Initialize fixed-size map + state.fixed_partition_entries.resize(RadixPartitioning::NumberOfPartitions(radix_bits)); +} + +void RadixPartitionedTupleData::ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input) { + D_ASSERT(partitions.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); + RadixBitsSwitch(radix_bits, input.data[hash_col_idx], state.partition_indices, + input.size()); +} + +void RadixPartitionedTupleData::ComputePartitionIndices(Vector &row_locations, idx_t count, + Vector &partition_indices) const { + Vector intermediate(LogicalType::HASH); + partitions[0]->Gather(row_locations, *FlatVector::IncrementalSelectionVector(), count, hash_col_idx, intermediate, + *FlatVector::IncrementalSelectionVector()); + RadixBitsSwitch(radix_bits, intermediate, partition_indices, count); +} + +void RadixPartitionedTupleData::RepartitionFinalizeStates(PartitionedTupleData &old_partitioned_data, + PartitionedTupleData &new_partitioned_data, + PartitionedTupleDataAppendState &state, + idx_t finished_partition_idx) const { + D_ASSERT(old_partitioned_data.GetType() == PartitionedTupleDataType::RADIX && + new_partitioned_data.GetType() == PartitionedTupleDataType::RADIX); + const auto &old_radix_partitions = old_partitioned_data.Cast(); + const auto &new_radix_partitions = new_partitioned_data.Cast(); + const auto old_radix_bits = old_radix_partitions.GetRadixBits(); + const auto new_radix_bits = new_radix_partitions.GetRadixBits(); + D_ASSERT(new_radix_bits > old_radix_bits); + + // We take the most significant digits as the partition index + // When repartitioning, e.g., partition 0 from "old" goes into the first N partitions in "new" + // When partition 0 is done, we can already finalize the append states, unpinning blocks + const auto multiplier = RadixPartitioning::NumberOfPartitions(new_radix_bits - old_radix_bits); + const auto from_idx = finished_partition_idx * multiplier; + const auto to_idx = from_idx + multiplier; + auto &partitions = new_partitioned_data.GetPartitions(); + for (idx_t partition_index = from_idx; partition_index < to_idx; partition_index++) { + auto &partition = *partitions[partition_index]; + auto &partition_pin_state = *state.partition_pin_states[partition_index]; + partition.FinalizePinState(partition_pin_state); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/random_engine.cpp b/src/duckdb/src/common/random_engine.cpp new file mode 100644 index 00000000..0c9aec4e --- /dev/null +++ b/src/duckdb/src/common/random_engine.cpp @@ -0,0 +1,41 @@ +#include "duckdb/common/random_engine.hpp" +#include "pcg_random.hpp" +#include + +namespace duckdb { + +struct RandomState { + RandomState() { + } + + pcg32 pcg; +}; + +RandomEngine::RandomEngine(int64_t seed) : random_state(make_uniq()) { + if (seed < 0) { + random_state->pcg.seed(pcg_extras::seed_seq_from()); + } else { + random_state->pcg.seed(seed); + } +} + +RandomEngine::~RandomEngine() { +} + +double RandomEngine::NextRandom(double min, double max) { + D_ASSERT(max >= min); + return min + (NextRandom() * (max - min)); +} + +double RandomEngine::NextRandom() { + return std::ldexp(random_state->pcg(), -32); +} +uint32_t RandomEngine::NextRandomInteger() { + return random_state->pcg(); +} + +void RandomEngine::SetSeed(uint32_t seed) { + random_state->pcg.seed(seed); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/re2_regex.cpp b/src/duckdb/src/common/re2_regex.cpp new file mode 100644 index 00000000..c1f7dffa --- /dev/null +++ b/src/duckdb/src/common/re2_regex.cpp @@ -0,0 +1,62 @@ +#include "duckdb/common/vector.hpp" +#include + +#include "duckdb/common/re2_regex.hpp" +#include "re2/re2.h" + +namespace duckdb_re2 { + +Regex::Regex(const std::string &pattern, RegexOptions options) { + RE2::Options o; + o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE); + regex = std::make_shared(StringPiece(pattern), o); +} + +bool RegexSearchInternal(const char *input, Match &match, const Regex &r, RE2::Anchor anchor, size_t start, + size_t end) { + auto ®ex = r.GetRegex(); + duckdb::vector target_groups; + auto group_count = regex.NumberOfCapturingGroups() + 1; + target_groups.resize(group_count); + match.groups.clear(); + if (!regex.Match(StringPiece(input), start, end, anchor, target_groups.data(), group_count)) { + return false; + } + for (auto &group : target_groups) { + GroupMatch group_match; + group_match.text = group.ToString(); + group_match.position = group.data() - input; + match.groups.emplace_back(group_match); + } + return true; +} + +bool RegexSearch(const std::string &input, Match &match, const Regex ®ex) { + return RegexSearchInternal(input.c_str(), match, regex, RE2::UNANCHORED, 0, input.size()); +} + +bool RegexMatch(const std::string &input, Match &match, const Regex ®ex) { + return RegexSearchInternal(input.c_str(), match, regex, RE2::ANCHOR_BOTH, 0, input.size()); +} + +bool RegexMatch(const char *start, const char *end, Match &match, const Regex ®ex) { + return RegexSearchInternal(start, match, regex, RE2::ANCHOR_BOTH, 0, end - start); +} + +bool RegexMatch(const std::string &input, const Regex ®ex) { + Match nop_match; + return RegexSearchInternal(input.c_str(), nop_match, regex, RE2::ANCHOR_BOTH, 0, input.size()); +} + +duckdb::vector RegexFindAll(const std::string &input, const Regex ®ex) { + duckdb::vector matches; + size_t position = 0; + Match match; + while (RegexSearchInternal(input.c_str(), match, regex, RE2::UNANCHORED, position, input.size())) { + position += match.position(0) + match.length(0); + matches.emplace_back(match); + } + return matches; +} + +} // namespace duckdb_re2 diff --git a/src/duckdb/src/common/row_operations/row_aggregate.cpp b/src/duckdb/src/common/row_operations/row_aggregate.cpp new file mode 100644 index 00000000..6c89d887 --- /dev/null +++ b/src/duckdb/src/common/row_operations/row_aggregate.cpp @@ -0,0 +1,122 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row_operations/row_aggregate.cpp +// +// +//===----------------------------------------------------------------------===// +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/row/tuple_data_layout.hpp" +#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" + +namespace duckdb { + +void RowOperations::InitializeStates(TupleDataLayout &layout, Vector &addresses, const SelectionVector &sel, + idx_t count) { + if (count == 0) { + return; + } + auto pointers = FlatVector::GetData(addresses); + auto &offsets = layout.GetOffsets(); + auto aggr_idx = layout.ColumnCount(); + + for (const auto &aggr : layout.GetAggregates()) { + for (idx_t i = 0; i < count; ++i) { + auto row_idx = sel.get_index(i); + auto row = pointers[row_idx]; + aggr.function.initialize(row + offsets[aggr_idx]); + } + ++aggr_idx; + } +} + +void RowOperations::DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, idx_t count) { + if (count == 0) { + return; + } + // Move to the first aggregate state + VectorOperations::AddInPlace(addresses, layout.GetAggrOffset(), count); + for (const auto &aggr : layout.GetAggregates()) { + if (aggr.function.destructor) { + AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); + aggr.function.destructor(addresses, aggr_input_data, count); + } + // Move to the next aggregate state + VectorOperations::AddInPlace(addresses, aggr.payload_size, count); + } +} + +void RowOperations::UpdateStates(RowOperationsState &state, AggregateObject &aggr, Vector &addresses, + DataChunk &payload, idx_t arg_idx, idx_t count) { + AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); + aggr.function.update(aggr.child_count == 0 ? nullptr : &payload.data[arg_idx], aggr_input_data, aggr.child_count, + addresses, count); +} + +void RowOperations::UpdateFilteredStates(RowOperationsState &state, AggregateFilterData &filter_data, + AggregateObject &aggr, Vector &addresses, DataChunk &payload, idx_t arg_idx) { + idx_t count = filter_data.ApplyFilter(payload); + if (count == 0) { + return; + } + + Vector filtered_addresses(addresses, filter_data.true_sel, count); + filtered_addresses.Flatten(count); + + UpdateStates(state, aggr, filtered_addresses, filter_data.filtered_payload, arg_idx, count); +} + +void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, Vector &targets, + idx_t count) { + if (count == 0) { + return; + } + + // Move to the first aggregate states + VectorOperations::AddInPlace(sources, layout.GetAggrOffset(), count); + VectorOperations::AddInPlace(targets, layout.GetAggrOffset(), count); + + // Keep track of the offset + idx_t offset = layout.GetAggrOffset(); + + for (auto &aggr : layout.GetAggregates()) { + D_ASSERT(aggr.function.combine); + AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); + aggr.function.combine(sources, targets, aggr_input_data, count); + + // Move to the next aggregate states + VectorOperations::AddInPlace(sources, aggr.payload_size, count); + VectorOperations::AddInPlace(targets, aggr.payload_size, count); + + // Increment the offset + offset += aggr.payload_size; + } + + // Now subtract the offset to get back to the original position + VectorOperations::AddInPlace(sources, -offset, count); + VectorOperations::AddInPlace(targets, -offset, count); +} + +void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, + DataChunk &result, idx_t aggr_idx) { + // Copy the addresses + Vector addresses_copy(LogicalType::POINTER); + VectorOperations::Copy(addresses, addresses_copy, result.size(), 0, 0); + + // Move to the first aggregate state + VectorOperations::AddInPlace(addresses_copy, layout.GetAggrOffset(), result.size()); + + auto &aggregates = layout.GetAggregates(); + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &target = result.data[aggr_idx + i]; + auto &aggr = aggregates[i]; + AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); + aggr.function.finalize(addresses_copy, aggr_input_data, target, result.size(), 0); + + // Move to the next aggregate state + VectorOperations::AddInPlace(addresses_copy, aggr.payload_size, result.size()); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_external.cpp b/src/duckdb/src/common/row_operations/row_external.cpp new file mode 100644 index 00000000..9e2fa071 --- /dev/null +++ b/src/duckdb/src/common/row_operations/row_external.cpp @@ -0,0 +1,163 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row_operations/row_external.cpp +// +// +//===----------------------------------------------------------------------===// +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/row/row_layout.hpp" + +namespace duckdb { + +using ValidityBytes = RowLayout::ValidityBytes; + +void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count) { + const idx_t row_width = layout.GetRowWidth(); + data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; + idx_t done = 0; + while (done != count) { + const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); + const data_ptr_t row_ptr = base_row_ptr + done * row_width; + // Load heap row pointers + data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); + for (idx_t i = 0; i < next; i++) { + heap_row_ptrs[i] = Load(heap_ptr_ptr); + heap_ptr_ptr += row_width; + } + // Loop through the blob columns + for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { + auto physical_type = layout.GetTypes()[col_idx].InternalType(); + if (TypeIsConstantSize(physical_type)) { + continue; + } + data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; + if (physical_type == PhysicalType::VARCHAR) { + data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; + for (idx_t i = 0; i < next; i++) { + if (Load(col_ptr) > string_t::INLINE_LENGTH) { + // Overwrite the string pointer with the within-row offset (if not inlined) + Store(Load(string_ptr) - heap_row_ptrs[i], string_ptr); + } + col_ptr += row_width; + string_ptr += row_width; + } + } else { + // Non-varchar blob columns + for (idx_t i = 0; i < next; i++) { + // Overwrite the column data pointer with the within-row offset + Store(Load(col_ptr) - heap_row_ptrs[i], col_ptr); + col_ptr += row_width; + } + } + } + done += next; + } +} + +void RowOperations::SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, + const idx_t count, const idx_t base_offset) { + const idx_t row_width = layout.GetRowWidth(); + row_ptr += layout.GetHeapOffset(); + idx_t cumulative_offset = 0; + for (idx_t i = 0; i < count; i++) { + Store(base_offset + cumulative_offset, row_ptr); + cumulative_offset += Load(heap_base_ptr + cumulative_offset); + row_ptr += row_width; + } +} + +void RowOperations::CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, + data_ptr_t heap_ptr, const idx_t count) { + const auto row_width = layout.GetRowWidth(); + const auto heap_offset = layout.GetHeapOffset(); + for (idx_t i = 0; i < count; i++) { + // Figure out source and size + const auto source_heap_ptr = Load(row_ptr + heap_offset); + const auto size = Load(source_heap_ptr); + D_ASSERT(size >= sizeof(uint32_t)); + + // Copy and swizzle + memcpy(heap_ptr, source_heap_ptr, size); + Store(heap_ptr - heap_base_ptr, row_ptr + heap_offset); + + // Increment for next iteration + row_ptr += row_width; + heap_ptr += size; + } +} + +void RowOperations::UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, + const data_ptr_t base_heap_ptr, const idx_t count) { + const auto row_width = layout.GetRowWidth(); + data_ptr_t heap_ptr_ptr = base_row_ptr + layout.GetHeapOffset(); + for (idx_t i = 0; i < count; i++) { + Store(base_heap_ptr + Load(heap_ptr_ptr), heap_ptr_ptr); + heap_ptr_ptr += row_width; + } +} + +static inline void VerifyUnswizzledString(const RowLayout &layout, const idx_t &col_idx, const data_ptr_t &row_ptr) { +#ifdef DEBUG + if (layout.GetTypes()[col_idx].id() != LogicalTypeId::VARCHAR) { + return; + } + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + ValidityBytes row_mask(row_ptr); + if (row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { + auto str = Load(row_ptr + layout.GetOffsets()[col_idx]); + str.Verify(); + } +#endif +} + +void RowOperations::UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, + const data_ptr_t base_heap_ptr, const idx_t count) { + const idx_t row_width = layout.GetRowWidth(); + data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; + idx_t done = 0; + while (done != count) { + const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); + const data_ptr_t row_ptr = base_row_ptr + done * row_width; + // Restore heap row pointers + data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); + for (idx_t i = 0; i < next; i++) { + heap_row_ptrs[i] = base_heap_ptr + Load(heap_ptr_ptr); + Store(heap_row_ptrs[i], heap_ptr_ptr); + heap_ptr_ptr += row_width; + } + // Loop through the blob columns + for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { + auto physical_type = layout.GetTypes()[col_idx].InternalType(); + if (TypeIsConstantSize(physical_type)) { + continue; + } + data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; + if (physical_type == PhysicalType::VARCHAR) { + data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; + for (idx_t i = 0; i < next; i++) { + if (Load(col_ptr) > string_t::INLINE_LENGTH) { + // Overwrite the string offset with the pointer (if not inlined) + Store(heap_row_ptrs[i] + Load(string_ptr), string_ptr); + VerifyUnswizzledString(layout, col_idx, row_ptr + i * row_width); + } + col_ptr += row_width; + string_ptr += row_width; + } + } else { + // Non-varchar blob columns + for (idx_t i = 0; i < next; i++) { + // Overwrite the column data offset with the pointer + Store(heap_row_ptrs[i] + Load(col_ptr), col_ptr); + col_ptr += row_width; + } + } + } + done += next; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_gather.cpp b/src/duckdb/src/common/row_operations/row_gather.cpp new file mode 100644 index 00000000..631c5ffa --- /dev/null +++ b/src/duckdb/src/common/row_operations/row_gather.cpp @@ -0,0 +1,231 @@ +//===--------------------------------------------------------------------===// +// row_gather.cpp +// Description: This file contains the implementation of the gather operators +//===--------------------------------------------------------------------===// + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/constant_operators.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/row/row_data_collection.hpp" +#include "duckdb/common/types/row/row_layout.hpp" +#include "duckdb/common/types/row/tuple_data_layout.hpp" + +namespace duckdb { + +using ValidityBytes = RowLayout::ValidityBytes; + +template +static void TemplatedGatherLoop(Vector &rows, const SelectionVector &row_sel, Vector &col, + const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, + idx_t build_size) { + // Precompute mask indexes + const auto &offsets = layout.GetOffsets(); + const auto col_offset = offsets[col_no]; + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); + + auto ptrs = FlatVector::GetData(rows); + auto data = FlatVector::GetData(col); + auto &col_mask = FlatVector::Validity(col); + + for (idx_t i = 0; i < count; i++) { + auto row_idx = row_sel.get_index(i); + auto row = ptrs[row_idx]; + auto col_idx = col_sel.get_index(i); + data[col_idx] = Load(row + col_offset); + ValidityBytes row_mask(row); + if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { + if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { + //! We need to initialize the mask with the vector size. + col_mask.Initialize(build_size); + } + col_mask.SetInvalid(col_idx); + } + } +} + +static void GatherVarchar(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, + idx_t count, const RowLayout &layout, idx_t col_no, idx_t build_size, + data_ptr_t base_heap_ptr) { + // Precompute mask indexes + const auto &offsets = layout.GetOffsets(); + const auto col_offset = offsets[col_no]; + const auto heap_offset = layout.GetHeapOffset(); + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); + + auto ptrs = FlatVector::GetData(rows); + auto data = FlatVector::GetData(col); + auto &col_mask = FlatVector::Validity(col); + + for (idx_t i = 0; i < count; i++) { + auto row_idx = row_sel.get_index(i); + auto row = ptrs[row_idx]; + auto col_idx = col_sel.get_index(i); + auto col_ptr = row + col_offset; + data[col_idx] = Load(col_ptr); + ValidityBytes row_mask(row); + if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { + if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { + //! We need to initialize the mask with the vector size. + col_mask.Initialize(build_size); + } + col_mask.SetInvalid(col_idx); + } else if (base_heap_ptr && Load(col_ptr) > string_t::INLINE_LENGTH) { + // Not inline, so unswizzle the copied pointer the pointer + auto heap_ptr_ptr = row + heap_offset; + auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); + auto string_ptr = data_ptr_t(data + col_idx) + string_t::HEADER_SIZE; + Store(heap_row_ptr + Load(string_ptr), string_ptr); +#ifdef DEBUG + data[col_idx].Verify(); +#endif + } + } +} + +static void GatherNestedVector(Vector &rows, const SelectionVector &row_sel, Vector &col, + const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, + data_ptr_t base_heap_ptr) { + const auto &offsets = layout.GetOffsets(); + const auto col_offset = offsets[col_no]; + const auto heap_offset = layout.GetHeapOffset(); + auto ptrs = FlatVector::GetData(rows); + + // Build the gather locations + auto data_locations = make_unsafe_uniq_array(count); + auto mask_locations = make_unsafe_uniq_array(count); + for (idx_t i = 0; i < count; i++) { + auto row_idx = row_sel.get_index(i); + auto row = ptrs[row_idx]; + mask_locations[i] = row; + auto col_ptr = ptrs[row_idx] + col_offset; + if (base_heap_ptr) { + auto heap_ptr_ptr = row + heap_offset; + auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); + data_locations[i] = heap_row_ptr + Load(col_ptr); + } else { + data_locations[i] = Load(col_ptr); + } + } + + // Deserialise into the selected locations + RowOperations::HeapGather(col, count, col_sel, col_no, data_locations.get(), mask_locations.get()); +} + +void RowOperations::Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, + const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size, + data_ptr_t heap_ptr) { + D_ASSERT(rows.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(rows.GetType().id() == LogicalTypeId::POINTER); // "Cannot gather from non-pointer type!" + + col.SetVectorType(VectorType::FLAT_VECTOR); + switch (col.GetType().InternalType()) { + case PhysicalType::UINT8: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::UINT16: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::UINT32: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::UINT64: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::INT16: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::INT32: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::INT64: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::INT128: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::FLOAT: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::DOUBLE: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::INTERVAL: + TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); + break; + case PhysicalType::VARCHAR: + GatherVarchar(rows, row_sel, col, col_sel, count, layout, col_no, build_size, heap_ptr); + break; + case PhysicalType::LIST: + case PhysicalType::STRUCT: + GatherNestedVector(rows, row_sel, col, col_sel, count, layout, col_no, heap_ptr); + break; + default: + throw InternalException("Unimplemented type for RowOperations::Gather"); + } +} + +template +static void TemplatedFullScanLoop(Vector &rows, Vector &col, idx_t count, idx_t col_offset, idx_t col_no) { + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); + + auto ptrs = FlatVector::GetData(rows); + auto data = FlatVector::GetData(col); + // auto &col_mask = FlatVector::Validity(col); + + for (idx_t i = 0; i < count; i++) { + auto row = ptrs[i]; + data[i] = Load(row + col_offset); + ValidityBytes row_mask(row); + if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { + throw InternalException("Null value comparisons not implemented for perfect hash table yet"); + // col_mask.SetInvalid(i); + } + } +} + +void RowOperations::FullScanColumn(const TupleDataLayout &layout, Vector &rows, Vector &col, idx_t count, + idx_t col_no) { + const auto col_offset = layout.GetOffsets()[col_no]; + col.SetVectorType(VectorType::FLAT_VECTOR); + switch (col.GetType().InternalType()) { + case PhysicalType::UINT8: + TemplatedFullScanLoop(rows, col, count, col_offset, col_no); + break; + case PhysicalType::UINT16: + TemplatedFullScanLoop(rows, col, count, col_offset, col_no); + break; + case PhysicalType::UINT32: + TemplatedFullScanLoop(rows, col, count, col_offset, col_no); + break; + case PhysicalType::UINT64: + TemplatedFullScanLoop(rows, col, count, col_offset, col_no); + break; + case PhysicalType::INT8: + TemplatedFullScanLoop(rows, col, count, col_offset, col_no); + break; + case PhysicalType::INT16: + TemplatedFullScanLoop(rows, col, count, col_offset, col_no); + break; + case PhysicalType::INT32: + TemplatedFullScanLoop(rows, col, count, col_offset, col_no); + break; + case PhysicalType::INT64: + TemplatedFullScanLoop(rows, col, count, col_offset, col_no); + break; + default: + throw NotImplementedException("Unimplemented type for RowOperations::FullScanColumn"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_gather.cpp b/src/duckdb/src/common/row_operations/row_heap_gather.cpp new file mode 100644 index 00000000..a8b6e7b9 --- /dev/null +++ b/src/duckdb/src/common/row_operations/row_heap_gather.cpp @@ -0,0 +1,208 @@ +#include "duckdb/common/helper.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +using ValidityBytes = TemplatedValidityMask; + +template +static void TemplatedHeapGather(Vector &v, const idx_t count, const SelectionVector &sel, data_ptr_t *key_locations) { + auto target = FlatVector::GetData(v); + + for (idx_t i = 0; i < count; ++i) { + const auto col_idx = sel.get_index(i); + target[col_idx] = Load(key_locations[i]); + key_locations[i] += sizeof(T); + } +} + +static void HeapGatherStringVector(Vector &v, const idx_t vcount, const SelectionVector &sel, + data_ptr_t *key_locations) { + const auto &validity = FlatVector::Validity(v); + auto target = FlatVector::GetData(v); + + for (idx_t i = 0; i < vcount; i++) { + const auto col_idx = sel.get_index(i); + if (!validity.RowIsValid(col_idx)) { + continue; + } + auto len = Load(key_locations[i]); + key_locations[i] += sizeof(uint32_t); + target[col_idx] = StringVector::AddStringOrBlob(v, string_t(const_char_ptr_cast(key_locations[i]), len)); + key_locations[i] += len; + } +} + +static void HeapGatherStructVector(Vector &v, const idx_t vcount, const SelectionVector &sel, + data_ptr_t *key_locations) { + // struct must have a validitymask for its fields + auto &child_types = StructType::GetChildTypes(v.GetType()); + const idx_t struct_validitymask_size = (child_types.size() + 7) / 8; + data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; + for (idx_t i = 0; i < vcount; i++) { + // use key_locations as the validitymask, and create struct_key_locations + struct_validitymask_locations[i] = key_locations[i]; + key_locations[i] += struct_validitymask_size; + } + + // now deserialize into the struct vectors + auto &children = StructVector::GetEntries(v); + for (idx_t i = 0; i < child_types.size(); i++) { + RowOperations::HeapGather(*children[i], vcount, sel, i, key_locations, struct_validitymask_locations); + } +} + +static void HeapGatherListVector(Vector &v, const idx_t vcount, const SelectionVector &sel, data_ptr_t *key_locations) { + const auto &validity = FlatVector::Validity(v); + + auto child_type = ListType::GetChildType(v.GetType()); + auto list_data = ListVector::GetData(v); + data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; + + uint64_t entry_offset = ListVector::GetListSize(v); + for (idx_t i = 0; i < vcount; i++) { + const auto col_idx = sel.get_index(i); + if (!validity.RowIsValid(col_idx)) { + continue; + } + // read list length + auto entry_remaining = Load(key_locations[i]); + key_locations[i] += sizeof(uint64_t); + // set list entry attributes + list_data[col_idx].length = entry_remaining; + list_data[col_idx].offset = entry_offset; + // skip over the validity mask + data_ptr_t validitymask_location = key_locations[i]; + idx_t offset_in_byte = 0; + key_locations[i] += (entry_remaining + 7) / 8; + // entry sizes + data_ptr_t var_entry_size_ptr = nullptr; + if (!TypeIsConstantSize(child_type.InternalType())) { + var_entry_size_ptr = key_locations[i]; + key_locations[i] += entry_remaining * sizeof(idx_t); + } + + // now read the list data + while (entry_remaining > 0) { + auto next = MinValue(entry_remaining, (idx_t)STANDARD_VECTOR_SIZE); + + // initialize a new vector to append + Vector append_vector(v.GetType()); + append_vector.SetVectorType(v.GetVectorType()); + + auto &list_vec_to_append = ListVector::GetEntry(append_vector); + + // set validity + //! Since we are constructing the vector, this will always be a flat vector. + auto &append_validity = FlatVector::Validity(list_vec_to_append); + for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { + append_validity.Set(entry_idx, *(validitymask_location) & (1 << offset_in_byte)); + if (++offset_in_byte == 8) { + validitymask_location++; + offset_in_byte = 0; + } + } + + // compute entry sizes and set locations where the list entries are + if (TypeIsConstantSize(child_type.InternalType())) { + // constant size list entries + const idx_t type_size = GetTypeIdSize(child_type.InternalType()); + for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { + list_entry_locations[entry_idx] = key_locations[i]; + key_locations[i] += type_size; + } + } else { + // variable size list entries + for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { + list_entry_locations[entry_idx] = key_locations[i]; + key_locations[i] += Load(var_entry_size_ptr); + var_entry_size_ptr += sizeof(idx_t); + } + } + + // now deserialize and add to listvector + RowOperations::HeapGather(list_vec_to_append, next, *FlatVector::IncrementalSelectionVector(), 0, + list_entry_locations, nullptr); + ListVector::Append(v, list_vec_to_append, next); + + // update for next iteration + entry_remaining -= next; + entry_offset += next; + } + } +} + +void RowOperations::HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, const idx_t &col_no, + data_ptr_t *key_locations, data_ptr_t *validitymask_locations) { + v.SetVectorType(VectorType::FLAT_VECTOR); + + auto &validity = FlatVector::Validity(v); + if (validitymask_locations) { + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); + + for (idx_t i = 0; i < vcount; i++) { + ValidityBytes row_mask(validitymask_locations[i]); + const auto valid = row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry); + const auto col_idx = sel.get_index(i); + validity.Set(col_idx, valid); + } + } + + auto type = v.GetType().InternalType(); + switch (type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::INT16: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::INT32: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::INT64: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::UINT8: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::UINT16: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::UINT32: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::UINT64: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::INT128: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::FLOAT: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::DOUBLE: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::INTERVAL: + TemplatedHeapGather(v, vcount, sel, key_locations); + break; + case PhysicalType::VARCHAR: + HeapGatherStringVector(v, vcount, sel, key_locations); + break; + case PhysicalType::STRUCT: + HeapGatherStructVector(v, vcount, sel, key_locations); + break; + case PhysicalType::LIST: + HeapGatherListVector(v, vcount, sel, key_locations); + break; + default: + throw NotImplementedException("Unimplemented deserialize from row-format"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp b/src/duckdb/src/common/row_operations/row_heap_scatter.cpp new file mode 100644 index 00000000..c51e2b5d --- /dev/null +++ b/src/duckdb/src/common/row_operations/row_heap_scatter.cpp @@ -0,0 +1,407 @@ +#include "duckdb/common/helper.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +using ValidityBytes = TemplatedValidityMask; + +static void ComputeStringEntrySizes(UnifiedVectorFormat &vdata, idx_t entry_sizes[], const idx_t ser_count, + const SelectionVector &sel, const idx_t offset) { + auto strings = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < ser_count; i++) { + auto idx = sel.get_index(i); + auto str_idx = vdata.sel->get_index(idx + offset); + if (vdata.validity.RowIsValid(str_idx)) { + entry_sizes[i] += sizeof(uint32_t) + strings[str_idx].GetSize(); + } + } +} + +static void ComputeStructEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, + const SelectionVector &sel, idx_t offset) { + // obtain child vectors + idx_t num_children; + auto &children = StructVector::GetEntries(v); + num_children = children.size(); + // add struct validitymask size + const idx_t struct_validitymask_size = (num_children + 7) / 8; + for (idx_t i = 0; i < ser_count; i++) { + entry_sizes[i] += struct_validitymask_size; + } + // compute size of child vectors + for (auto &struct_vector : children) { + RowOperations::ComputeEntrySizes(*struct_vector, entry_sizes, vcount, ser_count, sel, offset); + } +} + +static void ComputeListEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, + const SelectionVector &sel, idx_t offset) { + auto list_data = ListVector::GetData(v); + auto &child_vector = ListVector::GetEntry(v); + idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; + for (idx_t i = 0; i < ser_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx + offset); + if (vdata.validity.RowIsValid(source_idx)) { + auto list_entry = list_data[source_idx]; + + // make room for list length, list validitymask + entry_sizes[i] += sizeof(list_entry.length); + entry_sizes[i] += (list_entry.length + 7) / 8; + + // serialize size of each entry (if non-constant size) + if (!TypeIsConstantSize(ListType::GetChildType(v.GetType()).InternalType())) { + entry_sizes[i] += list_entry.length * sizeof(list_entry.length); + } + + // compute size of each the elements in list_entry and sum them + auto entry_remaining = list_entry.length; + auto entry_offset = list_entry.offset; + while (entry_remaining > 0) { + // the list entry can span multiple vectors + auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); + + // compute and add to the total + std::fill_n(list_entry_sizes, next, 0); + RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, + *FlatVector::IncrementalSelectionVector(), entry_offset); + for (idx_t list_idx = 0; list_idx < next; list_idx++) { + entry_sizes[i] += list_entry_sizes[list_idx]; + } + + // update for next iteration + entry_remaining -= next; + entry_offset += next; + } + } + } +} + +void RowOperations::ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, + idx_t ser_count, const SelectionVector &sel, idx_t offset) { + const auto physical_type = v.GetType().InternalType(); + if (TypeIsConstantSize(physical_type)) { + const auto type_size = GetTypeIdSize(physical_type); + for (idx_t i = 0; i < ser_count; i++) { + entry_sizes[i] += type_size; + } + } else { + switch (physical_type) { + case PhysicalType::VARCHAR: + ComputeStringEntrySizes(vdata, entry_sizes, ser_count, sel, offset); + break; + case PhysicalType::STRUCT: + ComputeStructEntrySizes(v, entry_sizes, vcount, ser_count, sel, offset); + break; + case PhysicalType::LIST: + ComputeListEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); + break; + default: + // LCOV_EXCL_START + throw NotImplementedException("Column with variable size type %s cannot be serialized to row-format", + v.GetType().ToString()); + // LCOV_EXCL_STOP + } + } +} + +void RowOperations::ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, + const SelectionVector &sel, idx_t offset) { + UnifiedVectorFormat vdata; + v.ToUnifiedFormat(vcount, vdata); + ComputeEntrySizes(v, vdata, entry_sizes, vcount, ser_count, sel, offset); +} + +template +static void TemplatedHeapScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, idx_t col_idx, + data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { + auto source = UnifiedVectorFormat::GetData(vdata); + if (!validitymask_locations) { + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx + offset); + + auto target = (T *)key_locations[i]; + Store(source[source_idx], data_ptr_cast(target)); + key_locations[i] += sizeof(T); + } + } else { + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + const auto bit = ~(1UL << idx_in_entry); + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx + offset); + + auto target = (T *)key_locations[i]; + Store(source[source_idx], data_ptr_cast(target)); + key_locations[i] += sizeof(T); + + // set the validitymask + if (!vdata.validity.RowIsValid(source_idx)) { + *(validitymask_locations[i] + entry_idx) &= bit; + } + } + } +} + +static void HeapScatterStringVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_idx, + data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { + UnifiedVectorFormat vdata; + v.ToUnifiedFormat(vcount, vdata); + + auto strings = UnifiedVectorFormat::GetData(vdata); + if (!validitymask_locations) { + for (idx_t i = 0; i < ser_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx + offset); + if (vdata.validity.RowIsValid(source_idx)) { + auto &string_entry = strings[source_idx]; + // store string size + Store(string_entry.GetSize(), key_locations[i]); + key_locations[i] += sizeof(uint32_t); + // store the string + memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); + key_locations[i] += string_entry.GetSize(); + } + } + } else { + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + const auto bit = ~(1UL << idx_in_entry); + for (idx_t i = 0; i < ser_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx + offset); + if (vdata.validity.RowIsValid(source_idx)) { + auto &string_entry = strings[source_idx]; + // store string size + Store(string_entry.GetSize(), key_locations[i]); + key_locations[i] += sizeof(uint32_t); + // store the string + memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); + key_locations[i] += string_entry.GetSize(); + } else { + // set the validitymask + *(validitymask_locations[i] + entry_idx) &= bit; + } + } + } +} + +static void HeapScatterStructVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_idx, + data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { + UnifiedVectorFormat vdata; + v.ToUnifiedFormat(vcount, vdata); + + auto &children = StructVector::GetEntries(v); + idx_t num_children = children.size(); + + // the whole struct itself can be NULL + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + const auto bit = ~(1UL << idx_in_entry); + + // struct must have a validitymask for its fields + const idx_t struct_validitymask_size = (num_children + 7) / 8; + data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; + for (idx_t i = 0; i < ser_count; i++) { + // initialize the struct validity mask + struct_validitymask_locations[i] = key_locations[i]; + memset(struct_validitymask_locations[i], -1, struct_validitymask_size); + key_locations[i] += struct_validitymask_size; + + // set whether the whole struct is null + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + if (validitymask_locations && !vdata.validity.RowIsValid(source_idx)) { + *(validitymask_locations[i] + entry_idx) &= bit; + } + } + + // now serialize the struct vectors + for (idx_t i = 0; i < children.size(); i++) { + auto &struct_vector = *children[i]; + RowOperations::HeapScatter(struct_vector, vcount, sel, ser_count, i, key_locations, + struct_validitymask_locations, offset); + } +} + +static void HeapScatterListVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_no, + data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { + UnifiedVectorFormat vdata; + v.ToUnifiedFormat(vcount, vdata); + + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); + + auto list_data = ListVector::GetData(v); + + auto &child_vector = ListVector::GetEntry(v); + + UnifiedVectorFormat list_vdata; + child_vector.ToUnifiedFormat(ListVector::GetListSize(v), list_vdata); + auto child_type = ListType::GetChildType(v.GetType()).InternalType(); + + idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; + data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; + + for (idx_t i = 0; i < ser_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx + offset); + if (!vdata.validity.RowIsValid(source_idx)) { + if (validitymask_locations) { + // set the row validitymask for this column to invalid + ValidityBytes row_mask(validitymask_locations[i]); + row_mask.SetInvalidUnsafe(entry_idx, idx_in_entry); + } + continue; + } + auto list_entry = list_data[source_idx]; + + // store list length + Store(list_entry.length, key_locations[i]); + key_locations[i] += sizeof(list_entry.length); + + // make room for the validitymask + data_ptr_t list_validitymask_location = key_locations[i]; + idx_t entry_offset_in_byte = 0; + idx_t validitymask_size = (list_entry.length + 7) / 8; + memset(list_validitymask_location, -1, validitymask_size); + key_locations[i] += validitymask_size; + + // serialize size of each entry (if non-constant size) + data_ptr_t var_entry_size_ptr = nullptr; + if (!TypeIsConstantSize(child_type)) { + var_entry_size_ptr = key_locations[i]; + key_locations[i] += list_entry.length * sizeof(idx_t); + } + + auto entry_remaining = list_entry.length; + auto entry_offset = list_entry.offset; + while (entry_remaining > 0) { + // the list entry can span multiple vectors + auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); + + // serialize list validity + for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { + auto list_idx = list_vdata.sel->get_index(entry_idx + entry_offset); + if (!list_vdata.validity.RowIsValid(list_idx)) { + *(list_validitymask_location) &= ~(1UL << entry_offset_in_byte); + } + if (++entry_offset_in_byte == 8) { + list_validitymask_location++; + entry_offset_in_byte = 0; + } + } + + if (TypeIsConstantSize(child_type)) { + // constant size list entries: set list entry locations + const idx_t type_size = GetTypeIdSize(child_type); + for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { + list_entry_locations[entry_idx] = key_locations[i]; + key_locations[i] += type_size; + } + } else { + // variable size list entries: compute entry sizes and set list entry locations + std::fill_n(list_entry_sizes, next, 0); + RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, + *FlatVector::IncrementalSelectionVector(), entry_offset); + for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { + list_entry_locations[entry_idx] = key_locations[i]; + key_locations[i] += list_entry_sizes[entry_idx]; + Store(list_entry_sizes[entry_idx], var_entry_size_ptr); + var_entry_size_ptr += sizeof(idx_t); + } + } + + // now serialize to the locations + RowOperations::HeapScatter(child_vector, ListVector::GetListSize(v), + *FlatVector::IncrementalSelectionVector(), next, 0, list_entry_locations, + nullptr, entry_offset); + + // update for next iteration + entry_remaining -= next; + entry_offset += next; + } + } +} + +void RowOperations::HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_idx, + data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset) { + if (TypeIsConstantSize(v.GetType().InternalType())) { + UnifiedVectorFormat vdata; + v.ToUnifiedFormat(vcount, vdata); + RowOperations::HeapScatterVData(vdata, v.GetType().InternalType(), sel, ser_count, col_idx, key_locations, + validitymask_locations, offset); + } else { + switch (v.GetType().InternalType()) { + case PhysicalType::VARCHAR: + HeapScatterStringVector(v, vcount, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::STRUCT: + HeapScatterStructVector(v, vcount, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::LIST: + HeapScatterListVector(v, vcount, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + default: + // LCOV_EXCL_START + throw NotImplementedException("Serialization of variable length vector with type %s", + v.GetType().ToString()); + // LCOV_EXCL_STOP + } + } +} + +void RowOperations::HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, + idx_t ser_count, idx_t col_idx, data_ptr_t *key_locations, + data_ptr_t *validitymask_locations, idx_t offset) { + switch (type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::INT16: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::INT32: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::INT64: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::UINT8: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::UINT16: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::UINT32: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::UINT64: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::INT128: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::FLOAT: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::DOUBLE: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + case PhysicalType::INTERVAL: + TemplatedHeapScatter(vdata, sel, ser_count, col_idx, key_locations, validitymask_locations, offset); + break; + default: + throw NotImplementedException("FIXME: Serialize to of constant type column to row-format"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_matcher.cpp b/src/duckdb/src/common/row_operations/row_matcher.cpp new file mode 100644 index 00000000..95867096 --- /dev/null +++ b/src/duckdb/src/common/row_operations/row_matcher.cpp @@ -0,0 +1,375 @@ +#include "duckdb/common/row_operations/row_matcher.hpp" + +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/row/tuple_data_collection.hpp" + +namespace duckdb { + +using ValidityBytes = TupleDataLayout::ValidityBytes; + +template +static idx_t TemplatedMatch(Vector &, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, const idx_t count, + const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, const idx_t col_idx, + const vector &, SelectionVector *no_match_sel, idx_t &no_match_count) { + using COMPARISON_OP = ComparisonOperationWrapper; + + // LHS + const auto &lhs_sel = *lhs_format.unified.sel; + const auto lhs_data = UnifiedVectorFormat::GetData(lhs_format.unified); + const auto &lhs_validity = lhs_format.unified.validity; + + // RHS + const auto rhs_locations = FlatVector::GetData(rhs_row_locations); + const auto rhs_offset_in_row = rhs_layout.GetOffsets()[col_idx]; + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + idx_t match_count = 0; + for (idx_t i = 0; i < count; i++) { + const auto idx = sel.get_index(i); + + const auto lhs_idx = lhs_sel.get_index(idx); + const auto lhs_null = lhs_validity.AllValid() ? false : !lhs_validity.RowIsValid(lhs_idx); + + const auto &rhs_location = rhs_locations[idx]; + const ValidityBytes rhs_mask(rhs_location); + const auto rhs_null = !rhs_mask.RowIsValid(rhs_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry); + + if (COMPARISON_OP::template Operation(lhs_data[lhs_idx], Load(rhs_location + rhs_offset_in_row), lhs_null, + rhs_null)) { + sel.set_index(match_count++, idx); + } else if (NO_MATCH_SEL) { + no_match_sel->set_index(no_match_count++, idx); + } + } + return match_count; +} + +template +static idx_t StructMatchEquality(Vector &lhs_vector, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, + const idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, + const idx_t col_idx, const vector &child_functions, + SelectionVector *no_match_sel, idx_t &no_match_count) { + using COMPARISON_OP = ComparisonOperationWrapper; + + // LHS + const auto &lhs_sel = *lhs_format.unified.sel; + const auto &lhs_validity = lhs_format.unified.validity; + + // RHS + const auto rhs_locations = FlatVector::GetData(rhs_row_locations); + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + idx_t match_count = 0; + for (idx_t i = 0; i < count; i++) { + const auto idx = sel.get_index(i); + + const auto lhs_idx = lhs_sel.get_index(idx); + const auto lhs_null = lhs_validity.AllValid() ? false : !lhs_validity.RowIsValid(lhs_idx); + + const auto &rhs_location = rhs_locations[idx]; + const ValidityBytes rhs_mask(rhs_location); + const auto rhs_null = !rhs_mask.RowIsValid(rhs_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry); + + // For structs there is no value to compare, here we match NULLs and let recursion do the rest + // So we use the comparison only if rhs or LHS is NULL and COMPARE_NULL is true + if (!(lhs_null || rhs_null) || + (COMPARISON_OP::COMPARE_NULL && COMPARISON_OP::template Operation(0, 0, lhs_null, rhs_null))) { + sel.set_index(match_count++, idx); + } else if (NO_MATCH_SEL) { + no_match_sel->set_index(no_match_count++, idx); + } + } + + // Create a Vector of pointers to the start of the TupleDataLayout of the STRUCT + Vector rhs_struct_row_locations(LogicalType::POINTER); + const auto rhs_offset_in_row = rhs_layout.GetOffsets()[col_idx]; + auto rhs_struct_locations = FlatVector::GetData(rhs_struct_row_locations); + for (idx_t i = 0; i < match_count; i++) { + const auto idx = sel.get_index(i); + rhs_struct_locations[idx] = rhs_locations[idx] + rhs_offset_in_row; + } + + // Get the struct layout and struct entries + const auto &rhs_struct_layout = rhs_layout.GetStructLayout(col_idx); + auto &lhs_struct_vectors = StructVector::GetEntries(lhs_vector); + D_ASSERT(rhs_struct_layout.ColumnCount() == lhs_struct_vectors.size()); + + for (idx_t struct_col_idx = 0; struct_col_idx < rhs_struct_layout.ColumnCount(); struct_col_idx++) { + auto &lhs_struct_vector = *lhs_struct_vectors[struct_col_idx]; + auto &lhs_struct_format = lhs_format.children[struct_col_idx]; + const auto &child_function = child_functions[struct_col_idx]; + match_count = child_function.function(lhs_struct_vector, lhs_struct_format, sel, match_count, rhs_struct_layout, + rhs_struct_row_locations, struct_col_idx, child_function.child_functions, + no_match_sel, no_match_count); + } + + return match_count; +} + +template +static idx_t SelectComparison(Vector &, Vector &, const SelectionVector &, idx_t, SelectionVector *, + SelectionVector *) { + throw NotImplementedException("Unsupported list comparison operand for RowMatcher::GetMatchFunction"); +} + +template <> +idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::NestedEquals(left, right, sel, count, true_sel, false_sel); +} + +template <> +idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::NestedNotEquals(left, right, sel, count, true_sel, false_sel); +} + +template <> +idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::DistinctFrom(left, right, &sel, count, true_sel, false_sel); +} + +template <> +idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::NotDistinctFrom(left, right, &sel, count, true_sel, false_sel); +} + +template <> +idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel); +} + +template <> +idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel); +} + +template <> +idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::DistinctLessThan(left, right, &sel, count, true_sel, false_sel); +} + +template <> +idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::DistinctLessThanEquals(left, right, &sel, count, true_sel, false_sel); +} + +template +static idx_t GenericNestedMatch(Vector &lhs_vector, const TupleDataVectorFormat &, SelectionVector &sel, + const idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, + const idx_t col_idx, const vector &, SelectionVector *no_match_sel, + idx_t &no_match_count) { + const auto &type = rhs_layout.GetTypes()[col_idx]; + + // Gather a dense Vector containing the column values being matched + Vector key(type); + const auto gather_function = TupleDataCollection::GetGatherFunction(type); + gather_function.function(rhs_layout, rhs_row_locations, col_idx, sel, count, key, + *FlatVector::IncrementalSelectionVector(), key, gather_function.child_functions); + + // Densify the input column + Vector sliced(lhs_vector, sel, count); + + if (NO_MATCH_SEL) { + SelectionVector no_match_sel_offset(no_match_sel->data() + no_match_count); + auto match_count = SelectComparison(sliced, key, sel, count, &sel, &no_match_sel_offset); + no_match_count += count - match_count; + return match_count; + } + return SelectComparison(sliced, key, sel, count, &sel, nullptr); +} + +void RowMatcher::Initialize(const bool no_match_sel, const TupleDataLayout &layout, const Predicates &predicates) { + match_functions.reserve(predicates.size()); + for (idx_t col_idx = 0; col_idx < predicates.size(); col_idx++) { + match_functions.push_back(GetMatchFunction(no_match_sel, layout.GetTypes()[col_idx], predicates[col_idx])); + } +} + +idx_t RowMatcher::Match(DataChunk &lhs, const vector &lhs_formats, SelectionVector &sel, + idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, + SelectionVector *no_match_sel, idx_t &no_match_count) { + D_ASSERT(!match_functions.empty()); + for (idx_t col_idx = 0; col_idx < match_functions.size(); col_idx++) { + const auto &match_function = match_functions[col_idx]; + count = + match_function.function(lhs.data[col_idx], lhs_formats[col_idx], sel, count, rhs_layout, rhs_row_locations, + col_idx, match_function.child_functions, no_match_sel, no_match_count); + } + return count; +} + +MatchFunction RowMatcher::GetMatchFunction(const bool no_match_sel, const LogicalType &type, + const ExpressionType predicate) { + return no_match_sel ? GetMatchFunction(type, predicate) : GetMatchFunction(type, predicate); +} + +template +MatchFunction RowMatcher::GetMatchFunction(const LogicalType &type, const ExpressionType predicate) { + switch (type.InternalType()) { + case PhysicalType::BOOL: + return GetMatchFunction(predicate); + case PhysicalType::INT8: + return GetMatchFunction(predicate); + case PhysicalType::INT16: + return GetMatchFunction(predicate); + case PhysicalType::INT32: + return GetMatchFunction(predicate); + case PhysicalType::INT64: + return GetMatchFunction(predicate); + case PhysicalType::INT128: + return GetMatchFunction(predicate); + case PhysicalType::UINT8: + return GetMatchFunction(predicate); + case PhysicalType::UINT16: + return GetMatchFunction(predicate); + case PhysicalType::UINT32: + return GetMatchFunction(predicate); + case PhysicalType::UINT64: + return GetMatchFunction(predicate); + case PhysicalType::FLOAT: + return GetMatchFunction(predicate); + case PhysicalType::DOUBLE: + return GetMatchFunction(predicate); + case PhysicalType::INTERVAL: + return GetMatchFunction(predicate); + case PhysicalType::VARCHAR: + return GetMatchFunction(predicate); + case PhysicalType::STRUCT: + return GetStructMatchFunction(type, predicate); + case PhysicalType::LIST: + return GetListMatchFunction(predicate); + default: + throw InternalException("Unsupported PhysicalType for RowMatcher::GetMatchFunction: %s", + EnumUtil::ToString(type.InternalType())); + } +} + +template +MatchFunction RowMatcher::GetMatchFunction(const ExpressionType predicate) { + MatchFunction result; + switch (predicate) { + case ExpressionType::COMPARE_EQUAL: + result.function = TemplatedMatch; + break; + case ExpressionType::COMPARE_NOTEQUAL: + result.function = TemplatedMatch; + break; + case ExpressionType::COMPARE_DISTINCT_FROM: + result.function = TemplatedMatch; + break; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + result.function = TemplatedMatch; + break; + case ExpressionType::COMPARE_GREATERTHAN: + result.function = TemplatedMatch; + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + result.function = TemplatedMatch; + break; + case ExpressionType::COMPARE_LESSTHAN: + result.function = TemplatedMatch; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + result.function = TemplatedMatch; + break; + default: + throw InternalException("Unsupported ExpressionType for RowMatcher::GetMatchFunction: %s", + EnumUtil::ToString(predicate)); + } + return result; +} + +template +MatchFunction RowMatcher::GetStructMatchFunction(const LogicalType &type, const ExpressionType predicate) { + // We perform equality conditions like it's just a row, but we cannot perform inequality conditions like a row, + // because for equality conditions we need to always loop through all columns, but for inequality conditions, + // we need to find the first inequality, so the loop looks very different + MatchFunction result; + ExpressionType child_predicate = predicate; + switch (predicate) { + case ExpressionType::COMPARE_EQUAL: + result.function = StructMatchEquality; + child_predicate = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + break; + case ExpressionType::COMPARE_NOTEQUAL: + result.function = GenericNestedMatch; + return result; + case ExpressionType::COMPARE_DISTINCT_FROM: + result.function = GenericNestedMatch; + return result; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + result.function = StructMatchEquality; + break; + case ExpressionType::COMPARE_GREATERTHAN: + result.function = GenericNestedMatch; + return result; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + result.function = GenericNestedMatch; + return result; + case ExpressionType::COMPARE_LESSTHAN: + result.function = GenericNestedMatch; + return result; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + result.function = GenericNestedMatch; + return result; + default: + throw InternalException("Unsupported ExpressionType for RowMatcher::GetStructMatchFunction: %s", + EnumUtil::ToString(predicate)); + } + + result.child_functions.reserve(StructType::GetChildCount(type)); + for (const auto &child_type : StructType::GetChildTypes(type)) { + result.child_functions.push_back(GetMatchFunction(child_type.second, child_predicate)); + } + + return result; +} + +template +MatchFunction RowMatcher::GetListMatchFunction(const ExpressionType predicate) { + MatchFunction result; + switch (predicate) { + case ExpressionType::COMPARE_EQUAL: + result.function = GenericNestedMatch; + break; + case ExpressionType::COMPARE_NOTEQUAL: + result.function = GenericNestedMatch; + break; + case ExpressionType::COMPARE_DISTINCT_FROM: + result.function = GenericNestedMatch; + break; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + result.function = GenericNestedMatch; + break; + case ExpressionType::COMPARE_GREATERTHAN: + result.function = GenericNestedMatch; + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + result.function = GenericNestedMatch; + break; + case ExpressionType::COMPARE_LESSTHAN: + result.function = GenericNestedMatch; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + result.function = GenericNestedMatch; + break; + default: + throw InternalException("Unsupported ExpressionType for RowMatcher::GetListMatchFunction: %s", + EnumUtil::ToString(predicate)); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp b/src/duckdb/src/common/row_operations/row_radix_scatter.cpp new file mode 100644 index 00000000..1a8acaec --- /dev/null +++ b/src/duckdb/src/common/row_operations/row_radix_scatter.cpp @@ -0,0 +1,271 @@ +#include "duckdb/common/helper.hpp" +#include "duckdb/common/radix.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +template +void TemplatedRadixScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, + data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, + const idx_t offset) { + auto source = UnifiedVectorFormat::GetData(vdata); + if (has_null) { + auto &validity = vdata.validity; + const data_t valid = nulls_first ? 1 : 0; + const data_t invalid = 1 - valid; + + for (idx_t i = 0; i < add_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + // write validity and according value + if (validity.RowIsValid(source_idx)) { + key_locations[i][0] = valid; + Radix::EncodeData(key_locations[i] + 1, source[source_idx]); + // invert bits if desc + if (desc) { + for (idx_t s = 1; s < sizeof(T) + 1; s++) { + *(key_locations[i] + s) = ~*(key_locations[i] + s); + } + } + } else { + key_locations[i][0] = invalid; + memset(key_locations[i] + 1, '\0', sizeof(T)); + } + key_locations[i] += sizeof(T) + 1; + } + } else { + for (idx_t i = 0; i < add_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + // write value + Radix::EncodeData(key_locations[i], source[source_idx]); + // invert bits if desc + if (desc) { + for (idx_t s = 0; s < sizeof(T); s++) { + *(key_locations[i] + s) = ~*(key_locations[i] + s); + } + } + key_locations[i] += sizeof(T); + } + } +} + +void RadixScatterStringVector(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, + data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, + const idx_t prefix_len, idx_t offset) { + auto source = UnifiedVectorFormat::GetData(vdata); + if (has_null) { + auto &validity = vdata.validity; + const data_t valid = nulls_first ? 1 : 0; + const data_t invalid = 1 - valid; + + for (idx_t i = 0; i < add_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + // write validity and according value + if (validity.RowIsValid(source_idx)) { + key_locations[i][0] = valid; + Radix::EncodeStringDataPrefix(key_locations[i] + 1, source[source_idx], prefix_len); + // invert bits if desc + if (desc) { + for (idx_t s = 1; s < prefix_len + 1; s++) { + *(key_locations[i] + s) = ~*(key_locations[i] + s); + } + } + } else { + key_locations[i][0] = invalid; + memset(key_locations[i] + 1, '\0', prefix_len); + } + key_locations[i] += prefix_len + 1; + } + } else { + for (idx_t i = 0; i < add_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + // write value + Radix::EncodeStringDataPrefix(key_locations[i], source[source_idx], prefix_len); + // invert bits if desc + if (desc) { + for (idx_t s = 0; s < prefix_len; s++) { + *(key_locations[i] + s) = ~*(key_locations[i] + s); + } + } + key_locations[i] += prefix_len; + } + } +} + +void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, + data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, + const idx_t prefix_len, const idx_t width, const idx_t offset) { + auto list_data = ListVector::GetData(v); + auto &child_vector = ListVector::GetEntry(v); + auto list_size = ListVector::GetListSize(v); + child_vector.Flatten(list_size); + + // serialize null values + if (has_null) { + auto &validity = vdata.validity; + const data_t valid = nulls_first ? 1 : 0; + const data_t invalid = 1 - valid; + + for (idx_t i = 0; i < add_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + data_ptr_t key_location = key_locations[i] + 1; + // write validity and according value + if (validity.RowIsValid(source_idx)) { + key_locations[i][0] = valid; + key_locations[i]++; + auto &list_entry = list_data[source_idx]; + if (list_entry.length > 0) { + // denote that the list is not empty with a 1 + key_locations[i][0] = 1; + key_locations[i]++; + RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, + key_locations + i, false, true, false, prefix_len, width - 1, + list_entry.offset); + } else { + // denote that the list is empty with a 0 + key_locations[i][0] = 0; + key_locations[i]++; + memset(key_locations[i], '\0', width - 2); + } + // invert bits if desc + if (desc) { + for (idx_t s = 0; s < width - 1; s++) { + *(key_location + s) = ~*(key_location + s); + } + } + } else { + key_locations[i][0] = invalid; + memset(key_locations[i] + 1, '\0', width - 1); + key_locations[i] += width; + } + } + } else { + for (idx_t i = 0; i < add_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + auto &list_entry = list_data[source_idx]; + data_ptr_t key_location = key_locations[i]; + if (list_entry.length > 0) { + // denote that the list is not empty with a 1 + key_locations[i][0] = 1; + key_locations[i]++; + RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, + key_locations + i, false, true, false, prefix_len, width - 1, + list_entry.offset); + } else { + // denote that the list is empty with a 0 + key_locations[i][0] = 0; + key_locations[i]++; + memset(key_locations[i], '\0', width - 1); + } + // invert bits if desc + if (desc) { + for (idx_t s = 0; s < width; s++) { + *(key_location + s) = ~*(key_location + s); + } + } + } + } +} + +void RadixScatterStructVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, + idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, + const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { + // serialize null values + if (has_null) { + auto &validity = vdata.validity; + const data_t valid = nulls_first ? 1 : 0; + const data_t invalid = 1 - valid; + + for (idx_t i = 0; i < add_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + // write validity and according value + if (validity.RowIsValid(source_idx)) { + key_locations[i][0] = valid; + } else { + key_locations[i][0] = invalid; + } + key_locations[i]++; + } + width--; + } + // serialize the struct + auto &child_vector = *StructVector::GetEntries(v)[0]; + RowOperations::RadixScatter(child_vector, vcount, *FlatVector::IncrementalSelectionVector(), add_count, + key_locations, false, true, false, prefix_len, width, offset); + // invert bits if desc + if (desc) { + for (idx_t i = 0; i < add_count; i++) { + for (idx_t s = 0; s < width; s++) { + *(key_locations[i] - width + s) = ~*(key_locations[i] - width + s); + } + } + } +} + +void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, + data_ptr_t *key_locations, bool desc, bool has_null, bool nulls_first, + idx_t prefix_len, idx_t width, idx_t offset) { + UnifiedVectorFormat vdata; + v.ToUnifiedFormat(vcount, vdata); + switch (v.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::INT16: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::INT32: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::INT64: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::UINT8: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::UINT16: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::UINT32: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::UINT64: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::INT128: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::FLOAT: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::DOUBLE: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::INTERVAL: + TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); + break; + case PhysicalType::VARCHAR: + RadixScatterStringVector(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, offset); + break; + case PhysicalType::LIST: + RadixScatterListVector(v, vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, width, + offset); + break; + case PhysicalType::STRUCT: + RadixScatterStructVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, + prefix_len, width, offset); + break; + default: + throw NotImplementedException("Cannot ORDER BY column with type %s", v.GetType().ToString()); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_scatter.cpp b/src/duckdb/src/common/row_operations/row_scatter.cpp new file mode 100644 index 00000000..c5104d4f --- /dev/null +++ b/src/duckdb/src/common/row_operations/row_scatter.cpp @@ -0,0 +1,228 @@ +//===--------------------------------------------------------------------===// +// row_scatter.cpp +// Description: This file contains the implementation of the row scattering +// operators +//===--------------------------------------------------------------------===// + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/row/row_data_collection.hpp" +#include "duckdb/common/types/row/row_layout.hpp" +#include "duckdb/common/types/selection_vector.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +using ValidityBytes = RowLayout::ValidityBytes; + +template +static void TemplatedScatter(UnifiedVectorFormat &col, Vector &rows, const SelectionVector &sel, const idx_t count, + const idx_t col_offset, const idx_t col_no) { + auto data = UnifiedVectorFormat::GetData(col); + auto ptrs = FlatVector::GetData(rows); + + if (!col.validity.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto col_idx = col.sel->get_index(idx); + auto row = ptrs[idx]; + + auto isnull = !col.validity.RowIsValid(col_idx); + T store_value = isnull ? NullValue() : data[col_idx]; + Store(store_value, row + col_offset); + if (isnull) { + ValidityBytes col_mask(ptrs[idx]); + col_mask.SetInvalidUnsafe(col_no); + } + } + } else { + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto col_idx = col.sel->get_index(idx); + auto row = ptrs[idx]; + + Store(data[col_idx], row + col_offset); + } + } +} + +static void ComputeStringEntrySizes(const UnifiedVectorFormat &col, idx_t entry_sizes[], const SelectionVector &sel, + const idx_t count, const idx_t offset = 0) { + auto data = UnifiedVectorFormat::GetData(col); + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto col_idx = col.sel->get_index(idx) + offset; + const auto &str = data[col_idx]; + if (col.validity.RowIsValid(col_idx) && !str.IsInlined()) { + entry_sizes[i] += str.GetSize(); + } + } +} + +static void ScatterStringVector(UnifiedVectorFormat &col, Vector &rows, data_ptr_t str_locations[], + const SelectionVector &sel, const idx_t count, const idx_t col_offset, + const idx_t col_no) { + auto string_data = UnifiedVectorFormat::GetData(col); + auto ptrs = FlatVector::GetData(rows); + + // Write out zero length to avoid swizzling problems. + const string_t null(nullptr, 0); + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto col_idx = col.sel->get_index(idx); + auto row = ptrs[idx]; + if (!col.validity.RowIsValid(col_idx)) { + ValidityBytes col_mask(row); + col_mask.SetInvalidUnsafe(col_no); + Store(null, row + col_offset); + } else if (string_data[col_idx].IsInlined()) { + Store(string_data[col_idx], row + col_offset); + } else { + const auto &str = string_data[col_idx]; + string_t inserted(const_char_ptr_cast(str_locations[i]), str.GetSize()); + memcpy(inserted.GetDataWriteable(), str.GetData(), str.GetSize()); + str_locations[i] += str.GetSize(); + inserted.Finalize(); + Store(inserted, row + col_offset); + } + } +} + +static void ScatterNestedVector(Vector &vec, UnifiedVectorFormat &col, Vector &rows, data_ptr_t data_locations[], + const SelectionVector &sel, const idx_t count, const idx_t col_offset, + const idx_t col_no, const idx_t vcount) { + // Store pointers to the data in the row + // Do this first because SerializeVector destroys the locations + auto ptrs = FlatVector::GetData(rows); + data_ptr_t validitymask_locations[STANDARD_VECTOR_SIZE]; + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto row = ptrs[idx]; + validitymask_locations[i] = row; + + Store(data_locations[i], row + col_offset); + } + + // Serialise the data + RowOperations::HeapScatter(vec, vcount, sel, count, col_no, data_locations, validitymask_locations); +} + +void RowOperations::Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, + RowDataCollection &string_heap, const SelectionVector &sel, idx_t count) { + if (count == 0) { + return; + } + + // Set the validity mask for each row before inserting data + auto ptrs = FlatVector::GetData(rows); + for (idx_t i = 0; i < count; ++i) { + auto row_idx = sel.get_index(i); + auto row = ptrs[row_idx]; + ValidityBytes(row).SetAllValid(layout.ColumnCount()); + } + + const auto vcount = columns.size(); + auto &offsets = layout.GetOffsets(); + auto &types = layout.GetTypes(); + + // Compute the entry size of the variable size columns + vector handles; + data_ptr_t data_locations[STANDARD_VECTOR_SIZE]; + if (!layout.AllConstant()) { + idx_t entry_sizes[STANDARD_VECTOR_SIZE]; + std::fill_n(entry_sizes, count, sizeof(uint32_t)); + for (idx_t col_no = 0; col_no < types.size(); col_no++) { + if (TypeIsConstantSize(types[col_no].InternalType())) { + continue; + } + + auto &vec = columns.data[col_no]; + auto &col = col_data[col_no]; + switch (types[col_no].InternalType()) { + case PhysicalType::VARCHAR: + ComputeStringEntrySizes(col, entry_sizes, sel, count); + break; + case PhysicalType::LIST: + case PhysicalType::STRUCT: + RowOperations::ComputeEntrySizes(vec, col, entry_sizes, vcount, count, sel); + break; + default: + throw InternalException("Unsupported type for RowOperations::Scatter"); + } + } + + // Build out the buffer space + handles = string_heap.Build(count, data_locations, entry_sizes); + + // Serialize information that is needed for swizzling if the computation goes out-of-core + const idx_t heap_pointer_offset = layout.GetHeapOffset(); + for (idx_t i = 0; i < count; i++) { + auto row_idx = sel.get_index(i); + auto row = ptrs[row_idx]; + // Pointer to this row in the heap block + Store(data_locations[i], row + heap_pointer_offset); + // Row size is stored in the heap in front of each row + Store(entry_sizes[i], data_locations[i]); + data_locations[i] += sizeof(uint32_t); + } + } + + for (idx_t col_no = 0; col_no < types.size(); col_no++) { + auto &vec = columns.data[col_no]; + auto &col = col_data[col_no]; + auto col_offset = offsets[col_no]; + + switch (types[col_no].InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::INT16: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::INT32: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::INT64: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::UINT8: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::UINT16: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::UINT32: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::UINT64: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::INT128: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::FLOAT: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::DOUBLE: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::INTERVAL: + TemplatedScatter(col, rows, sel, count, col_offset, col_no); + break; + case PhysicalType::VARCHAR: + ScatterStringVector(col, rows, data_locations, sel, count, col_offset, col_no); + break; + case PhysicalType::LIST: + case PhysicalType::STRUCT: + ScatterNestedVector(vec, col, rows, data_locations, sel, count, col_offset, col_no, vcount); + break; + default: + throw InternalException("Unsupported type for RowOperations::Scatter"); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/binary_deserializer.cpp b/src/duckdb/src/common/serializer/binary_deserializer.cpp new file mode 100644 index 00000000..86d9638b --- /dev/null +++ b/src/duckdb/src/common/serializer/binary_deserializer.cpp @@ -0,0 +1,133 @@ +#include "duckdb/common/serializer/binary_deserializer.hpp" + +namespace duckdb { + +//------------------------------------------------------------------------- +// Nested Type Hooks +//------------------------------------------------------------------------- +void BinaryDeserializer::OnPropertyBegin(const field_id_t field_id, const char *) { + auto field = NextField(); + if (field != field_id) { + throw InternalException("Failed to deserialize: field id mismatch, expected: %d, got: %d", field_id, field); + } +} + +void BinaryDeserializer::OnPropertyEnd() { +} + +bool BinaryDeserializer::OnOptionalPropertyBegin(const field_id_t field_id, const char *s) { + auto next_field = PeekField(); + auto present = next_field == field_id; + if (present) { + ConsumeField(); + } + return present; +} + +void BinaryDeserializer::OnOptionalPropertyEnd(bool present) { +} + +void BinaryDeserializer::OnObjectBegin() { + nesting_level++; +} + +void BinaryDeserializer::OnObjectEnd() { + auto next_field = NextField(); + if (next_field != MESSAGE_TERMINATOR_FIELD_ID) { + throw InternalException("Failed to deserialize: expected end of object, but found field id: %d", next_field); + } + nesting_level--; +} + +idx_t BinaryDeserializer::OnListBegin() { + return VarIntDecode(); +} + +void BinaryDeserializer::OnListEnd() { +} + +bool BinaryDeserializer::OnNullableBegin() { + return ReadBool(); +} + +void BinaryDeserializer::OnNullableEnd() { +} + +//------------------------------------------------------------------------- +// Primitive Types +//------------------------------------------------------------------------- +bool BinaryDeserializer::ReadBool() { + return static_cast(ReadPrimitive()); +} + +char BinaryDeserializer::ReadChar() { + return ReadPrimitive(); +} + +int8_t BinaryDeserializer::ReadSignedInt8() { + return VarIntDecode(); +} + +uint8_t BinaryDeserializer::ReadUnsignedInt8() { + return VarIntDecode(); +} + +int16_t BinaryDeserializer::ReadSignedInt16() { + return VarIntDecode(); +} + +uint16_t BinaryDeserializer::ReadUnsignedInt16() { + return VarIntDecode(); +} + +int32_t BinaryDeserializer::ReadSignedInt32() { + return VarIntDecode(); +} + +uint32_t BinaryDeserializer::ReadUnsignedInt32() { + return VarIntDecode(); +} + +int64_t BinaryDeserializer::ReadSignedInt64() { + return VarIntDecode(); +} + +uint64_t BinaryDeserializer::ReadUnsignedInt64() { + return VarIntDecode(); +} + +float BinaryDeserializer::ReadFloat() { + auto value = ReadPrimitive(); + return value; +} + +double BinaryDeserializer::ReadDouble() { + auto value = ReadPrimitive(); + return value; +} + +string BinaryDeserializer::ReadString() { + auto len = VarIntDecode(); + if (len == 0) { + return string(); + } + auto buffer = make_unsafe_uniq_array(len); + ReadData(buffer.get(), len); + return string(const_char_ptr_cast(buffer.get()), len); +} + +hugeint_t BinaryDeserializer::ReadHugeInt() { + auto upper = VarIntDecode(); + auto lower = VarIntDecode(); + return hugeint_t(upper, lower); +} + +void BinaryDeserializer::ReadDataPtr(data_ptr_t &ptr_p, idx_t count) { + auto len = VarIntDecode(); + if (len != count) { + throw SerializationException("Tried to read blob of %d size, but only %d elements are available", count, len); + } + ReadData(ptr_p, count); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/binary_serializer.cpp b/src/duckdb/src/common/serializer/binary_serializer.cpp new file mode 100644 index 00000000..67f31b56 --- /dev/null +++ b/src/duckdb/src/common/serializer/binary_serializer.cpp @@ -0,0 +1,170 @@ +#include "duckdb/common/serializer/binary_serializer.hpp" + +#ifdef DEBUG +#include "duckdb/common/string_util.hpp" +#endif + +namespace duckdb { + +void BinarySerializer::OnPropertyBegin(const field_id_t field_id, const char *tag) { + // Just write the field id straight up + Write(field_id); +#ifdef DEBUG + // First of check that we are inside an object + if (debug_stack.empty()) { + throw InternalException("OnPropertyBegin called outside of object"); + } + + // Check that the tag is unique + auto &state = debug_stack.back(); + auto &seen_field_ids = state.seen_field_ids; + auto &seen_field_tags = state.seen_field_tags; + auto &seen_fields = state.seen_fields; + + if (seen_field_ids.find(field_id) != seen_field_ids.end() || seen_field_tags.find(tag) != seen_field_tags.end()) { + string all_fields; + for (auto &field : seen_fields) { + all_fields += StringUtil::Format("\"%s\":%d ", field.first, field.second); + } + throw InternalException("Duplicate field id/tag in field: \"%s\":%d, other fields: %s", tag, field_id, + all_fields); + } + + seen_field_ids.insert(field_id); + seen_field_tags.insert(tag); + seen_fields.emplace_back(tag, field_id); +#else + (void)tag; +#endif +} + +void BinarySerializer::OnPropertyEnd() { + // Nothing to do here +} + +void BinarySerializer::OnOptionalPropertyBegin(const field_id_t field_id, const char *tag, bool present) { + // Dont write anything at all if the property is not present + if (present) { + OnPropertyBegin(field_id, tag); + } +} + +void BinarySerializer::OnOptionalPropertyEnd(bool present) { + // Nothing to do here +} + +//------------------------------------------------------------------------- +// Nested Type Hooks +//------------------------------------------------------------------------- +void BinarySerializer::OnObjectBegin() { +#ifdef DEBUG + debug_stack.emplace_back(); +#endif +} + +void BinarySerializer::OnObjectEnd() { +#ifdef DEBUG + debug_stack.pop_back(); +#endif + // Write object terminator + Write(MESSAGE_TERMINATOR_FIELD_ID); +} + +void BinarySerializer::OnListBegin(idx_t count) { + VarIntEncode(count); +} + +void BinarySerializer::OnListEnd() { +} + +void BinarySerializer::OnNullableBegin(bool present) { + WriteValue(present); +} + +void BinarySerializer::OnNullableEnd() { +} + +//------------------------------------------------------------------------- +// Primitive Types +//------------------------------------------------------------------------- +void BinarySerializer::WriteNull() { + // This should never be called, optional writes should be handled by OnOptionalBegin +} + +void BinarySerializer::WriteValue(bool value) { + Write(value); +} + +void BinarySerializer::WriteValue(uint8_t value) { + VarIntEncode(value); +} + +void BinarySerializer::WriteValue(char value) { + Write(value); +} + +void BinarySerializer::WriteValue(int8_t value) { + VarIntEncode(value); +} + +void BinarySerializer::WriteValue(uint16_t value) { + VarIntEncode(value); +} + +void BinarySerializer::WriteValue(int16_t value) { + VarIntEncode(value); +} + +void BinarySerializer::WriteValue(uint32_t value) { + VarIntEncode(value); +} + +void BinarySerializer::WriteValue(int32_t value) { + VarIntEncode(value); +} + +void BinarySerializer::WriteValue(uint64_t value) { + VarIntEncode(value); +} + +void BinarySerializer::WriteValue(int64_t value) { + VarIntEncode(value); +} + +void BinarySerializer::WriteValue(hugeint_t value) { + VarIntEncode(value.upper); + VarIntEncode(value.lower); +} + +void BinarySerializer::WriteValue(float value) { + Write(value); +} + +void BinarySerializer::WriteValue(double value) { + Write(value); +} + +void BinarySerializer::WriteValue(const string &value) { + uint32_t len = value.length(); + VarIntEncode(len); + WriteData(value.c_str(), len); +} + +void BinarySerializer::WriteValue(const string_t value) { + uint32_t len = value.GetSize(); + VarIntEncode(len); + WriteData(value.GetDataUnsafe(), len); +} + +void BinarySerializer::WriteValue(const char *value) { + uint32_t len = strlen(value); + VarIntEncode(len); + WriteData(value, len); +} + +void BinarySerializer::WriteDataPtr(const_data_ptr_t ptr, idx_t count) { + VarIntEncode(static_cast(count)); + WriteData(ptr, count); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/buffered_file_reader.cpp b/src/duckdb/src/common/serializer/buffered_file_reader.cpp new file mode 100644 index 00000000..ce76c354 --- /dev/null +++ b/src/duckdb/src/common/serializer/buffered_file_reader.cpp @@ -0,0 +1,58 @@ +#include "duckdb/common/serializer/buffered_file_reader.hpp" +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/exception.hpp" + +#include +#include + +namespace duckdb { + +BufferedFileReader::BufferedFileReader(FileSystem &fs, const char *path, FileLockType lock_type, + optional_ptr opener) + : fs(fs), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), read_data(0), total_read(0) { + handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ, lock_type, FileSystem::DEFAULT_COMPRESSION, opener.get()); + file_size = fs.GetFileSize(*handle); +} + +void BufferedFileReader::ReadData(data_ptr_t target_buffer, uint64_t read_size) { + // first copy anything we can from the buffer + data_ptr_t end_ptr = target_buffer + read_size; + while (true) { + idx_t to_read = MinValue(end_ptr - target_buffer, read_data - offset); + if (to_read > 0) { + memcpy(target_buffer, data.get() + offset, to_read); + offset += to_read; + target_buffer += to_read; + } + if (target_buffer < end_ptr) { + D_ASSERT(offset == read_data); + total_read += read_data; + // did not finish reading yet but exhausted buffer + // read data into buffer + offset = 0; + read_data = fs.Read(*handle, data.get(), FILE_BUFFER_SIZE); + if (read_data == 0) { + throw SerializationException("not enough data in file to deserialize result"); + } + } else { + return; + } + } +} + +bool BufferedFileReader::Finished() { + return total_read + offset == file_size; +} + +void BufferedFileReader::Seek(uint64_t location) { + D_ASSERT(location <= file_size); + handle->Seek(location); + total_read = location; + read_data = offset = 0; +} + +uint64_t BufferedFileReader::CurrentOffset() { + return total_read + offset; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/buffered_file_writer.cpp b/src/duckdb/src/common/serializer/buffered_file_writer.cpp new file mode 100644 index 00000000..bf358161 --- /dev/null +++ b/src/duckdb/src/common/serializer/buffered_file_writer.cpp @@ -0,0 +1,67 @@ +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/algorithm.hpp" +#include + +namespace duckdb { + +// Remove this when we switch C++17: https://stackoverflow.com/a/53350948 +constexpr uint8_t BufferedFileWriter::DEFAULT_OPEN_FLAGS; + +BufferedFileWriter::BufferedFileWriter(FileSystem &fs, const string &path_p, uint8_t open_flags) + : fs(fs), path(path_p), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), total_written(0) { + handle = fs.OpenFile(path, open_flags, FileLockType::WRITE_LOCK); +} + +int64_t BufferedFileWriter::GetFileSize() { + return fs.GetFileSize(*handle) + offset; +} + +idx_t BufferedFileWriter::GetTotalWritten() { + return total_written + offset; +} + +void BufferedFileWriter::WriteData(const_data_ptr_t buffer, idx_t write_size) { + // first copy anything we can from the buffer + const_data_ptr_t end_ptr = buffer + write_size; + while (buffer < end_ptr) { + idx_t to_write = MinValue((end_ptr - buffer), FILE_BUFFER_SIZE - offset); + D_ASSERT(to_write > 0); + memcpy(data.get() + offset, buffer, to_write); + offset += to_write; + buffer += to_write; + if (offset == FILE_BUFFER_SIZE) { + Flush(); + } + } +} + +void BufferedFileWriter::Flush() { + if (offset == 0) { + return; + } + fs.Write(*handle, data.get(), offset); + total_written += offset; + offset = 0; +} + +void BufferedFileWriter::Sync() { + Flush(); + handle->Sync(); +} + +void BufferedFileWriter::Truncate(int64_t size) { + uint64_t persistent = fs.GetFileSize(*handle); + D_ASSERT((uint64_t)size <= persistent + offset); + if (persistent <= (uint64_t)size) { + // truncating into the pending write buffer. + offset = size - persistent; + } else { + // truncate the physical file on disk + handle->Truncate(size); + // reset anything written in the buffer + offset = 0; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/memory_stream.cpp b/src/duckdb/src/common/serializer/memory_stream.cpp new file mode 100644 index 00000000..6064ebe1 --- /dev/null +++ b/src/duckdb/src/common/serializer/memory_stream.cpp @@ -0,0 +1,61 @@ +#include "duckdb/common/serializer/memory_stream.hpp" + +namespace duckdb { + +MemoryStream::MemoryStream(idx_t capacity) + : position(0), capacity(capacity), owns_data(true), data(static_cast(malloc(capacity))) { +} + +MemoryStream::MemoryStream(data_ptr_t buffer, idx_t capacity) + : position(0), capacity(capacity), owns_data(false), data(buffer) { +} + +MemoryStream::~MemoryStream() { + if (owns_data) { + free(data); + } +} + +void MemoryStream::WriteData(const_data_ptr_t source, idx_t write_size) { + while (position + write_size > capacity) { + if (owns_data) { + capacity *= 2; + data = static_cast(realloc(data, capacity)); + } else { + throw SerializationException("Failed to serialize: not enough space in buffer to fulfill write request"); + } + } + + memcpy(data + position, source, write_size); + position += write_size; +} + +void MemoryStream::ReadData(data_ptr_t destination, idx_t read_size) { + if (position + read_size > capacity) { + throw SerializationException("Failed to deserialize: not enough data in buffer to fulfill read request"); + } + memcpy(destination, data + position, read_size); + position += read_size; +} + +void MemoryStream::Rewind() { + position = 0; +} + +void MemoryStream::Release() { + owns_data = false; +} + +data_ptr_t MemoryStream::GetData() const { + return data; +} + +idx_t MemoryStream::GetPosition() const { + return position; +} + +idx_t MemoryStream::GetCapacity() const { + return capacity; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/serializer.cpp b/src/duckdb/src/common/serializer/serializer.cpp new file mode 100644 index 00000000..303abb34 --- /dev/null +++ b/src/duckdb/src/common/serializer/serializer.cpp @@ -0,0 +1,15 @@ +#include "duckdb/common/serializer/serializer.hpp" + +namespace duckdb { + +template <> +void Serializer::WriteValue(const vector &vec) { + auto count = vec.size(); + OnListBegin(count); + for (auto item : vec) { + WriteValue(item); + } + OnListEnd(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sort/comparators.cpp b/src/duckdb/src/common/sort/comparators.cpp new file mode 100644 index 00000000..b6cf7b55 --- /dev/null +++ b/src/duckdb/src/common/sort/comparators.cpp @@ -0,0 +1,381 @@ +#include "duckdb/common/sort/comparators.hpp" + +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/sort/sort.hpp" + +namespace duckdb { + +bool Comparators::TieIsBreakable(const idx_t &tie_col, const data_ptr_t &row_ptr, const SortLayout &sort_layout) { + const auto &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); + // Check if the blob is NULL + ValidityBytes row_mask(row_ptr); + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { + // Can't break a NULL tie + return false; + } + auto &row_layout = sort_layout.blob_layout; + if (row_layout.GetTypes()[col_idx].InternalType() != PhysicalType::VARCHAR) { + // Nested type, must be broken + return true; + } + const auto &tie_col_offset = row_layout.GetOffsets()[col_idx]; + auto tie_string = Load(row_ptr + tie_col_offset); + if (tie_string.GetSize() < sort_layout.prefix_lengths[tie_col]) { + // No need to break the tie - we already compared the full string + return false; + } + return true; +} + +int Comparators::CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, + const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort) { + // Compare the sorting columns one by one + int comp_res = 0; + data_ptr_t l_ptr_offset = l_ptr; + data_ptr_t r_ptr_offset = r_ptr; + for (idx_t col_idx = 0; col_idx < sort_layout.column_count; col_idx++) { + comp_res = FastMemcmp(l_ptr_offset, r_ptr_offset, sort_layout.column_sizes[col_idx]); + if (comp_res == 0 && !sort_layout.constant_size[col_idx]) { + comp_res = BreakBlobTie(col_idx, left, right, sort_layout, external_sort); + } + if (comp_res != 0) { + break; + } + l_ptr_offset += sort_layout.column_sizes[col_idx]; + r_ptr_offset += sort_layout.column_sizes[col_idx]; + } + return comp_res; +} + +int Comparators::CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type) { + switch (type.InternalType()) { + case PhysicalType::VARCHAR: + return TemplatedCompareVal(l_ptr, r_ptr); + case PhysicalType::LIST: + case PhysicalType::STRUCT: { + auto l_nested_ptr = Load(l_ptr); + auto r_nested_ptr = Load(r_ptr); + return CompareValAndAdvance(l_nested_ptr, r_nested_ptr, type, true); + } + default: + throw NotImplementedException("Unimplemented CompareVal for type %s", type.ToString()); + } +} + +int Comparators::BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, + const SortLayout &sort_layout, const bool &external) { + data_ptr_t l_data_ptr = left.DataPtr(*left.sb->blob_sorting_data); + data_ptr_t r_data_ptr = right.DataPtr(*right.sb->blob_sorting_data); + if (!TieIsBreakable(tie_col, l_data_ptr, sort_layout)) { + // Quick check to see if ties can be broken + return 0; + } + // Align the pointers + const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); + const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; + l_data_ptr += tie_col_offset; + r_data_ptr += tie_col_offset; + // Do the comparison + const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; + const auto &type = sort_layout.blob_layout.GetTypes()[col_idx]; + int result; + if (external) { + // Store heap pointers + data_ptr_t l_heap_ptr = left.HeapPtr(*left.sb->blob_sorting_data); + data_ptr_t r_heap_ptr = right.HeapPtr(*right.sb->blob_sorting_data); + // Unswizzle offset to pointer + UnswizzleSingleValue(l_data_ptr, l_heap_ptr, type); + UnswizzleSingleValue(r_data_ptr, r_heap_ptr, type); + // Compare + result = CompareVal(l_data_ptr, r_data_ptr, type); + // Swizzle the pointers back to offsets + SwizzleSingleValue(l_data_ptr, l_heap_ptr, type); + SwizzleSingleValue(r_data_ptr, r_heap_ptr, type); + } else { + result = CompareVal(l_data_ptr, r_data_ptr, type); + } + return order * result; +} + +template +int Comparators::TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr) { + const auto left_val = Load(left_ptr); + const auto right_val = Load(right_ptr); + if (Equals::Operation(left_val, right_val)) { + return 0; + } else if (LessThan::Operation(left_val, right_val)) { + return -1; + } else { + return 1; + } +} + +int Comparators::CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid) { + switch (type.InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::INT16: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::INT32: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::INT64: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::UINT8: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::UINT16: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::UINT32: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::UINT64: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::INT128: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::FLOAT: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::DOUBLE: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::INTERVAL: + return TemplatedCompareAndAdvance(l_ptr, r_ptr); + case PhysicalType::VARCHAR: + return CompareStringAndAdvance(l_ptr, r_ptr, valid); + case PhysicalType::LIST: + return CompareListAndAdvance(l_ptr, r_ptr, ListType::GetChildType(type), valid); + case PhysicalType::STRUCT: + return CompareStructAndAdvance(l_ptr, r_ptr, StructType::GetChildTypes(type), valid); + default: + throw NotImplementedException("Unimplemented CompareValAndAdvance for type %s", type.ToString()); + } +} + +template +int Comparators::TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr) { + auto result = TemplatedCompareVal(left_ptr, right_ptr); + left_ptr += sizeof(T); + right_ptr += sizeof(T); + return result; +} + +int Comparators::CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid) { + if (!valid) { + return 0; + } + uint32_t left_string_size = Load(left_ptr); + uint32_t right_string_size = Load(right_ptr); + left_ptr += sizeof(uint32_t); + right_ptr += sizeof(uint32_t); + auto memcmp_res = memcmp(const_char_ptr_cast(left_ptr), const_char_ptr_cast(right_ptr), + std::min(left_string_size, right_string_size)); + + left_ptr += left_string_size; + right_ptr += right_string_size; + + if (memcmp_res != 0) { + return memcmp_res; + } + if (left_string_size == right_string_size) { + return 0; + } + if (left_string_size < right_string_size) { + return -1; + } + return 1; +} + +int Comparators::CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, + const child_list_t &types, bool valid) { + idx_t count = types.size(); + // Load validity masks + ValidityBytes left_validity(left_ptr); + ValidityBytes right_validity(right_ptr); + left_ptr += (count + 7) / 8; + right_ptr += (count + 7) / 8; + // Initialize variables + bool left_valid; + bool right_valid; + idx_t entry_idx; + idx_t idx_in_entry; + // Compare + int comp_res = 0; + for (idx_t i = 0; i < count; i++) { + ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); + left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); + right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); + auto &type = types[i].second; + if ((left_valid == right_valid) || TypeIsConstantSize(type.InternalType())) { + comp_res = CompareValAndAdvance(left_ptr, right_ptr, types[i].second, left_valid && valid); + } + if (!left_valid && !right_valid) { + comp_res = 0; + } else if (!left_valid) { + comp_res = 1; + } else if (!right_valid) { + comp_res = -1; + } + if (comp_res != 0) { + break; + } + } + return comp_res; +} + +int Comparators::CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, + bool valid) { + if (!valid) { + return 0; + } + // Load list lengths + auto left_len = Load(left_ptr); + auto right_len = Load(right_ptr); + left_ptr += sizeof(idx_t); + right_ptr += sizeof(idx_t); + // Load list validity masks + ValidityBytes left_validity(left_ptr); + ValidityBytes right_validity(right_ptr); + left_ptr += (left_len + 7) / 8; + right_ptr += (right_len + 7) / 8; + // Compare + int comp_res = 0; + idx_t count = MinValue(left_len, right_len); + if (TypeIsConstantSize(type.InternalType())) { + // Templated code for fixed-size types + switch (type.InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::INT16: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::INT32: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::INT64: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::UINT8: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::UINT16: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::UINT32: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::UINT64: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::INT128: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::FLOAT: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::DOUBLE: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + case PhysicalType::INTERVAL: + comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); + break; + default: + throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); + } + } else { + // Variable-sized list entries + bool left_valid; + bool right_valid; + idx_t entry_idx; + idx_t idx_in_entry; + // Size (in bytes) of all variable-sizes entries is stored before the entries begin, + // to make deserialization easier. We need to skip over them + left_ptr += left_len * sizeof(idx_t); + right_ptr += right_len * sizeof(idx_t); + for (idx_t i = 0; i < count; i++) { + ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); + left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); + right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); + if (left_valid && right_valid) { + switch (type.InternalType()) { + case PhysicalType::LIST: + comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); + break; + case PhysicalType::VARCHAR: + comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); + break; + case PhysicalType::STRUCT: + comp_res = + CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); + break; + default: + throw NotImplementedException("CompareListAndAdvance for variable-size type %s", type.ToString()); + } + } else if (!left_valid && !right_valid) { + comp_res = 0; + } else if (left_valid) { + comp_res = -1; + } else { + comp_res = 1; + } + if (comp_res != 0) { + break; + } + } + } + // All values that we looped over were equal + if (comp_res == 0 && left_len != right_len) { + // Smaller lists first + if (left_len < right_len) { + comp_res = -1; + } else { + comp_res = 1; + } + } + return comp_res; +} + +template +int Comparators::TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, + const ValidityBytes &left_validity, const ValidityBytes &right_validity, + const idx_t &count) { + int comp_res = 0; + bool left_valid; + bool right_valid; + idx_t entry_idx; + idx_t idx_in_entry; + for (idx_t i = 0; i < count; i++) { + ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); + left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); + right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); + comp_res = TemplatedCompareAndAdvance(left_ptr, right_ptr); + if (!left_valid && !right_valid) { + comp_res = 0; + } else if (!left_valid) { + comp_res = 1; + } else if (!right_valid) { + comp_res = -1; + } + if (comp_res != 0) { + break; + } + } + return comp_res; +} + +void Comparators::UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { + if (type.InternalType() == PhysicalType::VARCHAR) { + data_ptr += string_t::HEADER_SIZE; + } + Store(heap_ptr + Load(data_ptr), data_ptr); +} + +void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { + if (type.InternalType() == PhysicalType::VARCHAR) { + data_ptr += string_t::HEADER_SIZE; + } + Store(Load(data_ptr) - heap_ptr, data_ptr); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sort/merge_sorter.cpp b/src/duckdb/src/common/sort/merge_sorter.cpp new file mode 100644 index 00000000..7d2f6a1d --- /dev/null +++ b/src/duckdb/src/common/sort/merge_sorter.cpp @@ -0,0 +1,663 @@ +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/sort/comparators.hpp" +#include "duckdb/common/sort/sort.hpp" + +namespace duckdb { + +MergeSorter::MergeSorter(GlobalSortState &state, BufferManager &buffer_manager) + : state(state), buffer_manager(buffer_manager), sort_layout(state.sort_layout) { +} + +void MergeSorter::PerformInMergeRound() { + while (true) { + { + lock_guard pair_guard(state.lock); + if (state.pair_idx == state.num_pairs) { + break; + } + GetNextPartition(); + } + MergePartition(); + } +} + +void MergeSorter::MergePartition() { + auto &left_block = *left->sb; + auto &right_block = *right->sb; +#ifdef DEBUG + D_ASSERT(left_block.radix_sorting_data.size() == left_block.payload_data->data_blocks.size()); + D_ASSERT(right_block.radix_sorting_data.size() == right_block.payload_data->data_blocks.size()); + if (!state.payload_layout.AllConstant() && state.external) { + D_ASSERT(left_block.payload_data->data_blocks.size() == left_block.payload_data->heap_blocks.size()); + D_ASSERT(right_block.payload_data->data_blocks.size() == right_block.payload_data->heap_blocks.size()); + } + if (!sort_layout.all_constant) { + D_ASSERT(left_block.radix_sorting_data.size() == left_block.blob_sorting_data->data_blocks.size()); + D_ASSERT(right_block.radix_sorting_data.size() == right_block.blob_sorting_data->data_blocks.size()); + if (state.external) { + D_ASSERT(left_block.blob_sorting_data->data_blocks.size() == + left_block.blob_sorting_data->heap_blocks.size()); + D_ASSERT(right_block.blob_sorting_data->data_blocks.size() == + right_block.blob_sorting_data->heap_blocks.size()); + } + } +#endif + // Set up the write block + // Each merge task produces a SortedBlock with exactly state.block_capacity rows or less + result->InitializeWrite(); + // Initialize arrays to store merge data + bool left_smaller[STANDARD_VECTOR_SIZE]; + idx_t next_entry_sizes[STANDARD_VECTOR_SIZE]; + // Merge loop +#ifdef DEBUG + auto l_count = left->Remaining(); + auto r_count = right->Remaining(); +#endif + while (true) { + auto l_remaining = left->Remaining(); + auto r_remaining = right->Remaining(); + if (l_remaining + r_remaining == 0) { + // Done + break; + } + const idx_t next = MinValue(l_remaining + r_remaining, (idx_t)STANDARD_VECTOR_SIZE); + if (l_remaining != 0 && r_remaining != 0) { + // Compute the merge (not needed if one side is exhausted) + ComputeMerge(next, left_smaller); + } + // Actually merge the data (radix, blob, and payload) + MergeRadix(next, left_smaller); + if (!sort_layout.all_constant) { + MergeData(*result->blob_sorting_data, *left_block.blob_sorting_data, *right_block.blob_sorting_data, next, + left_smaller, next_entry_sizes, true); + D_ASSERT(result->radix_sorting_data.size() == result->blob_sorting_data->data_blocks.size()); + } + MergeData(*result->payload_data, *left_block.payload_data, *right_block.payload_data, next, left_smaller, + next_entry_sizes, false); + D_ASSERT(result->radix_sorting_data.size() == result->payload_data->data_blocks.size()); + } +#ifdef DEBUG + D_ASSERT(result->Count() == l_count + r_count); +#endif +} + +void MergeSorter::GetNextPartition() { + // Create result block + state.sorted_blocks_temp[state.pair_idx].push_back(make_uniq(buffer_manager, state)); + result = state.sorted_blocks_temp[state.pair_idx].back().get(); + // Determine which blocks must be merged + auto &left_block = *state.sorted_blocks[state.pair_idx * 2]; + auto &right_block = *state.sorted_blocks[state.pair_idx * 2 + 1]; + const idx_t l_count = left_block.Count(); + const idx_t r_count = right_block.Count(); + // Initialize left and right reader + left = make_uniq(buffer_manager, state); + right = make_uniq(buffer_manager, state); + // Compute the work that this thread must do using Merge Path + idx_t l_end; + idx_t r_end; + if (state.l_start + state.r_start + state.block_capacity < l_count + r_count) { + left->sb = state.sorted_blocks[state.pair_idx * 2].get(); + right->sb = state.sorted_blocks[state.pair_idx * 2 + 1].get(); + const idx_t intersection = state.l_start + state.r_start + state.block_capacity; + GetIntersection(intersection, l_end, r_end); + D_ASSERT(l_end <= l_count); + D_ASSERT(r_end <= r_count); + D_ASSERT(intersection == l_end + r_end); + } else { + l_end = l_count; + r_end = r_count; + } + // Create slices of the data that this thread must merge + left->SetIndices(0, 0); + right->SetIndices(0, 0); + left_input = left_block.CreateSlice(state.l_start, l_end, left->entry_idx); + right_input = right_block.CreateSlice(state.r_start, r_end, right->entry_idx); + left->sb = left_input.get(); + right->sb = right_input.get(); + state.l_start = l_end; + state.r_start = r_end; + D_ASSERT(left->Remaining() + right->Remaining() == state.block_capacity || (l_end == l_count && r_end == r_count)); + // Update global state + if (state.l_start == l_count && state.r_start == r_count) { + // Delete references to previous pair + state.sorted_blocks[state.pair_idx * 2] = nullptr; + state.sorted_blocks[state.pair_idx * 2 + 1] = nullptr; + // Advance pair + state.pair_idx++; + state.l_start = 0; + state.r_start = 0; + } +} + +int MergeSorter::CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx) { + D_ASSERT(l_idx < l.sb->Count()); + D_ASSERT(r_idx < r.sb->Count()); + + // Easy comparison using the previous result (intersections must increase monotonically) + if (l_idx < state.l_start) { + return -1; + } + if (r_idx < state.r_start) { + return 1; + } + + l.sb->GlobalToLocalIndex(l_idx, l.block_idx, l.entry_idx); + r.sb->GlobalToLocalIndex(r_idx, r.block_idx, r.entry_idx); + + l.PinRadix(l.block_idx); + r.PinRadix(r.block_idx); + data_ptr_t l_ptr = l.radix_handle.Ptr() + l.entry_idx * sort_layout.entry_size; + data_ptr_t r_ptr = r.radix_handle.Ptr() + r.entry_idx * sort_layout.entry_size; + + int comp_res; + if (sort_layout.all_constant) { + comp_res = FastMemcmp(l_ptr, r_ptr, sort_layout.comparison_size); + } else { + l.PinData(*l.sb->blob_sorting_data); + r.PinData(*r.sb->blob_sorting_data); + comp_res = Comparators::CompareTuple(l, r, l_ptr, r_ptr, sort_layout, state.external); + } + return comp_res; +} + +void MergeSorter::GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx) { + const idx_t l_count = left->sb->Count(); + const idx_t r_count = right->sb->Count(); + // Cover some edge cases + // Code coverage off because these edge cases cannot happen unless other code changes + // Edge cases have been tested extensively while developing Merge Path in a script + // LCOV_EXCL_START + if (diagonal >= l_count + r_count) { + l_idx = l_count; + r_idx = r_count; + return; + } else if (diagonal == 0) { + l_idx = 0; + r_idx = 0; + return; + } else if (l_count == 0) { + l_idx = 0; + r_idx = diagonal; + return; + } else if (r_count == 0) { + r_idx = 0; + l_idx = diagonal; + return; + } + // LCOV_EXCL_STOP + // Determine offsets for the binary search + const idx_t l_offset = MinValue(l_count, diagonal); + const idx_t r_offset = diagonal > l_count ? diagonal - l_count : 0; + D_ASSERT(l_offset + r_offset == diagonal); + const idx_t search_space = diagonal > MaxValue(l_count, r_count) ? l_count + r_count - diagonal + : MinValue(diagonal, MinValue(l_count, r_count)); + // Double binary search + idx_t li = 0; + idx_t ri = search_space - 1; + idx_t middle; + int comp_res; + while (li <= ri) { + middle = (li + ri) / 2; + l_idx = l_offset - middle; + r_idx = r_offset + middle; + if (l_idx == l_count || r_idx == 0) { + comp_res = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); + if (comp_res > 0) { + l_idx--; + r_idx++; + } else { + return; + } + if (l_idx == 0 || r_idx == r_count) { + // This case is incredibly difficult to cover as it is dependent on parallelism randomness + // But it has been tested extensively during development in a script + // LCOV_EXCL_START + return; + // LCOV_EXCL_STOP + } else { + break; + } + } + comp_res = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx); + if (comp_res > 0) { + li = middle + 1; + } else { + ri = middle - 1; + } + } + int l_r_min1 = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx - 1); + int l_min1_r = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); + if (l_r_min1 > 0 && l_min1_r < 0) { + return; + } else if (l_r_min1 > 0) { + l_idx--; + r_idx++; + } else if (l_min1_r < 0) { + l_idx++; + r_idx--; + } +} + +void MergeSorter::ComputeMerge(const idx_t &count, bool left_smaller[]) { + auto &l = *left; + auto &r = *right; + auto &l_sorted_block = *l.sb; + auto &r_sorted_block = *r.sb; + // Save indices to restore afterwards + idx_t l_block_idx_before = l.block_idx; + idx_t l_entry_idx_before = l.entry_idx; + idx_t r_block_idx_before = r.block_idx; + idx_t r_entry_idx_before = r.entry_idx; + // Data pointers for both sides + data_ptr_t l_radix_ptr; + data_ptr_t r_radix_ptr; + // Compute the merge of the next 'count' tuples + idx_t compared = 0; + while (compared < count) { + // Move to the next block (if needed) + if (l.block_idx < l_sorted_block.radix_sorting_data.size() && + l.entry_idx == l_sorted_block.radix_sorting_data[l.block_idx]->count) { + l.block_idx++; + l.entry_idx = 0; + } + if (r.block_idx < r_sorted_block.radix_sorting_data.size() && + r.entry_idx == r_sorted_block.radix_sorting_data[r.block_idx]->count) { + r.block_idx++; + r.entry_idx = 0; + } + const bool l_done = l.block_idx == l_sorted_block.radix_sorting_data.size(); + const bool r_done = r.block_idx == r_sorted_block.radix_sorting_data.size(); + if (l_done || r_done) { + // One of the sides is exhausted, no need to compare + break; + } + // Pin the radix sorting data + left->PinRadix(l.block_idx); + l_radix_ptr = left->RadixPtr(); + right->PinRadix(r.block_idx); + r_radix_ptr = right->RadixPtr(); + + const idx_t l_count = l_sorted_block.radix_sorting_data[l.block_idx]->count; + const idx_t r_count = r_sorted_block.radix_sorting_data[r.block_idx]->count; + // Compute the merge + if (sort_layout.all_constant) { + // All sorting columns are constant size + for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { + left_smaller[compared] = FastMemcmp(l_radix_ptr, r_radix_ptr, sort_layout.comparison_size) < 0; + const bool &l_smaller = left_smaller[compared]; + const bool r_smaller = !l_smaller; + // Use comparison bool (0 or 1) to increment entries and pointers + l.entry_idx += l_smaller; + r.entry_idx += r_smaller; + l_radix_ptr += l_smaller * sort_layout.entry_size; + r_radix_ptr += r_smaller * sort_layout.entry_size; + } + } else { + // Pin the blob data + left->PinData(*l_sorted_block.blob_sorting_data); + right->PinData(*r_sorted_block.blob_sorting_data); + // Merge with variable size sorting columns + for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { + left_smaller[compared] = + Comparators::CompareTuple(*left, *right, l_radix_ptr, r_radix_ptr, sort_layout, state.external) < 0; + const bool &l_smaller = left_smaller[compared]; + const bool r_smaller = !l_smaller; + // Use comparison bool (0 or 1) to increment entries and pointers + l.entry_idx += l_smaller; + r.entry_idx += r_smaller; + l_radix_ptr += l_smaller * sort_layout.entry_size; + r_radix_ptr += r_smaller * sort_layout.entry_size; + } + } + } + // Reset block indices + left->SetIndices(l_block_idx_before, l_entry_idx_before); + right->SetIndices(r_block_idx_before, r_entry_idx_before); +} + +void MergeSorter::MergeRadix(const idx_t &count, const bool left_smaller[]) { + auto &l = *left; + auto &r = *right; + // Save indices to restore afterwards + idx_t l_block_idx_before = l.block_idx; + idx_t l_entry_idx_before = l.entry_idx; + idx_t r_block_idx_before = r.block_idx; + idx_t r_entry_idx_before = r.entry_idx; + + auto &l_blocks = l.sb->radix_sorting_data; + auto &r_blocks = r.sb->radix_sorting_data; + RowDataBlock *l_block = nullptr; + RowDataBlock *r_block = nullptr; + + data_ptr_t l_ptr; + data_ptr_t r_ptr; + + RowDataBlock *result_block = result->radix_sorting_data.back().get(); + auto result_handle = buffer_manager.Pin(result_block->block); + data_ptr_t result_ptr = result_handle.Ptr() + result_block->count * sort_layout.entry_size; + + idx_t copied = 0; + while (copied < count) { + // Move to the next block (if needed) + if (l.block_idx < l_blocks.size() && l.entry_idx == l_blocks[l.block_idx]->count) { + // Delete reference to previous block + l_blocks[l.block_idx]->block = nullptr; + // Advance block + l.block_idx++; + l.entry_idx = 0; + } + if (r.block_idx < r_blocks.size() && r.entry_idx == r_blocks[r.block_idx]->count) { + // Delete reference to previous block + r_blocks[r.block_idx]->block = nullptr; + // Advance block + r.block_idx++; + r.entry_idx = 0; + } + const bool l_done = l.block_idx == l_blocks.size(); + const bool r_done = r.block_idx == r_blocks.size(); + // Pin the radix sortable blocks + idx_t l_count; + if (!l_done) { + l_block = l_blocks[l.block_idx].get(); + left->PinRadix(l.block_idx); + l_ptr = l.RadixPtr(); + l_count = l_block->count; + } else { + l_count = 0; + } + idx_t r_count; + if (!r_done) { + r_block = r_blocks[r.block_idx].get(); + r.PinRadix(r.block_idx); + r_ptr = r.RadixPtr(); + r_count = r_block->count; + } else { + r_count = 0; + } + // Copy using computed merge + if (!l_done && !r_done) { + // Both sides have data - merge + MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_block, result_ptr, + sort_layout.entry_size, left_smaller, copied, count); + } else if (r_done) { + // Right side is exhausted + FlushRows(l_ptr, l.entry_idx, l_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); + } else { + // Left side is exhausted + FlushRows(r_ptr, r.entry_idx, r_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); + } + } + // Reset block indices + left->SetIndices(l_block_idx_before, l_entry_idx_before); + right->SetIndices(r_block_idx_before, r_entry_idx_before); +} + +void MergeSorter::MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, + const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices) { + auto &l = *left; + auto &r = *right; + // Save indices to restore afterwards + idx_t l_block_idx_before = l.block_idx; + idx_t l_entry_idx_before = l.entry_idx; + idx_t r_block_idx_before = r.block_idx; + idx_t r_entry_idx_before = r.entry_idx; + + const auto &layout = result_data.layout; + const idx_t row_width = layout.GetRowWidth(); + const idx_t heap_pointer_offset = layout.GetHeapOffset(); + + // Left and right row data to merge + data_ptr_t l_ptr; + data_ptr_t r_ptr; + // Accompanying left and right heap data (if needed) + data_ptr_t l_heap_ptr; + data_ptr_t r_heap_ptr; + + // Result rows to write to + RowDataBlock *result_data_block = result_data.data_blocks.back().get(); + auto result_data_handle = buffer_manager.Pin(result_data_block->block); + data_ptr_t result_data_ptr = result_data_handle.Ptr() + result_data_block->count * row_width; + // Result heap to write to (if needed) + RowDataBlock *result_heap_block = nullptr; + BufferHandle result_heap_handle; + data_ptr_t result_heap_ptr; + if (!layout.AllConstant() && state.external) { + result_heap_block = result_data.heap_blocks.back().get(); + result_heap_handle = buffer_manager.Pin(result_heap_block->block); + result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; + } + + idx_t copied = 0; + while (copied < count) { + // Move to new data blocks (if needed) + if (l.block_idx < l_data.data_blocks.size() && l.entry_idx == l_data.data_blocks[l.block_idx]->count) { + // Delete reference to previous block + l_data.data_blocks[l.block_idx]->block = nullptr; + if (!layout.AllConstant() && state.external) { + l_data.heap_blocks[l.block_idx]->block = nullptr; + } + // Advance block + l.block_idx++; + l.entry_idx = 0; + } + if (r.block_idx < r_data.data_blocks.size() && r.entry_idx == r_data.data_blocks[r.block_idx]->count) { + // Delete reference to previous block + r_data.data_blocks[r.block_idx]->block = nullptr; + if (!layout.AllConstant() && state.external) { + r_data.heap_blocks[r.block_idx]->block = nullptr; + } + // Advance block + r.block_idx++; + r.entry_idx = 0; + } + const bool l_done = l.block_idx == l_data.data_blocks.size(); + const bool r_done = r.block_idx == r_data.data_blocks.size(); + // Pin the row data blocks + if (!l_done) { + l.PinData(l_data); + l_ptr = l.DataPtr(l_data); + } + if (!r_done) { + r.PinData(r_data); + r_ptr = r.DataPtr(r_data); + } + const idx_t &l_count = !l_done ? l_data.data_blocks[l.block_idx]->count : 0; + const idx_t &r_count = !r_done ? r_data.data_blocks[r.block_idx]->count : 0; + // Perform the merge + if (layout.AllConstant() || !state.external) { + // If all constant size, or if we are doing an in-memory sort, we do not need to touch the heap + if (!l_done && !r_done) { + // Both sides have data - merge + MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, + row_width, left_smaller, copied, count); + } else if (r_done) { + // Right side is exhausted + FlushRows(l_ptr, l.entry_idx, l_count, *result_data_block, result_data_ptr, row_width, copied, count); + } else { + // Left side is exhausted + FlushRows(r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, row_width, copied, count); + } + } else { + // External sorting with variable size data. Pin the heap blocks too + if (!l_done) { + l_heap_ptr = l.BaseHeapPtr(l_data) + Load(l_ptr + heap_pointer_offset); + D_ASSERT(l_heap_ptr - l.BaseHeapPtr(l_data) >= 0); + D_ASSERT((idx_t)(l_heap_ptr - l.BaseHeapPtr(l_data)) < l_data.heap_blocks[l.block_idx]->byte_offset); + } + if (!r_done) { + r_heap_ptr = r.BaseHeapPtr(r_data) + Load(r_ptr + heap_pointer_offset); + D_ASSERT(r_heap_ptr - r.BaseHeapPtr(r_data) >= 0); + D_ASSERT((idx_t)(r_heap_ptr - r.BaseHeapPtr(r_data)) < r_data.heap_blocks[r.block_idx]->byte_offset); + } + // Both the row and heap data need to be dealt with + if (!l_done && !r_done) { + // Both sides have data - merge + idx_t l_idx_copy = l.entry_idx; + idx_t r_idx_copy = r.entry_idx; + data_ptr_t result_data_ptr_copy = result_data_ptr; + idx_t copied_copy = copied; + // Merge row data + MergeRows(l_ptr, l_idx_copy, l_count, r_ptr, r_idx_copy, r_count, *result_data_block, + result_data_ptr_copy, row_width, left_smaller, copied_copy, count); + const idx_t merged = copied_copy - copied; + // Compute the entry sizes and number of heap bytes that will be copied + idx_t copy_bytes = 0; + data_ptr_t l_heap_ptr_copy = l_heap_ptr; + data_ptr_t r_heap_ptr_copy = r_heap_ptr; + for (idx_t i = 0; i < merged; i++) { + // Store base heap offset in the row data + Store(result_heap_block->byte_offset + copy_bytes, result_data_ptr + heap_pointer_offset); + result_data_ptr += row_width; + // Compute entry size and add to total + const bool &l_smaller = left_smaller[copied + i]; + const bool r_smaller = !l_smaller; + auto &entry_size = next_entry_sizes[copied + i]; + entry_size = + l_smaller * Load(l_heap_ptr_copy) + r_smaller * Load(r_heap_ptr_copy); + D_ASSERT(entry_size >= sizeof(uint32_t)); + D_ASSERT(l_heap_ptr_copy - l.BaseHeapPtr(l_data) + l_smaller * entry_size <= + l_data.heap_blocks[l.block_idx]->byte_offset); + D_ASSERT(r_heap_ptr_copy - r.BaseHeapPtr(r_data) + r_smaller * entry_size <= + r_data.heap_blocks[r.block_idx]->byte_offset); + l_heap_ptr_copy += l_smaller * entry_size; + r_heap_ptr_copy += r_smaller * entry_size; + copy_bytes += entry_size; + } + // Reallocate result heap block size (if needed) + if (result_heap_block->byte_offset + copy_bytes > result_heap_block->capacity) { + idx_t new_capacity = result_heap_block->byte_offset + copy_bytes; + buffer_manager.ReAllocate(result_heap_block->block, new_capacity); + result_heap_block->capacity = new_capacity; + result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; + } + D_ASSERT(result_heap_block->byte_offset + copy_bytes <= result_heap_block->capacity); + // Now copy the heap data + for (idx_t i = 0; i < merged; i++) { + const bool &l_smaller = left_smaller[copied + i]; + const bool r_smaller = !l_smaller; + const auto &entry_size = next_entry_sizes[copied + i]; + memcpy(result_heap_ptr, + reinterpret_cast(l_smaller * CastPointerToValue(l_heap_ptr) + + r_smaller * CastPointerToValue(r_heap_ptr)), + entry_size); + D_ASSERT(Load(result_heap_ptr) == entry_size); + result_heap_ptr += entry_size; + l_heap_ptr += l_smaller * entry_size; + r_heap_ptr += r_smaller * entry_size; + l.entry_idx += l_smaller; + r.entry_idx += r_smaller; + } + // Update result indices and pointers + result_heap_block->count += merged; + result_heap_block->byte_offset += copy_bytes; + copied += merged; + } else if (r_done) { + // Right side is exhausted - flush left + FlushBlobs(layout, l_count, l_ptr, l.entry_idx, l_heap_ptr, *result_data_block, result_data_ptr, + *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); + } else { + // Left side is exhausted - flush right + FlushBlobs(layout, r_count, r_ptr, r.entry_idx, r_heap_ptr, *result_data_block, result_data_ptr, + *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); + } + D_ASSERT(result_data_block->count == result_heap_block->count); + } + } + if (reset_indices) { + left->SetIndices(l_block_idx_before, l_entry_idx_before); + right->SetIndices(r_block_idx_before, r_entry_idx_before); + } +} + +void MergeSorter::MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, + idx_t &r_entry_idx, const idx_t &r_count, RowDataBlock &target_block, + data_ptr_t &target_ptr, const idx_t &entry_size, const bool left_smaller[], idx_t &copied, + const idx_t &count) { + const idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); + idx_t i; + for (i = 0; i < next && l_entry_idx < l_count && r_entry_idx < r_count; i++) { + const bool &l_smaller = left_smaller[copied + i]; + const bool r_smaller = !l_smaller; + // Use comparison bool (0 or 1) to copy an entry from either side + FastMemcpy( + target_ptr, + reinterpret_cast(l_smaller * CastPointerToValue(l_ptr) + r_smaller * CastPointerToValue(r_ptr)), + entry_size); + target_ptr += entry_size; + // Use the comparison bool to increment entries and pointers + l_entry_idx += l_smaller; + r_entry_idx += r_smaller; + l_ptr += l_smaller * entry_size; + r_ptr += r_smaller * entry_size; + } + // Update counts + target_block.count += i; + copied += i; +} + +void MergeSorter::FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, + RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, + const idx_t &count) { + // Compute how many entries we can fit + idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); + next = MinValue(next, source_count - source_entry_idx); + // Copy them all in a single memcpy + const idx_t copy_bytes = next * entry_size; + memcpy(target_ptr, source_ptr, copy_bytes); + target_ptr += copy_bytes; + source_ptr += copy_bytes; + // Update counts + source_entry_idx += next; + target_block.count += next; + copied += next; +} + +void MergeSorter::FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, + idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, + data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, + BufferHandle &target_heap_handle, data_ptr_t &target_heap_ptr, idx_t &copied, + const idx_t &count) { + const idx_t row_width = layout.GetRowWidth(); + const idx_t heap_pointer_offset = layout.GetHeapOffset(); + idx_t source_entry_idx_copy = source_entry_idx; + data_ptr_t target_data_ptr_copy = target_data_ptr; + idx_t copied_copy = copied; + // Flush row data + FlushRows(source_data_ptr, source_entry_idx_copy, source_count, target_data_block, target_data_ptr_copy, row_width, + copied_copy, count); + const idx_t flushed = copied_copy - copied; + // Compute the entry sizes and number of heap bytes that will be copied + idx_t copy_bytes = 0; + data_ptr_t source_heap_ptr_copy = source_heap_ptr; + for (idx_t i = 0; i < flushed; i++) { + // Store base heap offset in the row data + Store(target_heap_block.byte_offset + copy_bytes, target_data_ptr + heap_pointer_offset); + target_data_ptr += row_width; + // Compute entry size and add to total + auto entry_size = Load(source_heap_ptr_copy); + D_ASSERT(entry_size >= sizeof(uint32_t)); + source_heap_ptr_copy += entry_size; + copy_bytes += entry_size; + } + // Reallocate result heap block size (if needed) + if (target_heap_block.byte_offset + copy_bytes > target_heap_block.capacity) { + idx_t new_capacity = target_heap_block.byte_offset + copy_bytes; + buffer_manager.ReAllocate(target_heap_block.block, new_capacity); + target_heap_block.capacity = new_capacity; + target_heap_ptr = target_heap_handle.Ptr() + target_heap_block.byte_offset; + } + D_ASSERT(target_heap_block.byte_offset + copy_bytes <= target_heap_block.capacity); + // Copy the heap data in one go + memcpy(target_heap_ptr, source_heap_ptr, copy_bytes); + target_heap_ptr += copy_bytes; + source_heap_ptr += copy_bytes; + source_entry_idx += flushed; + copied += flushed; + // Update result indices and pointers + target_heap_block.count += flushed; + target_heap_block.byte_offset += copy_bytes; + D_ASSERT(target_heap_block.byte_offset <= target_heap_block.capacity); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sort/partition_state.cpp b/src/duckdb/src/common/sort/partition_state.cpp new file mode 100644 index 00000000..5f6f3133 --- /dev/null +++ b/src/duckdb/src/common/sort/partition_state.cpp @@ -0,0 +1,674 @@ +#include "duckdb/common/sort/partition_state.hpp" + +#include "duckdb/common/types/column/column_data_consumer.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parallel/event.hpp" + +#include + +namespace duckdb { + +PartitionGlobalHashGroup::PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, + const Orders &orders, const Types &payload_types, bool external) + : count(0), batch_base(0) { + + RowLayout payload_layout; + payload_layout.Initialize(payload_types); + global_sort = make_uniq(buffer_manager, orders, payload_layout); + global_sort->external = external; + + // Set up a comparator for the partition subset + partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(partitions.size()); +} + +int PartitionGlobalHashGroup::ComparePartitions(const SBIterator &left, const SBIterator &right) const { + int part_cmp = 0; + if (partition_layout.all_constant) { + part_cmp = FastMemcmp(left.entry_ptr, right.entry_ptr, partition_layout.comparison_size); + } else { + part_cmp = Comparators::CompareTuple(left.scan, right.scan, left.entry_ptr, right.entry_ptr, partition_layout, + left.external); + } + return part_cmp; +} + +void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, ValidityMask &order_mask) { + D_ASSERT(count > 0); + + SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); + SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); + + partition_mask.SetValidUnsafe(0); + order_mask.SetValidUnsafe(0); + for (++curr; curr.GetIndex() < count; ++curr) { + // Compare the partition subset first because if that differs, then so does the full ordering + const auto part_cmp = ComparePartitions(prev, curr); + ; + + if (part_cmp) { + partition_mask.SetValidUnsafe(curr.GetIndex()); + order_mask.SetValidUnsafe(curr.GetIndex()); + } else if (prev.Compare(curr)) { + order_mask.SetValidUnsafe(curr.GetIndex()); + } + ++prev; + } +} + +void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, + const vector> &partition_bys, + const Orders &order_bys, + const vector> &partition_stats) { + + // we sort by both 1) partition by expression list and 2) order by expressions + const auto partition_cols = partition_bys.size(); + for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { + auto &pexpr = partition_bys[prt_idx]; + + if (partition_stats.empty() || !partition_stats[prt_idx]) { + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), nullptr); + } else { + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), + partition_stats[prt_idx]->ToUnique()); + } + partitions.emplace_back(orders.back().Copy()); + } + + for (const auto &order : order_bys) { + orders.emplace_back(order.Copy()); + } +} + +PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, + const vector> &partition_bys, + const vector &order_bys, + const Types &payload_types, + const vector> &partition_stats, + idx_t estimated_cardinality) + : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), + fixed_bits(0), payload_types(payload_types), memory_per_thread(0), max_bits(1), count(0) { + + GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); + + memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); + external = ClientConfig::GetConfig(context).force_external; + + const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * idx_t(Storage::BLOCK_ALLOC_SIZE))); + while (max_bits < 10 && (thread_pages >> max_bits) > 1) { + ++max_bits; + } + + if (!orders.empty()) { + if (partitions.empty()) { + // Sort early into a dedicated hash group if we only sort. + grouping_types.Initialize(payload_types); + auto new_group = + make_uniq(buffer_manager, partitions, orders, payload_types, external); + hash_groups.emplace_back(std::move(new_group)); + } else { + auto types = payload_types; + types.push_back(LogicalType::HASH); + grouping_types.Initialize(types); + ResizeGroupingData(estimated_cardinality); + } + } +} + +bool PartitionGlobalSinkState::HasMergeTasks() const { + if (grouping_data) { + auto &groups = grouping_data->GetPartitions(); + return !groups.empty(); + } else if (!hash_groups.empty()) { + D_ASSERT(hash_groups.size() == 1); + return hash_groups[0]->count > 0; + } else { + return false; + } +} + +void PartitionGlobalSinkState::SyncPartitioning(const PartitionGlobalSinkState &other) { + fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; + + const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; + if (fixed_bits != old_bits) { + const auto hash_col_idx = payload_types.size(); + grouping_data = make_uniq(buffer_manager, grouping_types, fixed_bits, hash_col_idx); + } +} + +unique_ptr PartitionGlobalSinkState::CreatePartition(idx_t new_bits) const { + const auto hash_col_idx = payload_types.size(); + return make_uniq(buffer_manager, grouping_types, new_bits, hash_col_idx); +} + +void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { + // Have we started to combine? Then just live with it. + if (fixed_bits || (grouping_data && !grouping_data->GetPartitions().empty())) { + return; + } + // Is the average partition size too large? + const idx_t partition_size = STANDARD_ROW_GROUPS_SIZE; + const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; + auto new_bits = bits ? bits : 4; + while (new_bits < max_bits && (cardinality / RadixPartitioning::NumberOfPartitions(new_bits)) > partition_size) { + ++new_bits; + } + + // Repartition the grouping data + if (new_bits != bits) { + grouping_data = CreatePartition(new_bits); + } +} + +void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { + // We are done if the local_partition is right sized. + auto &local_radix = local_partition->Cast(); + const auto new_bits = grouping_data->GetRadixBits(); + if (local_radix.GetRadixBits() == new_bits) { + return; + } + + // If the local partition is now too small, flush it and reallocate + auto new_partition = CreatePartition(new_bits); + local_partition->FlushAppendState(*local_append); + local_partition->Repartition(*new_partition); + + local_partition = std::move(new_partition); + local_append = make_uniq(); + local_partition->InitializeAppendState(*local_append); +} + +void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { + // Make sure grouping_data doesn't change under us. + lock_guard guard(lock); + + if (!local_partition) { + local_partition = CreatePartition(grouping_data->GetRadixBits()); + local_append = make_uniq(); + local_partition->InitializeAppendState(*local_append); + return; + } + + // Grow the groups if they are too big + ResizeGroupingData(count); + + // Sync local partition to have the same bit count + SyncLocalPartition(local_partition, local_append); +} + +void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { + if (!local_partition) { + return; + } + local_partition->FlushAppendState(*local_append); + + // Make sure grouping_data doesn't change under us. + // Combine has an internal mutex, so this is single-threaded anyway. + lock_guard guard(lock); + SyncLocalPartition(local_partition, local_append); + grouping_data->Combine(*local_partition); +} + +PartitionLocalMergeState::PartitionLocalMergeState(PartitionGlobalSinkState &gstate) + : merge_state(nullptr), stage(PartitionSortStage::INIT), finished(true), executor(gstate.context) { + + // Set up the sort expression computation. + vector sort_types; + for (auto &order : gstate.orders) { + auto &oexpr = order.expression; + sort_types.emplace_back(oexpr->return_type); + executor.AddExpression(*oexpr); + } + sort_chunk.Initialize(gstate.allocator, sort_types); + payload_chunk.Initialize(gstate.allocator, gstate.payload_types); +} + +void PartitionLocalMergeState::Scan() { + if (!merge_state->group_data) { + // OVER(ORDER BY...) + // Already sorted + return; + } + + auto &group_data = *merge_state->group_data; + auto &hash_group = *merge_state->hash_group; + auto &chunk_state = merge_state->chunk_state; + // Copy the data from the group into the sort code. + auto &global_sort = *hash_group.global_sort; + LocalSortState local_sort; + local_sort.Initialize(global_sort, global_sort.buffer_manager); + + TupleDataScanState local_scan; + group_data.InitializeScan(local_scan, merge_state->column_ids); + while (group_data.Scan(chunk_state, local_scan, payload_chunk)) { + sort_chunk.Reset(); + executor.Execute(payload_chunk, sort_chunk); + + local_sort.SinkChunk(sort_chunk, payload_chunk); + if (local_sort.SizeInBytes() > merge_state->memory_per_thread) { + local_sort.Sort(global_sort, true); + } + hash_group.count += payload_chunk.size(); + } + + global_sort.AddLocalState(local_sort); +} + +// Per-thread sink state +PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) + : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { + + vector group_types; + for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { + auto &pexpr = *gstate.partitions[prt_idx].expression.get(); + group_types.push_back(pexpr.return_type); + executor.AddExpression(pexpr); + } + sort_cols = gstate.orders.size() + group_types.size(); + + if (sort_cols) { + auto payload_types = gstate.payload_types; + if (!group_types.empty()) { + // OVER(PARTITION BY...) + group_chunk.Initialize(allocator, group_types); + payload_types.emplace_back(LogicalType::HASH); + } else { + // OVER(ORDER BY...) + for (idx_t ord_idx = 0; ord_idx < gstate.orders.size(); ord_idx++) { + auto &pexpr = *gstate.orders[ord_idx].expression.get(); + group_types.push_back(pexpr.return_type); + executor.AddExpression(pexpr); + } + group_chunk.Initialize(allocator, group_types); + + // Single partition + auto &global_sort = *gstate.hash_groups[0]->global_sort; + local_sort = make_uniq(); + local_sort->Initialize(global_sort, global_sort.buffer_manager); + } + // OVER(...) + payload_chunk.Initialize(allocator, payload_types); + } else { + // OVER() + payload_layout.Initialize(gstate.payload_types); + } +} + +void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { + const auto count = input_chunk.size(); + D_ASSERT(group_chunk.ColumnCount() > 0); + + // OVER(PARTITION BY...) (hash grouping) + group_chunk.Reset(); + executor.Execute(input_chunk, group_chunk); + VectorOperations::Hash(group_chunk.data[0], hash_vector, count); + for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { + VectorOperations::CombineHash(hash_vector, group_chunk.data[prt_idx], count); + } +} + +void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { + gstate.count += input_chunk.size(); + + // OVER() + if (sort_cols == 0) { + // No sorts, so build paged row chunks + if (!rows) { + const auto entry_size = payload_layout.GetRowWidth(); + const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, (Storage::BLOCK_SIZE / entry_size) + 1); + rows = make_uniq(gstate.buffer_manager, capacity, entry_size); + strings = make_uniq(gstate.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); + } + const auto row_count = input_chunk.size(); + const auto row_sel = FlatVector::IncrementalSelectionVector(); + Vector addresses(LogicalType::POINTER); + auto key_locations = FlatVector::GetData(addresses); + const auto prev_rows_blocks = rows->blocks.size(); + auto handles = rows->Build(row_count, key_locations, nullptr, row_sel); + auto input_data = input_chunk.ToUnifiedFormat(); + RowOperations::Scatter(input_chunk, input_data.get(), payload_layout, addresses, *strings, *row_sel, row_count); + // Mark that row blocks contain pointers (heap blocks are pinned) + if (!payload_layout.AllConstant()) { + D_ASSERT(strings->keep_pinned); + for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { + rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink"); + } + } + return; + } + + if (local_sort) { + // OVER(ORDER BY...) + group_chunk.Reset(); + executor.Execute(input_chunk, group_chunk); + local_sort->SinkChunk(group_chunk, input_chunk); + + auto &hash_group = *gstate.hash_groups[0]; + hash_group.count += input_chunk.size(); + + if (local_sort->SizeInBytes() > gstate.memory_per_thread) { + auto &global_sort = *hash_group.global_sort; + local_sort->Sort(global_sort, true); + } + return; + } + + // OVER(...) + payload_chunk.Reset(); + auto &hash_vector = payload_chunk.data.back(); + Hash(input_chunk, hash_vector); + for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { + payload_chunk.data[col_idx].Reference(input_chunk.data[col_idx]); + } + payload_chunk.SetCardinality(input_chunk); + + gstate.UpdateLocalPartition(local_partition, local_append); + local_partition->Append(*local_append, payload_chunk); +} + +void PartitionLocalSinkState::Combine() { + // OVER() + if (sort_cols == 0) { + // Only one partition again, so need a global lock. + lock_guard glock(gstate.lock); + if (gstate.rows) { + if (rows) { + gstate.rows->Merge(*rows); + gstate.strings->Merge(*strings); + rows.reset(); + strings.reset(); + } + } else { + gstate.rows = std::move(rows); + gstate.strings = std::move(strings); + } + return; + } + + if (local_sort) { + // OVER(ORDER BY...) + auto &hash_group = *gstate.hash_groups[0]; + auto &global_sort = *hash_group.global_sort; + global_sort.AddLocalState(*local_sort); + local_sort.reset(); + return; + } + + // OVER(...) + gstate.CombineLocalPartition(local_partition, local_append); +} + +PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data_p, + hash_t hash_bin) + : sink(sink), group_data(std::move(group_data_p)), memory_per_thread(sink.memory_per_thread), + num_threads(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()), stage(PartitionSortStage::INIT), + total_tasks(0), tasks_assigned(0), tasks_completed(0) { + + const auto group_idx = sink.hash_groups.size(); + auto new_group = make_uniq(sink.buffer_manager, sink.partitions, sink.orders, + sink.payload_types, sink.external); + sink.hash_groups.emplace_back(std::move(new_group)); + + hash_group = sink.hash_groups[group_idx].get(); + global_sort = sink.hash_groups[group_idx]->global_sort.get(); + + sink.bin_groups[hash_bin] = group_idx; + + column_ids.reserve(sink.payload_types.size()); + for (column_t i = 0; i < sink.payload_types.size(); ++i) { + column_ids.emplace_back(i); + } + group_data->InitializeScan(chunk_state, column_ids); +} + +PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink) + : sink(sink), memory_per_thread(sink.memory_per_thread), + num_threads(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()), stage(PartitionSortStage::INIT), + total_tasks(0), tasks_assigned(0), tasks_completed(0) { + + const hash_t hash_bin = 0; + const size_t group_idx = 0; + hash_group = sink.hash_groups[group_idx].get(); + global_sort = sink.hash_groups[group_idx]->global_sort.get(); + + sink.bin_groups[hash_bin] = group_idx; +} + +void PartitionLocalMergeState::Prepare() { + merge_state->group_data.reset(); + + auto &global_sort = *merge_state->global_sort; + global_sort.PrepareMergePhase(); +} + +void PartitionLocalMergeState::Merge() { + auto &global_sort = *merge_state->global_sort; + MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); + merge_sorter.PerformInMergeRound(); +} + +void PartitionLocalMergeState::ExecuteTask() { + switch (stage) { + case PartitionSortStage::SCAN: + Scan(); + break; + case PartitionSortStage::PREPARE: + Prepare(); + break; + case PartitionSortStage::MERGE: + Merge(); + break; + default: + throw InternalException("Unexpected PartitionSortStage in ExecuteTask!"); + } + + merge_state->CompleteTask(); + finished = true; +} + +bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { + lock_guard guard(lock); + + if (tasks_assigned >= total_tasks) { + return false; + } + + local_state.merge_state = this; + local_state.stage = stage; + local_state.finished = false; + tasks_assigned++; + + return true; +} + +void PartitionGlobalMergeState::CompleteTask() { + lock_guard guard(lock); + + ++tasks_completed; +} + +bool PartitionGlobalMergeState::TryPrepareNextStage() { + lock_guard guard(lock); + + if (tasks_completed < total_tasks) { + return false; + } + + tasks_assigned = tasks_completed = 0; + + switch (stage) { + case PartitionSortStage::INIT: + // If the partitions are unordered, don't scan in parallel + // because it produces non-deterministic orderings. + // This can theoretically happen with ORDER BY, + // but that is something the query should be explicit about. + total_tasks = sink.orders.size() > sink.partitions.size() ? num_threads : 1; + stage = PartitionSortStage::SCAN; + return true; + + case PartitionSortStage::SCAN: + total_tasks = 1; + stage = PartitionSortStage::PREPARE; + return true; + + case PartitionSortStage::PREPARE: + total_tasks = global_sort->sorted_blocks.size() / 2; + if (!total_tasks) { + break; + } + stage = PartitionSortStage::MERGE; + global_sort->InitializeMergeRound(); + return true; + + case PartitionSortStage::MERGE: + global_sort->CompleteMergeRound(true); + total_tasks = global_sort->sorted_blocks.size() / 2; + if (!total_tasks) { + break; + } + global_sort->InitializeMergeRound(); + return true; + + case PartitionSortStage::SORTED: + break; + } + + stage = PartitionSortStage::SORTED; + + return false; +} + +PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { + // Schedule all the sorts for maximum thread utilisation + if (sink.grouping_data) { + auto &partitions = sink.grouping_data->GetPartitions(); + sink.bin_groups.resize(partitions.size(), partitions.size()); + for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { + auto &group_data = partitions[hash_bin]; + // Prepare for merge sort phase + if (group_data->Count()) { + auto state = make_uniq(sink, std::move(group_data), hash_bin); + states.emplace_back(std::move(state)); + } + } + } else { + // OVER(ORDER BY...) + // Already sunk into the single global sort, so set up single merge with no data + sink.bin_groups.resize(1, 1); + auto state = make_uniq(sink); + states.emplace_back(std::move(state)); + } +} + +class PartitionMergeTask : public ExecutorTask { +public: + PartitionMergeTask(shared_ptr event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p, + PartitionGlobalSinkState &gstate) + : ExecutorTask(context_p), event(std::move(event_p)), local_state(gstate), hash_groups(hash_groups_p) { + } + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; + +private: + struct ExecutorCallback : public PartitionGlobalMergeStates::Callback { + explicit ExecutorCallback(Executor &executor) : executor(executor) { + } + + bool HasError() const override { + return executor.HasError(); + } + + Executor &executor; + }; + + shared_ptr event; + PartitionLocalMergeState local_state; + PartitionGlobalMergeStates &hash_groups; +}; + +bool PartitionGlobalMergeStates::ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback) { + // Loop until all hash groups are done + size_t sorted = 0; + while (sorted < states.size()) { + // First check if there is an unfinished task for this thread + if (callback.HasError()) { + return false; + } + if (!local_state.TaskFinished()) { + local_state.ExecuteTask(); + continue; + } + + // Thread is done with its assigned task, try to fetch new work + for (auto group = sorted; group < states.size(); ++group) { + auto &global_state = states[group]; + if (global_state->IsSorted()) { + // This hash group is done + // Update the high water mark of densely completed groups + if (sorted == group) { + ++sorted; + } + continue; + } + + // Try to assign work for this hash group to this thread + if (global_state->AssignTask(local_state)) { + // We assigned a task to this thread! + // Break out of this loop to re-enter the top-level loop and execute the task + break; + } + + // Hash group global state couldn't assign a task to this thread + // Try to prepare the next stage + if (!global_state->TryPrepareNextStage()) { + // This current hash group is not yet done + // But we were not able to assign a task for it to this thread + // See if the next hash group is better + continue; + } + + // We were able to prepare the next stage for this hash group! + // Try to assign a task once more + if (global_state->AssignTask(local_state)) { + // We assigned a task to this thread! + // Break out of this loop to re-enter the top-level loop and execute the task + break; + } + + // We were able to prepare the next merge round, + // but we were not able to assign a task for it to this thread + // The tasks were assigned to other threads while this thread waited for the lock + // Go to the next iteration to see if another hash group has a task + } + } + + return true; +} + +TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { + ExecutorCallback callback(executor); + + if (!hash_groups.ExecuteTask(local_state, callback)) { + return TaskExecutionResult::TASK_ERROR; + } + + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; +} + +void PartitionMergeEvent::Schedule() { + auto &context = pipeline->GetClientContext(); + + // Schedule tasks equal to the number of threads, which will each merge multiple partitions + auto &ts = TaskScheduler::GetScheduler(context); + idx_t num_threads = ts.NumberOfThreads(); + + vector> merge_tasks; + for (idx_t tnum = 0; tnum < num_threads; tnum++) { + merge_tasks.emplace_back(make_uniq(shared_from_this(), context, merge_states, gstate)); + } + SetTasks(std::move(merge_tasks)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sort/radix_sort.cpp b/src/duckdb/src/common/sort/radix_sort.cpp new file mode 100644 index 00000000..179d5f9d --- /dev/null +++ b/src/duckdb/src/common/sort/radix_sort.cpp @@ -0,0 +1,344 @@ +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/sort/comparators.hpp" +#include "duckdb/common/sort/duckdb_pdqsort.hpp" +#include "duckdb/common/sort/sort.hpp" + +namespace duckdb { + +//! Calls std::sort on strings that are tied by their prefix after the radix sort +static void SortTiedBlobs(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &start, const idx_t &end, + const idx_t &tie_col, bool *ties, const data_ptr_t blob_ptr, const SortLayout &sort_layout) { + const auto row_width = sort_layout.blob_layout.GetRowWidth(); + // Locate the first blob row in question + data_ptr_t row_ptr = dataptr + start * sort_layout.entry_size; + data_ptr_t blob_row_ptr = blob_ptr + Load(row_ptr + sort_layout.comparison_size) * row_width; + if (!Comparators::TieIsBreakable(tie_col, blob_row_ptr, sort_layout)) { + // Quick check to see if ties can be broken + return; + } + // Fill pointer array for sorting + auto ptr_block = make_unsafe_uniq_array(end - start); + auto entry_ptrs = (data_ptr_t *)ptr_block.get(); + for (idx_t i = start; i < end; i++) { + entry_ptrs[i - start] = row_ptr; + row_ptr += sort_layout.entry_size; + } + // Slow pointer-based sorting + const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; + const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); + const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; + auto logical_type = sort_layout.blob_layout.GetTypes()[col_idx]; + std::sort(entry_ptrs, entry_ptrs + end - start, + [&blob_ptr, &order, &sort_layout, &tie_col_offset, &row_width, &logical_type](const data_ptr_t l, + const data_ptr_t r) { + idx_t left_idx = Load(l + sort_layout.comparison_size); + idx_t right_idx = Load(r + sort_layout.comparison_size); + data_ptr_t left_ptr = blob_ptr + left_idx * row_width + tie_col_offset; + data_ptr_t right_ptr = blob_ptr + right_idx * row_width + tie_col_offset; + return order * Comparators::CompareVal(left_ptr, right_ptr, logical_type) < 0; + }); + // Re-order + auto temp_block = buffer_manager.GetBufferAllocator().Allocate((end - start) * sort_layout.entry_size); + data_ptr_t temp_ptr = temp_block.get(); + for (idx_t i = 0; i < end - start; i++) { + FastMemcpy(temp_ptr, entry_ptrs[i], sort_layout.entry_size); + temp_ptr += sort_layout.entry_size; + } + memcpy(dataptr + start * sort_layout.entry_size, temp_block.get(), (end - start) * sort_layout.entry_size); + // Determine if there are still ties (if this is not the last column) + if (tie_col < sort_layout.column_count - 1) { + data_ptr_t idx_ptr = dataptr + start * sort_layout.entry_size + sort_layout.comparison_size; + // Load current entry + data_ptr_t current_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; + for (idx_t i = 0; i < end - start - 1; i++) { + // Load next entry and compare + idx_ptr += sort_layout.entry_size; + data_ptr_t next_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; + ties[start + i] = Comparators::CompareVal(current_ptr, next_ptr, logical_type) == 0; + current_ptr = next_ptr; + } + } +} + +//! Identifies sequences of rows that are tied by the prefix of a blob column, and sorts them +static void SortTiedBlobs(BufferManager &buffer_manager, SortedBlock &sb, bool *ties, data_ptr_t dataptr, + const idx_t &count, const idx_t &tie_col, const SortLayout &sort_layout) { + D_ASSERT(!ties[count - 1]); + auto &blob_block = *sb.blob_sorting_data->data_blocks.back(); + auto blob_handle = buffer_manager.Pin(blob_block.block); + const data_ptr_t blob_ptr = blob_handle.Ptr(); + + for (idx_t i = 0; i < count; i++) { + if (!ties[i]) { + continue; + } + idx_t j; + for (j = i; j < count; j++) { + if (!ties[j]) { + break; + } + } + SortTiedBlobs(buffer_manager, dataptr, i, j + 1, tie_col, ties, blob_ptr, sort_layout); + i = j; + } +} + +//! Returns whether there are any 'true' values in the ties[] array +static bool AnyTies(bool ties[], const idx_t &count) { + D_ASSERT(!ties[count - 1]); + bool any_ties = false; + for (idx_t i = 0; i < count - 1; i++) { + any_ties = any_ties || ties[i]; + } + return any_ties; +} + +//! Compares subsequent rows to check for ties +static void ComputeTies(data_ptr_t dataptr, const idx_t &count, const idx_t &col_offset, const idx_t &tie_size, + bool ties[], const SortLayout &sort_layout) { + D_ASSERT(!ties[count - 1]); + D_ASSERT(col_offset + tie_size <= sort_layout.comparison_size); + // Align dataptr + dataptr += col_offset; + for (idx_t i = 0; i < count - 1; i++) { + ties[i] = ties[i] && FastMemcmp(dataptr, dataptr + sort_layout.entry_size, tie_size) == 0; + dataptr += sort_layout.entry_size; + } +} + +//! Textbook LSD radix sort +void RadixSortLSD(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, + const idx_t &row_width, const idx_t &sorting_size) { + auto temp_block = buffer_manager.GetBufferAllocator().Allocate(count * row_width); + bool swap = false; + + idx_t counts[SortConstants::VALUES_PER_RADIX]; + for (idx_t r = 1; r <= sorting_size; r++) { + // Init counts to 0 + memset(counts, 0, sizeof(counts)); + // Const some values for convenience + const data_ptr_t source_ptr = swap ? temp_block.get() : dataptr; + const data_ptr_t target_ptr = swap ? dataptr : temp_block.get(); + const idx_t offset = col_offset + sorting_size - r; + // Collect counts + data_ptr_t offset_ptr = source_ptr + offset; + for (idx_t i = 0; i < count; i++) { + counts[*offset_ptr]++; + offset_ptr += row_width; + } + // Compute offsets from counts + idx_t max_count = counts[0]; + for (idx_t val = 1; val < SortConstants::VALUES_PER_RADIX; val++) { + max_count = MaxValue(max_count, counts[val]); + counts[val] = counts[val] + counts[val - 1]; + } + if (max_count == count) { + continue; + } + // Re-order the data in temporary array + data_ptr_t row_ptr = source_ptr + (count - 1) * row_width; + for (idx_t i = 0; i < count; i++) { + idx_t &radix_offset = --counts[*(row_ptr + offset)]; + FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); + row_ptr -= row_width; + } + swap = !swap; + } + // Move data back to original buffer (if it was swapped) + if (swap) { + memcpy(dataptr, temp_block.get(), count * row_width); + } +} + +//! Insertion sort, used when count of values is low +inline void InsertionSort(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, + const idx_t &col_offset, const idx_t &row_width, const idx_t &total_comp_width, + const idx_t &offset, bool swap) { + const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; + const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; + if (count > 1) { + const idx_t total_offset = col_offset + offset; + auto temp_val = make_unsafe_uniq_array(row_width); + const data_ptr_t val = temp_val.get(); + const auto comp_width = total_comp_width - offset; + for (idx_t i = 1; i < count; i++) { + FastMemcpy(val, source_ptr + i * row_width, row_width); + idx_t j = i; + while (j > 0 && + FastMemcmp(source_ptr + (j - 1) * row_width + total_offset, val + total_offset, comp_width) > 0) { + FastMemcpy(source_ptr + j * row_width, source_ptr + (j - 1) * row_width, row_width); + j--; + } + FastMemcpy(source_ptr + j * row_width, val, row_width); + } + } + if (swap) { + memcpy(target_ptr, source_ptr, count * row_width); + } +} + +//! MSD radix sort that switches to insertion sort with low bucket sizes +void RadixSortMSD(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, const idx_t &col_offset, + const idx_t &row_width, const idx_t &comp_width, const idx_t &offset, idx_t locations[], bool swap) { + const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; + const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; + // Init counts to 0 + memset(locations, 0, SortConstants::MSD_RADIX_LOCATIONS * sizeof(idx_t)); + idx_t *counts = locations + 1; + // Collect counts + const idx_t total_offset = col_offset + offset; + data_ptr_t offset_ptr = source_ptr + total_offset; + for (idx_t i = 0; i < count; i++) { + counts[*offset_ptr]++; + offset_ptr += row_width; + } + // Compute locations from counts + idx_t max_count = 0; + for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { + max_count = MaxValue(max_count, counts[radix]); + counts[radix] += locations[radix]; + } + if (max_count != count) { + // Re-order the data in temporary array + data_ptr_t row_ptr = source_ptr; + for (idx_t i = 0; i < count; i++) { + const idx_t &radix_offset = locations[*(row_ptr + total_offset)]++; + FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); + row_ptr += row_width; + } + swap = !swap; + } + // Check if done + if (offset == comp_width - 1) { + if (swap) { + memcpy(orig_ptr, temp_ptr, count * row_width); + } + return; + } + if (max_count == count) { + RadixSortMSD(orig_ptr, temp_ptr, count, col_offset, row_width, comp_width, offset + 1, + locations + SortConstants::MSD_RADIX_LOCATIONS, swap); + return; + } + // Recurse + idx_t radix_count = locations[0]; + for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { + const idx_t loc = (locations[radix] - radix_count) * row_width; + if (radix_count > SortConstants::INSERTION_SORT_THRESHOLD) { + RadixSortMSD(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, + locations + SortConstants::MSD_RADIX_LOCATIONS, swap); + } else if (radix_count != 0) { + InsertionSort(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, + swap); + } + radix_count = locations[radix + 1] - locations[radix]; + } +} + +//! Calls different sort functions, depending on the count and sorting sizes +void RadixSort(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, + const idx_t &sorting_size, const SortLayout &sort_layout, bool contains_string) { + if (contains_string) { + auto begin = duckdb_pdqsort::PDQIterator(dataptr, sort_layout.entry_size); + auto end = begin + count; + duckdb_pdqsort::PDQConstants constants(sort_layout.entry_size, col_offset, sorting_size, *end); + duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); + } else if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { + InsertionSort(dataptr, nullptr, count, 0, sort_layout.entry_size, sort_layout.comparison_size, 0, false); + } else if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { + RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); + } else { + auto temp_block = buffer_manager.Allocate(MaxValue(count * sort_layout.entry_size, (idx_t)Storage::BLOCK_SIZE)); + auto preallocated_array = make_unsafe_uniq_array(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); + RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, + preallocated_array.get(), false); + } +} + +//! Identifies sequences of rows that are tied, and calls radix sort on these +static void SubSortTiedTuples(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &count, + const idx_t &col_offset, const idx_t &sorting_size, bool ties[], + const SortLayout &sort_layout, bool contains_string) { + D_ASSERT(!ties[count - 1]); + for (idx_t i = 0; i < count; i++) { + if (!ties[i]) { + continue; + } + idx_t j; + for (j = i + 1; j < count; j++) { + if (!ties[j]) { + break; + } + } + RadixSort(buffer_manager, dataptr + i * sort_layout.entry_size, j - i + 1, col_offset, sorting_size, + sort_layout, contains_string); + i = j; + } +} + +void LocalSortState::SortInMemory() { + auto &sb = *sorted_blocks.back(); + auto &block = *sb.radix_sorting_data.back(); + const auto &count = block.count; + auto handle = buffer_manager->Pin(block.block); + const auto dataptr = handle.Ptr(); + // Assign an index to each row + data_ptr_t idx_dataptr = dataptr + sort_layout->comparison_size; + for (uint32_t i = 0; i < count; i++) { + Store(i, idx_dataptr); + idx_dataptr += sort_layout->entry_size; + } + // Radix sort and break ties until no more ties, or until all columns are sorted + idx_t sorting_size = 0; + idx_t col_offset = 0; + unsafe_unique_array ties_ptr; + bool *ties = nullptr; + bool contains_string = false; + for (idx_t i = 0; i < sort_layout->column_count; i++) { + sorting_size += sort_layout->column_sizes[i]; + contains_string = contains_string || sort_layout->logical_types[i].InternalType() == PhysicalType::VARCHAR; + if (sort_layout->constant_size[i] && i < sort_layout->column_count - 1) { + // Add columns to the sorting size until we reach a variable size column, or the last column + continue; + } + + if (!ties) { + // This is the first sort + RadixSort(*buffer_manager, dataptr, count, col_offset, sorting_size, *sort_layout, contains_string); + ties_ptr = make_unsafe_uniq_array(count); + ties = ties_ptr.get(); + std::fill_n(ties, count - 1, true); + ties[count - 1] = false; + } else { + // For subsequent sorts, we only have to subsort the tied tuples + SubSortTiedTuples(*buffer_manager, dataptr, count, col_offset, sorting_size, ties, *sort_layout, + contains_string); + } + + contains_string = false; + + if (sort_layout->constant_size[i] && i == sort_layout->column_count - 1) { + // All columns are sorted, no ties to break because last column is constant size + break; + } + + ComputeTies(dataptr, count, col_offset, sorting_size, ties, *sort_layout); + if (!AnyTies(ties, count)) { + // No ties, stop sorting + break; + } + + if (!sort_layout->constant_size[i]) { + SortTiedBlobs(*buffer_manager, sb, ties, dataptr, count, i, *sort_layout); + if (!AnyTies(ties, count)) { + // No more ties after tie-breaking, stop + break; + } + } + + col_offset += sorting_size; + sorting_size = 0; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sort/sort_state.cpp b/src/duckdb/src/common/sort/sort_state.cpp new file mode 100644 index 00000000..b10970d0 --- /dev/null +++ b/src/duckdb/src/common/sort/sort_state.cpp @@ -0,0 +1,479 @@ +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/radix.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sort/sorted_block.hpp" + +#include +#include + +namespace duckdb { + +idx_t GetNestedSortingColSize(idx_t &col_size, const LogicalType &type) { + auto physical_type = type.InternalType(); + if (TypeIsConstantSize(physical_type)) { + col_size += GetTypeIdSize(physical_type); + return 0; + } else { + switch (physical_type) { + case PhysicalType::VARCHAR: { + // Nested strings are between 4 and 11 chars long for alignment + auto size_before_str = col_size; + col_size += 11; + col_size -= (col_size - 12) % 8; + return col_size - size_before_str; + } + case PhysicalType::LIST: + // Lists get 2 bytes (null and empty list) + col_size += 2; + return GetNestedSortingColSize(col_size, ListType::GetChildType(type)); + case PhysicalType::STRUCT: + // Structs get 1 bytes (null) + col_size++; + return GetNestedSortingColSize(col_size, StructType::GetChildType(type, 0)); + default: + throw NotImplementedException("Unable to order column with type %s", type.ToString()); + } + } +} + +SortLayout::SortLayout(const vector &orders) + : column_count(orders.size()), all_constant(true), comparison_size(0), entry_size(0) { + vector blob_layout_types; + for (idx_t i = 0; i < column_count; i++) { + const auto &order = orders[i]; + + order_types.push_back(order.type); + order_by_null_types.push_back(order.null_order); + auto &expr = *order.expression; + logical_types.push_back(expr.return_type); + + auto physical_type = expr.return_type.InternalType(); + constant_size.push_back(TypeIsConstantSize(physical_type)); + + if (order.stats) { + stats.push_back(order.stats.get()); + has_null.push_back(stats.back()->CanHaveNull()); + } else { + stats.push_back(nullptr); + has_null.push_back(true); + } + + idx_t col_size = has_null.back() ? 1 : 0; + prefix_lengths.push_back(0); + if (!TypeIsConstantSize(physical_type) && physical_type != PhysicalType::VARCHAR) { + prefix_lengths.back() = GetNestedSortingColSize(col_size, expr.return_type); + } else if (physical_type == PhysicalType::VARCHAR) { + idx_t size_before = col_size; + if (stats.back() && StringStats::HasMaxStringLength(*stats.back())) { + col_size += StringStats::MaxStringLength(*stats.back()); + if (col_size > 12) { + col_size = 12; + } else { + constant_size.back() = true; + } + } else { + col_size = 12; + } + prefix_lengths.back() = col_size - size_before; + } else { + col_size += GetTypeIdSize(physical_type); + } + + comparison_size += col_size; + column_sizes.push_back(col_size); + } + entry_size = comparison_size + sizeof(uint32_t); + + // 8-byte alignment + if (entry_size % 8 != 0) { + // First assign more bytes to strings instead of aligning + idx_t bytes_to_fill = 8 - (entry_size % 8); + for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { + if (bytes_to_fill == 0) { + break; + } + if (logical_types[col_idx].InternalType() == PhysicalType::VARCHAR && stats[col_idx] && + StringStats::HasMaxStringLength(*stats[col_idx])) { + idx_t diff = StringStats::MaxStringLength(*stats[col_idx]) - prefix_lengths[col_idx]; + if (diff > 0) { + // Increase all sizes accordingly + idx_t increase = MinValue(bytes_to_fill, diff); + column_sizes[col_idx] += increase; + prefix_lengths[col_idx] += increase; + constant_size[col_idx] = increase == diff; + comparison_size += increase; + entry_size += increase; + bytes_to_fill -= increase; + } + } + } + entry_size = AlignValue(entry_size); + } + + for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { + all_constant = all_constant && constant_size[col_idx]; + if (!constant_size[col_idx]) { + sorting_to_blob_col[col_idx] = blob_layout_types.size(); + blob_layout_types.push_back(logical_types[col_idx]); + } + } + + blob_layout.Initialize(blob_layout_types); +} + +SortLayout SortLayout::GetPrefixComparisonLayout(idx_t num_prefix_cols) const { + SortLayout result; + result.column_count = num_prefix_cols; + result.all_constant = true; + result.comparison_size = 0; + for (idx_t col_idx = 0; col_idx < num_prefix_cols; col_idx++) { + result.order_types.push_back(order_types[col_idx]); + result.order_by_null_types.push_back(order_by_null_types[col_idx]); + result.logical_types.push_back(logical_types[col_idx]); + + result.all_constant = result.all_constant && constant_size[col_idx]; + result.constant_size.push_back(constant_size[col_idx]); + + result.comparison_size += column_sizes[col_idx]; + result.column_sizes.push_back(column_sizes[col_idx]); + + result.prefix_lengths.push_back(prefix_lengths[col_idx]); + result.stats.push_back(stats[col_idx]); + result.has_null.push_back(has_null[col_idx]); + } + result.entry_size = entry_size; + result.blob_layout = blob_layout; + result.sorting_to_blob_col = sorting_to_blob_col; + return result; +} + +LocalSortState::LocalSortState() : initialized(false) { + if (!Radix::IsLittleEndian()) { + throw NotImplementedException("Sorting is not supported on big endian architectures"); + } +} + +void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p) { + sort_layout = &global_sort_state.sort_layout; + payload_layout = &global_sort_state.payload_layout; + buffer_manager = &buffer_manager_p; + // Radix sorting data + radix_sorting_data = make_uniq( + *buffer_manager, RowDataCollection::EntriesPerBlock(sort_layout->entry_size), sort_layout->entry_size); + // Blob sorting data + if (!sort_layout->all_constant) { + auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); + blob_sorting_data = make_uniq( + *buffer_manager, RowDataCollection::EntriesPerBlock(blob_row_width), blob_row_width); + blob_sorting_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); + } + // Payload data + auto payload_row_width = payload_layout->GetRowWidth(); + payload_data = make_uniq(*buffer_manager, RowDataCollection::EntriesPerBlock(payload_row_width), + payload_row_width); + payload_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); + // Init done + initialized = true; +} + +void LocalSortState::SinkChunk(DataChunk &sort, DataChunk &payload) { + D_ASSERT(sort.size() == payload.size()); + // Build and serialize sorting data to radix sortable rows + auto data_pointers = FlatVector::GetData(addresses); + auto handles = radix_sorting_data->Build(sort.size(), data_pointers, nullptr); + for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { + bool has_null = sort_layout->has_null[sort_col]; + bool nulls_first = sort_layout->order_by_null_types[sort_col] == OrderByNullType::NULLS_FIRST; + bool desc = sort_layout->order_types[sort_col] == OrderType::DESCENDING; + RowOperations::RadixScatter(sort.data[sort_col], sort.size(), sel_ptr, sort.size(), data_pointers, desc, + has_null, nulls_first, sort_layout->prefix_lengths[sort_col], + sort_layout->column_sizes[sort_col]); + } + + // Also fully serialize blob sorting columns (to be able to break ties + if (!sort_layout->all_constant) { + DataChunk blob_chunk; + blob_chunk.SetCardinality(sort.size()); + for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { + if (!sort_layout->constant_size[sort_col]) { + blob_chunk.data.emplace_back(sort.data[sort_col]); + } + } + handles = blob_sorting_data->Build(blob_chunk.size(), data_pointers, nullptr); + auto blob_data = blob_chunk.ToUnifiedFormat(); + RowOperations::Scatter(blob_chunk, blob_data.get(), sort_layout->blob_layout, addresses, *blob_sorting_heap, + sel_ptr, blob_chunk.size()); + D_ASSERT(blob_sorting_heap->keep_pinned); + } + + // Finally, serialize payload data + handles = payload_data->Build(payload.size(), data_pointers, nullptr); + auto input_data = payload.ToUnifiedFormat(); + RowOperations::Scatter(payload, input_data.get(), *payload_layout, addresses, *payload_heap, sel_ptr, + payload.size()); + D_ASSERT(payload_heap->keep_pinned); +} + +idx_t LocalSortState::SizeInBytes() const { + idx_t size_in_bytes = radix_sorting_data->SizeInBytes() + payload_data->SizeInBytes(); + if (!sort_layout->all_constant) { + size_in_bytes += blob_sorting_data->SizeInBytes() + blob_sorting_heap->SizeInBytes(); + } + if (!payload_layout->AllConstant()) { + size_in_bytes += payload_heap->SizeInBytes(); + } + return size_in_bytes; +} + +void LocalSortState::Sort(GlobalSortState &global_sort_state, bool reorder_heap) { + D_ASSERT(radix_sorting_data->count == payload_data->count); + if (radix_sorting_data->count == 0) { + return; + } + // Move all data to a single SortedBlock + sorted_blocks.emplace_back(make_uniq(*buffer_manager, global_sort_state)); + auto &sb = *sorted_blocks.back(); + // Fixed-size sorting data + auto sorting_block = ConcatenateBlocks(*radix_sorting_data); + sb.radix_sorting_data.push_back(std::move(sorting_block)); + // Variable-size sorting data + if (!sort_layout->all_constant) { + auto &blob_data = *blob_sorting_data; + auto new_block = ConcatenateBlocks(blob_data); + sb.blob_sorting_data->data_blocks.push_back(std::move(new_block)); + } + // Payload data + auto payload_block = ConcatenateBlocks(*payload_data); + sb.payload_data->data_blocks.push_back(std::move(payload_block)); + // Now perform the actual sort + SortInMemory(); + // Re-order before the merge sort + ReOrder(global_sort_state, reorder_heap); +} + +unique_ptr LocalSortState::ConcatenateBlocks(RowDataCollection &row_data) { + // Don't copy and delete if there is only one block. + if (row_data.blocks.size() == 1) { + auto new_block = std::move(row_data.blocks[0]); + row_data.blocks.clear(); + row_data.count = 0; + return new_block; + } + // Create block with the correct capacity + auto buffer_manager = &row_data.buffer_manager; + const idx_t &entry_size = row_data.entry_size; + idx_t capacity = MaxValue(((idx_t)Storage::BLOCK_SIZE + entry_size - 1) / entry_size, row_data.count); + auto new_block = make_uniq(*buffer_manager, capacity, entry_size); + new_block->count = row_data.count; + auto new_block_handle = buffer_manager->Pin(new_block->block); + data_ptr_t new_block_ptr = new_block_handle.Ptr(); + // Copy the data of the blocks into a single block + for (idx_t i = 0; i < row_data.blocks.size(); i++) { + auto &block = row_data.blocks[i]; + auto block_handle = buffer_manager->Pin(block->block); + memcpy(new_block_ptr, block_handle.Ptr(), block->count * entry_size); + new_block_ptr += block->count * entry_size; + block.reset(); + } + row_data.blocks.clear(); + row_data.count = 0; + return new_block; +} + +void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, + bool reorder_heap) { + sd.swizzled = reorder_heap; + auto &unordered_data_block = sd.data_blocks.back(); + const idx_t count = unordered_data_block->count; + auto unordered_data_handle = buffer_manager->Pin(unordered_data_block->block); + const data_ptr_t unordered_data_ptr = unordered_data_handle.Ptr(); + // Create new block that will hold re-ordered row data + auto ordered_data_block = + make_uniq(*buffer_manager, unordered_data_block->capacity, unordered_data_block->entry_size); + ordered_data_block->count = count; + auto ordered_data_handle = buffer_manager->Pin(ordered_data_block->block); + data_ptr_t ordered_data_ptr = ordered_data_handle.Ptr(); + // Re-order fixed-size row layout + const idx_t row_width = sd.layout.GetRowWidth(); + const idx_t sorting_entry_size = gstate.sort_layout.entry_size; + for (idx_t i = 0; i < count; i++) { + auto index = Load(sorting_ptr); + FastMemcpy(ordered_data_ptr, unordered_data_ptr + index * row_width, row_width); + ordered_data_ptr += row_width; + sorting_ptr += sorting_entry_size; + } + ordered_data_block->block->SetSwizzling( + sd.layout.AllConstant() || !sd.swizzled ? nullptr : "LocalSortState::ReOrder.ordered_data"); + // Replace the unordered data block with the re-ordered data block + sd.data_blocks.clear(); + sd.data_blocks.push_back(std::move(ordered_data_block)); + // Deal with the heap (if necessary) + if (!sd.layout.AllConstant() && reorder_heap) { + // Swizzle the column pointers to offsets + RowOperations::SwizzleColumns(sd.layout, ordered_data_handle.Ptr(), count); + sd.data_blocks.back()->block->SetSwizzling(nullptr); + // Create a single heap block to store the ordered heap + idx_t total_byte_offset = + std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, + [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); + idx_t heap_block_size = MaxValue(total_byte_offset, (idx_t)Storage::BLOCK_SIZE); + auto ordered_heap_block = make_uniq(*buffer_manager, heap_block_size, 1); + ordered_heap_block->count = count; + ordered_heap_block->byte_offset = total_byte_offset; + auto ordered_heap_handle = buffer_manager->Pin(ordered_heap_block->block); + data_ptr_t ordered_heap_ptr = ordered_heap_handle.Ptr(); + // Fill the heap in order + ordered_data_ptr = ordered_data_handle.Ptr(); + const idx_t heap_pointer_offset = sd.layout.GetHeapOffset(); + for (idx_t i = 0; i < count; i++) { + auto heap_row_ptr = Load(ordered_data_ptr + heap_pointer_offset); + auto heap_row_size = Load(heap_row_ptr); + memcpy(ordered_heap_ptr, heap_row_ptr, heap_row_size); + ordered_heap_ptr += heap_row_size; + ordered_data_ptr += row_width; + } + // Swizzle the base pointer to the offset of each row in the heap + RowOperations::SwizzleHeapPointer(sd.layout, ordered_data_handle.Ptr(), ordered_heap_handle.Ptr(), count); + // Move the re-ordered heap to the SortedData, and clear the local heap + sd.heap_blocks.push_back(std::move(ordered_heap_block)); + heap.pinned_blocks.clear(); + heap.blocks.clear(); + heap.count = 0; + } +} + +void LocalSortState::ReOrder(GlobalSortState &gstate, bool reorder_heap) { + auto &sb = *sorted_blocks.back(); + auto sorting_handle = buffer_manager->Pin(sb.radix_sorting_data.back()->block); + const data_ptr_t sorting_ptr = sorting_handle.Ptr() + gstate.sort_layout.comparison_size; + // Re-order variable size sorting columns + if (!gstate.sort_layout.all_constant) { + ReOrder(*sb.blob_sorting_data, sorting_ptr, *blob_sorting_heap, gstate, reorder_heap); + } + // And the payload + ReOrder(*sb.payload_data, sorting_ptr, *payload_heap, gstate, reorder_heap); +} + +GlobalSortState::GlobalSortState(BufferManager &buffer_manager, const vector &orders, + RowLayout &payload_layout) + : buffer_manager(buffer_manager), sort_layout(SortLayout(orders)), payload_layout(payload_layout), + block_capacity(0), external(false) { +} + +void GlobalSortState::AddLocalState(LocalSortState &local_sort_state) { + if (!local_sort_state.radix_sorting_data) { + return; + } + + // Sort accumulated data + // we only re-order the heap when the data is expected to not fit in memory + // re-ordering the heap avoids random access when reading/merging but incurs a significant cost of shuffling data + // when data fits in memory, doing random access on reads is cheaper than re-shuffling + local_sort_state.Sort(*this, external || !local_sort_state.sorted_blocks.empty()); + + // Append local state sorted data to this global state + lock_guard append_guard(lock); + for (auto &sb : local_sort_state.sorted_blocks) { + sorted_blocks.push_back(std::move(sb)); + } + auto &payload_heap = local_sort_state.payload_heap; + for (idx_t i = 0; i < payload_heap->blocks.size(); i++) { + heap_blocks.push_back(std::move(payload_heap->blocks[i])); + pinned_blocks.push_back(std::move(payload_heap->pinned_blocks[i])); + } + if (!sort_layout.all_constant) { + auto &blob_heap = local_sort_state.blob_sorting_heap; + for (idx_t i = 0; i < blob_heap->blocks.size(); i++) { + heap_blocks.push_back(std::move(blob_heap->blocks[i])); + pinned_blocks.push_back(std::move(blob_heap->pinned_blocks[i])); + } + } +} + +void GlobalSortState::PrepareMergePhase() { + // Determine if we need to use do an external sort + idx_t total_heap_size = + std::accumulate(sorted_blocks.begin(), sorted_blocks.end(), (idx_t)0, + [](idx_t a, const unique_ptr &b) { return a + b->HeapSize(); }); + if (external || (pinned_blocks.empty() && total_heap_size > 0.25 * buffer_manager.GetMaxMemory())) { + external = true; + } + // Use the data that we have to determine which partition size to use during the merge + if (external && total_heap_size > 0) { + // If we have variable size data we need to be conservative, as there might be skew + idx_t max_block_size = 0; + for (auto &sb : sorted_blocks) { + idx_t size_in_bytes = sb->SizeInBytes(); + if (size_in_bytes > max_block_size) { + max_block_size = size_in_bytes; + block_capacity = sb->Count(); + } + } + } else { + for (auto &sb : sorted_blocks) { + block_capacity = MaxValue(block_capacity, sb->Count()); + } + } + // Unswizzle and pin heap blocks if we can fit everything in memory + if (!external) { + for (auto &sb : sorted_blocks) { + sb->blob_sorting_data->Unswizzle(); + sb->payload_data->Unswizzle(); + } + } +} + +void GlobalSortState::InitializeMergeRound() { + D_ASSERT(sorted_blocks_temp.empty()); + // If we reverse this list, the blocks that were merged last will be merged first in the next round + // These are still in memory, therefore this reduces the amount of read/write to disk! + std::reverse(sorted_blocks.begin(), sorted_blocks.end()); + // Uneven number of blocks - keep one on the side + if (sorted_blocks.size() % 2 == 1) { + odd_one_out = std::move(sorted_blocks.back()); + sorted_blocks.pop_back(); + } + // Init merge path path indices + pair_idx = 0; + num_pairs = sorted_blocks.size() / 2; + l_start = 0; + r_start = 0; + // Allocate room for merge results + for (idx_t p_idx = 0; p_idx < num_pairs; p_idx++) { + sorted_blocks_temp.emplace_back(); + } +} + +void GlobalSortState::CompleteMergeRound(bool keep_radix_data) { + sorted_blocks.clear(); + for (auto &sorted_block_vector : sorted_blocks_temp) { + sorted_blocks.push_back(make_uniq(buffer_manager, *this)); + sorted_blocks.back()->AppendSortedBlocks(sorted_block_vector); + } + sorted_blocks_temp.clear(); + if (odd_one_out) { + sorted_blocks.push_back(std::move(odd_one_out)); + odd_one_out = nullptr; + } + // Only one block left: Done! + if (sorted_blocks.size() == 1 && !keep_radix_data) { + sorted_blocks[0]->radix_sorting_data.clear(); + sorted_blocks[0]->blob_sorting_data = nullptr; + } +} +void GlobalSortState::Print() { + PayloadScanner scanner(*this, false); + DataChunk chunk; + chunk.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); + for (;;) { + scanner.Scan(chunk); + const auto count = chunk.size(); + if (!count) { + break; + } + chunk.Print(); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sort/sorted_block.cpp b/src/duckdb/src/common/sort/sorted_block.cpp new file mode 100644 index 00000000..e0d4bfb4 --- /dev/null +++ b/src/duckdb/src/common/sort/sorted_block.cpp @@ -0,0 +1,385 @@ +#include "duckdb/common/sort/sorted_block.hpp" + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/types/row/row_data_collection.hpp" + +#include + +namespace duckdb { + +SortedData::SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, + GlobalSortState &state) + : type(type), layout(layout), swizzled(state.external), buffer_manager(buffer_manager), state(state) { +} + +idx_t SortedData::Count() { + idx_t count = std::accumulate(data_blocks.begin(), data_blocks.end(), (idx_t)0, + [](idx_t a, const unique_ptr &b) { return a + b->count; }); + if (!layout.AllConstant() && state.external) { + D_ASSERT(count == std::accumulate(heap_blocks.begin(), heap_blocks.end(), (idx_t)0, + [](idx_t a, const unique_ptr &b) { return a + b->count; })); + } + return count; +} + +void SortedData::CreateBlock() { + auto capacity = + MaxValue(((idx_t)Storage::BLOCK_SIZE + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); + data_blocks.push_back(make_uniq(buffer_manager, capacity, layout.GetRowWidth())); + if (!layout.AllConstant() && state.external) { + heap_blocks.push_back(make_uniq(buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1)); + D_ASSERT(data_blocks.size() == heap_blocks.size()); + } +} + +unique_ptr SortedData::CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index) { + // Add the corresponding blocks to the result + auto result = make_uniq(type, layout, buffer_manager, state); + for (idx_t i = start_block_index; i <= end_block_index; i++) { + result->data_blocks.push_back(data_blocks[i]->Copy()); + if (!layout.AllConstant() && state.external) { + result->heap_blocks.push_back(heap_blocks[i]->Copy()); + } + } + // All of the blocks that come before block with idx = start_block_idx can be reset (other references exist) + for (idx_t i = 0; i < start_block_index; i++) { + data_blocks[i]->block = nullptr; + if (!layout.AllConstant() && state.external) { + heap_blocks[i]->block = nullptr; + } + } + // Use start and end entry indices to set the boundaries + D_ASSERT(end_entry_index <= result->data_blocks.back()->count); + result->data_blocks.back()->count = end_entry_index; + if (!layout.AllConstant() && state.external) { + result->heap_blocks.back()->count = end_entry_index; + } + return result; +} + +void SortedData::Unswizzle() { + if (layout.AllConstant() || !swizzled) { + return; + } + for (idx_t i = 0; i < data_blocks.size(); i++) { + auto &data_block = data_blocks[i]; + auto &heap_block = heap_blocks[i]; + D_ASSERT(data_block->block->IsSwizzled()); + auto data_handle_p = buffer_manager.Pin(data_block->block); + auto heap_handle_p = buffer_manager.Pin(heap_block->block); + RowOperations::UnswizzlePointers(layout, data_handle_p.Ptr(), heap_handle_p.Ptr(), data_block->count); + state.heap_blocks.push_back(std::move(heap_block)); + state.pinned_blocks.push_back(std::move(heap_handle_p)); + } + swizzled = false; + heap_blocks.clear(); +} + +SortedBlock::SortedBlock(BufferManager &buffer_manager, GlobalSortState &state) + : buffer_manager(buffer_manager), state(state), sort_layout(state.sort_layout), + payload_layout(state.payload_layout) { + blob_sorting_data = make_uniq(SortedDataType::BLOB, sort_layout.blob_layout, buffer_manager, state); + payload_data = make_uniq(SortedDataType::PAYLOAD, payload_layout, buffer_manager, state); +} + +idx_t SortedBlock::Count() const { + idx_t count = std::accumulate(radix_sorting_data.begin(), radix_sorting_data.end(), (idx_t)0, + [](idx_t a, const unique_ptr &b) { return a + b->count; }); + if (!sort_layout.all_constant) { + D_ASSERT(count == blob_sorting_data->Count()); + } + D_ASSERT(count == payload_data->Count()); + return count; +} + +void SortedBlock::InitializeWrite() { + CreateBlock(); + if (!sort_layout.all_constant) { + blob_sorting_data->CreateBlock(); + } + payload_data->CreateBlock(); +} + +void SortedBlock::CreateBlock() { + auto capacity = MaxValue(((idx_t)Storage::BLOCK_SIZE + sort_layout.entry_size - 1) / sort_layout.entry_size, + state.block_capacity); + radix_sorting_data.push_back(make_uniq(buffer_manager, capacity, sort_layout.entry_size)); +} + +void SortedBlock::AppendSortedBlocks(vector> &sorted_blocks) { + D_ASSERT(Count() == 0); + for (auto &sb : sorted_blocks) { + for (auto &radix_block : sb->radix_sorting_data) { + radix_sorting_data.push_back(std::move(radix_block)); + } + if (!sort_layout.all_constant) { + for (auto &blob_block : sb->blob_sorting_data->data_blocks) { + blob_sorting_data->data_blocks.push_back(std::move(blob_block)); + } + for (auto &heap_block : sb->blob_sorting_data->heap_blocks) { + blob_sorting_data->heap_blocks.push_back(std::move(heap_block)); + } + } + for (auto &payload_data_block : sb->payload_data->data_blocks) { + payload_data->data_blocks.push_back(std::move(payload_data_block)); + } + if (!payload_data->layout.AllConstant()) { + for (auto &payload_heap_block : sb->payload_data->heap_blocks) { + payload_data->heap_blocks.push_back(std::move(payload_heap_block)); + } + } + } +} + +void SortedBlock::GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index) { + if (global_idx == Count()) { + local_block_index = radix_sorting_data.size() - 1; + local_entry_index = radix_sorting_data.back()->count; + return; + } + D_ASSERT(global_idx < Count()); + local_entry_index = global_idx; + for (local_block_index = 0; local_block_index < radix_sorting_data.size(); local_block_index++) { + const idx_t &block_count = radix_sorting_data[local_block_index]->count; + if (local_entry_index >= block_count) { + local_entry_index -= block_count; + } else { + break; + } + } + D_ASSERT(local_entry_index < radix_sorting_data[local_block_index]->count); +} + +unique_ptr SortedBlock::CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx) { + // Identify blocks/entry indices of this slice + idx_t start_block_index; + idx_t start_entry_index; + GlobalToLocalIndex(start, start_block_index, start_entry_index); + idx_t end_block_index; + idx_t end_entry_index; + GlobalToLocalIndex(end, end_block_index, end_entry_index); + // Add the corresponding blocks to the result + auto result = make_uniq(buffer_manager, state); + for (idx_t i = start_block_index; i <= end_block_index; i++) { + result->radix_sorting_data.push_back(radix_sorting_data[i]->Copy()); + } + // Reset all blocks that come before block with idx = start_block_idx (slice holds new reference) + for (idx_t i = 0; i < start_block_index; i++) { + radix_sorting_data[i]->block = nullptr; + } + // Use start and end entry indices to set the boundaries + entry_idx = start_entry_index; + D_ASSERT(end_entry_index <= result->radix_sorting_data.back()->count); + result->radix_sorting_data.back()->count = end_entry_index; + // Same for the var size sorting data + if (!sort_layout.all_constant) { + result->blob_sorting_data = blob_sorting_data->CreateSlice(start_block_index, end_block_index, end_entry_index); + } + // And the payload data + result->payload_data = payload_data->CreateSlice(start_block_index, end_block_index, end_entry_index); + return result; +} + +idx_t SortedBlock::HeapSize() const { + idx_t result = 0; + if (!sort_layout.all_constant) { + for (auto &block : blob_sorting_data->heap_blocks) { + result += block->capacity; + } + } + if (!payload_layout.AllConstant()) { + for (auto &block : payload_data->heap_blocks) { + result += block->capacity; + } + } + return result; +} + +idx_t SortedBlock::SizeInBytes() const { + idx_t bytes = 0; + for (idx_t i = 0; i < radix_sorting_data.size(); i++) { + bytes += radix_sorting_data[i]->capacity * sort_layout.entry_size; + if (!sort_layout.all_constant) { + bytes += blob_sorting_data->data_blocks[i]->capacity * sort_layout.blob_layout.GetRowWidth(); + bytes += blob_sorting_data->heap_blocks[i]->capacity; + } + bytes += payload_data->data_blocks[i]->capacity * payload_layout.GetRowWidth(); + if (!payload_layout.AllConstant()) { + bytes += payload_data->heap_blocks[i]->capacity; + } + } + return bytes; +} + +SBScanState::SBScanState(BufferManager &buffer_manager, GlobalSortState &state) + : buffer_manager(buffer_manager), sort_layout(state.sort_layout), state(state), block_idx(0), entry_idx(0) { +} + +void SBScanState::PinRadix(idx_t block_idx_to) { + auto &radix_sorting_data = sb->radix_sorting_data; + D_ASSERT(block_idx_to < radix_sorting_data.size()); + auto &block = radix_sorting_data[block_idx_to]; + if (!radix_handle.IsValid() || radix_handle.GetBlockHandle() != block->block) { + radix_handle = buffer_manager.Pin(block->block); + } +} + +void SBScanState::PinData(SortedData &sd) { + D_ASSERT(block_idx < sd.data_blocks.size()); + auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; + auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; + + auto &data_block = sd.data_blocks[block_idx]; + if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { + data_handle = buffer_manager.Pin(data_block->block); + } + if (sd.layout.AllConstant() || !state.external) { + return; + } + auto &heap_block = sd.heap_blocks[block_idx]; + if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { + heap_handle = buffer_manager.Pin(heap_block->block); + } +} + +data_ptr_t SBScanState::RadixPtr() const { + return radix_handle.Ptr() + entry_idx * sort_layout.entry_size; +} + +data_ptr_t SBScanState::DataPtr(SortedData &sd) const { + auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; + D_ASSERT(sd.data_blocks[block_idx]->block->Readers() != 0 && + data_handle.GetBlockHandle() == sd.data_blocks[block_idx]->block); + return data_handle.Ptr() + entry_idx * sd.layout.GetRowWidth(); +} + +data_ptr_t SBScanState::HeapPtr(SortedData &sd) const { + return BaseHeapPtr(sd) + Load(DataPtr(sd) + sd.layout.GetHeapOffset()); +} + +data_ptr_t SBScanState::BaseHeapPtr(SortedData &sd) const { + auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; + D_ASSERT(!sd.layout.AllConstant() && state.external); + D_ASSERT(sd.heap_blocks[block_idx]->block->Readers() != 0 && + heap_handle.GetBlockHandle() == sd.heap_blocks[block_idx]->block); + return heap_handle.Ptr(); +} + +idx_t SBScanState::Remaining() const { + const auto &blocks = sb->radix_sorting_data; + idx_t remaining = 0; + if (block_idx < blocks.size()) { + remaining += blocks[block_idx]->count - entry_idx; + for (idx_t i = block_idx + 1; i < blocks.size(); i++) { + remaining += blocks[i]->count; + } + } + return remaining; +} + +void SBScanState::SetIndices(idx_t block_idx_to, idx_t entry_idx_to) { + block_idx = block_idx_to; + entry_idx = entry_idx_to; +} + +PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush_p) { + auto count = sorted_data.Count(); + auto &layout = sorted_data.layout; + + // Create collections to put the data into so we can use RowDataCollectionScanner + rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); + rows->count = count; + + heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); + if (!sorted_data.layout.AllConstant()) { + heap->count = count; + } + + if (flush_p) { + // If we are flushing, we can just move the data + rows->blocks = std::move(sorted_data.data_blocks); + if (!layout.AllConstant()) { + heap->blocks = std::move(sorted_data.heap_blocks); + } + } else { + // Not flushing, create references to the blocks + for (auto &block : sorted_data.data_blocks) { + rows->blocks.emplace_back(block->Copy()); + } + if (!layout.AllConstant()) { + for (auto &block : sorted_data.heap_blocks) { + heap->blocks.emplace_back(block->Copy()); + } + } + } + + scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); +} + +PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, bool flush_p) + : PayloadScanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state, flush_p) { +} + +PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush_p) { + auto &sorted_data = *global_sort_state.sorted_blocks[0]->payload_data; + auto count = sorted_data.data_blocks[block_idx]->count; + auto &layout = sorted_data.layout; + + // Create collections to put the data into so we can use RowDataCollectionScanner + rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); + if (flush_p) { + rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); + } else { + rows->blocks.emplace_back(sorted_data.data_blocks[block_idx]->Copy()); + } + rows->count = count; + + heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); + if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { + if (flush_p) { + heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); + } else { + heap->blocks.emplace_back(sorted_data.heap_blocks[block_idx]->Copy()); + } + heap->count = count; + } + + scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); +} + +void PayloadScanner::Scan(DataChunk &chunk) { + scanner->Scan(chunk); +} + +int SBIterator::ComparisonValue(ExpressionType comparison) { + switch (comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + return -1; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return 0; + default: + throw InternalException("Unimplemented comparison type for IEJoin!"); + } +} + +static idx_t GetBlockCountWithEmptyCheck(const GlobalSortState &gss) { + D_ASSERT(!gss.sorted_blocks.empty()); + return gss.sorted_blocks[0]->radix_sorting_data.size(); +} + +SBIterator::SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p) + : sort_layout(gss.sort_layout), block_count(GetBlockCountWithEmptyCheck(gss)), block_capacity(gss.block_capacity), + cmp_size(sort_layout.comparison_size), entry_size(sort_layout.entry_size), all_constant(sort_layout.all_constant), + external(gss.external), cmp(ComparisonValue(comparison)), scan(gss.buffer_manager, gss), block_ptr(nullptr), + entry_ptr(nullptr) { + + scan.sb = gss.sorted_blocks[0].get(); + scan.block_idx = block_count; + SetIndex(entry_idx_p); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/string_util.cpp b/src/duckdb/src/common/string_util.cpp new file mode 100644 index 00000000..4f994b6a --- /dev/null +++ b/src/duckdb/src/common/string_util.cpp @@ -0,0 +1,375 @@ +#include "duckdb/common/string_util.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/helper.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace duckdb { + +string StringUtil::GenerateRandomName(idx_t length) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, 15); + + std::stringstream ss; + ss << std::hex; + for (idx_t i = 0; i < length; i++) { + ss << dis(gen); + } + return ss.str(); +} + +bool StringUtil::Contains(const string &haystack, const string &needle) { + return (haystack.find(needle) != string::npos); +} + +void StringUtil::LTrim(string &str) { + auto it = str.begin(); + while (it != str.end() && CharacterIsSpace(*it)) { + it++; + } + str.erase(str.begin(), it); +} + +// Remove trailing ' ', '\f', '\n', '\r', '\t', '\v' +void StringUtil::RTrim(string &str) { + str.erase(find_if(str.rbegin(), str.rend(), [](int ch) { return ch > 0 && !CharacterIsSpace(ch); }).base(), + str.end()); +} + +void StringUtil::RTrim(string &str, const string &chars_to_trim) { + str.erase(find_if(str.rbegin(), str.rend(), + [&chars_to_trim](int ch) { return ch > 0 && chars_to_trim.find(ch) == string::npos; }) + .base(), + str.end()); +} + +void StringUtil::Trim(string &str) { + StringUtil::LTrim(str); + StringUtil::RTrim(str); +} + +bool StringUtil::StartsWith(string str, string prefix) { + if (prefix.size() > str.size()) { + return false; + } + return equal(prefix.begin(), prefix.end(), str.begin()); +} + +bool StringUtil::EndsWith(const string &str, const string &suffix) { + if (suffix.size() > str.size()) { + return false; + } + return equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + +string StringUtil::Repeat(const string &str, idx_t n) { + std::ostringstream os; + for (idx_t i = 0; i < n; i++) { + os << str; + } + return (os.str()); +} + +vector StringUtil::Split(const string &str, char delimiter) { + std::stringstream ss(str); + vector lines; + string temp; + while (getline(ss, temp, delimiter)) { + lines.push_back(temp); + } + return (lines); +} + +namespace string_util_internal { + +inline void SkipSpaces(const string &str, idx_t &index) { + while (index < str.size() && std::isspace(str[index])) { + index++; + } +} + +inline void ConsumeLetter(const string &str, idx_t &index, char expected) { + if (index >= str.size() || str[index] != expected) { + throw ParserException("Invalid quoted list: %s", str); + } + + index++; +} + +template +inline void TakeWhile(const string &str, idx_t &index, const F &cond, string &taker) { + while (index < str.size() && cond(str[index])) { + taker.push_back(str[index]); + index++; + } +} + +inline string TakePossiblyQuotedItem(const string &str, idx_t &index, char delimiter, char quote) { + string entry; + + if (str[index] == quote) { + index++; + TakeWhile( + str, index, [quote](char c) { return c != quote; }, entry); + ConsumeLetter(str, index, quote); + } else { + TakeWhile( + str, index, [delimiter, quote](char c) { return c != delimiter && c != quote && !std::isspace(c); }, entry); + } + + return entry; +} + +} // namespace string_util_internal + +vector StringUtil::SplitWithQuote(const string &str, char delimiter, char quote) { + vector entries; + idx_t i = 0; + + string_util_internal::SkipSpaces(str, i); + while (i < str.size()) { + if (!entries.empty()) { + string_util_internal::ConsumeLetter(str, i, delimiter); + } + + entries.emplace_back(string_util_internal::TakePossiblyQuotedItem(str, i, delimiter, quote)); + string_util_internal::SkipSpaces(str, i); + } + + return entries; +} + +string StringUtil::Join(const vector &input, const string &separator) { + return StringUtil::Join(input, input.size(), separator, [](const string &s) { return s; }); +} + +string StringUtil::BytesToHumanReadableString(idx_t bytes) { + string db_size; + auto kilobytes = bytes / 1000; + auto megabytes = kilobytes / 1000; + kilobytes -= megabytes * 1000; + auto gigabytes = megabytes / 1000; + megabytes -= gigabytes * 1000; + auto terabytes = gigabytes / 1000; + gigabytes -= terabytes * 1000; + auto petabytes = terabytes / 1000; + terabytes -= petabytes * 1000; + if (petabytes > 0) { + return to_string(petabytes) + "." + to_string(terabytes / 100) + "PB"; + } + if (terabytes > 0) { + return to_string(terabytes) + "." + to_string(gigabytes / 100) + "TB"; + } else if (gigabytes > 0) { + return to_string(gigabytes) + "." + to_string(megabytes / 100) + "GB"; + } else if (megabytes > 0) { + return to_string(megabytes) + "." + to_string(kilobytes / 100) + "MB"; + } else if (kilobytes > 0) { + return to_string(kilobytes) + "KB"; + } else { + return to_string(bytes) + (bytes == 1 ? " byte" : " bytes"); + } +} + +string StringUtil::Upper(const string &str) { + string copy(str); + transform(copy.begin(), copy.end(), copy.begin(), [](unsigned char c) { return std::toupper(c); }); + return (copy); +} + +string StringUtil::Lower(const string &str) { + string copy(str); + transform(copy.begin(), copy.end(), copy.begin(), [](unsigned char c) { return StringUtil::CharacterToLower(c); }); + return (copy); +} + +bool StringUtil::IsLower(const string &str) { + return str == Lower(str); +} + +// Jenkins hash function: https://en.wikipedia.org/wiki/Jenkins_hash_function +uint64_t StringUtil::CIHash(const string &str) { + uint32_t hash = 0; + for (auto c : str) { + hash += StringUtil::CharacterToLower(c); + hash += hash << 10; + hash ^= hash >> 6; + } + hash += hash << 3; + hash ^= hash >> 11; + hash += hash << 15; + return hash; +} + +bool StringUtil::CIEquals(const string &l1, const string &l2) { + if (l1.size() != l2.size()) { + return false; + } + for (idx_t c = 0; c < l1.size(); c++) { + if (StringUtil::CharacterToLower(l1[c]) != StringUtil::CharacterToLower(l2[c])) { + return false; + } + } + return true; +} + +vector StringUtil::Split(const string &input, const string &split) { + vector splits; + + idx_t last = 0; + idx_t input_len = input.size(); + idx_t split_len = split.size(); + while (last <= input_len) { + idx_t next = input.find(split, last); + if (next == string::npos) { + next = input_len; + } + + // Push the substring [last, next) on to splits + string substr = input.substr(last, next - last); + if (!substr.empty()) { + splits.push_back(substr); + } + last = next + split_len; + } + if (splits.empty()) { + splits.push_back(input); + } + return splits; +} + +string StringUtil::Replace(string source, const string &from, const string &to) { + if (from.empty()) { + throw InternalException("Invalid argument to StringUtil::Replace - empty FROM"); + } + idx_t start_pos = 0; + while ((start_pos = source.find(from, start_pos)) != string::npos) { + source.replace(start_pos, from.length(), to); + start_pos += to.length(); // In case 'to' contains 'from', like + // replacing 'x' with 'yx' + } + return source; +} + +vector StringUtil::TopNStrings(vector> scores, idx_t n, idx_t threshold) { + if (scores.empty()) { + return vector(); + } + sort(scores.begin(), scores.end(), [](const pair &a, const pair &b) -> bool { + return a.second < b.second || (a.second == b.second && a.first.size() < b.first.size()); + }); + vector result; + result.push_back(scores[0].first); + for (idx_t i = 1; i < MinValue(scores.size(), n); i++) { + if (scores[i].second > threshold) { + break; + } + result.push_back(scores[i].first); + } + return result; +} + +struct LevenshteinArray { + LevenshteinArray(idx_t len1, idx_t len2) : len1(len1) { + dist = make_unsafe_uniq_array(len1 * len2); + } + + idx_t &Score(idx_t i, idx_t j) { + return dist[GetIndex(i, j)]; + } + +private: + idx_t len1; + unsafe_unique_array dist; + + idx_t GetIndex(idx_t i, idx_t j) { + return j * len1 + i; + } +}; + +// adapted from https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#C++ +idx_t StringUtil::LevenshteinDistance(const string &s1_p, const string &s2_p, idx_t not_equal_penalty) { + auto s1 = StringUtil::Lower(s1_p); + auto s2 = StringUtil::Lower(s2_p); + idx_t len1 = s1.size(); + idx_t len2 = s2.size(); + if (len1 == 0) { + return len2; + } + if (len2 == 0) { + return len1; + } + LevenshteinArray array(len1 + 1, len2 + 1); + array.Score(0, 0) = 0; + for (idx_t i = 0; i <= len1; i++) { + array.Score(i, 0) = i; + } + for (idx_t j = 0; j <= len2; j++) { + array.Score(0, j) = j; + } + for (idx_t i = 1; i <= len1; i++) { + for (idx_t j = 1; j <= len2; j++) { + // d[i][j] = std::min({ d[i - 1][j] + 1, + // d[i][j - 1] + 1, + // d[i - 1][j - 1] + (s1[i - 1] == s2[j - 1] ? 0 : 1) }); + int equal = s1[i - 1] == s2[j - 1] ? 0 : not_equal_penalty; + idx_t adjacent_score1 = array.Score(i - 1, j) + 1; + idx_t adjacent_score2 = array.Score(i, j - 1) + 1; + idx_t adjacent_score3 = array.Score(i - 1, j - 1) + equal; + + idx_t t = MinValue(adjacent_score1, adjacent_score2); + array.Score(i, j) = MinValue(t, adjacent_score3); + } + } + return array.Score(len1, len2); +} + +idx_t StringUtil::SimilarityScore(const string &s1, const string &s2) { + return LevenshteinDistance(s1, s2, 3); +} + +vector StringUtil::TopNLevenshtein(const vector &strings, const string &target, idx_t n, + idx_t threshold) { + vector> scores; + scores.reserve(strings.size()); + for (auto &str : strings) { + if (target.size() < str.size()) { + scores.emplace_back(str, SimilarityScore(str.substr(0, target.size()), target)); + } else { + scores.emplace_back(str, SimilarityScore(str, target)); + } + } + return TopNStrings(scores, n, threshold); +} + +string StringUtil::CandidatesMessage(const vector &candidates, const string &candidate) { + string result_str; + if (!candidates.empty()) { + result_str = "\n" + candidate + ": "; + for (idx_t i = 0; i < candidates.size(); i++) { + if (i > 0) { + result_str += ", "; + } + result_str += "\"" + candidates[i] + "\""; + } + } + return result_str; +} + +string StringUtil::CandidatesErrorMessage(const vector &strings, const string &target, + const string &message_prefix, idx_t n) { + auto closest_strings = StringUtil::TopNLevenshtein(strings, target, n); + return StringUtil::CandidatesMessage(closest_strings, message_prefix); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer.cpp b/src/duckdb/src/common/tree_renderer.cpp new file mode 100644 index 00000000..1ce04df6 --- /dev/null +++ b/src/duckdb/src/common/tree_renderer.cpp @@ -0,0 +1,552 @@ +#include "duckdb/common/tree_renderer.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "utf8proc_wrapper.hpp" + +#include + +namespace duckdb { + +RenderTree::RenderTree(idx_t width_p, idx_t height_p) : width(width_p), height(height_p) { + nodes = unique_ptr[]>(new unique_ptr[(width + 1) * (height + 1)]); +} + +RenderTreeNode *RenderTree::GetNode(idx_t x, idx_t y) { + if (x >= width || y >= height) { + return nullptr; + } + return nodes[GetPosition(x, y)].get(); +} + +bool RenderTree::HasNode(idx_t x, idx_t y) { + if (x >= width || y >= height) { + return false; + } + return nodes[GetPosition(x, y)].get() != nullptr; +} + +idx_t RenderTree::GetPosition(idx_t x, idx_t y) { + return y * width + x; +} + +void RenderTree::SetNode(idx_t x, idx_t y, unique_ptr node) { + nodes[GetPosition(x, y)] = std::move(node); +} + +void TreeRenderer::RenderTopLayer(RenderTree &root, std::ostream &ss, idx_t y) { + for (idx_t x = 0; x < root.width; x++) { + if (x * config.NODE_RENDER_WIDTH >= config.MAXIMUM_RENDER_WIDTH) { + break; + } + if (root.HasNode(x, y)) { + ss << config.LTCORNER; + ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2 - 1); + if (y == 0) { + // top level node: no node above this one + ss << config.HORIZONTAL; + } else { + // render connection to node above this one + ss << config.DMIDDLE; + } + ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2 - 1); + ss << config.RTCORNER; + } else { + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); + } + } + ss << std::endl; +} + +void TreeRenderer::RenderBottomLayer(RenderTree &root, std::ostream &ss, idx_t y) { + for (idx_t x = 0; x <= root.width; x++) { + if (x * config.NODE_RENDER_WIDTH >= config.MAXIMUM_RENDER_WIDTH) { + break; + } + if (root.HasNode(x, y)) { + ss << config.LDCORNER; + ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2 - 1); + if (root.HasNode(x, y + 1)) { + // node below this one: connect to that one + ss << config.TMIDDLE; + } else { + // no node below this one: end the box + ss << config.HORIZONTAL; + } + ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2 - 1); + ss << config.RDCORNER; + } else if (root.HasNode(x, y + 1)) { + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); + ss << config.VERTICAL; + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); + } else { + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); + } + } + ss << std::endl; +} + +string AdjustTextForRendering(string source, idx_t max_render_width) { + idx_t cpos = 0; + idx_t render_width = 0; + vector> render_widths; + while (cpos < source.size()) { + idx_t char_render_width = Utf8Proc::RenderWidth(source.c_str(), source.size(), cpos); + cpos = Utf8Proc::NextGraphemeCluster(source.c_str(), source.size(), cpos); + render_width += char_render_width; + render_widths.emplace_back(cpos, render_width); + if (render_width > max_render_width) { + break; + } + } + if (render_width > max_render_width) { + // need to find a position to truncate + for (idx_t pos = render_widths.size(); pos > 0; pos--) { + if (render_widths[pos - 1].second < max_render_width - 4) { + return source.substr(0, render_widths[pos - 1].first) + "..." + + string(max_render_width - render_widths[pos - 1].second - 3, ' '); + } + } + source = "..."; + } + // need to pad with spaces + idx_t total_spaces = max_render_width - render_width; + idx_t half_spaces = total_spaces / 2; + idx_t extra_left_space = total_spaces % 2 == 0 ? 0 : 1; + return string(half_spaces + extra_left_space, ' ') + source + string(half_spaces, ' '); +} + +static bool NodeHasMultipleChildren(RenderTree &root, idx_t x, idx_t y) { + for (; x < root.width && !root.HasNode(x + 1, y); x++) { + if (root.HasNode(x + 1, y + 1)) { + return true; + } + } + return false; +} + +void TreeRenderer::RenderBoxContent(RenderTree &root, std::ostream &ss, idx_t y) { + // we first need to figure out how high our boxes are going to be + vector> extra_info; + idx_t extra_height = 0; + extra_info.resize(root.width); + for (idx_t x = 0; x < root.width; x++) { + auto node = root.GetNode(x, y); + if (node) { + SplitUpExtraInfo(node->extra_text, extra_info[x]); + if (extra_info[x].size() > extra_height) { + extra_height = extra_info[x].size(); + } + } + } + extra_height = MinValue(extra_height, config.MAX_EXTRA_LINES); + idx_t halfway_point = (extra_height + 1) / 2; + // now we render the actual node + for (idx_t render_y = 0; render_y <= extra_height; render_y++) { + for (idx_t x = 0; x < root.width; x++) { + if (x * config.NODE_RENDER_WIDTH >= config.MAXIMUM_RENDER_WIDTH) { + break; + } + auto node = root.GetNode(x, y); + if (!node) { + if (render_y == halfway_point) { + bool has_child_to_the_right = NodeHasMultipleChildren(root, x, y); + if (root.HasNode(x, y + 1)) { + // node right below this one + ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2); + ss << config.RTCORNER; + if (has_child_to_the_right) { + // but we have another child to the right! keep rendering the line + ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH / 2); + } else { + // only a child below this one: fill the rest with spaces + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); + } + } else if (has_child_to_the_right) { + // child to the right, but no child right below this one: render a full line + ss << StringUtil::Repeat(config.HORIZONTAL, config.NODE_RENDER_WIDTH); + } else { + // empty spot: render spaces + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); + } + } else if (render_y >= halfway_point) { + if (root.HasNode(x, y + 1)) { + // we have a node below this empty spot: render a vertical line + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); + ss << config.VERTICAL; + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH / 2); + } else { + // empty spot: render spaces + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); + } + } else { + // empty spot: render spaces + ss << StringUtil::Repeat(" ", config.NODE_RENDER_WIDTH); + } + } else { + ss << config.VERTICAL; + // figure out what to render + string render_text; + if (render_y == 0) { + render_text = node->name; + } else { + if (render_y <= extra_info[x].size()) { + render_text = extra_info[x][render_y - 1]; + } + } + render_text = AdjustTextForRendering(render_text, config.NODE_RENDER_WIDTH - 2); + ss << render_text; + + if (render_y == halfway_point && NodeHasMultipleChildren(root, x, y)) { + ss << config.LMIDDLE; + } else { + ss << config.VERTICAL; + } + } + } + ss << std::endl; + } +} + +string TreeRenderer::ToString(const LogicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string TreeRenderer::ToString(const PhysicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string TreeRenderer::ToString(const QueryProfiler::TreeNode &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string TreeRenderer::ToString(const Pipeline &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +void TreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { + auto tree = CreateTree(op); + ToStream(*tree, ss); +} + +void TreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { + auto tree = CreateTree(op); + ToStream(*tree, ss); +} + +void TreeRenderer::Render(const QueryProfiler::TreeNode &op, std::ostream &ss) { + auto tree = CreateTree(op); + ToStream(*tree, ss); +} + +void TreeRenderer::Render(const Pipeline &op, std::ostream &ss) { + auto tree = CreateTree(op); + ToStream(*tree, ss); +} + +void TreeRenderer::ToStream(RenderTree &root, std::ostream &ss) { + while (root.width * config.NODE_RENDER_WIDTH > config.MAXIMUM_RENDER_WIDTH) { + if (config.NODE_RENDER_WIDTH - 2 < config.MINIMUM_RENDER_WIDTH) { + break; + } + config.NODE_RENDER_WIDTH -= 2; + } + + for (idx_t y = 0; y < root.height; y++) { + // start by rendering the top layer + RenderTopLayer(root, ss, y); + // now we render the content of the boxes + RenderBoxContent(root, ss, y); + // render the bottom layer of each of the boxes + RenderBottomLayer(root, ss, y); + } +} + +bool TreeRenderer::CanSplitOnThisChar(char l) { + return (l < '0' || (l > '9' && l < 'A') || (l > 'Z' && l < 'a')) && l != '_'; +} + +bool TreeRenderer::IsPadding(char l) { + return l == ' ' || l == '\t' || l == '\n' || l == '\r'; +} + +string TreeRenderer::RemovePadding(string l) { + idx_t start = 0, end = l.size(); + while (start < l.size() && IsPadding(l[start])) { + start++; + } + while (end > 0 && IsPadding(l[end - 1])) { + end--; + } + return l.substr(start, end - start); +} + +void TreeRenderer::SplitStringBuffer(const string &source, vector &result) { + D_ASSERT(Utf8Proc::IsValid(source.c_str(), source.size())); + idx_t max_line_render_size = config.NODE_RENDER_WIDTH - 2; + // utf8 in prompt, get render width + idx_t cpos = 0; + idx_t start_pos = 0; + idx_t render_width = 0; + idx_t last_possible_split = 0; + while (cpos < source.size()) { + // check if we can split on this character + if (CanSplitOnThisChar(source[cpos])) { + last_possible_split = cpos; + } + size_t char_render_width = Utf8Proc::RenderWidth(source.c_str(), source.size(), cpos); + idx_t next_cpos = Utf8Proc::NextGraphemeCluster(source.c_str(), source.size(), cpos); + if (render_width + char_render_width > max_line_render_size) { + if (last_possible_split <= start_pos + 8) { + last_possible_split = cpos; + } + result.push_back(source.substr(start_pos, last_possible_split - start_pos)); + start_pos = last_possible_split; + cpos = last_possible_split; + render_width = 0; + } + cpos = next_cpos; + render_width += char_render_width; + } + if (source.size() > start_pos) { + result.push_back(source.substr(start_pos, source.size() - start_pos)); + } +} + +void TreeRenderer::SplitUpExtraInfo(const string &extra_info, vector &result) { + if (extra_info.empty()) { + return; + } + if (!Utf8Proc::IsValid(extra_info.c_str(), extra_info.size())) { + return; + } + auto splits = StringUtil::Split(extra_info, "\n"); + if (!splits.empty() && splits[0] != "[INFOSEPARATOR]") { + result.push_back(ExtraInfoSeparator()); + } + for (auto &split : splits) { + if (split == "[INFOSEPARATOR]") { + result.push_back(ExtraInfoSeparator()); + continue; + } + string str = RemovePadding(split); + if (str.empty()) { + continue; + } + SplitStringBuffer(str, result); + } +} + +string TreeRenderer::ExtraInfoSeparator() { + return StringUtil::Repeat(string(config.HORIZONTAL) + " ", (config.NODE_RENDER_WIDTH - 7) / 2); +} + +unique_ptr TreeRenderer::CreateRenderNode(string name, string extra_info) { + auto result = make_uniq(); + result->name = std::move(name); + result->extra_text = std::move(extra_info); + return result; +} + +class TreeChildrenIterator { +public: + template + static bool HasChildren(const T &op) { + return !op.children.empty(); + } + template + static void Iterate(const T &op, const std::function &callback) { + for (auto &child : op.children) { + callback(*child); + } + } +}; + +template <> +bool TreeChildrenIterator::HasChildren(const PhysicalOperator &op) { + switch (op.type) { + case PhysicalOperatorType::DELIM_JOIN: + case PhysicalOperatorType::POSITIONAL_SCAN: + return true; + default: + return !op.children.empty(); + } +} +template <> +void TreeChildrenIterator::Iterate(const PhysicalOperator &op, + const std::function &callback) { + for (auto &child : op.children) { + callback(*child); + } + if (op.type == PhysicalOperatorType::DELIM_JOIN) { + auto &delim = op.Cast(); + callback(*delim.join); + } else if ((op.type == PhysicalOperatorType::POSITIONAL_SCAN)) { + auto &pscan = op.Cast(); + for (auto &table : pscan.child_tables) { + callback(*table); + } + } +} + +struct PipelineRenderNode { + explicit PipelineRenderNode(const PhysicalOperator &op) : op(op) { + } + + const PhysicalOperator &op; + unique_ptr child; +}; + +template <> +bool TreeChildrenIterator::HasChildren(const PipelineRenderNode &op) { + return op.child.get(); +} + +template <> +void TreeChildrenIterator::Iterate(const PipelineRenderNode &op, + const std::function &callback) { + if (op.child) { + callback(*op.child); + } +} + +template +static void GetTreeWidthHeight(const T &op, idx_t &width, idx_t &height) { + if (!TreeChildrenIterator::HasChildren(op)) { + width = 1; + height = 1; + return; + } + width = 0; + height = 0; + + TreeChildrenIterator::Iterate(op, [&](const T &child) { + idx_t child_width, child_height; + GetTreeWidthHeight(child, child_width, child_height); + width += child_width; + height = MaxValue(height, child_height); + }); + height++; +} + +template +idx_t TreeRenderer::CreateRenderTreeRecursive(RenderTree &result, const T &op, idx_t x, idx_t y) { + auto node = TreeRenderer::CreateNode(op); + result.SetNode(x, y, std::move(node)); + + if (!TreeChildrenIterator::HasChildren(op)) { + return 1; + } + idx_t width = 0; + // render the children of this node + TreeChildrenIterator::Iterate( + op, [&](const T &child) { width += CreateRenderTreeRecursive(result, child, x + width, y + 1); }); + return width; +} + +template +unique_ptr TreeRenderer::CreateRenderTree(const T &op) { + idx_t width, height; + GetTreeWidthHeight(op, width, height); + + auto result = make_uniq(width, height); + + // now fill in the tree + CreateRenderTreeRecursive(*result, op, 0, 0); + return result; +} + +unique_ptr TreeRenderer::CreateNode(const LogicalOperator &op) { + return CreateRenderNode(op.GetName(), op.ParamsToString()); +} + +unique_ptr TreeRenderer::CreateNode(const PhysicalOperator &op) { + return CreateRenderNode(op.GetName(), op.ParamsToString()); +} + +unique_ptr TreeRenderer::CreateNode(const PipelineRenderNode &op) { + return CreateNode(op.op); +} + +string TreeRenderer::ExtractExpressionsRecursive(ExpressionInfo &state) { + string result = "\n[INFOSEPARATOR]"; + result += "\n" + state.function_name; + result += "\n" + StringUtil::Format("%.9f", double(state.function_time)); + if (state.children.empty()) { + return result; + } + // render the children of this node + for (auto &child : state.children) { + result += ExtractExpressionsRecursive(*child); + } + return result; +} + +unique_ptr TreeRenderer::CreateNode(const QueryProfiler::TreeNode &op) { + auto result = TreeRenderer::CreateRenderNode(op.name, op.extra_info); + result->extra_text += "\n[INFOSEPARATOR]"; + result->extra_text += "\n" + to_string(op.info.elements); + string timing = StringUtil::Format("%.2f", op.info.time); + result->extra_text += "\n(" + timing + "s)"; + if (config.detailed) { + for (auto &info : op.info.executors_info) { + if (!info) { + continue; + } + for (auto &executor_info : info->roots) { + string sample_count = to_string(executor_info->sample_count); + result->extra_text += "\n[INFOSEPARATOR]"; + result->extra_text += "\nsample_count: " + sample_count; + string sample_tuples_count = to_string(executor_info->sample_tuples_count); + result->extra_text += "\n[INFOSEPARATOR]"; + result->extra_text += "\nsample_tuples_count: " + sample_tuples_count; + string total_count = to_string(executor_info->total_count); + result->extra_text += "\n[INFOSEPARATOR]"; + result->extra_text += "\ntotal_count: " + total_count; + for (auto &state : executor_info->root->children) { + result->extra_text += ExtractExpressionsRecursive(*state); + } + } + } + } + return result; +} + +unique_ptr TreeRenderer::CreateTree(const LogicalOperator &op) { + return CreateRenderTree(op); +} + +unique_ptr TreeRenderer::CreateTree(const PhysicalOperator &op) { + return CreateRenderTree(op); +} + +unique_ptr TreeRenderer::CreateTree(const QueryProfiler::TreeNode &op) { + return CreateRenderTree(op); +} + +unique_ptr TreeRenderer::CreateTree(const Pipeline &op) { + auto operators = op.GetOperators(); + D_ASSERT(!operators.empty()); + unique_ptr node; + for (auto &op : operators) { + auto new_node = make_uniq(op.get()); + new_node->child = std::move(node); + node = std::move(new_node); + } + return CreateRenderTree(*node); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp new file mode 100644 index 00000000..043ac3e7 --- /dev/null +++ b/src/duckdb/src/common/types.cpp @@ -0,0 +1,1094 @@ +#include "duckdb/common/types.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/catalog/default/default_types.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/function/cast_rules.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/common/extra_type_info.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include + +namespace duckdb { + +LogicalType::LogicalType() : LogicalType(LogicalTypeId::INVALID) { +} + +LogicalType::LogicalType(LogicalTypeId id) : id_(id) { + physical_type_ = GetInternalType(); +} +LogicalType::LogicalType(LogicalTypeId id, shared_ptr type_info_p) + : id_(id), type_info_(std::move(type_info_p)) { + physical_type_ = GetInternalType(); +} + +LogicalType::LogicalType(const LogicalType &other) + : id_(other.id_), physical_type_(other.physical_type_), type_info_(other.type_info_) { +} + +LogicalType::LogicalType(LogicalType &&other) noexcept + : id_(other.id_), physical_type_(other.physical_type_), type_info_(std::move(other.type_info_)) { +} + +hash_t LogicalType::Hash() const { + return duckdb::Hash((uint8_t)id_); +} + +PhysicalType LogicalType::GetInternalType() { + switch (id_) { + case LogicalTypeId::BOOLEAN: + return PhysicalType::BOOL; + case LogicalTypeId::TINYINT: + return PhysicalType::INT8; + case LogicalTypeId::UTINYINT: + return PhysicalType::UINT8; + case LogicalTypeId::SMALLINT: + return PhysicalType::INT16; + case LogicalTypeId::USMALLINT: + return PhysicalType::UINT16; + case LogicalTypeId::SQLNULL: + case LogicalTypeId::DATE: + case LogicalTypeId::INTEGER: + return PhysicalType::INT32; + case LogicalTypeId::UINTEGER: + return PhysicalType::UINT32; + case LogicalTypeId::BIGINT: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIMESTAMP_TZ: + return PhysicalType::INT64; + case LogicalTypeId::UBIGINT: + return PhysicalType::UINT64; + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UUID: + return PhysicalType::INT128; + case LogicalTypeId::FLOAT: + return PhysicalType::FLOAT; + case LogicalTypeId::DOUBLE: + return PhysicalType::DOUBLE; + case LogicalTypeId::DECIMAL: { + if (!type_info_) { + return PhysicalType::INVALID; + } + auto width = DecimalType::GetWidth(*this); + if (width <= Decimal::MAX_WIDTH_INT16) { + return PhysicalType::INT16; + } else if (width <= Decimal::MAX_WIDTH_INT32) { + return PhysicalType::INT32; + } else if (width <= Decimal::MAX_WIDTH_INT64) { + return PhysicalType::INT64; + } else if (width <= Decimal::MAX_WIDTH_INT128) { + return PhysicalType::INT128; + } else { + throw InternalException("Decimal has a width of %d which is bigger than the maximum supported width of %d", + width, DecimalType::MaxWidth()); + } + } + case LogicalTypeId::VARCHAR: + case LogicalTypeId::CHAR: + case LogicalTypeId::BLOB: + case LogicalTypeId::BIT: + return PhysicalType::VARCHAR; + case LogicalTypeId::INTERVAL: + return PhysicalType::INTERVAL; + case LogicalTypeId::UNION: + case LogicalTypeId::STRUCT: + return PhysicalType::STRUCT; + case LogicalTypeId::LIST: + case LogicalTypeId::MAP: + return PhysicalType::LIST; + case LogicalTypeId::POINTER: + // LCOV_EXCL_START + if (sizeof(uintptr_t) == sizeof(uint32_t)) { + return PhysicalType::UINT32; + } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { + return PhysicalType::UINT64; + } else { + throw InternalException("Unsupported pointer size"); + } + // LCOV_EXCL_STOP + case LogicalTypeId::VALIDITY: + return PhysicalType::BIT; + case LogicalTypeId::ENUM: { + if (!type_info_) { + return PhysicalType::INVALID; + } + return EnumType::GetPhysicalType(*this); + } + case LogicalTypeId::TABLE: + case LogicalTypeId::LAMBDA: + case LogicalTypeId::ANY: + case LogicalTypeId::INVALID: + case LogicalTypeId::UNKNOWN: + return PhysicalType::INVALID; + case LogicalTypeId::USER: + return PhysicalType::UNKNOWN; + case LogicalTypeId::AGGREGATE_STATE: + return PhysicalType::VARCHAR; + default: + throw InternalException("Invalid LogicalType %s", ToString()); + } +} + +// **DEPRECATED**: Use EnumUtil directly instead. +string LogicalTypeIdToString(LogicalTypeId type) { + return EnumUtil::ToString(type); +} + +constexpr const LogicalTypeId LogicalType::INVALID; +constexpr const LogicalTypeId LogicalType::SQLNULL; +constexpr const LogicalTypeId LogicalType::BOOLEAN; +constexpr const LogicalTypeId LogicalType::TINYINT; +constexpr const LogicalTypeId LogicalType::UTINYINT; +constexpr const LogicalTypeId LogicalType::SMALLINT; +constexpr const LogicalTypeId LogicalType::USMALLINT; +constexpr const LogicalTypeId LogicalType::INTEGER; +constexpr const LogicalTypeId LogicalType::UINTEGER; +constexpr const LogicalTypeId LogicalType::BIGINT; +constexpr const LogicalTypeId LogicalType::UBIGINT; +constexpr const LogicalTypeId LogicalType::HUGEINT; +constexpr const LogicalTypeId LogicalType::UUID; +constexpr const LogicalTypeId LogicalType::FLOAT; +constexpr const LogicalTypeId LogicalType::DOUBLE; +constexpr const LogicalTypeId LogicalType::DATE; + +constexpr const LogicalTypeId LogicalType::TIMESTAMP; +constexpr const LogicalTypeId LogicalType::TIMESTAMP_MS; +constexpr const LogicalTypeId LogicalType::TIMESTAMP_NS; +constexpr const LogicalTypeId LogicalType::TIMESTAMP_S; + +constexpr const LogicalTypeId LogicalType::TIME; + +constexpr const LogicalTypeId LogicalType::TIME_TZ; +constexpr const LogicalTypeId LogicalType::TIMESTAMP_TZ; + +constexpr const LogicalTypeId LogicalType::HASH; +constexpr const LogicalTypeId LogicalType::POINTER; + +constexpr const LogicalTypeId LogicalType::VARCHAR; + +constexpr const LogicalTypeId LogicalType::BLOB; +constexpr const LogicalTypeId LogicalType::BIT; +constexpr const LogicalTypeId LogicalType::INTERVAL; +constexpr const LogicalTypeId LogicalType::ROW_TYPE; + +// TODO these are incomplete and should maybe not exist as such +constexpr const LogicalTypeId LogicalType::TABLE; +constexpr const LogicalTypeId LogicalType::LAMBDA; + +constexpr const LogicalTypeId LogicalType::ANY; + +const vector LogicalType::Numeric() { + vector types = {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, + LogicalType::BIGINT, LogicalType::HUGEINT, LogicalType::FLOAT, + LogicalType::DOUBLE, LogicalTypeId::DECIMAL, LogicalType::UTINYINT, + LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT}; + return types; +} + +const vector LogicalType::Integral() { + vector types = {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, + LogicalType::BIGINT, LogicalType::HUGEINT, LogicalType::UTINYINT, + LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT}; + return types; +} + +const vector LogicalType::AllTypes() { + vector types = { + LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, + LogicalType::BIGINT, LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::DOUBLE, + LogicalType::FLOAT, LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::BIT, + LogicalType::INTERVAL, LogicalType::HUGEINT, LogicalTypeId::DECIMAL, LogicalType::UTINYINT, + LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, LogicalType::TIME, + LogicalTypeId::LIST, LogicalTypeId::STRUCT, LogicalType::TIME_TZ, LogicalType::TIMESTAMP_TZ, + LogicalTypeId::MAP, LogicalTypeId::UNION, LogicalType::UUID}; + return types; +} + +const PhysicalType ROW_TYPE = PhysicalType::INT64; + +// LCOV_EXCL_START +string TypeIdToString(PhysicalType type) { + switch (type) { + case PhysicalType::BOOL: + return "BOOL"; + case PhysicalType::INT8: + return "INT8"; + case PhysicalType::INT16: + return "INT16"; + case PhysicalType::INT32: + return "INT32"; + case PhysicalType::INT64: + return "INT64"; + case PhysicalType::UINT8: + return "UINT8"; + case PhysicalType::UINT16: + return "UINT16"; + case PhysicalType::UINT32: + return "UINT32"; + case PhysicalType::UINT64: + return "UINT64"; + case PhysicalType::INT128: + return "INT128"; + case PhysicalType::FLOAT: + return "FLOAT"; + case PhysicalType::DOUBLE: + return "DOUBLE"; + case PhysicalType::VARCHAR: + return "VARCHAR"; + case PhysicalType::INTERVAL: + return "INTERVAL"; + case PhysicalType::STRUCT: + return "STRUCT"; + case PhysicalType::LIST: + return "LIST"; + case PhysicalType::INVALID: + return "INVALID"; + case PhysicalType::BIT: + return "BIT"; + case PhysicalType::UNKNOWN: + return "UNKNOWN"; + } + return "INVALID"; +} +// LCOV_EXCL_STOP + +idx_t GetTypeIdSize(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + case PhysicalType::BOOL: + return sizeof(bool); + case PhysicalType::INT8: + return sizeof(int8_t); + case PhysicalType::INT16: + return sizeof(int16_t); + case PhysicalType::INT32: + return sizeof(int32_t); + case PhysicalType::INT64: + return sizeof(int64_t); + case PhysicalType::UINT8: + return sizeof(uint8_t); + case PhysicalType::UINT16: + return sizeof(uint16_t); + case PhysicalType::UINT32: + return sizeof(uint32_t); + case PhysicalType::UINT64: + return sizeof(uint64_t); + case PhysicalType::INT128: + return sizeof(hugeint_t); + case PhysicalType::FLOAT: + return sizeof(float); + case PhysicalType::DOUBLE: + return sizeof(double); + case PhysicalType::VARCHAR: + return sizeof(string_t); + case PhysicalType::INTERVAL: + return sizeof(interval_t); + case PhysicalType::STRUCT: + case PhysicalType::UNKNOWN: + return 0; // no own payload + case PhysicalType::LIST: + return sizeof(list_entry_t); // offset + len + default: + throw InternalException("Invalid PhysicalType for GetTypeIdSize"); + } +} + +bool TypeIsConstantSize(PhysicalType type) { + return (type >= PhysicalType::BOOL && type <= PhysicalType::DOUBLE) || type == PhysicalType::INTERVAL || + type == PhysicalType::INT128; +} +bool TypeIsIntegral(PhysicalType type) { + return (type >= PhysicalType::UINT8 && type <= PhysicalType::INT64) || type == PhysicalType::INT128; +} +bool TypeIsNumeric(PhysicalType type) { + return (type >= PhysicalType::UINT8 && type <= PhysicalType::DOUBLE) || type == PhysicalType::INT128; +} +bool TypeIsInteger(PhysicalType type) { + return (type >= PhysicalType::UINT8 && type <= PhysicalType::INT64) || type == PhysicalType::INT128; +} + +string LogicalType::ToString() const { + auto alias = GetAlias(); + if (!alias.empty()) { + return alias; + } + switch (id_) { + case LogicalTypeId::STRUCT: { + if (!type_info_) { + return "STRUCT"; + } + auto &child_types = StructType::GetChildTypes(*this); + string ret = "STRUCT("; + for (size_t i = 0; i < child_types.size(); i++) { + ret += StringUtil::Format("%s %s", SQLIdentifier(child_types[i].first), child_types[i].second); + if (i < child_types.size() - 1) { + ret += ", "; + } + } + ret += ")"; + return ret; + } + case LogicalTypeId::LIST: { + if (!type_info_) { + return "LIST"; + } + return ListType::GetChildType(*this).ToString() + "[]"; + } + case LogicalTypeId::MAP: { + if (!type_info_) { + return "MAP"; + } + auto &key_type = MapType::KeyType(*this); + auto &value_type = MapType::ValueType(*this); + return "MAP(" + key_type.ToString() + ", " + value_type.ToString() + ")"; + } + case LogicalTypeId::UNION: { + if (!type_info_) { + return "UNION"; + } + string ret = "UNION("; + size_t count = UnionType::GetMemberCount(*this); + for (size_t i = 0; i < count; i++) { + ret += UnionType::GetMemberName(*this, i) + " " + UnionType::GetMemberType(*this, i).ToString(); + if (i < count - 1) { + ret += ", "; + } + } + ret += ")"; + return ret; + } + case LogicalTypeId::DECIMAL: { + if (!type_info_) { + return "DECIMAL"; + } + auto width = DecimalType::GetWidth(*this); + auto scale = DecimalType::GetScale(*this); + if (width == 0) { + return "DECIMAL"; + } + return StringUtil::Format("DECIMAL(%d,%d)", width, scale); + } + case LogicalTypeId::ENUM: { + string ret = "ENUM("; + for (idx_t i = 0; i < EnumType::GetSize(*this); i++) { + if (i > 0) { + ret += ", "; + } + ret += KeywordHelper::WriteQuoted(EnumType::GetString(*this, i).GetString(), '\''); + } + ret += ")"; + return ret; + } + case LogicalTypeId::USER: { + return KeywordHelper::WriteOptionallyQuoted(UserType::GetTypeName(*this)); + } + case LogicalTypeId::AGGREGATE_STATE: { + return AggregateStateType::GetTypeName(*this); + } + default: + return EnumUtil::ToString(id_); + } +} +// LCOV_EXCL_STOP + +LogicalTypeId TransformStringToLogicalTypeId(const string &str) { + auto type = DefaultTypeGenerator::GetDefaultType(str); + if (type == LogicalTypeId::INVALID) { + // This is a User Type, at this point we don't know if its one of the User Defined Types or an error + // It is checked in the binder + type = LogicalTypeId::USER; + } + return type; +} + +LogicalType TransformStringToLogicalType(const string &str) { + if (StringUtil::Lower(str) == "null") { + return LogicalType::SQLNULL; + } + return Parser::ParseColumnList("dummy " + str).GetColumn(LogicalIndex(0)).Type(); +} + +LogicalType GetUserTypeRecursive(const LogicalType &type, ClientContext &context) { + if (type.id() == LogicalTypeId::USER && type.HasAlias()) { + return Catalog::GetType(context, INVALID_CATALOG, INVALID_SCHEMA, type.GetAlias()); + } + // Look for LogicalTypeId::USER in nested types + if (type.id() == LogicalTypeId::STRUCT) { + child_list_t children; + children.reserve(StructType::GetChildCount(type)); + for (auto &child : StructType::GetChildTypes(type)) { + children.emplace_back(child.first, GetUserTypeRecursive(child.second, context)); + } + return LogicalType::STRUCT(children); + } + if (type.id() == LogicalTypeId::LIST) { + return LogicalType::LIST(GetUserTypeRecursive(ListType::GetChildType(type), context)); + } + if (type.id() == LogicalTypeId::MAP) { + return LogicalType::MAP(GetUserTypeRecursive(MapType::KeyType(type), context), + GetUserTypeRecursive(MapType::ValueType(type), context)); + } + // Not LogicalTypeId::USER or a nested type + return type; +} + +LogicalType TransformStringToLogicalType(const string &str, ClientContext &context) { + return GetUserTypeRecursive(TransformStringToLogicalType(str), context); +} + +bool LogicalType::IsIntegral() const { + switch (id_) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::HUGEINT: + return true; + default: + return false; + } +} + +bool LogicalType::IsNumeric() const { + switch (id_) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + return true; + default: + return false; + } +} + +bool LogicalType::IsValid() const { + return id() != LogicalTypeId::INVALID && id() != LogicalTypeId::UNKNOWN; +} + +bool LogicalType::GetDecimalProperties(uint8_t &width, uint8_t &scale) const { + switch (id_) { + case LogicalTypeId::SQLNULL: + width = 0; + scale = 0; + break; + case LogicalTypeId::BOOLEAN: + width = 1; + scale = 0; + break; + case LogicalTypeId::TINYINT: + // tinyint: [-127, 127] = DECIMAL(3,0) + width = 3; + scale = 0; + break; + case LogicalTypeId::SMALLINT: + // smallint: [-32767, 32767] = DECIMAL(5,0) + width = 5; + scale = 0; + break; + case LogicalTypeId::INTEGER: + // integer: [-2147483647, 2147483647] = DECIMAL(10,0) + width = 10; + scale = 0; + break; + case LogicalTypeId::BIGINT: + // bigint: [-9223372036854775807, 9223372036854775807] = DECIMAL(19,0) + width = 19; + scale = 0; + break; + case LogicalTypeId::UTINYINT: + // UInt8 — [0 : 255] + width = 3; + scale = 0; + break; + case LogicalTypeId::USMALLINT: + // UInt16 — [0 : 65535] + width = 5; + scale = 0; + break; + case LogicalTypeId::UINTEGER: + // UInt32 — [0 : 4294967295] + width = 10; + scale = 0; + break; + case LogicalTypeId::UBIGINT: + // UInt64 — [0 : 18446744073709551615] + width = 20; + scale = 0; + break; + case LogicalTypeId::HUGEINT: + // hugeint: max size decimal (38, 0) + // note that a hugeint is not guaranteed to fit in this + width = 38; + scale = 0; + break; + case LogicalTypeId::DECIMAL: + width = DecimalType::GetWidth(*this); + scale = DecimalType::GetScale(*this); + break; + default: + // Nonsense values to ensure initialization + width = 255u; + scale = 255u; + // FIXME(carlo): This should be probably a throw, requires checkign the various call-sites + return false; + } + return true; +} + +//! Grows Decimal width/scale when appropriate +static LogicalType DecimalSizeCheck(const LogicalType &left, const LogicalType &right) { + D_ASSERT(left.id() == LogicalTypeId::DECIMAL || right.id() == LogicalTypeId::DECIMAL); + D_ASSERT(left.id() != right.id()); + + //! Make sure the 'right' is the DECIMAL type + if (left.id() == LogicalTypeId::DECIMAL) { + return DecimalSizeCheck(right, left); + } + auto width = DecimalType::GetWidth(right); + auto scale = DecimalType::GetScale(right); + + uint8_t other_width; + uint8_t other_scale; + bool success = left.GetDecimalProperties(other_width, other_scale); + if (!success) { + throw InternalException("Type provided to DecimalSizeCheck was not a numeric type"); + } + D_ASSERT(other_scale == 0); + const auto effective_width = width - scale; + if (other_width > effective_width) { + auto new_width = other_width + scale; + //! Cap the width at max, if an actual value exceeds this, an exception will be thrown later + if (new_width > DecimalType::MaxWidth()) { + new_width = DecimalType::MaxWidth(); + } + return LogicalType::DECIMAL(new_width, scale); + } + return right; +} + +static LogicalType CombineNumericTypes(const LogicalType &left, const LogicalType &right) { + D_ASSERT(left.id() != right.id()); + if (left.id() > right.id()) { + // this method is symmetric + // arrange it so the left type is smaller to limit the number of options we need to check + return CombineNumericTypes(right, left); + } + if (CastRules::ImplicitCast(left, right) >= 0) { + // we can implicitly cast left to right, return right + //! Depending on the type, we might need to grow the `width` of the DECIMAL type + if (right.id() == LogicalTypeId::DECIMAL) { + return DecimalSizeCheck(left, right); + } + return right; + } + if (CastRules::ImplicitCast(right, left) >= 0) { + // we can implicitly cast right to left, return left + //! Depending on the type, we might need to grow the `width` of the DECIMAL type + if (left.id() == LogicalTypeId::DECIMAL) { + return DecimalSizeCheck(right, left); + } + return left; + } + // we can't cast implicitly either way and types are not equal + // this happens when left is signed and right is unsigned + // e.g. INTEGER and UINTEGER + // in this case we need to upcast to make sure the types fit + + if (left.id() == LogicalTypeId::BIGINT || right.id() == LogicalTypeId::UBIGINT) { + return LogicalType::HUGEINT; + } + if (left.id() == LogicalTypeId::INTEGER || right.id() == LogicalTypeId::UINTEGER) { + return LogicalType::BIGINT; + } + if (left.id() == LogicalTypeId::SMALLINT || right.id() == LogicalTypeId::USMALLINT) { + return LogicalType::INTEGER; + } + if (left.id() == LogicalTypeId::TINYINT || right.id() == LogicalTypeId::UTINYINT) { + return LogicalType::SMALLINT; + } + throw InternalException("Cannot combine these numeric types!?"); +} + +LogicalType LogicalType::MaxLogicalType(const LogicalType &left, const LogicalType &right) { + // we always prefer aliased types + if (!left.GetAlias().empty()) { + return left; + } + if (!right.GetAlias().empty()) { + return right; + } + if (left.id() != right.id() && left.IsNumeric() && right.IsNumeric()) { + return CombineNumericTypes(left, right); + } else if (left.id() == LogicalTypeId::UNKNOWN) { + return right; + } else if (right.id() == LogicalTypeId::UNKNOWN) { + return left; + } else if ((right.id() == LogicalTypeId::ENUM || left.id() == LogicalTypeId::ENUM) && right.id() != left.id()) { + // if one is an enum and the other is not, compare strings, not enums + // see https://github.com/duckdb/duckdb/issues/8561 + return LogicalTypeId::VARCHAR; + } else if (left.id() < right.id()) { + return right; + } + if (right.id() < left.id()) { + return left; + } + // Since both left and right are equal we get the left type as our type_id for checks + auto type_id = left.id(); + if (type_id == LogicalTypeId::ENUM) { + // If both types are different ENUMs we do a string comparison. + return left == right ? left : LogicalType::VARCHAR; + } + if (type_id == LogicalTypeId::VARCHAR) { + // varchar: use type that has collation (if any) + if (StringType::GetCollation(right).empty()) { + return left; + } + return right; + } + if (type_id == LogicalTypeId::DECIMAL) { + // unify the width/scale so that the resulting decimal always fits + // "width - scale" gives us the number of digits on the left side of the decimal point + // "scale" gives us the number of digits allowed on the right of the decimal point + // using the max of these of the two types gives us the new decimal size + auto extra_width_left = DecimalType::GetWidth(left) - DecimalType::GetScale(left); + auto extra_width_right = DecimalType::GetWidth(right) - DecimalType::GetScale(right); + auto extra_width = MaxValue(extra_width_left, extra_width_right); + auto scale = MaxValue(DecimalType::GetScale(left), DecimalType::GetScale(right)); + auto width = extra_width + scale; + if (width > DecimalType::MaxWidth()) { + // if the resulting decimal does not fit, we truncate the scale + width = DecimalType::MaxWidth(); + scale = width - extra_width; + } + return LogicalType::DECIMAL(width, scale); + } + if (type_id == LogicalTypeId::LIST) { + // list: perform max recursively on child type + auto new_child = MaxLogicalType(ListType::GetChildType(left), ListType::GetChildType(right)); + return LogicalType::LIST(new_child); + } + if (type_id == LogicalTypeId::MAP) { + // list: perform max recursively on child type + auto new_child = MaxLogicalType(ListType::GetChildType(left), ListType::GetChildType(right)); + return LogicalType::MAP(new_child); + } + if (type_id == LogicalTypeId::STRUCT) { + // struct: perform recursively + auto &left_child_types = StructType::GetChildTypes(left); + auto &right_child_types = StructType::GetChildTypes(right); + if (left_child_types.size() != right_child_types.size()) { + // child types are not of equal size, we can't cast anyway + // just return the left child + return left; + } + child_list_t child_types; + for (idx_t i = 0; i < left_child_types.size(); i++) { + auto child_type = MaxLogicalType(left_child_types[i].second, right_child_types[i].second); + child_types.emplace_back(left_child_types[i].first, std::move(child_type)); + } + + return LogicalType::STRUCT(child_types); + } + if (type_id == LogicalTypeId::UNION) { + auto left_member_count = UnionType::GetMemberCount(left); + auto right_member_count = UnionType::GetMemberCount(right); + if (left_member_count != right_member_count) { + // return the "larger" type, with the most members + return left_member_count > right_member_count ? left : right; + } + // otherwise, keep left, don't try to meld the two together. + return left; + } + // types are equal but no extra specifier: just return the type + return left; +} + +void LogicalType::Verify() const { +#ifdef DEBUG + if (id_ == LogicalTypeId::DECIMAL) { + D_ASSERT(DecimalType::GetWidth(*this) >= 1 && DecimalType::GetWidth(*this) <= Decimal::MAX_WIDTH_DECIMAL); + D_ASSERT(DecimalType::GetScale(*this) >= 0 && DecimalType::GetScale(*this) <= DecimalType::GetWidth(*this)); + } +#endif +} + +bool ApproxEqual(float ldecimal, float rdecimal) { + if (Value::IsNan(ldecimal) && Value::IsNan(rdecimal)) { + return true; + } + if (!Value::FloatIsFinite(ldecimal) || !Value::FloatIsFinite(rdecimal)) { + return ldecimal == rdecimal; + } + float epsilon = std::fabs(rdecimal) * 0.01 + 0.00000001; + return std::fabs(ldecimal - rdecimal) <= epsilon; +} + +bool ApproxEqual(double ldecimal, double rdecimal) { + if (Value::IsNan(ldecimal) && Value::IsNan(rdecimal)) { + return true; + } + if (!Value::DoubleIsFinite(ldecimal) || !Value::DoubleIsFinite(rdecimal)) { + return ldecimal == rdecimal; + } + double epsilon = std::fabs(rdecimal) * 0.01 + 0.00000001; + return std::fabs(ldecimal - rdecimal) <= epsilon; +} + +//===--------------------------------------------------------------------===// +// Extra Type Info +//===--------------------------------------------------------------------===// +void LogicalType::SetAlias(string alias) { + if (!type_info_) { + type_info_ = make_shared(ExtraTypeInfoType::GENERIC_TYPE_INFO, std::move(alias)); + } else { + type_info_->alias = std::move(alias); + } +} + +string LogicalType::GetAlias() const { + if (id() == LogicalTypeId::USER) { + return UserType::GetTypeName(*this); + } + if (type_info_) { + return type_info_->alias; + } + return string(); +} + +bool LogicalType::HasAlias() const { + if (id() == LogicalTypeId::USER) { + return !UserType::GetTypeName(*this).empty(); + } + if (type_info_ && !type_info_->alias.empty()) { + return true; + } + return false; +} + +//===--------------------------------------------------------------------===// +// Decimal Type +//===--------------------------------------------------------------------===// +uint8_t DecimalType::GetWidth(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::DECIMAL); + auto info = type.AuxInfo(); + D_ASSERT(info); + return info->Cast().width; +} + +uint8_t DecimalType::GetScale(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::DECIMAL); + auto info = type.AuxInfo(); + D_ASSERT(info); + return info->Cast().scale; +} + +uint8_t DecimalType::MaxWidth() { + return DecimalWidth::max; +} + +LogicalType LogicalType::DECIMAL(int width, int scale) { + D_ASSERT(width >= scale); + auto type_info = make_shared(width, scale); + return LogicalType(LogicalTypeId::DECIMAL, std::move(type_info)); +} + +//===--------------------------------------------------------------------===// +// String Type +//===--------------------------------------------------------------------===// +string StringType::GetCollation(const LogicalType &type) { + if (type.id() != LogicalTypeId::VARCHAR) { + return string(); + } + auto info = type.AuxInfo(); + if (!info) { + return string(); + } + if (info->type == ExtraTypeInfoType::GENERIC_TYPE_INFO) { + return string(); + } + return info->Cast().collation; +} + +LogicalType LogicalType::VARCHAR_COLLATION(string collation) { // NOLINT + auto string_info = make_shared(std::move(collation)); + return LogicalType(LogicalTypeId::VARCHAR, std::move(string_info)); +} + +//===--------------------------------------------------------------------===// +// List Type +//===--------------------------------------------------------------------===// +const LogicalType &ListType::GetChildType(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::MAP); + auto info = type.AuxInfo(); + D_ASSERT(info); + return info->Cast().child_type; +} + +LogicalType LogicalType::LIST(const LogicalType &child) { + auto info = make_shared(child); + return LogicalType(LogicalTypeId::LIST, std::move(info)); +} + +//===--------------------------------------------------------------------===// +// Aggregate State Type +//===--------------------------------------------------------------------===// +const aggregate_state_t &AggregateStateType::GetStateType(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::AGGREGATE_STATE); + auto info = type.AuxInfo(); + D_ASSERT(info); + return info->Cast().state_type; +} + +const string AggregateStateType::GetTypeName(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::AGGREGATE_STATE); + auto info = type.AuxInfo(); + if (!info) { + return "AGGREGATE_STATE"; + } + auto aggr_state = info->Cast().state_type; + return "AGGREGATE_STATE<" + aggr_state.function_name + "(" + + StringUtil::Join(aggr_state.bound_argument_types, aggr_state.bound_argument_types.size(), ", ", + [](const LogicalType &arg_type) { return arg_type.ToString(); }) + + ")" + "::" + aggr_state.return_type.ToString() + ">"; +} + +//===--------------------------------------------------------------------===// +// Struct Type +//===--------------------------------------------------------------------===// +const child_list_t &StructType::GetChildTypes(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION); + + auto info = type.AuxInfo(); + D_ASSERT(info); + return info->Cast().child_types; +} + +const LogicalType &StructType::GetChildType(const LogicalType &type, idx_t index) { + auto &child_types = StructType::GetChildTypes(type); + D_ASSERT(index < child_types.size()); + return child_types[index].second; +} + +const string &StructType::GetChildName(const LogicalType &type, idx_t index) { + auto &child_types = StructType::GetChildTypes(type); + D_ASSERT(index < child_types.size()); + return child_types[index].first; +} + +idx_t StructType::GetChildCount(const LogicalType &type) { + return StructType::GetChildTypes(type).size(); +} +bool StructType::IsUnnamed(const LogicalType &type) { + auto &child_types = StructType::GetChildTypes(type); + D_ASSERT(child_types.size() > 0); + return child_types[0].first.empty(); +} + +LogicalType LogicalType::STRUCT(child_list_t children) { + auto info = make_shared(std::move(children)); + return LogicalType(LogicalTypeId::STRUCT, std::move(info)); +} + +LogicalType LogicalType::AGGREGATE_STATE(aggregate_state_t state_type) { // NOLINT + auto info = make_shared(std::move(state_type)); + return LogicalType(LogicalTypeId::AGGREGATE_STATE, std::move(info)); +} + +//===--------------------------------------------------------------------===// +// Map Type +//===--------------------------------------------------------------------===// +LogicalType LogicalType::MAP(const LogicalType &child_p) { + D_ASSERT(child_p.id() == LogicalTypeId::STRUCT); + auto &children = StructType::GetChildTypes(child_p); + D_ASSERT(children.size() == 2); + + // We do this to enforce that for every MAP created, the keys are called "key" + // and the values are called "value" + + // This is done because for Vector the keys of the STRUCT are used in equality checks. + // Vector::Reference will throw if the types don't match + child_list_t new_children(2); + new_children[0] = children[0]; + new_children[0].first = "key"; + + new_children[1] = children[1]; + new_children[1].first = "value"; + + auto child = LogicalType::STRUCT(std::move(new_children)); + auto info = make_shared(child); + return LogicalType(LogicalTypeId::MAP, std::move(info)); +} + +LogicalType LogicalType::MAP(LogicalType key, LogicalType value) { + child_list_t child_types; + child_types.emplace_back("key", std::move(key)); + child_types.emplace_back("value", std::move(value)); + return LogicalType::MAP(LogicalType::STRUCT(child_types)); +} + +const LogicalType &MapType::KeyType(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::MAP); + return StructType::GetChildTypes(ListType::GetChildType(type))[0].second; +} + +const LogicalType &MapType::ValueType(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::MAP); + return StructType::GetChildTypes(ListType::GetChildType(type))[1].second; +} + +//===--------------------------------------------------------------------===// +// Union Type +//===--------------------------------------------------------------------===// +LogicalType LogicalType::UNION(child_list_t members) { + D_ASSERT(!members.empty()); + D_ASSERT(members.size() <= UnionType::MAX_UNION_MEMBERS); + // union types always have a hidden "tag" field in front + members.insert(members.begin(), {"", LogicalType::UTINYINT}); + auto info = make_shared(std::move(members)); + return LogicalType(LogicalTypeId::UNION, std::move(info)); +} + +const LogicalType &UnionType::GetMemberType(const LogicalType &type, idx_t index) { + auto &child_types = StructType::GetChildTypes(type); + D_ASSERT(index < child_types.size()); + // skip the "tag" field + return child_types[index + 1].second; +} + +const string &UnionType::GetMemberName(const LogicalType &type, idx_t index) { + auto &child_types = StructType::GetChildTypes(type); + D_ASSERT(index < child_types.size()); + // skip the "tag" field + return child_types[index + 1].first; +} + +idx_t UnionType::GetMemberCount(const LogicalType &type) { + // don't count the "tag" field + return StructType::GetChildTypes(type).size() - 1; +} +const child_list_t UnionType::CopyMemberTypes(const LogicalType &type) { + auto child_types = StructType::GetChildTypes(type); + child_types.erase(child_types.begin()); + return child_types; +} + +//===--------------------------------------------------------------------===// +// User Type +//===--------------------------------------------------------------------===// +const string &UserType::GetTypeName(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::USER); + auto info = type.AuxInfo(); + D_ASSERT(info); + return info->Cast().user_type_name; +} + +LogicalType LogicalType::USER(const string &user_type_name) { + auto info = make_shared(user_type_name); + return LogicalType(LogicalTypeId::USER, std::move(info)); +} + +//===--------------------------------------------------------------------===// +// Enum Type +//===--------------------------------------------------------------------===// +LogicalType LogicalType::ENUM(Vector &ordered_data, idx_t size) { + return EnumTypeInfo::CreateType(ordered_data, size); +} + +LogicalType LogicalType::ENUM(const string &enum_name, Vector &ordered_data, idx_t size) { + return LogicalType::ENUM(ordered_data, size); +} + +const string EnumType::GetValue(const Value &val) { + auto info = val.type().AuxInfo(); + auto &values_insert_order = info->Cast().GetValuesInsertOrder(); + return StringValue::Get(values_insert_order.GetValue(val.GetValue())); +} + +const Vector &EnumType::GetValuesInsertOrder(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::ENUM); + auto info = type.AuxInfo(); + D_ASSERT(info); + return info->Cast().GetValuesInsertOrder(); +} + +idx_t EnumType::GetSize(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::ENUM); + auto info = type.AuxInfo(); + D_ASSERT(info); + return info->Cast().GetDictSize(); +} + +PhysicalType EnumType::GetPhysicalType(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::ENUM); + auto aux_info = type.AuxInfo(); + D_ASSERT(aux_info); + auto &info = aux_info->Cast(); + D_ASSERT(info.GetEnumDictType() == EnumDictType::VECTOR_DICT); + return EnumTypeInfo::DictType(info.GetDictSize()); +} + +//===--------------------------------------------------------------------===// +// Logical Type +//===--------------------------------------------------------------------===// + +// the destructor needs to know about the extra type info +LogicalType::~LogicalType() { +} + +bool LogicalType::EqualTypeInfo(const LogicalType &rhs) const { + if (type_info_.get() == rhs.type_info_.get()) { + return true; + } + if (type_info_) { + return type_info_->Equals(rhs.type_info_.get()); + } else { + D_ASSERT(rhs.type_info_); + return rhs.type_info_->Equals(type_info_.get()); + } +} + +bool LogicalType::operator==(const LogicalType &rhs) const { + if (id_ != rhs.id_) { + return false; + } + return EqualTypeInfo(rhs); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/batched_data_collection.cpp b/src/duckdb/src/common/types/batched_data_collection.cpp new file mode 100644 index 00000000..072e0c73 --- /dev/null +++ b/src/duckdb/src/common/types/batched_data_collection.cpp @@ -0,0 +1,109 @@ +#include "duckdb/common/types/batched_data_collection.hpp" + +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, + bool buffer_managed_p) + : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p) { +} + +void BatchedDataCollection::Append(DataChunk &input, idx_t batch_index) { + D_ASSERT(batch_index != DConstants::INVALID_INDEX); + optional_ptr collection; + if (last_collection.collection && last_collection.batch_index == batch_index) { + // we are inserting into the same collection as before: use it directly + collection = last_collection.collection; + } else { + // new collection: check if there is already an entry + D_ASSERT(data.find(batch_index) == data.end()); + unique_ptr new_collection; + if (last_collection.collection) { + new_collection = make_uniq(*last_collection.collection); + } else if (buffer_managed) { + new_collection = make_uniq(BufferManager::GetBufferManager(context), types); + } else { + new_collection = make_uniq(Allocator::DefaultAllocator(), types); + } + last_collection.collection = new_collection.get(); + last_collection.batch_index = batch_index; + new_collection->InitializeAppend(last_collection.append_state); + collection = new_collection.get(); + data.insert(make_pair(batch_index, std::move(new_collection))); + } + collection->Append(last_collection.append_state, input); +} + +void BatchedDataCollection::Merge(BatchedDataCollection &other) { + for (auto &entry : other.data) { + if (data.find(entry.first) != data.end()) { + throw InternalException( + "BatchedDataCollection::Merge error - batch index %d is present in both collections. This occurs when " + "batch indexes are not uniquely distributed over threads", + entry.first); + } + data[entry.first] = std::move(entry.second); + } + other.data.clear(); +} + +void BatchedDataCollection::InitializeScan(BatchedChunkScanState &state) { + state.iterator = data.begin(); + if (state.iterator == data.end()) { + return; + } + state.iterator->second->InitializeScan(state.scan_state); +} + +void BatchedDataCollection::Scan(BatchedChunkScanState &state, DataChunk &output) { + while (state.iterator != data.end()) { + // check if there is a chunk remaining in this collection + auto collection = state.iterator->second.get(); + collection->Scan(state.scan_state, output); + if (output.size() > 0) { + return; + } + // there isn't! move to the next collection + state.iterator++; + if (state.iterator == data.end()) { + return; + } + state.iterator->second->InitializeScan(state.scan_state); + } +} + +unique_ptr BatchedDataCollection::FetchCollection() { + unique_ptr result; + for (auto &entry : data) { + if (!result) { + result = std::move(entry.second); + } else { + result->Combine(*entry.second); + } + } + data.clear(); + if (!result) { + // empty result + return make_uniq(Allocator::DefaultAllocator(), types); + } + return result; +} + +string BatchedDataCollection::ToString() const { + string result; + result += "Batched Data Collection\n"; + for (auto &entry : data) { + result += "Batch Index - " + to_string(entry.first) + "\n"; + result += entry.second->ToString() + "\n\n"; + } + return result; +} + +void BatchedDataCollection::Print() const { + Printer::Print(ToString()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/bit.cpp b/src/duckdb/src/common/types/bit.cpp new file mode 100644 index 00000000..761ac953 --- /dev/null +++ b/src/duckdb/src/common/types/bit.cpp @@ -0,0 +1,408 @@ +#include "duckdb/common/assert.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/string_type.hpp" + +namespace duckdb { + +// **** helper functions **** +static char ComputePadding(idx_t len) { + return (8 - (len % 8)) % 8; +} + +idx_t Bit::ComputeBitstringLen(idx_t len) { + idx_t result = len / 8; + if (len % 8 != 0) { + result++; + } + // additional first byte to store info on zero padding + result++; + return result; +} + +static inline idx_t GetBitPadding(const string_t &bit_string) { + auto data = const_data_ptr_cast(bit_string.GetData()); + D_ASSERT(idx_t(data[0]) <= 8); + return data[0]; +} + +static inline idx_t GetBitSize(const string_t &str) { + string error_message; + idx_t str_len; + if (!Bit::TryGetBitStringSize(str, str_len, &error_message)) { + throw ConversionException(error_message); + } + return str_len; +} + +uint8_t Bit::GetFirstByte(const string_t &str) { + D_ASSERT(str.GetSize() > 1); + + auto data = const_data_ptr_cast(str.GetData()); + return data[1] & ((1 << (8 - data[0])) - 1); +} + +void Bit::Finalize(string_t &str) { + // bit strings require all padding bits to be set to 1 + // this method sets all padding bits to 1 + auto padding = GetBitPadding(str); + for (idx_t i = 0; i < idx_t(padding); i++) { + Bit::SetBitInternal(str, i, 1); + } + Bit::Verify(str); +} + +void Bit::SetEmptyBitString(string_t &target, string_t &input) { + char *res_buf = target.GetDataWriteable(); + const char *buf = input.GetData(); + memset(res_buf, 0, input.GetSize()); + res_buf[0] = buf[0]; + Bit::Finalize(target); +} + +void Bit::SetEmptyBitString(string_t &target, idx_t len) { + char *res_buf = target.GetDataWriteable(); + memset(res_buf, 0, target.GetSize()); + res_buf[0] = ComputePadding(len); + Bit::Finalize(target); +} + +// **** casting functions **** +void Bit::ToString(string_t bits, char *output) { + auto data = const_data_ptr_cast(bits.GetData()); + auto len = bits.GetSize(); + + idx_t padding = GetBitPadding(bits); + idx_t output_idx = 0; + for (idx_t bit_idx = padding; bit_idx < 8; bit_idx++) { + output[output_idx++] = data[1] & (1 << (7 - bit_idx)) ? '1' : '0'; + } + for (idx_t byte_idx = 2; byte_idx < len; byte_idx++) { + for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { + output[output_idx++] = data[byte_idx] & (1 << (7 - bit_idx)) ? '1' : '0'; + } + } +} + +string Bit::ToString(string_t str) { + auto len = BitLength(str); + auto buffer = make_unsafe_uniq_array(len); + ToString(str, buffer.get()); + return string(buffer.get(), len); +} + +bool Bit::TryGetBitStringSize(string_t str, idx_t &str_len, string *error_message) { + auto data = const_data_ptr_cast(str.GetData()); + auto len = str.GetSize(); + str_len = 0; + for (idx_t i = 0; i < len; i++) { + if (data[i] == '0' || data[i] == '1') { + str_len++; + } else { + string error = StringUtil::Format("Invalid character encountered in string -> bit conversion: '%s'", + string(const_char_ptr_cast(data) + i, 1)); + HandleCastError::AssignError(error, error_message); + return false; + } + } + if (str_len == 0) { + string error = "Cannot cast empty string to BIT"; + HandleCastError::AssignError(error, error_message); + return false; + } + str_len = ComputeBitstringLen(str_len); + return true; +} + +void Bit::ToBit(string_t str, string_t &output_str) { + auto data = const_data_ptr_cast(str.GetData()); + auto len = str.GetSize(); + auto output = output_str.GetDataWriteable(); + + char byte = 0; + idx_t padded_byte = len % 8; + for (idx_t i = 0; i < padded_byte; i++) { + byte <<= 1; + if (data[i] == '1') { + byte |= 1; + } + } + if (padded_byte != 0) { + *(output++) = (8 - padded_byte); // the first byte contains the number of padded zeroes + } + *(output++) = byte; + + for (idx_t byte_idx = padded_byte; byte_idx < len; byte_idx += 8) { + byte = 0; + for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { + byte <<= 1; + if (data[byte_idx + bit_idx] == '1') { + byte |= 1; + } + } + *(output++) = byte; + } + Bit::Finalize(output_str); + Bit::Verify(output_str); +} + +string Bit::ToBit(string_t str) { + auto bit_len = GetBitSize(str); + auto buffer = make_unsafe_uniq_array(bit_len); + string_t output_str(buffer.get(), bit_len); + Bit::ToBit(str, output_str); + return output_str.GetString(); +} + +void Bit::BlobToBit(string_t blob, string_t &output_str) { + auto data = const_data_ptr_cast(blob.GetData()); + auto output = output_str.GetDataWriteable(); + idx_t size = blob.GetSize(); + + *output = 0; // No padding + memcpy(output + 1, data, size); +} + +string Bit::BlobToBit(string_t blob) { + auto buffer = make_unsafe_uniq_array(blob.GetSize() + 1); + string_t output_str(buffer.get(), blob.GetSize() + 1); + Bit::BlobToBit(blob, output_str); + return output_str.GetString(); +} + +void Bit::BitToBlob(string_t bit, string_t &output_blob) { + D_ASSERT(bit.GetSize() == output_blob.GetSize() + 1); + + auto data = const_data_ptr_cast(bit.GetData()); + auto output = output_blob.GetDataWriteable(); + idx_t size = output_blob.GetSize(); + + output[0] = GetFirstByte(bit); + if (size > 2) { + ++output; + // First byte in bitstring contains amount of padded bits, + // second byte in bitstring is the padded byte, + // therefore the rest of the data starts at data + 2 (third byte) + memcpy(output, data + 2, size - 1); + } +} + +string Bit::BitToBlob(string_t bit) { + D_ASSERT(bit.GetSize() > 1); + + auto buffer = make_unsafe_uniq_array(bit.GetSize() - 1); + string_t output_str(buffer.get(), bit.GetSize() - 1); + Bit::BitToBlob(bit, output_str); + return output_str.GetString(); +} + +// **** scalar functions **** +void Bit::BitString(const string_t &input, const idx_t &bit_length, string_t &result) { + char *res_buf = result.GetDataWriteable(); + const char *buf = input.GetData(); + + auto padding = ComputePadding(bit_length); + res_buf[0] = padding; + for (idx_t i = 0; i < bit_length; i++) { + if (i < bit_length - input.GetSize()) { + Bit::SetBit(result, i, 0); + } else { + idx_t bit = buf[i - (bit_length - input.GetSize())] == '1' ? 1 : 0; + Bit::SetBit(result, i, bit); + } + } + Bit::Finalize(result); +} + +idx_t Bit::BitLength(string_t bits) { + return ((bits.GetSize() - 1) * 8) - GetBitPadding(bits); +} + +idx_t Bit::OctetLength(string_t bits) { + return bits.GetSize() - 1; +} + +idx_t Bit::BitCount(string_t bits) { + idx_t count = 0; + const char *buf = bits.GetData(); + for (idx_t byte_idx = 1; byte_idx < OctetLength(bits) + 1; byte_idx++) { + for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { + count += (buf[byte_idx] & (1 << bit_idx)) ? 1 : 0; + } + } + return count - GetBitPadding(bits); +} + +idx_t Bit::BitPosition(string_t substring, string_t bits) { + const char *buf = bits.GetData(); + auto len = bits.GetSize(); + auto substr_len = BitLength(substring); + idx_t substr_idx = 0; + + for (idx_t bit_idx = GetBitPadding(bits); bit_idx < 8; bit_idx++) { + idx_t bit = buf[1] & (1 << (7 - bit_idx)) ? 1 : 0; + if (bit == GetBit(substring, substr_idx)) { + substr_idx++; + if (substr_idx == substr_len) { + return (bit_idx - GetBitPadding(bits)) - substr_len + 2; + } + } else { + substr_idx = 0; + } + } + + for (idx_t byte_idx = 2; byte_idx < len; byte_idx++) { + for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { + idx_t bit = buf[byte_idx] & (1 << (7 - bit_idx)) ? 1 : 0; + if (bit == GetBit(substring, substr_idx)) { + substr_idx++; + if (substr_idx == substr_len) { + return (((byte_idx - 1) * 8) + bit_idx - GetBitPadding(bits)) - substr_len + 2; + } + } else { + substr_idx = 0; + } + } + } + return 0; +} + +idx_t Bit::GetBit(string_t bit_string, idx_t n) { + return Bit::GetBitInternal(bit_string, n + GetBitPadding(bit_string)); +} + +idx_t Bit::GetBitIndex(idx_t n) { + return n / 8 + 1; +} + +idx_t Bit::GetBitInternal(string_t bit_string, idx_t n) { + const char *buf = bit_string.GetData(); + auto idx = Bit::GetBitIndex(n); + D_ASSERT(idx < bit_string.GetSize()); + char byte = buf[idx] >> (7 - (n % 8)); + return (byte & 1 ? 1 : 0); +} + +void Bit::SetBit(string_t &bit_string, idx_t n, idx_t new_value) { + SetBitInternal(bit_string, n + GetBitPadding(bit_string), new_value); +} + +void Bit::SetBitInternal(string_t &bit_string, idx_t n, idx_t new_value) { + char *buf = bit_string.GetDataWriteable(); + + auto idx = Bit::GetBitIndex(n); + D_ASSERT(idx < bit_string.GetSize()); + char shift_byte = 1 << (7 - (n % 8)); + if (new_value == 0) { + shift_byte = ~shift_byte; + buf[idx] &= shift_byte; + } else { + buf[idx] |= shift_byte; + } +} + +// **** BITWISE operators **** +void Bit::RightShift(const string_t &bit_string, const idx_t &shift, string_t &result) { + char *res_buf = result.GetDataWriteable(); + const char *buf = bit_string.GetData(); + res_buf[0] = buf[0]; + for (idx_t i = 0; i < Bit::BitLength(result); i++) { + if (i < shift) { + Bit::SetBit(result, i, 0); + } else { + idx_t bit = Bit::GetBit(bit_string, i - shift); + Bit::SetBit(result, i, bit); + } + } + Bit::Finalize(result); +} + +void Bit::LeftShift(const string_t &bit_string, const idx_t &shift, string_t &result) { + char *res_buf = result.GetDataWriteable(); + const char *buf = bit_string.GetData(); + res_buf[0] = buf[0]; + for (idx_t i = 0; i < Bit::BitLength(bit_string); i++) { + if (i < (Bit::BitLength(bit_string) - shift)) { + idx_t bit = Bit::GetBit(bit_string, shift + i); + Bit::SetBit(result, i, bit); + } else { + Bit::SetBit(result, i, 0); + } + } + Bit::Finalize(result); + Bit::Verify(result); +} + +void Bit::BitwiseAnd(const string_t &rhs, const string_t &lhs, string_t &result) { + if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { + throw InvalidInputException("Cannot AND bit strings of different sizes"); + } + + char *buf = result.GetDataWriteable(); + const char *r_buf = rhs.GetData(); + const char *l_buf = lhs.GetData(); + + buf[0] = l_buf[0]; + for (idx_t i = 1; i < lhs.GetSize(); i++) { + buf[i] = l_buf[i] & r_buf[i]; + } + // and should preserve padding bits + Bit::Verify(result); +} + +void Bit::BitwiseOr(const string_t &rhs, const string_t &lhs, string_t &result) { + if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { + throw InvalidInputException("Cannot OR bit strings of different sizes"); + } + + char *buf = result.GetDataWriteable(); + const char *r_buf = rhs.GetData(); + const char *l_buf = lhs.GetData(); + + buf[0] = l_buf[0]; + for (idx_t i = 1; i < lhs.GetSize(); i++) { + buf[i] = l_buf[i] | r_buf[i]; + } + // or should preserve padding bits + Bit::Verify(result); +} + +void Bit::BitwiseXor(const string_t &rhs, const string_t &lhs, string_t &result) { + if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { + throw InvalidInputException("Cannot XOR bit strings of different sizes"); + } + + char *buf = result.GetDataWriteable(); + const char *r_buf = rhs.GetData(); + const char *l_buf = lhs.GetData(); + + buf[0] = l_buf[0]; + for (idx_t i = 1; i < lhs.GetSize(); i++) { + buf[i] = l_buf[i] ^ r_buf[i]; + } + Bit::Finalize(result); +} + +void Bit::BitwiseNot(const string_t &input, string_t &result) { + char *result_buf = result.GetDataWriteable(); + const char *buf = input.GetData(); + + result_buf[0] = buf[0]; + for (idx_t i = 1; i < input.GetSize(); i++) { + result_buf[i] = ~buf[i]; + } + Bit::Finalize(result); +} + +void Bit::Verify(const string_t &input) { +#ifdef DEBUG + // bit strings require all padding bits to be set to 1 + auto padding = GetBitPadding(input); + for (idx_t i = 0; i < padding; i++) { + D_ASSERT(Bit::GetBitInternal(input, i)); + } +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/blob.cpp b/src/duckdb/src/common/types/blob.cpp new file mode 100644 index 00000000..8e5720d1 --- /dev/null +++ b/src/duckdb/src/common/types/blob.cpp @@ -0,0 +1,266 @@ +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/blob.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/operator/cast_operators.hpp" + +namespace duckdb { + +constexpr const char *Blob::HEX_TABLE; +const int Blob::HEX_MAP[256] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + -1, -1, -1, -1, -1, -1, -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; + +bool IsRegularCharacter(data_t c) { + return c >= 32 && c <= 126 && c != '\\' && c != '\'' && c != '"'; +} + +idx_t Blob::GetStringSize(string_t blob) { + auto data = const_data_ptr_cast(blob.GetData()); + auto len = blob.GetSize(); + idx_t str_len = 0; + for (idx_t i = 0; i < len; i++) { + if (IsRegularCharacter(data[i])) { + // ascii characters are rendered as-is + str_len++; + } else { + // non-ascii characters are rendered as hexadecimal (e.g. \x00) + str_len += 4; + } + } + return str_len; +} + +void Blob::ToString(string_t blob, char *output) { + auto data = const_data_ptr_cast(blob.GetData()); + auto len = blob.GetSize(); + idx_t str_idx = 0; + for (idx_t i = 0; i < len; i++) { + if (IsRegularCharacter(data[i])) { + // ascii characters are rendered as-is + output[str_idx++] = data[i]; + } else { + auto byte_a = data[i] >> 4; + auto byte_b = data[i] & 0x0F; + D_ASSERT(byte_a >= 0 && byte_a < 16); + D_ASSERT(byte_b >= 0 && byte_b < 16); + // non-ascii characters are rendered as hexadecimal (e.g. \x00) + output[str_idx++] = '\\'; + output[str_idx++] = 'x'; + output[str_idx++] = Blob::HEX_TABLE[byte_a]; + output[str_idx++] = Blob::HEX_TABLE[byte_b]; + } + } + D_ASSERT(str_idx == GetStringSize(blob)); +} + +string Blob::ToString(string_t blob) { + auto str_len = GetStringSize(blob); + auto buffer = make_unsafe_uniq_array(str_len); + Blob::ToString(blob, buffer.get()); + return string(buffer.get(), str_len); +} + +bool Blob::TryGetBlobSize(string_t str, idx_t &str_len, string *error_message) { + auto data = const_data_ptr_cast(str.GetData()); + auto len = str.GetSize(); + str_len = 0; + for (idx_t i = 0; i < len; i++) { + if (data[i] == '\\') { + if (i + 3 >= len) { + string error = "Invalid hex escape code encountered in string -> blob conversion: " + "unterminated escape code at end of blob"; + HandleCastError::AssignError(error, error_message); + return false; + } + if (data[i + 1] != 'x' || Blob::HEX_MAP[data[i + 2]] < 0 || Blob::HEX_MAP[data[i + 3]] < 0) { + string error = + StringUtil::Format("Invalid hex escape code encountered in string -> blob conversion: %s", + string(const_char_ptr_cast(data) + i, 4)); + HandleCastError::AssignError(error, error_message); + return false; + } + str_len++; + i += 3; + } else if (data[i] <= 127) { + str_len++; + } else { + string error = "Invalid byte encountered in STRING -> BLOB conversion. All non-ascii characters " + "must be escaped with hex codes (e.g. \\xAA)"; + HandleCastError::AssignError(error, error_message); + return false; + } + } + return true; +} + +idx_t Blob::GetBlobSize(string_t str) { + string error_message; + idx_t str_len; + if (!Blob::TryGetBlobSize(str, str_len, &error_message)) { + throw ConversionException(error_message); + } + return str_len; +} + +void Blob::ToBlob(string_t str, data_ptr_t output) { + auto data = const_data_ptr_cast(str.GetData()); + auto len = str.GetSize(); + idx_t blob_idx = 0; + for (idx_t i = 0; i < len; i++) { + if (data[i] == '\\') { + int byte_a = Blob::HEX_MAP[data[i + 2]]; + int byte_b = Blob::HEX_MAP[data[i + 3]]; + D_ASSERT(i + 3 < len); + D_ASSERT(byte_a >= 0 && byte_b >= 0); + D_ASSERT(data[i + 1] == 'x'); + output[blob_idx++] = (byte_a << 4) + byte_b; + i += 3; + } else if (data[i] <= 127) { + output[blob_idx++] = data_t(data[i]); + } else { + throw ConversionException("Invalid byte encountered in STRING -> BLOB conversion. All non-ascii characters " + "must be escaped with hex codes (e.g. \\xAA)"); + } + } + D_ASSERT(blob_idx == GetBlobSize(str)); +} + +string Blob::ToBlob(string_t str) { + auto blob_len = GetBlobSize(str); + auto buffer = make_unsafe_uniq_array(blob_len); + Blob::ToBlob(str, data_ptr_cast(buffer.get())); + return string(buffer.get(), blob_len); +} + +// base64 functions are adapted from https://gist.github.com/tomykaira/f0fd86b6c73063283afe550bc5d77594 +idx_t Blob::ToBase64Size(string_t blob) { + // every 4 characters in base64 encode 3 bytes, plus (potential) padding at the end + auto input_size = blob.GetSize(); + return ((input_size + 2) / 3) * 4; +} + +void Blob::ToBase64(string_t blob, char *output) { + auto input_data = const_data_ptr_cast(blob.GetData()); + auto input_size = blob.GetSize(); + idx_t out_idx = 0; + idx_t i; + // convert the bulk of the string to base64 + // this happens in steps of 3 bytes -> 4 output bytes + for (i = 0; i + 2 < input_size; i += 3) { + output[out_idx++] = Blob::BASE64_MAP[(input_data[i] >> 2) & 0x3F]; + output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4) | ((input_data[i + 1] & 0xF0) >> 4)]; + output[out_idx++] = Blob::BASE64_MAP[((input_data[i + 1] & 0xF) << 2) | ((input_data[i + 2] & 0xC0) >> 6)]; + output[out_idx++] = Blob::BASE64_MAP[input_data[i + 2] & 0x3F]; + } + + if (i < input_size) { + // there are one or two bytes left over: we have to insert padding + // first write the first 6 bits of the first byte + output[out_idx++] = Blob::BASE64_MAP[(input_data[i] >> 2) & 0x3F]; + // now check the character count + if (i == input_size - 1) { + // single byte left over: convert the remainder of that byte and insert padding + output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4)]; + output[out_idx++] = Blob::BASE64_PADDING; + } else { + // two bytes left over: convert the second byte as well + output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4) | ((input_data[i + 1] & 0xF0) >> 4)]; + output[out_idx++] = Blob::BASE64_MAP[((input_data[i + 1] & 0xF) << 2)]; + } + output[out_idx++] = Blob::BASE64_PADDING; + } +} + +static constexpr int BASE64_DECODING_TABLE[256] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, + -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, -1, -1, -1, -1, -1, -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; + +idx_t Blob::FromBase64Size(string_t str) { + auto input_data = str.GetData(); + auto input_size = str.GetSize(); + if (input_size % 4 != 0) { + // valid base64 needs to always be cleanly divisible by 4 + throw ConversionException("Could not decode string \"%s\" as base64: length must be a multiple of 4", + str.GetString()); + } + if (input_size < 4) { + // empty string + return 0; + } + auto base_size = input_size / 4 * 3; + // check for padding to figure out the length + if (input_data[input_size - 2] == Blob::BASE64_PADDING) { + // two bytes of padding + return base_size - 2; + } + if (input_data[input_size - 1] == Blob::BASE64_PADDING) { + // one byte of padding + return base_size - 1; + } + // no padding + return base_size; +} + +template +uint32_t DecodeBase64Bytes(const string_t &str, const_data_ptr_t input_data, idx_t base_idx) { + int decoded_bytes[4]; + for (idx_t decode_idx = 0; decode_idx < 4; decode_idx++) { + if (ALLOW_PADDING && decode_idx >= 2 && input_data[base_idx + decode_idx] == Blob::BASE64_PADDING) { + // the last two bytes of a base64 string can have padding: in this case we set the byte to 0 + decoded_bytes[decode_idx] = 0; + } else { + decoded_bytes[decode_idx] = BASE64_DECODING_TABLE[input_data[base_idx + decode_idx]]; + } + if (decoded_bytes[decode_idx] < 0) { + throw ConversionException( + "Could not decode string \"%s\" as base64: invalid byte value '%d' at position %d", str.GetString(), + input_data[base_idx + decode_idx], base_idx + decode_idx); + } + } + return (decoded_bytes[0] << 3 * 6) + (decoded_bytes[1] << 2 * 6) + (decoded_bytes[2] << 1 * 6) + + (decoded_bytes[3] << 0 * 6); +} + +void Blob::FromBase64(string_t str, data_ptr_t output, idx_t output_size) { + D_ASSERT(output_size == FromBase64Size(str)); + auto input_data = const_data_ptr_cast(str.GetData()); + auto input_size = str.GetSize(); + if (input_size == 0) { + return; + } + idx_t out_idx = 0; + idx_t i = 0; + for (i = 0; i + 4 < input_size; i += 4) { + auto combined = DecodeBase64Bytes(str, input_data, i); + output[out_idx++] = (combined >> 2 * 8) & 0xFF; + output[out_idx++] = (combined >> 1 * 8) & 0xFF; + output[out_idx++] = (combined >> 0 * 8) & 0xFF; + } + // decode the final four bytes: padding is allowed here + auto combined = DecodeBase64Bytes(str, input_data, i); + output[out_idx++] = (combined >> 2 * 8) & 0xFF; + if (out_idx < output_size) { + output[out_idx++] = (combined >> 1 * 8) & 0xFF; + } + if (out_idx < output_size) { + output[out_idx++] = (combined >> 0 * 8) & 0xFF; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/cast_helpers.cpp b/src/duckdb/src/common/types/cast_helpers.cpp new file mode 100644 index 00000000..5011b674 --- /dev/null +++ b/src/duckdb/src/common/types/cast_helpers.cpp @@ -0,0 +1,110 @@ +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/hugeint.hpp" + +namespace duckdb { + +const int64_t NumericHelper::POWERS_OF_TEN[] {1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000}; + +const double NumericHelper::DOUBLE_POWERS_OF_TEN[] {1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, + 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, + 1e20, 1e21, 1e22, 1e23, 1e24, 1e25, 1e26, 1e27, 1e28, 1e29, + 1e30, 1e31, 1e32, 1e33, 1e34, 1e35, 1e36, 1e37, 1e38, 1e39}; + +template <> +int NumericHelper::UnsignedLength(uint8_t value) { + int length = 1; + length += value >= 10; + length += value >= 100; + return length; +} + +template <> +int NumericHelper::UnsignedLength(uint16_t value) { + int length = 1; + length += value >= 10; + length += value >= 100; + length += value >= 1000; + length += value >= 10000; + return length; +} + +template <> +int NumericHelper::UnsignedLength(uint32_t value) { + if (value >= 10000) { + int length = 5; + length += value >= 100000; + length += value >= 1000000; + length += value >= 10000000; + length += value >= 100000000; + length += value >= 1000000000; + return length; + } else { + int length = 1; + length += value >= 10; + length += value >= 100; + length += value >= 1000; + return length; + } +} + +template <> +int NumericHelper::UnsignedLength(uint64_t value) { + if (value >= 10000000000ULL) { + if (value >= 1000000000000000ULL) { + int length = 16; + length += value >= 10000000000000000ULL; + length += value >= 100000000000000000ULL; + length += value >= 1000000000000000000ULL; + length += value >= 10000000000000000000ULL; + return length; + } else { + int length = 11; + length += value >= 100000000000ULL; + length += value >= 1000000000000ULL; + length += value >= 10000000000000ULL; + length += value >= 100000000000000ULL; + return length; + } + } else { + if (value >= 100000ULL) { + int length = 6; + length += value >= 1000000ULL; + length += value >= 10000000ULL; + length += value >= 100000000ULL; + length += value >= 1000000000ULL; + return length; + } else { + int length = 1; + length += value >= 10ULL; + length += value >= 100ULL; + length += value >= 1000ULL; + length += value >= 10000ULL; + return length; + } + } +} + +template <> +std::string NumericHelper::ToString(hugeint_t value) { + return Hugeint::ToString(value); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/chunk_collection.cpp b/src/duckdb/src/common/types/chunk_collection.cpp new file mode 100644 index 00000000..abd4314e --- /dev/null +++ b/src/duckdb/src/common/types/chunk_collection.cpp @@ -0,0 +1,190 @@ +#include "duckdb/common/types/chunk_collection.hpp" + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/queue.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include +#include + +namespace duckdb { + +ChunkCollection::ChunkCollection(Allocator &allocator) : allocator(allocator), count(0) { +} + +ChunkCollection::ChunkCollection(ClientContext &context) : ChunkCollection(Allocator::Get(context)) { +} + +void ChunkCollection::Verify() { +#ifdef DEBUG + for (auto &chunk : chunks) { + chunk->Verify(); + } +#endif +} + +void ChunkCollection::Append(ChunkCollection &other) { + for (auto &chunk : other.chunks) { + Append(*chunk); + } +} + +void ChunkCollection::Merge(ChunkCollection &other) { + if (other.count == 0) { + return; + } + if (count == 0) { + chunks = std::move(other.chunks); + types = std::move(other.types); + count = other.count; + return; + } + unique_ptr old_back; + if (!chunks.empty() && chunks.back()->size() != STANDARD_VECTOR_SIZE) { + old_back = std::move(chunks.back()); + chunks.pop_back(); + count -= old_back->size(); + } + for (auto &chunk : other.chunks) { + chunks.push_back(std::move(chunk)); + } + count += other.count; + if (old_back) { + Append(*old_back); + } + Verify(); +} + +void ChunkCollection::Append(DataChunk &new_chunk) { + if (new_chunk.size() == 0) { + return; + } + new_chunk.Verify(); + + // we have to ensure that every chunk in the ChunkCollection is completely + // filled, otherwise our O(1) lookup in GetValue and SetValue does not work + // first fill the latest chunk, if it exists + count += new_chunk.size(); + + idx_t remaining_data = new_chunk.size(); + idx_t offset = 0; + if (chunks.empty()) { + // first chunk + types = new_chunk.GetTypes(); + } else { + // the types of the new chunk should match the types of the previous one + D_ASSERT(types.size() == new_chunk.ColumnCount()); + auto new_types = new_chunk.GetTypes(); + for (idx_t i = 0; i < types.size(); i++) { + if (new_types[i] != types[i]) { + throw TypeMismatchException(new_types[i], types[i], "Type mismatch when combining rows"); + } + if (types[i].InternalType() == PhysicalType::LIST) { + // need to check all the chunks because they can have only-null list entries + for (auto &chunk : chunks) { + auto &chunk_vec = chunk->data[i]; + auto &new_vec = new_chunk.data[i]; + auto &chunk_type = chunk_vec.GetType(); + auto &new_type = new_vec.GetType(); + if (chunk_type != new_type) { + throw TypeMismatchException(chunk_type, new_type, "Type mismatch when combining lists"); + } + } + } + // TODO check structs, too + } + + // first append data to the current chunk + DataChunk &last_chunk = *chunks.back(); + idx_t added_data = MinValue(remaining_data, STANDARD_VECTOR_SIZE - last_chunk.size()); + if (added_data > 0) { + // copy elements to the last chunk + new_chunk.Flatten(); + // have to be careful here: setting the cardinality without calling normalify can cause incorrect partial + // decompression + idx_t old_count = new_chunk.size(); + new_chunk.SetCardinality(added_data); + + last_chunk.Append(new_chunk); + remaining_data -= added_data; + // reset the chunk to the old data + new_chunk.SetCardinality(old_count); + offset = added_data; + } + } + + if (remaining_data > 0) { + // create a new chunk and fill it with the remainder + auto chunk = make_uniq(); + chunk->Initialize(allocator, types); + new_chunk.Copy(*chunk, offset); + chunks.push_back(std::move(chunk)); + } +} + +void ChunkCollection::Append(unique_ptr new_chunk) { + if (types.empty()) { + types = new_chunk->GetTypes(); + } + D_ASSERT(types == new_chunk->GetTypes()); + count += new_chunk->size(); + chunks.push_back(std::move(new_chunk)); +} + +void ChunkCollection::Fuse(ChunkCollection &other) { + if (count == 0) { + chunks.reserve(other.ChunkCount()); + for (idx_t chunk_idx = 0; chunk_idx < other.ChunkCount(); ++chunk_idx) { + auto lhs = make_uniq(); + auto &rhs = other.GetChunk(chunk_idx); + lhs->data.reserve(rhs.data.size()); + for (auto &v : rhs.data) { + lhs->data.emplace_back(v); + } + lhs->SetCardinality(rhs.size()); + chunks.push_back(std::move(lhs)); + } + count = other.Count(); + } else { + D_ASSERT(this->ChunkCount() == other.ChunkCount()); + for (idx_t chunk_idx = 0; chunk_idx < ChunkCount(); ++chunk_idx) { + auto &lhs = this->GetChunk(chunk_idx); + auto &rhs = other.GetChunk(chunk_idx); + D_ASSERT(lhs.size() == rhs.size()); + for (auto &v : rhs.data) { + lhs.data.emplace_back(v); + } + } + } + types.insert(types.end(), other.types.begin(), other.types.end()); +} + +Value ChunkCollection::GetValue(idx_t column, idx_t index) { + return chunks[LocateChunk(index)]->GetValue(column, index % STANDARD_VECTOR_SIZE); +} + +void ChunkCollection::SetValue(idx_t column, idx_t index, const Value &value) { + chunks[LocateChunk(index)]->SetValue(column, index % STANDARD_VECTOR_SIZE, value); +} + +void ChunkCollection::CopyCell(idx_t column, idx_t index, Vector &target, idx_t target_offset) { + auto &chunk = GetChunkForRow(index); + auto &source = chunk.data[column]; + const auto source_offset = index % STANDARD_VECTOR_SIZE; + VectorOperations::Copy(source, target, source_offset + 1, source_offset, target_offset); +} + +string ChunkCollection::ToString() const { + return chunks.empty() ? "ChunkCollection [ 0 ]" + : "ChunkCollection [ " + std::to_string(count) + " ]: \n" + chunks[0]->ToString(); +} + +void ChunkCollection::Print() const { + Printer::Print(ToString()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/column_data_allocator.cpp b/src/duckdb/src/common/types/column/column_data_allocator.cpp new file mode 100644 index 00000000..3e7d1843 --- /dev/null +++ b/src/duckdb/src/common/types/column/column_data_allocator.cpp @@ -0,0 +1,265 @@ +#include "duckdb/common/types/column/column_data_allocator.hpp" + +#include "duckdb/common/types/column/column_data_collection_segment.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +ColumnDataAllocator::ColumnDataAllocator(Allocator &allocator) : type(ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { + alloc.allocator = &allocator; +} + +ColumnDataAllocator::ColumnDataAllocator(BufferManager &buffer_manager) + : type(ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { + alloc.buffer_manager = &buffer_manager; +} + +ColumnDataAllocator::ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type) + : type(allocator_type) { + switch (type) { + case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: + case ColumnDataAllocatorType::HYBRID: + alloc.buffer_manager = &BufferManager::GetBufferManager(context); + break; + case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: + alloc.allocator = &Allocator::Get(context); + break; + default: + throw InternalException("Unrecognized column data allocator type"); + } +} + +ColumnDataAllocator::ColumnDataAllocator(ColumnDataAllocator &other) { + type = other.GetType(); + switch (type) { + case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: + case ColumnDataAllocatorType::HYBRID: + alloc.allocator = other.alloc.allocator; + break; + case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: + alloc.buffer_manager = other.alloc.buffer_manager; + break; + default: + throw InternalException("Unrecognized column data allocator type"); + } +} + +BufferHandle ColumnDataAllocator::Pin(uint32_t block_id) { + D_ASSERT(type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || type == ColumnDataAllocatorType::HYBRID); + shared_ptr handle; + if (shared) { + // we only need to grab the lock when accessing the vector, because vector access is not thread-safe: + // the vector can be resized by another thread while we try to access it + lock_guard guard(lock); + handle = blocks[block_id].handle; + } else { + handle = blocks[block_id].handle; + } + return alloc.buffer_manager->Pin(handle); +} + +BufferHandle ColumnDataAllocator::AllocateBlock(idx_t size) { + D_ASSERT(type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || type == ColumnDataAllocatorType::HYBRID); + auto block_size = MaxValue(size, Storage::BLOCK_SIZE); + BlockMetaData data; + data.size = 0; + data.capacity = block_size; + auto pin = alloc.buffer_manager->Allocate(block_size, false, &data.handle); + blocks.push_back(std::move(data)); + return pin; +} + +void ColumnDataAllocator::AllocateEmptyBlock(idx_t size) { + auto allocation_amount = MaxValue(NextPowerOfTwo(size), 4096); + if (!blocks.empty()) { + idx_t last_capacity = blocks.back().capacity; + auto next_capacity = MinValue(last_capacity * 2, last_capacity + Storage::BLOCK_SIZE); + allocation_amount = MaxValue(next_capacity, allocation_amount); + } + D_ASSERT(type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); + BlockMetaData data; + data.size = 0; + data.capacity = allocation_amount; + data.handle = nullptr; + blocks.push_back(std::move(data)); +} + +void ColumnDataAllocator::AssignPointer(uint32_t &block_id, uint32_t &offset, data_ptr_t pointer) { + auto pointer_value = uintptr_t(pointer); + if (sizeof(uintptr_t) == sizeof(uint32_t)) { + block_id = uint32_t(pointer_value); + } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { + block_id = uint32_t(pointer_value & 0xFFFFFFFF); + offset = uint32_t(pointer_value >> 32); + } else { + throw InternalException("ColumnDataCollection: Architecture not supported!?"); + } +} + +void ColumnDataAllocator::AllocateBuffer(idx_t size, uint32_t &block_id, uint32_t &offset, + ChunkManagementState *chunk_state) { + D_ASSERT(allocated_data.empty()); + if (blocks.empty() || blocks.back().Capacity() < size) { + auto pinned_block = AllocateBlock(size); + if (chunk_state) { + D_ASSERT(!blocks.empty()); + auto new_block_id = blocks.size() - 1; + chunk_state->handles[new_block_id] = std::move(pinned_block); + } + } + auto &block = blocks.back(); + D_ASSERT(size <= block.capacity - block.size); + block_id = blocks.size() - 1; + if (chunk_state && chunk_state->handles.find(block_id) == chunk_state->handles.end()) { + // not guaranteed to be pinned already by this thread (if shared allocator) + chunk_state->handles[block_id] = alloc.buffer_manager->Pin(blocks[block_id].handle); + } + offset = block.size; + block.size += size; +} + +void ColumnDataAllocator::AllocateMemory(idx_t size, uint32_t &block_id, uint32_t &offset, + ChunkManagementState *chunk_state) { + D_ASSERT(blocks.size() == allocated_data.size()); + if (blocks.empty() || blocks.back().Capacity() < size) { + AllocateEmptyBlock(size); + auto &last_block = blocks.back(); + auto allocated = alloc.allocator->Allocate(last_block.capacity); + allocated_data.push_back(std::move(allocated)); + } + auto &block = blocks.back(); + D_ASSERT(size <= block.capacity - block.size); + AssignPointer(block_id, offset, allocated_data.back().get() + block.size); + block.size += size; +} + +void ColumnDataAllocator::AllocateData(idx_t size, uint32_t &block_id, uint32_t &offset, + ChunkManagementState *chunk_state) { + switch (type) { + case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: + case ColumnDataAllocatorType::HYBRID: + if (shared) { + lock_guard guard(lock); + AllocateBuffer(size, block_id, offset, chunk_state); + } else { + AllocateBuffer(size, block_id, offset, chunk_state); + } + break; + case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: + D_ASSERT(!shared); + AllocateMemory(size, block_id, offset, chunk_state); + break; + default: + throw InternalException("Unrecognized allocator type"); + } +} + +void ColumnDataAllocator::Initialize(ColumnDataAllocator &other) { + D_ASSERT(other.HasBlocks()); + blocks.push_back(other.blocks.back()); +} + +data_ptr_t ColumnDataAllocator::GetDataPointer(ChunkManagementState &state, uint32_t block_id, uint32_t offset) { + if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { + // in-memory allocator: construct pointer from block_id and offset + if (sizeof(uintptr_t) == sizeof(uint32_t)) { + uintptr_t pointer_value = uintptr_t(block_id); + return (data_ptr_t)pointer_value; // NOLINT - convert from pointer value back to pointer + } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { + uintptr_t pointer_value = (uintptr_t(offset) << 32) | uintptr_t(block_id); + return (data_ptr_t)pointer_value; // NOLINT - convert from pointer value back to pointer + } else { + throw InternalException("ColumnDataCollection: Architecture not supported!?"); + } + } + D_ASSERT(state.handles.find(block_id) != state.handles.end()); + return state.handles[block_id].Ptr() + offset; +} + +void ColumnDataAllocator::UnswizzlePointers(ChunkManagementState &state, Vector &result, idx_t v_offset, uint16_t count, + uint32_t block_id, uint32_t offset) { + D_ASSERT(result.GetType().InternalType() == PhysicalType::VARCHAR); + lock_guard guard(lock); + + auto &validity = FlatVector::Validity(result); + auto strings = FlatVector::GetData(result); + + // find first non-inlined string + uint32_t i = v_offset; + const uint32_t end = v_offset + count; + for (; i < end; i++) { + if (!validity.RowIsValid(i)) { + continue; + } + if (!strings[i].IsInlined()) { + break; + } + } + // at least one string must be non-inlined, otherwise this function should not be called + D_ASSERT(i < end); + + auto base_ptr = char_ptr_cast(GetDataPointer(state, block_id, offset)); + if (strings[i].GetData() == base_ptr) { + // pointers are still valid + return; + } + + // pointer mismatch! pointers are invalid, set them correctly + for (; i < end; i++) { + if (!validity.RowIsValid(i)) { + continue; + } + if (strings[i].IsInlined()) { + continue; + } + strings[i].SetPointer(base_ptr); + base_ptr += strings[i].GetSize(); + } +} + +void ColumnDataAllocator::DeleteBlock(uint32_t block_id) { + blocks[block_id].handle->SetCanDestroy(true); +} + +Allocator &ColumnDataAllocator::GetAllocator() { + return type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR ? *alloc.allocator + : alloc.buffer_manager->GetBufferAllocator(); +} + +void ColumnDataAllocator::InitializeChunkState(ChunkManagementState &state, ChunkMetaData &chunk) { + if (type != ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR && type != ColumnDataAllocatorType::HYBRID) { + // nothing to pin + return; + } + // release any handles that are no longer required + bool found_handle; + do { + found_handle = false; + for (auto it = state.handles.begin(); it != state.handles.end(); it++) { + if (chunk.block_ids.find(it->first) != chunk.block_ids.end()) { + // still required: do not release + continue; + } + state.handles.erase(it); + found_handle = true; + break; + } + } while (found_handle); + + // grab any handles that are now required + for (auto &block_id : chunk.block_ids) { + if (state.handles.find(block_id) != state.handles.end()) { + // already pinned: don't need to do anything + continue; + } + state.handles[block_id] = Pin(block_id); + } +} + +uint32_t BlockMetaData::Capacity() { + D_ASSERT(size <= capacity); + return capacity - size; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/column_data_collection.cpp b/src/duckdb/src/common/types/column/column_data_collection.cpp new file mode 100644 index 00000000..070620fd --- /dev/null +++ b/src/duckdb/src/common/types/column/column_data_collection.cpp @@ -0,0 +1,1103 @@ +#include "duckdb/common/types/column/column_data_collection.hpp" + +#include "duckdb/common/printer.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/column/column_data_collection_segment.hpp" +#include "duckdb/common/types/value_map.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +struct ColumnDataMetaData; + +typedef void (*column_data_copy_function_t)(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, + Vector &source, idx_t offset, idx_t copy_count); + +struct ColumnDataCopyFunction { + column_data_copy_function_t function; + vector child_functions; +}; + +struct ColumnDataMetaData { + ColumnDataMetaData(ColumnDataCopyFunction ©_function, ColumnDataCollectionSegment &segment, + ColumnDataAppendState &state, ChunkMetaData &chunk_data, VectorDataIndex vector_data_index) + : copy_function(copy_function), segment(segment), state(state), chunk_data(chunk_data), + vector_data_index(vector_data_index) { + } + ColumnDataMetaData(ColumnDataCopyFunction ©_function, ColumnDataMetaData &parent, + VectorDataIndex vector_data_index) + : copy_function(copy_function), segment(parent.segment), state(parent.state), chunk_data(parent.chunk_data), + vector_data_index(vector_data_index) { + } + + ColumnDataCopyFunction ©_function; + ColumnDataCollectionSegment &segment; + ColumnDataAppendState &state; + ChunkMetaData &chunk_data; + VectorDataIndex vector_data_index; + idx_t child_list_size = DConstants::INVALID_INDEX; + + VectorMetaData &GetVectorMetaData() { + return segment.GetVectorData(vector_data_index); + } +}; + +//! Explicitly initialized without types +ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p) { + types.clear(); + count = 0; + this->finished_append = false; + allocator = make_shared(allocator_p); +} + +ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p, vector types_p) { + Initialize(std::move(types_p)); + allocator = make_shared(allocator_p); +} + +ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p) { + Initialize(std::move(types_p)); + allocator = make_shared(buffer_manager); +} + +ColumnDataCollection::ColumnDataCollection(shared_ptr allocator_p, vector types_p) { + Initialize(std::move(types_p)); + this->allocator = std::move(allocator_p); +} + +ColumnDataCollection::ColumnDataCollection(ClientContext &context, vector types_p, + ColumnDataAllocatorType type) + : ColumnDataCollection(make_shared(context, type), std::move(types_p)) { + D_ASSERT(!types.empty()); +} + +ColumnDataCollection::ColumnDataCollection(ColumnDataCollection &other) + : ColumnDataCollection(other.allocator, other.types) { + other.finished_append = true; + D_ASSERT(!types.empty()); +} + +ColumnDataCollection::~ColumnDataCollection() { +} + +void ColumnDataCollection::Initialize(vector types_p) { + this->types = std::move(types_p); + this->count = 0; + this->finished_append = false; + D_ASSERT(!types.empty()); + copy_functions.reserve(types.size()); + for (auto &type : types) { + copy_functions.push_back(GetCopyFunction(type)); + } +} + +void ColumnDataCollection::CreateSegment() { + segments.emplace_back(make_uniq(allocator, types)); +} + +Allocator &ColumnDataCollection::GetAllocator() const { + return allocator->GetAllocator(); +} + +idx_t ColumnDataCollection::SizeInBytes() const { + idx_t total_size = 0; + for (const auto &segment : segments) { + total_size += segment->SizeInBytes(); + } + return total_size; +} + +//===--------------------------------------------------------------------===// +// ColumnDataRow +//===--------------------------------------------------------------------===// +ColumnDataRow::ColumnDataRow(DataChunk &chunk_p, idx_t row_index, idx_t base_index) + : chunk(chunk_p), row_index(row_index), base_index(base_index) { +} + +Value ColumnDataRow::GetValue(idx_t column_index) const { + D_ASSERT(column_index < chunk.ColumnCount()); + D_ASSERT(row_index < chunk.size()); + return chunk.data[column_index].GetValue(row_index); +} + +idx_t ColumnDataRow::RowIndex() const { + return base_index + row_index; +} + +//===--------------------------------------------------------------------===// +// ColumnDataRowCollection +//===--------------------------------------------------------------------===// +ColumnDataRowCollection::ColumnDataRowCollection(const ColumnDataCollection &collection) { + if (collection.Count() == 0) { + return; + } + // read all the chunks + ColumnDataScanState temp_scan_state; + collection.InitializeScan(temp_scan_state, ColumnDataScanProperties::DISALLOW_ZERO_COPY); + while (true) { + auto chunk = make_uniq(); + collection.InitializeScanChunk(*chunk); + if (!collection.Scan(temp_scan_state, *chunk)) { + break; + } + chunks.push_back(std::move(chunk)); + } + // now create all of the column data rows + rows.reserve(collection.Count()); + idx_t base_row = 0; + for (auto &chunk : chunks) { + for (idx_t row_idx = 0; row_idx < chunk->size(); row_idx++) { + rows.emplace_back(*chunk, row_idx, base_row); + } + base_row += chunk->size(); + } +} + +ColumnDataRow &ColumnDataRowCollection::operator[](idx_t i) { + return rows[i]; +} + +const ColumnDataRow &ColumnDataRowCollection::operator[](idx_t i) const { + return rows[i]; +} + +Value ColumnDataRowCollection::GetValue(idx_t column, idx_t index) const { + return rows[index].GetValue(column); +} + +//===--------------------------------------------------------------------===// +// ColumnDataChunkIterator +//===--------------------------------------------------------------------===// +ColumnDataChunkIterationHelper ColumnDataCollection::Chunks() const { + vector column_ids; + for (idx_t i = 0; i < ColumnCount(); i++) { + column_ids.push_back(i); + } + return Chunks(column_ids); +} + +ColumnDataChunkIterationHelper ColumnDataCollection::Chunks(vector column_ids) const { + return ColumnDataChunkIterationHelper(*this, std::move(column_ids)); +} + +ColumnDataChunkIterationHelper::ColumnDataChunkIterationHelper(const ColumnDataCollection &collection_p, + vector column_ids_p) + : collection(collection_p), column_ids(std::move(column_ids_p)) { +} + +ColumnDataChunkIterationHelper::ColumnDataChunkIterator::ColumnDataChunkIterator( + const ColumnDataCollection *collection_p, vector column_ids_p) + : collection(collection_p), scan_chunk(make_shared()), row_index(0) { + if (!collection) { + return; + } + collection->InitializeScan(scan_state, std::move(column_ids_p)); + collection->InitializeScanChunk(scan_state, *scan_chunk); + collection->Scan(scan_state, *scan_chunk); +} + +void ColumnDataChunkIterationHelper::ColumnDataChunkIterator::Next() { + if (!collection) { + return; + } + if (!collection->Scan(scan_state, *scan_chunk)) { + collection = nullptr; + row_index = 0; + } else { + row_index += scan_chunk->size(); + } +} + +ColumnDataChunkIterationHelper::ColumnDataChunkIterator & +ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator++() { + Next(); + return *this; +} + +bool ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator!=(const ColumnDataChunkIterator &other) const { + return collection != other.collection || row_index != other.row_index; +} + +DataChunk &ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator*() const { + return *scan_chunk; +} + +//===--------------------------------------------------------------------===// +// ColumnDataRowIterator +//===--------------------------------------------------------------------===// +ColumnDataRowIterationHelper ColumnDataCollection::Rows() const { + return ColumnDataRowIterationHelper(*this); +} + +ColumnDataRowIterationHelper::ColumnDataRowIterationHelper(const ColumnDataCollection &collection_p) + : collection(collection_p) { +} + +ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p) + : collection(collection_p), scan_chunk(make_shared()), current_row(*scan_chunk, 0, 0) { + if (!collection) { + return; + } + collection->InitializeScan(scan_state); + collection->InitializeScanChunk(*scan_chunk); + collection->Scan(scan_state, *scan_chunk); +} + +void ColumnDataRowIterationHelper::ColumnDataRowIterator::Next() { + if (!collection) { + return; + } + current_row.row_index++; + if (current_row.row_index >= scan_chunk->size()) { + current_row.base_index += scan_chunk->size(); + current_row.row_index = 0; + if (!collection->Scan(scan_state, *scan_chunk)) { + // exhausted collection: move iterator to nop state + current_row.base_index = 0; + collection = nullptr; + } + } +} + +ColumnDataRowIterationHelper::ColumnDataRowIterator ColumnDataRowIterationHelper::begin() { // NOLINT + return ColumnDataRowIterationHelper::ColumnDataRowIterator(collection.Count() == 0 ? nullptr : &collection); +} +ColumnDataRowIterationHelper::ColumnDataRowIterator ColumnDataRowIterationHelper::end() { // NOLINT + return ColumnDataRowIterationHelper::ColumnDataRowIterator(nullptr); +} + +ColumnDataRowIterationHelper::ColumnDataRowIterator &ColumnDataRowIterationHelper::ColumnDataRowIterator::operator++() { + Next(); + return *this; +} + +bool ColumnDataRowIterationHelper::ColumnDataRowIterator::operator!=(const ColumnDataRowIterator &other) const { + return collection != other.collection || current_row.row_index != other.current_row.row_index || + current_row.base_index != other.current_row.base_index; +} + +const ColumnDataRow &ColumnDataRowIterationHelper::ColumnDataRowIterator::operator*() const { + return current_row; +} + +//===--------------------------------------------------------------------===// +// Append +//===--------------------------------------------------------------------===// +void ColumnDataCollection::InitializeAppend(ColumnDataAppendState &state) { + D_ASSERT(!finished_append); + state.vector_data.resize(types.size()); + if (segments.empty()) { + CreateSegment(); + } + auto &segment = *segments.back(); + if (segment.chunk_data.empty()) { + segment.AllocateNewChunk(); + } + segment.InitializeChunkState(segment.chunk_data.size() - 1, state.current_chunk_state); +} + +void ColumnDataCopyValidity(const UnifiedVectorFormat &source_data, validity_t *target, idx_t source_offset, + idx_t target_offset, idx_t copy_count) { + ValidityMask validity(target); + if (target_offset == 0) { + // first time appending to this vector + // all data here is still uninitialized + // initialize the validity mask to set all to valid + validity.SetAllValid(STANDARD_VECTOR_SIZE); + } + // FIXME: we can do something more optimized here using bitshifts & bitwise ors + if (!source_data.validity.AllValid()) { + for (idx_t i = 0; i < copy_count; i++) { + auto idx = source_data.sel->get_index(source_offset + i); + if (!source_data.validity.RowIsValid(idx)) { + validity.SetInvalid(target_offset + i); + } + } + } +} + +template +struct BaseValueCopy { + static idx_t TypeSize() { + return sizeof(T); + } + + template + static void Assign(ColumnDataMetaData &meta_data, data_ptr_t target, data_ptr_t source, idx_t target_idx, + idx_t source_idx) { + auto result_data = (T *)target; + auto source_data = (T *)source; + result_data[target_idx] = OP::Operation(meta_data, source_data[source_idx]); + } +}; + +template +struct StandardValueCopy : public BaseValueCopy { + static T Operation(ColumnDataMetaData &, T input) { + return input; + } +}; + +struct StringValueCopy : public BaseValueCopy { + static string_t Operation(ColumnDataMetaData &meta_data, string_t input) { + return input.IsInlined() ? input : meta_data.segment.heap->AddBlob(input); + } +}; + +struct ConstListValueCopy : public BaseValueCopy { + using TYPE = list_entry_t; + + static TYPE Operation(ColumnDataMetaData &meta_data, TYPE input) { + input.offset = meta_data.child_list_size; + return input; + } +}; + +struct ListValueCopy : public BaseValueCopy { + using TYPE = list_entry_t; + + static TYPE Operation(ColumnDataMetaData &meta_data, TYPE input) { + input.offset = meta_data.child_list_size; + meta_data.child_list_size += input.length; + return input; + } +}; + +struct StructValueCopy { + static idx_t TypeSize() { + return 0; + } + + template + static void Assign(ColumnDataMetaData &meta_data, data_ptr_t target, data_ptr_t source, idx_t target_idx, + idx_t source_idx) { + } +}; + +template +static void TemplatedColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, + Vector &source, idx_t offset, idx_t count) { + auto &segment = meta_data.segment; + auto &append_state = meta_data.state; + + auto current_index = meta_data.vector_data_index; + idx_t remaining = count; + while (remaining > 0) { + auto ¤t_segment = segment.GetVectorData(current_index); + idx_t append_count = MinValue(STANDARD_VECTOR_SIZE - current_segment.count, remaining); + + auto base_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, current_segment.block_id, + current_segment.offset); + auto validity_data = ColumnDataCollectionSegment::GetValidityPointer(base_ptr, OP::TypeSize()); + + ValidityMask result_validity(validity_data); + if (current_segment.count == 0) { + // first time appending to this vector + // all data here is still uninitialized + // initialize the validity mask to set all to valid + result_validity.SetAllValid(STANDARD_VECTOR_SIZE); + } + for (idx_t i = 0; i < append_count; i++) { + auto source_idx = source_data.sel->get_index(offset + i); + if (source_data.validity.RowIsValid(source_idx)) { + OP::template Assign(meta_data, base_ptr, source_data.data, current_segment.count + i, source_idx); + } else { + result_validity.SetInvalid(current_segment.count + i); + } + } + current_segment.count += append_count; + offset += append_count; + remaining -= append_count; + if (remaining > 0) { + // need to append more, check if we need to allocate a new vector or not + if (!current_segment.next_data.IsValid()) { + segment.AllocateVector(source.GetType(), meta_data.chunk_data, append_state, current_index); + } + D_ASSERT(segment.GetVectorData(current_index).next_data.IsValid()); + current_index = segment.GetVectorData(current_index).next_data; + } + } +} + +template +static void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, + idx_t offset, idx_t copy_count) { + TemplatedColumnDataCopy>(meta_data, source_data, source, offset, copy_count); +} + +template <> +void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, + idx_t offset, idx_t copy_count) { + + const auto &allocator_type = meta_data.segment.allocator->GetType(); + if (allocator_type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR || + allocator_type == ColumnDataAllocatorType::HYBRID) { + // strings cannot be spilled to disk - use StringHeap + TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); + return; + } + D_ASSERT(allocator_type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); + + auto &segment = meta_data.segment; + auto &append_state = meta_data.state; + + VectorDataIndex child_index; + if (meta_data.GetVectorMetaData().child_index.IsValid()) { + // find the last child index + child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index); + auto next_child_index = segment.GetVectorData(child_index).next_data; + while (next_child_index.IsValid()) { + child_index = next_child_index; + next_child_index = segment.GetVectorData(child_index).next_data; + } + } + + auto current_index = meta_data.vector_data_index; + idx_t remaining = copy_count; + while (remaining > 0) { + // how many values fit in the current string vector + idx_t vector_remaining = + MinValue(STANDARD_VECTOR_SIZE - segment.GetVectorData(current_index).count, remaining); + + // 'append_count' is less if we cannot fit that amount of non-inlined strings on one buffer-managed block + idx_t append_count; + idx_t heap_size = 0; + const auto source_entries = UnifiedVectorFormat::GetData(source_data); + for (append_count = 0; append_count < vector_remaining; append_count++) { + auto source_idx = source_data.sel->get_index(offset + append_count); + if (!source_data.validity.RowIsValid(source_idx)) { + continue; + } + const auto &entry = source_entries[source_idx]; + if (entry.IsInlined()) { + continue; + } + if (heap_size + entry.GetSize() > Storage::BLOCK_SIZE) { + break; + } + heap_size += entry.GetSize(); + } + + if (vector_remaining != 0 && append_count == 0) { + // single string is longer than Storage::BLOCK_SIZE + // we allocate one block at a time for long strings + auto source_idx = source_data.sel->get_index(offset + append_count); + D_ASSERT(source_data.validity.RowIsValid(source_idx)); + D_ASSERT(!source_entries[source_idx].IsInlined()); + D_ASSERT(source_entries[source_idx].GetSize() > Storage::BLOCK_SIZE); + heap_size += source_entries[source_idx].GetSize(); + append_count++; + } + + // allocate string heap for the next 'append_count' strings + data_ptr_t heap_ptr = nullptr; + if (heap_size != 0) { + child_index = segment.AllocateStringHeap(heap_size, meta_data.chunk_data, append_state, child_index); + if (!meta_data.GetVectorMetaData().child_index.IsValid()) { + meta_data.GetVectorMetaData().child_index = meta_data.segment.AddChildIndex(child_index); + } + auto &child_segment = segment.GetVectorData(child_index); + heap_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, child_segment.block_id, + child_segment.offset); + } + + auto ¤t_segment = segment.GetVectorData(current_index); + auto base_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, current_segment.block_id, + current_segment.offset); + auto validity_data = ColumnDataCollectionSegment::GetValidityPointer(base_ptr, sizeof(string_t)); + ValidityMask target_validity(validity_data); + if (current_segment.count == 0) { + // first time appending to this vector + // all data here is still uninitialized + // initialize the validity mask to set all to valid + target_validity.SetAllValid(STANDARD_VECTOR_SIZE); + } + + auto target_entries = reinterpret_cast(base_ptr); + for (idx_t i = 0; i < append_count; i++) { + auto source_idx = source_data.sel->get_index(offset + i); + auto target_idx = current_segment.count + i; + if (!source_data.validity.RowIsValid(source_idx)) { + target_validity.SetInvalid(target_idx); + continue; + } + const auto &source_entry = source_entries[source_idx]; + auto &target_entry = target_entries[target_idx]; + if (source_entry.IsInlined()) { + target_entry = source_entry; + } else { + D_ASSERT(heap_ptr != nullptr); + memcpy(heap_ptr, source_entry.GetData(), source_entry.GetSize()); + target_entry = string_t(const_char_ptr_cast(heap_ptr), source_entry.GetSize()); + heap_ptr += source_entry.GetSize(); + } + } + + if (heap_size != 0) { + current_segment.swizzle_data.emplace_back(child_index, current_segment.count, append_count); + } + + current_segment.count += append_count; + offset += append_count; + remaining -= append_count; + + if (vector_remaining - append_count == 0) { + // need to append more, check if we need to allocate a new vector or not + if (!current_segment.next_data.IsValid()) { + segment.AllocateVector(source.GetType(), meta_data.chunk_data, append_state, current_index); + } + D_ASSERT(segment.GetVectorData(current_index).next_data.IsValid()); + current_index = segment.GetVectorData(current_index).next_data; + } + } +} + +template <> +void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, + idx_t offset, idx_t copy_count) { + + auto &segment = meta_data.segment; + + auto &child_vector = ListVector::GetEntry(source); + auto &child_type = child_vector.GetType(); + + if (!meta_data.GetVectorMetaData().child_index.IsValid()) { + auto child_index = segment.AllocateVector(child_type, meta_data.chunk_data, meta_data.state); + meta_data.GetVectorMetaData().child_index = meta_data.segment.AddChildIndex(child_index); + } + + auto &child_function = meta_data.copy_function.child_functions[0]; + auto child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index); + + // figure out the current list size by traversing the set of child entries + idx_t current_list_size = 0; + auto current_child_index = child_index; + while (current_child_index.IsValid()) { + auto &child_vdata = segment.GetVectorData(current_child_index); + current_list_size += child_vdata.count; + current_child_index = child_vdata.next_data; + } + + // set the child vector + UnifiedVectorFormat child_vector_data; + ColumnDataMetaData child_meta_data(child_function, meta_data, child_index); + auto info = ListVector::GetConsecutiveChildListInfo(source, offset, copy_count); + + if (info.needs_slicing) { + SelectionVector sel(info.child_list_info.length); + ListVector::GetConsecutiveChildSelVector(source, sel, offset, copy_count); + + auto sliced_child_vector = Vector(child_vector, sel, info.child_list_info.length); + sliced_child_vector.Flatten(info.child_list_info.length); + info.child_list_info.offset = 0; + + sliced_child_vector.ToUnifiedFormat(info.child_list_info.length, child_vector_data); + child_function.function(child_meta_data, child_vector_data, sliced_child_vector, info.child_list_info.offset, + info.child_list_info.length); + + } else { + child_vector.ToUnifiedFormat(info.child_list_info.length, child_vector_data); + child_function.function(child_meta_data, child_vector_data, child_vector, info.child_list_info.offset, + info.child_list_info.length); + } + + // now copy the list entries + meta_data.child_list_size = current_list_size; + if (info.is_constant) { + TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); + } else { + TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); + } +} + +void ColumnDataCopyStruct(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, + idx_t offset, idx_t copy_count) { + auto &segment = meta_data.segment; + + // copy the NULL values for the main struct vector + TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); + + auto &child_types = StructType::GetChildTypes(source.GetType()); + // now copy all the child vectors + D_ASSERT(meta_data.GetVectorMetaData().child_index.IsValid()); + auto &child_vectors = StructVector::GetEntries(source); + for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { + auto &child_function = meta_data.copy_function.child_functions[child_idx]; + auto child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index, child_idx); + ColumnDataMetaData child_meta_data(child_function, meta_data, child_index); + + UnifiedVectorFormat child_data; + child_vectors[child_idx]->ToUnifiedFormat(copy_count, child_data); + + child_function.function(child_meta_data, child_data, *child_vectors[child_idx], offset, copy_count); + } +} + +ColumnDataCopyFunction ColumnDataCollection::GetCopyFunction(const LogicalType &type) { + ColumnDataCopyFunction result; + column_data_copy_function_t function; + switch (type.InternalType()) { + case PhysicalType::BOOL: + function = ColumnDataCopy; + break; + case PhysicalType::INT8: + function = ColumnDataCopy; + break; + case PhysicalType::INT16: + function = ColumnDataCopy; + break; + case PhysicalType::INT32: + function = ColumnDataCopy; + break; + case PhysicalType::INT64: + function = ColumnDataCopy; + break; + case PhysicalType::INT128: + function = ColumnDataCopy; + break; + case PhysicalType::UINT8: + function = ColumnDataCopy; + break; + case PhysicalType::UINT16: + function = ColumnDataCopy; + break; + case PhysicalType::UINT32: + function = ColumnDataCopy; + break; + case PhysicalType::UINT64: + function = ColumnDataCopy; + break; + case PhysicalType::FLOAT: + function = ColumnDataCopy; + break; + case PhysicalType::DOUBLE: + function = ColumnDataCopy; + break; + case PhysicalType::INTERVAL: + function = ColumnDataCopy; + break; + case PhysicalType::VARCHAR: + function = ColumnDataCopy; + break; + case PhysicalType::STRUCT: { + function = ColumnDataCopyStruct; + auto &child_types = StructType::GetChildTypes(type); + for (auto &kv : child_types) { + result.child_functions.push_back(GetCopyFunction(kv.second)); + } + break; + } + case PhysicalType::LIST: { + function = ColumnDataCopy; + auto child_function = GetCopyFunction(ListType::GetChildType(type)); + result.child_functions.push_back(child_function); + break; + } + default: + throw InternalException("Unsupported type for ColumnDataCollection::GetCopyFunction"); + } + result.function = function; + return result; +} + +static bool IsComplexType(const LogicalType &type) { + switch (type.InternalType()) { + case PhysicalType::STRUCT: + case PhysicalType::LIST: + return true; + default: + return false; + }; +} + +void ColumnDataCollection::Append(ColumnDataAppendState &state, DataChunk &input) { + D_ASSERT(!finished_append); + D_ASSERT(types == input.GetTypes()); + + auto &segment = *segments.back(); + for (idx_t vector_idx = 0; vector_idx < types.size(); vector_idx++) { + if (IsComplexType(input.data[vector_idx].GetType())) { + input.data[vector_idx].Flatten(input.size()); + } + input.data[vector_idx].ToUnifiedFormat(input.size(), state.vector_data[vector_idx]); + } + + idx_t remaining = input.size(); + while (remaining > 0) { + auto &chunk_data = segment.chunk_data.back(); + idx_t append_amount = MinValue(remaining, STANDARD_VECTOR_SIZE - chunk_data.count); + if (append_amount > 0) { + idx_t offset = input.size() - remaining; + for (idx_t vector_idx = 0; vector_idx < types.size(); vector_idx++) { + ColumnDataMetaData meta_data(copy_functions[vector_idx], segment, state, chunk_data, + chunk_data.vector_data[vector_idx]); + copy_functions[vector_idx].function(meta_data, state.vector_data[vector_idx], input.data[vector_idx], + offset, append_amount); + } + chunk_data.count += append_amount; + } + remaining -= append_amount; + if (remaining > 0) { + // more to do + // allocate a new chunk + segment.AllocateNewChunk(); + segment.InitializeChunkState(segment.chunk_data.size() - 1, state.current_chunk_state); + } + } + segment.count += input.size(); + count += input.size(); +} + +void ColumnDataCollection::Append(DataChunk &input) { + ColumnDataAppendState state; + InitializeAppend(state); + Append(state, input); +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, ColumnDataScanProperties properties) const { + vector column_ids; + column_ids.reserve(types.size()); + for (idx_t i = 0; i < types.size(); i++) { + column_ids.push_back(i); + } + InitializeScan(state, std::move(column_ids), properties); +} + +void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, vector column_ids, + ColumnDataScanProperties properties) const { + state.chunk_index = 0; + state.segment_index = 0; + state.current_row_index = 0; + state.next_row_index = 0; + state.current_chunk_state.handles.clear(); + state.properties = properties; + state.column_ids = std::move(column_ids); +} + +void ColumnDataCollection::InitializeScan(ColumnDataParallelScanState &state, + ColumnDataScanProperties properties) const { + InitializeScan(state.scan_state, properties); +} + +void ColumnDataCollection::InitializeScan(ColumnDataParallelScanState &state, vector column_ids, + ColumnDataScanProperties properties) const { + InitializeScan(state.scan_state, std::move(column_ids), properties); +} + +bool ColumnDataCollection::Scan(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, + DataChunk &result) const { + result.Reset(); + + idx_t chunk_index; + idx_t segment_index; + idx_t row_index; + { + lock_guard l(state.lock); + if (!NextScanIndex(state.scan_state, chunk_index, segment_index, row_index)) { + return false; + } + } + ScanAtIndex(state, lstate, result, chunk_index, segment_index, row_index); + return true; +} + +void ColumnDataCollection::InitializeScanChunk(DataChunk &chunk) const { + chunk.Initialize(allocator->GetAllocator(), types); +} + +void ColumnDataCollection::InitializeScanChunk(ColumnDataScanState &state, DataChunk &chunk) const { + D_ASSERT(!state.column_ids.empty()); + vector chunk_types; + chunk_types.reserve(state.column_ids.size()); + for (idx_t i = 0; i < state.column_ids.size(); i++) { + auto column_idx = state.column_ids[i]; + D_ASSERT(column_idx < types.size()); + chunk_types.push_back(types[column_idx]); + } + chunk.Initialize(allocator->GetAllocator(), chunk_types); +} + +bool ColumnDataCollection::NextScanIndex(ColumnDataScanState &state, idx_t &chunk_index, idx_t &segment_index, + idx_t &row_index) const { + row_index = state.current_row_index = state.next_row_index; + // check if we still have collections to scan + if (state.segment_index >= segments.size()) { + // no more data left in the scan + return false; + } + // check within the current collection if we still have chunks to scan + while (state.chunk_index >= segments[state.segment_index]->chunk_data.size()) { + // exhausted all chunks for this internal data structure: move to the next one + state.chunk_index = 0; + state.segment_index++; + state.current_chunk_state.handles.clear(); + if (state.segment_index >= segments.size()) { + return false; + } + } + state.next_row_index += segments[state.segment_index]->chunk_data[state.chunk_index].count; + segment_index = state.segment_index; + chunk_index = state.chunk_index++; + return true; +} + +void ColumnDataCollection::ScanAtIndex(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, + DataChunk &result, idx_t chunk_index, idx_t segment_index, + idx_t row_index) const { + if (segment_index != lstate.current_segment_index) { + lstate.current_chunk_state.handles.clear(); + lstate.current_segment_index = segment_index; + } + auto &segment = *segments[segment_index]; + lstate.current_chunk_state.properties = state.scan_state.properties; + segment.ReadChunk(chunk_index, lstate.current_chunk_state, result, state.scan_state.column_ids); + lstate.current_row_index = row_index; + result.Verify(); +} + +bool ColumnDataCollection::Scan(ColumnDataScanState &state, DataChunk &result) const { + result.Reset(); + + idx_t chunk_index; + idx_t segment_index; + idx_t row_index; + if (!NextScanIndex(state, chunk_index, segment_index, row_index)) { + return false; + } + + // found a chunk to scan -> scan it + auto &segment = *segments[segment_index]; + state.current_chunk_state.properties = state.properties; + segment.ReadChunk(chunk_index, state.current_chunk_state, result, state.column_ids); + result.Verify(); + return true; +} + +ColumnDataRowCollection ColumnDataCollection::GetRows() const { + return ColumnDataRowCollection(*this); +} + +//===--------------------------------------------------------------------===// +// Combine +//===--------------------------------------------------------------------===// +void ColumnDataCollection::Combine(ColumnDataCollection &other) { + if (other.count == 0) { + return; + } + if (types != other.types) { + throw InternalException("Attempting to combine ColumnDataCollections with mismatching types"); + } + this->count += other.count; + this->segments.reserve(segments.size() + other.segments.size()); + for (auto &other_seg : other.segments) { + segments.push_back(std::move(other_seg)); + } + other.Reset(); + Verify(); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +idx_t ColumnDataCollection::ChunkCount() const { + idx_t chunk_count = 0; + for (auto &segment : segments) { + chunk_count += segment->ChunkCount(); + } + return chunk_count; +} + +void ColumnDataCollection::FetchChunk(idx_t chunk_idx, DataChunk &result) const { + D_ASSERT(chunk_idx < ChunkCount()); + for (auto &segment : segments) { + if (chunk_idx >= segment->ChunkCount()) { + chunk_idx -= segment->ChunkCount(); + } else { + segment->FetchChunk(chunk_idx, result); + return; + } + } + throw InternalException("Failed to find chunk in ColumnDataCollection"); +} + +//===--------------------------------------------------------------------===// +// Helpers +//===--------------------------------------------------------------------===// +void ColumnDataCollection::Verify() { +#ifdef DEBUG + // verify counts + idx_t total_segment_count = 0; + for (auto &segment : segments) { + segment->Verify(); + total_segment_count += segment->count; + } + D_ASSERT(total_segment_count == this->count); +#endif +} + +// LCOV_EXCL_START +string ColumnDataCollection::ToString() const { + DataChunk chunk; + InitializeScanChunk(chunk); + + ColumnDataScanState scan_state; + InitializeScan(scan_state); + + string result = StringUtil::Format("ColumnDataCollection - [%llu Chunks, %llu Rows]\n", ChunkCount(), Count()); + idx_t chunk_idx = 0; + idx_t row_count = 0; + while (Scan(scan_state, chunk)) { + result += + StringUtil::Format("Chunk %llu - [Rows %llu - %llu]\n", chunk_idx, row_count, row_count + chunk.size()) + + chunk.ToString(); + chunk_idx++; + row_count += chunk.size(); + } + + return result; +} +// LCOV_EXCL_STOP + +void ColumnDataCollection::Print() const { + Printer::Print(ToString()); +} + +void ColumnDataCollection::Reset() { + count = 0; + segments.clear(); + + // Refreshes the ColumnDataAllocator to prevent holding on to allocated data unnecessarily + allocator = make_shared(*allocator); +} + +struct ValueResultEquals { + bool operator()(const Value &a, const Value &b) const { + return Value::DefaultValuesAreEqual(a, b); + } +}; + +bool ColumnDataCollection::ResultEquals(const ColumnDataCollection &left, const ColumnDataCollection &right, + string &error_message, bool ordered) { + if (left.ColumnCount() != right.ColumnCount()) { + error_message = "Column count mismatch"; + return false; + } + if (left.Count() != right.Count()) { + error_message = "Row count mismatch"; + return false; + } + auto left_rows = left.GetRows(); + auto right_rows = right.GetRows(); + for (idx_t r = 0; r < left.Count(); r++) { + for (idx_t c = 0; c < left.ColumnCount(); c++) { + auto lvalue = left_rows.GetValue(c, r); + auto rvalue = right_rows.GetValue(c, r); + + if (!Value::DefaultValuesAreEqual(lvalue, rvalue)) { + error_message = + StringUtil::Format("%s <> %s (row: %lld, col: %lld)\n", lvalue.ToString(), rvalue.ToString(), r, c); + break; + } + } + if (!error_message.empty()) { + if (ordered) { + return false; + } else { + break; + } + } + } + if (!error_message.empty()) { + // do an unordered comparison + bool found_all = true; + for (idx_t c = 0; c < left.ColumnCount(); c++) { + std::unordered_multiset lvalues; + for (idx_t r = 0; r < left.Count(); r++) { + auto lvalue = left_rows.GetValue(c, r); + lvalues.insert(lvalue); + } + for (idx_t r = 0; r < right.Count(); r++) { + auto rvalue = right_rows.GetValue(c, r); + auto entry = lvalues.find(rvalue); + if (entry == lvalues.end()) { + found_all = false; + break; + } + lvalues.erase(entry); + } + if (!found_all) { + break; + } + } + if (!found_all) { + return false; + } + error_message = string(); + } + return true; +} + +vector> ColumnDataCollection::GetHeapReferences() { + vector> result(segments.size(), nullptr); + for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { + result[segment_idx] = segments[segment_idx]->heap; + } + return result; +} + +ColumnDataAllocatorType ColumnDataCollection::GetAllocatorType() const { + return allocator->GetType(); +} + +const vector> &ColumnDataCollection::GetSegments() const { + return segments; +} + +void ColumnDataCollection::Serialize(Serializer &serializer) const { + vector> values; + values.resize(ColumnCount()); + for (auto &chunk : Chunks()) { + for (idx_t c = 0; c < chunk.ColumnCount(); c++) { + for (idx_t r = 0; r < chunk.size(); r++) { + values[c].push_back(chunk.GetValue(c, r)); + } + } + } + serializer.WriteProperty(100, "types", types); + serializer.WriteProperty(101, "values", values); +} + +unique_ptr ColumnDataCollection::Deserialize(Deserializer &deserializer) { + auto types = deserializer.ReadProperty>(100, "types"); + auto values = deserializer.ReadProperty>>(101, "values"); + + auto collection = make_uniq(Allocator::DefaultAllocator(), types); + if (values.empty()) { + return collection; + } + DataChunk chunk; + chunk.Initialize(Allocator::DefaultAllocator(), types); + + for (idx_t r = 0; r < values[0].size(); r++) { + for (idx_t c = 0; c < types.size(); c++) { + chunk.SetValue(c, chunk.size(), values[c][r]); + } + chunk.SetCardinality(chunk.size() + 1); + if (chunk.size() == STANDARD_VECTOR_SIZE) { + collection->Append(chunk); + chunk.Reset(); + } + } + if (chunk.size() > 0) { + collection->Append(chunk); + } + return collection; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp new file mode 100644 index 00000000..2f534dc9 --- /dev/null +++ b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp @@ -0,0 +1,277 @@ +#include "duckdb/common/types/column/column_data_collection_segment.hpp" + +#include "duckdb/common/vector_operations/vector_operations.hpp" + +namespace duckdb { + +ColumnDataCollectionSegment::ColumnDataCollectionSegment(shared_ptr allocator_p, + vector types_p) + : allocator(std::move(allocator_p)), types(std::move(types_p)), count(0), + heap(make_shared(allocator->GetAllocator())) { +} + +idx_t ColumnDataCollectionSegment::GetDataSize(idx_t type_size) { + return AlignValue(type_size * STANDARD_VECTOR_SIZE); +} + +validity_t *ColumnDataCollectionSegment::GetValidityPointer(data_ptr_t base_ptr, idx_t type_size) { + return reinterpret_cast(base_ptr + GetDataSize(type_size)); +} + +VectorDataIndex ColumnDataCollectionSegment::AllocateVectorInternal(const LogicalType &type, ChunkMetaData &chunk_meta, + ChunkManagementState *chunk_state) { + VectorMetaData meta_data; + meta_data.count = 0; + + auto internal_type = type.InternalType(); + auto type_size = internal_type == PhysicalType::STRUCT ? 0 : GetTypeIdSize(internal_type); + allocator->AllocateData(GetDataSize(type_size) + ValidityMask::STANDARD_MASK_SIZE, meta_data.block_id, + meta_data.offset, chunk_state); + if (allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || + allocator->GetType() == ColumnDataAllocatorType::HYBRID) { + chunk_meta.block_ids.insert(meta_data.block_id); + } + + auto index = vector_data.size(); + vector_data.push_back(meta_data); + return VectorDataIndex(index); +} + +VectorDataIndex ColumnDataCollectionSegment::AllocateVector(const LogicalType &type, ChunkMetaData &chunk_meta, + ChunkManagementState *chunk_state, + VectorDataIndex prev_index) { + auto index = AllocateVectorInternal(type, chunk_meta, chunk_state); + if (prev_index.IsValid()) { + GetVectorData(prev_index).next_data = index; + } + if (type.InternalType() == PhysicalType::STRUCT) { + // initialize the struct children + auto &child_types = StructType::GetChildTypes(type); + auto base_child_index = ReserveChildren(child_types.size()); + for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { + VectorDataIndex prev_child_index; + if (prev_index.IsValid()) { + prev_child_index = GetChildIndex(GetVectorData(prev_index).child_index, child_idx); + } + auto child_index = AllocateVector(child_types[child_idx].second, chunk_meta, chunk_state, prev_child_index); + SetChildIndex(base_child_index, child_idx, child_index); + } + GetVectorData(index).child_index = base_child_index; + } + return index; +} + +VectorDataIndex ColumnDataCollectionSegment::AllocateVector(const LogicalType &type, ChunkMetaData &chunk_meta, + ColumnDataAppendState &append_state, + VectorDataIndex prev_index) { + return AllocateVector(type, chunk_meta, &append_state.current_chunk_state, prev_index); +} + +VectorDataIndex ColumnDataCollectionSegment::AllocateStringHeap(idx_t size, ChunkMetaData &chunk_meta, + ColumnDataAppendState &append_state, + VectorDataIndex prev_index) { + D_ASSERT(allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); + D_ASSERT(size != 0); + + VectorMetaData meta_data; + meta_data.count = 0; + + allocator->AllocateData(AlignValue(size), meta_data.block_id, meta_data.offset, &append_state.current_chunk_state); + chunk_meta.block_ids.insert(meta_data.block_id); + + VectorDataIndex index(vector_data.size()); + vector_data.push_back(meta_data); + + if (prev_index.IsValid()) { + GetVectorData(prev_index).next_data = index; + } + + return index; +} + +void ColumnDataCollectionSegment::AllocateNewChunk() { + ChunkMetaData meta_data; + meta_data.count = 0; + meta_data.vector_data.reserve(types.size()); + for (idx_t i = 0; i < types.size(); i++) { + auto vector_idx = AllocateVector(types[i], meta_data); + meta_data.vector_data.push_back(vector_idx); + } + chunk_data.push_back(std::move(meta_data)); +} + +void ColumnDataCollectionSegment::InitializeChunkState(idx_t chunk_index, ChunkManagementState &state) { + auto &chunk = chunk_data[chunk_index]; + allocator->InitializeChunkState(state, chunk); +} + +VectorDataIndex ColumnDataCollectionSegment::GetChildIndex(VectorChildIndex index, idx_t child_entry) { + D_ASSERT(index.IsValid()); + D_ASSERT(index.index + child_entry < child_indices.size()); + return VectorDataIndex(child_indices[index.index + child_entry]); +} + +VectorChildIndex ColumnDataCollectionSegment::AddChildIndex(VectorDataIndex index) { + auto result = child_indices.size(); + child_indices.push_back(index); + return VectorChildIndex(result); +} + +VectorChildIndex ColumnDataCollectionSegment::ReserveChildren(idx_t child_count) { + auto result = child_indices.size(); + for (idx_t i = 0; i < child_count; i++) { + child_indices.emplace_back(); + } + return VectorChildIndex(result); +} + +void ColumnDataCollectionSegment::SetChildIndex(VectorChildIndex base_idx, idx_t child_number, VectorDataIndex index) { + D_ASSERT(base_idx.IsValid()); + D_ASSERT(index.IsValid()); + D_ASSERT(base_idx.index + child_number < child_indices.size()); + child_indices[base_idx.index + child_number] = index; +} + +idx_t ColumnDataCollectionSegment::ReadVectorInternal(ChunkManagementState &state, VectorDataIndex vector_index, + Vector &result) { + auto &vector_type = result.GetType(); + auto internal_type = vector_type.InternalType(); + auto type_size = GetTypeIdSize(internal_type); + auto &vdata = GetVectorData(vector_index); + + auto base_ptr = allocator->GetDataPointer(state, vdata.block_id, vdata.offset); + auto validity_data = GetValidityPointer(base_ptr, type_size); + if (!vdata.next_data.IsValid() && state.properties != ColumnDataScanProperties::DISALLOW_ZERO_COPY) { + // no next data, we can do a zero-copy read of this vector + FlatVector::SetData(result, base_ptr); + FlatVector::Validity(result).Initialize(validity_data); + return vdata.count; + } + + // the data for this vector is spread over multiple vector data entries + // we need to copy over the data for each of the vectors + // first figure out how many rows we need to copy by looping over all of the child vector indexes + idx_t vector_count = 0; + auto next_index = vector_index; + while (next_index.IsValid()) { + auto ¤t_vdata = GetVectorData(next_index); + vector_count += current_vdata.count; + next_index = current_vdata.next_data; + } + // resize the result vector + result.Resize(0, vector_count); + next_index = vector_index; + // now perform the copy of each of the vectors + auto target_data = FlatVector::GetData(result); + auto &target_validity = FlatVector::Validity(result); + idx_t current_offset = 0; + while (next_index.IsValid()) { + auto ¤t_vdata = GetVectorData(next_index); + base_ptr = allocator->GetDataPointer(state, current_vdata.block_id, current_vdata.offset); + validity_data = GetValidityPointer(base_ptr, type_size); + if (type_size > 0) { + memcpy(target_data + current_offset * type_size, base_ptr, current_vdata.count * type_size); + } + ValidityMask current_validity(validity_data); + target_validity.SliceInPlace(current_validity, current_offset, 0, current_vdata.count); + current_offset += current_vdata.count; + next_index = current_vdata.next_data; + } + return vector_count; +} + +idx_t ColumnDataCollectionSegment::ReadVector(ChunkManagementState &state, VectorDataIndex vector_index, + Vector &result) { + auto &vector_type = result.GetType(); + auto internal_type = vector_type.InternalType(); + auto &vdata = GetVectorData(vector_index); + if (vdata.count == 0) { + return 0; + } + auto vcount = ReadVectorInternal(state, vector_index, result); + if (internal_type == PhysicalType::LIST) { + // list: copy child + auto &child_vector = ListVector::GetEntry(result); + auto child_count = ReadVector(state, GetChildIndex(vdata.child_index), child_vector); + ListVector::SetListSize(result, child_count); + } else if (internal_type == PhysicalType::STRUCT) { + auto &child_vectors = StructVector::GetEntries(result); + for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { + auto child_count = + ReadVector(state, GetChildIndex(vdata.child_index, child_idx), *child_vectors[child_idx]); + if (child_count != vcount) { + throw InternalException("Column Data Collection: mismatch in struct child sizes"); + } + } + } else if (internal_type == PhysicalType::VARCHAR) { + if (allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { + auto next_index = vector_index; + idx_t offset = 0; + while (next_index.IsValid()) { + auto ¤t_vdata = GetVectorData(next_index); + for (auto &swizzle_segment : current_vdata.swizzle_data) { + auto &string_heap_segment = GetVectorData(swizzle_segment.child_index); + allocator->UnswizzlePointers(state, result, offset + swizzle_segment.offset, swizzle_segment.count, + string_heap_segment.block_id, string_heap_segment.offset); + } + offset += current_vdata.count; + next_index = current_vdata.next_data; + } + } + if (state.properties == ColumnDataScanProperties::DISALLOW_ZERO_COPY) { + VectorOperations::Copy(result, result, vdata.count, 0, 0); + } + } + return vcount; +} + +void ColumnDataCollectionSegment::ReadChunk(idx_t chunk_index, ChunkManagementState &state, DataChunk &chunk, + const vector &column_ids) { + D_ASSERT(chunk.ColumnCount() == column_ids.size()); + D_ASSERT(state.properties != ColumnDataScanProperties::INVALID); + InitializeChunkState(chunk_index, state); + auto &chunk_meta = chunk_data[chunk_index]; + for (idx_t i = 0; i < column_ids.size(); i++) { + auto vector_idx = column_ids[i]; + D_ASSERT(vector_idx < chunk_meta.vector_data.size()); + ReadVector(state, chunk_meta.vector_data[vector_idx], chunk.data[i]); + } + chunk.SetCardinality(chunk_meta.count); +} + +idx_t ColumnDataCollectionSegment::ChunkCount() const { + return chunk_data.size(); +} + +idx_t ColumnDataCollectionSegment::SizeInBytes() const { + D_ASSERT(!allocator->IsShared()); + return allocator->SizeInBytes() + heap->SizeInBytes(); +} + +void ColumnDataCollectionSegment::FetchChunk(idx_t chunk_idx, DataChunk &result) { + vector column_ids; + column_ids.reserve(types.size()); + for (idx_t i = 0; i < types.size(); i++) { + column_ids.push_back(i); + } + FetchChunk(chunk_idx, result, column_ids); +} + +void ColumnDataCollectionSegment::FetchChunk(idx_t chunk_idx, DataChunk &result, const vector &column_ids) { + D_ASSERT(chunk_idx < chunk_data.size()); + ChunkManagementState state; + state.properties = ColumnDataScanProperties::DISALLOW_ZERO_COPY; + ReadChunk(chunk_idx, state, result, column_ids); +} + +void ColumnDataCollectionSegment::Verify() { +#ifdef DEBUG + idx_t total_count = 0; + for (idx_t i = 0; i < chunk_data.size(); i++) { + total_count += chunk_data[i].count; + } + D_ASSERT(total_count == this->count); +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/column_data_consumer.cpp b/src/duckdb/src/common/types/column/column_data_consumer.cpp new file mode 100644 index 00000000..d9fb4fdb --- /dev/null +++ b/src/duckdb/src/common/types/column/column_data_consumer.cpp @@ -0,0 +1,102 @@ +#include "duckdb/common/types/column/column_data_consumer.hpp" + +#include + +namespace duckdb { + +using ChunkReference = ColumnDataConsumer::ChunkReference; + +ChunkReference::ChunkReference(ColumnDataCollectionSegment *segment_p, uint32_t chunk_index_p) + : segment(segment_p), chunk_index_in_segment(chunk_index_p) { +} + +uint32_t ChunkReference::GetMinimumBlockID() const { + const auto &block_ids = segment->chunk_data[chunk_index_in_segment].block_ids; + return *std::min_element(block_ids.begin(), block_ids.end()); +} + +ColumnDataConsumer::ColumnDataConsumer(ColumnDataCollection &collection_p, vector column_ids) + : collection(collection_p), column_ids(std::move(column_ids)) { +} + +void ColumnDataConsumer::InitializeScan() { + chunk_count = collection.ChunkCount(); + current_chunk_index = 0; + chunk_delete_index = DConstants::INVALID_INDEX; + + // Initialize chunk references and sort them, so we can scan them in a sane order, regardless of how it was created + chunk_references.reserve(chunk_count); + for (auto &segment : collection.GetSegments()) { + for (idx_t chunk_index = 0; chunk_index < segment->chunk_data.size(); chunk_index++) { + chunk_references.emplace_back(segment.get(), chunk_index); + } + } + std::sort(chunk_references.begin(), chunk_references.end()); +} + +bool ColumnDataConsumer::AssignChunk(ColumnDataConsumerScanState &state) { + lock_guard guard(lock); + if (current_chunk_index == chunk_count) { + // All chunks have been assigned + state.current_chunk_state.handles.clear(); + state.chunk_index = DConstants::INVALID_INDEX; + return false; + } + // Assign chunk index + state.chunk_index = current_chunk_index++; + D_ASSERT(chunks_in_progress.find(state.chunk_index) == chunks_in_progress.end()); + chunks_in_progress.insert(state.chunk_index); + return true; +} + +void ColumnDataConsumer::ScanChunk(ColumnDataConsumerScanState &state, DataChunk &chunk) const { + D_ASSERT(state.chunk_index < chunk_count); + auto &chunk_ref = chunk_references[state.chunk_index]; + if (state.allocator != chunk_ref.segment->allocator.get()) { + // Previously scanned a chunk from a different allocator, reset the handles + state.allocator = chunk_ref.segment->allocator.get(); + state.current_chunk_state.handles.clear(); + } + chunk_ref.segment->ReadChunk(chunk_ref.chunk_index_in_segment, state.current_chunk_state, chunk, column_ids); +} + +void ColumnDataConsumer::FinishChunk(ColumnDataConsumerScanState &state) { + D_ASSERT(state.chunk_index < chunk_count); + idx_t delete_index_start; + idx_t delete_index_end; + { + lock_guard guard(lock); + D_ASSERT(chunks_in_progress.find(state.chunk_index) != chunks_in_progress.end()); + delete_index_start = chunk_delete_index; + delete_index_end = *std::min_element(chunks_in_progress.begin(), chunks_in_progress.end()); + chunks_in_progress.erase(state.chunk_index); + chunk_delete_index = delete_index_end; + } + ConsumeChunks(delete_index_start, delete_index_end); +} +void ColumnDataConsumer::ConsumeChunks(idx_t delete_index_start, idx_t delete_index_end) { + for (idx_t chunk_index = delete_index_start; chunk_index < delete_index_end; chunk_index++) { + if (chunk_index == 0) { + continue; + } + auto &prev_chunk_ref = chunk_references[chunk_index - 1]; + auto &curr_chunk_ref = chunk_references[chunk_index]; + auto prev_allocator = prev_chunk_ref.segment->allocator.get(); + auto curr_allocator = curr_chunk_ref.segment->allocator.get(); + auto prev_min_block_id = prev_chunk_ref.GetMinimumBlockID(); + auto curr_min_block_id = curr_chunk_ref.GetMinimumBlockID(); + if (prev_allocator != curr_allocator) { + // Moved to the next allocator, delete all remaining blocks in the previous one + for (uint32_t block_id = prev_min_block_id; block_id < prev_allocator->BlockCount(); block_id++) { + prev_allocator->DeleteBlock(block_id); + } + continue; + } + // Same allocator, see if we can delete blocks + for (uint32_t block_id = prev_min_block_id; block_id < curr_min_block_id; block_id++) { + prev_allocator->DeleteBlock(block_id); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/partitioned_column_data.cpp b/src/duckdb/src/common/types/column/partitioned_column_data.cpp new file mode 100644 index 00000000..0843931c --- /dev/null +++ b/src/duckdb/src/common/types/column/partitioned_column_data.cpp @@ -0,0 +1,172 @@ +#include "duckdb/common/types/column/partitioned_column_data.hpp" + +#include "duckdb/common/hive_partitioning.hpp" +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +PartitionedColumnData::PartitionedColumnData(PartitionedColumnDataType type_p, ClientContext &context_p, + vector types_p) + : type(type_p), context(context_p), types(std::move(types_p)), + allocators(make_shared()) { +} + +PartitionedColumnData::PartitionedColumnData(const PartitionedColumnData &other) + : type(other.type), context(other.context), types(other.types), allocators(other.allocators) { +} + +unique_ptr PartitionedColumnData::CreateShared() { + switch (type) { + case PartitionedColumnDataType::RADIX: + return make_uniq(Cast()); + case PartitionedColumnDataType::HIVE: + return make_uniq(Cast()); + default: + throw NotImplementedException("CreateShared for this type of PartitionedColumnData"); + } +} + +PartitionedColumnData::~PartitionedColumnData() { +} + +void PartitionedColumnData::InitializeAppendState(PartitionedColumnDataAppendState &state) const { + state.partition_sel.Initialize(); + state.slice_chunk.Initialize(BufferAllocator::Get(context), types); + InitializeAppendStateInternal(state); +} + +unique_ptr PartitionedColumnData::CreatePartitionBuffer() const { + auto result = make_uniq(); + result->Initialize(BufferAllocator::Get(context), types, BufferSize()); + return result; +} + +void PartitionedColumnData::Append(PartitionedColumnDataAppendState &state, DataChunk &input) { + // Compute partition indices and store them in state.partition_indices + ComputePartitionIndices(state, input); + + // Compute the counts per partition + const auto count = input.size(); + const auto partition_indices = FlatVector::GetData(state.partition_indices); + auto &partition_entries = state.partition_entries; + partition_entries.clear(); + switch (state.partition_indices.GetVectorType()) { + case VectorType::FLAT_VECTOR: + for (idx_t i = 0; i < count; i++) { + const auto &partition_index = partition_indices[i]; + auto partition_entry = partition_entries.find(partition_index); + if (partition_entry == partition_entries.end()) { + partition_entries[partition_index] = list_entry_t(0, 1); + } else { + partition_entry->second.length++; + } + } + break; + case VectorType::CONSTANT_VECTOR: + partition_entries[partition_indices[0]] = list_entry_t(0, count); + break; + default: + throw InternalException("Unexpected VectorType in PartitionedColumnData::Append"); + } + + // Early out: check if everything belongs to a single partition + if (partition_entries.size() == 1) { + const auto &partition_index = partition_entries.begin()->first; + auto &partition = *partitions[partition_index]; + auto &partition_append_state = *state.partition_append_states[partition_index]; + partition.Append(partition_append_state, input); + return; + } + + // Compute offsets from the counts + idx_t offset = 0; + for (auto &pc : partition_entries) { + auto &partition_entry = pc.second; + partition_entry.offset = offset; + offset += partition_entry.length; + } + + // Now initialize a single selection vector that acts as a selection vector for every partition + auto &all_partitions_sel = state.partition_sel; + for (idx_t i = 0; i < count; i++) { + const auto &partition_index = partition_indices[i]; + auto &partition_offset = partition_entries[partition_index].offset; + all_partitions_sel[partition_offset++] = i; + } + + // Loop through the partitions to append the new data to the partition buffers, and flush the buffers if necessary + SelectionVector partition_sel; + for (auto &pc : partition_entries) { + const auto &partition_index = pc.first; + + // Partition, buffer, and append state for this partition index + auto &partition = *partitions[partition_index]; + auto &partition_buffer = *state.partition_buffers[partition_index]; + auto &partition_append_state = *state.partition_append_states[partition_index]; + + // Length and offset into the selection vector for this chunk, for this partition + const auto &partition_entry = pc.second; + const auto &partition_length = partition_entry.length; + const auto partition_offset = partition_entry.offset - partition_length; + + // Create a selection vector for this partition using the offset into the single selection vector + partition_sel.Initialize(all_partitions_sel.data() + partition_offset); + + if (partition_length >= HalfBufferSize()) { + // Slice the input chunk using the selection vector + state.slice_chunk.Reset(); + state.slice_chunk.Slice(input, partition_sel, partition_length); + + // Append it to the partition directly + partition.Append(partition_append_state, state.slice_chunk); + } else { + // Append the input chunk to the partition buffer using the selection vector + partition_buffer.Append(input, false, &partition_sel, partition_length); + + if (partition_buffer.size() >= HalfBufferSize()) { + // Next batch won't fit in the buffer, flush it to the partition + partition.Append(partition_append_state, partition_buffer); + partition_buffer.Reset(); + partition_buffer.SetCapacity(BufferSize()); + } + } + } +} + +void PartitionedColumnData::FlushAppendState(PartitionedColumnDataAppendState &state) { + for (idx_t i = 0; i < state.partition_buffers.size(); i++) { + auto &partition_buffer = *state.partition_buffers[i]; + if (partition_buffer.size() > 0) { + partitions[i]->Append(partition_buffer); + partition_buffer.Reset(); + } + } +} + +void PartitionedColumnData::Combine(PartitionedColumnData &other) { + // Now combine the state's partitions into this + lock_guard guard(lock); + + if (partitions.empty()) { + // This is the first merge, we just copy them over + partitions = std::move(other.partitions); + } else { + D_ASSERT(partitions.size() == other.partitions.size()); + // Combine the append state's partitions into this PartitionedColumnData + for (idx_t i = 0; i < other.partitions.size(); i++) { + partitions[i]->Combine(*other.partitions[i]); + } + } +} + +vector> &PartitionedColumnData::GetPartitions() { + return partitions; +} + +void PartitionedColumnData::CreateAllocator() { + allocators->allocators.emplace_back(make_shared(BufferManager::GetBufferManager(context))); + allocators->allocators.back()->MakeShared(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/conflict_info.cpp b/src/duckdb/src/common/types/conflict_info.cpp new file mode 100644 index 00000000..e832bf89 --- /dev/null +++ b/src/duckdb/src/common/types/conflict_info.cpp @@ -0,0 +1,18 @@ +#include "duckdb/common/types/constraint_conflict_info.hpp" +#include "duckdb/storage/index.hpp" + +namespace duckdb { + +bool ConflictInfo::ConflictTargetMatches(Index &index) const { + if (only_check_unique && !index.IsUnique()) { + // We only support checking ON CONFLICT for Unique/Primary key constraints + return false; + } + if (column_ids.empty()) { + return true; + } + // Check whether the column ids match + return column_ids == index.column_id_set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/conflict_manager.cpp b/src/duckdb/src/common/types/conflict_manager.cpp new file mode 100644 index 00000000..38d61240 --- /dev/null +++ b/src/duckdb/src/common/types/conflict_manager.cpp @@ -0,0 +1,258 @@ +#include "duckdb/common/types/conflict_manager.hpp" +#include "duckdb/storage/index.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/common/types/constraint_conflict_info.hpp" + +namespace duckdb { + +ConflictManager::ConflictManager(VerifyExistenceType lookup_type, idx_t input_size, + optional_ptr conflict_info) + : lookup_type(lookup_type), input_size(input_size), conflict_info(conflict_info), conflicts(input_size, false), + mode(ConflictManagerMode::THROW) { +} + +ManagedSelection &ConflictManager::InternalSelection() { + if (!conflicts.Initialized()) { + conflicts.Initialize(input_size); + } + return conflicts; +} + +const unordered_set &ConflictManager::InternalConflictSet() const { + D_ASSERT(conflict_set); + return *conflict_set; +} + +Vector &ConflictManager::InternalRowIds() { + if (!row_ids) { + row_ids = make_uniq(LogicalType::ROW_TYPE, input_size); + } + return *row_ids; +} + +Vector &ConflictManager::InternalIntermediate() { + if (!intermediate_vector) { + intermediate_vector = make_uniq(LogicalType::BOOLEAN, true, true, input_size); + } + return *intermediate_vector; +} + +const ConflictInfo &ConflictManager::GetConflictInfo() const { + D_ASSERT(conflict_info); + return *conflict_info; +} + +void ConflictManager::FinishLookup() { + if (mode == ConflictManagerMode::THROW) { + return; + } + if (!SingleIndexTarget()) { + return; + } + if (conflicts.Count() != 0) { + // We have recorded conflicts from the one index we're interested in + // We set this so we don't duplicate the conflicts when there are duplicate indexes + // that also match our conflict target + single_index_finished = true; + } +} + +void ConflictManager::SetMode(ConflictManagerMode mode) { + // Only allow SCAN when we have conflict info + D_ASSERT(mode != ConflictManagerMode::SCAN || conflict_info != nullptr); + this->mode = mode; +} + +void ConflictManager::AddToConflictSet(idx_t chunk_index) { + if (!conflict_set) { + conflict_set = make_uniq>(); + } + auto &set = *conflict_set; + set.insert(chunk_index); +} + +void ConflictManager::AddConflictInternal(idx_t chunk_index, row_t row_id) { + D_ASSERT(mode == ConflictManagerMode::SCAN); + + // Only when we should not throw on conflict should we get here + D_ASSERT(!ShouldThrow(chunk_index)); + AddToConflictSet(chunk_index); + if (SingleIndexTarget()) { + // If we have identical indexes, only the conflicts of the first index should be recorded + // as the other index(es) would produce the exact same conflicts anyways + if (single_index_finished) { + return; + } + + // We can be more efficient because we don't need to merge conflicts of multiple indexes + auto &selection = InternalSelection(); + auto &row_ids = InternalRowIds(); + auto data = FlatVector::GetData(row_ids); + data[selection.Count()] = row_id; + selection.Append(chunk_index); + } else { + auto &intermediate = InternalIntermediate(); + auto data = FlatVector::GetData(intermediate); + // Mark this index in the chunk as producing a conflict + data[chunk_index] = true; + if (row_id_map.empty()) { + row_id_map.resize(input_size); + } + row_id_map[chunk_index] = row_id; + } +} + +bool ConflictManager::IsConflict(LookupResultType type) { + switch (type) { + case LookupResultType::LOOKUP_NULL: { + if (ShouldIgnoreNulls()) { + return false; + } + // If nulls are not ignored, treat this as a hit instead + return IsConflict(LookupResultType::LOOKUP_HIT); + } + case LookupResultType::LOOKUP_HIT: { + return true; + } + case LookupResultType::LOOKUP_MISS: { + // FIXME: If we record a miss as a conflict when the verify type is APPEND_FK, then we can simplify the checks + // in VerifyForeignKeyConstraint This also means we should not record a hit as a conflict when the verify type + // is APPEND_FK + return false; + } + default: { + throw NotImplementedException("Type not implemented for LookupResultType"); + } + } +} + +bool ConflictManager::AddHit(idx_t chunk_index, row_t row_id) { + D_ASSERT(chunk_index < input_size); + // First check if this causes a conflict + if (!IsConflict(LookupResultType::LOOKUP_HIT)) { + return false; + } + + // Then check if we should throw on a conflict + if (ShouldThrow(chunk_index)) { + return true; + } + if (mode == ConflictManagerMode::THROW) { + // When our mode is THROW, and the chunk index is part of the previously scanned conflicts + // then we ignore the conflict instead + D_ASSERT(!ShouldThrow(chunk_index)); + return false; + } + D_ASSERT(conflict_info); + // Because we don't throw, we need to register the conflict + AddConflictInternal(chunk_index, row_id); + return false; +} + +bool ConflictManager::AddMiss(idx_t chunk_index) { + D_ASSERT(chunk_index < input_size); + return IsConflict(LookupResultType::LOOKUP_MISS); +} + +bool ConflictManager::AddNull(idx_t chunk_index) { + D_ASSERT(chunk_index < input_size); + if (!IsConflict(LookupResultType::LOOKUP_NULL)) { + return false; + } + return AddHit(chunk_index, DConstants::INVALID_INDEX); +} + +bool ConflictManager::SingleIndexTarget() const { + D_ASSERT(conflict_info); + // We are only interested in a specific index + return !conflict_info->column_ids.empty(); +} + +bool ConflictManager::ShouldThrow(idx_t chunk_index) const { + if (mode == ConflictManagerMode::SCAN) { + return false; + } + D_ASSERT(mode == ConflictManagerMode::THROW); + if (conflict_set == nullptr) { + // No conflicts were scanned, so this conflict is not in the set + return true; + } + auto &set = InternalConflictSet(); + if (set.count(chunk_index)) { + return false; + } + // None of the scanned conflicts arose from this insert tuple + return true; +} + +bool ConflictManager::ShouldIgnoreNulls() const { + switch (lookup_type) { + case VerifyExistenceType::APPEND: + return true; + case VerifyExistenceType::APPEND_FK: + return false; + case VerifyExistenceType::DELETE_FK: + return true; + default: + throw InternalException("Type not implemented for VerifyExistenceType"); + } +} + +Vector &ConflictManager::RowIds() { + D_ASSERT(finalized); + return *row_ids; +} + +const ManagedSelection &ConflictManager::Conflicts() const { + D_ASSERT(finalized); + return conflicts; +} + +idx_t ConflictManager::ConflictCount() const { + return conflicts.Count(); +} + +void ConflictManager::Finalize() { + D_ASSERT(!finalized); + if (SingleIndexTarget()) { + // Selection vector has been directly populated already, no need to finalize + finalized = true; + return; + } + finalized = true; + if (!intermediate_vector) { + // No conflicts were found, we're done + return; + } + auto &intermediate = InternalIntermediate(); + auto data = FlatVector::GetData(intermediate); + auto &selection = InternalSelection(); + // Create the selection vector from the encountered conflicts + for (idx_t i = 0; i < input_size; i++) { + if (data[i]) { + selection.Append(i); + } + } + // Now create the row_ids Vector, aligned with the selection vector + auto &row_ids = InternalRowIds(); + auto row_id_data = FlatVector::GetData(row_ids); + + for (idx_t i = 0; i < selection.Count(); i++) { + D_ASSERT(!row_id_map.empty()); + auto index = selection[i]; + D_ASSERT(index < row_id_map.size()); + auto row_id = row_id_map[index]; + row_id_data[i] = row_id; + } + intermediate_vector.reset(); +} + +VerifyExistenceType ConflictManager::LookupType() const { + return this->lookup_type; +} + +void ConflictManager::SetIndexCount(idx_t count) { + index_count = count; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/data_chunk.cpp b/src/duckdb/src/common/types/data_chunk.cpp new file mode 100644 index 00000000..4c7e16a6 --- /dev/null +++ b/src/duckdb/src/common/types/data_chunk.cpp @@ -0,0 +1,375 @@ +#include "duckdb/common/types/data_chunk.hpp" + +#include "duckdb/common/array.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/sel_cache.hpp" +#include "duckdb/common/types/vector_cache.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/execution_context.hpp" + +#include "duckdb/common/serializer/memory_stream.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" + +namespace duckdb { + +DataChunk::DataChunk() : count(0), capacity(STANDARD_VECTOR_SIZE) { +} + +DataChunk::~DataChunk() { +} + +void DataChunk::InitializeEmpty(const vector &types) { + InitializeEmpty(types.begin(), types.end()); +} + +void DataChunk::Initialize(Allocator &allocator, const vector &types, idx_t capacity_p) { + Initialize(allocator, types.begin(), types.end(), capacity_p); +} + +void DataChunk::Initialize(ClientContext &context, const vector &types, idx_t capacity_p) { + Initialize(Allocator::Get(context), types, capacity_p); +} + +void DataChunk::Initialize(Allocator &allocator, vector::const_iterator begin, + vector::const_iterator end, idx_t capacity_p) { + D_ASSERT(data.empty()); // can only be initialized once + D_ASSERT(std::distance(begin, end) != 0); // empty chunk not allowed + capacity = capacity_p; + for (; begin != end; begin++) { + VectorCache cache(allocator, *begin, capacity); + data.emplace_back(cache); + vector_caches.push_back(std::move(cache)); + } +} + +void DataChunk::Initialize(ClientContext &context, vector::const_iterator begin, + vector::const_iterator end, idx_t capacity_p) { + Initialize(Allocator::Get(context), begin, end, capacity_p); +} + +void DataChunk::InitializeEmpty(vector::const_iterator begin, vector::const_iterator end) { + capacity = STANDARD_VECTOR_SIZE; + D_ASSERT(data.empty()); // can only be initialized once + D_ASSERT(std::distance(begin, end) != 0); // empty chunk not allowed + for (; begin != end; begin++) { + data.emplace_back(*begin, nullptr); + } +} + +void DataChunk::Reset() { + if (data.empty()) { + return; + } + if (vector_caches.size() != data.size()) { + throw InternalException("VectorCache and column count mismatch in DataChunk::Reset"); + } + for (idx_t i = 0; i < ColumnCount(); i++) { + data[i].ResetFromCache(vector_caches[i]); + } + capacity = STANDARD_VECTOR_SIZE; + SetCardinality(0); +} + +void DataChunk::Destroy() { + data.clear(); + vector_caches.clear(); + capacity = 0; + SetCardinality(0); +} + +Value DataChunk::GetValue(idx_t col_idx, idx_t index) const { + D_ASSERT(index < size()); + return data[col_idx].GetValue(index); +} + +void DataChunk::SetValue(idx_t col_idx, idx_t index, const Value &val) { + data[col_idx].SetValue(index, val); +} + +bool DataChunk::AllConstant() const { + for (auto &v : data) { + if (v.GetVectorType() != VectorType::CONSTANT_VECTOR) { + return false; + } + } + return true; +} + +void DataChunk::Reference(DataChunk &chunk) { + D_ASSERT(chunk.ColumnCount() <= ColumnCount()); + SetCapacity(chunk); + SetCardinality(chunk); + for (idx_t i = 0; i < chunk.ColumnCount(); i++) { + data[i].Reference(chunk.data[i]); + } +} + +void DataChunk::Move(DataChunk &chunk) { + SetCardinality(chunk); + SetCapacity(chunk); + data = std::move(chunk.data); + vector_caches = std::move(chunk.vector_caches); + + chunk.Destroy(); +} + +void DataChunk::Copy(DataChunk &other, idx_t offset) const { + D_ASSERT(ColumnCount() == other.ColumnCount()); + D_ASSERT(other.size() == 0); + + for (idx_t i = 0; i < ColumnCount(); i++) { + D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); + VectorOperations::Copy(data[i], other.data[i], size(), offset, 0); + } + other.SetCardinality(size() - offset); +} + +void DataChunk::Copy(DataChunk &other, const SelectionVector &sel, const idx_t source_count, const idx_t offset) const { + D_ASSERT(ColumnCount() == other.ColumnCount()); + D_ASSERT(other.size() == 0); + D_ASSERT((offset + source_count) <= size()); + + for (idx_t i = 0; i < ColumnCount(); i++) { + D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); + VectorOperations::Copy(data[i], other.data[i], sel, source_count, offset, 0); + } + other.SetCardinality(source_count - offset); +} + +void DataChunk::Split(DataChunk &other, idx_t split_idx) { + D_ASSERT(other.size() == 0); + D_ASSERT(other.data.empty()); + D_ASSERT(split_idx < data.size()); + const idx_t num_cols = data.size(); + for (idx_t col_idx = split_idx; col_idx < num_cols; col_idx++) { + other.data.push_back(std::move(data[col_idx])); + other.vector_caches.push_back(std::move(vector_caches[col_idx])); + } + for (idx_t col_idx = split_idx; col_idx < num_cols; col_idx++) { + data.pop_back(); + vector_caches.pop_back(); + } + other.SetCapacity(*this); + other.SetCardinality(*this); +} + +void DataChunk::Fuse(DataChunk &other) { + D_ASSERT(other.size() == size()); + const idx_t num_cols = other.data.size(); + for (idx_t col_idx = 0; col_idx < num_cols; ++col_idx) { + data.emplace_back(std::move(other.data[col_idx])); + vector_caches.emplace_back(std::move(other.vector_caches[col_idx])); + } + other.Destroy(); +} + +void DataChunk::ReferenceColumns(DataChunk &other, const vector &column_ids) { + D_ASSERT(ColumnCount() == column_ids.size()); + Reset(); + for (idx_t col_idx = 0; col_idx < ColumnCount(); col_idx++) { + auto &other_col = other.data[column_ids[col_idx]]; + auto &this_col = data[col_idx]; + D_ASSERT(other_col.GetType() == this_col.GetType()); + this_col.Reference(other_col); + } + SetCardinality(other.size()); +} + +void DataChunk::Append(const DataChunk &other, bool resize, SelectionVector *sel, idx_t sel_count) { + idx_t new_size = sel ? size() + sel_count : size() + other.size(); + if (other.size() == 0) { + return; + } + if (ColumnCount() != other.ColumnCount()) { + throw InternalException("Column counts of appending chunk doesn't match!"); + } + if (new_size > capacity) { + if (resize) { + auto new_capacity = NextPowerOfTwo(new_size); + for (idx_t i = 0; i < ColumnCount(); i++) { + data[i].Resize(size(), new_capacity); + } + capacity = new_capacity; + } else { + throw InternalException("Can't append chunk to other chunk without resizing"); + } + } + for (idx_t i = 0; i < ColumnCount(); i++) { + D_ASSERT(data[i].GetVectorType() == VectorType::FLAT_VECTOR); + if (sel) { + VectorOperations::Copy(other.data[i], data[i], *sel, sel_count, 0, size()); + } else { + VectorOperations::Copy(other.data[i], data[i], other.size(), 0, size()); + } + } + SetCardinality(new_size); +} + +void DataChunk::Flatten() { + for (idx_t i = 0; i < ColumnCount(); i++) { + data[i].Flatten(size()); + } +} + +vector DataChunk::GetTypes() { + vector types; + for (idx_t i = 0; i < ColumnCount(); i++) { + types.push_back(data[i].GetType()); + } + return types; +} + +string DataChunk::ToString() const { + string retval = "Chunk - [" + to_string(ColumnCount()) + " Columns]\n"; + for (idx_t i = 0; i < ColumnCount(); i++) { + retval += "- " + data[i].ToString(size()) + "\n"; + } + return retval; +} + +void DataChunk::Serialize(Serializer &serializer) const { + + // write the count + auto row_count = size(); + serializer.WriteProperty(100, "rows", row_count); + + // we should never try to serialize empty data chunks + auto column_count = ColumnCount(); + D_ASSERT(column_count); + + // write the types + serializer.WriteList(101, "types", column_count, + [&](Serializer::List &list, idx_t i) { list.WriteElement(data[i].GetType()); }); + + // write the data + serializer.WriteList(102, "columns", column_count, [&](Serializer::List &list, idx_t i) { + list.WriteObject([&](Serializer &object) { + // Reference the vector to avoid potentially mutating it during serialization + Vector serialized_vector(data[i].GetType()); + serialized_vector.Reference(data[i]); + serialized_vector.Serialize(object, row_count); + }); + }); +} + +void DataChunk::Deserialize(Deserializer &deserializer) { + + // read and set the row count + auto row_count = deserializer.ReadProperty(100, "rows"); + + // read the types + vector types; + deserializer.ReadList(101, "types", [&](Deserializer::List &list, idx_t i) { + auto type = list.ReadElement(); + types.push_back(type); + }); + + // initialize the data chunk + D_ASSERT(!types.empty()); + Initialize(Allocator::DefaultAllocator(), types); + SetCardinality(row_count); + + // read the data + deserializer.ReadList(102, "columns", [&](Deserializer::List &list, idx_t i) { + list.ReadObject([&](Deserializer &object) { data[i].Deserialize(object, row_count); }); + }); +} + +void DataChunk::Slice(const SelectionVector &sel_vector, idx_t count_p) { + this->count = count_p; + SelCache merge_cache; + for (idx_t c = 0; c < ColumnCount(); c++) { + data[c].Slice(sel_vector, count_p, merge_cache); + } +} + +void DataChunk::Slice(DataChunk &other, const SelectionVector &sel, idx_t count_p, idx_t col_offset) { + D_ASSERT(other.ColumnCount() <= col_offset + ColumnCount()); + this->count = count_p; + SelCache merge_cache; + for (idx_t c = 0; c < other.ColumnCount(); c++) { + if (other.data[c].GetVectorType() == VectorType::DICTIONARY_VECTOR) { + // already a dictionary! merge the dictionaries + data[col_offset + c].Reference(other.data[c]); + data[col_offset + c].Slice(sel, count_p, merge_cache); + } else { + data[col_offset + c].Slice(other.data[c], sel, count_p); + } + } +} + +unsafe_unique_array DataChunk::ToUnifiedFormat() { + auto unified_data = make_unsafe_uniq_array(ColumnCount()); + for (idx_t col_idx = 0; col_idx < ColumnCount(); col_idx++) { + data[col_idx].ToUnifiedFormat(size(), unified_data[col_idx]); + } + return unified_data; +} + +void DataChunk::Hash(Vector &result) { + D_ASSERT(result.GetType().id() == LogicalType::HASH); + VectorOperations::Hash(data[0], result, size()); + for (idx_t i = 1; i < ColumnCount(); i++) { + VectorOperations::CombineHash(result, data[i], size()); + } +} + +void DataChunk::Hash(vector &column_ids, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalType::HASH); + D_ASSERT(!column_ids.empty()); + + VectorOperations::Hash(data[column_ids[0]], result, size()); + for (idx_t i = 1; i < column_ids.size(); i++) { + VectorOperations::CombineHash(result, data[column_ids[i]], size()); + } +} + +void DataChunk::Verify() { +#ifdef DEBUG + D_ASSERT(size() <= capacity); + + // verify that all vectors in this chunk have the chunk selection vector + for (idx_t i = 0; i < ColumnCount(); i++) { + data[i].Verify(size()); + } + + if (!ColumnCount()) { + // don't try to round-trip dummy data chunks with no data + // e.g., these exist in queries like 'SELECT distinct(col0, col1) FROM tbl', where we have groups, but no + // payload so the payload will be such an empty data chunk + return; + } + + // verify that we can round-trip chunk serialization + MemoryStream mem_stream; + BinarySerializer serializer(mem_stream); + + serializer.Begin(); + Serialize(serializer); + serializer.End(); + + mem_stream.Rewind(); + + BinaryDeserializer deserializer(mem_stream); + DataChunk new_chunk; + + deserializer.Begin(); + new_chunk.Deserialize(deserializer); + deserializer.End(); + + D_ASSERT(size() == new_chunk.size()); +#endif +} + +void DataChunk::Print() const { + Printer::Print(ToString()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/date.cpp b/src/duckdb/src/common/types/date.cpp new file mode 100644 index 00000000..5ab61e3b --- /dev/null +++ b/src/duckdb/src/common/types/date.cpp @@ -0,0 +1,622 @@ +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/limits.hpp" + +#include +#include +#include + +namespace duckdb { + +static_assert(sizeof(date_t) == sizeof(int32_t), "date_t was padded"); + +const char *Date::PINF = "infinity"; // NOLINT +const char *Date::NINF = "-infinity"; // NOLINT +const char *Date::EPOCH = "epoch"; // NOLINT + +const string_t Date::MONTH_NAMES_ABBREVIATED[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}; +const string_t Date::MONTH_NAMES[] = {"January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December"}; +const string_t Date::DAY_NAMES[] = {"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}; +const string_t Date::DAY_NAMES_ABBREVIATED[] = {"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"}; + +const int32_t Date::NORMAL_DAYS[] = {0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; +const int32_t Date::CUMULATIVE_DAYS[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; +const int32_t Date::LEAP_DAYS[] = {0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; +const int32_t Date::CUMULATIVE_LEAP_DAYS[] = {0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366}; +const int8_t Date::MONTH_PER_DAY_OF_YEAR[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}; +const int8_t Date::LEAP_MONTH_PER_DAY_OF_YEAR[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}; +const int32_t Date::CUMULATIVE_YEAR_DAYS[] = { + 0, 365, 730, 1096, 1461, 1826, 2191, 2557, 2922, 3287, 3652, 4018, 4383, 4748, + 5113, 5479, 5844, 6209, 6574, 6940, 7305, 7670, 8035, 8401, 8766, 9131, 9496, 9862, + 10227, 10592, 10957, 11323, 11688, 12053, 12418, 12784, 13149, 13514, 13879, 14245, 14610, 14975, + 15340, 15706, 16071, 16436, 16801, 17167, 17532, 17897, 18262, 18628, 18993, 19358, 19723, 20089, + 20454, 20819, 21184, 21550, 21915, 22280, 22645, 23011, 23376, 23741, 24106, 24472, 24837, 25202, + 25567, 25933, 26298, 26663, 27028, 27394, 27759, 28124, 28489, 28855, 29220, 29585, 29950, 30316, + 30681, 31046, 31411, 31777, 32142, 32507, 32872, 33238, 33603, 33968, 34333, 34699, 35064, 35429, + 35794, 36160, 36525, 36890, 37255, 37621, 37986, 38351, 38716, 39082, 39447, 39812, 40177, 40543, + 40908, 41273, 41638, 42004, 42369, 42734, 43099, 43465, 43830, 44195, 44560, 44926, 45291, 45656, + 46021, 46387, 46752, 47117, 47482, 47847, 48212, 48577, 48942, 49308, 49673, 50038, 50403, 50769, + 51134, 51499, 51864, 52230, 52595, 52960, 53325, 53691, 54056, 54421, 54786, 55152, 55517, 55882, + 56247, 56613, 56978, 57343, 57708, 58074, 58439, 58804, 59169, 59535, 59900, 60265, 60630, 60996, + 61361, 61726, 62091, 62457, 62822, 63187, 63552, 63918, 64283, 64648, 65013, 65379, 65744, 66109, + 66474, 66840, 67205, 67570, 67935, 68301, 68666, 69031, 69396, 69762, 70127, 70492, 70857, 71223, + 71588, 71953, 72318, 72684, 73049, 73414, 73779, 74145, 74510, 74875, 75240, 75606, 75971, 76336, + 76701, 77067, 77432, 77797, 78162, 78528, 78893, 79258, 79623, 79989, 80354, 80719, 81084, 81450, + 81815, 82180, 82545, 82911, 83276, 83641, 84006, 84371, 84736, 85101, 85466, 85832, 86197, 86562, + 86927, 87293, 87658, 88023, 88388, 88754, 89119, 89484, 89849, 90215, 90580, 90945, 91310, 91676, + 92041, 92406, 92771, 93137, 93502, 93867, 94232, 94598, 94963, 95328, 95693, 96059, 96424, 96789, + 97154, 97520, 97885, 98250, 98615, 98981, 99346, 99711, 100076, 100442, 100807, 101172, 101537, 101903, + 102268, 102633, 102998, 103364, 103729, 104094, 104459, 104825, 105190, 105555, 105920, 106286, 106651, 107016, + 107381, 107747, 108112, 108477, 108842, 109208, 109573, 109938, 110303, 110669, 111034, 111399, 111764, 112130, + 112495, 112860, 113225, 113591, 113956, 114321, 114686, 115052, 115417, 115782, 116147, 116513, 116878, 117243, + 117608, 117974, 118339, 118704, 119069, 119435, 119800, 120165, 120530, 120895, 121260, 121625, 121990, 122356, + 122721, 123086, 123451, 123817, 124182, 124547, 124912, 125278, 125643, 126008, 126373, 126739, 127104, 127469, + 127834, 128200, 128565, 128930, 129295, 129661, 130026, 130391, 130756, 131122, 131487, 131852, 132217, 132583, + 132948, 133313, 133678, 134044, 134409, 134774, 135139, 135505, 135870, 136235, 136600, 136966, 137331, 137696, + 138061, 138427, 138792, 139157, 139522, 139888, 140253, 140618, 140983, 141349, 141714, 142079, 142444, 142810, + 143175, 143540, 143905, 144271, 144636, 145001, 145366, 145732, 146097}; + +void Date::ExtractYearOffset(int32_t &n, int32_t &year, int32_t &year_offset) { + year = Date::EPOCH_YEAR; + // first we normalize n to be in the year range [1970, 2370] + // since leap years repeat every 400 years, we can safely normalize just by "shifting" the CumulativeYearDays array + while (n < 0) { + n += Date::DAYS_PER_YEAR_INTERVAL; + year -= Date::YEAR_INTERVAL; + } + while (n >= Date::DAYS_PER_YEAR_INTERVAL) { + n -= Date::DAYS_PER_YEAR_INTERVAL; + year += Date::YEAR_INTERVAL; + } + // interpolation search + // we can find an upper bound of the year by assuming each year has 365 days + year_offset = n / 365; + // because of leap years we might be off by a little bit: compensate by decrementing the year offset until we find + // our year + while (n < Date::CUMULATIVE_YEAR_DAYS[year_offset]) { + year_offset--; + D_ASSERT(year_offset >= 0); + } + year += year_offset; + D_ASSERT(n >= Date::CUMULATIVE_YEAR_DAYS[year_offset]); +} + +void Date::Convert(date_t d, int32_t &year, int32_t &month, int32_t &day) { + auto n = d.days; + int32_t year_offset; + Date::ExtractYearOffset(n, year, year_offset); + + day = n - Date::CUMULATIVE_YEAR_DAYS[year_offset]; + D_ASSERT(day >= 0 && day <= 365); + + bool is_leap_year = (Date::CUMULATIVE_YEAR_DAYS[year_offset + 1] - Date::CUMULATIVE_YEAR_DAYS[year_offset]) == 366; + if (is_leap_year) { + month = Date::LEAP_MONTH_PER_DAY_OF_YEAR[day]; + day -= Date::CUMULATIVE_LEAP_DAYS[month - 1]; + } else { + month = Date::MONTH_PER_DAY_OF_YEAR[day]; + day -= Date::CUMULATIVE_DAYS[month - 1]; + } + day++; + D_ASSERT(day > 0 && day <= (is_leap_year ? Date::LEAP_DAYS[month] : Date::NORMAL_DAYS[month])); + D_ASSERT(month > 0 && month <= 12); +} + +bool Date::TryFromDate(int32_t year, int32_t month, int32_t day, date_t &result) { + int32_t n = 0; + if (!Date::IsValid(year, month, day)) { + return false; + } + n += Date::IsLeapYear(year) ? Date::CUMULATIVE_LEAP_DAYS[month - 1] : Date::CUMULATIVE_DAYS[month - 1]; + n += day - 1; + if (year < 1970) { + int32_t diff_from_base = 1970 - year; + int32_t year_index = 400 - (diff_from_base % 400); + int32_t fractions = diff_from_base / 400; + n += Date::CUMULATIVE_YEAR_DAYS[year_index]; + n -= Date::DAYS_PER_YEAR_INTERVAL; + n -= fractions * Date::DAYS_PER_YEAR_INTERVAL; + } else if (year >= 2370) { + int32_t diff_from_base = year - 2370; + int32_t year_index = diff_from_base % 400; + int32_t fractions = diff_from_base / 400; + n += Date::CUMULATIVE_YEAR_DAYS[year_index]; + n += Date::DAYS_PER_YEAR_INTERVAL; + n += fractions * Date::DAYS_PER_YEAR_INTERVAL; + } else { + n += Date::CUMULATIVE_YEAR_DAYS[year - 1970]; + } +#ifdef DEBUG + int32_t y, m, d; + Date::Convert(date_t(n), y, m, d); + D_ASSERT(year == y); + D_ASSERT(month == m); + D_ASSERT(day == d); +#endif + result = date_t(n); + return true; +} + +date_t Date::FromDate(int32_t year, int32_t month, int32_t day) { + date_t result; + if (!Date::TryFromDate(year, month, day, result)) { + throw ConversionException("Date out of range: %d-%d-%d", year, month, day); + } + return result; +} + +bool Date::ParseDoubleDigit(const char *buf, idx_t len, idx_t &pos, int32_t &result) { + if (pos < len && StringUtil::CharacterIsDigit(buf[pos])) { + result = buf[pos++] - '0'; + if (pos < len && StringUtil::CharacterIsDigit(buf[pos])) { + result = (buf[pos++] - '0') + result * 10; + } + return true; + } + return false; +} + +static bool TryConvertDateSpecial(const char *buf, idx_t len, idx_t &pos, const char *special) { + auto p = pos; + for (; p < len && *special; ++p) { + const auto s = *special++; + if (!s || StringUtil::CharacterToLower(buf[p]) != s) { + return false; + } + } + if (*special) { + return false; + } + pos = p; + return true; +} + +bool Date::TryConvertDate(const char *buf, idx_t len, idx_t &pos, date_t &result, bool &special, bool strict) { + special = false; + pos = 0; + if (len == 0) { + return false; + } + + int32_t day = 0; + int32_t month = -1; + int32_t year = 0; + bool yearneg = false; + int sep; + + // skip leading spaces + while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { + pos++; + } + + if (pos >= len) { + return false; + } + if (buf[pos] == '-') { + yearneg = true; + pos++; + if (pos >= len) { + return false; + } + } + if (!StringUtil::CharacterIsDigit(buf[pos])) { + // Check for special values + if (TryConvertDateSpecial(buf, len, pos, PINF)) { + result = yearneg ? date_t::ninfinity() : date_t::infinity(); + } else if (TryConvertDateSpecial(buf, len, pos, EPOCH)) { + result = date_t::epoch(); + } else { + return false; + } + // skip trailing spaces - parsing must be strict here + while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { + pos++; + } + special = true; + return pos == len; + } + // first parse the year + for (; pos < len && StringUtil::CharacterIsDigit(buf[pos]); pos++) { + if (year >= 100000000) { + return false; + } + year = (buf[pos] - '0') + year * 10; + } + if (yearneg) { + year = -year; + } + + if (pos >= len) { + return false; + } + + // fetch the separator + sep = buf[pos++]; + if (sep != ' ' && sep != '-' && sep != '/' && sep != '\\') { + // invalid separator + return false; + } + + // parse the month + if (!Date::ParseDoubleDigit(buf, len, pos, month)) { + return false; + } + + if (pos >= len) { + return false; + } + + if (buf[pos++] != sep) { + return false; + } + + if (pos >= len) { + return false; + } + + // now parse the day + if (!Date::ParseDoubleDigit(buf, len, pos, day)) { + return false; + } + + // check for an optional trailing " (BC)"" + if (len - pos >= 5 && StringUtil::CharacterIsSpace(buf[pos]) && buf[pos + 1] == '(' && + StringUtil::CharacterToLower(buf[pos + 2]) == 'b' && StringUtil::CharacterToLower(buf[pos + 3]) == 'c' && + buf[pos + 4] == ')') { + if (yearneg || year == 0) { + return false; + } + year = -year + 1; + pos += 5; + } + + // in strict mode, check remaining string for non-space characters + if (strict) { + // skip trailing spaces + while (pos < len && StringUtil::CharacterIsSpace((unsigned char)buf[pos])) { + pos++; + } + // check position. if end was not reached, non-space chars remaining + if (pos < len) { + return false; + } + } else { + // in non-strict mode, check for any direct trailing digits + if (pos < len && StringUtil::CharacterIsDigit((unsigned char)buf[pos])) { + return false; + } + } + + return Date::TryFromDate(year, month, day, result); +} + +string Date::ConversionError(const string &str) { + return StringUtil::Format("date field value out of range: \"%s\", " + "expected format is (YYYY-MM-DD)", + str); +} + +string Date::ConversionError(string_t str) { + return ConversionError(str.GetString()); +} + +date_t Date::FromCString(const char *buf, idx_t len, bool strict) { + date_t result; + idx_t pos; + bool special = false; + if (!TryConvertDate(buf, len, pos, result, special, strict)) { + throw ConversionException(ConversionError(string(buf, len))); + } + return result; +} + +date_t Date::FromString(const string &str, bool strict) { + return Date::FromCString(str.c_str(), str.size(), strict); +} + +string Date::ToString(date_t date) { + // PG displays temporal infinities in lowercase, + // but numerics in Titlecase. + if (date == date_t::infinity()) { + return PINF; + } else if (date == date_t::ninfinity()) { + return NINF; + } + int32_t date_units[3]; + idx_t year_length; + bool add_bc; + Date::Convert(date, date_units[0], date_units[1], date_units[2]); + + auto length = DateToStringCast::Length(date_units, year_length, add_bc); + auto buffer = make_unsafe_uniq_array(length); + DateToStringCast::Format(buffer.get(), date_units, year_length, add_bc); + return string(buffer.get(), length); +} + +string Date::Format(int32_t year, int32_t month, int32_t day) { + return ToString(Date::FromDate(year, month, day)); +} + +bool Date::IsLeapYear(int32_t year) { + return year % 4 == 0 && (year % 100 != 0 || year % 400 == 0); +} + +bool Date::IsValid(int32_t year, int32_t month, int32_t day) { + if (month < 1 || month > 12) { + return false; + } + if (day < 1) { + return false; + } + if (year <= DATE_MIN_YEAR) { + if (year < DATE_MIN_YEAR) { + return false; + } else if (year == DATE_MIN_YEAR) { + if (month < DATE_MIN_MONTH || (month == DATE_MIN_MONTH && day < DATE_MIN_DAY)) { + return false; + } + } + } + if (year >= DATE_MAX_YEAR) { + if (year > DATE_MAX_YEAR) { + return false; + } else if (year == DATE_MAX_YEAR) { + if (month > DATE_MAX_MONTH || (month == DATE_MAX_MONTH && day > DATE_MAX_DAY)) { + return false; + } + } + } + return Date::IsLeapYear(year) ? day <= Date::LEAP_DAYS[month] : day <= Date::NORMAL_DAYS[month]; +} + +int32_t Date::MonthDays(int32_t year, int32_t month) { + D_ASSERT(month >= 1 && month <= 12); + return Date::IsLeapYear(year) ? Date::LEAP_DAYS[month] : Date::NORMAL_DAYS[month]; +} + +date_t Date::EpochDaysToDate(int32_t epoch) { + return (date_t)epoch; +} + +int32_t Date::EpochDays(date_t date) { + return date.days; +} + +date_t Date::EpochToDate(int64_t epoch) { + return date_t(epoch / Interval::SECS_PER_DAY); +} + +int64_t Date::Epoch(date_t date) { + return ((int64_t)date.days) * Interval::SECS_PER_DAY; +} + +int64_t Date::EpochNanoseconds(date_t date) { + int64_t result; + if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY * 1000, + result)) { + throw ConversionException("Could not convert DATE (%s) to nanoseconds", Date::ToString(date)); + } + return result; +} + +int64_t Date::EpochMicroseconds(date_t date) { + int64_t result; + if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY, result)) { + throw ConversionException("Could not convert DATE (%s) to microseconds", Date::ToString(date)); + } + return result; +} + +int64_t Date::EpochMilliseconds(date_t date) { + int64_t result; + const auto MILLIS_PER_DAY = Interval::MICROS_PER_DAY / Interval::MICROS_PER_MSEC; + if (!TryMultiplyOperator::Operation(date.days, MILLIS_PER_DAY, result)) { + throw ConversionException("Could not convert DATE (%s) to milliseconds", Date::ToString(date)); + } + return result; +} + +int32_t Date::ExtractYear(date_t d, int32_t *last_year) { + auto n = d.days; + // cached look up: check if year of this date is the same as the last one we looked up + // note that this only works for years in the range [1970, 2370] + if (n >= Date::CUMULATIVE_YEAR_DAYS[*last_year] && n < Date::CUMULATIVE_YEAR_DAYS[*last_year + 1]) { + return Date::EPOCH_YEAR + *last_year; + } + int32_t year; + Date::ExtractYearOffset(n, year, *last_year); + return year; +} + +int32_t Date::ExtractYear(timestamp_t ts, int32_t *last_year) { + return Date::ExtractYear(Timestamp::GetDate(ts), last_year); +} + +int32_t Date::ExtractYear(date_t d) { + int32_t year, year_offset; + Date::ExtractYearOffset(d.days, year, year_offset); + return year; +} + +int32_t Date::ExtractMonth(date_t date) { + int32_t out_year, out_month, out_day; + Date::Convert(date, out_year, out_month, out_day); + return out_month; +} + +int32_t Date::ExtractDay(date_t date) { + int32_t out_year, out_month, out_day; + Date::Convert(date, out_year, out_month, out_day); + return out_day; +} + +int32_t Date::ExtractDayOfTheYear(date_t date) { + int32_t year, year_offset; + Date::ExtractYearOffset(date.days, year, year_offset); + return date.days - Date::CUMULATIVE_YEAR_DAYS[year_offset] + 1; +} + +int64_t Date::ExtractJulianDay(date_t date) { + // Julian Day 0 is (-4713, 11, 24) in the proleptic Gregorian calendar. + static const int64_t JULIAN_EPOCH = -2440588; + return date.days - JULIAN_EPOCH; +} + +int32_t Date::ExtractISODayOfTheWeek(date_t date) { + // date of 0 is 1970-01-01, which was a Thursday (4) + // -7 = 4 + // -6 = 5 + // -5 = 6 + // -4 = 7 + // -3 = 1 + // -2 = 2 + // -1 = 3 + // 0 = 4 + // 1 = 5 + // 2 = 6 + // 3 = 7 + // 4 = 1 + // 5 = 2 + // 6 = 3 + // 7 = 4 + if (date.days < 0) { + // negative date: start off at 4 and cycle downwards + return (7 - ((-int64_t(date.days) + 3) % 7)); + } else { + // positive date: start off at 4 and cycle upwards + return ((int64_t(date.days) + 3) % 7) + 1; + } +} + +template +static T PythonDivMod(const T &x, const T &y, T &r) { + // D_ASSERT(y > 0); + T quo = x / y; + r = x - quo * y; + if (r < 0) { + --quo; + r += y; + } + // D_ASSERT(0 <= r && r < y); + return quo; +} + +static date_t GetISOWeekOne(int32_t year) { + const auto first_day = Date::FromDate(year, 1, 1); /* ord of 1/1 */ + /* 0 if 1/1 is a Monday, 1 if a Tue, etc. */ + const auto first_weekday = Date::ExtractISODayOfTheWeek(first_day) - 1; + /* ordinal of closest Monday at or before 1/1 */ + auto week1_monday = first_day - first_weekday; + + if (first_weekday > 3) { /* if 1/1 was Fri, Sat, Sun */ + week1_monday += 7; + } + + return week1_monday; +} + +static int32_t GetISOYearWeek(const date_t date, int32_t &year) { + int32_t month, day; + Date::Convert(date, year, month, day); + auto week1_monday = GetISOWeekOne(year); + auto week = PythonDivMod((date.days - week1_monday.days), 7, day); + if (week < 0) { + week1_monday = GetISOWeekOne(--year); + week = PythonDivMod((date.days - week1_monday.days), 7, day); + } else if (week >= 52 && date >= GetISOWeekOne(year + 1)) { + ++year; + week = 0; + } + + return week + 1; +} + +void Date::ExtractISOYearWeek(date_t date, int32_t &year, int32_t &week) { + week = GetISOYearWeek(date, year); +} + +int32_t Date::ExtractISOWeekNumber(date_t date) { + int32_t year, week; + ExtractISOYearWeek(date, year, week); + return week; +} + +int32_t Date::ExtractISOYearNumber(date_t date) { + int32_t year, week; + ExtractISOYearWeek(date, year, week); + return year; +} + +int32_t Date::ExtractWeekNumberRegular(date_t date, bool monday_first) { + int32_t year, month, day; + Date::Convert(date, year, month, day); + month -= 1; + day -= 1; + // get the day of the year + auto day_of_the_year = + (Date::IsLeapYear(year) ? Date::CUMULATIVE_LEAP_DAYS[month] : Date::CUMULATIVE_DAYS[month]) + day; + // now figure out the first monday or sunday of the year + // what day is January 1st? + auto day_of_jan_first = Date::ExtractISODayOfTheWeek(Date::FromDate(year, 1, 1)); + // monday = 1, sunday = 7 + int32_t first_week_start; + if (monday_first) { + // have to find next "1" + if (day_of_jan_first == 1) { + // jan 1 is monday: starts immediately + first_week_start = 0; + } else { + // jan 1 is not monday: count days until next monday + first_week_start = 8 - day_of_jan_first; + } + } else { + first_week_start = 7 - day_of_jan_first; + } + if (day_of_the_year < first_week_start) { + // day occurs before first week starts: week 0 + return 0; + } + return ((day_of_the_year - first_week_start) / 7) + 1; +} + +// Returns the date of the monday of the current week. +date_t Date::GetMondayOfCurrentWeek(date_t date) { + int32_t dotw = Date::ExtractISODayOfTheWeek(date); + return date - (dotw - 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/decimal.cpp b/src/duckdb/src/common/types/decimal.cpp new file mode 100644 index 00000000..323cec4e --- /dev/null +++ b/src/duckdb/src/common/types/decimal.cpp @@ -0,0 +1,33 @@ +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +template +string TemplatedDecimalToString(SIGNED value, uint8_t width, uint8_t scale) { + auto len = DecimalToString::DecimalLength(value, width, scale); + auto data = make_unsafe_uniq_array(len + 1); + DecimalToString::FormatDecimal(value, width, scale, data.get(), len); + return string(data.get(), len); +} + +string Decimal::ToString(int16_t value, uint8_t width, uint8_t scale) { + return TemplatedDecimalToString(value, width, scale); +} + +string Decimal::ToString(int32_t value, uint8_t width, uint8_t scale) { + return TemplatedDecimalToString(value, width, scale); +} + +string Decimal::ToString(int64_t value, uint8_t width, uint8_t scale) { + return TemplatedDecimalToString(value, width, scale); +} + +string Decimal::ToString(hugeint_t value, uint8_t width, uint8_t scale) { + auto len = HugeintToStringCast::DecimalLength(value, width, scale); + auto data = make_unsafe_uniq_array(len + 1); + HugeintToStringCast::FormatDecimal(value, width, scale, data.get(), len); + return string(data.get(), len); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/hash.cpp b/src/duckdb/src/common/types/hash.cpp new file mode 100644 index 00000000..63e0d662 --- /dev/null +++ b/src/duckdb/src/common/types/hash.cpp @@ -0,0 +1,139 @@ +#include "duckdb/common/types/hash.hpp" + +#include "duckdb/common/helper.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/interval.hpp" + +#include +#include + +namespace duckdb { + +template <> +hash_t Hash(uint64_t val) { + return murmurhash64(val); +} + +template <> +hash_t Hash(int64_t val) { + return murmurhash64((uint64_t)val); +} + +template <> +hash_t Hash(hugeint_t val) { + return murmurhash64(val.lower) ^ murmurhash64(val.upper); +} + +template +struct FloatingPointEqualityTransform { + static void OP(T &val) { + if (val == (T)0.0) { + // Turn negative zero into positive zero + val = (T)0.0; + } else if (std::isnan(val)) { + val = std::numeric_limits::quiet_NaN(); + } + } +}; + +template <> +hash_t Hash(float val) { + static_assert(sizeof(float) == sizeof(uint32_t), ""); + FloatingPointEqualityTransform::OP(val); + uint32_t uval = Load(const_data_ptr_cast(&val)); + return murmurhash64(uval); +} + +template <> +hash_t Hash(double val) { + static_assert(sizeof(double) == sizeof(uint64_t), ""); + FloatingPointEqualityTransform::OP(val); + uint64_t uval = Load(const_data_ptr_cast(&val)); + return murmurhash64(uval); +} + +template <> +hash_t Hash(interval_t val) { + return Hash(val.days) ^ Hash(val.months) ^ Hash(val.micros); +} + +template <> +hash_t Hash(const char *str) { + return Hash(str, strlen(str)); +} + +template <> +hash_t Hash(string_t val) { + return Hash(val.GetData(), val.GetSize()); +} + +template <> +hash_t Hash(char *val) { + return Hash(val); +} + +// MIT License +// Copyright (c) 2018-2021 Martin Ankerl +// https://github.com/martinus/robin-hood-hashing/blob/3.11.5/LICENSE +hash_t HashBytes(void *ptr, size_t len) noexcept { + static constexpr uint64_t M = UINT64_C(0xc6a4a7935bd1e995); + static constexpr uint64_t SEED = UINT64_C(0xe17a1465); + static constexpr unsigned int R = 47; + + auto const *const data64 = static_cast(ptr); + uint64_t h = SEED ^ (len * M); + + size_t const n_blocks = len / 8; + for (size_t i = 0; i < n_blocks; ++i) { + auto k = Load(reinterpret_cast(data64 + i)); + + k *= M; + k ^= k >> R; + k *= M; + + h ^= k; + h *= M; + } + + auto const *const data8 = reinterpret_cast(data64 + n_blocks); + switch (len & 7U) { + case 7: + h ^= static_cast(data8[6]) << 48U; + DUCKDB_EXPLICIT_FALLTHROUGH; + case 6: + h ^= static_cast(data8[5]) << 40U; + DUCKDB_EXPLICIT_FALLTHROUGH; + case 5: + h ^= static_cast(data8[4]) << 32U; + DUCKDB_EXPLICIT_FALLTHROUGH; + case 4: + h ^= static_cast(data8[3]) << 24U; + DUCKDB_EXPLICIT_FALLTHROUGH; + case 3: + h ^= static_cast(data8[2]) << 16U; + DUCKDB_EXPLICIT_FALLTHROUGH; + case 2: + h ^= static_cast(data8[1]) << 8U; + DUCKDB_EXPLICIT_FALLTHROUGH; + case 1: + h ^= static_cast(data8[0]); + h *= M; + DUCKDB_EXPLICIT_FALLTHROUGH; + default: + break; + } + h ^= h >> R; + h *= M; + h ^= h >> R; + return static_cast(h); +} + +hash_t Hash(const char *val, size_t size) { + return HashBytes((void *)val, size); +} + +hash_t Hash(uint8_t *val, size_t size) { + return HashBytes((void *)val, size); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/hugeint.cpp b/src/duckdb/src/common/types/hugeint.cpp new file mode 100644 index 00000000..f778c4e0 --- /dev/null +++ b/src/duckdb/src/common/types/hugeint.cpp @@ -0,0 +1,810 @@ +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/windows_undefs.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/operator/cast_operators.hpp" + +#include +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// String Conversion +//===--------------------------------------------------------------------===// +const hugeint_t Hugeint::POWERS_OF_TEN[] { + hugeint_t(1), + hugeint_t(10), + hugeint_t(100), + hugeint_t(1000), + hugeint_t(10000), + hugeint_t(100000), + hugeint_t(1000000), + hugeint_t(10000000), + hugeint_t(100000000), + hugeint_t(1000000000), + hugeint_t(10000000000), + hugeint_t(100000000000), + hugeint_t(1000000000000), + hugeint_t(10000000000000), + hugeint_t(100000000000000), + hugeint_t(1000000000000000), + hugeint_t(10000000000000000), + hugeint_t(100000000000000000), + hugeint_t(1000000000000000000), + hugeint_t(1000000000000000000) * hugeint_t(10), + hugeint_t(1000000000000000000) * hugeint_t(100), + hugeint_t(1000000000000000000) * hugeint_t(1000), + hugeint_t(1000000000000000000) * hugeint_t(10000), + hugeint_t(1000000000000000000) * hugeint_t(100000), + hugeint_t(1000000000000000000) * hugeint_t(1000000), + hugeint_t(1000000000000000000) * hugeint_t(10000000), + hugeint_t(1000000000000000000) * hugeint_t(100000000), + hugeint_t(1000000000000000000) * hugeint_t(1000000000), + hugeint_t(1000000000000000000) * hugeint_t(10000000000), + hugeint_t(1000000000000000000) * hugeint_t(100000000000), + hugeint_t(1000000000000000000) * hugeint_t(1000000000000), + hugeint_t(1000000000000000000) * hugeint_t(10000000000000), + hugeint_t(1000000000000000000) * hugeint_t(100000000000000), + hugeint_t(1000000000000000000) * hugeint_t(1000000000000000), + hugeint_t(1000000000000000000) * hugeint_t(10000000000000000), + hugeint_t(1000000000000000000) * hugeint_t(100000000000000000), + hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000), + hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000) * hugeint_t(10), + hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000) * hugeint_t(100)}; + +static uint8_t PositiveHugeintHighestBit(hugeint_t bits) { + uint8_t out = 0; + if (bits.upper) { + out = 64; + uint64_t up = bits.upper; + while (up) { + up >>= 1; + out++; + } + } else { + uint64_t low = bits.lower; + while (low) { + low >>= 1; + out++; + } + } + return out; +} + +static bool PositiveHugeintIsBitSet(hugeint_t lhs, uint8_t bit_position) { + if (bit_position < 64) { + return lhs.lower & (uint64_t(1) << uint64_t(bit_position)); + } else { + return lhs.upper & (uint64_t(1) << uint64_t(bit_position - 64)); + } +} + +hugeint_t PositiveHugeintLeftShift(hugeint_t lhs, uint32_t amount) { + D_ASSERT(amount > 0 && amount < 64); + hugeint_t result; + result.lower = lhs.lower << amount; + result.upper = (lhs.upper << amount) + (lhs.lower >> (64 - amount)); + return result; +} + +hugeint_t Hugeint::DivModPositive(hugeint_t lhs, uint64_t rhs, uint64_t &remainder) { + D_ASSERT(lhs.upper >= 0); + // DivMod code adapted from: + // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp + + // initialize the result and remainder to 0 + hugeint_t div_result; + div_result.lower = 0; + div_result.upper = 0; + remainder = 0; + + uint8_t highest_bit_set = PositiveHugeintHighestBit(lhs); + // now iterate over the amount of bits that are set in the LHS + for (uint8_t x = highest_bit_set; x > 0; x--) { + // left-shift the current result and remainder by 1 + div_result = PositiveHugeintLeftShift(div_result, 1); + remainder <<= 1; + // we get the value of the bit at position X, where position 0 is the least-significant bit + if (PositiveHugeintIsBitSet(lhs, x - 1)) { + // increment the remainder + remainder++; + } + if (remainder >= rhs) { + // the remainder has passed the division multiplier: add one to the divide result + remainder -= rhs; + div_result.lower++; + if (div_result.lower == 0) { + // overflow + div_result.upper++; + } + } + } + return div_result; +} + +string Hugeint::ToString(hugeint_t input) { + uint64_t remainder; + string result; + bool negative = input.upper < 0; + if (negative) { + NegateInPlace(input); + } + while (true) { + if (!input.lower && !input.upper) { + break; + } + input = Hugeint::DivModPositive(input, 10, remainder); + result = string(1, '0' + remainder) + result; // NOLINT + } + if (result.empty()) { + // value is zero + return "0"; + } + return negative ? "-" + result : result; +} + +//===--------------------------------------------------------------------===// +// Multiply +//===--------------------------------------------------------------------===// +bool Hugeint::TryMultiply(hugeint_t lhs, hugeint_t rhs, hugeint_t &result) { + bool lhs_negative = lhs.upper < 0; + bool rhs_negative = rhs.upper < 0; + if (lhs_negative) { + NegateInPlace(lhs); + } + if (rhs_negative) { + NegateInPlace(rhs); + } +#if ((__GNUC__ >= 5) || defined(__clang__)) && defined(__SIZEOF_INT128__) + __uint128_t left = __uint128_t(lhs.lower) + (__uint128_t(lhs.upper) << 64); + __uint128_t right = __uint128_t(rhs.lower) + (__uint128_t(rhs.upper) << 64); + __uint128_t result_i128; + if (__builtin_mul_overflow(left, right, &result_i128)) { + return false; + } + uint64_t upper = uint64_t(result_i128 >> 64); + if (upper & 0x8000000000000000) { + return false; + } + result.upper = int64_t(upper); + result.lower = uint64_t(result_i128 & 0xffffffffffffffff); +#else + // Multiply code adapted from: + // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp + + // split values into 4 32-bit parts + uint64_t top[4] = {uint64_t(lhs.upper) >> 32, uint64_t(lhs.upper) & 0xffffffff, lhs.lower >> 32, + lhs.lower & 0xffffffff}; + uint64_t bottom[4] = {uint64_t(rhs.upper) >> 32, uint64_t(rhs.upper) & 0xffffffff, rhs.lower >> 32, + rhs.lower & 0xffffffff}; + uint64_t products[4][4]; + + // multiply each component of the values + for (auto x = 0; x < 4; x++) { + for (auto y = 0; y < 4; y++) { + products[x][y] = top[x] * bottom[y]; + } + } + + // if any of these products are set to a non-zero value, there is always an overflow + if (products[0][0] || products[0][1] || products[0][2] || products[1][0] || products[2][0] || products[1][1]) { + return false; + } + // if the high bits of any of these are set, there is always an overflow + if ((products[0][3] & 0xffffffff80000000) || (products[1][2] & 0xffffffff80000000) || + (products[2][1] & 0xffffffff80000000) || (products[3][0] & 0xffffffff80000000)) { + return false; + } + + // otherwise we merge the result of the different products together in-order + + // first row + uint64_t fourth32 = (products[3][3] & 0xffffffff); + uint64_t third32 = (products[3][2] & 0xffffffff) + (products[3][3] >> 32); + uint64_t second32 = (products[3][1] & 0xffffffff) + (products[3][2] >> 32); + uint64_t first32 = (products[3][0] & 0xffffffff) + (products[3][1] >> 32); + + // second row + third32 += (products[2][3] & 0xffffffff); + second32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32); + first32 += (products[2][1] & 0xffffffff) + (products[2][2] >> 32); + + // third row + second32 += (products[1][3] & 0xffffffff); + first32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32); + + // fourth row + first32 += (products[0][3] & 0xffffffff); + + // move carry to next digit + third32 += fourth32 >> 32; + second32 += third32 >> 32; + first32 += second32 >> 32; + + // check if the combination of the different products resulted in an overflow + if (first32 & 0xffffff80000000) { + return false; + } + + // remove carry from current digit + fourth32 &= 0xffffffff; + third32 &= 0xffffffff; + second32 &= 0xffffffff; + first32 &= 0xffffffff; + + // combine components + result.lower = (third32 << 32) | fourth32; + result.upper = (first32 << 32) | second32; +#endif + if (lhs_negative ^ rhs_negative) { + NegateInPlace(result); + } + return true; +} + +hugeint_t Hugeint::Multiply(hugeint_t lhs, hugeint_t rhs) { + hugeint_t result; + if (!TryMultiply(lhs, rhs, result)) { + throw OutOfRangeException("Overflow in HUGEINT multiplication!"); + } + return result; +} + +//===--------------------------------------------------------------------===// +// Divide +//===--------------------------------------------------------------------===// +hugeint_t Hugeint::DivMod(hugeint_t lhs, hugeint_t rhs, hugeint_t &remainder) { + // division by zero not allowed + D_ASSERT(!(rhs.upper == 0 && rhs.lower == 0)); + + bool lhs_negative = lhs.upper < 0; + bool rhs_negative = rhs.upper < 0; + if (lhs_negative) { + Hugeint::NegateInPlace(lhs); + } + if (rhs_negative) { + Hugeint::NegateInPlace(rhs); + } + // DivMod code adapted from: + // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp + + // initialize the result and remainder to 0 + hugeint_t div_result; + div_result.lower = 0; + div_result.upper = 0; + remainder.lower = 0; + remainder.upper = 0; + + uint8_t highest_bit_set = PositiveHugeintHighestBit(lhs); + // now iterate over the amount of bits that are set in the LHS + for (uint8_t x = highest_bit_set; x > 0; x--) { + // left-shift the current result and remainder by 1 + div_result = PositiveHugeintLeftShift(div_result, 1); + remainder = PositiveHugeintLeftShift(remainder, 1); + + // we get the value of the bit at position X, where position 0 is the least-significant bit + if (PositiveHugeintIsBitSet(lhs, x - 1)) { + // increment the remainder + Hugeint::AddInPlace(remainder, 1); + } + if (Hugeint::GreaterThanEquals(remainder, rhs)) { + // the remainder has passed the division multiplier: add one to the divide result + remainder = Hugeint::Subtract(remainder, rhs); + Hugeint::AddInPlace(div_result, 1); + } + } + if (lhs_negative ^ rhs_negative) { + Hugeint::NegateInPlace(div_result); + } + if (lhs_negative) { + Hugeint::NegateInPlace(remainder); + } + return div_result; +} + +hugeint_t Hugeint::Divide(hugeint_t lhs, hugeint_t rhs) { + hugeint_t remainder; + return Hugeint::DivMod(lhs, rhs, remainder); +} + +hugeint_t Hugeint::Modulo(hugeint_t lhs, hugeint_t rhs) { + hugeint_t remainder; + Hugeint::DivMod(lhs, rhs, remainder); + return remainder; +} + +//===--------------------------------------------------------------------===// +// Add/Subtract +//===--------------------------------------------------------------------===// +bool Hugeint::AddInPlace(hugeint_t &lhs, hugeint_t rhs) { + int overflow = lhs.lower + rhs.lower < lhs.lower; + if (rhs.upper >= 0) { + // RHS is positive: check for overflow + if (lhs.upper > (std::numeric_limits::max() - rhs.upper - overflow)) { + return false; + } + lhs.upper = lhs.upper + overflow + rhs.upper; + } else { + // RHS is negative: check for underflow + if (lhs.upper < std::numeric_limits::min() - rhs.upper - overflow) { + return false; + } + lhs.upper = lhs.upper + (overflow + rhs.upper); + } + lhs.lower += rhs.lower; + if (lhs.upper == std::numeric_limits::min() && lhs.lower == 0) { + return false; + } + return true; +} + +bool Hugeint::SubtractInPlace(hugeint_t &lhs, hugeint_t rhs) { + // underflow + int underflow = lhs.lower - rhs.lower > lhs.lower; + if (rhs.upper >= 0) { + // RHS is positive: check for underflow + if (lhs.upper < (std::numeric_limits::min() + rhs.upper + underflow)) { + return false; + } + lhs.upper = (lhs.upper - rhs.upper) - underflow; + } else { + // RHS is negative: check for overflow + if (lhs.upper > std::numeric_limits::min() && + lhs.upper - 1 >= (std::numeric_limits::max() + rhs.upper + underflow)) { + return false; + } + lhs.upper = lhs.upper - (rhs.upper + underflow); + } + lhs.lower -= rhs.lower; + if (lhs.upper == std::numeric_limits::min() && lhs.lower == 0) { + return false; + } + return true; +} + +hugeint_t Hugeint::Add(hugeint_t lhs, hugeint_t rhs) { + if (!AddInPlace(lhs, rhs)) { + throw OutOfRangeException("Overflow in HUGEINT addition"); + } + return lhs; +} + +hugeint_t Hugeint::Subtract(hugeint_t lhs, hugeint_t rhs) { + if (!SubtractInPlace(lhs, rhs)) { + throw OutOfRangeException("Underflow in HUGEINT addition"); + } + return lhs; +} + +//===--------------------------------------------------------------------===// +// Hugeint Cast/Conversion +//===--------------------------------------------------------------------===// +template +bool HugeintTryCastInteger(hugeint_t input, DST &result) { + switch (input.upper) { + case 0: + // positive number: check if the positive number is in range + if (input.lower <= uint64_t(NumericLimits::Maximum())) { + result = DST(input.lower); + return true; + } + break; + case -1: + if (!SIGNED) { + return false; + } + // negative number: check if the negative number is in range + if (input.lower >= NumericLimits::Maximum() - uint64_t(NumericLimits::Maximum())) { + result = -DST(NumericLimits::Maximum() - input.lower) - 1; + return true; + } + break; + default: + break; + } + return false; +} + +template <> +bool Hugeint::TryCast(hugeint_t input, int8_t &result) { + return HugeintTryCastInteger(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, int16_t &result) { + return HugeintTryCastInteger(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, int32_t &result) { + return HugeintTryCastInteger(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, int64_t &result) { + return HugeintTryCastInteger(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, uint8_t &result) { + return HugeintTryCastInteger(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, uint16_t &result) { + return HugeintTryCastInteger(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, uint32_t &result) { + return HugeintTryCastInteger(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, uint64_t &result) { + return HugeintTryCastInteger(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, hugeint_t &result) { + result = input; + return true; +} + +template <> +bool Hugeint::TryCast(hugeint_t input, float &result) { + double dbl_result; + Hugeint::TryCast(input, dbl_result); + result = (float)dbl_result; + return true; +} + +template +bool CastBigintToFloating(hugeint_t input, REAL_T &result) { + switch (input.upper) { + case -1: + // special case for upper = -1 to avoid rounding issues in small negative numbers + result = -REAL_T(NumericLimits::Maximum() - input.lower) - 1; + break; + default: + result = REAL_T(input.lower) + REAL_T(input.upper) * REAL_T(NumericLimits::Maximum()); + break; + } + return true; +} + +template <> +bool Hugeint::TryCast(hugeint_t input, double &result) { + return CastBigintToFloating(input, result); +} + +template <> +bool Hugeint::TryCast(hugeint_t input, long double &result) { + return CastBigintToFloating(input, result); +} + +template +hugeint_t HugeintConvertInteger(DST input) { + hugeint_t result; + result.lower = (uint64_t)input; + result.upper = (input < 0) * -1; + return result; +} + +template <> +bool Hugeint::TryConvert(int8_t value, hugeint_t &result) { + result = HugeintConvertInteger(value); + return true; +} + +template <> +bool Hugeint::TryConvert(const char *value, hugeint_t &result) { + auto len = strlen(value); + string_t string_val(value, len); + return TryCast::Operation(string_val, result, true); +} + +template <> +bool Hugeint::TryConvert(int16_t value, hugeint_t &result) { + result = HugeintConvertInteger(value); + return true; +} + +template <> +bool Hugeint::TryConvert(int32_t value, hugeint_t &result) { + result = HugeintConvertInteger(value); + return true; +} + +template <> +bool Hugeint::TryConvert(int64_t value, hugeint_t &result) { + result = HugeintConvertInteger(value); + return true; +} +template <> +bool Hugeint::TryConvert(uint8_t value, hugeint_t &result) { + result = HugeintConvertInteger(value); + return true; +} +template <> +bool Hugeint::TryConvert(uint16_t value, hugeint_t &result) { + result = HugeintConvertInteger(value); + return true; +} +template <> +bool Hugeint::TryConvert(uint32_t value, hugeint_t &result) { + result = HugeintConvertInteger(value); + return true; +} +template <> +bool Hugeint::TryConvert(uint64_t value, hugeint_t &result) { + result = HugeintConvertInteger(value); + return true; +} + +template <> +bool Hugeint::TryConvert(hugeint_t value, hugeint_t &result) { + result = value; + return true; +} + +template <> +bool Hugeint::TryConvert(float value, hugeint_t &result) { + return Hugeint::TryConvert(double(value), result); +} + +template +bool ConvertFloatingToBigint(REAL_T value, hugeint_t &result) { + if (!Value::IsFinite(value)) { + return false; + } + if (value <= -170141183460469231731687303715884105728.0 || value >= 170141183460469231731687303715884105727.0) { + return false; + } + bool negative = value < 0; + if (negative) { + value = -value; + } + result.lower = (uint64_t)fmod(value, REAL_T(NumericLimits::Maximum())); + result.upper = (uint64_t)(value / REAL_T(NumericLimits::Maximum())); + if (negative) { + Hugeint::NegateInPlace(result); + } + return true; +} + +template <> +bool Hugeint::TryConvert(double value, hugeint_t &result) { + return ConvertFloatingToBigint(value, result); +} + +template <> +bool Hugeint::TryConvert(long double value, hugeint_t &result) { + return ConvertFloatingToBigint(value, result); +} + +//===--------------------------------------------------------------------===// +// hugeint_t operators +//===--------------------------------------------------------------------===// +hugeint_t::hugeint_t(int64_t value) { + auto result = Hugeint::Convert(value); + this->lower = result.lower; + this->upper = result.upper; +} + +bool hugeint_t::operator==(const hugeint_t &rhs) const { + return Hugeint::Equals(*this, rhs); +} + +bool hugeint_t::operator!=(const hugeint_t &rhs) const { + return Hugeint::NotEquals(*this, rhs); +} + +bool hugeint_t::operator<(const hugeint_t &rhs) const { + return Hugeint::LessThan(*this, rhs); +} + +bool hugeint_t::operator<=(const hugeint_t &rhs) const { + return Hugeint::LessThanEquals(*this, rhs); +} + +bool hugeint_t::operator>(const hugeint_t &rhs) const { + return Hugeint::GreaterThan(*this, rhs); +} + +bool hugeint_t::operator>=(const hugeint_t &rhs) const { + return Hugeint::GreaterThanEquals(*this, rhs); +} + +hugeint_t hugeint_t::operator+(const hugeint_t &rhs) const { + return Hugeint::Add(*this, rhs); +} + +hugeint_t hugeint_t::operator-(const hugeint_t &rhs) const { + return Hugeint::Subtract(*this, rhs); +} + +hugeint_t hugeint_t::operator*(const hugeint_t &rhs) const { + return Hugeint::Multiply(*this, rhs); +} + +hugeint_t hugeint_t::operator/(const hugeint_t &rhs) const { + return Hugeint::Divide(*this, rhs); +} + +hugeint_t hugeint_t::operator%(const hugeint_t &rhs) const { + return Hugeint::Modulo(*this, rhs); +} + +hugeint_t hugeint_t::operator-() const { + return Hugeint::Negate(*this); +} + +hugeint_t hugeint_t::operator>>(const hugeint_t &rhs) const { + hugeint_t result; + uint64_t shift = rhs.lower; + if (rhs.upper != 0 || shift >= 128) { + return hugeint_t(0); + } else if (shift == 0) { + return *this; + } else if (shift == 64) { + result.upper = (upper < 0) ? -1 : 0; + result.lower = upper; + } else if (shift < 64) { + // perform lower shift in unsigned integer, and mask away the most significant bit + result.lower = (uint64_t(upper) << (64 - shift)) | (lower >> shift); + result.upper = upper >> shift; + } else { + D_ASSERT(shift < 128); + result.lower = upper >> (shift - 64); + result.upper = (upper < 0) ? -1 : 0; + } + return result; +} + +hugeint_t hugeint_t::operator<<(const hugeint_t &rhs) const { + if (upper < 0) { + return hugeint_t(0); + } + hugeint_t result; + uint64_t shift = rhs.lower; + if (rhs.upper != 0 || shift >= 128) { + return hugeint_t(0); + } else if (shift == 64) { + result.upper = lower; + result.lower = 0; + } else if (shift == 0) { + return *this; + } else if (shift < 64) { + // perform upper shift in unsigned integer, and mask away the most significant bit + uint64_t upper_shift = ((uint64_t(upper) << shift) + (lower >> (64 - shift))) & 0x7FFFFFFFFFFFFFFF; + result.lower = lower << shift; + result.upper = upper_shift; + } else { + D_ASSERT(shift < 128); + result.lower = 0; + result.upper = (lower << (shift - 64)) & 0x7FFFFFFFFFFFFFFF; + } + return result; +} + +hugeint_t hugeint_t::operator&(const hugeint_t &rhs) const { + hugeint_t result; + result.lower = lower & rhs.lower; + result.upper = upper & rhs.upper; + return result; +} + +hugeint_t hugeint_t::operator|(const hugeint_t &rhs) const { + hugeint_t result; + result.lower = lower | rhs.lower; + result.upper = upper | rhs.upper; + return result; +} + +hugeint_t hugeint_t::operator^(const hugeint_t &rhs) const { + hugeint_t result; + result.lower = lower ^ rhs.lower; + result.upper = upper ^ rhs.upper; + return result; +} + +hugeint_t hugeint_t::operator~() const { + hugeint_t result; + result.lower = ~lower; + result.upper = ~upper; + return result; +} + +hugeint_t &hugeint_t::operator+=(const hugeint_t &rhs) { + Hugeint::AddInPlace(*this, rhs); + return *this; +} +hugeint_t &hugeint_t::operator-=(const hugeint_t &rhs) { + Hugeint::SubtractInPlace(*this, rhs); + return *this; +} +hugeint_t &hugeint_t::operator*=(const hugeint_t &rhs) { + *this = Hugeint::Multiply(*this, rhs); + return *this; +} +hugeint_t &hugeint_t::operator/=(const hugeint_t &rhs) { + *this = Hugeint::Divide(*this, rhs); + return *this; +} +hugeint_t &hugeint_t::operator%=(const hugeint_t &rhs) { + *this = Hugeint::Modulo(*this, rhs); + return *this; +} +hugeint_t &hugeint_t::operator>>=(const hugeint_t &rhs) { + *this = *this >> rhs; + return *this; +} +hugeint_t &hugeint_t::operator<<=(const hugeint_t &rhs) { + *this = *this << rhs; + return *this; +} +hugeint_t &hugeint_t::operator&=(const hugeint_t &rhs) { + lower &= rhs.lower; + upper &= rhs.upper; + return *this; +} +hugeint_t &hugeint_t::operator|=(const hugeint_t &rhs) { + lower |= rhs.lower; + upper |= rhs.upper; + return *this; +} +hugeint_t &hugeint_t::operator^=(const hugeint_t &rhs) { + lower ^= rhs.lower; + upper ^= rhs.upper; + return *this; +} + +bool hugeint_t::operator!() const { + return *this == 0; +} + +hugeint_t::operator bool() const { + return *this != 0; +} + +template +static T NarrowCast(const hugeint_t &input) { + // NarrowCast is supposed to truncate (take lower) + return static_cast(input.lower); +} + +hugeint_t::operator uint8_t() const { + return NarrowCast(*this); +} +hugeint_t::operator uint16_t() const { + return NarrowCast(*this); +} +hugeint_t::operator uint32_t() const { + return NarrowCast(*this); +} +hugeint_t::operator uint64_t() const { + return NarrowCast(*this); +} +hugeint_t::operator int8_t() const { + return NarrowCast(*this); +} +hugeint_t::operator int16_t() const { + return NarrowCast(*this); +} +hugeint_t::operator int32_t() const { + return NarrowCast(*this); +} +hugeint_t::operator int64_t() const { + return NarrowCast(*this); +} + +string hugeint_t::ToString() const { + return Hugeint::ToString(*this); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/hyperloglog.cpp b/src/duckdb/src/common/types/hyperloglog.cpp new file mode 100644 index 00000000..4efcdc82 --- /dev/null +++ b/src/duckdb/src/common/types/hyperloglog.cpp @@ -0,0 +1,276 @@ +#include "duckdb/common/types/hyperloglog.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include "hyperloglog.hpp" + +namespace duckdb { + +HyperLogLog::HyperLogLog() : hll(nullptr) { + hll = duckdb_hll::hll_create(); + // Insert into a dense hll can be vectorized, sparse cannot, so we immediately convert + duckdb_hll::hllSparseToDense(hll); +} + +HyperLogLog::HyperLogLog(duckdb_hll::robj *hll) : hll(hll) { +} + +HyperLogLog::~HyperLogLog() { + duckdb_hll::hll_destroy(hll); +} + +void HyperLogLog::Add(data_ptr_t element, idx_t size) { + if (duckdb_hll::hll_add(hll, element, size) == HLL_C_ERR) { + throw InternalException("Could not add to HLL?"); + } +} + +idx_t HyperLogLog::Count() const { + // exception from size_t ban + size_t result; + + if (duckdb_hll::hll_count(hll, &result) != HLL_C_OK) { + throw InternalException("Could not count HLL?"); + } + return result; +} + +unique_ptr HyperLogLog::Merge(HyperLogLog &other) { + duckdb_hll::robj *hlls[2]; + hlls[0] = hll; + hlls[1] = other.hll; + auto new_hll = duckdb_hll::hll_merge(hlls, 2); + if (!new_hll) { + throw InternalException("Could not merge HLLs"); + } + return unique_ptr(new HyperLogLog(new_hll)); +} + +HyperLogLog *HyperLogLog::MergePointer(HyperLogLog &other) { + duckdb_hll::robj *hlls[2]; + hlls[0] = hll; + hlls[1] = other.hll; + auto new_hll = duckdb_hll::hll_merge(hlls, 2); + if (!new_hll) { + throw Exception("Could not merge HLLs"); + } + return new HyperLogLog(new_hll); +} + +unique_ptr HyperLogLog::Merge(HyperLogLog logs[], idx_t count) { + auto hlls_uptr = unique_ptr { + new duckdb_hll::robj *[count] + }; + auto hlls = hlls_uptr.get(); + for (idx_t i = 0; i < count; i++) { + hlls[i] = logs[i].hll; + } + auto new_hll = duckdb_hll::hll_merge(hlls, count); + if (!new_hll) { + throw InternalException("Could not merge HLLs"); + } + return unique_ptr(new HyperLogLog(new_hll)); +} + +idx_t HyperLogLog::GetSize() { + return duckdb_hll::get_size(); +} + +data_ptr_t HyperLogLog::GetPtr() const { + return data_ptr_cast((hll)->ptr); +} + +unique_ptr HyperLogLog::Copy() { + auto result = make_uniq(); + lock_guard guard(lock); + memcpy(result->GetPtr(), GetPtr(), GetSize()); + D_ASSERT(result->Count() == Count()); + return result; +} + +void HyperLogLog::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", HLLStorageType::UNCOMPRESSED); + serializer.WriteProperty(101, "data", GetPtr(), GetSize()); +} + +unique_ptr HyperLogLog::Deserialize(Deserializer &deserializer) { + auto result = make_uniq(); + auto storage_type = deserializer.ReadProperty(100, "type"); + switch (storage_type) { + case HLLStorageType::UNCOMPRESSED: + deserializer.ReadProperty(101, "data", result->GetPtr(), GetSize()); + break; + default: + throw SerializationException("Unknown HyperLogLog storage type!"); + } + return result; +} + +//===--------------------------------------------------------------------===// +// Vectorized HLL implementation +//===--------------------------------------------------------------------===// +//! Taken from https://nullprogram.com/blog/2018/07/31/ +template +inline uint64_t TemplatedHash(const T &elem) { + uint64_t x = elem; + x ^= x >> 30; + x *= UINT64_C(0xbf58476d1ce4e5b9); + x ^= x >> 27; + x *= UINT64_C(0x94d049bb133111eb); + x ^= x >> 31; + return x; +} + +template <> +inline uint64_t TemplatedHash(const hugeint_t &elem) { + return TemplatedHash(Load(const_data_ptr_cast(&elem.upper))) ^ + TemplatedHash(elem.lower); +} + +template +inline void CreateIntegerRecursive(const_data_ptr_t &data, uint64_t &x) { + x ^= (uint64_t)data[rest - 1] << ((rest - 1) * 8); + return CreateIntegerRecursive(data, x); +} + +template <> +inline void CreateIntegerRecursive<1>(const_data_ptr_t &data, uint64_t &x) { + x ^= (uint64_t)data[0]; +} + +inline uint64_t HashOtherSize(const_data_ptr_t &data, const idx_t &len) { + uint64_t x = 0; + switch (len & 7) { + case 7: + CreateIntegerRecursive<7>(data, x); + break; + case 6: + CreateIntegerRecursive<6>(data, x); + break; + case 5: + CreateIntegerRecursive<5>(data, x); + break; + case 4: + CreateIntegerRecursive<4>(data, x); + break; + case 3: + CreateIntegerRecursive<3>(data, x); + break; + case 2: + CreateIntegerRecursive<2>(data, x); + break; + case 1: + CreateIntegerRecursive<1>(data, x); + break; + case 0: + break; + } + return TemplatedHash(x); +} + +template <> +inline uint64_t TemplatedHash(const string_t &elem) { + auto data = const_data_ptr_cast(elem.GetData()); + const auto &len = elem.GetSize(); + uint64_t h = 0; + for (idx_t i = 0; i + sizeof(uint64_t) <= len; i += sizeof(uint64_t)) { + h ^= TemplatedHash(Load(data)); + data += sizeof(uint64_t); + } + switch (len & (sizeof(uint64_t) - 1)) { + case 4: + h ^= TemplatedHash(Load(data)); + break; + case 2: + h ^= TemplatedHash(Load(data)); + break; + case 1: + h ^= TemplatedHash(Load(data)); + break; + default: + h ^= HashOtherSize(data, len); + } + return h; +} + +template +void TemplatedComputeHashes(UnifiedVectorFormat &vdata, const idx_t &count, uint64_t hashes[]) { + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + if (vdata.validity.RowIsValid(idx)) { + hashes[i] = TemplatedHash(data[idx]); + } else { + hashes[i] = 0; + } + } +} + +static void ComputeHashes(UnifiedVectorFormat &vdata, const LogicalType &type, uint64_t hashes[], idx_t count) { + switch (type.InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + case PhysicalType::UINT8: + return TemplatedComputeHashes(vdata, count, hashes); + case PhysicalType::INT16: + case PhysicalType::UINT16: + return TemplatedComputeHashes(vdata, count, hashes); + case PhysicalType::INT32: + case PhysicalType::UINT32: + case PhysicalType::FLOAT: + return TemplatedComputeHashes(vdata, count, hashes); + case PhysicalType::INT64: + case PhysicalType::UINT64: + case PhysicalType::DOUBLE: + return TemplatedComputeHashes(vdata, count, hashes); + case PhysicalType::INT128: + case PhysicalType::INTERVAL: + static_assert(sizeof(hugeint_t) == sizeof(interval_t), "ComputeHashes assumes these are the same size!"); + return TemplatedComputeHashes(vdata, count, hashes); + case PhysicalType::VARCHAR: + return TemplatedComputeHashes(vdata, count, hashes); + default: + throw InternalException("Unimplemented type for HyperLogLog::ComputeHashes"); + } +} + +//! Taken from https://stackoverflow.com/a/72088344 +static inline uint8_t CountTrailingZeros(uint64_t &x) { + static constexpr const uint64_t DEBRUIJN = 0x03f79d71b4cb0a89; + static constexpr const uint8_t LOOKUP[] = {0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28, 16, 3, 61, + 54, 58, 35, 52, 50, 42, 21, 44, 38, 32, 29, 23, 17, 11, 4, 62, + 46, 55, 26, 59, 40, 36, 15, 53, 34, 51, 20, 43, 31, 22, 10, 45, + 25, 39, 14, 33, 19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63}; + return LOOKUP[(DEBRUIJN * (x ^ (x - 1))) >> 58]; +} + +static inline void ComputeIndexAndCount(uint64_t &hash, uint8_t &prefix) { + uint64_t index = hash & ((1 << 12) - 1); /* Register index. */ + hash >>= 12; /* Remove bits used to address the register. */ + hash |= ((uint64_t)1 << (64 - 12)); /* Make sure the count will be <= Q+1. */ + + prefix = CountTrailingZeros(hash) + 1; /* Add 1 since we count the "00000...1" pattern. */ + hash = index; +} + +void HyperLogLog::ProcessEntries(UnifiedVectorFormat &vdata, const LogicalType &type, uint64_t hashes[], + uint8_t counts[], idx_t count) { + ComputeHashes(vdata, type, hashes, count); + for (idx_t i = 0; i < count; i++) { + ComputeIndexAndCount(hashes[i], counts[i]); + } +} + +void HyperLogLog::AddToLogs(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[], + HyperLogLog **logs[], const SelectionVector *log_sel) { + AddToLogsInternal(vdata, count, indices, counts, reinterpret_cast(logs), log_sel); +} + +void HyperLogLog::AddToLog(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[]) { + lock_guard guard(lock); + AddToSingleLogInternal(vdata, count, indices, counts, hll); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/interval.cpp b/src/duckdb/src/common/types/interval.cpp new file mode 100644 index 00000000..9fb09608 --- /dev/null +++ b/src/duckdb/src/common/types/interval.cpp @@ -0,0 +1,477 @@ +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/string_util.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +bool Interval::FromString(const string &str, interval_t &result) { + string error_message; + return Interval::FromCString(str.c_str(), str.size(), result, &error_message, false); +} + +template +void IntervalTryAddition(T &target, int64_t input, int64_t multiplier) { + int64_t addition; + if (!TryMultiplyOperator::Operation(input, multiplier, addition)) { + throw OutOfRangeException("interval value is out of range"); + } + T addition_base = Cast::Operation(addition); + if (!TryAddOperator::Operation(target, addition_base, target)) { + throw OutOfRangeException("interval value is out of range"); + } +} + +bool Interval::FromCString(const char *str, idx_t len, interval_t &result, string *error_message, bool strict) { + idx_t pos = 0; + idx_t start_pos; + bool negative; + bool found_any = false; + int64_t number; + DatePartSpecifier specifier; + string specifier_str; + + result.days = 0; + result.micros = 0; + result.months = 0; + + if (len == 0) { + return false; + } + + switch (str[pos]) { + case '@': + pos++; + goto standard_interval; + case 'P': + case 'p': + pos++; + goto posix_interval; + default: + goto standard_interval; + } +standard_interval: + // start parsing a standard interval (e.g. 2 years 3 months...) + for (; pos < len; pos++) { + char c = str[pos]; + if (c == ' ' || c == '\t' || c == '\n') { + // skip spaces + continue; + } else if (c >= '0' && c <= '9') { + // start parsing a positive number + negative = false; + goto interval_parse_number; + } else if (c == '-') { + // negative number + negative = true; + pos++; + goto interval_parse_number; + } else if (c == 'a' || c == 'A') { + // parse the word "ago" as the final specifier + goto interval_parse_ago; + } else { + // unrecognized character, expected a number or end of string + return false; + } + } + goto end_of_string; +interval_parse_number: + start_pos = pos; + for (; pos < len; pos++) { + char c = str[pos]; + if (c >= '0' && c <= '9') { + // the number continues + continue; + } else if (c == ':') { + // colon: we are parsing a time + goto interval_parse_time; + } else { + if (pos == start_pos) { + return false; + } + // finished the number, parse it from the string + string_t nr_string(str + start_pos, pos - start_pos); + number = Cast::Operation(nr_string); + if (negative) { + number = -number; + } + goto interval_parse_identifier; + } + } + goto end_of_string; +interval_parse_time : { + // parse the remainder of the time as a Time type + dtime_t time; + idx_t pos; + if (!Time::TryConvertTime(str + start_pos, len - start_pos, pos, time)) { + return false; + } + result.micros += time.micros; + found_any = true; + if (negative) { + result.micros = -result.micros; + } + goto end_of_string; +} +interval_parse_identifier: + for (; pos < len; pos++) { + char c = str[pos]; + if (c == ' ' || c == '\t' || c == '\n') { + // skip spaces at the start + continue; + } else { + break; + } + } + // now parse the identifier + start_pos = pos; + for (; pos < len; pos++) { + char c = str[pos]; + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) { + // keep parsing the string + continue; + } else { + break; + } + } + specifier_str = string(str + start_pos, pos - start_pos); + if (!TryGetDatePartSpecifier(specifier_str, specifier)) { + HandleCastError::AssignError(StringUtil::Format("extract specifier \"%s\" not recognized", specifier_str), + error_message); + return false; + } + // add the specifier to the interval + switch (specifier) { + case DatePartSpecifier::MILLENNIUM: + IntervalTryAddition(result.months, number, MONTHS_PER_MILLENIUM); + break; + case DatePartSpecifier::CENTURY: + IntervalTryAddition(result.months, number, MONTHS_PER_CENTURY); + break; + case DatePartSpecifier::DECADE: + IntervalTryAddition(result.months, number, MONTHS_PER_DECADE); + break; + case DatePartSpecifier::YEAR: + IntervalTryAddition(result.months, number, MONTHS_PER_YEAR); + break; + case DatePartSpecifier::QUARTER: + IntervalTryAddition(result.months, number, MONTHS_PER_QUARTER); + break; + case DatePartSpecifier::MONTH: + IntervalTryAddition(result.months, number, 1); + break; + case DatePartSpecifier::DAY: + IntervalTryAddition(result.days, number, 1); + break; + case DatePartSpecifier::WEEK: + IntervalTryAddition(result.days, number, DAYS_PER_WEEK); + break; + case DatePartSpecifier::MICROSECONDS: + IntervalTryAddition(result.micros, number, 1); + break; + case DatePartSpecifier::MILLISECONDS: + IntervalTryAddition(result.micros, number, MICROS_PER_MSEC); + break; + case DatePartSpecifier::SECOND: + IntervalTryAddition(result.micros, number, MICROS_PER_SEC); + break; + case DatePartSpecifier::MINUTE: + IntervalTryAddition(result.micros, number, MICROS_PER_MINUTE); + break; + case DatePartSpecifier::HOUR: + IntervalTryAddition(result.micros, number, MICROS_PER_HOUR); + break; + default: + HandleCastError::AssignError( + StringUtil::Format("extract specifier \"%s\" not supported for interval", specifier_str), error_message); + return false; + } + found_any = true; + goto standard_interval; +interval_parse_ago: + D_ASSERT(str[pos] == 'a' || str[pos] == 'A'); + // parse the "ago" string at the end of the interval + if (len - pos < 3) { + return false; + } + pos++; + if (!(str[pos] == 'g' || str[pos] == 'G')) { + return false; + } + pos++; + if (!(str[pos] == 'o' || str[pos] == 'O')) { + return false; + } + pos++; + // parse any trailing whitespace + for (; pos < len; pos++) { + char c = str[pos]; + if (c == ' ' || c == '\t' || c == '\n') { + continue; + } else { + return false; + } + } + // invert all the values + result.months = -result.months; + result.days = -result.days; + result.micros = -result.micros; + goto end_of_string; +end_of_string: + if (!found_any) { + // end of string and no identifiers were found: cannot convert empty interval + return false; + } + return true; +posix_interval: + return false; +} + +string Interval::ToString(const interval_t &interval) { + char buffer[70]; + idx_t length = IntervalToStringCast::Format(interval, buffer); + return string(buffer, length); +} + +int64_t Interval::GetMilli(const interval_t &val) { + int64_t milli_month, milli_day, milli; + if (!TryMultiplyOperator::Operation((int64_t)val.months, Interval::MICROS_PER_MONTH / 1000, milli_month)) { + throw ConversionException("Could not convert Interval to Milliseconds"); + } + if (!TryMultiplyOperator::Operation((int64_t)val.days, Interval::MICROS_PER_DAY / 1000, milli_day)) { + throw ConversionException("Could not convert Interval to Milliseconds"); + } + milli = val.micros / 1000; + if (!TryAddOperator::Operation(milli, milli_month, milli)) { + throw ConversionException("Could not convert Interval to Milliseconds"); + } + if (!TryAddOperator::Operation(milli, milli_day, milli)) { + throw ConversionException("Could not convert Interval to Milliseconds"); + } + return milli; +} + +int64_t Interval::GetMicro(const interval_t &val) { + int64_t micro_month, micro_day, micro_total; + micro_total = val.micros; + if (!TryMultiplyOperator::Operation((int64_t)val.months, MICROS_PER_MONTH, micro_month)) { + throw ConversionException("Could not convert Month to Microseconds"); + } + if (!TryMultiplyOperator::Operation((int64_t)val.days, MICROS_PER_DAY, micro_day)) { + throw ConversionException("Could not convert Day to Microseconds"); + } + if (!TryAddOperator::Operation(micro_total, micro_month, micro_total)) { + throw ConversionException("Could not convert Interval to Microseconds"); + } + if (!TryAddOperator::Operation(micro_total, micro_day, micro_total)) { + throw ConversionException("Could not convert Interval to Microseconds"); + } + + return micro_total; +} + +int64_t Interval::GetNanoseconds(const interval_t &val) { + int64_t nano; + const auto micro_total = GetMicro(val); + if (!TryMultiplyOperator::Operation(micro_total, NANOS_PER_MICRO, nano)) { + throw ConversionException("Could not convert Interval to Nanoseconds"); + } + + return nano; +} + +interval_t Interval::GetAge(timestamp_t timestamp_1, timestamp_t timestamp_2) { + D_ASSERT(Timestamp::IsFinite(timestamp_1) && Timestamp::IsFinite(timestamp_2)); + date_t date1, date2; + dtime_t time1, time2; + + Timestamp::Convert(timestamp_1, date1, time1); + Timestamp::Convert(timestamp_2, date2, time2); + + // and from date extract the years, months and days + int32_t year1, month1, day1; + int32_t year2, month2, day2; + Date::Convert(date1, year1, month1, day1); + Date::Convert(date2, year2, month2, day2); + // finally perform the differences + auto year_diff = year1 - year2; + auto month_diff = month1 - month2; + auto day_diff = day1 - day2; + + // and from time extract hours, minutes, seconds and milliseconds + int32_t hour1, min1, sec1, micros1; + int32_t hour2, min2, sec2, micros2; + Time::Convert(time1, hour1, min1, sec1, micros1); + Time::Convert(time2, hour2, min2, sec2, micros2); + // finally perform the differences + auto hour_diff = hour1 - hour2; + auto min_diff = min1 - min2; + auto sec_diff = sec1 - sec2; + auto micros_diff = micros1 - micros2; + + // flip sign if necessary + bool sign_flipped = false; + if (timestamp_1 < timestamp_2) { + year_diff = -year_diff; + month_diff = -month_diff; + day_diff = -day_diff; + hour_diff = -hour_diff; + min_diff = -min_diff; + sec_diff = -sec_diff; + micros_diff = -micros_diff; + sign_flipped = true; + } + // now propagate any negative field into the next higher field + while (micros_diff < 0) { + micros_diff += MICROS_PER_SEC; + sec_diff--; + } + while (sec_diff < 0) { + sec_diff += SECS_PER_MINUTE; + min_diff--; + } + while (min_diff < 0) { + min_diff += MINS_PER_HOUR; + hour_diff--; + } + while (hour_diff < 0) { + hour_diff += HOURS_PER_DAY; + day_diff--; + } + while (day_diff < 0) { + if (timestamp_1 < timestamp_2) { + day_diff += Date::IsLeapYear(year1) ? Date::LEAP_DAYS[month1] : Date::NORMAL_DAYS[month1]; + month_diff--; + } else { + day_diff += Date::IsLeapYear(year2) ? Date::LEAP_DAYS[month2] : Date::NORMAL_DAYS[month2]; + month_diff--; + } + } + while (month_diff < 0) { + month_diff += MONTHS_PER_YEAR; + year_diff--; + } + + // recover sign if necessary + if (sign_flipped) { + year_diff = -year_diff; + month_diff = -month_diff; + day_diff = -day_diff; + hour_diff = -hour_diff; + min_diff = -min_diff; + sec_diff = -sec_diff; + micros_diff = -micros_diff; + } + interval_t interval; + interval.months = year_diff * MONTHS_PER_YEAR + month_diff; + interval.days = day_diff; + interval.micros = Time::FromTime(hour_diff, min_diff, sec_diff, micros_diff).micros; + + return interval; +} + +interval_t Interval::GetDifference(timestamp_t timestamp_1, timestamp_t timestamp_2) { + if (!Timestamp::IsFinite(timestamp_1) || !Timestamp::IsFinite(timestamp_2)) { + throw InvalidInputException("Cannot subtract infinite timestamps"); + } + const auto us_1 = Timestamp::GetEpochMicroSeconds(timestamp_1); + const auto us_2 = Timestamp::GetEpochMicroSeconds(timestamp_2); + int64_t delta_us; + if (!TrySubtractOperator::Operation(us_1, us_2, delta_us)) { + throw ConversionException("Timestamp difference is out of bounds"); + } + return FromMicro(delta_us); +} + +interval_t Interval::FromMicro(int64_t delta_us) { + interval_t result; + result.months = 0; + result.days = delta_us / Interval::MICROS_PER_DAY; + result.micros = delta_us % Interval::MICROS_PER_DAY; + + return result; +} + +interval_t Interval::Invert(interval_t interval) { + interval.days = -interval.days; + interval.micros = -interval.micros; + interval.months = -interval.months; + return interval; +} + +date_t Interval::Add(date_t left, interval_t right) { + if (!Date::IsFinite(left)) { + return left; + } + date_t result; + if (right.months != 0) { + int32_t year, month, day; + Date::Convert(left, year, month, day); + int32_t year_diff = right.months / Interval::MONTHS_PER_YEAR; + year += year_diff; + month += right.months - year_diff * Interval::MONTHS_PER_YEAR; + if (month > Interval::MONTHS_PER_YEAR) { + year++; + month -= Interval::MONTHS_PER_YEAR; + } else if (month <= 0) { + year--; + month += Interval::MONTHS_PER_YEAR; + } + day = MinValue(day, Date::MonthDays(year, month)); + result = Date::FromDate(year, month, day); + } else { + result = left; + } + if (right.days != 0) { + if (!TryAddOperator::Operation(result.days, right.days, result.days)) { + throw OutOfRangeException("Date out of range"); + } + } + if (right.micros != 0) { + if (!TryAddOperator::Operation(result.days, int32_t(right.micros / Interval::MICROS_PER_DAY), result.days)) { + throw OutOfRangeException("Date out of range"); + } + } + if (!Date::IsFinite(result)) { + throw OutOfRangeException("Date out of range"); + } + return result; +} + +dtime_t Interval::Add(dtime_t left, interval_t right, date_t &date) { + int64_t diff = right.micros - ((right.micros / Interval::MICROS_PER_DAY) * Interval::MICROS_PER_DAY); + left += diff; + if (left.micros >= Interval::MICROS_PER_DAY) { + left.micros -= Interval::MICROS_PER_DAY; + date.days++; + } else if (left.micros < 0) { + left.micros += Interval::MICROS_PER_DAY; + date.days--; + } + return left; +} + +timestamp_t Interval::Add(timestamp_t left, interval_t right) { + if (!Timestamp::IsFinite(left)) { + return left; + } + date_t date; + dtime_t time; + Timestamp::Convert(left, date, time); + auto new_date = Interval::Add(date, right); + auto new_time = Interval::Add(time, right, new_date); + return Timestamp::FromDatetime(new_date, new_time); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/list_segment.cpp b/src/duckdb/src/common/types/list_segment.cpp new file mode 100644 index 00000000..de350b60 --- /dev/null +++ b/src/duckdb/src/common/types/list_segment.cpp @@ -0,0 +1,544 @@ +#include "duckdb/common/types/list_segment.hpp" + +namespace duckdb { + +// forward declarations +//===--------------------------------------------------------------------===// +// Primitives +//===--------------------------------------------------------------------===// +template +static idx_t GetAllocationSize(uint16_t capacity) { + return AlignValue(sizeof(ListSegment) + capacity * (sizeof(bool) + sizeof(T))); +} + +template +static data_ptr_t AllocatePrimitiveData(ArenaAllocator &allocator, uint16_t capacity) { + return allocator.Allocate(GetAllocationSize(capacity)); +} + +template +static T *GetPrimitiveData(ListSegment *segment) { + return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + segment->capacity * sizeof(bool)); +} + +template +static const T *GetPrimitiveData(const ListSegment *segment) { + return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + + segment->capacity * sizeof(bool)); +} + +//===--------------------------------------------------------------------===// +// Lists +//===--------------------------------------------------------------------===// +static idx_t GetAllocationSizeList(uint16_t capacity) { + return AlignValue(sizeof(ListSegment) + capacity * (sizeof(bool) + sizeof(uint64_t)) + sizeof(LinkedList)); +} + +static data_ptr_t AllocateListData(ArenaAllocator &allocator, uint16_t capacity) { + return allocator.Allocate(GetAllocationSizeList(capacity)); +} + +static uint64_t *GetListLengthData(ListSegment *segment) { + return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + + segment->capacity * sizeof(bool)); +} + +static const uint64_t *GetListLengthData(const ListSegment *segment) { + return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + + segment->capacity * sizeof(bool)); +} + +static const LinkedList *GetListChildData(const ListSegment *segment) { + return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + + segment->capacity * (sizeof(bool) + sizeof(uint64_t))); +} + +static LinkedList *GetListChildData(ListSegment *segment) { + return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + + segment->capacity * (sizeof(bool) + sizeof(uint64_t))); +} + +//===--------------------------------------------------------------------===// +// Structs +//===--------------------------------------------------------------------===// +static idx_t GetAllocationSizeStruct(uint16_t capacity, idx_t child_count) { + return AlignValue(sizeof(ListSegment) + capacity * sizeof(bool) + child_count * sizeof(ListSegment *)); +} + +static data_ptr_t AllocateStructData(ArenaAllocator &allocator, uint16_t capacity, idx_t child_count) { + return allocator.Allocate(GetAllocationSizeStruct(capacity, child_count)); +} + +static ListSegment **GetStructData(ListSegment *segment) { + return reinterpret_cast(data_ptr_cast(segment) + +sizeof(ListSegment) + + segment->capacity * sizeof(bool)); +} + +static const ListSegment *const *GetStructData(const ListSegment *segment) { + return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + + segment->capacity * sizeof(bool)); +} + +static bool *GetNullMask(ListSegment *segment) { + return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment)); +} + +static const bool *GetNullMask(const ListSegment *segment) { + return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment)); +} + +static uint16_t GetCapacityForNewSegment(uint16_t capacity) { + auto next_power_of_two = idx_t(capacity) * 2; + if (next_power_of_two >= NumericLimits::Maximum()) { + return capacity; + } + return uint16_t(next_power_of_two); +} + +//===--------------------------------------------------------------------===// +// Create +//===--------------------------------------------------------------------===// +template +static ListSegment *CreatePrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, uint16_t capacity) { + // allocate data and set the header + auto segment = (ListSegment *)AllocatePrimitiveData(allocator, capacity); + segment->capacity = capacity; + segment->count = 0; + segment->next = nullptr; + return segment; +} + +static ListSegment *CreateListSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, uint16_t capacity) { + // allocate data and set the header + auto segment = reinterpret_cast(AllocateListData(allocator, capacity)); + segment->capacity = capacity; + segment->count = 0; + segment->next = nullptr; + + // create an empty linked list for the child vector + auto linked_child_list = GetListChildData(segment); + LinkedList linked_list(0, nullptr, nullptr); + Store(linked_list, data_ptr_cast(linked_child_list)); + + return segment; +} + +static ListSegment *CreateStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, + uint16_t capacity) { + // allocate data and set header + auto segment = + reinterpret_cast(AllocateStructData(allocator, capacity, functions.child_functions.size())); + segment->capacity = capacity; + segment->count = 0; + segment->next = nullptr; + + // create a child ListSegment with exactly the same capacity for each child vector + auto child_segments = GetStructData(segment); + for (idx_t i = 0; i < functions.child_functions.size(); i++) { + auto child_function = functions.child_functions[i]; + auto child_segment = child_function.create_segment(child_function, allocator, capacity); + Store(child_segment, data_ptr_cast(child_segments + i)); + } + + return segment; +} + +static ListSegment *GetSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, + LinkedList &linked_list) { + ListSegment *segment; + + // determine segment + if (!linked_list.last_segment) { + // empty linked list, create the first (and last) segment + auto capacity = ListSegment::INITIAL_CAPACITY; + segment = functions.create_segment(functions, allocator, capacity); + linked_list.first_segment = segment; + linked_list.last_segment = segment; + + } else if (linked_list.last_segment->capacity == linked_list.last_segment->count) { + // the last segment of the linked list is full, create a new one and append it + auto capacity = GetCapacityForNewSegment(linked_list.last_segment->capacity); + segment = functions.create_segment(functions, allocator, capacity); + linked_list.last_segment->next = segment; + linked_list.last_segment = segment; + } else { + // the last segment of the linked list is not full, append the data to it + segment = linked_list.last_segment; + } + + D_ASSERT(segment); + return segment; +} + +//===--------------------------------------------------------------------===// +// Append +//===--------------------------------------------------------------------===// +template +static void WriteDataToPrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &, ListSegment *segment, + RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { + + auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); + + // write null validity + auto null_mask = GetNullMask(segment); + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[segment->count] = !valid; + + // write value + if (valid) { + auto segment_data = GetPrimitiveData(segment); + auto input_data_ptr = UnifiedVectorFormat::GetData(input_data.unified); + Store(input_data_ptr[sel_entry_idx], data_ptr_cast(segment_data + segment->count)); + } +} + +static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, + idx_t &entry_idx) { + + auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); + + // write null validity + auto null_mask = GetNullMask(segment); + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[segment->count] = !valid; + + // set the length of this string + auto str_length_data = GetListLengthData(segment); + uint64_t str_length = 0; + + // get the string + string_t str_entry; + if (valid) { + str_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; + str_length = str_entry.GetSize(); + } + + // we can reconstruct the offset from the length + Store(str_length, data_ptr_cast(str_length_data + segment->count)); + if (!valid) { + return; + } + + // write the characters to the linked list of child segments + auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); + for (char &c : str_entry.GetString()) { + auto child_segment = GetSegment(functions.child_functions.back(), allocator, child_segments); + auto data = GetPrimitiveData(child_segment); + data[child_segment->count] = c; + child_segment->count++; + child_segments.total_capacity++; + } + + // store the updated linked list + Store(child_segments, data_ptr_cast(GetListChildData(segment))); +} + +static void WriteDataToListSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { + + auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); + + // write null validity + auto null_mask = GetNullMask(segment); + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[segment->count] = !valid; + + // set the length of this list + auto list_length_data = GetListLengthData(segment); + uint64_t list_length = 0; + + if (valid) { + // get list entry information + const auto &list_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; + list_length = list_entry.length; + + // loop over the child vector entries and recurse on them + auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); + D_ASSERT(functions.child_functions.size() == 1); + for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { + auto source_idx_child = list_entry.offset + child_idx; + functions.child_functions[0].AppendRow(allocator, child_segments, input_data.children.back(), + source_idx_child); + } + // store the updated linked list + Store(child_segments, data_ptr_cast(GetListChildData(segment))); + } + + Store(list_length, data_ptr_cast(list_length_data + segment->count)); +} + +static void WriteDataToStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { + + auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); + + // write null validity + auto null_mask = GetNullMask(segment); + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[segment->count] = !valid; + + // write value + D_ASSERT(input_data.children.size() == functions.child_functions.size()); + auto child_list = GetStructData(segment); + + // write the data of each of the children of the struct + for (idx_t i = 0; i < input_data.children.size(); i++) { + auto child_list_segment = Load(data_ptr_cast(child_list + i)); + auto &child_function = functions.child_functions[i]; + child_function.write_data(child_function, allocator, child_list_segment, input_data.children[i], entry_idx); + child_list_segment->count++; + } +} + +void ListSegmentFunctions::AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, + RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) const { + + auto &write_data_to_segment = *this; + auto segment = GetSegment(write_data_to_segment, allocator, linked_list); + write_data_to_segment.write_data(write_data_to_segment, allocator, segment, input_data, entry_idx); + + linked_list.total_capacity++; + segment->count++; +} + +//===--------------------------------------------------------------------===// +// Read +//===--------------------------------------------------------------------===// +template +static void ReadDataFromPrimitiveSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, + idx_t &total_count) { + + auto &aggr_vector_validity = FlatVector::Validity(result); + + // set NULLs + auto null_mask = GetNullMask(segment); + for (idx_t i = 0; i < segment->count; i++) { + if (null_mask[i]) { + aggr_vector_validity.SetInvalid(total_count + i); + } + } + + auto aggr_vector_data = FlatVector::GetData(result); + + // load values + for (idx_t i = 0; i < segment->count; i++) { + if (aggr_vector_validity.RowIsValid(total_count + i)) { + auto data = GetPrimitiveData(segment); + aggr_vector_data[total_count + i] = Load(const_data_ptr_cast(data + i)); + } + } +} + +static void ReadDataFromVarcharSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, + idx_t &total_count) { + + auto &aggr_vector_validity = FlatVector::Validity(result); + + // set NULLs + auto null_mask = GetNullMask(segment); + for (idx_t i = 0; i < segment->count; i++) { + if (null_mask[i]) { + aggr_vector_validity.SetInvalid(total_count + i); + } + } + + // append all the child chars to one string + string str = ""; + auto linked_child_list = Load(const_data_ptr_cast(GetListChildData(segment))); + while (linked_child_list.first_segment) { + auto child_segment = linked_child_list.first_segment; + auto data = GetPrimitiveData(child_segment); + str.append(data, child_segment->count); + linked_child_list.first_segment = child_segment->next; + } + linked_child_list.last_segment = nullptr; + + // use length and (reconstructed) offset to get the correct substrings + auto aggr_vector_data = FlatVector::GetData(result); + auto str_length_data = GetListLengthData(segment); + + // get the substrings and write them to the result vector + idx_t offset = 0; + for (idx_t i = 0; i < segment->count; i++) { + if (!null_mask[i]) { + auto str_length = Load(const_data_ptr_cast(str_length_data + i)); + auto substr = str.substr(offset, str_length); + auto str_t = StringVector::AddStringOrBlob(result, substr); + aggr_vector_data[total_count + i] = str_t; + offset += str_length; + } + } +} + +static void ReadDataFromListSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, + idx_t &total_count) { + + auto &aggr_vector_validity = FlatVector::Validity(result); + + // set NULLs + auto null_mask = GetNullMask(segment); + for (idx_t i = 0; i < segment->count; i++) { + if (null_mask[i]) { + aggr_vector_validity.SetInvalid(total_count + i); + } + } + + auto list_vector_data = FlatVector::GetData(result); + + // get the starting offset + idx_t offset = 0; + if (total_count != 0) { + offset = list_vector_data[total_count - 1].offset + list_vector_data[total_count - 1].length; + } + idx_t starting_offset = offset; + + // set length and offsets + auto list_length_data = GetListLengthData(segment); + for (idx_t i = 0; i < segment->count; i++) { + auto list_length = Load(const_data_ptr_cast(list_length_data + i)); + list_vector_data[total_count + i].length = list_length; + list_vector_data[total_count + i].offset = offset; + offset += list_length; + } + + auto &child_vector = ListVector::GetEntry(result); + auto linked_child_list = Load(const_data_ptr_cast(GetListChildData(segment))); + ListVector::Reserve(result, offset); + + // recurse into the linked list of child values + D_ASSERT(functions.child_functions.size() == 1); + functions.child_functions[0].BuildListVector(linked_child_list, child_vector, starting_offset); + ListVector::SetListSize(result, offset); +} + +static void ReadDataFromStructSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, + idx_t &total_count) { + + auto &aggr_vector_validity = FlatVector::Validity(result); + + // set NULLs + auto null_mask = GetNullMask(segment); + for (idx_t i = 0; i < segment->count; i++) { + if (null_mask[i]) { + aggr_vector_validity.SetInvalid(total_count + i); + } + } + + auto &children = StructVector::GetEntries(result); + + // recurse into the child segments of each child of the struct + D_ASSERT(children.size() == functions.child_functions.size()); + auto struct_children = GetStructData(segment); + for (idx_t child_count = 0; child_count < children.size(); child_count++) { + auto struct_children_segment = Load(const_data_ptr_cast(struct_children + child_count)); + auto &child_function = functions.child_functions[child_count]; + child_function.read_data(child_function, struct_children_segment, *children[child_count], total_count); + } +} + +void ListSegmentFunctions::BuildListVector(const LinkedList &linked_list, Vector &result, + idx_t &initial_total_count) const { + auto &read_data_from_segment = *this; + idx_t total_count = initial_total_count; + auto segment = linked_list.first_segment; + while (segment) { + read_data_from_segment.read_data(read_data_from_segment, segment, result, total_count); + + total_count += segment->count; + segment = segment->next; + } +} + +//===--------------------------------------------------------------------===// +// Functions +//===--------------------------------------------------------------------===// +template +void SegmentPrimitiveFunction(ListSegmentFunctions &functions) { + functions.create_segment = CreatePrimitiveSegment; + functions.write_data = WriteDataToPrimitiveSegment; + functions.read_data = ReadDataFromPrimitiveSegment; +} + +void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType &type) { + + auto physical_type = type.InternalType(); + switch (physical_type) { + case PhysicalType::BIT: + case PhysicalType::BOOL: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::INT8: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::INT16: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::INT32: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::INT64: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::UINT8: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::UINT16: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::UINT32: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::UINT64: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::FLOAT: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::DOUBLE: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::INT128: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::INTERVAL: + SegmentPrimitiveFunction(functions); + break; + case PhysicalType::VARCHAR: { + functions.create_segment = CreateListSegment; + functions.write_data = WriteDataToVarcharSegment; + functions.read_data = ReadDataFromVarcharSegment; + + functions.child_functions.emplace_back(); + SegmentPrimitiveFunction(functions.child_functions.back()); + break; + } + case PhysicalType::LIST: { + functions.create_segment = CreateListSegment; + functions.write_data = WriteDataToListSegment; + functions.read_data = ReadDataFromListSegment; + + // recurse + functions.child_functions.emplace_back(); + GetSegmentDataFunctions(functions.child_functions.back(), ListType::GetChildType(type)); + break; + } + case PhysicalType::STRUCT: { + functions.create_segment = CreateStructSegment; + functions.write_data = WriteDataToStructSegment; + functions.read_data = ReadDataFromStructSegment; + + // recurse + auto child_types = StructType::GetChildTypes(type); + for (idx_t i = 0; i < child_types.size(); i++) { + functions.child_functions.emplace_back(); + GetSegmentDataFunctions(functions.child_functions.back(), child_types[i].second); + } + break; + } + default: + throw InternalException("LIST aggregate not yet implemented for " + type.ToString()); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp new file mode 100644 index 00000000..03fded4f --- /dev/null +++ b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp @@ -0,0 +1,436 @@ +#include "duckdb/common/types/row/partitioned_tuple_data.hpp" + +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/common/types/row/tuple_data_iterator.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +PartitionedTupleData::PartitionedTupleData(PartitionedTupleDataType type_p, BufferManager &buffer_manager_p, + const TupleDataLayout &layout_p) + : type(type_p), buffer_manager(buffer_manager_p), layout(layout_p.Copy()), count(0), data_size(0), + allocators(make_shared()) { +} + +PartitionedTupleData::PartitionedTupleData(const PartitionedTupleData &other) + : type(other.type), buffer_manager(other.buffer_manager), layout(other.layout.Copy()) { +} + +PartitionedTupleData::~PartitionedTupleData() { +} + +const TupleDataLayout &PartitionedTupleData::GetLayout() const { + return layout; +} + +PartitionedTupleDataType PartitionedTupleData::GetType() const { + return type; +} + +void PartitionedTupleData::InitializeAppendState(PartitionedTupleDataAppendState &state, + TupleDataPinProperties properties) const { + state.partition_sel.Initialize(); + state.reverse_partition_sel.Initialize(); + + vector column_ids; + column_ids.reserve(layout.ColumnCount()); + for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { + column_ids.emplace_back(col_idx); + } + + InitializeAppendStateInternal(state, properties); +} + +void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, DataChunk &input, + const SelectionVector &append_sel, const idx_t append_count) { + TupleDataCollection::ToUnifiedFormat(state.chunk_state, input); + AppendUnified(state, input, append_sel, append_count); +} + +bool PartitionedTupleData::UseFixedSizeMap() const { + return MaxPartitionIndex() < PartitionedTupleDataAppendState::MAP_THRESHOLD; +} + +void PartitionedTupleData::AppendUnified(PartitionedTupleDataAppendState &state, DataChunk &input, + const SelectionVector &append_sel, const idx_t append_count) { + const idx_t actual_append_count = append_count == DConstants::INVALID_INDEX ? input.size() : append_count; + + // Compute partition indices and store them in state.partition_indices + ComputePartitionIndices(state, input); + + // Build the selection vector for the partitions + BuildPartitionSel(state, append_sel, actual_append_count); + + // Early out: check if everything belongs to a single partition + optional_idx partition_index; + if (UseFixedSizeMap()) { + if (state.fixed_partition_entries.size() == 1) { + partition_index = state.fixed_partition_entries.begin().GetKey(); + } + } else { + if (state.partition_entries.size() == 1) { + partition_index = state.partition_entries.begin()->first; + } + } + if (partition_index.IsValid()) { + auto &partition = *partitions[partition_index.GetIndex()]; + auto &partition_pin_state = *state.partition_pin_states[partition_index.GetIndex()]; + + const auto size_before = partition.SizeInBytes(); + partition.AppendUnified(partition_pin_state, state.chunk_state, input, append_sel, actual_append_count); + data_size += partition.SizeInBytes() - size_before; + } else { + // Compute the heap sizes for the whole chunk + if (!layout.AllConstant()) { + TupleDataCollection::ComputeHeapSizes(state.chunk_state, input, state.partition_sel, actual_append_count); + } + + // Build the buffer space + BuildBufferSpace(state); + + // Now scatter everything in one go + partitions[0]->Scatter(state.chunk_state, input, state.partition_sel, actual_append_count); + } + + count += actual_append_count; + Verify(); +} + +void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, TupleDataChunkState &input, + const idx_t append_count) { + // Compute partition indices and store them in state.partition_indices + ComputePartitionIndices(input.row_locations, append_count, state.partition_indices); + + // Build the selection vector for the partitions + BuildPartitionSel(state, *FlatVector::IncrementalSelectionVector(), append_count); + + // Early out: check if everything belongs to a single partition + optional_idx partition_index; + if (UseFixedSizeMap()) { + if (state.fixed_partition_entries.size() == 1) { + partition_index = state.fixed_partition_entries.begin().GetKey(); + } + } else { + if (state.partition_entries.size() == 1) { + partition_index = state.partition_entries.begin()->first; + } + } + + if (partition_index.IsValid()) { + auto &partition = *partitions[partition_index.GetIndex()]; + auto &partition_pin_state = *state.partition_pin_states[partition_index.GetIndex()]; + + state.chunk_state.heap_sizes.Reference(input.heap_sizes); + + const auto size_before = partition.SizeInBytes(); + partition.Build(partition_pin_state, state.chunk_state, 0, append_count); + data_size += partition.SizeInBytes() - size_before; + + partition.CopyRows(state.chunk_state, input, *FlatVector::IncrementalSelectionVector(), append_count); + } else { + // Build the buffer space + state.chunk_state.heap_sizes.Slice(input.heap_sizes, state.partition_sel, append_count); + state.chunk_state.heap_sizes.Flatten(append_count); + BuildBufferSpace(state); + + // Copy the rows + partitions[0]->CopyRows(state.chunk_state, input, state.partition_sel, append_count); + } + + count += append_count; + Verify(); +} + +// LCOV_EXCL_START +template +struct UnorderedMapGetter { + static inline const typename MAP_TYPE::key_type &GetKey(typename MAP_TYPE::iterator &iterator) { + return iterator->first; + } + + static inline const typename MAP_TYPE::key_type &GetKey(const typename MAP_TYPE::const_iterator &iterator) { + return iterator->first; + } + + static inline typename MAP_TYPE::mapped_type &GetValue(typename MAP_TYPE::iterator &iterator) { + return iterator->second; + } + + static inline const typename MAP_TYPE::mapped_type &GetValue(const typename MAP_TYPE::const_iterator &iterator) { + return iterator->second; + } +}; + +template +struct FixedSizeMapGetter { + static inline const idx_t &GetKey(fixed_size_map_iterator_t &iterator) { + return iterator.GetKey(); + } + + static inline const idx_t &GetKey(const const_fixed_size_map_iterator_t &iterator) { + return iterator.GetKey(); + } + + static inline T &GetValue(fixed_size_map_iterator_t &iterator) { + return iterator.GetValue(); + } + + static inline const T &GetValue(const const_fixed_size_map_iterator_t &iterator) { + return iterator.GetValue(); + } +}; +// LCOV_EXCL_STOP + +void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, const SelectionVector &append_sel, + const idx_t append_count) { + if (UseFixedSizeMap()) { + BuildPartitionSel, FixedSizeMapGetter>( + state, state.fixed_partition_entries, append_sel, append_count); + } else { + BuildPartitionSel, UnorderedMapGetter>>( + state, state.partition_entries, append_sel, append_count); + } +} + +template +void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, MAP_TYPE &partition_entries, + const SelectionVector &append_sel, const idx_t append_count) { + const auto partition_indices = FlatVector::GetData(state.partition_indices); + partition_entries.clear(); + + switch (state.partition_indices.GetVectorType()) { + case VectorType::FLAT_VECTOR: + for (idx_t i = 0; i < append_count; i++) { + const auto index = append_sel.get_index(i); + const auto &partition_index = partition_indices[index]; + auto partition_entry = partition_entries.find(partition_index); + if (partition_entry == partition_entries.end()) { + partition_entries[partition_index] = list_entry_t(0, 1); + } else { + GETTER::GetValue(partition_entry).length++; + } + } + break; + case VectorType::CONSTANT_VECTOR: + partition_entries[partition_indices[0]] = list_entry_t(0, append_count); + break; + default: + throw InternalException("Unexpected VectorType in PartitionedTupleData::Append"); + } + + // Early out: check if everything belongs to a single partition + if (partition_entries.size() == 1) { + // This needs to be initialized, even if we go the short path here + for (idx_t i = 0; i < append_count; i++) { + const auto index = append_sel.get_index(i); + state.reverse_partition_sel[index] = i; + } + return; + } + + // Compute offsets from the counts + idx_t offset = 0; + for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { + auto &partition_entry = GETTER::GetValue(it); + partition_entry.offset = offset; + offset += partition_entry.length; + } + + // Now initialize a single selection vector that acts as a selection vector for every partition + auto &partition_sel = state.partition_sel; + auto &reverse_partition_sel = state.reverse_partition_sel; + for (idx_t i = 0; i < append_count; i++) { + const auto index = append_sel.get_index(i); + const auto &partition_index = partition_indices[index]; + auto &partition_offset = partition_entries[partition_index].offset; + reverse_partition_sel[index] = partition_offset; + partition_sel[partition_offset++] = index; + } +} + +void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state) { + if (UseFixedSizeMap()) { + BuildBufferSpace, FixedSizeMapGetter>( + state, state.fixed_partition_entries); + } else { + BuildBufferSpace, UnorderedMapGetter>>( + state, state.partition_entries); + } +} + +template +void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state, const MAP_TYPE &partition_entries) { + for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { + const auto &partition_index = GETTER::GetKey(it); + + // Partition, pin state for this partition index + auto &partition = *partitions[partition_index]; + auto &partition_pin_state = *state.partition_pin_states[partition_index]; + + // Length and offset for this partition + const auto &partition_entry = GETTER::GetValue(it); + const auto &partition_length = partition_entry.length; + const auto partition_offset = partition_entry.offset - partition_length; + + // Build out the buffer space for this partition + const auto size_before = partition.SizeInBytes(); + partition.Build(partition_pin_state, state.chunk_state, partition_offset, partition_length); + data_size += partition.SizeInBytes() - size_before; + } +} + +void PartitionedTupleData::FlushAppendState(PartitionedTupleDataAppendState &state) { + for (idx_t partition_index = 0; partition_index < partitions.size(); partition_index++) { + auto &partition = *partitions[partition_index]; + auto &partition_pin_state = *state.partition_pin_states[partition_index]; + partition.FinalizePinState(partition_pin_state); + } +} + +void PartitionedTupleData::Combine(PartitionedTupleData &other) { + if (other.Count() == 0) { + return; + } + + // Now combine the state's partitions into this + lock_guard guard(lock); + if (partitions.empty()) { + // This is the first merge, we just copy them over + partitions = std::move(other.partitions); + } else { + D_ASSERT(partitions.size() == other.partitions.size()); + // Combine the append state's partitions into this PartitionedTupleData + for (idx_t i = 0; i < other.partitions.size(); i++) { + partitions[i]->Combine(*other.partitions[i]); + } + } + this->count += other.count; + this->data_size += other.data_size; + Verify(); +} + +void PartitionedTupleData::Reset() { + for (auto &partition : partitions) { + partition->Reset(); + } + this->count = 0; + this->data_size = 0; + Verify(); +} + +void PartitionedTupleData::Repartition(PartitionedTupleData &new_partitioned_data) { + D_ASSERT(layout.GetTypes() == new_partitioned_data.layout.GetTypes()); + + if (partitions.size() == new_partitioned_data.partitions.size()) { + new_partitioned_data.Combine(*this); + return; + } + + PartitionedTupleDataAppendState append_state; + new_partitioned_data.InitializeAppendState(append_state); + + const auto reverse = RepartitionReverseOrder(); + const idx_t start_idx = reverse ? partitions.size() : 0; + const idx_t end_idx = reverse ? 0 : partitions.size(); + const int64_t update = reverse ? -1 : 1; + const int64_t adjustment = reverse ? -1 : 0; + + for (idx_t partition_idx = start_idx; partition_idx != end_idx; partition_idx += update) { + auto actual_partition_idx = partition_idx + adjustment; + auto &partition = *partitions[actual_partition_idx]; + + if (partition.Count() > 0) { + TupleDataChunkIterator iterator(partition, TupleDataPinProperties::DESTROY_AFTER_DONE, true); + auto &chunk_state = iterator.GetChunkState(); + do { + new_partitioned_data.Append(append_state, chunk_state, iterator.GetCurrentChunkCount()); + } while (iterator.Next()); + + RepartitionFinalizeStates(*this, new_partitioned_data, append_state, actual_partition_idx); + } + partitions[actual_partition_idx]->Reset(); + } + new_partitioned_data.FlushAppendState(append_state); + + count = 0; + data_size = 0; + + Verify(); +} + +void PartitionedTupleData::Unpin() { + for (auto &partition : partitions) { + partition->Unpin(); + } +} + +vector> &PartitionedTupleData::GetPartitions() { + return partitions; +} + +unique_ptr PartitionedTupleData::GetUnpartitioned() { + auto data_collection = std::move(partitions[0]); + partitions[0] = make_uniq(buffer_manager, layout); + + for (idx_t i = 1; i < partitions.size(); i++) { + data_collection->Combine(*partitions[i]); + } + count = 0; + data_size = 0; + + data_collection->Verify(); + Verify(); + + return data_collection; +} + +idx_t PartitionedTupleData::Count() const { + return count; +} + +idx_t PartitionedTupleData::SizeInBytes() const { + idx_t total_size = 0; + for (auto &partition : partitions) { + total_size += partition->SizeInBytes(); + } + return total_size; +} + +idx_t PartitionedTupleData::PartitionCount() const { + return partitions.size(); +} + +void PartitionedTupleData::Verify() const { +#ifdef DEBUG + idx_t total_count = 0; + idx_t total_size = 0; + for (auto &partition : partitions) { + partition->Verify(); + total_count += partition->Count(); + total_size += partition->SizeInBytes(); + } + D_ASSERT(total_count == this->count); + D_ASSERT(total_size == this->data_size); +#endif +} + +// LCOV_EXCL_START +string PartitionedTupleData::ToString() { + string result = + StringUtil::Format("PartitionedTupleData - [%llu Partitions, %llu Rows]\n", partitions.size(), Count()); + for (idx_t partition_idx = 0; partition_idx < partitions.size(); partition_idx++) { + result += StringUtil::Format("Partition %llu: ", partition_idx) + partitions[partition_idx]->ToString(); + } + return result; +} + +void PartitionedTupleData::Print() { + Printer::Print(ToString()); +} +// LCOV_EXCL_STOP + +void PartitionedTupleData::CreateAllocator() { + allocators->allocators.emplace_back(make_shared(buffer_manager, layout)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_data_collection.cpp b/src/duckdb/src/common/types/row/row_data_collection.cpp new file mode 100644 index 00000000..47a12efc --- /dev/null +++ b/src/duckdb/src/common/types/row/row_data_collection.cpp @@ -0,0 +1,141 @@ +#include "duckdb/common/types/row/row_data_collection.hpp" + +namespace duckdb { + +RowDataCollection::RowDataCollection(BufferManager &buffer_manager, idx_t block_capacity, idx_t entry_size, + bool keep_pinned) + : buffer_manager(buffer_manager), count(0), block_capacity(block_capacity), entry_size(entry_size), + keep_pinned(keep_pinned) { + D_ASSERT(block_capacity * entry_size + entry_size > Storage::BLOCK_SIZE); +} + +idx_t RowDataCollection::AppendToBlock(RowDataBlock &block, BufferHandle &handle, + vector &append_entries, idx_t remaining, idx_t entry_sizes[]) { + idx_t append_count = 0; + data_ptr_t dataptr; + if (entry_sizes) { + D_ASSERT(entry_size == 1); + // compute how many entries fit if entry size is variable + dataptr = handle.Ptr() + block.byte_offset; + for (idx_t i = 0; i < remaining; i++) { + if (block.byte_offset + entry_sizes[i] > block.capacity) { + if (block.count == 0 && append_count == 0 && entry_sizes[i] > block.capacity) { + // special case: single entry is bigger than block capacity + // resize current block to fit the entry, append it, and move to the next block + block.capacity = entry_sizes[i]; + buffer_manager.ReAllocate(block.block, block.capacity); + dataptr = handle.Ptr(); + append_count++; + block.byte_offset += entry_sizes[i]; + } + break; + } + append_count++; + block.byte_offset += entry_sizes[i]; + } + } else { + append_count = MinValue(remaining, block.capacity - block.count); + dataptr = handle.Ptr() + block.count * entry_size; + } + append_entries.emplace_back(dataptr, append_count); + block.count += append_count; + return append_count; +} + +RowDataBlock &RowDataCollection::CreateBlock() { + blocks.push_back(make_uniq(buffer_manager, block_capacity, entry_size)); + return *blocks.back(); +} + +vector RowDataCollection::Build(idx_t added_count, data_ptr_t key_locations[], idx_t entry_sizes[], + const SelectionVector *sel) { + vector handles; + vector append_entries; + + // first allocate space of where to serialize the keys and payload columns + idx_t remaining = added_count; + { + // first append to the last block (if any) + lock_guard append_lock(rdc_lock); + count += added_count; + + if (!blocks.empty()) { + auto &last_block = *blocks.back(); + if (last_block.count < last_block.capacity) { + // last block has space: pin the buffer of this block + auto handle = buffer_manager.Pin(last_block.block); + // now append to the block + idx_t append_count = AppendToBlock(last_block, handle, append_entries, remaining, entry_sizes); + remaining -= append_count; + handles.push_back(std::move(handle)); + } + } + while (remaining > 0) { + // now for the remaining data, allocate new buffers to store the data and append there + auto &new_block = CreateBlock(); + auto handle = buffer_manager.Pin(new_block.block); + + // offset the entry sizes array if we have added entries already + idx_t *offset_entry_sizes = entry_sizes ? entry_sizes + added_count - remaining : nullptr; + + idx_t append_count = AppendToBlock(new_block, handle, append_entries, remaining, offset_entry_sizes); + D_ASSERT(new_block.count > 0); + remaining -= append_count; + + if (keep_pinned) { + pinned_blocks.push_back(std::move(handle)); + } else { + handles.push_back(std::move(handle)); + } + } + } + // now set up the key_locations based on the append entries + idx_t append_idx = 0; + for (auto &append_entry : append_entries) { + idx_t next = append_idx + append_entry.count; + if (entry_sizes) { + for (; append_idx < next; append_idx++) { + key_locations[append_idx] = append_entry.baseptr; + append_entry.baseptr += entry_sizes[append_idx]; + } + } else { + for (; append_idx < next; append_idx++) { + auto idx = sel->get_index(append_idx); + key_locations[idx] = append_entry.baseptr; + append_entry.baseptr += entry_size; + } + } + } + // return the unique pointers to the handles because they must stay pinned + return handles; +} + +void RowDataCollection::Merge(RowDataCollection &other) { + if (other.count == 0) { + return; + } + RowDataCollection temp(buffer_manager, Storage::BLOCK_SIZE, 1); + { + // One lock at a time to avoid deadlocks + lock_guard read_lock(other.rdc_lock); + temp.count = other.count; + temp.block_capacity = other.block_capacity; + temp.entry_size = other.entry_size; + temp.blocks = std::move(other.blocks); + temp.pinned_blocks = std::move(other.pinned_blocks); + } + other.Clear(); + + lock_guard write_lock(rdc_lock); + count += temp.count; + block_capacity = MaxValue(block_capacity, temp.block_capacity); + entry_size = MaxValue(entry_size, temp.entry_size); + for (auto &block : temp.blocks) { + blocks.emplace_back(std::move(block)); + } + for (auto &handle : temp.pinned_blocks) { + pinned_blocks.emplace_back(std::move(handle)); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp b/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp new file mode 100644 index 00000000..a7b4e2db --- /dev/null +++ b/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp @@ -0,0 +1,313 @@ +#include "duckdb/common/types/row/row_data_collection_scanner.hpp" + +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/row/row_data_collection.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include + +namespace duckdb { + +void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block_collection, + RowDataCollection &swizzled_string_heap, + RowDataCollection &block_collection, RowDataCollection &string_heap, + const RowLayout &layout) { + if (block_collection.count == 0) { + return; + } + + if (layout.AllConstant()) { + // No heap blocks! Just merge fixed-size data + swizzled_block_collection.Merge(block_collection); + return; + } + + // We create one heap block per data block and swizzle the pointers + D_ASSERT(string_heap.keep_pinned == swizzled_string_heap.keep_pinned); + auto &buffer_manager = block_collection.buffer_manager; + auto &heap_blocks = string_heap.blocks; + idx_t heap_block_idx = 0; + idx_t heap_block_remaining = heap_blocks[heap_block_idx]->count; + for (auto &data_block : block_collection.blocks) { + if (heap_block_remaining == 0) { + heap_block_remaining = heap_blocks[++heap_block_idx]->count; + } + + // Pin the data block and swizzle the pointers within the rows + auto data_handle = buffer_manager.Pin(data_block->block); + auto data_ptr = data_handle.Ptr(); + if (!string_heap.keep_pinned) { + D_ASSERT(!data_block->block->IsSwizzled()); + RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); + data_block->block->SetSwizzling(nullptr); + } + // At this point the data block is pinned and the heap pointer is valid + // so we can copy heap data as needed + + // We want to copy as little of the heap data as possible, check how the data and heap blocks line up + if (heap_block_remaining >= data_block->count) { + // Easy: current heap block contains all strings for this data block, just copy (reference) the block + swizzled_string_heap.blocks.emplace_back(heap_blocks[heap_block_idx]->Copy()); + swizzled_string_heap.blocks.back()->count = data_block->count; + + // Swizzle the heap pointer if we are not pinning the heap + auto &heap_block = swizzled_string_heap.blocks.back()->block; + auto heap_handle = buffer_manager.Pin(heap_block); + if (!swizzled_string_heap.keep_pinned) { + auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); + auto heap_offset = heap_ptr - heap_handle.Ptr(); + RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, heap_offset); + } else { + swizzled_string_heap.pinned_blocks.emplace_back(std::move(heap_handle)); + } + + // Update counter + heap_block_remaining -= data_block->count; + } else { + // Strings for this data block are spread over the current heap block and the next (and possibly more) + if (string_heap.keep_pinned) { + // The heap is changing underneath the data block, + // so swizzle the string pointers to make them portable. + RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); + } + idx_t data_block_remaining = data_block->count; + vector> ptrs_and_sizes; + idx_t total_size = 0; + const auto base_row_ptr = data_ptr; + while (data_block_remaining > 0) { + if (heap_block_remaining == 0) { + heap_block_remaining = heap_blocks[++heap_block_idx]->count; + } + auto next = MinValue(data_block_remaining, heap_block_remaining); + + // Figure out where to start copying strings, and how many bytes we need to copy + auto heap_start_ptr = Load(data_ptr + layout.GetHeapOffset()); + auto heap_end_ptr = + Load(data_ptr + layout.GetHeapOffset() + (next - 1) * layout.GetRowWidth()); + idx_t size = heap_end_ptr - heap_start_ptr + Load(heap_end_ptr); + ptrs_and_sizes.emplace_back(heap_start_ptr, size); + D_ASSERT(size <= heap_blocks[heap_block_idx]->byte_offset); + + // Swizzle the heap pointer + RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_start_ptr, next, total_size); + total_size += size; + + // Update where we are in the data and heap blocks + data_ptr += next * layout.GetRowWidth(); + data_block_remaining -= next; + heap_block_remaining -= next; + } + + // Finally, we allocate a new heap block and copy data to it + swizzled_string_heap.blocks.emplace_back( + make_uniq(buffer_manager, MaxValue(total_size, (idx_t)Storage::BLOCK_SIZE), 1)); + auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); + auto new_heap_ptr = new_heap_handle.Ptr(); + for (auto &ptr_and_size : ptrs_and_sizes) { + memcpy(new_heap_ptr, ptr_and_size.first, ptr_and_size.second); + new_heap_ptr += ptr_and_size.second; + } + new_heap_ptr = new_heap_handle.Ptr(); + if (swizzled_string_heap.keep_pinned) { + // Since the heap blocks are pinned, we can unswizzle the data again. + swizzled_string_heap.pinned_blocks.emplace_back(std::move(new_heap_handle)); + RowOperations::UnswizzlePointers(layout, base_row_ptr, new_heap_ptr, data_block->count); + RowOperations::UnswizzleHeapPointer(layout, base_row_ptr, new_heap_ptr, data_block->count); + } + } + } + + // We're done with variable-sized data, now just merge the fixed-size data + swizzled_block_collection.Merge(block_collection); + D_ASSERT(swizzled_block_collection.blocks.size() == swizzled_string_heap.blocks.size()); + + // Update counts and cleanup + swizzled_string_heap.count = string_heap.count; + string_heap.Clear(); +} + +void RowDataCollectionScanner::ScanState::PinData() { + auto &rows = scanner.rows; + D_ASSERT(block_idx < rows.blocks.size()); + auto &data_block = rows.blocks[block_idx]; + if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { + data_handle = rows.buffer_manager.Pin(data_block->block); + } + if (scanner.layout.AllConstant() || !scanner.external) { + return; + } + + auto &heap = scanner.heap; + D_ASSERT(block_idx < heap.blocks.size()); + auto &heap_block = heap.blocks[block_idx]; + if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { + heap_handle = heap.buffer_manager.Pin(heap_block->block); + } +} + +RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, + const RowLayout &layout_p, bool external_p, bool flush_p) + : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), + external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { + + if (unswizzling) { + D_ASSERT(rows.blocks.size() == heap.blocks.size()); + } + + ValidateUnscannedBlock(); +} + +RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, + const RowLayout &layout_p, bool external_p, idx_t block_idx, + bool flush_p) + : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), + external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { + + if (unswizzling) { + D_ASSERT(rows.blocks.size() == heap.blocks.size()); + } + + D_ASSERT(block_idx < rows.blocks.size()); + read_state.block_idx = block_idx; + read_state.entry_idx = 0; + + // Pretend that we have scanned up to the start block + // and will stop at the end + auto begin = rows.blocks.begin(); + auto end = begin + block_idx; + total_scanned = + std::accumulate(begin, end, idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); + total_count = total_scanned + (*end)->count; + + ValidateUnscannedBlock(); +} + +void RowDataCollectionScanner::SwizzleBlock(RowDataBlock &data_block, RowDataBlock &heap_block) { + // Pin the data block and swizzle the pointers within the rows + D_ASSERT(!data_block.block->IsSwizzled()); + auto data_handle = rows.buffer_manager.Pin(data_block.block); + auto data_ptr = data_handle.Ptr(); + RowOperations::SwizzleColumns(layout, data_ptr, data_block.count); + data_block.block->SetSwizzling(nullptr); + + // Swizzle the heap pointers + auto heap_handle = heap.buffer_manager.Pin(heap_block.block); + auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); + auto heap_offset = heap_ptr - heap_handle.Ptr(); + RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, heap_offset); +} + +void RowDataCollectionScanner::ReSwizzle() { + if (rows.count == 0) { + return; + } + + if (!unswizzling) { + // No swizzled blocks! + return; + } + + D_ASSERT(rows.blocks.size() == heap.blocks.size()); + for (idx_t i = 0; i < rows.blocks.size(); ++i) { + auto &data_block = rows.blocks[i]; + if (data_block->block && !data_block->block->IsSwizzled()) { + SwizzleBlock(*data_block, *heap.blocks[i]); + } + } +} + +void RowDataCollectionScanner::ValidateUnscannedBlock() const { + if (unswizzling && read_state.block_idx < rows.blocks.size() && Remaining()) { + D_ASSERT(rows.blocks[read_state.block_idx]->block->IsSwizzled()); + } +} + +void RowDataCollectionScanner::Scan(DataChunk &chunk) { + auto count = MinValue((idx_t)STANDARD_VECTOR_SIZE, total_count - total_scanned); + if (count == 0) { + chunk.SetCardinality(count); + return; + } + + // Only flush blocks we processed. + const auto flush_block_idx = read_state.block_idx; + + const idx_t &row_width = layout.GetRowWidth(); + // Set up a batch of pointers to scan data from + idx_t scanned = 0; + auto data_pointers = FlatVector::GetData(addresses); + + // We must pin ALL blocks we are going to gather from + vector pinned_blocks; + while (scanned < count) { + read_state.PinData(); + auto &data_block = rows.blocks[read_state.block_idx]; + idx_t next = MinValue(data_block->count - read_state.entry_idx, count - scanned); + const data_ptr_t data_ptr = read_state.data_handle.Ptr() + read_state.entry_idx * row_width; + // Set up the next pointers + data_ptr_t row_ptr = data_ptr; + for (idx_t i = 0; i < next; i++) { + data_pointers[scanned + i] = row_ptr; + row_ptr += row_width; + } + // Unswizzle the offsets back to pointers (if needed) + if (unswizzling) { + RowOperations::UnswizzlePointers(layout, data_ptr, read_state.heap_handle.Ptr(), next); + rows.blocks[read_state.block_idx]->block->SetSwizzling("RowDataCollectionScanner::Scan"); + } + // Update state indices + read_state.entry_idx += next; + scanned += next; + total_scanned += next; + if (read_state.entry_idx == data_block->count) { + // Pin completed blocks so we don't lose them + pinned_blocks.emplace_back(rows.buffer_manager.Pin(data_block->block)); + if (unswizzling) { + auto &heap_block = heap.blocks[read_state.block_idx]; + pinned_blocks.emplace_back(heap.buffer_manager.Pin(heap_block->block)); + } + read_state.block_idx++; + read_state.entry_idx = 0; + ValidateUnscannedBlock(); + } + } + D_ASSERT(scanned == count); + // Deserialize the payload data + for (idx_t col_no = 0; col_no < layout.ColumnCount(); col_no++) { + RowOperations::Gather(addresses, *FlatVector::IncrementalSelectionVector(), chunk.data[col_no], + *FlatVector::IncrementalSelectionVector(), count, layout, col_no); + } + chunk.SetCardinality(count); + chunk.Verify(); + + // Switch to a new set of pinned blocks + read_state.pinned_blocks.swap(pinned_blocks); + + if (flush) { + // Release blocks we have passed. + for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { + rows.blocks[i]->block = nullptr; + if (unswizzling) { + heap.blocks[i]->block = nullptr; + } + } + } else if (unswizzling) { + // Reswizzle blocks we have passed so they can be flushed safely. + for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { + auto &data_block = rows.blocks[i]; + if (data_block->block && !data_block->block->IsSwizzled()) { + SwizzleBlock(*data_block, *heap.blocks[i]); + } + } + } +} + +void RowDataCollectionScanner::Reset(bool flush_p) { + flush = flush_p; + total_scanned = 0; + + read_state.block_idx = 0; + read_state.entry_idx = 0; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_layout.cpp b/src/duckdb/src/common/types/row/row_layout.cpp new file mode 100644 index 00000000..3add8e42 --- /dev/null +++ b/src/duckdb/src/common/types/row/row_layout.cpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row_layout.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/common/types/row/row_layout.hpp" + +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +RowLayout::RowLayout() : flag_width(0), data_width(0), row_width(0), all_constant(true), heap_pointer_offset(0) { +} + +void RowLayout::Initialize(vector types_p, bool align) { + offsets.clear(); + types = std::move(types_p); + + // Null mask at the front - 1 bit per value. + flag_width = ValidityBytes::ValidityMaskSize(types.size()); + row_width = flag_width; + + // Whether all columns are constant size. + for (const auto &type : types) { + all_constant = all_constant && TypeIsConstantSize(type.InternalType()); + } + + // This enables pointer swizzling for out-of-core computation. + if (!all_constant) { + // When unswizzled, the pointer lives here. + // When swizzled, the pointer is replaced by an offset. + heap_pointer_offset = row_width; + // The 8 byte pointer will be replaced with an 8 byte idx_t when swizzled. + // However, this cannot be sizeof(data_ptr_t), since 32 bit builds use 4 byte pointers. + row_width += sizeof(idx_t); + } + + // Data columns. No alignment required. + for (const auto &type : types) { + offsets.push_back(row_width); + const auto internal_type = type.InternalType(); + if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { + row_width += GetTypeIdSize(type.InternalType()); + } else { + // Variable size types use pointers to the actual data (can be swizzled). + // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). + row_width += sizeof(idx_t); + } + } + + data_width = row_width - flag_width; + + // Alignment padding for the next row + if (align) { + row_width = AlignValue(row_width); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp new file mode 100644 index 00000000..b30af49c --- /dev/null +++ b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp @@ -0,0 +1,475 @@ +#include "duckdb/common/types/row/tuple_data_allocator.hpp" + +#include "duckdb/common/types/row/tuple_data_segment.hpp" +#include "duckdb/common/types/row/tuple_data_states.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +using ValidityBytes = TupleDataLayout::ValidityBytes; + +TupleDataBlock::TupleDataBlock(BufferManager &buffer_manager, idx_t capacity_p) : capacity(capacity_p), size(0) { + buffer_manager.Allocate(capacity, false, &handle); +} + +TupleDataBlock::TupleDataBlock(TupleDataBlock &&other) noexcept { + std::swap(handle, other.handle); + std::swap(capacity, other.capacity); + std::swap(size, other.size); +} + +TupleDataBlock &TupleDataBlock::operator=(TupleDataBlock &&other) noexcept { + std::swap(handle, other.handle); + std::swap(capacity, other.capacity); + std::swap(size, other.size); + return *this; +} + +TupleDataAllocator::TupleDataAllocator(BufferManager &buffer_manager, const TupleDataLayout &layout) + : buffer_manager(buffer_manager), layout(layout.Copy()) { +} + +TupleDataAllocator::TupleDataAllocator(TupleDataAllocator &allocator) + : buffer_manager(allocator.buffer_manager), layout(allocator.layout.Copy()) { +} + +BufferManager &TupleDataAllocator::GetBufferManager() { + return buffer_manager; +} + +Allocator &TupleDataAllocator::GetAllocator() { + return buffer_manager.GetBufferAllocator(); +} + +const TupleDataLayout &TupleDataAllocator::GetLayout() const { + return layout; +} + +idx_t TupleDataAllocator::RowBlockCount() const { + return row_blocks.size(); +} + +idx_t TupleDataAllocator::HeapBlockCount() const { + return heap_blocks.size(); +} + +void TupleDataAllocator::Build(TupleDataSegment &segment, TupleDataPinState &pin_state, + TupleDataChunkState &chunk_state, const idx_t append_offset, const idx_t append_count) { + D_ASSERT(this == segment.allocator.get()); + auto &chunks = segment.chunks; + if (!chunks.empty()) { + ReleaseOrStoreHandles(pin_state, segment, chunks.back(), true); + } + + // Build the chunk parts for the incoming data + chunk_part_indices.clear(); + idx_t offset = 0; + while (offset != append_count) { + if (chunks.empty() || chunks.back().count == STANDARD_VECTOR_SIZE) { + chunks.emplace_back(); + } + auto &chunk = chunks.back(); + + // Build the next part + auto next = MinValue(append_count - offset, STANDARD_VECTOR_SIZE - chunk.count); + chunk.AddPart(BuildChunkPart(pin_state, chunk_state, append_offset + offset, next, chunk), layout); + auto &chunk_part = chunk.parts.back(); + next = chunk_part.count; + + segment.count += next; + segment.data_size += chunk_part.count * layout.GetRowWidth(); + if (!layout.AllConstant()) { + segment.data_size += chunk_part.total_heap_size; + } + + offset += next; + chunk_part_indices.emplace_back(chunks.size() - 1, chunk.parts.size() - 1); + } + + // Now initialize the pointers to write the data to + chunk_parts.clear(); + for (auto &indices : chunk_part_indices) { + chunk_parts.emplace_back(segment.chunks[indices.first].parts[indices.second]); + } + InitializeChunkStateInternal(pin_state, chunk_state, append_offset, false, true, false, chunk_parts); + + // To reduce metadata, we try to merge chunk parts where possible + // Due to the way chunk parts are constructed, only the last part of the first chunk is eligible for merging + segment.chunks[chunk_part_indices[0].first].MergeLastChunkPart(layout); + + segment.Verify(); +} + +TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, + const idx_t append_offset, const idx_t append_count, + TupleDataChunk &chunk) { + D_ASSERT(append_count != 0); + TupleDataChunkPart result(*chunk.lock); + + // Allocate row block (if needed) + if (row_blocks.empty() || row_blocks.back().RemainingCapacity() < layout.GetRowWidth()) { + row_blocks.emplace_back(buffer_manager, (idx_t)Storage::BLOCK_SIZE); + } + result.row_block_index = row_blocks.size() - 1; + auto &row_block = row_blocks[result.row_block_index]; + result.row_block_offset = row_block.size; + + // Set count (might be reduced later when checking heap space) + result.count = MinValue(row_block.RemainingCapacity(layout.GetRowWidth()), append_count); + if (!layout.AllConstant()) { + const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); + + // Compute total heap size first + idx_t total_heap_size = 0; + for (idx_t i = 0; i < result.count; i++) { + const auto &heap_size = heap_sizes[append_offset + i]; + total_heap_size += heap_size; + } + + if (total_heap_size == 0) { + // We don't need a heap at all + result.heap_block_index = TupleDataChunkPart::INVALID_INDEX; + result.heap_block_offset = TupleDataChunkPart::INVALID_INDEX; + result.total_heap_size = 0; + result.base_heap_ptr = nullptr; + } else { + // Allocate heap block (if needed) + if (heap_blocks.empty() || heap_blocks.back().RemainingCapacity() < heap_sizes[append_offset]) { + const auto size = MaxValue((idx_t)Storage::BLOCK_SIZE, heap_sizes[append_offset]); + heap_blocks.emplace_back(buffer_manager, size); + } + result.heap_block_index = heap_blocks.size() - 1; + auto &heap_block = heap_blocks[result.heap_block_index]; + result.heap_block_offset = heap_block.size; + + const auto heap_remaining = heap_block.RemainingCapacity(); + if (total_heap_size <= heap_remaining) { + // Everything fits + result.total_heap_size = total_heap_size; + } else { + // Not everything fits - determine how many we can read next + result.total_heap_size = 0; + for (idx_t i = 0; i < result.count; i++) { + const auto &heap_size = heap_sizes[append_offset + i]; + if (result.total_heap_size + heap_size > heap_remaining) { + result.count = i; + break; + } + result.total_heap_size += heap_size; + } + } + + // Mark this portion of the heap block as filled and set the pointer + heap_block.size += result.total_heap_size; + result.base_heap_ptr = GetBaseHeapPointer(pin_state, result); + } + } + D_ASSERT(result.count != 0 && result.count <= STANDARD_VECTOR_SIZE); + + // Mark this portion of the row block as filled + row_block.size += result.count * layout.GetRowWidth(); + + return result; +} + +void TupleDataAllocator::InitializeChunkState(TupleDataSegment &segment, TupleDataPinState &pin_state, + TupleDataChunkState &chunk_state, idx_t chunk_idx, bool init_heap) { + D_ASSERT(this == segment.allocator.get()); + D_ASSERT(chunk_idx < segment.ChunkCount()); + auto &chunk = segment.chunks[chunk_idx]; + + // Release or store any handles that are no longer required: + // We can't release the heap here if the current chunk's heap_block_ids is empty, because if we are iterating with + // PinProperties::DESTROY_AFTER_DONE, we might destroy a heap block that is needed by a later chunk, e.g., + // when chunk 0 needs heap block 0, chunk 1 does not need any heap blocks, and chunk 2 needs heap block 0 again + ReleaseOrStoreHandles(pin_state, segment, chunk, !chunk.heap_block_ids.empty()); + + unsafe_vector> parts; + parts.reserve(chunk.parts.size()); + for (auto &part : chunk.parts) { + parts.emplace_back(part); + } + + InitializeChunkStateInternal(pin_state, chunk_state, 0, true, init_heap, init_heap, parts); +} + +static inline void InitializeHeapSizes(const data_ptr_t row_locations[], idx_t heap_sizes[], const idx_t offset, + const idx_t next, const TupleDataChunkPart &part, const idx_t heap_size_offset) { + // Read the heap sizes from the rows + for (idx_t i = 0; i < next; i++) { + auto idx = offset + i; + heap_sizes[idx] = Load(row_locations[idx] + heap_size_offset); + } + + // Verify total size +#ifdef DEBUG + idx_t total_heap_size = 0; + for (idx_t i = 0; i < next; i++) { + auto idx = offset + i; + total_heap_size += heap_sizes[idx]; + } + D_ASSERT(total_heap_size == part.total_heap_size); +#endif +} + +void TupleDataAllocator::InitializeChunkStateInternal(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, + idx_t offset, bool recompute, bool init_heap_pointers, + bool init_heap_sizes, + unsafe_vector> &parts) { + auto row_locations = FlatVector::GetData(chunk_state.row_locations); + auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); + auto heap_locations = FlatVector::GetData(chunk_state.heap_locations); + + for (auto &part_ref : parts) { + auto &part = part_ref.get(); + const auto next = part.count; + + // Set up row locations for the scan + const auto row_width = layout.GetRowWidth(); + const auto base_row_ptr = GetRowPointer(pin_state, part); + for (idx_t i = 0; i < next; i++) { + row_locations[offset + i] = base_row_ptr + i * row_width; + } + + if (layout.AllConstant()) { // Can't have a heap + offset += next; + continue; + } + + if (part.total_heap_size == 0) { + if (init_heap_sizes) { // No heap, but we need the heap sizes + InitializeHeapSizes(row_locations, heap_sizes, offset, next, part, layout.GetHeapSizeOffset()); + } + offset += next; + continue; + } + + // Check if heap block has changed - re-compute the pointers within each row if so + if (recompute && pin_state.properties != TupleDataPinProperties::ALREADY_PINNED) { + const auto new_base_heap_ptr = GetBaseHeapPointer(pin_state, part); + if (part.base_heap_ptr != new_base_heap_ptr) { + lock_guard guard(part.lock); + const auto old_base_heap_ptr = part.base_heap_ptr; + if (old_base_heap_ptr != new_base_heap_ptr) { + Vector old_heap_ptrs( + Value::POINTER(CastPointerToValue(old_base_heap_ptr + part.heap_block_offset))); + Vector new_heap_ptrs( + Value::POINTER(CastPointerToValue(new_base_heap_ptr + part.heap_block_offset))); + RecomputeHeapPointers(old_heap_ptrs, *ConstantVector::ZeroSelectionVector(), row_locations, + new_heap_ptrs, offset, next, layout, 0); + part.base_heap_ptr = new_base_heap_ptr; + } + } + } + + if (init_heap_sizes) { + InitializeHeapSizes(row_locations, heap_sizes, offset, next, part, layout.GetHeapSizeOffset()); + } + + if (init_heap_pointers) { + // Set the pointers where the heap data will be written (if needed) + heap_locations[offset] = part.base_heap_ptr + part.heap_block_offset; + for (idx_t i = 1; i < next; i++) { + auto idx = offset + i; + heap_locations[idx] = heap_locations[idx - 1] + heap_sizes[idx - 1]; + } + } + + offset += next; + } + D_ASSERT(offset <= STANDARD_VECTOR_SIZE); +} + +static inline void VerifyStrings(const LogicalTypeId type_id, const data_ptr_t row_locations[], const idx_t col_idx, + const idx_t base_col_offset, const idx_t col_offset, const idx_t offset, + const idx_t count) { +#ifdef DEBUG + if (type_id != LogicalTypeId::VARCHAR) { + // Make sure we don't verify BLOB / AGGREGATE_STATE + return; + } + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + for (idx_t i = 0; i < count; i++) { + const auto &row_location = row_locations[offset + i] + base_col_offset; + ValidityBytes row_mask(row_location); + if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { + auto recomputed_string = Load(row_location + col_offset); + recomputed_string.Verify(); + } + } +#endif +} + +void TupleDataAllocator::RecomputeHeapPointers(Vector &old_heap_ptrs, const SelectionVector &old_heap_sel, + const data_ptr_t row_locations[], Vector &new_heap_ptrs, + const idx_t offset, const idx_t count, const TupleDataLayout &layout, + const idx_t base_col_offset) { + const auto old_heap_locations = FlatVector::GetData(old_heap_ptrs); + + UnifiedVectorFormat new_heap_data; + new_heap_ptrs.ToUnifiedFormat(offset + count, new_heap_data); + const auto new_heap_locations = UnifiedVectorFormat::GetData(new_heap_data); + const auto new_heap_sel = *new_heap_data.sel; + + for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { + const auto &col_offset = layout.GetOffsets()[col_idx]; + + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + const auto &type = layout.GetTypes()[col_idx]; + switch (type.InternalType()) { + case PhysicalType::VARCHAR: { + for (idx_t i = 0; i < count; i++) { + const auto idx = offset + i; + const auto &row_location = row_locations[idx] + base_col_offset; + ValidityBytes row_mask(row_location); + if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { + continue; + } + + const auto &old_heap_ptr = old_heap_locations[old_heap_sel.get_index(idx)]; + const auto &new_heap_ptr = new_heap_locations[new_heap_sel.get_index(idx)]; + + const auto string_location = row_location + col_offset; + if (Load(string_location) > string_t::INLINE_LENGTH) { + const auto string_ptr_location = string_location + string_t::HEADER_SIZE; + const auto string_ptr = Load(string_ptr_location); + const auto diff = string_ptr - old_heap_ptr; + D_ASSERT(diff >= 0); + Store(new_heap_ptr + diff, string_ptr_location); + } + } + VerifyStrings(type.id(), row_locations, col_idx, base_col_offset, col_offset, offset, count); + break; + } + case PhysicalType::LIST: { + for (idx_t i = 0; i < count; i++) { + const auto idx = offset + i; + const auto &row_location = row_locations[idx] + base_col_offset; + ValidityBytes row_mask(row_location); + if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { + continue; + } + + const auto &old_heap_ptr = old_heap_locations[old_heap_sel.get_index(idx)]; + const auto &new_heap_ptr = new_heap_locations[new_heap_sel.get_index(idx)]; + + const auto &list_ptr_location = row_location + col_offset; + const auto list_ptr = Load(list_ptr_location); + const auto diff = list_ptr - old_heap_ptr; + D_ASSERT(diff >= 0); + Store(new_heap_ptr + diff, list_ptr_location); + } + break; + } + case PhysicalType::STRUCT: { + const auto &struct_layout = layout.GetStructLayout(col_idx); + if (!struct_layout.AllConstant()) { + RecomputeHeapPointers(old_heap_ptrs, old_heap_sel, row_locations, new_heap_ptrs, offset, count, + struct_layout, base_col_offset + col_offset); + } + break; + } + default: + continue; + } + } +} + +void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, TupleDataSegment &segment, + TupleDataChunk &chunk, bool release_heap) { + D_ASSERT(this == segment.allocator.get()); + ReleaseOrStoreHandlesInternal(segment, segment.pinned_row_handles, pin_state.row_handles, chunk.row_block_ids, + row_blocks, pin_state.properties); + if (!layout.AllConstant() && release_heap) { + ReleaseOrStoreHandlesInternal(segment, segment.pinned_heap_handles, pin_state.heap_handles, + chunk.heap_block_ids, heap_blocks, pin_state.properties); + } +} + +void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, TupleDataSegment &segment) { + static TupleDataChunk DUMMY_CHUNK; + ReleaseOrStoreHandles(pin_state, segment, DUMMY_CHUNK, true); +} + +void TupleDataAllocator::ReleaseOrStoreHandlesInternal( + TupleDataSegment &segment, unsafe_vector &pinned_handles, perfect_map_t &handles, + const perfect_set_t &block_ids, unsafe_vector &blocks, TupleDataPinProperties properties) { + bool found_handle; + do { + found_handle = false; + for (auto it = handles.begin(); it != handles.end(); it++) { + const auto block_id = it->first; + if (block_ids.find(block_id) != block_ids.end()) { + // still required: do not release + continue; + } + switch (properties) { + case TupleDataPinProperties::KEEP_EVERYTHING_PINNED: { + lock_guard guard(segment.pinned_handles_lock); + const auto block_count = block_id + 1; + if (block_count > pinned_handles.size()) { + pinned_handles.resize(block_count); + } + pinned_handles[block_id] = std::move(it->second); + break; + } + case TupleDataPinProperties::UNPIN_AFTER_DONE: + case TupleDataPinProperties::ALREADY_PINNED: + break; + case TupleDataPinProperties::DESTROY_AFTER_DONE: + blocks[block_id].handle = nullptr; + break; + default: + D_ASSERT(properties == TupleDataPinProperties::INVALID); + throw InternalException("Encountered TupleDataPinProperties::INVALID"); + } + handles.erase(it); + found_handle = true; + break; + } + } while (found_handle); +} + +BufferHandle &TupleDataAllocator::PinRowBlock(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { + const auto &row_block_index = part.row_block_index; + auto it = pin_state.row_handles.find(row_block_index); + if (it == pin_state.row_handles.end()) { + D_ASSERT(row_block_index < row_blocks.size()); + auto &row_block = row_blocks[row_block_index]; + D_ASSERT(row_block.handle); + D_ASSERT(part.row_block_offset < row_block.size); + D_ASSERT(part.row_block_offset + part.count * layout.GetRowWidth() <= row_block.size); + it = pin_state.row_handles.emplace(row_block_index, buffer_manager.Pin(row_block.handle)).first; + } + return it->second; +} + +BufferHandle &TupleDataAllocator::PinHeapBlock(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { + const auto &heap_block_index = part.heap_block_index; + auto it = pin_state.heap_handles.find(heap_block_index); + if (it == pin_state.heap_handles.end()) { + D_ASSERT(heap_block_index < heap_blocks.size()); + auto &heap_block = heap_blocks[heap_block_index]; + D_ASSERT(heap_block.handle); + D_ASSERT(part.heap_block_offset < heap_block.size); + D_ASSERT(part.heap_block_offset + part.total_heap_size <= heap_block.size); + it = pin_state.heap_handles.emplace(heap_block_index, buffer_manager.Pin(heap_block.handle)).first; + } + return it->second; +} + +data_ptr_t TupleDataAllocator::GetRowPointer(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { + return PinRowBlock(pin_state, part).Ptr() + part.row_block_offset; +} + +data_ptr_t TupleDataAllocator::GetBaseHeapPointer(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { + return PinHeapBlock(pin_state, part).Ptr(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_collection.cpp b/src/duckdb/src/common/types/row/tuple_data_collection.cpp new file mode 100644 index 00000000..97858151 --- /dev/null +++ b/src/duckdb/src/common/types/row/tuple_data_collection.cpp @@ -0,0 +1,544 @@ +#include "duckdb/common/types/row/tuple_data_collection.hpp" + +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/row/tuple_data_allocator.hpp" + +#include + +namespace duckdb { + +using ValidityBytes = TupleDataLayout::ValidityBytes; + +TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, const TupleDataLayout &layout_p) + : layout(layout_p.Copy()), allocator(make_shared(buffer_manager, layout)) { + Initialize(); +} + +TupleDataCollection::TupleDataCollection(shared_ptr allocator) + : layout(allocator->GetLayout().Copy()), allocator(std::move(allocator)) { + Initialize(); +} + +TupleDataCollection::~TupleDataCollection() { +} + +void TupleDataCollection::Initialize() { + D_ASSERT(!layout.GetTypes().empty()); + this->count = 0; + this->data_size = 0; + scatter_functions.reserve(layout.ColumnCount()); + gather_functions.reserve(layout.ColumnCount()); + for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { + auto &type = layout.GetTypes()[col_idx]; + scatter_functions.emplace_back(GetScatterFunction(type)); + gather_functions.emplace_back(GetGatherFunction(type)); + } +} + +void GetAllColumnIDsInternal(vector &column_ids, const idx_t column_count) { + column_ids.reserve(column_count); + for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { + column_ids.emplace_back(col_idx); + } +} + +void TupleDataCollection::GetAllColumnIDs(vector &column_ids) { + GetAllColumnIDsInternal(column_ids, layout.ColumnCount()); +} + +const TupleDataLayout &TupleDataCollection::GetLayout() const { + return layout; +} + +const idx_t &TupleDataCollection::Count() const { + return count; +} + +idx_t TupleDataCollection::ChunkCount() const { + idx_t total_chunk_count = 0; + for (const auto &segment : segments) { + total_chunk_count += segment.ChunkCount(); + } + return total_chunk_count; +} + +idx_t TupleDataCollection::SizeInBytes() const { + idx_t total_size = 0; + for (const auto &segment : segments) { + total_size += segment.SizeInBytes(); + } + return total_size; +} + +void TupleDataCollection::Unpin() { + for (auto &segment : segments) { + segment.Unpin(); + } +} + +// LCOV_EXCL_START +void VerifyAppendColumns(const TupleDataLayout &layout, const vector &column_ids) { +#ifdef DEBUG + for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { + if (std::find(column_ids.begin(), column_ids.end(), col_idx) != column_ids.end()) { + continue; + } + // This column will not be appended in the first go - verify that it is fixed-size - we cannot resize heap after + const auto physical_type = layout.GetTypes()[col_idx].InternalType(); + D_ASSERT(physical_type != PhysicalType::VARCHAR && physical_type != PhysicalType::LIST); + if (physical_type == PhysicalType::STRUCT) { + const auto &struct_layout = layout.GetStructLayout(col_idx); + vector struct_column_ids; + struct_column_ids.reserve(struct_layout.ColumnCount()); + for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { + struct_column_ids.emplace_back(struct_col_idx); + } + VerifyAppendColumns(struct_layout, struct_column_ids); + } + } +#endif +} +// LCOV_EXCL_STOP + +void TupleDataCollection::InitializeAppend(TupleDataAppendState &append_state, TupleDataPinProperties properties) { + vector column_ids; + GetAllColumnIDs(column_ids); + InitializeAppend(append_state, std::move(column_ids), properties); +} + +void TupleDataCollection::InitializeAppend(TupleDataAppendState &append_state, vector column_ids, + TupleDataPinProperties properties) { + VerifyAppendColumns(layout, column_ids); + InitializeAppend(append_state.pin_state, properties); + InitializeChunkState(append_state.chunk_state, std::move(column_ids)); +} + +void TupleDataCollection::InitializeAppend(TupleDataPinState &pin_state, TupleDataPinProperties properties) { + pin_state.properties = properties; + if (segments.empty()) { + segments.emplace_back(allocator); + } +} + +static void InitializeVectorFormat(vector &vector_data, const vector &types) { + vector_data.resize(types.size()); + for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { + const auto &type = types[col_idx]; + switch (type.InternalType()) { + case PhysicalType::STRUCT: { + const auto &child_list = StructType::GetChildTypes(type); + vector child_types; + child_types.reserve(child_list.size()); + for (const auto &child_entry : child_list) { + child_types.emplace_back(child_entry.second); + } + InitializeVectorFormat(vector_data[col_idx].children, child_types); + break; + } + case PhysicalType::LIST: + InitializeVectorFormat(vector_data[col_idx].children, {ListType::GetChildType(type)}); + break; + default: + break; + } + } +} + +void TupleDataCollection::InitializeChunkState(TupleDataChunkState &chunk_state, vector column_ids) { + TupleDataCollection::InitializeChunkState(chunk_state, layout.GetTypes(), std::move(column_ids)); +} + +void TupleDataCollection::InitializeChunkState(TupleDataChunkState &chunk_state, const vector &types, + vector column_ids) { + if (column_ids.empty()) { + GetAllColumnIDsInternal(column_ids, types.size()); + } + InitializeVectorFormat(chunk_state.vector_data, types); + chunk_state.column_ids = std::move(column_ids); +} + +void TupleDataCollection::Append(DataChunk &new_chunk, const SelectionVector &append_sel, idx_t append_count) { + TupleDataAppendState append_state; + InitializeAppend(append_state); + Append(append_state, new_chunk, append_sel, append_count); +} + +void TupleDataCollection::Append(DataChunk &new_chunk, vector column_ids, const SelectionVector &append_sel, + const idx_t append_count) { + TupleDataAppendState append_state; + InitializeAppend(append_state, std::move(column_ids)); + Append(append_state, new_chunk, append_sel, append_count); +} + +void TupleDataCollection::Append(TupleDataAppendState &append_state, DataChunk &new_chunk, + const SelectionVector &append_sel, const idx_t append_count) { + Append(append_state.pin_state, append_state.chunk_state, new_chunk, append_sel, append_count); +} + +void TupleDataCollection::Append(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, DataChunk &new_chunk, + const SelectionVector &append_sel, const idx_t append_count) { + TupleDataCollection::ToUnifiedFormat(chunk_state, new_chunk); + AppendUnified(pin_state, chunk_state, new_chunk, append_sel, append_count); +} + +void TupleDataCollection::AppendUnified(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, + DataChunk &new_chunk, const SelectionVector &append_sel, + const idx_t append_count) { + const idx_t actual_append_count = append_count == DConstants::INVALID_INDEX ? new_chunk.size() : append_count; + if (actual_append_count == 0) { + return; + } + + if (!layout.AllConstant()) { + TupleDataCollection::ComputeHeapSizes(chunk_state, new_chunk, append_sel, actual_append_count); + } + + Build(pin_state, chunk_state, 0, actual_append_count); + +#ifdef DEBUG + Vector heap_locations_copy(LogicalType::POINTER); + if (!layout.AllConstant()) { + VectorOperations::Copy(chunk_state.heap_locations, heap_locations_copy, actual_append_count, 0, 0); + } +#endif + + Scatter(chunk_state, new_chunk, append_sel, actual_append_count); + +#ifdef DEBUG + // Verify that the size of the data written to the heap is the same as the size we computed it would be + if (!layout.AllConstant()) { + const auto original_heap_locations = FlatVector::GetData(heap_locations_copy); + const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); + const auto offset_heap_locations = FlatVector::GetData(chunk_state.heap_locations); + for (idx_t i = 0; i < actual_append_count; i++) { + D_ASSERT(offset_heap_locations[i] == original_heap_locations[i] + heap_sizes[i]); + } + } +#endif +} + +static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector &vector, const idx_t count) { + vector.ToUnifiedFormat(count, format.unified); + format.original_sel = format.unified.sel; + format.original_owned_sel.Initialize(format.unified.owned_sel); + switch (vector.GetType().InternalType()) { + case PhysicalType::STRUCT: { + auto &entries = StructVector::GetEntries(vector); + D_ASSERT(format.children.size() == entries.size()); + for (idx_t struct_col_idx = 0; struct_col_idx < entries.size(); struct_col_idx++) { + ToUnifiedFormatInternal(reinterpret_cast(format.children[struct_col_idx]), + *entries[struct_col_idx], count); + } + break; + } + case PhysicalType::LIST: + D_ASSERT(format.children.size() == 1); + ToUnifiedFormatInternal(reinterpret_cast(format.children[0]), + ListVector::GetEntry(vector), ListVector::GetListSize(vector)); + break; + default: + break; + } +} + +void TupleDataCollection::ToUnifiedFormat(TupleDataChunkState &chunk_state, DataChunk &new_chunk) { + D_ASSERT(chunk_state.vector_data.size() >= chunk_state.column_ids.size()); // Needs InitializeAppend + for (const auto &col_idx : chunk_state.column_ids) { + ToUnifiedFormatInternal(chunk_state.vector_data[col_idx], new_chunk.data[col_idx], new_chunk.size()); + } +} + +void TupleDataCollection::GetVectorData(const TupleDataChunkState &chunk_state, UnifiedVectorFormat result[]) { + const auto &vector_data = chunk_state.vector_data; + for (idx_t i = 0; i < vector_data.size(); i++) { + const auto &source = vector_data[i].unified; + auto &target = result[i]; + target.sel = source.sel; + target.data = source.data; + target.validity = source.validity; + } +} + +void TupleDataCollection::Build(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, + const idx_t append_offset, const idx_t append_count) { + auto &segment = segments.back(); + const auto size_before = segment.SizeInBytes(); + segment.allocator->Build(segment, pin_state, chunk_state, append_offset, append_count); + data_size += segment.SizeInBytes() - size_before; + count += append_count; + Verify(); +} + +// LCOV_EXCL_START +void VerifyHeapSizes(const data_ptr_t source_locations[], const idx_t heap_sizes[], const SelectionVector &append_sel, + const idx_t append_count, const idx_t heap_size_offset) { +#ifdef DEBUG + for (idx_t i = 0; i < append_count; i++) { + auto idx = append_sel.get_index(i); + const auto stored_heap_size = Load(source_locations[idx] + heap_size_offset); + D_ASSERT(stored_heap_size == heap_sizes[idx]); + } +#endif +} +// LCOV_EXCL_STOP + +void TupleDataCollection::CopyRows(TupleDataChunkState &chunk_state, TupleDataChunkState &input, + const SelectionVector &append_sel, const idx_t append_count) const { + const auto source_locations = FlatVector::GetData(input.row_locations); + const auto target_locations = FlatVector::GetData(chunk_state.row_locations); + + // Copy rows + const auto row_width = layout.GetRowWidth(); + for (idx_t i = 0; i < append_count; i++) { + auto idx = append_sel.get_index(i); + FastMemcpy(target_locations[i], source_locations[idx], row_width); + } + + // Copy heap if we need to + if (!layout.AllConstant()) { + const auto source_heap_locations = FlatVector::GetData(input.heap_locations); + const auto target_heap_locations = FlatVector::GetData(chunk_state.heap_locations); + const auto heap_sizes = FlatVector::GetData(input.heap_sizes); + VerifyHeapSizes(source_locations, heap_sizes, append_sel, append_count, layout.GetHeapSizeOffset()); + + // Check if we need to copy anything at all + idx_t total_heap_size = 0; + for (idx_t i = 0; i < append_count; i++) { + auto idx = append_sel.get_index(i); + total_heap_size += heap_sizes[idx]; + } + if (total_heap_size == 0) { + return; + } + + // Copy heap + for (idx_t i = 0; i < append_count; i++) { + auto idx = append_sel.get_index(i); + FastMemcpy(target_heap_locations[i], source_heap_locations[idx], heap_sizes[idx]); + } + + // Recompute pointers after copying the data + TupleDataAllocator::RecomputeHeapPointers(input.heap_locations, append_sel, target_locations, + chunk_state.heap_locations, 0, append_count, layout, 0); + } +} + +void TupleDataCollection::Combine(TupleDataCollection &other) { + if (other.count == 0) { + return; + } + if (this->layout.GetTypes() != other.GetLayout().GetTypes()) { + throw InternalException("Attempting to combine TupleDataCollection with mismatching types"); + } + this->segments.reserve(this->segments.size() + other.segments.size()); + for (auto &other_seg : other.segments) { + AddSegment(std::move(other_seg)); + } + other.Reset(); +} + +void TupleDataCollection::AddSegment(TupleDataSegment &&segment) { + count += segment.count; + data_size += segment.data_size; + segments.emplace_back(std::move(segment)); + Verify(); +} + +void TupleDataCollection::Combine(unique_ptr other) { + Combine(*other); +} + +void TupleDataCollection::Reset() { + count = 0; + data_size = 0; + segments.clear(); + + // Refreshes the TupleDataAllocator to prevent holding on to allocated data unnecessarily + allocator = make_shared(*allocator); +} + +void TupleDataCollection::InitializeChunk(DataChunk &chunk) const { + chunk.Initialize(allocator->GetAllocator(), layout.GetTypes()); +} + +void TupleDataCollection::InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const { + auto &column_ids = state.chunk_state.column_ids; + D_ASSERT(!column_ids.empty()); + vector chunk_types; + chunk_types.reserve(column_ids.size()); + for (idx_t i = 0; i < column_ids.size(); i++) { + auto column_idx = column_ids[i]; + D_ASSERT(column_idx < layout.ColumnCount()); + chunk_types.push_back(layout.GetTypes()[column_idx]); + } + chunk.Initialize(allocator->GetAllocator(), chunk_types); +} + +void TupleDataCollection::InitializeScan(TupleDataScanState &state, TupleDataPinProperties properties) const { + vector column_ids; + column_ids.reserve(layout.ColumnCount()); + for (idx_t i = 0; i < layout.ColumnCount(); i++) { + column_ids.push_back(i); + } + InitializeScan(state, std::move(column_ids), properties); +} + +void TupleDataCollection::InitializeScan(TupleDataScanState &state, vector column_ids, + TupleDataPinProperties properties) const { + state.pin_state.row_handles.clear(); + state.pin_state.heap_handles.clear(); + state.pin_state.properties = properties; + state.segment_index = 0; + state.chunk_index = 0; + state.chunk_state.column_ids = std::move(column_ids); +} + +void TupleDataCollection::InitializeScan(TupleDataParallelScanState &gstate, TupleDataPinProperties properties) const { + InitializeScan(gstate.scan_state, properties); +} + +void TupleDataCollection::InitializeScan(TupleDataParallelScanState &state, vector column_ids, + TupleDataPinProperties properties) const { + InitializeScan(state.scan_state, std::move(column_ids), properties); +} + +bool TupleDataCollection::Scan(TupleDataScanState &state, DataChunk &result) { + const auto segment_index_before = state.segment_index; + idx_t segment_index; + idx_t chunk_index; + if (!NextScanIndex(state, segment_index, chunk_index)) { + if (!segments.empty()) { + FinalizePinState(state.pin_state, segments[segment_index_before]); + } + result.SetCardinality(0); + return false; + } + if (segment_index_before != DConstants::INVALID_INDEX && segment_index != segment_index_before) { + FinalizePinState(state.pin_state, segments[segment_index_before]); + } + ScanAtIndex(state.pin_state, state.chunk_state, state.chunk_state.column_ids, segment_index, chunk_index, result); + return true; +} + +bool TupleDataCollection::Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate, DataChunk &result) { + lstate.pin_state.properties = gstate.scan_state.pin_state.properties; + + const auto segment_index_before = lstate.segment_index; + { + lock_guard guard(gstate.lock); + if (!NextScanIndex(gstate.scan_state, lstate.segment_index, lstate.chunk_index)) { + if (!segments.empty()) { + FinalizePinState(lstate.pin_state, segments[segment_index_before]); + } + result.SetCardinality(0); + return false; + } + } + if (segment_index_before != DConstants::INVALID_INDEX && segment_index_before != lstate.segment_index) { + FinalizePinState(lstate.pin_state, segments[lstate.segment_index]); + } + ScanAtIndex(lstate.pin_state, lstate.chunk_state, gstate.scan_state.chunk_state.column_ids, lstate.segment_index, + lstate.chunk_index, result); + return true; +} + +bool TupleDataCollection::ScanComplete(const TupleDataScanState &state) const { + if (Count() == 0) { + return true; + } + return state.segment_index == segments.size() - 1 && state.chunk_index == segments.back().ChunkCount(); +} + +void TupleDataCollection::FinalizePinState(TupleDataPinState &pin_state, TupleDataSegment &segment) { + segment.allocator->ReleaseOrStoreHandles(pin_state, segment); +} + +void TupleDataCollection::FinalizePinState(TupleDataPinState &pin_state) { + D_ASSERT(!segments.empty()); + FinalizePinState(pin_state, segments.back()); +} + +bool TupleDataCollection::NextScanIndex(TupleDataScanState &state, idx_t &segment_index, idx_t &chunk_index) { + // Check if we still have segments to scan + if (state.segment_index >= segments.size()) { + // No more data left in the scan + return false; + } + // Check within the current segment if we still have chunks to scan + while (state.chunk_index >= segments[state.segment_index].ChunkCount()) { + // Exhausted all chunks for this segment: Move to the next one + state.segment_index++; + state.chunk_index = 0; + if (state.segment_index >= segments.size()) { + return false; + } + } + segment_index = state.segment_index; + chunk_index = state.chunk_index++; + return true; +} + +void TupleDataCollection::ScanAtIndex(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, + const vector &column_ids, idx_t segment_index, idx_t chunk_index, + DataChunk &result) { + auto &segment = segments[segment_index]; + auto &chunk = segment.chunks[chunk_index]; + segment.allocator->InitializeChunkState(segment, pin_state, chunk_state, chunk_index, false); + result.Reset(); + Gather(chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), chunk.count, column_ids, result, + *FlatVector::IncrementalSelectionVector()); + result.SetCardinality(chunk.count); +} + +// LCOV_EXCL_START +string TupleDataCollection::ToString() { + DataChunk chunk; + InitializeChunk(chunk); + + TupleDataScanState scan_state; + InitializeScan(scan_state); + + string result = StringUtil::Format("TupleDataCollection - [%llu Chunks, %llu Rows]\n", ChunkCount(), Count()); + idx_t chunk_idx = 0; + idx_t row_count = 0; + while (Scan(scan_state, chunk)) { + result += + StringUtil::Format("Chunk %llu - [Rows %llu - %llu]\n", chunk_idx, row_count, row_count + chunk.size()) + + chunk.ToString(); + chunk_idx++; + row_count += chunk.size(); + } + + return result; +} + +void TupleDataCollection::Print() { + Printer::Print(ToString()); +} + +void TupleDataCollection::Verify() const { +#ifdef DEBUG + idx_t total_count = 0; + idx_t total_size = 0; + for (const auto &segment : segments) { + segment.Verify(); + total_count += segment.count; + total_size += segment.data_size; + } + D_ASSERT(total_count == this->count); + D_ASSERT(total_size == this->data_size); +#endif +} + +void TupleDataCollection::VerifyEverythingPinned() const { +#ifdef DEBUG + for (const auto &segment : segments) { + segment.VerifyEverythingPinned(); + } +#endif +} +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_iterator.cpp b/src/duckdb/src/common/types/row/tuple_data_iterator.cpp new file mode 100644 index 00000000..a209322a --- /dev/null +++ b/src/duckdb/src/common/types/row/tuple_data_iterator.cpp @@ -0,0 +1,96 @@ +#include "duckdb/common/types/row/tuple_data_iterator.hpp" + +#include "duckdb/common/types/row/tuple_data_allocator.hpp" + +namespace duckdb { + +TupleDataChunkIterator::TupleDataChunkIterator(TupleDataCollection &collection_p, TupleDataPinProperties properties_p, + bool init_heap) + : TupleDataChunkIterator(collection_p, properties_p, 0, collection_p.ChunkCount(), init_heap) { +} + +TupleDataChunkIterator::TupleDataChunkIterator(TupleDataCollection &collection_p, TupleDataPinProperties properties, + idx_t chunk_idx_from, idx_t chunk_idx_to, bool init_heap_p) + : collection(collection_p), init_heap(init_heap_p) { + state.pin_state.properties = properties; + D_ASSERT(chunk_idx_from < chunk_idx_to); + D_ASSERT(chunk_idx_to <= collection.ChunkCount()); + idx_t overall_chunk_index = 0; + for (idx_t segment_idx = 0; segment_idx < collection.segments.size(); segment_idx++) { + const auto &segment = collection.segments[segment_idx]; + if (chunk_idx_from >= overall_chunk_index && chunk_idx_from <= overall_chunk_index + segment.ChunkCount()) { + // We start in this segment + start_segment_idx = segment_idx; + start_chunk_idx = chunk_idx_from - overall_chunk_index; + } + if (chunk_idx_to >= overall_chunk_index && chunk_idx_to <= overall_chunk_index + segment.ChunkCount()) { + // We end in this segment + end_segment_idx = segment_idx; + end_chunk_idx = chunk_idx_to - overall_chunk_index; + } + overall_chunk_index += segment.ChunkCount(); + } + + Reset(); +} + +void TupleDataChunkIterator::InitializeCurrentChunk() { + auto &segment = collection.segments[current_segment_idx]; + segment.allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, current_chunk_idx, init_heap); +} + +bool TupleDataChunkIterator::Done() const { + return current_segment_idx == end_segment_idx && current_chunk_idx == end_chunk_idx; +} + +bool TupleDataChunkIterator::Next() { + D_ASSERT(!Done()); // Check if called after already done + + // Set the next indices and checks if we're at the end of the collection + // NextScanIndex can go past this iterators 'end', so we have to check the indices again + const auto segment_idx_before = current_segment_idx; + if (!collection.NextScanIndex(state, current_segment_idx, current_chunk_idx) || Done()) { + // Drop pins / stores them if TupleDataPinProperties::KEEP_EVERYTHING_PINNED + collection.FinalizePinState(state.pin_state, collection.segments[segment_idx_before]); + current_segment_idx = end_segment_idx; + current_chunk_idx = end_chunk_idx; + return false; + } + + // Finalize pin state when moving from one segment to the next + if (current_segment_idx != segment_idx_before) { + collection.FinalizePinState(state.pin_state, collection.segments[segment_idx_before]); + } + + InitializeCurrentChunk(); + return true; +} + +void TupleDataChunkIterator::Reset() { + state.segment_index = start_segment_idx; + state.chunk_index = start_chunk_idx; + collection.NextScanIndex(state, current_segment_idx, current_chunk_idx); + InitializeCurrentChunk(); +} + +idx_t TupleDataChunkIterator::GetCurrentChunkCount() const { + return collection.segments[current_segment_idx].chunks[current_chunk_idx].count; +} + +TupleDataChunkState &TupleDataChunkIterator::GetChunkState() { + return state.chunk_state; +} + +data_ptr_t *TupleDataChunkIterator::GetRowLocations() { + return FlatVector::GetData(state.chunk_state.row_locations); +} + +data_ptr_t *TupleDataChunkIterator::GetHeapLocations() { + return FlatVector::GetData(state.chunk_state.heap_locations); +} + +idx_t *TupleDataChunkIterator::GetHeapSizes() { + return FlatVector::GetData(state.chunk_state.heap_sizes); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_layout.cpp b/src/duckdb/src/common/types/row/tuple_data_layout.cpp new file mode 100644 index 00000000..3caa365b --- /dev/null +++ b/src/duckdb/src/common/types/row/tuple_data_layout.cpp @@ -0,0 +1,129 @@ +#include "duckdb/common/types/row/tuple_data_layout.hpp" + +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +TupleDataLayout::TupleDataLayout() + : flag_width(0), data_width(0), aggr_width(0), row_width(0), all_constant(true), heap_size_offset(0), + has_destructor(false) { +} + +TupleDataLayout TupleDataLayout::Copy() const { + TupleDataLayout result; + result.types = this->types; + result.aggregates = this->aggregates; + if (this->struct_layouts) { + result.struct_layouts = make_uniq>(); + for (const auto &entry : *this->struct_layouts) { + result.struct_layouts->emplace(entry.first, entry.second.Copy()); + } + } + result.flag_width = this->flag_width; + result.data_width = this->data_width; + result.aggr_width = this->aggr_width; + result.row_width = this->row_width; + result.offsets = this->offsets; + result.all_constant = this->all_constant; + result.heap_size_offset = this->heap_size_offset; + result.has_destructor = this->has_destructor; + return result; +} + +void TupleDataLayout::Initialize(vector types_p, Aggregates aggregates_p, bool align, bool heap_offset_p) { + offsets.clear(); + types = std::move(types_p); + + // Null mask at the front - 1 bit per value. + flag_width = ValidityBytes::ValidityMaskSize(types.size()); + row_width = flag_width; + + // Whether all columns are constant size. + for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { + const auto &type = types[col_idx]; + if (type.InternalType() == PhysicalType::STRUCT) { + // structs are recursively stored as a TupleDataLayout again + const auto &child_types = StructType::GetChildTypes(type); + vector child_type_vector; + child_type_vector.reserve(child_types.size()); + for (auto &ct : child_types) { + child_type_vector.emplace_back(ct.second); + } + if (!struct_layouts) { + struct_layouts = make_uniq>(); + } + auto struct_entry = struct_layouts->emplace(col_idx, TupleDataLayout()); + struct_entry.first->second.Initialize(std::move(child_type_vector), false, false); + all_constant = all_constant && struct_entry.first->second.AllConstant(); + } else { + all_constant = all_constant && TypeIsConstantSize(type.InternalType()); + } + } + + // This enables pointer swizzling for out-of-core computation. + if (heap_offset_p && !all_constant) { + heap_size_offset = row_width; + row_width += sizeof(uint32_t); + } + + // Data columns. No alignment required. + for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { + const auto &type = types[col_idx]; + offsets.push_back(row_width); + const auto internal_type = type.InternalType(); + if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { + row_width += GetTypeIdSize(type.InternalType()); + } else if (internal_type == PhysicalType::STRUCT) { + // Just get the size of the TupleDataLayout of the struct + row_width += GetStructLayout(col_idx).GetRowWidth(); + } else { + // Variable size types use pointers to the actual data (can be swizzled). + // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). + row_width += sizeof(idx_t); + } + } + + // Alignment padding for aggregates +#ifndef DUCKDB_ALLOW_UNDEFINED + if (align) { + row_width = AlignValue(row_width); + } +#endif + data_width = row_width - flag_width; + + // Aggregate fields. + aggregates = std::move(aggregates_p); + for (auto &aggregate : aggregates) { + offsets.push_back(row_width); + row_width += aggregate.payload_size; +#ifndef DUCKDB_ALLOW_UNDEFINED + D_ASSERT(aggregate.payload_size == AlignValue(aggregate.payload_size)); +#endif + } + aggr_width = row_width - data_width - flag_width; + + // Alignment padding for the next row +#ifndef DUCKDB_ALLOW_UNDEFINED + if (align) { + row_width = AlignValue(row_width); + } +#endif + + has_destructor = false; + for (auto &aggr : GetAggregates()) { + if (aggr.function.destructor) { + has_destructor = true; + break; + } + } +} + +void TupleDataLayout::Initialize(vector types_p, bool align, bool heap_offset_p) { + Initialize(std::move(types_p), Aggregates(), align, heap_offset_p); +} + +void TupleDataLayout::Initialize(Aggregates aggregates_p, bool align, bool heap_offset_p) { + Initialize(vector(), std::move(aggregates_p), align, heap_offset_p); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp new file mode 100644 index 00000000..849d5d48 --- /dev/null +++ b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp @@ -0,0 +1,1250 @@ +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/row/tuple_data_collection.hpp" + +namespace duckdb { + +using ValidityBytes = TupleDataLayout::ValidityBytes; + +template +static constexpr idx_t TupleDataWithinListFixedSize() { + return sizeof(T); +} + +template <> +constexpr idx_t TupleDataWithinListFixedSize() { + return sizeof(uint32_t); +} + +template +static inline void TupleDataValueStore(const T &source, const data_ptr_t &row_location, const idx_t offset_in_row, + data_ptr_t &heap_location) { + Store(source, row_location + offset_in_row); +} + +template <> +inline void TupleDataValueStore(const string_t &source, const data_ptr_t &row_location, const idx_t offset_in_row, + data_ptr_t &heap_location) { + if (source.IsInlined()) { + Store(source, row_location + offset_in_row); + } else { + memcpy(heap_location, source.GetData(), source.GetSize()); + Store(string_t(const_char_ptr_cast(heap_location), source.GetSize()), row_location + offset_in_row); + heap_location += source.GetSize(); + } +} + +template +static inline void TupleDataWithinListValueStore(const T &source, const data_ptr_t &location, + data_ptr_t &heap_location) { + Store(source, location); +} + +template <> +inline void TupleDataWithinListValueStore(const string_t &source, const data_ptr_t &location, + data_ptr_t &heap_location) { + Store(source.GetSize(), location); + memcpy(heap_location, source.GetData(), source.GetSize()); + heap_location += source.GetSize(); +} + +template +static inline T TupleDataWithinListValueLoad(const data_ptr_t &location, data_ptr_t &heap_location) { + return Load(location); +} + +template <> +inline string_t TupleDataWithinListValueLoad(const data_ptr_t &location, data_ptr_t &heap_location) { + const auto size = Load(location); + string_t result(const_char_ptr_cast(heap_location), size); + heap_location += size; + return result; +} + +#ifdef DEBUG +static void ResetCombinedListData(vector &vector_data) { + for (auto &vd : vector_data) { + vd.combined_list_data = nullptr; + ResetCombinedListData(vd.children); + } +} +#endif + +void TupleDataCollection::ComputeHeapSizes(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, + const SelectionVector &append_sel, const idx_t append_count) { +#ifdef DEBUG + ResetCombinedListData(chunk_state.vector_data); +#endif + + auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); + std::fill_n(heap_sizes, new_chunk.size(), 0); + + for (idx_t col_idx = 0; col_idx < new_chunk.ColumnCount(); col_idx++) { + auto &source_v = new_chunk.data[col_idx]; + auto &source_format = chunk_state.vector_data[col_idx]; + TupleDataCollection::ComputeHeapSizes(chunk_state.heap_sizes, source_v, source_format, append_sel, + append_count); + } +} + +static inline idx_t StringHeapSize(const string_t &val) { + return val.IsInlined() ? 0 : val.GetSize(); +} + +void TupleDataCollection::ComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, const SelectionVector &append_sel, + const idx_t append_count) { + const auto type = source_v.GetType().InternalType(); + if (type != PhysicalType::VARCHAR && type != PhysicalType::STRUCT && type != PhysicalType::LIST) { + return; + } + + auto heap_sizes = FlatVector::GetData(heap_sizes_v); + + const auto &source_vector_data = source_format.unified; + const auto &source_sel = *source_vector_data.sel; + const auto &source_validity = source_vector_data.validity; + + switch (type) { + case PhysicalType::VARCHAR: { + // Only non-inlined strings are stored in the heap + const auto source_data = UnifiedVectorFormat::GetData(source_vector_data); + for (idx_t i = 0; i < append_count; i++) { + const auto source_idx = source_sel.get_index(append_sel.get_index(i)); + if (source_validity.RowIsValid(source_idx)) { + heap_sizes[i] += StringHeapSize(source_data[source_idx]); + } else { + heap_sizes[i] += StringHeapSize(NullValue()); + } + } + break; + } + case PhysicalType::STRUCT: { + // Recurse through the struct children + auto &struct_sources = StructVector::GetEntries(source_v); + for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { + const auto &struct_source = struct_sources[struct_col_idx]; + auto &struct_format = source_format.children[struct_col_idx]; + TupleDataCollection::ComputeHeapSizes(heap_sizes_v, *struct_source, struct_format, append_sel, + append_count); + } + break; + } + case PhysicalType::LIST: { + // Lists are stored entirely in the heap + for (idx_t i = 0; i < append_count; i++) { + auto source_idx = source_sel.get_index(append_sel.get_index(i)); + if (source_validity.RowIsValid(source_idx)) { + heap_sizes[i] += sizeof(uint64_t); // Size of the list + } + } + + // Recurse + D_ASSERT(source_format.children.size() == 1); + auto &child_source_v = ListVector::GetEntry(source_v); + auto &child_format = source_format.children[0]; + TupleDataCollection::WithinListHeapComputeSizes(heap_sizes_v, child_source_v, child_format, append_sel, + append_count, source_vector_data); + break; + } + default: + throw NotImplementedException("ComputeHeapSizes for %s", EnumUtil::ToString(source_v.GetType().id())); + } +} + +void TupleDataCollection::WithinListHeapComputeSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const UnifiedVectorFormat &list_data) { + auto type = source_v.GetType().InternalType(); + if (TypeIsConstantSize(type)) { + TupleDataCollection::ComputeFixedWithinListHeapSizes(heap_sizes_v, source_v, source_format, append_sel, + append_count, list_data); + return; + } + + switch (type) { + case PhysicalType::VARCHAR: + TupleDataCollection::StringWithinListComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, + append_count, list_data); + break; + case PhysicalType::STRUCT: + TupleDataCollection::StructWithinListComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, + append_count, list_data); + break; + case PhysicalType::LIST: + TupleDataCollection::ListWithinListComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, + append_count, list_data); + break; + default: + throw NotImplementedException("WithinListHeapComputeSizes for %s", EnumUtil::ToString(source_v.GetType().id())); + } +} + +void TupleDataCollection::ComputeFixedWithinListHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const UnifiedVectorFormat &list_data) { + // List data + const auto list_sel = *list_data.sel; + const auto list_entries = UnifiedVectorFormat::GetData(list_data); + const auto &list_validity = list_data.validity; + + // Target + auto heap_sizes = FlatVector::GetData(heap_sizes_v); + + D_ASSERT(TypeIsConstantSize(source_v.GetType().InternalType())); + const auto type_size = GetTypeIdSize(source_v.GetType().InternalType()); + for (idx_t i = 0; i < append_count; i++) { + const auto list_idx = list_sel.get_index(append_sel.get_index(i)); + if (!list_validity.RowIsValid(list_idx)) { + continue; // Original list entry is invalid - no need to serialize the child + } + + // Get the current list length + const auto &list_length = list_entries[list_idx].length; + + // Size is validity mask and all values + auto &heap_size = heap_sizes[i]; + heap_size += ValidityBytes::SizeInBytes(list_length); + heap_size += list_length * type_size; + } +} + +void TupleDataCollection::StringWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const UnifiedVectorFormat &list_data) { + // Source + const auto &source_data = source_format.unified; + const auto &source_sel = *source_data.sel; + const auto data = UnifiedVectorFormat::GetData(source_data); + const auto &source_validity = source_data.validity; + + // List data + const auto list_sel = *list_data.sel; + const auto list_entries = UnifiedVectorFormat::GetData(list_data); + const auto &list_validity = list_data.validity; + + // Target + auto heap_sizes = FlatVector::GetData(heap_sizes_v); + + for (idx_t i = 0; i < append_count; i++) { + const auto list_idx = list_sel.get_index(append_sel.get_index(i)); + if (!list_validity.RowIsValid(list_idx)) { + continue; // Original list entry is invalid - no need to serialize the child + } + + // Get the current list entry + const auto &list_entry = list_entries[list_idx]; + const auto &list_offset = list_entry.offset; + const auto &list_length = list_entry.length; + + // Size is validity mask and all string sizes + auto &heap_size = heap_sizes[i]; + heap_size += ValidityBytes::SizeInBytes(list_length); + heap_size += list_length * TupleDataWithinListFixedSize(); + + // Plus all the actual strings + for (idx_t child_i = 0; child_i < list_length; child_i++) { + const auto child_source_idx = source_sel.get_index(list_offset + child_i); + if (source_validity.RowIsValid(child_source_idx)) { + heap_size += data[child_source_idx].GetSize(); + } + } + } +} + +void TupleDataCollection::StructWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const UnifiedVectorFormat &list_data) { + // List data + const auto list_sel = *list_data.sel; + const auto list_entries = UnifiedVectorFormat::GetData(list_data); + const auto &list_validity = list_data.validity; + + // Target + auto heap_sizes = FlatVector::GetData(heap_sizes_v); + + for (idx_t i = 0; i < append_count; i++) { + const auto list_idx = list_sel.get_index(append_sel.get_index(i)); + if (!list_validity.RowIsValid(list_idx)) { + continue; // Original list entry is invalid - no need to serialize the child + } + + // Get the current list length + const auto &list_length = list_entries[list_idx].length; + + // Size is just the validity mask + heap_sizes[i] += ValidityBytes::SizeInBytes(list_length); + } + + // Recurse + auto &struct_sources = StructVector::GetEntries(source_v); + for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { + auto &struct_source = *struct_sources[struct_col_idx]; + auto &struct_format = source_format.children[struct_col_idx]; + TupleDataCollection::WithinListHeapComputeSizes(heap_sizes_v, struct_source, struct_format, append_sel, + append_count, list_data); + } +} + +static void ApplySliceRecursive(const Vector &source_v, TupleDataVectorFormat &source_format, + const SelectionVector &combined_sel, const idx_t count) { + D_ASSERT(source_format.combined_list_data); + auto &combined_list_data = *source_format.combined_list_data; + + combined_list_data.selection_data = source_format.original_sel->Slice(combined_sel, count); + source_format.unified.owned_sel.Initialize(combined_list_data.selection_data); + source_format.unified.sel = &source_format.unified.owned_sel; + + if (source_v.GetType().InternalType() == PhysicalType::STRUCT) { + // We have to apply it to the child vectors too + auto &struct_sources = StructVector::GetEntries(source_v); + for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { + auto &struct_source = *struct_sources[struct_col_idx]; + auto &struct_format = source_format.children[struct_col_idx]; +#ifdef DEBUG + D_ASSERT(!struct_format.combined_list_data); +#endif + if (!struct_format.combined_list_data) { + struct_format.combined_list_data = make_uniq(); + } + ApplySliceRecursive(struct_source, struct_format, *source_format.unified.sel, count); + } + } +} + +void TupleDataCollection::ListWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const UnifiedVectorFormat &list_data) { + // List data (of the list Vector that "source_v" is in) + const auto list_sel = *list_data.sel; + const auto list_entries = UnifiedVectorFormat::GetData(list_data); + const auto &list_validity = list_data.validity; + + // Child list ("source_v") + const auto &child_list_data = source_format.unified; + const auto child_list_sel = *child_list_data.sel; + const auto child_list_entries = UnifiedVectorFormat::GetData(child_list_data); + const auto &child_list_validity = child_list_data.validity; + + // Figure out actual child list size (can differ from ListVector::GetListSize if dict/const vector), + // and we cannot use ConstantVector::ZeroSelectionVector because it may need to be longer than STANDARD_VECTOR_SIZE + idx_t sum_of_sizes = 0; + for (idx_t i = 0; i < append_count; i++) { + const auto list_idx = list_sel.get_index(append_sel.get_index(i)); + if (!list_validity.RowIsValid(list_idx)) { + continue; + } + const auto &list_entry = list_entries[list_idx]; + const auto &list_offset = list_entry.offset; + const auto &list_length = list_entry.length; + + for (idx_t child_i = 0; child_i < list_length; child_i++) { + const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); + if (!child_list_validity.RowIsValid(child_list_idx)) { + continue; + } + + const auto &child_list_entry = child_list_entries[child_list_idx]; + const auto &child_list_length = child_list_entry.length; + + sum_of_sizes += child_list_length; + } + } + const auto child_list_child_count = MaxValue(sum_of_sizes, ListVector::GetListSize(source_v)); + + // Target + auto heap_sizes = FlatVector::GetData(heap_sizes_v); + + // Construct combined list entries and a selection vector for the child list child + auto &child_format = source_format.children[0]; +#ifdef DEBUG + // In debug mode this should be deleted by ResetCombinedListData + D_ASSERT(!child_format.combined_list_data); +#endif + if (!child_format.combined_list_data) { + child_format.combined_list_data = make_uniq(); + } + auto &combined_list_data = *child_format.combined_list_data; + auto &combined_list_entries = combined_list_data.combined_list_entries; + SelectionVector combined_sel(child_list_child_count); + for (idx_t i = 0; i < child_list_child_count; i++) { + combined_sel.set_index(i, 0); + } + + idx_t combined_list_offset = 0; + for (idx_t i = 0; i < append_count; i++) { + const auto list_idx = list_sel.get_index(append_sel.get_index(i)); + if (!list_validity.RowIsValid(list_idx)) { + continue; // Original list entry is invalid - no need to serialize the child list + } + + // Get the current list entry + const auto &list_entry = list_entries[list_idx]; + const auto &list_offset = list_entry.offset; + const auto &list_length = list_entry.length; + + // Size is the validity mask and the list sizes + auto &heap_size = heap_sizes[i]; + heap_size += ValidityBytes::SizeInBytes(list_length); + heap_size += list_length * sizeof(uint64_t); + + idx_t child_list_size = 0; + for (idx_t child_i = 0; child_i < list_length; child_i++) { + const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); + const auto &child_list_entry = child_list_entries[child_list_idx]; + if (child_list_validity.RowIsValid(child_list_idx)) { + const auto &child_list_offset = child_list_entry.offset; + const auto &child_list_length = child_list_entry.length; + + // Add this child's list entries to the combined selection vector + for (idx_t child_value_i = 0; child_value_i < child_list_length; child_value_i++) { + auto idx = combined_list_offset + child_list_size + child_value_i; + auto loc = child_list_offset + child_value_i; + combined_sel.set_index(idx, loc); + } + + child_list_size += child_list_length; + } + } + + // Combine the child list entries into one + combined_list_entries[list_idx] = {combined_list_offset, child_list_size}; + combined_list_offset += child_list_size; + } + + // Create a combined child_list_data to be used as list_data in the recursion + auto &combined_child_list_data = combined_list_data.combined_data; + combined_child_list_data.sel = list_data.sel; + combined_child_list_data.data = data_ptr_cast(combined_list_entries); + combined_child_list_data.validity = list_data.validity; + + // Combine the selection vectors + D_ASSERT(source_format.children.size() == 1); + auto &child_source = ListVector::GetEntry(source_v); + ApplySliceRecursive(child_source, child_format, combined_sel, child_list_child_count); + + // Recurse + TupleDataCollection::WithinListHeapComputeSizes(heap_sizes_v, child_source, child_format, append_sel, append_count, + combined_child_list_data); +} + +void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, + const SelectionVector &append_sel, const idx_t append_count) const { + const auto row_locations = FlatVector::GetData(chunk_state.row_locations); + + // Set the validity mask for each row before inserting data + const auto validity_bytes = ValidityBytes::SizeInBytes(layout.ColumnCount()); + for (idx_t i = 0; i < append_count; i++) { + FastMemset(row_locations[i], ~0, validity_bytes); + } + + if (!layout.AllConstant()) { + // Set the heap size for each row + const auto heap_size_offset = layout.GetHeapSizeOffset(); + const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); + for (idx_t i = 0; i < append_count; i++) { + Store(heap_sizes[i], row_locations[i] + heap_size_offset); + } + } + + // Write the data + for (const auto &col_idx : chunk_state.column_ids) { + Scatter(chunk_state, new_chunk.data[col_idx], col_idx, append_sel, append_count); + } +} + +void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const Vector &source, const column_t column_id, + const SelectionVector &append_sel, const idx_t append_count) const { + const auto &scatter_function = scatter_functions[column_id]; + scatter_function.function(source, chunk_state.vector_data[column_id], append_sel, append_count, layout, + chunk_state.row_locations, chunk_state.heap_locations, column_id, + chunk_state.vector_data[column_id].unified, scatter_function.child_functions); +} + +template +static void TupleDataTemplatedScatter(const Vector &source, const TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const TupleDataLayout &layout, const Vector &row_locations, + Vector &heap_locations, const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, + const vector &child_functions) { + // Source + const auto &source_data = source_format.unified; + const auto &source_sel = *source_data.sel; + const auto data = UnifiedVectorFormat::GetData(source_data); + const auto &validity = source_data.validity; + + // Target + auto target_locations = FlatVector::GetData(row_locations); + auto target_heap_locations = FlatVector::GetData(heap_locations); + + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + const auto offset_in_row = layout.GetOffsets()[col_idx]; + if (validity.AllValid()) { + for (idx_t i = 0; i < append_count; i++) { + const auto source_idx = source_sel.get_index(append_sel.get_index(i)); + TupleDataValueStore(data[source_idx], target_locations[i], offset_in_row, target_heap_locations[i]); + } + } else { + for (idx_t i = 0; i < append_count; i++) { + const auto source_idx = source_sel.get_index(append_sel.get_index(i)); + if (validity.RowIsValid(source_idx)) { + TupleDataValueStore(data[source_idx], target_locations[i], offset_in_row, target_heap_locations[i]); + } else { + TupleDataValueStore(NullValue(), target_locations[i], offset_in_row, target_heap_locations[i]); + ValidityBytes(target_locations[i]).SetInvalidUnsafe(entry_idx, idx_in_entry); + } + } + } +} + +static void TupleDataStructScatter(const Vector &source, const TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, + const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, + const vector &child_functions) { + // Source + const auto &source_data = source_format.unified; + const auto &source_sel = *source_data.sel; + const auto &validity = source_data.validity; + + // Target + auto target_locations = FlatVector::GetData(row_locations); + + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + // Set validity of the STRUCT in this layout + if (!validity.AllValid()) { + for (idx_t i = 0; i < append_count; i++) { + const auto source_idx = source_sel.get_index(append_sel.get_index(i)); + if (!validity.RowIsValid(source_idx)) { + ValidityBytes(target_locations[i]).SetInvalidUnsafe(entry_idx, idx_in_entry); + } + } + } + + // Create a Vector of pointers to the TupleDataLayout of the STRUCT + Vector struct_row_locations(LogicalType::POINTER, append_count); + auto struct_target_locations = FlatVector::GetData(struct_row_locations); + const auto offset_in_row = layout.GetOffsets()[col_idx]; + for (idx_t i = 0; i < append_count; i++) { + struct_target_locations[i] = target_locations[i] + offset_in_row; + } + + const auto &struct_layout = layout.GetStructLayout(col_idx); + auto &struct_sources = StructVector::GetEntries(source); + D_ASSERT(struct_layout.ColumnCount() == struct_sources.size()); + + // Set the validity of the entries within the STRUCTs + const auto validity_bytes = ValidityBytes::SizeInBytes(struct_layout.ColumnCount()); + for (idx_t i = 0; i < append_count; i++) { + memset(struct_target_locations[i], ~0, validity_bytes); + } + + // Recurse through the struct children + for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { + auto &struct_source = *struct_sources[struct_col_idx]; + const auto &struct_source_format = source_format.children[struct_col_idx]; + const auto &struct_scatter_function = child_functions[struct_col_idx]; + struct_scatter_function.function(struct_source, struct_source_format, append_sel, append_count, struct_layout, + struct_row_locations, heap_locations, struct_col_idx, dummy_arg, + struct_scatter_function.child_functions); + } +} + +static void TupleDataListScatter(const Vector &source, const TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, + const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, + const vector &child_functions) { + // Source + const auto &source_data = source_format.unified; + const auto &source_sel = *source_data.sel; + const auto data = UnifiedVectorFormat::GetData(source_data); + const auto &validity = source_data.validity; + + // Target + auto target_locations = FlatVector::GetData(row_locations); + auto target_heap_locations = FlatVector::GetData(heap_locations); + + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + // Set validity of the LIST in this layout, and store pointer to where it's stored + const auto offset_in_row = layout.GetOffsets()[col_idx]; + for (idx_t i = 0; i < append_count; i++) { + const auto source_idx = source_sel.get_index(append_sel.get_index(i)); + if (validity.RowIsValid(source_idx)) { + auto &target_heap_location = target_heap_locations[i]; + Store(target_heap_location, target_locations[i] + offset_in_row); + + // Store list length and skip over it + Store(data[source_idx].length, target_heap_location); + target_heap_location += sizeof(uint64_t); + } else { + ValidityBytes(target_locations[i]).SetInvalidUnsafe(entry_idx, idx_in_entry); + } + } + + // Recurse + D_ASSERT(child_functions.size() == 1); + auto &child_source = ListVector::GetEntry(source); + auto &child_format = source_format.children[0]; + const auto &child_function = child_functions[0]; + child_function.function(child_source, child_format, append_sel, append_count, layout, row_locations, heap_locations, + col_idx, source_format.unified, child_function.child_functions); +} + +template +static void TupleDataTemplatedWithinListScatter(const Vector &source, const TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const TupleDataLayout &layout, const Vector &row_locations, + Vector &heap_locations, const idx_t col_idx, + const UnifiedVectorFormat &list_data, + const vector &child_functions) { + // Source + const auto &source_data = source_format.unified; + const auto &source_sel = *source_data.sel; + const auto data = UnifiedVectorFormat::GetData(source_data); + const auto &source_validity = source_data.validity; + + // List data + const auto list_sel = *list_data.sel; + const auto list_entries = UnifiedVectorFormat::GetData(list_data); + const auto &list_validity = list_data.validity; + + // Target + auto target_heap_locations = FlatVector::GetData(heap_locations); + + for (idx_t i = 0; i < append_count; i++) { + const auto list_idx = list_sel.get_index(append_sel.get_index(i)); + if (!list_validity.RowIsValid(list_idx)) { + continue; // Original list entry is invalid - no need to serialize the child + } + + // Get the current list entry + const auto &list_entry = list_entries[list_idx]; + const auto &list_offset = list_entry.offset; + const auto &list_length = list_entry.length; + + // Initialize validity mask and skip heap pointer over it + auto &target_heap_location = target_heap_locations[i]; + ValidityBytes child_mask(target_heap_location); + child_mask.SetAllValid(list_length); + target_heap_location += ValidityBytes::SizeInBytes(list_length); + + // Get the start to the fixed-size data and skip the heap pointer over it + const auto child_data_location = target_heap_location; + target_heap_location += list_length * TupleDataWithinListFixedSize(); + + // Store the data and validity belonging to this list entry + for (idx_t child_i = 0; child_i < list_length; child_i++) { + const auto child_source_idx = source_sel.get_index(list_offset + child_i); + if (source_validity.RowIsValid(child_source_idx)) { + TupleDataWithinListValueStore(data[child_source_idx], + child_data_location + child_i * TupleDataWithinListFixedSize(), + target_heap_location); + } else { + child_mask.SetInvalidUnsafe(child_i); + } + } + } +} + +static void TupleDataStructWithinListScatter(const Vector &source, const TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const TupleDataLayout &layout, const Vector &row_locations, + Vector &heap_locations, const idx_t col_idx, + const UnifiedVectorFormat &list_data, + const vector &child_functions) { + // Source + const auto &source_data = source_format.unified; + const auto &source_sel = *source_data.sel; + const auto &source_validity = source_data.validity; + + // List data + const auto list_sel = *list_data.sel; + const auto list_entries = UnifiedVectorFormat::GetData(list_data); + const auto &list_validity = list_data.validity; + + // Target + auto target_heap_locations = FlatVector::GetData(heap_locations); + + // Initialize the validity of the STRUCTs + for (idx_t i = 0; i < append_count; i++) { + const auto list_idx = list_sel.get_index(append_sel.get_index(i)); + if (!list_validity.RowIsValid(list_idx)) { + continue; // Original list entry is invalid - no need to serialize the child + } + + // Get the current list entry + const auto &list_entry = list_entries[list_idx]; + const auto &list_offset = list_entry.offset; + const auto &list_length = list_entry.length; + + // Initialize validity mask and skip the heap pointer over it + auto &target_heap_location = target_heap_locations[i]; + ValidityBytes child_mask(target_heap_location); + child_mask.SetAllValid(list_length); + target_heap_location += ValidityBytes::SizeInBytes(list_length); + + // Store the validity belonging to this list entry + for (idx_t child_i = 0; child_i < list_length; child_i++) { + const auto child_source_idx = source_sel.get_index(list_offset + child_i); + if (!source_validity.RowIsValid(child_source_idx)) { + child_mask.SetInvalidUnsafe(child_i); + } + } + } + + // Recurse through the children + auto &struct_sources = StructVector::GetEntries(source); + for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { + auto &struct_source = *struct_sources[struct_col_idx]; + auto &struct_format = source_format.children[struct_col_idx]; + const auto &struct_scatter_function = child_functions[struct_col_idx]; + struct_scatter_function.function(struct_source, struct_format, append_sel, append_count, layout, row_locations, + heap_locations, struct_col_idx, list_data, + struct_scatter_function.child_functions); + } +} + +static void TupleDataListWithinListScatter(const Vector &child_list, const TupleDataVectorFormat &child_list_format, + const SelectionVector &append_sel, const idx_t append_count, + const TupleDataLayout &layout, const Vector &row_locations, + Vector &heap_locations, const idx_t col_idx, + const UnifiedVectorFormat &list_data, + const vector &child_functions) { + // List data (of the list Vector that "child_list" is in) + const auto list_sel = *list_data.sel; + const auto list_entries = UnifiedVectorFormat::GetData(list_data); + const auto &list_validity = list_data.validity; + + // Child list + const auto &child_list_data = child_list_format.unified; + const auto child_list_sel = *child_list_data.sel; + const auto child_list_entries = UnifiedVectorFormat::GetData(child_list_data); + const auto &child_list_validity = child_list_data.validity; + + // Target + auto target_heap_locations = FlatVector::GetData(heap_locations); + + for (idx_t i = 0; i < append_count; i++) { + const auto list_idx = list_sel.get_index(append_sel.get_index(i)); + if (!list_validity.RowIsValid(list_idx)) { + continue; // Original list entry is invalid - no need to serialize the child list + } + + // Get the current list entry + const auto &list_entry = list_entries[list_idx]; + const auto &list_offset = list_entry.offset; + const auto &list_length = list_entry.length; + + // Initialize validity mask and skip heap pointer over it + auto &target_heap_location = target_heap_locations[i]; + ValidityBytes child_mask(target_heap_location); + child_mask.SetAllValid(list_length); + target_heap_location += ValidityBytes::SizeInBytes(list_length); + + // Get the start to the fixed-size data and skip the heap pointer over it + const auto child_data_location = target_heap_location; + target_heap_location += list_length * sizeof(uint64_t); + + for (idx_t child_i = 0; child_i < list_length; child_i++) { + const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); + if (child_list_validity.RowIsValid(child_list_idx)) { + const auto &child_list_length = child_list_entries[child_list_idx].length; + Store(child_list_length, child_data_location + child_i * sizeof(uint64_t)); + } else { + child_mask.SetInvalidUnsafe(child_i); + } + } + } + + // Recurse + D_ASSERT(child_functions.size() == 1); + auto &child_vec = ListVector::GetEntry(child_list); + auto &child_format = child_list_format.children[0]; + auto &combined_child_list_data = child_format.combined_list_data->combined_data; + const auto &child_function = child_functions[0]; + child_function.function(child_vec, child_format, append_sel, append_count, layout, row_locations, heap_locations, + col_idx, combined_child_list_data, child_function.child_functions); +} + +template +tuple_data_scatter_function_t TupleDataGetScatterFunction(bool within_list) { + return within_list ? TupleDataTemplatedWithinListScatter : TupleDataTemplatedScatter; +} + +TupleDataScatterFunction TupleDataCollection::GetScatterFunction(const LogicalType &type, bool within_list) { + TupleDataScatterFunction result; + switch (type.InternalType()) { + case PhysicalType::BOOL: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::INT8: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::INT16: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::INT32: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::INT64: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::INT128: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::UINT8: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::UINT16: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::UINT32: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::UINT64: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::FLOAT: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::DOUBLE: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::INTERVAL: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::VARCHAR: + result.function = TupleDataGetScatterFunction(within_list); + break; + case PhysicalType::STRUCT: { + result.function = within_list ? TupleDataStructWithinListScatter : TupleDataStructScatter; + for (const auto &child_type : StructType::GetChildTypes(type)) { + result.child_functions.push_back(GetScatterFunction(child_type.second, within_list)); + } + break; + } + case PhysicalType::LIST: + result.function = within_list ? TupleDataListWithinListScatter : TupleDataListScatter; + result.child_functions.emplace_back(GetScatterFunction(ListType::GetChildType(type), true)); + break; + default: + throw InternalException("Unsupported type for TupleDataCollection::GetScatterFunction"); + } + return result; +} + +void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, + DataChunk &result, const SelectionVector &target_sel) const { + D_ASSERT(result.ColumnCount() == layout.ColumnCount()); + vector column_ids; + column_ids.reserve(layout.ColumnCount()); + for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { + column_ids.emplace_back(col_idx); + } + Gather(row_locations, scan_sel, scan_count, column_ids, result, target_sel); +} + +void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, + const vector &column_ids, DataChunk &result, + const SelectionVector &target_sel) const { + for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { + Gather(row_locations, scan_sel, scan_count, column_ids[col_idx], result.data[col_idx], target_sel); + } +} + +void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, + const column_t column_id, Vector &result, const SelectionVector &target_sel) const { + const auto &gather_function = gather_functions[column_id]; + gather_function.function(layout, row_locations, column_id, scan_sel, scan_count, result, target_sel, result, + gather_function.child_functions); +} + +template +static void TupleDataTemplatedGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, + const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, + const SelectionVector &target_sel, Vector &dummy_vector, + const vector &child_functions) { + // Source + auto source_locations = FlatVector::GetData(row_locations); + + // Target + auto target_data = FlatVector::GetData(target); + auto &target_validity = FlatVector::Validity(target); + + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + const auto offset_in_row = layout.GetOffsets()[col_idx]; + for (idx_t i = 0; i < scan_count; i++) { + const auto &source_row = source_locations[scan_sel.get_index(i)]; + const auto target_idx = target_sel.get_index(i); + ValidityBytes row_mask(source_row); + if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { + target_data[target_idx] = Load(source_row + offset_in_row); + } else { + target_validity.SetInvalid(target_idx); + } + } +} + +static void TupleDataStructGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, + const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, + const SelectionVector &target_sel, Vector &dummy_vector, + const vector &child_functions) { + // Source + auto source_locations = FlatVector::GetData(row_locations); + + // Target + auto &target_validity = FlatVector::Validity(target); + + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + // Get validity of the struct and create a Vector of pointers to the start of the TupleDataLayout of the STRUCT + Vector struct_row_locations(LogicalType::POINTER); + auto struct_source_locations = FlatVector::GetData(struct_row_locations); + const auto offset_in_row = layout.GetOffsets()[col_idx]; + for (idx_t i = 0; i < scan_count; i++) { + const auto source_idx = scan_sel.get_index(i); + const auto &source_row = source_locations[source_idx]; + + // Set the validity + ValidityBytes row_mask(source_row); + if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { + const auto target_idx = target_sel.get_index(i); + target_validity.SetInvalid(target_idx); + } + + // Set the pointer + struct_source_locations[source_idx] = source_row + offset_in_row; + } + + // Get the struct layout and struct entries + const auto &struct_layout = layout.GetStructLayout(col_idx); + auto &struct_targets = StructVector::GetEntries(target); + D_ASSERT(struct_layout.ColumnCount() == struct_targets.size()); + + // Recurse through the struct children + for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { + auto &struct_target = *struct_targets[struct_col_idx]; + const auto &struct_gather_function = child_functions[struct_col_idx]; + struct_gather_function.function(struct_layout, struct_row_locations, struct_col_idx, scan_sel, scan_count, + struct_target, target_sel, dummy_vector, + struct_gather_function.child_functions); + } +} + +static void TupleDataListGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, + const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, + const SelectionVector &target_sel, Vector &dummy_vector, + const vector &child_functions) { + // Source + auto source_locations = FlatVector::GetData(row_locations); + + // Target + auto target_list_entries = FlatVector::GetData(target); + auto &target_validity = FlatVector::Validity(target); + + // Precompute mask indexes + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + // Load pointers to the data from the row + Vector heap_locations(LogicalType::POINTER); + auto source_heap_locations = FlatVector::GetData(heap_locations); + auto &source_heap_validity = FlatVector::Validity(heap_locations); + + const auto offset_in_row = layout.GetOffsets()[col_idx]; + uint64_t target_list_offset = 0; + for (idx_t i = 0; i < scan_count; i++) { + const auto source_idx = scan_sel.get_index(i); + const auto target_idx = target_sel.get_index(i); + + const auto &source_row = source_locations[source_idx]; + ValidityBytes row_mask(source_row); + if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { + auto &source_heap_location = source_heap_locations[source_idx]; + source_heap_location = Load(source_row + offset_in_row); + + // Load list size and skip over + const auto list_length = Load(source_heap_location); + source_heap_location += sizeof(uint64_t); + + // Initialize list entry, and increment offset + target_list_entries[target_idx] = {target_list_offset, list_length}; + target_list_offset += list_length; + } else { + source_heap_validity.SetInvalid(source_idx); + target_validity.SetInvalid(target_idx); + } + } + auto list_size_before = ListVector::GetListSize(target); + ListVector::Reserve(target, list_size_before + target_list_offset); + ListVector::SetListSize(target, list_size_before + target_list_offset); + + // Recurse + D_ASSERT(child_functions.size() == 1); + const auto &child_function = child_functions[0]; + child_function.function(layout, heap_locations, list_size_before, scan_sel, scan_count, + ListVector::GetEntry(target), target_sel, target, child_function.child_functions); +} + +template +static void TupleDataTemplatedWithinListGather(const TupleDataLayout &layout, Vector &heap_locations, + const idx_t list_size_before, const SelectionVector &scan_sel, + const idx_t scan_count, Vector &target, + const SelectionVector &target_sel, Vector &list_vector, + const vector &child_functions) { + // Source + auto source_heap_locations = FlatVector::GetData(heap_locations); + auto &source_heap_validity = FlatVector::Validity(heap_locations); + + // Target + auto target_data = FlatVector::GetData(target); + auto &target_validity = FlatVector::Validity(target); + + // List parent + const auto list_entries = FlatVector::GetData(list_vector); + + uint64_t target_offset = list_size_before; + for (idx_t i = 0; i < scan_count; i++) { + const auto source_idx = scan_sel.get_index(i); + if (!source_heap_validity.RowIsValid(source_idx)) { + continue; + } + + const auto &list_length = list_entries[target_sel.get_index(i)].length; + + // Initialize validity mask + auto &source_heap_location = source_heap_locations[source_idx]; + ValidityBytes source_mask(source_heap_location); + source_heap_location += ValidityBytes::SizeInBytes(list_length); + + // Get the start to the fixed-size data and skip the heap pointer over it + const auto source_data_location = source_heap_location; + source_heap_location += list_length * TupleDataWithinListFixedSize(); + + // Load the child validity and data belonging to this list entry + for (idx_t child_i = 0; child_i < list_length; child_i++) { + if (source_mask.RowIsValidUnsafe(child_i)) { + target_data[target_offset + child_i] = TupleDataWithinListValueLoad( + source_data_location + child_i * TupleDataWithinListFixedSize(), source_heap_location); + } else { + target_validity.SetInvalid(target_offset + child_i); + } + } + target_offset += list_length; + } +} + +static void TupleDataStructWithinListGather(const TupleDataLayout &layout, Vector &heap_locations, + const idx_t list_size_before, const SelectionVector &scan_sel, + const idx_t scan_count, Vector &target, const SelectionVector &target_sel, + Vector &list_vector, + const vector &child_functions) { + // Source + auto source_heap_locations = FlatVector::GetData(heap_locations); + auto &source_heap_validity = FlatVector::Validity(heap_locations); + + // Target + auto &target_validity = FlatVector::Validity(target); + + // List parent + const auto list_entries = FlatVector::GetData(list_vector); + + uint64_t target_offset = list_size_before; + for (idx_t i = 0; i < scan_count; i++) { + const auto source_idx = scan_sel.get_index(i); + if (!source_heap_validity.RowIsValid(source_idx)) { + continue; + } + + const auto &list_length = list_entries[target_sel.get_index(i)].length; + + // Initialize validity mask and skip over it + auto &source_heap_location = source_heap_locations[source_idx]; + ValidityBytes source_mask(source_heap_location); + source_heap_location += ValidityBytes::SizeInBytes(list_length); + + // Load the child validity belonging to this list entry + for (idx_t child_i = 0; child_i < list_length; child_i++) { + if (!source_mask.RowIsValidUnsafe(child_i)) { + target_validity.SetInvalid(target_offset + child_i); + } + } + target_offset += list_length; + } + + // Recurse + auto &struct_targets = StructVector::GetEntries(target); + for (idx_t struct_col_idx = 0; struct_col_idx < struct_targets.size(); struct_col_idx++) { + auto &struct_target = *struct_targets[struct_col_idx]; + const auto &struct_gather_function = child_functions[struct_col_idx]; + struct_gather_function.function(layout, heap_locations, list_size_before, scan_sel, scan_count, struct_target, + target_sel, list_vector, struct_gather_function.child_functions); + } +} + +static void TupleDataListWithinListGather(const TupleDataLayout &layout, Vector &heap_locations, + const idx_t list_size_before, const SelectionVector &scan_sel, + const idx_t scan_count, Vector &target, const SelectionVector &target_sel, + Vector &list_vector, const vector &child_functions) { + // Source + auto source_heap_locations = FlatVector::GetData(heap_locations); + auto &source_heap_validity = FlatVector::Validity(heap_locations); + + // Target + auto target_list_entries = FlatVector::GetData(target); + auto &target_validity = FlatVector::Validity(target); + const auto child_list_size_before = ListVector::GetListSize(target); + + // List parent + const auto list_entries = FlatVector::GetData(list_vector); + + // We need to create a vector that has the combined list sizes (hugeint_t has same size as list_entry_t) + Vector combined_list_vector(LogicalType::HUGEINT); + auto combined_list_entries = FlatVector::GetData(combined_list_vector); + + uint64_t target_offset = list_size_before; + uint64_t target_child_offset = child_list_size_before; + for (idx_t i = 0; i < scan_count; i++) { + const auto source_idx = scan_sel.get_index(i); + if (!source_heap_validity.RowIsValid(source_idx)) { + continue; + } + + const auto &list_length = list_entries[target_sel.get_index(i)].length; + + // Initialize validity mask and skip over it + auto &source_heap_location = source_heap_locations[source_idx]; + ValidityBytes source_mask(source_heap_location); + source_heap_location += ValidityBytes::SizeInBytes(list_length); + + // Get the start to the fixed-size data and skip the heap pointer over it + const auto source_data_location = source_heap_location; + source_heap_location += list_length * sizeof(uint64_t); + + // Set the offset of the combined list entry + auto &combined_list_entry = combined_list_entries[target_sel.get_index(i)]; + combined_list_entry.offset = target_child_offset; + + // Load the child validity and data belonging to this list entry + for (idx_t child_i = 0; child_i < list_length; child_i++) { + if (source_mask.RowIsValidUnsafe(child_i)) { + auto &target_list_entry = target_list_entries[target_offset + child_i]; + target_list_entry.offset = target_child_offset; + target_list_entry.length = Load(source_data_location + child_i * sizeof(uint64_t)); + target_child_offset += target_list_entry.length; + } else { + target_validity.SetInvalid(target_offset + child_i); + } + } + + // Set the length of the combined list entry + combined_list_entry.length = target_child_offset - combined_list_entry.offset; + + target_offset += list_length; + } + ListVector::Reserve(target, target_child_offset); + ListVector::SetListSize(target, target_child_offset); + + // Recurse + D_ASSERT(child_functions.size() == 1); + const auto &child_function = child_functions[0]; + child_function.function(layout, heap_locations, child_list_size_before, scan_sel, scan_count, + ListVector::GetEntry(target), target_sel, combined_list_vector, + child_function.child_functions); +} + +template +tuple_data_gather_function_t TupleDataGetGatherFunction(bool within_list) { + return within_list ? TupleDataTemplatedWithinListGather : TupleDataTemplatedGather; +} + +TupleDataGatherFunction TupleDataCollection::GetGatherFunction(const LogicalType &type, bool within_list) { + TupleDataGatherFunction result; + switch (type.InternalType()) { + case PhysicalType::BOOL: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::INT8: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::INT16: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::INT32: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::INT64: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::INT128: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::UINT8: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::UINT16: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::UINT32: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::UINT64: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::FLOAT: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::DOUBLE: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::INTERVAL: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::VARCHAR: + result.function = TupleDataGetGatherFunction(within_list); + break; + case PhysicalType::STRUCT: { + result.function = within_list ? TupleDataStructWithinListGather : TupleDataStructGather; + for (const auto &child_type : StructType::GetChildTypes(type)) { + result.child_functions.push_back(GetGatherFunction(child_type.second, within_list)); + } + break; + } + case PhysicalType::LIST: + result.function = within_list ? TupleDataListWithinListGather : TupleDataListGather; + result.child_functions.push_back(GetGatherFunction(ListType::GetChildType(type), true)); + break; + default: + throw InternalException("Unsupported type for TupleDataCollection::GetGatherFunction"); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_segment.cpp b/src/duckdb/src/common/types/row/tuple_data_segment.cpp new file mode 100644 index 00000000..d84da30f --- /dev/null +++ b/src/duckdb/src/common/types/row/tuple_data_segment.cpp @@ -0,0 +1,175 @@ +#include "duckdb/common/types/row/tuple_data_segment.hpp" + +#include "duckdb/common/types/row/tuple_data_allocator.hpp" + +namespace duckdb { + +TupleDataChunkPart::TupleDataChunkPart(mutex &lock_p) : lock(lock_p) { +} + +void SwapTupleDataChunkPart(TupleDataChunkPart &a, TupleDataChunkPart &b) { + std::swap(a.row_block_index, b.row_block_index); + std::swap(a.row_block_offset, b.row_block_offset); + std::swap(a.heap_block_index, b.heap_block_index); + std::swap(a.heap_block_offset, b.heap_block_offset); + std::swap(a.base_heap_ptr, b.base_heap_ptr); + std::swap(a.total_heap_size, b.total_heap_size); + std::swap(a.count, b.count); + std::swap(a.lock, b.lock); +} + +TupleDataChunkPart::TupleDataChunkPart(TupleDataChunkPart &&other) noexcept : lock((other.lock)) { + SwapTupleDataChunkPart(*this, other); +} + +TupleDataChunkPart &TupleDataChunkPart::operator=(TupleDataChunkPart &&other) noexcept { + SwapTupleDataChunkPart(*this, other); + return *this; +} + +TupleDataChunk::TupleDataChunk() : count(0), lock(make_unsafe_uniq()) { + parts.reserve(2); +} + +static inline void SwapTupleDataChunk(TupleDataChunk &a, TupleDataChunk &b) noexcept { + std::swap(a.parts, b.parts); + std::swap(a.row_block_ids, b.row_block_ids); + std::swap(a.heap_block_ids, b.heap_block_ids); + std::swap(a.count, b.count); + std::swap(a.lock, b.lock); +} + +TupleDataChunk::TupleDataChunk(TupleDataChunk &&other) noexcept { + SwapTupleDataChunk(*this, other); +} + +TupleDataChunk &TupleDataChunk::operator=(TupleDataChunk &&other) noexcept { + SwapTupleDataChunk(*this, other); + return *this; +} + +void TupleDataChunk::AddPart(TupleDataChunkPart &&part, const TupleDataLayout &layout) { + count += part.count; + row_block_ids.insert(part.row_block_index); + if (!layout.AllConstant() && part.total_heap_size > 0) { + heap_block_ids.insert(part.heap_block_index); + } + part.lock = *lock; + parts.emplace_back(std::move(part)); +} + +void TupleDataChunk::Verify() const { +#ifdef DEBUG + idx_t total_count = 0; + for (const auto &part : parts) { + total_count += part.count; + } + D_ASSERT(this->count == total_count); + D_ASSERT(this->count <= STANDARD_VECTOR_SIZE); +#endif +} + +void TupleDataChunk::MergeLastChunkPart(const TupleDataLayout &layout) { + if (parts.size() < 2) { + return; + } + + auto &second_to_last = parts[parts.size() - 2]; + auto &last = parts[parts.size() - 1]; + + auto rows_align = + last.row_block_index == second_to_last.row_block_index && + last.row_block_offset == second_to_last.row_block_offset + second_to_last.count * layout.GetRowWidth(); + + if (!rows_align) { // If rows don't align we can never merge + return; + } + + if (layout.AllConstant()) { // No heap and rows align - merge + second_to_last.count += last.count; + parts.pop_back(); + return; + } + + if (last.heap_block_index == second_to_last.heap_block_index && + last.heap_block_offset == second_to_last.heap_block_index + second_to_last.total_heap_size && + last.base_heap_ptr == second_to_last.base_heap_ptr) { // There is a heap and it aligns - merge + second_to_last.total_heap_size += last.total_heap_size; + second_to_last.count += last.count; + parts.pop_back(); + } +} + +TupleDataSegment::TupleDataSegment(shared_ptr allocator_p) + : allocator(std::move(allocator_p)), count(0), data_size(0) { +} + +TupleDataSegment::~TupleDataSegment() { + lock_guard guard(pinned_handles_lock); + pinned_row_handles.clear(); + pinned_heap_handles.clear(); + allocator = nullptr; +} + +void SwapTupleDataSegment(TupleDataSegment &a, TupleDataSegment &b) { + std::swap(a.allocator, b.allocator); + std::swap(a.chunks, b.chunks); + std::swap(a.count, b.count); + std::swap(a.data_size, b.data_size); + std::swap(a.pinned_row_handles, b.pinned_row_handles); + std::swap(a.pinned_heap_handles, b.pinned_heap_handles); +} + +TupleDataSegment::TupleDataSegment(TupleDataSegment &&other) noexcept { + SwapTupleDataSegment(*this, other); +} + +TupleDataSegment &TupleDataSegment::operator=(TupleDataSegment &&other) noexcept { + SwapTupleDataSegment(*this, other); + return *this; +} + +idx_t TupleDataSegment::ChunkCount() const { + return chunks.size(); +} + +idx_t TupleDataSegment::SizeInBytes() const { + return data_size; +} + +void TupleDataSegment::Unpin() { + lock_guard guard(pinned_handles_lock); + pinned_row_handles.clear(); + pinned_heap_handles.clear(); +} + +void TupleDataSegment::Verify() const { +#ifdef DEBUG + const auto &layout = allocator->GetLayout(); + + idx_t total_count = 0; + idx_t total_size = 0; + for (const auto &chunk : chunks) { + chunk.Verify(); + total_count += chunk.count; + + total_size += chunk.count * layout.GetRowWidth(); + if (!layout.AllConstant()) { + for (const auto &part : chunk.parts) { + total_size += part.total_heap_size; + } + } + } + D_ASSERT(total_count == this->count); + D_ASSERT(total_size == this->data_size); +#endif +} + +void TupleDataSegment::VerifyEverythingPinned() const { +#ifdef DEBUG + D_ASSERT(pinned_row_handles.size() == allocator->RowBlockCount()); + D_ASSERT(pinned_heap_handles.size() == allocator->HeapBlockCount()); +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/selection_vector.cpp b/src/duckdb/src/common/types/selection_vector.cpp new file mode 100644 index 00000000..463cc76d --- /dev/null +++ b/src/duckdb/src/common/types/selection_vector.cpp @@ -0,0 +1,46 @@ +#include "duckdb/common/types/selection_vector.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/to_string.hpp" + +namespace duckdb { + +SelectionData::SelectionData(idx_t count) { + owned_data = make_unsafe_uniq_array(count); +#ifdef DEBUG + for (idx_t i = 0; i < count; i++) { + owned_data[i] = std::numeric_limits::max(); + } +#endif +} + +// LCOV_EXCL_START +string SelectionVector::ToString(idx_t count) const { + string result = "Selection Vector (" + to_string(count) + ") ["; + for (idx_t i = 0; i < count; i++) { + if (i != 0) { + result += ", "; + } + result += to_string(get_index(i)); + } + result += "]"; + return result; +} + +void SelectionVector::Print(idx_t count) const { + Printer::Print(ToString(count)); +} +// LCOV_EXCL_STOP + +buffer_ptr SelectionVector::Slice(const SelectionVector &sel, idx_t count) const { + auto data = make_buffer(count); + auto result_ptr = data->owned_data.get(); + // for every element, we perform result[i] = target[new[i]] + for (idx_t i = 0; i < count; i++) { + auto new_idx = sel.get_index(i); + auto idx = this->get_index(new_idx); + result_ptr[i] = idx; + } + return data; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/string_heap.cpp b/src/duckdb/src/common/types/string_heap.cpp new file mode 100644 index 00000000..d1b7b661 --- /dev/null +++ b/src/duckdb/src/common/types/string_heap.cpp @@ -0,0 +1,62 @@ +#include "duckdb/common/types/string_heap.hpp" + +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" +#include "utf8proc_wrapper.hpp" + +#include + +namespace duckdb { + +StringHeap::StringHeap(Allocator &allocator) : allocator(allocator) { +} + +void StringHeap::Destroy() { + allocator.Destroy(); +} + +void StringHeap::Move(StringHeap &other) { + other.allocator.Move(allocator); +} + +string_t StringHeap::AddString(const char *data, idx_t len) { + D_ASSERT(Utf8Proc::Analyze(data, len) != UnicodeType::INVALID); + return AddBlob(data, len); +} + +string_t StringHeap::AddString(const char *data) { + return AddString(data, strlen(data)); +} + +string_t StringHeap::AddString(const string &data) { + return AddString(data.c_str(), data.size()); +} + +string_t StringHeap::AddString(const string_t &data) { + return AddString(data.GetData(), data.GetSize()); +} + +string_t StringHeap::AddBlob(const char *data, idx_t len) { + auto insert_string = EmptyString(len); + auto insert_pos = insert_string.GetDataWriteable(); + memcpy(insert_pos, data, len); + insert_string.Finalize(); + return insert_string; +} + +string_t StringHeap::AddBlob(const string_t &data) { + return AddBlob(data.GetData(), data.GetSize()); +} + +string_t StringHeap::EmptyString(idx_t len) { + D_ASSERT(len > string_t::INLINE_LENGTH); + auto insert_pos = const_char_ptr_cast(allocator.Allocate(len)); + return string_t(insert_pos, len); +} + +idx_t StringHeap::SizeInBytes() const { + return allocator.SizeInBytes(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/string_type.cpp b/src/duckdb/src/common/types/string_type.cpp new file mode 100644 index 00000000..e8e98718 --- /dev/null +++ b/src/duckdb/src/common/types/string_type.cpp @@ -0,0 +1,29 @@ +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/algorithm.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +void string_t::Verify() const { + auto dataptr = GetData(); + (void)dataptr; + D_ASSERT(dataptr); + +#ifdef DEBUG + auto utf_type = Utf8Proc::Analyze(dataptr, GetSize()); + D_ASSERT(utf_type != UnicodeType::INVALID); +#endif + + // verify that the prefix contains the first four characters of the string + for (idx_t i = 0; i < MinValue(PREFIX_LENGTH, GetSize()); i++) { + D_ASSERT(GetPrefix()[i] == dataptr[i]); + } + // verify that for strings with length <= INLINE_LENGTH, the rest of the string is zero + for (idx_t i = GetSize(); i < INLINE_LENGTH; i++) { + D_ASSERT(GetData()[i] == '\0'); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/time.cpp b/src/duckdb/src/common/types/time.cpp new file mode 100644 index 00000000..fa906a93 --- /dev/null +++ b/src/duckdb/src/common/types/time.cpp @@ -0,0 +1,342 @@ +#include "duckdb/common/types/time.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/operator/multiply.hpp" + +#include +#include +#include + +namespace duckdb { + +static_assert(sizeof(dtime_t) == sizeof(int64_t), "dtime_t was padded"); + +// string format is hh:mm:ss.microsecondsZ +// microseconds and Z are optional +// ISO 8601 + +bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict) { + int32_t hour = -1, min = -1, sec = -1, micros = -1; + pos = 0; + + if (len == 0) { + return false; + } + + int sep; + + // skip leading spaces + while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { + pos++; + } + + if (pos >= len) { + return false; + } + + if (!StringUtil::CharacterIsDigit(buf[pos])) { + return false; + } + + if (!Date::ParseDoubleDigit(buf, len, pos, hour)) { + return false; + } + if (hour < 0 || hour >= 24) { + return false; + } + + if (pos >= len) { + return false; + } + + // fetch the separator + sep = buf[pos++]; + if (sep != ':') { + // invalid separator + return false; + } + + if (!Date::ParseDoubleDigit(buf, len, pos, min)) { + return false; + } + if (min < 0 || min >= 60) { + return false; + } + + if (pos >= len) { + return false; + } + + if (buf[pos++] != sep) { + return false; + } + + if (!Date::ParseDoubleDigit(buf, len, pos, sec)) { + return false; + } + if (sec < 0 || sec >= 60) { + return false; + } + + micros = 0; + if (pos < len && buf[pos] == '.') { + pos++; + // we expect some microseconds + int32_t mult = 100000; + for (; pos < len && StringUtil::CharacterIsDigit(buf[pos]); pos++, mult /= 10) { + if (mult > 0) { + micros += (buf[pos] - '0') * mult; + } + } + } + + // in strict mode, check remaining string for non-space characters + if (strict) { + // skip trailing spaces + while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { + pos++; + } + // check position. if end was not reached, non-space chars remaining + if (pos < len) { + return false; + } + } + + result = Time::FromTime(hour, min, sec, micros); + return true; +} + +bool Time::TryConvertTime(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict) { + if (!Time::TryConvertInternal(buf, len, pos, result, strict)) { + if (!strict) { + // last chance, check if we can parse as timestamp + timestamp_t timestamp; + if (Timestamp::TryConvertTimestamp(buf, len, timestamp) == TimestampCastResult::SUCCESS) { + if (!Timestamp::IsFinite(timestamp)) { + return false; + } + result = Timestamp::GetTime(timestamp); + return true; + } + } + return false; + } + return true; +} + +bool Time::TryParseUTCOffset(const char *str, idx_t &pos, idx_t len, int32_t &offset) { + offset = 0; + if (pos == len || StringUtil::CharacterIsSpace(str[pos])) { + return true; + } + + idx_t curpos = pos; + // Minimum of 3 characters + if (curpos + 3 > len) { + // no characters left to parse + return false; + } + + const auto sign_char = str[curpos]; + if (sign_char != '+' && sign_char != '-') { + // expected either + or - + return false; + } + curpos++; + + int32_t hh = 0; + idx_t start = curpos; + for (; curpos < len; ++curpos) { + const auto c = str[curpos]; + if (!StringUtil::CharacterIsDigit(c)) { + break; + } + hh = hh * 10 + (c - '0'); + } + // HH is in [-1559,+1559] and must be at least two digits + if (curpos - start < 2 || hh > 1559) { + return false; + } + + // optional minute specifier: expected ":MM" + int32_t mm = 0; + if (curpos + 3 <= len && str[curpos] == ':') { + ++curpos; + if (!Date::ParseDoubleDigit(str, len, curpos, mm) || mm >= Interval::MINS_PER_HOUR) { + return false; + } + } + + // optional seconds specifier: expected ":SS" + int32_t ss = 0; + if (curpos + 3 <= len && str[curpos] == ':') { + ++curpos; + if (!Date::ParseDoubleDigit(str, len, curpos, ss) || ss >= Interval::SECS_PER_MINUTE) { + return false; + } + } + + // Assemble the offset now that we know nothing went wrong + offset += hh * Interval::SECS_PER_HOUR; + offset += mm * Interval::SECS_PER_MINUTE; + offset += ss; + if (sign_char == '-') { + offset = -offset; + } + + pos = curpos; + + return true; +} + +bool Time::TryConvertTimeTZ(const char *buf, idx_t len, idx_t &pos, dtime_tz_t &result, bool strict) { + dtime_t time_part; + if (!Time::TryConvertInternal(buf, len, pos, time_part, false)) { + if (!strict) { + // last chance, check if we can parse as timestamp + timestamp_t timestamp; + if (Timestamp::TryConvertTimestamp(buf, len, timestamp) == TimestampCastResult::SUCCESS) { + if (!Timestamp::IsFinite(timestamp)) { + return false; + } + result = dtime_tz_t(Timestamp::GetTime(timestamp), 0); + return true; + } + } + return false; + } + + // We can't use Timestamp::TryParseUTCOffset because the colon is optional there but required here. + int32_t offset = 0; + if (!TryParseUTCOffset(buf, pos, len, offset)) { + return false; + } + + // in strict mode, check remaining string for non-space characters + if (strict) { + // skip trailing spaces + while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { + pos++; + } + // check position. if end was not reached, non-space chars remaining + if (pos < len) { + return false; + } + } + + result = dtime_tz_t(time_part, offset); + + return true; +} + +string Time::ConversionError(const string &str) { + return StringUtil::Format("time field value out of range: \"%s\", " + "expected format is ([YYYY-MM-DD ]HH:MM:SS[.MS])", + str); +} + +string Time::ConversionError(string_t str) { + return Time::ConversionError(str.GetString()); +} + +dtime_t Time::FromCString(const char *buf, idx_t len, bool strict) { + dtime_t result; + idx_t pos; + if (!Time::TryConvertTime(buf, len, pos, result, strict)) { + throw ConversionException(ConversionError(string(buf, len))); + } + return result; +} + +dtime_t Time::FromString(const string &str, bool strict) { + return Time::FromCString(str.c_str(), str.size(), strict); +} + +string Time::ToString(dtime_t time) { + int32_t time_units[4]; + Time::Convert(time, time_units[0], time_units[1], time_units[2], time_units[3]); + + char micro_buffer[6]; + auto length = TimeToStringCast::Length(time_units, micro_buffer); + auto buffer = make_unsafe_uniq_array(length); + TimeToStringCast::Format(buffer.get(), length, time_units, micro_buffer); + return string(buffer.get(), length); +} + +string Time::ToUTCOffset(int hour_offset, int minute_offset) { + dtime_t time((hour_offset * Interval::MINS_PER_HOUR + minute_offset) * Interval::MICROS_PER_MINUTE); + + char buffer[1 + 2 + 1 + 2]; + idx_t length = 0; + buffer[length++] = (time.micros < 0 ? '-' : '+'); + time.micros = std::abs(time.micros); + + int32_t time_units[4]; + Time::Convert(time, time_units[0], time_units[1], time_units[2], time_units[3]); + + TimeToStringCast::FormatTwoDigits(buffer + length, time_units[0]); + length += 2; + if (time_units[1]) { + buffer[length++] = ':'; + TimeToStringCast::FormatTwoDigits(buffer + length, time_units[1]); + length += 2; + } + + return string(buffer, length); +} + +dtime_t Time::FromTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { + int64_t result; + result = hour; // hours + result = result * Interval::MINS_PER_HOUR + minute; // hours -> minutes + result = result * Interval::SECS_PER_MINUTE + second; // minutes -> seconds + result = result * Interval::MICROS_PER_SEC + microseconds; // seconds -> microseconds + return dtime_t(result); +} + +bool Time::IsValidTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { + if (hour < 0 || hour >= 24) { + return false; + } + if (minute < 0 || minute >= 60) { + return false; + } + if (second < 0 || second > 60) { + return false; + } + if (microseconds < 0 || microseconds > 1000000) { + return false; + } + return true; +} + +void Time::Convert(dtime_t dtime, int32_t &hour, int32_t &min, int32_t &sec, int32_t µs) { + int64_t time = dtime.micros; + hour = int32_t(time / Interval::MICROS_PER_HOUR); + time -= int64_t(hour) * Interval::MICROS_PER_HOUR; + min = int32_t(time / Interval::MICROS_PER_MINUTE); + time -= int64_t(min) * Interval::MICROS_PER_MINUTE; + sec = int32_t(time / Interval::MICROS_PER_SEC); + time -= int64_t(sec) * Interval::MICROS_PER_SEC; + micros = int32_t(time); + D_ASSERT(Time::IsValidTime(hour, min, sec, micros)); +} + +dtime_t Time::FromTimeMs(int64_t time_ms) { + int64_t result; + if (!TryMultiplyOperator::Operation(time_ms, Interval::MICROS_PER_MSEC, result)) { + throw ConversionException("Could not convert Time(MS) to Time(US)"); + } + return dtime_t(result); +} + +dtime_t Time::FromTimeNs(int64_t time_ns) { + return dtime_t(time_ns / Interval::NANOS_PER_MICRO); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/timestamp.cpp b/src/duckdb/src/common/types/timestamp.cpp new file mode 100644 index 00000000..f0f5796d --- /dev/null +++ b/src/duckdb/src/common/types/timestamp.cpp @@ -0,0 +1,351 @@ +#include "duckdb/common/types/timestamp.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/chrono.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/limits.hpp" +#include + +namespace duckdb { + +static_assert(sizeof(timestamp_t) == sizeof(int64_t), "timestamp_t was padded"); + +// timestamp/datetime uses 64 bits, high 32 bits for date and low 32 bits for time +// string format is YYYY-MM-DDThh:mm:ssZ +// T may be a space +// Z is optional +// ISO 8601 + +// arithmetic operators +timestamp_t timestamp_t::operator+(const double &value) const { + timestamp_t result; + if (!TryAddOperator::Operation(this->value, int64_t(value), result.value)) { + throw OutOfRangeException("Overflow in timestamp addition"); + } + return result; +} + +int64_t timestamp_t::operator-(const timestamp_t &other) const { + int64_t result; + if (!TrySubtractOperator::Operation(value, int64_t(other.value), result)) { + throw OutOfRangeException("Overflow in timestamp subtraction"); + } + return result; +} + +// in-place operators +timestamp_t ×tamp_t::operator+=(const int64_t &delta) { + if (!TryAddOperator::Operation(value, delta, value)) { + throw OutOfRangeException("Overflow in timestamp increment"); + } + return *this; +} + +timestamp_t ×tamp_t::operator-=(const int64_t &delta) { + if (!TrySubtractOperator::Operation(value, delta, value)) { + throw OutOfRangeException("Overflow in timestamp decrement"); + } + return *this; +} + +bool Timestamp::TryConvertTimestampTZ(const char *str, idx_t len, timestamp_t &result, bool &has_offset, string_t &tz) { + idx_t pos; + date_t date; + dtime_t time; + has_offset = false; + if (!Date::TryConvertDate(str, len, pos, date, has_offset)) { + return false; + } + if (pos == len) { + // no time: only a date or special + if (date == date_t::infinity()) { + result = timestamp_t::infinity(); + return true; + } else if (date == date_t::ninfinity()) { + result = timestamp_t::ninfinity(); + return true; + } + return Timestamp::TryFromDatetime(date, dtime_t(0), result); + } + // try to parse a time field + if (str[pos] == ' ' || str[pos] == 'T') { + pos++; + } + idx_t time_pos = 0; + if (!Time::TryConvertTime(str + pos, len - pos, time_pos, time)) { + return false; + } + pos += time_pos; + if (!Timestamp::TryFromDatetime(date, time, result)) { + return false; + } + if (pos < len) { + // skip a "Z" at the end (as per the ISO8601 specs) + int hour_offset, minute_offset; + if (str[pos] == 'Z') { + pos++; + has_offset = true; + } else if (Timestamp::TryParseUTCOffset(str, pos, len, hour_offset, minute_offset)) { + const int64_t delta = hour_offset * Interval::MICROS_PER_HOUR + minute_offset * Interval::MICROS_PER_MINUTE; + if (!TrySubtractOperator::Operation(result.value, delta, result.value)) { + return false; + } + has_offset = true; + } else { + // Parse a time zone: / [A-Za-z0-9/_]+/ + if (str[pos++] != ' ') { + return false; + } + auto tz_name = str + pos; + for (; pos < len && CharacterIsTimeZone(str[pos]); ++pos) { + continue; + } + auto tz_len = str + pos - tz_name; + if (tz_len) { + tz = string_t(tz_name, tz_len); + } + // Note that the caller must reinterpret the instant we return to the given time zone + } + + // skip any spaces at the end + while (pos < len && StringUtil::CharacterIsSpace(str[pos])) { + pos++; + } + if (pos < len) { + return false; + } + } + return true; +} + +TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, timestamp_t &result) { + string_t tz(nullptr, 0); + bool has_offset = false; + // We don't understand TZ without an extension, so fail if one was provided. + auto success = TryConvertTimestampTZ(str, len, result, has_offset, tz); + if (!success) { + return TimestampCastResult::ERROR_INCORRECT_FORMAT; + } + if (tz.GetSize() == 0) { + // no timezone provided - success! + return TimestampCastResult::SUCCESS; + } + if (tz.GetSize() == 3) { + // we can ONLY handle UTC without ICU being loaded + auto tz_ptr = tz.GetData(); + if ((tz_ptr[0] == 'u' || tz_ptr[0] == 'U') && (tz_ptr[1] == 't' || tz_ptr[1] == 'T') && + (tz_ptr[2] == 'c' || tz_ptr[2] == 'C')) { + return TimestampCastResult::SUCCESS; + } + } + return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; +} + +string Timestamp::ConversionError(const string &str) { + return StringUtil::Format("timestamp field value out of range: \"%s\", " + "expected format is (YYYY-MM-DD HH:MM:SS[.US][±HH:MM| ZONE])", + str); +} + +string Timestamp::UnsupportedTimezoneError(const string &str) { + return StringUtil::Format("timestamp field value \"%s\" has a timestamp that is not UTC.\nUse the TIMESTAMPTZ type " + "with the ICU extension loaded to handle non-UTC timestamps.", + str); +} + +string Timestamp::ConversionError(string_t str) { + return Timestamp::ConversionError(str.GetString()); +} + +string Timestamp::UnsupportedTimezoneError(string_t str) { + return Timestamp::UnsupportedTimezoneError(str.GetString()); +} + +timestamp_t Timestamp::FromCString(const char *str, idx_t len) { + timestamp_t result; + auto cast_result = Timestamp::TryConvertTimestamp(str, len, result); + if (cast_result == TimestampCastResult::SUCCESS) { + return result; + } + if (cast_result == TimestampCastResult::ERROR_NON_UTC_TIMEZONE) { + throw ConversionException(Timestamp::UnsupportedTimezoneError(string(str, len))); + } else { + throw ConversionException(Timestamp::ConversionError(string(str, len))); + } +} + +bool Timestamp::TryParseUTCOffset(const char *str, idx_t &pos, idx_t len, int &hour_offset, int &minute_offset) { + minute_offset = 0; + idx_t curpos = pos; + // parse the next 3 characters + if (curpos + 3 > len) { + // no characters left to parse + return false; + } + char sign_char = str[curpos]; + if (sign_char != '+' && sign_char != '-') { + // expected either + or - + return false; + } + curpos++; + if (!StringUtil::CharacterIsDigit(str[curpos]) || !StringUtil::CharacterIsDigit(str[curpos + 1])) { + // expected +HH or -HH + return false; + } + hour_offset = (str[curpos] - '0') * 10 + (str[curpos + 1] - '0'); + if (sign_char == '-') { + hour_offset = -hour_offset; + } + curpos += 2; + + // optional minute specifier: expected either "MM" or ":MM" + if (curpos >= len) { + // done, nothing left + pos = curpos; + return true; + } + if (str[curpos] == ':') { + curpos++; + } + if (curpos + 2 > len || !StringUtil::CharacterIsDigit(str[curpos]) || + !StringUtil::CharacterIsDigit(str[curpos + 1])) { + // no MM specifier + pos = curpos; + return true; + } + // we have an MM specifier: parse it + minute_offset = (str[curpos] - '0') * 10 + (str[curpos + 1] - '0'); + if (sign_char == '-') { + minute_offset = -minute_offset; + } + pos = curpos + 2; + return true; +} + +timestamp_t Timestamp::FromString(const string &str) { + return Timestamp::FromCString(str.c_str(), str.size()); +} + +string Timestamp::ToString(timestamp_t timestamp) { + if (timestamp == timestamp_t::infinity()) { + return Date::PINF; + } else if (timestamp == timestamp_t::ninfinity()) { + return Date::NINF; + } + date_t date; + dtime_t time; + Timestamp::Convert(timestamp, date, time); + return Date::ToString(date) + " " + Time::ToString(time); +} + +date_t Timestamp::GetDate(timestamp_t timestamp) { + if (timestamp == timestamp_t::infinity()) { + return date_t::infinity(); + } else if (timestamp == timestamp_t::ninfinity()) { + return date_t::ninfinity(); + } + return date_t((timestamp.value + (timestamp.value < 0)) / Interval::MICROS_PER_DAY - (timestamp.value < 0)); +} + +dtime_t Timestamp::GetTime(timestamp_t timestamp) { + if (!IsFinite(timestamp)) { + throw ConversionException("Can't get TIME of infinite TIMESTAMP"); + } + date_t date = Timestamp::GetDate(timestamp); + return dtime_t(timestamp.value - (int64_t(date.days) * int64_t(Interval::MICROS_PER_DAY))); +} + +bool Timestamp::TryFromDatetime(date_t date, dtime_t time, timestamp_t &result) { + if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY, result.value)) { + return false; + } + if (!TryAddOperator::Operation(result.value, time.micros, result.value)) { + return false; + } + return Timestamp::IsFinite(result); +} + +timestamp_t Timestamp::FromDatetime(date_t date, dtime_t time) { + timestamp_t result; + if (!TryFromDatetime(date, time, result)) { + throw Exception("Overflow exception in date/time -> timestamp conversion"); + } + return result; +} + +void Timestamp::Convert(timestamp_t timestamp, date_t &out_date, dtime_t &out_time) { + out_date = GetDate(timestamp); + int64_t days_micros; + if (!TryMultiplyOperator::Operation(out_date.days, Interval::MICROS_PER_DAY, + days_micros)) { + throw ConversionException("Date out of range in timestamp conversion"); + } + out_time = dtime_t(timestamp.value - days_micros); + D_ASSERT(timestamp == Timestamp::FromDatetime(out_date, out_time)); +} + +timestamp_t Timestamp::GetCurrentTimestamp() { + auto now = system_clock::now(); + auto epoch_ms = duration_cast(now.time_since_epoch()).count(); + return Timestamp::FromEpochMs(epoch_ms); +} + +timestamp_t Timestamp::FromEpochSeconds(int64_t sec) { + int64_t result; + if (!TryMultiplyOperator::Operation(sec, Interval::MICROS_PER_SEC, result)) { + throw ConversionException("Could not convert Timestamp(S) to Timestamp(US)"); + } + return timestamp_t(result); +} + +timestamp_t Timestamp::FromEpochMs(int64_t ms) { + int64_t result; + if (!TryMultiplyOperator::Operation(ms, Interval::MICROS_PER_MSEC, result)) { + throw ConversionException("Could not convert Timestamp(MS) to Timestamp(US)"); + } + return timestamp_t(result); +} + +timestamp_t Timestamp::FromEpochMicroSeconds(int64_t micros) { + return timestamp_t(micros); +} + +timestamp_t Timestamp::FromEpochNanoSeconds(int64_t ns) { + return timestamp_t(ns / 1000); +} + +int64_t Timestamp::GetEpochSeconds(timestamp_t timestamp) { + return timestamp.value / Interval::MICROS_PER_SEC; +} + +int64_t Timestamp::GetEpochMs(timestamp_t timestamp) { + return timestamp.value / Interval::MICROS_PER_MSEC; +} + +int64_t Timestamp::GetEpochMicroSeconds(timestamp_t timestamp) { + return timestamp.value; +} + +int64_t Timestamp::GetEpochNanoSeconds(timestamp_t timestamp) { + int64_t result; + int64_t ns_in_us = 1000; + if (!TryMultiplyOperator::Operation(timestamp.value, ns_in_us, result)) { + throw ConversionException("Could not convert Timestamp(US) to Timestamp(NS)"); + } + return result; +} + +double Timestamp::GetJulianDay(timestamp_t timestamp) { + double result = Timestamp::GetTime(timestamp).micros; + result /= Interval::MICROS_PER_DAY; + result += Date::ExtractJulianDay(Timestamp::GetDate(timestamp)); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/uuid.cpp b/src/duckdb/src/common/types/uuid.cpp new file mode 100644 index 00000000..7ab19f9d --- /dev/null +++ b/src/duckdb/src/common/types/uuid.cpp @@ -0,0 +1,127 @@ +#include "duckdb/common/types/uuid.hpp" +#include "duckdb/common/random_engine.hpp" + +namespace duckdb { + +bool UUID::FromString(string str, hugeint_t &result) { + auto hex2char = [](char ch) -> unsigned char { + if (ch >= '0' && ch <= '9') { + return ch - '0'; + } + if (ch >= 'a' && ch <= 'f') { + return 10 + ch - 'a'; + } + if (ch >= 'A' && ch <= 'F') { + return 10 + ch - 'A'; + } + return 0; + }; + auto is_hex = [](char ch) -> bool { + return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F'); + }; + + if (str.empty()) { + return false; + } + int has_braces = 0; + if (str.front() == '{') { + has_braces = 1; + } + if (has_braces && str.back() != '}') { + return false; + } + + result.lower = 0; + result.upper = 0; + size_t count = 0; + for (size_t i = has_braces; i < str.size() - has_braces; ++i) { + if (str[i] == '-') { + continue; + } + if (count >= 32 || !is_hex(str[i])) { + return false; + } + if (count >= 16) { + result.lower = (result.lower << 4) | hex2char(str[i]); + } else { + result.upper = (result.upper << 4) | hex2char(str[i]); + } + count++; + } + // Flip the first bit to make `order by uuid` same as `order by uuid::varchar` + result.upper ^= (uint64_t(1) << 63); + return count == 32; +} + +void UUID::ToString(hugeint_t input, char *buf) { + auto byte_to_hex = [](char byte_val, char *buf, idx_t &pos) { + static char const HEX_DIGITS[] = "0123456789abcdef"; + buf[pos++] = HEX_DIGITS[(byte_val >> 4) & 0xf]; + buf[pos++] = HEX_DIGITS[byte_val & 0xf]; + }; + + // Flip back before convert to string + int64_t upper = input.upper ^ (uint64_t(1) << 63); + idx_t pos = 0; + byte_to_hex(upper >> 56 & 0xFF, buf, pos); + byte_to_hex(upper >> 48 & 0xFF, buf, pos); + byte_to_hex(upper >> 40 & 0xFF, buf, pos); + byte_to_hex(upper >> 32 & 0xFF, buf, pos); + buf[pos++] = '-'; + byte_to_hex(upper >> 24 & 0xFF, buf, pos); + byte_to_hex(upper >> 16 & 0xFF, buf, pos); + buf[pos++] = '-'; + byte_to_hex(upper >> 8 & 0xFF, buf, pos); + byte_to_hex(upper & 0xFF, buf, pos); + buf[pos++] = '-'; + byte_to_hex(input.lower >> 56 & 0xFF, buf, pos); + byte_to_hex(input.lower >> 48 & 0xFF, buf, pos); + buf[pos++] = '-'; + byte_to_hex(input.lower >> 40 & 0xFF, buf, pos); + byte_to_hex(input.lower >> 32 & 0xFF, buf, pos); + byte_to_hex(input.lower >> 24 & 0xFF, buf, pos); + byte_to_hex(input.lower >> 16 & 0xFF, buf, pos); + byte_to_hex(input.lower >> 8 & 0xFF, buf, pos); + byte_to_hex(input.lower & 0xFF, buf, pos); +} + +hugeint_t UUID::GenerateRandomUUID(RandomEngine &engine) { + uint8_t bytes[16]; + for (int i = 0; i < 16; i += 4) { + *reinterpret_cast(bytes + i) = engine.NextRandomInteger(); + } + // variant must be 10xxxxxx + bytes[8] &= 0xBF; + bytes[8] |= 0x80; + // version must be 0100xxxx + bytes[6] &= 0x4F; + bytes[6] |= 0x40; + + hugeint_t result; + result.upper = 0; + result.upper |= ((int64_t)bytes[0] << 56); + result.upper |= ((int64_t)bytes[1] << 48); + result.upper |= ((int64_t)bytes[2] << 40); + result.upper |= ((int64_t)bytes[3] << 32); + result.upper |= ((int64_t)bytes[4] << 24); + result.upper |= ((int64_t)bytes[5] << 16); + result.upper |= ((int64_t)bytes[6] << 8); + result.upper |= bytes[7]; + result.lower = 0; + result.lower |= ((uint64_t)bytes[8] << 56); + result.lower |= ((uint64_t)bytes[9] << 48); + result.lower |= ((uint64_t)bytes[10] << 40); + result.lower |= ((uint64_t)bytes[11] << 32); + result.lower |= ((uint64_t)bytes[12] << 24); + result.lower |= ((uint64_t)bytes[13] << 16); + result.lower |= ((uint64_t)bytes[14] << 8); + result.lower |= bytes[15]; + return result; +} + +hugeint_t UUID::GenerateRandomUUID() { + RandomEngine engine; + return GenerateRandomUUID(engine); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/validity_mask.cpp b/src/duckdb/src/common/types/validity_mask.cpp new file mode 100644 index 00000000..5ec42b77 --- /dev/null +++ b/src/duckdb/src/common/types/validity_mask.cpp @@ -0,0 +1,232 @@ +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/serializer/write_stream.hpp" +#include "duckdb/common/serializer/read_stream.hpp" + +namespace duckdb { + +ValidityData::ValidityData(idx_t count) : TemplatedValidityData(count) { +} +ValidityData::ValidityData(const ValidityMask &original, idx_t count) + : TemplatedValidityData(original.GetData(), count) { +} + +void ValidityMask::Combine(const ValidityMask &other, idx_t count) { + if (other.AllValid()) { + // X & 1 = X + return; + } + if (AllValid()) { + // 1 & Y = Y + Initialize(other); + return; + } + if (validity_mask == other.validity_mask) { + // X & X == X + return; + } + // have to merge + // create a new validity mask that contains the combined mask + auto owned_data = std::move(validity_data); + auto data = GetData(); + auto other_data = other.GetData(); + + Initialize(count); + auto result_data = GetData(); + + auto entry_count = ValidityData::EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + result_data[entry_idx] = data[entry_idx] & other_data[entry_idx]; + } +} + +// LCOV_EXCL_START +string ValidityMask::ToString(idx_t count) const { + string result = "Validity Mask (" + to_string(count) + ") ["; + for (idx_t i = 0; i < count; i++) { + result += RowIsValid(i) ? "." : "X"; + } + result += "]"; + return result; +} +// LCOV_EXCL_STOP + +void ValidityMask::Resize(idx_t old_size, idx_t new_size) { + D_ASSERT(new_size >= old_size); + if (validity_mask) { + auto new_size_count = EntryCount(new_size); + auto old_size_count = EntryCount(old_size); + auto new_validity_data = make_buffer(new_size); + auto new_owned_data = new_validity_data->owned_data.get(); + for (idx_t entry_idx = 0; entry_idx < old_size_count; entry_idx++) { + new_owned_data[entry_idx] = validity_mask[entry_idx]; + } + for (idx_t entry_idx = old_size_count; entry_idx < new_size_count; entry_idx++) { + new_owned_data[entry_idx] = ValidityData::MAX_ENTRY; + } + validity_data = std::move(new_validity_data); + validity_mask = validity_data->owned_data.get(); + } else { + Initialize(new_size); + } +} + +void ValidityMask::Slice(const ValidityMask &other, idx_t source_offset, idx_t count) { + if (other.AllValid()) { + validity_mask = nullptr; + validity_data.reset(); + return; + } + if (source_offset == 0) { + Initialize(other); + return; + } + ValidityMask new_mask(count); + new_mask.SliceInPlace(other, 0, source_offset, count); + Initialize(new_mask); +} + +bool ValidityMask::IsAligned(idx_t count) { + return count % BITS_PER_VALUE == 0; +} + +void ValidityMask::SliceInPlace(const ValidityMask &other, idx_t target_offset, idx_t source_offset, idx_t count) { + if (IsAligned(source_offset) && IsAligned(target_offset)) { + auto target_validity = GetData(); + auto source_validity = other.GetData(); + auto source_offset_entries = EntryCount(source_offset); + auto target_offset_entries = EntryCount(target_offset); + memcpy(target_validity + target_offset_entries, source_validity + source_offset_entries, + sizeof(validity_t) * EntryCount(count)); + return; + } else if (IsAligned(target_offset)) { + // Simple common case where we are shifting into an aligned mask (e.g., 0 in Slice above) + const idx_t entire_units = count / BITS_PER_VALUE; + const idx_t ragged = count % BITS_PER_VALUE; + const idx_t tail = source_offset % BITS_PER_VALUE; + const idx_t head = BITS_PER_VALUE - tail; + auto source_validity = other.GetData() + (source_offset / BITS_PER_VALUE); + auto target_validity = this->GetData() + (target_offset / BITS_PER_VALUE); + auto src_entry = *source_validity++; + for (idx_t i = 0; i < entire_units; ++i) { + // Start with head of previous src + validity_t tgt_entry = src_entry >> tail; + src_entry = *source_validity++; + // Add in tail of current src + tgt_entry |= (src_entry << head); + *target_validity++ = tgt_entry; + } + // Finish last ragged entry + if (ragged) { + // Start with head of previous src + validity_t tgt_entry = (src_entry >> tail); + // Add in the tail of the next src, if head was too small + if (head < ragged) { + src_entry = *source_validity++; + tgt_entry |= (src_entry << head); + } + // Mask off the bits that go past the ragged end + tgt_entry &= (ValidityBuffer::MAX_ENTRY >> (BITS_PER_VALUE - ragged)); + // Restore the ragged end of the target + tgt_entry |= *target_validity & (ValidityBuffer::MAX_ENTRY << ragged); + *target_validity++ = tgt_entry; + } + return; + } + + // FIXME: use bitwise operations here +#if 1 + for (idx_t i = 0; i < count; i++) { + Set(target_offset + i, other.RowIsValid(source_offset + i)); + } +#else + // first shift the "whole" units + idx_t entire_units = offset / BITS_PER_VALUE; + idx_t sub_units = offset - entire_units * BITS_PER_VALUE; + if (entire_units > 0) { + idx_t validity_idx; + for (validity_idx = 0; validity_idx + entire_units < STANDARD_ENTRY_COUNT; validity_idx++) { + new_mask.validity_mask[validity_idx] = other.validity_mask[validity_idx + entire_units]; + } + } + // now we shift the remaining sub units + // this gets a bit more complicated because we have to shift over the borders of the entries + // e.g. suppose we have 2 entries of length 4 and we left-shift by two + // 0101|1010 + // a regular left-shift of both gets us: + // 0100|1000 + // we then OR the overflow (right-shifted by BITS_PER_VALUE - offset) together to get the correct result + // 0100|1000 -> + // 0110|1000 + if (sub_units > 0) { + idx_t validity_idx; + for (validity_idx = 0; validity_idx + 1 < STANDARD_ENTRY_COUNT; validity_idx++) { + new_mask.validity_mask[validity_idx] = + (other.validity_mask[validity_idx] >> sub_units) | + (other.validity_mask[validity_idx + 1] << (BITS_PER_VALUE - sub_units)); + } + new_mask.validity_mask[validity_idx] >>= sub_units; + } +#ifdef DEBUG + for (idx_t i = offset; i < STANDARD_VECTOR_SIZE; i++) { + D_ASSERT(new_mask.RowIsValid(i - offset) == other.RowIsValid(i)); + } + Initialize(new_mask); +#endif +#endif +} + +enum class ValiditySerialization : uint8_t { BITMASK = 0, VALID_VALUES = 1, INVALID_VALUES = 2 }; + +void ValidityMask::Write(WriteStream &writer, idx_t count) { + auto valid_values = CountValid(count); + auto invalid_values = count - valid_values; + auto bitmask_bytes = ValidityMask::ValidityMaskSize(count); + auto need_u32 = count >= NumericLimits::Maximum(); + auto bytes_per_value = need_u32 ? sizeof(uint32_t) : sizeof(uint16_t); + auto valid_value_size = bytes_per_value * valid_values + sizeof(uint32_t); + auto invalid_value_size = bytes_per_value * invalid_values + sizeof(uint32_t); + if (valid_value_size < bitmask_bytes || invalid_value_size < bitmask_bytes) { + auto serialize_valid = valid_value_size < invalid_value_size; + // serialize (in)valid value indexes as [COUNT][V0][V1][...][VN] + auto flag = serialize_valid ? ValiditySerialization::VALID_VALUES : ValiditySerialization::INVALID_VALUES; + writer.Write(flag); + writer.Write(MinValue(valid_values, invalid_values)); + for (idx_t i = 0; i < count; i++) { + if (RowIsValid(i) == serialize_valid) { + if (need_u32) { + writer.Write(i); + } else { + writer.Write(i); + } + } + } + } else { + // serialize the entire bitmask + writer.Write(ValiditySerialization::BITMASK); + writer.WriteData(const_data_ptr_cast(GetData()), bitmask_bytes); + } +} + +void ValidityMask::Read(ReadStream &reader, idx_t count) { + Initialize(count); + // deserialize the storage type + auto flag = reader.Read(); + if (flag == ValiditySerialization::BITMASK) { + // deserialize the bitmask + reader.ReadData(data_ptr_cast(GetData()), ValidityMask::ValidityMaskSize(count)); + return; + } + auto is_u32 = count >= NumericLimits::Maximum(); + auto is_valid = flag == ValiditySerialization::VALID_VALUES; + auto serialize_count = reader.Read(); + if (is_valid) { + SetAllInvalid(count); + } + for (idx_t i = 0; i < serialize_count; i++) { + idx_t index = is_u32 ? reader.Read() : reader.Read(); + Set(index, is_valid); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp new file mode 100644 index 00000000..22a75098 --- /dev/null +++ b/src/duckdb/src/common/types/value.cpp @@ -0,0 +1,1866 @@ +#include "duckdb/common/types/value.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/operator/aggregate_operators.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" + +#include "utf8proc_wrapper.hpp" +#include "duckdb/common/operator/numeric_binary_operators.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/types/blob.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/uuid.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/main/error_manager.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Extra Value Info +//===--------------------------------------------------------------------===// +enum class ExtraValueInfoType : uint8_t { INVALID_TYPE_INFO = 0, STRING_VALUE_INFO = 1, NESTED_VALUE_INFO = 2 }; + +struct ExtraValueInfo { + explicit ExtraValueInfo(ExtraValueInfoType type) : type(type) { + } + virtual ~ExtraValueInfo() { + } + + ExtraValueInfoType type; + +public: + bool Equals(ExtraValueInfo *other_p) const { + if (!other_p) { + return false; + } + if (type != other_p->type) { + return false; + } + return EqualsInternal(other_p); + } + + template + T &Get() { + if (type != T::TYPE) { + throw InternalException("ExtraValueInfo type mismatch"); + } + return (T &)*this; + } + +protected: + virtual bool EqualsInternal(ExtraValueInfo *other_p) const { + return true; + } +}; + +//===--------------------------------------------------------------------===// +// String Value Info +//===--------------------------------------------------------------------===// +struct StringValueInfo : public ExtraValueInfo { + static constexpr const ExtraValueInfoType TYPE = ExtraValueInfoType::STRING_VALUE_INFO; + +public: + explicit StringValueInfo(string str_p) + : ExtraValueInfo(ExtraValueInfoType::STRING_VALUE_INFO), str(std::move(str_p)) { + } + + const string &GetString() { + return str; + } + +protected: + bool EqualsInternal(ExtraValueInfo *other_p) const override { + return other_p->Get().str == str; + } + + string str; +}; + +//===--------------------------------------------------------------------===// +// Nested Value Info +//===--------------------------------------------------------------------===// +struct NestedValueInfo : public ExtraValueInfo { + static constexpr const ExtraValueInfoType TYPE = ExtraValueInfoType::NESTED_VALUE_INFO; + +public: + NestedValueInfo() : ExtraValueInfo(ExtraValueInfoType::NESTED_VALUE_INFO) { + } + explicit NestedValueInfo(vector values_p) + : ExtraValueInfo(ExtraValueInfoType::NESTED_VALUE_INFO), values(std::move(values_p)) { + } + + const vector &GetValues() { + return values; + } + +protected: + bool EqualsInternal(ExtraValueInfo *other_p) const override { + return other_p->Get().values == values; + } + + vector values; +}; +//===--------------------------------------------------------------------===// +// Value +//===--------------------------------------------------------------------===// +Value::Value(LogicalType type) : type_(std::move(type)), is_null(true) { +} + +Value::Value(int32_t val) : type_(LogicalType::INTEGER), is_null(false) { + value_.integer = val; +} + +Value::Value(int64_t val) : type_(LogicalType::BIGINT), is_null(false) { + value_.bigint = val; +} + +Value::Value(float val) : type_(LogicalType::FLOAT), is_null(false) { + value_.float_ = val; +} + +Value::Value(double val) : type_(LogicalType::DOUBLE), is_null(false) { + value_.double_ = val; +} + +Value::Value(const char *val) : Value(val ? string(val) : string()) { +} + +Value::Value(std::nullptr_t val) : Value(LogicalType::VARCHAR) { +} + +Value::Value(string_t val) : Value(val.GetString()) { +} + +Value::Value(string val) : type_(LogicalType::VARCHAR), is_null(false) { + if (!Value::StringIsValid(val.c_str(), val.size())) { + throw Exception(ErrorManager::InvalidUnicodeError(val, "value construction")); + } + value_info_ = make_shared(std::move(val)); +} + +Value::~Value() { +} + +Value::Value(const Value &other) + : type_(other.type_), is_null(other.is_null), value_(other.value_), value_info_(other.value_info_) { +} + +Value::Value(Value &&other) noexcept + : type_(std::move(other.type_)), is_null(other.is_null), value_(other.value_), + value_info_(std::move(other.value_info_)) { +} + +Value &Value::operator=(const Value &other) { + if (this == &other) { + return *this; + } + type_ = other.type_; + is_null = other.is_null; + value_ = other.value_; + value_info_ = other.value_info_; + return *this; +} + +Value &Value::operator=(Value &&other) noexcept { + type_ = std::move(other.type_); + is_null = other.is_null; + value_ = other.value_; + value_info_ = std::move(other.value_info_); + return *this; +} + +Value Value::MinimumValue(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::BOOLEAN: + return Value::BOOLEAN(false); + case LogicalTypeId::TINYINT: + return Value::TINYINT(NumericLimits::Minimum()); + case LogicalTypeId::SMALLINT: + return Value::SMALLINT(NumericLimits::Minimum()); + case LogicalTypeId::INTEGER: + case LogicalTypeId::SQLNULL: + return Value::INTEGER(NumericLimits::Minimum()); + case LogicalTypeId::BIGINT: + return Value::BIGINT(NumericLimits::Minimum()); + case LogicalTypeId::HUGEINT: + return Value::HUGEINT(NumericLimits::Minimum()); + case LogicalTypeId::UUID: + return Value::UUID(NumericLimits::Minimum()); + case LogicalTypeId::UTINYINT: + return Value::UTINYINT(NumericLimits::Minimum()); + case LogicalTypeId::USMALLINT: + return Value::USMALLINT(NumericLimits::Minimum()); + case LogicalTypeId::UINTEGER: + return Value::UINTEGER(NumericLimits::Minimum()); + case LogicalTypeId::UBIGINT: + return Value::UBIGINT(NumericLimits::Minimum()); + case LogicalTypeId::DATE: + return Value::DATE(Date::FromDate(Date::DATE_MIN_YEAR, Date::DATE_MIN_MONTH, Date::DATE_MIN_DAY)); + case LogicalTypeId::TIME: + return Value::TIME(dtime_t(0)); + case LogicalTypeId::TIMESTAMP: + return Value::TIMESTAMP(Date::FromDate(Timestamp::MIN_YEAR, Timestamp::MIN_MONTH, Timestamp::MIN_DAY), + dtime_t(0)); + case LogicalTypeId::TIMESTAMP_SEC: + return MinimumValue(LogicalType::TIMESTAMP).DefaultCastAs(LogicalType::TIMESTAMP_S); + case LogicalTypeId::TIMESTAMP_MS: + return MinimumValue(LogicalType::TIMESTAMP).DefaultCastAs(LogicalType::TIMESTAMP_MS); + case LogicalTypeId::TIMESTAMP_NS: + return Value::TIMESTAMPNS(timestamp_t(NumericLimits::Minimum())); + case LogicalTypeId::TIME_TZ: + return Value::TIMETZ(dtime_tz_t(dtime_t(0), dtime_tz_t::MIN_OFFSET)); + case LogicalTypeId::TIMESTAMP_TZ: + return Value::TIMESTAMPTZ(Timestamp::FromDatetime( + Date::FromDate(Timestamp::MIN_YEAR, Timestamp::MIN_MONTH, Timestamp::MIN_DAY), dtime_t(0))); + case LogicalTypeId::FLOAT: + return Value::FLOAT(NumericLimits::Minimum()); + case LogicalTypeId::DOUBLE: + return Value::DOUBLE(NumericLimits::Minimum()); + case LogicalTypeId::DECIMAL: { + auto width = DecimalType::GetWidth(type); + auto scale = DecimalType::GetScale(type); + switch (type.InternalType()) { + case PhysicalType::INT16: + return Value::DECIMAL(int16_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); + case PhysicalType::INT32: + return Value::DECIMAL(int32_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); + case PhysicalType::INT64: + return Value::DECIMAL(int64_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); + case PhysicalType::INT128: + return Value::DECIMAL(-Hugeint::POWERS_OF_TEN[width] + 1, width, scale); + default: + throw InternalException("Unknown decimal type"); + } + } + case LogicalTypeId::ENUM: + return Value::ENUM(0, type); + default: + throw InvalidTypeException(type, "MinimumValue requires numeric type"); + } +} + +Value Value::MaximumValue(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::BOOLEAN: + return Value::BOOLEAN(true); + case LogicalTypeId::TINYINT: + return Value::TINYINT(NumericLimits::Maximum()); + case LogicalTypeId::SMALLINT: + return Value::SMALLINT(NumericLimits::Maximum()); + case LogicalTypeId::INTEGER: + case LogicalTypeId::SQLNULL: + return Value::INTEGER(NumericLimits::Maximum()); + case LogicalTypeId::BIGINT: + return Value::BIGINT(NumericLimits::Maximum()); + case LogicalTypeId::HUGEINT: + return Value::HUGEINT(NumericLimits::Maximum()); + case LogicalTypeId::UUID: + return Value::UUID(NumericLimits::Maximum()); + case LogicalTypeId::UTINYINT: + return Value::UTINYINT(NumericLimits::Maximum()); + case LogicalTypeId::USMALLINT: + return Value::USMALLINT(NumericLimits::Maximum()); + case LogicalTypeId::UINTEGER: + return Value::UINTEGER(NumericLimits::Maximum()); + case LogicalTypeId::UBIGINT: + return Value::UBIGINT(NumericLimits::Maximum()); + case LogicalTypeId::DATE: + return Value::DATE(Date::FromDate(Date::DATE_MAX_YEAR, Date::DATE_MAX_MONTH, Date::DATE_MAX_DAY)); + case LogicalTypeId::TIME: + return Value::TIME(dtime_t(Interval::SECS_PER_DAY * Interval::MICROS_PER_SEC - 1)); + case LogicalTypeId::TIMESTAMP: + return Value::TIMESTAMP(timestamp_t(NumericLimits::Maximum() - 1)); + case LogicalTypeId::TIMESTAMP_MS: + return MaximumValue(LogicalType::TIMESTAMP).DefaultCastAs(LogicalType::TIMESTAMP_MS); + case LogicalTypeId::TIMESTAMP_NS: + return Value::TIMESTAMPNS(timestamp_t(NumericLimits::Maximum() - 1)); + case LogicalTypeId::TIMESTAMP_SEC: + return MaximumValue(LogicalType::TIMESTAMP).DefaultCastAs(LogicalType::TIMESTAMP_S); + case LogicalTypeId::TIME_TZ: + return Value::TIMETZ( + dtime_tz_t(dtime_t(Interval::SECS_PER_DAY * Interval::MICROS_PER_SEC - 1), dtime_tz_t::MAX_OFFSET)); + case LogicalTypeId::TIMESTAMP_TZ: + return MaximumValue(LogicalType::TIMESTAMP); + case LogicalTypeId::FLOAT: + return Value::FLOAT(NumericLimits::Maximum()); + case LogicalTypeId::DOUBLE: + return Value::DOUBLE(NumericLimits::Maximum()); + case LogicalTypeId::DECIMAL: { + auto width = DecimalType::GetWidth(type); + auto scale = DecimalType::GetScale(type); + switch (type.InternalType()) { + case PhysicalType::INT16: + return Value::DECIMAL(int16_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); + case PhysicalType::INT32: + return Value::DECIMAL(int32_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); + case PhysicalType::INT64: + return Value::DECIMAL(int64_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); + case PhysicalType::INT128: + return Value::DECIMAL(Hugeint::POWERS_OF_TEN[width] - 1, width, scale); + default: + throw InternalException("Unknown decimal type"); + } + } + case LogicalTypeId::ENUM: + return Value::ENUM(EnumType::GetSize(type) - 1, type); + default: + throw InvalidTypeException(type, "MaximumValue requires numeric type"); + } +} + +Value Value::Infinity(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::DATE: + return Value::DATE(date_t::infinity()); + case LogicalTypeId::TIMESTAMP: + return Value::TIMESTAMP(timestamp_t::infinity()); + case LogicalTypeId::TIMESTAMP_MS: + return Value::TIMESTAMPMS(timestamp_t::infinity()); + case LogicalTypeId::TIMESTAMP_NS: + return Value::TIMESTAMPNS(timestamp_t::infinity()); + case LogicalTypeId::TIMESTAMP_SEC: + return Value::TIMESTAMPSEC(timestamp_t::infinity()); + case LogicalTypeId::TIMESTAMP_TZ: + return Value::TIMESTAMPTZ(timestamp_t::infinity()); + case LogicalTypeId::FLOAT: + return Value::FLOAT(std::numeric_limits::infinity()); + case LogicalTypeId::DOUBLE: + return Value::DOUBLE(std::numeric_limits::infinity()); + default: + throw InvalidTypeException(type, "Infinity requires numeric type"); + } +} + +Value Value::NegativeInfinity(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::DATE: + return Value::DATE(date_t::ninfinity()); + case LogicalTypeId::TIMESTAMP: + return Value::TIMESTAMP(timestamp_t::ninfinity()); + case LogicalTypeId::TIMESTAMP_MS: + return Value::TIMESTAMPMS(timestamp_t::ninfinity()); + case LogicalTypeId::TIMESTAMP_NS: + return Value::TIMESTAMPNS(timestamp_t::ninfinity()); + case LogicalTypeId::TIMESTAMP_SEC: + return Value::TIMESTAMPSEC(timestamp_t::ninfinity()); + case LogicalTypeId::TIMESTAMP_TZ: + return Value::TIMESTAMPTZ(timestamp_t::ninfinity()); + case LogicalTypeId::FLOAT: + return Value::FLOAT(-std::numeric_limits::infinity()); + case LogicalTypeId::DOUBLE: + return Value::DOUBLE(-std::numeric_limits::infinity()); + default: + throw InvalidTypeException(type, "NegativeInfinity requires numeric type"); + } +} + +Value Value::BOOLEAN(int8_t value) { + Value result(LogicalType::BOOLEAN); + result.value_.boolean = bool(value); + result.is_null = false; + return result; +} + +Value Value::TINYINT(int8_t value) { + Value result(LogicalType::TINYINT); + result.value_.tinyint = value; + result.is_null = false; + return result; +} + +Value Value::SMALLINT(int16_t value) { + Value result(LogicalType::SMALLINT); + result.value_.smallint = value; + result.is_null = false; + return result; +} + +Value Value::INTEGER(int32_t value) { + Value result(LogicalType::INTEGER); + result.value_.integer = value; + result.is_null = false; + return result; +} + +Value Value::BIGINT(int64_t value) { + Value result(LogicalType::BIGINT); + result.value_.bigint = value; + result.is_null = false; + return result; +} + +Value Value::HUGEINT(hugeint_t value) { + Value result(LogicalType::HUGEINT); + result.value_.hugeint = value; + result.is_null = false; + return result; +} + +Value Value::UUID(hugeint_t value) { + Value result(LogicalType::UUID); + result.value_.hugeint = value; + result.is_null = false; + return result; +} + +Value Value::UUID(const string &value) { + Value result(LogicalType::UUID); + result.value_.hugeint = UUID::FromString(value); + result.is_null = false; + return result; +} + +Value Value::UTINYINT(uint8_t value) { + Value result(LogicalType::UTINYINT); + result.value_.utinyint = value; + result.is_null = false; + return result; +} + +Value Value::USMALLINT(uint16_t value) { + Value result(LogicalType::USMALLINT); + result.value_.usmallint = value; + result.is_null = false; + return result; +} + +Value Value::UINTEGER(uint32_t value) { + Value result(LogicalType::UINTEGER); + result.value_.uinteger = value; + result.is_null = false; + return result; +} + +Value Value::UBIGINT(uint64_t value) { + Value result(LogicalType::UBIGINT); + result.value_.ubigint = value; + result.is_null = false; + return result; +} + +bool Value::FloatIsFinite(float value) { + return !(std::isnan(value) || std::isinf(value)); +} + +bool Value::DoubleIsFinite(double value) { + return !(std::isnan(value) || std::isinf(value)); +} + +template <> +bool Value::IsNan(float input) { + return std::isnan(input); +} + +template <> +bool Value::IsNan(double input) { + return std::isnan(input); +} + +template <> +bool Value::IsFinite(float input) { + return Value::FloatIsFinite(input); +} + +template <> +bool Value::IsFinite(double input) { + return Value::DoubleIsFinite(input); +} + +template <> +bool Value::IsFinite(date_t input) { + return Date::IsFinite(input); +} + +template <> +bool Value::IsFinite(timestamp_t input) { + return Timestamp::IsFinite(input); +} + +bool Value::StringIsValid(const char *str, idx_t length) { + auto utf_type = Utf8Proc::Analyze(str, length); + return utf_type != UnicodeType::INVALID; +} + +Value Value::DECIMAL(int16_t value, uint8_t width, uint8_t scale) { + return Value::DECIMAL(int64_t(value), width, scale); +} + +Value Value::DECIMAL(int32_t value, uint8_t width, uint8_t scale) { + return Value::DECIMAL(int64_t(value), width, scale); +} + +Value Value::DECIMAL(int64_t value, uint8_t width, uint8_t scale) { + auto decimal_type = LogicalType::DECIMAL(width, scale); + Value result(decimal_type); + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + result.value_.smallint = value; + break; + case PhysicalType::INT32: + result.value_.integer = value; + break; + case PhysicalType::INT64: + result.value_.bigint = value; + break; + default: + result.value_.hugeint = value; + break; + } + result.type_.Verify(); + result.is_null = false; + return result; +} + +Value Value::DECIMAL(hugeint_t value, uint8_t width, uint8_t scale) { + D_ASSERT(width >= Decimal::MAX_WIDTH_INT64 && width <= Decimal::MAX_WIDTH_INT128); + Value result(LogicalType::DECIMAL(width, scale)); + result.value_.hugeint = value; + result.is_null = false; + return result; +} + +Value Value::FLOAT(float value) { + Value result(LogicalType::FLOAT); + result.value_.float_ = value; + result.is_null = false; + return result; +} + +Value Value::DOUBLE(double value) { + Value result(LogicalType::DOUBLE); + result.value_.double_ = value; + result.is_null = false; + return result; +} + +Value Value::HASH(hash_t value) { + Value result(LogicalType::HASH); + result.value_.hash = value; + result.is_null = false; + return result; +} + +Value Value::POINTER(uintptr_t value) { + Value result(LogicalType::POINTER); + result.value_.pointer = value; + result.is_null = false; + return result; +} + +Value Value::DATE(date_t value) { + Value result(LogicalType::DATE); + result.value_.date = value; + result.is_null = false; + return result; +} + +Value Value::DATE(int32_t year, int32_t month, int32_t day) { + return Value::DATE(Date::FromDate(year, month, day)); +} + +Value Value::TIME(dtime_t value) { + Value result(LogicalType::TIME); + result.value_.time = value; + result.is_null = false; + return result; +} + +Value Value::TIMETZ(dtime_tz_t value) { + Value result(LogicalType::TIME_TZ); + result.value_.timetz = value; + result.is_null = false; + return result; +} + +Value Value::TIME(int32_t hour, int32_t min, int32_t sec, int32_t micros) { + return Value::TIME(Time::FromTime(hour, min, sec, micros)); +} + +Value Value::TIMESTAMP(timestamp_t value) { + Value result(LogicalType::TIMESTAMP); + result.value_.timestamp = value; + result.is_null = false; + return result; +} + +Value Value::TIMESTAMPTZ(timestamp_t value) { + Value result(LogicalType::TIMESTAMP_TZ); + result.value_.timestamp = value; + result.is_null = false; + return result; +} + +Value Value::TIMESTAMPNS(timestamp_t timestamp) { + Value result(LogicalType::TIMESTAMP_NS); + result.value_.timestamp = timestamp; + result.is_null = false; + return result; +} + +Value Value::TIMESTAMPMS(timestamp_t timestamp) { + Value result(LogicalType::TIMESTAMP_MS); + result.value_.timestamp = timestamp; + result.is_null = false; + return result; +} + +Value Value::TIMESTAMPSEC(timestamp_t timestamp) { + Value result(LogicalType::TIMESTAMP_S); + result.value_.timestamp = timestamp; + result.is_null = false; + return result; +} + +Value Value::TIMESTAMP(date_t date, dtime_t time) { + return Value::TIMESTAMP(Timestamp::FromDatetime(date, time)); +} + +Value Value::TIMESTAMP(int32_t year, int32_t month, int32_t day, int32_t hour, int32_t min, int32_t sec, + int32_t micros) { + auto val = Value::TIMESTAMP(Date::FromDate(year, month, day), Time::FromTime(hour, min, sec, micros)); + val.type_ = LogicalType::TIMESTAMP; + return val; +} + +Value Value::STRUCT(child_list_t values) { + Value result; + child_list_t child_types; + vector struct_values; + for (auto &child : values) { + child_types.push_back(make_pair(std::move(child.first), child.second.type())); + struct_values.push_back(std::move(child.second)); + } + result.value_info_ = make_shared(std::move(struct_values)); + result.type_ = LogicalType::STRUCT(child_types); + result.is_null = false; + return result; +} + +Value Value::MAP(const LogicalType &child_type, vector values) { + Value result; + + result.type_ = LogicalType::MAP(child_type); + result.is_null = false; + for (auto &val : values) { + D_ASSERT(val.type().InternalType() == PhysicalType::STRUCT); + auto &children = StructValue::GetChildren(val); + + // Ensure that the field containing the keys is called 'key' + // and that the field containing the values is called 'value' + // this is required to make equality checks work + D_ASSERT(children.size() == 2); + child_list_t new_children; + new_children.reserve(2); + new_children.push_back(std::make_pair("key", children[0])); + new_children.push_back(std::make_pair("value", children[1])); + val = Value::STRUCT(std::move(new_children)); + } + result.value_info_ = make_shared(std::move(values)); + return result; +} + +Value Value::UNION(child_list_t members, uint8_t tag, Value value) { + D_ASSERT(!members.empty()); + D_ASSERT(members.size() <= UnionType::MAX_UNION_MEMBERS); + D_ASSERT(members.size() > tag); + + D_ASSERT(value.type() == members[tag].second); + + Value result; + result.is_null = false; + // add the tag to the front of the struct + vector union_values; + union_values.emplace_back(Value::UTINYINT(tag)); + for (idx_t i = 0; i < members.size(); i++) { + if (i != tag) { + union_values.emplace_back(members[i].second); + } else { + union_values.emplace_back(nullptr); + } + } + union_values[tag + 1] = std::move(value); + result.value_info_ = make_shared(std::move(union_values)); + result.type_ = LogicalType::UNION(std::move(members)); + return result; +} + +Value Value::LIST(vector values) { + if (values.empty()) { + throw InternalException("Value::LIST without providing a child-type requires a non-empty list of values. Use " + "Value::LIST(child_type, list) instead."); + } +#ifdef DEBUG + for (idx_t i = 1; i < values.size(); i++) { + D_ASSERT(values[i].type() == values[0].type()); + } +#endif + Value result; + result.type_ = LogicalType::LIST(values[0].type()); + result.value_info_ = make_shared(std::move(values)); + result.is_null = false; + return result; +} + +Value Value::LIST(const LogicalType &child_type, vector values) { + if (values.empty()) { + return Value::EMPTYLIST(child_type); + } + for (auto &val : values) { + val = val.DefaultCastAs(child_type); + } + return Value::LIST(std::move(values)); +} + +Value Value::EMPTYLIST(const LogicalType &child_type) { + Value result; + result.type_ = LogicalType::LIST(child_type); + result.value_info_ = make_shared(); + result.is_null = false; + return result; +} + +Value Value::BLOB(const_data_ptr_t data, idx_t len) { + Value result(LogicalType::BLOB); + result.is_null = false; + result.value_info_ = make_shared(string(const_char_ptr_cast(data), len)); + return result; +} + +Value Value::BLOB(const string &data) { + Value result(LogicalType::BLOB); + result.is_null = false; + result.value_info_ = make_shared(Blob::ToBlob(string_t(data))); + return result; +} + +Value Value::BIT(const_data_ptr_t data, idx_t len) { + Value result(LogicalType::BIT); + result.is_null = false; + result.value_info_ = make_shared(string(const_char_ptr_cast(data), len)); + return result; +} + +Value Value::BIT(const string &data) { + Value result(LogicalType::BIT); + result.is_null = false; + result.value_info_ = make_shared(Bit::ToBit(string_t(data))); + return result; +} + +Value Value::ENUM(uint64_t value, const LogicalType &original_type) { + D_ASSERT(original_type.id() == LogicalTypeId::ENUM); + Value result(original_type); + switch (original_type.InternalType()) { + case PhysicalType::UINT8: + result.value_.utinyint = value; + break; + case PhysicalType::UINT16: + result.value_.usmallint = value; + break; + case PhysicalType::UINT32: + result.value_.uinteger = value; + break; + default: + throw InternalException("Incorrect Physical Type for ENUM"); + } + result.is_null = false; + return result; +} + +Value Value::INTERVAL(int32_t months, int32_t days, int64_t micros) { + Value result(LogicalType::INTERVAL); + result.is_null = false; + result.value_.interval.months = months; + result.value_.interval.days = days; + result.value_.interval.micros = micros; + return result; +} + +Value Value::INTERVAL(interval_t interval) { + return Value::INTERVAL(interval.months, interval.days, interval.micros); +} + +//===--------------------------------------------------------------------===// +// CreateValue +//===--------------------------------------------------------------------===// +template <> +Value Value::CreateValue(bool value) { + return Value::BOOLEAN(value); +} + +template <> +Value Value::CreateValue(int8_t value) { + return Value::TINYINT(value); +} + +template <> +Value Value::CreateValue(int16_t value) { + return Value::SMALLINT(value); +} + +template <> +Value Value::CreateValue(int32_t value) { + return Value::INTEGER(value); +} + +template <> +Value Value::CreateValue(int64_t value) { + return Value::BIGINT(value); +} + +template <> +Value Value::CreateValue(uint8_t value) { + return Value::UTINYINT(value); +} + +template <> +Value Value::CreateValue(uint16_t value) { + return Value::USMALLINT(value); +} + +template <> +Value Value::CreateValue(uint32_t value) { + return Value::UINTEGER(value); +} + +template <> +Value Value::CreateValue(uint64_t value) { + return Value::UBIGINT(value); +} + +template <> +Value Value::CreateValue(hugeint_t value) { + return Value::HUGEINT(value); +} + +template <> +Value Value::CreateValue(date_t value) { + return Value::DATE(value); +} + +template <> +Value Value::CreateValue(dtime_t value) { + return Value::TIME(value); +} + +template <> +Value Value::CreateValue(dtime_tz_t value) { + return Value::TIMETZ(value); +} + +template <> +Value Value::CreateValue(timestamp_t value) { + return Value::TIMESTAMP(value); +} + +template <> +Value Value::CreateValue(timestamp_sec_t value) { + return Value::TIMESTAMPSEC(value); +} + +template <> +Value Value::CreateValue(timestamp_ms_t value) { + return Value::TIMESTAMPMS(value); +} + +template <> +Value Value::CreateValue(timestamp_ns_t value) { + return Value::TIMESTAMPNS(value); +} + +template <> +Value Value::CreateValue(timestamp_tz_t value) { + return Value::TIMESTAMPTZ(value); +} + +template <> +Value Value::CreateValue(const char *value) { + return Value(string(value)); +} + +template <> +Value Value::CreateValue(string value) { // NOLINT: required for templating + return Value::BLOB(value); +} + +template <> +Value Value::CreateValue(string_t value) { + return Value(value); +} + +template <> +Value Value::CreateValue(float value) { + return Value::FLOAT(value); +} + +template <> +Value Value::CreateValue(double value) { + return Value::DOUBLE(value); +} + +template <> +Value Value::CreateValue(interval_t value) { + return Value::INTERVAL(value); +} + +template <> +Value Value::CreateValue(Value value) { + return value; +} + +//===--------------------------------------------------------------------===// +// GetValue +//===--------------------------------------------------------------------===// +template +T Value::GetValueInternal() const { + if (IsNull()) { + throw InternalException("Calling GetValueInternal on a value that is NULL"); + } + switch (type_.id()) { + case LogicalTypeId::BOOLEAN: + return Cast::Operation(value_.boolean); + case LogicalTypeId::TINYINT: + return Cast::Operation(value_.tinyint); + case LogicalTypeId::SMALLINT: + return Cast::Operation(value_.smallint); + case LogicalTypeId::INTEGER: + return Cast::Operation(value_.integer); + case LogicalTypeId::BIGINT: + return Cast::Operation(value_.bigint); + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UUID: + return Cast::Operation(value_.hugeint); + case LogicalTypeId::DATE: + return Cast::Operation(value_.date); + case LogicalTypeId::TIME: + return Cast::Operation(value_.time); + case LogicalTypeId::TIME_TZ: + return Cast::Operation(value_.timetz); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return Cast::Operation(value_.timestamp); + case LogicalTypeId::UTINYINT: + return Cast::Operation(value_.utinyint); + case LogicalTypeId::USMALLINT: + return Cast::Operation(value_.usmallint); + case LogicalTypeId::UINTEGER: + return Cast::Operation(value_.uinteger); + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::UBIGINT: + return Cast::Operation(value_.ubigint); + case LogicalTypeId::FLOAT: + return Cast::Operation(value_.float_); + case LogicalTypeId::DOUBLE: + return Cast::Operation(value_.double_); + case LogicalTypeId::VARCHAR: + return Cast::Operation(StringValue::Get(*this).c_str()); + case LogicalTypeId::INTERVAL: + return Cast::Operation(value_.interval); + case LogicalTypeId::DECIMAL: + return DefaultCastAs(LogicalType::DOUBLE).GetValueInternal(); + case LogicalTypeId::ENUM: { + switch (type_.InternalType()) { + case PhysicalType::UINT8: + return Cast::Operation(value_.utinyint); + case PhysicalType::UINT16: + return Cast::Operation(value_.usmallint); + case PhysicalType::UINT32: + return Cast::Operation(value_.uinteger); + default: + throw InternalException("Invalid Internal Type for ENUMs"); + } + } + default: + throw NotImplementedException("Unimplemented type \"%s\" for GetValue()", type_.ToString()); + } +} + +template <> +bool Value::GetValue() const { + return GetValueInternal(); +} +template <> +int8_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +int16_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +int32_t Value::GetValue() const { + if (type_.id() == LogicalTypeId::DATE) { + return value_.integer; + } + return GetValueInternal(); +} +template <> +int64_t Value::GetValue() const { + if (IsNull()) { + throw InternalException("Calling GetValue on a value that is NULL"); + } + switch (type_.id()) { + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP_TZ: + return value_.bigint; + default: + return GetValueInternal(); + } +} +template <> +hugeint_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +uint8_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +uint16_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +uint32_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +uint64_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +string Value::GetValue() const { + return ToString(); +} +template <> +float Value::GetValue() const { + return GetValueInternal(); +} +template <> +double Value::GetValue() const { + return GetValueInternal(); +} +template <> +date_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +dtime_t Value::GetValue() const { + return GetValueInternal(); +} +template <> +timestamp_t Value::GetValue() const { + return GetValueInternal(); +} + +template <> +DUCKDB_API interval_t Value::GetValue() const { + return GetValueInternal(); +} + +template <> +DUCKDB_API Value Value::GetValue() const { + return Value(*this); +} + +uintptr_t Value::GetPointer() const { + D_ASSERT(type() == LogicalType::POINTER); + return value_.pointer; +} + +Value Value::Numeric(const LogicalType &type, int64_t value) { + switch (type.id()) { + case LogicalTypeId::BOOLEAN: + D_ASSERT(value == 0 || value == 1); + return Value::BOOLEAN(value ? 1 : 0); + case LogicalTypeId::TINYINT: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::TINYINT((int8_t)value); + case LogicalTypeId::SMALLINT: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::SMALLINT((int16_t)value); + case LogicalTypeId::INTEGER: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::INTEGER((int32_t)value); + case LogicalTypeId::BIGINT: + return Value::BIGINT(value); + case LogicalTypeId::UTINYINT: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::UTINYINT((uint8_t)value); + case LogicalTypeId::USMALLINT: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::USMALLINT((uint16_t)value); + case LogicalTypeId::UINTEGER: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::UINTEGER((uint32_t)value); + case LogicalTypeId::UBIGINT: + D_ASSERT(value >= 0); + return Value::UBIGINT(value); + case LogicalTypeId::HUGEINT: + return Value::HUGEINT(value); + case LogicalTypeId::DECIMAL: + return Value::DECIMAL(value, DecimalType::GetWidth(type), DecimalType::GetScale(type)); + case LogicalTypeId::FLOAT: + return Value((float)value); + case LogicalTypeId::DOUBLE: + return Value((double)value); + case LogicalTypeId::POINTER: + return Value::POINTER(value); + case LogicalTypeId::DATE: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::DATE(date_t(value)); + case LogicalTypeId::TIME: + return Value::TIME(dtime_t(value)); + case LogicalTypeId::TIMESTAMP: + return Value::TIMESTAMP(timestamp_t(value)); + case LogicalTypeId::TIMESTAMP_NS: + return Value::TIMESTAMPNS(timestamp_t(value)); + case LogicalTypeId::TIMESTAMP_MS: + return Value::TIMESTAMPMS(timestamp_t(value)); + case LogicalTypeId::TIMESTAMP_SEC: + return Value::TIMESTAMPSEC(timestamp_t(value)); + case LogicalTypeId::TIMESTAMP_TZ: + return Value::TIMESTAMPTZ(timestamp_t(value)); + case LogicalTypeId::ENUM: + switch (type.InternalType()) { + case PhysicalType::UINT8: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::UTINYINT((uint8_t)value); + case PhysicalType::UINT16: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::USMALLINT((uint16_t)value); + case PhysicalType::UINT32: + D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); + return Value::UINTEGER((uint32_t)value); + default: + throw InternalException("Enum doesn't accept this physical type"); + } + default: + throw InvalidTypeException(type, "Numeric requires numeric type"); + } +} + +Value Value::Numeric(const LogicalType &type, hugeint_t value) { +#ifdef DEBUG + // perform a throwing cast to verify that the type fits + Value::HUGEINT(value).DefaultCastAs(type); +#endif + switch (type.id()) { + case LogicalTypeId::HUGEINT: + return Value::HUGEINT(value); + case LogicalTypeId::UBIGINT: + return Value::UBIGINT(Hugeint::Cast(value)); + default: + return Value::Numeric(type, Hugeint::Cast(value)); + } +} + +//===--------------------------------------------------------------------===// +// GetValueUnsafe +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::BOOL); + return value_.boolean; +} + +template <> +int8_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INT8 || type_.InternalType() == PhysicalType::BOOL); + return value_.tinyint; +} + +template <> +int16_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INT16); + return value_.smallint; +} + +template <> +int32_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INT32); + return value_.integer; +} + +template <> +int64_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INT64); + return value_.bigint; +} + +template <> +hugeint_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INT128); + return value_.hugeint; +} + +template <> +uint8_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::UINT8); + return value_.utinyint; +} + +template <> +uint16_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::UINT16); + return value_.usmallint; +} + +template <> +uint32_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::UINT32); + return value_.uinteger; +} + +template <> +uint64_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::UINT64); + return value_.ubigint; +} + +template <> +string Value::GetValueUnsafe() const { + return StringValue::Get(*this); +} + +template <> +DUCKDB_API string_t Value::GetValueUnsafe() const { + return string_t(StringValue::Get(*this)); +} + +template <> +float Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::FLOAT); + return value_.float_; +} + +template <> +double Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::DOUBLE); + return value_.double_; +} + +template <> +date_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INT32); + return value_.date; +} + +template <> +dtime_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INT64); + return value_.time; +} + +template <> +timestamp_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INT64); + return value_.timestamp; +} + +template <> +interval_t Value::GetValueUnsafe() const { + D_ASSERT(type_.InternalType() == PhysicalType::INTERVAL); + return value_.interval; +} + +//===--------------------------------------------------------------------===// +// Hash +//===--------------------------------------------------------------------===// +hash_t Value::Hash() const { + if (IsNull()) { + return 0; + } + Vector input(*this); + Vector result(LogicalType::HASH); + VectorOperations::Hash(input, result, 1); + + auto data = FlatVector::GetData(result); + return data[0]; +} + +string Value::ToString() const { + if (IsNull()) { + return "NULL"; + } + return StringValue::Get(DefaultCastAs(LogicalType::VARCHAR)); +} + +string Value::ToSQLString() const { + if (IsNull()) { + return ToString(); + } + switch (type_.id()) { + case LogicalTypeId::UUID: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::INTERVAL: + case LogicalTypeId::BLOB: + return "'" + ToString() + "'::" + type_.ToString(); + case LogicalTypeId::VARCHAR: + case LogicalTypeId::ENUM: + return "'" + StringUtil::Replace(ToString(), "'", "''") + "'"; + case LogicalTypeId::STRUCT: { + string ret = "{"; + auto &child_types = StructType::GetChildTypes(type_); + auto &struct_values = StructValue::GetChildren(*this); + for (size_t i = 0; i < struct_values.size(); i++) { + auto &name = child_types[i].first; + auto &child = struct_values[i]; + ret += "'" + name + "': " + child.ToSQLString(); + if (i < struct_values.size() - 1) { + ret += ", "; + } + } + ret += "}"; + return ret; + } + case LogicalTypeId::FLOAT: + if (!FloatIsFinite(FloatValue::Get(*this))) { + return "'" + ToString() + "'::" + type_.ToString(); + } + return ToString(); + case LogicalTypeId::DOUBLE: { + double val = DoubleValue::Get(*this); + if (!DoubleIsFinite(val)) { + if (!Value::IsNan(val)) { + // to infinity and beyond + return val < 0 ? "-1e1000" : "1e1000"; + } + return "'" + ToString() + "'::" + type_.ToString(); + } + return ToString(); + } + case LogicalTypeId::LIST: { + string ret = "["; + auto &list_values = ListValue::GetChildren(*this); + for (size_t i = 0; i < list_values.size(); i++) { + auto &child = list_values[i]; + ret += child.ToSQLString(); + if (i < list_values.size() - 1) { + ret += ", "; + } + } + ret += "]"; + return ret; + } + default: + return ToString(); + } +} + +//===--------------------------------------------------------------------===// +// Type-specific getters +//===--------------------------------------------------------------------===// +bool BooleanValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +int8_t TinyIntValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +int16_t SmallIntValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +int32_t IntegerValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +int64_t BigIntValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +hugeint_t HugeIntValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +uint8_t UTinyIntValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +uint16_t USmallIntValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +uint32_t UIntegerValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +uint64_t UBigIntValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +float FloatValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +double DoubleValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +const string &StringValue::Get(const Value &value) { + if (value.is_null) { + throw InternalException("Calling StringValue::Get on a NULL value"); + } + D_ASSERT(value.type().InternalType() == PhysicalType::VARCHAR); + D_ASSERT(value.value_info_); + return value.value_info_->Get().GetString(); +} + +date_t DateValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +dtime_t TimeValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +timestamp_t TimestampValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +interval_t IntervalValue::Get(const Value &value) { + return value.GetValueUnsafe(); +} + +const vector &StructValue::GetChildren(const Value &value) { + if (value.is_null) { + throw InternalException("Calling StructValue::GetChildren on a NULL value"); + } + D_ASSERT(value.type().InternalType() == PhysicalType::STRUCT); + D_ASSERT(value.value_info_); + return value.value_info_->Get().GetValues(); +} + +const vector &ListValue::GetChildren(const Value &value) { + if (value.is_null) { + throw InternalException("Calling ListValue::GetChildren on a NULL value"); + } + D_ASSERT(value.type().InternalType() == PhysicalType::LIST); + D_ASSERT(value.value_info_); + return value.value_info_->Get().GetValues(); +} + +const Value &UnionValue::GetValue(const Value &value) { + D_ASSERT(value.type().id() == LogicalTypeId::UNION); + auto &children = StructValue::GetChildren(value); + auto tag = children[0].GetValueUnsafe(); + D_ASSERT(tag < children.size() - 1); + return children[tag + 1]; +} + +union_tag_t UnionValue::GetTag(const Value &value) { + D_ASSERT(value.type().id() == LogicalTypeId::UNION); + auto children = StructValue::GetChildren(value); + auto tag = children[0].GetValueUnsafe(); + D_ASSERT(tag < children.size() - 1); + return tag; +} + +const LogicalType &UnionValue::GetType(const Value &value) { + return UnionType::GetMemberType(value.type(), UnionValue::GetTag(value)); +} + +hugeint_t IntegralValue::Get(const Value &value) { + switch (value.type().InternalType()) { + case PhysicalType::INT8: + return TinyIntValue::Get(value); + case PhysicalType::INT16: + return SmallIntValue::Get(value); + case PhysicalType::INT32: + return IntegerValue::Get(value); + case PhysicalType::INT64: + return BigIntValue::Get(value); + case PhysicalType::INT128: + return HugeIntValue::Get(value); + case PhysicalType::UINT8: + return UTinyIntValue::Get(value); + case PhysicalType::UINT16: + return USmallIntValue::Get(value); + case PhysicalType::UINT32: + return UIntegerValue::Get(value); + case PhysicalType::UINT64: + return UBigIntValue::Get(value); + default: + throw InternalException("Invalid internal type \"%s\" for IntegralValue::Get", value.type().ToString()); + } +} + +//===--------------------------------------------------------------------===// +// Comparison Operators +//===--------------------------------------------------------------------===// +bool Value::operator==(const Value &rhs) const { + return ValueOperations::Equals(*this, rhs); +} + +bool Value::operator!=(const Value &rhs) const { + return ValueOperations::NotEquals(*this, rhs); +} + +bool Value::operator<(const Value &rhs) const { + return ValueOperations::LessThan(*this, rhs); +} + +bool Value::operator>(const Value &rhs) const { + return ValueOperations::GreaterThan(*this, rhs); +} + +bool Value::operator<=(const Value &rhs) const { + return ValueOperations::LessThanEquals(*this, rhs); +} + +bool Value::operator>=(const Value &rhs) const { + return ValueOperations::GreaterThanEquals(*this, rhs); +} + +bool Value::operator==(const int64_t &rhs) const { + return *this == Value::Numeric(type_, rhs); +} + +bool Value::operator!=(const int64_t &rhs) const { + return *this != Value::Numeric(type_, rhs); +} + +bool Value::operator<(const int64_t &rhs) const { + return *this < Value::Numeric(type_, rhs); +} + +bool Value::operator>(const int64_t &rhs) const { + return *this > Value::Numeric(type_, rhs); +} + +bool Value::operator<=(const int64_t &rhs) const { + return *this <= Value::Numeric(type_, rhs); +} + +bool Value::operator>=(const int64_t &rhs) const { + return *this >= Value::Numeric(type_, rhs); +} + +bool Value::TryCastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, + Value &new_value, string *error_message, bool strict) const { + if (type_ == target_type) { + new_value = Copy(); + return true; + } + Vector input(*this); + Vector result(target_type); + if (!VectorOperations::TryCast(set, get_input, input, result, 1, error_message, strict)) { + return false; + } + new_value = result.GetValue(0); + return true; +} + +bool Value::TryCastAs(ClientContext &context, const LogicalType &target_type, Value &new_value, string *error_message, + bool strict) const { + GetCastFunctionInput get_input(context); + return TryCastAs(CastFunctionSet::Get(context), get_input, target_type, new_value, error_message, strict); +} + +bool Value::DefaultTryCastAs(const LogicalType &target_type, Value &new_value, string *error_message, + bool strict) const { + CastFunctionSet set; + GetCastFunctionInput get_input; + return TryCastAs(set, get_input, target_type, new_value, error_message, strict); +} + +Value Value::CastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, + bool strict) const { + Value new_value; + string error_message; + if (!TryCastAs(set, get_input, target_type, new_value, &error_message, strict)) { + throw InvalidInputException("Failed to cast value: %s", error_message); + } + return new_value; +} + +Value Value::CastAs(ClientContext &context, const LogicalType &target_type, bool strict) const { + GetCastFunctionInput get_input(context); + return CastAs(CastFunctionSet::Get(context), get_input, target_type, strict); +} + +Value Value::DefaultCastAs(const LogicalType &target_type, bool strict) const { + CastFunctionSet set; + GetCastFunctionInput get_input; + return CastAs(set, get_input, target_type, strict); +} + +bool Value::TryCastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, + bool strict) { + Value new_value; + string error_message; + if (!TryCastAs(set, get_input, target_type, new_value, &error_message, strict)) { + return false; + } + type_ = target_type; + is_null = new_value.is_null; + value_ = new_value.value_; + value_info_ = std::move(new_value.value_info_); + return true; +} + +bool Value::TryCastAs(ClientContext &context, const LogicalType &target_type, bool strict) { + GetCastFunctionInput get_input(context); + return TryCastAs(CastFunctionSet::Get(context), get_input, target_type, strict); +} + +bool Value::DefaultTryCastAs(const LogicalType &target_type, bool strict) { + CastFunctionSet set; + GetCastFunctionInput get_input; + return TryCastAs(set, get_input, target_type, strict); +} + +void Value::Reinterpret(LogicalType new_type) { + this->type_ = std::move(new_type); +} + +void Value::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type_); + serializer.WriteProperty(101, "is_null", is_null); + if (!IsNull()) { + switch (type_.InternalType()) { + case PhysicalType::BIT: + throw InternalException("BIT type should not be serialized"); + case PhysicalType::BOOL: + serializer.WriteProperty(102, "value", value_.boolean); + break; + case PhysicalType::INT8: + serializer.WriteProperty(102, "value", value_.tinyint); + break; + case PhysicalType::INT16: + serializer.WriteProperty(102, "value", value_.smallint); + break; + case PhysicalType::INT32: + serializer.WriteProperty(102, "value", value_.integer); + break; + case PhysicalType::INT64: + serializer.WriteProperty(102, "value", value_.bigint); + break; + case PhysicalType::UINT8: + serializer.WriteProperty(102, "value", value_.utinyint); + break; + case PhysicalType::UINT16: + serializer.WriteProperty(102, "value", value_.usmallint); + break; + case PhysicalType::UINT32: + serializer.WriteProperty(102, "value", value_.uinteger); + break; + case PhysicalType::UINT64: + serializer.WriteProperty(102, "value", value_.ubigint); + break; + case PhysicalType::INT128: + serializer.WriteProperty(102, "value", value_.hugeint); + break; + case PhysicalType::FLOAT: + serializer.WriteProperty(102, "value", value_.float_); + break; + case PhysicalType::DOUBLE: + serializer.WriteProperty(102, "value", value_.double_); + break; + case PhysicalType::INTERVAL: + serializer.WriteProperty(102, "value", value_.interval); + break; + case PhysicalType::VARCHAR: { + if (type_.id() == LogicalTypeId::BLOB) { + auto blob_str = Blob::ToString(StringValue::Get(*this)); + serializer.WriteProperty(102, "value", blob_str); + } else { + serializer.WriteProperty(102, "value", StringValue::Get(*this)); + } + } break; + case PhysicalType::LIST: { + serializer.WriteObject(102, "value", [&](Serializer &serializer) { + auto &children = ListValue::GetChildren(*this); + serializer.WriteProperty(100, "children", children); + }); + } break; + case PhysicalType::STRUCT: { + serializer.WriteObject(102, "value", [&](Serializer &serializer) { + auto &children = StructValue::GetChildren(*this); + serializer.WriteProperty(100, "children", children); + }); + } break; + default: + throw NotImplementedException("Unimplemented type for Serialize"); + } + } +} + +Value Value::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto is_null = deserializer.ReadProperty(101, "is_null"); + Value new_value = Value(type); + if (is_null) { + return new_value; + } + new_value.is_null = false; + switch (type.InternalType()) { + case PhysicalType::BIT: + throw InternalException("BIT type should not be deserialized"); + case PhysicalType::BOOL: + new_value.value_.boolean = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::UINT8: + new_value.value_.utinyint = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::INT8: + new_value.value_.tinyint = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::UINT16: + new_value.value_.usmallint = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::INT16: + new_value.value_.smallint = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::UINT32: + new_value.value_.uinteger = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::INT32: + new_value.value_.integer = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::UINT64: + new_value.value_.ubigint = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::INT64: + new_value.value_.bigint = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::INT128: + new_value.value_.hugeint = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::FLOAT: + new_value.value_.float_ = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::DOUBLE: + new_value.value_.double_ = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::INTERVAL: + new_value.value_.interval = deserializer.ReadProperty(102, "value"); + break; + case PhysicalType::VARCHAR: { + auto str = deserializer.ReadProperty(102, "value"); + if (type.id() == LogicalTypeId::BLOB) { + new_value.value_info_ = make_shared(Blob::ToBlob(str)); + } else { + new_value.value_info_ = make_shared(str); + } + } break; + case PhysicalType::LIST: { + deserializer.ReadObject(102, "value", [&](Deserializer &obj) { + auto children = obj.ReadProperty>(100, "children"); + new_value.value_info_ = make_shared(children); + }); + } break; + case PhysicalType::STRUCT: { + deserializer.ReadObject(102, "value", [&](Deserializer &obj) { + auto children = obj.ReadProperty>(100, "children"); + new_value.value_info_ = make_shared(children); + }); + } break; + default: + throw NotImplementedException("Unimplemented type for Deserialize"); + } + return new_value; +} + +void Value::Print() const { + Printer::Print(ToString()); +} + +bool Value::NotDistinctFrom(const Value &lvalue, const Value &rvalue) { + return ValueOperations::NotDistinctFrom(lvalue, rvalue); +} + +static string SanitizeValue(string input) { + // some results might contain padding spaces, e.g. when rendering + // VARCHAR(10) and the string only has 6 characters, they will be padded + // with spaces to 10 in the rendering. We don't do that here yet as we + // are looking at internal structures. So just ignore any extra spaces + // on the right + StringUtil::RTrim(input); + // for result checking code, replace null bytes with their escaped value (\0) + return StringUtil::Replace(input, string("\0", 1), "\\0"); +} + +bool Value::ValuesAreEqual(CastFunctionSet &set, GetCastFunctionInput &get_input, const Value &result_value, + const Value &value) { + if (result_value.IsNull() != value.IsNull()) { + return false; + } + if (result_value.IsNull() && value.IsNull()) { + // NULL = NULL in checking code + return true; + } + switch (value.type_.id()) { + case LogicalTypeId::FLOAT: { + auto other = result_value.CastAs(set, get_input, LogicalType::FLOAT); + float ldecimal = value.value_.float_; + float rdecimal = other.value_.float_; + return ApproxEqual(ldecimal, rdecimal); + } + case LogicalTypeId::DOUBLE: { + auto other = result_value.CastAs(set, get_input, LogicalType::DOUBLE); + double ldecimal = value.value_.double_; + double rdecimal = other.value_.double_; + return ApproxEqual(ldecimal, rdecimal); + } + case LogicalTypeId::VARCHAR: { + auto other = result_value.CastAs(set, get_input, LogicalType::VARCHAR); + string left = SanitizeValue(StringValue::Get(other)); + string right = SanitizeValue(StringValue::Get(value)); + return left == right; + } + default: + if (result_value.type_.id() == LogicalTypeId::FLOAT || result_value.type_.id() == LogicalTypeId::DOUBLE) { + return Value::ValuesAreEqual(set, get_input, value, result_value); + } + return value == result_value; + } +} + +bool Value::ValuesAreEqual(ClientContext &context, const Value &result_value, const Value &value) { + GetCastFunctionInput get_input(context); + return Value::ValuesAreEqual(CastFunctionSet::Get(context), get_input, result_value, value); +} +bool Value::DefaultValuesAreEqual(const Value &result_value, const Value &value) { + CastFunctionSet set; + GetCastFunctionInput get_input; + return Value::ValuesAreEqual(set, get_input, result_value, value); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp new file mode 100644 index 00000000..58b9f162 --- /dev/null +++ b/src/duckdb/src/common/types/vector.cpp @@ -0,0 +1,2019 @@ +#include "duckdb/common/types/vector.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/sel_cache.hpp" +#include "duckdb/common/types/vector_cache.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/storage/string_uncompressed.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/fsst.hpp" +#include "fsst.h" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/value_map.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include // strlen() on Solaris + +namespace duckdb { + +Vector::Vector(LogicalType type_p, bool create_data, bool zero_data, idx_t capacity) + : vector_type(VectorType::FLAT_VECTOR), type(std::move(type_p)), data(nullptr) { + if (create_data) { + Initialize(zero_data, capacity); + } +} + +Vector::Vector(LogicalType type_p, idx_t capacity) : Vector(std::move(type_p), true, false, capacity) { +} + +Vector::Vector(LogicalType type_p, data_ptr_t dataptr) + : vector_type(VectorType::FLAT_VECTOR), type(std::move(type_p)), data(dataptr) { + if (dataptr && !type.IsValid()) { + throw InternalException("Cannot create a vector of type INVALID!"); + } +} + +Vector::Vector(const VectorCache &cache) : type(cache.GetType()) { + ResetFromCache(cache); +} + +Vector::Vector(Vector &other) : type(other.type) { + Reference(other); +} + +Vector::Vector(Vector &other, const SelectionVector &sel, idx_t count) : type(other.type) { + Slice(other, sel, count); +} + +Vector::Vector(Vector &other, idx_t offset, idx_t end) : type(other.type) { + Slice(other, offset, end); +} + +Vector::Vector(const Value &value) : type(value.type()) { + Reference(value); +} + +Vector::Vector(Vector &&other) noexcept + : vector_type(other.vector_type), type(std::move(other.type)), data(other.data), + validity(std::move(other.validity)), buffer(std::move(other.buffer)), auxiliary(std::move(other.auxiliary)) { +} + +void Vector::Reference(const Value &value) { + D_ASSERT(GetType().id() == value.type().id()); + this->vector_type = VectorType::CONSTANT_VECTOR; + buffer = VectorBuffer::CreateConstantVector(value.type()); + auto internal_type = value.type().InternalType(); + if (internal_type == PhysicalType::STRUCT) { + auto struct_buffer = make_uniq(); + auto &child_types = StructType::GetChildTypes(value.type()); + auto &child_vectors = struct_buffer->GetChildren(); + for (idx_t i = 0; i < child_types.size(); i++) { + auto vector = + make_uniq(value.IsNull() ? Value(child_types[i].second) : StructValue::GetChildren(value)[i]); + child_vectors.push_back(std::move(vector)); + } + auxiliary = shared_ptr(struct_buffer.release()); + if (value.IsNull()) { + SetValue(0, value); + } + } else if (internal_type == PhysicalType::LIST) { + auto list_buffer = make_uniq(value.type()); + auxiliary = shared_ptr(list_buffer.release()); + data = buffer->GetData(); + SetValue(0, value); + } else { + auxiliary.reset(); + data = buffer->GetData(); + SetValue(0, value); + } +} + +void Vector::Reference(const Vector &other) { + if (other.GetType().id() != GetType().id()) { + throw InternalException("Vector::Reference used on vector of different type"); + } + D_ASSERT(other.GetType() == GetType()); + Reinterpret(other); +} + +void Vector::ReferenceAndSetType(const Vector &other) { + type = other.GetType(); + Reference(other); +} + +void Vector::Reinterpret(const Vector &other) { + vector_type = other.vector_type; + AssignSharedPointer(buffer, other.buffer); + AssignSharedPointer(auxiliary, other.auxiliary); + data = other.data; + validity = other.validity; +} + +void Vector::ResetFromCache(const VectorCache &cache) { + cache.ResetFromCache(*this); +} + +void Vector::Slice(Vector &other, idx_t offset, idx_t end) { + if (other.GetVectorType() == VectorType::CONSTANT_VECTOR) { + Reference(other); + return; + } + D_ASSERT(other.GetVectorType() == VectorType::FLAT_VECTOR); + + auto internal_type = GetType().InternalType(); + if (internal_type == PhysicalType::STRUCT) { + Vector new_vector(GetType()); + auto &entries = StructVector::GetEntries(new_vector); + auto &other_entries = StructVector::GetEntries(other); + D_ASSERT(entries.size() == other_entries.size()); + for (idx_t i = 0; i < entries.size(); i++) { + entries[i]->Slice(*other_entries[i], offset, end); + } + new_vector.validity.Slice(other.validity, offset, end - offset); + Reference(new_vector); + } else { + Reference(other); + if (offset > 0) { + data = data + GetTypeIdSize(internal_type) * offset; + validity.Slice(other.validity, offset, end - offset); + } + } +} + +void Vector::Slice(Vector &other, const SelectionVector &sel, idx_t count) { + Reference(other); + Slice(sel, count); +} + +void Vector::Slice(const SelectionVector &sel, idx_t count) { + if (GetVectorType() == VectorType::CONSTANT_VECTOR) { + // dictionary on a constant is just a constant + return; + } + if (GetVectorType() == VectorType::DICTIONARY_VECTOR) { + // already a dictionary, slice the current dictionary + auto ¤t_sel = DictionaryVector::SelVector(*this); + auto sliced_dictionary = current_sel.Slice(sel, count); + buffer = make_buffer(std::move(sliced_dictionary)); + if (GetType().InternalType() == PhysicalType::STRUCT) { + auto &child_vector = DictionaryVector::Child(*this); + + Vector new_child(child_vector); + new_child.auxiliary = make_buffer(new_child, sel, count); + auxiliary = make_buffer(std::move(new_child)); + } + return; + } + + if (GetVectorType() == VectorType::FSST_VECTOR) { + Flatten(sel, count); + return; + } + + Vector child_vector(*this); + auto internal_type = GetType().InternalType(); + if (internal_type == PhysicalType::STRUCT) { + child_vector.auxiliary = make_buffer(*this, sel, count); + } + auto child_ref = make_buffer(std::move(child_vector)); + auto dict_buffer = make_buffer(sel); + vector_type = VectorType::DICTIONARY_VECTOR; + buffer = std::move(dict_buffer); + auxiliary = std::move(child_ref); +} + +void Vector::Slice(const SelectionVector &sel, idx_t count, SelCache &cache) { + if (GetVectorType() == VectorType::DICTIONARY_VECTOR && GetType().InternalType() != PhysicalType::STRUCT) { + // dictionary vector: need to merge dictionaries + // check if we have a cached entry + auto ¤t_sel = DictionaryVector::SelVector(*this); + auto target_data = current_sel.data(); + auto entry = cache.cache.find(target_data); + if (entry != cache.cache.end()) { + // cached entry exists: use that + this->buffer = make_buffer(entry->second->Cast().GetSelVector()); + vector_type = VectorType::DICTIONARY_VECTOR; + } else { + Slice(sel, count); + cache.cache[target_data] = this->buffer; + } + } else { + Slice(sel, count); + } +} + +void Vector::Initialize(bool zero_data, idx_t capacity) { + auxiliary.reset(); + validity.Reset(); + auto &type = GetType(); + auto internal_type = type.InternalType(); + if (internal_type == PhysicalType::STRUCT) { + auto struct_buffer = make_uniq(type, capacity); + auxiliary = shared_ptr(struct_buffer.release()); + } else if (internal_type == PhysicalType::LIST) { + auto list_buffer = make_uniq(type, capacity); + auxiliary = shared_ptr(list_buffer.release()); + } + auto type_size = GetTypeIdSize(internal_type); + if (type_size > 0) { + buffer = VectorBuffer::CreateStandardVector(type, capacity); + data = buffer->GetData(); + if (zero_data) { + memset(data, 0, capacity * type_size); + } + } + if (capacity > STANDARD_VECTOR_SIZE) { + validity.Resize(STANDARD_VECTOR_SIZE, capacity); + } +} + +struct DataArrays { + Vector &vec; + data_ptr_t data; + optional_ptr buffer; + idx_t type_size; + bool is_nested; + DataArrays(Vector &vec, data_ptr_t data, optional_ptr buffer, idx_t type_size, bool is_nested) + : vec(vec), data(data), buffer(buffer), type_size(type_size), is_nested(is_nested) { + } +}; + +void FindChildren(vector &to_resize, VectorBuffer &auxiliary) { + if (auxiliary.GetBufferType() == VectorBufferType::LIST_BUFFER) { + auto &buffer = auxiliary.Cast(); + auto &child = buffer.GetChild(); + auto data = child.GetData(); + if (!data) { + //! Nested type + DataArrays arrays(child, data, child.GetBuffer().get(), GetTypeIdSize(child.GetType().InternalType()), + true); + to_resize.emplace_back(arrays); + FindChildren(to_resize, *child.GetAuxiliary()); + } else { + DataArrays arrays(child, data, child.GetBuffer().get(), GetTypeIdSize(child.GetType().InternalType()), + false); + to_resize.emplace_back(arrays); + } + } else if (auxiliary.GetBufferType() == VectorBufferType::STRUCT_BUFFER) { + auto &buffer = auxiliary.Cast(); + auto &children = buffer.GetChildren(); + for (auto &child : children) { + auto data = child->GetData(); + if (!data) { + //! Nested type + DataArrays arrays(*child, data, child->GetBuffer().get(), + GetTypeIdSize(child->GetType().InternalType()), true); + to_resize.emplace_back(arrays); + FindChildren(to_resize, *child->GetAuxiliary()); + } else { + DataArrays arrays(*child, data, child->GetBuffer().get(), + GetTypeIdSize(child->GetType().InternalType()), false); + to_resize.emplace_back(arrays); + } + } + } +} +void Vector::Resize(idx_t cur_size, idx_t new_size) { + vector to_resize; + if (!buffer) { + buffer = make_buffer(0); + } + if (!data) { + //! this is a nested structure + DataArrays arrays(*this, data, buffer.get(), GetTypeIdSize(GetType().InternalType()), true); + to_resize.emplace_back(arrays); + FindChildren(to_resize, *auxiliary); + } else { + DataArrays arrays(*this, data, buffer.get(), GetTypeIdSize(GetType().InternalType()), false); + to_resize.emplace_back(arrays); + } + for (auto &data_to_resize : to_resize) { + if (!data_to_resize.is_nested) { + auto new_data = make_unsafe_uniq_array(new_size * data_to_resize.type_size); + memcpy(new_data.get(), data_to_resize.data, cur_size * data_to_resize.type_size * sizeof(data_t)); + data_to_resize.buffer->SetData(std::move(new_data)); + data_to_resize.vec.data = data_to_resize.buffer->GetData(); + } + data_to_resize.vec.validity.Resize(cur_size, new_size); + } +} + +void Vector::SetValue(idx_t index, const Value &val) { + if (GetVectorType() == VectorType::DICTIONARY_VECTOR) { + // dictionary: apply dictionary and forward to child + auto &sel_vector = DictionaryVector::SelVector(*this); + auto &child = DictionaryVector::Child(*this); + return child.SetValue(sel_vector.get_index(index), val); + } + if (val.type() != GetType()) { + SetValue(index, val.DefaultCastAs(GetType())); + return; + } + D_ASSERT(val.type().InternalType() == GetType().InternalType()); + + validity.EnsureWritable(); + validity.Set(index, !val.IsNull()); + if (val.IsNull() && GetType().InternalType() != PhysicalType::STRUCT) { + // for structs we still need to set the child-entries to NULL + // so we do not bail out yet + return; + } + + switch (GetType().InternalType()) { + case PhysicalType::BOOL: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::INT8: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::INT16: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::INT32: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::INT64: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::INT128: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::UINT8: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::UINT16: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::UINT32: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::UINT64: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::FLOAT: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::DOUBLE: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::INTERVAL: + reinterpret_cast(data)[index] = val.GetValueUnsafe(); + break; + case PhysicalType::VARCHAR: + reinterpret_cast(data)[index] = StringVector::AddStringOrBlob(*this, StringValue::Get(val)); + break; + case PhysicalType::STRUCT: { + D_ASSERT(GetVectorType() == VectorType::CONSTANT_VECTOR || GetVectorType() == VectorType::FLAT_VECTOR); + + auto &children = StructVector::GetEntries(*this); + if (val.IsNull()) { + for (size_t i = 0; i < children.size(); i++) { + auto &vec_child = children[i]; + vec_child->SetValue(index, Value()); + } + } else { + auto &val_children = StructValue::GetChildren(val); + D_ASSERT(children.size() == val_children.size()); + for (size_t i = 0; i < children.size(); i++) { + auto &vec_child = children[i]; + auto &struct_child = val_children[i]; + vec_child->SetValue(index, struct_child); + } + } + break; + } + case PhysicalType::LIST: { + auto offset = ListVector::GetListSize(*this); + auto &val_children = ListValue::GetChildren(val); + if (!val_children.empty()) { + for (idx_t i = 0; i < val_children.size(); i++) { + ListVector::PushBack(*this, val_children[i]); + } + } + //! now set the pointer + auto &entry = reinterpret_cast(data)[index]; + entry.length = val_children.size(); + entry.offset = offset; + break; + } + default: + throw InternalException("Unimplemented type for Vector::SetValue"); + } +} + +Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { + const Vector *vector = &v_p; + idx_t index = index_p; + bool finished = false; + while (!finished) { + switch (vector->GetVectorType()) { + case VectorType::CONSTANT_VECTOR: + index = 0; + finished = true; + break; + case VectorType::FLAT_VECTOR: + finished = true; + break; + case VectorType::FSST_VECTOR: + finished = true; + break; + // dictionary: apply dictionary and forward to child + case VectorType::DICTIONARY_VECTOR: { + auto &sel_vector = DictionaryVector::SelVector(*vector); + auto &child = DictionaryVector::Child(*vector); + vector = &child; + index = sel_vector.get_index(index); + break; + } + case VectorType::SEQUENCE_VECTOR: { + int64_t start, increment; + SequenceVector::GetSequence(*vector, start, increment); + return Value::Numeric(vector->GetType(), start + increment * index); + } + default: + throw InternalException("Unimplemented vector type for Vector::GetValue"); + } + } + auto data = vector->data; + auto &validity = vector->validity; + auto &type = vector->GetType(); + + if (!validity.RowIsValid(index)) { + return Value(vector->GetType()); + } + + if (vector->GetVectorType() == VectorType::FSST_VECTOR) { + if (vector->GetType().InternalType() != PhysicalType::VARCHAR) { + throw InternalException("FSST Vector with non-string datatype found!"); + } + auto str_compressed = reinterpret_cast(data)[index]; + Value result = FSSTPrimitives::DecompressValue(FSSTVector::GetDecoder(const_cast(*vector)), + str_compressed.GetData(), str_compressed.GetSize()); + return result; + } + + switch (vector->GetType().id()) { + case LogicalTypeId::BOOLEAN: + return Value::BOOLEAN(reinterpret_cast(data)[index]); + case LogicalTypeId::TINYINT: + return Value::TINYINT(reinterpret_cast(data)[index]); + case LogicalTypeId::SMALLINT: + return Value::SMALLINT(reinterpret_cast(data)[index]); + case LogicalTypeId::INTEGER: + return Value::INTEGER(reinterpret_cast(data)[index]); + case LogicalTypeId::DATE: + return Value::DATE(reinterpret_cast(data)[index]); + case LogicalTypeId::TIME: + return Value::TIME(reinterpret_cast(data)[index]); + case LogicalTypeId::TIME_TZ: + return Value::TIMETZ(reinterpret_cast(data)[index]); + case LogicalTypeId::BIGINT: + return Value::BIGINT(reinterpret_cast(data)[index]); + case LogicalTypeId::UTINYINT: + return Value::UTINYINT(reinterpret_cast(data)[index]); + case LogicalTypeId::USMALLINT: + return Value::USMALLINT(reinterpret_cast(data)[index]); + case LogicalTypeId::UINTEGER: + return Value::UINTEGER(reinterpret_cast(data)[index]); + case LogicalTypeId::UBIGINT: + return Value::UBIGINT(reinterpret_cast(data)[index]); + case LogicalTypeId::TIMESTAMP: + return Value::TIMESTAMP(reinterpret_cast(data)[index]); + case LogicalTypeId::TIMESTAMP_NS: + return Value::TIMESTAMPNS(reinterpret_cast(data)[index]); + case LogicalTypeId::TIMESTAMP_MS: + return Value::TIMESTAMPMS(reinterpret_cast(data)[index]); + case LogicalTypeId::TIMESTAMP_SEC: + return Value::TIMESTAMPSEC(reinterpret_cast(data)[index]); + case LogicalTypeId::TIMESTAMP_TZ: + return Value::TIMESTAMPTZ(reinterpret_cast(data)[index]); + case LogicalTypeId::HUGEINT: + return Value::HUGEINT(reinterpret_cast(data)[index]); + case LogicalTypeId::UUID: + return Value::UUID(reinterpret_cast(data)[index]); + case LogicalTypeId::DECIMAL: { + auto width = DecimalType::GetWidth(type); + auto scale = DecimalType::GetScale(type); + switch (type.InternalType()) { + case PhysicalType::INT16: + return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); + case PhysicalType::INT32: + return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); + case PhysicalType::INT64: + return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); + case PhysicalType::INT128: + return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); + default: + throw InternalException("Physical type '%s' has a width bigger than 38, which is not supported", + TypeIdToString(type.InternalType())); + } + } + case LogicalTypeId::ENUM: { + switch (type.InternalType()) { + case PhysicalType::UINT8: + return Value::ENUM(reinterpret_cast(data)[index], type); + case PhysicalType::UINT16: + return Value::ENUM(reinterpret_cast(data)[index], type); + case PhysicalType::UINT32: + return Value::ENUM(reinterpret_cast(data)[index], type); + default: + throw InternalException("ENUM can only have unsigned integers as physical types"); + } + } + case LogicalTypeId::POINTER: + return Value::POINTER(reinterpret_cast(data)[index]); + case LogicalTypeId::FLOAT: + return Value::FLOAT(reinterpret_cast(data)[index]); + case LogicalTypeId::DOUBLE: + return Value::DOUBLE(reinterpret_cast(data)[index]); + case LogicalTypeId::INTERVAL: + return Value::INTERVAL(reinterpret_cast(data)[index]); + case LogicalTypeId::VARCHAR: { + auto str = reinterpret_cast(data)[index]; + return Value(str.GetString()); + } + case LogicalTypeId::AGGREGATE_STATE: + case LogicalTypeId::BLOB: { + auto str = reinterpret_cast(data)[index]; + return Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + case LogicalTypeId::BIT: { + auto str = reinterpret_cast(data)[index]; + return Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + case LogicalTypeId::MAP: { + auto offlen = reinterpret_cast(data)[index]; + auto &child_vec = ListVector::GetEntry(*vector); + duckdb::vector children; + for (idx_t i = offlen.offset; i < offlen.offset + offlen.length; i++) { + children.push_back(child_vec.GetValue(i)); + } + return Value::MAP(ListType::GetChildType(type), std::move(children)); + } + case LogicalTypeId::UNION: { + auto tag = UnionVector::GetTag(*vector, index); + auto value = UnionVector::GetMember(*vector, tag).GetValue(index); + auto members = UnionType::CopyMemberTypes(type); + return Value::UNION(members, tag, std::move(value)); + } + case LogicalTypeId::STRUCT: { + // we can derive the value schema from the vector schema + auto &child_entries = StructVector::GetEntries(*vector); + child_list_t children; + for (idx_t child_idx = 0; child_idx < child_entries.size(); child_idx++) { + auto &struct_child = child_entries[child_idx]; + children.push_back(make_pair(StructType::GetChildName(type, child_idx), struct_child->GetValue(index_p))); + } + return Value::STRUCT(std::move(children)); + } + case LogicalTypeId::LIST: { + auto offlen = reinterpret_cast(data)[index]; + auto &child_vec = ListVector::GetEntry(*vector); + duckdb::vector children; + for (idx_t i = offlen.offset; i < offlen.offset + offlen.length; i++) { + children.push_back(child_vec.GetValue(i)); + } + return Value::LIST(ListType::GetChildType(type), std::move(children)); + } + default: + throw InternalException("Unimplemented type for value access"); + } +} + +Value Vector::GetValue(const Vector &v_p, idx_t index_p) { + auto value = GetValueInternal(v_p, index_p); + // set the alias of the type to the correct value, if there is a type alias + if (v_p.GetType().HasAlias()) { + value.GetTypeMutable().CopyAuxInfo(v_p.GetType()); + } + if (v_p.GetType().id() != LogicalTypeId::AGGREGATE_STATE && value.type().id() != LogicalTypeId::AGGREGATE_STATE) { + + D_ASSERT(v_p.GetType() == value.type()); + } + return value; +} + +Value Vector::GetValue(idx_t index) const { + return GetValue(*this, index); +} + +// LCOV_EXCL_START +string VectorTypeToString(VectorType type) { + switch (type) { + case VectorType::FLAT_VECTOR: + return "FLAT"; + case VectorType::FSST_VECTOR: + return "FSST"; + case VectorType::SEQUENCE_VECTOR: + return "SEQUENCE"; + case VectorType::DICTIONARY_VECTOR: + return "DICTIONARY"; + case VectorType::CONSTANT_VECTOR: + return "CONSTANT"; + default: + return "UNKNOWN"; + } +} + +string Vector::ToString(idx_t count) const { + string retval = + VectorTypeToString(GetVectorType()) + " " + GetType().ToString() + ": " + to_string(count) + " = [ "; + switch (GetVectorType()) { + case VectorType::FLAT_VECTOR: + case VectorType::DICTIONARY_VECTOR: + for (idx_t i = 0; i < count; i++) { + retval += GetValue(i).ToString() + (i == count - 1 ? "" : ", "); + } + break; + case VectorType::FSST_VECTOR: { + for (idx_t i = 0; i < count; i++) { + string_t compressed_string = reinterpret_cast(data)[i]; + Value val = FSSTPrimitives::DecompressValue(FSSTVector::GetDecoder(const_cast(*this)), + compressed_string.GetData(), compressed_string.GetSize()); + retval += GetValue(i).ToString() + (i == count - 1 ? "" : ", "); + } + } break; + case VectorType::CONSTANT_VECTOR: + retval += GetValue(0).ToString(); + break; + case VectorType::SEQUENCE_VECTOR: { + int64_t start, increment; + SequenceVector::GetSequence(*this, start, increment); + for (idx_t i = 0; i < count; i++) { + retval += to_string(start + increment * i) + (i == count - 1 ? "" : ", "); + } + break; + } + default: + retval += "UNKNOWN VECTOR TYPE"; + break; + } + retval += "]"; + return retval; +} + +void Vector::Print(idx_t count) const { + Printer::Print(ToString(count)); +} + +string Vector::ToString() const { + string retval = VectorTypeToString(GetVectorType()) + " " + GetType().ToString() + ": (UNKNOWN COUNT) [ "; + switch (GetVectorType()) { + case VectorType::FLAT_VECTOR: + case VectorType::DICTIONARY_VECTOR: + break; + case VectorType::CONSTANT_VECTOR: + retval += GetValue(0).ToString(); + break; + case VectorType::SEQUENCE_VECTOR: { + break; + } + default: + retval += "UNKNOWN VECTOR TYPE"; + break; + } + retval += "]"; + return retval; +} + +void Vector::Print() const { + Printer::Print(ToString()); +} +// LCOV_EXCL_STOP + +template +static void TemplatedFlattenConstantVector(data_ptr_t data, data_ptr_t old_data, idx_t count) { + auto constant = Load(old_data); + auto output = (T *)data; + for (idx_t i = 0; i < count; i++) { + output[i] = constant; + } +} + +void Vector::Flatten(idx_t count) { + switch (GetVectorType()) { + case VectorType::FLAT_VECTOR: + // already a flat vector + break; + case VectorType::FSST_VECTOR: { + // Even though count may only be a part of the vector, we need to flatten the whole thing due to the way + // ToUnifiedFormat uses flatten + idx_t total_count = FSSTVector::GetCount(*this); + // create vector to decompress into + Vector other(GetType(), total_count); + // now copy the data of this vector to the other vector, decompressing the strings in the process + VectorOperations::Copy(*this, other, total_count, 0, 0); + // create a reference to the data in the other vector + this->Reference(other); + break; + } + case VectorType::DICTIONARY_VECTOR: { + // create a new flat vector of this type + Vector other(GetType(), count); + // now copy the data of this vector to the other vector, removing the selection vector in the process + VectorOperations::Copy(*this, other, count, 0, 0); + // create a reference to the data in the other vector + this->Reference(other); + break; + } + case VectorType::CONSTANT_VECTOR: { + bool is_null = ConstantVector::IsNull(*this); + // allocate a new buffer for the vector + auto old_buffer = std::move(buffer); + auto old_data = data; + buffer = VectorBuffer::CreateStandardVector(type, MaxValue(STANDARD_VECTOR_SIZE, count)); + if (old_buffer) { + D_ASSERT(buffer->GetAuxiliaryData() == nullptr); + // The old buffer might be relying on the auxiliary data, keep it alive + buffer->MoveAuxiliaryData(*old_buffer); + } + data = buffer->GetData(); + vector_type = VectorType::FLAT_VECTOR; + if (is_null) { + // constant NULL, set nullmask + validity.EnsureWritable(); + validity.SetAllInvalid(count); + return; + } + // non-null constant: have to repeat the constant + switch (GetType().InternalType()) { + case PhysicalType::BOOL: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::INT8: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::INT16: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::INT32: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::INT64: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::UINT8: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::UINT16: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::UINT32: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::UINT64: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::INT128: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::FLOAT: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::DOUBLE: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::INTERVAL: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::VARCHAR: + TemplatedFlattenConstantVector(data, old_data, count); + break; + case PhysicalType::LIST: { + TemplatedFlattenConstantVector(data, old_data, count); + break; + } + case PhysicalType::STRUCT: { + auto normalified_buffer = make_uniq(); + + auto &new_children = normalified_buffer->GetChildren(); + + auto &child_entries = StructVector::GetEntries(*this); + for (auto &child : child_entries) { + D_ASSERT(child->GetVectorType() == VectorType::CONSTANT_VECTOR); + auto vector = make_uniq(*child); + vector->Flatten(count); + new_children.push_back(std::move(vector)); + } + auxiliary = shared_ptr(normalified_buffer.release()); + } break; + default: + throw InternalException("Unimplemented type for VectorOperations::Flatten"); + } + break; + } + case VectorType::SEQUENCE_VECTOR: { + int64_t start, increment, sequence_count; + SequenceVector::GetSequence(*this, start, increment, sequence_count); + + buffer = VectorBuffer::CreateStandardVector(GetType()); + data = buffer->GetData(); + VectorOperations::GenerateSequence(*this, sequence_count, start, increment); + break; + } + default: + throw InternalException("Unimplemented type for normalify"); + } +} + +void Vector::Flatten(const SelectionVector &sel, idx_t count) { + switch (GetVectorType()) { + case VectorType::FLAT_VECTOR: + // already a flat vector + break; + case VectorType::FSST_VECTOR: { + // create a new flat vector of this type + Vector other(GetType()); + // copy the data of this vector to the other vector, removing compression and selection vector in the process + VectorOperations::Copy(*this, other, sel, count, 0, 0); + // create a reference to the data in the other vector + this->Reference(other); + break; + } + case VectorType::SEQUENCE_VECTOR: { + int64_t start, increment; + SequenceVector::GetSequence(*this, start, increment); + + buffer = VectorBuffer::CreateStandardVector(GetType()); + data = buffer->GetData(); + VectorOperations::GenerateSequence(*this, count, sel, start, increment); + break; + } + default: + throw InternalException("Unimplemented type for normalify with selection vector"); + } +} + +void Vector::ToUnifiedFormat(idx_t count, UnifiedVectorFormat &format) { + switch (GetVectorType()) { + case VectorType::DICTIONARY_VECTOR: { + auto &sel = DictionaryVector::SelVector(*this); + format.owned_sel.Initialize(sel); + format.sel = &format.owned_sel; + + auto &child = DictionaryVector::Child(*this); + if (child.GetVectorType() == VectorType::FLAT_VECTOR) { + format.data = FlatVector::GetData(child); + format.validity = FlatVector::Validity(child); + } else { + // dictionary with non-flat child: create a new reference to the child and flatten it + Vector child_vector(child); + child_vector.Flatten(sel, count); + auto new_aux = make_buffer(std::move(child_vector)); + + format.data = FlatVector::GetData(new_aux->data); + format.validity = FlatVector::Validity(new_aux->data); + this->auxiliary = std::move(new_aux); + } + break; + } + case VectorType::CONSTANT_VECTOR: + format.sel = ConstantVector::ZeroSelectionVector(count, format.owned_sel); + format.data = ConstantVector::GetData(*this); + format.validity = ConstantVector::Validity(*this); + break; + default: + Flatten(count); + format.sel = FlatVector::IncrementalSelectionVector(); + format.data = FlatVector::GetData(*this); + format.validity = FlatVector::Validity(*this); + break; + } +} + +void Vector::RecursiveToUnifiedFormat(Vector &input, idx_t count, RecursiveUnifiedVectorFormat &data) { + + input.ToUnifiedFormat(count, data.unified); + + if (input.GetType().InternalType() == PhysicalType::LIST) { + auto &child = ListVector::GetEntry(input); + auto child_count = ListVector::GetListSize(input); + data.children.emplace_back(); + Vector::RecursiveToUnifiedFormat(child, child_count, data.children.back()); + + } else if (input.GetType().InternalType() == PhysicalType::STRUCT) { + auto &children = StructVector::GetEntries(input); + for (idx_t i = 0; i < children.size(); i++) { + data.children.emplace_back(); + } + for (idx_t i = 0; i < children.size(); i++) { + Vector::RecursiveToUnifiedFormat(*children[i], count, data.children[i]); + } + } +} + +void Vector::Sequence(int64_t start, int64_t increment, idx_t count) { + this->vector_type = VectorType::SEQUENCE_VECTOR; + this->buffer = make_buffer(sizeof(int64_t) * 3); + auto data = reinterpret_cast(buffer->GetData()); + data[0] = start; + data[1] = increment; + data[2] = int64_t(count); + validity.Reset(); + auxiliary.reset(); +} + +void Vector::Serialize(Serializer &serializer, idx_t count) { + auto &logical_type = GetType(); + + UnifiedVectorFormat vdata; + ToUnifiedFormat(count, vdata); + + const bool all_valid = (count > 0) && !vdata.validity.AllValid(); + serializer.WriteProperty(100, "all_valid", all_valid); + if (all_valid) { + ValidityMask flat_mask(count); + for (idx_t i = 0; i < count; ++i) { + auto row_idx = vdata.sel->get_index(i); + flat_mask.Set(i, vdata.validity.RowIsValid(row_idx)); + } + serializer.WriteProperty(101, "validity", const_data_ptr_cast(flat_mask.GetData()), + flat_mask.ValidityMaskSize(count)); + } + if (TypeIsConstantSize(logical_type.InternalType())) { + // constant size type: simple copy + idx_t write_size = GetTypeIdSize(logical_type.InternalType()) * count; + auto ptr = make_unsafe_uniq_array(write_size); + VectorOperations::WriteToStorage(*this, count, ptr.get()); + serializer.WriteProperty(102, "data", ptr.get(), write_size); + } else { + switch (logical_type.InternalType()) { + case PhysicalType::VARCHAR: { + auto strings = UnifiedVectorFormat::GetData(vdata); + + // Serialize data as a list + serializer.WriteList(102, "data", count, [&](Serializer::List &list, idx_t i) { + auto idx = vdata.sel->get_index(i); + auto str = !vdata.validity.RowIsValid(idx) ? NullValue() : strings[idx]; + list.WriteElement(str); + }); + break; + } + case PhysicalType::STRUCT: { + auto &entries = StructVector::GetEntries(*this); + + // Serialize entries as a list + serializer.WriteList(103, "children", entries.size(), [&](Serializer::List &list, idx_t i) { + list.WriteObject([&](Serializer &object) { entries[i]->Serialize(object, count); }); + }); + break; + } + case PhysicalType::LIST: { + auto &child = ListVector::GetEntry(*this); + auto list_size = ListVector::GetListSize(*this); + + // serialize the list entries in a flat array + auto entries = make_unsafe_uniq_array(count); + auto source_array = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + auto source = source_array[idx]; + entries[i].offset = source.offset; + entries[i].length = source.length; + } + serializer.WriteProperty(104, "list_size", list_size); + serializer.WriteList(105, "entries", count, [&](Serializer::List &list, idx_t i) { + list.WriteObject([&](Serializer &object) { + object.WriteProperty(100, "offset", entries[i].offset); + object.WriteProperty(101, "length", entries[i].length); + }); + }); + serializer.WriteObject(106, "child", [&](Serializer &object) { child.Serialize(object, list_size); }); + break; + } + default: + throw InternalException("Unimplemented variable width type for Vector::Serialize!"); + } + } +} + +void Vector::Deserialize(Deserializer &deserializer, idx_t count) { + auto &logical_type = GetType(); + + auto &validity = FlatVector::Validity(*this); + validity.Reset(); + const auto has_validity = deserializer.ReadProperty(100, "all_valid"); + if (has_validity) { + validity.Initialize(count); + deserializer.ReadProperty(101, "validity", data_ptr_cast(validity.GetData()), validity.ValidityMaskSize(count)); + } + + if (TypeIsConstantSize(logical_type.InternalType())) { + // constant size type: read fixed amount of data + auto column_size = GetTypeIdSize(logical_type.InternalType()) * count; + auto ptr = make_unsafe_uniq_array(column_size); + deserializer.ReadProperty(102, "data", ptr.get(), column_size); + + VectorOperations::ReadFromStorage(ptr.get(), count, *this); + } else { + switch (logical_type.InternalType()) { + case PhysicalType::VARCHAR: { + auto strings = FlatVector::GetData(*this); + deserializer.ReadList(102, "data", [&](Deserializer::List &list, idx_t i) { + auto str = list.ReadElement(); + if (validity.RowIsValid(i)) { + strings[i] = StringVector::AddStringOrBlob(*this, str); + } + }); + break; + } + case PhysicalType::STRUCT: { + auto &entries = StructVector::GetEntries(*this); + // Deserialize entries as a list + deserializer.ReadList(103, "children", [&](Deserializer::List &list, idx_t i) { + list.ReadObject([&](Deserializer &obj) { entries[i]->Deserialize(obj, count); }); + }); + break; + } + case PhysicalType::LIST: { + // Read the list size + auto list_size = deserializer.ReadProperty(104, "list_size"); + ListVector::Reserve(*this, list_size); + ListVector::SetListSize(*this, list_size); + + // Read the entries + auto list_entries = FlatVector::GetData(*this); + deserializer.ReadList(105, "entries", [&](Deserializer::List &list, idx_t i) { + list.ReadObject([&](Deserializer &obj) { + list_entries[i].offset = obj.ReadProperty(100, "offset"); + list_entries[i].length = obj.ReadProperty(101, "length"); + }); + }); + + // Read the child vector + deserializer.ReadObject(106, "child", [&](Deserializer &obj) { + auto &child = ListVector::GetEntry(*this); + child.Deserialize(obj, list_size); + }); + break; + } + default: + throw InternalException("Unimplemented variable width type for Vector::Deserialize!"); + } + } +} + +void Vector::SetVectorType(VectorType vector_type_p) { + this->vector_type = vector_type_p; + if (TypeIsConstantSize(GetType().InternalType()) && + (GetVectorType() == VectorType::CONSTANT_VECTOR || GetVectorType() == VectorType::FLAT_VECTOR)) { + auxiliary.reset(); + } + if (vector_type == VectorType::CONSTANT_VECTOR && GetType().InternalType() == PhysicalType::STRUCT) { + auto &entries = StructVector::GetEntries(*this); + for (auto &entry : entries) { + entry->SetVectorType(vector_type); + } + } +} + +void Vector::UTFVerify(const SelectionVector &sel, idx_t count) { +#ifdef DEBUG + if (count == 0) { + return; + } + if (GetType().InternalType() == PhysicalType::VARCHAR) { + // we just touch all the strings and let the sanitizer figure out if any + // of them are deallocated/corrupt + switch (GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + auto string = ConstantVector::GetData(*this); + if (!ConstantVector::IsNull(*this)) { + string->Verify(); + } + break; + } + case VectorType::FLAT_VECTOR: { + auto strings = FlatVector::GetData(*this); + for (idx_t i = 0; i < count; i++) { + auto oidx = sel.get_index(i); + if (validity.RowIsValid(oidx)) { + strings[oidx].Verify(); + } + } + break; + } + default: + break; + } + } +#endif +} + +void Vector::UTFVerify(idx_t count) { + auto flat_sel = FlatVector::IncrementalSelectionVector(); + + UTFVerify(*flat_sel, count); +} + +void Vector::VerifyMap(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { +#ifdef DEBUG + D_ASSERT(vector_p.GetType().id() == LogicalTypeId::MAP); + auto &child = ListType::GetChildType(vector_p.GetType()); + D_ASSERT(StructType::GetChildCount(child) == 2); + D_ASSERT(StructType::GetChildName(child, 0) == "key"); + D_ASSERT(StructType::GetChildName(child, 1) == "value"); + + auto valid_check = MapVector::CheckMapValidity(vector_p, count, sel_p); + D_ASSERT(valid_check == MapInvalidReason::VALID); +#endif // DEBUG +} + +void Vector::VerifyUnion(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { +#ifdef DEBUG + D_ASSERT(vector_p.GetType().id() == LogicalTypeId::UNION); + auto valid_check = UnionVector::CheckUnionValidity(vector_p, count, sel_p); + D_ASSERT(valid_check == UnionInvalidReason::VALID); +#endif // DEBUG +} + +void Vector::Verify(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { +#ifdef DEBUG + if (count == 0) { + return; + } + Vector *vector = &vector_p; + const SelectionVector *sel = &sel_p; + SelectionVector owned_sel; + auto &type = vector->GetType(); + auto vtype = vector->GetVectorType(); + if (vector->GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(*vector); + D_ASSERT(child.GetVectorType() != VectorType::DICTIONARY_VECTOR); + auto &dict_sel = DictionaryVector::SelVector(*vector); + // merge the selection vectors and verify the child + auto new_buffer = dict_sel.Slice(*sel, count); + owned_sel.Initialize(new_buffer); + sel = &owned_sel; + vector = &child; + vtype = vector->GetVectorType(); + } + if (TypeIsConstantSize(type.InternalType()) && + (vtype == VectorType::CONSTANT_VECTOR || vtype == VectorType::FLAT_VECTOR)) { + D_ASSERT(!vector->auxiliary); + } + if (type.id() == LogicalTypeId::VARCHAR) { + // verify that the string is correct unicode + switch (vtype) { + case VectorType::FLAT_VECTOR: { + auto &validity = FlatVector::Validity(*vector); + auto strings = FlatVector::GetData(*vector); + for (idx_t i = 0; i < count; i++) { + auto oidx = sel->get_index(i); + if (validity.RowIsValid(oidx)) { + strings[oidx].Verify(); + } + } + break; + } + default: + break; + } + } + + if (type.id() == LogicalTypeId::BIT) { + switch (vtype) { + case VectorType::FLAT_VECTOR: { + auto &validity = FlatVector::Validity(*vector); + auto strings = FlatVector::GetData(*vector); + for (idx_t i = 0; i < count; i++) { + auto oidx = sel->get_index(i); + if (validity.RowIsValid(oidx)) { + auto buf = strings[oidx].GetData(); + D_ASSERT(*buf >= 0 && *buf < 8); + Bit::Verify(strings[oidx]); + } + } + break; + } + default: + break; + } + } + + if (type.InternalType() == PhysicalType::STRUCT) { + auto &child_types = StructType::GetChildTypes(type); + D_ASSERT(!child_types.empty()); + // create a selection vector of the non-null entries of the struct vector + auto &children = StructVector::GetEntries(*vector); + D_ASSERT(child_types.size() == children.size()); + for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { + D_ASSERT(children[child_idx]->GetType() == child_types[child_idx].second); + Vector::Verify(*children[child_idx], sel_p, count); + if (vtype == VectorType::CONSTANT_VECTOR) { + D_ASSERT(children[child_idx]->GetVectorType() == VectorType::CONSTANT_VECTOR); + if (ConstantVector::IsNull(*vector)) { + D_ASSERT(ConstantVector::IsNull(*children[child_idx])); + } + } + if (vtype != VectorType::FLAT_VECTOR) { + continue; + } + optional_ptr child_validity; + SelectionVector owned_child_sel; + const SelectionVector *child_sel = &owned_child_sel; + if (children[child_idx]->GetVectorType() == VectorType::FLAT_VECTOR) { + child_sel = FlatVector::IncrementalSelectionVector(); + child_validity = &FlatVector::Validity(*children[child_idx]); + } else if (children[child_idx]->GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(*children[child_idx]); + if (child.GetVectorType() != VectorType::FLAT_VECTOR) { + continue; + } + child_validity = &FlatVector::Validity(child); + child_sel = &DictionaryVector::SelVector(*children[child_idx]); + } else if (children[child_idx]->GetVectorType() == VectorType::CONSTANT_VECTOR) { + child_sel = ConstantVector::ZeroSelectionVector(count, owned_child_sel); + child_validity = &ConstantVector::Validity(*children[child_idx]); + } else { + continue; + } + // for any NULL entry in the struct, the child should be NULL as well + auto &validity = FlatVector::Validity(*vector); + for (idx_t i = 0; i < count; i++) { + auto index = sel->get_index(i); + if (!validity.RowIsValid(index)) { + auto child_index = child_sel->get_index(sel_p.get_index(i)); + D_ASSERT(!child_validity->RowIsValid(child_index)); + } + } + } + + if (vector->GetType().id() == LogicalTypeId::UNION) { + VerifyUnion(*vector, *sel, count); + } + } + + if (type.InternalType() == PhysicalType::LIST) { + if (vtype == VectorType::CONSTANT_VECTOR) { + if (!ConstantVector::IsNull(*vector)) { + auto &child = ListVector::GetEntry(*vector); + SelectionVector child_sel(ListVector::GetListSize(*vector)); + idx_t child_count = 0; + auto le = ConstantVector::GetData(*vector); + D_ASSERT(le->offset + le->length <= ListVector::GetListSize(*vector)); + for (idx_t k = 0; k < le->length; k++) { + child_sel.set_index(child_count++, le->offset + k); + } + Vector::Verify(child, child_sel, child_count); + } + } else if (vtype == VectorType::FLAT_VECTOR) { + auto &validity = FlatVector::Validity(*vector); + auto &child = ListVector::GetEntry(*vector); + auto child_size = ListVector::GetListSize(*vector); + auto list_data = FlatVector::GetData(*vector); + idx_t total_size = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = sel->get_index(i); + auto &le = list_data[idx]; + if (validity.RowIsValid(idx)) { + D_ASSERT(le.offset + le.length <= child_size); + total_size += le.length; + } + } + SelectionVector child_sel(total_size); + idx_t child_count = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = sel->get_index(i); + auto &le = list_data[idx]; + if (validity.RowIsValid(idx)) { + D_ASSERT(le.offset + le.length <= child_size); + for (idx_t k = 0; k < le.length; k++) { + child_sel.set_index(child_count++, le.offset + k); + } + } + } + Vector::Verify(child, child_sel, child_count); + } + + if (vector->GetType().id() == LogicalTypeId::MAP) { + VerifyMap(*vector, *sel, count); + } + } +#endif +} + +void Vector::Verify(idx_t count) { + auto flat_sel = FlatVector::IncrementalSelectionVector(); + Verify(*this, *flat_sel, count); +} + +//===--------------------------------------------------------------------===// +// FlatVector +//===--------------------------------------------------------------------===// +void FlatVector::SetNull(Vector &vector, idx_t idx, bool is_null) { + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); + vector.validity.Set(idx, !is_null); + if (is_null && vector.GetType().InternalType() == PhysicalType::STRUCT) { + // set all child entries to null as well + auto &entries = StructVector::GetEntries(vector); + for (auto &entry : entries) { + FlatVector::SetNull(*entry, idx, is_null); + } + } +} + +//===--------------------------------------------------------------------===// +// ConstantVector +//===--------------------------------------------------------------------===// +void ConstantVector::SetNull(Vector &vector, bool is_null) { + D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + vector.validity.Set(0, !is_null); + if (is_null && vector.GetType().InternalType() == PhysicalType::STRUCT) { + // set all child entries to null as well + auto &entries = StructVector::GetEntries(vector); + for (auto &entry : entries) { + entry->SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(*entry, is_null); + } + } +} + +const SelectionVector *ConstantVector::ZeroSelectionVector(idx_t count, SelectionVector &owned_sel) { + if (count <= STANDARD_VECTOR_SIZE) { + return ConstantVector::ZeroSelectionVector(); + } + owned_sel.Initialize(count); + for (idx_t i = 0; i < count; i++) { + owned_sel.set_index(i, 0); + } + return &owned_sel; +} + +void ConstantVector::Reference(Vector &vector, Vector &source, idx_t position, idx_t count) { + auto &source_type = source.GetType(); + switch (source_type.InternalType()) { + case PhysicalType::LIST: { + // retrieve the list entry from the source vector + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + + auto list_index = vdata.sel->get_index(position); + if (!vdata.validity.RowIsValid(list_index)) { + // list is null: create null value + Value null_value(source_type); + vector.Reference(null_value); + break; + } + + auto list_data = UnifiedVectorFormat::GetData(vdata); + auto list_entry = list_data[list_index]; + + // add the list entry as the first element of "vector" + // FIXME: we only need to allocate space for 1 tuple here + auto target_data = FlatVector::GetData(vector); + target_data[0] = list_entry; + + // create a reference to the child list of the source vector + auto &child = ListVector::GetEntry(vector); + child.Reference(ListVector::GetEntry(source)); + + ListVector::SetListSize(vector, ListVector::GetListSize(source)); + vector.SetVectorType(VectorType::CONSTANT_VECTOR); + break; + } + case PhysicalType::STRUCT: { + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + + auto struct_index = vdata.sel->get_index(position); + if (!vdata.validity.RowIsValid(struct_index)) { + // null struct: create null value + Value null_value(source_type); + vector.Reference(null_value); + break; + } + + // struct: pass constant reference into child entries + auto &source_entries = StructVector::GetEntries(source); + auto &target_entries = StructVector::GetEntries(vector); + for (idx_t i = 0; i < source_entries.size(); i++) { + ConstantVector::Reference(*target_entries[i], *source_entries[i], position, count); + } + vector.SetVectorType(VectorType::CONSTANT_VECTOR); + vector.validity.Set(0, true); + break; + } + default: + // default behavior: get a value from the vector and reference it + // this is not that expensive for scalar types + auto value = source.GetValue(position); + vector.Reference(value); + D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + break; + } +} + +//===--------------------------------------------------------------------===// +// StringVector +//===--------------------------------------------------------------------===// +string_t StringVector::AddString(Vector &vector, const char *data, idx_t len) { + return StringVector::AddString(vector, string_t(data, len)); +} + +string_t StringVector::AddStringOrBlob(Vector &vector, const char *data, idx_t len) { + return StringVector::AddStringOrBlob(vector, string_t(data, len)); +} + +string_t StringVector::AddString(Vector &vector, const char *data) { + return StringVector::AddString(vector, string_t(data, strlen(data))); +} + +string_t StringVector::AddString(Vector &vector, const string &data) { + return StringVector::AddString(vector, string_t(data.c_str(), data.size())); +} + +string_t StringVector::AddString(Vector &vector, string_t data) { + D_ASSERT(vector.GetType().id() == LogicalTypeId::VARCHAR || vector.GetType().id() == LogicalTypeId::BIT); + if (data.IsInlined()) { + // string will be inlined: no need to store in string heap + return data; + } + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); + auto &string_buffer = vector.auxiliary->Cast(); + return string_buffer.AddString(data); +} + +string_t StringVector::AddStringOrBlob(Vector &vector, string_t data) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + if (data.IsInlined()) { + // string will be inlined: no need to store in string heap + return data; + } + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); + auto &string_buffer = vector.auxiliary->Cast(); + return string_buffer.AddBlob(data); +} + +string_t StringVector::EmptyString(Vector &vector, idx_t len) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + if (len <= string_t::INLINE_LENGTH) { + return string_t(len); + } + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); + auto &string_buffer = vector.auxiliary->Cast(); + return string_buffer.EmptyString(len); +} + +void StringVector::AddHandle(Vector &vector, BufferHandle handle) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + auto &string_buffer = vector.auxiliary->Cast(); + string_buffer.AddHeapReference(make_buffer(std::move(handle))); +} + +void StringVector::AddBuffer(Vector &vector, buffer_ptr buffer) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + D_ASSERT(buffer.get() != vector.auxiliary.get()); + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + auto &string_buffer = vector.auxiliary->Cast(); + string_buffer.AddHeapReference(std::move(buffer)); +} + +void StringVector::AddHeapReference(Vector &vector, Vector &other) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + D_ASSERT(other.GetType().InternalType() == PhysicalType::VARCHAR); + + if (other.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + StringVector::AddHeapReference(vector, DictionaryVector::Child(other)); + return; + } + if (!other.auxiliary) { + return; + } + StringVector::AddBuffer(vector, other.auxiliary); +} + +//===--------------------------------------------------------------------===// +// FSSTVector +//===--------------------------------------------------------------------===// +string_t FSSTVector::AddCompressedString(Vector &vector, const char *data, idx_t len) { + return FSSTVector::AddCompressedString(vector, string_t(data, len)); +} + +string_t FSSTVector::AddCompressedString(Vector &vector, string_t data) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + if (data.IsInlined()) { + // string will be inlined: no need to store in string heap + return data; + } + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); + auto &fsst_string_buffer = vector.auxiliary->Cast(); + return fsst_string_buffer.AddBlob(data); +} + +void *FSSTVector::GetDecoder(const Vector &vector) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + if (!vector.auxiliary) { + throw InternalException("GetDecoder called on FSST Vector without registered buffer"); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); + auto &fsst_string_buffer = vector.auxiliary->Cast(); + return fsst_string_buffer.GetDecoder(); +} + +void FSSTVector::RegisterDecoder(Vector &vector, buffer_ptr &duckdb_fsst_decoder) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); + + auto &fsst_string_buffer = vector.auxiliary->Cast(); + fsst_string_buffer.AddDecoder(duckdb_fsst_decoder); +} + +void FSSTVector::SetCount(Vector &vector, idx_t count) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); + + auto &fsst_string_buffer = vector.auxiliary->Cast(); + fsst_string_buffer.SetCount(count); +} + +idx_t FSSTVector::GetCount(Vector &vector) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + + if (!vector.auxiliary) { + vector.auxiliary = make_buffer(); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); + + auto &fsst_string_buffer = vector.auxiliary->Cast(); + return fsst_string_buffer.GetCount(); +} + +void FSSTVector::DecompressVector(const Vector &src, Vector &dst, idx_t src_offset, idx_t dst_offset, idx_t copy_count, + const SelectionVector *sel) { + D_ASSERT(src.GetVectorType() == VectorType::FSST_VECTOR); + D_ASSERT(dst.GetVectorType() == VectorType::FLAT_VECTOR); + auto dst_mask = FlatVector::Validity(dst); + auto ldata = FSSTVector::GetCompressedData(src); + auto tdata = FlatVector::GetData(dst); + for (idx_t i = 0; i < copy_count; i++) { + auto source_idx = sel->get_index(src_offset + i); + auto target_idx = dst_offset + i; + string_t compressed_string = ldata[source_idx]; + if (dst_mask.RowIsValid(target_idx) && compressed_string.GetSize() > 0) { + tdata[target_idx] = FSSTPrimitives::DecompressValue( + FSSTVector::GetDecoder(src), dst, compressed_string.GetData(), compressed_string.GetSize()); + } else { + tdata[target_idx] = string_t(nullptr, 0); + } + } +} + +//===--------------------------------------------------------------------===// +// MapVector +//===--------------------------------------------------------------------===// +Vector &MapVector::GetKeys(Vector &vector) { + auto &entries = StructVector::GetEntries(ListVector::GetEntry(vector)); + D_ASSERT(entries.size() == 2); + return *entries[0]; +} +Vector &MapVector::GetValues(Vector &vector) { + auto &entries = StructVector::GetEntries(ListVector::GetEntry(vector)); + D_ASSERT(entries.size() == 2); + return *entries[1]; +} + +const Vector &MapVector::GetKeys(const Vector &vector) { + return GetKeys((Vector &)vector); +} +const Vector &MapVector::GetValues(const Vector &vector) { + return GetValues((Vector &)vector); +} + +MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const SelectionVector &sel) { + D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); + UnifiedVectorFormat map_vdata; + + map.ToUnifiedFormat(count, map_vdata); + auto &map_validity = map_vdata.validity; + + auto list_data = ListVector::GetData(map); + auto &keys = MapVector::GetKeys(map); + UnifiedVectorFormat key_vdata; + keys.ToUnifiedFormat(count, key_vdata); + auto &key_validity = key_vdata.validity; + + for (idx_t row = 0; row < count; row++) { + auto mapped_row = sel.get_index(row); + auto map_idx = map_vdata.sel->get_index(mapped_row); + // map is allowed to be NULL + if (!map_validity.RowIsValid(map_idx)) { + continue; + } + value_set_t unique_keys; + for (idx_t i = 0; i < list_data[map_idx].length; i++) { + auto index = list_data[map_idx].offset + i; + index = key_vdata.sel->get_index(index); + if (!key_validity.RowIsValid(index)) { + return MapInvalidReason::NULL_KEY; + } + auto value = keys.GetValue(index); + auto result = unique_keys.insert(value); + if (!result.second) { + return MapInvalidReason::DUPLICATE_KEY; + } + } + } + return MapInvalidReason::VALID; +} + +void MapVector::MapConversionVerify(Vector &vector, idx_t count) { + auto valid_check = MapVector::CheckMapValidity(vector, count); + switch (valid_check) { + case MapInvalidReason::VALID: + break; + case MapInvalidReason::DUPLICATE_KEY: { + throw InvalidInputException("Map keys have to be unique"); + } + case MapInvalidReason::NULL_KEY: { + throw InvalidInputException("Map keys can not be NULL"); + } + case MapInvalidReason::NULL_KEY_LIST: { + throw InvalidInputException("The list of map keys is not allowed to be NULL"); + } + default: { + throw InternalException("MapInvalidReason not implemented"); + } + } +} + +//===--------------------------------------------------------------------===// +// StructVector +//===--------------------------------------------------------------------===// +vector> &StructVector::GetEntries(Vector &vector) { + D_ASSERT(vector.GetType().id() == LogicalTypeId::STRUCT || vector.GetType().id() == LogicalTypeId::UNION); + + if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(vector); + return StructVector::GetEntries(child); + } + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || + vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + D_ASSERT(vector.auxiliary); + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRUCT_BUFFER); + return vector.auxiliary->Cast().GetChildren(); +} + +const vector> &StructVector::GetEntries(const Vector &vector) { + return GetEntries((Vector &)vector); +} + +//===--------------------------------------------------------------------===// +// ListVector +//===--------------------------------------------------------------------===// +const Vector &ListVector::GetEntry(const Vector &vector) { + D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST || vector.GetType().id() == LogicalTypeId::MAP); + if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(vector); + return ListVector::GetEntry(child); + } + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || + vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + D_ASSERT(vector.auxiliary); + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::LIST_BUFFER); + return vector.auxiliary->Cast().GetChild(); +} + +Vector &ListVector::GetEntry(Vector &vector) { + const Vector &cvector = vector; + return const_cast(ListVector::GetEntry(cvector)); +} + +void ListVector::Reserve(Vector &vector, idx_t required_capacity) { + D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST || vector.GetType().id() == LogicalTypeId::MAP); + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || + vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + D_ASSERT(vector.auxiliary); + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::LIST_BUFFER); + auto &child_buffer = vector.auxiliary->Cast(); + child_buffer.Reserve(required_capacity); +} + +idx_t ListVector::GetListSize(const Vector &vec) { + if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(vec); + return ListVector::GetListSize(child); + } + D_ASSERT(vec.auxiliary); + return vec.auxiliary->Cast().GetSize(); +} + +idx_t ListVector::GetListCapacity(const Vector &vec) { + if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(vec); + return ListVector::GetListSize(child); + } + D_ASSERT(vec.auxiliary); + return vec.auxiliary->Cast().GetCapacity(); +} + +void ListVector::ReferenceEntry(Vector &vector, Vector &other) { + D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST); + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || + vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + D_ASSERT(other.GetType().id() == LogicalTypeId::LIST); + D_ASSERT(other.GetVectorType() == VectorType::FLAT_VECTOR || other.GetVectorType() == VectorType::CONSTANT_VECTOR); + vector.auxiliary = other.auxiliary; +} + +void ListVector::SetListSize(Vector &vec, idx_t size) { + if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(vec); + ListVector::SetListSize(child, size); + } + vec.auxiliary->Cast().SetSize(size); +} + +void ListVector::Append(Vector &target, const Vector &source, idx_t source_size, idx_t source_offset) { + if (source_size - source_offset == 0) { + //! Nothing to add + return; + } + auto &target_buffer = target.auxiliary->Cast(); + target_buffer.Append(source, source_size, source_offset); +} + +void ListVector::Append(Vector &target, const Vector &source, const SelectionVector &sel, idx_t source_size, + idx_t source_offset) { + if (source_size - source_offset == 0) { + //! Nothing to add + return; + } + auto &target_buffer = target.auxiliary->Cast(); + target_buffer.Append(source, sel, source_size, source_offset); +} + +void ListVector::PushBack(Vector &target, const Value &insert) { + auto &target_buffer = target.auxiliary->Cast(); + target_buffer.PushBack(insert); +} + +idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { + + auto info = ListVector::GetConsecutiveChildListInfo(list, offset, count); + if (info.needs_slicing) { + SelectionVector sel(info.child_list_info.length); + ListVector::GetConsecutiveChildSelVector(list, sel, offset, count); + + result.Slice(sel, info.child_list_info.length); + result.Flatten(info.child_list_info.length); + } + return info.child_list_info.length; +} + +ConsecutiveChildListInfo ListVector::GetConsecutiveChildListInfo(Vector &list, idx_t offset, idx_t count) { + + ConsecutiveChildListInfo info; + UnifiedVectorFormat unified_list_data; + list.ToUnifiedFormat(offset + count, unified_list_data); + auto list_data = UnifiedVectorFormat::GetData(unified_list_data); + + // find the first non-NULL entry + idx_t first_length = 0; + for (idx_t i = offset; i < offset + count; i++) { + auto idx = unified_list_data.sel->get_index(i); + if (!unified_list_data.validity.RowIsValid(idx)) { + continue; + } + info.child_list_info.offset = list_data[idx].offset; + first_length = list_data[idx].length; + break; + } + + // small performance improvement for constant vectors + // avoids iterating over all their (constant) elements + if (list.GetVectorType() == VectorType::CONSTANT_VECTOR) { + info.child_list_info.length = first_length; + return info; + } + + // now get the child count and determine whether the children are stored consecutively + // also determine if a flat vector has pseudo constant values (all offsets + length the same) + // this can happen e.g. for UNNESTs + bool is_consecutive = true; + for (idx_t i = offset; i < offset + count; i++) { + auto idx = unified_list_data.sel->get_index(i); + if (!unified_list_data.validity.RowIsValid(idx)) { + continue; + } + if (list_data[idx].offset != info.child_list_info.offset || list_data[idx].length != first_length) { + info.is_constant = false; + } + if (list_data[idx].offset != info.child_list_info.offset + info.child_list_info.length) { + is_consecutive = false; + } + info.child_list_info.length += list_data[idx].length; + } + + if (info.is_constant) { + info.child_list_info.length = first_length; + } + if (!info.is_constant && !is_consecutive) { + info.needs_slicing = true; + } + + return info; +} + +void ListVector::GetConsecutiveChildSelVector(Vector &list, SelectionVector &sel, idx_t offset, idx_t count) { + UnifiedVectorFormat unified_list_data; + list.ToUnifiedFormat(offset + count, unified_list_data); + auto list_data = UnifiedVectorFormat::GetData(unified_list_data); + + // SelectionVector child_sel(info.second.length); + idx_t entry = 0; + for (idx_t i = offset; i < offset + count; i++) { + auto idx = unified_list_data.sel->get_index(i); + if (!unified_list_data.validity.RowIsValid(idx)) { + continue; + } + for (idx_t k = 0; k < list_data[idx].length; k++) { + // child_sel.set_index(entry++, list_data[idx].offset + k); + sel.set_index(entry++, list_data[idx].offset + k); + } + } + // + // result.Slice(child_sel, info.second.length); + // result.Flatten(info.second.length); + // info.second.offset = 0; +} + +//===--------------------------------------------------------------------===// +// UnionVector +//===--------------------------------------------------------------------===// +const Vector &UnionVector::GetMember(const Vector &vector, idx_t member_index) { + D_ASSERT(member_index < UnionType::GetMemberCount(vector.GetType())); + auto &entries = StructVector::GetEntries(vector); + return *entries[member_index + 1]; // skip the "tag" entry +} + +Vector &UnionVector::GetMember(Vector &vector, idx_t member_index) { + D_ASSERT(member_index < UnionType::GetMemberCount(vector.GetType())); + auto &entries = StructVector::GetEntries(vector); + return *entries[member_index + 1]; // skip the "tag" entry +} + +const Vector &UnionVector::GetTags(const Vector &vector) { + // the tag vector is always the first struct child. + return *StructVector::GetEntries(vector)[0]; +} + +Vector &UnionVector::GetTags(Vector &vector) { + // the tag vector is always the first struct child. + return *StructVector::GetEntries(vector)[0]; +} + +void UnionVector::SetToMember(Vector &union_vector, union_tag_t tag, Vector &member_vector, idx_t count, + bool keep_tags_for_null) { + D_ASSERT(union_vector.GetType().id() == LogicalTypeId::UNION); + D_ASSERT(tag < UnionType::GetMemberCount(union_vector.GetType())); + + // Set the union member to the specified vector + UnionVector::GetMember(union_vector, tag).Reference(member_vector); + auto &tag_vector = UnionVector::GetTags(union_vector); + + if (member_vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // if the member vector is constant, we can set the union to constant as well + union_vector.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(tag_vector)[0] = tag; + ConstantVector::SetNull(union_vector, ConstantVector::IsNull(member_vector)); + + } else { + // otherwise flatten and set to flatvector + member_vector.Flatten(count); + union_vector.SetVectorType(VectorType::FLAT_VECTOR); + + if (member_vector.validity.AllValid()) { + // if the member vector is all valid, we can set the tag to constant + tag_vector.SetVectorType(VectorType::CONSTANT_VECTOR); + auto tag_data = ConstantVector::GetData(tag_vector); + *tag_data = tag; + } else { + tag_vector.SetVectorType(VectorType::FLAT_VECTOR); + if (keep_tags_for_null) { + FlatVector::Validity(tag_vector).SetAllValid(count); + FlatVector::Validity(union_vector).SetAllValid(count); + } else { + // ensure the tags have the same validity as the member + FlatVector::Validity(union_vector) = FlatVector::Validity(member_vector); + FlatVector::Validity(tag_vector) = FlatVector::Validity(member_vector); + } + + auto tag_data = FlatVector::GetData(tag_vector); + memset(tag_data, tag, count); + } + } + + // Set the non-selected members to constant null vectors + for (idx_t i = 0; i < UnionType::GetMemberCount(union_vector.GetType()); i++) { + if (i != tag) { + auto &member = UnionVector::GetMember(union_vector, i); + member.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(member, true); + } + } +} + +union_tag_t UnionVector::GetTag(const Vector &vector, idx_t index) { + // the tag vector is always the first struct child. + auto &tag_vector = *StructVector::GetEntries(vector)[0]; + if (tag_vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(tag_vector); + return FlatVector::GetData(child)[index]; + } + if (tag_vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return ConstantVector::GetData(tag_vector)[0]; + } + return FlatVector::GetData(tag_vector)[index]; +} + +UnionInvalidReason UnionVector::CheckUnionValidity(Vector &vector, idx_t count, const SelectionVector &sel) { + D_ASSERT(vector.GetType().id() == LogicalTypeId::UNION); + auto member_count = UnionType::GetMemberCount(vector.GetType()); + if (member_count == 0) { + return UnionInvalidReason::NO_MEMBERS; + } + + UnifiedVectorFormat union_vdata; + vector.ToUnifiedFormat(count, union_vdata); + + UnifiedVectorFormat tags_vdata; + auto &tag_vector = UnionVector::GetTags(vector); + tag_vector.ToUnifiedFormat(count, tags_vdata); + + // check that only one member is valid at a time + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + auto union_mapped_row_idx = sel.get_index(row_idx); + if (!union_vdata.validity.RowIsValid(union_mapped_row_idx)) { + continue; + } + + auto tag_mapped_row_idx = tags_vdata.sel->get_index(row_idx); + if (!tags_vdata.validity.RowIsValid(tag_mapped_row_idx)) { + continue; + } + + auto tag = (UnifiedVectorFormat::GetData(tags_vdata))[tag_mapped_row_idx]; + if (tag >= member_count) { + return UnionInvalidReason::TAG_OUT_OF_RANGE; + } + + bool found_valid = false; + for (idx_t member_idx = 0; member_idx < member_count; member_idx++) { + + UnifiedVectorFormat member_vdata; + auto &member = UnionVector::GetMember(vector, member_idx); + member.ToUnifiedFormat(count, member_vdata); + + auto mapped_row_idx = member_vdata.sel->get_index(row_idx); + if (member_vdata.validity.RowIsValid(mapped_row_idx)) { + if (found_valid) { + return UnionInvalidReason::VALIDITY_OVERLAP; + } + found_valid = true; + if (tag != static_cast(member_idx)) { + return UnionInvalidReason::TAG_MISMATCH; + } + } + } + } + + return UnionInvalidReason::VALID; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector_buffer.cpp b/src/duckdb/src/common/types/vector_buffer.cpp new file mode 100644 index 00000000..fcd7066f --- /dev/null +++ b/src/duckdb/src/common/types/vector_buffer.cpp @@ -0,0 +1,117 @@ +#include "duckdb/common/types/vector_buffer.hpp" + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" + +namespace duckdb { + +buffer_ptr VectorBuffer::CreateStandardVector(PhysicalType type, idx_t capacity) { + return make_buffer(capacity * GetTypeIdSize(type)); +} + +buffer_ptr VectorBuffer::CreateConstantVector(PhysicalType type) { + return make_buffer(GetTypeIdSize(type)); +} + +buffer_ptr VectorBuffer::CreateConstantVector(const LogicalType &type) { + return VectorBuffer::CreateConstantVector(type.InternalType()); +} + +buffer_ptr VectorBuffer::CreateStandardVector(const LogicalType &type, idx_t capacity) { + return VectorBuffer::CreateStandardVector(type.InternalType(), capacity); +} + +VectorStringBuffer::VectorStringBuffer() : VectorBuffer(VectorBufferType::STRING_BUFFER) { +} + +VectorStringBuffer::VectorStringBuffer(VectorBufferType type) : VectorBuffer(type) { +} + +VectorFSSTStringBuffer::VectorFSSTStringBuffer() : VectorStringBuffer(VectorBufferType::FSST_BUFFER) { +} + +VectorStructBuffer::VectorStructBuffer() : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { +} + +VectorStructBuffer::VectorStructBuffer(const LogicalType &type, idx_t capacity) + : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { + auto &child_types = StructType::GetChildTypes(type); + for (auto &child_type : child_types) { + auto vector = make_uniq(child_type.second, capacity); + children.push_back(std::move(vector)); + } +} + +VectorStructBuffer::VectorStructBuffer(Vector &other, const SelectionVector &sel, idx_t count) + : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { + auto &other_vector = StructVector::GetEntries(other); + for (auto &child_vector : other_vector) { + auto vector = make_uniq(*child_vector, sel, count); + children.push_back(std::move(vector)); + } +} + +VectorStructBuffer::~VectorStructBuffer() { +} + +VectorListBuffer::VectorListBuffer(unique_ptr vector, idx_t initial_capacity) + : VectorBuffer(VectorBufferType::LIST_BUFFER), child(std::move(vector)), capacity(initial_capacity) { +} + +VectorListBuffer::VectorListBuffer(const LogicalType &list_type, idx_t initial_capacity) + : VectorBuffer(VectorBufferType::LIST_BUFFER), + child(make_uniq(ListType::GetChildType(list_type), initial_capacity)), capacity(initial_capacity) { +} + +void VectorListBuffer::Reserve(idx_t to_reserve) { + if (to_reserve > capacity) { + idx_t new_capacity = NextPowerOfTwo(to_reserve); + D_ASSERT(new_capacity >= to_reserve); + child->Resize(capacity, new_capacity); + capacity = new_capacity; + } +} + +void VectorListBuffer::Append(const Vector &to_append, idx_t to_append_size, idx_t source_offset) { + Reserve(size + to_append_size - source_offset); + VectorOperations::Copy(to_append, *child, to_append_size, source_offset, size); + size += to_append_size - source_offset; +} + +void VectorListBuffer::Append(const Vector &to_append, const SelectionVector &sel, idx_t to_append_size, + idx_t source_offset) { + Reserve(size + to_append_size - source_offset); + VectorOperations::Copy(to_append, *child, sel, to_append_size, source_offset, size); + size += to_append_size - source_offset; +} + +void VectorListBuffer::PushBack(const Value &insert) { + while (size + 1 > capacity) { + child->Resize(capacity, capacity * 2); + capacity *= 2; + } + child->SetValue(size++, insert); +} + +void VectorListBuffer::SetCapacity(idx_t new_capacity) { + this->capacity = new_capacity; +} + +void VectorListBuffer::SetSize(idx_t new_size) { + this->size = new_size; +} + +VectorListBuffer::~VectorListBuffer() { +} + +ManagedVectorBuffer::ManagedVectorBuffer(BufferHandle handle) + : VectorBuffer(VectorBufferType::MANAGED_BUFFER), handle(std::move(handle)) { +} + +ManagedVectorBuffer::~ManagedVectorBuffer() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector_cache.cpp b/src/duckdb/src/common/types/vector_cache.cpp new file mode 100644 index 00000000..12f8827e --- /dev/null +++ b/src/duckdb/src/common/types/vector_cache.cpp @@ -0,0 +1,115 @@ +#include "duckdb/common/types/vector_cache.hpp" + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +class VectorCacheBuffer : public VectorBuffer { +public: + explicit VectorCacheBuffer(Allocator &allocator, const LogicalType &type_p, idx_t capacity_p = STANDARD_VECTOR_SIZE) + : VectorBuffer(VectorBufferType::OPAQUE_BUFFER), type(type_p), capacity(capacity_p) { + auto internal_type = type.InternalType(); + switch (internal_type) { + case PhysicalType::LIST: { + // memory for the list offsets + owned_data = allocator.Allocate(capacity * GetTypeIdSize(internal_type)); + // child data of the list + auto &child_type = ListType::GetChildType(type); + child_caches.push_back(make_buffer(allocator, child_type, capacity)); + auto child_vector = make_uniq(child_type, false, false); + auxiliary = make_shared(std::move(child_vector)); + break; + } + case PhysicalType::STRUCT: { + auto &child_types = StructType::GetChildTypes(type); + for (auto &child_type : child_types) { + child_caches.push_back(make_buffer(allocator, child_type.second, capacity)); + } + auto struct_buffer = make_shared(type); + auxiliary = std::move(struct_buffer); + break; + } + default: + owned_data = allocator.Allocate(capacity * GetTypeIdSize(internal_type)); + break; + } + } + + void ResetFromCache(Vector &result, const buffer_ptr &buffer) { + D_ASSERT(type == result.GetType()); + auto internal_type = type.InternalType(); + result.vector_type = VectorType::FLAT_VECTOR; + AssignSharedPointer(result.buffer, buffer); + result.validity.Reset(); + switch (internal_type) { + case PhysicalType::LIST: { + result.data = owned_data.get(); + // reinitialize the VectorListBuffer + AssignSharedPointer(result.auxiliary, auxiliary); + // propagate through child + auto &child_cache = child_caches[0]->Cast(); + auto &list_buffer = result.auxiliary->Cast(); + list_buffer.SetCapacity(child_cache.capacity); + list_buffer.SetSize(0); + list_buffer.SetAuxiliaryData(nullptr); + + auto &list_child = list_buffer.GetChild(); + child_cache.ResetFromCache(list_child, child_caches[0]); + break; + } + case PhysicalType::STRUCT: { + // struct does not have data + result.data = nullptr; + // reinitialize the VectorStructBuffer + auxiliary->SetAuxiliaryData(nullptr); + AssignSharedPointer(result.auxiliary, auxiliary); + // propagate through children + auto &children = result.auxiliary->Cast().GetChildren(); + for (idx_t i = 0; i < children.size(); i++) { + auto &child_cache = child_caches[i]->Cast(); + child_cache.ResetFromCache(*children[i], child_caches[i]); + } + break; + } + default: + // regular type: no aux data and reset data to cached data + result.data = owned_data.get(); + result.auxiliary.reset(); + break; + } + } + + const LogicalType &GetType() { + return type; + } + +private: + //! The type of the vector cache + LogicalType type; + //! Owned data + AllocatedData owned_data; + //! Child caches (if any). Used for nested types. + vector> child_caches; + //! Aux data for the vector (if any) + buffer_ptr auxiliary; + //! Capacity of the vector + idx_t capacity; +}; + +VectorCache::VectorCache(Allocator &allocator, const LogicalType &type_p, idx_t capacity_p) { + buffer = make_buffer(allocator, type_p, capacity_p); +} + +void VectorCache::ResetFromCache(Vector &result) const { + D_ASSERT(buffer); + auto &vcache = buffer->Cast(); + vcache.ResetFromCache(result, buffer); +} + +const LogicalType &VectorCache::GetType() const { + auto &vcache = buffer->Cast(); + return vcache.GetType(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector_constants.cpp b/src/duckdb/src/common/types/vector_constants.cpp new file mode 100644 index 00000000..78c7fb5b --- /dev/null +++ b/src/duckdb/src/common/types/vector_constants.cpp @@ -0,0 +1,18 @@ +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +const SelectionVector *ConstantVector::ZeroSelectionVector() { + static const SelectionVector ZERO_SELECTION_VECTOR = + SelectionVector(const_cast(ConstantVector::ZERO_VECTOR)); // NOLINT + return &ZERO_SELECTION_VECTOR; +} + +const SelectionVector *FlatVector::IncrementalSelectionVector() { + static const SelectionVector INCREMENTAL_SELECTION_VECTOR; + return &INCREMENTAL_SELECTION_VECTOR; +} + +const sel_t ConstantVector::ZERO_VECTOR[STANDARD_VECTOR_SIZE] = {0}; + +} // namespace duckdb diff --git a/src/duckdb/src/common/value_operations/comparison_operations.cpp b/src/duckdb/src/common/value_operations/comparison_operations.cpp new file mode 100644 index 00000000..3680ff81 --- /dev/null +++ b/src/duckdb/src/common/value_operations/comparison_operations.cpp @@ -0,0 +1,245 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Comparison Operations +//===--------------------------------------------------------------------===// + +struct ValuePositionComparator { + // Return true if the positional Values definitely match. + // Default to the same as the final value + template + static inline bool Definite(const Value &lhs, const Value &rhs) { + return Final(lhs, rhs); + } + + // Select the positional Values that need further testing. + // Usually this means Is Not Distinct, as those are the semantics used by Postges + template + static inline bool Possible(const Value &lhs, const Value &rhs) { + return ValueOperations::NotDistinctFrom(lhs, rhs); + } + + // Return true if the positional Values definitely match in the final position + // This needs to be specialised. + template + static inline bool Final(const Value &lhs, const Value &rhs) { + return false; + } + + // Tie-break based on length when one of the sides has been exhausted, returning true if the LHS matches. + // This essentially means that the existing positions compare equal. + // Default to the same semantics as the OP for idx_t. This works in most cases. + template + static inline bool TieBreak(const idx_t lpos, const idx_t rpos) { + return OP::Operation(lpos, rpos); + } +}; + +// Equals must always check every column +template <> +inline bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { + return false; +} + +template <> +inline bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { + return ValueOperations::NotDistinctFrom(lhs, rhs); +} + +// NotEquals must check everything that matched +template <> +inline bool ValuePositionComparator::Possible(const Value &lhs, const Value &rhs) { + return true; +} + +template <> +inline bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { + return ValueOperations::NotDistinctFrom(lhs, rhs); +} + +// Non-strict inequalities must use strict comparisons for Definite +template <> +bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { + return !ValuePositionComparator::Definite(lhs, rhs); +} + +template <> +bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { + return ValueOperations::DistinctGreaterThan(lhs, rhs); +} + +template <> +bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { + return !ValuePositionComparator::Final(lhs, rhs); +} + +template <> +bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { + return !ValuePositionComparator::Definite(rhs, lhs); +} + +template <> +bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { + return !ValuePositionComparator::Final(rhs, lhs); +} + +// Strict inequalities just use strict for both Definite and Final +template <> +bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { + return ValuePositionComparator::Final(rhs, lhs); +} + +template +static bool TemplatedBooleanOperation(const Value &left, const Value &right) { + const auto &left_type = left.type(); + const auto &right_type = right.type(); + if (left_type != right_type) { + Value left_copy = left; + Value right_copy = right; + + LogicalType comparison_type = BoundComparisonExpression::BindComparison(left_type, right_type); + if (!left_copy.DefaultTryCastAs(comparison_type) || !right_copy.DefaultTryCastAs(comparison_type)) { + return false; + } + D_ASSERT(left_copy.type() == right_copy.type()); + return TemplatedBooleanOperation(left_copy, right_copy); + } + switch (left_type.InternalType()) { + case PhysicalType::BOOL: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::INT8: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::INT16: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::INT32: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::INT64: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::UINT8: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::UINT16: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::UINT32: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::UINT64: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::INT128: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::FLOAT: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::DOUBLE: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::INTERVAL: + return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); + case PhysicalType::VARCHAR: + return OP::Operation(StringValue::Get(left), StringValue::Get(right)); + case PhysicalType::STRUCT: { + auto &left_children = StructValue::GetChildren(left); + auto &right_children = StructValue::GetChildren(right); + // this should be enforced by the type + D_ASSERT(left_children.size() == right_children.size()); + idx_t i = 0; + for (; i < left_children.size() - 1; ++i) { + if (ValuePositionComparator::Definite(left_children[i], right_children[i])) { + return true; + } + if (!ValuePositionComparator::Possible(left_children[i], right_children[i])) { + return false; + } + } + return ValuePositionComparator::Final(left_children[i], right_children[i]); + } + case PhysicalType::LIST: { + auto &left_children = ListValue::GetChildren(left); + auto &right_children = ListValue::GetChildren(right); + for (idx_t pos = 0;; ++pos) { + if (pos == left_children.size() || pos == right_children.size()) { + return ValuePositionComparator::TieBreak(left_children.size(), right_children.size()); + } + if (ValuePositionComparator::Definite(left_children[pos], right_children[pos])) { + return true; + } + if (!ValuePositionComparator::Possible(left_children[pos], right_children[pos])) { + return false; + } + } + return false; + } + default: + throw InternalException("Unimplemented type for value comparison"); + } +} + +bool ValueOperations::Equals(const Value &left, const Value &right) { + if (left.IsNull() || right.IsNull()) { + throw InternalException("Comparison on NULL values"); + } + return TemplatedBooleanOperation(left, right); +} + +bool ValueOperations::NotEquals(const Value &left, const Value &right) { + return !ValueOperations::Equals(left, right); +} + +bool ValueOperations::GreaterThan(const Value &left, const Value &right) { + if (left.IsNull() || right.IsNull()) { + throw InternalException("Comparison on NULL values"); + } + return TemplatedBooleanOperation(left, right); +} + +bool ValueOperations::GreaterThanEquals(const Value &left, const Value &right) { + return !ValueOperations::GreaterThan(right, left); +} + +bool ValueOperations::LessThan(const Value &left, const Value &right) { + return ValueOperations::GreaterThan(right, left); +} + +bool ValueOperations::LessThanEquals(const Value &left, const Value &right) { + return !ValueOperations::GreaterThan(left, right); +} + +bool ValueOperations::NotDistinctFrom(const Value &left, const Value &right) { + if (left.IsNull() && right.IsNull()) { + return true; + } + if (left.IsNull() != right.IsNull()) { + return false; + } + return TemplatedBooleanOperation(left, right); +} + +bool ValueOperations::DistinctFrom(const Value &left, const Value &right) { + return !ValueOperations::NotDistinctFrom(left, right); +} + +bool ValueOperations::DistinctGreaterThan(const Value &left, const Value &right) { + if (left.IsNull() && right.IsNull()) { + return false; + } else if (right.IsNull()) { + return false; + } else if (left.IsNull()) { + return true; + } + return TemplatedBooleanOperation(left, right); +} + +bool ValueOperations::DistinctGreaterThanEquals(const Value &left, const Value &right) { + return !ValueOperations::DistinctGreaterThan(right, left); +} + +bool ValueOperations::DistinctLessThan(const Value &left, const Value &right) { + return ValueOperations::DistinctGreaterThan(right, left); +} + +bool ValueOperations::DistinctLessThanEquals(const Value &left, const Value &right) { + return !ValueOperations::DistinctGreaterThan(left, right); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/boolean_operators.cpp b/src/duckdb/src/common/vector_operations/boolean_operators.cpp new file mode 100644 index 00000000..667c9208 --- /dev/null +++ b/src/duckdb/src/common/vector_operations/boolean_operators.cpp @@ -0,0 +1,177 @@ +//===--------------------------------------------------------------------===// +// boolean_operators.cpp +// Description: This file contains the implementation of the boolean +// operations AND OR ! +//===--------------------------------------------------------------------===// + +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// AND/OR +//===--------------------------------------------------------------------===// +template +static void TemplatedBooleanNullmask(Vector &left, Vector &right, Vector &result, idx_t count) { + D_ASSERT(left.GetType().id() == LogicalTypeId::BOOLEAN && right.GetType().id() == LogicalTypeId::BOOLEAN && + result.GetType().id() == LogicalTypeId::BOOLEAN); + + if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // operation on two constants, result is constant vector + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto ldata = ConstantVector::GetData(left); + auto rdata = ConstantVector::GetData(right); + auto result_data = ConstantVector::GetData(result); + + bool is_null = OP::Operation(*ldata > 0, *rdata > 0, ConstantVector::IsNull(left), + ConstantVector::IsNull(right), *result_data); + ConstantVector::SetNull(result, is_null); + } else { + // perform generic loop + UnifiedVectorFormat ldata, rdata; + left.ToUnifiedFormat(count, ldata); + right.ToUnifiedFormat(count, rdata); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto left_data = UnifiedVectorFormat::GetData(ldata); // we use uint8 to avoid load of gunk bools + auto right_data = UnifiedVectorFormat::GetData(rdata); + auto result_data = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + if (!ldata.validity.AllValid() || !rdata.validity.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto lidx = ldata.sel->get_index(i); + auto ridx = rdata.sel->get_index(i); + bool is_null = + OP::Operation(left_data[lidx] > 0, right_data[ridx] > 0, !ldata.validity.RowIsValid(lidx), + !rdata.validity.RowIsValid(ridx), result_data[i]); + result_mask.Set(i, !is_null); + } + } else { + for (idx_t i = 0; i < count; i++) { + auto lidx = ldata.sel->get_index(i); + auto ridx = rdata.sel->get_index(i); + result_data[i] = OP::SimpleOperation(left_data[lidx], right_data[ridx]); + } + } + } +} + +/* +SQL AND Rules: + +TRUE AND TRUE = TRUE +TRUE AND FALSE = FALSE +TRUE AND NULL = NULL +FALSE AND TRUE = FALSE +FALSE AND FALSE = FALSE +FALSE AND NULL = FALSE +NULL AND TRUE = NULL +NULL AND FALSE = FALSE +NULL AND NULL = NULL + +Basically: +- Only true if both are true +- False if either is false (regardless of NULLs) +- NULL otherwise +*/ +struct TernaryAnd { + static bool SimpleOperation(bool left, bool right) { + return left && right; + } + static bool Operation(bool left, bool right, bool left_null, bool right_null, bool &result) { + if (left_null && right_null) { + // both NULL: + // result is NULL + return true; + } else if (left_null) { + // left is NULL: + // result is FALSE if right is false + // result is NULL if right is true + result = right; + return right; + } else if (right_null) { + // right is NULL: + // result is FALSE if left is false + // result is NULL if left is true + result = left; + return left; + } else { + // no NULL: perform the AND + result = left && right; + return false; + } + } +}; + +void VectorOperations::And(Vector &left, Vector &right, Vector &result, idx_t count) { + TemplatedBooleanNullmask(left, right, result, count); +} + +/* +SQL OR Rules: + +OR +TRUE OR TRUE = TRUE +TRUE OR FALSE = TRUE +TRUE OR NULL = TRUE +FALSE OR TRUE = TRUE +FALSE OR FALSE = FALSE +FALSE OR NULL = NULL +NULL OR TRUE = TRUE +NULL OR FALSE = NULL +NULL OR NULL = NULL + +Basically: +- Only false if both are false +- True if either is true (regardless of NULLs) +- NULL otherwise +*/ + +struct TernaryOr { + static bool SimpleOperation(bool left, bool right) { + return left || right; + } + static bool Operation(bool left, bool right, bool left_null, bool right_null, bool &result) { + if (left_null && right_null) { + // both NULL: + // result is NULL + return true; + } else if (left_null) { + // left is NULL: + // result is TRUE if right is true + // result is NULL if right is false + result = right; + return !right; + } else if (right_null) { + // right is NULL: + // result is TRUE if left is true + // result is NULL if left is false + result = left; + return !left; + } else { + // no NULL: perform the OR + result = left || right; + return false; + } + } +}; + +void VectorOperations::Or(Vector &left, Vector &right, Vector &result, idx_t count) { + TemplatedBooleanNullmask(left, right, result, count); +} + +struct NotOperator { + template + static inline TR Operation(TA left) { + return !left; + } +}; + +void VectorOperations::Not(Vector &input, Vector &result, idx_t count) { + D_ASSERT(input.GetType() == LogicalType::BOOLEAN && result.GetType() == LogicalType::BOOLEAN); + UnaryExecutor::Execute(input, result, count); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/comparison_operators.cpp b/src/duckdb/src/common/vector_operations/comparison_operators.cpp new file mode 100644 index 00000000..33686f37 --- /dev/null +++ b/src/duckdb/src/common/vector_operations/comparison_operators.cpp @@ -0,0 +1,286 @@ +//===--------------------------------------------------------------------===// +// comparison_operators.cpp +// Description: This file contains the implementation of the comparison +// operations == != >= <= > < +//===--------------------------------------------------------------------===// + +#include "duckdb/common/operator/comparison_operators.hpp" + +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include "duckdb/common/likely.hpp" + +namespace duckdb { + +template +bool EqualsFloat(T left, T right) { + if (DUCKDB_UNLIKELY(Value::IsNan(left) && Value::IsNan(right))) { + return true; + } + return left == right; +} + +template <> +bool Equals::Operation(const float &left, const float &right) { + return EqualsFloat(left, right); +} + +template <> +bool Equals::Operation(const double &left, const double &right) { + return EqualsFloat(left, right); +} + +template +bool GreaterThanFloat(T left, T right) { + // handle nans + // nan is always bigger than everything else + bool left_is_nan = Value::IsNan(left); + bool right_is_nan = Value::IsNan(right); + // if right is nan, there is no number that is bigger than right + if (DUCKDB_UNLIKELY(right_is_nan)) { + return false; + } + // if left is nan, but right is not, left is always bigger + if (DUCKDB_UNLIKELY(left_is_nan)) { + return true; + } + return left > right; +} + +template <> +bool GreaterThan::Operation(const float &left, const float &right) { + return GreaterThanFloat(left, right); +} + +template <> +bool GreaterThan::Operation(const double &left, const double &right) { + return GreaterThanFloat(left, right); +} + +template +bool GreaterThanEqualsFloat(T left, T right) { + // handle nans + // nan is always bigger than everything else + bool left_is_nan = Value::IsNan(left); + bool right_is_nan = Value::IsNan(right); + // if right is nan, there is no bigger number + // we only return true if left is also nan (in which case the numbers are equal) + if (DUCKDB_UNLIKELY(right_is_nan)) { + return left_is_nan; + } + // if left is nan, but right is not, left is always bigger + if (DUCKDB_UNLIKELY(left_is_nan)) { + return true; + } + return left >= right; +} + +template <> +bool GreaterThanEquals::Operation(const float &left, const float &right) { + return GreaterThanEqualsFloat(left, right); +} + +template <> +bool GreaterThanEquals::Operation(const double &left, const double &right) { + return GreaterThanEqualsFloat(left, right); +} + +struct ComparisonSelector { + template + static idx_t Select(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + throw NotImplementedException("Unknown comparison operation!"); + } +}; + +template <> +inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel); +} + +template <> +inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel); +} + +template <> +inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel); +} + +template <> +inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, + const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel); +} + +template <> +inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::GreaterThan(right, left, sel, count, true_sel, false_sel); +} + +template <> +inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::GreaterThanEquals(right, left, sel, count, true_sel, false_sel); +} + +static void ComparesNotNull(UnifiedVectorFormat &ldata, UnifiedVectorFormat &rdata, ValidityMask &vresult, + idx_t count) { + for (idx_t i = 0; i < count; ++i) { + auto lidx = ldata.sel->get_index(i); + auto ridx = rdata.sel->get_index(i); + if (!ldata.validity.RowIsValid(lidx) || !rdata.validity.RowIsValid(ridx)) { + vresult.SetInvalid(i); + } + } +} + +template +static void NestedComparisonExecutor(Vector &left, Vector &right, Vector &result, idx_t count) { + const auto left_constant = left.GetVectorType() == VectorType::CONSTANT_VECTOR; + const auto right_constant = right.GetVectorType() == VectorType::CONSTANT_VECTOR; + + if ((left_constant && ConstantVector::IsNull(left)) || (right_constant && ConstantVector::IsNull(right))) { + // either left or right is constant NULL: result is constant NULL + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + + if (left_constant && right_constant) { + // both sides are constant, and neither is NULL so just compare one element. + result.SetVectorType(VectorType::CONSTANT_VECTOR); + SelectionVector true_sel(1); + auto match_count = ComparisonSelector::Select(left, right, nullptr, 1, &true_sel, nullptr); + auto result_data = ConstantVector::GetData(result); + result_data[0] = match_count > 0; + return; + } + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + UnifiedVectorFormat leftv, rightv; + left.ToUnifiedFormat(count, leftv); + right.ToUnifiedFormat(count, rightv); + if (!leftv.validity.AllValid() || !rightv.validity.AllValid()) { + ComparesNotNull(leftv, rightv, result_validity, count); + } + SelectionVector true_sel(count); + SelectionVector false_sel(count); + idx_t match_count = ComparisonSelector::Select(left, right, nullptr, count, &true_sel, &false_sel); + + for (idx_t i = 0; i < match_count; ++i) { + const auto idx = true_sel.get_index(i); + result_data[idx] = true; + } + + const idx_t no_match_count = count - match_count; + for (idx_t i = 0; i < no_match_count; ++i) { + const auto idx = false_sel.get_index(i); + result_data[idx] = false; + } +} + +struct ComparisonExecutor { +private: + template + static inline void TemplatedExecute(Vector &left, Vector &right, Vector &result, idx_t count) { + BinaryExecutor::Execute(left, right, result, count); + } + +public: + template + static inline void Execute(Vector &left, Vector &right, Vector &result, idx_t count) { + D_ASSERT(left.GetType() == right.GetType() && result.GetType() == LogicalType::BOOLEAN); + // the inplace loops take the result as the last parameter + switch (left.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::INT16: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::INT32: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::INT64: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::UINT8: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::UINT16: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::UINT32: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::UINT64: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::INT128: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::FLOAT: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::DOUBLE: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::INTERVAL: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::VARCHAR: + TemplatedExecute(left, right, result, count); + break; + case PhysicalType::LIST: + case PhysicalType::STRUCT: + NestedComparisonExecutor(left, right, result, count); + break; + default: + throw InternalException("Invalid type for comparison"); + } + } +}; + +void VectorOperations::Equals(Vector &left, Vector &right, Vector &result, idx_t count) { + ComparisonExecutor::Execute(left, right, result, count); +} + +void VectorOperations::NotEquals(Vector &left, Vector &right, Vector &result, idx_t count) { + ComparisonExecutor::Execute(left, right, result, count); +} + +void VectorOperations::GreaterThanEquals(Vector &left, Vector &right, Vector &result, idx_t count) { + ComparisonExecutor::Execute(left, right, result, count); +} + +void VectorOperations::LessThanEquals(Vector &left, Vector &right, Vector &result, idx_t count) { + ComparisonExecutor::Execute(right, left, result, count); +} + +void VectorOperations::GreaterThan(Vector &left, Vector &right, Vector &result, idx_t count) { + ComparisonExecutor::Execute(left, right, result, count); +} + +void VectorOperations::LessThan(Vector &left, Vector &right, Vector &result, idx_t count) { + ComparisonExecutor::Execute(right, left, result, count); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/generators.cpp b/src/duckdb/src/common/vector_operations/generators.cpp new file mode 100644 index 00000000..77dd5f9c --- /dev/null +++ b/src/duckdb/src/common/vector_operations/generators.cpp @@ -0,0 +1,102 @@ +//===--------------------------------------------------------------------===// +// generators.cpp +// Description: This file contains the implementation of different generators +//===--------------------------------------------------------------------===// + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/limits.hpp" + +namespace duckdb { + +template +void TemplatedGenerateSequence(Vector &result, idx_t count, int64_t start, int64_t increment) { + D_ASSERT(result.GetType().IsNumeric()); + if (start > NumericLimits::Maximum() || increment > NumericLimits::Maximum()) { + throw Exception("Sequence start or increment out of type range"); + } + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto value = (T)start; + for (idx_t i = 0; i < count; i++) { + if (i > 0) { + value += increment; + } + result_data[i] = value; + } +} + +void VectorOperations::GenerateSequence(Vector &result, idx_t count, int64_t start, int64_t increment) { + if (!result.GetType().IsNumeric()) { + throw InvalidTypeException(result.GetType(), "Can only generate sequences for numeric values!"); + } + switch (result.GetType().InternalType()) { + case PhysicalType::INT8: + TemplatedGenerateSequence(result, count, start, increment); + break; + case PhysicalType::INT16: + TemplatedGenerateSequence(result, count, start, increment); + break; + case PhysicalType::INT32: + TemplatedGenerateSequence(result, count, start, increment); + break; + case PhysicalType::INT64: + TemplatedGenerateSequence(result, count, start, increment); + break; + case PhysicalType::FLOAT: + TemplatedGenerateSequence(result, count, start, increment); + break; + case PhysicalType::DOUBLE: + TemplatedGenerateSequence(result, count, start, increment); + break; + default: + throw NotImplementedException("Unimplemented type for generate sequence"); + } +} + +template +void TemplatedGenerateSequence(Vector &result, idx_t count, const SelectionVector &sel, int64_t start, + int64_t increment) { + D_ASSERT(result.GetType().IsNumeric()); + if (start > NumericLimits::Maximum() || increment > NumericLimits::Maximum()) { + throw Exception("Sequence start or increment out of type range"); + } + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto value = (T)start; + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + result_data[idx] = value + increment * idx; + } +} + +void VectorOperations::GenerateSequence(Vector &result, idx_t count, const SelectionVector &sel, int64_t start, + int64_t increment) { + if (!result.GetType().IsNumeric()) { + throw InvalidTypeException(result.GetType(), "Can only generate sequences for numeric values!"); + } + switch (result.GetType().InternalType()) { + case PhysicalType::INT8: + TemplatedGenerateSequence(result, count, sel, start, increment); + break; + case PhysicalType::INT16: + TemplatedGenerateSequence(result, count, sel, start, increment); + break; + case PhysicalType::INT32: + TemplatedGenerateSequence(result, count, sel, start, increment); + break; + case PhysicalType::INT64: + TemplatedGenerateSequence(result, count, sel, start, increment); + break; + case PhysicalType::FLOAT: + TemplatedGenerateSequence(result, count, sel, start, increment); + break; + case PhysicalType::DOUBLE: + TemplatedGenerateSequence(result, count, sel, start, increment); + break; + default: + throw NotImplementedException("Unimplemented type for generate sequence"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp new file mode 100644 index 00000000..1564a391 --- /dev/null +++ b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp @@ -0,0 +1,928 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" + +namespace duckdb { + +struct DistinctBinaryLambdaWrapper { + template + static inline RESULT_TYPE Operation(LEFT_TYPE left, RIGHT_TYPE right, bool is_left_null, bool is_right_null) { + return OP::template Operation(left, right, is_left_null, is_right_null); + } +}; + +template +static void DistinctExecuteGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + RESULT_TYPE *__restrict result_data, const SelectionVector *__restrict lsel, + const SelectionVector *__restrict rsel, idx_t count, ValidityMask &lmask, + ValidityMask &rmask, ValidityMask &result_mask) { + for (idx_t i = 0; i < count; i++) { + auto lindex = lsel->get_index(i); + auto rindex = rsel->get_index(i); + auto lentry = ldata[lindex]; + auto rentry = rdata[rindex]; + result_data[i] = + OP::template Operation(lentry, rentry, !lmask.RowIsValid(lindex), !rmask.RowIsValid(rindex)); + } +} + +template +static void DistinctExecuteConstant(Vector &left, Vector &right, Vector &result) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + auto ldata = ConstantVector::GetData(left); + auto rdata = ConstantVector::GetData(right); + auto result_data = ConstantVector::GetData(result); + *result_data = + OP::template Operation(*ldata, *rdata, ConstantVector::IsNull(left), ConstantVector::IsNull(right)); +} + +template +static void DistinctExecuteGeneric(Vector &left, Vector &right, Vector &result, idx_t count) { + if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { + DistinctExecuteConstant(left, right, result); + } else { + UnifiedVectorFormat ldata, rdata; + + left.ToUnifiedFormat(count, ldata); + right.ToUnifiedFormat(count, rdata); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + DistinctExecuteGenericLoop( + UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), + result_data, ldata.sel, rdata.sel, count, ldata.validity, rdata.validity, FlatVector::Validity(result)); + } +} + +template +static void DistinctExecuteSwitch(Vector &left, Vector &right, Vector &result, idx_t count) { + DistinctExecuteGeneric(left, right, result, count); +} + +template +static void DistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { + DistinctExecuteSwitch(left, right, result, count); +} + +template +static inline idx_t +DistinctSelectGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, + const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, + ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { + idx_t true_count = 0, false_count = 0; + for (idx_t i = 0; i < count; i++) { + auto result_idx = result_sel->get_index(i); + auto lindex = lsel->get_index(i); + auto rindex = rsel->get_index(i); + if (NO_NULL) { + if (OP::Operation(ldata[lindex], rdata[rindex], false, false)) { + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count++, result_idx); + } + } else { + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count++, result_idx); + } + } + } else { + if (OP::Operation(ldata[lindex], rdata[rindex], !lmask.RowIsValid(lindex), !rmask.RowIsValid(rindex))) { + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count++, result_idx); + } + } else { + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count++, result_idx); + } + } + } + } + if (HAS_TRUE_SEL) { + return true_count; + } else { + return count - false_count; + } +} +template +static inline idx_t +DistinctSelectGenericLoopSelSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, + const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, + ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { + if (true_sel && false_sel) { + return DistinctSelectGenericLoop( + ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); + } else if (true_sel) { + return DistinctSelectGenericLoop( + ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); + } else { + D_ASSERT(false_sel); + return DistinctSelectGenericLoop( + ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); + } +} + +template +static inline idx_t +DistinctSelectGenericLoopSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, + const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, + ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { + if (!lmask.AllValid() || !rmask.AllValid()) { + return DistinctSelectGenericLoopSelSwitch( + ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); + } else { + return DistinctSelectGenericLoopSelSwitch( + ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); + } +} + +template +static idx_t DistinctSelectGeneric(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + UnifiedVectorFormat ldata, rdata; + + left.ToUnifiedFormat(count, ldata); + right.ToUnifiedFormat(count, rdata); + + return DistinctSelectGenericLoopSwitch( + UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), ldata.sel, + rdata.sel, sel, count, ldata.validity, rdata.validity, true_sel, false_sel); +} +template +static inline idx_t DistinctSelectFlatLoop(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, + const SelectionVector *sel, idx_t count, ValidityMask &lmask, + ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { + idx_t true_count = 0, false_count = 0; + for (idx_t i = 0; i < count; i++) { + idx_t result_idx = sel->get_index(i); + idx_t lidx = LEFT_CONSTANT ? 0 : i; + idx_t ridx = RIGHT_CONSTANT ? 0 : i; + const bool lnull = !lmask.RowIsValid(lidx); + const bool rnull = !rmask.RowIsValid(ridx); + bool comparison_result = OP::Operation(ldata[lidx], rdata[ridx], lnull, rnull); + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count, result_idx); + true_count += comparison_result; + } + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count, result_idx); + false_count += !comparison_result; + } + } + if (HAS_TRUE_SEL) { + return true_count; + } else { + return count - false_count; + } +} + +template +static inline idx_t DistinctSelectFlatLoopSelSwitch(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, + const SelectionVector *sel, idx_t count, ValidityMask &lmask, + ValidityMask &rmask, SelectionVector *true_sel, + SelectionVector *false_sel) { + if (true_sel && false_sel) { + return DistinctSelectFlatLoop( + ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); + } else if (true_sel) { + return DistinctSelectFlatLoop( + ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); + } else { + D_ASSERT(false_sel); + return DistinctSelectFlatLoop( + ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); + } +} + +template +static inline idx_t DistinctSelectFlatLoopSwitch(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, + const SelectionVector *sel, idx_t count, ValidityMask &lmask, + ValidityMask &rmask, SelectionVector *true_sel, + SelectionVector *false_sel) { + return DistinctSelectFlatLoopSelSwitch( + ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); +} +template +static idx_t DistinctSelectFlat(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + auto ldata = FlatVector::GetData(left); + auto rdata = FlatVector::GetData(right); + if (LEFT_CONSTANT) { + ValidityMask validity; + if (ConstantVector::IsNull(left)) { + validity.SetAllInvalid(1); + } + return DistinctSelectFlatLoopSwitch( + ldata, rdata, sel, count, validity, FlatVector::Validity(right), true_sel, false_sel); + } else if (RIGHT_CONSTANT) { + ValidityMask validity; + if (ConstantVector::IsNull(right)) { + validity.SetAllInvalid(1); + } + return DistinctSelectFlatLoopSwitch( + ldata, rdata, sel, count, FlatVector::Validity(left), validity, true_sel, false_sel); + } else { + return DistinctSelectFlatLoopSwitch( + ldata, rdata, sel, count, FlatVector::Validity(left), FlatVector::Validity(right), true_sel, false_sel); + } +} +template +static idx_t DistinctSelectConstant(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + auto ldata = ConstantVector::GetData(left); + auto rdata = ConstantVector::GetData(right); + + // both sides are constant, return either 0 or the count + // in this case we do not fill in the result selection vector at all + if (!OP::Operation(*ldata, *rdata, ConstantVector::IsNull(left), ConstantVector::IsNull(right))) { + if (false_sel) { + for (idx_t i = 0; i < count; i++) { + false_sel->set_index(i, sel->get_index(i)); + } + } + return 0; + } else { + if (true_sel) { + for (idx_t i = 0; i < count; i++) { + true_sel->set_index(i, sel->get_index(i)); + } + } + return count; + } +} + +template +static idx_t DistinctSelect(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + if (!sel) { + sel = FlatVector::IncrementalSelectionVector(); + } + if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return DistinctSelectConstant(left, right, sel, count, true_sel, false_sel); + } else if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && + right.GetVectorType() == VectorType::FLAT_VECTOR) { + return DistinctSelectFlat(left, right, sel, count, true_sel, false_sel); + } else if (left.GetVectorType() == VectorType::FLAT_VECTOR && + right.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return DistinctSelectFlat(left, right, sel, count, true_sel, false_sel); + } else if (left.GetVectorType() == VectorType::FLAT_VECTOR && right.GetVectorType() == VectorType::FLAT_VECTOR) { + return DistinctSelectFlat(left, right, sel, count, true_sel, + false_sel); + } else { + return DistinctSelectGeneric(left, right, sel, count, true_sel, false_sel); + } +} + +template +static idx_t DistinctSelectNotNull(Vector &left, Vector &right, const idx_t count, idx_t &true_count, + const SelectionVector &sel, SelectionVector &maybe_vec, OptionalSelection &true_opt, + OptionalSelection &false_opt) { + UnifiedVectorFormat lvdata, rvdata; + left.ToUnifiedFormat(count, lvdata); + right.ToUnifiedFormat(count, rvdata); + + auto &lmask = lvdata.validity; + auto &rmask = rvdata.validity; + + idx_t remaining = 0; + if (lmask.AllValid() && rmask.AllValid()) { + // None are NULL, distinguish values. + for (idx_t i = 0; i < count; ++i) { + const auto idx = sel.get_index(i); + maybe_vec.set_index(remaining++, idx); + } + return remaining; + } + + // Slice the Vectors down to the rows that are not determined (i.e., neither is NULL) + SelectionVector slicer(count); + true_count = 0; + idx_t false_count = 0; + for (idx_t i = 0; i < count; ++i) { + const auto result_idx = sel.get_index(i); + const auto lidx = lvdata.sel->get_index(i); + const auto ridx = rvdata.sel->get_index(i); + const auto lnull = !lmask.RowIsValid(lidx); + const auto rnull = !rmask.RowIsValid(ridx); + if (lnull || rnull) { + // If either is NULL then we can major distinguish them + if (!OP::Operation(false, false, lnull, rnull)) { + false_opt.Append(false_count, result_idx); + } else { + true_opt.Append(true_count, result_idx); + } + } else { + // Neither is NULL, distinguish values. + slicer.set_index(remaining, i); + maybe_vec.set_index(remaining++, result_idx); + } + } + + true_opt.Advance(true_count); + false_opt.Advance(false_count); + + if (remaining && remaining < count) { + left.Slice(slicer, remaining); + right.Slice(slicer, remaining); + } + + return remaining; +} + +struct PositionComparator { + // Select the rows that definitely match. + // Default to the same as the final row + template + static idx_t Definite(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector &false_sel) { + return Final(left, right, sel, count, true_sel, &false_sel); + } + + // Select the possible rows that need further testing. + // Usually this means Is Not Distinct, as those are the semantics used by Postges + template + static idx_t Possible(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector &true_sel, SelectionVector *false_sel) { + return VectorOperations::NestedEquals(left, right, sel, count, &true_sel, false_sel); + } + + // Select the matching rows for the final position. + // This needs to be specialised. + template + static idx_t Final(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return 0; + } + + // Tie-break based on length when one of the sides has been exhausted, returning true if the LHS matches. + // This essentially means that the existing positions compare equal. + // Default to the same semantics as the OP for idx_t. This works in most cases. + template + static bool TieBreak(const idx_t lpos, const idx_t rpos) { + return OP::Operation(lpos, rpos, false, false); + } +}; + +// NotDistinctFrom must always check every column +template <> +idx_t PositionComparator::Definite(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, + SelectionVector &false_sel) { + return 0; +} + +template <> +idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::NestedEquals(left, right, sel, count, true_sel, false_sel); +} + +// DistinctFrom must check everything that matched +template <> +idx_t PositionComparator::Possible(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector &true_sel, + SelectionVector *false_sel) { + return count; +} + +template <> +idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::NestedNotEquals(left, right, sel, count, true_sel, false_sel); +} + +// Non-strict inequalities must use strict comparisons for Definite +template <> +idx_t PositionComparator::Definite(Vector &left, Vector &right, + const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, + SelectionVector &false_sel) { + return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, &false_sel); +} + +template <> +idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::DistinctGreaterThanEquals(right, left, &sel, count, true_sel, false_sel); +} + +template <> +idx_t PositionComparator::Definite(Vector &left, Vector &right, + const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, + SelectionVector &false_sel) { + return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, &false_sel); +} + +template <> +idx_t PositionComparator::Final(Vector &left, Vector &right, + const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel); +} + +// Strict inequalities just use strict for both Definite and Final +template <> +idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, false_sel); +} + +template <> +idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel); +} + +using StructEntries = vector>; + +static void ExtractNestedSelection(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, + OptionalSelection &opt) { + + for (idx_t i = 0; i < count;) { + const auto slice_idx = slice_sel.get_index(i); + const auto result_idx = sel.get_index(slice_idx); + opt.Append(i, result_idx); + } + opt.Advance(count); +} + +static void DensifyNestedSelection(const SelectionVector &dense_sel, const idx_t count, SelectionVector &slice_sel) { + for (idx_t i = 0; i < count; ++i) { + slice_sel.set_index(i, dense_sel.get_index(i)); + } +} + +template +static idx_t DistinctSelectStruct(Vector &left, Vector &right, idx_t count, const SelectionVector &sel, + OptionalSelection &true_opt, OptionalSelection &false_opt) { + if (count == 0) { + return 0; + } + + // Avoid allocating in the 99% of the cases where we don't need to. + StructEntries lsliced, rsliced; + auto &lchildren = StructVector::GetEntries(left); + auto &rchildren = StructVector::GetEntries(right); + D_ASSERT(lchildren.size() == rchildren.size()); + + // In order to reuse the comparators, we have to track what passed and failed internally. + // To do that, we need local SVs that we then merge back into the real ones after every pass. + const auto vcount = count; + SelectionVector slice_sel(count); + for (idx_t i = 0; i < count; ++i) { + slice_sel.set_index(i, i); + } + + SelectionVector true_sel(count); + SelectionVector false_sel(count); + + idx_t match_count = 0; + for (idx_t col_no = 0; col_no < lchildren.size(); ++col_no) { + // Slice the children to maintain density + Vector lchild(*lchildren[col_no]); + lchild.Flatten(vcount); + lchild.Slice(slice_sel, count); + + Vector rchild(*rchildren[col_no]); + rchild.Flatten(vcount); + rchild.Slice(slice_sel, count); + + // Find everything that definitely matches + auto true_count = PositionComparator::Definite(lchild, rchild, slice_sel, count, &true_sel, false_sel); + if (true_count > 0) { + auto false_count = count - true_count; + + // Extract the definite matches into the true result + ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); + + // Remove the definite matches from the slicing vector + DensifyNestedSelection(false_sel, false_count, slice_sel); + + match_count += true_count; + count -= true_count; + } + + if (col_no != lchildren.size() - 1) { + // Find what might match on the next position + true_count = PositionComparator::Possible(lchild, rchild, slice_sel, count, true_sel, &false_sel); + auto false_count = count - true_count; + + // Extract the definite failures into the false result + ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); + + // Remove any definite failures from the slicing vector + if (false_count) { + DensifyNestedSelection(true_sel, true_count, slice_sel); + } + + count = true_count; + } else { + true_count = PositionComparator::Final(lchild, rchild, slice_sel, count, &true_sel, &false_sel); + auto false_count = count - true_count; + + // Extract the definite matches into the true result + ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); + + // Extract the definite failures into the false result + ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); + + match_count += true_count; + } + } + return match_count; +} + +static void PositionListCursor(SelectionVector &cursor, UnifiedVectorFormat &vdata, const idx_t pos, + const SelectionVector &slice_sel, const idx_t count) { + const auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; ++i) { + const auto slice_idx = slice_sel.get_index(i); + + const auto lidx = vdata.sel->get_index(slice_idx); + const auto &entry = data[lidx]; + cursor.set_index(i, entry.offset + pos); + } +} + +template +static idx_t DistinctSelectList(Vector &left, Vector &right, idx_t count, const SelectionVector &sel, + OptionalSelection &true_opt, OptionalSelection &false_opt) { + if (count == 0) { + return count; + } + + // Create dictionary views of the children so we can vectorise the positional comparisons. + SelectionVector lcursor(count); + SelectionVector rcursor(count); + + Vector lentry_flattened(ListVector::GetEntry(left)); + Vector rentry_flattened(ListVector::GetEntry(right)); + lentry_flattened.Flatten(ListVector::GetListSize(left)); + rentry_flattened.Flatten(ListVector::GetListSize(right)); + Vector lchild(lentry_flattened, lcursor, count); + Vector rchild(rentry_flattened, rcursor, count); + + // To perform the positional comparison, we use a vectorisation of the following algorithm: + // bool CompareLists(T *left, idx_t nleft, T *right, nright) { + // for (idx_t pos = 0; ; ++pos) { + // if (nleft == pos || nright == pos) + // return OP::TieBreak(nleft, nright); + // if (OP::Definite(*left, *right)) + // return true; + // if (!OP::Maybe(*left, *right)) + // return false; + // } + // ++left, ++right; + // } + // } + + // Get pointers to the list entries + UnifiedVectorFormat lvdata; + left.ToUnifiedFormat(count, lvdata); + const auto ldata = UnifiedVectorFormat::GetData(lvdata); + + UnifiedVectorFormat rvdata; + right.ToUnifiedFormat(count, rvdata); + const auto rdata = UnifiedVectorFormat::GetData(rvdata); + + // In order to reuse the comparators, we have to track what passed and failed internally. + // To do that, we need local SVs that we then merge back into the real ones after every pass. + SelectionVector slice_sel(count); + for (idx_t i = 0; i < count; ++i) { + slice_sel.set_index(i, i); + } + + SelectionVector true_sel(count); + SelectionVector false_sel(count); + + idx_t match_count = 0; + for (idx_t pos = 0; count > 0; ++pos) { + // Set up the cursors for the current position + PositionListCursor(lcursor, lvdata, pos, slice_sel, count); + PositionListCursor(rcursor, rvdata, pos, slice_sel, count); + + // Tie-break the pairs where one of the LISTs is exhausted. + idx_t true_count = 0; + idx_t false_count = 0; + idx_t maybe_count = 0; + for (idx_t i = 0; i < count; ++i) { + const auto slice_idx = slice_sel.get_index(i); + const auto lidx = lvdata.sel->get_index(slice_idx); + const auto &lentry = ldata[lidx]; + const auto ridx = rvdata.sel->get_index(slice_idx); + const auto &rentry = rdata[ridx]; + if (lentry.length == pos || rentry.length == pos) { + const auto idx = sel.get_index(slice_idx); + if (PositionComparator::TieBreak(lentry.length, rentry.length)) { + true_opt.Append(true_count, idx); + } else { + false_opt.Append(false_count, idx); + } + } else { + true_sel.set_index(maybe_count++, slice_idx); + } + } + true_opt.Advance(true_count); + false_opt.Advance(false_count); + match_count += true_count; + + // Redensify the list cursors + if (maybe_count < count) { + count = maybe_count; + DensifyNestedSelection(true_sel, count, slice_sel); + PositionListCursor(lcursor, lvdata, pos, slice_sel, count); + PositionListCursor(rcursor, rvdata, pos, slice_sel, count); + } + + // Find everything that definitely matches + true_count = PositionComparator::Definite(lchild, rchild, slice_sel, count, &true_sel, false_sel); + if (true_count) { + false_count = count - true_count; + ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); + match_count += true_count; + + // Redensify the list cursors + count -= true_count; + DensifyNestedSelection(false_sel, count, slice_sel); + PositionListCursor(lcursor, lvdata, pos, slice_sel, count); + PositionListCursor(rcursor, rvdata, pos, slice_sel, count); + } + + // Find what might match on the next position + true_count = PositionComparator::Possible(lchild, rchild, slice_sel, count, true_sel, &false_sel); + false_count = count - true_count; + ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); + + if (false_count) { + DensifyNestedSelection(true_sel, true_count, slice_sel); + } + count = true_count; + } + + return match_count; +} + +template +static idx_t DistinctSelectNested(Vector &left, Vector &right, const SelectionVector *sel, const idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + // The Select operations all use a dense pair of input vectors to partition + // a selection vector in a single pass. But to implement progressive comparisons, + // we have to make multiple passes, so we need to keep track of the original input positions + // and then scatter the output selections when we are done. + if (!sel) { + sel = FlatVector::IncrementalSelectionVector(); + } + + // Make buffered selections for progressive comparisons + // TODO: Remove unnecessary allocations + SelectionVector true_vec(count); + OptionalSelection true_opt(&true_vec); + + SelectionVector false_vec(count); + OptionalSelection false_opt(&false_vec); + + SelectionVector maybe_vec(count); + + // Handle NULL nested values + Vector l_not_null(left); + Vector r_not_null(right); + + idx_t match_count = 0; + auto unknown = + DistinctSelectNotNull(l_not_null, r_not_null, count, match_count, *sel, maybe_vec, true_opt, false_opt); + + if (PhysicalType::LIST == left.GetType().InternalType()) { + match_count += DistinctSelectList(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt); + } else { + match_count += DistinctSelectStruct(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt); + } + + // Copy the buffered selections to the output selections + if (true_sel) { + DensifyNestedSelection(true_vec, match_count, *true_sel); + } + + if (false_sel) { + DensifyNestedSelection(false_vec, count - match_count, *false_sel); + } + + return match_count; +} + +template +static void NestedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count); + +template +static inline void TemplatedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { + DistinctExecute(left, right, result, count); +} +template +static void ExecuteDistinct(Vector &left, Vector &right, Vector &result, idx_t count) { + D_ASSERT(left.GetType() == right.GetType() && result.GetType() == LogicalType::BOOLEAN); + // the inplace loops take the result as the last parameter + switch (left.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::INT16: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::INT32: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::INT64: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::UINT8: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::UINT16: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::UINT32: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::UINT64: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::INT128: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::FLOAT: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::DOUBLE: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::INTERVAL: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::VARCHAR: + TemplatedDistinctExecute(left, right, result, count); + break; + case PhysicalType::LIST: + case PhysicalType::STRUCT: + NestedDistinctExecute(left, right, result, count); + break; + default: + throw InternalException("Invalid type for distinct comparison"); + } +} + +template +static idx_t TemplatedDistinctSelectOperation(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + // the inplace loops take the result as the last parameter + switch (left.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INT16: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INT32: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INT64: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::UINT8: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::UINT16: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::UINT32: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::UINT64: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INT128: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::FLOAT: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::DOUBLE: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INTERVAL: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::VARCHAR: + return DistinctSelect(left, right, sel, count, true_sel, false_sel); + case PhysicalType::STRUCT: + case PhysicalType::LIST: + return DistinctSelectNested(left, right, sel, count, true_sel, false_sel); + default: + throw InternalException("Invalid type for distinct selection"); + } +} + +template +static void NestedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { + const auto left_constant = left.GetVectorType() == VectorType::CONSTANT_VECTOR; + const auto right_constant = right.GetVectorType() == VectorType::CONSTANT_VECTOR; + + if (left_constant && right_constant) { + // both sides are constant, so just compare one element. + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto result_data = ConstantVector::GetData(result); + SelectionVector true_sel(1); + auto match_count = TemplatedDistinctSelectOperation(left, right, nullptr, 1, &true_sel, nullptr); + result_data[0] = match_count > 0; + return; + } + + SelectionVector true_sel(count); + SelectionVector false_sel(count); + + // DISTINCT is either true or false + idx_t match_count = TemplatedDistinctSelectOperation(left, right, nullptr, count, &true_sel, &false_sel); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + + for (idx_t i = 0; i < match_count; ++i) { + const auto idx = true_sel.get_index(i); + result_data[idx] = true; + } + + const idx_t no_match_count = count - match_count; + for (idx_t i = 0; i < no_match_count; ++i) { + const auto idx = false_sel.get_index(i); + result_data[idx] = false; + } +} + +void VectorOperations::DistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count) { + ExecuteDistinct(left, right, result, count); +} + +void VectorOperations::NotDistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count) { + ExecuteDistinct(left, right, result, count); +} + +// true := A != B with nulls being equal +idx_t VectorOperations::DistinctFrom(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedDistinctSelectOperation(left, right, sel, count, true_sel, false_sel); +} +// true := A == B with nulls being equal +idx_t VectorOperations::NotDistinctFrom(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return count - TemplatedDistinctSelectOperation(left, right, sel, count, false_sel, true_sel); +} + +// true := A > B with nulls being maximal +idx_t VectorOperations::DistinctGreaterThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedDistinctSelectOperation(left, right, sel, count, true_sel, false_sel); +} + +// true := A > B with nulls being minimal +idx_t VectorOperations::DistinctGreaterThanNullsFirst(Vector &left, Vector &right, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return TemplatedDistinctSelectOperation( + left, right, sel, count, true_sel, false_sel); +} +// true := A >= B with nulls being maximal +idx_t VectorOperations::DistinctGreaterThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return count - + TemplatedDistinctSelectOperation(right, left, sel, count, false_sel, true_sel); +} +// true := A < B with nulls being maximal +idx_t VectorOperations::DistinctLessThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedDistinctSelectOperation(right, left, sel, count, true_sel, false_sel); +} + +// true := A < B with nulls being minimal +idx_t VectorOperations::DistinctLessThanNullsFirst(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedDistinctSelectOperation( + right, left, sel, count, true_sel, false_sel); +} + +// true := A <= B with nulls being maximal +idx_t VectorOperations::DistinctLessThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return count - + TemplatedDistinctSelectOperation(left, right, sel, count, false_sel, true_sel); +} + +// true := A != B with nulls being equal, inputs selected +idx_t VectorOperations::NestedNotEquals(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedDistinctSelectOperation(left, right, &sel, count, true_sel, false_sel); +} +// true := A == B with nulls being equal, inputs selected +idx_t VectorOperations::NestedEquals(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return count - + TemplatedDistinctSelectOperation(left, right, &sel, count, false_sel, true_sel); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/null_operations.cpp b/src/duckdb/src/common/vector_operations/null_operations.cpp new file mode 100644 index 00000000..48bc904d --- /dev/null +++ b/src/duckdb/src/common/vector_operations/null_operations.cpp @@ -0,0 +1,113 @@ +//===--------------------------------------------------------------------===// +// null_operators.cpp +// Description: This file contains the implementation of the +// IS NULL/NOT IS NULL operators +//===--------------------------------------------------------------------===// + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +namespace duckdb { + +template +void IsNullLoop(Vector &input, Vector &result, idx_t count) { + D_ASSERT(result.GetType() == LogicalType::BOOLEAN); + + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto result_data = ConstantVector::GetData(result); + *result_data = INVERSE ? !ConstantVector::IsNull(input) : ConstantVector::IsNull(input); + } else { + UnifiedVectorFormat data; + input.ToUnifiedFormat(count, data); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto idx = data.sel->get_index(i); + result_data[i] = INVERSE ? data.validity.RowIsValid(idx) : !data.validity.RowIsValid(idx); + } + } +} + +void VectorOperations::IsNotNull(Vector &input, Vector &result, idx_t count) { + IsNullLoop(input, result, count); +} + +void VectorOperations::IsNull(Vector &input, Vector &result, idx_t count) { + IsNullLoop(input, result, count); +} + +bool VectorOperations::HasNotNull(Vector &input, idx_t count) { + if (count == 0) { + return false; + } + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return !ConstantVector::IsNull(input); + } else { + UnifiedVectorFormat data; + input.ToUnifiedFormat(count, data); + + if (data.validity.AllValid()) { + return true; + } + for (idx_t i = 0; i < count; i++) { + auto idx = data.sel->get_index(i); + if (data.validity.RowIsValid(idx)) { + return true; + } + } + return false; + } +} + +bool VectorOperations::HasNull(Vector &input, idx_t count) { + if (count == 0) { + return false; + } + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return ConstantVector::IsNull(input); + } else { + UnifiedVectorFormat data; + input.ToUnifiedFormat(count, data); + + if (data.validity.AllValid()) { + return false; + } + for (idx_t i = 0; i < count; i++) { + auto idx = data.sel->get_index(i); + if (!data.validity.RowIsValid(idx)) { + return true; + } + } + return false; + } +} + +idx_t VectorOperations::CountNotNull(Vector &input, const idx_t count) { + idx_t valid = 0; + + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + if (vdata.validity.AllValid()) { + return count; + } + switch (input.GetVectorType()) { + case VectorType::FLAT_VECTOR: + valid += vdata.validity.CountValid(count); + break; + case VectorType::CONSTANT_VECTOR: + valid += vdata.validity.CountValid(1) * count; + break; + default: + for (idx_t i = 0; i < count; ++i) { + const auto row_idx = vdata.sel->get_index(i); + valid += int(vdata.validity.RowIsValid(row_idx)); + } + break; + } + + return valid; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp b/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp new file mode 100644 index 00000000..d2bd0f31 --- /dev/null +++ b/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp @@ -0,0 +1,40 @@ +//===--------------------------------------------------------------------===// +// numeric_inplace_operators.cpp +// Description: This file contains the implementation of numeric inplace ops +// += *= /= -= %= +//===--------------------------------------------------------------------===// + +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// In-Place Addition +//===--------------------------------------------------------------------===// + +void VectorOperations::AddInPlace(Vector &input, int64_t right, idx_t count) { + D_ASSERT(input.GetType().id() == LogicalTypeId::POINTER); + if (right == 0) { + return; + } + switch (input.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + D_ASSERT(!ConstantVector::IsNull(input)); + auto data = ConstantVector::GetData(input); + *data += right; + break; + } + default: { + D_ASSERT(input.GetVectorType() == VectorType::FLAT_VECTOR); + auto data = FlatVector::GetData(input); + for (idx_t i = 0; i < count; i++) { + data[i] += right; + } + break; + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/vector_cast.cpp b/src/duckdb/src/common/vector_operations/vector_cast.cpp new file mode 100644 index 00000000..3893f74a --- /dev/null +++ b/src/duckdb/src/common/vector_operations/vector_cast.cpp @@ -0,0 +1,43 @@ +#include "duckdb/common/limits.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/function/scalar_function.hpp" + +namespace duckdb { + +bool VectorOperations::TryCast(CastFunctionSet &set, GetCastFunctionInput &input, Vector &source, Vector &result, + idx_t count, string *error_message, bool strict) { + auto cast_function = set.GetCastFunction(source.GetType(), result.GetType(), input); + unique_ptr local_state; + if (cast_function.init_local_state) { + CastLocalStateParameters lparameters(input.context, cast_function.cast_data); + local_state = cast_function.init_local_state(lparameters); + } + CastParameters parameters(cast_function.cast_data.get(), strict, error_message, local_state.get()); + return cast_function.function(source, result, count, parameters); +} + +bool VectorOperations::DefaultTryCast(Vector &source, Vector &result, idx_t count, string *error_message, bool strict) { + CastFunctionSet set; + GetCastFunctionInput input; + return VectorOperations::TryCast(set, input, source, result, count, error_message, strict); +} + +void VectorOperations::DefaultCast(Vector &source, Vector &result, idx_t count, bool strict) { + VectorOperations::DefaultTryCast(source, result, count, nullptr, strict); +} + +bool VectorOperations::TryCast(ClientContext &context, Vector &source, Vector &result, idx_t count, + string *error_message, bool strict) { + auto &config = DBConfig::GetConfig(context); + auto &set = config.GetCastFunctions(); + GetCastFunctionInput get_input(context); + return VectorOperations::TryCast(set, get_input, source, result, count, error_message, strict); +} + +void VectorOperations::Cast(ClientContext &context, Vector &source, Vector &result, idx_t count, bool strict) { + VectorOperations::TryCast(context, source, result, count, nullptr, strict); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/vector_copy.cpp b/src/duckdb/src/common/vector_operations/vector_copy.cpp new file mode 100644 index 00000000..4f1870de --- /dev/null +++ b/src/duckdb/src/common/vector_operations/vector_copy.cpp @@ -0,0 +1,271 @@ +//===--------------------------------------------------------------------===// +// copy.cpp +// Description: This file contains the implementation of the different copy +// functions +//===--------------------------------------------------------------------===// + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/storage/segment/uncompressed.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +namespace duckdb { + +template +static void TemplatedCopy(const Vector &source, const SelectionVector &sel, Vector &target, idx_t source_offset, + idx_t target_offset, idx_t copy_count) { + auto ldata = FlatVector::GetData(source); + auto tdata = FlatVector::GetData(target); + for (idx_t i = 0; i < copy_count; i++) { + auto source_idx = sel.get_index(source_offset + i); + tdata[target_offset + i] = ldata[source_idx]; + } +} + +static const ValidityMask &CopyValidityMask(const Vector &v) { + switch (v.GetVectorType()) { + case VectorType::FLAT_VECTOR: + return FlatVector::Validity(v); + case VectorType::FSST_VECTOR: + return FSSTVector::Validity(v); + default: + throw InternalException("Unsupported vector type in vector copy"); + } +} + +void VectorOperations::Copy(const Vector &source_p, Vector &target, const SelectionVector &sel_p, idx_t source_count, + idx_t source_offset, idx_t target_offset) { + D_ASSERT(source_offset <= source_count); + D_ASSERT(source_p.GetType() == target.GetType()); + idx_t copy_count = source_count - source_offset; + + SelectionVector owned_sel; + const SelectionVector *sel = &sel_p; + + const Vector *source = &source_p; + bool finished = false; + while (!finished) { + switch (source->GetVectorType()) { + case VectorType::DICTIONARY_VECTOR: { + // dictionary vector: merge selection vectors + auto &child = DictionaryVector::Child(*source); + auto &dict_sel = DictionaryVector::SelVector(*source); + // merge the selection vectors and verify the child + auto new_buffer = dict_sel.Slice(*sel, source_count); + owned_sel.Initialize(new_buffer); + sel = &owned_sel; + source = &child; + break; + } + case VectorType::SEQUENCE_VECTOR: { + int64_t start, increment; + Vector seq(source->GetType()); + SequenceVector::GetSequence(*source, start, increment); + VectorOperations::GenerateSequence(seq, source_count, *sel, start, increment); + VectorOperations::Copy(seq, target, *sel, source_count, source_offset, target_offset); + return; + } + case VectorType::CONSTANT_VECTOR: + sel = ConstantVector::ZeroSelectionVector(copy_count, owned_sel); + finished = true; + break; + case VectorType::FSST_VECTOR: + finished = true; + break; + case VectorType::FLAT_VECTOR: + finished = true; + break; + default: + throw NotImplementedException("FIXME unimplemented vector type for VectorOperations::Copy"); + } + } + + if (copy_count == 0) { + return; + } + + // Allow copying of a single value to constant vectors + const auto target_vector_type = target.GetVectorType(); + if (copy_count == 1 && target_vector_type == VectorType::CONSTANT_VECTOR) { + target_offset = 0; + target.SetVectorType(VectorType::FLAT_VECTOR); + } + D_ASSERT(target.GetVectorType() == VectorType::FLAT_VECTOR); + + // first copy the nullmask + auto &tmask = FlatVector::Validity(target); + if (source->GetVectorType() == VectorType::CONSTANT_VECTOR) { + const bool valid = !ConstantVector::IsNull(*source); + for (idx_t i = 0; i < copy_count; i++) { + tmask.Set(target_offset + i, valid); + } + } else { + auto &smask = CopyValidityMask(*source); + if (smask.IsMaskSet()) { + for (idx_t i = 0; i < copy_count; i++) { + auto idx = sel->get_index(source_offset + i); + + if (smask.RowIsValid(idx)) { + // set valid + if (!tmask.AllValid()) { + tmask.SetValidUnsafe(target_offset + i); + } + } else { + // set invalid + if (tmask.AllValid()) { + auto init_size = MaxValue(STANDARD_VECTOR_SIZE, target_offset + copy_count); + tmask.Initialize(init_size); + } + tmask.SetInvalidUnsafe(target_offset + i); + } + } + } + } + + D_ASSERT(sel); + + // For FSST Vectors we decompress instead of copying. + if (source->GetVectorType() == VectorType::FSST_VECTOR) { + FSSTVector::DecompressVector(*source, target, source_offset, target_offset, copy_count, sel); + return; + } + + // now copy over the data + switch (source->GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::INT16: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::INT32: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::INT64: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::UINT8: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::UINT16: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::UINT32: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::UINT64: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::INT128: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::FLOAT: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::DOUBLE: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::INTERVAL: + TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); + break; + case PhysicalType::VARCHAR: { + auto ldata = FlatVector::GetData(*source); + auto tdata = FlatVector::GetData(target); + for (idx_t i = 0; i < copy_count; i++) { + auto source_idx = sel->get_index(source_offset + i); + auto target_idx = target_offset + i; + if (tmask.RowIsValid(target_idx)) { + tdata[target_idx] = StringVector::AddStringOrBlob(target, ldata[source_idx]); + } + } + break; + } + case PhysicalType::STRUCT: { + auto &source_children = StructVector::GetEntries(*source); + auto &target_children = StructVector::GetEntries(target); + D_ASSERT(source_children.size() == target_children.size()); + for (idx_t i = 0; i < source_children.size(); i++) { + VectorOperations::Copy(*source_children[i], *target_children[i], sel_p, source_count, source_offset, + target_offset); + } + break; + } + case PhysicalType::LIST: { + D_ASSERT(target.GetType().InternalType() == PhysicalType::LIST); + + auto &source_child = ListVector::GetEntry(*source); + auto sdata = FlatVector::GetData(*source); + auto tdata = FlatVector::GetData(target); + + if (target_vector_type == VectorType::CONSTANT_VECTOR) { + // If we are only writing one value, then the copied values (if any) are contiguous + // and we can just Append from the offset position + if (!tmask.RowIsValid(target_offset)) { + break; + } + auto source_idx = sel->get_index(source_offset); + auto &source_entry = sdata[source_idx]; + const idx_t source_child_size = source_entry.length + source_entry.offset; + + //! overwrite constant target vectors. + ListVector::SetListSize(target, 0); + ListVector::Append(target, source_child, source_child_size, source_entry.offset); + + auto &target_entry = tdata[target_offset]; + target_entry.length = source_entry.length; + target_entry.offset = 0; + } else { + //! if the source has list offsets, we need to append them to the target + //! build a selection vector for the copied child elements + vector child_rows; + for (idx_t i = 0; i < copy_count; ++i) { + if (tmask.RowIsValid(target_offset + i)) { + auto source_idx = sel->get_index(source_offset + i); + auto &source_entry = sdata[source_idx]; + for (idx_t j = 0; j < source_entry.length; ++j) { + child_rows.emplace_back(source_entry.offset + j); + } + } + } + idx_t source_child_size = child_rows.size(); + SelectionVector child_sel(child_rows.data()); + + idx_t old_target_child_len = ListVector::GetListSize(target); + + //! append to list itself + ListVector::Append(target, source_child, child_sel, source_child_size); + + //! now write the list offsets + for (idx_t i = 0; i < copy_count; i++) { + auto source_idx = sel->get_index(source_offset + i); + auto &source_entry = sdata[source_idx]; + auto &target_entry = tdata[target_offset + i]; + + target_entry.length = source_entry.length; + target_entry.offset = old_target_child_len; + if (tmask.RowIsValid(target_offset + i)) { + old_target_child_len += target_entry.length; + } + } + } + break; + } + default: + throw NotImplementedException("Unimplemented type '%s' for copy!", + TypeIdToString(source->GetType().InternalType())); + } + + if (target_vector_type != VectorType::FLAT_VECTOR) { + target.SetVectorType(target_vector_type); + } +} + +void VectorOperations::Copy(const Vector &source, Vector &target, idx_t source_count, idx_t source_offset, + idx_t target_offset) { + VectorOperations::Copy(source, target, *FlatVector::IncrementalSelectionVector(), source_count, source_offset, + target_offset); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/vector_hash.cpp b/src/duckdb/src/common/vector_operations/vector_hash.cpp new file mode 100644 index 00000000..f489d8f2 --- /dev/null +++ b/src/duckdb/src/common/vector_operations/vector_hash.cpp @@ -0,0 +1,376 @@ +//===--------------------------------------------------------------------===// +// hash.cpp +// Description: This file contains the vectorized hash implementations +//===--------------------------------------------------------------------===// + +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" + +namespace duckdb { + +struct HashOp { + static const hash_t NULL_HASH = 0xbf58476d1ce4e5b9; + + template + static inline hash_t Operation(T input, bool is_null) { + return is_null ? NULL_HASH : duckdb::Hash(input); + } +}; + +static inline hash_t CombineHashScalar(hash_t a, hash_t b) { + return (a * UINT64_C(0xbf58476d1ce4e5b9)) ^ b; +} + +template +static inline void TightLoopHash(const T *__restrict ldata, hash_t *__restrict result_data, const SelectionVector *rsel, + idx_t count, const SelectionVector *__restrict sel_vector, ValidityMask &mask) { + if (!mask.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto ridx = HAS_RSEL ? rsel->get_index(i) : i; + auto idx = sel_vector->get_index(ridx); + result_data[ridx] = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); + } + } else { + for (idx_t i = 0; i < count; i++) { + auto ridx = HAS_RSEL ? rsel->get_index(i) : i; + auto idx = sel_vector->get_index(ridx); + result_data[ridx] = duckdb::Hash(ldata[idx]); + } + } +} + +template +static inline void TemplatedLoopHash(Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + auto ldata = ConstantVector::GetData(input); + auto result_data = ConstantVector::GetData(result); + *result_data = HashOp::Operation(*ldata, ConstantVector::IsNull(input)); + } else { + result.SetVectorType(VectorType::FLAT_VECTOR); + + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + + TightLoopHash(UnifiedVectorFormat::GetData(idata), FlatVector::GetData(result), rsel, + count, idata.sel, idata.validity); + } +} + +template +static inline void StructLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { + auto &children = StructVector::GetEntries(input); + + D_ASSERT(!children.empty()); + idx_t col_no = 0; + if (HAS_RSEL) { + if (FIRST_HASH) { + VectorOperations::Hash(*children[col_no++], hashes, *rsel, count); + } else { + VectorOperations::CombineHash(hashes, *children[col_no++], *rsel, count); + } + while (col_no < children.size()) { + VectorOperations::CombineHash(hashes, *children[col_no++], *rsel, count); + } + } else { + if (FIRST_HASH) { + VectorOperations::Hash(*children[col_no++], hashes, count); + } else { + VectorOperations::CombineHash(hashes, *children[col_no++], count); + } + while (col_no < children.size()) { + VectorOperations::CombineHash(hashes, *children[col_no++], count); + } + } +} + +template +static inline void ListLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { + auto hdata = FlatVector::GetData(hashes); + + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + const auto ldata = UnifiedVectorFormat::GetData(idata); + + // Hash the children into a temporary + auto &child = ListVector::GetEntry(input); + const auto child_count = ListVector::GetListSize(input); + + Vector child_hashes(LogicalType::HASH, child_count); + if (child_count > 0) { + VectorOperations::Hash(child, child_hashes, child_count); + child_hashes.Flatten(child_count); + } + auto chdata = FlatVector::GetData(child_hashes); + + // Reduce the number of entries to check to the non-empty ones + SelectionVector unprocessed(count); + SelectionVector cursor(HAS_RSEL ? STANDARD_VECTOR_SIZE : count); + idx_t remaining = 0; + for (idx_t i = 0; i < count; ++i) { + const idx_t ridx = HAS_RSEL ? rsel->get_index(i) : i; + const auto lidx = idata.sel->get_index(ridx); + const auto &entry = ldata[lidx]; + if (idata.validity.RowIsValid(lidx) && entry.length > 0) { + unprocessed.set_index(remaining++, ridx); + cursor.set_index(ridx, entry.offset); + } else if (FIRST_HASH) { + hdata[ridx] = HashOp::NULL_HASH; + } + // Empty or NULL non-first elements have no effect. + } + + count = remaining; + if (count == 0) { + return; + } + + // Merge the first position hash into the main hash + idx_t position = 1; + if (FIRST_HASH) { + remaining = 0; + for (idx_t i = 0; i < count; ++i) { + const auto ridx = unprocessed.get_index(i); + const auto cidx = cursor.get_index(ridx); + hdata[ridx] = chdata[cidx]; + + const auto lidx = idata.sel->get_index(ridx); + const auto &entry = ldata[lidx]; + if (entry.length > position) { + // Entry still has values to hash + unprocessed.set_index(remaining++, ridx); + cursor.set_index(ridx, cidx + 1); + } + } + count = remaining; + if (count == 0) { + return; + } + ++position; + } + + // Combine the hashes for the remaining positions until there are none left + for (;; ++position) { + remaining = 0; + for (idx_t i = 0; i < count; ++i) { + const auto ridx = unprocessed.get_index(i); + const auto cidx = cursor.get_index(ridx); + hdata[ridx] = CombineHashScalar(hdata[ridx], chdata[cidx]); + + const auto lidx = idata.sel->get_index(ridx); + const auto &entry = ldata[lidx]; + if (entry.length > position) { + // Entry still has values to hash + unprocessed.set_index(remaining++, ridx); + cursor.set_index(ridx, cidx + 1); + } + } + + count = remaining; + if (count == 0) { + break; + } + } +} + +template +static inline void HashTypeSwitch(Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { + D_ASSERT(result.GetType().id() == LogicalType::HASH); + switch (input.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::INT16: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::INT32: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::INT64: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::UINT8: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::UINT16: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::UINT32: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::UINT64: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::INT128: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::FLOAT: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::DOUBLE: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::INTERVAL: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::VARCHAR: + TemplatedLoopHash(input, result, rsel, count); + break; + case PhysicalType::STRUCT: + StructLoopHash(input, result, rsel, count); + break; + case PhysicalType::LIST: + ListLoopHash(input, result, rsel, count); + break; + default: + throw InvalidTypeException(input.GetType(), "Invalid type for hash"); + } +} + +void VectorOperations::Hash(Vector &input, Vector &result, idx_t count) { + HashTypeSwitch(input, result, nullptr, count); +} + +void VectorOperations::Hash(Vector &input, Vector &result, const SelectionVector &sel, idx_t count) { + HashTypeSwitch(input, result, &sel, count); +} + +template +static inline void TightLoopCombineHashConstant(const T *__restrict ldata, hash_t constant_hash, + hash_t *__restrict hash_data, const SelectionVector *rsel, idx_t count, + const SelectionVector *__restrict sel_vector, ValidityMask &mask) { + if (!mask.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto ridx = HAS_RSEL ? rsel->get_index(i) : i; + auto idx = sel_vector->get_index(ridx); + auto other_hash = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); + hash_data[ridx] = CombineHashScalar(constant_hash, other_hash); + } + } else { + for (idx_t i = 0; i < count; i++) { + auto ridx = HAS_RSEL ? rsel->get_index(i) : i; + auto idx = sel_vector->get_index(ridx); + auto other_hash = duckdb::Hash(ldata[idx]); + hash_data[ridx] = CombineHashScalar(constant_hash, other_hash); + } + } +} + +template +static inline void TightLoopCombineHash(const T *__restrict ldata, hash_t *__restrict hash_data, + const SelectionVector *rsel, idx_t count, + const SelectionVector *__restrict sel_vector, ValidityMask &mask) { + if (!mask.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto ridx = HAS_RSEL ? rsel->get_index(i) : i; + auto idx = sel_vector->get_index(ridx); + auto other_hash = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); + hash_data[ridx] = CombineHashScalar(hash_data[ridx], other_hash); + } + } else { + for (idx_t i = 0; i < count; i++) { + auto ridx = HAS_RSEL ? rsel->get_index(i) : i; + auto idx = sel_vector->get_index(ridx); + auto other_hash = duckdb::Hash(ldata[idx]); + hash_data[ridx] = CombineHashScalar(hash_data[ridx], other_hash); + } + } +} + +template +void TemplatedLoopCombineHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR && hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { + auto ldata = ConstantVector::GetData(input); + auto hash_data = ConstantVector::GetData(hashes); + + auto other_hash = HashOp::Operation(*ldata, ConstantVector::IsNull(input)); + *hash_data = CombineHashScalar(*hash_data, other_hash); + } else { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + if (hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // mix constant with non-constant, first get the constant value + auto constant_hash = *ConstantVector::GetData(hashes); + // now re-initialize the hashes vector to an empty flat vector + hashes.SetVectorType(VectorType::FLAT_VECTOR); + TightLoopCombineHashConstant(UnifiedVectorFormat::GetData(idata), constant_hash, + FlatVector::GetData(hashes), rsel, count, idata.sel, + idata.validity); + } else { + D_ASSERT(hashes.GetVectorType() == VectorType::FLAT_VECTOR); + TightLoopCombineHash(UnifiedVectorFormat::GetData(idata), + FlatVector::GetData(hashes), rsel, count, idata.sel, + idata.validity); + } + } +} + +template +static inline void CombineHashTypeSwitch(Vector &hashes, Vector &input, const SelectionVector *rsel, idx_t count) { + D_ASSERT(hashes.GetType().id() == LogicalType::HASH); + switch (input.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::INT16: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::INT32: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::INT64: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::UINT8: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::UINT16: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::UINT32: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::UINT64: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::INT128: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::FLOAT: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::DOUBLE: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::INTERVAL: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::VARCHAR: + TemplatedLoopCombineHash(input, hashes, rsel, count); + break; + case PhysicalType::STRUCT: + StructLoopHash(input, hashes, rsel, count); + break; + case PhysicalType::LIST: + ListLoopHash(input, hashes, rsel, count); + break; + default: + throw InvalidTypeException(input.GetType(), "Invalid type for hash"); + } +} + +void VectorOperations::CombineHash(Vector &hashes, Vector &input, idx_t count) { + CombineHashTypeSwitch(hashes, input, nullptr, count); +} + +void VectorOperations::CombineHash(Vector &hashes, Vector &input, const SelectionVector &rsel, idx_t count) { + CombineHashTypeSwitch(hashes, input, &rsel, count); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/vector_storage.cpp b/src/duckdb/src/common/vector_operations/vector_storage.cpp new file mode 100644 index 00000000..be7c97c7 --- /dev/null +++ b/src/duckdb/src/common/vector_operations/vector_storage.cpp @@ -0,0 +1,125 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +namespace duckdb { + +template +static void CopyToStorageLoop(UnifiedVectorFormat &vdata, idx_t count, data_ptr_t target) { + auto ldata = UnifiedVectorFormat::GetData(vdata); + auto result_data = (T *)target; + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + if (!vdata.validity.RowIsValid(idx)) { + result_data[i] = NullValue(); + } else { + result_data[i] = ldata[idx]; + } + } +} + +void VectorOperations::WriteToStorage(Vector &source, idx_t count, data_ptr_t target) { + if (count == 0) { + return; + } + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + + switch (source.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::INT16: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::INT32: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::INT64: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::UINT8: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::UINT16: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::UINT32: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::UINT64: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::INT128: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::FLOAT: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::DOUBLE: + CopyToStorageLoop(vdata, count, target); + break; + case PhysicalType::INTERVAL: + CopyToStorageLoop(vdata, count, target); + break; + default: + throw NotImplementedException("Unimplemented type for WriteToStorage"); + } +} + +template +static void ReadFromStorageLoop(data_ptr_t source, idx_t count, Vector &result) { + auto ldata = (T *)source; + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + result_data[i] = ldata[i]; + } +} + +void VectorOperations::ReadFromStorage(data_ptr_t source, idx_t count, Vector &result) { + result.SetVectorType(VectorType::FLAT_VECTOR); + switch (result.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::INT16: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::INT32: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::INT64: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::UINT8: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::UINT16: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::UINT32: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::UINT64: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::INT128: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::FLOAT: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::DOUBLE: + ReadFromStorageLoop(source, count, result); + break; + case PhysicalType::INTERVAL: + ReadFromStorageLoop(source, count, result); + break; + default: + throw NotImplementedException("Unimplemented type for ReadFromStorage"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/virtual_file_system.cpp b/src/duckdb/src/common/virtual_file_system.cpp new file mode 100644 index 00000000..0aaff142 --- /dev/null +++ b/src/duckdb/src/common/virtual_file_system.cpp @@ -0,0 +1,186 @@ +#include "duckdb/common/virtual_file_system.hpp" + +#include "duckdb/common/gzip_file_system.hpp" +#include "duckdb/common/pipe_file_system.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +VirtualFileSystem::VirtualFileSystem() : default_fs(FileSystem::CreateLocal()) { + VirtualFileSystem::RegisterSubSystem(FileCompressionType::GZIP, make_uniq()); +} + +unique_ptr VirtualFileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock, + FileCompressionType compression, FileOpener *opener) { + if (compression == FileCompressionType::AUTO_DETECT) { + // auto detect compression settings based on file name + auto lower_path = StringUtil::Lower(path); + if (StringUtil::EndsWith(lower_path, ".tmp")) { + // strip .tmp + lower_path = lower_path.substr(0, lower_path.length() - 4); + } + if (StringUtil::EndsWith(lower_path, ".gz")) { + compression = FileCompressionType::GZIP; + } else if (StringUtil::EndsWith(lower_path, ".zst")) { + compression = FileCompressionType::ZSTD; + } else { + compression = FileCompressionType::UNCOMPRESSED; + } + } + // open the base file handle + auto file_handle = FindFileSystem(path).OpenFile(path, flags, lock, FileCompressionType::UNCOMPRESSED, opener); + if (file_handle->GetType() == FileType::FILE_TYPE_FIFO) { + file_handle = PipeFileSystem::OpenPipe(std::move(file_handle)); + } else if (compression != FileCompressionType::UNCOMPRESSED) { + auto entry = compressed_fs.find(compression); + if (entry == compressed_fs.end()) { + throw NotImplementedException( + "Attempting to open a compressed file, but the compression type is not supported"); + } + file_handle = entry->second->OpenCompressedFile(std::move(file_handle), flags & FileFlags::FILE_FLAGS_WRITE); + } + return file_handle; +} + +void VirtualFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + handle.file_system.Read(handle, buffer, nr_bytes, location); +} + +void VirtualFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + handle.file_system.Write(handle, buffer, nr_bytes, location); +} + +int64_t VirtualFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + return handle.file_system.Read(handle, buffer, nr_bytes); +} + +int64_t VirtualFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { + return handle.file_system.Write(handle, buffer, nr_bytes); +} + +int64_t VirtualFileSystem::GetFileSize(FileHandle &handle) { + return handle.file_system.GetFileSize(handle); +} +time_t VirtualFileSystem::GetLastModifiedTime(FileHandle &handle) { + return handle.file_system.GetLastModifiedTime(handle); +} +FileType VirtualFileSystem::GetFileType(FileHandle &handle) { + return handle.file_system.GetFileType(handle); +} + +void VirtualFileSystem::Truncate(FileHandle &handle, int64_t new_size) { + handle.file_system.Truncate(handle, new_size); +} + +void VirtualFileSystem::FileSync(FileHandle &handle) { + handle.file_system.FileSync(handle); +} + +// need to look up correct fs for this +bool VirtualFileSystem::DirectoryExists(const string &directory) { + return FindFileSystem(directory).DirectoryExists(directory); +} +void VirtualFileSystem::CreateDirectory(const string &directory) { + FindFileSystem(directory).CreateDirectory(directory); +} + +void VirtualFileSystem::RemoveDirectory(const string &directory) { + FindFileSystem(directory).RemoveDirectory(directory); +} + +bool VirtualFileSystem::ListFiles(const string &directory, const std::function &callback, + FileOpener *opener) { + return FindFileSystem(directory).ListFiles(directory, callback, opener); +} + +void VirtualFileSystem::MoveFile(const string &source, const string &target) { + FindFileSystem(source).MoveFile(source, target); +} + +bool VirtualFileSystem::FileExists(const string &filename) { + return FindFileSystem(filename).FileExists(filename); +} + +bool VirtualFileSystem::IsPipe(const string &filename) { + return FindFileSystem(filename).IsPipe(filename); +} +void VirtualFileSystem::RemoveFile(const string &filename) { + FindFileSystem(filename).RemoveFile(filename); +} + +string VirtualFileSystem::PathSeparator(const string &path) { + return FindFileSystem(path).PathSeparator(path); +} + +vector VirtualFileSystem::Glob(const string &path, FileOpener *opener) { + return FindFileSystem(path).Glob(path, opener); +} + +void VirtualFileSystem::RegisterSubSystem(unique_ptr fs) { + sub_systems.push_back(std::move(fs)); +} + +void VirtualFileSystem::UnregisterSubSystem(const string &name) { + for (auto sub_system = sub_systems.begin(); sub_system != sub_systems.end(); sub_system++) { + if (sub_system->get()->GetName() == name) { + sub_systems.erase(sub_system); + return; + } + } + throw InvalidInputException("Could not find filesystem with name %s", name); +} + +void VirtualFileSystem::RegisterSubSystem(FileCompressionType compression_type, unique_ptr fs) { + compressed_fs[compression_type] = std::move(fs); +} + +vector VirtualFileSystem::ListSubSystems() { + vector names(sub_systems.size()); + for (idx_t i = 0; i < sub_systems.size(); i++) { + names[i] = sub_systems[i]->GetName(); + } + return names; +} + +std::string VirtualFileSystem::GetName() const { + return "VirtualFileSystem"; +} + +void VirtualFileSystem::SetDisabledFileSystems(const vector &names) { + unordered_set new_disabled_file_systems; + for (auto &name : names) { + if (name.empty()) { + continue; + } + if (new_disabled_file_systems.find(name) != new_disabled_file_systems.end()) { + throw InvalidInputException("Duplicate disabled file system \"%s\"", name); + } + new_disabled_file_systems.insert(name); + } + for (auto &disabled_fs : disabled_file_systems) { + if (new_disabled_file_systems.find(disabled_fs) == new_disabled_file_systems.end()) { + throw InvalidInputException("File system \"%s\" has been disabled previously, it cannot be re-enabled", + disabled_fs); + } + } + disabled_file_systems = std::move(new_disabled_file_systems); +} + +FileSystem &VirtualFileSystem::FindFileSystem(const string &path) { + auto &fs = FindFileSystemInternal(path); + if (!disabled_file_systems.empty() && disabled_file_systems.find(fs.GetName()) != disabled_file_systems.end()) { + throw PermissionException("File system %s has been disabled by configuration", fs.GetName()); + } + return fs; +} + +FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { + for (auto &sub_system : sub_systems) { + if (sub_system->CanHandleFile(path)) { + return *sub_system; + } + } + return *default_fs; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/windows_util.cpp b/src/duckdb/src/common/windows_util.cpp new file mode 100644 index 00000000..8be7a3a0 --- /dev/null +++ b/src/duckdb/src/common/windows_util.cpp @@ -0,0 +1,53 @@ +#include "duckdb/common/windows_util.hpp" + +namespace duckdb { + +#ifdef DUCKDB_WINDOWS + +std::wstring WindowsUtil::UTF8ToUnicode(const char *input) { + idx_t result_size; + + result_size = MultiByteToWideChar(CP_UTF8, 0, input, -1, nullptr, 0); + if (result_size == 0) { + throw IOException("Failure in MultiByteToWideChar"); + } + auto buffer = make_unsafe_uniq_array(result_size); + result_size = MultiByteToWideChar(CP_UTF8, 0, input, -1, buffer.get(), result_size); + if (result_size == 0) { + throw IOException("Failure in MultiByteToWideChar"); + } + return std::wstring(buffer.get(), result_size); +} + +static string WideCharToMultiByteWrapper(LPCWSTR input, uint32_t code_page) { + idx_t result_size; + + result_size = WideCharToMultiByte(code_page, 0, input, -1, 0, 0, 0, 0); + if (result_size == 0) { + throw IOException("Failure in WideCharToMultiByte"); + } + auto buffer = make_unsafe_uniq_array(result_size); + result_size = WideCharToMultiByte(code_page, 0, input, -1, buffer.get(), result_size, 0, 0); + if (result_size == 0) { + throw IOException("Failure in WideCharToMultiByte"); + } + return string(buffer.get(), result_size - 1); +} + +string WindowsUtil::UnicodeToUTF8(LPCWSTR input) { + return WideCharToMultiByteWrapper(input, CP_UTF8); +} + +static string WindowsUnicodeToMBCS(LPCWSTR unicode_text, int use_ansi) { + uint32_t code_page = use_ansi ? CP_ACP : CP_OEMCP; + return WideCharToMultiByteWrapper(unicode_text, code_page); +} + +string WindowsUtil::UTF8ToMBCS(const char *input, bool use_ansi) { + auto unicode = WindowsUtil::UTF8ToUnicode(input); + return WindowsUnicodeToMBCS(unicode.c_str(), use_ansi); +} + +#endif + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/algebraic/avg.cpp b/src/duckdb/src/core_functions/aggregate/algebraic/avg.cpp new file mode 100644 index 00000000..9cebfc4a --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/algebraic/avg.cpp @@ -0,0 +1,196 @@ +#include "duckdb/core_functions/aggregate/algebraic_functions.hpp" +#include "duckdb/core_functions/aggregate/sum_helpers.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +template +struct AvgState { + uint64_t count; + T value; + + void Initialize() { + this->count = 0; + } + + void Combine(const AvgState &other) { + this->count += other.count; + this->value += other.value; + } +}; + +struct KahanAvgState { + uint64_t count; + double value; + double err; + + void Initialize() { + this->count = 0; + this->err = 0.0; + } + + void Combine(const KahanAvgState &other) { + this->count += other.count; + KahanAddInternal(other.value, this->value, this->err); + KahanAddInternal(other.err, this->value, this->err); + } +}; + +struct AverageDecimalBindData : public FunctionData { + explicit AverageDecimalBindData(double scale) : scale(scale) { + } + + double scale; + +public: + unique_ptr Copy() const override { + return make_uniq(scale); + }; + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return scale == other.scale; + } +}; + +struct AverageSetOperation { + template + static void Initialize(STATE &state) { + state.Initialize(); + } + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.Combine(source); + } + template + static void AddValues(STATE &state, idx_t count) { + state.count += count; + } +}; + +template +static T GetAverageDivident(uint64_t count, optional_ptr bind_data) { + T divident = T(count); + if (bind_data) { + auto &avg_bind_data = bind_data->Cast(); + divident *= avg_bind_data.scale; + } + return divident; +} + +struct IntegerAverageOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); + target = double(state.value) / divident; + } + } +}; + +struct IntegerAverageOperationHugeint : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + long double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); + target = Hugeint::Cast(state.value) / divident; + } + } +}; + +struct HugeintAverageOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + long double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); + target = Hugeint::Cast(state.value) / divident; + } + } +}; + +struct NumericAverageOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.value / state.count; + } + } +}; + +struct KahanAverageOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = (state.value / state.count) + (state.err / state.count); + } + } +}; + +AggregateFunction GetAverageAggregate(PhysicalType type) { + switch (type) { + case PhysicalType::INT16: { + return AggregateFunction::UnaryAggregate, int16_t, double, IntegerAverageOperation>( + LogicalType::SMALLINT, LogicalType::DOUBLE); + } + case PhysicalType::INT32: { + return AggregateFunction::UnaryAggregate, int32_t, double, IntegerAverageOperationHugeint>( + LogicalType::INTEGER, LogicalType::DOUBLE); + } + case PhysicalType::INT64: { + return AggregateFunction::UnaryAggregate, int64_t, double, IntegerAverageOperationHugeint>( + LogicalType::BIGINT, LogicalType::DOUBLE); + } + case PhysicalType::INT128: { + return AggregateFunction::UnaryAggregate, hugeint_t, double, HugeintAverageOperation>( + LogicalType::HUGEINT, LogicalType::DOUBLE); + } + default: + throw InternalException("Unimplemented average aggregate"); + } +} + +unique_ptr BindDecimalAvg(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + function = GetAverageAggregate(decimal_type.InternalType()); + function.name = "avg"; + function.arguments[0] = decimal_type; + function.return_type = LogicalType::DOUBLE; + return make_uniq( + Hugeint::Cast(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)])); +} + +AggregateFunctionSet AvgFun::GetFunctions() { + AggregateFunctionSet avg; + + avg.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + BindDecimalAvg)); + avg.AddFunction(GetAverageAggregate(PhysicalType::INT16)); + avg.AddFunction(GetAverageAggregate(PhysicalType::INT32)); + avg.AddFunction(GetAverageAggregate(PhysicalType::INT64)); + avg.AddFunction(GetAverageAggregate(PhysicalType::INT128)); + avg.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericAverageOperation>( + LogicalType::DOUBLE, LogicalType::DOUBLE)); + return avg; +} + +AggregateFunction FAvgFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/algebraic/corr.cpp b/src/duckdb/src/core_functions/aggregate/algebraic/corr.cpp new file mode 100644 index 00000000..61678684 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/algebraic/corr.cpp @@ -0,0 +1,13 @@ +#include "duckdb/core_functions/aggregate/algebraic_functions.hpp" +#include "duckdb/core_functions/aggregate/algebraic/covar.hpp" +#include "duckdb/core_functions/aggregate/algebraic/stddev.hpp" +#include "duckdb/core_functions/aggregate/algebraic/corr.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +AggregateFunction CorrFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/algebraic/covar.cpp b/src/duckdb/src/core_functions/aggregate/algebraic/covar.cpp new file mode 100644 index 00000000..ced7d8be --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/algebraic/covar.cpp @@ -0,0 +1,17 @@ +#include "duckdb/core_functions/aggregate/algebraic_functions.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/core_functions/aggregate/algebraic/covar.hpp" + +namespace duckdb { + +AggregateFunction CovarPopFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +AggregateFunction CovarSampFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/algebraic/stddev.cpp b/src/duckdb/src/core_functions/aggregate/algebraic/stddev.cpp new file mode 100644 index 00000000..b21467ee --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/algebraic/stddev.cpp @@ -0,0 +1,34 @@ +#include "duckdb/core_functions/aggregate/algebraic_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/core_functions/aggregate/algebraic/stddev.hpp" +#include + +namespace duckdb { + +AggregateFunction StdDevSampFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction StdDevPopFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction VarPopFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction VarSampFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction StandardErrorOfTheMeanFun::GetFunction() { + return AggregateFunction::UnaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/approx_count.cpp b/src/duckdb/src/core_functions/aggregate/distributive/approx_count.cpp new file mode 100644 index 00000000..599ecb7d --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/approx_count.cpp @@ -0,0 +1,145 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/types/hyperloglog.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +struct ApproxDistinctCountState { + ApproxDistinctCountState() : log(nullptr) { + } + ~ApproxDistinctCountState() { + if (log) { + delete log; + } + } + + HyperLogLog *log; +}; + +struct ApproxCountDistinctFunction { + template + static void Initialize(STATE &state) { + state.log = nullptr; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.log) { + return; + } + if (!target.log) { + target.log = new HyperLogLog(); + } + D_ASSERT(target.log); + D_ASSERT(source.log); + auto new_log = target.log->MergePointer(*source.log); + delete target.log; + target.log = new_log; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.log) { + target = state.log->Count(); + } else { + target = 0; + } + } + + static bool IgnoreNull() { + return true; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.log) { + delete state.log; + state.log = nullptr; + } + } +}; + +static void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, + data_ptr_t state, idx_t count) { + D_ASSERT(input_count == 1); + + auto agg_state = reinterpret_cast(state); + if (!agg_state->log) { + agg_state->log = new HyperLogLog(); + } + + UnifiedVectorFormat vdata; + inputs[0].ToUnifiedFormat(count, vdata); + + if (count > STANDARD_VECTOR_SIZE) { + throw InternalException("ApproxCountDistinct - count must be at most vector size"); + } + uint64_t indices[STANDARD_VECTOR_SIZE]; + uint8_t counts[STANDARD_VECTOR_SIZE]; + HyperLogLog::ProcessEntries(vdata, inputs[0].GetType(), indices, counts, count); + agg_state->log->AddToLog(vdata, count, indices, counts); +} + +static void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, + Vector &state_vector, idx_t count) { + D_ASSERT(input_count == 1); + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = UnifiedVectorFormat::GetDataNoConst(sdata); + + for (idx_t i = 0; i < count; i++) { + auto agg_state = states[sdata.sel->get_index(i)]; + if (!agg_state->log) { + agg_state->log = new HyperLogLog(); + } + } + + UnifiedVectorFormat vdata; + inputs[0].ToUnifiedFormat(count, vdata); + + if (count > STANDARD_VECTOR_SIZE) { + throw InternalException("ApproxCountDistinct - count must be at most vector size"); + } + uint64_t indices[STANDARD_VECTOR_SIZE]; + uint8_t counts[STANDARD_VECTOR_SIZE]; + HyperLogLog::ProcessEntries(vdata, inputs[0].GetType(), indices, counts, count); + HyperLogLog::AddToLogs(vdata, count, indices, counts, reinterpret_cast(states), sdata.sel); +} + +AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) { + auto fun = AggregateFunction( + {input_type}, LogicalTypeId::BIGINT, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + ApproxCountDistinctUpdateFunction, + AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, + ApproxCountDistinctSimpleUpdateFunction, nullptr, + AggregateFunction::StateDestroy); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +AggregateFunctionSet ApproxCountDistinctFun::GetFunctions() { + AggregateFunctionSet approx_count("approx_count_distinct"); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UTINYINT)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::USMALLINT)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UINTEGER)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UBIGINT)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TINYINT)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::SMALLINT)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BIGINT)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::HUGEINT)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::FLOAT)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::DOUBLE)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::VARCHAR)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP_TZ)); + approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BLOB)); + return approx_count; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/arg_min_max.cpp b/src/duckdb/src/core_functions/aggregate/distributive/arg_min_max.cpp new file mode 100644 index 00000000..2fa99291 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/arg_min_max.cpp @@ -0,0 +1,338 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" + +namespace duckdb { + +struct ArgMinMaxStateBase { + ArgMinMaxStateBase() : is_initialized(false) { + } + + template + static inline void CreateValue(T &value) { + } + + template + static inline void DestroyValue(T &value) { + } + + template + static inline void AssignValue(T &target, T new_value, bool is_initialized) { + target = new_value; + } + + template + static inline void ReadValue(Vector &result, T &arg, T &target) { + target = arg; + } + + bool is_initialized; +}; + +// Out-of-line specialisations +template <> +void ArgMinMaxStateBase::CreateValue(Vector *&value) { + value = nullptr; +} + +template <> +void ArgMinMaxStateBase::DestroyValue(string_t &value) { + if (!value.IsInlined()) { + delete[] value.GetData(); + } +} + +template <> +void ArgMinMaxStateBase::DestroyValue(Vector *&value) { + delete value; + value = nullptr; +} + +template <> +void ArgMinMaxStateBase::AssignValue(string_t &target, string_t new_value, bool is_initialized) { + if (is_initialized) { + DestroyValue(target); + } + if (new_value.IsInlined()) { + target = new_value; + } else { + // non-inlined string, need to allocate space for it + auto len = new_value.GetSize(); + auto ptr = new char[len]; + memcpy(ptr, new_value.GetData(), len); + + target = string_t(ptr, len); + } +} + +template <> +void ArgMinMaxStateBase::ReadValue(Vector &result, string_t &arg, string_t &target) { + target = StringVector::AddStringOrBlob(result, arg); +} + +template +struct ArgMinMaxState : public ArgMinMaxStateBase { + using ARG_TYPE = A; + using BY_TYPE = B; + + ARG_TYPE arg; + BY_TYPE value; + + ArgMinMaxState() { + CreateValue(arg); + CreateValue(value); + } + + ~ArgMinMaxState() { + if (is_initialized) { + DestroyValue(arg); + DestroyValue(value); + is_initialized = false; + } + } +}; + +template +struct ArgMinMaxBase { + + template + static void Initialize(STATE &state) { + new (&state) STATE; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.~STATE(); + } + + template + static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &) { + if (!state.is_initialized) { + STATE::template AssignValue(state.arg, x, false); + STATE::template AssignValue(state.value, y, false); + state.is_initialized = true; + } else { + OP::template Execute(state, x, y); + } + } + + template + static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data) { + if (COMPARATOR::Operation(y_data, state.value)) { + STATE::template AssignValue(state.arg, x_data, true); + STATE::template AssignValue(state.value, y_data, true); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_initialized) { + return; + } + if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { + STATE::template AssignValue(target.arg, source.arg, target.is_initialized); + STATE::template AssignValue(target.value, source.value, target.is_initialized); + target.is_initialized = true; + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_initialized) { + finalize_data.ReturnNull(); + } else { + STATE::template ReadValue(finalize_data.result, state.arg, target); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +template +struct VectorArgMinMaxBase : ArgMinMaxBase { + template + static void AssignVector(STATE &state, Vector &arg, const idx_t idx) { + if (!state.is_initialized) { + state.arg = new Vector(arg.GetType()); + state.arg->SetVectorType(VectorType::CONSTANT_VECTOR); + } + sel_t selv = idx; + SelectionVector sel(&selv); + VectorOperations::Copy(arg, *state.arg, sel, 1, 0, 0); + } + + template + static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { + auto &arg = inputs[0]; + UnifiedVectorFormat adata; + arg.ToUnifiedFormat(count, adata); + + using BY_TYPE = typename STATE::BY_TYPE; + auto &by = inputs[1]; + UnifiedVectorFormat bdata; + by.ToUnifiedFormat(count, bdata); + const auto bys = UnifiedVectorFormat::GetData(bdata); + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + auto states = (STATE **)sdata.data; + for (idx_t i = 0; i < count; i++) { + const auto bidx = bdata.sel->get_index(i); + if (!bdata.validity.RowIsValid(bidx)) { + continue; + } + const auto bval = bys[bidx]; + + const auto sidx = sdata.sel->get_index(i); + auto &state = *states[sidx]; + if (!state.is_initialized) { + STATE::template AssignValue(state.value, bval, false); + AssignVector(state, arg, i); + state.is_initialized = true; + + } else if (COMPARATOR::template Operation(bval, state.value)) { + STATE::template AssignValue(state.value, bval, true); + AssignVector(state, arg, i); + } + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_initialized) { + return; + } + if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { + STATE::template AssignValue(target.value, source.value, target.is_initialized); + AssignVector(target, *source.arg, 0); + target.is_initialized = true; + } + } + + template + static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { + if (!state.is_initialized) { + finalize_data.ReturnNull(); + } else { + VectorOperations::Copy(*state.arg, finalize_data.result, 1, 0, finalize_data.result_idx); + } + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function.arguments[0] = arguments[0]->return_type; + function.return_type = arguments[0]->return_type; + return nullptr; + } +}; + +template +AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { + using STATE = ArgMinMaxState; + return AggregateFunction( + {type, by_type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + OP::template Update, AggregateFunction::StateCombine, + AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { + switch (by_type.InternalType()) { + case PhysicalType::INT32: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::INT64: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::DOUBLE: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::VARCHAR: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + default: + throw InternalException("Unimplemented arg_min/arg_max aggregate"); + } +} + +template +void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { + fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::INTEGER, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::BIGINT, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::DOUBLE, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::VARCHAR, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::DATE, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::TIMESTAMP, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::TIMESTAMP_TZ, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(LogicalType::BLOB, type)); +} + +template +AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { + using STATE = ArgMinMaxState; + auto function = AggregateFunction::BinaryAggregate(type, by_type, type); + if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) { + function.destructor = AggregateFunction::StateDestroy; + } + return function; +} + +template +AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { + switch (by_type.InternalType()) { + case PhysicalType::INT32: + return GetArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::INT64: + return GetArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::DOUBLE: + return GetArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::VARCHAR: + return GetArgMinMaxFunctionInternal(by_type, type); + default: + throw InternalException("Unimplemented arg_min/arg_max aggregate"); + } +} + +template +void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { + fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::INTEGER, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::BIGINT, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::DOUBLE, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::VARCHAR, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::DATE, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::TIMESTAMP, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::TIMESTAMP_TZ, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(LogicalType::BLOB, type)); +} + +template +static void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { + using OP = ArgMinMaxBase; + AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER); + AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT); + AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE); + AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR); + AddArgMinMaxFunctionBy(fun, LogicalType::DATE); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ); + AddArgMinMaxFunctionBy(fun, LogicalType::BLOB); + + using VECTOR_OP = VectorArgMinMaxBase; + AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); +} + +AggregateFunctionSet ArgMinFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun); + return fun; +} + +AggregateFunctionSet ArgMaxFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun); + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/bitagg.cpp b/src/duckdb/src/core_functions/aggregate/distributive/bitagg.cpp new file mode 100644 index 00000000..52b671d9 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/bitagg.cpp @@ -0,0 +1,226 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/aggregate_executor.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +template +struct BitState { + bool is_set; + T value; +}; + +template +static AggregateFunction GetBitfieldUnaryAggregate(LogicalType type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); + case LogicalTypeId::SMALLINT: + return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type); + case LogicalTypeId::INTEGER: + return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type); + case LogicalTypeId::BIGINT: + return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type); + case LogicalTypeId::HUGEINT: + return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type); + case LogicalTypeId::UTINYINT: + return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type); + case LogicalTypeId::USMALLINT: + return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type); + case LogicalTypeId::UINTEGER: + return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type); + case LogicalTypeId::UBIGINT: + return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type); + default: + throw InternalException("Unimplemented bitfield type for unary aggregate"); + } +} + +struct BitwiseOperation { + template + static void Initialize(STATE &state) { + // If there are no matching rows, returns a null value. + state.is_set = false; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + if (!state.is_set) { + OP::template Assign(state, input); + state.is_set = true; + } else { + OP::template Execute(state, input); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + OP::template Operation(state, input, unary_input); + } + + template + static void Assign(STATE &state, INPUT_TYPE input) { + state.value = input; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_set) { + // source is NULL, nothing to do. + return; + } + if (!target.is_set) { + // target is NULL, use source value directly. + OP::template Assign(target, source.value); + target.is_set = true; + } else { + OP::template Execute(target, source.value); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct BitAndOperation : public BitwiseOperation { + template + static void Execute(STATE &state, INPUT_TYPE input) { + state.value &= input; + } +}; + +struct BitOrOperation : public BitwiseOperation { + template + static void Execute(STATE &state, INPUT_TYPE input) { + state.value |= input; + } +}; + +struct BitXorOperation : public BitwiseOperation { + template + static void Execute(STATE &state, INPUT_TYPE input) { + state.value ^= input; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } +}; + +struct BitStringBitwiseOperation : public BitwiseOperation { + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.is_set && !state.value.IsInlined()) { + delete[] state.value.GetData(); + } + } + + template + static void Assign(STATE &state, INPUT_TYPE input) { + D_ASSERT(state.is_set == false); + if (input.IsInlined()) { + state.value = input; + } else { // non-inlined string, need to allocate space for it + auto len = input.GetSize(); + auto ptr = new char[len]; + memcpy(ptr, input.GetData(), len); + + state.value = string_t(ptr, len); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set) { + finalize_data.ReturnNull(); + } else { + target = finalize_data.ReturnString(state.value); + } + } +}; + +struct BitStringAndOperation : public BitStringBitwiseOperation { + + template + static void Execute(STATE &state, INPUT_TYPE input) { + Bit::BitwiseAnd(input, state.value, state.value); + } +}; + +struct BitStringOrOperation : public BitStringBitwiseOperation { + + template + static void Execute(STATE &state, INPUT_TYPE input) { + Bit::BitwiseOr(input, state.value, state.value); + } +}; + +struct BitStringXorOperation : public BitStringBitwiseOperation { + template + static void Execute(STATE &state, INPUT_TYPE input) { + Bit::BitwiseXor(input, state.value, state.value); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } +}; + +AggregateFunctionSet BitAndFun::GetFunctions() { + AggregateFunctionSet bit_and; + for (auto &type : LogicalType::Integral()) { + bit_and.AddFunction(GetBitfieldUnaryAggregate(type)); + } + + bit_and.AddFunction( + AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringAndOperation>( + LogicalType::BIT, LogicalType::BIT)); + return bit_and; +} + +AggregateFunctionSet BitOrFun::GetFunctions() { + AggregateFunctionSet bit_or; + for (auto &type : LogicalType::Integral()) { + bit_or.AddFunction(GetBitfieldUnaryAggregate(type)); + } + bit_or.AddFunction( + AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringOrOperation>( + LogicalType::BIT, LogicalType::BIT)); + return bit_or; +} + +AggregateFunctionSet BitXorFun::GetFunctions() { + AggregateFunctionSet bit_xor; + for (auto &type : LogicalType::Integral()) { + bit_xor.AddFunction(GetBitfieldUnaryAggregate(type)); + } + bit_xor.AddFunction( + AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringXorOperation>( + LogicalType::BIT, LogicalType::BIT)); + return bit_xor; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/bitstring_agg.cpp b/src/duckdb/src/core_functions/aggregate/distributive/bitstring_agg.cpp new file mode 100644 index 00000000..303021bd --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/bitstring_agg.cpp @@ -0,0 +1,269 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/aggregate_executor.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/operator/subtract.hpp" + +namespace duckdb { + +template +struct BitAggState { + bool is_set; + string_t value; + INPUT_TYPE min; + INPUT_TYPE max; +}; + +struct BitstringAggBindData : public FunctionData { + Value min; + Value max; + + BitstringAggBindData() { + } + + BitstringAggBindData(Value min, Value max) : min(std::move(min)), max(std::move(max)) { + } + + unique_ptr Copy() const override { + return make_uniq(*this); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + if (min.IsNull() && other.min.IsNull() && max.IsNull() && other.max.IsNull()) { + return true; + } + if (Value::NotDistinctFrom(min, other.min) && Value::NotDistinctFrom(max, other.max)) { + return true; + } + return false; + } +}; + +struct BitStringAggOperation { + static constexpr const idx_t MAX_BIT_RANGE = 1000000000; // for now capped at 1 billion bits + + template + static void Initialize(STATE &state) { + state.is_set = false; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + auto &bind_agg_data = unary_input.input.bind_data->template Cast(); + if (!state.is_set) { + if (bind_agg_data.min.IsNull() || bind_agg_data.max.IsNull()) { + throw BinderException( + "Could not retrieve required statistics. Alternatively, try by providing the statistics " + "explicitly: BITSTRING_AGG(col, min, max) "); + } + state.min = bind_agg_data.min.GetValue(); + state.max = bind_agg_data.max.GetValue(); + idx_t bit_range = + GetRange(bind_agg_data.min.GetValue(), bind_agg_data.max.GetValue()); + if (bit_range > MAX_BIT_RANGE) { + throw OutOfRangeException( + "The range between min and max value (%s <-> %s) is too large for bitstring aggregation", + NumericHelper::ToString(state.min), NumericHelper::ToString(state.max)); + } + idx_t len = Bit::ComputeBitstringLen(bit_range); + auto target = len > string_t::INLINE_LENGTH ? string_t(new char[len], len) : string_t(len); + Bit::SetEmptyBitString(target, bit_range); + + state.value = target; + state.is_set = true; + } + if (input >= state.min && input <= state.max) { + Execute(state, input, bind_agg_data.min.GetValue()); + } else { + throw OutOfRangeException("Value %s is outside of provided min and max range (%s <-> %s)", + NumericHelper::ToString(input), NumericHelper::ToString(state.min), + NumericHelper::ToString(state.max)); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + OP::template Operation(state, input, unary_input); + } + + template + static idx_t GetRange(INPUT_TYPE min, INPUT_TYPE max) { + D_ASSERT(max >= min); + INPUT_TYPE result; + if (!TrySubtractOperator::Operation(max, min, result)) { + return NumericLimits::Maximum(); + } + idx_t val(result); + if (val == NumericLimits::Maximum()) { + return val; + } + return val + 1; + } + + template + static void Execute(STATE &state, INPUT_TYPE input, INPUT_TYPE min) { + Bit::SetBit(state.value, input - min, 1); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_set) { + return; + } + if (!target.is_set) { + Assign(target, source.value); + target.is_set = true; + target.min = source.min; + target.max = source.max; + } else { + Bit::BitwiseOr(source.value, target.value, target.value); + } + } + + template + static void Assign(STATE &state, INPUT_TYPE input) { + D_ASSERT(state.is_set == false); + if (input.IsInlined()) { + state.value = input; + } else { // non-inlined string, need to allocate space for it + auto len = input.GetSize(); + auto ptr = new char[len]; + memcpy(ptr, input.GetData(), len); + state.value = string_t(ptr, len); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set) { + finalize_data.ReturnNull(); + } else { + target = StringVector::AddStringOrBlob(finalize_data.result, state.value); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.is_set && !state.value.IsInlined()) { + delete[] state.value.GetData(); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +template <> +void BitStringAggOperation::Execute(BitAggState &state, hugeint_t input, hugeint_t min) { + idx_t val; + if (Hugeint::TryCast(input - min, val)) { + Bit::SetBit(state.value, val, 1); + } else { + throw OutOfRangeException("Range too large for bitstring aggregation"); + } +} + +template <> +idx_t BitStringAggOperation::GetRange(hugeint_t min, hugeint_t max) { + hugeint_t result; + if (!TrySubtractOperator::Operation(max, min, result)) { + return NumericLimits::Maximum(); + } + idx_t range; + if (!Hugeint::TryCast(result + 1, range)) { + return NumericLimits::Maximum(); + } + return range; +} + +unique_ptr BitstringPropagateStats(ClientContext &context, BoundAggregateExpression &expr, + AggregateStatisticsInput &input) { + + if (!NumericStats::HasMinMax(input.child_stats[0])) { + throw BinderException("Could not retrieve required statistics. Alternatively, try by providing the statistics " + "explicitly: BITSTRING_AGG(col, min, max) "); + } + auto &bind_agg_data = input.bind_data->Cast(); + bind_agg_data.min = NumericStats::Min(input.child_stats[0]); + bind_agg_data.max = NumericStats::Max(input.child_stats[0]); + return nullptr; +} + +unique_ptr BindBitstringAgg(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments.size() == 3) { + if (!arguments[1]->IsFoldable() || !arguments[2]->IsFoldable()) { + throw BinderException("bitstring_agg requires a constant min and max argument"); + } + auto min = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + auto max = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); + Function::EraseArgument(function, arguments, 2); + Function::EraseArgument(function, arguments, 1); + return make_uniq(min, max); + } + return make_uniq(); +} + +template +static void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &type) { + auto function = + AggregateFunction::UnaryAggregateDestructor, TYPE, string_t, BitStringAggOperation>( + type, LogicalType::BIT); + function.bind = BindBitstringAgg; // create new a 'BitstringAggBindData' + function.statistics = BitstringPropagateStats; // stores min and max from column stats in BitstringAggBindData + bitstring_agg.AddFunction(function); // uses the BitstringAggBindData to access statistics for creating bitstring + function.arguments = {type, type, type}; + function.statistics = nullptr; // min and max are provided as arguments + bitstring_agg.AddFunction(function); +} + +void GetBitStringAggregate(const LogicalType &type, AggregateFunctionSet &bitstring_agg) { + switch (type.id()) { + case LogicalType::TINYINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::SMALLINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::INTEGER: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::BIGINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::HUGEINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::UTINYINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::USMALLINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::UINTEGER: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::UBIGINT: { + return BindBitString(bitstring_agg, type.id()); + } + default: + throw InternalException("Unimplemented bitstring aggregate"); + } +} + +AggregateFunctionSet BitstringAggFun::GetFunctions() { + AggregateFunctionSet bitstring_agg("bitstring_agg"); + for (auto &type : LogicalType::Integral()) { + GetBitStringAggregate(type, bitstring_agg); + } + return bitstring_agg; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/bool.cpp b/src/duckdb/src/core_functions/aggregate/distributive/bool.cpp new file mode 100644 index 00000000..20f2f3ba --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/bool.cpp @@ -0,0 +1,108 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct BoolState { + bool empty; + bool val; +}; + +struct BoolAndFunFunction { + template + static void Initialize(STATE &state) { + state.val = true; + state.empty = true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.val = target.val && source.val; + target.empty = target.empty && source.empty; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.empty) { + finalize_data.ReturnNull(); + return; + } + target = state.val; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + state.empty = false; + state.val = input && state.val; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + static bool IgnoreNull() { + return true; + } +}; + +struct BoolOrFunFunction { + template + static void Initialize(STATE &state) { + state.val = false; + state.empty = true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.val = target.val || source.val; + target.empty = target.empty && source.empty; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.empty) { + finalize_data.ReturnNull(); + return; + } + target = state.val; + } + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + state.empty = false; + state.val = input || state.val; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction BoolOrFun::GetFunction() { + auto fun = AggregateFunction::UnaryAggregate( + LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +AggregateFunction BoolAndFun::GetFunction() { + auto fun = AggregateFunction::UnaryAggregate( + LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/entropy.cpp b/src/duckdb/src/core_functions/aggregate/distributive/entropy.cpp new file mode 100644 index 00000000..b965b811 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/entropy.cpp @@ -0,0 +1,181 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/unordered_map.hpp" + +namespace duckdb { + +template +struct EntropyState { + using DistinctMap = unordered_map; + + idx_t count; + DistinctMap *distinct; + + EntropyState &operator=(const EntropyState &other) = delete; + + EntropyState &Assign(const EntropyState &other) { + D_ASSERT(!distinct); + distinct = new DistinctMap(*other.distinct); + count = other.count; + return *this; + } +}; + +struct EntropyFunctionBase { + template + static void Initialize(STATE &state) { + state.distinct = nullptr; + state.count = 0; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.distinct) { + return; + } + if (!target.distinct) { + target.Assign(source); + return; + } + for (auto &val : *source.distinct) { + auto value = val.first; + (*target.distinct)[value] += val.second; + } + target.count += source.count; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + double count = state.count; + if (state.distinct) { + double entropy = 0; + for (auto &val : *state.distinct) { + entropy += (val.second / count) * log2(count / val.second); + } + target = entropy; + } else { + target = 0; + } + } + + static bool IgnoreNull() { + return true; + } + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.distinct) { + delete state.distinct; + } + } +}; + +struct EntropyFunction : EntropyFunctionBase { + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (!state.distinct) { + state.distinct = new unordered_map(); + } + (*state.distinct)[input]++; + state.count++; + } + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } +}; + +struct EntropyFunctionString : EntropyFunctionBase { + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (!state.distinct) { + state.distinct = new unordered_map(); + } + auto value = input.GetString(); + (*state.distinct)[value]++; + state.count++; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } +}; + +template +AggregateFunction GetEntropyFunction(const LogicalType &input_type, const LogicalType &result_type) { + auto fun = + AggregateFunction::UnaryAggregateDestructor, INPUT_TYPE, RESULT_TYPE, EntropyFunction>( + input_type, result_type); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +AggregateFunction GetEntropyFunctionInternal(PhysicalType type) { + switch (type) { + case PhysicalType::UINT16: + return AggregateFunction::UnaryAggregateDestructor, uint16_t, double, EntropyFunction>( + LogicalType::USMALLINT, LogicalType::DOUBLE); + case PhysicalType::UINT32: + return AggregateFunction::UnaryAggregateDestructor, uint32_t, double, EntropyFunction>( + LogicalType::UINTEGER, LogicalType::DOUBLE); + case PhysicalType::UINT64: + return AggregateFunction::UnaryAggregateDestructor, uint64_t, double, EntropyFunction>( + LogicalType::UBIGINT, LogicalType::DOUBLE); + case PhysicalType::INT16: + return AggregateFunction::UnaryAggregateDestructor, int16_t, double, EntropyFunction>( + LogicalType::SMALLINT, LogicalType::DOUBLE); + case PhysicalType::INT32: + return AggregateFunction::UnaryAggregateDestructor, int32_t, double, EntropyFunction>( + LogicalType::INTEGER, LogicalType::DOUBLE); + case PhysicalType::INT64: + return AggregateFunction::UnaryAggregateDestructor, int64_t, double, EntropyFunction>( + LogicalType::BIGINT, LogicalType::DOUBLE); + case PhysicalType::FLOAT: + return AggregateFunction::UnaryAggregateDestructor, float, double, EntropyFunction>( + LogicalType::FLOAT, LogicalType::DOUBLE); + case PhysicalType::DOUBLE: + return AggregateFunction::UnaryAggregateDestructor, double, double, EntropyFunction>( + LogicalType::DOUBLE, LogicalType::DOUBLE); + case PhysicalType::VARCHAR: + return AggregateFunction::UnaryAggregateDestructor, string_t, double, + EntropyFunctionString>(LogicalType::VARCHAR, + LogicalType::DOUBLE); + + default: + throw InternalException("Unimplemented approximate_count aggregate"); + } +} + +AggregateFunction GetEntropyFunction(PhysicalType type) { + auto fun = GetEntropyFunctionInternal(type); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +AggregateFunctionSet EntropyFun::GetFunctions() { + AggregateFunctionSet entropy("entropy"); + entropy.AddFunction(GetEntropyFunction(PhysicalType::UINT16)); + entropy.AddFunction(GetEntropyFunction(PhysicalType::UINT32)); + entropy.AddFunction(GetEntropyFunction(PhysicalType::UINT64)); + entropy.AddFunction(GetEntropyFunction(PhysicalType::FLOAT)); + entropy.AddFunction(GetEntropyFunction(PhysicalType::INT16)); + entropy.AddFunction(GetEntropyFunction(PhysicalType::INT32)); + entropy.AddFunction(GetEntropyFunction(PhysicalType::INT64)); + entropy.AddFunction(GetEntropyFunction(PhysicalType::DOUBLE)); + entropy.AddFunction(GetEntropyFunction(PhysicalType::VARCHAR)); + entropy.AddFunction(GetEntropyFunction(LogicalType::TIMESTAMP, LogicalType::DOUBLE)); + entropy.AddFunction(GetEntropyFunction(LogicalType::TIMESTAMP_TZ, LogicalType::DOUBLE)); + return entropy; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/kurtosis.cpp b/src/duckdb/src/core_functions/aggregate/distributive/kurtosis.cpp new file mode 100644 index 00000000..063408e0 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/kurtosis.cpp @@ -0,0 +1,93 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +struct KurtosisState { + idx_t n; + double sum; + double sum_sqr; + double sum_cub; + double sum_four; +}; + +struct KurtosisOperation { + template + static void Initialize(STATE &state) { + state.n = 0; + state.sum = state.sum_sqr = state.sum_cub = state.sum_four = 0.0; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + state.n++; + state.sum += input; + state.sum_sqr += pow(input, 2); + state.sum_cub += pow(input, 3); + state.sum_four += pow(input, 4); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.n == 0) { + return; + } + target.n += source.n; + target.sum += source.sum; + target.sum_sqr += source.sum_sqr; + target.sum_cub += source.sum_cub; + target.sum_four += source.sum_four; + } + + template + static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { + auto n = (double)state.n; + if (n <= 3) { + finalize_data.ReturnNull(); + return; + } + double temp = 1 / n; + //! This is necessary due to linux 32 bits + long double temp_aux = 1 / n; + if (state.sum_sqr - state.sum * state.sum * temp == 0 || + state.sum_sqr - state.sum * state.sum * temp_aux == 0) { + finalize_data.ReturnNull(); + return; + } + double m4 = + temp * (state.sum_four - 4 * state.sum_cub * state.sum * temp + + 6 * state.sum_sqr * state.sum * state.sum * temp * temp - 3 * pow(state.sum, 4) * pow(temp, 3)); + + double m2 = temp * (state.sum_sqr - state.sum * state.sum * temp); + if (m2 <= 0 || ((n - 2) * (n - 3)) == 0) { // m2 shouldn't be below 0 but floating points are weird + finalize_data.ReturnNull(); + return; + } + target = (n - 1) * ((n + 1) * m4 / (m2 * m2) - 3 * (n - 1)) / ((n - 2) * (n - 3)); + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("Kurtosis is out of range!"); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction KurtosisFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/minmax.cpp b/src/duckdb/src/core_functions/aggregate/distributive/minmax.cpp new file mode 100644 index 00000000..fa087bab --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/minmax.cpp @@ -0,0 +1,563 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +template +struct MinMaxState { + T value; + bool isset; +}; + +template +static AggregateFunction GetUnaryAggregate(LogicalType type) { + switch (type.InternalType()) { + case PhysicalType::BOOL: + return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); + case PhysicalType::INT8: + return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); + case PhysicalType::INT16: + return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type); + case PhysicalType::INT32: + return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type); + case PhysicalType::INT64: + return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type); + case PhysicalType::UINT8: + return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type); + case PhysicalType::UINT16: + return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type); + case PhysicalType::UINT32: + return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type); + case PhysicalType::UINT64: + return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type); + case PhysicalType::INT128: + return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type); + case PhysicalType::FLOAT: + return AggregateFunction::UnaryAggregate, float, float, OP>(type, type); + case PhysicalType::DOUBLE: + return AggregateFunction::UnaryAggregate, double, double, OP>(type, type); + case PhysicalType::INTERVAL: + return AggregateFunction::UnaryAggregate, interval_t, interval_t, OP>(type, type); + default: + throw InternalException("Unimplemented type for min/max aggregate"); + } +} + +struct MinMaxBase { + template + static void Initialize(STATE &state) { + state.isset = false; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + if (!state.isset) { + OP::template Assign(state, input, unary_input.input); + state.isset = true; + } else { + OP::template Execute(state, input, unary_input.input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (!state.isset) { + OP::template Assign(state, input, unary_input.input); + state.isset = true; + } else { + OP::template Execute(state, input, unary_input.input); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct NumericMinMaxBase : public MinMaxBase { + template + static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &) { + state.value = input; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +struct MinOperation : public NumericMinMaxBase { + template + static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &) { + if (LessThan::Operation(input, state.value)) { + state.value = input; + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.isset) { + // source is NULL, nothing to do + return; + } + if (!target.isset) { + // target is NULL, use source value directly + target = source; + } else if (GreaterThan::Operation(target.value, source.value)) { + target.value = source.value; + } + } +}; + +struct MaxOperation : public NumericMinMaxBase { + template + static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &) { + if (GreaterThan::Operation(input, state.value)) { + state.value = input; + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.isset) { + // source is NULL, nothing to do + return; + } + if (!target.isset) { + // target is NULL, use source value directly + target = source; + } else if (LessThan::Operation(target.value, source.value)) { + target.value = source.value; + } + } +}; + +struct StringMinMaxBase : public MinMaxBase { + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.isset && !state.value.IsInlined()) { + delete[] state.value.GetData(); + } + } + + template + static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { + Destroy(state, input_data); + if (input.IsInlined()) { + state.value = input; + } else { + // non-inlined string, need to allocate space for it + auto len = input.GetSize(); + auto ptr = new char[len]; + memcpy(ptr, input.GetData(), len); + + state.value = string_t(ptr, len); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = StringVector::AddStringOrBlob(finalize_data.result, state.value); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { + if (!source.isset) { + // source is NULL, nothing to do + return; + } + if (!target.isset) { + // target is NULL, use source value directly + Assign(target, source.value, input_data); + target.isset = true; + } else { + OP::template Execute(target, source.value, input_data); + } + } +}; + +struct MinOperationString : public StringMinMaxBase { + template + static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { + if (LessThan::Operation(input, state.value)) { + Assign(state, input, input_data); + } + } +}; + +struct MaxOperationString : public StringMinMaxBase { + template + static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { + if (GreaterThan::Operation(input, state.value)) { + Assign(state, input, input_data); + } + } +}; + +template +static bool TemplatedOptimumType(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { + UnifiedVectorFormat lvdata, rvdata; + left.ToUnifiedFormat(lcount, lvdata); + right.ToUnifiedFormat(rcount, rvdata); + + lidx = lvdata.sel->get_index(lidx); + ridx = rvdata.sel->get_index(ridx); + + auto ldata = UnifiedVectorFormat::GetData(lvdata); + auto rdata = UnifiedVectorFormat::GetData(rvdata); + + auto &lval = ldata[lidx]; + auto &rval = rdata[ridx]; + + auto lnull = !lvdata.validity.RowIsValid(lidx); + auto rnull = !rvdata.validity.RowIsValid(ridx); + + return OP::Operation(lval, rval, lnull, rnull); +} + +template +static bool TemplatedOptimumList(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount); + +template +static bool TemplatedOptimumStruct(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount); + +template +static bool TemplatedOptimumValue(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { + D_ASSERT(left.GetType() == right.GetType()); + switch (left.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::INT16: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::INT32: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::INT64: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::UINT8: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::UINT16: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::UINT32: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::UINT64: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::INT128: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::FLOAT: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::DOUBLE: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::INTERVAL: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::VARCHAR: + return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::LIST: + return TemplatedOptimumList(left, lidx, lcount, right, ridx, rcount); + case PhysicalType::STRUCT: + return TemplatedOptimumStruct(left, lidx, lcount, right, ridx, rcount); + default: + throw InternalException("Invalid type for distinct comparison"); + } +} + +template +static bool TemplatedOptimumStruct(Vector &left, idx_t lidx_p, idx_t lcount, Vector &right, idx_t ridx_p, + idx_t rcount) { + // STRUCT dictionaries apply to all the children + // so map the indexes first + UnifiedVectorFormat lvdata, rvdata; + left.ToUnifiedFormat(lcount, lvdata); + right.ToUnifiedFormat(rcount, rvdata); + + idx_t lidx = lvdata.sel->get_index(lidx_p); + idx_t ridx = rvdata.sel->get_index(ridx_p); + + // DISTINCT semantics are in effect for nested types + auto lnull = !lvdata.validity.RowIsValid(lidx); + auto rnull = !rvdata.validity.RowIsValid(ridx); + if (lnull || rnull) { + return OP::Operation(0, 0, lnull, rnull); + } + + auto &lchildren = StructVector::GetEntries(left); + auto &rchildren = StructVector::GetEntries(right); + + D_ASSERT(lchildren.size() == rchildren.size()); + for (idx_t col_no = 0; col_no < lchildren.size(); ++col_no) { + auto &lchild = *lchildren[col_no]; + auto &rchild = *rchildren[col_no]; + + // Strict comparisons use the OP for definite + if (TemplatedOptimumValue(lchild, lidx_p, lcount, rchild, ridx_p, rcount)) { + return true; + } + + if (col_no == lchildren.size() - 1) { + break; + } + + // Strict comparisons use IS NOT DISTINCT for possible + if (!TemplatedOptimumValue(lchild, lidx_p, lcount, rchild, ridx_p, rcount)) { + return false; + } + } + + return false; +} + +template +static bool TemplatedOptimumList(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { + UnifiedVectorFormat lvdata, rvdata; + left.ToUnifiedFormat(lcount, lvdata); + right.ToUnifiedFormat(rcount, rvdata); + + // Update the indexes and vector sizes for recursion. + lidx = lvdata.sel->get_index(lidx); + ridx = rvdata.sel->get_index(ridx); + + lcount = ListVector::GetListSize(left); + rcount = ListVector::GetListSize(right); + + // DISTINCT semantics are in effect for nested types + auto lnull = !lvdata.validity.RowIsValid(lidx); + auto rnull = !rvdata.validity.RowIsValid(ridx); + if (lnull || rnull) { + return OP::Operation(0, 0, lnull, rnull); + } + + auto &lchild = ListVector::GetEntry(left); + auto &rchild = ListVector::GetEntry(right); + + auto ldata = UnifiedVectorFormat::GetData(lvdata); + auto rdata = UnifiedVectorFormat::GetData(rvdata); + + auto &lval = ldata[lidx]; + auto &rval = rdata[ridx]; + + for (idx_t pos = 0;; ++pos) { + // Tie-breaking uses the OP + if (pos == lval.length || pos == rval.length) { + return OP::Operation(lval.length, rval.length, false, false); + } + + // Strict comparisons use the OP for definite + lidx = lval.offset + pos; + ridx = rval.offset + pos; + if (TemplatedOptimumValue(lchild, lidx, lcount, rchild, ridx, rcount)) { + return true; + } + + // Strict comparisons use IS NOT DISTINCT for possible + if (!TemplatedOptimumValue(lchild, lidx, lcount, rchild, ridx, rcount)) { + return false; + } + } + + return false; +} + +struct VectorMinMaxState { + Vector *value; +}; + +struct VectorMinMaxBase { + static bool IgnoreNull() { + return true; + } + + template + static void Initialize(STATE &state) { + state.value = nullptr; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.value) { + delete state.value; + } + state.value = nullptr; + } + + template + static void Assign(STATE &state, Vector &input, const idx_t idx) { + if (!state.value) { + state.value = new Vector(input.GetType()); + state.value->SetVectorType(VectorType::CONSTANT_VECTOR); + } + sel_t selv = idx; + SelectionVector sel(&selv); + VectorOperations::Copy(input, *state.value, sel, 1, 0, 0); + } + + template + static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { + Assign(state, input, idx); + } + + template + static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { + auto &input = inputs[0]; + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + auto states = (STATE **)sdata.data; + for (idx_t i = 0; i < count; i++) { + const auto idx = idata.sel->get_index(i); + if (!idata.validity.RowIsValid(idx)) { + continue; + } + const auto sidx = sdata.sel->get_index(i); + auto &state = *states[sidx]; + if (!state.value) { + Assign(state, input, i); + } else { + OP::template Execute(state, input, i, count); + } + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.value) { + return; + } else if (!target.value) { + Assign(target, *source.value, 0); + } else { + OP::template Execute(target, *source.value, 0, 1); + } + } + + template + static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { + if (!state.value) { + finalize_data.ReturnNull(); + } else { + VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx); + } + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function.arguments[0] = arguments[0]->return_type; + function.return_type = arguments[0]->return_type; + return nullptr; + } +}; + +struct MinOperationVector : public VectorMinMaxBase { + template + static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { + if (TemplatedOptimumValue(input, idx, count, *state.value, 0, 1)) { + Assign(state, input, idx); + } + } +}; + +struct MaxOperationVector : public VectorMinMaxBase { + template + static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { + if (TemplatedOptimumValue(input, idx, count, *state.value, 0, 1)) { + Assign(state, input, idx); + } + } +}; + +template +unique_ptr BindDecimalMinMax(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + auto name = function.name; + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + function = GetUnaryAggregate(LogicalType::SMALLINT); + break; + case PhysicalType::INT32: + function = GetUnaryAggregate(LogicalType::INTEGER); + break; + case PhysicalType::INT64: + function = GetUnaryAggregate(LogicalType::BIGINT); + break; + default: + function = GetUnaryAggregate(LogicalType::HUGEINT); + break; + } + function.name = std::move(name); + function.arguments[0] = decimal_type; + function.return_type = decimal_type; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return nullptr; +} + +template +static AggregateFunction GetMinMaxFunction(const LogicalType &type) { + return AggregateFunction( + {type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + OP::template Update, AggregateFunction::StateCombine, + AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, AggregateFunction::StateDestroy); +} + +template +static AggregateFunction GetMinMaxOperator(const LogicalType &type) { + if (type.InternalType() == PhysicalType::VARCHAR) { + return AggregateFunction::UnaryAggregateDestructor, string_t, string_t, OP_STRING>( + type.id(), type.id()); + } else if (type.InternalType() == PhysicalType::LIST || type.InternalType() == PhysicalType::STRUCT) { + return GetMinMaxFunction(type); + } else { + return GetUnaryAggregate(type); + } +} + +template +unique_ptr BindMinMax(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto input_type = arguments[0]->return_type; + auto name = std::move(function.name); + function = GetMinMaxOperator(input_type); + function.name = std::move(name); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + if (function.bind) { + return function.bind(context, function, arguments); + } else { + return nullptr; + } +} + +template +static void AddMinMaxOperator(AggregateFunctionSet &set) { + set.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, BindDecimalMinMax)); + set.AddFunction(AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, BindMinMax)); +} + +AggregateFunctionSet MinFun::GetFunctions() { + AggregateFunctionSet min("min"); + AddMinMaxOperator(min); + return min; +} + +AggregateFunctionSet MaxFun::GetFunctions() { + AggregateFunctionSet max("max"); + AddMinMaxOperator(max); + return max; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/product.cpp b/src/duckdb/src/core_functions/aggregate/distributive/product.cpp new file mode 100644 index 00000000..fbe76617 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/product.cpp @@ -0,0 +1,61 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ProductState { + bool empty; + double val; +}; + +struct ProductFunction { + template + static void Initialize(STATE &state) { + state.val = 1; + state.empty = true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.val *= source.val; + target.empty = target.empty && source.empty; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.empty) { + finalize_data.ReturnNull(); + return; + } + target = state.val; + } + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (state.empty) { + state.empty = false; + } + state.val *= input; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction ProductFun::GetFunction() { + return AggregateFunction::UnaryAggregate( + LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/skew.cpp b/src/duckdb/src/core_functions/aggregate/distributive/skew.cpp new file mode 100644 index 00000000..ef42dce4 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/skew.cpp @@ -0,0 +1,86 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +struct SkewState { + size_t n; + double sum; + double sum_sqr; + double sum_cub; +}; + +struct SkewnessOperation { + template + static void Initialize(STATE &state) { + state.n = 0; + state.sum = state.sum_sqr = state.sum_cub = 0; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + state.n++; + state.sum += input; + state.sum_sqr += pow(input, 2); + state.sum_cub += pow(input, 3); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.n == 0) { + return; + } + + target.n += source.n; + target.sum += source.sum; + target.sum_sqr += source.sum_sqr; + target.sum_cub += source.sum_cub; + } + + template + static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { + if (state.n <= 2) { + finalize_data.ReturnNull(); + return; + } + double n = state.n; + double temp = 1 / n; + auto p = std::pow(temp * (state.sum_sqr - state.sum * state.sum * temp), 3); + if (p < 0) { + p = 0; // Shouldn't be below 0 but floating points are weird + } + double div = std::sqrt(p); + if (div == 0) { + finalize_data.ReturnNull(); + return; + } + double temp1 = std::sqrt(n * (n - 1)) / (n - 2); + target = temp1 * temp * + (state.sum_cub - 3 * state.sum_sqr * state.sum * temp + 2 * pow(state.sum, 3) * temp * temp) / div; + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("SKEW is out of range!"); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction SkewnessFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/string_agg.cpp b/src/duckdb/src/core_functions/aggregate/distributive/string_agg.cpp new file mode 100644 index 00000000..b09c52fc --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/string_agg.cpp @@ -0,0 +1,173 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +struct StringAggState { + idx_t size; + idx_t alloc_size; + char *dataptr; +}; + +struct StringAggBindData : public FunctionData { + explicit StringAggBindData(string sep_p) : sep(std::move(sep_p)) { + } + + string sep; + + unique_ptr Copy() const override { + return make_uniq(sep); + } + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return sep == other.sep; + } +}; + +struct StringAggFunction { + template + static void Initialize(STATE &state) { + state.dataptr = nullptr; + state.alloc_size = 0; + state.size = 0; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.dataptr) { + finalize_data.ReturnNull(); + } else { + target = StringVector::AddString(finalize_data.result, state.dataptr, state.size); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.dataptr) { + delete[] state.dataptr; + } + } + + static bool IgnoreNull() { + return true; + } + + static inline void PerformOperation(StringAggState &state, const char *str, const char *sep, idx_t str_size, + idx_t sep_size) { + if (!state.dataptr) { + // first iteration: allocate space for the string and copy it into the state + state.alloc_size = MaxValue(8, NextPowerOfTwo(str_size)); + state.dataptr = new char[state.alloc_size]; + state.size = str_size; + memcpy(state.dataptr, str, str_size); + } else { + // subsequent iteration: first check if we have space to place the string and separator + idx_t required_size = state.size + str_size + sep_size; + if (required_size > state.alloc_size) { + // no space! allocate extra space + while (state.alloc_size < required_size) { + state.alloc_size *= 2; + } + auto new_data = new char[state.alloc_size]; + memcpy(new_data, state.dataptr, state.size); + delete[] state.dataptr; + state.dataptr = new_data; + } + // copy the separator + memcpy(state.dataptr + state.size, sep, sep_size); + state.size += sep_size; + // copy the string + memcpy(state.dataptr + state.size, str, str_size); + state.size += str_size; + } + } + + static inline void PerformOperation(StringAggState &state, string_t str, optional_ptr data_p) { + auto &data = data_p->Cast(); + PerformOperation(state, str.GetData(), data.sep.c_str(), str.GetSize(), data.sep.size()); + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + PerformOperation(state, input, unary_input.input.bind_data); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + if (!source.dataptr) { + // source is not set: skip combining + return; + } + PerformOperation(target, string_t(source.dataptr, source.size), aggr_input_data.bind_data); + } +}; + +unique_ptr StringAggBind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments.size() == 1) { + // single argument: default to comma + return make_uniq(","); + } + D_ASSERT(arguments.size() == 2); + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("Separator argument to StringAgg must be a constant"); + } + auto separator_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + string separator_string = ","; + if (separator_val.IsNull()) { + arguments[0] = make_uniq(Value(LogicalType::VARCHAR)); + } else { + separator_string = separator_val.ToString(); + } + Function::EraseArgument(function, arguments, arguments.size() - 1); + return make_uniq(std::move(separator_string)); +} + +static void StringAggSerialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "separator", bind_data.sep); +} + +unique_ptr StringAggDeserialize(Deserializer &deserializer, AggregateFunction &bound_function) { + auto sep = deserializer.ReadProperty(100, "separator"); + return make_uniq(std::move(sep)); +} + +AggregateFunctionSet StringAggFun::GetFunctions() { + AggregateFunctionSet string_agg; + AggregateFunction string_agg_param( + {LogicalType::VARCHAR}, LogicalType::VARCHAR, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, + AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, + AggregateFunction::UnaryUpdate, StringAggBind, + AggregateFunction::StateDestroy); + string_agg_param.serialize = StringAggSerialize; + string_agg_param.deserialize = StringAggDeserialize; + string_agg.AddFunction(string_agg_param); + string_agg_param.arguments.emplace_back(LogicalType::VARCHAR); + string_agg.AddFunction(string_agg_param); + return string_agg; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/sum.cpp b/src/duckdb/src/core_functions/aggregate/distributive/sum.cpp new file mode 100644 index 00000000..9f243869 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/distributive/sum.cpp @@ -0,0 +1,217 @@ +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/core_functions/aggregate/sum_helpers.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +struct SumSetOperation { + template + static void Initialize(STATE &state) { + state.Initialize(); + } + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.Combine(source); + } + template + static void AddValues(STATE &state, idx_t count) { + state.isset = true; + } +}; + +struct IntegerSumOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = Hugeint::Convert(state.value); + } + } +}; + +struct SumToHugeintOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +template +struct DoubleSumOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +using NumericSumOperation = DoubleSumOperation; +using KahanSumOperation = DoubleSumOperation; + +struct HugeintSumOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) { + switch (type) { + case PhysicalType::INT32: { + auto function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, IntegerSumOperation>( + LogicalType::INTEGER, LogicalType::HUGEINT); + function.name = "sum_no_overflow"; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + case PhysicalType::INT64: { + auto function = AggregateFunction::UnaryAggregate, int64_t, hugeint_t, IntegerSumOperation>( + LogicalType::BIGINT, LogicalType::HUGEINT); + function.name = "sum_no_overflow"; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + default: + throw BinderException("Unsupported internal type for sum_no_overflow"); + } +} + +unique_ptr SumPropagateStats(ClientContext &context, BoundAggregateExpression &expr, + AggregateStatisticsInput &input) { + if (input.node_stats && input.node_stats->has_max_cardinality) { + auto &numeric_stats = input.child_stats[0]; + if (!NumericStats::HasMinMax(numeric_stats)) { + return nullptr; + } + auto internal_type = numeric_stats.GetType().InternalType(); + hugeint_t max_negative; + hugeint_t max_positive; + switch (internal_type) { + case PhysicalType::INT32: + max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe(); + max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe(); + break; + case PhysicalType::INT64: + max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe(); + max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe(); + break; + default: + throw InternalException("Unsupported type for propagate sum stats"); + } + auto max_sum_negative = max_negative * hugeint_t(input.node_stats->max_cardinality); + auto max_sum_positive = max_positive * hugeint_t(input.node_stats->max_cardinality); + if (max_sum_positive >= NumericLimits::Maximum() || + max_sum_negative <= NumericLimits::Minimum()) { + // sum can potentially exceed int64_t bounds: use hugeint sum + return nullptr; + } + // total sum is guaranteed to fit in a single int64: use int64 sum instead of hugeint sum + expr.function = GetSumAggregateNoOverflow(internal_type); + } + return nullptr; +} + +AggregateFunction GetSumAggregate(PhysicalType type) { + switch (type) { + case PhysicalType::INT16: { + auto function = AggregateFunction::UnaryAggregate, int16_t, hugeint_t, IntegerSumOperation>( + LogicalType::SMALLINT, LogicalType::HUGEINT); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + + case PhysicalType::INT32: { + auto function = + AggregateFunction::UnaryAggregate, int32_t, hugeint_t, SumToHugeintOperation>( + LogicalType::INTEGER, LogicalType::HUGEINT); + function.statistics = SumPropagateStats; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + case PhysicalType::INT64: { + auto function = + AggregateFunction::UnaryAggregate, int64_t, hugeint_t, SumToHugeintOperation>( + LogicalType::BIGINT, LogicalType::HUGEINT); + function.statistics = SumPropagateStats; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + case PhysicalType::INT128: { + auto function = + AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, HugeintSumOperation>( + LogicalType::HUGEINT, LogicalType::HUGEINT); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + default: + throw InternalException("Unimplemented sum aggregate"); + } +} + +unique_ptr BindDecimalSum(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + function = GetSumAggregate(decimal_type.InternalType()); + function.name = "sum"; + function.arguments[0] = decimal_type; + function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type)); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return nullptr; +} + +unique_ptr BindDecimalSumNoOverflow(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + function = GetSumAggregateNoOverflow(decimal_type.InternalType()); + function.name = "sum_no_overflow"; + function.arguments[0] = decimal_type; + function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type)); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return nullptr; +} + +AggregateFunctionSet SumFun::GetFunctions() { + AggregateFunctionSet sum; + // decimal + sum.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + BindDecimalSum)); + sum.AddFunction(GetSumAggregate(PhysicalType::INT16)); + sum.AddFunction(GetSumAggregate(PhysicalType::INT32)); + sum.AddFunction(GetSumAggregate(PhysicalType::INT64)); + sum.AddFunction(GetSumAggregate(PhysicalType::INT128)); + sum.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericSumOperation>( + LogicalType::DOUBLE, LogicalType::DOUBLE)); + return sum; +} + +AggregateFunctionSet SumNoOverflowFun::GetFunctions() { + AggregateFunctionSet sum_no_overflow; + sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT32)); + sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT64)); + sum_no_overflow.AddFunction( + AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr, nullptr, + FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, BindDecimalSumNoOverflow)); + return sum_no_overflow; +} + +AggregateFunction KahanSumFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/src/core_functions/aggregate/holistic/approximate_quantile.cpp new file mode 100644 index 00000000..36a08614 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -0,0 +1,347 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/core_functions/aggregate/holistic_functions.hpp" +#include "t_digest.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include +#include +#include + +namespace duckdb { + +struct ApproxQuantileState { + duckdb_tdigest::TDigest *h; + idx_t pos; +}; + +struct ApproximateQuantileBindData : public FunctionData { + ApproximateQuantileBindData() { + } + explicit ApproximateQuantileBindData(float quantile_p) : quantiles(1, quantile_p) { + } + + explicit ApproximateQuantileBindData(vector quantiles_p) : quantiles(std::move(quantiles_p)) { + } + + unique_ptr Copy() const override { + return make_uniq(quantiles); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + // return quantiles == other.quantiles; + if (quantiles != other.quantiles) { + return false; + } + return true; + } + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "quantiles", bind_data.quantiles); + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + deserializer.ReadProperty(100, "quantiles", result->quantiles); + return std::move(result); + } + + vector quantiles; +}; + +struct ApproxQuantileOperation { + using SAVE_TYPE = duckdb_tdigest::Value; + + template + static void Initialize(STATE &state) { + state.pos = 0; + state.h = nullptr; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + auto val = Cast::template Operation(input); + if (!Value::DoubleIsFinite(val)) { + return; + } + if (!state.h) { + state.h = new duckdb_tdigest::TDigest(100); + } + state.h->add(val); + state.pos++; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.pos == 0) { + return; + } + D_ASSERT(source.h); + if (!target.h) { + target.h = new duckdb_tdigest::TDigest(100); + } + target.h->merge(source.h); + target.pos += source.pos; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.h) { + delete state.h; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct ApproxQuantileScalarOperation : public ApproxQuantileOperation { + template + static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { + if (state.pos == 0) { + finalize_data.ReturnNull(); + return; + } + D_ASSERT(state.h); + D_ASSERT(finalize_data.input.bind_data); + state.h->compress(); + auto &bind_data = finalize_data.input.bind_data->template Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + target = Cast::template Operation(state.h->quantile(bind_data.quantiles[0])); + } +}; + +AggregateFunction GetApproximateQuantileAggregateFunction(PhysicalType type) { + switch (type) { + case PhysicalType::INT16: + return AggregateFunction::UnaryAggregateDestructor(LogicalType::SMALLINT, + LogicalType::SMALLINT); + case PhysicalType::INT32: + return AggregateFunction::UnaryAggregateDestructor(LogicalType::INTEGER, + LogicalType::INTEGER); + case PhysicalType::INT64: + return AggregateFunction::UnaryAggregateDestructor(LogicalType::BIGINT, + LogicalType::BIGINT); + case PhysicalType::INT128: + return AggregateFunction::UnaryAggregateDestructor(LogicalType::HUGEINT, + LogicalType::HUGEINT); + case PhysicalType::DOUBLE: + return AggregateFunction::UnaryAggregateDestructor(LogicalType::DOUBLE, + LogicalType::DOUBLE); + default: + throw InternalException("Unimplemented quantile aggregate"); + } +} + +static float CheckApproxQuantile(const Value &quantile_val) { + if (quantile_val.IsNull()) { + throw BinderException("APPROXIMATE QUANTILE parameter cannot be NULL"); + } + auto quantile = quantile_val.GetValue(); + if (quantile < 0 || quantile > 1) { + throw BinderException("APPROXIMATE QUANTILE can only take parameters in range [0, 1]"); + } + + return quantile; +} + +unique_ptr BindApproxQuantile(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("APPROXIMATE QUANTILE can only take constant quantile parameters"); + } + Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + + vector quantiles; + if (quantile_val.type().id() != LogicalTypeId::LIST) { + quantiles.push_back(CheckApproxQuantile(quantile_val)); + } else { + for (const auto &element_val : ListValue::GetChildren(quantile_val)) { + quantiles.push_back(CheckApproxQuantile(element_val)); + } + } + + // remove the quantile argument so we can use the unary aggregate + Function::EraseArgument(function, arguments, arguments.size() - 1); + return make_uniq(quantiles); +} + +unique_ptr BindApproxQuantileDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindApproxQuantile(context, function, arguments); + function = GetApproximateQuantileAggregateFunction(arguments[0]->return_type.InternalType()); + function.name = "approx_quantile"; + function.serialize = ApproximateQuantileBindData::Serialize; + function.deserialize = ApproximateQuantileBindData::Deserialize; + return bind_data; +} + +AggregateFunction GetApproximateQuantileAggregate(PhysicalType type) { + auto fun = GetApproximateQuantileAggregateFunction(type); + fun.bind = BindApproxQuantile; + fun.serialize = ApproximateQuantileBindData::Serialize; + fun.deserialize = ApproximateQuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::FLOAT); + return fun; +} + +template +struct ApproxQuantileListOperation : public ApproxQuantileOperation { + + template + static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) { + if (state.pos == 0) { + finalize_data.ReturnNull(); + return; + } + + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->template Cast(); + + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); + auto rdata = FlatVector::GetData(result); + + D_ASSERT(state.h); + state.h->compress(); + + auto &entry = target; + entry.offset = ridx; + entry.length = bind_data.quantiles.size(); + for (size_t q = 0; q < entry.length; ++q) { + const auto &quantile = bind_data.quantiles[q]; + rdata[ridx + q] = Cast::template Operation(state.h->quantile(quantile)); + } + + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); + } +}; + +template +static AggregateFunction ApproxQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { + LogicalType result_type = LogicalType::LIST(child_type); + return AggregateFunction( + {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, + nullptr, AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetTypedApproxQuantileListAggregateFunction(const LogicalType &type) { + using STATE = ApproxQuantileState; + using OP = ApproxQuantileListOperation; + auto fun = ApproxQuantileListAggregate(type, type); + fun.serialize = ApproximateQuantileBindData::Serialize; + fun.deserialize = ApproximateQuantileBindData::Deserialize; + return fun; +} + +AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::SMALLINT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::INTEGER: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::BIGINT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::HUGEINT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::FLOAT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::DOUBLE: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedApproxQuantileListAggregateFunction(type); + case PhysicalType::INT32: + return GetTypedApproxQuantileListAggregateFunction(type); + case PhysicalType::INT64: + return GetTypedApproxQuantileListAggregateFunction(type); + case PhysicalType::INT128: + return GetTypedApproxQuantileListAggregateFunction(type); + default: + throw NotImplementedException("Unimplemented approximate quantile list aggregate"); + } + default: + // TODO: Add quantitative temporal types + throw NotImplementedException("Unimplemented approximate quantile list aggregate"); + } +} + +unique_ptr BindApproxQuantileDecimalList(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindApproxQuantile(context, function, arguments); + function = GetApproxQuantileListAggregateFunction(arguments[0]->return_type); + function.name = "approx_quantile"; + function.serialize = ApproximateQuantileBindData::Serialize; + function.deserialize = ApproximateQuantileBindData::Deserialize; + return bind_data; +} + +AggregateFunction GetApproxQuantileListAggregate(const LogicalType &type) { + auto fun = GetApproxQuantileListAggregateFunction(type); + fun.bind = BindApproxQuantile; + fun.serialize = ApproximateQuantileBindData::Serialize; + fun.deserialize = ApproximateQuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + auto list_of_float = LogicalType::LIST(LogicalType::FLOAT); + fun.arguments.push_back(list_of_float); + return fun; +} + +AggregateFunctionSet ApproxQuantileFun::GetFunctions() { + AggregateFunctionSet approx_quantile; + approx_quantile.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, LogicalType::FLOAT}, LogicalTypeId::DECIMAL, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + BindApproxQuantileDecimal)); + + approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT16)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT32)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT64)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT128)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::DOUBLE)); + + // List variants + approx_quantile.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::FLOAT)}, + LogicalType::LIST(LogicalTypeId::DECIMAL), nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, BindApproxQuantileDecimalList)); + + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::TINYINT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::SMALLINT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::INTEGER)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::BIGINT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::HUGEINT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::FLOAT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::DOUBLE)); + return approx_quantile; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/holistic/mode.cpp b/src/duckdb/src/core_functions/aggregate/holistic/mode.cpp new file mode 100644 index 00000000..0f81b88c --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/holistic/mode.cpp @@ -0,0 +1,358 @@ +// MODE( ) +// Returns the most frequent value for the values within expr1. +// NULL values are ignored. If all the values are NULL, or there are 0 rows, then the function returns NULL. + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/unordered_map.hpp" + +#include + +namespace std { + +template <> +struct hash { + inline size_t operator()(const duckdb::interval_t &val) const { + return hash {}(val.days) ^ hash {}(val.months) ^ hash {}(val.micros); + } +}; + +template <> +struct hash { + inline size_t operator()(const duckdb::hugeint_t &val) const { + return hash {}(val.upper) ^ hash {}(val.lower); + } +}; + +} // namespace std + +namespace duckdb { + +template +struct ModeState { + struct ModeAttr { + ModeAttr() : count(0), first_row(std::numeric_limits::max()) { + } + size_t count; + idx_t first_row; + }; + using Counts = unordered_map; + + Counts *frequency_map; + KEY_TYPE *mode; + size_t nonzero; + bool valid; + size_t count; + + void Initialize() { + frequency_map = nullptr; + mode = nullptr; + nonzero = 0; + valid = false; + count = 0; + } + + void Destroy() { + if (frequency_map) { + delete frequency_map; + } + if (mode) { + delete mode; + } + } + + void Reset() { + Counts empty; + frequency_map->swap(empty); + nonzero = 0; + count = 0; + valid = false; + } + + void ModeAdd(const KEY_TYPE &key, idx_t row) { + auto &attr = (*frequency_map)[key]; + auto new_count = (attr.count += 1); + if (new_count == 1) { + ++nonzero; + attr.first_row = row; + } else { + attr.first_row = MinValue(row, attr.first_row); + } + if (new_count > count) { + valid = true; + count = new_count; + if (mode) { + *mode = key; + } else { + mode = new KEY_TYPE(key); + } + } + } + + void ModeRm(const KEY_TYPE &key, idx_t frame) { + auto &attr = (*frequency_map)[key]; + auto old_count = attr.count; + nonzero -= int(old_count == 1); + + attr.count -= 1; + if (count == old_count && key == *mode) { + valid = false; + } + } + + typename Counts::const_iterator Scan() const { + //! Initialize control variables to first variable of the frequency map + auto highest_frequency = frequency_map->begin(); + for (auto i = highest_frequency; i != frequency_map->end(); ++i) { + // Tie break with the lowest insert position + if (i->second.count > highest_frequency->second.count || + (i->second.count == highest_frequency->second.count && + i->second.first_row < highest_frequency->second.first_row)) { + highest_frequency = i; + } + } + return highest_frequency; + } +}; + +struct ModeIncluded { + inline explicit ModeIncluded(const ValidityMask &fmask_p, const ValidityMask &dmask_p, idx_t bias_p) + : fmask(fmask_p), dmask(dmask_p), bias(bias_p) { + } + + inline bool operator()(const idx_t &idx) const { + return fmask.RowIsValid(idx) && dmask.RowIsValid(idx - bias); + } + const ValidityMask &fmask; + const ValidityMask &dmask; + const idx_t bias; +}; + +struct ModeAssignmentStandard { + template + static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { + return RESULT_TYPE(input); + } +}; + +struct ModeAssignmentString { + template + static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { + return StringVector::AddString(result, input); + } +}; + +template +struct ModeFunction { + template + static void Initialize(STATE &state) { + state.Initialize(); + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + if (!state.frequency_map) { + state.frequency_map = new typename STATE::Counts(); + } + auto key = KEY_TYPE(input); + auto &i = (*state.frequency_map)[key]; + i.count++; + i.first_row = MinValue(i.first_row, state.count); + state.count++; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.frequency_map) { + return; + } + if (!target.frequency_map) { + // Copy - don't destroy! Otherwise windowing will break. + target.frequency_map = new typename STATE::Counts(*source.frequency_map); + return; + } + for (auto &val : *source.frequency_map) { + auto &i = (*target.frequency_map)[val.first]; + i.count += val.second.count; + i.first_row = MinValue(i.first_row, val.second.first_row); + } + target.count += source.count; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.frequency_map) { + finalize_data.ReturnNull(); + return; + } + auto highest_frequency = state.Scan(); + if (highest_frequency != state.frequency_map->end()) { + target = ASSIGN_OP::template Assign(finalize_data.result, highest_frequency->first); + } else { + finalize_data.ReturnNull(); + } + } + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) { + if (!state.frequency_map) { + state.frequency_map = new typename STATE::Counts(); + } + auto key = KEY_TYPE(input); + auto &i = (*state.frequency_map)[key]; + i.count += count; + i.first_row = MinValue(i.first_row, state.count); + state.count += count; + } + + template + static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, + AggregateInputData &, STATE &state, const FrameBounds &frame, const FrameBounds &prev, + Vector &result, idx_t rid, idx_t bias) { + auto rdata = FlatVector::GetData(result); + auto &rmask = FlatVector::Validity(result); + + ModeIncluded included(fmask, dmask, bias); + + if (!state.frequency_map) { + state.frequency_map = new typename STATE::Counts; + } + const double tau = .25; + if (state.nonzero <= tau * state.frequency_map->size() || prev.end <= frame.start || frame.end <= prev.start) { + state.Reset(); + // for f ∈ F do + for (auto f = frame.start; f < frame.end; ++f) { + if (included(f)) { + state.ModeAdd(KEY_TYPE(data[f]), f); + } + } + } else { + // for f ∈ P \ F do + for (auto p = prev.start; p < frame.start; ++p) { + if (included(p)) { + state.ModeRm(KEY_TYPE(data[p]), p); + } + } + for (auto p = frame.end; p < prev.end; ++p) { + if (included(p)) { + state.ModeRm(KEY_TYPE(data[p]), p); + } + } + + // for f ∈ F \ P do + for (auto f = frame.start; f < prev.start; ++f) { + if (included(f)) { + state.ModeAdd(KEY_TYPE(data[f]), f); + } + } + for (auto f = prev.end; f < frame.end; ++f) { + if (included(f)) { + state.ModeAdd(KEY_TYPE(data[f]), f); + } + } + } + + if (!state.valid) { + // Rescan + auto highest_frequency = state.Scan(); + if (highest_frequency != state.frequency_map->end()) { + *(state.mode) = highest_frequency->first; + state.count = highest_frequency->second.count; + state.valid = (state.count > 0); + } + } + + if (state.valid) { + rdata[rid] = ASSIGN_OP::template Assign(result, *state.mode); + } else { + rmask.Set(rid, false); + } + } + + static bool IgnoreNull() { + return true; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.Destroy(); + } +}; + +template +AggregateFunction GetTypedModeFunction(const LogicalType &type) { + using STATE = ModeState; + using OP = ModeFunction; + auto func = AggregateFunction::UnaryAggregateDestructor(type, type); + func.window = AggregateFunction::UnaryWindow; + return func; +} + +AggregateFunction GetModeAggregate(const LogicalType &type) { + switch (type.InternalType()) { + case PhysicalType::INT8: + return GetTypedModeFunction(type); + case PhysicalType::UINT8: + return GetTypedModeFunction(type); + case PhysicalType::INT16: + return GetTypedModeFunction(type); + case PhysicalType::UINT16: + return GetTypedModeFunction(type); + case PhysicalType::INT32: + return GetTypedModeFunction(type); + case PhysicalType::UINT32: + return GetTypedModeFunction(type); + case PhysicalType::INT64: + return GetTypedModeFunction(type); + case PhysicalType::UINT64: + return GetTypedModeFunction(type); + case PhysicalType::INT128: + return GetTypedModeFunction(type); + + case PhysicalType::FLOAT: + return GetTypedModeFunction(type); + case PhysicalType::DOUBLE: + return GetTypedModeFunction(type); + + case PhysicalType::INTERVAL: + return GetTypedModeFunction(type); + + case PhysicalType::VARCHAR: + return GetTypedModeFunction(type); + + default: + throw NotImplementedException("Unimplemented mode aggregate"); + } +} + +unique_ptr BindModeDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetModeAggregate(arguments[0]->return_type); + function.name = "mode"; + return nullptr; +} + +AggregateFunctionSet ModeFun::GetFunctions() { + const vector TEMPORAL = {LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::TIME, + LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ, LogicalType::INTERVAL}; + + AggregateFunctionSet mode; + mode.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, BindModeDecimal)); + + for (const auto &type : LogicalType::Numeric()) { + if (type.id() != LogicalTypeId::DECIMAL) { + mode.AddFunction(GetModeAggregate(type)); + } + } + + for (const auto &type : TEMPORAL) { + mode.AddFunction(GetModeAggregate(type)); + } + + mode.AddFunction(GetModeAggregate(LogicalType::VARCHAR)); + return mode; +} +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/holistic/quantile.cpp b/src/duckdb/src/core_functions/aggregate/holistic/quantile.cpp new file mode 100644 index 00000000..045ae65f --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/holistic/quantile.cpp @@ -0,0 +1,1468 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/abs.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/queue.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include +#include +#include + +namespace duckdb { + +// Hugeint arithmetic +static hugeint_t MultiplyByDouble(const hugeint_t &h, const double &d) { + D_ASSERT(d >= 0 && d <= 1); + return Hugeint::Convert(Hugeint::Cast(h) * d); +} + +// Interval arithmetic +static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT + D_ASSERT(d >= 0 && d <= 1); + return Interval::FromMicro(std::llround(Interval::GetMicro(i) * d)); +} + +inline interval_t operator+(const interval_t &lhs, const interval_t &rhs) { + return Interval::FromMicro(Interval::GetMicro(lhs) + Interval::GetMicro(rhs)); +} + +inline interval_t operator-(const interval_t &lhs, const interval_t &rhs) { + return Interval::FromMicro(Interval::GetMicro(lhs) - Interval::GetMicro(rhs)); +} + +template +struct QuantileState { + using SaveType = SAVE_TYPE; + + // Regular aggregation + vector v; + + // Windowed Quantile indirection + vector w; + idx_t pos; + + // Windowed MAD indirection + vector m; + + QuantileState() : pos(0) { + } + + ~QuantileState() { + } + + inline void SetPos(size_t pos_p) { + pos = pos_p; + if (pos >= w.size()) { + w.resize(pos); + } + } +}; + +struct QuantileIncluded { + inline explicit QuantileIncluded(const ValidityMask &fmask_p, const ValidityMask &dmask_p, idx_t bias_p) + : fmask(fmask_p), dmask(dmask_p), bias(bias_p) { + } + + inline bool operator()(const idx_t &idx) const { + return fmask.RowIsValid(idx) && dmask.RowIsValid(idx - bias); + } + + inline bool AllValid() const { + return fmask.AllValid() && dmask.AllValid(); + } + + const ValidityMask &fmask; + const ValidityMask &dmask; + const idx_t bias; +}; + +void ReuseIndexes(idx_t *index, const FrameBounds &frame, const FrameBounds &prev) { + idx_t j = 0; + + // Copy overlapping indices + for (idx_t p = 0; p < (prev.end - prev.start); ++p) { + auto idx = index[p]; + + // Shift down into any hole + if (j != p) { + index[j] = idx; + } + + // Skip overlapping values + if (frame.start <= idx && idx < frame.end) { + ++j; + } + } + + // Insert new indices + if (j > 0) { + // Overlap: append the new ends + for (auto f = frame.start; f < prev.start; ++f, ++j) { + index[j] = f; + } + for (auto f = prev.end; f < frame.end; ++f, ++j) { + index[j] = f; + } + } else { + // No overlap: overwrite with new values + for (auto f = frame.start; f < frame.end; ++f, ++j) { + index[j] = f; + } + } +} + +static idx_t ReplaceIndex(idx_t *index, const FrameBounds &frame, const FrameBounds &prev) { // NOLINT + D_ASSERT(index); + + idx_t j = 0; + for (idx_t p = 0; p < (prev.end - prev.start); ++p) { + auto idx = index[p]; + if (j != p) { + break; + } + + if (frame.start <= idx && idx < frame.end) { + ++j; + } + } + index[j] = frame.end - 1; + + return j; +} + +template +static inline int CanReplace(const idx_t *index, const INPUT_TYPE *fdata, const idx_t j, const idx_t k0, const idx_t k1, + const QuantileIncluded &validity) { + D_ASSERT(index); + + // NULLs sort to the end, so if we have inserted a NULL, + // it must be past the end of the quantile to be replaceable. + // Note that the quantile values are never NULL. + const auto ij = index[j]; + if (!validity(ij)) { + return k1 < j ? 1 : 0; + } + + auto curr = fdata[ij]; + if (k1 < j) { + auto hi = fdata[index[k0]]; + return hi < curr ? 1 : 0; + } else if (j < k0) { + auto lo = fdata[index[k1]]; + return curr < lo ? -1 : 0; + } + + return 0; +} + +template +struct IndirectLess { + inline explicit IndirectLess(const INPUT_TYPE *inputs_p) : inputs(inputs_p) { + } + + inline bool operator()(const idx_t &lhi, const idx_t &rhi) const { + return inputs[lhi] < inputs[rhi]; + } + + const INPUT_TYPE *inputs; +}; + +struct CastInterpolation { + + template + static inline TARGET_TYPE Cast(const INPUT_TYPE &src, Vector &result) { + return Cast::Operation(src); + } + template + static inline TARGET_TYPE Interpolate(const TARGET_TYPE &lo, const double d, const TARGET_TYPE &hi) { + const auto delta = hi - lo; + return lo + delta * d; + } +}; + +template <> +interval_t CastInterpolation::Cast(const dtime_t &src, Vector &result) { + return {0, 0, src.micros}; +} + +template <> +double CastInterpolation::Interpolate(const double &lo, const double d, const double &hi) { + return lo * (1.0 - d) + hi * d; +} + +template <> +dtime_t CastInterpolation::Interpolate(const dtime_t &lo, const double d, const dtime_t &hi) { + return dtime_t(std::llround(lo.micros * (1.0 - d) + hi.micros * d)); +} + +template <> +timestamp_t CastInterpolation::Interpolate(const timestamp_t &lo, const double d, const timestamp_t &hi) { + return timestamp_t(std::llround(lo.value * (1.0 - d) + hi.value * d)); +} + +template <> +hugeint_t CastInterpolation::Interpolate(const hugeint_t &lo, const double d, const hugeint_t &hi) { + const hugeint_t delta = hi - lo; + return lo + MultiplyByDouble(delta, d); +} + +template <> +interval_t CastInterpolation::Interpolate(const interval_t &lo, const double d, const interval_t &hi) { + const interval_t delta = hi - lo; + return lo + MultiplyByDouble(delta, d); +} + +template <> +string_t CastInterpolation::Cast(const std::string &src, Vector &result) { + return StringVector::AddString(result, src); +} + +template <> +string_t CastInterpolation::Cast(const string_t &src, Vector &result) { + return StringVector::AddString(result, src); +} + +// Direct access +template +struct QuantileDirect { + using INPUT_TYPE = T; + using RESULT_TYPE = T; + + inline const INPUT_TYPE &operator()(const INPUT_TYPE &x) const { + return x; + } +}; + +// Indirect access +template +struct QuantileIndirect { + using INPUT_TYPE = idx_t; + using RESULT_TYPE = T; + const RESULT_TYPE *data; + + explicit QuantileIndirect(const RESULT_TYPE *data_p) : data(data_p) { + } + + inline RESULT_TYPE operator()(const idx_t &input) const { + return data[input]; + } +}; + +// Composed access +template +struct QuantileComposed { + using INPUT_TYPE = typename INNER::INPUT_TYPE; + using RESULT_TYPE = typename OUTER::RESULT_TYPE; + + const OUTER &outer; + const INNER &inner; + + explicit QuantileComposed(const OUTER &outer_p, const INNER &inner_p) : outer(outer_p), inner(inner_p) { + } + + inline RESULT_TYPE operator()(const idx_t &input) const { + return outer(inner(input)); + } +}; + +// Accessed comparison +template +struct QuantileCompare { + using INPUT_TYPE = typename ACCESSOR::INPUT_TYPE; + const ACCESSOR &accessor; + const bool desc; + explicit QuantileCompare(const ACCESSOR &accessor_p, bool desc_p) : accessor(accessor_p), desc(desc_p) { + } + + inline bool operator()(const INPUT_TYPE &lhs, const INPUT_TYPE &rhs) const { + const auto lval = accessor(lhs); + const auto rval = accessor(rhs); + + return desc ? (rval < lval) : (lval < rval); + } +}; + +// Avoid using naked Values in inner loops... +struct QuantileValue { + explicit QuantileValue(const Value &v) : val(v), dbl(v.GetValue()) { + const auto &type = val.type(); + switch (type.id()) { + case LogicalTypeId::DECIMAL: { + integral = IntegralValue::Get(v); + scaling = Hugeint::POWERS_OF_TEN[DecimalType::GetScale(type)]; + break; + } + default: + break; + } + } + + Value val; + + // DOUBLE + double dbl; + + // DECIMAL + hugeint_t integral; + hugeint_t scaling; +}; + +bool operator==(const QuantileValue &x, const QuantileValue &y) { + return x.val == y.val; +} + +// Continuous interpolation +template +struct Interpolator { + Interpolator(const QuantileValue &q, const idx_t n_p, const bool desc_p) + : desc(desc_p), RN((double)(n_p - 1) * q.dbl), FRN(floor(RN)), CRN(ceil(RN)), begin(0), end(n_p) { + } + + template > + TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + QuantileCompare comp(accessor, desc); + if (CRN == FRN) { + std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); + return CastInterpolation::Cast(accessor(v_t[FRN]), result); + } else { + std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); + std::nth_element(v_t + FRN, v_t + CRN, v_t + end, comp); + auto lo = CastInterpolation::Cast(accessor(v_t[FRN]), result); + auto hi = CastInterpolation::Cast(accessor(v_t[CRN]), result); + return CastInterpolation::Interpolate(lo, RN - FRN, hi); + } + } + + template > + TARGET_TYPE Replace(const INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + if (CRN == FRN) { + return CastInterpolation::Cast(accessor(v_t[FRN]), result); + } else { + auto lo = CastInterpolation::Cast(accessor(v_t[FRN]), result); + auto hi = CastInterpolation::Cast(accessor(v_t[CRN]), result); + return CastInterpolation::Interpolate(lo, RN - FRN, hi); + } + } + + const bool desc; + const double RN; + const idx_t FRN; + const idx_t CRN; + + idx_t begin; + idx_t end; +}; + +// Discrete "interpolation" +template <> +struct Interpolator { + static inline idx_t Index(const QuantileValue &q, const idx_t n) { + idx_t floored; + switch (q.val.type().id()) { + case LogicalTypeId::DECIMAL: { + // Integer arithmetic for accuracy + const auto integral = q.integral; + const auto scaling = q.scaling; + const auto scaled_q = DecimalMultiplyOverflowCheck::Operation(n, integral); + const auto scaled_n = DecimalMultiplyOverflowCheck::Operation(n, scaling); + floored = Cast::Operation((scaled_n - scaled_q) / scaling); + break; + } + default: + const auto scaled_q = (double)(n * q.dbl); + floored = floor(n - scaled_q); + break; + } + + return MaxValue(1, n - floored) - 1; + } + + Interpolator(const QuantileValue &q, const idx_t n_p, bool desc_p) + : desc(desc_p), FRN(Index(q, n_p)), CRN(FRN), begin(0), end(n_p) { + } + + template > + TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + QuantileCompare comp(accessor, desc); + std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); + return CastInterpolation::Cast(accessor(v_t[FRN]), result); + } + + template > + TARGET_TYPE Replace(const INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + return CastInterpolation::Cast(accessor(v_t[FRN]), result); + } + + const bool desc; + const idx_t FRN; + const idx_t CRN; + + idx_t begin; + idx_t end; +}; + +template +static inline T QuantileAbs(const T &t) { + return AbsOperator::Operation(t); +} + +template <> +inline Value QuantileAbs(const Value &v) { + const auto &type = v.type(); + switch (type.id()) { + case LogicalTypeId::DECIMAL: { + const auto integral = IntegralValue::Get(v); + const auto width = DecimalType::GetWidth(type); + const auto scale = DecimalType::GetScale(type); + switch (type.InternalType()) { + case PhysicalType::INT16: + return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); + case PhysicalType::INT32: + return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); + case PhysicalType::INT64: + return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); + case PhysicalType::INT128: + return Value::DECIMAL(QuantileAbs(integral), width, scale); + default: + throw InternalException("Unknown DECIMAL type"); + } + } + default: + return Value::DOUBLE(QuantileAbs(v.GetValue())); + } +} + +struct QuantileBindData : public FunctionData { + QuantileBindData() { + } + + explicit QuantileBindData(const Value &quantile_p) + : quantiles(1, QuantileValue(QuantileAbs(quantile_p))), order(1, 0), desc(quantile_p < 0) { + } + + explicit QuantileBindData(const vector &quantiles_p) { + vector normalised; + size_t pos = 0; + size_t neg = 0; + for (idx_t i = 0; i < quantiles_p.size(); ++i) { + const auto &q = quantiles_p[i]; + pos += (q > 0); + neg += (q < 0); + normalised.emplace_back(QuantileAbs(q)); + order.push_back(i); + } + if (pos && neg) { + throw BinderException("QUANTILE parameters must have consistent signs"); + } + desc = (neg > 0); + + IndirectLess lt(normalised.data()); + std::sort(order.begin(), order.end(), lt); + + for (const auto &q : normalised) { + quantiles.emplace_back(QuantileValue(q)); + } + } + + QuantileBindData(const QuantileBindData &other) : order(other.order), desc(other.desc) { + for (const auto &q : other.quantiles) { + quantiles.emplace_back(q); + } + } + + unique_ptr Copy() const override { + return make_uniq(*this); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return desc == other.desc && quantiles == other.quantiles && order == other.order; + } + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + vector raw; + for (const auto &q : bind_data.quantiles) { + raw.emplace_back(q.val); + } + serializer.WriteProperty(100, "quantiles", raw); + serializer.WriteProperty(101, "order", bind_data.order); + serializer.WriteProperty(102, "desc", bind_data.desc); + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + vector raw; + deserializer.ReadProperty(100, "quantiles", raw); + deserializer.ReadProperty(101, "order", result->order); + deserializer.ReadProperty(102, "desc", result->desc); + for (const auto &r : raw) { + result->quantiles.emplace_back(QuantileValue(r)); + } + return std::move(result); + } + + static void SerializeDecimal(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + throw NotImplementedException("FIXME: serializing quantiles with decimals is not supported right now"); + } + + vector quantiles; + vector order; + bool desc; +}; + +struct QuantileOperation { + template + static void Initialize(STATE &state) { + new (&state) STATE(); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + state.v.emplace_back(input); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.v.empty()) { + return; + } + target.v.insert(target.v.end(), source.v.begin(), source.v.end()); + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.~STATE(); + } + + static bool IgnoreNull() { + return true; + } +}; + +template +static AggregateFunction QuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { // NOLINT + LogicalType result_type = LogicalType::LIST(child_type); + return AggregateFunction( + {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, + nullptr, AggregateFunction::StateDestroy); +} + +template +struct QuantileScalarOperation : public QuantileOperation { + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); + target = interp.template Operation(state.v.data(), finalize_data.result); + } + + template + static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, + AggregateInputData &aggr_input_data, STATE &state, const FrameBounds &frame, + const FrameBounds &prev, Vector &result, idx_t ridx, idx_t bias) { + auto rdata = FlatVector::GetData(result); + auto &rmask = FlatVector::Validity(result); + + QuantileIncluded included(fmask, dmask, bias); + + // Lazily initialise frame state + auto prev_pos = state.pos; + state.SetPos(frame.end - frame.start); + + auto index = state.w.data(); + D_ASSERT(index); + + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); + + // Find the two positions needed + const auto &q = bind_data.quantiles[0]; + + bool replace = false; + if (frame.start == prev.start + 1 && frame.end == prev.end + 1) { + // Fixed frame size + const auto j = ReplaceIndex(index, frame, prev); + // We can only replace if the number of NULLs has not changed + if (included.AllValid() || included(prev.start) == included(prev.end)) { + Interpolator interp(q, prev_pos, false); + replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); + if (replace) { + state.pos = prev_pos; + } + } + } else { + ReuseIndexes(index, frame, prev); + } + + if (!replace && !included.AllValid()) { + // Remove the NULLs + state.pos = std::partition(index, index + state.pos, included) - index; + } + if (state.pos) { + Interpolator interp(q, state.pos, false); + + using ID = QuantileIndirect; + ID indirect(data); + rdata[ridx] = replace ? interp.template Replace(index, result, indirect) + : interp.template Operation(index, result, indirect); + } else { + rmask.Set(ridx, false); + } + } +}; + +template +AggregateFunction GetTypedDiscreteQuantileAggregateFunction(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileScalarOperation; + auto fun = AggregateFunction::UnaryAggregateDestructor(type, type); + fun.window = AggregateFunction::UnaryWindow; + return fun; +} + +AggregateFunction GetDiscreteQuantileAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::SMALLINT: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::INTEGER: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::BIGINT: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::HUGEINT: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::FLOAT: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::DOUBLE: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedDiscreteQuantileAggregateFunction(type); + case PhysicalType::INT32: + return GetTypedDiscreteQuantileAggregateFunction(type); + case PhysicalType::INT64: + return GetTypedDiscreteQuantileAggregateFunction(type); + case PhysicalType::INT128: + return GetTypedDiscreteQuantileAggregateFunction(type); + default: + throw NotImplementedException("Unimplemented discrete quantile aggregate"); + } + case LogicalTypeId::DATE: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return GetTypedDiscreteQuantileAggregateFunction(type); + case LogicalTypeId::INTERVAL: + return GetTypedDiscreteQuantileAggregateFunction(type); + + case LogicalTypeId::VARCHAR: + return GetTypedDiscreteQuantileAggregateFunction(type); + + default: + throw NotImplementedException("Unimplemented discrete quantile aggregate"); + } +} + +template +struct QuantileListOperation : public QuantileOperation { + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); + auto rdata = FlatVector::GetData(result); + + auto v_t = state.v.data(); + D_ASSERT(v_t); + + auto &entry = target; + entry.offset = ridx; + idx_t lower = 0; + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, state.v.size(), bind_data.desc); + interp.begin = lower; + rdata[ridx + q] = interp.template Operation(v_t, result); + lower = interp.FRN; + } + entry.length = bind_data.quantiles.size(); + + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); + } + + template + static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, + AggregateInputData &aggr_input_data, STATE &state, const FrameBounds &frame, + const FrameBounds &prev, Vector &list, idx_t lidx, idx_t bias) { + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); + + QuantileIncluded included(fmask, dmask, bias); + + // Result is a constant LIST with a fixed length + auto ldata = FlatVector::GetData(list); + auto &lmask = FlatVector::Validity(list); + auto &lentry = ldata[lidx]; + lentry.offset = ListVector::GetListSize(list); + lentry.length = bind_data.quantiles.size(); + + ListVector::Reserve(list, lentry.offset + lentry.length); + ListVector::SetListSize(list, lentry.offset + lentry.length); + auto &result = ListVector::GetEntry(list); + auto rdata = FlatVector::GetData(result); + + // Lazily initialise frame state + auto prev_pos = state.pos; + state.SetPos(frame.end - frame.start); + + auto index = state.w.data(); + + // We can generalise replacement for quantile lists by observing that when a replacement is + // valid for a single quantile, it is valid for all quantiles greater/less than that quantile + // based on whether the insertion is below/above the quantile location. + // So if a replaced index in an IQR is located between Q25 and Q50, but has a value below Q25, + // then Q25 must be recomputed, but Q50 and Q75 are unaffected. + // For a single element list, this reduces to the scalar case. + std::pair replaceable {state.pos, 0}; + if (frame.start == prev.start + 1 && frame.end == prev.end + 1) { + // Fixed frame size + const auto j = ReplaceIndex(index, frame, prev); + // We can only replace if the number of NULLs has not changed + if (included.AllValid() || included(prev.start) == included(prev.end)) { + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, prev_pos, false); + const auto replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); + if (replace < 0) { + // Replacement is before this quantile, so the rest will be replaceable too. + replaceable.first = MinValue(replaceable.first, interp.FRN); + replaceable.second = prev_pos; + break; + } else if (replace > 0) { + // Replacement is after this quantile, so everything before it is replaceable too. + replaceable.first = 0; + replaceable.second = MaxValue(replaceable.second, interp.CRN); + } + } + if (replaceable.first < replaceable.second) { + state.pos = prev_pos; + } + } + } else { + ReuseIndexes(index, frame, prev); + } + + if (replaceable.first >= replaceable.second && !included.AllValid()) { + // Remove the NULLs + state.pos = std::partition(index, index + state.pos, included) - index; + } + + if (state.pos) { + using ID = QuantileIndirect; + ID indirect(data); + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, state.pos, false); + if (replaceable.first <= interp.FRN && interp.CRN <= replaceable.second) { + rdata[lentry.offset + q] = interp.template Replace(index, result, indirect); + } else { + // Make sure we don't disturb any replacements + if (replaceable.first < replaceable.second) { + if (interp.FRN < replaceable.first) { + interp.end = replaceable.first; + } + if (replaceable.second < interp.CRN) { + interp.begin = replaceable.second; + } + } + rdata[lentry.offset + q] = + interp.template Operation(index, result, indirect); + } + } + } else { + lmask.Set(lidx, false); + } + } +}; + +template +AggregateFunction GetTypedDiscreteQuantileListAggregateFunction(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileListOperation; + auto fun = QuantileListAggregate(type, type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.window = AggregateFunction::UnaryWindow; + return fun; +} + +AggregateFunction GetDiscreteQuantileListAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::SMALLINT: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::INTEGER: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::BIGINT: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::HUGEINT: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::FLOAT: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::DOUBLE: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case PhysicalType::INT32: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case PhysicalType::INT64: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case PhysicalType::INT128: + return GetTypedDiscreteQuantileListAggregateFunction(type); + default: + throw NotImplementedException("Unimplemented discrete quantile list aggregate"); + } + case LogicalTypeId::DATE: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::INTERVAL: + return GetTypedDiscreteQuantileListAggregateFunction(type); + case LogicalTypeId::VARCHAR: + return GetTypedDiscreteQuantileListAggregateFunction(type); + default: + throw NotImplementedException("Unimplemented discrete quantile list aggregate"); + } +} + +template +AggregateFunction GetTypedContinuousQuantileAggregateFunction(const LogicalType &input_type, + const LogicalType &target_type) { + using STATE = QuantileState; + using OP = QuantileScalarOperation; + auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.window = AggregateFunction::UnaryWindow; + return fun; +} + +AggregateFunction GetContinuousQuantileAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::SMALLINT: + return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::INTEGER: + return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::BIGINT: + return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::HUGEINT: + return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::FLOAT: + return GetTypedContinuousQuantileAggregateFunction(type, type); + case LogicalTypeId::DOUBLE: + return GetTypedContinuousQuantileAggregateFunction(type, type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedContinuousQuantileAggregateFunction(type, type); + case PhysicalType::INT32: + return GetTypedContinuousQuantileAggregateFunction(type, type); + case PhysicalType::INT64: + return GetTypedContinuousQuantileAggregateFunction(type, type); + case PhysicalType::INT128: + return GetTypedContinuousQuantileAggregateFunction(type, type); + default: + throw NotImplementedException("Unimplemented continuous quantile DECIMAL aggregate"); + } + case LogicalTypeId::DATE: + return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::TIMESTAMP); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return GetTypedContinuousQuantileAggregateFunction(type, type); + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return GetTypedContinuousQuantileAggregateFunction(type, type); + + default: + throw NotImplementedException("Unimplemented continuous quantile aggregate"); + } +} + +template +AggregateFunction GetTypedContinuousQuantileListAggregateFunction(const LogicalType &input_type, + const LogicalType &result_type) { + using STATE = QuantileState; + using OP = QuantileListOperation; + auto fun = QuantileListAggregate(input_type, result_type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.window = AggregateFunction::UnaryWindow; + return fun; +} + +AggregateFunction GetContinuousQuantileListAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::SMALLINT: + return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::INTEGER: + return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::BIGINT: + return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::HUGEINT: + return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); + + case LogicalTypeId::FLOAT: + return GetTypedContinuousQuantileListAggregateFunction(type, type); + case LogicalTypeId::DOUBLE: + return GetTypedContinuousQuantileListAggregateFunction(type, type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedContinuousQuantileListAggregateFunction(type, type); + case PhysicalType::INT32: + return GetTypedContinuousQuantileListAggregateFunction(type, type); + case PhysicalType::INT64: + return GetTypedContinuousQuantileListAggregateFunction(type, type); + case PhysicalType::INT128: + return GetTypedContinuousQuantileListAggregateFunction(type, type); + default: + throw NotImplementedException("Unimplemented discrete quantile DECIMAL list aggregate"); + } + break; + + case LogicalTypeId::DATE: + return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::TIMESTAMP); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return GetTypedContinuousQuantileListAggregateFunction(type, type); + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return GetTypedContinuousQuantileListAggregateFunction(type, type); + + default: + throw NotImplementedException("Unimplemented discrete quantile list aggregate"); + } +} + +template +struct MadAccessor { + using INPUT_TYPE = T; + using RESULT_TYPE = R; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = input - median; + return TryAbsOperator::Operation(delta); + } +}; + +// hugeint_t - double => undefined +template <> +struct MadAccessor { + using INPUT_TYPE = hugeint_t; + using RESULT_TYPE = double; + using MEDIAN_TYPE = double; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = Hugeint::Cast(input) - median; + return TryAbsOperator::Operation(delta); + } +}; + +// date_t - timestamp_t => interval_t +template <> +struct MadAccessor { + using INPUT_TYPE = date_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = timestamp_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto dt = Cast::Operation(input); + const auto delta = dt - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +// timestamp_t - timestamp_t => int64_t +template <> +struct MadAccessor { + using INPUT_TYPE = timestamp_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = timestamp_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = input - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +// dtime_t - dtime_t => int64_t +template <> +struct MadAccessor { + using INPUT_TYPE = dtime_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = dtime_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = input - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +template +struct MedianAbsoluteDeviationOperation : public QuantileOperation { + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + using SAVE_TYPE = typename STATE::SaveType; + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + const auto &q = bind_data.quantiles[0]; + Interpolator interp(q, state.v.size(), false); + const auto med = interp.template Operation(state.v.data(), finalize_data.result); + + MadAccessor accessor(med); + target = interp.template Operation(state.v.data(), finalize_data.result, accessor); + } + + template + static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, + AggregateInputData &aggr_input_data, STATE &state, const FrameBounds &frame, + const FrameBounds &prev, Vector &result, idx_t ridx, idx_t bias) { + auto rdata = FlatVector::GetData(result); + auto &rmask = FlatVector::Validity(result); + + QuantileIncluded included(fmask, dmask, bias); + + // Lazily initialise frame state + auto prev_pos = state.pos; + state.SetPos(frame.end - frame.start); + + auto index = state.w.data(); + D_ASSERT(index); + + // We need a second index for the second pass. + if (state.pos > state.m.size()) { + state.m.resize(state.pos); + } + + auto index2 = state.m.data(); + D_ASSERT(index2); + + // The replacement trick does not work on the second index because if + // the median has changed, the previous order is not correct. + // It is probably close, however, and so reuse is helpful. + ReuseIndexes(index2, frame, prev); + std::partition(index2, index2 + state.pos, included); + + // Find the two positions needed for the median + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + const auto &q = bind_data.quantiles[0]; + + bool replace = false; + if (frame.start == prev.start + 1 && frame.end == prev.end + 1) { + // Fixed frame size + const auto j = ReplaceIndex(index, frame, prev); + // We can only replace if the number of NULLs has not changed + if (included.AllValid() || included(prev.start) == included(prev.end)) { + Interpolator interp(q, prev_pos, false); + replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); + if (replace) { + state.pos = prev_pos; + } + } + } else { + ReuseIndexes(index, frame, prev); + } + + if (!replace && !included.AllValid()) { + // Remove the NULLs + state.pos = std::partition(index, index + state.pos, included) - index; + } + + if (state.pos) { + Interpolator interp(q, state.pos, false); + + // Compute or replace median from the first index + using ID = QuantileIndirect; + ID indirect(data); + const auto med = replace ? interp.template Replace(index, result, indirect) + : interp.template Operation(index, result, indirect); + + // Compute mad from the second index + using MAD = MadAccessor; + MAD mad(med); + + using MadIndirect = QuantileComposed; + MadIndirect mad_indirect(mad, indirect); + rdata[ridx] = interp.template Operation(index2, result, mad_indirect); + } else { + rmask.Set(ridx, false); + } + } +}; + +unique_ptr BindMedian(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); +} + +template +AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type, + const LogicalType &target_type) { + using STATE = QuantileState; + using OP = MedianAbsoluteDeviationOperation; + auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); + fun.bind = BindMedian; + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.window = AggregateFunction::UnaryWindow; + return fun; +} + +AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::FLOAT: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case LogicalTypeId::DOUBLE: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT32: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT64: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT128: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + default: + throw NotImplementedException("Unimplemented Median Absolute Deviation DECIMAL aggregate"); + } + break; + + case LogicalTypeId::DATE: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, + LogicalType::INTERVAL); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return GetTypedMedianAbsoluteDeviationAggregateFunction( + type, LogicalType::INTERVAL); + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, + LogicalType::INTERVAL); + + default: + throw NotImplementedException("Unimplemented Median Absolute Deviation aggregate"); + } +} + +unique_ptr BindMedianDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindMedian(context, function, arguments); + + function = GetDiscreteQuantileAggregateFunction(arguments[0]->return_type); + function.name = "median"; + function.serialize = QuantileBindData::SerializeDecimal; + function.deserialize = QuantileBindData::Deserialize; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return bind_data; +} + +unique_ptr BindMedianAbsoluteDeviationDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type); + function.name = "mad"; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return BindMedian(context, function, arguments); +} + +static const Value &CheckQuantile(const Value &quantile_val) { + if (quantile_val.IsNull()) { + throw BinderException("QUANTILE parameter cannot be NULL"); + } + auto quantile = quantile_val.GetValue(); + if (quantile < -1 || quantile > 1) { + throw BinderException("QUANTILE can only take parameters in the range [-1, 1]"); + } + if (Value::IsNan(quantile)) { + throw BinderException("QUANTILE parameter cannot be NaN"); + } + + return quantile_val; +} + +unique_ptr BindQuantile(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("QUANTILE can only take constant parameters"); + } + Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + vector quantiles; + if (quantile_val.type().id() != LogicalTypeId::LIST) { + quantiles.push_back(CheckQuantile(quantile_val)); + } else { + for (const auto &element_val : ListValue::GetChildren(quantile_val)) { + quantiles.push_back(CheckQuantile(element_val)); + } + } + + Function::EraseArgument(function, arguments, arguments.size() - 1); + return make_uniq(quantiles); +} + +unique_ptr BindDiscreteQuantileDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindQuantile(context, function, arguments); + function = GetDiscreteQuantileAggregateFunction(arguments[0]->return_type); + function.name = "quantile_disc"; + function.serialize = QuantileBindData::SerializeDecimal; + function.deserialize = QuantileBindData::Deserialize; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return bind_data; +} + +unique_ptr BindDiscreteQuantileDecimalList(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindQuantile(context, function, arguments); + function = GetDiscreteQuantileListAggregateFunction(arguments[0]->return_type); + function.name = "quantile_disc"; + function.serialize = QuantileBindData::SerializeDecimal; + function.deserialize = QuantileBindData::Deserialize; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return bind_data; +} + +unique_ptr BindContinuousQuantileDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindQuantile(context, function, arguments); + function = GetContinuousQuantileAggregateFunction(arguments[0]->return_type); + function.name = "quantile_cont"; + function.serialize = QuantileBindData::SerializeDecimal; + function.deserialize = QuantileBindData::Deserialize; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return bind_data; +} + +unique_ptr BindContinuousQuantileDecimalList(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindQuantile(context, function, arguments); + function = GetContinuousQuantileListAggregateFunction(arguments[0]->return_type); + function.name = "quantile_cont"; + function.serialize = QuantileBindData::SerializeDecimal; + function.deserialize = QuantileBindData::Deserialize; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return bind_data; +} + +static bool CanInterpolate(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::INTERVAL: + case LogicalTypeId::VARCHAR: + return false; + default: + return true; + } +} + +AggregateFunction GetMedianAggregate(const LogicalType &type) { + auto fun = CanInterpolate(type) ? GetContinuousQuantileAggregateFunction(type) + : GetDiscreteQuantileAggregateFunction(type); + fun.bind = BindMedian; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = QuantileBindData::Deserialize; + return fun; +} + +AggregateFunction GetDiscreteQuantileAggregate(const LogicalType &type) { + auto fun = GetDiscreteQuantileAggregateFunction(type); + fun.bind = BindQuantile; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = QuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::DOUBLE); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +AggregateFunction GetDiscreteQuantileListAggregate(const LogicalType &type) { + auto fun = GetDiscreteQuantileListAggregateFunction(type); + fun.bind = BindQuantile; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = QuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); + fun.arguments.push_back(list_of_double); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +AggregateFunction GetContinuousQuantileAggregate(const LogicalType &type) { + auto fun = GetContinuousQuantileAggregateFunction(type); + fun.bind = BindQuantile; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = QuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::DOUBLE); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +AggregateFunction GetContinuousQuantileListAggregate(const LogicalType &type) { + auto fun = GetContinuousQuantileListAggregateFunction(type); + fun.bind = BindQuantile; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = QuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); + fun.arguments.push_back(list_of_double); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +AggregateFunction GetQuantileDecimalAggregate(const vector &arguments, const LogicalType &return_type, + bind_aggregate_function_t bind) { + AggregateFunction fun(arguments, return_type, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, bind); + fun.bind = bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = QuantileBindData::Deserialize; + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +vector GetQuantileTypes() { + return {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, LogicalType::BIGINT, + LogicalType::HUGEINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, + LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ, + LogicalType::INTERVAL, LogicalType::VARCHAR}; +} + +AggregateFunctionSet MedianFun::GetFunctions() { + AggregateFunctionSet median("median"); + median.AddFunction( + GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, BindMedianDecimal)); + for (const auto &type : GetQuantileTypes()) { + median.AddFunction(GetMedianAggregate(type)); + } + return median; +} + +AggregateFunctionSet QuantileDiscFun::GetFunctions() { + AggregateFunctionSet quantile_disc("quantile_disc"); + quantile_disc.AddFunction(GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, + LogicalTypeId::DECIMAL, BindDiscreteQuantileDecimal)); + quantile_disc.AddFunction( + GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, + LogicalType::LIST(LogicalTypeId::DECIMAL), BindDiscreteQuantileDecimalList)); + for (const auto &type : GetQuantileTypes()) { + quantile_disc.AddFunction(GetDiscreteQuantileAggregate(type)); + quantile_disc.AddFunction(GetDiscreteQuantileListAggregate(type)); + } + return quantile_disc; + // quantile +} + +AggregateFunctionSet QuantileContFun::GetFunctions() { + AggregateFunctionSet quantile_cont("quantile_cont"); + quantile_cont.AddFunction(GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, + LogicalTypeId::DECIMAL, BindContinuousQuantileDecimal)); + quantile_cont.AddFunction( + GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, + LogicalType::LIST(LogicalTypeId::DECIMAL), BindContinuousQuantileDecimalList)); + + for (const auto &type : GetQuantileTypes()) { + if (CanInterpolate(type)) { + quantile_cont.AddFunction(GetContinuousQuantileAggregate(type)); + quantile_cont.AddFunction(GetContinuousQuantileListAggregate(type)); + } + } + return quantile_cont; +} + +AggregateFunctionSet MadFun::GetFunctions() { + AggregateFunctionSet mad("mad"); + mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal)); + + const vector MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, + LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, + LogicalType::TIME_TZ}; + for (const auto &type : MAD_TYPES) { + mad.AddFunction(GetMedianAbsoluteDeviationAggregateFunction(type)); + } + return mad; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/holistic/reservoir_quantile.cpp b/src/duckdb/src/core_functions/aggregate/holistic/reservoir_quantile.cpp new file mode 100644 index 00000000..8945948e --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/holistic/reservoir_quantile.cpp @@ -0,0 +1,443 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/reservoir_sample.hpp" +#include "duckdb/core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/queue.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include +#include + +namespace duckdb { + +template +struct ReservoirQuantileState { + T *v; + idx_t len; + idx_t pos; + BaseReservoirSampling *r_samp; + + void Resize(idx_t new_len) { + if (new_len <= len) { + return; + } + T *old_v = v; + v = (T *)realloc(v, new_len * sizeof(T)); + if (!v) { + free(old_v); + throw InternalException("Memory allocation failure"); + } + len = new_len; + } + + void ReplaceElement(T &input) { + v[r_samp->min_entry] = input; + r_samp->ReplaceElement(); + } + + void FillReservoir(idx_t sample_size, T element) { + if (pos < sample_size) { + v[pos++] = element; + r_samp->InitializeReservoir(pos, len); + } else { + D_ASSERT(r_samp->next_index >= r_samp->current_count); + if (r_samp->next_index == r_samp->current_count) { + ReplaceElement(element); + } + } + } +}; + +struct ReservoirQuantileBindData : public FunctionData { + ReservoirQuantileBindData() { + } + ReservoirQuantileBindData(double quantile_p, int32_t sample_size_p) + : quantiles(1, quantile_p), sample_size(sample_size_p) { + } + + ReservoirQuantileBindData(vector quantiles_p, int32_t sample_size_p) + : quantiles(std::move(quantiles_p)), sample_size(sample_size_p) { + } + + unique_ptr Copy() const override { + return make_uniq(quantiles, sample_size); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return quantiles == other.quantiles && sample_size == other.sample_size; + } + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "quantiles", bind_data.quantiles); + serializer.WriteProperty(101, "sample_size", bind_data.sample_size); + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + deserializer.ReadProperty(100, "quantiles", result->quantiles); + deserializer.ReadProperty(101, "sample_size", result->sample_size); + return std::move(result); + } + + vector quantiles; + int32_t sample_size; +}; + +struct ReservoirQuantileOperation { + template + static void Initialize(STATE &state) { + state.v = nullptr; + state.len = 0; + state.pos = 0; + state.r_samp = nullptr; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + auto &bind_data = unary_input.input.bind_data->template Cast(); + if (state.pos == 0) { + state.Resize(bind_data.sample_size); + } + if (!state.r_samp) { + state.r_samp = new BaseReservoirSampling(); + } + D_ASSERT(state.v); + state.FillReservoir(bind_data.sample_size, input); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.pos == 0) { + return; + } + if (target.pos == 0) { + target.Resize(source.len); + } + if (!target.r_samp) { + target.r_samp = new BaseReservoirSampling(); + } + for (idx_t src_idx = 0; src_idx < source.pos; src_idx++) { + target.FillReservoir(target.len, source.v[src_idx]); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.v) { + free(state.v); + state.v = nullptr; + } + if (state.r_samp) { + delete state.r_samp; + state.r_samp = nullptr; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct ReservoirQuantileScalarOperation : public ReservoirQuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.pos == 0) { + finalize_data.ReturnNull(); + return; + } + D_ASSERT(state.v); + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->template Cast(); + auto v_t = state.v; + D_ASSERT(bind_data.quantiles.size() == 1); + auto offset = (idx_t)((double)(state.pos - 1) * bind_data.quantiles[0]); + std::nth_element(v_t, v_t + offset, v_t + state.pos); + target = v_t[offset]; + } +}; + +AggregateFunction GetReservoirQuantileAggregateFunction(PhysicalType type) { + switch (type) { + case PhysicalType::INT8: + return AggregateFunction::UnaryAggregateDestructor, int8_t, int8_t, + ReservoirQuantileScalarOperation>(LogicalType::TINYINT, + LogicalType::TINYINT); + + case PhysicalType::INT16: + return AggregateFunction::UnaryAggregateDestructor, int16_t, int16_t, + ReservoirQuantileScalarOperation>(LogicalType::SMALLINT, + LogicalType::SMALLINT); + + case PhysicalType::INT32: + return AggregateFunction::UnaryAggregateDestructor, int32_t, int32_t, + ReservoirQuantileScalarOperation>(LogicalType::INTEGER, + LogicalType::INTEGER); + + case PhysicalType::INT64: + return AggregateFunction::UnaryAggregateDestructor, int64_t, int64_t, + ReservoirQuantileScalarOperation>(LogicalType::BIGINT, + LogicalType::BIGINT); + + case PhysicalType::INT128: + return AggregateFunction::UnaryAggregateDestructor, hugeint_t, hugeint_t, + ReservoirQuantileScalarOperation>(LogicalType::HUGEINT, + LogicalType::HUGEINT); + case PhysicalType::FLOAT: + return AggregateFunction::UnaryAggregateDestructor, float, float, + ReservoirQuantileScalarOperation>(LogicalType::FLOAT, + LogicalType::FLOAT); + case PhysicalType::DOUBLE: + return AggregateFunction::UnaryAggregateDestructor, double, double, + ReservoirQuantileScalarOperation>(LogicalType::DOUBLE, + LogicalType::DOUBLE); + default: + throw InternalException("Unimplemented reservoir quantile aggregate"); + } +} + +template +struct ReservoirQuantileListOperation : public ReservoirQuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.pos == 0) { + finalize_data.ReturnNull(); + return; + } + + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->template Cast(); + + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); + auto rdata = FlatVector::GetData(result); + + auto v_t = state.v; + D_ASSERT(v_t); + + auto &entry = target; + entry.offset = ridx; + entry.length = bind_data.quantiles.size(); + for (size_t q = 0; q < entry.length; ++q) { + const auto &quantile = bind_data.quantiles[q]; + auto offset = (idx_t)((double)(state.pos - 1) * quantile); + std::nth_element(v_t, v_t + offset, v_t + state.pos); + rdata[ridx + q] = v_t[offset]; + } + + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); + } +}; + +template +static AggregateFunction ReservoirQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { + LogicalType result_type = LogicalType::LIST(child_type); + return AggregateFunction( + {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, + nullptr, AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetTypedReservoirQuantileListAggregateFunction(const LogicalType &type) { + using STATE = ReservoirQuantileState; + using OP = ReservoirQuantileListOperation; + auto fun = ReservoirQuantileListAggregate(type, type); + return fun; +} + +AggregateFunction GetReservoirQuantileListAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::SMALLINT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::INTEGER: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::BIGINT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::HUGEINT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::FLOAT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::DOUBLE: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedReservoirQuantileListAggregateFunction(type); + case PhysicalType::INT32: + return GetTypedReservoirQuantileListAggregateFunction(type); + case PhysicalType::INT64: + return GetTypedReservoirQuantileListAggregateFunction(type); + case PhysicalType::INT128: + return GetTypedReservoirQuantileListAggregateFunction(type); + default: + throw NotImplementedException("Unimplemented reservoir quantile list aggregate"); + } + default: + // TODO: Add quantitative temporal types + throw NotImplementedException("Unimplemented reservoir quantile list aggregate"); + } +} + +static double CheckReservoirQuantile(const Value &quantile_val) { + if (quantile_val.IsNull()) { + throw BinderException("RESERVOIR_QUANTILE QUANTILE parameter cannot be NULL"); + } + auto quantile = quantile_val.GetValue(); + if (quantile < 0 || quantile > 1) { + throw BinderException("RESERVOIR_QUANTILE can only take parameters in the range [0, 1]"); + } + return quantile; +} + +unique_ptr BindReservoirQuantile(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + D_ASSERT(arguments.size() >= 2); + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("RESERVOIR_QUANTILE can only take constant quantile parameters"); + } + Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + vector quantiles; + if (quantile_val.type().id() != LogicalTypeId::LIST) { + quantiles.push_back(CheckReservoirQuantile(quantile_val)); + } else { + for (const auto &element_val : ListValue::GetChildren(quantile_val)) { + quantiles.push_back(CheckReservoirQuantile(element_val)); + } + } + + if (arguments.size() == 2) { + if (function.arguments.size() == 2) { + Function::EraseArgument(function, arguments, arguments.size() - 1); + } else { + arguments.pop_back(); + } + return make_uniq(quantiles, 8192); + } + if (!arguments[2]->IsFoldable()) { + throw BinderException("RESERVOIR_QUANTILE can only take constant sample size parameters"); + } + Value sample_size_val = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); + if (sample_size_val.IsNull()) { + throw BinderException("Size of the RESERVOIR_QUANTILE sample cannot be NULL"); + } + auto sample_size = sample_size_val.GetValue(); + + if (sample_size_val.IsNull() || sample_size <= 0) { + throw BinderException("Size of the RESERVOIR_QUANTILE sample must be bigger than 0"); + } + + // remove the quantile argument so we can use the unary aggregate + Function::EraseArgument(function, arguments, arguments.size() - 1); + Function::EraseArgument(function, arguments, arguments.size() - 1); + return make_uniq(quantiles, sample_size); +} + +unique_ptr BindReservoirQuantileDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetReservoirQuantileAggregateFunction(arguments[0]->return_type.InternalType()); + auto bind_data = BindReservoirQuantile(context, function, arguments); + function.name = "reservoir_quantile"; + function.serialize = ReservoirQuantileBindData::Serialize; + function.deserialize = ReservoirQuantileBindData::Deserialize; + return bind_data; +} + +AggregateFunction GetReservoirQuantileAggregate(PhysicalType type) { + auto fun = GetReservoirQuantileAggregateFunction(type); + fun.bind = BindReservoirQuantile; + fun.serialize = ReservoirQuantileBindData::Serialize; + fun.deserialize = ReservoirQuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::DOUBLE); + return fun; +} + +unique_ptr BindReservoirQuantileDecimalList(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetReservoirQuantileListAggregateFunction(arguments[0]->return_type); + auto bind_data = BindReservoirQuantile(context, function, arguments); + function.serialize = ReservoirQuantileBindData::Serialize; + function.deserialize = ReservoirQuantileBindData::Deserialize; + function.name = "reservoir_quantile"; + return bind_data; +} + +AggregateFunction GetReservoirQuantileListAggregate(const LogicalType &type) { + auto fun = GetReservoirQuantileListAggregateFunction(type); + fun.bind = BindReservoirQuantile; + fun.serialize = ReservoirQuantileBindData::Serialize; + fun.deserialize = ReservoirQuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); + fun.arguments.push_back(list_of_double); + return fun; +} + +static void DefineReservoirQuantile(AggregateFunctionSet &set, const LogicalType &type) { + // Four versions: type, scalar/list[, count] + auto fun = GetReservoirQuantileAggregate(type.InternalType()); + set.AddFunction(fun); + + fun.arguments.emplace_back(LogicalType::INTEGER); + set.AddFunction(fun); + + // List variants + fun = GetReservoirQuantileListAggregate(type); + set.AddFunction(fun); + + fun.arguments.emplace_back(LogicalType::INTEGER); + set.AddFunction(fun); +} + +static void GetReservoirQuantileDecimalFunction(AggregateFunctionSet &set, const vector &arguments, + const LogicalType &return_value) { + AggregateFunction fun(arguments, return_value, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + BindReservoirQuantileDecimal); + fun.serialize = ReservoirQuantileBindData::Serialize; + fun.deserialize = ReservoirQuantileBindData::Deserialize; + set.AddFunction(fun); + + fun.arguments.emplace_back(LogicalType::INTEGER); + set.AddFunction(fun); +} + +AggregateFunctionSet ReservoirQuantileFun::GetFunctions() { + AggregateFunctionSet reservoir_quantile; + + // DECIMAL + GetReservoirQuantileDecimalFunction(reservoir_quantile, {LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, + LogicalTypeId::DECIMAL); + GetReservoirQuantileDecimalFunction(reservoir_quantile, + {LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, + LogicalType::LIST(LogicalTypeId::DECIMAL)); + + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::TINYINT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::SMALLINT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::INTEGER); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::BIGINT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::HUGEINT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::FLOAT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::DOUBLE); + return reservoir_quantile; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/nested/histogram.cpp b/src/duckdb/src/core_functions/aggregate/nested/histogram.cpp new file mode 100644 index 00000000..b50c535e --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/nested/histogram.cpp @@ -0,0 +1,265 @@ +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/core_functions/aggregate/nested_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +struct HistogramFunctor { + template > + static void HistogramUpdate(UnifiedVectorFormat &sdata, UnifiedVectorFormat &input_data, idx_t count) { + auto states = (HistogramAggState **)sdata.data; + for (idx_t i = 0; i < count; i++) { + if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) { + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + state.hist = new MAP_TYPE(); + } + auto value = UnifiedVectorFormat::GetData(input_data); + (*state.hist)[value[input_data.sel->get_index(i)]]++; + } + } + } + + template + static Value HistogramFinalize(T first) { + return Value::CreateValue(first); + } +}; + +struct HistogramStringFunctor { + template > + static void HistogramUpdate(UnifiedVectorFormat &sdata, UnifiedVectorFormat &input_data, idx_t count) { + auto states = (HistogramAggState **)sdata.data; + auto input_strings = UnifiedVectorFormat::GetData(input_data); + for (idx_t i = 0; i < count; i++) { + if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) { + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + state.hist = new MAP_TYPE(); + } + (*state.hist)[input_strings[input_data.sel->get_index(i)].GetString()]++; + } + } + } + + template + static Value HistogramFinalize(T first) { + string_t value = first; + return Value::CreateValue(value); + } +}; + +struct HistogramFunction { + template + static void Initialize(STATE &state) { + state.hist = nullptr; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.hist) { + delete state.hist; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +template +static void HistogramUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, + idx_t count) { + + D_ASSERT(input_count == 1); + + auto &input = inputs[0]; + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + UnifiedVectorFormat input_data; + input.ToUnifiedFormat(count, input_data); + + OP::template HistogramUpdate(sdata, input_data, count); +} + +template +static void HistogramCombineFunction(Vector &state_vector, Vector &combined, AggregateInputData &, idx_t count) { + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states_ptr = (HistogramAggState **)sdata.data; + + auto combined_ptr = FlatVector::GetData *>(combined); + + for (idx_t i = 0; i < count; i++) { + auto &state = *states_ptr[sdata.sel->get_index(i)]; + if (!state.hist) { + continue; + } + if (!combined_ptr[i]->hist) { + combined_ptr[i]->hist = new MAP_TYPE(); + } + D_ASSERT(combined_ptr[i]->hist); + D_ASSERT(state.hist); + for (auto &entry : *state.hist) { + (*combined_ptr[i]->hist)[entry.first] += entry.second; + } + } +} + +template +static void HistogramFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, + idx_t offset) { + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = (HistogramAggState **)sdata.data; + + auto &mask = FlatVector::Validity(result); + auto old_len = ListVector::GetListSize(result); + + for (idx_t i = 0; i < count; i++) { + const auto rid = i + offset; + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + mask.SetInvalid(rid); + continue; + } + + for (auto &entry : *state.hist) { + Value bucket_value = OP::template HistogramFinalize(entry.first); + auto count_value = Value::CreateValue(entry.second); + auto struct_value = + Value::STRUCT({std::make_pair("key", bucket_value), std::make_pair("value", count_value)}); + ListVector::PushBack(result, struct_value); + } + + auto list_struct_data = ListVector::GetData(result); + list_struct_data[rid].length = ListVector::GetListSize(result) - old_len; + list_struct_data[rid].offset = old_len; + old_len += list_struct_data[rid].length; + } + result.Verify(count); +} + +unique_ptr HistogramBindFunction(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + + D_ASSERT(arguments.size() == 1); + + if (arguments[0]->return_type.id() == LogicalTypeId::LIST || + arguments[0]->return_type.id() == LogicalTypeId::STRUCT || + arguments[0]->return_type.id() == LogicalTypeId::MAP) { + throw NotImplementedException("Unimplemented type for histogram %s", arguments[0]->return_type.ToString()); + } + + auto struct_type = LogicalType::MAP(arguments[0]->return_type, LogicalType::UBIGINT); + + function.return_type = struct_type; + return make_uniq(function.return_type); +} + +template > +static AggregateFunction GetHistogramFunction(const LogicalType &type) { + + using STATE_TYPE = HistogramAggState; + + return AggregateFunction("histogram", {type}, LogicalTypeId::MAP, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + HistogramUpdateFunction, HistogramCombineFunction, + HistogramFinalizeFunction, nullptr, HistogramBindFunction, + AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetMapType(const LogicalType &type) { + + if (IS_ORDERED) { + return GetHistogramFunction(type); + } + return GetHistogramFunction>(type); +} + +template +AggregateFunction GetHistogramFunction(const LogicalType &type) { + + switch (type.id()) { + case LogicalType::BOOLEAN: + return GetMapType(type); + case LogicalType::UTINYINT: + return GetMapType(type); + case LogicalType::USMALLINT: + return GetMapType(type); + case LogicalType::UINTEGER: + return GetMapType(type); + case LogicalType::UBIGINT: + return GetMapType(type); + case LogicalType::TINYINT: + return GetMapType(type); + case LogicalType::SMALLINT: + return GetMapType(type); + case LogicalType::INTEGER: + return GetMapType(type); + case LogicalType::BIGINT: + return GetMapType(type); + case LogicalType::FLOAT: + return GetMapType(type); + case LogicalType::DOUBLE: + return GetMapType(type); + case LogicalType::VARCHAR: + return GetMapType(type); + case LogicalType::TIMESTAMP: + return GetMapType(type); + case LogicalType::TIMESTAMP_TZ: + return GetMapType(type); + case LogicalType::TIMESTAMP_S: + return GetMapType(type); + case LogicalType::TIMESTAMP_MS: + return GetMapType(type); + case LogicalType::TIMESTAMP_NS: + return GetMapType(type); + case LogicalType::TIME: + return GetMapType(type); + case LogicalType::TIME_TZ: + return GetMapType(type); + case LogicalType::DATE: + return GetMapType(type); + default: + throw InternalException("Unimplemented histogram aggregate"); + } +} + +AggregateFunctionSet HistogramFun::GetFunctions() { + AggregateFunctionSet fun; + fun.AddFunction(GetHistogramFunction<>(LogicalType::BOOLEAN)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::UTINYINT)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::USMALLINT)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::UINTEGER)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::UBIGINT)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::TINYINT)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::SMALLINT)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::INTEGER)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::BIGINT)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::FLOAT)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::DOUBLE)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::VARCHAR)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_TZ)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_S)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_MS)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_NS)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::TIME)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::TIME_TZ)); + fun.AddFunction(GetHistogramFunction<>(LogicalType::DATE)); + return fun; +} + +AggregateFunction HistogramFun::GetHistogramUnorderedMap(LogicalType &type) { + const auto &const_type = type; + return GetHistogramFunction(const_type); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/nested/list.cpp b/src/duckdb/src/core_functions/aggregate/nested/list.cpp new file mode 100644 index 00000000..a7d4d742 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/nested/list.cpp @@ -0,0 +1,213 @@ +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types/list_segment.hpp" +#include "duckdb/core_functions/aggregate/nested_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +struct ListBindData : public FunctionData { + explicit ListBindData(const LogicalType &stype_p); + ~ListBindData() override; + + LogicalType stype; + ListSegmentFunctions functions; + + unique_ptr Copy() const override { + return make_uniq(stype); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return stype == other.stype; + } +}; + +ListBindData::ListBindData(const LogicalType &stype_p) : stype(stype_p) { + // always unnest once because the result vector is of type LIST + auto type = ListType::GetChildType(stype_p); + GetSegmentDataFunctions(functions, type); +} + +ListBindData::~ListBindData() { +} + +struct ListAggState { + LinkedList linked_list; +}; + +struct ListFunction { + template + static void Initialize(STATE &state) { + state.linked_list.total_capacity = 0; + state.linked_list.first_segment = nullptr; + state.linked_list.last_segment = nullptr; + } + static bool IgnoreNull() { + return false; + } +}; + +static void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + Vector &state_vector, idx_t count) { + + D_ASSERT(input_count == 1); + auto &input = inputs[0]; + RecursiveUnifiedVectorFormat input_data; + Vector::RecursiveToUnifiedFormat(input, count, input_data); + + UnifiedVectorFormat states_data; + state_vector.ToUnifiedFormat(count, states_data); + auto states = UnifiedVectorFormat::GetData(states_data); + + auto &list_bind_data = aggr_input_data.bind_data->Cast(); + + for (idx_t i = 0; i < count; i++) { + auto &state = *states[states_data.sel->get_index(i)]; + list_bind_data.functions.AppendRow(aggr_input_data.allocator, state.linked_list, input_data, i); + } +} + +static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &, idx_t count) { + + UnifiedVectorFormat states_data; + states_vector.ToUnifiedFormat(count, states_data); + auto states_ptr = UnifiedVectorFormat::GetData(states_data); + + auto combined_ptr = FlatVector::GetData(combined); + for (idx_t i = 0; i < count; i++) { + + auto &state = *states_ptr[states_data.sel->get_index(i)]; + if (state.linked_list.total_capacity == 0) { + // NULL, no need to append + // this can happen when adding a FILTER to the grouping, e.g., + // LIST(i) FILTER (WHERE i <> 3) + continue; + } + + if (combined_ptr[i]->linked_list.total_capacity == 0) { + combined_ptr[i]->linked_list = state.linked_list; + continue; + } + + // append the linked list + combined_ptr[i]->linked_list.last_segment->next = state.linked_list.first_segment; + combined_ptr[i]->linked_list.last_segment = state.linked_list.last_segment; + combined_ptr[i]->linked_list.total_capacity += state.linked_list.total_capacity; + } +} + +static void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + idx_t offset) { + + UnifiedVectorFormat states_data; + states_vector.ToUnifiedFormat(count, states_data); + auto states = UnifiedVectorFormat::GetData(states_data); + + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + + auto &mask = FlatVector::Validity(result); + auto result_data = FlatVector::GetData(result); + size_t total_len = ListVector::GetListSize(result); + + auto &list_bind_data = aggr_input_data.bind_data->Cast(); + + // first iterate over all entries and set up the list entries, and get the newly required total length + for (idx_t i = 0; i < count; i++) { + + auto &state = *states[states_data.sel->get_index(i)]; + const auto rid = i + offset; + result_data[rid].offset = total_len; + if (state.linked_list.total_capacity == 0) { + mask.SetInvalid(rid); + result_data[rid].length = 0; + continue; + } + + // set the length and offset of this list in the result vector + auto total_capacity = state.linked_list.total_capacity; + result_data[rid].length = total_capacity; + total_len += total_capacity; + } + + // reserve capacity, then iterate over all entries again and copy over the data to the child vector + ListVector::Reserve(result, total_len); + auto &result_child = ListVector::GetEntry(result); + for (idx_t i = 0; i < count; i++) { + + auto &state = *states[states_data.sel->get_index(i)]; + const auto rid = i + offset; + if (state.linked_list.total_capacity == 0) { + continue; + } + + idx_t current_offset = result_data[rid].offset; + list_bind_data.functions.BuildListVector(state.linked_list, result_child, current_offset); + } + + ListVector::SetListSize(result, total_len); +} + +static void ListWindow(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, + idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, + Vector &result, idx_t rid, idx_t bias) { + + auto &list_bind_data = aggr_input_data.bind_data->Cast(); + LinkedList linked_list; + + // UPDATE step + + D_ASSERT(input_count == 1); + auto &input = inputs[0]; + + // FIXME: we unify more values than necessary (count is frame.end) + RecursiveUnifiedVectorFormat input_data; + Vector::RecursiveToUnifiedFormat(input, frame.end, input_data); + + for (idx_t i = frame.start; i < frame.end; i++) { + list_bind_data.functions.AppendRow(aggr_input_data.allocator, linked_list, input_data, i); + } + + // FINALIZE step + + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + auto result_data = FlatVector::GetData(result); + size_t total_len = ListVector::GetListSize(result); + + // set the length and offset of this list in the result vector + result_data[rid].offset = total_len; + result_data[rid].length = linked_list.total_capacity; + D_ASSERT(linked_list.total_capacity != 0); + total_len += linked_list.total_capacity; + + // reserve capacity, then copy over the data to the child vector + ListVector::Reserve(result, total_len); + auto &result_child = ListVector::GetEntry(result); + idx_t offset = result_data[rid].offset; + list_bind_data.functions.BuildListVector(linked_list, result_child, offset); + + ListVector::SetListSize(result, total_len); +} + +unique_ptr ListBindFunction(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + D_ASSERT(arguments.size() == 1); + D_ASSERT(function.arguments.size() == 1); + + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + function.arguments[0] = LogicalTypeId::UNKNOWN; + function.return_type = LogicalType::SQLNULL; + return nullptr; + } + + function.return_type = LogicalType::LIST(arguments[0]->return_type); + return make_uniq(function.return_type); +} + +AggregateFunction ListFun::GetFunction() { + return AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, ListUpdateFunction, + ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, + ListWindow); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/regression/regr_avg.cpp b/src/duckdb/src/core_functions/aggregate/regression/regr_avg.cpp new file mode 100644 index 00000000..4136ab03 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/regression/regr_avg.cpp @@ -0,0 +1,64 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/core_functions/aggregate/regression_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { +struct RegrState { + double sum; + size_t count; +}; + +struct RegrAvgFunction { + template + static void Initialize(STATE &state) { + state.sum = 0; + state.count = 0; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.sum += source.sum; + target.count += source.count; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.sum / (double)state.count; + } + } + static bool IgnoreNull() { + return true; + } +}; +struct RegrAvgXFunction : RegrAvgFunction { + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + state.sum += x; + state.count++; + } +}; + +struct RegrAvgYFunction : RegrAvgFunction { + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + state.sum += y; + state.count++; + } +}; + +AggregateFunction RegrAvgxFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +AggregateFunction RegrAvgyFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/regression/regr_count.cpp b/src/duckdb/src/core_functions/aggregate/regression/regr_count.cpp new file mode 100644 index 00000000..333bef41 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/regression/regr_count.cpp @@ -0,0 +1,18 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/core_functions/aggregate/regression_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/core_functions/aggregate/regression/regr_count.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +AggregateFunction RegrCountFun::GetFunction() { + auto regr_count = AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::UINTEGER); + regr_count.name = "regr_count"; + regr_count.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return regr_count; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/regression/regr_intercept.cpp b/src/duckdb/src/core_functions/aggregate/regression/regr_intercept.cpp new file mode 100644 index 00000000..a3a11745 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/regression/regr_intercept.cpp @@ -0,0 +1,63 @@ +//! AVG(y)-REGR_SLOPE(y,x)*AVG(x) + +#include "duckdb/core_functions/aggregate/regression_functions.hpp" +#include "duckdb/core_functions/aggregate/regression/regr_slope.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct RegrInterceptState { + size_t count; + double sum_x; + double sum_y; + RegrSlopeState slope; +}; + +struct RegrInterceptOperation { + template + static void Initialize(STATE &state) { + state.count = 0; + state.sum_x = 0; + state.sum_y = 0; + RegrSlopeOperation::Initialize(state.slope); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + state.count++; + state.sum_x += x; + state.sum_y += y; + RegrSlopeOperation::Operation(state.slope, y, x, idata); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + target.count += source.count; + target.sum_x += source.sum_x; + target.sum_y += source.sum_y; + RegrSlopeOperation::Combine(source.slope, target.slope, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + return; + } + RegrSlopeOperation::Finalize(state.slope, target, finalize_data); + auto x_avg = state.sum_x / state.count; + auto y_avg = state.sum_y / state.count; + target = y_avg - target * x_avg; + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction RegrInterceptFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/regression/regr_r2.cpp b/src/duckdb/src/core_functions/aggregate/regression/regr_r2.cpp new file mode 100644 index 00000000..4d68225e --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/regression/regr_r2.cpp @@ -0,0 +1,72 @@ +// REGR_R2(y, x) +// Returns the coefficient of determination for non-null pairs in a group. +// It is computed for non-null pairs using the following formula: +// null if var_pop(x) = 0, else +// 1 if var_pop(y) = 0 and var_pop(x) <> 0, else +// power(corr(y,x), 2) + +#include "duckdb/core_functions/aggregate/algebraic/corr.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/core_functions/aggregate/regression_functions.hpp" + +namespace duckdb { +struct RegrR2State { + CorrState corr; + StddevState var_pop_x; + StddevState var_pop_y; +}; + +struct RegrR2Operation { + template + static void Initialize(STATE &state) { + CorrOperation::Initialize(state.corr); + STDDevBaseOperation::Initialize(state.var_pop_x); + STDDevBaseOperation::Initialize(state.var_pop_y); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + CorrOperation::Operation(state.corr, y, x, idata); + STDDevBaseOperation::Execute(state.var_pop_x, x); + STDDevBaseOperation::Execute(state.var_pop_y, y); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + CorrOperation::Combine(source.corr, target.corr, aggr_input_data); + STDDevBaseOperation::Combine(source.var_pop_x, target.var_pop_x, aggr_input_data); + STDDevBaseOperation::Combine(source.var_pop_y, target.var_pop_y, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + auto var_pop_x = state.var_pop_x.count > 1 ? (state.var_pop_x.dsquared / state.var_pop_x.count) : 0; + if (!Value::DoubleIsFinite(var_pop_x)) { + throw OutOfRangeException("VARPOP(X) is out of range!"); + } + if (var_pop_x == 0) { + finalize_data.ReturnNull(); + return; + } + auto var_pop_y = state.var_pop_y.count > 1 ? (state.var_pop_y.dsquared / state.var_pop_y.count) : 0; + if (!Value::DoubleIsFinite(var_pop_y)) { + throw OutOfRangeException("VARPOP(Y) is out of range!"); + } + if (var_pop_y == 0) { + target = 1; + return; + } + CorrOperation::Finalize(state.corr, target, finalize_data); + target = pow(target, 2); + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction RegrR2Fun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/regression/regr_slope.cpp b/src/duckdb/src/core_functions/aggregate/regression/regr_slope.cpp new file mode 100644 index 00000000..1e86b011 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/regression/regr_slope.cpp @@ -0,0 +1,20 @@ +// REGR_SLOPE(y, x) +// Returns the slope of the linear regression line for non-null pairs in a group. +// It is computed for non-null pairs using the following formula: +// COVAR_POP(x,y) / VAR_POP(x) + +//! Input : Any numeric type +//! Output : Double + +#include "duckdb/core_functions/aggregate/regression/regr_slope.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/core_functions/aggregate/regression_functions.hpp" + +namespace duckdb { + +AggregateFunction RegrSlopeFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/regression/regr_sxx_syy.cpp b/src/duckdb/src/core_functions/aggregate/regression/regr_sxx_syy.cpp new file mode 100644 index 00000000..e789172d --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/regression/regr_sxx_syy.cpp @@ -0,0 +1,75 @@ +// REGR_SXX(y, x) +// Returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs. +// REGR_SYY(y, x) +// Returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs. + +#include "duckdb/core_functions/aggregate/regression/regr_count.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/core_functions/aggregate/regression_functions.hpp" + +namespace duckdb { + +struct RegrSState { + size_t count; + StddevState var_pop; +}; + +struct RegrBaseOperation { + template + static void Initialize(STATE &state) { + RegrCountFunction::Initialize(state.count); + STDDevBaseOperation::Initialize(state.var_pop); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + RegrCountFunction::Combine(source.count, target.count, aggr_input_data); + STDDevBaseOperation::Combine(source.var_pop, target.var_pop, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.var_pop.count == 0) { + finalize_data.ReturnNull(); + return; + } + auto var_pop = state.var_pop.count > 1 ? (state.var_pop.dsquared / state.var_pop.count) : 0; + if (!Value::DoubleIsFinite(var_pop)) { + throw OutOfRangeException("VARPOP is out of range!"); + } + RegrCountFunction::Finalize(state.count, target, finalize_data); + target *= var_pop; + } + + static bool IgnoreNull() { + return true; + } +}; + +struct RegrSXXOperation : RegrBaseOperation { + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + RegrCountFunction::Operation(state.count, y, x, idata); + STDDevBaseOperation::Execute(state.var_pop, x); + } +}; + +struct RegrSYYOperation : RegrBaseOperation { + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + RegrCountFunction::Operation(state.count, y, x, idata); + STDDevBaseOperation::Execute(state.var_pop, y); + } +}; + +AggregateFunction RegrSXXFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +AggregateFunction RegrSYYFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/regression/regr_sxy.cpp b/src/duckdb/src/core_functions/aggregate/regression/regr_sxy.cpp new file mode 100644 index 00000000..e3f3d4ae --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/regression/regr_sxy.cpp @@ -0,0 +1,53 @@ +// REGR_SXY(y, x) +// Returns REGR_COUNT(expr1, expr2) * COVAR_POP(expr1, expr2) for non-null pairs. + +#include "duckdb/core_functions/aggregate/regression/regr_count.hpp" +#include "duckdb/core_functions/aggregate/algebraic/covar.hpp" +#include "duckdb/core_functions/aggregate/regression_functions.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct RegrSXyState { + size_t count; + CovarState cov_pop; +}; + +struct RegrSXYOperation { + template + static void Initialize(STATE &state) { + RegrCountFunction::Initialize(state.count); + CovarOperation::Initialize(state.cov_pop); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + RegrCountFunction::Operation(state.count, y, x, idata); + CovarOperation::Operation(state.cov_pop, y, x, idata); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); + RegrCountFunction::Combine(source.count, target.count, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + CovarPopOperation::Finalize(state.cov_pop, target, finalize_data); + auto cov_pop = target; + RegrCountFunction::Finalize(state.count, target, finalize_data); + target *= cov_pop; + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction RegrSXYFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/core_functions.cpp b/src/duckdb/src/core_functions/core_functions.cpp new file mode 100644 index 00000000..25a57f59 --- /dev/null +++ b/src/duckdb/src/core_functions/core_functions.cpp @@ -0,0 +1,50 @@ +#include "duckdb/core_functions/core_functions.hpp" +#include "duckdb/core_functions/function_list.hpp" +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" + +namespace duckdb { + +template +void FillExtraInfo(StaticFunctionDefinition &function, T &info) { + info.internal = true; + info.description = function.description; + info.parameter_names = StringUtil::Split(function.parameters, ","); + info.example = function.example; +} + +void CoreFunctions::RegisterFunctions(Catalog &catalog, CatalogTransaction transaction) { + auto functions = StaticFunctionDefinition::GetFunctionList(); + for (idx_t i = 0; functions[i].name; i++) { + auto &function = functions[i]; + if (function.get_function || function.get_function_set) { + // scalar function + ScalarFunctionSet result; + if (function.get_function) { + result.AddFunction(function.get_function()); + } else { + result = function.get_function_set(); + } + result.name = function.name; + CreateScalarFunctionInfo info(result); + FillExtraInfo(function, info); + catalog.CreateFunction(transaction, info); + } else if (function.get_aggregate_function || function.get_aggregate_function_set) { + // aggregate function + AggregateFunctionSet result; + if (function.get_aggregate_function) { + result.AddFunction(function.get_aggregate_function()); + } else { + result = function.get_aggregate_function_set(); + } + result.name = function.name; + CreateAggregateFunctionInfo info(result); + FillExtraInfo(function, info); + catalog.CreateFunction(transaction, info); + } else { + throw InternalException("Do not know how to register function of this type"); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/function_list.cpp b/src/duckdb/src/core_functions/function_list.cpp new file mode 100644 index 00000000..9e9ef048 --- /dev/null +++ b/src/duckdb/src/core_functions/function_list.cpp @@ -0,0 +1,370 @@ +#include "duckdb/core_functions/function_list.hpp" +#include "duckdb/core_functions/aggregate/algebraic_functions.hpp" +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/core_functions/aggregate/nested_functions.hpp" +#include "duckdb/core_functions/aggregate/regression_functions.hpp" +#include "duckdb/core_functions/scalar/bit_functions.hpp" +#include "duckdb/core_functions/scalar/blob_functions.hpp" +#include "duckdb/core_functions/scalar/date_functions.hpp" +#include "duckdb/core_functions/scalar/enum_functions.hpp" +#include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/core_functions/scalar/map_functions.hpp" +#include "duckdb/core_functions/scalar/math_functions.hpp" +#include "duckdb/core_functions/scalar/operators_functions.hpp" +#include "duckdb/core_functions/scalar/random_functions.hpp" +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/core_functions/scalar/struct_functions.hpp" +#include "duckdb/core_functions/scalar/union_functions.hpp" +#include "duckdb/core_functions/scalar/debug_functions.hpp" + +namespace duckdb { + +// Scalar Function +#define DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _NAME) \ + { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::GetFunction, nullptr, nullptr, nullptr } +#define DUCKDB_SCALAR_FUNCTION(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _PARAM::Name) +#define DUCKDB_SCALAR_FUNCTION_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) +// Scalar Function Set +#define DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _NAME) \ + { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, _PARAM::GetFunctions, nullptr, nullptr } +#define DUCKDB_SCALAR_FUNCTION_SET(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) +#define DUCKDB_SCALAR_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) +// Aggregate Function +#define DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _NAME) \ + { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, _PARAM::GetFunction, nullptr } +#define DUCKDB_AGGREGATE_FUNCTION(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _PARAM::Name) +#define DUCKDB_AGGREGATE_FUNCTION_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) +// Aggregate Function Set +#define DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _NAME) \ + { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, nullptr, _PARAM::GetFunctions } +#define DUCKDB_AGGREGATE_FUNCTION_SET(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) +#define DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) +#define FINAL_FUNCTION \ + { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr } + +// this list is generated by scripts/generate_functions.py +static StaticFunctionDefinition internal_functions[] = { + DUCKDB_SCALAR_FUNCTION(FactorialOperatorFun), + DUCKDB_SCALAR_FUNCTION_SET(BitwiseAndFun), + DUCKDB_SCALAR_FUNCTION(PowOperatorFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListInnerProductFunAlias), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDistanceFunAlias), + DUCKDB_SCALAR_FUNCTION_SET(LeftShiftFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListCosineSimilarityFunAlias), + DUCKDB_SCALAR_FUNCTION_SET(RightShiftFun), + DUCKDB_SCALAR_FUNCTION_SET(AbsOperatorFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(PowOperatorFunAlias), + DUCKDB_SCALAR_FUNCTION(StartsWithOperatorFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AbsFun), + DUCKDB_SCALAR_FUNCTION(AcosFun), + DUCKDB_SCALAR_FUNCTION_SET(AgeFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(AggregateFun), + DUCKDB_SCALAR_FUNCTION(AliasFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ApplyFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ApproxCountDistinctFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ApproxQuantileFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgmaxFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgminFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(ArrayAggFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggrFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggregateFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayApplyFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayDistinctFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayFilterFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayReverseSortFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySliceFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySortFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayTransformFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayUniqueFun), + DUCKDB_SCALAR_FUNCTION(ASCIIFun), + DUCKDB_SCALAR_FUNCTION(AsinFun), + DUCKDB_SCALAR_FUNCTION(AtanFun), + DUCKDB_SCALAR_FUNCTION(Atan2Fun), + DUCKDB_AGGREGATE_FUNCTION_SET(AvgFun), + DUCKDB_SCALAR_FUNCTION_SET(BarFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(Base64Fun), + DUCKDB_SCALAR_FUNCTION_SET(BinFun), + DUCKDB_AGGREGATE_FUNCTION_SET(BitAndFun), + DUCKDB_SCALAR_FUNCTION_SET(BitCountFun), + DUCKDB_AGGREGATE_FUNCTION_SET(BitOrFun), + DUCKDB_SCALAR_FUNCTION(BitPositionFun), + DUCKDB_AGGREGATE_FUNCTION_SET(BitXorFun), + DUCKDB_SCALAR_FUNCTION(BitStringFun), + DUCKDB_AGGREGATE_FUNCTION_SET(BitstringAggFun), + DUCKDB_AGGREGATE_FUNCTION(BoolAndFun), + DUCKDB_AGGREGATE_FUNCTION(BoolOrFun), + DUCKDB_SCALAR_FUNCTION(CardinalityFun), + DUCKDB_SCALAR_FUNCTION(CbrtFun), + DUCKDB_SCALAR_FUNCTION_SET(CeilFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(CeilingFun), + DUCKDB_SCALAR_FUNCTION_SET(CenturyFun), + DUCKDB_SCALAR_FUNCTION(ChrFun), + DUCKDB_AGGREGATE_FUNCTION(CorrFun), + DUCKDB_SCALAR_FUNCTION(CosFun), + DUCKDB_SCALAR_FUNCTION(CotFun), + DUCKDB_AGGREGATE_FUNCTION(CovarPopFun), + DUCKDB_AGGREGATE_FUNCTION(CovarSampFun), + DUCKDB_SCALAR_FUNCTION(CurrentDatabaseFun), + DUCKDB_SCALAR_FUNCTION(CurrentDateFun), + DUCKDB_SCALAR_FUNCTION(CurrentQueryFun), + DUCKDB_SCALAR_FUNCTION(CurrentSchemaFun), + DUCKDB_SCALAR_FUNCTION(CurrentSchemasFun), + DUCKDB_SCALAR_FUNCTION(CurrentSettingFun), + DUCKDB_SCALAR_FUNCTION(DamerauLevenshteinFun), + DUCKDB_SCALAR_FUNCTION_SET(DateDiffFun), + DUCKDB_SCALAR_FUNCTION_SET(DatePartFun), + DUCKDB_SCALAR_FUNCTION_SET(DateSubFun), + DUCKDB_SCALAR_FUNCTION_SET(DateTruncFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatediffFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatepartFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatesubFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatetruncFun), + DUCKDB_SCALAR_FUNCTION_SET(DayFun), + DUCKDB_SCALAR_FUNCTION_SET(DayNameFun), + DUCKDB_SCALAR_FUNCTION_SET(DayOfMonthFun), + DUCKDB_SCALAR_FUNCTION_SET(DayOfWeekFun), + DUCKDB_SCALAR_FUNCTION_SET(DayOfYearFun), + DUCKDB_SCALAR_FUNCTION_SET(DecadeFun), + DUCKDB_SCALAR_FUNCTION(DecodeFun), + DUCKDB_SCALAR_FUNCTION(DegreesFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(Editdist3Fun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ElementAtFun), + DUCKDB_SCALAR_FUNCTION(EncodeFun), + DUCKDB_AGGREGATE_FUNCTION_SET(EntropyFun), + DUCKDB_SCALAR_FUNCTION(EnumCodeFun), + DUCKDB_SCALAR_FUNCTION(EnumFirstFun), + DUCKDB_SCALAR_FUNCTION(EnumLastFun), + DUCKDB_SCALAR_FUNCTION(EnumRangeFun), + DUCKDB_SCALAR_FUNCTION(EnumRangeBoundaryFun), + DUCKDB_SCALAR_FUNCTION_SET(EpochFun), + DUCKDB_SCALAR_FUNCTION_SET(EpochMsFun), + DUCKDB_SCALAR_FUNCTION_SET(EpochNsFun), + DUCKDB_SCALAR_FUNCTION_SET(EpochUsFun), + DUCKDB_SCALAR_FUNCTION_SET(EraFun), + DUCKDB_SCALAR_FUNCTION(ErrorFun), + DUCKDB_SCALAR_FUNCTION(EvenFun), + DUCKDB_SCALAR_FUNCTION(ExpFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FactorialFun), + DUCKDB_AGGREGATE_FUNCTION(FAvgFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FilterFun), + DUCKDB_SCALAR_FUNCTION(ListFlattenFun), + DUCKDB_SCALAR_FUNCTION_SET(FloorFun), + DUCKDB_SCALAR_FUNCTION(FormatFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FormatreadabledecimalsizeFun), + DUCKDB_SCALAR_FUNCTION(FormatBytesFun), + DUCKDB_SCALAR_FUNCTION(FromBase64Fun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FromBinaryFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FromHexFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(FsumFun), + DUCKDB_SCALAR_FUNCTION(GammaFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(GcdFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(GenRandomUuidFun), + DUCKDB_SCALAR_FUNCTION_SET(GenerateSeriesFun), + DUCKDB_SCALAR_FUNCTION(GetBitFun), + DUCKDB_SCALAR_FUNCTION(CurrentTimeFun), + DUCKDB_SCALAR_FUNCTION(GetCurrentTimestampFun), + DUCKDB_SCALAR_FUNCTION_SET(GreatestFun), + DUCKDB_SCALAR_FUNCTION_SET(GreatestCommonDivisorFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(GroupConcatFun), + DUCKDB_SCALAR_FUNCTION(HammingFun), + DUCKDB_SCALAR_FUNCTION(HashFun), + DUCKDB_SCALAR_FUNCTION_SET(HexFun), + DUCKDB_AGGREGATE_FUNCTION_SET(HistogramFun), + DUCKDB_SCALAR_FUNCTION_SET(HoursFun), + DUCKDB_SCALAR_FUNCTION(InSearchPathFun), + DUCKDB_SCALAR_FUNCTION(InstrFun), + DUCKDB_SCALAR_FUNCTION_SET(IsFiniteFun), + DUCKDB_SCALAR_FUNCTION_SET(IsInfiniteFun), + DUCKDB_SCALAR_FUNCTION_SET(IsNanFun), + DUCKDB_SCALAR_FUNCTION_SET(ISODayOfWeekFun), + DUCKDB_SCALAR_FUNCTION_SET(ISOYearFun), + DUCKDB_SCALAR_FUNCTION(JaccardFun), + DUCKDB_SCALAR_FUNCTION(JaroSimilarityFun), + DUCKDB_SCALAR_FUNCTION(JaroWinklerSimilarityFun), + DUCKDB_SCALAR_FUNCTION_SET(JulianDayFun), + DUCKDB_AGGREGATE_FUNCTION(KahanSumFun), + DUCKDB_AGGREGATE_FUNCTION(KurtosisFun), + DUCKDB_SCALAR_FUNCTION_SET(LastDayFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(LcmFun), + DUCKDB_SCALAR_FUNCTION_SET(LeastFun), + DUCKDB_SCALAR_FUNCTION_SET(LeastCommonMultipleFun), + DUCKDB_SCALAR_FUNCTION(LeftFun), + DUCKDB_SCALAR_FUNCTION(LeftGraphemeFun), + DUCKDB_SCALAR_FUNCTION(LevenshteinFun), + DUCKDB_SCALAR_FUNCTION(LogGammaFun), + DUCKDB_AGGREGATE_FUNCTION(ListFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListAggrFun), + DUCKDB_SCALAR_FUNCTION(ListAggregateFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListApplyFun), + DUCKDB_SCALAR_FUNCTION_SET(ListCosineSimilarityFun), + DUCKDB_SCALAR_FUNCTION_SET(ListDistanceFun), + DUCKDB_SCALAR_FUNCTION(ListDistinctFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDotProductFun), + DUCKDB_SCALAR_FUNCTION(ListFilterFun), + DUCKDB_SCALAR_FUNCTION_SET(ListInnerProductFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListPackFun), + DUCKDB_SCALAR_FUNCTION_SET(ListReverseSortFun), + DUCKDB_SCALAR_FUNCTION_SET(ListSliceFun), + DUCKDB_SCALAR_FUNCTION_SET(ListSortFun), + DUCKDB_SCALAR_FUNCTION(ListTransformFun), + DUCKDB_SCALAR_FUNCTION(ListUniqueFun), + DUCKDB_SCALAR_FUNCTION(ListValueFun), + DUCKDB_SCALAR_FUNCTION(LnFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(LogFun), + DUCKDB_SCALAR_FUNCTION(Log10Fun), + DUCKDB_SCALAR_FUNCTION(Log2Fun), + DUCKDB_SCALAR_FUNCTION(LpadFun), + DUCKDB_SCALAR_FUNCTION_SET(LtrimFun), + DUCKDB_AGGREGATE_FUNCTION_SET(MadFun), + DUCKDB_SCALAR_FUNCTION_SET(MakeDateFun), + DUCKDB_SCALAR_FUNCTION(MakeTimeFun), + DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampFun), + DUCKDB_SCALAR_FUNCTION(MapFun), + DUCKDB_SCALAR_FUNCTION(MapConcatFun), + DUCKDB_SCALAR_FUNCTION(MapEntriesFun), + DUCKDB_SCALAR_FUNCTION(MapExtractFun), + DUCKDB_SCALAR_FUNCTION(MapFromEntriesFun), + DUCKDB_SCALAR_FUNCTION(MapKeysFun), + DUCKDB_SCALAR_FUNCTION(MapValuesFun), + DUCKDB_AGGREGATE_FUNCTION_SET(MaxFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MaxByFun), + DUCKDB_SCALAR_FUNCTION(MD5Fun), + DUCKDB_SCALAR_FUNCTION(MD5NumberFun), + DUCKDB_SCALAR_FUNCTION(MD5NumberLowerFun), + DUCKDB_SCALAR_FUNCTION(MD5NumberUpperFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MeanFun), + DUCKDB_AGGREGATE_FUNCTION_SET(MedianFun), + DUCKDB_SCALAR_FUNCTION_SET(MicrosecondsFun), + DUCKDB_SCALAR_FUNCTION_SET(MillenniumFun), + DUCKDB_SCALAR_FUNCTION_SET(MillisecondsFun), + DUCKDB_AGGREGATE_FUNCTION_SET(MinFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MinByFun), + DUCKDB_SCALAR_FUNCTION_SET(MinutesFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(MismatchesFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ModeFun), + DUCKDB_SCALAR_FUNCTION_SET(MonthFun), + DUCKDB_SCALAR_FUNCTION_SET(MonthNameFun), + DUCKDB_SCALAR_FUNCTION_SET(NextAfterFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(NowFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(OrdFun), + DUCKDB_SCALAR_FUNCTION(PiFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(PositionFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(PowFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(PowerFun), + DUCKDB_SCALAR_FUNCTION(PrintfFun), + DUCKDB_AGGREGATE_FUNCTION(ProductFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(QuantileFun), + DUCKDB_AGGREGATE_FUNCTION_SET(QuantileContFun), + DUCKDB_AGGREGATE_FUNCTION_SET(QuantileDiscFun), + DUCKDB_SCALAR_FUNCTION_SET(QuarterFun), + DUCKDB_SCALAR_FUNCTION(RadiansFun), + DUCKDB_SCALAR_FUNCTION(RandomFun), + DUCKDB_SCALAR_FUNCTION_SET(ListRangeFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(RegexpSplitToArrayFun), + DUCKDB_AGGREGATE_FUNCTION(RegrAvgxFun), + DUCKDB_AGGREGATE_FUNCTION(RegrAvgyFun), + DUCKDB_AGGREGATE_FUNCTION(RegrCountFun), + DUCKDB_AGGREGATE_FUNCTION(RegrInterceptFun), + DUCKDB_AGGREGATE_FUNCTION(RegrR2Fun), + DUCKDB_AGGREGATE_FUNCTION(RegrSlopeFun), + DUCKDB_AGGREGATE_FUNCTION(RegrSXXFun), + DUCKDB_AGGREGATE_FUNCTION(RegrSXYFun), + DUCKDB_AGGREGATE_FUNCTION(RegrSYYFun), + DUCKDB_SCALAR_FUNCTION_SET(RepeatFun), + DUCKDB_SCALAR_FUNCTION(ReplaceFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ReservoirQuantileFun), + DUCKDB_SCALAR_FUNCTION(ReverseFun), + DUCKDB_SCALAR_FUNCTION(RightFun), + DUCKDB_SCALAR_FUNCTION(RightGraphemeFun), + DUCKDB_SCALAR_FUNCTION_SET(RoundFun), + DUCKDB_SCALAR_FUNCTION(RowFun), + DUCKDB_SCALAR_FUNCTION(RpadFun), + DUCKDB_SCALAR_FUNCTION_SET(RtrimFun), + DUCKDB_SCALAR_FUNCTION_SET(SecondsFun), + DUCKDB_AGGREGATE_FUNCTION(StandardErrorOfTheMeanFun), + DUCKDB_SCALAR_FUNCTION(SetBitFun), + DUCKDB_SCALAR_FUNCTION(SetseedFun), + DUCKDB_SCALAR_FUNCTION(SHA256Fun), + DUCKDB_SCALAR_FUNCTION_SET(SignFun), + DUCKDB_SCALAR_FUNCTION_SET(SignBitFun), + DUCKDB_SCALAR_FUNCTION(SinFun), + DUCKDB_AGGREGATE_FUNCTION(SkewnessFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(SplitFun), + DUCKDB_SCALAR_FUNCTION(SqrtFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StartsWithFun), + DUCKDB_SCALAR_FUNCTION(StatsFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(StddevFun), + DUCKDB_AGGREGATE_FUNCTION(StdDevPopFun), + DUCKDB_AGGREGATE_FUNCTION(StdDevSampFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StrSplitFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(StrSplitRegexFun), + DUCKDB_SCALAR_FUNCTION_SET(StrfTimeFun), + DUCKDB_AGGREGATE_FUNCTION_SET(StringAggFun), + DUCKDB_SCALAR_FUNCTION(StringSplitFun), + DUCKDB_SCALAR_FUNCTION_SET(StringSplitRegexFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StringToArrayFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StrposFun), + DUCKDB_SCALAR_FUNCTION_SET(StrpTimeFun), + DUCKDB_SCALAR_FUNCTION(StructInsertFun), + DUCKDB_SCALAR_FUNCTION(StructPackFun), + DUCKDB_AGGREGATE_FUNCTION_SET(SumFun), + DUCKDB_AGGREGATE_FUNCTION_SET(SumNoOverflowFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(SumkahanFun), + DUCKDB_SCALAR_FUNCTION(TanFun), + DUCKDB_SCALAR_FUNCTION_SET(TimeBucketFun), + DUCKDB_SCALAR_FUNCTION_SET(TimezoneFun), + DUCKDB_SCALAR_FUNCTION_SET(TimezoneHourFun), + DUCKDB_SCALAR_FUNCTION_SET(TimezoneMinuteFun), + DUCKDB_SCALAR_FUNCTION_SET(ToBaseFun), + DUCKDB_SCALAR_FUNCTION(ToBase64Fun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ToBinaryFun), + DUCKDB_SCALAR_FUNCTION(ToDaysFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ToHexFun), + DUCKDB_SCALAR_FUNCTION(ToHoursFun), + DUCKDB_SCALAR_FUNCTION(ToMicrosecondsFun), + DUCKDB_SCALAR_FUNCTION(ToMillisecondsFun), + DUCKDB_SCALAR_FUNCTION(ToMinutesFun), + DUCKDB_SCALAR_FUNCTION(ToMonthsFun), + DUCKDB_SCALAR_FUNCTION(ToSecondsFun), + DUCKDB_SCALAR_FUNCTION(ToTimestampFun), + DUCKDB_SCALAR_FUNCTION(ToYearsFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(TodayFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(TransactionTimestampFun), + DUCKDB_SCALAR_FUNCTION(TranslateFun), + DUCKDB_SCALAR_FUNCTION_SET(TrimFun), + DUCKDB_SCALAR_FUNCTION_SET(TruncFun), + DUCKDB_SCALAR_FUNCTION_SET(TryStrpTimeFun), + DUCKDB_SCALAR_FUNCTION(CurrentTransactionIdFun), + DUCKDB_SCALAR_FUNCTION(TypeOfFun), + DUCKDB_SCALAR_FUNCTION(UnbinFun), + DUCKDB_SCALAR_FUNCTION(UnhexFun), + DUCKDB_SCALAR_FUNCTION(UnicodeFun), + DUCKDB_SCALAR_FUNCTION(UnionExtractFun), + DUCKDB_SCALAR_FUNCTION(UnionTagFun), + DUCKDB_SCALAR_FUNCTION(UnionValueFun), + DUCKDB_SCALAR_FUNCTION(UUIDFun), + DUCKDB_AGGREGATE_FUNCTION(VarPopFun), + DUCKDB_AGGREGATE_FUNCTION(VarSampFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(VarianceFun), + DUCKDB_SCALAR_FUNCTION(VectorTypeFun), + DUCKDB_SCALAR_FUNCTION(VersionFun), + DUCKDB_SCALAR_FUNCTION_SET(WeekFun), + DUCKDB_SCALAR_FUNCTION_SET(WeekDayFun), + DUCKDB_SCALAR_FUNCTION_SET(WeekOfYearFun), + DUCKDB_SCALAR_FUNCTION_SET(BitwiseXorFun), + DUCKDB_SCALAR_FUNCTION_SET(YearFun), + DUCKDB_SCALAR_FUNCTION_SET(YearWeekFun), + DUCKDB_SCALAR_FUNCTION_SET(BitwiseOrFun), + DUCKDB_SCALAR_FUNCTION_SET(BitwiseNotFun), + FINAL_FUNCTION +}; + +StaticFunctionDefinition *StaticFunctionDefinition::GetFunctionList() { + return internal_functions; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/bit/bitstring.cpp b/src/duckdb/src/core_functions/scalar/bit/bitstring.cpp new file mode 100644 index 00000000..ee806c95 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/bit/bitstring.cpp @@ -0,0 +1,97 @@ +#include "duckdb/core_functions/scalar/bit_functions.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// BitStringFunction +//===--------------------------------------------------------------------===// +static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t n) { + if (n < 0) { + throw InvalidInputException("The bitstring length cannot be negative"); + } + if (idx_t(n) < input.GetSize()) { + throw InvalidInputException("Length must be equal or larger than input string"); + } + idx_t len; + Bit::TryGetBitStringSize(input, len, nullptr); // string verification + + len = Bit::ComputeBitstringLen(n); + string_t target = StringVector::EmptyString(result, len); + Bit::BitString(input, n, target); + target.Finalize(); + return target; + }); +} + +ScalarFunction BitStringFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction); +} + +//===--------------------------------------------------------------------===// +// get_bit +//===--------------------------------------------------------------------===// +struct GetBitOperator { + template + static inline TR Operation(TA input, TB n) { + if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { + throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), + NumericHelper::ToString(Bit::BitLength(input) - 1)); + } + return Bit::GetBit(input, n); + } +}; + +ScalarFunction GetBitFun::GetFunction() { + return ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::INTEGER, + ScalarFunction::BinaryFunction); +} + +//===--------------------------------------------------------------------===// +// set_bit +//===--------------------------------------------------------------------===// +static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &result) { + TernaryExecutor::Execute( + args.data[0], args.data[1], args.data[2], result, args.size(), + [&](string_t input, int32_t n, int32_t new_value) { + if (new_value != 0 && new_value != 1) { + throw InvalidInputException("The new bit must be 1 or 0"); + } + if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { + throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), + NumericHelper::ToString(Bit::BitLength(input) - 1)); + } + string_t target = StringVector::EmptyString(result, input.GetSize()); + memcpy(target.GetDataWriteable(), input.GetData(), input.GetSize()); + Bit::SetBit(target, n, new_value); + return target; + }); +} + +ScalarFunction SetBitFun::GetFunction() { + return ScalarFunction({LogicalType::BIT, LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::BIT, + SetBitOperation); +} + +//===--------------------------------------------------------------------===// +// bit_position +//===--------------------------------------------------------------------===// +struct BitPositionOperator { + template + static inline TR Operation(TA substring, TB input) { + if (substring.GetSize() > input.GetSize()) { + return 0; + } + return Bit::BitPosition(substring, input); + } +}; + +ScalarFunction BitPositionFun::GetFunction() { + return ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::INTEGER, + ScalarFunction::BinaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/blob/base64.cpp b/src/duckdb/src/core_functions/scalar/blob/base64.cpp new file mode 100644 index 00000000..3545f3b5 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/blob/base64.cpp @@ -0,0 +1,45 @@ +#include "duckdb/core_functions/scalar/blob_functions.hpp" +#include "duckdb/common/types/blob.hpp" + +namespace duckdb { + +struct Base64EncodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto result_str = StringVector::EmptyString(result, Blob::ToBase64Size(input)); + Blob::ToBase64(input, result_str.GetDataWriteable()); + result_str.Finalize(); + return result_str; + } +}; + +struct Base64DecodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto result_size = Blob::FromBase64Size(input); + auto result_blob = StringVector::EmptyString(result, result_size); + Blob::FromBase64(input, data_ptr_cast(result_blob.GetDataWriteable()), result_size); + result_blob.Finalize(); + return result_blob; + } +}; + +static void Base64EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // decode is also a nop cast, but requires verification if the provided string is actually + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +static void Base64DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // decode is also a nop cast, but requires verification if the provided string is actually + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +ScalarFunction ToBase64Fun::GetFunction() { + return ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, Base64EncodeFunction); +} + +ScalarFunction FromBase64Fun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, Base64DecodeFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/blob/encode.cpp b/src/duckdb/src/core_functions/scalar/blob/encode.cpp new file mode 100644 index 00000000..75e66cc9 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/blob/encode.cpp @@ -0,0 +1,39 @@ +#include "duckdb/core_functions/scalar/blob_functions.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +static void EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // encode is essentially a nop cast from varchar to blob + // we only need to reinterpret the data using the blob type + result.Reinterpret(args.data[0]); +} + +struct BlobDecodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + if (Utf8Proc::Analyze(input_data, input_length) == UnicodeType::INVALID) { + throw ConversionException( + "Failure in decode: could not convert blob to UTF8 string, the blob contained invalid UTF8 characters"); + } + return input; + } +}; + +static void DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // decode is also a nop cast, but requires verification if the provided string is actually + UnaryExecutor::Execute(args.data[0], result, args.size()); + StringVector::AddHeapReference(result, args.data[0]); +} + +ScalarFunction EncodeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, EncodeFunction); +} + +ScalarFunction DecodeFun::GetFunction() { + return ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, DecodeFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/age.cpp b/src/duckdb/src/core_functions/scalar/date/age.cpp new file mode 100644 index 00000000..f8db919f --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/age.cpp @@ -0,0 +1,49 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" + +namespace duckdb { + +static void AgeFunctionStandard(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + auto current_timestamp = Timestamp::GetCurrentTimestamp(); + + UnaryExecutor::ExecuteWithNulls(input.data[0], result, input.size(), + [&](timestamp_t input, ValidityMask &mask, idx_t idx) { + if (Timestamp::IsFinite(input)) { + return Interval::GetAge(current_timestamp, input); + } else { + mask.SetInvalid(idx); + return interval_t(); + } + }); +} + +static void AgeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 2); + + BinaryExecutor::ExecuteWithNulls( + input.data[0], input.data[1], result, input.size(), + [&](timestamp_t input1, timestamp_t input2, ValidityMask &mask, idx_t idx) { + if (Timestamp::IsFinite(input1) && Timestamp::IsFinite(input2)) { + return Interval::GetAge(input1, input2); + } else { + mask.SetInvalid(idx); + return interval_t(); + } + }); +} + +ScalarFunctionSet AgeFun::GetFunctions() { + ScalarFunctionSet age("age"); + age.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::INTERVAL, AgeFunctionStandard)); + age.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, LogicalType::INTERVAL, AgeFunction)); + return age; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/current.cpp b/src/duckdb/src/core_functions/scalar/date/current.cpp new file mode 100644 index 00000000..e8b069e6 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/current.cpp @@ -0,0 +1,54 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/transaction/meta_transaction.hpp" + +namespace duckdb { + +static timestamp_t GetTransactionTimestamp(ExpressionState &state) { + return MetaTransaction::Get(state.GetContext()).start_timestamp; +} + +static void CurrentTimeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 0); + auto val = Value::TIME(Timestamp::GetTime(GetTransactionTimestamp(state))); + result.Reference(val); +} + +static void CurrentDateFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 0); + + auto val = Value::DATE(Timestamp::GetDate(GetTransactionTimestamp(state))); + result.Reference(val); +} + +static void CurrentTimestampFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 0); + + auto val = Value::TIMESTAMPTZ(GetTransactionTimestamp(state)); + result.Reference(val); +} + +ScalarFunction CurrentTimeFun::GetFunction() { + ScalarFunction current_time({}, LogicalType::TIME, CurrentTimeFunction); + current_time.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return current_time; +} + +ScalarFunction CurrentDateFun::GetFunction() { + ScalarFunction current_date({}, LogicalType::DATE, CurrentDateFunction); + current_date.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return current_date; +} + +ScalarFunction GetCurrentTimestampFun::GetFunction() { + ScalarFunction current_timestamp({}, LogicalType::TIMESTAMP_TZ, CurrentTimestampFunction); + current_timestamp.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return current_timestamp; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/date_diff.cpp b/src/duckdb/src/core_functions/scalar/date/date_diff.cpp new file mode 100644 index 00000000..8fdd808d --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/date_diff.cpp @@ -0,0 +1,440 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +// This function is an implementation of the "period-crossing" date difference function from T-SQL +// https://docs.microsoft.com/en-us/sql/t-sql/functions/datediff-transact-sql?view=sql-server-ver15 +struct DateDiff { + template + static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { + BinaryExecutor::ExecuteWithNulls( + left, right, result, count, [&](TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { + return OP::template Operation(startdate, enddate); + } else { + mask.SetInvalid(idx); + return TR(); + } + }); + } + + struct YearOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractYear(enddate) - Date::ExtractYear(startdate); + } + }; + + struct MonthOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + int32_t start_year, start_month, start_day; + Date::Convert(startdate, start_year, start_month, start_day); + int32_t end_year, end_month, end_day; + Date::Convert(enddate, end_year, end_month, end_day); + + return (end_year * 12 + end_month - 1) - (start_year * 12 + start_month - 1); + } + }; + + struct DayOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return TR(Date::EpochDays(enddate)) - TR(Date::EpochDays(startdate)); + } + }; + + struct DecadeOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractYear(enddate) / 10 - Date::ExtractYear(startdate) / 10; + } + }; + + struct CenturyOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractYear(enddate) / 100 - Date::ExtractYear(startdate) / 100; + } + }; + + struct MilleniumOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractYear(enddate) / 1000 - Date::ExtractYear(startdate) / 1000; + } + }; + + struct QuarterOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + int32_t start_year, start_month, start_day; + Date::Convert(startdate, start_year, start_month, start_day); + int32_t end_year, end_month, end_day; + Date::Convert(enddate, end_year, end_month, end_day); + + return (end_year * 12 + end_month - 1) / Interval::MONTHS_PER_QUARTER - + (start_year * 12 + start_month - 1) / Interval::MONTHS_PER_QUARTER; + } + }; + + struct WeekOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::Epoch(Date::GetMondayOfCurrentWeek(enddate)) / Interval::SECS_PER_WEEK - + Date::Epoch(Date::GetMondayOfCurrentWeek(startdate)) / Interval::SECS_PER_WEEK; + } + }; + + struct ISOYearOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractISOYearNumber(enddate) - Date::ExtractISOYearNumber(startdate); + } + }; + + struct MicrosecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::EpochMicroseconds(enddate) - Date::EpochMicroseconds(startdate); + } + }; + + struct MillisecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::EpochMicroseconds(enddate) / Interval::MICROS_PER_MSEC - + Date::EpochMicroseconds(startdate) / Interval::MICROS_PER_MSEC; + } + }; + + struct SecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::Epoch(enddate) - Date::Epoch(startdate); + } + }; + + struct MinutesOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::Epoch(enddate) / Interval::SECS_PER_MINUTE - + Date::Epoch(startdate) / Interval::SECS_PER_MINUTE; + } + }; + + struct HoursOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::Epoch(enddate) / Interval::SECS_PER_HOUR - Date::Epoch(startdate) / Interval::SECS_PER_HOUR; + } + }; +}; + +// TIMESTAMP specialisations +template <> +int64_t DateDiff::YearOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return YearOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::MonthOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return MonthOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::DayOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return DayOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::DecadeOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return DecadeOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::CenturyOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return CenturyOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::MilleniumOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return MilleniumOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::QuarterOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return QuarterOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::WeekOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return WeekOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::ISOYearOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return ISOYearOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::MicrosecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + const auto start = Timestamp::GetEpochMicroSeconds(startdate); + const auto end = Timestamp::GetEpochMicroSeconds(enddate); + return SubtractOperatorOverflowCheck::Operation(end, start); +} + +template <> +int64_t DateDiff::MillisecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return Timestamp::GetEpochMs(enddate) - Timestamp::GetEpochMs(startdate); +} + +template <> +int64_t DateDiff::SecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return Timestamp::GetEpochSeconds(enddate) - Timestamp::GetEpochSeconds(startdate); +} + +template <> +int64_t DateDiff::MinutesOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return Timestamp::GetEpochSeconds(enddate) / Interval::SECS_PER_MINUTE - + Timestamp::GetEpochSeconds(startdate) / Interval::SECS_PER_MINUTE; +} + +template <> +int64_t DateDiff::HoursOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return Timestamp::GetEpochSeconds(enddate) / Interval::SECS_PER_HOUR - + Timestamp::GetEpochSeconds(startdate) / Interval::SECS_PER_HOUR; +} + +// TIME specialisations +template <> +int64_t DateDiff::YearOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"year\" not recognized"); +} + +template <> +int64_t DateDiff::MonthOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"month\" not recognized"); +} + +template <> +int64_t DateDiff::DayOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"day\" not recognized"); +} + +template <> +int64_t DateDiff::DecadeOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"decade\" not recognized"); +} + +template <> +int64_t DateDiff::CenturyOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"century\" not recognized"); +} + +template <> +int64_t DateDiff::MilleniumOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"millennium\" not recognized"); +} + +template <> +int64_t DateDiff::QuarterOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"quarter\" not recognized"); +} + +template <> +int64_t DateDiff::WeekOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"week\" not recognized"); +} + +template <> +int64_t DateDiff::ISOYearOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"isoyear\" not recognized"); +} + +template <> +int64_t DateDiff::MicrosecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros - startdate.micros; +} + +template <> +int64_t DateDiff::MillisecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros / Interval::MICROS_PER_MSEC - startdate.micros / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DateDiff::SecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros / Interval::MICROS_PER_SEC - startdate.micros / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DateDiff::MinutesOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros / Interval::MICROS_PER_MINUTE - startdate.micros / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DateDiff::HoursOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros / Interval::MICROS_PER_HOUR - startdate.micros / Interval::MICROS_PER_HOUR; +} + +template +static int64_t DifferenceDates(DatePartSpecifier type, TA startdate, TB enddate) { + switch (type) { + case DatePartSpecifier::YEAR: + return DateDiff::YearOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MONTH: + return DateDiff::MonthOperator::template Operation(startdate, enddate); + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + return DateDiff::DayOperator::template Operation(startdate, enddate); + case DatePartSpecifier::DECADE: + return DateDiff::DecadeOperator::template Operation(startdate, enddate); + case DatePartSpecifier::CENTURY: + return DateDiff::CenturyOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MILLENNIUM: + return DateDiff::MilleniumOperator::template Operation(startdate, enddate); + case DatePartSpecifier::QUARTER: + return DateDiff::QuarterOperator::template Operation(startdate, enddate); + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + return DateDiff::WeekOperator::template Operation(startdate, enddate); + case DatePartSpecifier::ISOYEAR: + return DateDiff::ISOYearOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MICROSECONDS: + return DateDiff::MicrosecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MILLISECONDS: + return DateDiff::MillisecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + return DateDiff::SecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MINUTE: + return DateDiff::MinutesOperator::template Operation(startdate, enddate); + case DatePartSpecifier::HOUR: + return DateDiff::HoursOperator::template Operation(startdate, enddate); + default: + throw NotImplementedException("Specifier type not implemented for DATEDIFF"); + } +} + +struct DateDiffTernaryOperator { + template + static inline TR Operation(TS part, TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { + return DifferenceDates(GetDatePartSpecifier(part.GetString()), startdate, enddate); + } else { + mask.SetInvalid(idx); + return TR(); + } + } +}; + +template +static void DateDiffBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { + switch (type) { + case DatePartSpecifier::YEAR: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MONTH: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::DECADE: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::CENTURY: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MILLENNIUM: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::QUARTER: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::ISOYEAR: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MICROSECONDS: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MILLISECONDS: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MINUTE: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::HOUR: + DateDiff::BinaryExecute(left, right, result, count); + break; + default: + throw NotImplementedException("Specifier type not implemented for DATEDIFF"); + } +} + +template +static void DateDiffFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3); + auto &part_arg = args.data[0]; + auto &start_arg = args.data[1]; + auto &end_arg = args.data[2]; + + if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // Common case of constant part. + if (ConstantVector::IsNull(part_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); + DateDiffBinaryExecutor(type, start_arg, end_arg, result, args.size()); + } + } else { + TernaryExecutor::ExecuteWithNulls( + part_arg, start_arg, end_arg, result, args.size(), + DateDiffTernaryOperator::Operation); + } +} + +ScalarFunctionSet DateDiffFun::GetFunctions() { + ScalarFunctionSet date_diff("date_diff"); + date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE, LogicalType::DATE}, + LogicalType::BIGINT, DateDiffFunction)); + date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, + LogicalType::BIGINT, DateDiffFunction)); + date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME, LogicalType::TIME}, + LogicalType::BIGINT, DateDiffFunction)); + return date_diff; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/date_part.cpp b/src/duckdb/src/core_functions/scalar/date/date_part.cpp new file mode 100644 index 00000000..dda67a1b --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/date_part.cpp @@ -0,0 +1,1876 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +DatePartSpecifier GetDateTypePartSpecifier(const string &specifier, LogicalType &type) { + const auto part = GetDatePartSpecifier(specifier); + switch (type.id()) { + case LogicalType::TIMESTAMP: + case LogicalType::TIMESTAMP_TZ: + return part; + case LogicalType::DATE: + switch (part) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::MONTH: + case DatePartSpecifier::DAY: + case DatePartSpecifier::DECADE: + case DatePartSpecifier::CENTURY: + case DatePartSpecifier::MILLENNIUM: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::ISOYEAR: + case DatePartSpecifier::WEEK: + case DatePartSpecifier::QUARTER: + case DatePartSpecifier::DOY: + case DatePartSpecifier::YEARWEEK: + case DatePartSpecifier::ERA: + case DatePartSpecifier::EPOCH: + case DatePartSpecifier::JULIAN_DAY: + return part; + default: + break; + } + break; + case LogicalType::TIME: + switch (part) { + case DatePartSpecifier::MICROSECONDS: + case DatePartSpecifier::MILLISECONDS: + case DatePartSpecifier::SECOND: + case DatePartSpecifier::MINUTE: + case DatePartSpecifier::HOUR: + case DatePartSpecifier::EPOCH: + case DatePartSpecifier::TIMEZONE: + case DatePartSpecifier::TIMEZONE_HOUR: + case DatePartSpecifier::TIMEZONE_MINUTE: + return part; + default: + break; + } + break; + case LogicalType::INTERVAL: + switch (part) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::MONTH: + case DatePartSpecifier::DAY: + case DatePartSpecifier::DECADE: + case DatePartSpecifier::CENTURY: + case DatePartSpecifier::QUARTER: + case DatePartSpecifier::MILLENNIUM: + case DatePartSpecifier::MICROSECONDS: + case DatePartSpecifier::MILLISECONDS: + case DatePartSpecifier::SECOND: + case DatePartSpecifier::MINUTE: + case DatePartSpecifier::HOUR: + case DatePartSpecifier::EPOCH: + return part; + default: + break; + } + break; + default: + break; + } + + throw NotImplementedException("\"%s\" units \"%s\" not recognized", EnumUtil::ToString(type.id()), specifier); +} + +template +static unique_ptr PropagateSimpleDatePartStatistics(vector &child_stats) { + // we can always propagate simple date part statistics + // since the min and max can never exceed these bounds + auto result = NumericStats::CreateEmpty(LogicalType::BIGINT); + result.CopyValidity(child_stats[0]); + NumericStats::SetMin(result, Value::BIGINT(MIN)); + NumericStats::SetMax(result, Value::BIGINT(MAX)); + return result.ToUnique(); +} + +struct DatePart { + template + static unique_ptr PropagateDatePartStatistics(vector &child_stats, + const LogicalType &stats_type = LogicalType::BIGINT) { + // we can only propagate complex date part stats if the child has stats + auto &nstats = child_stats[0]; + if (!NumericStats::HasMinMax(nstats)) { + return nullptr; + } + // run the operator on both the min and the max, this gives us the [min, max] bound + auto min = NumericStats::GetMin(nstats); + auto max = NumericStats::GetMax(nstats); + if (min > max) { + return nullptr; + } + // Infinities prevent us from computing generic ranges + if (!Value::IsFinite(min) || !Value::IsFinite(max)) { + return nullptr; + } + TR min_part = OP::template Operation(min); + TR max_part = OP::template Operation(max); + auto result = NumericStats::CreateEmpty(stats_type); + NumericStats::SetMin(result, Value(min_part)); + NumericStats::SetMax(result, Value(max_part)); + result.CopyValidity(child_stats[0]); + return result.ToUnique(); + } + + template + struct PartOperator { + template + static inline TR Operation(TA input, ValidityMask &mask, idx_t idx, void *dataptr) { + if (Value::IsFinite(input)) { + return OP::template Operation(input); + } else { + mask.SetInvalid(idx); + return TR(); + } + } + }; + + template + static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() >= 1); + using IOP = PartOperator; + UnaryExecutor::GenericExecute(input.data[0], result, input.size(), nullptr, true); + } + + struct YearOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractYear(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct MonthOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractMonth(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + // min/max of month operator is [1, 12] + return PropagateSimpleDatePartStatistics<1, 12>(input.child_stats); + } + }; + + struct DayOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractDay(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + // min/max of day operator is [1, 31] + return PropagateSimpleDatePartStatistics<1, 31>(input.child_stats); + } + }; + + struct DecadeOperator { + // From the PG docs: "The year field divided by 10" + template + static inline TR DecadeFromYear(TR yyyy) { + return yyyy / 10; + } + + template + static inline TR Operation(TA input) { + return DecadeFromYear(YearOperator::Operation(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct CenturyOperator { + // From the PG docs: + // "The first century starts at 0001-01-01 00:00:00 AD, although they did not know it at the time. + // This definition applies to all Gregorian calendar countries. + // There is no century number 0, you go from -1 century to 1 century. + // If you disagree with this, please write your complaint to: Pope, Cathedral Saint-Peter of Roma, Vatican." + // (To be fair, His Holiness had nothing to do with this - + // it was the lack of zero in the counting systems of the time...) + template + static inline TR CenturyFromYear(TR yyyy) { + if (yyyy > 0) { + return ((yyyy - 1) / 100) + 1; + } else { + return (yyyy / 100) - 1; + } + } + + template + static inline TR Operation(TA input) { + return CenturyFromYear(YearOperator::Operation(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct MillenniumOperator { + // See the century comment + template + static inline TR MillenniumFromYear(TR yyyy) { + if (yyyy > 0) { + return ((yyyy - 1) / 1000) + 1; + } else { + return (yyyy / 1000) - 1; + } + } + + template + static inline TR Operation(TA input) { + return MillenniumFromYear(YearOperator::Operation(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct QuarterOperator { + template + static inline TR QuarterFromMonth(TR mm) { + return (mm - 1) / Interval::MONTHS_PER_QUARTER + 1; + } + + template + static inline TR Operation(TA input) { + return QuarterFromMonth(Date::ExtractMonth(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + // min/max of quarter operator is [1, 4] + return PropagateSimpleDatePartStatistics<1, 4>(input.child_stats); + } + }; + + struct DayOfWeekOperator { + template + static inline TR DayOfWeekFromISO(TR isodow) { + // day of the week (Sunday = 0, Saturday = 6) + // turn sunday into 0 by doing mod 7 + return isodow % 7; + } + + template + static inline TR Operation(TA input) { + return DayOfWeekFromISO(Date::ExtractISODayOfTheWeek(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 6>(input.child_stats); + } + }; + + struct ISODayOfWeekOperator { + template + static inline TR Operation(TA input) { + // isodow (Monday = 1, Sunday = 7) + return Date::ExtractISODayOfTheWeek(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<1, 7>(input.child_stats); + } + }; + + struct DayOfYearOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractDayOfTheYear(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<1, 366>(input.child_stats); + } + }; + + struct WeekOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractISOWeekNumber(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<1, 54>(input.child_stats); + } + }; + + struct ISOYearOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractISOYearNumber(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct YearWeekOperator { + template + static inline TR YearWeekFromParts(TR yyyy, TR ww) { + return yyyy * 100 + ((yyyy > 0) ? ww : -ww); + } + + template + static inline TR Operation(TA input) { + int32_t yyyy, ww; + Date::ExtractISOYearWeek(input, yyyy, ww); + return YearWeekFromParts(yyyy, ww); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct EpochNanosecondsOperator { + template + static inline TR Operation(TA input) { + return input.micros * Interval::NANOS_PER_MICRO; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct EpochMicrosecondsOperator { + template + static inline TR Operation(TA input) { + return input.micros; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct EpochMillisOperator { + template + static inline TR Operation(TA input) { + return input.micros / Interval::MICROS_PER_MSEC; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + + static void Inverse(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + + UnaryExecutor::Execute(input.data[0], result, input.size(), + [&](int64_t input) { return Timestamp::FromEpochMs(input); }); + } + }; + + struct MicrosecondsOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60000000>(input.child_stats); + } + }; + + struct MillisecondsOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60000>(input.child_stats); + } + }; + + struct SecondsOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60>(input.child_stats); + } + }; + + struct MinutesOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60>(input.child_stats); + } + }; + + struct HoursOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 24>(input.child_stats); + } + }; + + struct EpochOperator { + template + static inline TR Operation(TA input) { + return Date::Epoch(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats, LogicalType::DOUBLE); + } + }; + + struct EraOperator { + template + static inline TR EraFromYear(TR yyyy) { + return yyyy > 0 ? 1 : 0; + } + + template + static inline TR Operation(TA input) { + return EraFromYear(Date::ExtractYear(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 1>(input.child_stats); + } + }; + + struct TimezoneOperator { + template + static inline TR Operation(TA input) { + // Regular timestamps are UTC. + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); + } + }; + + struct JulianDayOperator { + template + static inline TR Operation(TA input) { + return Timestamp::GetJulianDay(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats, LogicalType::DOUBLE); + } + }; + + // These are all zero and have the same restrictions + using TimezoneHourOperator = TimezoneOperator; + using TimezoneMinuteOperator = TimezoneOperator; + + struct StructOperator { + using part_codes_t = vector; + using part_mask_t = uint64_t; + + enum MaskBits : uint8_t { + YMD = 1 << 0, + DOW = 1 << 1, + DOY = 1 << 2, + EPOCH = 1 << 3, + TIME = 1 << 4, + ZONE = 1 << 5, + ISO = 1 << 6, + JD = 1 << 7 + }; + + static part_mask_t GetMask(const part_codes_t &part_codes) { + part_mask_t mask = 0; + for (const auto &part_code : part_codes) { + switch (part_code) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::MONTH: + case DatePartSpecifier::DAY: + case DatePartSpecifier::DECADE: + case DatePartSpecifier::CENTURY: + case DatePartSpecifier::MILLENNIUM: + case DatePartSpecifier::QUARTER: + case DatePartSpecifier::ERA: + mask |= YMD; + break; + case DatePartSpecifier::YEARWEEK: + case DatePartSpecifier::WEEK: + case DatePartSpecifier::ISOYEAR: + mask |= ISO; + break; + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + mask |= DOW; + break; + case DatePartSpecifier::DOY: + mask |= DOY; + break; + case DatePartSpecifier::EPOCH: + mask |= EPOCH; + break; + case DatePartSpecifier::JULIAN_DAY: + mask |= JD; + break; + case DatePartSpecifier::MICROSECONDS: + case DatePartSpecifier::MILLISECONDS: + case DatePartSpecifier::SECOND: + case DatePartSpecifier::MINUTE: + case DatePartSpecifier::HOUR: + mask |= TIME; + break; + case DatePartSpecifier::TIMEZONE: + case DatePartSpecifier::TIMEZONE_HOUR: + case DatePartSpecifier::TIMEZONE_MINUTE: + mask |= ZONE; + break; + case DatePartSpecifier::INVALID: + throw InternalException("Invalid DatePartSpecifier for STRUCT mask!"); + } + } + return mask; + } + + template + static inline P HasPartValue(vector

part_values, DatePartSpecifier part) { + auto idx = size_t(part); + if (IsBigintDatepart(part)) { + return part_values[idx - size_t(DatePartSpecifier::BEGIN_BIGINT)]; + } else { + return part_values[idx - size_t(DatePartSpecifier::BEGIN_DOUBLE)]; + } + } + + using bigint_vec = vector; + using double_vec = vector; + + template + static inline void Operation(bigint_vec &bigint_values, double_vec &double_values, const TA &input, + const idx_t idx, const part_mask_t mask) { + int64_t *bigint_data; + // YMD calculations + int32_t yyyy = 1970; + int32_t mm = 0; + int32_t dd = 1; + if (mask & YMD) { + Date::Convert(input, yyyy, mm, dd); + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::YEAR); + if (bigint_data) { + bigint_data[idx] = yyyy; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::MONTH); + if (bigint_data) { + bigint_data[idx] = mm; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DAY); + if (bigint_data) { + bigint_data[idx] = dd; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DECADE); + if (bigint_data) { + bigint_data[idx] = DecadeOperator::DecadeFromYear(yyyy); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::CENTURY); + if (bigint_data) { + bigint_data[idx] = CenturyOperator::CenturyFromYear(yyyy); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::MILLENNIUM); + if (bigint_data) { + bigint_data[idx] = MillenniumOperator::MillenniumFromYear(yyyy); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::QUARTER); + if (bigint_data) { + bigint_data[idx] = QuarterOperator::QuarterFromMonth(mm); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ERA); + if (bigint_data) { + bigint_data[idx] = EraOperator::EraFromYear(yyyy); + } + } + + // Week calculations + if (mask & DOW) { + auto isodow = Date::ExtractISODayOfTheWeek(input); + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DOW); + if (bigint_data) { + bigint_data[idx] = DayOfWeekOperator::DayOfWeekFromISO(isodow); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ISODOW); + if (bigint_data) { + bigint_data[idx] = isodow; + } + } + + // ISO calculations + if (mask & ISO) { + int32_t ww = 0; + int32_t iyyy = 0; + Date::ExtractISOYearWeek(input, iyyy, ww); + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::WEEK); + if (bigint_data) { + bigint_data[idx] = ww; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ISOYEAR); + if (bigint_data) { + bigint_data[idx] = iyyy; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::YEARWEEK); + if (bigint_data) { + bigint_data[idx] = YearWeekOperator::YearWeekFromParts(iyyy, ww); + } + } + + if (mask & EPOCH) { + auto double_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (double_data) { + double_data[idx] = Date::Epoch(input); + } + } + if (mask & DOY) { + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DOY); + if (bigint_data) { + bigint_data[idx] = Date::ExtractDayOfTheYear(input); + } + } + if (mask & JD) { + auto double_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); + if (double_data) { + double_data[idx] = Date::ExtractJulianDay(input); + } + } + } + }; +}; + +template +static void LastYearFunction(DataChunk &args, ExpressionState &state, Vector &result) { + int32_t last_year = 0; + UnaryExecutor::ExecuteWithNulls(args.data[0], result, args.size(), + [&](T input, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(input)) { + return Date::ExtractYear(input, &last_year); + } else { + mask.SetInvalid(idx); + return 0; + } + }); +} + +template <> +int64_t DatePart::YearOperator::Operation(timestamp_t input) { + return YearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::YearOperator::Operation(interval_t input) { + return input.months / Interval::MONTHS_PER_YEAR; +} + +template <> +int64_t DatePart::YearOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"year\" not recognized"); +} + +template <> +int64_t DatePart::MonthOperator::Operation(timestamp_t input) { + return MonthOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::MonthOperator::Operation(interval_t input) { + return input.months % Interval::MONTHS_PER_YEAR; +} + +template <> +int64_t DatePart::MonthOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"month\" not recognized"); +} + +template <> +int64_t DatePart::DayOperator::Operation(timestamp_t input) { + return DayOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::DayOperator::Operation(interval_t input) { + return input.days; +} + +template <> +int64_t DatePart::DayOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"day\" not recognized"); +} + +template <> +int64_t DatePart::DecadeOperator::Operation(interval_t input) { + return input.months / Interval::MONTHS_PER_DECADE; +} + +template <> +int64_t DatePart::DecadeOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"decade\" not recognized"); +} + +template <> +int64_t DatePart::CenturyOperator::Operation(interval_t input) { + return input.months / Interval::MONTHS_PER_CENTURY; +} + +template <> +int64_t DatePart::CenturyOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"century\" not recognized"); +} + +template <> +int64_t DatePart::MillenniumOperator::Operation(interval_t input) { + return input.months / Interval::MONTHS_PER_MILLENIUM; +} + +template <> +int64_t DatePart::MillenniumOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"millennium\" not recognized"); +} + +template <> +int64_t DatePart::QuarterOperator::Operation(timestamp_t input) { + return QuarterOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::QuarterOperator::Operation(interval_t input) { + return MonthOperator::Operation(input) / Interval::MONTHS_PER_QUARTER + 1; +} + +template <> +int64_t DatePart::QuarterOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"quarter\" not recognized"); +} + +template <> +int64_t DatePart::DayOfWeekOperator::Operation(timestamp_t input) { + return DayOfWeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::DayOfWeekOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"dow\" not recognized"); +} + +template <> +int64_t DatePart::DayOfWeekOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"dow\" not recognized"); +} + +template <> +int64_t DatePart::ISODayOfWeekOperator::Operation(timestamp_t input) { + return ISODayOfWeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::ISODayOfWeekOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"isodow\" not recognized"); +} + +template <> +int64_t DatePart::ISODayOfWeekOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"isodow\" not recognized"); +} + +template <> +int64_t DatePart::DayOfYearOperator::Operation(timestamp_t input) { + return DayOfYearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::DayOfYearOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"doy\" not recognized"); +} + +template <> +int64_t DatePart::DayOfYearOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"doy\" not recognized"); +} + +template <> +int64_t DatePart::WeekOperator::Operation(timestamp_t input) { + return WeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::WeekOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"week\" not recognized"); +} + +template <> +int64_t DatePart::WeekOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"week\" not recognized"); +} + +template <> +int64_t DatePart::ISOYearOperator::Operation(timestamp_t input) { + return ISOYearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::ISOYearOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"isoyear\" not recognized"); +} + +template <> +int64_t DatePart::ISOYearOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"isoyear\" not recognized"); +} + +template <> +int64_t DatePart::YearWeekOperator::Operation(timestamp_t input) { + return YearWeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::YearWeekOperator::Operation(interval_t input) { + const auto yyyy = YearOperator::Operation(input); + const auto ww = WeekOperator::Operation(input); + return YearWeekOperator::YearWeekFromParts(yyyy, ww); +} + +template <> +int64_t DatePart::YearWeekOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"yearweek\" not recognized"); +} + +template <> +int64_t DatePart::EpochNanosecondsOperator::Operation(timestamp_t input) { + return Timestamp::GetEpochNanoSeconds(input); +} + +template <> +int64_t DatePart::EpochNanosecondsOperator::Operation(date_t input) { + return Date::EpochNanoseconds(input); +} + +template <> +int64_t DatePart::EpochNanosecondsOperator::Operation(interval_t input) { + return Interval::GetNanoseconds(input); +} + +template <> +int64_t DatePart::EpochMicrosecondsOperator::Operation(timestamp_t input) { + return Timestamp::GetEpochMicroSeconds(input); +} + +template <> +int64_t DatePart::EpochMicrosecondsOperator::Operation(date_t input) { + return Date::EpochMicroseconds(input); +} + +template <> +int64_t DatePart::EpochMicrosecondsOperator::Operation(interval_t input) { + return Interval::GetMicro(input); +} + +template <> +int64_t DatePart::EpochMillisOperator::Operation(timestamp_t input) { + return Timestamp::GetEpochMs(input); +} + +template <> +int64_t DatePart::EpochMillisOperator::Operation(date_t input) { + return Date::EpochMilliseconds(input); +} + +template <> +int64_t DatePart::EpochMillisOperator::Operation(interval_t input) { + return Interval::GetMilli(input); +} + +template <> +int64_t DatePart::MicrosecondsOperator::Operation(timestamp_t input) { + auto time = Timestamp::GetTime(input); + // remove everything but the second & microsecond part + return time.micros % Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MicrosecondsOperator::Operation(interval_t input) { + // remove everything but the second & microsecond part + return input.micros % Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MicrosecondsOperator::Operation(dtime_t input) { + // remove everything but the second & microsecond part + return input.micros % Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MillisecondsOperator::Operation(timestamp_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DatePart::MillisecondsOperator::Operation(interval_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DatePart::MillisecondsOperator::Operation(dtime_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DatePart::SecondsOperator::Operation(timestamp_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DatePart::SecondsOperator::Operation(interval_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DatePart::SecondsOperator::Operation(dtime_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DatePart::MinutesOperator::Operation(timestamp_t input) { + auto time = Timestamp::GetTime(input); + // remove the hour part, and truncate to minutes + return (time.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MinutesOperator::Operation(interval_t input) { + // remove the hour part, and truncate to minutes + return (input.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MinutesOperator::Operation(dtime_t input) { + // remove the hour part, and truncate to minutes + return (input.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::HoursOperator::Operation(timestamp_t input) { + return Timestamp::GetTime(input).micros / Interval::MICROS_PER_HOUR; +} + +template <> +int64_t DatePart::HoursOperator::Operation(interval_t input) { + return input.micros / Interval::MICROS_PER_HOUR; +} + +template <> +int64_t DatePart::HoursOperator::Operation(dtime_t input) { + return input.micros / Interval::MICROS_PER_HOUR; +} + +template <> +double DatePart::EpochOperator::Operation(timestamp_t input) { + return Timestamp::GetEpochMicroSeconds(input) / double(Interval::MICROS_PER_SEC); +} + +template <> +double DatePart::EpochOperator::Operation(interval_t input) { + int64_t interval_years = input.months / Interval::MONTHS_PER_YEAR; + int64_t interval_days; + interval_days = Interval::DAYS_PER_YEAR * interval_years; + interval_days += Interval::DAYS_PER_MONTH * (input.months % Interval::MONTHS_PER_YEAR); + interval_days += input.days; + int64_t interval_epoch; + interval_epoch = interval_days * Interval::SECS_PER_DAY; + // we add 0.25 days per year to sort of account for leap days + interval_epoch += interval_years * (Interval::SECS_PER_DAY / 4); + return interval_epoch + input.micros / double(Interval::MICROS_PER_SEC); +} + +// TODO: We can't propagate interval statistics because we can't easily compare interval_t for order. +template <> +unique_ptr DatePart::EpochOperator::PropagateStatistics(ClientContext &context, + FunctionStatisticsInput &input) { + return nullptr; +} + +template <> +double DatePart::EpochOperator::Operation(dtime_t input) { + return input.micros / double(Interval::MICROS_PER_SEC); +} + +template <> +unique_ptr DatePart::EpochOperator::PropagateStatistics(ClientContext &context, + FunctionStatisticsInput &input) { + auto result = NumericStats::CreateEmpty(LogicalType::DOUBLE); + result.CopyValidity(input.child_stats[0]); + NumericStats::SetMin(result, Value::DOUBLE(0)); + NumericStats::SetMax(result, Value::DOUBLE(Interval::SECS_PER_DAY)); + return result.ToUnique(); +} + +template <> +int64_t DatePart::EraOperator::Operation(timestamp_t input) { + return EraOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::EraOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"era\" not recognized"); +} + +template <> +int64_t DatePart::EraOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"era\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneOperator::Operation(date_t input) { + throw NotImplementedException("\"date\" units \"timezone\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneOperator::Operation(interval_t input) { + throw NotImplementedException("\"interval\" units \"timezone\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneOperator::Operation(dtime_t input) { + return 0; +} + +template <> +double DatePart::JulianDayOperator::Operation(date_t input) { + return Date::ExtractJulianDay(input); +} + +template <> +double DatePart::JulianDayOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"julian\" not recognized"); +} + +template <> +double DatePart::JulianDayOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"julian\" not recognized"); +} + +template <> +void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const dtime_t &input, + const idx_t idx, const part_mask_t mask) { + int64_t *part_data; + if (mask & TIME) { + const auto micros = MicrosecondsOperator::Operation(input); + part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); + if (part_data) { + part_data[idx] = micros; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_MSEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_SEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); + if (part_data) { + part_data[idx] = MinutesOperator::Operation(input); + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); + if (part_data) { + part_data[idx] = HoursOperator::Operation(input); + } + } + + if (mask & EPOCH) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (part_data) { + part_data[idx] = EpochOperator::Operation(input); + ; + } + } + + if (mask & ZONE) { + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE); + if (part_data) { + part_data[idx] = 0; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_HOUR); + if (part_data) { + part_data[idx] = 0; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_MINUTE); + if (part_data) { + part_data[idx] = 0; + } + } +} + +template <> +void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const timestamp_t &input, + const idx_t idx, const part_mask_t mask) { + date_t d; + dtime_t t; + Timestamp::Convert(input, d, t); + + // Both define epoch, and the correct value is the sum. + // So mask it out and compute it separately. + Operation(bigint_values, double_values, d, idx, mask & ~EPOCH); + Operation(bigint_values, double_values, t, idx, mask & ~EPOCH); + + if (mask & EPOCH) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (part_data) { + part_data[idx] = EpochOperator::Operation(input); + } + } + + if (mask & JD) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); + if (part_data) { + part_data[idx] = JulianDayOperator::Operation(input); + } + } +} + +template <> +void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const interval_t &input, + const idx_t idx, const part_mask_t mask) { + int64_t *part_data; + if (mask & YMD) { + const auto mm = input.months % Interval::MONTHS_PER_YEAR; + part_data = HasPartValue(bigint_values, DatePartSpecifier::YEAR); + if (part_data) { + part_data[idx] = input.months / Interval::MONTHS_PER_YEAR; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MONTH); + if (part_data) { + part_data[idx] = mm; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::DAY); + if (part_data) { + part_data[idx] = input.days; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::DECADE); + if (part_data) { + part_data[idx] = input.months / Interval::MONTHS_PER_DECADE; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::CENTURY); + if (part_data) { + part_data[idx] = input.months / Interval::MONTHS_PER_CENTURY; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLENNIUM); + if (part_data) { + part_data[idx] = input.months / Interval::MONTHS_PER_MILLENIUM; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::QUARTER); + if (part_data) { + part_data[idx] = mm / Interval::MONTHS_PER_QUARTER + 1; + } + } + + if (mask & TIME) { + const auto micros = MicrosecondsOperator::Operation(input); + part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); + if (part_data) { + part_data[idx] = micros; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_MSEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_SEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); + if (part_data) { + part_data[idx] = MinutesOperator::Operation(input); + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); + if (part_data) { + part_data[idx] = HoursOperator::Operation(input); + } + } + + if (mask & EPOCH) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (part_data) { + part_data[idx] = EpochOperator::Operation(input); + } + } +} + +template +static int64_t ExtractElement(DatePartSpecifier type, T element) { + switch (type) { + case DatePartSpecifier::YEAR: + return DatePart::YearOperator::template Operation(element); + case DatePartSpecifier::MONTH: + return DatePart::MonthOperator::template Operation(element); + case DatePartSpecifier::DAY: + return DatePart::DayOperator::template Operation(element); + case DatePartSpecifier::DECADE: + return DatePart::DecadeOperator::template Operation(element); + case DatePartSpecifier::CENTURY: + return DatePart::CenturyOperator::template Operation(element); + case DatePartSpecifier::MILLENNIUM: + return DatePart::MillenniumOperator::template Operation(element); + case DatePartSpecifier::QUARTER: + return DatePart::QuarterOperator::template Operation(element); + case DatePartSpecifier::DOW: + return DatePart::DayOfWeekOperator::template Operation(element); + case DatePartSpecifier::ISODOW: + return DatePart::ISODayOfWeekOperator::template Operation(element); + case DatePartSpecifier::DOY: + return DatePart::DayOfYearOperator::template Operation(element); + case DatePartSpecifier::WEEK: + return DatePart::WeekOperator::template Operation(element); + case DatePartSpecifier::ISOYEAR: + return DatePart::ISOYearOperator::template Operation(element); + case DatePartSpecifier::YEARWEEK: + return DatePart::YearWeekOperator::template Operation(element); + case DatePartSpecifier::MICROSECONDS: + return DatePart::MicrosecondsOperator::template Operation(element); + case DatePartSpecifier::MILLISECONDS: + return DatePart::MillisecondsOperator::template Operation(element); + case DatePartSpecifier::SECOND: + return DatePart::SecondsOperator::template Operation(element); + case DatePartSpecifier::MINUTE: + return DatePart::MinutesOperator::template Operation(element); + case DatePartSpecifier::HOUR: + return DatePart::HoursOperator::template Operation(element); + case DatePartSpecifier::ERA: + return DatePart::EraOperator::template Operation(element); + case DatePartSpecifier::TIMEZONE: + case DatePartSpecifier::TIMEZONE_HOUR: + case DatePartSpecifier::TIMEZONE_MINUTE: + return DatePart::TimezoneOperator::template Operation(element); + default: + throw NotImplementedException("Specifier type not implemented for DATEPART"); + } +} + +template +static void DatePartFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + auto &spec_arg = args.data[0]; + auto &date_arg = args.data[1]; + + BinaryExecutor::ExecuteWithNulls( + spec_arg, date_arg, result, args.size(), [&](string_t specifier, T date, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(date)) { + return ExtractElement(GetDatePartSpecifier(specifier.GetString()), date); + } else { + mask.SetInvalid(idx); + return int64_t(0); + } + }); +} + +static unique_ptr DatePartBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // If we are only looking for Julian Days for timestamps, + // then return doubles. + if (arguments[0]->HasParameter() || !arguments[0]->IsFoldable()) { + return nullptr; + } + + Value part_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + const auto part_name = part_value.ToString(); + switch (GetDatePartSpecifier(part_name)) { + case DatePartSpecifier::JULIAN_DAY: + arguments.erase(arguments.begin()); + bound_function.arguments.erase(bound_function.arguments.begin()); + bound_function.name = "julian"; + bound_function.return_type = LogicalType::DOUBLE; + switch (arguments[0]->return_type.id()) { + case LogicalType::TIMESTAMP: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; + break; + case LogicalType::DATE: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; + break; + default: + throw BinderException("%s can only take DATE or TIMESTAMP arguments", bound_function.name); + } + break; + case DatePartSpecifier::EPOCH: + arguments.erase(arguments.begin()); + bound_function.arguments.erase(bound_function.arguments.begin()); + bound_function.name = "epoch"; + bound_function.return_type = LogicalType::DOUBLE; + switch (arguments[0]->return_type.id()) { + case LogicalType::TIMESTAMP: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + case LogicalType::DATE: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + case LogicalType::INTERVAL: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + case LogicalType::TIME: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + default: + throw BinderException("%s can only take temporal arguments", bound_function.name); + } + break; + default: + break; + } + + return nullptr; +} + +ScalarFunctionSet GetGenericDatePartFunction(scalar_function_t date_func, scalar_function_t ts_func, + scalar_function_t interval_func, function_statistics_t date_stats, + function_statistics_t ts_stats) { + ScalarFunctionSet operator_set; + operator_set.AddFunction( + ScalarFunction({LogicalType::DATE}, LogicalType::BIGINT, std::move(date_func), nullptr, nullptr, date_stats)); + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BIGINT, std::move(ts_func), nullptr, nullptr, ts_stats)); + operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, LogicalType::BIGINT, std::move(interval_func))); + return operator_set; +} + +template +static ScalarFunctionSet GetDatePartFunction() { + return GetGenericDatePartFunction( + DatePart::UnaryFunction, DatePart::UnaryFunction, + ScalarFunction::UnaryFunction, OP::template PropagateStatistics, + OP::template PropagateStatistics); +} + +ScalarFunctionSet GetGenericTimePartFunction(const LogicalType &result_type, scalar_function_t date_func, + scalar_function_t ts_func, scalar_function_t interval_func, + scalar_function_t time_func, function_statistics_t date_stats, + function_statistics_t ts_stats, function_statistics_t time_stats) { + ScalarFunctionSet operator_set; + operator_set.AddFunction( + ScalarFunction({LogicalType::DATE}, result_type, std::move(date_func), nullptr, nullptr, date_stats)); + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP}, result_type, std::move(ts_func), nullptr, nullptr, ts_stats)); + operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, result_type, std::move(interval_func))); + operator_set.AddFunction( + ScalarFunction({LogicalType::TIME}, result_type, std::move(time_func), nullptr, nullptr, time_stats)); + return operator_set; +} + +template +static ScalarFunctionSet GetTimePartFunction(const LogicalType &result_type = LogicalType::BIGINT) { + return GetGenericTimePartFunction( + result_type, DatePart::UnaryFunction, DatePart::UnaryFunction, + ScalarFunction::UnaryFunction, ScalarFunction::UnaryFunction, + OP::template PropagateStatistics, OP::template PropagateStatistics, + OP::template PropagateStatistics); +} + +struct LastDayOperator { + template + static inline TR Operation(TA input) { + int32_t yyyy, mm, dd; + Date::Convert(input, yyyy, mm, dd); + yyyy += (mm / 12); + mm %= 12; + ++mm; + return Date::FromDate(yyyy, mm, 1) - 1; + } +}; + +template <> +date_t LastDayOperator::Operation(timestamp_t input) { + return LastDayOperator::Operation(Timestamp::GetDate(input)); +} + +struct MonthNameOperator { + template + static inline TR Operation(TA input) { + return Date::MONTH_NAMES[DatePart::MonthOperator::Operation(input) - 1]; + } +}; + +struct DayNameOperator { + template + static inline TR Operation(TA input) { + return Date::DAY_NAMES[DatePart::DayOfWeekOperator::Operation(input)]; + } +}; + +struct StructDatePart { + using part_codes_t = vector; + + struct BindData : public VariableReturnBindData { + part_codes_t part_codes; + + explicit BindData(const LogicalType &stype, const part_codes_t &part_codes_p) + : VariableReturnBindData(stype), part_codes(part_codes_p) { + } + + unique_ptr Copy() const override { + return make_uniq(stype, part_codes); + } + }; + + static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // collect names and deconflict, construct return type + if (arguments[0]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[0]->IsFoldable()) { + throw BinderException("%s can only take constant lists of part names", bound_function.name); + } + + case_insensitive_set_t name_collision_set; + child_list_t struct_children; + part_codes_t part_codes; + + Value parts_list = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + if (parts_list.type().id() == LogicalTypeId::LIST) { + auto &list_children = ListValue::GetChildren(parts_list); + if (list_children.empty()) { + throw BinderException("%s requires non-empty lists of part names", bound_function.name); + } + for (const auto &part_value : list_children) { + if (part_value.IsNull()) { + throw BinderException("NULL struct entry name in %s", bound_function.name); + } + const auto part_name = part_value.ToString(); + const auto part_code = GetDateTypePartSpecifier(part_name, arguments[1]->return_type); + if (name_collision_set.find(part_name) != name_collision_set.end()) { + throw BinderException("Duplicate struct entry name \"%s\" in %s", part_name, bound_function.name); + } + name_collision_set.insert(part_name); + part_codes.emplace_back(part_code); + const auto part_type = IsBigintDatepart(part_code) ? LogicalType::BIGINT : LogicalType::DOUBLE; + struct_children.emplace_back(make_pair(part_name, part_type)); + } + } else { + throw BinderException("%s can only take constant lists of part names", bound_function.name); + } + + Function::EraseArgument(bound_function, arguments, 0); + bound_function.return_type = LogicalType::STRUCT(struct_children); + return make_uniq(bound_function.return_type, part_codes); + } + + template + static void Function(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + D_ASSERT(args.ColumnCount() == 1); + + const auto count = args.size(); + Vector &input = args.data[0]; + + // Type counts + const auto BIGINT_COUNT = size_t(DatePartSpecifier::BEGIN_DOUBLE) - size_t(DatePartSpecifier::BEGIN_BIGINT); + const auto DOUBLE_COUNT = size_t(DatePartSpecifier::BEGIN_INVALID) - size_t(DatePartSpecifier::BEGIN_DOUBLE); + DatePart::StructOperator::bigint_vec bigint_values(BIGINT_COUNT, nullptr); + DatePart::StructOperator::double_vec double_values(DOUBLE_COUNT, nullptr); + const auto part_mask = DatePart::StructOperator::GetMask(info.part_codes); + + auto &child_entries = StructVector::GetEntries(result); + + // The first computer of a part "owns" it + // and other requestors just reference the owner + vector owners(int(DatePartSpecifier::JULIAN_DAY) + 1, child_entries.size()); + for (size_t col = 0; col < child_entries.size(); ++col) { + const auto part_index = size_t(info.part_codes[col]); + if (owners[part_index] == child_entries.size()) { + owners[part_index] = col; + } + } + + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + if (ConstantVector::IsNull(input)) { + ConstantVector::SetNull(result, true); + } else { + ConstantVector::SetNull(result, false); + for (size_t col = 0; col < child_entries.size(); ++col) { + auto &child_entry = child_entries[col]; + ConstantVector::SetNull(*child_entry, false); + const auto part_index = size_t(info.part_codes[col]); + if (owners[part_index] == col) { + if (IsBigintDatepart(info.part_codes[col])) { + bigint_values[part_index - size_t(DatePartSpecifier::BEGIN_BIGINT)] = + ConstantVector::GetData(*child_entry); + } else { + double_values[part_index - size_t(DatePartSpecifier::BEGIN_DOUBLE)] = + ConstantVector::GetData(*child_entry); + } + } + } + auto tdata = ConstantVector::GetData(input); + if (Value::IsFinite(tdata[0])) { + DatePart::StructOperator::Operation(bigint_values, double_values, tdata[0], 0, part_mask); + } else { + for (auto &child_entry : child_entries) { + ConstantVector::SetNull(*child_entry, true); + } + } + } + } else { + UnifiedVectorFormat rdata; + input.ToUnifiedFormat(count, rdata); + + const auto &arg_valid = rdata.validity; + auto tdata = UnifiedVectorFormat::GetData(rdata); + + // Start with a valid flat vector + result.SetVectorType(VectorType::FLAT_VECTOR); + auto &res_valid = FlatVector::Validity(result); + if (res_valid.GetData()) { + res_valid.SetAllValid(count); + } + + // Start with valid children + for (size_t col = 0; col < child_entries.size(); ++col) { + auto &child_entry = child_entries[col]; + child_entry->SetVectorType(VectorType::FLAT_VECTOR); + auto &child_validity = FlatVector::Validity(*child_entry); + if (child_validity.GetData()) { + child_validity.SetAllValid(count); + } + + // Pre-multiplex + const auto part_index = size_t(info.part_codes[col]); + if (owners[part_index] == col) { + if (IsBigintDatepart(info.part_codes[col])) { + bigint_values[part_index - size_t(DatePartSpecifier::BEGIN_BIGINT)] = + FlatVector::GetData(*child_entry); + } else { + double_values[part_index - size_t(DatePartSpecifier::BEGIN_DOUBLE)] = + FlatVector::GetData(*child_entry); + } + } + } + + for (idx_t i = 0; i < count; ++i) { + const auto idx = rdata.sel->get_index(i); + if (arg_valid.RowIsValid(idx)) { + if (Value::IsFinite(tdata[idx])) { + DatePart::StructOperator::Operation(bigint_values, double_values, tdata[idx], i, part_mask); + } else { + for (auto &child_entry : child_entries) { + FlatVector::Validity(*child_entry).SetInvalid(i); + } + } + } else { + res_valid.SetInvalid(i); + for (auto &child_entry : child_entries) { + FlatVector::Validity(*child_entry).SetInvalid(i); + } + } + } + } + + // Reference any duplicate parts + for (size_t col = 0; col < child_entries.size(); ++col) { + const auto part_index = size_t(info.part_codes[col]); + const auto owner = owners[part_index]; + if (owner != col) { + child_entries[col]->Reference(*child_entries[owner]); + } + } + + result.Verify(count); + } + + static void SerializeFunction(Serializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + D_ASSERT(bind_data_p); + auto &info = bind_data_p->Cast(); + serializer.WriteProperty(100, "stype", info.stype); + serializer.WriteProperty(101, "part_codes", info.part_codes); + } + + static unique_ptr DeserializeFunction(Deserializer &deserializer, ScalarFunction &bound_function) { + auto stype = deserializer.ReadProperty(100, "stype"); + auto part_codes = deserializer.ReadProperty>(101, "part_codes"); + return make_uniq(std::move(stype), std::move(part_codes)); + } + + template + static ScalarFunction GetFunction(const LogicalType &temporal_type) { + auto part_type = LogicalType::LIST(LogicalType::VARCHAR); + auto result_type = LogicalType::STRUCT({}); + ScalarFunction result({part_type, temporal_type}, result_type, Function, Bind); + result.serialize = SerializeFunction; + result.deserialize = DeserializeFunction; + return result; + } +}; + +ScalarFunctionSet YearFun::GetFunctions() { + return GetGenericDatePartFunction(LastYearFunction, LastYearFunction, + ScalarFunction::UnaryFunction, + DatePart::YearOperator::PropagateStatistics, + DatePart::YearOperator::PropagateStatistics); +} + +ScalarFunctionSet MonthFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet DayFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet DecadeFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet CenturyFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet MillenniumFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet QuarterFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet DayOfWeekFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet ISODayOfWeekFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet DayOfYearFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet WeekFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet ISOYearFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet EraFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet TimezoneFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet TimezoneHourFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet TimezoneMinuteFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet EpochFun::GetFunctions() { + return GetTimePartFunction(LogicalType::DOUBLE); +} + +ScalarFunctionSet EpochNsFun::GetFunctions() { + using OP = DatePart::EpochNanosecondsOperator; + auto operator_set = GetTimePartFunction(); + + // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU + auto tstz_func = DatePart::UnaryFunction; + auto tstz_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); + return operator_set; +} + +ScalarFunctionSet EpochUsFun::GetFunctions() { + using OP = DatePart::EpochMicrosecondsOperator; + auto operator_set = GetTimePartFunction(); + + // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU + auto tstz_func = DatePart::UnaryFunction; + auto tstz_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); + return operator_set; +} + +ScalarFunctionSet EpochMsFun::GetFunctions() { + using OP = DatePart::EpochMillisOperator; + auto operator_set = GetTimePartFunction(); + + // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU + auto tstz_func = DatePart::UnaryFunction; + auto tstz_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); + + // Legacy inverse BIGINT => TIMESTAMP + operator_set.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, DatePart::EpochMillisOperator::Inverse)); + + return operator_set; +} + +ScalarFunctionSet MicrosecondsFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet MillisecondsFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet SecondsFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet MinutesFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet HoursFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet YearWeekFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet DayOfMonthFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet WeekDayFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet WeekOfYearFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet LastDayFun::GetFunctions() { + ScalarFunctionSet last_day; + last_day.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::DATE, + DatePart::UnaryFunction)); + last_day.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::DATE, + DatePart::UnaryFunction)); + return last_day; +} + +ScalarFunctionSet MonthNameFun::GetFunctions() { + ScalarFunctionSet monthname; + monthname.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::VARCHAR, + DatePart::UnaryFunction)); + monthname.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::VARCHAR, + DatePart::UnaryFunction)); + return monthname; +} + +ScalarFunctionSet DayNameFun::GetFunctions() { + ScalarFunctionSet dayname; + dayname.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::VARCHAR, + DatePart::UnaryFunction)); + dayname.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::VARCHAR, + DatePart::UnaryFunction)); + return dayname; +} + +ScalarFunctionSet JulianDayFun::GetFunctions() { + using OP = DatePart::JulianDayOperator; + + ScalarFunctionSet operator_set; + auto date_func = DatePart::UnaryFunction; + auto date_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::DATE}, LogicalType::DOUBLE, date_func, nullptr, nullptr, date_stats)); + auto ts_func = DatePart::UnaryFunction; + auto ts_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::DOUBLE, ts_func, nullptr, nullptr, ts_stats)); + + return operator_set; +} + +ScalarFunctionSet DatePartFun::GetFunctions() { + ScalarFunctionSet date_part; + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + + // struct variants + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::DATE)); + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIMESTAMP)); + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIME)); + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::INTERVAL)); + + return date_part; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/date_sub.cpp b/src/duckdb/src/core_functions/scalar/date/date_sub.cpp new file mode 100644 index 00000000..6d4c4e24 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/date_sub.cpp @@ -0,0 +1,454 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +struct DateSub { + static int64_t SubtractMicros(timestamp_t startdate, timestamp_t enddate) { + const auto start = Timestamp::GetEpochMicroSeconds(startdate); + const auto end = Timestamp::GetEpochMicroSeconds(enddate); + return SubtractOperatorOverflowCheck::Operation(end, start); + } + + template + static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { + BinaryExecutor::ExecuteWithNulls( + left, right, result, count, [&](TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { + return OP::template Operation(startdate, enddate); + } else { + mask.SetInvalid(idx); + return TR(); + } + }); + } + + struct MonthOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + + if (start_ts > end_ts) { + return -MonthOperator::Operation(end_ts, start_ts); + } + // The number of complete months depends on whether end_ts is on the last day of the month. + date_t end_date; + dtime_t end_time; + Timestamp::Convert(end_ts, end_date, end_time); + + int32_t yyyy, mm, dd; + Date::Convert(end_date, yyyy, mm, dd); + const auto end_days = Date::MonthDays(yyyy, mm); + if (end_days == dd) { + // Now check whether the start day is after the end day + date_t start_date; + dtime_t start_time; + Timestamp::Convert(start_ts, start_date, start_time); + Date::Convert(start_date, yyyy, mm, dd); + if (dd > end_days || (dd == end_days && start_time < end_time)) { + // Move back to the same time on the last day of the (shorter) end month + start_date = Date::FromDate(yyyy, mm, end_days); + start_ts = Timestamp::FromDatetime(start_date, start_time); + } + } + + // Our interval difference will now give the correct result. + // Note that PG gives different interval subtraction results, + // so if we change this we will have to reimplement. + return Interval::GetAge(end_ts, start_ts).months; + } + }; + + struct QuarterOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_QUARTER; + } + }; + + struct YearOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_YEAR; + } + }; + + struct DecadeOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_DECADE; + } + }; + + struct CenturyOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_CENTURY; + } + }; + + struct MilleniumOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_MILLENIUM; + } + }; + + struct DayOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_DAY; + } + }; + + struct WeekOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_WEEK; + } + }; + + struct MicrosecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate); + } + }; + + struct MillisecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_MSEC; + } + }; + + struct SecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_SEC; + } + }; + + struct MinutesOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_MINUTE; + } + }; + + struct HoursOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_HOUR; + } + }; +}; + +// DATE specialisations +template <> +int64_t DateSub::YearOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return YearOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MonthOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MonthOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::DayOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return DayOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::DecadeOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return DecadeOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::CenturyOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return CenturyOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MilleniumOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MilleniumOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::QuarterOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return QuarterOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::WeekOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return WeekOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MicrosecondsOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MicrosecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MillisecondsOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MillisecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::SecondsOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return SecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MinutesOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MinutesOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::HoursOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return HoursOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +// TIME specialisations +template <> +int64_t DateSub::YearOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"year\" not recognized"); +} + +template <> +int64_t DateSub::MonthOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"month\" not recognized"); +} + +template <> +int64_t DateSub::DayOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"day\" not recognized"); +} + +template <> +int64_t DateSub::DecadeOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"decade\" not recognized"); +} + +template <> +int64_t DateSub::CenturyOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"century\" not recognized"); +} + +template <> +int64_t DateSub::MilleniumOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"millennium\" not recognized"); +} + +template <> +int64_t DateSub::QuarterOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"quarter\" not recognized"); +} + +template <> +int64_t DateSub::WeekOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"week\" not recognized"); +} + +template <> +int64_t DateSub::MicrosecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros - startdate.micros; +} + +template <> +int64_t DateSub::MillisecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return (enddate.micros - startdate.micros) / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DateSub::SecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return (enddate.micros - startdate.micros) / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DateSub::MinutesOperator::Operation(dtime_t startdate, dtime_t enddate) { + return (enddate.micros - startdate.micros) / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DateSub::HoursOperator::Operation(dtime_t startdate, dtime_t enddate) { + return (enddate.micros - startdate.micros) / Interval::MICROS_PER_HOUR; +} + +template +static int64_t SubtractDateParts(DatePartSpecifier type, TA startdate, TB enddate) { + switch (type) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::ISOYEAR: + return DateSub::YearOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MONTH: + return DateSub::MonthOperator::template Operation(startdate, enddate); + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + return DateSub::DayOperator::template Operation(startdate, enddate); + case DatePartSpecifier::DECADE: + return DateSub::DecadeOperator::template Operation(startdate, enddate); + case DatePartSpecifier::CENTURY: + return DateSub::CenturyOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MILLENNIUM: + return DateSub::MilleniumOperator::template Operation(startdate, enddate); + case DatePartSpecifier::QUARTER: + return DateSub::QuarterOperator::template Operation(startdate, enddate); + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + return DateSub::WeekOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MICROSECONDS: + return DateSub::MicrosecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MILLISECONDS: + return DateSub::MillisecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + return DateSub::SecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MINUTE: + return DateSub::MinutesOperator::template Operation(startdate, enddate); + case DatePartSpecifier::HOUR: + return DateSub::HoursOperator::template Operation(startdate, enddate); + default: + throw NotImplementedException("Specifier type not implemented for DATESUB"); + } +} + +struct DateSubTernaryOperator { + template + static inline TR Operation(TS part, TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { + return SubtractDateParts(GetDatePartSpecifier(part.GetString()), startdate, enddate); + } else { + mask.SetInvalid(idx); + return TR(); + } + } +}; + +template +static void DateSubBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { + switch (type) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::ISOYEAR: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MONTH: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::DECADE: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::CENTURY: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MILLENNIUM: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::QUARTER: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MICROSECONDS: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MILLISECONDS: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MINUTE: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::HOUR: + DateSub::BinaryExecute(left, right, result, count); + break; + default: + throw NotImplementedException("Specifier type not implemented for DATESUB"); + } +} + +template +static void DateSubFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3); + auto &part_arg = args.data[0]; + auto &start_arg = args.data[1]; + auto &end_arg = args.data[2]; + + if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // Common case of constant part. + if (ConstantVector::IsNull(part_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); + DateSubBinaryExecutor(type, start_arg, end_arg, result, args.size()); + } + } else { + TernaryExecutor::ExecuteWithNulls( + part_arg, start_arg, end_arg, result, args.size(), + DateSubTernaryOperator::Operation); + } +} + +ScalarFunctionSet DateSubFun::GetFunctions() { + ScalarFunctionSet date_sub("date_sub"); + date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE, LogicalType::DATE}, + LogicalType::BIGINT, DateSubFunction)); + date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, + LogicalType::BIGINT, DateSubFunction)); + date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME, LogicalType::TIME}, + LogicalType::BIGINT, DateSubFunction)); + return date_sub; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/date_trunc.cpp b/src/duckdb/src/core_functions/scalar/date/date_trunc.cpp new file mode 100644 index 00000000..0493c71b --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/date_trunc.cpp @@ -0,0 +1,734 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +struct DateTrunc { + template + static inline TR UnaryFunction(TA input) { + if (Value::IsFinite(input)) { + return OP::template Operation(input); + } else { + return Cast::template Operation(input); + } + } + + template + static inline void UnaryExecute(Vector &left, Vector &result, idx_t count) { + UnaryExecutor::Execute(left, result, count, UnaryFunction); + } + + struct MillenniumOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate((Date::ExtractYear(input) / 1000) * 1000, 1, 1); + } + }; + + struct CenturyOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate((Date::ExtractYear(input) / 100) * 100, 1, 1); + } + }; + + struct DecadeOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate((Date::ExtractYear(input) / 10) * 10, 1, 1); + } + }; + + struct YearOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate(Date::ExtractYear(input), 1, 1); + } + }; + + struct QuarterOperator { + template + static inline TR Operation(TA input) { + int32_t yyyy, mm, dd; + Date::Convert(input, yyyy, mm, dd); + mm = 1 + (((mm - 1) / 3) * 3); + return Date::FromDate(yyyy, mm, 1); + } + }; + + struct MonthOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate(Date::ExtractYear(input), Date::ExtractMonth(input), 1); + } + }; + + struct WeekOperator { + template + static inline TR Operation(TA input) { + return Date::GetMondayOfCurrentWeek(input); + } + }; + + struct ISOYearOperator { + template + static inline TR Operation(TA input) { + date_t date = Date::GetMondayOfCurrentWeek(input); + date.days -= (Date::ExtractISOWeekNumber(date) - 1) * Interval::DAYS_PER_WEEK; + + return date; + } + }; + + struct DayOperator { + template + static inline TR Operation(TA input) { + return input; + } + }; + + struct HourOperator { + template + static inline TR Operation(TA input) { + int32_t hour, min, sec, micros; + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + Time::Convert(time, hour, min, sec, micros); + return Timestamp::FromDatetime(date, Time::FromTime(hour, 0, 0, 0)); + } + }; + + struct MinuteOperator { + template + static inline TR Operation(TA input) { + int32_t hour, min, sec, micros; + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + Time::Convert(time, hour, min, sec, micros); + return Timestamp::FromDatetime(date, Time::FromTime(hour, min, 0, 0)); + } + }; + + struct SecondOperator { + template + static inline TR Operation(TA input) { + int32_t hour, min, sec, micros; + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + Time::Convert(time, hour, min, sec, micros); + return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, 0)); + } + }; + + struct MillisecondOperator { + template + static inline TR Operation(TA input) { + int32_t hour, min, sec, micros; + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + Time::Convert(time, hour, min, sec, micros); + micros -= micros % Interval::MICROS_PER_MSEC; + return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, micros)); + } + }; + + struct MicrosecondOperator { + template + static inline TR Operation(TA input) { + return input; + } + }; +}; + +// DATE specialisations +template <> +date_t DateTrunc::MillenniumOperator::Operation(timestamp_t input) { + return MillenniumOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::MillenniumOperator::Operation(date_t input) { + return Timestamp::FromDatetime(MillenniumOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::MillenniumOperator::Operation(timestamp_t input) { + return MillenniumOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::CenturyOperator::Operation(timestamp_t input) { + return CenturyOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::CenturyOperator::Operation(date_t input) { + return Timestamp::FromDatetime(CenturyOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::CenturyOperator::Operation(timestamp_t input) { + return CenturyOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::DecadeOperator::Operation(timestamp_t input) { + return DecadeOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::DecadeOperator::Operation(date_t input) { + return Timestamp::FromDatetime(DecadeOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::DecadeOperator::Operation(timestamp_t input) { + return DecadeOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::YearOperator::Operation(timestamp_t input) { + return YearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::YearOperator::Operation(date_t input) { + return Timestamp::FromDatetime(YearOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::YearOperator::Operation(timestamp_t input) { + return YearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::QuarterOperator::Operation(timestamp_t input) { + return QuarterOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::QuarterOperator::Operation(date_t input) { + return Timestamp::FromDatetime(QuarterOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::QuarterOperator::Operation(timestamp_t input) { + return QuarterOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::MonthOperator::Operation(timestamp_t input) { + return MonthOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::MonthOperator::Operation(date_t input) { + return Timestamp::FromDatetime(MonthOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::MonthOperator::Operation(timestamp_t input) { + return MonthOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::WeekOperator::Operation(timestamp_t input) { + return WeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::WeekOperator::Operation(date_t input) { + return Timestamp::FromDatetime(WeekOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::WeekOperator::Operation(timestamp_t input) { + return WeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::ISOYearOperator::Operation(timestamp_t input) { + return ISOYearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::ISOYearOperator::Operation(date_t input) { + return Timestamp::FromDatetime(ISOYearOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::ISOYearOperator::Operation(timestamp_t input) { + return ISOYearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::DayOperator::Operation(timestamp_t input) { + return DayOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::DayOperator::Operation(date_t input) { + return Timestamp::FromDatetime(DayOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::DayOperator::Operation(timestamp_t input) { + return DayOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::HourOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::HourOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::HourOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(HourOperator::Operation(input)); +} + +template <> +date_t DateTrunc::MinuteOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::MinuteOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::MinuteOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(HourOperator::Operation(input)); +} + +template <> +date_t DateTrunc::SecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::SecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::SecondOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(DayOperator::Operation(input)); +} + +template <> +date_t DateTrunc::MillisecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::MillisecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::MillisecondOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(MillisecondOperator::Operation(input)); +} + +template <> +date_t DateTrunc::MicrosecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::MicrosecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::MicrosecondOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(MicrosecondOperator::Operation(input)); +} + +// INTERVAL specialisations +template <> +interval_t DateTrunc::MillenniumOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_MILLENIUM) * Interval::MONTHS_PER_MILLENIUM; + return input; +} + +template <> +interval_t DateTrunc::CenturyOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_CENTURY) * Interval::MONTHS_PER_CENTURY; + return input; +} + +template <> +interval_t DateTrunc::DecadeOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_DECADE) * Interval::MONTHS_PER_DECADE; + return input; +} + +template <> +interval_t DateTrunc::YearOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_YEAR) * Interval::MONTHS_PER_YEAR; + return input; +} + +template <> +interval_t DateTrunc::QuarterOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_QUARTER) * Interval::MONTHS_PER_QUARTER; + return input; +} + +template <> +interval_t DateTrunc::MonthOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + return input; +} + +template <> +interval_t DateTrunc::WeekOperator::Operation(interval_t input) { + input.micros = 0; + input.days = (input.days / Interval::DAYS_PER_WEEK) * Interval::DAYS_PER_WEEK; + return input; +} + +template <> +interval_t DateTrunc::ISOYearOperator::Operation(interval_t input) { + return YearOperator::Operation(input); +} + +template <> +interval_t DateTrunc::DayOperator::Operation(interval_t input) { + input.micros = 0; + return input; +} + +template <> +interval_t DateTrunc::HourOperator::Operation(interval_t input) { + input.micros = (input.micros / Interval::MICROS_PER_HOUR) * Interval::MICROS_PER_HOUR; + return input; +} + +template <> +interval_t DateTrunc::MinuteOperator::Operation(interval_t input) { + input.micros = (input.micros / Interval::MICROS_PER_MINUTE) * Interval::MICROS_PER_MINUTE; + return input; +} + +template <> +interval_t DateTrunc::SecondOperator::Operation(interval_t input) { + input.micros = (input.micros / Interval::MICROS_PER_SEC) * Interval::MICROS_PER_SEC; + return input; +} + +template <> +interval_t DateTrunc::MillisecondOperator::Operation(interval_t input) { + input.micros = (input.micros / Interval::MICROS_PER_MSEC) * Interval::MICROS_PER_MSEC; + return input; +} + +template <> +interval_t DateTrunc::MicrosecondOperator::Operation(interval_t input) { + return input; +} + +template +static TR TruncateElement(DatePartSpecifier type, TA element) { + if (!Value::IsFinite(element)) { + return Cast::template Operation(element); + } + + switch (type) { + case DatePartSpecifier::MILLENNIUM: + return DateTrunc::MillenniumOperator::Operation(element); + case DatePartSpecifier::CENTURY: + return DateTrunc::CenturyOperator::Operation(element); + case DatePartSpecifier::DECADE: + return DateTrunc::DecadeOperator::Operation(element); + case DatePartSpecifier::YEAR: + return DateTrunc::YearOperator::Operation(element); + case DatePartSpecifier::QUARTER: + return DateTrunc::QuarterOperator::Operation(element); + case DatePartSpecifier::MONTH: + return DateTrunc::MonthOperator::Operation(element); + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + return DateTrunc::WeekOperator::Operation(element); + case DatePartSpecifier::ISOYEAR: + return DateTrunc::ISOYearOperator::Operation(element); + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + return DateTrunc::DayOperator::Operation(element); + case DatePartSpecifier::HOUR: + return DateTrunc::HourOperator::Operation(element); + case DatePartSpecifier::MINUTE: + return DateTrunc::MinuteOperator::Operation(element); + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + return DateTrunc::SecondOperator::Operation(element); + case DatePartSpecifier::MILLISECONDS: + return DateTrunc::MillisecondOperator::Operation(element); + case DatePartSpecifier::MICROSECONDS: + return DateTrunc::MicrosecondOperator::Operation(element); + default: + throw NotImplementedException("Specifier type not implemented for DATETRUNC"); + } +} + +struct DateTruncBinaryOperator { + template + static inline TR Operation(TA specifier, TB date) { + return TruncateElement(GetDatePartSpecifier(specifier.GetString()), date); + } +}; + +template +static void DateTruncUnaryExecutor(DatePartSpecifier type, Vector &left, Vector &result, idx_t count) { + switch (type) { + case DatePartSpecifier::MILLENNIUM: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::CENTURY: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::DECADE: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::YEAR: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::QUARTER: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::MONTH: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::ISOYEAR: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::HOUR: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::MINUTE: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::MILLISECONDS: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::MICROSECONDS: + DateTrunc::UnaryExecute(left, result, count); + break; + default: + throw NotImplementedException("Specifier type not implemented for DATETRUNC"); + } +} + +template +static void DateTruncFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + auto &part_arg = args.data[0]; + auto &date_arg = args.data[1]; + + if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // Common case of constant part. + if (ConstantVector::IsNull(part_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); + DateTruncUnaryExecutor(type, date_arg, result, args.size()); + } + } else { + BinaryExecutor::ExecuteStandard(part_arg, date_arg, result, + args.size()); + } +} + +template +static unique_ptr DateTruncStatistics(vector &child_stats) { + // we can only propagate date stats if the child has stats + auto &nstats = child_stats[1]; + if (!NumericStats::HasMinMax(nstats)) { + return nullptr; + } + // run the operator on both the min and the max, this gives us the [min, max] bound + auto min = NumericStats::GetMin(nstats); + auto max = NumericStats::GetMax(nstats); + if (min > max) { + return nullptr; + } + + // Infinite values are unmodified + auto min_part = DateTrunc::UnaryFunction(min); + auto max_part = DateTrunc::UnaryFunction(max); + + auto min_value = Value::CreateValue(min_part); + auto max_value = Value::CreateValue(max_part); + auto result = NumericStats::CreateEmpty(min_value.type()); + NumericStats::SetMin(result, min_value); + NumericStats::SetMax(result, max_value); + result.CopyValidity(child_stats[0]); + return result.ToUnique(); +} + +template +static unique_ptr PropagateDateTruncStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return DateTruncStatistics(input.child_stats); +} + +template +static function_statistics_t DateTruncStats(DatePartSpecifier type) { + switch (type) { + case DatePartSpecifier::MILLENNIUM: + return PropagateDateTruncStatistics; + case DatePartSpecifier::CENTURY: + return PropagateDateTruncStatistics; + case DatePartSpecifier::DECADE: + return PropagateDateTruncStatistics; + case DatePartSpecifier::YEAR: + return PropagateDateTruncStatistics; + case DatePartSpecifier::QUARTER: + return PropagateDateTruncStatistics; + case DatePartSpecifier::MONTH: + return PropagateDateTruncStatistics; + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + return PropagateDateTruncStatistics; + case DatePartSpecifier::ISOYEAR: + return PropagateDateTruncStatistics; + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + return PropagateDateTruncStatistics; + case DatePartSpecifier::HOUR: + return PropagateDateTruncStatistics; + case DatePartSpecifier::MINUTE: + return PropagateDateTruncStatistics; + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + return PropagateDateTruncStatistics; + case DatePartSpecifier::MILLISECONDS: + return PropagateDateTruncStatistics; + case DatePartSpecifier::MICROSECONDS: + return PropagateDateTruncStatistics; + default: + throw NotImplementedException("Specifier type not implemented for DATETRUNC statistics"); + } +} + +static unique_ptr DateTruncBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (!arguments[0]->IsFoldable()) { + return nullptr; + } + + // Rebind to return a date if we are truncating that far + Value part_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + if (part_value.IsNull()) { + return nullptr; + } + const auto part_name = part_value.ToString(); + const auto part_code = GetDatePartSpecifier(part_name); + switch (part_code) { + case DatePartSpecifier::MILLENNIUM: + case DatePartSpecifier::CENTURY: + case DatePartSpecifier::DECADE: + case DatePartSpecifier::YEAR: + case DatePartSpecifier::QUARTER: + case DatePartSpecifier::MONTH: + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + case DatePartSpecifier::ISOYEAR: + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + switch (bound_function.arguments[1].id()) { + case LogicalType::TIMESTAMP: + bound_function.function = DateTruncFunction; + bound_function.statistics = DateTruncStats(part_code); + break; + case LogicalType::DATE: + bound_function.function = DateTruncFunction; + bound_function.statistics = DateTruncStats(part_code); + break; + default: + throw NotImplementedException("Temporal argument type for DATETRUNC"); + } + bound_function.return_type = LogicalType::DATE; + break; + default: + switch (bound_function.arguments[1].id()) { + case LogicalType::TIMESTAMP: + bound_function.statistics = DateTruncStats(part_code); + break; + case LogicalType::DATE: + bound_function.statistics = DateTruncStats(part_code); + break; + default: + throw NotImplementedException("Temporal argument type for DATETRUNC"); + } + break; + } + + return nullptr; +} + +ScalarFunctionSet DateTruncFun::GetFunctions() { + ScalarFunctionSet date_trunc("date_trunc"); + date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, + DateTruncFunction, DateTruncBind)); + date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::TIMESTAMP, + DateTruncFunction, DateTruncBind)); + date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::INTERVAL, + DateTruncFunction)); + return date_trunc; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/epoch.cpp b/src/duckdb/src/core_functions/scalar/date/epoch.cpp new file mode 100644 index 00000000..3de2d50a --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/epoch.cpp @@ -0,0 +1,31 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" + +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" + +namespace duckdb { + +struct EpochSecOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE sec) { + int64_t result; + if (!TryCast::Operation(sec * Interval::MICROS_PER_SEC, result)) { + throw ConversionException("Could not convert epoch seconds to TIMESTAMP WITH TIME ZONE"); + } + return timestamp_t(result); + } +}; + +static void EpochSecFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + + UnaryExecutor::Execute(input.data[0], result, input.size()); +} + +ScalarFunction ToTimestampFun::GetFunction() { + // to_timestamp is an alias from Postgres that converts the time in seconds to a timestamp + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::TIMESTAMP_TZ, EpochSecFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/make_date.cpp b/src/duckdb/src/core_functions/scalar/date/make_date.cpp new file mode 100644 index 00000000..2eb5248b --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/make_date.cpp @@ -0,0 +1,123 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/senary_executor.hpp" + +#include + +namespace duckdb { + +struct MakeDateOperator { + template + static RESULT_TYPE Operation(YYYY yyyy, MM mm, DD dd) { + return Date::FromDate(yyyy, mm, dd); + } +}; + +template +static void ExecuteMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 3); + auto &yyyy = input.data[0]; + auto &mm = input.data[1]; + auto &dd = input.data[2]; + + TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), + MakeDateOperator::Operation); +} + +template +static void ExecuteStructMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { + // this should be guaranteed by the binder + D_ASSERT(input.ColumnCount() == 1); + auto &vec = input.data[0]; + + auto &children = StructVector::GetEntries(vec); + D_ASSERT(children.size() == 3); + auto &yyyy = *children[0]; + auto &mm = *children[1]; + auto &dd = *children[2]; + + TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), Date::FromDate); +} + +struct MakeTimeOperator { + template + static RESULT_TYPE Operation(HH hh, MM mm, SS ss) { + int64_t secs = ss; + int64_t micros = std::round((ss - secs) * Interval::MICROS_PER_SEC); + if (!Time::IsValidTime(hh, mm, secs, micros)) { + throw ConversionException("Time out of range: %d:%d:%d.%d", hh, mm, secs, micros); + } + return Time::FromTime(hh, mm, secs, micros); + } +}; + +template +static void ExecuteMakeTime(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 3); + auto &yyyy = input.data[0]; + auto &mm = input.data[1]; + auto &dd = input.data[2]; + + TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), + MakeTimeOperator::Operation); +} + +struct MakeTimestampOperator { + template + static RESULT_TYPE Operation(YYYY yyyy, MM mm, DD dd, HR hr, MN mn, SS ss) { + const auto d = MakeDateOperator::Operation(yyyy, mm, dd); + const auto t = MakeTimeOperator::Operation(hr, mn, ss); + return Timestamp::FromDatetime(d, t); + } + + template + static RESULT_TYPE Operation(T micros) { + return timestamp_t(micros); + } +}; + +template +static void ExecuteMakeTimestamp(DataChunk &input, ExpressionState &state, Vector &result) { + if (input.ColumnCount() == 1) { + auto func = MakeTimestampOperator::Operation; + UnaryExecutor::Execute(input.data[0], result, input.size(), func); + return; + } + + D_ASSERT(input.ColumnCount() == 6); + + auto func = MakeTimestampOperator::Operation; + SenaryExecutor::Execute(input, result, func); +} + +ScalarFunctionSet MakeDateFun::GetFunctions() { + ScalarFunctionSet make_date("make_date"); + make_date.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::DATE, ExecuteMakeDate)); + + child_list_t make_date_children { + {"year", LogicalType::BIGINT}, {"month", LogicalType::BIGINT}, {"day", LogicalType::BIGINT}}; + make_date.AddFunction( + ScalarFunction({LogicalType::STRUCT(make_date_children)}, LogicalType::DATE, ExecuteStructMakeDate)); + return make_date; +} + +ScalarFunction MakeTimeFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, LogicalType::TIME, + ExecuteMakeTime); +} + +ScalarFunctionSet MakeTimestampFun::GetFunctions() { + ScalarFunctionSet operator_set("make_timestamp"); + operator_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, + LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, + LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); + operator_set.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); + return operator_set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/strftime.cpp b/src/duckdb/src/core_functions/scalar/date/strftime.cpp new file mode 100644 index 00000000..a764c97e --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/strftime.cpp @@ -0,0 +1,261 @@ +#include "duckdb/function/scalar/strftime_format.hpp" + +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/core_functions/scalar/date_functions.hpp" + +#include +#include + +namespace duckdb { + +struct StrfTimeBindData : public FunctionData { + explicit StrfTimeBindData(StrfTimeFormat format_p, string format_string_p, bool is_null) + : format(std::move(format_p)), format_string(std::move(format_string_p)), is_null(is_null) { + } + + StrfTimeFormat format; + string format_string; + bool is_null; + + unique_ptr Copy() const override { + return make_uniq(format, format_string, is_null); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return format_string == other.format_string; + } +}; + +template +static unique_ptr StrfTimeBindFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto format_idx = REVERSED ? 0 : 1; + auto &format_arg = arguments[format_idx]; + if (format_arg->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!format_arg->IsFoldable()) { + throw InvalidInputException("strftime format must be a constant"); + } + Value options_str = ExpressionExecutor::EvaluateScalar(context, *format_arg); + auto format_string = options_str.GetValue(); + StrfTimeFormat format; + bool is_null = options_str.IsNull(); + if (!is_null) { + string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); + if (!error.empty()) { + throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); + } + } + return make_uniq(format, format_string, is_null); +} + +template +static void StrfTimeFunctionDate(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + if (info.is_null) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + info.format.ConvertDateVector(args.data[REVERSED ? 1 : 0], result, args.size()); +} + +template +static void StrfTimeFunctionTimestamp(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + if (info.is_null) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + info.format.ConvertTimestampVector(args.data[REVERSED ? 1 : 0], result, args.size()); +} + +ScalarFunctionSet StrfTimeFun::GetFunctions() { + ScalarFunctionSet strftime; + + strftime.AddFunction(ScalarFunction({LogicalType::DATE, LogicalType::VARCHAR}, LogicalType::VARCHAR, + StrfTimeFunctionDate, StrfTimeBindFunction)); + strftime.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::VARCHAR}, LogicalType::VARCHAR, + StrfTimeFunctionTimestamp, StrfTimeBindFunction)); + strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::VARCHAR, + StrfTimeFunctionDate, StrfTimeBindFunction)); + strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::VARCHAR, + StrfTimeFunctionTimestamp, StrfTimeBindFunction)); + return strftime; +} + +StrpTimeFormat::StrpTimeFormat() { +} + +StrpTimeFormat::StrpTimeFormat(const string &format_string) { + if (format_string.empty()) { + return; + } + StrTimeFormat::ParseFormatSpecifier(format_string, *this); +} + +struct StrpTimeBindData : public FunctionData { + StrpTimeBindData(const StrpTimeFormat &format, const string &format_string) + : formats(1, format), format_strings(1, format_string) { + } + + StrpTimeBindData(vector formats_p, vector format_strings_p) + : formats(std::move(formats_p)), format_strings(std::move(format_strings_p)) { + } + + vector formats; + vector format_strings; + + unique_ptr Copy() const override { + return make_uniq(formats, format_strings); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return format_strings == other.format_strings; + } +}; + +static unique_ptr StrpTimeBindFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw InvalidInputException("strptime format must be a constant"); + } + Value format_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + string format_string; + StrpTimeFormat format; + if (format_value.IsNull()) { + return make_uniq(format, format_string); + } else if (format_value.type().id() == LogicalTypeId::VARCHAR) { + format_string = format_value.ToString(); + format.format_specifier = format_string; + string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); + if (!error.empty()) { + throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); + } + if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { + bound_function.return_type = LogicalType::TIMESTAMP_TZ; + } + return make_uniq(format, format_string); + } else if (format_value.type() == LogicalType::LIST(LogicalType::VARCHAR)) { + const auto &children = ListValue::GetChildren(format_value); + if (children.empty()) { + throw InvalidInputException("strptime format list must not be empty"); + } + vector format_strings; + vector formats; + for (const auto &child : children) { + format_string = child.ToString(); + format.format_specifier = format_string; + string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); + if (!error.empty()) { + throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); + } + // If any format has UTC offsets, then we have to produce TSTZ + if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { + bound_function.return_type = LogicalType::TIMESTAMP_TZ; + } + format_strings.emplace_back(format_string); + formats.emplace_back(format); + } + return make_uniq(formats, format_strings); + } else { + throw InvalidInputException("strptime format must be a string"); + } +} + +struct StrpTimeFunction { + + static void Parse(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + if (args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR && ConstantVector::IsNull(args.data[1])) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { + StrpTimeFormat::ParseResult result; + for (auto &format : info.formats) { + if (format.Parse(input, result)) { + return result.ToTimestamp(); + } + } + throw InvalidInputException(result.FormatError(input, info.formats[0].format_specifier)); + }); + } + + static void TryParse(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + if (args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR && ConstantVector::IsNull(args.data[1])) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + + UnaryExecutor::ExecuteWithNulls( + args.data[0], result, args.size(), [&](string_t input, ValidityMask &mask, idx_t idx) { + timestamp_t result; + string error; + for (auto &format : info.formats) { + if (format.TryParseTimestamp(input, result, error)) { + return result; + } + } + + mask.SetInvalid(idx); + return timestamp_t(); + }); + } +}; + +ScalarFunctionSet StrpTimeFun::GetFunctions() { + ScalarFunctionSet strptime; + + const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); + auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, + StrpTimeFunction::Parse, StrpTimeBindFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + strptime.AddFunction(fun); + + fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::Parse, + StrpTimeBindFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + strptime.AddFunction(fun); + return strptime; +} + +ScalarFunctionSet TryStrpTimeFun::GetFunctions() { + ScalarFunctionSet try_strptime; + + const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); + auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, + StrpTimeFunction::TryParse, StrpTimeBindFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + try_strptime.AddFunction(fun); + + fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::TryParse, + StrpTimeBindFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + try_strptime.AddFunction(fun); + + return try_strptime; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/time_bucket.cpp b/src/duckdb/src/core_functions/scalar/date/time_bucket.cpp new file mode 100644 index 00000000..d317ea60 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/time_bucket.cpp @@ -0,0 +1,370 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/core_functions/scalar/date_functions.hpp" + +namespace duckdb { + +struct TimeBucket { + + // Use 2000-01-03 00:00:00 (Monday) as origin when bucket_width is days, hours, ... for TimescaleDB compatibility + // There are 10959 days between 1970-01-01 and 2000-01-03 + constexpr static const int64_t DEFAULT_ORIGIN_MICROS = 10959 * Interval::MICROS_PER_DAY; + // Use 2000-01-01 as origin when bucket_width is months, years, ... for TimescaleDB compatibility + // There are 360 months between 1970-01-01 and 2000-01-01 + constexpr static const int32_t DEFAULT_ORIGIN_MONTHS = 360; + + enum struct BucketWidthType { CONVERTIBLE_TO_MICROS, CONVERTIBLE_TO_MONTHS, UNCLASSIFIED }; + + static inline BucketWidthType ClassifyBucketWidth(const interval_t bucket_width) { + if (bucket_width.months == 0 && Interval::GetMicro(bucket_width) > 0) { + return BucketWidthType::CONVERTIBLE_TO_MICROS; + } else if (bucket_width.months > 0 && bucket_width.days == 0 && bucket_width.micros == 0) { + return BucketWidthType::CONVERTIBLE_TO_MONTHS; + } else { + return BucketWidthType::UNCLASSIFIED; + } + } + + static inline BucketWidthType ClassifyBucketWidthErrorThrow(const interval_t bucket_width) { + if (bucket_width.months == 0) { + int64_t bucket_width_micros = Interval::GetMicro(bucket_width); + if (bucket_width_micros <= 0) { + throw NotImplementedException("Period must be greater than 0"); + } + return BucketWidthType::CONVERTIBLE_TO_MICROS; + } else if (bucket_width.months != 0 && bucket_width.days == 0 && bucket_width.micros == 0) { + if (bucket_width.months < 0) { + throw NotImplementedException("Period must be greater than 0"); + } + return BucketWidthType::CONVERTIBLE_TO_MONTHS; + } else { + throw NotImplementedException("Month intervals cannot have day or time component"); + } + } + + template + static inline int32_t EpochMonths(T ts) { + date_t ts_date = Cast::template Operation(ts); + return (Date::ExtractYear(ts_date) - 1970) * 12 + Date::ExtractMonth(ts_date) - 1; + } + + static inline timestamp_t WidthConvertibleToMicrosCommon(int64_t bucket_width_micros, int64_t ts_micros, + int64_t origin_micros) { + origin_micros %= bucket_width_micros; + ts_micros = SubtractOperatorOverflowCheck::Operation(ts_micros, origin_micros); + + int64_t result_micros = (ts_micros / bucket_width_micros) * bucket_width_micros; + if (ts_micros < 0 && ts_micros % bucket_width_micros != 0) { + result_micros = + SubtractOperatorOverflowCheck::Operation(result_micros, bucket_width_micros); + } + result_micros += origin_micros; + + return Timestamp::FromEpochMicroSeconds(result_micros); + } + + static inline date_t WidthConvertibleToMonthsCommon(int32_t bucket_width_months, int32_t ts_months, + int32_t origin_months) { + origin_months %= bucket_width_months; + ts_months = SubtractOperatorOverflowCheck::Operation(ts_months, origin_months); + + int32_t result_months = (ts_months / bucket_width_months) * bucket_width_months; + if (ts_months < 0 && ts_months % bucket_width_months != 0) { + result_months = + SubtractOperatorOverflowCheck::Operation(result_months, bucket_width_months); + } + result_months += origin_months; + + int32_t year = + (result_months < 0 && result_months % 12 != 0) ? 1970 + result_months / 12 - 1 : 1970 + result_months / 12; + int32_t month = + (result_months < 0 && result_months % 12 != 0) ? result_months % 12 + 13 : result_months % 12 + 1; + + return Date::FromDate(year, month, 1); + } + + struct WidthConvertibleToMicrosBinaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int64_t bucket_width_micros = Interval::GetMicro(bucket_width); + int64_t ts_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(ts)); + return Cast::template Operation( + WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, DEFAULT_ORIGIN_MICROS)); + } + }; + + struct WidthConvertibleToMonthsBinaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int32_t ts_months = EpochMonths(ts); + return Cast::template Operation( + WidthConvertibleToMonthsCommon(bucket_width.months, ts_months, DEFAULT_ORIGIN_MONTHS)); + } + }; + + struct BinaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts) { + BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); + switch (bucket_width_type) { + case BucketWidthType::CONVERTIBLE_TO_MICROS: + return WidthConvertibleToMicrosBinaryOperator::Operation(bucket_width, ts); + case BucketWidthType::CONVERTIBLE_TO_MONTHS: + return WidthConvertibleToMonthsBinaryOperator::Operation(bucket_width, ts); + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + }; + + struct OffsetWidthConvertibleToMicrosTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC offset) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int64_t bucket_width_micros = Interval::GetMicro(bucket_width); + int64_t ts_micros = Timestamp::GetEpochMicroSeconds( + Interval::Add(Cast::template Operation(ts), Interval::Invert(offset))); + return Cast::template Operation(Interval::Add( + WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, DEFAULT_ORIGIN_MICROS), offset)); + } + }; + + struct OffsetWidthConvertibleToMonthsTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC offset) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int32_t ts_months = EpochMonths(Interval::Add(ts, Interval::Invert(offset))); + return Interval::Add(Cast::template Operation(WidthConvertibleToMonthsCommon( + bucket_width.months, ts_months, DEFAULT_ORIGIN_MONTHS)), + offset); + } + }; + + struct OffsetTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC offset) { + BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); + switch (bucket_width_type) { + case BucketWidthType::CONVERTIBLE_TO_MICROS: + return OffsetWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, + offset); + case BucketWidthType::CONVERTIBLE_TO_MONTHS: + return OffsetWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, + offset); + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + }; + + struct OriginWidthConvertibleToMicrosTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC origin) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int64_t bucket_width_micros = Interval::GetMicro(bucket_width); + int64_t ts_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(ts)); + int64_t origin_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(origin)); + return Cast::template Operation( + WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, origin_micros)); + } + }; + + struct OriginWidthConvertibleToMonthsTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC origin) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int32_t ts_months = EpochMonths(ts); + int32_t origin_months = EpochMonths(origin); + return Cast::template Operation( + WidthConvertibleToMonthsCommon(bucket_width.months, ts_months, origin_months)); + } + }; + + struct OriginTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC origin, ValidityMask &mask, idx_t idx) { + if (!Value::IsFinite(origin)) { + mask.SetInvalid(idx); + return TR(); + } + BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); + switch (bucket_width_type) { + case BucketWidthType::CONVERTIBLE_TO_MICROS: + return OriginWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, + origin); + case BucketWidthType::CONVERTIBLE_TO_MONTHS: + return OriginWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, + origin); + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + }; +}; + +template +static void TimeBucketFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + + auto &bucket_width_arg = args.data[0]; + auto &ts_arg = args.data[1]; + + if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(bucket_width_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); + TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); + switch (bucket_width_type) { + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: + BinaryExecutor::Execute( + bucket_width_arg, ts_arg, result, args.size(), + TimeBucket::WidthConvertibleToMicrosBinaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: + BinaryExecutor::Execute( + bucket_width_arg, ts_arg, result, args.size(), + TimeBucket::WidthConvertibleToMonthsBinaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::UNCLASSIFIED: + BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), + TimeBucket::BinaryOperator::Operation); + break; + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + } else { + BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), + TimeBucket::BinaryOperator::Operation); + } +} + +template +static void TimeBucketOffsetFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3); + + auto &bucket_width_arg = args.data[0]; + auto &ts_arg = args.data[1]; + auto &offset_arg = args.data[2]; + + if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(bucket_width_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); + TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); + switch (bucket_width_type) { + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, offset_arg, result, args.size(), + TimeBucket::OffsetWidthConvertibleToMicrosTernaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, offset_arg, result, args.size(), + TimeBucket::OffsetWidthConvertibleToMonthsTernaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::UNCLASSIFIED: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, offset_arg, result, args.size(), + TimeBucket::OffsetTernaryOperator::Operation); + break; + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + } else { + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, offset_arg, result, args.size(), + TimeBucket::OffsetTernaryOperator::Operation); + } +} + +template +static void TimeBucketOriginFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3); + + auto &bucket_width_arg = args.data[0]; + auto &ts_arg = args.data[1]; + auto &origin_arg = args.data[2]; + + if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR && + origin_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(bucket_width_arg) || ConstantVector::IsNull(origin_arg) || + !Value::IsFinite(*ConstantVector::GetData(origin_arg))) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); + TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); + switch (bucket_width_type) { + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, origin_arg, result, args.size(), + TimeBucket::OriginWidthConvertibleToMicrosTernaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, origin_arg, result, args.size(), + TimeBucket::OriginWidthConvertibleToMonthsTernaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::UNCLASSIFIED: + TernaryExecutor::ExecuteWithNulls( + bucket_width_arg, ts_arg, origin_arg, result, args.size(), + TimeBucket::OriginTernaryOperator::Operation); + break; + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + } else { + TernaryExecutor::ExecuteWithNulls( + bucket_width_arg, ts_arg, origin_arg, result, args.size(), + TimeBucket::OriginTernaryOperator::Operation); + } +} + +ScalarFunctionSet TimeBucketFun::GetFunctions() { + ScalarFunctionSet time_bucket; + time_bucket.AddFunction( + ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE}, LogicalType::DATE, TimeBucketFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, + TimeBucketFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE, LogicalType::INTERVAL}, + LogicalType::DATE, TimeBucketOffsetFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + LogicalType::TIMESTAMP, TimeBucketOffsetFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE, LogicalType::DATE}, + LogicalType::DATE, TimeBucketOriginFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, + LogicalType::TIMESTAMP, TimeBucketOriginFunction)); + return time_bucket; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/date/to_interval.cpp b/src/duckdb/src/core_functions/scalar/date/to_interval.cpp new file mode 100644 index 00000000..5227272b --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/date/to_interval.cpp @@ -0,0 +1,150 @@ +#include "duckdb/core_functions/scalar/date_functions.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/operator/multiply.hpp" + +namespace duckdb { + +struct ToYearsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.days = 0; + result.micros = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_YEAR, + result.months)) { + throw OutOfRangeException("Interval value %d years out of range", input); + } + return result; + } +}; + +struct ToMonthsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = input; + result.days = 0; + result.micros = 0; + return result; + } +}; + +struct ToDaysOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = input; + result.micros = 0; + return result; + } +}; + +struct ToHoursOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_HOUR, + result.micros)) { + throw OutOfRangeException("Interval value %d hours out of range", input); + } + return result; + } +}; + +struct ToMinutesOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_MINUTE, + result.micros)) { + throw OutOfRangeException("Interval value %d minutes out of range", input); + } + return result; + } +}; + +struct ToSecondsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_SEC, + result.micros)) { + throw OutOfRangeException("Interval value %d seconds out of range", input); + } + return result; + } +}; + +struct ToMilliSecondsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_MSEC, + result.micros)) { + throw OutOfRangeException("Interval value %d milliseconds out of range", input); + } + return result; + } +}; + +struct ToMicroSecondsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + result.micros = input; + return result; + } +}; + +ScalarFunction ToYearsFun::GetFunction() { + return ScalarFunction({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); +} + +ScalarFunction ToMonthsFun::GetFunction() { + return ScalarFunction({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); +} + +ScalarFunction ToDaysFun::GetFunction() { + return ScalarFunction({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); +} + +ScalarFunction ToHoursFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); +} + +ScalarFunction ToMinutesFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); +} + +ScalarFunction ToSecondsFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); +} + +ScalarFunction ToMillisecondsFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); +} + +ScalarFunction ToMicrosecondsFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/debug/vector_type.cpp b/src/duckdb/src/core_functions/scalar/debug/vector_type.cpp new file mode 100644 index 00000000..0f2dc5e2 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/debug/vector_type.cpp @@ -0,0 +1,23 @@ +#include "duckdb/core_functions/scalar/debug_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +static void VectorTypeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto data = ConstantVector::GetData(result); + data[0] = StringVector::AddString(result, EnumUtil::ToString(input.data[0].GetVectorType())); +} + +ScalarFunction VectorTypeFun::GetFunction() { + return ScalarFunction("vector_type", // name of the function + {LogicalType::ANY}, // argument list + LogicalType::VARCHAR, // return type + VectorTypeFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/enum/enum_functions.cpp b/src/duckdb/src/core_functions/scalar/enum/enum_functions.cpp new file mode 100644 index 00000000..5722cf58 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/enum/enum_functions.cpp @@ -0,0 +1,169 @@ +#include "duckdb/core_functions/scalar/enum_functions.hpp" + +namespace duckdb { + +static void EnumFirstFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto types = input.GetTypes(); + D_ASSERT(types.size() == 1); + auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); + auto val = Value(enum_vector.GetValue(0)); + result.Reference(val); +} + +static void EnumLastFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto types = input.GetTypes(); + D_ASSERT(types.size() == 1); + auto enum_size = EnumType::GetSize(types[0]); + auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); + auto val = Value(enum_vector.GetValue(enum_size - 1)); + result.Reference(val); +} + +static void EnumRangeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto types = input.GetTypes(); + D_ASSERT(types.size() == 1); + auto enum_size = EnumType::GetSize(types[0]); + auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); + vector enum_values; + for (idx_t i = 0; i < enum_size; i++) { + enum_values.emplace_back(enum_vector.GetValue(i)); + } + auto val = Value::LIST(enum_values); + result.Reference(val); +} + +static void EnumRangeBoundaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto types = input.GetTypes(); + D_ASSERT(types.size() == 2); + idx_t start, end; + auto first_param = input.GetValue(0, 0); + auto second_param = input.GetValue(1, 0); + + auto &enum_vector = + first_param.IsNull() ? EnumType::GetValuesInsertOrder(types[1]) : EnumType::GetValuesInsertOrder(types[0]); + + if (first_param.IsNull()) { + start = 0; + } else { + start = first_param.GetValue(); + } + if (second_param.IsNull()) { + end = EnumType::GetSize(types[0]); + } else { + end = second_param.GetValue() + 1; + } + vector enum_values; + for (idx_t i = start; i < end; i++) { + enum_values.emplace_back(enum_vector.GetValue(i)); + } + Value val; + if (enum_values.empty()) { + val = Value::EMPTYLIST(LogicalType::VARCHAR); + } else { + val = Value::LIST(enum_values); + } + result.Reference(val); +} + +static void EnumCodeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.GetTypes().size() == 1); + result.Reinterpret(input.data[0]); +} + +static void CheckEnumParameter(const Expression &expr) { + if (expr.HasParameter()) { + throw ParameterNotResolvedException(); + } +} + +unique_ptr BindEnumFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + CheckEnumParameter(*arguments[0]); + if (arguments[0]->return_type.id() != LogicalTypeId::ENUM) { + throw BinderException("This function needs an ENUM as an argument"); + } + return nullptr; +} + +unique_ptr BindEnumCodeFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + CheckEnumParameter(*arguments[0]); + if (arguments[0]->return_type.id() != LogicalTypeId::ENUM) { + throw BinderException("This function needs an ENUM as an argument"); + } + + auto phy_type = EnumType::GetPhysicalType(arguments[0]->return_type); + switch (phy_type) { + case PhysicalType::UINT8: + bound_function.return_type = LogicalType(LogicalTypeId::UTINYINT); + break; + case PhysicalType::UINT16: + bound_function.return_type = LogicalType(LogicalTypeId::USMALLINT); + break; + case PhysicalType::UINT32: + bound_function.return_type = LogicalType(LogicalTypeId::UINTEGER); + break; + case PhysicalType::UINT64: + bound_function.return_type = LogicalType(LogicalTypeId::UBIGINT); + break; + default: + throw InternalException("Unsupported Enum Internal Type"); + } + + return nullptr; +} + +unique_ptr BindEnumRangeBoundaryFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + CheckEnumParameter(*arguments[0]); + CheckEnumParameter(*arguments[1]); + if (arguments[0]->return_type.id() != LogicalTypeId::ENUM && arguments[0]->return_type != LogicalType::SQLNULL) { + throw BinderException("This function needs an ENUM as an argument"); + } + if (arguments[1]->return_type.id() != LogicalTypeId::ENUM && arguments[1]->return_type != LogicalType::SQLNULL) { + throw BinderException("This function needs an ENUM as an argument"); + } + if (arguments[0]->return_type == LogicalType::SQLNULL && arguments[1]->return_type == LogicalType::SQLNULL) { + throw BinderException("This function needs an ENUM as an argument"); + } + if (arguments[0]->return_type.id() == LogicalTypeId::ENUM && + arguments[1]->return_type.id() == LogicalTypeId::ENUM && + arguments[0]->return_type != arguments[1]->return_type) { + throw BinderException("The parameters need to link to ONLY one enum OR be NULL "); + } + return nullptr; +} + +ScalarFunction EnumFirstFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumFirstFunction, BindEnumFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction EnumLastFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumLastFunction, BindEnumFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction EnumCodeFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::ANY, EnumCodeFunction, BindEnumCodeFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction EnumRangeFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), EnumRangeFunction, + BindEnumFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction EnumRangeBoundaryFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), + EnumRangeBoundaryFunction, BindEnumRangeBoundaryFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/alias.cpp b/src/duckdb/src/core_functions/scalar/generic/alias.cpp new file mode 100644 index 00000000..e7065ba5 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/alias.cpp @@ -0,0 +1,18 @@ +#include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +static void AliasFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + Value v(state.expr.alias.empty() ? func_expr.children[0]->GetName() : state.expr.alias); + result.Reference(v); +} + +ScalarFunction AliasFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, AliasFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/current_setting.cpp b/src/duckdb/src/core_functions/scalar/generic/current_setting.cpp new file mode 100644 index 00000000..5eb5c91a --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/current_setting.cpp @@ -0,0 +1,69 @@ +#include "duckdb/core_functions/scalar/generic_functions.hpp" + +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/catalog/catalog.hpp" +namespace duckdb { + +struct CurrentSettingBindData : public FunctionData { + explicit CurrentSettingBindData(Value value_p) : value(std::move(value_p)) { + } + + Value value; + +public: + unique_ptr Copy() const override { + return make_uniq(value); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return Value::NotDistinctFrom(value, other.value); + } +}; + +static void CurrentSettingFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + result.Reference(info.value); +} + +unique_ptr CurrentSettingBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + auto &key_child = arguments[0]; + if (key_child->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + if (key_child->return_type.id() != LogicalTypeId::VARCHAR || + key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { + throw ParserException("Key name for current_setting needs to be a constant string"); + } + Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); + D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); + auto &key_str = StringValue::Get(key_val); + if (key_val.IsNull() || key_str.empty()) { + throw ParserException("Key name for current_setting needs to be neither NULL nor empty"); + } + + auto key = StringUtil::Lower(key_str); + Value val; + if (!context.TryGetCurrentSetting(key, val)) { + Catalog::AutoloadExtensionByConfigName(context, key); + // If autoloader didn't throw, the config is now available + context.TryGetCurrentSetting(key, val); + } + + bound_function.return_type = val.type(); + return make_uniq(val); +} + +ScalarFunction CurrentSettingFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::ANY, CurrentSettingFunction, CurrentSettingBind); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/error.cpp b/src/duckdb/src/core_functions/scalar/generic/error.cpp new file mode 100644 index 00000000..5d38236b --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/error.cpp @@ -0,0 +1,21 @@ +#include "duckdb/core_functions/scalar/generic_functions.hpp" +#include + +namespace duckdb { + +struct ErrorOperator { + template + static inline TR Operation(const TA &input) { + throw Exception(input.GetString()); + } +}; + +ScalarFunction ErrorFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction); + // Set the function with side effects to avoid the optimization. + fun.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/hash.cpp b/src/duckdb/src/core_functions/scalar/generic/hash.cpp new file mode 100644 index 00000000..b99e9704 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/hash.cpp @@ -0,0 +1,19 @@ +#include "duckdb/core_functions/scalar/generic_functions.hpp" + +namespace duckdb { + +static void HashFunction(DataChunk &args, ExpressionState &state, Vector &result) { + args.Hash(result); + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +ScalarFunction HashFun::GetFunction() { + auto hash_fun = ScalarFunction({LogicalType::ANY}, LogicalType::HASH, HashFunction); + hash_fun.varargs = LogicalType::ANY; + hash_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return hash_fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/least.cpp b/src/duckdb/src/core_functions/scalar/generic/least.cpp new file mode 100644 index 00000000..42e2d68d --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/least.cpp @@ -0,0 +1,137 @@ +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/core_functions/scalar/generic_functions.hpp" + +namespace duckdb { + +template +struct LeastOperator { + template + static T Operation(T left, T right) { + return OP::Operation(left, right) ? left : right; + } +}; + +template +static void LeastGreatestFunction(DataChunk &args, ExpressionState &state, Vector &result) { + if (args.ColumnCount() == 1) { + // single input: nop + result.Reference(args.data[0]); + return; + } + auto result_type = VectorType::CONSTANT_VECTOR; + for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { + if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { + // non-constant input: result is not a constant vector + result_type = VectorType::FLAT_VECTOR; + } + if (IS_STRING) { + // for string vectors we add a reference to the heap of the children + StringVector::AddHeapReference(result, args.data[col_idx]); + } + } + + auto result_data = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + // copy over the first column + bool result_has_value[STANDARD_VECTOR_SIZE]; + { + UnifiedVectorFormat vdata; + args.data[0].ToUnifiedFormat(args.size(), vdata); + auto input_data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < args.size(); i++) { + auto vindex = vdata.sel->get_index(i); + if (vdata.validity.RowIsValid(vindex)) { + result_data[i] = input_data[vindex]; + result_has_value[i] = true; + } else { + result_has_value[i] = false; + } + } + } + // now handle the remainder of the columns + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + if (args.data[col_idx].GetVectorType() == VectorType::CONSTANT_VECTOR && + ConstantVector::IsNull(args.data[col_idx])) { + // ignore null vector + continue; + } + + UnifiedVectorFormat vdata; + args.data[col_idx].ToUnifiedFormat(args.size(), vdata); + + auto input_data = UnifiedVectorFormat::GetData(vdata); + if (!vdata.validity.AllValid()) { + // potential new null entries: have to check the null mask + for (idx_t i = 0; i < args.size(); i++) { + auto vindex = vdata.sel->get_index(i); + if (vdata.validity.RowIsValid(vindex)) { + // not a null entry: perform the operation and add to new set + auto ivalue = input_data[vindex]; + if (!result_has_value[i] || OP::template Operation(ivalue, result_data[i])) { + result_has_value[i] = true; + result_data[i] = ivalue; + } + } + } + } else { + // no new null entries: only need to perform the operation + for (idx_t i = 0; i < args.size(); i++) { + auto vindex = vdata.sel->get_index(i); + + auto ivalue = input_data[vindex]; + if (!result_has_value[i] || OP::template Operation(ivalue, result_data[i])) { + result_has_value[i] = true; + result_data[i] = ivalue; + } + } + } + } + for (idx_t i = 0; i < args.size(); i++) { + if (!result_has_value[i]) { + result_mask.SetInvalid(i); + } + } + result.SetVectorType(result_type); +} + +template +ScalarFunction GetLeastGreatestFunction(const LogicalType &type) { + return ScalarFunction({type}, type, LeastGreatestFunction, nullptr, nullptr, nullptr, nullptr, type, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING); +} + +template +static ScalarFunctionSet GetLeastGreatestFunctions() { + ScalarFunctionSet fun_set; + fun_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::BIGINT, LeastGreatestFunction, + nullptr, nullptr, nullptr, nullptr, LogicalType::BIGINT, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + fun_set.AddFunction(ScalarFunction( + {LogicalType::HUGEINT}, LogicalType::HUGEINT, LeastGreatestFunction, nullptr, nullptr, nullptr, + nullptr, LogicalType::HUGEINT, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + fun_set.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, LeastGreatestFunction, + nullptr, nullptr, nullptr, nullptr, LogicalType::DOUBLE, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + fun_set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, + LeastGreatestFunction, nullptr, nullptr, nullptr, nullptr, + LogicalType::VARCHAR, FunctionSideEffects::NO_SIDE_EFFECTS, + FunctionNullHandling::SPECIAL_HANDLING)); + + fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIMESTAMP)); + fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIME)); + fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::DATE)); + + fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIMESTAMP_TZ)); + fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIME_TZ)); + return fun_set; +} + +ScalarFunctionSet LeastFun::GetFunctions() { + return GetLeastGreatestFunctions(); +} + +ScalarFunctionSet GreatestFun::GetFunctions() { + return GetLeastGreatestFunctions(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/stats.cpp b/src/duckdb/src/core_functions/scalar/generic/stats.cpp new file mode 100644 index 00000000..d19dcfc9 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/stats.cpp @@ -0,0 +1,54 @@ +#include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +struct StatsBindData : public FunctionData { + explicit StatsBindData(string stats_p = string()) : stats(std::move(stats_p)) { + } + + string stats; + +public: + unique_ptr Copy() const override { + return make_uniq(stats); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return stats == other.stats; + } +}; + +static void StatsFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + if (info.stats.empty()) { + info.stats = "No statistics"; + } + Value v(info.stats); + result.Reference(v); +} + +unique_ptr StatsBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return make_uniq(); +} + +static unique_ptr StatsPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &bind_data = input.bind_data; + auto &info = bind_data->Cast(); + info.stats = child_stats[0].ToString(); + return nullptr; +} + +ScalarFunction StatsFun::GetFunction() { + ScalarFunction stats({LogicalType::ANY}, LogicalType::VARCHAR, StatsFunction, StatsBind, nullptr, + StatsPropagateStats); + stats.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + stats.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return stats; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/system_functions.cpp b/src/duckdb/src/core_functions/scalar/generic/system_functions.cpp new file mode 100644 index 00000000..1bfb8555 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/system_functions.cpp @@ -0,0 +1,108 @@ +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/main/database_manager.hpp" + +namespace duckdb { + +// current_query +static void CurrentQueryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + Value val(state.GetContext().GetCurrentQuery()); + result.Reference(val); +} + +// current_schema +static void CurrentSchemaFunction(DataChunk &input, ExpressionState &state, Vector &result) { + Value val(ClientData::Get(state.GetContext()).catalog_search_path->GetDefault().schema); + result.Reference(val); +} + +// current_database +static void CurrentDatabaseFunction(DataChunk &input, ExpressionState &state, Vector &result) { + Value val(DatabaseManager::GetDefaultDatabase(state.GetContext())); + result.Reference(val); +} + +// current_schemas +static void CurrentSchemasFunction(DataChunk &input, ExpressionState &state, Vector &result) { + if (!input.AllConstant()) { + throw NotImplementedException("current_schemas requires a constant input"); + } + if (ConstantVector::IsNull(input.data[0])) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + auto implicit_schemas = *ConstantVector::GetData(input.data[0]); + vector schema_list; + auto &catalog_search_path = ClientData::Get(state.GetContext()).catalog_search_path; + auto &search_path = implicit_schemas ? catalog_search_path->Get() : catalog_search_path->GetSetPaths(); + std::transform(search_path.begin(), search_path.end(), std::back_inserter(schema_list), + [](const CatalogSearchEntry &s) -> Value { return Value(s.schema); }); + + auto val = Value::LIST(LogicalType::VARCHAR, schema_list); + result.Reference(val); +} + +// in_search_path +static void InSearchPathFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &context = state.GetContext(); + auto &search_path = ClientData::Get(context).catalog_search_path; + BinaryExecutor::Execute( + input.data[0], input.data[1], result, input.size(), [&](string_t db_name, string_t schema_name) { + return search_path->SchemaInSearchPath(context, db_name.GetString(), schema_name.GetString()); + }); +} + +// txid_current +static void TransactionIdCurrent(DataChunk &input, ExpressionState &state, Vector &result) { + auto &context = state.GetContext(); + auto &catalog = Catalog::GetCatalog(context, DatabaseManager::GetDefaultDatabase(context)); + auto &transaction = DuckTransaction::Get(context, catalog); + auto val = Value::BIGINT(transaction.start_time); + result.Reference(val); +} + +// version +static void VersionFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto val = Value(DuckDB::LibraryVersion()); + result.Reference(val); +} + +ScalarFunction CurrentQueryFun::GetFunction() { + ScalarFunction current_query({}, LogicalType::VARCHAR, CurrentQueryFunction); + current_query.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return current_query; +} + +ScalarFunction CurrentSchemaFun::GetFunction() { + return ScalarFunction({}, LogicalType::VARCHAR, CurrentSchemaFunction); +} + +ScalarFunction CurrentDatabaseFun::GetFunction() { + return ScalarFunction({}, LogicalType::VARCHAR, CurrentDatabaseFunction); +} + +ScalarFunction CurrentSchemasFun::GetFunction() { + auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); + return ScalarFunction({LogicalType::BOOLEAN}, varchar_list_type, CurrentSchemasFunction); +} + +ScalarFunction InSearchPathFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, InSearchPathFunction); +} + +ScalarFunction CurrentTransactionIdFun::GetFunction() { + return ScalarFunction({}, LogicalType::BIGINT, TransactionIdCurrent); +} + +ScalarFunction VersionFun::GetFunction() { + return ScalarFunction({}, LogicalType::VARCHAR, VersionFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/typeof.cpp b/src/duckdb/src/core_functions/scalar/generic/typeof.cpp new file mode 100644 index 00000000..a1b01f8c --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/typeof.cpp @@ -0,0 +1,16 @@ +#include "duckdb/core_functions/scalar/generic_functions.hpp" + +namespace duckdb { + +static void TypeOfFunction(DataChunk &args, ExpressionState &state, Vector &result) { + Value v(args.data[0].GetType().ToString()); + result.Reference(v); +} + +ScalarFunction TypeOfFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/array_slice.cpp b/src/duckdb/src/core_functions/scalar/list/array_slice.cpp new file mode 100644 index 00000000..a9b2aeec --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/array_slice.cpp @@ -0,0 +1,434 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/swap.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +struct ListSliceBindData : public FunctionData { + ListSliceBindData(const LogicalType &return_type_p, bool begin_is_empty_p, bool end_is_empty_p) + : return_type(return_type_p), begin_is_empty(begin_is_empty_p), end_is_empty(end_is_empty_p) { + } + ~ListSliceBindData() override; + + LogicalType return_type; + + bool begin_is_empty; + bool end_is_empty; + +public: + bool Equals(const FunctionData &other_p) const override; + unique_ptr Copy() const override; +}; + +ListSliceBindData::~ListSliceBindData() { +} + +bool ListSliceBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return return_type == other.return_type && begin_is_empty == other.begin_is_empty && + end_is_empty == other.end_is_empty; +} + +unique_ptr ListSliceBindData::Copy() const { + return make_uniq(return_type, begin_is_empty, end_is_empty); +} + +template +static int CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool svalid) { + if (step < 0) { + step = abs(step); + } + if (step == 0 && svalid) { + throw InvalidInputException("Slice step cannot be zero"); + } + if (step == 1) { + return end - begin; + } else if (static_cast(step) >= (end - begin)) { + return 1; + } + if ((end - begin) % step != 0) { + return (end - begin) / step + 1; + } + return (end - begin) / step; +} + +template +INDEX_TYPE ValueLength(const INPUT_TYPE &value) { + return 0; +} + +template <> +int64_t ValueLength(const list_entry_t &value) { + return value.length; +} + +template <> +int64_t ValueLength(const string_t &value) { + return LengthFun::Length(value); +} + +template +static void ClampIndex(INDEX_TYPE &index, const INPUT_TYPE &value, const INDEX_TYPE length, bool is_min) { + if (index < 0) { + index = (!is_min) ? index + 1 : index; + index = length + index; + return; + } else if (index > length) { + index = length; + } + return; +} + +template +static bool ClampSlice(const INPUT_TYPE &value, INDEX_TYPE &begin, INDEX_TYPE &end) { + // Clamp offsets + begin = (begin != 0 && begin != (INDEX_TYPE)NumericLimits::Minimum()) ? begin - 1 : begin; + + bool is_min = false; + if (begin == (INDEX_TYPE)NumericLimits::Minimum()) { + begin++; + is_min = true; + } + + const auto length = ValueLength(value); + if (begin < 0 && -begin > length && end < 0 && -end > length) { + begin = 0; + end = 0; + return true; + } + if (begin < 0 && -begin > length) { + begin = 0; + } + ClampIndex(begin, value, length, is_min); + ClampIndex(end, value, length, false); + end = MaxValue(begin, end); + + return true; +} + +template +INPUT_TYPE SliceValue(Vector &result, INPUT_TYPE input, INDEX_TYPE begin, INDEX_TYPE end) { + return input; +} + +template <> +list_entry_t SliceValue(Vector &result, list_entry_t input, int64_t begin, int64_t end) { + input.offset += begin; + input.length = end - begin; + return input; +} + +template <> +string_t SliceValue(Vector &result, string_t input, int64_t begin, int64_t end) { + // one-based - zero has strange semantics + return SubstringFun::SubstringUnicode(result, input, begin + 1, end - begin); +} + +template +INPUT_TYPE SliceValueWithSteps(Vector &result, SelectionVector &sel, INPUT_TYPE input, INDEX_TYPE begin, INDEX_TYPE end, + INDEX_TYPE step, idx_t &sel_idx) { + return input; +} + +template <> +list_entry_t SliceValueWithSteps(Vector &result, SelectionVector &sel, list_entry_t input, int64_t begin, int64_t end, + int64_t step, idx_t &sel_idx) { + if (end - begin == 0) { + input.length = 0; + input.offset = sel_idx; + return input; + } + input.length = CalculateSliceLength(begin, end, step, true); + idx_t child_idx = input.offset + begin; + if (step < 0) { + child_idx = input.offset + end - 1; + } + input.offset = sel_idx; + for (idx_t i = 0; i < input.length; i++) { + sel.set_index(sel_idx, child_idx); + child_idx += step; + sel_idx++; + } + return input; +} + +template +static void ExecuteConstantSlice(Vector &result, Vector &str_vector, Vector &begin_vector, Vector &end_vector, + optional_ptr step_vector, const idx_t count, SelectionVector &sel, + idx_t &sel_idx, optional_ptr result_child_vector, bool begin_is_empty, + bool end_is_empty) { + auto result_data = ConstantVector::GetData(result); + auto str_data = ConstantVector::GetData(str_vector); + auto begin_data = ConstantVector::GetData(begin_vector); + auto end_data = ConstantVector::GetData(end_vector); + auto step_data = step_vector ? ConstantVector::GetData(*step_vector) : nullptr; + + auto str = str_data[0]; + auto begin = begin_is_empty ? 0 : begin_data[0]; + auto end = end_is_empty ? ValueLength(str) : end_data[0]; + auto step = step_data ? step_data[0] : 1; + + if (step < 0) { + swap(begin, end); + begin = end_is_empty ? 0 : begin; + end = begin_is_empty ? ValueLength(str) : end; + } + + auto str_valid = !ConstantVector::IsNull(str_vector); + auto begin_valid = !ConstantVector::IsNull(begin_vector); + auto end_valid = !ConstantVector::IsNull(end_vector); + auto step_valid = step_vector && !ConstantVector::IsNull(*step_vector); + + // Clamp offsets + bool clamp_result = false; + if (str_valid && begin_valid && end_valid && (step_valid || step == 1)) { + clamp_result = ClampSlice(str, begin, end); + } + + auto sel_length = 0; + bool sel_valid = false; + if (step_vector && step_valid && str_valid && begin_valid && end_valid && step != 1 && end - begin > 0) { + sel_length = CalculateSliceLength(begin, end, step, step_valid); + sel.Initialize(sel_length); + sel_valid = true; + } + + // Try to slice + if (!str_valid || !begin_valid || !end_valid || (step_vector && !step_valid) || !clamp_result) { + ConstantVector::SetNull(result, true); + } else if (step == 1) { + result_data[0] = SliceValue(result, str, begin, end); + } else { + result_data[0] = SliceValueWithSteps(result, sel, str, begin, end, step, sel_idx); + } + + if (sel_valid) { + result_child_vector->Slice(sel, sel_length); + ListVector::SetListSize(result, sel_length); + } +} + +template +static void ExecuteFlatSlice(Vector &result, Vector &list_vector, Vector &begin_vector, Vector &end_vector, + optional_ptr step_vector, const idx_t count, SelectionVector &sel, idx_t &sel_idx, + optional_ptr result_child_vector, bool begin_is_empty, bool end_is_empty) { + UnifiedVectorFormat list_data, begin_data, end_data, step_data; + idx_t sel_length = 0; + + list_vector.ToUnifiedFormat(count, list_data); + begin_vector.ToUnifiedFormat(count, begin_data); + end_vector.ToUnifiedFormat(count, end_data); + if (step_vector) { + step_vector->ToUnifiedFormat(count, step_data); + sel.Initialize(ListVector::GetListSize(list_vector)); + } + + auto result_data = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + + for (idx_t i = 0; i < count; ++i) { + auto list_idx = list_data.sel->get_index(i); + auto begin_idx = begin_data.sel->get_index(i); + auto end_idx = end_data.sel->get_index(i); + auto step_idx = step_vector ? step_data.sel->get_index(i) : 0; + + auto list_valid = list_data.validity.RowIsValid(list_idx); + auto begin_valid = begin_data.validity.RowIsValid(begin_idx); + auto end_valid = end_data.validity.RowIsValid(end_idx); + auto step_valid = step_vector && step_data.validity.RowIsValid(step_idx); + + if (!list_valid || !begin_valid || !end_valid || (step_vector && !step_valid)) { + result_mask.SetInvalid(i); + continue; + } + + auto sliced = reinterpret_cast(list_data.data)[list_idx]; + auto begin = begin_is_empty ? 0 : reinterpret_cast(begin_data.data)[begin_idx]; + auto end = end_is_empty ? ValueLength(sliced) + : reinterpret_cast(end_data.data)[end_idx]; + auto step = step_vector ? reinterpret_cast(step_data.data)[step_idx] : 1; + + if (step < 0) { + swap(begin, end); + begin = end_is_empty ? 0 : begin; + end = begin_is_empty ? ValueLength(sliced) : end; + } + + bool clamp_result = false; + if (step_valid || step == 1) { + clamp_result = ClampSlice(sliced, begin, end); + } + + auto length = 0; + if (end - begin > 0) { + length = CalculateSliceLength(begin, end, step, step_valid); + } + sel_length += length; + + if (!clamp_result) { + result_mask.SetInvalid(i); + } else if (!step_vector) { + result_data[i] = SliceValue(result, sliced, begin, end); + } else { + result_data[i] = + SliceValueWithSteps(result, sel, sliced, begin, end, step, sel_idx); + } + } + if (step_vector) { + SelectionVector new_sel(sel_length); + for (idx_t i = 0; i < sel_length; ++i) { + new_sel.set_index(i, sel.get_index(i)); + } + result_child_vector->Slice(new_sel, sel_length); + ListVector::SetListSize(result, sel_length); + } +} + +template +static void ExecuteSlice(Vector &result, Vector &list_or_str_vector, Vector &begin_vector, Vector &end_vector, + optional_ptr step_vector, const idx_t count, bool begin_is_empty, bool end_is_empty) { + optional_ptr result_child_vector; + if (step_vector) { + result_child_vector = &ListVector::GetEntry(result); + } + + SelectionVector sel; + idx_t sel_idx = 0; + + if (result.GetVectorType() == VectorType::CONSTANT_VECTOR) { + ExecuteConstantSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, + count, sel, sel_idx, result_child_vector, begin_is_empty, + end_is_empty); + } else { + ExecuteFlatSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, + count, sel, sel_idx, result_child_vector, begin_is_empty, + end_is_empty); + } + result.Verify(count); +} + +static void ArraySliceFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); + D_ASSERT(args.data.size() == 3 || args.data.size() == 4); + auto count = args.size(); + + Vector &list_or_str_vector = args.data[0]; + if (list_or_str_vector.GetType().id() == LogicalTypeId::SQLNULL) { + auto &result_validity = FlatVector::Validity(result); + result_validity.SetInvalid(0); + return; + } + + Vector &begin_vector = args.data[1]; + Vector &end_vector = args.data[2]; + + optional_ptr step_vector; + if (args.ColumnCount() == 4) { + step_vector = &args.data[3]; + } + + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto begin_is_empty = info.begin_is_empty; + auto end_is_empty = info.end_is_empty; + + result.SetVectorType(args.AllConstant() ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); + switch (result.GetType().id()) { + case LogicalTypeId::LIST: { + // Share the value dictionary as we are just going to slice it + if (list_or_str_vector.GetVectorType() != VectorType::FLAT_VECTOR && + list_or_str_vector.GetVectorType() != VectorType::CONSTANT_VECTOR) { + list_or_str_vector.Flatten(count); + } + ListVector::ReferenceEntry(result, list_or_str_vector); + ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, count, + begin_is_empty, end_is_empty); + break; + } + case LogicalTypeId::VARCHAR: { + ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, count, + begin_is_empty, end_is_empty); + break; + } + default: + throw NotImplementedException("Specifier type not implemented"); + } +} + +static bool CheckIfParamIsEmpty(duckdb::unique_ptr ¶m) { + bool is_empty = false; + if (param->return_type.id() == LogicalTypeId::LIST) { + auto empty_list = make_uniq(Value::LIST(LogicalType::INTEGER, vector())); + is_empty = param->Equals(*empty_list); + if (!is_empty) { + // if the param is not empty, the user has entered a list instead of a BIGINT + throw BinderException("The upper and lower bounds of the slice must be a BIGINT"); + } + } + return is_empty; +} + +static unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(arguments.size() == 3 || arguments.size() == 4); + D_ASSERT(bound_function.arguments.size() == 3 || bound_function.arguments.size() == 4); + + switch (arguments[0]->return_type.id()) { + case LogicalTypeId::LIST: + // The result is the same type + bound_function.return_type = arguments[0]->return_type; + break; + case LogicalTypeId::VARCHAR: + // string slice returns a string + if (bound_function.arguments.size() == 4) { + throw NotImplementedException( + "Slice with steps has not been implemented for string types, you can consider rewriting your query as " + "follows:\n SELECT array_to_string((str_split(string, '')[begin:end:step], '');"); + } + bound_function.return_type = arguments[0]->return_type; + for (idx_t i = 1; i < 3; i++) { + if (arguments[i]->return_type.id() != LogicalTypeId::LIST) { + bound_function.arguments[i] = LogicalType::BIGINT; + } + } + break; + case LogicalTypeId::SQLNULL: + case LogicalTypeId::UNKNOWN: + bound_function.arguments[0] = LogicalTypeId::UNKNOWN; + bound_function.return_type = LogicalType::SQLNULL; + break; + default: + throw BinderException("ARRAY_SLICE can only operate on LISTs and VARCHARs"); + } + + bool begin_is_empty = CheckIfParamIsEmpty(arguments[1]); + if (!begin_is_empty) { + bound_function.arguments[1] = LogicalType::BIGINT; + } + bool end_is_empty = CheckIfParamIsEmpty(arguments[2]); + if (!end_is_empty) { + bound_function.arguments[2] = LogicalType::BIGINT; + } + + return make_uniq(bound_function.return_type, begin_is_empty, end_is_empty); +} + +ScalarFunctionSet ListSliceFun::GetFunctions() { + // the arguments and return types are actually set in the binder function + ScalarFunction fun({LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ArraySliceFunction, + ArraySliceBind); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + + ScalarFunctionSet set; + set.AddFunction(fun); + fun.arguments.push_back(LogicalType::BIGINT); + set.AddFunction(fun); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/flatten.cpp b/src/duckdb/src/core_functions/scalar/list/flatten.cpp new file mode 100644 index 00000000..34cd5515 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/flatten.cpp @@ -0,0 +1,134 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +void ListFlattenFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + + Vector &input = args.data[0]; + if (input.GetType().id() == LogicalTypeId::SQLNULL) { + result.Reference(input); + return; + } + + idx_t count = args.size(); + + UnifiedVectorFormat list_data; + input.ToUnifiedFormat(count, list_data); + auto list_entries = UnifiedVectorFormat::GetData(list_data); + auto &child_vector = ListVector::GetEntry(input); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_entries = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + if (child_vector.GetType().id() == LogicalTypeId::SQLNULL) { + for (idx_t i = 0; i < count; i++) { + auto list_index = list_data.sel->get_index(i); + if (!list_data.validity.RowIsValid(list_index)) { + result_validity.SetInvalid(i); + continue; + } + result_entries[i].offset = 0; + result_entries[i].length = 0; + } + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + return; + } + + auto child_size = ListVector::GetListSize(input); + UnifiedVectorFormat child_data; + child_vector.ToUnifiedFormat(child_size, child_data); + auto child_entries = UnifiedVectorFormat::GetData(child_data); + auto &data_vector = ListVector::GetEntry(child_vector); + + idx_t offset = 0; + for (idx_t i = 0; i < count; i++) { + auto list_index = list_data.sel->get_index(i); + if (!list_data.validity.RowIsValid(list_index)) { + result_validity.SetInvalid(i); + continue; + } + auto list_entry = list_entries[list_index]; + + idx_t source_offset = 0; + // Find first valid child list entry to get offset + for (idx_t j = 0; j < list_entry.length; j++) { + auto child_list_index = child_data.sel->get_index(list_entry.offset + j); + if (child_data.validity.RowIsValid(child_list_index)) { + source_offset = child_entries[child_list_index].offset; + break; + } + } + + idx_t length = 0; + // Find last valid child list entry to get length + for (idx_t j = list_entry.length - 1; j != (idx_t)-1; j--) { + auto child_list_index = child_data.sel->get_index(list_entry.offset + j); + if (child_data.validity.RowIsValid(child_list_index)) { + auto child_entry = child_entries[child_list_index]; + length = child_entry.offset + child_entry.length - source_offset; + break; + } + } + ListVector::Append(result, data_vector, source_offset + length, source_offset); + + result_entries[i].offset = offset; + result_entries[i].length = length; + offset += length; + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr ListFlattenBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 1); + + auto &input_type = arguments[0]->return_type; + bound_function.arguments[0] = input_type; + if (input_type.id() == LogicalTypeId::UNKNOWN) { + bound_function.arguments[0] = LogicalType(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + D_ASSERT(input_type.id() == LogicalTypeId::LIST); + + auto child_type = ListType::GetChildType(input_type); + if (child_type.id() == LogicalType::SQLNULL) { + bound_function.return_type = input_type; + return make_uniq(bound_function.return_type); + } + if (child_type.id() == LogicalTypeId::UNKNOWN) { + bound_function.arguments[0] = LogicalType(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + D_ASSERT(child_type.id() == LogicalTypeId::LIST); + + bound_function.return_type = child_type; + return make_uniq(bound_function.return_type); +} + +static unique_ptr ListFlattenStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &list_child_stats = ListStats::GetChildStats(child_stats[0]); + auto child_copy = list_child_stats.Copy(); + child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + return child_copy.ToUnique(); +} + +ScalarFunction ListFlattenFun::GetFunction() { + return ScalarFunction({LogicalType::LIST(LogicalType::LIST(LogicalType::ANY))}, LogicalType::LIST(LogicalType::ANY), + ListFlattenFunction, ListFlattenBind, nullptr, ListFlattenStats); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/src/core_functions/scalar/list/list_aggregates.cpp new file mode 100644 index 00000000..0f561b1f --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/list_aggregates.cpp @@ -0,0 +1,531 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/core_functions/aggregate/nested_functions.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +// FIXME: use a local state for each thread to increase performance? +// FIXME: benchmark the use of simple_update against using update (if applicable) + +static unique_ptr ListAggregatesBindFailure(ScalarFunction &bound_function) { + bound_function.arguments[0] = LogicalType::SQLNULL; + bound_function.return_type = LogicalType::SQLNULL; + return make_uniq(LogicalType::SQLNULL); +} + +struct ListAggregatesBindData : public FunctionData { + ListAggregatesBindData(const LogicalType &stype_p, unique_ptr aggr_expr_p); + ~ListAggregatesBindData() override; + + LogicalType stype; + unique_ptr aggr_expr; + + unique_ptr Copy() const override { + return make_uniq(stype, aggr_expr->Copy()); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return stype == other.stype && aggr_expr->Equals(*other.aggr_expr); + } + void Serialize(Serializer &serializer) const { + serializer.WriteProperty(1, "stype", stype); + serializer.WriteProperty(2, "aggr_expr", aggr_expr); + } + static unique_ptr Deserialize(Deserializer &deserializer) { + auto stype = deserializer.ReadProperty(1, "stype"); + auto aggr_expr = deserializer.ReadProperty>(2, "aggr_expr"); + auto result = make_uniq(std::move(stype), std::move(aggr_expr)); + return result; + } + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + auto bind_data = dynamic_cast(bind_data_p.get()); + serializer.WritePropertyWithDefault(100, "bind_data", bind_data, (const ListAggregatesBindData *)nullptr); + } + + static unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &bound_function) { + auto result = deserializer.ReadPropertyWithDefault>( + 100, "bind_data", unique_ptr(nullptr)); + if (!result) { + return ListAggregatesBindFailure(bound_function); + } + return std::move(result); + } +}; + +ListAggregatesBindData::ListAggregatesBindData(const LogicalType &stype_p, unique_ptr aggr_expr_p) + : stype(stype_p), aggr_expr(std::move(aggr_expr_p)) { +} + +ListAggregatesBindData::~ListAggregatesBindData() { +} + +struct StateVector { + StateVector(idx_t count_p, unique_ptr aggr_expr_p) + : count(count_p), aggr_expr(std::move(aggr_expr_p)), state_vector(Vector(LogicalType::POINTER, count_p)) { + } + + ~StateVector() { // NOLINT + // destroy objects within the aggregate states + auto &aggr = aggr_expr->Cast(); + if (aggr.function.destructor) { + ArenaAllocator allocator(Allocator::DefaultAllocator()); + AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); + aggr.function.destructor(state_vector, aggr_input_data, count); + } + } + + idx_t count; + unique_ptr aggr_expr; + Vector state_vector; +}; + +struct FinalizeValueFunctor { + template + static Value FinalizeValue(T first) { + return Value::CreateValue(first); + } +}; + +struct FinalizeStringValueFunctor { + template + static Value FinalizeValue(T first) { + string_t value = first; + return Value::CreateValue(value); + } +}; + +struct AggregateFunctor { + template > + static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + } +}; + +struct DistinctFunctor { + template > + static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = (HistogramAggState **)sdata.data; + + auto result_data = FlatVector::GetData(result); + + idx_t offset = 0; + for (idx_t i = 0; i < count; i++) { + + auto state = states[sdata.sel->get_index(i)]; + result_data[i].offset = offset; + + if (!state->hist) { + result_data[i].length = 0; + continue; + } + + result_data[i].length = state->hist->size(); + offset += state->hist->size(); + + for (auto &entry : *state->hist) { + Value bucket_value = OP::template FinalizeValue(entry.first); + ListVector::PushBack(result, bucket_value); + } + } + result.Verify(count); + } +}; + +struct UniqueFunctor { + template > + static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = (HistogramAggState **)sdata.data; + + auto result_data = FlatVector::GetData(result); + + for (idx_t i = 0; i < count; i++) { + + auto state = states[sdata.sel->get_index(i)]; + + if (!state->hist) { + result_data[i] = 0; + continue; + } + + result_data[i] = state->hist->size(); + } + result.Verify(count); + } +}; + +template +static void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto count = args.size(); + Vector &lists = args.data[0]; + + // set the result vector + result.SetVectorType(VectorType::FLAT_VECTOR); + auto &result_validity = FlatVector::Validity(result); + + if (lists.GetType().id() == LogicalTypeId::SQLNULL) { + result_validity.SetInvalid(0); + return; + } + + // get the aggregate function + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto &aggr = info.aggr_expr->Cast(); + ArenaAllocator allocator(Allocator::DefaultAllocator()); + AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); + + D_ASSERT(aggr.function.update); + + auto lists_size = ListVector::GetListSize(lists); + auto &child_vector = ListVector::GetEntry(lists); + child_vector.Flatten(lists_size); + + UnifiedVectorFormat child_data; + child_vector.ToUnifiedFormat(lists_size, child_data); + + UnifiedVectorFormat lists_data; + lists.ToUnifiedFormat(count, lists_data); + auto list_entries = UnifiedVectorFormat::GetData(lists_data); + + // state_buffer holds the state for each list of this chunk + idx_t size = aggr.function.state_size(); + auto state_buffer = make_unsafe_uniq_array(size * count); + + // state vector for initialize and finalize + StateVector state_vector(count, info.aggr_expr->Copy()); + auto states = FlatVector::GetData(state_vector.state_vector); + + // state vector of STANDARD_VECTOR_SIZE holds the pointers to the states + Vector state_vector_update = Vector(LogicalType::POINTER); + auto states_update = FlatVector::GetData(state_vector_update); + + // selection vector pointing to the data + SelectionVector sel_vector(STANDARD_VECTOR_SIZE); + idx_t states_idx = 0; + + for (idx_t i = 0; i < count; i++) { + + // initialize the state for this list + auto state_ptr = state_buffer.get() + size * i; + states[i] = state_ptr; + aggr.function.initialize(states[i]); + + auto lists_index = lists_data.sel->get_index(i); + const auto &list_entry = list_entries[lists_index]; + + // nothing to do for this list + if (!lists_data.validity.RowIsValid(lists_index)) { + result_validity.SetInvalid(i); + continue; + } + + // skip empty list + if (list_entry.length == 0) { + continue; + } + + for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { + // states vector is full, update + if (states_idx == STANDARD_VECTOR_SIZE) { + // update the aggregate state(s) + Vector slice(child_vector, sel_vector, states_idx); + aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); + + // reset values + states_idx = 0; + } + + auto source_idx = child_data.sel->get_index(list_entry.offset + child_idx); + sel_vector.set_index(states_idx, source_idx); + states_update[states_idx] = state_ptr; + states_idx++; + } + } + + // update the remaining elements of the last list(s) + if (states_idx != 0) { + Vector slice(child_vector, sel_vector, states_idx); + aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); + } + + if (IS_AGGR) { + // finalize all the aggregate states + aggr.function.finalize(state_vector.state_vector, aggr_input_data, result, count, 0); + + } else { + // finalize manually to use the map + D_ASSERT(aggr.function.arguments.size() == 1); + auto key_type = aggr.function.arguments[0]; + + switch (key_type.InternalType()) { + case PhysicalType::BOOL: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::UINT8: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::UINT16: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::UINT32: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::UINT64: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::INT8: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::INT16: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::INT32: + if (key_type.id() == LogicalTypeId::DATE) { + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + } else { + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + } + break; + case PhysicalType::INT64: + switch (key_type.id()) { + case LogicalTypeId::TIME: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case LogicalTypeId::TIME_TZ: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case LogicalTypeId::TIMESTAMP: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case LogicalTypeId::TIMESTAMP_MS: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case LogicalTypeId::TIMESTAMP_NS: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case LogicalTypeId::TIMESTAMP_SEC: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case LogicalTypeId::TIMESTAMP_TZ: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + default: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + } + break; + case PhysicalType::FLOAT: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::DOUBLE: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::VARCHAR: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + default: + throw InternalException("Unimplemented histogram aggregate"); + } + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static void ListAggregateFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() >= 2); + ListAggregatesFunction(args, state, result); +} + +static void ListDistinctFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + ListAggregatesFunction(args, state, result); +} + +static void ListUniqueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + ListAggregatesFunction(args, state, result); +} + +template +static unique_ptr +ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_function, const LogicalType &list_child_type, + AggregateFunction &aggr_function, vector> &arguments) { + + // create the child expression and its type + vector> children; + auto expr = make_uniq(Value(list_child_type)); + children.push_back(std::move(expr)); + // push any extra arguments into the list aggregate bind + if (arguments.size() > 2) { + for (idx_t i = 2; i < arguments.size(); i++) { + children.push_back(std::move(arguments[i])); + } + arguments.resize(2); + } + + FunctionBinder function_binder(context); + auto bound_aggr_function = function_binder.BindAggregateFunction(aggr_function, std::move(children)); + bound_function.arguments[0] = LogicalType::LIST(bound_aggr_function->function.arguments[0]); + + if (IS_AGGR) { + bound_function.return_type = bound_aggr_function->function.return_type; + } + // check if the aggregate function consumed all the extra input arguments + if (bound_aggr_function->children.size() > 1) { + throw InvalidInputException( + "Aggregate function %s is not supported for list_aggr: extra arguments were not removed during bind", + bound_aggr_function->ToString()); + } + + return make_uniq(bound_function.return_type, std::move(bound_aggr_function)); +} + +template +static unique_ptr ListAggregatesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { + return ListAggregatesBindFailure(bound_function); + } + + bool is_parameter = arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN; + auto list_child_type = is_parameter ? LogicalTypeId::UNKNOWN : ListType::GetChildType(arguments[0]->return_type); + + string function_name = "histogram"; + if (IS_AGGR) { // get the name of the aggregate function + if (!arguments[1]->IsFoldable()) { + throw InvalidInputException("Aggregate function name must be a constant"); + } + // get the function name + Value function_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + function_name = function_value.ToString(); + } + + // look up the aggregate function in the catalog + QueryErrorContext error_context(nullptr, 0); + auto &func = Catalog::GetSystemCatalog(context).GetEntry( + context, DEFAULT_SCHEMA, function_name, error_context); + D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); + + if (is_parameter) { + bound_function.arguments[0] = LogicalTypeId::UNKNOWN; + bound_function.return_type = LogicalType::SQLNULL; + return nullptr; + } + + // find a matching aggregate function + string error; + vector types; + types.push_back(list_child_type); + // push any extra arguments into the type list + for (idx_t i = 2; i < arguments.size(); i++) { + types.push_back(arguments[i]->return_type); + } + + FunctionBinder function_binder(context); + auto best_function_idx = function_binder.BindFunction(func.name, func.functions, types, error); + if (best_function_idx == DConstants::INVALID_INDEX) { + throw BinderException("No matching aggregate function\n%s", error); + } + + // found a matching function, bind it as an aggregate + auto best_function = func.functions.GetFunctionByOffset(best_function_idx); + if (IS_AGGR) { + return ListAggregatesBindFunction(context, bound_function, list_child_type, best_function, arguments); + } + + // create the unordered map histogram function + D_ASSERT(best_function.arguments.size() == 1); + auto key_type = best_function.arguments[0]; + auto aggr_function = HistogramFun::GetHistogramUnorderedMap(key_type); + return ListAggregatesBindFunction(context, bound_function, list_child_type, aggr_function, arguments); +} + +static unique_ptr ListAggregateBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + // the list column and the name of the aggregate function + D_ASSERT(bound_function.arguments.size() >= 2); + D_ASSERT(arguments.size() >= 2); + + return ListAggregatesBind(context, bound_function, arguments); +} + +static unique_ptr ListDistinctBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + D_ASSERT(bound_function.arguments.size() == 1); + D_ASSERT(arguments.size() == 1); + bound_function.return_type = arguments[0]->return_type; + + return ListAggregatesBind<>(context, bound_function, arguments); +} + +static unique_ptr ListUniqueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + D_ASSERT(bound_function.arguments.size() == 1); + D_ASSERT(arguments.size() == 1); + bound_function.return_type = LogicalType::UBIGINT; + + return ListAggregatesBind<>(context, bound_function, arguments); +} + +ScalarFunction ListAggregateFun::GetFunction() { + auto result = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, LogicalType::ANY, + ListAggregateFunction, ListAggregateBind); + result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + result.varargs = LogicalType::ANY; + result.serialize = ListAggregatesBindData::Serialize; + result.deserialize = ListAggregatesBindData::Deserialize; + return result; +} + +ScalarFunction ListDistinctFun::GetFunction() { + return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), + ListDistinctFunction, ListDistinctBind); +} + +ScalarFunction ListUniqueFun::GetFunction() { + return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::UBIGINT, ListUniqueFunction, + ListUniqueBind); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/list_cosine_similarity.cpp b/src/duckdb/src/core_functions/scalar/list/list_cosine_similarity.cpp new file mode 100644 index 00000000..72607752 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/list_cosine_similarity.cpp @@ -0,0 +1,78 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include +#include + +namespace duckdb { + +template +static void ListCosineSimilarity(DataChunk &args, ExpressionState &, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + + auto count = args.size(); + auto &left = args.data[0]; + auto &right = args.data[1]; + auto left_count = ListVector::GetListSize(left); + auto right_count = ListVector::GetListSize(right); + + auto &left_child = ListVector::GetEntry(left); + auto &right_child = ListVector::GetEntry(right); + + D_ASSERT(left_child.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(right_child.GetVectorType() == VectorType::FLAT_VECTOR); + + if (!FlatVector::Validity(left_child).CheckAllValid(left_count)) { + throw InvalidInputException("list_cosine_similarity: left argument can not contain NULL values"); + } + + if (!FlatVector::Validity(right_child).CheckAllValid(right_count)) { + throw InvalidInputException("list_cosine_similarity: right argument can not contain NULL values"); + } + + auto left_data = FlatVector::GetData(left_child); + auto right_data = FlatVector::GetData(right_child); + + BinaryExecutor::Execute( + left, right, result, count, [&](list_entry_t left, list_entry_t right) { + if (left.length != right.length) { + throw InvalidInputException(StringUtil::Format( + "list_cosine_similarity: list dimensions must be equal, got left length %d and right length %d", + left.length, right.length)); + } + + auto dimensions = left.length; + + NUMERIC_TYPE distance = 0; + NUMERIC_TYPE norm_l = 0; + NUMERIC_TYPE norm_r = 0; + + auto l_ptr = left_data + left.offset; + auto r_ptr = right_data + right.offset; + for (idx_t i = 0; i < dimensions; i++) { + auto x = *l_ptr++; + auto y = *r_ptr++; + distance += x * y; + norm_l += x * x; + norm_r += y * y; + } + + auto similarity = distance / (std::sqrt(norm_l) * std::sqrt(norm_r)); + + // clamp to [-1, 1] to avoid floating point errors + return std::max(static_cast(-1), std::min(similarity, static_cast(1))); + }); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +ScalarFunctionSet ListCosineSimilarityFun::GetFunctions() { + ScalarFunctionSet set("list_cosine_similarity"); + set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::FLOAT), LogicalType::LIST(LogicalType::FLOAT)}, + LogicalType::FLOAT, ListCosineSimilarity)); + set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::DOUBLE), LogicalType::LIST(LogicalType::DOUBLE)}, + LogicalType::DOUBLE, ListCosineSimilarity)); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/list_distance.cpp b/src/duckdb/src/core_functions/scalar/list/list_distance.cpp new file mode 100644 index 00000000..aa70e4a1 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/list_distance.cpp @@ -0,0 +1,72 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include + +namespace duckdb { + +template +static void ListDistance(DataChunk &args, ExpressionState &, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + + auto count = args.size(); + auto &left = args.data[0]; + auto &right = args.data[1]; + auto left_count = ListVector::GetListSize(left); + auto right_count = ListVector::GetListSize(right); + + auto &left_child = ListVector::GetEntry(left); + auto &right_child = ListVector::GetEntry(right); + + D_ASSERT(left_child.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(right_child.GetVectorType() == VectorType::FLAT_VECTOR); + + if (!FlatVector::Validity(left_child).CheckAllValid(left_count)) { + throw InvalidInputException("list_distance: left argument can not contain NULL values"); + } + + if (!FlatVector::Validity(right_child).CheckAllValid(right_count)) { + throw InvalidInputException("list_distance: right argument can not contain NULL values"); + } + + auto left_data = FlatVector::GetData(left_child); + auto right_data = FlatVector::GetData(right_child); + + BinaryExecutor::Execute( + left, right, result, count, [&](list_entry_t left, list_entry_t right) { + if (left.length != right.length) { + throw InvalidInputException(StringUtil::Format( + "list_distance: list dimensions must be equal, got left length %d and right length %d", left.length, + right.length)); + } + + auto dimensions = left.length; + + NUMERIC_TYPE distance = 0; + + auto l_ptr = left_data + left.offset; + auto r_ptr = right_data + right.offset; + + for (idx_t i = 0; i < dimensions; i++) { + auto x = *l_ptr++; + auto y = *r_ptr++; + auto diff = x - y; + distance += diff * diff; + } + + return std::sqrt(distance); + }); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +ScalarFunctionSet ListDistanceFun::GetFunctions() { + ScalarFunctionSet set("list_distance"); + set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::FLOAT), LogicalType::LIST(LogicalType::FLOAT)}, + LogicalType::FLOAT, ListDistance)); + set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::DOUBLE), LogicalType::LIST(LogicalType::DOUBLE)}, + LogicalType::DOUBLE, ListDistance)); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/list_inner_product.cpp b/src/duckdb/src/core_functions/scalar/list/list_inner_product.cpp new file mode 100644 index 00000000..45293e0c --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/list_inner_product.cpp @@ -0,0 +1,70 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" + +namespace duckdb { + +template +static void ListInnerProduct(DataChunk &args, ExpressionState &, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + + auto count = args.size(); + auto &left = args.data[0]; + auto &right = args.data[1]; + auto left_count = ListVector::GetListSize(left); + auto right_count = ListVector::GetListSize(right); + + auto &left_child = ListVector::GetEntry(left); + auto &right_child = ListVector::GetEntry(right); + + D_ASSERT(left_child.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(right_child.GetVectorType() == VectorType::FLAT_VECTOR); + + if (!FlatVector::Validity(left_child).CheckAllValid(left_count)) { + throw InvalidInputException("list_inner_product: left argument can not contain NULL values"); + } + + if (!FlatVector::Validity(right_child).CheckAllValid(right_count)) { + throw InvalidInputException("list_inner_product: right argument can not contain NULL values"); + } + + auto left_data = FlatVector::GetData(left_child); + auto right_data = FlatVector::GetData(right_child); + + BinaryExecutor::Execute( + left, right, result, count, [&](list_entry_t left, list_entry_t right) { + if (left.length != right.length) { + throw InvalidInputException(StringUtil::Format( + "list_inner_product: list dimensions must be equal, got left length %d and right length %d", + left.length, right.length)); + } + + auto dimensions = left.length; + + NUMERIC_TYPE distance = 0; + + auto l_ptr = left_data + left.offset; + auto r_ptr = right_data + right.offset; + + for (idx_t i = 0; i < dimensions; i++) { + auto x = *l_ptr++; + auto y = *r_ptr++; + distance += x * y; + } + + return distance; + }); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +ScalarFunctionSet ListInnerProductFun::GetFunctions() { + ScalarFunctionSet set("list_inner_product"); + set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::FLOAT), LogicalType::LIST(LogicalType::FLOAT)}, + LogicalType::FLOAT, ListInnerProduct)); + set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::DOUBLE), LogicalType::LIST(LogicalType::DOUBLE)}, + LogicalType::DOUBLE, ListInnerProduct)); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/list_lambdas.cpp b/src/duckdb/src/core_functions/scalar/list/list_lambdas.cpp new file mode 100644 index 00000000..a9a8feba --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/list_lambdas.cpp @@ -0,0 +1,412 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_lambda_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +struct ListLambdaBindData : public FunctionData { + ListLambdaBindData(const LogicalType &stype_p, unique_ptr lambda_expr); + ~ListLambdaBindData() override; + + LogicalType stype; + unique_ptr lambda_expr; + +public: + bool Equals(const FunctionData &other_p) const override; + unique_ptr Copy() const override; + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + // auto &bind_data = bind_data_p->Cast(); + // serializer.WriteProperty(100, "stype", bind_data.stype); + // serializer.WritePropertyWithDefault(101, "lambda_expr", bind_data.lambda_expr, + // unique_ptr()); + throw NotImplementedException("FIXME: list lambda serialize"); + } + + static unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &function) { + // auto stype = deserializer.ReadProperty(100, "stype"); + // auto lambda_expr = + // deserializer.ReadPropertyWithDefault>(101, "lambda_expr", + // unique_ptr()); return make_uniq(stype, std::move(lambda_expr)); + throw NotImplementedException("FIXME: list lambda deserialize"); + } +}; + +ListLambdaBindData::ListLambdaBindData(const LogicalType &stype_p, unique_ptr lambda_expr_p) + : stype(stype_p), lambda_expr(std::move(lambda_expr_p)) { +} + +unique_ptr ListLambdaBindData::Copy() const { + return make_uniq(stype, lambda_expr ? lambda_expr->Copy() : nullptr); +} + +bool ListLambdaBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return Expression::Equals(lambda_expr, other.lambda_expr) && stype == other.stype; +} + +ListLambdaBindData::~ListLambdaBindData() { +} + +static void AppendTransformedToResult(Vector &lambda_vector, idx_t &elem_cnt, Vector &result) { + + // append the lambda_vector to the result list + UnifiedVectorFormat lambda_child_data; + lambda_vector.ToUnifiedFormat(elem_cnt, lambda_child_data); + ListVector::Append(result, lambda_vector, *lambda_child_data.sel, elem_cnt, 0); +} + +static void AppendFilteredToResult(Vector &lambda_vector, list_entry_t *result_entries, idx_t &elem_cnt, Vector &result, + idx_t &curr_list_len, idx_t &curr_list_offset, idx_t &appended_lists_cnt, + vector &lists_len, idx_t &curr_original_list_len, DataChunk &input_chunk) { + + idx_t true_count = 0; + SelectionVector true_sel(elem_cnt); + UnifiedVectorFormat lambda_data; + lambda_vector.ToUnifiedFormat(elem_cnt, lambda_data); + + auto lambda_values = UnifiedVectorFormat::GetData(lambda_data); + auto &lambda_validity = lambda_data.validity; + + // compute the new lengths and offsets, and create a selection vector + for (idx_t i = 0; i < elem_cnt; i++) { + auto entry = lambda_data.sel->get_index(i); + + while (appended_lists_cnt < lists_len.size() && lists_len[appended_lists_cnt] == 0) { + result_entries[appended_lists_cnt].offset = curr_list_offset; + result_entries[appended_lists_cnt].length = 0; + appended_lists_cnt++; + } + + // found a true value + if (lambda_validity.RowIsValid(entry) && lambda_values[entry]) { + true_sel.set_index(true_count++, i); + curr_list_len++; + } + + curr_original_list_len++; + + if (lists_len[appended_lists_cnt] == curr_original_list_len) { + result_entries[appended_lists_cnt].offset = curr_list_offset; + result_entries[appended_lists_cnt].length = curr_list_len; + curr_list_offset += curr_list_len; + appended_lists_cnt++; + curr_list_len = 0; + curr_original_list_len = 0; + } + } + + while (appended_lists_cnt < lists_len.size() && lists_len[appended_lists_cnt] == 0) { + result_entries[appended_lists_cnt].offset = curr_list_offset; + result_entries[appended_lists_cnt].length = 0; + appended_lists_cnt++; + } + + // slice to get the new lists and append them to the result + Vector new_lists(input_chunk.data[0], true_sel, true_count); + new_lists.Flatten(true_count); + UnifiedVectorFormat new_lists_child_data; + new_lists.ToUnifiedFormat(true_count, new_lists_child_data); + ListVector::Append(result, new_lists, *new_lists_child_data.sel, true_count, 0); +} + +static void ExecuteExpression(vector &types, vector &result_types, idx_t &elem_cnt, + SelectionVector &sel, vector &sel_vectors, DataChunk &input_chunk, + DataChunk &lambda_chunk, Vector &child_vector, DataChunk &args, + ExpressionExecutor &expr_executor) { + + input_chunk.SetCardinality(elem_cnt); + lambda_chunk.SetCardinality(elem_cnt); + + // set the list child vector + Vector slice(child_vector, sel, elem_cnt); + Vector second_slice(child_vector, sel, elem_cnt); + slice.Flatten(elem_cnt); + second_slice.Flatten(elem_cnt); + + input_chunk.data[0].Reference(slice); + input_chunk.data[1].Reference(second_slice); + + // set the other vectors + vector slices; + for (idx_t col_idx = 0; col_idx < args.ColumnCount() - 1; col_idx++) { + slices.emplace_back(args.data[col_idx + 1], sel_vectors[col_idx], elem_cnt); + slices[col_idx].Flatten(elem_cnt); + input_chunk.data[col_idx + 2].Reference(slices[col_idx]); + } + + // execute the lambda expression + expr_executor.Execute(input_chunk, lambda_chunk); +} + +template +static void ListLambdaFunction(DataChunk &args, ExpressionState &state, Vector &result) { + + // always at least the list argument + D_ASSERT(args.ColumnCount() >= 1); + + auto count = args.size(); + Vector &lists = args.data[0]; + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_entries = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + if (lists.GetType().id() == LogicalTypeId::SQLNULL) { + result_validity.SetInvalid(0); + return; + } + + // e.g. window functions in sub queries return dictionary vectors, which segfault on expression execution + // if not flattened first + for (idx_t i = 1; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::FLAT_VECTOR && + args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + args.data[i].Flatten(count); + } + } + + // get the lists data + UnifiedVectorFormat lists_data; + lists.ToUnifiedFormat(count, lists_data); + auto list_entries = UnifiedVectorFormat::GetData(lists_data); + + // get the lambda expression + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto &lambda_expr = info.lambda_expr; + + // get the child vector and child data + auto lists_size = ListVector::GetListSize(lists); + auto &child_vector = ListVector::GetEntry(lists); + child_vector.Flatten(lists_size); + UnifiedVectorFormat child_data; + child_vector.ToUnifiedFormat(lists_size, child_data); + + // to slice the child vector + SelectionVector sel(STANDARD_VECTOR_SIZE); + + // this vector never contains more than one element + vector result_types; + result_types.push_back(lambda_expr->return_type); + + // non-lambda parameter columns + vector columns; + vector indexes; + vector sel_vectors; + + vector types; + types.push_back(child_vector.GetType()); + types.push_back(child_vector.GetType()); + + // skip the list column + for (idx_t i = 1; i < args.ColumnCount(); i++) { + columns.emplace_back(); + args.data[i].ToUnifiedFormat(count, columns[i - 1]); + indexes.push_back(0); + sel_vectors.emplace_back(STANDARD_VECTOR_SIZE); + types.push_back(args.data[i].GetType()); + } + + // get the expression executor + ExpressionExecutor expr_executor(state.GetContext(), *lambda_expr); + + // these are only for the list_filter + vector lists_len; + idx_t curr_list_len = 0; + idx_t curr_list_offset = 0; + idx_t appended_lists_cnt = 0; + idx_t curr_original_list_len = 0; + + if (!IS_TRANSFORM) { + lists_len.reserve(count); + } + + DataChunk input_chunk; + DataChunk lambda_chunk; + input_chunk.InitializeEmpty(types); + lambda_chunk.Initialize(Allocator::DefaultAllocator(), result_types); + + // loop over the child entries and create chunks to be executed by the expression executor + idx_t elem_cnt = 0; + idx_t offset = 0; + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + + auto lists_index = lists_data.sel->get_index(row_idx); + const auto &list_entry = list_entries[lists_index]; + + // set the result to NULL for this row + if (!lists_data.validity.RowIsValid(lists_index)) { + result_validity.SetInvalid(row_idx); + if (!IS_TRANSFORM) { + lists_len.push_back(0); + } + continue; + } + + // set the length and offset of the resulting lists of list_transform + if (IS_TRANSFORM) { + result_entries[row_idx].offset = offset; + result_entries[row_idx].length = list_entry.length; + offset += list_entry.length; + } else { + lists_len.push_back(list_entry.length); + } + + // empty list, nothing to execute + if (list_entry.length == 0) { + continue; + } + + // get the data indexes + for (idx_t col_idx = 0; col_idx < args.ColumnCount() - 1; col_idx++) { + indexes[col_idx] = columns[col_idx].sel->get_index(row_idx); + } + + // iterate list elements and create transformed expression columns + for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { + // reached STANDARD_VECTOR_SIZE elements + if (elem_cnt == STANDARD_VECTOR_SIZE) { + lambda_chunk.Reset(); + ExecuteExpression(types, result_types, elem_cnt, sel, sel_vectors, input_chunk, lambda_chunk, + child_vector, args, expr_executor); + + auto &lambda_vector = lambda_chunk.data[0]; + + if (IS_TRANSFORM) { + AppendTransformedToResult(lambda_vector, elem_cnt, result); + } else { + AppendFilteredToResult(lambda_vector, result_entries, elem_cnt, result, curr_list_len, + curr_list_offset, appended_lists_cnt, lists_len, curr_original_list_len, + input_chunk); + } + elem_cnt = 0; + } + + // to slice the child vector + auto source_idx = child_data.sel->get_index(list_entry.offset + child_idx); + sel.set_index(elem_cnt, source_idx); + + // for each column, set the index of the selection vector to slice properly + for (idx_t col_idx = 0; col_idx < args.ColumnCount() - 1; col_idx++) { + sel_vectors[col_idx].set_index(elem_cnt, indexes[col_idx]); + } + elem_cnt++; + } + } + + lambda_chunk.Reset(); + ExecuteExpression(types, result_types, elem_cnt, sel, sel_vectors, input_chunk, lambda_chunk, child_vector, args, + expr_executor); + auto &lambda_vector = lambda_chunk.data[0]; + + if (IS_TRANSFORM) { + AppendTransformedToResult(lambda_vector, elem_cnt, result); + } else { + AppendFilteredToResult(lambda_vector, result_entries, elem_cnt, result, curr_list_len, curr_list_offset, + appended_lists_cnt, lists_len, curr_original_list_len, input_chunk); + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static void ListTransformFunction(DataChunk &args, ExpressionState &state, Vector &result) { + ListLambdaFunction<>(args, state, result); +} + +static void ListFilterFunction(DataChunk &args, ExpressionState &state, Vector &result) { + ListLambdaFunction(args, state, result); +} + +template +static unique_ptr ListLambdaBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto &bound_lambda_expr = arguments[1]->Cast(); + if (bound_lambda_expr.parameter_count != LAMBDA_PARAM_CNT) { + throw BinderException("Incorrect number of parameters in lambda function! " + bound_function.name + + " expects " + to_string(LAMBDA_PARAM_CNT) + " parameter(s)."); + } + + if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { + bound_function.arguments[0] = LogicalType::SQLNULL; + bound_function.return_type = LogicalType::SQLNULL; + return make_uniq(bound_function.return_type, nullptr); + } + + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + + D_ASSERT(arguments[0]->return_type.id() == LogicalTypeId::LIST); + + // get the lambda expression and put it in the bind info + auto lambda_expr = std::move(bound_lambda_expr.lambda_expr); + return make_uniq(bound_function.return_type, std::move(lambda_expr)); +} + +static unique_ptr ListTransformBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + // at least the list column and the lambda function + D_ASSERT(arguments.size() == 2); + if (arguments[1]->expression_class != ExpressionClass::BOUND_LAMBDA) { + throw BinderException("Invalid lambda expression!"); + } + + auto &bound_lambda_expr = arguments[1]->Cast(); + bound_function.return_type = LogicalType::LIST(bound_lambda_expr.lambda_expr->return_type); + return ListLambdaBind<1>(context, bound_function, arguments); +} + +static unique_ptr ListFilterBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + // at least the list column and the lambda function + D_ASSERT(arguments.size() == 2); + if (arguments[1]->expression_class != ExpressionClass::BOUND_LAMBDA) { + throw BinderException("Invalid lambda expression!"); + } + + // try to cast to boolean, if the return type of the lambda filter expression is not already boolean + auto &bound_lambda_expr = arguments[1]->Cast(); + if (bound_lambda_expr.lambda_expr->return_type != LogicalType::BOOLEAN) { + auto cast_lambda_expr = + BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.lambda_expr), LogicalType::BOOLEAN); + bound_lambda_expr.lambda_expr = std::move(cast_lambda_expr); + } + + bound_function.return_type = arguments[0]->return_type; + return ListLambdaBind<1>(context, bound_function, arguments); +} + +ScalarFunction ListTransformFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), + ListTransformFunction, ListTransformBind, nullptr, nullptr); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = ListLambdaBindData::Serialize; + fun.deserialize = ListLambdaBindData::Deserialize; + return fun; +} + +ScalarFunction ListFilterFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), + ListFilterFunction, ListFilterBind, nullptr, nullptr); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = ListLambdaBindData::Serialize; + fun.deserialize = ListLambdaBindData::Deserialize; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/list_sort.cpp b/src/duckdb/src/core_functions/scalar/list/list_sort.cpp new file mode 100644 index 00000000..02dc205e --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/list_sort.cpp @@ -0,0 +1,344 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/common/sort/sort.hpp" + +namespace duckdb { + +struct ListSortBindData : public FunctionData { + ListSortBindData(OrderType order_type_p, OrderByNullType null_order_p, const LogicalType &return_type_p, + const LogicalType &child_type_p, ClientContext &context_p); + ~ListSortBindData() override; + + OrderType order_type; + OrderByNullType null_order; + LogicalType return_type; + LogicalType child_type; + + vector types; + vector payload_types; + + ClientContext &context; + RowLayout payload_layout; + vector orders; + +public: + bool Equals(const FunctionData &other_p) const override; + unique_ptr Copy() const override; +}; + +ListSortBindData::ListSortBindData(OrderType order_type_p, OrderByNullType null_order_p, + const LogicalType &return_type_p, const LogicalType &child_type_p, + ClientContext &context_p) + : order_type(order_type_p), null_order(null_order_p), return_type(return_type_p), child_type(child_type_p), + context(context_p) { + + // get the vector types + types.emplace_back(LogicalType::USMALLINT); + types.emplace_back(child_type); + D_ASSERT(types.size() == 2); + + // get the payload types + payload_types.emplace_back(LogicalType::UINTEGER); + D_ASSERT(payload_types.size() == 1); + + // initialize the payload layout + payload_layout.Initialize(payload_types); + + // get the BoundOrderByNode + auto idx_col_expr = make_uniq_base(LogicalType::USMALLINT, 0); + auto lists_col_expr = make_uniq_base(child_type, 1); + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, std::move(idx_col_expr)); + orders.emplace_back(order_type, null_order, std::move(lists_col_expr)); +} + +unique_ptr ListSortBindData::Copy() const { + return make_uniq(order_type, null_order, return_type, child_type, context); +} + +bool ListSortBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return order_type == other.order_type && null_order == other.null_order; +} + +ListSortBindData::~ListSortBindData() { +} + +// create the key_chunk and the payload_chunk and sink them into the local_sort_state +void SinkDataChunk(Vector *child_vector, SelectionVector &sel, idx_t offset_lists_indices, vector &types, + vector &payload_types, Vector &payload_vector, LocalSortState &local_sort_state, + bool &data_to_sort, Vector &lists_indices) { + + // slice the child vector + Vector slice(*child_vector, sel, offset_lists_indices); + + // initialize and fill key_chunk + DataChunk key_chunk; + key_chunk.InitializeEmpty(types); + key_chunk.data[0].Reference(lists_indices); + key_chunk.data[1].Reference(slice); + key_chunk.SetCardinality(offset_lists_indices); + + // initialize and fill key_chunk and payload_chunk + DataChunk payload_chunk; + payload_chunk.InitializeEmpty(payload_types); + payload_chunk.data[0].Reference(payload_vector); + payload_chunk.SetCardinality(offset_lists_indices); + + key_chunk.Verify(); + payload_chunk.Verify(); + + // sink + key_chunk.Flatten(); + local_sort_state.SinkChunk(key_chunk, payload_chunk); + data_to_sort = true; +} + +static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() >= 1 && args.ColumnCount() <= 3); + auto count = args.size(); + Vector &input_lists = args.data[0]; + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto &result_validity = FlatVector::Validity(result); + + if (input_lists.GetType().id() == LogicalTypeId::SQLNULL) { + result_validity.SetInvalid(0); + return; + } + + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + // initialize the global and local sorting state + auto &buffer_manager = BufferManager::GetBufferManager(info.context); + GlobalSortState global_sort_state(buffer_manager, info.orders, info.payload_layout); + LocalSortState local_sort_state; + local_sort_state.Initialize(global_sort_state, buffer_manager); + + // this ensures that we do not change the order of the entries in the input chunk + VectorOperations::Copy(input_lists, result, count, 0, 0); + + // get the child vector + auto lists_size = ListVector::GetListSize(result); + auto &child_vector = ListVector::GetEntry(result); + UnifiedVectorFormat child_data; + child_vector.ToUnifiedFormat(lists_size, child_data); + + // get the lists data + UnifiedVectorFormat lists_data; + result.ToUnifiedFormat(count, lists_data); + auto list_entries = UnifiedVectorFormat::GetData(lists_data); + + // create the lists_indices vector, this contains an element for each list's entry, + // the element corresponds to the list's index, e.g. for [1, 2, 4], [5, 4] + // lists_indices contains [0, 0, 0, 1, 1] + Vector lists_indices(LogicalType::USMALLINT); + auto lists_indices_data = FlatVector::GetData(lists_indices); + + // create the payload_vector, this is just a vector containing incrementing integers + // this will later be used as the 'new' selection vector of the child_vector, after + // rearranging the payload according to the sorting order + Vector payload_vector(LogicalType::UINTEGER); + auto payload_vector_data = FlatVector::GetData(payload_vector); + + // selection vector pointing to the data of the child vector, + // used for slicing the child_vector correctly + SelectionVector sel(STANDARD_VECTOR_SIZE); + + idx_t offset_lists_indices = 0; + uint32_t incr_payload_count = 0; + bool data_to_sort = false; + + for (idx_t i = 0; i < count; i++) { + auto lists_index = lists_data.sel->get_index(i); + const auto &list_entry = list_entries[lists_index]; + + // nothing to do for this list + if (!lists_data.validity.RowIsValid(lists_index)) { + result_validity.SetInvalid(i); + continue; + } + + // empty list, no sorting required + if (list_entry.length == 0) { + continue; + } + + for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { + // lists_indices vector is full, sink + if (offset_lists_indices == STANDARD_VECTOR_SIZE) { + SinkDataChunk(&child_vector, sel, offset_lists_indices, info.types, info.payload_types, payload_vector, + local_sort_state, data_to_sort, lists_indices); + offset_lists_indices = 0; + } + + auto source_idx = list_entry.offset + child_idx; + sel.set_index(offset_lists_indices, source_idx); + lists_indices_data[offset_lists_indices] = (uint32_t)i; + payload_vector_data[offset_lists_indices] = source_idx; + offset_lists_indices++; + incr_payload_count++; + } + } + + if (offset_lists_indices != 0) { + SinkDataChunk(&child_vector, sel, offset_lists_indices, info.types, info.payload_types, payload_vector, + local_sort_state, data_to_sort, lists_indices); + } + + if (data_to_sort) { + // add local state to global state, which sorts the data + global_sort_state.AddLocalState(local_sort_state); + global_sort_state.PrepareMergePhase(); + + // selection vector that is to be filled with the 'sorted' payload + SelectionVector sel_sorted(incr_payload_count); + idx_t sel_sorted_idx = 0; + + // scan the sorted row data + PayloadScanner scanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state); + for (;;) { + DataChunk result_chunk; + result_chunk.Initialize(Allocator::DefaultAllocator(), info.payload_types); + result_chunk.SetCardinality(0); + scanner.Scan(result_chunk); + if (result_chunk.size() == 0) { + break; + } + + // construct the selection vector with the new order from the result vectors + Vector result_vector(result_chunk.data[0]); + auto result_data = FlatVector::GetData(result_vector); + auto row_count = result_chunk.size(); + + for (idx_t i = 0; i < row_count; i++) { + sel_sorted.set_index(sel_sorted_idx, result_data[i]); + D_ASSERT(result_data[i] < lists_size); + sel_sorted_idx++; + } + } + + D_ASSERT(sel_sorted_idx == incr_payload_count); + child_vector.Slice(sel_sorted, sel_sorted_idx); + child_vector.Flatten(sel_sorted_idx); + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr ListSortBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments, OrderType &order, + OrderByNullType &null_order) { + + LogicalType child_type; + if (arguments[0]->return_type == LogicalTypeId::UNKNOWN) { + bound_function.arguments[0] = LogicalTypeId::UNKNOWN; + bound_function.return_type = LogicalType::SQLNULL; + child_type = bound_function.return_type; + return make_uniq(order, null_order, bound_function.return_type, child_type, context); + } + + bound_function.arguments[0] = arguments[0]->return_type; + bound_function.return_type = arguments[0]->return_type; + child_type = ListType::GetChildType(arguments[0]->return_type); + + return make_uniq(order, null_order, bound_function.return_type, child_type, context); +} + +template +static T GetOrder(ClientContext &context, Expression &expr) { + if (!expr.IsFoldable()) { + throw InvalidInputException("Sorting order must be a constant"); + } + Value order_value = ExpressionExecutor::EvaluateScalar(context, expr); + auto order_name = StringUtil::Upper(order_value.ToString()); + return EnumUtil::FromString(order_name.c_str()); +} + +static unique_ptr ListNormalSortBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(!arguments.empty() && arguments.size() <= 3); + auto order = OrderType::ORDER_DEFAULT; + auto null_order = OrderByNullType::ORDER_DEFAULT; + + // get the sorting order + if (arguments.size() >= 2) { + order = GetOrder(context, *arguments[1]); + } + // get the null sorting order + if (arguments.size() == 3) { + null_order = GetOrder(context, *arguments[2]); + } + auto &config = DBConfig::GetConfig(context); + order = config.ResolveOrder(order); + null_order = config.ResolveNullOrder(order, null_order); + return ListSortBind(context, bound_function, arguments, order, null_order); +} + +static unique_ptr ListReverseSortBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto order = OrderType::ORDER_DEFAULT; + auto null_order = OrderByNullType::ORDER_DEFAULT; + + if (arguments.size() == 2) { + null_order = GetOrder(context, *arguments[1]); + } + auto &config = DBConfig::GetConfig(context); + order = config.ResolveOrder(order); + switch (order) { + case OrderType::ASCENDING: + order = OrderType::DESCENDING; + break; + case OrderType::DESCENDING: + order = OrderType::ASCENDING; + break; + default: + throw InternalException("Unexpected order type in list reverse sort"); + } + null_order = config.ResolveNullOrder(order, null_order); + return ListSortBind(context, bound_function, arguments, order, null_order); +} + +ScalarFunctionSet ListSortFun::GetFunctions() { + // one parameter: list + ScalarFunction sort({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), ListSortFunction, + ListNormalSortBind); + + // two parameters: list, order + ScalarFunction sort_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListNormalSortBind); + + // three parameters: list, order, null order + ScalarFunction sort_orders({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListNormalSortBind); + + ScalarFunctionSet list_sort; + list_sort.AddFunction(sort); + list_sort.AddFunction(sort_order); + list_sort.AddFunction(sort_orders); + return list_sort; +} + +ScalarFunctionSet ListReverseSortFun::GetFunctions() { + // one parameter: list + ScalarFunction sort_reverse({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), + ListSortFunction, ListReverseSortBind); + + // two parameters: list, null order + ScalarFunction sort_reverse_null_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListReverseSortBind); + + ScalarFunctionSet list_reverse_sort; + list_reverse_sort.AddFunction(sort_reverse); + list_reverse_sort.AddFunction(sort_reverse_null_order); + return list_reverse_sort; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/list_value.cpp b/src/duckdb/src/core_functions/scalar/list/list_value.cpp new file mode 100644 index 00000000..ea60d821 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/list_value.cpp @@ -0,0 +1,70 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void ListValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + auto &child_type = ListType::GetChildType(result.GetType()); + + result.SetVectorType(VectorType::CONSTANT_VECTOR); + for (idx_t i = 0; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::FLAT_VECTOR); + } + } + + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < args.size(); i++) { + result_data[i].offset = ListVector::GetListSize(result); + for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { + auto val = args.GetValue(col_idx, i).DefaultCastAs(child_type); + ListVector::PushBack(result, val); + } + result_data[i].length = args.ColumnCount(); + } + result.Verify(args.size()); +} + +static unique_ptr ListValueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // collect names and deconflict, construct return type + LogicalType child_type = arguments.empty() ? LogicalType::SQLNULL : arguments[0]->return_type; + for (idx_t i = 1; i < arguments.size(); i++) { + child_type = LogicalType::MaxLogicalType(child_type, arguments[i]->return_type); + } + + // this is more for completeness reasons + bound_function.varargs = child_type; + bound_function.return_type = LogicalType::LIST(child_type); + return make_uniq(bound_function.return_type); +} + +unique_ptr ListValueStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + auto list_stats = ListStats::CreateEmpty(expr.return_type); + auto &list_child_stats = ListStats::GetChildStats(list_stats); + for (idx_t i = 0; i < child_stats.size(); i++) { + list_child_stats.Merge(child_stats[i]); + } + return list_stats.ToUnique(); +} + +ScalarFunction ListValueFun::GetFunction() { + // the arguments and return types are actually set in the binder function + ScalarFunction fun("list_value", {}, LogicalTypeId::LIST, ListValueFunction, ListValueBind, nullptr, + ListValueStats); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/range.cpp b/src/duckdb/src/core_functions/scalar/list/range.cpp new file mode 100644 index 00000000..4bb86853 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/range.cpp @@ -0,0 +1,275 @@ +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/timestamp.hpp" + +namespace duckdb { + +struct NumericRangeInfo { + using TYPE = int64_t; + using INCREMENT_TYPE = int64_t; + + static int64_t DefaultStart() { + return 0; + } + static int64_t DefaultIncrement() { + return 1; + } + + static uint64_t ListLength(int64_t start_value, int64_t end_value, int64_t increment_value, bool inclusive_bound) { + if (increment_value == 0) { + return 0; + } + if (start_value > end_value && increment_value > 0) { + return 0; + } + if (start_value < end_value && increment_value < 0) { + return 0; + } + hugeint_t total_diff = AbsValue(hugeint_t(end_value) - hugeint_t(start_value)); + hugeint_t increment = AbsValue(hugeint_t(increment_value)); + hugeint_t total_values = total_diff / increment; + if (total_diff % increment == 0) { + if (inclusive_bound) { + total_values += 1; + } + } else { + total_values += 1; + } + if (total_values > NumericLimits::Maximum()) { + throw InvalidInputException("Lists larger than 2^32 elements are not supported"); + } + return Hugeint::Cast(total_values); + } + + static void Increment(int64_t &input, int64_t increment) { + input += increment; + } +}; +struct TimestampRangeInfo { + using TYPE = timestamp_t; + using INCREMENT_TYPE = interval_t; + + static timestamp_t DefaultStart() { + throw InternalException("Default start not implemented for timestamp range"); + } + static interval_t DefaultIncrement() { + throw InternalException("Default increment not implemented for timestamp range"); + } + static uint64_t ListLength(timestamp_t start_value, timestamp_t end_value, interval_t increment_value, + bool inclusive_bound) { + bool is_positive = increment_value.months > 0 || increment_value.days > 0 || increment_value.micros > 0; + bool is_negative = increment_value.months < 0 || increment_value.days < 0 || increment_value.micros < 0; + if (!is_negative && !is_positive) { + // interval is 0: no result + return 0; + } + // We don't allow infinite bounds because they generate errors or infinite loops + if (!Timestamp::IsFinite(start_value) || !Timestamp::IsFinite(end_value)) { + throw InvalidInputException("Interval infinite bounds not supported"); + } + + if (is_negative && is_positive) { + // we don't allow a mix of + throw InvalidInputException("Interval with mix of negative/positive entries not supported"); + } + if (start_value > end_value && is_positive) { + return 0; + } + if (start_value < end_value && is_negative) { + return 0; + } + int64_t total_values = 0; + if (is_negative) { + // negative interval, start_value is going down + while (inclusive_bound ? start_value >= end_value : start_value > end_value) { + start_value = Interval::Add(start_value, increment_value); + total_values++; + if (total_values > NumericLimits::Maximum()) { + throw InvalidInputException("Lists larger than 2^32 elements are not supported"); + } + } + } else { + // positive interval, start_value is going up + while (inclusive_bound ? start_value <= end_value : start_value < end_value) { + start_value = Interval::Add(start_value, increment_value); + total_values++; + if (total_values > NumericLimits::Maximum()) { + throw InvalidInputException("Lists larger than 2^32 elements are not supported"); + } + } + } + return total_values; + } + + static void Increment(timestamp_t &input, interval_t increment) { + input = Interval::Add(input, increment); + } +}; + +template +class RangeInfoStruct { +public: + explicit RangeInfoStruct(DataChunk &args_p) : args(args_p) { + switch (args.ColumnCount()) { + case 1: + args.data[0].ToUnifiedFormat(args.size(), vdata[0]); + break; + case 2: + args.data[0].ToUnifiedFormat(args.size(), vdata[0]); + args.data[1].ToUnifiedFormat(args.size(), vdata[1]); + break; + case 3: + args.data[0].ToUnifiedFormat(args.size(), vdata[0]); + args.data[1].ToUnifiedFormat(args.size(), vdata[1]); + args.data[2].ToUnifiedFormat(args.size(), vdata[2]); + break; + default: + throw InternalException("Unsupported number of parameters for range"); + } + } + + bool RowIsValid(idx_t row_idx) { + for (idx_t i = 0; i < args.ColumnCount(); i++) { + auto idx = vdata[i].sel->get_index(row_idx); + if (!vdata[i].validity.RowIsValid(idx)) { + return false; + } + } + return true; + } + + typename OP::TYPE StartListValue(idx_t row_idx) { + if (args.ColumnCount() == 1) { + return OP::DefaultStart(); + } else { + auto data = (typename OP::TYPE *)vdata[0].data; + auto idx = vdata[0].sel->get_index(row_idx); + return data[idx]; + } + } + + typename OP::TYPE EndListValue(idx_t row_idx) { + idx_t vdata_idx = args.ColumnCount() == 1 ? 0 : 1; + auto data = (typename OP::TYPE *)vdata[vdata_idx].data; + auto idx = vdata[vdata_idx].sel->get_index(row_idx); + return data[idx]; + } + + typename OP::INCREMENT_TYPE ListIncrementValue(idx_t row_idx) { + if (args.ColumnCount() < 3) { + return OP::DefaultIncrement(); + } else { + auto data = (typename OP::INCREMENT_TYPE *)vdata[2].data; + auto idx = vdata[2].sel->get_index(row_idx); + return data[idx]; + } + } + + void GetListValues(idx_t row_idx, typename OP::TYPE &start_value, typename OP::TYPE &end_value, + typename OP::INCREMENT_TYPE &increment_value) { + start_value = StartListValue(row_idx); + end_value = EndListValue(row_idx); + increment_value = ListIncrementValue(row_idx); + } + + uint64_t ListLength(idx_t row_idx) { + typename OP::TYPE start_value; + typename OP::TYPE end_value; + typename OP::INCREMENT_TYPE increment_value; + GetListValues(row_idx, start_value, end_value, increment_value); + return OP::ListLength(start_value, end_value, increment_value, INCLUSIVE_BOUND); + } + +private: + DataChunk &args; + UnifiedVectorFormat vdata[3]; +}; + +template +static void ListRangeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + + RangeInfoStruct info(args); + idx_t args_size = 1; + auto result_type = VectorType::CONSTANT_VECTOR; + for (idx_t i = 0; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + args_size = args.size(); + result_type = VectorType::FLAT_VECTOR; + break; + } + } + auto list_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + int64_t total_size = 0; + for (idx_t i = 0; i < args_size; i++) { + if (!info.RowIsValid(i)) { + result_validity.SetInvalid(i); + list_data[i].offset = total_size; + list_data[i].length = 0; + } else { + list_data[i].offset = total_size; + list_data[i].length = info.ListLength(i); + total_size += list_data[i].length; + } + } + + // now construct the child vector of the list + ListVector::Reserve(result, total_size); + auto range_data = FlatVector::GetData(ListVector::GetEntry(result)); + idx_t total_idx = 0; + for (idx_t i = 0; i < args_size; i++) { + typename OP::TYPE start_value = info.StartListValue(i); + typename OP::INCREMENT_TYPE increment = info.ListIncrementValue(i); + + typename OP::TYPE range_value = start_value; + for (idx_t range_idx = 0; range_idx < list_data[i].length; range_idx++) { + if (range_idx > 0) { + OP::Increment(range_value, increment); + } + range_data[total_idx++] = range_value; + } + } + + ListVector::SetListSize(result, total_size); + result.SetVectorType(result_type); + + result.Verify(args.size()); +} + +ScalarFunctionSet ListRangeFun::GetFunctions() { + // the arguments and return types are actually set in the binder function + ScalarFunctionSet range_set; + range_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + range_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + range_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + range_set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + LogicalType::LIST(LogicalType::TIMESTAMP), + ListRangeFunction)); + return range_set; +} + +ScalarFunctionSet GenerateSeriesFun::GetFunctions() { + ScalarFunctionSet generate_series; + generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + generate_series.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + LogicalType::LIST(LogicalType::TIMESTAMP), + ListRangeFunction)); + return generate_series; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/map/cardinality.cpp b/src/duckdb/src/core_functions/scalar/map/cardinality.cpp new file mode 100644 index 00000000..8bf0dbd1 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/map/cardinality.cpp @@ -0,0 +1,50 @@ +#include "duckdb/core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void CardinalityFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &map = args.data[0]; + UnifiedVectorFormat map_data; + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + map.ToUnifiedFormat(args.size(), map_data); + for (idx_t row = 0; row < args.size(); row++) { + auto list_entry = UnifiedVectorFormat::GetData(map_data)[map_data.sel->get_index(row)]; + result_data[row] = list_entry.length; + result_validity.Set(row, map_data.validity.RowIsValid(map_data.sel->get_index(row))); + } + + if (args.size() == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr CardinalityBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.size() != 1) { + throw BinderException("Cardinality must have exactly one arguments"); + } + + if (arguments[0]->return_type.id() != LogicalTypeId::MAP) { + throw BinderException("Cardinality can only operate on MAPs"); + } + + bound_function.return_type = LogicalType::UBIGINT; + return make_uniq(bound_function.return_type); +} + +ScalarFunction CardinalityFun::GetFunction() { + ScalarFunction fun({LogicalType::ANY}, LogicalType::UBIGINT, CardinalityFunction, CardinalityBind); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/map/map.cpp b/src/duckdb/src/core_functions/scalar/map/map.cpp new file mode 100644 index 00000000..b4c5669a --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/map/map.cpp @@ -0,0 +1,188 @@ +#include "duckdb/core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types/value_map.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +// Example: +// source: [1,2,3], expansion_factor: 4 +// target (result): [1,2,3,1,2,3,1,2,3,1,2,3] +static void CreateExpandedVector(const Vector &source, Vector &target, idx_t expansion_factor) { + idx_t count = ListVector::GetListSize(source); + auto &entry = ListVector::GetEntry(source); + + idx_t target_idx = 0; + for (idx_t copy = 0; copy < expansion_factor; copy++) { + for (idx_t key_idx = 0; key_idx < count; key_idx++) { + target.SetValue(target_idx, entry.GetValue(key_idx)); + target_idx++; + } + } + D_ASSERT(target_idx == count * expansion_factor); +} + +static void AlignVectorToReference(const Vector &original, const Vector &reference, idx_t tuple_count, Vector &result) { + auto original_length = ListVector::GetListSize(original); + auto new_length = ListVector::GetListSize(reference); + + Vector expanded_const(ListType::GetChildType(original.GetType()), new_length); + + auto expansion_factor = new_length / original_length; + if (expansion_factor != tuple_count) { + throw InvalidInputException("Error in MAP creation: key list and value list do not align. i.e. different " + "size or incompatible structure"); + } + CreateExpandedVector(original, expanded_const, expansion_factor); + result.Reference(expanded_const); +} + +static bool ListEntriesEqual(Vector &keys, Vector &values, idx_t count) { + auto key_count = ListVector::GetListSize(keys); + auto value_count = ListVector::GetListSize(values); + bool same_vector_type = keys.GetVectorType() == values.GetVectorType(); + + D_ASSERT(keys.GetType().id() == LogicalTypeId::LIST); + D_ASSERT(values.GetType().id() == LogicalTypeId::LIST); + + UnifiedVectorFormat keys_data; + UnifiedVectorFormat values_data; + + keys.ToUnifiedFormat(count, keys_data); + values.ToUnifiedFormat(count, values_data); + + auto keys_entries = UnifiedVectorFormat::GetData(keys_data); + auto values_entries = UnifiedVectorFormat::GetData(values_data); + + if (same_vector_type) { + const auto key_data = keys_data.data; + const auto value_data = values_data.data; + + if (keys.GetVectorType() == VectorType::CONSTANT_VECTOR) { + D_ASSERT(values.GetVectorType() == VectorType::CONSTANT_VECTOR); + // Only need to compare one entry in this case + return memcmp(key_data, value_data, sizeof(list_entry_t)) == 0; + } + + // Fast path if the vector types are equal, can just check if the entries are the same + if (key_count != value_count) { + return false; + } + return memcmp(key_data, value_data, count * sizeof(list_entry_t)) == 0; + } + + // Compare the list_entries one by one + for (idx_t i = 0; i < count; i++) { + auto keys_idx = keys_data.sel->get_index(i); + auto values_idx = values_data.sel->get_index(i); + + if (keys_entries[keys_idx] != values_entries[values_idx]) { + return false; + } + } + return true; +} + +static void MapFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); + + auto &key_vector = MapVector::GetKeys(result); + auto &value_vector = MapVector::GetValues(result); + auto result_data = ListVector::GetData(result); + + result.SetVectorType(VectorType::CONSTANT_VECTOR); + if (args.data.empty()) { + ListVector::SetListSize(result, 0); + result_data->offset = 0; + result_data->length = 0; + result.Verify(args.size()); + return; + } + + bool keys_are_const = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR; + bool values_are_const = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR; + if (!keys_are_const || !values_are_const) { + result.SetVectorType(VectorType::FLAT_VECTOR); + } + + auto key_count = ListVector::GetListSize(args.data[0]); + auto value_count = ListVector::GetListSize(args.data[1]); + auto key_data = ListVector::GetData(args.data[0]); + auto value_data = ListVector::GetData(args.data[1]); + auto src_data = key_data; + + if (keys_are_const && !values_are_const) { + AlignVectorToReference(args.data[0], args.data[1], args.size(), key_vector); + src_data = value_data; + } else if (values_are_const && !keys_are_const) { + AlignVectorToReference(args.data[1], args.data[0], args.size(), value_vector); + } else { + if (!ListEntriesEqual(args.data[0], args.data[1], args.size())) { + throw InvalidInputException("Error in MAP creation: key list and value list do not align. i.e. different " + "size or incompatible structure"); + } + } + + ListVector::SetListSize(result, MaxValue(key_count, value_count)); + + result_data = ListVector::GetData(result); + for (idx_t i = 0; i < args.size(); i++) { + result_data[i] = src_data[i]; + } + + // check whether one of the vectors has already been referenced to an expanded vector in the case of const/non-const + // combination. If not, then referencing is still necessary + if (!(keys_are_const && !values_are_const)) { + key_vector.Reference(ListVector::GetEntry(args.data[0])); + } + if (!(values_are_const && !keys_are_const)) { + value_vector.Reference(ListVector::GetEntry(args.data[1])); + } + + MapVector::MapConversionVerify(result, args.size()); + result.Verify(args.size()); +} + +static unique_ptr MapBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + child_list_t child_types; + + if (arguments.size() != 2 && !arguments.empty()) { + throw Exception("We need exactly two lists for a map"); + } + if (arguments.size() == 2) { + if (arguments[0]->return_type.id() != LogicalTypeId::LIST) { + throw Exception("First argument is not a list"); + } + if (arguments[1]->return_type.id() != LogicalTypeId::LIST) { + throw Exception("Second argument is not a list"); + } + child_types.push_back(make_pair("key", arguments[0]->return_type)); + child_types.push_back(make_pair("value", arguments[1]->return_type)); + } + + if (arguments.empty()) { + auto empty = LogicalType::LIST(LogicalTypeId::SQLNULL); + child_types.push_back(make_pair("key", empty)); + child_types.push_back(make_pair("value", empty)); + } + + bound_function.return_type = + LogicalType::MAP(ListType::GetChildType(child_types[0].second), ListType::GetChildType(child_types[1].second)); + + return make_uniq(bound_function.return_type); +} + +ScalarFunction MapFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction fun({}, LogicalTypeId::MAP, MapFunction, MapBind); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/map/map_concat.cpp b/src/duckdb/src/core_functions/scalar/map/map_concat.cpp new file mode 100644 index 00000000..1a6a2702 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/map/map_concat.cpp @@ -0,0 +1,187 @@ +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/core_functions/scalar/map_functions.hpp" + +namespace duckdb { + +namespace { + +struct MapKeyIndexPair { + MapKeyIndexPair(idx_t map, idx_t key) : map_index(map), key_index(key) { + } + // The index of the map that this key comes from + idx_t map_index; + // The index within the maps key_list + idx_t key_index; +}; + +} // namespace + +vector GetListEntries(vector keys, vector values) { + D_ASSERT(keys.size() == values.size()); + vector entries; + for (idx_t i = 0; i < keys.size(); i++) { + child_list_t children; + children.emplace_back(make_pair("key", std::move(keys[i]))); + children.emplace_back(make_pair("value", std::move(values[i]))); + entries.push_back(Value::STRUCT(std::move(children))); + } + return entries; +} + +static void MapConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { + if (result.GetType().id() == LogicalTypeId::SQLNULL) { + // All inputs are NULL, just return NULL + auto &validity = FlatVector::Validity(result); + validity.SetInvalid(0); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + return; + } + D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); + auto count = args.size(); + + auto map_count = args.ColumnCount(); + vector map_formats(map_count); + for (idx_t i = 0; i < map_count; i++) { + auto &map = args.data[i]; + map.ToUnifiedFormat(count, map_formats[i]); + } + auto result_data = FlatVector::GetData(result); + + for (idx_t i = 0; i < count; i++) { + // Loop through all the maps per list + // we cant do better because all the entries of the child vector have to be contiguous + // so we cant start the next row before we have finished the one before it + auto &result_entry = result_data[i]; + vector index_to_map; + vector keys_list; + for (idx_t map_idx = 0; map_idx < map_count; map_idx++) { + if (args.data[map_idx].GetType().id() == LogicalTypeId::SQLNULL) { + continue; + } + auto &map_format = map_formats[map_idx]; + auto &keys = MapVector::GetKeys(args.data[map_idx]); + + auto index = map_format.sel->get_index(i); + auto entry = UnifiedVectorFormat::GetData(map_format)[index]; + + // Update the list for this row + for (idx_t list_idx = 0; list_idx < entry.length; list_idx++) { + auto key_index = entry.offset + list_idx; + auto key = keys.GetValue(key_index); + auto entry = std::find(keys_list.begin(), keys_list.end(), key); + if (entry == keys_list.end()) { + // Result list does not contain this value yet + keys_list.push_back(key); + index_to_map.emplace_back(map_idx, key_index); + } else { + // Result list already contains this, update where to find the value at + auto distance = std::distance(keys_list.begin(), entry); + auto &mapping = *(index_to_map.begin() + distance); + mapping.key_index = key_index; + mapping.map_index = map_idx; + } + } + } + vector values_list; + D_ASSERT(keys_list.size() == index_to_map.size()); + // Get the values from the mapping + for (auto &mapping : index_to_map) { + auto &map = args.data[mapping.map_index]; + auto &values = MapVector::GetValues(map); + values_list.push_back(values.GetValue(mapping.key_index)); + } + D_ASSERT(values_list.size() == keys_list.size()); + result_entry.offset = ListVector::GetListSize(result); + result_entry.length = values_list.size(); + auto list_entries = GetListEntries(std::move(keys_list), std::move(values_list)); + for (auto &list_entry : list_entries) { + ListVector::PushBack(result, list_entry); + } + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +static bool IsEmptyMap(const LogicalType &map) { + D_ASSERT(map.id() == LogicalTypeId::MAP); + auto &key_type = MapType::KeyType(map); + auto &value_type = MapType::ValueType(map); + return key_type.id() == LogicalType::SQLNULL && value_type.id() == LogicalType::SQLNULL; +} + +static unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + auto arg_count = arguments.size(); + if (arg_count < 2) { + throw InvalidInputException("The provided amount of arguments is incorrect, please provide 2 or more maps"); + } + + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + // Prepared statement + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + + LogicalType expected = LogicalType::SQLNULL; + + bool is_null = true; + // Check and verify that all the maps are of the same type + for (idx_t i = 0; i < arg_count; i++) { + auto &arg = arguments[i]; + auto &map = arg->return_type; + if (map.id() == LogicalTypeId::UNKNOWN) { + // Prepared statement + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + if (map.id() == LogicalTypeId::SQLNULL) { + // The maps are allowed to be NULL + continue; + } + if (map.id() != LogicalTypeId::MAP) { + throw InvalidInputException("MAP_CONCAT only takes map arguments"); + } + is_null = false; + if (IsEmptyMap(map)) { + // Map is allowed to be empty + continue; + } + + if (expected.id() == LogicalTypeId::SQLNULL) { + expected = map; + } else if (map != expected) { + throw InvalidInputException( + "'value' type of map differs between arguments, expected '%s', found '%s' instead", expected.ToString(), + map.ToString()); + } + } + + if (expected.id() == LogicalTypeId::SQLNULL && is_null == false) { + expected = LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL); + } + bound_function.return_type = expected; + return make_uniq(bound_function.return_type); +} + +ScalarFunction MapConcatFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction fun("map_concat", {}, LogicalTypeId::LIST, MapConcatFunction, MapConcatBind); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.varargs = LogicalType::ANY; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/map/map_entries.cpp b/src/duckdb/src/core_functions/scalar/map/map_entries.cpp new file mode 100644 index 00000000..caaeccee --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/map/map_entries.cpp @@ -0,0 +1,62 @@ +#include "duckdb/core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +// Reverse of map_from_entries +static void MapEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + idx_t count = args.size(); + + result.Reinterpret(args.data[0]); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +static unique_ptr MapEntriesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + child_list_t child_types; + + if (arguments.size() != 1) { + throw InvalidInputException("Too many arguments provided, only expecting a single map"); + } + auto &map = arguments[0]->return_type; + + if (map.id() == LogicalTypeId::UNKNOWN) { + // Prepared statement + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + + if (map.id() != LogicalTypeId::MAP) { + throw InvalidInputException("The provided argument is not a map"); + } + auto &key_type = MapType::KeyType(map); + auto &value_type = MapType::ValueType(map); + + child_types.push_back(make_pair("key", key_type)); + child_types.push_back(make_pair("value", value_type)); + + auto row_type = LogicalType::STRUCT(child_types); + + bound_function.return_type = LogicalType::LIST(row_type); + return make_uniq(bound_function.return_type); +} + +ScalarFunction MapEntriesFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction fun({}, LogicalTypeId::LIST, MapEntriesFunction, MapEntriesBind); + fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.varargs = LogicalType::ANY; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/map/map_extract.cpp b/src/duckdb/src/core_functions/scalar/map/map_extract.cpp new file mode 100644 index 00000000..2986a7f6 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/map/map_extract.cpp @@ -0,0 +1,148 @@ +#include "duckdb/core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +struct MapKeyArgFunctor { + // MAP is a LIST(STRUCT(K,V)) + // meaning the MAP itself is a List, but the child vector that we're interested in (the keys) + // are a level deeper than the initial child vector + + static Vector &GetList(Vector &map) { + return map; + } + static idx_t GetListSize(Vector &map) { + return ListVector::GetListSize(map); + } + static Vector &GetEntry(Vector &map) { + return MapVector::GetKeys(map); + } +}; + +void FillResult(Vector &map, Vector &offsets, Vector &result, idx_t count) { + UnifiedVectorFormat map_data; + map.ToUnifiedFormat(count, map_data); + + UnifiedVectorFormat offset_data; + offsets.ToUnifiedFormat(count, offset_data); + + auto result_data = FlatVector::GetData(result); + auto entry_count = ListVector::GetListSize(map); + auto &values_entries = MapVector::GetValues(map); + UnifiedVectorFormat values_entry_data; + // Note: this vector can have a different size than the map + values_entries.ToUnifiedFormat(entry_count, values_entry_data); + + for (idx_t row = 0; row < count; row++) { + idx_t offset_idx = offset_data.sel->get_index(row); + auto offset = UnifiedVectorFormat::GetData(offset_data)[offset_idx]; + + // Get the current size of the list, for the offset + idx_t current_offset = ListVector::GetListSize(result); + if (!offset_data.validity.RowIsValid(offset_idx) || !offset) { + // Set the entry data for this result row + auto &entry = result_data[row]; + entry.length = 0; + entry.offset = current_offset; + continue; + } + // All list indices start at 1, reduce by 1 to get the actual index + offset--; + + // Get the 'values' list entry corresponding to the offset + idx_t value_index = map_data.sel->get_index(row); + auto &value_list_entry = UnifiedVectorFormat::GetData(map_data)[value_index]; + + // Add the values to the result + idx_t list_offset = value_list_entry.offset + offset; + // All keys are unique, only one will ever match + idx_t length = 1; + ListVector::Append(result, values_entries, length + list_offset, list_offset); + + // Set the entry data for this result row + auto &entry = result_data[row]; + entry.length = length; + entry.offset = current_offset; + } +} + +static void MapExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.data.size() == 2); + D_ASSERT(args.data[0].GetType().id() == LogicalTypeId::MAP); + result.SetVectorType(VectorType::FLAT_VECTOR); + + idx_t tuple_count = args.size(); + // Optimization: because keys are not allowed to be NULL, we can early-out + if (args.data[1].GetType().id() == LogicalTypeId::SQLNULL) { + //! We don't need to look through the map if the 'key' to look for is NULL + ListVector::SetListSize(result, 0); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto list_data = ConstantVector::GetData(result); + list_data->offset = 0; + list_data->length = 0; + result.Verify(tuple_count); + return; + } + + auto &map = args.data[0]; + auto &key = args.data[1]; + + UnifiedVectorFormat map_data; + + // Create the chunk we'll feed to ListPosition + DataChunk list_position_chunk; + vector chunk_types; + chunk_types.reserve(2); + chunk_types.push_back(map.GetType()); + chunk_types.push_back(key.GetType()); + list_position_chunk.InitializeEmpty(chunk_types.begin(), chunk_types.end()); + + // Populate it with the map keys list and the key vector + list_position_chunk.data[0].Reference(map); + list_position_chunk.data[1].Reference(key); + list_position_chunk.SetCardinality(tuple_count); + + Vector position_vector(LogicalType::LIST(LogicalType::INTEGER), tuple_count); + // We can pass around state as it's not used by ListPositionFunction anyways + ListContainsOrPosition(list_position_chunk, position_vector); + + FillResult(map, position_vector, result, tuple_count); + + if (tuple_count == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + + result.Verify(tuple_count); +} + +static unique_ptr MapExtractBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.size() != 2) { + throw BinderException("MAP_EXTRACT must have exactly two arguments"); + } + if (arguments[0]->return_type.id() != LogicalTypeId::MAP) { + throw BinderException("MAP_EXTRACT can only operate on MAPs"); + } + auto &value_type = MapType::ValueType(arguments[0]->return_type); + + //! Here we have to construct the List Type that will be returned + bound_function.return_type = LogicalType::LIST(value_type); + auto key_type = MapType::KeyType(arguments[0]->return_type); + if (key_type.id() != LogicalTypeId::SQLNULL && arguments[1]->return_type.id() != LogicalTypeId::SQLNULL) { + bound_function.arguments[1] = MapType::KeyType(arguments[0]->return_type); + } + return make_uniq(value_type); +} + +ScalarFunction MapExtractFun::GetFunction() { + ScalarFunction fun({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, MapExtractFunction, MapExtractBind); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/map/map_from_entries.cpp b/src/duckdb/src/core_functions/scalar/map/map_from_entries.cpp new file mode 100644 index 00000000..be79503e --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/map/map_from_entries.cpp @@ -0,0 +1,60 @@ +#include "duckdb/core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void MapFromEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto count = args.size(); + + result.Reinterpret(args.data[0]); + + MapVector::MapConversionVerify(result, count); + result.Verify(count); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr MapFromEntriesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.size() != 1) { + throw InvalidInputException("The input argument must be a list of structs."); + } + auto &list = arguments[0]->return_type; + + if (list.id() == LogicalTypeId::UNKNOWN) { + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + + if (list.id() != LogicalTypeId::LIST) { + throw InvalidInputException("The provided argument is not a list of structs"); + } + auto &elem_type = ListType::GetChildType(list); + if (elem_type.id() != LogicalTypeId::STRUCT) { + throw InvalidInputException("The elements of the list must be structs"); + } + auto &children = StructType::GetChildTypes(elem_type); + if (children.size() != 2) { + throw InvalidInputException("The provided struct type should only contain 2 fields, a key and a value"); + } + + bound_function.return_type = LogicalType::MAP(elem_type); + return make_uniq(bound_function.return_type); +} + +ScalarFunction MapFromEntriesFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction fun({}, LogicalTypeId::MAP, MapFromEntriesFunction, MapFromEntriesBind); + fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.varargs = LogicalType::ANY; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/map/map_keys_values.cpp b/src/duckdb/src/core_functions/scalar/map/map_keys_values.cpp new file mode 100644 index 00000000..6c1e8efb --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/map/map_keys_values.cpp @@ -0,0 +1,98 @@ +#include "duckdb/core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void MapKeyValueFunction(DataChunk &args, ExpressionState &state, Vector &result, + Vector &(*get_child_vector)(Vector &)) { + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + auto count = args.size(); + + auto &map = args.data[0]; + D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); + auto child = get_child_vector(map); + + auto &entries = ListVector::GetEntry(result); + entries.Reference(child); + + UnifiedVectorFormat map_data; + map.ToUnifiedFormat(count, map_data); + + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + FlatVector::SetData(result, map_data.data); + FlatVector::SetValidity(result, map_data.validity); + auto list_size = ListVector::GetListSize(map); + ListVector::SetListSize(result, list_size); + if (map.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + result.Slice(*map_data.sel, count); + } + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +static void MapKeysFunction(DataChunk &args, ExpressionState &state, Vector &result) { + MapKeyValueFunction(args, state, result, MapVector::GetKeys); +} + +static void MapValuesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + MapKeyValueFunction(args, state, result, MapVector::GetValues); +} + +static unique_ptr MapKeyValueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments, + const LogicalType &(*type_func)(const LogicalType &)) { + if (arguments.size() != 1) { + throw InvalidInputException("Too many arguments provided, only expecting a single map"); + } + auto &map = arguments[0]->return_type; + + if (map.id() == LogicalTypeId::UNKNOWN) { + // Prepared statement + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + + if (map.id() != LogicalTypeId::MAP) { + throw InvalidInputException("The provided argument is not a map"); + } + + auto &type = type_func(map); + + bound_function.return_type = LogicalType::LIST(type); + return make_uniq(bound_function.return_type); +} + +static unique_ptr MapKeysBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return MapKeyValueBind(context, bound_function, arguments, MapType::KeyType); +} + +static unique_ptr MapValuesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return MapKeyValueBind(context, bound_function, arguments, MapType::ValueType); +} + +ScalarFunction MapKeysFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction fun({}, LogicalTypeId::LIST, MapKeysFunction, MapKeysBind); + fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.varargs = LogicalType::ANY; + return fun; +} + +ScalarFunction MapValuesFun::GetFunction() { + ScalarFunction fun({}, LogicalTypeId::LIST, MapValuesFunction, MapValuesBind); + fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.varargs = LogicalType::ANY; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/math/numeric.cpp b/src/duckdb/src/core_functions/scalar/math/numeric.cpp new file mode 100644 index 00000000..19c841b2 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/math/numeric.cpp @@ -0,0 +1,1280 @@ +#include "duckdb/core_functions/scalar/math_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/operator/abs.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/types/bit.hpp" +#include +#include + +namespace duckdb { + +template +static scalar_function_t GetScalarIntegerUnaryFunctionFixedReturn(const LogicalType &type) { + scalar_function_t function; + switch (type.id()) { + case LogicalTypeId::TINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::SMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::INTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::BIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::HUGEINT: + function = &ScalarFunction::UnaryFunction; + break; + default: + throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunctionFixedReturn"); + } + return function; +} + +//===--------------------------------------------------------------------===// +// nextafter +//===--------------------------------------------------------------------===// +struct NextAfterOperator { + template + static inline TR Operation(TA base, TB exponent) { + throw NotImplementedException("Unimplemented type for NextAfter Function"); + } + + template + static inline double Operation(double input, double approximate_to) { + return nextafter(input, approximate_to); + } + template + static inline float Operation(float input, float approximate_to) { + return nextafterf(input, approximate_to); + } +}; + +ScalarFunctionSet NextAfterFun::GetFunctions() { + ScalarFunctionSet next_after_fun; + next_after_fun.AddFunction( + ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::BinaryFunction)); + next_after_fun.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, + ScalarFunction::BinaryFunction)); + return next_after_fun; +} + +//===--------------------------------------------------------------------===// +// abs +//===--------------------------------------------------------------------===// +static unique_ptr PropagateAbsStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 1); + // can only propagate stats if the children have stats + auto &lstats = child_stats[0]; + Value new_min, new_max; + bool potential_overflow = true; + if (NumericStats::HasMinMax(lstats)) { + switch (expr.return_type.InternalType()) { + case PhysicalType::INT8: + potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); + break; + case PhysicalType::INT16: + potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); + break; + case PhysicalType::INT32: + potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); + break; + case PhysicalType::INT64: + potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); + break; + default: + return nullptr; + } + } + if (potential_overflow) { + new_min = Value(expr.return_type); + new_max = Value(expr.return_type); + } else { + // no potential overflow + + // compute stats + auto current_min = NumericStats::Min(lstats).GetValue(); + auto current_max = NumericStats::Max(lstats).GetValue(); + + int64_t min_val, max_val; + + if (current_min < 0 && current_max < 0) { + // if both min and max are below zero, then min=abs(cur_max) and max=abs(cur_min) + min_val = AbsValue(current_max); + max_val = AbsValue(current_min); + } else if (current_min < 0) { + D_ASSERT(current_max >= 0); + // if min is below zero and max is above 0, then min=0 and max=max(cur_max, abs(cur_min)) + min_val = 0; + max_val = MaxValue(AbsValue(current_min), current_max); + } else { + // if both current_min and current_max are > 0, then the abs is a no-op and can be removed entirely + *input.expr_ptr = std::move(input.expr.children[0]); + return child_stats[0].ToUnique(); + } + new_min = Value::Numeric(expr.return_type, min_val); + new_max = Value::Numeric(expr.return_type, max_val); + expr.function.function = ScalarFunction::GetScalarUnaryFunction(expr.return_type); + } + auto stats = NumericStats::CreateEmpty(expr.return_type); + NumericStats::SetMin(stats, new_min); + NumericStats::SetMax(stats, new_max); + stats.CopyValidity(lstats); + return stats.ToUnique(); +} + +template +unique_ptr DecimalUnaryOpBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); + break; + case PhysicalType::INT32: + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); + break; + case PhysicalType::INT64: + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); + break; + default: + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); + break; + } + bound_function.arguments[0] = decimal_type; + bound_function.return_type = decimal_type; + return nullptr; +} + +ScalarFunctionSet AbsOperatorFun::GetFunctions() { + ScalarFunctionSet abs; + for (auto &type : LogicalType::Numeric()) { + switch (type.id()) { + case LogicalTypeId::DECIMAL: + abs.AddFunction(ScalarFunction({type}, type, nullptr, DecimalUnaryOpBind)); + break; + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: { + ScalarFunction func({type}, type, ScalarFunction::GetScalarUnaryFunction(type)); + func.statistics = PropagateAbsStats; + abs.AddFunction(func); + break; + } + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + abs.AddFunction(ScalarFunction({type}, type, ScalarFunction::NopFunction)); + break; + default: + abs.AddFunction(ScalarFunction({type}, type, ScalarFunction::GetScalarUnaryFunction(type))); + break; + } + } + return abs; +} + +//===--------------------------------------------------------------------===// +// bit_count +//===--------------------------------------------------------------------===// +struct BitCntOperator { + template + static inline TR Operation(TA input) { + using TU = typename std::make_unsigned::type; + TR count = 0; + for (auto value = TU(input); value; ++count) { + value &= (value - 1); + } + return count; + } +}; + +struct HugeIntBitCntOperator { + template + static inline TR Operation(TA input) { + using TU = typename std::make_unsigned::type; + TR count = 0; + + for (auto value = TU(input.upper); value; ++count) { + value &= (value - 1); + } + for (auto value = TU(input.lower); value; ++count) { + value &= (value - 1); + } + return count; + } +}; + +struct BitStringBitCntOperator { + template + static inline TR Operation(TA input) { + TR count = Bit::BitCount(input); + return count; + } +}; + +ScalarFunctionSet BitCountFun::GetFunctions() { + ScalarFunctionSet functions; + functions.AddFunction(ScalarFunction({LogicalType::TINYINT}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::SMALLINT}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::INTEGER}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::HUGEINT}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction)); + return functions; +} + +//===--------------------------------------------------------------------===// +// sign +//===--------------------------------------------------------------------===// +struct SignOperator { + template + static TR Operation(TA input) { + if (input == TA(0)) { + return 0; + } else if (input > TA(0)) { + return 1; + } else { + return -1; + } + } +}; + +template <> +int8_t SignOperator::Operation(float input) { + if (input == 0 || Value::IsNan(input)) { + return 0; + } else if (input > 0) { + return 1; + } else { + return -1; + } +} + +template <> +int8_t SignOperator::Operation(double input) { + if (input == 0 || Value::IsNan(input)) { + return 0; + } else if (input > 0) { + return 1; + } else { + return -1; + } +} + +ScalarFunctionSet SignFun::GetFunctions() { + ScalarFunctionSet sign; + for (auto &type : LogicalType::Numeric()) { + if (type.id() == LogicalTypeId::DECIMAL) { + continue; + } else { + sign.AddFunction( + ScalarFunction({type}, LogicalType::TINYINT, + ScalarFunction::GetScalarUnaryFunctionFixedReturn(type))); + } + } + return sign; +} + +//===--------------------------------------------------------------------===// +// ceil +//===--------------------------------------------------------------------===// +struct CeilOperator { + template + static inline TR Operation(TA left) { + return std::ceil(left); + } +}; + +template +static void GenericRoundFunctionDecimal(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + OP::template Operation(input, DecimalType::GetScale(func_expr.children[0]->return_type), result); +} + +template +unique_ptr BindGenericRoundFunctionDecimal(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // ceil essentially removes the scale + auto &decimal_type = arguments[0]->return_type; + auto scale = DecimalType::GetScale(decimal_type); + auto width = DecimalType::GetWidth(decimal_type); + if (scale == 0) { + bound_function.function = ScalarFunction::NopFunction; + } else { + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = GenericRoundFunctionDecimal; + break; + case PhysicalType::INT32: + bound_function.function = GenericRoundFunctionDecimal; + break; + case PhysicalType::INT64: + bound_function.function = GenericRoundFunctionDecimal; + break; + default: + bound_function.function = GenericRoundFunctionDecimal; + break; + } + } + bound_function.arguments[0] = decimal_type; + bound_function.return_type = LogicalType::DECIMAL(width, 0); + return nullptr; +} + +struct CeilDecimalOperator { + template + static void Operation(DataChunk &input, uint8_t scale, Vector &result) { + T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]; + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + // below 0 we floor the number (e.g. -10.5 -> -10) + return input / power_of_ten; + } else { + // above 0 we ceil the number + return ((input - 1) / power_of_ten) + 1; + } + }); + } +}; + +ScalarFunctionSet CeilFun::GetFunctions() { + ScalarFunctionSet ceil; + for (auto &type : LogicalType::Numeric()) { + scalar_function_t func = nullptr; + bind_scalar_function_t bind_func = nullptr; + if (type.IsIntegral()) { + // no ceil for integral numbers + continue; + } + switch (type.id()) { + case LogicalTypeId::FLOAT: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DOUBLE: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DECIMAL: + bind_func = BindGenericRoundFunctionDecimal; + break; + default: + throw InternalException("Unimplemented numeric type for function \"ceil\""); + } + ceil.AddFunction(ScalarFunction({type}, type, func, bind_func)); + } + return ceil; +} + +//===--------------------------------------------------------------------===// +// floor +//===--------------------------------------------------------------------===// +struct FloorOperator { + template + static inline TR Operation(TA left) { + return std::floor(left); + } +}; + +struct FloorDecimalOperator { + template + static void Operation(DataChunk &input, uint8_t scale, Vector &result) { + T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]; + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + // below 0 we ceil the number (e.g. -10.5 -> -11) + return ((input + 1) / power_of_ten) - 1; + } else { + // above 0 we floor the number + return input / power_of_ten; + } + }); + } +}; + +ScalarFunctionSet FloorFun::GetFunctions() { + ScalarFunctionSet floor; + for (auto &type : LogicalType::Numeric()) { + scalar_function_t func = nullptr; + bind_scalar_function_t bind_func = nullptr; + if (type.IsIntegral()) { + // no floor for integral numbers + continue; + } + switch (type.id()) { + case LogicalTypeId::FLOAT: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DOUBLE: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DECIMAL: + bind_func = BindGenericRoundFunctionDecimal; + break; + default: + throw InternalException("Unimplemented numeric type for function \"floor\""); + } + floor.AddFunction(ScalarFunction({type}, type, func, bind_func)); + } + return floor; +} + +//===--------------------------------------------------------------------===// +// trunc +//===--------------------------------------------------------------------===// +struct TruncOperator { + // Integer truncation is a NOP + template + static inline TR Operation(TA left) { + return std::trunc(left); + } +}; + +struct TruncDecimalOperator { + template + static void Operation(DataChunk &input, uint8_t scale, Vector &result) { + T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]; + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + // Always floor + return (input / power_of_ten); + }); + } +}; + +ScalarFunctionSet TruncFun::GetFunctions() { + ScalarFunctionSet trunc; + for (auto &type : LogicalType::Numeric()) { + scalar_function_t func = nullptr; + bind_scalar_function_t bind_func = nullptr; + // Truncation of integers gets generated by some tools (e.g., Tableau/JDBC:Postgres) + switch (type.id()) { + case LogicalTypeId::FLOAT: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DOUBLE: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DECIMAL: + bind_func = BindGenericRoundFunctionDecimal; + break; + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + func = ScalarFunction::NopFunction; + break; + default: + throw InternalException("Unimplemented numeric type for function \"trunc\""); + } + trunc.AddFunction(ScalarFunction({type}, type, func, bind_func)); + } + return trunc; +} + +//===--------------------------------------------------------------------===// +// round +//===--------------------------------------------------------------------===// +struct RoundOperatorPrecision { + template + static inline TR Operation(TA input, TB precision) { + double rounded_value; + if (precision < 0) { + double modifier = std::pow(10, -TA(precision)); + rounded_value = (std::round(input / modifier)) * modifier; + if (std::isinf(rounded_value) || std::isnan(rounded_value)) { + return 0; + } + } else { + double modifier = std::pow(10, TA(precision)); + rounded_value = (std::round(input * modifier)) / modifier; + if (std::isinf(rounded_value) || std::isnan(rounded_value)) { + return input; + } + } + return rounded_value; + } +}; + +struct RoundOperator { + template + static inline TR Operation(TA input) { + double rounded_value = round(input); + if (std::isinf(rounded_value) || std::isnan(rounded_value)) { + return input; + } + return rounded_value; + } +}; + +struct RoundDecimalOperator { + template + static void Operation(DataChunk &input, uint8_t scale, Vector &result) { + T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]; + T addition = power_of_ten / 2; + // regular round rounds towards the nearest number + // in case of a tie we round away from zero + // i.e. -10.5 -> -11, 10.5 -> 11 + // we implement this by adding (positive) or subtracting (negative) 0.5 + // and then flooring the number + // e.g. 10.5 + 0.5 = 11, floor(11) = 11 + // 10.4 + 0.5 = 10.9, floor(10.9) = 10 + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + input -= addition; + } else { + input += addition; + } + return input / power_of_ten; + }); + } +}; + +struct RoundPrecisionFunctionData : public FunctionData { + explicit RoundPrecisionFunctionData(int32_t target_scale) : target_scale(target_scale) { + } + + int32_t target_scale; + + unique_ptr Copy() const override { + return make_uniq(target_scale); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return target_scale == other.target_scale; + } +}; + +template +static void DecimalRoundNegativePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); + auto width = DecimalType::GetWidth(func_expr.children[0]->return_type); + if (info.target_scale <= -int32_t(width)) { + // scale too big for width + result.SetVectorType(VectorType::CONSTANT_VECTOR); + result.SetValue(0, Value::INTEGER(0)); + return; + } + T divide_power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale + source_scale]; + T multiply_power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]; + T addition = divide_power_of_ten / 2; + + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + input -= addition; + } else { + input += addition; + } + return input / divide_power_of_ten * multiply_power_of_ten; + }); +} + +template +static void DecimalRoundPositivePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); + T power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]; + T addition = power_of_ten / 2; + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + input -= addition; + } else { + input += addition; + } + return input / power_of_ten; + }); +} + +unique_ptr BindDecimalRoundPrecision(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto &decimal_type = arguments[0]->return_type; + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); + } + Value val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]).DefaultCastAs(LogicalType::INTEGER); + if (val.IsNull()) { + throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); + } + // our new precision becomes the round value + // e.g. ROUND(DECIMAL(18,3), 1) -> DECIMAL(18,1) + // but ONLY if the round value is positive + // if it is negative the scale becomes zero + // i.e. ROUND(DECIMAL(18,3), -1) -> DECIMAL(18,0) + int32_t round_value = IntegerValue::Get(val); + uint8_t target_scale; + auto width = DecimalType::GetWidth(decimal_type); + auto scale = DecimalType::GetScale(decimal_type); + if (round_value < 0) { + target_scale = 0; + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = DecimalRoundNegativePrecisionFunction; + break; + case PhysicalType::INT32: + bound_function.function = DecimalRoundNegativePrecisionFunction; + break; + case PhysicalType::INT64: + bound_function.function = DecimalRoundNegativePrecisionFunction; + break; + default: + bound_function.function = DecimalRoundNegativePrecisionFunction; + break; + } + } else { + if (round_value >= (int32_t)scale) { + // if round_value is bigger than or equal to scale we do nothing + bound_function.function = ScalarFunction::NopFunction; + target_scale = scale; + } else { + target_scale = round_value; + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = DecimalRoundPositivePrecisionFunction; + break; + case PhysicalType::INT32: + bound_function.function = DecimalRoundPositivePrecisionFunction; + break; + case PhysicalType::INT64: + bound_function.function = DecimalRoundPositivePrecisionFunction; + break; + default: + bound_function.function = DecimalRoundPositivePrecisionFunction; + break; + } + } + } + bound_function.arguments[0] = decimal_type; + bound_function.return_type = LogicalType::DECIMAL(width, target_scale); + return make_uniq(round_value); +} + +ScalarFunctionSet RoundFun::GetFunctions() { + ScalarFunctionSet round; + for (auto &type : LogicalType::Numeric()) { + scalar_function_t round_prec_func = nullptr; + scalar_function_t round_func = nullptr; + bind_scalar_function_t bind_func = nullptr; + bind_scalar_function_t bind_prec_func = nullptr; + if (type.IsIntegral()) { + // no round for integral numbers + continue; + } + switch (type.id()) { + case LogicalTypeId::FLOAT: + round_func = ScalarFunction::UnaryFunction; + round_prec_func = ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::DOUBLE: + round_func = ScalarFunction::UnaryFunction; + round_prec_func = ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::DECIMAL: + bind_func = BindGenericRoundFunctionDecimal; + bind_prec_func = BindDecimalRoundPrecision; + break; + default: + throw InternalException("Unimplemented numeric type for function \"floor\""); + } + round.AddFunction(ScalarFunction({type}, type, round_func, bind_func)); + round.AddFunction(ScalarFunction({type, LogicalType::INTEGER}, type, round_prec_func, bind_prec_func)); + } + return round; +} + +//===--------------------------------------------------------------------===// +// exp +//===--------------------------------------------------------------------===// +struct ExpOperator { + template + static inline TR Operation(TA left) { + return std::exp(left); + } +}; + +ScalarFunction ExpFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// pow +//===--------------------------------------------------------------------===// +struct PowOperator { + template + static inline TR Operation(TA base, TB exponent) { + return std::pow(base, exponent); + } +}; + +ScalarFunction PowOperatorFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::BinaryFunction); +} + +//===--------------------------------------------------------------------===// +// sqrt +//===--------------------------------------------------------------------===// +struct SqrtOperator { + template + static inline TR Operation(TA input) { + if (input < 0) { + throw OutOfRangeException("cannot take square root of a negative number"); + } + return std::sqrt(input); + } +}; + +ScalarFunction SqrtFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// cbrt +//===--------------------------------------------------------------------===// +struct CbRtOperator { + template + static inline TR Operation(TA left) { + return std::cbrt(left); + } +}; + +ScalarFunction CbrtFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// ln +//===--------------------------------------------------------------------===// + +struct LnOperator { + template + static inline TR Operation(TA input) { + if (input < 0) { + throw OutOfRangeException("cannot take logarithm of a negative number"); + } + if (input == 0) { + throw OutOfRangeException("cannot take logarithm of zero"); + } + return std::log(input); + } +}; + +ScalarFunction LnFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// log +//===--------------------------------------------------------------------===// +struct Log10Operator { + template + static inline TR Operation(TA input) { + if (input < 0) { + throw OutOfRangeException("cannot take logarithm of a negative number"); + } + if (input == 0) { + throw OutOfRangeException("cannot take logarithm of zero"); + } + return std::log10(input); + } +}; + +ScalarFunction Log10Fun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// log2 +//===--------------------------------------------------------------------===// +struct Log2Operator { + template + static inline TR Operation(TA input) { + if (input < 0) { + throw OutOfRangeException("cannot take logarithm of a negative number"); + } + if (input == 0) { + throw OutOfRangeException("cannot take logarithm of zero"); + } + return std::log2(input); + } +}; + +ScalarFunction Log2Fun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// pi +//===--------------------------------------------------------------------===// +static void PiFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 0); + Value pi_value = Value::DOUBLE(PI); + result.Reference(pi_value); +} + +ScalarFunction PiFun::GetFunction() { + return ScalarFunction({}, LogicalType::DOUBLE, PiFunction); +} + +//===--------------------------------------------------------------------===// +// degrees +//===--------------------------------------------------------------------===// +struct DegreesOperator { + template + static inline TR Operation(TA left) { + return left * (180 / PI); + } +}; + +ScalarFunction DegreesFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// radians +//===--------------------------------------------------------------------===// +struct RadiansOperator { + template + static inline TR Operation(TA left) { + return left * (PI / 180); + } +}; + +ScalarFunction RadiansFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// isnan +//===--------------------------------------------------------------------===// +struct IsNanOperator { + template + static inline TR Operation(TA input) { + return Value::IsNan(input); + } +}; + +ScalarFunctionSet IsNanFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// signbit +//===--------------------------------------------------------------------===// +struct SignBitOperator { + template + static inline TR Operation(TA input) { + return std::signbit(input); + } +}; + +ScalarFunctionSet SignBitFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// isinf +//===--------------------------------------------------------------------===// +struct IsInfiniteOperator { + template + static inline TR Operation(TA input) { + return !Value::IsNan(input) && !Value::IsFinite(input); + } +}; + +template <> +bool IsInfiniteOperator::Operation(date_t input) { + return !Value::IsFinite(input); +} + +template <> +bool IsInfiniteOperator::Operation(timestamp_t input) { + return !Value::IsFinite(input); +} + +ScalarFunctionSet IsInfiniteFun::GetFunctions() { + ScalarFunctionSet funcs("isinf"); + funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// isfinite +//===--------------------------------------------------------------------===// +struct IsFiniteOperator { + template + static inline TR Operation(TA input) { + return Value::IsFinite(input); + } +}; + +ScalarFunctionSet IsFiniteFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// sin +//===--------------------------------------------------------------------===// +template +struct NoInfiniteDoubleWrapper { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + if (DUCKDB_UNLIKELY(!Value::IsFinite(input))) { + if (Value::IsNan(input)) { + return input; + } + throw OutOfRangeException("input value %lf is out of range for numeric function", input); + } + return OP::template Operation(input); + } +}; + +struct SinOperator { + template + static inline TR Operation(TA input) { + return std::sin(input); + } +}; + +ScalarFunction SinFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); +} + +//===--------------------------------------------------------------------===// +// cos +//===--------------------------------------------------------------------===// +struct CosOperator { + template + static inline TR Operation(TA input) { + return (double)std::cos(input); + } +}; + +ScalarFunction CosFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); +} + +//===--------------------------------------------------------------------===// +// tan +//===--------------------------------------------------------------------===// +struct TanOperator { + template + static inline TR Operation(TA input) { + return (double)std::tan(input); + } +}; + +ScalarFunction TanFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); +} + +//===--------------------------------------------------------------------===// +// asin +//===--------------------------------------------------------------------===// +struct ASinOperator { + template + static inline TR Operation(TA input) { + if (input < -1 || input > 1) { + throw Exception("ASIN is undefined outside [-1,1]"); + } + return (double)std::asin(input); + } +}; + +ScalarFunction AsinFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); +} + +//===--------------------------------------------------------------------===// +// atan +//===--------------------------------------------------------------------===// +struct ATanOperator { + template + static inline TR Operation(TA input) { + return (double)std::atan(input); + } +}; + +ScalarFunction AtanFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// atan2 +//===--------------------------------------------------------------------===// +struct ATan2 { + template + static inline TR Operation(TA left, TB right) { + return (double)std::atan2(left, right); + } +}; + +ScalarFunction Atan2Fun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::BinaryFunction); +} + +//===--------------------------------------------------------------------===// +// acos +//===--------------------------------------------------------------------===// +struct ACos { + template + static inline TR Operation(TA input) { + return (double)std::acos(input); + } +}; + +ScalarFunction AcosFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); +} + +//===--------------------------------------------------------------------===// +// cot +//===--------------------------------------------------------------------===// +struct CotOperator { + template + static inline TR Operation(TA input) { + return 1.0 / (double)std::tan(input); + } +}; + +ScalarFunction CotFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); +} + +//===--------------------------------------------------------------------===// +// gamma +//===--------------------------------------------------------------------===// +struct GammaOperator { + template + static inline TR Operation(TA input) { + if (input == 0) { + throw OutOfRangeException("cannot take gamma of zero"); + } + return std::tgamma(input); + } +}; + +ScalarFunction GammaFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// gamma +//===--------------------------------------------------------------------===// +struct LogGammaOperator { + template + static inline TR Operation(TA input) { + if (input == 0) { + throw OutOfRangeException("cannot take log gamma of zero"); + } + return std::lgamma(input); + } +}; + +ScalarFunction LogGammaFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// factorial(), ! +//===--------------------------------------------------------------------===// +struct FactorialOperator { + template + static inline TR Operation(TA left) { + TR ret = 1; + for (TA i = 2; i <= left; i++) { + ret *= i; + } + return ret; + } +}; + +ScalarFunction FactorialOperatorFun::GetFunction() { + return ScalarFunction({LogicalType::INTEGER}, LogicalType::HUGEINT, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// even +//===--------------------------------------------------------------------===// +struct EvenOperator { + template + static inline TR Operation(TA left) { + double value; + if (left >= 0) { + value = std::ceil(left); + } else { + value = std::ceil(-left); + value = -value; + } + if (std::floor(value / 2) * 2 != value) { + if (left >= 0) { + return value += 1; + } + return value -= 1; + } + return value; + } +}; + +ScalarFunction EvenFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// gcd +//===--------------------------------------------------------------------===// + +// should be replaced with std::gcd in a newer C++ standard +template +TA GreatestCommonDivisor(TA left, TA right) { + TA a = left; + TA b = right; + + // This protects the following modulo operations from a corner case, + // where we would get a runtime error due to an integer overflow. + if ((left == NumericLimits::Minimum() && right == -1) || + (left == -1 && right == NumericLimits::Minimum())) { + return 1; + } + + while (true) { + if (a == 0) { + return TryAbsOperator::Operation(b); + } + b %= a; + + if (b == 0) { + return TryAbsOperator::Operation(a); + } + a %= b; + } +} + +struct GreatestCommonDivisorOperator { + template + static inline TR Operation(TA left, TB right) { + return GreatestCommonDivisor(left, right); + } +}; + +ScalarFunctionSet GreatestCommonDivisorFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, + ScalarFunction::BinaryFunction)); + funcs.AddFunction( + ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, + ScalarFunction::BinaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// lcm +//===--------------------------------------------------------------------===// + +// should be replaced with std::lcm in a newer C++ standard +struct LeastCommonMultipleOperator { + template + static inline TR Operation(TA left, TB right) { + if (left == 0 || right == 0) { + return 0; + } + TR result; + if (!TryMultiplyOperator::Operation(left, right / GreatestCommonDivisor(left, right), result)) { + throw OutOfRangeException("lcm value is out of range"); + } + return TryAbsOperator::Operation(result); + } +}; + +ScalarFunctionSet LeastCommonMultipleFun::GetFunctions() { + ScalarFunctionSet funcs; + + funcs.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, + ScalarFunction::BinaryFunction)); + funcs.AddFunction( + ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, + ScalarFunction::BinaryFunction)); + return funcs; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/operators/bitwise.cpp b/src/duckdb/src/core_functions/scalar/operators/bitwise.cpp new file mode 100644 index 00000000..f1604aa6 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/operators/bitwise.cpp @@ -0,0 +1,307 @@ +#include "duckdb/core_functions/scalar/operators_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/bit.hpp" + +namespace duckdb { + +template +static scalar_function_t GetScalarIntegerUnaryFunction(const LogicalType &type) { + scalar_function_t function; + switch (type.id()) { + case LogicalTypeId::TINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::SMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::INTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::BIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UTINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::USMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UINTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UBIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::HUGEINT: + function = &ScalarFunction::UnaryFunction; + break; + default: + throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunction"); + } + return function; +} + +template +static scalar_function_t GetScalarIntegerBinaryFunction(const LogicalType &type) { + scalar_function_t function; + switch (type.id()) { + case LogicalTypeId::TINYINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::SMALLINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::INTEGER: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::BIGINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::UTINYINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::USMALLINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::UINTEGER: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::UBIGINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::HUGEINT: + function = &ScalarFunction::BinaryFunction; + break; + default: + throw NotImplementedException("Unimplemented type for GetScalarIntegerBinaryFunction"); + } + return function; +} + +//===--------------------------------------------------------------------===// +// & [bitwise_and] +//===--------------------------------------------------------------------===// +struct BitwiseANDOperator { + template + static inline TR Operation(TA left, TB right) { + return left & right; + } +}; + +static void BitwiseANDOperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { + string_t target = StringVector::EmptyString(result, rhs.GetSize()); + + Bit::BitwiseAnd(rhs, lhs, target); + return target; + }); +} + +ScalarFunctionSet BitwiseAndFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseANDOperation)); + return functions; +} + +//===--------------------------------------------------------------------===// +// | [bitwise_or] +//===--------------------------------------------------------------------===// +struct BitwiseOROperator { + template + static inline TR Operation(TA left, TB right) { + return left | right; + } +}; + +static void BitwiseOROperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { + string_t target = StringVector::EmptyString(result, rhs.GetSize()); + + Bit::BitwiseOr(rhs, lhs, target); + return target; + }); +} + +ScalarFunctionSet BitwiseOrFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseOROperation)); + return functions; +} + +//===--------------------------------------------------------------------===// +// # [bitwise_xor] +//===--------------------------------------------------------------------===// +struct BitwiseXOROperator { + template + static inline TR Operation(TA left, TB right) { + return left ^ right; + } +}; + +static void BitwiseXOROperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { + string_t target = StringVector::EmptyString(result, rhs.GetSize()); + + Bit::BitwiseXor(rhs, lhs, target); + return target; + }); +} + +ScalarFunctionSet BitwiseXorFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseXOROperation)); + return functions; +} + +//===--------------------------------------------------------------------===// +// ~ [bitwise_not] +//===--------------------------------------------------------------------===// +struct BitwiseNotOperator { + template + static inline TR Operation(TA input) { + return ~input; + } +}; + +static void BitwiseNOTOperation(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { + string_t target = StringVector::EmptyString(result, input.GetSize()); + + Bit::BitwiseNot(input, target); + return target; + }); +} + +ScalarFunctionSet BitwiseNotFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction(ScalarFunction({type}, type, GetScalarIntegerUnaryFunction(type))); + } + functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIT, BitwiseNOTOperation)); + return functions; +} + +//===--------------------------------------------------------------------===// +// << [bitwise_left_shift] +//===--------------------------------------------------------------------===// + +struct BitwiseShiftLeftOperator { + template + static inline TR Operation(TA input, TB shift) { + TA max_shift = TA(sizeof(TA) * 8); + if (input < 0) { + throw OutOfRangeException("Cannot left-shift negative number %s", NumericHelper::ToString(input)); + } + if (shift < 0) { + throw OutOfRangeException("Cannot left-shift by negative number %s", NumericHelper::ToString(shift)); + } + if (shift >= max_shift) { + if (input == 0) { + return 0; + } + throw OutOfRangeException("Left-shift value %s is out of range", NumericHelper::ToString(shift)); + } + if (shift == 0) { + return input; + } + TA max_value = (TA(1) << (max_shift - shift - 1)); + if (input >= max_value) { + throw OutOfRangeException("Overflow in left shift (%s << %s)", NumericHelper::ToString(input), + NumericHelper::ToString(shift)); + } + return input << shift; + } +}; + +static void BitwiseShiftLeftOperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { + int32_t max_shift = Bit::BitLength(input); + if (shift == 0) { + return input; + } + if (shift < 0) { + throw OutOfRangeException("Cannot left-shift by negative number %s", NumericHelper::ToString(shift)); + } + string_t target = StringVector::EmptyString(result, input.GetSize()); + + if (shift >= max_shift) { + Bit::SetEmptyBitString(target, input); + return target; + } + Bit::LeftShift(input, shift, target); + return target; + }); +} + +ScalarFunctionSet LeftShiftFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction( + ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftLeftOperation)); + return functions; +} + +//===--------------------------------------------------------------------===// +// >> [bitwise_right_shift] +//===--------------------------------------------------------------------===// +template +bool RightShiftInRange(T shift) { + return shift >= 0 && shift < T(sizeof(T) * 8); +} + +struct BitwiseShiftRightOperator { + template + static inline TR Operation(TA input, TB shift) { + return RightShiftInRange(shift) ? input >> shift : 0; + } +}; + +static void BitwiseShiftRightOperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { + int32_t max_shift = Bit::BitLength(input); + if (shift == 0) { + return input; + } + string_t target = StringVector::EmptyString(result, input.GetSize()); + if (shift < 0 || shift >= max_shift) { + Bit::SetEmptyBitString(target, input); + return target; + } + Bit::RightShift(input, shift, target); + return target; + }); +} + +ScalarFunctionSet RightShiftFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction( + ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftRightOperation)); + return functions; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/random/random.cpp b/src/duckdb/src/core_functions/scalar/random/random.cpp new file mode 100644 index 00000000..df2f3553 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/random/random.cpp @@ -0,0 +1,63 @@ +#include "duckdb/core_functions/scalar/random_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/random_engine.hpp" +#include "duckdb/common/types/uuid.hpp" + +namespace duckdb { + +struct RandomLocalState : public FunctionLocalState { + explicit RandomLocalState(uint32_t seed) : random_engine(seed) { + } + + RandomEngine random_engine; +}; + +static void RandomFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 0); + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < args.size(); i++) { + result_data[i] = lstate.random_engine.NextRandom(); + } +} + +static unique_ptr RandomInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + auto &random_engine = RandomEngine::Get(state.GetContext()); + lock_guard guard(random_engine.lock); + return make_uniq(random_engine.NextRandomInteger()); +} + +ScalarFunction RandomFun::GetFunction() { + ScalarFunction random("random", {}, LogicalType::DOUBLE, RandomFunction, nullptr, nullptr, nullptr, + RandomInitLocalState); + random.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return random; +} + +static void GenerateUUIDFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 0); + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + + for (idx_t i = 0; i < args.size(); i++) { + result_data[i] = UUID::GenerateRandomUUID(lstate.random_engine); + } +} + +ScalarFunction UUIDFun::GetFunction() { + ScalarFunction uuid_function({}, LogicalType::UUID, GenerateUUIDFunction, nullptr, nullptr, nullptr, + RandomInitLocalState); + // generate a random uuid + uuid_function.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return uuid_function; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/random/setseed.cpp b/src/duckdb/src/core_functions/scalar/random/setseed.cpp new file mode 100644 index 00000000..24f460eb --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/random/setseed.cpp @@ -0,0 +1,61 @@ +#include "duckdb/core_functions/scalar/random_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/random_engine.hpp" + +namespace duckdb { + +struct SetseedBindData : public FunctionData { + //! The client context for the function call + ClientContext &context; + + explicit SetseedBindData(ClientContext &context) : context(context) { + } + + unique_ptr Copy() const override { + return make_uniq(context); + } + + bool Equals(const FunctionData &other_p) const override { + return true; + } +}; + +static void SetSeedFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto &input = args.data[0]; + input.Flatten(args.size()); + + auto input_seeds = FlatVector::GetData(input); + uint32_t half_max = NumericLimits::Maximum() / 2; + + auto &random_engine = RandomEngine::Get(info.context); + for (idx_t i = 0; i < args.size(); i++) { + if (input_seeds[i] < -1.0 || input_seeds[i] > 1.0 || Value::IsNan(input_seeds[i])) { + throw Exception("SETSEED accepts seed values between -1.0 and 1.0, inclusive"); + } + uint32_t norm_seed = (input_seeds[i] + 1.0) * half_max; + random_engine.SetSeed(norm_seed); + } + + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); +} + +unique_ptr SetSeedBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return make_uniq(context); +} + +ScalarFunction SetseedFun::GetFunction() { + ScalarFunction setseed("setseed", {LogicalType::DOUBLE}, LogicalType::SQLNULL, SetSeedFunction, SetSeedBind); + setseed.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + return setseed; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/ascii.cpp b/src/duckdb/src/core_functions/scalar/string/ascii.cpp new file mode 100644 index 00000000..5f41338b --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/ascii.cpp @@ -0,0 +1,24 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +struct AsciiOperator { + template + static inline TR Operation(const TA &input) { + auto str = input.GetData(); + if (Utf8Proc::Analyze(str, input.GetSize()) == UnicodeType::ASCII) { + return str[0]; + } + int utf8_bytes = 4; + return Utf8Proc::UTF8ToCodepoint(str, utf8_bytes); + } +}; + +ScalarFunction ASCIIFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::INTEGER, + ScalarFunction::UnaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/bar.cpp b/src/duckdb/src/core_functions/scalar/string/bar.cpp new file mode 100644 index 00000000..d55ae672 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/bar.cpp @@ -0,0 +1,93 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/unicode_bar.hpp" +#include "duckdb/common/vector_operations/generic_executor.hpp" + +namespace duckdb { + +static string_t BarScalarFunction(double x, double min, double max, double max_width, string &result) { + static const char *FULL_BLOCK = UnicodeBar::FullBlock(); + static const char *const *PARTIAL_BLOCKS = UnicodeBar::PartialBlocks(); + static const idx_t PARTIAL_BLOCKS_COUNT = UnicodeBar::PartialBlocksCount(); + + if (!Value::IsFinite(max_width)) { + throw ValueOutOfRangeException("Max bar width must not be NaN or infinity"); + } + if (max_width < 1) { + throw ValueOutOfRangeException("Max bar width must be >= 1"); + } + if (max_width > 1000) { + throw ValueOutOfRangeException("Max bar width must be <= 1000"); + } + + double width; + + if (Value::IsNan(x) || Value::IsNan(min) || Value::IsNan(max) || x <= min) { + width = 0; + } else if (x >= max) { + width = max_width; + } else { + width = max_width * (x - min) / (max - min); + } + + if (!Value::IsFinite(width)) { + throw ValueOutOfRangeException("Bar width must not be NaN or infinity"); + } + + result.clear(); + + int32_t width_as_int = static_cast(width * PARTIAL_BLOCKS_COUNT); + idx_t full_blocks_count = (width_as_int / PARTIAL_BLOCKS_COUNT); + for (idx_t i = 0; i < full_blocks_count; i++) { + result += FULL_BLOCK; + } + + idx_t remaining = width_as_int % PARTIAL_BLOCKS_COUNT; + + if (remaining) { + result += PARTIAL_BLOCKS[remaining]; + } + + return string_t(result); +} + +static void BarFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); + auto &x_arg = args.data[0]; + auto &min_arg = args.data[1]; + auto &max_arg = args.data[2]; + string buffer; + + if (args.ColumnCount() == 3) { + GenericExecutor::ExecuteTernary, PrimitiveType, PrimitiveType, + PrimitiveType>( + x_arg, min_arg, max_arg, result, args.size(), + [&](PrimitiveType x, PrimitiveType min, PrimitiveType max) { + return StringVector::AddString(result, BarScalarFunction(x.val, min.val, max.val, 80, buffer)); + }); + } else { + auto &width_arg = args.data[3]; + GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, + PrimitiveType, PrimitiveType>( + x_arg, min_arg, max_arg, width_arg, result, args.size(), + [&](PrimitiveType x, PrimitiveType min, PrimitiveType max, + PrimitiveType width) { + return StringVector::AddString(result, BarScalarFunction(x.val, min.val, max.val, width.val, buffer)); + }); + } +} + +ScalarFunctionSet BarFun::GetFunctions() { + ScalarFunctionSet bar; + bar.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, + LogicalType::VARCHAR, BarFunction)); + bar.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, + LogicalType::VARCHAR, BarFunction)); + return bar; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/chr.cpp b/src/duckdb/src/core_functions/scalar/string/chr.cpp new file mode 100644 index 00000000..e7bb62e1 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/chr.cpp @@ -0,0 +1,48 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +struct ChrOperator { + static void GetCodepoint(int32_t input, char c[], int &utf8_bytes) { + if (input < 0 || !Utf8Proc::CodepointToUtf8(input, utf8_bytes, &c[0])) { + throw InvalidInputException("Invalid UTF8 Codepoint %d", input); + } + } + + template + static inline TR Operation(const TA &input) { + char c[5] = {'\0', '\0', '\0', '\0', '\0'}; + int utf8_bytes; + GetCodepoint(input, c, utf8_bytes); + return string_t(&c[0], utf8_bytes); + } +}; + +#ifdef DUCKDB_DEBUG_NO_INLINE +// the chr function depends on the data always being inlined (which is always possible, since it outputs max 4 bytes) +// to enable chr when string inlining is disabled we create a special function here +static void ChrFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &code_vec = args.data[0]; + + char c[5] = {'\0', '\0', '\0', '\0', '\0'}; + int utf8_bytes; + UnaryExecutor::Execute(code_vec, result, args.size(), [&](int32_t input) { + ChrOperator::GetCodepoint(input, c, utf8_bytes); + return StringVector::AddString(result, &c[0], utf8_bytes); + }); +} +#endif + +ScalarFunction ChrFun::GetFunction() { + return ScalarFunction("chr", {LogicalType::INTEGER}, LogicalType::VARCHAR, +#ifdef DUCKDB_DEBUG_NO_INLINE + ChrFunction +#else + ScalarFunction::UnaryFunction +#endif + ); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/damerau_levenshtein.cpp b/src/duckdb/src/core_functions/scalar/string/damerau_levenshtein.cpp new file mode 100644 index 00000000..20bb7dfc --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/damerau_levenshtein.cpp @@ -0,0 +1,104 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +// Using Lowrance-Wagner (LW) algorithm: https://doi.org/10.1145%2F321879.321880 +// Can't calculate as trivial modification to levenshtein algorithm +// as we need to potentially know about earlier in the string +static idx_t DamerauLevenshteinDistance(const string_t &source, const string_t &target) { + // costs associated with each type of edit, to aid readability + constexpr uint8_t COST_SUBSTITUTION = 1; + constexpr uint8_t COST_INSERTION = 1; + constexpr uint8_t COST_DELETION = 1; + constexpr uint8_t COST_TRANSPOSITION = 1; + const auto source_len = source.GetSize(); + const auto target_len = target.GetSize(); + + // If one string is empty, the distance equals the length of the other string + // either through target_len insertions + // or source_len deletions + if (source_len == 0) { + return target_len * COST_INSERTION; + } else if (target_len == 0) { + return source_len * COST_DELETION; + } + + const auto source_str = source.GetData(); + const auto target_str = target.GetData(); + + // larger than the largest possible value: + const auto inf = source_len * COST_DELETION + target_len * COST_INSERTION + 1; + // minimum edit distance from prefix of source string to prefix of target string + // same object as H in LW paper (with indices offset by 1) + vector> distance(source_len + 2, vector(target_len + 2, inf)); + // keeps track of the largest string indices of source string matching each character + // same as DA in LW paper + map largest_source_chr_matching; + + // initialise row/column corresponding to zero-length strings + // partial string -> empty requires a deletion for each character + for (idx_t source_idx = 0; source_idx <= source_len; source_idx++) { + distance[source_idx + 1][1] = source_idx * COST_DELETION; + } + // and empty -> partial string means simply inserting characters + for (idx_t target_idx = 1; target_idx <= target_len; target_idx++) { + distance[1][target_idx + 1] = target_idx * COST_INSERTION; + } + // loop through string indices - these are offset by 2 from distance indices + for (idx_t source_idx = 0; source_idx < source_len; source_idx++) { + // keeps track of the largest string indices of target string matching current source character + // same as DB in LW paper + idx_t largest_target_chr_matching; + largest_target_chr_matching = 0; + for (idx_t target_idx = 0; target_idx < target_len; target_idx++) { + // correspond to i1 and j1 in LW paper respectively + idx_t largest_source_chr_matching_target; + idx_t largest_target_chr_matching_source; + // cost associated to diagnanl shift in distance matrix + // corresponds to d in LW paper + uint8_t cost_diagonal_shift; + largest_source_chr_matching_target = largest_source_chr_matching[target_str[target_idx]]; + largest_target_chr_matching_source = largest_target_chr_matching; + // if characters match, diagonal move costs nothing and we update our largest target index + // otherwise move is substitution and costs as such + if (source_str[source_idx] == target_str[target_idx]) { + cost_diagonal_shift = 0; + largest_target_chr_matching = target_idx + 1; + } else { + cost_diagonal_shift = COST_SUBSTITUTION; + } + distance[source_idx + 2][target_idx + 2] = MinValue( + distance[source_idx + 1][target_idx + 1] + cost_diagonal_shift, + MinValue(distance[source_idx + 2][target_idx + 1] + COST_INSERTION, + MinValue(distance[source_idx + 1][target_idx + 2] + COST_DELETION, + distance[largest_source_chr_matching_target][largest_target_chr_matching_source] + + (source_idx - largest_source_chr_matching_target) * COST_DELETION + + COST_TRANSPOSITION + + (target_idx - largest_target_chr_matching_source) * COST_INSERTION))); + } + largest_source_chr_matching[source_str[source_idx]] = source_idx + 1; + } + return distance[source_len + 1][target_len + 1]; +} + +static int64_t DamerauLevenshteinScalarFunction(Vector &result, const string_t source, const string_t target) { + return (int64_t)DamerauLevenshteinDistance(source, target); +} + +static void DamerauLevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &source_vec = args.data[0]; + auto &target_vec = args.data[1]; + + BinaryExecutor::Execute( + source_vec, target_vec, result, args.size(), + [&](string_t source, string_t target) { return DamerauLevenshteinScalarFunction(result, source, target); }); +} + +ScalarFunction DamerauLevenshteinFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, + DamerauLevenshteinFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/format_bytes.cpp b/src/duckdb/src/core_functions/scalar/string/format_bytes.cpp new file mode 100644 index 00000000..b1a974f4 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/format_bytes.cpp @@ -0,0 +1,29 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +static void FormatBytesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](int64_t bytes) { + bool is_negative = bytes < 0; + idx_t unsigned_bytes; + if (bytes < 0) { + if (bytes == NumericLimits::Minimum()) { + unsigned_bytes = idx_t(NumericLimits::Maximum()) + 1; + } else { + unsigned_bytes = idx_t(-bytes); + } + } else { + unsigned_bytes = idx_t(bytes); + } + return StringVector::AddString(result, (is_negative ? "-" : "") + + StringUtil::BytesToHumanReadableString(unsigned_bytes)); + }); +} + +ScalarFunction FormatBytesFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, FormatBytesFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/hamming.cpp b/src/duckdb/src/core_functions/scalar/string/hamming.cpp new file mode 100644 index 00000000..892430da --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/hamming.cpp @@ -0,0 +1,45 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include +#include + +namespace duckdb { + +static int64_t MismatchesScalarFunction(Vector &result, const string_t str, string_t tgt) { + idx_t str_len = str.GetSize(); + idx_t tgt_len = tgt.GetSize(); + + if (str_len != tgt_len) { + throw InvalidInputException("Mismatch Function: Strings must be of equal length!"); + } + if (str_len < 1) { + throw InvalidInputException("Mismatch Function: Strings must be of length > 0!"); + } + + idx_t mismatches = 0; + auto str_str = str.GetData(); + auto tgt_str = tgt.GetData(); + + for (idx_t idx = 0; idx < str_len; ++idx) { + if (str_str[idx] != tgt_str[idx]) { + mismatches++; + } + } + return (int64_t)mismatches; +} + +static void MismatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &tgt_vec = args.data[1]; + + BinaryExecutor::Execute( + str_vec, tgt_vec, result, args.size(), + [&](string_t str, string_t tgt) { return MismatchesScalarFunction(result, str, tgt); }); +} + +ScalarFunction HammingFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, MismatchesFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/hex.cpp b/src/duckdb/src/core_functions/scalar/string/hex.cpp new file mode 100644 index 00000000..ffea5e31 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/hex.cpp @@ -0,0 +1,374 @@ +#include "duckdb/common/bit_utils.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/blob.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/core_functions/scalar/string_functions.hpp" + +namespace duckdb { + +static void WriteHexBytes(uint64_t x, char *&output, idx_t buffer_size) { + idx_t offset = buffer_size * 4; + + for (; offset >= 4; offset -= 4) { + uint8_t byte = (x >> (offset - 4)) & 0x0F; + *output = Blob::HEX_TABLE[byte]; + output++; + } +} + +static void WriteHugeIntHexBytes(hugeint_t x, char *&output, idx_t buffer_size) { + idx_t offset = buffer_size * 4; + auto upper = x.upper; + auto lower = x.lower; + + for (; offset >= 68; offset -= 4) { + uint8_t byte = (upper >> (offset - 68)) & 0x0F; + *output = Blob::HEX_TABLE[byte]; + output++; + } + + for (; offset >= 4; offset -= 4) { + uint8_t byte = (lower >> (offset - 4)) & 0x0F; + *output = Blob::HEX_TABLE[byte]; + output++; + } +} + +static void WriteBinBytes(uint64_t x, char *&output, idx_t buffer_size) { + idx_t offset = buffer_size; + for (; offset >= 1; offset -= 1) { + *output = ((x >> (offset - 1)) & 0x01) + '0'; + output++; + } +} + +static void WriteHugeIntBinBytes(hugeint_t x, char *&output, idx_t buffer_size) { + auto upper = x.upper; + auto lower = x.lower; + idx_t offset = buffer_size; + + for (; offset >= 65; offset -= 1) { + *output = ((upper >> (offset - 65)) & 0x01) + '0'; + output++; + } + + for (; offset >= 1; offset -= 1) { + *output = ((lower >> (offset - 1)) & 0x01) + '0'; + output++; + } +} + +struct HexStrOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + // Allocate empty space + auto target = StringVector::EmptyString(result, size * 2); + auto output = target.GetDataWriteable(); + + for (idx_t i = 0; i < size; ++i) { + *output = Blob::HEX_TABLE[(data[i] >> 4) & 0x0F]; + output++; + *output = Blob::HEX_TABLE[data[i] & 0x0F]; + output++; + } + + target.Finalize(); + return target; + } +}; + +struct HexIntegralOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + + idx_t num_leading_zero = CountZeros::Leading(input); + idx_t num_bits_to_check = 64 - num_leading_zero; + D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); + + idx_t buffer_size = (num_bits_to_check + 3) / 4; + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + D_ASSERT(buffer_size > 0); + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteHexBytes(input, output, buffer_size); + + target.Finalize(); + return target; + } +}; + +struct HexHugeIntOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + + idx_t num_leading_zero = CountZeros::Leading(input); + idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + D_ASSERT(buffer_size > 0); + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteHugeIntHexBytes(input, output, buffer_size); + + target.Finalize(); + return target; + } +}; + +template +static void ToHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + auto &input = args.data[0]; + idx_t count = args.size(); + UnaryExecutor::ExecuteString(input, result, count); +} + +struct BinaryStrOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + // Allocate empty space + auto target = StringVector::EmptyString(result, size * 8); + auto output = target.GetDataWriteable(); + + for (idx_t i = 0; i < size; ++i) { + uint8_t byte = data[i]; + for (idx_t i = 8; i >= 1; --i) { + *output = ((byte >> (i - 1)) & 0x01) + '0'; + output++; + } + } + + target.Finalize(); + return target; + } +}; + +struct BinaryIntegralOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + + idx_t num_leading_zero = CountZeros::Leading(input); + idx_t num_bits_to_check = 64 - num_leading_zero; + D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); + + idx_t buffer_size = num_bits_to_check; + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + D_ASSERT(buffer_size > 0); + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteBinBytes(input, output, buffer_size); + + target.Finalize(); + return target; + } +}; + +struct BinaryHugeIntOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + idx_t num_leading_zero = CountZeros::Leading(input); + idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteHugeIntBinBytes(input, output, buffer_size); + + target.Finalize(); + return target; + } +}; + +struct FromHexOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + if (size > NumericLimits::Maximum()) { + throw InvalidInputException("Hexadecimal input length larger than 2^32 are not supported"); + } + + D_ASSERT(size <= NumericLimits::Maximum()); + auto buffer_size = (size + 1) / 2; + + // Allocate empty space + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + // Treated as a single byte + idx_t i = 0; + if (size % 2 != 0) { + *output = StringUtil::GetHexValue(data[i]); + i++; + output++; + } + + for (; i < size; i += 2) { + uint8_t major = StringUtil::GetHexValue(data[i]); + uint8_t minor = StringUtil::GetHexValue(data[i + 1]); + *output = (major << 4) | minor; + output++; + } + + target.Finalize(); + return target; + } +}; + +struct FromBinaryOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + if (size > NumericLimits::Maximum()) { + throw InvalidInputException("Binary input length larger than 2^32 are not supported"); + } + + D_ASSERT(size <= NumericLimits::Maximum()); + auto buffer_size = (size + 7) / 8; + + // Allocate empty space + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + // Treated as a single byte + idx_t i = 0; + if (size % 8 != 0) { + uint8_t byte = 0; + for (idx_t j = size % 8; j > 0; --j) { + byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); + i++; + } + *output = byte; + output++; + } + + while (i < size) { + uint8_t byte = 0; + for (idx_t j = 8; j > 0; --j) { + byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); + i++; + } + *output = byte; + output++; + } + + target.Finalize(); + return target; + } +}; + +template +static void ToBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + auto &input = args.data[0]; + idx_t count = args.size(); + UnaryExecutor::ExecuteString(input, result, count); +} + +static void FromBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); + auto &input = args.data[0]; + idx_t count = args.size(); + + UnaryExecutor::ExecuteString(input, result, count); +} + +static void FromHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); + auto &input = args.data[0]; + idx_t count = args.size(); + + UnaryExecutor::ExecuteString(input, result, count); +} + +ScalarFunctionSet HexFun::GetFunctions() { + ScalarFunctionSet to_hex; + to_hex.AddFunction( + ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToHexFunction)); + + to_hex.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, ToHexFunction)); + + to_hex.AddFunction( + ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, ToHexFunction)); + + to_hex.AddFunction( + ScalarFunction({LogicalType::HUGEINT}, LogicalType::VARCHAR, ToHexFunction)); + return to_hex; +} + +ScalarFunction UnhexFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, FromHexFunction); +} + +ScalarFunctionSet BinFun::GetFunctions() { + ScalarFunctionSet to_binary; + + to_binary.AddFunction( + ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToBinaryFunction)); + to_binary.AddFunction(ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, + ToBinaryFunction)); + to_binary.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, ToBinaryFunction)); + to_binary.AddFunction(ScalarFunction({LogicalType::HUGEINT}, LogicalType::VARCHAR, + ToBinaryFunction)); + return to_binary; +} + +ScalarFunction UnbinFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, FromBinaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/instr.cpp b/src/duckdb/src/core_functions/scalar/string/instr.cpp new file mode 100644 index 00000000..becbbd48 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/instr.cpp @@ -0,0 +1,58 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "utf8proc.hpp" + +namespace duckdb { + +struct InstrOperator { + template + static inline TR Operation(TA haystack, TB needle) { + int64_t string_position = 0; + + auto location = ContainsFun::Find(haystack, needle); + if (location != DConstants::INVALID_INDEX) { + auto len = (utf8proc_ssize_t)location; + auto str = reinterpret_cast(haystack.GetData()); + D_ASSERT(len <= (utf8proc_ssize_t)haystack.GetSize()); + for (++string_position; len > 0; ++string_position) { + utf8proc_int32_t codepoint; + auto bytes = utf8proc_iterate(str, len, &codepoint); + str += bytes; + len -= bytes; + } + } + return string_position; + } +}; + +struct InstrAsciiOperator { + template + static inline TR Operation(TA haystack, TB needle) { + auto location = ContainsFun::Find(haystack, needle); + return location == DConstants::INVALID_INDEX ? 0 : location + 1; + } +}; + +static unique_ptr InStrPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 2); + // can only propagate stats if the children have stats + // for strpos, we only care if the FIRST string has unicode or not + if (!StringStats::CanContainUnicode(child_stats[0])) { + expr.function.function = ScalarFunction::BinaryFunction; + } + return nullptr; +} + +ScalarFunction InstrFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, + ScalarFunction::BinaryFunction, nullptr, nullptr, + InStrPropagateStats); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/jaccard.cpp b/src/duckdb/src/core_functions/scalar/string/jaccard.cpp new file mode 100644 index 00000000..69024442 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/jaccard.cpp @@ -0,0 +1,65 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/map.hpp" + +#include + +namespace duckdb { + +static inline map GetSet(const string_t &str) { + auto map_of_chars = map {}; + idx_t str_len = str.GetSize(); + auto s = str.GetData(); + + for (idx_t pos = 0; pos < str_len; pos++) { + map_of_chars.insert(std::make_pair(s[pos], 1)); + } + return map_of_chars; +} + +static double JaccardSimilarity(const string_t &str, const string_t &txt) { + if (str.GetSize() < 1 || txt.GetSize() < 1) { + throw InvalidInputException("Jaccard Function: An argument too short!"); + } + map m_str, m_txt; + + m_str = GetSet(str); + m_txt = GetSet(txt); + + if (m_str.size() > m_txt.size()) { + m_str.swap(m_txt); + } + + for (auto const &achar : m_str) { + ++m_txt[achar.first]; + } + // m_txt.size is now size of union. + + idx_t size_intersect = 0; + for (const auto &apair : m_txt) { + if (apair.second > 1) { + size_intersect++; + } + } + + return (double)size_intersect / (double)m_txt.size(); +} + +static double JaccardScalarFunction(Vector &result, const string_t str, string_t tgt) { + return (double)JaccardSimilarity(str, tgt); +} + +static void JaccardFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &tgt_vec = args.data[1]; + + BinaryExecutor::Execute( + str_vec, tgt_vec, result, args.size(), + [&](string_t str, string_t tgt) { return JaccardScalarFunction(result, str, tgt); }); +} + +ScalarFunction JaccardFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaccardFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/jaro_winkler.cpp b/src/duckdb/src/core_functions/scalar/string/jaro_winkler.cpp new file mode 100644 index 00000000..3c54b411 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/jaro_winkler.cpp @@ -0,0 +1,71 @@ +#include "jaro_winkler.hpp" + +#include "duckdb/core_functions/scalar/string_functions.hpp" + +namespace duckdb { + +static inline double JaroScalarFunction(const string_t &s1, const string_t &s2) { + auto s1_begin = s1.GetData(); + auto s2_begin = s2.GetData(); + return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize()); +} + +static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2) { + auto s1_begin = s1.GetData(); + auto s2_begin = s2.GetData(); + return duckdb_jaro_winkler::jaro_winkler_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, + s2_begin + s2.GetSize()); +} + +template +static void CachedFunction(Vector &constant, Vector &other, Vector &result, idx_t count) { + auto val = constant.GetValue(0); + if (val.IsNull()) { + auto &result_validity = FlatVector::Validity(result); + result_validity.SetAllInvalid(count); + return; + } + + auto str_val = StringValue::Get(val); + auto cached = CACHED_SIMILARITY(str_val); + UnaryExecutor::Execute(other, result, count, [&](const string_t &other_str) { + auto other_str_begin = other_str.GetData(); + return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize()); + }); +} + +template > +static void TemplatedJaroWinklerFunction(DataChunk &args, Vector &result, SIMILARITY_FUNCTION fun) { + bool arg0_constant = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR; + bool arg1_constant = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR; + if (!(arg0_constant ^ arg1_constant)) { + // We can't optimize by caching one of the two strings + BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), fun); + return; + } + + if (arg0_constant) { + CachedFunction(args.data[0], args.data[1], result, args.size()); + } else { + CachedFunction(args.data[1], args.data[0], result, args.size()); + } +} + +static void JaroFunction(DataChunk &args, ExpressionState &state, Vector &result) { + TemplatedJaroWinklerFunction>(args, result, JaroScalarFunction); +} + +static void JaroWinklerFunction(DataChunk &args, ExpressionState &state, Vector &result) { + TemplatedJaroWinklerFunction>(args, result, + JaroWinklerScalarFunction); +} + +ScalarFunction JaroSimilarityFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction); +} + +ScalarFunction JaroWinklerSimilarityFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/left_right.cpp b/src/duckdb/src/core_functions/scalar/string/left_right.cpp new file mode 100644 index 00000000..886559b6 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/left_right.cpp @@ -0,0 +1,100 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/function/scalar/string_functions.hpp" + +#include +#include + +namespace duckdb { + +struct LeftRightUnicode { + template + static inline TR Operation(TA input) { + return LengthFun::Length(input); + } + + static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { + return SubstringFun::SubstringUnicode(result, input, offset, length); + } +}; + +struct LeftRightGrapheme { + template + static inline TR Operation(TA input) { + return LengthFun::GraphemeCount(input); + } + + static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { + return SubstringFun::SubstringGrapheme(result, input, offset, length); + } +}; + +template +static string_t LeftScalarFunction(Vector &result, const string_t str, int64_t pos) { + if (pos >= 0) { + return OP::Substring(result, str, 1, pos); + } + + int64_t num_characters = OP::template Operation(str); + pos = MaxValue(0, num_characters + pos); + return OP::Substring(result, str, 1, pos); +} + +template +static void LeftFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &pos_vec = args.data[1]; + + BinaryExecutor::Execute( + str_vec, pos_vec, result, args.size(), + [&](string_t str, int64_t pos) { return LeftScalarFunction(result, str, pos); }); +} + +ScalarFunction LeftFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + LeftFunction); +} + +ScalarFunction LeftGraphemeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + LeftFunction); +} + +template +static string_t RightScalarFunction(Vector &result, const string_t str, int64_t pos) { + int64_t num_characters = OP::template Operation(str); + if (pos >= 0) { + int64_t len = MinValue(num_characters, pos); + int64_t start = num_characters - len + 1; + return OP::Substring(result, str, start, len); + } + + int64_t len = 0; + if (pos != std::numeric_limits::min()) { + len = num_characters - MinValue(num_characters, -pos); + } + int64_t start = num_characters - len + 1; + return OP::Substring(result, str, start, len); +} + +template +static void RightFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &pos_vec = args.data[1]; + BinaryExecutor::Execute( + str_vec, pos_vec, result, args.size(), + [&](string_t str, int64_t pos) { return RightScalarFunction(result, str, pos); }); +} + +ScalarFunction RightFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + RightFunction); +} + +ScalarFunction RightGraphemeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + RightFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/levenshtein.cpp b/src/duckdb/src/core_functions/scalar/string/levenshtein.cpp new file mode 100644 index 00000000..13731e38 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/levenshtein.cpp @@ -0,0 +1,84 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/string_util.hpp" + +#include +#include + +namespace duckdb { + +// See: https://www.kdnuggets.com/2020/10/optimizing-levenshtein-distance-measuring-text-similarity.html +// And: Iterative 2-row algorithm: https://en.wikipedia.org/wiki/Levenshtein_distance +// Note: A first implementation using the array algorithm version resulted in an error raised by duckdb +// (too muach memory usage) + +static idx_t LevenshteinDistance(const string_t &txt, const string_t &tgt) { + auto txt_len = txt.GetSize(); + auto tgt_len = tgt.GetSize(); + + // If one string is empty, the distance equals the length of the other string + if (txt_len == 0) { + return tgt_len; + } else if (tgt_len == 0) { + return txt_len; + } + + auto txt_str = txt.GetData(); + auto tgt_str = tgt.GetData(); + + // Create two working vectors + vector distances0(tgt_len + 1, 0); + vector distances1(tgt_len + 1, 0); + + idx_t cost_substitution = 0; + idx_t cost_insertion = 0; + idx_t cost_deletion = 0; + + // initialize distances0 vector + // edit distance for an empty txt string is just the number of characters to delete from tgt + for (idx_t pos_tgt = 0; pos_tgt <= tgt_len; pos_tgt++) { + distances0[pos_tgt] = pos_tgt; + } + + for (idx_t pos_txt = 0; pos_txt < txt_len; pos_txt++) { + // calculate distances1 (current raw distances) from the previous row + + distances1[0] = pos_txt + 1; + + for (idx_t pos_tgt = 0; pos_tgt < tgt_len; pos_tgt++) { + cost_deletion = distances0[pos_tgt + 1] + 1; + cost_insertion = distances1[pos_tgt] + 1; + cost_substitution = distances0[pos_tgt]; + + if (txt_str[pos_txt] != tgt_str[pos_tgt]) { + cost_substitution += 1; + } + + distances1[pos_tgt + 1] = MinValue(cost_deletion, MinValue(cost_substitution, cost_insertion)); + } + // copy distances1 (current row) to distances0 (previous row) for next iteration + // since data in distances1 is always invalidated, a swap without copy is more efficient + distances0 = distances1; + } + + return distances0[tgt_len]; +} + +static int64_t LevenshteinScalarFunction(Vector &result, const string_t str, string_t tgt) { + return (int64_t)LevenshteinDistance(str, tgt); +} + +static void LevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &tgt_vec = args.data[1]; + + BinaryExecutor::Execute( + str_vec, tgt_vec, result, args.size(), + [&](string_t str, string_t tgt) { return LevenshteinScalarFunction(result, str, tgt); }); +} + +ScalarFunction LevenshteinFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, LevenshteinFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/md5.cpp b/src/duckdb/src/core_functions/scalar/string/md5.cpp new file mode 100644 index 00000000..6e7ac124 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/md5.cpp @@ -0,0 +1,86 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/crypto/md5.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" + +namespace duckdb { + +struct MD5Operator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto hash = StringVector::EmptyString(result, MD5Context::MD5_HASH_LENGTH_TEXT); + MD5Context context; + context.Add(input); + context.FinishHex(hash.GetDataWriteable()); + hash.Finalize(); + return hash; + } +}; + +struct MD5Number128Operator { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + data_t digest[MD5Context::MD5_HASH_LENGTH_BINARY]; + + MD5Context context; + context.Add(input); + context.Finish(digest); + return *reinterpret_cast(digest); + } +}; + +template +struct MD5Number64Operator { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + data_t digest[MD5Context::MD5_HASH_LENGTH_BINARY]; + + MD5Context context; + context.Add(input); + context.Finish(digest); + return *reinterpret_cast(&digest[lower ? 8 : 0]); + } +}; + +static void MD5Function(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + + UnaryExecutor::ExecuteString(input, result, args.size()); +} + +static void MD5NumberFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + + UnaryExecutor::Execute(input, result, args.size()); +} + +static void MD5NumberUpperFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + + UnaryExecutor::Execute>(input, result, args.size()); +} + +static void MD5NumberLowerFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + + UnaryExecutor::Execute>(input, result, args.size()); +} + +ScalarFunction MD5Fun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, MD5Function); +} + +ScalarFunction MD5NumberFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::HUGEINT, MD5NumberFunction); +} + +ScalarFunction MD5NumberUpperFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::UBIGINT, MD5NumberUpperFunction); +} + +ScalarFunction MD5NumberLowerFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::UBIGINT, MD5NumberLowerFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/pad.cpp b/src/duckdb/src/core_functions/scalar/string/pad.cpp new file mode 100644 index 00000000..3ff111ca --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/pad.cpp @@ -0,0 +1,143 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/pair.hpp" + +#include "utf8proc.hpp" + +namespace duckdb { + +static pair PadCountChars(const idx_t len, const char *data, const idx_t size) { + // Count how much of str will fit in the output + auto str = reinterpret_cast(data); + idx_t nbytes = 0; + idx_t nchars = 0; + for (; nchars < len && nbytes < size; ++nchars) { + utf8proc_int32_t codepoint; + auto bytes = utf8proc_iterate(str + nbytes, size - nbytes, &codepoint); + D_ASSERT(bytes > 0); + nbytes += bytes; + } + + return pair(nbytes, nchars); +} + +static bool InsertPadding(const idx_t len, const string_t &pad, vector &result) { + // Copy the padding until the output is long enough + auto data = pad.GetData(); + auto size = pad.GetSize(); + + // Check whether we need data that we don't have + if (len > 0 && size == 0) { + return false; + } + + // Insert characters until we have all we need. + auto str = reinterpret_cast(data); + idx_t nbytes = 0; + for (idx_t nchars = 0; nchars < len; ++nchars) { + // If we are at the end of the pad, flush all of it and loop back + if (nbytes >= size) { + result.insert(result.end(), data, data + size); + nbytes = 0; + } + + // Write the next character + utf8proc_int32_t codepoint; + auto bytes = utf8proc_iterate(str + nbytes, size - nbytes, &codepoint); + D_ASSERT(bytes > 0); + nbytes += bytes; + } + + // Flush the remaining pad + result.insert(result.end(), data, data + nbytes); + + return true; +} + +static string_t LeftPadFunction(const string_t &str, const int32_t len, const string_t &pad, vector &result) { + // Reuse the buffer + result.clear(); + + // Get information about the base string + auto data_str = str.GetData(); + auto size_str = str.GetSize(); + + // Count how much of str will fit in the output + auto written = PadCountChars(len, data_str, size_str); + + // Left pad by the number of characters still needed + if (!InsertPadding(len - written.second, pad, result)) { + throw Exception("Insufficient padding in LPAD."); + } + + // Append as much of the original string as fits + result.insert(result.end(), data_str, data_str + written.first); + + return string_t(result.data(), result.size()); +} + +struct LeftPadOperator { + static inline string_t Operation(const string_t &str, const int32_t len, const string_t &pad, + vector &result) { + return LeftPadFunction(str, len, pad, result); + } +}; + +static string_t RightPadFunction(const string_t &str, const int32_t len, const string_t &pad, vector &result) { + // Reuse the buffer + result.clear(); + + // Get information about the base string + auto data_str = str.GetData(); + auto size_str = str.GetSize(); + + // Count how much of str will fit in the output + auto written = PadCountChars(len, data_str, size_str); + + // Append as much of the original string as fits + result.insert(result.end(), data_str, data_str + written.first); + + // Right pad by the number of characters still needed + if (!InsertPadding(len - written.second, pad, result)) { + throw Exception("Insufficient padding in RPAD."); + }; + + return string_t(result.data(), result.size()); +} + +struct RightPadOperator { + static inline string_t Operation(const string_t &str, const int32_t len, const string_t &pad, + vector &result) { + return RightPadFunction(str, len, pad, result); + } +}; + +template +static void PadFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vector = args.data[0]; + auto &len_vector = args.data[1]; + auto &pad_vector = args.data[2]; + + vector buffer; + TernaryExecutor::Execute( + str_vector, len_vector, pad_vector, result, args.size(), [&](string_t str, int32_t len, string_t pad) { + len = MaxValue(len, 0); + return StringVector::AddString(result, OP::Operation(str, len, pad, buffer)); + }); +} + +ScalarFunction LpadFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, + PadFunction); +} + +ScalarFunction RpadFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, + PadFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/printf.cpp b/src/duckdb/src/core_functions/scalar/string/printf.cpp new file mode 100644 index 00000000..b71bedef --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/printf.cpp @@ -0,0 +1,171 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/limits.hpp" +#include "fmt/format.h" +#include "fmt/printf.h" + +namespace duckdb { + +struct FMTPrintf { + template + static string OP(const char *format_str, vector> &format_args) { + return duckdb_fmt::vsprintf( + format_str, duckdb_fmt::basic_format_args(format_args.data(), static_cast(format_args.size()))); + } +}; + +struct FMTFormat { + template + static string OP(const char *format_str, vector> &format_args) { + return duckdb_fmt::vformat( + format_str, duckdb_fmt::basic_format_args(format_args.data(), static_cast(format_args.size()))); + } +}; + +unique_ptr BindPrintfFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + for (idx_t i = 1; i < arguments.size(); i++) { + switch (arguments[i]->return_type.id()) { + case LogicalTypeId::BOOLEAN: + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::VARCHAR: + // these types are natively supported + bound_function.arguments.push_back(arguments[i]->return_type); + break; + case LogicalTypeId::DECIMAL: + // decimal type: add cast to double + bound_function.arguments.emplace_back(LogicalType::DOUBLE); + break; + case LogicalTypeId::UNKNOWN: + // parameter: accept any input and rebind later + bound_function.arguments.emplace_back(LogicalType::ANY); + break; + default: + // all other types: add cast to string + bound_function.arguments.emplace_back(LogicalType::VARCHAR); + break; + } + } + return nullptr; +} + +template +static void PrintfFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &format_string = args.data[0]; + auto &result_validity = FlatVector::Validity(result); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + result_validity.Initialize(args.size()); + for (idx_t i = 0; i < args.ColumnCount(); i++) { + switch (args.data[i].GetVectorType()) { + case VectorType::CONSTANT_VECTOR: + if (ConstantVector::IsNull(args.data[i])) { + // constant null! result is always NULL regardless of other input + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + break; + default: + // FLAT VECTOR, we can directly OR the nullmask + args.data[i].Flatten(args.size()); + result.SetVectorType(VectorType::FLAT_VECTOR); + result_validity.Combine(FlatVector::Validity(args.data[i]), args.size()); + break; + } + } + idx_t count = result.GetVectorType() == VectorType::CONSTANT_VECTOR ? 1 : args.size(); + + auto format_data = FlatVector::GetData(format_string); + auto result_data = FlatVector::GetData(result); + for (idx_t idx = 0; idx < count; idx++) { + if (result.GetVectorType() == VectorType::FLAT_VECTOR && FlatVector::IsNull(result, idx)) { + // this entry is NULL: skip it + continue; + } + + // first fetch the format string + auto fmt_idx = format_string.GetVectorType() == VectorType::CONSTANT_VECTOR ? 0 : idx; + auto format_string = format_data[fmt_idx].GetString(); + + // now gather all the format arguments + vector> format_args; + vector> string_args; + + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + auto &col = args.data[col_idx]; + idx_t arg_idx = col.GetVectorType() == VectorType::CONSTANT_VECTOR ? 0 : idx; + switch (col.GetType().id()) { + case LogicalTypeId::BOOLEAN: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::TINYINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::SMALLINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::INTEGER: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::BIGINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::FLOAT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::DOUBLE: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::VARCHAR: { + auto arg_data = FlatVector::GetData(col); + auto string_view = + duckdb_fmt::basic_string_view(arg_data[arg_idx].GetData(), arg_data[arg_idx].GetSize()); + format_args.emplace_back(duckdb_fmt::internal::make_arg(string_view)); + break; + } + default: + throw InternalException("Unexpected type for printf format"); + } + } + // finally actually perform the format + string dynamic_result = FORMAT_FUN::template OP(format_string.c_str(), format_args); + result_data[idx] = StringVector::AddString(result, dynamic_result); + } +} + +ScalarFunction PrintfFun::GetFunction() { + // duckdb_fmt::printf_context, duckdb_fmt::vsprintf + ScalarFunction printf_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, + PrintfFunction, BindPrintfFunction); + printf_fun.varargs = LogicalType::ANY; + return printf_fun; +} + +ScalarFunction FormatFun::GetFunction() { + // duckdb_fmt::format_context, duckdb_fmt::vformat + ScalarFunction format_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, + PrintfFunction, BindPrintfFunction); + format_fun.varargs = LogicalType::ANY; + return format_fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/repeat.cpp b/src/duckdb/src/core_functions/scalar/string/repeat.cpp new file mode 100644 index 00000000..4ff356e2 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/repeat.cpp @@ -0,0 +1,43 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include +#include + +namespace duckdb { + +static string_t RepeatScalarFunction(const string_t &str, const int64_t cnt, vector &result) { + // Get information about the repeated string + auto input_str = str.GetData(); + auto size_str = str.GetSize(); + + // Reuse the buffer + result.clear(); + for (auto remaining = cnt; remaining-- > 0;) { + result.insert(result.end(), input_str, input_str + size_str); + } + + return string_t(result.data(), result.size()); +} + +static void RepeatFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vector = args.data[0]; + auto &cnt_vector = args.data[1]; + + vector buffer; + BinaryExecutor::Execute( + str_vector, cnt_vector, result, args.size(), [&](string_t str, int64_t cnt) { + return StringVector::AddString(result, RepeatScalarFunction(str, cnt, buffer)); + }); +} + +ScalarFunctionSet RepeatFun::GetFunctions() { + ScalarFunctionSet repeat; + for (const auto &type : {LogicalType::VARCHAR, LogicalType::BLOB}) { + repeat.AddFunction(ScalarFunction({type, LogicalType::BIGINT}, type, RepeatFunction)); + } + return repeat; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/replace.cpp b/src/duckdb/src/core_functions/scalar/string/replace.cpp new file mode 100644 index 00000000..73ae3fea --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/replace.cpp @@ -0,0 +1,84 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" + +#include +#include +#include + +namespace duckdb { + +static idx_t NextNeedle(const char *input_haystack, idx_t size_haystack, const char *input_needle, + const idx_t size_needle) { + // Needle needs something to proceed + if (size_needle > 0) { + // Haystack should be bigger or equal size to the needle + for (idx_t string_position = 0; (size_haystack - string_position) >= size_needle; ++string_position) { + // Compare Needle to the Haystack + if ((memcmp(input_haystack + string_position, input_needle, size_needle) == 0)) { + return string_position; + } + } + } + // Did not find the needle + return size_haystack; +} + +static string_t ReplaceScalarFunction(const string_t &haystack, const string_t &needle, const string_t &thread, + vector &result) { + // Get information about the needle, the haystack and the "thread" + auto input_haystack = haystack.GetData(); + auto size_haystack = haystack.GetSize(); + + auto input_needle = needle.GetData(); + auto size_needle = needle.GetSize(); + + auto input_thread = thread.GetData(); + auto size_thread = thread.GetSize(); + + // Reuse the buffer + result.clear(); + + for (;;) { + // Append the non-matching characters + auto string_position = NextNeedle(input_haystack, size_haystack, input_needle, size_needle); + result.insert(result.end(), input_haystack, input_haystack + string_position); + input_haystack += string_position; + size_haystack -= string_position; + + // Stop when we have read the entire haystack + if (size_haystack == 0) { + break; + } + + // Replace the matching characters + result.insert(result.end(), input_thread, input_thread + size_thread); + input_haystack += size_needle; + size_haystack -= size_needle; + } + + return string_t(result.data(), result.size()); +} + +static void ReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &haystack_vector = args.data[0]; + auto &needle_vector = args.data[1]; + auto &thread_vector = args.data[2]; + + vector buffer; + TernaryExecutor::Execute( + haystack_vector, needle_vector, thread_vector, result, args.size(), + [&](string_t input_string, string_t needle_string, string_t thread_string) { + return StringVector::AddString(result, + ReplaceScalarFunction(input_string, needle_string, thread_string, buffer)); + }); +} + +ScalarFunction ReplaceFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + ReplaceFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/reverse.cpp b/src/duckdb/src/core_functions/scalar/string/reverse.cpp new file mode 100644 index 00000000..95c3cf1a --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/reverse.cpp @@ -0,0 +1,56 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "utf8proc.hpp" + +#include + +namespace duckdb { + +//! Fast ASCII string reverse, returns false if the input data is not ascii +static bool StrReverseASCII(const char *input, idx_t n, char *output) { + for (idx_t i = 0; i < n; i++) { + if (input[i] & 0x80) { + // non-ascii character + return false; + } + output[n - i - 1] = input[i]; + } + return true; +} + +//! Unicode string reverse using grapheme breakers +static void StrReverseUnicode(const char *input, idx_t n, char *output) { + utf8proc_grapheme_callback(input, n, [&](size_t start, size_t end) { + memcpy(output + n - end, input + start, end - start); + return true; + }); +} + +struct ReverseOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + + auto target = StringVector::EmptyString(result, input_length); + auto target_data = target.GetDataWriteable(); + if (!StrReverseASCII(input_data, input_length, target_data)) { + StrReverseUnicode(input_data, input_length, target_data); + } + target.Finalize(); + return target; + } +}; + +static void ReverseFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +ScalarFunction ReverseFun::GetFunction() { + return ScalarFunction("reverse", {LogicalType::VARCHAR}, LogicalType::VARCHAR, ReverseFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/sha256.cpp b/src/duckdb/src/core_functions/scalar/string/sha256.cpp new file mode 100644 index 00000000..efc09c05 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/sha256.cpp @@ -0,0 +1,32 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "mbedtls_wrapper.hpp" + +namespace duckdb { + +struct SHA256Operator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto hash = StringVector::EmptyString(result, duckdb_mbedtls::MbedTlsWrapper::SHA256_HASH_LENGTH_TEXT); + + duckdb_mbedtls::MbedTlsWrapper::SHA256State state; + state.AddString(input.GetString()); + state.FinishHex(hash.GetDataWriteable()); + + hash.Finalize(); + return hash; + } +}; + +static void SHA256Function(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + + UnaryExecutor::ExecuteString(input, result, args.size()); +} + +ScalarFunction SHA256Fun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, SHA256Function); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/starts_with.cpp b/src/duckdb/src/core_functions/scalar/string/starts_with.cpp new file mode 100644 index 00000000..c4661b91 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/starts_with.cpp @@ -0,0 +1,44 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +static bool StartsWith(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, + idx_t needle_size) { + D_ASSERT(needle_size > 0); + if (needle_size > haystack_size) { + // needle is bigger than haystack: haystack cannot start with needle + return false; + } + return memcmp(haystack, needle, needle_size) == 0; +} + +static bool StartsWith(const string_t &haystack_s, const string_t &needle_s) { + + auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); + auto haystack_size = haystack_s.GetSize(); + auto needle = const_uchar_ptr_cast(needle_s.GetData()); + auto needle_size = needle_s.GetSize(); + if (needle_size == 0) { + // empty needle: always true + return true; + } + return StartsWith(haystack, haystack_size, needle, needle_size); +} + +struct StartsWithOperator { + template + static inline TR Operation(TA left, TB right) { + return StartsWith(left, right); + } +}; + +ScalarFunction StartsWithOperatorFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + ScalarFunction::BinaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/string_split.cpp b/src/duckdb/src/core_functions/scalar/string/string_split.cpp new file mode 100644 index 00000000..7a41d3bd --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/string_split.cpp @@ -0,0 +1,196 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_size.hpp" +#include "duckdb/function/scalar/regexp.hpp" +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/string_functions.hpp" + +namespace duckdb { + +struct StringSplitInput { + StringSplitInput(Vector &result_list, Vector &result_child, idx_t offset) + : result_list(result_list), result_child(result_child), offset(offset) { + } + + Vector &result_list; + Vector &result_child; + idx_t offset; + + void AddSplit(const char *split_data, idx_t split_size, idx_t list_idx) { + auto list_entry = offset + list_idx; + if (list_entry >= ListVector::GetListCapacity(result_list)) { + ListVector::SetListSize(result_list, offset + list_idx); + ListVector::Reserve(result_list, ListVector::GetListCapacity(result_list) * 2); + } + FlatVector::GetData(result_child)[list_entry] = + StringVector::AddString(result_child, split_data, split_size); + } +}; + +struct RegularStringSplit { + static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, + idx_t &match_size, void *data) { + match_size = delim_size; + if (delim_size == 0) { + return 0; + } + return ContainsFun::Find(const_uchar_ptr_cast(input_data), input_size, const_uchar_ptr_cast(delim_data), + delim_size); + } +}; + +struct ConstantRegexpStringSplit { + static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, + idx_t &match_size, void *data) { + D_ASSERT(data); + auto regex = reinterpret_cast(data); + duckdb_re2::StringPiece match; + if (!regex->Match(duckdb_re2::StringPiece(input_data, input_size), 0, input_size, RE2::UNANCHORED, &match, 1)) { + return DConstants::INVALID_INDEX; + } + match_size = match.size(); + return match.data() - input_data; + } +}; + +struct RegexpStringSplit { + static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, + idx_t &match_size, void *data) { + duckdb_re2::RE2 regex(duckdb_re2::StringPiece(delim_data, delim_size)); + if (!regex.ok()) { + throw InvalidInputException(regex.error()); + } + return ConstantRegexpStringSplit::Find(input_data, input_size, delim_data, delim_size, match_size, ®ex); + } +}; + +struct StringSplitter { + template + static idx_t Split(string_t input, string_t delim, StringSplitInput &state, void *data) { + auto input_data = input.GetData(); + auto input_size = input.GetSize(); + auto delim_data = delim.GetData(); + auto delim_size = delim.GetSize(); + idx_t list_idx = 0; + while (input_size > 0) { + idx_t match_size = 0; + auto pos = OP::Find(input_data, input_size, delim_data, delim_size, match_size, data); + if (pos > input_size) { + break; + } + if (match_size == 0 && pos == 0) { + // special case: 0 length match and pos is 0 + // move to the next character + for (pos++; pos < input_size; pos++) { + if (LengthFun::IsCharacter(input_data[pos])) { + break; + } + } + if (pos == input_size) { + break; + } + } + D_ASSERT(input_size >= pos + match_size); + state.AddSplit(input_data, pos, list_idx); + + list_idx++; + input_data += (pos + match_size); + input_size -= (pos + match_size); + } + state.AddSplit(input_data, input_size, list_idx); + list_idx++; + return list_idx; + } +}; + +template +static void StringSplitExecutor(DataChunk &args, ExpressionState &state, Vector &result, void *data = nullptr) { + UnifiedVectorFormat input_data; + args.data[0].ToUnifiedFormat(args.size(), input_data); + auto inputs = UnifiedVectorFormat::GetData(input_data); + + UnifiedVectorFormat delim_data; + args.data[1].ToUnifiedFormat(args.size(), delim_data); + auto delims = UnifiedVectorFormat::GetData(delim_data); + + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + + result.SetVectorType(VectorType::FLAT_VECTOR); + ListVector::SetListSize(result, 0); + + auto list_struct_data = FlatVector::GetData(result); + + // count all the splits and set up the list entries + auto &child_entry = ListVector::GetEntry(result); + auto &result_mask = FlatVector::Validity(result); + idx_t total_splits = 0; + for (idx_t i = 0; i < args.size(); i++) { + auto input_idx = input_data.sel->get_index(i); + auto delim_idx = delim_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(input_idx)) { + result_mask.SetInvalid(i); + continue; + } + StringSplitInput split_input(result, child_entry, total_splits); + if (!delim_data.validity.RowIsValid(delim_idx)) { + // delim is NULL: copy the complete entry + split_input.AddSplit(inputs[input_idx].GetData(), inputs[input_idx].GetSize(), 0); + list_struct_data[i].length = 1; + list_struct_data[i].offset = total_splits; + total_splits++; + continue; + } + auto list_length = StringSplitter::Split(inputs[input_idx], delims[delim_idx], split_input, data); + list_struct_data[i].length = list_length; + list_struct_data[i].offset = total_splits; + total_splits += list_length; + } + ListVector::SetListSize(result, total_splits); + D_ASSERT(ListVector::GetListSize(result) == total_splits); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static void StringSplitFunction(DataChunk &args, ExpressionState &state, Vector &result) { + StringSplitExecutor(args, state, result, nullptr); +} + +static void StringSplitRegexFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + if (info.constant_pattern) { + // fast path: pre-compiled regex + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + StringSplitExecutor(args, state, result, &lstate.constant_pattern); + } else { + // slow path: have to re-compile regex for every row + StringSplitExecutor(args, state, result); + } +} + +ScalarFunction StringSplitFun::GetFunction() { + auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); + + ScalarFunction string_split({LogicalType::VARCHAR, LogicalType::VARCHAR}, varchar_list_type, StringSplitFunction); + string_split.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return string_split; +} + +ScalarFunctionSet StringSplitRegexFun::GetFunctions() { + auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); + ScalarFunctionSet regexp_split; + ScalarFunction regex_fun({LogicalType::VARCHAR, LogicalType::VARCHAR}, varchar_list_type, StringSplitRegexFunction, + RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING); + regexp_split.AddFunction(regex_fun); + // regexp options + regex_fun.arguments.emplace_back(LogicalType::VARCHAR); + regexp_split.AddFunction(regex_fun); + return regexp_split; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/to_base.cpp b/src/duckdb/src/core_functions/scalar/string/to_base.cpp new file mode 100644 index 00000000..ad5e1088 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/to_base.cpp @@ -0,0 +1,66 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +static const char alphabet[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + +static unique_ptr ToBaseBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // If no min_length is specified, default to 0 + D_ASSERT(arguments.size() == 2 || arguments.size() == 3); + if (arguments.size() == 2) { + arguments.push_back(make_uniq_base(Value::INTEGER(0))); + } + return nullptr; +} + +static void ToBaseFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + auto &radix = args.data[1]; + auto &min_length = args.data[2]; + auto count = args.size(); + + TernaryExecutor::Execute( + input, radix, min_length, result, count, [&](int64_t input, int32_t radix, int32_t min_length) { + if (input < 0) { + throw InvalidInputException("'to_base' number must be greater than or equal to 0"); + } + if (radix < 2 || radix > 36) { + throw InvalidInputException("'to_base' radix must be between 2 and 36"); + } + if (min_length > 64 || min_length < 0) { + throw InvalidInputException("'to_base' min_length must be between 0 and 64"); + } + + char buf[64]; + char *end = buf + sizeof(buf); + char *ptr = end; + do { + *--ptr = alphabet[input % radix]; + input /= radix; + } while (input > 0); + + auto length = end - ptr; + while (length < min_length) { + *--ptr = '0'; + length++; + } + + return StringVector::AddString(result, ptr, end - ptr); + }); +} + +ScalarFunctionSet ToBaseFun::GetFunctions() { + ScalarFunctionSet set("to_base"); + + set.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::INTEGER}, LogicalType::VARCHAR, ToBaseFunction, ToBaseBind)); + set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::INTEGER, LogicalType::INTEGER}, + LogicalType::VARCHAR, ToBaseFunction, ToBaseBind)); + + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/translate.cpp b/src/duckdb/src/core_functions/scalar/string/translate.cpp new file mode 100644 index 00000000..44d8b4b1 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/translate.cpp @@ -0,0 +1,96 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" + +#include +#include +#include +#include + +namespace duckdb { + +static string_t TranslateScalarFunction(const string_t &haystack, const string_t &needle, const string_t &thread, + vector &result) { + // Get information about the haystack, the needle and the "thread" + auto input_haystack = haystack.GetData(); + auto size_haystack = haystack.GetSize(); + + auto input_needle = needle.GetData(); + auto size_needle = needle.GetSize(); + + auto input_thread = thread.GetData(); + auto size_thread = thread.GetSize(); + + // Reuse the buffer + result.clear(); + result.reserve(size_haystack); + + idx_t i = 0, j = 0; + int sz = 0, c_sz = 0; + + // Character to be replaced + unordered_map to_replace; + while (i < size_needle && j < size_thread) { + auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); + input_needle += sz; + i += sz; + auto codepoint_thread = Utf8Proc::UTF8ToCodepoint(input_thread, sz); + input_thread += sz; + j += sz; + // Ignore unicode character that is existed in to_replace + if (to_replace.count(codepoint_needle) == 0) { + to_replace[codepoint_needle] = codepoint_thread; + } + } + + // Character to be deleted + unordered_set to_delete; + while (i < size_needle) { + auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); + input_needle += sz; + i += sz; + // Add unicode character that will be deleted + if (to_replace.count(codepoint_needle) == 0) { + to_delete.insert(codepoint_needle); + } + } + + char c[5] = {'\0', '\0', '\0', '\0', '\0'}; + for (i = 0; i < size_haystack; i += sz) { + auto codepoint_haystack = Utf8Proc::UTF8ToCodepoint(input_haystack, sz); + if (to_replace.count(codepoint_haystack) != 0) { + Utf8Proc::CodepointToUtf8(to_replace[codepoint_haystack], c_sz, c); + result.insert(result.end(), c, c + c_sz); + } else if (to_delete.count(codepoint_haystack) == 0) { + result.insert(result.end(), input_haystack, input_haystack + sz); + } + input_haystack += sz; + } + + return string_t(result.data(), result.size()); +} + +static void TranslateFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &haystack_vector = args.data[0]; + auto &needle_vector = args.data[1]; + auto &thread_vector = args.data[2]; + + vector buffer; + TernaryExecutor::Execute( + haystack_vector, needle_vector, thread_vector, result, args.size(), + [&](string_t input_string, string_t needle_string, string_t thread_string) { + return StringVector::AddString(result, + TranslateScalarFunction(input_string, needle_string, thread_string, buffer)); + }); +} + +ScalarFunction TranslateFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + TranslateFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/trim.cpp b/src/duckdb/src/core_functions/scalar/string/trim.cpp new file mode 100644 index 00000000..91e3b5dd --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/trim.cpp @@ -0,0 +1,154 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "utf8proc.hpp" + +#include + +namespace duckdb { + +template +struct TrimOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + utf8proc_int32_t codepoint; + auto str = reinterpret_cast(data); + + // Find the first character that is not left trimmed + idx_t begin = 0; + if (LTRIM) { + while (begin < size) { + auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); + D_ASSERT(bytes > 0); + if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { + break; + } + begin += bytes; + } + } + + // Find the last character that is not right trimmed + idx_t end; + if (RTRIM) { + end = begin; + for (auto next = begin; next < size;) { + auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); + D_ASSERT(bytes > 0); + next += bytes; + if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { + end = next; + } + } + } else { + end = size; + } + + // Copy the trimmed string + auto target = StringVector::EmptyString(result, end - begin); + auto output = target.GetDataWriteable(); + memcpy(output, data + begin, end - begin); + + target.Finalize(); + return target; + } +}; + +template +static void UnaryTrimFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); +} + +static void GetIgnoredCodepoints(string_t ignored, unordered_set &ignored_codepoints) { + auto dataptr = reinterpret_cast(ignored.GetData()); + auto size = ignored.GetSize(); + idx_t pos = 0; + while (pos < size) { + utf8proc_int32_t codepoint; + pos += utf8proc_iterate(dataptr + pos, size - pos, &codepoint); + ignored_codepoints.insert(codepoint); + } +} + +template +static void BinaryTrimFunction(DataChunk &input, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + input.data[0], input.data[1], result, input.size(), [&](string_t input, string_t ignored) { + auto data = input.GetData(); + auto size = input.GetSize(); + + unordered_set ignored_codepoints; + GetIgnoredCodepoints(ignored, ignored_codepoints); + + utf8proc_int32_t codepoint; + auto str = reinterpret_cast(data); + + // Find the first character that is not left trimmed + idx_t begin = 0; + if (LTRIM) { + while (begin < size) { + auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); + if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { + break; + } + begin += bytes; + } + } + + // Find the last character that is not right trimmed + idx_t end; + if (RTRIM) { + end = begin; + for (auto next = begin; next < size;) { + auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); + D_ASSERT(bytes > 0); + next += bytes; + if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { + end = next; + } + } + } else { + end = size; + } + + // Copy the trimmed string + auto target = StringVector::EmptyString(result, end - begin); + auto output = target.GetDataWriteable(); + memcpy(output, data + begin, end - begin); + + target.Finalize(); + return target; + }); +} + +ScalarFunctionSet TrimFun::GetFunctions() { + ScalarFunctionSet trim; + trim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); + + trim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + BinaryTrimFunction)); + return trim; +} + +ScalarFunctionSet LtrimFun::GetFunctions() { + ScalarFunctionSet ltrim; + ltrim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); + ltrim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + BinaryTrimFunction)); + return ltrim; +} + +ScalarFunctionSet RtrimFun::GetFunctions() { + ScalarFunctionSet rtrim; + rtrim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); + + rtrim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + BinaryTrimFunction)); + return rtrim; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/unicode.cpp b/src/duckdb/src/core_functions/scalar/string/unicode.cpp new file mode 100644 index 00000000..b621c532 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/unicode.cpp @@ -0,0 +1,28 @@ +#include "duckdb/core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "utf8proc.hpp" + +#include + +namespace duckdb { + +struct UnicodeOperator { + template + static inline TR Operation(const TA &input) { + auto str = reinterpret_cast(input.GetData()); + auto len = input.GetSize(); + utf8proc_int32_t codepoint; + (void)utf8proc_iterate(str, len, &codepoint); + return codepoint; + } +}; + +ScalarFunction UnicodeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::INTEGER, + ScalarFunction::UnaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/struct/struct_insert.cpp b/src/duckdb/src/core_functions/scalar/struct/struct_insert.cpp new file mode 100644 index 00000000..3d9753f5 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/struct/struct_insert.cpp @@ -0,0 +1,109 @@ +#include "duckdb/core_functions/scalar/struct_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +static void StructInsertFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &starting_vec = args.data[0]; + + starting_vec.Verify(args.size()); + + auto &starting_child_entries = StructVector::GetEntries(starting_vec); + auto &result_child_entries = StructVector::GetEntries(result); + + // Assign the starting vector entries to the result vector + for (size_t i = 0; i < starting_child_entries.size(); i++) { + auto &starting_child = starting_child_entries[i]; + result_child_entries[i]->Reference(*starting_child); + } + + // Assign the new entries to the result vector + for (size_t i = 1; i < args.ColumnCount(); i++) { + result_child_entries[starting_child_entries.size() + i - 1]->Reference(args.data[i]); + } + + result.Verify(args.size()); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr StructInsertBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + case_insensitive_set_t name_collision_set; + + if (arguments.empty()) { + throw Exception("Missing required arguments for struct_insert function."); + } + + if (LogicalTypeId::STRUCT != arguments[0]->return_type.id()) { + throw Exception("The first argument to struct_insert must be a STRUCT"); + } + + if (arguments.size() < 2) { + throw Exception("Can't insert nothing into a struct"); + } + + child_list_t new_struct_children; + + auto &existing_struct_children = StructType::GetChildTypes(arguments[0]->return_type); + + for (size_t i = 0; i < existing_struct_children.size(); i++) { + auto &child = existing_struct_children[i]; + name_collision_set.insert(child.first); + new_struct_children.push_back(make_pair(child.first, child.second)); + } + + // Loop through the additional arguments (name/value pairs) + for (idx_t i = 1; i < arguments.size(); i++) { + auto &child = arguments[i]; + if (child->alias.empty() && bound_function.name == "struct_insert") { + throw BinderException("Need named argument for struct insert, e.g. STRUCT_PACK(a := b)"); + } + if (name_collision_set.find(child->alias) != name_collision_set.end()) { + throw BinderException("Duplicate struct entry name \"%s\"", child->alias); + } + name_collision_set.insert(child->alias); + new_struct_children.push_back(make_pair(child->alias, arguments[i]->return_type)); + } + + // this is more for completeness reasons + bound_function.return_type = LogicalType::STRUCT(new_struct_children); + return make_uniq(bound_function.return_type); +} + +unique_ptr StructInsertStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + auto new_struct_stats = StructStats::CreateUnknown(expr.return_type); + + auto existing_count = StructType::GetChildCount(child_stats[0].GetType()); + auto existing_stats = StructStats::GetChildStats(child_stats[0]); + for (idx_t i = 0; i < existing_count; i++) { + StructStats::SetChildStats(new_struct_stats, i, existing_stats[i]); + } + auto new_count = StructType::GetChildCount(expr.return_type); + auto offset = new_count - child_stats.size(); + for (idx_t i = 1; i < child_stats.size(); i++) { + StructStats::SetChildStats(new_struct_stats, offset + i, child_stats[i]); + } + return new_struct_stats.ToUnique(); +} + +ScalarFunction StructInsertFun::GetFunction() { + // the arguments and return types are actually set in the binder function + ScalarFunction fun({}, LogicalTypeId::STRUCT, StructInsertFunction, StructInsertBind, nullptr, StructInsertStats); + fun.varargs = LogicalType::ANY; + fun.serialize = VariableReturnBindData::Serialize; + fun.deserialize = VariableReturnBindData::Deserialize; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/struct/struct_pack.cpp b/src/duckdb/src/core_functions/scalar/struct/struct_pack.cpp new file mode 100644 index 00000000..bd6787a3 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/struct/struct_pack.cpp @@ -0,0 +1,93 @@ +#include "duckdb/core_functions/scalar/struct_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +static void StructPackFunction(DataChunk &args, ExpressionState &state, Vector &result) { +#ifdef DEBUG + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + // this should never happen if the binder below is sane + D_ASSERT(args.ColumnCount() == StructType::GetChildTypes(info.stype).size()); +#endif + bool all_const = true; + auto &child_entries = StructVector::GetEntries(result); + for (size_t i = 0; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + all_const = false; + } + // same holds for this + child_entries[i]->Reference(args.data[i]); + } + result.SetVectorType(all_const ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); + + result.Verify(args.size()); +} + +template +static unique_ptr StructPackBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + case_insensitive_set_t name_collision_set; + + // collect names and deconflict, construct return type + if (arguments.empty()) { + throw Exception("Can't pack nothing into a struct"); + } + child_list_t struct_children; + for (idx_t i = 0; i < arguments.size(); i++) { + auto &child = arguments[i]; + string alias; + if (IS_STRUCT_PACK) { + if (child->alias.empty()) { + throw BinderException("Need named argument for struct pack, e.g. STRUCT_PACK(a := b)"); + } + alias = child->alias; + if (name_collision_set.find(alias) != name_collision_set.end()) { + throw BinderException("Duplicate struct entry name \"%s\"", alias); + } + name_collision_set.insert(alias); + } + struct_children.push_back(make_pair(alias, arguments[i]->return_type)); + } + + // this is more for completeness reasons + bound_function.return_type = LogicalType::STRUCT(struct_children); + return make_uniq(bound_function.return_type); +} + +unique_ptr StructPackStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + auto struct_stats = StructStats::CreateUnknown(expr.return_type); + for (idx_t i = 0; i < child_stats.size(); i++) { + StructStats::SetChildStats(struct_stats, i, child_stats[i]); + } + return struct_stats.ToUnique(); +} + +template +ScalarFunction GetStructPackFunction() { + ScalarFunction fun(IS_STRUCT_PACK ? "struct_pack" : "row", {}, LogicalTypeId::STRUCT, StructPackFunction, + StructPackBind, nullptr, StructPackStats); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = VariableReturnBindData::Serialize; + fun.deserialize = VariableReturnBindData::Deserialize; + return fun; +} + +ScalarFunction StructPackFun::GetFunction() { + return GetStructPackFunction(); +} + +ScalarFunction RowFun::GetFunction() { + return GetStructPackFunction(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/union/union_extract.cpp b/src/duckdb/src/core_functions/scalar/union/union_extract.cpp new file mode 100644 index 00000000..fe838cb1 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/union/union_extract.cpp @@ -0,0 +1,106 @@ +#include "duckdb/core_functions/scalar/union_functions.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" + +namespace duckdb { + +struct UnionExtractBindData : public FunctionData { + UnionExtractBindData(string key, idx_t index, LogicalType type) + : key(std::move(key)), index(index), type(std::move(type)) { + } + + string key; + idx_t index; + LogicalType type; + +public: + unique_ptr Copy() const override { + return make_uniq(key, index, type); + } + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return key == other.key && index == other.index && type == other.type; + } +}; + +static void UnionExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + // this should be guaranteed by the binder + auto &vec = args.data[0]; + vec.Verify(args.size()); + + D_ASSERT(info.index < UnionType::GetMemberCount(vec.GetType())); + auto &member = UnionVector::GetMember(vec, info.index); + result.Reference(member); + result.Verify(args.size()); +} + +static unique_ptr UnionExtractBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2); + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + D_ASSERT(LogicalTypeId::UNION == arguments[0]->return_type.id()); + idx_t union_member_count = UnionType::GetMemberCount(arguments[0]->return_type); + if (union_member_count == 0) { + throw InternalException("Can't extract something from an empty union"); + } + bound_function.arguments[0] = arguments[0]->return_type; + + auto &key_child = arguments[1]; + if (key_child->HasParameter()) { + throw ParameterNotResolvedException(); + } + + if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { + throw BinderException("Key name for union_extract needs to be a constant string"); + } + Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); + D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); + auto &key_str = StringValue::Get(key_val); + if (key_val.IsNull() || key_str.empty()) { + throw BinderException("Key name for union_extract needs to be neither NULL nor empty"); + } + string key = StringUtil::Lower(key_str); + + LogicalType return_type; + idx_t key_index = 0; + bool found_key = false; + + for (size_t i = 0; i < union_member_count; i++) { + auto &member_name = UnionType::GetMemberName(arguments[0]->return_type, i); + if (StringUtil::Lower(member_name) == key) { + found_key = true; + key_index = i; + return_type = UnionType::GetMemberType(arguments[0]->return_type, i); + break; + } + } + + if (!found_key) { + vector candidates; + candidates.reserve(union_member_count); + for (idx_t i = 0; i < union_member_count; i++) { + candidates.push_back(UnionType::GetMemberName(arguments[0]->return_type, i)); + } + auto closest_settings = StringUtil::TopNLevenshtein(candidates, key); + auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); + throw BinderException("Could not find key \"%s\" in union\n%s", key, message); + } + + bound_function.return_type = return_type; + return make_uniq(key, key_index, return_type); +} + +ScalarFunction UnionExtractFun::GetFunction() { + // the arguments and return types are actually set in the binder function + return ScalarFunction({LogicalTypeId::UNION, LogicalType::VARCHAR}, LogicalType::ANY, UnionExtractFunction, + UnionExtractBind, nullptr, nullptr); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/union/union_tag.cpp b/src/duckdb/src/core_functions/scalar/union/union_tag.cpp new file mode 100644 index 00000000..431df0ad --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/union/union_tag.cpp @@ -0,0 +1,58 @@ +#include "duckdb/core_functions/scalar/union_functions.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" + +namespace duckdb { + +static unique_ptr UnionTagBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + if (arguments.empty()) { + throw BinderException("Missing required arguments for union_tag function."); + } + + if (LogicalTypeId::UNKNOWN == arguments[0]->return_type.id()) { + throw ParameterNotResolvedException(); + } + + if (LogicalTypeId::UNION != arguments[0]->return_type.id()) { + throw BinderException("First argument to union_tag function must be a union type."); + } + + if (arguments.size() > 1) { + throw BinderException("Too many arguments, union_tag takes at most one argument."); + } + + auto member_count = UnionType::GetMemberCount(arguments[0]->return_type); + if (member_count == 0) { + // this should never happen, empty unions are not allowed + throw InternalException("Can't get tags from an empty union"); + } + + bound_function.arguments[0] = arguments[0]->return_type; + + auto varchar_vector = Vector(LogicalType::VARCHAR, member_count); + for (idx_t i = 0; i < member_count; i++) { + auto str = string_t(UnionType::GetMemberName(arguments[0]->return_type, i)); + FlatVector::GetData(varchar_vector)[i] = + str.IsInlined() ? str : StringVector::AddString(varchar_vector, str); + } + auto enum_type = LogicalType::ENUM(varchar_vector, member_count); + bound_function.return_type = enum_type; + + return nullptr; +} + +static void UnionTagFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalTypeId::ENUM); + result.Reinterpret(UnionVector::GetTags(args.data[0])); +} + +ScalarFunction UnionTagFun::GetFunction() { + return ScalarFunction({LogicalTypeId::UNION}, LogicalTypeId::ANY, UnionTagFunction, UnionTagBind, nullptr, + nullptr); // TODO: Statistics? +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/union/union_value.cpp b/src/duckdb/src/core_functions/scalar/union/union_value.cpp new file mode 100644 index 00000000..6ba7070a --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/union/union_value.cpp @@ -0,0 +1,68 @@ +#include "duckdb/core_functions/scalar/union_functions.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" + +namespace duckdb { + +struct UnionValueBindData : public FunctionData { + UnionValueBindData() { + } + +public: + unique_ptr Copy() const override { + return make_uniq(); + } + bool Equals(const FunctionData &other_p) const override { + return true; + } +}; + +static void UnionValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // Assign the new entries to the result vector + UnionVector::GetMember(result, 0).Reference(args.data[0]); + + // Set the result tag vector to a constant value + auto &tag_vector = UnionVector::GetTags(result); + tag_vector.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(tag_vector)[0] = 0; + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + + result.Verify(args.size()); +} + +static unique_ptr UnionValueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + if (arguments.size() != 1) { + throw BinderException("union_value takes exactly one argument"); + } + auto &child = arguments[0]; + + if (child->alias.empty()) { + throw BinderException("Need named argument for union tag, e.g. UNION_VALUE(a := b)"); + } + + child_list_t union_members; + + union_members.push_back(make_pair(child->alias, child->return_type)); + + bound_function.return_type = LogicalType::UNION(std::move(union_members)); + return make_uniq(bound_function.return_type); +} + +ScalarFunction UnionValueFun::GetFunction() { + ScalarFunction fun("union_value", {}, LogicalTypeId::UNION, UnionValueFunction, UnionValueBind, nullptr, nullptr); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = VariableReturnBindData::Serialize; + fun.deserialize = VariableReturnBindData::Deserialize; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/adaptive_filter.cpp b/src/duckdb/src/execution/adaptive_filter.cpp new file mode 100644 index 00000000..bbadeb63 --- /dev/null +++ b/src/duckdb/src/execution/adaptive_filter.cpp @@ -0,0 +1,90 @@ +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/execution/adaptive_filter.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +AdaptiveFilter::AdaptiveFilter(const Expression &expr) + : iteration_count(0), observe_interval(10), execute_interval(20), warmup(true) { + auto &conj_expr = expr.Cast(); + D_ASSERT(conj_expr.children.size() > 1); + for (idx_t idx = 0; idx < conj_expr.children.size(); idx++) { + permutation.push_back(idx); + if (idx != conj_expr.children.size() - 1) { + swap_likeliness.push_back(100); + } + } + right_random_border = 100 * (conj_expr.children.size() - 1); +} + +AdaptiveFilter::AdaptiveFilter(TableFilterSet *table_filters) + : iteration_count(0), observe_interval(10), execute_interval(20), warmup(true) { + for (auto &table_filter : table_filters->filters) { + permutation.push_back(table_filter.first); + swap_likeliness.push_back(100); + } + swap_likeliness.pop_back(); + right_random_border = 100 * (table_filters->filters.size() - 1); +} +void AdaptiveFilter::AdaptRuntimeStatistics(double duration) { + iteration_count++; + runtime_sum += duration; + + if (!warmup) { + // the last swap was observed + if (observe && iteration_count == observe_interval) { + // keep swap if runtime decreased, else reverse swap + if (prev_mean - (runtime_sum / iteration_count) <= 0) { + // reverse swap because runtime didn't decrease + std::swap(permutation[swap_idx], permutation[swap_idx + 1]); + + // decrease swap likeliness, but make sure there is always a small likeliness left + if (swap_likeliness[swap_idx] > 1) { + swap_likeliness[swap_idx] /= 2; + } + } else { + // keep swap because runtime decreased, reset likeliness + swap_likeliness[swap_idx] = 100; + } + observe = false; + + // reset values + iteration_count = 0; + runtime_sum = 0.0; + } else if (!observe && iteration_count == execute_interval) { + // save old mean to evaluate swap + prev_mean = runtime_sum / iteration_count; + + // get swap index and swap likeliness + std::uniform_int_distribution distribution(1, right_random_border); // a <= i <= b + idx_t random_number = distribution(generator) - 1; + + swap_idx = random_number / 100; // index to be swapped + idx_t likeliness = random_number - 100 * swap_idx; // random number between [0, 100) + + // check if swap is going to happen + if (swap_likeliness[swap_idx] > likeliness) { // always true for the first swap of an index + // swap + std::swap(permutation[swap_idx], permutation[swap_idx + 1]); + + // observe whether swap will be applied + observe = true; + } + + // reset values + iteration_count = 0; + runtime_sum = 0.0; + } + } else { + if (iteration_count == 5) { + // initially set all values + iteration_count = 0; + runtime_sum = 0.0; + observe = false; + warmup = false; + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/aggregate_hashtable.cpp b/src/duckdb/src/execution/aggregate_hashtable.cpp new file mode 100644 index 00000000..be5aa968 --- /dev/null +++ b/src/duckdb/src/execution/aggregate_hashtable.cpp @@ -0,0 +1,535 @@ +#include "duckdb/execution/aggregate_hashtable.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/row/tuple_data_iterator.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +using ValidityBytes = TupleDataLayout::ValidityBytes; + +GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, + vector group_types, vector payload_types, + const vector &bindings, + idx_t initial_capacity, idx_t radix_bits) + : GroupedAggregateHashTable(context, allocator, std::move(group_types), std::move(payload_types), + AggregateObject::CreateAggregateObjects(bindings), initial_capacity, radix_bits) { +} + +GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, + vector group_types) + : GroupedAggregateHashTable(context, allocator, std::move(group_types), {}, vector()) { +} + +GroupedAggregateHashTable::AggregateHTAppendState::AggregateHTAppendState() + : ht_offsets(LogicalType::UBIGINT), hash_salts(LogicalType::HASH), group_compare_vector(STANDARD_VECTOR_SIZE), + no_match_vector(STANDARD_VECTOR_SIZE), empty_vector(STANDARD_VECTOR_SIZE), new_groups(STANDARD_VECTOR_SIZE), + addresses(LogicalType::POINTER) { +} + +GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, + vector group_types_p, + vector payload_types_p, + vector aggregate_objects_p, + idx_t initial_capacity, idx_t radix_bits) + : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), + radix_bits(radix_bits), count(0), capacity(0), aggregate_allocator(make_shared(allocator)) { + + // Append hash column to the end and initialise the row layout + group_types_p.emplace_back(LogicalType::HASH); + layout.Initialize(std::move(group_types_p), std::move(aggregate_objects_p)); + + hash_offset = layout.GetOffsets()[layout.ColumnCount() - 1]; + + // Partitioned data and pointer table + InitializePartitionedData(); + Resize(initial_capacity); + + // Predicates + predicates.resize(layout.ColumnCount() - 1, ExpressionType::COMPARE_NOT_DISTINCT_FROM); + row_matcher.Initialize(true, layout, predicates); +} + +void GroupedAggregateHashTable::InitializePartitionedData() { + if (!partitioned_data || RadixPartitioning::RadixBits(partitioned_data->PartitionCount()) != radix_bits) { + D_ASSERT(!partitioned_data || partitioned_data->Count() == 0); + partitioned_data = + make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); + } else { + partitioned_data->Reset(); + } + + D_ASSERT(GetLayout().GetAggrWidth() == layout.GetAggrWidth()); + D_ASSERT(GetLayout().GetDataWidth() == layout.GetDataWidth()); + D_ASSERT(GetLayout().GetRowWidth() == layout.GetRowWidth()); + + partitioned_data->InitializeAppendState(state.append_state, TupleDataPinProperties::KEEP_EVERYTHING_PINNED); +} + +unique_ptr &GroupedAggregateHashTable::GetPartitionedData() { + return partitioned_data; +} + +shared_ptr GroupedAggregateHashTable::GetAggregateAllocator() { + return aggregate_allocator; +} + +GroupedAggregateHashTable::~GroupedAggregateHashTable() { + Destroy(); +} + +void GroupedAggregateHashTable::Destroy() { + if (!partitioned_data || partitioned_data->Count() == 0 || !layout.HasDestructor()) { + return; + } + + // There are aggregates with destructors: Call the destructor for each of the aggregates + // Currently does not happen because aggregate destructors are called while scanning in RadixPartitionedHashTable + // LCOV_EXCL_START + RowOperationsState row_state(*aggregate_allocator); + for (auto &data_collection : partitioned_data->GetPartitions()) { + if (data_collection->Count() == 0) { + continue; + } + TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); + auto &row_locations = iterator.GetChunkState().row_locations; + do { + RowOperations::DestroyStates(row_state, layout, row_locations, iterator.GetCurrentChunkCount()); + } while (iterator.Next()); + data_collection->Reset(); + } + // LCOV_EXCL_STOP +} + +const TupleDataLayout &GroupedAggregateHashTable::GetLayout() const { + return partitioned_data->GetLayout(); +} + +idx_t GroupedAggregateHashTable::Count() const { + return count; +} + +idx_t GroupedAggregateHashTable::InitialCapacity() { + return STANDARD_VECTOR_SIZE * 2ULL; +} + +idx_t GroupedAggregateHashTable::GetCapacityForCount(idx_t count) { + count = MaxValue(InitialCapacity(), count); + return NextPowerOfTwo(count * LOAD_FACTOR); +} + +idx_t GroupedAggregateHashTable::Capacity() const { + return capacity; +} + +idx_t GroupedAggregateHashTable::ResizeThreshold() const { + return Capacity() / LOAD_FACTOR; +} + +idx_t GroupedAggregateHashTable::ApplyBitMask(hash_t hash) const { + return hash & bitmask; +} + +void GroupedAggregateHashTable::Verify() { +#ifdef DEBUG + idx_t total_count = 0; + for (idx_t i = 0; i < capacity; i++) { + const auto &entry = entries[i]; + if (!entry.IsOccupied()) { + continue; + } + auto hash = Load(entry.GetPointer() + hash_offset); + D_ASSERT(entry.GetSalt() == aggr_ht_entry_t::ExtractSalt(hash)); + total_count++; + } + D_ASSERT(total_count == Count()); +#endif +} + +void GroupedAggregateHashTable::ClearPointerTable() { + std::fill_n(entries, capacity, aggr_ht_entry_t(0)); +} + +void GroupedAggregateHashTable::ResetCount() { + count = 0; +} + +void GroupedAggregateHashTable::SetRadixBits(idx_t radix_bits_p) { + radix_bits = radix_bits_p; +} + +void GroupedAggregateHashTable::Resize(idx_t size) { + D_ASSERT(size >= STANDARD_VECTOR_SIZE); + D_ASSERT(IsPowerOfTwo(size)); + if (size < capacity) { + throw InternalException("Cannot downsize a hash table!"); + } + + capacity = size; + hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(aggr_ht_entry_t)); + entries = reinterpret_cast(hash_map.get()); + ClearPointerTable(); + bitmask = capacity - 1; + + if (Count() != 0) { + for (auto &data_collection : partitioned_data->GetPartitions()) { + if (data_collection->Count() == 0) { + continue; + } + TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::ALREADY_PINNED, false); + const auto row_locations = iterator.GetRowLocations(); + do { + for (idx_t i = 0; i < iterator.GetCurrentChunkCount(); i++) { + const auto &row_location = row_locations[i]; + const auto hash = Load(row_location + hash_offset); + + // Find an empty entry + auto entry_idx = ApplyBitMask(hash); + D_ASSERT(entry_idx == hash % capacity); + while (entries[entry_idx].IsOccupied() > 0) { + entry_idx++; + if (entry_idx >= capacity) { + entry_idx = 0; + } + } + auto &entry = entries[entry_idx]; + D_ASSERT(!entry.IsOccupied()); + entry.SetSalt(aggr_ht_entry_t::ExtractSalt(hash)); + entry.SetPointer(row_location); + D_ASSERT(entry.IsOccupied()); + } + } while (iterator.Next()); + } + } + + Verify(); +} + +idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload, AggregateType filter) { + unsafe_vector aggregate_filter; + + auto &aggregates = layout.GetAggregates(); + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &aggregate = aggregates[i]; + if (aggregate.aggr_type == filter) { + aggregate_filter.push_back(i); + } + } + return AddChunk(groups, payload, aggregate_filter); +} + +idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload, const unsafe_vector &filter) { + Vector hashes(LogicalType::HASH); + groups.Hash(hashes); + + return AddChunk(groups, hashes, payload, filter); +} + +idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashes, DataChunk &payload, + const unsafe_vector &filter) { + if (groups.size() == 0) { + return 0; + } + +#ifdef DEBUG + D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); + for (idx_t i = 0; i < groups.ColumnCount(); i++) { + D_ASSERT(groups.GetTypes()[i] == layout.GetTypes()[i]); + } +#endif + + const auto new_group_count = FindOrCreateGroups(groups, group_hashes, state.addresses, state.new_groups); + VectorOperations::AddInPlace(state.addresses, layout.GetAggrOffset(), payload.size()); + + // Now every cell has an entry, update the aggregates + auto &aggregates = layout.GetAggregates(); + idx_t filter_idx = 0; + idx_t payload_idx = 0; + RowOperationsState row_state(*aggregate_allocator); + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &aggr = aggregates[i]; + if (filter_idx >= filter.size() || i < filter[filter_idx]) { + // Skip all the aggregates that are not in the filter + payload_idx += aggr.child_count; + VectorOperations::AddInPlace(state.addresses, aggr.payload_size, payload.size()); + continue; + } + D_ASSERT(i == filter[filter_idx]); + + if (aggr.aggr_type != AggregateType::DISTINCT && aggr.filter) { + RowOperations::UpdateFilteredStates(row_state, filter_set.GetFilterData(i), aggr, state.addresses, payload, + payload_idx); + } else { + RowOperations::UpdateStates(row_state, aggr, state.addresses, payload, payload_idx, payload.size()); + } + + // Move to the next aggregate + payload_idx += aggr.child_count; + VectorOperations::AddInPlace(state.addresses, aggr.payload_size, payload.size()); + filter_idx++; + } + + Verify(); + return new_group_count; +} + +void GroupedAggregateHashTable::FetchAggregates(DataChunk &groups, DataChunk &result) { +#ifdef DEBUG + groups.Verify(); + D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); + for (idx_t i = 0; i < result.ColumnCount(); i++) { + D_ASSERT(result.data[i].GetType() == payload_types[i]); + } +#endif + + result.SetCardinality(groups); + if (groups.size() == 0) { + return; + } + + // find the groups associated with the addresses + // FIXME: this should not use the FindOrCreateGroups, creating them is unnecessary + Vector addresses(LogicalType::POINTER); + FindOrCreateGroups(groups, addresses); + // now fetch the aggregates + RowOperationsState row_state(*aggregate_allocator); + RowOperations::FinalizeStates(row_state, layout, addresses, result, 0); +} + +idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, Vector &group_hashes_v, + Vector &addresses_v, SelectionVector &new_groups_out) { + D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); + D_ASSERT(group_hashes_v.GetType() == LogicalType::HASH); + D_ASSERT(state.ht_offsets.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(state.ht_offsets.GetType() == LogicalType::UBIGINT); + D_ASSERT(addresses_v.GetType() == LogicalType::POINTER); + D_ASSERT(state.hash_salts.GetType() == LogicalType::HASH); + + // Need to fit the entire vector, and resize at threshold + if (Count() + groups.size() > capacity || Count() + groups.size() > ResizeThreshold()) { + Verify(); + Resize(capacity * 2); + } + D_ASSERT(capacity - Count() >= groups.size()); // we need to be able to fit at least one vector of data + + group_hashes_v.Flatten(groups.size()); + auto hashes = FlatVector::GetData(group_hashes_v); + + addresses_v.Flatten(groups.size()); + auto addresses = FlatVector::GetData(addresses_v); + + // Compute the entry in the table based on the hash using a modulo, + // and precompute the hash salts for faster comparison below + auto ht_offsets = FlatVector::GetData(state.ht_offsets); + auto hash_salts = FlatVector::GetData(state.hash_salts); + for (idx_t r = 0; r < groups.size(); r++) { + const auto &hash = hashes[r]; + ht_offsets[r] = ApplyBitMask(hash); + D_ASSERT(ht_offsets[r] == hash % capacity); + hash_salts[r] = aggr_ht_entry_t::ExtractSalt(hash); + } + + // we start out with all entries [0, 1, 2, ..., groups.size()] + const SelectionVector *sel_vector = FlatVector::IncrementalSelectionVector(); + + // Make a chunk that references the groups and the hashes and convert to unified format + if (state.group_chunk.ColumnCount() == 0) { + state.group_chunk.InitializeEmpty(layout.GetTypes()); + } + D_ASSERT(state.group_chunk.ColumnCount() == layout.GetTypes().size()); + for (idx_t grp_idx = 0; grp_idx < groups.ColumnCount(); grp_idx++) { + state.group_chunk.data[grp_idx].Reference(groups.data[grp_idx]); + } + state.group_chunk.data[groups.ColumnCount()].Reference(group_hashes_v); + state.group_chunk.SetCardinality(groups); + + // convert all vectors to unified format + auto &chunk_state = state.append_state.chunk_state; + TupleDataCollection::ToUnifiedFormat(chunk_state, state.group_chunk); + if (!state.group_data) { + state.group_data = make_unsafe_uniq_array(state.group_chunk.ColumnCount()); + } + TupleDataCollection::GetVectorData(chunk_state, state.group_data.get()); + + idx_t new_group_count = 0; + idx_t remaining_entries = groups.size(); + while (remaining_entries > 0) { + idx_t new_entry_count = 0; + idx_t need_compare_count = 0; + idx_t no_match_count = 0; + + // For each remaining entry, figure out whether or not it belongs to a full or empty group + for (idx_t i = 0; i < remaining_entries; i++) { + const auto index = sel_vector->get_index(i); + const auto &salt = hash_salts[index]; + auto &entry = entries[ht_offsets[index]]; + if (entry.IsOccupied()) { // Cell is occupied: Compare salts + if (entry.GetSalt() == salt) { + state.group_compare_vector.set_index(need_compare_count++, index); + } else { + state.no_match_vector.set_index(no_match_count++, index); + } + } else { // Cell is unoccupied + // Set salt (also marks as occupied) + entry.SetSalt(salt); + + // Update selection lists for outer loops + state.empty_vector.set_index(new_entry_count++, index); + new_groups_out.set_index(new_group_count++, index); + } + } + + if (new_entry_count != 0) { + // Append everything that belongs to an empty group + partitioned_data->AppendUnified(state.append_state, state.group_chunk, state.empty_vector, new_entry_count); + RowOperations::InitializeStates(layout, chunk_state.row_locations, + *FlatVector::IncrementalSelectionVector(), new_entry_count); + + // Set the entry pointers in the 1st part of the HT now that the data has been appended + const auto row_locations = FlatVector::GetData(chunk_state.row_locations); + const auto &row_sel = state.append_state.reverse_partition_sel; + for (idx_t new_entry_idx = 0; new_entry_idx < new_entry_count; new_entry_idx++) { + const auto index = state.empty_vector.get_index(new_entry_idx); + const auto row_idx = row_sel.get_index(index); + const auto &row_location = row_locations[row_idx]; + + auto &entry = entries[ht_offsets[index]]; + + entry.SetPointer(row_location); + addresses[index] = row_location; + } + } + + if (need_compare_count != 0) { + // Get the pointers to the rows that need to be compared + for (idx_t need_compare_idx = 0; need_compare_idx < need_compare_count; need_compare_idx++) { + const auto index = state.group_compare_vector.get_index(need_compare_idx); + const auto &entry = entries[ht_offsets[index]]; + addresses[index] = entry.GetPointer(); + } + + // Perform group comparisons + row_matcher.Match(state.group_chunk, chunk_state.vector_data, state.group_compare_vector, + need_compare_count, layout, addresses_v, &state.no_match_vector, no_match_count); + } + + // Linear probing: each of the entries that do not match move to the next entry in the HT + for (idx_t i = 0; i < no_match_count; i++) { + idx_t index = state.no_match_vector.get_index(i); + ht_offsets[index]++; + if (ht_offsets[index] >= capacity) { + ht_offsets[index] = 0; + } + } + sel_vector = &state.no_match_vector; + remaining_entries = no_match_count; + } + + count += new_group_count; + return new_group_count; +} + +// this is to support distinct aggregations where we need to record whether we +// have already seen a value for a group +idx_t GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &group_hashes, Vector &addresses_out, + SelectionVector &new_groups_out) { + return FindOrCreateGroupsInternal(groups, group_hashes, addresses_out, new_groups_out); +} + +void GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &addresses) { + // create a dummy new_groups sel vector + FindOrCreateGroups(groups, addresses, state.new_groups); +} + +idx_t GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &addresses_out, + SelectionVector &new_groups_out) { + Vector hashes(LogicalType::HASH); + groups.Hash(hashes); + return FindOrCreateGroups(groups, hashes, addresses_out, new_groups_out); +} + +struct FlushMoveState { + explicit FlushMoveState(TupleDataCollection &collection_p) + : collection(collection_p), hashes(LogicalType::HASH), group_addresses(LogicalType::POINTER), + new_groups_sel(STANDARD_VECTOR_SIZE) { + const auto &layout = collection.GetLayout(); + vector column_ids; + column_ids.reserve(layout.ColumnCount() - 1); + for (idx_t col_idx = 0; col_idx < layout.ColumnCount() - 1; col_idx++) { + column_ids.emplace_back(col_idx); + } + collection.InitializeScan(scan_state, column_ids, TupleDataPinProperties::DESTROY_AFTER_DONE); + collection.InitializeScanChunk(scan_state, groups); + hash_col_idx = layout.ColumnCount() - 1; + } + + bool Scan() { + if (collection.Scan(scan_state, groups)) { + collection.Gather(scan_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), + groups.size(), hash_col_idx, hashes, *FlatVector::IncrementalSelectionVector()); + return true; + } + + collection.FinalizePinState(scan_state.pin_state); + return false; + } + + TupleDataCollection &collection; + TupleDataScanState scan_state; + DataChunk groups; + + idx_t hash_col_idx; + Vector hashes; + + Vector group_addresses; + SelectionVector new_groups_sel; +}; + +void GroupedAggregateHashTable::Combine(GroupedAggregateHashTable &other) { + auto other_data = other.partitioned_data->GetUnpartitioned(); + Combine(*other_data); + + // Inherit ownership to all stored aggregate allocators + stored_allocators.emplace_back(other.aggregate_allocator); + for (const auto &stored_allocator : other.stored_allocators) { + stored_allocators.emplace_back(stored_allocator); + } +} + +void GroupedAggregateHashTable::Combine(TupleDataCollection &other_data) { + D_ASSERT(other_data.GetLayout().GetAggrWidth() == layout.GetAggrWidth()); + D_ASSERT(other_data.GetLayout().GetDataWidth() == layout.GetDataWidth()); + D_ASSERT(other_data.GetLayout().GetRowWidth() == layout.GetRowWidth()); + + if (other_data.Count() == 0) { + return; + } + + FlushMoveState fm_state(other_data); + RowOperationsState row_state(*aggregate_allocator); + while (fm_state.Scan()) { + FindOrCreateGroups(fm_state.groups, fm_state.hashes, fm_state.group_addresses, fm_state.new_groups_sel); + RowOperations::CombineStates(row_state, layout, fm_state.scan_state.chunk_state.row_locations, + fm_state.group_addresses, fm_state.groups.size()); + if (layout.HasDestructor()) { + RowOperations::DestroyStates(row_state, layout, fm_state.scan_state.chunk_state.row_locations, + fm_state.groups.size()); + } + } + + Verify(); +} + +void GroupedAggregateHashTable::UnpinData() { + partitioned_data->FlushAppendState(state.append_state); + partitioned_data->Unpin(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/base_aggregate_hashtable.cpp b/src/duckdb/src/execution/base_aggregate_hashtable.cpp new file mode 100644 index 00000000..eec99f9e --- /dev/null +++ b/src/duckdb/src/execution/base_aggregate_hashtable.cpp @@ -0,0 +1,15 @@ +#include "duckdb/execution/base_aggregate_hashtable.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +BaseAggregateHashTable::BaseAggregateHashTable(ClientContext &context, Allocator &allocator, + const vector &aggregates, + vector payload_types_p) + : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), + payload_types(std::move(payload_types_p)) { + filter_set.Initialize(context, aggregates, payload_types); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/column_binding_resolver.cpp b/src/duckdb/src/execution/column_binding_resolver.cpp new file mode 100644 index 00000000..9dd72c59 --- /dev/null +++ b/src/duckdb/src/execution/column_binding_resolver.cpp @@ -0,0 +1,170 @@ +#include "duckdb/execution/column_binding_resolver.hpp" + +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_any_join.hpp" +#include "duckdb/planner/operator/logical_create_index.hpp" +#include "duckdb/planner/operator/logical_insert.hpp" +#include "duckdb/planner/operator/logical_extension_operator.hpp" + +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/to_string.hpp" + +namespace duckdb { + +ColumnBindingResolver::ColumnBindingResolver() { +} + +void ColumnBindingResolver::VisitOperator(LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: { + // special case: comparison join + auto &comp_join = op.Cast(); + // first get the bindings of the LHS and resolve the LHS expressions + VisitOperator(*comp_join.children[0]); + for (auto &cond : comp_join.conditions) { + VisitExpression(&cond.left); + } + // visit the duplicate eliminated columns on the LHS, if any + for (auto &expr : comp_join.duplicate_eliminated_columns) { + VisitExpression(&expr); + } + // then get the bindings of the RHS and resolve the RHS expressions + VisitOperator(*comp_join.children[1]); + for (auto &cond : comp_join.conditions) { + VisitExpression(&cond.right); + } + // finally update the bindings with the result bindings of the join + bindings = op.GetColumnBindings(); + return; + } + case LogicalOperatorType::LOGICAL_ANY_JOIN: { + // ANY join, this join is different because we evaluate the expression on the bindings of BOTH join sides at + // once i.e. we set the bindings first to the bindings of the entire join, and then resolve the expressions of + // this operator + VisitOperatorChildren(op); + bindings = op.GetColumnBindings(); + auto &any_join = op.Cast(); + if (any_join.join_type == JoinType::SEMI || any_join.join_type == JoinType::ANTI) { + auto right_bindings = op.children[1]->GetColumnBindings(); + bindings.insert(bindings.end(), right_bindings.begin(), right_bindings.end()); + } + VisitOperatorExpressions(op); + return; + } + case LogicalOperatorType::LOGICAL_CREATE_INDEX: { + // CREATE INDEX statement, add the columns of the table with table index 0 to the binding set + // afterwards bind the expressions of the CREATE INDEX statement + auto &create_index = op.Cast(); + bindings = LogicalOperator::GenerateColumnBindings(0, create_index.table.GetColumns().LogicalColumnCount()); + VisitOperatorExpressions(op); + return; + } + case LogicalOperatorType::LOGICAL_GET: { + //! We first need to update the current set of bindings and then visit operator expressions + bindings = op.GetColumnBindings(); + VisitOperatorExpressions(op); + return; + } + case LogicalOperatorType::LOGICAL_INSERT: { + //! We want to execute the normal path, but also add a dummy 'excluded' binding if there is a + // ON CONFLICT DO UPDATE clause + auto &insert_op = op.Cast(); + if (insert_op.action_type != OnConflictAction::THROW) { + // Get the bindings from the children + VisitOperatorChildren(op); + auto column_count = insert_op.table.GetColumns().PhysicalColumnCount(); + auto dummy_bindings = LogicalOperator::GenerateColumnBindings(insert_op.excluded_table_index, column_count); + // Now insert our dummy bindings at the start of the bindings, + // so the first 'column_count' indices of the chunk are reserved for our 'excluded' columns + bindings.insert(bindings.begin(), dummy_bindings.begin(), dummy_bindings.end()); + if (insert_op.on_conflict_condition) { + VisitExpression(&insert_op.on_conflict_condition); + } + if (insert_op.do_update_condition) { + VisitExpression(&insert_op.do_update_condition); + } + VisitOperatorExpressions(op); + bindings = op.GetColumnBindings(); + return; + } + break; + } + case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: { + auto &ext_op = op.Cast(); + ext_op.ResolveColumnBindings(*this, bindings); + return; + } + default: + break; + } + + // general case + // first visit the children of this operator + VisitOperatorChildren(op); + // now visit the expressions of this operator to resolve any bound column references + VisitOperatorExpressions(op); + // finally update the current set of bindings to the current set of column bindings + bindings = op.GetColumnBindings(); +} + +unique_ptr ColumnBindingResolver::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + D_ASSERT(expr.depth == 0); + // check the current set of column bindings to see which index corresponds to the column reference + for (idx_t i = 0; i < bindings.size(); i++) { + if (expr.binding == bindings[i]) { + return make_uniq(expr.alias, expr.return_type, i); + } + } + // LCOV_EXCL_START + // could not bind the column reference, this should never happen and indicates a bug in the code + // generate an error message + string bound_columns = "["; + for (idx_t i = 0; i < bindings.size(); i++) { + if (i != 0) { + bound_columns += " "; + } + bound_columns += to_string(bindings[i].table_index) + "." + to_string(bindings[i].column_index); + } + bound_columns += "]"; + + throw InternalException("Failed to bind column reference \"%s\" [%d.%d] (bindings: %s)", expr.alias, + expr.binding.table_index, expr.binding.column_index, bound_columns); + // LCOV_EXCL_STOP +} + +unordered_set ColumnBindingResolver::VerifyInternal(LogicalOperator &op) { + unordered_set result; + for (auto &child : op.children) { + auto child_indexes = VerifyInternal(*child); + for (auto index : child_indexes) { + D_ASSERT(index != DConstants::INVALID_INDEX); + if (result.find(index) != result.end()) { + throw InternalException("Duplicate table index \"%lld\" found", index); + } + result.insert(index); + } + } + auto indexes = op.GetTableIndex(); + for (auto index : indexes) { + D_ASSERT(index != DConstants::INVALID_INDEX); + if (result.find(index) != result.end()) { + throw InternalException("Duplicate table index \"%lld\" found", index); + } + result.insert(index); + } + return result; +} + +void ColumnBindingResolver::Verify(LogicalOperator &op) { +#ifdef DEBUG + VerifyInternal(op); +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp new file mode 100644 index 00000000..8ee35871 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor.cpp @@ -0,0 +1,311 @@ +#include "duckdb/execution/expression_executor.hpp" + +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/planner/expression/list.hpp" + +namespace duckdb { + +ExpressionExecutor::ExpressionExecutor(ClientContext &context) : context(&context) { +} + +ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression *expression) + : ExpressionExecutor(context) { + D_ASSERT(expression); + AddExpression(*expression); +} + +ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression &expression) + : ExpressionExecutor(context) { + AddExpression(expression); +} + +ExpressionExecutor::ExpressionExecutor(ClientContext &context, const vector> &exprs) + : ExpressionExecutor(context) { + D_ASSERT(exprs.size() > 0); + for (auto &expr : exprs) { + AddExpression(*expr); + } +} + +ExpressionExecutor::ExpressionExecutor(const vector> &exprs) : context(nullptr) { + D_ASSERT(exprs.size() > 0); + for (auto &expr : exprs) { + AddExpression(*expr); + } +} + +ExpressionExecutor::ExpressionExecutor() : context(nullptr) { +} + +bool ExpressionExecutor::HasContext() { + return context; +} + +ClientContext &ExpressionExecutor::GetContext() { + if (!context) { + throw InternalException("Calling ExpressionExecutor::GetContext on an expression executor without a context"); + } + return *context; +} + +Allocator &ExpressionExecutor::GetAllocator() { + return context ? Allocator::Get(*context) : Allocator::DefaultAllocator(); +} + +void ExpressionExecutor::AddExpression(const Expression &expr) { + expressions.push_back(&expr); + auto state = make_uniq(); + Initialize(expr, *state); + state->Verify(); + states.push_back(std::move(state)); +} + +void ExpressionExecutor::Initialize(const Expression &expression, ExpressionExecutorState &state) { + state.executor = this; + state.root_state = InitializeState(expression, state); +} + +void ExpressionExecutor::Execute(DataChunk *input, DataChunk &result) { + SetChunk(input); + D_ASSERT(expressions.size() == result.ColumnCount()); + D_ASSERT(!expressions.empty()); + + for (idx_t i = 0; i < expressions.size(); i++) { + ExecuteExpression(i, result.data[i]); + } + result.SetCardinality(input ? input->size() : 1); + result.Verify(); +} + +void ExpressionExecutor::ExecuteExpression(DataChunk &input, Vector &result) { + SetChunk(&input); + ExecuteExpression(result); +} + +idx_t ExpressionExecutor::SelectExpression(DataChunk &input, SelectionVector &sel) { + D_ASSERT(expressions.size() == 1); + SetChunk(&input); + states[0]->profiler.BeginSample(); + idx_t selected_tuples = Select(*expressions[0], states[0]->root_state.get(), nullptr, input.size(), &sel, nullptr); + states[0]->profiler.EndSample(chunk ? chunk->size() : 0); + return selected_tuples; +} + +void ExpressionExecutor::ExecuteExpression(Vector &result) { + D_ASSERT(expressions.size() == 1); + ExecuteExpression(0, result); +} + +void ExpressionExecutor::ExecuteExpression(idx_t expr_idx, Vector &result) { + D_ASSERT(expr_idx < expressions.size()); + D_ASSERT(result.GetType().id() == expressions[expr_idx]->return_type.id()); + states[expr_idx]->profiler.BeginSample(); + Execute(*expressions[expr_idx], states[expr_idx]->root_state.get(), nullptr, chunk ? chunk->size() : 1, result); + states[expr_idx]->profiler.EndSample(chunk ? chunk->size() : 0); +} + +Value ExpressionExecutor::EvaluateScalar(ClientContext &context, const Expression &expr, bool allow_unfoldable) { + D_ASSERT(allow_unfoldable || expr.IsFoldable()); + D_ASSERT(expr.IsScalar()); + // use an ExpressionExecutor to execute the expression + ExpressionExecutor executor(context, expr); + + Vector result(expr.return_type); + executor.ExecuteExpression(result); + + D_ASSERT(allow_unfoldable || result.GetVectorType() == VectorType::CONSTANT_VECTOR); + auto result_value = result.GetValue(0); + D_ASSERT(result_value.type().InternalType() == expr.return_type.InternalType()); + return result_value; +} + +bool ExpressionExecutor::TryEvaluateScalar(ClientContext &context, const Expression &expr, Value &result) { + try { + result = EvaluateScalar(context, expr); + return true; + } catch (InternalException &ex) { + throw; + } catch (...) { + return false; + } +} + +void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t count) { + D_ASSERT(expr.return_type.id() == vector.GetType().id()); + vector.Verify(count); + if (expr.verification_stats) { + expr.verification_stats->Verify(vector, count); + } +} + +unique_ptr ExpressionExecutor::InitializeState(const Expression &expr, + ExpressionExecutorState &state) { + switch (expr.expression_class) { + case ExpressionClass::BOUND_REF: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_BETWEEN: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_CASE: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_CAST: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_COMPARISON: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_CONJUNCTION: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_CONSTANT: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_FUNCTION: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_OPERATOR: + return InitializeState(expr.Cast(), state); + case ExpressionClass::BOUND_PARAMETER: + return InitializeState(expr.Cast(), state); + default: + throw InternalException("Attempting to initialize state of expression of unknown type!"); + } +} + +void ExpressionExecutor::Execute(const Expression &expr, ExpressionState *state, const SelectionVector *sel, + idx_t count, Vector &result) { +#ifdef DEBUG + //! The result Vector must be "clean" + if (result.GetVectorType() == VectorType::FLAT_VECTOR) { + D_ASSERT(FlatVector::Validity(result).CheckAllValid(count)); + } +#endif + + if (count == 0) { + return; + } + if (result.GetType().id() != expr.return_type.id()) { + throw InternalException( + "ExpressionExecutor::Execute called with a result vector of type %s that does not match expression type %s", + result.GetType(), expr.return_type); + } + switch (expr.expression_class) { + case ExpressionClass::BOUND_BETWEEN: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_REF: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_CASE: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_CAST: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_COMPARISON: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_CONJUNCTION: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_CONSTANT: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_FUNCTION: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_OPERATOR: + Execute(expr.Cast(), state, sel, count, result); + break; + case ExpressionClass::BOUND_PARAMETER: + Execute(expr.Cast(), state, sel, count, result); + break; + default: + throw InternalException("Attempting to execute expression of unknown type!"); + } + Verify(expr, result, count); +} + +idx_t ExpressionExecutor::Select(const Expression &expr, ExpressionState *state, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { + if (count == 0) { + return 0; + } + D_ASSERT(true_sel || false_sel); + D_ASSERT(expr.return_type.id() == LogicalTypeId::BOOLEAN); + switch (expr.expression_class) { + case ExpressionClass::BOUND_BETWEEN: + return Select(expr.Cast(), state, sel, count, true_sel, false_sel); + case ExpressionClass::BOUND_COMPARISON: + return Select(expr.Cast(), state, sel, count, true_sel, false_sel); + case ExpressionClass::BOUND_CONJUNCTION: + return Select(expr.Cast(), state, sel, count, true_sel, false_sel); + default: + return DefaultSelect(expr, state, sel, count, true_sel, false_sel); + } +} + +template +static inline idx_t DefaultSelectLoop(const SelectionVector *bsel, const uint8_t *__restrict bdata, ValidityMask &mask, + const SelectionVector *sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + idx_t true_count = 0, false_count = 0; + for (idx_t i = 0; i < count; i++) { + auto bidx = bsel->get_index(i); + auto result_idx = sel->get_index(i); + if (bdata[bidx] > 0 && (NO_NULL || mask.RowIsValid(bidx))) { + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count++, result_idx); + } + } else { + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count++, result_idx); + } + } + } + if (HAS_TRUE_SEL) { + return true_count; + } else { + return count - false_count; + } +} + +template +static inline idx_t DefaultSelectSwitch(UnifiedVectorFormat &idata, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + if (true_sel && false_sel) { + return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), + idata.validity, sel, count, true_sel, false_sel); + } else if (true_sel) { + return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), + idata.validity, sel, count, true_sel, false_sel); + } else { + D_ASSERT(false_sel); + return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), + idata.validity, sel, count, true_sel, false_sel); + } +} + +idx_t ExpressionExecutor::DefaultSelect(const Expression &expr, ExpressionState *state, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { + // generic selection of boolean expression: + // resolve the true/false expression first + // then use that to generate the selection vector + bool intermediate_bools[STANDARD_VECTOR_SIZE]; + Vector intermediate(LogicalType::BOOLEAN, data_ptr_cast(intermediate_bools)); + Execute(expr, state, sel, count, intermediate); + + UnifiedVectorFormat idata; + intermediate.ToUnifiedFormat(count, idata); + + if (!sel) { + sel = FlatVector::IncrementalSelectionVector(); + } + if (!idata.validity.AllValid()) { + return DefaultSelectSwitch(idata, sel, count, true_sel, false_sel); + } else { + return DefaultSelectSwitch(idata, sel, count, true_sel, false_sel); + } +} + +vector> &ExpressionExecutor::GetStates() { + return states; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_between.cpp b/src/duckdb/src/execution/expression_executor/execute_between.cpp new file mode 100644 index 00000000..1fc618fe --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_between.cpp @@ -0,0 +1,152 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" + +namespace duckdb { + +struct BothInclusiveBetweenOperator { + template + static inline bool Operation(T input, T lower, T upper) { + return GreaterThanEquals::Operation(input, lower) && LessThanEquals::Operation(input, upper); + } +}; + +struct LowerInclusiveBetweenOperator { + template + static inline bool Operation(T input, T lower, T upper) { + return GreaterThanEquals::Operation(input, lower) && LessThan::Operation(input, upper); + } +}; + +struct UpperInclusiveBetweenOperator { + template + static inline bool Operation(T input, T lower, T upper) { + return GreaterThan::Operation(input, lower) && LessThanEquals::Operation(input, upper); + } +}; + +struct ExclusiveBetweenOperator { + template + static inline bool Operation(T input, T lower, T upper) { + return GreaterThan::Operation(input, lower) && LessThan::Operation(input, upper); + } +}; + +template +static idx_t BetweenLoopTypeSwitch(Vector &input, Vector &lower, Vector &upper, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + switch (input.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::INT16: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::INT32: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::INT64: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::INT128: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::UINT8: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::UINT16: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::UINT32: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::UINT64: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::FLOAT: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, false_sel); + case PhysicalType::DOUBLE: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::VARCHAR: + return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, + false_sel); + case PhysicalType::INTERVAL: + return TernaryExecutor::Select(input, lower, upper, sel, count, + true_sel, false_sel); + default: + throw InvalidTypeException(input.GetType(), "Invalid type for BETWEEN"); + } +} + +unique_ptr ExpressionExecutor::InitializeState(const BoundBetweenExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + result->AddChild(expr.input.get()); + result->AddChild(expr.lower.get()); + result->AddChild(expr.upper.get()); + result->Finalize(); + return result; +} + +void ExpressionExecutor::Execute(const BoundBetweenExpression &expr, ExpressionState *state, const SelectionVector *sel, + idx_t count, Vector &result) { + // resolve the children + state->intermediate_chunk.Reset(); + + auto &input = state->intermediate_chunk.data[0]; + auto &lower = state->intermediate_chunk.data[1]; + auto &upper = state->intermediate_chunk.data[2]; + + Execute(*expr.input, state->child_states[0].get(), sel, count, input); + Execute(*expr.lower, state->child_states[1].get(), sel, count, lower); + Execute(*expr.upper, state->child_states[2].get(), sel, count, upper); + + Vector intermediate1(LogicalType::BOOLEAN); + Vector intermediate2(LogicalType::BOOLEAN); + + if (expr.upper_inclusive && expr.lower_inclusive) { + VectorOperations::GreaterThanEquals(input, lower, intermediate1, count); + VectorOperations::LessThanEquals(input, upper, intermediate2, count); + } else if (expr.lower_inclusive) { + VectorOperations::GreaterThanEquals(input, lower, intermediate1, count); + VectorOperations::LessThan(input, upper, intermediate2, count); + } else if (expr.upper_inclusive) { + VectorOperations::GreaterThan(input, lower, intermediate1, count); + VectorOperations::LessThanEquals(input, upper, intermediate2, count); + } else { + VectorOperations::GreaterThan(input, lower, intermediate1, count); + VectorOperations::LessThan(input, upper, intermediate2, count); + } + VectorOperations::And(intermediate1, intermediate2, result, count); +} + +idx_t ExpressionExecutor::Select(const BoundBetweenExpression &expr, ExpressionState *state, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { + // resolve the children + Vector input(state->intermediate_chunk.data[0]); + Vector lower(state->intermediate_chunk.data[1]); + Vector upper(state->intermediate_chunk.data[2]); + + Execute(*expr.input, state->child_states[0].get(), sel, count, input); + Execute(*expr.lower, state->child_states[1].get(), sel, count, lower); + Execute(*expr.upper, state->child_states[2].get(), sel, count, upper); + + if (expr.upper_inclusive && expr.lower_inclusive) { + return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, + false_sel); + } else if (expr.lower_inclusive) { + return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, + false_sel); + } else if (expr.upper_inclusive) { + return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, + false_sel); + } else { + return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, false_sel); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_case.cpp b/src/duckdb/src/execution/expression_executor/execute_case.cpp new file mode 100644 index 00000000..16159a69 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_case.cpp @@ -0,0 +1,221 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" + +namespace duckdb { + +struct CaseExpressionState : public ExpressionState { + CaseExpressionState(const Expression &expr, ExpressionExecutorState &root) + : ExpressionState(expr, root), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE) { + } + + SelectionVector true_sel; + SelectionVector false_sel; +}; + +unique_ptr ExpressionExecutor::InitializeState(const BoundCaseExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + for (auto &case_check : expr.case_checks) { + result->AddChild(case_check.when_expr.get()); + result->AddChild(case_check.then_expr.get()); + } + result->AddChild(expr.else_expr.get()); + result->Finalize(); + return std::move(result); +} + +void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionState *state_p, const SelectionVector *sel, + idx_t count, Vector &result) { + auto &state = state_p->Cast(); + + state.intermediate_chunk.Reset(); + + // first execute the check expression + auto current_true_sel = &state.true_sel; + auto current_false_sel = &state.false_sel; + auto current_sel = sel; + idx_t current_count = count; + for (idx_t i = 0; i < expr.case_checks.size(); i++) { + auto &case_check = expr.case_checks[i]; + auto &intermediate_result = state.intermediate_chunk.data[i * 2 + 1]; + auto check_state = state.child_states[i * 2].get(); + auto then_state = state.child_states[i * 2 + 1].get(); + + idx_t tcount = + Select(*case_check.when_expr, check_state, current_sel, current_count, current_true_sel, current_false_sel); + if (tcount == 0) { + // everything is false: do nothing + continue; + } + idx_t fcount = current_count - tcount; + if (fcount == 0 && current_count == count) { + // everything is true in the first CHECK statement + // we can skip the entire case and only execute the TRUE side + Execute(*case_check.then_expr, then_state, sel, count, result); + return; + } else { + // we need to execute and then fill in the desired tuples in the result + Execute(*case_check.then_expr, then_state, current_true_sel, tcount, intermediate_result); + FillSwitch(intermediate_result, result, *current_true_sel, tcount); + } + // continue with the false tuples + current_sel = current_false_sel; + current_count = fcount; + if (fcount == 0) { + // everything is true: we are done + break; + } + } + if (current_count > 0) { + auto else_state = state.child_states.back().get(); + if (current_count == count) { + // everything was false, we can just evaluate the else expression directly + Execute(*expr.else_expr, else_state, sel, count, result); + return; + } else { + auto &intermediate_result = state.intermediate_chunk.data[expr.case_checks.size() * 2]; + + D_ASSERT(current_sel); + Execute(*expr.else_expr, else_state, current_sel, current_count, intermediate_result); + FillSwitch(intermediate_result, result, *current_sel, current_count); + } + } + if (sel) { + result.Slice(*sel, count); + } +} + +template +void TemplatedFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { + result.SetVectorType(VectorType::FLAT_VECTOR); + auto res = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { + auto data = ConstantVector::GetData(vector); + if (ConstantVector::IsNull(vector)) { + for (idx_t i = 0; i < count; i++) { + result_mask.SetInvalid(sel.get_index(i)); + } + } else { + for (idx_t i = 0; i < count; i++) { + res[sel.get_index(i)] = *data; + } + } + } else { + UnifiedVectorFormat vdata; + vector.ToUnifiedFormat(count, vdata); + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto source_idx = vdata.sel->get_index(i); + auto res_idx = sel.get_index(i); + + res[res_idx] = data[source_idx]; + result_mask.Set(res_idx, vdata.validity.RowIsValid(source_idx)); + } + } +} + +void ValidityFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { + result.SetVectorType(VectorType::FLAT_VECTOR); + auto &result_mask = FlatVector::Validity(result); + if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(vector)) { + for (idx_t i = 0; i < count; i++) { + result_mask.SetInvalid(sel.get_index(i)); + } + } + } else { + UnifiedVectorFormat vdata; + vector.ToUnifiedFormat(count, vdata); + if (vdata.validity.AllValid()) { + return; + } + for (idx_t i = 0; i < count; i++) { + auto source_idx = vdata.sel->get_index(i); + if (!vdata.validity.RowIsValid(source_idx)) { + result_mask.SetInvalid(sel.get_index(i)); + } + } + } +} + +void ExpressionExecutor::FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { + switch (result.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::INT16: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::INT32: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::INT64: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::UINT8: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::UINT16: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::UINT32: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::UINT64: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::INT128: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::FLOAT: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::DOUBLE: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::INTERVAL: + TemplatedFillLoop(vector, result, sel, count); + break; + case PhysicalType::VARCHAR: + TemplatedFillLoop(vector, result, sel, count); + StringVector::AddHeapReference(result, vector); + break; + case PhysicalType::STRUCT: { + auto &vector_entries = StructVector::GetEntries(vector); + auto &result_entries = StructVector::GetEntries(result); + ValidityFillLoop(vector, result, sel, count); + D_ASSERT(vector_entries.size() == result_entries.size()); + for (idx_t i = 0; i < vector_entries.size(); i++) { + FillSwitch(*vector_entries[i], *result_entries[i], sel, count); + } + break; + } + case PhysicalType::LIST: { + idx_t offset = ListVector::GetListSize(result); + auto &list_child = ListVector::GetEntry(vector); + ListVector::Append(result, list_child, ListVector::GetListSize(vector)); + + // all the false offsets need to be incremented by true_child.count + TemplatedFillLoop(vector, result, sel, count); + if (offset == 0) { + break; + } + + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto result_idx = sel.get_index(i); + result_data[result_idx].offset += offset; + } + + Vector::Verify(result, sel, count); + break; + } + default: + throw NotImplementedException("Unimplemented type for case expression: %s", result.GetType().ToString()); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_cast.cpp b/src/duckdb/src/execution/expression_executor/execute_cast.cpp new file mode 100644 index 00000000..399b2cab --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_cast.cpp @@ -0,0 +1,43 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +unique_ptr ExpressionExecutor::InitializeState(const BoundCastExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + result->AddChild(expr.child.get()); + result->Finalize(); + if (expr.bound_cast.init_local_state) { + CastLocalStateParameters parameters(root.executor->GetContext(), expr.bound_cast.cast_data); + result->local_state = expr.bound_cast.init_local_state(parameters); + } + return std::move(result); +} + +void ExpressionExecutor::Execute(const BoundCastExpression &expr, ExpressionState *state, const SelectionVector *sel, + idx_t count, Vector &result) { + auto lstate = ExecuteFunctionState::GetFunctionState(*state); + + // resolve the child + state->intermediate_chunk.Reset(); + + auto &child = state->intermediate_chunk.data[0]; + auto child_state = state->child_states[0].get(); + + Execute(*expr.child, child_state, sel, count, child); + if (expr.try_cast) { + string error_message; + CastParameters parameters(expr.bound_cast.cast_data.get(), false, &error_message, lstate); + expr.bound_cast.function(child, result, count, parameters); + } else { + // cast it to the type specified by the cast expression + D_ASSERT(result.GetType() == expr.return_type); + CastParameters parameters(expr.bound_cast.cast_data.get(), false, nullptr, lstate); + expr.bound_cast.function(child, result, count, parameters); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp new file mode 100644 index 00000000..7ca51462 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp @@ -0,0 +1,310 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" + +#include + +namespace duckdb { + +unique_ptr ExpressionExecutor::InitializeState(const BoundComparisonExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + result->AddChild(expr.left.get()); + result->AddChild(expr.right.get()); + result->Finalize(); + return result; +} + +void ExpressionExecutor::Execute(const BoundComparisonExpression &expr, ExpressionState *state, + const SelectionVector *sel, idx_t count, Vector &result) { + // resolve the children + state->intermediate_chunk.Reset(); + auto &left = state->intermediate_chunk.data[0]; + auto &right = state->intermediate_chunk.data[1]; + + Execute(*expr.left, state->child_states[0].get(), sel, count, left); + Execute(*expr.right, state->child_states[1].get(), sel, count, right); + + switch (expr.type) { + case ExpressionType::COMPARE_EQUAL: + VectorOperations::Equals(left, right, result, count); + break; + case ExpressionType::COMPARE_NOTEQUAL: + VectorOperations::NotEquals(left, right, result, count); + break; + case ExpressionType::COMPARE_LESSTHAN: + VectorOperations::LessThan(left, right, result, count); + break; + case ExpressionType::COMPARE_GREATERTHAN: + VectorOperations::GreaterThan(left, right, result, count); + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + VectorOperations::LessThanEquals(left, right, result, count); + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + VectorOperations::GreaterThanEquals(left, right, result, count); + break; + case ExpressionType::COMPARE_DISTINCT_FROM: + VectorOperations::DistinctFrom(left, right, result, count); + break; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + VectorOperations::NotDistinctFrom(left, right, result, count); + break; + default: + throw InternalException("Unknown comparison type!"); + } +} + +template +static idx_t NestedSelectOperation(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + +template +static idx_t TemplatedSelectOperation(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + // the inplace loops take the result as the last parameter + switch (left.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INT16: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INT32: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INT64: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::UINT8: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::UINT16: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::UINT32: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::UINT64: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INT128: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::FLOAT: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::DOUBLE: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::INTERVAL: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::VARCHAR: + return BinaryExecutor::Select(left, right, sel, count, true_sel, false_sel); + case PhysicalType::LIST: + case PhysicalType::STRUCT: + return NestedSelectOperation(left, right, sel, count, true_sel, false_sel); + default: + throw InternalException("Invalid type for comparison"); + } +} + +struct NestedSelector { + // Select the matching rows for the values of a nested type that are not both NULL. + // Those semantics are the same as the corresponding non-distinct comparator + template + static idx_t Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + throw InvalidTypeException(left.GetType(), "Invalid operation for nested SELECT"); + } +}; + +template <> +idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::NestedEquals(left, right, sel, count, true_sel, false_sel); +} + +template <> +idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::NestedNotEquals(left, right, sel, count, true_sel, false_sel); +} + +template <> +idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::DistinctLessThan(left, right, &sel, count, true_sel, false_sel); +} + +template <> +idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::DistinctLessThanEquals(left, right, &sel, count, true_sel, false_sel); +} + +template <> +idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel); +} + +template <> +idx_t NestedSelector::Select(Vector &left, Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel); +} + +static inline idx_t SelectNotNull(Vector &left, Vector &right, const idx_t count, const SelectionVector &sel, + SelectionVector &maybe_vec, OptionalSelection &false_opt) { + + UnifiedVectorFormat lvdata, rvdata; + left.ToUnifiedFormat(count, lvdata); + right.ToUnifiedFormat(count, rvdata); + + auto &lmask = lvdata.validity; + auto &rmask = rvdata.validity; + + // For top-level comparisons, NULL semantics are in effect, + // so filter out any NULLs + idx_t remaining = 0; + if (lmask.AllValid() && rmask.AllValid()) { + // None are NULL, distinguish values. + for (idx_t i = 0; i < count; ++i) { + const auto idx = sel.get_index(i); + maybe_vec.set_index(remaining++, idx); + } + return remaining; + } + + // Slice the Vectors down to the rows that are not determined (i.e., neither is NULL) + SelectionVector slicer(count); + idx_t false_count = 0; + for (idx_t i = 0; i < count; ++i) { + const auto result_idx = sel.get_index(i); + const auto lidx = lvdata.sel->get_index(i); + const auto ridx = rvdata.sel->get_index(i); + if (!lmask.RowIsValid(lidx) || !rmask.RowIsValid(ridx)) { + false_opt.Append(false_count, result_idx); + } else { + // Neither is NULL, distinguish values. + slicer.set_index(remaining, i); + maybe_vec.set_index(remaining++, result_idx); + } + } + false_opt.Advance(false_count); + + if (remaining && remaining < count) { + left.Slice(slicer, remaining); + right.Slice(slicer, remaining); + } + + return remaining; +} + +static void ScatterSelection(SelectionVector *target, const idx_t count, const SelectionVector &dense_vec) { + if (target) { + for (idx_t i = 0; i < count; ++i) { + target->set_index(i, dense_vec.get_index(i)); + } + } +} + +template +static idx_t NestedSelectOperation(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + // The Select operations all use a dense pair of input vectors to partition + // a selection vector in a single pass. But to implement progressive comparisons, + // we have to make multiple passes, so we need to keep track of the original input positions + // and then scatter the output selections when we are done. + if (!sel) { + sel = FlatVector::IncrementalSelectionVector(); + } + + // Make buffered selections for progressive comparisons + // TODO: Remove unnecessary allocations + SelectionVector true_vec(count); + OptionalSelection true_opt(&true_vec); + + SelectionVector false_vec(count); + OptionalSelection false_opt(&false_vec); + + SelectionVector maybe_vec(count); + + // Handle NULL nested values + Vector l_not_null(left); + Vector r_not_null(right); + + auto match_count = SelectNotNull(l_not_null, r_not_null, count, *sel, maybe_vec, false_opt); + auto no_match_count = count - match_count; + count = match_count; + + // Now that we have handled the NULLs, we can use the recursive nested comparator for the rest. + match_count = NestedSelector::Select(l_not_null, r_not_null, maybe_vec, count, true_opt, false_opt); + no_match_count += (count - match_count); + + // Copy the buffered selections to the output selections + ScatterSelection(true_sel, match_count, true_vec); + ScatterSelection(false_sel, no_match_count, false_vec); + + return match_count; +} + +idx_t VectorOperations::Equals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel); +} + +idx_t VectorOperations::NotEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel); +} + +idx_t VectorOperations::GreaterThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel); +} + +idx_t VectorOperations::GreaterThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel); +} + +idx_t VectorOperations::LessThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedSelectOperation(right, left, sel, count, true_sel, false_sel); +} + +idx_t VectorOperations::LessThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + return TemplatedSelectOperation(right, left, sel, count, true_sel, false_sel); +} + +idx_t ExpressionExecutor::Select(const BoundComparisonExpression &expr, ExpressionState *state, + const SelectionVector *sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + // resolve the children + state->intermediate_chunk.Reset(); + auto &left = state->intermediate_chunk.data[0]; + auto &right = state->intermediate_chunk.data[1]; + + Execute(*expr.left, state->child_states[0].get(), sel, count, left); + Execute(*expr.right, state->child_states[1].get(), sel, count, right); + + switch (expr.type) { + case ExpressionType::COMPARE_EQUAL: + return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel); + case ExpressionType::COMPARE_NOTEQUAL: + return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel); + case ExpressionType::COMPARE_LESSTHAN: + return VectorOperations::LessThan(left, right, sel, count, true_sel, false_sel); + case ExpressionType::COMPARE_GREATERTHAN: + return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, false_sel); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel); + case ExpressionType::COMPARE_DISTINCT_FROM: + return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, false_sel); + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, false_sel); + default: + throw InternalException("Unknown comparison type!"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp b/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp new file mode 100644 index 00000000..de1fc801 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp @@ -0,0 +1,144 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/execution/adaptive_filter.hpp" +#include "duckdb/common/chrono.hpp" + +#include + +namespace duckdb { + +struct ConjunctionState : public ExpressionState { + ConjunctionState(const Expression &expr, ExpressionExecutorState &root) : ExpressionState(expr, root) { + adaptive_filter = make_uniq(expr); + } + unique_ptr adaptive_filter; +}; + +unique_ptr ExpressionExecutor::InitializeState(const BoundConjunctionExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + for (auto &child : expr.children) { + result->AddChild(child.get()); + } + result->Finalize(); + return std::move(result); +} + +void ExpressionExecutor::Execute(const BoundConjunctionExpression &expr, ExpressionState *state, + const SelectionVector *sel, idx_t count, Vector &result) { + // execute the children + state->intermediate_chunk.Reset(); + for (idx_t i = 0; i < expr.children.size(); i++) { + auto ¤t_result = state->intermediate_chunk.data[i]; + Execute(*expr.children[i], state->child_states[i].get(), sel, count, current_result); + if (i == 0) { + // move the result + result.Reference(current_result); + } else { + Vector intermediate(LogicalType::BOOLEAN); + // AND/OR together + switch (expr.type) { + case ExpressionType::CONJUNCTION_AND: + VectorOperations::And(current_result, result, intermediate, count); + break; + case ExpressionType::CONJUNCTION_OR: + VectorOperations::Or(current_result, result, intermediate, count); + break; + default: + throw InternalException("Unknown conjunction type!"); + } + result.Reference(intermediate); + } + } +} + +idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, ExpressionState *state_p, + const SelectionVector *sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + auto &state = state_p->Cast(); + + if (expr.type == ExpressionType::CONJUNCTION_AND) { + // get runtime statistics + auto start_time = high_resolution_clock::now(); + + const SelectionVector *current_sel = sel; + idx_t current_count = count; + idx_t false_count = 0; + + unique_ptr temp_true, temp_false; + if (false_sel) { + temp_false = make_uniq(STANDARD_VECTOR_SIZE); + } + if (!true_sel) { + temp_true = make_uniq(STANDARD_VECTOR_SIZE); + true_sel = temp_true.get(); + } + for (idx_t i = 0; i < expr.children.size(); i++) { + idx_t tcount = Select(*expr.children[state.adaptive_filter->permutation[i]], + state.child_states[state.adaptive_filter->permutation[i]].get(), current_sel, + current_count, true_sel, temp_false.get()); + idx_t fcount = current_count - tcount; + if (fcount > 0 && false_sel) { + // move failing tuples into the false_sel + // tuples passed, move them into the actual result vector + for (idx_t i = 0; i < fcount; i++) { + false_sel->set_index(false_count++, temp_false->get_index(i)); + } + } + current_count = tcount; + if (current_count == 0) { + break; + } + if (current_count < count) { + // tuples were filtered out: move on to using the true_sel to only evaluate passing tuples in subsequent + // iterations + current_sel = true_sel; + } + } + + // adapt runtime statistics + auto end_time = high_resolution_clock::now(); + state.adaptive_filter->AdaptRuntimeStatistics(duration_cast>(end_time - start_time).count()); + return current_count; + } else { + // get runtime statistics + auto start_time = high_resolution_clock::now(); + + const SelectionVector *current_sel = sel; + idx_t current_count = count; + idx_t result_count = 0; + + unique_ptr temp_true, temp_false; + if (true_sel) { + temp_true = make_uniq(STANDARD_VECTOR_SIZE); + } + if (!false_sel) { + temp_false = make_uniq(STANDARD_VECTOR_SIZE); + false_sel = temp_false.get(); + } + for (idx_t i = 0; i < expr.children.size(); i++) { + idx_t tcount = Select(*expr.children[state.adaptive_filter->permutation[i]], + state.child_states[state.adaptive_filter->permutation[i]].get(), current_sel, + current_count, temp_true.get(), false_sel); + if (tcount > 0) { + if (true_sel) { + // tuples passed, move them into the actual result vector + for (idx_t i = 0; i < tcount; i++) { + true_sel->set_index(result_count++, temp_true->get_index(i)); + } + } + // now move on to check only the non-passing tuples + current_count -= tcount; + current_sel = false_sel; + } + } + + // adapt runtime statistics + auto end_time = high_resolution_clock::now(); + state.adaptive_filter->AdaptRuntimeStatistics(duration_cast>(end_time - start_time).count()); + return result_count; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_constant.cpp b/src/duckdb/src/execution/expression_executor/execute_constant.cpp new file mode 100644 index 00000000..cd9b463b --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_constant.cpp @@ -0,0 +1,20 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +unique_ptr ExpressionExecutor::InitializeState(const BoundConstantExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + result->Finalize(); + return result; +} + +void ExpressionExecutor::Execute(const BoundConstantExpression &expr, ExpressionState *state, + const SelectionVector *sel, idx_t count, Vector &result) { + D_ASSERT(expr.value.type() == expr.return_type); + result.Reference(expr.value); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_function.cpp b/src/duckdb/src/execution/expression_executor/execute_function.cpp new file mode 100644 index 00000000..58207ed4 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_function.cpp @@ -0,0 +1,86 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +ExecuteFunctionState::ExecuteFunctionState(const Expression &expr, ExpressionExecutorState &root) + : ExpressionState(expr, root) { +} + +ExecuteFunctionState::~ExecuteFunctionState() { +} + +unique_ptr ExpressionExecutor::InitializeState(const BoundFunctionExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + for (auto &child : expr.children) { + result->AddChild(child.get()); + } + result->Finalize(); + if (expr.function.init_local_state) { + result->local_state = expr.function.init_local_state(*result, expr, expr.bind_info.get()); + } + return std::move(result); +} + +static void VerifyNullHandling(const BoundFunctionExpression &expr, DataChunk &args, Vector &result) { +#ifdef DEBUG + if (args.data.empty() || expr.function.null_handling != FunctionNullHandling::DEFAULT_NULL_HANDLING) { + return; + } + + // Combine all the argument validity masks into a flat validity mask + idx_t count = args.size(); + ValidityMask combined_mask(count); + for (auto &arg : args.data) { + UnifiedVectorFormat arg_data; + arg.ToUnifiedFormat(count, arg_data); + + for (idx_t i = 0; i < count; i++) { + auto idx = arg_data.sel->get_index(i); + if (!arg_data.validity.RowIsValid(idx)) { + combined_mask.SetInvalid(i); + } + } + } + + // Default is that if any of the arguments are NULL, the result is also NULL + UnifiedVectorFormat result_data; + result.ToUnifiedFormat(count, result_data); + for (idx_t i = 0; i < count; i++) { + if (!combined_mask.RowIsValid(i)) { + auto idx = result_data.sel->get_index(i); + D_ASSERT(!result_data.validity.RowIsValid(idx)); + } + } +#endif +} + +void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, ExpressionState *state, + const SelectionVector *sel, idx_t count, Vector &result) { + state->intermediate_chunk.Reset(); + auto &arguments = state->intermediate_chunk; + if (!state->types.empty()) { + for (idx_t i = 0; i < expr.children.size(); i++) { + D_ASSERT(state->types[i] == expr.children[i]->return_type); + Execute(*expr.children[i], state->child_states[i].get(), sel, count, arguments.data[i]); +#ifdef DEBUG + if (expr.children[i]->return_type.id() == LogicalTypeId::VARCHAR) { + arguments.data[i].UTFVerify(count); + } +#endif + } + arguments.Verify(); + } + arguments.SetCardinality(count); + + state->profiler.BeginSample(); + D_ASSERT(expr.function.function); + expr.function.function(arguments, *state, result); + state->profiler.EndSample(count); + + VerifyNullHandling(expr, arguments, result); + D_ASSERT(result.GetType() == expr.return_type); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_operator.cpp b/src/duckdb/src/execution/expression_executor/execute_operator.cpp new file mode 100644 index 00000000..1440b096 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_operator.cpp @@ -0,0 +1,138 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" + +namespace duckdb { + +unique_ptr ExpressionExecutor::InitializeState(const BoundOperatorExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + for (auto &child : expr.children) { + result->AddChild(child.get()); + } + result->Finalize(); + return result; +} + +void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, ExpressionState *state, + const SelectionVector *sel, idx_t count, Vector &result) { + // special handling for special snowflake 'IN' + // IN has n children + if (expr.type == ExpressionType::COMPARE_IN || expr.type == ExpressionType::COMPARE_NOT_IN) { + if (expr.children.size() < 2) { + throw InvalidInputException("IN needs at least two children"); + } + + Vector left(expr.children[0]->return_type); + // eval left side + Execute(*expr.children[0], state->child_states[0].get(), sel, count, left); + + // init result to false + Vector intermediate(LogicalType::BOOLEAN); + Value false_val = Value::BOOLEAN(false); + intermediate.Reference(false_val); + + // in rhs is a list of constants + // for every child, OR the result of the comparision with the left + // to get the overall result. + for (idx_t child = 1; child < expr.children.size(); child++) { + Vector vector_to_check(expr.children[child]->return_type); + Vector comp_res(LogicalType::BOOLEAN); + + Execute(*expr.children[child], state->child_states[child].get(), sel, count, vector_to_check); + VectorOperations::Equals(left, vector_to_check, comp_res, count); + + if (child == 1) { + // first child: move to result + intermediate.Reference(comp_res); + } else { + // otherwise OR together + Vector new_result(LogicalType::BOOLEAN, true, false); + VectorOperations::Or(intermediate, comp_res, new_result, count); + intermediate.Reference(new_result); + } + } + if (expr.type == ExpressionType::COMPARE_NOT_IN) { + // NOT IN: invert result + VectorOperations::Not(intermediate, result, count); + } else { + // directly use the result + result.Reference(intermediate); + } + } else if (expr.type == ExpressionType::OPERATOR_COALESCE) { + SelectionVector sel_a(count); + SelectionVector sel_b(count); + SelectionVector slice_sel(count); + SelectionVector result_sel(count); + SelectionVector *next_sel = &sel_a; + const SelectionVector *current_sel = sel; + idx_t remaining_count = count; + idx_t next_count; + for (idx_t child = 0; child < expr.children.size(); child++) { + Vector vector_to_check(expr.children[child]->return_type); + Execute(*expr.children[child], state->child_states[child].get(), current_sel, remaining_count, + vector_to_check); + + UnifiedVectorFormat vdata; + vector_to_check.ToUnifiedFormat(remaining_count, vdata); + + idx_t result_count = 0; + next_count = 0; + for (idx_t i = 0; i < remaining_count; i++) { + auto base_idx = current_sel ? current_sel->get_index(i) : i; + auto idx = vdata.sel->get_index(i); + if (vdata.validity.RowIsValid(idx)) { + slice_sel.set_index(result_count, i); + result_sel.set_index(result_count++, base_idx); + } else { + next_sel->set_index(next_count++, base_idx); + } + } + if (result_count > 0) { + vector_to_check.Slice(slice_sel, result_count); + FillSwitch(vector_to_check, result, result_sel, result_count); + } + current_sel = next_sel; + next_sel = next_sel == &sel_a ? &sel_b : &sel_a; + remaining_count = next_count; + if (next_count == 0) { + break; + } + } + if (remaining_count > 0) { + for (idx_t i = 0; i < remaining_count; i++) { + FlatVector::SetNull(result, current_sel->get_index(i), true); + } + } + if (sel) { + result.Slice(*sel, count); + } else if (count == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + } else if (expr.children.size() == 1) { + state->intermediate_chunk.Reset(); + auto &child = state->intermediate_chunk.data[0]; + + Execute(*expr.children[0], state->child_states[0].get(), sel, count, child); + switch (expr.type) { + case ExpressionType::OPERATOR_NOT: { + VectorOperations::Not(child, result, count); + break; + } + case ExpressionType::OPERATOR_IS_NULL: { + VectorOperations::IsNull(child, result, count); + break; + } + case ExpressionType::OPERATOR_IS_NOT_NULL: { + VectorOperations::IsNotNull(child, result, count); + break; + } + default: + throw NotImplementedException("Unsupported operator type with 1 child!"); + } + } else { + throw NotImplementedException("operator"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_parameter.cpp b/src/duckdb/src/execution/expression_executor/execute_parameter.cpp new file mode 100644 index 00000000..c03ca934 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_parameter.cpp @@ -0,0 +1,22 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" + +namespace duckdb { + +unique_ptr ExpressionExecutor::InitializeState(const BoundParameterExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + result->Finalize(); + return result; +} + +void ExpressionExecutor::Execute(const BoundParameterExpression &expr, ExpressionState *state, + const SelectionVector *sel, idx_t count, Vector &result) { + D_ASSERT(expr.parameter_data); + D_ASSERT(expr.parameter_data->return_type == expr.return_type); + D_ASSERT(expr.parameter_data->GetValue().type() == expr.return_type); + result.Reference(expr.parameter_data->GetValue()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_reference.cpp b/src/duckdb/src/execution/expression_executor/execute_reference.cpp new file mode 100644 index 00000000..88fdfa63 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor/execute_reference.cpp @@ -0,0 +1,25 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +unique_ptr ExpressionExecutor::InitializeState(const BoundReferenceExpression &expr, + ExpressionExecutorState &root) { + auto result = make_uniq(expr, root); + result->Finalize(); + return result; +} + +void ExpressionExecutor::Execute(const BoundReferenceExpression &expr, ExpressionState *state, + const SelectionVector *sel, idx_t count, Vector &result) { + D_ASSERT(expr.index != DConstants::INVALID_INDEX); + D_ASSERT(expr.index < chunk->ColumnCount()); + + if (sel) { + result.Slice(chunk->data[expr.index], *sel, count); + } else { + result.Reference(chunk->data[expr.index]); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor_state.cpp b/src/duckdb/src/execution/expression_executor_state.cpp new file mode 100644 index 00000000..401e7615 --- /dev/null +++ b/src/duckdb/src/execution/expression_executor_state.cpp @@ -0,0 +1,52 @@ +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +void ExpressionState::AddChild(Expression *expr) { + types.push_back(expr->return_type); + child_states.push_back(ExpressionExecutor::InitializeState(*expr, root)); +} + +void ExpressionState::Finalize() { + if (!types.empty()) { + intermediate_chunk.Initialize(GetAllocator(), types); + } +} + +Allocator &ExpressionState::GetAllocator() { + return root.executor->GetAllocator(); +} + +bool ExpressionState::HasContext() { + return root.executor->HasContext(); +} + +ClientContext &ExpressionState::GetContext() { + if (!HasContext()) { + throw BinderException("Cannot use %s in this context", (expr.Cast()).function.name); + } + return root.executor->GetContext(); +} + +ExpressionState::ExpressionState(const Expression &expr, ExpressionExecutorState &root) : expr(expr), root(root) { +} + +ExpressionExecutorState::ExpressionExecutorState() : profiler() { +} + +void ExpressionState::Verify(ExpressionExecutorState &root_executor) { + D_ASSERT(&root_executor == &root); + for (auto &entry : child_states) { + entry->Verify(root_executor); + } +} + +void ExpressionExecutorState::Verify() { + D_ASSERT(executor); + root_state->Verify(*this); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/art.cpp b/src/duckdb/src/execution/index/art/art.cpp new file mode 100644 index 00000000..0d3cc1ac --- /dev/null +++ b/src/duckdb/src/execution/index/art/art.cpp @@ -0,0 +1,1132 @@ +#include "duckdb/execution/index/art/art.hpp" + +#include "duckdb/common/radix.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/arena_allocator.hpp" +#include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/prefix.hpp" +#include "duckdb/execution/index/art/leaf.hpp" +#include "duckdb/execution/index/art/node4.hpp" +#include "duckdb/execution/index/art/node16.hpp" +#include "duckdb/execution/index/art/node48.hpp" +#include "duckdb/execution/index/art/node256.hpp" +#include "duckdb/execution/index/art/iterator.hpp" +#include "duckdb/common/types/conflict_manager.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/metadata/metadata_reader.hpp" +#include "duckdb/storage/table_io_manager.hpp" + +#include + +namespace duckdb { + +struct ARTIndexScanState : public IndexScanState { + + //! Scan predicates (single predicate scan or range scan) + Value values[2]; + //! Expressions of the scan predicates + ExpressionType expressions[2]; + bool checked = false; + //! All scanned row IDs + vector result_ids; + Iterator iterator; +}; + +ART::ART(const vector &column_ids, TableIOManager &table_io_manager, + const vector> &unbound_expressions, const IndexConstraintType constraint_type, + AttachedDatabase &db, const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr, + const BlockPointer &pointer) + : Index(db, IndexType::ART, table_io_manager, column_ids, unbound_expressions, constraint_type), + allocators(allocators_ptr), owns_data(false) { + if (!Radix::IsLittleEndian()) { + throw NotImplementedException("ART indexes are not supported on big endian architectures"); + } + + // initialize all allocators + if (!allocators) { + owns_data = true; + auto &block_manager = table_io_manager.GetIndexBlockManager(); + + array, ALLOCATOR_COUNT> allocator_array = { + make_uniq(sizeof(Prefix), block_manager), + make_uniq(sizeof(Leaf), block_manager), + make_uniq(sizeof(Node4), block_manager), + make_uniq(sizeof(Node16), block_manager), + make_uniq(sizeof(Node48), block_manager), + make_uniq(sizeof(Node256), block_manager)}; + allocators = make_shared, ALLOCATOR_COUNT>>(std::move(allocator_array)); + } + + if (pointer.IsValid()) { + Deserialize(pointer); + } + + // validate the types of the key columns + for (idx_t i = 0; i < types.size(); i++) { + switch (types[i]) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::INT128: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + case PhysicalType::VARCHAR: + break; + default: + throw InvalidTypeException(logical_types[i], "Invalid type for index key."); + } + } +} + +//===--------------------------------------------------------------------===// +// Initialize Predicate Scans +//===--------------------------------------------------------------------===// + +unique_ptr ART::InitializeScanSinglePredicate(const Transaction &transaction, const Value &value, + const ExpressionType expression_type) { + // initialize point lookup + auto result = make_uniq(); + result->values[0] = value; + result->expressions[0] = expression_type; + return std::move(result); +} + +unique_ptr ART::InitializeScanTwoPredicates(const Transaction &transaction, const Value &low_value, + const ExpressionType low_expression_type, + const Value &high_value, + const ExpressionType high_expression_type) { + // initialize range lookup + auto result = make_uniq(); + result->values[0] = low_value; + result->expressions[0] = low_expression_type; + result->values[1] = high_value; + result->expressions[1] = high_expression_type; + return std::move(result); +} + +//===--------------------------------------------------------------------===// +// Keys +//===--------------------------------------------------------------------===// + +template +static void TemplatedGenerateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, vector &keys) { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + + D_ASSERT(keys.size() >= count); + auto input_data = UnifiedVectorFormat::GetData(idata); + for (idx_t i = 0; i < count; i++) { + auto idx = idata.sel->get_index(i); + if (idata.validity.RowIsValid(idx)) { + ARTKey::CreateARTKey(allocator, input.GetType(), keys[i], input_data[idx]); + } else { + // we need to possibly reset the former key value in the keys vector + keys[i] = ARTKey(); + } + } +} + +template +static void ConcatenateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, vector &keys) { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + + auto input_data = UnifiedVectorFormat::GetData(idata); + for (idx_t i = 0; i < count; i++) { + auto idx = idata.sel->get_index(i); + + // key is not NULL (no previous column entry was NULL) + if (!keys[i].Empty()) { + if (!idata.validity.RowIsValid(idx)) { + // this column entry is NULL, set whole key to NULL + keys[i] = ARTKey(); + } else { + auto other_key = ARTKey::CreateARTKey(allocator, input.GetType(), input_data[idx]); + keys[i].ConcatenateARTKey(allocator, other_key); + } + } + } +} + +void ART::GenerateKeys(ArenaAllocator &allocator, DataChunk &input, vector &keys) { + // generate keys for the first input column + switch (input.data[0].GetType().InternalType()) { + case PhysicalType::BOOL: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::INT8: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::INT16: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::INT32: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::INT64: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::INT128: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::UINT8: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::UINT16: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::UINT32: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::UINT64: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::FLOAT: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::DOUBLE: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + case PhysicalType::VARCHAR: + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + break; + default: + throw InternalException("Invalid type for index"); + } + + for (idx_t i = 1; i < input.ColumnCount(); i++) { + // for each of the remaining columns, concatenate + switch (input.data[i].GetType().InternalType()) { + case PhysicalType::BOOL: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::INT8: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::INT16: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::INT32: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::INT64: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::INT128: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::UINT8: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::UINT16: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::UINT32: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::UINT64: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::FLOAT: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::DOUBLE: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + case PhysicalType::VARCHAR: + ConcatenateKeys(allocator, input.data[i], input.size(), keys); + break; + default: + throw InternalException("Invalid type for index"); + } + } +} + +//===--------------------------------------------------------------------===// +// Construct from sorted data (only during CREATE (UNIQUE) INDEX statements) +//===--------------------------------------------------------------------===// + +struct KeySection { + KeySection(idx_t start_p, idx_t end_p, idx_t depth_p, data_t key_byte_p) + : start(start_p), end(end_p), depth(depth_p), key_byte(key_byte_p) {}; + KeySection(idx_t start_p, idx_t end_p, vector &keys, KeySection &key_section) + : start(start_p), end(end_p), depth(key_section.depth + 1), key_byte(keys[end_p].data[key_section.depth]) {}; + idx_t start; + idx_t end; + idx_t depth; + data_t key_byte; +}; + +void GetChildSections(vector &child_sections, vector &keys, KeySection &key_section) { + + idx_t child_start_idx = key_section.start; + for (idx_t i = key_section.start + 1; i <= key_section.end; i++) { + if (keys[i - 1].data[key_section.depth] != keys[i].data[key_section.depth]) { + child_sections.emplace_back(child_start_idx, i - 1, keys, key_section); + child_start_idx = i; + } + } + child_sections.emplace_back(child_start_idx, key_section.end, keys, key_section); +} + +bool Construct(ART &art, vector &keys, row_t *row_ids, Node &node, KeySection &key_section, + bool &has_constraint) { + + D_ASSERT(key_section.start < keys.size()); + D_ASSERT(key_section.end < keys.size()); + D_ASSERT(key_section.start <= key_section.end); + + auto &start_key = keys[key_section.start]; + auto &end_key = keys[key_section.end]; + + // increment the depth until we reach a leaf or find a mismatching byte + auto prefix_start = key_section.depth; + while (start_key.len != key_section.depth && start_key.ByteMatches(end_key, key_section.depth)) { + key_section.depth++; + } + + // we reached a leaf, i.e. all the bytes of start_key and end_key match + if (start_key.len == key_section.depth) { + // end_idx is inclusive + auto num_row_ids = key_section.end - key_section.start + 1; + + // check for possible constraint violation + auto single_row_id = num_row_ids == 1; + if (has_constraint && !single_row_id) { + return false; + } + + reference ref_node(node); + Prefix::New(art, ref_node, start_key, prefix_start, start_key.len - prefix_start); + if (single_row_id) { + Leaf::New(ref_node, row_ids[key_section.start]); + } else { + Leaf::New(art, ref_node, row_ids + key_section.start, num_row_ids); + } + return true; + } + + // create a new node and recurse + + // we will find at least two child entries of this node, otherwise we'd have reached a leaf + vector child_sections; + GetChildSections(child_sections, keys, key_section); + + // set the prefix + reference ref_node(node); + auto prefix_length = key_section.depth - prefix_start; + Prefix::New(art, ref_node, start_key, prefix_start, prefix_length); + + // set the node + auto node_type = Node::GetARTNodeTypeByCount(child_sections.size()); + Node::New(art, ref_node, node_type); + + // recurse on each child section + for (auto &child_section : child_sections) { + Node new_child; + auto no_violation = Construct(art, keys, row_ids, new_child, child_section, has_constraint); + Node::InsertChild(art, ref_node, child_section.key_byte, new_child); + if (!no_violation) { + return false; + } + } + return true; +} + +bool ART::ConstructFromSorted(idx_t count, vector &keys, Vector &row_identifiers) { + + // prepare the row_identifiers + row_identifiers.Flatten(count); + auto row_ids = FlatVector::GetData(row_identifiers); + + auto key_section = KeySection(0, count - 1, 0, 0); + auto has_constraint = IsUnique(); + if (!Construct(*this, keys, row_ids, tree, key_section, has_constraint)) { + return false; + } + +#ifdef DEBUG + D_ASSERT(!VerifyAndToStringInternal(true).empty()); + for (idx_t i = 0; i < count; i++) { + D_ASSERT(!keys[i].Empty()); + auto leaf = Lookup(tree, keys[i], 0); + D_ASSERT(Leaf::ContainsRowId(*this, *leaf, row_ids[i])); + } +#endif + + return true; +} + +//===--------------------------------------------------------------------===// +// Insert / Verification / Constraint Checking +//===--------------------------------------------------------------------===// +PreservedError ART::Insert(IndexLock &lock, DataChunk &input, Vector &row_ids) { + + D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); + D_ASSERT(logical_types[0] == input.data[0].GetType()); + + // generate the keys for the given input + ArenaAllocator arena_allocator(BufferAllocator::Get(db)); + vector keys(input.size()); + GenerateKeys(arena_allocator, input, keys); + + // get the corresponding row IDs + row_ids.Flatten(input.size()); + auto row_identifiers = FlatVector::GetData(row_ids); + + // now insert the elements into the index + idx_t failed_index = DConstants::INVALID_INDEX; + for (idx_t i = 0; i < input.size(); i++) { + if (keys[i].Empty()) { + continue; + } + + row_t row_id = row_identifiers[i]; + if (!Insert(tree, keys[i], 0, row_id)) { + // failed to insert because of constraint violation + failed_index = i; + break; + } + } + + // failed to insert because of constraint violation: remove previously inserted entries + if (failed_index != DConstants::INVALID_INDEX) { + for (idx_t i = 0; i < failed_index; i++) { + if (keys[i].Empty()) { + continue; + } + row_t row_id = row_identifiers[i]; + Erase(tree, keys[i], 0, row_id); + } + } + + if (failed_index != DConstants::INVALID_INDEX) { + return PreservedError(ConstraintException("PRIMARY KEY or UNIQUE constraint violated: duplicate key \"%s\"", + AppendRowError(input, failed_index))); + } + +#ifdef DEBUG + for (idx_t i = 0; i < input.size(); i++) { + if (keys[i].Empty()) { + continue; + } + + auto leaf = Lookup(tree, keys[i], 0); + D_ASSERT(Leaf::ContainsRowId(*this, *leaf, row_identifiers[i])); + } +#endif + + return PreservedError(); +} + +PreservedError ART::Append(IndexLock &lock, DataChunk &appended_data, Vector &row_identifiers) { + DataChunk expression_result; + expression_result.Initialize(Allocator::DefaultAllocator(), logical_types); + + // first resolve the expressions for the index + ExecuteExpressions(appended_data, expression_result); + + // now insert into the index + return Insert(lock, expression_result, row_identifiers); +} + +void ART::VerifyAppend(DataChunk &chunk) { + ConflictManager conflict_manager(VerifyExistenceType::APPEND, chunk.size()); + CheckConstraintsForChunk(chunk, conflict_manager); +} + +void ART::VerifyAppend(DataChunk &chunk, ConflictManager &conflict_manager) { + D_ASSERT(conflict_manager.LookupType() == VerifyExistenceType::APPEND); + CheckConstraintsForChunk(chunk, conflict_manager); +} + +bool ART::InsertToLeaf(Node &leaf, const row_t &row_id) { + + if (IsUnique()) { + return false; + } + + Leaf::Insert(*this, leaf, row_id); + return true; +} + +bool ART::Insert(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id) { + + // node is currently empty, create a leaf here with the key + if (!node.HasMetadata()) { + D_ASSERT(depth <= key.len); + reference ref_node(node); + Prefix::New(*this, ref_node, key, depth, key.len - depth); + Leaf::New(ref_node, row_id); + return true; + } + + auto node_type = node.GetType(); + + // insert the row ID into this leaf + if (node_type == NType::LEAF || node_type == NType::LEAF_INLINED) { + return InsertToLeaf(node, row_id); + } + + if (node_type != NType::PREFIX) { + D_ASSERT(depth < key.len); + auto child = node.GetChildMutable(*this, key[depth]); + + // recurse, if a child exists at key[depth] + if (child) { + bool success = Insert(*child, key, depth + 1, row_id); + node.ReplaceChild(*this, key[depth], *child); + return success; + } + + // insert a new leaf node at key[depth] + Node leaf_node; + reference ref_node(leaf_node); + if (depth + 1 < key.len) { + Prefix::New(*this, ref_node, key, depth + 1, key.len - depth - 1); + } + Leaf::New(ref_node, row_id); + Node::InsertChild(*this, node, key[depth], leaf_node); + return true; + } + + // this is a prefix node, traverse + reference next_node(node); + auto mismatch_position = Prefix::TraverseMutable(*this, next_node, key, depth); + + // prefix matches key + if (next_node.get().GetType() != NType::PREFIX) { + return Insert(next_node, key, depth, row_id); + } + + // prefix does not match the key, we need to create a new Node4; this new Node4 has two children, + // the remaining part of the prefix, and the new leaf + Node remaining_prefix; + auto prefix_byte = Prefix::GetByte(*this, next_node, mismatch_position); + Prefix::Split(*this, next_node, remaining_prefix, mismatch_position); + Node4::New(*this, next_node); + + // insert remaining prefix + Node4::InsertChild(*this, next_node, prefix_byte, remaining_prefix); + + // insert new leaf + Node leaf_node; + reference ref_node(leaf_node); + if (depth + 1 < key.len) { + Prefix::New(*this, ref_node, key, depth + 1, key.len - depth - 1); + } + Leaf::New(ref_node, row_id); + Node4::InsertChild(*this, next_node, key[depth], leaf_node); + return true; +} + +//===--------------------------------------------------------------------===// +// Drop and Delete +//===--------------------------------------------------------------------===// + +void ART::CommitDrop(IndexLock &index_lock) { + for (auto &allocator : *allocators) { + allocator->Reset(); + } + tree.Clear(); +} + +void ART::Delete(IndexLock &state, DataChunk &input, Vector &row_ids) { + + DataChunk expression; + expression.Initialize(Allocator::DefaultAllocator(), logical_types); + + // first resolve the expressions + ExecuteExpressions(input, expression); + + // then generate the keys for the given input + ArenaAllocator arena_allocator(BufferAllocator::Get(db)); + vector keys(expression.size()); + GenerateKeys(arena_allocator, expression, keys); + + // now erase the elements from the database + row_ids.Flatten(input.size()); + auto row_identifiers = FlatVector::GetData(row_ids); + + for (idx_t i = 0; i < input.size(); i++) { + if (keys[i].Empty()) { + continue; + } + Erase(tree, keys[i], 0, row_identifiers[i]); + } + +#ifdef DEBUG + // verify that we removed all row IDs + for (idx_t i = 0; i < input.size(); i++) { + if (keys[i].Empty()) { + continue; + } + + auto leaf = Lookup(tree, keys[i], 0); + if (leaf) { + D_ASSERT(!Leaf::ContainsRowId(*this, *leaf, row_identifiers[i])); + } + } +#endif +} + +void ART::Erase(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id) { + + if (!node.HasMetadata()) { + return; + } + + // handle prefix + reference next_node(node); + if (next_node.get().GetType() == NType::PREFIX) { + Prefix::TraverseMutable(*this, next_node, key, depth); + if (next_node.get().GetType() == NType::PREFIX) { + return; + } + } + + // delete a row ID from a leaf (root is leaf with possible prefix nodes) + if (next_node.get().GetType() == NType::LEAF || next_node.get().GetType() == NType::LEAF_INLINED) { + if (Leaf::Remove(*this, next_node, row_id)) { + Node::Free(*this, node); + } + return; + } + + D_ASSERT(depth < key.len); + auto child = next_node.get().GetChildMutable(*this, key[depth]); + if (child) { + D_ASSERT(child->HasMetadata()); + + auto temp_depth = depth + 1; + reference child_node(*child); + if (child_node.get().GetType() == NType::PREFIX) { + Prefix::TraverseMutable(*this, child_node, key, temp_depth); + if (child_node.get().GetType() == NType::PREFIX) { + return; + } + } + + if (child_node.get().GetType() == NType::LEAF || child_node.get().GetType() == NType::LEAF_INLINED) { + // leaf found, remove entry + if (Leaf::Remove(*this, child_node, row_id)) { + Node::DeleteChild(*this, next_node, node, key[depth]); + } + return; + } + + // recurse + Erase(*child, key, depth + 1, row_id); + next_node.get().ReplaceChild(*this, key[depth], *child); + } +} + +//===--------------------------------------------------------------------===// +// Point Query (Equal) +//===--------------------------------------------------------------------===// + +static ARTKey CreateKey(ArenaAllocator &allocator, PhysicalType type, Value &value) { + D_ASSERT(type == value.type().InternalType()); + switch (type) { + case PhysicalType::BOOL: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::INT8: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::INT16: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::INT32: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::INT64: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::UINT8: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::UINT16: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::UINT32: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::UINT64: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::INT128: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::FLOAT: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::DOUBLE: + return ARTKey::CreateARTKey(allocator, value.type(), value); + case PhysicalType::VARCHAR: + return ARTKey::CreateARTKey(allocator, value.type(), value); + default: + throw InternalException("Invalid type for the ART key"); + } +} + +bool ART::SearchEqual(ARTKey &key, idx_t max_count, vector &result_ids) { + + auto leaf = Lookup(tree, key, 0); + if (!leaf) { + return true; + } + return Leaf::GetRowIds(*this, *leaf, result_ids, max_count); +} + +void ART::SearchEqualJoinNoFetch(ARTKey &key, idx_t &result_size) { + + // we need to look for a leaf + auto leaf_node = Lookup(tree, key, 0); + if (!leaf_node) { + result_size = 0; + return; + } + + // we only perform index joins on PK/FK columns + D_ASSERT(leaf_node->GetType() == NType::LEAF_INLINED); + result_size = 1; + return; +} + +//===--------------------------------------------------------------------===// +// Lookup +//===--------------------------------------------------------------------===// + +optional_ptr ART::Lookup(const Node &node, const ARTKey &key, idx_t depth) { + + reference node_ref(node); + while (node_ref.get().HasMetadata()) { + + // traverse prefix, if exists + reference next_node(node_ref.get()); + if (next_node.get().GetType() == NType::PREFIX) { + Prefix::Traverse(*this, next_node, key, depth); + if (next_node.get().GetType() == NType::PREFIX) { + return nullptr; + } + } + + if (next_node.get().GetType() == NType::LEAF || next_node.get().GetType() == NType::LEAF_INLINED) { + return &next_node.get(); + } + + D_ASSERT(depth < key.len); + auto child = next_node.get().GetChild(*this, key[depth]); + if (!child) { + // prefix matches key, but no child at byte, ART/subtree does not contain key + return nullptr; + } + + // lookup in child node + node_ref = *child; + D_ASSERT(node_ref.get().HasMetadata()); + depth++; + } + + return nullptr; +} + +//===--------------------------------------------------------------------===// +// Greater Than and Less Than +//===--------------------------------------------------------------------===// + +bool ART::SearchGreater(ARTIndexScanState &state, ARTKey &key, bool equal, idx_t max_count, vector &result_ids) { + + if (!tree.HasMetadata()) { + return true; + } + Iterator &it = state.iterator; + + // find the lowest value that satisfies the predicate + if (!it.art) { + it.art = this; + if (!it.LowerBound(tree, key, equal, 0)) { + // early-out, if the maximum value in the ART is lower than the lower bound + return true; + } + } + + // after that we continue the scan; we don't need to check the bounds as any value following this value is + // automatically bigger and hence satisfies our predicate + ARTKey empty_key = ARTKey(); + return it.Scan(empty_key, max_count, result_ids, false); +} + +bool ART::SearchLess(ARTIndexScanState &state, ARTKey &upper_bound, bool equal, idx_t max_count, + vector &result_ids) { + + if (!tree.HasMetadata()) { + return true; + } + Iterator &it = state.iterator; + + if (!it.art) { + it.art = this; + // find the minimum value in the ART: we start scanning from this value + it.FindMinimum(tree); + // early-out, if the minimum value is higher than the upper bound + if (it.current_key > upper_bound) { + return true; + } + } + + // now continue the scan until we reach the upper bound + return it.Scan(upper_bound, max_count, result_ids, equal); +} + +//===--------------------------------------------------------------------===// +// Closed Range Query +//===--------------------------------------------------------------------===// + +bool ART::SearchCloseRange(ARTIndexScanState &state, ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, + bool right_equal, idx_t max_count, vector &result_ids) { + + Iterator &it = state.iterator; + + // find the first node that satisfies the left predicate + if (!it.art) { + it.art = this; + if (!it.LowerBound(tree, lower_bound, left_equal, 0)) { + // early-out, if the maximum value in the ART is lower than the lower bound + return true; + } + } + + // now continue the scan until we reach the upper bound + return it.Scan(upper_bound, max_count, result_ids, right_equal); +} + +bool ART::Scan(const Transaction &transaction, const DataTable &table, IndexScanState &state, const idx_t max_count, + vector &result_ids) { + + auto &scan_state = state.Cast(); + vector row_ids; + bool success; + + // FIXME: the key directly owning the data for a single key might be more efficient + D_ASSERT(scan_state.values[0].type().InternalType() == types[0]); + ArenaAllocator arena_allocator(Allocator::Get(db)); + auto key = CreateKey(arena_allocator, types[0], scan_state.values[0]); + + if (scan_state.values[1].IsNull()) { + + // single predicate + lock_guard l(lock); + switch (scan_state.expressions[0]) { + case ExpressionType::COMPARE_EQUAL: + success = SearchEqual(key, max_count, row_ids); + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + success = SearchGreater(scan_state, key, true, max_count, row_ids); + break; + case ExpressionType::COMPARE_GREATERTHAN: + success = SearchGreater(scan_state, key, false, max_count, row_ids); + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + success = SearchLess(scan_state, key, true, max_count, row_ids); + break; + case ExpressionType::COMPARE_LESSTHAN: + success = SearchLess(scan_state, key, false, max_count, row_ids); + break; + default: + throw InternalException("Index scan type not implemented"); + } + + } else { + + // two predicates + lock_guard l(lock); + + D_ASSERT(scan_state.values[1].type().InternalType() == types[0]); + auto upper_bound = CreateKey(arena_allocator, types[0], scan_state.values[1]); + + bool left_equal = scan_state.expressions[0] == ExpressionType ::COMPARE_GREATERTHANOREQUALTO; + bool right_equal = scan_state.expressions[1] == ExpressionType ::COMPARE_LESSTHANOREQUALTO; + success = SearchCloseRange(scan_state, key, upper_bound, left_equal, right_equal, max_count, row_ids); + } + + if (!success) { + return false; + } + if (row_ids.empty()) { + return true; + } + + // sort the row ids + sort(row_ids.begin(), row_ids.end()); + // duplicate eliminate the row ids and append them to the row ids of the state + result_ids.reserve(row_ids.size()); + + result_ids.push_back(row_ids[0]); + for (idx_t i = 1; i < row_ids.size(); i++) { + if (row_ids[i] != row_ids[i - 1]) { + result_ids.push_back(row_ids[i]); + } + } + return true; +} + +//===--------------------------------------------------------------------===// +// More Verification / Constraint Checking +//===--------------------------------------------------------------------===// + +string ART::GenerateErrorKeyName(DataChunk &input, idx_t row) { + + // FIXME: why exactly can we not pass the expression_chunk as an argument to this + // FIXME: function instead of re-executing? + // re-executing the expressions is not very fast, but we're going to throw, so we don't care + DataChunk expression_chunk; + expression_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); + ExecuteExpressions(input, expression_chunk); + + string key_name; + for (idx_t k = 0; k < expression_chunk.ColumnCount(); k++) { + if (k > 0) { + key_name += ", "; + } + key_name += unbound_expressions[k]->GetName() + ": " + expression_chunk.data[k].GetValue(row).ToString(); + } + return key_name; +} + +string ART::GenerateConstraintErrorMessage(VerifyExistenceType verify_type, const string &key_name) { + switch (verify_type) { + case VerifyExistenceType::APPEND: { + // APPEND to PK/UNIQUE table, but node/key already exists in PK/UNIQUE table + string type = IsPrimary() ? "primary key" : "unique"; + return StringUtil::Format( + "Duplicate key \"%s\" violates %s constraint. " + "If this is an unexpected constraint violation please double " + "check with the known index limitations section in our documentation (docs - sql - indexes).", + key_name, type); + } + case VerifyExistenceType::APPEND_FK: { + // APPEND_FK to FK table, node/key does not exist in PK/UNIQUE table + return StringUtil::Format( + "Violates foreign key constraint because key \"%s\" does not exist in the referenced table", key_name); + } + case VerifyExistenceType::DELETE_FK: { + // DELETE_FK that still exists in a FK table, i.e., not a valid delete + return StringUtil::Format("Violates foreign key constraint because key \"%s\" is still referenced by a foreign " + "key in a different table", + key_name); + } + default: + throw NotImplementedException("Type not implemented for VerifyExistenceType"); + } +} + +void ART::CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_manager) { + + // don't alter the index during constraint checking + lock_guard l(lock); + + // first resolve the expressions for the index + DataChunk expression_chunk; + expression_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); + ExecuteExpressions(input, expression_chunk); + + // generate the keys for the given input + ArenaAllocator arena_allocator(BufferAllocator::Get(db)); + vector keys(expression_chunk.size()); + GenerateKeys(arena_allocator, expression_chunk, keys); + + idx_t found_conflict = DConstants::INVALID_INDEX; + for (idx_t i = 0; found_conflict == DConstants::INVALID_INDEX && i < input.size(); i++) { + + if (keys[i].Empty()) { + if (conflict_manager.AddNull(i)) { + found_conflict = i; + } + continue; + } + + auto leaf = Lookup(tree, keys[i], 0); + if (!leaf) { + if (conflict_manager.AddMiss(i)) { + found_conflict = i; + } + continue; + } + + // when we find a node, we need to update the 'matches' and 'row_ids' + // NOTE: leaves can have more than one row_id, but for UNIQUE/PRIMARY KEY they will only have one + D_ASSERT(leaf->GetType() == NType::LEAF_INLINED); + if (conflict_manager.AddHit(i, leaf->GetRowId())) { + found_conflict = i; + } + } + + conflict_manager.FinishLookup(); + + if (found_conflict == DConstants::INVALID_INDEX) { + return; + } + + auto key_name = GenerateErrorKeyName(input, found_conflict); + auto exception_msg = GenerateConstraintErrorMessage(conflict_manager.LookupType(), key_name); + throw ConstraintException(exception_msg); +} + +//===--------------------------------------------------------------------===// +// Serialization +//===--------------------------------------------------------------------===// + +BlockPointer ART::Serialize(MetadataWriter &writer) { + + D_ASSERT(owns_data); + + // early-out, if all allocators are empty + if (!tree.HasMetadata()) { + root_block_pointer = BlockPointer(); + return root_block_pointer; + } + + lock_guard l(lock); + auto &block_manager = table_io_manager.GetIndexBlockManager(); + PartialBlockManager partial_block_manager(block_manager, CheckpointType::FULL_CHECKPOINT); + + vector allocator_pointers; + for (auto &allocator : *allocators) { + allocator_pointers.push_back(allocator->Serialize(partial_block_manager, writer)); + } + partial_block_manager.FlushPartialBlocks(); + + root_block_pointer = writer.GetBlockPointer(); + writer.Write(tree); + for (auto &allocator_pointer : allocator_pointers) { + writer.Write(allocator_pointer); + } + + return root_block_pointer; +} + +void ART::Deserialize(const BlockPointer &pointer) { + + D_ASSERT(pointer.IsValid()); + MetadataReader reader(table_io_manager.GetMetadataManager(), pointer); + tree = reader.Read(); + + for (idx_t i = 0; i < ALLOCATOR_COUNT; i++) { + (*allocators)[i]->Deserialize(reader.Read()); + } +} + +//===--------------------------------------------------------------------===// +// Vacuum +//===--------------------------------------------------------------------===// + +void ART::InitializeVacuum(ARTFlags &flags) { + + flags.vacuum_flags.reserve(allocators->size()); + for (auto &allocator : *allocators) { + flags.vacuum_flags.push_back(allocator->InitializeVacuum()); + } +} + +void ART::FinalizeVacuum(const ARTFlags &flags) { + + for (idx_t i = 0; i < allocators->size(); i++) { + if (flags.vacuum_flags[i]) { + (*allocators)[i]->FinalizeVacuum(); + } + } +} + +void ART::Vacuum(IndexLock &state) { + + D_ASSERT(owns_data); + + if (!tree.HasMetadata()) { + for (auto &allocator : *allocators) { + allocator->Reset(); + } + return; + } + + // holds true, if an allocator needs a vacuum, and false otherwise + ARTFlags flags; + InitializeVacuum(flags); + + // skip vacuum if no allocators require it + auto perform_vacuum = false; + for (const auto &vacuum_flag : flags.vacuum_flags) { + if (vacuum_flag) { + perform_vacuum = true; + break; + } + } + if (!perform_vacuum) { + return; + } + + // traverse the allocated memory of the tree to perform a vacuum + tree.Vacuum(*this, flags); + + // finalize the vacuum operation + FinalizeVacuum(flags); +} + +//===--------------------------------------------------------------------===// +// Merging +//===--------------------------------------------------------------------===// + +void ART::InitializeMerge(ARTFlags &flags) { + + D_ASSERT(owns_data); + + flags.merge_buffer_counts.reserve(allocators->size()); + for (auto &allocator : *allocators) { + flags.merge_buffer_counts.emplace_back(allocator->GetUpperBoundBufferId()); + } +} + +bool ART::MergeIndexes(IndexLock &state, Index &other_index) { + + auto &other_art = other_index.Cast(); + if (!other_art.tree.HasMetadata()) { + return true; + } + + if (other_art.owns_data) { + if (tree.HasMetadata()) { + // fully deserialize other_index, and traverse it to increment its buffer IDs + ARTFlags flags; + InitializeMerge(flags); + other_art.tree.InitializeMerge(other_art, flags); + } + + // merge the node storage + for (idx_t i = 0; i < allocators->size(); i++) { + (*allocators)[i]->Merge(*(*other_art.allocators)[i]); + } + } + + // merge the ARTs + if (!tree.Merge(*this, other_art.tree)) { + return false; + } + return true; +} + +//===--------------------------------------------------------------------===// +// Utility +//===--------------------------------------------------------------------===// + +string ART::VerifyAndToString(IndexLock &state, const bool only_verify) { + // FIXME: this can be improved by counting the allocations of each node type, + // FIXME: and by asserting that each fixed-size allocator lists an equal number of + // FIXME: allocations of that type + return VerifyAndToStringInternal(only_verify); +} + +string ART::VerifyAndToStringInternal(const bool only_verify) { + if (tree.HasMetadata()) { + return "ART: " + tree.VerifyAndToString(*this, only_verify); + } + return "[empty]"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/art_key.cpp b/src/duckdb/src/execution/index/art/art_key.cpp new file mode 100644 index 00000000..9cc26be2 --- /dev/null +++ b/src/duckdb/src/execution/index/art/art_key.cpp @@ -0,0 +1,106 @@ +#include "duckdb/execution/index/art/art_key.hpp" + +namespace duckdb { + +ARTKey::ARTKey() : len(0) { +} + +ARTKey::ARTKey(const data_ptr_t &data, const uint32_t &len) : len(len), data(data) { +} + +ARTKey::ARTKey(ArenaAllocator &allocator, const uint32_t &len) : len(len) { + data = allocator.Allocate(len); +} + +template <> +ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, string_t value) { + uint32_t len = value.GetSize() + 1; + auto data = allocator.Allocate(len); + memcpy(data, value.GetData(), len - 1); + + // FIXME: rethink this + if (type == LogicalType::BLOB || type == LogicalType::VARCHAR) { + // indexes cannot contain BLOBs (or BLOBs cast to VARCHARs) that contain null-terminated bytes + for (uint32_t i = 0; i < len - 1; i++) { + if (data[i] == '\0') { + throw NotImplementedException("Indexes cannot contain BLOBs that contain null-terminated bytes."); + } + } + } + + data[len - 1] = '\0'; + return ARTKey(data, len); +} + +template <> +ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, const char *value) { + return ARTKey::CreateARTKey(allocator, type, string_t(value, strlen(value))); +} + +template <> +void ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, string_t value) { + key.len = value.GetSize() + 1; + key.data = allocator.Allocate(key.len); + memcpy(key.data, value.GetData(), key.len - 1); + + // FIXME: rethink this + if (type == LogicalType::BLOB || type == LogicalType::VARCHAR) { + // indexes cannot contain BLOBs (or BLOBs cast to VARCHARs) that contain null-terminated bytes + for (uint32_t i = 0; i < key.len - 1; i++) { + if (key.data[i] == '\0') { + throw NotImplementedException("Indexes cannot contain BLOBs that contain null-terminated bytes."); + } + } + } + + key.data[key.len - 1] = '\0'; +} + +template <> +void ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, const char *value) { + ARTKey::CreateARTKey(allocator, type, key, string_t(value, strlen(value))); +} + +bool ARTKey::operator>(const ARTKey &k) const { + for (uint32_t i = 0; i < MinValue(len, k.len); i++) { + if (data[i] > k.data[i]) { + return true; + } else if (data[i] < k.data[i]) { + return false; + } + } + return len > k.len; +} + +bool ARTKey::operator>=(const ARTKey &k) const { + for (uint32_t i = 0; i < MinValue(len, k.len); i++) { + if (data[i] > k.data[i]) { + return true; + } else if (data[i] < k.data[i]) { + return false; + } + } + return len >= k.len; +} + +bool ARTKey::operator==(const ARTKey &k) const { + if (len != k.len) { + return false; + } + for (uint32_t i = 0; i < len; i++) { + if (data[i] != k.data[i]) { + return false; + } + } + return true; +} + +void ARTKey::ConcatenateARTKey(ArenaAllocator &allocator, ARTKey &other_key) { + + auto compound_data = allocator.Allocate(len + other_key.len); + memcpy(compound_data, data, len); + memcpy(compound_data + len, other_key.data, other_key.len); + len += other_key.len; + data = compound_data; +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/iterator.cpp b/src/duckdb/src/execution/index/art/iterator.cpp new file mode 100644 index 00000000..0d0290eb --- /dev/null +++ b/src/duckdb/src/execution/index/art/iterator.cpp @@ -0,0 +1,211 @@ +#include "duckdb/execution/index/art/iterator.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" +#include "duckdb/execution/index/art/prefix.hpp" + +namespace duckdb { + +bool IteratorKey::operator>(const ARTKey &key) const { + for (idx_t i = 0; i < MinValue(key_bytes.size(), key.len); i++) { + if (key_bytes[i] > key.data[i]) { + return true; + } else if (key_bytes[i] < key.data[i]) { + return false; + } + } + return key_bytes.size() > key.len; +} + +bool IteratorKey::operator>=(const ARTKey &key) const { + for (idx_t i = 0; i < MinValue(key_bytes.size(), key.len); i++) { + if (key_bytes[i] > key.data[i]) { + return true; + } else if (key_bytes[i] < key.data[i]) { + return false; + } + } + return key_bytes.size() >= key.len; +} + +bool IteratorKey::operator==(const ARTKey &key) const { + // NOTE: we only use this for finding the LowerBound, in which case the length + // has to be equal + D_ASSERT(key_bytes.size() == key.len); + for (idx_t i = 0; i < key_bytes.size(); i++) { + if (key_bytes[i] != key.data[i]) { + return false; + } + } + return true; +} + +bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, vector &result_ids, const bool equal) { + + bool has_next; + do { + if (!upper_bound.Empty()) { + // no more row IDs within the key bounds + if (equal) { + if (current_key > upper_bound) { + return true; + } + } else { + if (current_key >= upper_bound) { + return true; + } + } + } + + // copy all row IDs of this leaf into the result IDs (if they don't exceed max_count) + if (!Leaf::GetRowIds(*art, last_leaf, result_ids, max_count)) { + return false; + } + + // get the next leaf + has_next = Next(); + + } while (has_next); + + return true; +} + +void Iterator::FindMinimum(const Node &node) { + + D_ASSERT(node.HasMetadata()); + + // found the minimum + if (node.GetType() == NType::LEAF || node.GetType() == NType::LEAF_INLINED) { + last_leaf = node; + return; + } + + // traverse the prefix + if (node.GetType() == NType::PREFIX) { + auto &prefix = Node::Ref(*art, node, NType::PREFIX); + for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + current_key.Push(prefix.data[i]); + } + nodes.emplace(node, 0); + return FindMinimum(prefix.ptr); + } + + // go to the leftmost entry in the current node and recurse + uint8_t byte = 0; + auto next = node.GetNextChild(*art, byte); + D_ASSERT(next); + current_key.Push(byte); + nodes.emplace(node, byte); + FindMinimum(*next); +} + +bool Iterator::LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth) { + + if (!node.HasMetadata()) { + return false; + } + + // we found the lower bound + if (node.GetType() == NType::LEAF || node.GetType() == NType::LEAF_INLINED) { + if (!equal && current_key == key) { + return Next(); + } + last_leaf = node; + return true; + } + + if (node.GetType() != NType::PREFIX) { + auto next_byte = key[depth]; + auto child = node.GetNextChild(*art, next_byte); + if (!child) { + // the key is greater than any key in this subtree + return Next(); + } + + current_key.Push(next_byte); + nodes.emplace(node, next_byte); + + if (next_byte > key[depth]) { + // we only need to find the minimum from here + // because all keys will be greater than the lower bound + FindMinimum(*child); + return true; + } + + // recurse into the child + return LowerBound(*child, key, equal, depth + 1); + } + + // resolve the prefix + auto &prefix = Node::Ref(*art, node, NType::PREFIX); + for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + current_key.Push(prefix.data[i]); + } + nodes.emplace(node, 0); + + for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + // the key down to this node is less than the lower bound, the next key will be + // greater than the lower bound + if (prefix.data[i] < key[depth + i]) { + return Next(); + } + // we only need to find the minimum from here + // because all keys will be greater than the lower bound + if (prefix.data[i] > key[depth + i]) { + FindMinimum(prefix.ptr); + return true; + } + } + + // recurse into the child + depth += prefix.data[Node::PREFIX_SIZE]; + return LowerBound(prefix.ptr, key, equal, depth); +} + +bool Iterator::Next() { + + while (!nodes.empty()) { + + auto &top = nodes.top(); + D_ASSERT(top.node.GetType() != NType::LEAF && top.node.GetType() != NType::LEAF_INLINED); + + if (top.node.GetType() == NType::PREFIX) { + PopNode(); + continue; + } + + if (top.byte == NumericLimits::Maximum()) { + // no node found: move up the tree, pop key byte of current node + PopNode(); + continue; + } + + top.byte++; + auto next_node = top.node.GetNextChild(*art, top.byte); + if (!next_node) { + PopNode(); + continue; + } + + current_key.Pop(1); + current_key.Push(top.byte); + + FindMinimum(*next_node); + return true; + } + return false; +} + +void Iterator::PopNode() { + if (nodes.top().node.GetType() == NType::PREFIX) { + auto &prefix = Node::Ref(*art, nodes.top().node, NType::PREFIX); + auto prefix_byte_count = prefix.data[Node::PREFIX_SIZE]; + current_key.Pop(prefix_byte_count); + } else { + current_key.Pop(1); + } + nodes.pop(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/leaf.cpp b/src/duckdb/src/execution/index/art/leaf.cpp new file mode 100644 index 00000000..1523dd99 --- /dev/null +++ b/src/duckdb/src/execution/index/art/leaf.cpp @@ -0,0 +1,347 @@ +#include "duckdb/execution/index/art/leaf.hpp" + +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" + +namespace duckdb { + +void Leaf::New(Node &node, const row_t row_id) { + + // we directly inline this row ID into the node pointer + D_ASSERT(row_id < MAX_ROW_ID_LOCAL); + node.Clear(); + node.SetMetadata(static_cast(NType::LEAF_INLINED)); + node.SetRowId(row_id); +} + +void Leaf::New(ART &art, reference &node, const row_t *row_ids, idx_t count) { + + D_ASSERT(count > 1); + + idx_t copy_count = 0; + while (count) { + node.get() = Node::GetAllocator(art, NType::LEAF).New(); + node.get().SetMetadata(static_cast(NType::LEAF)); + + auto &leaf = Node::RefMutable(art, node, NType::LEAF); + + leaf.count = MinValue((idx_t)Node::LEAF_SIZE, count); + + for (idx_t i = 0; i < leaf.count; i++) { + leaf.row_ids[i] = row_ids[copy_count + i]; + } + + copy_count += leaf.count; + count -= leaf.count; + + node = leaf.ptr; + leaf.ptr.Clear(); + } +} + +Leaf &Leaf::New(ART &art, Node &node) { + node = Node::GetAllocator(art, NType::LEAF).New(); + node.SetMetadata(static_cast(NType::LEAF)); + auto &leaf = Node::RefMutable(art, node, NType::LEAF); + + leaf.count = 0; + leaf.ptr.Clear(); + return leaf; +} + +void Leaf::Free(ART &art, Node &node) { + + Node current_node = node; + Node next_node; + while (current_node.HasMetadata()) { + next_node = Node::RefMutable(art, current_node, NType::LEAF).ptr; + Node::GetAllocator(art, NType::LEAF).Free(current_node); + current_node = next_node; + } + + node.Clear(); +} + +void Leaf::InitializeMerge(ART &art, Node &node, const ARTFlags &flags) { + + auto merge_buffer_count = flags.merge_buffer_counts[static_cast(NType::LEAF) - 1]; + + Node next_node = node; + node.IncreaseBufferId(merge_buffer_count); + + while (next_node.HasMetadata()) { + auto &leaf = Node::RefMutable(art, next_node, NType::LEAF); + next_node = leaf.ptr; + if (leaf.ptr.HasMetadata()) { + leaf.ptr.IncreaseBufferId(merge_buffer_count); + } + } +} + +void Leaf::Merge(ART &art, Node &l_node, Node &r_node) { + + D_ASSERT(l_node.HasMetadata() && r_node.HasMetadata()); + + // copy inlined row ID of r_node + if (r_node.GetType() == NType::LEAF_INLINED) { + Insert(art, l_node, r_node.GetRowId()); + r_node.Clear(); + return; + } + + // l_node has an inlined row ID, swap and insert + if (l_node.GetType() == NType::LEAF_INLINED) { + auto row_id = l_node.GetRowId(); + l_node = r_node; + Insert(art, l_node, row_id); + r_node.Clear(); + return; + } + + D_ASSERT(l_node.GetType() != NType::LEAF_INLINED); + D_ASSERT(r_node.GetType() != NType::LEAF_INLINED); + + reference l_node_ref(l_node); + reference l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); + + // find a non-full node + while (l_leaf.get().count == Node::LEAF_SIZE) { + l_node_ref = l_leaf.get().ptr; + + // the last leaf is full + if (!l_leaf.get().ptr.HasMetadata()) { + break; + } + l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); + } + + // store the last leaf and then append r_node + auto last_leaf_node = l_node_ref.get(); + l_node_ref.get() = r_node; + r_node.Clear(); + + // append the remaining row IDs of the last leaf node + if (last_leaf_node.HasMetadata()) { + // find the tail + l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); + while (l_leaf.get().ptr.HasMetadata()) { + l_leaf = Node::RefMutable(art, l_leaf.get().ptr, NType::LEAF); + } + // append the row IDs + auto &last_leaf = Node::RefMutable(art, last_leaf_node, NType::LEAF); + for (idx_t i = 0; i < last_leaf.count; i++) { + l_leaf = l_leaf.get().Append(art, last_leaf.row_ids[i]); + } + Node::GetAllocator(art, NType::LEAF).Free(last_leaf_node); + } +} + +void Leaf::Insert(ART &art, Node &node, const row_t row_id) { + + D_ASSERT(node.HasMetadata()); + + if (node.GetType() == NType::LEAF_INLINED) { + MoveInlinedToLeaf(art, node); + Insert(art, node, row_id); + return; + } + + // append to the tail + reference leaf = Node::RefMutable(art, node, NType::LEAF); + while (leaf.get().ptr.HasMetadata()) { + leaf = Node::RefMutable(art, leaf.get().ptr, NType::LEAF); + } + leaf.get().Append(art, row_id); +} + +bool Leaf::Remove(ART &art, reference &node, const row_t row_id) { + + D_ASSERT(node.get().HasMetadata()); + + if (node.get().GetType() == NType::LEAF_INLINED) { + if (node.get().GetRowId() == row_id) { + return true; + } + return false; + } + + reference leaf = Node::RefMutable(art, node, NType::LEAF); + + // inline the remaining row ID + if (leaf.get().count == 2) { + if (leaf.get().row_ids[0] == row_id || leaf.get().row_ids[1] == row_id) { + auto remaining_row_id = leaf.get().row_ids[0] == row_id ? leaf.get().row_ids[1] : leaf.get().row_ids[0]; + Node::Free(art, node); + New(node, remaining_row_id); + } + return false; + } + + // get the last row ID (the order within a leaf does not matter) + // because we want to overwrite the row ID to remove with that one + + // go to the tail and keep track of the previous leaf node + reference prev_leaf(leaf); + while (leaf.get().ptr.HasMetadata()) { + prev_leaf = leaf; + leaf = Node::RefMutable(art, leaf.get().ptr, NType::LEAF); + } + + auto last_idx = leaf.get().count; + auto last_row_id = leaf.get().row_ids[last_idx - 1]; + + // only one row ID in this leaf segment, free it + if (leaf.get().count == 1) { + Node::Free(art, prev_leaf.get().ptr); + if (last_row_id == row_id) { + return false; + } + } else { + leaf.get().count--; + } + + // find the row ID and copy the last row ID to that position + while (node.get().HasMetadata()) { + leaf = Node::RefMutable(art, node, NType::LEAF); + for (idx_t i = 0; i < leaf.get().count; i++) { + if (leaf.get().row_ids[i] == row_id) { + leaf.get().row_ids[i] = last_row_id; + return false; + } + } + node = leaf.get().ptr; + } + return false; +} + +idx_t Leaf::TotalCount(ART &art, const Node &node) { + + D_ASSERT(node.HasMetadata()); + if (node.GetType() == NType::LEAF_INLINED) { + return 1; + } + + idx_t count = 0; + reference node_ref(node); + while (node_ref.get().HasMetadata()) { + auto &leaf = Node::Ref(art, node_ref, NType::LEAF); + count += leaf.count; + node_ref = leaf.ptr; + } + return count; +} + +bool Leaf::GetRowIds(ART &art, const Node &node, vector &result_ids, idx_t max_count) { + + // adding more elements would exceed the maximum count + D_ASSERT(node.HasMetadata()); + if (result_ids.size() + TotalCount(art, node) > max_count) { + return false; + } + + if (node.GetType() == NType::LEAF_INLINED) { + // push back the inlined row ID of this leaf + result_ids.push_back(node.GetRowId()); + + } else { + // push back all the row IDs of this leaf + reference last_leaf_ref(node); + while (last_leaf_ref.get().HasMetadata()) { + auto &leaf = Node::Ref(art, last_leaf_ref, NType::LEAF); + for (idx_t i = 0; i < leaf.count; i++) { + result_ids.push_back(leaf.row_ids[i]); + } + last_leaf_ref = leaf.ptr; + } + } + + return true; +} + +bool Leaf::ContainsRowId(ART &art, const Node &node, const row_t row_id) { + + D_ASSERT(node.HasMetadata()); + + if (node.GetType() == NType::LEAF_INLINED) { + return node.GetRowId() == row_id; + } + + reference ref_node(node); + while (ref_node.get().HasMetadata()) { + auto &leaf = Node::Ref(art, ref_node, NType::LEAF); + for (idx_t i = 0; i < leaf.count; i++) { + if (leaf.row_ids[i] == row_id) { + return true; + } + } + ref_node = leaf.ptr; + } + + return false; +} + +string Leaf::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { + + if (node.GetType() == NType::LEAF_INLINED) { + return only_verify ? "" : "Leaf [count: 1, row ID: " + to_string(node.GetRowId()) + "]"; + } + + string str = ""; + + reference node_ref(node); + while (node_ref.get().HasMetadata()) { + + auto &leaf = Node::Ref(art, node_ref, NType::LEAF); + D_ASSERT(leaf.count <= Node::LEAF_SIZE); + + str += "Leaf [count: " + to_string(leaf.count) + ", row IDs: "; + for (idx_t i = 0; i < leaf.count; i++) { + str += to_string(leaf.row_ids[i]) + "-"; + } + str += "] "; + + node_ref = leaf.ptr; + } + return only_verify ? "" : str; +} + +void Leaf::Vacuum(ART &art, Node &node) { + + auto &allocator = Node::GetAllocator(art, NType::LEAF); + + reference node_ref(node); + while (node_ref.get().HasMetadata()) { + if (allocator.NeedsVacuum(node_ref)) { + node_ref.get() = allocator.VacuumPointer(node_ref); + node_ref.get().SetMetadata(static_cast(NType::LEAF)); + } + auto &leaf = Node::RefMutable(art, node_ref, NType::LEAF); + node_ref = leaf.ptr; + } +} + +void Leaf::MoveInlinedToLeaf(ART &art, Node &node) { + + D_ASSERT(node.GetType() == NType::LEAF_INLINED); + auto row_id = node.GetRowId(); + auto &leaf = New(art, node); + + leaf.count = 1; + leaf.row_ids[0] = row_id; +} + +Leaf &Leaf::Append(ART &art, const row_t row_id) { + + reference leaf(*this); + + // we need a new leaf node + if (leaf.get().count == Node::LEAF_SIZE) { + leaf = New(art, leaf.get().ptr); + } + + leaf.get().row_ids[leaf.get().count] = row_id; + leaf.get().count++; + return leaf.get(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node.cpp b/src/duckdb/src/execution/index/art/node.cpp new file mode 100644 index 00000000..5c82b748 --- /dev/null +++ b/src/duckdb/src/execution/index/art/node.cpp @@ -0,0 +1,518 @@ +#include "duckdb/execution/index/art/node.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/swap.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node256.hpp" +#include "duckdb/execution/index/art/node48.hpp" +#include "duckdb/execution/index/art/node16.hpp" +#include "duckdb/execution/index/art/node4.hpp" +#include "duckdb/execution/index/art/leaf.hpp" +#include "duckdb/execution/index/art/prefix.hpp" +#include "duckdb/storage/table_io_manager.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// New / Free +//===--------------------------------------------------------------------===// + +void Node::New(ART &art, Node &node, const NType type) { + + // NOTE: leaves and prefixes should not pass through this function + + switch (type) { + case NType::NODE_4: + Node4::New(art, node); + break; + case NType::NODE_16: + Node16::New(art, node); + break; + case NType::NODE_48: + Node48::New(art, node); + break; + case NType::NODE_256: + Node256::New(art, node); + break; + default: + throw InternalException("Invalid node type for New."); + } +} + +void Node::Free(ART &art, Node &node) { + + if (!node.HasMetadata()) { + return node.Clear(); + } + + // free the children of the nodes + auto type = node.GetType(); + switch (type) { + case NType::PREFIX: + // iterative + return Prefix::Free(art, node); + case NType::LEAF: + // iterative + return Leaf::Free(art, node); + case NType::NODE_4: + Node4::Free(art, node); + break; + case NType::NODE_16: + Node16::Free(art, node); + break; + case NType::NODE_48: + Node48::Free(art, node); + break; + case NType::NODE_256: + Node256::Free(art, node); + break; + case NType::LEAF_INLINED: + return node.Clear(); + } + + GetAllocator(art, type).Free(node); + node.Clear(); +} + +//===--------------------------------------------------------------------===// +// Get Allocators +//===--------------------------------------------------------------------===// + +FixedSizeAllocator &Node::GetAllocator(const ART &art, const NType type) { + return *(*art.allocators)[static_cast(type) - 1]; +} + +//===--------------------------------------------------------------------===// +// Inserts +//===--------------------------------------------------------------------===// + +void Node::ReplaceChild(const ART &art, const uint8_t byte, const Node child) const { + + switch (GetType()) { + case NType::NODE_4: + return RefMutable(art, *this, NType::NODE_4).ReplaceChild(byte, child); + case NType::NODE_16: + return RefMutable(art, *this, NType::NODE_16).ReplaceChild(byte, child); + case NType::NODE_48: + return RefMutable(art, *this, NType::NODE_48).ReplaceChild(byte, child); + case NType::NODE_256: + return RefMutable(art, *this, NType::NODE_256).ReplaceChild(byte, child); + default: + throw InternalException("Invalid node type for ReplaceChild."); + } +} + +void Node::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + + switch (node.GetType()) { + case NType::NODE_4: + return Node4::InsertChild(art, node, byte, child); + case NType::NODE_16: + return Node16::InsertChild(art, node, byte, child); + case NType::NODE_48: + return Node48::InsertChild(art, node, byte, child); + case NType::NODE_256: + return Node256::InsertChild(art, node, byte, child); + default: + throw InternalException("Invalid node type for InsertChild."); + } +} + +//===--------------------------------------------------------------------===// +// Deletes +//===--------------------------------------------------------------------===// + +void Node::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte) { + + switch (node.GetType()) { + case NType::NODE_4: + return Node4::DeleteChild(art, node, prefix, byte); + case NType::NODE_16: + return Node16::DeleteChild(art, node, byte); + case NType::NODE_48: + return Node48::DeleteChild(art, node, byte); + case NType::NODE_256: + return Node256::DeleteChild(art, node, byte); + default: + throw InternalException("Invalid node type for DeleteChild."); + } +} + +//===--------------------------------------------------------------------===// +// Get functions +//===--------------------------------------------------------------------===// + +optional_ptr Node::GetChild(ART &art, const uint8_t byte) const { + + D_ASSERT(HasMetadata()); + + switch (GetType()) { + case NType::NODE_4: + return Ref(art, *this, NType::NODE_4).GetChild(byte); + case NType::NODE_16: + return Ref(art, *this, NType::NODE_16).GetChild(byte); + case NType::NODE_48: + return Ref(art, *this, NType::NODE_48).GetChild(byte); + case NType::NODE_256: + return Ref(art, *this, NType::NODE_256).GetChild(byte); + default: + throw InternalException("Invalid node type for GetChild."); + } +} + +optional_ptr Node::GetChildMutable(ART &art, const uint8_t byte) const { + + D_ASSERT(HasMetadata()); + + switch (GetType()) { + case NType::NODE_4: + return RefMutable(art, *this, NType::NODE_4).GetChildMutable(byte); + case NType::NODE_16: + return RefMutable(art, *this, NType::NODE_16).GetChildMutable(byte); + case NType::NODE_48: + return RefMutable(art, *this, NType::NODE_48).GetChildMutable(byte); + case NType::NODE_256: + return RefMutable(art, *this, NType::NODE_256).GetChildMutable(byte); + default: + throw InternalException("Invalid node type for GetChildMutable."); + } +} + +optional_ptr Node::GetNextChild(ART &art, uint8_t &byte) const { + + D_ASSERT(HasMetadata()); + + switch (GetType()) { + case NType::NODE_4: + return Ref(art, *this, NType::NODE_4).GetNextChild(byte); + case NType::NODE_16: + return Ref(art, *this, NType::NODE_16).GetNextChild(byte); + case NType::NODE_48: + return Ref(art, *this, NType::NODE_48).GetNextChild(byte); + case NType::NODE_256: + return Ref(art, *this, NType::NODE_256).GetNextChild(byte); + default: + throw InternalException("Invalid node type for GetNextChild."); + } +} + +optional_ptr Node::GetNextChildMutable(ART &art, uint8_t &byte) const { + + D_ASSERT(HasMetadata()); + + switch (GetType()) { + case NType::NODE_4: + return RefMutable(art, *this, NType::NODE_4).GetNextChildMutable(byte); + case NType::NODE_16: + return RefMutable(art, *this, NType::NODE_16).GetNextChildMutable(byte); + case NType::NODE_48: + return RefMutable(art, *this, NType::NODE_48).GetNextChildMutable(byte); + case NType::NODE_256: + return RefMutable(art, *this, NType::NODE_256).GetNextChildMutable(byte); + default: + throw InternalException("Invalid node type for GetNextChildMutable."); + } +} + +//===--------------------------------------------------------------------===// +// Utility +//===--------------------------------------------------------------------===// + +string Node::VerifyAndToString(ART &art, const bool only_verify) const { + + D_ASSERT(HasMetadata()); + + if (GetType() == NType::LEAF || GetType() == NType::LEAF_INLINED) { + auto str = Leaf::VerifyAndToString(art, *this, only_verify); + return only_verify ? "" : "\n" + str; + } + if (GetType() == NType::PREFIX) { + auto str = Prefix::VerifyAndToString(art, *this, only_verify); + return only_verify ? "" : "\n" + str; + } + + string str = "Node" + to_string(GetCapacity()) + ": ["; + uint8_t byte = 0; + auto child = GetNextChild(art, byte); + + while (child) { + str += "(" + to_string(byte) + ", " + child->VerifyAndToString(art, only_verify) + ")"; + if (byte == NumericLimits::Maximum()) { + break; + } + + byte++; + child = GetNextChild(art, byte); + } + + return only_verify ? "" : "\n" + str + "]"; +} + +idx_t Node::GetCapacity() const { + + switch (GetType()) { + case NType::NODE_4: + return NODE_4_CAPACITY; + case NType::NODE_16: + return NODE_16_CAPACITY; + case NType::NODE_48: + return NODE_48_CAPACITY; + case NType::NODE_256: + return NODE_256_CAPACITY; + default: + throw InternalException("Invalid node type for GetCapacity."); + } +} + +NType Node::GetARTNodeTypeByCount(const idx_t count) { + + if (count <= NODE_4_CAPACITY) { + return NType::NODE_4; + } else if (count <= NODE_16_CAPACITY) { + return NType::NODE_16; + } else if (count <= NODE_48_CAPACITY) { + return NType::NODE_48; + } + return NType::NODE_256; +} + +//===--------------------------------------------------------------------===// +// Merging +//===--------------------------------------------------------------------===// + +void Node::InitializeMerge(ART &art, const ARTFlags &flags) { + + D_ASSERT(HasMetadata()); + + switch (GetType()) { + case NType::PREFIX: + // iterative + return Prefix::InitializeMerge(art, *this, flags); + case NType::LEAF: + // iterative + return Leaf::InitializeMerge(art, *this, flags); + case NType::NODE_4: + RefMutable(art, *this, NType::NODE_4).InitializeMerge(art, flags); + break; + case NType::NODE_16: + RefMutable(art, *this, NType::NODE_16).InitializeMerge(art, flags); + break; + case NType::NODE_48: + RefMutable(art, *this, NType::NODE_48).InitializeMerge(art, flags); + break; + case NType::NODE_256: + RefMutable(art, *this, NType::NODE_256).InitializeMerge(art, flags); + break; + case NType::LEAF_INLINED: + return; + } + + IncreaseBufferId(flags.merge_buffer_counts[static_cast(GetType()) - 1]); +} + +bool Node::Merge(ART &art, Node &other) { + + if (!HasMetadata()) { + *this = other; + other = Node(); + return true; + } + + return ResolvePrefixes(art, other); +} + +bool MergePrefixContainsOtherPrefix(ART &art, reference &l_node, reference &r_node, + idx_t &mismatch_position) { + + // r_node's prefix contains l_node's prefix + // l_node cannot be a leaf, otherwise the key represented by l_node would be a subset of another key + // which is not possible by our construction + D_ASSERT(l_node.get().GetType() != NType::LEAF && l_node.get().GetType() != NType::LEAF_INLINED); + + // test if the next byte (mismatch_position) in r_node (prefix) exists in l_node + auto mismatch_byte = Prefix::GetByte(art, r_node, mismatch_position); + auto child_node = l_node.get().GetChildMutable(art, mismatch_byte); + + // update the prefix of r_node to only consist of the bytes after mismatch_position + Prefix::Reduce(art, r_node, mismatch_position); + + if (!child_node) { + // insert r_node as a child of l_node at the empty position + Node::InsertChild(art, l_node, mismatch_byte, r_node); + r_node.get().Clear(); + return true; + } + + // recurse + return child_node->ResolvePrefixes(art, r_node); +} + +void MergePrefixesDiffer(ART &art, reference &l_node, reference &r_node, idx_t &mismatch_position) { + + // create a new node and insert both nodes as children + + Node l_child; + auto l_byte = Prefix::GetByte(art, l_node, mismatch_position); + Prefix::Split(art, l_node, l_child, mismatch_position); + Node4::New(art, l_node); + + // insert children + Node4::InsertChild(art, l_node, l_byte, l_child); + auto r_byte = Prefix::GetByte(art, r_node, mismatch_position); + Prefix::Reduce(art, r_node, mismatch_position); + Node4::InsertChild(art, l_node, r_byte, r_node); + + r_node.get().Clear(); +} + +bool Node::ResolvePrefixes(ART &art, Node &other) { + + // NOTE: we always merge into the left ART + + D_ASSERT(HasMetadata() && other.HasMetadata()); + + // case 1: both nodes have no prefix + if (GetType() != NType::PREFIX && other.GetType() != NType::PREFIX) { + return MergeInternal(art, other); + } + + reference l_node(*this); + reference r_node(other); + + idx_t mismatch_position = DConstants::INVALID_INDEX; + + // traverse prefixes + if (l_node.get().GetType() == NType::PREFIX && r_node.get().GetType() == NType::PREFIX) { + + if (!Prefix::Traverse(art, l_node, r_node, mismatch_position)) { + return false; + } + // we already recurse because the prefixes matched (so far) + if (mismatch_position == DConstants::INVALID_INDEX) { + return true; + } + + } else { + + // l_prefix contains r_prefix + if (l_node.get().GetType() == NType::PREFIX) { + swap(*this, other); + } + mismatch_position = 0; + } + D_ASSERT(mismatch_position != DConstants::INVALID_INDEX); + + // case 2: one prefix contains the other prefix + if (l_node.get().GetType() != NType::PREFIX && r_node.get().GetType() == NType::PREFIX) { + return MergePrefixContainsOtherPrefix(art, l_node, r_node, mismatch_position); + } + + // case 3: prefixes differ at a specific byte + MergePrefixesDiffer(art, l_node, r_node, mismatch_position); + return true; +} + +bool Node::MergeInternal(ART &art, Node &other) { + + D_ASSERT(HasMetadata() && other.HasMetadata()); + D_ASSERT(GetType() != NType::PREFIX && other.GetType() != NType::PREFIX); + + // always try to merge the smaller node into the bigger node + // because maybe there is enough free space in the bigger node to fit the smaller one + // without too much recursion + if (GetType() < other.GetType()) { + swap(*this, other); + } + + Node empty_node; + auto &l_node = *this; + auto &r_node = other; + + if (r_node.GetType() == NType::LEAF || r_node.GetType() == NType::LEAF_INLINED) { + D_ASSERT(l_node.GetType() == NType::LEAF || l_node.GetType() == NType::LEAF_INLINED); + + if (art.IsUnique()) { + return false; + } + + Leaf::Merge(art, l_node, r_node); + return true; + } + + uint8_t byte = 0; + auto r_child = r_node.GetNextChildMutable(art, byte); + + // while r_node still has children to merge + while (r_child) { + auto l_child = l_node.GetChildMutable(art, byte); + if (!l_child) { + // insert child at empty byte + InsertChild(art, l_node, byte, *r_child); + r_node.ReplaceChild(art, byte, empty_node); + + } else { + // recurse + if (!l_child->ResolvePrefixes(art, *r_child)) { + return false; + } + } + + if (byte == NumericLimits::Maximum()) { + break; + } + byte++; + r_child = r_node.GetNextChildMutable(art, byte); + } + + Free(art, r_node); + return true; +} + +//===--------------------------------------------------------------------===// +// Vacuum +//===--------------------------------------------------------------------===// + +void Node::Vacuum(ART &art, const ARTFlags &flags) { + + D_ASSERT(HasMetadata()); + + auto node_type = GetType(); + auto node_type_idx = static_cast(node_type); + + // iterative functions + if (node_type == NType::PREFIX) { + return Prefix::Vacuum(art, *this, flags); + } + if (node_type == NType::LEAF_INLINED) { + return; + } + if (node_type == NType::LEAF) { + if (flags.vacuum_flags[node_type_idx - 1]) { + Leaf::Vacuum(art, *this); + } + return; + } + + auto &allocator = GetAllocator(art, node_type); + auto needs_vacuum = flags.vacuum_flags[node_type_idx - 1] && allocator.NeedsVacuum(*this); + if (needs_vacuum) { + *this = allocator.VacuumPointer(*this); + SetMetadata(node_type_idx); + } + + // recursive functions + switch (node_type) { + case NType::NODE_4: + return RefMutable(art, *this, NType::NODE_4).Vacuum(art, flags); + case NType::NODE_16: + return RefMutable(art, *this, NType::NODE_16).Vacuum(art, flags); + case NType::NODE_48: + return RefMutable(art, *this, NType::NODE_48).Vacuum(art, flags); + case NType::NODE_256: + return RefMutable(art, *this, NType::NODE_256).Vacuum(art, flags); + default: + throw InternalException("Invalid node type for Vacuum."); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node16.cpp b/src/duckdb/src/execution/index/art/node16.cpp new file mode 100644 index 00000000..4541c124 --- /dev/null +++ b/src/duckdb/src/execution/index/art/node16.cpp @@ -0,0 +1,196 @@ +#include "duckdb/execution/index/art/node16.hpp" + +#include "duckdb/execution/index/art/node4.hpp" +#include "duckdb/execution/index/art/node48.hpp" + +namespace duckdb { + +Node16 &Node16::New(ART &art, Node &node) { + + node = Node::GetAllocator(art, NType::NODE_16).New(); + node.SetMetadata(static_cast(NType::NODE_16)); + auto &n16 = Node::RefMutable(art, node, NType::NODE_16); + + n16.count = 0; + return n16; +} + +void Node16::Free(ART &art, Node &node) { + + D_ASSERT(node.HasMetadata()); + auto &n16 = Node::RefMutable(art, node, NType::NODE_16); + + // free all children + for (idx_t i = 0; i < n16.count; i++) { + Node::Free(art, n16.children[i]); + } +} + +Node16 &Node16::GrowNode4(ART &art, Node &node16, Node &node4) { + + auto &n4 = Node::RefMutable(art, node4, NType::NODE_4); + auto &n16 = New(art, node16); + + n16.count = n4.count; + for (idx_t i = 0; i < n4.count; i++) { + n16.key[i] = n4.key[i]; + n16.children[i] = n4.children[i]; + } + + n4.count = 0; + Node::Free(art, node4); + return n16; +} + +Node16 &Node16::ShrinkNode48(ART &art, Node &node16, Node &node48) { + + auto &n16 = New(art, node16); + auto &n48 = Node::RefMutable(art, node48, NType::NODE_48); + + n16.count = 0; + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + D_ASSERT(n16.count <= Node::NODE_16_CAPACITY); + if (n48.child_index[i] != Node::EMPTY_MARKER) { + n16.key[n16.count] = i; + n16.children[n16.count] = n48.children[n48.child_index[i]]; + n16.count++; + } + } + + n48.count = 0; + Node::Free(art, node48); + return n16; +} + +void Node16::InitializeMerge(ART &art, const ARTFlags &flags) { + + for (idx_t i = 0; i < count; i++) { + children[i].InitializeMerge(art, flags); + } +} + +void Node16::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + + D_ASSERT(node.HasMetadata()); + auto &n16 = Node::RefMutable(art, node, NType::NODE_16); + + // ensure that there is no other child at the same byte + for (idx_t i = 0; i < n16.count; i++) { + D_ASSERT(n16.key[i] != byte); + } + + // insert new child node into node + if (n16.count < Node::NODE_16_CAPACITY) { + // still space, just insert the child + idx_t child_pos = 0; + while (child_pos < n16.count && n16.key[child_pos] < byte) { + child_pos++; + } + // move children backwards to make space + for (idx_t i = n16.count; i > child_pos; i--) { + n16.key[i] = n16.key[i - 1]; + n16.children[i] = n16.children[i - 1]; + } + + n16.key[child_pos] = byte; + n16.children[child_pos] = child; + n16.count++; + + } else { + // node is full, grow to Node48 + auto node16 = node; + Node48::GrowNode16(art, node, node16); + Node48::InsertChild(art, node, byte, child); + } +} + +void Node16::DeleteChild(ART &art, Node &node, const uint8_t byte) { + + D_ASSERT(node.HasMetadata()); + auto &n16 = Node::RefMutable(art, node, NType::NODE_16); + + idx_t child_pos = 0; + for (; child_pos < n16.count; child_pos++) { + if (n16.key[child_pos] == byte) { + break; + } + } + + D_ASSERT(child_pos < n16.count); + + // free the child and decrease the count + Node::Free(art, n16.children[child_pos]); + n16.count--; + + // potentially move any children backwards + for (idx_t i = child_pos; i < n16.count; i++) { + n16.key[i] = n16.key[i + 1]; + n16.children[i] = n16.children[i + 1]; + } + + // shrink node to Node4 + if (n16.count < Node::NODE_4_CAPACITY) { + auto node16 = node; + Node4::ShrinkNode16(art, node, node16); + } +} + +void Node16::ReplaceChild(const uint8_t byte, const Node child) { + for (idx_t i = 0; i < count; i++) { + if (key[i] == byte) { + children[i] = child; + return; + } + } +} + +optional_ptr Node16::GetChild(const uint8_t byte) const { + for (idx_t i = 0; i < count; i++) { + if (key[i] == byte) { + D_ASSERT(children[i].HasMetadata()); + return &children[i]; + } + } + return nullptr; +} + +optional_ptr Node16::GetChildMutable(const uint8_t byte) { + for (idx_t i = 0; i < count; i++) { + if (key[i] == byte) { + D_ASSERT(children[i].HasMetadata()); + return &children[i]; + } + } + return nullptr; +} + +optional_ptr Node16::GetNextChild(uint8_t &byte) const { + for (idx_t i = 0; i < count; i++) { + if (key[i] >= byte) { + byte = key[i]; + D_ASSERT(children[i].HasMetadata()); + return &children[i]; + } + } + return nullptr; +} + +optional_ptr Node16::GetNextChildMutable(uint8_t &byte) { + for (idx_t i = 0; i < count; i++) { + if (key[i] >= byte) { + byte = key[i]; + D_ASSERT(children[i].HasMetadata()); + return &children[i]; + } + } + return nullptr; +} + +void Node16::Vacuum(ART &art, const ARTFlags &flags) { + + for (idx_t i = 0; i < count; i++) { + children[i].Vacuum(art, flags); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node256.cpp b/src/duckdb/src/execution/index/art/node256.cpp new file mode 100644 index 00000000..b4d82fec --- /dev/null +++ b/src/duckdb/src/execution/index/art/node256.cpp @@ -0,0 +1,138 @@ +#include "duckdb/execution/index/art/node256.hpp" + +#include "duckdb/execution/index/art/node48.hpp" + +namespace duckdb { + +Node256 &Node256::New(ART &art, Node &node) { + + node = Node::GetAllocator(art, NType::NODE_256).New(); + node.SetMetadata(static_cast(NType::NODE_256)); + auto &n256 = Node::RefMutable(art, node, NType::NODE_256); + + n256.count = 0; + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + n256.children[i].Clear(); + } + + return n256; +} + +void Node256::Free(ART &art, Node &node) { + + D_ASSERT(node.HasMetadata()); + auto &n256 = Node::RefMutable(art, node, NType::NODE_256); + + if (!n256.count) { + return; + } + + // free all children + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + if (n256.children[i].HasMetadata()) { + Node::Free(art, n256.children[i]); + } + } +} + +Node256 &Node256::GrowNode48(ART &art, Node &node256, Node &node48) { + + auto &n48 = Node::RefMutable(art, node48, NType::NODE_48); + auto &n256 = New(art, node256); + + n256.count = n48.count; + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + if (n48.child_index[i] != Node::EMPTY_MARKER) { + n256.children[i] = n48.children[n48.child_index[i]]; + } else { + n256.children[i].Clear(); + } + } + + n48.count = 0; + Node::Free(art, node48); + return n256; +} + +void Node256::InitializeMerge(ART &art, const ARTFlags &flags) { + + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + if (children[i].HasMetadata()) { + children[i].InitializeMerge(art, flags); + } + } +} + +void Node256::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + + D_ASSERT(node.HasMetadata()); + auto &n256 = Node::RefMutable(art, node, NType::NODE_256); + + // ensure that there is no other child at the same byte + D_ASSERT(!n256.children[byte].HasMetadata()); + + n256.count++; + D_ASSERT(n256.count <= Node::NODE_256_CAPACITY); + n256.children[byte] = child; +} + +void Node256::DeleteChild(ART &art, Node &node, const uint8_t byte) { + + D_ASSERT(node.HasMetadata()); + auto &n256 = Node::RefMutable(art, node, NType::NODE_256); + + // free the child and decrease the count + Node::Free(art, n256.children[byte]); + n256.count--; + + // shrink node to Node48 + if (n256.count <= Node::NODE_256_SHRINK_THRESHOLD) { + auto node256 = node; + Node48::ShrinkNode256(art, node, node256); + } +} + +optional_ptr Node256::GetChild(const uint8_t byte) const { + if (children[byte].HasMetadata()) { + return &children[byte]; + } + return nullptr; +} + +optional_ptr Node256::GetChildMutable(const uint8_t byte) { + if (children[byte].HasMetadata()) { + return &children[byte]; + } + return nullptr; +} + +optional_ptr Node256::GetNextChild(uint8_t &byte) const { + for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { + if (children[i].HasMetadata()) { + byte = i; + return &children[i]; + } + } + return nullptr; +} + +optional_ptr Node256::GetNextChildMutable(uint8_t &byte) { + for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { + if (children[i].HasMetadata()) { + byte = i; + return &children[i]; + } + } + return nullptr; +} + +void Node256::Vacuum(ART &art, const ARTFlags &flags) { + + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + if (children[i].HasMetadata()) { + children[i].Vacuum(art, flags); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node4.cpp b/src/duckdb/src/execution/index/art/node4.cpp new file mode 100644 index 00000000..6438b744 --- /dev/null +++ b/src/duckdb/src/execution/index/art/node4.cpp @@ -0,0 +1,189 @@ +#include "duckdb/execution/index/art/node4.hpp" + +#include "duckdb/execution/index/art/prefix.hpp" +#include "duckdb/execution/index/art/node16.hpp" + +namespace duckdb { + +Node4 &Node4::New(ART &art, Node &node) { + + node = Node::GetAllocator(art, NType::NODE_4).New(); + node.SetMetadata(static_cast(NType::NODE_4)); + auto &n4 = Node::RefMutable(art, node, NType::NODE_4); + + n4.count = 0; + return n4; +} + +void Node4::Free(ART &art, Node &node) { + + D_ASSERT(node.HasMetadata()); + auto &n4 = Node::RefMutable(art, node, NType::NODE_4); + + // free all children + for (idx_t i = 0; i < n4.count; i++) { + Node::Free(art, n4.children[i]); + } +} + +Node4 &Node4::ShrinkNode16(ART &art, Node &node4, Node &node16) { + + auto &n4 = New(art, node4); + auto &n16 = Node::RefMutable(art, node16, NType::NODE_16); + + D_ASSERT(n16.count <= Node::NODE_4_CAPACITY); + n4.count = n16.count; + for (idx_t i = 0; i < n16.count; i++) { + n4.key[i] = n16.key[i]; + n4.children[i] = n16.children[i]; + } + + n16.count = 0; + Node::Free(art, node16); + return n4; +} + +void Node4::InitializeMerge(ART &art, const ARTFlags &flags) { + + for (idx_t i = 0; i < count; i++) { + children[i].InitializeMerge(art, flags); + } +} + +void Node4::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + + D_ASSERT(node.HasMetadata()); + auto &n4 = Node::RefMutable(art, node, NType::NODE_4); + + // ensure that there is no other child at the same byte + for (idx_t i = 0; i < n4.count; i++) { + D_ASSERT(n4.key[i] != byte); + } + + // insert new child node into node + if (n4.count < Node::NODE_4_CAPACITY) { + // still space, just insert the child + idx_t child_pos = 0; + while (child_pos < n4.count && n4.key[child_pos] < byte) { + child_pos++; + } + // move children backwards to make space + for (idx_t i = n4.count; i > child_pos; i--) { + n4.key[i] = n4.key[i - 1]; + n4.children[i] = n4.children[i - 1]; + } + + n4.key[child_pos] = byte; + n4.children[child_pos] = child; + n4.count++; + + } else { + // node is full, grow to Node16 + auto node4 = node; + Node16::GrowNode4(art, node, node4); + Node16::InsertChild(art, node, byte, child); + } +} + +void Node4::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte) { + + D_ASSERT(node.HasMetadata()); + auto &n4 = Node::RefMutable(art, node, NType::NODE_4); + + idx_t child_pos = 0; + for (; child_pos < n4.count; child_pos++) { + if (n4.key[child_pos] == byte) { + break; + } + } + + D_ASSERT(child_pos < n4.count); + D_ASSERT(n4.count > 1); + + // free the child and decrease the count + Node::Free(art, n4.children[child_pos]); + n4.count--; + + // potentially move any children backwards + for (idx_t i = child_pos; i < n4.count; i++) { + n4.key[i] = n4.key[i + 1]; + n4.children[i] = n4.children[i + 1]; + } + + // this is a one way node, compress + if (n4.count == 1) { + + // we need to keep track of the old node pointer + // because Concatenate() might overwrite that pointer while appending bytes to + // the prefix (and by doing so overwriting the subsequent node with + // new prefix nodes) + auto old_n4_node = node; + + // get only child and concatenate prefixes + auto child = *n4.GetChildMutable(n4.key[0]); + Prefix::Concatenate(art, prefix, n4.key[0], child); + + n4.count--; + Node::Free(art, old_n4_node); + } +} + +void Node4::ReplaceChild(const uint8_t byte, const Node child) { + for (idx_t i = 0; i < count; i++) { + if (key[i] == byte) { + children[i] = child; + return; + } + } +} + +optional_ptr Node4::GetChild(const uint8_t byte) const { + for (idx_t i = 0; i < count; i++) { + if (key[i] == byte) { + D_ASSERT(children[i].HasMetadata()); + return &children[i]; + } + } + return nullptr; +} + +optional_ptr Node4::GetChildMutable(const uint8_t byte) { + for (idx_t i = 0; i < count; i++) { + if (key[i] == byte) { + D_ASSERT(children[i].HasMetadata()); + return &children[i]; + } + } + return nullptr; +} + +optional_ptr Node4::GetNextChild(uint8_t &byte) const { + for (idx_t i = 0; i < count; i++) { + if (key[i] >= byte) { + byte = key[i]; + D_ASSERT(children[i].HasMetadata()); + return &children[i]; + } + } + return nullptr; +} + +optional_ptr Node4::GetNextChildMutable(uint8_t &byte) { + for (idx_t i = 0; i < count; i++) { + if (key[i] >= byte) { + byte = key[i]; + D_ASSERT(children[i].HasMetadata()); + return &children[i]; + } + } + return nullptr; +} + +void Node4::Vacuum(ART &art, const ARTFlags &flags) { + + for (idx_t i = 0; i < count; i++) { + children[i].Vacuum(art, flags); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node48.cpp b/src/duckdb/src/execution/index/art/node48.cpp new file mode 100644 index 00000000..8ba21122 --- /dev/null +++ b/src/duckdb/src/execution/index/art/node48.cpp @@ -0,0 +1,198 @@ +#include "duckdb/execution/index/art/node48.hpp" + +#include "duckdb/execution/index/art/node16.hpp" +#include "duckdb/execution/index/art/node256.hpp" + +namespace duckdb { + +Node48 &Node48::New(ART &art, Node &node) { + + node = Node::GetAllocator(art, NType::NODE_48).New(); + node.SetMetadata(static_cast(NType::NODE_48)); + auto &n48 = Node::RefMutable(art, node, NType::NODE_48); + + n48.count = 0; + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + n48.child_index[i] = Node::EMPTY_MARKER; + } + for (idx_t i = 0; i < Node::NODE_48_CAPACITY; i++) { + n48.children[i].Clear(); + } + + return n48; +} + +void Node48::Free(ART &art, Node &node) { + + D_ASSERT(node.HasMetadata()); + auto &n48 = Node::RefMutable(art, node, NType::NODE_48); + + if (!n48.count) { + return; + } + + // free all children + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + if (n48.child_index[i] != Node::EMPTY_MARKER) { + Node::Free(art, n48.children[n48.child_index[i]]); + } + } +} + +Node48 &Node48::GrowNode16(ART &art, Node &node48, Node &node16) { + + auto &n16 = Node::RefMutable(art, node16, NType::NODE_16); + auto &n48 = New(art, node48); + + n48.count = n16.count; + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + n48.child_index[i] = Node::EMPTY_MARKER; + } + + for (idx_t i = 0; i < n16.count; i++) { + n48.child_index[n16.key[i]] = i; + n48.children[i] = n16.children[i]; + } + + // necessary for faster child insertion/deletion + for (idx_t i = n16.count; i < Node::NODE_48_CAPACITY; i++) { + n48.children[i].Clear(); + } + + n16.count = 0; + Node::Free(art, node16); + return n48; +} + +Node48 &Node48::ShrinkNode256(ART &art, Node &node48, Node &node256) { + + auto &n48 = New(art, node48); + auto &n256 = Node::RefMutable(art, node256, NType::NODE_256); + + n48.count = 0; + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + D_ASSERT(n48.count <= Node::NODE_48_CAPACITY); + if (n256.children[i].HasMetadata()) { + n48.child_index[i] = n48.count; + n48.children[n48.count] = n256.children[i]; + n48.count++; + } else { + n48.child_index[i] = Node::EMPTY_MARKER; + } + } + + // necessary for faster child insertion/deletion + for (idx_t i = n48.count; i < Node::NODE_48_CAPACITY; i++) { + n48.children[i].Clear(); + } + + n256.count = 0; + Node::Free(art, node256); + return n48; +} + +void Node48::InitializeMerge(ART &art, const ARTFlags &flags) { + + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + if (child_index[i] != Node::EMPTY_MARKER) { + children[child_index[i]].InitializeMerge(art, flags); + } + } +} + +void Node48::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + + D_ASSERT(node.HasMetadata()); + auto &n48 = Node::RefMutable(art, node, NType::NODE_48); + + // ensure that there is no other child at the same byte + D_ASSERT(n48.child_index[byte] == Node::EMPTY_MARKER); + + // insert new child node into node + if (n48.count < Node::NODE_48_CAPACITY) { + // still space, just insert the child + idx_t child_pos = n48.count; + if (n48.children[child_pos].HasMetadata()) { + // find an empty position in the node list if the current position is occupied + child_pos = 0; + while (n48.children[child_pos].HasMetadata()) { + child_pos++; + } + } + n48.children[child_pos] = child; + n48.child_index[byte] = child_pos; + n48.count++; + + } else { + // node is full, grow to Node256 + auto node48 = node; + Node256::GrowNode48(art, node, node48); + Node256::InsertChild(art, node, byte, child); + } +} + +void Node48::DeleteChild(ART &art, Node &node, const uint8_t byte) { + + D_ASSERT(node.HasMetadata()); + auto &n48 = Node::RefMutable(art, node, NType::NODE_48); + + // free the child and decrease the count + Node::Free(art, n48.children[n48.child_index[byte]]); + n48.child_index[byte] = Node::EMPTY_MARKER; + n48.count--; + + // shrink node to Node16 + if (n48.count < Node::NODE_48_SHRINK_THRESHOLD) { + auto node48 = node; + Node16::ShrinkNode48(art, node, node48); + } +} + +optional_ptr Node48::GetChild(const uint8_t byte) const { + if (child_index[byte] != Node::EMPTY_MARKER) { + D_ASSERT(children[child_index[byte]].HasMetadata()); + return &children[child_index[byte]]; + } + return nullptr; +} + +optional_ptr Node48::GetChildMutable(const uint8_t byte) { + if (child_index[byte] != Node::EMPTY_MARKER) { + D_ASSERT(children[child_index[byte]].HasMetadata()); + return &children[child_index[byte]]; + } + return nullptr; +} + +optional_ptr Node48::GetNextChild(uint8_t &byte) const { + for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { + if (child_index[i] != Node::EMPTY_MARKER) { + byte = i; + D_ASSERT(children[child_index[i]].HasMetadata()); + return &children[child_index[i]]; + } + } + return nullptr; +} + +optional_ptr Node48::GetNextChildMutable(uint8_t &byte) { + for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { + if (child_index[i] != Node::EMPTY_MARKER) { + byte = i; + D_ASSERT(children[child_index[i]].HasMetadata()); + return &children[child_index[i]]; + } + } + return nullptr; +} + +void Node48::Vacuum(ART &art, const ARTFlags &flags) { + + for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + if (child_index[i] != Node::EMPTY_MARKER) { + children[child_index[i]].Vacuum(art, flags); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/prefix.cpp b/src/duckdb/src/execution/index/art/prefix.cpp new file mode 100644 index 00000000..a26ec85f --- /dev/null +++ b/src/duckdb/src/execution/index/art/prefix.cpp @@ -0,0 +1,370 @@ +#include "duckdb/execution/index/art/prefix.hpp" + +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/node.hpp" +#include "duckdb/common/swap.hpp" + +namespace duckdb { + +Prefix &Prefix::New(ART &art, Node &node) { + + node = Node::GetAllocator(art, NType::PREFIX).New(); + node.SetMetadata(static_cast(NType::PREFIX)); + + auto &prefix = Node::RefMutable(art, node, NType::PREFIX); + prefix.data[Node::PREFIX_SIZE] = 0; + return prefix; +} + +Prefix &Prefix::New(ART &art, Node &node, uint8_t byte, const Node &next) { + + node = Node::GetAllocator(art, NType::PREFIX).New(); + node.SetMetadata(static_cast(NType::PREFIX)); + + auto &prefix = Node::RefMutable(art, node, NType::PREFIX); + prefix.data[Node::PREFIX_SIZE] = 1; + prefix.data[0] = byte; + prefix.ptr = next; + return prefix; +} + +void Prefix::New(ART &art, reference &node, const ARTKey &key, const uint32_t depth, uint32_t count) { + + if (count == 0) { + return; + } + idx_t copy_count = 0; + + while (count) { + node.get() = Node::GetAllocator(art, NType::PREFIX).New(); + node.get().SetMetadata(static_cast(NType::PREFIX)); + auto &prefix = Node::RefMutable(art, node, NType::PREFIX); + + auto this_count = MinValue((uint32_t)Node::PREFIX_SIZE, count); + prefix.data[Node::PREFIX_SIZE] = (uint8_t)this_count; + memcpy(prefix.data, key.data + depth + copy_count, this_count); + + node = prefix.ptr; + copy_count += this_count; + count -= this_count; + } +} + +void Prefix::Free(ART &art, Node &node) { + + Node current_node = node; + Node next_node; + while (current_node.HasMetadata() && current_node.GetType() == NType::PREFIX) { + next_node = Node::RefMutable(art, current_node, NType::PREFIX).ptr; + Node::GetAllocator(art, NType::PREFIX).Free(current_node); + current_node = next_node; + } + + Node::Free(art, current_node); + node.Clear(); +} + +void Prefix::InitializeMerge(ART &art, Node &node, const ARTFlags &flags) { + + auto merge_buffer_count = flags.merge_buffer_counts[static_cast(NType::PREFIX) - 1]; + + Node next_node = node; + reference prefix = Node::RefMutable(art, next_node, NType::PREFIX); + + while (next_node.GetType() == NType::PREFIX) { + next_node = prefix.get().ptr; + if (prefix.get().ptr.GetType() == NType::PREFIX) { + prefix.get().ptr.IncreaseBufferId(merge_buffer_count); + prefix = Node::RefMutable(art, next_node, NType::PREFIX); + } + } + + node.IncreaseBufferId(merge_buffer_count); + prefix.get().ptr.InitializeMerge(art, flags); +} + +void Prefix::Concatenate(ART &art, Node &prefix_node, const uint8_t byte, Node &child_prefix_node) { + + D_ASSERT(prefix_node.HasMetadata() && child_prefix_node.HasMetadata()); + + // append a byte and a child_prefix to prefix + if (prefix_node.GetType() == NType::PREFIX) { + + // get the tail + reference prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); + D_ASSERT(prefix.get().ptr.HasMetadata()); + + while (prefix.get().ptr.GetType() == NType::PREFIX) { + prefix = Node::RefMutable(art, prefix.get().ptr, NType::PREFIX); + D_ASSERT(prefix.get().ptr.HasMetadata()); + } + + // append the byte + prefix = prefix.get().Append(art, byte); + + if (child_prefix_node.GetType() == NType::PREFIX) { + // append the child prefix + prefix.get().Append(art, child_prefix_node); + } else { + // set child_prefix_node to succeed prefix + prefix.get().ptr = child_prefix_node; + } + return; + } + + // create a new prefix node containing the byte, then append the child_prefix to it + if (prefix_node.GetType() != NType::PREFIX && child_prefix_node.GetType() == NType::PREFIX) { + + auto child_prefix = child_prefix_node; + auto &prefix = New(art, prefix_node, byte); + prefix.Append(art, child_prefix); + return; + } + + // neither prefix nor child_prefix are prefix nodes + // create a new prefix containing the byte + New(art, prefix_node, byte, child_prefix_node); +} + +idx_t Prefix::Traverse(ART &art, reference &prefix_node, const ARTKey &key, idx_t &depth) { + + D_ASSERT(prefix_node.get().HasMetadata()); + D_ASSERT(prefix_node.get().GetType() == NType::PREFIX); + + // compare prefix nodes to key bytes + while (prefix_node.get().GetType() == NType::PREFIX) { + auto &prefix = Node::Ref(art, prefix_node, NType::PREFIX); + for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + if (prefix.data[i] != key[depth]) { + return i; + } + depth++; + } + prefix_node = prefix.ptr; + D_ASSERT(prefix_node.get().HasMetadata()); + } + + return DConstants::INVALID_INDEX; +} + +idx_t Prefix::TraverseMutable(ART &art, reference &prefix_node, const ARTKey &key, idx_t &depth) { + + D_ASSERT(prefix_node.get().HasMetadata()); + D_ASSERT(prefix_node.get().GetType() == NType::PREFIX); + + // compare prefix nodes to key bytes + while (prefix_node.get().GetType() == NType::PREFIX) { + auto &prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); + for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + if (prefix.data[i] != key[depth]) { + return i; + } + depth++; + } + prefix_node = prefix.ptr; + D_ASSERT(prefix_node.get().HasMetadata()); + } + + return DConstants::INVALID_INDEX; +} + +bool Prefix::Traverse(ART &art, reference &l_node, reference &r_node, idx_t &mismatch_position) { + + auto &l_prefix = Node::RefMutable(art, l_node.get(), NType::PREFIX); + auto &r_prefix = Node::RefMutable(art, r_node.get(), NType::PREFIX); + + // compare prefix bytes + idx_t max_count = MinValue(l_prefix.data[Node::PREFIX_SIZE], r_prefix.data[Node::PREFIX_SIZE]); + for (idx_t i = 0; i < max_count; i++) { + if (l_prefix.data[i] != r_prefix.data[i]) { + mismatch_position = i; + break; + } + } + + if (mismatch_position == DConstants::INVALID_INDEX) { + + // prefixes match (so far) + if (l_prefix.data[Node::PREFIX_SIZE] == r_prefix.data[Node::PREFIX_SIZE]) { + return l_prefix.ptr.ResolvePrefixes(art, r_prefix.ptr); + } + + mismatch_position = max_count; + + // l_prefix contains r_prefix + if (r_prefix.ptr.GetType() != NType::PREFIX && r_prefix.data[Node::PREFIX_SIZE] == max_count) { + swap(l_node.get(), r_node.get()); + l_node = r_prefix.ptr; + + } else { + // r_prefix contains l_prefix + l_node = l_prefix.ptr; + } + } + + return true; +} + +void Prefix::Reduce(ART &art, Node &prefix_node, const idx_t n) { + + D_ASSERT(prefix_node.HasMetadata()); + D_ASSERT(n < Node::PREFIX_SIZE); + + reference prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); + + // free this prefix node + if (n == (idx_t)(prefix.get().data[Node::PREFIX_SIZE] - 1)) { + auto next_ptr = prefix.get().ptr; + D_ASSERT(next_ptr.HasMetadata()); + prefix.get().ptr.Clear(); + Node::Free(art, prefix_node); + prefix_node = next_ptr; + return; + } + + // shift by n bytes in the current prefix + for (idx_t i = 0; i < Node::PREFIX_SIZE - n - 1; i++) { + prefix.get().data[i] = prefix.get().data[n + i + 1]; + } + D_ASSERT(n < (idx_t)(prefix.get().data[Node::PREFIX_SIZE] - 1)); + prefix.get().data[Node::PREFIX_SIZE] -= n + 1; + + // append the remaining prefix bytes + prefix.get().Append(art, prefix.get().ptr); +} + +void Prefix::Split(ART &art, reference &prefix_node, Node &child_node, idx_t position) { + + D_ASSERT(prefix_node.get().HasMetadata()); + + auto &prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); + + // the split is at the last byte of this prefix, so the child_node contains all subsequent + // prefix nodes (prefix.ptr) (if any), and the count of this prefix decreases by one, + // then, we reference prefix.ptr, to overwrite it with a new node later + if (position + 1 == Node::PREFIX_SIZE) { + prefix.data[Node::PREFIX_SIZE]--; + prefix_node = prefix.ptr; + child_node = prefix.ptr; + return; + } + + // append the remaining bytes after the split + if (position + 1 < prefix.data[Node::PREFIX_SIZE]) { + reference child_prefix = New(art, child_node); + for (idx_t i = position + 1; i < prefix.data[Node::PREFIX_SIZE]; i++) { + child_prefix = child_prefix.get().Append(art, prefix.data[i]); + } + + D_ASSERT(prefix.ptr.HasMetadata()); + + if (prefix.ptr.GetType() == NType::PREFIX) { + child_prefix.get().Append(art, prefix.ptr); + } else { + // this is the last prefix node of the prefix + child_prefix.get().ptr = prefix.ptr; + } + } + + // this is the last prefix node of the prefix + if (position + 1 == prefix.data[Node::PREFIX_SIZE]) { + child_node = prefix.ptr; + } + + // set the new size of this node + prefix.data[Node::PREFIX_SIZE] = position; + + // no bytes left before the split, free this node + if (position == 0) { + prefix.ptr.Clear(); + Node::Free(art, prefix_node.get()); + return; + } + + // bytes left before the split, reference subsequent node + prefix_node = prefix.ptr; + return; +} + +string Prefix::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { + + // NOTE: we could do this recursively, but the function-call overhead can become kinda crazy + string str = ""; + + reference node_ref(node); + while (node_ref.get().GetType() == NType::PREFIX) { + + auto &prefix = Node::Ref(art, node_ref, NType::PREFIX); + D_ASSERT(prefix.data[Node::PREFIX_SIZE] != 0); + D_ASSERT(prefix.data[Node::PREFIX_SIZE] <= Node::PREFIX_SIZE); + + str += " prefix_bytes:["; + for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + str += to_string(prefix.data[i]) + "-"; + } + str += "] "; + + node_ref = prefix.ptr; + } + + auto subtree = node_ref.get().VerifyAndToString(art, only_verify); + return only_verify ? "" : str + subtree; +} + +void Prefix::Vacuum(ART &art, Node &node, const ARTFlags &flags) { + + bool flag_set = flags.vacuum_flags[static_cast(NType::PREFIX) - 1]; + auto &allocator = Node::GetAllocator(art, NType::PREFIX); + + reference node_ref(node); + while (node_ref.get().GetType() == NType::PREFIX) { + if (flag_set && allocator.NeedsVacuum(node_ref)) { + node_ref.get() = allocator.VacuumPointer(node_ref); + node_ref.get().SetMetadata(static_cast(NType::PREFIX)); + } + auto &prefix = Node::RefMutable(art, node_ref, NType::PREFIX); + node_ref = prefix.ptr; + } + + node_ref.get().Vacuum(art, flags); +} + +Prefix &Prefix::Append(ART &art, const uint8_t byte) { + + reference prefix(*this); + + // we need a new prefix node + if (prefix.get().data[Node::PREFIX_SIZE] == Node::PREFIX_SIZE) { + prefix = New(art, prefix.get().ptr); + } + + prefix.get().data[prefix.get().data[Node::PREFIX_SIZE]] = byte; + prefix.get().data[Node::PREFIX_SIZE]++; + return prefix.get(); +} + +void Prefix::Append(ART &art, Node other_prefix) { + + D_ASSERT(other_prefix.HasMetadata()); + + reference prefix(*this); + while (other_prefix.GetType() == NType::PREFIX) { + + // copy prefix bytes + auto &other = Node::RefMutable(art, other_prefix, NType::PREFIX); + for (idx_t i = 0; i < other.data[Node::PREFIX_SIZE]; i++) { + prefix = prefix.get().Append(art, other.data[i]); + } + + D_ASSERT(other.ptr.HasMetadata()); + + prefix.get().ptr = other.ptr; + Node::GetAllocator(art, NType::PREFIX).Free(other_prefix); + other_prefix = prefix.get().ptr; + } + + D_ASSERT(prefix.get().ptr.GetType() != NType::PREFIX); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/fixed_size_allocator.cpp b/src/duckdb/src/execution/index/fixed_size_allocator.cpp new file mode 100644 index 00000000..daa9dc15 --- /dev/null +++ b/src/duckdb/src/execution/index/fixed_size_allocator.cpp @@ -0,0 +1,323 @@ +#include "duckdb/execution/index/fixed_size_allocator.hpp" + +#include "duckdb/storage/metadata/metadata_reader.hpp" + +namespace duckdb { + +FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager) + : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), + metadata_manager(block_manager.GetMetadataManager()), segment_size(segment_size), total_segment_count(0) { + + if (segment_size > Storage::BLOCK_SIZE - sizeof(validity_t)) { + throw InternalException("The maximum segment size of fixed-size allocators is " + + to_string(Storage::BLOCK_SIZE - sizeof(validity_t))); + } + + // calculate how many segments fit into one buffer (available_segments_per_buffer) + + idx_t bits_per_value = sizeof(validity_t) * 8; + idx_t byte_count = 0; + + bitmask_count = 0; + available_segments_per_buffer = 0; + + while (byte_count < Storage::BLOCK_SIZE) { + if (!bitmask_count || (bitmask_count * bits_per_value) % available_segments_per_buffer == 0) { + // we need to add another validity_t value to the bitmask, to allow storing another + // bits_per_value segments on a buffer + bitmask_count++; + byte_count += sizeof(validity_t); + } + + auto remaining_bytes = Storage::BLOCK_SIZE - byte_count; + auto remaining_segments = MinValue(remaining_bytes / segment_size, bits_per_value); + + if (remaining_segments == 0) { + break; + } + + available_segments_per_buffer += remaining_segments; + byte_count += remaining_segments * segment_size; + } + + bitmask_offset = bitmask_count * sizeof(validity_t); +} + +IndexPointer FixedSizeAllocator::New() { + + // no more segments available + if (buffers_with_free_space.empty()) { + + // add a new buffer + auto buffer_id = GetAvailableBufferId(); + FixedSizeBuffer new_buffer(block_manager); + buffers.insert(make_pair(buffer_id, std::move(new_buffer))); + buffers_with_free_space.insert(buffer_id); + + // set the bitmask + D_ASSERT(buffers.find(buffer_id) != buffers.end()); + auto &buffer = buffers.find(buffer_id)->second; + ValidityMask mask(reinterpret_cast(buffer.Get())); + + // zero-initialize the bitmask to avoid leaking memory to disk + auto data = mask.GetData(); + for (idx_t i = 0; i < bitmask_count; i++) { + data[i] = 0; + } + + // initializing the bitmask of the new buffer + mask.SetAllValid(available_segments_per_buffer); + } + + // return a pointer to a free segment + D_ASSERT(!buffers_with_free_space.empty()); + auto buffer_id = uint32_t(*buffers_with_free_space.begin()); + + D_ASSERT(buffers.find(buffer_id) != buffers.end()); + auto &buffer = buffers.find(buffer_id)->second; + auto offset = buffer.GetOffset(bitmask_count); + + total_segment_count++; + buffer.segment_count++; + if (buffer.segment_count == available_segments_per_buffer) { + buffers_with_free_space.erase(buffer_id); + } + + // zero-initialize that segment + auto buffer_ptr = buffer.Get(); + auto offset_in_buffer = buffer_ptr + offset * segment_size + bitmask_offset; + memset(offset_in_buffer, 0, segment_size); + + return IndexPointer(buffer_id, offset); +} + +void FixedSizeAllocator::Free(const IndexPointer ptr) { + + auto buffer_id = ptr.GetBufferId(); + auto offset = ptr.GetOffset(); + + D_ASSERT(buffers.find(buffer_id) != buffers.end()); + auto &buffer = buffers.find(buffer_id)->second; + + auto bitmask_ptr = reinterpret_cast(buffer.Get()); + ValidityMask mask(bitmask_ptr); + D_ASSERT(!mask.RowIsValid(offset)); + mask.SetValid(offset); + + D_ASSERT(total_segment_count > 0); + D_ASSERT(buffer.segment_count > 0); + + // adjust the allocator fields + buffers_with_free_space.insert(buffer_id); + total_segment_count--; + buffer.segment_count--; +} + +void FixedSizeAllocator::Reset() { + for (auto &buffer : buffers) { + buffer.second.Destroy(); + } + buffers.clear(); + buffers_with_free_space.clear(); + total_segment_count = 0; +} + +idx_t FixedSizeAllocator::GetMemoryUsage() const { + idx_t memory_usage = 0; + for (auto &buffer : buffers) { + if (buffer.second.InMemory()) { + memory_usage += Storage::BLOCK_SIZE; + } + } + return memory_usage; +} + +idx_t FixedSizeAllocator::GetUpperBoundBufferId() const { + idx_t upper_bound_id = 0; + for (auto &buffer : buffers) { + if (buffer.first >= upper_bound_id) { + upper_bound_id = buffer.first + 1; + } + } + return upper_bound_id; +} + +void FixedSizeAllocator::Merge(FixedSizeAllocator &other) { + + D_ASSERT(segment_size == other.segment_size); + + // remember the buffer count and merge the buffers + idx_t upper_bound_id = GetUpperBoundBufferId(); + for (auto &buffer : other.buffers) { + buffers.insert(make_pair(buffer.first + upper_bound_id, std::move(buffer.second))); + } + other.buffers.clear(); + + // merge the buffers with free spaces + for (auto &buffer_id : other.buffers_with_free_space) { + buffers_with_free_space.insert(buffer_id + upper_bound_id); + } + other.buffers_with_free_space.clear(); + + // add the total allocations + total_segment_count += other.total_segment_count; +} + +bool FixedSizeAllocator::InitializeVacuum() { + + // NOTE: we do not vacuum buffers that are not in memory. We might consider changing this + // in the future, although buffers on disk should almost never be eligible for a vacuum + + if (total_segment_count == 0) { + Reset(); + return false; + } + + // remove all empty buffers + auto buffer_it = buffers.begin(); + while (buffer_it != buffers.end()) { + if (!buffer_it->second.segment_count) { + buffers_with_free_space.erase(buffer_it->first); + buffer_it->second.Destroy(); + buffer_it = buffers.erase(buffer_it); + } else { + buffer_it++; + } + } + + // determine if a vacuum is necessary + multimap temporary_vacuum_buffers; + D_ASSERT(vacuum_buffers.empty()); + idx_t available_segments_in_memory = 0; + + for (auto &buffer : buffers) { + buffer.second.vacuum = false; + if (buffer.second.InMemory()) { + auto available_segments_in_buffer = available_segments_per_buffer - buffer.second.segment_count; + available_segments_in_memory += available_segments_in_buffer; + temporary_vacuum_buffers.emplace(available_segments_in_buffer, buffer.first); + } + } + + // no buffers in memory + if (temporary_vacuum_buffers.empty()) { + return false; + } + + auto excess_buffer_count = available_segments_in_memory / available_segments_per_buffer; + + // calculate the vacuum threshold adaptively + D_ASSERT(excess_buffer_count < temporary_vacuum_buffers.size()); + idx_t memory_usage = GetMemoryUsage(); + idx_t excess_memory_usage = excess_buffer_count * Storage::BLOCK_SIZE; + auto excess_percentage = double(excess_memory_usage) / double(memory_usage); + auto threshold = double(VACUUM_THRESHOLD) / 100.0; + if (excess_percentage < threshold) { + return false; + } + + D_ASSERT(excess_buffer_count <= temporary_vacuum_buffers.size()); + D_ASSERT(temporary_vacuum_buffers.size() <= buffers.size()); + + // erasing from a multimap, we vacuum the buffers with the most free spaces (least full) + while (temporary_vacuum_buffers.size() != excess_buffer_count) { + temporary_vacuum_buffers.erase(temporary_vacuum_buffers.begin()); + } + + // adjust the buffers, and erase all to-be-vacuumed buffers from the available buffer list + for (auto &vacuum_buffer : temporary_vacuum_buffers) { + auto buffer_id = vacuum_buffer.second; + D_ASSERT(buffers.find(buffer_id) != buffers.end()); + buffers.find(buffer_id)->second.vacuum = true; + buffers_with_free_space.erase(buffer_id); + } + + for (auto &vacuum_buffer : temporary_vacuum_buffers) { + vacuum_buffers.insert(vacuum_buffer.second); + } + + return true; +} + +void FixedSizeAllocator::FinalizeVacuum() { + + for (auto &buffer_id : vacuum_buffers) { + D_ASSERT(buffers.find(buffer_id) != buffers.end()); + auto &buffer = buffers.find(buffer_id)->second; + D_ASSERT(buffer.InMemory()); + buffer.Destroy(); + buffers.erase(buffer_id); + } + vacuum_buffers.clear(); +} + +IndexPointer FixedSizeAllocator::VacuumPointer(const IndexPointer ptr) { + + // we do not need to adjust the bitmask of the old buffer, because we will free the entire + // buffer after the vacuum operation + + auto new_ptr = New(); + // new increases the allocation count, we need to counter that here + total_segment_count--; + + memcpy(Get(new_ptr), Get(ptr), segment_size); + return new_ptr; +} + +BlockPointer FixedSizeAllocator::Serialize(PartialBlockManager &partial_block_manager, MetadataWriter &writer) { + + for (auto &buffer : buffers) { + buffer.second.Serialize(partial_block_manager, available_segments_per_buffer, segment_size, bitmask_offset); + } + + auto block_pointer = writer.GetBlockPointer(); + writer.Write(segment_size); + writer.Write(static_cast(buffers.size())); + writer.Write(static_cast(buffers_with_free_space.size())); + + for (auto &buffer : buffers) { + writer.Write(buffer.first); + writer.Write(buffer.second.block_pointer); + writer.Write(buffer.second.segment_count); + writer.Write(buffer.second.allocation_size); + } + for (auto &buffer_id : buffers_with_free_space) { + writer.Write(buffer_id); + } + + return block_pointer; +} + +void FixedSizeAllocator::Deserialize(const BlockPointer &block_pointer) { + + MetadataReader reader(metadata_manager, block_pointer); + segment_size = reader.Read(); + auto buffer_count = reader.Read(); + auto buffers_with_free_space_count = reader.Read(); + + total_segment_count = 0; + + for (idx_t i = 0; i < buffer_count; i++) { + auto buffer_id = reader.Read(); + auto buffer_block_pointer = reader.Read(); + auto segment_count = reader.Read(); + auto allocation_size = reader.Read(); + FixedSizeBuffer new_buffer(block_manager, segment_count, allocation_size, buffer_block_pointer); + buffers.insert(make_pair(buffer_id, std::move(new_buffer))); + total_segment_count += segment_count; + } + for (idx_t i = 0; i < buffers_with_free_space_count; i++) { + buffers_with_free_space.insert(reader.Read()); + } +} + +idx_t FixedSizeAllocator::GetAvailableBufferId() const { + idx_t buffer_id = buffers.size(); + while (buffers.find(buffer_id) != buffers.end()) { + D_ASSERT(buffer_id > 0); + buffer_id--; + } + return buffer_id; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/fixed_size_buffer.cpp b/src/duckdb/src/execution/index/fixed_size_buffer.cpp new file mode 100644 index 00000000..0617f93d --- /dev/null +++ b/src/duckdb/src/execution/index/fixed_size_buffer.cpp @@ -0,0 +1,285 @@ +#include "duckdb/execution/index/fixed_size_buffer.hpp" + +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// PartialBlockForIndex +//===--------------------------------------------------------------------===// + +PartialBlockForIndex::PartialBlockForIndex(PartialBlockState state, BlockManager &block_manager, + const shared_ptr &block_handle) + : PartialBlock(state, block_manager, block_handle) { +} + +void PartialBlockForIndex::Flush(const idx_t free_space_left) { + FlushInternal(free_space_left); + block_handle = block_manager.ConvertToPersistent(state.block_id, std::move(block_handle)); + Clear(); +} + +void PartialBlockForIndex::Merge(PartialBlock &other, idx_t offset, idx_t other_size) { + throw InternalException("no merge for PartialBlockForIndex"); +} + +void PartialBlockForIndex::Clear() { + block_handle.reset(); +} + +//===--------------------------------------------------------------------===// +// FixedSizeBuffer +//===--------------------------------------------------------------------===// + +constexpr idx_t FixedSizeBuffer::BASE[]; +constexpr uint8_t FixedSizeBuffer::SHIFT[]; + +FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager) + : block_manager(block_manager), segment_count(0), allocation_size(0), dirty(false), vacuum(false), block_pointer(), + block_handle(nullptr) { + + auto &buffer_manager = block_manager.buffer_manager; + buffer_handle = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &block_handle); +} + +FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, + const BlockPointer &block_pointer) + : block_manager(block_manager), segment_count(segment_count), allocation_size(allocation_size), dirty(false), + vacuum(false), block_pointer(block_pointer) { + + D_ASSERT(block_pointer.IsValid()); + block_handle = block_manager.RegisterBlock(block_pointer.block_id); + D_ASSERT(block_handle->BlockId() < MAXIMUM_BLOCK); +} + +void FixedSizeBuffer::Destroy() { + if (InMemory()) { + // we can have multiple readers on a pinned block, and unpinning the buffer handle + // decrements the reader count on the underlying block handle (Destroy() unpins) + buffer_handle.Destroy(); + } + if (OnDisk()) { + // marking a block as modified decreases the reference count of multi-use blocks + block_manager.MarkBlockAsModified(block_pointer.block_id); + } +} + +void FixedSizeBuffer::Serialize(PartialBlockManager &partial_block_manager, const idx_t available_segments, + const idx_t segment_size, const idx_t bitmask_offset) { + + // we do not serialize a block that is already on disk and not in memory + if (!InMemory()) { + if (!OnDisk() || dirty) { + throw InternalException("invalid or missing buffer in FixedSizeAllocator"); + } + return; + } + + // we do not serialize a block that is already on disk and not dirty + if (!dirty && OnDisk()) { + return; + } + + if (dirty) { + // the allocation possibly changed + auto max_offset = GetMaxOffset(available_segments); + allocation_size = max_offset * segment_size + bitmask_offset; + } + + // the buffer is in memory, so we copied it onto a new buffer when pinning + D_ASSERT(InMemory() && !OnDisk()); + + // now we write the changes, first get a partial block allocation + PartialBlockAllocation allocation = partial_block_manager.GetBlockAllocation(allocation_size); + block_pointer.block_id = allocation.state.block_id; + block_pointer.offset = allocation.state.offset; + + auto &buffer_manager = block_manager.buffer_manager; + + if (allocation.partial_block) { + // copy to an existing partial block + D_ASSERT(block_pointer.offset > 0); + auto &p_block_for_index = allocation.partial_block->Cast(); + auto dst_handle = buffer_manager.Pin(p_block_for_index.block_handle); + memcpy(dst_handle.Ptr() + block_pointer.offset, buffer_handle.Ptr(), allocation_size); + SetUninitializedRegions(p_block_for_index, segment_size, block_pointer.offset, bitmask_offset); + + } else { + // create a new block that can potentially be used as a partial block + D_ASSERT(block_handle); + D_ASSERT(!block_pointer.offset); + auto p_block_for_index = make_uniq(allocation.state, block_manager, block_handle); + SetUninitializedRegions(*p_block_for_index, segment_size, block_pointer.offset, bitmask_offset); + allocation.partial_block = std::move(p_block_for_index); + } + + partial_block_manager.RegisterPartialBlock(std::move(allocation)); + + // resetting this buffer + buffer_handle.Destroy(); + block_handle = block_manager.RegisterBlock(block_pointer.block_id); + D_ASSERT(block_handle->BlockId() < MAXIMUM_BLOCK); + + // we persist any changes, so the buffer is no longer dirty + dirty = false; +} + +void FixedSizeBuffer::Pin() { + + auto &buffer_manager = block_manager.buffer_manager; + D_ASSERT(block_pointer.IsValid()); + D_ASSERT(block_handle && block_handle->BlockId() < MAXIMUM_BLOCK); + D_ASSERT(!dirty); + + buffer_handle = buffer_manager.Pin(block_handle); + + // we need to copy the (partial) data into a new (not yet disk-backed) buffer handle + shared_ptr new_block_handle; + auto new_buffer_handle = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &new_block_handle); + + memcpy(new_buffer_handle.Ptr(), buffer_handle.Ptr() + block_pointer.offset, allocation_size); + + Destroy(); + buffer_handle = std::move(new_buffer_handle); + block_handle = new_block_handle; + block_pointer = BlockPointer(); +} + +uint32_t FixedSizeBuffer::GetOffset(const idx_t bitmask_count) { + + // get the bitmask data + auto bitmask_ptr = reinterpret_cast(Get()); + ValidityMask mask(bitmask_ptr); + auto data = mask.GetData(); + + // fills up a buffer sequentially before searching for free bits + if (mask.RowIsValid(segment_count)) { + mask.SetInvalid(segment_count); + return segment_count; + } + + for (idx_t entry_idx = 0; entry_idx < bitmask_count; entry_idx++) { + // get an entry with free bits + if (data[entry_idx] == 0) { + continue; + } + + // find the position of the free bit + auto entry = data[entry_idx]; + idx_t first_valid_bit = 0; + + // this loop finds the position of the rightmost set bit in entry and stores it + // in first_valid_bit + for (idx_t i = 0; i < 6; i++) { + // set the left half of the bits of this level to zero and test if the entry is still not zero + if (entry & BASE[i]) { + // first valid bit is in the rightmost s[i] bits + // permanently set the left half of the bits to zero + entry &= BASE[i]; + } else { + // first valid bit is in the leftmost s[i] bits + // shift by s[i] for the next iteration and add s[i] to the position of the rightmost set bit + entry >>= SHIFT[i]; + first_valid_bit += SHIFT[i]; + } + } + D_ASSERT(entry); + + auto prev_bits = entry_idx * sizeof(validity_t) * 8; + D_ASSERT(mask.RowIsValid(prev_bits + first_valid_bit)); + mask.SetInvalid(prev_bits + first_valid_bit); + return (prev_bits + first_valid_bit); + } + + throw InternalException("Invalid bitmask for FixedSizeAllocator"); +} + +uint32_t FixedSizeBuffer::GetMaxOffset(const idx_t available_segments) { + + // this function calls Get() on the buffer + D_ASSERT(InMemory()); + + // finds the maximum zero bit in a bitmask, and adds one to it, + // so that max_offset * segment_size = allocated_size of this bitmask's buffer + idx_t entry_size = sizeof(validity_t) * 8; + idx_t bitmask_count = available_segments / entry_size; + if (available_segments % entry_size != 0) { + bitmask_count++; + } + uint32_t max_offset = bitmask_count * sizeof(validity_t) * 8; + auto bits_in_last_entry = available_segments % (sizeof(validity_t) * 8); + + // get the bitmask data + auto bitmask_ptr = reinterpret_cast(Get()); + const ValidityMask mask(bitmask_ptr); + const auto data = mask.GetData(); + + D_ASSERT(bitmask_count > 0); + for (idx_t i = bitmask_count; i > 0; i--) { + + auto entry = data[i - 1]; + + // set all bits after bits_in_last_entry + if (i == bitmask_count) { + entry |= ~idx_t(0) << bits_in_last_entry; + } + + if (entry == ~idx_t(0)) { + max_offset -= sizeof(validity_t) * 8; + continue; + } + + // invert data[entry_idx] + auto entry_inv = ~entry; + idx_t first_valid_bit = 0; + + // then find the position of the LEFTMOST set bit + for (idx_t level = 0; level < 6; level++) { + + // set the right half of the bits of this level to zero and test if the entry is still not zero + if (entry_inv & ~BASE[level]) { + // first valid bit is in the leftmost s[level] bits + // shift by s[level] for the next iteration and add s[level] to the position of the leftmost set bit + entry_inv >>= SHIFT[level]; + first_valid_bit += SHIFT[level]; + } else { + // first valid bit is in the rightmost s[level] bits + // permanently set the left half of the bits to zero + entry_inv &= BASE[level]; + } + } + D_ASSERT(entry_inv); + max_offset -= sizeof(validity_t) * 8 - first_valid_bit; + D_ASSERT(!mask.RowIsValid(max_offset)); + return max_offset + 1; + } + + // there are no allocations in this buffer + throw InternalException("tried to serialize empty buffer"); +} + +void FixedSizeBuffer::SetUninitializedRegions(PartialBlockForIndex &p_block_for_index, const idx_t segment_size, + const idx_t offset, const idx_t bitmask_offset) { + + // this function calls Get() on the buffer + D_ASSERT(InMemory()); + + auto bitmask_ptr = reinterpret_cast(Get()); + ValidityMask mask(bitmask_ptr); + + idx_t i = 0; + idx_t max_offset = offset + allocation_size; + idx_t current_offset = offset + bitmask_offset; + while (current_offset < max_offset) { + + if (mask.RowIsValid(i)) { + D_ASSERT(current_offset + segment_size <= max_offset); + p_block_for_index.AddUninitializedRegion(current_offset, current_offset + segment_size); + } + current_offset += segment_size; + i++; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp new file mode 100644 index 00000000..dc17aa62 --- /dev/null +++ b/src/duckdb/src/execution/join_hashtable.cpp @@ -0,0 +1,1148 @@ +#include "duckdb/execution/join_hashtable.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/column/column_data_collection_segment.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +using ValidityBytes = JoinHashTable::ValidityBytes; +using ScanStructure = JoinHashTable::ScanStructure; +using ProbeSpill = JoinHashTable::ProbeSpill; +using ProbeSpillLocalState = JoinHashTable::ProbeSpillLocalAppendState; + +JoinHashTable::JoinHashTable(BufferManager &buffer_manager_p, const vector &conditions_p, + vector btypes, JoinType type_p) + : buffer_manager(buffer_manager_p), conditions(conditions_p), build_types(std::move(btypes)), entry_size(0), + tuple_size(0), vfound(Value::BOOLEAN(false)), join_type(type_p), finalized(false), has_null(false), + external(false), radix_bits(4), partition_start(0), partition_end(0) { + + for (auto &condition : conditions) { + D_ASSERT(condition.left->return_type == condition.right->return_type); + auto type = condition.left->return_type; + if (condition.comparison == ExpressionType::COMPARE_EQUAL || + condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + + // ensure that all equality conditions are at the front, + // and that all other conditions are at the back + D_ASSERT(equality_types.size() == condition_types.size()); + equality_types.push_back(type); + } + + predicates.push_back(condition.comparison); + null_values_are_equal.push_back(condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || + condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM); + + condition_types.push_back(type); + } + // at least one equality is necessary + D_ASSERT(!equality_types.empty()); + + // Types for the layout + vector layout_types(condition_types); + layout_types.insert(layout_types.end(), build_types.begin(), build_types.end()); + if (IsRightOuterJoin(join_type)) { + // full/right outer joins need an extra bool to keep track of whether or not a tuple has found a matching entry + // we place the bool before the NEXT pointer + layout_types.emplace_back(LogicalType::BOOLEAN); + } + layout_types.emplace_back(LogicalType::HASH); + layout.Initialize(layout_types, false); + row_matcher.Initialize(false, layout, predicates); + row_matcher_no_match_sel.Initialize(true, layout, predicates); + + const auto &offsets = layout.GetOffsets(); + tuple_size = offsets[condition_types.size() + build_types.size()]; + pointer_offset = offsets.back(); + entry_size = layout.GetRowWidth(); + + data_collection = make_uniq(buffer_manager, layout); + sink_collection = + make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); +} + +JoinHashTable::~JoinHashTable() { +} + +void JoinHashTable::Merge(JoinHashTable &other) { + { + lock_guard guard(data_lock); + data_collection->Combine(*other.data_collection); + } + + if (join_type == JoinType::MARK) { + auto &info = correlated_mark_join_info; + lock_guard mj_lock(info.mj_lock); + has_null = has_null || other.has_null; + if (!info.correlated_types.empty()) { + auto &other_info = other.correlated_mark_join_info; + info.correlated_counts->Combine(*other_info.correlated_counts); + } + } + + sink_collection->Combine(*other.sink_collection); +} + +void JoinHashTable::ApplyBitmask(Vector &hashes, idx_t count) { + if (hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { + D_ASSERT(!ConstantVector::IsNull(hashes)); + auto indices = ConstantVector::GetData(hashes); + *indices = *indices & bitmask; + } else { + hashes.Flatten(count); + auto indices = FlatVector::GetData(hashes); + for (idx_t i = 0; i < count; i++) { + indices[i] &= bitmask; + } + } +} + +void JoinHashTable::ApplyBitmask(Vector &hashes, const SelectionVector &sel, idx_t count, Vector &pointers) { + UnifiedVectorFormat hdata; + hashes.ToUnifiedFormat(count, hdata); + + auto hash_data = UnifiedVectorFormat::GetData(hdata); + auto result_data = FlatVector::GetData(pointers); + auto main_ht = reinterpret_cast(hash_map.get()); + for (idx_t i = 0; i < count; i++) { + auto rindex = sel.get_index(i); + auto hindex = hdata.sel->get_index(rindex); + auto hash = hash_data[hindex]; + result_data[rindex] = main_ht + (hash & bitmask); + } +} + +void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes) { + if (count == keys.size()) { + // no null values are filtered: use regular hash functions + VectorOperations::Hash(keys.data[0], hashes, keys.size()); + for (idx_t i = 1; i < equality_types.size(); i++) { + VectorOperations::CombineHash(hashes, keys.data[i], keys.size()); + } + } else { + // null values were filtered: use selection vector + VectorOperations::Hash(keys.data[0], hashes, sel, count); + for (idx_t i = 1; i < equality_types.size(); i++) { + VectorOperations::CombineHash(hashes, keys.data[i], sel, count); + } + } +} + +static idx_t FilterNullValues(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, + SelectionVector &result) { + idx_t result_count = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto key_idx = vdata.sel->get_index(idx); + if (vdata.validity.RowIsValid(key_idx)) { + result.set_index(result_count++, idx); + } + } + return result_count; +} + +void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChunk &keys, DataChunk &payload) { + D_ASSERT(!finalized); + D_ASSERT(keys.size() == payload.size()); + if (keys.size() == 0) { + return; + } + // special case: correlated mark join + if (join_type == JoinType::MARK && !correlated_mark_join_info.correlated_types.empty()) { + auto &info = correlated_mark_join_info; + lock_guard mj_lock(info.mj_lock); + // Correlated MARK join + // for the correlated mark join we need to keep track of COUNT(*) and COUNT(COLUMN) for each of the correlated + // columns push into the aggregate hash table + D_ASSERT(info.correlated_counts); + info.group_chunk.SetCardinality(keys); + for (idx_t i = 0; i < info.correlated_types.size(); i++) { + info.group_chunk.data[i].Reference(keys.data[i]); + } + if (info.correlated_payload.data.empty()) { + vector types; + types.push_back(keys.data[info.correlated_types.size()].GetType()); + info.correlated_payload.InitializeEmpty(types); + } + info.correlated_payload.SetCardinality(keys); + info.correlated_payload.data[0].Reference(keys.data[info.correlated_types.size()]); + info.correlated_counts->AddChunk(info.group_chunk, info.correlated_payload, AggregateType::NON_DISTINCT); + } + + // build a chunk to append to the data collection [keys, payload, (optional "found" boolean), hash] + DataChunk source_chunk; + source_chunk.InitializeEmpty(layout.GetTypes()); + for (idx_t i = 0; i < keys.ColumnCount(); i++) { + source_chunk.data[i].Reference(keys.data[i]); + } + idx_t col_offset = keys.ColumnCount(); + D_ASSERT(build_types.size() == payload.ColumnCount()); + for (idx_t i = 0; i < payload.ColumnCount(); i++) { + source_chunk.data[col_offset + i].Reference(payload.data[i]); + } + col_offset += payload.ColumnCount(); + if (IsRightOuterJoin(join_type)) { + // for FULL/RIGHT OUTER joins initialize the "found" boolean to false + source_chunk.data[col_offset].Reference(vfound); + col_offset++; + } + Vector hash_values(LogicalType::HASH); + source_chunk.data[col_offset].Reference(hash_values); + source_chunk.SetCardinality(keys); + + // ToUnifiedFormat the source chunk + TupleDataCollection::ToUnifiedFormat(append_state.chunk_state, source_chunk); + + // prepare the keys for processing + const SelectionVector *current_sel; + SelectionVector sel(STANDARD_VECTOR_SIZE); + idx_t added_count = PrepareKeys(keys, append_state.chunk_state.vector_data, current_sel, sel, true); + if (added_count < keys.size()) { + has_null = true; + } + if (added_count == 0) { + return; + } + + // hash the keys and obtain an entry in the list + // note that we only hash the keys used in the equality comparison + Hash(keys, *current_sel, added_count, hash_values); + + // Re-reference and ToUnifiedFormat the hash column after computing it + source_chunk.data[col_offset].Reference(hash_values); + hash_values.ToUnifiedFormat(source_chunk.size(), append_state.chunk_state.vector_data.back().unified); + + // We already called TupleDataCollection::ToUnifiedFormat, so we can AppendUnified here + sink_collection->AppendUnified(append_state, source_chunk, *current_sel, added_count); +} + +idx_t JoinHashTable::PrepareKeys(DataChunk &keys, vector &vector_data, + const SelectionVector *¤t_sel, SelectionVector &sel, bool build_side) { + // figure out which keys are NULL, and create a selection vector out of them + current_sel = FlatVector::IncrementalSelectionVector(); + idx_t added_count = keys.size(); + if (build_side && IsRightOuterJoin(join_type)) { + // in case of a right or full outer join, we cannot remove NULL keys from the build side + return added_count; + } + + for (idx_t col_idx = 0; col_idx < keys.ColumnCount(); col_idx++) { + if (!null_values_are_equal[col_idx]) { + auto &col_key_data = vector_data[col_idx].unified; + if (col_key_data.validity.AllValid()) { + continue; + } + added_count = FilterNullValues(col_key_data, *current_sel, added_count, sel); + // null values are NOT equal for this column, filter them out + current_sel = &sel; + } + } + return added_count; +} + +template +static inline void InsertHashesLoop(atomic pointers[], const hash_t indices[], const idx_t count, + const data_ptr_t key_locations[], const idx_t pointer_offset) { + for (idx_t i = 0; i < count; i++) { + const auto index = indices[i]; + if (PARALLEL) { + data_ptr_t head; + do { + head = pointers[index]; + Store(head, key_locations[i] + pointer_offset); + } while (!std::atomic_compare_exchange_weak(&pointers[index], &head, key_locations[i])); + } else { + // set prev in current key to the value (NOTE: this will be nullptr if there is none) + Store(pointers[index], key_locations[i] + pointer_offset); + + // set pointer to current tuple + pointers[index] = key_locations[i]; + } + } +} + +void JoinHashTable::InsertHashes(Vector &hashes, idx_t count, data_ptr_t key_locations[], bool parallel) { + D_ASSERT(hashes.GetType().id() == LogicalType::HASH); + + // use bitmask to get position in array + ApplyBitmask(hashes, count); + + hashes.Flatten(count); + D_ASSERT(hashes.GetVectorType() == VectorType::FLAT_VECTOR); + + auto pointers = reinterpret_cast *>(hash_map.get()); + auto indices = FlatVector::GetData(hashes); + + if (parallel) { + InsertHashesLoop(pointers, indices, count, key_locations, pointer_offset); + } else { + InsertHashesLoop(pointers, indices, count, key_locations, pointer_offset); + } +} + +void JoinHashTable::InitializePointerTable() { + idx_t capacity = PointerTableCapacity(Count()); + D_ASSERT(IsPowerOfTwo(capacity)); + + if (hash_map.get()) { + // There is already a hash map + auto current_capacity = hash_map.GetSize() / sizeof(data_ptr_t); + if (capacity > current_capacity) { + // Need more space + hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(data_ptr_t)); + } else { + // Just use the current hash map + capacity = current_capacity; + } + } else { + // Allocate a hash map + hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(data_ptr_t)); + } + D_ASSERT(hash_map.GetSize() == capacity * sizeof(data_ptr_t)); + + // initialize HT with all-zero entries + std::fill_n(reinterpret_cast(hash_map.get()), capacity, nullptr); + + bitmask = capacity - 1; +} + +void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool parallel) { + // Pointer table should be allocated + D_ASSERT(hash_map.get()); + + Vector hashes(LogicalType::HASH); + auto hash_data = FlatVector::GetData(hashes); + + TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::KEEP_EVERYTHING_PINNED, chunk_idx_from, + chunk_idx_to, false); + const auto row_locations = iterator.GetRowLocations(); + do { + const auto count = iterator.GetCurrentChunkCount(); + for (idx_t i = 0; i < count; i++) { + hash_data[i] = Load(row_locations[i] + pointer_offset); + } + InsertHashes(hashes, count, row_locations, parallel); + } while (iterator.Next()); +} + +unique_ptr JoinHashTable::InitializeScanStructure(DataChunk &keys, TupleDataChunkState &key_state, + const SelectionVector *¤t_sel) { + D_ASSERT(Count() > 0); // should be handled before + D_ASSERT(finalized); + + // set up the scan structure + auto ss = make_uniq(*this, key_state); + + if (join_type != JoinType::INNER) { + ss->found_match = make_unsafe_uniq_array(STANDARD_VECTOR_SIZE); + memset(ss->found_match.get(), 0, sizeof(bool) * STANDARD_VECTOR_SIZE); + } + + // first prepare the keys for probing + TupleDataCollection::ToUnifiedFormat(key_state, keys); + ss->count = PrepareKeys(keys, key_state.vector_data, current_sel, ss->sel_vector, false); + return ss; +} + +unique_ptr JoinHashTable::Probe(DataChunk &keys, TupleDataChunkState &key_state, + Vector *precomputed_hashes) { + const SelectionVector *current_sel; + auto ss = InitializeScanStructure(keys, key_state, current_sel); + if (ss->count == 0) { + return ss; + } + + if (precomputed_hashes) { + ApplyBitmask(*precomputed_hashes, *current_sel, ss->count, ss->pointers); + } else { + // hash all the keys + Vector hashes(LogicalType::HASH); + Hash(keys, *current_sel, ss->count, hashes); + + // now initialize the pointers of the scan structure based on the hashes + ApplyBitmask(hashes, *current_sel, ss->count, ss->pointers); + } + + // create the selection vector linking to only non-empty entries + ss->InitializeSelectionVector(current_sel); + + return ss; +} + +ScanStructure::ScanStructure(JoinHashTable &ht_p, TupleDataChunkState &key_state_p) + : key_state(key_state_p), pointers(LogicalType::POINTER), sel_vector(STANDARD_VECTOR_SIZE), ht(ht_p), + finished(false) { +} + +void ScanStructure::Next(DataChunk &keys, DataChunk &left, DataChunk &result) { + if (finished) { + return; + } + switch (ht.join_type) { + case JoinType::INNER: + case JoinType::RIGHT: + NextInnerJoin(keys, left, result); + break; + case JoinType::SEMI: + NextSemiJoin(keys, left, result); + break; + case JoinType::MARK: + NextMarkJoin(keys, left, result); + break; + case JoinType::ANTI: + NextAntiJoin(keys, left, result); + break; + case JoinType::OUTER: + case JoinType::LEFT: + NextLeftJoin(keys, left, result); + break; + case JoinType::SINGLE: + NextSingleJoin(keys, left, result); + break; + default: + throw InternalException("Unhandled join type in JoinHashTable"); + } +} + +idx_t ScanStructure::ResolvePredicates(DataChunk &keys, SelectionVector &match_sel, SelectionVector *no_match_sel) { + // Start with the scan selection + for (idx_t i = 0; i < this->count; ++i) { + match_sel.set_index(i, this->sel_vector.get_index(i)); + } + idx_t no_match_count = 0; + + auto &matcher = no_match_sel ? ht.row_matcher_no_match_sel : ht.row_matcher; + return matcher.Match(keys, key_state.vector_data, match_sel, this->count, ht.layout, pointers, no_match_sel, + no_match_count); +} + +idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vector) { + while (true) { + // resolve the predicates for this set of keys + idx_t result_count = ResolvePredicates(keys, result_vector, nullptr); + + // after doing all the comparisons set the found_match vector + if (found_match) { + for (idx_t i = 0; i < result_count; i++) { + auto idx = result_vector.get_index(i); + found_match[idx] = true; + } + } + if (result_count > 0) { + return result_count; + } + // no matches found: check the next set of pointers + AdvancePointers(); + if (this->count == 0) { + return 0; + } + } +} + +void ScanStructure::AdvancePointers(const SelectionVector &sel, idx_t sel_count) { + // now for all the pointers, we move on to the next set of pointers + idx_t new_count = 0; + auto ptrs = FlatVector::GetData(this->pointers); + for (idx_t i = 0; i < sel_count; i++) { + auto idx = sel.get_index(i); + ptrs[idx] = Load(ptrs[idx] + ht.pointer_offset); + if (ptrs[idx]) { + this->sel_vector.set_index(new_count++, idx); + } + } + this->count = new_count; +} + +void ScanStructure::InitializeSelectionVector(const SelectionVector *¤t_sel) { + idx_t non_empty_count = 0; + auto ptrs = FlatVector::GetData(pointers); + auto cnt = count; + for (idx_t i = 0; i < cnt; i++) { + const auto idx = current_sel->get_index(i); + ptrs[idx] = Load(ptrs[idx]); + if (ptrs[idx]) { + sel_vector.set_index(non_empty_count++, idx); + } + } + count = non_empty_count; +} + +void ScanStructure::AdvancePointers() { + AdvancePointers(this->sel_vector, this->count); +} + +void ScanStructure::GatherResult(Vector &result, const SelectionVector &result_vector, + const SelectionVector &sel_vector, const idx_t count, const idx_t col_no) { + ht.data_collection->Gather(pointers, sel_vector, count, col_no, result, result_vector); +} + +void ScanStructure::GatherResult(Vector &result, const SelectionVector &sel_vector, const idx_t count, + const idx_t col_idx) { + GatherResult(result, *FlatVector::IncrementalSelectionVector(), sel_vector, count, col_idx); +} + +void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { + D_ASSERT(result.ColumnCount() == left.ColumnCount() + ht.build_types.size()); + if (this->count == 0) { + // no pointers left to chase + return; + } + + SelectionVector result_vector(STANDARD_VECTOR_SIZE); + + idx_t result_count = ScanInnerJoin(keys, result_vector); + if (result_count > 0) { + if (IsRightOuterJoin(ht.join_type)) { + // full/right outer join: mark join matches as FOUND in the HT + auto ptrs = FlatVector::GetData(pointers); + for (idx_t i = 0; i < result_count; i++) { + auto idx = result_vector.get_index(i); + // NOTE: threadsan reports this as a data race because this can be set concurrently by separate threads + // Technically it is, but it does not matter, since the only value that can be written is "true" + Store(true, ptrs[idx] + ht.tuple_size); + } + } + // matches were found + // construct the result + // on the LHS, we create a slice using the result vector + result.Slice(left, result_vector, result_count); + + // on the RHS, we need to fetch the data from the hash table + for (idx_t i = 0; i < ht.build_types.size(); i++) { + auto &vector = result.data[left.ColumnCount() + i]; + D_ASSERT(vector.GetType() == ht.build_types[i]); + GatherResult(vector, result_vector, result_count, i + ht.condition_types.size()); + } + AdvancePointers(); + } +} + +void ScanStructure::ScanKeyMatches(DataChunk &keys) { + // the semi-join, anti-join and mark-join we handle a differently from the inner join + // since there can be at most STANDARD_VECTOR_SIZE results + // we handle the entire chunk in one call to Next(). + // for every pointer, we keep chasing pointers and doing comparisons. + // this results in a boolean array indicating whether or not the tuple has a match + SelectionVector match_sel(STANDARD_VECTOR_SIZE), no_match_sel(STANDARD_VECTOR_SIZE); + while (this->count > 0) { + // resolve the predicates for the current set of pointers + idx_t match_count = ResolvePredicates(keys, match_sel, &no_match_sel); + idx_t no_match_count = this->count - match_count; + + // mark each of the matches as found + for (idx_t i = 0; i < match_count; i++) { + found_match[match_sel.get_index(i)] = true; + } + // continue searching for the ones where we did not find a match yet + AdvancePointers(no_match_sel, no_match_count); + } +} + +template +void ScanStructure::NextSemiOrAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { + D_ASSERT(left.ColumnCount() == result.ColumnCount()); + D_ASSERT(keys.size() == left.size()); + // create the selection vector from the matches that were found + SelectionVector sel(STANDARD_VECTOR_SIZE); + idx_t result_count = 0; + for (idx_t i = 0; i < keys.size(); i++) { + if (found_match[i] == MATCH) { + // part of the result + sel.set_index(result_count++, i); + } + } + // construct the final result + if (result_count > 0) { + // we only return the columns on the left side + // reference the columns of the left side from the result + result.Slice(left, sel, result_count); + } else { + D_ASSERT(result.size() == 0); + } +} + +void ScanStructure::NextSemiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { + // first scan for key matches + ScanKeyMatches(keys); + // then construct the result from all tuples with a match + NextSemiOrAntiJoin(keys, left, result); + + finished = true; +} + +void ScanStructure::NextAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { + // first scan for key matches + ScanKeyMatches(keys); + // then construct the result from all tuples that did not find a match + NextSemiOrAntiJoin(keys, left, result); + + finished = true; +} + +void ScanStructure::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &child, DataChunk &result) { + // for the initial set of columns we just reference the left side + result.SetCardinality(child); + for (idx_t i = 0; i < child.ColumnCount(); i++) { + result.data[i].Reference(child.data[i]); + } + auto &mark_vector = result.data.back(); + mark_vector.SetVectorType(VectorType::FLAT_VECTOR); + // first we set the NULL values from the join keys + // if there is any NULL in the keys, the result is NULL + auto bool_result = FlatVector::GetData(mark_vector); + auto &mask = FlatVector::Validity(mark_vector); + for (idx_t col_idx = 0; col_idx < join_keys.ColumnCount(); col_idx++) { + if (ht.null_values_are_equal[col_idx]) { + continue; + } + UnifiedVectorFormat jdata; + join_keys.data[col_idx].ToUnifiedFormat(join_keys.size(), jdata); + if (!jdata.validity.AllValid()) { + for (idx_t i = 0; i < join_keys.size(); i++) { + auto jidx = jdata.sel->get_index(i); + mask.Set(i, jdata.validity.RowIsValidUnsafe(jidx)); + } + } + } + // now set the remaining entries to either true or false based on whether a match was found + if (found_match) { + for (idx_t i = 0; i < child.size(); i++) { + bool_result[i] = found_match[i]; + } + } else { + memset(bool_result, 0, sizeof(bool) * child.size()); + } + // if the right side contains NULL values, the result of any FALSE becomes NULL + if (ht.has_null) { + for (idx_t i = 0; i < child.size(); i++) { + if (!bool_result[i]) { + mask.SetInvalid(i); + } + } + } +} + +void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &input, DataChunk &result) { + D_ASSERT(result.ColumnCount() == input.ColumnCount() + 1); + D_ASSERT(result.data.back().GetType() == LogicalType::BOOLEAN); + // this method should only be called for a non-empty HT + D_ASSERT(ht.Count() > 0); + + ScanKeyMatches(keys); + if (ht.correlated_mark_join_info.correlated_types.empty()) { + ConstructMarkJoinResult(keys, input, result); + } else { + auto &info = ht.correlated_mark_join_info; + lock_guard mj_lock(info.mj_lock); + + // there are correlated columns + // first we fetch the counts from the aggregate hashtable corresponding to these entries + D_ASSERT(keys.ColumnCount() == info.group_chunk.ColumnCount() + 1); + info.group_chunk.SetCardinality(keys); + for (idx_t i = 0; i < info.group_chunk.ColumnCount(); i++) { + info.group_chunk.data[i].Reference(keys.data[i]); + } + info.correlated_counts->FetchAggregates(info.group_chunk, info.result_chunk); + + // for the initial set of columns we just reference the left side + result.SetCardinality(input); + for (idx_t i = 0; i < input.ColumnCount(); i++) { + result.data[i].Reference(input.data[i]); + } + // create the result matching vector + auto &last_key = keys.data.back(); + auto &result_vector = result.data.back(); + // first set the nullmask based on whether or not there were NULL values in the join key + result_vector.SetVectorType(VectorType::FLAT_VECTOR); + auto bool_result = FlatVector::GetData(result_vector); + auto &mask = FlatVector::Validity(result_vector); + switch (last_key.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: + if (ConstantVector::IsNull(last_key)) { + mask.SetAllInvalid(input.size()); + } + break; + case VectorType::FLAT_VECTOR: + mask.Copy(FlatVector::Validity(last_key), input.size()); + break; + default: { + UnifiedVectorFormat kdata; + last_key.ToUnifiedFormat(keys.size(), kdata); + for (idx_t i = 0; i < input.size(); i++) { + auto kidx = kdata.sel->get_index(i); + mask.Set(i, kdata.validity.RowIsValid(kidx)); + } + break; + } + } + + auto count_star = FlatVector::GetData(info.result_chunk.data[0]); + auto count = FlatVector::GetData(info.result_chunk.data[1]); + // set the entries to either true or false based on whether a match was found + for (idx_t i = 0; i < input.size(); i++) { + D_ASSERT(count_star[i] >= count[i]); + bool_result[i] = found_match ? found_match[i] : false; + if (!bool_result[i] && count_star[i] > count[i]) { + // RHS has NULL value and result is false: set to null + mask.SetInvalid(i); + } + if (count_star[i] == 0) { + // count == 0, set nullmask to false (we know the result is false now) + mask.SetValid(i); + } + } + } + finished = true; +} + +void ScanStructure::NextLeftJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { + // a LEFT OUTER JOIN is identical to an INNER JOIN except all tuples that do + // not have a match must return at least one tuple (with the right side set + // to NULL in every column) + NextInnerJoin(keys, left, result); + if (result.size() == 0) { + // no entries left from the normal join + // fill in the result of the remaining left tuples + // together with NULL values on the right-hand side + idx_t remaining_count = 0; + SelectionVector sel(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < left.size(); i++) { + if (!found_match[i]) { + sel.set_index(remaining_count++, i); + } + } + if (remaining_count > 0) { + // have remaining tuples + // slice the left side with tuples that did not find a match + result.Slice(left, sel, remaining_count); + + // now set the right side to NULL + for (idx_t i = left.ColumnCount(); i < result.ColumnCount(); i++) { + Vector &vec = result.data[i]; + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); + } + } + finished = true; + } +} + +void ScanStructure::NextSingleJoin(DataChunk &keys, DataChunk &input, DataChunk &result) { + // single join + // this join is similar to the semi join except that + // (1) we actually return data from the RHS and + // (2) we return NULL for that data if there is no match + idx_t result_count = 0; + SelectionVector result_sel(STANDARD_VECTOR_SIZE); + SelectionVector match_sel(STANDARD_VECTOR_SIZE), no_match_sel(STANDARD_VECTOR_SIZE); + while (this->count > 0) { + // resolve the predicates for the current set of pointers + idx_t match_count = ResolvePredicates(keys, match_sel, &no_match_sel); + idx_t no_match_count = this->count - match_count; + + // mark each of the matches as found + for (idx_t i = 0; i < match_count; i++) { + // found a match for this index + auto index = match_sel.get_index(i); + found_match[index] = true; + result_sel.set_index(result_count++, index); + } + // continue searching for the ones where we did not find a match yet + AdvancePointers(no_match_sel, no_match_count); + } + // reference the columns of the left side from the result + D_ASSERT(input.ColumnCount() > 0); + for (idx_t i = 0; i < input.ColumnCount(); i++) { + result.data[i].Reference(input.data[i]); + } + // now fetch the data from the RHS + for (idx_t i = 0; i < ht.build_types.size(); i++) { + auto &vector = result.data[input.ColumnCount() + i]; + // set NULL entries for every entry that was not found + for (idx_t j = 0; j < input.size(); j++) { + if (!found_match[j]) { + FlatVector::SetNull(vector, j, true); + } + } + // for the remaining values we fetch the values + GatherResult(vector, result_sel, result_sel, result_count, i + ht.condition_types.size()); + } + result.SetCardinality(input.size()); + + // like the SEMI, ANTI and MARK join types, the SINGLE join only ever does one pass over the HT per input chunk + finished = true; +} + +void JoinHashTable::ScanFullOuter(JoinHTScanState &state, Vector &addresses, DataChunk &result) { + // scan the HT starting from the current position and check which rows from the build side did not find a match + auto key_locations = FlatVector::GetData(addresses); + idx_t found_entries = 0; + + auto &iterator = state.iterator; + if (iterator.Done()) { + return; + } + + const auto row_locations = iterator.GetRowLocations(); + do { + const auto count = iterator.GetCurrentChunkCount(); + for (idx_t i = state.offset_in_chunk; i < count; i++) { + auto found_match = Load(row_locations[i] + tuple_size); + if (!found_match) { + key_locations[found_entries++] = row_locations[i]; + if (found_entries == STANDARD_VECTOR_SIZE) { + state.offset_in_chunk = i + 1; + break; + } + } + } + if (found_entries == STANDARD_VECTOR_SIZE) { + break; + } + state.offset_in_chunk = 0; + } while (iterator.Next()); + + // now gather from the found rows + if (found_entries == 0) { + return; + } + result.SetCardinality(found_entries); + idx_t left_column_count = result.ColumnCount() - build_types.size(); + const auto &sel_vector = *FlatVector::IncrementalSelectionVector(); + // set the left side as a constant NULL + for (idx_t i = 0; i < left_column_count; i++) { + Vector &vec = result.data[i]; + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); + } + + // gather the values from the RHS + for (idx_t i = 0; i < build_types.size(); i++) { + auto &vector = result.data[left_column_count + i]; + D_ASSERT(vector.GetType() == build_types[i]); + const auto col_no = condition_types.size() + i; + data_collection->Gather(addresses, sel_vector, found_entries, col_no, vector, sel_vector); + } +} + +idx_t JoinHashTable::FillWithHTOffsets(JoinHTScanState &state, Vector &addresses) { + // iterate over HT + auto key_locations = FlatVector::GetData(addresses); + idx_t key_count = 0; + + auto &iterator = state.iterator; + const auto row_locations = iterator.GetRowLocations(); + do { + const auto count = iterator.GetCurrentChunkCount(); + for (idx_t i = 0; i < count; i++) { + key_locations[key_count + i] = row_locations[i]; + } + key_count += count; + } while (iterator.Next()); + + return key_count; +} + +bool JoinHashTable::RequiresExternalJoin(ClientConfig &config, vector> &local_hts) { + total_count = 0; + idx_t data_size = 0; + for (auto &ht : local_hts) { + auto &local_sink_collection = ht->GetSinkCollection(); + total_count += local_sink_collection.Count(); + data_size += local_sink_collection.SizeInBytes(); + } + + if (total_count == 0) { + return false; + } + + if (config.force_external) { + // Do 1 round per partition if forcing external join to test all code paths + const auto r = RadixPartitioning::NumberOfPartitions(radix_bits); + auto data_size_per_round = (data_size + r - 1) / r; + auto count_per_round = (total_count + r - 1) / r; + max_ht_size = data_size_per_round + PointerTableSize(count_per_round); + external = true; + } else { + auto ht_size = data_size + PointerTableSize(total_count); + external = ht_size > max_ht_size; + } + return external; +} + +void JoinHashTable::Unpartition() { + for (auto &partition : sink_collection->GetPartitions()) { + data_collection->Combine(*partition); + } +} + +bool JoinHashTable::RequiresPartitioning(ClientConfig &config, vector> &local_hts) { + D_ASSERT(total_count != 0); + D_ASSERT(external); + + idx_t num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); + vector partition_counts(num_partitions, 0); + vector partition_sizes(num_partitions, 0); + for (auto &ht : local_hts) { + const auto &local_partitions = ht->GetSinkCollection().GetPartitions(); + for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { + auto &local_partition = local_partitions[partition_idx]; + partition_counts[partition_idx] += local_partition->Count(); + partition_sizes[partition_idx] += local_partition->SizeInBytes(); + } + } + + // Figure out if we can fit all single partitions in memory + idx_t max_partition_idx = 0; + idx_t max_partition_size = 0; + for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { + const auto &partition_count = partition_counts[partition_idx]; + const auto &partition_size = partition_sizes[partition_idx]; + auto partition_ht_size = partition_size + PointerTableSize(partition_count); + if (partition_ht_size > max_partition_size) { + max_partition_size = partition_ht_size; + max_partition_idx = partition_idx; + } + } + + if (config.force_external || max_partition_size > max_ht_size) { + const auto partition_count = partition_counts[max_partition_idx]; + const auto partition_size = partition_sizes[max_partition_idx]; + + const auto max_added_bits = RadixPartitioning::MAX_RADIX_BITS - radix_bits; + idx_t added_bits = config.force_external ? 2 : 1; + for (; added_bits < max_added_bits; added_bits++) { + double partition_multiplier = RadixPartitioning::NumberOfPartitions(added_bits); + + auto new_estimated_count = double(partition_count) / partition_multiplier; + auto new_estimated_size = double(partition_size) / partition_multiplier; + auto new_estimated_ht_size = new_estimated_size + PointerTableSize(new_estimated_count); + + if (config.force_external || new_estimated_ht_size <= double(max_ht_size) / 4) { + // Aim for an estimated partition size of max_ht_size / 4 + break; + } + } + radix_bits += added_bits; + sink_collection = + make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); + return true; + } else { + return false; + } +} + +void JoinHashTable::Partition(JoinHashTable &global_ht) { + auto new_sink_collection = + make_uniq(buffer_manager, layout, global_ht.radix_bits, layout.ColumnCount() - 1); + sink_collection->Repartition(*new_sink_collection); + sink_collection = std::move(new_sink_collection); + global_ht.Merge(*this); +} + +void JoinHashTable::Reset() { + data_collection->Reset(); + finalized = false; +} + +bool JoinHashTable::PrepareExternalFinalize() { + if (finalized) { + Reset(); + } + + const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); + if (partition_end == num_partitions) { + return false; + } + + // Start where we left off + auto &partitions = sink_collection->GetPartitions(); + partition_start = partition_end; + + // Determine how many partitions we can do next (at least one) + idx_t count = 0; + idx_t data_size = 0; + idx_t partition_idx; + for (partition_idx = partition_start; partition_idx < num_partitions; partition_idx++) { + auto incl_count = count + partitions[partition_idx]->Count(); + auto incl_data_size = data_size + partitions[partition_idx]->SizeInBytes(); + auto incl_ht_size = incl_data_size + PointerTableSize(incl_count); + if (count > 0 && incl_ht_size > max_ht_size) { + break; + } + count = incl_count; + data_size = incl_data_size; + } + partition_end = partition_idx; + + // Move the partitions to the main data collection + for (partition_idx = partition_start; partition_idx < partition_end; partition_idx++) { + data_collection->Combine(*partitions[partition_idx]); + } + D_ASSERT(Count() == count); + + return true; +} + +static void CreateSpillChunk(DataChunk &spill_chunk, DataChunk &keys, DataChunk &payload, Vector &hashes) { + spill_chunk.Reset(); + idx_t spill_col_idx = 0; + for (idx_t col_idx = 0; col_idx < keys.ColumnCount(); col_idx++) { + spill_chunk.data[col_idx].Reference(keys.data[col_idx]); + } + spill_col_idx += keys.ColumnCount(); + for (idx_t col_idx = 0; col_idx < payload.data.size(); col_idx++) { + spill_chunk.data[spill_col_idx + col_idx].Reference(payload.data[col_idx]); + } + spill_col_idx += payload.ColumnCount(); + spill_chunk.data[spill_col_idx].Reference(hashes); +} + +unique_ptr JoinHashTable::ProbeAndSpill(DataChunk &keys, TupleDataChunkState &key_state, + DataChunk &payload, ProbeSpill &probe_spill, + ProbeSpillLocalAppendState &spill_state, + DataChunk &spill_chunk) { + // hash all the keys + Vector hashes(LogicalType::HASH); + Hash(keys, *FlatVector::IncrementalSelectionVector(), keys.size(), hashes); + + // find out which keys we can match with the current pinned partitions + SelectionVector true_sel; + SelectionVector false_sel; + true_sel.Initialize(); + false_sel.Initialize(); + auto true_count = RadixPartitioning::Select(hashes, FlatVector::IncrementalSelectionVector(), keys.size(), + radix_bits, partition_end, &true_sel, &false_sel); + auto false_count = keys.size() - true_count; + + CreateSpillChunk(spill_chunk, keys, payload, hashes); + + // can't probe these values right now, append to spill + spill_chunk.Slice(false_sel, false_count); + spill_chunk.Verify(); + probe_spill.Append(spill_chunk, spill_state); + + // slice the stuff we CAN probe right now + hashes.Slice(true_sel, true_count); + keys.Slice(true_sel, true_count); + payload.Slice(true_sel, true_count); + + const SelectionVector *current_sel; + auto ss = InitializeScanStructure(keys, key_state, current_sel); + if (ss->count == 0) { + return ss; + } + + // now initialize the pointers of the scan structure based on the hashes + ApplyBitmask(hashes, *current_sel, ss->count, ss->pointers); + + // create the selection vector linking to only non-empty entries + ss->InitializeSelectionVector(current_sel); + + return ss; +} + +ProbeSpill::ProbeSpill(JoinHashTable &ht, ClientContext &context, const vector &probe_types) + : ht(ht), context(context), probe_types(probe_types) { + auto remaining_count = ht.GetSinkCollection().Count(); + auto remaining_data_size = ht.GetSinkCollection().SizeInBytes(); + auto remaining_ht_size = remaining_data_size + ht.PointerTableSize(remaining_count); + if (remaining_ht_size <= ht.max_ht_size) { + // No need to partition as we will only have one more probe round + partitioned = false; + } else { + // More than one probe round to go, so we need to partition + partitioned = true; + global_partitions = + make_uniq(context, probe_types, ht.radix_bits, probe_types.size() - 1); + } + column_ids.reserve(probe_types.size()); + for (column_t column_id = 0; column_id < probe_types.size(); column_id++) { + column_ids.emplace_back(column_id); + } +} + +ProbeSpillLocalState ProbeSpill::RegisterThread() { + ProbeSpillLocalAppendState result; + lock_guard guard(lock); + if (partitioned) { + local_partitions.emplace_back(global_partitions->CreateShared()); + local_partition_append_states.emplace_back(make_uniq()); + local_partitions.back()->InitializeAppendState(*local_partition_append_states.back()); + + result.local_partition = local_partitions.back().get(); + result.local_partition_append_state = local_partition_append_states.back().get(); + } else { + local_spill_collections.emplace_back( + make_uniq(BufferManager::GetBufferManager(context), probe_types)); + local_spill_append_states.emplace_back(make_uniq()); + local_spill_collections.back()->InitializeAppend(*local_spill_append_states.back()); + + result.local_spill_collection = local_spill_collections.back().get(); + result.local_spill_append_state = local_spill_append_states.back().get(); + } + return result; +} + +void ProbeSpill::Append(DataChunk &chunk, ProbeSpillLocalAppendState &local_state) { + if (partitioned) { + local_state.local_partition->Append(*local_state.local_partition_append_state, chunk); + } else { + local_state.local_spill_collection->Append(*local_state.local_spill_append_state, chunk); + } +} + +void ProbeSpill::Finalize() { + if (partitioned) { + D_ASSERT(local_partitions.size() == local_partition_append_states.size()); + for (idx_t i = 0; i < local_partition_append_states.size(); i++) { + local_partitions[i]->FlushAppendState(*local_partition_append_states[i]); + } + for (auto &local_partition : local_partitions) { + global_partitions->Combine(*local_partition); + } + local_partitions.clear(); + local_partition_append_states.clear(); + } else { + if (local_spill_collections.empty()) { + global_spill_collection = + make_uniq(BufferManager::GetBufferManager(context), probe_types); + } else { + global_spill_collection = std::move(local_spill_collections[0]); + for (idx_t i = 1; i < local_spill_collections.size(); i++) { + global_spill_collection->Combine(*local_spill_collections[i]); + } + } + local_spill_collections.clear(); + local_spill_append_states.clear(); + } +} + +void ProbeSpill::PrepareNextProbe() { + if (partitioned) { + auto &partitions = global_partitions->GetPartitions(); + if (partitions.empty() || ht.partition_start == partitions.size()) { + // Can't probe, just make an empty one + global_spill_collection = + make_uniq(BufferManager::GetBufferManager(context), probe_types); + } else { + // Move specific partitions to the global spill collection + global_spill_collection = std::move(partitions[ht.partition_start]); + for (idx_t i = ht.partition_start + 1; i < ht.partition_end; i++) { + auto &partition = partitions[i]; + if (global_spill_collection->Count() == 0) { + global_spill_collection = std::move(partition); + } else { + global_spill_collection->Combine(*partition); + } + } + } + } + consumer = make_uniq(*global_spill_collection, column_ids); + consumer->InitializeScan(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp new file mode 100644 index 00000000..f64943fa --- /dev/null +++ b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp @@ -0,0 +1,189 @@ +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/execution/nested_loop_join.hpp" + +namespace duckdb { + +struct InitialNestedLoopJoin { + template + static idx_t Operation(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, idx_t &rpos, + SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { + using MATCH_OP = ComparisonOperationWrapper; + + // initialize phase of nested loop join + // fill lvector and rvector with matches from the base vectors + UnifiedVectorFormat left_data, right_data; + left.ToUnifiedFormat(left_size, left_data); + right.ToUnifiedFormat(right_size, right_data); + + auto ldata = UnifiedVectorFormat::GetData(left_data); + auto rdata = UnifiedVectorFormat::GetData(right_data); + idx_t result_count = 0; + for (; rpos < right_size; rpos++) { + idx_t right_position = right_data.sel->get_index(rpos); + bool right_is_valid = right_data.validity.RowIsValid(right_position); + for (; lpos < left_size; lpos++) { + if (result_count == STANDARD_VECTOR_SIZE) { + // out of space! + return result_count; + } + idx_t left_position = left_data.sel->get_index(lpos); + bool left_is_valid = left_data.validity.RowIsValid(left_position); + if (MATCH_OP::Operation(ldata[left_position], rdata[right_position], !left_is_valid, !right_is_valid)) { + // emit tuple + lvector.set_index(result_count, lpos); + rvector.set_index(result_count, rpos); + result_count++; + } + } + lpos = 0; + } + return result_count; + } +}; + +struct RefineNestedLoopJoin { + template + static idx_t Operation(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, idx_t &rpos, + SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { + using MATCH_OP = ComparisonOperationWrapper; + + UnifiedVectorFormat left_data, right_data; + left.ToUnifiedFormat(left_size, left_data); + right.ToUnifiedFormat(right_size, right_data); + + // refine phase of the nested loop join + // refine lvector and rvector based on matches of subsequent conditions (in case there are multiple conditions + // in the join) + D_ASSERT(current_match_count > 0); + auto ldata = UnifiedVectorFormat::GetData(left_data); + auto rdata = UnifiedVectorFormat::GetData(right_data); + idx_t result_count = 0; + for (idx_t i = 0; i < current_match_count; i++) { + auto lidx = lvector.get_index(i); + auto ridx = rvector.get_index(i); + auto left_idx = left_data.sel->get_index(lidx); + auto right_idx = right_data.sel->get_index(ridx); + bool left_is_valid = left_data.validity.RowIsValid(left_idx); + bool right_is_valid = right_data.validity.RowIsValid(right_idx); + if (MATCH_OP::Operation(ldata[left_idx], rdata[right_idx], !left_is_valid, !right_is_valid)) { + lvector.set_index(result_count, lidx); + rvector.set_index(result_count, ridx); + result_count++; + } + } + return result_count; + } +}; + +template +static idx_t NestedLoopJoinTypeSwitch(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, + idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, + idx_t current_match_count) { + switch (left.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, + current_match_count); + case PhysicalType::INT16: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, + current_match_count); + case PhysicalType::INT32: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, + current_match_count); + case PhysicalType::INT64: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, + current_match_count); + case PhysicalType::UINT8: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, + current_match_count); + case PhysicalType::UINT16: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case PhysicalType::UINT32: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case PhysicalType::UINT64: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case PhysicalType::INT128: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case PhysicalType::FLOAT: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, + current_match_count); + case PhysicalType::DOUBLE: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, + current_match_count); + case PhysicalType::INTERVAL: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case PhysicalType::VARCHAR: + return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + default: + throw InternalException("Unimplemented type for join!"); + } +} + +template +idx_t NestedLoopJoinComparisonSwitch(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, + idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, + idx_t current_match_count, ExpressionType comparison_type) { + D_ASSERT(left.GetType() == right.GetType()); + switch (comparison_type) { + case ExpressionType::COMPARE_EQUAL: + return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case ExpressionType::COMPARE_NOTEQUAL: + return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case ExpressionType::COMPARE_LESSTHAN: + return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case ExpressionType::COMPARE_GREATERTHAN: + return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, + lvector, rvector, current_match_count); + case ExpressionType::COMPARE_DISTINCT_FROM: + return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, + rvector, current_match_count); + default: + throw NotImplementedException("Unimplemented comparison type for join!"); + } +} + +idx_t NestedLoopJoinInner::Perform(idx_t &lpos, idx_t &rpos, DataChunk &left_conditions, DataChunk &right_conditions, + SelectionVector &lvector, SelectionVector &rvector, + const vector &conditions) { + D_ASSERT(left_conditions.ColumnCount() == right_conditions.ColumnCount()); + if (lpos >= left_conditions.size() || rpos >= right_conditions.size()) { + return 0; + } + // for the first condition, lvector and rvector are not set yet + // we initialize them using the InitialNestedLoopJoin + idx_t match_count = NestedLoopJoinComparisonSwitch( + left_conditions.data[0], right_conditions.data[0], left_conditions.size(), right_conditions.size(), lpos, rpos, + lvector, rvector, 0, conditions[0].comparison); + // now resolve the rest of the conditions + for (idx_t i = 1; i < conditions.size(); i++) { + // check if we have run out of tuples to compare + if (match_count == 0) { + return 0; + } + // if not, get the vectors to compare + Vector &l = left_conditions.data[i]; + Vector &r = right_conditions.data[i]; + // then we refine the currently obtained results using the RefineNestedLoopJoin + match_count = NestedLoopJoinComparisonSwitch( + l, r, left_conditions.size(), right_conditions.size(), lpos, rpos, lvector, rvector, match_count, + conditions[i].comparison); + } + return match_count; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp new file mode 100644 index 00000000..853037e5 --- /dev/null +++ b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp @@ -0,0 +1,165 @@ +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/nested_loop_join.hpp" + +namespace duckdb { + +template +static void TemplatedMarkJoin(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { + using MATCH_OP = ComparisonOperationWrapper; + + UnifiedVectorFormat left_data, right_data; + left.ToUnifiedFormat(lcount, left_data); + right.ToUnifiedFormat(rcount, right_data); + + auto ldata = UnifiedVectorFormat::GetData(left_data); + auto rdata = UnifiedVectorFormat::GetData(right_data); + for (idx_t i = 0; i < lcount; i++) { + if (found_match[i]) { + continue; + } + auto lidx = left_data.sel->get_index(i); + const auto left_null = !left_data.validity.RowIsValid(lidx); + if (!MATCH_OP::COMPARE_NULL && left_null) { + continue; + } + for (idx_t j = 0; j < rcount; j++) { + auto ridx = right_data.sel->get_index(j); + const auto right_null = !right_data.validity.RowIsValid(ridx); + if (!MATCH_OP::COMPARE_NULL && right_null) { + continue; + } + if (MATCH_OP::template Operation(ldata[lidx], rdata[ridx], left_null, right_null)) { + found_match[i] = true; + break; + } + } + } +} + +static void MarkJoinNested(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[], + ExpressionType comparison_type) { + Vector left_reference(left.GetType()); + SelectionVector true_sel(rcount); + for (idx_t i = 0; i < lcount; i++) { + if (found_match[i]) { + continue; + } + ConstantVector::Reference(left_reference, left, i, rcount); + idx_t count; + switch (comparison_type) { + case ExpressionType::COMPARE_EQUAL: + count = VectorOperations::Equals(left_reference, right, nullptr, rcount, nullptr, nullptr); + break; + case ExpressionType::COMPARE_NOTEQUAL: + count = VectorOperations::NotEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); + break; + case ExpressionType::COMPARE_LESSTHAN: + count = VectorOperations::LessThan(left_reference, right, nullptr, rcount, nullptr, nullptr); + break; + case ExpressionType::COMPARE_GREATERTHAN: + count = VectorOperations::GreaterThan(left_reference, right, nullptr, rcount, nullptr, nullptr); + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + count = VectorOperations::LessThanEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + count = VectorOperations::GreaterThanEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); + break; + case ExpressionType::COMPARE_DISTINCT_FROM: + count = VectorOperations::DistinctFrom(left_reference, right, nullptr, rcount, nullptr, nullptr); + break; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + count = VectorOperations::NotDistinctFrom(left_reference, right, nullptr, rcount, nullptr, nullptr); + break; + default: + throw InternalException("Unsupported comparison type for MarkJoinNested"); + } + if (count > 0) { + found_match[i] = true; + } + } +} + +template +static void MarkJoinSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { + switch (left.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::INT16: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::INT32: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::INT64: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::INT128: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::UINT8: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::UINT16: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::UINT32: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::UINT64: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::FLOAT: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::DOUBLE: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + case PhysicalType::VARCHAR: + return TemplatedMarkJoin(left, right, lcount, rcount, found_match); + default: + throw NotImplementedException("Unimplemented type for mark join!"); + } +} + +static void MarkJoinComparisonSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[], + ExpressionType comparison_type) { + switch (left.GetType().InternalType()) { + case PhysicalType::STRUCT: + case PhysicalType::LIST: + return MarkJoinNested(left, right, lcount, rcount, found_match, comparison_type); + default: + break; + } + D_ASSERT(left.GetType() == right.GetType()); + switch (comparison_type) { + case ExpressionType::COMPARE_EQUAL: + return MarkJoinSwitch(left, right, lcount, rcount, found_match); + case ExpressionType::COMPARE_NOTEQUAL: + return MarkJoinSwitch(left, right, lcount, rcount, found_match); + case ExpressionType::COMPARE_LESSTHAN: + return MarkJoinSwitch(left, right, lcount, rcount, found_match); + case ExpressionType::COMPARE_GREATERTHAN: + return MarkJoinSwitch(left, right, lcount, rcount, found_match); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return MarkJoinSwitch(left, right, lcount, rcount, found_match); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return MarkJoinSwitch(left, right, lcount, rcount, found_match); + case ExpressionType::COMPARE_DISTINCT_FROM: + return MarkJoinSwitch(left, right, lcount, rcount, found_match); + default: + throw NotImplementedException("Unimplemented comparison type for join!"); + } +} + +void NestedLoopJoinMark::Perform(DataChunk &left, ColumnDataCollection &right, bool found_match[], + const vector &conditions) { + // initialize a new temporary selection vector for the left chunk + // loop over all chunks in the RHS + ColumnDataScanState scan_state; + right.InitializeScan(scan_state); + + DataChunk scan_chunk; + right.InitializeScanChunk(scan_chunk); + + while (right.Scan(scan_state, scan_chunk)) { + for (idx_t i = 0; i < conditions.size(); i++) { + MarkJoinComparisonSwitch(left.data[i], scan_chunk.data[i], left.size(), scan_chunk.size(), found_match, + conditions[i].comparison); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp new file mode 100644 index 00000000..b2ef9478 --- /dev/null +++ b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp @@ -0,0 +1,85 @@ +#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" + +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" + +namespace duckdb { + +AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_data, idx_t child_count, + idx_t payload_size, AggregateType aggr_type, PhysicalType return_type, + Expression *filter) + : function(std::move(function)), + bind_data_wrapper(bind_data ? make_shared(bind_data->Copy()) : nullptr), + child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type), + filter(filter) { +} + +AggregateObject::AggregateObject(BoundAggregateExpression *aggr) + : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(), + AlignValue(aggr->function.state_size()), aggr->aggr_type, aggr->return_type.InternalType(), + aggr->filter.get()) { +} + +AggregateObject::AggregateObject(BoundWindowExpression &window) + : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(), + AlignValue(window.aggregate->state_size()), AggregateType::NON_DISTINCT, + window.return_type.InternalType(), window.filter_expr.get()) { +} + +vector AggregateObject::CreateAggregateObjects(const vector &bindings) { + vector aggregates; + aggregates.reserve(aggregates.size()); + for (auto &binding : bindings) { + aggregates.emplace_back(binding); + } + return aggregates; +} + +AggregateFilterData::AggregateFilterData(ClientContext &context, Expression &filter_expr, + const vector &payload_types) + : filter_executor(context, &filter_expr), true_sel(STANDARD_VECTOR_SIZE) { + if (payload_types.empty()) { + return; + } + filtered_payload.Initialize(Allocator::Get(context), payload_types); +} + +idx_t AggregateFilterData::ApplyFilter(DataChunk &payload) { + filtered_payload.Reset(); + + auto count = filter_executor.SelectExpression(payload, true_sel); + filtered_payload.Slice(payload, true_sel, count); + return count; +} + +AggregateFilterDataSet::AggregateFilterDataSet() { +} + +void AggregateFilterDataSet::Initialize(ClientContext &context, const vector &aggregates, + const vector &payload_types) { + bool has_filters = false; + for (auto &aggregate : aggregates) { + if (aggregate.filter) { + has_filters = true; + break; + } + } + if (!has_filters) { + // no filters: nothing to do + return; + } + filter_data.resize(aggregates.size()); + for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { + auto &aggr = aggregates[aggr_idx]; + if (aggr.filter) { + filter_data[aggr_idx] = make_uniq(context, *aggr.filter, payload_types); + } + } +} + +AggregateFilterData &AggregateFilterDataSet::GetFilterData(idx_t aggr_idx) { + D_ASSERT(aggr_idx < filter_data.size()); + D_ASSERT(filter_data[aggr_idx]); + return *filter_data[aggr_idx]; +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp new file mode 100644 index 00000000..cd76deee --- /dev/null +++ b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp @@ -0,0 +1,216 @@ +#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +//! Shared information about a collection of distinct aggregates +DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector> &aggregates, + vector indices) + : indices(std::move(indices)), aggregates(aggregates) { + table_count = CreateTableIndexMap(); + + const idx_t aggregate_count = aggregates.size(); + + total_child_count = 0; + for (idx_t i = 0; i < aggregate_count; i++) { + auto &aggregate = aggregates[i]->Cast(); + + if (!aggregate.IsDistinct()) { + continue; + } + total_child_count += aggregate.children.size(); + } +} + +//! Stateful data for the distinct aggregates + +DistinctAggregateState::DistinctAggregateState(const DistinctAggregateData &data, ClientContext &client) + : child_executor(client) { + + radix_states.resize(data.info.table_count); + distinct_output_chunks.resize(data.info.table_count); + + idx_t aggregate_count = data.info.aggregates.size(); + for (idx_t i = 0; i < aggregate_count; i++) { + auto &aggregate = data.info.aggregates[i]->Cast(); + + // Initialize the child executor and get the payload types for every aggregate + for (auto &child : aggregate.children) { + child_executor.AddExpression(*child); + } + if (!aggregate.IsDistinct()) { + continue; + } + D_ASSERT(data.info.table_map.count(i)); + idx_t table_idx = data.info.table_map.at(i); + if (data.radix_tables[table_idx] == nullptr) { + //! This table is unused because the aggregate shares its data with another + continue; + } + + // Get the global sinkstate for the aggregate + auto &radix_table = *data.radix_tables[table_idx]; + radix_states[table_idx] = radix_table.GetGlobalSinkState(client); + + // Fill the chunk_types (group_by + children) + vector chunk_types; + for (auto &group_type : data.grouped_aggregate_data[table_idx]->group_types) { + chunk_types.push_back(group_type); + } + + // This is used in Finalize to get the data from the radix table + distinct_output_chunks[table_idx] = make_uniq(); + distinct_output_chunks[table_idx]->Initialize(client, chunk_types); + } +} + +//! Persistent + shared (read-only) data for the distinct aggregates +DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info) + : DistinctAggregateData(info, {}, nullptr) { +} + +DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info, const GroupingSet &groups, + const vector> *group_expressions) + : info(info) { + grouped_aggregate_data.resize(info.table_count); + radix_tables.resize(info.table_count); + grouping_sets.resize(info.table_count); + + for (auto &i : info.indices) { + auto &aggregate = info.aggregates[i]->Cast(); + + D_ASSERT(info.table_map.count(i)); + idx_t table_idx = info.table_map.at(i); + if (radix_tables[table_idx] != nullptr) { + //! This aggregate shares a table with another aggregate, and the table is already initialized + continue; + } + // The grouping set contains the indices of the chunk that correspond to the data vector + // that will be used to figure out in which bucket the payload should be put + auto &grouping_set = grouping_sets[table_idx]; + //! Populate the group with the children of the aggregate + for (auto &group : groups) { + grouping_set.insert(group); + } + idx_t group_by_size = group_expressions ? group_expressions->size() : 0; + for (idx_t set_idx = 0; set_idx < aggregate.children.size(); set_idx++) { + grouping_set.insert(set_idx + group_by_size); + } + // Create the hashtable for the aggregate + grouped_aggregate_data[table_idx] = make_uniq(); + grouped_aggregate_data[table_idx]->InitializeDistinct(info.aggregates[i], group_expressions); + radix_tables[table_idx] = + make_uniq(grouping_set, *grouped_aggregate_data[table_idx]); + + // Fill the chunk_types (only contains the payload of the distinct aggregates) + vector chunk_types; + for (auto &child_p : aggregate.children) { + chunk_types.push_back(child_p->return_type); + } + } +} + +using aggr_ref_t = reference; + +struct FindMatchingAggregate { + explicit FindMatchingAggregate(const aggr_ref_t &aggr) : aggr_r(aggr) { + } + bool operator()(const aggr_ref_t other_r) { + auto &other = other_r.get(); + auto &aggr = aggr_r.get(); + if (other.children.size() != aggr.children.size()) { + return false; + } + if (!Expression::Equals(aggr.filter, other.filter)) { + return false; + } + for (idx_t i = 0; i < aggr.children.size(); i++) { + auto &other_child = other.children[i]->Cast(); + auto &aggr_child = aggr.children[i]->Cast(); + if (other_child.index != aggr_child.index) { + return false; + } + } + return true; + } + const aggr_ref_t aggr_r; +}; + +idx_t DistinctAggregateCollectionInfo::CreateTableIndexMap() { + vector table_inputs; + + D_ASSERT(table_map.empty()); + for (auto &agg_idx : indices) { + D_ASSERT(agg_idx < aggregates.size()); + auto &aggregate = aggregates[agg_idx]->Cast(); + + auto matching_inputs = + std::find_if(table_inputs.begin(), table_inputs.end(), FindMatchingAggregate(std::ref(aggregate))); + if (matching_inputs != table_inputs.end()) { + //! Assign the existing table to the aggregate + idx_t found_idx = std::distance(table_inputs.begin(), matching_inputs); + table_map[agg_idx] = found_idx; + continue; + } + //! Create a new table and assign its index to the aggregate + table_map[agg_idx] = table_inputs.size(); + table_inputs.push_back(std::ref(aggregate)); + } + //! Every distinct aggregate needs to be assigned an index + D_ASSERT(table_map.size() == indices.size()); + //! There can not be more tables than there are distinct aggregates + D_ASSERT(table_inputs.size() <= indices.size()); + + return table_inputs.size(); +} + +bool DistinctAggregateCollectionInfo::AnyDistinct() const { + return !indices.empty(); +} + +const unsafe_vector &DistinctAggregateCollectionInfo::Indices() const { + return this->indices; +} + +static vector GetDistinctIndices(vector> &aggregates) { + vector distinct_indices; + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &aggregate = aggregates[i]; + auto &aggr = aggregate->Cast(); + if (aggr.IsDistinct()) { + distinct_indices.push_back(i); + } + } + return distinct_indices; +} + +unique_ptr +DistinctAggregateCollectionInfo::Create(vector> &aggregates) { + vector indices = GetDistinctIndices(aggregates); + if (indices.empty()) { + return nullptr; + } + return make_uniq(aggregates, std::move(indices)); +} + +bool DistinctAggregateData::IsDistinct(idx_t index) const { + bool is_distinct = !radix_tables.empty() && info.table_map.count(index); +#ifdef DEBUG + //! Make sure that if it is distinct, it's also in the indices + //! And if it's not distinct, that it's also not in the indices + bool found = false; + for (auto &idx : info.indices) { + if (idx == index) { + found = true; + break; + } + } + D_ASSERT(found == is_distinct); +#endif + return is_distinct; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp b/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp new file mode 100644 index 00000000..d0b52a60 --- /dev/null +++ b/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp @@ -0,0 +1,96 @@ +#include "duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp" + +namespace duckdb { + +idx_t GroupedAggregateData::GroupCount() const { + return groups.size(); +} + +const vector> &GroupedAggregateData::GetGroupingFunctions() const { + return grouping_functions; +} + +void GroupedAggregateData::InitializeGroupby(vector> groups, + vector> expressions, + vector> grouping_functions) { + InitializeGroupbyGroups(std::move(groups)); + vector payload_types_filters; + + SetGroupingFunctions(grouping_functions); + + filter_count = 0; + for (auto &expr : expressions) { + D_ASSERT(expr->expression_class == ExpressionClass::BOUND_AGGREGATE); + D_ASSERT(expr->IsAggregate()); + auto &aggr = expr->Cast(); + bindings.push_back(&aggr); + + aggregate_return_types.push_back(aggr.return_type); + for (auto &child : aggr.children) { + payload_types.push_back(child->return_type); + } + if (aggr.filter) { + filter_count++; + payload_types_filters.push_back(aggr.filter->return_type); + } + if (!aggr.function.combine) { + throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); + } + aggregates.push_back(std::move(expr)); + } + for (const auto &pay_filters : payload_types_filters) { + payload_types.push_back(pay_filters); + } +} + +void GroupedAggregateData::InitializeDistinct(const unique_ptr &aggregate, + const vector> *groups_p) { + auto &aggr = aggregate->Cast(); + D_ASSERT(aggr.IsDistinct()); + + // Add the (empty in ungrouped case) groups of the aggregates + InitializeDistinctGroups(groups_p); + + // bindings.push_back(&aggr); + filter_count = 0; + aggregate_return_types.push_back(aggr.return_type); + for (idx_t i = 0; i < aggr.children.size(); i++) { + auto &child = aggr.children[i]; + group_types.push_back(child->return_type); + groups.push_back(child->Copy()); + payload_types.push_back(child->return_type); + if (aggr.filter) { + filter_count++; + } + } + if (!aggr.function.combine) { + throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); + } +} + +void GroupedAggregateData::InitializeDistinctGroups(const vector> *groups_p) { + if (!groups_p) { + return; + } + for (auto &expr : *groups_p) { + group_types.push_back(expr->return_type); + groups.push_back(expr->Copy()); + } +} + +void GroupedAggregateData::InitializeGroupbyGroups(vector> groups) { + // Add all the expressions of the group by clause + for (auto &expr : groups) { + group_types.push_back(expr->return_type); + } + this->groups = std::move(groups); +} + +void GroupedAggregateData::SetGroupingFunctions(vector> &functions) { + grouping_functions.reserve(functions.size()); + for (idx_t i = 0; i < functions.size(); i++) { + grouping_functions.push_back(std::move(functions[i])); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp new file mode 100644 index 00000000..dfbe9c0b --- /dev/null +++ b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp @@ -0,0 +1,876 @@ +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/aggregate_hashtable.hpp" +#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/parallel/interrupt.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +HashAggregateGroupingData::HashAggregateGroupingData(GroupingSet &grouping_set_p, + const GroupedAggregateData &grouped_aggregate_data, + unique_ptr &info) + : table_data(grouping_set_p, grouped_aggregate_data) { + if (info) { + distinct_data = make_uniq(*info, grouping_set_p, &grouped_aggregate_data.groups); + } +} + +bool HashAggregateGroupingData::HasDistinct() const { + return distinct_data != nullptr; +} + +HashAggregateGroupingGlobalState::HashAggregateGroupingGlobalState(const HashAggregateGroupingData &data, + ClientContext &context) { + table_state = data.table_data.GetGlobalSinkState(context); + if (data.HasDistinct()) { + distinct_state = make_uniq(*data.distinct_data, context); + } +} + +HashAggregateGroupingLocalState::HashAggregateGroupingLocalState(const PhysicalHashAggregate &op, + const HashAggregateGroupingData &data, + ExecutionContext &context) { + table_state = data.table_data.GetLocalSinkState(context); + if (!data.HasDistinct()) { + return; + } + auto &distinct_data = *data.distinct_data; + + auto &distinct_indices = op.distinct_collection_info->Indices(); + D_ASSERT(!distinct_indices.empty()); + + distinct_states.resize(op.distinct_collection_info->aggregates.size()); + auto &table_map = op.distinct_collection_info->table_map; + + for (auto &idx : distinct_indices) { + idx_t table_idx = table_map[idx]; + auto &radix_table = distinct_data.radix_tables[table_idx]; + if (radix_table == nullptr) { + // This aggregate has identical input as another aggregate, so no table is created for it + continue; + } + // Initialize the states of the radix tables used for the distinct aggregates + distinct_states[table_idx] = radix_table->GetLocalSinkState(context); + } +} + +static vector CreateGroupChunkTypes(vector> &groups) { + set group_indices; + + if (groups.empty()) { + return {}; + } + + for (auto &group : groups) { + D_ASSERT(group->type == ExpressionType::BOUND_REF); + auto &bound_ref = group->Cast(); + group_indices.insert(bound_ref.index); + } + idx_t highest_index = *group_indices.rbegin(); + vector types(highest_index + 1, LogicalType::SQLNULL); + for (auto &group : groups) { + auto &bound_ref = group->Cast(); + types[bound_ref.index] = bound_ref.return_type; + } + return types; +} + +bool PhysicalHashAggregate::CanSkipRegularSink() const { + if (!filter_indexes.empty()) { + // If we have filters, we can't skip the regular sink, because we might lose groups otherwise. + return false; + } + if (grouped_aggregate_data.aggregates.empty()) { + // When there are no aggregates, we have to add to the main ht right away + return false; + } + if (!non_distinct_filter.empty()) { + return false; + } + return true; +} + +PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, + vector> expressions, idx_t estimated_cardinality) + : PhysicalHashAggregate(context, std::move(types), std::move(expressions), {}, estimated_cardinality) { +} + +PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, + vector> expressions, + vector> groups_p, idx_t estimated_cardinality) + : PhysicalHashAggregate(context, std::move(types), std::move(expressions), std::move(groups_p), {}, {}, + estimated_cardinality) { +} + +PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, + vector> expressions, + vector> groups_p, + vector grouping_sets_p, + vector> grouping_functions_p, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::HASH_GROUP_BY, std::move(types), estimated_cardinality), + grouping_sets(std::move(grouping_sets_p)) { + // get a list of all aggregates to be computed + const idx_t group_count = groups_p.size(); + if (grouping_sets.empty()) { + GroupingSet set; + for (idx_t i = 0; i < group_count; i++) { + set.insert(i); + } + grouping_sets.push_back(std::move(set)); + } + input_group_types = CreateGroupChunkTypes(groups_p); + + grouped_aggregate_data.InitializeGroupby(std::move(groups_p), std::move(expressions), + std::move(grouping_functions_p)); + + auto &aggregates = grouped_aggregate_data.aggregates; + // filter_indexes must be pre-built, not lazily instantiated in parallel... + // Because everything that lives in this class should be read-only at execution time + idx_t aggregate_input_idx = 0; + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &aggregate = aggregates[i]; + auto &aggr = aggregate->Cast(); + aggregate_input_idx += aggr.children.size(); + if (aggr.aggr_type == AggregateType::DISTINCT) { + distinct_filter.push_back(i); + } else if (aggr.aggr_type == AggregateType::NON_DISTINCT) { + non_distinct_filter.push_back(i); + } else { // LCOV_EXCL_START + throw NotImplementedException("AggregateType not implemented in PhysicalHashAggregate"); + } // LCOV_EXCL_STOP + } + + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &aggregate = aggregates[i]; + auto &aggr = aggregate->Cast(); + if (aggr.filter) { + auto &bound_ref_expr = aggr.filter->Cast(); + if (!filter_indexes.count(aggr.filter.get())) { + // Replace the bound reference expression's index with the corresponding index of the payload chunk + filter_indexes[aggr.filter.get()] = bound_ref_expr.index; + bound_ref_expr.index = aggregate_input_idx; + } + aggregate_input_idx++; + } + } + + distinct_collection_info = DistinctAggregateCollectionInfo::Create(grouped_aggregate_data.aggregates); + + for (idx_t i = 0; i < grouping_sets.size(); i++) { + groupings.emplace_back(grouping_sets[i], grouped_aggregate_data, distinct_collection_info); + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class HashAggregateGlobalSinkState : public GlobalSinkState { +public: + HashAggregateGlobalSinkState(const PhysicalHashAggregate &op, ClientContext &context) { + grouping_states.reserve(op.groupings.size()); + for (idx_t i = 0; i < op.groupings.size(); i++) { + auto &grouping = op.groupings[i]; + grouping_states.emplace_back(grouping, context); + } + vector filter_types; + for (auto &aggr : op.grouped_aggregate_data.aggregates) { + auto &aggregate = aggr->Cast(); + for (auto &child : aggregate.children) { + payload_types.push_back(child->return_type); + } + if (aggregate.filter) { + filter_types.push_back(aggregate.filter->return_type); + } + } + payload_types.reserve(payload_types.size() + filter_types.size()); + payload_types.insert(payload_types.end(), filter_types.begin(), filter_types.end()); + } + + vector grouping_states; + vector payload_types; + //! Whether or not the aggregate is finished + bool finished = false; +}; + +class HashAggregateLocalSinkState : public LocalSinkState { +public: + HashAggregateLocalSinkState(const PhysicalHashAggregate &op, ExecutionContext &context) { + + auto &payload_types = op.grouped_aggregate_data.payload_types; + if (!payload_types.empty()) { + aggregate_input_chunk.InitializeEmpty(payload_types); + } + + grouping_states.reserve(op.groupings.size()); + for (auto &grouping : op.groupings) { + grouping_states.emplace_back(op, grouping, context); + } + // The filter set is only needed here for the distinct aggregates + // the filtering of data for the regular aggregates is done within the hashtable + vector aggregate_objects; + for (auto &aggregate : op.grouped_aggregate_data.aggregates) { + auto &aggr = aggregate->Cast(); + aggregate_objects.emplace_back(&aggr); + } + + filter_set.Initialize(context.client, aggregate_objects, payload_types); + } + + DataChunk aggregate_input_chunk; + vector grouping_states; + AggregateFilterDataSet filter_set; +}; + +void PhysicalHashAggregate::SetMultiScan(GlobalSinkState &state) { + auto &gstate = state.Cast(); + for (auto &grouping_state : gstate.grouping_states) { + RadixPartitionedHashTable::SetMultiScan(*grouping_state.table_state); + if (!grouping_state.distinct_state) { + continue; + } + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +unique_ptr PhysicalHashAggregate::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(*this, context); +} + +unique_ptr PhysicalHashAggregate::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(*this, context); +} + +void PhysicalHashAggregate::SinkDistinctGrouping(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, + idx_t grouping_idx) const { + auto &sink = input.local_state.Cast(); + auto &global_sink = input.global_state.Cast(); + + auto &grouping_gstate = global_sink.grouping_states[grouping_idx]; + auto &grouping_lstate = sink.grouping_states[grouping_idx]; + auto &distinct_info = *distinct_collection_info; + + auto &distinct_state = grouping_gstate.distinct_state; + auto &distinct_data = groupings[grouping_idx].distinct_data; + + DataChunk empty_chunk; + + // Create an empty filter for Sink, since we don't need to update any aggregate states here + unsafe_vector empty_filter; + + for (idx_t &idx : distinct_info.indices) { + auto &aggregate = grouped_aggregate_data.aggregates[idx]->Cast(); + + D_ASSERT(distinct_info.table_map.count(idx)); + idx_t table_idx = distinct_info.table_map[idx]; + if (!distinct_data->radix_tables[table_idx]) { + continue; + } + D_ASSERT(distinct_data->radix_tables[table_idx]); + auto &radix_table = *distinct_data->radix_tables[table_idx]; + auto &radix_global_sink = *distinct_state->radix_states[table_idx]; + auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; + + InterruptState interrupt_state; + OperatorSinkInput sink_input {radix_global_sink, radix_local_sink, interrupt_state}; + + if (aggregate.filter) { + DataChunk filter_chunk; + auto &filtered_data = sink.filter_set.GetFilterData(idx); + filter_chunk.InitializeEmpty(filtered_data.filtered_payload.GetTypes()); + + // Add the filter Vector (BOOL) + auto it = filter_indexes.find(aggregate.filter.get()); + D_ASSERT(it != filter_indexes.end()); + D_ASSERT(it->second < chunk.data.size()); + auto &filter_bound_ref = aggregate.filter->Cast(); + filter_chunk.data[filter_bound_ref.index].Reference(chunk.data[it->second]); + filter_chunk.SetCardinality(chunk.size()); + + // We cant use the AggregateFilterData::ApplyFilter method, because the chunk we need to + // apply the filter to also has the groups, and the filtered_data.filtered_payload does not have those. + SelectionVector sel_vec(STANDARD_VECTOR_SIZE); + idx_t count = filtered_data.filter_executor.SelectExpression(filter_chunk, sel_vec); + + if (count == 0) { + continue; + } + + // Because the 'input' chunk needs to be re-used after this, we need to create + // a duplicate of it, that we can apply the filter to + DataChunk filtered_input; + filtered_input.InitializeEmpty(chunk.GetTypes()); + + for (idx_t group_idx = 0; group_idx < grouped_aggregate_data.groups.size(); group_idx++) { + auto &group = grouped_aggregate_data.groups[group_idx]; + auto &bound_ref = group->Cast(); + filtered_input.data[bound_ref.index].Reference(chunk.data[bound_ref.index]); + } + for (idx_t child_idx = 0; child_idx < aggregate.children.size(); child_idx++) { + auto &child = aggregate.children[child_idx]; + auto &bound_ref = child->Cast(); + + filtered_input.data[bound_ref.index].Reference(chunk.data[bound_ref.index]); + } + filtered_input.Slice(sel_vec, count); + filtered_input.SetCardinality(count); + + radix_table.Sink(context, filtered_input, sink_input, empty_chunk, empty_filter); + } else { + radix_table.Sink(context, chunk, sink_input, empty_chunk, empty_filter); + } + } +} + +void PhysicalHashAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + for (idx_t i = 0; i < groupings.size(); i++) { + SinkDistinctGrouping(context, chunk, input, i); + } +} + +SinkResultType PhysicalHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &local_state = input.local_state.Cast(); + auto &global_state = input.global_state.Cast(); + + if (distinct_collection_info) { + SinkDistinct(context, chunk, input); + } + + if (CanSkipRegularSink()) { + return SinkResultType::NEED_MORE_INPUT; + } + + DataChunk &aggregate_input_chunk = local_state.aggregate_input_chunk; + auto &aggregates = grouped_aggregate_data.aggregates; + idx_t aggregate_input_idx = 0; + + // Populate the aggregate child vectors + for (auto &aggregate : aggregates) { + auto &aggr = aggregate->Cast(); + for (auto &child_expr : aggr.children) { + D_ASSERT(child_expr->type == ExpressionType::BOUND_REF); + auto &bound_ref_expr = child_expr->Cast(); + D_ASSERT(bound_ref_expr.index < chunk.data.size()); + aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.index]); + } + } + // Populate the filter vectors + for (auto &aggregate : aggregates) { + auto &aggr = aggregate->Cast(); + if (aggr.filter) { + auto it = filter_indexes.find(aggr.filter.get()); + D_ASSERT(it != filter_indexes.end()); + D_ASSERT(it->second < chunk.data.size()); + aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[it->second]); + } + } + + aggregate_input_chunk.SetCardinality(chunk.size()); + aggregate_input_chunk.Verify(); + + // For every grouping set there is one radix_table + for (idx_t i = 0; i < groupings.size(); i++) { + auto &grouping_local_state = global_state.grouping_states[i]; + auto &grouping_global_state = local_state.grouping_states[i]; + InterruptState interrupt_state; + OperatorSinkInput sink_input {*grouping_local_state.table_state, *grouping_global_state.table_state, + interrupt_state}; + + auto &grouping = groupings[i]; + auto &table = grouping.table_data; + table.Sink(context, chunk, sink_input, aggregate_input_chunk, non_distinct_filter); + } + + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Combine +//===--------------------------------------------------------------------===// +void PhysicalHashAggregate::CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const { + + auto &global_sink = input.global_state.Cast(); + auto &sink = input.local_state.Cast(); + + if (!distinct_collection_info) { + return; + } + for (idx_t i = 0; i < groupings.size(); i++) { + auto &grouping_gstate = global_sink.grouping_states[i]; + auto &grouping_lstate = sink.grouping_states[i]; + + auto &distinct_data = groupings[i].distinct_data; + auto &distinct_state = grouping_gstate.distinct_state; + + const auto table_count = distinct_data->radix_tables.size(); + for (idx_t table_idx = 0; table_idx < table_count; table_idx++) { + if (!distinct_data->radix_tables[table_idx]) { + continue; + } + auto &radix_table = *distinct_data->radix_tables[table_idx]; + auto &radix_global_sink = *distinct_state->radix_states[table_idx]; + auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; + + radix_table.Combine(context, radix_global_sink, radix_local_sink); + } + } +} + +SinkCombineResultType PhysicalHashAggregate::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &llstate = input.local_state.Cast(); + + OperatorSinkCombineInput combine_distinct_input {gstate, llstate, input.interrupt_state}; + CombineDistinct(context, combine_distinct_input); + + if (CanSkipRegularSink()) { + return SinkCombineResultType::FINISHED; + } + for (idx_t i = 0; i < groupings.size(); i++) { + auto &grouping_gstate = gstate.grouping_states[i]; + auto &grouping_lstate = llstate.grouping_states[i]; + + auto &grouping = groupings[i]; + auto &table = grouping.table_data; + table.Combine(context, *grouping_gstate.table_state, *grouping_lstate.table_state); + } + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +class HashAggregateFinalizeEvent : public BasePipelineEvent { +public: + //! "Regular" Finalize Event that is scheduled after combining the thread-local distinct HTs + HashAggregateFinalizeEvent(ClientContext &context, Pipeline *pipeline_p, const PhysicalHashAggregate &op_p, + HashAggregateGlobalSinkState &gstate_p) + : BasePipelineEvent(*pipeline_p), context(context), op(op_p), gstate(gstate_p) { + } + +public: + void Schedule() override; + +private: + ClientContext &context; + + const PhysicalHashAggregate &op; + HashAggregateGlobalSinkState &gstate; +}; + +class HashAggregateFinalizeTask : public ExecutorTask { +public: + HashAggregateFinalizeTask(ClientContext &context, Pipeline &pipeline, shared_ptr event_p, + const PhysicalHashAggregate &op, HashAggregateGlobalSinkState &state_p) + : ExecutorTask(pipeline.executor), context(context), pipeline(pipeline), event(std::move(event_p)), op(op), + gstate(state_p) { + } + +public: + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; + +private: + ClientContext &context; + Pipeline &pipeline; + shared_ptr event; + + const PhysicalHashAggregate &op; + HashAggregateGlobalSinkState &gstate; +}; + +void HashAggregateFinalizeEvent::Schedule() { + vector> tasks; + tasks.push_back(make_uniq(context, *pipeline, shared_from_this(), op, gstate)); + D_ASSERT(!tasks.empty()); + SetTasks(std::move(tasks)); +} + +TaskExecutionResult HashAggregateFinalizeTask::ExecuteTask(TaskExecutionMode mode) { + op.FinalizeInternal(pipeline, *event, context, gstate, false); + D_ASSERT(!gstate.finished); + gstate.finished = true; + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; +} + +class HashAggregateDistinctFinalizeEvent : public BasePipelineEvent { +public: + //! Distinct Finalize Event that is scheduled if we have distinct aggregates + HashAggregateDistinctFinalizeEvent(ClientContext &context, Pipeline &pipeline_p, const PhysicalHashAggregate &op_p, + HashAggregateGlobalSinkState &gstate_p) + : BasePipelineEvent(pipeline_p), context(context), op(op_p), gstate(gstate_p) { + } + +public: + void Schedule() override; + void FinishEvent() override; + +private: + void CreateGlobalSources(); + +private: + ClientContext &context; + + const PhysicalHashAggregate &op; + HashAggregateGlobalSinkState &gstate; + +public: + //! The GlobalSourceStates for all the radix tables of the distinct aggregates + vector>> global_source_states; +}; + +class HashAggregateDistinctFinalizeTask : public ExecutorTask { +public: + HashAggregateDistinctFinalizeTask(Pipeline &pipeline, shared_ptr event_p, const PhysicalHashAggregate &op, + HashAggregateGlobalSinkState &state_p) + : ExecutorTask(pipeline.executor), pipeline(pipeline), event(std::move(event_p)), op(op), gstate(state_p) { + } + +public: + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; + +private: + void AggregateDistinctGrouping(const idx_t grouping_idx); + +private: + Pipeline &pipeline; + shared_ptr event; + + const PhysicalHashAggregate &op; + HashAggregateGlobalSinkState &gstate; +}; + +void HashAggregateDistinctFinalizeEvent::Schedule() { + CreateGlobalSources(); + + const idx_t n_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + vector> tasks; + for (idx_t i = 0; i < n_threads; i++) { + tasks.push_back(make_uniq(*pipeline, shared_from_this(), op, gstate)); + } + SetTasks(std::move(tasks)); +} + +void HashAggregateDistinctFinalizeEvent::CreateGlobalSources() { + auto &aggregates = op.grouped_aggregate_data.aggregates; + global_source_states.reserve(op.groupings.size()); + for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { + auto &grouping = op.groupings[grouping_idx]; + auto &distinct_data = *grouping.distinct_data; + + vector> aggregate_sources; + aggregate_sources.reserve(aggregates.size()); + for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { + auto &aggregate = aggregates[agg_idx]; + auto &aggr = aggregate->Cast(); + + if (!aggr.IsDistinct()) { + aggregate_sources.push_back(nullptr); + continue; + } + D_ASSERT(distinct_data.info.table_map.count(agg_idx)); + + auto table_idx = distinct_data.info.table_map.at(agg_idx); + auto &radix_table_p = distinct_data.radix_tables[table_idx]; + aggregate_sources.push_back(radix_table_p->GetGlobalSourceState(context)); + } + global_source_states.push_back(std::move(aggregate_sources)); + } +} + +void HashAggregateDistinctFinalizeEvent::FinishEvent() { + // Now that everything is added to the main ht, we can actually finalize + auto new_event = make_shared(context, pipeline.get(), op, gstate); + this->InsertEvent(std::move(new_event)); +} + +TaskExecutionResult HashAggregateDistinctFinalizeTask::ExecuteTask(TaskExecutionMode mode) { + for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { + AggregateDistinctGrouping(grouping_idx); + } + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; +} + +void HashAggregateDistinctFinalizeTask::AggregateDistinctGrouping(const idx_t grouping_idx) { + D_ASSERT(op.distinct_collection_info); + auto &info = *op.distinct_collection_info; + + auto &grouping_data = op.groupings[grouping_idx]; + auto &grouping_state = gstate.grouping_states[grouping_idx]; + D_ASSERT(grouping_state.distinct_state); + auto &distinct_state = *grouping_state.distinct_state; + auto &distinct_data = *grouping_data.distinct_data; + + auto &aggregates = info.aggregates; + + // Thread-local contexts + ThreadContext thread_context(executor.context); + ExecutionContext execution_context(executor.context, thread_context, &pipeline); + + // Sink state to sink into global HTs + InterruptState interrupt_state; + auto &global_sink_state = *grouping_state.table_state; + auto local_sink_state = grouping_data.table_data.GetLocalSinkState(execution_context); + OperatorSinkInput sink_input {global_sink_state, *local_sink_state, interrupt_state}; + + // Create a chunk that mimics the 'input' chunk in Sink, for storing the group vectors + DataChunk group_chunk; + if (!op.input_group_types.empty()) { + group_chunk.Initialize(executor.context, op.input_group_types); + } + + auto &groups = op.grouped_aggregate_data.groups; + const idx_t group_by_size = groups.size(); + + DataChunk aggregate_input_chunk; + if (!gstate.payload_types.empty()) { + aggregate_input_chunk.Initialize(executor.context, gstate.payload_types); + } + + auto &finalize_event = event->Cast(); + + idx_t payload_idx; + idx_t next_payload_idx = 0; + for (idx_t agg_idx = 0; agg_idx < op.grouped_aggregate_data.aggregates.size(); agg_idx++) { + auto &aggregate = aggregates[agg_idx]->Cast(); + + // Forward the payload idx + payload_idx = next_payload_idx; + next_payload_idx = payload_idx + aggregate.children.size(); + + // If aggregate is not distinct, skip it + if (!distinct_data.IsDistinct(agg_idx)) { + continue; + } + + D_ASSERT(distinct_data.info.table_map.count(agg_idx)); + const auto &table_idx = distinct_data.info.table_map.at(agg_idx); + auto &radix_table = distinct_data.radix_tables[table_idx]; + + auto &sink = *distinct_state.radix_states[table_idx]; + auto local_source = radix_table->GetLocalSourceState(execution_context); + OperatorSourceInput source_input {*finalize_event.global_source_states[grouping_idx][agg_idx], *local_source, + interrupt_state}; + + // Create a duplicate of the output_chunk, because of multi-threading we cant alter the original + DataChunk output_chunk; + output_chunk.Initialize(executor.context, distinct_state.distinct_output_chunks[table_idx]->GetTypes()); + + // Fetch all the data from the aggregate ht, and Sink it into the main ht + while (true) { + output_chunk.Reset(); + group_chunk.Reset(); + aggregate_input_chunk.Reset(); + + auto res = radix_table->GetData(execution_context, output_chunk, sink, source_input); + if (res == SourceResultType::FINISHED) { + D_ASSERT(output_chunk.size() == 0); + break; + } else if (res == SourceResultType::BLOCKED) { + throw InternalException( + "Unexpected interrupt from radix table GetData in HashAggregateDistinctFinalizeTask"); + } + + auto &grouped_aggregate_data = *distinct_data.grouped_aggregate_data[table_idx]; + for (idx_t group_idx = 0; group_idx < group_by_size; group_idx++) { + auto &group = grouped_aggregate_data.groups[group_idx]; + auto &bound_ref_expr = group->Cast(); + group_chunk.data[bound_ref_expr.index].Reference(output_chunk.data[group_idx]); + } + group_chunk.SetCardinality(output_chunk); + + for (idx_t child_idx = 0; child_idx < grouped_aggregate_data.groups.size() - group_by_size; child_idx++) { + aggregate_input_chunk.data[payload_idx + child_idx].Reference( + output_chunk.data[group_by_size + child_idx]); + } + aggregate_input_chunk.SetCardinality(output_chunk); + + // Sink it into the main ht + grouping_data.table_data.Sink(execution_context, group_chunk, sink_input, aggregate_input_chunk, {agg_idx}); + } + } + grouping_data.table_data.Combine(execution_context, global_sink_state, *local_sink_state); +} + +SinkFinalizeType PhysicalHashAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, + GlobalSinkState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + D_ASSERT(distinct_collection_info); + + for (idx_t i = 0; i < groupings.size(); i++) { + auto &grouping = groupings[i]; + auto &distinct_data = *grouping.distinct_data; + auto &distinct_state = *gstate.grouping_states[i].distinct_state; + + for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { + if (!distinct_data.radix_tables[table_idx]) { + continue; + } + auto &radix_table = distinct_data.radix_tables[table_idx]; + auto &radix_state = *distinct_state.radix_states[table_idx]; + radix_table->Finalize(context, radix_state); + } + } + auto new_event = make_shared(context, pipeline, *this, gstate); + event.InsertEvent(std::move(new_event)); + return SinkFinalizeType::READY; +} + +SinkFinalizeType PhysicalHashAggregate::FinalizeInternal(Pipeline &pipeline, Event &event, ClientContext &context, + GlobalSinkState &gstate_p, bool check_distinct) const { + auto &gstate = gstate_p.Cast(); + + if (check_distinct && distinct_collection_info) { + // There are distinct aggregates + // If these are partitioned those need to be combined first + // Then we Finalize again, skipping this step + return FinalizeDistinct(pipeline, event, context, gstate_p); + } + + for (idx_t i = 0; i < groupings.size(); i++) { + auto &grouping = groupings[i]; + auto &grouping_gstate = gstate.grouping_states[i]; + grouping.table_data.Finalize(context, *grouping_gstate.table_state); + } + return SinkFinalizeType::READY; +} + +SinkFinalizeType PhysicalHashAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + return FinalizeInternal(pipeline, event, context, input.global_state, true); +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class HashAggregateGlobalSourceState : public GlobalSourceState { +public: + HashAggregateGlobalSourceState(ClientContext &context, const PhysicalHashAggregate &op) : op(op), state_index(0) { + for (auto &grouping : op.groupings) { + auto &rt = grouping.table_data; + radix_states.push_back(rt.GetGlobalSourceState(context)); + } + } + + const PhysicalHashAggregate &op; + mutex lock; + atomic state_index; + + vector> radix_states; + +public: + idx_t MaxThreads() override { + // If there are no tables, we only need one thread. + if (op.groupings.empty()) { + return 1; + } + + auto &ht_state = op.sink_state->Cast(); + idx_t partitions = 0; + for (size_t sidx = 0; sidx < op.groupings.size(); ++sidx) { + auto &grouping = op.groupings[sidx]; + auto &grouping_gstate = ht_state.grouping_states[sidx]; + partitions += grouping.table_data.NumberOfPartitions(*grouping_gstate.table_state); + } + return MaxValue(1, partitions); + } +}; + +unique_ptr PhysicalHashAggregate::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(context, *this); +} + +class HashAggregateLocalSourceState : public LocalSourceState { +public: + explicit HashAggregateLocalSourceState(ExecutionContext &context, const PhysicalHashAggregate &op) { + for (auto &grouping : op.groupings) { + auto &rt = grouping.table_data; + radix_states.push_back(rt.GetLocalSourceState(context)); + } + } + + vector> radix_states; +}; + +unique_ptr PhysicalHashAggregate::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(context, *this); +} + +SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &sink_gstate = sink_state->Cast(); + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + while (true) { + idx_t radix_idx = gstate.state_index; + if (radix_idx >= groupings.size()) { + break; + } + auto &grouping = groupings[radix_idx]; + auto &radix_table = grouping.table_data; + auto &grouping_gstate = sink_gstate.grouping_states[radix_idx]; + + InterruptState interrupt_state; + OperatorSourceInput source_input {*gstate.radix_states[radix_idx], *lstate.radix_states[radix_idx], + interrupt_state}; + auto res = radix_table.GetData(context, chunk, *grouping_gstate.table_state, source_input); + if (chunk.size() != 0) { + return SourceResultType::HAVE_MORE_OUTPUT; + } else if (res == SourceResultType::BLOCKED) { + throw InternalException("Unexpectedly Blocked from radix_table"); + } + + // move to the next table + lock_guard l(gstate.lock); + radix_idx++; + if (radix_idx > gstate.state_index) { + // we have not yet worked on the table + // move the global index forwards + gstate.state_index = radix_idx; + } + } + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +string PhysicalHashAggregate::ParamsToString() const { + string result; + auto &groups = grouped_aggregate_data.groups; + auto &aggregates = grouped_aggregate_data.aggregates; + for (idx_t i = 0; i < groups.size(); i++) { + if (i > 0) { + result += "\n"; + } + result += groups[i]->GetName(); + } + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &aggregate = aggregates[i]->Cast(); + if (i > 0 || !groups.empty()) { + result += "\n"; + } + result += aggregates[i]->GetName(); + if (aggregate.filter) { + result += " Filter: " + aggregate.filter->GetName(); + } + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp new file mode 100644 index 00000000..fe7e6c46 --- /dev/null +++ b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp @@ -0,0 +1,224 @@ +#include "duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp" + +#include "duckdb/execution/perfect_aggregate_hashtable.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(ClientContext &context, vector types_p, + vector> aggregates_p, + vector> groups_p, + const vector> &group_stats, + vector required_bits_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::PERFECT_HASH_GROUP_BY, std::move(types_p), estimated_cardinality), + groups(std::move(groups_p)), aggregates(std::move(aggregates_p)), required_bits(std::move(required_bits_p)) { + D_ASSERT(groups.size() == group_stats.size()); + group_minima.reserve(group_stats.size()); + for (auto &stats : group_stats) { + D_ASSERT(stats); + auto &nstats = *stats; + D_ASSERT(NumericStats::HasMin(nstats)); + group_minima.push_back(NumericStats::Min(nstats)); + } + for (auto &expr : groups) { + group_types.push_back(expr->return_type); + } + + vector bindings; + vector payload_types_filters; + for (auto &expr : aggregates) { + D_ASSERT(expr->expression_class == ExpressionClass::BOUND_AGGREGATE); + D_ASSERT(expr->IsAggregate()); + auto &aggr = expr->Cast(); + bindings.push_back(&aggr); + + D_ASSERT(!aggr.IsDistinct()); + D_ASSERT(aggr.function.combine); + for (auto &child : aggr.children) { + payload_types.push_back(child->return_type); + } + if (aggr.filter) { + payload_types_filters.push_back(aggr.filter->return_type); + } + } + for (const auto &pay_filters : payload_types_filters) { + payload_types.push_back(pay_filters); + } + aggregate_objects = AggregateObject::CreateAggregateObjects(bindings); + + // filter_indexes must be pre-built, not lazily instantiated in parallel... + idx_t aggregate_input_idx = 0; + for (auto &aggregate : aggregates) { + auto &aggr = aggregate->Cast(); + aggregate_input_idx += aggr.children.size(); + } + for (auto &aggregate : aggregates) { + auto &aggr = aggregate->Cast(); + if (aggr.filter) { + auto &bound_ref_expr = aggr.filter->Cast(); + auto it = filter_indexes.find(aggr.filter.get()); + if (it == filter_indexes.end()) { + filter_indexes[aggr.filter.get()] = bound_ref_expr.index; + bound_ref_expr.index = aggregate_input_idx++; + } else { + ++aggregate_input_idx; + } + } + } +} + +unique_ptr PhysicalPerfectHashAggregate::CreateHT(Allocator &allocator, + ClientContext &context) const { + return make_uniq(context, allocator, group_types, payload_types, aggregate_objects, + group_minima, required_bits); +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class PerfectHashAggregateGlobalState : public GlobalSinkState { +public: + PerfectHashAggregateGlobalState(const PhysicalPerfectHashAggregate &op, ClientContext &context) + : ht(op.CreateHT(Allocator::Get(context), context)) { + } + + //! The lock for updating the global aggregate state + mutex lock; + //! The global aggregate hash table + unique_ptr ht; +}; + +class PerfectHashAggregateLocalState : public LocalSinkState { +public: + PerfectHashAggregateLocalState(const PhysicalPerfectHashAggregate &op, ExecutionContext &context) + : ht(op.CreateHT(Allocator::Get(context.client), context.client)) { + group_chunk.InitializeEmpty(op.group_types); + if (!op.payload_types.empty()) { + aggregate_input_chunk.InitializeEmpty(op.payload_types); + } + } + + //! The local aggregate hash table + unique_ptr ht; + DataChunk group_chunk; + DataChunk aggregate_input_chunk; +}; + +unique_ptr PhysicalPerfectHashAggregate::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(*this, context); +} + +unique_ptr PhysicalPerfectHashAggregate::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(*this, context); +} + +SinkResultType PhysicalPerfectHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &lstate = input.local_state.Cast(); + DataChunk &group_chunk = lstate.group_chunk; + DataChunk &aggregate_input_chunk = lstate.aggregate_input_chunk; + + for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { + auto &group = groups[group_idx]; + D_ASSERT(group->type == ExpressionType::BOUND_REF); + auto &bound_ref_expr = group->Cast(); + group_chunk.data[group_idx].Reference(chunk.data[bound_ref_expr.index]); + } + idx_t aggregate_input_idx = 0; + for (auto &aggregate : aggregates) { + auto &aggr = aggregate->Cast(); + for (auto &child_expr : aggr.children) { + D_ASSERT(child_expr->type == ExpressionType::BOUND_REF); + auto &bound_ref_expr = child_expr->Cast(); + aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.index]); + } + } + for (auto &aggregate : aggregates) { + auto &aggr = aggregate->Cast(); + if (aggr.filter) { + auto it = filter_indexes.find(aggr.filter.get()); + D_ASSERT(it != filter_indexes.end()); + aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[it->second]); + } + } + + group_chunk.SetCardinality(chunk.size()); + + aggregate_input_chunk.SetCardinality(chunk.size()); + + group_chunk.Verify(); + aggregate_input_chunk.Verify(); + D_ASSERT(aggregate_input_chunk.ColumnCount() == 0 || group_chunk.size() == aggregate_input_chunk.size()); + + lstate.ht->AddChunk(group_chunk, aggregate_input_chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Combine +//===--------------------------------------------------------------------===// +SinkCombineResultType PhysicalPerfectHashAggregate::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &lstate = input.local_state.Cast(); + auto &gstate = input.global_state.Cast(); + + lock_guard l(gstate.lock); + gstate.ht->Combine(*lstate.ht); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class PerfectHashAggregateState : public GlobalSourceState { +public: + PerfectHashAggregateState() : ht_scan_position(0) { + } + + //! The current position to scan the HT for output tuples + idx_t ht_scan_position; +}; + +unique_ptr PhysicalPerfectHashAggregate::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(); +} + +SourceResultType PhysicalPerfectHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &state = input.global_state.Cast(); + auto &gstate = sink_state->Cast(); + + gstate.ht->Scan(state.ht_scan_position, chunk); + + if (chunk.size() > 0) { + return SourceResultType::HAVE_MORE_OUTPUT; + } else { + return SourceResultType::FINISHED; + } +} + +string PhysicalPerfectHashAggregate::ParamsToString() const { + string result; + for (idx_t i = 0; i < groups.size(); i++) { + if (i > 0) { + result += "\n"; + } + result += groups[i]->GetName(); + } + for (idx_t i = 0; i < aggregates.size(); i++) { + if (i > 0 || !groups.empty()) { + result += "\n"; + } + result += aggregates[i]->GetName(); + auto &aggregate = aggregates[i]->Cast(); + if (aggregate.filter) { + result += " Filter: " + aggregate.filter->GetName(); + } + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp new file mode 100644 index 00000000..8b2b482c --- /dev/null +++ b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp @@ -0,0 +1,220 @@ +#include "duckdb/execution/operator/aggregate/physical_streaming_window.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" + +namespace duckdb { + +PhysicalStreamingWindow::PhysicalStreamingWindow(vector types, vector> select_list, + idx_t estimated_cardinality, PhysicalOperatorType type) + : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { +} + +class StreamingWindowGlobalState : public GlobalOperatorState { +public: + StreamingWindowGlobalState() : row_number(1) { + } + + //! The next row number. + std::atomic row_number; +}; + +class StreamingWindowState : public OperatorState { +public: + using StateBuffer = vector; + + StreamingWindowState() + : initialized(false), allocator(Allocator::DefaultAllocator()), + statev(LogicalType::POINTER, data_ptr_cast(&state_ptr)) { + } + + ~StreamingWindowState() override { + for (size_t i = 0; i < aggregate_dtors.size(); ++i) { + auto dtor = aggregate_dtors[i]; + if (dtor) { + AggregateInputData aggr_input_data(aggregate_bind_data[i], allocator); + state_ptr = aggregate_states[i].data(); + dtor(statev, aggr_input_data, 1); + } + } + } + + void Initialize(ClientContext &context, DataChunk &input, const vector> &expressions) { + const_vectors.resize(expressions.size()); + aggregate_states.resize(expressions.size()); + aggregate_bind_data.resize(expressions.size(), nullptr); + aggregate_dtors.resize(expressions.size(), nullptr); + + for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) { + auto &expr = *expressions[expr_idx]; + auto &wexpr = expr.Cast(); + switch (expr.GetExpressionType()) { + case ExpressionType::WINDOW_AGGREGATE: { + auto &aggregate = *wexpr.aggregate; + auto &state = aggregate_states[expr_idx]; + aggregate_bind_data[expr_idx] = wexpr.bind_info.get(); + aggregate_dtors[expr_idx] = aggregate.destructor; + state.resize(aggregate.state_size()); + aggregate.initialize(state.data()); + break; + } + case ExpressionType::WINDOW_FIRST_VALUE: { + // Just execute the expression once + ExpressionExecutor executor(context); + executor.AddExpression(*wexpr.children[0]); + DataChunk result; + result.Initialize(Allocator::Get(context), {wexpr.children[0]->return_type}); + executor.Execute(input, result); + + const_vectors[expr_idx] = make_uniq(result.GetValue(0, 0)); + break; + } + case ExpressionType::WINDOW_PERCENT_RANK: { + const_vectors[expr_idx] = make_uniq(Value((double)0)); + break; + } + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: { + const_vectors[expr_idx] = make_uniq(Value((int64_t)1)); + break; + } + default: + break; + } + } + initialized = true; + } + +public: + bool initialized; + vector> const_vectors; + ArenaAllocator allocator; + + // Aggregation + vector aggregate_states; + vector aggregate_bind_data; + vector aggregate_dtors; + data_ptr_t state_ptr; + Vector statev; +}; + +unique_ptr PhysicalStreamingWindow::GetGlobalOperatorState(ClientContext &context) const { + return make_uniq(); +} + +unique_ptr PhysicalStreamingWindow::GetOperatorState(ExecutionContext &context) const { + return make_uniq(); +} + +OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate_p, OperatorState &state_p) const { + auto &gstate = gstate_p.Cast(); + auto &state = state_p.Cast(); + state.allocator.Reset(); + + if (!state.initialized) { + state.Initialize(context.client, input, select_list); + } + // Put payload columns in place + for (idx_t col_idx = 0; col_idx < input.data.size(); col_idx++) { + chunk.data[col_idx].Reference(input.data[col_idx]); + } + // Compute window function + const idx_t count = input.size(); + for (idx_t expr_idx = 0; expr_idx < select_list.size(); expr_idx++) { + idx_t col_idx = input.data.size() + expr_idx; + auto &expr = *select_list[expr_idx]; + auto &result = chunk.data[col_idx]; + switch (expr.GetExpressionType()) { + case ExpressionType::WINDOW_AGGREGATE: { + // Establish the aggregation environment + auto &wexpr = expr.Cast(); + auto &aggregate = *wexpr.aggregate; + auto &statev = state.statev; + state.state_ptr = state.aggregate_states[expr_idx].data(); + AggregateInputData aggr_input_data(wexpr.bind_info.get(), state.allocator); + + // Check for COUNT(*) + if (wexpr.children.empty()) { + D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t)); + auto data = FlatVector::GetData(result); + int64_t start_row = gstate.row_number; + for (idx_t i = 0; i < input.size(); ++i) { + data[i] = start_row + i; + } + break; + } + + // Compute the arguments + auto &allocator = Allocator::Get(context.client); + ExpressionExecutor executor(context.client); + vector payload_types; + for (auto &child : wexpr.children) { + payload_types.push_back(child->return_type); + executor.AddExpression(*child); + } + + DataChunk payload; + payload.Initialize(allocator, payload_types); + executor.Execute(input, payload); + + // Iterate through them using a single SV + payload.Flatten(); + DataChunk row; + row.Initialize(allocator, payload_types); + sel_t s = 0; + SelectionVector sel(&s); + row.Slice(sel, 1); + for (size_t col_idx = 0; col_idx < payload.ColumnCount(); ++col_idx) { + DictionaryVector::Child(row.data[col_idx]).Reference(payload.data[col_idx]); + } + + // Update the state and finalize it one row at a time. + for (idx_t i = 0; i < input.size(); ++i) { + sel.set_index(0, i); + aggregate.update(row.data.data(), aggr_input_data, row.ColumnCount(), statev, 1); + aggregate.finalize(statev, aggr_input_data, result, 1, i); + } + break; + } + case ExpressionType::WINDOW_FIRST_VALUE: + case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: { + // Reference constant vector + chunk.data[col_idx].Reference(*state.const_vectors[expr_idx]); + break; + } + case ExpressionType::WINDOW_ROW_NUMBER: { + // Set row numbers + int64_t start_row = gstate.row_number; + auto rdata = FlatVector::GetData(chunk.data[col_idx]); + for (idx_t i = 0; i < count; i++) { + rdata[i] = start_row + i; + } + break; + } + default: + throw NotImplementedException("%s for StreamingWindow", ExpressionTypeToString(expr.GetExpressionType())); + } + } + gstate.row_number += count; + chunk.SetCardinality(count); + return OperatorResultType::NEED_MORE_INPUT; +} + +string PhysicalStreamingWindow::ParamsToString() const { + string result; + for (idx_t i = 0; i < select_list.size(); i++) { + if (i > 0) { + result += "\n"; + } + result += select_list[i]->GetName(); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp new file mode 100644 index 00000000..8db0ede3 --- /dev/null +++ b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -0,0 +1,633 @@ +#include "duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" +#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" +#include "duckdb/execution/radix_partitioned_hashtable.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/parallel/interrupt.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +#include + +namespace duckdb { + +PhysicalUngroupedAggregate::PhysicalUngroupedAggregate(vector types, + vector> expressions, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::UNGROUPED_AGGREGATE, std::move(types), estimated_cardinality), + aggregates(std::move(expressions)) { + + distinct_collection_info = DistinctAggregateCollectionInfo::Create(aggregates); + if (!distinct_collection_info) { + return; + } + distinct_data = make_uniq(*distinct_collection_info); +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +struct AggregateState { + explicit AggregateState(const vector> &aggregate_expressions) { + counts = make_uniq_array>(aggregate_expressions.size()); + for (idx_t i = 0; i < aggregate_expressions.size(); i++) { + auto &aggregate = aggregate_expressions[i]; + D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); + auto &aggr = aggregate->Cast(); + auto state = make_unsafe_uniq_array(aggr.function.state_size()); + aggr.function.initialize(state.get()); + aggregates.push_back(std::move(state)); + bind_data.push_back(aggr.bind_info.get()); + destructors.push_back(aggr.function.destructor); +#ifdef DEBUG + counts[i] = 0; +#endif + } + } + ~AggregateState() { + D_ASSERT(destructors.size() == aggregates.size()); + for (idx_t i = 0; i < destructors.size(); i++) { + if (!destructors[i]) { + continue; + } + Vector state_vector(Value::POINTER(CastPointerToValue(aggregates[i].get()))); + state_vector.SetVectorType(VectorType::FLAT_VECTOR); + + ArenaAllocator allocator(Allocator::DefaultAllocator()); + AggregateInputData aggr_input_data(bind_data[i], allocator); + destructors[i](state_vector, aggr_input_data, 1); + } + } + + void Move(AggregateState &other) { + other.aggregates = std::move(aggregates); + other.destructors = std::move(destructors); + } + + //! The aggregate values + vector> aggregates; + //! The bind data + vector bind_data; + //! The destructors + vector destructors; + //! Counts (used for verification) + unique_array> counts; +}; + +class UngroupedAggregateGlobalSinkState : public GlobalSinkState { +public: + UngroupedAggregateGlobalSinkState(const PhysicalUngroupedAggregate &op, ClientContext &client) + : state(op.aggregates), finished(false), allocator(BufferAllocator::Get(client)) { + if (op.distinct_data) { + distinct_state = make_uniq(*op.distinct_data, client); + } + } + + //! The lock for updating the global aggregate state + mutex lock; + //! The global aggregate state + AggregateState state; + //! Whether or not the aggregate is finished + bool finished; + //! The data related to the distinct aggregates (if there are any) + unique_ptr distinct_state; + //! Global arena allocator + ArenaAllocator allocator; +}; + +class UngroupedAggregateLocalSinkState : public LocalSinkState { +public: + UngroupedAggregateLocalSinkState(const PhysicalUngroupedAggregate &op, const vector &child_types, + GlobalSinkState &gstate_p, ExecutionContext &context) + : allocator(BufferAllocator::Get(context.client)), state(op.aggregates), child_executor(context.client), + aggregate_input_chunk(), filter_set() { + auto &gstate = gstate_p.Cast(); + + auto &allocator = BufferAllocator::Get(context.client); + InitializeDistinctAggregates(op, gstate, context); + + vector payload_types; + vector aggregate_objects; + for (auto &aggregate : op.aggregates) { + D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); + auto &aggr = aggregate->Cast(); + // initialize the payload chunk + for (auto &child : aggr.children) { + payload_types.push_back(child->return_type); + child_executor.AddExpression(*child); + } + aggregate_objects.emplace_back(&aggr); + } + if (!payload_types.empty()) { // for select count(*) from t; there is no payload at all + aggregate_input_chunk.Initialize(allocator, payload_types); + } + filter_set.Initialize(context.client, aggregate_objects, child_types); + } + + //! Local arena allocator + ArenaAllocator allocator; + //! The local aggregate state + AggregateState state; + //! The executor + ExpressionExecutor child_executor; + //! The payload chunk, containing all the Vectors for the aggregates + DataChunk aggregate_input_chunk; + //! Aggregate filter data set + AggregateFilterDataSet filter_set; + //! The local sink states of the distinct aggregates hash tables + vector> radix_states; + +public: + void Reset() { + aggregate_input_chunk.Reset(); + } + void InitializeDistinctAggregates(const PhysicalUngroupedAggregate &op, + const UngroupedAggregateGlobalSinkState &gstate, ExecutionContext &context) { + + if (!op.distinct_data) { + return; + } + auto &data = *op.distinct_data; + auto &state = *gstate.distinct_state; + D_ASSERT(!data.radix_tables.empty()); + + const idx_t aggregate_count = state.radix_states.size(); + radix_states.resize(aggregate_count); + + auto &distinct_info = *op.distinct_collection_info; + + for (auto &idx : distinct_info.indices) { + idx_t table_idx = distinct_info.table_map[idx]; + if (data.radix_tables[table_idx] == nullptr) { + // This aggregate has identical input as another aggregate, so no table is created for it + continue; + } + auto &radix_table = *data.radix_tables[table_idx]; + radix_states[table_idx] = radix_table.GetLocalSinkState(context); + } + } +}; + +bool PhysicalUngroupedAggregate::SinkOrderDependent() const { + for (auto &expr : aggregates) { + auto &aggr = expr->Cast(); + if (aggr.function.order_dependent == AggregateOrderDependent::ORDER_DEPENDENT) { + return true; + } + } + return false; +} + +unique_ptr PhysicalUngroupedAggregate::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(*this, context); +} + +unique_ptr PhysicalUngroupedAggregate::GetLocalSinkState(ExecutionContext &context) const { + D_ASSERT(sink_state); + auto &gstate = *sink_state; + return make_uniq(*this, children[0]->GetTypes(), gstate, context); +} + +void PhysicalUngroupedAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &sink = input.local_state.Cast(); + auto &global_sink = input.global_state.Cast(); + D_ASSERT(distinct_data); + auto &distinct_state = *global_sink.distinct_state; + auto &distinct_info = *distinct_collection_info; + auto &distinct_indices = distinct_info.Indices(); + + DataChunk empty_chunk; + + auto &distinct_filter = distinct_info.Indices(); + + for (auto &idx : distinct_indices) { + auto &aggregate = aggregates[idx]->Cast(); + + idx_t table_idx = distinct_info.table_map[idx]; + if (!distinct_data->radix_tables[table_idx]) { + // This distinct aggregate shares its data with another + continue; + } + D_ASSERT(distinct_data->radix_tables[table_idx]); + auto &radix_table = *distinct_data->radix_tables[table_idx]; + auto &radix_global_sink = *distinct_state.radix_states[table_idx]; + auto &radix_local_sink = *sink.radix_states[table_idx]; + OperatorSinkInput sink_input {radix_global_sink, radix_local_sink, input.interrupt_state}; + + if (aggregate.filter) { + // The hashtable can apply a filter, but only on the payload + // And in our case, we need to filter the groups (the distinct aggr children) + + // Apply the filter before inserting into the hashtable + auto &filtered_data = sink.filter_set.GetFilterData(idx); + idx_t count = filtered_data.ApplyFilter(chunk); + filtered_data.filtered_payload.SetCardinality(count); + + radix_table.Sink(context, filtered_data.filtered_payload, sink_input, empty_chunk, distinct_filter); + } else { + radix_table.Sink(context, chunk, sink_input, empty_chunk, distinct_filter); + } + } +} + +SinkResultType PhysicalUngroupedAggregate::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &sink = input.local_state.Cast(); + + // perform the aggregation inside the local state + sink.Reset(); + + if (distinct_data) { + SinkDistinct(context, chunk, input); + } + + DataChunk &payload_chunk = sink.aggregate_input_chunk; + + idx_t payload_idx = 0; + idx_t next_payload_idx = 0; + + for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { + auto &aggregate = aggregates[aggr_idx]->Cast(); + + payload_idx = next_payload_idx; + next_payload_idx = payload_idx + aggregate.children.size(); + + if (aggregate.IsDistinct()) { + continue; + } + + idx_t payload_cnt = 0; + // resolve the filter (if any) + if (aggregate.filter) { + auto &filtered_data = sink.filter_set.GetFilterData(aggr_idx); + auto count = filtered_data.ApplyFilter(chunk); + + sink.child_executor.SetChunk(filtered_data.filtered_payload); + payload_chunk.SetCardinality(count); + } else { + sink.child_executor.SetChunk(chunk); + payload_chunk.SetCardinality(chunk); + } + +#ifdef DEBUG + sink.state.counts[aggr_idx] += payload_chunk.size(); +#endif + + // resolve the child expressions of the aggregate (if any) + for (idx_t i = 0; i < aggregate.children.size(); ++i) { + sink.child_executor.ExecuteExpression(payload_idx + payload_cnt, + payload_chunk.data[payload_idx + payload_cnt]); + payload_cnt++; + } + + auto start_of_input = payload_cnt == 0 ? nullptr : &payload_chunk.data[payload_idx]; + AggregateInputData aggr_input_data(aggregate.bind_info.get(), sink.allocator); + aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, + sink.state.aggregates[aggr_idx].get(), payload_chunk.size()); + } + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Combine +//===--------------------------------------------------------------------===// +void PhysicalUngroupedAggregate::CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + if (!distinct_data) { + return; + } + auto &distinct_state = gstate.distinct_state; + auto table_count = distinct_data->radix_tables.size(); + for (idx_t table_idx = 0; table_idx < table_count; table_idx++) { + D_ASSERT(distinct_data->radix_tables[table_idx]); + auto &radix_table = *distinct_data->radix_tables[table_idx]; + auto &radix_global_sink = *distinct_state->radix_states[table_idx]; + auto &radix_local_sink = *lstate.radix_states[table_idx]; + + radix_table.Combine(context, radix_global_sink, radix_local_sink); + } +} + +SinkCombineResultType PhysicalUngroupedAggregate::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + D_ASSERT(!gstate.finished); + + // finalize: combine the local state into the global state + // all aggregates are combinable: we might be doing a parallel aggregate + // use the combine method to combine the partial aggregates + OperatorSinkCombineInput distinct_input {gstate, lstate, input.interrupt_state}; + CombineDistinct(context, distinct_input); + + lock_guard glock(gstate.lock); + for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { + auto &aggregate = aggregates[aggr_idx]->Cast(); + + if (aggregate.IsDistinct()) { + continue; + } + + Vector source_state(Value::POINTER(CastPointerToValue(lstate.state.aggregates[aggr_idx].get()))); + Vector dest_state(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get()))); + + AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator); + aggregate.function.combine(source_state, dest_state, aggr_input_data, 1); +#ifdef DEBUG + gstate.state.counts[aggr_idx] += lstate.state.counts[aggr_idx]; +#endif + } + lstate.allocator.Destroy(); + + auto &client_profiler = QueryProfiler::Get(context.client); + context.thread.profiler.Flush(*this, lstate.child_executor, "child_executor", 0); + client_profiler.Flush(context.thread.profiler); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +class UngroupedDistinctAggregateFinalizeEvent : public BasePipelineEvent { +public: + UngroupedDistinctAggregateFinalizeEvent(ClientContext &context, const PhysicalUngroupedAggregate &op_p, + UngroupedAggregateGlobalSinkState &gstate_p, Pipeline &pipeline_p) + : BasePipelineEvent(pipeline_p), context(context), op(op_p), gstate(gstate_p), tasks_scheduled(0), + tasks_done(0) { + } + +public: + void Schedule() override; + +private: + ClientContext &context; + + const PhysicalUngroupedAggregate &op; + UngroupedAggregateGlobalSinkState &gstate; + +public: + mutex lock; + idx_t tasks_scheduled; + idx_t tasks_done; + + vector> global_source_states; +}; + +class UngroupedDistinctAggregateFinalizeTask : public ExecutorTask { +public: + UngroupedDistinctAggregateFinalizeTask(Executor &executor, shared_ptr event_p, + const PhysicalUngroupedAggregate &op, + UngroupedAggregateGlobalSinkState &state_p) + : ExecutorTask(executor), event(std::move(event_p)), op(op), gstate(state_p), + allocator(BufferAllocator::Get(executor.context)) { + } + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; + +private: + void AggregateDistinct(); + +private: + shared_ptr event; + + const PhysicalUngroupedAggregate &op; + UngroupedAggregateGlobalSinkState &gstate; + + ArenaAllocator allocator; +}; + +void UngroupedDistinctAggregateFinalizeEvent::Schedule() { + D_ASSERT(gstate.distinct_state); + auto &aggregates = op.aggregates; + auto &distinct_data = *op.distinct_data; + + idx_t payload_idx = 0; + idx_t next_payload_idx = 0; + for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { + auto &aggregate = aggregates[agg_idx]->Cast(); + + // Forward the payload idx + payload_idx = next_payload_idx; + next_payload_idx = payload_idx + aggregate.children.size(); + + // If aggregate is not distinct, skip it + if (!distinct_data.IsDistinct(agg_idx)) { + global_source_states.push_back(nullptr); + continue; + } + D_ASSERT(distinct_data.info.table_map.count(agg_idx)); + + // Create global state for scanning + auto table_idx = distinct_data.info.table_map.at(agg_idx); + auto &radix_table_p = *distinct_data.radix_tables[table_idx]; + global_source_states.push_back(radix_table_p.GetGlobalSourceState(context)); + } + + const idx_t n_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + vector> tasks; + for (idx_t i = 0; i < n_threads; i++) { + tasks.push_back( + make_uniq(pipeline->executor, shared_from_this(), op, gstate)); + tasks_scheduled++; + } + SetTasks(std::move(tasks)); +} + +TaskExecutionResult UngroupedDistinctAggregateFinalizeTask::ExecuteTask(TaskExecutionMode mode) { + AggregateDistinct(); + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; +} + +void UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() { + D_ASSERT(gstate.distinct_state); + auto &distinct_state = *gstate.distinct_state; + auto &distinct_data = *op.distinct_data; + + // Create thread-local copy of aggregate state + auto &aggregates = op.aggregates; + AggregateState state(aggregates); + + // Thread-local contexts + ThreadContext thread_context(executor.context); + ExecutionContext execution_context(executor.context, thread_context, nullptr); + + auto &finalize_event = event->Cast(); + + // Now loop through the distinct aggregates, scanning the distinct HTs + idx_t payload_idx = 0; + idx_t next_payload_idx = 0; + for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { + auto &aggregate = aggregates[agg_idx]->Cast(); + + // Forward the payload idx + payload_idx = next_payload_idx; + next_payload_idx = payload_idx + aggregate.children.size(); + + // If aggregate is not distinct, skip it + if (!distinct_data.IsDistinct(agg_idx)) { + continue; + } + + const auto table_idx = distinct_data.info.table_map.at(agg_idx); + auto &radix_table = *distinct_data.radix_tables[table_idx]; + auto lstate = radix_table.GetLocalSourceState(execution_context); + + auto &sink = *distinct_state.radix_states[table_idx]; + InterruptState interrupt_state; + OperatorSourceInput source_input {*finalize_event.global_source_states[agg_idx], *lstate, interrupt_state}; + + DataChunk output_chunk; + output_chunk.Initialize(executor.context, distinct_state.distinct_output_chunks[table_idx]->GetTypes()); + + DataChunk payload_chunk; + payload_chunk.InitializeEmpty(distinct_data.grouped_aggregate_data[table_idx]->group_types); + payload_chunk.SetCardinality(0); + + AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); + while (true) { + output_chunk.Reset(); + + auto res = radix_table.GetData(execution_context, output_chunk, sink, source_input); + if (res == SourceResultType::FINISHED) { + D_ASSERT(output_chunk.size() == 0); + break; + } else if (res == SourceResultType::BLOCKED) { + throw InternalException( + "Unexpected interrupt from radix table GetData in UngroupedDistinctAggregateFinalizeTask"); + } + + // We dont need to resolve the filter, we already did this in Sink + idx_t payload_cnt = aggregate.children.size(); + for (idx_t i = 0; i < payload_cnt; i++) { + payload_chunk.data[i].Reference(output_chunk.data[i]); + } + payload_chunk.SetCardinality(output_chunk); + +#ifdef DEBUG + gstate.state.counts[agg_idx] += payload_chunk.size(); +#endif + + // Update the aggregate state + auto start_of_input = payload_cnt ? &payload_chunk.data[0] : nullptr; + aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, + state.aggregates[agg_idx].get(), payload_chunk.size()); + } + } + + // After scanning the distinct HTs, we can combine the thread-local agg states with the thread-global + lock_guard guard(finalize_event.lock); + payload_idx = 0; + next_payload_idx = 0; + for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { + if (!distinct_data.IsDistinct(agg_idx)) { + continue; + } + + auto &aggregate = aggregates[agg_idx]->Cast(); + AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); + + Vector state_vec(Value::POINTER(CastPointerToValue(state.aggregates[agg_idx].get()))); + Vector combined_vec(Value::POINTER(CastPointerToValue(gstate.state.aggregates[agg_idx].get()))); + aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1); + } + + D_ASSERT(!gstate.finished); + if (++finalize_event.tasks_done == finalize_event.tasks_scheduled) { + gstate.finished = true; + } +} + +SinkFinalizeType PhysicalUngroupedAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, + GlobalSinkState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + D_ASSERT(distinct_data); + auto &distinct_state = *gstate.distinct_state; + + for (idx_t table_idx = 0; table_idx < distinct_data->radix_tables.size(); table_idx++) { + auto &radix_table_p = distinct_data->radix_tables[table_idx]; + auto &radix_state = *distinct_state.radix_states[table_idx]; + radix_table_p->Finalize(context, radix_state); + } + auto new_event = make_shared(context, *this, gstate, pipeline); + event.InsertEvent(std::move(new_event)); + return SinkFinalizeType::READY; +} + +SinkFinalizeType PhysicalUngroupedAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + + if (distinct_data) { + return FinalizeDistinct(pipeline, event, context, input.global_state); + } + + D_ASSERT(!gstate.finished); + gstate.finished = true; + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +void VerifyNullHandling(DataChunk &chunk, AggregateState &state, const vector> &aggregates) { +#ifdef DEBUG + for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { + auto &aggr = aggregates[aggr_idx]->Cast(); + if (state.counts[aggr_idx] == 0 && aggr.function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { + // Default is when 0 values go in, NULL comes out + UnifiedVectorFormat vdata; + chunk.data[aggr_idx].ToUnifiedFormat(1, vdata); + D_ASSERT(!vdata.validity.RowIsValid(vdata.sel->get_index(0))); + } + } +#endif +} + +SourceResultType PhysicalUngroupedAggregate::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &gstate = sink_state->Cast(); + D_ASSERT(gstate.finished); + + // initialize the result chunk with the aggregate values + chunk.SetCardinality(1); + for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { + auto &aggregate = aggregates[aggr_idx]->Cast(); + + Vector state_vector(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get()))); + AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator); + aggregate.function.finalize(state_vector, aggr_input_data, chunk.data[aggr_idx], 1, 0); + } + VerifyNullHandling(chunk, gstate.state, aggregates); + + return SourceResultType::FINISHED; +} + +string PhysicalUngroupedAggregate::ParamsToString() const { + string result; + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &aggregate = aggregates[i]->Cast(); + if (i > 0) { + result += "\n"; + } + result += aggregates[i]->GetName(); + if (aggregate.filter) { + result += " Filter: " + aggregate.filter->GetName(); + } + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp new file mode 100644 index 00000000..1b7d039f --- /dev/null +++ b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp @@ -0,0 +1,702 @@ +#include "duckdb/execution/operator/aggregate/physical_window.hpp" + +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/sort/partition_state.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/types/column/column_data_consumer.hpp" +#include "duckdb/common/types/row/row_data_collection_scanner.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/windows_undefs.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/window_executor.hpp" +#include "duckdb/execution/window_segment_tree.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" + +#include +#include +#include + +namespace duckdb { + +// Global sink state +class WindowGlobalSinkState : public GlobalSinkState { +public: + WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context) + : op(op), mode(DBConfig::GetConfig(context).options.window_mode) { + + D_ASSERT(op.select_list[0]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); + auto &wexpr = op.select_list[0]->Cast(); + + global_partition = + make_uniq(context, wexpr.partitions, wexpr.orders, op.children[0]->types, + wexpr.partitions_stats, op.estimated_cardinality); + } + + const PhysicalWindow &op; + unique_ptr global_partition; + WindowAggregationMode mode; +}; + +// Per-thread sink state +class WindowLocalSinkState : public LocalSinkState { +public: + WindowLocalSinkState(ClientContext &context, const WindowGlobalSinkState &gstate) + : local_partition(context, *gstate.global_partition) { + } + + void Sink(DataChunk &input_chunk) { + local_partition.Sink(input_chunk); + } + + void Combine() { + local_partition.Combine(); + } + + PartitionLocalSinkState local_partition; +}; + +// this implements a sorted window functions variant +PhysicalWindow::PhysicalWindow(vector types, vector> select_list_p, + idx_t estimated_cardinality, PhysicalOperatorType type) + : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list_p)) { + is_order_dependent = false; + for (auto &expr : select_list) { + D_ASSERT(expr->expression_class == ExpressionClass::BOUND_WINDOW); + auto &bound_window = expr->Cast(); + if (bound_window.partitions.empty() && bound_window.orders.empty()) { + is_order_dependent = true; + } + } +} + +static unique_ptr WindowExecutorFactory(BoundWindowExpression &wexpr, ClientContext &context, + const ValidityMask &partition_mask, + const ValidityMask &order_mask, const idx_t payload_count, + WindowAggregationMode mode) { + switch (wexpr.type) { + case ExpressionType::WINDOW_AGGREGATE: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask, mode); + case ExpressionType::WINDOW_ROW_NUMBER: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_RANK_DENSE: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_RANK: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_PERCENT_RANK: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_CUME_DIST: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_NTILE: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_LEAD: + case ExpressionType::WINDOW_LAG: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_FIRST_VALUE: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_LAST_VALUE: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + case ExpressionType::WINDOW_NTH_VALUE: + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + break; + default: + throw InternalException("Window aggregate type %s", ExpressionTypeToString(wexpr.type)); + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +SinkResultType PhysicalWindow::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &lstate = input.local_state.Cast(); + + lstate.Sink(chunk); + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalWindow::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &lstate = input.local_state.Cast(); + lstate.Combine(); + + return SinkCombineResultType::FINISHED; +} + +unique_ptr PhysicalWindow::GetLocalSinkState(ExecutionContext &context) const { + auto &gstate = sink_state->Cast(); + return make_uniq(context.client, gstate); +} + +unique_ptr PhysicalWindow::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(*this, context); +} + +SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &state = input.global_state.Cast(); + + // Did we get any data? + if (!state.global_partition->count) { + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // Do we have any sorting to schedule? + if (state.global_partition->rows) { + D_ASSERT(!state.global_partition->grouping_data); + return state.global_partition->rows->count ? SinkFinalizeType::READY : SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // Find the first group to sort + if (!state.global_partition->HasMergeTasks()) { + // Empty input! + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // Schedule all the sorts for maximum thread utilisation + auto new_event = make_shared(*state.global_partition, pipeline); + event.InsertEvent(std::move(new_event)); + + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class WindowPartitionSourceState; + +class WindowGlobalSourceState : public GlobalSourceState { +public: + using HashGroupSourcePtr = unique_ptr; + using ScannerPtr = unique_ptr; + using Task = std::pair; + + WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p); + + //! Get the next task + Task NextTask(idx_t hash_bin); + + //! Context for executing computations + ClientContext &context; + //! All the sunk data + WindowGlobalSinkState &gsink; + //! The next group to build. + atomic next_build; + //! The built groups + vector built; + //! Serialise access to the built hash groups + mutable mutex built_lock; + //! The number of unfinished tasks + atomic tasks_remaining; + +public: + idx_t MaxThreads() override { + return tasks_remaining; + } + +private: + Task CreateTask(idx_t hash_bin); + Task StealWork(); +}; + +WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p) + : context(context_p), gsink(gsink_p), next_build(0), tasks_remaining(0) { + auto &hash_groups = gsink.global_partition->hash_groups; + + auto &gpart = gsink.global_partition; + if (hash_groups.empty()) { + // OVER() + built.resize(1); + if (gpart->rows) { + tasks_remaining += gpart->rows->blocks.size(); + } + } else { + built.resize(hash_groups.size()); + idx_t batch_base = 0; + for (auto &hash_group : hash_groups) { + if (!hash_group) { + continue; + } + auto &global_sort_state = *hash_group->global_sort; + if (global_sort_state.sorted_blocks.empty()) { + continue; + } + + D_ASSERT(global_sort_state.sorted_blocks.size() == 1); + auto &sb = *global_sort_state.sorted_blocks[0]; + auto &sd = *sb.payload_data; + tasks_remaining += sd.data_blocks.size(); + + hash_group->batch_base = batch_base; + batch_base += sd.data_blocks.size(); + } + } +} + +// Per-bin evaluation state (build and evaluate) +class WindowPartitionSourceState { +public: + using HashGroupPtr = unique_ptr; + using ExecutorPtr = unique_ptr; + using Executors = vector; + + WindowPartitionSourceState(ClientContext &context, WindowGlobalSourceState &gsource) + : context(context), op(gsource.gsink.op), gsource(gsource), read_block_idx(0), unscanned(0) { + layout.Initialize(gsource.gsink.global_partition->payload_types); + } + + unique_ptr GetScanner() const; + void MaterializeSortedData(); + void BuildPartition(WindowGlobalSinkState &gstate, const idx_t hash_bin); + + ClientContext &context; + const PhysicalWindow &op; + WindowGlobalSourceState &gsource; + + HashGroupPtr hash_group; + //! The generated input chunks + unique_ptr rows; + unique_ptr heap; + RowLayout layout; + //! The partition boundary mask + vector partition_bits; + ValidityMask partition_mask; + //! The order boundary mask + vector order_bits; + ValidityMask order_mask; + //! External paging + bool external; + //! The current execution functions + Executors executors; + + //! The bin number + idx_t hash_bin; + + //! The next block to read. + mutable atomic read_block_idx; + //! The number of remaining unscanned blocks. + atomic unscanned; +}; + +void WindowPartitionSourceState::MaterializeSortedData() { + auto &global_sort_state = *hash_group->global_sort; + if (global_sort_state.sorted_blocks.empty()) { + return; + } + + // scan the sorted row data + D_ASSERT(global_sort_state.sorted_blocks.size() == 1); + auto &sb = *global_sort_state.sorted_blocks[0]; + + // Free up some memory before allocating more + sb.radix_sorting_data.clear(); + sb.blob_sorting_data = nullptr; + + // Move the sorting row blocks into our RDCs + auto &buffer_manager = global_sort_state.buffer_manager; + auto &sd = *sb.payload_data; + + // Data blocks are required + D_ASSERT(!sd.data_blocks.empty()); + auto &block = sd.data_blocks[0]; + rows = make_uniq(buffer_manager, block->capacity, block->entry_size); + rows->blocks = std::move(sd.data_blocks); + rows->count = std::accumulate(rows->blocks.begin(), rows->blocks.end(), idx_t(0), + [&](idx_t c, const unique_ptr &b) { return c + b->count; }); + + // Heap blocks are optional, but we want both for iteration. + if (!sd.heap_blocks.empty()) { + auto &block = sd.heap_blocks[0]; + heap = make_uniq(buffer_manager, block->capacity, block->entry_size); + heap->blocks = std::move(sd.heap_blocks); + hash_group.reset(); + } else { + heap = make_uniq(buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); + } + heap->count = std::accumulate(heap->blocks.begin(), heap->blocks.end(), idx_t(0), + [&](idx_t c, const unique_ptr &b) { return c + b->count; }); +} + +unique_ptr WindowPartitionSourceState::GetScanner() const { + auto &gsink = *gsource.gsink.global_partition; + if ((gsink.rows && !hash_bin) || hash_bin < gsink.hash_groups.size()) { + const auto block_idx = read_block_idx++; + if (block_idx >= rows->blocks.size()) { + return nullptr; + } + // Second pass can flush + --gsource.tasks_remaining; + return make_uniq(*rows, *heap, layout, external, block_idx, true); + } + return nullptr; +} + +void WindowPartitionSourceState::BuildPartition(WindowGlobalSinkState &gstate, const idx_t hash_bin_p) { + // Get rid of any stale data + hash_bin = hash_bin_p; + + // There are three types of partitions: + // 1. No partition (no sorting) + // 2. One partition (sorting, but no hashing) + // 3. Multiple partitions (sorting and hashing) + + // How big is the partition? + auto &gpart = *gsource.gsink.global_partition; + idx_t count = 0; + if (hash_bin < gpart.hash_groups.size() && gpart.hash_groups[hash_bin]) { + count = gpart.hash_groups[hash_bin]->count; + } else if (gpart.rows && !hash_bin) { + count = gpart.count; + } else { + return; + } + + // Initialise masks to false + const auto bit_count = ValidityMask::ValidityMaskSize(count); + partition_bits.clear(); + partition_bits.resize(bit_count, 0); + partition_mask.Initialize(partition_bits.data()); + + order_bits.clear(); + order_bits.resize(bit_count, 0); + order_mask.Initialize(order_bits.data()); + + // Scan the sorted data into new Collections + external = gpart.external; + if (gpart.rows && !hash_bin) { + // Simple mask + partition_mask.SetValidUnsafe(0); + order_mask.SetValidUnsafe(0); + // No partition - align the heap blocks with the row blocks + rows = gpart.rows->CloneEmpty(gpart.rows->keep_pinned); + heap = gpart.strings->CloneEmpty(gpart.strings->keep_pinned); + RowDataCollectionScanner::AlignHeapBlocks(*rows, *heap, *gpart.rows, *gpart.strings, layout); + external = true; + } else if (hash_bin < gpart.hash_groups.size()) { + // Overwrite the collections with the sorted data + D_ASSERT(gpart.hash_groups[hash_bin].get()); + hash_group = std::move(gpart.hash_groups[hash_bin]); + hash_group->ComputeMasks(partition_mask, order_mask); + external = hash_group->global_sort->external; + MaterializeSortedData(); + } else { + return; + } + + // Create the executors for each function + executors.clear(); + for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { + D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); + auto &wexpr = op.select_list[expr_idx]->Cast(); + auto wexec = WindowExecutorFactory(wexpr, context, partition_mask, order_mask, count, gstate.mode); + executors.emplace_back(std::move(wexec)); + } + + // First pass over the input without flushing + DataChunk input_chunk; + input_chunk.Initialize(gpart.allocator, gpart.payload_types); + auto scanner = make_uniq(*rows, *heap, layout, external, false); + idx_t input_idx = 0; + while (true) { + input_chunk.Reset(); + scanner->Scan(input_chunk); + if (input_chunk.size() == 0) { + break; + } + + // TODO: Parallelization opportunity + for (auto &wexec : executors) { + wexec->Sink(input_chunk, input_idx, scanner->Count()); + } + input_idx += input_chunk.size(); + } + + // TODO: Parallelization opportunity + for (auto &wexec : executors) { + wexec->Finalize(); + } + + // External scanning assumes all blocks are swizzled. + scanner->ReSwizzle(); + + // Start the block countdown + unscanned = rows->blocks.size(); +} + +// Per-thread scan state +class WindowLocalSourceState : public LocalSourceState { +public: + using ReadStatePtr = unique_ptr; + using ReadStates = vector; + + explicit WindowLocalSourceState(WindowGlobalSourceState &gsource); + void UpdateBatchIndex(); + bool NextPartition(); + void Scan(DataChunk &chunk); + + //! The shared source state + WindowGlobalSourceState &gsource; + //! The current bin being processed + idx_t hash_bin; + //! The current batch index (for output reordering) + idx_t batch_index; + //! The current source being processed + optional_ptr partition_source; + //! The read cursor + unique_ptr scanner; + //! Buffer for the inputs + DataChunk input_chunk; + //! Executor read states. + ReadStates read_states; + //! Buffer for window results + DataChunk output_chunk; +}; + +WindowLocalSourceState::WindowLocalSourceState(WindowGlobalSourceState &gsource) + : gsource(gsource), hash_bin(gsource.built.size()), batch_index(0) { + auto &gsink = *gsource.gsink.global_partition; + auto &op = gsource.gsink.op; + + input_chunk.Initialize(gsink.allocator, gsink.payload_types); + + vector output_types; + for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { + D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); + auto &wexpr = op.select_list[expr_idx]->Cast(); + output_types.emplace_back(wexpr.return_type); + } + output_chunk.Initialize(Allocator::Get(gsource.context), output_types); +} + +WindowGlobalSourceState::Task WindowGlobalSourceState::CreateTask(idx_t hash_bin) { + // Build outside the lock so no one tries to steal before we are done. + auto partition_source = make_uniq(context, *this); + partition_source->BuildPartition(gsink, hash_bin); + Task result(partition_source.get(), partition_source->GetScanner()); + + // Is there any data to scan? + if (result.second) { + lock_guard built_guard(built_lock); + built[hash_bin] = std::move(partition_source); + + return result; + } + + return Task(); +} + +WindowGlobalSourceState::Task WindowGlobalSourceState::StealWork() { + for (idx_t hash_bin = 0; hash_bin < built.size(); ++hash_bin) { + lock_guard built_guard(built_lock); + auto &partition_source = built[hash_bin]; + if (!partition_source) { + continue; + } + + Task result(partition_source.get(), partition_source->GetScanner()); + + // Is there any data to scan? + if (result.second) { + return result; + } + } + + // Nothing to steal + return Task(); +} + +WindowGlobalSourceState::Task WindowGlobalSourceState::NextTask(idx_t hash_bin) { + auto &hash_groups = gsink.global_partition->hash_groups; + const auto bin_count = built.size(); + + // Flush unneeded data + if (hash_bin < bin_count) { + // Lock and delete when all blocks have been scanned + // We do this here instead of in NextScan so the WindowLocalSourceState + // has a chance to delete its state objects first, + // which may reference the partition_source + + // Delete data outside the lock in case it is slow + HashGroupSourcePtr killed; + lock_guard built_guard(built_lock); + auto &partition_source = built[hash_bin]; + if (partition_source && !partition_source->unscanned) { + killed = std::move(partition_source); + } + } + + hash_bin = next_build++; + if (hash_bin < bin_count) { + // Find a non-empty hash group. + for (; hash_bin < hash_groups.size(); hash_bin = next_build++) { + if (hash_groups[hash_bin] && hash_groups[hash_bin]->count) { + auto result = CreateTask(hash_bin); + if (result.second) { + return result; + } + } + } + + // OVER() doesn't have a hash_group + if (hash_groups.empty()) { + auto result = CreateTask(hash_bin); + if (result.second) { + return result; + } + } + } + + // Work stealing + while (!context.interrupted && tasks_remaining) { + auto result = StealWork(); + if (result.second) { + return result; + } + + // If there is nothing to steal but there are unfinished partitions, + // yield until any pending builds are done. + TaskScheduler::YieldThread(); + } + + return Task(); +} + +void WindowLocalSourceState::UpdateBatchIndex() { + D_ASSERT(partition_source); + D_ASSERT(scanner.get()); + + batch_index = partition_source->hash_group ? partition_source->hash_group->batch_base : 0; + batch_index += scanner->BlockIndex(); +} + +bool WindowLocalSourceState::NextPartition() { + // Release old states before the source + scanner.reset(); + read_states.clear(); + + // Get a partition_source that is not finished + while (!scanner) { + auto task = gsource.NextTask(hash_bin); + if (!task.first) { + return false; + } + partition_source = task.first; + scanner = std::move(task.second); + hash_bin = partition_source->hash_bin; + UpdateBatchIndex(); + } + + for (auto &wexec : partition_source->executors) { + read_states.emplace_back(wexec->GetExecutorState()); + } + + return true; +} + +void WindowLocalSourceState::Scan(DataChunk &result) { + D_ASSERT(scanner); + if (!scanner->Remaining()) { + lock_guard built_guard(gsource.built_lock); + --partition_source->unscanned; + scanner = partition_source->GetScanner(); + + if (!scanner) { + partition_source = nullptr; + read_states.clear(); + return; + } + + UpdateBatchIndex(); + } + + const auto position = scanner->Scanned(); + input_chunk.Reset(); + scanner->Scan(input_chunk); + + auto &executors = partition_source->executors; + output_chunk.Reset(); + for (idx_t expr_idx = 0; expr_idx < executors.size(); ++expr_idx) { + auto &executor = *executors[expr_idx]; + auto &lstate = *read_states[expr_idx]; + auto &result = output_chunk.data[expr_idx]; + executor.Evaluate(position, input_chunk, result, lstate); + } + output_chunk.SetCardinality(input_chunk); + output_chunk.Verify(); + + idx_t out_idx = 0; + result.SetCardinality(input_chunk); + for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); col_idx++) { + result.data[out_idx++].Reference(input_chunk.data[col_idx]); + } + for (idx_t col_idx = 0; col_idx < output_chunk.ColumnCount(); col_idx++) { + result.data[out_idx++].Reference(output_chunk.data[col_idx]); + } + result.Verify(); +} + +unique_ptr PhysicalWindow::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gsource_p) const { + auto &gsource = gsource_p.Cast(); + return make_uniq(gsource); +} + +unique_ptr PhysicalWindow::GetGlobalSourceState(ClientContext &context) const { + auto &gsink = sink_state->Cast(); + return make_uniq(context, gsink); +} + +bool PhysicalWindow::SupportsBatchIndex() const { + // We can only preserve order for single partitioning + // or work stealing causes out of order batch numbers + auto &wexpr = select_list[0]->Cast(); + return wexpr.partitions.empty() && !wexpr.orders.empty(); +} + +OrderPreservationType PhysicalWindow::SourceOrder() const { + return SupportsBatchIndex() ? OrderPreservationType::FIXED_ORDER : OrderPreservationType::NO_ORDER; +} + +idx_t PhysicalWindow::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate_p, + LocalSourceState &lstate_p) const { + auto &lstate = lstate_p.Cast(); + return lstate.batch_index; +} + +SourceResultType PhysicalWindow::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &lsource = input.local_state.Cast(); + while (chunk.size() == 0) { + // Move to the next bin if we are done. + while (!lsource.scanner) { + if (!lsource.NextPartition()) { + return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; + } + } + + lsource.Scan(chunk); + } + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +string PhysicalWindow::ParamsToString() const { + string result; + for (idx_t i = 0; i < select_list.size(); i++) { + if (i > 0) { + result += "\n"; + } + result += select_list[i]->GetName(); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/base_csv_reader.cpp b/src/duckdb/src/execution/operator/csv_scanner/base_csv_reader.cpp new file mode 100644 index 00000000..a1c3e57b --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/base_csv_reader.cpp @@ -0,0 +1,595 @@ +#include "duckdb/execution/operator/scan/csv/base_csv_reader.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/decimal_cast_operators.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/main/appender.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/storage/data_table.hpp" +#include "utf8proc_wrapper.hpp" +#include "utf8proc.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/main/error_manager.hpp" +#include "duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp" +#include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" +#include "duckdb/main/client_data.hpp" +#include +#include +#include +#include + +namespace duckdb { + +string BaseCSVReader::GetLineNumberStr(idx_t line_error, bool is_line_estimated, idx_t buffer_idx) { + // If an error happens during auto-detect it is an estimated line + string estimated = (is_line_estimated ? string(" (estimated)") : string("")); + return to_string(GetLineError(line_error, buffer_idx)) + estimated; +} + +BaseCSVReader::BaseCSVReader(ClientContext &context_p, CSVReaderOptions options_p, + const vector &requested_types) + : context(context_p), fs(FileSystem::GetFileSystem(context)), allocator(BufferAllocator::Get(context)), + options(std::move(options_p)) { +} + +BaseCSVReader::~BaseCSVReader() { +} + +unique_ptr BaseCSVReader::OpenCSV(ClientContext &context, const CSVReaderOptions &options_p) { + return CSVFileHandle::OpenFile(FileSystem::GetFileSystem(context), BufferAllocator::Get(context), + options_p.file_path, options_p.compression); +} + +void BaseCSVReader::InitParseChunk(idx_t num_cols) { + // adapt not null info + if (options.force_not_null.size() != num_cols) { + options.force_not_null.resize(num_cols, false); + } + if (num_cols == parse_chunk.ColumnCount()) { + parse_chunk.Reset(); + } else { + parse_chunk.Destroy(); + + // initialize the parse_chunk with a set of VARCHAR types + vector varchar_types(num_cols, LogicalType::VARCHAR); + parse_chunk.Initialize(allocator, varchar_types); + } +} + +void BaseCSVReader::InitializeProjection() { + for (idx_t i = 0; i < GetTypes().size(); i++) { + reader_data.column_ids.push_back(i); + reader_data.column_mapping.push_back(i); + } +} + +template +static bool TemplatedTryCastDateVector(map &options, Vector &input_vector, + Vector &result_vector, idx_t count, string &error_message, idx_t &line_error) { + D_ASSERT(input_vector.GetType().id() == LogicalTypeId::VARCHAR); + bool all_converted = true; + idx_t cur_line = 0; + UnaryExecutor::Execute(input_vector, result_vector, count, [&](string_t input) { + T result; + if (!OP::Operation(options, input, result, error_message)) { + line_error = cur_line; + all_converted = false; + } + cur_line++; + return result; + }); + return all_converted; +} + +struct TryCastDateOperator { + static bool Operation(map &options, string_t input, date_t &result, + string &error_message) { + return options[LogicalTypeId::DATE].TryParseDate(input, result, error_message); + } +}; + +struct TryCastTimestampOperator { + static bool Operation(map &options, string_t input, timestamp_t &result, + string &error_message) { + return options[LogicalTypeId::TIMESTAMP].TryParseTimestamp(input, result, error_message); + } +}; + +bool BaseCSVReader::TryCastDateVector(map &options, Vector &input_vector, + Vector &result_vector, idx_t count, string &error_message, idx_t &line_error) { + return TemplatedTryCastDateVector(options, input_vector, result_vector, count, + error_message, line_error); +} + +bool BaseCSVReader::TryCastTimestampVector(map &options, Vector &input_vector, + Vector &result_vector, idx_t count, string &error_message) { + idx_t line_error; + return TemplatedTryCastDateVector(options, input_vector, result_vector, + count, error_message, line_error); +} + +void BaseCSVReader::VerifyLineLength(idx_t line_size, idx_t buffer_idx) { + if (line_size > options.maximum_line_size) { + throw InvalidInputException( + "Error in file \"%s\" on line %s: Maximum line size of %llu bytes exceeded!", options.file_path, + GetLineNumberStr(parse_chunk.size(), linenr_estimated, buffer_idx).c_str(), options.maximum_line_size); + } +} + +template +bool TemplatedTryCastFloatingVector(CSVReaderOptions &options, Vector &input_vector, Vector &result_vector, idx_t count, + string &error_message, idx_t &line_error) { + D_ASSERT(input_vector.GetType().id() == LogicalTypeId::VARCHAR); + bool all_converted = true; + idx_t row = 0; + UnaryExecutor::Execute(input_vector, result_vector, count, [&](string_t input) { + T result; + if (!OP::Operation(input, result, &error_message)) { + line_error = row; + all_converted = false; + } else { + row++; + } + return result; + }); + return all_converted; +} + +template +bool TemplatedTryCastDecimalVector(CSVReaderOptions &options, Vector &input_vector, Vector &result_vector, idx_t count, + string &error_message, uint8_t width, uint8_t scale) { + D_ASSERT(input_vector.GetType().id() == LogicalTypeId::VARCHAR); + bool all_converted = true; + UnaryExecutor::Execute(input_vector, result_vector, count, [&](string_t input) { + T result; + if (!OP::Operation(input, result, &error_message, width, scale)) { + all_converted = false; + } + return result; + }); + return all_converted; +} + +void BaseCSVReader::AddValue(string_t str_val, idx_t &column, vector &escape_positions, bool has_quotes, + idx_t buffer_idx) { + auto length = str_val.GetSize(); + if (length == 0 && column == 0) { + row_empty = true; + } else { + row_empty = false; + } + if (!return_types.empty() && column == return_types.size() && length == 0) { + // skip a single trailing delimiter in last column + return; + } + if (column >= return_types.size()) { + if (options.ignore_errors) { + error_column_overflow = true; + return; + } else { + throw InvalidInputException( + "Error in file \"%s\", on line %s: expected %lld values per row, but got more. (%s)", options.file_path, + GetLineNumberStr(linenr, linenr_estimated, buffer_idx).c_str(), return_types.size(), + options.ToString()); + } + } + + // insert the line number into the chunk + idx_t row_entry = parse_chunk.size(); + + // test against null string, but only if the value was not quoted + if ((!(has_quotes && !options.allow_quoted_nulls) || return_types[column].id() != LogicalTypeId::VARCHAR) && + !options.force_not_null[column] && Equals::Operation(str_val, string_t(options.null_str))) { + FlatVector::SetNull(parse_chunk.data[column], row_entry, true); + } else { + auto &v = parse_chunk.data[column]; + auto parse_data = FlatVector::GetData(v); + if (!escape_positions.empty()) { + // remove escape characters (if any) + string old_val = str_val.GetString(); + string new_val = ""; + idx_t prev_pos = 0; + for (idx_t i = 0; i < escape_positions.size(); i++) { + idx_t next_pos = escape_positions[i]; + new_val += old_val.substr(prev_pos, next_pos - prev_pos); + prev_pos = ++next_pos; + } + new_val += old_val.substr(prev_pos, old_val.size() - prev_pos); + escape_positions.clear(); + parse_data[row_entry] = StringVector::AddStringOrBlob(v, string_t(new_val)); + } else { + parse_data[row_entry] = str_val; + } + } + + // move to the next column + column++; +} + +bool BaseCSVReader::AddRow(DataChunk &insert_chunk, idx_t &column, string &error_message, idx_t buffer_idx) { + linenr++; + + if (row_empty) { + row_empty = false; + if (return_types.size() != 1) { + if (mode == ParserMode::PARSING) { + FlatVector::SetNull(parse_chunk.data[0], parse_chunk.size(), false); + } + column = 0; + return false; + } + } + + // Error forwarded by 'ignore_errors' - originally encountered in 'AddValue' + if (error_column_overflow) { + D_ASSERT(options.ignore_errors); + error_column_overflow = false; + column = 0; + return false; + } + + if (column < return_types.size()) { + if (options.null_padding) { + for (; column < return_types.size(); column++) { + FlatVector::SetNull(parse_chunk.data[column], parse_chunk.size(), true); + } + } else if (options.ignore_errors) { + column = 0; + return false; + } else { + if (mode == ParserMode::SNIFFING_DATATYPES) { + error_message = "Error when adding line"; + return false; + } else { + throw InvalidInputException( + "Error in file \"%s\" on line %s: expected %lld values per row, but got %d.\nParser options:\n%s", + options.file_path, GetLineNumberStr(linenr, linenr_estimated, buffer_idx).c_str(), + return_types.size(), column, options.ToString()); + } + } + } + + parse_chunk.SetCardinality(parse_chunk.size() + 1); + + if (mode == ParserMode::PARSING_HEADER) { + return true; + } + + if (mode == ParserMode::SNIFFING_DATATYPES) { + return true; + } + + if (mode == ParserMode::PARSING && parse_chunk.size() == STANDARD_VECTOR_SIZE) { + Flush(insert_chunk, buffer_idx); + return true; + } + + column = 0; + return false; +} + +void BaseCSVReader::VerifyUTF8(idx_t col_idx, idx_t row_idx, DataChunk &chunk, int64_t offset) { + D_ASSERT(col_idx < chunk.data.size()); + D_ASSERT(row_idx < chunk.size()); + auto &v = chunk.data[col_idx]; + if (FlatVector::IsNull(v, row_idx)) { + return; + } + + auto parse_data = FlatVector::GetData(chunk.data[col_idx]); + auto s = parse_data[row_idx]; + auto utf_type = Utf8Proc::Analyze(s.GetData(), s.GetSize()); + if (utf_type == UnicodeType::INVALID) { + string col_name = to_string(col_idx); + if (col_idx < names.size()) { + col_name = "\"" + names[col_idx] + "\""; + } + int64_t error_line = linenr - (chunk.size() - row_idx) + 1 + offset; + D_ASSERT(error_line >= 0); + throw InvalidInputException("Error in file \"%s\" at line %llu in column \"%s\": " + "%s. Parser options:\n%s", + options.file_path, error_line, col_name, + ErrorManager::InvalidUnicodeError(s.GetString(), "CSV file"), options.ToString()); + } +} + +void BaseCSVReader::VerifyUTF8(idx_t col_idx) { + D_ASSERT(col_idx < parse_chunk.data.size()); + for (idx_t i = 0; i < parse_chunk.size(); i++) { + VerifyUTF8(col_idx, i, parse_chunk); + } +} + +bool TryCastDecimalVectorCommaSeparated(CSVReaderOptions &options, Vector &input_vector, Vector &result_vector, + idx_t count, string &error_message, const LogicalType &result_type) { + auto width = DecimalType::GetWidth(result_type); + auto scale = DecimalType::GetScale(result_type); + switch (result_type.InternalType()) { + case PhysicalType::INT16: + return TemplatedTryCastDecimalVector( + options, input_vector, result_vector, count, error_message, width, scale); + case PhysicalType::INT32: + return TemplatedTryCastDecimalVector( + options, input_vector, result_vector, count, error_message, width, scale); + case PhysicalType::INT64: + return TemplatedTryCastDecimalVector( + options, input_vector, result_vector, count, error_message, width, scale); + case PhysicalType::INT128: + return TemplatedTryCastDecimalVector( + options, input_vector, result_vector, count, error_message, width, scale); + default: + throw InternalException("Unimplemented physical type for decimal"); + } +} + +bool TryCastFloatingVectorCommaSeparated(CSVReaderOptions &options, Vector &input_vector, Vector &result_vector, + idx_t count, string &error_message, const LogicalType &result_type, + idx_t &line_error) { + switch (result_type.InternalType()) { + case PhysicalType::DOUBLE: + return TemplatedTryCastFloatingVector( + options, input_vector, result_vector, count, error_message, line_error); + case PhysicalType::FLOAT: + return TemplatedTryCastFloatingVector( + options, input_vector, result_vector, count, error_message, line_error); + default: + throw InternalException("Unimplemented physical type for floating"); + } +} + +// Location of erroneous value in the current parse chunk +struct ErrorLocation { + idx_t row_idx; + idx_t col_idx; + idx_t row_line; + + ErrorLocation(idx_t row_idx, idx_t col_idx, idx_t row_line) + : row_idx(row_idx), col_idx(col_idx), row_line(row_line) { + } +}; + +bool BaseCSVReader::Flush(DataChunk &insert_chunk, idx_t buffer_idx, bool try_add_line) { + if (parse_chunk.size() == 0) { + return true; + } + + bool conversion_error_ignored = false; + + // convert the columns in the parsed chunk to the types of the table + insert_chunk.SetCardinality(parse_chunk); + if (reader_data.column_ids.empty() && !reader_data.empty_columns) { + throw InternalException("BaseCSVReader::Flush called on a CSV reader that was not correctly initialized. Call " + "MultiFileReader::InitializeReader or InitializeProjection"); + } + D_ASSERT(reader_data.column_ids.size() == reader_data.column_mapping.size()); + for (idx_t c = 0; c < reader_data.column_ids.size(); c++) { + auto col_idx = reader_data.column_ids[c]; + auto result_idx = reader_data.column_mapping[c]; + auto &parse_vector = parse_chunk.data[col_idx]; + auto &result_vector = insert_chunk.data[result_idx]; + auto &type = result_vector.GetType(); + if (type.id() == LogicalTypeId::VARCHAR) { + // target type is varchar: no need to convert + // just test that all strings are valid utf-8 strings + VerifyUTF8(col_idx); + // reinterpret rather than reference so we can deal with user-defined types + result_vector.Reinterpret(parse_vector); + } else { + string error_message; + bool success; + idx_t line_error = 0; + bool target_type_not_varchar = false; + if (options.dialect_options.has_format[LogicalTypeId::DATE] && type.id() == LogicalTypeId::DATE) { + // use the date format to cast the chunk + success = TryCastDateVector(options.dialect_options.date_format, parse_vector, result_vector, + parse_chunk.size(), error_message, line_error); + } else if (options.dialect_options.has_format[LogicalTypeId::TIMESTAMP] && + type.id() == LogicalTypeId::TIMESTAMP) { + // use the date format to cast the chunk + success = TryCastTimestampVector(options.dialect_options.date_format, parse_vector, result_vector, + parse_chunk.size(), error_message); + } else if (options.decimal_separator != "." && + (type.id() == LogicalTypeId::FLOAT || type.id() == LogicalTypeId::DOUBLE)) { + success = TryCastFloatingVectorCommaSeparated(options, parse_vector, result_vector, parse_chunk.size(), + error_message, type, line_error); + } else if (options.decimal_separator != "." && type.id() == LogicalTypeId::DECIMAL) { + success = TryCastDecimalVectorCommaSeparated(options, parse_vector, result_vector, parse_chunk.size(), + error_message, type); + } else { + // target type is not varchar: perform a cast + target_type_not_varchar = true; + success = + VectorOperations::TryCast(context, parse_vector, result_vector, parse_chunk.size(), &error_message); + } + if (success) { + continue; + } + if (try_add_line) { + return false; + } + + string col_name = to_string(col_idx); + if (col_idx < names.size()) { + col_name = "\"" + names[col_idx] + "\""; + } + + // figure out the exact line number + if (target_type_not_varchar) { + UnifiedVectorFormat inserted_column_data; + result_vector.ToUnifiedFormat(parse_chunk.size(), inserted_column_data); + for (; line_error < parse_chunk.size(); line_error++) { + if (!inserted_column_data.validity.RowIsValid(line_error) && + !FlatVector::IsNull(parse_vector, line_error)) { + break; + } + } + } + + // The line_error must be summed with linenr (All lines emmited from this batch) + // But subtracted from the parse_chunk + D_ASSERT(line_error + linenr >= parse_chunk.size()); + line_error += linenr; + line_error -= parse_chunk.size(); + + auto error_line = GetLineError(line_error, buffer_idx); + + if (options.ignore_errors) { + conversion_error_ignored = true; + + } else if (options.auto_detect) { + throw InvalidInputException("%s in column %s, at line %llu.\n\nParser " + "options:\n%s.\n\nConsider either increasing the sample size " + "(SAMPLE_SIZE=X [X rows] or SAMPLE_SIZE=-1 [all rows]), " + "or skipping column conversion (ALL_VARCHAR=1)", + error_message, col_name, error_line, options.ToString()); + } else { + throw InvalidInputException("%s at line %llu in column %s. Parser options:\n%s ", error_message, + error_line, col_name, options.ToString()); + } + } + } + if (conversion_error_ignored) { + D_ASSERT(options.ignore_errors); + + SelectionVector succesful_rows(parse_chunk.size()); + idx_t sel_size = 0; + + // Keep track of failed cells + vector failed_cells; + + for (idx_t row_idx = 0; row_idx < parse_chunk.size(); row_idx++) { + + auto global_row_idx = row_idx + linenr - parse_chunk.size(); + auto row_line = GetLineError(global_row_idx, buffer_idx, false); + + bool row_failed = false; + for (idx_t c = 0; c < reader_data.column_ids.size(); c++) { + auto col_idx = reader_data.column_ids[c]; + auto result_idx = reader_data.column_mapping[c]; + + auto &parse_vector = parse_chunk.data[col_idx]; + auto &result_vector = insert_chunk.data[result_idx]; + + bool was_already_null = FlatVector::IsNull(parse_vector, row_idx); + if (!was_already_null && FlatVector::IsNull(result_vector, row_idx)) { + Increment(buffer_idx); + auto bla = GetLineError(global_row_idx, buffer_idx, false); + row_idx += bla; + row_idx -= bla; + row_failed = true; + failed_cells.emplace_back(row_idx, col_idx, row_line); + } + } + if (!row_failed) { + succesful_rows.set_index(sel_size++, row_idx); + } + } + + // Now do a second pass to produce the reject table entries + if (!failed_cells.empty() && !options.rejects_table_name.empty()) { + auto limit = options.rejects_limit; + + auto rejects = CSVRejectsTable::GetOrCreate(context, options.rejects_table_name); + lock_guard lock(rejects->write_lock); + + // short circuit if we already have too many rejects + if (limit == 0 || rejects->count < limit) { + auto &table = rejects->GetTable(context); + InternalAppender appender(context, table); + auto file_name = GetFileName(); + + for (auto &cell : failed_cells) { + if (limit != 0 && rejects->count >= limit) { + break; + } + rejects->count++; + + auto row_idx = cell.row_idx; + auto col_idx = cell.col_idx; + auto row_line = cell.row_line; + + auto col_name = to_string(col_idx); + if (col_idx < names.size()) { + col_name = "\"" + names[col_idx] + "\""; + } + + auto &parse_vector = parse_chunk.data[col_idx]; + auto parsed_str = FlatVector::GetData(parse_vector)[row_idx]; + auto &type = insert_chunk.data[col_idx].GetType(); + auto row_error_msg = StringUtil::Format("Could not convert string '%s' to '%s'", + parsed_str.GetString(), type.ToString()); + + // Add the row to the rejects table + appender.BeginRow(); + appender.Append(string_t(file_name)); + appender.Append(row_line); + appender.Append(col_idx); + appender.Append(string_t(col_name)); + appender.Append(parsed_str); + + if (!options.rejects_recovery_columns.empty()) { + child_list_t recovery_key; + for (auto &key_idx : options.rejects_recovery_column_ids) { + // Figure out if the recovery key is valid. + // If not, error out for real. + auto &component_vector = parse_chunk.data[key_idx]; + if (FlatVector::IsNull(component_vector, row_idx)) { + throw InvalidInputException("%s at line %llu in column %s. Parser options:\n%s ", + "Could not parse recovery column", row_line, col_name, + options.ToString()); + } + auto component = Value(FlatVector::GetData(component_vector)[row_idx]); + recovery_key.emplace_back(names[key_idx], component); + } + appender.Append(Value::STRUCT(recovery_key)); + } + + appender.Append(string_t(row_error_msg)); + appender.EndRow(); + } + appender.Close(); + } + } + + // Now slice the insert chunk to only include the succesful rows + insert_chunk.Slice(succesful_rows, sel_size); + } + parse_chunk.Reset(); + return true; +} + +void BaseCSVReader::SetNewLineDelimiter(bool carry, bool carry_followed_by_nl) { + if (options.dialect_options.new_line == NewLineIdentifier::NOT_SET) { + if (options.dialect_options.new_line == NewLineIdentifier::MIX) { + return; + } + NewLineIdentifier this_line_identifier; + if (carry) { + if (carry_followed_by_nl) { + this_line_identifier = NewLineIdentifier::CARRY_ON; + } else { + this_line_identifier = NewLineIdentifier::SINGLE; + } + } else { + this_line_identifier = NewLineIdentifier::SINGLE; + } + if (options.dialect_options.new_line == NewLineIdentifier::NOT_SET) { + options.dialect_options.new_line = this_line_identifier; + return; + } + if (options.dialect_options.new_line != this_line_identifier) { + options.dialect_options.new_line = NewLineIdentifier::MIX; + return; + } + options.dialect_options.new_line = this_line_identifier; + } +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/buffered_csv_reader.cpp b/src/duckdb/src/execution/operator/csv_scanner/buffered_csv_reader.cpp new file mode 100644 index 00000000..55c9494c --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/buffered_csv_reader.cpp @@ -0,0 +1,434 @@ +#include "duckdb/execution/operator/scan/csv/buffered_csv_reader.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "duckdb/execution/operator/scan/csv/csv_state_machine.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/error_manager.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/storage/data_table.hpp" +#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" + +#include +#include +#include +#include + +namespace duckdb { + +BufferedCSVReader::BufferedCSVReader(ClientContext &context, CSVReaderOptions options_p, + const vector &requested_types) + : BaseCSVReader(context, std::move(options_p), requested_types), buffer_size(0), position(0), start(0) { + file_handle = OpenCSV(context, options); + Initialize(requested_types); +} + +BufferedCSVReader::BufferedCSVReader(ClientContext &context, string filename, CSVReaderOptions options_p, + const vector &requested_types) + : BaseCSVReader(context, std::move(options_p), requested_types), buffer_size(0), position(0), start(0) { + options.file_path = std::move(filename); + file_handle = OpenCSV(context, options); + Initialize(requested_types); +} + +void BufferedCSVReader::Initialize(const vector &requested_types) { + if (options.auto_detect && options.file_options.union_by_name) { + // This is required for the sniffer to work on Union By Name + D_ASSERT(options.file_path == file_handle->GetFilePath()); + auto bm_file_handle = BaseCSVReader::OpenCSV(context, options); + auto csv_buffer_manager = make_shared(context, std::move(bm_file_handle), options); + CSVSniffer sniffer(options, csv_buffer_manager, state_machine_cache); + auto sniffer_result = sniffer.SniffCSV(); + return_types = sniffer_result.return_types; + names = sniffer_result.names; + if (return_types.empty()) { + throw InvalidInputException("Failed to detect column types from CSV: is the file a valid CSV file?"); + } + } else { + return_types = requested_types; + ResetBuffer(); + } + SkipRowsAndReadHeader(options.dialect_options.skip_rows, options.dialect_options.header); + InitParseChunk(return_types.size()); +} + +void BufferedCSVReader::ResetBuffer() { + buffer.reset(); + buffer_size = 0; + position = 0; + start = 0; + cached_buffers.clear(); +} + +void BufferedCSVReader::SkipRowsAndReadHeader(idx_t skip_rows, bool skip_header) { + for (idx_t i = 0; i < skip_rows; i++) { + // ignore skip rows + string read_line = file_handle->ReadLine(); + linenr++; + } + + if (skip_header) { + // ignore the first line as a header line + InitParseChunk(return_types.size()); + ParseCSV(ParserMode::PARSING_HEADER); + } +} + +string BufferedCSVReader::ColumnTypesError(case_insensitive_map_t sql_types_per_column, + const vector &names) { + for (idx_t i = 0; i < names.size(); i++) { + auto it = sql_types_per_column.find(names[i]); + if (it != sql_types_per_column.end()) { + sql_types_per_column.erase(names[i]); + continue; + } + } + if (sql_types_per_column.empty()) { + return string(); + } + string exception = "COLUMN_TYPES error: Columns with names: "; + for (auto &col : sql_types_per_column) { + exception += "\"" + col.first + "\","; + } + exception.pop_back(); + exception += " do not exist in the CSV File"; + return exception; +} + +void BufferedCSVReader::SkipEmptyLines() { + if (parse_chunk.data.size() == 1) { + // Empty lines are null data. + return; + } + for (; position < buffer_size; position++) { + if (!StringUtil::CharacterIsNewline(buffer[position])) { + return; + } + } +} + +void UpdateMaxLineLength(ClientContext &context, idx_t line_length) { + if (!context.client_data->debug_set_max_line_length) { + return; + } + if (line_length < context.client_data->debug_max_line_length) { + return; + } + context.client_data->debug_max_line_length = line_length; +} + +bool BufferedCSVReader::ReadBuffer(idx_t &start, idx_t &line_start) { + if (start > buffer_size) { + return false; + } + auto old_buffer = std::move(buffer); + + // the remaining part of the last buffer + idx_t remaining = buffer_size - start; + + idx_t buffer_read_size = INITIAL_BUFFER_SIZE_LARGE; + + while (remaining > buffer_read_size) { + buffer_read_size *= 2; + } + + // Check line length + if (remaining > options.maximum_line_size) { + throw InvalidInputException("Maximum line size of %llu bytes exceeded on line %s!", options.maximum_line_size, + GetLineNumberStr(linenr, linenr_estimated)); + } + + buffer = make_unsafe_uniq_array(buffer_read_size + remaining + 1); + buffer_size = remaining + buffer_read_size; + if (remaining > 0) { + // remaining from last buffer: copy it here + memcpy(buffer.get(), old_buffer.get() + start, remaining); + } + idx_t read_count = file_handle->Read(buffer.get() + remaining, buffer_read_size); + + bytes_in_chunk += read_count; + buffer_size = remaining + read_count; + buffer[buffer_size] = '\0'; + if (old_buffer) { + cached_buffers.push_back(std::move(old_buffer)); + } + start = 0; + position = remaining; + if (!bom_checked) { + bom_checked = true; + if (read_count >= 3 && buffer[0] == '\xEF' && buffer[1] == '\xBB' && buffer[2] == '\xBF') { + start += 3; + position += 3; + } + } + line_start = start; + + return read_count > 0; +} + +void BufferedCSVReader::ParseCSV(DataChunk &insert_chunk) { + string error_message; + if (!TryParseCSV(ParserMode::PARSING, insert_chunk, error_message)) { + throw InvalidInputException(error_message); + } +} + +void BufferedCSVReader::ParseCSV(ParserMode mode) { + DataChunk dummy_chunk; + string error_message; + if (!TryParseCSV(mode, dummy_chunk, error_message)) { + throw InvalidInputException(error_message); + } +} + +bool BufferedCSVReader::TryParseCSV(ParserMode parser_mode, DataChunk &insert_chunk, string &error_message) { + mode = parser_mode; + // used for parsing algorithm + bool finished_chunk = false; + idx_t column = 0; + idx_t offset = 0; + bool has_quotes = false; + vector escape_positions; + + idx_t line_start = position; + idx_t line_size = 0; + // read values into the buffer (if any) + if (position >= buffer_size) { + if (!ReadBuffer(start, line_start)) { + return true; + } + } + + // start parsing the first value + goto value_start; +value_start: + offset = 0; + /* state: value_start */ + // this state parses the first character of a value + if (buffer[position] == options.dialect_options.state_machine_options.quote) { + // quote: actual value starts in the next position + // move to in_quotes state + start = position + 1; + line_size++; + goto in_quotes; + } else { + // no quote, move to normal parsing state + start = position; + goto normal; + } +normal: + /* state: normal parsing state */ + // this state parses the remainder of a non-quoted value until we reach a delimiter or newline + do { + for (; position < buffer_size; position++) { + line_size++; + if (buffer[position] == options.dialect_options.state_machine_options.delimiter) { + // delimiter: end the value and add it to the chunk + goto add_value; + } else if (StringUtil::CharacterIsNewline(buffer[position])) { + // newline: add row + goto add_row; + } + } + } while (ReadBuffer(start, line_start)); + // file ends during normal scan: go to end state + goto final_state; +add_value: + AddValue(string_t(buffer.get() + start, position - start - offset), column, escape_positions, has_quotes); + // increase position by 1 and move start to the new position + offset = 0; + has_quotes = false; + start = ++position; + line_size++; + if (position >= buffer_size && !ReadBuffer(start, line_start)) { + // file ends right after delimiter, go to final state + goto final_state; + } + goto value_start; +add_row : { + // check type of newline (\r or \n) + bool carriage_return = buffer[position] == '\r'; + AddValue(string_t(buffer.get() + start, position - start - offset), column, escape_positions, has_quotes); + if (!error_message.empty()) { + return false; + } + VerifyLineLength(position - line_start); + + finished_chunk = AddRow(insert_chunk, column, error_message); + UpdateMaxLineLength(context, position - line_start); + if (!error_message.empty()) { + return false; + } + // increase position by 1 and move start to the new position + offset = 0; + has_quotes = false; + position++; + line_size = 0; + start = position; + line_start = position; + if (position >= buffer_size && !ReadBuffer(start, line_start)) { + // file ends right after delimiter, go to final state + goto final_state; + } + if (carriage_return) { + // \r newline, go to special state that parses an optional \n afterwards + goto carriage_return; + } else { + SetNewLineDelimiter(); + SkipEmptyLines(); + + start = position; + line_start = position; + if (position >= buffer_size && !ReadBuffer(start, line_start)) { + // file ends right after delimiter, go to final state + goto final_state; + } + // \n newline, move to value start + if (finished_chunk) { + return true; + } + goto value_start; + } +} +in_quotes: + /* state: in_quotes */ + // this state parses the remainder of a quoted value + has_quotes = true; + position++; + line_size++; + do { + for (; position < buffer_size; position++) { + line_size++; + if (buffer[position] == options.dialect_options.state_machine_options.quote) { + // quote: move to unquoted state + goto unquote; + } else if (buffer[position] == options.dialect_options.state_machine_options.escape) { + // escape: store the escaped position and move to handle_escape state + escape_positions.push_back(position - start); + goto handle_escape; + } + } + } while (ReadBuffer(start, line_start)); + // still in quoted state at the end of the file, error: + throw InvalidInputException("Error in file \"%s\" on line %s: unterminated quotes. (%s)", options.file_path, + GetLineNumberStr(linenr, linenr_estimated).c_str(), options.ToString()); +unquote: + /* state: unquote */ + // this state handles the state directly after we unquote + // in this state we expect either another quote (entering the quoted state again, and escaping the quote) + // or a delimiter/newline, ending the current value and moving on to the next value + position++; + line_size++; + if (position >= buffer_size && !ReadBuffer(start, line_start)) { + // file ends right after unquote, go to final state + offset = 1; + goto final_state; + } + if (buffer[position] == options.dialect_options.state_machine_options.quote && + (options.dialect_options.state_machine_options.escape == '\0' || + options.dialect_options.state_machine_options.escape == options.dialect_options.state_machine_options.quote)) { + // escaped quote, return to quoted state and store escape position + escape_positions.push_back(position - start); + goto in_quotes; + } else if (buffer[position] == options.dialect_options.state_machine_options.delimiter) { + // delimiter, add value + offset = 1; + goto add_value; + } else if (StringUtil::CharacterIsNewline(buffer[position])) { + offset = 1; + goto add_row; + } else { + error_message = StringUtil::Format( + "Error in file \"%s\" on line %s: quote should be followed by end of value, end of " + "row or another quote. (%s)", + options.file_path, GetLineNumberStr(linenr, linenr_estimated).c_str(), options.ToString()); + return false; + } +handle_escape: + /* state: handle_escape */ + // escape should be followed by a quote or another escape character + position++; + line_size++; + if (position >= buffer_size && !ReadBuffer(start, line_start)) { + error_message = StringUtil::Format( + "Error in file \"%s\" on line %s: neither QUOTE nor ESCAPE is proceeded by ESCAPE. (%s)", options.file_path, + GetLineNumberStr(linenr, linenr_estimated).c_str(), options.ToString()); + return false; + } + if (buffer[position] != options.dialect_options.state_machine_options.quote && + buffer[position] != options.dialect_options.state_machine_options.escape) { + error_message = StringUtil::Format( + "Error in file \"%s\" on line %s: neither QUOTE nor ESCAPE is proceeded by ESCAPE. (%s)", options.file_path, + GetLineNumberStr(linenr, linenr_estimated).c_str(), options.ToString()); + return false; + } + // escape was followed by quote or escape, go back to quoted state + goto in_quotes; +carriage_return: + /* state: carriage_return */ + // this stage optionally skips a newline (\n) character, which allows \r\n to be interpreted as a single line + if (buffer[position] == '\n') { + SetNewLineDelimiter(true, true); + // newline after carriage return: skip + // increase position by 1 and move start to the new position + start = ++position; + line_size++; + + if (position >= buffer_size && !ReadBuffer(start, line_start)) { + // file ends right after delimiter, go to final state + goto final_state; + } + } else { + SetNewLineDelimiter(true, false); + } + if (finished_chunk) { + return true; + } + SkipEmptyLines(); + start = position; + line_start = position; + if (position >= buffer_size && !ReadBuffer(start, line_start)) { + // file ends right after delimiter, go to final state + goto final_state; + } + + goto value_start; +final_state: + if (finished_chunk) { + return true; + } + + if (column > 0 || position > start) { + // remaining values to be added to the chunk + AddValue(string_t(buffer.get() + start, position - start - offset), column, escape_positions, has_quotes); + VerifyLineLength(position - line_start); + + finished_chunk = AddRow(insert_chunk, column, error_message); + SkipEmptyLines(); + UpdateMaxLineLength(context, line_size); + if (!error_message.empty()) { + return false; + } + } + + // final stage, only reached after parsing the file is finished + // flush the parsed chunk and finalize parsing + if (mode == ParserMode::PARSING) { + Flush(insert_chunk); + } + + end_of_file_reached = true; + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/csv_buffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/csv_buffer.cpp new file mode 100644 index 00000000..27e916e4 --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/csv_buffer.cpp @@ -0,0 +1,89 @@ +#include "duckdb/execution/operator/scan/csv/csv_buffer.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +CSVBuffer::CSVBuffer(ClientContext &context, idx_t buffer_size_p, CSVFileHandle &file_handle, + idx_t &global_csv_current_position, idx_t file_number_p) + : context(context), first_buffer(true), file_number(file_number_p), can_seek(file_handle.CanSeek()) { + AllocateBuffer(buffer_size_p); + auto buffer = Ptr(); + actual_buffer_size = file_handle.Read(buffer, buffer_size_p); + while (actual_buffer_size < buffer_size_p && !file_handle.FinishedReading()) { + // We keep reading until this block is full + actual_buffer_size += file_handle.Read(&buffer[actual_buffer_size], buffer_size_p - actual_buffer_size); + } + global_csv_start = global_csv_current_position; + // BOM check (https://en.wikipedia.org/wiki/Byte_order_mark) + if (actual_buffer_size >= 3 && buffer[0] == '\xEF' && buffer[1] == '\xBB' && buffer[2] == '\xBF') { + start_position += 3; + } + last_buffer = file_handle.FinishedReading(); +} + +CSVBuffer::CSVBuffer(CSVFileHandle &file_handle, ClientContext &context, idx_t buffer_size, + idx_t global_csv_current_position, idx_t file_number_p) + : context(context), global_csv_start(global_csv_current_position), file_number(file_number_p), + can_seek(file_handle.CanSeek()) { + AllocateBuffer(buffer_size); + auto buffer = handle.Ptr(); + actual_buffer_size = file_handle.Read(handle.Ptr(), buffer_size); + while (actual_buffer_size < buffer_size && !file_handle.FinishedReading()) { + // We keep reading until this block is full + actual_buffer_size += file_handle.Read(&buffer[actual_buffer_size], buffer_size - actual_buffer_size); + } + last_buffer = file_handle.FinishedReading(); +} + +shared_ptr CSVBuffer::Next(CSVFileHandle &file_handle, idx_t buffer_size, idx_t file_number_p) { + auto next_csv_buffer = + make_shared(file_handle, context, buffer_size, global_csv_start + actual_buffer_size, file_number_p); + if (next_csv_buffer->GetBufferSize() == 0) { + // We are done reading + return nullptr; + } + return next_csv_buffer; +} + +void CSVBuffer::AllocateBuffer(idx_t buffer_size) { + auto &buffer_manager = BufferManager::GetBufferManager(context); + bool can_destroy = can_seek; + handle = buffer_manager.Allocate(MaxValue(Storage::BLOCK_SIZE, buffer_size), can_destroy, &block); +} + +idx_t CSVBuffer::GetBufferSize() { + return actual_buffer_size; +} + +void CSVBuffer::Reload(CSVFileHandle &file_handle) { + AllocateBuffer(actual_buffer_size); + file_handle.Seek(global_csv_start); + file_handle.Read(handle.Ptr(), actual_buffer_size); +} + +unique_ptr CSVBuffer::Pin(CSVFileHandle &file_handle) { + auto &buffer_manager = BufferManager::GetBufferManager(context); + if (can_seek && block->IsUnloaded()) { + // We have to reload it from disk + block = nullptr; + Reload(file_handle); + } + return make_uniq(buffer_manager.Pin(block), actual_buffer_size, first_buffer, last_buffer, + global_csv_start, start_position, file_number); +} + +void CSVBuffer::Unpin() { + if (handle.IsValid()) { + handle.Destroy(); + } +} + +idx_t CSVBuffer::GetStart() { + return start_position; +} + +bool CSVBuffer::IsCSVFileLastBuffer() { + return last_buffer; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/csv_buffer_manager.cpp b/src/duckdb/src/execution/operator/csv_scanner/csv_buffer_manager.cpp new file mode 100644 index 00000000..408e7c85 --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/csv_buffer_manager.cpp @@ -0,0 +1,90 @@ +#include "duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp" +#include "duckdb/execution/operator/scan/csv/csv_buffer.hpp" +namespace duckdb { + +CSVBufferManager::CSVBufferManager(ClientContext &context_p, unique_ptr file_handle_p, + const CSVReaderOptions &options, idx_t file_idx_p) + : file_handle(std::move(file_handle_p)), context(context_p), file_idx(file_idx_p), + buffer_size(CSVBuffer::CSV_BUFFER_SIZE) { + if (options.skip_rows_set) { + // Skip rows if they are set + skip_rows = options.dialect_options.skip_rows; + } + auto file_size = file_handle->FileSize(); + if (file_size > 0 && file_size < buffer_size) { + buffer_size = CSVBuffer::CSV_MINIMUM_BUFFER_SIZE; + } + if (options.buffer_size < buffer_size) { + buffer_size = options.buffer_size; + } + for (idx_t i = 0; i < skip_rows; i++) { + file_handle->ReadLine(); + } + Initialize(); +} + +void CSVBufferManager::UnpinBuffer(idx_t cache_idx) { + if (cache_idx < cached_buffers.size()) { + cached_buffers[cache_idx]->Unpin(); + } +} + +void CSVBufferManager::Initialize() { + if (cached_buffers.empty()) { + cached_buffers.emplace_back( + make_shared(context, buffer_size, *file_handle, global_csv_pos, file_idx)); + last_buffer = cached_buffers.front(); + } + start_pos = last_buffer->GetStart(); +} + +idx_t CSVBufferManager::GetStartPos() { + return start_pos; +} +bool CSVBufferManager::ReadNextAndCacheIt() { + D_ASSERT(last_buffer); + if (!last_buffer->IsCSVFileLastBuffer()) { + auto maybe_last_buffer = last_buffer->Next(*file_handle, buffer_size, file_idx); + if (!maybe_last_buffer) { + last_buffer->last_buffer = true; + return false; + } + last_buffer = std::move(maybe_last_buffer); + cached_buffers.emplace_back(last_buffer); + return true; + } + return false; +} + +unique_ptr CSVBufferManager::GetBuffer(const idx_t pos) { + while (pos >= cached_buffers.size()) { + if (done) { + return nullptr; + } + if (!ReadNextAndCacheIt()) { + done = true; + } + } + if (pos != 0) { + cached_buffers[pos - 1]->Unpin(); + } + return cached_buffers[pos]->Pin(*file_handle); +} + +bool CSVBufferIterator::Finished() { + return !cur_buffer_handle; +} + +void CSVBufferIterator::Reset() { + if (cur_buffer_handle) { + cur_buffer_handle.reset(); + } + if (cur_buffer_idx > 0) { + buffer_manager->UnpinBuffer(cur_buffer_idx - 1); + } + cur_buffer_idx = 0; + buffer_manager->Initialize(); + cur_pos = buffer_manager->GetStartPos(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/csv_file_handle.cpp b/src/duckdb/src/execution/operator/csv_scanner/csv_file_handle.cpp new file mode 100644 index 00000000..6462db94 --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/csv_file_handle.cpp @@ -0,0 +1,95 @@ +#include "duckdb/execution/operator/scan/csv/csv_file_handle.hpp" + +namespace duckdb { + +CSVFileHandle::CSVFileHandle(FileSystem &fs, Allocator &allocator, unique_ptr file_handle_p, + const string &path_p, FileCompressionType compression) + : file_handle(std::move(file_handle_p)), path(path_p) { + can_seek = file_handle->CanSeek(); + on_disk_file = file_handle->OnDiskFile(); + file_size = file_handle->GetFileSize(); +} + +unique_ptr CSVFileHandle::OpenFileHandle(FileSystem &fs, Allocator &allocator, const string &path, + FileCompressionType compression) { + auto file_handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ, FileLockType::NO_LOCK, compression); + if (file_handle->CanSeek()) { + file_handle->Reset(); + } + return file_handle; +} + +unique_ptr CSVFileHandle::OpenFile(FileSystem &fs, Allocator &allocator, const string &path, + FileCompressionType compression) { + auto file_handle = CSVFileHandle::OpenFileHandle(fs, allocator, path, compression); + return make_uniq(fs, allocator, std::move(file_handle), path, compression); +} + +bool CSVFileHandle::CanSeek() { + return can_seek; +} + +void CSVFileHandle::Seek(idx_t position) { + if (!can_seek) { + throw InternalException("Cannot seek in this file"); + } + file_handle->Seek(position); +} + +bool CSVFileHandle::OnDiskFile() { + return on_disk_file; +} + +idx_t CSVFileHandle::FileSize() { + return file_size; +} + +bool CSVFileHandle::FinishedReading() { + return finished; +} + +idx_t CSVFileHandle::Read(void *buffer, idx_t nr_bytes) { + requested_bytes += nr_bytes; + // if this is a plain file source OR we can seek we are not caching anything + auto bytes_read = file_handle->Read(buffer, nr_bytes); + if (!finished) { + finished = bytes_read == 0; + } + return bytes_read; +} + +string CSVFileHandle::ReadLine() { + bool carriage_return = false; + string result; + char buffer[1]; + while (true) { + idx_t bytes_read = Read(buffer, 1); + if (bytes_read == 0) { + return result; + } + if (carriage_return) { + if (buffer[0] != '\n') { + if (!file_handle->CanSeek()) { + throw BinderException( + "Carriage return newlines not supported when reading CSV files in which we cannot seek"); + } + file_handle->Seek(file_handle->SeekPosition() - 1); + return result; + } + } + if (buffer[0] == '\n') { + return result; + } + if (buffer[0] != '\r') { + result += buffer[0]; + } else { + carriage_return = true; + } + } +} + +string CSVFileHandle::GetFilePath() { + return path; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/csv_reader_options.cpp b/src/duckdb/src/execution/operator/csv_scanner/csv_reader_options.cpp new file mode 100644 index 00000000..ceadfaa8 --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/csv_reader_options.cpp @@ -0,0 +1,494 @@ +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/common/bind_helpers.hpp" +#include "duckdb/common/vector_size.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/multi_file_reader.hpp" + +namespace duckdb { + +static bool ParseBoolean(const Value &value, const string &loption); + +static bool ParseBoolean(const vector &set, const string &loption) { + if (set.empty()) { + // no option specified: default to true + return true; + } + if (set.size() > 1) { + throw BinderException("\"%s\" expects a single argument as a boolean value (e.g. TRUE or 1)", loption); + } + return ParseBoolean(set[0], loption); +} + +static bool ParseBoolean(const Value &value, const string &loption) { + + if (value.type().id() == LogicalTypeId::LIST) { + auto &children = ListValue::GetChildren(value); + return ParseBoolean(children, loption); + } + if (value.type() == LogicalType::FLOAT || value.type() == LogicalType::DOUBLE || + value.type().id() == LogicalTypeId::DECIMAL) { + throw BinderException("\"%s\" expects a boolean value (e.g. TRUE or 1)", loption); + } + return BooleanValue::Get(value.DefaultCastAs(LogicalType::BOOLEAN)); +} + +static string ParseString(const Value &value, const string &loption) { + if (value.IsNull()) { + return string(); + } + if (value.type().id() == LogicalTypeId::LIST) { + auto &children = ListValue::GetChildren(value); + if (children.size() != 1) { + throw BinderException("\"%s\" expects a single argument as a string value", loption); + } + return ParseString(children[0], loption); + } + if (value.type().id() != LogicalTypeId::VARCHAR) { + throw BinderException("\"%s\" expects a string argument!", loption); + } + return value.GetValue(); +} + +static int64_t ParseInteger(const Value &value, const string &loption) { + if (value.type().id() == LogicalTypeId::LIST) { + auto &children = ListValue::GetChildren(value); + if (children.size() != 1) { + // no option specified or multiple options specified + throw BinderException("\"%s\" expects a single argument as an integer value", loption); + } + return ParseInteger(children[0], loption); + } + return value.GetValue(); +} + +bool CSVReaderOptions::GetHeader() const { + return this->dialect_options.header; +} + +void CSVReaderOptions::SetHeader(bool input) { + this->dialect_options.header = input; + this->has_header = true; +} + +void CSVReaderOptions::SetCompression(const string &compression_p) { + this->compression = FileCompressionTypeFromString(compression_p); +} + +string CSVReaderOptions::GetEscape() const { + return std::string(1, this->dialect_options.state_machine_options.escape); +} + +void CSVReaderOptions::SetEscape(const string &input) { + auto escape_str = input; + if (escape_str.size() > 1) { + throw InvalidInputException("The escape option cannot exceed a size of 1 byte."); + } + if (escape_str.empty()) { + escape_str = string("\0", 1); + } + this->dialect_options.state_machine_options.escape = escape_str[0]; + this->has_escape = true; +} + +int64_t CSVReaderOptions::GetSkipRows() const { + return this->dialect_options.skip_rows; +} + +void CSVReaderOptions::SetSkipRows(int64_t skip_rows) { + dialect_options.skip_rows = skip_rows; + skip_rows_set = true; +} + +string CSVReaderOptions::GetDelimiter() const { + return std::string(1, this->dialect_options.state_machine_options.delimiter); +} + +void CSVReaderOptions::SetDelimiter(const string &input) { + auto delim_str = StringUtil::Replace(input, "\\t", "\t"); + if (delim_str.size() > 1) { + throw InvalidInputException("The delimiter option cannot exceed a size of 1 byte."); + } + this->has_delimiter = true; + if (input.empty()) { + delim_str = string("\0", 1); + } + this->dialect_options.state_machine_options.delimiter = delim_str[0]; +} + +string CSVReaderOptions::GetQuote() const { + return std::string(1, this->dialect_options.state_machine_options.quote); +} + +void CSVReaderOptions::SetQuote(const string "e_p) { + auto quote_str = quote_p; + if (quote_str.size() > 1) { + throw InvalidInputException("The quote option cannot exceed a size of 1 byte."); + } + if (quote_str.empty()) { + quote_str = string("\0", 1); + } + this->dialect_options.state_machine_options.quote = quote_str[0]; + this->has_quote = true; +} + +NewLineIdentifier CSVReaderOptions::GetNewline() const { + return dialect_options.new_line; +} + +void CSVReaderOptions::SetNewline(const string &input) { + if (input == "\\n" || input == "\\r") { + dialect_options.new_line = NewLineIdentifier::SINGLE; + } else if (input == "\\r\\n") { + dialect_options.new_line = NewLineIdentifier::CARRY_ON; + } else { + throw InvalidInputException("This is not accepted as a newline: " + input); + } + has_newline = true; +} + +void CSVReaderOptions::SetDateFormat(LogicalTypeId type, const string &format, bool read_format) { + string error; + if (read_format) { + error = StrTimeFormat::ParseFormatSpecifier(format, dialect_options.date_format[type]); + dialect_options.date_format[type].format_specifier = format; + } else { + error = StrTimeFormat::ParseFormatSpecifier(format, write_date_format[type]); + } + if (!error.empty()) { + throw InvalidInputException("Could not parse DATEFORMAT: %s", error.c_str()); + } + dialect_options.has_format[type] = true; +} + +void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, vector &expected_names) { + if (SetBaseOption(loption, value)) { + return; + } + if (loption == "auto_detect") { + auto_detect = ParseBoolean(value, loption); + } else if (loption == "sample_size") { + int64_t sample_size_option = ParseInteger(value, loption); + if (sample_size_option < 1 && sample_size_option != -1) { + throw BinderException("Unsupported parameter for SAMPLE_SIZE: cannot be smaller than 1"); + } + if (sample_size_option == -1) { + // If -1, we basically read the whole thing + sample_size_chunks = NumericLimits().Maximum(); + } else { + sample_size_chunks = sample_size_option / STANDARD_VECTOR_SIZE; + if (sample_size_option % STANDARD_VECTOR_SIZE != 0) { + sample_size_chunks++; + } + } + + } else if (loption == "skip") { + SetSkipRows(ParseInteger(value, loption)); + } else if (loption == "max_line_size" || loption == "maximum_line_size") { + maximum_line_size = ParseInteger(value, loption); + } else if (loption == "force_not_null") { + force_not_null = ParseColumnList(value, expected_names, loption); + } else if (loption == "date_format" || loption == "dateformat") { + string format = ParseString(value, loption); + SetDateFormat(LogicalTypeId::DATE, format, true); + } else if (loption == "timestamp_format" || loption == "timestampformat") { + string format = ParseString(value, loption); + SetDateFormat(LogicalTypeId::TIMESTAMP, format, true); + } else if (loption == "ignore_errors") { + ignore_errors = ParseBoolean(value, loption); + } else if (loption == "buffer_size") { + buffer_size = ParseInteger(value, loption); + if (buffer_size == 0) { + throw InvalidInputException("Buffer Size option must be higher than 0"); + } + } else if (loption == "decimal_separator") { + decimal_separator = ParseString(value, loption); + if (decimal_separator != "." && decimal_separator != ",") { + throw BinderException("Unsupported parameter for DECIMAL_SEPARATOR: should be '.' or ','"); + } + } else if (loption == "null_padding") { + null_padding = ParseBoolean(value, loption); + } else if (loption == "allow_quoted_nulls") { + allow_quoted_nulls = ParseBoolean(value, loption); + } else if (loption == "parallel") { + parallel_mode = ParseBoolean(value, loption) ? ParallelMode::PARALLEL : ParallelMode::SINGLE_THREADED; + } else if (loption == "rejects_table") { + // skip, handled in SetRejectsOptions + auto table_name = ParseString(value, loption); + if (table_name.empty()) { + throw BinderException("REJECTS_TABLE option cannot be empty"); + } + rejects_table_name = table_name; + } else if (loption == "rejects_recovery_columns") { + // Get the list of columns to use as a recovery key + auto &children = ListValue::GetChildren(value); + for (auto &child : children) { + auto col_name = child.GetValue(); + rejects_recovery_columns.push_back(col_name); + } + } else if (loption == "rejects_limit") { + int64_t limit = ParseInteger(value, loption); + if (limit < 0) { + throw BinderException("Unsupported parameter for REJECTS_LIMIT: cannot be negative"); + } + rejects_limit = limit; + } else { + throw BinderException("Unrecognized option for CSV reader \"%s\"", loption); + } +} + +void CSVReaderOptions::SetWriteOption(const string &loption, const Value &value) { + if (loption == "new_line") { + // Steal this from SetBaseOption so we can write different newlines (e.g., format JSON ARRAY) + write_newline = ParseString(value, loption); + return; + } + + if (SetBaseOption(loption, value)) { + return; + } + + if (loption == "force_quote") { + force_quote = ParseColumnList(value, name_list, loption); + } else if (loption == "date_format" || loption == "dateformat") { + string format = ParseString(value, loption); + SetDateFormat(LogicalTypeId::DATE, format, false); + } else if (loption == "timestamp_format" || loption == "timestampformat") { + string format = ParseString(value, loption); + if (StringUtil::Lower(format) == "iso") { + format = "%Y-%m-%dT%H:%M:%S.%fZ"; + } + SetDateFormat(LogicalTypeId::TIMESTAMP, format, false); + SetDateFormat(LogicalTypeId::TIMESTAMP_TZ, format, false); + } else if (loption == "prefix") { + prefix = ParseString(value, loption); + } else if (loption == "suffix") { + suffix = ParseString(value, loption); + } else { + throw BinderException("Unrecognized option CSV writer \"%s\"", loption); + } +} + +bool CSVReaderOptions::SetBaseOption(const string &loption, const Value &value) { + // Make sure this function was only called after the option was turned into lowercase + D_ASSERT(!std::any_of(loption.begin(), loption.end(), ::isupper)); + + if (StringUtil::StartsWith(loption, "delim") || StringUtil::StartsWith(loption, "sep")) { + SetDelimiter(ParseString(value, loption)); + } else if (loption == "quote") { + SetQuote(ParseString(value, loption)); + } else if (loption == "new_line") { + SetNewline(ParseString(value, loption)); + } else if (loption == "escape") { + SetEscape(ParseString(value, loption)); + } else if (loption == "header") { + SetHeader(ParseBoolean(value, loption)); + } else if (loption == "null" || loption == "nullstr") { + null_str = ParseString(value, loption); + } else if (loption == "encoding") { + auto encoding = StringUtil::Lower(ParseString(value, loption)); + if (encoding != "utf8" && encoding != "utf-8") { + throw BinderException("Copy is only supported for UTF-8 encoded files, ENCODING 'UTF-8'"); + } + } else if (loption == "compression") { + SetCompression(ParseString(value, loption)); + } else { + // unrecognized option in base CSV + return false; + } + return true; +} + +string CSVReaderOptions::ToString() const { + return " file=" + file_path + "\n delimiter='" + dialect_options.state_machine_options.delimiter + + (has_delimiter ? "'" : (auto_detect ? "' (auto detected)" : "' (default)")) + "\n quote='" + + dialect_options.state_machine_options.quote + + (has_quote ? "'" : (auto_detect ? "' (auto detected)" : "' (default)")) + "\n escape='" + + dialect_options.state_machine_options.escape + + (has_escape ? "'" : (auto_detect ? "' (auto detected)" : "' (default)")) + + "\n header=" + std::to_string(dialect_options.header) + + (has_header ? "" : (auto_detect ? " (auto detected)" : "' (default)")) + + "\n sample_size=" + std::to_string(sample_size_chunks * STANDARD_VECTOR_SIZE) + + "\n ignore_errors=" + std::to_string(ignore_errors) + "\n all_varchar=" + std::to_string(all_varchar); +} + +static Value StringVectorToValue(const vector &vec) { + vector content; + content.reserve(vec.size()); + for (auto &item : vec) { + content.push_back(Value(item)); + } + return Value::LIST(std::move(content)); +} + +static uint8_t GetCandidateSpecificity(const LogicalType &candidate_type) { + //! Const ht with accepted auto_types and their weights in specificity + const duckdb::unordered_map auto_type_candidates_specificity { + {(uint8_t)LogicalTypeId::VARCHAR, 0}, {(uint8_t)LogicalTypeId::TIMESTAMP, 1}, + {(uint8_t)LogicalTypeId::DATE, 2}, {(uint8_t)LogicalTypeId::TIME, 3}, + {(uint8_t)LogicalTypeId::DOUBLE, 4}, {(uint8_t)LogicalTypeId::FLOAT, 5}, + {(uint8_t)LogicalTypeId::BIGINT, 6}, {(uint8_t)LogicalTypeId::INTEGER, 7}, + {(uint8_t)LogicalTypeId::SMALLINT, 8}, {(uint8_t)LogicalTypeId::TINYINT, 9}, + {(uint8_t)LogicalTypeId::BOOLEAN, 10}, {(uint8_t)LogicalTypeId::SQLNULL, 11}}; + + auto id = (uint8_t)candidate_type.id(); + auto it = auto_type_candidates_specificity.find(id); + if (it == auto_type_candidates_specificity.end()) { + throw BinderException("Auto Type Candidate of type %s is not accepted as a valid input", + EnumUtil::ToString(candidate_type.id())); + } + return it->second; +} + +void CSVReaderOptions::FromNamedParameters(named_parameter_map_t &in, ClientContext &context, + vector &return_types, vector &names) { + for (auto &kv : in) { + if (MultiFileReader::ParseOption(kv.first, kv.second, file_options, context)) { + continue; + } + auto loption = StringUtil::Lower(kv.first); + if (loption == "columns") { + explicitly_set_columns = true; + auto &child_type = kv.second.type(); + if (child_type.id() != LogicalTypeId::STRUCT) { + throw BinderException("read_csv columns requires a struct as input"); + } + auto &struct_children = StructValue::GetChildren(kv.second); + D_ASSERT(StructType::GetChildCount(child_type) == struct_children.size()); + for (idx_t i = 0; i < struct_children.size(); i++) { + auto &name = StructType::GetChildName(child_type, i); + auto &val = struct_children[i]; + names.push_back(name); + if (val.type().id() != LogicalTypeId::VARCHAR) { + throw BinderException("read_csv requires a type specification as string"); + } + return_types.emplace_back(TransformStringToLogicalType(StringValue::Get(val), context)); + } + if (names.empty()) { + throw BinderException("read_csv requires at least a single column as input!"); + } + } else if (loption == "auto_type_candidates") { + auto_type_candidates.clear(); + map candidate_types; + // We always have the extremes of Null and Varchar, so we can default to varchar if the + // sniffer is not able to confidently detect that column type + candidate_types[GetCandidateSpecificity(LogicalType::VARCHAR)] = LogicalType::VARCHAR; + candidate_types[GetCandidateSpecificity(LogicalType::SQLNULL)] = LogicalType::SQLNULL; + + auto &child_type = kv.second.type(); + if (child_type.id() != LogicalTypeId::LIST) { + throw BinderException("read_csv auto_types requires a list as input"); + } + auto &list_children = ListValue::GetChildren(kv.second); + if (list_children.empty()) { + throw BinderException("auto_type_candidates requires at least one type"); + } + for (auto &child : list_children) { + if (child.type().id() != LogicalTypeId::VARCHAR) { + throw BinderException("auto_type_candidates requires a type specification as string"); + } + auto candidate_type = TransformStringToLogicalType(StringValue::Get(child), context); + candidate_types[GetCandidateSpecificity(candidate_type)] = candidate_type; + } + for (auto &candidate_type : candidate_types) { + auto_type_candidates.emplace_back(candidate_type.second); + } + } else if (loption == "column_names" || loption == "names") { + if (!name_list.empty()) { + throw BinderException("read_csv_auto column_names/names can only be supplied once"); + } + if (kv.second.IsNull()) { + throw BinderException("read_csv_auto %s cannot be NULL", kv.first); + } + auto &children = ListValue::GetChildren(kv.second); + for (auto &child : children) { + name_list.push_back(StringValue::Get(child)); + } + } else if (loption == "column_types" || loption == "types" || loption == "dtypes") { + auto &child_type = kv.second.type(); + if (child_type.id() != LogicalTypeId::STRUCT && child_type.id() != LogicalTypeId::LIST) { + throw BinderException("read_csv_auto %s requires a struct or list as input", kv.first); + } + if (!sql_type_list.empty()) { + throw BinderException("read_csv_auto column_types/types/dtypes can only be supplied once"); + } + vector sql_type_names; + if (child_type.id() == LogicalTypeId::STRUCT) { + auto &struct_children = StructValue::GetChildren(kv.second); + D_ASSERT(StructType::GetChildCount(child_type) == struct_children.size()); + for (idx_t i = 0; i < struct_children.size(); i++) { + auto &name = StructType::GetChildName(child_type, i); + auto &val = struct_children[i]; + if (val.type().id() != LogicalTypeId::VARCHAR) { + throw BinderException("read_csv_auto %s requires a type specification as string", kv.first); + } + sql_type_names.push_back(StringValue::Get(val)); + sql_types_per_column[name] = i; + } + } else { + auto &list_child = ListType::GetChildType(child_type); + if (list_child.id() != LogicalTypeId::VARCHAR) { + throw BinderException("read_csv_auto %s requires a list of types (varchar) as input", kv.first); + } + auto &children = ListValue::GetChildren(kv.second); + for (auto &child : children) { + sql_type_names.push_back(StringValue::Get(child)); + } + } + sql_type_list.reserve(sql_type_names.size()); + for (auto &sql_type : sql_type_names) { + auto def_type = TransformStringToLogicalType(sql_type, context); + if (def_type.id() == LogicalTypeId::USER) { + throw BinderException("Unrecognized type \"%s\" for read_csv_auto %s definition", sql_type, + kv.first); + } + sql_type_list.push_back(std::move(def_type)); + } + } else if (loption == "all_varchar") { + all_varchar = BooleanValue::Get(kv.second); + } else if (loption == "normalize_names") { + normalize_names = BooleanValue::Get(kv.second); + } else { + SetReadOption(loption, kv.second, names); + } + } +} + +//! This function is used to remember options set by the sniffer, for use in ReadCSVRelation +void CSVReaderOptions::ToNamedParameters(named_parameter_map_t &named_params) { + if (has_delimiter) { + named_params["delim"] = Value(GetDelimiter()); + } + if (has_newline) { + named_params["newline"] = Value(EnumUtil::ToString(GetNewline())); + } + if (has_quote) { + named_params["quote"] = Value(GetQuote()); + } + if (has_escape) { + named_params["escape"] = Value(GetEscape()); + } + if (has_header) { + named_params["header"] = Value(GetHeader()); + } + named_params["max_line_size"] = Value::BIGINT(maximum_line_size); + if (skip_rows_set) { + named_params["skip"] = Value::BIGINT(GetSkipRows()); + } + named_params["null_padding"] = Value::BOOLEAN(null_padding); + if (!date_format.at(LogicalType::DATE).format_specifier.empty()) { + named_params["dateformat"] = Value(date_format.at(LogicalType::DATE).format_specifier); + } + if (!date_format.at(LogicalType::TIMESTAMP).format_specifier.empty()) { + named_params["timestampformat"] = Value(date_format.at(LogicalType::TIMESTAMP).format_specifier); + } + + named_params["normalize_names"] = Value::BOOLEAN(normalize_names); + if (!name_list.empty() && !named_params.count("column_names") && !named_params.count("names")) { + named_params["column_names"] = StringVectorToValue(name_list); + } + named_params["all_varchar"] = Value::BOOLEAN(all_varchar); + named_params["maximum_line_size"] = Value::BIGINT(maximum_line_size); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/csv_state_machine.cpp b/src/duckdb/src/execution/operator/csv_scanner/csv_state_machine.cpp new file mode 100644 index 00000000..785b74f4 --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/csv_state_machine.cpp @@ -0,0 +1,35 @@ +#include "duckdb/execution/operator/scan/csv/csv_state_machine.hpp" +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "utf8proc_wrapper.hpp" +#include "duckdb/main/error_manager.hpp" +#include "duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp" + +namespace duckdb { + +CSVStateMachine::CSVStateMachine(CSVReaderOptions &options_p, const CSVStateMachineOptions &state_machine_options, + shared_ptr buffer_manager_p, + CSVStateMachineCache &csv_state_machine_cache_p) + : csv_state_machine_cache(csv_state_machine_cache_p), options(options_p), + csv_buffer_iterator(std::move(buffer_manager_p)), + transition_array(csv_state_machine_cache.Get(state_machine_options)) { + dialect_options.state_machine_options = state_machine_options; + dialect_options.has_format = options.dialect_options.has_format; + dialect_options.date_format = options.dialect_options.date_format; + dialect_options.skip_rows = options.dialect_options.skip_rows; +} + +void CSVStateMachine::Reset() { + csv_buffer_iterator.Reset(); +} + +void CSVStateMachine::VerifyUTF8() { + auto utf_type = Utf8Proc::Analyze(value.c_str(), value.size()); + if (utf_type == UnicodeType::INVALID) { + int64_t error_line = cur_rows; + throw InvalidInputException("Error in file \"%s\" at line %llu: " + "%s. Parser options:\n%s", + options.file_path, error_line, ErrorManager::InvalidUnicodeError(value, "CSV file"), + options.ToString()); + } +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/csv_state_machine_cache.cpp b/src/duckdb/src/execution/operator/csv_scanner/csv_state_machine_cache.cpp new file mode 100644 index 00000000..4cf52f2b --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/csv_state_machine_cache.cpp @@ -0,0 +1,106 @@ +#include "duckdb/execution/operator/scan/csv/csv_state_machine.hpp" +#include "duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp" + +namespace duckdb { + +void InitializeTransitionArray(unsigned char *transition_array, const uint8_t state) { + for (uint32_t i = 0; i < NUM_TRANSITIONS; i++) { + transition_array[i] = state; + } +} + +void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_options) { + D_ASSERT(state_machine_cache.find(state_machine_options) == state_machine_cache.end()); + // Initialize transition array with default values to the Standard option + auto &transition_array = state_machine_cache[state_machine_options]; + const uint8_t standard_state = static_cast(CSVState::STANDARD); + const uint8_t field_separator_state = static_cast(CSVState::DELIMITER); + const uint8_t record_separator_state = static_cast(CSVState::RECORD_SEPARATOR); + const uint8_t carriage_return_state = static_cast(CSVState::CARRIAGE_RETURN); + const uint8_t quoted_state = static_cast(CSVState::QUOTED); + const uint8_t unquoted_state = static_cast(CSVState::UNQUOTED); + const uint8_t escape_state = static_cast(CSVState::ESCAPE); + const uint8_t empty_line_state = static_cast(CSVState::EMPTY_LINE); + const uint8_t invalid_state = static_cast(CSVState::INVALID); + + for (uint32_t i = 0; i < NUM_STATES; i++) { + switch (i) { + case quoted_state: + InitializeTransitionArray(transition_array[i], quoted_state); + break; + case unquoted_state: + case invalid_state: + case escape_state: + InitializeTransitionArray(transition_array[i], invalid_state); + break; + default: + InitializeTransitionArray(transition_array[i], standard_state); + break; + } + } + + // Now set values depending on configuration + // 1) Standard State + transition_array[standard_state][static_cast(state_machine_options.delimiter)] = field_separator_state; + transition_array[standard_state][static_cast('\n')] = record_separator_state; + transition_array[standard_state][static_cast('\r')] = carriage_return_state; + transition_array[standard_state][static_cast(state_machine_options.quote)] = quoted_state; + // 2) Field Separator State + transition_array[field_separator_state][static_cast(state_machine_options.delimiter)] = + field_separator_state; + transition_array[field_separator_state][static_cast('\n')] = record_separator_state; + transition_array[field_separator_state][static_cast('\r')] = carriage_return_state; + transition_array[field_separator_state][static_cast(state_machine_options.quote)] = quoted_state; + // 3) Record Separator State + transition_array[record_separator_state][static_cast(state_machine_options.delimiter)] = + field_separator_state; + transition_array[record_separator_state][static_cast('\n')] = empty_line_state; + transition_array[record_separator_state][static_cast('\r')] = empty_line_state; + transition_array[record_separator_state][static_cast(state_machine_options.quote)] = quoted_state; + // 4) Carriage Return State + transition_array[carriage_return_state][static_cast('\n')] = record_separator_state; + transition_array[carriage_return_state][static_cast('\r')] = empty_line_state; + transition_array[carriage_return_state][static_cast(state_machine_options.escape)] = escape_state; + // 5) Quoted State + transition_array[quoted_state][static_cast(state_machine_options.quote)] = unquoted_state; + if (state_machine_options.quote != state_machine_options.escape) { + transition_array[quoted_state][static_cast(state_machine_options.escape)] = escape_state; + } + // 6) Unquoted State + transition_array[unquoted_state][static_cast('\n')] = record_separator_state; + transition_array[unquoted_state][static_cast('\r')] = carriage_return_state; + transition_array[unquoted_state][static_cast(state_machine_options.delimiter)] = field_separator_state; + if (state_machine_options.quote == state_machine_options.escape) { + transition_array[unquoted_state][static_cast(state_machine_options.escape)] = quoted_state; + } + // 7) Escaped State + transition_array[escape_state][static_cast(state_machine_options.quote)] = quoted_state; + transition_array[escape_state][static_cast(state_machine_options.escape)] = quoted_state; + // 8) Empty Line State + transition_array[empty_line_state][static_cast('\r')] = empty_line_state; + transition_array[empty_line_state][static_cast('\n')] = empty_line_state; +} + +CSVStateMachineCache::CSVStateMachineCache() { + for (auto quoterule : default_quote_rule) { + const auto "e_candidates = default_quote[static_cast(quoterule)]; + for (const auto "e : quote_candidates) { + for (const auto &delimiter : default_delimiter) { + const auto &escape_candidates = default_escape[static_cast(quoterule)]; + for (const auto &escape : escape_candidates) { + Insert({delimiter, quote, escape}); + } + } + } + } +} + +const state_machine_t &CSVStateMachineCache::Get(const CSVStateMachineOptions &state_machine_options) { + //! Custom State Machine, we need to create it and cache it first + if (state_machine_cache.find(state_machine_options) == state_machine_cache.end()) { + Insert(state_machine_options); + } + const auto &transition_array = state_machine_cache[state_machine_options]; + return transition_array; +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/parallel_csv_reader.cpp b/src/duckdb/src/execution/operator/csv_scanner/parallel_csv_reader.cpp new file mode 100644 index 00000000..73fe6726 --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/parallel_csv_reader.cpp @@ -0,0 +1,685 @@ +#include "duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/storage/data_table.hpp" +#include "utf8proc_wrapper.hpp" +#include "utf8proc.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/execution/operator/scan/csv/csv_line_info.hpp" + +#include +#include +#include +#include + +namespace duckdb { + +ParallelCSVReader::ParallelCSVReader(ClientContext &context, CSVReaderOptions options_p, + unique_ptr buffer_p, idx_t first_pos_first_buffer_p, + const vector &requested_types, idx_t file_idx_p) + : BaseCSVReader(context, std::move(options_p), requested_types), file_idx(file_idx_p), + first_pos_first_buffer(first_pos_first_buffer_p) { + Initialize(requested_types); + SetBufferRead(std::move(buffer_p)); +} + +void ParallelCSVReader::Initialize(const vector &requested_types) { + return_types = requested_types; + InitParseChunk(return_types.size()); +} + +bool ParallelCSVReader::NewLineDelimiter(bool carry, bool carry_followed_by_nl, bool first_char) { + // Set the delimiter if not set yet. + SetNewLineDelimiter(carry, carry_followed_by_nl); + D_ASSERT(options.dialect_options.new_line == NewLineIdentifier::SINGLE || + options.dialect_options.new_line == NewLineIdentifier::CARRY_ON); + if (options.dialect_options.new_line == NewLineIdentifier::SINGLE) { + return (!carry) || (carry && !carry_followed_by_nl); + } + return (carry && carry_followed_by_nl) || (!carry && first_char); +} + +void ParallelCSVReader::SkipEmptyLines() { + idx_t new_pos_buffer = position_buffer; + if (parse_chunk.data.size() == 1) { + // Empty lines are null data. + return; + } + for (; new_pos_buffer < end_buffer; new_pos_buffer++) { + if (StringUtil::CharacterIsNewline((*buffer)[new_pos_buffer])) { + bool carrier_return = (*buffer)[new_pos_buffer] == '\r'; + new_pos_buffer++; + if (carrier_return && new_pos_buffer < buffer_size && (*buffer)[new_pos_buffer] == '\n') { + position_buffer++; + } + if (new_pos_buffer > end_buffer) { + return; + } + position_buffer = new_pos_buffer; + } else if ((*buffer)[new_pos_buffer] != ' ') { + return; + } + } +} + +bool ParallelCSVReader::SetPosition() { + if (buffer->buffer->is_first_buffer && start_buffer == position_buffer && start_buffer == first_pos_first_buffer) { + start_buffer = buffer->buffer->start_position; + position_buffer = start_buffer; + verification_positions.beginning_of_first_line = position_buffer; + verification_positions.end_of_last_line = position_buffer; + // First buffer doesn't need any setting + + if (options.dialect_options.header) { + for (; position_buffer < end_buffer; position_buffer++) { + if (StringUtil::CharacterIsNewline((*buffer)[position_buffer])) { + bool carrier_return = (*buffer)[position_buffer] == '\r'; + position_buffer++; + if (carrier_return && position_buffer < buffer_size && (*buffer)[position_buffer] == '\n') { + position_buffer++; + } + if (position_buffer > end_buffer) { + VerifyLineLength(position_buffer, buffer->batch_index); + return false; + } + SkipEmptyLines(); + if (verification_positions.beginning_of_first_line == 0) { + verification_positions.beginning_of_first_line = position_buffer; + } + VerifyLineLength(position_buffer, buffer->batch_index); + verification_positions.end_of_last_line = position_buffer; + return true; + } + } + VerifyLineLength(position_buffer, buffer->batch_index); + return false; + } + SkipEmptyLines(); + if (verification_positions.beginning_of_first_line == 0) { + verification_positions.beginning_of_first_line = position_buffer; + } + + verification_positions.end_of_last_line = position_buffer; + return true; + } + + // We have to move position up to next new line + idx_t end_buffer_real = end_buffer; + // Check if we already start in a valid line + string error_message; + bool successfully_read_first_line = false; + while (!successfully_read_first_line) { + DataChunk first_line_chunk; + first_line_chunk.Initialize(allocator, return_types); + // Ensure that parse_chunk has no gunk when trying to figure new line + parse_chunk.Reset(); + for (; position_buffer < end_buffer; position_buffer++) { + if (StringUtil::CharacterIsNewline((*buffer)[position_buffer])) { + bool carriage_return = (*buffer)[position_buffer] == '\r'; + bool carriage_return_followed = false; + position_buffer++; + if (position_buffer < end_buffer) { + if (carriage_return && (*buffer)[position_buffer] == '\n') { + carriage_return_followed = true; + position_buffer++; + } + } + if (NewLineDelimiter(carriage_return, carriage_return_followed, position_buffer - 1 == start_buffer)) { + break; + } + } + } + SkipEmptyLines(); + + if (position_buffer > buffer_size) { + break; + } + + auto pos_check = position_buffer == 0 ? position_buffer : position_buffer - 1; + if (position_buffer >= end_buffer && !StringUtil::CharacterIsNewline((*buffer)[pos_check])) { + break; + } + + if (position_buffer > end_buffer && options.dialect_options.new_line == NewLineIdentifier::CARRY_ON && + (*buffer)[pos_check] == '\n') { + break; + } + idx_t position_set = position_buffer; + start_buffer = position_buffer; + // We check if we can add this line + // disable the projection pushdown while reading the first line + // otherwise the first line parsing can be influenced by which columns we are reading + auto column_ids = std::move(reader_data.column_ids); + auto column_mapping = std::move(reader_data.column_mapping); + InitializeProjection(); + try { + successfully_read_first_line = TryParseSimpleCSV(first_line_chunk, error_message, true); + } catch (...) { + successfully_read_first_line = false; + } + // restore the projection pushdown + reader_data.column_ids = std::move(column_ids); + reader_data.column_mapping = std::move(column_mapping); + end_buffer = end_buffer_real; + start_buffer = position_set; + if (position_buffer >= end_buffer) { + if (successfully_read_first_line) { + position_buffer = position_set; + } + break; + } + position_buffer = position_set; + } + if (verification_positions.beginning_of_first_line == 0) { + verification_positions.beginning_of_first_line = position_buffer; + } + // Ensure that parse_chunk has no gunk when trying to figure new line + parse_chunk.Reset(); + + verification_positions.end_of_last_line = position_buffer; + finished = false; + return successfully_read_first_line; +} + +void ParallelCSVReader::SetBufferRead(unique_ptr buffer_read_p) { + if (!buffer_read_p->buffer) { + throw InternalException("ParallelCSVReader::SetBufferRead - CSVBufferRead does not have a buffer to read"); + } + position_buffer = buffer_read_p->buffer_start; + start_buffer = buffer_read_p->buffer_start; + end_buffer = buffer_read_p->buffer_end; + if (buffer_read_p->next_buffer) { + buffer_size = buffer_read_p->buffer->actual_size + buffer_read_p->next_buffer->actual_size; + } else { + buffer_size = buffer_read_p->buffer->actual_size; + } + buffer = std::move(buffer_read_p); + + reached_remainder_state = false; + verification_positions.beginning_of_first_line = 0; + verification_positions.end_of_last_line = 0; + finished = false; + D_ASSERT(end_buffer <= buffer_size); +} + +VerificationPositions ParallelCSVReader::GetVerificationPositions() { + verification_positions.beginning_of_first_line += buffer->buffer->csv_global_start; + verification_positions.end_of_last_line += buffer->buffer->csv_global_start; + return verification_positions; +} + +// If BufferRemainder returns false, it means we are done scanning this buffer and should go to the end_state +bool ParallelCSVReader::BufferRemainder() { + if (position_buffer >= end_buffer && !reached_remainder_state) { + // First time we finish the buffer piece we should scan here, we set the variables + // to allow this piece to be scanned up to the end of the buffer or the next new line + reached_remainder_state = true; + // end_buffer is allowed to go to buffer size to finish its last line + end_buffer = buffer_size; + } + if (position_buffer >= end_buffer) { + // buffer ends, return false + return false; + } + // we can still scan stuff, return true + return true; +} + +bool AllNewLine(string_t value, idx_t column_amount) { + auto value_str = value.GetString(); + if (value_str.empty() && column_amount == 1) { + // This is a one column (empty) + return false; + } + for (idx_t i = 0; i < value.GetSize(); i++) { + if (!StringUtil::CharacterIsNewline(value_str[i])) { + return false; + } + } + return true; +} + +bool ParallelCSVReader::TryParseSimpleCSV(DataChunk &insert_chunk, string &error_message, bool try_add_line) { + // If line is not set, we have to figure it out, we assume whatever is in the first line + if (options.dialect_options.new_line == NewLineIdentifier::NOT_SET) { + idx_t cur_pos = position_buffer; + // we can start in the middle of a new line, so move a bit forward. + while (cur_pos < end_buffer) { + if (StringUtil::CharacterIsNewline((*buffer)[cur_pos])) { + cur_pos++; + } else { + break; + } + } + for (; cur_pos < end_buffer; cur_pos++) { + if (StringUtil::CharacterIsNewline((*buffer)[cur_pos])) { + bool carriage_return = (*buffer)[cur_pos] == '\r'; + bool carriage_return_followed = false; + cur_pos++; + if (cur_pos < end_buffer) { + if (carriage_return && (*buffer)[cur_pos] == '\n') { + carriage_return_followed = true; + cur_pos++; + } + } + SetNewLineDelimiter(carriage_return, carriage_return_followed); + break; + } + } + } + // used for parsing algorithm + if (start_buffer == buffer_size) { + // Nothing to read + finished = true; + return true; + } + D_ASSERT(end_buffer <= buffer_size); + bool finished_chunk = false; + idx_t column = 0; + idx_t offset = 0; + bool has_quotes = false; + + vector escape_positions; + if ((start_buffer == buffer->buffer_start || start_buffer == buffer->buffer_end) && !try_add_line) { + // First time reading this buffer piece + if (!SetPosition()) { + finished = true; + return true; + } + } + if (position_buffer == buffer_size) { + // Nothing to read + finished = true; + return true; + } + // Keep track of line size + idx_t line_start = position_buffer; + // start parsing the first value + goto value_start; + +value_start : { + /* state: value_start */ + if (!BufferRemainder()) { + goto final_state; + } + offset = 0; + + // this state parses the first character of a value + if ((*buffer)[position_buffer] == options.dialect_options.state_machine_options.quote) { + // quote: actual value starts in the next position + // move to in_quotes state + start_buffer = position_buffer + 1; + goto in_quotes; + } else { + // no quote, move to normal parsing state + start_buffer = position_buffer; + goto normal; + } +}; + +normal : { + /* state: normal parsing state */ + // this state parses the remainder of a non-quoted value until we reach a delimiter or newline + for (; position_buffer < end_buffer; position_buffer++) { + auto c = (*buffer)[position_buffer]; + if (c == options.dialect_options.state_machine_options.delimiter) { + // Check if previous character is a quote, if yes, this means we are in a non-initialized quoted value + // This only matters for when trying to figure out where csv lines start + if (position_buffer > 0 && try_add_line) { + if ((*buffer)[position_buffer - 1] == options.dialect_options.state_machine_options.quote) { + return false; + } + } + // delimiter: end the value and add it to the chunk + goto add_value; + } else if (StringUtil::CharacterIsNewline(c)) { + // Check if previous character is a quote, if yes, this means we are in a non-initialized quoted value + // This only matters for when trying to figure out where csv lines start + if (position_buffer > 0 && try_add_line) { + if ((*buffer)[position_buffer - 1] == options.dialect_options.state_machine_options.quote) { + return false; + } + } + // newline: add row + if (column > 0 || try_add_line || parse_chunk.data.size() == 1) { + goto add_row; + } + if (column == 0 && position_buffer == start_buffer) { + start_buffer++; + } + } + } + if (!BufferRemainder()) { + goto final_state; + } else { + goto normal; + } +}; + +add_value : { + /* state: Add value to string vector */ + AddValue(buffer->GetValue(start_buffer, position_buffer, offset), column, escape_positions, has_quotes, + buffer->local_batch_index); + // increase position by 1 and move start to the new position + offset = 0; + has_quotes = false; + start_buffer = ++position_buffer; + if (!BufferRemainder()) { + goto final_state; + } + goto value_start; +}; + +add_row : { + /* state: Add Row to Parse chunk */ + // check type of newline (\r or \n) + bool carriage_return = (*buffer)[position_buffer] == '\r'; + + AddValue(buffer->GetValue(start_buffer, position_buffer, offset), column, escape_positions, has_quotes, + buffer->local_batch_index); + if (try_add_line) { + bool success = column == insert_chunk.ColumnCount(); + if (success) { + idx_t cur_linenr = linenr; + AddRow(insert_chunk, column, error_message, buffer->local_batch_index); + success = Flush(insert_chunk, buffer->local_batch_index, true); + linenr = cur_linenr; + } + reached_remainder_state = false; + parse_chunk.Reset(); + return success; + } else { + VerifyLineLength(position_buffer - line_start, buffer->batch_index); + line_start = position_buffer; + finished_chunk = AddRow(insert_chunk, column, error_message, buffer->local_batch_index); + } + // increase position by 1 and move start to the new position + offset = 0; + has_quotes = false; + position_buffer++; + start_buffer = position_buffer; + verification_positions.end_of_last_line = position_buffer; + if (carriage_return) { + // \r newline, go to special state that parses an optional \n afterwards + // optionally skips a newline (\n) character, which allows \r\n to be interpreted as a single line + if (!BufferRemainder()) { + goto final_state; + } + if ((*buffer)[position_buffer] == '\n') { + if (options.dialect_options.new_line == NewLineIdentifier::SINGLE) { + error_message = "Wrong NewLine Identifier. Expecting \\r\\n"; + return false; + } + // newline after carriage return: skip + // increase position by 1 and move start to the new position + start_buffer = ++position_buffer; + + SkipEmptyLines(); + verification_positions.end_of_last_line = position_buffer; + start_buffer = position_buffer; + if (reached_remainder_state) { + goto final_state; + } + } else { + if (options.dialect_options.new_line == NewLineIdentifier::CARRY_ON) { + error_message = "Wrong NewLine Identifier. Expecting \\r or \\n"; + return false; + } + } + if (!BufferRemainder()) { + goto final_state; + } + if (reached_remainder_state || finished_chunk) { + goto final_state; + } + goto value_start; + } else { + if (options.dialect_options.new_line == NewLineIdentifier::CARRY_ON) { + error_message = "Wrong NewLine Identifier. Expecting \\r or \\n"; + return false; + } + if (reached_remainder_state) { + goto final_state; + } + if (!BufferRemainder()) { + goto final_state; + } + SkipEmptyLines(); + if (position_buffer - verification_positions.end_of_last_line > options.buffer_size) { + error_message = "Line does not fit in one buffer. Increase the buffer size."; + return false; + } + verification_positions.end_of_last_line = position_buffer; + start_buffer = position_buffer; + // \n newline, move to value start + if (finished_chunk) { + goto final_state; + } + goto value_start; + } +} +in_quotes: + /* state: in_quotes this state parses the remainder of a quoted value*/ + has_quotes = true; + position_buffer++; + for (; position_buffer < end_buffer; position_buffer++) { + auto c = (*buffer)[position_buffer]; + if (c == options.dialect_options.state_machine_options.quote) { + // quote: move to unquoted state + goto unquote; + } else if (c == options.dialect_options.state_machine_options.escape) { + // escape: store the escaped position and move to handle_escape state + escape_positions.push_back(position_buffer - start_buffer); + goto handle_escape; + } + } + if (!BufferRemainder()) { + if (buffer->buffer->is_last_buffer) { + if (try_add_line) { + return false; + } + // still in quoted state at the end of the file or at the end of a buffer when running multithreaded, error: + throw InvalidInputException("Error in file \"%s\" on line %s: unterminated quotes. (%s)", options.file_path, + GetLineNumberStr(linenr, linenr_estimated, buffer->local_batch_index).c_str(), + options.ToString()); + } else { + goto final_state; + } + } else { + position_buffer--; + goto in_quotes; + } + +unquote : { + /* state: unquote: this state handles the state directly after we unquote*/ + // + // in this state we expect either another quote (entering the quoted state again, and escaping the quote) + // or a delimiter/newline, ending the current value and moving on to the next value + position_buffer++; + if (!BufferRemainder()) { + offset = 1; + goto final_state; + } + auto c = (*buffer)[position_buffer]; + if (c == options.dialect_options.state_machine_options.quote && + (options.dialect_options.state_machine_options.escape == '\0' || + options.dialect_options.state_machine_options.escape == options.dialect_options.state_machine_options.quote)) { + // escaped quote, return to quoted state and store escape position + escape_positions.push_back(position_buffer - start_buffer); + goto in_quotes; + } else if (c == options.dialect_options.state_machine_options.delimiter) { + // delimiter, add value + offset = 1; + goto add_value; + } else if (StringUtil::CharacterIsNewline(c)) { + offset = 1; + // FIXME: should this be an assertion? + D_ASSERT(try_add_line || (!try_add_line && column == parse_chunk.ColumnCount() - 1)); + goto add_row; + } else if (position_buffer >= end_buffer) { + // reached end of buffer + offset = 1; + goto final_state; + } else { + error_message = StringUtil::Format( + "Error in file \"%s\" on line %s: quote should be followed by end of value, end of " + "row or another quote. (%s). ", + options.file_path, GetLineNumberStr(linenr, linenr_estimated, buffer->local_batch_index).c_str(), + options.ToString()); + return false; + } +} +handle_escape : { + /* state: handle_escape */ + // escape should be followed by a quote or another escape character + position_buffer++; + if (!BufferRemainder()) { + goto final_state; + } + if (position_buffer >= buffer_size && buffer->buffer->is_last_buffer) { + error_message = StringUtil::Format( + "Error in file \"%s\" on line %s: neither QUOTE nor ESCAPE is proceeded by ESCAPE. (%s)", options.file_path, + GetLineNumberStr(linenr, linenr_estimated, buffer->local_batch_index).c_str(), options.ToString()); + return false; + } + if ((*buffer)[position_buffer] != options.dialect_options.state_machine_options.quote && + (*buffer)[position_buffer] != options.dialect_options.state_machine_options.escape) { + error_message = StringUtil::Format( + "Error in file \"%s\" on line %s: neither QUOTE nor ESCAPE is proceeded by ESCAPE. (%s)", options.file_path, + GetLineNumberStr(linenr, linenr_estimated, buffer->local_batch_index).c_str(), options.ToString()); + return false; + } + // escape was followed by quote or escape, go back to quoted state + goto in_quotes; +} +final_state : { + /* state: final_stage reached after we finished reading the end_buffer of the csv buffer */ + // reset end buffer + end_buffer = buffer->buffer_end; + if (position_buffer == end_buffer) { + reached_remainder_state = false; + } + if (finished_chunk) { + if (position_buffer >= end_buffer) { + if (position_buffer == end_buffer && StringUtil::CharacterIsNewline((*buffer)[position_buffer - 1]) && + position_buffer < buffer_size) { + // last position is a new line, we still have to go through one more line of this buffer + finished = false; + } else { + finished = true; + } + } + buffer->lines_read += insert_chunk.size(); + return true; + } + // If this is the last buffer, we have to read the last value + if (buffer->buffer->is_last_buffer || !buffer->next_buffer || + (buffer->next_buffer && buffer->next_buffer->is_last_buffer)) { + if (column > 0 || start_buffer != position_buffer || try_add_line || + (insert_chunk.data.size() == 1 && start_buffer != position_buffer)) { + // remaining values to be added to the chunk + auto str_value = buffer->GetValue(start_buffer, position_buffer, offset); + if (!AllNewLine(str_value, insert_chunk.data.size()) || offset == 0) { + AddValue(str_value, column, escape_positions, has_quotes, buffer->local_batch_index); + if (try_add_line) { + bool success = column == return_types.size(); + if (success) { + auto cur_linenr = linenr; + AddRow(insert_chunk, column, error_message, buffer->local_batch_index); + success = Flush(insert_chunk, buffer->local_batch_index); + linenr = cur_linenr; + } + parse_chunk.Reset(); + reached_remainder_state = false; + return success; + } else { + VerifyLineLength(position_buffer - line_start, buffer->batch_index); + line_start = position_buffer; + AddRow(insert_chunk, column, error_message, buffer->local_batch_index); + if (position_buffer - verification_positions.end_of_last_line > options.buffer_size) { + error_message = "Line does not fit in one buffer. Increase the buffer size."; + return false; + } + verification_positions.end_of_last_line = position_buffer; + } + } + } + } + // flush the parsed chunk and finalize parsing + if (mode == ParserMode::PARSING) { + Flush(insert_chunk, buffer->local_batch_index); + buffer->lines_read += insert_chunk.size(); + } + if (position_buffer - verification_positions.end_of_last_line > options.buffer_size) { + error_message = "Line does not fit in one buffer. Increase the buffer size."; + return false; + } + end_buffer = buffer_size; + SkipEmptyLines(); + end_buffer = buffer->buffer_end; + verification_positions.end_of_last_line = position_buffer; + if (position_buffer >= end_buffer) { + if (position_buffer >= end_buffer) { + if (position_buffer == end_buffer && StringUtil::CharacterIsNewline((*buffer)[position_buffer - 1]) && + position_buffer < buffer_size) { + // last position is a new line, we still have to go through one more line of this buffer + finished = false; + } else { + finished = true; + } + } + } + return true; +}; +} + +void ParallelCSVReader::ParseCSV(DataChunk &insert_chunk) { + string error_message; + if (!TryParseCSV(ParserMode::PARSING, insert_chunk, error_message)) { + throw InvalidInputException(error_message); + } +} + +idx_t ParallelCSVReader::GetLineError(idx_t line_error, idx_t buffer_idx, bool stop_at_first) { + while (true) { + if (buffer->line_info->CanItGetLine(file_idx, buffer_idx)) { + auto cur_start = verification_positions.beginning_of_first_line + buffer->buffer->csv_global_start; + return buffer->line_info->GetLine(buffer_idx, line_error, file_idx, cur_start, false, stop_at_first); + } + } +} + +void ParallelCSVReader::Increment(idx_t buffer_idx) { + return buffer->line_info->Increment(file_idx, buffer_idx); +} + +bool ParallelCSVReader::TryParseCSV(ParserMode mode) { + DataChunk dummy_chunk; + string error_message; + return TryParseCSV(mode, dummy_chunk, error_message); +} + +void ParallelCSVReader::ParseCSV(ParserMode mode) { + DataChunk dummy_chunk; + string error_message; + if (!TryParseCSV(mode, dummy_chunk, error_message)) { + throw InvalidInputException(error_message); + } +} + +bool ParallelCSVReader::TryParseCSV(ParserMode parser_mode, DataChunk &insert_chunk, string &error_message) { + mode = parser_mode; + return TryParseSimpleCSV(insert_chunk, error_message); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp new file mode 100644 index 00000000..cc3fc947 --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp @@ -0,0 +1,61 @@ +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" + +namespace duckdb { + +CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, shared_ptr buffer_manager_p, + CSVStateMachineCache &state_machine_cache_p, bool explicit_set_columns_p) + : state_machine_cache(state_machine_cache_p), options(options_p), buffer_manager(std::move(buffer_manager_p)), + explicit_set_columns(explicit_set_columns_p) { + + // Check if any type is BLOB + for (auto &type : options.sql_type_list) { + if (type.id() == LogicalTypeId::BLOB) { + throw InvalidInputException( + "CSV auto-detect for blobs not supported: there may be invalid UTF-8 in the file"); + } + } + + // Initialize Format Candidates + for (const auto &format_template : format_template_candidates) { + auto &logical_type = format_template.first; + best_format_candidates[logical_type].clear(); + } +} + +SnifferResult CSVSniffer::SniffCSV() { + // 1. Dialect Detection + DetectDialect(); + if (explicit_set_columns) { + if (!candidates.empty()) { + options.dialect_options.state_machine_options = candidates[0]->dialect_options.state_machine_options; + options.dialect_options.new_line = candidates[0]->dialect_options.new_line; + } + // We do not need to run type and header detection as these were defined by the user + return SnifferResult(detected_types, names); + } + // 2. Type Detection + DetectTypes(); + // 3. Header Detection + DetectHeader(); + D_ASSERT(best_sql_types_candidates_per_column_idx.size() == names.size()); + // 4. Type Replacement + ReplaceTypes(); + // 5. Type Refinement + RefineTypes(); + // We are done, construct and return the result. + + // Set the CSV Options in the reference + options.dialect_options = best_candidate->dialect_options; + options.has_header = best_candidate->dialect_options.header; + options.skip_rows_set = options.dialect_options.skip_rows > 0; + if (options.has_header) { + options.dialect_options.true_start = best_start_with_header; + } else { + options.dialect_options.true_start = best_start_without_header; + } + + // Return the types and names + return SnifferResult(detected_types, names); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp new file mode 100644 index 00000000..add96c2d --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp @@ -0,0 +1,339 @@ +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "duckdb/main/client_data.hpp" + +namespace duckdb { + +struct SniffDialect { + inline static void Initialize(CSVStateMachine &machine) { + machine.state = CSVState::STANDARD; + machine.previous_state = CSVState::STANDARD; + machine.pre_previous_state = CSVState::STANDARD; + machine.cur_rows = 0; + machine.column_count = 1; + } + + inline static bool Process(CSVStateMachine &machine, vector &sniffed_column_counts, char current_char, + idx_t current_pos) { + + D_ASSERT(sniffed_column_counts.size() == STANDARD_VECTOR_SIZE); + + if (machine.state == CSVState::INVALID) { + sniffed_column_counts.clear(); + return true; + } + machine.pre_previous_state = machine.previous_state; + machine.previous_state = machine.state; + + machine.state = static_cast( + machine.transition_array[static_cast(machine.state)][static_cast(current_char)]); + + bool carriage_return = machine.previous_state == CSVState::CARRIAGE_RETURN; + machine.column_count += machine.previous_state == CSVState::DELIMITER; + sniffed_column_counts[machine.cur_rows] = machine.column_count; + machine.cur_rows += + machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE; + machine.column_count -= (machine.column_count - 1) * (machine.previous_state == CSVState::RECORD_SEPARATOR); + + // It means our carriage return is actually a record separator + machine.cur_rows += machine.state != CSVState::RECORD_SEPARATOR && carriage_return; + machine.column_count -= + (machine.column_count - 1) * (machine.state != CSVState::RECORD_SEPARATOR && carriage_return); + + // Identify what is our line separator + machine.carry_on_separator = + (machine.state == CSVState::RECORD_SEPARATOR && carriage_return) || machine.carry_on_separator; + machine.single_record_separator = ((machine.state != CSVState::RECORD_SEPARATOR && carriage_return) || + (machine.state == CSVState::RECORD_SEPARATOR && !carriage_return)) || + machine.single_record_separator; + if (machine.cur_rows >= STANDARD_VECTOR_SIZE) { + // We sniffed enough rows + return true; + } + return false; + } + inline static void Finalize(CSVStateMachine &machine, vector &sniffed_column_counts) { + if (machine.state == CSVState::INVALID) { + return; + } + if (machine.cur_rows < STANDARD_VECTOR_SIZE && machine.state == CSVState::DELIMITER) { + sniffed_column_counts[machine.cur_rows] = ++machine.column_count; + } + if (machine.cur_rows < STANDARD_VECTOR_SIZE && machine.state != CSVState::EMPTY_LINE) { + sniffed_column_counts[machine.cur_rows++] = machine.column_count; + } + NewLineIdentifier suggested_newline; + if (machine.carry_on_separator) { + if (machine.single_record_separator) { + suggested_newline = NewLineIdentifier::MIX; + } else { + suggested_newline = NewLineIdentifier::CARRY_ON; + } + } else { + suggested_newline = NewLineIdentifier::SINGLE; + } + if (machine.options.dialect_options.new_line == NewLineIdentifier::NOT_SET) { + machine.dialect_options.new_line = suggested_newline; + } else { + if (machine.options.dialect_options.new_line != suggested_newline) { + // Invalidate this whole detection + machine.cur_rows = 0; + } + } + sniffed_column_counts.erase(sniffed_column_counts.begin() + machine.cur_rows, sniffed_column_counts.end()); + } +}; + +void CSVSniffer::GenerateCandidateDetectionSearchSpace(vector &delim_candidates, + vector "erule_candidates, + unordered_map> "e_candidates_map, + unordered_map> &escape_candidates_map) { + if (options.has_delimiter) { + // user provided a delimiter: use that delimiter + delim_candidates = {options.dialect_options.state_machine_options.delimiter}; + } else { + // no delimiter provided: try standard/common delimiters + delim_candidates = {',', '|', ';', '\t'}; + } + if (options.has_quote) { + // user provided quote: use that quote rule + quote_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = {options.dialect_options.state_machine_options.quote}; + quote_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = {options.dialect_options.state_machine_options.quote}; + quote_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = {options.dialect_options.state_machine_options.quote}; + } else { + // no quote rule provided: use standard/common quotes + quote_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = {'\"'}; + quote_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = {'\"', '\''}; + quote_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = {'\0'}; + } + if (options.has_escape) { + // user provided escape: use that escape rule + if (options.dialect_options.state_machine_options.escape == '\0') { + quoterule_candidates = {QuoteRule::QUOTES_RFC}; + } else { + quoterule_candidates = {QuoteRule::QUOTES_OTHER}; + } + escape_candidates_map[(uint8_t)quoterule_candidates[0]] = { + options.dialect_options.state_machine_options.escape}; + } else { + // no escape provided: try standard/common escapes + quoterule_candidates = {QuoteRule::QUOTES_RFC, QuoteRule::QUOTES_OTHER, QuoteRule::NO_QUOTES}; + } +} + +void CSVSniffer::GenerateStateMachineSearchSpace(vector> &csv_state_machines, + const vector &delimiter_candidates, + const vector "erule_candidates, + const unordered_map> "e_candidates_map, + const unordered_map> &escape_candidates_map) { + // Generate state machines for all option combinations + for (const auto quoterule : quoterule_candidates) { + const auto "e_candidates = quote_candidates_map.at((uint8_t)quoterule); + for (const auto "e : quote_candidates) { + for (const auto &delimiter : delimiter_candidates) { + const auto &escape_candidates = escape_candidates_map.at((uint8_t)quoterule); + for (const auto &escape : escape_candidates) { + D_ASSERT(buffer_manager); + CSVStateMachineOptions state_machine_options(delimiter, quote, escape); + csv_state_machines.emplace_back(make_uniq(options, state_machine_options, + buffer_manager, state_machine_cache)); + } + } + } + } +} + +void CSVSniffer::AnalyzeDialectCandidate(unique_ptr state_machine, idx_t &rows_read, + idx_t &best_consistent_rows, idx_t &prev_padding_count) { + // The sniffed_column_counts variable keeps track of the number of columns found for each row + vector sniffed_column_counts(STANDARD_VECTOR_SIZE); + + state_machine->csv_buffer_iterator.Process(*state_machine, sniffed_column_counts); + idx_t start_row = options.dialect_options.skip_rows; + idx_t consistent_rows = 0; + idx_t num_cols = sniffed_column_counts.empty() ? 0 : sniffed_column_counts[0]; + idx_t padding_count = 0; + bool allow_padding = options.null_padding; + if (sniffed_column_counts.size() > rows_read) { + rows_read = sniffed_column_counts.size(); + } + for (idx_t row = 0; row < sniffed_column_counts.size(); row++) { + if (sniffed_column_counts[row] == num_cols) { + consistent_rows++; + } else if (num_cols < sniffed_column_counts[row] && !options.skip_rows_set) { + // all rows up to this point will need padding + padding_count = 0; + // we use the maximum amount of num_cols that we find + num_cols = sniffed_column_counts[row]; + start_row = row + options.dialect_options.skip_rows; + consistent_rows = 1; + + } else if (num_cols >= sniffed_column_counts[row]) { + // we are missing some columns, we can parse this as long as we add padding + padding_count++; + } + } + + // Calculate the total number of consistent rows after adding padding. + consistent_rows += padding_count; + + // Whether there are more values (rows) available that are consistent, exceeding the current best. + bool more_values = (consistent_rows > best_consistent_rows && num_cols >= max_columns_found); + + // If additional padding is required when compared to the previous padding count. + bool require_more_padding = padding_count > prev_padding_count; + + // If less padding is now required when compared to the previous padding count. + bool require_less_padding = padding_count < prev_padding_count; + + // If there was only a single column before, and the new number of columns exceeds that. + bool single_column_before = max_columns_found < 2 && num_cols > max_columns_found; + + // If the number of rows is consistent with the calculated value after accounting for skipped rows and the + // start row. + bool rows_consistent = + start_row + consistent_rows - options.dialect_options.skip_rows == sniffed_column_counts.size(); + + // If there are more than one consistent row. + bool more_than_one_row = (consistent_rows > 1); + + // If there are more than one column. + bool more_than_one_column = (num_cols > 1); + + // If the start position is valid. + bool start_good = !candidates.empty() && (start_row <= candidates.front()->start_row); + + // If padding happened but it is not allowed. + bool invalid_padding = !allow_padding && padding_count > 0; + + // If rows are consistent and no invalid padding happens, this is the best suitable candidate if one of the + // following is valid: + // - There's a single column before. + // - There are more values and no additional padding is required. + // - There's more than one column and less padding is required. + if (rows_consistent && + (single_column_before || (more_values && !require_more_padding) || + (more_than_one_column && require_less_padding)) && + !invalid_padding) { + best_consistent_rows = consistent_rows; + max_columns_found = num_cols; + prev_padding_count = padding_count; + state_machine->start_row = start_row; + candidates.clear(); + state_machine->dialect_options.num_cols = num_cols; + candidates.emplace_back(std::move(state_machine)); + return; + } + // If there's more than one row and column, the start is good, rows are consistent, + // no additional padding is required, and there is no invalid padding, and there is not yet a candidate + // with the same quote, we add this state_machine as a suitable candidate. + if (more_than_one_row && more_than_one_column && start_good && rows_consistent && !require_more_padding && + !invalid_padding) { + bool same_quote_is_candidate = false; + for (auto &candidate : candidates) { + if (state_machine->dialect_options.state_machine_options.quote == + candidate->dialect_options.state_machine_options.quote) { + same_quote_is_candidate = true; + } + } + if (!same_quote_is_candidate) { + state_machine->start_row = start_row; + state_machine->dialect_options.num_cols = num_cols; + candidates.emplace_back(std::move(state_machine)); + } + } +} + +bool CSVSniffer::RefineCandidateNextChunk(CSVStateMachine &candidate) { + vector sniffed_column_counts(STANDARD_VECTOR_SIZE); + candidate.csv_buffer_iterator.Process(candidate, sniffed_column_counts); + bool allow_padding = options.null_padding; + + for (idx_t row = 0; row < sniffed_column_counts.size(); row++) { + if (max_columns_found != sniffed_column_counts[row] && !allow_padding) { + return false; + } + } + return true; +} + +void CSVSniffer::RefineCandidates() { + // It's very frequent that more than one dialect can parse a csv file, hence here we run one state machine + // fully on the whole sample dataset, when/if it fails we go to the next one. + if (candidates.empty()) { + // No candidates to refine + return; + } + if (candidates.size() == 1 || candidates[0]->csv_buffer_iterator.Finished()) { + // Only one candidate nothing to refine or all candidates already checked + return; + } + for (auto &cur_candidate : candidates) { + for (idx_t i = 1; i <= options.sample_size_chunks; i++) { + bool finished_file = cur_candidate->csv_buffer_iterator.Finished(); + if (finished_file || i == options.sample_size_chunks) { + // we finished the file or our chunk sample successfully: stop + auto successful_candidate = std::move(cur_candidate); + candidates.clear(); + candidates.emplace_back(std::move(successful_candidate)); + return; + } + cur_candidate->cur_rows = 0; + cur_candidate->column_count = 1; + if (!RefineCandidateNextChunk(*cur_candidate)) { + // This candidate failed, move to the next one + break; + } + } + } + candidates.clear(); + return; +} + +// Dialect Detection consists of five steps: +// 1. Generate a search space of all possible dialects +// 2. Generate a state machine for each dialect +// 3. Analyze the first chunk of the file and find the best dialect candidates +// 4. Analyze the remaining chunks of the file and find the best dialect candidate +void CSVSniffer::DetectDialect() { + // Variables for Dialect Detection + // Candidates for the delimiter + vector delim_candidates; + // Quote-Rule Candidates + vector quoterule_candidates; + // Candidates for the quote option + unordered_map> quote_candidates_map; + // Candidates for the escape option + unordered_map> escape_candidates_map; + escape_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = {'\0', '\"', '\''}; + escape_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = {'\\'}; + escape_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = {'\0'}; + // Number of rows read + idx_t rows_read = 0; + // Best Number of consistent rows (i.e., presenting all columns) + idx_t best_consistent_rows = 0; + // If padding was necessary (i.e., rows are missing some columns, how many) + idx_t prev_padding_count = 0; + // Vector of CSV State Machines + vector> csv_state_machines; + + // Step 1: Generate search space + GenerateCandidateDetectionSearchSpace(delim_candidates, quoterule_candidates, quote_candidates_map, + escape_candidates_map); + // Step 2: Generate state machines + GenerateStateMachineSearchSpace(csv_state_machines, delim_candidates, quoterule_candidates, quote_candidates_map, + escape_candidates_map); + // Step 3: Analyze all candidates on the first chunk + for (auto &state_machine : csv_state_machines) { + state_machine->Reset(); + AnalyzeDialectCandidate(std::move(state_machine), rows_read, best_consistent_rows, prev_padding_count); + } + // Step 4: Loop over candidates and find if they can still produce good results for the remaining chunks + RefineCandidates(); + // if no dialect candidate was found, we throw an exception + if (candidates.empty()) { + throw InvalidInputException( + "Error in file \"%s\": CSV options could not be auto-detected. Consider setting parser options manually.", + options.file_path); + } +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp new file mode 100644 index 00000000..152f9baf --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp @@ -0,0 +1,171 @@ +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "utf8proc.hpp" + +namespace duckdb { + +// Helper function to generate column names +static string GenerateColumnName(const idx_t total_cols, const idx_t col_number, const string &prefix = "column") { + int max_digits = NumericHelper::UnsignedLength(total_cols - 1); + int digits = NumericHelper::UnsignedLength(col_number); + string leading_zeros = string(max_digits - digits, '0'); + string value = to_string(col_number); + return string(prefix + leading_zeros + value); +} + +// Helper function for UTF-8 aware space trimming +static string TrimWhitespace(const string &col_name) { + utf8proc_int32_t codepoint; + auto str = reinterpret_cast(col_name.c_str()); + idx_t size = col_name.size(); + // Find the first character that is not left trimmed + idx_t begin = 0; + while (begin < size) { + auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); + D_ASSERT(bytes > 0); + if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { + break; + } + begin += bytes; + } + + // Find the last character that is not right trimmed + idx_t end; + end = begin; + for (auto next = begin; next < col_name.size();) { + auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); + D_ASSERT(bytes > 0); + next += bytes; + if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { + end = next; + } + } + + // return the trimmed string + return col_name.substr(begin, end - begin); +} + +static string NormalizeColumnName(const string &col_name) { + // normalize UTF8 characters to NFKD + auto nfkd = utf8proc_NFKD(reinterpret_cast(col_name.c_str()), col_name.size()); + const string col_name_nfkd = string(const_char_ptr_cast(nfkd), strlen(const_char_ptr_cast(nfkd))); + free(nfkd); + + // only keep ASCII characters 0-9 a-z A-Z and replace spaces with regular whitespace + string col_name_ascii = ""; + for (idx_t i = 0; i < col_name_nfkd.size(); i++) { + if (col_name_nfkd[i] == '_' || (col_name_nfkd[i] >= '0' && col_name_nfkd[i] <= '9') || + (col_name_nfkd[i] >= 'A' && col_name_nfkd[i] <= 'Z') || + (col_name_nfkd[i] >= 'a' && col_name_nfkd[i] <= 'z')) { + col_name_ascii += col_name_nfkd[i]; + } else if (StringUtil::CharacterIsSpace(col_name_nfkd[i])) { + col_name_ascii += " "; + } + } + + // trim whitespace and replace remaining whitespace by _ + string col_name_trimmed = TrimWhitespace(col_name_ascii); + string col_name_cleaned = ""; + bool in_whitespace = false; + for (idx_t i = 0; i < col_name_trimmed.size(); i++) { + if (col_name_trimmed[i] == ' ') { + if (!in_whitespace) { + col_name_cleaned += "_"; + in_whitespace = true; + } + } else { + col_name_cleaned += col_name_trimmed[i]; + in_whitespace = false; + } + } + + // don't leave string empty; if not empty, make lowercase + if (col_name_cleaned.empty()) { + col_name_cleaned = "_"; + } else { + col_name_cleaned = StringUtil::Lower(col_name_cleaned); + } + + // prepend _ if name starts with a digit or is a reserved keyword + if (KeywordHelper::IsKeyword(col_name_cleaned) || (col_name_cleaned[0] >= '0' && col_name_cleaned[0] <= '9')) { + col_name_cleaned = "_" + col_name_cleaned; + } + return col_name_cleaned; +} +void CSVSniffer::DetectHeader() { + // information for header detection + bool first_row_consistent = true; + // check if header row is all null and/or consistent with detected column data types + bool first_row_nulls = true; + // This case will fail in dialect detection, so we assert here just for sanity + D_ASSERT(best_candidate->options.null_padding || + best_sql_types_candidates_per_column_idx.size() == best_header_row.size()); + for (idx_t col = 0; col < best_header_row.size(); col++) { + auto dummy_val = best_header_row[col]; + if (!dummy_val.IsNull()) { + first_row_nulls = false; + } + + // try cast to sql_type of column + const auto &sql_type = best_sql_types_candidates_per_column_idx[col].back(); + if (!TryCastValue(*best_candidate, dummy_val, sql_type)) { + first_row_consistent = false; + } + } + bool has_header; + if (!best_candidate->options.has_header) { + has_header = !first_row_consistent || first_row_nulls; + } else { + has_header = best_candidate->options.dialect_options.header; + } + // update parser info, and read, generate & set col_names based on previous findings + if (has_header) { + best_candidate->dialect_options.header = true; + case_insensitive_map_t name_collision_count; + + // get header names from CSV + for (idx_t col = 0; col < best_header_row.size(); col++) { + const auto &val = best_header_row[col]; + string col_name = val.ToString(); + + // generate name if field is empty + if (col_name.empty() || val.IsNull()) { + col_name = GenerateColumnName(best_candidate->dialect_options.num_cols, col); + } + + // normalize names or at least trim whitespace + if (best_candidate->options.normalize_names) { + col_name = NormalizeColumnName(col_name); + } else { + col_name = TrimWhitespace(col_name); + } + + // avoid duplicate header names + while (name_collision_count.find(col_name) != name_collision_count.end()) { + name_collision_count[col_name] += 1; + col_name = col_name + "_" + to_string(name_collision_count[col_name]); + } + names.push_back(col_name); + name_collision_count[col_name] = 0; + } + if (best_header_row.size() < best_candidate->dialect_options.num_cols && options.null_padding) { + for (idx_t col = best_header_row.size(); col < best_candidate->dialect_options.num_cols; col++) { + names.push_back(GenerateColumnName(best_candidate->dialect_options.num_cols, col)); + } + } else if (best_header_row.size() < best_candidate->dialect_options.num_cols) { + throw InternalException("Detected header has number of columns inferior to dialect detection"); + } + + } else { + best_candidate->dialect_options.header = false; + for (idx_t col = 0; col < best_candidate->dialect_options.num_cols; col++) { + names.push_back(GenerateColumnName(best_candidate->dialect_options.num_cols, col)); + } + } + + // If the user provided names, we must replace our header with the user provided names + for (idx_t i = 0; i < MinValue(names.size(), best_candidate->options.name_list.size()); i++) { + names[i] = best_candidate->options.name_list[i]; + } +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp new file mode 100644 index 00000000..c7a300cc --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp @@ -0,0 +1,415 @@ +#include "duckdb/common/operator/decimal_cast_operators.hpp" +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/string.hpp" + +namespace duckdb { +struct TryCastFloatingOperator { + template + static bool Operation(string_t input) { + T result; + string error_message; + return OP::Operation(input, result, &error_message); + } +}; + +struct TupleSniffing { + idx_t line_number; + idx_t position; + bool set = false; + vector values; +}; + +static bool StartsWithNumericDate(string &separator, const string &value) { + auto begin = value.c_str(); + auto end = begin + value.size(); + + // StrpTimeFormat::Parse will skip whitespace, so we can too + auto field1 = std::find_if_not(begin, end, StringUtil::CharacterIsSpace); + if (field1 == end) { + return false; + } + + // first numeric field must start immediately + if (!StringUtil::CharacterIsDigit(*field1)) { + return false; + } + auto literal1 = std::find_if_not(field1, end, StringUtil::CharacterIsDigit); + if (literal1 == end) { + return false; + } + + // second numeric field must exist + auto field2 = std::find_if(literal1, end, StringUtil::CharacterIsDigit); + if (field2 == end) { + return false; + } + auto literal2 = std::find_if_not(field2, end, StringUtil::CharacterIsDigit); + if (literal2 == end) { + return false; + } + + // third numeric field must exist + auto field3 = std::find_if(literal2, end, StringUtil::CharacterIsDigit); + if (field3 == end) { + return false; + } + + // second literal must match first + if (((field3 - literal2) != (field2 - literal1)) || strncmp(literal1, literal2, (field2 - literal1)) != 0) { + return false; + } + + // copy the literal as the separator, escaping percent signs + separator.clear(); + while (literal1 < field2) { + const auto literal_char = *literal1++; + if (literal_char == '%') { + separator.push_back(literal_char); + } + separator.push_back(literal_char); + } + + return true; +} + +string GenerateDateFormat(const string &separator, const char *format_template) { + string format_specifier = format_template; + auto amount_of_dashes = std::count(format_specifier.begin(), format_specifier.end(), '-'); + // All our date formats must have at least one - + D_ASSERT(amount_of_dashes); + string result; + result.reserve(format_specifier.size() - amount_of_dashes + (amount_of_dashes * separator.size())); + for (auto &character : format_specifier) { + if (character == '-') { + result += separator; + } else { + result += character; + } + } + return result; +} + +bool CSVSniffer::TryCastValue(CSVStateMachine &candidate, const Value &value, const LogicalType &sql_type) { + if (value.IsNull()) { + return true; + } + if (candidate.dialect_options.has_format.find(LogicalTypeId::DATE)->second && + sql_type.id() == LogicalTypeId::DATE) { + date_t result; + string error_message; + return candidate.dialect_options.date_format.find(LogicalTypeId::DATE) + ->second.TryParseDate(string_t(StringValue::Get(value)), result, error_message); + } + if (candidate.dialect_options.has_format.find(LogicalTypeId::TIMESTAMP)->second && + sql_type.id() == LogicalTypeId::TIMESTAMP) { + timestamp_t result; + string error_message; + return candidate.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP) + ->second.TryParseTimestamp(string_t(StringValue::Get(value)), result, error_message); + } + if (candidate.options.decimal_separator != "." && (sql_type.id() == LogicalTypeId::DOUBLE)) { + return TryCastFloatingOperator::Operation(StringValue::Get(value)); + } + Value new_value; + string error_message; + return value.TryCastAs(buffer_manager->context, sql_type, new_value, &error_message, true); +} + +void CSVSniffer::SetDateFormat(CSVStateMachine &candidate, const string &format_specifier, + const LogicalTypeId &sql_type) { + candidate.dialect_options.has_format[sql_type] = true; + auto &date_format = candidate.dialect_options.date_format[sql_type]; + date_format.format_specifier = format_specifier; + StrTimeFormat::ParseFormatSpecifier(date_format.format_specifier, date_format); +} + +struct SniffValue { + inline static void Initialize(CSVStateMachine &machine) { + machine.state = CSVState::STANDARD; + machine.previous_state = CSVState::STANDARD; + machine.pre_previous_state = CSVState::STANDARD; + machine.cur_rows = 0; + machine.value = ""; + machine.rows_read = 0; + } + + inline static bool Process(CSVStateMachine &machine, vector &sniffed_values, char current_char, + idx_t current_pos) { + + if ((machine.dialect_options.new_line == NewLineIdentifier::SINGLE && + (current_char == '\r' || current_char == '\n')) || + (machine.dialect_options.new_line == NewLineIdentifier::CARRY_ON && current_char == '\n')) { + machine.rows_read++; + } + + if ((machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE) || + (machine.state != CSVState::RECORD_SEPARATOR && machine.previous_state == CSVState::CARRIAGE_RETURN)) { + sniffed_values[machine.cur_rows].position = machine.line_start_pos; + sniffed_values[machine.cur_rows].set = true; + machine.line_start_pos = current_pos; + } + machine.pre_previous_state = machine.previous_state; + machine.previous_state = machine.state; + machine.state = static_cast( + machine.transition_array[static_cast(machine.state)][static_cast(current_char)]); + + bool carriage_return = machine.previous_state == CSVState::CARRIAGE_RETURN; + if (machine.previous_state == CSVState::DELIMITER || + (machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE) || + (machine.state != CSVState::RECORD_SEPARATOR && carriage_return)) { + // Started a new value + // Check if it's UTF-8 + machine.VerifyUTF8(); + if (machine.value.empty() || machine.value == machine.options.null_str) { + // We set empty == null value + sniffed_values[machine.cur_rows].values.push_back(Value(LogicalType::VARCHAR)); + } else { + sniffed_values[machine.cur_rows].values.push_back(Value(machine.value)); + } + sniffed_values[machine.cur_rows].line_number = machine.rows_read; + + machine.value = ""; + } + if (machine.state == CSVState::STANDARD || + (machine.state == CSVState::QUOTED && machine.previous_state == CSVState::QUOTED)) { + machine.value += current_char; + } + machine.cur_rows += + machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE; + // It means our carriage return is actually a record separator + machine.cur_rows += machine.state != CSVState::RECORD_SEPARATOR && carriage_return; + if (machine.cur_rows >= sniffed_values.size()) { + // We sniffed enough rows + return true; + } + return false; + } + + inline static void Finalize(CSVStateMachine &machine, vector &sniffed_values) { + if (machine.cur_rows < sniffed_values.size() && machine.state == CSVState::DELIMITER) { + // Started a new empty value + sniffed_values[machine.cur_rows].values.push_back(Value(machine.value)); + } + if (machine.cur_rows < sniffed_values.size() && machine.state != CSVState::EMPTY_LINE) { + machine.VerifyUTF8(); + sniffed_values[machine.cur_rows].line_number = machine.rows_read; + if (!sniffed_values[machine.cur_rows].set) { + sniffed_values[machine.cur_rows].position = machine.line_start_pos; + sniffed_values[machine.cur_rows].set = true; + } + + sniffed_values[machine.cur_rows++].values.push_back(Value(machine.value)); + } + sniffed_values.erase(sniffed_values.end() - (sniffed_values.size() - machine.cur_rows), sniffed_values.end()); + } +}; + +void CSVSniffer::DetectDateAndTimeStampFormats(CSVStateMachine &candidate, + map &has_format_candidates, + map> &format_candidates, + const LogicalType &sql_type, const string &separator, Value &dummy_val) { + // generate date format candidates the first time through + auto &type_format_candidates = format_candidates[sql_type.id()]; + const auto had_format_candidates = has_format_candidates[sql_type.id()]; + if (!has_format_candidates[sql_type.id()]) { + has_format_candidates[sql_type.id()] = true; + // order by preference + auto entry = format_template_candidates.find(sql_type.id()); + if (entry != format_template_candidates.end()) { + const auto &format_template_list = entry->second; + for (const auto &t : format_template_list) { + const auto format_string = GenerateDateFormat(separator, t); + // don't parse ISO 8601 + if (format_string.find("%Y-%m-%d") == string::npos) { + type_format_candidates.emplace_back(format_string); + } + } + } + // initialise the first candidate + candidate.dialect_options.has_format[sql_type.id()] = true; + // all formats are constructed to be valid + SetDateFormat(candidate, type_format_candidates.back(), sql_type.id()); + } + // check all formats and keep the first one that works + StrpTimeFormat::ParseResult result; + auto save_format_candidates = type_format_candidates; + while (!type_format_candidates.empty()) { + // avoid using exceptions for flow control... + auto ¤t_format = candidate.dialect_options.date_format[sql_type.id()]; + if (current_format.Parse(StringValue::Get(dummy_val), result)) { + break; + } + // doesn't work - move to the next one + type_format_candidates.pop_back(); + candidate.dialect_options.has_format[sql_type.id()] = (!type_format_candidates.empty()); + if (!type_format_candidates.empty()) { + SetDateFormat(candidate, type_format_candidates.back(), sql_type.id()); + } + } + // if none match, then this is not a value of type sql_type, + if (type_format_candidates.empty()) { + // so restore the candidates that did work. + // or throw them out if they were generated by this value. + if (had_format_candidates) { + type_format_candidates.swap(save_format_candidates); + if (!type_format_candidates.empty()) { + SetDateFormat(candidate, type_format_candidates.back(), sql_type.id()); + } + } else { + has_format_candidates[sql_type.id()] = false; + } + } +} + +void CSVSniffer::DetectTypes() { + idx_t min_varchar_cols = max_columns_found + 1; + vector return_types; + // check which info candidate leads to minimum amount of non-varchar columns... + for (auto &candidate : candidates) { + unordered_map> info_sql_types_candidates; + for (idx_t i = 0; i < candidate->dialect_options.num_cols; i++) { + info_sql_types_candidates[i] = candidate->options.auto_type_candidates; + } + map has_format_candidates; + map> format_candidates; + for (const auto &t : format_template_candidates) { + has_format_candidates[t.first] = false; + format_candidates[t.first].clear(); + } + D_ASSERT(candidate->dialect_options.num_cols > 0); + + // Set all return_types to VARCHAR so we can do datatype detection based on VARCHAR values + return_types.clear(); + return_types.assign(candidate->dialect_options.num_cols, LogicalType::VARCHAR); + + // Reset candidate for parsing + candidate->Reset(); + + // Parse chunk and read csv with info candidate + vector tuples(STANDARD_VECTOR_SIZE); + candidate->csv_buffer_iterator.Process(*candidate, tuples); + // Potentially Skip empty rows (I find this dirty, but it is what the original code does) + // The true line where parsing starts in reference to the csv file + idx_t true_line_start = 0; + idx_t true_pos = 0; + // The start point of the tuples + idx_t tuple_true_start = 0; + while (tuple_true_start < tuples.size()) { + if (tuples[tuple_true_start].values.empty() || + (tuples[tuple_true_start].values.size() == 1 && tuples[tuple_true_start].values[0].IsNull())) { + true_line_start = tuples[tuple_true_start].line_number; + true_pos = tuples[tuple_true_start].position; + tuple_true_start++; + } else { + break; + } + } + + // Potentially Skip Notes (I also find this dirty, but it is what the original code does) + while (tuple_true_start < tuples.size()) { + if (tuples[tuple_true_start].values.size() < max_columns_found && !options.null_padding) { + true_line_start = tuples[tuple_true_start].line_number; + true_pos = tuples[tuple_true_start].position; + tuple_true_start++; + } else { + break; + } + } + if (tuple_true_start < tuples.size()) { + true_pos = tuples[tuple_true_start].position; + } + if (tuple_true_start > 0) { + tuples.erase(tuples.begin(), tuples.begin() + tuple_true_start); + } + + idx_t row_idx = 0; + if (tuples.size() > 1 && (!options.has_header || (options.has_header && options.dialect_options.header))) { + // This means we have more than one row, hence we can use the first row to detect if we have a header + row_idx = 1; + } + if (!tuples.empty()) { + best_start_without_header = tuples[0].position - true_pos; + } + + // First line where we start our type detection + const idx_t start_idx_detection = row_idx; + for (; row_idx < tuples.size(); row_idx++) { + for (idx_t col = 0; col < tuples[row_idx].values.size(); col++) { + auto &col_type_candidates = info_sql_types_candidates[col]; + // col_type_candidates can't be empty since anything in a CSV file should at least be a string + // and we validate utf-8 compatibility when creating the type + D_ASSERT(!col_type_candidates.empty()); + auto cur_top_candidate = col_type_candidates.back(); + auto dummy_val = tuples[row_idx].values[col]; + // try cast from string to sql_type + while (col_type_candidates.size() > 1) { + const auto &sql_type = col_type_candidates.back(); + // try formatting for date types if the user did not specify one and it starts with numeric values. + string separator; + bool has_format_is_set = false; + auto format_iterator = candidate->dialect_options.has_format.find(sql_type.id()); + if (format_iterator != candidate->dialect_options.has_format.end()) { + has_format_is_set = format_iterator->second; + } + if (has_format_candidates.count(sql_type.id()) && + (!has_format_is_set || format_candidates[sql_type.id()].size() > 1) && !dummy_val.IsNull() && + StartsWithNumericDate(separator, StringValue::Get(dummy_val))) { + DetectDateAndTimeStampFormats(*candidate, has_format_candidates, format_candidates, sql_type, + separator, dummy_val); + } + // try cast from string to sql_type + if (TryCastValue(*candidate, dummy_val, sql_type)) { + break; + } else { + if (row_idx != start_idx_detection && cur_top_candidate == LogicalType::BOOLEAN) { + // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we + // immediately pop to varchar. + while (col_type_candidates.back() != LogicalType::VARCHAR) { + col_type_candidates.pop_back(); + } + break; + } + col_type_candidates.pop_back(); + } + } + } + } + + idx_t varchar_cols = 0; + + for (idx_t col = 0; col < info_sql_types_candidates.size(); col++) { + auto &col_type_candidates = info_sql_types_candidates[col]; + // check number of varchar columns + const auto &col_type = col_type_candidates.back(); + if (col_type == LogicalType::VARCHAR) { + varchar_cols++; + } + } + + // it's good if the dialect creates more non-varchar columns, but only if we sacrifice < 30% of best_num_cols. + if (varchar_cols < min_varchar_cols && info_sql_types_candidates.size() > (max_columns_found * 0.7)) { + // we have a new best_options candidate + if (true_line_start > 0) { + // Add empty rows to skip_rows + candidate->dialect_options.skip_rows += true_line_start; + } + best_candidate = std::move(candidate); + min_varchar_cols = varchar_cols; + best_sql_types_candidates_per_column_idx = info_sql_types_candidates; + best_format_candidates = format_candidates; + best_header_row = tuples[0].values; + best_start_with_header = tuples[0].position - true_pos; + } + } + // Assert that it's all good at this point. + D_ASSERT(best_candidate && !best_format_candidates.empty() && !best_header_row.empty()); + + for (const auto &best : best_format_candidates) { + if (!best.second.empty()) { + SetDateFormat(*best_candidate, best.second.back(), best.first); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp new file mode 100644 index 00000000..66f2547a --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp @@ -0,0 +1,196 @@ +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "duckdb/execution/operator/scan/csv/base_csv_reader.hpp" +namespace duckdb { +struct Parse { + inline static void Initialize(CSVStateMachine &machine) { + machine.state = CSVState::STANDARD; + machine.previous_state = CSVState::STANDARD; + machine.pre_previous_state = CSVState::STANDARD; + + machine.cur_rows = 0; + machine.column_count = 0; + machine.value = ""; + } + + inline static bool Process(CSVStateMachine &machine, DataChunk &parse_chunk, char current_char, idx_t current_pos) { + + machine.pre_previous_state = machine.previous_state; + machine.previous_state = machine.state; + machine.state = static_cast( + machine.transition_array[static_cast(machine.state)][static_cast(current_char)]); + + bool carriage_return = machine.previous_state == CSVState::CARRIAGE_RETURN; + if (machine.previous_state == CSVState::DELIMITER || + (machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE) || + (machine.state != CSVState::RECORD_SEPARATOR && carriage_return)) { + // Started a new value + // Check if it's UTF-8 (Or not?) + machine.VerifyUTF8(); + auto &v = parse_chunk.data[machine.column_count++]; + auto parse_data = FlatVector::GetData(v); + auto &validity_mask = FlatVector::Validity(v); + if (machine.value.empty()) { + validity_mask.SetInvalid(machine.cur_rows); + } else { + parse_data[machine.cur_rows] = StringVector::AddStringOrBlob(v, string_t(machine.value)); + } + machine.value = ""; + } + if (((machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE) || + (machine.state != CSVState::RECORD_SEPARATOR && carriage_return)) && + machine.options.null_padding && machine.column_count < parse_chunk.ColumnCount()) { + // It's a new row, check if we need to pad stuff + while (machine.column_count < parse_chunk.ColumnCount()) { + auto &v = parse_chunk.data[machine.column_count++]; + auto &validity_mask = FlatVector::Validity(v); + validity_mask.SetInvalid(machine.cur_rows); + } + } + if (machine.state == CSVState::STANDARD || + (machine.state == CSVState::QUOTED && machine.previous_state == CSVState::QUOTED)) { + machine.value += current_char; + } + machine.cur_rows += + machine.previous_state == CSVState::RECORD_SEPARATOR && machine.state != CSVState::EMPTY_LINE; + machine.column_count -= machine.column_count * (machine.previous_state == CSVState::RECORD_SEPARATOR); + + // It means our carriage return is actually a record separator + machine.cur_rows += machine.state != CSVState::RECORD_SEPARATOR && carriage_return; + machine.column_count -= machine.column_count * (machine.state != CSVState::RECORD_SEPARATOR && carriage_return); + + if (machine.cur_rows >= STANDARD_VECTOR_SIZE) { + // We sniffed enough rows + return true; + } + return false; + } + + inline static void Finalize(CSVStateMachine &machine, DataChunk &parse_chunk) { + if (machine.cur_rows < STANDARD_VECTOR_SIZE && machine.state != CSVState::EMPTY_LINE) { + machine.VerifyUTF8(); + auto &v = parse_chunk.data[machine.column_count++]; + auto parse_data = FlatVector::GetData(v); + if (machine.value.empty()) { + auto &validity_mask = FlatVector::Validity(v); + validity_mask.SetInvalid(machine.cur_rows); + } else { + parse_data[machine.cur_rows] = StringVector::AddStringOrBlob(v, string_t(machine.value)); + } + while (machine.column_count < parse_chunk.ColumnCount()) { + auto &v_pad = parse_chunk.data[machine.column_count++]; + auto &validity_mask = FlatVector::Validity(v_pad); + validity_mask.SetInvalid(machine.cur_rows); + } + machine.cur_rows++; + } + parse_chunk.SetCardinality(machine.cur_rows); + } +}; + +bool CSVSniffer::TryCastVector(Vector &parse_chunk_col, idx_t size, const LogicalType &sql_type) { + // try vector-cast from string to sql_type + Vector dummy_result(sql_type); + if (best_candidate->dialect_options.has_format[LogicalTypeId::DATE] && sql_type == LogicalTypeId::DATE) { + // use the date format to cast the chunk + string error_message; + idx_t line_error; + return BaseCSVReader::TryCastDateVector(best_candidate->dialect_options.date_format, parse_chunk_col, + dummy_result, size, error_message, line_error); + } + if (best_candidate->dialect_options.has_format[LogicalTypeId::TIMESTAMP] && sql_type == LogicalTypeId::TIMESTAMP) { + // use the timestamp format to cast the chunk + string error_message; + return BaseCSVReader::TryCastTimestampVector(best_candidate->dialect_options.date_format, parse_chunk_col, + dummy_result, size, error_message); + } + // target type is not varchar: perform a cast + string error_message; + return VectorOperations::DefaultTryCast(parse_chunk_col, dummy_result, size, &error_message, true); +} + +void CSVSniffer::RefineTypes() { + // if data types were provided, exit here if number of columns does not match + detected_types.assign(best_candidate->dialect_options.num_cols, LogicalType::VARCHAR); + if (best_candidate->options.all_varchar) { + // return all types varchar + return; + } + DataChunk parse_chunk; + parse_chunk.Initialize(BufferAllocator::Get(buffer_manager->context), detected_types, STANDARD_VECTOR_SIZE); + for (idx_t i = 1; i < best_candidate->options.sample_size_chunks; i++) { + bool finished_file = best_candidate->csv_buffer_iterator.Finished(); + if (finished_file) { + // we finished the file: stop + // set sql types + detected_types.clear(); + for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { + LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); + if (best_sql_types_candidates_per_column_idx[column_idx].size() == + best_candidate->options.auto_type_candidates.size()) { + d_type = LogicalType::VARCHAR; + } + detected_types.push_back(d_type); + } + return; + } + best_candidate->csv_buffer_iterator.Process(*best_candidate, parse_chunk); + for (idx_t col = 0; col < parse_chunk.ColumnCount(); col++) { + vector &col_type_candidates = best_sql_types_candidates_per_column_idx[col]; + bool is_bool_type = col_type_candidates.back() == LogicalType::BOOLEAN; + while (col_type_candidates.size() > 1) { + const auto &sql_type = col_type_candidates.back(); + // narrow down the date formats + if (best_format_candidates.count(sql_type.id())) { + auto &best_type_format_candidates = best_format_candidates[sql_type.id()]; + auto save_format_candidates = best_type_format_candidates; + while (!best_type_format_candidates.empty()) { + if (TryCastVector(parse_chunk.data[col], parse_chunk.size(), sql_type)) { + break; + } + // doesn't work - move to the next one + best_type_format_candidates.pop_back(); + best_candidate->dialect_options.has_format[sql_type.id()] = + (!best_type_format_candidates.empty()); + if (!best_type_format_candidates.empty()) { + SetDateFormat(*best_candidate, best_type_format_candidates.back(), sql_type.id()); + } + } + // if none match, then this is not a column of type sql_type, + if (best_type_format_candidates.empty()) { + // so restore the candidates that did work. + best_type_format_candidates.swap(save_format_candidates); + if (!best_type_format_candidates.empty()) { + SetDateFormat(*best_candidate, best_type_format_candidates.back(), sql_type.id()); + } + } + } + if (TryCastVector(parse_chunk.data[col], parse_chunk.size(), sql_type)) { + break; + } else { + if (col_type_candidates.back() == LogicalType::BOOLEAN && is_bool_type) { + // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we + // immediately pop to varchar. + while (col_type_candidates.back() != LogicalType::VARCHAR) { + col_type_candidates.pop_back(); + } + break; + } + col_type_candidates.pop_back(); + } + } + } + // reset parse chunk for the next iteration + parse_chunk.Reset(); + } + detected_types.clear(); + // set sql types + for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { + LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); + if (best_sql_types_candidates_per_column_idx[column_idx].size() == + best_candidate->options.auto_type_candidates.size()) { + d_type = LogicalType::VARCHAR; + } + detected_types.push_back(d_type); + } +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp new file mode 100644 index 00000000..41988082 --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp @@ -0,0 +1,39 @@ +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "duckdb/execution/operator/scan/csv/buffered_csv_reader.hpp" + +namespace duckdb { +void CSVSniffer::ReplaceTypes() { + if (best_candidate->options.sql_type_list.empty()) { + return; + } + // user-defined types were supplied for certain columns + // override the types + if (!best_candidate->options.sql_types_per_column.empty()) { + // types supplied as name -> value map + idx_t found = 0; + for (idx_t i = 0; i < names.size(); i++) { + auto it = best_candidate->options.sql_types_per_column.find(names[i]); + if (it != best_candidate->options.sql_types_per_column.end()) { + best_sql_types_candidates_per_column_idx[i] = {best_candidate->options.sql_type_list[it->second]}; + found++; + } + } + if (!best_candidate->options.file_options.union_by_name && + found < best_candidate->options.sql_types_per_column.size()) { + string error_msg = BufferedCSVReader::ColumnTypesError(options.sql_types_per_column, names); + if (!error_msg.empty()) { + throw BinderException(error_msg); + } + } + return; + } + // types supplied as list + if (names.size() < best_candidate->options.sql_type_list.size()) { + throw BinderException("read_csv: %d types were provided, but CSV file only has %d columns", + best_candidate->options.sql_type_list.size(), names.size()); + } + for (idx_t i = 0; i < best_candidate->options.sql_type_list.size(); i++) { + best_sql_types_candidates_per_column_idx[i] = {best_candidate->options.sql_type_list[i]}; + } +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/filter/physical_filter.cpp b/src/duckdb/src/execution/operator/filter/physical_filter.cpp new file mode 100644 index 00000000..d70d5093 --- /dev/null +++ b/src/duckdb/src/execution/operator/filter/physical_filter.cpp @@ -0,0 +1,62 @@ +#include "duckdb/execution/operator/filter/physical_filter.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/parallel/thread_context.hpp" +namespace duckdb { + +PhysicalFilter::PhysicalFilter(vector types, vector> select_list, + idx_t estimated_cardinality) + : CachingPhysicalOperator(PhysicalOperatorType::FILTER, std::move(types), estimated_cardinality) { + D_ASSERT(select_list.size() > 0); + if (select_list.size() > 1) { + // create a big AND out of the expressions + auto conjunction = make_uniq(ExpressionType::CONJUNCTION_AND); + for (auto &expr : select_list) { + conjunction->children.push_back(std::move(expr)); + } + expression = std::move(conjunction); + } else { + expression = std::move(select_list[0]); + } +} + +class FilterState : public CachingOperatorState { +public: + explicit FilterState(ExecutionContext &context, Expression &expr) + : executor(context.client, expr), sel(STANDARD_VECTOR_SIZE) { + } + + ExpressionExecutor executor; + SelectionVector sel; + +public: + void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { + context.thread.profiler.Flush(op, executor, "filter", 0); + } +}; + +unique_ptr PhysicalFilter::GetOperatorState(ExecutionContext &context) const { + return make_uniq(context, *expression); +} + +OperatorResultType PhysicalFilter::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state_p) const { + auto &state = state_p.Cast(); + idx_t result_count = state.executor.SelectExpression(input, state.sel); + if (result_count == input.size()) { + // nothing was filtered: skip adding any selection vectors + chunk.Reference(input); + } else { + chunk.Slice(input, state.sel, result_count); + } + return OperatorResultType::NEED_MORE_INPUT; +} + +string PhysicalFilter::ParamsToString() const { + auto result = expression->GetName(); + result += "\n[INFOSEPARATOR]\n"; + result += StringUtil::Format("EC: %llu", estimated_cardinality); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp new file mode 100644 index 00000000..2c0d6ac8 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp @@ -0,0 +1,76 @@ +#include "duckdb/execution/operator/helper/physical_batch_collector.hpp" + +#include "duckdb/common/types/batched_data_collection.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/materialized_query_result.hpp" + +namespace duckdb { + +PhysicalBatchCollector::PhysicalBatchCollector(PreparedStatementData &data) : PhysicalResultCollector(data) { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class BatchCollectorGlobalState : public GlobalSinkState { +public: + BatchCollectorGlobalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { + } + + mutex glock; + BatchedDataCollection data; + unique_ptr result; +}; + +class BatchCollectorLocalState : public LocalSinkState { +public: + BatchCollectorLocalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { + } + + BatchedDataCollection data; +}; + +SinkResultType PhysicalBatchCollector::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &state = input.local_state.Cast(); + state.data.Append(chunk, state.partition_info.batch_index.GetIndex()); + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalBatchCollector::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &state = input.local_state.Cast(); + + lock_guard lock(gstate.glock); + gstate.data.Merge(state.data); + + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalBatchCollector::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + auto collection = gstate.data.FetchCollection(); + D_ASSERT(collection); + auto result = make_uniq(statement_type, properties, names, std::move(collection), + context.GetClientProperties()); + gstate.result = std::move(result); + return SinkFinalizeType::READY; +} + +unique_ptr PhysicalBatchCollector::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context.client, *this); +} + +unique_ptr PhysicalBatchCollector::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr PhysicalBatchCollector::GetResult(GlobalSinkState &state) { + auto &gstate = state.Cast(); + D_ASSERT(gstate.result); + return std::move(gstate.result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_execute.cpp b/src/duckdb/src/execution/operator/helper/physical_execute.cpp new file mode 100644 index 00000000..4a076921 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_execute.cpp @@ -0,0 +1,20 @@ +#include "duckdb/execution/operator/helper/physical_execute.hpp" + +#include "duckdb/parallel/meta_pipeline.hpp" + +namespace duckdb { + +PhysicalExecute::PhysicalExecute(PhysicalOperator &plan) + : PhysicalOperator(PhysicalOperatorType::EXECUTE, plan.types, -1), plan(plan) { +} + +vector> PhysicalExecute::GetChildren() const { + return {plan}; +} + +void PhysicalExecute::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + // EXECUTE statement: build pipeline on child + meta_pipeline.Build(plan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp new file mode 100644 index 00000000..b5a43bfe --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp @@ -0,0 +1,46 @@ +#include "duckdb/execution/operator/helper/physical_explain_analyze.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/query_profiler.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class ExplainAnalyzeStateGlobalState : public GlobalSinkState { +public: + string analyzed_plan; +}; + +SinkResultType PhysicalExplainAnalyze::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + return SinkResultType::NEED_MORE_INPUT; +} + +SinkFinalizeType PhysicalExplainAnalyze::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &profiler = QueryProfiler::Get(context); + gstate.analyzed_plan = profiler.ToString(); + return SinkFinalizeType::READY; +} + +unique_ptr PhysicalExplainAnalyze::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(); +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalExplainAnalyze::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &gstate = sink_state->Cast(); + + chunk.SetValue(0, 0, Value("analyzed_plan")); + chunk.SetValue(1, 0, Value(gstate.analyzed_plan)); + chunk.SetCardinality(1); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_limit.cpp b/src/duckdb/src/execution/operator/helper/physical_limit.cpp new file mode 100644 index 00000000..4fd75344 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_limit.cpp @@ -0,0 +1,226 @@ +#include "duckdb/execution/operator/helper/physical_limit.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/types/batched_data_collection.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/helper/physical_streaming_limit.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +PhysicalLimit::PhysicalLimit(vector types, idx_t limit, idx_t offset, + unique_ptr limit_expression, unique_ptr offset_expression, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::LIMIT, std::move(types), estimated_cardinality), limit_value(limit), + offset_value(offset), limit_expression(std::move(limit_expression)), + offset_expression(std::move(offset_expression)) { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class LimitGlobalState : public GlobalSinkState { +public: + explicit LimitGlobalState(ClientContext &context, const PhysicalLimit &op) : data(context, op.types, true) { + limit = 0; + offset = 0; + } + + mutex glock; + idx_t limit; + idx_t offset; + BatchedDataCollection data; +}; + +class LimitLocalState : public LocalSinkState { +public: + explicit LimitLocalState(ClientContext &context, const PhysicalLimit &op) + : current_offset(0), data(context, op.types, true) { + this->limit = op.limit_expression ? DConstants::INVALID_INDEX : op.limit_value; + this->offset = op.offset_expression ? DConstants::INVALID_INDEX : op.offset_value; + } + + idx_t current_offset; + idx_t limit; + idx_t offset; + BatchedDataCollection data; +}; + +unique_ptr PhysicalLimit::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr PhysicalLimit::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context.client, *this); +} + +bool PhysicalLimit::ComputeOffset(ExecutionContext &context, DataChunk &input, idx_t &limit, idx_t &offset, + idx_t current_offset, idx_t &max_element, Expression *limit_expression, + Expression *offset_expression) { + if (limit != DConstants::INVALID_INDEX && offset != DConstants::INVALID_INDEX) { + max_element = limit + offset; + if ((limit == 0 || current_offset >= max_element) && !(limit_expression || offset_expression)) { + return false; + } + } + + // get the next chunk from the child + if (limit == DConstants::INVALID_INDEX) { + limit = 1ULL << 62ULL; + Value val = GetDelimiter(context, input, limit_expression); + if (!val.IsNull()) { + limit = val.GetValue(); + } + if (limit > 1ULL << 62ULL) { + throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", limit, 1ULL << 62ULL); + } + } + if (offset == DConstants::INVALID_INDEX) { + offset = 0; + Value val = GetDelimiter(context, input, offset_expression); + if (!val.IsNull()) { + offset = val.GetValue(); + } + if (offset > 1ULL << 62ULL) { + throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", offset, 1ULL << 62ULL); + } + } + max_element = limit + offset; + if (limit == 0 || current_offset >= max_element) { + return false; + } + return true; +} + +SinkResultType PhysicalLimit::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + + D_ASSERT(chunk.size() > 0); + auto &state = input.local_state.Cast(); + auto &limit = state.limit; + auto &offset = state.offset; + + idx_t max_element; + if (!ComputeOffset(context, chunk, limit, offset, state.current_offset, max_element, limit_expression.get(), + offset_expression.get())) { + return SinkResultType::FINISHED; + } + auto max_cardinality = max_element - state.current_offset; + if (max_cardinality < chunk.size()) { + chunk.SetCardinality(max_cardinality); + } + state.data.Append(chunk, state.partition_info.batch_index.GetIndex()); + state.current_offset += chunk.size(); + if (state.current_offset == max_element) { + return SinkResultType::FINISHED; + } + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalLimit::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &state = input.local_state.Cast(); + + lock_guard lock(gstate.glock); + gstate.limit = state.limit; + gstate.offset = state.offset; + gstate.data.Merge(state.data); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class LimitSourceState : public GlobalSourceState { +public: + LimitSourceState() { + initialized = false; + current_offset = 0; + } + + bool initialized; + idx_t current_offset; + BatchedChunkScanState scan_state; +}; + +unique_ptr PhysicalLimit::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(); +} + +SourceResultType PhysicalLimit::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + auto &gstate = sink_state->Cast(); + auto &state = input.global_state.Cast(); + while (state.current_offset < gstate.limit + gstate.offset) { + if (!state.initialized) { + gstate.data.InitializeScan(state.scan_state); + state.initialized = true; + } + gstate.data.Scan(state.scan_state, chunk); + if (chunk.size() == 0) { + return SourceResultType::FINISHED; + } + if (HandleOffset(chunk, state.current_offset, gstate.offset, gstate.limit)) { + break; + } + } + + return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; +} + +bool PhysicalLimit::HandleOffset(DataChunk &input, idx_t ¤t_offset, idx_t offset, idx_t limit) { + idx_t max_element = limit + offset; + if (limit == DConstants::INVALID_INDEX) { + max_element = DConstants::INVALID_INDEX; + } + idx_t input_size = input.size(); + if (current_offset < offset) { + // we are not yet at the offset point + if (current_offset + input.size() > offset) { + // however we will reach it in this chunk + // we have to copy part of the chunk with an offset + idx_t start_position = offset - current_offset; + auto chunk_count = MinValue(limit, input.size() - start_position); + SelectionVector sel(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < chunk_count; i++) { + sel.set_index(i, start_position + i); + } + // set up a slice of the input chunks + input.Slice(input, sel, chunk_count); + } else { + current_offset += input_size; + return false; + } + } else { + // have to copy either the entire chunk or part of it + idx_t chunk_count; + if (current_offset + input.size() >= max_element) { + // have to limit the count of the chunk + chunk_count = max_element - current_offset; + } else { + // we copy the entire chunk + chunk_count = input.size(); + } + // instead of copying we just change the pointer in the current chunk + input.Reference(input); + input.SetCardinality(chunk_count); + } + + current_offset += input_size; + return true; +} + +Value PhysicalLimit::GetDelimiter(ExecutionContext &context, DataChunk &input, Expression *expr) { + DataChunk limit_chunk; + vector types {expr->return_type}; + auto &allocator = Allocator::Get(context.client); + limit_chunk.Initialize(allocator, types); + ExpressionExecutor limit_executor(context.client, expr); + auto input_size = input.size(); + input.SetCardinality(1); + limit_executor.Execute(input, limit_chunk); + input.SetCardinality(input_size); + auto limit_value = limit_chunk.GetValue(0, 0); + return limit_value; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp b/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp new file mode 100644 index 00000000..a65cc219 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp @@ -0,0 +1,142 @@ +#include "duckdb/execution/operator/helper/physical_limit_percent.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/helper/physical_limit.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class LimitPercentGlobalState : public GlobalSinkState { +public: + explicit LimitPercentGlobalState(ClientContext &context, const PhysicalLimitPercent &op) + : current_offset(0), data(context, op.GetTypes()) { + if (!op.limit_expression) { + this->limit_percent = op.limit_percent; + is_limit_percent_delimited = true; + } else { + this->limit_percent = 100.0; + } + + if (!op.offset_expression) { + this->offset = op.offset_value; + is_offset_delimited = true; + } else { + this->offset = 0; + } + } + + idx_t current_offset; + double limit_percent; + idx_t offset; + ColumnDataCollection data; + + bool is_limit_percent_delimited = false; + bool is_offset_delimited = false; +}; + +unique_ptr PhysicalLimitPercent::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +SinkResultType PhysicalLimitPercent::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + D_ASSERT(chunk.size() > 0); + auto &state = input.global_state.Cast(); + auto &limit_percent = state.limit_percent; + auto &offset = state.offset; + + // get the next chunk from the child + if (!state.is_limit_percent_delimited) { + Value val = PhysicalLimit::GetDelimiter(context, chunk, limit_expression.get()); + if (!val.IsNull()) { + limit_percent = val.GetValue(); + } + if (limit_percent < 0.0) { + throw BinderException("Percentage value(%f) can't be negative", limit_percent); + } + state.is_limit_percent_delimited = true; + } + if (!state.is_offset_delimited) { + Value val = PhysicalLimit::GetDelimiter(context, chunk, offset_expression.get()); + if (!val.IsNull()) { + offset = val.GetValue(); + } + if (offset > 1ULL << 62ULL) { + throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", offset, 1ULL << 62ULL); + } + state.is_offset_delimited = true; + } + + if (!PhysicalLimit::HandleOffset(chunk, state.current_offset, offset, DConstants::INVALID_INDEX)) { + return SinkResultType::NEED_MORE_INPUT; + } + + state.data.Append(chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class LimitPercentOperatorState : public GlobalSourceState { +public: + explicit LimitPercentOperatorState(const PhysicalLimitPercent &op) + : limit(DConstants::INVALID_INDEX), current_offset(0) { + D_ASSERT(op.sink_state); + auto &gstate = op.sink_state->Cast(); + gstate.data.InitializeScan(scan_state); + } + + ColumnDataScanState scan_state; + idx_t limit; + idx_t current_offset; +}; + +unique_ptr PhysicalLimitPercent::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this); +} + +SourceResultType PhysicalLimitPercent::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &gstate = sink_state->Cast(); + auto &state = input.global_state.Cast(); + auto &percent_limit = gstate.limit_percent; + auto &offset = gstate.offset; + auto &limit = state.limit; + auto ¤t_offset = state.current_offset; + + if (gstate.is_limit_percent_delimited && limit == DConstants::INVALID_INDEX) { + idx_t count = gstate.data.Count(); + if (count > 0) { + count += offset; + } + if (Value::IsNan(percent_limit) || percent_limit < 0 || percent_limit > 100) { + throw OutOfRangeException("Limit percent out of range, should be between 0% and 100%"); + } + double limit_dbl = percent_limit / 100 * count; + if (limit_dbl > count) { + limit = count; + } else { + limit = idx_t(limit_dbl); + } + if (limit == 0) { + return SourceResultType::FINISHED; + } + } + + if (current_offset >= limit) { + return SourceResultType::FINISHED; + } + if (!gstate.data.Scan(state.scan_state, chunk)) { + return SourceResultType::FINISHED; + } + + PhysicalLimit::HandleOffset(chunk, current_offset, 0, limit); + + return SourceResultType::HAVE_MORE_OUTPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_load.cpp b/src/duckdb/src/execution/operator/helper/physical_load.cpp new file mode 100644 index 00000000..62054206 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_load.cpp @@ -0,0 +1,17 @@ +#include "duckdb/execution/operator/helper/physical_load.hpp" +#include "duckdb/main/extension_helper.hpp" + +namespace duckdb { + +SourceResultType PhysicalLoad::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + if (info->load_type == LoadType::INSTALL || info->load_type == LoadType::FORCE_INSTALL) { + ExtensionHelper::InstallExtension(context.client, info->filename, info->load_type == LoadType::FORCE_INSTALL, + info->repository); + } else { + ExtensionHelper::LoadExternalExtension(context.client, info->filename); + } + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp new file mode 100644 index 00000000..650a8916 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp @@ -0,0 +1,84 @@ +#include "duckdb/execution/operator/helper/physical_materialized_collector.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/main/materialized_query_result.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +PhysicalMaterializedCollector::PhysicalMaterializedCollector(PreparedStatementData &data, bool parallel) + : PhysicalResultCollector(data), parallel(parallel) { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class MaterializedCollectorGlobalState : public GlobalSinkState { +public: + mutex glock; + unique_ptr collection; + shared_ptr context; +}; + +class MaterializedCollectorLocalState : public LocalSinkState { +public: + unique_ptr collection; + ColumnDataAppendState append_state; +}; + +SinkResultType PhysicalMaterializedCollector::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &lstate = input.local_state.Cast(); + lstate.collection->Append(lstate.append_state, chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalMaterializedCollector::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + if (lstate.collection->Count() == 0) { + return SinkCombineResultType::FINISHED; + } + + lock_guard l(gstate.glock); + if (!gstate.collection) { + gstate.collection = std::move(lstate.collection); + } else { + gstate.collection->Combine(*lstate.collection); + } + + return SinkCombineResultType::FINISHED; +} + +unique_ptr PhysicalMaterializedCollector::GetGlobalSinkState(ClientContext &context) const { + auto state = make_uniq(); + state->context = context.shared_from_this(); + return std::move(state); +} + +unique_ptr PhysicalMaterializedCollector::GetLocalSinkState(ExecutionContext &context) const { + auto state = make_uniq(); + state->collection = make_uniq(Allocator::DefaultAllocator(), types); + state->collection->InitializeAppend(state->append_state); + return std::move(state); +} + +unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState &state) { + auto &gstate = state.Cast(); + if (!gstate.collection) { + gstate.collection = make_uniq(Allocator::DefaultAllocator(), types); + } + auto result = make_uniq(statement_type, properties, names, std::move(gstate.collection), + gstate.context->GetClientProperties()); + return std::move(result); +} + +bool PhysicalMaterializedCollector::ParallelSink() const { + return parallel; +} + +bool PhysicalMaterializedCollector::SinkOrderDependent() const { + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_pragma.cpp b/src/duckdb/src/execution/operator/helper/physical_pragma.cpp new file mode 100644 index 00000000..24782f95 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_pragma.cpp @@ -0,0 +1,14 @@ +#include "duckdb/execution/operator/helper/physical_pragma.hpp" + +namespace duckdb { + +SourceResultType PhysicalPragma::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &client = context.client; + FunctionParameters parameters {info.parameters, info.named_parameters}; + function.function(client, parameters); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp new file mode 100644 index 00000000..784d6ada --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp @@ -0,0 +1,16 @@ +#include "duckdb/execution/operator/helper/physical_prepare.hpp" +#include "duckdb/main/client_data.hpp" + +namespace duckdb { + +SourceResultType PhysicalPrepare::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &client = context.client; + + // store the prepared statement in the context + ClientData::Get(client).prepared_statements[name] = prepared; + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp b/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp new file mode 100644 index 00000000..ec846f1d --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp @@ -0,0 +1,73 @@ +#include "duckdb/execution/operator/helper/physical_reservoir_sample.hpp" +#include "duckdb/execution/reservoir_sample.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class SampleGlobalSinkState : public GlobalSinkState { +public: + explicit SampleGlobalSinkState(Allocator &allocator, SampleOptions &options) { + if (options.is_percentage) { + auto percentage = options.sample_size.GetValue(); + if (percentage == 0) { + return; + } + sample = make_uniq(allocator, percentage, options.seed); + } else { + auto size = options.sample_size.GetValue(); + if (size == 0) { + return; + } + sample = make_uniq(allocator, size, options.seed); + } + } + + //! The lock for updating the global aggregate state + mutex lock; + //! The reservoir sample + unique_ptr sample; +}; + +unique_ptr PhysicalReservoirSample::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(Allocator::Get(context), *options); +} + +SinkResultType PhysicalReservoirSample::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + if (!gstate.sample) { + return SinkResultType::FINISHED; + } + // we implement reservoir sampling without replacement and exponential jumps here + // the algorithm is adopted from the paper Weighted random sampling with a reservoir by Pavlos S. Efraimidis et al. + // note that the original algorithm is about weighted sampling; this is a simplified approach for uniform sampling + lock_guard glock(gstate.lock); + gstate.sample->AddToReservoir(chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalReservoirSample::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &sink = this->sink_state->Cast(); + if (!sink.sample) { + return SourceResultType::FINISHED; + } + auto sample_chunk = sink.sample->GetChunk(); + if (!sample_chunk) { + return SourceResultType::FINISHED; + } + chunk.Move(*sample_chunk); + + return SourceResultType::HAVE_MORE_OUTPUT; +} + +string PhysicalReservoirSample::ParamsToString() const { + return options->sample_size.ToString() + (options->is_percentage ? "%" : " rows"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_reset.cpp b/src/duckdb/src/execution/operator/helper/physical_reset.cpp new file mode 100644 index 00000000..6fd3b9f3 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_reset.cpp @@ -0,0 +1,74 @@ +#include "duckdb/execution/operator/helper/physical_reset.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +void PhysicalReset::ResetExtensionVariable(ExecutionContext &context, DBConfig &config, + ExtensionOption &extension_option) const { + if (extension_option.set_function) { + extension_option.set_function(context.client, scope, extension_option.default_value); + } + if (scope == SetScope::GLOBAL) { + config.ResetOption(name); + } else { + auto &client_config = ClientConfig::GetConfig(context.client); + client_config.set_variables[name] = extension_option.default_value; + } +} + +SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + auto &config = DBConfig::GetConfig(context.client); + if (config.options.lock_configuration) { + throw InvalidInputException("Cannot reset configuration option \"%s\" - the configuration has been locked", + name); + } + auto option = DBConfig::GetOptionByName(name); + if (!option) { + // check if this is an extra extension variable + auto entry = config.extension_parameters.find(name); + if (entry == config.extension_parameters.end()) { + Catalog::AutoloadExtensionByConfigName(context.client, name); + entry = config.extension_parameters.find(name); + D_ASSERT(entry != config.extension_parameters.end()); + } + ResetExtensionVariable(context, config, entry->second); + return SourceResultType::FINISHED; + } + + // Transform scope + SetScope variable_scope = scope; + if (variable_scope == SetScope::AUTOMATIC) { + if (option->set_local) { + variable_scope = SetScope::SESSION; + } else { + D_ASSERT(option->set_global); + variable_scope = SetScope::GLOBAL; + } + } + + switch (variable_scope) { + case SetScope::GLOBAL: { + if (!option->set_global) { + throw CatalogException("option \"%s\" cannot be reset globally", name); + } + auto &db = DatabaseInstance::GetDatabase(context.client); + config.ResetOption(&db, *option); + break; + } + case SetScope::SESSION: + if (!option->reset_local) { + throw CatalogException("option \"%s\" cannot be reset locally", name); + } + option->reset_local(context.client); + break; + default: + throw InternalException("Unsupported SetScope for variable"); + } + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp new file mode 100644 index 00000000..e0ac959a --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp @@ -0,0 +1,53 @@ +#include "duckdb/execution/operator/helper/physical_result_collector.hpp" + +#include "duckdb/execution/operator/helper/physical_batch_collector.hpp" +#include "duckdb/execution/operator/helper/physical_materialized_collector.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" + +namespace duckdb { + +PhysicalResultCollector::PhysicalResultCollector(PreparedStatementData &data) + : PhysicalOperator(PhysicalOperatorType::RESULT_COLLECTOR, {LogicalType::BOOLEAN}, 0), + statement_type(data.statement_type), properties(data.properties), plan(*data.plan), names(data.names) { + this->types = data.types; +} + +unique_ptr PhysicalResultCollector::GetResultCollector(ClientContext &context, + PreparedStatementData &data) { + if (!PhysicalPlanGenerator::PreserveInsertionOrder(context, *data.plan)) { + // the plan is not order preserving, so we just use the parallel materialized collector + return make_uniq_base(data, true); + } else if (!PhysicalPlanGenerator::UseBatchIndex(context, *data.plan)) { + // the plan is order preserving, but we cannot use the batch index: use a single-threaded result collector + return make_uniq_base(data, false); + } else { + // we care about maintaining insertion order and the sources all support batch indexes + // use a batch collector + return make_uniq_base(data); + } +} + +vector> PhysicalResultCollector::GetChildren() const { + return {plan}; +} + +void PhysicalResultCollector::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + // operator is a sink, build a pipeline + sink_state.reset(); + + D_ASSERT(children.empty()); + + // single operator: the operator becomes the data source of the current pipeline + auto &state = meta_pipeline.GetState(); + state.SetPipelineSource(current, *this); + + // we create a new pipeline starting from the child + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + child_meta_pipeline.Build(plan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_set.cpp b/src/duckdb/src/execution/operator/helper/physical_set.cpp new file mode 100644 index 00000000..8153ea54 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_set.cpp @@ -0,0 +1,77 @@ +#include "duckdb/execution/operator/helper/physical_set.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, + SetScope scope, const Value &value) { + auto &config = DBConfig::GetConfig(context); + auto &target_type = extension_option.type; + Value target_value = value.CastAs(context, target_type); + if (extension_option.set_function) { + extension_option.set_function(context, scope, target_value); + } + if (scope == SetScope::GLOBAL) { + config.SetOption(name, std::move(target_value)); + } else { + auto &client_config = ClientConfig::GetConfig(context); + client_config.set_variables[name] = std::move(target_value); + } +} + +SourceResultType PhysicalSet::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + auto &config = DBConfig::GetConfig(context.client); + if (config.options.lock_configuration) { + throw InvalidInputException("Cannot change configuration option \"%s\" - the configuration has been locked", + name); + } + auto option = DBConfig::GetOptionByName(name); + if (!option) { + // check if this is an extra extension variable + auto entry = config.extension_parameters.find(name); + if (entry == config.extension_parameters.end()) { + Catalog::AutoloadExtensionByConfigName(context.client, name); + entry = config.extension_parameters.find(name); + D_ASSERT(entry != config.extension_parameters.end()); + } + SetExtensionVariable(context.client, entry->second, name, scope, value); + return SourceResultType::FINISHED; + } + SetScope variable_scope = scope; + if (variable_scope == SetScope::AUTOMATIC) { + if (option->set_local) { + variable_scope = SetScope::SESSION; + } else { + D_ASSERT(option->set_global); + variable_scope = SetScope::GLOBAL; + } + } + + Value input_val = value.CastAs(context.client, option->parameter_type); + switch (variable_scope) { + case SetScope::GLOBAL: { + if (!option->set_global) { + throw CatalogException("option \"%s\" cannot be set globally", name); + } + auto &db = DatabaseInstance::GetDatabase(context.client); + auto &config = DBConfig::GetConfig(context.client); + config.SetOption(&db, *option, input_val); + break; + } + case SetScope::SESSION: + if (!option->set_local) { + throw CatalogException("option \"%s\" cannot be set locally", name); + } + option->set_local(context.client, input_val); + break; + default: + throw InternalException("Unsupported SetScope for variable"); + } + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_streaming_limit.cpp b/src/duckdb/src/execution/operator/helper/physical_streaming_limit.cpp new file mode 100644 index 00000000..5fbd56a3 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_streaming_limit.cpp @@ -0,0 +1,71 @@ +#include "duckdb/execution/operator/helper/physical_streaming_limit.hpp" +#include "duckdb/execution/operator/helper/physical_limit.hpp" + +namespace duckdb { + +PhysicalStreamingLimit::PhysicalStreamingLimit(vector types, idx_t limit, idx_t offset, + unique_ptr limit_expression, + unique_ptr offset_expression, idx_t estimated_cardinality, + bool parallel) + : PhysicalOperator(PhysicalOperatorType::STREAMING_LIMIT, std::move(types), estimated_cardinality), + limit_value(limit), offset_value(offset), limit_expression(std::move(limit_expression)), + offset_expression(std::move(offset_expression)), parallel(parallel) { +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +class StreamingLimitOperatorState : public OperatorState { +public: + explicit StreamingLimitOperatorState(const PhysicalStreamingLimit &op) { + this->limit = op.limit_expression ? DConstants::INVALID_INDEX : op.limit_value; + this->offset = op.offset_expression ? DConstants::INVALID_INDEX : op.offset_value; + } + + idx_t limit; + idx_t offset; +}; + +class StreamingLimitGlobalState : public GlobalOperatorState { +public: + StreamingLimitGlobalState() : current_offset(0) { + } + + std::atomic current_offset; +}; + +unique_ptr PhysicalStreamingLimit::GetOperatorState(ExecutionContext &context) const { + return make_uniq(*this); +} + +unique_ptr PhysicalStreamingLimit::GetGlobalOperatorState(ClientContext &context) const { + return make_uniq(); +} + +OperatorResultType PhysicalStreamingLimit::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate_p, OperatorState &state_p) const { + auto &gstate = gstate_p.Cast(); + auto &state = state_p.Cast(); + auto &limit = state.limit; + auto &offset = state.offset; + idx_t current_offset = gstate.current_offset.fetch_add(input.size()); + idx_t max_element; + if (!PhysicalLimit::ComputeOffset(context, input, limit, offset, current_offset, max_element, + limit_expression.get(), offset_expression.get())) { + return OperatorResultType::FINISHED; + } + if (PhysicalLimit::HandleOffset(input, current_offset, offset, limit)) { + chunk.Reference(input); + } + return OperatorResultType::NEED_MORE_INPUT; +} + +OrderPreservationType PhysicalStreamingLimit::OperatorOrder() const { + return OrderPreservationType::FIXED_ORDER; +} + +bool PhysicalStreamingLimit::ParallelOperator() const { + return parallel; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp b/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp new file mode 100644 index 00000000..cda2fa25 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp @@ -0,0 +1,75 @@ +#include "duckdb/execution/operator/helper/physical_streaming_sample.hpp" +#include "duckdb/common/random_engine.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +PhysicalStreamingSample::PhysicalStreamingSample(vector types, SampleMethod method, double percentage, + int64_t seed, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::STREAMING_SAMPLE, std::move(types), estimated_cardinality), method(method), + percentage(percentage / 100), seed(seed) { +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +class StreamingSampleOperatorState : public OperatorState { +public: + explicit StreamingSampleOperatorState(int64_t seed) : random(seed) { + } + + RandomEngine random; +}; + +void PhysicalStreamingSample::SystemSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { + // system sampling: we throw one dice per chunk + auto &state = state_p.Cast(); + double rand = state.random.NextRandom(); + if (rand <= percentage) { + // rand is smaller than sample_size: output chunk + result.Reference(input); + } +} + +void PhysicalStreamingSample::BernoulliSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { + // bernoulli sampling: we throw one dice per tuple + // then slice the result chunk + auto &state = state_p.Cast(); + idx_t result_count = 0; + SelectionVector sel(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < input.size(); i++) { + double rand = state.random.NextRandom(); + if (rand <= percentage) { + sel.set_index(result_count++, i); + } + } + if (result_count > 0) { + result.Slice(input, sel, result_count); + } +} + +unique_ptr PhysicalStreamingSample::GetOperatorState(ExecutionContext &context) const { + return make_uniq(seed); +} + +OperatorResultType PhysicalStreamingSample::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const { + switch (method) { + case SampleMethod::BERNOULLI_SAMPLE: + BernoulliSample(input, chunk, state); + break; + case SampleMethod::SYSTEM_SAMPLE: + SystemSample(input, chunk, state); + break; + default: + throw InternalException("Unsupported sample method for streaming sample"); + } + return OperatorResultType::NEED_MORE_INPUT; +} + +string PhysicalStreamingSample::ParamsToString() const { + return EnumUtil::ToString(method) + ": " + to_string(100 * percentage) + "%"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp new file mode 100644 index 00000000..cca98d85 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp @@ -0,0 +1,55 @@ +#include "duckdb/execution/operator/helper/physical_transaction.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/valid_checker.hpp" + +namespace duckdb { + +SourceResultType PhysicalTransaction::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &client = context.client; + + auto type = info->type; + if (type == TransactionType::COMMIT && ValidChecker::IsInvalidated(client.ActiveTransaction())) { + // transaction is invalidated - turn COMMIT into ROLLBACK + type = TransactionType::ROLLBACK; + } + switch (type) { + case TransactionType::BEGIN_TRANSACTION: { + if (client.transaction.IsAutoCommit()) { + // start the active transaction + // if autocommit is active, we have already called + // BeginTransaction by setting autocommit to false we + // prevent it from being closed after this query, hence + // preserving the transaction context for the next query + client.transaction.SetAutoCommit(false); + } else { + throw TransactionException("cannot start a transaction within a transaction"); + } + break; + } + case TransactionType::COMMIT: { + if (client.transaction.IsAutoCommit()) { + throw TransactionException("cannot commit - no transaction is active"); + } else { + // explicitly commit the current transaction + client.transaction.Commit(); + } + break; + } + case TransactionType::ROLLBACK: { + if (client.transaction.IsAutoCommit()) { + throw TransactionException("cannot rollback - no transaction is active"); + } else { + // explicitly rollback the current transaction + client.transaction.Rollback(); + } + break; + } + default: + throw NotImplementedException("Unrecognized transaction type!"); + } + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp b/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp new file mode 100644 index 00000000..7ab179ab --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp @@ -0,0 +1,92 @@ +#include "duckdb/execution/operator/helper/physical_vacuum.hpp" + +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/statistics/distinct_statistics.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +namespace duckdb { + +PhysicalVacuum::PhysicalVacuum(unique_ptr info_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::VACUUM, {LogicalType::BOOLEAN}, estimated_cardinality), + info(std::move(info_p)) { +} + +class VacuumLocalSinkState : public LocalSinkState { +public: + explicit VacuumLocalSinkState(VacuumInfo &info) { + for (idx_t col_idx = 0; col_idx < info.columns.size(); col_idx++) { + column_distinct_stats.push_back(make_uniq()); + } + }; + + vector> column_distinct_stats; +}; + +unique_ptr PhysicalVacuum::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(*info); +} + +class VacuumGlobalSinkState : public GlobalSinkState { +public: + explicit VacuumGlobalSinkState(VacuumInfo &info) { + for (idx_t col_idx = 0; col_idx < info.columns.size(); col_idx++) { + column_distinct_stats.push_back(make_uniq()); + } + }; + + mutex stats_lock; + vector> column_distinct_stats; +}; + +unique_ptr PhysicalVacuum::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(*info); +} + +SinkResultType PhysicalVacuum::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &lstate = input.local_state.Cast(); + D_ASSERT(lstate.column_distinct_stats.size() == info->column_id_map.size()); + + for (idx_t col_idx = 0; col_idx < chunk.data.size(); col_idx++) { + if (!DistinctStatistics::TypeIsSupported(chunk.data[col_idx].GetType())) { + continue; + } + lstate.column_distinct_stats[col_idx]->Update(chunk.data[col_idx], chunk.size(), false); + } + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalVacuum::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + lock_guard lock(gstate.stats_lock); + D_ASSERT(gstate.column_distinct_stats.size() == lstate.column_distinct_stats.size()); + for (idx_t col_idx = 0; col_idx < gstate.column_distinct_stats.size(); col_idx++) { + gstate.column_distinct_stats[col_idx]->Merge(*lstate.column_distinct_stats[col_idx]); + } + + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalVacuum::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &sink = input.global_state.Cast(); + + auto table = info->table; + for (idx_t col_idx = 0; col_idx < sink.column_distinct_stats.size(); col_idx++) { + table->GetStorage().SetDistinct(info->column_id_map.at(col_idx), + std::move(sink.column_distinct_stats[col_idx])); + } + + return SinkFinalizeType::READY; +} + +SourceResultType PhysicalVacuum::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + // NOP + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/outer_join_marker.cpp b/src/duckdb/src/execution/operator/join/outer_join_marker.cpp new file mode 100644 index 00000000..53d7de3d --- /dev/null +++ b/src/duckdb/src/execution/operator/join/outer_join_marker.cpp @@ -0,0 +1,108 @@ +#include "duckdb/execution/operator/join/outer_join_marker.hpp" + +namespace duckdb { + +OuterJoinMarker::OuterJoinMarker(bool enabled_p) : enabled(enabled_p), count(0) { +} + +void OuterJoinMarker::Initialize(idx_t count_p) { + if (!enabled) { + return; + } + this->count = count_p; + found_match = make_unsafe_uniq_array(count); + Reset(); +} + +void OuterJoinMarker::Reset() { + if (!enabled) { + return; + } + memset(found_match.get(), 0, sizeof(bool) * count); +} + +void OuterJoinMarker::SetMatch(idx_t position) { + if (!enabled) { + return; + } + D_ASSERT(position < count); + found_match[position] = true; +} + +void OuterJoinMarker::SetMatches(const SelectionVector &sel, idx_t count, idx_t base_idx) { + if (!enabled) { + return; + } + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto pos = base_idx + idx; + D_ASSERT(pos < this->count); + found_match[pos] = true; + } +} + +void OuterJoinMarker::ConstructLeftJoinResult(DataChunk &left, DataChunk &result) { + if (!enabled) { + return; + } + D_ASSERT(count == STANDARD_VECTOR_SIZE); + SelectionVector remaining_sel(STANDARD_VECTOR_SIZE); + idx_t remaining_count = 0; + for (idx_t i = 0; i < left.size(); i++) { + if (!found_match[i]) { + remaining_sel.set_index(remaining_count++, i); + } + } + if (remaining_count > 0) { + result.Slice(left, remaining_sel, remaining_count); + for (idx_t idx = left.ColumnCount(); idx < result.ColumnCount(); idx++) { + result.data[idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result.data[idx], true); + } + } +} + +idx_t OuterJoinMarker::MaxThreads() const { + return count / (STANDARD_VECTOR_SIZE * 10ULL); +} + +void OuterJoinMarker::InitializeScan(ColumnDataCollection &data, OuterJoinGlobalScanState &gstate) { + gstate.data = &data; + data.InitializeScan(gstate.global_scan); +} + +void OuterJoinMarker::InitializeScan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanState &lstate) { + D_ASSERT(gstate.data); + lstate.match_sel.Initialize(STANDARD_VECTOR_SIZE); + gstate.data->InitializeScanChunk(lstate.scan_chunk); +} + +void OuterJoinMarker::Scan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanState &lstate, DataChunk &result) { + D_ASSERT(gstate.data); + // fill in NULL values for the LHS + while (gstate.data->Scan(gstate.global_scan, lstate.local_scan, lstate.scan_chunk)) { + idx_t result_count = 0; + // figure out which tuples didn't find a match in the RHS + for (idx_t i = 0; i < lstate.scan_chunk.size(); i++) { + if (!found_match[lstate.local_scan.current_row_index + i]) { + lstate.match_sel.set_index(result_count++, i); + } + } + if (result_count > 0) { + // if there were any tuples that didn't find a match, output them + idx_t left_column_count = result.ColumnCount() - lstate.scan_chunk.ColumnCount(); + for (idx_t i = 0; i < left_column_count; i++) { + result.data[i].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result.data[i], true); + } + for (idx_t col_idx = left_column_count; col_idx < result.ColumnCount(); col_idx++) { + result.data[col_idx].Slice(lstate.scan_chunk.data[col_idx - left_column_count], lstate.match_sel, + result_count); + } + result.SetCardinality(result_count); + return; + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp new file mode 100644 index 00000000..7c7ad73e --- /dev/null +++ b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp @@ -0,0 +1,285 @@ +#include "duckdb/execution/operator/join/perfect_hash_join_executor.hpp" + +#include "duckdb/common/types/row/row_layout.hpp" +#include "duckdb/execution/operator/join/physical_hash_join.hpp" + +namespace duckdb { + +PerfectHashJoinExecutor::PerfectHashJoinExecutor(const PhysicalHashJoin &join_p, JoinHashTable &ht_p, + PerfectHashJoinStats perfect_join_stats) + : join(join_p), ht(ht_p), perfect_join_statistics(std::move(perfect_join_stats)) { +} + +bool PerfectHashJoinExecutor::CanDoPerfectHashJoin() { + return perfect_join_statistics.is_build_small; +} + +//===--------------------------------------------------------------------===// +// Build +//===--------------------------------------------------------------------===// +bool PerfectHashJoinExecutor::BuildPerfectHashTable(LogicalType &key_type) { + // First, allocate memory for each build column + auto build_size = perfect_join_statistics.build_range + 1; + for (const auto &type : ht.build_types) { + perfect_hash_table.emplace_back(type, build_size); + } + + // and for duplicate_checking + bitmap_build_idx = make_unsafe_uniq_array(build_size); + memset(bitmap_build_idx.get(), 0, sizeof(bool) * build_size); // set false + + // Now fill columns with build data + + return FullScanHashTable(key_type); +} + +bool PerfectHashJoinExecutor::FullScanHashTable(LogicalType &key_type) { + auto &data_collection = ht.GetDataCollection(); + + // TODO: In a parallel finalize: One should exclusively lock and each thread should do one part of the code below. + Vector tuples_addresses(LogicalType::POINTER, ht.Count()); // allocate space for all the tuples + + idx_t key_count = 0; + if (data_collection.ChunkCount() > 0) { + JoinHTScanState join_ht_state(data_collection, 0, data_collection.ChunkCount(), + TupleDataPinProperties::KEEP_EVERYTHING_PINNED); + + // Go through all the blocks and fill the keys addresses + key_count = ht.FillWithHTOffsets(join_ht_state, tuples_addresses); + } + + // Scan the build keys in the hash table + Vector build_vector(key_type, key_count); + RowOperations::FullScanColumn(ht.layout, tuples_addresses, build_vector, key_count, 0); + + // Now fill the selection vector using the build keys and create a sequential vector + // TODO: add check for fast pass when probe is part of build domain + SelectionVector sel_build(key_count + 1); + SelectionVector sel_tuples(key_count + 1); + bool success = FillSelectionVectorSwitchBuild(build_vector, sel_build, sel_tuples, key_count); + + // early out + if (!success) { + return false; + } + if (unique_keys == perfect_join_statistics.build_range + 1 && !ht.has_null) { + perfect_join_statistics.is_build_dense = true; + } + key_count = unique_keys; // do not consider keys out of the range + + // Full scan the remaining build columns and fill the perfect hash table + const auto build_size = perfect_join_statistics.build_range + 1; + for (idx_t i = 0; i < ht.build_types.size(); i++) { + auto &vector = perfect_hash_table[i]; + D_ASSERT(vector.GetType() == ht.build_types[i]); + if (build_size > STANDARD_VECTOR_SIZE) { + auto &col_mask = FlatVector::Validity(vector); + col_mask.Initialize(build_size); + } + + const auto col_no = ht.condition_types.size() + i; + data_collection.Gather(tuples_addresses, sel_tuples, key_count, col_no, vector, sel_build); + } + + return true; +} + +bool PerfectHashJoinExecutor::FillSelectionVectorSwitchBuild(Vector &source, SelectionVector &sel_vec, + SelectionVector &seq_sel_vec, idx_t count) { + switch (source.GetType().InternalType()) { + case PhysicalType::INT8: + return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); + case PhysicalType::INT16: + return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); + case PhysicalType::INT32: + return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); + case PhysicalType::INT64: + return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); + case PhysicalType::UINT8: + return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); + case PhysicalType::UINT16: + return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); + case PhysicalType::UINT32: + return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); + case PhysicalType::UINT64: + return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); + default: + throw NotImplementedException("Type not supported for perfect hash join"); + } +} + +template +bool PerfectHashJoinExecutor::TemplatedFillSelectionVectorBuild(Vector &source, SelectionVector &sel_vec, + SelectionVector &seq_sel_vec, idx_t count) { + if (perfect_join_statistics.build_min.IsNull() || perfect_join_statistics.build_max.IsNull()) { + return false; + } + auto min_value = perfect_join_statistics.build_min.GetValueUnsafe(); + auto max_value = perfect_join_statistics.build_max.GetValueUnsafe(); + UnifiedVectorFormat vector_data; + source.ToUnifiedFormat(count, vector_data); + auto data = reinterpret_cast(vector_data.data); + // generate the selection vector + for (idx_t i = 0, sel_idx = 0; i < count; ++i) { + auto data_idx = vector_data.sel->get_index(i); + auto input_value = data[data_idx]; + // add index to selection vector if value in the range + if (min_value <= input_value && input_value <= max_value) { + auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position + sel_vec.set_index(sel_idx, idx); + if (bitmap_build_idx[idx]) { + return false; + } else { + bitmap_build_idx[idx] = true; + unique_keys++; + } + seq_sel_vec.set_index(sel_idx++, i); + } + } + return true; +} + +//===--------------------------------------------------------------------===// +// Probe +//===--------------------------------------------------------------------===// +class PerfectHashJoinState : public OperatorState { +public: + PerfectHashJoinState(ClientContext &context, const PhysicalHashJoin &join) : probe_executor(context) { + join_keys.Initialize(Allocator::Get(context), join.condition_types); + for (auto &cond : join.conditions) { + probe_executor.AddExpression(*cond.left); + } + build_sel_vec.Initialize(STANDARD_VECTOR_SIZE); + probe_sel_vec.Initialize(STANDARD_VECTOR_SIZE); + seq_sel_vec.Initialize(STANDARD_VECTOR_SIZE); + } + + DataChunk join_keys; + ExpressionExecutor probe_executor; + SelectionVector build_sel_vec; + SelectionVector probe_sel_vec; + SelectionVector seq_sel_vec; +}; + +unique_ptr PerfectHashJoinExecutor::GetOperatorState(ExecutionContext &context) { + auto state = make_uniq(context.client, join); + return std::move(state); +} + +OperatorResultType PerfectHashJoinExecutor::ProbePerfectHashTable(ExecutionContext &context, DataChunk &input, + DataChunk &result, OperatorState &state_p) { + auto &state = state_p.Cast(); + // keeps track of how many probe keys have a match + idx_t probe_sel_count = 0; + + // fetch the join keys from the chunk + state.join_keys.Reset(); + state.probe_executor.Execute(input, state.join_keys); + // select the keys that are in the min-max range + auto &keys_vec = state.join_keys.data[0]; + auto keys_count = state.join_keys.size(); + // todo: add check for fast pass when probe is part of build domain + FillSelectionVectorSwitchProbe(keys_vec, state.build_sel_vec, state.probe_sel_vec, keys_count, probe_sel_count); + + // If build is dense and probe is in build's domain, just reference probe + if (perfect_join_statistics.is_build_dense && keys_count == probe_sel_count) { + result.Reference(input); + } else { + // otherwise, filter it out the values that do not match + result.Slice(input, state.probe_sel_vec, probe_sel_count, 0); + } + // on the build side, we need to fetch the data and build dictionary vectors with the sel_vec + for (idx_t i = 0; i < ht.build_types.size(); i++) { + auto &result_vector = result.data[input.ColumnCount() + i]; + D_ASSERT(result_vector.GetType() == ht.build_types[i]); + auto &build_vec = perfect_hash_table[i]; + result_vector.Reference(build_vec); + result_vector.Slice(state.build_sel_vec, probe_sel_count); + } + return OperatorResultType::NEED_MORE_INPUT; +} + +void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(Vector &source, SelectionVector &build_sel_vec, + SelectionVector &probe_sel_vec, idx_t count, + idx_t &probe_sel_count) { + switch (source.GetType().InternalType()) { + case PhysicalType::INT8: + TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); + break; + case PhysicalType::INT16: + TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); + break; + case PhysicalType::INT32: + TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); + break; + case PhysicalType::INT64: + TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); + break; + case PhysicalType::UINT8: + TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); + break; + case PhysicalType::UINT16: + TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); + break; + case PhysicalType::UINT32: + TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); + break; + case PhysicalType::UINT64: + TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); + break; + default: + throw NotImplementedException("Type not supported"); + } +} + +template +void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(Vector &source, SelectionVector &build_sel_vec, + SelectionVector &probe_sel_vec, idx_t count, + idx_t &probe_sel_count) { + auto min_value = perfect_join_statistics.build_min.GetValueUnsafe(); + auto max_value = perfect_join_statistics.build_max.GetValueUnsafe(); + + UnifiedVectorFormat vector_data; + source.ToUnifiedFormat(count, vector_data); + auto data = reinterpret_cast(vector_data.data); + auto validity_mask = &vector_data.validity; + // build selection vector for non-dense build + if (validity_mask->AllValid()) { + for (idx_t i = 0, sel_idx = 0; i < count; ++i) { + // retrieve value from vector + auto data_idx = vector_data.sel->get_index(i); + auto input_value = data[data_idx]; + // add index to selection vector if value in the range + if (min_value <= input_value && input_value <= max_value) { + auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position + // check for matches in the build + if (bitmap_build_idx[idx]) { + build_sel_vec.set_index(sel_idx, idx); + probe_sel_vec.set_index(sel_idx++, i); + probe_sel_count++; + } + } + } + } else { + for (idx_t i = 0, sel_idx = 0; i < count; ++i) { + // retrieve value from vector + auto data_idx = vector_data.sel->get_index(i); + if (!validity_mask->RowIsValid(data_idx)) { + continue; + } + auto input_value = data[data_idx]; + // add index to selection vector if value in the range + if (min_value <= input_value && input_value <= max_value) { + auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position + // check for matches in the build + if (bitmap_build_idx[idx]) { + build_sel_vec.set_index(sel_idx, idx); + probe_sel_vec.set_index(sel_idx++, i); + probe_sel_count++; + } + } + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp new file mode 100644 index 00000000..0752f5bd --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp @@ -0,0 +1,875 @@ +#include "duckdb/execution/operator/join/physical_asof_join.hpp" + +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/sort/comparators.hpp" +#include "duckdb/common/sort/partition_state.hpp" +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/join/outer_join_marker.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/event.hpp" +#include "duckdb/parallel/thread_context.hpp" + +#include + +namespace duckdb { + +PhysicalAsOfJoin::PhysicalAsOfJoin(LogicalComparisonJoin &op, unique_ptr left, + unique_ptr right) + : PhysicalComparisonJoin(op, PhysicalOperatorType::ASOF_JOIN, std::move(op.conditions), op.join_type, + op.estimated_cardinality), + comparison_type(ExpressionType::INVALID) { + + // Convert the conditions partitions and sorts + for (auto &cond : conditions) { + D_ASSERT(cond.left->return_type == cond.right->return_type); + join_key_types.push_back(cond.left->return_type); + + auto left = cond.left->Copy(); + auto right = cond.right->Copy(); + switch (cond.comparison) { + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHAN: + null_sensitive.emplace_back(lhs_orders.size()); + lhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(left)); + rhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(right)); + comparison_type = cond.comparison; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_LESSTHAN: + // Always put NULLS LAST so they can be ignored. + null_sensitive.emplace_back(lhs_orders.size()); + lhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(left)); + rhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(right)); + comparison_type = cond.comparison; + break; + case ExpressionType::COMPARE_EQUAL: + null_sensitive.emplace_back(lhs_orders.size()); + // Fall through + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + lhs_partitions.emplace_back(std::move(left)); + rhs_partitions.emplace_back(std::move(right)); + break; + default: + throw NotImplementedException("Unsupported join condition for ASOF join"); + } + } + D_ASSERT(!lhs_orders.empty()); + D_ASSERT(!rhs_orders.empty()); + + children.push_back(std::move(left)); + children.push_back(std::move(right)); + + // Fill out the right projection map. + right_projection_map = op.right_projection_map; + if (right_projection_map.empty()) { + const auto right_count = children[1]->types.size(); + right_projection_map.reserve(right_count); + for (column_t i = 0; i < right_count; ++i) { + right_projection_map.emplace_back(i); + } + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class AsOfGlobalSinkState : public GlobalSinkState { +public: + AsOfGlobalSinkState(ClientContext &context, const PhysicalAsOfJoin &op) + : rhs_sink(context, op.rhs_partitions, op.rhs_orders, op.children[1]->types, {}, op.estimated_cardinality), + is_outer(IsRightOuterJoin(op.join_type)), has_null(false) { + } + + idx_t Count() const { + return rhs_sink.count; + } + + PartitionLocalSinkState *RegisterBuffer(ClientContext &context) { + lock_guard guard(lock); + lhs_buffers.emplace_back(make_uniq(context, *lhs_sink)); + return lhs_buffers.back().get(); + } + + PartitionGlobalSinkState rhs_sink; + + // One per partition + const bool is_outer; + vector right_outers; + bool has_null; + + // Left side buffering + unique_ptr lhs_sink; + + mutex lock; + vector> lhs_buffers; +}; + +class AsOfLocalSinkState : public LocalSinkState { +public: + explicit AsOfLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) + : local_partition(context, gstate_p) { + } + + void Sink(DataChunk &input_chunk) { + local_partition.Sink(input_chunk); + } + + void Combine() { + local_partition.Combine(); + } + + PartitionLocalSinkState local_partition; +}; + +unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr PhysicalAsOfJoin::GetLocalSinkState(ExecutionContext &context) const { + // We only sink the RHS + auto &gsink = sink_state->Cast(); + return make_uniq(context.client, gsink.rhs_sink); +} + +SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &lstate = input.local_state.Cast(); + + lstate.Sink(chunk); + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &lstate = input.local_state.Cast(); + lstate.Combine(); + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + + // The data is all in so we can initialise the left partitioning. + const vector> partitions_stats; + gstate.lhs_sink = make_uniq(context, lhs_partitions, lhs_orders, children[0]->types, + partitions_stats, 0); + gstate.lhs_sink->SyncPartitioning(gstate.rhs_sink); + + // Find the first group to sort + if (!gstate.rhs_sink.HasMergeTasks() && EmptyResultIfRHSIsEmpty()) { + // Empty input! + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // Schedule all the sorts for maximum thread utilisation + auto new_event = make_shared(gstate.rhs_sink, pipeline); + event.InsertEvent(std::move(new_event)); + + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +class AsOfGlobalState : public GlobalOperatorState { +public: + explicit AsOfGlobalState(AsOfGlobalSinkState &gsink) { + // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple + auto &rhs_partition = gsink.rhs_sink; + auto &right_outers = gsink.right_outers; + right_outers.reserve(rhs_partition.hash_groups.size()); + for (const auto &hash_group : rhs_partition.hash_groups) { + right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); + right_outers.back().Initialize(hash_group->count); + } + } +}; + +unique_ptr PhysicalAsOfJoin::GetGlobalOperatorState(ClientContext &context) const { + auto &gsink = sink_state->Cast(); + return make_uniq(gsink); +} + +class AsOfLocalState : public CachingOperatorState { +public: + AsOfLocalState(ClientContext &context, const PhysicalAsOfJoin &op) + : context(context), allocator(Allocator::Get(context)), op(op), lhs_executor(context), + left_outer(IsLeftOuterJoin(op.join_type)), fetch_next_left(true) { + lhs_keys.Initialize(allocator, op.join_key_types); + for (const auto &cond : op.conditions) { + lhs_executor.AddExpression(*cond.left); + } + + lhs_payload.Initialize(allocator, op.children[0]->types); + lhs_sel.Initialize(); + left_outer.Initialize(STANDARD_VECTOR_SIZE); + + auto &gsink = op.sink_state->Cast(); + lhs_partition_sink = gsink.RegisterBuffer(context); + } + + bool Sink(DataChunk &input); + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk); + + ClientContext &context; + Allocator &allocator; + const PhysicalAsOfJoin &op; + + ExpressionExecutor lhs_executor; + DataChunk lhs_keys; + ValidityMask lhs_valid_mask; + SelectionVector lhs_sel; + DataChunk lhs_payload; + + OuterJoinMarker left_outer; + bool fetch_next_left; + + optional_ptr lhs_partition_sink; +}; + +bool AsOfLocalState::Sink(DataChunk &input) { + // Compute the join keys + lhs_keys.Reset(); + lhs_executor.Execute(input, lhs_keys); + + // Combine the NULLs + const auto count = input.size(); + lhs_valid_mask.Reset(); + for (auto col_idx : op.null_sensitive) { + auto &col = lhs_keys.data[col_idx]; + UnifiedVectorFormat unified; + col.ToUnifiedFormat(count, unified); + lhs_valid_mask.Combine(unified.validity, count); + } + + // Convert the mask to a selection vector + // and mark all the rows that cannot match for early return. + idx_t lhs_valid = 0; + const auto entry_count = lhs_valid_mask.EntryCount(count); + idx_t base_idx = 0; + left_outer.Reset(); + for (idx_t entry_idx = 0; entry_idx < entry_count;) { + const auto validity_entry = lhs_valid_mask.GetValidityEntry(entry_idx++); + const auto next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (ValidityMask::AllValid(validity_entry)) { + for (; base_idx < next; ++base_idx) { + lhs_sel.set_index(lhs_valid++, base_idx); + left_outer.SetMatch(base_idx); + } + } else if (ValidityMask::NoneValid(validity_entry)) { + base_idx = next; + } else { + const auto start = base_idx; + for (; base_idx < next; ++base_idx) { + if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { + lhs_sel.set_index(lhs_valid++, base_idx); + left_outer.SetMatch(base_idx); + } + } + } + } + + // Slice the keys to the ones we can match + lhs_payload.Reset(); + if (lhs_valid == count) { + lhs_payload.Reference(input); + lhs_payload.SetCardinality(input); + } else { + lhs_payload.Slice(input, lhs_sel, lhs_valid); + lhs_payload.SetCardinality(lhs_valid); + + // Flush the ones that can't match + fetch_next_left = false; + } + + lhs_partition_sink->Sink(lhs_payload); + + return false; +} + +OperatorResultType AsOfLocalState::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk) { + input.Verify(); + Sink(input); + + // If there were any unmatchable rows, return them now so we can forget about them. + if (!fetch_next_left) { + fetch_next_left = true; + left_outer.ConstructLeftJoinResult(input, chunk); + left_outer.Reset(); + } + + // Just keep asking for data and buffering it + return OperatorResultType::NEED_MORE_INPUT; +} + +OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &lstate_p) const { + auto &gsink = sink_state->Cast(); + auto &lstate = lstate_p.Cast(); + + if (gsink.rhs_sink.count == 0) { + // empty RHS + if (!EmptyResultIfRHSIsEmpty()) { + ConstructEmptyJoinResult(join_type, gsink.has_null, input, chunk); + return OperatorResultType::NEED_MORE_INPUT; + } else { + return OperatorResultType::FINISHED; + } + } + + return lstate.ExecuteInternal(context, input, chunk); +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class AsOfProbeBuffer { +public: + using Orders = vector; + + static bool IsExternal(ClientContext &context) { + return ClientConfig::GetConfig(context).force_external; + } + + AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op); + +public: + void ResolveJoin(bool *found_matches, idx_t *matches = nullptr); + bool Scanning() const { + return lhs_scanner.get(); + } + void BeginLeftScan(hash_t scan_bin); + bool NextLeft(); + void EndScan(); + + // resolve joins that output max N elements (SEMI, ANTI, MARK) + void ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk); + // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) + void ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk); + // Chunk may be empty + void GetData(ExecutionContext &context, DataChunk &chunk); + bool HasMoreData() const { + return !fetch_next_left || (lhs_scanner && lhs_scanner->Remaining()); + } + + ClientContext &context; + Allocator &allocator; + const PhysicalAsOfJoin &op; + BufferManager &buffer_manager; + const bool force_external; + const idx_t memory_per_thread; + Orders lhs_orders; + + // LHS scanning + SelectionVector lhs_sel; + optional_ptr left_hash; + OuterJoinMarker left_outer; + unique_ptr left_itr; + unique_ptr lhs_scanner; + DataChunk lhs_payload; + + // RHS scanning + optional_ptr right_hash; + optional_ptr right_outer; + unique_ptr right_itr; + unique_ptr rhs_scanner; + DataChunk rhs_payload; + + idx_t lhs_match_count; + bool fetch_next_left; +}; + +AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op) + : context(context), allocator(Allocator::Get(context)), op(op), + buffer_manager(BufferManager::GetBufferManager(context)), force_external(IsExternal(context)), + memory_per_thread(op.GetMaxThreadMemory(context)), left_outer(IsLeftOuterJoin(op.join_type)), + fetch_next_left(true) { + vector> partition_stats; + Orders partitions; // Not used. + PartitionGlobalSinkState::GenerateOrderings(partitions, lhs_orders, op.lhs_partitions, op.lhs_orders, + partition_stats); + + // We sort the row numbers of the incoming block, not the rows + lhs_payload.Initialize(allocator, op.children[0]->types); + rhs_payload.Initialize(allocator, op.children[1]->types); + + lhs_sel.Initialize(); + left_outer.Initialize(STANDARD_VECTOR_SIZE); +} + +void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { + auto &gsink = op.sink_state->Cast(); + auto &lhs_sink = *gsink.lhs_sink; + const auto left_group = lhs_sink.bin_groups[scan_bin]; + if (left_group >= lhs_sink.bin_groups.size()) { + return; + } + + auto iterator_comp = ExpressionType::INVALID; + switch (op.comparison_type) { + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + iterator_comp = ExpressionType::COMPARE_LESSTHANOREQUALTO; + break; + case ExpressionType::COMPARE_GREATERTHAN: + iterator_comp = ExpressionType::COMPARE_LESSTHAN; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + iterator_comp = ExpressionType::COMPARE_GREATERTHANOREQUALTO; + break; + case ExpressionType::COMPARE_LESSTHAN: + iterator_comp = ExpressionType::COMPARE_GREATERTHAN; + break; + default: + throw NotImplementedException("Unsupported comparison type for ASOF join"); + } + + left_hash = lhs_sink.hash_groups[left_group].get(); + auto &left_sort = *(left_hash->global_sort); + if (left_sort.sorted_blocks.empty()) { + return; + } + lhs_scanner = make_uniq(left_sort, false); + left_itr = make_uniq(left_sort, iterator_comp); + + // We are only probing the corresponding right side bin, which may be empty + // If they are empty, we leave the iterator as null so we can emit left matches + auto &rhs_sink = gsink.rhs_sink; + const auto right_group = rhs_sink.bin_groups[scan_bin]; + if (right_group < rhs_sink.bin_groups.size()) { + right_hash = rhs_sink.hash_groups[right_group].get(); + right_outer = gsink.right_outers.data() + right_group; + auto &right_sort = *(right_hash->global_sort); + right_itr = make_uniq(right_sort, iterator_comp); + rhs_scanner = make_uniq(right_sort, false); + } +} + +bool AsOfProbeBuffer::NextLeft() { + if (!HasMoreData()) { + return false; + } + + // Scan the next sorted chunk + lhs_payload.Reset(); + left_itr->SetIndex(lhs_scanner->Scanned()); + lhs_scanner->Scan(lhs_payload); + + return true; +} + +void AsOfProbeBuffer::EndScan() { + right_hash = nullptr; + right_itr.reset(); + rhs_scanner.reset(); + right_outer = nullptr; + + left_hash = nullptr; + left_itr.reset(); + lhs_scanner.reset(); +} + +void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { + // If there was no right partition, there are no matches + lhs_match_count = 0; + left_outer.Reset(); + if (!right_itr) { + return; + } + + const auto count = lhs_payload.size(); + const auto left_base = left_itr->GetIndex(); + // Searching for right <= left + for (idx_t i = 0; i < count; ++i) { + left_itr->SetIndex(left_base + i); + + // If right > left, then there is no match + if (!right_itr->Compare(*left_itr)) { + continue; + } + + // Exponential search forward for a non-matching value using radix iterators + // (We use exponential search to avoid thrashing the block manager on large probes) + idx_t bound = 1; + idx_t begin = right_itr->GetIndex(); + right_itr->SetIndex(begin + bound); + while (right_itr->GetIndex() < right_hash->count) { + if (right_itr->Compare(*left_itr)) { + // If right <= left, jump ahead + bound *= 2; + right_itr->SetIndex(begin + bound); + } else { + break; + } + } + + // Binary search for the first non-matching value using radix iterators + // The previous value (which we know exists) is the match + auto first = begin + bound / 2; + auto last = MinValue(begin + bound, right_hash->count); + while (first < last) { + const auto mid = first + (last - first) / 2; + right_itr->SetIndex(mid); + if (right_itr->Compare(*left_itr)) { + // If right <= left, new lower bound + first = mid + 1; + } else { + last = mid; + } + } + right_itr->SetIndex(--first); + + // Check partitions for strict equality + if (right_hash->ComparePartitions(*left_itr, *right_itr)) { + continue; + } + + // Emit match data + right_outer->SetMatch(first); + left_outer.SetMatch(i); + if (found_match) { + found_match[i] = true; + } + if (matches) { + matches[i] = first; + } + lhs_sel.set_index(lhs_match_count++, i); + } +} + +unique_ptr PhysicalAsOfJoin::GetOperatorState(ExecutionContext &context) const { + return make_uniq(context.client, *this); +} + +void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk) { + // perform the actual join + bool found_match[STANDARD_VECTOR_SIZE] = {false}; + ResolveJoin(found_match); + + // now construct the result based on the join result + switch (op.join_type) { + case JoinType::SEMI: + PhysicalJoin::ConstructSemiJoinResult(lhs_payload, chunk, found_match); + break; + case JoinType::ANTI: + PhysicalJoin::ConstructAntiJoinResult(lhs_payload, chunk, found_match); + break; + default: + throw NotImplementedException("Unimplemented join type for AsOf join"); + } +} + +void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk) { + // perform the actual join + idx_t matches[STANDARD_VECTOR_SIZE]; + ResolveJoin(nullptr, matches); + + for (idx_t i = 0; i < lhs_match_count; ++i) { + const auto idx = lhs_sel[i]; + const auto match_pos = matches[idx]; + // Skip to the range containing the match + while (match_pos >= rhs_scanner->Scanned()) { + rhs_payload.Reset(); + rhs_scanner->Scan(rhs_payload); + } + // Append the individual values + // TODO: Batch the copies + const auto source_offset = match_pos - (rhs_scanner->Scanned() - rhs_payload.size()); + for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { + const auto rhs_idx = op.right_projection_map[col_idx]; + auto &source = rhs_payload.data[rhs_idx]; + auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; + VectorOperations::Copy(source, target, source_offset + 1, source_offset, i); + } + } + + // Slice the left payload into the result + for (column_t i = 0; i < lhs_payload.ColumnCount(); ++i) { + chunk.data[i].Slice(lhs_payload.data[i], lhs_sel, lhs_match_count); + } + chunk.SetCardinality(lhs_match_count); + + // If we are doing a left join, come back for the NULLs + fetch_next_left = !left_outer.Enabled(); +} + +void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { + // Handle dangling left join results from current chunk + if (!fetch_next_left) { + fetch_next_left = true; + if (left_outer.Enabled()) { + // left join: before we move to the next chunk, see if we need to output any vectors that didn't + // have a match found + left_outer.ConstructLeftJoinResult(lhs_payload, chunk); + left_outer.Reset(); + } + return; + } + + // Stop if there is no more data + if (!NextLeft()) { + return; + } + + switch (op.join_type) { + case JoinType::SEMI: + case JoinType::ANTI: + case JoinType::MARK: + // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk + ResolveSimpleJoin(context, chunk); + break; + case JoinType::LEFT: + case JoinType::INNER: + case JoinType::RIGHT: + case JoinType::OUTER: + ResolveComplexJoin(context, chunk); + break; + default: + throw NotImplementedException("Unimplemented type for as-of join!"); + } +} + +class AsOfGlobalSourceState : public GlobalSourceState { +public: + explicit AsOfGlobalSourceState(AsOfGlobalSinkState &gsink_p) + : gsink(gsink_p), next_combine(0), combined(0), merged(0), mergers(0), next_left(0), flushed(0), next_right(0) { + } + + PartitionGlobalMergeStates &GetMergeStates() { + lock_guard guard(lock); + if (!merge_states) { + merge_states = make_uniq(*gsink.lhs_sink); + } + return *merge_states; + } + + AsOfGlobalSinkState &gsink; + //! The next buffer to combine + atomic next_combine; + //! The number of combined buffers + atomic combined; + //! The number of combined buffers + atomic merged; + //! The number of combined buffers + atomic mergers; + //! The next buffer to flush + atomic next_left; + //! The number of flushed buffers + atomic flushed; + //! The right outer output read position. + atomic next_right; + //! The merge handler + mutex lock; + unique_ptr merge_states; + +public: + idx_t MaxThreads() override { + return gsink.lhs_buffers.size(); + } +}; + +unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &context) const { + auto &gsink = sink_state->Cast(); + return make_uniq(gsink); +} + +class AsOfLocalSourceState : public LocalSourceState { +public: + using HashGroupPtr = unique_ptr; + + AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, ClientContext &client_p); + + // Return true if we were not interrupted (another thread died) + bool CombineLeftPartitions(); + bool MergeLeftPartitions(); + + idx_t BeginRightScan(const idx_t hash_bin); + + AsOfGlobalSourceState &gsource; + ClientContext &client; + + //! The left side partition being probed + AsOfProbeBuffer probe_buffer; + + //! The read partition + idx_t hash_bin; + HashGroupPtr hash_group; + //! The read cursor + unique_ptr scanner; + //! Pointer to the matches + const bool *found_match = {}; +}; + +AsOfLocalSourceState::AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, + ClientContext &client_p) + : gsource(gsource), client(client_p), probe_buffer(gsource.gsink.lhs_sink->context, op) { + gsource.mergers++; +} + +bool AsOfLocalSourceState::CombineLeftPartitions() { + const auto buffer_count = gsource.gsink.lhs_buffers.size(); + while (gsource.combined < buffer_count && !client.interrupted) { + const auto next_combine = gsource.next_combine++; + if (next_combine < buffer_count) { + gsource.gsink.lhs_buffers[next_combine]->Combine(); + ++gsource.combined; + } else { + TaskScheduler::GetScheduler(client).YieldThread(); + } + } + + return !client.interrupted; +} + +bool AsOfLocalSourceState::MergeLeftPartitions() { + PartitionGlobalMergeStates::Callback local_callback; + PartitionLocalMergeState local_merge(*gsource.gsink.lhs_sink); + gsource.GetMergeStates().ExecuteTask(local_merge, local_callback); + gsource.merged++; + while (gsource.merged < gsource.mergers && !client.interrupted) { + TaskScheduler::GetScheduler(client).YieldThread(); + } + return !client.interrupted; +} + +idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { + hash_bin = hash_bin_p; + + hash_group = std::move(gsource.gsink.rhs_sink.hash_groups[hash_bin]); + if (hash_group->global_sort->sorted_blocks.empty()) { + return 0; + } + scanner = make_uniq(*hash_group->global_sort); + found_match = gsource.gsink.right_outers[hash_bin].GetMatches(); + + return scanner->Remaining(); +} + +unique_ptr PhysicalAsOfJoin::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + auto &gsource = gstate.Cast(); + return make_uniq(gsource, *this, context.client); +} + +SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &gsource = input.global_state.Cast(); + auto &lsource = input.local_state.Cast(); + auto &rhs_sink = gsource.gsink.rhs_sink; + auto &client = context.client; + + // Step 1: Combine the partitions + if (!lsource.CombineLeftPartitions()) { + return SourceResultType::FINISHED; + } + + // Step 2: Sort on all threads + if (!lsource.MergeLeftPartitions()) { + return SourceResultType::FINISHED; + } + + // Step 3: Join the partitions + auto &lhs_sink = *gsource.gsink.lhs_sink; + const auto left_bins = lhs_sink.grouping_data ? lhs_sink.grouping_data->GetPartitions().size() : 1; + while (gsource.flushed < left_bins) { + // Make sure we have something to flush + if (!lsource.probe_buffer.Scanning()) { + const auto left_bin = gsource.next_left++; + if (left_bin < left_bins) { + // More to flush + lsource.probe_buffer.BeginLeftScan(left_bin); + } else if (!IsRightOuterJoin(join_type) || client.interrupted) { + return SourceResultType::FINISHED; + } else { + // Wait for all threads to finish + // TODO: How to implement a spin wait correctly? + // Returning BLOCKED seems to hang the system. + TaskScheduler::GetScheduler(client).YieldThread(); + continue; + } + } + + lsource.probe_buffer.GetData(context, chunk); + if (chunk.size()) { + return SourceResultType::HAVE_MORE_OUTPUT; + } else if (lsource.probe_buffer.HasMoreData()) { + // Join the next partition + continue; + } else { + lsource.probe_buffer.EndScan(); + gsource.flushed++; + } + } + + // Step 4: Emit right join matches + if (!IsRightOuterJoin(join_type)) { + return SourceResultType::FINISHED; + } + + auto &hash_groups = rhs_sink.hash_groups; + const auto right_groups = hash_groups.size(); + + DataChunk rhs_chunk; + rhs_chunk.Initialize(Allocator::Get(context.client), rhs_sink.payload_types); + SelectionVector rsel(STANDARD_VECTOR_SIZE); + + while (chunk.size() == 0) { + // Move to the next bin if we are done. + while (!lsource.scanner || !lsource.scanner->Remaining()) { + lsource.scanner.reset(); + lsource.hash_group.reset(); + auto hash_bin = gsource.next_right++; + if (hash_bin >= right_groups) { + return SourceResultType::FINISHED; + } + + for (; hash_bin < hash_groups.size(); hash_bin = gsource.next_right++) { + if (hash_groups[hash_bin]) { + break; + } + } + lsource.BeginRightScan(hash_bin); + } + const auto rhs_position = lsource.scanner->Scanned(); + lsource.scanner->Scan(rhs_chunk); + + const auto count = rhs_chunk.size(); + if (count == 0) { + return SourceResultType::FINISHED; + } + + // figure out which tuples didn't find a match in the RHS + auto found_match = lsource.found_match; + idx_t result_count = 0; + for (idx_t i = 0; i < count; i++) { + if (!found_match[rhs_position + i]) { + rsel.set_index(result_count++, i); + } + } + + if (result_count > 0) { + // if there were any tuples that didn't find a match, output them + const idx_t left_column_count = children[0]->types.size(); + for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); + } + for (idx_t col_idx = 0; col_idx < right_projection_map.size(); ++col_idx) { + const auto rhs_idx = right_projection_map[col_idx]; + chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); + } + chunk.SetCardinality(result_count); + break; + } + } + + return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp new file mode 100644 index 00000000..33a7947b --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp @@ -0,0 +1,264 @@ +#include "duckdb/execution/operator/join/physical_blockwise_nl_join.hpp" + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/join/outer_join_marker.hpp" +#include "duckdb/execution/operator/join/physical_comparison_join.hpp" +#include "duckdb/execution/operator/join/physical_cross_product.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +PhysicalBlockwiseNLJoin::PhysicalBlockwiseNLJoin(LogicalOperator &op, unique_ptr left, + unique_ptr right, unique_ptr condition, + JoinType join_type, idx_t estimated_cardinality) + : PhysicalJoin(op, PhysicalOperatorType::BLOCKWISE_NL_JOIN, join_type, estimated_cardinality), + condition(std::move(condition)) { + children.push_back(std::move(left)); + children.push_back(std::move(right)); + // MARK and SINGLE joins not handled + D_ASSERT(join_type != JoinType::MARK); + D_ASSERT(join_type != JoinType::SINGLE); +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class BlockwiseNLJoinLocalState : public LocalSinkState { +public: + BlockwiseNLJoinLocalState() { + } +}; + +class BlockwiseNLJoinGlobalState : public GlobalSinkState { +public: + explicit BlockwiseNLJoinGlobalState(ClientContext &context, const PhysicalBlockwiseNLJoin &op) + : right_chunks(context, op.children[1]->GetTypes()), right_outer(IsRightOuterJoin(op.join_type)) { + } + + mutex lock; + ColumnDataCollection right_chunks; + OuterJoinMarker right_outer; +}; + +unique_ptr PhysicalBlockwiseNLJoin::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr PhysicalBlockwiseNLJoin::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(); +} + +SinkResultType PhysicalBlockwiseNLJoin::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + lock_guard nl_lock(gstate.lock); + gstate.right_chunks.Append(chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +SinkFinalizeType PhysicalBlockwiseNLJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + gstate.right_outer.Initialize(gstate.right_chunks.Count()); + + if (gstate.right_chunks.Count() == 0 && EmptyResultIfRHSIsEmpty()) { + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +class BlockwiseNLJoinState : public CachingOperatorState { +public: + explicit BlockwiseNLJoinState(ExecutionContext &context, ColumnDataCollection &rhs, + const PhysicalBlockwiseNLJoin &op) + : cross_product(rhs), left_outer(IsLeftOuterJoin(op.join_type)), match_sel(STANDARD_VECTOR_SIZE), + executor(context.client, *op.condition) { + left_outer.Initialize(STANDARD_VECTOR_SIZE); + } + + CrossProductExecutor cross_product; + OuterJoinMarker left_outer; + SelectionVector match_sel; + ExpressionExecutor executor; + DataChunk intermediate_chunk; +}; + +unique_ptr PhysicalBlockwiseNLJoin::GetOperatorState(ExecutionContext &context) const { + auto &gstate = sink_state->Cast(); + auto result = make_uniq(context, gstate.right_chunks, *this); + if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { + vector intermediate_types; + for (auto &type : children[0]->types) { + intermediate_types.emplace_back(type); + } + for (auto &type : children[1]->types) { + intermediate_types.emplace_back(type); + } + result->intermediate_chunk.Initialize(Allocator::DefaultAllocator(), intermediate_types); + } + return std::move(result); +} + +OperatorResultType PhysicalBlockwiseNLJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, + DataChunk &chunk, GlobalOperatorState &gstate_p, + OperatorState &state_p) const { + D_ASSERT(input.size() > 0); + auto &state = state_p.Cast(); + auto &gstate = sink_state->Cast(); + + if (gstate.right_chunks.Count() == 0) { + // empty RHS + if (!EmptyResultIfRHSIsEmpty()) { + PhysicalComparisonJoin::ConstructEmptyJoinResult(join_type, false, input, chunk); + return OperatorResultType::NEED_MORE_INPUT; + } else { + return OperatorResultType::FINISHED; + } + } + + DataChunk *intermediate_chunk = &chunk; + if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { + intermediate_chunk = &state.intermediate_chunk; + intermediate_chunk->Reset(); + } + + // now perform the actual join + // we perform a cross product, then execute the expression directly on the cross product result + idx_t result_count = 0; + bool found_match[STANDARD_VECTOR_SIZE] = {false}; + + do { + auto result = state.cross_product.Execute(input, *intermediate_chunk); + if (result == OperatorResultType::NEED_MORE_INPUT) { + // exhausted input, have to pull new LHS chunk + if (state.left_outer.Enabled()) { + // left join: before we move to the next chunk, see if we need to output any vectors that didn't + // have a match found + state.left_outer.ConstructLeftJoinResult(input, *intermediate_chunk); + state.left_outer.Reset(); + } + + if (join_type == JoinType::SEMI) { + PhysicalJoin::ConstructSemiJoinResult(input, chunk, found_match); + } + if (join_type == JoinType::ANTI) { + PhysicalJoin::ConstructAntiJoinResult(input, chunk, found_match); + } + + return OperatorResultType::NEED_MORE_INPUT; + } + + // now perform the computation + result_count = state.executor.SelectExpression(*intermediate_chunk, state.match_sel); + + // handle anti and semi joins with different logic + if (result_count > 0) { + // found a match! + // handle anti semi join conditions first + if (join_type == JoinType::ANTI || join_type == JoinType::SEMI) { + if (state.cross_product.ScanLHS()) { + found_match[state.cross_product.PositionInChunk()] = true; + } else { + for (idx_t i = 0; i < result_count; i++) { + found_match[state.match_sel.get_index(i)] = true; + } + } + intermediate_chunk->Reset(); + // trick the loop to continue as semi and anti joins will never produce more output than + // the LHS cardinality + result_count = 0; + } else { + // check if the cross product is scanning the LHS or the RHS in its entirety + if (!state.cross_product.ScanLHS()) { + // set the match flags in the LHS + state.left_outer.SetMatches(state.match_sel, result_count); + // set the match flag in the RHS + gstate.right_outer.SetMatch(state.cross_product.ScanPosition() + + state.cross_product.PositionInChunk()); + } else { + // set the match flag in the LHS + state.left_outer.SetMatch(state.cross_product.PositionInChunk()); + // set the match flags in the RHS + gstate.right_outer.SetMatches(state.match_sel, result_count, state.cross_product.ScanPosition()); + } + intermediate_chunk->Slice(state.match_sel, result_count); + } + } else { + // no result: reset the chunk + intermediate_chunk->Reset(); + } + } while (result_count == 0); + + return OperatorResultType::HAVE_MORE_OUTPUT; +} + +string PhysicalBlockwiseNLJoin::ParamsToString() const { + string extra_info = EnumUtil::ToString(join_type) + "\n"; + extra_info += condition->GetName(); + return extra_info; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class BlockwiseNLJoinGlobalScanState : public GlobalSourceState { +public: + explicit BlockwiseNLJoinGlobalScanState(const PhysicalBlockwiseNLJoin &op) : op(op) { + D_ASSERT(op.sink_state); + auto &sink = op.sink_state->Cast(); + sink.right_outer.InitializeScan(sink.right_chunks, scan_state); + } + + const PhysicalBlockwiseNLJoin &op; + OuterJoinGlobalScanState scan_state; + +public: + idx_t MaxThreads() override { + auto &sink = op.sink_state->Cast(); + return sink.right_outer.MaxThreads(); + } +}; + +class BlockwiseNLJoinLocalScanState : public LocalSourceState { +public: + explicit BlockwiseNLJoinLocalScanState(const PhysicalBlockwiseNLJoin &op, BlockwiseNLJoinGlobalScanState &gstate) { + D_ASSERT(op.sink_state); + auto &sink = op.sink_state->Cast(); + sink.right_outer.InitializeScan(gstate.scan_state, scan_state); + } + + OuterJoinLocalScanState scan_state; +}; + +unique_ptr PhysicalBlockwiseNLJoin::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this); +} + +unique_ptr PhysicalBlockwiseNLJoin::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(*this, gstate.Cast()); +} + +SourceResultType PhysicalBlockwiseNLJoin::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + D_ASSERT(IsRightOuterJoin(join_type)); + // check if we need to scan any unmatched tuples from the RHS for the full/right outer join + auto &sink = sink_state->Cast(); + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan chunks we still need to output + sink.right_outer.Scan(gstate.scan_state, lstate.scan_state, chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp b/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp new file mode 100644 index 00000000..dc4e9b97 --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp @@ -0,0 +1,83 @@ +#include "duckdb/execution/operator/join/physical_comparison_join.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +PhysicalComparisonJoin::PhysicalComparisonJoin(LogicalOperator &op, PhysicalOperatorType type, + vector conditions_p, JoinType join_type, + idx_t estimated_cardinality) + : PhysicalJoin(op, type, join_type, estimated_cardinality) { + conditions.resize(conditions_p.size()); + // we reorder conditions so the ones with COMPARE_EQUAL occur first + idx_t equal_position = 0; + idx_t other_position = conditions_p.size() - 1; + for (idx_t i = 0; i < conditions_p.size(); i++) { + if (conditions_p[i].comparison == ExpressionType::COMPARE_EQUAL || + conditions_p[i].comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + // COMPARE_EQUAL and COMPARE_NOT_DISTINCT_FROM, move to the start + conditions[equal_position++] = std::move(conditions_p[i]); + } else { + // other expression, move to the end + conditions[other_position--] = std::move(conditions_p[i]); + } + } +} + +string PhysicalComparisonJoin::ParamsToString() const { + string extra_info = EnumUtil::ToString(join_type) + "\n"; + for (auto &it : conditions) { + string op = ExpressionTypeToOperator(it.comparison); + extra_info += it.left->GetName() + " " + op + " " + it.right->GetName() + "\n"; + } + extra_info += "\n[INFOSEPARATOR]\n"; + extra_info += StringUtil::Format("EC: %llu\n", estimated_cardinality); + return extra_info; +} + +void PhysicalComparisonJoin::ConstructEmptyJoinResult(JoinType join_type, bool has_null, DataChunk &input, + DataChunk &result) { + // empty hash table, special case + if (join_type == JoinType::ANTI) { + // anti join with empty hash table, NOP join + // return the input + D_ASSERT(input.ColumnCount() == result.ColumnCount()); + result.Reference(input); + } else if (join_type == JoinType::MARK) { + // MARK join with empty hash table + D_ASSERT(join_type == JoinType::MARK); + D_ASSERT(result.ColumnCount() == input.ColumnCount() + 1); + auto &result_vector = result.data.back(); + D_ASSERT(result_vector.GetType() == LogicalType::BOOLEAN); + // for every data vector, we just reference the child chunk + result.SetCardinality(input); + for (idx_t i = 0; i < input.ColumnCount(); i++) { + result.data[i].Reference(input.data[i]); + } + // for the MARK vector: + // if the HT has no NULL values (i.e. empty result set), return a vector that has false for every input + // entry if the HT has NULL values (i.e. result set had values, but all were NULL), return a vector that + // has NULL for every input entry + if (!has_null) { + auto bool_result = FlatVector::GetData(result_vector); + for (idx_t i = 0; i < result.size(); i++) { + bool_result[i] = false; + } + } else { + FlatVector::Validity(result_vector).SetAllInvalid(result.size()); + } + } else if (join_type == JoinType::LEFT || join_type == JoinType::OUTER || join_type == JoinType::SINGLE) { + // LEFT/FULL OUTER/SINGLE join and build side is empty + // for the LHS we reference the data + result.SetCardinality(input.size()); + for (idx_t i = 0; i < input.ColumnCount(); i++) { + result.data[i].Reference(input.data[i]); + } + // for the RHS + for (idx_t k = input.ColumnCount(); k < result.ColumnCount(); k++) { + result.data[k].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result.data[k], true); + } + } +} +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_cross_product.cpp b/src/duckdb/src/execution/operator/join/physical_cross_product.cpp new file mode 100644 index 00000000..a1175017 --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_cross_product.cpp @@ -0,0 +1,146 @@ +#include "duckdb/execution/operator/join/physical_cross_product.hpp" + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/operator/join/physical_join.hpp" + +namespace duckdb { + +PhysicalCrossProduct::PhysicalCrossProduct(vector types, unique_ptr left, + unique_ptr right, idx_t estimated_cardinality) + : CachingPhysicalOperator(PhysicalOperatorType::CROSS_PRODUCT, std::move(types), estimated_cardinality) { + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class CrossProductGlobalState : public GlobalSinkState { +public: + explicit CrossProductGlobalState(ClientContext &context, const PhysicalCrossProduct &op) + : rhs_materialized(context, op.children[1]->GetTypes()) { + rhs_materialized.InitializeAppend(append_state); + } + + ColumnDataCollection rhs_materialized; + ColumnDataAppendState append_state; + mutex rhs_lock; +}; + +unique_ptr PhysicalCrossProduct::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +SinkResultType PhysicalCrossProduct::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &sink = input.global_state.Cast(); + lock_guard client_guard(sink.rhs_lock); + sink.rhs_materialized.Append(sink.append_state, chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +CrossProductExecutor::CrossProductExecutor(ColumnDataCollection &rhs) + : rhs(rhs), position_in_chunk(0), initialized(false), finished(false) { + rhs.InitializeScanChunk(scan_chunk); +} + +void CrossProductExecutor::Reset(DataChunk &input, DataChunk &output) { + initialized = true; + finished = false; + scan_input_chunk = false; + rhs.InitializeScan(scan_state); + position_in_chunk = 0; + scan_chunk.Reset(); +} + +bool CrossProductExecutor::NextValue(DataChunk &input, DataChunk &output) { + if (!initialized) { + // not initialized yet: initialize the scan + Reset(input, output); + } + position_in_chunk++; + idx_t chunk_size = scan_input_chunk ? input.size() : scan_chunk.size(); + if (position_in_chunk < chunk_size) { + return true; + } + // fetch the next chunk + rhs.Scan(scan_state, scan_chunk); + position_in_chunk = 0; + if (scan_chunk.size() == 0) { + return false; + } + // the way the cross product works is that we keep one chunk constantly referenced + // while iterating over the other chunk one value at a time + // the second one is the chunk we are "scanning" + + // for the engine, it is better if we emit larger chunks + // hence the chunk that we keep constantly referenced should be the larger of the two + scan_input_chunk = input.size() < scan_chunk.size(); + return true; +} + +OperatorResultType CrossProductExecutor::Execute(DataChunk &input, DataChunk &output) { + if (rhs.Count() == 0) { + // no RHS: empty result + return OperatorResultType::FINISHED; + } + if (!NextValue(input, output)) { + // ran out of entries on the RHS + // reset the RHS and move to the next chunk on the LHS + initialized = false; + return OperatorResultType::NEED_MORE_INPUT; + } + + // set up the constant chunk + auto &constant_chunk = scan_input_chunk ? scan_chunk : input; + auto col_count = constant_chunk.ColumnCount(); + auto col_offset = scan_input_chunk ? input.ColumnCount() : 0; + output.SetCardinality(constant_chunk.size()); + for (idx_t i = 0; i < col_count; i++) { + output.data[col_offset + i].Reference(constant_chunk.data[i]); + } + + // for the chunk that we are scanning, scan a single value from that chunk + auto &scan = scan_input_chunk ? input : scan_chunk; + col_count = scan.ColumnCount(); + col_offset = scan_input_chunk ? 0 : input.ColumnCount(); + for (idx_t i = 0; i < col_count; i++) { + ConstantVector::Reference(output.data[col_offset + i], scan.data[i], position_in_chunk, scan.size()); + } + return OperatorResultType::HAVE_MORE_OUTPUT; +} + +class CrossProductOperatorState : public CachingOperatorState { +public: + explicit CrossProductOperatorState(ColumnDataCollection &rhs) : executor(rhs) { + } + + CrossProductExecutor executor; +}; + +unique_ptr PhysicalCrossProduct::GetOperatorState(ExecutionContext &context) const { + auto &sink = sink_state->Cast(); + return make_uniq(sink.rhs_materialized); +} + +OperatorResultType PhysicalCrossProduct::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state_p) const { + auto &state = state_p.Cast(); + return state.executor.Execute(input, chunk); +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalCrossProduct::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); +} + +vector> PhysicalCrossProduct::GetSources() const { + return children[0]->GetSources(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_delim_join.cpp new file mode 100644 index 00000000..487fc35d --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_delim_join.cpp @@ -0,0 +1,151 @@ +#include "duckdb/execution/operator/join/physical_delim_join.hpp" + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parallel/thread_context.hpp" + +namespace duckdb { + +PhysicalDelimJoin::PhysicalDelimJoin(vector types, unique_ptr original_join, + vector> delim_scans, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::DELIM_JOIN, std::move(types), estimated_cardinality), + join(std::move(original_join)), delim_scans(std::move(delim_scans)) { + D_ASSERT(join->children.size() == 2); + // now for the original join + // we take its left child, this is the side that we will duplicate eliminate + children.push_back(std::move(join->children[0])); + + // we replace it with a PhysicalColumnDataScan, that scans the ColumnDataCollection that we keep cached + // the actual chunk collection to scan will be created in the DelimJoinGlobalState + auto cached_chunk_scan = make_uniq( + children[0]->GetTypes(), PhysicalOperatorType::COLUMN_DATA_SCAN, estimated_cardinality); + join->children[0] = std::move(cached_chunk_scan); +} + +vector> PhysicalDelimJoin::GetChildren() const { + vector> result; + for (auto &child : children) { + result.push_back(*child); + } + result.push_back(*join); + result.push_back(*distinct); + return result; +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class DelimJoinGlobalState : public GlobalSinkState { +public: + explicit DelimJoinGlobalState(ClientContext &context, const PhysicalDelimJoin &delim_join) + : lhs_data(context, delim_join.children[0]->GetTypes()) { + D_ASSERT(delim_join.delim_scans.size() > 0); + // set up the delim join chunk to scan in the original join + auto &cached_chunk_scan = delim_join.join->children[0]->Cast(); + cached_chunk_scan.collection = &lhs_data; + } + + ColumnDataCollection lhs_data; + mutex lhs_lock; + + void Merge(ColumnDataCollection &input) { + lock_guard guard(lhs_lock); + lhs_data.Combine(input); + } +}; + +class DelimJoinLocalState : public LocalSinkState { +public: + explicit DelimJoinLocalState(ClientContext &context, const PhysicalDelimJoin &delim_join) + : lhs_data(context, delim_join.children[0]->GetTypes()) { + lhs_data.InitializeAppend(append_state); + } + + unique_ptr distinct_state; + ColumnDataCollection lhs_data; + ColumnDataAppendState append_state; + + void Append(DataChunk &input) { + lhs_data.Append(input); + } +}; + +unique_ptr PhysicalDelimJoin::GetGlobalSinkState(ClientContext &context) const { + auto state = make_uniq(context, *this); + distinct->sink_state = distinct->GetGlobalSinkState(context); + if (delim_scans.size() > 1) { + PhysicalHashAggregate::SetMultiScan(*distinct->sink_state); + } + return std::move(state); +} + +unique_ptr PhysicalDelimJoin::GetLocalSinkState(ExecutionContext &context) const { + auto state = make_uniq(context.client, *this); + state->distinct_state = distinct->GetLocalSinkState(context); + return std::move(state); +} + +SinkResultType PhysicalDelimJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &lstate = input.local_state.Cast(); + lstate.lhs_data.Append(lstate.append_state, chunk); + OperatorSinkInput distinct_sink_input {*distinct->sink_state, *lstate.distinct_state, input.interrupt_state}; + distinct->Sink(context, chunk, distinct_sink_input); + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalDelimJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &lstate = input.local_state.Cast(); + auto &gstate = input.global_state.Cast(); + gstate.Merge(lstate.lhs_data); + + OperatorSinkCombineInput distinct_combine_input {*distinct->sink_state, *lstate.distinct_state, + input.interrupt_state}; + distinct->Combine(context, distinct_combine_input); + + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalDelimJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, + OperatorSinkFinalizeInput &input) const { + // finalize the distinct HT + D_ASSERT(distinct); + + OperatorSinkFinalizeInput finalize_input {*distinct->sink_state, input.interrupt_state}; + distinct->Finalize(pipeline, event, client, finalize_input); + return SinkFinalizeType::READY; +} + +string PhysicalDelimJoin::ParamsToString() const { + return join->ParamsToString(); +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalDelimJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + op_state.reset(); + sink_state.reset(); + + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + child_meta_pipeline.Build(*children[0]); + + if (type == PhysicalOperatorType::DELIM_JOIN) { + // recurse into the actual join + // any pipelines in there depend on the main pipeline + // any scan of the duplicate eliminated data on the RHS depends on this pipeline + // we add an entry to the mapping of (PhysicalOperator*) -> (Pipeline*) + auto &state = meta_pipeline.GetState(); + for (auto &delim_scan : delim_scans) { + state.delim_join_dependencies.insert( + make_pair(delim_scan, reference(*child_meta_pipeline.GetBasePipeline()))); + } + join->BuildPipelines(current, meta_pipeline); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp new file mode 100644 index 00000000..d6faee75 --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp @@ -0,0 +1,929 @@ +#include "duckdb/execution/operator/join/physical_hash_join.hpp" + +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/storage_manager.hpp" + +namespace duckdb { + +PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr left, + unique_ptr right, vector cond, JoinType join_type, + const vector &left_projection_map, + const vector &right_projection_map_p, vector delim_types, + idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_stats) + : PhysicalComparisonJoin(op, PhysicalOperatorType::HASH_JOIN, std::move(cond), join_type, estimated_cardinality), + right_projection_map(right_projection_map_p), delim_types(std::move(delim_types)), + perfect_join_statistics(std::move(perfect_join_stats)) { + + children.push_back(std::move(left)); + children.push_back(std::move(right)); + + D_ASSERT(left_projection_map.empty()); + for (auto &condition : conditions) { + condition_types.push_back(condition.left->return_type); + } + + // for ANTI, SEMI and MARK join, we only need to store the keys, so for these the build types are empty + if (join_type != JoinType::ANTI && join_type != JoinType::SEMI && join_type != JoinType::MARK) { + build_types = LogicalOperator::MapTypes(children[1]->GetTypes(), right_projection_map); + } +} + +PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr left, + unique_ptr right, vector cond, JoinType join_type, + idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_state) + : PhysicalHashJoin(op, std::move(left), std::move(right), std::move(cond), join_type, {}, {}, {}, + estimated_cardinality, std::move(perfect_join_state)) { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class HashJoinGlobalSinkState : public GlobalSinkState { +public: + HashJoinGlobalSinkState(const PhysicalHashJoin &op, ClientContext &context_p) + : context(context_p), finalized(false), scanned_data(false) { + hash_table = op.InitializeHashTable(context); + + // for perfect hash join + perfect_join_executor = make_uniq(op, *hash_table, op.perfect_join_statistics); + // for external hash join + external = ClientConfig::GetConfig(context).force_external; + // Set probe types + const auto &payload_types = op.children[0]->types; + probe_types.insert(probe_types.end(), op.condition_types.begin(), op.condition_types.end()); + probe_types.insert(probe_types.end(), payload_types.begin(), payload_types.end()); + probe_types.emplace_back(LogicalType::HASH); + } + + void ScheduleFinalize(Pipeline &pipeline, Event &event); + void InitializeProbeSpill(); + +public: + ClientContext &context; + //! Global HT used by the join + unique_ptr hash_table; + //! The perfect hash join executor (if any) + unique_ptr perfect_join_executor; + //! Whether or not the hash table has been finalized + bool finalized = false; + + //! Whether we are doing an external join + bool external; + + //! Hash tables built by each thread + mutex lock; + vector> local_hash_tables; + + //! Excess probe data gathered during Sink + vector probe_types; + unique_ptr probe_spill; + + //! Whether or not we have started scanning data using GetData + atomic scanned_data; +}; + +class HashJoinLocalSinkState : public LocalSinkState { +public: + HashJoinLocalSinkState(const PhysicalHashJoin &op, ClientContext &context) : build_executor(context) { + auto &allocator = BufferAllocator::Get(context); + if (!op.right_projection_map.empty()) { + build_chunk.Initialize(allocator, op.build_types); + } + for (auto &cond : op.conditions) { + build_executor.AddExpression(*cond.right); + } + join_keys.Initialize(allocator, op.condition_types); + + hash_table = op.InitializeHashTable(context); + + hash_table->GetSinkCollection().InitializeAppendState(append_state); + } + +public: + PartitionedTupleDataAppendState append_state; + + DataChunk build_chunk; + DataChunk join_keys; + ExpressionExecutor build_executor; + + //! Thread-local HT + unique_ptr hash_table; +}; + +unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &context) const { + auto result = + make_uniq(BufferManager::GetBufferManager(context), conditions, build_types, join_type); + result->max_ht_size = double(0.6) * BufferManager::GetBufferManager(context).GetMaxMemory(); + if (!delim_types.empty() && join_type == JoinType::MARK) { + // correlated MARK join + if (delim_types.size() + 1 == conditions.size()) { + // the correlated MARK join has one more condition than the amount of correlated columns + // this is the case in a correlated ANY() expression + // in this case we need to keep track of additional entries, namely: + // - (1) the total amount of elements per group + // - (2) the amount of non-null elements per group + // we need these to correctly deal with the cases of either: + // - (1) the group being empty [in which case the result is always false, even if the comparison is NULL] + // - (2) the group containing a NULL value [in which case FALSE becomes NULL] + auto &info = result->correlated_mark_join_info; + + vector payload_types; + vector correlated_aggregates; + unique_ptr aggr; + + // jury-rigging the GroupedAggregateHashTable + // we need a count_star and a count to get counts with and without NULLs + + FunctionBinder function_binder(context); + aggr = function_binder.BindAggregateFunction(CountStarFun::GetFunction(), {}, nullptr, + AggregateType::NON_DISTINCT); + correlated_aggregates.push_back(&*aggr); + payload_types.push_back(aggr->return_type); + info.correlated_aggregates.push_back(std::move(aggr)); + + auto count_fun = CountFun::GetFunction(); + vector> children; + // this is a dummy but we need it to make the hash table understand whats going on + children.push_back(make_uniq_base(count_fun.return_type, 0)); + aggr = function_binder.BindAggregateFunction(count_fun, std::move(children), nullptr, + AggregateType::NON_DISTINCT); + correlated_aggregates.push_back(&*aggr); + payload_types.push_back(aggr->return_type); + info.correlated_aggregates.push_back(std::move(aggr)); + + auto &allocator = BufferAllocator::Get(context); + info.correlated_counts = make_uniq(context, allocator, delim_types, + payload_types, correlated_aggregates); + info.correlated_types = delim_types; + info.group_chunk.Initialize(allocator, delim_types); + info.result_chunk.Initialize(allocator, payload_types); + } + } + return result; +} + +unique_ptr PhysicalHashJoin::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(*this, context); +} + +unique_ptr PhysicalHashJoin::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(*this, context.client); +} + +SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &lstate = input.local_state.Cast(); + + // resolve the join keys for the right chunk + lstate.join_keys.Reset(); + lstate.build_executor.Execute(chunk, lstate.join_keys); + + // build the HT + auto &ht = *lstate.hash_table; + if (!right_projection_map.empty()) { + // there is a projection map: fill the build chunk with the projected columns + lstate.build_chunk.Reset(); + lstate.build_chunk.SetCardinality(chunk); + for (idx_t i = 0; i < right_projection_map.size(); i++) { + lstate.build_chunk.data[i].Reference(chunk.data[right_projection_map[i]]); + } + ht.Build(lstate.append_state, lstate.join_keys, lstate.build_chunk); + } else if (!build_types.empty()) { + // there is not a projected map: place the entire right chunk in the HT + ht.Build(lstate.append_state, lstate.join_keys, chunk); + } else { + // there are only keys: place an empty chunk in the payload + lstate.build_chunk.SetCardinality(chunk.size()); + ht.Build(lstate.append_state, lstate.join_keys, lstate.build_chunk); + } + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalHashJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + if (lstate.hash_table) { + lstate.hash_table->GetSinkCollection().FlushAppendState(lstate.append_state); + lock_guard local_ht_lock(gstate.lock); + gstate.local_hash_tables.push_back(std::move(lstate.hash_table)); + } + auto &client_profiler = QueryProfiler::Get(context.client); + context.thread.profiler.Flush(*this, lstate.build_executor, "build_executor", 1); + client_profiler.Flush(context.thread.profiler); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +class HashJoinFinalizeTask : public ExecutorTask { +public: + HashJoinFinalizeTask(shared_ptr event_p, ClientContext &context, HashJoinGlobalSinkState &sink_p, + idx_t chunk_idx_from_p, idx_t chunk_idx_to_p, bool parallel_p) + : ExecutorTask(context), event(std::move(event_p)), sink(sink_p), chunk_idx_from(chunk_idx_from_p), + chunk_idx_to(chunk_idx_to_p), parallel(parallel_p) { + } + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + sink.hash_table->Finalize(chunk_idx_from, chunk_idx_to, parallel); + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; + } + +private: + shared_ptr event; + HashJoinGlobalSinkState &sink; + idx_t chunk_idx_from; + idx_t chunk_idx_to; + bool parallel; +}; + +class HashJoinFinalizeEvent : public BasePipelineEvent { +public: + HashJoinFinalizeEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink) + : BasePipelineEvent(pipeline_p), sink(sink) { + } + + HashJoinGlobalSinkState &sink; + +public: + void Schedule() override { + auto &context = pipeline->GetClientContext(); + + vector> finalize_tasks; + auto &ht = *sink.hash_table; + const auto chunk_count = ht.GetDataCollection().ChunkCount(); + const idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + if (num_threads == 1 || (ht.Count() < PARALLEL_CONSTRUCT_THRESHOLD && !context.config.verify_parallelism)) { + // Single-threaded finalize + finalize_tasks.push_back( + make_uniq(shared_from_this(), context, sink, 0, chunk_count, false)); + } else { + // Parallel finalize + auto chunks_per_thread = MaxValue((chunk_count + num_threads - 1) / num_threads, 1); + + idx_t chunk_idx = 0; + for (idx_t thread_idx = 0; thread_idx < num_threads; thread_idx++) { + auto chunk_idx_from = chunk_idx; + auto chunk_idx_to = MinValue(chunk_idx_from + chunks_per_thread, chunk_count); + finalize_tasks.push_back(make_uniq(shared_from_this(), context, sink, + chunk_idx_from, chunk_idx_to, true)); + chunk_idx = chunk_idx_to; + if (chunk_idx == chunk_count) { + break; + } + } + } + SetTasks(std::move(finalize_tasks)); + } + + void FinishEvent() override { + sink.hash_table->GetDataCollection().VerifyEverythingPinned(); + sink.hash_table->finalized = true; + } + + static constexpr const idx_t PARALLEL_CONSTRUCT_THRESHOLD = 1048576; +}; + +void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) { + if (hash_table->Count() == 0) { + hash_table->finalized = true; + return; + } + hash_table->InitializePointerTable(); + auto new_event = make_shared(pipeline, *this); + event.InsertEvent(std::move(new_event)); +} + +void HashJoinGlobalSinkState::InitializeProbeSpill() { + lock_guard guard(lock); + if (!probe_spill) { + probe_spill = make_uniq(*hash_table, context, probe_types); + } +} + +class HashJoinRepartitionTask : public ExecutorTask { +public: + HashJoinRepartitionTask(shared_ptr event_p, ClientContext &context, JoinHashTable &global_ht, + JoinHashTable &local_ht) + : ExecutorTask(context), event(std::move(event_p)), global_ht(global_ht), local_ht(local_ht) { + } + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + local_ht.Partition(global_ht); + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; + } + +private: + shared_ptr event; + + JoinHashTable &global_ht; + JoinHashTable &local_ht; +}; + +class HashJoinPartitionEvent : public BasePipelineEvent { +public: + HashJoinPartitionEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink, + vector> &local_hts) + : BasePipelineEvent(pipeline_p), sink(sink), local_hts(local_hts) { + } + + HashJoinGlobalSinkState &sink; + vector> &local_hts; + +public: + void Schedule() override { + auto &context = pipeline->GetClientContext(); + vector> partition_tasks; + partition_tasks.reserve(local_hts.size()); + for (auto &local_ht : local_hts) { + partition_tasks.push_back( + make_uniq(shared_from_this(), context, *sink.hash_table, *local_ht)); + } + SetTasks(std::move(partition_tasks)); + } + + void FinishEvent() override { + local_hts.clear(); + sink.hash_table->PrepareExternalFinalize(); + sink.ScheduleFinalize(*pipeline, *this); + } +}; + +SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &sink = input.global_state.Cast(); + auto &ht = *sink.hash_table; + + sink.external = ht.RequiresExternalJoin(context.config, sink.local_hash_tables); + if (sink.external) { + sink.perfect_join_executor.reset(); + if (ht.RequiresPartitioning(context.config, sink.local_hash_tables)) { + auto new_event = make_shared(pipeline, sink, sink.local_hash_tables); + event.InsertEvent(std::move(new_event)); + } else { + for (auto &local_ht : sink.local_hash_tables) { + ht.Merge(*local_ht); + } + sink.local_hash_tables.clear(); + sink.hash_table->PrepareExternalFinalize(); + sink.ScheduleFinalize(pipeline, event); + } + sink.finalized = true; + return SinkFinalizeType::READY; + } else { + for (auto &local_ht : sink.local_hash_tables) { + ht.Merge(*local_ht); + } + sink.local_hash_tables.clear(); + ht.Unpartition(); + } + + // check for possible perfect hash table + auto use_perfect_hash = sink.perfect_join_executor->CanDoPerfectHashJoin(); + if (use_perfect_hash) { + D_ASSERT(ht.equality_types.size() == 1); + auto key_type = ht.equality_types[0]; + use_perfect_hash = sink.perfect_join_executor->BuildPerfectHashTable(key_type); + } + // In case of a large build side or duplicates, use regular hash join + if (!use_perfect_hash) { + sink.perfect_join_executor.reset(); + sink.ScheduleFinalize(pipeline, event); + } + sink.finalized = true; + if (ht.Count() == 0 && EmptyResultIfRHSIsEmpty()) { + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +class HashJoinOperatorState : public CachingOperatorState { +public: + explicit HashJoinOperatorState(ClientContext &context) : probe_executor(context), initialized(false) { + } + + DataChunk join_keys; + TupleDataChunkState join_key_state; + + ExpressionExecutor probe_executor; + unique_ptr scan_structure; + unique_ptr perfect_hash_join_state; + + bool initialized; + JoinHashTable::ProbeSpillLocalAppendState spill_state; + //! Chunk to sink data into for external join + DataChunk spill_chunk; + +public: + void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { + context.thread.profiler.Flush(op, probe_executor, "probe_executor", 0); + } +}; + +unique_ptr PhysicalHashJoin::GetOperatorState(ExecutionContext &context) const { + auto &allocator = BufferAllocator::Get(context.client); + auto &sink = sink_state->Cast(); + auto state = make_uniq(context.client); + if (sink.perfect_join_executor) { + state->perfect_hash_join_state = sink.perfect_join_executor->GetOperatorState(context); + } else { + state->join_keys.Initialize(allocator, condition_types); + for (auto &cond : conditions) { + state->probe_executor.AddExpression(*cond.left); + } + TupleDataCollection::InitializeChunkState(state->join_key_state, condition_types); + } + if (sink.external) { + state->spill_chunk.Initialize(allocator, sink.probe_types); + sink.InitializeProbeSpill(); + } + + return std::move(state); +} + +OperatorResultType PhysicalHashJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state_p) const { + auto &state = state_p.Cast(); + auto &sink = sink_state->Cast(); + D_ASSERT(sink.finalized); + D_ASSERT(!sink.scanned_data); + + // some initialization for external hash join + if (sink.external && !state.initialized) { + if (!sink.probe_spill) { + sink.InitializeProbeSpill(); + } + state.spill_state = sink.probe_spill->RegisterThread(); + state.initialized = true; + } + + if (sink.hash_table->Count() == 0 && EmptyResultIfRHSIsEmpty()) { + return OperatorResultType::FINISHED; + } + + if (sink.perfect_join_executor) { + D_ASSERT(!sink.external); + return sink.perfect_join_executor->ProbePerfectHashTable(context, input, chunk, *state.perfect_hash_join_state); + } + + if (state.scan_structure) { + // still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) + state.scan_structure->Next(state.join_keys, input, chunk); + if (chunk.size() > 0) { + return OperatorResultType::HAVE_MORE_OUTPUT; + } + state.scan_structure = nullptr; + return OperatorResultType::NEED_MORE_INPUT; + } + + // probe the HT + if (sink.hash_table->Count() == 0) { + ConstructEmptyJoinResult(sink.hash_table->join_type, sink.hash_table->has_null, input, chunk); + return OperatorResultType::NEED_MORE_INPUT; + } + + // resolve the join keys for the left chunk + state.join_keys.Reset(); + state.probe_executor.Execute(input, state.join_keys); + + // perform the actual probe + if (sink.external) { + state.scan_structure = sink.hash_table->ProbeAndSpill(state.join_keys, state.join_key_state, input, + *sink.probe_spill, state.spill_state, state.spill_chunk); + } else { + state.scan_structure = sink.hash_table->Probe(state.join_keys, state.join_key_state); + } + state.scan_structure->Next(state.join_keys, input, chunk); + return OperatorResultType::HAVE_MORE_OUTPUT; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +enum class HashJoinSourceStage : uint8_t { INIT, BUILD, PROBE, SCAN_HT, DONE }; + +class HashJoinLocalSourceState; + +class HashJoinGlobalSourceState : public GlobalSourceState { +public: + HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context); + + //! Initialize this source state using the info in the sink + void Initialize(HashJoinGlobalSinkState &sink); + //! Try to prepare the next stage + void TryPrepareNextStage(HashJoinGlobalSinkState &sink); + //! Prepare the next build/probe/scan_ht stage for external hash join (must hold lock) + void PrepareBuild(HashJoinGlobalSinkState &sink); + void PrepareProbe(HashJoinGlobalSinkState &sink); + void PrepareScanHT(HashJoinGlobalSinkState &sink); + //! Assigns a task to a local source state + bool AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate); + + idx_t MaxThreads() override { + D_ASSERT(op.sink_state); + auto &gstate = op.sink_state->Cast(); + + idx_t count; + if (gstate.probe_spill) { + count = probe_count; + } else if (IsRightOuterJoin(op.join_type)) { + count = gstate.hash_table->Count(); + } else { + return 0; + } + return count / ((idx_t)STANDARD_VECTOR_SIZE * parallel_scan_chunk_count); + } + +public: + const PhysicalHashJoin &op; + + //! For synchronizing the external hash join + atomic global_stage; + mutex lock; + + //! For HT build synchronization + idx_t build_chunk_idx; + idx_t build_chunk_count; + idx_t build_chunk_done; + idx_t build_chunks_per_thread; + + //! For probe synchronization + idx_t probe_chunk_count; + idx_t probe_chunk_done; + + //! To determine the number of threads + idx_t probe_count; + idx_t parallel_scan_chunk_count; + + //! For full/outer synchronization + idx_t full_outer_chunk_idx; + idx_t full_outer_chunk_count; + idx_t full_outer_chunk_done; + idx_t full_outer_chunks_per_thread; +}; + +class HashJoinLocalSourceState : public LocalSourceState { +public: + HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator); + + //! Do the work this thread has been assigned + void ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); + //! Whether this thread has finished the work it has been assigned + bool TaskFinished(); + //! Build, probe and scan for external hash join + void ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate); + void ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); + void ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); + +public: + //! The stage that this thread was assigned work for + HashJoinSourceStage local_stage; + //! Vector with pointers here so we don't have to re-initialize + Vector addresses; + + //! Chunks assigned to this thread for building the pointer table + idx_t build_chunk_idx_from; + idx_t build_chunk_idx_to; + + //! Local scan state for probe spill + ColumnDataConsumerScanState probe_local_scan; + //! Chunks for holding the scanned probe collection + DataChunk probe_chunk; + DataChunk join_keys; + DataChunk payload; + TupleDataChunkState join_key_state; + //! Column indices to easily reference the join keys/payload columns in probe_chunk + vector join_key_indices; + vector payload_indices; + //! Scan structure for the external probe + unique_ptr scan_structure; + bool empty_ht_probe_in_progress; + + //! Chunks assigned to this thread for a full/outer scan + idx_t full_outer_chunk_idx_from; + idx_t full_outer_chunk_idx_to; + unique_ptr full_outer_scan_state; +}; + +unique_ptr PhysicalHashJoin::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this, context); +} + +unique_ptr PhysicalHashJoin::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(*this, BufferAllocator::Get(context.client)); +} + +HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context) + : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0), + probe_chunk_done(0), probe_count(op.children[0]->estimated_cardinality), + parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120) { +} + +void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) { + lock_guard init_lock(lock); + if (global_stage != HashJoinSourceStage::INIT) { + // Another thread initialized + return; + } + + // Finalize the probe spill + if (sink.probe_spill) { + sink.probe_spill->Finalize(); + } + + global_stage = HashJoinSourceStage::PROBE; + TryPrepareNextStage(sink); +} + +void HashJoinGlobalSourceState::TryPrepareNextStage(HashJoinGlobalSinkState &sink) { + switch (global_stage.load()) { + case HashJoinSourceStage::BUILD: + if (build_chunk_done == build_chunk_count) { + sink.hash_table->GetDataCollection().VerifyEverythingPinned(); + sink.hash_table->finalized = true; + PrepareProbe(sink); + } + break; + case HashJoinSourceStage::PROBE: + if (probe_chunk_done == probe_chunk_count) { + if (IsRightOuterJoin(op.join_type)) { + PrepareScanHT(sink); + } else { + PrepareBuild(sink); + } + } + break; + case HashJoinSourceStage::SCAN_HT: + if (full_outer_chunk_done == full_outer_chunk_count) { + PrepareBuild(sink); + } + break; + default: + break; + } +} + +void HashJoinGlobalSourceState::PrepareBuild(HashJoinGlobalSinkState &sink) { + D_ASSERT(global_stage != HashJoinSourceStage::BUILD); + auto &ht = *sink.hash_table; + + // Try to put the next partitions in the block collection of the HT + if (!sink.external || !ht.PrepareExternalFinalize()) { + global_stage = HashJoinSourceStage::DONE; + return; + } + + auto &data_collection = ht.GetDataCollection(); + if (data_collection.Count() == 0 && op.EmptyResultIfRHSIsEmpty()) { + PrepareBuild(sink); + return; + } + + build_chunk_idx = 0; + build_chunk_count = data_collection.ChunkCount(); + build_chunk_done = 0; + + auto num_threads = TaskScheduler::GetScheduler(sink.context).NumberOfThreads(); + build_chunks_per_thread = MaxValue((build_chunk_count + num_threads - 1) / num_threads, 1); + + ht.InitializePointerTable(); + + global_stage = HashJoinSourceStage::BUILD; +} + +void HashJoinGlobalSourceState::PrepareProbe(HashJoinGlobalSinkState &sink) { + sink.probe_spill->PrepareNextProbe(); + const auto &consumer = *sink.probe_spill->consumer; + + probe_chunk_count = consumer.Count() == 0 ? 0 : consumer.ChunkCount(); + probe_chunk_done = 0; + + global_stage = HashJoinSourceStage::PROBE; + if (probe_chunk_count == 0) { + TryPrepareNextStage(sink); + return; + } +} + +void HashJoinGlobalSourceState::PrepareScanHT(HashJoinGlobalSinkState &sink) { + D_ASSERT(global_stage != HashJoinSourceStage::SCAN_HT); + auto &ht = *sink.hash_table; + + auto &data_collection = ht.GetDataCollection(); + full_outer_chunk_idx = 0; + full_outer_chunk_count = data_collection.ChunkCount(); + full_outer_chunk_done = 0; + + auto num_threads = TaskScheduler::GetScheduler(sink.context).NumberOfThreads(); + full_outer_chunks_per_thread = MaxValue((full_outer_chunk_count + num_threads - 1) / num_threads, 1); + + global_stage = HashJoinSourceStage::SCAN_HT; +} + +bool HashJoinGlobalSourceState::AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate) { + D_ASSERT(lstate.TaskFinished()); + + lock_guard guard(lock); + switch (global_stage.load()) { + case HashJoinSourceStage::BUILD: + if (build_chunk_idx != build_chunk_count) { + lstate.local_stage = global_stage; + lstate.build_chunk_idx_from = build_chunk_idx; + build_chunk_idx = MinValue(build_chunk_count, build_chunk_idx + build_chunks_per_thread); + lstate.build_chunk_idx_to = build_chunk_idx; + return true; + } + break; + case HashJoinSourceStage::PROBE: + if (sink.probe_spill->consumer && sink.probe_spill->consumer->AssignChunk(lstate.probe_local_scan)) { + lstate.local_stage = global_stage; + lstate.empty_ht_probe_in_progress = false; + return true; + } + break; + case HashJoinSourceStage::SCAN_HT: + if (full_outer_chunk_idx != full_outer_chunk_count) { + lstate.local_stage = global_stage; + lstate.full_outer_chunk_idx_from = full_outer_chunk_idx; + full_outer_chunk_idx = + MinValue(full_outer_chunk_count, full_outer_chunk_idx + full_outer_chunks_per_thread); + lstate.full_outer_chunk_idx_to = full_outer_chunk_idx; + return true; + } + break; + case HashJoinSourceStage::DONE: + break; + default: + throw InternalException("Unexpected HashJoinSourceStage in AssignTask!"); + } + return false; +} + +HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator) + : local_stage(HashJoinSourceStage::INIT), addresses(LogicalType::POINTER) { + auto &chunk_state = probe_local_scan.current_chunk_state; + chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; + + auto &sink = op.sink_state->Cast(); + probe_chunk.Initialize(allocator, sink.probe_types); + join_keys.Initialize(allocator, op.condition_types); + payload.Initialize(allocator, op.children[0]->types); + TupleDataCollection::InitializeChunkState(join_key_state, op.condition_types); + + // Store the indices of the columns to reference them easily + idx_t col_idx = 0; + for (; col_idx < op.condition_types.size(); col_idx++) { + join_key_indices.push_back(col_idx); + } + for (; col_idx < sink.probe_types.size() - 1; col_idx++) { + payload_indices.push_back(col_idx); + } +} + +void HashJoinLocalSourceState::ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, + DataChunk &chunk) { + switch (local_stage) { + case HashJoinSourceStage::BUILD: + ExternalBuild(sink, gstate); + break; + case HashJoinSourceStage::PROBE: + ExternalProbe(sink, gstate, chunk); + break; + case HashJoinSourceStage::SCAN_HT: + ExternalScanHT(sink, gstate, chunk); + break; + default: + throw InternalException("Unexpected HashJoinSourceStage in ExecuteTask!"); + } +} + +bool HashJoinLocalSourceState::TaskFinished() { + switch (local_stage) { + case HashJoinSourceStage::INIT: + case HashJoinSourceStage::BUILD: + return true; + case HashJoinSourceStage::PROBE: + return scan_structure == nullptr && !empty_ht_probe_in_progress; + case HashJoinSourceStage::SCAN_HT: + return full_outer_scan_state == nullptr; + default: + throw InternalException("Unexpected HashJoinSourceStage in TaskFinished!"); + } +} + +void HashJoinLocalSourceState::ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate) { + D_ASSERT(local_stage == HashJoinSourceStage::BUILD); + + auto &ht = *sink.hash_table; + ht.Finalize(build_chunk_idx_from, build_chunk_idx_to, true); + + lock_guard guard(gstate.lock); + gstate.build_chunk_done += build_chunk_idx_to - build_chunk_idx_from; +} + +void HashJoinLocalSourceState::ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, + DataChunk &chunk) { + D_ASSERT(local_stage == HashJoinSourceStage::PROBE && sink.hash_table->finalized); + + if (scan_structure) { + // Still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) + scan_structure->Next(join_keys, payload, chunk); + if (chunk.size() != 0) { + return; + } + } + + if (scan_structure || empty_ht_probe_in_progress) { + // Previous probe is done + scan_structure = nullptr; + empty_ht_probe_in_progress = false; + sink.probe_spill->consumer->FinishChunk(probe_local_scan); + lock_guard lock(gstate.lock); + gstate.probe_chunk_done++; + return; + } + + // Scan input chunk for next probe + sink.probe_spill->consumer->ScanChunk(probe_local_scan, probe_chunk); + + // Get the probe chunk columns/hashes + join_keys.ReferenceColumns(probe_chunk, join_key_indices); + payload.ReferenceColumns(probe_chunk, payload_indices); + auto precomputed_hashes = &probe_chunk.data.back(); + + if (sink.hash_table->Count() == 0 && !gstate.op.EmptyResultIfRHSIsEmpty()) { + gstate.op.ConstructEmptyJoinResult(sink.hash_table->join_type, sink.hash_table->has_null, payload, chunk); + empty_ht_probe_in_progress = true; + return; + } + + // Perform the probe + scan_structure = sink.hash_table->Probe(join_keys, join_key_state, precomputed_hashes); + scan_structure->Next(join_keys, payload, chunk); +} + +void HashJoinLocalSourceState::ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, + DataChunk &chunk) { + D_ASSERT(local_stage == HashJoinSourceStage::SCAN_HT); + + if (!full_outer_scan_state) { + full_outer_scan_state = make_uniq(sink.hash_table->GetDataCollection(), + full_outer_chunk_idx_from, full_outer_chunk_idx_to); + } + sink.hash_table->ScanFullOuter(*full_outer_scan_state, addresses, chunk); + + if (chunk.size() == 0) { + full_outer_scan_state = nullptr; + lock_guard guard(gstate.lock); + gstate.full_outer_chunk_done += full_outer_chunk_idx_to - full_outer_chunk_idx_from; + } +} + +SourceResultType PhysicalHashJoin::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &sink = sink_state->Cast(); + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + sink.scanned_data = true; + + if (!sink.external && !IsRightOuterJoin(join_type)) { + return SourceResultType::FINISHED; + } + + if (gstate.global_stage == HashJoinSourceStage::INIT) { + gstate.Initialize(sink); + } + + // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done + // Therefore, we loop until we've produced tuples, or until the operator is actually done + while (gstate.global_stage != HashJoinSourceStage::DONE && chunk.size() == 0) { + if (!lstate.TaskFinished() || gstate.AssignTask(sink, lstate)) { + lstate.ExecuteTask(sink, gstate, chunk); + } else { + lock_guard guard(gstate.lock); + gstate.TryPrepareNextStage(sink); + } + } + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp new file mode 100644 index 00000000..dfb61bda --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp @@ -0,0 +1,1050 @@ +#include "duckdb/execution/operator/join/physical_iejoin.hpp" + +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sort/sorted_block.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/event.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +#include + +namespace duckdb { + +PhysicalIEJoin::PhysicalIEJoin(LogicalComparisonJoin &op, unique_ptr left, + unique_ptr right, vector cond, JoinType join_type, + idx_t estimated_cardinality) + : PhysicalRangeJoin(op, PhysicalOperatorType::IE_JOIN, std::move(left), std::move(right), std::move(cond), + join_type, estimated_cardinality) { + + // 1. let L1 (resp. L2) be the array of column X (resp. Y) + D_ASSERT(conditions.size() >= 2); + lhs_orders.resize(2); + rhs_orders.resize(2); + for (idx_t i = 0; i < 2; ++i) { + auto &cond = conditions[i]; + D_ASSERT(cond.left->return_type == cond.right->return_type); + join_key_types.push_back(cond.left->return_type); + + // Convert the conditions to sort orders + auto left = cond.left->Copy(); + auto right = cond.right->Copy(); + auto sense = OrderType::INVALID; + + // 2. if (op1 ∈ {>, ≥}) sort L1 in descending order + // 3. else if (op1 ∈ {<, ≤}) sort L1 in ascending order + // 4. if (op2 ∈ {>, ≥}) sort L2 in ascending order + // 5. else if (op2 ∈ {<, ≤}) sort L2 in descending order + switch (cond.comparison) { + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + sense = i ? OrderType::ASCENDING : OrderType::DESCENDING; + break; + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + sense = i ? OrderType::DESCENDING : OrderType::ASCENDING; + break; + default: + throw NotImplementedException("Unimplemented join type for IEJoin"); + } + lhs_orders[i].emplace_back(BoundOrderByNode(sense, OrderByNullType::NULLS_LAST, std::move(left))); + rhs_orders[i].emplace_back(BoundOrderByNode(sense, OrderByNullType::NULLS_LAST, std::move(right))); + } + + for (idx_t i = 2; i < conditions.size(); ++i) { + auto &cond = conditions[i]; + D_ASSERT(cond.left->return_type == cond.right->return_type); + join_key_types.push_back(cond.left->return_type); + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class IEJoinLocalState : public LocalSinkState { +public: + using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; + + IEJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child) + : table(context, op, child) { + } + + //! The local sort state + LocalSortedTable table; +}; + +class IEJoinGlobalState : public GlobalSinkState { +public: + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; + +public: + IEJoinGlobalState(ClientContext &context, const PhysicalIEJoin &op) : child(0) { + tables.resize(2); + RowLayout lhs_layout; + lhs_layout.Initialize(op.children[0]->types); + vector lhs_order; + lhs_order.emplace_back(op.lhs_orders[0][0].Copy()); + tables[0] = make_uniq(context, lhs_order, lhs_layout); + + RowLayout rhs_layout; + rhs_layout.Initialize(op.children[1]->types); + vector rhs_order; + rhs_order.emplace_back(op.rhs_orders[0][0].Copy()); + tables[1] = make_uniq(context, rhs_order, rhs_layout); + } + + IEJoinGlobalState(IEJoinGlobalState &prev) + : GlobalSinkState(prev), tables(std::move(prev.tables)), child(prev.child + 1) { + } + + void Sink(DataChunk &input, IEJoinLocalState &lstate) { + auto &table = *tables[child]; + auto &global_sort_state = table.global_sort_state; + auto &local_sort_state = lstate.table.local_sort_state; + + // Sink the data into the local sort state + lstate.table.Sink(input, global_sort_state); + + // When sorting data reaches a certain size, we sort it + if (local_sort_state.SizeInBytes() >= table.memory_per_thread) { + local_sort_state.Sort(global_sort_state, true); + } + } + + vector> tables; + size_t child; +}; + +unique_ptr PhysicalIEJoin::GetGlobalSinkState(ClientContext &context) const { + D_ASSERT(!sink_state); + return make_uniq(context, *this); +} + +unique_ptr PhysicalIEJoin::GetLocalSinkState(ExecutionContext &context) const { + idx_t sink_child = 0; + if (sink_state) { + const auto &ie_sink = sink_state->Cast(); + sink_child = ie_sink.child; + } + return make_uniq(context.client, *this, sink_child); +} + +SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + gstate.Sink(chunk, lstate); + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + gstate.tables[gstate.child]->Combine(lstate.table); + auto &client_profiler = QueryProfiler::Get(context.client); + + context.thread.profiler.Flush(*this, lstate.table.executor, gstate.child ? "rhs_executor" : "lhs_executor", 1); + client_profiler.Flush(context.thread.profiler); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &table = *gstate.tables[gstate.child]; + auto &global_sort_state = table.global_sort_state; + + if ((gstate.child == 1 && IsRightOuterJoin(join_type)) || (gstate.child == 0 && IsLeftOuterJoin(join_type))) { + // for FULL/LEFT/RIGHT OUTER JOIN, initialize found_match to false for every tuple + table.IntializeMatches(); + } + if (gstate.child == 1 && global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + // Empty input! + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // Sort the current input child + table.Finalize(pipeline, event); + + // Move to the next input child + ++gstate.child; + + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +OperatorResultType PhysicalIEJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const { + return OperatorResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +struct IEJoinUnion { + using SortedTable = PhysicalRangeJoin::GlobalSortedTable; + + static idx_t AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, + int64_t base, const idx_t block_idx); + + static void Sort(SortedTable &table) { + auto &global_sort_state = table.global_sort_state; + global_sort_state.PrepareMergePhase(); + while (global_sort_state.sorted_blocks.size() > 1) { + global_sort_state.InitializeMergeRound(); + MergeSorter merge_sorter(global_sort_state, global_sort_state.buffer_manager); + merge_sorter.PerformInMergeRound(); + global_sort_state.CompleteMergeRound(true); + } + } + + template + static vector ExtractColumn(SortedTable &table, idx_t col_idx) { + vector result; + result.reserve(table.count); + + auto &gstate = table.global_sort_state; + auto &blocks = *gstate.sorted_blocks[0]->payload_data; + PayloadScanner scanner(blocks, gstate, false); + + DataChunk payload; + payload.Initialize(Allocator::DefaultAllocator(), gstate.payload_layout.GetTypes()); + for (;;) { + scanner.Scan(payload); + const auto count = payload.size(); + if (!count) { + break; + } + + const auto data_ptr = FlatVector::GetData(payload.data[col_idx]); + result.insert(result.end(), data_ptr, data_ptr + count); + } + + return result; + } + + IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, SortedTable &t2, + const idx_t b2); + + idx_t SearchL1(idx_t pos); + bool NextRow(); + + //! Inverted loop + idx_t JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel); + + //! L1 + unique_ptr l1; + //! L2 + unique_ptr l2; + + //! Li + vector li; + //! P + vector p; + + //! B + vector bit_array; + ValidityMask bit_mask; + + //! Bloom Filter + static constexpr idx_t BLOOM_CHUNK_BITS = 1024; + idx_t bloom_count; + vector bloom_array; + ValidityMask bloom_filter; + + //! Iteration state + idx_t n; + idx_t i; + idx_t j; + unique_ptr op1; + unique_ptr off1; + unique_ptr op2; + unique_ptr off2; + int64_t lrid; +}; + +idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, + int64_t base, const idx_t block_idx) { + LocalSortState local_sort_state; + local_sort_state.Initialize(marked.global_sort_state, marked.global_sort_state.buffer_manager); + + // Reading + const auto valid = table.count - table.has_null; + auto &gstate = table.global_sort_state; + PayloadScanner scanner(gstate, block_idx); + auto table_idx = block_idx * gstate.block_capacity; + + DataChunk scanned; + scanned.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); + + // Writing + auto types = local_sort_state.sort_layout->logical_types; + const idx_t payload_idx = types.size(); + + const auto &payload_types = local_sort_state.payload_layout->GetTypes(); + types.insert(types.end(), payload_types.begin(), payload_types.end()); + const idx_t rid_idx = types.size() - 1; + + DataChunk keys; + DataChunk payload; + keys.Initialize(Allocator::DefaultAllocator(), types); + + idx_t inserted = 0; + for (auto rid = base; table_idx < valid;) { + scanner.Scan(scanned); + + // NULLs are at the end, so stop when we reach them + auto scan_count = scanned.size(); + if (table_idx + scan_count > valid) { + scan_count = valid - table_idx; + scanned.SetCardinality(scan_count); + } + if (scan_count == 0) { + break; + } + table_idx += scan_count; + + // Compute the input columns from the payload + keys.Reset(); + keys.Split(payload, rid_idx); + executor.Execute(scanned, keys); + + // Mark the rid column + payload.data[0].Sequence(rid, increment, scan_count); + payload.SetCardinality(scan_count); + keys.Fuse(payload); + rid += increment * scan_count; + + // Sort on the sort columns (which will no longer be needed) + keys.Split(payload, payload_idx); + local_sort_state.SinkChunk(keys, payload); + inserted += scan_count; + keys.Fuse(payload); + + // Flush when we have enough data + if (local_sort_state.SizeInBytes() >= marked.memory_per_thread) { + local_sort_state.Sort(marked.global_sort_state, true); + } + } + marked.global_sort_state.AddLocalState(local_sort_state); + marked.count += inserted; + + return inserted; +} + +IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, + SortedTable &t2, const idx_t b2) + : n(0), i(0) { + // input : query Q with 2 join predicates t1.X op1 t2.X' and t1.Y op2 t2.Y', tables T, T' of sizes m and n resp. + // output: a list of tuple pairs (ti , tj) + // Note that T/T' are already sorted on X/X' and contain the payload data + // We only join the two block numbers and use the sizes of the blocks as the counts + + // 0. Filter out tables with no overlap + if (!t1.BlockSize(b1) || !t2.BlockSize(b2)) { + return; + } + + const auto &cmp1 = op.conditions[0].comparison; + SBIterator bounds1(t1.global_sort_state, cmp1); + SBIterator bounds2(t2.global_sort_state, cmp1); + + // t1.X[0] op1 t2.X'[-1] + bounds1.SetIndex(bounds1.block_capacity * b1); + bounds2.SetIndex(bounds2.block_capacity * b2 + t2.BlockSize(b2) - 1); + if (!bounds1.Compare(bounds2)) { + return; + } + + // 1. let L1 (resp. L2) be the array of column X (resp. Y ) + const auto &order1 = op.lhs_orders[0][0]; + const auto &order2 = op.lhs_orders[1][0]; + + // 2. if (op1 ∈ {>, ≥}) sort L1 in descending order + // 3. else if (op1 ∈ {<, ≤}) sort L1 in ascending order + + // For the union algorithm, we make a unified table with the keys and the rids as the payload: + // X/X', Y/Y', R/R'/Li + // The first position is the sort key. + vector types; + types.emplace_back(order2.expression->return_type); + types.emplace_back(LogicalType::BIGINT); + RowLayout payload_layout; + payload_layout.Initialize(types); + + // Sort on the first expression + auto ref = make_uniq(order1.expression->return_type, 0); + vector orders; + orders.emplace_back(order1.type, order1.null_order, std::move(ref)); + + l1 = make_uniq(context, orders, payload_layout); + + // LHS has positive rids + ExpressionExecutor l_executor(context); + l_executor.AddExpression(*order1.expression); + l_executor.AddExpression(*order2.expression); + AppendKey(t1, l_executor, *l1, 1, 1, b1); + + // RHS has negative rids + ExpressionExecutor r_executor(context); + r_executor.AddExpression(*op.rhs_orders[0][0].expression); + r_executor.AddExpression(*op.rhs_orders[1][0].expression); + AppendKey(t2, r_executor, *l1, -1, -1, b2); + + if (l1->global_sort_state.sorted_blocks.empty()) { + return; + } + + Sort(*l1); + + op1 = make_uniq(l1->global_sort_state, cmp1); + off1 = make_uniq(l1->global_sort_state, cmp1); + + // We don't actually need the L1 column, just its sort key, which is in the sort blocks + li = ExtractColumn(*l1, types.size() - 1); + + // 4. if (op2 ∈ {>, ≥}) sort L2 in ascending order + // 5. else if (op2 ∈ {<, ≤}) sort L2 in descending order + + // We sort on Y/Y' to obtain the sort keys and the permutation array. + // For this we just need a two-column table of Y, P + types.clear(); + types.emplace_back(LogicalType::BIGINT); + payload_layout.Initialize(types); + + // Sort on the first expression + orders.clear(); + ref = make_uniq(order2.expression->return_type, 0); + orders.emplace_back(order2.type, order2.null_order, std::move(ref)); + + ExpressionExecutor executor(context); + executor.AddExpression(*orders[0].expression); + + l2 = make_uniq(context, orders, payload_layout); + for (idx_t base = 0, block_idx = 0; block_idx < l1->BlockCount(); ++block_idx) { + base += AppendKey(*l1, executor, *l2, 1, base, block_idx); + } + + Sort(*l2); + + // We don't actually need the L2 column, just its sort key, which is in the sort blocks + + // 6. compute the permutation array P of L2 w.r.t. L1 + p = ExtractColumn(*l2, types.size() - 1); + + // 7. initialize bit-array B (|B| = n), and set all bits to 0 + n = l2->count.load(); + bit_array.resize(ValidityMask::EntryCount(n), 0); + bit_mask.Initialize(bit_array.data()); + + // Bloom filter + bloom_count = (n + (BLOOM_CHUNK_BITS - 1)) / BLOOM_CHUNK_BITS; + bloom_array.resize(ValidityMask::EntryCount(bloom_count), 0); + bloom_filter.Initialize(bloom_array.data()); + + // 11. for(i←1 to n) do + const auto &cmp2 = op.conditions[1].comparison; + op2 = make_uniq(l2->global_sort_state, cmp2); + off2 = make_uniq(l2->global_sort_state, cmp2); + i = 0; + j = 0; + (void)NextRow(); +} + +idx_t IEJoinUnion::SearchL1(idx_t pos) { + // Perform an exponential search in the appropriate direction + op1->SetIndex(pos); + + idx_t step = 1; + auto hi = pos; + auto lo = pos; + if (!op1->cmp) { + // Scan left for loose inequality + lo -= MinValue(step, lo); + step *= 2; + off1->SetIndex(lo); + while (lo > 0 && op1->Compare(*off1)) { + hi = lo; + lo -= MinValue(step, lo); + step *= 2; + off1->SetIndex(lo); + } + } else { + // Scan right for strict inequality + hi += MinValue(step, n - hi); + step *= 2; + off1->SetIndex(hi); + while (hi < n && !op1->Compare(*off1)) { + lo = hi; + hi += MinValue(step, n - hi); + step *= 2; + off1->SetIndex(hi); + } + } + + // Binary search the target area + while (lo < hi) { + const auto mid = lo + (hi - lo) / 2; + off1->SetIndex(mid); + if (op1->Compare(*off1)) { + hi = mid; + } else { + lo = mid + 1; + } + } + + off1->SetIndex(lo); + + return lo; +} + +bool IEJoinUnion::NextRow() { + for (; i < n; ++i) { + // 12. pos ← P[i] + auto pos = p[i]; + lrid = li[pos]; + if (lrid < 0) { + continue; + } + + // 16. B[pos] ← 1 + op2->SetIndex(i); + for (; off2->GetIndex() < n; ++(*off2)) { + if (!off2->Compare(*op2)) { + break; + } + const auto p2 = p[off2->GetIndex()]; + if (li[p2] < 0) { + // Only mark rhs matches. + bit_mask.SetValid(p2); + bloom_filter.SetValid(p2 / BLOOM_CHUNK_BITS); + } + } + + // 9. if (op1 ∈ {≤,≥} and op2 ∈ {≤,≥}) eqOff = 0 + // 10. else eqOff = 1 + // No, because there could be more than one equal value. + // Find the leftmost off1 where L1[pos] op1 L1[off1..n] + // These are the rows that satisfy the op1 condition + // and that is where we should start scanning B from + j = SearchL1(pos); + + return true; + } + return false; +} + +static idx_t NextValid(const ValidityMask &bits, idx_t j, const idx_t n) { + if (j >= n) { + return n; + } + + // We can do a first approximation by checking entries one at a time + // which gives 64:1. + idx_t entry_idx, idx_in_entry; + bits.GetEntryIndex(j, entry_idx, idx_in_entry); + auto entry = bits.GetValidityEntry(entry_idx++); + + // Trim the bits before the start position + entry &= (ValidityMask::ValidityBuffer::MAX_ENTRY << idx_in_entry); + + // Check the non-ragged entries + for (const auto entry_count = bits.EntryCount(n); entry_idx < entry_count; ++entry_idx) { + if (entry) { + for (; idx_in_entry < bits.BITS_PER_VALUE; ++idx_in_entry, ++j) { + if (bits.RowIsValid(entry, idx_in_entry)) { + return j; + } + } + } else { + j += bits.BITS_PER_VALUE - idx_in_entry; + } + + entry = bits.GetValidityEntry(entry_idx); + idx_in_entry = 0; + } + + // Check the final entry + for (; j < n; ++idx_in_entry, ++j) { + if (bits.RowIsValid(entry, idx_in_entry)) { + return j; + } + } + + return j; +} + +idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel) { + // 8. initialize join result as an empty list for tuple pairs + idx_t result_count = 0; + + // 11. for(i←1 to n) do + while (i < n) { + // 13. for (j ← pos+eqOff to n) do + for (;;) { + // 14. if B[j] = 1 then + + // Use the Bloom filter to find candidate blocks + while (j < n) { + auto bloom_begin = NextValid(bloom_filter, j / BLOOM_CHUNK_BITS, bloom_count) * BLOOM_CHUNK_BITS; + auto bloom_end = MinValue(n, bloom_begin + BLOOM_CHUNK_BITS); + + j = MaxValue(j, bloom_begin); + j = NextValid(bit_mask, j, bloom_end); + if (j < bloom_end) { + break; + } + } + + if (j >= n) { + break; + } + + // Filter out tuples with the same sign (they come from the same table) + const auto rrid = li[j]; + ++j; + + // 15. add tuples w.r.t. (L1[j], L1[i]) to join result + if (lrid > 0 && rrid < 0) { + lsel.set_index(result_count, sel_t(+lrid - 1)); + rsel.set_index(result_count, sel_t(-rrid - 1)); + ++result_count; + if (result_count == STANDARD_VECTOR_SIZE) { + // out of space! + return result_count; + } + } + } + ++i; + + if (!NextRow()) { + break; + } + } + + return result_count; +} + +class IEJoinLocalSourceState : public LocalSourceState { +public: + explicit IEJoinLocalSourceState(ClientContext &context, const PhysicalIEJoin &op) + : op(op), true_sel(STANDARD_VECTOR_SIZE), left_executor(context), right_executor(context), + left_matches(nullptr), right_matches(nullptr) { + auto &allocator = Allocator::Get(context); + unprojected.Initialize(allocator, op.unprojected_types); + + if (op.conditions.size() < 3) { + return; + } + + vector left_types; + vector right_types; + for (idx_t i = 2; i < op.conditions.size(); ++i) { + const auto &cond = op.conditions[i]; + + left_types.push_back(cond.left->return_type); + left_executor.AddExpression(*cond.left); + + right_types.push_back(cond.left->return_type); + right_executor.AddExpression(*cond.right); + } + + left_keys.Initialize(allocator, left_types); + right_keys.Initialize(allocator, right_types); + } + + idx_t SelectOuterRows(bool *matches) { + idx_t count = 0; + for (; outer_idx < outer_count; ++outer_idx) { + if (!matches[outer_idx]) { + true_sel.set_index(count++, outer_idx); + if (count >= STANDARD_VECTOR_SIZE) { + outer_idx++; + break; + } + } + } + + return count; + } + + const PhysicalIEJoin &op; + + // Joining + unique_ptr joiner; + + idx_t left_base; + idx_t left_block_index; + + idx_t right_base; + idx_t right_block_index; + + // Trailing predicates + SelectionVector true_sel; + + ExpressionExecutor left_executor; + DataChunk left_keys; + + ExpressionExecutor right_executor; + DataChunk right_keys; + + DataChunk unprojected; + + // Outer joins + idx_t outer_idx; + idx_t outer_count; + bool *left_matches; + bool *right_matches; +}; + +void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state_p) const { + auto &state = state_p.Cast(); + auto &ie_sink = sink_state->Cast(); + auto &left_table = *ie_sink.tables[0]; + auto &right_table = *ie_sink.tables[1]; + + const auto left_cols = children[0]->GetTypes().size(); + auto &chunk = state.unprojected; + do { + SelectionVector lsel(STANDARD_VECTOR_SIZE); + SelectionVector rsel(STANDARD_VECTOR_SIZE); + auto result_count = state.joiner->JoinComplexBlocks(lsel, rsel); + if (result_count == 0) { + // exhausted this pair + return; + } + + // found matches: extract them + + chunk.Reset(); + SliceSortedPayload(chunk, left_table.global_sort_state, state.left_block_index, lsel, result_count, 0); + SliceSortedPayload(chunk, right_table.global_sort_state, state.right_block_index, rsel, result_count, + left_cols); + chunk.SetCardinality(result_count); + + auto sel = FlatVector::IncrementalSelectionVector(); + if (conditions.size() > 2) { + // If there are more expressions to compute, + // split the result chunk into the left and right halves + // so we can compute the values for comparison. + const auto tail_cols = conditions.size() - 2; + + DataChunk right_chunk; + chunk.Split(right_chunk, left_cols); + state.left_executor.SetChunk(chunk); + state.right_executor.SetChunk(right_chunk); + + auto tail_count = result_count; + auto true_sel = &state.true_sel; + for (size_t cmp_idx = 0; cmp_idx < tail_cols; ++cmp_idx) { + auto &left = state.left_keys.data[cmp_idx]; + state.left_executor.ExecuteExpression(cmp_idx, left); + + auto &right = state.right_keys.data[cmp_idx]; + state.right_executor.ExecuteExpression(cmp_idx, right); + + if (tail_count < result_count) { + left.Slice(*sel, tail_count); + right.Slice(*sel, tail_count); + } + tail_count = SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, true_sel); + sel = true_sel; + } + chunk.Fuse(right_chunk); + + if (tail_count < result_count) { + result_count = tail_count; + chunk.Slice(*sel, result_count); + } + } + + // We need all of the data to compute other predicates, + // but we only return what is in the projection map + ProjectResult(chunk, result); + + // found matches: mark the found matches if required + if (left_table.found_match) { + for (idx_t i = 0; i < result_count; i++) { + left_table.found_match[state.left_base + lsel[sel->get_index(i)]] = true; + } + } + if (right_table.found_match) { + for (idx_t i = 0; i < result_count; i++) { + right_table.found_match[state.right_base + rsel[sel->get_index(i)]] = true; + } + } + result.Verify(); + } while (result.size() == 0); +} + +class IEJoinGlobalSourceState : public GlobalSourceState { +public: + explicit IEJoinGlobalSourceState(const PhysicalIEJoin &op) + : op(op), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), right_outers(0), + next_right(0) { + } + + void Initialize(IEJoinGlobalState &sink_state) { + lock_guard initializing(lock); + if (initialized) { + return; + } + + // Compute the starting row for reach block + // (In theory these are all the same size, but you never know...) + auto &left_table = *sink_state.tables[0]; + const auto left_blocks = left_table.BlockCount(); + idx_t left_base = 0; + + for (size_t lhs = 0; lhs < left_blocks; ++lhs) { + left_bases.emplace_back(left_base); + left_base += left_table.BlockSize(lhs); + } + + auto &right_table = *sink_state.tables[1]; + const auto right_blocks = right_table.BlockCount(); + idx_t right_base = 0; + for (size_t rhs = 0; rhs < right_blocks; ++rhs) { + right_bases.emplace_back(right_base); + right_base += right_table.BlockSize(rhs); + } + + // Outer join block counts + if (left_table.found_match) { + left_outers = left_blocks; + } + + if (right_table.found_match) { + right_outers = right_blocks; + } + + // Ready for action + initialized = true; + } + +public: + idx_t MaxThreads() override { + // We can't leverage any more threads than block pairs. + const auto &sink_state = (op.sink_state->Cast()); + return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); + } + + void GetNextPair(ClientContext &client, IEJoinGlobalState &gstate, IEJoinLocalSourceState &lstate) { + auto &left_table = *gstate.tables[0]; + auto &right_table = *gstate.tables[1]; + + const auto left_blocks = left_table.BlockCount(); + const auto right_blocks = right_table.BlockCount(); + const auto pair_count = left_blocks * right_blocks; + + // Regular block + const auto i = next_pair++; + if (i < pair_count) { + const auto b1 = i / right_blocks; + const auto b2 = i % right_blocks; + + lstate.left_block_index = b1; + lstate.left_base = left_bases[b1]; + + lstate.right_block_index = b2; + lstate.right_base = right_bases[b2]; + + lstate.joiner = make_uniq(client, op, left_table, b1, right_table, b2); + return; + } + + // Outer joins + if (!left_outers && !right_outers) { + return; + } + + // Spin wait for regular blocks to finish(!) + while (completed < pair_count) { + std::this_thread::yield(); + } + + // Left outer blocks + const auto l = next_left++; + if (l < left_outers) { + lstate.joiner = nullptr; + lstate.left_block_index = l; + lstate.left_base = left_bases[l]; + + lstate.left_matches = left_table.found_match.get() + lstate.left_base; + lstate.outer_idx = 0; + lstate.outer_count = left_table.BlockSize(l); + return; + } else { + lstate.left_matches = nullptr; + } + + // Right outer block + const auto r = next_right++; + if (r < right_outers) { + lstate.joiner = nullptr; + lstate.right_block_index = r; + lstate.right_base = right_bases[r]; + + lstate.right_matches = right_table.found_match.get() + lstate.right_base; + lstate.outer_idx = 0; + lstate.outer_count = right_table.BlockSize(r); + return; + } else { + lstate.right_matches = nullptr; + } + } + + void PairCompleted(ClientContext &client, IEJoinGlobalState &gstate, IEJoinLocalSourceState &lstate) { + lstate.joiner.reset(); + ++completed; + GetNextPair(client, gstate, lstate); + } + + const PhysicalIEJoin &op; + + mutex lock; + bool initialized; + + // Join queue state + std::atomic next_pair; + std::atomic completed; + + // Block base row number + vector left_bases; + vector right_bases; + + // Outer joins + idx_t left_outers; + std::atomic next_left; + + idx_t right_outers; + std::atomic next_right; +}; + +unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this); +} + +unique_ptr PhysicalIEJoin::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(context.client, *this); +} + +SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &result, + OperatorSourceInput &input) const { + auto &ie_sink = sink_state->Cast(); + auto &ie_gstate = input.global_state.Cast(); + auto &ie_lstate = input.local_state.Cast(); + + ie_gstate.Initialize(ie_sink); + + if (!ie_lstate.joiner && !ie_lstate.left_matches && !ie_lstate.right_matches) { + ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); + } + + // Process INNER results + while (ie_lstate.joiner) { + ResolveComplexJoin(context, result, ie_lstate); + + if (result.size()) { + return SourceResultType::HAVE_MORE_OUTPUT; + } + + ie_gstate.PairCompleted(context.client, ie_sink, ie_lstate); + } + + // Process LEFT OUTER results + const auto left_cols = children[0]->GetTypes().size(); + while (ie_lstate.left_matches) { + const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.left_matches); + if (!count) { + ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); + continue; + } + auto &chunk = ie_lstate.unprojected; + chunk.Reset(); + SliceSortedPayload(chunk, ie_sink.tables[0]->global_sort_state, ie_lstate.left_block_index, ie_lstate.true_sel, + count); + + // Fill in NULLs to the right + for (auto col_idx = left_cols; col_idx < chunk.ColumnCount(); ++col_idx) { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); + } + + ProjectResult(chunk, result); + result.SetCardinality(count); + result.Verify(); + + return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + } + + // Process RIGHT OUTER results + while (ie_lstate.right_matches) { + const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.right_matches); + if (!count) { + ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); + continue; + } + + auto &chunk = ie_lstate.unprojected; + chunk.Reset(); + SliceSortedPayload(chunk, ie_sink.tables[1]->global_sort_state, ie_lstate.right_block_index, ie_lstate.true_sel, + count, left_cols); + + // Fill in NULLs to the left + for (idx_t col_idx = 0; col_idx < left_cols; ++col_idx) { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); + } + + ProjectResult(chunk, result); + result.SetCardinality(count); + result.Verify(); + + break; + } + + return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalIEJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + D_ASSERT(children.size() == 2); + if (meta_pipeline.HasRecursiveCTE()) { + throw NotImplementedException("IEJoins are not supported in recursive CTEs yet"); + } + + // becomes a source after both children fully sink their data + meta_pipeline.GetState().SetPipelineSource(current, *this); + + // Create one child meta pipeline that will hold the LHS and RHS pipelines + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + + // Build out LHS + auto lhs_pipeline = child_meta_pipeline.GetBasePipeline(); + children[0]->BuildPipelines(*lhs_pipeline, child_meta_pipeline); + + // Build out RHS + auto rhs_pipeline = child_meta_pipeline.CreatePipeline(); + children[1]->BuildPipelines(*rhs_pipeline, child_meta_pipeline); + + // Despite having the same sink, RHS and everything created after it need their own (same) PipelineFinishEvent + child_meta_pipeline.AddFinishEvent(rhs_pipeline); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_index_join.cpp b/src/duckdb/src/execution/operator/join/physical_index_join.cpp new file mode 100644 index 00000000..3035e0ec --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_index_join.cpp @@ -0,0 +1,242 @@ +#include "duckdb/execution/operator/join/physical_index_join.hpp" + +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/operator/scan/physical_table_scan.hpp" +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/execution/index/art/art_key.hpp" + +namespace duckdb { + +class IndexJoinOperatorState : public CachingOperatorState { +public: + IndexJoinOperatorState(ClientContext &context, const PhysicalIndexJoin &op) + : probe_executor(context), arena_allocator(BufferAllocator::Get(context)), keys(STANDARD_VECTOR_SIZE) { + auto &allocator = Allocator::Get(context); + rhs_rows.resize(STANDARD_VECTOR_SIZE); + result_sizes.resize(STANDARD_VECTOR_SIZE); + + join_keys.Initialize(allocator, op.condition_types); + for (auto &cond : op.conditions) { + probe_executor.AddExpression(*cond.left); + } + if (!op.fetch_types.empty()) { + rhs_chunk.Initialize(allocator, op.fetch_types); + } + rhs_sel.Initialize(STANDARD_VECTOR_SIZE); + } + + bool first_fetch = true; + idx_t lhs_idx = 0; + idx_t rhs_idx = 0; + idx_t result_size = 0; + vector result_sizes; + DataChunk join_keys; + DataChunk rhs_chunk; + SelectionVector rhs_sel; + + //! Vector of rows that mush be fetched for every LHS key + vector> rhs_rows; + ExpressionExecutor probe_executor; + + ArenaAllocator arena_allocator; + vector keys; + unique_ptr fetch_state; + +public: + void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { + context.thread.profiler.Flush(op, probe_executor, "probe_executor", 0); + } +}; + +PhysicalIndexJoin::PhysicalIndexJoin(LogicalOperator &op, unique_ptr left, + unique_ptr right, vector cond, JoinType join_type, + const vector &left_projection_map_p, vector right_projection_map_p, + vector column_ids_p, Index &index_p, bool lhs_first, + idx_t estimated_cardinality) + : CachingPhysicalOperator(PhysicalOperatorType::INDEX_JOIN, std::move(op.types), estimated_cardinality), + left_projection_map(left_projection_map_p), right_projection_map(std::move(right_projection_map_p)), + index(index_p), conditions(std::move(cond)), join_type(join_type), lhs_first(lhs_first) { + D_ASSERT(right->type == PhysicalOperatorType::TABLE_SCAN); + auto &tbl_scan = right->Cast(); + column_ids = std::move(column_ids_p); + children.push_back(std::move(left)); + children.push_back(std::move(right)); + for (auto &condition : conditions) { + condition_types.push_back(condition.left->return_type); + } + //! Only add to fetch_ids columns that are not indexed + for (auto &index_id : index.column_ids) { + index_ids.insert(index_id); + } + + for (idx_t i = 0; i < column_ids.size(); i++) { + auto column_id = column_ids[i]; + auto it = index_ids.find(column_id); + if (it == index_ids.end()) { + fetch_ids.push_back(column_id); + if (column_id == COLUMN_IDENTIFIER_ROW_ID) { + fetch_types.emplace_back(LogicalType::ROW_TYPE); + } else { + fetch_types.push_back(tbl_scan.returned_types[column_id]); + } + } + } + if (right_projection_map.empty()) { + for (column_t i = 0; i < column_ids.size(); i++) { + right_projection_map.push_back(i); + } + } + if (left_projection_map.empty()) { + for (column_t i = 0; i < children[0]->types.size(); i++) { + left_projection_map.push_back(i); + } + } +} + +unique_ptr PhysicalIndexJoin::GetOperatorState(ExecutionContext &context) const { + return make_uniq(context.client, *this); +} + +void PhysicalIndexJoin::Output(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + OperatorState &state_p) const { + auto &phy_tbl_scan = children[1]->Cast(); + auto &bind_tbl = phy_tbl_scan.bind_data->Cast(); + auto &transaction = DuckTransaction::Get(context.client, bind_tbl.table.catalog); + auto &state = state_p.Cast(); + + auto &tbl = bind_tbl.table.GetStorage(); + idx_t output_sel_idx = 0; + vector fetch_rows; + + while (output_sel_idx < STANDARD_VECTOR_SIZE && state.lhs_idx < input.size()) { + if (state.rhs_idx < state.result_sizes[state.lhs_idx]) { + state.rhs_sel.set_index(output_sel_idx++, state.lhs_idx); + if (!fetch_types.empty()) { + //! We need to collect the rows we want to fetch + fetch_rows.push_back(state.rhs_rows[state.lhs_idx][state.rhs_idx]); + } + state.rhs_idx++; + } else { + //! We are done with the matches from this LHS Key + state.rhs_idx = 0; + state.lhs_idx++; + } + } + //! Now we fetch the RHS data + if (!fetch_types.empty()) { + if (fetch_rows.empty()) { + return; + } + state.rhs_chunk.Reset(); + state.fetch_state = make_uniq(); + Vector row_ids(LogicalType::ROW_TYPE, data_ptr_cast(&fetch_rows[0])); + tbl.Fetch(transaction, state.rhs_chunk, fetch_ids, row_ids, output_sel_idx, *state.fetch_state); + } + + //! Now we actually produce our result chunk + idx_t left_offset = lhs_first ? 0 : right_projection_map.size(); + idx_t right_offset = lhs_first ? left_projection_map.size() : 0; + idx_t rhs_column_idx = 0; + for (idx_t i = 0; i < right_projection_map.size(); i++) { + auto it = index_ids.find(column_ids[right_projection_map[i]]); + if (it == index_ids.end()) { + chunk.data[right_offset + i].Reference(state.rhs_chunk.data[rhs_column_idx++]); + } else { + chunk.data[right_offset + i].Slice(state.join_keys.data[0], state.rhs_sel, output_sel_idx); + } + } + for (idx_t i = 0; i < left_projection_map.size(); i++) { + chunk.data[left_offset + i].Slice(input.data[left_projection_map[i]], state.rhs_sel, output_sel_idx); + } + + state.result_size = output_sel_idx; + chunk.SetCardinality(state.result_size); +} + +void PhysicalIndexJoin::GetRHSMatches(ExecutionContext &context, DataChunk &input, OperatorState &state_p) const { + + auto &state = state_p.Cast(); + auto &art = index.Cast(); + + // generate the keys for this chunk + state.arena_allocator.Reset(); + ART::GenerateKeys(state.arena_allocator, state.join_keys, state.keys); + + for (idx_t i = 0; i < input.size(); i++) { + state.rhs_rows[i].clear(); + if (!state.keys[i].Empty()) { + if (fetch_types.empty()) { + IndexLock lock; + index.InitializeLock(lock); + art.SearchEqualJoinNoFetch(state.keys[i], state.result_sizes[i]); + } else { + IndexLock lock; + index.InitializeLock(lock); + art.SearchEqual(state.keys[i], (idx_t)-1, state.rhs_rows[i]); + state.result_sizes[i] = state.rhs_rows[i].size(); + } + } else { + //! This is null so no matches + state.result_sizes[i] = 0; + } + } + for (idx_t i = input.size(); i < STANDARD_VECTOR_SIZE; i++) { + //! No LHS chunk value so result size is empty + state.result_sizes[i] = 0; + } +} + +OperatorResultType PhysicalIndexJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state_p) const { + auto &state = state_p.Cast(); + + state.result_size = 0; + if (state.first_fetch) { + state.probe_executor.Execute(input, state.join_keys); + + //! Fill Matches for the current LHS chunk + GetRHSMatches(context, input, state_p); + state.first_fetch = false; + } + //! Check if we need to get a new LHS chunk + if (state.lhs_idx >= input.size()) { + state.lhs_idx = 0; + state.rhs_idx = 0; + state.first_fetch = true; + // reset the LHS chunk to reset the validity masks + state.join_keys.Reset(); + return OperatorResultType::NEED_MORE_INPUT; + } + //! Output vectors + if (state.lhs_idx < input.size()) { + Output(context, input, chunk, state_p); + } + return OperatorResultType::HAVE_MORE_OUTPUT; +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalIndexJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + // index join: we only continue into the LHS + // the right side is probed by the index join + // so we don't need to do anything in the pipeline with this child + meta_pipeline.GetState().AddPipelineOperator(current, *this); + children[0]->BuildPipelines(current, meta_pipeline); +} + +vector> PhysicalIndexJoin::GetSources() const { + return children[0]->GetSources(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_join.cpp b/src/duckdb/src/execution/operator/join/physical_join.cpp new file mode 100644 index 00000000..5bf8aebc --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_join.cpp @@ -0,0 +1,84 @@ +#include "duckdb/execution/operator/join/physical_join.hpp" + +#include "duckdb/execution/operator/join/physical_hash_join.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" + +namespace duckdb { + +PhysicalJoin::PhysicalJoin(LogicalOperator &op, PhysicalOperatorType type, JoinType join_type, + idx_t estimated_cardinality) + : CachingPhysicalOperator(type, op.types, estimated_cardinality), join_type(join_type) { +} + +bool PhysicalJoin::EmptyResultIfRHSIsEmpty() const { + // empty RHS with INNER, RIGHT or SEMI join means empty result set + switch (join_type) { + case JoinType::INNER: + case JoinType::RIGHT: + case JoinType::SEMI: + return true; + default: + return false; + } +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalJoin::BuildJoinPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline, PhysicalOperator &op) { + op.op_state.reset(); + op.sink_state.reset(); + + // 'current' is the probe pipeline: add this operator + auto &state = meta_pipeline.GetState(); + state.AddPipelineOperator(current, op); + + // save the last added pipeline to set up dependencies later (in case we need to add a child pipeline) + vector> pipelines_so_far; + meta_pipeline.GetPipelines(pipelines_so_far, false); + auto last_pipeline = pipelines_so_far.back().get(); + + // on the RHS (build side), we construct a child MetaPipeline with this operator as its sink + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, op); + child_meta_pipeline.Build(*op.children[1]); + + // continue building the current pipeline on the LHS (probe side) + op.children[0]->BuildPipelines(current, meta_pipeline); + + switch (op.type) { + case PhysicalOperatorType::POSITIONAL_JOIN: + // Positional joins are always outer + meta_pipeline.CreateChildPipeline(current, op, last_pipeline); + return; + case PhysicalOperatorType::CROSS_PRODUCT: + return; + default: + break; + } + + // Join can become a source operator if it's RIGHT/OUTER, or if the hash join goes out-of-core + bool add_child_pipeline = false; + auto &join_op = op.Cast(); + if (join_op.IsSource()) { + add_child_pipeline = true; + } + + if (add_child_pipeline) { + meta_pipeline.CreateChildPipeline(current, op, last_pipeline); + } +} + +void PhysicalJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); +} + +vector> PhysicalJoin::GetSources() const { + auto result = children[0]->GetSources(); + if (IsSource()) { + result.push_back(*this); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp new file mode 100644 index 00000000..9036748e --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp @@ -0,0 +1,466 @@ +#include "duckdb/execution/operator/join/physical_nested_loop_join.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/nested_loop_join.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/execution/operator/join/outer_join_marker.hpp" + +namespace duckdb { + +PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(LogicalOperator &op, unique_ptr left, + unique_ptr right, vector cond, + JoinType join_type, idx_t estimated_cardinality) + : PhysicalComparisonJoin(op, PhysicalOperatorType::NESTED_LOOP_JOIN, std::move(cond), join_type, + estimated_cardinality) { + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +bool PhysicalJoin::HasNullValues(DataChunk &chunk) { + for (idx_t col_idx = 0; col_idx < chunk.ColumnCount(); col_idx++) { + UnifiedVectorFormat vdata; + chunk.data[col_idx].ToUnifiedFormat(chunk.size(), vdata); + + if (vdata.validity.AllValid()) { + continue; + } + for (idx_t i = 0; i < chunk.size(); i++) { + auto idx = vdata.sel->get_index(i); + if (!vdata.validity.RowIsValid(idx)) { + return true; + } + } + } + return false; +} + +template +static void ConstructSemiOrAntiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { + D_ASSERT(left.ColumnCount() == result.ColumnCount()); + // create the selection vector from the matches that were found + idx_t result_count = 0; + SelectionVector sel(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < left.size(); i++) { + if (found_match[i] == MATCH) { + sel.set_index(result_count++, i); + } + } + // construct the final result + if (result_count > 0) { + // we only return the columns on the left side + // project them using the result selection vector + // reference the columns of the left side from the result + result.Slice(left, sel, result_count); + } else { + result.SetCardinality(0); + } +} + +void PhysicalJoin::ConstructSemiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { + ConstructSemiOrAntiJoinResult(left, result, found_match); +} + +void PhysicalJoin::ConstructAntiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { + ConstructSemiOrAntiJoinResult(left, result, found_match); +} + +void PhysicalJoin::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &left, DataChunk &result, bool found_match[], + bool has_null) { + // for the initial set of columns we just reference the left side + result.SetCardinality(left); + for (idx_t i = 0; i < left.ColumnCount(); i++) { + result.data[i].Reference(left.data[i]); + } + auto &mark_vector = result.data.back(); + mark_vector.SetVectorType(VectorType::FLAT_VECTOR); + // first we set the NULL values from the join keys + // if there is any NULL in the keys, the result is NULL + auto bool_result = FlatVector::GetData(mark_vector); + auto &mask = FlatVector::Validity(mark_vector); + for (idx_t col_idx = 0; col_idx < join_keys.ColumnCount(); col_idx++) { + UnifiedVectorFormat jdata; + join_keys.data[col_idx].ToUnifiedFormat(join_keys.size(), jdata); + if (!jdata.validity.AllValid()) { + for (idx_t i = 0; i < join_keys.size(); i++) { + auto jidx = jdata.sel->get_index(i); + mask.Set(i, jdata.validity.RowIsValid(jidx)); + } + } + } + // now set the remaining entries to either true or false based on whether a match was found + if (found_match) { + for (idx_t i = 0; i < left.size(); i++) { + bool_result[i] = found_match[i]; + } + } else { + memset(bool_result, 0, sizeof(bool) * left.size()); + } + // if the right side contains NULL values, the result of any FALSE becomes NULL + if (has_null) { + for (idx_t i = 0; i < left.size(); i++) { + if (!bool_result[i]) { + mask.SetInvalid(i); + } + } + } +} + +bool PhysicalNestedLoopJoin::IsSupported(const vector &conditions, JoinType join_type) { + if (join_type == JoinType::MARK) { + return true; + } + for (auto &cond : conditions) { + if (cond.left->return_type.InternalType() == PhysicalType::STRUCT || + cond.left->return_type.InternalType() == PhysicalType::LIST) { + return false; + } + } + return true; +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class NestedLoopJoinLocalState : public LocalSinkState { +public: + explicit NestedLoopJoinLocalState(ClientContext &context, const vector &conditions) + : rhs_executor(context) { + vector condition_types; + for (auto &cond : conditions) { + rhs_executor.AddExpression(*cond.right); + condition_types.push_back(cond.right->return_type); + } + right_condition.Initialize(Allocator::Get(context), condition_types); + } + + //! The chunk holding the right condition + DataChunk right_condition; + //! The executor of the RHS condition + ExpressionExecutor rhs_executor; +}; + +class NestedLoopJoinGlobalState : public GlobalSinkState { +public: + explicit NestedLoopJoinGlobalState(ClientContext &context, const PhysicalNestedLoopJoin &op) + : right_payload_data(context, op.children[1]->types), right_condition_data(context, op.GetJoinTypes()), + has_null(false), right_outer(IsRightOuterJoin(op.join_type)) { + } + + mutex nj_lock; + //! Materialized data of the RHS + ColumnDataCollection right_payload_data; + //! Materialized join condition of the RHS + ColumnDataCollection right_condition_data; + //! Whether or not the RHS of the nested loop join has NULL values + atomic has_null; + //! A bool indicating for each tuple in the RHS if they found a match (only used in FULL OUTER JOIN) + OuterJoinMarker right_outer; +}; + +vector PhysicalNestedLoopJoin::GetJoinTypes() const { + vector result; + for (auto &op : conditions) { + result.push_back(op.right->return_type); + } + return result; +} + +SinkResultType PhysicalNestedLoopJoin::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &nlj_state = input.local_state.Cast(); + + // resolve the join expression of the right side + nlj_state.right_condition.Reset(); + nlj_state.rhs_executor.Execute(chunk, nlj_state.right_condition); + + // if we have not seen any NULL values yet, and we are performing a MARK join, check if there are NULL values in + // this chunk + if (join_type == JoinType::MARK && !gstate.has_null) { + if (HasNullValues(nlj_state.right_condition)) { + gstate.has_null = true; + } + } + + // append the payload data and the conditions + lock_guard nj_guard(gstate.nj_lock); + gstate.right_payload_data.Append(chunk); + gstate.right_condition_data.Append(nlj_state.right_condition); + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalNestedLoopJoin::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &state = input.local_state.Cast(); + auto &client_profiler = QueryProfiler::Get(context.client); + + context.thread.profiler.Flush(*this, state.rhs_executor, "rhs_executor", 1); + client_profiler.Flush(context.thread.profiler); + + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalNestedLoopJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + gstate.right_outer.Initialize(gstate.right_payload_data.Count()); + if (gstate.right_payload_data.Count() == 0 && EmptyResultIfRHSIsEmpty()) { + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + return SinkFinalizeType::READY; +} + +unique_ptr PhysicalNestedLoopJoin::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr PhysicalNestedLoopJoin::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context.client, conditions); +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +class PhysicalNestedLoopJoinState : public CachingOperatorState { +public: + PhysicalNestedLoopJoinState(ClientContext &context, const PhysicalNestedLoopJoin &op, + const vector &conditions) + : fetch_next_left(true), fetch_next_right(false), lhs_executor(context), left_tuple(0), right_tuple(0), + left_outer(IsLeftOuterJoin(op.join_type)) { + vector condition_types; + for (auto &cond : conditions) { + lhs_executor.AddExpression(*cond.left); + condition_types.push_back(cond.left->return_type); + } + auto &allocator = Allocator::Get(context); + left_condition.Initialize(allocator, condition_types); + right_condition.Initialize(allocator, condition_types); + right_payload.Initialize(allocator, op.children[1]->GetTypes()); + left_outer.Initialize(STANDARD_VECTOR_SIZE); + } + + bool fetch_next_left; + bool fetch_next_right; + DataChunk left_condition; + //! The executor of the LHS condition + ExpressionExecutor lhs_executor; + + ColumnDataScanState condition_scan_state; + ColumnDataScanState payload_scan_state; + DataChunk right_condition; + DataChunk right_payload; + + idx_t left_tuple; + idx_t right_tuple; + + OuterJoinMarker left_outer; + +public: + void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { + context.thread.profiler.Flush(op, lhs_executor, "lhs_executor", 0); + } +}; + +unique_ptr PhysicalNestedLoopJoin::GetOperatorState(ExecutionContext &context) const { + return make_uniq(context.client, *this, conditions); +} + +OperatorResultType PhysicalNestedLoopJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, + DataChunk &chunk, GlobalOperatorState &gstate_p, + OperatorState &state_p) const { + auto &gstate = sink_state->Cast(); + + if (gstate.right_payload_data.Count() == 0) { + // empty RHS + if (!EmptyResultIfRHSIsEmpty()) { + ConstructEmptyJoinResult(join_type, gstate.has_null, input, chunk); + return OperatorResultType::NEED_MORE_INPUT; + } else { + return OperatorResultType::FINISHED; + } + } + + switch (join_type) { + case JoinType::SEMI: + case JoinType::ANTI: + case JoinType::MARK: + // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk + ResolveSimpleJoin(context, input, chunk, state_p); + return OperatorResultType::NEED_MORE_INPUT; + case JoinType::LEFT: + case JoinType::INNER: + case JoinType::OUTER: + case JoinType::RIGHT: + return ResolveComplexJoin(context, input, chunk, state_p); + default: + throw NotImplementedException("Unimplemented type for nested loop join!"); + } +} + +void PhysicalNestedLoopJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + OperatorState &state_p) const { + auto &state = state_p.Cast(); + auto &gstate = sink_state->Cast(); + + // resolve the left join condition for the current chunk + state.left_condition.Reset(); + state.lhs_executor.Execute(input, state.left_condition); + + bool found_match[STANDARD_VECTOR_SIZE] = {false}; + NestedLoopJoinMark::Perform(state.left_condition, gstate.right_condition_data, found_match, conditions); + switch (join_type) { + case JoinType::MARK: + // now construct the mark join result from the found matches + PhysicalJoin::ConstructMarkJoinResult(state.left_condition, input, chunk, found_match, gstate.has_null); + break; + case JoinType::SEMI: + // construct the semi join result from the found matches + PhysicalJoin::ConstructSemiJoinResult(input, chunk, found_match); + break; + case JoinType::ANTI: + // construct the anti join result from the found matches + PhysicalJoin::ConstructAntiJoinResult(input, chunk, found_match); + break; + default: + throw NotImplementedException("Unimplemented type for simple nested loop join!"); + } +} + +OperatorResultType PhysicalNestedLoopJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, + DataChunk &chunk, OperatorState &state_p) const { + auto &state = state_p.Cast(); + auto &gstate = sink_state->Cast(); + + idx_t match_count; + do { + if (state.fetch_next_right) { + // we exhausted the chunk on the right: move to the next chunk on the right + state.left_tuple = 0; + state.right_tuple = 0; + state.fetch_next_right = false; + // check if we exhausted all chunks on the RHS + if (gstate.right_condition_data.Scan(state.condition_scan_state, state.right_condition)) { + if (!gstate.right_payload_data.Scan(state.payload_scan_state, state.right_payload)) { + throw InternalException("Nested loop join: payload and conditions are unaligned!?"); + } + if (state.right_condition.size() != state.right_payload.size()) { + throw InternalException("Nested loop join: payload and conditions are unaligned!?"); + } + } else { + // we exhausted all chunks on the right: move to the next chunk on the left + state.fetch_next_left = true; + if (state.left_outer.Enabled()) { + // left join: before we move to the next chunk, see if we need to output any vectors that didn't + // have a match found + state.left_outer.ConstructLeftJoinResult(input, chunk); + state.left_outer.Reset(); + } + return OperatorResultType::NEED_MORE_INPUT; + } + } + if (state.fetch_next_left) { + // resolve the left join condition for the current chunk + state.left_condition.Reset(); + state.lhs_executor.Execute(input, state.left_condition); + + state.left_tuple = 0; + state.right_tuple = 0; + gstate.right_condition_data.InitializeScan(state.condition_scan_state); + gstate.right_condition_data.Scan(state.condition_scan_state, state.right_condition); + + gstate.right_payload_data.InitializeScan(state.payload_scan_state); + gstate.right_payload_data.Scan(state.payload_scan_state, state.right_payload); + state.fetch_next_left = false; + } + // now we have a left and a right chunk that we can join together + // note that we only get here in the case of a LEFT, INNER or FULL join + auto &left_chunk = input; + auto &right_condition = state.right_condition; + auto &right_payload = state.right_payload; + + // sanity check + left_chunk.Verify(); + right_condition.Verify(); + right_payload.Verify(); + + // now perform the join + SelectionVector lvector(STANDARD_VECTOR_SIZE), rvector(STANDARD_VECTOR_SIZE); + match_count = NestedLoopJoinInner::Perform(state.left_tuple, state.right_tuple, state.left_condition, + right_condition, lvector, rvector, conditions); + // we have finished resolving the join conditions + if (match_count > 0) { + // we have matching tuples! + // construct the result + state.left_outer.SetMatches(lvector, match_count); + gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); + + chunk.Slice(input, lvector, match_count); + chunk.Slice(right_payload, rvector, match_count, input.ColumnCount()); + } + + // check if we exhausted the RHS, if we did we need to move to the next right chunk in the next iteration + if (state.right_tuple >= right_condition.size()) { + state.fetch_next_right = true; + } + } while (match_count == 0); + return OperatorResultType::HAVE_MORE_OUTPUT; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class NestedLoopJoinGlobalScanState : public GlobalSourceState { +public: + explicit NestedLoopJoinGlobalScanState(const PhysicalNestedLoopJoin &op) : op(op) { + D_ASSERT(op.sink_state); + auto &sink = op.sink_state->Cast(); + sink.right_outer.InitializeScan(sink.right_payload_data, scan_state); + } + + const PhysicalNestedLoopJoin &op; + OuterJoinGlobalScanState scan_state; + +public: + idx_t MaxThreads() override { + auto &sink = op.sink_state->Cast(); + return sink.right_outer.MaxThreads(); + } +}; + +class NestedLoopJoinLocalScanState : public LocalSourceState { +public: + explicit NestedLoopJoinLocalScanState(const PhysicalNestedLoopJoin &op, NestedLoopJoinGlobalScanState &gstate) { + D_ASSERT(op.sink_state); + auto &sink = op.sink_state->Cast(); + sink.right_outer.InitializeScan(gstate.scan_state, scan_state); + } + + OuterJoinLocalScanState scan_state; +}; + +unique_ptr PhysicalNestedLoopJoin::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this); +} + +unique_ptr PhysicalNestedLoopJoin::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(*this, gstate.Cast()); +} + +SourceResultType PhysicalNestedLoopJoin::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + D_ASSERT(IsRightOuterJoin(join_type)); + // check if we need to scan any unmatched tuples from the RHS for the full/right outer join + auto &sink = sink_state->Cast(); + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan chunks we still need to output + sink.right_outer.Scan(gstate.scan_state, lstate.scan_state, chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp new file mode 100644 index 00000000..42a2d187 --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp @@ -0,0 +1,763 @@ +#include "duckdb/execution/operator/join/physical_piecewise_merge_join.hpp" + +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/sort/comparators.hpp" +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/join/outer_join_marker.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/event.hpp" +#include "duckdb/parallel/thread_context.hpp" + +namespace duckdb { + +PhysicalPiecewiseMergeJoin::PhysicalPiecewiseMergeJoin(LogicalComparisonJoin &op, unique_ptr left, + unique_ptr right, vector cond, + JoinType join_type, idx_t estimated_cardinality) + : PhysicalRangeJoin(op, PhysicalOperatorType::PIECEWISE_MERGE_JOIN, std::move(left), std::move(right), + std::move(cond), join_type, estimated_cardinality) { + + for (auto &cond : conditions) { + D_ASSERT(cond.left->return_type == cond.right->return_type); + join_key_types.push_back(cond.left->return_type); + + // Convert the conditions to sort orders + auto left = cond.left->Copy(); + auto right = cond.right->Copy(); + switch (cond.comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + lhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(left)); + rhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(right)); + break; + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + lhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(left)); + rhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(right)); + break; + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_DISTINCT_FROM: + // Allowed in multi-predicate joins, but can't be first/sort. + D_ASSERT(!lhs_orders.empty()); + lhs_orders.emplace_back(OrderType::INVALID, OrderByNullType::NULLS_LAST, std::move(left)); + rhs_orders.emplace_back(OrderType::INVALID, OrderByNullType::NULLS_LAST, std::move(right)); + break; + + default: + // COMPARE EQUAL not supported with merge join + throw NotImplementedException("Unimplemented join type for merge join"); + } + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class MergeJoinLocalState : public LocalSinkState { +public: + explicit MergeJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child) + : table(context, op, child) { + } + + //! The local sort state + PhysicalRangeJoin::LocalSortedTable table; +}; + +class MergeJoinGlobalState : public GlobalSinkState { +public: + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; + +public: + MergeJoinGlobalState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op) { + RowLayout rhs_layout; + rhs_layout.Initialize(op.children[1]->types); + vector rhs_order; + rhs_order.emplace_back(op.rhs_orders[0].Copy()); + table = make_uniq(context, rhs_order, rhs_layout); + } + + inline idx_t Count() const { + return table->count; + } + + void Sink(DataChunk &input, MergeJoinLocalState &lstate) { + auto &global_sort_state = table->global_sort_state; + auto &local_sort_state = lstate.table.local_sort_state; + + // Sink the data into the local sort state + lstate.table.Sink(input, global_sort_state); + + // When sorting data reaches a certain size, we sort it + if (local_sort_state.SizeInBytes() >= table->memory_per_thread) { + local_sort_state.Sort(global_sort_state, true); + } + } + + unique_ptr table; +}; + +unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSinkState(ExecutionContext &context) const { + // We only sink the RHS + return make_uniq(context.client, *this, 1); +} + +SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + gstate.Sink(chunk, lstate); + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + gstate.table->Combine(lstate.table); + auto &client_profiler = QueryProfiler::Get(context.client); + + context.thread.profiler.Flush(*this, lstate.table.executor, "rhs_executor", 1); + client_profiler.Flush(context.thread.profiler); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &global_sort_state = gstate.table->global_sort_state; + + if (IsRightOuterJoin(join_type)) { + // for FULL/RIGHT OUTER JOIN, initialize found_match to false for every tuple + gstate.table->IntializeMatches(); + } + if (global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + // Empty input! + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // Sort the current input child + gstate.table->Finalize(pipeline, event); + + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +class PiecewiseMergeJoinState : public CachingOperatorState { +public: + using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; + + PiecewiseMergeJoinState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op, bool force_external) + : context(context), allocator(Allocator::Get(context)), op(op), + buffer_manager(BufferManager::GetBufferManager(context)), force_external(force_external), + left_outer(IsLeftOuterJoin(op.join_type)), left_position(0), first_fetch(true), finished(true), + right_position(0), right_chunk_index(0), rhs_executor(context) { + vector condition_types; + for (auto &order : op.lhs_orders) { + condition_types.push_back(order.expression->return_type); + } + left_outer.Initialize(STANDARD_VECTOR_SIZE); + lhs_layout.Initialize(op.children[0]->types); + lhs_payload.Initialize(allocator, op.children[0]->types); + + lhs_order.emplace_back(op.lhs_orders[0].Copy()); + + // Set up shared data for multiple predicates + sel.Initialize(STANDARD_VECTOR_SIZE); + condition_types.clear(); + for (auto &order : op.rhs_orders) { + rhs_executor.AddExpression(*order.expression); + condition_types.push_back(order.expression->return_type); + } + rhs_keys.Initialize(allocator, condition_types); + } + + ClientContext &context; + Allocator &allocator; + const PhysicalPiecewiseMergeJoin &op; + BufferManager &buffer_manager; + bool force_external; + + // Block sorting + DataChunk lhs_payload; + OuterJoinMarker left_outer; + vector lhs_order; + RowLayout lhs_layout; + unique_ptr lhs_local_table; + unique_ptr lhs_global_state; + unique_ptr scanner; + + // Simple scans + idx_t left_position; + + // Complex scans + bool first_fetch; + bool finished; + idx_t right_position; + idx_t right_chunk_index; + idx_t right_base; + idx_t prev_left_index; + + // Secondary predicate shared data + SelectionVector sel; + DataChunk rhs_keys; + DataChunk rhs_input; + ExpressionExecutor rhs_executor; + vector payload_heap_handles; + +public: + void ResolveJoinKeys(DataChunk &input) { + // sort by join key + lhs_global_state = make_uniq(buffer_manager, lhs_order, lhs_layout); + lhs_local_table = make_uniq(context, op, 0); + lhs_local_table->Sink(input, *lhs_global_state); + + // Set external (can be forced with the PRAGMA) + lhs_global_state->external = force_external; + lhs_global_state->AddLocalState(lhs_local_table->local_sort_state); + lhs_global_state->PrepareMergePhase(); + while (lhs_global_state->sorted_blocks.size() > 1) { + MergeSorter merge_sorter(*lhs_global_state, buffer_manager); + merge_sorter.PerformInMergeRound(); + lhs_global_state->CompleteMergeRound(); + } + + // Scan the sorted payload + D_ASSERT(lhs_global_state->sorted_blocks.size() == 1); + + scanner = make_uniq(*lhs_global_state->sorted_blocks[0]->payload_data, *lhs_global_state); + lhs_payload.Reset(); + scanner->Scan(lhs_payload); + + // Recompute the sorted keys from the sorted input + lhs_local_table->keys.Reset(); + lhs_local_table->executor.Execute(lhs_payload, lhs_local_table->keys); + } + + void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { + if (lhs_local_table) { + context.thread.profiler.Flush(op, lhs_local_table->executor, "lhs_executor", 0); + } + } +}; + +unique_ptr PhysicalPiecewiseMergeJoin::GetOperatorState(ExecutionContext &context) const { + auto &config = ClientConfig::GetConfig(context.client); + return make_uniq(context.client, *this, config.force_external); +} + +static inline idx_t SortedBlockNotNull(const idx_t base, const idx_t count, const idx_t not_null) { + return MinValue(base + count, MaxValue(base, not_null)) - base; +} + +static int MergeJoinComparisonValue(ExpressionType comparison) { + switch (comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + return -1; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return 0; + default: + throw InternalException("Unimplemented comparison type for merge join!"); + } +} + +struct BlockMergeInfo { + GlobalSortState &state; + //! The block being scanned + const idx_t block_idx; + //! The number of not-NULL values in the block (they are at the end) + const idx_t not_null; + //! The current offset in the block + idx_t &entry_idx; + SelectionVector result; + + BlockMergeInfo(GlobalSortState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) + : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { + } +}; + +static void MergeJoinPinSortingBlock(SBScanState &scan, const idx_t block_idx) { + scan.SetIndices(block_idx, 0); + scan.PinRadix(block_idx); + + auto &sd = *scan.sb->blob_sorting_data; + if (block_idx < sd.data_blocks.size()) { + scan.PinData(sd); + } +} + +static data_ptr_t MergeJoinRadixPtr(SBScanState &scan, const idx_t entry_idx) { + scan.entry_idx = entry_idx; + return scan.RadixPtr(); +} + +static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &rstate, bool *found_match, + const ExpressionType comparison) { + const auto cmp = MergeJoinComparisonValue(comparison); + + // The sort parameters should all be the same + auto &lsort = *lstate.lhs_global_state; + auto &rsort = rstate.table->global_sort_state; + D_ASSERT(lsort.sort_layout.all_constant == rsort.sort_layout.all_constant); + const auto all_constant = lsort.sort_layout.all_constant; + D_ASSERT(lsort.external == rsort.external); + const auto external = lsort.external; + + // There should only be one sorted block if they have been sorted + D_ASSERT(lsort.sorted_blocks.size() == 1); + SBScanState lread(lsort.buffer_manager, lsort); + lread.sb = lsort.sorted_blocks[0].get(); + + const idx_t l_block_idx = 0; + idx_t l_entry_idx = 0; + const auto lhs_not_null = lstate.lhs_local_table->count - lstate.lhs_local_table->has_null; + MergeJoinPinSortingBlock(lread, l_block_idx); + auto l_ptr = MergeJoinRadixPtr(lread, l_entry_idx); + + D_ASSERT(rsort.sorted_blocks.size() == 1); + SBScanState rread(rsort.buffer_manager, rsort); + rread.sb = rsort.sorted_blocks[0].get(); + + const auto cmp_size = lsort.sort_layout.comparison_size; + const auto entry_size = lsort.sort_layout.entry_size; + + idx_t right_base = 0; + for (idx_t r_block_idx = 0; r_block_idx < rread.sb->radix_sorting_data.size(); r_block_idx++) { + // we only care about the BIGGEST value in each of the RHS data blocks + // because we want to figure out if the LHS values are less than [or equal] to ANY value + // get the biggest value from the RHS chunk + MergeJoinPinSortingBlock(rread, r_block_idx); + + auto &rblock = *rread.sb->radix_sorting_data[r_block_idx]; + const auto r_not_null = + SortedBlockNotNull(right_base, rblock.count, rstate.table->count - rstate.table->has_null); + if (r_not_null == 0) { + break; + } + const auto r_entry_idx = r_not_null - 1; + right_base += rblock.count; + + auto r_ptr = MergeJoinRadixPtr(rread, r_entry_idx); + + // now we start from the current lpos value and check if we found a new value that is [<= OR <] the max RHS + // value + while (true) { + int comp_res; + if (all_constant) { + comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); + } else { + lread.entry_idx = l_entry_idx; + rread.entry_idx = r_entry_idx; + comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, lsort.sort_layout, external); + } + + if (comp_res <= cmp) { + // found a match for lpos, set it in the found_match vector + found_match[l_entry_idx] = true; + l_entry_idx++; + l_ptr += entry_size; + if (l_entry_idx >= lhs_not_null) { + // early out: we exhausted the entire LHS and they all match + return 0; + } + } else { + // we found no match: any subsequent value from the LHS we scan now will be bigger and thus also not + // match move to the next RHS chunk + break; + } + } + } + return 0; +} + +void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + OperatorState &state_p) const { + auto &state = state_p.Cast(); + auto &gstate = sink_state->Cast(); + + state.ResolveJoinKeys(input); + auto &lhs_table = *state.lhs_local_table; + + // perform the actual join + bool found_match[STANDARD_VECTOR_SIZE]; + memset(found_match, 0, sizeof(found_match)); + MergeJoinSimpleBlocks(state, gstate, found_match, conditions[0].comparison); + + // use the sorted payload + const auto lhs_not_null = lhs_table.count - lhs_table.has_null; + auto &payload = state.lhs_payload; + + // now construct the result based on the join result + switch (join_type) { + case JoinType::MARK: { + // The only part of the join keys that is actually used is the validity mask. + // Since the payload is sorted, we can just set the tail end of the validity masks to invalid. + for (auto &key : lhs_table.keys.data) { + key.Flatten(lhs_table.keys.size()); + auto &mask = FlatVector::Validity(key); + if (mask.AllValid()) { + continue; + } + mask.SetAllValid(lhs_not_null); + for (idx_t i = lhs_not_null; i < lhs_table.count; ++i) { + mask.SetInvalid(i); + } + } + // So we make a set of keys that have the validity mask set for the + PhysicalJoin::ConstructMarkJoinResult(lhs_table.keys, payload, chunk, found_match, gstate.table->has_null); + break; + } + case JoinType::SEMI: + PhysicalJoin::ConstructSemiJoinResult(payload, chunk, found_match); + break; + case JoinType::ANTI: + PhysicalJoin::ConstructAntiJoinResult(payload, chunk, found_match); + break; + default: + throw NotImplementedException("Unimplemented join type for merge join"); + } +} + +static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const ExpressionType comparison, + idx_t &prev_left_index) { + const auto cmp = MergeJoinComparisonValue(comparison); + + // The sort parameters should all be the same + D_ASSERT(l.state.sort_layout.all_constant == r.state.sort_layout.all_constant); + const auto all_constant = r.state.sort_layout.all_constant; + D_ASSERT(l.state.external == r.state.external); + const auto external = l.state.external; + + // There should only be one sorted block if they have been sorted + D_ASSERT(l.state.sorted_blocks.size() == 1); + SBScanState lread(l.state.buffer_manager, l.state); + lread.sb = l.state.sorted_blocks[0].get(); + D_ASSERT(lread.sb->radix_sorting_data.size() == 1); + MergeJoinPinSortingBlock(lread, l.block_idx); + auto l_start = MergeJoinRadixPtr(lread, 0); + auto l_ptr = MergeJoinRadixPtr(lread, l.entry_idx); + + D_ASSERT(r.state.sorted_blocks.size() == 1); + SBScanState rread(r.state.buffer_manager, r.state); + rread.sb = r.state.sorted_blocks[0].get(); + + if (r.entry_idx >= r.not_null) { + return 0; + } + + MergeJoinPinSortingBlock(rread, r.block_idx); + auto r_ptr = MergeJoinRadixPtr(rread, r.entry_idx); + + const auto cmp_size = l.state.sort_layout.comparison_size; + const auto entry_size = l.state.sort_layout.entry_size; + + idx_t result_count = 0; + while (true) { + if (l.entry_idx < prev_left_index) { + // left side smaller: found match + l.result.set_index(result_count, sel_t(l.entry_idx)); + r.result.set_index(result_count, sel_t(r.entry_idx)); + result_count++; + // move left side forward + l.entry_idx++; + l_ptr += entry_size; + if (result_count == STANDARD_VECTOR_SIZE) { + // out of space! + break; + } + continue; + } + if (l.entry_idx < l.not_null) { + int comp_res; + if (all_constant) { + comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); + } else { + lread.entry_idx = l.entry_idx; + rread.entry_idx = r.entry_idx; + comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, l.state.sort_layout, external); + } + if (comp_res <= cmp) { + // left side smaller: found match + l.result.set_index(result_count, sel_t(l.entry_idx)); + r.result.set_index(result_count, sel_t(r.entry_idx)); + result_count++; + // move left side forward + l.entry_idx++; + l_ptr += entry_size; + if (result_count == STANDARD_VECTOR_SIZE) { + // out of space! + break; + } + continue; + } + } + + prev_left_index = l.entry_idx; + // right side smaller or equal, or left side exhausted: move + // right pointer forward reset left side to start + r.entry_idx++; + if (r.entry_idx >= r.not_null) { + break; + } + r_ptr += entry_size; + + l_ptr = l_start; + l.entry_idx = 0; + } + + return result_count; +} + +OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, + DataChunk &chunk, OperatorState &state_p) const { + auto &state = state_p.Cast(); + auto &gstate = sink_state->Cast(); + auto &rsorted = *gstate.table->global_sort_state.sorted_blocks[0]; + const auto left_cols = input.ColumnCount(); + const auto tail_cols = conditions.size() - 1; + + state.payload_heap_handles.clear(); + do { + if (state.first_fetch) { + state.ResolveJoinKeys(input); + + state.right_chunk_index = 0; + state.right_base = 0; + state.left_position = 0; + state.prev_left_index = 0; + state.right_position = 0; + state.first_fetch = false; + state.finished = false; + } + if (state.finished) { + if (state.left_outer.Enabled()) { + // left join: before we move to the next chunk, see if we need to output any vectors that didn't + // have a match found + state.left_outer.ConstructLeftJoinResult(state.lhs_payload, chunk); + state.left_outer.Reset(); + } + state.first_fetch = true; + state.finished = false; + return OperatorResultType::NEED_MORE_INPUT; + } + + auto &lhs_table = *state.lhs_local_table; + const auto lhs_not_null = lhs_table.count - lhs_table.has_null; + BlockMergeInfo left_info(*state.lhs_global_state, 0, state.left_position, lhs_not_null); + + const auto &rblock = *rsorted.radix_sorting_data[state.right_chunk_index]; + const auto rhs_not_null = + SortedBlockNotNull(state.right_base, rblock.count, gstate.table->count - gstate.table->has_null); + BlockMergeInfo right_info(gstate.table->global_sort_state, state.right_chunk_index, state.right_position, + rhs_not_null); + + idx_t result_count = + MergeJoinComplexBlocks(left_info, right_info, conditions[0].comparison, state.prev_left_index); + if (result_count == 0) { + // exhausted this chunk on the right side + // move to the next right chunk + state.left_position = 0; + state.right_position = 0; + state.right_base += rsorted.radix_sorting_data[state.right_chunk_index]->count; + state.right_chunk_index++; + if (state.right_chunk_index >= rsorted.radix_sorting_data.size()) { + state.finished = true; + } + } else { + // found matches: extract them + chunk.Reset(); + for (idx_t c = 0; c < state.lhs_payload.ColumnCount(); ++c) { + chunk.data[c].Slice(state.lhs_payload.data[c], left_info.result, result_count); + } + state.payload_heap_handles.push_back(SliceSortedPayload(chunk, right_info.state, right_info.block_idx, + right_info.result, result_count, left_cols)); + chunk.SetCardinality(result_count); + + auto sel = FlatVector::IncrementalSelectionVector(); + if (tail_cols) { + // If there are more expressions to compute, + // split the result chunk into the left and right halves + // so we can compute the values for comparison. + chunk.Split(state.rhs_input, left_cols); + state.rhs_executor.SetChunk(state.rhs_input); + state.rhs_keys.Reset(); + + auto tail_count = result_count; + for (size_t cmp_idx = 1; cmp_idx < conditions.size(); ++cmp_idx) { + Vector left(lhs_table.keys.data[cmp_idx]); + left.Slice(left_info.result, result_count); + + auto &right = state.rhs_keys.data[cmp_idx]; + state.rhs_executor.ExecuteExpression(cmp_idx, right); + + if (tail_count < result_count) { + left.Slice(*sel, tail_count); + right.Slice(*sel, tail_count); + } + tail_count = + SelectJoinTail(conditions[cmp_idx].comparison, left, right, sel, tail_count, &state.sel); + sel = &state.sel; + } + chunk.Fuse(state.rhs_input); + + if (tail_count < result_count) { + result_count = tail_count; + chunk.Slice(*sel, result_count); + } + } + + // found matches: mark the found matches if required + if (state.left_outer.Enabled()) { + for (idx_t i = 0; i < result_count; i++) { + state.left_outer.SetMatch(left_info.result[sel->get_index(i)]); + } + } + if (gstate.table->found_match) { + // Absolute position of the block + start position inside that block + for (idx_t i = 0; i < result_count; i++) { + gstate.table->found_match[state.right_base + right_info.result[sel->get_index(i)]] = true; + } + } + chunk.SetCardinality(result_count); + chunk.Verify(); + } + } while (chunk.size() == 0); + return OperatorResultType::HAVE_MORE_OUTPUT; +} + +OperatorResultType PhysicalPiecewiseMergeJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, + DataChunk &chunk, GlobalOperatorState &gstate_p, + OperatorState &state) const { + auto &gstate = sink_state->Cast(); + + if (gstate.Count() == 0) { + // empty RHS + if (!EmptyResultIfRHSIsEmpty()) { + ConstructEmptyJoinResult(join_type, gstate.table->has_null, input, chunk); + return OperatorResultType::NEED_MORE_INPUT; + } else { + return OperatorResultType::FINISHED; + } + } + + input.Verify(); + switch (join_type) { + case JoinType::SEMI: + case JoinType::ANTI: + case JoinType::MARK: + // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk + ResolveSimpleJoin(context, input, chunk, state); + return OperatorResultType::NEED_MORE_INPUT; + case JoinType::LEFT: + case JoinType::INNER: + case JoinType::RIGHT: + case JoinType::OUTER: + return ResolveComplexJoin(context, input, chunk, state); + default: + throw NotImplementedException("Unimplemented type for piecewise merge loop join!"); + } +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class PiecewiseJoinScanState : public GlobalSourceState { +public: + explicit PiecewiseJoinScanState(const PhysicalPiecewiseMergeJoin &op) : op(op), right_outer_position(0) { + } + + mutex lock; + const PhysicalPiecewiseMergeJoin &op; + unique_ptr scanner; + idx_t right_outer_position; + +public: + idx_t MaxThreads() override { + auto &sink = op.sink_state->Cast(); + return sink.Count() / (STANDARD_VECTOR_SIZE * idx_t(10)); + } +}; + +unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this); +} + +SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, DataChunk &result, + OperatorSourceInput &input) const { + D_ASSERT(IsRightOuterJoin(join_type)); + // check if we need to scan any unmatched tuples from the RHS for the full/right outer join + auto &sink = sink_state->Cast(); + auto &state = input.global_state.Cast(); + + lock_guard l(state.lock); + if (!state.scanner) { + // Initialize scanner (if not yet initialized) + auto &sort_state = sink.table->global_sort_state; + if (sort_state.sorted_blocks.empty()) { + return SourceResultType::FINISHED; + } + state.scanner = make_uniq(*sort_state.sorted_blocks[0]->payload_data, sort_state); + } + + // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan the found_match for any chunks we + // still need to output + const auto found_match = sink.table->found_match.get(); + + DataChunk rhs_chunk; + rhs_chunk.Initialize(Allocator::Get(context.client), sink.table->global_sort_state.payload_layout.GetTypes()); + SelectionVector rsel(STANDARD_VECTOR_SIZE); + for (;;) { + // Read the next sorted chunk + state.scanner->Scan(rhs_chunk); + + const auto count = rhs_chunk.size(); + if (count == 0) { + return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + } + + idx_t result_count = 0; + // figure out which tuples didn't find a match in the RHS + for (idx_t i = 0; i < count; i++) { + if (!found_match[state.right_outer_position + i]) { + rsel.set_index(result_count++, i); + } + } + state.right_outer_position += count; + + if (result_count > 0) { + // if there were any tuples that didn't find a match, output them + const idx_t left_column_count = children[0]->types.size(); + for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { + result.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result.data[col_idx], true); + } + const idx_t right_column_count = children[1]->types.size(); + ; + for (idx_t col_idx = 0; col_idx < right_column_count; ++col_idx) { + result.data[left_column_count + col_idx].Slice(rhs_chunk.data[col_idx], rsel, result_count); + } + result.SetCardinality(result_count); + break; + } + } + + return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_positional_join.cpp b/src/duckdb/src/execution/operator/join/physical_positional_join.cpp new file mode 100644 index 00000000..bcf4b498 --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_positional_join.cpp @@ -0,0 +1,196 @@ +#include "duckdb/execution/operator/join/physical_positional_join.hpp" + +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/operator/join/physical_join.hpp" + +namespace duckdb { + +PhysicalPositionalJoin::PhysicalPositionalJoin(vector types, unique_ptr left, + unique_ptr right, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::POSITIONAL_JOIN, std::move(types), estimated_cardinality) { + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class PositionalJoinGlobalState : public GlobalSinkState { +public: + explicit PositionalJoinGlobalState(ClientContext &context, const PhysicalPositionalJoin &op) + : rhs(context, op.children[1]->GetTypes()), initialized(false), source_offset(0), exhausted(false) { + rhs.InitializeAppend(append_state); + } + + ColumnDataCollection rhs; + ColumnDataAppendState append_state; + mutex rhs_lock; + + bool initialized; + ColumnDataScanState scan_state; + DataChunk source; + idx_t source_offset; + bool exhausted; + + void InitializeScan(); + idx_t Refill(); + idx_t CopyData(DataChunk &output, const idx_t count, const idx_t col_offset); + void Execute(DataChunk &input, DataChunk &output); + void GetData(DataChunk &output); +}; + +unique_ptr PhysicalPositionalJoin::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +SinkResultType PhysicalPositionalJoin::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &sink = input.global_state.Cast(); + lock_guard client_guard(sink.rhs_lock); + sink.rhs.Append(sink.append_state, chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +void PositionalJoinGlobalState::InitializeScan() { + if (!initialized) { + // not initialized yet: initialize the scan + initialized = true; + rhs.InitializeScanChunk(source); + rhs.InitializeScan(scan_state); + } +} + +idx_t PositionalJoinGlobalState::Refill() { + if (source_offset >= source.size()) { + if (!exhausted) { + source.Reset(); + rhs.Scan(scan_state, source); + } + source_offset = 0; + } + + const auto available = source.size() - source_offset; + if (!available) { + if (!exhausted) { + source.Reset(); + for (idx_t i = 0; i < source.ColumnCount(); ++i) { + auto &vec = source.data[i]; + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); + } + exhausted = true; + } + } + + return available; +} + +idx_t PositionalJoinGlobalState::CopyData(DataChunk &output, const idx_t count, const idx_t col_offset) { + if (!source_offset && (source.size() >= count || exhausted)) { + // Fast track: aligned and has enough data + for (idx_t i = 0; i < source.ColumnCount(); ++i) { + output.data[col_offset + i].Reference(source.data[i]); + } + source_offset += count; + } else { + // Copy data + for (idx_t target_offset = 0; target_offset < count;) { + const auto needed = count - target_offset; + const auto available = exhausted ? needed : (source.size() - source_offset); + const auto copy_size = MinValue(needed, available); + const auto source_count = source_offset + copy_size; + for (idx_t i = 0; i < source.ColumnCount(); ++i) { + VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_count, source_offset, + target_offset); + } + target_offset += copy_size; + source_offset += copy_size; + Refill(); + } + } + + return source.ColumnCount(); +} + +void PositionalJoinGlobalState::Execute(DataChunk &input, DataChunk &output) { + lock_guard client_guard(rhs_lock); + + // Reference the input and assume it will be full + const auto col_offset = input.ColumnCount(); + for (idx_t i = 0; i < col_offset; ++i) { + output.data[i].Reference(input.data[i]); + } + + // Copy or reference the RHS columns + const auto count = input.size(); + InitializeScan(); + Refill(); + CopyData(output, count, col_offset); + + output.SetCardinality(count); +} + +OperatorResultType PhysicalPositionalJoin::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state_p) const { + auto &sink = sink_state->Cast(); + sink.Execute(input, chunk); + return OperatorResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +void PositionalJoinGlobalState::GetData(DataChunk &output) { + lock_guard client_guard(rhs_lock); + + InitializeScan(); + Refill(); + + // LHS exhausted + if (exhausted) { + // RHS exhausted too, so we are done + output.SetCardinality(0); + return; + } + + // LHS is all NULL + const auto col_offset = output.ColumnCount() - source.ColumnCount(); + for (idx_t i = 0; i < col_offset; ++i) { + auto &vec = output.data[i]; + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); + } + + // RHS still has data, so copy it + const auto count = MinValue(STANDARD_VECTOR_SIZE, source.size() - source_offset); + CopyData(output, count, col_offset); + output.SetCardinality(count); +} + +SourceResultType PhysicalPositionalJoin::GetData(ExecutionContext &context, DataChunk &result, + OperatorSourceInput &input) const { + auto &sink = sink_state->Cast(); + sink.GetData(result); + + return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalPositionalJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); +} + +vector> PhysicalPositionalJoin::GetSources() const { + auto result = children[0]->GetSources(); + if (IsSource()) { + result.push_back(*this); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_range_join.cpp b/src/duckdb/src/execution/operator/join/physical_range_join.cpp new file mode 100644 index 00000000..07cbc327 --- /dev/null +++ b/src/duckdb/src/execution/operator/join/physical_range_join.cpp @@ -0,0 +1,383 @@ +#include "duckdb/execution/operator/join/physical_range_join.hpp" + +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/sort/comparators.hpp" +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/parallel/thread_context.hpp" + +#include + +namespace duckdb { + +PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, + const idx_t child) + : op(op), executor(context), has_null(0), count(0) { + // Initialize order clause expression executor and key DataChunk + vector types; + for (const auto &cond : op.conditions) { + const auto &expr = child ? cond.right : cond.left; + executor.AddExpression(*expr); + + types.push_back(expr->return_type); + } + auto &allocator = Allocator::Get(context); + keys.Initialize(allocator, types); +} + +void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) { + // Initialize local state (if necessary) + if (!local_sort_state.initialized) { + local_sort_state.Initialize(global_sort_state, global_sort_state.buffer_manager); + } + + // Obtain sorting columns + keys.Reset(); + executor.Execute(input, keys); + + // Count the NULLs so we can exclude them later + has_null += MergeNulls(op.conditions); + count += keys.size(); + + // Only sort the primary key + DataChunk join_head; + join_head.data.emplace_back(keys.data[0]); + join_head.SetCardinality(keys.size()); + + // Sink the data into the local sort state + local_sort_state.SinkChunk(join_head, input); +} + +PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector &orders, + RowLayout &payload_layout) + : global_sort_state(BufferManager::GetBufferManager(context), orders, payload_layout), has_null(0), count(0), + memory_per_thread(0) { + D_ASSERT(orders.size() == 1); + + // Set external (can be forced with the PRAGMA) + auto &config = ClientConfig::GetConfig(context); + global_sort_state.external = config.force_external; + memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); +} + +void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { + global_sort_state.AddLocalState(ltable.local_sort_state); + has_null += ltable.has_null; + count += ltable.count; +} + +void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { + found_match = make_unsafe_uniq_array(Count()); + memset(found_match.get(), 0, sizeof(bool) * Count()); +} + +void PhysicalRangeJoin::GlobalSortedTable::Print() { + global_sort_state.Print(); +} + +class RangeJoinMergeTask : public ExecutorTask { +public: + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; + +public: + RangeJoinMergeTask(shared_ptr event_p, ClientContext &context, GlobalSortedTable &table) + : ExecutorTask(context), event(std::move(event_p)), context(context), table(table) { + } + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + // Initialize iejoin sorted and iterate until done + auto &global_sort_state = table.global_sort_state; + MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); + merge_sorter.PerformInMergeRound(); + event->FinishTask(); + + return TaskExecutionResult::TASK_FINISHED; + } + +private: + shared_ptr event; + ClientContext &context; + GlobalSortedTable &table; +}; + +class RangeJoinMergeEvent : public BasePipelineEvent { +public: + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; + +public: + RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p) + : BasePipelineEvent(pipeline_p), table(table_p) { + } + + GlobalSortedTable &table; + +public: + void Schedule() override { + auto &context = pipeline->GetClientContext(); + + // Schedule tasks equal to the number of threads, which will each merge multiple partitions + auto &ts = TaskScheduler::GetScheduler(context); + idx_t num_threads = ts.NumberOfThreads(); + + vector> iejoin_tasks; + for (idx_t tnum = 0; tnum < num_threads; tnum++) { + iejoin_tasks.push_back(make_uniq(shared_from_this(), context, table)); + } + SetTasks(std::move(iejoin_tasks)); + } + + void FinishEvent() override { + auto &global_sort_state = table.global_sort_state; + + global_sort_state.CompleteMergeRound(true); + if (global_sort_state.sorted_blocks.size() > 1) { + // Multiple blocks remaining: Schedule the next round + table.ScheduleMergeTasks(*pipeline, *this); + } + } +}; + +void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { + // Initialize global sort state for a round of merging + global_sort_state.InitializeMergeRound(); + auto new_event = make_shared(*this, pipeline); + event.InsertEvent(std::move(new_event)); +} + +void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) { + // Prepare for merge sort phase + global_sort_state.PrepareMergePhase(); + + // Start the merge phase or finish if a merge is not necessary + if (global_sort_state.sorted_blocks.size() > 1) { + ScheduleMergeTasks(pipeline, event); + } +} + +PhysicalRangeJoin::PhysicalRangeJoin(LogicalComparisonJoin &op, PhysicalOperatorType type, + unique_ptr left, unique_ptr right, + vector cond, JoinType join_type, idx_t estimated_cardinality) + : PhysicalComparisonJoin(op, type, std::move(cond), join_type, estimated_cardinality) { + // Reorder the conditions so that ranges are at the front. + // TODO: use stats to improve the choice? + // TODO: Prefer fixed length types? + if (conditions.size() > 1) { + vector conditions_p(conditions.size()); + std::swap(conditions_p, conditions); + idx_t range_position = 0; + idx_t other_position = conditions_p.size(); + for (idx_t i = 0; i < conditions_p.size(); ++i) { + switch (conditions_p[i].comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + conditions[range_position++] = std::move(conditions_p[i]); + break; + default: + conditions[--other_position] = std::move(conditions_p[i]); + break; + } + } + } + + children.push_back(std::move(left)); + children.push_back(std::move(right)); + + // Fill out the left projection map. + left_projection_map = op.left_projection_map; + if (left_projection_map.empty()) { + const auto left_count = children[0]->types.size(); + left_projection_map.reserve(left_count); + for (column_t i = 0; i < left_count; ++i) { + left_projection_map.emplace_back(i); + } + } + // Fill out the right projection map. + right_projection_map = op.right_projection_map; + if (right_projection_map.empty()) { + const auto right_count = children[1]->types.size(); + right_projection_map.reserve(right_count); + for (column_t i = 0; i < right_count; ++i) { + right_projection_map.emplace_back(i); + } + } + + // Construct the unprojected type layout from the children's types + unprojected_types = children[0]->GetTypes(); + auto &types = children[1]->GetTypes(); + unprojected_types.insert(unprojected_types.end(), types.begin(), types.end()); +} + +idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(const vector &conditions) { + // Merge the validity masks of the comparison keys into the primary + // Return the number of NULLs in the resulting chunk + D_ASSERT(keys.ColumnCount() > 0); + const auto count = keys.size(); + + size_t all_constant = 0; + for (auto &v : keys.data) { + if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { + ++all_constant; + } + } + + auto &primary = keys.data[0]; + if (all_constant == keys.data.size()) { + // Either all NULL or no NULLs + for (auto &v : keys.data) { + if (ConstantVector::IsNull(v)) { + ConstantVector::SetNull(primary, true); + return count; + } + } + return 0; + } else if (keys.ColumnCount() > 1) { + // Flatten the primary, as it will need to merge arbitrary validity masks + primary.Flatten(count); + auto &pvalidity = FlatVector::Validity(primary); + D_ASSERT(keys.ColumnCount() == conditions.size()); + for (size_t c = 1; c < keys.data.size(); ++c) { + // Skip comparisons that accept NULLs + if (conditions[c].comparison == ExpressionType::COMPARE_DISTINCT_FROM) { + continue; + } + // ToUnifiedFormat the rest, as the sort code will do this anyway. + auto &v = keys.data[c]; + UnifiedVectorFormat vdata; + v.ToUnifiedFormat(count, vdata); + auto &vvalidity = vdata.validity; + if (vvalidity.AllValid()) { + continue; + } + pvalidity.EnsureWritable(); + switch (v.GetVectorType()) { + case VectorType::FLAT_VECTOR: { + // Merge entire entries + auto pmask = pvalidity.GetData(); + const auto entry_count = pvalidity.EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) { + pmask[entry_idx] &= vvalidity.GetValidityEntry(entry_idx); + } + break; + } + case VectorType::CONSTANT_VECTOR: + // All or nothing + if (ConstantVector::IsNull(v)) { + pvalidity.SetAllInvalid(count); + return count; + } + break; + default: + // One by one + for (idx_t i = 0; i < count; ++i) { + const auto idx = vdata.sel->get_index(i); + if (!vvalidity.RowIsValidUnsafe(idx)) { + pvalidity.SetInvalidUnsafe(i); + } + } + break; + } + } + return count - pvalidity.CountValid(count); + } else { + return count - VectorOperations::CountNotNull(primary, count); + } +} + +void PhysicalRangeJoin::ProjectResult(DataChunk &chunk, DataChunk &result) const { + const auto left_projected = left_projection_map.size(); + for (idx_t i = 0; i < left_projected; ++i) { + result.data[i].Reference(chunk.data[left_projection_map[i]]); + } + const auto left_width = children[0]->types.size(); + for (idx_t i = 0; i < right_projection_map.size(); ++i) { + result.data[left_projected + i].Reference(chunk.data[left_width + right_projection_map[i]]); + } + result.SetCardinality(chunk); +} + +BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, + const SelectionVector &result, const idx_t result_count, + const idx_t left_cols) { + // There should only be one sorted block if they have been sorted + D_ASSERT(state.sorted_blocks.size() == 1); + SBScanState read_state(state.buffer_manager, state); + read_state.sb = state.sorted_blocks[0].get(); + auto &sorted_data = *read_state.sb->payload_data; + + read_state.SetIndices(block_idx, 0); + read_state.PinData(sorted_data); + const auto data_ptr = read_state.DataPtr(sorted_data); + data_ptr_t heap_ptr = nullptr; + + // Set up a batch of pointers to scan data from + Vector addresses(LogicalType::POINTER, result_count); + auto data_pointers = FlatVector::GetData(addresses); + + // Set up the data pointers for the values that are actually referenced + const idx_t &row_width = sorted_data.layout.GetRowWidth(); + + auto prev_idx = result.get_index(0); + SelectionVector gsel(result_count); + idx_t addr_count = 0; + gsel.set_index(0, addr_count); + data_pointers[addr_count] = data_ptr + prev_idx * row_width; + for (idx_t i = 1; i < result_count; ++i) { + const auto row_idx = result.get_index(i); + if (row_idx != prev_idx) { + data_pointers[++addr_count] = data_ptr + row_idx * row_width; + prev_idx = row_idx; + } + gsel.set_index(i, addr_count); + } + ++addr_count; + + // Unswizzle the offsets back to pointers (if needed) + if (!sorted_data.layout.AllConstant() && state.external) { + heap_ptr = read_state.payload_heap_handle.Ptr(); + } + + // Deserialize the payload data + auto sel = FlatVector::IncrementalSelectionVector(); + for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) { + auto &col = payload.data[left_cols + col_no]; + RowOperations::Gather(addresses, *sel, col, *sel, addr_count, sorted_data.layout, col_no, 0, heap_ptr); + col.Slice(gsel, result_count); + } + + return std::move(read_state.payload_heap_handle); +} + +idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, + const SelectionVector *sel, idx_t count, SelectionVector *true_sel) { + switch (condition) { + case ExpressionType::COMPARE_NOTEQUAL: + return VectorOperations::NotEquals(left, right, sel, count, true_sel, nullptr); + case ExpressionType::COMPARE_LESSTHAN: + return VectorOperations::LessThan(left, right, sel, count, true_sel, nullptr); + case ExpressionType::COMPARE_GREATERTHAN: + return VectorOperations::GreaterThan(left, right, sel, count, true_sel, nullptr); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, nullptr); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, nullptr); + case ExpressionType::COMPARE_DISTINCT_FROM: + return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, nullptr); + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, nullptr); + case ExpressionType::COMPARE_EQUAL: + return VectorOperations::Equals(left, right, sel, count, true_sel, nullptr); + default: + throw InternalException("Unsupported comparison type for PhysicalRangeJoin"); + } + + return count; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/order/physical_order.cpp b/src/duckdb/src/execution/operator/order/physical_order.cpp new file mode 100644 index 00000000..3806ae8a --- /dev/null +++ b/src/duckdb/src/execution/operator/order/physical_order.cpp @@ -0,0 +1,282 @@ +#include "duckdb/execution/operator/order/physical_order.hpp" + +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/parallel/event.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +PhysicalOrder::PhysicalOrder(vector types, vector orders, vector projections, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::ORDER_BY, std::move(types), estimated_cardinality), + orders(std::move(orders)), projections(std::move(projections)) { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class OrderGlobalSinkState : public GlobalSinkState { +public: + OrderGlobalSinkState(BufferManager &buffer_manager, const PhysicalOrder &order, RowLayout &payload_layout) + : global_sort_state(buffer_manager, order.orders, payload_layout) { + } + + //! Global sort state + GlobalSortState global_sort_state; + //! Memory usage per thread + idx_t memory_per_thread; +}; + +class OrderLocalSinkState : public LocalSinkState { +public: + OrderLocalSinkState(ClientContext &context, const PhysicalOrder &op) : key_executor(context) { + // Initialize order clause expression executor and DataChunk + vector key_types; + for (auto &order : op.orders) { + key_types.push_back(order.expression->return_type); + key_executor.AddExpression(*order.expression); + } + auto &allocator = Allocator::Get(context); + keys.Initialize(allocator, key_types); + payload.Initialize(allocator, op.types); + } + +public: + //! The local sort state + LocalSortState local_sort_state; + //! Key expression executor, and chunk to hold the vectors + ExpressionExecutor key_executor; + DataChunk keys; + //! Payload chunk to hold the vectors + DataChunk payload; +}; + +unique_ptr PhysicalOrder::GetGlobalSinkState(ClientContext &context) const { + // Get the payload layout from the return types + RowLayout payload_layout; + payload_layout.Initialize(types); + auto state = make_uniq(BufferManager::GetBufferManager(context), *this, payload_layout); + // Set external (can be force with the PRAGMA) + state->global_sort_state.external = ClientConfig::GetConfig(context).force_external; + state->memory_per_thread = GetMaxThreadMemory(context); + return std::move(state); +} + +unique_ptr PhysicalOrder::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context.client, *this); +} + +SinkResultType PhysicalOrder::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + auto &global_sort_state = gstate.global_sort_state; + auto &local_sort_state = lstate.local_sort_state; + + // Initialize local state (if necessary) + if (!local_sort_state.initialized) { + local_sort_state.Initialize(global_sort_state, BufferManager::GetBufferManager(context.client)); + } + + // Obtain sorting columns + auto &keys = lstate.keys; + keys.Reset(); + lstate.key_executor.Execute(chunk, keys); + + auto &payload = lstate.payload; + payload.ReferenceColumns(chunk, projections); + + // Sink the data into the local sort state + keys.Verify(); + chunk.Verify(); + local_sort_state.SinkChunk(keys, payload); + + // When sorting data reaches a certain size, we sort it + if (local_sort_state.SizeInBytes() >= gstate.memory_per_thread) { + local_sort_state.Sort(global_sort_state, true); + } + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalOrder::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + gstate.global_sort_state.AddLocalState(lstate.local_sort_state); + + return SinkCombineResultType::FINISHED; +} + +class PhysicalOrderMergeTask : public ExecutorTask { +public: + PhysicalOrderMergeTask(shared_ptr event_p, ClientContext &context, OrderGlobalSinkState &state) + : ExecutorTask(context), event(std::move(event_p)), context(context), state(state) { + } + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + // Initialize merge sorted and iterate until done + auto &global_sort_state = state.global_sort_state; + MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); + merge_sorter.PerformInMergeRound(); + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; + } + +private: + shared_ptr event; + ClientContext &context; + OrderGlobalSinkState &state; +}; + +class OrderMergeEvent : public BasePipelineEvent { +public: + OrderMergeEvent(OrderGlobalSinkState &gstate_p, Pipeline &pipeline_p) + : BasePipelineEvent(pipeline_p), gstate(gstate_p) { + } + + OrderGlobalSinkState &gstate; + +public: + void Schedule() override { + auto &context = pipeline->GetClientContext(); + + // Schedule tasks equal to the number of threads, which will each merge multiple partitions + auto &ts = TaskScheduler::GetScheduler(context); + idx_t num_threads = ts.NumberOfThreads(); + + vector> merge_tasks; + for (idx_t tnum = 0; tnum < num_threads; tnum++) { + merge_tasks.push_back(make_uniq(shared_from_this(), context, gstate)); + } + SetTasks(std::move(merge_tasks)); + } + + void FinishEvent() override { + auto &global_sort_state = gstate.global_sort_state; + + global_sort_state.CompleteMergeRound(); + if (global_sort_state.sorted_blocks.size() > 1) { + // Multiple blocks remaining: Schedule the next round + PhysicalOrder::ScheduleMergeTasks(*pipeline, *this, gstate); + } + } +}; + +SinkFinalizeType PhysicalOrder::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &state = input.global_state.Cast(); + auto &global_sort_state = state.global_sort_state; + + if (global_sort_state.sorted_blocks.empty()) { + // Empty input! + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // Prepare for merge sort phase + global_sort_state.PrepareMergePhase(); + + // Start the merge phase or finish if a merge is not necessary + if (global_sort_state.sorted_blocks.size() > 1) { + PhysicalOrder::ScheduleMergeTasks(pipeline, event, state); + } + return SinkFinalizeType::READY; +} + +void PhysicalOrder::ScheduleMergeTasks(Pipeline &pipeline, Event &event, OrderGlobalSinkState &state) { + // Initialize global sort state for a round of merging + state.global_sort_state.InitializeMergeRound(); + auto new_event = make_shared(state, pipeline); + event.InsertEvent(std::move(new_event)); +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class PhysicalOrderGlobalSourceState : public GlobalSourceState { +public: + explicit PhysicalOrderGlobalSourceState(OrderGlobalSinkState &sink) : next_batch_index(0) { + auto &global_sort_state = sink.global_sort_state; + if (global_sort_state.sorted_blocks.empty()) { + total_batches = 0; + } else { + D_ASSERT(global_sort_state.sorted_blocks.size() == 1); + total_batches = global_sort_state.sorted_blocks[0]->payload_data->data_blocks.size(); + } + } + + idx_t MaxThreads() override { + return total_batches; + } + +public: + atomic next_batch_index; + idx_t total_batches; +}; + +unique_ptr PhysicalOrder::GetGlobalSourceState(ClientContext &context) const { + auto &sink = this->sink_state->Cast(); + return make_uniq(sink); +} + +class PhysicalOrderLocalSourceState : public LocalSourceState { +public: + explicit PhysicalOrderLocalSourceState(PhysicalOrderGlobalSourceState &gstate) + : batch_index(gstate.next_batch_index++) { + } + +public: + idx_t batch_index; + unique_ptr scanner; +}; + +unique_ptr PhysicalOrder::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + return make_uniq(gstate); +} + +SourceResultType PhysicalOrder::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + if (lstate.scanner && lstate.scanner->Remaining() == 0) { + lstate.batch_index = gstate.next_batch_index++; + lstate.scanner = nullptr; + } + + if (lstate.batch_index >= gstate.total_batches) { + return SourceResultType::FINISHED; + } + + if (!lstate.scanner) { + auto &sink = this->sink_state->Cast(); + auto &global_sort_state = sink.global_sort_state; + lstate.scanner = make_uniq(global_sort_state, lstate.batch_index, true); + } + + lstate.scanner->Scan(chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +idx_t PhysicalOrder::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate_p, + LocalSourceState &lstate_p) const { + auto &lstate = lstate_p.Cast(); + return lstate.batch_index; +} + +string PhysicalOrder::ParamsToString() const { + string result = "ORDERS:\n"; + for (idx_t i = 0; i < orders.size(); i++) { + if (i > 0) { + result += "\n"; + } + result += orders[i].expression->ToString() + " "; + result += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/order/physical_top_n.cpp b/src/duckdb/src/execution/operator/order/physical_top_n.cpp new file mode 100644 index 00000000..7c8f6202 --- /dev/null +++ b/src/duckdb/src/execution/operator/order/physical_top_n.cpp @@ -0,0 +1,516 @@ +#include "duckdb/execution/operator/order/physical_top_n.hpp" + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/types/row/row_layout.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/data_table.hpp" + +namespace duckdb { + +PhysicalTopN::PhysicalTopN(vector types, vector orders, idx_t limit, idx_t offset, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::TOP_N, std::move(types), estimated_cardinality), orders(std::move(orders)), + limit(limit), offset(offset) { +} + +//===--------------------------------------------------------------------===// +// Heaps +//===--------------------------------------------------------------------===// +class TopNHeap; + +struct TopNScanState { + unique_ptr scanner; + idx_t pos; + bool exclude_offset; +}; + +class TopNSortState { +public: + explicit TopNSortState(TopNHeap &heap); + + TopNHeap &heap; + unique_ptr local_state; + unique_ptr global_state; + idx_t count; + bool is_sorted; + +public: + void Initialize(); + void Append(DataChunk &sort_chunk, DataChunk &payload); + + void Sink(DataChunk &input); + void Finalize(); + + void Move(TopNSortState &other); + + void InitializeScan(TopNScanState &state, bool exclude_offset); + void Scan(TopNScanState &state, DataChunk &chunk); +}; + +class TopNHeap { +public: + TopNHeap(ClientContext &context, const vector &payload_types, const vector &orders, + idx_t limit, idx_t offset); + TopNHeap(ExecutionContext &context, const vector &payload_types, + const vector &orders, idx_t limit, idx_t offset); + TopNHeap(ClientContext &context, Allocator &allocator, const vector &payload_types, + const vector &orders, idx_t limit, idx_t offset); + + Allocator &allocator; + BufferManager &buffer_manager; + const vector &payload_types; + const vector &orders; + idx_t limit; + idx_t offset; + TopNSortState sort_state; + ExpressionExecutor executor; + DataChunk sort_chunk; + DataChunk compare_chunk; + DataChunk payload_chunk; + //! A set of boundary values that determine either the minimum or the maximum value we have to consider for our + //! top-n + DataChunk boundary_values; + //! Whether or not the boundary_values has been set. The boundary_values are only set after a reduce step + bool has_boundary_values; + + SelectionVector final_sel; + SelectionVector true_sel; + SelectionVector false_sel; + SelectionVector new_remaining_sel; + +public: + void Sink(DataChunk &input); + void Combine(TopNHeap &other); + void Reduce(); + void Finalize(); + + void ExtractBoundaryValues(DataChunk ¤t_chunk, DataChunk &prev_chunk); + + void InitializeScan(TopNScanState &state, bool exclude_offset); + void Scan(TopNScanState &state, DataChunk &chunk); + + bool CheckBoundaryValues(DataChunk &sort_chunk, DataChunk &payload); +}; + +//===--------------------------------------------------------------------===// +// TopNSortState +//===--------------------------------------------------------------------===// +TopNSortState::TopNSortState(TopNHeap &heap) : heap(heap), count(0), is_sorted(false) { +} + +void TopNSortState::Initialize() { + RowLayout layout; + layout.Initialize(heap.payload_types); + auto &buffer_manager = heap.buffer_manager; + global_state = make_uniq(buffer_manager, heap.orders, layout); + local_state = make_uniq(); + local_state->Initialize(*global_state, buffer_manager); +} + +void TopNSortState::Append(DataChunk &sort_chunk, DataChunk &payload) { + D_ASSERT(!is_sorted); + if (heap.has_boundary_values) { + if (!heap.CheckBoundaryValues(sort_chunk, payload)) { + return; + } + } + + local_state->SinkChunk(sort_chunk, payload); + count += payload.size(); +} + +void TopNSortState::Sink(DataChunk &input) { + // compute the ordering values for the new chunk + heap.sort_chunk.Reset(); + heap.executor.Execute(input, heap.sort_chunk); + + // append the new chunk to what we have already + Append(heap.sort_chunk, input); +} + +void TopNSortState::Move(TopNSortState &other) { + local_state = std::move(other.local_state); + global_state = std::move(other.global_state); + count = other.count; + is_sorted = other.is_sorted; +} + +void TopNSortState::Finalize() { + D_ASSERT(!is_sorted); + global_state->AddLocalState(*local_state); + + global_state->PrepareMergePhase(); + while (global_state->sorted_blocks.size() > 1) { + MergeSorter merge_sorter(*global_state, heap.buffer_manager); + merge_sorter.PerformInMergeRound(); + global_state->CompleteMergeRound(); + } + is_sorted = true; +} + +void TopNSortState::InitializeScan(TopNScanState &state, bool exclude_offset) { + D_ASSERT(is_sorted); + if (global_state->sorted_blocks.empty()) { + state.scanner = nullptr; + } else { + D_ASSERT(global_state->sorted_blocks.size() == 1); + state.scanner = make_uniq(*global_state->sorted_blocks[0]->payload_data, *global_state); + } + state.pos = 0; + state.exclude_offset = exclude_offset && heap.offset > 0; +} + +void TopNSortState::Scan(TopNScanState &state, DataChunk &chunk) { + if (!state.scanner) { + return; + } + auto offset = heap.offset; + auto limit = heap.limit; + D_ASSERT(is_sorted); + while (chunk.size() == 0) { + state.scanner->Scan(chunk); + if (chunk.size() == 0) { + break; + } + idx_t start = state.pos; + idx_t end = state.pos + chunk.size(); + state.pos = end; + + idx_t chunk_start = 0; + idx_t chunk_end = chunk.size(); + if (state.exclude_offset) { + // we need to exclude all tuples before the OFFSET + // check if we should include anything + if (end <= offset) { + // end is smaller than offset: include nothing! + chunk.Reset(); + continue; + } else if (start < offset) { + // we need to slice + chunk_start = offset - start; + } + } + // check if we need to truncate at the offset + limit mark + if (start >= offset + limit) { + // we are finished + chunk_end = 0; + } else if (end > offset + limit) { + // the end extends past the offset + limit + // truncate the current chunk + chunk_end = offset + limit - start; + } + D_ASSERT(chunk_end - chunk_start <= STANDARD_VECTOR_SIZE); + if (chunk_end == chunk_start) { + chunk.Reset(); + break; + } else if (chunk_start > 0) { + SelectionVector sel(STANDARD_VECTOR_SIZE); + for (idx_t i = chunk_start; i < chunk_end; i++) { + sel.set_index(i - chunk_start, i); + } + chunk.Slice(sel, chunk_end - chunk_start); + } else if (chunk_end != chunk.size()) { + chunk.SetCardinality(chunk_end); + } + } +} + +//===--------------------------------------------------------------------===// +// TopNHeap +//===--------------------------------------------------------------------===// +TopNHeap::TopNHeap(ClientContext &context, Allocator &allocator, const vector &payload_types_p, + const vector &orders_p, idx_t limit, idx_t offset) + : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), payload_types(payload_types_p), + orders(orders_p), limit(limit), offset(offset), sort_state(*this), executor(context), has_boundary_values(false), + final_sel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE), + new_remaining_sel(STANDARD_VECTOR_SIZE) { + // initialize the executor and the sort_chunk + vector sort_types; + for (auto &order : orders) { + auto &expr = order.expression; + sort_types.push_back(expr->return_type); + executor.AddExpression(*expr); + } + payload_chunk.Initialize(allocator, payload_types); + sort_chunk.Initialize(allocator, sort_types); + compare_chunk.Initialize(allocator, sort_types); + boundary_values.Initialize(allocator, sort_types); + sort_state.Initialize(); +} + +TopNHeap::TopNHeap(ClientContext &context, const vector &payload_types, + const vector &orders, idx_t limit, idx_t offset) + : TopNHeap(context, BufferAllocator::Get(context), payload_types, orders, limit, offset) { +} + +TopNHeap::TopNHeap(ExecutionContext &context, const vector &payload_types, + const vector &orders, idx_t limit, idx_t offset) + : TopNHeap(context.client, Allocator::Get(context.client), payload_types, orders, limit, offset) { +} + +void TopNHeap::Sink(DataChunk &input) { + sort_state.Sink(input); +} + +void TopNHeap::Combine(TopNHeap &other) { + other.Finalize(); + + TopNScanState state; + other.InitializeScan(state, false); + while (true) { + payload_chunk.Reset(); + other.Scan(state, payload_chunk); + if (payload_chunk.size() == 0) { + break; + } + Sink(payload_chunk); + } + Reduce(); +} + +void TopNHeap::Finalize() { + sort_state.Finalize(); +} + +void TopNHeap::Reduce() { + idx_t min_sort_threshold = MaxValue(STANDARD_VECTOR_SIZE * 5ULL, 2ULL * (limit + offset)); + if (sort_state.count < min_sort_threshold) { + // only reduce when we pass two times the limit + offset, or 5 vectors (whichever comes first) + return; + } + sort_state.Finalize(); + TopNSortState new_state(*this); + new_state.Initialize(); + + TopNScanState state; + sort_state.InitializeScan(state, false); + + DataChunk new_chunk; + new_chunk.Initialize(allocator, payload_types); + + DataChunk *current_chunk = &new_chunk; + DataChunk *prev_chunk = &payload_chunk; + has_boundary_values = false; + while (true) { + current_chunk->Reset(); + Scan(state, *current_chunk); + if (current_chunk->size() == 0) { + ExtractBoundaryValues(*current_chunk, *prev_chunk); + break; + } + new_state.Sink(*current_chunk); + std::swap(current_chunk, prev_chunk); + } + + sort_state.Move(new_state); +} + +void TopNHeap::ExtractBoundaryValues(DataChunk ¤t_chunk, DataChunk &prev_chunk) { + // extract the last entry of the prev_chunk and set as minimum value + D_ASSERT(prev_chunk.size() > 0); + for (idx_t col_idx = 0; col_idx < current_chunk.ColumnCount(); col_idx++) { + ConstantVector::Reference(current_chunk.data[col_idx], prev_chunk.data[col_idx], prev_chunk.size() - 1, + prev_chunk.size()); + } + current_chunk.SetCardinality(1); + sort_chunk.Reset(); + executor.Execute(¤t_chunk, sort_chunk); + + boundary_values.Reset(); + boundary_values.Append(sort_chunk); + boundary_values.SetCardinality(1); + for (idx_t i = 0; i < boundary_values.ColumnCount(); i++) { + boundary_values.data[i].SetVectorType(VectorType::CONSTANT_VECTOR); + } + has_boundary_values = true; +} + +bool TopNHeap::CheckBoundaryValues(DataChunk &sort_chunk, DataChunk &payload) { + // we have boundary values + // from these boundary values, determine which values we should insert (if any) + idx_t final_count = 0; + + SelectionVector remaining_sel(nullptr); + idx_t remaining_count = sort_chunk.size(); + for (idx_t i = 0; i < orders.size(); i++) { + if (remaining_sel.data()) { + compare_chunk.data[i].Slice(sort_chunk.data[i], remaining_sel, remaining_count); + } else { + compare_chunk.data[i].Reference(sort_chunk.data[i]); + } + bool is_last = i + 1 == orders.size(); + idx_t true_count; + if (orders[i].null_order == OrderByNullType::NULLS_LAST) { + if (orders[i].type == OrderType::ASCENDING) { + true_count = VectorOperations::DistinctLessThan(compare_chunk.data[i], boundary_values.data[i], + &remaining_sel, remaining_count, &true_sel, &false_sel); + } else { + true_count = VectorOperations::DistinctGreaterThanNullsFirst(compare_chunk.data[i], + boundary_values.data[i], &remaining_sel, + remaining_count, &true_sel, &false_sel); + } + } else { + D_ASSERT(orders[i].null_order == OrderByNullType::NULLS_FIRST); + if (orders[i].type == OrderType::ASCENDING) { + true_count = VectorOperations::DistinctLessThanNullsFirst(compare_chunk.data[i], + boundary_values.data[i], &remaining_sel, + remaining_count, &true_sel, &false_sel); + } else { + true_count = + VectorOperations::DistinctGreaterThan(compare_chunk.data[i], boundary_values.data[i], + &remaining_sel, remaining_count, &true_sel, &false_sel); + } + } + + if (true_count > 0) { + memcpy(final_sel.data() + final_count, true_sel.data(), true_count * sizeof(sel_t)); + final_count += true_count; + } + idx_t false_count = remaining_count - true_count; + if (false_count > 0) { + // check what we should continue to check + compare_chunk.data[i].Slice(sort_chunk.data[i], false_sel, false_count); + remaining_count = VectorOperations::NotDistinctFrom(compare_chunk.data[i], boundary_values.data[i], + &false_sel, false_count, &new_remaining_sel, nullptr); + if (is_last) { + memcpy(final_sel.data() + final_count, new_remaining_sel.data(), remaining_count * sizeof(sel_t)); + final_count += remaining_count; + } else { + remaining_sel.Initialize(new_remaining_sel); + } + } else { + break; + } + } + if (final_count == 0) { + return false; + } + if (final_count < sort_chunk.size()) { + sort_chunk.Slice(final_sel, final_count); + payload.Slice(final_sel, final_count); + } + return true; +} + +void TopNHeap::InitializeScan(TopNScanState &state, bool exclude_offset) { + sort_state.InitializeScan(state, exclude_offset); +} + +void TopNHeap::Scan(TopNScanState &state, DataChunk &chunk) { + sort_state.Scan(state, chunk); +} + +class TopNGlobalState : public GlobalSinkState { +public: + TopNGlobalState(ClientContext &context, const vector &payload_types, + const vector &orders, idx_t limit, idx_t offset) + : heap(context, payload_types, orders, limit, offset) { + } + + mutex lock; + TopNHeap heap; +}; + +class TopNLocalState : public LocalSinkState { +public: + TopNLocalState(ExecutionContext &context, const vector &payload_types, + const vector &orders, idx_t limit, idx_t offset) + : heap(context, payload_types, orders, limit, offset) { + } + + TopNHeap heap; +}; + +unique_ptr PhysicalTopN::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context, types, orders, limit, offset); +} + +unique_ptr PhysicalTopN::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, types, orders, limit, offset); +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +SinkResultType PhysicalTopN::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + // append to the local sink state + auto &sink = input.local_state.Cast(); + sink.heap.Sink(chunk); + sink.heap.Reduce(); + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Combine +//===--------------------------------------------------------------------===// +SinkCombineResultType PhysicalTopN::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + // scan the local top N and append it to the global heap + lock_guard glock(gstate.lock); + gstate.heap.Combine(lstate.heap); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +SinkFinalizeType PhysicalTopN::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + // global finalize: compute the final top N + gstate.heap.Finalize(); + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class TopNOperatorState : public GlobalSourceState { +public: + TopNScanState state; + bool initialized = false; +}; + +unique_ptr PhysicalTopN::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(); +} + +SourceResultType PhysicalTopN::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + if (limit == 0) { + return SourceResultType::FINISHED; + } + auto &state = input.global_state.Cast(); + auto &gstate = sink_state->Cast(); + + if (!state.initialized) { + gstate.heap.InitializeScan(state.state, true); + state.initialized = true; + } + gstate.heap.Scan(state.state, chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +string PhysicalTopN::ParamsToString() const { + string result; + result += "Top " + to_string(limit); + if (offset > 0) { + result += "\n"; + result += "Offset " + to_string(offset); + } + result += "\n[INFOSEPARATOR]"; + for (idx_t i = 0; i < orders.size(); i++) { + result += "\n"; + result += orders[i].expression->ToString() + " "; + result += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp b/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp new file mode 100644 index 00000000..a96bb225 --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp @@ -0,0 +1,48 @@ +#include "duckdb/main/appender.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +namespace duckdb { + +TableCatalogEntry &CSVRejectsTable::GetTable(ClientContext &context) { + auto &temp_catalog = Catalog::GetCatalog(context, TEMP_CATALOG); + auto &table_entry = temp_catalog.GetEntry(context, TEMP_CATALOG, DEFAULT_SCHEMA, name); + return table_entry; +} + +shared_ptr CSVRejectsTable::GetOrCreate(ClientContext &context, const string &name) { + auto key = "CSV_REJECTS_TABLE_CACHE_ENTRY_" + StringUtil::Upper(name); + auto &cache = ObjectCache::GetObjectCache(context); + return cache.GetOrCreate(key, name); +} + +void CSVRejectsTable::InitializeTable(ClientContext &context, const ReadCSVData &data) { + // (Re)Create the temporary rejects table + auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG); + auto info = make_uniq(TEMP_CATALOG, DEFAULT_SCHEMA, name); + info->temporary = true; + info->on_conflict = OnCreateConflict::ERROR_ON_CONFLICT; + info->columns.AddColumn(ColumnDefinition("file", LogicalType::VARCHAR)); + info->columns.AddColumn(ColumnDefinition("line", LogicalType::BIGINT)); + info->columns.AddColumn(ColumnDefinition("column", LogicalType::BIGINT)); + info->columns.AddColumn(ColumnDefinition("column_name", LogicalType::VARCHAR)); + info->columns.AddColumn(ColumnDefinition("parsed_value", LogicalType::VARCHAR)); + + if (!data.options.rejects_recovery_columns.empty()) { + child_list_t recovery_key_components; + for (auto &col_name : data.options.rejects_recovery_columns) { + recovery_key_components.emplace_back(col_name, LogicalType::VARCHAR); + } + info->columns.AddColumn(ColumnDefinition("recovery_columns", LogicalType::STRUCT(recovery_key_components))); + } + + info->columns.AddColumn(ColumnDefinition("error", LogicalType::VARCHAR)); + + catalog.CreateTable(context, std::move(info)); + + count = 0; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp new file mode 100644 index 00000000..f951ddc4 --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp @@ -0,0 +1,213 @@ +#include "duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp" + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/types/batched_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" + +#include + +namespace duckdb { + +PhysicalBatchCopyToFile::PhysicalBatchCopyToFile(vector types, CopyFunction function_p, + unique_ptr bind_data_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::BATCH_COPY_TO_FILE, std::move(types), estimated_cardinality), + function(std::move(function_p)), bind_data(std::move(bind_data_p)) { + if (!function.flush_batch || !function.prepare_batch) { + throw InternalException( + "PhysicalBatchCopyToFile created for copy function that does not have prepare_batch/flush_batch defined"); + } +} + +//===--------------------------------------------------------------------===// +// States +//===--------------------------------------------------------------------===// +class BatchCopyToGlobalState : public GlobalSinkState { +public: + explicit BatchCopyToGlobalState(unique_ptr global_state) + : rows_copied(0), global_state(std::move(global_state)), any_flushing(false) { + } + + mutex lock; + //! The total number of rows copied to the file + atomic rows_copied; + //! Global copy state + unique_ptr global_state; + //! The prepared batch data by batch index - ready to flush + map> batch_data; + //! Lock for flushing to disk + mutex flush_lock; + //! Whether or not any threads are flushing (only one thread can flush at a time) + atomic any_flushing; + + void AddBatchData(idx_t batch_index, unique_ptr new_batch) { + // move the batch data to the set of prepared batch data + lock_guard l(lock); + auto entry = batch_data.insert(make_pair(batch_index, std::move(new_batch))); + if (!entry.second) { + throw InternalException("Duplicate batch index %llu encountered in PhysicalBatchCopyToFile", batch_index); + } + } +}; + +class BatchCopyToLocalState : public LocalSinkState { +public: + explicit BatchCopyToLocalState(unique_ptr local_state_p) + : local_state(std::move(local_state_p)), rows_copied(0) { + } + + //! Local copy state + unique_ptr local_state; + //! The current collection we are appending to + unique_ptr collection; + //! The append state of the collection + ColumnDataAppendState append_state; + //! How many rows have been copied in total + idx_t rows_copied; + //! The current batch index + optional_idx batch_index; + + void InitializeCollection(ClientContext &context, const PhysicalOperator &op) { + collection = make_uniq(BufferAllocator::Get(context), op.children[0]->types); + collection->InitializeAppend(append_state); + } +}; + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +SinkResultType PhysicalBatchCopyToFile::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &state = input.local_state.Cast(); + if (!state.collection) { + state.InitializeCollection(context.client, *this); + state.batch_index = state.partition_info.batch_index.GetIndex(); + } + state.rows_copied += chunk.size(); + state.collection->Append(state.append_state, chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalBatchCopyToFile::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &state = input.local_state.Cast(); + auto &gstate = input.global_state.Cast(); + gstate.rows_copied += state.rows_copied; + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +SinkFinalizeType PhysicalBatchCopyToFile::FinalFlush(ClientContext &context, GlobalSinkState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + idx_t min_batch_index = idx_t(NumericLimits::Maximum()); + FlushBatchData(context, gstate_p, min_batch_index); + if (function.copy_to_finalize) { + function.copy_to_finalize(context, *bind_data, *gstate.global_state); + + if (use_tmp_file) { + PhysicalCopyToFile::MoveTmpFile(context, file_path); + } + } + return SinkFinalizeType::READY; +} + +SinkFinalizeType PhysicalBatchCopyToFile::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + FinalFlush(context, input.global_state); + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Batch Data Handling +//===--------------------------------------------------------------------===// +void PhysicalBatchCopyToFile::PrepareBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t batch_index, + unique_ptr collection) const { + auto &gstate = gstate_p.Cast(); + + // prepare the batch + auto batch_data = function.prepare_batch(context, *bind_data, *gstate.global_state, std::move(collection)); + gstate.AddBatchData(batch_index, std::move(batch_data)); +} + +void PhysicalBatchCopyToFile::FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index) const { + auto &gstate = gstate_p.Cast(); + + // flush batch data to disk (if there are any to flush) + // grab the flush lock - we can only call flush_batch with this lock + // otherwise the data might end up in the wrong order + { + lock_guard l(gstate.flush_lock); + if (gstate.any_flushing) { + return; + } + gstate.any_flushing = true; + } + ActiveFlushGuard active_flush(gstate.any_flushing); + while (true) { + unique_ptr batch_data; + { + // fetch the next batch to flush (if any) + lock_guard l(gstate.lock); + if (gstate.batch_data.empty()) { + // no batch data left to flush + break; + } + auto entry = gstate.batch_data.begin(); + if (entry->first >= min_index) { + // this data is past the min_index - we cannot write it yet + break; + } + if (!entry->second) { + // this batch is in process of being prepared but is not ready yet + break; + } + batch_data = std::move(entry->second); + gstate.batch_data.erase(entry); + } + function.flush_batch(context, *bind_data, *gstate.global_state, *batch_data); + } +} + +//===--------------------------------------------------------------------===// +// Next Batch +//===--------------------------------------------------------------------===// +void PhysicalBatchCopyToFile::NextBatch(ExecutionContext &context, GlobalSinkState &gstate_p, + LocalSinkState &lstate) const { + auto &state = lstate.Cast(); + if (state.collection && state.collection->Count() > 0) { + // we finished processing this batch + // start flushing data + auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); + PrepareBatchData(context.client, gstate_p, state.batch_index.GetIndex(), std::move(state.collection)); + FlushBatchData(context.client, gstate_p, min_batch_index); + } + state.batch_index = lstate.partition_info.batch_index.GetIndex(); + + state.InitializeCollection(context.client, *this); +} + +unique_ptr PhysicalBatchCopyToFile::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(function.copy_to_initialize_local(context, *bind_data)); +} + +unique_ptr PhysicalBatchCopyToFile::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(function.copy_to_initialize_global(context, *bind_data, file_path)); +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalBatchCopyToFile::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &g = sink_state->Cast(); + + chunk.SetCardinality(1); + chunk.SetValue(0, 0, Value::BIGINT(g.rows_copied)); + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp new file mode 100644 index 00000000..d807e0e9 --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp @@ -0,0 +1,457 @@ +#include "duckdb/execution/operator/persistent/physical_batch_insert.hpp" + +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table/row_group_collection.hpp" +#include "duckdb/storage/table_io_manager.hpp" +#include "duckdb/transaction/local_storage.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +PhysicalBatchInsert::PhysicalBatchInsert(vector types, TableCatalogEntry &table, + physical_index_vector_t column_index_map, + vector> bound_defaults, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::BATCH_INSERT, std::move(types), estimated_cardinality), + column_index_map(std::move(column_index_map)), insert_table(&table), insert_types(table.GetTypes()), + bound_defaults(std::move(bound_defaults)) { +} + +PhysicalBatchInsert::PhysicalBatchInsert(LogicalOperator &op, SchemaCatalogEntry &schema, + unique_ptr info_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::BATCH_CREATE_TABLE_AS, op.types, estimated_cardinality), + insert_table(nullptr), schema(&schema), info(std::move(info_p)) { + PhysicalInsert::GetInsertInfo(*info, insert_types, bound_defaults); +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// + +class CollectionMerger { +public: + explicit CollectionMerger(ClientContext &context) : context(context) { + } + + ClientContext &context; + vector> current_collections; + +public: + void AddCollection(unique_ptr collection) { + current_collections.push_back(std::move(collection)); + } + + bool Empty() { + return current_collections.empty(); + } + + unique_ptr Flush(OptimisticDataWriter &writer) { + if (Empty()) { + return nullptr; + } + unique_ptr new_collection = std::move(current_collections[0]); + if (current_collections.size() > 1) { + // we have gathered multiple collections: create one big collection and merge that + auto &types = new_collection->GetTypes(); + TableAppendState append_state; + new_collection->InitializeAppend(append_state); + + DataChunk scan_chunk; + scan_chunk.Initialize(context, types); + + vector column_ids; + for (idx_t i = 0; i < types.size(); i++) { + column_ids.push_back(i); + } + for (auto &collection : current_collections) { + if (!collection) { + continue; + } + TableScanState scan_state; + scan_state.Initialize(column_ids); + collection->InitializeScan(scan_state.local_state, column_ids, nullptr); + + while (true) { + scan_chunk.Reset(); + scan_state.local_state.ScanCommitted(scan_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); + if (scan_chunk.size() == 0) { + break; + } + auto new_row_group = new_collection->Append(scan_chunk, append_state); + if (new_row_group) { + writer.WriteNewRowGroup(*new_collection); + } + } + } + new_collection->FinalizeAppend(TransactionData(0, 0), append_state); + writer.WriteLastRowGroup(*new_collection); + } + current_collections.clear(); + return new_collection; + } +}; + +enum class RowGroupBatchType : uint8_t { FLUSHED, NOT_FLUSHED }; +struct RowGroupBatchEntry { + RowGroupBatchEntry(idx_t batch_idx, unique_ptr collection_p, RowGroupBatchType type) + : batch_idx(batch_idx), total_rows(collection_p->GetTotalRows()), collection(std::move(collection_p)), + type(type) { + } + + idx_t batch_idx; + idx_t total_rows; + unique_ptr collection; + RowGroupBatchType type; +}; + +class BatchInsertGlobalState : public GlobalSinkState { +public: + static constexpr const idx_t BATCH_FLUSH_THRESHOLD = LocalStorage::MERGE_THRESHOLD * 3; + +public: + explicit BatchInsertGlobalState(DuckTableEntry &table) : table(table), insert_count(0) { + } + + mutex lock; + DuckTableEntry &table; + idx_t insert_count; + vector collections; + idx_t next_start = 0; + bool optimistically_written = false; + + void FindMergeCollections(idx_t min_batch_index, optional_idx &merged_batch_index, + vector> &result) { + bool merge = false; + idx_t start_index = next_start; + idx_t current_idx; + idx_t total_count = 0; + for (current_idx = start_index; current_idx < collections.size(); current_idx++) { + auto &entry = collections[current_idx]; + if (entry.batch_idx >= min_batch_index) { + // this entry is AFTER the min_batch_index + // we might still find new entries! + break; + } + if (entry.type == RowGroupBatchType::FLUSHED) { + // already flushed: cannot flush anything here + if (total_count > 0) { + merge = true; + break; + } + start_index = current_idx + 1; + if (start_index > next_start) { + // avoid checking this segment again in the future + next_start = start_index; + } + total_count = 0; + continue; + } + // not flushed - add to set of indexes to flush + total_count += entry.total_rows; + if (total_count >= BATCH_FLUSH_THRESHOLD) { + merge = true; + break; + } + } + if (merge && total_count > 0) { + D_ASSERT(current_idx > start_index); + merged_batch_index = collections[start_index].batch_idx; + for (idx_t idx = start_index; idx < current_idx; idx++) { + auto &entry = collections[idx]; + if (!entry.collection || entry.type == RowGroupBatchType::FLUSHED) { + throw InternalException("Adding a row group collection that should not be flushed"); + } + result.push_back(std::move(entry.collection)); + entry.total_rows = total_count; + entry.type = RowGroupBatchType::FLUSHED; + } + if (start_index + 1 < current_idx) { + // erase all entries except the first one + collections.erase(collections.begin() + start_index + 1, collections.begin() + current_idx); + } + } + } + + unique_ptr MergeCollections(ClientContext &context, + vector> merge_collections, + OptimisticDataWriter &writer) { + D_ASSERT(!merge_collections.empty()); + CollectionMerger merger(context); + for (auto &collection : merge_collections) { + merger.AddCollection(std::move(collection)); + } + optimistically_written = true; + return merger.Flush(writer); + } + + void AddCollection(ClientContext &context, idx_t batch_index, idx_t min_batch_index, + unique_ptr current_collection, + optional_ptr writer = nullptr, + optional_ptr written_to_disk = nullptr) { + if (batch_index < min_batch_index) { + throw InternalException( + "Batch index of the added collection (%llu) is smaller than the min batch index (%llu)", batch_index, + min_batch_index); + } + auto new_count = current_collection->GetTotalRows(); + auto batch_type = + new_count < Storage::ROW_GROUP_SIZE ? RowGroupBatchType::NOT_FLUSHED : RowGroupBatchType::FLUSHED; + if (batch_type == RowGroupBatchType::FLUSHED && writer) { + writer->WriteLastRowGroup(*current_collection); + } + optional_idx merged_batch_index; + vector> merge_collections; + { + lock_guard l(lock); + insert_count += new_count; + + // add the collection to the batch index + RowGroupBatchEntry new_entry(batch_index, std::move(current_collection), batch_type); + + auto it = std::lower_bound( + collections.begin(), collections.end(), new_entry, + [&](const RowGroupBatchEntry &a, const RowGroupBatchEntry &b) { return a.batch_idx < b.batch_idx; }); + if (it != collections.end() && it->batch_idx == new_entry.batch_idx) { + throw InternalException( + "PhysicalBatchInsert::AddCollection error: batch index %d is present in multiple " + "collections. This occurs when " + "batch indexes are not uniquely distributed over threads", + batch_index); + } + collections.insert(it, std::move(new_entry)); + if (writer) { + FindMergeCollections(min_batch_index, merged_batch_index, merge_collections); + } + } + if (!merge_collections.empty()) { + // merge together the collections + D_ASSERT(writer); + auto final_collection = MergeCollections(context, std::move(merge_collections), *writer); + if (written_to_disk) { + *written_to_disk = true; + } + // add the merged-together collection to the set of batch indexes + { + lock_guard l(lock); + RowGroupBatchEntry new_entry(merged_batch_index.GetIndex(), std::move(final_collection), + RowGroupBatchType::FLUSHED); + auto it = std::lower_bound(collections.begin(), collections.end(), new_entry, + [&](const RowGroupBatchEntry &a, const RowGroupBatchEntry &b) { + return a.batch_idx < b.batch_idx; + }); + if (it->batch_idx != merged_batch_index.GetIndex()) { + throw InternalException("Merged batch index was no longer present in collection"); + } + it->collection = std::move(new_entry.collection); + } + } + } +}; + +class BatchInsertLocalState : public LocalSinkState { +public: + BatchInsertLocalState(ClientContext &context, const vector &types, + const vector> &bound_defaults) + : default_executor(context, bound_defaults), written_to_disk(false) { + insert_chunk.Initialize(Allocator::Get(context), types); + } + + DataChunk insert_chunk; + ExpressionExecutor default_executor; + idx_t current_index; + TableAppendState current_append_state; + unique_ptr current_collection; + optional_ptr writer; + bool written_to_disk; + + void CreateNewCollection(DuckTableEntry &table, const vector &insert_types) { + auto &table_info = table.GetStorage().info; + auto &block_manager = TableIOManager::Get(table.GetStorage()).GetBlockManagerForRowData(); + current_collection = make_uniq(table_info, block_manager, insert_types, MAX_ROW_ID); + current_collection->InitializeEmpty(); + current_collection->InitializeAppend(current_append_state); + written_to_disk = false; + } +}; + +unique_ptr PhysicalBatchInsert::GetGlobalSinkState(ClientContext &context) const { + optional_ptr table; + if (info) { + // CREATE TABLE AS + D_ASSERT(!insert_table); + auto &catalog = schema->catalog; + auto created_table = catalog.CreateTable(catalog.GetCatalogTransaction(context), *schema.get_mutable(), *info); + table = &created_table->Cast(); + } else { + D_ASSERT(insert_table); + D_ASSERT(insert_table->IsDuckTable()); + table = insert_table.get_mutable(); + } + auto result = make_uniq(table->Cast()); + return std::move(result); +} + +unique_ptr PhysicalBatchInsert::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context.client, insert_types, bound_defaults); +} + +void PhysicalBatchInsert::NextBatch(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate_p) const { + auto &gstate = state.Cast(); + auto &lstate = lstate_p.Cast(); + + auto &table = gstate.table; + auto batch_index = lstate.partition_info.batch_index.GetIndex(); + if (lstate.current_collection) { + if (lstate.current_index == batch_index) { + throw InternalException("NextBatch called with the same batch index?"); + } + // batch index has changed: move the old collection to the global state and create a new collection + TransactionData tdata(0, 0); + lstate.current_collection->FinalizeAppend(tdata, lstate.current_append_state); + gstate.AddCollection(context.client, lstate.current_index, lstate.partition_info.min_batch_index.GetIndex(), + std::move(lstate.current_collection), lstate.writer, &lstate.written_to_disk); + lstate.CreateNewCollection(table, insert_types); + } + lstate.current_index = batch_index; +} + +SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + auto &table = gstate.table; + PhysicalInsert::ResolveDefaults(table, chunk, column_index_map, lstate.default_executor, lstate.insert_chunk); + + auto batch_index = lstate.partition_info.batch_index.GetIndex(); + if (!lstate.current_collection) { + lock_guard l(gstate.lock); + // no collection yet: create a new one + lstate.CreateNewCollection(table, insert_types); + lstate.writer = &table.GetStorage().CreateOptimisticWriter(context.client); + } + + if (lstate.current_index != batch_index) { + throw InternalException("Current batch differs from batch - but NextBatch was not called!?"); + } + + table.GetStorage().VerifyAppendConstraints(table, context.client, lstate.insert_chunk); + + auto new_row_group = lstate.current_collection->Append(lstate.insert_chunk, lstate.current_append_state); + if (new_row_group) { + // we have already written to disk - flush the next row group as well + lstate.writer->WriteNewRowGroup(*lstate.current_collection); + lstate.written_to_disk = true; + } + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalBatchInsert::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + auto &client_profiler = QueryProfiler::Get(context.client); + context.thread.profiler.Flush(*this, lstate.default_executor, "default_executor", 1); + client_profiler.Flush(context.thread.profiler); + + if (!lstate.current_collection) { + return SinkCombineResultType::FINISHED; + } + + if (lstate.current_collection->GetTotalRows() > 0) { + TransactionData tdata(0, 0); + lstate.current_collection->FinalizeAppend(tdata, lstate.current_append_state); + gstate.AddCollection(context.client, lstate.current_index, lstate.partition_info.min_batch_index.GetIndex(), + std::move(lstate.current_collection)); + } + { + lock_guard l(gstate.lock); + gstate.table.GetStorage().FinalizeOptimisticWriter(context.client, *lstate.writer); + } + + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalBatchInsert::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + + if (gstate.optimistically_written || gstate.insert_count >= LocalStorage::MERGE_THRESHOLD) { + // we have written data to disk optimistically or are inserting a large amount of data + // perform a final pass over all of the row groups and merge them together + vector> mergers; + unique_ptr current_merger; + + auto &storage = gstate.table.GetStorage(); + for (auto &entry : gstate.collections) { + if (entry.type == RowGroupBatchType::NOT_FLUSHED) { + // this collection has not been flushed: add it to the merge set + if (!current_merger) { + current_merger = make_uniq(context); + } + current_merger->AddCollection(std::move(entry.collection)); + } else { + // this collection has been flushed: it does not need to be merged + // create a separate collection merger only for this entry + if (current_merger) { + // we have small collections remaining: flush them + mergers.push_back(std::move(current_merger)); + current_merger.reset(); + } + auto larger_merger = make_uniq(context); + larger_merger->AddCollection(std::move(entry.collection)); + mergers.push_back(std::move(larger_merger)); + } + } + if (current_merger) { + mergers.push_back(std::move(current_merger)); + } + + // now that we have created all of the mergers, perform the actual merging + vector> final_collections; + final_collections.reserve(mergers.size()); + auto &writer = storage.CreateOptimisticWriter(context); + for (auto &merger : mergers) { + final_collections.push_back(merger->Flush(writer)); + } + storage.FinalizeOptimisticWriter(context, writer); + + // finally, merge the row groups into the local storage + for (auto &collection : final_collections) { + storage.LocalMerge(context, *collection); + } + } else { + // we are writing a small amount of data to disk + // append directly to transaction local storage + auto &table = gstate.table; + auto &storage = table.GetStorage(); + LocalAppendState append_state; + storage.InitializeLocalAppend(append_state, context); + auto &transaction = DuckTransaction::Get(context, table.catalog); + for (auto &entry : gstate.collections) { + entry.collection->Scan(transaction, [&](DataChunk &insert_chunk) { + storage.LocalAppend(append_state, table, context, insert_chunk); + return true; + }); + } + storage.FinalizeLocalAppend(append_state); + } + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// + +SourceResultType PhysicalBatchInsert::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &insert_gstate = sink_state->Cast(); + + chunk.SetCardinality(1); + chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.insert_count)); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp new file mode 100644 index 00000000..c192a4d9 --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -0,0 +1,245 @@ +#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/hive_partitioning.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/file_opener.hpp" +#include "duckdb/common/types/uuid.hpp" +#include "duckdb/common/string_util.hpp" + +#include + +namespace duckdb { + +class CopyToFunctionGlobalState : public GlobalSinkState { +public: + explicit CopyToFunctionGlobalState(unique_ptr global_state) + : rows_copied(0), last_file_offset(0), global_state(std::move(global_state)) { + } + mutex lock; + idx_t rows_copied; + idx_t last_file_offset; + unique_ptr global_state; + + //! shared state for HivePartitionedColumnData + shared_ptr partition_state; +}; + +class CopyToFunctionLocalState : public LocalSinkState { +public: + explicit CopyToFunctionLocalState(unique_ptr local_state) + : local_state(std::move(local_state)), writer_offset(0) { + } + unique_ptr global_state; + unique_ptr local_state; + + //! Buffers the tuples in partitions before writing + unique_ptr part_buffer; + unique_ptr part_buffer_append_state; + + idx_t writer_offset; +}; + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// + +void PhysicalCopyToFile::MoveTmpFile(ClientContext &context, const string &tmp_file_path) { + auto &fs = FileSystem::GetFileSystem(context); + auto file_path = tmp_file_path.substr(0, tmp_file_path.length() - 4); + if (fs.FileExists(file_path)) { + fs.RemoveFile(file_path); + } + fs.MoveFile(tmp_file_path, file_path); +} + +PhysicalCopyToFile::PhysicalCopyToFile(vector types, CopyFunction function_p, + unique_ptr bind_data, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::COPY_TO_FILE, std::move(types), estimated_cardinality), + function(std::move(function_p)), bind_data(std::move(bind_data)), parallel(false) { +} + +SinkResultType PhysicalCopyToFile::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &g = input.global_state.Cast(); + auto &l = input.local_state.Cast(); + + if (partition_output) { + l.part_buffer->Append(*l.part_buffer_append_state, chunk); + return SinkResultType::NEED_MORE_INPUT; + } + + { + lock_guard glock(g.lock); + g.rows_copied += chunk.size(); + } + function.copy_to_sink(context, *bind_data, per_thread_output ? *l.global_state : *g.global_state, *l.local_state, + chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +static void CreateDir(const string &dir_path, FileSystem &fs) { + if (!fs.DirectoryExists(dir_path)) { + fs.CreateDirectory(dir_path); + } +} + +static string CreateDirRecursive(const vector &cols, const vector &names, const vector &values, + string path, FileSystem &fs) { + CreateDir(path, fs); + + for (idx_t i = 0; i < cols.size(); i++) { + const auto &partition_col_name = names[cols[i]]; + const auto &partition_value = values[i]; + string p_dir = partition_col_name + "=" + partition_value.ToString(); + path = fs.JoinPath(path, p_dir); + CreateDir(path, fs); + } + + return path; +} + +SinkCombineResultType PhysicalCopyToFile::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &g = input.global_state.Cast(); + auto &l = input.local_state.Cast(); + + if (partition_output) { + auto &fs = FileSystem::GetFileSystem(context.client); + l.part_buffer->FlushAppendState(*l.part_buffer_append_state); + auto &partitions = l.part_buffer->GetPartitions(); + auto partition_key_map = l.part_buffer->GetReverseMap(); + + string trimmed_path = file_path; + StringUtil::RTrim(trimmed_path, fs.PathSeparator(trimmed_path)); + + for (idx_t i = 0; i < partitions.size(); i++) { + string hive_path = + CreateDirRecursive(partition_columns, names, partition_key_map[i]->values, trimmed_path, fs); + string full_path(filename_pattern.CreateFilename(fs, hive_path, function.extension, l.writer_offset)); + if (fs.FileExists(full_path) && !overwrite_or_ignore) { + throw IOException("failed to create " + full_path + + ", file exists! Enable OVERWRITE_OR_IGNORE option to force writing"); + } + // Create a writer for the current file + auto fun_data_global = function.copy_to_initialize_global(context.client, *bind_data, full_path); + auto fun_data_local = function.copy_to_initialize_local(context, *bind_data); + + for (auto &chunk : partitions[i]->Chunks()) { + function.copy_to_sink(context, *bind_data, *fun_data_global, *fun_data_local, chunk); + } + + function.copy_to_combine(context, *bind_data, *fun_data_global, *fun_data_local); + function.copy_to_finalize(context.client, *bind_data, *fun_data_global); + } + + return SinkCombineResultType::FINISHED; + } + + if (function.copy_to_combine) { + function.copy_to_combine(context, *bind_data, per_thread_output ? *l.global_state : *g.global_state, + *l.local_state); + + if (per_thread_output) { + function.copy_to_finalize(context.client, *bind_data, *l.global_state); + } + } + + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalCopyToFile::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + if (per_thread_output || partition_output) { + // already happened in combine + return SinkFinalizeType::READY; + } + if (function.copy_to_finalize) { + function.copy_to_finalize(context, *bind_data, *gstate.global_state); + + if (use_tmp_file) { + D_ASSERT(!per_thread_output); // FIXME + D_ASSERT(!partition_output); // FIXME + MoveTmpFile(context, file_path); + } + } + return SinkFinalizeType::READY; +} + +unique_ptr PhysicalCopyToFile::GetLocalSinkState(ExecutionContext &context) const { + if (partition_output) { + auto state = make_uniq(nullptr); + { + auto &g = sink_state->Cast(); + lock_guard glock(g.lock); + state->writer_offset = g.last_file_offset++; + + state->part_buffer = make_uniq(context.client, expected_types, partition_columns, + g.partition_state); + state->part_buffer_append_state = make_uniq(); + state->part_buffer->InitializeAppendState(*state->part_buffer_append_state); + } + return std::move(state); + } + auto res = make_uniq(function.copy_to_initialize_local(context, *bind_data)); + if (per_thread_output) { + idx_t this_file_offset; + { + auto &g = sink_state->Cast(); + lock_guard glock(g.lock); + this_file_offset = g.last_file_offset++; + } + auto &fs = FileSystem::GetFileSystem(context.client); + string output_path(filename_pattern.CreateFilename(fs, file_path, function.extension, this_file_offset)); + if (fs.FileExists(output_path) && !overwrite_or_ignore) { + throw IOException("%s exists! Enable OVERWRITE_OR_IGNORE option to force writing", output_path); + } + res->global_state = function.copy_to_initialize_global(context.client, *bind_data, output_path); + } + return std::move(res); +} + +unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext &context) const { + + if (partition_output || per_thread_output) { + auto &fs = FileSystem::GetFileSystem(context); + + if (fs.FileExists(file_path) && !overwrite_or_ignore) { + throw IOException("%s exists! Enable OVERWRITE_OR_IGNORE option to force writing", file_path); + } + if (!fs.DirectoryExists(file_path)) { + fs.CreateDirectory(file_path); + } else if (!overwrite_or_ignore) { + idx_t n_files = 0; + fs.ListFiles(file_path, [&n_files](const string &path, bool) { n_files++; }); + if (n_files > 0) { + throw IOException("Directory %s is not empty! Enable OVERWRITE_OR_IGNORE option to force writing", + file_path); + } + } + + auto state = make_uniq(nullptr); + + if (partition_output) { + state->partition_state = make_shared(); + } + + return std::move(state); + } + + return make_uniq(function.copy_to_initialize_global(context, *bind_data, file_path)); +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// + +SourceResultType PhysicalCopyToFile::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &g = sink_state->Cast(); + + chunk.SetCardinality(1); + chunk.SetValue(0, 0, Value::BIGINT(g.rows_copied)); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp new file mode 100644 index 00000000..a3e0a4c1 --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp @@ -0,0 +1,102 @@ +#include "duckdb/execution/operator/persistent/physical_delete.hpp" + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/transaction/duck_transaction.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class DeleteGlobalState : public GlobalSinkState { +public: + explicit DeleteGlobalState(ClientContext &context, const vector &return_types) + : deleted_count(0), return_collection(context, return_types) { + } + + mutex delete_lock; + idx_t deleted_count; + ColumnDataCollection return_collection; +}; + +class DeleteLocalState : public LocalSinkState { +public: + DeleteLocalState(Allocator &allocator, const vector &table_types) { + delete_chunk.Initialize(allocator, table_types); + } + DataChunk delete_chunk; +}; + +SinkResultType PhysicalDelete::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &ustate = input.local_state.Cast(); + + // get rows and + auto &transaction = DuckTransaction::Get(context.client, table.db); + auto &row_identifiers = chunk.data[row_id_index]; + + vector column_ids; + for (idx_t i = 0; i < table.column_definitions.size(); i++) { + column_ids.emplace_back(i); + }; + auto cfs = ColumnFetchState(); + + lock_guard delete_guard(gstate.delete_lock); + if (return_chunk) { + row_identifiers.Flatten(chunk.size()); + table.Fetch(transaction, ustate.delete_chunk, column_ids, row_identifiers, chunk.size(), cfs); + gstate.return_collection.Append(ustate.delete_chunk); + } + gstate.deleted_count += table.Delete(tableref, context.client, row_identifiers, chunk.size()); + + return SinkResultType::NEED_MORE_INPUT; +} + +unique_ptr PhysicalDelete::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, GetTypes()); +} + +unique_ptr PhysicalDelete::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(Allocator::Get(context.client), table.GetTypes()); +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class DeleteSourceState : public GlobalSourceState { +public: + explicit DeleteSourceState(const PhysicalDelete &op) { + if (op.return_chunk) { + D_ASSERT(op.sink_state); + auto &g = op.sink_state->Cast(); + g.return_collection.InitializeScan(scan_state); + } + } + + ColumnDataScanState scan_state; +}; + +unique_ptr PhysicalDelete::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this); +} + +SourceResultType PhysicalDelete::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &state = input.global_state.Cast(); + auto &g = sink_state->Cast(); + if (!return_chunk) { + chunk.SetCardinality(1); + chunk.SetValue(0, 0, Value::BIGINT(g.deleted_count)); + return SourceResultType::FINISHED; + } + + g.return_collection.Scan(state.scan_state, chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_export.cpp b/src/duckdb/src/execution/operator/persistent/physical_export.cpp new file mode 100644 index 00000000..a24c25ab --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/physical_export.cpp @@ -0,0 +1,215 @@ +#include "duckdb/execution/operator/persistent/physical_export.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/transaction/transaction.hpp" + +#include +#include + +namespace duckdb { + +using std::stringstream; + +static void WriteCatalogEntries(stringstream &ss, vector> &entries) { + for (auto &entry : entries) { + if (entry.get().internal) { + continue; + } + ss << entry.get().ToSQL() << std::endl; + } + ss << std::endl; +} + +static void WriteStringStreamToFile(FileSystem &fs, stringstream &ss, const string &path) { + auto ss_string = ss.str(); + auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW, + FileLockType::WRITE_LOCK); + fs.Write(*handle, (void *)ss_string.c_str(), ss_string.size()); + handle.reset(); +} + +static void WriteValueAsSQL(stringstream &ss, Value &val) { + if (val.type().IsNumeric()) { + ss << val.ToString(); + } else { + ss << "'" << val.ToString() << "'"; + } +} + +static void WriteCopyStatement(FileSystem &fs, stringstream &ss, CopyInfo &info, ExportedTableData &exported_table, + CopyFunction const &function) { + ss << "COPY "; + + if (exported_table.schema_name != DEFAULT_SCHEMA) { + ss << KeywordHelper::WriteOptionallyQuoted(exported_table.schema_name) << "."; + } + + ss << StringUtil::Format("%s FROM %s (", SQLIdentifier(exported_table.table_name), + SQLString(exported_table.file_path)); + + // write the copy options + ss << "FORMAT '" << info.format << "'"; + if (info.format == "csv") { + // insert default csv options, if not specified + if (info.options.find("header") == info.options.end()) { + info.options["header"].push_back(Value::INTEGER(1)); + } + if (info.options.find("delimiter") == info.options.end() && info.options.find("sep") == info.options.end() && + info.options.find("delim") == info.options.end()) { + info.options["delimiter"].push_back(Value(",")); + } + if (info.options.find("quote") == info.options.end()) { + info.options["quote"].push_back(Value("\"")); + } + } + for (auto ©_option : info.options) { + if (copy_option.first == "force_quote") { + continue; + } + ss << ", " << copy_option.first << " "; + if (copy_option.second.size() == 1) { + WriteValueAsSQL(ss, copy_option.second[0]); + } else { + // FIXME handle multiple options + throw NotImplementedException("FIXME: serialize list of options"); + } + } + ss << ");" << std::endl; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class ExportSourceState : public GlobalSourceState { +public: + ExportSourceState() : finished(false) { + } + + bool finished; +}; + +unique_ptr PhysicalExport::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(); +} + +SourceResultType PhysicalExport::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &state = input.global_state.Cast(); + if (state.finished) { + return SourceResultType::FINISHED; + } + + auto &ccontext = context.client; + auto &fs = FileSystem::GetFileSystem(ccontext); + + // gather all catalog types to export + vector> schemas; + vector> custom_types; + vector> sequences; + vector> tables; + vector> views; + vector> indexes; + vector> macros; + + auto schema_list = Catalog::GetSchemas(ccontext, info->catalog); + for (auto &schema_p : schema_list) { + auto &schema = schema_p.get(); + if (!schema.internal) { + schemas.push_back(schema); + } + schema.Scan(context.client, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { + if (entry.internal) { + return; + } + if (entry.type != CatalogType::TABLE_ENTRY) { + views.push_back(entry); + } + }); + schema.Scan(context.client, CatalogType::SEQUENCE_ENTRY, + [&](CatalogEntry &entry) { sequences.push_back(entry); }); + schema.Scan(context.client, CatalogType::TYPE_ENTRY, + [&](CatalogEntry &entry) { custom_types.push_back(entry); }); + schema.Scan(context.client, CatalogType::INDEX_ENTRY, [&](CatalogEntry &entry) { indexes.push_back(entry); }); + schema.Scan(context.client, CatalogType::MACRO_ENTRY, [&](CatalogEntry &entry) { + if (!entry.internal && entry.type == CatalogType::MACRO_ENTRY) { + macros.push_back(entry); + } + }); + schema.Scan(context.client, CatalogType::TABLE_MACRO_ENTRY, [&](CatalogEntry &entry) { + if (!entry.internal && entry.type == CatalogType::TABLE_MACRO_ENTRY) { + macros.push_back(entry); + } + }); + } + + // consider the order of tables because of foreign key constraint + for (idx_t i = 0; i < exported_tables.data.size(); i++) { + tables.push_back(exported_tables.data[i].entry); + } + + // order macro's by timestamp so nested macro's are imported nicely + sort(macros.begin(), macros.end(), [](const reference &lhs, const reference &rhs) { + return lhs.get().oid < rhs.get().oid; + }); + + // write the schema.sql file + // export order is SCHEMA -> SEQUENCE -> TABLE -> VIEW -> INDEX + + stringstream ss; + WriteCatalogEntries(ss, schemas); + WriteCatalogEntries(ss, custom_types); + WriteCatalogEntries(ss, sequences); + WriteCatalogEntries(ss, tables); + WriteCatalogEntries(ss, views); + WriteCatalogEntries(ss, indexes); + WriteCatalogEntries(ss, macros); + + WriteStringStreamToFile(fs, ss, fs.JoinPath(info->file_path, "schema.sql")); + + // write the load.sql file + // for every table, we write COPY INTO statement with the specified options + stringstream load_ss; + for (idx_t i = 0; i < exported_tables.data.size(); i++) { + auto exported_table_info = exported_tables.data[i].table_data; + WriteCopyStatement(fs, load_ss, *info, exported_table_info, function); + } + WriteStringStreamToFile(fs, load_ss, fs.JoinPath(info->file_path, "load.sql")); + state.finished = true; + + return SourceResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +SinkResultType PhysicalExport::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + // nop + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalExport::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + // EXPORT has an optional child + // we only need to schedule child pipelines if there is a child + auto &state = meta_pipeline.GetState(); + state.SetPipelineSource(current, *this); + if (children.empty()) { + return; + } + PhysicalOperator::BuildPipelines(current, meta_pipeline); +} + +vector> PhysicalExport::GetSources() const { + return {*this}; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_fixed_batch_copy.cpp b/src/duckdb/src/execution/operator/persistent/physical_fixed_batch_copy.cpp new file mode 100644 index 00000000..27597c45 --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/physical_fixed_batch_copy.cpp @@ -0,0 +1,496 @@ +#include "duckdb/execution/operator/persistent/physical_fixed_batch_copy.hpp" +#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/types/batched_data_collection.hpp" +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/queue.hpp" +#include "duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp" + +#include + +namespace duckdb { + +PhysicalFixedBatchCopy::PhysicalFixedBatchCopy(vector types, CopyFunction function_p, + unique_ptr bind_data_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::BATCH_COPY_TO_FILE, std::move(types), estimated_cardinality), + function(std::move(function_p)), bind_data(std::move(bind_data_p)) { + if (!function.flush_batch || !function.prepare_batch || !function.desired_batch_size) { + throw InternalException("PhysicalFixedBatchCopy created for copy function that does not have " + "prepare_batch/flush_batch/desired_batch_size defined"); + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class BatchCopyTask { +public: + virtual ~BatchCopyTask() { + } + + virtual void Execute(const PhysicalFixedBatchCopy &op, ClientContext &context, GlobalSinkState &gstate_p) = 0; +}; + +//===--------------------------------------------------------------------===// +// States +//===--------------------------------------------------------------------===// +class FixedBatchCopyGlobalState : public GlobalSinkState { +public: + explicit FixedBatchCopyGlobalState(unique_ptr global_state) + : rows_copied(0), global_state(std::move(global_state)), batch_size(0), scheduled_batch_index(0), + flushed_batch_index(0), any_flushing(false), any_finished(false) { + } + + mutex lock; + mutex flush_lock; + //! The total number of rows copied to the file + atomic rows_copied; + //! Global copy state + unique_ptr global_state; + //! The desired batch size (if any) + idx_t batch_size; + //! Unpartitioned batches - only used in case batch_size is required + map> raw_batches; + //! The prepared batch data by batch index - ready to flush + map> batch_data; + //! The index of the latest batch index that has been scheduled + atomic scheduled_batch_index; + //! The index of the latest batch index that has been flushed + atomic flushed_batch_index; + //! Whether or not any thread is flushing + atomic any_flushing; + //! Whether or not any threads are finished + atomic any_finished; + + void AddTask(unique_ptr task) { + lock_guard l(task_lock); + task_queue.push(std::move(task)); + } + + unique_ptr GetTask() { + lock_guard l(task_lock); + if (task_queue.empty()) { + return nullptr; + } + auto entry = std::move(task_queue.front()); + task_queue.pop(); + return entry; + } + + idx_t TaskCount() { + lock_guard l(task_lock); + return task_queue.size(); + } + + void AddBatchData(idx_t batch_index, unique_ptr new_batch) { + // move the batch data to the set of prepared batch data + lock_guard l(lock); + auto entry = batch_data.insert(make_pair(batch_index, std::move(new_batch))); + if (!entry.second) { + throw InternalException("Duplicate batch index %llu encountered in PhysicalFixedBatchCopy", batch_index); + } + } + +private: + mutex task_lock; + //! The task queue for the batch copy to file + queue> task_queue; +}; + +class FixedBatchCopyLocalState : public LocalSinkState { +public: + explicit FixedBatchCopyLocalState(unique_ptr local_state_p) + : local_state(std::move(local_state_p)), rows_copied(0) { + } + + //! Local copy state + unique_ptr local_state; + //! The current collection we are appending to + unique_ptr collection; + //! The append state of the collection + ColumnDataAppendState append_state; + //! How many rows have been copied in total + idx_t rows_copied; + //! The current batch index + optional_idx batch_index; + + void InitializeCollection(ClientContext &context, const PhysicalOperator &op) { + collection = make_uniq(BufferAllocator::Get(context), op.children[0]->types); + collection->InitializeAppend(append_state); + } +}; + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +SinkResultType PhysicalFixedBatchCopy::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &state = input.local_state.Cast(); + if (!state.collection) { + state.InitializeCollection(context.client, *this); + state.batch_index = state.partition_info.batch_index.GetIndex(); + } + state.rows_copied += chunk.size(); + state.collection->Append(state.append_state, chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalFixedBatchCopy::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &state = input.local_state.Cast(); + auto &gstate = input.global_state.Cast(); + gstate.rows_copied += state.rows_copied; + if (!gstate.any_finished) { + // signal that this thread is finished processing batches and that we should move on to Finalize + lock_guard l(gstate.lock); + gstate.any_finished = true; + } + ExecuteTasks(context.client, gstate); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// ProcessRemainingBatchesEvent +//===--------------------------------------------------------------------===// +class ProcessRemainingBatchesTask : public ExecutorTask { +public: + ProcessRemainingBatchesTask(Executor &executor, shared_ptr event_p, FixedBatchCopyGlobalState &state_p, + ClientContext &context, const PhysicalFixedBatchCopy &op) + : ExecutorTask(executor), event(std::move(event_p)), op(op), gstate(state_p), context(context) { + } + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + while (op.ExecuteTask(context, gstate)) { + op.FlushBatchData(context, gstate, 0); + } + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; + } + +private: + shared_ptr event; + const PhysicalFixedBatchCopy &op; + FixedBatchCopyGlobalState &gstate; + ClientContext &context; +}; + +class ProcessRemainingBatchesEvent : public BasePipelineEvent { +public: + ProcessRemainingBatchesEvent(const PhysicalFixedBatchCopy &op_p, FixedBatchCopyGlobalState &gstate_p, + Pipeline &pipeline_p, ClientContext &context) + : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), context(context) { + } + const PhysicalFixedBatchCopy &op; + FixedBatchCopyGlobalState &gstate; + ClientContext &context; + +public: + void Schedule() override { + vector> tasks; + for (idx_t i = 0; i < idx_t(TaskScheduler::GetScheduler(context).NumberOfThreads()); i++) { + auto process_task = + make_uniq(pipeline->executor, shared_from_this(), gstate, context, op); + tasks.push_back(std::move(process_task)); + } + D_ASSERT(!tasks.empty()); + SetTasks(std::move(tasks)); + } + + void FinishEvent() override { + //! Now that all batches are processed we finish flushing the file to disk + op.FinalFlush(context, gstate); + } +}; +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +SinkFinalizeType PhysicalFixedBatchCopy::FinalFlush(ClientContext &context, GlobalSinkState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + if (gstate.TaskCount() != 0) { + throw InternalException("Unexecuted tasks are remaining in PhysicalFixedBatchCopy::FinalFlush!?"); + } + idx_t min_batch_index = idx_t(NumericLimits::Maximum()); + FlushBatchData(context, gstate_p, min_batch_index); + if (gstate.scheduled_batch_index != gstate.flushed_batch_index) { + throw InternalException("Not all batches were flushed to disk - incomplete file?"); + } + if (function.copy_to_finalize) { + function.copy_to_finalize(context, *bind_data, *gstate.global_state); + + if (use_tmp_file) { + PhysicalCopyToFile::MoveTmpFile(context, file_path); + } + } + return SinkFinalizeType::READY; +} + +SinkFinalizeType PhysicalFixedBatchCopy::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + idx_t min_batch_index = idx_t(NumericLimits::Maximum()); + // repartition any remaining batches + RepartitionBatches(context, input.global_state, min_batch_index, true); + // check if we have multiple tasks to execute + if (gstate.TaskCount() <= 1) { + // we don't - just execute the remaining task and finish flushing to disk + ExecuteTasks(context, input.global_state); + FinalFlush(context, input.global_state); + return SinkFinalizeType::READY; + } + // we have multiple tasks remaining - launch an event to execute the tasks in parallel + auto new_event = make_shared(*this, gstate, pipeline, context); + event.InsertEvent(std::move(new_event)); + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Tasks +//===--------------------------------------------------------------------===// +class RepartitionedFlushTask : public BatchCopyTask { +public: + RepartitionedFlushTask() { + } + + void Execute(const PhysicalFixedBatchCopy &op, ClientContext &context, GlobalSinkState &gstate_p) override { + op.FlushBatchData(context, gstate_p, 0); + } +}; + +class PrepareBatchTask : public BatchCopyTask { +public: + PrepareBatchTask(idx_t batch_index, unique_ptr collection_p) + : batch_index(batch_index), collection(std::move(collection_p)) { + } + + idx_t batch_index; + unique_ptr collection; + + void Execute(const PhysicalFixedBatchCopy &op, ClientContext &context, GlobalSinkState &gstate_p) override { + auto &gstate = gstate_p.Cast(); + auto batch_data = + op.function.prepare_batch(context, *op.bind_data, *gstate.global_state, std::move(collection)); + gstate.AddBatchData(batch_index, std::move(batch_data)); + if (batch_index == gstate.flushed_batch_index) { + gstate.AddTask(make_uniq()); + } + } +}; + +//===--------------------------------------------------------------------===// +// Batch Data Handling +//===--------------------------------------------------------------------===// +void PhysicalFixedBatchCopy::AddRawBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t batch_index, + unique_ptr collection) const { + auto &gstate = gstate_p.Cast(); + + // add the batch index to the set of raw batches + lock_guard l(gstate.lock); + auto entry = gstate.raw_batches.insert(make_pair(batch_index, std::move(collection))); + if (!entry.second) { + throw InternalException("Duplicate batch index %llu encountered in PhysicalFixedBatchCopy", batch_index); + } +} + +static bool CorrectSizeForBatch(idx_t collection_size, idx_t desired_size) { + return idx_t(AbsValue(int64_t(collection_size) - int64_t(desired_size))) < STANDARD_VECTOR_SIZE; +} + +void PhysicalFixedBatchCopy::RepartitionBatches(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index, + bool final) const { + auto &gstate = gstate_p.Cast(); + + // repartition batches until the min index is reached + lock_guard l(gstate.lock); + if (gstate.raw_batches.empty()) { + return; + } + if (!final) { + if (gstate.any_finished) { + // we only repartition in ::NextBatch if all threads are still busy processing batches + // otherwise we might end up repartitioning a lot of data with only a few threads remaining + // which causes erratic performance + return; + } + // if this is not the final flush we first check if we have enough data to merge past the batch threshold + idx_t candidate_rows = 0; + for (auto entry = gstate.raw_batches.begin(); entry != gstate.raw_batches.end(); entry++) { + if (entry->first >= min_index) { + // we have exceeded the minimum batch + break; + } + candidate_rows += entry->second->Count(); + } + if (candidate_rows < gstate.batch_size) { + // not enough rows - cancel! + return; + } + } + // gather all collections we can repartition + idx_t max_batch_index = 0; + vector> collections; + for (auto entry = gstate.raw_batches.begin(); entry != gstate.raw_batches.end();) { + if (entry->first >= min_index) { + break; + } + max_batch_index = entry->first; + collections.push_back(std::move(entry->second)); + entry = gstate.raw_batches.erase(entry); + } + unique_ptr current_collection; + ColumnDataAppendState append_state; + // now perform the actual repartitioning + for (auto &collection : collections) { + if (!current_collection) { + if (CorrectSizeForBatch(collection->Count(), gstate.batch_size)) { + // the collection is ~approximately equal to the batch size (off by at most one vector) + // use it directly + gstate.AddTask(make_uniq(gstate.scheduled_batch_index++, std::move(collection))); + collection.reset(); + } else if (collection->Count() < gstate.batch_size) { + // the collection is smaller than the batch size - use it as a starting point + current_collection = std::move(collection); + collection.reset(); + } else { + // the collection is too large for a batch - we need to repartition + // create an empty collection + current_collection = make_uniq(BufferAllocator::Get(context), children[0]->types); + } + if (current_collection) { + current_collection->InitializeAppend(append_state); + } + } + if (!collection) { + // we have consumed the collection already - no need to append + continue; + } + // iterate the collection while appending + for (auto &chunk : collection->Chunks()) { + // append the chunk to the collection + current_collection->Append(append_state, chunk); + if (current_collection->Count() < gstate.batch_size) { + // the collection is still under the batch size - continue + continue; + } + // the collection is full - move it to the result and create a new one + gstate.AddTask(make_uniq(gstate.scheduled_batch_index++, std::move(current_collection))); + current_collection = make_uniq(BufferAllocator::Get(context), children[0]->types); + current_collection->InitializeAppend(append_state); + } + } + if (current_collection && current_collection->Count() > 0) { + // if there are any remaining batches that are not filled up to the batch size + // AND this is not the final collection + // re-add it to the set of raw (to-be-merged) batches + if (final || CorrectSizeForBatch(current_collection->Count(), gstate.batch_size)) { + gstate.AddTask(make_uniq(gstate.scheduled_batch_index++, std::move(current_collection))); + } else { + gstate.raw_batches[max_batch_index] = std::move(current_collection); + } + } +} + +void PhysicalFixedBatchCopy::FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index) const { + auto &gstate = gstate_p.Cast(); + + // flush batch data to disk (if there are any to flush) + // grab the flush lock - we can only call flush_batch with this lock + // otherwise the data might end up in the wrong order + { + lock_guard l(gstate.flush_lock); + if (gstate.any_flushing) { + return; + } + gstate.any_flushing = true; + } + ActiveFlushGuard active_flush(gstate.any_flushing); + while (true) { + unique_ptr batch_data; + { + lock_guard l(gstate.lock); + if (gstate.batch_data.empty()) { + // no batch data left to flush + break; + } + auto entry = gstate.batch_data.begin(); + if (entry->first != gstate.flushed_batch_index) { + // this entry is not yet ready to be flushed + break; + } + if (entry->first < gstate.flushed_batch_index) { + throw InternalException("Batch index was out of order!?"); + } + batch_data = std::move(entry->second); + gstate.batch_data.erase(entry); + } + function.flush_batch(context, *bind_data, *gstate.global_state, *batch_data); + gstate.flushed_batch_index++; + } +} + +//===--------------------------------------------------------------------===// +// Tasks +//===--------------------------------------------------------------------===// +bool PhysicalFixedBatchCopy::ExecuteTask(ClientContext &context, GlobalSinkState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + auto task = gstate.GetTask(); + if (!task) { + return false; + } + task->Execute(*this, context, gstate_p); + return true; +} + +void PhysicalFixedBatchCopy::ExecuteTasks(ClientContext &context, GlobalSinkState &gstate_p) const { + while (ExecuteTask(context, gstate_p)) { + } +} + +//===--------------------------------------------------------------------===// +// Next Batch +//===--------------------------------------------------------------------===// +void PhysicalFixedBatchCopy::NextBatch(ExecutionContext &context, GlobalSinkState &gstate_p, + LocalSinkState &lstate) const { + auto &state = lstate.Cast(); + if (state.collection && state.collection->Count() > 0) { + // we finished processing this batch + // start flushing data + auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); + // push the raw batch data into the set of unprocessed batches + AddRawBatchData(context.client, gstate_p, state.batch_index.GetIndex(), std::move(state.collection)); + // attempt to repartition to our desired batch size + RepartitionBatches(context.client, gstate_p, min_batch_index); + // execute a single batch task + ExecuteTask(context.client, gstate_p); + FlushBatchData(context.client, gstate_p, min_batch_index); + } + state.batch_index = lstate.partition_info.batch_index.GetIndex(); + + state.InitializeCollection(context.client, *this); +} + +unique_ptr PhysicalFixedBatchCopy::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(function.copy_to_initialize_local(context, *bind_data)); +} + +unique_ptr PhysicalFixedBatchCopy::GetGlobalSinkState(ClientContext &context) const { + auto result = + make_uniq(function.copy_to_initialize_global(context, *bind_data, file_path)); + result->batch_size = function.desired_batch_size(context, *bind_data); + return std::move(result); +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalFixedBatchCopy::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &g = sink_state->Cast(); + + chunk.SetCardinality(1); + chunk.SetValue(0, 0, Value::BIGINT(g.rows_copied)); + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp new file mode 100644 index 00000000..8e91ae20 --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp @@ -0,0 +1,550 @@ +#include "duckdb/execution/operator/persistent/physical_insert.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/storage/table_io_manager.hpp" +#include "duckdb/transaction/local_storage.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/common/types/conflict_manager.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/storage/table/append_state.hpp" + +namespace duckdb { + +PhysicalInsert::PhysicalInsert(vector types_p, TableCatalogEntry &table, + physical_index_vector_t column_index_map, + vector> bound_defaults, + vector> set_expressions, vector set_columns, + vector set_types, idx_t estimated_cardinality, bool return_chunk, + bool parallel, OnConflictAction action_type, + unique_ptr on_conflict_condition_p, + unique_ptr do_update_condition_p, unordered_set conflict_target_p, + vector columns_to_fetch_p) + : PhysicalOperator(PhysicalOperatorType::INSERT, std::move(types_p), estimated_cardinality), + column_index_map(std::move(column_index_map)), insert_table(&table), insert_types(table.GetTypes()), + bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk), parallel(parallel), + action_type(action_type), set_expressions(std::move(set_expressions)), set_columns(std::move(set_columns)), + set_types(std::move(set_types)), on_conflict_condition(std::move(on_conflict_condition_p)), + do_update_condition(std::move(do_update_condition_p)), conflict_target(std::move(conflict_target_p)), + columns_to_fetch(std::move(columns_to_fetch_p)) { + + if (action_type == OnConflictAction::THROW) { + return; + } + + D_ASSERT(this->set_expressions.size() == this->set_columns.size()); + + // One or more columns are referenced from the existing table, + // we use the 'insert_types' to figure out which types these columns have + types_to_fetch = vector(columns_to_fetch.size(), LogicalType::SQLNULL); + for (idx_t i = 0; i < columns_to_fetch.size(); i++) { + auto &id = columns_to_fetch[i]; + D_ASSERT(id < insert_types.size()); + types_to_fetch[i] = insert_types[id]; + } +} + +PhysicalInsert::PhysicalInsert(LogicalOperator &op, SchemaCatalogEntry &schema, unique_ptr info_p, + idx_t estimated_cardinality, bool parallel) + : PhysicalOperator(PhysicalOperatorType::CREATE_TABLE_AS, op.types, estimated_cardinality), insert_table(nullptr), + return_chunk(false), schema(&schema), info(std::move(info_p)), parallel(parallel), + action_type(OnConflictAction::THROW) { + GetInsertInfo(*info, insert_types, bound_defaults); +} + +void PhysicalInsert::GetInsertInfo(const BoundCreateTableInfo &info, vector &insert_types, + vector> &bound_defaults) { + auto &create_info = info.base->Cast(); + for (auto &col : create_info.columns.Physical()) { + insert_types.push_back(col.GetType()); + bound_defaults.push_back(make_uniq(Value(col.GetType()))); + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class InsertGlobalState : public GlobalSinkState { +public: + explicit InsertGlobalState(ClientContext &context, const vector &return_types, DuckTableEntry &table) + : table(table), insert_count(0), initialized(false), return_collection(context, return_types) { + } + + mutex lock; + DuckTableEntry &table; + idx_t insert_count; + bool initialized; + LocalAppendState append_state; + ColumnDataCollection return_collection; +}; + +class InsertLocalState : public LocalSinkState { +public: + InsertLocalState(ClientContext &context, const vector &types, + const vector> &bound_defaults) + : default_executor(context, bound_defaults) { + insert_chunk.Initialize(Allocator::Get(context), types); + } + + DataChunk insert_chunk; + ExpressionExecutor default_executor; + TableAppendState local_append_state; + unique_ptr local_collection; + optional_ptr writer; + // Rows that have been updated by a DO UPDATE conflict + unordered_set updated_global_rows; + // Rows in the transaction-local storage that have been updated by a DO UPDATE conflict + unordered_set updated_local_rows; + idx_t update_count = 0; +}; + +unique_ptr PhysicalInsert::GetGlobalSinkState(ClientContext &context) const { + optional_ptr table; + if (info) { + // CREATE TABLE AS + D_ASSERT(!insert_table); + auto &catalog = schema->catalog; + table = &catalog.CreateTable(catalog.GetCatalogTransaction(context), *schema.get_mutable(), *info) + ->Cast(); + } else { + D_ASSERT(insert_table); + D_ASSERT(insert_table->IsDuckTable()); + table = insert_table.get_mutable(); + } + auto result = make_uniq(context, GetTypes(), table->Cast()); + return std::move(result); +} + +unique_ptr PhysicalInsert::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context.client, insert_types, bound_defaults); +} + +void PhysicalInsert::ResolveDefaults(const TableCatalogEntry &table, DataChunk &chunk, + const physical_index_vector_t &column_index_map, + ExpressionExecutor &default_executor, DataChunk &result) { + chunk.Flatten(); + default_executor.SetChunk(chunk); + + result.Reset(); + result.SetCardinality(chunk); + + if (!column_index_map.empty()) { + // columns specified by the user, use column_index_map + for (auto &col : table.GetColumns().Physical()) { + auto storage_idx = col.StorageOid(); + auto mapped_index = column_index_map[col.Physical()]; + if (mapped_index == DConstants::INVALID_INDEX) { + // insert default value + default_executor.ExecuteExpression(storage_idx, result.data[storage_idx]); + } else { + // get value from child chunk + D_ASSERT((idx_t)mapped_index < chunk.ColumnCount()); + D_ASSERT(result.data[storage_idx].GetType() == chunk.data[mapped_index].GetType()); + result.data[storage_idx].Reference(chunk.data[mapped_index]); + } + } + } else { + // no columns specified, just append directly + for (idx_t i = 0; i < result.ColumnCount(); i++) { + D_ASSERT(result.data[i].GetType() == chunk.data[i].GetType()); + result.data[i].Reference(chunk.data[i]); + } + } +} + +bool AllConflictsMeetCondition(DataChunk &result) { + auto data = FlatVector::GetData(result.data[0]); + for (idx_t i = 0; i < result.size(); i++) { + if (!data[i]) { + return false; + } + } + return true; +} + +void CheckOnConflictCondition(ExecutionContext &context, DataChunk &conflicts, const unique_ptr &condition, + DataChunk &result) { + ExpressionExecutor executor(context.client, *condition); + result.Initialize(context.client, {LogicalType::BOOLEAN}); + executor.Execute(conflicts, result); + result.SetCardinality(conflicts.size()); +} + +static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_chunk, DataChunk &input_chunk, + ClientContext &client, const PhysicalInsert &op) { + auto &types_to_fetch = op.types_to_fetch; + auto &insert_types = op.insert_types; + + if (types_to_fetch.empty()) { + // We have not scanned the initial table, so we can just duplicate the initial chunk + result.Initialize(client, input_chunk.GetTypes()); + result.Reference(input_chunk); + result.SetCardinality(input_chunk); + return; + } + vector combined_types; + combined_types.reserve(insert_types.size() + types_to_fetch.size()); + combined_types.insert(combined_types.end(), insert_types.begin(), insert_types.end()); + combined_types.insert(combined_types.end(), types_to_fetch.begin(), types_to_fetch.end()); + + result.Initialize(client, combined_types); + result.Reset(); + // Add the VALUES list + for (idx_t i = 0; i < insert_types.size(); i++) { + idx_t col_idx = i; + auto &other_col = input_chunk.data[i]; + auto &this_col = result.data[col_idx]; + D_ASSERT(other_col.GetType() == this_col.GetType()); + this_col.Reference(other_col); + } + // Add the columns from the original conflicting tuples + for (idx_t i = 0; i < types_to_fetch.size(); i++) { + idx_t col_idx = i + insert_types.size(); + auto &other_col = scan_chunk.data[i]; + auto &this_col = result.data[col_idx]; + D_ASSERT(other_col.GetType() == this_col.GetType()); + this_col.Reference(other_col); + } + // This is guaranteed by the requirement of a conflict target to have a condition or set expressions + // Only when we have any sort of condition or SET expression that references the existing table is this possible + // to not be true. + // We can have a SET expression without a conflict target ONLY if there is only 1 Index on the table + // In which case this also can't cause a discrepancy between existing tuple count and insert tuple count + D_ASSERT(input_chunk.size() == scan_chunk.size()); + result.SetCardinality(input_chunk.size()); +} + +static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, TableCatalogEntry &table, Vector &row_ids, + DataChunk &update_chunk, const PhysicalInsert &op) { + + auto &do_update_condition = op.do_update_condition; + auto &set_types = op.set_types; + auto &set_expressions = op.set_expressions; + // Check the optional condition for the DO UPDATE clause, to filter which rows will be updated + if (do_update_condition) { + DataChunk do_update_filter_result; + do_update_filter_result.Initialize(context.client, {LogicalType::BOOLEAN}); + ExpressionExecutor where_executor(context.client, *do_update_condition); + where_executor.Execute(chunk, do_update_filter_result); + do_update_filter_result.SetCardinality(chunk.size()); + + ManagedSelection selection(chunk.size()); + + auto where_data = FlatVector::GetData(do_update_filter_result.data[0]); + for (idx_t i = 0; i < chunk.size(); i++) { + if (where_data[i]) { + selection.Append(i); + } + } + if (selection.Count() != selection.Size()) { + // Not all conflicts met the condition, need to filter out the ones that don't + chunk.Slice(selection.Selection(), selection.Count()); + chunk.SetCardinality(selection.Count()); + // Also apply this Slice to the to-update row_ids + row_ids.Slice(selection.Selection(), selection.Count()); + } + } + + // Execute the SET expressions + update_chunk.Initialize(context.client, set_types); + ExpressionExecutor executor(context.client, set_expressions); + executor.Execute(chunk, update_chunk); + update_chunk.SetCardinality(chunk); +} + +template +static idx_t PerformOnConflictAction(ExecutionContext &context, DataChunk &chunk, TableCatalogEntry &table, + Vector &row_ids, const PhysicalInsert &op) { + + if (op.action_type == OnConflictAction::NOTHING) { + return 0; + } + auto &set_columns = op.set_columns; + + DataChunk update_chunk; + CreateUpdateChunk(context, chunk, table, row_ids, update_chunk, op); + + auto &data_table = table.GetStorage(); + // Perform the update, using the results of the SET expressions + if (GLOBAL) { + data_table.Update(table, context.client, row_ids, set_columns, update_chunk); + } else { + auto &local_storage = LocalStorage::Get(context.client, data_table.db); + // Perform the update, using the results of the SET expressions + local_storage.Update(data_table, row_ids, set_columns, update_chunk); + } + return update_chunk.size(); +} + +// TODO: should we use a hash table to keep track of this instead? +template +static void RegisterUpdatedRows(InsertLocalState &lstate, const Vector &row_ids, idx_t count) { + // Insert all rows, if any of the rows has already been updated before, we throw an error + auto data = FlatVector::GetData(row_ids); + + // The rowids in the transaction-local ART aren't final yet so we have to separately keep track of the two sets of + // rowids + unordered_set &updated_rows = GLOBAL ? lstate.updated_global_rows : lstate.updated_local_rows; + for (idx_t i = 0; i < count; i++) { + auto result = updated_rows.insert(data[i]); + if (result.second == false) { + throw InvalidInputException( + "ON CONFLICT DO UPDATE can not update the same row twice in the same command, Ensure that no rows " + "proposed for insertion within the same command have duplicate constrained values"); + } + } +} + +template +static idx_t HandleInsertConflicts(TableCatalogEntry &table, ExecutionContext &context, InsertLocalState &lstate, + DataTable &data_table, const PhysicalInsert &op) { + auto &types_to_fetch = op.types_to_fetch; + auto &on_conflict_condition = op.on_conflict_condition; + auto &conflict_target = op.conflict_target; + auto &columns_to_fetch = op.columns_to_fetch; + + auto &local_storage = LocalStorage::Get(context.client, data_table.db); + + // We either want to do nothing, or perform an update when conflicts arise + ConflictInfo conflict_info(conflict_target); + ConflictManager conflict_manager(VerifyExistenceType::APPEND, lstate.insert_chunk.size(), &conflict_info); + if (GLOBAL) { + data_table.VerifyAppendConstraints(table, context.client, lstate.insert_chunk, &conflict_manager); + } else { + DataTable::VerifyUniqueIndexes(local_storage.GetIndexes(data_table), context.client, lstate.insert_chunk, + &conflict_manager); + } + conflict_manager.Finalize(); + if (conflict_manager.ConflictCount() == 0) { + // No conflicts found, 0 updates performed + return 0; + } + auto &conflicts = conflict_manager.Conflicts(); + auto &row_ids = conflict_manager.RowIds(); + + DataChunk conflict_chunk; // contains only the conflicting values + DataChunk scan_chunk; // contains the original values, that caused the conflict + DataChunk combined_chunk; // contains conflict_chunk + scan_chunk (wide) + + // Filter out everything but the conflicting rows + conflict_chunk.Initialize(context.client, lstate.insert_chunk.GetTypes()); + conflict_chunk.Reference(lstate.insert_chunk); + conflict_chunk.Slice(conflicts.Selection(), conflicts.Count()); + conflict_chunk.SetCardinality(conflicts.Count()); + + // Holds the pins for the fetched rows + unique_ptr fetch_state; + if (!types_to_fetch.empty()) { + D_ASSERT(scan_chunk.size() == 0); + // When these values are required for the conditions or the SET expressions, + // then we scan the existing table for the conflicting tuples, using the rowids + scan_chunk.Initialize(context.client, types_to_fetch); + fetch_state = make_uniq(); + if (GLOBAL) { + auto &transaction = DuckTransaction::Get(context.client, table.catalog); + data_table.Fetch(transaction, scan_chunk, columns_to_fetch, row_ids, conflicts.Count(), *fetch_state); + } else { + local_storage.FetchChunk(data_table, row_ids, conflicts.Count(), columns_to_fetch, scan_chunk, + *fetch_state); + } + } + + // Splice the Input chunk and the fetched chunk together + CombineExistingAndInsertTuples(combined_chunk, scan_chunk, conflict_chunk, context.client, op); + + if (on_conflict_condition) { + DataChunk conflict_condition_result; + CheckOnConflictCondition(context, combined_chunk, on_conflict_condition, conflict_condition_result); + bool conditions_met = AllConflictsMeetCondition(conflict_condition_result); + if (!conditions_met) { + // Filter out the tuples that did pass the filter, then run the verify again + ManagedSelection sel(combined_chunk.size()); + auto data = FlatVector::GetData(conflict_condition_result.data[0]); + for (idx_t i = 0; i < combined_chunk.size(); i++) { + if (!data[i]) { + // Only populate the selection vector with the tuples that did not meet the condition + sel.Append(i); + } + } + combined_chunk.Slice(sel.Selection(), sel.Count()); + row_ids.Slice(sel.Selection(), sel.Count()); + if (GLOBAL) { + data_table.VerifyAppendConstraints(table, context.client, combined_chunk, nullptr); + } else { + DataTable::VerifyUniqueIndexes(local_storage.GetIndexes(data_table), context.client, + lstate.insert_chunk, nullptr); + } + throw InternalException("The previous operation was expected to throw but didn't"); + } + } + + RegisterUpdatedRows(lstate, row_ids, combined_chunk.size()); + + idx_t updated_tuples = PerformOnConflictAction(context, combined_chunk, table, row_ids, op); + + // Remove the conflicting tuples from the insert chunk + SelectionVector sel_vec(lstate.insert_chunk.size()); + idx_t new_size = + SelectionVector::Inverted(conflicts.Selection(), sel_vec, conflicts.Count(), lstate.insert_chunk.size()); + lstate.insert_chunk.Slice(sel_vec, new_size); + lstate.insert_chunk.SetCardinality(new_size); + return updated_tuples; +} + +idx_t PhysicalInsert::OnConflictHandling(TableCatalogEntry &table, ExecutionContext &context, + InsertLocalState &lstate) const { + auto &data_table = table.GetStorage(); + if (action_type == OnConflictAction::THROW) { + data_table.VerifyAppendConstraints(table, context.client, lstate.insert_chunk, nullptr); + return 0; + } + // Check whether any conflicts arise, and if they all meet the conflict_target + condition + // If that's not the case - We throw the first error + idx_t updated_tuples = 0; + updated_tuples += HandleInsertConflicts(table, context, lstate, data_table, *this); + // Also check the transaction-local storage+ART so we can detect conflicts within this transaction + updated_tuples += HandleInsertConflicts(table, context, lstate, data_table, *this); + + return updated_tuples; +} + +SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + auto &table = gstate.table; + auto &storage = table.GetStorage(); + PhysicalInsert::ResolveDefaults(table, chunk, column_index_map, lstate.default_executor, lstate.insert_chunk); + + if (!parallel) { + if (!gstate.initialized) { + storage.InitializeLocalAppend(gstate.append_state, context.client); + gstate.initialized = true; + } + + idx_t updated_tuples = OnConflictHandling(table, context, lstate); + gstate.insert_count += lstate.insert_chunk.size(); + gstate.insert_count += updated_tuples; + storage.LocalAppend(gstate.append_state, table, context.client, lstate.insert_chunk, true); + + if (return_chunk) { + gstate.return_collection.Append(lstate.insert_chunk); + } + } else { + D_ASSERT(!return_chunk); + // parallel append + if (!lstate.local_collection) { + lock_guard l(gstate.lock); + auto &table_info = storage.info; + auto &block_manager = TableIOManager::Get(storage).GetBlockManagerForRowData(); + lstate.local_collection = + make_uniq(table_info, block_manager, insert_types, MAX_ROW_ID); + lstate.local_collection->InitializeEmpty(); + lstate.local_collection->InitializeAppend(lstate.local_append_state); + lstate.writer = &gstate.table.GetStorage().CreateOptimisticWriter(context.client); + } + OnConflictHandling(table, context, lstate); + + auto new_row_group = lstate.local_collection->Append(lstate.insert_chunk, lstate.local_append_state); + if (new_row_group) { + lstate.writer->WriteNewRowGroup(*lstate.local_collection); + } + } + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalInsert::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + auto &client_profiler = QueryProfiler::Get(context.client); + context.thread.profiler.Flush(*this, lstate.default_executor, "default_executor", 1); + client_profiler.Flush(context.thread.profiler); + + if (!parallel || !lstate.local_collection) { + return SinkCombineResultType::FINISHED; + } + + // parallel append: finalize the append + TransactionData tdata(0, 0); + lstate.local_collection->FinalizeAppend(tdata, lstate.local_append_state); + + auto append_count = lstate.local_collection->GetTotalRows(); + + lock_guard lock(gstate.lock); + gstate.insert_count += append_count; + if (append_count < Storage::ROW_GROUP_SIZE) { + // we have few rows - append to the local storage directly + auto &table = gstate.table; + auto &storage = table.GetStorage(); + storage.InitializeLocalAppend(gstate.append_state, context.client); + auto &transaction = DuckTransaction::Get(context.client, table.catalog); + lstate.local_collection->Scan(transaction, [&](DataChunk &insert_chunk) { + storage.LocalAppend(gstate.append_state, table, context.client, insert_chunk); + return true; + }); + storage.FinalizeLocalAppend(gstate.append_state); + } else { + // we have written rows to disk optimistically - merge directly into the transaction-local storage + gstate.table.GetStorage().FinalizeOptimisticWriter(context.client, *lstate.writer); + gstate.table.GetStorage().LocalMerge(context.client, *lstate.local_collection); + } + + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalInsert::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + if (!parallel && gstate.initialized) { + auto &table = gstate.table; + auto &storage = table.GetStorage(); + storage.FinalizeLocalAppend(gstate.append_state); + } + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class InsertSourceState : public GlobalSourceState { +public: + explicit InsertSourceState(const PhysicalInsert &op) { + if (op.return_chunk) { + D_ASSERT(op.sink_state); + auto &g = op.sink_state->Cast(); + g.return_collection.InitializeScan(scan_state); + } + } + + ColumnDataScanState scan_state; +}; + +unique_ptr PhysicalInsert::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this); +} + +SourceResultType PhysicalInsert::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &state = input.global_state.Cast(); + auto &insert_gstate = sink_state->Cast(); + if (!return_chunk) { + chunk.SetCardinality(1); + chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.insert_count)); + return SourceResultType::FINISHED; + } + + insert_gstate.return_collection.Scan(state.scan_state, chunk); + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_update.cpp b/src/duckdb/src/execution/operator/persistent/physical_update.cpp new file mode 100644 index 00000000..fb7f1c30 --- /dev/null +++ b/src/duckdb/src/execution/operator/persistent/physical_update.cpp @@ -0,0 +1,187 @@ +#include "duckdb/execution/operator/persistent/physical_update.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/storage/data_table.hpp" + +namespace duckdb { + +PhysicalUpdate::PhysicalUpdate(vector types, TableCatalogEntry &tableref, DataTable &table, + vector columns, vector> expressions, + vector> bound_defaults, idx_t estimated_cardinality, + bool return_chunk) + : PhysicalOperator(PhysicalOperatorType::UPDATE, std::move(types), estimated_cardinality), tableref(tableref), + table(table), columns(std::move(columns)), expressions(std::move(expressions)), + bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk) { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class UpdateGlobalState : public GlobalSinkState { +public: + explicit UpdateGlobalState(ClientContext &context, const vector &return_types) + : updated_count(0), return_collection(context, return_types) { + } + + mutex lock; + idx_t updated_count; + unordered_set updated_columns; + ColumnDataCollection return_collection; +}; + +class UpdateLocalState : public LocalSinkState { +public: + UpdateLocalState(ClientContext &context, const vector> &expressions, + const vector &table_types, const vector> &bound_defaults) + : default_executor(context, bound_defaults) { + // initialize the update chunk + auto &allocator = Allocator::Get(context); + vector update_types; + update_types.reserve(expressions.size()); + for (auto &expr : expressions) { + update_types.push_back(expr->return_type); + } + update_chunk.Initialize(allocator, update_types); + // initialize the mock chunk + mock_chunk.Initialize(allocator, table_types); + } + + DataChunk update_chunk; + DataChunk mock_chunk; + ExpressionExecutor default_executor; +}; + +SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + DataChunk &update_chunk = lstate.update_chunk; + DataChunk &mock_chunk = lstate.mock_chunk; + + chunk.Flatten(); + lstate.default_executor.SetChunk(chunk); + + // update data in the base table + // the row ids are given to us as the last column of the child chunk + auto &row_ids = chunk.data[chunk.ColumnCount() - 1]; + update_chunk.Reset(); + update_chunk.SetCardinality(chunk); + + for (idx_t i = 0; i < expressions.size(); i++) { + if (expressions[i]->type == ExpressionType::VALUE_DEFAULT) { + // default expression, set to the default value of the column + lstate.default_executor.ExecuteExpression(columns[i].index, update_chunk.data[i]); + } else { + D_ASSERT(expressions[i]->type == ExpressionType::BOUND_REF); + // index into child chunk + auto &binding = expressions[i]->Cast(); + update_chunk.data[i].Reference(chunk.data[binding.index]); + } + } + + lock_guard glock(gstate.lock); + if (update_is_del_and_insert) { + // index update or update on complex type, perform a delete and an append instead + + // figure out which rows have not yet been deleted in this update + // this is required since we might see the same row_id multiple times + // in the case of an UPDATE query that e.g. has joins + auto row_id_data = FlatVector::GetData(row_ids); + SelectionVector sel(STANDARD_VECTOR_SIZE); + idx_t update_count = 0; + for (idx_t i = 0; i < update_chunk.size(); i++) { + auto row_id = row_id_data[i]; + if (gstate.updated_columns.find(row_id) == gstate.updated_columns.end()) { + gstate.updated_columns.insert(row_id); + sel.set_index(update_count++, i); + } + } + if (update_count != update_chunk.size()) { + // we need to slice here + update_chunk.Slice(sel, update_count); + } + table.Delete(tableref, context.client, row_ids, update_chunk.size()); + // for the append we need to arrange the columns in a specific manner (namely the "standard table order") + mock_chunk.SetCardinality(update_chunk); + for (idx_t i = 0; i < columns.size(); i++) { + mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); + } + table.LocalAppend(tableref, context.client, mock_chunk); + } else { + if (return_chunk) { + mock_chunk.SetCardinality(update_chunk); + for (idx_t i = 0; i < columns.size(); i++) { + mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); + } + } + table.Update(tableref, context.client, row_ids, columns, update_chunk); + } + + if (return_chunk) { + gstate.return_collection.Append(mock_chunk); + } + + gstate.updated_count += chunk.size(); + + return SinkResultType::NEED_MORE_INPUT; +} + +unique_ptr PhysicalUpdate::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, GetTypes()); +} + +unique_ptr PhysicalUpdate::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context.client, expressions, table.GetTypes(), bound_defaults); +} + +SinkCombineResultType PhysicalUpdate::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + auto &state = input.local_state.Cast(); + auto &client_profiler = QueryProfiler::Get(context.client); + context.thread.profiler.Flush(*this, state.default_executor, "default_executor", 1); + client_profiler.Flush(context.thread.profiler); + + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +class UpdateSourceState : public GlobalSourceState { +public: + explicit UpdateSourceState(const PhysicalUpdate &op) { + if (op.return_chunk) { + D_ASSERT(op.sink_state); + auto &g = op.sink_state->Cast(); + g.return_collection.InitializeScan(scan_state); + } + } + + ColumnDataScanState scan_state; +}; + +unique_ptr PhysicalUpdate::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(*this); +} + +SourceResultType PhysicalUpdate::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &state = input.global_state.Cast(); + auto &g = sink_state->Cast(); + if (!return_chunk) { + chunk.SetCardinality(1); + chunk.SetValue(0, 0, Value::BIGINT(g.updated_count)); + return SourceResultType::FINISHED; + } + + g.return_collection.Scan(state.scan_state, chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp new file mode 100644 index 00000000..69d255bf --- /dev/null +++ b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp @@ -0,0 +1,82 @@ +#include "duckdb/execution/operator/projection/physical_pivot.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +PhysicalPivot::PhysicalPivot(vector types_p, unique_ptr child, + BoundPivotInfo bound_pivot_p) + : PhysicalOperator(PhysicalOperatorType::PIVOT, std::move(types_p), child->estimated_cardinality), + bound_pivot(std::move(bound_pivot_p)) { + children.push_back(std::move(child)); + for (idx_t p = 0; p < bound_pivot.pivot_values.size(); p++) { + auto entry = pivot_map.find(bound_pivot.pivot_values[p]); + if (entry != pivot_map.end()) { + continue; + } + pivot_map[bound_pivot.pivot_values[p]] = bound_pivot.group_count + p; + } + // extract the empty aggregate expressions + ArenaAllocator allocator(Allocator::DefaultAllocator()); + for (auto &aggr_expr : bound_pivot.aggregates) { + auto &aggr = aggr_expr->Cast(); + // for each aggregate, initialize an empty aggregate state and finalize it immediately + auto state = make_unsafe_uniq_array(aggr.function.state_size()); + aggr.function.initialize(state.get()); + Vector state_vector(Value::POINTER(CastPointerToValue(state.get()))); + Vector result_vector(aggr_expr->return_type); + AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); + aggr.function.finalize(state_vector, aggr_input_data, result_vector, 1, 0); + empty_aggregates.push_back(result_vector.GetValue(0)); + } +} + +OperatorResultType PhysicalPivot::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const { + // copy the groups as-is + for (idx_t i = 0; i < bound_pivot.group_count; i++) { + chunk.data[i].Reference(input.data[i]); + } + auto pivot_column_lists = FlatVector::GetData(input.data.back()); + auto &pivot_column_values = ListVector::GetEntry(input.data.back()); + auto pivot_columns = FlatVector::GetData(pivot_column_values); + + // initialize all aggregate columns with the empty aggregate value + // if there are multiple aggregates the columns are in order of [AGGR1][AGGR2][AGGR1][AGGR2] + // so we need to alternate the empty_aggregate that we use + idx_t aggregate = 0; + for (idx_t c = bound_pivot.group_count; c < chunk.ColumnCount(); c++) { + chunk.data[c].Reference(empty_aggregates[aggregate]); + chunk.data[c].Flatten(input.size()); + aggregate++; + if (aggregate >= empty_aggregates.size()) { + aggregate = 0; + } + } + + // move the pivots to the given columns + for (idx_t r = 0; r < input.size(); r++) { + auto list = pivot_column_lists[r]; + for (idx_t l = 0; l < list.length; l++) { + // figure out the column value number of this list + auto &column_name = pivot_columns[list.offset + l]; + auto entry = pivot_map.find(column_name); + if (entry == pivot_map.end()) { + // column entry not found in map - that means this element is explicitly excluded from the pivot list + continue; + } + auto column_idx = entry->second; + for (idx_t aggr = 0; aggr < empty_aggregates.size(); aggr++) { + auto pivot_value_lists = FlatVector::GetData(input.data[bound_pivot.group_count + aggr]); + auto &pivot_value_child = ListVector::GetEntry(input.data[bound_pivot.group_count + aggr]); + if (list.offset != pivot_value_lists[r].offset || list.length != pivot_value_lists[r].length) { + throw InternalException("Pivot - unaligned lists between values and columns!?"); + } + chunk.data[column_idx + aggr].SetValue(r, pivot_value_child.GetValue(list.offset + l)); + } + } + } + chunk.SetCardinality(input.size()); + return OperatorResultType::NEED_MORE_INPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_projection.cpp b/src/duckdb/src/execution/operator/projection/physical_projection.cpp new file mode 100644 index 00000000..e6d915e3 --- /dev/null +++ b/src/duckdb/src/execution/operator/projection/physical_projection.cpp @@ -0,0 +1,80 @@ +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +class ProjectionState : public OperatorState { +public: + explicit ProjectionState(ExecutionContext &context, const vector> &expressions) + : executor(context.client, expressions) { + } + + ExpressionExecutor executor; + +public: + void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { + context.thread.profiler.Flush(op, executor, "projection", 0); + } +}; + +PhysicalProjection::PhysicalProjection(vector types, vector> select_list, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::PROJECTION, std::move(types), estimated_cardinality), + select_list(std::move(select_list)) { +} + +OperatorResultType PhysicalProjection::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state_p) const { + auto &state = state_p.Cast(); + state.executor.Execute(input, chunk); + return OperatorResultType::NEED_MORE_INPUT; +} + +unique_ptr PhysicalProjection::GetOperatorState(ExecutionContext &context) const { + return make_uniq(context, select_list); +} + +unique_ptr +PhysicalProjection::CreateJoinProjection(vector proj_types, const vector &lhs_types, + const vector &rhs_types, const vector &left_projection_map, + const vector &right_projection_map, const idx_t estimated_cardinality) { + + vector> proj_selects; + proj_selects.reserve(proj_types.size()); + + if (left_projection_map.empty()) { + for (storage_t i = 0; i < lhs_types.size(); ++i) { + proj_selects.emplace_back(make_uniq(lhs_types[i], i)); + } + } else { + for (auto i : left_projection_map) { + proj_selects.emplace_back(make_uniq(lhs_types[i], i)); + } + } + const auto left_cols = lhs_types.size(); + + if (right_projection_map.empty()) { + for (storage_t i = 0; i < rhs_types.size(); ++i) { + proj_selects.emplace_back(make_uniq(rhs_types[i], left_cols + i)); + } + + } else { + for (auto i : right_projection_map) { + proj_selects.emplace_back(make_uniq(rhs_types[i], left_cols + i)); + } + } + + return make_uniq(std::move(proj_types), std::move(proj_selects), estimated_cardinality); +} + +string PhysicalProjection::ParamsToString() const { + string extra_info; + for (auto &expr : select_list) { + extra_info += expr->GetName() + "\n"; + } + return extra_info; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp new file mode 100644 index 00000000..d32094f7 --- /dev/null +++ b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp @@ -0,0 +1,115 @@ +#include "duckdb/execution/operator/projection/physical_tableinout_function.hpp" + +namespace duckdb { + +class TableInOutLocalState : public OperatorState { +public: + TableInOutLocalState() : row_index(0), new_row(true) { + } + + unique_ptr local_state; + idx_t row_index; + bool new_row; + DataChunk input_chunk; +}; + +class TableInOutGlobalState : public GlobalOperatorState { +public: + TableInOutGlobalState() { + } + + unique_ptr global_state; +}; + +PhysicalTableInOutFunction::PhysicalTableInOutFunction(vector types, TableFunction function_p, + unique_ptr bind_data_p, + vector column_ids_p, idx_t estimated_cardinality, + vector project_input_p) + : PhysicalOperator(PhysicalOperatorType::INOUT_FUNCTION, std::move(types), estimated_cardinality), + function(std::move(function_p)), bind_data(std::move(bind_data_p)), column_ids(std::move(column_ids_p)), + projected_input(std::move(project_input_p)) { +} + +unique_ptr PhysicalTableInOutFunction::GetOperatorState(ExecutionContext &context) const { + auto &gstate = op_state->Cast(); + auto result = make_uniq(); + if (function.init_local) { + TableFunctionInitInput input(bind_data.get(), column_ids, vector(), nullptr); + result->local_state = function.init_local(context, input, gstate.global_state.get()); + } + if (!projected_input.empty()) { + result->input_chunk.Initialize(context.client, children[0]->types); + } + return std::move(result); +} + +unique_ptr PhysicalTableInOutFunction::GetGlobalOperatorState(ClientContext &context) const { + auto result = make_uniq(); + if (function.init_global) { + TableFunctionInitInput input(bind_data.get(), column_ids, vector(), nullptr); + result->global_state = function.init_global(context, input); + } + return std::move(result); +} + +OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate_p, OperatorState &state_p) const { + auto &gstate = gstate_p.Cast(); + auto &state = state_p.Cast(); + TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); + if (projected_input.empty()) { + // straightforward case - no need to project input + return function.in_out_function(context, data, input, chunk); + } + // when project_input is set we execute the input function row-by-row + if (state.new_row) { + if (state.row_index >= input.size()) { + // finished processing this chunk + state.new_row = true; + state.row_index = 0; + return OperatorResultType::NEED_MORE_INPUT; + } + // we are processing a new row: fetch the data for the current row + state.input_chunk.Reset(); + D_ASSERT(input.ColumnCount() == state.input_chunk.ColumnCount()); + // set up the input data to the table in-out function + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + ConstantVector::Reference(state.input_chunk.data[col_idx], input.data[col_idx], state.row_index, 1); + } + state.input_chunk.SetCardinality(1); + state.row_index++; + state.new_row = false; + } + // set up the output data in "chunk" + D_ASSERT(chunk.ColumnCount() > projected_input.size()); + D_ASSERT(state.row_index > 0); + idx_t base_idx = chunk.ColumnCount() - projected_input.size(); + for (idx_t project_idx = 0; project_idx < projected_input.size(); project_idx++) { + auto source_idx = projected_input[project_idx]; + auto target_idx = base_idx + project_idx; + ConstantVector::Reference(chunk.data[target_idx], input.data[source_idx], state.row_index - 1, 1); + } + auto result = function.in_out_function(context, data, state.input_chunk, chunk); + if (result == OperatorResultType::FINISHED) { + return result; + } + if (result == OperatorResultType::NEED_MORE_INPUT) { + // we finished processing this row: move to the next row + state.new_row = true; + } + return OperatorResultType::HAVE_MORE_OUTPUT; +} + +OperatorFinalizeResultType PhysicalTableInOutFunction::FinalExecute(ExecutionContext &context, DataChunk &chunk, + GlobalOperatorState &gstate_p, + OperatorState &state_p) const { + auto &gstate = gstate_p.Cast(); + auto &state = state_p.Cast(); + if (!projected_input.empty()) { + throw InternalException("FinalExecute not supported for project_input"); + } + TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); + return function.in_out_function_final(context, data, chunk); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp new file mode 100644 index 00000000..d5632f80 --- /dev/null +++ b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp @@ -0,0 +1,364 @@ +#include "duckdb/execution/operator/projection/physical_unnest.hpp" + +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_unnest_expression.hpp" + +namespace duckdb { + +class UnnestOperatorState : public OperatorState { +public: + UnnestOperatorState(ClientContext &context, const vector> &select_list) + : current_row(0), list_position(0), longest_list_length(DConstants::INVALID_INDEX), first_fetch(true), + executor(context) { + + // for each UNNEST in the select_list, we add the child expression to the expression executor + // and set the return type in the list_data chunk, which will contain the evaluated expression results + vector list_data_types; + for (auto &exp : select_list) { + D_ASSERT(exp->type == ExpressionType::BOUND_UNNEST); + auto &bue = exp->Cast(); + list_data_types.push_back(bue.child->return_type); + executor.AddExpression(*bue.child.get()); + } + + auto &allocator = Allocator::Get(context); + list_data.Initialize(allocator, list_data_types); + + list_vector_data.resize(list_data.ColumnCount()); + list_child_data.resize(list_data.ColumnCount()); + } + + idx_t current_row; + idx_t list_position; + idx_t longest_list_length; + bool first_fetch; + + ExpressionExecutor executor; + DataChunk list_data; + vector list_vector_data; + vector list_child_data; + +public: + //! Reset the fields of the unnest operator state + void Reset(); + //! Set the longest list's length for the current row + void SetLongestListLength(); +}; + +void UnnestOperatorState::Reset() { + current_row = 0; + list_position = 0; + longest_list_length = DConstants::INVALID_INDEX; + first_fetch = true; +} + +void UnnestOperatorState::SetLongestListLength() { + longest_list_length = 0; + for (idx_t col_idx = 0; col_idx < list_data.ColumnCount(); col_idx++) { + + auto &vector_data = list_vector_data[col_idx]; + auto current_idx = vector_data.sel->get_index(current_row); + + if (vector_data.validity.RowIsValid(current_idx)) { + + // check if this list is longer + auto list_data_entries = UnifiedVectorFormat::GetData(vector_data); + auto list_entry = list_data_entries[current_idx]; + if (list_entry.length > longest_list_length) { + longest_list_length = list_entry.length; + } + } + } +} + +PhysicalUnnest::PhysicalUnnest(vector types, vector> select_list, + idx_t estimated_cardinality, PhysicalOperatorType type) + : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { + D_ASSERT(!this->select_list.empty()); +} + +static void UnnestNull(idx_t start, idx_t end, Vector &result) { + + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + auto &validity = FlatVector::Validity(result); + for (idx_t i = start; i < end; i++) { + validity.SetInvalid(i); + } + if (result.GetType().InternalType() == PhysicalType::STRUCT) { + auto &struct_children = StructVector::GetEntries(result); + for (auto &child : struct_children) { + UnnestNull(start, end, *child); + } + } +} + +template +static void TemplatedUnnest(UnifiedVectorFormat &vector_data, idx_t start, idx_t end, Vector &result) { + + auto source_data = UnifiedVectorFormat::GetData(vector_data); + auto &source_mask = vector_data.validity; + + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + + for (idx_t i = start; i < end; i++) { + auto source_idx = vector_data.sel->get_index(i); + auto target_idx = i - start; + if (source_mask.RowIsValid(source_idx)) { + result_data[target_idx] = source_data[source_idx]; + result_mask.SetValid(target_idx); + } else { + result_mask.SetInvalid(target_idx); + } + } +} + +static void UnnestValidity(UnifiedVectorFormat &vector_data, idx_t start, idx_t end, Vector &result) { + + auto &source_mask = vector_data.validity; + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + auto &result_mask = FlatVector::Validity(result); + + for (idx_t i = start; i < end; i++) { + auto source_idx = vector_data.sel->get_index(i); + auto target_idx = i - start; + result_mask.Set(target_idx, source_mask.RowIsValid(source_idx)); + } +} + +static void UnnestVector(UnifiedVectorFormat &child_vector_data, Vector &child_vector, idx_t list_size, idx_t start, + idx_t end, Vector &result) { + + D_ASSERT(child_vector.GetType() == result.GetType()); + switch (result.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::INT16: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::INT32: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::INT64: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::INT128: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::UINT8: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::UINT16: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::UINT32: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::UINT64: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::FLOAT: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::DOUBLE: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::INTERVAL: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::VARCHAR: + TemplatedUnnest(child_vector_data, start, end, result); + break; + case PhysicalType::LIST: { + // the child vector of result now references the child vector source + // FIXME: only reference relevant children (start - end) instead of all + auto &target = ListVector::GetEntry(result); + target.Reference(ListVector::GetEntry(child_vector)); + ListVector::SetListSize(result, ListVector::GetListSize(child_vector)); + // unnest + TemplatedUnnest(child_vector_data, start, end, result); + break; + } + case PhysicalType::STRUCT: { + auto &child_vector_entries = StructVector::GetEntries(child_vector); + auto &result_entries = StructVector::GetEntries(result); + + // set the validity mask for the 'outer' struct vector before unnesting its children + UnnestValidity(child_vector_data, start, end, result); + + for (idx_t i = 0; i < child_vector_entries.size(); i++) { + UnifiedVectorFormat child_vector_entries_data; + child_vector_entries[i]->ToUnifiedFormat(list_size, child_vector_entries_data); + UnnestVector(child_vector_entries_data, *child_vector_entries[i], list_size, start, end, + *result_entries[i]); + } + break; + } + default: + throw InternalException("Unimplemented type for UNNEST."); + } +} + +static void PrepareInput(UnnestOperatorState &state, DataChunk &input, + const vector> &select_list) { + + state.list_data.Reset(); + // execute the expressions inside each UNNEST in the select_list to get the list data + // execution results (lists) are kept in state.list_data chunk + state.executor.Execute(input, state.list_data); + + // verify incoming lists + state.list_data.Verify(); + D_ASSERT(input.size() == state.list_data.size()); + D_ASSERT(state.list_data.ColumnCount() == select_list.size()); + D_ASSERT(state.list_vector_data.size() == state.list_data.ColumnCount()); + D_ASSERT(state.list_child_data.size() == state.list_data.ColumnCount()); + + // get the UnifiedVectorFormat of each list_data vector (LIST vectors for the different UNNESTs) + // both for the vector itself and its child vector + for (idx_t col_idx = 0; col_idx < state.list_data.ColumnCount(); col_idx++) { + + auto &list_vector = state.list_data.data[col_idx]; + list_vector.ToUnifiedFormat(state.list_data.size(), state.list_vector_data[col_idx]); + + if (list_vector.GetType() == LogicalType::SQLNULL) { + // UNNEST(NULL): SQLNULL vectors don't have child vectors, but we need to point to the child vector of + // each vector, so we just get the UnifiedVectorFormat of the vector itself + auto &child_vector = list_vector; + child_vector.ToUnifiedFormat(0, state.list_child_data[col_idx]); + } else { + auto list_size = ListVector::GetListSize(list_vector); + auto &child_vector = ListVector::GetEntry(list_vector); + child_vector.ToUnifiedFormat(list_size, state.list_child_data[col_idx]); + } + } + + state.first_fetch = false; +} + +unique_ptr PhysicalUnnest::GetOperatorState(ExecutionContext &context) const { + return PhysicalUnnest::GetState(context, select_list); +} + +unique_ptr PhysicalUnnest::GetState(ExecutionContext &context, + const vector> &select_list) { + return make_uniq(context.client, select_list); +} + +OperatorResultType PhysicalUnnest::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + OperatorState &state_p, + const vector> &select_list, + bool include_input) { + + auto &state = state_p.Cast(); + + do { + // reset validities, if previous loop iteration contained UNNEST(NULL) + if (include_input) { + chunk.Reset(); + } + + // prepare the input data by executing any expressions and getting the + // UnifiedVectorFormat of each LIST vector (list_vector_data) and its child vector (list_child_data) + if (state.first_fetch) { + PrepareInput(state, input, select_list); + } + + // finished with all rows of this input chunk, reset + if (state.current_row >= input.size()) { + state.Reset(); + return OperatorResultType::NEED_MORE_INPUT; + } + + // each UNNEST in the select_list contains a list (or NULL) for this row, find the longest list + // because this length determines how many times we need to repeat for the current row + if (state.longest_list_length == DConstants::INVALID_INDEX) { + state.SetLongestListLength(); + } + D_ASSERT(state.longest_list_length != DConstants::INVALID_INDEX); + + // we emit chunks of either STANDARD_VECTOR_SIZE or smaller + auto this_chunk_len = MinValue(STANDARD_VECTOR_SIZE, state.longest_list_length - state.list_position); + chunk.SetCardinality(this_chunk_len); + + // if we include other projection input columns, e.g. SELECT 1, UNNEST([1, 2]);, then + // we need to add them as a constant vector to the resulting chunk + // FIXME: emit multiple unnested rows. Currently, we never emit a chunk containing multiple unnested input rows, + // so setting a constant vector for the value at state.current_row is fine + idx_t col_offset = 0; + if (include_input) { + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + ConstantVector::Reference(chunk.data[col_idx], input.data[col_idx], state.current_row, input.size()); + } + col_offset = input.ColumnCount(); + } + + // unnest the lists + for (idx_t col_idx = 0; col_idx < state.list_data.ColumnCount(); col_idx++) { + + auto &result_vector = chunk.data[col_idx + col_offset]; + + if (state.list_data.data[col_idx].GetType() == LogicalType::SQLNULL) { + // UNNEST(NULL) + chunk.SetCardinality(0); + break; + } + + auto &vector_data = state.list_vector_data[col_idx]; + auto current_idx = vector_data.sel->get_index(state.current_row); + + if (!vector_data.validity.RowIsValid(current_idx)) { + UnnestNull(0, this_chunk_len, result_vector); + continue; + } + + auto list_data = UnifiedVectorFormat::GetData(vector_data); + auto list_entry = list_data[current_idx]; + + idx_t list_count = 0; + if (state.list_position < list_entry.length) { + // there are still list_count elements to unnest + list_count = MinValue(this_chunk_len, list_entry.length - state.list_position); + + auto &list_vector = state.list_data.data[col_idx]; + auto &child_vector = ListVector::GetEntry(list_vector); + auto list_size = ListVector::GetListSize(list_vector); + auto &child_vector_data = state.list_child_data[col_idx]; + + auto base_offset = list_entry.offset + state.list_position; + UnnestVector(child_vector_data, child_vector, list_size, base_offset, base_offset + list_count, + result_vector); + } + + // fill the rest with NULLs + if (list_count != this_chunk_len) { + UnnestNull(list_count, this_chunk_len, result_vector); + } + } + + chunk.Verify(); + + state.list_position += this_chunk_len; + if (state.list_position == state.longest_list_length) { + state.current_row++; + state.longest_list_length = DConstants::INVALID_INDEX; + state.list_position = 0; + } + + // we only emit one unnested row (that contains data) at a time + } while (chunk.size() == 0); + return OperatorResultType::HAVE_MORE_OUTPUT; +} + +OperatorResultType PhysicalUnnest::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &, OperatorState &state) const { + return ExecuteInternal(context, input, chunk, state, select_list); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp new file mode 100644 index 00000000..55c00b7a --- /dev/null +++ b/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp @@ -0,0 +1,98 @@ +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" + +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" + +namespace duckdb { + +PhysicalColumnDataScan::PhysicalColumnDataScan(vector types, PhysicalOperatorType op_type, + idx_t estimated_cardinality, + unique_ptr owned_collection_p) + : PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(owned_collection_p.get()), + owned_collection(std::move(owned_collection_p)) { +} + +class PhysicalColumnDataScanState : public GlobalSourceState { +public: + explicit PhysicalColumnDataScanState() : initialized(false) { + } + + //! The current position in the scan + ColumnDataScanState scan_state; + bool initialized; +}; + +unique_ptr PhysicalColumnDataScan::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(); +} + +SourceResultType PhysicalColumnDataScan::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &state = input.global_state.Cast(); + if (collection->Count() == 0) { + return SourceResultType::FINISHED; + } + if (!state.initialized) { + collection->InitializeScan(state.scan_state); + state.initialized = true; + } + collection->Scan(state.scan_state, chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalColumnDataScan::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + // check if there is any additional action we need to do depending on the type + auto &state = meta_pipeline.GetState(); + switch (type) { + case PhysicalOperatorType::DELIM_SCAN: { + auto entry = state.delim_join_dependencies.find(*this); + D_ASSERT(entry != state.delim_join_dependencies.end()); + // this chunk scan introduces a dependency to the current pipeline + // namely a dependency on the duplicate elimination pipeline to finish + auto delim_dependency = entry->second.get().shared_from_this(); + auto delim_sink = state.GetPipelineSink(*delim_dependency); + D_ASSERT(delim_sink); + D_ASSERT(delim_sink->type == PhysicalOperatorType::DELIM_JOIN); + auto &delim_join = delim_sink->Cast(); + current.AddDependency(delim_dependency); + state.SetPipelineSource(current, delim_join.distinct->Cast()); + return; + } + case PhysicalOperatorType::CTE_SCAN: { + break; + } + case PhysicalOperatorType::RECURSIVE_CTE_SCAN: + if (!meta_pipeline.HasRecursiveCTE()) { + throw InternalException("Recursive CTE scan found without recursive CTE node"); + } + break; + default: + break; + } + D_ASSERT(children.empty()); + state.SetPipelineSource(current, *this); +} + +string PhysicalColumnDataScan::ParamsToString() const { + string result = ""; + switch (type) { + case PhysicalOperatorType::CTE_SCAN: + case PhysicalOperatorType::RECURSIVE_CTE_SCAN: { + result += "\n[INFOSEPARATOR]\n"; + result += StringUtil::Format("idx: %llu", cte_index); + break; + } + default: + break; + } + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp new file mode 100644 index 00000000..1a620803 --- /dev/null +++ b/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp @@ -0,0 +1,13 @@ +#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" + +namespace duckdb { + +SourceResultType PhysicalDummyScan::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + // return a single row on the first call to the dummy scan + chunk.SetCardinality(1); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp b/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp new file mode 100644 index 00000000..2e7d006b --- /dev/null +++ b/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp @@ -0,0 +1,10 @@ +#include "duckdb/execution/operator/scan/physical_empty_result.hpp" + +namespace duckdb { + +SourceResultType PhysicalEmptyResult::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp new file mode 100644 index 00000000..5e08bc94 --- /dev/null +++ b/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp @@ -0,0 +1,63 @@ +#include "duckdb/execution/operator/scan/physical_expression_scan.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +class ExpressionScanState : public OperatorState { +public: + explicit ExpressionScanState(Allocator &allocator, const PhysicalExpressionScan &op) : expression_index(0) { + temp_chunk.Initialize(allocator, op.GetTypes()); + } + + //! The current position in the scan + idx_t expression_index; + //! Temporary chunk for evaluating expressions + DataChunk temp_chunk; +}; + +unique_ptr PhysicalExpressionScan::GetOperatorState(ExecutionContext &context) const { + return make_uniq(Allocator::Get(context.client), *this); +} + +OperatorResultType PhysicalExpressionScan::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state_p) const { + auto &state = state_p.Cast(); + + for (; chunk.size() + input.size() <= STANDARD_VECTOR_SIZE && state.expression_index < expressions.size(); + state.expression_index++) { + state.temp_chunk.Reset(); + EvaluateExpression(context.client, state.expression_index, &input, state.temp_chunk); + chunk.Append(state.temp_chunk); + } + if (state.expression_index < expressions.size()) { + return OperatorResultType::HAVE_MORE_OUTPUT; + } else { + state.expression_index = 0; + return OperatorResultType::NEED_MORE_INPUT; + } +} + +void PhysicalExpressionScan::EvaluateExpression(ClientContext &context, idx_t expression_idx, DataChunk *child_chunk, + DataChunk &result) const { + ExpressionExecutor executor(context, expressions[expression_idx]); + if (child_chunk) { + child_chunk->Verify(); + executor.Execute(*child_chunk, result); + } else { + executor.Execute(result); + } +} + +bool PhysicalExpressionScan::IsFoldable() const { + for (auto &expr_list : expressions) { + for (auto &expr : expr_list) { + if (!expr->IsFoldable()) { + return false; + } + } + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp new file mode 100644 index 00000000..e2a67505 --- /dev/null +++ b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp @@ -0,0 +1,211 @@ +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parallel/interrupt.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/transaction/transaction.hpp" + +#include + +namespace duckdb { + +PhysicalPositionalScan::PhysicalPositionalScan(vector types, unique_ptr left, + unique_ptr right) + : PhysicalOperator(PhysicalOperatorType::POSITIONAL_SCAN, std::move(types), + MaxValue(left->estimated_cardinality, right->estimated_cardinality)) { + + // Manage the children ourselves + if (left->type == PhysicalOperatorType::TABLE_SCAN) { + child_tables.emplace_back(std::move(left)); + } else if (left->type == PhysicalOperatorType::POSITIONAL_SCAN) { + auto &left_scan = left->Cast(); + child_tables = std::move(left_scan.child_tables); + } else { + throw InternalException("Invalid left input for PhysicalPositionalScan"); + } + + if (right->type == PhysicalOperatorType::TABLE_SCAN) { + child_tables.emplace_back(std::move(right)); + } else if (right->type == PhysicalOperatorType::POSITIONAL_SCAN) { + auto &right_scan = right->Cast(); + auto &right_tables = right_scan.child_tables; + child_tables.reserve(child_tables.size() + right_tables.size()); + std::move(right_tables.begin(), right_tables.end(), std::back_inserter(child_tables)); + } else { + throw InternalException("Invalid right input for PhysicalPositionalScan"); + } +} + +class PositionalScanGlobalSourceState : public GlobalSourceState { +public: + PositionalScanGlobalSourceState(ClientContext &context, const PhysicalPositionalScan &op) { + for (const auto &table : op.child_tables) { + global_states.emplace_back(table->GetGlobalSourceState(context)); + } + } + + vector> global_states; + + idx_t MaxThreads() override { + return 1; + } +}; + +class PositionalTableScanner { +public: + PositionalTableScanner(ExecutionContext &context, PhysicalOperator &table_p, GlobalSourceState &gstate_p) + : table(table_p), global_state(gstate_p), source_offset(0), exhausted(false) { + local_state = table.GetLocalSourceState(context, gstate_p); + source.Initialize(Allocator::Get(context.client), table.types); + } + + idx_t Refill(ExecutionContext &context) { + if (source_offset >= source.size()) { + if (!exhausted) { + source.Reset(); + + InterruptState interrupt_state; + OperatorSourceInput source_input {global_state, *local_state, interrupt_state}; + auto source_result = table.GetData(context, source, source_input); + if (source_result == SourceResultType::BLOCKED) { + throw NotImplementedException( + "Unexpected interrupt from table Source in PositionalTableScanner refill"); + } + } + source_offset = 0; + } + + const auto available = source.size() - source_offset; + if (!available) { + if (!exhausted) { + source.Reset(); + for (idx_t i = 0; i < source.ColumnCount(); ++i) { + auto &vec = source.data[i]; + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); + } + exhausted = true; + } + } + + return available; + } + + idx_t CopyData(ExecutionContext &context, DataChunk &output, const idx_t count, const idx_t col_offset) { + if (!source_offset && (source.size() >= count || exhausted)) { + // Fast track: aligned and has enough data + for (idx_t i = 0; i < source.ColumnCount(); ++i) { + output.data[col_offset + i].Reference(source.data[i]); + } + source_offset += count; + } else { + // Copy data + for (idx_t target_offset = 0; target_offset < count;) { + const auto needed = count - target_offset; + const auto available = exhausted ? needed : (source.size() - source_offset); + const auto copy_size = MinValue(needed, available); + const auto source_count = source_offset + copy_size; + for (idx_t i = 0; i < source.ColumnCount(); ++i) { + VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_count, source_offset, + target_offset); + } + target_offset += copy_size; + source_offset += copy_size; + Refill(context); + } + } + + return source.ColumnCount(); + } + + double GetProgress(ClientContext &context) { + return table.GetProgress(context, global_state); + } + + PhysicalOperator &table; + GlobalSourceState &global_state; + unique_ptr local_state; + DataChunk source; + idx_t source_offset; + bool exhausted; +}; + +class PositionalScanLocalSourceState : public LocalSourceState { +public: + PositionalScanLocalSourceState(ExecutionContext &context, PositionalScanGlobalSourceState &gstate, + const PhysicalPositionalScan &op) { + for (size_t i = 0; i < op.child_tables.size(); ++i) { + auto &child = *op.child_tables[i]; + auto &global_state = *gstate.global_states[i]; + scanners.emplace_back(make_uniq(context, child, global_state)); + } + } + + vector> scanners; +}; + +unique_ptr PhysicalPositionalScan::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(context, gstate.Cast(), *this); +} + +unique_ptr PhysicalPositionalScan::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(context, *this); +} + +SourceResultType PhysicalPositionalScan::GetData(ExecutionContext &context, DataChunk &output, + OperatorSourceInput &input) const { + auto &lstate = input.local_state.Cast(); + + // Find the longest source block + idx_t count = 0; + for (auto &scanner : lstate.scanners) { + count = MaxValue(count, scanner->Refill(context)); + } + + // All done? + if (!count) { + return SourceResultType::FINISHED; + } + + // Copy or reference the source columns + idx_t col_offset = 0; + for (auto &scanner : lstate.scanners) { + col_offset += scanner->CopyData(context, output, count, col_offset); + } + + output.SetCardinality(count); + return SourceResultType::HAVE_MORE_OUTPUT; +} + +double PhysicalPositionalScan::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + + double result = child_tables[0]->GetProgress(context, *gstate.global_states[0]); + for (size_t t = 1; t < child_tables.size(); ++t) { + result = MinValue(result, child_tables[t]->GetProgress(context, *gstate.global_states[t])); + } + + return result; +} + +bool PhysicalPositionalScan::Equals(const PhysicalOperator &other_p) const { + if (type != other_p.type) { + return false; + } + + auto &other = other_p.Cast(); + if (child_tables.size() != other.child_tables.size()) { + return false; + } + for (size_t i = 0; i < child_tables.size(); ++i) { + if (!child_tables[i]->Equals(*other.child_tables[i])) { + return false; + } + } + + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp new file mode 100644 index 00000000..ba64b293 --- /dev/null +++ b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp @@ -0,0 +1,169 @@ +#include "duckdb/execution/operator/scan/physical_table_scan.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/transaction/transaction.hpp" + +#include + +namespace duckdb { + +PhysicalTableScan::PhysicalTableScan(vector types, TableFunction function_p, + unique_ptr bind_data_p, vector returned_types_p, + vector column_ids_p, vector projection_ids_p, + vector names_p, unique_ptr table_filters_p, + idx_t estimated_cardinality, ExtraOperatorInfo extra_info) + : PhysicalOperator(PhysicalOperatorType::TABLE_SCAN, std::move(types), estimated_cardinality), + function(std::move(function_p)), bind_data(std::move(bind_data_p)), returned_types(std::move(returned_types_p)), + column_ids(std::move(column_ids_p)), projection_ids(std::move(projection_ids_p)), names(std::move(names_p)), + table_filters(std::move(table_filters_p)), extra_info(extra_info) { +} + +class TableScanGlobalSourceState : public GlobalSourceState { +public: + TableScanGlobalSourceState(ClientContext &context, const PhysicalTableScan &op) { + if (op.function.init_global) { + TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, op.table_filters.get()); + global_state = op.function.init_global(context, input); + if (global_state) { + max_threads = global_state->MaxThreads(); + } + } else { + max_threads = 1; + } + } + + idx_t max_threads = 0; + unique_ptr global_state; + + idx_t MaxThreads() override { + return max_threads; + } +}; + +class TableScanLocalSourceState : public LocalSourceState { +public: + TableScanLocalSourceState(ExecutionContext &context, TableScanGlobalSourceState &gstate, + const PhysicalTableScan &op) { + if (op.function.init_local) { + TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, op.table_filters.get()); + local_state = op.function.init_local(context, input, gstate.global_state.get()); + } + } + + unique_ptr local_state; +}; + +unique_ptr PhysicalTableScan::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(context, gstate.Cast(), *this); +} + +unique_ptr PhysicalTableScan::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(context, *this); +} + +SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + D_ASSERT(!column_ids.empty()); + auto &gstate = input.global_state.Cast(); + auto &state = input.local_state.Cast(); + + TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); + function.function(context.client, data, chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +double PhysicalTableScan::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + if (function.table_scan_progress) { + return function.table_scan_progress(context, bind_data.get(), gstate.global_state.get()); + } + // if table_scan_progress is not implemented we don't support this function yet in the progress bar + return -1; +} + +idx_t PhysicalTableScan::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate_p, + LocalSourceState &lstate) const { + D_ASSERT(SupportsBatchIndex()); + D_ASSERT(function.get_batch_index); + auto &gstate = gstate_p.Cast(); + auto &state = lstate.Cast(); + return function.get_batch_index(context.client, bind_data.get(), state.local_state.get(), + gstate.global_state.get()); +} + +string PhysicalTableScan::GetName() const { + return StringUtil::Upper(function.name + " " + function.extra_info); +} + +string PhysicalTableScan::ParamsToString() const { + string result; + if (function.to_string) { + result = function.to_string(bind_data.get()); + result += "\n[INFOSEPARATOR]\n"; + } + if (function.projection_pushdown) { + if (function.filter_prune) { + for (idx_t i = 0; i < projection_ids.size(); i++) { + const auto &column_id = column_ids[projection_ids[i]]; + if (column_id < names.size()) { + if (i > 0) { + result += "\n"; + } + result += names[column_id]; + } + } + } else { + for (idx_t i = 0; i < column_ids.size(); i++) { + const auto &column_id = column_ids[i]; + if (column_id < names.size()) { + if (i > 0) { + result += "\n"; + } + result += names[column_id]; + } + } + } + } + if (function.filter_pushdown && table_filters) { + result += "\n[INFOSEPARATOR]\n"; + result += "Filters: "; + for (auto &f : table_filters->filters) { + auto &column_index = f.first; + auto &filter = f.second; + if (column_index < names.size()) { + result += filter->ToString(names[column_ids[column_index]]); + result += "\n"; + } + } + } + if (!extra_info.file_filters.empty()) { + result += "\n[INFOSEPARATOR]\n"; + result += "File Filters: " + extra_info.file_filters; + } + result += "\n[INFOSEPARATOR]\n"; + result += StringUtil::Format("EC: %llu", estimated_cardinality); + return result; +} + +bool PhysicalTableScan::Equals(const PhysicalOperator &other_p) const { + if (type != other_p.type) { + return false; + } + auto &other = other_p.Cast(); + if (function.function != other.function.function) { + return false; + } + if (column_ids != other.column_ids) { + return false; + } + if (!FunctionData::Equals(bind_data.get(), other.bind_data.get())) { + return false; + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_alter.cpp b/src/duckdb/src/execution/operator/schema/physical_alter.cpp new file mode 100644 index 00000000..7fc00006 --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_alter.cpp @@ -0,0 +1,17 @@ +#include "duckdb/execution/operator/schema/physical_alter.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalAlter::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + auto &catalog = Catalog::GetCatalog(context.client, info->catalog); + catalog.Alter(context.client, *info); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_attach.cpp b/src/duckdb/src/execution/operator/schema/physical_attach.cpp new file mode 100644 index 00000000..ac497d8a --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_attach.cpp @@ -0,0 +1,88 @@ +#include "duckdb/execution/operator/schema/physical_attach.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/database_path_and_type.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/parser/parsed_data/attach_info.hpp" +#include "duckdb/storage/storage_extension.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + // parse the options + auto &config = DBConfig::GetConfig(context.client); + AccessMode access_mode = config.options.access_mode; + string type; + string unrecognized_option; + for (auto &entry : info->options) { + if (entry.first == "readonly" || entry.first == "read_only") { + auto read_only = BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); + if (read_only) { + access_mode = AccessMode::READ_ONLY; + } else { + access_mode = AccessMode::READ_WRITE; + } + } else if (entry.first == "readwrite" || entry.first == "read_write") { + auto read_only = !BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); + if (read_only) { + access_mode = AccessMode::READ_ONLY; + } else { + access_mode = AccessMode::READ_WRITE; + } + } else if (entry.first == "type") { + type = StringValue::Get(entry.second.DefaultCastAs(LogicalType::VARCHAR)); + } else if (unrecognized_option.empty()) { + unrecognized_option = entry.first; + } + } + auto &db = DatabaseInstance::GetDatabase(context.client); + if (type.empty()) { + // try to extract type from path + auto path_and_type = DBPathAndType::Parse(info->path, config); + type = path_and_type.type; + info->path = path_and_type.path; + } + + if (type.empty() && !unrecognized_option.empty()) { + throw BinderException("Unrecognized option for attach \"%s\"", unrecognized_option); + } + + // if we are loading a database type from an extension - check if that extension is loaded + if (!type.empty()) { + if (!Catalog::TryAutoLoad(context.client, type)) { + // FIXME: Here it might be preferrable to use an AutoLoadOrThrow kind of function + // so that either there will be success or a message to throw, and load will be + // attempted only once respecting the autoloading options + ExtensionHelper::LoadExternalExtension(context.client, type); + } + } + + // attach the database + auto &name = info->name; + const auto &path = info->path; + + if (name.empty()) { + auto &fs = FileSystem::GetFileSystem(context.client); + name = AttachedDatabase::ExtractDatabaseName(path, fs); + } + auto &db_manager = DatabaseManager::Get(context.client); + auto existing_db = db_manager.GetDatabaseFromPath(context.client, path); + if (existing_db) { + throw BinderException("Database \"%s\" is already attached with alias \"%s\"", path, existing_db->GetName()); + } + auto new_db = db.CreateAttachedDatabase(*info, type, access_mode); + new_db->Initialize(); + + db_manager.AddDatabase(context.client, std::move(new_db)); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp new file mode 100644 index 00000000..7c598dc8 --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp @@ -0,0 +1,198 @@ +#include "duckdb/execution/operator/schema/physical_create_art_index.hpp" + +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/storage/index.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/node.hpp" +#include "duckdb/execution/index/art/leaf.hpp" + +namespace duckdb { + +PhysicalCreateARTIndex::PhysicalCreateARTIndex(LogicalOperator &op, TableCatalogEntry &table_p, + const vector &column_ids, unique_ptr info, + vector> unbound_expressions, + idx_t estimated_cardinality, const bool sorted) + : PhysicalOperator(PhysicalOperatorType::CREATE_INDEX, op.types, estimated_cardinality), + table(table_p.Cast()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)), + sorted(sorted) { + // convert virtual column ids to storage column ids + for (auto &column_id : column_ids) { + storage_ids.push_back(table.GetColumns().LogicalToPhysical(LogicalIndex(column_id)).index); + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// + +class CreateARTIndexGlobalSinkState : public GlobalSinkState { +public: + //! Global index to be added to the table + unique_ptr global_index; +}; + +class CreateARTIndexLocalSinkState : public LocalSinkState { +public: + explicit CreateARTIndexLocalSinkState(ClientContext &context) : arena_allocator(Allocator::Get(context)) {}; + + unique_ptr local_index; + ArenaAllocator arena_allocator; + vector keys; + DataChunk key_chunk; + vector key_column_ids; +}; + +unique_ptr PhysicalCreateARTIndex::GetGlobalSinkState(ClientContext &context) const { + auto state = make_uniq(); + + // create the global index + auto &storage = table.GetStorage(); + state->global_index = make_uniq(storage_ids, TableIOManager::Get(storage), unbound_expressions, + info->constraint_type, storage.db); + + return (std::move(state)); +} + +unique_ptr PhysicalCreateARTIndex::GetLocalSinkState(ExecutionContext &context) const { + auto state = make_uniq(context.client); + + // create the local index + + auto &storage = table.GetStorage(); + state->local_index = make_uniq(storage_ids, TableIOManager::Get(storage), unbound_expressions, + info->constraint_type, storage.db); + + state->keys = vector(STANDARD_VECTOR_SIZE); + state->key_chunk.Initialize(Allocator::Get(context.client), state->local_index->logical_types); + + for (idx_t i = 0; i < state->key_chunk.ColumnCount(); i++) { + state->key_column_ids.push_back(i); + } + return std::move(state); +} + +SinkResultType PhysicalCreateARTIndex::SinkUnsorted(Vector &row_identifiers, OperatorSinkInput &input) const { + + auto &l_state = input.local_state.Cast(); + auto count = l_state.key_chunk.size(); + + // get the corresponding row IDs + row_identifiers.Flatten(count); + auto row_ids = FlatVector::GetData(row_identifiers); + + // insert the row IDs + auto &art = l_state.local_index->Cast(); + for (idx_t i = 0; i < count; i++) { + if (!art.Insert(art.tree, l_state.keys[i], 0, row_ids[i])) { + throw ConstraintException("Data contains duplicates on indexed column(s)"); + } + } + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkResultType PhysicalCreateARTIndex::SinkSorted(Vector &row_identifiers, OperatorSinkInput &input) const { + + auto &l_state = input.local_state.Cast(); + auto &storage = table.GetStorage(); + auto &l_index = l_state.local_index; + + // create an ART from the chunk + auto art = make_uniq(l_index->column_ids, l_index->table_io_manager, l_index->unbound_expressions, + l_index->constraint_type, storage.db, l_index->Cast().allocators); + if (!art->ConstructFromSorted(l_state.key_chunk.size(), l_state.keys, row_identifiers)) { + throw ConstraintException("Data contains duplicates on indexed column(s)"); + } + + // merge into the local ART + if (!l_index->MergeIndexes(*art)) { + throw ConstraintException("Data contains duplicates on indexed column(s)"); + } + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkResultType PhysicalCreateARTIndex::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + + D_ASSERT(chunk.ColumnCount() >= 2); + + // generate the keys for the given input + auto &l_state = input.local_state.Cast(); + l_state.key_chunk.ReferenceColumns(chunk, l_state.key_column_ids); + l_state.arena_allocator.Reset(); + ART::GenerateKeys(l_state.arena_allocator, l_state.key_chunk, l_state.keys); + + // insert the keys and their corresponding row IDs + auto &row_identifiers = chunk.data[chunk.ColumnCount() - 1]; + if (sorted) { + return SinkSorted(row_identifiers, input); + } + return SinkUnsorted(row_identifiers, input); +} + +SinkCombineResultType PhysicalCreateARTIndex::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + // merge the local index into the global index + if (!gstate.global_index->MergeIndexes(*lstate.local_index)) { + throw ConstraintException("Data contains duplicates on indexed column(s)"); + } + + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + + // here, we set the resulting global index as the newly created index of the table + auto &state = input.global_state.Cast(); + + // vacuum excess memory and verify + state.global_index->Vacuum(); + D_ASSERT(!state.global_index->VerifyAndToString(true).empty()); + + auto &storage = table.GetStorage(); + if (!storage.IsRoot()) { + throw TransactionException("Transaction conflict: cannot add an index to a table that has been altered!"); + } + + auto &schema = table.schema; + auto index_entry = schema.CreateIndex(context, *info, table).get(); + if (!index_entry) { + D_ASSERT(info->on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT); + // index already exists, but error ignored because of IF NOT EXISTS + return SinkFinalizeType::READY; + } + auto &index = index_entry->Cast(); + + index.index = state.global_index.get(); + index.info = storage.info; + for (auto &parsed_expr : info->parsed_expressions) { + index.parsed_expressions.push_back(parsed_expr->Copy()); + } + + // add index to storage + storage.info->indexes.AddIndex(std::move(state.global_index)); + return SinkFinalizeType::READY; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// + +SourceResultType PhysicalCreateARTIndex::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_function.cpp b/src/duckdb/src/execution/operator/schema/physical_create_function.cpp new file mode 100644 index 00000000..2521b208 --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_create_function.cpp @@ -0,0 +1,19 @@ +#include "duckdb/execution/operator/schema/physical_create_function.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalCreateFunction::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &catalog = Catalog::GetCatalog(context.client, info->catalog); + catalog.CreateFunction(context.client, *info); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp b/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp new file mode 100644 index 00000000..d5e340b8 --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp @@ -0,0 +1,20 @@ +#include "duckdb/execution/operator/schema/physical_create_schema.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalCreateSchema::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &catalog = Catalog::GetCatalog(context.client, info->catalog); + if (catalog.IsSystemCatalog()) { + throw BinderException("Cannot create schema in system catalog"); + } + catalog.CreateSchema(context.client, *info); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp b/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp new file mode 100644 index 00000000..80c4a2ff --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp @@ -0,0 +1,17 @@ +#include "duckdb/execution/operator/schema/physical_create_sequence.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalCreateSequence::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &catalog = Catalog::GetCatalog(context.client, info->catalog); + catalog.CreateSequence(context.client, *info); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_table.cpp b/src/duckdb/src/execution/operator/schema/physical_create_table.cpp new file mode 100644 index 00000000..3220c435 --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_create_table.cpp @@ -0,0 +1,27 @@ +#include "duckdb/execution/operator/schema/physical_create_table.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/data_table.hpp" + +namespace duckdb { + +PhysicalCreateTable::PhysicalCreateTable(LogicalOperator &op, SchemaCatalogEntry &schema, + unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::CREATE_TABLE, op.types, estimated_cardinality), schema(schema), + info(std::move(info)) { +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalCreateTable::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &catalog = schema.catalog; + catalog.CreateTable(catalog.GetCatalogTransaction(context.client), schema, *info); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_type.cpp b/src/duckdb/src/execution/operator/schema/physical_create_type.cpp new file mode 100644 index 00000000..68bc258b --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_create_type.cpp @@ -0,0 +1,85 @@ +#include "duckdb/execution/operator/schema/physical_create_type.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/common/string_map_set.hpp" + +namespace duckdb { + +PhysicalCreateType::PhysicalCreateType(unique_ptr info_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::CREATE_TYPE, {LogicalType::BIGINT}, estimated_cardinality), + info(std::move(info_p)) { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class CreateTypeGlobalState : public GlobalSinkState { +public: + explicit CreateTypeGlobalState(ClientContext &context) : result(LogicalType::VARCHAR) { + } + Vector result; + idx_t size = 0; + idx_t capacity = STANDARD_VECTOR_SIZE; + string_set_t found_strings; +}; + +unique_ptr PhysicalCreateType::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context); +} + +SinkResultType PhysicalCreateType::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + idx_t total_row_count = gstate.size + chunk.size(); + if (total_row_count > NumericLimits::Maximum()) { + throw InvalidInputException("Attempted to create ENUM of size %llu, which exceeds the maximum size of %llu", + total_row_count, NumericLimits::Maximum()); + } + UnifiedVectorFormat sdata; + chunk.data[0].ToUnifiedFormat(chunk.size(), sdata); + + if (total_row_count > gstate.capacity) { + // We must resize our result vector + gstate.result.Resize(gstate.capacity, gstate.capacity * 2); + gstate.capacity *= 2; + } + + auto src_ptr = UnifiedVectorFormat::GetData(sdata); + auto result_ptr = FlatVector::GetData(gstate.result); + // Input vector has NULL value, we just throw an exception + for (idx_t i = 0; i < chunk.size(); i++) { + idx_t idx = sdata.sel->get_index(i); + if (!sdata.validity.RowIsValid(idx)) { + throw InvalidInputException("Attempted to create ENUM type with NULL value!"); + } + auto str = src_ptr[idx]; + auto entry = gstate.found_strings.find(src_ptr[idx]); + if (entry != gstate.found_strings.end()) { + // entry was already found - skip + continue; + } + auto owned_string = StringVector::AddStringOrBlob(gstate.result, str.GetData(), str.GetSize()); + gstate.found_strings.insert(owned_string); + result_ptr[gstate.size++] = owned_string; + } + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalCreateType::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + if (IsSink()) { + D_ASSERT(info->type == LogicalType::INVALID); + auto &g_sink_state = sink_state->Cast(); + info->type = LogicalType::ENUM(g_sink_state.result, g_sink_state.size); + } + + auto &catalog = Catalog::GetCatalog(context.client, info->catalog); + catalog.CreateType(context.client, *info); + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_view.cpp b/src/duckdb/src/execution/operator/schema/physical_create_view.cpp new file mode 100644 index 00000000..948adad1 --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_create_view.cpp @@ -0,0 +1,17 @@ +#include "duckdb/execution/operator/schema/physical_create_view.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalCreateView::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &catalog = Catalog::GetCatalog(context.client, info->catalog); + catalog.CreateView(context.client, *info); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_detach.cpp b/src/duckdb/src/execution/operator/schema/physical_detach.cpp new file mode 100644 index 00000000..480890c3 --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_detach.cpp @@ -0,0 +1,22 @@ +#include "duckdb/execution/operator/schema/physical_detach.hpp" +#include "duckdb/parser/parsed_data/detach_info.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/storage/storage_extension.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalDetach::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &db_manager = DatabaseManager::Get(context.client); + db_manager.DetachDatabase(context.client, info->name, info->if_not_found); + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_drop.cpp b/src/duckdb/src/execution/operator/schema/physical_drop.cpp new file mode 100644 index 00000000..c15eb7c1 --- /dev/null +++ b/src/duckdb/src/execution/operator/schema/physical_drop.cpp @@ -0,0 +1,52 @@ +#include "duckdb/execution/operator/schema/physical_drop.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/main/settings.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalDrop::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + switch (info->type) { + case CatalogType::PREPARED_STATEMENT: { + // DEALLOCATE silently ignores errors + auto &statements = ClientData::Get(context.client).prepared_statements; + if (statements.find(info->name) != statements.end()) { + statements.erase(info->name); + } + break; + } + case CatalogType::SCHEMA_ENTRY: { + auto &catalog = Catalog::GetCatalog(context.client, info->catalog); + catalog.DropEntry(context.client, *info); + auto qualified_name = QualifiedName::Parse(info->name); + + // Check if the dropped schema was set as the current schema + auto &client_data = ClientData::Get(context.client); + auto &default_entry = client_data.catalog_search_path->GetDefault(); + auto ¤t_catalog = default_entry.catalog; + auto ¤t_schema = default_entry.schema; + D_ASSERT(info->name != DEFAULT_SCHEMA); + + if (info->catalog == current_catalog && current_schema == info->name) { + // Reset the schema to default + SchemaSetting::SetLocal(context.client, DEFAULT_SCHEMA); + } + break; + } + default: { + auto &catalog = Catalog::GetCatalog(context.client, info->catalog); + catalog.DropEntry(context.client, *info); + break; + } + } + + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/set/physical_cte.cpp b/src/duckdb/src/execution/operator/set/physical_cte.cpp new file mode 100644 index 00000000..67cfa021 --- /dev/null +++ b/src/duckdb/src/execution/operator/set/physical_cte.cpp @@ -0,0 +1,160 @@ +#include "duckdb/execution/operator/set/physical_cte.hpp" + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/aggregate_hashtable.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/parallel/event.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +PhysicalCTE::PhysicalCTE(string ctename, idx_t table_index, vector types, unique_ptr top, + unique_ptr bottom, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::CTE, std::move(types), estimated_cardinality), table_index(table_index), + ctename(std::move(ctename)) { + children.push_back(std::move(top)); + children.push_back(std::move(bottom)); +} + +PhysicalCTE::~PhysicalCTE() { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class CTEState : public GlobalSinkState { +public: + explicit CTEState(ClientContext &context, const PhysicalCTE &op) + : intermediate_table(context, op.children[1]->GetTypes()) { + } + ColumnDataCollection intermediate_table; + ColumnDataScanState scan_state; + bool initialized = false; + bool finished_scan = false; +}; + +unique_ptr PhysicalCTE::GetGlobalSinkState(ClientContext &context) const { + working_table->Reset(); + return make_uniq(context, *this); +} + +SinkResultType PhysicalCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + if (!gstate.finished_scan) { + working_table->Append(chunk); + } else { + gstate.intermediate_table.Append(chunk); + } + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalCTE::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + auto &gstate = sink_state->Cast(); + if (!gstate.initialized) { + gstate.intermediate_table.InitializeScan(gstate.scan_state); + gstate.finished_scan = false; + gstate.initialized = true; + } + if (!gstate.finished_scan) { + gstate.finished_scan = true; + ExecuteRecursivePipelines(context); + } + + gstate.intermediate_table.Scan(gstate.scan_state, chunk); + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +void PhysicalCTE::ExecuteRecursivePipelines(ExecutionContext &context) const { + if (!recursive_meta_pipeline) { + throw InternalException("Missing meta pipeline for recursive CTE"); + } + + // get and reset pipelines + vector> pipelines; + recursive_meta_pipeline->GetPipelines(pipelines, true); + for (auto &pipeline : pipelines) { + auto sink = pipeline->GetSink(); + if (sink.get() != this) { + sink->sink_state.reset(); + } + for (auto &op_ref : pipeline->GetOperators()) { + auto &op = op_ref.get(); + op.op_state.reset(); + } + pipeline->ClearSource(); + } + + // get the MetaPipelines in the recursive_meta_pipeline and reschedule them + vector> meta_pipelines; + recursive_meta_pipeline->GetMetaPipelines(meta_pipelines, true, false); + auto &executor = recursive_meta_pipeline->GetExecutor(); + vector> events; + executor.ReschedulePipelines(meta_pipelines, events); + + while (true) { + executor.WorkOnTasks(); + if (executor.HasError()) { + executor.ThrowException(); + } + bool finished = true; + for (auto &event : events) { + if (!event->IsFinished()) { + finished = false; + break; + } + } + if (finished) { + // all pipelines finished: done! + break; + } + } +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + D_ASSERT(children.size() == 2); + op_state.reset(); + sink_state.reset(); + recursive_meta_pipeline.reset(); + + auto &state = meta_pipeline.GetState(); + state.SetPipelineSource(current, *this); + + auto &executor = meta_pipeline.GetExecutor(); + executor.AddMaterializedCTE(*this); + + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + child_meta_pipeline.Build(*children[0]); + + // the RHS is the recursive pipeline + recursive_meta_pipeline = make_shared(executor, state, this); + if (meta_pipeline.HasRecursiveCTE()) { + recursive_meta_pipeline->SetRecursiveCTE(); + } + recursive_meta_pipeline->Build(*children[1]); +} + +vector> PhysicalCTE::GetSources() const { + return {*this}; +} + +string PhysicalCTE::ParamsToString() const { + string result = ""; + result += "\n[INFOSEPARATOR]\n"; + result += ctename; + result += "\n[INFOSEPARATOR]\n"; + result += StringUtil::Format("idx: %llu", table_index); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp new file mode 100644 index 00000000..ea367ed4 --- /dev/null +++ b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp @@ -0,0 +1,207 @@ +#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/aggregate_hashtable.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/parallel/event.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +PhysicalRecursiveCTE::PhysicalRecursiveCTE(string ctename, idx_t table_index, vector types, bool union_all, + unique_ptr top, unique_ptr bottom, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::RECURSIVE_CTE, std::move(types), estimated_cardinality), + ctename(std::move(ctename)), table_index(table_index), union_all(union_all) { + children.push_back(std::move(top)); + children.push_back(std::move(bottom)); +} + +PhysicalRecursiveCTE::~PhysicalRecursiveCTE() { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class RecursiveCTEState : public GlobalSinkState { +public: + explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op) + : intermediate_table(context, op.GetTypes()), new_groups(STANDARD_VECTOR_SIZE) { + ht = make_uniq(context, BufferAllocator::Get(context), op.types, + vector(), vector()); + } + + unique_ptr ht; + + bool intermediate_empty = true; + ColumnDataCollection intermediate_table; + ColumnDataScanState scan_state; + bool initialized = false; + bool finished_scan = false; + SelectionVector new_groups; +}; + +unique_ptr PhysicalRecursiveCTE::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +idx_t PhysicalRecursiveCTE::ProbeHT(DataChunk &chunk, RecursiveCTEState &state) const { + Vector dummy_addresses(LogicalType::POINTER); + + // Use the HT to eliminate duplicate rows + idx_t new_group_count = state.ht->FindOrCreateGroups(chunk, dummy_addresses, state.new_groups); + + // we only return entries we have not seen before (i.e. new groups) + chunk.Slice(state.new_groups, new_group_count); + + return new_group_count; +} + +SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + if (!union_all) { + idx_t match_count = ProbeHT(chunk, gstate); + if (match_count > 0) { + gstate.intermediate_table.Append(chunk); + } + } else { + gstate.intermediate_table.Append(chunk); + } + return SinkResultType::NEED_MORE_INPUT; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +SourceResultType PhysicalRecursiveCTE::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &gstate = sink_state->Cast(); + if (!gstate.initialized) { + gstate.intermediate_table.InitializeScan(gstate.scan_state); + gstate.finished_scan = false; + gstate.initialized = true; + } + while (chunk.size() == 0) { + if (!gstate.finished_scan) { + // scan any chunks we have collected so far + gstate.intermediate_table.Scan(gstate.scan_state, chunk); + if (chunk.size() == 0) { + gstate.finished_scan = true; + } else { + break; + } + } else { + // we have run out of chunks + // now we need to recurse + // we set up the working table as the data we gathered in this iteration of the recursion + working_table->Reset(); + working_table->Combine(gstate.intermediate_table); + // and we clear the intermediate table + gstate.finished_scan = false; + gstate.intermediate_table.Reset(); + // now we need to re-execute all of the pipelines that depend on the recursion + ExecuteRecursivePipelines(context); + + // check if we obtained any results + // if not, we are done + if (gstate.intermediate_table.Count() == 0) { + gstate.finished_scan = true; + break; + } + // set up the scan again + gstate.intermediate_table.InitializeScan(gstate.scan_state); + } + } + + return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +void PhysicalRecursiveCTE::ExecuteRecursivePipelines(ExecutionContext &context) const { + if (!recursive_meta_pipeline) { + throw InternalException("Missing meta pipeline for recursive CTE"); + } + D_ASSERT(recursive_meta_pipeline->HasRecursiveCTE()); + + // get and reset pipelines + vector> pipelines; + recursive_meta_pipeline->GetPipelines(pipelines, true); + for (auto &pipeline : pipelines) { + auto sink = pipeline->GetSink(); + if (sink.get() != this) { + sink->sink_state.reset(); + } + for (auto &op_ref : pipeline->GetOperators()) { + auto &op = op_ref.get(); + op.op_state.reset(); + } + pipeline->ClearSource(); + } + + // get the MetaPipelines in the recursive_meta_pipeline and reschedule them + vector> meta_pipelines; + recursive_meta_pipeline->GetMetaPipelines(meta_pipelines, true, false); + auto &executor = recursive_meta_pipeline->GetExecutor(); + vector> events; + executor.ReschedulePipelines(meta_pipelines, events); + + while (true) { + executor.WorkOnTasks(); + if (executor.HasError()) { + executor.ThrowException(); + } + bool finished = true; + for (auto &event : events) { + if (!event->IsFinished()) { + finished = false; + break; + } + } + if (finished) { + // all pipelines finished: done! + break; + } + } +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalRecursiveCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + op_state.reset(); + sink_state.reset(); + recursive_meta_pipeline.reset(); + + auto &state = meta_pipeline.GetState(); + state.SetPipelineSource(current, *this); + + auto &executor = meta_pipeline.GetExecutor(); + executor.AddRecursiveCTE(*this); + + // the LHS of the recursive CTE is our initial state + auto &initial_state_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + initial_state_pipeline.Build(*children[0]); + + // the RHS is the recursive pipeline + recursive_meta_pipeline = make_shared(executor, state, this); + recursive_meta_pipeline->SetRecursiveCTE(); + recursive_meta_pipeline->Build(*children[1]); +} + +vector> PhysicalRecursiveCTE::GetSources() const { + return {*this}; +} + +string PhysicalRecursiveCTE::ParamsToString() const { + string result = ""; + result += "\n[INFOSEPARATOR]\n"; + result += ctename; + result += "\n[INFOSEPARATOR]\n"; + result += StringUtil::Format("idx: %llu", table_index); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/set/physical_union.cpp b/src/duckdb/src/execution/operator/set/physical_union.cpp new file mode 100644 index 00000000..e40e00be --- /dev/null +++ b/src/duckdb/src/execution/operator/set/physical_union.cpp @@ -0,0 +1,67 @@ +#include "duckdb/execution/operator/set/physical_union.hpp" + +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parallel/thread_context.hpp" + +namespace duckdb { + +PhysicalUnion::PhysicalUnion(vector types, unique_ptr top, + unique_ptr bottom, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::UNION, std::move(types), estimated_cardinality) { + children.push_back(std::move(top)); + children.push_back(std::move(bottom)); +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalUnion::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + op_state.reset(); + sink_state.reset(); + + // order matters if any of the downstream operators are order dependent, + // or if the sink preserves order, but does not support batch indices to do so + auto sink = meta_pipeline.GetSink(); + bool order_matters = false; + if (current.IsOrderDependent()) { + order_matters = true; + } + if (sink) { + if (sink->SinkOrderDependent() || sink->RequiresBatchIndex()) { + order_matters = true; + } + if (!sink->ParallelSink()) { + order_matters = true; + } + } + + // create a union pipeline that is identical to 'current' + auto union_pipeline = meta_pipeline.CreateUnionPipeline(current, order_matters); + + // continue with the current pipeline + children[0]->BuildPipelines(current, meta_pipeline); + + if (order_matters) { + // order matters, so 'union_pipeline' must come after all pipelines created by building out 'current' + meta_pipeline.AddDependenciesFrom(union_pipeline, union_pipeline, false); + } + + // build the union pipeline + children[1]->BuildPipelines(*union_pipeline, meta_pipeline); + + // Assign proper batch index to the union pipeline + // This needs to happen after the pipelines have been built because unions can be nested + meta_pipeline.AssignNextBatchIndex(union_pipeline); +} + +vector> PhysicalUnion::GetSources() const { + vector> result; + for (auto &child : children) { + auto child_sources = child->GetSources(); + result.insert(result.end(), child_sources.begin(), child_sources.end()); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp new file mode 100644 index 00000000..826b1d55 --- /dev/null +++ b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp @@ -0,0 +1,313 @@ +#include "duckdb/execution/perfect_aggregate_hashtable.hpp" + +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +PerfectAggregateHashTable::PerfectAggregateHashTable(ClientContext &context, Allocator &allocator, + const vector &group_types_p, + vector payload_types_p, + vector aggregate_objects_p, + vector group_minima_p, vector required_bits_p) + : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), + addresses(LogicalType::POINTER), required_bits(std::move(required_bits_p)), total_required_bits(0), + group_minima(std::move(group_minima_p)), sel(STANDARD_VECTOR_SIZE), + aggregate_allocator(make_uniq(allocator)) { + for (auto &group_bits : required_bits) { + total_required_bits += group_bits; + } + // the total amount of groups we allocate space for is 2^required_bits + total_groups = (uint64_t)1 << total_required_bits; + // we don't need to store the groups in a perfect hash table, since the group keys can be deduced by their location + grouping_columns = group_types_p.size(); + layout.Initialize(std::move(aggregate_objects_p)); + tuple_size = layout.GetRowWidth(); + + // allocate and null initialize the data + owned_data = make_unsafe_uniq_array(tuple_size * total_groups); + data = owned_data.get(); + + // set up the empty payloads for every tuple, and initialize the "occupied" flag to false + group_is_set = make_unsafe_uniq_array(total_groups); + memset(group_is_set.get(), 0, total_groups * sizeof(bool)); + + // initialize the hash table for each entry + auto address_data = FlatVector::GetData(addresses); + idx_t init_count = 0; + for (idx_t i = 0; i < total_groups; i++) { + address_data[init_count] = uintptr_t(data) + (tuple_size * i); + init_count++; + if (init_count == STANDARD_VECTOR_SIZE) { + RowOperations::InitializeStates(layout, addresses, *FlatVector::IncrementalSelectionVector(), init_count); + init_count = 0; + } + } + RowOperations::InitializeStates(layout, addresses, *FlatVector::IncrementalSelectionVector(), init_count); +} + +PerfectAggregateHashTable::~PerfectAggregateHashTable() { + Destroy(); +} + +template +static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value &min, uintptr_t *address_data, + idx_t current_shift, idx_t count) { + auto data = UnifiedVectorFormat::GetData(group_data); + auto min_val = min.GetValueUnsafe(); + if (!group_data.validity.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto index = group_data.sel->get_index(i); + // check if the value is NULL + // NULL groups are considered as "0" in the hash table + // that is to say, they have no effect on the position of the element (because 0 << shift is 0) + // we only need to handle non-null values here + if (group_data.validity.RowIsValid(index)) { + D_ASSERT(data[index] >= min_val); + uintptr_t adjusted_value = (data[index] - min_val) + 1; + address_data[i] += adjusted_value << current_shift; + } + } + } else { + // no null values: we can directly compute the addresses + for (idx_t i = 0; i < count; i++) { + auto index = group_data.sel->get_index(i); + uintptr_t adjusted_value = (data[index] - min_val) + 1; + address_data[i] += adjusted_value << current_shift; + } + } +} + +static void ComputeGroupLocation(Vector &group, Value &min, uintptr_t *address_data, idx_t current_shift, idx_t count) { + UnifiedVectorFormat vdata; + group.ToUnifiedFormat(count, vdata); + + switch (group.GetType().InternalType()) { + case PhysicalType::INT8: + ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); + break; + case PhysicalType::INT16: + ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); + break; + case PhysicalType::INT32: + ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); + break; + case PhysicalType::INT64: + ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); + break; + case PhysicalType::UINT8: + ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); + break; + case PhysicalType::UINT16: + ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); + break; + case PhysicalType::UINT32: + ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); + break; + case PhysicalType::UINT64: + ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); + break; + default: + throw InternalException("Unsupported group type for perfect aggregate hash table"); + } +} + +void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) { + // first we need to find the location in the HT of each of the groups + auto address_data = FlatVector::GetData(addresses); + // zero-initialize the address data + memset(address_data, 0, groups.size() * sizeof(uintptr_t)); + D_ASSERT(groups.ColumnCount() == group_minima.size()); + + // then compute the actual group location by iterating over each of the groups + idx_t current_shift = total_required_bits; + for (idx_t i = 0; i < groups.ColumnCount(); i++) { + current_shift -= required_bits[i]; + ComputeGroupLocation(groups.data[i], group_minima[i], address_data, current_shift, groups.size()); + } + // now we have the HT entry number for every tuple + // compute the actual pointer to the data by adding it to the base HT pointer and multiplying by the tuple size + for (idx_t i = 0; i < groups.size(); i++) { + const auto group = address_data[i]; + D_ASSERT(group < total_groups); + group_is_set[group] = true; + address_data[i] = uintptr_t(data) + group * tuple_size; + } + + // after finding the group location we update the aggregates + idx_t payload_idx = 0; + auto &aggregates = layout.GetAggregates(); + RowOperationsState row_state(*aggregate_allocator); + for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { + auto &aggregate = aggregates[aggr_idx]; + auto input_count = (idx_t)aggregate.child_count; + if (aggregate.filter) { + RowOperations::UpdateFilteredStates(row_state, filter_set.GetFilterData(aggr_idx), aggregate, addresses, + payload, payload_idx); + } else { + RowOperations::UpdateStates(row_state, aggregate, addresses, payload, payload_idx, payload.size()); + } + // move to the next aggregate + payload_idx += input_count; + VectorOperations::AddInPlace(addresses, aggregate.payload_size, payload.size()); + } +} + +void PerfectAggregateHashTable::Combine(PerfectAggregateHashTable &other) { + D_ASSERT(total_groups == other.total_groups); + D_ASSERT(tuple_size == other.tuple_size); + + Vector source_addresses(LogicalType::POINTER); + Vector target_addresses(LogicalType::POINTER); + auto source_addresses_ptr = FlatVector::GetData(source_addresses); + auto target_addresses_ptr = FlatVector::GetData(target_addresses); + + // iterate over all entries of both hash tables and call combine for all entries that can be combined + data_ptr_t source_ptr = other.data; + data_ptr_t target_ptr = data; + idx_t combine_count = 0; + RowOperationsState row_state(*aggregate_allocator); + for (idx_t i = 0; i < total_groups; i++) { + auto has_entry_source = other.group_is_set[i]; + // we only have any work to do if the source has an entry for this group + if (has_entry_source) { + group_is_set[i] = true; + source_addresses_ptr[combine_count] = source_ptr; + target_addresses_ptr[combine_count] = target_ptr; + combine_count++; + if (combine_count == STANDARD_VECTOR_SIZE) { + RowOperations::CombineStates(row_state, layout, source_addresses, target_addresses, combine_count); + combine_count = 0; + } + } + source_ptr += tuple_size; + target_ptr += tuple_size; + } + RowOperations::CombineStates(row_state, layout, source_addresses, target_addresses, combine_count); + + // FIXME: after moving the arena allocator, we currently have to ensure that the pointer is not nullptr, because the + // FIXME: Destroy()-function of the hash table expects an allocator in some cases (e.g., for sorted aggregates) + stored_allocators.push_back(std::move(other.aggregate_allocator)); + other.aggregate_allocator = make_uniq(allocator); +} + +template +static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, idx_t mask, idx_t shift, + idx_t entry_count, Vector &result) { + auto data = FlatVector::GetData(result); + auto &validity_mask = FlatVector::Validity(result); + auto min_data = min.GetValueUnsafe(); + for (idx_t i = 0; i < entry_count; i++) { + // extract the value of this group from the total group index + auto group_index = (group_values[i] >> shift) & mask; + if (group_index == 0) { + // if it is 0, the value is NULL + validity_mask.SetInvalid(i); + } else { + // otherwise we add the value (minus 1) to the min value + data[i] = min_data + group_index - 1; + } + } +} + +static void ReconstructGroupVector(uint32_t group_values[], Value &min, idx_t required_bits, idx_t shift, + idx_t entry_count, Vector &result) { + // construct the mask for this entry + idx_t mask = ((uint64_t)1 << required_bits) - 1; + switch (result.GetType().InternalType()) { + case PhysicalType::INT8: + ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); + break; + case PhysicalType::INT16: + ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); + break; + case PhysicalType::INT32: + ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); + break; + case PhysicalType::INT64: + ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); + break; + case PhysicalType::UINT8: + ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); + break; + case PhysicalType::UINT16: + ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); + break; + case PhysicalType::UINT32: + ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); + break; + case PhysicalType::UINT64: + ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); + break; + default: + throw InternalException("Invalid type for perfect aggregate HT group"); + } +} + +void PerfectAggregateHashTable::Scan(idx_t &scan_position, DataChunk &result) { + auto data_pointers = FlatVector::GetData(addresses); + uint32_t group_values[STANDARD_VECTOR_SIZE]; + + // iterate over the HT until we either have exhausted the entire HT, or + idx_t entry_count = 0; + for (; scan_position < total_groups; scan_position++) { + if (group_is_set[scan_position]) { + // this group is set: add it to the set of groups to extract + data_pointers[entry_count] = data + tuple_size * scan_position; + group_values[entry_count] = scan_position; + entry_count++; + if (entry_count == STANDARD_VECTOR_SIZE) { + scan_position++; + break; + } + } + } + if (entry_count == 0) { + // no entries found + return; + } + // first reconstruct the groups from the group index + idx_t shift = total_required_bits; + for (idx_t i = 0; i < grouping_columns; i++) { + shift -= required_bits[i]; + ReconstructGroupVector(group_values, group_minima[i], required_bits[i], shift, entry_count, result.data[i]); + } + // then construct the payloads + result.SetCardinality(entry_count); + RowOperationsState row_state(*aggregate_allocator); + RowOperations::FinalizeStates(row_state, layout, addresses, result, grouping_columns); +} + +void PerfectAggregateHashTable::Destroy() { + // check if there is any destructor to call + bool has_destructor = false; + for (auto &aggr : layout.GetAggregates()) { + if (aggr.function.destructor) { + has_destructor = true; + } + } + if (!has_destructor) { + return; + } + // there are aggregates with destructors: loop over the hash table + // and call the destructor method for each of the aggregates + auto data_pointers = FlatVector::GetData(addresses); + idx_t count = 0; + + // iterate over all initialised slots of the hash table + RowOperationsState row_state(*aggregate_allocator); + data_ptr_t payload_ptr = data; + for (idx_t i = 0; i < total_groups; i++) { + if (group_is_set[i]) { + data_pointers[count++] = payload_ptr; + if (count == STANDARD_VECTOR_SIZE) { + RowOperations::DestroyStates(row_state, layout, addresses, count); + count = 0; + } + } + payload_ptr += tuple_size; + } + RowOperations::DestroyStates(row_state, layout, addresses, count); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_operator.cpp b/src/duckdb/src/execution/physical_operator.cpp new file mode 100644 index 00000000..fae06215 --- /dev/null +++ b/src/duckdb/src/execution/physical_operator.cpp @@ -0,0 +1,305 @@ +#include "duckdb/execution/physical_operator.hpp" + +#include "duckdb/common/printer.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/tree_renderer.hpp" +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +string PhysicalOperator::GetName() const { + return PhysicalOperatorToString(type); +} + +string PhysicalOperator::ToString() const { + TreeRenderer renderer; + return renderer.ToString(*this); +} + +// LCOV_EXCL_START +void PhysicalOperator::Print() const { + Printer::Print(ToString()); +} +// LCOV_EXCL_STOP + +vector> PhysicalOperator::GetChildren() const { + vector> result; + for (auto &child : children) { + result.push_back(*child); + } + return result; +} + +//===--------------------------------------------------------------------===// +// Operator +//===--------------------------------------------------------------------===// +// LCOV_EXCL_START +unique_ptr PhysicalOperator::GetOperatorState(ExecutionContext &context) const { + return make_uniq(); +} + +unique_ptr PhysicalOperator::GetGlobalOperatorState(ClientContext &context) const { + return make_uniq(); +} + +OperatorResultType PhysicalOperator::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const { + throw InternalException("Calling Execute on a node that is not an operator!"); +} + +OperatorFinalizeResultType PhysicalOperator::FinalExecute(ExecutionContext &context, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const { + throw InternalException("Calling FinalExecute on a node that is not an operator!"); +} +// LCOV_EXCL_STOP + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +unique_ptr PhysicalOperator::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(); +} + +unique_ptr PhysicalOperator::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(); +} + +// LCOV_EXCL_START +SourceResultType PhysicalOperator::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + throw InternalException("Calling GetData on a node that is not a source!"); +} + +idx_t PhysicalOperator::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, + LocalSourceState &lstate) const { + throw InternalException("Calling GetBatchIndex on a node that does not support it"); +} + +double PhysicalOperator::GetProgress(ClientContext &context, GlobalSourceState &gstate) const { + return -1; +} +// LCOV_EXCL_STOP + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +// LCOV_EXCL_START +SinkResultType PhysicalOperator::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + throw InternalException("Calling Sink on a node that is not a sink!"); +} + +// LCOV_EXCL_STOP + +SinkCombineResultType PhysicalOperator::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { + return SinkCombineResultType::FINISHED; +} + +SinkFinalizeType PhysicalOperator::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + return SinkFinalizeType::READY; +} + +void PhysicalOperator::NextBatch(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate_p) const { +} + +unique_ptr PhysicalOperator::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(); +} + +unique_ptr PhysicalOperator::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(); +} + +idx_t PhysicalOperator::GetMaxThreadMemory(ClientContext &context) { + // Memory usage per thread should scale with max mem / num threads + // We take 1/4th of this, to be conservative + idx_t max_memory = BufferManager::GetBufferManager(context).GetMaxMemory(); + idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + return (max_memory / num_threads) / 4; +} + +bool PhysicalOperator::OperatorCachingAllowed(ExecutionContext &context) { + if (!context.client.config.enable_caching_operators) { + return false; + } else if (!context.pipeline) { + return false; + } else if (!context.pipeline->GetSink()) { + return false; + } else if (context.pipeline->GetSink()->RequiresBatchIndex()) { + return false; + } else if (context.pipeline->IsOrderDependent()) { + return false; + } + + return true; +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalOperator::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + op_state.reset(); + + auto &state = meta_pipeline.GetState(); + if (IsSink()) { + // operator is a sink, build a pipeline + sink_state.reset(); + D_ASSERT(children.size() == 1); + + // single operator: the operator becomes the data source of the current pipeline + state.SetPipelineSource(current, *this); + + // we create a new pipeline starting from the child + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + child_meta_pipeline.Build(*children[0]); + } else { + // operator is not a sink! recurse in children + if (children.empty()) { + // source + state.SetPipelineSource(current, *this); + } else { + if (children.size() != 1) { + throw InternalException("Operator not supported in BuildPipelines"); + } + state.AddPipelineOperator(current, *this); + children[0]->BuildPipelines(current, meta_pipeline); + } + } +} + +vector> PhysicalOperator::GetSources() const { + vector> result; + if (IsSink()) { + D_ASSERT(children.size() == 1); + result.push_back(*this); + return result; + } else { + if (children.empty()) { + // source + result.push_back(*this); + return result; + } else { + if (children.size() != 1) { + throw InternalException("Operator not supported in GetSource"); + } + return children[0]->GetSources(); + } + } +} + +bool PhysicalOperator::AllSourcesSupportBatchIndex() const { + auto sources = GetSources(); + for (auto &source : sources) { + if (!source.get().SupportsBatchIndex()) { + return false; + } + } + return true; +} + +void PhysicalOperator::Verify() { +#ifdef DEBUG + auto sources = GetSources(); + D_ASSERT(!sources.empty()); + for (auto &child : children) { + child->Verify(); + } +#endif +} + +bool CachingPhysicalOperator::CanCacheType(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::LIST: + case LogicalTypeId::MAP: + return false; + case LogicalTypeId::STRUCT: { + auto &entries = StructType::GetChildTypes(type); + for (auto &entry : entries) { + if (!CanCacheType(entry.second)) { + return false; + } + } + return true; + } + default: + return true; + } +} + +CachingPhysicalOperator::CachingPhysicalOperator(PhysicalOperatorType type, vector types_p, + idx_t estimated_cardinality) + : PhysicalOperator(type, std::move(types_p), estimated_cardinality) { + + caching_supported = true; + for (auto &col_type : types) { + if (!CanCacheType(col_type)) { + caching_supported = false; + break; + } + } +} + +OperatorResultType CachingPhysicalOperator::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state_p) const { + auto &state = state_p.Cast(); + + // Execute child operator + auto child_result = ExecuteInternal(context, input, chunk, gstate, state); + +#if STANDARD_VECTOR_SIZE >= 128 + if (!state.initialized) { + state.initialized = true; + state.can_cache_chunk = caching_supported && PhysicalOperator::OperatorCachingAllowed(context); + } + if (!state.can_cache_chunk) { + return child_result; + } + // TODO chunk size of 0 should not result in a cache being created! + if (chunk.size() < CACHE_THRESHOLD) { + // we have filtered out a significant amount of tuples + // add this chunk to the cache and continue + + if (!state.cached_chunk) { + state.cached_chunk = make_uniq(); + state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); + } + + state.cached_chunk->Append(chunk); + + if (state.cached_chunk->size() >= (STANDARD_VECTOR_SIZE - CACHE_THRESHOLD) || + child_result == OperatorResultType::FINISHED) { + // chunk cache full: return it + chunk.Move(*state.cached_chunk); + state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); + return child_result; + } else { + // chunk cache not full return empty result + chunk.Reset(); + } + } +#endif + + return child_result; +} + +OperatorFinalizeResultType CachingPhysicalOperator::FinalExecute(ExecutionContext &context, DataChunk &chunk, + GlobalOperatorState &gstate, + OperatorState &state_p) const { + auto &state = state_p.Cast(); + if (state.cached_chunk) { + chunk.Move(*state.cached_chunk); + state.cached_chunk.reset(); + } else { + chunk.SetCardinality(0); + } + return OperatorFinalizeResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp b/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp new file mode 100644 index 00000000..4160783e --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp @@ -0,0 +1,243 @@ +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp" +#include "duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp" +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" + +namespace duckdb { + +static uint32_t RequiredBitsForValue(uint32_t n) { + idx_t required_bits = 0; + while (n > 0) { + n >>= 1; + required_bits++; + } + return required_bits; +} + +template +hugeint_t GetRangeHugeint(const BaseStatistics &nstats) { + return Hugeint::Convert(NumericStats::GetMax(nstats)) - Hugeint::Convert(NumericStats::GetMin(nstats)); +} + +static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate &op, vector &bits_per_group) { + if (op.grouping_sets.size() > 1 || !op.grouping_functions.empty()) { + return false; + } + idx_t perfect_hash_bits = 0; + if (op.group_stats.empty()) { + op.group_stats.resize(op.groups.size()); + } + for (idx_t group_idx = 0; group_idx < op.groups.size(); group_idx++) { + auto &group = op.groups[group_idx]; + auto &stats = op.group_stats[group_idx]; + + switch (group->return_type.InternalType()) { + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + break; + default: + // we only support simple integer types for perfect hashing + return false; + } + // check if the group has stats available + auto &group_type = group->return_type; + if (!stats) { + // no stats, but we might still be able to use perfect hashing if the type is small enough + // for small types we can just set the stats to [type_min, type_max] + switch (group_type.InternalType()) { + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + break; + default: + // type is too large and there are no stats: skip perfect hashing + return false; + } + // construct stats with the min and max value of the type + stats = NumericStats::CreateUnknown(group_type).ToUnique(); + NumericStats::SetMin(*stats, Value::MinimumValue(group_type)); + NumericStats::SetMax(*stats, Value::MaximumValue(group_type)); + } + auto &nstats = *stats; + + if (!NumericStats::HasMinMax(nstats)) { + return false; + } + + if (NumericStats::Max(*stats) < NumericStats::Min(*stats)) { + // May result in underflow + return false; + } + + // we have a min and a max value for the stats: use that to figure out how many bits we have + // we add two here, one for the NULL value, and one to make the computation one-indexed + // (e.g. if min and max are the same, we still need one entry in total) + hugeint_t range_h; + switch (group_type.InternalType()) { + case PhysicalType::INT8: + range_h = GetRangeHugeint(nstats); + break; + case PhysicalType::INT16: + range_h = GetRangeHugeint(nstats); + break; + case PhysicalType::INT32: + range_h = GetRangeHugeint(nstats); + break; + case PhysicalType::INT64: + range_h = GetRangeHugeint(nstats); + break; + case PhysicalType::UINT8: + range_h = GetRangeHugeint(nstats); + break; + case PhysicalType::UINT16: + range_h = GetRangeHugeint(nstats); + break; + case PhysicalType::UINT32: + range_h = GetRangeHugeint(nstats); + break; + case PhysicalType::UINT64: + range_h = GetRangeHugeint(nstats); + break; + default: + throw InternalException("Unsupported type for perfect hash (should be caught before)"); + } + + uint64_t range; + if (!Hugeint::TryCast(range_h, range)) { + return false; + } + + // bail out on any range bigger than 2^32 + if (range >= NumericLimits::Maximum()) { + return false; + } + + range += 2; + // figure out how many bits we need + idx_t required_bits = RequiredBitsForValue(range); + bits_per_group.push_back(required_bits); + perfect_hash_bits += required_bits; + // check if we have exceeded the bits for the hash + if (perfect_hash_bits > ClientConfig::GetConfig(context).perfect_ht_threshold) { + // too many bits for perfect hash + return false; + } + } + for (auto &expression : op.expressions) { + auto &aggregate = expression->Cast(); + if (aggregate.IsDistinct() || !aggregate.function.combine) { + // distinct aggregates are not supported in perfect hash aggregates + return false; + } + } + return true; +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalAggregate &op) { + unique_ptr groupby; + D_ASSERT(op.children.size() == 1); + + auto plan = CreatePlan(*op.children[0]); + + plan = ExtractAggregateExpressions(std::move(plan), op.expressions, op.groups); + + if (op.groups.empty() && op.grouping_sets.size() <= 1) { + // no groups, check if we can use a simple aggregation + // special case: aggregate entire columns together + bool use_simple_aggregation = true; + for (auto &expression : op.expressions) { + auto &aggregate = expression->Cast(); + if (!aggregate.function.simple_update) { + // unsupported aggregate for simple aggregation: use hash aggregation + use_simple_aggregation = false; + break; + } + } + if (use_simple_aggregation) { + groupby = make_uniq_base(op.types, std::move(op.expressions), + op.estimated_cardinality); + } else { + groupby = make_uniq_base( + context, op.types, std::move(op.expressions), op.estimated_cardinality); + } + } else { + // groups! create a GROUP BY aggregator + // use a perfect hash aggregate if possible + vector required_bits; + if (CanUsePerfectHashAggregate(context, op, required_bits)) { + groupby = make_uniq_base( + context, op.types, std::move(op.expressions), std::move(op.groups), std::move(op.group_stats), + std::move(required_bits), op.estimated_cardinality); + } else { + groupby = make_uniq_base( + context, op.types, std::move(op.expressions), std::move(op.groups), std::move(op.grouping_sets), + std::move(op.grouping_functions), op.estimated_cardinality); + } + } + groupby->children.push_back(std::move(plan)); + return groupby; +} + +unique_ptr +PhysicalPlanGenerator::ExtractAggregateExpressions(unique_ptr child, + vector> &aggregates, + vector> &groups) { + vector> expressions; + vector types; + + // bind sorted aggregates + for (auto &aggr : aggregates) { + auto &bound_aggr = aggr->Cast(); + if (bound_aggr.order_bys) { + // sorted aggregate! + FunctionBinder::BindSortedAggregate(context, bound_aggr, groups); + } + } + for (auto &group : groups) { + auto ref = make_uniq(group->return_type, expressions.size()); + types.push_back(group->return_type); + expressions.push_back(std::move(group)); + group = std::move(ref); + } + for (auto &aggr : aggregates) { + auto &bound_aggr = aggr->Cast(); + for (auto &child : bound_aggr.children) { + auto ref = make_uniq(child->return_type, expressions.size()); + types.push_back(child->return_type); + expressions.push_back(std::move(child)); + child = std::move(ref); + } + if (bound_aggr.filter) { + auto &filter = bound_aggr.filter; + auto ref = make_uniq(filter->return_type, expressions.size()); + types.push_back(filter->return_type); + expressions.push_back(std::move(filter)); + bound_aggr.filter = std::move(ref); + } + } + if (expressions.empty()) { + return child; + } + auto projection = + make_uniq(std::move(types), std::move(expressions), child->estimated_cardinality); + projection->children.push_back(std::move(child)); + return std::move(projection); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_any_join.cpp b/src/duckdb/src/execution/physical_plan/plan_any_join.cpp new file mode 100644 index 00000000..5e8ee622 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_any_join.cpp @@ -0,0 +1,20 @@ +#include "duckdb/execution/operator/join/physical_blockwise_nl_join.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_any_join.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalAnyJoin &op) { + // first visit the child nodes + D_ASSERT(op.children.size() == 2); + D_ASSERT(op.condition); + + auto left = CreatePlan(*op.children[0]); + auto right = CreatePlan(*op.children[1]); + + // create the blockwise NL join + return make_uniq(op, std::move(left), std::move(right), std::move(op.condition), + op.join_type, op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp new file mode 100644 index 00000000..927defa4 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp @@ -0,0 +1,125 @@ +#include "duckdb/execution/operator/aggregate/physical_window.hpp" +#include "duckdb/execution/operator/join/physical_asof_join.hpp" +#include "duckdb/execution/operator/join/physical_iejoin.hpp" +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) { + // now visit the children + D_ASSERT(op.children.size() == 2); + idx_t lhs_cardinality = op.children[0]->EstimateCardinality(context); + idx_t rhs_cardinality = op.children[1]->EstimateCardinality(context); + auto left = CreatePlan(*op.children[0]); + auto right = CreatePlan(*op.children[1]); + D_ASSERT(left && right); + + // Validate + vector equi_indexes; + auto asof_idx = op.conditions.size(); + for (size_t c = 0; c < op.conditions.size(); ++c) { + auto &cond = op.conditions[c]; + switch (cond.comparison) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + equi_indexes.emplace_back(c); + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_LESSTHAN: + D_ASSERT(asof_idx == op.conditions.size()); + asof_idx = c; + break; + default: + throw InternalException("Invalid ASOF JOIN comparison"); + } + } + D_ASSERT(asof_idx < op.conditions.size()); + + if (!ClientConfig::GetConfig(context).force_asof_iejoin) { + return make_uniq(op, std::move(left), std::move(right)); + } + + // Strip extra column from rhs projections + auto &right_projection_map = op.right_projection_map; + if (right_projection_map.empty()) { + const auto right_count = right->types.size(); + right_projection_map.reserve(right_count); + for (column_t i = 0; i < right_count; ++i) { + right_projection_map.emplace_back(i); + } + } + + // Debug implementation: IEJoin of Window + // LEAD(asof_column, 1, infinity) OVER (PARTITION BY equi_column... ORDER BY asof_column) AS asof_end + auto &asof_comp = op.conditions[asof_idx]; + auto &asof_column = asof_comp.right; + auto asof_type = asof_column->return_type; + auto asof_end = make_uniq(ExpressionType::WINDOW_LEAD, asof_type, nullptr, nullptr); + asof_end->children.emplace_back(asof_column->Copy()); + // TODO: If infinities are not supported for a type, fake them by looking at LHS statistics? + asof_end->offset_expr = make_uniq(Value::BIGINT(1)); + for (auto equi_idx : equi_indexes) { + asof_end->partitions.emplace_back(op.conditions[equi_idx].right->Copy()); + } + switch (asof_comp.comparison) { + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHAN: + asof_end->orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, asof_column->Copy()); + asof_end->default_expr = make_uniq(Value::Infinity(asof_type)); + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_LESSTHAN: + asof_end->orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_FIRST, asof_column->Copy()); + asof_end->default_expr = make_uniq(Value::NegativeInfinity(asof_type)); + break; + default: + throw InternalException("Invalid ASOF JOIN ordering for WINDOW"); + } + + asof_end->start = WindowBoundary::UNBOUNDED_PRECEDING; + asof_end->end = WindowBoundary::CURRENT_ROW_ROWS; + + vector> window_select; + window_select.emplace_back(std::move(asof_end)); + + auto &window_types = op.children[1]->types; + window_types.emplace_back(asof_type); + + auto window = make_uniq(window_types, std::move(window_select), rhs_cardinality); + window->children.emplace_back(std::move(right)); + + // IEJoin(left, window, conditions || asof_comp ~op asof_end) + JoinCondition asof_upper; + asof_upper.left = asof_comp.left->Copy(); + asof_upper.right = make_uniq(asof_type, window_types.size() - 1); + switch (asof_comp.comparison) { + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + asof_upper.comparison = ExpressionType::COMPARE_LESSTHAN; + break; + case ExpressionType::COMPARE_GREATERTHAN: + asof_upper.comparison = ExpressionType::COMPARE_LESSTHANOREQUALTO; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + asof_upper.comparison = ExpressionType::COMPARE_GREATERTHAN; + break; + case ExpressionType::COMPARE_LESSTHAN: + asof_upper.comparison = ExpressionType::COMPARE_GREATERTHANOREQUALTO; + break; + default: + throw InternalException("Invalid ASOF JOIN comparison for IEJoin"); + } + + op.conditions.emplace_back(std::move(asof_upper)); + + return make_uniq(op, std::move(left), std::move(window), std::move(op.conditions), op.join_type, + lhs_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_column_data_get.cpp b/src/duckdb/src/execution/physical_plan/plan_column_data_get.cpp new file mode 100644 index 00000000..49e23c6f --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_column_data_get.cpp @@ -0,0 +1,17 @@ +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_column_data_get.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalColumnDataGet &op) { + D_ASSERT(op.children.size() == 0); + D_ASSERT(op.collection); + + // create a PhysicalChunkScan pointing towards the owned collection + auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, + op.estimated_cardinality, std::move(op.collection)); + return std::move(chunk_scan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp new file mode 100644 index 00000000..f03b4915 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp @@ -0,0 +1,357 @@ +#include "duckdb/execution/operator/join/perfect_hash_join_executor.hpp" +#include "duckdb/execution/operator/join/physical_cross_product.hpp" +#include "duckdb/execution/operator/join/physical_hash_join.hpp" +#include "duckdb/execution/operator/join/physical_iejoin.hpp" +#include "duckdb/execution/operator/join/physical_index_join.hpp" +#include "duckdb/execution/operator/join/physical_nested_loop_join.hpp" +#include "duckdb/execution/operator/join/physical_piecewise_merge_join.hpp" +#include "duckdb/execution/operator/scan/physical_table_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/execution/operator/join/physical_blockwise_nl_join.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" + +namespace duckdb { + +static bool CanPlanIndexJoin(ClientContext &context, TableScanBindData &bind_data, PhysicalTableScan &scan) { + auto &table = bind_data.table; + auto &transaction = DuckTransaction::Get(context, table.catalog); + auto &local_storage = LocalStorage::Get(transaction); + if (local_storage.Find(table.GetStorage())) { + // transaction local appends: skip index join + return false; + } + if (scan.table_filters && !scan.table_filters->filters.empty()) { + // table scan filters + return false; + } + return true; +} + +bool ExtractNumericValue(Value val, int64_t &result) { + if (!val.type().IsIntegral()) { + switch (val.type().InternalType()) { + case PhysicalType::INT16: + result = val.GetValueUnsafe(); + break; + case PhysicalType::INT32: + result = val.GetValueUnsafe(); + break; + case PhysicalType::INT64: + result = val.GetValueUnsafe(); + break; + default: + return false; + } + } else { + if (!val.DefaultTryCastAs(LogicalType::BIGINT)) { + return false; + } + result = val.GetValue(); + } + return true; +} + +void CheckForPerfectJoinOpt(LogicalComparisonJoin &op, PerfectHashJoinStats &join_state) { + // we only do this optimization for inner joins + if (op.join_type != JoinType::INNER) { + return; + } + // with one condition + if (op.conditions.size() != 1) { + return; + } + // with propagated statistics + if (op.join_stats.empty()) { + return; + } + for (auto &type : op.children[1]->types) { + switch (type.InternalType()) { + case PhysicalType::STRUCT: + case PhysicalType::LIST: + return; + default: + break; + } + } + // with equality condition and null values not equal + for (auto &&condition : op.conditions) { + if (condition.comparison != ExpressionType::COMPARE_EQUAL) { + return; + } + } + // with integral internal types + for (auto &&join_stat : op.join_stats) { + if (!TypeIsInteger(join_stat->GetType().InternalType()) || + join_stat->GetType().InternalType() == PhysicalType::INT128) { + // perfect join not possible for non-integral types or hugeint + return; + } + } + + // and when the build range is smaller than the threshold + auto &stats_build = *op.join_stats[0].get(); // lhs stats + if (!NumericStats::HasMinMax(stats_build)) { + return; + } + int64_t min_value, max_value; + if (!ExtractNumericValue(NumericStats::Min(stats_build), min_value) || + !ExtractNumericValue(NumericStats::Max(stats_build), max_value)) { + return; + } + int64_t build_range; + if (!TrySubtractOperator::Operation(max_value, min_value, build_range)) { + return; + } + + // Fill join_stats for invisible join + auto &stats_probe = *op.join_stats[1].get(); // rhs stats + if (!NumericStats::HasMinMax(stats_probe)) { + return; + } + + // The max size our build must have to run the perfect HJ + const idx_t MAX_BUILD_SIZE = 1000000; + join_state.probe_min = NumericStats::Min(stats_probe); + join_state.probe_max = NumericStats::Max(stats_probe); + join_state.build_min = NumericStats::Min(stats_build); + join_state.build_max = NumericStats::Max(stats_build); + join_state.estimated_cardinality = op.estimated_cardinality; + join_state.build_range = build_range; + if (join_state.build_range > MAX_BUILD_SIZE) { + return; + } + if (NumericStats::Min(stats_build) <= NumericStats::Min(stats_probe) && + NumericStats::Max(stats_probe) <= NumericStats::Max(stats_build)) { + join_state.is_probe_in_domain = true; + } + join_state.is_build_small = true; + return; +} + +static optional_ptr CanUseIndexJoin(TableScanBindData &tbl, Expression &expr) { + optional_ptr result; + tbl.table.GetStorage().info->indexes.Scan([&](Index &index) { + if (index.unbound_expressions.size() != 1) { + return false; + } + if (expr.alias == index.unbound_expressions[0]->alias) { + result = &index; + return true; + } + return false; + }); + return result; +} + +optional_ptr CheckIndexJoin(ClientContext &context, LogicalComparisonJoin &op, PhysicalOperator &plan, + Expression &condition) { + if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + return nullptr; + } + // check if one of the tables has an index on column + if (op.join_type != JoinType::INNER) { + return nullptr; + } + if (op.conditions.size() != 1) { + return nullptr; + } + // check if the child is (1) a table scan, and (2) has an index on the join condition + if (plan.type != PhysicalOperatorType::TABLE_SCAN) { + return nullptr; + } + auto &tbl_scan = plan.Cast(); + auto tbl_data = dynamic_cast(tbl_scan.bind_data.get()); + if (!tbl_data) { + return nullptr; + } + optional_ptr result; + if (CanPlanIndexJoin(context, *tbl_data, tbl_scan)) { + result = CanUseIndexJoin(*tbl_data, condition); + } + return result; +} + +static bool PlanIndexJoin(ClientContext &context, LogicalComparisonJoin &op, unique_ptr &plan, + unique_ptr &left, unique_ptr &right, + optional_ptr index, bool swap_condition = false) { + if (!index) { + return false; + } + + // index joins are disabled if enable_optimizer is false + if (!ClientConfig::GetConfig(context).enable_optimizer) { + return false; + } + + // index joins are disabled on default + auto force_index_join = ClientConfig::GetConfig(context).force_index_join; + if (!ClientConfig::GetConfig(context).enable_index_join && !force_index_join) { + return false; + } + + // check if the cardinality difference justifies an index join + auto index_join_is_applicable = left->estimated_cardinality < 0.01 * right->estimated_cardinality; + if (!index_join_is_applicable && !force_index_join) { + return false; + } + + // plan the index join + if (swap_condition) { + swap(op.conditions[0].left, op.conditions[0].right); + swap(op.left_projection_map, op.right_projection_map); + } + D_ASSERT(right->type == PhysicalOperatorType::TABLE_SCAN); + auto &tbl_scan = right->Cast(); + + plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), op.join_type, + op.left_projection_map, op.right_projection_map, tbl_scan.column_ids, *index, + !swap_condition, op.estimated_cardinality); + return true; +} + +static bool PlanIndexJoin(ClientContext &context, LogicalComparisonJoin &op, unique_ptr &plan, + unique_ptr &left, unique_ptr &right) { + if (op.conditions.empty()) { + return false; + } + // check if we can plan an index join on the RHS + auto right_index = CheckIndexJoin(context, op, *right, *op.conditions[0].right); + if (PlanIndexJoin(context, op, plan, left, right, right_index)) { + return true; + } + // else check if we can plan an index join on the left side + auto left_index = CheckIndexJoin(context, op, *left, *op.conditions[0].left); + if (PlanIndexJoin(context, op, plan, right, left, left_index, true)) { + return true; + } + return false; +} + +static void RewriteJoinCondition(Expression &expr, idx_t offset) { + if (expr.type == ExpressionType::BOUND_REF) { + auto &ref = expr.Cast(); + ref.index += offset; + } + ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { RewriteJoinCondition(child, offset); }); +} + +unique_ptr PhysicalPlanGenerator::PlanComparisonJoin(LogicalComparisonJoin &op) { + // now visit the children + D_ASSERT(op.children.size() == 2); + idx_t lhs_cardinality = op.children[0]->EstimateCardinality(context); + idx_t rhs_cardinality = op.children[1]->EstimateCardinality(context); + auto left = CreatePlan(*op.children[0]); + auto right = CreatePlan(*op.children[1]); + left->estimated_cardinality = lhs_cardinality; + right->estimated_cardinality = rhs_cardinality; + D_ASSERT(left && right); + + if (op.conditions.empty()) { + // no conditions: insert a cross product + return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); + } + + bool has_equality = false; + size_t has_range = 0; + for (size_t c = 0; c < op.conditions.size(); ++c) { + auto &cond = op.conditions[c]; + switch (cond.comparison) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + has_equality = true; + break; + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + ++has_range; + break; + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_DISTINCT_FROM: + break; + default: + throw NotImplementedException("Unimplemented comparison join"); + } + } + + bool can_merge = has_range > 0; + bool can_iejoin = has_range >= 2 && recursive_cte_tables.empty(); + switch (op.join_type) { + case JoinType::SEMI: + case JoinType::ANTI: + case JoinType::MARK: + can_merge = can_merge && op.conditions.size() == 1; + can_iejoin = false; + break; + default: + break; + } + + // TODO: Extend PWMJ to handle all comparisons and projection maps + const auto prefer_range_joins = (ClientConfig::GetConfig(context).prefer_range_joins && can_iejoin); + + unique_ptr plan; + if (has_equality && !prefer_range_joins) { + // check if we can use an index join + if (PlanIndexJoin(context, op, plan, left, right)) { + return plan; + } + // Equality join with small number of keys : possible perfect join optimization + PerfectHashJoinStats perfect_join_stats; + CheckForPerfectJoinOpt(op, perfect_join_stats); + plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), + op.join_type, op.left_projection_map, op.right_projection_map, + std::move(op.mark_types), op.estimated_cardinality, perfect_join_stats); + + } else { + static constexpr const idx_t NESTED_LOOP_JOIN_THRESHOLD = 5; + if (left->estimated_cardinality <= NESTED_LOOP_JOIN_THRESHOLD || + right->estimated_cardinality <= NESTED_LOOP_JOIN_THRESHOLD) { + can_iejoin = false; + can_merge = false; + } + if (can_iejoin) { + plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), + op.join_type, op.estimated_cardinality); + } else if (can_merge) { + // range join: use piecewise merge join + plan = + make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), + op.join_type, op.estimated_cardinality); + } else if (PhysicalNestedLoopJoin::IsSupported(op.conditions, op.join_type)) { + // inequality join: use nested loop + plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), + op.join_type, op.estimated_cardinality); + } else { + for (auto &cond : op.conditions) { + RewriteJoinCondition(*cond.right, left->types.size()); + } + auto condition = JoinCondition::CreateExpression(std::move(op.conditions)); + plan = make_uniq(op, std::move(left), std::move(right), std::move(condition), + op.join_type, op.estimated_cardinality); + } + } + return plan; +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalComparisonJoin &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + return PlanAsOfJoin(op); + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + return PlanComparisonJoin(op); + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + return PlanDelimJoin(op); + default: + throw InternalException("Unrecognized operator type for LogicalComparisonJoin"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp b/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp new file mode 100644 index 00000000..0fe22751 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp @@ -0,0 +1,65 @@ +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" +#include "duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp" +#include "duckdb/execution/operator/persistent/physical_fixed_batch_copy.hpp" +#include "duckdb/planner/operator/logical_copy_to_file.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCopyToFile &op) { + auto plan = CreatePlan(*op.children[0]); + bool preserve_insertion_order = PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); + bool supports_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); + auto &fs = FileSystem::GetFileSystem(context); + op.file_path = fs.ExpandPath(op.file_path); + if (op.use_tmp_file) { + op.file_path += ".tmp"; + } + if (op.per_thread_output || op.partition_output || !op.partition_columns.empty() || op.overwrite_or_ignore) { + // hive-partitioning/per-thread output does not care about insertion order, and does not support batch indexes + preserve_insertion_order = false; + supports_batch_index = false; + } + auto mode = CopyFunctionExecutionMode::REGULAR_COPY_TO_FILE; + if (op.function.execution_mode) { + mode = op.function.execution_mode(preserve_insertion_order, supports_batch_index); + } + if (mode == CopyFunctionExecutionMode::BATCH_COPY_TO_FILE) { + if (!supports_batch_index) { + throw InternalException("BATCH_COPY_TO_FILE can only be used if batch indexes are supported"); + } + // batched copy to file + if (op.function.desired_batch_size) { + auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), + op.estimated_cardinality); + copy->file_path = op.file_path; + copy->use_tmp_file = op.use_tmp_file; + copy->children.push_back(std::move(plan)); + return std::move(copy); + } else { + auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), + op.estimated_cardinality); + copy->file_path = op.file_path; + copy->use_tmp_file = op.use_tmp_file; + copy->children.push_back(std::move(plan)); + return std::move(copy); + } + } + // COPY from select statement to file + auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), op.estimated_cardinality); + copy->file_path = op.file_path; + copy->use_tmp_file = op.use_tmp_file; + copy->overwrite_or_ignore = op.overwrite_or_ignore; + copy->filename_pattern = op.filename_pattern; + copy->per_thread_output = op.per_thread_output; + copy->partition_output = op.partition_output; + copy->partition_columns = op.partition_columns; + copy->names = op.names; + copy->expected_types = op.expected_types; + copy->parallel = mode == CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; + + copy->children.push_back(std::move(plan)); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_create.cpp b/src/duckdb/src/execution/physical_plan/plan_create.cpp new file mode 100644 index 00000000..42a1652e --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_create.cpp @@ -0,0 +1,42 @@ +#include "duckdb/execution/operator/schema/physical_create_function.hpp" +#include "duckdb/execution/operator/schema/physical_create_schema.hpp" +#include "duckdb/execution/operator/schema/physical_create_sequence.hpp" +#include "duckdb/execution/operator/schema/physical_create_type.hpp" +#include "duckdb/execution/operator/schema/physical_create_view.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/operator/logical_create.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreate &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_CREATE_VIEW: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_CREATE_MACRO: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_CREATE_TYPE: { + unique_ptr create = make_uniq( + unique_ptr_cast(std::move(op.info)), op.estimated_cardinality); + if (!op.children.empty()) { + D_ASSERT(op.children.size() == 1); + auto plan = CreatePlan(*op.children[0]); + create->children.push_back(std::move(plan)); + } + return create; + } + default: + throw NotImplementedException("Unimplemented type for logical simple create"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_create_index.cpp b/src/duckdb/src/execution/physical_plan/plan_create_index.cpp new file mode 100644 index 00000000..b8923d21 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_create_index.cpp @@ -0,0 +1,120 @@ +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/operator/filter/physical_filter.hpp" +#include "duckdb/execution/operator/scan/physical_table_scan.hpp" +#include "duckdb/execution/operator/schema/physical_create_art_index.hpp" +#include "duckdb/execution/operator/order/physical_order.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_create_index.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/table_filter.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateIndex &op) { + // generate a physical plan for the parallel index creation which consists of the following operators + // table scan - projection (for expression execution) - filter (NOT NULL) - order - create index + D_ASSERT(op.children.size() == 1); + auto table_scan = CreatePlan(*op.children[0]); + + // validate that all expressions contain valid scalar functions + // e.g. get_current_timestamp(), random(), and sequence values are not allowed as index keys + // because they make deletions and lookups unfeasible + for (idx_t i = 0; i < op.unbound_expressions.size(); i++) { + auto &expr = op.unbound_expressions[i]; + if (expr->HasSideEffects()) { + throw BinderException("Index keys cannot contain expressions with side " + "effects."); + } + } + + // If we get here without the plan and the index type is not ART, we throw an exception + // because we don't support any other index type yet. However an operator extension could have + // replaced this part of the plan with a different index creation operator. + if (op.info->index_type != IndexType::ART) { + throw BinderException("Index type not supported"); + } + + // table scan operator for index key columns and row IDs + dependencies.AddDependency(op.table); + + D_ASSERT(op.info->scan_types.size() - 1 <= op.info->names.size()); + D_ASSERT(op.info->scan_types.size() - 1 <= op.info->column_ids.size()); + + // projection to execute expressions on the key columns + + vector new_column_types; + vector> select_list; + for (idx_t i = 0; i < op.expressions.size(); i++) { + new_column_types.push_back(op.expressions[i]->return_type); + select_list.push_back(std::move(op.expressions[i])); + } + new_column_types.emplace_back(LogicalType::ROW_TYPE); + select_list.push_back(make_uniq(LogicalType::ROW_TYPE, op.info->scan_types.size() - 1)); + + auto projection = make_uniq(new_column_types, std::move(select_list), op.estimated_cardinality); + projection->children.push_back(std::move(table_scan)); + + // filter operator for IS_NOT_NULL on each key column + + vector filter_types; + vector> filter_select_list; + + for (idx_t i = 0; i < new_column_types.size() - 1; i++) { + filter_types.push_back(new_column_types[i]); + auto is_not_null_expr = + make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); + auto bound_ref = make_uniq(new_column_types[i], i); + is_not_null_expr->children.push_back(std::move(bound_ref)); + filter_select_list.push_back(std::move(is_not_null_expr)); + } + + auto null_filter = + make_uniq(std::move(filter_types), std::move(filter_select_list), op.estimated_cardinality); + null_filter->types.emplace_back(LogicalType::ROW_TYPE); + null_filter->children.push_back(std::move(projection)); + + // determine if we sort the data prior to index creation + // we don't sort, if either VARCHAR or compound key + auto perform_sorting = true; + if (op.unbound_expressions.size() > 1) { + perform_sorting = false; + } else if (op.unbound_expressions[0]->return_type.InternalType() == PhysicalType::VARCHAR) { + perform_sorting = false; + } + + // actual physical create index operator + + auto physical_create_index = + make_uniq(op, op.table, op.info->column_ids, std::move(op.info), + std::move(op.unbound_expressions), op.estimated_cardinality, perform_sorting); + + if (perform_sorting) { + + // optional order operator + vector orders; + vector projections; + for (idx_t i = 0; i < new_column_types.size() - 1; i++) { + auto col_expr = make_uniq_base(new_column_types[i], i); + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(col_expr)); + projections.emplace_back(i); + } + projections.emplace_back(new_column_types.size() - 1); + + auto physical_order = make_uniq(new_column_types, std::move(orders), std::move(projections), + op.estimated_cardinality); + physical_order->children.push_back(std::move(null_filter)); + + physical_create_index->children.push_back(std::move(physical_order)); + } else { + + // no ordering + physical_create_index->children.push_back(std::move(null_filter)); + } + + return std::move(physical_create_index); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_create_table.cpp b/src/duckdb/src/execution/physical_plan/plan_create_table.cpp new file mode 100644 index 00000000..c5f80933 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_create_table.cpp @@ -0,0 +1,50 @@ +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/execution/operator/schema/physical_create_table.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/execution/operator/persistent/physical_insert.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/operator/logical_create_table.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/execution/operator/persistent/physical_batch_insert.hpp" +#include "duckdb/planner/constraints/bound_check_constraint.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/catalog/duck_catalog.hpp" + +namespace duckdb { + +unique_ptr DuckCatalog::PlanCreateTableAs(ClientContext &context, LogicalCreateTable &op, + unique_ptr plan) { + bool parallel_streaming_insert = !PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); + bool use_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); + auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + unique_ptr create; + if (!parallel_streaming_insert && use_batch_index) { + create = make_uniq(op, op.schema, std::move(op.info), op.estimated_cardinality); + + } else { + create = make_uniq(op, op.schema, std::move(op.info), op.estimated_cardinality, + parallel_streaming_insert && num_threads > 1); + } + + D_ASSERT(op.children.size() == 1); + create->children.push_back(std::move(plan)); + return create; +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateTable &op) { + const auto &create_info = op.info->base->Cast(); + auto &catalog = op.info->schema.catalog; + auto existing_entry = catalog.GetEntry(context, create_info.schema, create_info.table, + OnEntryNotFound::RETURN_NULL); + bool replace = op.info->Base().on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT; + if ((!existing_entry || replace) && !op.children.empty()) { + auto plan = CreatePlan(*op.children[0]); + return op.schema.catalog.PlanCreateTableAs(context, op, std::move(plan)); + } else { + return make_uniq(op, op.schema, std::move(op.info), op.estimated_cardinality); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_cross_product.cpp b/src/duckdb/src/execution/physical_plan/plan_cross_product.cpp new file mode 100644 index 00000000..dac22070 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_cross_product.cpp @@ -0,0 +1,15 @@ +#include "duckdb/execution/operator/join/physical_cross_product.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCrossProduct &op) { + D_ASSERT(op.children.size() == 2); + + auto left = CreatePlan(*op.children[0]); + auto right = CreatePlan(*op.children[1]); + return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_cte.cpp b/src/duckdb/src/execution/physical_plan/plan_cte.cpp new file mode 100644 index 00000000..0c0b0485 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_cte.cpp @@ -0,0 +1,33 @@ +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/operator/set/physical_cte.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalMaterializedCTE &op) { + D_ASSERT(op.children.size() == 2); + + // Create the working_table that the PhysicalCTE will use for evaluation. + auto working_table = std::make_shared(context, op.children[0]->types); + + // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator + recursive_cte_tables[op.table_index] = working_table; + + // Create the plan for the left side. This is the materialization. + auto left = CreatePlan(*op.children[0]); + // Initialize an empty vector to collect the scan operators. + materialized_ctes.insert(op.table_index); + auto right = CreatePlan(*op.children[1]); + + auto cte = make_uniq(op.ctename, op.table_index, op.children[1]->types, std::move(left), + std::move(right), op.estimated_cardinality); + cte->working_table = working_table; + + return std::move(cte); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_delete.cpp b/src/duckdb/src/execution/physical_plan/plan_delete.cpp new file mode 100644 index 00000000..3a748c00 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_delete.cpp @@ -0,0 +1,32 @@ +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/execution/operator/persistent/physical_delete.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/logical_delete.hpp" +#include "duckdb/catalog/duck_catalog.hpp" + +namespace duckdb { + +unique_ptr DuckCatalog::PlanDelete(ClientContext &context, LogicalDelete &op, + unique_ptr plan) { + // get the index of the row_id column + auto &bound_ref = op.expressions[0]->Cast(); + + auto del = make_uniq(op.types, op.table, op.table.GetStorage(), bound_ref.index, + op.estimated_cardinality, op.return_chunk); + del->children.push_back(std::move(plan)); + return std::move(del); +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDelete &op) { + D_ASSERT(op.children.size() == 1); + D_ASSERT(op.expressions.size() == 1); + D_ASSERT(op.expressions[0]->type == ExpressionType::BOUND_REF); + + auto plan = CreatePlan(*op.children[0]); + + dependencies.AddDependency(op.table); + return op.table.catalog.PlanDelete(context, op, std::move(plan)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_delim_get.cpp b/src/duckdb/src/execution/physical_plan/plan_delim_get.cpp new file mode 100644 index 00000000..32ddeb2f --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_delim_get.cpp @@ -0,0 +1,16 @@ +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_delim_get.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDelimGet &op) { + D_ASSERT(op.children.empty()); + + // create a PhysicalChunkScan without an owned_collection, the collection will be added later + auto chunk_scan = + make_uniq(op.types, PhysicalOperatorType::DELIM_SCAN, op.estimated_cardinality); + return std::move(chunk_scan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp b/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp new file mode 100644 index 00000000..f30cb259 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp @@ -0,0 +1,50 @@ +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/join/physical_hash_join.hpp" +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" + +namespace duckdb { + +static void GatherDelimScans(const PhysicalOperator &op, vector> &delim_scans) { + if (op.type == PhysicalOperatorType::DELIM_SCAN) { + delim_scans.push_back(op); + } + for (auto &child : op.children) { + GatherDelimScans(*child, delim_scans); + } +} + +unique_ptr PhysicalPlanGenerator::PlanDelimJoin(LogicalComparisonJoin &op) { + // first create the underlying join + auto plan = PlanComparisonJoin(op); + // this should create a join, not a cross product + D_ASSERT(plan && plan->type != PhysicalOperatorType::CROSS_PRODUCT); + // duplicate eliminated join + // first gather the scans on the duplicate eliminated data set from the RHS + vector> delim_scans; + GatherDelimScans(*plan->children[1], delim_scans); + if (delim_scans.empty()) { + // no duplicate eliminated scans in the RHS! + // in this case we don't need to create a delim join + // just push the normal join + return plan; + } + vector delim_types; + vector> distinct_groups, distinct_expressions; + for (auto &delim_expr : op.duplicate_eliminated_columns) { + D_ASSERT(delim_expr->type == ExpressionType::BOUND_REF); + auto &bound_ref = delim_expr->Cast(); + delim_types.push_back(bound_ref.return_type); + distinct_groups.push_back(make_uniq(bound_ref.return_type, bound_ref.index)); + } + // now create the duplicate eliminated join + auto delim_join = make_uniq(op.types, std::move(plan), delim_scans, op.estimated_cardinality); + // we still have to create the DISTINCT clause that is used to generate the duplicate eliminated chunk + delim_join->distinct = make_uniq(context, delim_types, std::move(distinct_expressions), + std::move(distinct_groups), op.estimated_cardinality); + return std::move(delim_join); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_distinct.cpp b/src/duckdb/src/execution/physical_plan/plan_distinct.cpp new file mode 100644 index 00000000..a2980e87 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_distinct.cpp @@ -0,0 +1,89 @@ +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDistinct &op) { + D_ASSERT(op.children.size() == 1); + auto child = CreatePlan(*op.children[0]); + auto &distinct_targets = op.distinct_targets; + D_ASSERT(child); + D_ASSERT(!distinct_targets.empty()); + + auto &types = child->GetTypes(); + vector> groups, aggregates, projections; + idx_t group_count = distinct_targets.size(); + unordered_map group_by_references; + vector aggregate_types; + // creates one group per distinct_target + for (idx_t i = 0; i < distinct_targets.size(); i++) { + auto &target = distinct_targets[i]; + if (target->type == ExpressionType::BOUND_REF) { + auto &bound_ref = target->Cast(); + group_by_references[bound_ref.index] = i; + } + aggregate_types.push_back(target->return_type); + groups.push_back(std::move(target)); + } + bool requires_projection = false; + if (types.size() != group_count) { + requires_projection = true; + } + // we need to create one aggregate per column in the select_list + for (idx_t i = 0; i < types.size(); ++i) { + auto logical_type = types[i]; + // check if we can directly refer to a group, or if we need to push an aggregate with FIRST + auto entry = group_by_references.find(i); + if (entry != group_by_references.end()) { + auto group_index = entry->second; + // entry is found: can directly refer to a group + projections.push_back(make_uniq(logical_type, group_index)); + if (group_index != i) { + // we require a projection only if this group element is out of order + requires_projection = true; + } + } else { + if (op.distinct_type == DistinctType::DISTINCT && op.order_by) { + throw InternalException("Entry that is not a group, but not a DISTINCT ON aggregate"); + } + // entry is not one of the groups: need to push a FIRST aggregate + auto bound = make_uniq(logical_type, i); + vector> first_children; + first_children.push_back(std::move(bound)); + + FunctionBinder function_binder(context); + auto first_aggregate = function_binder.BindAggregateFunction( + FirstFun::GetFunction(logical_type), std::move(first_children), nullptr, AggregateType::NON_DISTINCT); + first_aggregate->order_bys = op.order_by ? op.order_by->Copy() : nullptr; + // add the projection + projections.push_back(make_uniq(logical_type, group_count + aggregates.size())); + // push it to the list of aggregates + aggregate_types.push_back(logical_type); + aggregates.push_back(std::move(first_aggregate)); + requires_projection = true; + } + } + + child = ExtractAggregateExpressions(std::move(child), aggregates, groups); + + // we add a physical hash aggregation in the plan to select the distinct groups + auto groupby = make_uniq(context, aggregate_types, std::move(aggregates), std::move(groups), + child->estimated_cardinality); + groupby->children.push_back(std::move(child)); + if (!requires_projection) { + return std::move(groupby); + } + + // we add a physical projection on top of the aggregation to project all members in the select list + auto aggr_projection = make_uniq(types, std::move(projections), groupby->estimated_cardinality); + aggr_projection->children.push_back(std::move(groupby)); + return std::move(aggr_projection); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_dummy_scan.cpp b/src/duckdb/src/execution/physical_plan/plan_dummy_scan.cpp new file mode 100644 index 00000000..ed561cf4 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_dummy_scan.cpp @@ -0,0 +1,12 @@ +#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDummyScan &op) { + D_ASSERT(op.children.size() == 0); + return make_uniq(op.types, op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_empty_result.cpp b/src/duckdb/src/execution/physical_plan/plan_empty_result.cpp new file mode 100644 index 00000000..7313f034 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_empty_result.cpp @@ -0,0 +1,12 @@ +#include "duckdb/execution/operator/scan/physical_empty_result.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalEmptyResult &op) { + D_ASSERT(op.children.size() == 0); + return make_uniq(op.types, op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_execute.cpp b/src/duckdb/src/execution/physical_plan/plan_execute.cpp new file mode 100644 index 00000000..735cf9fc --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_execute.cpp @@ -0,0 +1,21 @@ +#include "duckdb/execution/operator/helper/physical_execute.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_execute.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExecute &op) { + if (!op.prepared->plan) { + D_ASSERT(op.children.size() == 1); + auto owned_plan = CreatePlan(*op.children[0]); + auto execute = make_uniq(*owned_plan); + execute->owned_plan = std::move(owned_plan); + execute->prepared = std::move(op.prepared); + return std::move(execute); + } else { + D_ASSERT(op.children.size() == 0); + return make_uniq(*op.prepared->plan); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_explain.cpp b/src/duckdb/src/execution/physical_plan/plan_explain.cpp new file mode 100644 index 00000000..867aa4ac --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_explain.cpp @@ -0,0 +1,63 @@ +#include "duckdb/common/tree_renderer.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/operator/helper/physical_explain_analyze.hpp" +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/operator/logical_explain.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExplain &op) { + D_ASSERT(op.children.size() == 1); + auto logical_plan_opt = op.children[0]->ToString(); + auto plan = CreatePlan(*op.children[0]); + if (op.explain_type == ExplainType::EXPLAIN_ANALYZE) { + auto result = make_uniq(op.types); + result->children.push_back(std::move(plan)); + return std::move(result); + } + + op.physical_plan = plan->ToString(); + // the output of the explain + vector keys, values; + switch (ClientConfig::GetConfig(context).explain_output_type) { + case ExplainOutputType::OPTIMIZED_ONLY: + keys = {"logical_opt"}; + values = {logical_plan_opt}; + break; + case ExplainOutputType::PHYSICAL_ONLY: + keys = {"physical_plan"}; + values = {op.physical_plan}; + break; + default: + keys = {"logical_plan", "logical_opt", "physical_plan"}; + values = {op.logical_plan_unopt, logical_plan_opt, op.physical_plan}; + } + + // create a ColumnDataCollection from the output + auto &allocator = Allocator::Get(context); + vector plan_types {LogicalType::VARCHAR, LogicalType::VARCHAR}; + auto collection = + make_uniq(context, plan_types, ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); + + DataChunk chunk; + chunk.Initialize(allocator, op.types); + for (idx_t i = 0; i < keys.size(); i++) { + chunk.SetValue(0, chunk.size(), Value(keys[i])); + chunk.SetValue(1, chunk.size(), Value(values[i])); + chunk.SetCardinality(chunk.size() + 1); + if (chunk.size() == STANDARD_VECTOR_SIZE) { + collection->Append(chunk); + chunk.Reset(); + } + } + collection->Append(chunk); + + // create a chunk scan to output the result + auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, + op.estimated_cardinality, std::move(collection)); + return std::move(chunk_scan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_export.cpp b/src/duckdb/src/execution/physical_plan/plan_export.cpp new file mode 100644 index 00000000..3179ec6f --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_export.cpp @@ -0,0 +1,23 @@ +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/execution/operator/persistent/physical_export.hpp" +#include "duckdb/planner/operator/logical_export.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExport &op) { + auto &config = DBConfig::GetConfig(context); + if (!config.options.enable_external_access) { + throw PermissionException("Export is disabled through configuration"); + } + auto export_node = make_uniq(op.types, op.function, std::move(op.copy_info), + op.estimated_cardinality, op.exported_tables); + // plan the underlying copy statements, if any + if (!op.children.empty()) { + auto plan = CreatePlan(*op.children[0]); + export_node->children.push_back(std::move(plan)); + } + return std::move(export_node); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_expression_get.cpp b/src/duckdb/src/execution/physical_plan/plan_expression_get.cpp new file mode 100644 index 00000000..b0db627e --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_expression_get.cpp @@ -0,0 +1,39 @@ +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/operator/scan/physical_expression_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_expression_get.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExpressionGet &op) { + D_ASSERT(op.children.size() == 1); + auto plan = CreatePlan(*op.children[0]); + + auto expr_scan = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); + expr_scan->children.push_back(std::move(plan)); + if (!expr_scan->IsFoldable()) { + return std::move(expr_scan); + } + auto &allocator = Allocator::Get(context); + // simple expression scan (i.e. no subqueries to evaluate and no prepared statement parameters) + // we can evaluate all the expressions right now and turn this into a chunk collection scan + auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, + expr_scan->expressions.size()); + chunk_scan->owned_collection = make_uniq(context, op.types); + chunk_scan->collection = chunk_scan->owned_collection.get(); + + DataChunk chunk; + chunk.Initialize(allocator, op.types); + + ColumnDataAppendState append_state; + chunk_scan->owned_collection->InitializeAppend(append_state); + for (idx_t expression_idx = 0; expression_idx < expr_scan->expressions.size(); expression_idx++) { + chunk.Reset(); + expr_scan->EvaluateExpression(context, expression_idx, nullptr, chunk); + chunk_scan->owned_collection->Append(append_state, chunk); + } + return std::move(chunk_scan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_filter.cpp b/src/duckdb/src/execution/physical_plan/plan_filter.cpp new file mode 100644 index 00000000..ea87121a --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_filter.cpp @@ -0,0 +1,36 @@ +#include "duckdb/execution/operator/filter/physical_filter.hpp" +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalFilter &op) { + D_ASSERT(op.children.size() == 1); + unique_ptr plan = CreatePlan(*op.children[0]); + if (!op.expressions.empty()) { + D_ASSERT(plan->types.size() > 0); + // create a filter if there is anything to filter + auto filter = make_uniq(plan->types, std::move(op.expressions), op.estimated_cardinality); + filter->children.push_back(std::move(plan)); + plan = std::move(filter); + } + if (!op.projection_map.empty()) { + // there is a projection map, generate a physical projection + vector> select_list; + for (idx_t i = 0; i < op.projection_map.size(); i++) { + select_list.push_back(make_uniq(op.types[i], op.projection_map[i])); + } + auto proj = make_uniq(op.types, std::move(select_list), op.estimated_cardinality); + proj->children.push_back(std::move(plan)); + plan = std::move(proj); + } + return plan; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_get.cpp b/src/duckdb/src/execution/physical_plan/plan_get.cpp new file mode 100644 index 00000000..e7ec5437 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_get.cpp @@ -0,0 +1,100 @@ +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/operator/projection/physical_tableinout_function.hpp" +#include "duckdb/execution/operator/scan/physical_table_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + +namespace duckdb { + +unique_ptr CreateTableFilterSet(TableFilterSet &table_filters, vector &column_ids) { + // create the table filter map + auto table_filter_set = make_uniq(); + for (auto &table_filter : table_filters.filters) { + // find the relative column index from the absolute column index into the table + idx_t column_index = DConstants::INVALID_INDEX; + for (idx_t i = 0; i < column_ids.size(); i++) { + if (table_filter.first == column_ids[i]) { + column_index = i; + break; + } + } + if (column_index == DConstants::INVALID_INDEX) { + throw InternalException("Could not find column index for table filter"); + } + table_filter_set->filters[column_index] = std::move(table_filter.second); + } + return table_filter_set; +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { + if (!op.children.empty()) { + // this is for table producing functions that consume subquery results + D_ASSERT(op.children.size() == 1); + auto node = make_uniq(op.types, op.function, std::move(op.bind_data), op.column_ids, + op.estimated_cardinality, std::move(op.projected_input)); + node->children.push_back(CreatePlan(std::move(op.children[0]))); + return std::move(node); + } + if (!op.projected_input.empty()) { + throw InternalException("LogicalGet::project_input can only be set for table-in-out functions"); + } + + unique_ptr table_filters; + if (!op.table_filters.filters.empty()) { + table_filters = CreateTableFilterSet(op.table_filters, op.column_ids); + } + + if (op.function.dependency) { + op.function.dependency(dependencies, op.bind_data.get()); + } + // create the table scan node + if (!op.function.projection_pushdown) { + // function does not support projection pushdown + auto node = make_uniq(op.returned_types, op.function, std::move(op.bind_data), + op.returned_types, op.column_ids, vector(), op.names, + std::move(table_filters), op.estimated_cardinality, op.extra_info); + // first check if an additional projection is necessary + if (op.column_ids.size() == op.returned_types.size()) { + bool projection_necessary = false; + for (idx_t i = 0; i < op.column_ids.size(); i++) { + if (op.column_ids[i] != i) { + projection_necessary = true; + break; + } + } + if (!projection_necessary) { + // a projection is not necessary if all columns have been requested in-order + // in that case we just return the node + + return std::move(node); + } + } + // push a projection on top that does the projection + vector types; + vector> expressions; + for (auto &column_id : op.column_ids) { + if (column_id == COLUMN_IDENTIFIER_ROW_ID) { + types.emplace_back(LogicalType::BIGINT); + expressions.push_back(make_uniq(Value::BIGINT(0))); + } else { + auto type = op.returned_types[column_id]; + types.push_back(type); + expressions.push_back(make_uniq(type, column_id)); + } + } + + auto projection = + make_uniq(std::move(types), std::move(expressions), op.estimated_cardinality); + projection->children.push_back(std::move(node)); + return std::move(projection); + } else { + return make_uniq(op.types, op.function, std::move(op.bind_data), op.returned_types, + op.column_ids, op.projection_ids, op.names, std::move(table_filters), + op.estimated_cardinality, op.extra_info); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_insert.cpp b/src/duckdb/src/execution/physical_plan/plan_insert.cpp new file mode 100644 index 00000000..b0e35341 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_insert.cpp @@ -0,0 +1,113 @@ +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/execution/operator/persistent/physical_insert.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_insert.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/execution/operator/persistent/physical_batch_insert.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/catalog/duck_catalog.hpp" + +namespace duckdb { + +static OrderPreservationType OrderPreservationRecursive(PhysicalOperator &op) { + if (op.IsSource()) { + return op.SourceOrder(); + } + for (auto &child : op.children) { + auto child_preservation = OrderPreservationRecursive(*child); + if (child_preservation != OrderPreservationType::INSERTION_ORDER) { + return child_preservation; + } + } + return OrderPreservationType::INSERTION_ORDER; +} + +bool PhysicalPlanGenerator::PreserveInsertionOrder(ClientContext &context, PhysicalOperator &plan) { + auto &config = DBConfig::GetConfig(context); + + auto preservation_type = OrderPreservationRecursive(plan); + if (preservation_type == OrderPreservationType::FIXED_ORDER) { + // always need to maintain preservation order + return true; + } + if (preservation_type == OrderPreservationType::NO_ORDER) { + // never need to preserve order + return false; + } + // preserve insertion order - check flags + if (!config.options.preserve_insertion_order) { + // preserving insertion order is disabled by config + return false; + } + return true; +} + +bool PhysicalPlanGenerator::PreserveInsertionOrder(PhysicalOperator &plan) { + return PreserveInsertionOrder(context, plan); +} + +bool PhysicalPlanGenerator::UseBatchIndex(ClientContext &context, PhysicalOperator &plan) { + // TODO: always preserve order if query contains ORDER BY + auto &scheduler = TaskScheduler::GetScheduler(context); + if (scheduler.NumberOfThreads() == 1) { + // batch index usage only makes sense if we are using multiple threads + return false; + } + if (!plan.AllSourcesSupportBatchIndex()) { + // batch index is not supported + return false; + } + return true; +} + +bool PhysicalPlanGenerator::UseBatchIndex(PhysicalOperator &plan) { + return UseBatchIndex(context, plan); +} + +unique_ptr DuckCatalog::PlanInsert(ClientContext &context, LogicalInsert &op, + unique_ptr plan) { + bool parallel_streaming_insert = !PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); + bool use_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); + auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + if (op.return_chunk) { + // not supported for RETURNING (yet?) + parallel_streaming_insert = false; + use_batch_index = false; + } + if (op.action_type != OnConflictAction::THROW) { + // We don't support ON CONFLICT clause in batch insertion operation currently + use_batch_index = false; + } + if (op.action_type == OnConflictAction::UPDATE) { + // When we potentially need to perform updates, we have to check that row is not updated twice + // that currently needs to be done for every chunk, which would add a huge bottleneck to parallelized insertion + parallel_streaming_insert = false; + } + unique_ptr insert; + if (use_batch_index && !parallel_streaming_insert) { + insert = make_uniq(op.types, op.table, op.column_index_map, std::move(op.bound_defaults), + op.estimated_cardinality); + } else { + insert = make_uniq( + op.types, op.table, op.column_index_map, std::move(op.bound_defaults), std::move(op.expressions), + std::move(op.set_columns), std::move(op.set_types), op.estimated_cardinality, op.return_chunk, + parallel_streaming_insert && num_threads > 1, op.action_type, std::move(op.on_conflict_condition), + std::move(op.do_update_condition), std::move(op.on_conflict_filter), std::move(op.columns_to_fetch)); + } + D_ASSERT(plan); + insert->children.push_back(std::move(plan)); + return insert; +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalInsert &op) { + unique_ptr plan; + if (!op.children.empty()) { + D_ASSERT(op.children.size() == 1); + plan = CreatePlan(*op.children[0]); + } + dependencies.AddDependency(op.table); + return op.table.catalog.PlanInsert(context, op, std::move(plan)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_limit.cpp b/src/duckdb/src/execution/physical_plan/plan_limit.cpp new file mode 100644 index 00000000..b241b389 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_limit.cpp @@ -0,0 +1,36 @@ +#include "duckdb/execution/operator/helper/physical_limit.hpp" +#include "duckdb/execution/operator/helper/physical_streaming_limit.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalLimit &op) { + D_ASSERT(op.children.size() == 1); + + auto plan = CreatePlan(*op.children[0]); + + unique_ptr limit; + if (!PreserveInsertionOrder(*plan)) { + // use parallel streaming limit if insertion order is not important + limit = make_uniq(op.types, (idx_t)op.limit_val, op.offset_val, std::move(op.limit), + std::move(op.offset), op.estimated_cardinality, true); + } else { + // maintaining insertion order is important + if (UseBatchIndex(*plan)) { + // source supports batch index: use parallel batch limit + limit = make_uniq(op.types, (idx_t)op.limit_val, op.offset_val, std::move(op.limit), + std::move(op.offset), op.estimated_cardinality); + } else { + // source does not support batch index: use a non-parallel streaming limit + limit = make_uniq(op.types, (idx_t)op.limit_val, op.offset_val, std::move(op.limit), + std::move(op.offset), op.estimated_cardinality, false); + } + } + + limit->children.push_back(std::move(plan)); + return limit; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_limit_percent.cpp b/src/duckdb/src/execution/physical_plan/plan_limit_percent.cpp new file mode 100644 index 00000000..7caa5b32 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_limit_percent.cpp @@ -0,0 +1,18 @@ +#include "duckdb/execution/operator/helper/physical_limit_percent.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_limit_percent.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalLimitPercent &op) { + D_ASSERT(op.children.size() == 1); + + auto plan = CreatePlan(*op.children[0]); + + auto limit = make_uniq(op.types, op.limit_percent, op.offset_val, std::move(op.limit), + std::move(op.offset), op.estimated_cardinality); + limit->children.push_back(std::move(plan)); + return std::move(limit); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_order.cpp b/src/duckdb/src/execution/physical_plan/plan_order.cpp new file mode 100644 index 00000000..a7161ad3 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_order.cpp @@ -0,0 +1,28 @@ +#include "duckdb/execution/operator/order/physical_order.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_order.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalOrder &op) { + D_ASSERT(op.children.size() == 1); + + auto plan = CreatePlan(*op.children[0]); + if (!op.orders.empty()) { + vector projections; + if (op.projections.empty()) { + for (idx_t i = 0; i < plan->types.size(); i++) { + projections.push_back(i); + } + } else { + projections = std::move(op.projections); + } + auto order = + make_uniq(op.types, std::move(op.orders), std::move(projections), op.estimated_cardinality); + order->children.push_back(std::move(plan)); + plan = std::move(order); + } + return plan; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_pivot.cpp b/src/duckdb/src/execution/physical_plan/plan_pivot.cpp new file mode 100644 index 00000000..bca3bc90 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_pivot.cpp @@ -0,0 +1,14 @@ +#include "duckdb/execution/operator/projection/physical_pivot.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_pivot.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPivot &op) { + D_ASSERT(op.children.size() == 1); + auto child_plan = CreatePlan(*op.children[0]); + auto pivot = make_uniq(std::move(op.types), std::move(child_plan), std::move(op.bound_pivot)); + return std::move(pivot); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_positional_join.cpp b/src/duckdb/src/execution/physical_plan/plan_positional_join.cpp new file mode 100644 index 00000000..84e78808 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_positional_join.cpp @@ -0,0 +1,30 @@ +#include "duckdb/execution/operator/join/physical_positional_join.hpp" +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_positional_join.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPositionalJoin &op) { + D_ASSERT(op.children.size() == 2); + + auto left = CreatePlan(*op.children[0]); + auto right = CreatePlan(*op.children[1]); + switch (left->type) { + case PhysicalOperatorType::TABLE_SCAN: + case PhysicalOperatorType::POSITIONAL_SCAN: + switch (right->type) { + case PhysicalOperatorType::TABLE_SCAN: + case PhysicalOperatorType::POSITIONAL_SCAN: + return make_uniq(op.types, std::move(left), std::move(right)); + default: + break; + } + default: + break; + } + + return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_pragma.cpp b/src/duckdb/src/execution/physical_plan/plan_pragma.cpp new file mode 100644 index 00000000..7cd47ae2 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_pragma.cpp @@ -0,0 +1,11 @@ +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_pragma.hpp" + +#include "duckdb/execution/operator/helper/physical_pragma.hpp" +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPragma &op) { + return make_uniq(op.function, op.info, op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_prepare.cpp b/src/duckdb/src/execution/physical_plan/plan_prepare.cpp new file mode 100644 index 00000000..0a61e939 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_prepare.cpp @@ -0,0 +1,21 @@ +#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_prepare.hpp" +#include "duckdb/execution/operator/helper/physical_prepare.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPrepare &op) { + D_ASSERT(op.children.size() <= 1); + + // generate physical plan + if (!op.children.empty()) { + auto plan = CreatePlan(*op.children[0]); + op.prepared->types = plan->types; + op.prepared->plan = std::move(plan); + } + + return make_uniq(op.name, std::move(op.prepared), op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_projection.cpp b/src/duckdb/src/execution/physical_plan/plan_projection.cpp new file mode 100644 index 00000000..76fb36df --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_projection.cpp @@ -0,0 +1,44 @@ +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalProjection &op) { + D_ASSERT(op.children.size() == 1); + auto plan = CreatePlan(*op.children[0]); + +#ifdef DEBUG + for (auto &expr : op.expressions) { + D_ASSERT(!expr->IsWindow()); + D_ASSERT(!expr->IsAggregate()); + } +#endif + if (plan->types.size() == op.types.size()) { + // check if this projection can be omitted entirely + // this happens if a projection simply emits the columns in the same order + // e.g. PROJECTION(#0, #1, #2, #3, ...) + bool omit_projection = true; + for (idx_t i = 0; i < op.types.size(); i++) { + if (op.expressions[i]->type == ExpressionType::BOUND_REF) { + auto &bound_ref = op.expressions[i]->Cast(); + if (bound_ref.index == i) { + continue; + } + } + omit_projection = false; + break; + } + if (omit_projection) { + // the projection only directly projects the child' columns: omit it entirely + return plan; + } + } + + auto projection = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); + projection->children.push_back(std::move(plan)); + return std::move(projection); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp b/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp new file mode 100644 index 00000000..8ded97a8 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp @@ -0,0 +1,66 @@ +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" +#include "duckdb/planner/operator/logical_recursive_cte.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalRecursiveCTE &op) { + D_ASSERT(op.children.size() == 2); + + // Create the working_table that the PhysicalRecursiveCTE will use for evaluation. + auto working_table = std::make_shared(context, op.types); + + // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator + recursive_cte_tables[op.table_index] = working_table; + + auto left = CreatePlan(*op.children[0]); + auto right = CreatePlan(*op.children[1]); + + auto cte = make_uniq(op.ctename, op.table_index, op.types, op.union_all, std::move(left), + std::move(right), op.estimated_cardinality); + cte->working_table = working_table; + + return std::move(cte); +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCTERef &op) { + D_ASSERT(op.children.empty()); + + // Check if this LogicalCTERef is supposed to scan a materialized CTE. + if (op.materialized_cte == CTEMaterialize::CTE_MATERIALIZE_ALWAYS) { + // Lookup if there is a materialized CTE for the cte_index. + auto materialized_cte = materialized_ctes.find(op.cte_index); + + // If this check fails, this is a reference to a materialized recursive CTE. + if (materialized_cte != materialized_ctes.end()) { + auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::CTE_SCAN, + op.estimated_cardinality, op.cte_index); + + auto cte = recursive_cte_tables.find(op.cte_index); + if (cte == recursive_cte_tables.end()) { + throw InvalidInputException("Referenced materialized CTE does not exist."); + } + chunk_scan->collection = cte->second.get(); + + return std::move(chunk_scan); + } + } + + // CreatePlan of a LogicalRecursiveCTE must have happened before. + auto cte = recursive_cte_tables.find(op.cte_index); + if (cte == recursive_cte_tables.end()) { + throw InvalidInputException("Referenced recursive CTE does not exist."); + } + + auto chunk_scan = make_uniq( + cte->second.get()->Types(), PhysicalOperatorType::RECURSIVE_CTE_SCAN, op.estimated_cardinality, op.cte_index); + + chunk_scan->collection = cte->second.get(); + return std::move(chunk_scan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_reset.cpp b/src/duckdb/src/execution/physical_plan/plan_reset.cpp new file mode 100644 index 00000000..2f8aa369 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_reset.cpp @@ -0,0 +1,11 @@ +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_reset.hpp" +#include "duckdb/execution/operator/helper/physical_reset.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalReset &op) { + return make_uniq(op.name, op.scope, op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_sample.cpp b/src/duckdb/src/execution/physical_plan/plan_sample.cpp new file mode 100644 index 00000000..e13ef8eb --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_sample.cpp @@ -0,0 +1,37 @@ +#include "duckdb/execution/operator/helper/physical_reservoir_sample.hpp" +#include "duckdb/execution/operator/helper/physical_streaming_sample.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_sample.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSample &op) { + D_ASSERT(op.children.size() == 1); + + auto plan = CreatePlan(*op.children[0]); + + unique_ptr sample; + switch (op.sample_options->method) { + case SampleMethod::RESERVOIR_SAMPLE: + sample = make_uniq(op.types, std::move(op.sample_options), op.estimated_cardinality); + break; + case SampleMethod::SYSTEM_SAMPLE: + case SampleMethod::BERNOULLI_SAMPLE: + if (!op.sample_options->is_percentage) { + throw ParserException("Sample method %s cannot be used with a discrete sample count, either switch to " + "reservoir sampling or use a sample_size", + EnumUtil::ToString(op.sample_options->method)); + } + sample = make_uniq(op.types, op.sample_options->method, + op.sample_options->sample_size.GetValue(), + op.sample_options->seed, op.estimated_cardinality); + break; + default: + throw InternalException("Unimplemented sample method"); + } + sample->children.push_back(std::move(plan)); + return sample; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_set.cpp b/src/duckdb/src/execution/physical_plan/plan_set.cpp new file mode 100644 index 00000000..9325719b --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_set.cpp @@ -0,0 +1,11 @@ +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_set.hpp" +#include "duckdb/execution/operator/helper/physical_set.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSet &op) { + return make_uniq(op.name, op.value, op.scope, op.estimated_cardinality); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp b/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp new file mode 100644 index 00000000..f636eeee --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp @@ -0,0 +1,46 @@ +#include "duckdb/execution/operator/join/physical_hash_join.hpp" +#include "duckdb/execution/operator/set/physical_union.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSetOperation &op) { + D_ASSERT(op.children.size() == 2); + + auto left = CreatePlan(*op.children[0]); + auto right = CreatePlan(*op.children[1]); + + if (left->GetTypes() != right->GetTypes()) { + throw InvalidInputException("Type mismatch for SET OPERATION"); + } + + switch (op.type) { + case LogicalOperatorType::LOGICAL_UNION: + // UNION + return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); + default: { + // EXCEPT/INTERSECT + D_ASSERT(op.type == LogicalOperatorType::LOGICAL_EXCEPT || op.type == LogicalOperatorType::LOGICAL_INTERSECT); + auto &types = left->GetTypes(); + vector conditions; + // create equality condition for all columns + for (idx_t i = 0; i < types.size(); i++) { + JoinCondition cond; + cond.left = make_uniq(types[i], i); + cond.right = make_uniq(types[i], i); + cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + conditions.push_back(std::move(cond)); + } + // EXCEPT is ANTI join + // INTERSECT is SEMI join + PerfectHashJoinStats join_stats; // used in inner joins only + JoinType join_type = op.type == LogicalOperatorType::LOGICAL_EXCEPT ? JoinType::ANTI : JoinType::SEMI; + return make_uniq(op, std::move(left), std::move(right), std::move(conditions), join_type, + op.estimated_cardinality, join_stats); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_show_select.cpp b/src/duckdb/src/execution/physical_plan/plan_show_select.cpp new file mode 100644 index 00000000..a7392f96 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_show_select.cpp @@ -0,0 +1,47 @@ +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/parser/parsed_data/show_select_info.hpp" +#include "duckdb/planner/operator/logical_show.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalShow &op) { + DataChunk output; + output.Initialize(Allocator::Get(context), op.types); + + auto collection = make_uniq(context, op.types); + ColumnDataAppendState append_state; + collection->InitializeAppend(append_state); + for (idx_t column_idx = 0; column_idx < op.types_select.size(); column_idx++) { + auto type = op.types_select[column_idx]; + auto &name = op.aliases[column_idx]; + + // "name", TypeId::VARCHAR + output.SetValue(0, output.size(), Value(name)); + // "type", TypeId::VARCHAR + output.SetValue(1, output.size(), Value(type.ToString())); + // "null", TypeId::VARCHAR + output.SetValue(2, output.size(), Value("YES")); + // "pk", TypeId::BOOL + output.SetValue(3, output.size(), Value()); + // "dflt_value", TypeId::VARCHAR + output.SetValue(4, output.size(), Value()); + // "extra", TypeId::VARCHAR + output.SetValue(5, output.size(), Value()); + + output.SetCardinality(output.size() + 1); + if (output.size() == STANDARD_VECTOR_SIZE) { + collection->Append(append_state, output); + output.Reset(); + } + } + + collection->Append(append_state, output); + + // create a chunk scan to output the result + auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, + op.estimated_cardinality, std::move(collection)); + return std::move(chunk_scan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_simple.cpp b/src/duckdb/src/execution/physical_plan/plan_simple.cpp new file mode 100644 index 00000000..a32aef16 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_simple.cpp @@ -0,0 +1,51 @@ +#include "duckdb/execution/operator/helper/physical_load.hpp" +#include "duckdb/execution/operator/helper/physical_transaction.hpp" +#include "duckdb/execution/operator/helper/physical_vacuum.hpp" +#include "duckdb/execution/operator/schema/physical_alter.hpp" +#include "duckdb/execution/operator/schema/physical_attach.hpp" +#include "duckdb/execution/operator/schema/physical_create_schema.hpp" +#include "duckdb/execution/operator/schema/physical_create_sequence.hpp" +#include "duckdb/execution/operator/schema/physical_create_view.hpp" +#include "duckdb/execution/operator/schema/physical_detach.hpp" +#include "duckdb/execution/operator/schema/physical_drop.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSimple &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_ALTER: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_DROP: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_TRANSACTION: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_VACUUM: { + auto result = make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + if (!op.children.empty()) { + auto child = CreatePlan(*op.children[0]); + result->children.push_back(std::move(child)); + } + return std::move(result); + } + case LogicalOperatorType::LOGICAL_LOAD: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_ATTACH: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_DETACH: + return make_uniq(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + default: + throw NotImplementedException("Unimplemented type for logical simple operator"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_top_n.cpp b/src/duckdb/src/execution/physical_plan/plan_top_n.cpp new file mode 100644 index 00000000..b3043440 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_top_n.cpp @@ -0,0 +1,18 @@ +#include "duckdb/execution/operator/order/physical_top_n.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_top_n.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalTopN &op) { + D_ASSERT(op.children.size() == 1); + + auto plan = CreatePlan(*op.children[0]); + + auto top_n = + make_uniq(op.types, std::move(op.orders), (idx_t)op.limit, op.offset, op.estimated_cardinality); + top_n->children.push_back(std::move(plan)); + return std::move(top_n); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_unnest.cpp b/src/duckdb/src/execution/physical_plan/plan_unnest.cpp new file mode 100644 index 00000000..992da430 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_unnest.cpp @@ -0,0 +1,15 @@ +#include "duckdb/execution/operator/projection/physical_unnest.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_unnest.hpp" + +namespace duckdb { + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalUnnest &op) { + D_ASSERT(op.children.size() == 1); + auto plan = CreatePlan(*op.children[0]); + auto unnest = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); + unnest->children.push_back(std::move(plan)); + return std::move(unnest); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_update.cpp b/src/duckdb/src/execution/physical_plan/plan_update.cpp new file mode 100644 index 00000000..27890083 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_update.cpp @@ -0,0 +1,29 @@ +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/execution/operator/persistent/physical_update.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/operator/logical_update.hpp" +#include "duckdb/catalog/duck_catalog.hpp" + +namespace duckdb { + +unique_ptr DuckCatalog::PlanUpdate(ClientContext &context, LogicalUpdate &op, + unique_ptr plan) { + auto update = + make_uniq(op.types, op.table, op.table.GetStorage(), op.columns, std::move(op.expressions), + std::move(op.bound_defaults), op.estimated_cardinality, op.return_chunk); + + update->update_is_del_and_insert = op.update_is_del_and_insert; + update->children.push_back(std::move(plan)); + return std::move(update); +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalUpdate &op) { + D_ASSERT(op.children.size() == 1); + + auto plan = CreatePlan(*op.children[0]); + + dependencies.AddDependency(op.table); + return op.table.catalog.PlanUpdate(context, op, std::move(plan)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_window.cpp b/src/duckdb/src/execution/physical_plan/plan_window.cpp new file mode 100644 index 00000000..5e2e502c --- /dev/null +++ b/src/duckdb/src/execution/physical_plan/plan_window.cpp @@ -0,0 +1,129 @@ +#include "duckdb/execution/operator/aggregate/physical_streaming_window.hpp" +#include "duckdb/execution/operator/aggregate/physical_window.hpp" +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/operator/logical_window.hpp" + +#include + +namespace duckdb { + +static bool IsStreamingWindow(unique_ptr &expr) { + auto &wexpr = expr->Cast(); + if (!wexpr.partitions.empty() || !wexpr.orders.empty() || wexpr.ignore_nulls) { + return false; + } + switch (wexpr.type) { + // TODO: add more expression types here? + case ExpressionType::WINDOW_AGGREGATE: + // We can stream aggregates if they are "running totals" and don't use filters + return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS && + !wexpr.filter_expr; + case ExpressionType::WINDOW_FIRST_VALUE: + case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: + case ExpressionType::WINDOW_ROW_NUMBER: + return true; + default: + return false; + } +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { + D_ASSERT(op.children.size() == 1); + + auto plan = CreatePlan(*op.children[0]); +#ifdef DEBUG + for (auto &expr : op.expressions) { + D_ASSERT(expr->IsWindow()); + } +#endif + + op.estimated_cardinality = op.EstimateCardinality(context); + + // Slice types + auto types = op.types; + const auto output_idx = types.size() - op.expressions.size(); + types.resize(output_idx); + + // Identify streaming windows + vector blocking_windows; + vector streaming_windows; + for (idx_t expr_idx = 0; expr_idx < op.expressions.size(); expr_idx++) { + if (IsStreamingWindow(op.expressions[expr_idx])) { + streaming_windows.push_back(expr_idx); + } else { + blocking_windows.push_back(expr_idx); + } + } + + // Process the window functions by sharing the partition/order definitions + vector evaluation_order; + while (!blocking_windows.empty() || !streaming_windows.empty()) { + const bool process_streaming = blocking_windows.empty(); + auto &remaining = process_streaming ? streaming_windows : blocking_windows; + + // Find all functions that share the partitioning of the first remaining expression + const auto over_idx = remaining[0]; + auto &over_expr = op.expressions[over_idx]->Cast(); + + vector matching; + vector unprocessed; + for (const auto &expr_idx : remaining) { + D_ASSERT(op.expressions[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); + auto &wexpr = op.expressions[expr_idx]->Cast(); + if (over_expr.KeysAreCompatible(wexpr)) { + matching.emplace_back(expr_idx); + } else { + unprocessed.emplace_back(expr_idx); + } + } + remaining.swap(unprocessed); + + // Extract the matching expressions + vector> select_list; + for (const auto &expr_idx : matching) { + select_list.emplace_back(std::move(op.expressions[expr_idx])); + types.emplace_back(op.types[output_idx + expr_idx]); + } + + // Chain the new window operator on top of the plan + unique_ptr window; + if (process_streaming) { + window = make_uniq(types, std::move(select_list), op.estimated_cardinality); + } else { + window = make_uniq(types, std::move(select_list), op.estimated_cardinality); + } + window->children.push_back(std::move(plan)); + plan = std::move(window); + + // Remember the projection order if we changed it + if (!streaming_windows.empty() || !blocking_windows.empty() || !evaluation_order.empty()) { + evaluation_order.insert(evaluation_order.end(), matching.begin(), matching.end()); + } + } + + // Put everything back into place if it moved + if (!evaluation_order.empty()) { + vector> select_list(op.types.size()); + // The inputs don't move + for (idx_t i = 0; i < output_idx; ++i) { + select_list[i] = make_uniq(op.types[i], i); + } + // The outputs have been rearranged + for (idx_t i = 0; i < evaluation_order.size(); ++i) { + const auto expr_idx = evaluation_order[i] + output_idx; + select_list[expr_idx] = make_uniq(op.types[expr_idx], i + output_idx); + } + auto proj = make_uniq(op.types, std::move(select_list), op.estimated_cardinality); + proj->children.push_back(std::move(plan)); + plan = std::move(proj); + } + + return plan; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan_generator.cpp b/src/duckdb/src/execution/physical_plan_generator.cpp new file mode 100644 index 00000000..ab51b498 --- /dev/null +++ b/src/duckdb/src/execution/physical_plan_generator.cpp @@ -0,0 +1,233 @@ +#include "duckdb/execution/physical_plan_generator.hpp" + +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/column_binding_resolver.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/operator/logical_extension_operator.hpp" +#include "duckdb/planner/operator/list.hpp" + +namespace duckdb { + +class DependencyExtractor : public LogicalOperatorVisitor { +public: + explicit DependencyExtractor(DependencyList &dependencies) : dependencies(dependencies) { + } + +protected: + unique_ptr VisitReplace(BoundFunctionExpression &expr, unique_ptr *expr_ptr) override { + // extract dependencies from the bound function expression + if (expr.function.dependency) { + expr.function.dependency(expr, dependencies); + } + return nullptr; + } + +private: + DependencyList &dependencies; +}; + +PhysicalPlanGenerator::PhysicalPlanGenerator(ClientContext &context) : context(context) { +} + +PhysicalPlanGenerator::~PhysicalPlanGenerator() { +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(unique_ptr op) { + auto &profiler = QueryProfiler::Get(context); + + // first resolve column references + profiler.StartPhase("column_binding"); + ColumnBindingResolver resolver; + resolver.VisitOperator(*op); + profiler.EndPhase(); + + // now resolve types of all the operators + profiler.StartPhase("resolve_types"); + op->ResolveOperatorTypes(); + profiler.EndPhase(); + + // extract dependencies from the logical plan + DependencyExtractor extractor(dependencies); + extractor.VisitOperator(*op); + + // then create the main physical plan + profiler.StartPhase("create_plan"); + auto plan = CreatePlan(*op); + profiler.EndPhase(); + + plan->Verify(); + return plan; +} + +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalOperator &op) { + op.estimated_cardinality = op.EstimateCardinality(context); + unique_ptr plan = nullptr; + + switch (op.type) { + case LogicalOperatorType::LOGICAL_GET: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_PROJECTION: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_EMPTY_RESULT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_FILTER: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_WINDOW: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_UNNEST: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_LIMIT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_SAMPLE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_ORDER_BY: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_TOP_N: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_COPY_TO_FILE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_ANY_JOIN: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_INSERT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_DELETE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_CHUNK_GET: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_DELIM_GET: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_UPDATE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_CREATE_TABLE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_CREATE_INDEX: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_EXPLAIN: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_SHOW: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_DISTINCT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_PREPARE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_EXECUTE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_CREATE_VIEW: + case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: + case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: + case LogicalOperatorType::LOGICAL_CREATE_MACRO: + case LogicalOperatorType::LOGICAL_CREATE_TYPE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_PRAGMA: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_TRANSACTION: + case LogicalOperatorType::LOGICAL_ALTER: + case LogicalOperatorType::LOGICAL_DROP: + case LogicalOperatorType::LOGICAL_VACUUM: + case LogicalOperatorType::LOGICAL_LOAD: + case LogicalOperatorType::LOGICAL_ATTACH: + case LogicalOperatorType::LOGICAL_DETACH: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_CTE_REF: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_EXPORT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_SET: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_RESET: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_PIVOT: + plan = CreatePlan(op.Cast()); + break; + case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: + plan = op.Cast().CreatePlan(context, *this); + + if (!plan) { + throw InternalException("Missing PhysicalOperator for Extension Operator"); + } + break; + case LogicalOperatorType::LOGICAL_JOIN: + case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: + case LogicalOperatorType::LOGICAL_INVALID: { + throw NotImplementedException("Unimplemented logical operator type!"); + } + } + if (!plan) { + throw InternalException("Physical plan generator - no plan generated"); + } + + plan->estimated_cardinality = op.estimated_cardinality; + + return plan; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp new file mode 100644 index 00000000..e6d3ad20 --- /dev/null +++ b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp @@ -0,0 +1,810 @@ +#include "duckdb/execution/radix_partitioned_hashtable.hpp" + +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/types/row/tuple_data_collection.hpp" +#include "duckdb/common/types/row/tuple_data_iterator.hpp" +#include "duckdb/execution/aggregate_hashtable.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parallel/event.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +RadixPartitionedHashTable::RadixPartitionedHashTable(GroupingSet &grouping_set_p, const GroupedAggregateData &op_p) + : grouping_set(grouping_set_p), op(op_p) { + auto groups_count = op.GroupCount(); + for (idx_t i = 0; i < groups_count; i++) { + if (grouping_set.find(i) == grouping_set.end()) { + null_groups.push_back(i); + } + } + if (grouping_set.empty()) { + // Fake a single group with a constant value for aggregation without groups + group_types.emplace_back(LogicalType::TINYINT); + } + for (auto &entry : grouping_set) { + D_ASSERT(entry < op.group_types.size()); + group_types.push_back(op.group_types[entry]); + } + SetGroupingValues(); + + auto group_types_copy = group_types; + group_types_copy.emplace_back(LogicalType::HASH); + layout.Initialize(std::move(group_types_copy), AggregateObject::CreateAggregateObjects(op.bindings)); +} + +void RadixPartitionedHashTable::SetGroupingValues() { + // Compute the GROUPING values: + // For each parameter to the GROUPING clause, we check if the hash table groups on this particular group + // If it does, we return 0, otherwise we return 1 + // We then use bitshifts to combine these values + auto &grouping_functions = op.GetGroupingFunctions(); + for (auto &grouping : grouping_functions) { + int64_t grouping_value = 0; + D_ASSERT(grouping.size() < sizeof(int64_t) * 8); + for (idx_t i = 0; i < grouping.size(); i++) { + if (grouping_set.find(grouping[i]) == grouping_set.end()) { + // We don't group on this value! + grouping_value += (int64_t)1 << (grouping.size() - (i + 1)); + } + } + grouping_values.push_back(Value::BIGINT(grouping_value)); + } +} + +const TupleDataLayout &RadixPartitionedHashTable::GetLayout() const { + return layout; +} + +unique_ptr RadixPartitionedHashTable::CreateHT(ClientContext &context, const idx_t capacity, + const idx_t radix_bits) const { + return make_uniq(context, BufferAllocator::Get(context), group_types, op.payload_types, + op.bindings, capacity, radix_bits); +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +struct AggregatePartition { + explicit AggregatePartition(unique_ptr data_p) : data(std::move(data_p)), finalized(false) { + } + unique_ptr data; + atomic finalized; +}; + +class RadixHTGlobalSinkState; + +struct RadixHTConfig { +public: + explicit RadixHTConfig(ClientContext &context, RadixHTGlobalSinkState &sink); + + void SetRadixBits(idx_t radix_bits_p); + bool SetRadixBitsToExternal(); + idx_t GetRadixBits() const; + +private: + void SetRadixBitsInternal(const idx_t radix_bits_p, bool external); + static idx_t InitialSinkRadixBits(ClientContext &context); + static idx_t MaximumSinkRadixBits(ClientContext &context); + static idx_t ExternalRadixBits(const idx_t &maximum_sink_radix_bits_p); + static idx_t SinkCapacity(ClientContext &context); + +private: + //! Assume (1 << 15) = 32KB L1 cache per core, divided by two because hyperthreading + static constexpr const idx_t L1_CACHE_SIZE = 32768 / 2; + //! Assume (1 << 20) = 1MB L2 cache per core, divided by two because hyperthreading + static constexpr const idx_t L2_CACHE_SIZE = 1048576 / 2; + //! Assume (1 << 20) + (1 << 19) = 1.5MB L3 cache per core (shared), divided by two because hyperthreading + static constexpr const idx_t L3_CACHE_SIZE = 1572864 / 2; + + //! Sink radix bits to initialize with + static constexpr const idx_t MAXIMUM_INITIAL_SINK_RADIX_BITS = 3; + //! Maximum Sink radix bits (independent of threads) + static constexpr const idx_t MAXIMUM_FINAL_SINK_RADIX_BITS = 7; + //! By how many radix bits to increment if we go external + static constexpr const idx_t EXTERNAL_RADIX_BITS_INCREMENT = 3; + + //! The global sink state + RadixHTGlobalSinkState &sink; + //! Current thread-global sink radix bits + atomic sink_radix_bits; + //! Maximum Sink radix bits (set based on number of threads) + const idx_t maximum_sink_radix_bits; + //! Radix bits if we go external + const idx_t external_radix_bits; + +public: + //! Capacity of HTs during the Sink + const idx_t sink_capacity; + + //! If we fill this many blocks per partition, we trigger a repartition + static constexpr const double BLOCK_FILL_FACTOR = 1.8; + //! By how many bits to repartition if a repartition is triggered + static constexpr const idx_t REPARTITION_RADIX_BITS = 2; +}; + +class RadixHTGlobalSinkState : public GlobalSinkState { +public: + RadixHTGlobalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); + + //! Destroys aggregate states (if multi-scan) + ~RadixHTGlobalSinkState() override; + void Destroy(); + +public: + //! The radix HT + const RadixPartitionedHashTable &radix_ht; + //! Config for partitioning + RadixHTConfig config; + + //! Whether we've called Finalize + bool finalized; + //! Whether we are doing an external aggregation + atomic external; + //! Threads that have called Sink + atomic active_threads; + //! If any thread has called combine + atomic any_combined; + + //! Lock for uncombined_data/stored_allocators + mutex lock; + //! Uncombined partitioned data that will be put into the AggregatePartitions + unique_ptr uncombined_data; + //! Allocators used during the Sink/Finalize + vector> stored_allocators; + + //! Partitions that are finalized during GetData + vector> partitions; + + //! For synchronizing finalize tasks + atomic finalize_idx; + + //! Pin properties when scanning + TupleDataPinProperties scan_pin_properties; + //! Total count before combining + idx_t count_before_combining; +}; + +RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht_p) + : radix_ht(radix_ht_p), config(context, *this), finalized(false), external(false), active_threads(0), + any_combined(false), finalize_idx(0), scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE), + count_before_combining(0) { +} + +RadixHTGlobalSinkState::~RadixHTGlobalSinkState() { + Destroy(); +} + +// LCOV_EXCL_START +void RadixHTGlobalSinkState::Destroy() { + if (scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE || count_before_combining == 0 || + partitions.empty()) { + // Already destroyed / empty + return; + } + + TupleDataLayout layout = partitions[0]->data->GetLayout().Copy(); + if (!layout.HasDestructor()) { + return; // No destructors, exit + } + + // There are aggregates with destructors: Call the destructor for each of the aggregates + RowOperationsState row_state(*stored_allocators.back()); + for (auto &partition : partitions) { + auto &data_collection = *partition->data; + if (data_collection.Count() == 0) { + continue; + } + TupleDataChunkIterator iterator(data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); + auto &row_locations = iterator.GetChunkState().row_locations; + do { + RowOperations::DestroyStates(row_state, layout, row_locations, iterator.GetCurrentChunkCount()); + } while (iterator.Next()); + data_collection.Reset(); + } +} +// LCOV_EXCL_STOP + +RadixHTConfig::RadixHTConfig(ClientContext &context, RadixHTGlobalSinkState &sink_p) + : sink(sink_p), sink_radix_bits(InitialSinkRadixBits(context)), + maximum_sink_radix_bits(MaximumSinkRadixBits(context)), + external_radix_bits(ExternalRadixBits(maximum_sink_radix_bits)), sink_capacity(SinkCapacity(context)) { +} + +void RadixHTConfig::SetRadixBits(idx_t radix_bits_p) { + SetRadixBitsInternal(MinValue(radix_bits_p, maximum_sink_radix_bits), false); +} + +bool RadixHTConfig::SetRadixBitsToExternal() { + SetRadixBitsInternal(external_radix_bits, true); + return sink.external; +} + +idx_t RadixHTConfig::GetRadixBits() const { + return sink_radix_bits; +} + +void RadixHTConfig::SetRadixBitsInternal(const idx_t radix_bits_p, bool external) { + if (sink_radix_bits >= radix_bits_p || sink.any_combined) { + return; + } + + lock_guard guard(sink.lock); + if (sink_radix_bits >= radix_bits_p || sink.any_combined) { + return; + } + + if (external) { + sink.external = true; + } + sink_radix_bits = radix_bits_p; + return; +} + +idx_t RadixHTConfig::InitialSinkRadixBits(ClientContext &context) { + const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + return MinValue(RadixPartitioning::RadixBits(NextPowerOfTwo(active_threads)), MAXIMUM_INITIAL_SINK_RADIX_BITS); +} + +idx_t RadixHTConfig::MaximumSinkRadixBits(ClientContext &context) { + const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + return MinValue(RadixPartitioning::RadixBits(NextPowerOfTwo(active_threads)), MAXIMUM_FINAL_SINK_RADIX_BITS); +} + +idx_t RadixHTConfig::ExternalRadixBits(const idx_t &maximum_sink_radix_bits_p) { + return MinValue(maximum_sink_radix_bits_p + EXTERNAL_RADIX_BITS_INCREMENT, MAXIMUM_FINAL_SINK_RADIX_BITS); +} + +idx_t RadixHTConfig::SinkCapacity(ClientContext &context) { + // Get active and maximum number of threads + const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + const auto max_threads = DBConfig::GetSystemMaxThreads(FileSystem::GetFileSystem(context)); + + // Compute cache size per active thread (assuming cache is shared) + const auto total_shared_cache_size = max_threads * L3_CACHE_SIZE; + const auto cache_per_active_thread = L1_CACHE_SIZE + L2_CACHE_SIZE + total_shared_cache_size / active_threads; + + // Divide cache per active thread by entry size, round up to next power of two, to get capacity + const auto size_per_entry = sizeof(aggr_ht_entry_t) * GroupedAggregateHashTable::LOAD_FACTOR; + const auto capacity = NextPowerOfTwo(cache_per_active_thread / size_per_entry); + + // Capacity must be at least the minimum capacity + return MaxValue(capacity, GroupedAggregateHashTable::InitialCapacity()); +} + +class RadixHTLocalSinkState : public LocalSinkState { +public: + RadixHTLocalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); + +public: + //! Thread-local HT that is re-used after abandoning + unique_ptr ht; + //! Chunk with group columns + DataChunk group_chunk; + + //! Data that is abandoned ends up here (only if we're doing external aggregation) + unique_ptr abandoned_data; +}; + +RadixHTLocalSinkState::RadixHTLocalSinkState(ClientContext &, const RadixPartitionedHashTable &radix_ht) { + // If there are no groups we create a fake group so everything has the same group + group_chunk.InitializeEmpty(radix_ht.group_types); + if (radix_ht.grouping_set.empty()) { + group_chunk.data[0].Reference(Value::TINYINT(42)); + } +} + +unique_ptr RadixPartitionedHashTable::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr RadixPartitionedHashTable::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context.client, *this); +} + +void RadixPartitionedHashTable::PopulateGroupChunk(DataChunk &group_chunk, DataChunk &input_chunk) const { + idx_t chunk_index = 0; + // Populate the group_chunk + for (auto &group_idx : grouping_set) { + // Retrieve the expression containing the index in the input chunk + auto &group = op.groups[group_idx]; + D_ASSERT(group->type == ExpressionType::BOUND_REF); + auto &bound_ref_expr = group->Cast(); + // Reference from input_chunk[group.index] -> group_chunk[chunk_index] + group_chunk.data[chunk_index++].Reference(input_chunk.data[bound_ref_expr.index]); + } + group_chunk.SetCardinality(input_chunk.size()); + group_chunk.Verify(); +} + +bool MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, RadixHTLocalSinkState &lstate) { + auto &config = gstate.config; + auto &ht = *lstate.ht; + auto &partitioned_data = ht.GetPartitionedData(); + + // Check if we're approaching the memory limit + const idx_t n_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + const idx_t limit = BufferManager::GetBufferManager(context).GetMaxMemory(); + const idx_t thread_limit = 0.6 * limit / n_threads; + if (ht.GetPartitionedData()->SizeInBytes() > thread_limit || context.config.force_external) { + if (gstate.config.SetRadixBitsToExternal()) { + // We're approaching the memory limit, unpin the data + if (!lstate.abandoned_data) { + lstate.abandoned_data = make_uniq( + BufferManager::GetBufferManager(context), gstate.radix_ht.GetLayout(), config.GetRadixBits(), + gstate.radix_ht.GetLayout().ColumnCount() - 1); + } + + ht.UnpinData(); + partitioned_data->Repartition(*lstate.abandoned_data); + ht.SetRadixBits(gstate.config.GetRadixBits()); + ht.InitializePartitionedData(); + return true; + } + } + + const auto partition_count = partitioned_data->PartitionCount(); + const auto current_radix_bits = RadixPartitioning::RadixBits(partition_count); + D_ASSERT(current_radix_bits <= config.GetRadixBits()); + + const auto row_size_per_partition = + partitioned_data->Count() * partitioned_data->GetLayout().GetRowWidth() / partition_count; + if (row_size_per_partition > config.BLOCK_FILL_FACTOR * Storage::BLOCK_SIZE) { + // We crossed our block filling threshold, try to increment radix bits + config.SetRadixBits(current_radix_bits + config.REPARTITION_RADIX_BITS); + } + + const auto global_radix_bits = config.GetRadixBits(); + if (current_radix_bits == global_radix_bits) { + return false; // We're already on the right number of radix bits + } + + // We're out-of-sync with the global radix bits, repartition + ht.UnpinData(); + auto old_partitioned_data = std::move(partitioned_data); + ht.SetRadixBits(global_radix_bits); + ht.InitializePartitionedData(); + old_partitioned_data->Repartition(*ht.GetPartitionedData()); + return true; +} + +void RadixPartitionedHashTable::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, + DataChunk &payload_input, const unsafe_vector &filter) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + if (!lstate.ht) { + lstate.ht = CreateHT(context.client, gstate.config.sink_capacity, gstate.config.GetRadixBits()); + gstate.active_threads++; + } + + auto &group_chunk = lstate.group_chunk; + PopulateGroupChunk(group_chunk, chunk); + + auto &ht = *lstate.ht; + ht.AddChunk(group_chunk, payload_input, filter); + + if (ht.Count() + STANDARD_VECTOR_SIZE < ht.ResizeThreshold()) { + return; // We can fit another chunk + } + + if (gstate.active_threads > 2) { + // 'Reset' the HT without taking its data, we can just keep appending to the same collection + // This only works because we never resize the HT + ht.ClearPointerTable(); + ht.ResetCount(); + // We don't do this when running with 1 or 2 threads, it only makes sense when there's many threads + } + + // Check if we need to repartition + auto repartitioned = MaybeRepartition(context.client, gstate, lstate); + + if (repartitioned && ht.Count() != 0) { + // We repartitioned, but we didn't clear the pointer table / reset the count because we're on 1 or 2 threads + ht.ClearPointerTable(); + ht.ResetCount(); + } + + // TODO: combine early and often +} + +void RadixPartitionedHashTable::Combine(ExecutionContext &context, GlobalSinkState &gstate_p, + LocalSinkState &lstate_p) const { + auto &gstate = gstate_p.Cast(); + auto &lstate = lstate_p.Cast(); + if (!lstate.ht) { + return; + } + + // Set any_combined, then check one last time whether we need to repartition + gstate.any_combined = true; + MaybeRepartition(context.client, gstate, lstate); + + auto &ht = *lstate.ht; + ht.UnpinData(); + + if (lstate.abandoned_data) { + D_ASSERT(gstate.external); + D_ASSERT(lstate.abandoned_data->PartitionCount() == lstate.ht->GetPartitionedData()->PartitionCount()); + D_ASSERT(lstate.abandoned_data->PartitionCount() == + RadixPartitioning::NumberOfPartitions(gstate.config.GetRadixBits())); + lstate.abandoned_data->Combine(*lstate.ht->GetPartitionedData()); + } else { + lstate.abandoned_data = std::move(ht.GetPartitionedData()); + } + + lock_guard guard(gstate.lock); + if (gstate.uncombined_data) { + gstate.uncombined_data->Combine(*lstate.abandoned_data); + } else { + gstate.uncombined_data = std::move(lstate.abandoned_data); + } + gstate.stored_allocators.emplace_back(ht.GetAggregateAllocator()); +} + +void RadixPartitionedHashTable::Finalize(ClientContext &, GlobalSinkState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + + if (gstate.uncombined_data) { + auto &uncombined_data = *gstate.uncombined_data; + gstate.count_before_combining = uncombined_data.Count(); + + // If true there is no need to combine, it was all done by a single thread in a single HT + const auto single_ht = !gstate.external && gstate.active_threads == 1; + + auto &uncombined_partition_data = uncombined_data.GetPartitions(); + const auto n_partitions = uncombined_partition_data.size(); + gstate.partitions.reserve(n_partitions); + for (idx_t i = 0; i < n_partitions; i++) { + gstate.partitions.emplace_back(make_uniq(std::move(uncombined_partition_data[i]))); + if (single_ht) { + gstate.finalize_idx++; + gstate.partitions.back()->finalized = true; + } + } + } else { + gstate.count_before_combining = 0; + } + + gstate.finalized = true; +} + +//===--------------------------------------------------------------------===// +// Source +//===--------------------------------------------------------------------===// +idx_t RadixPartitionedHashTable::NumberOfPartitions(GlobalSinkState &sink_p) const { + auto &sink = sink_p.Cast(); + return sink.partitions.size(); +} + +void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &sink_p) { + auto &sink = sink_p.Cast(); + sink.scan_pin_properties = TupleDataPinProperties::UNPIN_AFTER_DONE; +} + +enum class RadixHTSourceTaskType : uint8_t { NO_TASK, FINALIZE, SCAN }; + +class RadixHTLocalSourceState; + +class RadixHTGlobalSourceState : public GlobalSourceState { +public: + RadixHTGlobalSourceState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); + + //! Assigns a task to a local source state + bool AssignTask(RadixHTGlobalSinkState &sink, RadixHTLocalSourceState &lstate); + +public: + //! The client context + ClientContext &context; + //! For synchronizing the source phase + atomic finished; + + //! Column ids for scanning + vector column_ids; + + //! For synchronizing scan tasks + atomic scan_idx; + atomic scan_done; +}; + +enum class RadixHTScanStatus : uint8_t { INIT, IN_PROGRESS, DONE }; + +class RadixHTLocalSourceState : public LocalSourceState { +public: + explicit RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &radix_ht); + +public: + //! Do the work this thread has been assigned + void ExecuteTask(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk); + //! Whether this thread has finished the work it has been assigned + bool TaskFinished(); + +private: + //! Execute the finalize or scan task + void Finalize(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate); + void Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk); + +public: + //! Current task and index + RadixHTSourceTaskType task; + idx_t task_idx; + + //! Thread-local HT that is re-used to Finalize + unique_ptr ht; + //! Current status of a Scan + RadixHTScanStatus scan_status; + +private: + //! Allocator and layout for finalizing state + TupleDataLayout layout; + ArenaAllocator aggregate_allocator; + + //! State and chunk for scanning + TupleDataScanState scan_state; + DataChunk scan_chunk; +}; + +unique_ptr RadixPartitionedHashTable::GetGlobalSourceState(ClientContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr RadixPartitionedHashTable::GetLocalSourceState(ExecutionContext &context) const { + return make_uniq(context, *this); +} + +RadixHTGlobalSourceState::RadixHTGlobalSourceState(ClientContext &context_p, const RadixPartitionedHashTable &radix_ht) + : context(context_p), finished(false), scan_idx(0), scan_done(0) { + for (column_t column_id = 0; column_id < radix_ht.group_types.size(); column_id++) { + column_ids.push_back(column_id); + } +} + +bool RadixHTGlobalSourceState::AssignTask(RadixHTGlobalSinkState &sink, RadixHTLocalSourceState &lstate) { + D_ASSERT(lstate.scan_status != RadixHTScanStatus::IN_PROGRESS); + + const auto n_partitions = sink.partitions.size(); + if (finished) { + return false; + } + // We first try to assign a Scan task, then a Finalize task if that didn't work, without using any locks + + // We need an atomic compare-and-swap to assign a Scan task, because we need to only increment + // the 'scan_idx' atomic if the 'finalize' of that partition is true, i.e., ready to be scanned + bool scan_assigned = true; + do { + lstate.task_idx = scan_idx.load(); + if (lstate.task_idx >= n_partitions || !sink.partitions[lstate.task_idx]->finalized) { + scan_assigned = false; + break; + } + } while (!std::atomic_compare_exchange_weak(&scan_idx, &lstate.task_idx, lstate.task_idx + 1)); + + if (scan_assigned) { + // We successfully assigned a Scan task + D_ASSERT(lstate.task_idx < n_partitions && sink.partitions[lstate.task_idx]->finalized); + lstate.task = RadixHTSourceTaskType::SCAN; + lstate.scan_status = RadixHTScanStatus::INIT; + return true; + } + + // We didn't assign a Scan task + if (sink.finalize_idx >= n_partitions) { + return false; // No finalize tasks left + } + + // We can just increment the atomic here, much simpler than assigning the scan task + lstate.task_idx = sink.finalize_idx++; + if (lstate.task_idx < n_partitions) { + // We successfully assigned a Finalize task + lstate.task = RadixHTSourceTaskType::FINALIZE; + return true; + } + + // We didn't manage to assign a Finalize task + return false; +} + +RadixHTLocalSourceState::RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &radix_ht) + : task(RadixHTSourceTaskType::NO_TASK), scan_status(RadixHTScanStatus::DONE), layout(radix_ht.GetLayout().Copy()), + aggregate_allocator(BufferAllocator::Get(context.client)) { + auto &allocator = BufferAllocator::Get(context.client); + auto scan_chunk_types = radix_ht.group_types; + for (auto &aggr_type : radix_ht.op.aggregate_return_types) { + scan_chunk_types.push_back(aggr_type); + } + scan_chunk.Initialize(allocator, scan_chunk_types); +} + +void RadixHTLocalSourceState::ExecuteTask(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, + DataChunk &chunk) { + switch (task) { + case RadixHTSourceTaskType::FINALIZE: + Finalize(sink, gstate); + break; + case RadixHTSourceTaskType::SCAN: + Scan(sink, gstate, chunk); + break; + default: + throw InternalException("Unexpected RadixHTSourceTaskType in ExecuteTask!"); + } +} + +void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate) { + D_ASSERT(task == RadixHTSourceTaskType::FINALIZE); + D_ASSERT(scan_status != RadixHTScanStatus::IN_PROGRESS); + + auto &partition = *sink.partitions[task_idx]; + if (partition.data->Count() == 0) { + partition.finalized = true; + return; + } + + if (!ht) { + // Create a HT with sufficient capacity + const auto capacity = GroupedAggregateHashTable::GetCapacityForCount(partition.data->Count()); + ht = sink.radix_ht.CreateHT(gstate.context, capacity, 0); + } else { + // We may want to resize here to the size of this partition, but for now we just assume uniform partition sizes + ht->InitializePartitionedData(); + ht->ClearPointerTable(); + ht->ResetCount(); + } + + // Now combine the uncombined data using this thread's HT + ht->Combine(*partition.data); + ht->UnpinData(); + + // Move the combined data back to the partition + partition.data = + make_uniq(BufferManager::GetBufferManager(gstate.context), sink.radix_ht.GetLayout()); + partition.data->Combine(*ht->GetPartitionedData()->GetPartitions()[0]); + + // Mark partition as ready to scan + partition.finalized = true; + + // Make sure this thread's aggregate allocator does not get lost + lock_guard guard(sink.lock); + sink.stored_allocators.emplace_back(ht->GetAggregateAllocator()); +} + +void RadixHTLocalSourceState::Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk) { + D_ASSERT(task == RadixHTSourceTaskType::SCAN); + D_ASSERT(scan_status != RadixHTScanStatus::DONE); + + auto &partition = *sink.partitions[task_idx]; + D_ASSERT(partition.finalized); + auto &data_collection = *partition.data; + + if (data_collection.Count() == 0) { + scan_status = RadixHTScanStatus::DONE; + if (++gstate.scan_done == sink.partitions.size()) { + gstate.finished = true; + } + return; + } + + if (scan_status == RadixHTScanStatus::INIT) { + data_collection.InitializeScan(scan_state, gstate.column_ids, sink.scan_pin_properties); + scan_status = RadixHTScanStatus::IN_PROGRESS; + } + + if (!data_collection.Scan(scan_state, scan_chunk)) { + scan_status = RadixHTScanStatus::DONE; + if (sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE) { + data_collection.Reset(); + } + return; + } + + if (data_collection.ScanComplete(scan_state)) { + if (++gstate.scan_done == sink.partitions.size()) { + gstate.finished = true; + } + } + + RowOperationsState row_state(aggregate_allocator); + const auto group_cols = layout.ColumnCount() - 1; + RowOperations::FinalizeStates(row_state, layout, scan_state.chunk_state.row_locations, scan_chunk, group_cols); + + if (sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE && layout.HasDestructor()) { + RowOperations::DestroyStates(row_state, layout, scan_state.chunk_state.row_locations, scan_chunk.size()); + } + + auto &radix_ht = sink.radix_ht; + idx_t chunk_index = 0; + for (auto &entry : radix_ht.grouping_set) { + chunk.data[entry].Reference(scan_chunk.data[chunk_index++]); + } + for (auto null_group : radix_ht.null_groups) { + chunk.data[null_group].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[null_group], true); + } + D_ASSERT(radix_ht.grouping_set.size() + radix_ht.null_groups.size() == radix_ht.op.GroupCount()); + for (idx_t col_idx = 0; col_idx < radix_ht.op.aggregates.size(); col_idx++) { + chunk.data[radix_ht.op.GroupCount() + col_idx].Reference( + scan_chunk.data[radix_ht.group_types.size() + col_idx]); + } + D_ASSERT(radix_ht.op.grouping_functions.size() == radix_ht.grouping_values.size()); + for (idx_t i = 0; i < radix_ht.op.grouping_functions.size(); i++) { + chunk.data[radix_ht.op.GroupCount() + radix_ht.op.aggregates.size() + i].Reference(radix_ht.grouping_values[i]); + } + chunk.SetCardinality(scan_chunk); + D_ASSERT(chunk.size() != 0); +} + +bool RadixHTLocalSourceState::TaskFinished() { + switch (task) { + case RadixHTSourceTaskType::FINALIZE: + return true; + case RadixHTSourceTaskType::SCAN: + return scan_status == RadixHTScanStatus::DONE; + default: + D_ASSERT(task == RadixHTSourceTaskType::NO_TASK); + return true; + } +} + +SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, DataChunk &chunk, + GlobalSinkState &sink_p, OperatorSourceInput &input) const { + auto &sink = sink_p.Cast(); + D_ASSERT(sink.finalized); + + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + D_ASSERT(sink.scan_pin_properties == TupleDataPinProperties::UNPIN_AFTER_DONE || + sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE); + + if (gstate.finished) { + return SourceResultType::FINISHED; + } + + if (sink.count_before_combining == 0) { + if (grouping_set.empty()) { + // Special case hack to sort out aggregating from empty intermediates for aggregations without groups + D_ASSERT(chunk.ColumnCount() == null_groups.size() + op.aggregates.size() + op.grouping_functions.size()); + // For each column in the aggregates, set to initial state + chunk.SetCardinality(1); + for (auto null_group : null_groups) { + chunk.data[null_group].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[null_group], true); + } + ArenaAllocator allocator(BufferAllocator::Get(context.client)); + for (idx_t i = 0; i < op.aggregates.size(); i++) { + D_ASSERT(op.aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); + auto &aggr = op.aggregates[i]->Cast(); + auto aggr_state = make_unsafe_uniq_array(aggr.function.state_size()); + aggr.function.initialize(aggr_state.get()); + + AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); + Vector state_vector(Value::POINTER(CastPointerToValue(aggr_state.get()))); + aggr.function.finalize(state_vector, aggr_input_data, chunk.data[null_groups.size() + i], 1, 0); + if (aggr.function.destructor) { + aggr.function.destructor(state_vector, aggr_input_data, 1); + } + } + // Place the grouping values (all the groups of the grouping_set condensed into a single value) + // Behind the null groups + aggregates + for (idx_t i = 0; i < op.grouping_functions.size(); i++) { + chunk.data[null_groups.size() + op.aggregates.size() + i].Reference(grouping_values[i]); + } + } + gstate.finished = true; + return SourceResultType::FINISHED; + } + + while (!gstate.finished && chunk.size() == 0) { + if (!lstate.TaskFinished() || gstate.AssignTask(sink, lstate)) { + lstate.ExecuteTask(sink, gstate, chunk); + } + } + + if (chunk.size() != 0) { + return SourceResultType::HAVE_MORE_OUTPUT; + } else { + return SourceResultType::FINISHED; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/reservoir_sample.cpp b/src/duckdb/src/execution/reservoir_sample.cpp new file mode 100644 index 00000000..3522b71a --- /dev/null +++ b/src/duckdb/src/execution/reservoir_sample.cpp @@ -0,0 +1,233 @@ +#include "duckdb/execution/reservoir_sample.hpp" +#include "duckdb/common/pair.hpp" + +namespace duckdb { + +ReservoirSample::ReservoirSample(Allocator &allocator, idx_t sample_count, int64_t seed) + : BlockingSample(seed), sample_count(sample_count), reservoir(allocator) { +} + +void ReservoirSample::AddToReservoir(DataChunk &input) { + if (sample_count == 0) { + return; + } + // Input: A population V of n weighted items + // Output: A reservoir R with a size m + // 1: The first m items of V are inserted into R + // first we need to check if the reservoir already has "m" elements + if (reservoir.Count() < sample_count) { + if (FillReservoir(input) == 0) { + // entire chunk was consumed by reservoir + return; + } + } + // find the position of next_index relative to current_count + idx_t remaining = input.size(); + idx_t base_offset = 0; + while (true) { + idx_t offset = base_reservoir_sample.next_index - base_reservoir_sample.current_count; + if (offset >= remaining) { + // not in this chunk! increment current count and go to the next chunk + base_reservoir_sample.current_count += remaining; + return; + } + // in this chunk! replace the element + ReplaceElement(input, base_offset + offset); + // shift the chunk forward + remaining -= offset; + base_offset += offset; + } +} + +unique_ptr ReservoirSample::GetChunk() { + return reservoir.Fetch(); +} + +void ReservoirSample::ReplaceElement(DataChunk &input, idx_t index_in_chunk) { + // replace the entry in the reservoir + // 8. The item in R with the minimum key is replaced by item vi + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + reservoir.SetValue(col_idx, base_reservoir_sample.min_entry, input.GetValue(col_idx, index_in_chunk)); + } + base_reservoir_sample.ReplaceElement(); +} + +idx_t ReservoirSample::FillReservoir(DataChunk &input) { + idx_t chunk_count = input.size(); + input.Flatten(); + + // we have not: append to the reservoir + idx_t required_count; + if (reservoir.Count() + chunk_count >= sample_count) { + // have to limit the count of the chunk + required_count = sample_count - reservoir.Count(); + } else { + // we copy the entire chunk + required_count = chunk_count; + } + // instead of copying we just change the pointer in the current chunk + input.SetCardinality(required_count); + reservoir.Append(input); + + base_reservoir_sample.InitializeReservoir(reservoir.Count(), sample_count); + + // check if there are still elements remaining + // this happens if we are on a boundary + // for example, input.size() is 1024, but our sample size is 10 + if (required_count == chunk_count) { + // we are done here + return 0; + } + // we still need to process a part of the chunk + // create a selection vector of the remaining elements + SelectionVector sel(STANDARD_VECTOR_SIZE); + for (idx_t i = required_count; i < chunk_count; i++) { + sel.set_index(i - required_count, i); + } + // slice the input vector and continue + input.Slice(sel, chunk_count - required_count); + return input.size(); +} + +ReservoirSamplePercentage::ReservoirSamplePercentage(Allocator &allocator, double percentage, int64_t seed) + : BlockingSample(seed), allocator(allocator), sample_percentage(percentage / 100.0), current_count(0), + is_finalized(false) { + reservoir_sample_size = idx_t(sample_percentage * RESERVOIR_THRESHOLD); + current_sample = make_uniq(allocator, reservoir_sample_size, random.NextRandomInteger()); +} + +void ReservoirSamplePercentage::AddToReservoir(DataChunk &input) { + if (current_count + input.size() > RESERVOIR_THRESHOLD) { + // we don't have enough space in our current reservoir + // first check what we still need to append to the current sample + idx_t append_to_current_sample_count = RESERVOIR_THRESHOLD - current_count; + idx_t append_to_next_sample = input.size() - append_to_current_sample_count; + if (append_to_current_sample_count > 0) { + // we have elements remaining, first add them to the current sample + if (append_to_next_sample > 0) { + // we need to also add to the next sample + DataChunk new_chunk; + new_chunk.InitializeEmpty(input.GetTypes()); + new_chunk.Slice(input, *FlatVector::IncrementalSelectionVector(), append_to_current_sample_count); + new_chunk.Flatten(); + current_sample->AddToReservoir(new_chunk); + } else { + input.Flatten(); + input.SetCardinality(append_to_current_sample_count); + current_sample->AddToReservoir(input); + } + } + if (append_to_next_sample > 0) { + // slice the input for the remainder + SelectionVector sel(append_to_next_sample); + for (idx_t i = 0; i < append_to_next_sample; i++) { + sel.set_index(i, append_to_current_sample_count + i); + } + input.Slice(sel, append_to_next_sample); + } + // now our first sample is filled: append it to the set of finished samples + finished_samples.push_back(std::move(current_sample)); + + // allocate a new sample, and potentially add the remainder of the current input to that sample + current_sample = make_uniq(allocator, reservoir_sample_size, random.NextRandomInteger()); + if (append_to_next_sample > 0) { + current_sample->AddToReservoir(input); + } + current_count = append_to_next_sample; + } else { + // we can just append to the current sample + current_count += input.size(); + current_sample->AddToReservoir(input); + } +} + +unique_ptr ReservoirSamplePercentage::GetChunk() { + if (!is_finalized) { + Finalize(); + } + while (!finished_samples.empty()) { + auto &front = finished_samples.front(); + auto chunk = front->GetChunk(); + if (chunk && chunk->size() > 0) { + return chunk; + } + // move to the next sample + finished_samples.erase(finished_samples.begin()); + } + return nullptr; +} + +void ReservoirSamplePercentage::Finalize() { + // need to finalize the current sample, if any + if (current_count > 0) { + // create a new sample + auto new_sample_size = idx_t(round(sample_percentage * current_count)); + auto new_sample = make_uniq(allocator, new_sample_size, random.NextRandomInteger()); + while (true) { + auto chunk = current_sample->GetChunk(); + if (!chunk || chunk->size() == 0) { + break; + } + new_sample->AddToReservoir(*chunk); + } + finished_samples.push_back(std::move(new_sample)); + } + is_finalized = true; +} + +BaseReservoirSampling::BaseReservoirSampling(int64_t seed) : random(seed) { + next_index = 0; + min_threshold = 0; + min_entry = 0; + current_count = 0; +} + +BaseReservoirSampling::BaseReservoirSampling() : BaseReservoirSampling(-1) { +} + +void BaseReservoirSampling::InitializeReservoir(idx_t cur_size, idx_t sample_size) { + //! 1: The first m items of V are inserted into R + //! first we need to check if the reservoir already has "m" elements + if (cur_size == sample_size) { + //! 2. For each item vi ∈ R: Calculate a key ki = random(0, 1) + //! we then define the threshold to enter the reservoir T_w as the minimum key of R + //! we use a priority queue to extract the minimum key in O(1) time + for (idx_t i = 0; i < sample_size; i++) { + double k_i = random.NextRandom(); + reservoir_weights.emplace(-k_i, i); + } + SetNextEntry(); + } +} + +void BaseReservoirSampling::SetNextEntry() { + //! 4. Let r = random(0, 1) and Xw = log(r) / log(T_w) + auto &min_key = reservoir_weights.top(); + double t_w = -min_key.first; + double r = random.NextRandom(); + double x_w = log(r) / log(t_w); + //! 5. From the current item vc skip items until item vi , such that: + //! 6. wc +wc+1 +···+wi−1 < Xw <= wc +wc+1 +···+wi−1 +wi + //! since all our weights are 1 (uniform sampling), we can just determine the amount of elements to skip + min_threshold = t_w; + min_entry = min_key.second; + next_index = MaxValue(1, idx_t(round(x_w))); + current_count = 0; +} + +void BaseReservoirSampling::ReplaceElement() { + //! replace the entry in the reservoir + //! pop the minimum entry + reservoir_weights.pop(); + //! now update the reservoir + //! 8. Let tw = Tw i , r2 = random(tw,1) and vi’s key: ki = (r2)1/wi + //! 9. The new threshold Tw is the new minimum key of R + //! we generate a random number between (min_threshold, 1) + double r2 = random.NextRandom(min_threshold, 1); + //! now we insert the new weight into the reservoir + reservoir_weights.emplace(-r2, min_entry); + //! we update the min entry with the new min entry in the reservoir + SetNextEntry(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/window_executor.cpp b/src/duckdb/src/execution/window_executor.cpp new file mode 100644 index 00000000..fb094e8d --- /dev/null +++ b/src/duckdb/src/execution/window_executor.cpp @@ -0,0 +1,1285 @@ +#include "duckdb/execution/window_executor.hpp" + +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/operator/subtract.hpp" + +namespace duckdb { + +static idx_t FindNextStart(const ValidityMask &mask, idx_t l, const idx_t r, idx_t &n) { + if (mask.AllValid()) { + auto start = MinValue(l + n - 1, r); + n -= MinValue(n, r - l); + return start; + } + + while (l < r) { + // If l is aligned with the start of a block, and the block is blank, then skip forward one block. + idx_t entry_idx; + idx_t shift; + mask.GetEntryIndex(l, entry_idx, shift); + + const auto block = mask.GetValidityEntry(entry_idx); + if (mask.NoneValid(block) && !shift) { + l += ValidityMask::BITS_PER_VALUE; + continue; + } + + // Loop over the block + for (; shift < ValidityMask::BITS_PER_VALUE && l < r; ++shift, ++l) { + if (mask.RowIsValid(block, shift) && --n == 0) { + return MinValue(l, r); + } + } + } + + // Didn't find a start so return the end of the range + return r; +} + +static idx_t FindPrevStart(const ValidityMask &mask, const idx_t l, idx_t r, idx_t &n) { + if (mask.AllValid()) { + auto start = (r <= l + n) ? l : r - n; + n -= r - start; + return start; + } + + while (l < r) { + // If r is aligned with the start of a block, and the previous block is blank, + // then skip backwards one block. + idx_t entry_idx; + idx_t shift; + mask.GetEntryIndex(r - 1, entry_idx, shift); + + const auto block = mask.GetValidityEntry(entry_idx); + if (mask.NoneValid(block) && (shift + 1 == ValidityMask::BITS_PER_VALUE)) { + // r is nonzero (> l) and word aligned, so this will not underflow. + r -= ValidityMask::BITS_PER_VALUE; + continue; + } + + // Loop backwards over the block + // shift is probing r-1 >= l >= 0 + for (++shift; shift-- > 0; --r) { + if (mask.RowIsValid(block, shift) && --n == 0) { + return MaxValue(l, r - 1); + } + } + } + + // Didn't find a start so return the start of the range + return l; +} + +template +static T GetCell(const DataChunk &chunk, idx_t column, idx_t index) { + D_ASSERT(chunk.ColumnCount() > column); + auto &source = chunk.data[column]; + const auto data = FlatVector::GetData(source); + return data[index]; +} + +static bool CellIsNull(const DataChunk &chunk, idx_t column, idx_t index) { + D_ASSERT(chunk.ColumnCount() > column); + auto &source = chunk.data[column]; + return FlatVector::IsNull(source, index); +} + +static void CopyCell(const DataChunk &chunk, idx_t column, idx_t index, Vector &target, idx_t target_offset) { + D_ASSERT(chunk.ColumnCount() > column); + auto &source = chunk.data[column]; + VectorOperations::Copy(source, target, index + 1, index, target_offset); +} + +//===--------------------------------------------------------------------===// +// WindowColumnIterator +//===--------------------------------------------------------------------===// +template +struct WindowColumnIterator { + using iterator = WindowColumnIterator; + using iterator_category = std::random_access_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = T; + using reference = T; + using pointer = idx_t; + + explicit WindowColumnIterator(const WindowInputColumn &coll_p, pointer pos_p = 0) : coll(&coll_p), pos(pos_p) { + } + + // Forward iterator + inline reference operator*() const { + return coll->GetCell(pos); + } + inline explicit operator pointer() const { + return pos; + } + + inline iterator &operator++() { + ++pos; + return *this; + } + inline iterator operator++(int) { + auto result = *this; + ++(*this); + return result; + } + + // Bidirectional iterator + inline iterator &operator--() { + --pos; + return *this; + } + inline iterator operator--(int) { + auto result = *this; + --(*this); + return result; + } + + // Random Access + inline iterator &operator+=(difference_type n) { + pos += n; + return *this; + } + inline iterator &operator-=(difference_type n) { + pos -= n; + return *this; + } + + inline reference operator[](difference_type m) const { + return coll->GetCell(pos + m); + } + + friend inline iterator &operator+(const iterator &a, difference_type n) { + return iterator(a.coll, a.pos + n); + } + + friend inline iterator &operator-(const iterator &a, difference_type n) { + return iterator(a.coll, a.pos - n); + } + + friend inline iterator &operator+(difference_type n, const iterator &a) { + return a + n; + } + friend inline difference_type operator-(const iterator &a, const iterator &b) { + return difference_type(a.pos - b.pos); + } + + friend inline bool operator==(const iterator &a, const iterator &b) { + return a.pos == b.pos; + } + friend inline bool operator!=(const iterator &a, const iterator &b) { + return a.pos != b.pos; + } + friend inline bool operator<(const iterator &a, const iterator &b) { + return a.pos < b.pos; + } + friend inline bool operator<=(const iterator &a, const iterator &b) { + return a.pos <= b.pos; + } + friend inline bool operator>(const iterator &a, const iterator &b) { + return a.pos > b.pos; + } + friend inline bool operator>=(const iterator &a, const iterator &b) { + return a.pos >= b.pos; + } + +private: + optional_ptr coll; + pointer pos; +}; + +template +struct OperationCompare : public std::function { + inline bool operator()(const T &lhs, const T &val) const { + return OP::template Operation(lhs, val); + } +}; + +template +static idx_t FindTypedRangeBound(const WindowInputColumn &over, const idx_t order_begin, const idx_t order_end, + WindowInputExpression &boundary, const idx_t chunk_idx, const FrameBounds &prev) { + D_ASSERT(!boundary.CellIsNull(chunk_idx)); + const auto val = boundary.GetCell(chunk_idx); + + OperationCompare comp; + WindowColumnIterator begin(over, order_begin); + WindowColumnIterator end(over, order_end); + + if (order_begin < prev.start && prev.start < order_end) { + const auto first = over.GetCell(prev.start); + if (!comp(val, first)) { + // prev.first <= val, so we can start further forward + begin += (prev.start - order_begin); + } + } + if (order_begin <= prev.end && prev.end < order_end) { + const auto second = over.GetCell(prev.end); + if (!comp(second, val)) { + // val <= prev.second, so we can end further back + // (prev.second is the largest peer) + end -= (order_end - prev.end - 1); + } + } + + if (FROM) { + return idx_t(std::lower_bound(begin, end, val, comp)); + } else { + return idx_t(std::upper_bound(begin, end, val, comp)); + } +} + +template +static idx_t FindRangeBound(const WindowInputColumn &over, const idx_t order_begin, const idx_t order_end, + WindowInputExpression &boundary, const idx_t chunk_idx, const FrameBounds &prev) { + D_ASSERT(boundary.chunk.ColumnCount() == 1); + D_ASSERT(boundary.chunk.data[0].GetType().InternalType() == over.input_expr.ptype); + + switch (over.input_expr.ptype) { + case PhysicalType::INT8: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::INT16: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::INT32: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::INT64: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::UINT8: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::UINT16: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::UINT32: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::UINT64: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::INT128: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::FLOAT: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::DOUBLE: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case PhysicalType::INTERVAL: + return FindTypedRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + default: + throw InternalException("Unsupported column type for RANGE"); + } +} + +template +static idx_t FindOrderedRangeBound(const WindowInputColumn &over, const OrderType range_sense, const idx_t order_begin, + const idx_t order_end, WindowInputExpression &boundary, const idx_t chunk_idx, + const FrameBounds &prev) { + switch (range_sense) { + case OrderType::ASCENDING: + return FindRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + case OrderType::DESCENDING: + return FindRangeBound(over, order_begin, order_end, boundary, chunk_idx, prev); + default: + throw InternalException("Unsupported ORDER BY sense for RANGE"); + } +} + +struct WindowBoundariesState { + static inline bool IsScalar(const unique_ptr &expr) { + return expr ? expr->IsScalar() : true; + } + + static inline bool BoundaryNeedsPeer(const WindowBoundary &boundary) { + switch (boundary) { + case WindowBoundary::CURRENT_ROW_RANGE: + case WindowBoundary::EXPR_PRECEDING_RANGE: + case WindowBoundary::EXPR_FOLLOWING_RANGE: + return true; + default: + return false; + } + } + + WindowBoundariesState(BoundWindowExpression &wexpr, const idx_t input_size); + + void Update(const idx_t row_idx, const WindowInputColumn &range_collection, const idx_t chunk_idx, + WindowInputExpression &boundary_start, WindowInputExpression &boundary_end, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + + void Bounds(DataChunk &bounds, idx_t row_idx, const WindowInputColumn &range, const idx_t count, + WindowInputExpression &boundary_start, WindowInputExpression &boundary_end, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + + // Cached lookups + const ExpressionType type; + const idx_t input_size; + const WindowBoundary start_boundary; + const WindowBoundary end_boundary; + const size_t partition_count; + const size_t order_count; + const OrderType range_sense; + const bool has_preceding_range; + const bool has_following_range; + const bool needs_peer; + + idx_t next_pos = 0; + idx_t partition_start = 0; + idx_t partition_end = 0; + idx_t peer_start = 0; + idx_t peer_end = 0; + idx_t valid_start = 0; + idx_t valid_end = 0; + int64_t window_start = -1; + int64_t window_end = -1; + FrameBounds prev; +}; + +//===--------------------------------------------------------------------===// +// WindowBoundariesState +//===--------------------------------------------------------------------===// +void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn &range_collection, + const idx_t chunk_idx, WindowInputExpression &boundary_start, + WindowInputExpression &boundary_end, const ValidityMask &partition_mask, + const ValidityMask &order_mask) { + + if (partition_count + order_count > 0) { + + // determine partition and peer group boundaries to ultimately figure out window size + const auto is_same_partition = !partition_mask.RowIsValidUnsafe(row_idx); + const auto is_peer = !order_mask.RowIsValidUnsafe(row_idx); + const auto is_jump = (next_pos != row_idx); + + // when the partition changes, recompute the boundaries + if (!is_same_partition || is_jump) { + if (is_jump) { + idx_t n = 1; + partition_start = FindPrevStart(partition_mask, 0, row_idx + 1, n); + n = 1; + peer_start = FindPrevStart(order_mask, 0, row_idx + 1, n); + } else { + partition_start = row_idx; + peer_start = row_idx; + } + + // find end of partition + partition_end = input_size; + if (partition_count) { + idx_t n = 1; + partition_end = FindNextStart(partition_mask, partition_start + 1, input_size, n); + } + + // Find valid ordering values for the new partition + // so we can exclude NULLs from RANGE expression computations + valid_start = partition_start; + valid_end = partition_end; + + if ((valid_start < valid_end) && has_preceding_range) { + // Exclude any leading NULLs + if (range_collection.CellIsNull(valid_start)) { + idx_t n = 1; + valid_start = FindNextStart(order_mask, valid_start + 1, valid_end, n); + } + } + + if ((valid_start < valid_end) && has_following_range) { + // Exclude any trailing NULLs + if (range_collection.CellIsNull(valid_end - 1)) { + idx_t n = 1; + valid_end = FindPrevStart(order_mask, valid_start, valid_end, n); + } + + // Reset range hints + prev.start = valid_start; + prev.end = valid_end; + } + } else if (!is_peer) { + peer_start = row_idx; + } + + if (needs_peer) { + peer_end = partition_end; + if (order_count) { + idx_t n = 1; + peer_end = FindNextStart(order_mask, peer_start + 1, partition_end, n); + } + } + + } else { + // OVER() + partition_end = input_size; + peer_end = partition_end; + } + next_pos = row_idx + 1; + + // determine window boundaries depending on the type of expression + window_start = -1; + window_end = -1; + + switch (start_boundary) { + case WindowBoundary::UNBOUNDED_PRECEDING: + window_start = partition_start; + break; + case WindowBoundary::CURRENT_ROW_ROWS: + window_start = row_idx; + break; + case WindowBoundary::CURRENT_ROW_RANGE: + window_start = peer_start; + break; + case WindowBoundary::EXPR_PRECEDING_ROWS: { + if (!TrySubtractOperator::Operation(int64_t(row_idx), boundary_start.GetCell(chunk_idx), + window_start)) { + throw OutOfRangeException("Overflow computing ROWS PRECEDING start"); + } + break; + } + case WindowBoundary::EXPR_FOLLOWING_ROWS: { + if (!TryAddOperator::Operation(int64_t(row_idx), boundary_start.GetCell(chunk_idx), window_start)) { + throw OutOfRangeException("Overflow computing ROWS FOLLOWING start"); + } + break; + } + case WindowBoundary::EXPR_PRECEDING_RANGE: { + if (boundary_start.CellIsNull(chunk_idx)) { + window_start = peer_start; + } else { + prev.start = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx, + boundary_start, chunk_idx, prev); + window_start = prev.start; + } + break; + } + case WindowBoundary::EXPR_FOLLOWING_RANGE: { + if (boundary_start.CellIsNull(chunk_idx)) { + window_start = peer_start; + } else { + prev.start = FindOrderedRangeBound(range_collection, range_sense, row_idx, valid_end, boundary_start, + chunk_idx, prev); + window_start = prev.start; + } + break; + } + default: + throw InternalException("Unsupported window start boundary"); + } + + switch (end_boundary) { + case WindowBoundary::CURRENT_ROW_ROWS: + window_end = row_idx + 1; + break; + case WindowBoundary::CURRENT_ROW_RANGE: + window_end = peer_end; + break; + case WindowBoundary::UNBOUNDED_FOLLOWING: + window_end = partition_end; + break; + case WindowBoundary::EXPR_PRECEDING_ROWS: + if (!TrySubtractOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), + window_end)) { + throw OutOfRangeException("Overflow computing ROWS PRECEDING end"); + } + break; + case WindowBoundary::EXPR_FOLLOWING_ROWS: + if (!TryAddOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), window_end)) { + throw OutOfRangeException("Overflow computing ROWS FOLLOWING end"); + } + break; + case WindowBoundary::EXPR_PRECEDING_RANGE: { + if (boundary_end.CellIsNull(chunk_idx)) { + window_end = peer_end; + } else { + prev.end = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx, boundary_end, + chunk_idx, prev); + window_end = prev.end; + } + break; + } + case WindowBoundary::EXPR_FOLLOWING_RANGE: { + if (boundary_end.CellIsNull(chunk_idx)) { + window_end = peer_end; + } else { + prev.end = FindOrderedRangeBound(range_collection, range_sense, row_idx, valid_end, boundary_end, + chunk_idx, prev); + window_end = prev.end; + } + break; + } + default: + throw InternalException("Unsupported window end boundary"); + } + + // clamp windows to partitions if they should exceed + if (window_start < (int64_t)partition_start) { + window_start = partition_start; + } + if (window_start > (int64_t)partition_end) { + window_start = partition_end; + } + if (window_end < (int64_t)partition_start) { + window_end = partition_start; + } + if (window_end > (int64_t)partition_end) { + window_end = partition_end; + } + + if (window_start < 0 || window_end < 0) { + throw InternalException("Failed to compute window boundaries"); + } +} + +static bool HasPrecedingRange(BoundWindowExpression &wexpr) { + return (wexpr.start == WindowBoundary::EXPR_PRECEDING_RANGE || wexpr.end == WindowBoundary::EXPR_PRECEDING_RANGE); +} + +static bool HasFollowingRange(BoundWindowExpression &wexpr) { + return (wexpr.start == WindowBoundary::EXPR_FOLLOWING_RANGE || wexpr.end == WindowBoundary::EXPR_FOLLOWING_RANGE); +} + +WindowBoundariesState::WindowBoundariesState(BoundWindowExpression &wexpr, const idx_t input_size) + : type(wexpr.type), input_size(input_size), start_boundary(wexpr.start), end_boundary(wexpr.end), + partition_count(wexpr.partitions.size()), order_count(wexpr.orders.size()), + range_sense(wexpr.orders.empty() ? OrderType::INVALID : wexpr.orders[0].type), + has_preceding_range(HasPrecedingRange(wexpr)), has_following_range(HasFollowingRange(wexpr)), + needs_peer(BoundaryNeedsPeer(wexpr.end) || wexpr.type == ExpressionType::WINDOW_CUME_DIST) { +} + +void WindowBoundariesState::Bounds(DataChunk &bounds, idx_t row_idx, const WindowInputColumn &range, const idx_t count, + WindowInputExpression &boundary_start, WindowInputExpression &boundary_end, + const ValidityMask &partition_mask, const ValidityMask &order_mask) { + bounds.Reset(); + D_ASSERT(bounds.ColumnCount() == 6); + auto partition_begin_data = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); + auto partition_end_data = FlatVector::GetData(bounds.data[PARTITION_END]); + auto peer_begin_data = FlatVector::GetData(bounds.data[PEER_BEGIN]); + auto peer_end_data = FlatVector::GetData(bounds.data[PEER_END]); + auto window_begin_data = FlatVector::GetData(bounds.data[WINDOW_BEGIN]); + auto window_end_data = FlatVector::GetData(bounds.data[WINDOW_END]); + for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { + Update(row_idx, range, chunk_idx, boundary_start, boundary_end, partition_mask, order_mask); + *partition_begin_data++ = partition_start; + *partition_end_data++ = partition_end; + if (needs_peer) { + *peer_begin_data++ = peer_start; + *peer_end_data++ = peer_end; + } + *window_begin_data++ = window_start; + *window_end_data++ = window_end; + } + bounds.SetCardinality(count); +} + +//===--------------------------------------------------------------------===// +// WindowExecutorBoundsState +//===--------------------------------------------------------------------===// +class WindowExecutorBoundsState : public WindowExecutorState { +public: + WindowExecutorBoundsState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t count, + const ValidityMask &partition_mask_p, const ValidityMask &order_mask_p); + ~WindowExecutorBoundsState() override { + } + + virtual void UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range); + + // Frame management + const ValidityMask &partition_mask; + const ValidityMask &order_mask; + DataChunk bounds; + WindowBoundariesState state; + + // evaluate boundaries if present. Parser has checked boundary types. + WindowInputExpression boundary_start; + WindowInputExpression boundary_end; +}; + +WindowExecutorBoundsState::WindowExecutorBoundsState(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask_p, + const ValidityMask &order_mask_p) + : partition_mask(partition_mask_p), order_mask(order_mask_p), state(wexpr, payload_count), + boundary_start(wexpr.start_expr.get(), context), boundary_end(wexpr.end_expr.get(), context) { + vector bounds_types(6, LogicalType(LogicalTypeId::UBIGINT)); + bounds.Initialize(Allocator::Get(context), bounds_types); +} + +void WindowExecutorBoundsState::UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) { + // Evaluate the row-level arguments + boundary_start.Execute(input_chunk); + boundary_end.Execute(input_chunk); + + const auto count = input_chunk.size(); + bounds.Reset(); + state.Bounds(bounds, row_idx, range, count, boundary_start, boundary_end, partition_mask, order_mask); +} + +//===--------------------------------------------------------------------===// +// WindowExecutor +//===--------------------------------------------------------------------===// +static void PrepareInputExpressions(vector> &exprs, ExpressionExecutor &executor, + DataChunk &chunk) { + if (exprs.empty()) { + return; + } + + vector types; + for (idx_t expr_idx = 0; expr_idx < exprs.size(); ++expr_idx) { + types.push_back(exprs[expr_idx]->return_type); + executor.AddExpression(*exprs[expr_idx]); + } + + if (!types.empty()) { + auto &allocator = executor.GetAllocator(); + chunk.Initialize(allocator, types); + } +} + +WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) + : wexpr(wexpr), context(context), payload_count(payload_count), partition_mask(partition_mask), + order_mask(order_mask), payload_collection(), payload_executor(context), + range((HasPrecedingRange(wexpr) || HasFollowingRange(wexpr)) ? wexpr.orders[0].expression.get() : nullptr, + context, payload_count) { + // TODO: child may be a scalar, don't need to materialize the whole collection then + + // evaluate inner expressions of window functions, could be more complex + PrepareInputExpressions(wexpr.children, payload_executor, payload_chunk); + + auto types = payload_chunk.GetTypes(); + if (!types.empty()) { + payload_collection.Initialize(Allocator::Get(context), types); + } +} + +unique_ptr WindowExecutor::GetExecutorState() const { + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +} + +//===--------------------------------------------------------------------===// +// WindowAggregateExecutor +//===--------------------------------------------------------------------===// +bool WindowAggregateExecutor::IsConstantAggregate() { + if (!wexpr.aggregate) { + return false; + } + + // COUNT(*) is already handled efficiently by segment trees. + if (wexpr.children.empty()) { + return false; + } + + /* + The default framing option is RANGE UNBOUNDED PRECEDING, which + is the same as RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT + ROW; it sets the frame to be all rows from the partition start + up through the current row's last peer (a row that the window's + ORDER BY clause considers equivalent to the current row; all + rows are peers if there is no ORDER BY). In general, UNBOUNDED + PRECEDING means that the frame starts with the first row of the + partition, and similarly UNBOUNDED FOLLOWING means that the + frame ends with the last row of the partition, regardless of + RANGE, ROWS or GROUPS mode. In ROWS mode, CURRENT ROW means that + the frame starts or ends with the current row; but in RANGE or + GROUPS mode it means that the frame starts or ends with the + current row's first or last peer in the ORDER BY ordering. The + offset PRECEDING and offset FOLLOWING options vary in meaning + depending on the frame mode. + */ + switch (wexpr.start) { + case WindowBoundary::UNBOUNDED_PRECEDING: + break; + case WindowBoundary::CURRENT_ROW_RANGE: + if (!wexpr.orders.empty()) { + return false; + } + break; + default: + return false; + } + + switch (wexpr.end) { + case WindowBoundary::UNBOUNDED_FOLLOWING: + break; + case WindowBoundary::CURRENT_ROW_RANGE: + if (!wexpr.orders.empty()) { + return false; + } + break; + default: + return false; + } + + return true; +} + +bool WindowAggregateExecutor::IsCustomAggregate() { + if (!wexpr.aggregate) { + return false; + } + + if (!AggregateObject(wexpr).function.window) { + return false; + } + + return (mode < WindowAggregationMode::COMBINE); +} + +void WindowExecutor::Evaluate(idx_t row_idx, DataChunk &input_chunk, Vector &result, + WindowExecutorState &lstate) const { + auto &lbstate = lstate.Cast(); + lbstate.UpdateBounds(row_idx, input_chunk, range); + + const auto count = input_chunk.size(); + EvaluateInternal(lstate, result, count, row_idx); + + result.Verify(count); +} + +WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t count, const ValidityMask &partition_mask, + const ValidityMask &order_mask, WindowAggregationMode mode) + : WindowExecutor(wexpr, context, count, partition_mask, order_mask), mode(mode), filter_executor(context) { + // TODO we could evaluate those expressions in parallel + + // Check for constant aggregate + if (IsConstantAggregate()) { + aggregator = + make_uniq(AggregateObject(wexpr), wexpr.return_type, partition_mask, count); + } else if (IsCustomAggregate()) { + aggregator = make_uniq(AggregateObject(wexpr), wexpr.return_type, count); + } else if (wexpr.aggregate) { + // build a segment tree for frame-adhering aggregates + // see http://www.vldb.org/pvldb/vol8/p1058-leis.pdf + aggregator = make_uniq(AggregateObject(wexpr), wexpr.return_type, count, mode); + } + + // evaluate the FILTER clause and stuff it into a large mask for compactness and reuse + if (wexpr.filter_expr) { + filter_executor.AddExpression(*wexpr.filter_expr); + filter_sel.Initialize(STANDARD_VECTOR_SIZE); + } +} + +void WindowAggregateExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) { + idx_t filtered = 0; + SelectionVector *filtering = nullptr; + if (wexpr.filter_expr) { + filtering = &filter_sel; + filtered = filter_executor.SelectExpression(input_chunk, filter_sel); + } + + if (!wexpr.children.empty()) { + payload_chunk.Reset(); + payload_executor.Execute(input_chunk, payload_chunk); + payload_chunk.Verify(); + } else if (aggregator) { + // Zero-argument aggregate (e.g., COUNT(*) + payload_chunk.SetCardinality(input_chunk); + } + + D_ASSERT(aggregator); + aggregator->Sink(payload_chunk, filtering, filtered); + + WindowExecutor::Sink(input_chunk, input_idx, total_count); +} + +void WindowAggregateExecutor::Finalize() { + D_ASSERT(aggregator); + aggregator->Finalize(); +} + +class WindowAggregateState : public WindowExecutorBoundsState { +public: + WindowAggregateState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask, + const WindowAggregator &aggregator) + : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask), + aggregator_state(aggregator.GetLocalState()) { + } + +public: + unique_ptr aggregator_state; + + void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); +}; + +unique_ptr WindowAggregateExecutor::GetExecutorState() const { + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask, *aggregator); +} + +void WindowAggregateExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &lastate = lstate.Cast(); + D_ASSERT(aggregator); + auto window_begin = FlatVector::GetData(lastate.bounds.data[WINDOW_BEGIN]); + auto window_end = FlatVector::GetData(lastate.bounds.data[WINDOW_END]); + aggregator->Evaluate(*lastate.aggregator_state, window_begin, window_end, result, count); +} + +//===--------------------------------------------------------------------===// +// WindowRowNumberExecutor +//===--------------------------------------------------------------------===// +WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &lbstate = lstate.Cast(); + auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); + auto rdata = FlatVector::GetData(result); + for (idx_t i = 0; i < count; ++i, ++row_idx) { + rdata[i] = row_idx - partition_begin[i] + 1; + } +} + +//===--------------------------------------------------------------------===// +// WindowPeerState +//===--------------------------------------------------------------------===// +class WindowPeerState : public WindowExecutorBoundsState { +public: + WindowPeerState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) + : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask) { + } + +public: + uint64_t dense_rank = 1; + uint64_t rank_equal = 0; + uint64_t rank = 1; + + void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); +}; + +void WindowPeerState::NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx) { + if (partition_begin == row_idx) { + dense_rank = 1; + rank = 1; + rank_equal = 0; + } else if (peer_begin == row_idx) { + dense_rank++; + rank += rank_equal; + rank_equal = 0; + } + rank_equal++; +} + +WindowRankExecutor::WindowRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) + : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +unique_ptr WindowRankExecutor::GetExecutorState() const { + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +} + +void WindowRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &lpeer = lstate.Cast(); + auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); + auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); + auto rdata = FlatVector::GetData(result); + + // Reset to "previous" row + lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; + lpeer.rank_equal = (row_idx - peer_begin[0]); + + for (idx_t i = 0; i < count; ++i, ++row_idx) { + lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); + rdata[i] = lpeer.rank; + } +} + +WindowDenseRankExecutor::WindowDenseRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +unique_ptr WindowDenseRankExecutor::GetExecutorState() const { + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +} + +void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &lpeer = lstate.Cast(); + auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); + auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); + auto rdata = FlatVector::GetData(result); + + // Reset to "previous" row + lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; + lpeer.rank_equal = (row_idx - peer_begin[0]); + + // The previous dense rank is the number of order mask bits in [partition_begin, row_idx) + lpeer.dense_rank = 0; + + auto order_begin = partition_begin[0]; + idx_t begin_idx; + idx_t begin_offset; + order_mask.GetEntryIndex(order_begin, begin_idx, begin_offset); + + auto order_end = row_idx; + idx_t end_idx; + idx_t end_offset; + order_mask.GetEntryIndex(order_end, end_idx, end_offset); + + // If they are in the same entry, just loop + if (begin_idx == end_idx) { + const auto entry = order_mask.GetValidityEntry(begin_idx); + for (; begin_offset < end_offset; ++begin_offset) { + lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); + } + } else { + // Count the ragged bits at the start of the partition + if (begin_offset) { + const auto entry = order_mask.GetValidityEntry(begin_idx); + for (; begin_offset < order_mask.BITS_PER_VALUE; ++begin_offset) { + lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); + ++order_begin; + } + ++begin_idx; + } + + // Count the the aligned bits. + ValidityMask tail_mask(order_mask.GetData() + begin_idx); + lpeer.dense_rank += tail_mask.CountValid(order_end - order_begin); + } + + for (idx_t i = 0; i < count; ++i, ++row_idx) { + lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); + rdata[i] = lpeer.dense_rank; + } +} + +WindowPercentRankExecutor::WindowPercentRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +unique_ptr WindowPercentRankExecutor::GetExecutorState() const { + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +} + +void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &lpeer = lstate.Cast(); + auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); + auto partition_end = FlatVector::GetData(lpeer.bounds.data[PARTITION_END]); + auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); + auto rdata = FlatVector::GetData(result); + + // Reset to "previous" row + lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; + lpeer.rank_equal = (row_idx - peer_begin[0]); + + for (idx_t i = 0; i < count; ++i, ++row_idx) { + lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); + int64_t denom = partition_end[i] - partition_begin[i] - 1; + double percent_rank = denom > 0 ? ((double)lpeer.rank - 1) / denom : 0; + rdata[i] = percent_rank; + } +} + +//===--------------------------------------------------------------------===// +// WindowCumeDistExecutor +//===--------------------------------------------------------------------===// +WindowCumeDistExecutor::WindowCumeDistExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +void WindowCumeDistExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &lbstate = lstate.Cast(); + auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); + auto partition_end = FlatVector::GetData(lbstate.bounds.data[PARTITION_END]); + auto peer_end = FlatVector::GetData(lbstate.bounds.data[PEER_END]); + auto rdata = FlatVector::GetData(result); + for (idx_t i = 0; i < count; ++i, ++row_idx) { + int64_t denom = partition_end[i] - partition_begin[i]; + double cume_dist = denom > 0 ? ((double)(peer_end[i] - partition_begin[i])) / denom : 0; + rdata[i] = cume_dist; + } +} + +//===--------------------------------------------------------------------===// +// WindowValueExecutor +//===--------------------------------------------------------------------===// +WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +void WindowValueExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) { + // Single pass over the input to produce the global data. + // Vectorisation for the win... + + // Set up a validity mask for IGNORE NULLS + bool check_nulls = false; + if (wexpr.ignore_nulls) { + switch (wexpr.type) { + case ExpressionType::WINDOW_LEAD: + case ExpressionType::WINDOW_LAG: + case ExpressionType::WINDOW_FIRST_VALUE: + case ExpressionType::WINDOW_LAST_VALUE: + case ExpressionType::WINDOW_NTH_VALUE: + check_nulls = true; + break; + default: + break; + } + } + + if (!wexpr.children.empty()) { + payload_chunk.Reset(); + payload_executor.Execute(input_chunk, payload_chunk); + payload_chunk.Verify(); + payload_collection.Append(payload_chunk, true); + + // process payload chunks while they are still piping hot + if (check_nulls) { + const auto count = input_chunk.size(); + + UnifiedVectorFormat vdata; + payload_chunk.data[0].ToUnifiedFormat(count, vdata); + if (!vdata.validity.AllValid()) { + // Lazily materialise the contents when we find the first NULL + if (ignore_nulls.AllValid()) { + ignore_nulls.Initialize(total_count); + } + // Write to the current position + if (input_idx % ValidityMask::BITS_PER_VALUE == 0) { + // If we are at the edge of an output entry, just copy the entries + auto dst = ignore_nulls.GetData() + ignore_nulls.EntryCount(input_idx); + auto src = vdata.validity.GetData(); + for (auto entry_count = vdata.validity.EntryCount(count); entry_count-- > 0;) { + *dst++ = *src++; + } + } else { + // If not, we have ragged data and need to copy one bit at a time. + for (idx_t i = 0; i < count; ++i) { + ignore_nulls.Set(input_idx + i, vdata.validity.RowIsValid(i)); + } + } + } + } + } + + WindowExecutor::Sink(input_chunk, input_idx, total_count); +} + +void WindowNtileExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + D_ASSERT(payload_collection.ColumnCount() == 1); + auto &lbstate = lstate.Cast(); + auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); + auto partition_end = FlatVector::GetData(lbstate.bounds.data[PARTITION_END]); + auto rdata = FlatVector::GetData(result); + for (idx_t i = 0; i < count; ++i, ++row_idx) { + if (CellIsNull(payload_collection, 0, row_idx)) { + FlatVector::SetNull(result, i, true); + } else { + auto n_param = GetCell(payload_collection, 0, row_idx); + if (n_param < 1) { + throw InvalidInputException("Argument for ntile must be greater than zero"); + } + // With thanks from SQLite's ntileValueFunc() + int64_t n_total = partition_end[i] - partition_begin[i]; + if (n_param > n_total) { + // more groups allowed than we have values + // map every entry to a unique group + n_param = n_total; + } + int64_t n_size = (n_total / n_param); + // find the row idx within the group + D_ASSERT(row_idx >= partition_begin[i]); + int64_t adjusted_row_idx = row_idx - partition_begin[i]; + // now compute the ntile + int64_t n_large = n_total - n_param * n_size; + int64_t i_small = n_large * (n_size + 1); + int64_t result_ntile; + + D_ASSERT((n_large * (n_size + 1) + (n_param - n_large) * n_size) == n_total); + + if (adjusted_row_idx < i_small) { + result_ntile = 1 + adjusted_row_idx / (n_size + 1); + } else { + result_ntile = 1 + n_large + (adjusted_row_idx - i_small) / n_size; + } + // result has to be between [1, NTILE] + D_ASSERT(result_ntile >= 1 && result_ntile <= n_param); + rdata[i] = result_ntile; + } + } +} + +//===--------------------------------------------------------------------===// +// WindowLeadLagState +//===--------------------------------------------------------------------===// +class WindowLeadLagState : public WindowExecutorBoundsState { +public: + WindowLeadLagState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) + : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask), + leadlag_offset(wexpr.offset_expr.get(), context), leadlag_default(wexpr.default_expr.get(), context) { + } + + void UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) override; + +public: + // LEAD/LAG Evaluation + WindowInputExpression leadlag_offset; + WindowInputExpression leadlag_default; +}; + +void WindowLeadLagState::UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) { + // Evaluate the row-level arguments + leadlag_offset.Execute(input_chunk); + leadlag_default.Execute(input_chunk); + + WindowExecutorBoundsState::UpdateBounds(row_idx, input_chunk, range); +} + +WindowLeadLagExecutor::WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +unique_ptr WindowLeadLagExecutor::GetExecutorState() const { + return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +} + +void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &llstate = lstate.Cast(); + + auto partition_begin = FlatVector::GetData(llstate.bounds.data[PARTITION_BEGIN]); + auto partition_end = FlatVector::GetData(llstate.bounds.data[PARTITION_END]); + for (idx_t i = 0; i < count; ++i, ++row_idx) { + int64_t offset = 1; + if (wexpr.offset_expr) { + offset = llstate.leadlag_offset.GetCell(i); + } + int64_t val_idx = (int64_t)row_idx; + if (wexpr.type == ExpressionType::WINDOW_LEAD) { + val_idx = AddOperatorOverflowCheck::Operation(val_idx, offset); + } else { + val_idx = SubtractOperatorOverflowCheck::Operation(val_idx, offset); + } + + idx_t delta = 0; + if (val_idx < (int64_t)row_idx) { + // Count backwards + delta = idx_t(row_idx - val_idx); + val_idx = FindPrevStart(ignore_nulls, partition_begin[i], row_idx, delta); + } else if (val_idx > (int64_t)row_idx) { + delta = idx_t(val_idx - row_idx); + val_idx = FindNextStart(ignore_nulls, row_idx + 1, partition_end[i], delta); + } + // else offset is zero, so don't move. + + if (!delta) { + CopyCell(payload_collection, 0, val_idx, result, i); + } else if (wexpr.default_expr) { + llstate.leadlag_default.CopyCell(result, i); + } else { + FlatVector::SetNull(result, i, true); + } + } +} + +WindowFirstValueExecutor::WindowFirstValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +void WindowFirstValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &lbstate = lstate.Cast(); + auto window_begin = FlatVector::GetData(lbstate.bounds.data[WINDOW_BEGIN]); + auto window_end = FlatVector::GetData(lbstate.bounds.data[WINDOW_END]); + for (idx_t i = 0; i < count; ++i, ++row_idx) { + if (window_begin[i] >= window_end[i]) { + FlatVector::SetNull(result, i, true); + continue; + } + // Same as NTH_VALUE(..., 1) + idx_t n = 1; + const auto first_idx = FindNextStart(ignore_nulls, window_begin[i], window_end[i], n); + if (!n) { + CopyCell(payload_collection, 0, first_idx, result, i); + } else { + FlatVector::SetNull(result, i, true); + } + } +} + +WindowLastValueExecutor::WindowLastValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +void WindowLastValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + auto &lbstate = lstate.Cast(); + auto window_begin = FlatVector::GetData(lbstate.bounds.data[WINDOW_BEGIN]); + auto window_end = FlatVector::GetData(lbstate.bounds.data[WINDOW_END]); + for (idx_t i = 0; i < count; ++i, ++row_idx) { + if (window_begin[i] >= window_end[i]) { + FlatVector::SetNull(result, i, true); + continue; + } + idx_t n = 1; + const auto last_idx = FindPrevStart(ignore_nulls, window_begin[i], window_end[i], n); + if (!n) { + CopyCell(payload_collection, 0, last_idx, result, i); + } else { + FlatVector::SetNull(result, i, true); + } + } +} + +WindowNthValueExecutor::WindowNthValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +} + +void WindowNthValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, + idx_t row_idx) const { + D_ASSERT(payload_collection.ColumnCount() == 2); + + auto &lbstate = lstate.Cast(); + auto window_begin = FlatVector::GetData(lbstate.bounds.data[WINDOW_BEGIN]); + auto window_end = FlatVector::GetData(lbstate.bounds.data[WINDOW_END]); + for (idx_t i = 0; i < count; ++i, ++row_idx) { + if (window_begin[i] >= window_end[i]) { + FlatVector::SetNull(result, i, true); + continue; + } + // Returns value evaluated at the row that is the n'th row of the window frame (counting from 1); + // returns NULL if there is no such row. + if (CellIsNull(payload_collection, 1, row_idx)) { + FlatVector::SetNull(result, i, true); + } else { + auto n_param = GetCell(payload_collection, 1, row_idx); + if (n_param < 1) { + FlatVector::SetNull(result, i, true); + } else { + auto n = idx_t(n_param); + const auto nth_index = FindNextStart(ignore_nulls, window_begin[i], window_end[i], n); + if (!n) { + CopyCell(payload_collection, 0, nth_index, result, i); + } else { + FlatVector::SetNull(result, i, true); + } + } + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/window_segment_tree.cpp b/src/duckdb/src/execution/window_segment_tree.cpp new file mode 100644 index 00000000..2a66399c --- /dev/null +++ b/src/duckdb/src/execution/window_segment_tree.cpp @@ -0,0 +1,656 @@ +#include "duckdb/execution/window_segment_tree.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// WindowAggregator +//===--------------------------------------------------------------------===// +WindowAggregatorState::WindowAggregatorState() : allocator(Allocator::DefaultAllocator()) { +} + +WindowAggregator::WindowAggregator(AggregateObject aggr, const LogicalType &result_type_p, idx_t partition_count_p) + : aggr(std::move(aggr)), result_type(result_type_p), partition_count(partition_count_p), + state_size(aggr.function.state_size()), filter_pos(0) { +} + +WindowAggregator::~WindowAggregator() { +} + +void WindowAggregator::Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) { + if (!inputs.ColumnCount() && payload_chunk.ColumnCount()) { + inputs.Initialize(Allocator::DefaultAllocator(), payload_chunk.GetTypes()); + } + if (inputs.ColumnCount()) { + inputs.Append(payload_chunk, true); + } + if (filter_sel) { + // Lazy instantiation + if (!filter_mask.IsMaskSet()) { + // Start with all invalid and set the ones that pass + filter_bits.resize(ValidityMask::ValidityMaskSize(partition_count), 0); + filter_mask.Initialize(filter_bits.data()); + } + for (idx_t f = 0; f < filtered; ++f) { + filter_mask.SetValid(filter_pos + filter_sel->get_index(f)); + } + filter_pos += payload_chunk.size(); + } +} + +void WindowAggregator::Finalize() { +} + +//===--------------------------------------------------------------------===// +// WindowConstantAggregate +//===--------------------------------------------------------------------===// +WindowConstantAggregator::WindowConstantAggregator(AggregateObject aggr, const LogicalType &result_type, + const ValidityMask &partition_mask, const idx_t count) + : WindowAggregator(std::move(aggr), result_type, count), partition(0), row(0), state(state_size), + statep(Value::POINTER(CastPointerToValue(state.data()))), + statef(Value::POINTER(CastPointerToValue(state.data()))) { + + statef.SetVectorType(VectorType::FLAT_VECTOR); // Prevent conversion of results to constants + + // Locate the partition boundaries + if (partition_mask.AllValid()) { + partition_offsets.emplace_back(0); + } else { + idx_t entry_idx; + idx_t shift; + for (idx_t start = 0; start < count;) { + partition_mask.GetEntryIndex(start, entry_idx, shift); + + // If start is aligned with the start of a block, + // and the block is blank, then skip forward one block. + const auto block = partition_mask.GetValidityEntry(entry_idx); + if (partition_mask.NoneValid(block) && !shift) { + start += ValidityMask::BITS_PER_VALUE; + continue; + } + + // Loop over the block + for (; shift < ValidityMask::BITS_PER_VALUE && start < count; ++shift, ++start) { + if (partition_mask.RowIsValid(block, shift)) { + partition_offsets.emplace_back(start); + } + } + } + } + + // Initialise the vector for caching the results + results = make_uniq(result_type, partition_offsets.size()); + partition_offsets.emplace_back(count); + + // Create an aggregate state for intermediate aggregates + gstate = make_uniq(); + + // Start the first aggregate + AggregateInit(); +} + +void WindowConstantAggregator::AggregateInit() { + aggr.function.initialize(state.data()); +} + +void WindowConstantAggregator::AggegateFinal(Vector &result, idx_t rid) { + AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); + aggr.function.finalize(statef, aggr_input_data, result, 1, rid); + + if (aggr.function.destructor) { + aggr.function.destructor(statef, aggr_input_data, 1); + } +} + +void WindowConstantAggregator::Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) { + const auto chunk_begin = row; + const auto chunk_end = chunk_begin + payload_chunk.size(); + + if (!inputs.ColumnCount() && payload_chunk.ColumnCount()) { + inputs.Initialize(Allocator::DefaultAllocator(), payload_chunk.GetTypes()); + } + + AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); + idx_t begin = 0; + idx_t filter_idx = 0; + auto partition_end = partition_offsets[partition + 1]; + while (row < chunk_end) { + if (row == partition_end) { + AggegateFinal(*results, partition++); + AggregateInit(); + partition_end = partition_offsets[partition + 1]; + } + partition_end = MinValue(partition_end, chunk_end); + auto end = partition_end - chunk_begin; + + inputs.Reset(); + if (filter_sel) { + // Slice to any filtered rows in [begin, end) + SelectionVector sel; + + // Find the first value in [begin, end) + for (; filter_idx < filtered; ++filter_idx) { + auto idx = filter_sel->get_index(filter_idx); + if (idx >= begin) { + break; + } + } + + // Find the first value in [end, filtered) + sel.Initialize(filter_sel->data() + filter_idx); + idx_t nsel = 0; + for (; filter_idx < filtered; ++filter_idx, ++nsel) { + auto idx = filter_sel->get_index(filter_idx); + if (idx >= end) { + break; + } + } + + if (nsel != inputs.size()) { + inputs.Slice(payload_chunk, sel, nsel); + } + } else { + // Slice to [begin, end) + if (begin) { + for (idx_t c = 0; c < payload_chunk.ColumnCount(); ++c) { + inputs.data[c].Slice(payload_chunk.data[c], begin, end); + } + } else { + inputs.Reference(payload_chunk); + } + inputs.SetCardinality(end - begin); + } + + // Aggregate the filtered rows into a single state + const auto count = inputs.size(); + if (aggr.function.simple_update) { + aggr.function.simple_update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), state.data(), count); + } else { + aggr.function.update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), statep, count); + } + + // Skip filtered rows too! + row += end - begin; + begin = end; + } +} + +void WindowConstantAggregator::Finalize() { + AggegateFinal(*results, partition++); +} + +class WindowConstantAggregatorState : public WindowAggregatorState { +public: + WindowConstantAggregatorState() : partition(0) { + matches.Initialize(); + } + ~WindowConstantAggregatorState() override { + } + +public: + //! The current result partition being read + idx_t partition; + //! Shared SV for evaluation + SelectionVector matches; +}; + +unique_ptr WindowConstantAggregator::GetLocalState() const { + return make_uniq(); +} + +void WindowConstantAggregator::Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, + Vector &target, idx_t count) const { + // Chunk up the constants and copy them one at a time + auto &lcstate = lstate.Cast(); + idx_t matched = 0; + idx_t target_offset = 0; + for (idx_t i = 0; i < count; ++i) { + const auto begin = begins[i]; + // Find the partition containing [begin, end) + while (partition_offsets[lcstate.partition + 1] <= begin) { + // Flush the previous partition's data + if (matched) { + VectorOperations::Copy(*results, target, lcstate.matches, matched, 0, target_offset); + target_offset += matched; + matched = 0; + } + ++lcstate.partition; + } + + lcstate.matches.set_index(matched++, lcstate.partition); + } + + // Flush the last partition + if (matched) { + VectorOperations::Copy(*results, target, lcstate.matches, matched, 0, target_offset); + } +} + +//===--------------------------------------------------------------------===// +// WindowCustomAggregator +//===--------------------------------------------------------------------===// +WindowCustomAggregator::WindowCustomAggregator(AggregateObject aggr, const LogicalType &result_type, idx_t count) + : WindowAggregator(std::move(aggr), result_type, count) { +} + +WindowCustomAggregator::~WindowCustomAggregator() { +} + +class WindowCustomAggregatorState : public WindowAggregatorState { +public: + explicit WindowCustomAggregatorState(const AggregateObject &aggr, DataChunk &inputs); + ~WindowCustomAggregatorState() override; + +public: + //! The aggregate function + const AggregateObject &aggr; + //! The aggregate function + DataChunk &inputs; + //! Data pointer that contains a single state, shared by all the custom evaluators + vector state; + //! Reused result state container for the window functions + Vector statef; + //! The frame boundaries, used for the window functions + FrameBounds frame; +}; + +WindowCustomAggregatorState::WindowCustomAggregatorState(const AggregateObject &aggr, DataChunk &inputs) + : aggr(aggr), inputs(inputs), state(aggr.function.state_size()), + statef(Value::POINTER(CastPointerToValue(state.data()))), frame(0, 0) { + // if we have a frame-by-frame method, share the single state + aggr.function.initialize(state.data()); +} + +WindowCustomAggregatorState::~WindowCustomAggregatorState() { + if (aggr.function.destructor) { + AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + aggr.function.destructor(statef, aggr_input_data, 1); + } +} + +unique_ptr WindowCustomAggregator::GetLocalState() const { + return make_uniq(aggr, const_cast(inputs)); +} + +void WindowCustomAggregator::Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, + Vector &result, idx_t count) const { + // TODO: window should take a const Vector* + auto &lcstate = lstate.Cast(); + auto &frame = lcstate.frame; + auto params = lcstate.inputs.data.data(); + auto &rmask = FlatVector::Validity(result); + for (idx_t i = 0; i < count; ++i) { + const auto begin = begins[i]; + const auto end = ends[i]; + if (begin >= end) { + rmask.SetInvalid(i); + continue; + } + + // Frame boundaries + auto prev = frame; + frame = FrameBounds(begin, end); + + // Extract the range + AggregateInputData aggr_input_data(aggr.GetFunctionData(), lstate.allocator); + aggr.function.window(params, filter_mask, aggr_input_data, inputs.ColumnCount(), lcstate.state.data(), frame, + prev, result, i, 0); + } +} + +//===--------------------------------------------------------------------===// +// WindowSegmentTree +//===--------------------------------------------------------------------===// +WindowSegmentTree::WindowSegmentTree(AggregateObject aggr, const LogicalType &result_type, idx_t count, + WindowAggregationMode mode_p) + : WindowAggregator(std::move(aggr), result_type, count), internal_nodes(0), mode(mode_p) { +} + +void WindowSegmentTree::Finalize() { + gstate = GetLocalState(); + if (inputs.ColumnCount() > 0) { + if (aggr.function.combine && UseCombineAPI()) { + ConstructTree(); + } + } +} + +WindowSegmentTree::~WindowSegmentTree() { + if (!aggr.function.destructor) { + // nothing to destroy + return; + } + AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); + // call the destructor for all the intermediate states + data_ptr_t address_data[STANDARD_VECTOR_SIZE]; + Vector addresses(LogicalType::POINTER, data_ptr_cast(address_data)); + idx_t count = 0; + for (idx_t i = 0; i < internal_nodes; i++) { + address_data[count++] = data_ptr_t(levels_flat_native.get() + i * state_size); + if (count == STANDARD_VECTOR_SIZE) { + aggr.function.destructor(addresses, aggr_input_data, count); + count = 0; + } + } + if (count > 0) { + aggr.function.destructor(addresses, aggr_input_data, count); + } +} + +class WindowSegmentTreeState : public WindowAggregatorState { +public: + WindowSegmentTreeState(const AggregateObject &aggr, DataChunk &inputs, const ValidityMask &filter_mask); + ~WindowSegmentTreeState() override; + + void FlushStates(bool combining); + void ExtractFrame(idx_t begin, idx_t end, data_ptr_t current_state); + void WindowSegmentValue(const WindowSegmentTree &tree, idx_t l_idx, idx_t begin, idx_t end, + data_ptr_t current_state); + void Finalize(Vector &result, idx_t count); + +public: + //! The aggregate function + const AggregateObject &aggr; + //! The aggregate function + DataChunk &inputs; + //! The filtered rows in inputs + const ValidityMask &filter_mask; + //! The size of a single aggregate state + const idx_t state_size; + //! Data pointer that contains a single state, used for intermediate window segment aggregation + vector state; + //! Input data chunk, used for leaf segment aggregation + DataChunk leaves; + //! The filtered rows in inputs. + SelectionVector filter_sel; + //! A vector of pointers to "state", used for intermediate window segment aggregation + Vector statep; + //! Reused state pointers for combining segment tree levels + Vector statel; + //! Reused result state container for the window functions + Vector statef; + //! Count of buffered values + idx_t flush_count; +}; + +WindowSegmentTreeState::WindowSegmentTreeState(const AggregateObject &aggr, DataChunk &inputs, + const ValidityMask &filter_mask) + : aggr(aggr), inputs(inputs), filter_mask(filter_mask), state_size(aggr.function.state_size()), + state(state_size * STANDARD_VECTOR_SIZE), statep(LogicalType::POINTER), statel(LogicalType::POINTER), + statef(LogicalType::POINTER), flush_count(0) { + if (inputs.ColumnCount() > 0) { + leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); + filter_sel.Initialize(); + } + + // Build the finalise vector that just points to the result states + data_ptr_t state_ptr = state.data(); + D_ASSERT(statef.GetVectorType() == VectorType::FLAT_VECTOR); + statef.SetVectorType(VectorType::CONSTANT_VECTOR); + statef.Flatten(STANDARD_VECTOR_SIZE); + auto fdata = FlatVector::GetData(statef); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; ++i) { + fdata[i] = state_ptr; + state_ptr += state_size; + } +} + +WindowSegmentTreeState::~WindowSegmentTreeState() { +} + +unique_ptr WindowSegmentTree::GetLocalState() const { + return make_uniq(aggr, const_cast(inputs), filter_mask); +} + +void WindowSegmentTreeState::FlushStates(bool combining) { + if (!flush_count) { + return; + } + + AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + if (combining) { + statel.Verify(flush_count); + aggr.function.combine(statel, statep, aggr_input_data, flush_count); + } else { + leaves.Reference(inputs); + leaves.Slice(filter_sel, flush_count); + aggr.function.update(&leaves.data[0], aggr_input_data, leaves.ColumnCount(), statep, flush_count); + } + + flush_count = 0; +} + +void WindowSegmentTreeState::ExtractFrame(idx_t begin, idx_t end, data_ptr_t state_ptr) { + const auto count = end - begin; + + // If we are not filtering, + // just update the shared dictionary selection to the range + // Otherwise set it to the input rows that pass the filter + auto states = FlatVector::GetData(statep); + if (filter_mask.AllValid()) { + for (idx_t i = 0; i < count; ++i) { + states[flush_count] = state_ptr; + filter_sel.set_index(flush_count++, begin + i); + if (flush_count >= STANDARD_VECTOR_SIZE) { + FlushStates(false); + } + } + } else { + for (idx_t i = begin; i < end; ++i) { + if (filter_mask.RowIsValid(i)) { + states[flush_count] = state_ptr; + filter_sel.set_index(flush_count++, i); + if (flush_count >= STANDARD_VECTOR_SIZE) { + FlushStates(false); + } + } + } + } +} + +void WindowSegmentTreeState::WindowSegmentValue(const WindowSegmentTree &tree, idx_t l_idx, idx_t begin, idx_t end, + data_ptr_t state_ptr) { + D_ASSERT(begin <= end); + if (begin == end || inputs.ColumnCount() == 0) { + return; + } + + const auto count = end - begin; + if (l_idx == 0) { + ExtractFrame(begin, end, state_ptr); + } else { + // find out where the states begin + auto begin_ptr = tree.levels_flat_native.get() + state_size * (begin + tree.levels_flat_start[l_idx - 1]); + // set up a vector of pointers that point towards the set of states + auto ldata = FlatVector::GetData(statel); + auto pdata = FlatVector::GetData(statep); + for (idx_t i = 0; i < count; i++) { + pdata[flush_count] = state_ptr; + ldata[flush_count++] = begin_ptr; + begin_ptr += state_size; + if (flush_count >= STANDARD_VECTOR_SIZE) { + FlushStates(true); + } + } + } +} +void WindowSegmentTreeState::Finalize(Vector &result, idx_t count) { + // Finalise the result aggregates + AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + aggr.function.finalize(statef, aggr_input_data, result, count, 0); + + // Destruct the result aggregates + if (aggr.function.destructor) { + aggr.function.destructor(statef, aggr_input_data, count); + } +} + +void WindowSegmentTree::ConstructTree() { + D_ASSERT(inputs.ColumnCount() > 0); + + // Use a temporary scan state to build the tree + auto >state = gstate->Cast(); + + // compute space required to store internal nodes of segment tree + internal_nodes = 0; + idx_t level_nodes = inputs.size(); + do { + level_nodes = (level_nodes + (TREE_FANOUT - 1)) / TREE_FANOUT; + internal_nodes += level_nodes; + } while (level_nodes > 1); + levels_flat_native = make_unsafe_uniq_array(internal_nodes * state_size); + levels_flat_start.push_back(0); + + idx_t levels_flat_offset = 0; + idx_t level_current = 0; + // level 0 is data itself + idx_t level_size; + // iterate over the levels of the segment tree + while ((level_size = + (level_current == 0 ? inputs.size() : levels_flat_offset - levels_flat_start[level_current - 1])) > 1) { + for (idx_t pos = 0; pos < level_size; pos += TREE_FANOUT) { + // compute the aggregate for this entry in the segment tree + data_ptr_t state_ptr = levels_flat_native.get() + (levels_flat_offset * state_size); + aggr.function.initialize(state_ptr); + gtstate.WindowSegmentValue(*this, level_current, pos, MinValue(level_size, pos + TREE_FANOUT), state_ptr); + gtstate.FlushStates(level_current > 0); + + levels_flat_offset++; + } + + levels_flat_start.push_back(levels_flat_offset); + level_current++; + } + + // Corner case: single element in the window + if (levels_flat_offset == 0) { + aggr.function.initialize(levels_flat_native.get()); + } +} + +void WindowSegmentTree::Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, Vector &result, + idx_t count) const { + auto <state = lstate.Cast(); + const auto cant_combine = (!aggr.function.combine || !UseCombineAPI()); + auto fdata = FlatVector::GetData(ltstate.statef); + + // First pass: aggregate the segment tree nodes + // Share adjacent identical states + // We do this first because we want to share only tree aggregations + idx_t prev_begin = 1; + idx_t prev_end = 0; + auto ldata = FlatVector::GetData(ltstate.statel); + auto pdata = FlatVector::GetData(ltstate.statep); + data_ptr_t prev_state = nullptr; + for (idx_t rid = 0; rid < count; ++rid) { + auto state_ptr = fdata[rid]; + aggr.function.initialize(state_ptr); + + if (cant_combine) { + // Make sure we initialise all states + continue; + } + + auto begin = begins[rid]; + auto end = ends[rid]; + if (begin >= end) { + continue; + } + + // Skip level 0 + idx_t l_idx = 0; + for (; l_idx < levels_flat_start.size() + 1; l_idx++) { + idx_t parent_begin = begin / TREE_FANOUT; + idx_t parent_end = end / TREE_FANOUT; + if (prev_state && l_idx == 1 && begin == prev_begin && end == prev_end) { + // Just combine the previous top level result + ldata[ltstate.flush_count] = prev_state; + pdata[ltstate.flush_count] = state_ptr; + if (++ltstate.flush_count >= STANDARD_VECTOR_SIZE) { + ltstate.FlushStates(true); + } + break; + } + + if (l_idx == 1) { + prev_state = state_ptr; + prev_begin = begin; + prev_end = end; + } + + if (parent_begin == parent_end) { + if (l_idx) { + ltstate.WindowSegmentValue(*this, l_idx, begin, end, state_ptr); + } + break; + } + idx_t group_begin = parent_begin * TREE_FANOUT; + if (begin != group_begin) { + if (l_idx) { + ltstate.WindowSegmentValue(*this, l_idx, begin, group_begin + TREE_FANOUT, state_ptr); + } + parent_begin++; + } + idx_t group_end = parent_end * TREE_FANOUT; + if (end != group_end) { + if (l_idx) { + ltstate.WindowSegmentValue(*this, l_idx, group_end, end, state_ptr); + } + } + begin = parent_begin; + end = parent_end; + } + } + ltstate.FlushStates(true); + + // Second pass: aggregate the ragged leaves + // (or everything if we can't combine) + for (idx_t rid = 0; rid < count; ++rid) { + auto state_ptr = fdata[rid]; + + const auto begin = begins[rid]; + const auto end = ends[rid]; + if (begin >= end) { + continue; + } + + // Aggregate everything at once if we can't combine states + idx_t parent_begin = begin / TREE_FANOUT; + idx_t parent_end = end / TREE_FANOUT; + if (parent_begin == parent_end || cant_combine) { + ltstate.WindowSegmentValue(*this, 0, begin, end, state_ptr); + continue; + } + + idx_t group_begin = parent_begin * TREE_FANOUT; + if (begin != group_begin) { + ltstate.WindowSegmentValue(*this, 0, begin, group_begin + TREE_FANOUT, state_ptr); + parent_begin++; + } + idx_t group_end = parent_end * TREE_FANOUT; + if (end != group_end) { + ltstate.WindowSegmentValue(*this, 0, group_end, end, state_ptr); + } + } + ltstate.FlushStates(false); + + ltstate.Finalize(result, count); + + // Set the validity mask on the invalid rows + auto &rmask = FlatVector::Validity(result); + for (idx_t rid = 0; rid < count; ++rid) { + const auto begin = begins[rid]; + const auto end = ends[rid]; + + if (begin >= end) { + rmask.SetInvalid(rid); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/distributive/count.cpp b/src/duckdb/src/function/aggregate/distributive/count.cpp new file mode 100644 index 00000000..ec9d705b --- /dev/null +++ b/src/duckdb/src/function/aggregate/distributive/count.cpp @@ -0,0 +1,259 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +struct BaseCountFunction { + template + static void Initialize(STATE &state) { + state = 0; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target += source; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + target = state; + } +}; + +struct CountStarFunction : public BaseCountFunction { + template + static void Operation(STATE &state, AggregateInputData &, idx_t idx) { + state += 1; + } + + template + static void ConstantOperation(STATE &state, AggregateInputData &, idx_t count) { + state += count; + } + + template + static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, + idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, + Vector &result, idx_t rid, idx_t bias) { + D_ASSERT(input_count == 0); + auto data = FlatVector::GetData(result); + const auto begin = frame.start; + const auto end = frame.end; + // Slice to any filtered rows + if (!filter_mask.AllValid()) { + RESULT_TYPE filtered = 0; + for (auto i = begin; i < end; ++i) { + filtered += filter_mask.RowIsValid(i); + } + data[rid] = filtered; + } else { + data[rid] = end - begin; + } + } +}; + +struct CountFunction : public BaseCountFunction { + using STATE = int64_t; + + static void Operation(STATE &state) { + state += 1; + } + + static void ConstantOperation(STATE &state, idx_t count) { + state += count; + } + + static bool IgnoreNull() { + return true; + } + + static inline void CountFlatLoop(STATE **__restrict states, ValidityMask &mask, idx_t count) { + if (!mask.AllValid()) { + idx_t base_idx = 0; + auto entry_count = ValidityMask::EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + auto validity_entry = mask.GetValidityEntry(entry_idx); + idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (ValidityMask::AllValid(validity_entry)) { + // all valid: perform operation + for (; base_idx < next; base_idx++) { + CountFunction::Operation(*states[base_idx]); + } + } else if (ValidityMask::NoneValid(validity_entry)) { + // nothing valid: skip all + base_idx = next; + continue; + } else { + // partially valid: need to check individual elements for validity + idx_t start = base_idx; + for (; base_idx < next; base_idx++) { + if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { + CountFunction::Operation(*states[base_idx]); + } + } + } + } + } else { + for (idx_t i = 0; i < count; i++) { + CountFunction::Operation(*states[i]); + } + } + } + + static inline void CountScatterLoop(STATE **__restrict states, const SelectionVector &isel, + const SelectionVector &ssel, ValidityMask &mask, idx_t count) { + if (!mask.AllValid()) { + // potential NULL values + for (idx_t i = 0; i < count; i++) { + auto idx = isel.get_index(i); + auto sidx = ssel.get_index(i); + if (mask.RowIsValid(idx)) { + CountFunction::Operation(*states[sidx]); + } + } + } else { + // quick path: no NULL values + for (idx_t i = 0; i < count; i++) { + auto sidx = ssel.get_index(i); + CountFunction::Operation(*states[sidx]); + } + } + } + + static void CountScatter(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, + idx_t count) { + auto &input = inputs[0]; + if (input.GetVectorType() == VectorType::FLAT_VECTOR && states.GetVectorType() == VectorType::FLAT_VECTOR) { + auto sdata = FlatVector::GetData(states); + CountFlatLoop(sdata, FlatVector::Validity(input), count); + } else { + UnifiedVectorFormat idata, sdata; + input.ToUnifiedFormat(count, idata); + states.ToUnifiedFormat(count, sdata); + CountScatterLoop(reinterpret_cast(sdata.data), *idata.sel, *sdata.sel, idata.validity, count); + } + } + + static inline void CountFlatUpdateLoop(STATE &result, ValidityMask &mask, idx_t count) { + idx_t base_idx = 0; + auto entry_count = ValidityMask::EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + auto validity_entry = mask.GetValidityEntry(entry_idx); + idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (ValidityMask::AllValid(validity_entry)) { + // all valid + result += next - base_idx; + base_idx = next; + } else if (ValidityMask::NoneValid(validity_entry)) { + // nothing valid: skip all + base_idx = next; + continue; + } else { + // partially valid: need to check individual elements for validity + idx_t start = base_idx; + for (; base_idx < next; base_idx++) { + if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { + result++; + } + } + } + } + } + + static inline void CountUpdateLoop(STATE &result, ValidityMask &mask, idx_t count, + const SelectionVector &sel_vector) { + if (mask.AllValid()) { + // no NULL values + result += count; + return; + } + for (idx_t i = 0; i < count; i++) { + auto idx = sel_vector.get_index(i); + if (mask.RowIsValid(idx)) { + result++; + } + } + } + + static void CountUpdate(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state_p, idx_t count) { + auto &input = inputs[0]; + auto &result = *reinterpret_cast(state_p); + switch (input.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + if (!ConstantVector::IsNull(input)) { + // if the constant is not null increment the state + result += count; + } + break; + } + case VectorType::FLAT_VECTOR: { + CountFlatUpdateLoop(result, FlatVector::Validity(input), count); + break; + } + case VectorType::SEQUENCE_VECTOR: { + // sequence vectors cannot have NULL values + result += count; + break; + } + default: { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + CountUpdateLoop(result, idata.validity, count, *idata.sel); + break; + } + } + } +}; + +AggregateFunction CountFun::GetFunction() { + AggregateFunction fun({LogicalType(LogicalTypeId::ANY)}, LogicalType::BIGINT, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, CountFunction::CountScatter, + AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, + FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountUpdate); + fun.name = "count"; + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +AggregateFunction CountStarFun::GetFunction() { + auto fun = AggregateFunction::NullaryAggregate(LogicalType::BIGINT); + fun.name = "count_star"; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.window = CountStarFunction::Window; + return fun; +} + +unique_ptr CountPropagateStats(ClientContext &context, BoundAggregateExpression &expr, + AggregateStatisticsInput &input) { + if (!expr.IsDistinct() && !input.child_stats[0].CanHaveNull()) { + // count on a column without null values: use count star + expr.function = CountStarFun::GetFunction(); + expr.function.name = "count_star"; + expr.children.clear(); + } + return nullptr; +} + +void CountFun::RegisterFunction(BuiltinFunctions &set) { + AggregateFunction count_function = CountFun::GetFunction(); + count_function.statistics = CountPropagateStats; + AggregateFunctionSet count("count"); + count.AddFunction(count_function); + // the count function can also be called without arguments + count_function.arguments.clear(); + count_function.statistics = nullptr; + count_function.window = CountStarFunction::Window; + count.AddFunction(count_function); + set.AddFunction(count); +} + +void CountStarFun::RegisterFunction(BuiltinFunctions &set) { + AggregateFunctionSet count("count_star"); + count.AddFunction(CountStarFun::GetFunction()); + set.AddFunction(count); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/distributive/first.cpp b/src/duckdb/src/function/aggregate/distributive/first.cpp new file mode 100644 index 00000000..8f8df8a7 --- /dev/null +++ b/src/duckdb/src/function/aggregate/distributive/first.cpp @@ -0,0 +1,355 @@ +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +template +struct FirstState { + T value; + bool is_set; + bool is_null; +}; + +struct FirstFunctionBase { + template + static void Initialize(STATE &state) { + state.is_set = false; + state.is_null = false; + } + + static bool IgnoreNull() { + return false; + } +}; + +template +struct FirstFunction : public FirstFunctionBase { + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (LAST || !state.is_set) { + if (!unary_input.RowIsValid()) { + if (!SKIP_NULLS) { + state.is_set = true; + } + state.is_null = true; + } else { + state.is_set = true; + state.is_null = false; + state.value = input; + } + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + Operation(state, input, unary_input); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!target.is_set) { + target = source; + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set || state.is_null) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +template +struct FirstFunctionString : public FirstFunctionBase { + template + static void SetValue(STATE &state, AggregateInputData &input_data, string_t value, bool is_null) { + if (LAST && state.is_set) { + Destroy(state, input_data); + } + if (is_null) { + if (!SKIP_NULLS) { + state.is_set = true; + state.is_null = true; + } + } else { + state.is_set = true; + state.is_null = false; + if (value.IsInlined()) { + state.value = value; + } else { + // non-inlined string, need to allocate space for it + auto len = value.GetSize(); + auto ptr = new char[len]; + memcpy(ptr, value.GetData(), len); + + state.value = string_t(ptr, len); + } + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (LAST || !state.is_set) { + SetValue(state, unary_input.input, input, !unary_input.RowIsValid()); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + Operation(state, input, unary_input); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { + if (source.is_set && (LAST || !target.is_set)) { + SetValue(target, input_data, source.value, source.is_null); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set || state.is_null) { + finalize_data.ReturnNull(); + } else { + target = StringVector::AddStringOrBlob(finalize_data.result, state.value); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.is_set && !state.is_null && !state.value.IsInlined()) { + delete[] state.value.GetData(); + } + } +}; + +struct FirstStateVector { + Vector *value; +}; + +template +struct FirstVectorFunction { + template + static void Initialize(STATE &state) { + state.value = nullptr; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.value) { + delete state.value; + } + } + static bool IgnoreNull() { + return SKIP_NULLS; + } + + template + static void SetValue(STATE &state, Vector &input, const idx_t idx) { + if (!state.value) { + state.value = new Vector(input.GetType()); + state.value->SetVectorType(VectorType::CONSTANT_VECTOR); + } + sel_t selv = idx; + SelectionVector sel(&selv); + VectorOperations::Copy(input, *state.value, sel, 1, 0, 0); + } + + static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { + auto &input = inputs[0]; + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + auto states = UnifiedVectorFormat::GetData(sdata); + for (idx_t i = 0; i < count; i++) { + const auto idx = idata.sel->get_index(i); + if (SKIP_NULLS && !idata.validity.RowIsValid(idx)) { + continue; + } + auto &state = *states[sdata.sel->get_index(i)]; + if (LAST || !state.value) { + SetValue(state, input, i); + } + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.value && (LAST || !target.value)) { + SetValue(target, *source.value, 0); + } + } + + template + static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { + if (!state.value) { + finalize_data.ReturnNull(); + } else { + VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx); + } + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function.arguments[0] = arguments[0]->return_type; + function.return_type = arguments[0]->return_type; + return nullptr; + } +}; + +template +static AggregateFunction GetFirstAggregateTemplated(LogicalType type) { + return AggregateFunction::UnaryAggregate, T, T, FirstFunction>(type, type); +} + +template +static AggregateFunction GetFirstFunction(const LogicalType &type); + +template +AggregateFunction GetDecimalFirstFunction(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::DECIMAL); + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetFirstFunction(LogicalType::SMALLINT); + case PhysicalType::INT32: + return GetFirstFunction(LogicalType::INTEGER); + case PhysicalType::INT64: + return GetFirstFunction(LogicalType::BIGINT); + default: + return GetFirstFunction(LogicalType::HUGEINT); + } +} + +template +static AggregateFunction GetFirstFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::BOOLEAN: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::TINYINT: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::SMALLINT: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::INTEGER: + case LogicalTypeId::DATE: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::BIGINT: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIMESTAMP_TZ: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::UTINYINT: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::USMALLINT: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::UINTEGER: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::UBIGINT: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::HUGEINT: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::FLOAT: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::DOUBLE: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::INTERVAL: + return GetFirstAggregateTemplated(type); + case LogicalTypeId::VARCHAR: + case LogicalTypeId::BLOB: + return AggregateFunction::UnaryAggregateDestructor, string_t, string_t, + FirstFunctionString>(type, type); + case LogicalTypeId::DECIMAL: { + type.Verify(); + AggregateFunction function = GetDecimalFirstFunction(type); + function.arguments[0] = type; + function.return_type = type; + // TODO set_key here? + return function; + } + default: { + using OP = FirstVectorFunction; + return AggregateFunction({type}, type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, OP::Update, + AggregateFunction::StateCombine, + AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateDestroy, nullptr, nullptr); + } + } +} + +AggregateFunction FirstFun::GetFunction(const LogicalType &type) { + auto fun = GetFirstFunction(type); + fun.name = "first"; + return fun; +} + +template +unique_ptr BindDecimalFirst(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + auto name = std::move(function.name); + function = GetFirstFunction(decimal_type); + function.name = std::move(name); + function.return_type = decimal_type; + return nullptr; +} + +template +static AggregateFunction GetFirstOperator(const LogicalType &type) { + if (type.id() == LogicalTypeId::DECIMAL) { + throw InternalException("FIXME: this shouldn't happen..."); + } + return GetFirstFunction(type); +} + +template +unique_ptr BindFirst(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto input_type = arguments[0]->return_type; + auto name = std::move(function.name); + function = GetFirstOperator(input_type); + function.name = std::move(name); + if (function.bind) { + return function.bind(context, function, arguments); + } else { + return nullptr; + } +} + +template +static void AddFirstOperator(AggregateFunctionSet &set) { + set.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, BindDecimalFirst)); + set.AddFunction(AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, BindFirst)); +} + +void FirstFun::RegisterFunction(BuiltinFunctions &set) { + AggregateFunctionSet first("first"); + AggregateFunctionSet last("last"); + AggregateFunctionSet any_value("any_value"); + + AddFirstOperator(first); + AddFirstOperator(last); + AddFirstOperator(any_value); + + set.AddFunction(first); + first.name = "arbitrary"; + set.AddFunction(first); + + set.AddFunction(last); + + set.AddFunction(any_value); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/distributive_functions.cpp b/src/duckdb/src/function/aggregate/distributive_functions.cpp new file mode 100644 index 00000000..5971861c --- /dev/null +++ b/src/duckdb/src/function/aggregate/distributive_functions.cpp @@ -0,0 +1,15 @@ +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/aggregate_function.hpp" + +namespace duckdb { + +void BuiltinFunctions::RegisterDistributiveAggregates() { + Register(); + Register(); + Register(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp new file mode 100644 index 00000000..6e5f2196 --- /dev/null +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -0,0 +1,585 @@ +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/function/aggregate/distributive_functions.hpp" + +namespace duckdb { + +struct SortedAggregateBindData : public FunctionData { + SortedAggregateBindData(ClientContext &context, BoundAggregateExpression &expr) + : buffer_manager(BufferManager::GetBufferManager(context)), function(expr.function), + bind_info(std::move(expr.bind_info)), threshold(ClientConfig::GetConfig(context).ordered_aggregate_threshold), + external(ClientConfig::GetConfig(context).force_external) { + auto &children = expr.children; + arg_types.reserve(children.size()); + for (const auto &child : children) { + arg_types.emplace_back(child->return_type); + } + auto &order_bys = *expr.order_bys; + sort_types.reserve(order_bys.orders.size()); + for (auto &order : order_bys.orders) { + orders.emplace_back(order.Copy()); + sort_types.emplace_back(order.expression->return_type); + } + sorted_on_args = (children.size() == order_bys.orders.size()); + for (size_t i = 0; sorted_on_args && i < children.size(); ++i) { + sorted_on_args = children[i]->Equals(*order_bys.orders[i].expression); + } + } + + SortedAggregateBindData(const SortedAggregateBindData &other) + : buffer_manager(other.buffer_manager), function(other.function), arg_types(other.arg_types), + sort_types(other.sort_types), sorted_on_args(other.sorted_on_args), threshold(other.threshold), + external(other.external) { + if (other.bind_info) { + bind_info = other.bind_info->Copy(); + } + for (auto &order : other.orders) { + orders.emplace_back(order.Copy()); + } + } + + unique_ptr Copy() const override { + return make_uniq(*this); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + if (bind_info && other.bind_info) { + if (!bind_info->Equals(*other.bind_info)) { + return false; + } + } else if (bind_info || other.bind_info) { + return false; + } + if (function != other.function) { + return false; + } + if (orders.size() != other.orders.size()) { + return false; + } + for (size_t i = 0; i < orders.size(); ++i) { + if (!orders[i].Equals(other.orders[i])) { + return false; + } + } + return true; + } + + BufferManager &buffer_manager; + AggregateFunction function; + vector arg_types; + unique_ptr bind_info; + + vector orders; + vector sort_types; + bool sorted_on_args; + + //! The sort flush threshold + const idx_t threshold; + const bool external; +}; + +struct SortedAggregateState { + //! Default buffer size, optimised for small group to avoid blowing out memory. + static const idx_t BUFFER_CAPACITY = 16; + + SortedAggregateState() : count(0), nsel(0), offset(0) { + } + + static inline void InitializeBuffer(DataChunk &chunk, const vector &types) { + if (!chunk.ColumnCount() && !types.empty()) { + chunk.Initialize(Allocator::DefaultAllocator(), types, BUFFER_CAPACITY); + } + } + + //! Make sure the buffer is large enough for slicing + static inline void ResetBuffer(DataChunk &chunk, const vector &types) { + chunk.Reset(); + chunk.Destroy(); + chunk.Initialize(Allocator::DefaultAllocator(), types); + } + + void Flush(const SortedAggregateBindData &order_bind) { + if (ordering) { + return; + } + + ordering = make_uniq(order_bind.buffer_manager, order_bind.sort_types); + InitializeBuffer(sort_buffer, order_bind.sort_types); + ordering->Append(sort_buffer); + ResetBuffer(sort_buffer, order_bind.sort_types); + + if (!order_bind.sorted_on_args) { + arguments = make_uniq(order_bind.buffer_manager, order_bind.arg_types); + InitializeBuffer(arg_buffer, order_bind.arg_types); + arguments->Append(arg_buffer); + ResetBuffer(arg_buffer, order_bind.arg_types); + } + } + + void Update(const SortedAggregateBindData &order_bind, DataChunk &sort_chunk, DataChunk &arg_chunk) { + count += sort_chunk.size(); + + // Lazy instantiation of the buffer chunks + InitializeBuffer(sort_buffer, order_bind.sort_types); + if (!order_bind.sorted_on_args) { + InitializeBuffer(arg_buffer, order_bind.arg_types); + } + + if (sort_chunk.size() + sort_buffer.size() > STANDARD_VECTOR_SIZE) { + Flush(order_bind); + } + if (arguments) { + ordering->Append(sort_chunk); + arguments->Append(arg_chunk); + } else if (ordering) { + ordering->Append(sort_chunk); + } else if (order_bind.sorted_on_args) { + sort_buffer.Append(sort_chunk, true); + } else { + sort_buffer.Append(sort_chunk, true); + arg_buffer.Append(arg_chunk, true); + } + } + + void UpdateSlice(const SortedAggregateBindData &order_bind, DataChunk &sort_inputs, DataChunk &arg_inputs) { + count += nsel; + + // Lazy instantiation of the buffer chunks + InitializeBuffer(sort_buffer, order_bind.sort_types); + if (!order_bind.sorted_on_args) { + InitializeBuffer(arg_buffer, order_bind.arg_types); + } + + if (nsel + sort_buffer.size() > STANDARD_VECTOR_SIZE) { + Flush(order_bind); + } + if (arguments) { + sort_buffer.Reset(); + sort_buffer.Slice(sort_inputs, sel, nsel); + ordering->Append(sort_buffer); + + arg_buffer.Reset(); + arg_buffer.Slice(arg_inputs, sel, nsel); + arguments->Append(arg_buffer); + } else if (ordering) { + sort_buffer.Reset(); + sort_buffer.Slice(sort_inputs, sel, nsel); + ordering->Append(sort_buffer); + } else if (order_bind.sorted_on_args) { + sort_buffer.Append(sort_inputs, true, &sel, nsel); + } else { + sort_buffer.Append(sort_inputs, true, &sel, nsel); + arg_buffer.Append(arg_inputs, true, &sel, nsel); + } + + nsel = 0; + offset = 0; + } + + void Combine(SortedAggregateBindData &order_bind, SortedAggregateState &other) { + if (other.arguments) { + // Force CDC if the other has it + Flush(order_bind); + ordering->Combine(*other.ordering); + arguments->Combine(*other.arguments); + count += other.count; + } else if (other.ordering) { + // Force CDC if the other has it + Flush(order_bind); + ordering->Combine(*other.ordering); + count += other.count; + } else if (other.sort_buffer.size()) { + Update(order_bind, other.sort_buffer, other.arg_buffer); + } + } + + void PrefixSortBuffer(DataChunk &prefixed) { + for (column_t col_idx = 0; col_idx < sort_buffer.ColumnCount(); ++col_idx) { + prefixed.data[col_idx + 1].Reference(sort_buffer.data[col_idx]); + } + prefixed.SetCardinality(sort_buffer); + } + + void Finalize(const SortedAggregateBindData &order_bind, DataChunk &prefixed, LocalSortState &local_sort) { + if (arguments) { + ColumnDataScanState sort_state; + ordering->InitializeScan(sort_state); + ColumnDataScanState arg_state; + arguments->InitializeScan(arg_state); + for (sort_buffer.Reset(); ordering->Scan(sort_state, sort_buffer); sort_buffer.Reset()) { + PrefixSortBuffer(prefixed); + arg_buffer.Reset(); + arguments->Scan(arg_state, arg_buffer); + local_sort.SinkChunk(prefixed, arg_buffer); + } + ordering->Reset(); + arguments->Reset(); + } else if (ordering) { + ColumnDataScanState sort_state; + ordering->InitializeScan(sort_state); + for (sort_buffer.Reset(); ordering->Scan(sort_state, sort_buffer); sort_buffer.Reset()) { + PrefixSortBuffer(prefixed); + local_sort.SinkChunk(prefixed, sort_buffer); + } + ordering->Reset(); + } else if (order_bind.sorted_on_args) { + PrefixSortBuffer(prefixed); + local_sort.SinkChunk(prefixed, sort_buffer); + } else { + PrefixSortBuffer(prefixed); + local_sort.SinkChunk(prefixed, arg_buffer); + } + } + + idx_t count; + unique_ptr arguments; + unique_ptr ordering; + + DataChunk sort_buffer; + DataChunk arg_buffer; + + // Selection for scattering + SelectionVector sel; + idx_t nsel; + idx_t offset; +}; + +struct SortedAggregateFunction { + template + static void Initialize(STATE &state) { + new (&state) STATE(); + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.~STATE(); + } + + static void ProjectInputs(Vector inputs[], const SortedAggregateBindData &order_bind, idx_t input_count, + idx_t count, DataChunk &arg_chunk, DataChunk &sort_chunk) { + idx_t col = 0; + + if (!order_bind.sorted_on_args) { + arg_chunk.InitializeEmpty(order_bind.arg_types); + for (auto &dst : arg_chunk.data) { + dst.Reference(inputs[col++]); + } + arg_chunk.SetCardinality(count); + } + + sort_chunk.InitializeEmpty(order_bind.sort_types); + for (auto &dst : sort_chunk.data) { + dst.Reference(inputs[col++]); + } + sort_chunk.SetCardinality(count); + } + + static void SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, + idx_t count) { + const auto order_bind = aggr_input_data.bind_data->Cast(); + DataChunk arg_chunk; + DataChunk sort_chunk; + ProjectInputs(inputs, order_bind, input_count, count, arg_chunk, sort_chunk); + + const auto order_state = reinterpret_cast(state); + order_state->Update(order_bind, sort_chunk, arg_chunk); + } + + static void ScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, + idx_t count) { + if (!count) { + return; + } + + // Append the arguments to the two sub-collections + const auto &order_bind = aggr_input_data.bind_data->Cast(); + DataChunk arg_inputs; + DataChunk sort_inputs; + ProjectInputs(inputs, order_bind, input_count, count, arg_inputs, sort_inputs); + + // We have to scatter the chunks one at a time + // so build a selection vector for each one. + UnifiedVectorFormat svdata; + states.ToUnifiedFormat(count, svdata); + + // Size the selection vector for each state. + auto sdata = UnifiedVectorFormat::GetDataNoConst(svdata); + for (idx_t i = 0; i < count; ++i) { + auto sidx = svdata.sel->get_index(i); + auto order_state = sdata[sidx]; + order_state->nsel++; + } + + // Build the selection vector for each state. + vector sel_data(count); + idx_t start = 0; + for (idx_t i = 0; i < count; ++i) { + auto sidx = svdata.sel->get_index(i); + auto order_state = sdata[sidx]; + if (!order_state->offset) { + // First one + order_state->offset = start; + order_state->sel.Initialize(sel_data.data() + order_state->offset); + start += order_state->nsel; + } + sel_data[order_state->offset++] = sidx; + } + + // Append nonempty slices to the arguments + for (idx_t i = 0; i < count; ++i) { + auto sidx = svdata.sel->get_index(i); + auto order_state = sdata[sidx]; + if (!order_state->nsel) { + continue; + } + + order_state->UpdateSlice(order_bind, sort_inputs, arg_inputs); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + auto &order_bind = aggr_input_data.bind_data->Cast(); + auto &other = const_cast(source); + target.Combine(order_bind, other); + } + + static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, + idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, + Vector &result, idx_t rid, idx_t bias) { + throw InternalException("Sorted aggregates should not be generated for window clauses"); + } + + static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + const idx_t offset) { + auto &order_bind = aggr_input_data.bind_data->Cast(); + auto &buffer_manager = order_bind.buffer_manager; + RowLayout payload_layout; + payload_layout.Initialize(order_bind.arg_types); + DataChunk chunk; + chunk.Initialize(Allocator::DefaultAllocator(), order_bind.arg_types); + DataChunk sliced; + sliced.Initialize(Allocator::DefaultAllocator(), order_bind.arg_types); + + // Reusable inner state + vector agg_state(order_bind.function.state_size()); + Vector agg_state_vec(Value::POINTER(CastPointerToValue(agg_state.data()))); + + // State variables + auto bind_info = order_bind.bind_info.get(); + ArenaAllocator allocator(Allocator::DefaultAllocator()); + AggregateInputData aggr_bind_info(bind_info, allocator); + + // Inner aggregate APIs + auto initialize = order_bind.function.initialize; + auto destructor = order_bind.function.destructor; + auto simple_update = order_bind.function.simple_update; + auto update = order_bind.function.update; + auto finalize = order_bind.function.finalize; + + auto sdata = FlatVector::GetData(states); + + vector state_unprocessed(count, 0); + for (idx_t i = 0; i < count; ++i) { + state_unprocessed[i] = sdata[i]->count; + } + + // Sort the input payloads on (state_idx ASC, orders) + vector orders; + orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, + make_uniq(Value::USMALLINT(0)))); + for (const auto &order : order_bind.orders) { + orders.emplace_back(order.Copy()); + } + + auto global_sort = make_uniq(buffer_manager, orders, payload_layout); + global_sort->external = order_bind.external; + auto local_sort = make_uniq(); + local_sort->Initialize(*global_sort, global_sort->buffer_manager); + + DataChunk prefixed; + prefixed.Initialize(Allocator::DefaultAllocator(), global_sort->sort_layout.logical_types); + + // Go through the states accumulating values to sort until we hit the sort threshold + idx_t unsorted_count = 0; + idx_t sorted = 0; + for (idx_t finalized = 0; finalized < count;) { + if (unsorted_count < order_bind.threshold) { + auto state = sdata[finalized]; + prefixed.Reset(); + prefixed.data[0].Reference(Value::USMALLINT(finalized)); + state->Finalize(order_bind, prefixed, *local_sort); + unsorted_count += state_unprocessed[finalized]; + + // Go to the next aggregate unless this is the last one + if (++finalized < count) { + continue; + } + } + + // If they were all empty (filtering) flush them + // (This can only happen on the last range) + if (!unsorted_count) { + break; + } + + // Sort all the data + global_sort->AddLocalState(*local_sort); + global_sort->PrepareMergePhase(); + while (global_sort->sorted_blocks.size() > 1) { + global_sort->InitializeMergeRound(); + MergeSorter merge_sorter(*global_sort, global_sort->buffer_manager); + merge_sorter.PerformInMergeRound(); + global_sort->CompleteMergeRound(false); + } + + auto scanner = make_uniq(*global_sort); + initialize(agg_state.data()); + while (scanner->Remaining()) { + chunk.Reset(); + scanner->Scan(chunk); + idx_t consumed = 0; + + // Distribute the scanned chunk to the aggregates + while (consumed < chunk.size()) { + // Find the next aggregate that needs data + for (; !state_unprocessed[sorted]; ++sorted) { + // Finalize a single value at the next offset + agg_state_vec.SetVectorType(states.GetVectorType()); + finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); + if (destructor) { + destructor(agg_state_vec, aggr_bind_info, 1); + } + + initialize(agg_state.data()); + } + const auto input_count = MinValue(state_unprocessed[sorted], chunk.size() - consumed); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + sliced.data[col_idx].Slice(chunk.data[col_idx], consumed, consumed + input_count); + } + sliced.SetCardinality(input_count); + + // These are all simple updates, so use it if available + if (simple_update) { + simple_update(sliced.data.data(), aggr_bind_info, sliced.data.size(), agg_state.data(), + sliced.size()); + } else { + // We are only updating a constant state + agg_state_vec.SetVectorType(VectorType::CONSTANT_VECTOR); + update(sliced.data.data(), aggr_bind_info, sliced.data.size(), agg_state_vec, sliced.size()); + } + + consumed += input_count; + state_unprocessed[sorted] -= input_count; + } + } + + // Finalize the last state for this sort + agg_state_vec.SetVectorType(states.GetVectorType()); + finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); + if (destructor) { + destructor(agg_state_vec, aggr_bind_info, 1); + } + ++sorted; + + // Stop if we are done + if (finalized >= count) { + break; + } + + // Create a new sort + scanner.reset(); + global_sort = make_uniq(buffer_manager, orders, payload_layout); + global_sort->external = order_bind.external; + local_sort = make_uniq(); + local_sort->Initialize(*global_sort, global_sort->buffer_manager); + unsorted_count = 0; + } + + for (; sorted < count; ++sorted) { + initialize(agg_state.data()); + + // Finalize a single value at the next offset + agg_state_vec.SetVectorType(states.GetVectorType()); + finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); + + if (destructor) { + destructor(agg_state_vec, aggr_bind_info, 1); + } + } + + result.Verify(count); + } +}; + +void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, + const vector> &groups) { + if (!expr.order_bys || expr.order_bys->orders.empty() || expr.children.empty()) { + // not a sorted aggregate: return + return; + } + if (context.config.enable_optimizer) { + // for each ORDER BY - check if it is actually necessary + // expressions that are in the groups do not need to be ORDERED BY + // `ORDER BY` on a group has no effect, because for each aggregate, the group is unique + // similarly, we only need to ORDER BY each aggregate once + expression_set_t seen_expressions; + for (auto &target : groups) { + seen_expressions.insert(*target); + } + vector new_order_nodes; + for (auto &order_node : expr.order_bys->orders) { + if (seen_expressions.find(*order_node.expression) != seen_expressions.end()) { + // we do not need to order by this node + continue; + } + seen_expressions.insert(*order_node.expression); + new_order_nodes.push_back(std::move(order_node)); + } + if (new_order_nodes.empty()) { + expr.order_bys.reset(); + return; + } + expr.order_bys->orders = std::move(new_order_nodes); + } + auto &bound_function = expr.function; + auto &children = expr.children; + auto &order_bys = *expr.order_bys; + auto sorted_bind = make_uniq(context, expr); + + if (!sorted_bind->sorted_on_args) { + // The arguments are the children plus the sort columns. + for (auto &order : order_bys.orders) { + children.emplace_back(std::move(order.expression)); + } + } + + vector arguments; + arguments.reserve(children.size()); + for (const auto &child : children) { + arguments.emplace_back(child->return_type); + } + + // Replace the aggregate with the wrapper + AggregateFunction ordered_aggregate( + bound_function.name, arguments, bound_function.return_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + SortedAggregateFunction::ScatterUpdate, + AggregateFunction::StateCombine, + SortedAggregateFunction::Finalize, bound_function.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, + AggregateFunction::StateDestroy, nullptr, + SortedAggregateFunction::Window); + + expr.function = std::move(ordered_aggregate); + expr.bind_info = std::move(sorted_bind); + expr.order_bys.reset(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/built_in_functions.cpp b/src/duckdb/src/function/built_in_functions.cpp new file mode 100644 index 00000000..23372fcd --- /dev/null +++ b/src/duckdb/src/function/built_in_functions.cpp @@ -0,0 +1,88 @@ +#include "duckdb/function/built_in_functions.hpp" +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" +#include "duckdb/parser/parsed_data/create_collation_info.hpp" +#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" +#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/parser/parsed_data/create_table_function_info.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" + +namespace duckdb { + +BuiltinFunctions::BuiltinFunctions(CatalogTransaction transaction, Catalog &catalog) + : transaction(transaction), catalog(catalog) { +} + +BuiltinFunctions::~BuiltinFunctions() { +} + +void BuiltinFunctions::AddCollation(string name, ScalarFunction function, bool combinable, + bool not_required_for_equality) { + CreateCollationInfo info(std::move(name), std::move(function), combinable, not_required_for_equality); + info.internal = true; + catalog.CreateCollation(transaction, info); +} + +void BuiltinFunctions::AddFunction(AggregateFunctionSet set) { + CreateAggregateFunctionInfo info(std::move(set)); + info.internal = true; + catalog.CreateFunction(transaction, info); +} + +void BuiltinFunctions::AddFunction(AggregateFunction function) { + CreateAggregateFunctionInfo info(std::move(function)); + info.internal = true; + catalog.CreateFunction(transaction, info); +} + +void BuiltinFunctions::AddFunction(PragmaFunction function) { + CreatePragmaFunctionInfo info(std::move(function)); + info.internal = true; + catalog.CreatePragmaFunction(transaction, info); +} + +void BuiltinFunctions::AddFunction(const string &name, PragmaFunctionSet functions) { + CreatePragmaFunctionInfo info(name, std::move(functions)); + info.internal = true; + catalog.CreatePragmaFunction(transaction, info); +} + +void BuiltinFunctions::AddFunction(ScalarFunction function) { + CreateScalarFunctionInfo info(std::move(function)); + info.internal = true; + catalog.CreateFunction(transaction, info); +} + +void BuiltinFunctions::AddFunction(const vector &names, ScalarFunction function) { // NOLINT: false positive + for (auto &name : names) { + function.name = name; + AddFunction(function); + } +} + +void BuiltinFunctions::AddFunction(ScalarFunctionSet set) { + CreateScalarFunctionInfo info(std::move(set)); + info.internal = true; + catalog.CreateFunction(transaction, info); +} + +void BuiltinFunctions::AddFunction(TableFunction function) { + CreateTableFunctionInfo info(std::move(function)); + info.internal = true; + catalog.CreateTableFunction(transaction, info); +} + +void BuiltinFunctions::AddFunction(TableFunctionSet set) { + CreateTableFunctionInfo info(std::move(set)); + info.internal = true; + catalog.CreateTableFunction(transaction, info); +} + +void BuiltinFunctions::AddFunction(CopyFunction function) { + CreateCopyFunctionInfo info(std::move(function)); + info.internal = true; + catalog.CreateCopyFunction(transaction, info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/bit_cast.cpp b/src/duckdb/src/function/cast/bit_cast.cpp new file mode 100644 index 00000000..aaffc12d --- /dev/null +++ b/src/duckdb/src/function/cast/bit_cast.cpp @@ -0,0 +1,49 @@ +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +BoundCastInfo DefaultCasts::BitCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + // Numerics + case LogicalTypeId::BOOLEAN: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::TINYINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::SMALLINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::INTEGER: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::BIGINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::UTINYINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::USMALLINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::UINTEGER: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::UBIGINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::HUGEINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::FLOAT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::DOUBLE: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::BLOB: + return BoundCastInfo(&VectorCastHelpers::StringCast); + + case LogicalTypeId::VARCHAR: + return BoundCastInfo(&VectorCastHelpers::StringCast); + + default: + return DefaultCasts::TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/blob_cast.cpp b/src/duckdb/src/function/cast/blob_cast.cpp new file mode 100644 index 00000000..170a733d --- /dev/null +++ b/src/duckdb/src/function/cast/blob_cast.cpp @@ -0,0 +1,22 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +BoundCastInfo DefaultCasts::BlobCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // blob to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::AGGREGATE_STATE: + return DefaultCasts::ReinterpretCast; + case LogicalTypeId::BIT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + + default: + return DefaultCasts::TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/cast_function_set.cpp b/src/duckdb/src/function/cast/cast_function_set.cpp new file mode 100644 index 00000000..cc152e62 --- /dev/null +++ b/src/duckdb/src/function/cast/cast_function_set.cpp @@ -0,0 +1,193 @@ + +#include "duckdb/function/cast/cast_function_set.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types/type_map.hpp" +#include "duckdb/function/cast_rules.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +BindCastInput::BindCastInput(CastFunctionSet &function_set, optional_ptr info, + optional_ptr context) + : function_set(function_set), info(info), context(context) { +} + +BoundCastInfo BindCastInput::GetCastFunction(const LogicalType &source, const LogicalType &target) { + GetCastFunctionInput input(context); + return function_set.GetCastFunction(source, target, input); +} + +BindCastFunction::BindCastFunction(bind_cast_function_t function_p, unique_ptr info_p) + : function(function_p), info(std::move(info_p)) { +} + +CastFunctionSet::CastFunctionSet() : map_info(nullptr) { + bind_functions.emplace_back(DefaultCasts::GetDefaultCastFunction); +} + +CastFunctionSet &CastFunctionSet::Get(ClientContext &context) { + return DBConfig::GetConfig(context).GetCastFunctions(); +} + +CastFunctionSet &CastFunctionSet::Get(DatabaseInstance &db) { + return DBConfig::GetConfig(db).GetCastFunctions(); +} + +BoundCastInfo CastFunctionSet::GetCastFunction(const LogicalType &source, const LogicalType &target, + GetCastFunctionInput &get_input) { + if (source == target) { + return DefaultCasts::NopCast; + } + // the first function is the default + // we iterate the set of bind functions backwards + for (idx_t i = bind_functions.size(); i > 0; i--) { + auto &bind_function = bind_functions[i - 1]; + BindCastInput input(*this, bind_function.info.get(), get_input.context); + auto result = bind_function.function(input, source, target); + if (result.function) { + // found a cast function! return it + return result; + } + } + // no cast found: return the default null cast + return DefaultCasts::TryVectorNullCast; +} + +struct MapCastNode { + MapCastNode(BoundCastInfo info, int64_t implicit_cast_cost) + : cast_info(std::move(info)), bind_function(nullptr), implicit_cast_cost(implicit_cast_cost) { + } + MapCastNode(bind_cast_function_t func, int64_t implicit_cast_cost) + : cast_info(nullptr), bind_function(func), implicit_cast_cost(implicit_cast_cost) { + } + + BoundCastInfo cast_info; + bind_cast_function_t bind_function; + int64_t implicit_cast_cost; +}; + +template +static auto RelaxedTypeMatch(type_map_t &map, const LogicalType &type) -> decltype(map.find(type)) { + D_ASSERT(map.find(type) == map.end()); // we shouldn't be here + switch (type.id()) { + case LogicalTypeId::LIST: + return map.find(LogicalType::LIST(LogicalType::ANY)); + case LogicalTypeId::STRUCT: + return map.find(LogicalType::STRUCT({{"any", LogicalType::ANY}})); + case LogicalTypeId::MAP: + for (auto it = map.begin(); it != map.end(); it++) { + const auto &entry_type = it->first; + if (entry_type.id() != LogicalTypeId::MAP) { + continue; + } + auto &entry_key_type = MapType::KeyType(entry_type); + auto &entry_val_type = MapType::ValueType(entry_type); + if ((entry_key_type == LogicalType::ANY || entry_key_type == MapType::KeyType(type)) && + (entry_val_type == LogicalType::ANY || entry_val_type == MapType::ValueType(type))) { + return it; + } + } + return map.end(); + case LogicalTypeId::UNION: + return map.find(LogicalType::UNION({{"any", LogicalType::ANY}})); + default: + return map.find(LogicalType::ANY); + } +} + +struct MapCastInfo : public BindCastInfo { +public: + const optional_ptr GetEntry(const LogicalType &source, const LogicalType &target) { + auto source_type_id_entry = casts.find(source.id()); + if (source_type_id_entry == casts.end()) { + source_type_id_entry = casts.find(LogicalTypeId::ANY); + if (source_type_id_entry == casts.end()) { + return nullptr; + } + } + + auto &source_type_entries = source_type_id_entry->second; + auto source_type_entry = source_type_entries.find(source); + if (source_type_entry == source_type_entries.end()) { + source_type_entry = RelaxedTypeMatch(source_type_entries, source); + if (source_type_entry == source_type_entries.end()) { + return nullptr; + } + } + + auto &target_type_id_entries = source_type_entry->second; + auto target_type_id_entry = target_type_id_entries.find(target.id()); + if (target_type_id_entry == target_type_id_entries.end()) { + target_type_id_entry = target_type_id_entries.find(LogicalTypeId::ANY); + if (target_type_id_entry == target_type_id_entries.end()) { + return nullptr; + } + } + + auto &target_type_entries = target_type_id_entry->second; + auto target_type_entry = target_type_entries.find(target); + if (target_type_entry == target_type_entries.end()) { + target_type_entry = RelaxedTypeMatch(target_type_entries, target); + if (target_type_entry == target_type_entries.end()) { + return nullptr; + } + } + + return &target_type_entry->second; + } + + void AddEntry(const LogicalType &source, const LogicalType &target, MapCastNode node) { + casts[source.id()][source][target.id()].insert(make_pair(target, std::move(node))); + } + +private: + type_id_map_t>>> casts; +}; + +int64_t CastFunctionSet::ImplicitCastCost(const LogicalType &source, const LogicalType &target) { + // check if a cast has been registered + if (map_info) { + auto entry = map_info->GetEntry(source, target); + if (entry) { + return entry->implicit_cast_cost; + } + } + // if not, fallback to the default implicit cast rules + return CastRules::ImplicitCast(source, target); +} + +BoundCastInfo MapCastFunction(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + D_ASSERT(input.info); + auto &map_info = input.info->Cast(); + auto entry = map_info.GetEntry(source, target); + if (entry) { + if (entry->bind_function) { + return entry->bind_function(input, source, target); + } + return entry->cast_info.Copy(); + } + return nullptr; +} + +void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, BoundCastInfo function, + int64_t implicit_cast_cost) { + RegisterCastFunction(source, target, MapCastNode(std::move(function), implicit_cast_cost)); +} + +void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, + bind_cast_function_t bind_function, int64_t implicit_cast_cost) { + RegisterCastFunction(source, target, MapCastNode(bind_function, implicit_cast_cost)); +} + +void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, MapCastNode node) { + if (!map_info) { + // create the cast map and the cast map function + auto info = make_uniq(); + map_info = info.get(); + bind_functions.emplace_back(MapCastFunction, std::move(info)); + } + map_info->AddEntry(source, target, std::move(node)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/decimal_cast.cpp b/src/duckdb/src/function/cast/decimal_cast.cpp new file mode 100644 index 00000000..e10a7799 --- /dev/null +++ b/src/duckdb/src/function/cast/decimal_cast.cpp @@ -0,0 +1,289 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +#include "duckdb/common/vector_operations/general_cast.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +template +static bool FromDecimalCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &source_type = source.GetType(); + auto width = DecimalType::GetWidth(source_type); + auto scale = DecimalType::GetScale(source_type); + switch (source_type.InternalType()) { + case PhysicalType::INT16: + return VectorCastHelpers::TemplatedDecimalCast( + source, result, count, parameters.error_message, width, scale); + case PhysicalType::INT32: + return VectorCastHelpers::TemplatedDecimalCast( + source, result, count, parameters.error_message, width, scale); + case PhysicalType::INT64: + return VectorCastHelpers::TemplatedDecimalCast( + source, result, count, parameters.error_message, width, scale); + case PhysicalType::INT128: + return VectorCastHelpers::TemplatedDecimalCast( + source, result, count, parameters.error_message, width, scale); + default: + throw InternalException("Unimplemented internal type for decimal"); + } +} + +template +struct DecimalScaleInput { + DecimalScaleInput(Vector &result_p, FACTOR_TYPE factor_p) : result(result_p), factor(factor_p) { + } + DecimalScaleInput(Vector &result_p, LIMIT_TYPE limit_p, FACTOR_TYPE factor_p, string *error_message_p, + uint8_t source_width_p, uint8_t source_scale_p) + : result(result_p), limit(limit_p), factor(factor_p), error_message(error_message_p), + source_width(source_width_p), source_scale(source_scale_p) { + } + + Vector &result; + LIMIT_TYPE limit; + FACTOR_TYPE factor; + bool all_converted = true; + string *error_message; + uint8_t source_width; + uint8_t source_scale; +}; + +struct DecimalScaleUpOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = (DecimalScaleInput *)dataptr; + return Cast::Operation(input) * data->factor; + } +}; + +struct DecimalScaleUpCheckOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = (DecimalScaleInput *)dataptr; + if (input >= data->limit || input <= -data->limit) { + auto error = StringUtil::Format("Casting value \"%s\" to type %s failed: value is out of range!", + Decimal::ToString(input, data->source_width, data->source_scale), + data->result.GetType().ToString()); + return HandleVectorCastError::Operation(std::move(error), mask, idx, data->error_message, + data->all_converted); + } + return Cast::Operation(input) * data->factor; + } +}; + +template +bool TemplatedDecimalScaleUp(Vector &source, Vector &result, idx_t count, string *error_message) { + auto source_scale = DecimalType::GetScale(source.GetType()); + auto source_width = DecimalType::GetWidth(source.GetType()); + auto result_scale = DecimalType::GetScale(result.GetType()); + auto result_width = DecimalType::GetWidth(result.GetType()); + D_ASSERT(result_scale >= source_scale); + idx_t scale_difference = result_scale - source_scale; + DEST multiply_factor = POWERS_DEST::POWERS_OF_TEN[scale_difference]; + idx_t target_width = result_width - scale_difference; + if (source_width < target_width) { + DecimalScaleInput input(result, multiply_factor); + // type will always fit: no need to check limit + UnaryExecutor::GenericExecute(source, result, count, &input); + return true; + } else { + // type might not fit: check limit + auto limit = POWERS_SOURCE::POWERS_OF_TEN[target_width]; + DecimalScaleInput input(result, limit, multiply_factor, error_message, source_width, + source_scale); + UnaryExecutor::GenericExecute(source, result, count, &input, + error_message); + return input.all_converted; + } +} + +struct DecimalScaleDownOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = (DecimalScaleInput *)dataptr; + return Cast::Operation(input / data->factor); + } +}; + +struct DecimalScaleDownCheckOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = (DecimalScaleInput *)dataptr; + if (input >= data->limit || input <= -data->limit) { + auto error = StringUtil::Format("Casting value \"%s\" to type %s failed: value is out of range!", + Decimal::ToString(input, data->source_width, data->source_scale), + data->result.GetType().ToString()); + return HandleVectorCastError::Operation(std::move(error), mask, idx, data->error_message, + data->all_converted); + } + return Cast::Operation(input / data->factor); + } +}; + +template +bool TemplatedDecimalScaleDown(Vector &source, Vector &result, idx_t count, string *error_message) { + auto source_scale = DecimalType::GetScale(source.GetType()); + auto source_width = DecimalType::GetWidth(source.GetType()); + auto result_scale = DecimalType::GetScale(result.GetType()); + auto result_width = DecimalType::GetWidth(result.GetType()); + D_ASSERT(result_scale < source_scale); + idx_t scale_difference = source_scale - result_scale; + idx_t target_width = result_width + scale_difference; + SOURCE divide_factor = POWERS_SOURCE::POWERS_OF_TEN[scale_difference]; + if (source_width < target_width) { + DecimalScaleInput input(result, divide_factor); + // type will always fit: no need to check limit + UnaryExecutor::GenericExecute(source, result, count, &input); + return true; + } else { + // type might not fit: check limit + auto limit = POWERS_SOURCE::POWERS_OF_TEN[target_width]; + DecimalScaleInput input(result, limit, divide_factor, error_message, source_width, source_scale); + UnaryExecutor::GenericExecute(source, result, count, &input, + error_message); + return input.all_converted; + } +} + +template +static bool DecimalDecimalCastSwitch(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto source_scale = DecimalType::GetScale(source.GetType()); + auto result_scale = DecimalType::GetScale(result.GetType()); + source.GetType().Verify(); + result.GetType().Verify(); + + // we need to either multiply or divide by the difference in scales + if (result_scale >= source_scale) { + // multiply + switch (result.GetType().InternalType()) { + case PhysicalType::INT16: + return TemplatedDecimalScaleUp(source, result, count, + parameters.error_message); + case PhysicalType::INT32: + return TemplatedDecimalScaleUp(source, result, count, + parameters.error_message); + case PhysicalType::INT64: + return TemplatedDecimalScaleUp(source, result, count, + parameters.error_message); + case PhysicalType::INT128: + return TemplatedDecimalScaleUp(source, result, count, + parameters.error_message); + default: + throw NotImplementedException("Unimplemented internal type for decimal"); + } + } else { + // divide + switch (result.GetType().InternalType()) { + case PhysicalType::INT16: + return TemplatedDecimalScaleDown(source, result, count, + parameters.error_message); + case PhysicalType::INT32: + return TemplatedDecimalScaleDown(source, result, count, + parameters.error_message); + case PhysicalType::INT64: + return TemplatedDecimalScaleDown(source, result, count, + parameters.error_message); + case PhysicalType::INT128: + return TemplatedDecimalScaleDown(source, result, count, + parameters.error_message); + default: + throw NotImplementedException("Unimplemented internal type for decimal"); + } + } +} + +struct DecimalCastInput { + DecimalCastInput(Vector &result_p, uint8_t width_p, uint8_t scale_p) + : result(result_p), width(width_p), scale(scale_p) { + } + + Vector &result; + uint8_t width; + uint8_t scale; +}; + +struct StringCastFromDecimalOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = reinterpret_cast(dataptr); + return StringCastFromDecimal::Operation(input, data->width, data->scale, data->result); + } +}; + +template +static bool DecimalToStringCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &source_type = source.GetType(); + auto width = DecimalType::GetWidth(source_type); + auto scale = DecimalType::GetScale(source_type); + DecimalCastInput input(result, width, scale); + + UnaryExecutor::GenericExecute(source, result, count, (void *)&input); + return true; +} + +BoundCastInfo DefaultCasts::DecimalCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::BOOLEAN: + return FromDecimalCast; + case LogicalTypeId::TINYINT: + return FromDecimalCast; + case LogicalTypeId::SMALLINT: + return FromDecimalCast; + case LogicalTypeId::INTEGER: + return FromDecimalCast; + case LogicalTypeId::BIGINT: + return FromDecimalCast; + case LogicalTypeId::UTINYINT: + return FromDecimalCast; + case LogicalTypeId::USMALLINT: + return FromDecimalCast; + case LogicalTypeId::UINTEGER: + return FromDecimalCast; + case LogicalTypeId::UBIGINT: + return FromDecimalCast; + case LogicalTypeId::HUGEINT: + return FromDecimalCast; + case LogicalTypeId::DECIMAL: { + // decimal to decimal cast + // first we need to figure out the source and target internal types + switch (source.InternalType()) { + case PhysicalType::INT16: + return DecimalDecimalCastSwitch; + case PhysicalType::INT32: + return DecimalDecimalCastSwitch; + case PhysicalType::INT64: + return DecimalDecimalCastSwitch; + case PhysicalType::INT128: + return DecimalDecimalCastSwitch; + default: + throw NotImplementedException("Unimplemented internal type for decimal in decimal_decimal cast"); + } + } + case LogicalTypeId::FLOAT: + return FromDecimalCast; + case LogicalTypeId::DOUBLE: + return FromDecimalCast; + case LogicalTypeId::VARCHAR: { + switch (source.InternalType()) { + case PhysicalType::INT16: + return DecimalToStringCast; + case PhysicalType::INT32: + return DecimalToStringCast; + case PhysicalType::INT64: + return DecimalToStringCast; + case PhysicalType::INT128: + return DecimalToStringCast; + default: + throw InternalException("Unimplemented internal decimal type"); + } + } + default: + return DefaultCasts::TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/default_casts.cpp b/src/duckdb/src/function/cast/default_casts.cpp new file mode 100644 index 00000000..1b6bff01 --- /dev/null +++ b/src/duckdb/src/function/cast/default_casts.cpp @@ -0,0 +1,147 @@ +#include "duckdb/function/cast/default_casts.hpp" + +#include "duckdb/common/likely.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +BindCastInfo::~BindCastInfo() { +} + +BoundCastData::~BoundCastData() { +} + +BoundCastInfo::BoundCastInfo(cast_function_t function_p, unique_ptr cast_data_p, + init_cast_local_state_t init_local_state_p) + : function(function_p), init_local_state(init_local_state_p), cast_data(std::move(cast_data_p)) { +} + +BoundCastInfo BoundCastInfo::Copy() const { + return BoundCastInfo(function, cast_data ? cast_data->Copy() : nullptr, init_local_state); +} + +bool DefaultCasts::NopCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + result.Reference(source); + return true; +} + +static string UnimplementedCastMessage(const LogicalType &source_type, const LogicalType &target_type) { + return StringUtil::Format("Unimplemented type for cast (%s -> %s)", source_type.ToString(), target_type.ToString()); +} + +// NULL cast only works if all values in source are NULL, otherwise an unimplemented cast exception is thrown +bool DefaultCasts::TryVectorNullCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + bool success = true; + if (VectorOperations::HasNotNull(source, count)) { + HandleCastError::AssignError(UnimplementedCastMessage(source.GetType(), result.GetType()), + parameters.error_message); + success = false; + } + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return success; +} + +bool DefaultCasts::ReinterpretCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + result.Reinterpret(source); + return true; +} + +static bool AggregateStateToBlobCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + if (result.GetType().id() != LogicalTypeId::BLOB) { + throw TypeMismatchException(source.GetType(), result.GetType(), + "Cannot cast AGGREGATE_STATE to anything but BLOB"); + } + result.Reinterpret(source); + return true; +} + +static bool NullTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + // cast a NULL to another type, just copy the properties and change the type + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return true; +} + +BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + D_ASSERT(source != target); + + // first check if were casting to a union + if (source.id() != LogicalTypeId::UNION && source.id() != LogicalTypeId::SQLNULL && + target.id() == LogicalTypeId::UNION) { + return ImplicitToUnionCast(input, source, target); + } + + // else, switch on source type + switch (source.id()) { + case LogicalTypeId::BOOLEAN: + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + return NumericCastSwitch(input, source, target); + case LogicalTypeId::POINTER: + return PointerCastSwitch(input, source, target); + case LogicalTypeId::UUID: + return UUIDCastSwitch(input, source, target); + case LogicalTypeId::DECIMAL: + return DecimalCastSwitch(input, source, target); + case LogicalTypeId::DATE: + return DateCastSwitch(input, source, target); + case LogicalTypeId::TIME: + return TimeCastSwitch(input, source, target); + case LogicalTypeId::TIME_TZ: + return TimeTzCastSwitch(input, source, target); + case LogicalTypeId::TIMESTAMP: + return TimestampCastSwitch(input, source, target); + case LogicalTypeId::TIMESTAMP_TZ: + return TimestampTzCastSwitch(input, source, target); + case LogicalTypeId::TIMESTAMP_NS: + return TimestampNsCastSwitch(input, source, target); + case LogicalTypeId::TIMESTAMP_MS: + return TimestampMsCastSwitch(input, source, target); + case LogicalTypeId::TIMESTAMP_SEC: + return TimestampSecCastSwitch(input, source, target); + case LogicalTypeId::INTERVAL: + return IntervalCastSwitch(input, source, target); + case LogicalTypeId::VARCHAR: + return StringCastSwitch(input, source, target); + case LogicalTypeId::BLOB: + return BlobCastSwitch(input, source, target); + case LogicalTypeId::BIT: + return BitCastSwitch(input, source, target); + case LogicalTypeId::SQLNULL: + return NullTypeCast; + case LogicalTypeId::MAP: + return MapCastSwitch(input, source, target); + case LogicalTypeId::STRUCT: + return StructCastSwitch(input, source, target); + case LogicalTypeId::LIST: + return ListCastSwitch(input, source, target); + case LogicalTypeId::UNION: + return UnionCastSwitch(input, source, target); + case LogicalTypeId::ENUM: + return EnumCastSwitch(input, source, target); + case LogicalTypeId::AGGREGATE_STATE: + return AggregateStateToBlobCast; + default: + return nullptr; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/enum_casts.cpp b/src/duckdb/src/function/cast/enum_casts.cpp new file mode 100644 index 00000000..062e83d7 --- /dev/null +++ b/src/duckdb/src/function/cast/enum_casts.cpp @@ -0,0 +1,182 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" + +namespace duckdb { + +template +bool EnumEnumCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + bool all_converted = true; + result.SetVectorType(VectorType::FLAT_VECTOR); + + auto &str_vec = EnumType::GetValuesInsertOrder(source.GetType()); + auto str_vec_ptr = FlatVector::GetData(str_vec); + + auto res_enum_type = result.GetType(); + + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + + auto source_data = UnifiedVectorFormat::GetData(vdata); + auto source_sel = vdata.sel; + auto source_mask = vdata.validity; + + auto result_data = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + + for (idx_t i = 0; i < count; i++) { + auto src_idx = source_sel->get_index(i); + if (!source_mask.RowIsValid(src_idx)) { + result_mask.SetInvalid(i); + continue; + } + auto key = EnumType::GetPos(res_enum_type, str_vec_ptr[source_data[src_idx]]); + if (key == -1) { + // key doesn't exist on result enum + if (!parameters.error_message) { + result_data[i] = HandleVectorCastError::Operation( + CastExceptionText(source_data[src_idx]), result_mask, i, + parameters.error_message, all_converted); + } else { + result_mask.SetInvalid(i); + } + continue; + } + result_data[i] = key; + } + return all_converted; +} + +template +BoundCastInfo EnumEnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + switch (target.InternalType()) { + case PhysicalType::UINT8: + return EnumEnumCast; + case PhysicalType::UINT16: + return EnumEnumCast; + case PhysicalType::UINT32: + return EnumEnumCast; + default: + throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); + } +} + +template +static bool EnumToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &enum_dictionary = EnumType::GetValuesInsertOrder(source.GetType()); + auto dictionary_data = FlatVector::GetData(enum_dictionary); + auto result_data = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + + auto source_data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto source_idx = vdata.sel->get_index(i); + if (!vdata.validity.RowIsValid(source_idx)) { + result_mask.SetInvalid(i); + continue; + } + auto enum_idx = source_data[source_idx]; + result_data[i] = dictionary_data[enum_idx]; + } + if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } else { + result.SetVectorType(VectorType::FLAT_VECTOR); + } + return true; +} + +struct EnumBoundCastData : public BoundCastData { + EnumBoundCastData(BoundCastInfo to_varchar_cast, BoundCastInfo from_varchar_cast) + : to_varchar_cast(std::move(to_varchar_cast)), from_varchar_cast(std::move(from_varchar_cast)) { + } + + BoundCastInfo to_varchar_cast; + BoundCastInfo from_varchar_cast; + +public: + unique_ptr Copy() const override { + return make_uniq(to_varchar_cast.Copy(), from_varchar_cast.Copy()); + } +}; + +unique_ptr BindEnumCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + auto to_varchar_cast = input.GetCastFunction(source, LogicalType::VARCHAR); + auto from_varchar_cast = input.GetCastFunction(LogicalType::VARCHAR, target); + return make_uniq(std::move(to_varchar_cast), std::move(from_varchar_cast)); +} + +struct EnumCastLocalState : public FunctionLocalState { +public: + unique_ptr to_varchar_local; + unique_ptr from_varchar_local; +}; + +static unique_ptr InitEnumCastLocalState(CastLocalStateParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + auto result = make_uniq(); + + if (cast_data.from_varchar_cast.init_local_state) { + CastLocalStateParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data); + result->from_varchar_local = cast_data.from_varchar_cast.init_local_state(from_varchar_params); + } + if (cast_data.to_varchar_cast.init_local_state) { + CastLocalStateParameters from_varchar_params(parameters, cast_data.to_varchar_cast.cast_data); + result->from_varchar_local = cast_data.to_varchar_cast.init_local_state(from_varchar_params); + } + return std::move(result); +} + +static bool EnumToAnyCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + auto &lstate = parameters.local_state->Cast(); + + Vector varchar_cast(LogicalType::VARCHAR, count); + + // cast to varchar + CastParameters to_varchar_params(parameters, cast_data.to_varchar_cast.cast_data, lstate.to_varchar_local); + cast_data.to_varchar_cast.function(source, varchar_cast, count, to_varchar_params); + + // cast from varchar to the target + CastParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data, lstate.from_varchar_local); + cast_data.from_varchar_cast.function(varchar_cast, result, count, from_varchar_params); + return true; +} + +BoundCastInfo DefaultCasts::EnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + auto enum_physical_type = source.InternalType(); + switch (target.id()) { + case LogicalTypeId::ENUM: { + // This means they are both ENUMs, but of different types. + switch (enum_physical_type) { + case PhysicalType::UINT8: + return EnumEnumCastSwitch(input, source, target); + case PhysicalType::UINT16: + return EnumEnumCastSwitch(input, source, target); + case PhysicalType::UINT32: + return EnumEnumCastSwitch(input, source, target); + default: + throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); + } + } + case LogicalTypeId::VARCHAR: + switch (enum_physical_type) { + case PhysicalType::UINT8: + return EnumToVarcharCast; + case PhysicalType::UINT16: + return EnumToVarcharCast; + case PhysicalType::UINT32: + return EnumToVarcharCast; + default: + throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); + } + default: { + return BoundCastInfo(EnumToAnyCast, BindEnumCast(input, source, target), InitEnumCastLocalState); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/list_casts.cpp b/src/duckdb/src/function/cast/list_casts.cpp new file mode 100644 index 00000000..a1c20b4c --- /dev/null +++ b/src/duckdb/src/function/cast/list_casts.cpp @@ -0,0 +1,139 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/cast/bound_cast_data.hpp" + +namespace duckdb { + +unique_ptr ListBoundCastData::BindListToListCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + vector child_cast_info; + auto &source_child_type = ListType::GetChildType(source); + auto &result_child_type = ListType::GetChildType(target); + auto child_cast = input.GetCastFunction(source_child_type, result_child_type); + return make_uniq(std::move(child_cast)); +} + +unique_ptr ListBoundCastData::InitListLocalState(CastLocalStateParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + if (!cast_data.child_cast_info.init_local_state) { + return nullptr; + } + CastLocalStateParameters child_parameters(parameters, cast_data.child_cast_info.cast_data); + return cast_data.child_cast_info.init_local_state(child_parameters); +} + +bool ListCast::ListToListCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + + // only handle constant and flat vectors here for now + if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(source.GetVectorType()); + ConstantVector::SetNull(result, ConstantVector::IsNull(source)); + + auto ldata = ConstantVector::GetData(source); + auto tdata = ConstantVector::GetData(result); + *tdata = *ldata; + } else { + source.Flatten(count); + result.SetVectorType(VectorType::FLAT_VECTOR); + FlatVector::SetValidity(result, FlatVector::Validity(source)); + + auto ldata = FlatVector::GetData(source); + auto tdata = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + tdata[i] = ldata[i]; + } + } + auto &source_cc = ListVector::GetEntry(source); + auto source_size = ListVector::GetListSize(source); + + ListVector::Reserve(result, source_size); + auto &append_vector = ListVector::GetEntry(result); + + CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); + bool all_succeeded = cast_data.child_cast_info.function(source_cc, append_vector, source_size, child_parameters); + ListVector::SetListSize(result, source_size); + D_ASSERT(ListVector::GetListSize(result) == source_size); + return all_succeeded; +} + +static bool ListToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; + // first cast the child vector to varchar + Vector varchar_list(LogicalType::LIST(LogicalType::VARCHAR), count); + ListCast::ListToListCast(source, varchar_list, count, parameters); + + // now construct the actual varchar vector + varchar_list.Flatten(count); + auto &child = ListVector::GetEntry(varchar_list); + auto list_data = FlatVector::GetData(varchar_list); + auto &validity = FlatVector::Validity(varchar_list); + + child.Flatten(count); + auto child_data = FlatVector::GetData(child); + auto &child_validity = FlatVector::Validity(child); + + auto result_data = FlatVector::GetData(result); + static constexpr const idx_t SEP_LENGTH = 2; + static constexpr const idx_t NULL_LENGTH = 4; + for (idx_t i = 0; i < count; i++) { + if (!validity.RowIsValid(i)) { + FlatVector::SetNull(result, i, true); + continue; + } + auto list = list_data[i]; + // figure out how long the result needs to be + idx_t list_length = 2; // "[" and "]" + for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { + auto idx = list.offset + list_idx; + if (list_idx > 0) { + list_length += SEP_LENGTH; // ", " + } + // string length, or "NULL" + list_length += child_validity.RowIsValid(idx) ? child_data[idx].GetSize() : NULL_LENGTH; + } + result_data[i] = StringVector::EmptyString(result, list_length); + auto dataptr = result_data[i].GetDataWriteable(); + auto offset = 0; + dataptr[offset++] = '['; + for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { + auto idx = list.offset + list_idx; + if (list_idx > 0) { + memcpy(dataptr + offset, ", ", SEP_LENGTH); + offset += SEP_LENGTH; + } + if (child_validity.RowIsValid(idx)) { + auto len = child_data[idx].GetSize(); + memcpy(dataptr + offset, child_data[idx].GetData(), len); + offset += len; + } else { + memcpy(dataptr + offset, "NULL", NULL_LENGTH); + offset += NULL_LENGTH; + } + } + dataptr[offset] = ']'; + result_data[i].Finalize(); + } + + if (constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + return true; +} + +BoundCastInfo DefaultCasts::ListCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + switch (target.id()) { + case LogicalTypeId::LIST: + return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target), + ListBoundCastData::InitListLocalState); + case LogicalTypeId::VARCHAR: + return BoundCastInfo( + ListToVarcharCast, + ListBoundCastData::BindListToListCast(input, source, LogicalType::LIST(LogicalType::VARCHAR)), + ListBoundCastData::InitListLocalState); + default: + return DefaultCasts::TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/map_cast.cpp b/src/duckdb/src/function/cast/map_cast.cpp new file mode 100644 index 00000000..20c14efa --- /dev/null +++ b/src/duckdb/src/function/cast/map_cast.cpp @@ -0,0 +1,94 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/cast/bound_cast_data.hpp" + +namespace duckdb { + +unique_ptr MapBoundCastData::BindMapToMapCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + vector child_cast_info; + auto source_key = MapType::KeyType(source); + auto target_key = MapType::KeyType(target); + auto source_val = MapType::ValueType(source); + auto target_val = MapType::ValueType(target); + auto key_cast = input.GetCastFunction(source_key, target_key); + auto value_cast = input.GetCastFunction(source_val, target_val); + return make_uniq(std::move(key_cast), std::move(value_cast)); +} + +static bool MapToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; + auto varchar_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); + Vector varchar_map(varchar_type, count); + + // since map's physical type is a list, the ListCast can be utilized + ListCast::ListToListCast(source, varchar_map, count, parameters); + + varchar_map.Flatten(count); + auto &validity = FlatVector::Validity(varchar_map); + auto &key_str = MapVector::GetKeys(varchar_map); + auto &val_str = MapVector::GetValues(varchar_map); + + key_str.Flatten(ListVector::GetListSize(source)); + val_str.Flatten(ListVector::GetListSize(source)); + + auto list_data = ListVector::GetData(varchar_map); + auto key_data = FlatVector::GetData(key_str); + auto val_data = FlatVector::GetData(val_str); + auto &key_validity = FlatVector::Validity(key_str); + auto &val_validity = FlatVector::Validity(val_str); + auto &struct_validity = FlatVector::Validity(ListVector::GetEntry(varchar_map)); + + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + if (!validity.RowIsValid(i)) { + FlatVector::SetNull(result, i, true); + continue; + } + auto list = list_data[i]; + string ret = "{"; + for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { + if (list_idx > 0) { + ret += ", "; + } + auto idx = list.offset + list_idx; + + if (!struct_validity.RowIsValid(idx)) { + ret += "NULL"; + continue; + } + if (!key_validity.RowIsValid(idx)) { + // throw InternalException("Error in map: key validity invalid?!"); + ret += "invalid"; + continue; + } + ret += key_data[idx].GetString(); + ret += "="; + ret += val_validity.RowIsValid(idx) ? val_data[idx].GetString() : "NULL"; + } + ret += "}"; + result_data[i] = StringVector::AddString(result, ret); + } + + if (constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + return true; +} + +BoundCastInfo DefaultCasts::MapCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + switch (target.id()) { + case LogicalTypeId::MAP: + return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target), + ListBoundCastData::InitListLocalState); + case LogicalTypeId::VARCHAR: { + auto varchar_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); + return BoundCastInfo(MapToVarcharCast, ListBoundCastData::BindListToListCast(input, source, varchar_type), + ListBoundCastData::InitListLocalState); + } + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/numeric_casts.cpp b/src/duckdb/src/function/cast/numeric_casts.cpp new file mode 100644 index 00000000..071b0604 --- /dev/null +++ b/src/duckdb/src/function/cast/numeric_casts.cpp @@ -0,0 +1,79 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" +#include "duckdb/common/operator/string_cast.hpp" +#include "duckdb/common/operator/numeric_cast.hpp" + +namespace duckdb { + +template +static BoundCastInfo InternalNumericCastSwitch(const LogicalType &source, const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::BOOLEAN: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::TINYINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::SMALLINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::INTEGER: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::BIGINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::UTINYINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::USMALLINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::UINTEGER: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::UBIGINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::HUGEINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::FLOAT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::DOUBLE: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::DECIMAL: + return BoundCastInfo(&VectorCastHelpers::ToDecimalCast); + case LogicalTypeId::VARCHAR: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::BIT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + default: + return DefaultCasts::TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::NumericCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + switch (source.id()) { + case LogicalTypeId::BOOLEAN: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::TINYINT: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::SMALLINT: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::INTEGER: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::BIGINT: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::UTINYINT: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::USMALLINT: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::UINTEGER: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::UBIGINT: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::HUGEINT: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::FLOAT: + return InternalNumericCastSwitch(source, target); + case LogicalTypeId::DOUBLE: + return InternalNumericCastSwitch(source, target); + default: + throw InternalException("NumericCastSwitch called with non-numeric argument"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/pointer_cast.cpp b/src/duckdb/src/function/cast/pointer_cast.cpp new file mode 100644 index 00000000..33eead29 --- /dev/null +++ b/src/duckdb/src/function/cast/pointer_cast.cpp @@ -0,0 +1,18 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +BoundCastInfo DefaultCasts::PointerCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // pointer to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + default: + return nullptr; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/string_cast.cpp b/src/duckdb/src/function/cast/string_cast.cpp new file mode 100644 index 00000000..5ef31d90 --- /dev/null +++ b/src/duckdb/src/function/cast/string_cast.cpp @@ -0,0 +1,415 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/cast/bound_cast_data.hpp" + +namespace duckdb { + +template +bool StringEnumCastLoop(const string_t *source_data, ValidityMask &source_mask, const LogicalType &source_type, + T *result_data, ValidityMask &result_mask, const LogicalType &result_type, idx_t count, + string *error_message, const SelectionVector *sel) { + bool all_converted = true; + for (idx_t i = 0; i < count; i++) { + idx_t source_idx = i; + if (sel) { + source_idx = sel->get_index(i); + } + if (source_mask.RowIsValid(source_idx)) { + auto pos = EnumType::GetPos(result_type, source_data[source_idx]); + if (pos == -1) { + result_data[i] = + HandleVectorCastError::Operation(CastExceptionText(source_data[source_idx]), + result_mask, i, error_message, all_converted); + } else { + result_data[i] = pos; + } + } else { + result_mask.SetInvalid(i); + } + } + return all_converted; +} + +template +bool StringEnumCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); + switch (source.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + auto source_data = ConstantVector::GetData(source); + auto source_mask = ConstantVector::Validity(source); + auto result_data = ConstantVector::GetData(result); + auto &result_mask = ConstantVector::Validity(result); + + return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask, + result.GetType(), 1, parameters.error_message, nullptr); + } + default: { + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + + result.SetVectorType(VectorType::FLAT_VECTOR); + + auto source_data = UnifiedVectorFormat::GetData(vdata); + auto source_sel = vdata.sel; + auto source_mask = vdata.validity; + auto result_data = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + + return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask, + result.GetType(), count, parameters.error_message, source_sel); + } + } +} + +static BoundCastInfo VectorStringCastNumericSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::ENUM: { + switch (target.InternalType()) { + case PhysicalType::UINT8: + return StringEnumCast; + case PhysicalType::UINT16: + return StringEnumCast; + case PhysicalType::UINT32: + return StringEnumCast; + default: + throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); + } + } + case LogicalTypeId::BOOLEAN: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::TINYINT: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::SMALLINT: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::INTEGER: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::BIGINT: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::UTINYINT: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::USMALLINT: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::UINTEGER: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::UBIGINT: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::HUGEINT: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::FLOAT: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::DOUBLE: + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::INTERVAL: + return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); + case LogicalTypeId::DECIMAL: + return BoundCastInfo(&VectorCastHelpers::ToDecimalCast); + default: + return DefaultCasts::TryVectorNullCast; + } +} + +//===--------------------------------------------------------------------===// +// string -> list casting +//===--------------------------------------------------------------------===// +bool VectorStringToList::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, + Vector &result, ValidityMask &result_mask, idx_t count, + CastParameters ¶meters, const SelectionVector *sel) { + idx_t total_list_size = 0; + for (idx_t i = 0; i < count; i++) { + idx_t idx = i; + if (sel) { + idx = sel->get_index(i); + } + if (!source_mask.RowIsValid(idx)) { + continue; + } + total_list_size += VectorStringToList::CountPartsList(source_data[idx]); + } + + Vector varchar_vector(LogicalType::VARCHAR, total_list_size); + + ListVector::Reserve(result, total_list_size); + ListVector::SetListSize(result, total_list_size); + + auto list_data = ListVector::GetData(result); + auto child_data = FlatVector::GetData(varchar_vector); + + bool all_converted = true; + idx_t total = 0; + for (idx_t i = 0; i < count; i++) { + idx_t idx = i; + if (sel) { + idx = sel->get_index(i); + } + if (!source_mask.RowIsValid(idx)) { + result_mask.SetInvalid(i); + continue; + } + + list_data[i].offset = total; + if (!VectorStringToList::SplitStringList(source_data[idx], child_data, total, varchar_vector)) { + string text = "Type VARCHAR with value '" + source_data[idx].GetString() + + "' can't be cast to the destination type LIST"; + HandleVectorCastError::Operation(text, result_mask, idx, parameters.error_message, all_converted); + } + list_data[i].length = total - list_data[i].offset; // length is the amount of parts coming from this string + } + D_ASSERT(total_list_size == total); + + auto &result_child = ListVector::GetEntry(result); + auto &cast_data = parameters.cast_data->Cast(); + CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); + return cast_data.child_cast_info.function(varchar_vector, result_child, total_list_size, child_parameters) && + all_converted; +} + +static LogicalType InitVarcharStructType(const LogicalType &target) { + child_list_t child_types; + for (auto &child : StructType::GetChildTypes(target)) { + child_types.push_back(make_pair(child.first, LogicalType::VARCHAR)); + } + + return LogicalType::STRUCT(child_types); +} + +//===--------------------------------------------------------------------===// +// string -> struct casting +//===--------------------------------------------------------------------===// +bool VectorStringToStruct::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, + Vector &result, ValidityMask &result_mask, idx_t count, + CastParameters ¶meters, const SelectionVector *sel) { + auto varchar_struct_type = InitVarcharStructType(result.GetType()); + Vector varchar_vector(varchar_struct_type, count); + auto &child_vectors = StructVector::GetEntries(varchar_vector); + auto &result_children = StructVector::GetEntries(result); + + string_map_t child_names; + vector child_masks; + for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { + child_names.insert({StructType::GetChildName(result.GetType(), child_idx), child_idx}); + child_masks.emplace_back(&FlatVector::Validity(*child_vectors[child_idx])); + child_masks[child_idx]->SetAllInvalid(count); + } + + bool all_converted = true; + for (idx_t i = 0; i < count; i++) { + idx_t idx = i; + if (sel) { + idx = sel->get_index(i); + } + if (!source_mask.RowIsValid(idx)) { + result_mask.SetInvalid(i); + continue; + } + if (!VectorStringToStruct::SplitStruct(source_data[idx], child_vectors, i, child_names, child_masks)) { + string text = "Type VARCHAR with value '" + source_data[idx].GetString() + + "' can't be cast to the destination type STRUCT"; + for (auto &child_mask : child_masks) { + child_mask->SetInvalid(idx); // some values may have already been found and set valid + } + HandleVectorCastError::Operation(text, result_mask, idx, parameters.error_message, all_converted); + } + } + + auto &cast_data = parameters.cast_data->Cast(); + auto &lstate = parameters.local_state->Cast(); + D_ASSERT(cast_data.child_cast_info.size() == result_children.size()); + + for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { + auto &child_varchar_vector = *child_vectors[child_idx]; + auto &result_child_vector = *result_children[child_idx]; + auto &child_cast_info = cast_data.child_cast_info[child_idx]; + CastParameters child_parameters(parameters, child_cast_info.cast_data, lstate.local_states[child_idx]); + if (!child_cast_info.function(child_varchar_vector, result_child_vector, count, child_parameters)) { + all_converted = false; + } + } + return all_converted; +} + +//===--------------------------------------------------------------------===// +// string -> map casting +//===--------------------------------------------------------------------===// +unique_ptr InitMapCastLocalState(CastLocalStateParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + auto result = make_uniq(); + + if (cast_data.key_cast.init_local_state) { + CastLocalStateParameters child_params(parameters, cast_data.key_cast.cast_data); + result->key_state = cast_data.key_cast.init_local_state(child_params); + } + if (cast_data.value_cast.init_local_state) { + CastLocalStateParameters child_params(parameters, cast_data.value_cast.cast_data); + result->value_state = cast_data.value_cast.init_local_state(child_params); + } + return std::move(result); +} + +bool VectorStringToMap::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, + Vector &result, ValidityMask &result_mask, idx_t count, + CastParameters ¶meters, const SelectionVector *sel) { + idx_t total_elements = 0; + for (idx_t i = 0; i < count; i++) { + idx_t idx = i; + if (sel) { + idx = sel->get_index(i); + } + if (!source_mask.RowIsValid(idx)) { + continue; + } + total_elements += (VectorStringToMap::CountPartsMap(source_data[idx]) + 1) / 2; + } + + Vector varchar_key_vector(LogicalType::VARCHAR, total_elements); + Vector varchar_val_vector(LogicalType::VARCHAR, total_elements); + auto child_key_data = FlatVector::GetData(varchar_key_vector); + auto child_val_data = FlatVector::GetData(varchar_val_vector); + + ListVector::Reserve(result, total_elements); + ListVector::SetListSize(result, total_elements); + auto list_data = ListVector::GetData(result); + + bool all_converted = true; + idx_t total = 0; + for (idx_t i = 0; i < count; i++) { + idx_t idx = i; + if (sel) { + idx = sel->get_index(i); + } + if (!source_mask.RowIsValid(idx)) { + result_mask.SetInvalid(idx); + continue; + } + + list_data[i].offset = total; + if (!VectorStringToMap::SplitStringMap(source_data[idx], child_key_data, child_val_data, total, + varchar_key_vector, varchar_val_vector)) { + string text = "Type VARCHAR with value '" + source_data[idx].GetString() + + "' can't be cast to the destination type MAP"; + FlatVector::SetNull(result, idx, true); + HandleVectorCastError::Operation(text, result_mask, idx, parameters.error_message, all_converted); + } + list_data[i].length = total - list_data[i].offset; + } + D_ASSERT(total_elements == total); + + auto &result_key_child = MapVector::GetKeys(result); + auto &result_val_child = MapVector::GetValues(result); + auto &cast_data = parameters.cast_data->Cast(); + auto &lstate = parameters.local_state->Cast(); + + CastParameters key_params(parameters, cast_data.key_cast.cast_data, lstate.key_state); + if (!cast_data.key_cast.function(varchar_key_vector, result_key_child, total_elements, key_params)) { + all_converted = false; + } + CastParameters val_params(parameters, cast_data.value_cast.cast_data, lstate.value_state); + if (!cast_data.value_cast.function(varchar_val_vector, result_val_child, total_elements, val_params)) { + all_converted = false; + } + + auto &key_validity = FlatVector::Validity(result_key_child); + if (!all_converted) { + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + if (!result_mask.RowIsValid(row_idx)) { + continue; + } + auto list = list_data[row_idx]; + for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { + auto idx = list.offset + list_idx; + if (!key_validity.RowIsValid(idx)) { + result_mask.SetInvalid(row_idx); + } + } + } + } + MapVector::MapConversionVerify(result, count); + return all_converted; +} + +template +bool StringToNestedTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); + + switch (source.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + auto source_data = ConstantVector::GetData(source); + auto &source_mask = ConstantVector::Validity(source); + auto &result_mask = FlatVector::Validity(result); + auto ret = T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, 1, parameters, nullptr); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + return ret; + } + default: { + UnifiedVectorFormat unified_source; + + source.ToUnifiedFormat(count, unified_source); + auto source_sel = unified_source.sel; + auto source_data = UnifiedVectorFormat::GetData(unified_source); + auto &source_mask = unified_source.validity; + auto &result_mask = FlatVector::Validity(result); + + return T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, count, parameters, + source_sel); + } + } +} + +BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + switch (target.id()) { + case LogicalTypeId::DATE: + return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); + case LogicalTypeId::TIME: + return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); + case LogicalTypeId::TIME_TZ: + return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); + case LogicalTypeId::TIMESTAMP_NS: + return BoundCastInfo( + &VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::TIMESTAMP_SEC: + return BoundCastInfo( + &VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::TIMESTAMP_MS: + return BoundCastInfo( + &VectorCastHelpers::TryCastStrictLoop); + case LogicalTypeId::BLOB: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::BIT: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::UUID: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::SQLNULL: + return &DefaultCasts::TryVectorNullCast; + case LogicalTypeId::VARCHAR: + return &DefaultCasts::ReinterpretCast; + case LogicalTypeId::LIST: + // the second argument allows for a secondary casting function to be passed in the CastParameters + return BoundCastInfo( + &StringToNestedTypeCast, + ListBoundCastData::BindListToListCast(input, LogicalType::LIST(LogicalType::VARCHAR), target), + ListBoundCastData::InitListLocalState); + case LogicalTypeId::STRUCT: + return BoundCastInfo(&StringToNestedTypeCast, + StructBoundCastData::BindStructToStructCast(input, InitVarcharStructType(target), target), + StructBoundCastData::InitStructCastLocalState); + case LogicalTypeId::MAP: + return BoundCastInfo(&StringToNestedTypeCast, + MapBoundCastData::BindMapToMapCast( + input, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR), target), + InitMapCastLocalState); + default: + return VectorStringCastNumericSwitch(input, source, target); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/struct_cast.cpp b/src/duckdb/src/function/cast/struct_cast.cpp new file mode 100644 index 00000000..b40949ac --- /dev/null +++ b/src/duckdb/src/function/cast/struct_cast.cpp @@ -0,0 +1,169 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/cast/bound_cast_data.hpp" + +namespace duckdb { + +unique_ptr StructBoundCastData::BindStructToStructCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + vector child_cast_info; + auto &source_child_types = StructType::GetChildTypes(source); + auto &result_child_types = StructType::GetChildTypes(target); + + auto target_is_unnamed = StructType::IsUnnamed(target); + auto source_is_unnamed = StructType::IsUnnamed(source); + + if (source_child_types.size() != result_child_types.size()) { + throw TypeMismatchException(source, target, "Cannot cast STRUCTs of different size"); + } + for (idx_t i = 0; i < source_child_types.size(); i++) { + if (!target_is_unnamed && !source_is_unnamed && + !StringUtil::CIEquals(source_child_types[i].first, result_child_types[i].first)) { + throw TypeMismatchException(source, target, "Cannot cast STRUCTs with different names"); + } + auto child_cast = input.GetCastFunction(source_child_types[i].second, result_child_types[i].second); + child_cast_info.push_back(std::move(child_cast)); + } + return make_uniq(std::move(child_cast_info), target); +} + +unique_ptr StructBoundCastData::InitStructCastLocalState(CastLocalStateParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + auto result = make_uniq(); + + for (auto &entry : cast_data.child_cast_info) { + unique_ptr child_state; + if (entry.init_local_state) { + CastLocalStateParameters child_params(parameters, entry.cast_data); + child_state = entry.init_local_state(child_params); + } + result->local_states.push_back(std::move(child_state)); + } + return std::move(result); +} + +static bool StructToStructCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + auto &lstate = parameters.local_state->Cast(); + auto &source_child_types = StructType::GetChildTypes(source.GetType()); + auto &source_children = StructVector::GetEntries(source); + D_ASSERT(source_children.size() == StructType::GetChildTypes(result.GetType()).size()); + + auto &result_children = StructVector::GetEntries(result); + bool all_converted = true; + for (idx_t c_idx = 0; c_idx < source_child_types.size(); c_idx++) { + auto &result_child_vector = *result_children[c_idx]; + auto &source_child_vector = *source_children[c_idx]; + CastParameters child_parameters(parameters, cast_data.child_cast_info[c_idx].cast_data, + lstate.local_states[c_idx]); + if (!cast_data.child_cast_info[c_idx].function(source_child_vector, result_child_vector, count, + child_parameters)) { + all_converted = false; + } + } + if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, ConstantVector::IsNull(source)); + } else { + source.Flatten(count); + FlatVector::Validity(result) = FlatVector::Validity(source); + } + return all_converted; +} + +static bool StructToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; + // first cast all child elements to varchar + auto &cast_data = parameters.cast_data->Cast(); + Vector varchar_struct(cast_data.target, count); + StructToStructCast(source, varchar_struct, count, parameters); + + // now construct the actual varchar vector + varchar_struct.Flatten(count); + auto &child_types = StructType::GetChildTypes(source.GetType()); + auto &children = StructVector::GetEntries(varchar_struct); + auto &validity = FlatVector::Validity(varchar_struct); + auto result_data = FlatVector::GetData(result); + static constexpr const idx_t SEP_LENGTH = 2; + static constexpr const idx_t NAME_SEP_LENGTH = 4; + static constexpr const idx_t NULL_LENGTH = 4; + for (idx_t i = 0; i < count; i++) { + if (!validity.RowIsValid(i)) { + FlatVector::SetNull(result, i, true); + continue; + } + idx_t string_length = 2; // {} + for (idx_t c = 0; c < children.size(); c++) { + if (c > 0) { + string_length += SEP_LENGTH; + } + children[c]->Flatten(count); + auto &child_validity = FlatVector::Validity(*children[c]); + auto data = FlatVector::GetData(*children[c]); + auto &name = child_types[c].first; + string_length += name.size() + NAME_SEP_LENGTH; // "'{name}': " + string_length += child_validity.RowIsValid(i) ? data[i].GetSize() : NULL_LENGTH; + } + result_data[i] = StringVector::EmptyString(result, string_length); + auto dataptr = result_data[i].GetDataWriteable(); + idx_t offset = 0; + dataptr[offset++] = '{'; + for (idx_t c = 0; c < children.size(); c++) { + if (c > 0) { + memcpy(dataptr + offset, ", ", SEP_LENGTH); + offset += SEP_LENGTH; + } + auto &child_validity = FlatVector::Validity(*children[c]); + auto data = FlatVector::GetData(*children[c]); + auto &name = child_types[c].first; + // "'{name}': " + dataptr[offset++] = '\''; + memcpy(dataptr + offset, name.c_str(), name.size()); + offset += name.size(); + dataptr[offset++] = '\''; + dataptr[offset++] = ':'; + dataptr[offset++] = ' '; + // value + if (child_validity.RowIsValid(i)) { + auto len = data[i].GetSize(); + memcpy(dataptr + offset, data[i].GetData(), len); + offset += len; + } else { + memcpy(dataptr + offset, "NULL", NULL_LENGTH); + offset += NULL_LENGTH; + } + } + dataptr[offset++] = '}'; + result_data[i].Finalize(); + } + + if (constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + return true; +} + +BoundCastInfo DefaultCasts::StructCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + switch (target.id()) { + case LogicalTypeId::STRUCT: + return BoundCastInfo(StructToStructCast, StructBoundCastData::BindStructToStructCast(input, source, target), + StructBoundCastData::InitStructCastLocalState); + case LogicalTypeId::VARCHAR: { + // bind a cast in which we convert all child entries to VARCHAR entries + auto &struct_children = StructType::GetChildTypes(source); + child_list_t varchar_children; + for (auto &child_entry : struct_children) { + varchar_children.push_back(make_pair(child_entry.first, LogicalType::VARCHAR)); + } + auto varchar_type = LogicalType::STRUCT(varchar_children); + return BoundCastInfo(StructToVarcharCast, + StructBoundCastData::BindStructToStructCast(input, source, varchar_type), + StructBoundCastData::InitStructCastLocalState); + } + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/time_casts.cpp b/src/duckdb/src/function/cast/time_casts.cpp new file mode 100644 index 00000000..3586247a --- /dev/null +++ b/src/duckdb/src/function/cast/time_casts.cpp @@ -0,0 +1,181 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" +#include "duckdb/common/operator/string_cast.hpp" +namespace duckdb { + +BoundCastInfo DefaultCasts::DateCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // date to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + // date to timestamp + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::TIMESTAMP_NS: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::TIMESTAMP_SEC: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::TIMESTAMP_MS: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + default: + return TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::TimeCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // time to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::TIME_TZ: + // time to time with time zone + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); + default: + return TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::TimeTzCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // time with time zone to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::TIME: + // time with time zone to time + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); + default: + return TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::TimestampCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // timestamp to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::DATE: + // timestamp to date + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIME: + // timestamp to time + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIME_TZ: + // timestamp to time_tz + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIMESTAMP_TZ: + // timestamp (us) to timestamp with time zone + return ReinterpretCast; + case LogicalTypeId::TIMESTAMP_NS: + // timestamp (us) to timestamp (ns) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIMESTAMP_MS: + // timestamp (us) to timestamp (ms) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIMESTAMP_SEC: + // timestamp (us) to timestamp (s) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + default: + return TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::TimestampTzCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // timestamp with time zone to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::TIME_TZ: + // timestamp with time zone to time with time zone. + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIMESTAMP: + // timestamp with time zone to timestamp (us) + return ReinterpretCast; + default: + return TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::TimestampNsCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // timestamp (ns) to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::TIMESTAMP: + // timestamp (ns) to timestamp (us) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + default: + return TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::TimestampMsCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // timestamp (ms) to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::TIMESTAMP: + // timestamp (ms) to timestamp (us) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIMESTAMP_NS: + // timestamp (ms) to timestamp (ns) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + default: + return TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::TimestampSecCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // timestamp (sec) to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::TIMESTAMP_MS: + // timestamp (s) to timestamp (ms) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIMESTAMP: + // timestamp (s) to timestamp (us) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + case LogicalTypeId::TIMESTAMP_NS: + // timestamp (s) to timestamp (ns) + return BoundCastInfo( + &VectorCastHelpers::TemplatedCastLoop); + default: + return TryVectorNullCast; + } +} +BoundCastInfo DefaultCasts::IntervalCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // time to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/union/from_struct.cpp b/src/duckdb/src/function/cast/union/from_struct.cpp new file mode 100644 index 00000000..7e3a0ae6 --- /dev/null +++ b/src/duckdb/src/function/cast/union/from_struct.cpp @@ -0,0 +1,114 @@ +#include "duckdb/function/cast/bound_cast_data.hpp" + +namespace duckdb { + +bool StructToUnionCast::AllowImplicitCastFromStruct(const LogicalType &source, const LogicalType &target) { + if (source.id() != LogicalTypeId::STRUCT) { + return false; + } + auto target_fields = StructType::GetChildTypes(target); + auto fields = StructType::GetChildTypes(source); + if (target_fields.size() != fields.size()) { + // Struct should have the same amount of fields as the union + return false; + } + for (idx_t i = 0; i < target_fields.size(); i++) { + auto &target_field = target_fields[i].second; + auto &target_field_name = target_fields[i].first; + auto &field = fields[i].second; + auto &field_name = fields[i].first; + if (i == 0) { + // For the tag field we don't accept a type substitute as varchar + if (target_field != field) { + return false; + } + continue; + } + if (!StringUtil::CIEquals(target_field_name, field_name)) { + return false; + } + if (target_field != field && field != LogicalType::VARCHAR) { + // We allow the field to be VARCHAR, since unsupported types get cast to VARCHAR by EXPORT DATABASE (format + // PARQUET) i.e UNION(a BIT) becomes STRUCT(a VARCHAR) + return false; + } + } + return true; +} + +// Physical Cast execution + +bool StructToUnionCast::Cast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + auto &lstate = parameters.local_state->Cast(); + + D_ASSERT(source.GetType().id() == LogicalTypeId::STRUCT); + D_ASSERT(result.GetType().id() == LogicalTypeId::UNION); + D_ASSERT(cast_data.target.id() == LogicalTypeId::UNION); + + auto &source_children = StructVector::GetEntries(source); + auto &target_children = StructVector::GetEntries(result); + + for (idx_t i = 0; i < source_children.size(); i++) { + auto &result_child_vector = *target_children[i]; + auto &source_child_vector = *source_children[i]; + CastParameters child_parameters(parameters, cast_data.child_cast_info[i].cast_data, lstate.local_states[i]); + auto converted = + cast_data.child_cast_info[i].function(source_child_vector, result_child_vector, count, child_parameters); + (void)converted; + D_ASSERT(converted); + } + + auto check_tags = UnionVector::CheckUnionValidity(result, count); + switch (check_tags) { + case UnionInvalidReason::TAG_OUT_OF_RANGE: + throw ConversionException("One or more of the tags do not point to a valid union member"); + case UnionInvalidReason::VALIDITY_OVERLAP: + throw ConversionException("One or more rows in the produced UNION have validity set for more than 1 member"); + case UnionInvalidReason::TAG_MISMATCH: + throw ConversionException( + "One or more rows in the produced UNION have tags that don't point to the valid member"); + case UnionInvalidReason::VALID: + break; + default: + throw InternalException("Struct to union cast failed for unknown reason"); + } + + if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, ConstantVector::IsNull(source)); + } else { + source.Flatten(count); + FlatVector::Validity(result) = FlatVector::Validity(source); + } + result.Verify(count); + return true; +} + +// Bind cast + +unique_ptr StructToUnionCast::BindData(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + vector child_cast_info; + D_ASSERT(source.id() == LogicalTypeId::STRUCT); + D_ASSERT(target.id() == LogicalTypeId::UNION); + + auto result_child_count = StructType::GetChildCount(target); + D_ASSERT(result_child_count == StructType::GetChildCount(source)); + + for (idx_t i = 0; i < result_child_count; i++) { + auto &source_child = StructType::GetChildType(source, i); + auto &target_child = StructType::GetChildType(target, i); + + auto child_cast = input.GetCastFunction(source_child, target_child); + child_cast_info.push_back(std::move(child_cast)); + } + return make_uniq(std::move(child_cast_info), target); +} + +BoundCastInfo StructToUnionCast::Bind(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + auto cast_data = StructToUnionCast::BindData(input, source, target); + return BoundCastInfo(&StructToUnionCast::Cast, std::move(cast_data), StructBoundCastData::InitStructCastLocalState); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/union_casts.cpp b/src/duckdb/src/function/cast/union_casts.cpp new file mode 100644 index 00000000..28033245 --- /dev/null +++ b/src/duckdb/src/function/cast/union_casts.cpp @@ -0,0 +1,366 @@ +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/bound_cast_data.hpp" + +#include // for std::sort + +namespace duckdb { + +//-------------------------------------------------------------------------------------------------- +// ??? -> UNION +//-------------------------------------------------------------------------------------------------- +// if the source can be implicitly cast to a member of the target union, the cast is valid + +unique_ptr BindToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + D_ASSERT(target.id() == LogicalTypeId::UNION); + + vector candidates; + + for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(target); member_idx++) { + auto member_type = UnionType::GetMemberType(target, member_idx); + auto member_name = UnionType::GetMemberName(target, member_idx); + auto member_cast_cost = input.function_set.ImplicitCastCost(source, member_type); + if (member_cast_cost != -1) { + auto member_cast_info = input.GetCastFunction(source, member_type); + candidates.emplace_back(member_idx, member_name, member_type, member_cast_cost, + std::move(member_cast_info)); + } + }; + + // no possible casts found! + if (candidates.empty()) { + auto message = StringUtil::Format( + "Type %s can't be cast as %s. %s can't be implicitly cast to any of the union member types: ", + source.ToString(), target.ToString(), source.ToString()); + + auto member_count = UnionType::GetMemberCount(target); + for (idx_t member_idx = 0; member_idx < member_count; member_idx++) { + auto member_type = UnionType::GetMemberType(target, member_idx); + message += member_type.ToString(); + if (member_idx < member_count - 1) { + message += ", "; + } + } + throw CastException(message); + } + + // sort the candidate casts by cost + std::sort(candidates.begin(), candidates.end(), UnionBoundCastData::SortByCostAscending); + + // select the lowest possible cost cast + auto &selected_cast = candidates[0]; + auto selected_cost = candidates[0].cost; + + // check if the cast is ambiguous (2 or more casts have the same cost) + if (candidates.size() > 1 && candidates[1].cost == selected_cost) { + + // collect all the ambiguous types + auto message = StringUtil::Format( + "Type %s can't be cast as %s. The cast is ambiguous, multiple possible members in target: ", source, + target); + for (size_t i = 0; i < candidates.size(); i++) { + if (candidates[i].cost == selected_cost) { + message += StringUtil::Format("'%s (%s)'", candidates[i].name, candidates[i].type.ToString()); + if (i < candidates.size() - 1) { + message += ", "; + } + } + } + message += ". Disambiguate the target type by using the 'union_value( := )' function to promote the " + "source value to a single member union before casting."; + throw CastException(message); + } + + // otherwise, return the selected cast + return make_uniq(std::move(selected_cast)); +} + +unique_ptr InitToUnionLocalState(CastLocalStateParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + if (!cast_data.member_cast_info.init_local_state) { + return nullptr; + } + CastLocalStateParameters child_parameters(parameters, cast_data.member_cast_info.cast_data); + return cast_data.member_cast_info.init_local_state(child_parameters); +} + +static bool ToUnionCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + D_ASSERT(result.GetType().id() == LogicalTypeId::UNION); + auto &cast_data = parameters.cast_data->Cast(); + auto &selected_member_vector = UnionVector::GetMember(result, cast_data.tag); + + CastParameters child_parameters(parameters, cast_data.member_cast_info.cast_data, parameters.local_state); + if (!cast_data.member_cast_info.function(source, selected_member_vector, count, child_parameters)) { + return false; + } + + // cast succeeded, create union vector + UnionVector::SetToMember(result, cast_data.tag, selected_member_vector, count, true); + + result.Verify(count); + + return true; +} + +BoundCastInfo DefaultCasts::ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + + D_ASSERT(target.id() == LogicalTypeId::UNION); + if (StructToUnionCast::AllowImplicitCastFromStruct(source, target)) { + return StructToUnionCast::Bind(input, source, target); + } + auto cast_data = BindToUnionCast(input, source, target); + return BoundCastInfo(&ToUnionCast, std::move(cast_data), InitToUnionLocalState); +} + +//-------------------------------------------------------------------------------------------------- +// UNION -> UNION +//-------------------------------------------------------------------------------------------------- +// if the source member tags is a subset of the target member tags, and all the source members can be +// implicitly cast to the corresponding target members, the cast is valid. +// +// VALID: UNION(A, B) -> UNION(A, B, C) +// VALID: UNION(A, B) -> UNION(A, C) if B can be implicitly cast to C +// +// INVALID: UNION(A, B, C) -> UNION(A, B) +// INVALID: UNION(A, B) -> UNION(A, C) if B can't be implicitly cast to C +// INVALID: UNION(A, B, D) -> UNION(A, B, C) + +struct UnionUnionBoundCastData : public BoundCastData { + + // mapping from source member index to target member index + // these are always the same size as the source member count + // (since all source members must be present in the target) + vector tag_map; + vector member_casts; + + LogicalType target_type; + + UnionUnionBoundCastData(vector tag_map, vector member_casts, LogicalType target_type) + : tag_map(std::move(tag_map)), member_casts(std::move(member_casts)), target_type(std::move(target_type)) { + } + +public: + unique_ptr Copy() const override { + vector member_casts_copy; + for (auto &member_cast : member_casts) { + member_casts_copy.push_back(member_cast.Copy()); + } + return make_uniq(tag_map, std::move(member_casts_copy), target_type); + } +}; + +unique_ptr BindUnionToUnionCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + D_ASSERT(source.id() == LogicalTypeId::UNION); + D_ASSERT(target.id() == LogicalTypeId::UNION); + + auto source_member_count = UnionType::GetMemberCount(source); + + auto tag_map = vector(source_member_count); + auto member_casts = vector(); + + for (idx_t source_idx = 0; source_idx < source_member_count; source_idx++) { + auto &source_member_type = UnionType::GetMemberType(source, source_idx); + auto &source_member_name = UnionType::GetMemberName(source, source_idx); + + bool found = false; + for (idx_t target_idx = 0; target_idx < UnionType::GetMemberCount(target); target_idx++) { + auto &target_member_name = UnionType::GetMemberName(target, target_idx); + + // found a matching member + if (source_member_name == target_member_name) { + auto &target_member_type = UnionType::GetMemberType(target, target_idx); + tag_map[source_idx] = target_idx; + member_casts.push_back(input.GetCastFunction(source_member_type, target_member_type)); + found = true; + break; + } + } + if (!found) { + // no matching member tag found in the target set + auto message = + StringUtil::Format("Type %s can't be cast as %s. The member '%s' is not present in target union", + source.ToString(), target.ToString(), source_member_name); + throw CastException(message); + } + } + + return make_uniq(tag_map, std::move(member_casts), target); +} + +unique_ptr InitUnionToUnionLocalState(CastLocalStateParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + auto result = make_uniq(); + + for (auto &entry : cast_data.member_casts) { + unique_ptr child_state; + if (entry.init_local_state) { + CastLocalStateParameters child_params(parameters, entry.cast_data); + child_state = entry.init_local_state(child_params); + } + result->local_states.push_back(std::move(child_state)); + } + return std::move(result); +} + +static bool UnionToUnionCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &cast_data = parameters.cast_data->Cast(); + auto &lstate = parameters.local_state->Cast(); + + auto source_member_count = UnionType::GetMemberCount(source.GetType()); + auto target_member_count = UnionType::GetMemberCount(result.GetType()); + + auto target_member_is_mapped = vector(target_member_count); + + // Perform the casts from source to target members + for (idx_t member_idx = 0; member_idx < source_member_count; member_idx++) { + auto target_member_idx = cast_data.tag_map[member_idx]; + + auto &source_member_vector = UnionVector::GetMember(source, member_idx); + auto &target_member_vector = UnionVector::GetMember(result, target_member_idx); + auto &member_cast = cast_data.member_casts[member_idx]; + + CastParameters child_parameters(parameters, member_cast.cast_data, lstate.local_states[member_idx]); + if (!member_cast.function(source_member_vector, target_member_vector, count, child_parameters)) { + return false; + } + + target_member_is_mapped[target_member_idx] = true; + } + + // All member casts succeeded! + + // Set the unmapped target members to constant NULL. + // If we cast UNION(A, B) -> UNION(A, B, C) we need to invalidate C so that + // the invariants of the result union hold. (only member columns "selected" + // by the rowwise corresponding tag in the tag vector should be valid) + for (idx_t target_member_idx = 0; target_member_idx < target_member_count; target_member_idx++) { + if (!target_member_is_mapped[target_member_idx]) { + auto &target_member_vector = UnionVector::GetMember(result, target_member_idx); + target_member_vector.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(target_member_vector, true); + } + } + + // Update the tags in the result vector + auto &source_tag_vector = UnionVector::GetTags(source); + auto &result_tag_vector = UnionVector::GetTags(result); + + if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // Constant vector case optimization + result.SetVectorType(VectorType::CONSTANT_VECTOR); + if (ConstantVector::IsNull(source)) { + ConstantVector::SetNull(result, true); + } else { + // map the tag + auto source_tag = ConstantVector::GetData(source_tag_vector)[0]; + auto mapped_tag = cast_data.tag_map[source_tag]; + ConstantVector::GetData(result_tag_vector)[0] = mapped_tag; + } + } else { + // Otherwise, use the unified vector format to access the source vector. + + // Ensure that all the result members are flat vectors + // This is not always the case, e.g. when a member is cast using the default TryNullCast function + // the resulting member vector will be a constant null vector. + for (idx_t target_member_idx = 0; target_member_idx < target_member_count; target_member_idx++) { + UnionVector::GetMember(result, target_member_idx).Flatten(count); + } + + // We assume that a union tag vector validity matches the union vector validity. + UnifiedVectorFormat source_tag_format; + source_tag_vector.ToUnifiedFormat(count, source_tag_format); + + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + auto source_row_idx = source_tag_format.sel->get_index(row_idx); + if (source_tag_format.validity.RowIsValid(source_row_idx)) { + // map the tag + auto source_tag = (UnifiedVectorFormat::GetData(source_tag_format))[source_row_idx]; + auto target_tag = cast_data.tag_map[source_tag]; + FlatVector::GetData(result_tag_vector)[row_idx] = target_tag; + } else { + + // Issue: The members of the result is not always flatvectors + // In the case of TryNullCast, the result member is constant. + FlatVector::SetNull(result, row_idx, true); + } + } + } + + result.Verify(count); + + return true; +} + +static bool UnionToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; + // first cast all union members to varchar + auto &cast_data = parameters.cast_data->Cast(); + Vector varchar_union(cast_data.target_type, count); + + UnionToUnionCast(source, varchar_union, count, parameters); + + // now construct the actual varchar vector + varchar_union.Flatten(count); + auto &tag_vector = UnionVector::GetTags(source); + auto tag_vector_type = tag_vector.GetVectorType(); + if (tag_vector_type != VectorType::CONSTANT_VECTOR && tag_vector_type != VectorType::FLAT_VECTOR) { + tag_vector.Flatten(count); + } + + auto tags = FlatVector::GetData(tag_vector); + + auto &validity = FlatVector::Validity(varchar_union); + auto result_data = FlatVector::GetData(result); + + for (idx_t i = 0; i < count; i++) { + if (!validity.RowIsValid(i)) { + FlatVector::SetNull(result, i, true); + continue; + } + + auto &member = UnionVector::GetMember(varchar_union, tags[i]); + UnifiedVectorFormat member_vdata; + member.ToUnifiedFormat(count, member_vdata); + + auto mapped_idx = member_vdata.sel->get_index(i); + auto member_valid = member_vdata.validity.RowIsValid(mapped_idx); + if (member_valid) { + auto member_str = (UnifiedVectorFormat::GetData(member_vdata))[mapped_idx]; + result_data[i] = StringVector::AddString(result, member_str); + } else { + result_data[i] = StringVector::AddString(result, "NULL"); + } + } + + if (constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + + result.Verify(count); + return true; +} + +BoundCastInfo DefaultCasts::UnionCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + D_ASSERT(source.id() == LogicalTypeId::UNION); + switch (target.id()) { + case LogicalTypeId::VARCHAR: { + // bind a cast in which we convert all members to VARCHAR first + child_list_t varchar_members; + for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(source); member_idx++) { + varchar_members.push_back(make_pair(UnionType::GetMemberName(source, member_idx), LogicalType::VARCHAR)); + } + auto varchar_type = LogicalType::UNION(std::move(varchar_members)); + return BoundCastInfo(UnionToVarcharCast, BindUnionToUnionCast(input, source, varchar_type), + InitUnionToUnionLocalState); + } + case LogicalTypeId::UNION: + return BoundCastInfo(UnionToUnionCast, BindUnionToUnionCast(input, source, target), InitUnionToUnionLocalState); + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/uuid_casts.cpp b/src/duckdb/src/function/cast/uuid_casts.cpp new file mode 100644 index 00000000..c2267b51 --- /dev/null +++ b/src/duckdb/src/function/cast/uuid_casts.cpp @@ -0,0 +1,18 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +BoundCastInfo DefaultCasts::UUIDCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + // uuid to varchar + return BoundCastInfo(&VectorCastHelpers::StringCast); + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/vector_cast_helpers.cpp b/src/duckdb/src/function/cast/vector_cast_helpers.cpp new file mode 100644 index 00000000..876f3841 --- /dev/null +++ b/src/duckdb/src/function/cast/vector_cast_helpers.cpp @@ -0,0 +1,353 @@ +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +// ------- Helper functions for splitting string nested types ------- +static bool IsNull(const char *buf, idx_t start_pos, Vector &child, idx_t row_idx) { + if (buf[start_pos] == 'N' && buf[start_pos + 1] == 'U' && buf[start_pos + 2] == 'L' && buf[start_pos + 3] == 'L') { + FlatVector::SetNull(child, row_idx, true); + return true; + } + return false; +} + +inline static void SkipWhitespace(const char *buf, idx_t &pos, idx_t len) { + while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { + pos++; + } +} + +static bool SkipToCloseQuotes(idx_t &pos, const char *buf, idx_t &len) { + char quote = buf[pos]; + pos++; + bool escaped = false; + + while (pos < len) { + if (buf[pos] == '\\') { + escaped = !escaped; + } else { + if (buf[pos] == quote && !escaped) { + return true; + } + escaped = false; + } + pos++; + } + return false; +} + +static bool SkipToClose(idx_t &idx, const char *buf, idx_t &len, idx_t &lvl, char close_bracket) { + idx++; + + while (idx < len) { + if (buf[idx] == '"' || buf[idx] == '\'') { + if (!SkipToCloseQuotes(idx, buf, len)) { + return false; + } + } else if (buf[idx] == '{') { + if (!SkipToClose(idx, buf, len, lvl, '}')) { + return false; + } + } else if (buf[idx] == '[') { + if (!SkipToClose(idx, buf, len, lvl, ']')) { + return false; + } + lvl++; + } else if (buf[idx] == close_bracket) { + if (close_bracket == ']') { + lvl--; + } + return true; + } + idx++; + } + return false; +} + +static idx_t StringTrim(const char *buf, idx_t &start_pos, idx_t pos) { + idx_t trailing_whitespace = 0; + while (StringUtil::CharacterIsSpace(buf[pos - trailing_whitespace - 1])) { + trailing_whitespace++; + } + if ((buf[start_pos] == '"' && buf[pos - trailing_whitespace - 1] == '"') || + (buf[start_pos] == '\'' && buf[pos - trailing_whitespace - 1] == '\'')) { + start_pos++; + trailing_whitespace++; + } + return (pos - trailing_whitespace); +} + +struct CountPartOperation { + idx_t count = 0; + + bool HandleKey(const char *buf, idx_t start_pos, idx_t pos) { + count++; + return true; + } + void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { + count++; + } +}; + +// ------- LIST SPLIT ------- +struct SplitStringListOperation { + SplitStringListOperation(string_t *child_data, idx_t &child_start, Vector &child) + : child_data(child_data), child_start(child_start), child(child) { + } + + string_t *child_data; + idx_t &child_start; + Vector &child; + + void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { + if ((pos - start_pos) == 4 && IsNull(buf, start_pos, child, child_start)) { + child_start++; + return; + } + if (start_pos > pos) { + pos = start_pos; + } + child_data[child_start] = StringVector::AddString(child, buf + start_pos, pos - start_pos); + child_start++; + } +}; + +template +static bool SplitStringListInternal(const string_t &input, OP &state) { + const char *buf = input.GetData(); + idx_t len = input.GetSize(); + idx_t lvl = 1; + idx_t pos = 0; + bool seen_value = false; + + SkipWhitespace(buf, pos, len); + if (pos == len || buf[pos] != '[') { + return false; + } + + SkipWhitespace(buf, ++pos, len); + idx_t start_pos = pos; + while (pos < len) { + if (buf[pos] == '[') { + if (!SkipToClose(pos, buf, len, ++lvl, ']')) { + return false; + } + } else if ((buf[pos] == '"' || buf[pos] == '\'') && pos == start_pos) { + SkipToCloseQuotes(pos, buf, len); + } else if (buf[pos] == '{') { + idx_t struct_lvl = 0; + SkipToClose(pos, buf, len, struct_lvl, '}'); + } else if (buf[pos] == ',' || buf[pos] == ']') { + idx_t trailing_whitespace = 0; + while (StringUtil::CharacterIsSpace(buf[pos - trailing_whitespace - 1])) { + trailing_whitespace++; + } + if (buf[pos] != ']' || start_pos != pos || seen_value) { + state.HandleValue(buf, start_pos, pos - trailing_whitespace); + seen_value = true; + } + if (buf[pos] == ']') { + lvl--; + break; + } + SkipWhitespace(buf, ++pos, len); + start_pos = pos; + continue; + } + pos++; + } + SkipWhitespace(buf, ++pos, len); + return (pos == len && lvl == 0); +} + +bool VectorStringToList::SplitStringList(const string_t &input, string_t *child_data, idx_t &child_start, + Vector &child) { + SplitStringListOperation state(child_data, child_start, child); + return SplitStringListInternal(input, state); +} + +idx_t VectorStringToList::CountPartsList(const string_t &input) { + CountPartOperation state; + SplitStringListInternal(input, state); + return state.count; +} + +// ------- MAP SPLIT ------- +struct SplitStringMapOperation { + SplitStringMapOperation(string_t *child_key_data, string_t *child_val_data, idx_t &child_start, Vector &varchar_key, + Vector &varchar_val) + : child_key_data(child_key_data), child_val_data(child_val_data), child_start(child_start), + varchar_key(varchar_key), varchar_val(varchar_val) { + } + + string_t *child_key_data; + string_t *child_val_data; + idx_t &child_start; + Vector &varchar_key; + Vector &varchar_val; + + bool HandleKey(const char *buf, idx_t start_pos, idx_t pos) { + if ((pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_key, child_start)) { + FlatVector::SetNull(varchar_val, child_start, true); + child_start++; + return false; + } + child_key_data[child_start] = StringVector::AddString(varchar_key, buf + start_pos, pos - start_pos); + return true; + } + + void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { + if ((pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_val, child_start)) { + child_start++; + return; + } + child_val_data[child_start] = StringVector::AddString(varchar_val, buf + start_pos, pos - start_pos); + child_start++; + } +}; + +template +static bool FindKeyOrValueMap(const char *buf, idx_t len, idx_t &pos, OP &state, bool key) { + auto start_pos = pos; + idx_t lvl = 0; + while (pos < len) { + if (buf[pos] == '"' || buf[pos] == '\'') { + SkipToCloseQuotes(pos, buf, len); + } else if (buf[pos] == '{') { + SkipToClose(pos, buf, len, lvl, '}'); + } else if (buf[pos] == '[') { + SkipToClose(pos, buf, len, lvl, ']'); + } else if (key && buf[pos] == '=') { + idx_t end_pos = StringTrim(buf, start_pos, pos); + return state.HandleKey(buf, start_pos, end_pos); // put string in KEY_child_vector + } else if (!key && (buf[pos] == ',' || buf[pos] == '}')) { + idx_t end_pos = StringTrim(buf, start_pos, pos); + state.HandleValue(buf, start_pos, end_pos); // put string in VALUE_child_vector + return true; + } + pos++; + } + return false; +} + +template +static bool SplitStringMapInternal(const string_t &input, OP &state) { + const char *buf = input.GetData(); + idx_t len = input.GetSize(); + idx_t pos = 0; + + SkipWhitespace(buf, pos, len); + if (pos == len || buf[pos] != '{') { + return false; + } + SkipWhitespace(buf, ++pos, len); + if (pos == len) { + return false; + } + if (buf[pos] == '}') { + SkipWhitespace(buf, ++pos, len); + return (pos == len); + } + while (pos < len) { + if (!FindKeyOrValueMap(buf, len, pos, state, true)) { + return false; + } + SkipWhitespace(buf, ++pos, len); + if (!FindKeyOrValueMap(buf, len, pos, state, false)) { + return false; + } + SkipWhitespace(buf, ++pos, len); + } + return true; +} + +bool VectorStringToMap::SplitStringMap(const string_t &input, string_t *child_key_data, string_t *child_val_data, + idx_t &child_start, Vector &varchar_key, Vector &varchar_val) { + SplitStringMapOperation state(child_key_data, child_val_data, child_start, varchar_key, varchar_val); + return SplitStringMapInternal(input, state); +} + +idx_t VectorStringToMap::CountPartsMap(const string_t &input) { + CountPartOperation state; + SplitStringMapInternal(input, state); + return state.count; +} + +// ------- STRUCT SPLIT ------- +static bool FindKeyStruct(const char *buf, idx_t len, idx_t &pos) { + while (pos < len) { + if (buf[pos] == ':') { + return true; + } + pos++; + } + return false; +} + +static bool FindValueStruct(const char *buf, idx_t len, idx_t &pos, Vector &varchar_child, idx_t &row_idx, + ValidityMask *child_mask) { + auto start_pos = pos; + idx_t lvl = 0; + while (pos < len) { + if (buf[pos] == '"' || buf[pos] == '\'') { + SkipToCloseQuotes(pos, buf, len); + } else if (buf[pos] == '{') { + SkipToClose(pos, buf, len, lvl, '}'); + } else if (buf[pos] == '[') { + SkipToClose(pos, buf, len, lvl, ']'); + } else if (buf[pos] == ',' || buf[pos] == '}') { + idx_t end_pos = StringTrim(buf, start_pos, pos); + if ((end_pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_child, row_idx)) { + return true; + } + FlatVector::GetData(varchar_child)[row_idx] = + StringVector::AddString(varchar_child, buf + start_pos, end_pos - start_pos); + child_mask->SetValid(row_idx); // any child not set to valid will remain invalid + return true; + } + pos++; + } + return false; +} + +bool VectorStringToStruct::SplitStruct(const string_t &input, vector> &varchar_vectors, + idx_t &row_idx, string_map_t &child_names, + vector &child_masks) { + const char *buf = input.GetData(); + idx_t len = input.GetSize(); + idx_t pos = 0; + idx_t child_idx; + + SkipWhitespace(buf, pos, len); + if (pos == len || buf[pos] != '{') { + return false; + } + SkipWhitespace(buf, ++pos, len); + if (buf[pos] == '}') { + pos++; + } else { + while (pos < len) { + auto key_start = pos; + if (!FindKeyStruct(buf, len, pos)) { + return false; + } + auto key_end = StringTrim(buf, key_start, pos); + string_t found_key(buf + key_start, key_end - key_start); + + auto it = child_names.find(found_key); + if (it == child_names.end()) { + return false; // false key + } + child_idx = it->second; + SkipWhitespace(buf, ++pos, len); + if (!FindValueStruct(buf, len, pos, *varchar_vectors[child_idx], row_idx, child_masks[child_idx])) { + return false; + } + SkipWhitespace(buf, ++pos, len); + } + } + SkipWhitespace(buf, pos, len); + return (pos == len); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast_rules.cpp b/src/duckdb/src/function/cast_rules.cpp new file mode 100644 index 00000000..f4ffb84c --- /dev/null +++ b/src/duckdb/src/function/cast_rules.cpp @@ -0,0 +1,320 @@ +#include "duckdb/function/cast_rules.hpp" + +namespace duckdb { + +//! The target type determines the preferred implicit casts +static int64_t TargetTypeCost(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::INTEGER: + return 103; + case LogicalTypeId::BIGINT: + return 101; + case LogicalTypeId::DOUBLE: + return 102; + case LogicalTypeId::HUGEINT: + return 120; + case LogicalTypeId::TIMESTAMP: + return 120; + case LogicalTypeId::VARCHAR: + return 149; + case LogicalTypeId::DECIMAL: + return 104; + case LogicalTypeId::STRUCT: + case LogicalTypeId::MAP: + case LogicalTypeId::LIST: + case LogicalTypeId::UNION: + return 160; + default: + return 110; + } +} + +static int64_t ImplicitCastTinyint(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastSmallint(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastInteger(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastBigint(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastUTinyint(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastUSmallint(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastUInteger(const LogicalType &to) { + switch (to.id()) { + + case LogicalTypeId::UBIGINT: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastUBigint(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastFloat(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::DOUBLE: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastDouble(const LogicalType &to) { + switch (to.id()) { + default: + return -1; + } +} + +static int64_t ImplicitCastDecimal(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastHugeint(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + return TargetTypeCost(to); + default: + return -1; + } +} + +static int64_t ImplicitCastDate(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::TIMESTAMP: + return TargetTypeCost(to); + default: + return -1; + } +} + +int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) { + if (from.id() == LogicalTypeId::SQLNULL) { + // NULL expression can be cast to anything + return TargetTypeCost(to); + } + if (from.id() == LogicalTypeId::UNKNOWN) { + // parameter expression can be cast to anything for no cost + return 0; + } + if (to.id() == LogicalTypeId::ANY) { + // anything can be cast to ANY type for (almost no) cost + return 1; + } + if (from.GetAlias() != to.GetAlias()) { + // if aliases are different, an implicit cast is not possible + return -1; + } + if (from.id() == LogicalTypeId::LIST && to.id() == LogicalTypeId::LIST) { + // Lists can be cast if their child types can be cast + auto child_cost = ImplicitCast(ListType::GetChildType(from), ListType::GetChildType(to)); + if (child_cost >= 100) { + // subtract one from the cost because we prefer LIST[X] -> LIST[VARCHAR] over LIST[X] -> VARCHAR + child_cost--; + } + return child_cost; + } + if (from.id() == to.id()) { + // arguments match: do nothing + return 0; + } + if (from.id() == LogicalTypeId::BLOB && to.id() == LogicalTypeId::VARCHAR) { + // Implicit cast not allowed from BLOB to VARCHAR + return -1; + } + if (to.id() == LogicalTypeId::VARCHAR) { + // everything can be cast to VARCHAR, but this cast has a high cost + return TargetTypeCost(to); + } + + if (from.id() == LogicalTypeId::UNION && to.id() == LogicalTypeId::UNION) { + // Unions can be cast if the source tags are a subset of the target tags + // in which case the most expensive cost is used + int cost = -1; + for (idx_t from_member_idx = 0; from_member_idx < UnionType::GetMemberCount(from); from_member_idx++) { + auto &from_member_name = UnionType::GetMemberName(from, from_member_idx); + + bool found = false; + for (idx_t to_member_idx = 0; to_member_idx < UnionType::GetMemberCount(to); to_member_idx++) { + auto &to_member_name = UnionType::GetMemberName(to, to_member_idx); + + if (from_member_name == to_member_name) { + auto &from_member_type = UnionType::GetMemberType(from, from_member_idx); + auto &to_member_type = UnionType::GetMemberType(to, to_member_idx); + + int child_cost = ImplicitCast(from_member_type, to_member_type); + if (child_cost > cost) { + cost = child_cost; + } + found = true; + break; + } + } + if (!found) { + return -1; + } + } + return cost; + } + + if (to.id() == LogicalTypeId::UNION) { + // check that the union type is fully resolved. + if (to.AuxInfo() == nullptr) { + return -1; + } + // every type can be implicitly be cast to a union if the source type is a member of the union + for (idx_t i = 0; i < UnionType::GetMemberCount(to); i++) { + auto member = UnionType::GetMemberType(to, i); + if (from == member) { + return 0; + } + } + } + + if ((from.id() == LogicalTypeId::TIMESTAMP_SEC || from.id() == LogicalTypeId::TIMESTAMP_MS || + from.id() == LogicalTypeId::TIMESTAMP_NS) && + to.id() == LogicalTypeId::TIMESTAMP) { + //! Any timestamp type can be converted to the default (us) type at low cost + return 101; + } + if ((to.id() == LogicalTypeId::TIMESTAMP_SEC || to.id() == LogicalTypeId::TIMESTAMP_MS || + to.id() == LogicalTypeId::TIMESTAMP_NS) && + from.id() == LogicalTypeId::TIMESTAMP) { + //! Any timestamp type can be converted to the default (us) type at low cost + return 100; + } + switch (from.id()) { + case LogicalTypeId::TINYINT: + return ImplicitCastTinyint(to); + case LogicalTypeId::SMALLINT: + return ImplicitCastSmallint(to); + case LogicalTypeId::INTEGER: + return ImplicitCastInteger(to); + case LogicalTypeId::BIGINT: + return ImplicitCastBigint(to); + case LogicalTypeId::UTINYINT: + return ImplicitCastUTinyint(to); + case LogicalTypeId::USMALLINT: + return ImplicitCastUSmallint(to); + case LogicalTypeId::UINTEGER: + return ImplicitCastUInteger(to); + case LogicalTypeId::UBIGINT: + return ImplicitCastUBigint(to); + case LogicalTypeId::HUGEINT: + return ImplicitCastHugeint(to); + case LogicalTypeId::FLOAT: + return ImplicitCastFloat(to); + case LogicalTypeId::DOUBLE: + return ImplicitCastDouble(to); + case LogicalTypeId::DATE: + return ImplicitCastDate(to); + case LogicalTypeId::DECIMAL: + return ImplicitCastDecimal(to); + default: + return -1; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/compression_config.cpp b/src/duckdb/src/function/compression_config.cpp new file mode 100644 index 00000000..50e68297 --- /dev/null +++ b/src/duckdb/src/function/compression_config.cpp @@ -0,0 +1,94 @@ +#include "duckdb/main/config.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/common/pair.hpp" + +namespace duckdb { + +typedef CompressionFunction (*get_compression_function_t)(PhysicalType type); +typedef bool (*compression_supports_type_t)(PhysicalType type); + +struct DefaultCompressionMethod { + CompressionType type; + get_compression_function_t get_function; + compression_supports_type_t supports_type; +}; + +static DefaultCompressionMethod internal_compression_methods[] = { + {CompressionType::COMPRESSION_CONSTANT, ConstantFun::GetFunction, ConstantFun::TypeIsSupported}, + {CompressionType::COMPRESSION_UNCOMPRESSED, UncompressedFun::GetFunction, UncompressedFun::TypeIsSupported}, + {CompressionType::COMPRESSION_RLE, RLEFun::GetFunction, RLEFun::TypeIsSupported}, + {CompressionType::COMPRESSION_BITPACKING, BitpackingFun::GetFunction, BitpackingFun::TypeIsSupported}, + {CompressionType::COMPRESSION_DICTIONARY, DictionaryCompressionFun::GetFunction, + DictionaryCompressionFun::TypeIsSupported}, + {CompressionType::COMPRESSION_CHIMP, ChimpCompressionFun::GetFunction, ChimpCompressionFun::TypeIsSupported}, + {CompressionType::COMPRESSION_PATAS, PatasCompressionFun::GetFunction, PatasCompressionFun::TypeIsSupported}, + {CompressionType::COMPRESSION_FSST, FSSTFun::GetFunction, FSSTFun::TypeIsSupported}, + {CompressionType::COMPRESSION_AUTO, nullptr, nullptr}}; + +static optional_ptr FindCompressionFunction(CompressionFunctionSet &set, CompressionType type, + PhysicalType data_type) { + auto &functions = set.functions; + auto comp_entry = functions.find(type); + if (comp_entry != functions.end()) { + auto &type_functions = comp_entry->second; + auto type_entry = type_functions.find(data_type); + if (type_entry != type_functions.end()) { + return &type_entry->second; + } + } + return nullptr; +} + +static optional_ptr LoadCompressionFunction(CompressionFunctionSet &set, CompressionType type, + PhysicalType data_type) { + for (idx_t index = 0; internal_compression_methods[index].get_function; index++) { + const auto &method = internal_compression_methods[index]; + if (method.type == type) { + // found the correct compression type + if (!method.supports_type(data_type)) { + // but it does not support this data type: bail out + return nullptr; + } + // the type is supported: create the function and insert it into the set + auto function = method.get_function(data_type); + set.functions[type].insert(make_pair(data_type, function)); + return FindCompressionFunction(set, type, data_type); + } + } + throw InternalException("Unsupported compression function type"); +} + +static void TryLoadCompression(DBConfig &config, vector> &result, CompressionType type, + PhysicalType data_type) { + auto function = config.GetCompressionFunction(type, data_type); + if (!function) { + return; + } + result.push_back(*function); +} + +vector> DBConfig::GetCompressionFunctions(PhysicalType data_type) { + vector> result; + TryLoadCompression(*this, result, CompressionType::COMPRESSION_UNCOMPRESSED, data_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_RLE, data_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_BITPACKING, data_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_DICTIONARY, data_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_CHIMP, data_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_PATAS, data_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_FSST, data_type); + return result; +} + +optional_ptr DBConfig::GetCompressionFunction(CompressionType type, PhysicalType data_type) { + lock_guard l(compression_functions->lock); + // check if the function is already loaded + auto function = FindCompressionFunction(*compression_functions, type, data_type); + if (function) { + return function; + } + // else load the function + return LoadCompressionFunction(*compression_functions, type, data_type); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/function.cpp b/src/duckdb/src/function/function.cpp new file mode 100644 index 00000000..44547e88 --- /dev/null +++ b/src/duckdb/src/function/function.cpp @@ -0,0 +1,159 @@ +#include "duckdb/function/function.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/parser/parsed_data/pragma_info.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +FunctionData::~FunctionData() { +} + +bool FunctionData::Equals(const FunctionData *left, const FunctionData *right) { + if (left == right) { + return true; + } + if (!left || !right) { + return false; + } + return left->Equals(*right); +} + +TableFunctionData::~TableFunctionData() { +} + +unique_ptr TableFunctionData::Copy() const { + throw InternalException("Copy not supported for TableFunctionData"); +} + +bool TableFunctionData::Equals(const FunctionData &other) const { + return false; +} + +Function::Function(string name_p) : name(std::move(name_p)) { +} +Function::~Function() { +} + +SimpleFunction::SimpleFunction(string name_p, vector arguments_p, LogicalType varargs_p) + : Function(std::move(name_p)), arguments(std::move(arguments_p)), varargs(std::move(varargs_p)) { +} + +SimpleFunction::~SimpleFunction() { +} + +string SimpleFunction::ToString() const { + return Function::CallToString(name, arguments); +} + +bool SimpleFunction::HasVarArgs() const { + return varargs.id() != LogicalTypeId::INVALID; +} + +SimpleNamedParameterFunction::SimpleNamedParameterFunction(string name_p, vector arguments_p, + LogicalType varargs_p) + : SimpleFunction(std::move(name_p), std::move(arguments_p), std::move(varargs_p)) { +} + +SimpleNamedParameterFunction::~SimpleNamedParameterFunction() { +} + +string SimpleNamedParameterFunction::ToString() const { + return Function::CallToString(name, arguments, named_parameters); +} + +bool SimpleNamedParameterFunction::HasNamedParameters() const { + return !named_parameters.empty(); +} + +BaseScalarFunction::BaseScalarFunction(string name_p, vector arguments_p, LogicalType return_type_p, + FunctionSideEffects side_effects, LogicalType varargs_p, + FunctionNullHandling null_handling) + : SimpleFunction(std::move(name_p), std::move(arguments_p), std::move(varargs_p)), + return_type(std::move(return_type_p)), side_effects(side_effects), null_handling(null_handling) { +} + +BaseScalarFunction::~BaseScalarFunction() { +} + +string BaseScalarFunction::ToString() const { + return Function::CallToString(name, arguments, return_type); +} + +// add your initializer for new functions here +void BuiltinFunctions::Initialize() { + RegisterTableScanFunctions(); + RegisterSQLiteFunctions(); + RegisterReadFunctions(); + RegisterTableFunctions(); + RegisterArrowFunctions(); + + RegisterDistributiveAggregates(); + + RegisterCompressedMaterializationFunctions(); + + RegisterGenericFunctions(); + RegisterOperators(); + RegisterSequenceFunctions(); + RegisterStringFunctions(); + RegisterNestedFunctions(); + + RegisterPragmaFunctions(); + + // initialize collations + AddCollation("nocase", LowerFun::GetFunction(), true); + AddCollation("noaccent", StripAccentsFun::GetFunction()); + AddCollation("nfc", NFCNormalizeFun::GetFunction()); +} + +hash_t BaseScalarFunction::Hash() const { + hash_t hash = return_type.Hash(); + for (auto &arg : arguments) { + hash = duckdb::CombineHash(hash, arg.Hash()); + } + return hash; +} + +string Function::CallToString(const string &name, const vector &arguments) { + string result = name + "("; + result += StringUtil::Join(arguments, arguments.size(), ", ", + [](const LogicalType &argument) { return argument.ToString(); }); + return result + ")"; +} + +string Function::CallToString(const string &name, const vector &arguments, + const LogicalType &return_type) { + string result = CallToString(name, arguments); + result += " -> " + return_type.ToString(); + return result; +} + +string Function::CallToString(const string &name, const vector &arguments, + const named_parameter_type_map_t &named_parameters) { + vector input_arguments; + input_arguments.reserve(arguments.size() + named_parameters.size()); + for (auto &arg : arguments) { + input_arguments.push_back(arg.ToString()); + } + for (auto &kv : named_parameters) { + input_arguments.push_back(StringUtil::Format("%s : %s", kv.first, kv.second.ToString())); + } + return StringUtil::Format("%s(%s)", name, StringUtil::Join(input_arguments, ", ")); +} + +void Function::EraseArgument(SimpleFunction &bound_function, vector> &arguments, + idx_t argument_index) { + if (bound_function.original_arguments.empty()) { + bound_function.original_arguments = bound_function.arguments; + } + D_ASSERT(arguments.size() == bound_function.arguments.size()); + D_ASSERT(argument_index < arguments.size()); + arguments.erase(arguments.begin() + argument_index); + bound_function.arguments.erase(bound_function.arguments.begin() + argument_index); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/function_binder.cpp b/src/duckdb/src/function/function_binder.cpp new file mode 100644 index 00000000..aba67bb1 --- /dev/null +++ b/src/duckdb/src/function/function_binder.cpp @@ -0,0 +1,323 @@ +#include "duckdb/function/function_binder.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/function/cast_rules.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +FunctionBinder::FunctionBinder(ClientContext &context) : context(context) { +} + +int64_t FunctionBinder::BindVarArgsFunctionCost(const SimpleFunction &func, const vector &arguments) { + if (arguments.size() < func.arguments.size()) { + // not enough arguments to fulfill the non-vararg part of the function + return -1; + } + int64_t cost = 0; + for (idx_t i = 0; i < arguments.size(); i++) { + LogicalType arg_type = i < func.arguments.size() ? func.arguments[i] : func.varargs; + if (arguments[i] == arg_type) { + // arguments match: do nothing + continue; + } + int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(arguments[i], arg_type); + if (cast_cost >= 0) { + // we can implicitly cast, add the cost to the total cost + cost += cast_cost; + } else { + // we can't implicitly cast: throw an error + return -1; + } + } + return cost; +} + +int64_t FunctionBinder::BindFunctionCost(const SimpleFunction &func, const vector &arguments) { + if (func.HasVarArgs()) { + // special case varargs function + return BindVarArgsFunctionCost(func, arguments); + } + if (func.arguments.size() != arguments.size()) { + // invalid argument count: check the next function + return -1; + } + int64_t cost = 0; + for (idx_t i = 0; i < arguments.size(); i++) { + int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(arguments[i], func.arguments[i]); + if (cast_cost >= 0) { + // we can implicitly cast, add the cost to the total cost + cost += cast_cost; + } else { + // we can't implicitly cast: throw an error + return -1; + } + } + return cost; +} + +template +vector FunctionBinder::BindFunctionsFromArguments(const string &name, FunctionSet &functions, + const vector &arguments, string &error) { + idx_t best_function = DConstants::INVALID_INDEX; + int64_t lowest_cost = NumericLimits::Maximum(); + vector candidate_functions; + for (idx_t f_idx = 0; f_idx < functions.functions.size(); f_idx++) { + auto &func = functions.functions[f_idx]; + // check the arguments of the function + int64_t cost = BindFunctionCost(func, arguments); + if (cost < 0) { + // auto casting was not possible + continue; + } + if (cost == lowest_cost) { + candidate_functions.push_back(f_idx); + continue; + } + if (cost > lowest_cost) { + continue; + } + candidate_functions.clear(); + lowest_cost = cost; + best_function = f_idx; + } + if (best_function == DConstants::INVALID_INDEX) { + // no matching function was found, throw an error + string call_str = Function::CallToString(name, arguments); + string candidate_str = ""; + for (auto &f : functions.functions) { + candidate_str += "\t" + f.ToString() + "\n"; + } + error = StringUtil::Format("No function matches the given name and argument types '%s'. You might need to add " + "explicit type casts.\n\tCandidate functions:\n%s", + call_str, candidate_str); + return candidate_functions; + } + candidate_functions.push_back(best_function); + return candidate_functions; +} + +template +idx_t FunctionBinder::MultipleCandidateException(const string &name, FunctionSet &functions, + vector &candidate_functions, + const vector &arguments, string &error) { + D_ASSERT(functions.functions.size() > 1); + // there are multiple possible function definitions + // throw an exception explaining which overloads are there + string call_str = Function::CallToString(name, arguments); + string candidate_str = ""; + for (auto &conf : candidate_functions) { + T f = functions.GetFunctionByOffset(conf); + candidate_str += "\t" + f.ToString() + "\n"; + } + error = StringUtil::Format("Could not choose a best candidate function for the function call \"%s\". In order to " + "select one, please add explicit type casts.\n\tCandidate functions:\n%s", + call_str, candidate_str); + return DConstants::INVALID_INDEX; +} + +template +idx_t FunctionBinder::BindFunctionFromArguments(const string &name, FunctionSet &functions, + const vector &arguments, string &error) { + auto candidate_functions = BindFunctionsFromArguments(name, functions, arguments, error); + if (candidate_functions.empty()) { + // no candidates + return DConstants::INVALID_INDEX; + } + if (candidate_functions.size() > 1) { + // multiple candidates, check if there are any unknown arguments + bool has_parameters = false; + for (auto &arg_type : arguments) { + if (arg_type.id() == LogicalTypeId::UNKNOWN) { + //! there are! we could not resolve parameters in this case + throw ParameterNotResolvedException(); + } + } + if (!has_parameters) { + return MultipleCandidateException(name, functions, candidate_functions, arguments, error); + } + } + return candidate_functions[0]; +} + +idx_t FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, + const vector &arguments, string &error) { + return BindFunctionFromArguments(name, functions, arguments, error); +} + +idx_t FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, + const vector &arguments, string &error) { + return BindFunctionFromArguments(name, functions, arguments, error); +} + +idx_t FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, + const vector &arguments, string &error) { + return BindFunctionFromArguments(name, functions, arguments, error); +} + +idx_t FunctionBinder::BindFunction(const string &name, PragmaFunctionSet &functions, PragmaInfo &info, string &error) { + vector types; + for (auto &value : info.parameters) { + types.push_back(value.type()); + } + idx_t entry = BindFunctionFromArguments(name, functions, types, error); + if (entry == DConstants::INVALID_INDEX) { + throw BinderException(error); + } + auto candidate_function = functions.GetFunctionByOffset(entry); + // cast the input parameters + for (idx_t i = 0; i < info.parameters.size(); i++) { + auto target_type = + i < candidate_function.arguments.size() ? candidate_function.arguments[i] : candidate_function.varargs; + info.parameters[i] = info.parameters[i].CastAs(context, target_type); + } + return entry; +} + +vector FunctionBinder::GetLogicalTypesFromExpressions(vector> &arguments) { + vector types; + types.reserve(arguments.size()); + for (auto &argument : arguments) { + types.push_back(argument->return_type); + } + return types; +} + +idx_t FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, + vector> &arguments, string &error) { + auto types = GetLogicalTypesFromExpressions(arguments); + return BindFunction(name, functions, types, error); +} + +idx_t FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, + vector> &arguments, string &error) { + auto types = GetLogicalTypesFromExpressions(arguments); + return BindFunction(name, functions, types, error); +} + +idx_t FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, + vector> &arguments, string &error) { + auto types = GetLogicalTypesFromExpressions(arguments); + return BindFunction(name, functions, types, error); +} + +enum class LogicalTypeComparisonResult { IDENTICAL_TYPE, TARGET_IS_ANY, DIFFERENT_TYPES }; + +LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const LogicalType &target_type) { + if (target_type.id() == LogicalTypeId::ANY) { + return LogicalTypeComparisonResult::TARGET_IS_ANY; + } + if (source_type == target_type) { + return LogicalTypeComparisonResult::IDENTICAL_TYPE; + } + if (source_type.id() == LogicalTypeId::LIST && target_type.id() == LogicalTypeId::LIST) { + return RequiresCast(ListType::GetChildType(source_type), ListType::GetChildType(target_type)); + } + return LogicalTypeComparisonResult::DIFFERENT_TYPES; +} + +void FunctionBinder::CastToFunctionArguments(SimpleFunction &function, vector> &children) { + for (idx_t i = 0; i < children.size(); i++) { + auto target_type = i < function.arguments.size() ? function.arguments[i] : function.varargs; + target_type.Verify(); + // don't cast lambda children, they get removed anyways + if (children[i]->return_type.id() == LogicalTypeId::LAMBDA) { + continue; + } + // check if the type of child matches the type of function argument + // if not we need to add a cast + auto cast_result = RequiresCast(children[i]->return_type, target_type); + // except for one special case: if the function accepts ANY argument + // in that case we don't add a cast + if (cast_result == LogicalTypeComparisonResult::DIFFERENT_TYPES) { + children[i] = BoundCastExpression::AddCastToType(context, std::move(children[i]), target_type); + } + } +} + +unique_ptr FunctionBinder::BindScalarFunction(const string &schema, const string &name, + vector> children, string &error, + bool is_operator, Binder *binder) { + // bind the function + auto &function = + Catalog::GetSystemCatalog(context).GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, schema, name); + D_ASSERT(function.type == CatalogType::SCALAR_FUNCTION_ENTRY); + return BindScalarFunction(function.Cast(), std::move(children), error, is_operator, + binder); +} + +unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogEntry &func, + vector> children, string &error, + bool is_operator, Binder *binder) { + // bind the function + idx_t best_function = BindFunction(func.name, func.functions, children, error); + if (best_function == DConstants::INVALID_INDEX) { + return nullptr; + } + + // found a matching function! + auto bound_function = func.functions.GetFunctionByOffset(best_function); + + if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { + for (auto &child : children) { + if (child->return_type == LogicalTypeId::SQLNULL) { + return make_uniq(Value(LogicalType::SQLNULL)); + } + if (!child->IsFoldable()) { + continue; + } + Value result; + if (!ExpressionExecutor::TryEvaluateScalar(context, *child, result)) { + continue; + } + if (result.IsNull()) { + return make_uniq(Value(LogicalType::SQLNULL)); + } + } + } + return BindScalarFunction(bound_function, std::move(children), is_operator); +} + +unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_function, + vector> children, + bool is_operator) { + unique_ptr bind_info; + if (bound_function.bind) { + bind_info = bound_function.bind(context, bound_function, children); + } + // check if we need to add casts to the children + CastToFunctionArguments(bound_function, children); + + // now create the function + auto return_type = bound_function.return_type; + return make_uniq(std::move(return_type), std::move(bound_function), std::move(children), + std::move(bind_info), is_operator); +} + +unique_ptr FunctionBinder::BindAggregateFunction(AggregateFunction bound_function, + vector> children, + unique_ptr filter, + AggregateType aggr_type) { + unique_ptr bind_info; + if (bound_function.bind) { + bind_info = bound_function.bind(context, bound_function, children); + // we may have lost some arguments in the bind + children.resize(MinValue(bound_function.arguments.size(), children.size())); + } + + // check if we need to add casts to the children + CastToFunctionArguments(bound_function, children); + + return make_uniq(std::move(bound_function), std::move(children), std::move(filter), + std::move(bind_info), aggr_type); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/function_set.cpp b/src/duckdb/src/function/function_set.cpp new file mode 100644 index 00000000..41d54c88 --- /dev/null +++ b/src/duckdb/src/function/function_set.cpp @@ -0,0 +1,92 @@ +#include "duckdb/function/function_set.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +ScalarFunctionSet::ScalarFunctionSet() : FunctionSet("") { +} + +ScalarFunctionSet::ScalarFunctionSet(string name) : FunctionSet(std::move(name)) { +} + +ScalarFunctionSet::ScalarFunctionSet(ScalarFunction fun) : FunctionSet(std::move(fun.name)) { + functions.push_back(std::move(fun)); +} + +ScalarFunction ScalarFunctionSet::GetFunctionByArguments(ClientContext &context, const vector &arguments) { + string error; + FunctionBinder binder(context); + idx_t index = binder.BindFunction(name, *this, arguments, error); + if (index == DConstants::INVALID_INDEX) { + throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), + error); + } + return GetFunctionByOffset(index); +} + +AggregateFunctionSet::AggregateFunctionSet() : FunctionSet("") { +} + +AggregateFunctionSet::AggregateFunctionSet(string name) : FunctionSet(std::move(name)) { +} + +AggregateFunctionSet::AggregateFunctionSet(AggregateFunction fun) : FunctionSet(std::move(fun.name)) { + functions.push_back(std::move(fun)); +} + +AggregateFunction AggregateFunctionSet::GetFunctionByArguments(ClientContext &context, + const vector &arguments) { + string error; + FunctionBinder binder(context); + idx_t index = binder.BindFunction(name, *this, arguments, error); + if (index == DConstants::INVALID_INDEX) { + // check if the arguments are a prefix of any of the arguments + // this is used for functions such as quantile or string_agg that delete part of their arguments during bind + // FIXME: we should come up with a better solution here + for (auto &func : functions) { + if (arguments.size() >= func.arguments.size()) { + continue; + } + bool is_prefix = true; + for (idx_t k = 0; k < arguments.size(); k++) { + if (arguments[k] != func.arguments[k]) { + is_prefix = false; + break; + } + } + if (is_prefix) { + return func; + } + } + throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), + error); + } + return GetFunctionByOffset(index); +} + +TableFunctionSet::TableFunctionSet(string name) : FunctionSet(std::move(name)) { +} + +TableFunctionSet::TableFunctionSet(TableFunction fun) : FunctionSet(std::move(fun.name)) { + functions.push_back(std::move(fun)); +} + +TableFunction TableFunctionSet::GetFunctionByArguments(ClientContext &context, const vector &arguments) { + string error; + FunctionBinder binder(context); + idx_t index = binder.BindFunction(name, *this, arguments, error); + if (index == DConstants::INVALID_INDEX) { + throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), + error); + } + return GetFunctionByOffset(index); +} + +PragmaFunctionSet::PragmaFunctionSet(string name) : FunctionSet(std::move(name)) { +} + +PragmaFunctionSet::PragmaFunctionSet(PragmaFunction fun) : FunctionSet(std::move(fun.name)) { + functions.push_back(std::move(fun)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/macro_function.cpp b/src/duckdb/src/function/macro_function.cpp new file mode 100644 index 00000000..ad92ee9f --- /dev/null +++ b/src/duckdb/src/function/macro_function.cpp @@ -0,0 +1,95 @@ + +#include "duckdb/function/macro_function.hpp" + +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/scalar_macro_function.hpp" +#include "duckdb/function/table_macro_function.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" + +namespace duckdb { + +// MacroFunction::MacroFunction(unique_ptr expression) : expression(std::move(expression)) {} + +MacroFunction::MacroFunction(MacroType type) : type(type) { +} + +string MacroFunction::ValidateArguments(MacroFunction ¯o_def, const string &name, FunctionExpression &function_expr, + vector> &positionals, + unordered_map> &defaults) { + + // separate positional and default arguments + for (auto &arg : function_expr.children) { + if (!arg->alias.empty()) { + // default argument + if (!macro_def.default_parameters.count(arg->alias)) { + return StringUtil::Format("Macro %s does not have default parameter %s!", name, arg->alias); + } else if (defaults.count(arg->alias)) { + return StringUtil::Format("Duplicate default parameters %s!", arg->alias); + } + defaults[arg->alias] = std::move(arg); + } else if (!defaults.empty()) { + return "Positional parameters cannot come after parameters with a default value!"; + } else { + // positional argument + positionals.push_back(std::move(arg)); + } + } + + // validate if the right number of arguments was supplied + string error; + auto ¶meters = macro_def.parameters; + if (parameters.size() != positionals.size()) { + error = StringUtil::Format( + "Macro function '%s(%s)' requires ", name, + StringUtil::Join(parameters, parameters.size(), ", ", [](const unique_ptr &p) { + return (p->Cast()).column_names[0]; + })); + error += parameters.size() == 1 ? "a single positional argument" + : StringUtil::Format("%i positional arguments", parameters.size()); + error += ", but "; + error += positionals.size() == 1 ? "a single positional argument was" + : StringUtil::Format("%i positional arguments were", positionals.size()); + error += " provided."; + return error; + } + + // Add the default values for parameters that have defaults, that were not explicitly assigned to + for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { + auto ¶meter_name = it->first; + auto ¶meter_default = it->second; + if (!defaults.count(parameter_name)) { + // This parameter was not set yet, set it with the default value + defaults[parameter_name] = parameter_default->Copy(); + } + } + + return error; +} + +void MacroFunction::CopyProperties(MacroFunction &other) const { + other.type = type; + for (auto ¶m : parameters) { + other.parameters.push_back(param->Copy()); + } + for (auto &kv : default_parameters) { + other.default_parameters[kv.first] = kv.second->Copy(); + } +} + +string MacroFunction::ToSQL(const string &schema, const string &name) const { + vector param_strings; + for (auto ¶m : parameters) { + param_strings.push_back(param->ToString()); + } + for (auto &named_param : default_parameters) { + param_strings.push_back(StringUtil::Format("%s := %s", named_param.first, named_param.second->ToString())); + } + + return StringUtil::Format("CREATE MACRO %s.%s(%s) AS ", schema, name, StringUtil::Join(param_strings, ", ")); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/pragma/pragma_functions.cpp b/src/duckdb/src/function/pragma/pragma_functions.cpp new file mode 100644 index 00000000..b4a3f6ce --- /dev/null +++ b/src/duckdb/src/function/pragma/pragma_functions.cpp @@ -0,0 +1,162 @@ +#include "duckdb/function/pragma/pragma_functions.hpp" + +#include "duckdb/common/enums/output_type.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/function/function_set.hpp" + +#include + +namespace duckdb { + +static void PragmaEnableProfilingStatement(ClientContext &context, const FunctionParameters ¶meters) { + auto &config = ClientConfig::GetConfig(context); + config.enable_profiler = true; + config.emit_profiler_output = true; +} + +void RegisterEnableProfiling(BuiltinFunctions &set) { + PragmaFunctionSet functions(""); + functions.AddFunction(PragmaFunction::PragmaStatement(string(), PragmaEnableProfilingStatement)); + + set.AddFunction("enable_profile", functions); + set.AddFunction("enable_profiling", functions); +} + +static void PragmaDisableProfiling(ClientContext &context, const FunctionParameters ¶meters) { + auto &config = ClientConfig::GetConfig(context); + config.enable_profiler = false; +} + +static void PragmaEnableProgressBar(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).enable_progress_bar = true; +} + +static void PragmaDisableProgressBar(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).enable_progress_bar = false; +} + +static void PragmaEnablePrintProgressBar(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).print_progress_bar = true; +} + +static void PragmaDisablePrintProgressBar(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).print_progress_bar = false; +} + +static void PragmaEnableVerification(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).query_verification_enabled = true; + ClientConfig::GetConfig(context).verify_serializer = true; +} + +static void PragmaDisableVerification(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).query_verification_enabled = false; + ClientConfig::GetConfig(context).verify_serializer = false; +} + +static void PragmaVerifySerializer(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).verify_serializer = true; +} + +static void PragmaDisableVerifySerializer(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).verify_serializer = false; +} + +static void PragmaEnableExternalVerification(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).verify_external = true; +} + +static void PragmaDisableExternalVerification(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).verify_external = false; +} + +static void PragmaEnableForceParallelism(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).verify_parallelism = true; +} + +static void PragmaEnableIndexJoin(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).enable_index_join = true; +} + +static void PragmaEnableForceIndexJoin(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).force_index_join = true; +} + +static void PragmaForceCheckpoint(ClientContext &context, const FunctionParameters ¶meters) { + DBConfig::GetConfig(context).options.force_checkpoint = true; +} + +static void PragmaDisableForceParallelism(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).verify_parallelism = false; +} + +static void PragmaEnableObjectCache(ClientContext &context, const FunctionParameters ¶meters) { + DBConfig::GetConfig(context).options.object_cache_enable = true; +} + +static void PragmaDisableObjectCache(ClientContext &context, const FunctionParameters ¶meters) { + DBConfig::GetConfig(context).options.object_cache_enable = false; +} + +static void PragmaEnableCheckpointOnShutdown(ClientContext &context, const FunctionParameters ¶meters) { + DBConfig::GetConfig(context).options.checkpoint_on_shutdown = true; +} + +static void PragmaDisableCheckpointOnShutdown(ClientContext &context, const FunctionParameters ¶meters) { + DBConfig::GetConfig(context).options.checkpoint_on_shutdown = false; +} + +static void PragmaEnableOptimizer(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).enable_optimizer = true; +} + +static void PragmaDisableOptimizer(ClientContext &context, const FunctionParameters ¶meters) { + ClientConfig::GetConfig(context).enable_optimizer = false; +} + +void PragmaFunctions::RegisterFunction(BuiltinFunctions &set) { + RegisterEnableProfiling(set); + + set.AddFunction(PragmaFunction::PragmaStatement("disable_profile", PragmaDisableProfiling)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_profiling", PragmaDisableProfiling)); + + set.AddFunction(PragmaFunction::PragmaStatement("enable_verification", PragmaEnableVerification)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_verification", PragmaDisableVerification)); + + set.AddFunction(PragmaFunction::PragmaStatement("verify_external", PragmaEnableExternalVerification)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_external", PragmaDisableExternalVerification)); + + set.AddFunction(PragmaFunction::PragmaStatement("verify_serializer", PragmaVerifySerializer)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_serializer", PragmaDisableVerifySerializer)); + + set.AddFunction(PragmaFunction::PragmaStatement("verify_parallelism", PragmaEnableForceParallelism)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_parallelism", PragmaDisableForceParallelism)); + + set.AddFunction(PragmaFunction::PragmaStatement("enable_object_cache", PragmaEnableObjectCache)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_object_cache", PragmaDisableObjectCache)); + + set.AddFunction(PragmaFunction::PragmaStatement("enable_optimizer", PragmaEnableOptimizer)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_optimizer", PragmaDisableOptimizer)); + + set.AddFunction(PragmaFunction::PragmaStatement("enable_index_join", PragmaEnableIndexJoin)); + set.AddFunction(PragmaFunction::PragmaStatement("force_index_join", PragmaEnableForceIndexJoin)); + set.AddFunction(PragmaFunction::PragmaStatement("force_checkpoint", PragmaForceCheckpoint)); + + set.AddFunction(PragmaFunction::PragmaStatement("enable_progress_bar", PragmaEnableProgressBar)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_progress_bar", PragmaDisableProgressBar)); + + set.AddFunction(PragmaFunction::PragmaStatement("enable_print_progress_bar", PragmaEnablePrintProgressBar)); + set.AddFunction(PragmaFunction::PragmaStatement("disable_print_progress_bar", PragmaDisablePrintProgressBar)); + + set.AddFunction(PragmaFunction::PragmaStatement("enable_checkpoint_on_shutdown", PragmaEnableCheckpointOnShutdown)); + set.AddFunction( + PragmaFunction::PragmaStatement("disable_checkpoint_on_shutdown", PragmaDisableCheckpointOnShutdown)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/pragma/pragma_queries.cpp b/src/duckdb/src/function/pragma/pragma_queries.cpp new file mode 100644 index 00000000..4e45d0dd --- /dev/null +++ b/src/duckdb/src/function/pragma/pragma_queries.cpp @@ -0,0 +1,215 @@ +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/pragma/pragma_functions.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/parser/qualified_name.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" +#include "duckdb/parser/statement/export_statement.hpp" + +namespace duckdb { + +string PragmaTableInfo(ClientContext &context, const FunctionParameters ¶meters) { + return StringUtil::Format("SELECT * FROM pragma_table_info('%s');", parameters.values[0].ToString()); +} + +string PragmaShowTables(ClientContext &context, const FunctionParameters ¶meters) { + // clang-format off + return R"EOF( + with "tables" as + ( + SELECT table_name as "name" + FROM duckdb_tables + where in_search_path(database_name, schema_name) + ), "views" as + ( + SELECT view_name as "name" + FROM duckdb_views + where in_search_path(database_name, schema_name) + ), db_objects as + ( + SELECT "name" FROM "tables" + UNION ALL + SELECT "name" FROM "views" + ) + SELECT "name" + FROM db_objects + ORDER BY "name";)EOF"; + // clang-format on +} + +string PragmaShowTablesExpanded(ClientContext &context, const FunctionParameters ¶meters) { + return R"( + SELECT + t.database_name AS database, + t.schema_name AS schema, + t.table_name AS name, + LIST(c.column_name order by c.column_index) AS column_names, + LIST(c.data_type order by c.column_index) AS column_types, + FIRST(t.temporary) AS temporary, + FROM duckdb_tables t + JOIN duckdb_columns c + USING (table_oid) + GROUP BY database, schema, name + + UNION ALL + + SELECT + v.database_name AS database, + v.schema_name AS schema, + v.view_name AS name, + LIST(c.column_name order by c.column_index) AS column_names, + LIST(c.data_type order by c.column_index) AS column_types, + FIRST(v.temporary) AS temporary, + FROM duckdb_views v + JOIN duckdb_columns c + ON (v.view_oid=c.table_oid) + GROUP BY database, schema, name + + ORDER BY database, schema, name + )"; +} + +string PragmaShowDatabases(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT database_name FROM duckdb_databases() WHERE NOT internal ORDER BY database_name;"; +} + +string PragmaAllProfiling(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT * FROM pragma_last_profiling_output() JOIN pragma_detailed_profiling_output() ON " + "(pragma_last_profiling_output.operator_id);"; +} + +string PragmaDatabaseList(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT * FROM pragma_database_list;"; +} + +string PragmaCollations(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT * FROM pragma_collations() ORDER BY 1;"; +} + +string PragmaFunctionsQuery(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT function_name AS name, upper(function_type) AS type, parameter_types AS parameters, varargs, " + "return_type, has_side_effects AS side_effects" + " FROM duckdb_functions()" + " WHERE function_type IN ('scalar', 'aggregate')" + " ORDER BY 1;"; +} + +string PragmaShow(ClientContext &context, const FunctionParameters ¶meters) { + // PRAGMA table_info but with some aliases + auto table = QualifiedName::Parse(parameters.values[0].ToString()); + + // clang-format off + string sql = R"( + SELECT + name AS "column_name", + type as "column_type", + CASE WHEN "notnull" THEN 'NO' ELSE 'YES' END AS "null", + (SELECT + MIN(CASE + WHEN constraint_type='PRIMARY KEY' THEN 'PRI' + WHEN constraint_type='UNIQUE' THEN 'UNI' + ELSE NULL END) + FROM duckdb_constraints() c + WHERE c.table_oid=cols.table_oid + AND list_contains(constraint_column_names, cols.column_name)) AS "key", + dflt_value AS "default", + NULL AS "extra" + FROM pragma_table_info('%func_param_table%') + LEFT JOIN duckdb_columns cols + ON cols.column_name = pragma_table_info.name + AND cols.table_name='%table_name%' + AND cols.schema_name='%table_schema%' + AND cols.database_name = '%table_database%' + ORDER BY column_index;)"; + // clang-format on + + sql = StringUtil::Replace(sql, "%func_param_table%", parameters.values[0].ToString()); + sql = StringUtil::Replace(sql, "%table_name%", table.name); + sql = StringUtil::Replace(sql, "%table_schema%", table.schema.empty() ? DEFAULT_SCHEMA : table.schema); + sql = StringUtil::Replace(sql, "%table_database%", + table.catalog.empty() ? DatabaseManager::GetDefaultDatabase(context) : table.catalog); + return sql; +} + +string PragmaVersion(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT * FROM pragma_version();"; +} + +string PragmaPlatform(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT * FROM pragma_platform();"; +} + +string PragmaImportDatabase(ClientContext &context, const FunctionParameters ¶meters) { + auto &config = DBConfig::GetConfig(context); + if (!config.options.enable_external_access) { + throw PermissionException("Import is disabled through configuration"); + } + auto &fs = FileSystem::GetFileSystem(context); + + string final_query; + // read the "shema.sql" and "load.sql" files + vector files = {"schema.sql", "load.sql"}; + for (auto &file : files) { + auto file_path = fs.JoinPath(parameters.values[0].ToString(), file); + auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_READ, FileSystem::DEFAULT_LOCK, + FileSystem::DEFAULT_COMPRESSION); + auto fsize = fs.GetFileSize(*handle); + auto buffer = make_unsafe_uniq_array(fsize); + fs.Read(*handle, buffer.get(), fsize); + auto query = string(buffer.get(), fsize); + // Replace the placeholder with the path provided to IMPORT + if (file == "load.sql") { + Parser parser; + parser.ParseQuery(query); + auto copy_statements = std::move(parser.statements); + query.clear(); + for (auto &statement_p : copy_statements) { + D_ASSERT(statement_p->type == StatementType::COPY_STATEMENT); + auto &statement = statement_p->Cast(); + auto &info = *statement.info; + auto file_name = fs.ExtractName(info.file_path); + info.file_path = fs.JoinPath(parameters.values[0].ToString(), file_name); + query += statement.ToString() + ";"; + } + } + final_query += query; + } + return final_query; +} + +string PragmaDatabaseSize(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT * FROM pragma_database_size();"; +} + +string PragmaStorageInfo(ClientContext &context, const FunctionParameters ¶meters) { + return StringUtil::Format("SELECT * FROM pragma_storage_info('%s');", parameters.values[0].ToString()); +} + +string PragmaMetadataInfo(ClientContext &context, const FunctionParameters ¶meters) { + return "SELECT * FROM pragma_metadata_info();"; +} + +void PragmaQueries::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(PragmaFunction::PragmaCall("table_info", PragmaTableInfo, {LogicalType::VARCHAR})); + set.AddFunction(PragmaFunction::PragmaCall("storage_info", PragmaStorageInfo, {LogicalType::VARCHAR})); + set.AddFunction(PragmaFunction::PragmaCall("metadata_info", PragmaMetadataInfo, {})); + set.AddFunction(PragmaFunction::PragmaStatement("show_tables", PragmaShowTables)); + set.AddFunction(PragmaFunction::PragmaStatement("show_tables_expanded", PragmaShowTablesExpanded)); + set.AddFunction(PragmaFunction::PragmaStatement("show_databases", PragmaShowDatabases)); + set.AddFunction(PragmaFunction::PragmaStatement("database_list", PragmaDatabaseList)); + set.AddFunction(PragmaFunction::PragmaStatement("collations", PragmaCollations)); + set.AddFunction(PragmaFunction::PragmaCall("show", PragmaShow, {LogicalType::VARCHAR})); + set.AddFunction(PragmaFunction::PragmaStatement("version", PragmaVersion)); + set.AddFunction(PragmaFunction::PragmaStatement("platform", PragmaPlatform)); + set.AddFunction(PragmaFunction::PragmaStatement("database_size", PragmaDatabaseSize)); + set.AddFunction(PragmaFunction::PragmaStatement("functions", PragmaFunctionsQuery)); + set.AddFunction(PragmaFunction::PragmaCall("import_database", PragmaImportDatabase, {LogicalType::VARCHAR})); + set.AddFunction(PragmaFunction::PragmaStatement("all_profiling_output", PragmaAllProfiling)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/pragma_function.cpp b/src/duckdb/src/function/pragma_function.cpp new file mode 100644 index 00000000..531cdfb4 --- /dev/null +++ b/src/duckdb/src/function/pragma_function.cpp @@ -0,0 +1,45 @@ +#include "duckdb/function/pragma_function.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +PragmaFunction::PragmaFunction(string name, PragmaType pragma_type, pragma_query_t query, pragma_function_t function, + vector arguments, LogicalType varargs) + : SimpleNamedParameterFunction(std::move(name), std::move(arguments), std::move(varargs)), type(pragma_type), + query(query), function(function) { +} + +PragmaFunction PragmaFunction::PragmaCall(const string &name, pragma_query_t query, vector arguments, + LogicalType varargs) { + return PragmaFunction(name, PragmaType::PRAGMA_CALL, query, nullptr, std::move(arguments), std::move(varargs)); +} + +PragmaFunction PragmaFunction::PragmaCall(const string &name, pragma_function_t function, vector arguments, + LogicalType varargs) { + return PragmaFunction(name, PragmaType::PRAGMA_CALL, nullptr, function, std::move(arguments), std::move(varargs)); +} + +PragmaFunction PragmaFunction::PragmaStatement(const string &name, pragma_query_t query) { + vector types; + return PragmaFunction(name, PragmaType::PRAGMA_STATEMENT, query, nullptr, std::move(types), LogicalType::INVALID); +} + +PragmaFunction PragmaFunction::PragmaStatement(const string &name, pragma_function_t function) { + vector types; + return PragmaFunction(name, PragmaType::PRAGMA_STATEMENT, nullptr, function, std::move(types), + LogicalType::INVALID); +} + +string PragmaFunction::ToString() const { + switch (type) { + case PragmaType::PRAGMA_STATEMENT: + return StringUtil::Format("PRAGMA %s", name); + case PragmaType::PRAGMA_CALL: { + return StringUtil::Format("PRAGMA %s", SimpleNamedParameterFunction::ToString()); + } + default: + return "UNKNOWN"; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp new file mode 100644 index 00000000..173353fb --- /dev/null +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp @@ -0,0 +1,214 @@ +#include "duckdb/function/function_set.hpp" +#include "duckdb/function/scalar/compressed_materialization_functions.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +static string IntegralCompressFunctionName(const LogicalType &result_type) { + return StringUtil::Format("__internal_compress_integral_%s", + StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); +} + +template +struct TemplatedIntegralCompress { + static inline RESULT_TYPE Operation(const INPUT_TYPE &input, const INPUT_TYPE &min_val) { + D_ASSERT(min_val <= input); + return input - min_val; + } +}; + +template +struct TemplatedIntegralCompress { + static inline RESULT_TYPE Operation(const hugeint_t &input, const hugeint_t &min_val) { + D_ASSERT(min_val <= input); + return (input - min_val).lower; + } +}; + +template +static void IntegralCompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + D_ASSERT(args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR); + const auto min_val = ConstantVector::GetData(args.data[1])[0]; + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](const INPUT_TYPE &input) { + return TemplatedIntegralCompress::Operation(input, min_val); + }); +} + +template +static scalar_function_t GetIntegralCompressFunction(const LogicalType &input_type, const LogicalType &result_type) { + return IntegralCompressFunction; +} + +template +static scalar_function_t GetIntegralCompressFunctionResultSwitch(const LogicalType &input_type, + const LogicalType &result_type) { + switch (result_type.id()) { + case LogicalTypeId::UTINYINT: + return GetIntegralCompressFunction(input_type, result_type); + case LogicalTypeId::USMALLINT: + return GetIntegralCompressFunction(input_type, result_type); + case LogicalTypeId::UINTEGER: + return GetIntegralCompressFunction(input_type, result_type); + case LogicalTypeId::UBIGINT: + return GetIntegralCompressFunction(input_type, result_type); + default: + throw InternalException("Unexpected result type in GetIntegralCompressFunctionResultSwitch"); + } +} + +static scalar_function_t GetIntegralCompressFunctionInputSwitch(const LogicalType &input_type, + const LogicalType &result_type) { + switch (input_type.id()) { + case LogicalTypeId::SMALLINT: + return GetIntegralCompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::INTEGER: + return GetIntegralCompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::BIGINT: + return GetIntegralCompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::HUGEINT: + return GetIntegralCompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::USMALLINT: + return GetIntegralCompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::UINTEGER: + return GetIntegralCompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::UBIGINT: + return GetIntegralCompressFunctionResultSwitch(input_type, result_type); + default: + throw InternalException("Unexpected input type in GetIntegralCompressFunctionInputSwitch"); + } +} + +static string IntegralDecompressFunctionName(const LogicalType &result_type) { + return StringUtil::Format("__internal_decompress_integral_%s", + StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); +} + +template +static inline RESULT_TYPE TemplatedIntegralDecompress(const INPUT_TYPE &input, const RESULT_TYPE &min_val) { + return min_val + input; +} + +template +static void IntegralDecompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + D_ASSERT(args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR); + D_ASSERT(args.data[1].GetType() == result.GetType()); + const auto min_val = ConstantVector::GetData(args.data[1])[0]; + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](const INPUT_TYPE &input) { + return TemplatedIntegralDecompress(input, min_val); + }); +} + +template +static scalar_function_t GetIntegralDecompressFunction(const LogicalType &input_type, const LogicalType &result_type) { + return IntegralDecompressFunction; +} + +template +static scalar_function_t GetIntegralDecompressFunctionResultSwitch(const LogicalType &input_type, + const LogicalType &result_type) { + switch (result_type.id()) { + case LogicalTypeId::SMALLINT: + return GetIntegralDecompressFunction(input_type, result_type); + case LogicalTypeId::INTEGER: + return GetIntegralDecompressFunction(input_type, result_type); + case LogicalTypeId::BIGINT: + return GetIntegralDecompressFunction(input_type, result_type); + case LogicalTypeId::HUGEINT: + return GetIntegralDecompressFunction(input_type, result_type); + case LogicalTypeId::USMALLINT: + return GetIntegralDecompressFunction(input_type, result_type); + case LogicalTypeId::UINTEGER: + return GetIntegralDecompressFunction(input_type, result_type); + case LogicalTypeId::UBIGINT: + return GetIntegralDecompressFunction(input_type, result_type); + default: + throw InternalException("Unexpected input type in GetIntegralDecompressFunctionSetSwitch"); + } +} + +static scalar_function_t GetIntegralDecompressFunctionInputSwitch(const LogicalType &input_type, + const LogicalType &result_type) { + switch (input_type.id()) { + case LogicalTypeId::UTINYINT: + return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::USMALLINT: + return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::UINTEGER: + return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); + case LogicalTypeId::UBIGINT: + return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); + default: + throw InternalException("Unexpected result type in GetIntegralDecompressFunctionInputSwitch"); + } +} + +static void CMIntegralSerialize(Serializer &serializer, const optional_ptr bind_data, + const ScalarFunction &function) { + serializer.WriteProperty(100, "arguments", function.arguments); + serializer.WriteProperty(101, "return_type", function.return_type); +} + +template +unique_ptr CMIntegralDeserialize(Deserializer &deserializer, ScalarFunction &function) { + function.arguments = deserializer.ReadProperty>(100, "arguments"); + auto return_type = deserializer.ReadProperty(101, "return_type"); + function.function = GET_FUNCTION(function.arguments[0], return_type); + return nullptr; +} + +ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { + ScalarFunction result(IntegralCompressFunctionName(result_type), {input_type, input_type}, result_type, + GetIntegralCompressFunctionInputSwitch(input_type, result_type), + CompressedMaterializationFunctions::Bind); + result.serialize = CMIntegralSerialize; + result.deserialize = CMIntegralDeserialize; + return result; +} + +static ScalarFunctionSet GetIntegralCompressFunctionSet(const LogicalType &result_type) { + ScalarFunctionSet set(IntegralCompressFunctionName(result_type)); + for (const auto &input_type : LogicalType::Integral()) { + if (GetTypeIdSize(result_type.InternalType()) < GetTypeIdSize(input_type.InternalType())) { + set.AddFunction(CMIntegralCompressFun::GetFunction(input_type, result_type)); + } + } + return set; +} + +void CMIntegralCompressFun::RegisterFunction(BuiltinFunctions &set) { + for (const auto &result_type : CompressedMaterializationFunctions::IntegralTypes()) { + set.AddFunction(GetIntegralCompressFunctionSet(result_type)); + } +} + +ScalarFunction CMIntegralDecompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { + ScalarFunction result(IntegralDecompressFunctionName(result_type), {input_type, result_type}, result_type, + GetIntegralDecompressFunctionInputSwitch(input_type, result_type), + CompressedMaterializationFunctions::Bind); + result.serialize = CMIntegralSerialize; + result.deserialize = CMIntegralDeserialize; + return result; +} + +static ScalarFunctionSet GetIntegralDecompressFunctionSet(const LogicalType &result_type) { + ScalarFunctionSet set(IntegralDecompressFunctionName(result_type)); + for (const auto &input_type : CompressedMaterializationFunctions::IntegralTypes()) { + if (GetTypeIdSize(result_type.InternalType()) > GetTypeIdSize(input_type.InternalType())) { + set.AddFunction(CMIntegralDecompressFun::GetFunction(input_type, result_type)); + } + } + return set; +} + +void CMIntegralDecompressFun::RegisterFunction(BuiltinFunctions &set) { + for (const auto &result_type : LogicalType::Integral()) { + if (GetTypeIdSize(result_type.InternalType()) > 1) { + set.AddFunction(GetIntegralDecompressFunctionSet(result_type)); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp new file mode 100644 index 00000000..96e14fec --- /dev/null +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp @@ -0,0 +1,250 @@ +#include "duckdb/common/bswap.hpp" +#include "duckdb/function/scalar/compressed_materialization_functions.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +static string StringCompressFunctionName(const LogicalType &result_type) { + return StringUtil::Format("__internal_compress_string_%s", + StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); +} + +template +static inline void TemplatedReverseMemCpy(const data_ptr_t __restrict &dest, const const_data_ptr_t __restrict &src) { + for (idx_t i = 0; i < LENGTH; i++) { + dest[i] = src[LENGTH - 1 - i]; + } +} + +static inline void ReverseMemCpy(const data_ptr_t __restrict &dest, const const_data_ptr_t __restrict &src, + const idx_t &length) { + for (idx_t i = 0; i < length; i++) { + dest[i] = src[length - 1 - i]; + } +} + +template +static inline RESULT_TYPE StringCompressInternal(const string_t &input) { + RESULT_TYPE result; + const auto result_ptr = data_ptr_cast(&result); + if (sizeof(RESULT_TYPE) <= string_t::INLINE_LENGTH) { + TemplatedReverseMemCpy(result_ptr, const_data_ptr_cast(input.GetPrefix())); + } else if (input.IsInlined()) { + static constexpr auto REMAINDER = sizeof(RESULT_TYPE) - string_t::INLINE_LENGTH; + TemplatedReverseMemCpy(result_ptr + REMAINDER, const_data_ptr_cast(input.GetPrefix())); + memset(result_ptr, '\0', REMAINDER); + } else { + const auto remainder = sizeof(RESULT_TYPE) - input.GetSize(); + ReverseMemCpy(result_ptr + remainder, data_ptr_cast(input.GetPointer()), input.GetSize()); + memset(result_ptr, '\0', remainder); + } + result_ptr[0] = input.GetSize(); + return result; +} + +template +static inline RESULT_TYPE StringCompress(const string_t &input) { + D_ASSERT(input.GetSize() < sizeof(RESULT_TYPE)); + return StringCompressInternal(input); +} + +template +static inline RESULT_TYPE MiniStringCompress(const string_t &input) { + if (sizeof(RESULT_TYPE) <= string_t::INLINE_LENGTH) { + return input.GetSize() + *const_data_ptr_cast(input.GetPrefix()); + } else if (input.GetSize() == 0) { + return 0; + } else { + return input.GetSize() + *const_data_ptr_cast(input.GetPointer()); + } +} + +template <> +inline uint8_t StringCompress(const string_t &input) { + D_ASSERT(input.GetSize() <= sizeof(uint8_t)); + return MiniStringCompress(input); +} + +template +static void StringCompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, args.size(), StringCompress); +} + +template +static scalar_function_t GetStringCompressFunction(const LogicalType &result_type) { + return StringCompressFunction; +} + +static scalar_function_t GetStringCompressFunctionSwitch(const LogicalType &result_type) { + switch (result_type.id()) { + case LogicalTypeId::UTINYINT: + return GetStringCompressFunction(result_type); + case LogicalTypeId::USMALLINT: + return GetStringCompressFunction(result_type); + case LogicalTypeId::UINTEGER: + return GetStringCompressFunction(result_type); + case LogicalTypeId::UBIGINT: + return GetStringCompressFunction(result_type); + case LogicalTypeId::HUGEINT: + return GetStringCompressFunction(result_type); + default: + throw InternalException("Unexpected type in GetStringCompressFunctionSwitch"); + } +} + +static string StringDecompressFunctionName() { + return "__internal_decompress_string"; +} + +struct StringDecompressLocalState : public FunctionLocalState { +public: + explicit StringDecompressLocalState(ClientContext &context) : allocator(Allocator::Get(context)) { + } + + static unique_ptr Init(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + return make_uniq(state.GetContext()); + } + +public: + ArenaAllocator allocator; +}; + +template +static inline string_t StringDecompress(const INPUT_TYPE &input, ArenaAllocator &allocator) { + const auto input_ptr = const_data_ptr_cast(&input); + string_t result(input_ptr[0]); + if (sizeof(INPUT_TYPE) <= string_t::INLINE_LENGTH) { + const auto result_ptr = data_ptr_cast(result.GetPrefixWriteable()); + TemplatedReverseMemCpy(result_ptr, input_ptr); + memset(result_ptr + sizeof(INPUT_TYPE) - 1, '\0', string_t::INLINE_LENGTH - sizeof(INPUT_TYPE) + 1); + } else if (result.GetSize() <= string_t::INLINE_LENGTH) { + static constexpr auto REMAINDER = sizeof(INPUT_TYPE) - string_t::INLINE_LENGTH; + const auto result_ptr = data_ptr_cast(result.GetPrefixWriteable()); + TemplatedReverseMemCpy(result_ptr, input_ptr + REMAINDER); + } else { + result.SetPointer(char_ptr_cast(allocator.Allocate(sizeof(INPUT_TYPE)))); + TemplatedReverseMemCpy(data_ptr_cast(result.GetPointer()), input_ptr); + memcpy(result.GetPrefixWriteable(), result.GetPointer(), string_t::PREFIX_LENGTH); + } + return result; +} + +template +static inline string_t MiniStringDecompress(const INPUT_TYPE &input, ArenaAllocator &allocator) { + if (input == 0) { + string_t result(uint32_t(0)); + memset(result.GetPrefixWriteable(), '\0', string_t::INLINE_BYTES); + return result; + } + + string_t result(1); + if (sizeof(INPUT_TYPE) <= string_t::INLINE_LENGTH) { + memset(result.GetPrefixWriteable(), '\0', string_t::INLINE_BYTES); + *data_ptr_cast(result.GetPrefixWriteable()) = input - 1; + } else { + result.SetPointer(char_ptr_cast(allocator.Allocate(1))); + *data_ptr_cast(result.GetPointer()) = input - 1; + memset(result.GetPrefixWriteable(), '\0', string_t::PREFIX_LENGTH); + *result.GetPrefixWriteable() = *result.GetPointer(); + } + return result; +} + +template <> +inline string_t StringDecompress(const uint8_t &input, ArenaAllocator &allocator) { + return MiniStringDecompress(input, allocator); +} + +template +static void StringDecompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &allocator = ExecuteFunctionState::GetFunctionState(state)->Cast().allocator; + allocator.Reset(); + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](const INPUT_TYPE &input) { + return StringDecompress(input, allocator); + }); +} + +template +static scalar_function_t GetStringDecompressFunction(const LogicalType &input_type) { + return StringDecompressFunction; +} + +static scalar_function_t GetStringDecompressFunctionSwitch(const LogicalType &input_type) { + switch (input_type.id()) { + case LogicalTypeId::UTINYINT: + return GetStringDecompressFunction(input_type); + case LogicalTypeId::USMALLINT: + return GetStringDecompressFunction(input_type); + case LogicalTypeId::UINTEGER: + return GetStringDecompressFunction(input_type); + case LogicalTypeId::UBIGINT: + return GetStringDecompressFunction(input_type); + case LogicalTypeId::HUGEINT: + return GetStringDecompressFunction(input_type); + default: + throw InternalException("Unexpected type in GetStringDecompressFunctionSwitch"); + } +} + +static void CMStringCompressSerialize(Serializer &serializer, const optional_ptr bind_data, + const ScalarFunction &function) { + serializer.WriteProperty(100, "arguments", function.arguments); + serializer.WriteProperty(101, "return_type", function.return_type); +} + +unique_ptr CMStringCompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { + function.arguments = deserializer.ReadProperty>(100, "arguments"); + auto return_type = deserializer.ReadProperty(101, "return_type"); + function.function = GetStringCompressFunctionSwitch(return_type); + return nullptr; +} + +ScalarFunction CMStringCompressFun::GetFunction(const LogicalType &result_type) { + ScalarFunction result(StringCompressFunctionName(result_type), {LogicalType::VARCHAR}, result_type, + GetStringCompressFunctionSwitch(result_type), CompressedMaterializationFunctions::Bind); + result.serialize = CMStringCompressSerialize; + result.deserialize = CMStringCompressDeserialize; + return result; +} + +void CMStringCompressFun::RegisterFunction(BuiltinFunctions &set) { + for (const auto &result_type : CompressedMaterializationFunctions::StringTypes()) { + set.AddFunction(CMStringCompressFun::GetFunction(result_type)); + } +} + +static void CMStringDecompressSerialize(Serializer &serializer, const optional_ptr bind_data, + const ScalarFunction &function) { + serializer.WriteProperty(100, "arguments", function.arguments); +} + +unique_ptr CMStringDecompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { + function.arguments = deserializer.ReadProperty>(100, "arguments"); + function.function = GetStringDecompressFunctionSwitch(function.arguments[0]); + return nullptr; +} + +ScalarFunction CMStringDecompressFun::GetFunction(const LogicalType &input_type) { + ScalarFunction result(StringDecompressFunctionName(), {input_type}, LogicalType::VARCHAR, + GetStringDecompressFunctionSwitch(input_type), CompressedMaterializationFunctions::Bind, + nullptr, nullptr, StringDecompressLocalState::Init); + result.serialize = CMStringDecompressSerialize; + result.deserialize = CMStringDecompressDeserialize; + return result; +} + +static ScalarFunctionSet GetStringDecompressFunctionSet() { + ScalarFunctionSet set(StringDecompressFunctionName()); + for (const auto &input_type : CompressedMaterializationFunctions::StringTypes()) { + set.AddFunction(CMStringDecompressFun::GetFunction(input_type)); + } + return set; +} + +void CMStringDecompressFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(GetStringDecompressFunctionSet()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/compressed_materialization_functions.cpp b/src/duckdb/src/function/scalar/compressed_materialization_functions.cpp new file mode 100644 index 00000000..456d1bb8 --- /dev/null +++ b/src/duckdb/src/function/scalar/compressed_materialization_functions.cpp @@ -0,0 +1,29 @@ +#include "duckdb/function/scalar/compressed_materialization_functions.hpp" + +namespace duckdb { + +const vector CompressedMaterializationFunctions::IntegralTypes() { + return {LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT}; +} + +const vector CompressedMaterializationFunctions::StringTypes() { + return {LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, + LogicalType::HUGEINT}; +} + +// LCOV_EXCL_START +unique_ptr CompressedMaterializationFunctions::Bind(ClientContext &context, + ScalarFunction &bound_function, + vector> &arguments) { + throw BinderException("Compressed materialization functions are for internal use only!"); +} +// LCOV_EXCL_STOP + +void BuiltinFunctions::RegisterCompressedMaterializationFunctions() { + Register(); + Register(); + Register(); + Register(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp new file mode 100644 index 00000000..05b0ebce --- /dev/null +++ b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp @@ -0,0 +1,107 @@ +#include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +struct ConstantOrNullBindData : public FunctionData { + explicit ConstantOrNullBindData(Value val) : value(std::move(val)) { + } + + Value value; + +public: + unique_ptr Copy() const override { + return make_uniq(value); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return value == other.value; + } +}; + +static void ConstantOrNullFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + result.Reference(info.value); + for (idx_t idx = 1; idx < args.ColumnCount(); idx++) { + switch (args.data[idx].GetVectorType()) { + case VectorType::FLAT_VECTOR: { + auto &input_mask = FlatVector::Validity(args.data[idx]); + if (!input_mask.AllValid()) { + // there are null values: need to merge them into the result + result.Flatten(args.size()); + auto &result_mask = FlatVector::Validity(result); + result_mask.Combine(input_mask, args.size()); + } + break; + } + case VectorType::CONSTANT_VECTOR: { + if (ConstantVector::IsNull(args.data[idx])) { + // input is constant null, return constant null + result.Reference(info.value); + ConstantVector::SetNull(result, true); + return; + } + break; + } + default: { + UnifiedVectorFormat vdata; + args.data[idx].ToUnifiedFormat(args.size(), vdata); + if (!vdata.validity.AllValid()) { + result.Flatten(args.size()); + auto &result_mask = FlatVector::Validity(result); + for (idx_t i = 0; i < args.size(); i++) { + if (!vdata.validity.RowIsValid(vdata.sel->get_index(i))) { + result_mask.SetInvalid(i); + } + } + } + break; + } + } + } +} + +ScalarFunction ConstantOrNull::GetFunction(const LogicalType &return_type) { + return ScalarFunction("constant_or_null", {return_type, LogicalType::ANY}, return_type, ConstantOrNullFunction); +} + +unique_ptr ConstantOrNull::Bind(Value value) { + return make_uniq(std::move(value)); +} + +bool ConstantOrNull::IsConstantOrNull(BoundFunctionExpression &expr, const Value &val) { + if (expr.function.name != "constant_or_null") { + return false; + } + D_ASSERT(expr.bind_info); + auto &bind_data = expr.bind_info->Cast(); + D_ASSERT(bind_data.value.type() == val.type()); + return bind_data.value == val; +} + +unique_ptr ConstantOrNullBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments[0]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[0]->IsFoldable()) { + throw BinderException("ConstantOrNull requires a constant input"); + } + D_ASSERT(arguments.size() >= 2); + auto value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + bound_function.return_type = arguments[0]->return_type; + return make_uniq(std::move(value)); +} + +void ConstantOrNull::RegisterFunction(BuiltinFunctions &set) { + auto fun = ConstantOrNull::GetFunction(LogicalType::ANY); + fun.bind = ConstantOrNullBind; + fun.varargs = LogicalType::ANY; + set.AddFunction(fun); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/generic_functions.cpp b/src/duckdb/src/function/scalar/generic_functions.cpp new file mode 100644 index 00000000..a128aa56 --- /dev/null +++ b/src/duckdb/src/function/scalar/generic_functions.cpp @@ -0,0 +1,10 @@ +#include "duckdb/function/scalar/generic_functions.hpp" + +namespace duckdb { + +void BuiltinFunctions::RegisterGenericFunctions() { + Register(); + Register(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/contains_or_position.cpp b/src/duckdb/src/function/scalar/list/contains_or_position.cpp new file mode 100644 index 00000000..064f4c28 --- /dev/null +++ b/src/duckdb/src/function/scalar/list/contains_or_position.cpp @@ -0,0 +1,81 @@ +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" + +namespace duckdb { + +static void ListContainsFunction(DataChunk &args, ExpressionState &state, Vector &result) { + (void)state; + return ListContainsOrPosition(args, result); +} + +static void ListPositionFunction(DataChunk &args, ExpressionState &state, Vector &result) { + (void)state; + return ListContainsOrPosition(args, result); +} + +template +static unique_ptr ListContainsOrPositionBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2); + + const auto &list = arguments[0]->return_type; // change to list + const auto &value = arguments[1]->return_type; + if (list.id() == LogicalTypeId::UNKNOWN) { + bound_function.return_type = RETURN_TYPE; + if (value.id() != LogicalTypeId::UNKNOWN) { + // only list is a parameter, cast it to a list of value type + bound_function.arguments[0] = LogicalType::LIST(value); + bound_function.arguments[1] = value; + } + } else if (value.id() == LogicalTypeId::UNKNOWN) { + // only value is a parameter: we expect the child type of list + auto const &child_type = ListType::GetChildType(list); + bound_function.arguments[0] = list; + bound_function.arguments[1] = child_type; + bound_function.return_type = RETURN_TYPE; + } else { + auto const &child_type = ListType::GetChildType(list); + auto max_child_type = LogicalType::MaxLogicalType(child_type, value); + auto list_type = LogicalType::LIST(max_child_type); + + bound_function.arguments[0] = list_type; + bound_function.arguments[1] = value == max_child_type ? value : max_child_type; + + // list_contains and list_position only differ in their return type + bound_function.return_type = RETURN_TYPE; + } + return make_uniq(bound_function.return_type); +} + +static unique_ptr ListContainsBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return ListContainsOrPositionBind(context, bound_function, arguments); +} + +static unique_ptr ListPositionBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return ListContainsOrPositionBind(context, bound_function, arguments); +} + +ScalarFunction ListContainsFun::GetFunction() { + return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, // argument list + LogicalType::BOOLEAN, // return type + ListContainsFunction, ListContainsBind, nullptr); +} + +ScalarFunction ListPositionFun::GetFunction() { + return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, // argument list + LogicalType::INTEGER, // return type + ListPositionFunction, ListPositionBind, nullptr); +} + +void ListContainsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction({"list_contains", "array_contains", "list_has", "array_has"}, GetFunction()); +} + +void ListPositionFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction({"list_position", "list_indexof", "array_position", "array_indexof"}, GetFunction()); +} +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/list_concat.cpp b/src/duckdb/src/function/scalar/list/list_concat.cpp new file mode 100644 index 00000000..0898f513 --- /dev/null +++ b/src/duckdb/src/function/scalar/list/list_concat.cpp @@ -0,0 +1,133 @@ +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +static void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + auto count = args.size(); + + Vector &lhs = args.data[0]; + Vector &rhs = args.data[1]; + if (lhs.GetType().id() == LogicalTypeId::SQLNULL) { + result.Reference(rhs); + return; + } + if (rhs.GetType().id() == LogicalTypeId::SQLNULL) { + result.Reference(lhs); + return; + } + + UnifiedVectorFormat lhs_data; + UnifiedVectorFormat rhs_data; + lhs.ToUnifiedFormat(count, lhs_data); + rhs.ToUnifiedFormat(count, rhs_data); + auto lhs_entries = UnifiedVectorFormat::GetData(lhs_data); + auto rhs_entries = UnifiedVectorFormat::GetData(rhs_data); + + auto lhs_list_size = ListVector::GetListSize(lhs); + auto rhs_list_size = ListVector::GetListSize(rhs); + auto &lhs_child = ListVector::GetEntry(lhs); + auto &rhs_child = ListVector::GetEntry(rhs); + UnifiedVectorFormat lhs_child_data; + UnifiedVectorFormat rhs_child_data; + lhs_child.ToUnifiedFormat(lhs_list_size, lhs_child_data); + rhs_child.ToUnifiedFormat(rhs_list_size, rhs_child_data); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_entries = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + idx_t offset = 0; + for (idx_t i = 0; i < count; i++) { + auto lhs_list_index = lhs_data.sel->get_index(i); + auto rhs_list_index = rhs_data.sel->get_index(i); + if (!lhs_data.validity.RowIsValid(lhs_list_index) && !rhs_data.validity.RowIsValid(rhs_list_index)) { + result_validity.SetInvalid(i); + continue; + } + result_entries[i].offset = offset; + result_entries[i].length = 0; + if (lhs_data.validity.RowIsValid(lhs_list_index)) { + const auto &lhs_entry = lhs_entries[lhs_list_index]; + result_entries[i].length += lhs_entry.length; + ListVector::Append(result, lhs_child, *lhs_child_data.sel, lhs_entry.offset + lhs_entry.length, + lhs_entry.offset); + } + if (rhs_data.validity.RowIsValid(rhs_list_index)) { + const auto &rhs_entry = rhs_entries[rhs_list_index]; + result_entries[i].length += rhs_entry.length; + ListVector::Append(result, rhs_child, *rhs_child_data.sel, rhs_entry.offset + rhs_entry.length, + rhs_entry.offset); + } + offset += result_entries[i].length; + } + D_ASSERT(ListVector::GetListSize(result) == offset); + + if (lhs.GetVectorType() == VectorType::CONSTANT_VECTOR && rhs.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr ListConcatBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2); + + auto &lhs = arguments[0]->return_type; + auto &rhs = arguments[1]->return_type; + if (lhs.id() == LogicalTypeId::UNKNOWN || rhs.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } else if (lhs.id() == LogicalTypeId::SQLNULL || rhs.id() == LogicalTypeId::SQLNULL) { + // we mimic postgres behaviour: list_concat(NULL, my_list) = my_list + auto return_type = rhs.id() == LogicalTypeId::SQLNULL ? lhs : rhs; + bound_function.arguments[0] = return_type; + bound_function.arguments[1] = return_type; + bound_function.return_type = return_type; + } else { + D_ASSERT(lhs.id() == LogicalTypeId::LIST); + D_ASSERT(rhs.id() == LogicalTypeId::LIST); + + // Resolve list type + LogicalType child_type = LogicalType::SQLNULL; + for (const auto &argument : arguments) { + child_type = LogicalType::MaxLogicalType(child_type, ListType::GetChildType(argument->return_type)); + } + auto list_type = LogicalType::LIST(child_type); + + bound_function.arguments[0] = list_type; + bound_function.arguments[1] = list_type; + bound_function.return_type = list_type; + } + return make_uniq(bound_function.return_type); +} + +static unique_ptr ListConcatStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + D_ASSERT(child_stats.size() == 2); + + auto &left_stats = child_stats[0]; + auto &right_stats = child_stats[1]; + + auto stats = left_stats.ToUnique(); + stats->Merge(right_stats); + + return stats; +} + +ScalarFunction ListConcatFun::GetFunction() { + // the arguments and return types are actually set in the binder function + auto fun = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, + LogicalType::LIST(LogicalType::ANY), ListConcatFunction, ListConcatBind, nullptr, + ListConcatStats); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +void ListConcatFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction({"list_concat", "list_cat", "array_concat", "array_cat"}, GetFunction()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/list_extract.cpp b/src/duckdb/src/function/scalar/list/list_extract.cpp new file mode 100644 index 00000000..e7d28d52 --- /dev/null +++ b/src/duckdb/src/function/scalar/list/list_extract.cpp @@ -0,0 +1,245 @@ +#include "duckdb/common/pair.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" + +namespace duckdb { + +template +void ListExtractTemplate(idx_t count, UnifiedVectorFormat &list_data, UnifiedVectorFormat &offsets_data, + Vector &child_vector, idx_t list_size, Vector &result) { + UnifiedVectorFormat child_format; + child_vector.ToUnifiedFormat(list_size, child_format); + + T *result_data; + + result.SetVectorType(VectorType::FLAT_VECTOR); + if (!VALIDITY_ONLY) { + result_data = FlatVector::GetData(result); + } + auto &result_mask = FlatVector::Validity(result); + + // heap-ref once + if (HEAP_REF) { + StringVector::AddHeapReference(result, child_vector); + } + + // this is lifted from ExecuteGenericLoop because we can't push the list child data into this otherwise + // should have gone with GetValue perhaps + auto child_data = UnifiedVectorFormat::GetData(child_format); + for (idx_t i = 0; i < count; i++) { + auto list_index = list_data.sel->get_index(i); + auto offsets_index = offsets_data.sel->get_index(i); + if (!list_data.validity.RowIsValid(list_index)) { + result_mask.SetInvalid(i); + continue; + } + if (!offsets_data.validity.RowIsValid(offsets_index)) { + result_mask.SetInvalid(i); + continue; + } + auto list_entry = (UnifiedVectorFormat::GetData(list_data))[list_index]; + auto offsets_entry = (UnifiedVectorFormat::GetData(offsets_data))[offsets_index]; + + // 1-based indexing + if (offsets_entry == 0) { + result_mask.SetInvalid(i); + continue; + } + offsets_entry = (offsets_entry > 0) ? offsets_entry - 1 : offsets_entry; + + idx_t child_offset; + if (offsets_entry < 0) { + if (offsets_entry < -int64_t(list_entry.length)) { + result_mask.SetInvalid(i); + continue; + } + child_offset = list_entry.offset + list_entry.length + offsets_entry; + } else { + if ((idx_t)offsets_entry >= list_entry.length) { + result_mask.SetInvalid(i); + continue; + } + child_offset = list_entry.offset + offsets_entry; + } + auto child_index = child_format.sel->get_index(child_offset); + if (child_format.validity.RowIsValid(child_index)) { + if (!VALIDITY_ONLY) { + result_data[i] = child_data[child_index]; + } + } else { + result_mask.SetInvalid(i); + } + } + if (count == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} +static void ExecuteListExtractInternal(const idx_t count, UnifiedVectorFormat &list, UnifiedVectorFormat &offsets, + Vector &child_vector, idx_t list_size, Vector &result) { + D_ASSERT(child_vector.GetType() == result.GetType()); + switch (result.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::INT16: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::INT32: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::INT64: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::INT128: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::UINT8: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::UINT16: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::UINT32: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::UINT64: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::FLOAT: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::DOUBLE: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::VARCHAR: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::INTERVAL: + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + case PhysicalType::STRUCT: { + auto &entries = StructVector::GetEntries(child_vector); + auto &result_entries = StructVector::GetEntries(result); + D_ASSERT(entries.size() == result_entries.size()); + // extract the child entries of the struct + for (idx_t i = 0; i < entries.size(); i++) { + ExecuteListExtractInternal(count, list, offsets, *entries[i], list_size, *result_entries[i]); + } + // extract the validity mask + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + } + case PhysicalType::LIST: { + // nested list: we have to reference the child + auto &child_child_list = ListVector::GetEntry(child_vector); + + ListVector::GetEntry(result).Reference(child_child_list); + ListVector::SetListSize(result, ListVector::GetListSize(child_vector)); + ListExtractTemplate(count, list, offsets, child_vector, list_size, result); + break; + } + default: + throw NotImplementedException("Unimplemented type for LIST_EXTRACT"); + } +} + +static void ExecuteListExtract(Vector &result, Vector &list, Vector &offsets, const idx_t count) { + D_ASSERT(list.GetType().id() == LogicalTypeId::LIST); + UnifiedVectorFormat list_data; + UnifiedVectorFormat offsets_data; + + list.ToUnifiedFormat(count, list_data); + offsets.ToUnifiedFormat(count, offsets_data); + ExecuteListExtractInternal(count, list_data, offsets_data, ListVector::GetEntry(list), + ListVector::GetListSize(list), result); + result.Verify(count); +} + +static void ExecuteStringExtract(Vector &result, Vector &input_vector, Vector &subscript_vector, const idx_t count) { + BinaryExecutor::Execute( + input_vector, subscript_vector, result, count, [&](string_t input_string, int64_t subscript) { + return SubstringFun::SubstringUnicode(result, input_string, subscript, 1); + }); +} + +static void ListExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + auto count = args.size(); + + result.SetVectorType(VectorType::CONSTANT_VECTOR); + for (idx_t i = 0; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::FLAT_VECTOR); + } + } + + Vector &base = args.data[0]; + Vector &subscript = args.data[1]; + + switch (base.GetType().id()) { + case LogicalTypeId::LIST: + ExecuteListExtract(result, base, subscript, count); + break; + case LogicalTypeId::VARCHAR: + ExecuteStringExtract(result, base, subscript, count); + break; + case LogicalTypeId::SQLNULL: + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + break; + default: + throw NotImplementedException("Specifier type not implemented"); + } +} + +static unique_ptr ListExtractBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2); + D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id()); + // list extract returns the child type of the list as return type + bound_function.return_type = ListType::GetChildType(arguments[0]->return_type); + return make_uniq(bound_function.return_type); +} + +static unique_ptr ListExtractStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &list_child_stats = ListStats::GetChildStats(child_stats[0]); + auto child_copy = list_child_stats.Copy(); + // list_extract always pushes a NULL, since if the offset is out of range for a list it inserts a null + child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + return child_copy.ToUnique(); +} + +void ListExtractFun::RegisterFunction(BuiltinFunctions &set) { + // the arguments and return types are actually set in the binder function + ScalarFunction lfun({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY, + ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); + + ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); + + ScalarFunctionSet list_extract("list_extract"); + list_extract.AddFunction(lfun); + list_extract.AddFunction(sfun); + set.AddFunction(list_extract); + + ScalarFunctionSet list_element("list_element"); + list_element.AddFunction(lfun); + list_element.AddFunction(sfun); + set.AddFunction(list_element); + + ScalarFunctionSet array_extract("array_extract"); + array_extract.AddFunction(lfun); + array_extract.AddFunction(sfun); + array_extract.AddFunction(StructExtractFun::GetFunction()); + set.AddFunction(array_extract); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/list_resize.cpp b/src/duckdb/src/function/scalar/list/list_resize.cpp new file mode 100644 index 00000000..107bb2df --- /dev/null +++ b/src/duckdb/src/function/scalar/list/list_resize.cpp @@ -0,0 +1,162 @@ +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { + +void ListResizeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.data[1].GetType().id() == LogicalTypeId::UBIGINT); + if (result.GetType().id() == LogicalTypeId::SQLNULL) { + FlatVector::SetNull(result, 0, true); + return; + } + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + auto count = args.size(); + + result.SetVectorType(VectorType::FLAT_VECTOR); + + auto &lists = args.data[0]; + auto &child = ListVector::GetEntry(args.data[0]); + auto &new_sizes = args.data[1]; + + UnifiedVectorFormat list_data; + lists.ToUnifiedFormat(count, list_data); + auto list_entries = UnifiedVectorFormat::GetData(list_data); + + UnifiedVectorFormat new_size_data; + new_sizes.ToUnifiedFormat(count, new_size_data); + auto new_size_entries = UnifiedVectorFormat::GetData(new_size_data); + + UnifiedVectorFormat child_data; + child.ToUnifiedFormat(count, child_data); + + // Find the new size of the result child vector + idx_t new_child_size = 0; + for (idx_t i = 0; i < count; i++) { + auto index = new_size_data.sel->get_index(i); + if (new_size_data.validity.RowIsValid(index)) { + new_child_size += new_size_entries[index]; + } + } + + // Create the default vector if it exists + UnifiedVectorFormat default_data; + optional_ptr default_vector; + if (args.ColumnCount() == 3) { + default_vector = &args.data[2]; + default_vector->ToUnifiedFormat(count, default_data); + default_vector->SetVectorType(VectorType::CONSTANT_VECTOR); + } + + ListVector::Reserve(result, new_child_size); + ListVector::SetListSize(result, new_child_size); + + auto result_entries = FlatVector::GetData(result); + auto &result_child = ListVector::GetEntry(result); + + // for each lists in the args + idx_t result_child_offset = 0; + for (idx_t args_index = 0; args_index < count; args_index++) { + auto l_index = list_data.sel->get_index(args_index); + auto new_index = new_size_data.sel->get_index(args_index); + + // set null if lists is null + if (!list_data.validity.RowIsValid(l_index)) { + FlatVector::SetNull(result, args_index, true); + continue; + } + + idx_t new_size_entry = 0; + if (new_size_data.validity.RowIsValid(new_index)) { + new_size_entry = new_size_entries[new_index]; + } + + // find the smallest size between lists and new_sizes + auto values_to_copy = MinValue(list_entries[l_index].length, new_size_entry); + + // set the result entry + result_entries[args_index].offset = result_child_offset; + result_entries[args_index].length = new_size_entry; + + // copy the values from the child vector + VectorOperations::Copy(child, result_child, list_entries[l_index].offset + values_to_copy, + list_entries[l_index].offset, result_child_offset); + result_child_offset += values_to_copy; + + // set default value if it exists + idx_t def_index = 0; + if (args.ColumnCount() == 3) { + def_index = default_data.sel->get_index(args_index); + } + + // if the new size is larger than the old size, fill in the default value + if (values_to_copy < new_size_entry) { + if (default_vector && default_data.validity.RowIsValid(def_index)) { + VectorOperations::Copy(*default_vector, result_child, new_size_entry - values_to_copy, def_index, + result_child_offset); + result_child_offset += new_size_entry - values_to_copy; + } else { + for (idx_t j = values_to_copy; j < new_size_entry; j++) { + FlatVector::SetNull(result_child, result_child_offset, true); + result_child_offset++; + } + } + } + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr ListResizeBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2 || arguments.size() == 3); + bound_function.arguments[1] = LogicalType::UBIGINT; + + // first argument is constant NULL + if (arguments[0]->return_type == LogicalType::SQLNULL) { + bound_function.arguments[0] = LogicalType::SQLNULL; + bound_function.return_type = LogicalType::SQLNULL; + return make_uniq(bound_function.return_type); + } + + // prepared statements + if (arguments[0]->return_type == LogicalType::UNKNOWN) { + bound_function.return_type = arguments[0]->return_type; + return make_uniq(bound_function.return_type); + } + + // default type does not match list type + if (bound_function.arguments.size() == 3 && + ListType::GetChildType(arguments[0]->return_type) != arguments[2]->return_type && + arguments[2]->return_type != LogicalTypeId::SQLNULL) { + bound_function.arguments[2] = ListType::GetChildType(arguments[0]->return_type); + } + + bound_function.return_type = arguments[0]->return_type; + return make_uniq(bound_function.return_type); +} + +void ListResizeFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunction sfun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY}, + LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); + sfun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + + ScalarFunction dfun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY, LogicalTypeId::ANY}, + LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); + dfun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + + ScalarFunctionSet list_resize("list_resize"); + list_resize.AddFunction(sfun); + list_resize.AddFunction(dfun); + set.AddFunction(list_resize); + + ScalarFunctionSet array_resize("array_resize"); + array_resize.AddFunction(sfun); + array_resize.AddFunction(dfun); + set.AddFunction(array_resize); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/nested_functions.cpp b/src/duckdb/src/function/scalar/nested_functions.cpp new file mode 100644 index 00000000..fa05ac7d --- /dev/null +++ b/src/duckdb/src/function/scalar/nested_functions.cpp @@ -0,0 +1,14 @@ +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +void BuiltinFunctions::RegisterNestedFunctions() { + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operators.cpp b/src/duckdb/src/function/scalar/operators.cpp new file mode 100644 index 00000000..2862b13c --- /dev/null +++ b/src/duckdb/src/function/scalar/operators.cpp @@ -0,0 +1,14 @@ +#include "duckdb/function/scalar/operators.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +void BuiltinFunctions::RegisterOperators() { + Register(); + Register(); + Register(); + Register(); + Register(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operators/add.cpp b/src/duckdb/src/function/scalar/operators/add.cpp new file mode 100644 index 00000000..6dcde063 --- /dev/null +++ b/src/duckdb/src/function/scalar/operators/add.cpp @@ -0,0 +1,237 @@ +#include "duckdb/common/operator/add.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/value.hpp" + +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/hugeint.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// + [add] +//===--------------------------------------------------------------------===// +template <> +float AddOperator::Operation(float left, float right) { + auto result = left + right; + return result; +} + +template <> +double AddOperator::Operation(double left, double right) { + auto result = left + right; + return result; +} + +template <> +interval_t AddOperator::Operation(interval_t left, interval_t right) { + left.months = AddOperatorOverflowCheck::Operation(left.months, right.months); + left.days = AddOperatorOverflowCheck::Operation(left.days, right.days); + left.micros = AddOperatorOverflowCheck::Operation(left.micros, right.micros); + return left; +} + +template <> +date_t AddOperator::Operation(date_t left, int32_t right) { + if (!Value::IsFinite(left)) { + return left; + } + int32_t days; + if (!TryAddOperator::Operation(left.days, right, days)) { + throw OutOfRangeException("Date out of range"); + } + date_t result(days); + if (!Value::IsFinite(result)) { + throw OutOfRangeException("Date out of range"); + } + return result; +} + +template <> +date_t AddOperator::Operation(int32_t left, date_t right) { + return AddOperator::Operation(right, left); +} + +template <> +timestamp_t AddOperator::Operation(date_t left, dtime_t right) { + if (left == date_t::infinity()) { + return timestamp_t::infinity(); + } else if (left == date_t::ninfinity()) { + return timestamp_t::ninfinity(); + } + timestamp_t result; + if (!Timestamp::TryFromDatetime(left, right, result)) { + throw OutOfRangeException("Timestamp out of range"); + } + return result; +} + +template <> +timestamp_t AddOperator::Operation(dtime_t left, date_t right) { + return AddOperator::Operation(right, left); +} + +template <> +date_t AddOperator::Operation(date_t left, interval_t right) { + return Interval::Add(left, right); +} + +template <> +date_t AddOperator::Operation(interval_t left, date_t right) { + return AddOperator::Operation(right, left); +} + +template <> +timestamp_t AddOperator::Operation(timestamp_t left, interval_t right) { + return Interval::Add(left, right); +} + +template <> +timestamp_t AddOperator::Operation(interval_t left, timestamp_t right) { + return AddOperator::Operation(right, left); +} + +//===--------------------------------------------------------------------===// +// + [add] with overflow check +//===--------------------------------------------------------------------===// +struct OverflowCheckedAddition { + template + static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { + UTYPE uresult = AddOperator::Operation(UTYPE(left), UTYPE(right)); + if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { + return false; + } + result = SRCTYPE(uresult); + return true; + } +}; + +template <> +bool TryAddOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { + return OverflowCheckedAddition::Operation(left, right, result); +} +template <> +bool TryAddOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { + return OverflowCheckedAddition::Operation(left, right, result); +} +template <> +bool TryAddOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { + return OverflowCheckedAddition::Operation(left, right, result); +} + +template <> +bool TryAddOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { + if (NumericLimits::Maximum() - left < right) { + return false; + } + return OverflowCheckedAddition::Operation(left, right, result); +} + +template <> +bool TryAddOperator::Operation(int8_t left, int8_t right, int8_t &result) { + return OverflowCheckedAddition::Operation(left, right, result); +} + +template <> +bool TryAddOperator::Operation(int16_t left, int16_t right, int16_t &result) { + return OverflowCheckedAddition::Operation(left, right, result); +} + +template <> +bool TryAddOperator::Operation(int32_t left, int32_t right, int32_t &result) { + return OverflowCheckedAddition::Operation(left, right, result); +} + +template <> +bool TryAddOperator::Operation(int64_t left, int64_t right, int64_t &result) { +#if (__GNUC__ >= 5) || defined(__clang__) + if (__builtin_add_overflow(left, right, &result)) { + return false; + } +#else + // https://blog.regehr.org/archives/1139 + result = int64_t((uint64_t)left + (uint64_t)right); + if ((left < 0 && right < 0 && result >= 0) || (left >= 0 && right >= 0 && result < 0)) { + return false; + } +#endif + return true; +} + +template <> +bool TryAddOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { + if (!Hugeint::AddInPlace(left, right)) { + return false; + } + result = left; + return true; +} + +//===--------------------------------------------------------------------===// +// add decimal with overflow check +//===--------------------------------------------------------------------===// +template +bool TryDecimalAddTemplated(T left, T right, T &result) { + if (right < 0) { + if (min - right > left) { + return false; + } + } else { + if (max - right < left) { + return false; + } + } + result = left + right; + return true; +} + +template <> +bool TryDecimalAdd::Operation(int16_t left, int16_t right, int16_t &result) { + return TryDecimalAddTemplated(left, right, result); +} + +template <> +bool TryDecimalAdd::Operation(int32_t left, int32_t right, int32_t &result) { + return TryDecimalAddTemplated(left, right, result); +} + +template <> +bool TryDecimalAdd::Operation(int64_t left, int64_t right, int64_t &result) { + return TryDecimalAddTemplated(left, right, result); +} + +template <> +bool TryDecimalAdd::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { + result = left + right; + if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { + return false; + } + return true; +} + +template <> +hugeint_t DecimalAddOverflowCheck::Operation(hugeint_t left, hugeint_t right) { + hugeint_t result; + if (!TryDecimalAdd::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in addition of DECIMAL(38) (%s + %s);", left.ToString(), right.ToString()); + } + return result; +} + +//===--------------------------------------------------------------------===// +// add time operator +//===--------------------------------------------------------------------===// +template <> +dtime_t AddTimeOperator::Operation(dtime_t left, interval_t right) { + date_t date(0); + return Interval::Add(left, right, date); +} + +template <> +dtime_t AddTimeOperator::Operation(interval_t left, dtime_t right) { + return AddTimeOperator::Operation(right, left); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operators/arithmetic.cpp b/src/duckdb/src/function/scalar/operators/arithmetic.cpp new file mode 100644 index 00000000..f316fb0a --- /dev/null +++ b/src/duckdb/src/function/scalar/operators/arithmetic.cpp @@ -0,0 +1,976 @@ +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/operator/numeric_binary_operators.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/function/scalar/operators.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +#include + +namespace duckdb { + +template +static scalar_function_t GetScalarIntegerFunction(PhysicalType type) { + scalar_function_t function; + switch (type) { + case PhysicalType::INT8: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::INT16: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::INT32: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::INT64: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::INT128: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::UINT8: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::UINT16: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::UINT32: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::UINT64: + function = &ScalarFunction::BinaryFunction; + break; + default: + throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction"); + } + return function; +} + +template +static scalar_function_t GetScalarBinaryFunction(PhysicalType type) { + scalar_function_t function; + switch (type) { + case PhysicalType::INT128: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::FLOAT: + function = &ScalarFunction::BinaryFunction; + break; + case PhysicalType::DOUBLE: + function = &ScalarFunction::BinaryFunction; + break; + default: + function = GetScalarIntegerFunction(type); + break; + } + return function; +} + +//===--------------------------------------------------------------------===// +// + [add] +//===--------------------------------------------------------------------===// +struct AddPropagateStatistics { + template + static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, + Value &new_max) { + T min, max; + // new min is min+min + if (!OP::Operation(NumericStats::GetMin(lstats), NumericStats::GetMin(rstats), min)) { + return true; + } + // new max is max+max + if (!OP::Operation(NumericStats::GetMax(lstats), NumericStats::GetMax(rstats), max)) { + return true; + } + new_min = Value::Numeric(type, min); + new_max = Value::Numeric(type, max); + return false; + } +}; + +struct SubtractPropagateStatistics { + template + static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, + Value &new_max) { + T min, max; + if (!OP::Operation(NumericStats::GetMin(lstats), NumericStats::GetMax(rstats), min)) { + return true; + } + if (!OP::Operation(NumericStats::GetMax(lstats), NumericStats::GetMin(rstats), max)) { + return true; + } + new_min = Value::Numeric(type, min); + new_max = Value::Numeric(type, max); + return false; + } +}; + +struct DecimalArithmeticBindData : public FunctionData { + DecimalArithmeticBindData() : check_overflow(true) { + } + + unique_ptr Copy() const override { + auto res = make_uniq(); + res->check_overflow = check_overflow; + return std::move(res); + } + + bool Equals(const FunctionData &other_p) const override { + auto other = other_p.Cast(); + return other.check_overflow == check_overflow; + } + + bool check_overflow; +}; + +template +static unique_ptr PropagateNumericStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 2); + // can only propagate stats if the children have stats + auto &lstats = child_stats[0]; + auto &rstats = child_stats[1]; + Value new_min, new_max; + bool potential_overflow = true; + if (NumericStats::HasMinMax(lstats) && NumericStats::HasMinMax(rstats)) { + switch (expr.return_type.InternalType()) { + case PhysicalType::INT8: + potential_overflow = + PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); + break; + case PhysicalType::INT16: + potential_overflow = + PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); + break; + case PhysicalType::INT32: + potential_overflow = + PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); + break; + case PhysicalType::INT64: + potential_overflow = + PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); + break; + default: + return nullptr; + } + } + if (potential_overflow) { + new_min = Value(expr.return_type); + new_max = Value(expr.return_type); + } else { + // no potential overflow: replace with non-overflowing operator + if (input.bind_data) { + auto &bind_data = input.bind_data->Cast(); + bind_data.check_overflow = false; + } + expr.function.function = GetScalarIntegerFunction(expr.return_type.InternalType()); + } + auto result = NumericStats::CreateEmpty(expr.return_type); + NumericStats::SetMin(result, new_min); + NumericStats::SetMax(result, new_max); + result.CombineValidity(lstats, rstats); + return result.ToUnique(); +} + +template +unique_ptr BindDecimalAddSubtract(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + auto bind_data = make_uniq(); + + // get the max width and scale of the input arguments + uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0; + for (idx_t i = 0; i < arguments.size(); i++) { + if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { + continue; + } + uint8_t width, scale; + auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); + if (!can_convert) { + throw InternalException("Could not convert type %s to a decimal.", arguments[i]->return_type.ToString()); + } + max_width = MaxValue(width, max_width); + max_scale = MaxValue(scale, max_scale); + max_width_over_scale = MaxValue(width - scale, max_width_over_scale); + } + D_ASSERT(max_width > 0); + // for addition/subtraction, we add 1 to the width to ensure we don't overflow + auto required_width = MaxValue(max_scale + max_width_over_scale, max_width) + 1; + if (required_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64) { + // we don't automatically promote past the hugeint boundary to avoid the large hugeint performance penalty + bind_data->check_overflow = true; + required_width = Decimal::MAX_WIDTH_INT64; + } + if (required_width > Decimal::MAX_WIDTH_DECIMAL) { + // target width does not fit in decimal at all: truncate the scale and perform overflow detection + bind_data->check_overflow = true; + required_width = Decimal::MAX_WIDTH_DECIMAL; + } + // arithmetic between two decimal arguments: check the types of the input arguments + LogicalType result_type = LogicalType::DECIMAL(required_width, max_scale); + // we cast all input types to the specified type + for (idx_t i = 0; i < arguments.size(); i++) { + // first check if the cast is necessary + // if the argument has a matching scale and internal type as the output type, no casting is necessary + auto &argument_type = arguments[i]->return_type; + uint8_t width, scale; + argument_type.GetDecimalProperties(width, scale); + if (scale == DecimalType::GetScale(result_type) && argument_type.InternalType() == result_type.InternalType()) { + bound_function.arguments[i] = argument_type; + } else { + bound_function.arguments[i] = result_type; + } + } + bound_function.return_type = result_type; + // now select the physical function to execute + if (bind_data->check_overflow) { + bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); + } else { + bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); + } + if (result_type.InternalType() != PhysicalType::INT128) { + if (IS_SUBTRACT) { + bound_function.statistics = + PropagateNumericStats; + } else { + bound_function.statistics = PropagateNumericStats; + } + } + return std::move(bind_data); +} + +static void SerializeDecimalArithmetic(Serializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "check_overflow", bind_data.check_overflow); + serializer.WriteProperty(101, "return_type", function.return_type); + serializer.WriteProperty(102, "arguments", function.arguments); +} + +// TODO this is partially duplicated from the bind +template +unique_ptr DeserializeDecimalArithmetic(Deserializer &deserializer, ScalarFunction &bound_function) { + + // // re-change the function pointers + auto check_overflow = deserializer.ReadProperty(100, "check_overflow"); + auto return_type = deserializer.ReadProperty(101, "return_type"); + auto arguments = deserializer.ReadProperty>(102, "arguments"); + if (check_overflow) { + bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); + } else { + bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); + } + bound_function.statistics = nullptr; // TODO we likely dont want to do stats prop again + bound_function.return_type = return_type; + bound_function.arguments = arguments; + + auto bind_data = make_uniq(); + bind_data->check_overflow = check_overflow; + return std::move(bind_data); +} + +unique_ptr NopDecimalBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + bound_function.return_type = arguments[0]->return_type; + bound_function.arguments[0] = arguments[0]->return_type; + return nullptr; +} + +ScalarFunction AddFun::GetFunction(const LogicalType &type) { + D_ASSERT(type.IsNumeric()); + if (type.id() == LogicalTypeId::DECIMAL) { + return ScalarFunction("+", {type}, type, ScalarFunction::NopFunction, NopDecimalBind); + } else { + return ScalarFunction("+", {type}, type, ScalarFunction::NopFunction); + } +} + +ScalarFunction AddFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { + if (left_type.IsNumeric() && left_type.id() == right_type.id()) { + if (left_type.id() == LogicalTypeId::DECIMAL) { + auto function = ScalarFunction("+", {left_type, right_type}, left_type, nullptr, + BindDecimalAddSubtract); + function.serialize = SerializeDecimalArithmetic; + function.deserialize = DeserializeDecimalArithmetic; + return function; + } else if (left_type.IsIntegral()) { + return ScalarFunction("+", {left_type, right_type}, left_type, + GetScalarIntegerFunction(left_type.InternalType()), nullptr, + nullptr, PropagateNumericStats); + } else { + return ScalarFunction("+", {left_type, right_type}, left_type, + GetScalarBinaryFunction(left_type.InternalType())); + } + } + + switch (left_type.id()) { + case LogicalTypeId::DATE: + if (right_type.id() == LogicalTypeId::INTEGER) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE, + ScalarFunction::BinaryFunction); + } else if (right_type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE, + ScalarFunction::BinaryFunction); + } else if (right_type.id() == LogicalTypeId::TIME) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP, + ScalarFunction::BinaryFunction); + } + break; + case LogicalTypeId::INTEGER: + if (right_type.id() == LogicalTypeId::DATE) { + return ScalarFunction("+", {left_type, right_type}, right_type, + ScalarFunction::BinaryFunction); + } + break; + case LogicalTypeId::INTERVAL: + if (right_type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::INTERVAL, + ScalarFunction::BinaryFunction); + } else if (right_type.id() == LogicalTypeId::DATE) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE, + ScalarFunction::BinaryFunction); + } else if (right_type.id() == LogicalTypeId::TIME) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::TIME, + ScalarFunction::BinaryFunction); + } else if (right_type.id() == LogicalTypeId::TIMESTAMP) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP, + ScalarFunction::BinaryFunction); + } + break; + case LogicalTypeId::TIME: + if (right_type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::TIME, + ScalarFunction::BinaryFunction); + } else if (right_type.id() == LogicalTypeId::DATE) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP, + ScalarFunction::BinaryFunction); + } + break; + case LogicalTypeId::TIMESTAMP: + if (right_type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP, + ScalarFunction::BinaryFunction); + } + break; + default: + break; + } + // LCOV_EXCL_START + throw NotImplementedException("AddFun for types %s, %s", EnumUtil::ToString(left_type.id()), + EnumUtil::ToString(right_type.id())); + // LCOV_EXCL_STOP +} + +void AddFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunctionSet functions("+"); + for (auto &type : LogicalType::Numeric()) { + // unary add function is a nop, but only exists for numeric types + functions.AddFunction(GetFunction(type)); + // binary add function adds two numbers together + functions.AddFunction(GetFunction(type, type)); + } + // we can add integers to dates + functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::INTEGER)); + functions.AddFunction(GetFunction(LogicalType::INTEGER, LogicalType::DATE)); + // we can add intervals together + functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::INTERVAL)); + // we can add intervals to dates/times/timestamps + functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::INTERVAL)); + functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::DATE)); + + functions.AddFunction(GetFunction(LogicalType::TIME, LogicalType::INTERVAL)); + functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::TIME)); + + functions.AddFunction(GetFunction(LogicalType::TIMESTAMP, LogicalType::INTERVAL)); + functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::TIMESTAMP)); + + // we can add times to dates + functions.AddFunction(GetFunction(LogicalType::TIME, LogicalType::DATE)); + functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::TIME)); + + // we can add lists together + functions.AddFunction(ListConcatFun::GetFunction()); + + set.AddFunction(functions); + + functions.name = "add"; + set.AddFunction(functions); +} + +//===--------------------------------------------------------------------===// +// - [subtract] +//===--------------------------------------------------------------------===// +struct NegateOperator { + template + static bool CanNegate(T input) { + using Limits = std::numeric_limits; + return !(Limits::is_integer && Limits::is_signed && Limits::lowest() == input); + } + + template + static inline TR Operation(TA input) { + auto cast = (TR)input; + if (!CanNegate(cast)) { + throw OutOfRangeException("Overflow in negation of integer!"); + } + return -cast; + } +}; + +template <> +bool NegateOperator::CanNegate(float input) { + return true; +} + +template <> +bool NegateOperator::CanNegate(double input) { + return true; +} + +template <> +interval_t NegateOperator::Operation(interval_t input) { + interval_t result; + result.months = NegateOperator::Operation(input.months); + result.days = NegateOperator::Operation(input.days); + result.micros = NegateOperator::Operation(input.micros); + return result; +} + +struct DecimalNegateBindData : public FunctionData { + DecimalNegateBindData() : bound_type(LogicalTypeId::INVALID) { + } + + unique_ptr Copy() const override { + auto res = make_uniq(); + res->bound_type = bound_type; + return std::move(res); + } + + bool Equals(const FunctionData &other_p) const override { + auto other = other_p.Cast(); + return other.bound_type == bound_type; + } + + LogicalTypeId bound_type; +}; + +unique_ptr DecimalNegateBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + auto bind_data = make_uniq(); + + auto &decimal_type = arguments[0]->return_type; + auto width = DecimalType::GetWidth(decimal_type); + if (width <= Decimal::MAX_WIDTH_INT16) { + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); + } else if (width <= Decimal::MAX_WIDTH_INT32) { + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); + } else if (width <= Decimal::MAX_WIDTH_INT64) { + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); + } else { + D_ASSERT(width <= Decimal::MAX_WIDTH_INT128); + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); + } + decimal_type.Verify(); + bound_function.arguments[0] = decimal_type; + bound_function.return_type = decimal_type; + return nullptr; +} + +struct NegatePropagateStatistics { + template + static bool Operation(LogicalType type, BaseStatistics &istats, Value &new_min, Value &new_max) { + auto max_value = NumericStats::GetMax(istats); + auto min_value = NumericStats::GetMin(istats); + if (!NegateOperator::CanNegate(min_value) || !NegateOperator::CanNegate(max_value)) { + return true; + } + // new min is -max + new_min = Value::Numeric(type, NegateOperator::Operation(max_value)); + // new max is -min + new_max = Value::Numeric(type, NegateOperator::Operation(min_value)); + return false; + } +}; + +static unique_ptr NegateBindStatistics(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 1); + // can only propagate stats if the children have stats + auto &istats = child_stats[0]; + Value new_min, new_max; + bool potential_overflow = true; + if (NumericStats::HasMinMax(istats)) { + switch (expr.return_type.InternalType()) { + case PhysicalType::INT8: + potential_overflow = + NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); + break; + case PhysicalType::INT16: + potential_overflow = + NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); + break; + case PhysicalType::INT32: + potential_overflow = + NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); + break; + case PhysicalType::INT64: + potential_overflow = + NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); + break; + default: + return nullptr; + } + } + if (potential_overflow) { + new_min = Value(expr.return_type); + new_max = Value(expr.return_type); + } + auto stats = NumericStats::CreateEmpty(expr.return_type); + NumericStats::SetMin(stats, new_min); + NumericStats::SetMax(stats, new_max); + stats.CopyValidity(istats); + return stats.ToUnique(); +} + +ScalarFunction SubtractFun::GetFunction(const LogicalType &type) { + if (type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction("-", {type}, type, ScalarFunction::UnaryFunction); + } else if (type.id() == LogicalTypeId::DECIMAL) { + return ScalarFunction("-", {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics); + } else { + D_ASSERT(type.IsNumeric()); + return ScalarFunction("-", {type}, type, ScalarFunction::GetScalarUnaryFunction(type), nullptr, + nullptr, NegateBindStatistics); + } +} + +ScalarFunction SubtractFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { + if (left_type.IsNumeric() && left_type.id() == right_type.id()) { + if (left_type.id() == LogicalTypeId::DECIMAL) { + auto function = + ScalarFunction("-", {left_type, right_type}, left_type, nullptr, + BindDecimalAddSubtract); + function.serialize = SerializeDecimalArithmetic; + function.deserialize = DeserializeDecimalArithmetic; + return function; + } else if (left_type.IsIntegral()) { + return ScalarFunction( + "-", {left_type, right_type}, left_type, + GetScalarIntegerFunction(left_type.InternalType()), nullptr, nullptr, + PropagateNumericStats); + + } else { + return ScalarFunction("-", {left_type, right_type}, left_type, + GetScalarBinaryFunction(left_type.InternalType())); + } + } + + switch (left_type.id()) { + case LogicalTypeId::DATE: + if (right_type.id() == LogicalTypeId::DATE) { + return ScalarFunction("-", {left_type, right_type}, LogicalType::BIGINT, + ScalarFunction::BinaryFunction); + + } else if (right_type.id() == LogicalTypeId::INTEGER) { + return ScalarFunction("-", {left_type, right_type}, LogicalType::DATE, + ScalarFunction::BinaryFunction); + } else if (right_type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction("-", {left_type, right_type}, LogicalType::DATE, + ScalarFunction::BinaryFunction); + } + break; + case LogicalTypeId::TIMESTAMP: + if (right_type.id() == LogicalTypeId::TIMESTAMP) { + return ScalarFunction( + "-", {left_type, right_type}, LogicalType::INTERVAL, + ScalarFunction::BinaryFunction); + } else if (right_type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction( + "-", {left_type, right_type}, LogicalType::TIMESTAMP, + ScalarFunction::BinaryFunction); + } + break; + case LogicalTypeId::INTERVAL: + if (right_type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction("-", {left_type, right_type}, LogicalType::INTERVAL, + ScalarFunction::BinaryFunction); + } + break; + case LogicalTypeId::TIME: + if (right_type.id() == LogicalTypeId::INTERVAL) { + return ScalarFunction("-", {left_type, right_type}, LogicalType::TIME, + ScalarFunction::BinaryFunction); + } + break; + default: + break; + } + // LCOV_EXCL_START + throw NotImplementedException("SubtractFun for types %s, %s", EnumUtil::ToString(left_type.id()), + EnumUtil::ToString(right_type.id())); + // LCOV_EXCL_STOP +} + +void SubtractFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunctionSet functions("-"); + for (auto &type : LogicalType::Numeric()) { + // unary subtract function, negates the input (i.e. multiplies by -1) + functions.AddFunction(GetFunction(type)); + // binary subtract function "a - b", subtracts b from a + functions.AddFunction(GetFunction(type, type)); + } + // we can subtract dates from each other + functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::DATE)); + // we can subtract integers from dates + functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::INTEGER)); + // we can subtract timestamps from each other + functions.AddFunction(GetFunction(LogicalType::TIMESTAMP, LogicalType::TIMESTAMP)); + // we can subtract intervals from each other + functions.AddFunction(GetFunction(LogicalType::INTERVAL, LogicalType::INTERVAL)); + // we can subtract intervals from dates/times/timestamps, but not the other way around + functions.AddFunction(GetFunction(LogicalType::DATE, LogicalType::INTERVAL)); + functions.AddFunction(GetFunction(LogicalType::TIME, LogicalType::INTERVAL)); + functions.AddFunction(GetFunction(LogicalType::TIMESTAMP, LogicalType::INTERVAL)); + // we can negate intervals + functions.AddFunction(GetFunction(LogicalType::INTERVAL)); + set.AddFunction(functions); + + functions.name = "subtract"; + set.AddFunction(functions); +} + +//===--------------------------------------------------------------------===// +// * [multiply] +//===--------------------------------------------------------------------===// +struct MultiplyPropagateStatistics { + template + static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, + Value &new_max) { + // statistics propagation on the multiplication is slightly less straightforward because of negative numbers + // the new min/max depend on the signs of the input types + // if both are positive the result is [lmin * rmin][lmax * rmax] + // if lmin/lmax are negative the result is [lmin * rmax][lmax * rmin] + // etc + // rather than doing all this switcheroo we just multiply all combinations of lmin/lmax with rmin/rmax + // and check what the minimum/maximum value is + T lvals[] {NumericStats::GetMin(lstats), NumericStats::GetMax(lstats)}; + T rvals[] {NumericStats::GetMin(rstats), NumericStats::GetMax(rstats)}; + T min = NumericLimits::Maximum(); + T max = NumericLimits::Minimum(); + // multiplications + for (idx_t l = 0; l < 2; l++) { + for (idx_t r = 0; r < 2; r++) { + T result; + if (!OP::Operation(lvals[l], rvals[r], result)) { + // potential overflow + return true; + } + if (result < min) { + min = result; + } + if (result > max) { + max = result; + } + } + } + new_min = Value::Numeric(type, min); + new_max = Value::Numeric(type, max); + return false; + } +}; + +unique_ptr BindDecimalMultiply(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + auto bind_data = make_uniq(); + + uint8_t result_width = 0, result_scale = 0; + uint8_t max_width = 0; + for (idx_t i = 0; i < arguments.size(); i++) { + if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { + continue; + } + uint8_t width, scale; + auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); + if (!can_convert) { + throw InternalException("Could not convert type %s to a decimal?", arguments[i]->return_type.ToString()); + } + if (width > max_width) { + max_width = width; + } + result_width += width; + result_scale += scale; + } + D_ASSERT(max_width > 0); + if (result_scale > Decimal::MAX_WIDTH_DECIMAL) { + throw OutOfRangeException( + "Needed scale %d to accurately represent the multiplication result, but this is out of range of the " + "DECIMAL type. Max scale is %d; could not perform an accurate multiplication. Either add a cast to DOUBLE, " + "or add an explicit cast to a decimal with a lower scale.", + result_scale, Decimal::MAX_WIDTH_DECIMAL); + } + if (result_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64 && + result_scale < Decimal::MAX_WIDTH_INT64) { + bind_data->check_overflow = true; + result_width = Decimal::MAX_WIDTH_INT64; + } + if (result_width > Decimal::MAX_WIDTH_DECIMAL) { + bind_data->check_overflow = true; + result_width = Decimal::MAX_WIDTH_DECIMAL; + } + LogicalType result_type = LogicalType::DECIMAL(result_width, result_scale); + // since our scale is the summation of our input scales, we do not need to cast to the result scale + // however, we might need to cast to the correct internal type + for (idx_t i = 0; i < arguments.size(); i++) { + auto &argument_type = arguments[i]->return_type; + if (argument_type.InternalType() == result_type.InternalType()) { + bound_function.arguments[i] = argument_type; + } else { + uint8_t width, scale; + if (!argument_type.GetDecimalProperties(width, scale)) { + scale = 0; + } + + bound_function.arguments[i] = LogicalType::DECIMAL(result_width, scale); + } + } + result_type.Verify(); + bound_function.return_type = result_type; + // now select the physical function to execute + if (bind_data->check_overflow) { + bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); + } else { + bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); + } + if (result_type.InternalType() != PhysicalType::INT128) { + bound_function.statistics = + PropagateNumericStats; + } + return std::move(bind_data); +} + +void MultiplyFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunctionSet functions("*"); + for (auto &type : LogicalType::Numeric()) { + if (type.id() == LogicalTypeId::DECIMAL) { + ScalarFunction function({type, type}, type, nullptr, BindDecimalMultiply); + function.serialize = SerializeDecimalArithmetic; + function.deserialize = DeserializeDecimalArithmetic; + functions.AddFunction(function); + } else if (TypeIsIntegral(type.InternalType())) { + functions.AddFunction(ScalarFunction( + {type, type}, type, GetScalarIntegerFunction(type.InternalType()), + nullptr, nullptr, + PropagateNumericStats)); + } else { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarBinaryFunction(type.InternalType()))); + } + } + functions.AddFunction( + ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::BinaryFunction)); + functions.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::INTERVAL}, LogicalType::INTERVAL, + ScalarFunction::BinaryFunction)); + set.AddFunction(functions); + + functions.name = "multiply"; + set.AddFunction(functions); +} + +//===--------------------------------------------------------------------===// +// / [divide] +//===--------------------------------------------------------------------===// +template <> +float DivideOperator::Operation(float left, float right) { + auto result = left / right; + return result; +} + +template <> +double DivideOperator::Operation(double left, double right) { + auto result = left / right; + return result; +} + +template <> +hugeint_t DivideOperator::Operation(hugeint_t left, hugeint_t right) { + if (right.lower == 0 && right.upper == 0) { + throw InternalException("Hugeint division by zero!"); + } + return left / right; +} + +template <> +interval_t DivideOperator::Operation(interval_t left, int64_t right) { + left.days /= right; + left.months /= right; + left.micros /= right; + return left; +} + +struct BinaryNumericDivideWrapper { + template + static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { + if (left == NumericLimits::Minimum() && right == -1) { + throw OutOfRangeException("Overflow in division of %d / %d", left, right); + } else if (right == 0) { + mask.SetInvalid(idx); + return left; + } else { + return OP::template Operation(left, right); + } + } + + static bool AddsNulls() { + return true; + } +}; + +struct BinaryZeroIsNullWrapper { + template + static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { + if (right == 0) { + mask.SetInvalid(idx); + return left; + } else { + return OP::template Operation(left, right); + } + } + + static bool AddsNulls() { + return true; + } +}; + +struct BinaryZeroIsNullHugeintWrapper { + template + static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { + if (right.upper == 0 && right.lower == 0) { + mask.SetInvalid(idx); + return left; + } else { + return OP::template Operation(left, right); + } + } + + static bool AddsNulls() { + return true; + } +}; + +template +static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute(input.data[0], input.data[1], result, input.size()); +} + +template +static scalar_function_t GetBinaryFunctionIgnoreZero(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::SMALLINT: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::INTEGER: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::BIGINT: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::UTINYINT: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::USMALLINT: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::UINTEGER: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::UBIGINT: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::HUGEINT: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::FLOAT: + return BinaryScalarFunctionIgnoreZero; + case LogicalTypeId::DOUBLE: + return BinaryScalarFunctionIgnoreZero; + default: + throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction"); + } +} + +void DivideFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunctionSet fp_divide("/"); + fp_divide.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, + GetBinaryFunctionIgnoreZero(LogicalType::FLOAT))); + fp_divide.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, + GetBinaryFunctionIgnoreZero(LogicalType::DOUBLE))); + fp_divide.AddFunction( + ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, + BinaryScalarFunctionIgnoreZero)); + set.AddFunction(fp_divide); + + ScalarFunctionSet full_divide("//"); + for (auto &type : LogicalType::Numeric()) { + if (type.id() == LogicalTypeId::DECIMAL) { + continue; + } else { + full_divide.AddFunction( + ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero(type))); + } + } + set.AddFunction(full_divide); + + full_divide.name = "divide"; + set.AddFunction(full_divide); +} + +//===--------------------------------------------------------------------===// +// % [modulo] +//===--------------------------------------------------------------------===// +template <> +float ModuloOperator::Operation(float left, float right) { + D_ASSERT(right != 0); + auto result = std::fmod(left, right); + return result; +} + +template <> +double ModuloOperator::Operation(double left, double right) { + D_ASSERT(right != 0); + auto result = std::fmod(left, right); + return result; +} + +template <> +hugeint_t ModuloOperator::Operation(hugeint_t left, hugeint_t right) { + if (right.lower == 0 && right.upper == 0) { + throw InternalException("Hugeint division by zero!"); + } + return left % right; +} + +void ModFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunctionSet functions("%"); + for (auto &type : LogicalType::Numeric()) { + if (type.id() == LogicalTypeId::DECIMAL) { + continue; + } else { + functions.AddFunction( + ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero(type))); + } + } + set.AddFunction(functions); + functions.name = "mod"; + set.AddFunction(functions); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operators/multiply.cpp b/src/duckdb/src/function/scalar/operators/multiply.cpp new file mode 100644 index 00000000..718f9a8e --- /dev/null +++ b/src/duckdb/src/function/scalar/operators/multiply.cpp @@ -0,0 +1,232 @@ +#include "duckdb/common/operator/multiply.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/windows_undefs.hpp" + +#include +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// * [multiply] +//===--------------------------------------------------------------------===// +template <> +float MultiplyOperator::Operation(float left, float right) { + auto result = left * right; + return result; +} + +template <> +double MultiplyOperator::Operation(double left, double right) { + auto result = left * right; + return result; +} + +template <> +interval_t MultiplyOperator::Operation(interval_t left, int64_t right) { + left.months = MultiplyOperatorOverflowCheck::Operation(left.months, right); + left.days = MultiplyOperatorOverflowCheck::Operation(left.days, right); + left.micros = MultiplyOperatorOverflowCheck::Operation(left.micros, right); + return left; +} + +template <> +interval_t MultiplyOperator::Operation(int64_t left, interval_t right) { + return MultiplyOperator::Operation(right, left); +} + +//===--------------------------------------------------------------------===// +// * [multiply] with overflow check +//===--------------------------------------------------------------------===// +struct OverflowCheckedMultiply { + template + static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { + UTYPE uresult = MultiplyOperator::Operation(UTYPE(left), UTYPE(right)); + if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { + return false; + } + result = SRCTYPE(uresult); + return true; + } +}; + +template <> +bool TryMultiplyOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { + return OverflowCheckedMultiply::Operation(left, right, result); +} +template <> +bool TryMultiplyOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { + return OverflowCheckedMultiply::Operation(left, right, result); +} +template <> +bool TryMultiplyOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { + return OverflowCheckedMultiply::Operation(left, right, result); +} +template <> +bool TryMultiplyOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { + if (left > right) { + std::swap(left, right); + } + if (left > NumericLimits::Maximum()) { + return false; + } + uint32_t c = right >> 32; + uint32_t d = NumericLimits::Maximum() & right; + uint64_t r = left * c; + uint64_t s = left * d; + if (r > NumericLimits::Maximum()) { + return false; + } + r <<= 32; + if (NumericLimits::Maximum() - s < r) { + return false; + } + return OverflowCheckedMultiply::Operation(left, right, result); +} + +template <> +bool TryMultiplyOperator::Operation(int8_t left, int8_t right, int8_t &result) { + return OverflowCheckedMultiply::Operation(left, right, result); +} + +template <> +bool TryMultiplyOperator::Operation(int16_t left, int16_t right, int16_t &result) { + return OverflowCheckedMultiply::Operation(left, right, result); +} + +template <> +bool TryMultiplyOperator::Operation(int32_t left, int32_t right, int32_t &result) { + return OverflowCheckedMultiply::Operation(left, right, result); +} + +template <> +bool TryMultiplyOperator::Operation(int64_t left, int64_t right, int64_t &result) { +#if (__GNUC__ >= 5) || defined(__clang__) + if (__builtin_mul_overflow(left, right, &result)) { + return false; + } +#else + if (left == std::numeric_limits::min()) { + if (right == 0) { + result = 0; + return true; + } + if (right == 1) { + result = left; + return true; + } + return false; + } + if (right == std::numeric_limits::min()) { + if (left == 0) { + result = 0; + return true; + } + if (left == 1) { + result = right; + return true; + } + return false; + } + uint64_t left_non_negative = uint64_t(std::abs(left)); + uint64_t right_non_negative = uint64_t(std::abs(right)); + // split values into 2 32-bit parts + uint64_t left_high_bits = left_non_negative >> 32; + uint64_t left_low_bits = left_non_negative & 0xffffffff; + uint64_t right_high_bits = right_non_negative >> 32; + uint64_t right_low_bits = right_non_negative & 0xffffffff; + + // check the high bits of both + // the high bits define the overflow + if (left_high_bits == 0) { + if (right_high_bits != 0) { + // only the right has high bits set + // multiply the high bits of right with the low bits of left + // multiply the low bits, and carry any overflow to the high bits + // then check for any overflow + auto low_low = left_low_bits * right_low_bits; + auto low_high = left_low_bits * right_high_bits; + auto high_bits = low_high + (low_low >> 32); + if (high_bits & 0xffffff80000000) { + // there is! abort + return false; + } + } + } else if (right_high_bits == 0) { + // only the left has high bits set + // multiply the high bits of left with the low bits of right + // multiply the low bits, and carry any overflow to the high bits + // then check for any overflow + auto low_low = left_low_bits * right_low_bits; + auto high_low = left_high_bits * right_low_bits; + auto high_bits = high_low + (low_low >> 32); + if (high_bits & 0xffffff80000000) { + // there is! abort + return false; + } + } else { + // both left and right have high bits set: guaranteed overflow + // abort! + return false; + } + // now we know that there is no overflow, we can just perform the multiplication + result = left * right; +#endif + return true; +} + +template <> +bool TryMultiplyOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { + return Hugeint::TryMultiply(left, right, result); +} + +//===--------------------------------------------------------------------===// +// multiply decimal with overflow check +//===--------------------------------------------------------------------===// +template +bool TryDecimalMultiplyTemplated(T left, T right, T &result) { + if (!TryMultiplyOperator::Operation(left, right, result) || result < min || result > max) { + return false; + } + return true; +} + +template <> +bool TryDecimalMultiply::Operation(int16_t left, int16_t right, int16_t &result) { + return TryDecimalMultiplyTemplated(left, right, result); +} + +template <> +bool TryDecimalMultiply::Operation(int32_t left, int32_t right, int32_t &result) { + return TryDecimalMultiplyTemplated(left, right, result); +} + +template <> +bool TryDecimalMultiply::Operation(int64_t left, int64_t right, int64_t &result) { + return TryDecimalMultiplyTemplated(left, right, result); +} + +template <> +bool TryDecimalMultiply::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { + result = left * right; + if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { + return false; + } + return true; +} + +template <> +hugeint_t DecimalMultiplyOverflowCheck::Operation(hugeint_t left, hugeint_t right) { + hugeint_t result; + if (!TryDecimalMultiply::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in multiplication of DECIMAL(38) (%s * %s). You might want to add an " + "explicit cast to a decimal with a smaller scale.", + left.ToString(), right.ToString()); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operators/subtract.cpp b/src/duckdb/src/function/scalar/operators/subtract.cpp new file mode 100644 index 00000000..c9918a38 --- /dev/null +++ b/src/duckdb/src/function/scalar/operators/subtract.cpp @@ -0,0 +1,222 @@ +#include "duckdb/common/operator/subtract.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// - [subtract] +//===--------------------------------------------------------------------===// +template <> +float SubtractOperator::Operation(float left, float right) { + auto result = left - right; + return result; +} + +template <> +double SubtractOperator::Operation(double left, double right) { + auto result = left - right; + return result; +} + +template <> +int64_t SubtractOperator::Operation(date_t left, date_t right) { + return int64_t(left.days) - int64_t(right.days); +} + +template <> +date_t SubtractOperator::Operation(date_t left, int32_t right) { + if (!Date::IsFinite(left)) { + return left; + } + int32_t days; + if (!TrySubtractOperator::Operation(left.days, right, days)) { + throw OutOfRangeException("Date out of range"); + } + + date_t result(days); + if (!Date::IsFinite(result)) { + throw OutOfRangeException("Date out of range"); + } + return result; +} + +template <> +interval_t SubtractOperator::Operation(interval_t left, interval_t right) { + interval_t result; + result.months = left.months - right.months; + result.days = left.days - right.days; + result.micros = left.micros - right.micros; + return result; +} + +template <> +date_t SubtractOperator::Operation(date_t left, interval_t right) { + return AddOperator::Operation(left, Interval::Invert(right)); +} + +template <> +timestamp_t SubtractOperator::Operation(timestamp_t left, interval_t right) { + return AddOperator::Operation(left, Interval::Invert(right)); +} + +template <> +interval_t SubtractOperator::Operation(timestamp_t left, timestamp_t right) { + return Interval::GetDifference(left, right); +} + +//===--------------------------------------------------------------------===// +// - [subtract] with overflow check +//===--------------------------------------------------------------------===// +struct OverflowCheckedSubtract { + template + static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { + UTYPE uresult = SubtractOperator::Operation(UTYPE(left), UTYPE(right)); + if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { + return false; + } + result = SRCTYPE(uresult); + return true; + } +}; + +template <> +bool TrySubtractOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { + if (right > left) { + return false; + } + return OverflowCheckedSubtract::Operation(left, right, result); +} + +template <> +bool TrySubtractOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { + if (right > left) { + return false; + } + return OverflowCheckedSubtract::Operation(left, right, result); +} + +template <> +bool TrySubtractOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { + if (right > left) { + return false; + } + return OverflowCheckedSubtract::Operation(left, right, result); +} + +template <> +bool TrySubtractOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { + if (right > left) { + return false; + } + return OverflowCheckedSubtract::Operation(left, right, result); +} + +template <> +bool TrySubtractOperator::Operation(int8_t left, int8_t right, int8_t &result) { + return OverflowCheckedSubtract::Operation(left, right, result); +} + +template <> +bool TrySubtractOperator::Operation(int16_t left, int16_t right, int16_t &result) { + return OverflowCheckedSubtract::Operation(left, right, result); +} + +template <> +bool TrySubtractOperator::Operation(int32_t left, int32_t right, int32_t &result) { + return OverflowCheckedSubtract::Operation(left, right, result); +} + +template <> +bool TrySubtractOperator::Operation(int64_t left, int64_t right, int64_t &result) { +#if (__GNUC__ >= 5) || defined(__clang__) + if (__builtin_sub_overflow(left, right, &result)) { + return false; + } +#else + if (right < 0) { + if (NumericLimits::Maximum() + right < left) { + return false; + } + } else { + if (NumericLimits::Minimum() + right > left) { + return false; + } + } + result = left - right; +#endif + return true; +} + +template <> +bool TrySubtractOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { + result = left; + return Hugeint::SubtractInPlace(result, right); +} + +//===--------------------------------------------------------------------===// +// subtract decimal with overflow check +//===--------------------------------------------------------------------===// +template +bool TryDecimalSubtractTemplated(T left, T right, T &result) { + if (right < 0) { + if (max + right < left) { + return false; + } + } else { + if (min + right > left) { + return false; + } + } + result = left - right; + return true; +} + +template <> +bool TryDecimalSubtract::Operation(int16_t left, int16_t right, int16_t &result) { + return TryDecimalSubtractTemplated(left, right, result); +} + +template <> +bool TryDecimalSubtract::Operation(int32_t left, int32_t right, int32_t &result) { + return TryDecimalSubtractTemplated(left, right, result); +} + +template <> +bool TryDecimalSubtract::Operation(int64_t left, int64_t right, int64_t &result) { + return TryDecimalSubtractTemplated(left, right, result); +} + +template <> +bool TryDecimalSubtract::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { + result = left - right; + if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { + return false; + } + return true; +} + +template <> +hugeint_t DecimalSubtractOverflowCheck::Operation(hugeint_t left, hugeint_t right) { + hugeint_t result; + if (!TryDecimalSubtract::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in subtract of DECIMAL(38) (%s - %s);", left.ToString(), right.ToString()); + } + return result; +} + +//===--------------------------------------------------------------------===// +// subtract time operator +//===--------------------------------------------------------------------===// +template <> +dtime_t SubtractTimeOperator::Operation(dtime_t left, interval_t right) { + right.micros = -right.micros; + return AddTimeOperator::Operation(left, right); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/pragma_functions.cpp b/src/duckdb/src/function/scalar/pragma_functions.cpp new file mode 100644 index 00000000..dfce174c --- /dev/null +++ b/src/duckdb/src/function/scalar/pragma_functions.cpp @@ -0,0 +1,10 @@ +#include "duckdb/function/pragma/pragma_functions.hpp" + +namespace duckdb { + +void BuiltinFunctions::RegisterPragmaFunctions() { + Register(); + Register(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/sequence/nextval.cpp b/src/duckdb/src/function/scalar/sequence/nextval.cpp new file mode 100644 index 00000000..9b54e7ab --- /dev/null +++ b/src/duckdb/src/function/scalar/sequence/nextval.cpp @@ -0,0 +1,151 @@ +#include "duckdb/function/scalar/sequence_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/dependency_list.hpp" +#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +struct NextvalBindData : public FunctionData { + explicit NextvalBindData(optional_ptr sequence) : sequence(sequence) { + } + + //! The sequence to use for the nextval computation; only if the sequence is a constant + optional_ptr sequence; + + unique_ptr Copy() const override { + return make_uniq(sequence); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return sequence == other.sequence; + } +}; + +struct CurrentSequenceValueOperator { + static int64_t Operation(DuckTransaction &transaction, SequenceCatalogEntry &seq) { + lock_guard seqlock(seq.lock); + int64_t result; + if (seq.usage_count == 0u) { + throw SequenceException("currval: sequence is not yet defined in this session"); + } + result = seq.last_value; + return result; + } +}; + +struct NextSequenceValueOperator { + static int64_t Operation(DuckTransaction &transaction, SequenceCatalogEntry &seq) { + lock_guard seqlock(seq.lock); + int64_t result; + result = seq.counter; + bool overflow = !TryAddOperator::Operation(seq.counter, seq.increment, seq.counter); + if (seq.cycle) { + if (overflow) { + seq.counter = seq.increment < 0 ? seq.max_value : seq.min_value; + } else if (seq.counter < seq.min_value) { + seq.counter = seq.max_value; + } else if (seq.counter > seq.max_value) { + seq.counter = seq.min_value; + } + } else { + if (result < seq.min_value || (overflow && seq.increment < 0)) { + throw SequenceException("nextval: reached minimum value of sequence \"%s\" (%lld)", seq.name, + seq.min_value); + } + if (result > seq.max_value || overflow) { + throw SequenceException("nextval: reached maximum value of sequence \"%s\" (%lld)", seq.name, + seq.max_value); + } + } + seq.last_value = result; + seq.usage_count++; + if (!seq.temporary) { + transaction.sequence_usage[&seq] = SequenceValue(seq.usage_count, seq.counter); + } + return result; + } +}; + +SequenceCatalogEntry &BindSequence(ClientContext &context, const string &name) { + auto qname = QualifiedName::Parse(name); + // fetch the sequence from the catalog + Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); + return Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); +} + +template +static void NextValFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto &input = args.data[0]; + + auto &context = state.GetContext(); + if (info.sequence) { + auto &sequence = *info.sequence; + auto &transaction = DuckTransaction::Get(context, sequence.catalog); + // sequence to use is hard coded + // increment the sequence + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < args.size(); i++) { + // get the next value from the sequence + result_data[i] = OP::Operation(transaction, sequence); + } + } else { + // sequence to use comes from the input + UnaryExecutor::Execute(input, result, args.size(), [&](string_t value) { + // fetch the sequence from the catalog + auto &sequence = BindSequence(context, value.GetString()); + // finally get the next value from the sequence + auto &transaction = DuckTransaction::Get(context, sequence.catalog); + return OP::Operation(transaction, sequence); + }); + } +} + +static unique_ptr NextValBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + optional_ptr sequence; + if (arguments[0]->IsFoldable()) { + // parameter to nextval function is a foldable constant + // evaluate the constant and perform the catalog lookup already + auto seqname = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + if (!seqname.IsNull()) { + sequence = &BindSequence(context, seqname.ToString()); + } + } + return make_uniq(sequence); +} + +static void NextValDependency(BoundFunctionExpression &expr, DependencyList &dependencies) { + auto &info = expr.bind_info->Cast(); + if (info.sequence) { + dependencies.AddDependency(*info.sequence); + } +} + +void NextvalFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunction next_val("nextval", {LogicalType::VARCHAR}, LogicalType::BIGINT, + NextValFunction, NextValBind, NextValDependency); + next_val.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + set.AddFunction(next_val); +} + +void CurrvalFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunction curr_val("currval", {LogicalType::VARCHAR}, LogicalType::BIGINT, + NextValFunction, NextValBind, NextValDependency); + curr_val.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS; + set.AddFunction(curr_val); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/sequence_functions.cpp b/src/duckdb/src/function/scalar/sequence_functions.cpp new file mode 100644 index 00000000..30b0c065 --- /dev/null +++ b/src/duckdb/src/function/scalar/sequence_functions.cpp @@ -0,0 +1,10 @@ +#include "duckdb/function/scalar/sequence_functions.hpp" + +namespace duckdb { + +void BuiltinFunctions::RegisterSequenceFunctions() { + Register(); + Register(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/strftime_format.cpp b/src/duckdb/src/function/scalar/strftime_format.cpp new file mode 100644 index 00000000..589dd58b --- /dev/null +++ b/src/duckdb/src/function/scalar/strftime_format.cpp @@ -0,0 +1,1204 @@ +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" + +namespace duckdb { + +idx_t StrfTimepecifierSize(StrTimeSpecifier specifier) { + switch (specifier) { + case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: + case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: + return 3; + case StrTimeSpecifier::WEEKDAY_DECIMAL: + return 1; + case StrTimeSpecifier::DAY_OF_MONTH_PADDED: + case StrTimeSpecifier::MONTH_DECIMAL_PADDED: + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: + case StrTimeSpecifier::HOUR_24_PADDED: + case StrTimeSpecifier::HOUR_12_PADDED: + case StrTimeSpecifier::MINUTE_PADDED: + case StrTimeSpecifier::SECOND_PADDED: + case StrTimeSpecifier::AM_PM: + case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: + case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: + return 2; + case StrTimeSpecifier::NANOSECOND_PADDED: + return 9; + case StrTimeSpecifier::MICROSECOND_PADDED: + return 6; + case StrTimeSpecifier::MILLISECOND_PADDED: + return 3; + case StrTimeSpecifier::DAY_OF_YEAR_PADDED: + return 3; + default: + return 0; + } +} + +void StrTimeFormat::AddLiteral(string literal) { + constant_size += literal.size(); + literals.push_back(std::move(literal)); +} + +void StrTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { + AddLiteral(std::move(preceding_literal)); + specifiers.push_back(specifier); +} + +void StrfTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { + is_date_specifier.push_back(IsDateSpecifier(specifier)); + idx_t specifier_size = StrfTimepecifierSize(specifier); + if (specifier_size == 0) { + // variable length specifier + var_length_specifiers.push_back(specifier); + } else { + // constant size specifier + constant_size += specifier_size; + } + StrTimeFormat::AddFormatSpecifier(std::move(preceding_literal), specifier); +} + +idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date, dtime_t time, int32_t utc_offset, + const char *tz_name) { + switch (specifier) { + case StrTimeSpecifier::FULL_WEEKDAY_NAME: + return Date::DAY_NAMES[Date::ExtractISODayOfTheWeek(date) % 7].GetSize(); + case StrTimeSpecifier::FULL_MONTH_NAME: + return Date::MONTH_NAMES[Date::ExtractMonth(date) - 1].GetSize(); + case StrTimeSpecifier::YEAR_DECIMAL: { + auto year = Date::ExtractYear(date); + // Be consistent with WriteStandardSpecifier + if (0 <= year && year <= 9999) { + return 4; + } else { + return NumericHelper::SignedLength(year); + } + } + case StrTimeSpecifier::MONTH_DECIMAL: { + idx_t len = 1; + auto month = Date::ExtractMonth(date); + len += month >= 10; + return len; + } + case StrTimeSpecifier::UTC_OFFSET: + // ±HH or ±HH:MM + return (utc_offset % 60) ? 6 : 3; + case StrTimeSpecifier::TZ_NAME: + if (tz_name) { + return strlen(tz_name); + } + // empty for now + return 0; + case StrTimeSpecifier::HOUR_24_DECIMAL: + case StrTimeSpecifier::HOUR_12_DECIMAL: + case StrTimeSpecifier::MINUTE_DECIMAL: + case StrTimeSpecifier::SECOND_DECIMAL: { + // time specifiers + idx_t len = 1; + int32_t hour, min, sec, msec; + Time::Convert(time, hour, min, sec, msec); + switch (specifier) { + case StrTimeSpecifier::HOUR_24_DECIMAL: + len += hour >= 10; + break; + case StrTimeSpecifier::HOUR_12_DECIMAL: + hour = hour % 12; + if (hour == 0) { + hour = 12; + } + len += hour >= 10; + break; + case StrTimeSpecifier::MINUTE_DECIMAL: + len += min >= 10; + break; + case StrTimeSpecifier::SECOND_DECIMAL: + len += sec >= 10; + break; + default: + throw InternalException("Time specifier mismatch"); + } + return len; + } + case StrTimeSpecifier::DAY_OF_MONTH: + return NumericHelper::UnsignedLength(Date::ExtractDay(date)); + case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: + return NumericHelper::UnsignedLength(Date::ExtractDayOfTheYear(date)); + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: + return NumericHelper::UnsignedLength(AbsValue(Date::ExtractYear(date)) % 100); + default: + throw InternalException("Unimplemented specifier for GetSpecifierLength"); + } +} + +//! Returns the total length of the date formatted by this format specifier +idx_t StrfTimeFormat::GetLength(date_t date, dtime_t time, int32_t utc_offset, const char *tz_name) { + idx_t size = constant_size; + if (!var_length_specifiers.empty()) { + for (auto &specifier : var_length_specifiers) { + size += GetSpecifierLength(specifier, date, time, utc_offset, tz_name); + } + } + return size; +} + +char *StrfTimeFormat::WriteString(char *target, const string_t &str) { + idx_t size = str.GetSize(); + memcpy(target, str.GetData(), size); + return target + size; +} + +// write a value in the range of 0..99 unpadded (e.g. "1", "2", ... "98", "99") +char *StrfTimeFormat::Write2(char *target, uint8_t value) { + D_ASSERT(value < 100); + if (value >= 10) { + return WritePadded2(target, value); + } else { + *target = char(uint8_t('0') + value); + return target + 1; + } +} + +// write a value in the range of 0..99 padded to 2 digits +char *StrfTimeFormat::WritePadded2(char *target, uint32_t value) { + D_ASSERT(value < 100); + auto index = static_cast(value * 2); + *target++ = duckdb_fmt::internal::data::digits[index]; + *target++ = duckdb_fmt::internal::data::digits[index + 1]; + return target; +} + +// write a value in the range of 0..999 padded +char *StrfTimeFormat::WritePadded3(char *target, uint32_t value) { + D_ASSERT(value < 1000); + if (value >= 100) { + WritePadded2(target + 1, value % 100); + *target = char(uint8_t('0') + value / 100); + return target + 3; + } else { + *target = '0'; + target++; + return WritePadded2(target, value); + } +} + +// write a value in the range of 0..999999... padded to the given number of digits +char *StrfTimeFormat::WritePadded(char *target, uint32_t value, size_t padding) { + D_ASSERT(padding > 1); + if (padding % 2) { + int decimals = value % 1000; + WritePadded3(target + padding - 3, decimals); + value /= 1000; + padding -= 3; + } + for (size_t i = 0; i < padding / 2; i++) { + int decimals = value % 100; + WritePadded2(target + padding - 2 * (i + 1), decimals); + value /= 100; + } + return target + padding; +} + +bool StrfTimeFormat::IsDateSpecifier(StrTimeSpecifier specifier) { + switch (specifier) { + case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: + case StrTimeSpecifier::FULL_WEEKDAY_NAME: + case StrTimeSpecifier::DAY_OF_YEAR_PADDED: + case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: + case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: + case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: + case StrTimeSpecifier::WEEKDAY_DECIMAL: + return true; + default: + return false; + } +} + +char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date, char *target) { + switch (specifier) { + case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: { + auto dow = Date::ExtractISODayOfTheWeek(date); + target = WriteString(target, Date::DAY_NAMES_ABBREVIATED[dow % 7]); + break; + } + case StrTimeSpecifier::FULL_WEEKDAY_NAME: { + auto dow = Date::ExtractISODayOfTheWeek(date); + target = WriteString(target, Date::DAY_NAMES[dow % 7]); + break; + } + case StrTimeSpecifier::WEEKDAY_DECIMAL: { + auto dow = Date::ExtractISODayOfTheWeek(date); + *target = char('0' + uint8_t(dow % 7)); + target++; + break; + } + case StrTimeSpecifier::DAY_OF_YEAR_PADDED: { + int32_t doy = Date::ExtractDayOfTheYear(date); + target = WritePadded3(target, doy); + break; + } + case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: + target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, true)); + break; + case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: + target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, false)); + break; + case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { + uint32_t doy = Date::ExtractDayOfTheYear(date); + target += NumericHelper::UnsignedLength(doy); + NumericHelper::FormatUnsigned(doy, target); + break; + } + default: + throw InternalException("Unimplemented date specifier for strftime"); + } + return target; +} + +char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t data[], const char *tz_name, + size_t tz_len, char *target) { + // data contains [0] year, [1] month, [2] day, [3] hour, [4] minute, [5] second, [6] msec, [7] utc + switch (specifier) { + case StrTimeSpecifier::DAY_OF_MONTH_PADDED: + target = WritePadded2(target, data[2]); + break; + case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { + auto &month_name = Date::MONTH_NAMES_ABBREVIATED[data[1] - 1]; + return WriteString(target, month_name); + } + case StrTimeSpecifier::FULL_MONTH_NAME: { + auto &month_name = Date::MONTH_NAMES[data[1] - 1]; + return WriteString(target, month_name); + } + case StrTimeSpecifier::MONTH_DECIMAL_PADDED: + target = WritePadded2(target, data[1]); + break; + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: + target = WritePadded2(target, AbsValue(data[0]) % 100); + break; + case StrTimeSpecifier::YEAR_DECIMAL: + if (data[0] >= 0 && data[0] <= 9999) { + target = WritePadded(target, data[0], 4); + } else { + int32_t year = data[0]; + if (data[0] < 0) { + *target = '-'; + year = -year; + target++; + } + auto len = NumericHelper::UnsignedLength(year); + NumericHelper::FormatUnsigned(year, target + len); + target += len; + } + break; + case StrTimeSpecifier::HOUR_24_PADDED: { + target = WritePadded2(target, data[3]); + break; + } + case StrTimeSpecifier::HOUR_12_PADDED: { + int hour = data[3] % 12; + if (hour == 0) { + hour = 12; + } + target = WritePadded2(target, hour); + break; + } + case StrTimeSpecifier::AM_PM: + *target++ = data[3] >= 12 ? 'P' : 'A'; + *target++ = 'M'; + break; + case StrTimeSpecifier::MINUTE_PADDED: { + target = WritePadded2(target, data[4]); + break; + } + case StrTimeSpecifier::SECOND_PADDED: + target = WritePadded2(target, data[5]); + break; + case StrTimeSpecifier::NANOSECOND_PADDED: + target = WritePadded(target, data[6] * Interval::NANOS_PER_MICRO, 9); + break; + case StrTimeSpecifier::MICROSECOND_PADDED: + target = WritePadded(target, data[6], 6); + break; + case StrTimeSpecifier::MILLISECOND_PADDED: + target = WritePadded3(target, data[6] / Interval::MICROS_PER_MSEC); + break; + case StrTimeSpecifier::UTC_OFFSET: { + *target++ = (data[7] < 0) ? '-' : '+'; + + auto offset = abs(data[7]); + auto offset_hours = offset / Interval::MINS_PER_HOUR; + auto offset_minutes = offset % Interval::MINS_PER_HOUR; + target = WritePadded2(target, offset_hours); + if (offset_minutes) { + *target++ = ':'; + target = WritePadded2(target, offset_minutes); + } + break; + } + case StrTimeSpecifier::TZ_NAME: + if (tz_name) { + memcpy(target, tz_name, tz_len); + target += strlen(tz_name); + } + break; + case StrTimeSpecifier::DAY_OF_MONTH: { + target = Write2(target, data[2] % 100); + break; + } + case StrTimeSpecifier::MONTH_DECIMAL: { + target = Write2(target, data[1]); + break; + } + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: { + target = Write2(target, AbsValue(data[0]) % 100); + break; + } + case StrTimeSpecifier::HOUR_24_DECIMAL: { + target = Write2(target, data[3]); + break; + } + case StrTimeSpecifier::HOUR_12_DECIMAL: { + int hour = data[3] % 12; + if (hour == 0) { + hour = 12; + } + target = Write2(target, hour); + break; + } + case StrTimeSpecifier::MINUTE_DECIMAL: { + target = Write2(target, data[4]); + break; + } + case StrTimeSpecifier::SECOND_DECIMAL: { + target = Write2(target, data[5]); + break; + } + default: + throw InternalException("Unimplemented specifier for WriteStandardSpecifier in strftime"); + } + return target; +} + +void StrfTimeFormat::FormatString(date_t date, int32_t data[8], const char *tz_name, char *target) { + D_ASSERT(specifiers.size() + 1 == literals.size()); + idx_t i; + for (i = 0; i < specifiers.size(); i++) { + // first copy the current literal + memcpy(target, literals[i].c_str(), literals[i].size()); + target += literals[i].size(); + // now copy the specifier + if (is_date_specifier[i]) { + target = WriteDateSpecifier(specifiers[i], date, target); + } else { + auto tz_len = tz_name ? strlen(tz_name) : 0; + target = WriteStandardSpecifier(specifiers[i], data, tz_name, tz_len, target); + } + } + // copy the final literal into the target + memcpy(target, literals[i].c_str(), literals[i].size()); +} + +void StrfTimeFormat::FormatString(date_t date, dtime_t time, char *target) { + int32_t data[8]; // year, month, day, hour, min, sec, µs, offset + Date::Convert(date, data[0], data[1], data[2]); + Time::Convert(time, data[3], data[4], data[5], data[6]); + data[7] = 0; + + FormatString(date, data, nullptr, target); +} + +string StrfTimeFormat::Format(timestamp_t timestamp, const string &format_str) { + StrfTimeFormat format; + format.ParseFormatSpecifier(format_str, format); + + auto date = Timestamp::GetDate(timestamp); + auto time = Timestamp::GetTime(timestamp); + + auto len = format.GetLength(date, time, 0, nullptr); + auto result = make_unsafe_uniq_array(len); + format.FormatString(date, time, result.get()); + return string(result.get(), len); +} + +string StrTimeFormat::ParseFormatSpecifier(const string &format_string, StrTimeFormat &format) { + if (format_string.empty()) { + return "Empty format string"; + } + format.format_specifier = format_string; + format.specifiers.clear(); + format.literals.clear(); + format.numeric_width.clear(); + format.constant_size = 0; + idx_t pos = 0; + string current_literal; + for (idx_t i = 0; i < format_string.size(); i++) { + if (format_string[i] == '%') { + if (i + 1 == format_string.size()) { + return "Trailing format character %"; + } + if (i > pos) { + // push the previous string to the current literal + current_literal += format_string.substr(pos, i - pos); + } + char format_char = format_string[++i]; + if (format_char == '%') { + // special case: %% + // set the pos for the next literal and continue + pos = i; + continue; + } + StrTimeSpecifier specifier; + if (format_char == '-' && i + 1 < format_string.size()) { + format_char = format_string[++i]; + switch (format_char) { + case 'd': + specifier = StrTimeSpecifier::DAY_OF_MONTH; + break; + case 'm': + specifier = StrTimeSpecifier::MONTH_DECIMAL; + break; + case 'y': + specifier = StrTimeSpecifier::YEAR_WITHOUT_CENTURY; + break; + case 'H': + specifier = StrTimeSpecifier::HOUR_24_DECIMAL; + break; + case 'I': + specifier = StrTimeSpecifier::HOUR_12_DECIMAL; + break; + case 'M': + specifier = StrTimeSpecifier::MINUTE_DECIMAL; + break; + case 'S': + specifier = StrTimeSpecifier::SECOND_DECIMAL; + break; + case 'j': + specifier = StrTimeSpecifier::DAY_OF_YEAR_DECIMAL; + break; + default: + return "Unrecognized format for strftime/strptime: %-" + string(1, format_char); + } + } else { + switch (format_char) { + case 'a': + specifier = StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME; + break; + case 'A': + specifier = StrTimeSpecifier::FULL_WEEKDAY_NAME; + break; + case 'w': + specifier = StrTimeSpecifier::WEEKDAY_DECIMAL; + break; + case 'd': + specifier = StrTimeSpecifier::DAY_OF_MONTH_PADDED; + break; + case 'h': + case 'b': + specifier = StrTimeSpecifier::ABBREVIATED_MONTH_NAME; + break; + case 'B': + specifier = StrTimeSpecifier::FULL_MONTH_NAME; + break; + case 'm': + specifier = StrTimeSpecifier::MONTH_DECIMAL_PADDED; + break; + case 'y': + specifier = StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED; + break; + case 'Y': + specifier = StrTimeSpecifier::YEAR_DECIMAL; + break; + case 'H': + specifier = StrTimeSpecifier::HOUR_24_PADDED; + break; + case 'I': + specifier = StrTimeSpecifier::HOUR_12_PADDED; + break; + case 'p': + specifier = StrTimeSpecifier::AM_PM; + break; + case 'M': + specifier = StrTimeSpecifier::MINUTE_PADDED; + break; + case 'S': + specifier = StrTimeSpecifier::SECOND_PADDED; + break; + case 'n': + specifier = StrTimeSpecifier::NANOSECOND_PADDED; + break; + case 'f': + specifier = StrTimeSpecifier::MICROSECOND_PADDED; + break; + case 'g': + specifier = StrTimeSpecifier::MILLISECOND_PADDED; + break; + case 'z': + specifier = StrTimeSpecifier::UTC_OFFSET; + break; + case 'Z': + specifier = StrTimeSpecifier::TZ_NAME; + break; + case 'j': + specifier = StrTimeSpecifier::DAY_OF_YEAR_PADDED; + break; + case 'U': + specifier = StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST; + break; + case 'W': + specifier = StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST; + break; + case 'c': + case 'x': + case 'X': + case 'T': { + string subformat; + if (format_char == 'c') { + // %c: Locale’s appropriate date and time representation. + // we push the ISO timestamp representation here + subformat = "%Y-%m-%d %H:%M:%S"; + } else if (format_char == 'x') { + // %x - Locale’s appropriate date representation. + // we push the ISO date format here + subformat = "%Y-%m-%d"; + } else if (format_char == 'X' || format_char == 'T') { + // %X - Locale’s appropriate time representation. + // we push the ISO time format here + subformat = "%H:%M:%S"; + } + // parse the subformat in a separate format specifier + StrfTimeFormat locale_format; + string error = StrTimeFormat::ParseFormatSpecifier(subformat, locale_format); + D_ASSERT(error.empty()); + // add the previous literal to the first literal of the subformat + locale_format.literals[0] = std::move(current_literal) + locale_format.literals[0]; + current_literal = ""; + // now push the subformat into the current format specifier + for (idx_t i = 0; i < locale_format.specifiers.size(); i++) { + format.AddFormatSpecifier(std::move(locale_format.literals[i]), locale_format.specifiers[i]); + } + pos = i + 1; + continue; + } + default: + return "Unrecognized format for strftime/strptime: %" + string(1, format_char); + } + } + format.AddFormatSpecifier(std::move(current_literal), specifier); + current_literal = ""; + pos = i + 1; + } + } + // add the final literal + if (pos < format_string.size()) { + current_literal += format_string.substr(pos, format_string.size() - pos); + } + format.AddLiteral(std::move(current_literal)); + return string(); +} + +void StrfTimeFormat::ConvertDateVector(Vector &input, Vector &result, idx_t count) { + D_ASSERT(input.GetType().id() == LogicalTypeId::DATE); + D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); + UnaryExecutor::ExecuteWithNulls(input, result, count, + [&](date_t input, ValidityMask &mask, idx_t idx) { + if (Date::IsFinite(input)) { + dtime_t time(0); + idx_t len = GetLength(input, time, 0, nullptr); + string_t target = StringVector::EmptyString(result, len); + FormatString(input, time, target.GetDataWriteable()); + target.Finalize(); + return target; + } else { + mask.SetInvalid(idx); + return string_t(); + } + }); +} + +void StrfTimeFormat::ConvertTimestampVector(Vector &input, Vector &result, idx_t count) { + D_ASSERT(input.GetType().id() == LogicalTypeId::TIMESTAMP || input.GetType().id() == LogicalTypeId::TIMESTAMP_TZ); + D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); + UnaryExecutor::ExecuteWithNulls( + input, result, count, [&](timestamp_t input, ValidityMask &mask, idx_t idx) { + if (Timestamp::IsFinite(input)) { + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + idx_t len = GetLength(date, time, 0, nullptr); + string_t target = StringVector::EmptyString(result, len); + FormatString(date, time, target.GetDataWriteable()); + target.Finalize(); + return target; + } else { + mask.SetInvalid(idx); + return string_t(); + } + }); +} + +void StrpTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { + numeric_width.push_back(NumericSpecifierWidth(specifier)); + StrTimeFormat::AddFormatSpecifier(std::move(preceding_literal), specifier); +} + +int StrpTimeFormat::NumericSpecifierWidth(StrTimeSpecifier specifier) { + switch (specifier) { + case StrTimeSpecifier::WEEKDAY_DECIMAL: + return 1; + case StrTimeSpecifier::DAY_OF_MONTH_PADDED: + case StrTimeSpecifier::DAY_OF_MONTH: + case StrTimeSpecifier::MONTH_DECIMAL_PADDED: + case StrTimeSpecifier::MONTH_DECIMAL: + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: + case StrTimeSpecifier::HOUR_24_PADDED: + case StrTimeSpecifier::HOUR_24_DECIMAL: + case StrTimeSpecifier::HOUR_12_PADDED: + case StrTimeSpecifier::HOUR_12_DECIMAL: + case StrTimeSpecifier::MINUTE_PADDED: + case StrTimeSpecifier::MINUTE_DECIMAL: + case StrTimeSpecifier::SECOND_PADDED: + case StrTimeSpecifier::SECOND_DECIMAL: + case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: + case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: + return 2; + case StrTimeSpecifier::MILLISECOND_PADDED: + case StrTimeSpecifier::DAY_OF_YEAR_PADDED: + case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: + return 3; + case StrTimeSpecifier::YEAR_DECIMAL: + return 4; + case StrTimeSpecifier::MICROSECOND_PADDED: + return 6; + case StrTimeSpecifier::NANOSECOND_PADDED: + return 9; + default: + return -1; + } +} + +enum class TimeSpecifierAMOrPM : uint8_t { TIME_SPECIFIER_NONE = 0, TIME_SPECIFIER_AM = 1, TIME_SPECIFIER_PM = 2 }; + +int32_t StrpTimeFormat::TryParseCollection(const char *data, idx_t &pos, idx_t size, const string_t collection[], + idx_t collection_count) const { + for (idx_t c = 0; c < collection_count; c++) { + auto &entry = collection[c]; + auto entry_data = entry.GetData(); + auto entry_size = entry.GetSize(); + // check if this entry matches + if (pos + entry_size > size) { + // too big: can't match + continue; + } + // compare the characters + idx_t i; + for (i = 0; i < entry_size; i++) { + if (std::tolower(entry_data[i]) != std::tolower(data[pos + i])) { + break; + } + } + if (i == entry_size) { + // full match + pos += entry_size; + return c; + } + } + return -1; +} + +//! Parses a timestamp using the given specifier +bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { + auto &result_data = result.data; + auto &error_message = result.error_message; + auto &error_position = result.error_position; + + // initialize the result + result_data[0] = 1900; + result_data[1] = 1; + result_data[2] = 1; + result_data[3] = 0; + result_data[4] = 0; + result_data[5] = 0; + result_data[6] = 0; + result_data[7] = 0; + + auto data = str.GetData(); + idx_t size = str.GetSize(); + // skip leading spaces + while (StringUtil::CharacterIsSpace(*data)) { + data++; + size--; + } + idx_t pos = 0; + TimeSpecifierAMOrPM ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_NONE; + + // Year offset state (Year+W/j) + auto offset_specifier = StrTimeSpecifier::WEEKDAY_DECIMAL; + uint64_t weekno = 0; + uint64_t weekday = 0; + uint64_t yearday = 0; + + for (idx_t i = 0;; i++) { + D_ASSERT(i < literals.size()); + // first compare the literal + const auto &literal = literals[i]; + for (size_t l = 0; l < literal.size();) { + // Match runs of spaces to runs of spaces. + if (StringUtil::CharacterIsSpace(literal[l])) { + if (!StringUtil::CharacterIsSpace(data[pos])) { + error_message = "Space does not match, expected " + literals[i]; + error_position = pos; + return false; + } + for (++pos; pos < size && StringUtil::CharacterIsSpace(data[pos]); ++pos) { + continue; + } + for (++l; l < literal.size() && StringUtil::CharacterIsSpace(literal[l]); ++l) { + continue; + } + continue; + } + // literal does not match + if (data[pos++] != literal[l++]) { + error_message = "Literal does not match, expected " + literal; + error_position = pos; + return false; + } + } + if (i == specifiers.size()) { + break; + } + // now parse the specifier + if (numeric_width[i] > 0) { + // numeric specifier: parse a number + uint64_t number = 0; + size_t start_pos = pos; + size_t end_pos = start_pos + numeric_width[i]; + while (pos < size && pos < end_pos && StringUtil::CharacterIsDigit(data[pos])) { + number = number * 10 + data[pos] - '0'; + pos++; + } + if (pos == start_pos) { + // expected a number here + error_message = "Expected a number"; + error_position = start_pos; + return false; + } + switch (specifiers[i]) { + case StrTimeSpecifier::DAY_OF_MONTH_PADDED: + case StrTimeSpecifier::DAY_OF_MONTH: + if (number < 1 || number > 31) { + error_message = "Day out of range, expected a value between 1 and 31"; + error_position = start_pos; + return false; + } + // day of the month + result_data[2] = number; + offset_specifier = specifiers[i]; + break; + case StrTimeSpecifier::MONTH_DECIMAL_PADDED: + case StrTimeSpecifier::MONTH_DECIMAL: + if (number < 1 || number > 12) { + error_message = "Month out of range, expected a value between 1 and 12"; + error_position = start_pos; + return false; + } + // month number + result_data[1] = number; + offset_specifier = specifiers[i]; + break; + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: + case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: + // year without century.. + // Python uses 69 as a crossover point (i.e. >= 69 is 19.., < 69 is 20..) + if (number >= 100) { + // %y only supports numbers between [0..99] + error_message = "Year without century out of range, expected a value between 0 and 99"; + error_position = start_pos; + return false; + } + if (number >= 69) { + result_data[0] = int32_t(1900 + number); + } else { + result_data[0] = int32_t(2000 + number); + } + break; + case StrTimeSpecifier::YEAR_DECIMAL: + // year as full number + result_data[0] = number; + break; + case StrTimeSpecifier::HOUR_24_PADDED: + case StrTimeSpecifier::HOUR_24_DECIMAL: + if (number >= 24) { + error_message = "Hour out of range, expected a value between 0 and 23"; + error_position = start_pos; + return false; + } + // hour as full number + result_data[3] = number; + break; + case StrTimeSpecifier::HOUR_12_PADDED: + case StrTimeSpecifier::HOUR_12_DECIMAL: + if (number < 1 || number > 12) { + error_message = "Hour12 out of range, expected a value between 1 and 12"; + error_position = start_pos; + return false; + } + // 12-hour number: start off by just storing the number + result_data[3] = number; + break; + case StrTimeSpecifier::MINUTE_PADDED: + case StrTimeSpecifier::MINUTE_DECIMAL: + if (number >= 60) { + error_message = "Minutes out of range, expected a value between 0 and 59"; + error_position = start_pos; + return false; + } + // minutes + result_data[4] = number; + break; + case StrTimeSpecifier::SECOND_PADDED: + case StrTimeSpecifier::SECOND_DECIMAL: + if (number >= 60) { + error_message = "Seconds out of range, expected a value between 0 and 59"; + error_position = start_pos; + return false; + } + // seconds + result_data[5] = number; + break; + case StrTimeSpecifier::NANOSECOND_PADDED: + D_ASSERT(number < Interval::NANOS_PER_SEC); // enforced by the length of the number + // microseconds (rounded) + result_data[6] = (number + Interval::NANOS_PER_MICRO / 2) / Interval::NANOS_PER_MICRO; + break; + case StrTimeSpecifier::MICROSECOND_PADDED: + D_ASSERT(number < Interval::MICROS_PER_SEC); // enforced by the length of the number + // microseconds + result_data[6] = number; + break; + case StrTimeSpecifier::MILLISECOND_PADDED: + D_ASSERT(number < Interval::MSECS_PER_SEC); // enforced by the length of the number + // microseconds + result_data[6] = number * Interval::MICROS_PER_MSEC; + break; + case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: + case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: + // m/d overrides WU/w but does not conflict + switch (offset_specifier) { + case StrTimeSpecifier::DAY_OF_MONTH_PADDED: + case StrTimeSpecifier::DAY_OF_MONTH: + case StrTimeSpecifier::MONTH_DECIMAL_PADDED: + case StrTimeSpecifier::MONTH_DECIMAL: + // Just validate, don't use + break; + case StrTimeSpecifier::WEEKDAY_DECIMAL: + // First offset specifier + offset_specifier = specifiers[i]; + break; + default: + error_message = "Multiple year offsets specified"; + error_position = start_pos; + return false; + } + if (number > 53) { + error_message = "Week out of range, expected a value between 0 and 53"; + error_position = start_pos; + return false; + } + weekno = number; + break; + case StrTimeSpecifier::WEEKDAY_DECIMAL: + if (number > 6) { + error_message = "Weekday out of range, expected a value between 0 and 6"; + error_position = start_pos; + return false; + } + weekday = number; + break; + case StrTimeSpecifier::DAY_OF_YEAR_PADDED: + case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: + // m/d overrides j but does not conflict + switch (offset_specifier) { + case StrTimeSpecifier::DAY_OF_MONTH_PADDED: + case StrTimeSpecifier::DAY_OF_MONTH: + case StrTimeSpecifier::MONTH_DECIMAL_PADDED: + case StrTimeSpecifier::MONTH_DECIMAL: + // Just validate, don't use + break; + case StrTimeSpecifier::WEEKDAY_DECIMAL: + // First offset specifier + offset_specifier = specifiers[i]; + break; + default: + error_message = "Multiple year offsets specified"; + error_position = start_pos; + return false; + } + if (number < 1 || number > 366) { + error_message = "Year day out of range, expected a value between 1 and 366"; + error_position = start_pos; + return false; + } + yearday = number; + break; + default: + throw NotImplementedException("Unsupported specifier for strptime"); + } + } else { + switch (specifiers[i]) { + case StrTimeSpecifier::AM_PM: { + // parse the next 2 characters + if (pos + 2 > size) { + // no characters left to parse + error_message = "Expected AM/PM"; + error_position = pos; + return false; + } + char pa_char = char(std::tolower(data[pos])); + char m_char = char(std::tolower(data[pos + 1])); + if (m_char != 'm') { + error_message = "Expected AM/PM"; + error_position = pos; + return false; + } + if (pa_char == 'p') { + ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_PM; + } else if (pa_char == 'a') { + ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_AM; + } else { + error_message = "Expected AM/PM"; + error_position = pos; + return false; + } + pos += 2; + break; + } + // we parse weekday names, but we don't use them as information + case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: + if (TryParseCollection(data, pos, size, Date::DAY_NAMES_ABBREVIATED, 7) < 0) { + error_message = "Expected an abbreviated day name (Mon, Tue, Wed, Thu, Fri, Sat, Sun)"; + error_position = pos; + return false; + } + break; + case StrTimeSpecifier::FULL_WEEKDAY_NAME: + if (TryParseCollection(data, pos, size, Date::DAY_NAMES, 7) < 0) { + error_message = "Expected a full day name (Monday, Tuesday, etc...)"; + error_position = pos; + return false; + } + break; + case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { + int32_t month = TryParseCollection(data, pos, size, Date::MONTH_NAMES_ABBREVIATED, 12); + if (month < 0) { + error_message = "Expected an abbreviated month name (Jan, Feb, Mar, etc..)"; + error_position = pos; + return false; + } + result_data[1] = month + 1; + break; + } + case StrTimeSpecifier::FULL_MONTH_NAME: { + int32_t month = TryParseCollection(data, pos, size, Date::MONTH_NAMES, 12); + if (month < 0) { + error_message = "Expected a full month name (January, February, etc...)"; + error_position = pos; + return false; + } + result_data[1] = month + 1; + break; + } + case StrTimeSpecifier::UTC_OFFSET: { + int hour_offset, minute_offset; + if (!Timestamp::TryParseUTCOffset(data, pos, size, hour_offset, minute_offset)) { + error_message = "Expected +HH[MM] or -HH[MM]"; + error_position = pos; + return false; + } + result_data[7] = hour_offset * Interval::MINS_PER_HOUR + minute_offset; + break; + } + case StrTimeSpecifier::TZ_NAME: { + // skip leading spaces + while (pos < size && StringUtil::CharacterIsSpace(data[pos])) { + pos++; + } + const auto tz_begin = data + pos; + // stop when we encounter a non-tz character + while (pos < size && Timestamp::CharacterIsTimeZone(data[pos])) { + pos++; + } + const auto tz_end = data + pos; + // Can't fully validate without a list - caller's responsibility. + // But tz must not be empty. + if (tz_end == tz_begin) { + error_message = "Empty Time Zone name"; + error_position = tz_begin - data; + return false; + } + result.tz.assign(tz_begin, tz_end); + break; + } + default: + throw NotImplementedException("Unsupported specifier for strptime"); + } + } + } + // skip trailing spaces + while (pos < size && StringUtil::CharacterIsSpace(data[pos])) { + pos++; + } + if (pos != size) { + error_message = "Full specifier did not match: trailing characters"; + error_position = pos; + return false; + } + if (ampm != TimeSpecifierAMOrPM::TIME_SPECIFIER_NONE) { + if (result_data[3] > 12) { + error_message = + "Invalid hour: " + to_string(result_data[3]) + " AM/PM, expected an hour within the range [0..12]"; + return false; + } + // adjust the hours based on the AM or PM specifier + if (ampm == TimeSpecifierAMOrPM::TIME_SPECIFIER_AM) { + // AM: 12AM=0, 1AM=1, 2AM=2, ..., 11AM=11 + if (result_data[3] == 12) { + result_data[3] = 0; + } + } else { + // PM: 12PM=12, 1PM=13, 2PM=14, ..., 11PM=23 + if (result_data[3] != 12) { + result_data[3] += 12; + } + } + } + switch (offset_specifier) { + case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: + case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: { + // Adjust weekday to be 0-based for the week type + weekday = (weekday + 7 - int(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % 7; + // Get the start of week 1, move back 7 days and then weekno * 7 + weekday gives the date + const auto jan1 = Date::FromDate(result_data[0], 1, 1); + auto yeardate = Date::GetMondayOfCurrentWeek(jan1); + yeardate -= int(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST); + // Is there a week 0? + yeardate -= 7 * int(yeardate >= jan1); + yeardate += weekno * 7 + weekday; + Date::Convert(yeardate, result_data[0], result_data[1], result_data[2]); + break; + } + case StrTimeSpecifier::DAY_OF_YEAR_PADDED: + case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { + auto yeardate = Date::FromDate(result_data[0], 1, 1); + yeardate += yearday - 1; + Date::Convert(yeardate, result_data[0], result_data[1], result_data[2]); + break; + } + case StrTimeSpecifier::DAY_OF_MONTH_PADDED: + case StrTimeSpecifier::DAY_OF_MONTH: + case StrTimeSpecifier::MONTH_DECIMAL_PADDED: + case StrTimeSpecifier::MONTH_DECIMAL: + // m/d overrides UWw/j + break; + default: + D_ASSERT(offset_specifier == StrTimeSpecifier::WEEKDAY_DECIMAL); + break; + } + + return true; +} + +StrpTimeFormat::ParseResult StrpTimeFormat::Parse(const string &format_string, const string &text) { + StrpTimeFormat format; + format.format_specifier = format_string; + string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); + if (!error.empty()) { + throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); + } + StrpTimeFormat::ParseResult result; + if (!format.Parse(text, result)) { + throw InvalidInputException("Failed to parse string \"%s\" with format specifier \"%s\"", text, format_string); + } + return result; +} + +string StrpTimeFormat::FormatStrpTimeError(const string &input, idx_t position) { + if (position == DConstants::INVALID_INDEX) { + return string(); + } + return input + "\n" + string(position, ' ') + "^"; +} + +date_t StrpTimeFormat::ParseResult::ToDate() { + return Date::FromDate(data[0], data[1], data[2]); +} + +bool StrpTimeFormat::ParseResult::TryToDate(date_t &result) { + return Date::TryFromDate(data[0], data[1], data[2], result); +} + +timestamp_t StrpTimeFormat::ParseResult::ToTimestamp() { + date_t date = Date::FromDate(data[0], data[1], data[2]); + const auto hour_offset = data[7] / Interval::MINS_PER_HOUR; + const auto mins_offset = data[7] % Interval::MINS_PER_HOUR; + dtime_t time = Time::FromTime(data[3] - hour_offset, data[4] - mins_offset, data[5], data[6]); + return Timestamp::FromDatetime(date, time); +} + +bool StrpTimeFormat::ParseResult::TryToTimestamp(timestamp_t &result) { + date_t date; + if (!TryToDate(date)) { + return false; + } + const auto hour_offset = data[7] / Interval::MINS_PER_HOUR; + const auto mins_offset = data[7] % Interval::MINS_PER_HOUR; + dtime_t time = Time::FromTime(data[3] - hour_offset, data[4] - mins_offset, data[5], data[6]); + return Timestamp::TryFromDatetime(date, time, result); +} + +string StrpTimeFormat::ParseResult::FormatError(string_t input, const string &format_specifier) { + return StringUtil::Format("Could not parse string \"%s\" according to format specifier \"%s\"\n%s\nError: %s", + input.GetString(), format_specifier, + FormatStrpTimeError(input.GetString(), error_position), error_message); +} + +bool StrpTimeFormat::TryParseDate(string_t input, date_t &result, string &error_message) const { + ParseResult parse_result; + if (!Parse(input, parse_result)) { + error_message = parse_result.FormatError(input, format_specifier); + return false; + } + return parse_result.TryToDate(result); +} + +bool StrpTimeFormat::TryParseTimestamp(string_t input, timestamp_t &result, string &error_message) const { + ParseResult parse_result; + if (!Parse(input, parse_result)) { + error_message = parse_result.FormatError(input, format_specifier); + return false; + } + return parse_result.TryToTimestamp(result); +} + +date_t StrpTimeFormat::ParseDate(string_t input) { + ParseResult result; + if (!Parse(input, result)) { + throw InvalidInputException(result.FormatError(input, format_specifier)); + } + return result.ToDate(); +} + +timestamp_t StrpTimeFormat::ParseTimestamp(string_t input) { + ParseResult result; + if (!Parse(input, result)) { + throw InvalidInputException(result.FormatError(input, format_specifier)); + } + return result.ToTimestamp(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/caseconvert.cpp b/src/duckdb/src/function/scalar/string/caseconvert.cpp new file mode 100644 index 00000000..73217252 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/caseconvert.cpp @@ -0,0 +1,177 @@ +#include "duckdb/function/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +#include "utf8proc.hpp" + +#include + +namespace duckdb { + +uint8_t UpperFun::ascii_to_upper_map[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, + 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + 88, 89, 90, 91, 92, 93, 94, 95, 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, + 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, 128, 129, 130, 131, + 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, + 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, + 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, + 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, + 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254}; +uint8_t LowerFun::ascii_to_lower_map[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, + 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 97, + 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, + 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, + 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, + 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, + 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, + 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254}; + +template +static string_t ASCIICaseConvert(Vector &result, const char *input_data, idx_t input_length) { + idx_t output_length = input_length; + auto result_str = StringVector::EmptyString(result, output_length); + auto result_data = result_str.GetDataWriteable(); + for (idx_t i = 0; i < input_length; i++) { + result_data[i] = IS_UPPER ? UpperFun::ascii_to_upper_map[uint8_t(input_data[i])] + : LowerFun::ascii_to_lower_map[uint8_t(input_data[i])]; + } + result_str.Finalize(); + return result_str; +} + +template +static idx_t GetResultLength(const char *input_data, idx_t input_length) { + idx_t output_length = 0; + for (idx_t i = 0; i < input_length;) { + if (input_data[i] & 0x80) { + // unicode + int sz = 0; + int codepoint = utf8proc_codepoint(input_data + i, sz); + int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); + int new_sz = utf8proc_codepoint_length(converted_codepoint); + D_ASSERT(new_sz >= 0); + output_length += new_sz; + i += sz; + } else { + // ascii + output_length++; + i++; + } + } + return output_length; +} + +template +static void CaseConvert(const char *input_data, idx_t input_length, char *result_data) { + for (idx_t i = 0; i < input_length;) { + if (input_data[i] & 0x80) { + // non-ascii character + int sz = 0, new_sz = 0; + int codepoint = utf8proc_codepoint(input_data + i, sz); + int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); + auto success = utf8proc_codepoint_to_utf8(converted_codepoint, new_sz, result_data); + D_ASSERT(success); + (void)success; + result_data += new_sz; + i += sz; + } else { + // ascii + *result_data = IS_UPPER ? UpperFun::ascii_to_upper_map[uint8_t(input_data[i])] + : LowerFun::ascii_to_lower_map[uint8_t(input_data[i])]; + result_data++; + i++; + } + } +} + +idx_t LowerFun::LowerLength(const char *input_data, idx_t input_length) { + return GetResultLength(input_data, input_length); +} + +void LowerFun::LowerCase(const char *input_data, idx_t input_length, char *result_data) { + CaseConvert(input_data, input_length, result_data); +} + +template +static string_t UnicodeCaseConvert(Vector &result, const char *input_data, idx_t input_length) { + // first figure out the output length + idx_t output_length = GetResultLength(input_data, input_length); + auto result_str = StringVector::EmptyString(result, output_length); + auto result_data = result_str.GetDataWriteable(); + + CaseConvert(input_data, input_length, result_data); + result_str.Finalize(); + return result_str; +} + +template +struct CaseConvertOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + return UnicodeCaseConvert(result, input_data, input_length); + } +}; + +template +static void CaseConvertFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); +} + +template +struct CaseConvertOperatorASCII { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + return ASCIICaseConvert(result, input_data, input_length); + } +}; + +template +static void CaseConvertFunctionASCII(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString>(args.data[0], result, + args.size()); +} + +template +static unique_ptr CaseConvertPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 1); + // can only propagate stats if the children have stats + if (!StringStats::CanContainUnicode(child_stats[0])) { + expr.function.function = CaseConvertFunctionASCII; + } + return nullptr; +} + +ScalarFunction LowerFun::GetFunction() { + return ScalarFunction("lower", {LogicalType::VARCHAR}, LogicalType::VARCHAR, CaseConvertFunction, nullptr, + nullptr, CaseConvertPropagateStats); +} + +void LowerFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction({"lower", "lcase"}, LowerFun::GetFunction()); +} + +void UpperFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction({"upper", "ucase"}, + ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, CaseConvertFunction, nullptr, + nullptr, CaseConvertPropagateStats)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/concat.cpp b/src/duckdb/src/function/scalar/string/concat.cpp new file mode 100644 index 00000000..fc3b4114 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/concat.cpp @@ -0,0 +1,269 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/scalar/string_functions.hpp" + +#include + +namespace duckdb { + +static void ConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + // iterate over the vectors to count how large the final string will be + idx_t constant_lengths = 0; + vector result_lengths(args.size(), 0); + for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { + auto &input = args.data[col_idx]; + D_ASSERT(input.GetType().id() == LogicalTypeId::VARCHAR); + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(input)) { + // constant null, skip + continue; + } + auto input_data = ConstantVector::GetData(input); + constant_lengths += input_data->GetSize(); + } else { + // non-constant vector: set the result type to a flat vector + result.SetVectorType(VectorType::FLAT_VECTOR); + // now get the lengths of each of the input elements + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(args.size(), vdata); + + auto input_data = UnifiedVectorFormat::GetData(vdata); + // now add the length of each vector to the result length + for (idx_t i = 0; i < args.size(); i++) { + auto idx = vdata.sel->get_index(i); + if (!vdata.validity.RowIsValid(idx)) { + continue; + } + result_lengths[i] += input_data[idx].GetSize(); + } + } + } + + // first we allocate the empty strings for each of the values + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < args.size(); i++) { + // allocate an empty string of the required size + idx_t str_length = constant_lengths + result_lengths[i]; + result_data[i] = StringVector::EmptyString(result, str_length); + // we reuse the result_lengths vector to store the currently appended size + result_lengths[i] = 0; + } + + // now that the empty space for the strings has been allocated, perform the concatenation + for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { + auto &input = args.data[col_idx]; + + // loop over the vector and concat to all results + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // constant vector + if (ConstantVector::IsNull(input)) { + // constant null, skip + continue; + } + // append the constant vector to each of the strings + auto input_data = ConstantVector::GetData(input); + auto input_ptr = input_data->GetData(); + auto input_len = input_data->GetSize(); + for (idx_t i = 0; i < args.size(); i++) { + memcpy(result_data[i].GetDataWriteable() + result_lengths[i], input_ptr, input_len); + result_lengths[i] += input_len; + } + } else { + // standard vector + UnifiedVectorFormat idata; + input.ToUnifiedFormat(args.size(), idata); + + auto input_data = UnifiedVectorFormat::GetData(idata); + for (idx_t i = 0; i < args.size(); i++) { + auto idx = idata.sel->get_index(i); + if (!idata.validity.RowIsValid(idx)) { + continue; + } + auto input_ptr = input_data[idx].GetData(); + auto input_len = input_data[idx].GetSize(); + memcpy(result_data[i].GetDataWriteable() + result_lengths[i], input_ptr, input_len); + result_lengths[i] += input_len; + } + } + } + for (idx_t i = 0; i < args.size(); i++) { + result_data[i].Finalize(); + } +} + +static void ConcatOperator(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t a, string_t b) { + auto a_data = a.GetData(); + auto b_data = b.GetData(); + auto a_length = a.GetSize(); + auto b_length = b.GetSize(); + + auto target_length = a_length + b_length; + auto target = StringVector::EmptyString(result, target_length); + auto target_data = target.GetDataWriteable(); + + memcpy(target_data, a_data, a_length); + memcpy(target_data + a_length, b_data, b_length); + target.Finalize(); + return target; + }); +} + +static void TemplatedConcatWS(DataChunk &args, const string_t *sep_data, const SelectionVector &sep_sel, + const SelectionVector &rsel, idx_t count, Vector &result) { + vector result_lengths(args.size(), 0); + vector has_results(args.size(), false); + auto orrified_data = make_unsafe_uniq_array(args.ColumnCount() - 1); + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + args.data[col_idx].ToUnifiedFormat(args.size(), orrified_data[col_idx - 1]); + } + + // first figure out the lengths + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + auto &idata = orrified_data[col_idx - 1]; + + auto input_data = UnifiedVectorFormat::GetData(idata); + for (idx_t i = 0; i < count; i++) { + auto ridx = rsel.get_index(i); + auto sep_idx = sep_sel.get_index(ridx); + auto idx = idata.sel->get_index(ridx); + if (!idata.validity.RowIsValid(idx)) { + continue; + } + if (has_results[ridx]) { + result_lengths[ridx] += sep_data[sep_idx].GetSize(); + } + result_lengths[ridx] += input_data[idx].GetSize(); + has_results[ridx] = true; + } + } + + // first we allocate the empty strings for each of the values + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto ridx = rsel.get_index(i); + // allocate an empty string of the required size + result_data[ridx] = StringVector::EmptyString(result, result_lengths[ridx]); + // we reuse the result_lengths vector to store the currently appended size + result_lengths[ridx] = 0; + has_results[ridx] = false; + } + + // now that the empty space for the strings has been allocated, perform the concatenation + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + auto &idata = orrified_data[col_idx - 1]; + auto input_data = UnifiedVectorFormat::GetData(idata); + for (idx_t i = 0; i < count; i++) { + auto ridx = rsel.get_index(i); + auto sep_idx = sep_sel.get_index(ridx); + auto idx = idata.sel->get_index(ridx); + if (!idata.validity.RowIsValid(idx)) { + continue; + } + if (has_results[ridx]) { + auto sep_size = sep_data[sep_idx].GetSize(); + auto sep_ptr = sep_data[sep_idx].GetData(); + memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], sep_ptr, sep_size); + result_lengths[ridx] += sep_size; + } + auto input_ptr = input_data[idx].GetData(); + auto input_len = input_data[idx].GetSize(); + memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], input_ptr, input_len); + result_lengths[ridx] += input_len; + has_results[ridx] = true; + } + } + for (idx_t i = 0; i < count; i++) { + auto ridx = rsel.get_index(i); + result_data[ridx].Finalize(); + } +} + +static void ConcatWSFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &separator = args.data[0]; + UnifiedVectorFormat vdata; + separator.ToUnifiedFormat(args.size(), vdata); + + result.SetVectorType(VectorType::CONSTANT_VECTOR); + for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { + if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::FLAT_VECTOR); + break; + } + } + switch (separator.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + if (ConstantVector::IsNull(separator)) { + // constant NULL as separator: return constant NULL vector + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + // no null values + auto sel = FlatVector::IncrementalSelectionVector(); + TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, *sel, args.size(), result); + return; + } + default: { + // default case: loop over nullmask and create a non-null selection vector + idx_t not_null_count = 0; + SelectionVector not_null_vector(STANDARD_VECTOR_SIZE); + auto &result_mask = FlatVector::Validity(result); + for (idx_t i = 0; i < args.size(); i++) { + if (!vdata.validity.RowIsValid(vdata.sel->get_index(i))) { + result_mask.SetInvalid(i); + } else { + not_null_vector.set_index(not_null_count++, i); + } + } + TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, not_null_vector, + not_null_count, result); + return; + } + } +} + +void ConcatFun::RegisterFunction(BuiltinFunctions &set) { + // the concat operator and concat function have different behavior regarding NULLs + // this is strange but seems consistent with postgresql and mysql + // (sqlite does not support the concat function, only the concat operator) + + // the concat operator behaves as one would expect: any NULL value present results in a NULL + // i.e. NULL || 'hello' = NULL + // the concat function, however, treats NULL values as an empty string + // i.e. concat(NULL, 'hello') = 'hello' + // concat_ws functions similarly to the concat function, except the result is NULL if the separator is NULL + // if the separator is not NULL, however, NULL values are counted as empty string + // there is one separate rule: there are no separators added between NULL values + // so the NULL value and empty string are different! + // e.g.: + // concat_ws(',', NULL, NULL) = "" + // concat_ws(',', '', '') = "," + ScalarFunction concat = ScalarFunction("concat", {LogicalType::VARCHAR}, LogicalType::VARCHAR, ConcatFunction); + concat.varargs = LogicalType::VARCHAR; + concat.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + set.AddFunction(concat); + + ScalarFunctionSet concat_op("||"); + concat_op.AddFunction( + ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, ConcatOperator)); + concat_op.AddFunction(ScalarFunction({LogicalType::BLOB, LogicalType::BLOB}, LogicalType::BLOB, ConcatOperator)); + concat_op.AddFunction(ListConcatFun::GetFunction()); + for (auto &fun : concat_op.functions) { + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + } + set.AddFunction(concat_op); + + ScalarFunction concat_ws = ScalarFunction("concat_ws", {LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::VARCHAR, ConcatWSFunction); + concat_ws.varargs = LogicalType::VARCHAR; + concat_ws.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + set.AddFunction(concat_ws); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/contains.cpp b/src/duckdb/src/function/scalar/string/contains.cpp new file mode 100644 index 00000000..fb68e00e --- /dev/null +++ b/src/duckdb/src/function/scalar/string/contains.cpp @@ -0,0 +1,165 @@ +#include "duckdb/function/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +template +static idx_t ContainsUnaligned(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, + idx_t base_offset) { + if (NEEDLE_SIZE > haystack_size) { + // needle is bigger than haystack: haystack cannot contain needle + return DConstants::INVALID_INDEX; + } + // contains for a small unaligned needle (3/5/6/7 bytes) + // we perform unsigned integer comparisons to check for equality of the entire needle in a single comparison + // this implementation is inspired by the memmem implementation of freebsd + + // first we set up the needle and the first NEEDLE_SIZE characters of the haystack as UNSIGNED integers + UNSIGNED needle_entry = 0; + UNSIGNED haystack_entry = 0; + const UNSIGNED start = (sizeof(UNSIGNED) * 8) - 8; + const UNSIGNED shift = (sizeof(UNSIGNED) - NEEDLE_SIZE) * 8; + for (int i = 0; i < NEEDLE_SIZE; i++) { + needle_entry |= UNSIGNED(needle[i]) << UNSIGNED(start - i * 8); + haystack_entry |= UNSIGNED(haystack[i]) << UNSIGNED(start - i * 8); + } + // now we perform the actual search + for (idx_t offset = NEEDLE_SIZE; offset < haystack_size; offset++) { + // for this position we first compare the haystack with the needle + if (haystack_entry == needle_entry) { + return base_offset + offset - NEEDLE_SIZE; + } + // now we adjust the haystack entry by + // (1) removing the left-most character (shift by 8) + // (2) adding the next character (bitwise or, with potential shift) + // this shift is only necessary if the needle size is not aligned with the unsigned integer size + // (e.g. needle size 3, unsigned integer size 4, we need to shift by 1) + haystack_entry = (haystack_entry << 8) | ((UNSIGNED(haystack[offset])) << shift); + } + if (haystack_entry == needle_entry) { + return base_offset + haystack_size - NEEDLE_SIZE; + } + return DConstants::INVALID_INDEX; +} + +template +static idx_t ContainsAligned(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, + idx_t base_offset) { + if (sizeof(UNSIGNED) > haystack_size) { + // needle is bigger than haystack: haystack cannot contain needle + return DConstants::INVALID_INDEX; + } + // contains for a small needle aligned with unsigned integer (2/4/8) + // similar to ContainsUnaligned, but simpler because we only need to do a reinterpret cast + auto needle_entry = Load(needle); + for (idx_t offset = 0; offset <= haystack_size - sizeof(UNSIGNED); offset++) { + // for this position we first compare the haystack with the needle + auto haystack_entry = Load(haystack + offset); + if (needle_entry == haystack_entry) { + return base_offset + offset; + } + } + return DConstants::INVALID_INDEX; +} + +idx_t ContainsGeneric(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, + idx_t needle_size, idx_t base_offset) { + if (needle_size > haystack_size) { + // needle is bigger than haystack: haystack cannot contain needle + return DConstants::INVALID_INDEX; + } + // this implementation is inspired by Raphael Javaux's faststrstr (https://github.com/RaphaelJ/fast_strstr) + // generic contains; note that we can't use strstr because we don't have null-terminated strings anymore + // we keep track of a shifting window sum of all characters with window size equal to needle_size + // this shifting sum is used to avoid calling into memcmp; + // we only need to call into memcmp when the window sum is equal to the needle sum + // when that happens, the characters are potentially the same and we call into memcmp to check if they are + uint32_t sums_diff = 0; + for (idx_t i = 0; i < needle_size; i++) { + sums_diff += haystack[i]; + sums_diff -= needle[i]; + } + idx_t offset = 0; + while (true) { + if (sums_diff == 0 && haystack[offset] == needle[0]) { + if (memcmp(haystack + offset, needle, needle_size) == 0) { + return base_offset + offset; + } + } + if (offset >= haystack_size - needle_size) { + return DConstants::INVALID_INDEX; + } + sums_diff -= haystack[offset]; + sums_diff += haystack[offset + needle_size]; + offset++; + } +} + +idx_t ContainsFun::Find(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, + idx_t needle_size) { + D_ASSERT(needle_size > 0); + // start off by performing a memchr to find the first character of the + auto location = memchr(haystack, needle[0], haystack_size); + if (location == nullptr) { + return DConstants::INVALID_INDEX; + } + idx_t base_offset = const_uchar_ptr_cast(location) - haystack; + haystack_size -= base_offset; + haystack = const_uchar_ptr_cast(location); + // switch algorithm depending on needle size + switch (needle_size) { + case 1: + return base_offset; + case 2: + return ContainsAligned(haystack, haystack_size, needle, base_offset); + case 3: + return ContainsUnaligned(haystack, haystack_size, needle, base_offset); + case 4: + return ContainsAligned(haystack, haystack_size, needle, base_offset); + case 5: + return ContainsUnaligned(haystack, haystack_size, needle, base_offset); + case 6: + return ContainsUnaligned(haystack, haystack_size, needle, base_offset); + case 7: + return ContainsUnaligned(haystack, haystack_size, needle, base_offset); + case 8: + return ContainsAligned(haystack, haystack_size, needle, base_offset); + default: + return ContainsGeneric(haystack, haystack_size, needle, needle_size, base_offset); + } +} + +idx_t ContainsFun::Find(const string_t &haystack_s, const string_t &needle_s) { + auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); + auto haystack_size = haystack_s.GetSize(); + auto needle = const_uchar_ptr_cast(needle_s.GetData()); + auto needle_size = needle_s.GetSize(); + if (needle_size == 0) { + // empty needle: always true + return 0; + } + return ContainsFun::Find(haystack, haystack_size, needle, needle_size); +} + +struct ContainsOperator { + template + static inline TR Operation(TA left, TB right) { + return ContainsFun::Find(left, right) != DConstants::INVALID_INDEX; + } +}; + +ScalarFunction ContainsFun::GetFunction() { + return ScalarFunction("contains", // name of the function + {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list + LogicalType::BOOLEAN, // return type + ScalarFunction::BinaryFunction); +} + +void ContainsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(GetFunction()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/length.cpp b/src/duckdb/src/function/scalar/string/length.cpp new file mode 100644 index 00000000..d4fe021b --- /dev/null +++ b/src/duckdb/src/function/scalar/string/length.cpp @@ -0,0 +1,141 @@ +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/common/types/bit.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "utf8proc.hpp" + +namespace duckdb { + +// length returns the number of unicode codepoints +struct StringLengthOperator { + template + static inline TR Operation(TA input) { + return LengthFun::Length(input); + } +}; + +struct GraphemeCountOperator { + template + static inline TR Operation(TA input) { + return LengthFun::GraphemeCount(input); + } +}; + +struct ArrayLengthOperator { + template + static inline TR Operation(TA input) { + return input.length; + } +}; + +struct ArrayLengthBinaryOperator { + template + static inline TR Operation(TA input, TB dimension) { + if (dimension != 1) { + throw NotImplementedException("array_length for dimensions other than 1 not implemented"); + } + return input.length; + } +}; + +// strlen returns the size in bytes +struct StrLenOperator { + template + static inline TR Operation(TA input) { + return input.GetSize(); + } +}; + +struct OctetLenOperator { + template + static inline TR Operation(TA input) { + return Bit::OctetLength(input); + } +}; + +// bitlen returns the size in bits +struct BitLenOperator { + template + static inline TR Operation(TA input) { + return 8 * input.GetSize(); + } +}; + +// bitstringlen returns the amount of bits in a bitstring +struct BitStringLenOperator { + template + static inline TR Operation(TA input) { + return Bit::BitLength(input); + } +}; + +static unique_ptr LengthPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 1); + // can only propagate stats if the children have stats + if (!StringStats::CanContainUnicode(child_stats[0])) { + expr.function.function = ScalarFunction::UnaryFunction; + } + return nullptr; +} + +static unique_ptr ListLengthBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments[0]->HasParameter()) { + throw ParameterNotResolvedException(); + } + bound_function.arguments[0] = arguments[0]->return_type; + return nullptr; +} + +void LengthFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunction array_length_unary = + ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction, ListLengthBind); + ScalarFunctionSet length("length"); + length.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction, nullptr, + nullptr, LengthPropagateStats)); + length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction)); + length.AddFunction(array_length_unary); + set.AddFunction(length); + length.name = "len"; + set.AddFunction(length); + + ScalarFunctionSet length_grapheme("length_grapheme"); + length_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction, + nullptr, nullptr, LengthPropagateStats)); + set.AddFunction(length_grapheme); + + ScalarFunctionSet array_length("array_length"); + array_length.AddFunction(array_length_unary); + array_length.AddFunction(ScalarFunction( + {LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::BIGINT, + ScalarFunction::BinaryFunction, ListLengthBind)); + set.AddFunction(array_length); + + set.AddFunction(ScalarFunction("strlen", {LogicalType::VARCHAR}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction)); + ScalarFunctionSet bit_length("bit_length"); + bit_length.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction)); + bit_length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction)); + set.AddFunction(bit_length); + // length for BLOB type + ScalarFunctionSet octet_length("octet_length"); + octet_length.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction)); + octet_length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction)); + set.AddFunction(octet_length); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/like.cpp b/src/duckdb/src/function/scalar/string/like.cpp new file mode 100644 index 00000000..eb9f2e34 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/like.cpp @@ -0,0 +1,543 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +struct StandardCharacterReader { + static char Operation(const char *data, idx_t pos) { + return data[pos]; + } +}; + +struct ASCIILCaseReader { + static char Operation(const char *data, idx_t pos) { + return (char)LowerFun::ascii_to_lower_map[(uint8_t)data[pos]]; + } +}; + +template +bool TemplatedLikeOperator(const char *sdata, idx_t slen, const char *pdata, idx_t plen, char escape) { + idx_t pidx = 0; + idx_t sidx = 0; + for (; pidx < plen && sidx < slen; pidx++) { + char pchar = READER::Operation(pdata, pidx); + char schar = READER::Operation(sdata, sidx); + if (HAS_ESCAPE && pchar == escape) { + pidx++; + if (pidx == plen) { + throw SyntaxException("Like pattern must not end with escape character!"); + } + if (pdata[pidx] != schar) { + return false; + } + sidx++; + } else if (pchar == UNDERSCORE) { + sidx++; + } else if (pchar == PERCENTAGE) { + pidx++; + while (pidx < plen && pdata[pidx] == PERCENTAGE) { + pidx++; + } + if (pidx == plen) { + return true; /* tail is acceptable */ + } + for (; sidx < slen; sidx++) { + if (TemplatedLikeOperator( + sdata + sidx, slen - sidx, pdata + pidx, plen - pidx, escape)) { + return true; + } + } + return false; + } else if (pchar == schar) { + sidx++; + } else { + return false; + } + } + while (pidx < plen && pdata[pidx] == PERCENTAGE) { + pidx++; + } + return pidx == plen && sidx == slen; +} + +struct LikeSegment { + explicit LikeSegment(string pattern) : pattern(std::move(pattern)) { + } + + string pattern; +}; + +struct LikeMatcher : public FunctionData { + LikeMatcher(string like_pattern_p, vector segments, bool has_start_percentage, bool has_end_percentage) + : like_pattern(std::move(like_pattern_p)), segments(std::move(segments)), + has_start_percentage(has_start_percentage), has_end_percentage(has_end_percentage) { + } + + bool Match(string_t &str) { + auto str_data = const_uchar_ptr_cast(str.GetData()); + auto str_len = str.GetSize(); + idx_t segment_idx = 0; + idx_t end_idx = segments.size() - 1; + if (!has_start_percentage) { + // no start sample_size: match the first part of the string directly + auto &segment = segments[0]; + if (str_len < segment.pattern.size()) { + return false; + } + if (memcmp(str_data, segment.pattern.c_str(), segment.pattern.size()) != 0) { + return false; + } + str_data += segment.pattern.size(); + str_len -= segment.pattern.size(); + segment_idx++; + if (segments.size() == 1) { + // only one segment, and it matches + // we have a match if there is an end sample_size, OR if the memcmp was an exact match (remaining str is + // empty) + return has_end_percentage || str_len == 0; + } + } + // main match loop: for every segment in the middle, use Contains to find the needle in the haystack + for (; segment_idx < end_idx; segment_idx++) { + auto &segment = segments[segment_idx]; + // find the pattern of the current segment + idx_t next_offset = ContainsFun::Find(str_data, str_len, const_uchar_ptr_cast(segment.pattern.c_str()), + segment.pattern.size()); + if (next_offset == DConstants::INVALID_INDEX) { + // could not find this pattern in the string: no match + return false; + } + idx_t offset = next_offset + segment.pattern.size(); + str_data += offset; + str_len -= offset; + } + if (!has_end_percentage) { + end_idx--; + // no end sample_size: match the final segment now + auto &segment = segments.back(); + if (str_len < segment.pattern.size()) { + return false; + } + if (memcmp(str_data + str_len - segment.pattern.size(), segment.pattern.c_str(), segment.pattern.size()) != + 0) { + return false; + } + return true; + } else { + auto &segment = segments.back(); + // find the pattern of the current segment + idx_t next_offset = ContainsFun::Find(str_data, str_len, const_uchar_ptr_cast(segment.pattern.c_str()), + segment.pattern.size()); + return next_offset != DConstants::INVALID_INDEX; + } + } + + static unique_ptr CreateLikeMatcher(string like_pattern, char escape = '\0') { + vector segments; + idx_t last_non_pattern = 0; + bool has_start_percentage = false; + bool has_end_percentage = false; + for (idx_t i = 0; i < like_pattern.size(); i++) { + auto ch = like_pattern[i]; + if (ch == escape || ch == '%' || ch == '_') { + // special character, push a constant pattern + if (i > last_non_pattern) { + segments.emplace_back(like_pattern.substr(last_non_pattern, i - last_non_pattern)); + } + last_non_pattern = i + 1; + if (ch == escape || ch == '_') { + // escape or underscore: could not create efficient like matcher + // FIXME: we could handle escaped percentages here + return nullptr; + } else { + // sample_size + if (i == 0) { + has_start_percentage = true; + } + if (i + 1 == like_pattern.size()) { + has_end_percentage = true; + } + } + } + } + if (last_non_pattern < like_pattern.size()) { + segments.emplace_back(like_pattern.substr(last_non_pattern, like_pattern.size() - last_non_pattern)); + } + if (segments.empty()) { + return nullptr; + } + return make_uniq(std::move(like_pattern), std::move(segments), has_start_percentage, + has_end_percentage); + } + + unique_ptr Copy() const override { + return make_uniq(like_pattern, segments, has_start_percentage, has_end_percentage); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return like_pattern == other.like_pattern; + } + +private: + string like_pattern; + vector segments; + bool has_start_percentage; + bool has_end_percentage; +}; + +static unique_ptr LikeBindFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // pattern is the second argument. If its constant, we can already prepare the pattern and store it for later. + D_ASSERT(arguments.size() == 2 || arguments.size() == 3); + if (arguments[1]->IsFoldable()) { + Value pattern_str = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + return LikeMatcher::CreateLikeMatcher(pattern_str.ToString()); + } + return nullptr; +} + +bool LikeOperatorFunction(const char *s, idx_t slen, const char *pattern, idx_t plen, char escape) { + return TemplatedLikeOperator<'%', '_', true>(s, slen, pattern, plen, escape); +} + +bool LikeOperatorFunction(const char *s, idx_t slen, const char *pattern, idx_t plen) { + return TemplatedLikeOperator<'%', '_', false>(s, slen, pattern, plen, '\0'); +} + +bool LikeOperatorFunction(string_t &s, string_t &pat) { + return LikeOperatorFunction(s.GetData(), s.GetSize(), pat.GetData(), pat.GetSize()); +} + +bool LikeOperatorFunction(string_t &s, string_t &pat, char escape) { + return LikeOperatorFunction(s.GetData(), s.GetSize(), pat.GetData(), pat.GetSize(), escape); +} + +bool LikeFun::Glob(const char *string, idx_t slen, const char *pattern, idx_t plen, bool allow_question_mark) { + idx_t sidx = 0; + idx_t pidx = 0; +main_loop : { + // main matching loop + while (sidx < slen && pidx < plen) { + char s = string[sidx]; + char p = pattern[pidx]; + switch (p) { + case '*': { + // asterisk: match any set of characters + // skip any subsequent asterisks + pidx++; + while (pidx < plen && pattern[pidx] == '*') { + pidx++; + } + // if the asterisk is the last character, the pattern always matches + if (pidx == plen) { + return true; + } + // recursively match the remainder of the pattern + for (; sidx < slen; sidx++) { + if (LikeFun::Glob(string + sidx, slen - sidx, pattern + pidx, plen - pidx)) { + return true; + } + } + return false; + } + case '?': + // when enabled: matches anything but null + if (allow_question_mark) { + break; + } + DUCKDB_EXPLICIT_FALLTHROUGH; + case '[': + pidx++; + goto parse_bracket; + case '\\': + // escape character, next character needs to match literally + pidx++; + // check that we still have a character remaining + if (pidx == plen) { + return false; + } + p = pattern[pidx]; + if (s != p) { + return false; + } + break; + default: + // not a control character: characters need to match literally + if (s != p) { + return false; + } + break; + } + sidx++; + pidx++; + } + while (pidx < plen && pattern[pidx] == '*') { + pidx++; + } + // we are finished only if we have consumed the full pattern + return pidx == plen && sidx == slen; +} +parse_bracket : { + // inside a bracket + if (pidx == plen) { + return false; + } + // check the first character + // if it is an exclamation mark we need to invert our logic + char p = pattern[pidx]; + char s = string[sidx]; + bool invert = false; + if (p == '!') { + invert = true; + pidx++; + } + bool found_match = invert; + idx_t start_pos = pidx; + bool found_closing_bracket = false; + // now check the remainder of the pattern + while (pidx < plen) { + p = pattern[pidx]; + // if the first character is a closing bracket, we match it literally + // otherwise it indicates an end of bracket + if (p == ']' && pidx > start_pos) { + // end of bracket found: we are done + found_closing_bracket = true; + pidx++; + break; + } + // we either match a range (a-b) or a single character (a) + // check if the next character is a dash + if (pidx + 1 == plen) { + // no next character! + break; + } + bool matches; + if (pattern[pidx + 1] == '-') { + // range! find the next character in the range + if (pidx + 2 == plen) { + break; + } + char next_char = pattern[pidx + 2]; + // check if the current character is within the range + matches = s >= p && s <= next_char; + // shift the pattern forward past the range + pidx += 3; + } else { + // no range! perform a direct match + matches = p == s; + // shift the pattern forward past the character + pidx++; + } + if (found_match == invert && matches) { + // found a match! set the found_matches flag + // we keep on pattern matching after this until we reach the end bracket + // however, we don't need to update the found_match flag anymore + found_match = !invert; + } + } + if (!found_closing_bracket) { + // no end of bracket: invalid pattern + return false; + } + if (!found_match) { + // did not match the bracket: return false; + return false; + } + // finished the bracket matching: move forward + sidx++; + goto main_loop; +} +} + +static char GetEscapeChar(string_t escape) { + // Only one escape character should be allowed + if (escape.GetSize() > 1) { + throw SyntaxException("Invalid escape string. Escape string must be empty or one character."); + } + return escape.GetSize() == 0 ? '\0' : *escape.GetData(); +} + +struct LikeEscapeOperator { + template + static inline bool Operation(TA str, TB pattern, TC escape) { + char escape_char = GetEscapeChar(escape); + return LikeOperatorFunction(str.GetData(), str.GetSize(), pattern.GetData(), pattern.GetSize(), escape_char); + } +}; + +struct NotLikeEscapeOperator { + template + static inline bool Operation(TA str, TB pattern, TC escape) { + return !LikeEscapeOperator::Operation(str, pattern, escape); + } +}; + +struct LikeOperator { + template + static inline TR Operation(TA str, TB pattern) { + return LikeOperatorFunction(str, pattern); + } +}; + +bool ILikeOperatorFunction(string_t &str, string_t &pattern, char escape = '\0') { + auto str_data = str.GetData(); + auto str_size = str.GetSize(); + auto pat_data = pattern.GetData(); + auto pat_size = pattern.GetSize(); + + // lowercase both the str and the pattern + idx_t str_llength = LowerFun::LowerLength(str_data, str_size); + auto str_ldata = make_unsafe_uniq_array(str_llength); + LowerFun::LowerCase(str_data, str_size, str_ldata.get()); + + idx_t pat_llength = LowerFun::LowerLength(pat_data, pat_size); + auto pat_ldata = make_unsafe_uniq_array(pat_llength); + LowerFun::LowerCase(pat_data, pat_size, pat_ldata.get()); + string_t str_lcase(str_ldata.get(), str_llength); + string_t pat_lcase(pat_ldata.get(), pat_llength); + return LikeOperatorFunction(str_lcase, pat_lcase, escape); +} + +struct ILikeEscapeOperator { + template + static inline bool Operation(TA str, TB pattern, TC escape) { + char escape_char = GetEscapeChar(escape); + return ILikeOperatorFunction(str, pattern, escape_char); + } +}; + +struct NotILikeEscapeOperator { + template + static inline bool Operation(TA str, TB pattern, TC escape) { + return !ILikeEscapeOperator::Operation(str, pattern, escape); + } +}; + +struct ILikeOperator { + template + static inline TR Operation(TA str, TB pattern) { + return ILikeOperatorFunction(str, pattern); + } +}; + +struct NotLikeOperator { + template + static inline TR Operation(TA str, TB pattern) { + return !LikeOperatorFunction(str, pattern); + } +}; + +struct NotILikeOperator { + template + static inline TR Operation(TA str, TB pattern) { + return !ILikeOperator::Operation(str, pattern); + } +}; + +struct ILikeOperatorASCII { + template + static inline TR Operation(TA str, TB pattern) { + return TemplatedLikeOperator<'%', '_', false, ASCIILCaseReader>(str.GetData(), str.GetSize(), pattern.GetData(), + pattern.GetSize(), '\0'); + } +}; + +struct NotILikeOperatorASCII { + template + static inline TR Operation(TA str, TB pattern) { + return !ILikeOperatorASCII::Operation(str, pattern); + } +}; + +struct GlobOperator { + template + static inline TR Operation(TA str, TB pattern) { + return LikeFun::Glob(str.GetData(), str.GetSize(), pattern.GetData(), pattern.GetSize()); + } +}; + +// This can be moved to the scalar_function class +template +static void LikeEscapeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str = args.data[0]; + auto &pattern = args.data[1]; + auto &escape = args.data[2]; + + TernaryExecutor::Execute( + str, pattern, escape, result, args.size(), FUNC::template Operation); +} + +template +static unique_ptr ILikePropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() >= 1); + // can only propagate stats if the children have stats + if (!StringStats::CanContainUnicode(child_stats[0])) { + expr.function.function = ScalarFunction::BinaryFunction; + } + return nullptr; +} + +template +static void RegularLikeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + if (func_expr.bind_info) { + auto &matcher = func_expr.bind_info->Cast(); + // use fast like matcher + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](string_t input) { + return INVERT ? !matcher.Match(input) : matcher.Match(input); + }); + } else { + // use generic like matcher + BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result, + input.size()); + } +} +void LikeFun::RegisterFunction(BuiltinFunctions &set) { + // like + set.AddFunction(GetLikeFunction()); + // not like + set.AddFunction(ScalarFunction("!~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + RegularLikeFunction, LikeBindFunction)); + // glob + set.AddFunction(ScalarFunction("~~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + ScalarFunction::BinaryFunction)); + // ilike + set.AddFunction(ScalarFunction("~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + ScalarFunction::BinaryFunction, nullptr, + nullptr, ILikePropagateStats)); + // not ilike + set.AddFunction(ScalarFunction("!~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + ScalarFunction::BinaryFunction, nullptr, + nullptr, ILikePropagateStats)); +} + +ScalarFunction LikeFun::GetLikeFunction() { + return ScalarFunction("~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + RegularLikeFunction, LikeBindFunction); +} + +void LikeEscapeFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(GetLikeEscapeFun()); + set.AddFunction({"not_like_escape"}, + ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::BOOLEAN, LikeEscapeFunction)); + + set.AddFunction({"ilike_escape"}, ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::BOOLEAN, LikeEscapeFunction)); + set.AddFunction({"not_ilike_escape"}, + ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::BOOLEAN, LikeEscapeFunction)); +} + +ScalarFunction LikeEscapeFun::GetLikeEscapeFun() { + return ScalarFunction("like_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::BOOLEAN, LikeEscapeFunction); +} +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/nfc_normalize.cpp b/src/duckdb/src/function/scalar/string/nfc_normalize.cpp new file mode 100644 index 00000000..28984335 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/nfc_normalize.cpp @@ -0,0 +1,38 @@ +#include "duckdb/function/scalar/string_functions.hpp" + +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +struct NFCNormalizeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + if (StripAccentsFun::IsAscii(input_data, input_length)) { + return input; + } + auto normalized_str = Utf8Proc::Normalize(input_data, input_length); + D_ASSERT(normalized_str); + auto result_str = StringVector::AddString(result, normalized_str); + free(normalized_str); + return result_str; + } +}; + +static void NFCNormalizeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + StringVector::AddHeapReference(result, args.data[0]); +} + +ScalarFunction NFCNormalizeFun::GetFunction() { + return ScalarFunction("nfc_normalize", {LogicalType::VARCHAR}, LogicalType::VARCHAR, NFCNormalizeFunction); +} + +void NFCNormalizeFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(NFCNormalizeFun::GetFunction()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/prefix.cpp b/src/duckdb/src/function/scalar/string/prefix.cpp new file mode 100644 index 00000000..d15c1e02 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/prefix.cpp @@ -0,0 +1,72 @@ +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/common/types/string_type.hpp" + +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +static bool PrefixFunction(const string_t &str, const string_t &pattern); + +struct PrefixOperator { + template + static inline TR Operation(TA left, TB right) { + return PrefixFunction(left, right); + } +}; +static bool PrefixFunction(const string_t &str, const string_t &pattern) { + auto str_length = str.GetSize(); + auto patt_length = pattern.GetSize(); + if (patt_length > str_length) { + return false; + } + if (patt_length <= string_t::PREFIX_LENGTH) { + // short prefix + if (patt_length == 0) { + // length = 0, return true + return true; + } + + // prefix early out + const char *str_pref = str.GetPrefix(); + const char *patt_pref = pattern.GetPrefix(); + for (idx_t i = 0; i < patt_length; ++i) { + if (str_pref[i] != patt_pref[i]) { + return false; + } + } + return true; + } else { + // prefix early out + const char *str_pref = str.GetPrefix(); + const char *patt_pref = pattern.GetPrefix(); + for (idx_t i = 0; i < string_t::PREFIX_LENGTH; ++i) { + if (str_pref[i] != patt_pref[i]) { + // early out + return false; + } + } + // compare the rest of the prefix + const char *str_data = str.GetData(); + const char *patt_data = pattern.GetData(); + D_ASSERT(patt_length <= str_length); + for (idx_t i = string_t::PREFIX_LENGTH; i < patt_length; ++i) { + if (str_data[i] != patt_data[i]) { + return false; + } + } + return true; + } +} + +ScalarFunction PrefixFun::GetFunction() { + return ScalarFunction("prefix", // name of the function + {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list + LogicalType::BOOLEAN, // return type + ScalarFunction::BinaryFunction); +} + +void PrefixFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(GetFunction()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/regexp.cpp b/src/duckdb/src/function/scalar/string/regexp.cpp new file mode 100644 index 00000000..d14ac7ac --- /dev/null +++ b/src/duckdb/src/function/scalar/string/regexp.cpp @@ -0,0 +1,458 @@ +#include "duckdb/function/scalar/regexp.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +using regexp_util::CreateStringPiece; +using regexp_util::Extract; +using regexp_util::ParseRegexOptions; +using regexp_util::TryParseConstantPattern; + +static bool RegexOptionsEquals(const duckdb_re2::RE2::Options &opt_a, const duckdb_re2::RE2::Options &opt_b) { + return opt_a.case_sensitive() == opt_b.case_sensitive(); +} + +RegexpBaseBindData::RegexpBaseBindData() : constant_pattern(false) { +} +RegexpBaseBindData::RegexpBaseBindData(duckdb_re2::RE2::Options options, string constant_string_p, + bool constant_pattern) + : options(options), constant_string(std::move(constant_string_p)), constant_pattern(constant_pattern) { +} + +RegexpBaseBindData::~RegexpBaseBindData() { +} + +bool RegexpBaseBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return constant_pattern == other.constant_pattern && constant_string == other.constant_string && + RegexOptionsEquals(options, other.options); +} + +unique_ptr RegexInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + auto &info = bind_data->Cast(); + if (info.constant_pattern) { + return make_uniq(info); + } + return nullptr; +} + +//===--------------------------------------------------------------------===// +// Regexp Matches +//===--------------------------------------------------------------------===// +RegexpMatchesBindData::RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string_p, + bool constant_pattern) + : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern) { + if (constant_pattern) { + auto pattern = make_uniq(constant_string, options); + if (!pattern->ok()) { + throw Exception(pattern->error()); + } + + range_success = pattern->PossibleMatchRange(&range_min, &range_max, 1000); + } else { + range_success = false; + } +} + +RegexpMatchesBindData::RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string_p, + bool constant_pattern, string range_min_p, string range_max_p, + bool range_success) + : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), range_min(std::move(range_min_p)), + range_max(std::move(range_max_p)), range_success(range_success) { +} + +unique_ptr RegexpMatchesBindData::Copy() const { + return make_uniq(options, constant_string, constant_pattern, range_min, range_max, + range_success); +} + +unique_ptr RegexpMatchesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // pattern is the second argument. If its constant, we can already prepare the pattern and store it for later. + D_ASSERT(arguments.size() == 2 || arguments.size() == 3); + RE2::Options options; + options.set_log_errors(false); + if (arguments.size() == 3) { + ParseRegexOptions(context, *arguments[2], options); + } + + string constant_string; + bool constant_pattern; + constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); + return make_uniq(options, std::move(constant_string), constant_pattern); +} + +struct RegexPartialMatch { + static inline bool Operation(const duckdb_re2::StringPiece &input, duckdb_re2::RE2 &re) { + return duckdb_re2::RE2::PartialMatch(input, re); + } +}; + +struct RegexFullMatch { + static inline bool Operation(const duckdb_re2::StringPiece &input, duckdb_re2::RE2 &re) { + return duckdb_re2::RE2::FullMatch(input, re); + } +}; + +template +static void RegexpMatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &strings = args.data[0]; + auto &patterns = args.data[1]; + + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + if (info.constant_pattern) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + UnaryExecutor::Execute(strings, result, args.size(), [&](string_t input) { + return OP::Operation(CreateStringPiece(input), lstate.constant_pattern); + }); + } else { + BinaryExecutor::Execute(strings, patterns, result, args.size(), + [&](string_t input, string_t pattern) { + RE2 re(CreateStringPiece(pattern), info.options); + if (!re.ok()) { + throw Exception(re.error()); + } + return OP::Operation(CreateStringPiece(input), re); + }); + } +} + +//===--------------------------------------------------------------------===// +// Regexp Replace +//===--------------------------------------------------------------------===// +RegexpReplaceBindData::RegexpReplaceBindData() : global_replace(false) { +} + +RegexpReplaceBindData::RegexpReplaceBindData(duckdb_re2::RE2::Options options, string constant_string_p, + bool constant_pattern, bool global_replace) + : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), global_replace(global_replace) { +} + +unique_ptr RegexpReplaceBindData::Copy() const { + auto copy = make_uniq(options, constant_string, constant_pattern, global_replace); + return std::move(copy); +} + +bool RegexpReplaceBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return RegexpBaseBindData::Equals(other) && global_replace == other.global_replace; +} + +static unique_ptr RegexReplaceBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto data = make_uniq(); + + data->constant_pattern = TryParseConstantPattern(context, *arguments[1], data->constant_string); + if (arguments.size() == 4) { + ParseRegexOptions(context, *arguments[3], data->options, &data->global_replace); + } + data->options.set_log_errors(false); + return std::move(data); +} + +static void RegexReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + auto &strings = args.data[0]; + auto &patterns = args.data[1]; + auto &replaces = args.data[2]; + + if (info.constant_pattern) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + BinaryExecutor::Execute( + strings, replaces, result, args.size(), [&](string_t input, string_t replace) { + std::string sstring = input.GetString(); + if (info.global_replace) { + RE2::GlobalReplace(&sstring, lstate.constant_pattern, CreateStringPiece(replace)); + } else { + RE2::Replace(&sstring, lstate.constant_pattern, CreateStringPiece(replace)); + } + return StringVector::AddString(result, sstring); + }); + } else { + TernaryExecutor::Execute( + strings, patterns, replaces, result, args.size(), [&](string_t input, string_t pattern, string_t replace) { + RE2 re(CreateStringPiece(pattern), info.options); + std::string sstring = input.GetString(); + if (info.global_replace) { + RE2::GlobalReplace(&sstring, re, CreateStringPiece(replace)); + } else { + RE2::Replace(&sstring, re, CreateStringPiece(replace)); + } + return StringVector::AddString(result, sstring); + }); + } +} + +//===--------------------------------------------------------------------===// +// Regexp Extract +//===--------------------------------------------------------------------===// +RegexpExtractBindData::RegexpExtractBindData() { +} + +RegexpExtractBindData::RegexpExtractBindData(duckdb_re2::RE2::Options options, string constant_string_p, + bool constant_pattern, string group_string_p) + : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), + group_string(std::move(group_string_p)), rewrite(group_string) { +} + +unique_ptr RegexpExtractBindData::Copy() const { + return make_uniq(options, constant_string, constant_pattern, group_string); +} + +bool RegexpExtractBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return RegexpBaseBindData::Equals(other) && group_string == other.group_string; +} + +static void RegexExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + const auto &info = func_expr.bind_info->Cast(); + + auto &strings = args.data[0]; + auto &patterns = args.data[1]; + if (info.constant_pattern) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + UnaryExecutor::Execute(strings, result, args.size(), [&](string_t input) { + return Extract(input, result, lstate.constant_pattern, info.rewrite); + }); + } else { + BinaryExecutor::Execute(strings, patterns, result, args.size(), + [&](string_t input, string_t pattern) { + RE2 re(CreateStringPiece(pattern), info.options); + return Extract(input, result, re, info.rewrite); + }); + } +} + +//===--------------------------------------------------------------------===// +// Regexp Extract Struct +//===--------------------------------------------------------------------===// +static void RegexExtractStructFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + + const auto count = args.size(); + auto &input = args.data[0]; + + auto &child_entries = StructVector::GetEntries(result); + const auto groupSize = child_entries.size(); + // Reference the 'input' StringBuffer, because we won't need to allocate new data + // for the result, all returned strings are substrings of the originals + for (auto &child_entry : child_entries) { + child_entry->SetAuxiliary(input.GetAuxiliary()); + } + + vector argv(groupSize); + vector groups(groupSize); + vector ws(groupSize); + for (size_t i = 0; i < groupSize; ++i) { + groups[i] = &argv[i]; + argv[i] = &ws[i]; + } + + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + if (ConstantVector::IsNull(input)) { + ConstantVector::SetNull(result, true); + } else { + ConstantVector::SetNull(result, false); + auto idata = ConstantVector::GetData(input); + auto str = CreateStringPiece(idata[0]); + auto match = duckdb_re2::RE2::PartialMatchN(str, lstate.constant_pattern, groups.data(), groups.size()); + for (size_t col = 0; col < child_entries.size(); ++col) { + auto &child_entry = child_entries[col]; + ConstantVector::SetNull(*child_entry, false); + auto &extracted = ws[col]; + auto cdata = ConstantVector::GetData(*child_entry); + cdata[0] = string_t(extracted.data(), match ? extracted.size() : 0); + } + } + } else { + UnifiedVectorFormat iunified; + input.ToUnifiedFormat(count, iunified); + + const auto &ivalidity = iunified.validity; + auto idata = UnifiedVectorFormat::GetData(iunified); + + // Start with a valid flat vector + result.SetVectorType(VectorType::FLAT_VECTOR); + + // Start with valid children + for (size_t col = 0; col < child_entries.size(); ++col) { + auto &child_entry = child_entries[col]; + child_entry->SetVectorType(VectorType::FLAT_VECTOR); + } + + for (idx_t i = 0; i < count; ++i) { + const auto idx = iunified.sel->get_index(i); + if (ivalidity.RowIsValid(idx)) { + auto str = CreateStringPiece(idata[idx]); + auto match = duckdb_re2::RE2::PartialMatchN(str, lstate.constant_pattern, groups.data(), groups.size()); + for (size_t col = 0; col < child_entries.size(); ++col) { + auto &child_entry = child_entries[col]; + auto cdata = FlatVector::GetData(*child_entry); + auto &extracted = ws[col]; + cdata[i] = string_t(extracted.data(), match ? extracted.size() : 0); + } + } else { + FlatVector::SetNull(result, i, true); + } + } + } +} + +static unique_ptr RegexExtractBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(arguments.size() >= 2); + + duckdb_re2::RE2::Options options; + + string constant_string; + bool constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); + + if (arguments.size() >= 4) { + ParseRegexOptions(context, *arguments[3], options); + } + + string group_string = "\\0"; + if (arguments.size() >= 3) { + if (arguments[2]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[2]->IsFoldable()) { + throw InvalidInputException("Group specification field must be a constant!"); + } + Value group = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); + if (group.IsNull()) { + group_string = ""; + } else if (group.type().id() == LogicalTypeId::LIST) { + if (!constant_pattern) { + throw BinderException("%s with LIST requires a constant pattern", bound_function.name); + } + auto &list_children = ListValue::GetChildren(group); + if (list_children.empty()) { + throw BinderException("%s requires non-empty lists of capture names", bound_function.name); + } + case_insensitive_set_t name_collision_set; + child_list_t struct_children; + for (const auto &child : list_children) { + if (child.IsNull()) { + throw BinderException("NULL group name in %s", bound_function.name); + } + const auto group_name = child.ToString(); + if (name_collision_set.find(group_name) != name_collision_set.end()) { + throw BinderException("Duplicate group name \"%s\" in %s", group_name, bound_function.name); + } + name_collision_set.insert(group_name); + struct_children.emplace_back(make_pair(group_name, LogicalType::VARCHAR)); + } + bound_function.return_type = LogicalType::STRUCT(struct_children); + + duckdb_re2::StringPiece constant_piece(constant_string.c_str(), constant_string.size()); + RE2 constant_pattern(constant_piece, options); + if (size_t(constant_pattern.NumberOfCapturingGroups()) < list_children.size()) { + throw BinderException("Not enough group names in %s", bound_function.name); + } + } else { + auto group_idx = group.GetValue(); + if (group_idx < 0 || group_idx > 9) { + throw InvalidInputException("Group index must be between 0 and 9!"); + } + group_string = "\\" + to_string(group_idx); + } + } + + return make_uniq(options, std::move(constant_string), constant_pattern, + std::move(group_string)); +} + +void RegexpFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunctionSet regexp_full_match("regexp_full_match"); + regexp_full_match.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegexpMatchesFunction, + RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + regexp_full_match.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, + LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + + ScalarFunctionSet regexp_partial_match("regexp_matches"); + regexp_partial_match.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegexpMatchesFunction, + RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + regexp_partial_match.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, + LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + + ScalarFunctionSet regexp_replace("regexp_replace"); + regexp_replace.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::VARCHAR, RegexReplaceFunction, RegexReplaceBind, nullptr, + nullptr, RegexInitLocalState)); + regexp_replace.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + RegexReplaceFunction, RegexReplaceBind, nullptr, nullptr, RegexInitLocalState)); + + ScalarFunctionSet regexp_extract("regexp_extract"); + regexp_extract.AddFunction( + ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, RegexExtractFunction, + RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + regexp_extract.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::VARCHAR, RegexExtractFunction, + RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + regexp_extract.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, + RegexExtractFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + // REGEXP_EXTRACT(, , [[, ]...]) + regexp_extract.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR)}, LogicalType::VARCHAR, + RegexExtractStructFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + // REGEXP_EXTRACT(, , [[, ]...], ) + regexp_extract.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR), LogicalType::VARCHAR}, + LogicalType::VARCHAR, RegexExtractStructFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, + LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + + ScalarFunctionSet regexp_extract_all("regexp_extract_all"); + regexp_extract_all.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), + RegexpExtractAll::Execute, RegexpExtractAll::Bind, nullptr, nullptr, RegexpExtractAll::InitLocalState, + LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + regexp_extract_all.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::LIST(LogicalType::VARCHAR), + RegexpExtractAll::Execute, RegexpExtractAll::Bind, nullptr, nullptr, RegexpExtractAll::InitLocalState, + LogicalType::INVALID, FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + regexp_extract_all.AddFunction( + ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::VARCHAR), RegexpExtractAll::Execute, RegexpExtractAll::Bind, + nullptr, nullptr, RegexpExtractAll::InitLocalState, LogicalType::INVALID, + FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling::SPECIAL_HANDLING)); + + set.AddFunction(regexp_full_match); + set.AddFunction(regexp_partial_match); + set.AddFunction(regexp_replace); + set.AddFunction(regexp_extract); + set.AddFunction(regexp_extract_all); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp new file mode 100644 index 00000000..98b50a79 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp @@ -0,0 +1,243 @@ +#include "duckdb/function/scalar/regexp.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "re2/re2.h" + +namespace duckdb { + +using regexp_util::CreateStringPiece; +using regexp_util::Extract; +using regexp_util::ParseRegexOptions; +using regexp_util::TryParseConstantPattern; + +unique_ptr +RegexpExtractAll::InitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { + auto &info = bind_data->Cast(); + if (info.constant_pattern) { + return make_uniq(info, true); + } + return nullptr; +} + +// Forwards startpos automatically +bool ExtractAll(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t *startpos, + duckdb_re2::StringPiece *groups, int ngroups) { + + D_ASSERT(pattern.ok()); + D_ASSERT(pattern.NumberOfCapturingGroups() == ngroups); + + if (!pattern.Match(input, *startpos, input.size(), pattern.Anchored(), groups, ngroups + 1)) { + return false; + } + idx_t consumed = static_cast(groups[0].end() - (input.begin() + *startpos)); + if (!consumed) { + // Empty match found, have to manually forward the input + // to avoid an infinite loop + // FIXME: support unicode characters + consumed++; + while (*startpos + consumed < input.length() && !LengthFun::IsCharacter(input[*startpos + consumed])) { + consumed++; + } + } + *startpos += consumed; + return true; +} + +void ExtractSingleTuple(const string_t &string, duckdb_re2::RE2 &pattern, int32_t group, RegexStringPieceArgs &args, + Vector &result, idx_t row) { + auto input = CreateStringPiece(string); + + auto &child_vector = ListVector::GetEntry(result); + auto list_content = FlatVector::GetData(child_vector); + auto &child_validity = FlatVector::Validity(child_vector); + + auto current_list_size = ListVector::GetListSize(result); + auto current_list_capacity = ListVector::GetListCapacity(result); + + auto result_data = FlatVector::GetData(result); + auto &list_entry = result_data[row]; + list_entry.offset = current_list_size; + + if (group < 0) { + list_entry.length = 0; + return; + } + // If the requested group index is out of bounds + // we want to throw only if there is a match + bool throw_on_group_found = (idx_t)group > args.size; + + idx_t startpos = 0; + for (idx_t iteration = 0; ExtractAll(input, pattern, &startpos, args.group_buffer, args.size); iteration++) { + if (!iteration && throw_on_group_found) { + throw InvalidInputException("Pattern has %d groups. Cannot access group %d", args.size, group); + } + + // Make sure we have enough room for the new entries + if (current_list_size + 1 >= current_list_capacity) { + ListVector::Reserve(result, current_list_capacity * 2); + current_list_capacity = ListVector::GetListCapacity(result); + list_content = FlatVector::GetData(child_vector); + } + + // Write the captured groups into the list-child vector + auto &match_group = args.group_buffer[group]; + + idx_t child_idx = current_list_size; + if (match_group.empty()) { + // This group was not matched + list_content[child_idx] = string_t(string.GetData(), 0); + if (match_group.begin() == nullptr) { + // This group is optional + child_validity.SetInvalid(child_idx); + } + } else { + // Every group is a substring of the original, we can find out the offset using the pointer + // the 'match_group' address is guaranteed to be bigger than that of the source + D_ASSERT(const_char_ptr_cast(match_group.begin()) >= string.GetData()); + idx_t offset = match_group.begin() - string.GetData(); + list_content[child_idx] = string_t(string.GetData() + offset, match_group.size()); + } + current_list_size++; + if (startpos > input.size()) { + // Empty match found at the end of the string + break; + } + } + list_entry.length = current_list_size - list_entry.offset; + ListVector::SetListSize(result, current_list_size); +} + +int32_t GetGroupIndex(DataChunk &args, idx_t row, int32_t &result) { + if (args.ColumnCount() < 3) { + result = 0; + return true; + } + UnifiedVectorFormat format; + args.data[2].ToUnifiedFormat(args.size(), format); + idx_t index = format.sel->get_index(row); + if (!format.validity.RowIsValid(index)) { + return false; + } + result = UnifiedVectorFormat::GetData(format)[index]; + return true; +} + +duckdb_re2::RE2 &GetPattern(const RegexpBaseBindData &info, ExpressionState &state, + unique_ptr &pattern_p) { + if (info.constant_pattern) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + return lstate.constant_pattern; + } + D_ASSERT(pattern_p); + return *pattern_p; +} + +RegexStringPieceArgs &GetGroupsBuffer(const RegexpBaseBindData &info, ExpressionState &state, + unique_ptr &groups_p) { + if (info.constant_pattern) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + return lstate.group_buffer; + } + D_ASSERT(groups_p); + return *groups_p; +} + +void RegexpExtractAll::Execute(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + const auto &info = func_expr.bind_info->Cast(); + + auto &strings = args.data[0]; + auto &patterns = args.data[1]; + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + auto &output_child = ListVector::GetEntry(result); + + UnifiedVectorFormat strings_data; + strings.ToUnifiedFormat(args.size(), strings_data); + + UnifiedVectorFormat pattern_data; + patterns.ToUnifiedFormat(args.size(), pattern_data); + + ListVector::Reserve(result, STANDARD_VECTOR_SIZE); + // Reference the 'strings' StringBuffer, because we won't need to allocate new data + // for the result, all returned strings are substrings of the originals + output_child.SetAuxiliary(strings.GetAuxiliary()); + + // Avoid doing extra work if all the inputs are constant + idx_t tuple_count = args.AllConstant() ? 1 : args.size(); + + unique_ptr non_const_args; + unique_ptr stored_re; + if (!info.constant_pattern) { + non_const_args = make_uniq(); + } else { + // Verify that the constant pattern is valid + auto &re = GetPattern(info, state, stored_re); + auto group_count_p = re.NumberOfCapturingGroups(); + if (group_count_p == -1) { + throw InvalidInputException("Pattern failed to parse, error: '%s'", re.error()); + } + } + + for (idx_t row = 0; row < tuple_count; row++) { + bool pattern_valid = true; + if (!info.constant_pattern) { + // Check if the pattern is NULL or not, + // and compile the pattern if it's not constant + auto pattern_idx = pattern_data.sel->get_index(row); + if (!pattern_data.validity.RowIsValid(pattern_idx)) { + pattern_valid = false; + } else { + auto &pattern_p = UnifiedVectorFormat::GetData(pattern_data)[pattern_idx]; + auto pattern_strpiece = CreateStringPiece(pattern_p); + stored_re = make_uniq(pattern_strpiece, info.options); + + // Increase the size of the args buffer if needed + auto group_count_p = stored_re->NumberOfCapturingGroups(); + if (group_count_p == -1) { + throw InvalidInputException("Pattern failed to parse, error: '%s'", stored_re->error()); + } + non_const_args->SetSize(group_count_p); + } + } + + auto string_idx = strings_data.sel->get_index(row); + int32_t group_index; + if (!pattern_valid || !strings_data.validity.RowIsValid(string_idx) || !GetGroupIndex(args, row, group_index)) { + // If something is NULL, the result is NULL + // FIXME: do we even need 'SPECIAL_HANDLING'? + auto result_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + result_data[row].length = 0; + result_data[row].offset = ListVector::GetListSize(result); + result_validity.SetInvalid(row); + continue; + } + + auto &re = GetPattern(info, state, stored_re); + auto &groups = GetGroupsBuffer(info, state, non_const_args); + auto &string = UnifiedVectorFormat::GetData(strings_data)[string_idx]; + ExtractSingleTuple(string, re, group_index, groups, result, row); + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +unique_ptr RegexpExtractAll::Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(arguments.size() >= 2); + + duckdb_re2::RE2::Options options; + + string constant_string; + bool constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); + + if (arguments.size() >= 4) { + ParseRegexOptions(context, *arguments[3], options); + } + return make_uniq(options, std::move(constant_string), constant_pattern, ""); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp new file mode 100644 index 00000000..4e485195 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp @@ -0,0 +1,83 @@ +#include "duckdb/function/scalar/regexp.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +namespace regexp_util { + +bool TryParseConstantPattern(ClientContext &context, Expression &expr, string &constant_string) { + if (!expr.IsFoldable()) { + return false; + } + Value pattern_str = ExpressionExecutor::EvaluateScalar(context, expr); + if (!pattern_str.IsNull() && pattern_str.type().id() == LogicalTypeId::VARCHAR) { + constant_string = StringValue::Get(pattern_str); + return true; + } + return false; +} + +void ParseRegexOptions(const string &options, duckdb_re2::RE2::Options &result, bool *global_replace) { + for (idx_t i = 0; i < options.size(); i++) { + switch (options[i]) { + case 'c': + // case-sensitive matching + result.set_case_sensitive(true); + break; + case 'i': + // case-insensitive matching + result.set_case_sensitive(false); + break; + case 'l': + // literal matching + result.set_literal(true); + break; + case 'm': + case 'n': + case 'p': + // newline-sensitive matching + result.set_dot_nl(false); + break; + case 's': + // non-newline-sensitive matching + result.set_dot_nl(true); + break; + case 'g': + // global replace, only available for regexp_replace + if (global_replace) { + *global_replace = true; + } else { + throw InvalidInputException("Option 'g' (global replace) is only valid for regexp_replace"); + } + break; + case ' ': + case '\t': + case '\n': + // ignore whitespace + break; + default: + throw InvalidInputException("Unrecognized Regex option %c", options[i]); + } + } +} + +void ParseRegexOptions(ClientContext &context, Expression &expr, RE2::Options &target, bool *global_replace) { + if (expr.HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!expr.IsFoldable()) { + throw InvalidInputException("Regex options field must be a constant"); + } + Value options_str = ExpressionExecutor::EvaluateScalar(context, expr); + if (options_str.IsNull()) { + throw InvalidInputException("Regex options field must not be NULL"); + } + if (options_str.type().id() != LogicalTypeId::VARCHAR) { + throw InvalidInputException("Regex options field must be a string"); + } + ParseRegexOptions(StringValue::Get(options_str), target, global_replace); +} + +} // namespace regexp_util + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/strip_accents.cpp b/src/duckdb/src/function/scalar/string/strip_accents.cpp new file mode 100644 index 00000000..758c7264 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/strip_accents.cpp @@ -0,0 +1,47 @@ +#include "duckdb/function/scalar/string_functions.hpp" + +#include "utf8proc.hpp" + +namespace duckdb { + +bool StripAccentsFun::IsAscii(const char *input, idx_t n) { + for (idx_t i = 0; i < n; i++) { + if (input[i] & 0x80) { + // non-ascii character + return false; + } + } + return true; +} + +struct StripAccentsOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + if (StripAccentsFun::IsAscii(input.GetData(), input.GetSize())) { + return input; + } + + // non-ascii, perform collation + auto stripped = utf8proc_remove_accents((const utf8proc_uint8_t *)input.GetData(), input.GetSize()); + auto result_str = StringVector::AddString(result, const_char_ptr_cast(stripped)); + free(stripped); + return result_str; + } +}; + +static void StripAccentsFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + StringVector::AddHeapReference(result, args.data[0]); +} + +ScalarFunction StripAccentsFun::GetFunction() { + return ScalarFunction("strip_accents", {LogicalType::VARCHAR}, LogicalType::VARCHAR, StripAccentsFunction); +} + +void StripAccentsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(StripAccentsFun::GetFunction()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/substring.cpp b/src/duckdb/src/function/scalar/string/substring.cpp new file mode 100644 index 00000000..b0b2d816 --- /dev/null +++ b/src/duckdb/src/function/scalar/string/substring.cpp @@ -0,0 +1,339 @@ +#include "duckdb/function/scalar/string_functions.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" + +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "utf8proc.hpp" +#include "duckdb/common/types/blob.hpp" + +namespace duckdb { + +static const int64_t SUPPORTED_UPPER_BOUND = NumericLimits::Maximum(); +static const int64_t SUPPORTED_LOWER_BOUND = -SUPPORTED_UPPER_BOUND - 1; + +static inline void AssertInSupportedRange(idx_t input_size, int64_t offset, int64_t length) { + + if (input_size > (uint64_t)SUPPORTED_UPPER_BOUND) { + throw OutOfRangeException("Substring input size is too large (> %d)", SUPPORTED_UPPER_BOUND); + } + if (offset < SUPPORTED_LOWER_BOUND) { + throw OutOfRangeException("Substring offset outside of supported range (< %d)", SUPPORTED_LOWER_BOUND); + } + if (offset > SUPPORTED_UPPER_BOUND) { + throw OutOfRangeException("Substring offset outside of supported range (> %d)", SUPPORTED_UPPER_BOUND); + } + if (length < SUPPORTED_LOWER_BOUND) { + throw OutOfRangeException("Substring length outside of supported range (< %d)", SUPPORTED_LOWER_BOUND); + } + if (length > SUPPORTED_UPPER_BOUND) { + throw OutOfRangeException("Substring length outside of supported range (> %d)", SUPPORTED_UPPER_BOUND); + } +} + +string_t SubstringEmptyString(Vector &result) { + auto result_string = StringVector::EmptyString(result, 0); + result_string.Finalize(); + return result_string; +} + +string_t SubstringSlice(Vector &result, const char *input_data, int64_t offset, int64_t length) { + auto result_string = StringVector::EmptyString(result, length); + auto result_data = result_string.GetDataWriteable(); + memcpy(result_data, input_data + offset, length); + result_string.Finalize(); + return result_string; +} + +// compute start and end characters from the given input size and offset/length +bool SubstringStartEnd(int64_t input_size, int64_t offset, int64_t length, int64_t &start, int64_t &end) { + if (length == 0) { + return false; + } + if (offset > 0) { + // positive offset: scan from start + start = MinValue(input_size, offset - 1); + } else if (offset < 0) { + // negative offset: scan from end (i.e. start = end + offset) + start = MaxValue(input_size + offset, 0); + } else { + // offset = 0: special case, we start 1 character BEHIND the first character + start = 0; + length--; + if (length <= 0) { + return false; + } + } + if (length > 0) { + // positive length: go forward (i.e. end = start + offset) + end = MinValue(input_size, start + length); + } else { + // negative length: go backwards (i.e. end = start, start = start + length) + end = start; + start = MaxValue(0, start + length); + } + if (start == end) { + return false; + } + D_ASSERT(start < end); + return true; +} + +string_t SubstringASCII(Vector &result, string_t input, int64_t offset, int64_t length) { + auto input_data = input.GetData(); + auto input_size = input.GetSize(); + + AssertInSupportedRange(input_size, offset, length); + + int64_t start, end; + if (!SubstringStartEnd(input_size, offset, length, start, end)) { + return SubstringEmptyString(result); + } + return SubstringSlice(result, input_data, start, end - start); +} + +string_t SubstringFun::SubstringUnicode(Vector &result, string_t input, int64_t offset, int64_t length) { + auto input_data = input.GetData(); + auto input_size = input.GetSize(); + + AssertInSupportedRange(input_size, offset, length); + + if (length == 0) { + return SubstringEmptyString(result); + } + // first figure out which direction we need to scan + idx_t start_pos; + idx_t end_pos; + if (offset < 0) { + start_pos = 0; + end_pos = DConstants::INVALID_INDEX; + + // negative offset: scan backwards + int64_t start, end; + + // we express start and end as unicode codepoints from the back + offset--; + if (length < 0) { + // negative length + start = -offset - length; + end = -offset; + } else { + // positive length + start = -offset; + end = -offset - length; + } + if (end <= 0) { + end_pos = input_size; + } + int64_t current_character = 0; + for (idx_t i = input_size; i > 0; i--) { + if (LengthFun::IsCharacter(input_data[i - 1])) { + current_character++; + if (current_character == start) { + start_pos = i; + break; + } else if (current_character == end) { + end_pos = i; + } + } + } + while (!LengthFun::IsCharacter(input_data[start_pos])) { + start_pos++; + } + while (end_pos < input_size && !LengthFun::IsCharacter(input_data[end_pos])) { + end_pos++; + } + + if (end_pos == DConstants::INVALID_INDEX) { + return SubstringEmptyString(result); + } + } else { + start_pos = DConstants::INVALID_INDEX; + end_pos = input_size; + + // positive offset: scan forwards + int64_t start, end; + + // we express start and end as unicode codepoints from the front + offset--; + if (length < 0) { + // negative length + start = MaxValue(0, offset + length); + end = offset; + } else { + // positive length + start = MaxValue(0, offset); + end = offset + length; + } + + int64_t current_character = 0; + for (idx_t i = 0; i < input_size; i++) { + if (LengthFun::IsCharacter(input_data[i])) { + if (current_character == start) { + start_pos = i; + } else if (current_character == end) { + end_pos = i; + break; + } + current_character++; + } + } + if (start_pos == DConstants::INVALID_INDEX || end == 0 || end <= start) { + return SubstringEmptyString(result); + } + } + D_ASSERT(end_pos >= start_pos); + // after we have found these, we can slice the substring + return SubstringSlice(result, input_data, start_pos, end_pos - start_pos); +} + +string_t SubstringFun::SubstringGrapheme(Vector &result, string_t input, int64_t offset, int64_t length) { + auto input_data = input.GetData(); + auto input_size = input.GetSize(); + + AssertInSupportedRange(input_size, offset, length); + + // we don't know yet if the substring is ascii, but we assume it is (for now) + // first get the start and end as if this was an ascii string + int64_t start, end; + if (!SubstringStartEnd(input_size, offset, length, start, end)) { + return SubstringEmptyString(result); + } + + // now check if all the characters between 0 and end are ascii characters + // note that we scan one further to check for a potential combining diacritics (e.g. i + diacritic is ï) + bool is_ascii = true; + idx_t ascii_end = MinValue(end + 1, input_size); + for (idx_t i = 0; i < ascii_end; i++) { + if (input_data[i] & 0x80) { + // found a non-ascii character: eek + is_ascii = false; + break; + } + } + if (is_ascii) { + // all characters are ascii, we can just slice the substring + return SubstringSlice(result, input_data, start, end - start); + } + // if the characters are not ascii, we need to scan grapheme clusters + // first figure out which direction we need to scan + // offset = 0 case is taken care of in SubstringStartEnd + if (offset < 0) { + // negative offset, this case is more difficult + // we first need to count the number of characters in the string + idx_t num_characters = 0; + utf8proc_grapheme_callback(input_data, input_size, [&](size_t start, size_t end) { + num_characters++; + return true; + }); + // now call substring start and end again, but with the number of unicode characters this time + SubstringStartEnd(num_characters, offset, length, start, end); + } + + // now scan the graphemes of the string to find the positions of the start and end characters + int64_t current_character = 0; + idx_t start_pos = DConstants::INVALID_INDEX, end_pos = input_size; + utf8proc_grapheme_callback(input_data, input_size, [&](size_t gstart, size_t gend) { + if (current_character == start) { + start_pos = gstart; + } else if (current_character == end) { + end_pos = gstart; + return false; + } + current_character++; + return true; + }); + if (start_pos == DConstants::INVALID_INDEX) { + return SubstringEmptyString(result); + } + // after we have found these, we can slice the substring + return SubstringSlice(result, input_data, start_pos, end_pos - start_pos); +} + +struct SubstringUnicodeOp { + static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { + return SubstringFun::SubstringUnicode(result, input, offset, length); + } +}; + +struct SubstringGraphemeOp { + static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { + return SubstringFun::SubstringGrapheme(result, input, offset, length); + } +}; + +template +static void SubstringFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input_vector = args.data[0]; + auto &offset_vector = args.data[1]; + if (args.ColumnCount() == 3) { + auto &length_vector = args.data[2]; + + TernaryExecutor::Execute( + input_vector, offset_vector, length_vector, result, args.size(), + [&](string_t input_string, int64_t offset, int64_t length) { + return OP::Substring(result, input_string, offset, length); + }); + } else { + BinaryExecutor::Execute( + input_vector, offset_vector, result, args.size(), [&](string_t input_string, int64_t offset) { + return OP::Substring(result, input_string, offset, NumericLimits::Maximum()); + }); + } +} + +static void SubstringFunctionASCII(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input_vector = args.data[0]; + auto &offset_vector = args.data[1]; + if (args.ColumnCount() == 3) { + auto &length_vector = args.data[2]; + + TernaryExecutor::Execute( + input_vector, offset_vector, length_vector, result, args.size(), + [&](string_t input_string, int64_t offset, int64_t length) { + return SubstringASCII(result, input_string, offset, length); + }); + } else { + BinaryExecutor::Execute( + input_vector, offset_vector, result, args.size(), [&](string_t input_string, int64_t offset) { + return SubstringASCII(result, input_string, offset, NumericLimits::Maximum()); + }); + } +} + +static unique_ptr SubstringPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + // can only propagate stats if the children have stats + // we only care about the stats of the first child (i.e. the string) + if (!StringStats::CanContainUnicode(child_stats[0])) { + expr.function.function = SubstringFunctionASCII; + } + return nullptr; +} + +void SubstringFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunctionSet substr("substring"); + substr.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::VARCHAR, SubstringFunction, nullptr, nullptr, + SubstringPropagateStats)); + substr.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + SubstringFunction, nullptr, nullptr, + SubstringPropagateStats)); + set.AddFunction(substr); + substr.name = "substr"; + set.AddFunction(substr); + + ScalarFunctionSet substr_grapheme("substring_grapheme"); + substr_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::VARCHAR, SubstringFunction, nullptr, + nullptr, SubstringPropagateStats)); + substr_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + SubstringFunction, nullptr, nullptr, + SubstringPropagateStats)); + set.AddFunction(substr_grapheme); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/suffix.cpp b/src/duckdb/src/function/scalar/string/suffix.cpp new file mode 100644 index 00000000..f497333a --- /dev/null +++ b/src/duckdb/src/function/scalar/string/suffix.cpp @@ -0,0 +1,47 @@ +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/common/types/string_type.hpp" + +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +static bool SuffixFunction(const string_t &str, const string_t &suffix); + +struct SuffixOperator { + template + static inline TR Operation(TA left, TB right) { + return SuffixFunction(left, right); + } +}; + +static bool SuffixFunction(const string_t &str, const string_t &suffix) { + auto suffix_size = suffix.GetSize(); + auto str_size = str.GetSize(); + if (suffix_size > str_size) { + return false; + } + + auto suffix_data = suffix.GetData(); + auto str_data = str.GetData(); + int32_t suf_idx = suffix_size - 1; + idx_t str_idx = str_size - 1; + for (; suf_idx >= 0; --suf_idx, --str_idx) { + if (suffix_data[suf_idx] != str_data[str_idx]) { + return false; + } + } + return true; +} + +ScalarFunction SuffixFun::GetFunction() { + return ScalarFunction("suffix", // name of the function + {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list + LogicalType::BOOLEAN, // return type + ScalarFunction::BinaryFunction); +} + +void SuffixFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction({"suffix", "ends_with"}, GetFunction()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string_functions.cpp b/src/duckdb/src/function/scalar/string_functions.cpp new file mode 100644 index 00000000..88d7b716 --- /dev/null +++ b/src/duckdb/src/function/scalar/string_functions.cpp @@ -0,0 +1,21 @@ +#include "duckdb/function/scalar/string_functions.hpp" + +namespace duckdb { + +void BuiltinFunctions::RegisterStringFunctions() { + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/struct/struct_extract.cpp b/src/duckdb/src/function/scalar/struct/struct_extract.cpp new file mode 100644 index 00000000..f9498275 --- /dev/null +++ b/src/duckdb/src/function/scalar/struct/struct_extract.cpp @@ -0,0 +1,122 @@ +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" + +namespace duckdb { + +struct StructExtractBindData : public FunctionData { + StructExtractBindData(string key, idx_t index, LogicalType type) + : key(std::move(key)), index(index), type(std::move(type)) { + } + + string key; + idx_t index; + LogicalType type; + +public: + unique_ptr Copy() const override { + return make_uniq(key, index, type); + } + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return key == other.key && index == other.index && type == other.type; + } +}; + +static void StructExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + // this should be guaranteed by the binder + auto &vec = args.data[0]; + + vec.Verify(args.size()); + auto &children = StructVector::GetEntries(vec); + D_ASSERT(info.index < children.size()); + auto &struct_child = children[info.index]; + result.Reference(*struct_child); + result.Verify(args.size()); +} + +static unique_ptr StructExtractBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2); + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + D_ASSERT(LogicalTypeId::STRUCT == arguments[0]->return_type.id()); + auto &struct_children = StructType::GetChildTypes(arguments[0]->return_type); + if (struct_children.empty()) { + throw InternalException("Can't extract something from an empty struct"); + } + bound_function.arguments[0] = arguments[0]->return_type; + + auto &key_child = arguments[1]; + if (key_child->HasParameter()) { + throw ParameterNotResolvedException(); + } + + if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { + throw BinderException("Key name for struct_extract needs to be a constant string"); + } + Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); + D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); + auto &key_str = StringValue::Get(key_val); + if (key_val.IsNull() || key_str.empty()) { + throw BinderException("Key name for struct_extract needs to be neither NULL nor empty"); + } + string key = StringUtil::Lower(key_str); + + LogicalType return_type; + idx_t key_index = 0; + bool found_key = false; + + for (size_t i = 0; i < struct_children.size(); i++) { + auto &child = struct_children[i]; + if (StringUtil::Lower(child.first) == key) { + found_key = true; + key_index = i; + return_type = child.second; + break; + } + } + + if (!found_key) { + vector candidates; + candidates.reserve(struct_children.size()); + for (auto &struct_child : struct_children) { + candidates.push_back(struct_child.first); + } + auto closest_settings = StringUtil::TopNLevenshtein(candidates, key); + auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); + throw BinderException("Could not find key \"%s\" in struct\n%s", key, message); + } + + bound_function.return_type = return_type; + return make_uniq(std::move(key), key_index, std::move(return_type)); +} + +static unique_ptr PropagateStructExtractStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &bind_data = input.bind_data; + + auto &info = bind_data->Cast(); + auto struct_child_stats = StructStats::GetChildStats(child_stats[0]); + return struct_child_stats[info.index].ToUnique(); +} + +ScalarFunction StructExtractFun::GetFunction() { + return ScalarFunction("struct_extract", {LogicalTypeId::STRUCT, LogicalType::VARCHAR}, LogicalType::ANY, + StructExtractFunction, StructExtractBind, nullptr, PropagateStructExtractStats); +} + +void StructExtractFun::RegisterFunction(BuiltinFunctions &set) { + // the arguments and return types are actually set in the binder function + auto fun = GetFunction(); + set.AddFunction(fun); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp new file mode 100644 index 00000000..e7125538 --- /dev/null +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -0,0 +1,365 @@ +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +// aggregate state export +struct ExportAggregateBindData : public FunctionData { + AggregateFunction aggr; + idx_t state_size; + + explicit ExportAggregateBindData(AggregateFunction aggr_p, idx_t state_size_p) + : aggr(std::move(aggr_p)), state_size(state_size_p) { + } + + unique_ptr Copy() const override { + return make_uniq(aggr, state_size); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return aggr == other.aggr && state_size == other.state_size; + } + + static ExportAggregateBindData &GetFrom(ExpressionState &state) { + auto &func_expr = state.expr.Cast(); + return func_expr.bind_info->Cast(); + } +}; + +struct CombineState : public FunctionLocalState { + idx_t state_size; + + unsafe_unique_array state_buffer0, state_buffer1; + Vector state_vector0, state_vector1; + + ArenaAllocator allocator; + + explicit CombineState(idx_t state_size_p) + : state_size(state_size_p), state_buffer0(make_unsafe_uniq_array(state_size_p)), + state_buffer1(make_unsafe_uniq_array(state_size_p)), + state_vector0(Value::POINTER(CastPointerToValue(state_buffer0.get()))), + state_vector1(Value::POINTER(CastPointerToValue(state_buffer1.get()))), + allocator(Allocator::DefaultAllocator()) { + } +}; + +static unique_ptr InitCombineState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + return make_uniq(bind_data.state_size); +} + +struct FinalizeState : public FunctionLocalState { + idx_t state_size; + unsafe_unique_array state_buffer; + Vector addresses; + + ArenaAllocator allocator; + + explicit FinalizeState(idx_t state_size_p) + : state_size(state_size_p), + state_buffer(make_unsafe_uniq_array(STANDARD_VECTOR_SIZE * AlignValue(state_size_p))), + addresses(LogicalType::POINTER), allocator(Allocator::DefaultAllocator()) { + } +}; + +static unique_ptr InitFinalizeState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + return make_uniq(bind_data.state_size); +} + +static void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector &result) { + auto &bind_data = ExportAggregateBindData::GetFrom(state_p); + auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); + local_state.allocator.Reset(); + + D_ASSERT(bind_data.state_size == bind_data.aggr.state_size()); + D_ASSERT(input.data.size() == 1); + D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); + auto aligned_state_size = AlignValue(bind_data.state_size); + + auto state_vec_ptr = FlatVector::GetData(local_state.addresses); + + UnifiedVectorFormat state_data; + input.data[0].ToUnifiedFormat(input.size(), state_data); + for (idx_t i = 0; i < input.size(); i++) { + auto state_idx = state_data.sel->get_index(i); + auto state_entry = UnifiedVectorFormat::GetData(state_data) + state_idx; + auto target_ptr = char_ptr_cast(local_state.state_buffer.get()) + aligned_state_size * i; + + if (state_data.validity.RowIsValid(state_idx)) { + D_ASSERT(state_entry->GetSize() == bind_data.state_size); + memcpy((void *)target_ptr, state_entry->GetData(), bind_data.state_size); + } else { + // create a dummy state because finalize does not understand NULLs in its input + // we put the NULL back in explicitly below + bind_data.aggr.initialize(data_ptr_cast(target_ptr)); + } + state_vec_ptr[i] = data_ptr_cast(target_ptr); + } + + AggregateInputData aggr_input_data(nullptr, local_state.allocator); + bind_data.aggr.finalize(local_state.addresses, aggr_input_data, result, input.size(), 0); + + for (idx_t i = 0; i < input.size(); i++) { + auto state_idx = state_data.sel->get_index(i); + if (!state_data.validity.RowIsValid(state_idx)) { + FlatVector::SetNull(result, i, true); + } + } +} + +static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &result) { + auto &bind_data = ExportAggregateBindData::GetFrom(state_p); + auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); + local_state.allocator.Reset(); + + D_ASSERT(bind_data.state_size == bind_data.aggr.state_size()); + + D_ASSERT(input.data.size() == 2); + D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); + D_ASSERT(input.data[0].GetType() == result.GetType()); + + if (input.data[0].GetType().InternalType() != input.data[1].GetType().InternalType()) { + throw IOException("Aggregate state combine type mismatch, expect %s, got %s", + input.data[0].GetType().ToString(), input.data[1].GetType().ToString()); + } + + UnifiedVectorFormat state0_data, state1_data; + input.data[0].ToUnifiedFormat(input.size(), state0_data); + input.data[1].ToUnifiedFormat(input.size(), state1_data); + + auto result_ptr = FlatVector::GetData(result); + + for (idx_t i = 0; i < input.size(); i++) { + auto state0_idx = state0_data.sel->get_index(i); + auto state1_idx = state1_data.sel->get_index(i); + + auto &state0 = UnifiedVectorFormat::GetData(state0_data)[state0_idx]; + auto &state1 = UnifiedVectorFormat::GetData(state1_data)[state1_idx]; + + // if both are NULL, we return NULL. If either of them is not, the result is that one + if (!state0_data.validity.RowIsValid(state0_idx) && !state1_data.validity.RowIsValid(state1_idx)) { + FlatVector::SetNull(result, i, true); + continue; + } + if (state0_data.validity.RowIsValid(state0_idx) && !state1_data.validity.RowIsValid(state1_idx)) { + result_ptr[i] = + StringVector::AddStringOrBlob(result, const_char_ptr_cast(state0.GetData()), bind_data.state_size); + continue; + } + if (!state0_data.validity.RowIsValid(state0_idx) && state1_data.validity.RowIsValid(state1_idx)) { + result_ptr[i] = + StringVector::AddStringOrBlob(result, const_char_ptr_cast(state1.GetData()), bind_data.state_size); + continue; + } + + // we actually have to combine + if (state0.GetSize() != bind_data.state_size || state1.GetSize() != bind_data.state_size) { + throw IOException("Aggregate state size mismatch, expect %llu, got %llu and %llu", bind_data.state_size, + state0.GetSize(), state1.GetSize()); + } + + memcpy(local_state.state_buffer0.get(), state0.GetData(), bind_data.state_size); + memcpy(local_state.state_buffer1.get(), state1.GetData(), bind_data.state_size); + + AggregateInputData aggr_input_data(nullptr, local_state.allocator); + bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); + + result_ptr[i] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(local_state.state_buffer1.get()), + bind_data.state_size); + } +} + +static unique_ptr BindAggregateState(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + // grab the aggregate type and bind the aggregate again + + // the aggregate name and types are in the logical type of the aggregate state, make sure its sane + auto &arg_return_type = arguments[0]->return_type; + for (auto &arg_type : bound_function.arguments) { + arg_type = arg_return_type; + } + + if (arg_return_type.id() != LogicalTypeId::AGGREGATE_STATE) { + throw BinderException("Can only FINALIZE aggregate state, not %s", arg_return_type.ToString()); + } + // combine + if (arguments.size() == 2 && arguments[0]->return_type != arguments[1]->return_type && + arguments[1]->return_type.id() != LogicalTypeId::BLOB) { + throw BinderException("Cannot COMBINE aggregate states from different functions, %s <> %s", + arguments[0]->return_type.ToString(), arguments[1]->return_type.ToString()); + } + + // following error states are only reachable when someone messes up creating the state_type + // which is impossible from SQL + + auto state_type = AggregateStateType::GetStateType(arg_return_type); + + // now we can look up the function in the catalog again and bind it + auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, + DEFAULT_SCHEMA, state_type.function_name); + if (func.type != CatalogType::AGGREGATE_FUNCTION_ENTRY) { + throw InternalException("Could not find aggregate %s", state_type.function_name); + } + auto &aggr = func.Cast(); + + string error; + + FunctionBinder function_binder(context); + idx_t best_function = + function_binder.BindFunction(aggr.name, aggr.functions, state_type.bound_argument_types, error); + if (best_function == DConstants::INVALID_INDEX) { + throw InternalException("Could not re-bind exported aggregate %s: %s", state_type.function_name, error); + } + auto bound_aggr = aggr.functions.GetFunctionByOffset(best_function); + if (bound_aggr.bind) { + // FIXME: this is really hacky + // but the aggregate state export needs a rework around how it handles more complex aggregates anyway + vector> args; + args.reserve(state_type.bound_argument_types.size()); + for (auto &arg_type : state_type.bound_argument_types) { + args.push_back(make_uniq(Value(arg_type))); + } + auto bind_info = bound_aggr.bind(context, bound_aggr, args); + if (bind_info) { + throw BinderException("Aggregate function with bind info not supported yet in aggregate state export"); + } + } + + if (bound_aggr.return_type != state_type.return_type || bound_aggr.arguments != state_type.bound_argument_types) { + throw InternalException("Type mismatch for exported aggregate %s", state_type.function_name); + } + + if (bound_function.name == "finalize") { + bound_function.return_type = bound_aggr.return_type; + } else { + D_ASSERT(bound_function.name == "combine"); + bound_function.return_type = arg_return_type; + } + + return make_uniq(bound_aggr, bound_aggr.state_size()); +} + +static void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + idx_t offset) { + D_ASSERT(offset == 0); + auto &bind_data = aggr_input_data.bind_data->Cast(); + auto state_size = bind_data.aggregate->function.state_size(); + auto blob_ptr = FlatVector::GetData(result); + auto addresses_ptr = FlatVector::GetData(state); + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + auto data_ptr = addresses_ptr[row_idx]; + blob_ptr[row_idx] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(data_ptr), state_size); + } +} + +ExportAggregateFunctionBindData::ExportAggregateFunctionBindData(unique_ptr aggregate_p) { + D_ASSERT(aggregate_p->type == ExpressionType::BOUND_AGGREGATE); + aggregate = unique_ptr_cast(std::move(aggregate_p)); +} + +unique_ptr ExportAggregateFunctionBindData::Copy() const { + return make_uniq(aggregate->Copy()); +} + +bool ExportAggregateFunctionBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return aggregate->Equals(*other.aggregate); +} + +static void ExportStateAggregateSerialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + throw NotImplementedException("FIXME: export state serialize"); +} + +static unique_ptr ExportStateAggregateDeserialize(Deserializer &deserializer, + AggregateFunction &function) { + throw NotImplementedException("FIXME: export state deserialize"); +} + +static void ExportStateScalarSerialize(Serializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + throw NotImplementedException("FIXME: export state serialize"); +} + +static unique_ptr ExportStateScalarDeserialize(Deserializer &deserializer, ScalarFunction &function) { + throw NotImplementedException("FIXME: export state deserialize"); +} + +unique_ptr +ExportAggregateFunction::Bind(unique_ptr child_aggregate) { + auto &bound_function = child_aggregate->function; + if (!bound_function.combine) { + throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.name); + } + if (bound_function.bind) { + throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom binders"); + } + if (bound_function.destructor) { + throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom destructors"); + } + // this should be required + D_ASSERT(bound_function.state_size); + D_ASSERT(bound_function.finalize); + + D_ASSERT(child_aggregate->function.return_type.id() != LogicalTypeId::INVALID); +#ifdef DEBUG + for (auto &arg_type : child_aggregate->function.arguments) { + D_ASSERT(arg_type.id() != LogicalTypeId::INVALID); + } +#endif + auto export_bind_data = make_uniq(child_aggregate->Copy()); + aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.return_type, + child_aggregate->function.arguments); + auto return_type = LogicalType::AGGREGATE_STATE(std::move(state_type)); + + auto export_function = + AggregateFunction("aggregate_state_export_" + bound_function.name, bound_function.arguments, return_type, + bound_function.state_size, bound_function.initialize, bound_function.update, + bound_function.combine, ExportAggregateFinalize, bound_function.simple_update, + /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, + /* can't propagate statistics */ nullptr, nullptr); + export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + export_function.serialize = ExportStateAggregateSerialize; + export_function.deserialize = ExportStateAggregateDeserialize; + + return make_uniq(export_function, std::move(child_aggregate->children), + std::move(child_aggregate->filter), std::move(export_bind_data), + child_aggregate->aggr_type); +} + +ScalarFunction ExportAggregateFunction::GetFinalize() { + auto result = ScalarFunction("finalize", {LogicalTypeId::AGGREGATE_STATE}, LogicalTypeId::INVALID, + AggregateStateFinalize, BindAggregateState, nullptr, nullptr, InitFinalizeState); + result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + result.serialize = ExportStateScalarSerialize; + result.deserialize = ExportStateScalarDeserialize; + return result; +} + +ScalarFunction ExportAggregateFunction::GetCombine() { + auto result = + ScalarFunction("combine", {LogicalTypeId::AGGREGATE_STATE, LogicalTypeId::ANY}, LogicalTypeId::AGGREGATE_STATE, + AggregateStateCombine, BindAggregateState, nullptr, nullptr, InitCombineState); + result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + result.serialize = ExportStateScalarSerialize; + result.deserialize = ExportStateScalarDeserialize; + return result; +} + +void ExportAggregateFunction::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(ExportAggregateFunction::GetCombine()); + set.AddFunction(ExportAggregateFunction::GetFinalize()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar_function.cpp b/src/duckdb/src/function/scalar_function.cpp new file mode 100644 index 00000000..250eee74 --- /dev/null +++ b/src/duckdb/src/function/scalar_function.cpp @@ -0,0 +1,65 @@ +#include "duckdb/function/scalar_function.hpp" + +namespace duckdb { + +FunctionLocalState::~FunctionLocalState() { +} + +ScalarFunction::ScalarFunction(string name, vector arguments, LogicalType return_type, + scalar_function_t function, bind_scalar_function_t bind, + dependency_function_t dependency, function_statistics_t statistics, + init_local_state_t init_local_state, LogicalType varargs, + FunctionSideEffects side_effects, FunctionNullHandling null_handling) + : BaseScalarFunction(std::move(name), std::move(arguments), std::move(return_type), side_effects, + std::move(varargs), null_handling), + function(std::move(function)), bind(bind), init_local_state(init_local_state), dependency(dependency), + statistics(statistics), serialize(nullptr), deserialize(nullptr) { +} + +ScalarFunction::ScalarFunction(vector arguments, LogicalType return_type, scalar_function_t function, + bind_scalar_function_t bind, dependency_function_t dependency, + function_statistics_t statistics, init_local_state_t init_local_state, + LogicalType varargs, FunctionSideEffects side_effects, + FunctionNullHandling null_handling) + : ScalarFunction(string(), std::move(arguments), std::move(return_type), std::move(function), bind, dependency, + statistics, init_local_state, std::move(varargs), side_effects, null_handling) { +} + +bool ScalarFunction::operator==(const ScalarFunction &rhs) const { + return name == rhs.name && arguments == rhs.arguments && return_type == rhs.return_type && varargs == rhs.varargs && + bind == rhs.bind && dependency == rhs.dependency && statistics == rhs.statistics; +} + +bool ScalarFunction::operator!=(const ScalarFunction &rhs) const { + return !(*this == rhs); +} + +bool ScalarFunction::Equal(const ScalarFunction &rhs) const { + // number of types + if (this->arguments.size() != rhs.arguments.size()) { + return false; + } + // argument types + for (idx_t i = 0; i < this->arguments.size(); ++i) { + if (this->arguments[i] != rhs.arguments[i]) { + return false; + } + } + // return type + if (this->return_type != rhs.return_type) { + return false; + } + // varargs + if (this->varargs != rhs.varargs) { + return false; + } + + return true; // they are equal +} + +void ScalarFunction::NopFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() >= 1); + result.Reference(input.data[0]); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar_macro_function.cpp b/src/duckdb/src/function/scalar_macro_function.cpp new file mode 100644 index 00000000..07f7b788 --- /dev/null +++ b/src/duckdb/src/function/scalar_macro_function.cpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar_macro_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/function/scalar_macro_function.hpp" + +#include "duckdb/function/macro_function.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" + +namespace duckdb { + +ScalarMacroFunction::ScalarMacroFunction(unique_ptr expression) + : MacroFunction(MacroType::SCALAR_MACRO), expression(std::move(expression)) { +} + +ScalarMacroFunction::ScalarMacroFunction(void) : MacroFunction(MacroType::SCALAR_MACRO) { +} + +unique_ptr ScalarMacroFunction::Copy() const { + auto result = make_uniq(); + result->expression = expression->Copy(); + CopyProperties(*result); + + return std::move(result); +} + +void RemoveQualificationRecursive(unique_ptr &expr) { + if (expr->GetExpressionType() == ExpressionType::COLUMN_REF) { + auto &col_ref = expr->Cast(); + auto &col_names = col_ref.column_names; + if (col_names.size() == 2 && col_names[0].find(DummyBinding::DUMMY_NAME) != string::npos) { + col_names.erase(col_names.begin()); + } + } else { + ParsedExpressionIterator::EnumerateChildren( + *expr, [](unique_ptr &child) { RemoveQualificationRecursive(child); }); + } +} + +string ScalarMacroFunction::ToSQL(const string &schema, const string &name) const { + // In case of nested macro's we need to fix it a bit + auto expression_copy = expression->Copy(); + RemoveQualificationRecursive(expression_copy); + return MacroFunction::ToSQL(schema, name) + StringUtil::Format("(%s);", expression_copy->ToString()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow.cpp b/src/duckdb/src/function/table/arrow.cpp new file mode 100644 index 00000000..9601e41d --- /dev/null +++ b/src/duckdb/src/function/table/arrow.cpp @@ -0,0 +1,408 @@ +#include "duckdb/common/arrow/arrow.hpp" + +#include "duckdb.hpp" +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/vector_buffer.hpp" +#include "duckdb/function/table/arrow.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/parser/parsed_data/create_table_function_info.hpp" +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +unique_ptr ArrowTableFunction::GetArrowLogicalType(ArrowSchema &schema) { + auto format = string(schema.format); + if (format == "n") { + return make_uniq(LogicalType::SQLNULL); + } else if (format == "b") { + return make_uniq(LogicalType::BOOLEAN); + } else if (format == "c") { + return make_uniq(LogicalType::TINYINT); + } else if (format == "s") { + return make_uniq(LogicalType::SMALLINT); + } else if (format == "i") { + return make_uniq(LogicalType::INTEGER); + } else if (format == "l") { + return make_uniq(LogicalType::BIGINT); + } else if (format == "C") { + return make_uniq(LogicalType::UTINYINT); + } else if (format == "S") { + return make_uniq(LogicalType::USMALLINT); + } else if (format == "I") { + return make_uniq(LogicalType::UINTEGER); + } else if (format == "L") { + return make_uniq(LogicalType::UBIGINT); + } else if (format == "f") { + return make_uniq(LogicalType::FLOAT); + } else if (format == "g") { + return make_uniq(LogicalType::DOUBLE); + } else if (format[0] == 'd') { //! this can be either decimal128 or decimal 256 (e.g., d:38,0) + std::string parameters = format.substr(format.find(':')); + uint8_t width = std::stoi(parameters.substr(1, parameters.find(','))); + uint8_t scale = std::stoi(parameters.substr(parameters.find(',') + 1)); + if (width > 38) { + throw NotImplementedException("Unsupported Internal Arrow Type for Decimal %s", format); + } + return make_uniq(LogicalType::DECIMAL(width, scale)); + } else if (format == "u") { + return make_uniq(LogicalType::VARCHAR, ArrowVariableSizeType::NORMAL); + } else if (format == "U") { + return make_uniq(LogicalType::VARCHAR, ArrowVariableSizeType::SUPER_SIZE); + } else if (format == "tsn:") { + return make_uniq(LogicalTypeId::TIMESTAMP_NS); + } else if (format == "tsu:") { + return make_uniq(LogicalTypeId::TIMESTAMP); + } else if (format == "tsm:") { + return make_uniq(LogicalTypeId::TIMESTAMP_MS); + } else if (format == "tss:") { + return make_uniq(LogicalTypeId::TIMESTAMP_SEC); + } else if (format == "tdD") { + return make_uniq(LogicalType::DATE, ArrowDateTimeType::DAYS); + } else if (format == "tdm") { + return make_uniq(LogicalType::DATE, ArrowDateTimeType::MILLISECONDS); + } else if (format == "tts") { + return make_uniq(LogicalType::TIME, ArrowDateTimeType::SECONDS); + } else if (format == "ttm") { + return make_uniq(LogicalType::TIME, ArrowDateTimeType::MILLISECONDS); + } else if (format == "ttu") { + return make_uniq(LogicalType::TIME, ArrowDateTimeType::MICROSECONDS); + } else if (format == "ttn") { + return make_uniq(LogicalType::TIME, ArrowDateTimeType::NANOSECONDS); + } else if (format == "tDs") { + return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::SECONDS); + } else if (format == "tDm") { + return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MILLISECONDS); + } else if (format == "tDu") { + return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MICROSECONDS); + } else if (format == "tDn") { + return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::NANOSECONDS); + } else if (format == "tiD") { + return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::DAYS); + } else if (format == "tiM") { + return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MONTHS); + } else if (format == "tin") { + return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MONTH_DAY_NANO); + } else if (format == "+l") { + auto child_type = GetArrowLogicalType(*schema.children[0]); + auto list_type = + make_uniq(LogicalType::LIST(child_type->GetDuckType()), ArrowVariableSizeType::NORMAL); + list_type->AddChild(std::move(child_type)); + return list_type; + } else if (format == "+L") { + auto child_type = GetArrowLogicalType(*schema.children[0]); + auto list_type = + make_uniq(LogicalType::LIST(child_type->GetDuckType()), ArrowVariableSizeType::SUPER_SIZE); + list_type->AddChild(std::move(child_type)); + return list_type; + } else if (format[0] == '+' && format[1] == 'w') { + std::string parameters = format.substr(format.find(':') + 1); + idx_t fixed_size = std::stoi(parameters); + auto child_type = GetArrowLogicalType(*schema.children[0]); + auto list_type = make_uniq(LogicalType::LIST(child_type->GetDuckType()), fixed_size); + list_type->AddChild(std::move(child_type)); + return list_type; + } else if (format == "+s") { + child_list_t child_types; + vector> children; + for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { + children.emplace_back(GetArrowLogicalType(*schema.children[type_idx])); + child_types.emplace_back(schema.children[type_idx]->name, children.back()->GetDuckType()); + } + auto struct_type = make_uniq(LogicalType::STRUCT(std::move(child_types))); + struct_type->AssignChildren(std::move(children)); + return struct_type; + } else if (format[0] == '+' && format[1] == 'u') { + if (format[2] != 's') { + throw NotImplementedException("Unsupported Internal Arrow Type: \"%c\" Union", format[2]); + } + D_ASSERT(format[3] == ':'); + + std::string prefix = "+us:"; + // TODO: what are these type ids actually for? + auto type_ids = StringUtil::Split(format.substr(prefix.size()), ','); + + child_list_t members; + vector> children; + for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { + auto type = schema.children[type_idx]; + + children.emplace_back(GetArrowLogicalType(*type)); + members.emplace_back(type->name, children.back()->GetDuckType()); + } + + auto union_type = make_uniq(LogicalType::UNION(members)); + union_type->AssignChildren(std::move(children)); + return union_type; + } else if (format == "+m") { + auto &arrow_struct_type = *schema.children[0]; + D_ASSERT(arrow_struct_type.n_children == 2); + auto key_type = GetArrowLogicalType(*arrow_struct_type.children[0]); + auto value_type = GetArrowLogicalType(*arrow_struct_type.children[1]); + auto map_type = make_uniq(LogicalType::MAP(key_type->GetDuckType(), value_type->GetDuckType()), + ArrowVariableSizeType::NORMAL); + child_list_t key_value; + key_value.emplace_back(std::make_pair("key", key_type->GetDuckType())); + key_value.emplace_back(std::make_pair("value", value_type->GetDuckType())); + + auto inner_struct = + make_uniq(LogicalType::STRUCT(std::move(key_value)), ArrowVariableSizeType::NORMAL); + vector> children; + children.reserve(2); + children.push_back(std::move(key_type)); + children.push_back(std::move(value_type)); + inner_struct->AssignChildren(std::move(children)); + map_type->AddChild(std::move(inner_struct)); + return map_type; + } else if (format == "z") { + return make_uniq(LogicalType::BLOB, ArrowVariableSizeType::NORMAL); + } else if (format == "Z") { + return make_uniq(LogicalType::BLOB, ArrowVariableSizeType::SUPER_SIZE); + } else if (format[0] == 'w') { + std::string parameters = format.substr(format.find(':') + 1); + idx_t fixed_size = std::stoi(parameters); + return make_uniq(LogicalType::BLOB, fixed_size); + } else if (format[0] == 't' && format[1] == 's') { + // Timestamp with Timezone + // TODO right now we just get the UTC value. We probably want to support this properly in the future + if (format[2] == 'n') { + return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::NANOSECONDS); + } else if (format[2] == 'u') { + return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::MICROSECONDS); + } else if (format[2] == 'm') { + return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::MILLISECONDS); + } else if (format[2] == 's') { + return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::SECONDS); + } else { + throw NotImplementedException(" Timestamptz precision of not accepted"); + } + } else { + throw NotImplementedException("Unsupported Internal Arrow Type %s", format); + } +} + +void ArrowTableFunction::RenameArrowColumns(vector &names) { + unordered_map name_map; + for (auto &column_name : names) { + // put it all lower_case + auto low_column_name = StringUtil::Lower(column_name); + if (name_map.find(low_column_name) == name_map.end()) { + // Name does not exist yet + name_map[low_column_name]++; + } else { + // Name already exists, we add _x where x is the repetition number + string new_column_name = column_name + "_" + std::to_string(name_map[low_column_name]); + auto new_column_name_low = StringUtil::Lower(new_column_name); + while (name_map.find(new_column_name_low) != name_map.end()) { + // This name is already here due to a previous definition + name_map[low_column_name]++; + new_column_name = column_name + "_" + std::to_string(name_map[low_column_name]); + new_column_name_low = StringUtil::Lower(new_column_name); + } + column_name = new_column_name; + name_map[new_column_name_low]++; + } + } +} + +void ArrowTableFunction::PopulateArrowTableType(ArrowTableType &arrow_table, ArrowSchemaWrapper &schema_p, + vector &names, vector &return_types) { + for (idx_t col_idx = 0; col_idx < (idx_t)schema_p.arrow_schema.n_children; col_idx++) { + auto &schema = *schema_p.arrow_schema.children[col_idx]; + if (!schema.release) { + throw InvalidInputException("arrow_scan: released schema passed"); + } + auto arrow_type = GetArrowLogicalType(schema); + if (schema.dictionary) { + auto logical_type = arrow_type->GetDuckType(); + auto dictionary = GetArrowLogicalType(*schema.dictionary); + return_types.emplace_back(dictionary->GetDuckType()); + // The dictionary might have different attributes (size type, datetime precision, etc..) + arrow_type->SetDictionary(std::move(dictionary)); + } else { + return_types.emplace_back(arrow_type->GetDuckType()); + } + arrow_table.AddColumn(col_idx, std::move(arrow_type)); + auto format = string(schema.format); + auto name = string(schema.name); + if (name.empty()) { + name = string("v") + to_string(col_idx); + } + names.push_back(name); + } +} + +unique_ptr ArrowTableFunction::ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + if (input.inputs[0].IsNull() || input.inputs[1].IsNull() || input.inputs[2].IsNull()) { + throw BinderException("arrow_scan: pointers cannot be null"); + } + + auto stream_factory_ptr = input.inputs[0].GetPointer(); + auto stream_factory_produce = (stream_factory_produce_t)input.inputs[1].GetPointer(); // NOLINT + auto stream_factory_get_schema = (stream_factory_get_schema_t)input.inputs[2].GetPointer(); // NOLINT + + auto res = make_uniq(stream_factory_produce, stream_factory_ptr); + + auto &data = *res; + stream_factory_get_schema(stream_factory_ptr, data.schema_root); + PopulateArrowTableType(res->arrow_table, data.schema_root, names, return_types); + RenameArrowColumns(names); + res->all_types = return_types; + return std::move(res); +} + +unique_ptr ProduceArrowScan(const ArrowScanFunctionData &function, + const vector &column_ids, TableFilterSet *filters) { + //! Generate Projection Pushdown Vector + ArrowStreamParameters parameters; + D_ASSERT(!column_ids.empty()); + for (idx_t idx = 0; idx < column_ids.size(); idx++) { + auto col_idx = column_ids[idx]; + if (col_idx != COLUMN_IDENTIFIER_ROW_ID) { + auto &schema = *function.schema_root.arrow_schema.children[col_idx]; + parameters.projected_columns.projection_map[idx] = schema.name; + parameters.projected_columns.columns.emplace_back(schema.name); + } + } + parameters.filters = filters; + return function.scanner_producer(function.stream_factory_ptr, parameters); +} + +idx_t ArrowTableFunction::ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data_p) { + return context.db->NumberOfThreads(); +} + +bool ArrowTableFunction::ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, + ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state) { + lock_guard parallel_lock(parallel_state.main_mutex); + if (parallel_state.done) { + return false; + } + state.chunk_offset = 0; + state.batch_index = ++parallel_state.batch_index; + + auto current_chunk = parallel_state.stream->GetNextChunk(); + while (current_chunk->arrow_array.length == 0 && current_chunk->arrow_array.release) { + current_chunk = parallel_state.stream->GetNextChunk(); + } + state.chunk = std::move(current_chunk); + //! have we run out of chunks? we are done + if (!state.chunk->arrow_array.release) { + parallel_state.done = true; + return false; + } + return true; +} + +unique_ptr ArrowTableFunction::ArrowScanInitGlobal(ClientContext &context, + TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->Cast(); + auto result = make_uniq(); + result->stream = ProduceArrowScan(bind_data, input.column_ids, input.filters.get()); + result->max_threads = ArrowScanMaxThreads(context, input.bind_data.get()); + if (input.CanRemoveFilterColumns()) { + result->projection_ids = input.projection_ids; + for (const auto &col_idx : input.column_ids) { + if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { + result->scanned_types.emplace_back(LogicalType::ROW_TYPE); + } else { + result->scanned_types.push_back(bind_data.all_types[col_idx]); + } + } + } + return std::move(result); +} + +unique_ptr +ArrowTableFunction::ArrowScanInitLocalInternal(ClientContext &context, TableFunctionInitInput &input, + GlobalTableFunctionState *global_state_p) { + auto &global_state = global_state_p->Cast(); + auto current_chunk = make_uniq(); + auto result = make_uniq(std::move(current_chunk)); + result->column_ids = input.column_ids; + result->filters = input.filters.get(); + if (input.CanRemoveFilterColumns()) { + auto &asgs = global_state_p->Cast(); + result->all_columns.Initialize(context, asgs.scanned_types); + } + if (!ArrowScanParallelStateNext(context, input.bind_data.get(), *result, global_state)) { + return nullptr; + } + return std::move(result); +} + +unique_ptr ArrowTableFunction::ArrowScanInitLocal(ExecutionContext &context, + TableFunctionInitInput &input, + GlobalTableFunctionState *global_state_p) { + return ArrowScanInitLocalInternal(context.client, input, global_state_p); +} + +void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + if (!data_p.local_state) { + return; + } + auto &data = data_p.bind_data->CastNoConst(); // FIXME + auto &state = data_p.local_state->Cast(); + auto &global_state = data_p.global_state->Cast(); + + //! Out of tuples in this chunk + if (state.chunk_offset >= (idx_t)state.chunk->arrow_array.length) { + if (!ArrowScanParallelStateNext(context, data_p.bind_data.get(), state, global_state)) { + return; + } + } + int64_t output_size = MinValue(STANDARD_VECTOR_SIZE, state.chunk->arrow_array.length - state.chunk_offset); + data.lines_read += output_size; + if (global_state.CanRemoveFilterColumns()) { + state.all_columns.Reset(); + state.all_columns.SetCardinality(output_size); + ArrowToDuckDB(state, data.arrow_table.GetColumns(), state.all_columns, data.lines_read - output_size); + output.ReferenceColumns(state.all_columns, global_state.projection_ids); + } else { + output.SetCardinality(output_size); + ArrowToDuckDB(state, data.arrow_table.GetColumns(), output, data.lines_read - output_size); + } + + output.Verify(); + state.chunk_offset += output.size(); +} + +unique_ptr ArrowTableFunction::ArrowScanCardinality(ClientContext &context, const FunctionData *data) { + return make_uniq(); +} + +idx_t ArrowTableFunction::ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, + LocalTableFunctionState *local_state, + GlobalTableFunctionState *global_state) { + auto &state = local_state->Cast(); + return state.batch_index; +} + +void ArrowTableFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunction arrow("arrow_scan", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, + ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); + arrow.cardinality = ArrowScanCardinality; + arrow.get_batch_index = ArrowGetBatchIndex; + arrow.projection_pushdown = true; + arrow.filter_pushdown = true; + arrow.filter_prune = true; + set.AddFunction(arrow); + + TableFunction arrow_dumb("arrow_scan_dumb", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, + ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); + arrow_dumb.cardinality = ArrowScanCardinality; + arrow_dumb.get_batch_index = ArrowGetBatchIndex; + arrow_dumb.projection_pushdown = false; + arrow_dumb.filter_pushdown = false; + arrow_dumb.filter_prune = false; + set.AddFunction(arrow_dumb); +} + +void BuiltinFunctions::RegisterArrowFunctions() { + ArrowTableFunction::RegisterFunction(*this); +} +} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp new file mode 100644 index 00000000..42d04d2e --- /dev/null +++ b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp @@ -0,0 +1,57 @@ +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +void ArrowTableType::AddColumn(idx_t index, unique_ptr type) { + D_ASSERT(arrow_convert_data.find(index) == arrow_convert_data.end()); + arrow_convert_data.emplace(std::make_pair(index, std::move(type))); +} + +const arrow_column_map_t &ArrowTableType::GetColumns() const { + return arrow_convert_data; +} + +void ArrowType::AddChild(unique_ptr child) { + children.emplace_back(std::move(child)); +} + +void ArrowType::AssignChildren(vector> children) { + D_ASSERT(this->children.empty()); + this->children = std::move(children); +} + +void ArrowType::SetDictionary(unique_ptr dictionary) { + D_ASSERT(!this->dictionary_type); + dictionary_type = std::move(dictionary); +} + +const ArrowType &ArrowType::GetDictionary() const { + D_ASSERT(dictionary_type); + return *dictionary_type; +} + +const LogicalType &ArrowType::GetDuckType() const { + return type; +} + +ArrowVariableSizeType ArrowType::GetSizeType() const { + return size_type; +} + +ArrowDateTimeType ArrowType::GetDateTimeType() const { + return date_time_precision; +} + +const ArrowType &ArrowType::operator[](idx_t index) const { + D_ASSERT(index < children.size()); + return *children[index]; +} + +idx_t ArrowType::FixedSize() const { + D_ASSERT(size_type == ArrowVariableSizeType::FIXED_SIZE); + return fixed_size; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow_conversion.cpp b/src/duckdb/src/function/table/arrow_conversion.cpp new file mode 100644 index 00000000..597e6c33 --- /dev/null +++ b/src/duckdb/src/function/table/arrow_conversion.cpp @@ -0,0 +1,860 @@ +#include "duckdb/function/table/arrow.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/arrow_aux_data.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void ShiftRight(unsigned char *ar, int size, int shift) { + int carry = 0; + while (shift--) { + for (int i = size - 1; i >= 0; --i) { + int next = (ar[i] & 1) ? 0x80 : 0; + ar[i] = carry | (ar[i] >> 1); + carry = next; + } + } +} + +template +T *ArrowBufferData(ArrowArray &array, idx_t buffer_idx) { + return (T *)array.buffers[buffer_idx]; // NOLINT +} + +static void GetValidityMask(ValidityMask &mask, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, + int64_t nested_offset = -1, bool add_null = false) { + // In certains we don't need to or cannot copy arrow's validity mask to duckdb. + // + // The conditions where we do want to copy arrow's mask to duckdb are: + // 1. nulls exist + // 2. n_buffers > 0, meaning the array's arrow type is not `null` + // 3. the validity buffer (the first buffer) is not a nullptr + if (array.null_count != 0 && array.n_buffers > 0 && array.buffers[0]) { + auto bit_offset = scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + bit_offset = nested_offset; + } + mask.EnsureWritable(); +#if STANDARD_VECTOR_SIZE > 64 + auto n_bitmask_bytes = (size + 8 - 1) / 8; + if (bit_offset % 8 == 0) { + //! just memcpy nullmask + memcpy((void *)mask.GetData(), ArrowBufferData(array, 0) + bit_offset / 8, n_bitmask_bytes); + } else { + //! need to re-align nullmask + vector temp_nullmask(n_bitmask_bytes + 1); + memcpy(temp_nullmask.data(), ArrowBufferData(array, 0) + bit_offset / 8, n_bitmask_bytes + 1); + ShiftRight(temp_nullmask.data(), n_bitmask_bytes + 1, + bit_offset % 8); //! why this has to be a right shift is a mystery to me + memcpy((void *)mask.GetData(), data_ptr_cast(temp_nullmask.data()), n_bitmask_bytes); + } +#else + auto byte_offset = bit_offset / 8; + auto source_data = ArrowBufferData(array, 0); + bit_offset %= 8; + for (idx_t i = 0; i < size; i++) { + mask.Set(i, source_data[byte_offset] & (1 << bit_offset)); + bit_offset++; + if (bit_offset == 8) { + bit_offset = 0; + byte_offset++; + } + } +#endif + } + if (add_null) { + //! We are setting a validity mask of the data part of dictionary vector + //! For some reason, Nulls are allowed to be indexes, hence we need to set the last element here to be null + //! We might have to resize the mask + mask.Resize(size, size + 1); + mask.SetInvalid(size); + } +} + +static void SetValidityMask(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, + int64_t nested_offset, bool add_null = false) { + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); + auto &mask = FlatVector::Validity(vector); + GetValidityMask(mask, array, scan_state, size, nested_offset, add_null); +} + +static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, + const ArrowType &arrow_type, int64_t nested_offset = -1, + ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); + +static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, + const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask) { + auto size_type = arrow_type.GetSizeType(); + idx_t list_size = 0; + SetValidityMask(vector, array, scan_state, size, nested_offset); + idx_t start_offset = 0; + idx_t cur_offset = 0; + if (size_type == ArrowVariableSizeType::FIXED_SIZE) { + auto fixed_size = arrow_type.FixedSize(); + //! Have to check validity mask before setting this up + idx_t offset = (scan_state.chunk_offset + array.offset) * fixed_size; + if (nested_offset != -1) { + offset = fixed_size * nested_offset; + } + start_offset = offset; + auto list_data = FlatVector::GetData(vector); + for (idx_t i = 0; i < size; i++) { + auto &le = list_data[i]; + le.offset = cur_offset; + le.length = fixed_size; + cur_offset += fixed_size; + } + list_size = start_offset + cur_offset; + } else if (size_type == ArrowVariableSizeType::NORMAL) { + auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; + if (nested_offset != -1) { + offsets = ArrowBufferData(array, 1) + nested_offset; + } + start_offset = offsets[0]; + auto list_data = FlatVector::GetData(vector); + for (idx_t i = 0; i < size; i++) { + auto &le = list_data[i]; + le.offset = cur_offset; + le.length = offsets[i + 1] - offsets[i]; + cur_offset += le.length; + } + list_size = offsets[size]; + } else { + auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; + if (nested_offset != -1) { + offsets = ArrowBufferData(array, 1) + nested_offset; + } + start_offset = offsets[0]; + auto list_data = FlatVector::GetData(vector); + for (idx_t i = 0; i < size; i++) { + auto &le = list_data[i]; + le.offset = cur_offset; + le.length = offsets[i + 1] - offsets[i]; + cur_offset += le.length; + } + list_size = offsets[size]; + } + list_size -= start_offset; + ListVector::Reserve(vector, list_size); + ListVector::SetListSize(vector, list_size); + auto &child_vector = ListVector::GetEntry(vector); + SetValidityMask(child_vector, *array.children[0], scan_state, list_size, start_offset); + auto &list_mask = FlatVector::Validity(vector); + if (parent_mask) { + //! Since this List is owned by a struct we must guarantee their validity map matches on Null + if (!parent_mask->AllValid()) { + for (idx_t i = 0; i < size; i++) { + if (!parent_mask->RowIsValid(i)) { + list_mask.SetInvalid(i); + } + } + } + } + if (list_size == 0 && start_offset == 0) { + ColumnArrowToDuckDB(child_vector, *array.children[0], scan_state, list_size, arrow_type[0], -1); + } else { + ColumnArrowToDuckDB(child_vector, *array.children[0], scan_state, list_size, arrow_type[0], start_offset); + } +} + +static void ArrowToDuckDBBlob(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, + const ArrowType &arrow_type, int64_t nested_offset) { + auto size_type = arrow_type.GetSizeType(); + SetValidityMask(vector, array, scan_state, size, nested_offset); + if (size_type == ArrowVariableSizeType::FIXED_SIZE) { + auto fixed_size = arrow_type.FixedSize(); + //! Have to check validity mask before setting this up + idx_t offset = (scan_state.chunk_offset + array.offset) * fixed_size; + if (nested_offset != -1) { + offset = fixed_size * nested_offset; + } + auto cdata = ArrowBufferData(array, 1); + for (idx_t row_idx = 0; row_idx < size; row_idx++) { + if (FlatVector::IsNull(vector, row_idx)) { + continue; + } + auto bptr = cdata + offset; + auto blob_len = fixed_size; + FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); + offset += blob_len; + } + } else if (size_type == ArrowVariableSizeType::NORMAL) { + auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; + if (nested_offset != -1) { + offsets = ArrowBufferData(array, 1) + array.offset + nested_offset; + } + auto cdata = ArrowBufferData(array, 2); + for (idx_t row_idx = 0; row_idx < size; row_idx++) { + if (FlatVector::IsNull(vector, row_idx)) { + continue; + } + auto bptr = cdata + offsets[row_idx]; + auto blob_len = offsets[row_idx + 1] - offsets[row_idx]; + FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); + } + } else { + //! Check if last offset is higher than max uint32 + if (ArrowBufferData(array, 1)[array.length] > NumericLimits::Maximum()) { // LCOV_EXCL_START + throw ConversionException("DuckDB does not support Blobs over 4GB"); + } // LCOV_EXCL_STOP + auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; + if (nested_offset != -1) { + offsets = ArrowBufferData(array, 1) + array.offset + nested_offset; + } + auto cdata = ArrowBufferData(array, 2); + for (idx_t row_idx = 0; row_idx < size; row_idx++) { + if (FlatVector::IsNull(vector, row_idx)) { + continue; + } + auto bptr = cdata + offsets[row_idx]; + auto blob_len = offsets[row_idx + 1] - offsets[row_idx]; + FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); + } + } +} + +static void ArrowToDuckDBMapVerify(Vector &vector, idx_t count) { + auto valid_check = MapVector::CheckMapValidity(vector, count); + switch (valid_check) { + case MapInvalidReason::VALID: + break; + case MapInvalidReason::DUPLICATE_KEY: { + throw InvalidInputException("Arrow map contains duplicate key, which isn't supported by DuckDB map type"); + } + case MapInvalidReason::NULL_KEY: { + throw InvalidInputException("Arrow map contains NULL as map key, which isn't supported by DuckDB map type"); + } + case MapInvalidReason::NULL_KEY_LIST: { + throw InvalidInputException("Arrow map contains NULL as key list, which isn't supported by DuckDB map type"); + } + default: { + throw InternalException("MapInvalidReason not implemented"); + } + } +} + +template +static void SetVectorString(Vector &vector, idx_t size, char *cdata, T *offsets) { + auto strings = FlatVector::GetData(vector); + for (idx_t row_idx = 0; row_idx < size; row_idx++) { + if (FlatVector::IsNull(vector, row_idx)) { + continue; + } + auto cptr = cdata + offsets[row_idx]; + auto str_len = offsets[row_idx + 1] - offsets[row_idx]; + if (str_len > NumericLimits::Maximum()) { // LCOV_EXCL_START + throw ConversionException("DuckDB does not support Strings over 4GB"); + } // LCOV_EXCL_STOP + strings[row_idx] = string_t(cptr, str_len); + } +} + +static void DirectConversion(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, int64_t nested_offset, + uint64_t parent_offset) { + auto internal_type = GetTypeIdSize(vector.GetType().InternalType()); + auto data_ptr = + ArrowBufferData(array, 1) + internal_type * (scan_state.chunk_offset + array.offset + parent_offset); + if (nested_offset != -1) { + data_ptr = ArrowBufferData(array, 1) + internal_type * (array.offset + nested_offset + parent_offset); + } + FlatVector::SetData(vector, data_ptr); +} + +template +static void TimeConversion(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, int64_t nested_offset, + idx_t size, int64_t conversion) { + auto tgt_ptr = FlatVector::GetData(vector); + auto &validity_mask = FlatVector::Validity(vector); + auto src_ptr = (T *)array.buffers[1] + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = (T *)array.buffers[1] + nested_offset + array.offset; + } + for (idx_t row = 0; row < size; row++) { + if (!validity_mask.RowIsValid(row)) { + continue; + } + if (!TryMultiplyOperator::Operation((int64_t)src_ptr[row], conversion, tgt_ptr[row].micros)) { + throw ConversionException("Could not convert Time to Microsecond"); + } + } +} + +static void TimestampTZConversion(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, + int64_t nested_offset, idx_t size, int64_t conversion) { + auto tgt_ptr = FlatVector::GetData(vector); + auto &validity_mask = FlatVector::Validity(vector); + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + for (idx_t row = 0; row < size; row++) { + if (!validity_mask.RowIsValid(row)) { + continue; + } + if (!TryMultiplyOperator::Operation(src_ptr[row], conversion, tgt_ptr[row].value)) { + throw ConversionException("Could not convert TimestampTZ to Microsecond"); + } + } +} + +static void IntervalConversionUs(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, + int64_t nested_offset, idx_t size, int64_t conversion) { + auto tgt_ptr = FlatVector::GetData(vector); + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + for (idx_t row = 0; row < size; row++) { + tgt_ptr[row].days = 0; + tgt_ptr[row].months = 0; + if (!TryMultiplyOperator::Operation(src_ptr[row], conversion, tgt_ptr[row].micros)) { + throw ConversionException("Could not convert Interval to Microsecond"); + } + } +} + +static void IntervalConversionMonths(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, + int64_t nested_offset, idx_t size) { + auto tgt_ptr = FlatVector::GetData(vector); + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + for (idx_t row = 0; row < size; row++) { + tgt_ptr[row].days = 0; + tgt_ptr[row].micros = 0; + tgt_ptr[row].months = src_ptr[row]; + } +} + +static void IntervalConversionMonthDayNanos(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, + int64_t nested_offset, idx_t size) { + auto tgt_ptr = FlatVector::GetData(vector); + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + for (idx_t row = 0; row < size; row++) { + tgt_ptr[row].days = src_ptr[row].days; + tgt_ptr[row].micros = src_ptr[row].nanoseconds / Interval::NANOS_PER_MICRO; + tgt_ptr[row].months = src_ptr[row].months; + } +} + +static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, + const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask, + uint64_t parent_offset) { + switch (vector.GetType().id()) { + case LogicalTypeId::SQLNULL: + vector.Reference(Value()); + break; + case LogicalTypeId::BOOLEAN: { + //! Arrow bit-packs boolean values + //! Lets first figure out where we are in the source array + auto src_ptr = ArrowBufferData(array, 1) + (scan_state.chunk_offset + array.offset) / 8; + + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + (nested_offset + array.offset) / 8; + } + auto tgt_ptr = (uint8_t *)FlatVector::GetData(vector); + int src_pos = 0; + idx_t cur_bit = scan_state.chunk_offset % 8; + if (nested_offset != -1) { + cur_bit = nested_offset % 8; + } + for (idx_t row = 0; row < size; row++) { + if ((src_ptr[src_pos] & (1 << cur_bit)) == 0) { + tgt_ptr[row] = 0; + } else { + tgt_ptr[row] = 1; + } + cur_bit++; + if (cur_bit == 8) { + src_pos++; + cur_bit = 0; + } + } + break; + } + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: { + DirectConversion(vector, array, scan_state, nested_offset, parent_offset); + break; + } + case LogicalTypeId::VARCHAR: { + auto size_type = arrow_type.GetSizeType(); + auto cdata = ArrowBufferData(array, 2); + if (size_type == ArrowVariableSizeType::SUPER_SIZE) { + auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; + if (nested_offset != -1) { + offsets = ArrowBufferData(array, 1) + array.offset + nested_offset; + } + SetVectorString(vector, size, cdata, offsets); + } else { + auto offsets = ArrowBufferData(array, 1) + array.offset + scan_state.chunk_offset; + if (nested_offset != -1) { + offsets = ArrowBufferData(array, 1) + array.offset + nested_offset; + } + SetVectorString(vector, size, cdata, offsets); + } + break; + } + case LogicalTypeId::DATE: { + + auto precision = arrow_type.GetDateTimeType(); + switch (precision) { + case ArrowDateTimeType::DAYS: { + DirectConversion(vector, array, scan_state, nested_offset, parent_offset); + break; + } + case ArrowDateTimeType::MILLISECONDS: { + //! convert date from nanoseconds to days + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + auto tgt_ptr = FlatVector::GetData(vector); + for (idx_t row = 0; row < size; row++) { + tgt_ptr[row] = date_t(int64_t(src_ptr[row]) / static_cast(1000 * 60 * 60 * 24)); + } + break; + } + default: + throw NotImplementedException("Unsupported precision for Date Type "); + } + break; + } + case LogicalTypeId::TIME: { + auto precision = arrow_type.GetDateTimeType(); + switch (precision) { + case ArrowDateTimeType::SECONDS: { + TimeConversion(vector, array, scan_state, nested_offset, size, 1000000); + break; + } + case ArrowDateTimeType::MILLISECONDS: { + TimeConversion(vector, array, scan_state, nested_offset, size, 1000); + break; + } + case ArrowDateTimeType::MICROSECONDS: { + TimeConversion(vector, array, scan_state, nested_offset, size, 1); + break; + } + case ArrowDateTimeType::NANOSECONDS: { + auto tgt_ptr = FlatVector::GetData(vector); + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + for (idx_t row = 0; row < size; row++) { + tgt_ptr[row].micros = src_ptr[row] / 1000; + } + break; + } + default: + throw NotImplementedException("Unsupported precision for Time Type "); + } + break; + } + case LogicalTypeId::TIMESTAMP_TZ: { + auto precision = arrow_type.GetDateTimeType(); + switch (precision) { + case ArrowDateTimeType::SECONDS: { + TimestampTZConversion(vector, array, scan_state, nested_offset, size, 1000000); + break; + } + case ArrowDateTimeType::MILLISECONDS: { + TimestampTZConversion(vector, array, scan_state, nested_offset, size, 1000); + break; + } + case ArrowDateTimeType::MICROSECONDS: { + DirectConversion(vector, array, scan_state, nested_offset, parent_offset); + break; + } + case ArrowDateTimeType::NANOSECONDS: { + auto tgt_ptr = FlatVector::GetData(vector); + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + for (idx_t row = 0; row < size; row++) { + tgt_ptr[row].value = src_ptr[row] / 1000; + } + break; + } + default: + throw NotImplementedException("Unsupported precision for TimestampTZ Type "); + } + break; + } + case LogicalTypeId::INTERVAL: { + auto precision = arrow_type.GetDateTimeType(); + switch (precision) { + case ArrowDateTimeType::SECONDS: { + IntervalConversionUs(vector, array, scan_state, nested_offset, size, 1000000); + break; + } + case ArrowDateTimeType::DAYS: + case ArrowDateTimeType::MILLISECONDS: { + IntervalConversionUs(vector, array, scan_state, nested_offset, size, 1000); + break; + } + case ArrowDateTimeType::MICROSECONDS: { + IntervalConversionUs(vector, array, scan_state, nested_offset, size, 1); + break; + } + case ArrowDateTimeType::NANOSECONDS: { + auto tgt_ptr = FlatVector::GetData(vector); + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + for (idx_t row = 0; row < size; row++) { + tgt_ptr[row].micros = src_ptr[row] / 1000; + tgt_ptr[row].days = 0; + tgt_ptr[row].months = 0; + } + break; + } + case ArrowDateTimeType::MONTHS: { + IntervalConversionMonths(vector, array, scan_state, nested_offset, size); + break; + } + case ArrowDateTimeType::MONTH_DAY_NANO: { + IntervalConversionMonthDayNanos(vector, array, scan_state, nested_offset, size); + break; + } + default: + throw NotImplementedException("Unsupported precision for Interval/Duration Type "); + } + break; + } + case LogicalTypeId::DECIMAL: { + auto val_mask = FlatVector::Validity(vector); + //! We have to convert from INT128 + auto src_ptr = ArrowBufferData(array, 1) + scan_state.chunk_offset + array.offset; + if (nested_offset != -1) { + src_ptr = ArrowBufferData(array, 1) + nested_offset + array.offset; + } + switch (vector.GetType().InternalType()) { + case PhysicalType::INT16: { + auto tgt_ptr = FlatVector::GetData(vector); + for (idx_t row = 0; row < size; row++) { + if (val_mask.RowIsValid(row)) { + auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); + D_ASSERT(result); + (void)result; + } + } + break; + } + case PhysicalType::INT32: { + auto tgt_ptr = FlatVector::GetData(vector); + for (idx_t row = 0; row < size; row++) { + if (val_mask.RowIsValid(row)) { + auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); + D_ASSERT(result); + (void)result; + } + } + break; + } + case PhysicalType::INT64: { + auto tgt_ptr = FlatVector::GetData(vector); + for (idx_t row = 0; row < size; row++) { + if (val_mask.RowIsValid(row)) { + auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); + D_ASSERT(result); + (void)result; + } + } + break; + } + case PhysicalType::INT128: { + FlatVector::SetData(vector, + ArrowBufferData(array, 1) + GetTypeIdSize(vector.GetType().InternalType()) * + (scan_state.chunk_offset + array.offset)); + break; + } + default: + throw NotImplementedException("Unsupported physical type for Decimal: %s", + TypeIdToString(vector.GetType().InternalType())); + } + break; + } + case LogicalTypeId::BLOB: { + ArrowToDuckDBBlob(vector, array, scan_state, size, arrow_type, nested_offset); + break; + } + case LogicalTypeId::LIST: { + ArrowToDuckDBList(vector, array, scan_state, size, arrow_type, nested_offset, parent_mask); + break; + } + case LogicalTypeId::MAP: { + ArrowToDuckDBList(vector, array, scan_state, size, arrow_type, nested_offset, parent_mask); + ArrowToDuckDBMapVerify(vector, size); + break; + } + case LogicalTypeId::STRUCT: { + //! Fill the children + auto &child_entries = StructVector::GetEntries(vector); + auto &struct_validity_mask = FlatVector::Validity(vector); + for (idx_t type_idx = 0; type_idx < static_cast(array.n_children); type_idx++) { + SetValidityMask(*child_entries[type_idx], *array.children[type_idx], scan_state, size, nested_offset); + if (!struct_validity_mask.AllValid()) { + auto &child_validity_mark = FlatVector::Validity(*child_entries[type_idx]); + for (idx_t i = 0; i < size; i++) { + if (!struct_validity_mask.RowIsValid(i)) { + child_validity_mark.SetInvalid(i); + } + } + } + ColumnArrowToDuckDB(*child_entries[type_idx], *array.children[type_idx], scan_state, size, + arrow_type[type_idx], nested_offset, &struct_validity_mask, array.offset); + } + break; + } + case LogicalTypeId::UNION: { + auto type_ids = ArrowBufferData(array, array.n_buffers == 1 ? 0 : 1); + D_ASSERT(type_ids); + auto members = UnionType::CopyMemberTypes(vector.GetType()); + + auto &validity_mask = FlatVector::Validity(vector); + + duckdb::vector children; + for (idx_t type_idx = 0; type_idx < static_cast(array.n_children); type_idx++) { + Vector child(members[type_idx].second); + auto arrow_array = array.children[type_idx]; + + SetValidityMask(child, *arrow_array, scan_state, size, nested_offset); + + ColumnArrowToDuckDB(child, *arrow_array, scan_state, size, arrow_type, nested_offset, &validity_mask); + + children.push_back(std::move(child)); + } + + for (idx_t row_idx = 0; row_idx < size; row_idx++) { + auto tag = type_ids[row_idx]; + + auto out_of_range = tag < 0 || tag >= array.n_children; + if (out_of_range) { + throw InvalidInputException("Arrow union tag out of range: %d", tag); + } + + const Value &value = children[tag].GetValue(row_idx); + vector.SetValue(row_idx, value.IsNull() ? Value() : Value::UNION(members, tag, value)); + } + + break; + } + default: + throw NotImplementedException("Unsupported type for arrow conversion: %s", vector.GetType().ToString()); + } +} + +template +static void SetSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { + auto indices = reinterpret_cast(indices_p); + for (idx_t row = 0; row < size; row++) { + sel.set_index(row, indices[row]); + } +} + +template +static void SetSelectionVectorLoopWithChecks(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { + + auto indices = reinterpret_cast(indices_p); + for (idx_t row = 0; row < size; row++) { + if (indices[row] > NumericLimits::Maximum()) { + throw ConversionException("DuckDB only supports indices that fit on an uint32"); + } + sel.set_index(row, indices[row]); + } +} + +template +static void SetMaskedSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, idx_t size, ValidityMask &mask, + idx_t last_element_pos) { + auto indices = reinterpret_cast(indices_p); + for (idx_t row = 0; row < size; row++) { + if (mask.RowIsValid(row)) { + sel.set_index(row, indices[row]); + } else { + //! Need to point out to last element + sel.set_index(row, last_element_pos); + } + } +} + +static void SetSelectionVector(SelectionVector &sel, data_ptr_t indices_p, LogicalType &logical_type, idx_t size, + ValidityMask *mask = nullptr, idx_t last_element_pos = 0) { + sel.Initialize(size); + + if (mask) { + switch (logical_type.id()) { + case LogicalTypeId::UTINYINT: + SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); + break; + case LogicalTypeId::TINYINT: + SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); + break; + case LogicalTypeId::USMALLINT: + SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); + break; + case LogicalTypeId::SMALLINT: + SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); + break; + case LogicalTypeId::UINTEGER: + if (last_element_pos > NumericLimits::Maximum()) { + //! Its guaranteed that our indices will point to the last element, so just throw an error + throw ConversionException("DuckDB only supports indices that fit on an uint32"); + } + SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); + break; + case LogicalTypeId::INTEGER: + SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); + break; + case LogicalTypeId::UBIGINT: + if (last_element_pos > NumericLimits::Maximum()) { + //! Its guaranteed that our indices will point to the last element, so just throw an error + throw ConversionException("DuckDB only supports indices that fit on an uint32"); + } + SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); + break; + case LogicalTypeId::BIGINT: + if (last_element_pos > NumericLimits::Maximum()) { + //! Its guaranteed that our indices will point to the last element, so just throw an error + throw ConversionException("DuckDB only supports indices that fit on an uint32"); + } + SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); + break; + + default: + throw NotImplementedException("(Arrow) Unsupported type for selection vectors %s", logical_type.ToString()); + } + + } else { + switch (logical_type.id()) { + case LogicalTypeId::UTINYINT: + SetSelectionVectorLoop(sel, indices_p, size); + break; + case LogicalTypeId::TINYINT: + SetSelectionVectorLoop(sel, indices_p, size); + break; + case LogicalTypeId::USMALLINT: + SetSelectionVectorLoop(sel, indices_p, size); + break; + case LogicalTypeId::SMALLINT: + SetSelectionVectorLoop(sel, indices_p, size); + break; + case LogicalTypeId::UINTEGER: + SetSelectionVectorLoop(sel, indices_p, size); + break; + case LogicalTypeId::INTEGER: + SetSelectionVectorLoop(sel, indices_p, size); + break; + case LogicalTypeId::UBIGINT: + if (last_element_pos > NumericLimits::Maximum()) { + //! We need to check if our indexes fit in a uint32_t + SetSelectionVectorLoopWithChecks(sel, indices_p, size); + } else { + SetSelectionVectorLoop(sel, indices_p, size); + } + break; + case LogicalTypeId::BIGINT: + if (last_element_pos > NumericLimits::Maximum()) { + //! We need to check if our indexes fit in a uint32_t + SetSelectionVectorLoopWithChecks(sel, indices_p, size); + } else { + SetSelectionVectorLoop(sel, indices_p, size); + } + break; + default: + throw ConversionException("(Arrow) Unsupported type for selection vectors %s", logical_type.ToString()); + } + } +} + +static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, + idx_t size, const ArrowType &arrow_type, idx_t col_idx) { + SelectionVector sel; + auto &dict_vectors = scan_state.arrow_dictionary_vectors; + if (!dict_vectors.count(col_idx)) { + //! We need to set the dictionary data for this column + auto base_vector = make_uniq(vector.GetType(), array.dictionary->length); + SetValidityMask(*base_vector, *array.dictionary, scan_state, array.dictionary->length, 0, array.null_count > 0); + ColumnArrowToDuckDB(*base_vector, *array.dictionary, scan_state, array.dictionary->length, + arrow_type.GetDictionary()); + dict_vectors[col_idx] = std::move(base_vector); + } + auto dictionary_type = arrow_type.GetDuckType(); + //! Get Pointer to Indices of Dictionary + auto indices = ArrowBufferData(array, 1) + + GetTypeIdSize(dictionary_type.InternalType()) * (scan_state.chunk_offset + array.offset); + if (array.null_count > 0) { + ValidityMask indices_validity; + GetValidityMask(indices_validity, array, scan_state, size); + SetSelectionVector(sel, indices, dictionary_type, size, &indices_validity, array.dictionary->length); + } else { + SetSelectionVector(sel, indices, dictionary_type, size); + } + vector.Slice(*dict_vectors[col_idx], sel, size); +} + +void ArrowTableFunction::ArrowToDuckDB(ArrowScanLocalState &scan_state, const arrow_column_map_t &arrow_convert_data, + DataChunk &output, idx_t start, bool arrow_scan_is_projected) { + for (idx_t idx = 0; idx < output.ColumnCount(); idx++) { + auto col_idx = scan_state.column_ids[idx]; + + // If projection was not pushed down into the arrow scanner, but projection pushdown is enabled on the + // table function, we need to use original column ids here. + auto arrow_array_idx = arrow_scan_is_projected ? idx : col_idx; + + if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { + // This column is skipped by the projection pushdown + continue; + } + + auto &array = *scan_state.chunk->arrow_array.children[arrow_array_idx]; + if (!array.release) { + throw InvalidInputException("arrow_scan: released array passed"); + } + if (array.length != scan_state.chunk->arrow_array.length) { + throw InvalidInputException("arrow_scan: array length mismatch"); + } + // Make sure this Vector keeps the Arrow chunk alive in case we can zero-copy the data + if (scan_state.arrow_owned_data.find(idx) == scan_state.arrow_owned_data.end()) { + auto arrow_data = make_shared(); + arrow_data->arrow_array = scan_state.chunk->arrow_array; + scan_state.chunk->arrow_array.release = nullptr; + scan_state.arrow_owned_data[idx] = arrow_data; + } + + output.data[idx].GetBuffer()->SetAuxiliaryData(make_uniq(scan_state.arrow_owned_data[idx])); + + D_ASSERT(arrow_convert_data.find(col_idx) != arrow_convert_data.end()); + auto &arrow_type = *arrow_convert_data.at(col_idx); + if (array.dictionary) { + ColumnArrowToDuckDBDictionary(output.data[idx], array, scan_state, output.size(), arrow_type, col_idx); + } else { + SetValidityMask(output.data[idx], array, scan_state, output.size(), -1); + ColumnArrowToDuckDB(output.data[idx], array, scan_state, output.size(), arrow_type); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/checkpoint.cpp b/src/duckdb/src/function/table/checkpoint.cpp new file mode 100644 index 00000000..9fdd19a7 --- /dev/null +++ b/src/duckdb/src/function/table/checkpoint.cpp @@ -0,0 +1,69 @@ +#include "duckdb/function/table/range.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/transaction/transaction_manager.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct CheckpointBindData : public FunctionData { + explicit CheckpointBindData(optional_ptr db) : db(db) { + } + + optional_ptr db; + +public: + unique_ptr Copy() const override { + return make_uniq(db); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return db == other.db; + } +}; + +static unique_ptr CheckpointBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + return_types.emplace_back(LogicalType::BOOLEAN); + names.emplace_back("Success"); + + optional_ptr db; + auto &db_manager = DatabaseManager::Get(context); + if (!input.inputs.empty()) { + if (input.inputs[0].IsNull()) { + throw BinderException("Database cannot be NULL"); + } + auto &db_name = StringValue::Get(input.inputs[0]); + db = db_manager.GetDatabase(context, db_name); + if (!db) { + throw BinderException("Database \"%s\" not found", db_name); + } + } else { + db = db_manager.GetDatabase(context, DatabaseManager::GetDefaultDatabase(context)); + } + return make_uniq(db); +} + +template +static void TemplatedCheckpointFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &transaction_manager = TransactionManager::Get(*bind_data.db.get_mutable()); + transaction_manager.Checkpoint(context, FORCE); +} + +void CheckpointFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunctionSet checkpoint("checkpoint"); + checkpoint.AddFunction(TableFunction({}, TemplatedCheckpointFunction, CheckpointBind)); + checkpoint.AddFunction(TableFunction({LogicalType::VARCHAR}, TemplatedCheckpointFunction, CheckpointBind)); + set.AddFunction(checkpoint); + + TableFunctionSet force_checkpoint("force_checkpoint"); + force_checkpoint.AddFunction(TableFunction({}, TemplatedCheckpointFunction, CheckpointBind)); + force_checkpoint.AddFunction( + TableFunction({LogicalType::VARCHAR}, TemplatedCheckpointFunction, CheckpointBind)); + set.AddFunction(force_checkpoint); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/copy_csv.cpp b/src/duckdb/src/function/table/copy_csv.cpp new file mode 100644 index 00000000..b1f1b9f1 --- /dev/null +++ b/src/duckdb/src/function/table/copy_csv.cpp @@ -0,0 +1,541 @@ +#include "duckdb/common/bind_helpers.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/multi_file_reader.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "duckdb/function/copy_function.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/common/serializer/write_stream.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" + +#include + +namespace duckdb { + +void AreOptionsEqual(char &str_1, char &str_2, const string &name_str_1, const string &name_str_2) { + if (str_1 == '\0' || str_2 == '\0') { + return; + } + if (str_1 == str_2) { + throw BinderException("%s must not appear in the %s specification and vice versa", name_str_1, name_str_2); + } +} + +void SubstringDetection(char &str_1, string &str_2, const string &name_str_1, const string &name_str_2) { + if (str_1 == '\0' || str_2.empty()) { + return; + } + if (str_2.find(str_1) != string::npos) { + throw BinderException("%s must not appear in the %s specification and vice versa", name_str_1, name_str_2); + } +} + +//===--------------------------------------------------------------------===// +// Bind +//===--------------------------------------------------------------------===// +void WriteQuoteOrEscape(WriteStream &writer, char quote_or_escape) { + if (quote_or_escape != '\0') { + writer.Write(quote_or_escape); + } +} + +void BaseCSVData::Finalize() { + // verify that the options are correct in the final pass + if (options.dialect_options.state_machine_options.escape == '\0') { + options.dialect_options.state_machine_options.escape = options.dialect_options.state_machine_options.quote; + } + // escape and delimiter must not be substrings of each other + if (options.has_delimiter && options.has_escape) { + AreOptionsEqual(options.dialect_options.state_machine_options.delimiter, + options.dialect_options.state_machine_options.escape, "DELIMITER", "ESCAPE"); + } + // delimiter and quote must not be substrings of each other + if (options.has_quote && options.has_delimiter) { + AreOptionsEqual(options.dialect_options.state_machine_options.quote, + options.dialect_options.state_machine_options.delimiter, "DELIMITER", "QUOTE"); + } + // escape and quote must not be substrings of each other (but can be the same) + if (options.dialect_options.state_machine_options.quote != options.dialect_options.state_machine_options.escape && + options.has_quote && options.has_escape) { + AreOptionsEqual(options.dialect_options.state_machine_options.quote, + options.dialect_options.state_machine_options.escape, "QUOTE", "ESCAPE"); + } + if (!options.null_str.empty()) { + // null string and delimiter must not be substrings of each other + if (options.has_delimiter) { + SubstringDetection(options.dialect_options.state_machine_options.delimiter, options.null_str, "DELIMITER", + "NULL"); + } + // quote/escape and nullstr must not be substrings of each other + if (options.has_quote) { + SubstringDetection(options.dialect_options.state_machine_options.quote, options.null_str, "QUOTE", "NULL"); + } + if (options.has_escape) { + SubstringDetection(options.dialect_options.state_machine_options.escape, options.null_str, "ESCAPE", + "NULL"); + } + } + + if (!options.prefix.empty() || !options.suffix.empty()) { + if (options.prefix.empty() || options.suffix.empty()) { + throw BinderException("COPY ... (FORMAT CSV) must have both PREFIX and SUFFIX, or none at all"); + } + if (options.dialect_options.header) { + throw BinderException("COPY ... (FORMAT CSV)'s HEADER cannot be combined with PREFIX/SUFFIX"); + } + } +} + +static unique_ptr WriteCSVBind(ClientContext &context, CopyInfo &info, vector &names, + vector &sql_types) { + auto bind_data = make_uniq(info.file_path, sql_types, names); + + // check all the options in the copy info + for (auto &option : info.options) { + auto loption = StringUtil::Lower(option.first); + auto &set = option.second; + bind_data->options.SetWriteOption(loption, ConvertVectorToValue(std::move(set))); + } + // verify the parsed options + if (bind_data->options.force_quote.empty()) { + // no FORCE_QUOTE specified: initialize to false + bind_data->options.force_quote.resize(names.size(), false); + } + bind_data->Finalize(); + + bind_data->requires_quotes = make_unsafe_uniq_array(256); + memset(bind_data->requires_quotes.get(), 0, sizeof(bool) * 256); + bind_data->requires_quotes['\n'] = true; + bind_data->requires_quotes['\r'] = true; + bind_data->requires_quotes[bind_data->options.dialect_options.state_machine_options.delimiter] = true; + bind_data->requires_quotes[bind_data->options.dialect_options.state_machine_options.quote] = true; + + if (!bind_data->options.write_newline.empty()) { + bind_data->newline = bind_data->options.write_newline; + } + return std::move(bind_data); +} + +static unique_ptr ReadCSVBind(ClientContext &context, CopyInfo &info, vector &expected_names, + vector &expected_types) { + auto bind_data = make_uniq(); + bind_data->csv_types = expected_types; + bind_data->csv_names = expected_names; + bind_data->return_types = expected_types; + bind_data->return_names = expected_names; + bind_data->files = MultiFileReader::GetFileList(context, Value(info.file_path), "CSV"); + + auto &options = bind_data->options; + + // check all the options in the copy info + for (auto &option : info.options) { + auto loption = StringUtil::Lower(option.first); + auto &set = option.second; + options.SetReadOption(loption, ConvertVectorToValue(set), expected_names); + } + // verify the parsed options + if (options.force_not_null.empty()) { + // no FORCE_QUOTE specified: initialize to false + options.force_not_null.resize(expected_types.size(), false); + } + + // Look for rejects table options last + named_parameter_map_t options_map; + for (auto &option : info.options) { + options_map[option.first] = ConvertVectorToValue(std::move(option.second)); + } + options.file_path = bind_data->files[0]; + options.name_list = expected_names; + options.sql_type_list = expected_types; + for (idx_t i = 0; i < expected_types.size(); i++) { + options.sql_types_per_column[expected_names[i]] = i; + } + + bind_data->FinalizeRead(context); + + if (options.auto_detect) { + // We must run the sniffer. + auto file_handle = BaseCSVReader::OpenCSV(context, options); + auto buffer_manager = make_shared(context, std::move(file_handle), options); + CSVSniffer sniffer(options, buffer_manager, bind_data->state_machine_cache); + auto sniffer_result = sniffer.SniffCSV(); + bind_data->csv_types = sniffer_result.return_types; + bind_data->csv_names = sniffer_result.names; + bind_data->return_types = sniffer_result.return_types; + bind_data->return_names = sniffer_result.names; + } + return std::move(bind_data); +} + +//===--------------------------------------------------------------------===// +// Helper writing functions +//===--------------------------------------------------------------------===// +static string AddEscapes(char &to_be_escaped, const char &escape, const string &val) { + idx_t i = 0; + string new_val = ""; + idx_t found = val.find(to_be_escaped); + + while (found != string::npos) { + while (i < found) { + new_val += val[i]; + i++; + } + if (escape != '\0') { + new_val += escape; + found = val.find(to_be_escaped, found + 1); + } + } + while (i < val.length()) { + new_val += val[i]; + i++; + } + return new_val; +} + +static bool RequiresQuotes(WriteCSVData &csv_data, const char *str, idx_t len) { + auto &options = csv_data.options; + // check if the string is equal to the null string + if (len == options.null_str.size() && memcmp(str, options.null_str.c_str(), len) == 0) { + return true; + } + auto str_data = reinterpret_cast(str); + for (idx_t i = 0; i < len; i++) { + if (csv_data.requires_quotes[str_data[i]]) { + // this byte requires quotes - write a quoted string + return true; + } + } + // no newline, quote or delimiter in the string + // no quoting or escaping necessary + return false; +} + +static void WriteQuotedString(WriteStream &writer, WriteCSVData &csv_data, const char *str, idx_t len, + bool force_quote) { + auto &options = csv_data.options; + if (!force_quote) { + // force quote is disabled: check if we need to add quotes anyway + force_quote = RequiresQuotes(csv_data, str, len); + } + if (force_quote) { + // quoting is enabled: we might need to escape things in the string + bool requires_escape = false; + // simple CSV + // do a single loop to check for a quote or escape value + for (idx_t i = 0; i < len; i++) { + if (str[i] == options.dialect_options.state_machine_options.quote || + str[i] == options.dialect_options.state_machine_options.escape) { + requires_escape = true; + break; + } + } + + if (!requires_escape) { + // fast path: no need to escape anything + WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote); + writer.WriteData(const_data_ptr_cast(str), len); + WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote); + return; + } + + // slow path: need to add escapes + string new_val(str, len); + new_val = AddEscapes(options.dialect_options.state_machine_options.escape, + options.dialect_options.state_machine_options.escape, new_val); + if (options.dialect_options.state_machine_options.escape != + options.dialect_options.state_machine_options.quote) { + // need to escape quotes separately + new_val = AddEscapes(options.dialect_options.state_machine_options.quote, + options.dialect_options.state_machine_options.escape, new_val); + } + WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote); + writer.WriteData(const_data_ptr_cast(new_val.c_str()), new_val.size()); + WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote); + } else { + writer.WriteData(const_data_ptr_cast(str), len); + } +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +struct LocalWriteCSVData : public LocalFunctionData { + //! The thread-local buffer to write data into + MemoryStream stream; + //! A chunk with VARCHAR columns to cast intermediates into + DataChunk cast_chunk; + //! If we've written any rows yet, allows us to prevent a trailing comma when writing JSON ARRAY + bool written_anything = false; +}; + +struct GlobalWriteCSVData : public GlobalFunctionData { + GlobalWriteCSVData(FileSystem &fs, const string &file_path, FileCompressionType compression) + : fs(fs), written_anything(false) { + handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW, + FileLockType::WRITE_LOCK, compression); + } + + //! Write generic data, e.g., CSV header + void WriteData(const_data_ptr_t data, idx_t size) { + lock_guard flock(lock); + handle->Write((void *)data, size); + } + + void WriteData(const char *data, idx_t size) { + WriteData(const_data_ptr_cast(data), size); + } + + //! Write rows + void WriteRows(const_data_ptr_t data, idx_t size, const string &newline) { + lock_guard flock(lock); + if (written_anything) { + handle->Write((void *)newline.c_str(), newline.length()); + } else { + written_anything = true; + } + handle->Write((void *)data, size); + } + + FileSystem &fs; + //! The mutex for writing to the physical file + mutex lock; + //! The file handle to write to + unique_ptr handle; + //! If we've written any rows yet, allows us to prevent a trailing comma when writing JSON ARRAY + bool written_anything; +}; + +static unique_ptr WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) { + auto &csv_data = bind_data.Cast(); + auto local_data = make_uniq(); + + // create the chunk with VARCHAR types + vector types; + types.resize(csv_data.options.name_list.size(), LogicalType::VARCHAR); + + local_data->cast_chunk.Initialize(Allocator::Get(context.client), types); + return std::move(local_data); +} + +static unique_ptr WriteCSVInitializeGlobal(ClientContext &context, FunctionData &bind_data, + const string &file_path) { + auto &csv_data = bind_data.Cast(); + auto &options = csv_data.options; + auto global_data = + make_uniq(FileSystem::GetFileSystem(context), file_path, options.compression); + + if (!options.prefix.empty()) { + global_data->WriteData(options.prefix.c_str(), options.prefix.size()); + } + + if (!(options.has_header && !options.dialect_options.header)) { + MemoryStream stream; + // write the header line to the file + for (idx_t i = 0; i < csv_data.options.name_list.size(); i++) { + if (i != 0) { + WriteQuoteOrEscape(stream, options.dialect_options.state_machine_options.delimiter); + } + WriteQuotedString(stream, csv_data, csv_data.options.name_list[i].c_str(), + csv_data.options.name_list[i].size(), false); + } + stream.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); + + global_data->WriteData(stream.GetData(), stream.GetPosition()); + } + + return std::move(global_data); +} + +static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_data, DataChunk &cast_chunk, + MemoryStream &writer, DataChunk &input, bool &written_anything) { + auto &csv_data = bind_data.Cast(); + auto &options = csv_data.options; + + // first cast the columns of the chunk to varchar + cast_chunk.Reset(); + cast_chunk.SetCardinality(input); + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + if (csv_data.sql_types[col_idx].id() == LogicalTypeId::VARCHAR) { + // VARCHAR, just reinterpret (cannot reference, because LogicalTypeId::VARCHAR is used by the JSON type too) + cast_chunk.data[col_idx].Reinterpret(input.data[col_idx]); + } else if (options.dialect_options.has_format[LogicalTypeId::DATE] && + csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) { + // use the date format to cast the chunk + csv_data.options.write_date_format[LogicalTypeId::DATE].ConvertDateVector( + input.data[col_idx], cast_chunk.data[col_idx], input.size()); + } else if (options.dialect_options.has_format[LogicalTypeId::TIMESTAMP] && + (csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP || + csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP_TZ)) { + // use the timestamp format to cast the chunk + csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].ConvertTimestampVector( + input.data[col_idx], cast_chunk.data[col_idx], input.size()); + } else { + // non varchar column, perform the cast + VectorOperations::Cast(context, input.data[col_idx], cast_chunk.data[col_idx], input.size()); + } + } + + cast_chunk.Flatten(); + // now loop over the vectors and output the values + for (idx_t row_idx = 0; row_idx < cast_chunk.size(); row_idx++) { + if (row_idx == 0 && !written_anything) { + written_anything = true; + } else { + writer.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); + } + // write values + for (idx_t col_idx = 0; col_idx < cast_chunk.ColumnCount(); col_idx++) { + if (col_idx != 0) { + WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.delimiter); + } + if (FlatVector::IsNull(cast_chunk.data[col_idx], row_idx)) { + // write null value + writer.WriteData(const_data_ptr_cast(options.null_str.c_str()), options.null_str.size()); + continue; + } + + // non-null value, fetch the string value from the cast chunk + auto str_data = FlatVector::GetData(cast_chunk.data[col_idx]); + // FIXME: we could gain some performance here by checking for certain types if they ever require quotes + // (e.g. integers only require quotes if the delimiter is a number, decimals only require quotes if the + // delimiter is a number or "." character) + WriteQuotedString(writer, csv_data, str_data[row_idx].GetData(), str_data[row_idx].GetSize(), + csv_data.options.force_quote[col_idx]); + } + } +} + +static void WriteCSVSink(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + LocalFunctionData &lstate, DataChunk &input) { + auto &csv_data = bind_data.Cast(); + auto &local_data = lstate.Cast(); + auto &global_state = gstate.Cast(); + + // write data into the local buffer + WriteCSVChunkInternal(context.client, bind_data, local_data.cast_chunk, local_data.stream, input, + local_data.written_anything); + + // check if we should flush what we have currently written + auto &writer = local_data.stream; + if (writer.GetPosition() >= csv_data.flush_size) { + global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); + writer.Rewind(); + local_data.written_anything = false; + } +} + +//===--------------------------------------------------------------------===// +// Combine +//===--------------------------------------------------------------------===// +static void WriteCSVCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + LocalFunctionData &lstate) { + auto &local_data = lstate.Cast(); + auto &global_state = gstate.Cast(); + auto &csv_data = bind_data.Cast(); + auto &writer = local_data.stream; + // flush the local writer + if (local_data.written_anything) { + global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); + writer.Rewind(); + } +} + +//===--------------------------------------------------------------------===// +// Finalize +//===--------------------------------------------------------------------===// +void WriteCSVFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) { + auto &global_state = gstate.Cast(); + auto &csv_data = bind_data.Cast(); + auto &options = csv_data.options; + + MemoryStream stream; + if (!options.suffix.empty()) { + stream.WriteData(const_data_ptr_cast(options.suffix.c_str()), options.suffix.size()); + } else if (global_state.written_anything) { + stream.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); + } + global_state.WriteData(stream.GetData(), stream.GetPosition()); + + global_state.handle->Close(); + global_state.handle.reset(); +} + +//===--------------------------------------------------------------------===// +// Execution Mode +//===--------------------------------------------------------------------===// +CopyFunctionExecutionMode WriteCSVExecutionMode(bool preserve_insertion_order, bool supports_batch_index) { + if (!preserve_insertion_order) { + return CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; + } + if (supports_batch_index) { + return CopyFunctionExecutionMode::BATCH_COPY_TO_FILE; + } + return CopyFunctionExecutionMode::REGULAR_COPY_TO_FILE; +} +//===--------------------------------------------------------------------===// +// Prepare Batch +//===--------------------------------------------------------------------===// +struct WriteCSVBatchData : public PreparedBatchData { + //! The thread-local buffer to write data into + MemoryStream stream; +}; + +unique_ptr WriteCSVPrepareBatch(ClientContext &context, FunctionData &bind_data, + GlobalFunctionData &gstate, + unique_ptr collection) { + auto &csv_data = bind_data.Cast(); + + // create the cast chunk with VARCHAR types + vector types; + types.resize(csv_data.options.name_list.size(), LogicalType::VARCHAR); + DataChunk cast_chunk; + cast_chunk.Initialize(Allocator::Get(context), types); + + // write CSV chunks to the batch data + bool written_anything = false; + auto batch = make_uniq(); + for (auto &chunk : collection->Chunks()) { + WriteCSVChunkInternal(context, bind_data, cast_chunk, batch->stream, chunk, written_anything); + } + return std::move(batch); +} + +//===--------------------------------------------------------------------===// +// Flush Batch +//===--------------------------------------------------------------------===// +void WriteCSVFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + PreparedBatchData &batch) { + auto &csv_batch = batch.Cast(); + auto &global_state = gstate.Cast(); + auto &csv_data = bind_data.Cast(); + auto &writer = csv_batch.stream; + global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); + writer.Rewind(); +} + +void CSVCopyFunction::RegisterFunction(BuiltinFunctions &set) { + CopyFunction info("csv"); + info.copy_to_bind = WriteCSVBind; + info.copy_to_initialize_local = WriteCSVInitializeLocal; + info.copy_to_initialize_global = WriteCSVInitializeGlobal; + info.copy_to_sink = WriteCSVSink; + info.copy_to_combine = WriteCSVCombine; + info.copy_to_finalize = WriteCSVFinalize; + info.execution_mode = WriteCSVExecutionMode; + info.prepare_batch = WriteCSVPrepareBatch; + info.flush_batch = WriteCSVFlushBatch; + + info.copy_from_bind = ReadCSVBind; + info.copy_from_function = ReadCSVTableFunction::GetFunction(); + + info.extension = "csv"; + + set.AddFunction(info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/glob.cpp b/src/duckdb/src/function/table/glob.cpp new file mode 100644 index 00000000..84dbc688 --- /dev/null +++ b/src/duckdb/src/function/table/glob.cpp @@ -0,0 +1,52 @@ +#include "duckdb/function/table/range.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/common/multi_file_reader.hpp" + +namespace duckdb { + +struct GlobFunctionBindData : public TableFunctionData { + vector files; +}; + +static unique_ptr GlobFunctionBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto result = make_uniq(); + result->files = MultiFileReader::GetFileList(context, input.inputs[0], "Globbing", FileGlobOptions::ALLOW_EMPTY); + return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("file"); + return std::move(result); +} + +struct GlobFunctionState : public GlobalTableFunctionState { + GlobFunctionState() : current_idx(0) { + } + + idx_t current_idx; +}; + +static unique_ptr GlobFunctionInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void GlobFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &state = data_p.global_state->Cast(); + + idx_t count = 0; + idx_t next_idx = MinValue(state.current_idx + STANDARD_VECTOR_SIZE, bind_data.files.size()); + for (; state.current_idx < next_idx; state.current_idx++) { + output.data[0].SetValue(count, bind_data.files[state.current_idx]); + count++; + } + output.SetCardinality(count); +} + +void GlobTableFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunction glob_function("glob", {LogicalType::VARCHAR}, GlobFunction, GlobFunctionBind, GlobFunctionInit); + set.AddFunction(MultiFileReader::CreateFunctionSet(glob_function)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/pragma_detailed_profiling_output.cpp b/src/duckdb/src/function/table/pragma_detailed_profiling_output.cpp new file mode 100644 index 00000000..e3cae0b9 --- /dev/null +++ b/src/duckdb/src/function/table/pragma_detailed_profiling_output.cpp @@ -0,0 +1,173 @@ +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" + +namespace duckdb { + +struct PragmaDetailedProfilingOutputOperatorData : public GlobalTableFunctionState { + explicit PragmaDetailedProfilingOutputOperatorData() : initialized(false) { + } + + ColumnDataScanState scan_state; + bool initialized; +}; + +struct PragmaDetailedProfilingOutputData : public TableFunctionData { + explicit PragmaDetailedProfilingOutputData(vector &types) : types(types) { + } + unique_ptr collection; + vector types; +}; + +static unique_ptr PragmaDetailedProfilingOutputBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, + vector &names) { + names.emplace_back("OPERATOR_ID"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("ANNOTATION"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("ID"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("NAME"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("TIME"); + return_types.emplace_back(LogicalType::DOUBLE); + + names.emplace_back("CYCLES_PER_TUPLE"); + return_types.emplace_back(LogicalType::DOUBLE); + + names.emplace_back("SAMPLE_SIZE"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("INPUT_SIZE"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("EXTRA_INFO"); + return_types.emplace_back(LogicalType::VARCHAR); + + return make_uniq(return_types); +} + +unique_ptr PragmaDetailedProfilingOutputInit(ClientContext &context, + TableFunctionInitInput &input) { + return make_uniq(); +} + +// Insert a row into the given datachunk +static void SetValue(DataChunk &output, int index, int op_id, string annotation, int id, string name, double time, + int sample_counter, int tuple_counter, string extra_info) { + output.SetValue(0, index, op_id); + output.SetValue(1, index, std::move(annotation)); + output.SetValue(2, index, id); + output.SetValue(3, index, std::move(name)); +#if defined(RDTSC) + output.SetValue(4, index, Value(nullptr)); + output.SetValue(5, index, time); +#else + output.SetValue(4, index, time); + output.SetValue(5, index, Value(nullptr)); + +#endif + output.SetValue(6, index, sample_counter); + output.SetValue(7, index, tuple_counter); + output.SetValue(8, index, std::move(extra_info)); +} + +static void ExtractFunctions(ColumnDataCollection &collection, ExpressionInfo &info, DataChunk &chunk, int op_id, + int &fun_id) { + if (info.hasfunction) { + D_ASSERT(info.sample_tuples_count != 0); + SetValue(chunk, chunk.size(), op_id, "Function", fun_id++, info.function_name, + int(info.function_time) / double(info.sample_tuples_count), info.sample_tuples_count, + info.tuples_count, ""); + + chunk.SetCardinality(chunk.size() + 1); + if (chunk.size() == STANDARD_VECTOR_SIZE) { + collection.Append(chunk); + chunk.Reset(); + } + } + if (info.children.empty()) { + return; + } + // extract the children of this node + for (auto &child : info.children) { + ExtractFunctions(collection, *child, chunk, op_id, fun_id); + } +} + +static void PragmaDetailedProfilingOutputFunction(ClientContext &context, TableFunctionInput &data_p, + DataChunk &output) { + auto &state = data_p.global_state->Cast(); + auto &data = data_p.bind_data->CastNoConst(); + + if (!state.initialized) { + // create a ColumnDataCollection + auto collection = make_uniq(context, data.types); + + // create a chunk + DataChunk chunk; + chunk.Initialize(context, data.types); + + // Initialize ids + int operator_counter = 1; + int function_counter = 1; + int expression_counter = 1; + auto &client_data = ClientData::Get(context); + if (client_data.query_profiler_history->GetPrevProfilers().empty()) { + return; + } + // For each Operator + auto &tree_map = client_data.query_profiler_history->GetPrevProfilers().back().second->GetTreeMap(); + for (auto op : tree_map) { + // For each Expression Executor + for (auto &expr_executor : op.second.get().info.executors_info) { + // For each Expression tree + if (!expr_executor) { + continue; + } + for (auto &expr_timer : expr_executor->roots) { + D_ASSERT(expr_timer->sample_tuples_count != 0); + SetValue(chunk, chunk.size(), operator_counter, "ExpressionRoot", expression_counter++, + // Sometimes, cycle counter is not accurate, too big or too small. return 0 for + // those cases + expr_timer->name, int(expr_timer->time) / double(expr_timer->sample_tuples_count), + expr_timer->sample_tuples_count, expr_timer->tuples_count, expr_timer->extra_info); + // Increment cardinality + chunk.SetCardinality(chunk.size() + 1); + // Check whether data chunk is full or not + if (chunk.size() == STANDARD_VECTOR_SIZE) { + collection->Append(chunk); + chunk.Reset(); + } + // Extract all functions inside the tree + ExtractFunctions(*collection, *expr_timer->root, chunk, operator_counter, function_counter); + } + } + operator_counter++; + } + collection->Append(chunk); + data.collection = std::move(collection); + data.collection->InitializeScan(state.scan_state); + state.initialized = true; + } + + data.collection->Scan(state.scan_state, output); +} + +void PragmaDetailedProfilingOutput::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("pragma_detailed_profiling_output", {}, PragmaDetailedProfilingOutputFunction, + PragmaDetailedProfilingOutputBind, PragmaDetailedProfilingOutputInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/pragma_last_profiling_output.cpp b/src/duckdb/src/function/table/pragma_last_profiling_output.cpp new file mode 100644 index 00000000..c822f3dd --- /dev/null +++ b/src/duckdb/src/function/table/pragma_last_profiling_output.cpp @@ -0,0 +1,101 @@ +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" + +namespace duckdb { + +struct PragmaLastProfilingOutputOperatorData : public GlobalTableFunctionState { + PragmaLastProfilingOutputOperatorData() : initialized(false) { + } + + ColumnDataScanState scan_state; + bool initialized; +}; + +struct PragmaLastProfilingOutputData : public TableFunctionData { + explicit PragmaLastProfilingOutputData(vector &types) : types(types) { + } + unique_ptr collection; + vector types; +}; + +static unique_ptr PragmaLastProfilingOutputBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, + vector &names) { + names.emplace_back("OPERATOR_ID"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("NAME"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("TIME"); + return_types.emplace_back(LogicalType::DOUBLE); + + names.emplace_back("CARDINALITY"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("DESCRIPTION"); + return_types.emplace_back(LogicalType::VARCHAR); + + return make_uniq(return_types); +} + +static void SetValue(DataChunk &output, int index, int op_id, string name, double time, int64_t car, + string description) { + output.SetValue(0, index, op_id); + output.SetValue(1, index, std::move(name)); + output.SetValue(2, index, time); + output.SetValue(3, index, car); + output.SetValue(4, index, std::move(description)); +} + +unique_ptr PragmaLastProfilingOutputInit(ClientContext &context, + TableFunctionInitInput &input) { + return make_uniq(); +} + +static void PragmaLastProfilingOutputFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &state = data_p.global_state->Cast(); + auto &data = data_p.bind_data->CastNoConst(); + if (!state.initialized) { + // create a ColumnDataCollection + auto collection = make_uniq(context, data.types); + + DataChunk chunk; + chunk.Initialize(context, data.types); + int operator_counter = 1; + auto &client_data = ClientData::Get(context); + if (!client_data.query_profiler_history->GetPrevProfilers().empty()) { + auto &tree_map = client_data.query_profiler_history->GetPrevProfilers().back().second->GetTreeMap(); + for (auto op : tree_map) { + auto &tree_info = op.second.get(); + SetValue(chunk, chunk.size(), operator_counter++, tree_info.name, tree_info.info.time, + tree_info.info.elements, " "); + chunk.SetCardinality(chunk.size() + 1); + if (chunk.size() == STANDARD_VECTOR_SIZE) { + collection->Append(chunk); + chunk.Reset(); + } + } + } + collection->Append(chunk); + data.collection = std::move(collection); + data.collection->InitializeScan(state.scan_state); + state.initialized = true; + } + + data.collection->Scan(state.scan_state, output); +} + +void PragmaLastProfilingOutput::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("pragma_last_profiling_output", {}, PragmaLastProfilingOutputFunction, + PragmaLastProfilingOutputBind, PragmaLastProfilingOutputInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/range.cpp b/src/duckdb/src/function/table/range.cpp new file mode 100644 index 00000000..205ee3ac --- /dev/null +++ b/src/duckdb/src/function/table/range.cpp @@ -0,0 +1,279 @@ +#include "duckdb/function/table/range.hpp" +#include "duckdb/function/table/summary.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/types/timestamp.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Range (integers) +//===--------------------------------------------------------------------===// +struct RangeFunctionBindData : public TableFunctionData { + hugeint_t start; + hugeint_t end; + hugeint_t increment; + +public: + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return other.start == start && other.end == end && other.increment == increment; + } +}; + +template +static void GenerateRangeParameters(const vector &inputs, RangeFunctionBindData &result) { + for (auto &input : inputs) { + if (input.IsNull()) { + result.start = GENERATE_SERIES ? 1 : 0; + result.end = 0; + result.increment = 1; + return; + } + } + if (inputs.size() < 2) { + // single argument: only the end is specified + result.start = 0; + result.end = inputs[0].GetValue(); + } else { + // two arguments: first two arguments are start and end + result.start = inputs[0].GetValue(); + result.end = inputs[1].GetValue(); + } + if (inputs.size() < 3) { + result.increment = 1; + } else { + result.increment = inputs[2].GetValue(); + } + if (result.increment == 0) { + throw BinderException("interval cannot be 0!"); + } + if (result.start > result.end && result.increment > 0) { + throw BinderException("start is bigger than end, but increment is positive: cannot generate infinite series"); + } else if (result.start < result.end && result.increment < 0) { + throw BinderException("start is smaller than end, but increment is negative: cannot generate infinite series"); + } +} + +template +static unique_ptr RangeFunctionBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto result = make_uniq(); + auto &inputs = input.inputs; + GenerateRangeParameters(inputs, *result); + + return_types.emplace_back(LogicalType::BIGINT); + if (GENERATE_SERIES) { + // generate_series has inclusive bounds on the RHS + if (result->increment < 0) { + result->end = result->end - 1; + } else { + result->end = result->end + 1; + } + names.emplace_back("generate_series"); + } else { + names.emplace_back("range"); + } + return std::move(result); +} + +struct RangeFunctionState : public GlobalTableFunctionState { + RangeFunctionState() : current_idx(0) { + } + + int64_t current_idx; +}; + +static unique_ptr RangeFunctionInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void RangeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &state = data_p.global_state->Cast(); + + auto increment = bind_data.increment; + auto end = bind_data.end; + hugeint_t current_value = bind_data.start + increment * state.current_idx; + int64_t current_value_i64; + if (!Hugeint::TryCast(current_value, current_value_i64)) { + return; + } + int64_t offset = increment < 0 ? 1 : -1; + idx_t remaining = MinValue(Hugeint::Cast((end - current_value + (increment + offset)) / increment), + STANDARD_VECTOR_SIZE); + // set the result vector as a sequence vector + output.data[0].Sequence(current_value_i64, Hugeint::Cast(increment), remaining); + // increment the index pointer by the remaining count + state.current_idx += remaining; + output.SetCardinality(remaining); +} + +unique_ptr RangeCardinality(ClientContext &context, const FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + idx_t cardinality = Hugeint::Cast((bind_data.end - bind_data.start) / bind_data.increment); + return make_uniq(cardinality, cardinality); +} + +//===--------------------------------------------------------------------===// +// Range (timestamp) +//===--------------------------------------------------------------------===// +struct RangeDateTimeBindData : public TableFunctionData { + timestamp_t start; + timestamp_t end; + interval_t increment; + bool inclusive_bound; + bool greater_than_check; + +public: + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return other.start == start && other.end == end && other.increment == increment && + other.inclusive_bound == inclusive_bound && other.greater_than_check == greater_than_check; + } + + bool Finished(timestamp_t current_value) const { + if (greater_than_check) { + if (inclusive_bound) { + return current_value > end; + } else { + return current_value >= end; + } + } else { + if (inclusive_bound) { + return current_value < end; + } else { + return current_value <= end; + } + } + } +}; + +template +static unique_ptr RangeDateTimeBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto result = make_uniq(); + auto &inputs = input.inputs; + D_ASSERT(inputs.size() == 3); + result->start = inputs[0].GetValue(); + result->end = inputs[1].GetValue(); + result->increment = inputs[2].GetValue(); + + // Infinities either cause errors or infinite loops, so just ban them + if (!Timestamp::IsFinite(result->start) || !Timestamp::IsFinite(result->end)) { + throw BinderException("RANGE with infinite bounds is not supported"); + } + + if (result->increment.months == 0 && result->increment.days == 0 && result->increment.micros == 0) { + throw BinderException("interval cannot be 0!"); + } + // all elements should point in the same direction + if (result->increment.months > 0 || result->increment.days > 0 || result->increment.micros > 0) { + if (result->increment.months < 0 || result->increment.days < 0 || result->increment.micros < 0) { + throw BinderException("RANGE with composite interval that has mixed signs is not supported"); + } + result->greater_than_check = true; + if (result->start > result->end) { + throw BinderException( + "start is bigger than end, but increment is positive: cannot generate infinite series"); + } + } else { + result->greater_than_check = false; + if (result->start < result->end) { + throw BinderException( + "start is smaller than end, but increment is negative: cannot generate infinite series"); + } + } + return_types.push_back(inputs[0].type()); + if (GENERATE_SERIES) { + // generate_series has inclusive bounds on the RHS + result->inclusive_bound = true; + names.emplace_back("generate_series"); + } else { + result->inclusive_bound = false; + names.emplace_back("range"); + } + return std::move(result); +} + +struct RangeDateTimeState : public GlobalTableFunctionState { + explicit RangeDateTimeState(timestamp_t start_p) : current_state(start_p) { + } + + timestamp_t current_state; + bool finished = false; +}; + +static unique_ptr RangeDateTimeInit(ClientContext &context, TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->Cast(); + return make_uniq(bind_data.start); +} + +static void RangeDateTimeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &state = data_p.global_state->Cast(); + if (state.finished) { + return; + } + + idx_t size = 0; + auto data = FlatVector::GetData(output.data[0]); + while (true) { + data[size++] = state.current_state; + state.current_state = + AddOperator::Operation(state.current_state, bind_data.increment); + if (bind_data.Finished(state.current_state)) { + state.finished = true; + break; + } + if (size >= STANDARD_VECTOR_SIZE) { + break; + } + } + output.SetCardinality(size); +} + +void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunctionSet range("range"); + + TableFunction range_function({LogicalType::BIGINT}, RangeFunction, RangeFunctionBind, RangeFunctionInit); + range_function.cardinality = RangeCardinality; + + // single argument range: (end) - implicit start = 0 and increment = 1 + range.AddFunction(range_function); + // two arguments range: (start, end) - implicit increment = 1 + range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; + range.AddFunction(range_function); + // three arguments range: (start, end, increment) + range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; + range.AddFunction(range_function); + range.AddFunction(TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + RangeDateTimeFunction, RangeDateTimeBind, RangeDateTimeInit)); + set.AddFunction(range); + // generate_series: similar to range, but inclusive instead of exclusive bounds on the RHS + TableFunctionSet generate_series("generate_series"); + range_function.bind = RangeFunctionBind; + range_function.arguments = {LogicalType::BIGINT}; + generate_series.AddFunction(range_function); + range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; + generate_series.AddFunction(range_function); + range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; + generate_series.AddFunction(range_function); + generate_series.AddFunction(TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + RangeDateTimeFunction, RangeDateTimeBind, RangeDateTimeInit)); + set.AddFunction(generate_series); +} + +void BuiltinFunctions::RegisterTableFunctions() { + CheckpointFunction::RegisterFunction(*this); + GlobTableFunction::RegisterFunction(*this); + RangeTableFunction::RegisterFunction(*this); + RepeatTableFunction::RegisterFunction(*this); + SummaryTableFunction::RegisterFunction(*this); + UnnestTableFunction::RegisterFunction(*this); + RepeatRowTableFunction::RegisterFunction(*this); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/read_csv.cpp b/src/duckdb/src/function/table/read_csv.cpp new file mode 100644 index 00000000..5a21874a --- /dev/null +++ b/src/duckdb/src/function/table/read_csv.cpp @@ -0,0 +1,1028 @@ +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/multi_file_reader.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/union_by_name.hpp" +#include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" +#include "duckdb/execution/operator/scan/csv/csv_line_info.hpp" +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include + +namespace duckdb { + +unique_ptr ReadCSV::OpenCSV(const string &file_path, FileCompressionType compression, + ClientContext &context) { + auto &fs = FileSystem::GetFileSystem(context); + auto &allocator = BufferAllocator::Get(context); + return CSVFileHandle::OpenFile(fs, allocator, file_path, compression); +} + +void ReadCSVData::FinalizeRead(ClientContext &context) { + BaseCSVData::Finalize(); + // Here we identify if we can run this CSV file on parallel or not. + bool not_supported_options = options.null_padding; + + auto number_of_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + //! If we have many csv files, we run single-threaded on each file and parallelize on the number of files + bool many_csv_files = files.size() > 1 && int64_t(files.size() * 2) >= number_of_threads; + if (options.parallel_mode != ParallelMode::PARALLEL && many_csv_files) { + single_threaded = true; + } + if (options.parallel_mode == ParallelMode::SINGLE_THREADED || not_supported_options || + options.dialect_options.new_line == NewLineIdentifier::MIX) { + // not supported for parallel CSV reading + single_threaded = true; + } + + // Validate rejects_table options + if (!options.rejects_table_name.empty()) { + if (!options.ignore_errors) { + throw BinderException("REJECTS_TABLE option is only supported when IGNORE_ERRORS is set to true"); + } + if (options.file_options.union_by_name) { + throw BinderException("REJECTS_TABLE option is not supported when UNION_BY_NAME is set to true"); + } + } + + if (!options.rejects_recovery_columns.empty()) { + if (options.rejects_table_name.empty()) { + throw BinderException( + "REJECTS_RECOVERY_COLUMNS option is only supported when REJECTS_TABLE is set to a table name"); + } + for (auto &recovery_col : options.rejects_recovery_columns) { + bool found = false; + for (idx_t col_idx = 0; col_idx < return_names.size(); col_idx++) { + if (StringUtil::CIEquals(return_names[col_idx], recovery_col)) { + options.rejects_recovery_column_ids.push_back(col_idx); + found = true; + break; + } + } + if (!found) { + throw BinderException("Unsupported parameter for REJECTS_RECOVERY_COLUMNS: column \"%s\" not found", + recovery_col); + } + } + } + + if (options.rejects_limit != 0) { + if (options.rejects_table_name.empty()) { + throw BinderException("REJECTS_LIMIT option is only supported when REJECTS_TABLE is set to a table name"); + } + } +} + +static unique_ptr ReadCSVBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + + auto result = make_uniq(); + auto &options = result->options; + result->files = MultiFileReader::GetFileList(context, input.inputs[0], "CSV"); + + options.FromNamedParameters(input.named_parameters, context, return_types, names); + bool explicitly_set_columns = options.explicitly_set_columns; + + options.file_options.AutoDetectHivePartitioning(result->files, context); + + if (!options.auto_detect && return_types.empty()) { + throw BinderException("read_csv requires columns to be specified through the 'columns' option. Use " + "read_csv_auto or set read_csv(..., " + "AUTO_DETECT=TRUE) to automatically guess columns."); + } + if (options.auto_detect) { + options.file_path = result->files[0]; + // Initialize Buffer Manager and Sniffer + auto file_handle = BaseCSVReader::OpenCSV(context, options); + result->buffer_manager = make_shared(context, std::move(file_handle), options); + CSVSniffer sniffer(options, result->buffer_manager, result->state_machine_cache, explicitly_set_columns); + auto sniffer_result = sniffer.SniffCSV(); + if (names.empty()) { + names = sniffer_result.names; + return_types = sniffer_result.return_types; + } else { + if (explicitly_set_columns) { + // The user has influenced the names, can't assume they are valid anymore + if (return_types.size() != names.size()) { + throw BinderException("The amount of names specified (%d) and the observed amount of types (%d) in " + "the file don't match", + names.size(), return_types.size()); + } + } else { + D_ASSERT(return_types.size() == names.size()); + } + } + + } else { + D_ASSERT(return_types.size() == names.size()); + } + result->csv_types = return_types; + result->csv_names = names; + + if (options.file_options.union_by_name) { + result->reader_bind = + MultiFileReader::BindUnionReader(context, return_types, names, *result, options); + if (result->union_readers.size() > 1) { + result->column_info.emplace_back(result->csv_names, result->csv_types); + for (idx_t i = 1; i < result->union_readers.size(); i++) { + result->column_info.emplace_back(result->union_readers[i]->names, + result->union_readers[i]->return_types); + } + } + if (!options.sql_types_per_column.empty()) { + auto exception = BufferedCSVReader::ColumnTypesError(options.sql_types_per_column, names); + if (!exception.empty()) { + throw BinderException(exception); + } + } + } else { + result->reader_bind = MultiFileReader::BindOptions(options.file_options, result->files, return_types, names); + } + result->return_types = return_types; + result->return_names = names; + result->FinalizeRead(context); + + return std::move(result); +} + +static unique_ptr ReadCSVAutoBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + input.named_parameters["auto_detect"] = Value::BOOLEAN(true); + return ReadCSVBind(context, input, return_types, names); +} + +//===--------------------------------------------------------------------===// +// Parallel CSV Reader CSV Global State +//===--------------------------------------------------------------------===// + +struct ParallelCSVGlobalState : public GlobalTableFunctionState { +public: + ParallelCSVGlobalState(ClientContext &context, shared_ptr buffer_manager_p, + const CSVReaderOptions &options, idx_t system_threads_p, const vector &files_path_p, + bool force_parallelism_p, vector column_ids_p) + : buffer_manager(std::move(buffer_manager_p)), system_threads(system_threads_p), + force_parallelism(force_parallelism_p), column_ids(std::move(column_ids_p)), + line_info(main_mutex, batch_to_tuple_end, tuple_start, tuple_end) { + current_file_path = files_path_p[0]; + CSVFileHandle *file_handle_ptr; + + if (!buffer_manager || (options.skip_rows_set && options.dialect_options.skip_rows > 0) || + buffer_manager->file_handle->GetFilePath() != current_file_path) { + // If our buffers are too small, and we skip too many rows there is a chance things will go over-buffer + // for now don't reuse the buffer manager + buffer_manager.reset(); + file_handle = ReadCSV::OpenCSV(current_file_path, options.compression, context); + file_handle_ptr = file_handle.get(); + } else { + file_handle_ptr = buffer_manager->file_handle.get(); + } + + file_size = file_handle_ptr->FileSize(); + first_file_size = file_size; + on_disk_file = file_handle_ptr->OnDiskFile(); + bytes_read = 0; + running_threads = MaxThreads(); + + // Initialize all the book-keeping variables + auto file_count = files_path_p.size(); + line_info.current_batches.resize(file_count); + line_info.lines_read.resize(file_count); + line_info.lines_errored.resize(file_count); + tuple_start.resize(file_count); + tuple_end.resize(file_count); + tuple_end_to_batch.resize(file_count); + batch_to_tuple_end.resize(file_count); + + // Initialize the lines read + line_info.lines_read[0][0] = options.dialect_options.skip_rows; + if (options.has_header && options.dialect_options.header) { + line_info.lines_read[0][0]++; + } + first_position = options.dialect_options.true_start; + next_byte = options.dialect_options.true_start; + } + explicit ParallelCSVGlobalState(idx_t system_threads_p) + : system_threads(system_threads_p), line_info(main_mutex, batch_to_tuple_end, tuple_start, tuple_end) { + running_threads = MaxThreads(); + } + + ~ParallelCSVGlobalState() override { + } + + //! How many bytes were read up to this point + atomic bytes_read; + //! Size of current file + idx_t file_size; + +public: + idx_t MaxThreads() const override; + //! Updates the CSV reader with the next buffer to read. Returns false if no more buffers are available. + bool Next(ClientContext &context, const ReadCSVData &bind_data, unique_ptr &reader); + //! Verify if the CSV File was read correctly + void Verify(); + + void UpdateVerification(VerificationPositions positions, idx_t file_number, idx_t batch_idx); + + void UpdateLinesRead(CSVBufferRead &buffer_read, idx_t file_idx); + + void DecrementThread(); + + bool Finished(); + + double GetProgress(const ReadCSVData &bind_data) const { + idx_t total_files = bind_data.files.size(); + + // get the progress WITHIN the current file + double progress; + if (file_size == 0) { + progress = 1.0; + } else { + progress = double(bytes_read) / double(file_size); + } + // now get the total percentage of files read + double percentage = double(file_index - 1) / total_files; + percentage += (double(1) / double(total_files)) * progress; + return percentage * 100; + } + +private: + //! File Handle for current file + shared_ptr buffer_manager; + + //! The index of the next file to read (i.e. current file + 1) + idx_t file_index = 1; + string current_file_path; + + //! Mutex to lock when getting next batch of bytes (Parallel Only) + mutex main_mutex; + //! Byte set from for last thread + idx_t next_byte = 0; + //! Size of first file + idx_t first_file_size = 0; + //! Whether or not this is an on-disk file + bool on_disk_file = true; + //! Basically max number of threads in DuckDB + idx_t system_threads; + //! Current batch index + idx_t batch_index = 0; + idx_t local_batch_index = 0; + + //! Forces parallelism for small CSV Files, should only be used for testing. + bool force_parallelism = false; + //! First Position of First Buffer + idx_t first_position = 0; + //! Current File Number + idx_t max_tuple_end = 0; + //! The vector stores positions where threads ended the last line they read in the CSV File, and the set stores + //! Positions where they started reading the first line. + vector> tuple_end; + vector> tuple_start; + //! Tuple end to batch + vector> tuple_end_to_batch; + //! Batch to Tuple End + vector> batch_to_tuple_end; + idx_t running_threads = 0; + //! The column ids to read + vector column_ids; + //! Line Info used in error messages + LineInfo line_info; + //! Current Buffer index + idx_t cur_buffer_idx = 0; + //! Only used if we don't run auto_detection first + unique_ptr file_handle; +}; + +idx_t ParallelCSVGlobalState::MaxThreads() const { + if (force_parallelism || !on_disk_file) { + return system_threads; + } + idx_t one_mb = 1000000; // We initialize max one thread per Mb + idx_t threads_per_mb = first_file_size / one_mb + 1; + if (threads_per_mb < system_threads || threads_per_mb == 1) { + return threads_per_mb; + } + + return system_threads; +} + +void ParallelCSVGlobalState::DecrementThread() { + lock_guard parallel_lock(main_mutex); + D_ASSERT(running_threads > 0); + running_threads--; +} + +bool ParallelCSVGlobalState::Finished() { + lock_guard parallel_lock(main_mutex); + return running_threads == 0; +} + +void ParallelCSVGlobalState::Verify() { + // All threads are done, we run some magic sweet verification code + lock_guard parallel_lock(main_mutex); + if (running_threads == 0) { + D_ASSERT(tuple_end.size() == tuple_start.size()); + for (idx_t i = 0; i < tuple_start.size(); i++) { + auto ¤t_tuple_end = tuple_end[i]; + auto ¤t_tuple_start = tuple_start[i]; + // figure out max value of last_pos + if (current_tuple_end.empty()) { + return; + } + auto max_value = *max_element(std::begin(current_tuple_end), std::end(current_tuple_end)); + for (idx_t tpl_idx = 0; tpl_idx < current_tuple_end.size(); tpl_idx++) { + auto last_pos = current_tuple_end[tpl_idx]; + auto first_pos = current_tuple_start.find(last_pos); + if (first_pos == current_tuple_start.end()) { + // this might be necessary due to carriage returns outside buffer scopes. + first_pos = current_tuple_start.find(last_pos + 1); + } + if (first_pos == current_tuple_start.end() && last_pos != max_value) { + auto batch_idx = tuple_end_to_batch[i][last_pos]; + auto problematic_line = line_info.GetLine(batch_idx); + throw InvalidInputException( + "CSV File not supported for multithreading. This can be a problematic line in your CSV File or " + "that this CSV can't be read in Parallel. Please, inspect if the line %llu is correct. If so, " + "please run single-threaded CSV Reading by setting parallel=false in the read_csv call.", + problematic_line); + } + } + } + } +} + +void LineInfo::Verify(idx_t file_idx, idx_t batch_idx, idx_t cur_first_pos) { + auto &tuple_start_set = tuple_start[file_idx]; + auto &processed_batches = batch_to_tuple_end[file_idx]; + auto &tuple_end_vec = tuple_end[file_idx]; + bool has_error = false; + idx_t problematic_line; + if (batch_idx == 0 || tuple_start_set.empty()) { + return; + } + for (idx_t cur_batch = 0; cur_batch < batch_idx - 1; cur_batch++) { + auto cur_end = tuple_end_vec[processed_batches[cur_batch]]; + auto first_pos = tuple_start_set.find(cur_end); + if (first_pos == tuple_start_set.end()) { + has_error = true; + problematic_line = GetLine(cur_batch); + break; + } + } + if (!has_error) { + auto cur_end = tuple_end_vec[processed_batches[batch_idx - 1]]; + if (cur_end != cur_first_pos) { + has_error = true; + problematic_line = GetLine(batch_idx); + } + } + if (has_error) { + throw InvalidInputException( + "CSV File not supported for multithreading. This can be a problematic line in your CSV File or " + "that this CSV can't be read in Parallel. Please, inspect if the line %llu is correct. If so, " + "please run single-threaded CSV Reading by setting parallel=false in the read_csv call.", + problematic_line); + } +} +bool ParallelCSVGlobalState::Next(ClientContext &context, const ReadCSVData &bind_data, + unique_ptr &reader) { + lock_guard parallel_lock(main_mutex); + if (!buffer_manager && file_handle) { + buffer_manager = make_shared(context, std::move(file_handle), bind_data.options); + } + if (!buffer_manager) { + return false; + } + auto current_buffer = buffer_manager->GetBuffer(cur_buffer_idx); + auto next_buffer = buffer_manager->GetBuffer(cur_buffer_idx + 1); + + if (!current_buffer) { + // This means we are done with the current file, we need to go to the next one (if exists). + if (file_index < bind_data.files.size()) { + current_file_path = bind_data.files[file_index]; + file_handle = ReadCSV::OpenCSV(current_file_path, bind_data.options.compression, context); + buffer_manager = + make_shared(context, std::move(file_handle), bind_data.options, file_index); + cur_buffer_idx = 0; + first_position = 0; + local_batch_index = 0; + + line_info.lines_read[file_index++][local_batch_index] = (bind_data.options.has_header ? 1 : 0); + + current_buffer = buffer_manager->GetBuffer(cur_buffer_idx); + next_buffer = buffer_manager->GetBuffer(cur_buffer_idx + 1); + } else { + // We are done scanning. + reader.reset(); + return false; + } + } + // set up the current buffer + line_info.current_batches[file_index - 1].insert(local_batch_index); + idx_t bytes_per_local_state = current_buffer->actual_size / MaxThreads() + 1; + auto result = make_uniq( + buffer_manager->GetBuffer(cur_buffer_idx), buffer_manager->GetBuffer(cur_buffer_idx + 1), next_byte, + next_byte + bytes_per_local_state, batch_index++, local_batch_index++, &line_info); + // move the byte index of the CSV reader to the next buffer + next_byte += bytes_per_local_state; + if (next_byte >= current_buffer->actual_size) { + // We replace the current buffer with the next buffer + next_byte = 0; + bytes_read += current_buffer->actual_size; + current_buffer = std::move(next_buffer); + cur_buffer_idx++; + if (current_buffer) { + // Next buffer gets the next-next buffer + next_buffer = buffer_manager->GetBuffer(cur_buffer_idx + 1); + } + } + if (!reader || reader->options.file_path != current_file_path) { + // we either don't have a reader, or the reader was created for a different file + // we need to create a new reader and instantiate it + if (file_index > 0 && file_index <= bind_data.union_readers.size() && bind_data.union_readers[file_index - 1]) { + // we are doing UNION BY NAME - fetch the options from the union reader for this file + auto &union_reader = *bind_data.union_readers[file_index - 1]; + reader = make_uniq(context, union_reader.options, std::move(result), first_position, + union_reader.GetTypes(), file_index - 1); + reader->names = union_reader.GetNames(); + } else if (file_index <= bind_data.column_info.size()) { + // Serialized Union By name + reader = make_uniq(context, bind_data.options, std::move(result), first_position, + bind_data.column_info[file_index - 1].types, file_index - 1); + reader->names = bind_data.column_info[file_index - 1].names; + } else { + // regular file - use the standard options + if (!result) { + return false; + } + reader = make_uniq(context, bind_data.options, std::move(result), first_position, + bind_data.csv_types, file_index - 1); + reader->names = bind_data.csv_names; + } + reader->options.file_path = current_file_path; + MultiFileReader::InitializeReader(*reader, bind_data.options.file_options, bind_data.reader_bind, + bind_data.return_types, bind_data.return_names, column_ids, nullptr, + bind_data.files.front(), context); + } else { + // update the current reader + reader->SetBufferRead(std::move(result)); + } + + return true; +} +void ParallelCSVGlobalState::UpdateVerification(VerificationPositions positions, idx_t file_number_p, idx_t batch_idx) { + lock_guard parallel_lock(main_mutex); + if (positions.end_of_last_line > max_tuple_end) { + max_tuple_end = positions.end_of_last_line; + } + tuple_end_to_batch[file_number_p][positions.end_of_last_line] = batch_idx; + batch_to_tuple_end[file_number_p][batch_idx] = tuple_end[file_number_p].size(); + tuple_start[file_number_p].insert(positions.beginning_of_first_line); + tuple_end[file_number_p].push_back(positions.end_of_last_line); +} + +void ParallelCSVGlobalState::UpdateLinesRead(CSVBufferRead &buffer_read, idx_t file_idx) { + auto batch_idx = buffer_read.local_batch_index; + auto lines_read = buffer_read.lines_read; + lock_guard parallel_lock(main_mutex); + line_info.current_batches[file_idx].erase(batch_idx); + line_info.lines_read[file_idx][batch_idx] += lines_read; +} + +bool LineInfo::CanItGetLine(idx_t file_idx, idx_t batch_idx) { + lock_guard parallel_lock(main_mutex); + if (current_batches.empty() || done) { + return true; + } + if (file_idx >= current_batches.size() || current_batches[file_idx].empty()) { + return true; + } + auto min_value = *current_batches[file_idx].begin(); + if (min_value >= batch_idx) { + return true; + } + return false; +} + +void LineInfo::Increment(idx_t file_idx, idx_t batch_idx) { + auto parallel_lock = duckdb::make_uniq>(main_mutex); + lines_errored[file_idx][batch_idx]++; +} + +// Returns the 1-indexed line number +idx_t LineInfo::GetLine(idx_t batch_idx, idx_t line_error, idx_t file_idx, idx_t cur_start, bool verify, + bool stop_at_first) { + unique_ptr> parallel_lock; + if (!verify) { + parallel_lock = duckdb::make_uniq>(main_mutex); + } + idx_t line_count = 0; + + if (!stop_at_first) { + // Figure out the amount of lines read in the current file + for (idx_t cur_batch_idx = 0; cur_batch_idx <= batch_idx; cur_batch_idx++) { + if (cur_batch_idx < batch_idx) { + line_count += lines_errored[file_idx][cur_batch_idx]; + } + line_count += lines_read[file_idx][cur_batch_idx]; + } + return line_count + line_error + 1; + } + + // Otherwise, check if we already have an error on another thread + if (done) { + // line count is 0-indexed, but we want to return 1-indexed + return first_line + 1; + } + for (idx_t i = 0; i <= batch_idx; i++) { + if (lines_read[file_idx].find(i) == lines_read[file_idx].end() && i != batch_idx) { + throw InternalException("Missing batch index on Parallel CSV Reader GetLine"); + } + line_count += lines_read[file_idx][i]; + } + + // before we are done, if this is not a call in Verify() we must check Verify up to this batch + if (!verify) { + Verify(file_idx, batch_idx, cur_start); + } + done = true; + first_line = line_count + line_error; + // line count is 0-indexed, but we want to return 1-indexed + return first_line + 1; +} + +static unique_ptr ParallelCSVInitGlobal(ClientContext &context, + TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->CastNoConst(); + if (bind_data.files.empty()) { + // This can happen when a filename based filter pushdown has eliminated all possible files for this scan. + return make_uniq(context.db->NumberOfThreads()); + } + bind_data.options.file_path = bind_data.files[0]; + auto buffer_manager = bind_data.buffer_manager; + return make_uniq(context, buffer_manager, bind_data.options, context.db->NumberOfThreads(), + bind_data.files, ClientConfig::GetConfig(context).verify_parallelism, + input.column_ids); +} + +//===--------------------------------------------------------------------===// +// Read CSV Local State +//===--------------------------------------------------------------------===// +struct ParallelCSVLocalState : public LocalTableFunctionState { +public: + explicit ParallelCSVLocalState(unique_ptr csv_reader_p) : csv_reader(std::move(csv_reader_p)) { + } + + //! The CSV reader + unique_ptr csv_reader; + CSVBufferRead previous_buffer; + bool done = false; +}; + +unique_ptr ParallelReadCSVInitLocal(ExecutionContext &context, TableFunctionInitInput &input, + GlobalTableFunctionState *global_state_p) { + auto &csv_data = input.bind_data->Cast(); + auto &global_state = global_state_p->Cast(); + unique_ptr csv_reader; + auto has_next = global_state.Next(context.client, csv_data, csv_reader); + if (!has_next) { + global_state.DecrementThread(); + csv_reader.reset(); + } + return make_uniq(std::move(csv_reader)); +} + +static void ParallelReadCSVFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &csv_global_state = data_p.global_state->Cast(); + auto &csv_local_state = data_p.local_state->Cast(); + + if (!csv_local_state.csv_reader) { + // no csv_reader was set, this can happen when a filename-based filter has filtered out all possible files + return; + } + + do { + if (output.size() != 0) { + MultiFileReader::FinalizeChunk(bind_data.reader_bind, csv_local_state.csv_reader->reader_data, output); + break; + } + if (csv_local_state.csv_reader->finished) { + auto verification_updates = csv_local_state.csv_reader->GetVerificationPositions(); + csv_global_state.UpdateVerification(verification_updates, + csv_local_state.csv_reader->buffer->buffer->file_idx, + csv_local_state.csv_reader->buffer->local_batch_index); + csv_global_state.UpdateLinesRead(*csv_local_state.csv_reader->buffer, csv_local_state.csv_reader->file_idx); + auto has_next = csv_global_state.Next(context, bind_data, csv_local_state.csv_reader); + if (csv_local_state.csv_reader) { + csv_local_state.csv_reader->linenr = 0; + } + if (!has_next) { + csv_global_state.DecrementThread(); + break; + } + } + csv_local_state.csv_reader->ParseCSV(output); + + } while (true); + if (csv_global_state.Finished()) { + csv_global_state.Verify(); + } +} + +//===--------------------------------------------------------------------===// +// Single-Threaded CSV Reader +//===--------------------------------------------------------------------===// +struct SingleThreadedCSVState : public GlobalTableFunctionState { + explicit SingleThreadedCSVState(idx_t total_files) : total_files(total_files), next_file(0), progress_in_files(0) { + } + + mutex csv_lock; + unique_ptr initial_reader; + //! The total number of files to read from + idx_t total_files; + //! The index of the next file to read (i.e. current file + 1) + atomic next_file; + //! How far along we are in reading the current set of open files + //! This goes from [0...next_file] * 100 + atomic progress_in_files; + //! The set of SQL types + vector csv_types; + //! The set of SQL names to be read from the file + vector csv_names; + //! The column ids to read + vector column_ids; + + idx_t MaxThreads() const override { + return total_files; + } + + double GetProgress(const ReadCSVData &bind_data) const { + D_ASSERT(total_files == bind_data.files.size()); + D_ASSERT(progress_in_files <= total_files * 100); + return (double(progress_in_files) / double(total_files)); + } + + unique_ptr GetCSVReader(ClientContext &context, ReadCSVData &bind_data, idx_t &file_index, + idx_t &total_size) { + return GetCSVReaderInternal(context, bind_data, file_index, total_size); + } + +private: + unique_ptr GetCSVReaderInternal(ClientContext &context, ReadCSVData &bind_data, + idx_t &file_index, idx_t &total_size) { + CSVReaderOptions options; + { + lock_guard l(csv_lock); + if (initial_reader) { + total_size = initial_reader->file_handle ? initial_reader->file_handle->FileSize() : 0; + return std::move(initial_reader); + } + if (next_file >= total_files) { + return nullptr; + } + options = bind_data.options; + file_index = next_file; + next_file++; + } + // reuse csv_readers was created during binding + unique_ptr result; + if (file_index < bind_data.union_readers.size() && bind_data.union_readers[file_index]) { + result = std::move(bind_data.union_readers[file_index]); + } else { + auto union_by_name = options.file_options.union_by_name; + options.file_path = bind_data.files[file_index]; + result = make_uniq(context, std::move(options), csv_types); + if (!union_by_name) { + result->names = csv_names; + } + MultiFileReader::InitializeReader(*result, bind_data.options.file_options, bind_data.reader_bind, + bind_data.return_types, bind_data.return_names, column_ids, nullptr, + bind_data.files.front(), context); + } + total_size = result->file_handle->FileSize(); + return result; + } +}; + +struct SingleThreadedCSVLocalState : public LocalTableFunctionState { +public: + explicit SingleThreadedCSVLocalState() : bytes_read(0), total_size(0), current_progress(0), file_index(0) { + } + + //! The CSV reader + unique_ptr csv_reader; + //! The current amount of bytes read by this reader + idx_t bytes_read; + //! The total amount of bytes in the file + idx_t total_size; + //! The current progress from 0..100 + idx_t current_progress; + //! The file index of this reader + idx_t file_index; +}; + +static unique_ptr SingleThreadedCSVInit(ClientContext &context, + TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->CastNoConst(); + auto result = make_uniq(bind_data.files.size()); + if (bind_data.files.empty()) { + // This can happen when a filename based filter pushdown has eliminated all possible files for this scan. + return std::move(result); + } else { + bind_data.options.file_path = bind_data.files[0]; + result->initial_reader = make_uniq(context, bind_data.options, bind_data.csv_types); + if (!bind_data.options.file_options.union_by_name) { + result->initial_reader->names = bind_data.csv_names; + } + if (bind_data.options.auto_detect) { + bind_data.options = result->initial_reader->options; + } + } + MultiFileReader::InitializeReader(*result->initial_reader, bind_data.options.file_options, bind_data.reader_bind, + bind_data.return_types, bind_data.return_names, input.column_ids, input.filters, + bind_data.files.front(), context); + for (auto &reader : bind_data.union_readers) { + if (!reader) { + continue; + } + MultiFileReader::InitializeReader(*reader, bind_data.options.file_options, bind_data.reader_bind, + bind_data.return_types, bind_data.return_names, input.column_ids, + input.filters, bind_data.files.front(), context); + } + result->column_ids = input.column_ids; + + if (!bind_data.options.file_options.union_by_name) { + // if we are reading multiple files - run auto-detect only on the first file + // UNLESS union by name is turned on - in that case we assume that different files have different schemas + // as such, we need to re-run the auto detection on each file + bind_data.options.auto_detect = false; + } + result->csv_types = bind_data.csv_types; + result->csv_names = bind_data.csv_names; + result->next_file = 1; + return std::move(result); +} + +unique_ptr SingleThreadedReadCSVInitLocal(ExecutionContext &context, + TableFunctionInitInput &input, + GlobalTableFunctionState *global_state_p) { + auto &bind_data = input.bind_data->CastNoConst(); + auto &data = global_state_p->Cast(); + auto result = make_uniq(); + result->csv_reader = data.GetCSVReader(context.client, bind_data, result->file_index, result->total_size); + return std::move(result); +} + +static void SingleThreadedCSVFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->CastNoConst(); + auto &data = data_p.global_state->Cast(); + auto &lstate = data_p.local_state->Cast(); + if (!lstate.csv_reader) { + // no csv_reader was set, this can happen when a filename-based filter has filtered out all possible files + return; + } + + do { + lstate.csv_reader->ParseCSV(output); + // update the number of bytes read + D_ASSERT(lstate.bytes_read <= lstate.csv_reader->bytes_in_chunk); + auto bytes_read = MinValue(lstate.total_size, lstate.csv_reader->bytes_in_chunk); + auto current_progress = lstate.total_size == 0 ? 100 : 100 * bytes_read / lstate.total_size; + if (current_progress > lstate.current_progress) { + if (current_progress > 100) { + throw InternalException("Progress should never exceed 100"); + } + data.progress_in_files += current_progress - lstate.current_progress; + lstate.current_progress = current_progress; + } + if (output.size() == 0) { + // exhausted this file, but we might have more files we can read + auto csv_reader = data.GetCSVReader(context, bind_data, lstate.file_index, lstate.total_size); + // add any left-over progress for this file to the progress bar + if (lstate.current_progress < 100) { + data.progress_in_files += 100 - lstate.current_progress; + } + // reset the current progress + lstate.current_progress = 0; + lstate.bytes_read = 0; + lstate.csv_reader = std::move(csv_reader); + if (!lstate.csv_reader) { + // no more files - we are done + return; + } + lstate.bytes_read = 0; + } else { + MultiFileReader::FinalizeChunk(bind_data.reader_bind, lstate.csv_reader->reader_data, output); + break; + } + } while (true); +} + +//===--------------------------------------------------------------------===// +// Read CSV Functions +//===--------------------------------------------------------------------===// +static unique_ptr ReadCSVInitGlobal(ClientContext &context, TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->Cast(); + + // Create the temporary rejects table + auto rejects_table = bind_data.options.rejects_table_name; + if (!rejects_table.empty()) { + CSVRejectsTable::GetOrCreate(context, rejects_table)->InitializeTable(context, bind_data); + } + if (bind_data.single_threaded) { + return SingleThreadedCSVInit(context, input); + } else { + return ParallelCSVInitGlobal(context, input); + } +} + +unique_ptr ReadCSVInitLocal(ExecutionContext &context, TableFunctionInitInput &input, + GlobalTableFunctionState *global_state_p) { + auto &csv_data = input.bind_data->Cast(); + if (csv_data.single_threaded) { + return SingleThreadedReadCSVInitLocal(context, input, global_state_p); + } else { + return ParallelReadCSVInitLocal(context, input, global_state_p); + } +} + +static void ReadCSVFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + if (bind_data.single_threaded) { + SingleThreadedCSVFunction(context, data_p, output); + } else { + ParallelReadCSVFunction(context, data_p, output); + } +} + +static idx_t CSVReaderGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, + LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state) { + auto &bind_data = bind_data_p->Cast(); + if (bind_data.single_threaded) { + auto &data = local_state->Cast(); + return data.file_index; + } + auto &data = local_state->Cast(); + return data.csv_reader->buffer->batch_index; +} + +static void ReadCSVAddNamedParameters(TableFunction &table_function) { + table_function.named_parameters["sep"] = LogicalType::VARCHAR; + table_function.named_parameters["delim"] = LogicalType::VARCHAR; + table_function.named_parameters["quote"] = LogicalType::VARCHAR; + table_function.named_parameters["new_line"] = LogicalType::VARCHAR; + table_function.named_parameters["escape"] = LogicalType::VARCHAR; + table_function.named_parameters["nullstr"] = LogicalType::VARCHAR; + table_function.named_parameters["columns"] = LogicalType::ANY; + table_function.named_parameters["auto_type_candidates"] = LogicalType::ANY; + table_function.named_parameters["header"] = LogicalType::BOOLEAN; + table_function.named_parameters["auto_detect"] = LogicalType::BOOLEAN; + table_function.named_parameters["sample_size"] = LogicalType::BIGINT; + table_function.named_parameters["all_varchar"] = LogicalType::BOOLEAN; + table_function.named_parameters["dateformat"] = LogicalType::VARCHAR; + table_function.named_parameters["timestampformat"] = LogicalType::VARCHAR; + table_function.named_parameters["normalize_names"] = LogicalType::BOOLEAN; + table_function.named_parameters["compression"] = LogicalType::VARCHAR; + table_function.named_parameters["skip"] = LogicalType::BIGINT; + table_function.named_parameters["max_line_size"] = LogicalType::VARCHAR; + table_function.named_parameters["maximum_line_size"] = LogicalType::VARCHAR; + table_function.named_parameters["ignore_errors"] = LogicalType::BOOLEAN; + table_function.named_parameters["rejects_table"] = LogicalType::VARCHAR; + table_function.named_parameters["rejects_limit"] = LogicalType::BIGINT; + table_function.named_parameters["rejects_recovery_columns"] = LogicalType::LIST(LogicalType::VARCHAR); + table_function.named_parameters["buffer_size"] = LogicalType::UBIGINT; + table_function.named_parameters["decimal_separator"] = LogicalType::VARCHAR; + table_function.named_parameters["parallel"] = LogicalType::BOOLEAN; + table_function.named_parameters["null_padding"] = LogicalType::BOOLEAN; + table_function.named_parameters["allow_quoted_nulls"] = LogicalType::BOOLEAN; + table_function.named_parameters["column_types"] = LogicalType::ANY; + table_function.named_parameters["dtypes"] = LogicalType::ANY; + table_function.named_parameters["types"] = LogicalType::ANY; + table_function.named_parameters["names"] = LogicalType::LIST(LogicalType::VARCHAR); + table_function.named_parameters["column_names"] = LogicalType::LIST(LogicalType::VARCHAR); + MultiFileReader::AddParameters(table_function); +} + +double CSVReaderProgress(ClientContext &context, const FunctionData *bind_data_p, + const GlobalTableFunctionState *global_state) { + auto &bind_data = bind_data_p->Cast(); + if (bind_data.single_threaded) { + auto &data = global_state->Cast(); + return data.GetProgress(bind_data); + } else { + auto &data = global_state->Cast(); + return data.GetProgress(bind_data); + } +} + +void CSVComplexFilterPushdown(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p, + vector> &filters) { + auto &data = bind_data_p->Cast(); + auto reset_reader = + MultiFileReader::ComplexFilterPushdown(context, data.files, data.options.file_options, get, filters); + if (reset_reader) { + MultiFileReader::PruneReaders(data); + } +} + +unique_ptr CSVReaderCardinality(ClientContext &context, const FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + idx_t per_file_cardinality = 0; + if (bind_data.buffer_manager && bind_data.buffer_manager->file_handle) { + auto estimated_row_width = (bind_data.csv_types.size() * 5); + per_file_cardinality = bind_data.buffer_manager->file_handle->FileSize() / estimated_row_width; + } else { + // determined through the scientific method as the average amount of rows in a CSV file + per_file_cardinality = 42; + } + return make_uniq(bind_data.files.size() * per_file_cardinality); +} + +static void CSVReaderSerialize(Serializer &serializer, const optional_ptr bind_data_p, + const TableFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "extra_info", function.extra_info); + serializer.WriteProperty(101, "csv_data", &bind_data); +} + +static unique_ptr CSVReaderDeserialize(Deserializer &deserializer, TableFunction &function) { + unique_ptr result; + deserializer.ReadProperty(100, "extra_info", function.extra_info); + deserializer.ReadProperty(101, "csv_data", result); + return std::move(result); +} + +TableFunction ReadCSVTableFunction::GetFunction() { + TableFunction read_csv("read_csv", {LogicalType::VARCHAR}, ReadCSVFunction, ReadCSVBind, ReadCSVInitGlobal, + ReadCSVInitLocal); + read_csv.table_scan_progress = CSVReaderProgress; + read_csv.pushdown_complex_filter = CSVComplexFilterPushdown; + read_csv.serialize = CSVReaderSerialize; + read_csv.deserialize = CSVReaderDeserialize; + read_csv.get_batch_index = CSVReaderGetBatchIndex; + read_csv.cardinality = CSVReaderCardinality; + read_csv.projection_pushdown = true; + ReadCSVAddNamedParameters(read_csv); + return read_csv; +} + +TableFunction ReadCSVTableFunction::GetAutoFunction() { + auto read_csv_auto = ReadCSVTableFunction::GetFunction(); + read_csv_auto.name = "read_csv_auto"; + read_csv_auto.bind = ReadCSVAutoBind; + return read_csv_auto; +} + +void ReadCSVTableFunction::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(MultiFileReader::CreateFunctionSet(ReadCSVTableFunction::GetFunction())); + set.AddFunction(MultiFileReader::CreateFunctionSet(ReadCSVTableFunction::GetAutoFunction())); +} + +unique_ptr ReadCSVReplacement(ClientContext &context, const string &table_name, ReplacementScanData *data) { + auto lower_name = StringUtil::Lower(table_name); + // remove any compression + if (StringUtil::EndsWith(lower_name, ".gz")) { + lower_name = lower_name.substr(0, lower_name.size() - 3); + } else if (StringUtil::EndsWith(lower_name, ".zst")) { + if (!Catalog::TryAutoLoad(context, "parquet")) { + throw MissingExtensionException("parquet extension is required for reading zst compressed file"); + } + lower_name = lower_name.substr(0, lower_name.size() - 4); + } + if (!StringUtil::EndsWith(lower_name, ".csv") && !StringUtil::Contains(lower_name, ".csv?") && + !StringUtil::EndsWith(lower_name, ".tsv") && !StringUtil::Contains(lower_name, ".tsv?")) { + return nullptr; + } + auto table_function = make_uniq(); + vector> children; + children.push_back(make_uniq(Value(table_name))); + table_function->function = make_uniq("read_csv_auto", std::move(children)); + + if (!FileSystem::HasGlob(table_name)) { + auto &fs = FileSystem::GetFileSystem(context); + table_function->alias = fs.ExtractBaseName(table_name); + } + + return std::move(table_function); +} + +void BuiltinFunctions::RegisterReadFunctions() { + CSVCopyFunction::RegisterFunction(*this); + ReadCSVTableFunction::RegisterFunction(*this); + auto &config = DBConfig::GetConfig(*transaction.db); + config.replacement_scans.emplace_back(ReadCSVReplacement); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/repeat.cpp b/src/duckdb/src/function/table/repeat.cpp new file mode 100644 index 00000000..b62cbe15 --- /dev/null +++ b/src/duckdb/src/function/table/repeat.cpp @@ -0,0 +1,57 @@ +#include "duckdb/function/table/range.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +struct RepeatFunctionData : public TableFunctionData { + RepeatFunctionData(Value value, idx_t target_count) : value(std::move(value)), target_count(target_count) { + } + + Value value; + idx_t target_count; +}; + +struct RepeatOperatorData : public GlobalTableFunctionState { + RepeatOperatorData() : current_count(0) { + } + idx_t current_count; +}; + +static unique_ptr RepeatBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + // the repeat function returns the type of the first argument + auto &inputs = input.inputs; + return_types.push_back(inputs[0].type()); + names.push_back(inputs[0].ToString()); + if (inputs[1].IsNull()) { + throw BinderException("Repeat second parameter cannot be NULL"); + } + return make_uniq(inputs[0], inputs[1].GetValue()); +} + +static unique_ptr RepeatInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void RepeatFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &state = data_p.global_state->Cast(); + + idx_t remaining = MinValue(bind_data.target_count - state.current_count, STANDARD_VECTOR_SIZE); + output.data[0].Reference(bind_data.value); + output.SetCardinality(remaining); + state.current_count += remaining; +} + +static unique_ptr RepeatCardinality(ClientContext &context, const FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + return make_uniq(bind_data.target_count, bind_data.target_count); +} + +void RepeatTableFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunction repeat("repeat", {LogicalType::ANY, LogicalType::BIGINT}, RepeatFunction, RepeatBind, RepeatInit); + repeat.cardinality = RepeatCardinality; + set.AddFunction(repeat); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/repeat_row.cpp b/src/duckdb/src/function/table/repeat_row.cpp new file mode 100644 index 00000000..a81d8720 --- /dev/null +++ b/src/duckdb/src/function/table/repeat_row.cpp @@ -0,0 +1,67 @@ +#include "duckdb/function/table/range.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +struct RepeatRowFunctionData : public TableFunctionData { + RepeatRowFunctionData(vector values, idx_t target_count) + : values(std::move(values)), target_count(target_count) { + } + + const vector values; + idx_t target_count; +}; + +struct RepeatRowOperatorData : public GlobalTableFunctionState { + RepeatRowOperatorData() : current_count(0) { + } + idx_t current_count; +}; + +static unique_ptr RepeatRowBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto &inputs = input.inputs; + for (idx_t input_idx = 0; input_idx < inputs.size(); input_idx++) { + return_types.push_back(inputs[input_idx].type()); + names.push_back("column" + std::to_string(input_idx)); + } + auto entry = input.named_parameters.find("num_rows"); + if (entry == input.named_parameters.end()) { + throw BinderException("repeat_rows requires num_rows to be specified"); + } + if (inputs.empty()) { + throw BinderException("repeat_rows requires at least one column to be specified"); + } + return make_uniq(inputs, entry->second.GetValue()); +} + +static unique_ptr RepeatRowInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void RepeatRowFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &state = data_p.global_state->Cast(); + + idx_t remaining = MinValue(bind_data.target_count - state.current_count, STANDARD_VECTOR_SIZE); + for (idx_t val_idx = 0; val_idx < bind_data.values.size(); val_idx++) { + output.data[val_idx].Reference(bind_data.values[val_idx]); + } + output.SetCardinality(remaining); + state.current_count += remaining; +} + +static unique_ptr RepeatRowCardinality(ClientContext &context, const FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + return make_uniq(bind_data.target_count, bind_data.target_count); +} + +void RepeatRowTableFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunction repeat_row("repeat_row", {}, RepeatRowFunction, RepeatRowBind, RepeatRowInit); + repeat_row.varargs = LogicalType::ANY; + repeat_row.named_parameters["num_rows"] = LogicalType::BIGINT; + repeat_row.cardinality = RepeatRowCardinality; + set.AddFunction(repeat_row); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/summary.cpp b/src/duckdb/src/function/table/summary.cpp new file mode 100644 index 00000000..d6c4615e --- /dev/null +++ b/src/duckdb/src/function/table/summary.cpp @@ -0,0 +1,52 @@ +#include "duckdb/function/table/summary.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/common/file_system.hpp" + +// this function makes not that much sense on its own but is a demo for table-parameter table-producing functions + +namespace duckdb { + +static unique_ptr SummaryFunctionBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + + return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("summary"); + + for (idx_t i = 0; i < input.input_table_types.size(); i++) { + return_types.push_back(input.input_table_types[i]); + names.emplace_back(input.input_table_names[i]); + } + + return make_uniq(); +} + +static OperatorResultType SummaryFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, + DataChunk &output) { + output.SetCardinality(input.size()); + + for (idx_t row_idx = 0; row_idx < input.size(); row_idx++) { + string summary_val = "["; + + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + summary_val += input.GetValue(col_idx, row_idx).ToString(); + if (col_idx < input.ColumnCount() - 1) { + summary_val += ", "; + } + } + summary_val += "]"; + output.SetValue(0, row_idx, Value(summary_val)); + } + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + output.data[col_idx + 1].Reference(input.data[col_idx]); + } + return OperatorResultType::NEED_MORE_INPUT; +} + +void SummaryTableFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunction summary_function("summary", {LogicalType::TABLE}, nullptr, SummaryFunctionBind); + summary_function.in_out_function = SummaryFunction; + set.AddFunction(summary_function); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_columns.cpp b/src/duckdb/src/function/table/system/duckdb_columns.cpp new file mode 100644 index 00000000..42d8203e --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_columns.cpp @@ -0,0 +1,334 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/parser/constraints/not_null_constraint.hpp" + +#include + +namespace duckdb { + +struct DuckDBColumnsData : public GlobalTableFunctionState { + DuckDBColumnsData() : offset(0), column_offset(0) { + } + + vector> entries; + idx_t offset; + idx_t column_offset; +}; + +static unique_ptr DuckDBColumnsBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("schema_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("table_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("table_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("column_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("column_index"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("internal"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("column_default"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("is_nullable"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("data_type"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("data_type_id"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("character_maximum_length"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("numeric_precision"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("numeric_precision_radix"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("numeric_scale"); + return_types.emplace_back(LogicalType::INTEGER); + + return nullptr; +} + +unique_ptr DuckDBColumnsInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas for tables and views and collect them + auto schemas = Catalog::GetAllSchemas(context); + for (auto &schema : schemas) { + schema.get().Scan(context, CatalogType::TABLE_ENTRY, + [&](CatalogEntry &entry) { result->entries.push_back(entry); }); + } + return std::move(result); +} + +class ColumnHelper { +public: + static unique_ptr Create(CatalogEntry &entry); + + virtual ~ColumnHelper() { + } + + virtual StandardEntry &Entry() = 0; + virtual idx_t NumColumns() = 0; + virtual const string &ColumnName(idx_t col) = 0; + virtual const LogicalType &ColumnType(idx_t col) = 0; + virtual const Value ColumnDefault(idx_t col) = 0; + virtual bool IsNullable(idx_t col) = 0; + + void WriteColumns(idx_t index, idx_t start_col, idx_t end_col, DataChunk &output); +}; + +class TableColumnHelper : public ColumnHelper { +public: + explicit TableColumnHelper(TableCatalogEntry &entry) : entry(entry) { + for (auto &constraint : entry.GetConstraints()) { + if (constraint->type == ConstraintType::NOT_NULL) { + auto ¬_null = *reinterpret_cast(constraint.get()); + not_null_cols.insert(not_null.index.index); + } + } + } + + StandardEntry &Entry() override { + return entry; + } + idx_t NumColumns() override { + return entry.GetColumns().LogicalColumnCount(); + } + const string &ColumnName(idx_t col) override { + return entry.GetColumn(LogicalIndex(col)).Name(); + } + const LogicalType &ColumnType(idx_t col) override { + return entry.GetColumn(LogicalIndex(col)).Type(); + } + const Value ColumnDefault(idx_t col) override { + auto &column = entry.GetColumn(LogicalIndex(col)); + if (column.Generated()) { + return Value(column.GeneratedExpression().ToString()); + } else if (column.DefaultValue()) { + return Value(column.DefaultValue()->ToString()); + } + return Value(); + } + bool IsNullable(idx_t col) override { + return not_null_cols.find(col) == not_null_cols.end(); + } + +private: + TableCatalogEntry &entry; + std::set not_null_cols; +}; + +class ViewColumnHelper : public ColumnHelper { +public: + explicit ViewColumnHelper(ViewCatalogEntry &entry) : entry(entry) { + } + + StandardEntry &Entry() override { + return entry; + } + idx_t NumColumns() override { + return entry.types.size(); + } + const string &ColumnName(idx_t col) override { + return entry.aliases[col]; + } + const LogicalType &ColumnType(idx_t col) override { + return entry.types[col]; + } + const Value ColumnDefault(idx_t col) override { + return Value(); + } + bool IsNullable(idx_t col) override { + return true; + } + +private: + ViewCatalogEntry &entry; +}; + +unique_ptr ColumnHelper::Create(CatalogEntry &entry) { + switch (entry.type) { + case CatalogType::TABLE_ENTRY: + return make_uniq(entry.Cast()); + case CatalogType::VIEW_ENTRY: + return make_uniq(entry.Cast()); + default: + throw NotImplementedException("Unsupported catalog type for duckdb_columns"); + } +} + +void ColumnHelper::WriteColumns(idx_t start_index, idx_t start_col, idx_t end_col, DataChunk &output) { + for (idx_t i = start_col; i < end_col; i++) { + auto index = start_index + (i - start_col); + auto &entry = Entry(); + + idx_t col = 0; + // database_name, VARCHAR + output.SetValue(col++, index, entry.catalog.GetName()); + // database_oid, BIGINT + output.SetValue(col++, index, Value::BIGINT(entry.catalog.GetOid())); + // schema_name, VARCHAR + output.SetValue(col++, index, entry.schema.name); + // schema_oid, BIGINT + output.SetValue(col++, index, Value::BIGINT(entry.schema.oid)); + // table_name, VARCHAR + output.SetValue(col++, index, entry.name); + // table_oid, BIGINT + output.SetValue(col++, index, Value::BIGINT(entry.oid)); + // column_name, VARCHAR + output.SetValue(col++, index, Value(ColumnName(i))); + // column_index, INTEGER + output.SetValue(col++, index, Value::INTEGER(i + 1)); + // internal, BOOLEAN + output.SetValue(col++, index, Value::BOOLEAN(entry.internal)); + // column_default, VARCHAR + output.SetValue(col++, index, Value(ColumnDefault(i))); + // is_nullable, BOOLEAN + output.SetValue(col++, index, Value::BOOLEAN(IsNullable(i))); + // data_type, VARCHAR + const LogicalType &type = ColumnType(i); + output.SetValue(col++, index, Value(type.ToString())); + // data_type_id, BIGINT + output.SetValue(col++, index, Value::BIGINT(int(type.id()))); + if (type == LogicalType::VARCHAR) { + // FIXME: need check constraints in place to set this correctly + // character_maximum_length, INTEGER + output.SetValue(col++, index, Value()); + } else { + // "character_maximum_length", PhysicalType::INTEGER + output.SetValue(col++, index, Value()); + } + + Value numeric_precision, numeric_scale, numeric_precision_radix; + switch (type.id()) { + case LogicalTypeId::DECIMAL: + numeric_precision = Value::INTEGER(DecimalType::GetWidth(type)); + numeric_scale = Value::INTEGER(DecimalType::GetScale(type)); + numeric_precision_radix = Value::INTEGER(10); + break; + case LogicalTypeId::HUGEINT: + numeric_precision = Value::INTEGER(128); + numeric_scale = Value::INTEGER(0); + numeric_precision_radix = Value::INTEGER(2); + break; + case LogicalTypeId::BIGINT: + numeric_precision = Value::INTEGER(64); + numeric_scale = Value::INTEGER(0); + numeric_precision_radix = Value::INTEGER(2); + break; + case LogicalTypeId::INTEGER: + numeric_precision = Value::INTEGER(32); + numeric_scale = Value::INTEGER(0); + numeric_precision_radix = Value::INTEGER(2); + break; + case LogicalTypeId::SMALLINT: + numeric_precision = Value::INTEGER(16); + numeric_scale = Value::INTEGER(0); + numeric_precision_radix = Value::INTEGER(2); + break; + case LogicalTypeId::TINYINT: + numeric_precision = Value::INTEGER(8); + numeric_scale = Value::INTEGER(0); + numeric_precision_radix = Value::INTEGER(2); + break; + case LogicalTypeId::FLOAT: + numeric_precision = Value::INTEGER(24); + numeric_scale = Value::INTEGER(0); + numeric_precision_radix = Value::INTEGER(2); + break; + case LogicalTypeId::DOUBLE: + numeric_precision = Value::INTEGER(53); + numeric_scale = Value::INTEGER(0); + numeric_precision_radix = Value::INTEGER(2); + break; + default: + numeric_precision = Value(); + numeric_scale = Value(); + numeric_precision_radix = Value(); + break; + } + + // numeric_precision, INTEGER + output.SetValue(col++, index, numeric_precision); + // numeric_precision_radix, INTEGER + output.SetValue(col++, index, numeric_precision_radix); + // numeric_scale, INTEGER + output.SetValue(col++, index, numeric_scale); + } +} + +void DuckDBColumnsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + + // We need to track the offset of the relation we're writing as well as the last column + // we wrote from that relation (if any); it's possible that we can fill up the output + // with a partial list of columns from a relation and will need to pick up processing the + // next chunk at the same spot. + idx_t next = data.offset; + idx_t column_offset = data.column_offset; + idx_t index = 0; + while (next < data.entries.size() && index < STANDARD_VECTOR_SIZE) { + auto column_helper = ColumnHelper::Create(data.entries[next].get()); + idx_t columns = column_helper->NumColumns(); + + // Check to see if we are going to exceed the maximum index for a DataChunk + if (index + (columns - column_offset) > STANDARD_VECTOR_SIZE) { + idx_t column_limit = column_offset + (STANDARD_VECTOR_SIZE - index); + output.SetCardinality(STANDARD_VECTOR_SIZE); + column_helper->WriteColumns(index, column_offset, column_limit, output); + + // Make the current column limit the column offset when we process the next chunk + column_offset = column_limit; + break; + } else { + // Otherwise, write all of the columns from the current relation and + // then move on to the next one. + output.SetCardinality(index + (columns - column_offset)); + column_helper->WriteColumns(index, column_offset, columns, output); + index += columns - column_offset; + next++; + column_offset = 0; + } + } + data.offset = next; + data.column_offset = column_offset; +} + +void DuckDBColumnsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_columns", {}, DuckDBColumnsFunction, DuckDBColumnsBind, DuckDBColumnsInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_constraints.cpp b/src/duckdb/src/function/table/system/duckdb_constraints.cpp new file mode 100644 index 00000000..560c6b3a --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_constraints.cpp @@ -0,0 +1,314 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/constraints/check_constraint.hpp" +#include "duckdb/parser/constraints/unique_constraint.hpp" +#include "duckdb/planner/constraints/bound_unique_constraint.hpp" +#include "duckdb/planner/constraints/bound_check_constraint.hpp" +#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" +#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" +#include "duckdb/storage/data_table.hpp" + +namespace duckdb { + +struct UniqueKeyInfo { + string schema; + string table; + vector columns; + + bool operator==(const UniqueKeyInfo &other) const { + return (schema == other.schema) && (table == other.table) && (columns == other.columns); + } +}; + +} // namespace duckdb + +namespace std { + +template <> +struct hash { + template + static size_t ComputeHash(const X &x) { + return hash()(x); + } + + size_t operator()(const duckdb::UniqueKeyInfo &j) const { + D_ASSERT(j.columns.size() > 0); + return ComputeHash(j.schema) + ComputeHash(j.table) + ComputeHash(j.columns[0].index); + } +}; + +} // namespace std + +namespace duckdb { + +struct DuckDBConstraintsData : public GlobalTableFunctionState { + DuckDBConstraintsData() : offset(0), constraint_offset(0), unique_constraint_offset(0) { + } + + vector> entries; + idx_t offset; + idx_t constraint_offset; + idx_t unique_constraint_offset; + unordered_map known_fk_unique_constraint_offsets; +}; + +static unique_ptr DuckDBConstraintsBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("schema_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("table_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("table_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("constraint_index"); + return_types.emplace_back(LogicalType::BIGINT); + + // CHECK, PRIMARY KEY or UNIQUE + names.emplace_back("constraint_type"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("constraint_text"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("expression"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("constraint_column_indexes"); + return_types.push_back(LogicalType::LIST(LogicalType::BIGINT)); + + names.emplace_back("constraint_column_names"); + return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); + + return nullptr; +} + +unique_ptr DuckDBConstraintsInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas for tables and collect them + auto schemas = Catalog::GetAllSchemas(context); + + for (auto &schema : schemas) { + vector> entries; + + schema.get().Scan(context, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { + if (entry.type == CatalogType::TABLE_ENTRY) { + entries.push_back(entry); + } + }); + + sort(entries.begin(), entries.end(), [&](CatalogEntry &x, CatalogEntry &y) { return (x.name < y.name); }); + + result->entries.insert(result->entries.end(), entries.begin(), entries.end()); + }; + + return std::move(result); +} + +void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset].get(); + D_ASSERT(entry.type == CatalogType::TABLE_ENTRY); + + auto &table = entry.Cast(); + auto &constraints = table.GetConstraints(); + bool is_duck_table = table.IsDuckTable(); + for (; data.constraint_offset < constraints.size() && count < STANDARD_VECTOR_SIZE; data.constraint_offset++) { + auto &constraint = constraints[data.constraint_offset]; + // return values: + // constraint_type, VARCHAR + // Processing this first due to shortcut (early continue) + string constraint_type; + switch (constraint->type) { + case ConstraintType::CHECK: + constraint_type = "CHECK"; + break; + case ConstraintType::UNIQUE: { + auto &unique = constraint->Cast(); + constraint_type = unique.is_primary_key ? "PRIMARY KEY" : "UNIQUE"; + break; + } + case ConstraintType::NOT_NULL: + constraint_type = "NOT NULL"; + break; + case ConstraintType::FOREIGN_KEY: { + if (!is_duck_table) { + continue; + } + auto &bound_constraints = table.GetBoundConstraints(); + auto &bound_foreign_key = bound_constraints[data.constraint_offset]->Cast(); + if (bound_foreign_key.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE) { + // Those are already covered by PRIMARY KEY and UNIQUE entries + continue; + } + constraint_type = "FOREIGN KEY"; + break; + } + default: + throw NotImplementedException("Unimplemented constraint for duckdb_constraints"); + } + + idx_t col = 0; + // database_name, LogicalType::VARCHAR + output.SetValue(col++, count, Value(table.schema.catalog.GetName())); + // database_oid, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(table.schema.catalog.GetOid())); + // schema_name, LogicalType::VARCHAR + output.SetValue(col++, count, Value(table.schema.name)); + // schema_oid, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(table.schema.oid)); + // table_name, LogicalType::VARCHAR + output.SetValue(col++, count, Value(table.name)); + // table_oid, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(table.oid)); + + // constraint_index, BIGINT + UniqueKeyInfo uk_info; + + if (is_duck_table) { + auto &bound_constraint = *table.GetBoundConstraints()[data.constraint_offset]; + switch (bound_constraint.type) { + case ConstraintType::UNIQUE: { + auto &bound_unique = bound_constraint.Cast(); + uk_info = {table.schema.name, table.name, bound_unique.keys}; + break; + } + case ConstraintType::FOREIGN_KEY: { + const auto &bound_foreign_key = bound_constraint.Cast(); + const auto &info = bound_foreign_key.info; + // find the other table + auto table_entry = Catalog::GetEntry( + context, table.catalog.GetName(), info.schema, info.table, OnEntryNotFound::RETURN_NULL); + if (!table_entry) { + throw InternalException("dukdb_constraints: entry %s.%s referenced in foreign key not found", + info.schema, info.table); + } + vector index; + for (auto &key : info.pk_keys) { + index.push_back(table_entry->GetColumns().PhysicalToLogical(key)); + } + uk_info = {table_entry->schema.name, table_entry->name, index}; + break; + } + default: + break; + } + } + + if (uk_info.columns.empty()) { + output.SetValue(col++, count, Value::BIGINT(data.unique_constraint_offset++)); + } else { + auto known_unique_constraint_offset = data.known_fk_unique_constraint_offsets.find(uk_info); + if (known_unique_constraint_offset == data.known_fk_unique_constraint_offsets.end()) { + data.known_fk_unique_constraint_offsets.insert(make_pair(uk_info, data.unique_constraint_offset)); + output.SetValue(col++, count, Value::BIGINT(data.unique_constraint_offset)); + data.unique_constraint_offset++; + } else { + output.SetValue(col++, count, Value::BIGINT(known_unique_constraint_offset->second)); + } + } + output.SetValue(col++, count, Value(constraint_type)); + + // constraint_text, VARCHAR + output.SetValue(col++, count, Value(constraint->ToString())); + + // expression, VARCHAR + Value expression_text; + if (constraint->type == ConstraintType::CHECK) { + auto &check = constraint->Cast(); + expression_text = Value(check.expression->ToString()); + } + output.SetValue(col++, count, expression_text); + + vector column_index_list; + if (is_duck_table) { + auto &bound_constraint = *table.GetBoundConstraints()[data.constraint_offset]; + switch (bound_constraint.type) { + case ConstraintType::CHECK: { + auto &bound_check = bound_constraint.Cast(); + for (auto &col_idx : bound_check.bound_columns) { + column_index_list.push_back(table.GetColumns().PhysicalToLogical(col_idx)); + } + break; + } + case ConstraintType::UNIQUE: { + auto &bound_unique = bound_constraint.Cast(); + for (auto &col_idx : bound_unique.keys) { + column_index_list.push_back(col_idx); + } + break; + } + case ConstraintType::NOT_NULL: { + auto &bound_not_null = bound_constraint.Cast(); + column_index_list.push_back(table.GetColumns().PhysicalToLogical(bound_not_null.index)); + break; + } + case ConstraintType::FOREIGN_KEY: { + auto &bound_foreign_key = bound_constraint.Cast(); + for (auto &col_idx : bound_foreign_key.info.fk_keys) { + column_index_list.push_back(table.GetColumns().PhysicalToLogical(col_idx)); + } + break; + } + default: + throw NotImplementedException("Unimplemented constraint for duckdb_constraints"); + } + } + + vector index_list; + vector column_name_list; + for (auto column_index : column_index_list) { + index_list.push_back(Value::BIGINT(column_index.index)); + column_name_list.emplace_back(table.GetColumn(column_index).Name()); + } + + // constraint_column_indexes, LIST + output.SetValue(col++, count, Value::LIST(LogicalType::BIGINT, std::move(index_list))); + + // constraint_column_names, LIST + output.SetValue(col++, count, Value::LIST(LogicalType::VARCHAR, std::move(column_name_list))); + + count++; + } + if (data.constraint_offset >= constraints.size()) { + data.constraint_offset = 0; + data.offset++; + } + } + output.SetCardinality(count); +} + +void DuckDBConstraintsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_constraints", {}, DuckDBConstraintsFunction, DuckDBConstraintsBind, + DuckDBConstraintsInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_databases.cpp b/src/duckdb/src/function/table/system/duckdb_databases.cpp new file mode 100644 index 00000000..981c7f84 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_databases.cpp @@ -0,0 +1,89 @@ +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/attached_database.hpp" + +namespace duckdb { + +struct DuckDBDatabasesData : public GlobalTableFunctionState { + DuckDBDatabasesData() : offset(0) { + } + + vector> entries; + idx_t offset; +}; + +static unique_ptr DuckDBDatabasesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("path"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("internal"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("type"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBDatabasesInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas for tables and collect them and collect them + auto &db_manager = DatabaseManager::Get(context); + result->entries = db_manager.GetDatabases(context); + return std::move(result); +} + +void DuckDBDatabasesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset++]; + + auto &attached = entry.get().Cast(); + // return values: + + idx_t col = 0; + // database_name, VARCHAR + output.SetValue(col++, count, attached.GetName()); + // database_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(attached.oid)); + // path, VARCHAR + bool is_internal = attached.IsSystem() || attached.IsTemporary(); + Value db_path; + if (!is_internal) { + bool in_memory = attached.GetCatalog().InMemory(); + if (!in_memory) { + db_path = Value(attached.GetCatalog().GetDBPath()); + } + } + output.SetValue(col++, count, db_path); + // internal, BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(is_internal)); + // type, VARCHAR + output.SetValue(col++, count, Value(attached.GetCatalog().GetCatalogType())); + + count++; + } + output.SetCardinality(count); +} + +void DuckDBDatabasesFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction( + TableFunction("duckdb_databases", {}, DuckDBDatabasesFunction, DuckDBDatabasesBind, DuckDBDatabasesInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_dependencies.cpp b/src/duckdb/src/function/table/system/duckdb_dependencies.cpp new file mode 100644 index 00000000..336aa415 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_dependencies.cpp @@ -0,0 +1,121 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +struct DependencyInformation { + DependencyInformation(CatalogEntry &object, CatalogEntry &dependent, DependencyType type) + : object(object), dependent(dependent), type(type) { + } + + CatalogEntry &object; + CatalogEntry &dependent; + DependencyType type; +}; + +struct DuckDBDependenciesData : public GlobalTableFunctionState { + DuckDBDependenciesData() : offset(0) { + } + + vector entries; + idx_t offset; +}; + +static unique_ptr DuckDBDependenciesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("classid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("objid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("objsubid"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("refclassid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("refobjid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("refobjsubid"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("deptype"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBDependenciesInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas and collect them + auto &catalog = Catalog::GetCatalog(context, INVALID_CATALOG); + if (catalog.IsDuckCatalog()) { + auto &duck_catalog = catalog.Cast(); + auto &dependency_manager = duck_catalog.GetDependencyManager(); + dependency_manager.Scan([&](CatalogEntry &obj, CatalogEntry &dependent, DependencyType type) { + result->entries.emplace_back(obj, dependent, type); + }); + } + + return std::move(result); +} + +void DuckDBDependenciesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset]; + + // return values: + // classid, LogicalType::BIGINT + output.SetValue(0, count, Value::BIGINT(0)); + // objid, LogicalType::BIGINT + output.SetValue(1, count, Value::BIGINT(entry.object.oid)); + // objsubid, LogicalType::INTEGER + output.SetValue(2, count, Value::INTEGER(0)); + // refclassid, LogicalType::BIGINT + output.SetValue(3, count, Value::BIGINT(0)); + // refobjid, LogicalType::BIGINT + output.SetValue(4, count, Value::BIGINT(entry.dependent.oid)); + // refobjsubid, LogicalType::INTEGER + output.SetValue(5, count, Value::INTEGER(0)); + // deptype, LogicalType::VARCHAR + string dependency_type_str; + switch (entry.type) { + case DependencyType::DEPENDENCY_REGULAR: + dependency_type_str = "n"; + break; + case DependencyType::DEPENDENCY_AUTOMATIC: + dependency_type_str = "a"; + break; + default: + throw NotImplementedException("Unimplemented dependency type"); + } + output.SetValue(6, count, Value(dependency_type_str)); + + data.offset++; + count++; + } + output.SetCardinality(count); +} + +void DuckDBDependenciesFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_dependencies", {}, DuckDBDependenciesFunction, DuckDBDependenciesBind, + DuckDBDependenciesInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_extensions.cpp b/src/duckdb/src/function/table/system/duckdb_extensions.cpp new file mode 100644 index 00000000..75cfc537 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_extensions.cpp @@ -0,0 +1,159 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/extension_helper.hpp" + +namespace duckdb { + +struct ExtensionInformation { + string name; + bool loaded = false; + bool installed = false; + string file_path; + string description; + vector aliases; +}; + +struct DuckDBExtensionsData : public GlobalTableFunctionState { + DuckDBExtensionsData() : offset(0) { + } + + vector entries; + idx_t offset; +}; + +static unique_ptr DuckDBExtensionsBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("extension_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("loaded"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("installed"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("install_path"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("description"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("aliases"); + return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); + + return nullptr; +} + +unique_ptr DuckDBExtensionsInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + auto &fs = FileSystem::GetFileSystem(context); + auto &db = DatabaseInstance::GetDatabase(context); + + map installed_extensions; + auto extension_count = ExtensionHelper::DefaultExtensionCount(); + auto alias_count = ExtensionHelper::ExtensionAliasCount(); + for (idx_t i = 0; i < extension_count; i++) { + auto extension = ExtensionHelper::GetDefaultExtension(i); + ExtensionInformation info; + info.name = extension.name; + info.installed = extension.statically_loaded; + info.loaded = false; + info.file_path = extension.statically_loaded ? "(BUILT-IN)" : string(); + info.description = extension.description; + for (idx_t k = 0; k < alias_count; k++) { + auto alias = ExtensionHelper::GetExtensionAlias(k); + if (info.name == alias.extension) { + info.aliases.emplace_back(alias.alias); + } + } + installed_extensions[info.name] = std::move(info); + } +#ifndef WASM_LOADABLE_EXTENSIONS + // scan the install directory for installed extensions + auto ext_directory = ExtensionHelper::ExtensionDirectory(context); + fs.ListFiles(ext_directory, [&](const string &path, bool is_directory) { + if (!StringUtil::EndsWith(path, ".duckdb_extension")) { + return; + } + ExtensionInformation info; + info.name = fs.ExtractBaseName(path); + info.loaded = false; + info.file_path = fs.JoinPath(ext_directory, path); + auto entry = installed_extensions.find(info.name); + if (entry == installed_extensions.end()) { + installed_extensions[info.name] = std::move(info); + } else { + if (!entry->second.loaded) { + entry->second.file_path = info.file_path; + } + entry->second.installed = true; + } + }); +#endif + // now check the list of currently loaded extensions + auto &loaded_extensions = db.LoadedExtensions(); + for (auto &ext_name : loaded_extensions) { + auto entry = installed_extensions.find(ext_name); + if (entry == installed_extensions.end()) { + ExtensionInformation info; + info.name = ext_name; + info.loaded = true; + installed_extensions[ext_name] = std::move(info); + } else { + entry->second.loaded = true; + } + } + + result->entries.reserve(installed_extensions.size()); + for (auto &kv : installed_extensions) { + result->entries.push_back(std::move(kv.second)); + } + return std::move(result); +} + +void DuckDBExtensionsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset]; + + // return values: + // extension_name LogicalType::VARCHAR + output.SetValue(0, count, Value(entry.name)); + // loaded LogicalType::BOOLEAN + output.SetValue(1, count, Value::BOOLEAN(entry.loaded)); + // installed LogicalType::BOOLEAN + output.SetValue(2, count, !entry.installed && entry.loaded ? Value() : Value::BOOLEAN(entry.installed)); + // install_path LogicalType::VARCHAR + output.SetValue(3, count, Value(entry.file_path)); + // description LogicalType::VARCHAR + output.SetValue(4, count, Value(entry.description)); + // aliases LogicalType::LIST(LogicalType::VARCHAR) + output.SetValue(5, count, Value::LIST(LogicalType::VARCHAR, entry.aliases)); + + data.offset++; + count++; + } + output.SetCardinality(count); +} + +void DuckDBExtensionsFun::RegisterFunction(BuiltinFunctions &set) { + TableFunctionSet functions("duckdb_extensions"); + functions.AddFunction(TableFunction({}, DuckDBExtensionsFunction, DuckDBExtensionsBind, DuckDBExtensionsInit)); + set.AddFunction(functions); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_functions.cpp b/src/duckdb/src/function/table/system/duckdb_functions.cpp new file mode 100644 index 00000000..2eb6499e --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_functions.cpp @@ -0,0 +1,518 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" +#include "duckdb/function/table_macro_function.hpp" +#include "duckdb/function/scalar_macro_function.hpp" + +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/main/client_data.hpp" + +namespace duckdb { + +struct DuckDBFunctionsData : public GlobalTableFunctionState { + DuckDBFunctionsData() : offset(0), offset_in_entry(0) { + } + + vector> entries; + idx_t offset; + idx_t offset_in_entry; +}; + +static unique_ptr DuckDBFunctionsBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("function_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("function_type"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("description"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("return_type"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("parameters"); + return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); + + names.emplace_back("parameter_types"); + return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); + + names.emplace_back("varargs"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("macro_definition"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("has_side_effects"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("internal"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("function_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("example"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +static void ExtractFunctionsFromSchema(ClientContext &context, SchemaCatalogEntry &schema, + DuckDBFunctionsData &result) { + schema.Scan(context, CatalogType::SCALAR_FUNCTION_ENTRY, + [&](CatalogEntry &entry) { result.entries.push_back(entry); }); + schema.Scan(context, CatalogType::TABLE_FUNCTION_ENTRY, + [&](CatalogEntry &entry) { result.entries.push_back(entry); }); + schema.Scan(context, CatalogType::PRAGMA_FUNCTION_ENTRY, + [&](CatalogEntry &entry) { result.entries.push_back(entry); }); +} + +unique_ptr DuckDBFunctionsInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas for tables and collect them and collect them + auto schemas = Catalog::GetAllSchemas(context); + for (auto &schema : schemas) { + ExtractFunctionsFromSchema(context, schema.get(), *result); + }; + + std::sort(result->entries.begin(), result->entries.end(), + [&](reference a, reference b) { + return (int32_t)a.get().type < (int32_t)b.get().type; + }); + return std::move(result); +} + +struct ScalarFunctionExtractor { + static idx_t FunctionCount(ScalarFunctionCatalogEntry &entry) { + return entry.functions.Size(); + } + + static Value GetFunctionType() { + return Value("scalar"); + } + + static Value GetReturnType(ScalarFunctionCatalogEntry &entry, idx_t offset) { + return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); + } + + static vector GetParameters(ScalarFunctionCatalogEntry &entry, idx_t offset) { + vector results; + for (idx_t i = 0; i < entry.functions.GetFunctionByOffset(offset).arguments.size(); i++) { + results.emplace_back("col" + to_string(i)); + } + return results; + } + + static Value GetParameterTypes(ScalarFunctionCatalogEntry &entry, idx_t offset) { + vector results; + auto fun = entry.functions.GetFunctionByOffset(offset); + for (idx_t i = 0; i < fun.arguments.size(); i++) { + results.emplace_back(fun.arguments[i].ToString()); + } + return Value::LIST(LogicalType::VARCHAR, std::move(results)); + } + + static Value GetVarArgs(ScalarFunctionCatalogEntry &entry, idx_t offset) { + auto fun = entry.functions.GetFunctionByOffset(offset); + return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); + } + + static Value GetMacroDefinition(ScalarFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value HasSideEffects(ScalarFunctionCatalogEntry &entry, idx_t offset) { + return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).side_effects == + FunctionSideEffects::HAS_SIDE_EFFECTS); + } +}; + +struct AggregateFunctionExtractor { + static idx_t FunctionCount(AggregateFunctionCatalogEntry &entry) { + return entry.functions.Size(); + } + + static Value GetFunctionType() { + return Value("aggregate"); + } + + static Value GetReturnType(AggregateFunctionCatalogEntry &entry, idx_t offset) { + return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); + } + + static vector GetParameters(AggregateFunctionCatalogEntry &entry, idx_t offset) { + vector results; + for (idx_t i = 0; i < entry.functions.GetFunctionByOffset(offset).arguments.size(); i++) { + results.emplace_back("col" + to_string(i)); + } + return results; + } + + static Value GetParameterTypes(AggregateFunctionCatalogEntry &entry, idx_t offset) { + vector results; + auto fun = entry.functions.GetFunctionByOffset(offset); + for (idx_t i = 0; i < fun.arguments.size(); i++) { + results.emplace_back(fun.arguments[i].ToString()); + } + return Value::LIST(LogicalType::VARCHAR, std::move(results)); + } + + static Value GetVarArgs(AggregateFunctionCatalogEntry &entry, idx_t offset) { + auto fun = entry.functions.GetFunctionByOffset(offset); + return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); + } + + static Value GetMacroDefinition(AggregateFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value HasSideEffects(AggregateFunctionCatalogEntry &entry, idx_t offset) { + return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).side_effects == + FunctionSideEffects::HAS_SIDE_EFFECTS); + } +}; + +struct MacroExtractor { + static idx_t FunctionCount(ScalarMacroCatalogEntry &entry) { + return 1; + } + + static Value GetFunctionType() { + return Value("macro"); + } + + static Value GetReturnType(ScalarMacroCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static vector GetParameters(ScalarMacroCatalogEntry &entry, idx_t offset) { + vector results; + for (auto ¶m : entry.function->parameters) { + D_ASSERT(param->type == ExpressionType::COLUMN_REF); + auto &colref = param->Cast(); + results.emplace_back(colref.GetColumnName()); + } + for (auto ¶m_entry : entry.function->default_parameters) { + results.emplace_back(param_entry.first); + } + return results; + } + + static Value GetParameterTypes(ScalarMacroCatalogEntry &entry, idx_t offset) { + vector results; + for (idx_t i = 0; i < entry.function->parameters.size(); i++) { + results.emplace_back(LogicalType::VARCHAR); + } + for (idx_t i = 0; i < entry.function->default_parameters.size(); i++) { + results.emplace_back(LogicalType::VARCHAR); + } + return Value::LIST(LogicalType::VARCHAR, std::move(results)); + } + + static Value GetVarArgs(ScalarMacroCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value GetMacroDefinition(ScalarMacroCatalogEntry &entry, idx_t offset) { + D_ASSERT(entry.function->type == MacroType::SCALAR_MACRO); + auto &func = entry.function->Cast(); + return func.expression->ToString(); + } + + static Value HasSideEffects(ScalarMacroCatalogEntry &entry, idx_t offset) { + return Value(); + } +}; + +struct TableMacroExtractor { + static idx_t FunctionCount(TableMacroCatalogEntry &entry) { + return 1; + } + + static Value GetFunctionType() { + return Value("table_macro"); + } + + static Value GetReturnType(TableMacroCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static vector GetParameters(TableMacroCatalogEntry &entry, idx_t offset) { + vector results; + for (auto ¶m : entry.function->parameters) { + D_ASSERT(param->type == ExpressionType::COLUMN_REF); + auto &colref = param->Cast(); + results.emplace_back(colref.GetColumnName()); + } + for (auto ¶m_entry : entry.function->default_parameters) { + results.emplace_back(param_entry.first); + } + return results; + } + + static Value GetParameterTypes(TableMacroCatalogEntry &entry, idx_t offset) { + vector results; + for (idx_t i = 0; i < entry.function->parameters.size(); i++) { + results.emplace_back(LogicalType::VARCHAR); + } + for (idx_t i = 0; i < entry.function->default_parameters.size(); i++) { + results.emplace_back(LogicalType::VARCHAR); + } + return Value::LIST(LogicalType::VARCHAR, std::move(results)); + } + + static Value GetVarArgs(TableMacroCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value GetMacroDefinition(TableMacroCatalogEntry &entry, idx_t offset) { + if (entry.function->type == MacroType::SCALAR_MACRO) { + auto &func = entry.function->Cast(); + return func.expression->ToString(); + } + return Value(); + } + + static Value HasSideEffects(TableMacroCatalogEntry &entry, idx_t offset) { + return Value(); + } +}; + +struct TableFunctionExtractor { + static idx_t FunctionCount(TableFunctionCatalogEntry &entry) { + return entry.functions.Size(); + } + + static Value GetFunctionType() { + return Value("table"); + } + + static Value GetReturnType(TableFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static vector GetParameters(TableFunctionCatalogEntry &entry, idx_t offset) { + vector results; + auto fun = entry.functions.GetFunctionByOffset(offset); + for (idx_t i = 0; i < fun.arguments.size(); i++) { + results.emplace_back("col" + to_string(i)); + } + for (auto ¶m : fun.named_parameters) { + results.emplace_back(param.first); + } + return results; + } + + static Value GetParameterTypes(TableFunctionCatalogEntry &entry, idx_t offset) { + vector results; + auto fun = entry.functions.GetFunctionByOffset(offset); + + for (idx_t i = 0; i < fun.arguments.size(); i++) { + results.emplace_back(fun.arguments[i].ToString()); + } + for (auto ¶m : fun.named_parameters) { + results.emplace_back(param.second.ToString()); + } + return Value::LIST(LogicalType::VARCHAR, std::move(results)); + } + + static Value GetVarArgs(TableFunctionCatalogEntry &entry, idx_t offset) { + auto fun = entry.functions.GetFunctionByOffset(offset); + return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); + } + + static Value GetMacroDefinition(TableFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value HasSideEffects(TableFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } +}; + +struct PragmaFunctionExtractor { + static idx_t FunctionCount(PragmaFunctionCatalogEntry &entry) { + return entry.functions.Size(); + } + + static Value GetFunctionType() { + return Value("pragma"); + } + + static Value GetReturnType(PragmaFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static vector GetParameters(PragmaFunctionCatalogEntry &entry, idx_t offset) { + vector results; + auto fun = entry.functions.GetFunctionByOffset(offset); + + for (idx_t i = 0; i < fun.arguments.size(); i++) { + results.emplace_back("col" + to_string(i)); + } + for (auto ¶m : fun.named_parameters) { + results.emplace_back(param.first); + } + return results; + } + + static Value GetParameterTypes(PragmaFunctionCatalogEntry &entry, idx_t offset) { + vector results; + auto fun = entry.functions.GetFunctionByOffset(offset); + + for (idx_t i = 0; i < fun.arguments.size(); i++) { + results.emplace_back(fun.arguments[i].ToString()); + } + for (auto ¶m : fun.named_parameters) { + results.emplace_back(param.second.ToString()); + } + return Value::LIST(LogicalType::VARCHAR, std::move(results)); + } + + static Value GetVarArgs(PragmaFunctionCatalogEntry &entry, idx_t offset) { + auto fun = entry.functions.GetFunctionByOffset(offset); + return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); + } + + static Value GetMacroDefinition(PragmaFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value HasSideEffects(PragmaFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } +}; + +template +bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &output, idx_t output_offset) { + auto &function = entry.Cast(); + idx_t col = 0; + + // database_name, LogicalType::VARCHAR + output.SetValue(col++, output_offset, Value(function.schema.catalog.GetName())); + + // schema_name, LogicalType::VARCHAR + output.SetValue(col++, output_offset, Value(function.schema.name)); + + // function_name, LogicalType::VARCHAR + output.SetValue(col++, output_offset, Value(function.name)); + + // function_type, LogicalType::VARCHAR + output.SetValue(col++, output_offset, Value(OP::GetFunctionType())); + + // function_description, LogicalType::VARCHAR + output.SetValue(col++, output_offset, entry.description.empty() ? Value() : entry.description); + + // return_type, LogicalType::VARCHAR + output.SetValue(col++, output_offset, OP::GetReturnType(function, function_idx)); + + // parameters, LogicalType::LIST(LogicalType::VARCHAR) + auto parameters = OP::GetParameters(function, function_idx); + for (idx_t param_idx = 0; param_idx < function.parameter_names.size() && param_idx < parameters.size(); + param_idx++) { + parameters[param_idx] = Value(function.parameter_names[param_idx]); + } + output.SetValue(col++, output_offset, Value::LIST(LogicalType::VARCHAR, std::move(parameters))); + + // parameter_types, LogicalType::LIST(LogicalType::VARCHAR) + output.SetValue(col++, output_offset, OP::GetParameterTypes(function, function_idx)); + + // varargs, LogicalType::VARCHAR + output.SetValue(col++, output_offset, OP::GetVarArgs(function, function_idx)); + + // macro_definition, LogicalType::VARCHAR + output.SetValue(col++, output_offset, OP::GetMacroDefinition(function, function_idx)); + + // has_side_effects, LogicalType::BOOLEAN + output.SetValue(col++, output_offset, OP::HasSideEffects(function, function_idx)); + + // internal, LogicalType::BOOLEAN + output.SetValue(col++, output_offset, Value::BOOLEAN(function.internal)); + + // function_oid, LogicalType::BIGINT + output.SetValue(col++, output_offset, Value::BIGINT(function.oid)); + + // example, LogicalType::VARCHAR + output.SetValue(col++, output_offset, entry.example.empty() ? Value() : entry.example); + + return function_idx + 1 == OP::FunctionCount(function); +} + +void DuckDBFunctionsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset].get().Cast(); + bool finished; + + switch (entry.type) { + case CatalogType::SCALAR_FUNCTION_ENTRY: + finished = ExtractFunctionData( + entry, data.offset_in_entry, output, count); + break; + case CatalogType::AGGREGATE_FUNCTION_ENTRY: + finished = ExtractFunctionData( + entry, data.offset_in_entry, output, count); + break; + case CatalogType::TABLE_MACRO_ENTRY: + finished = ExtractFunctionData(entry, data.offset_in_entry, + output, count); + break; + + case CatalogType::MACRO_ENTRY: + finished = ExtractFunctionData(entry, data.offset_in_entry, output, + count); + break; + case CatalogType::TABLE_FUNCTION_ENTRY: + finished = ExtractFunctionData( + entry, data.offset_in_entry, output, count); + break; + case CatalogType::PRAGMA_FUNCTION_ENTRY: + finished = ExtractFunctionData( + entry, data.offset_in_entry, output, count); + break; + default: + throw InternalException("FIXME: unrecognized function type in duckdb_functions"); + } + if (finished) { + // finished with this function, move to the next function + data.offset++; + data.offset_in_entry = 0; + } else { + // more functions remain + data.offset_in_entry++; + } + count++; + } + output.SetCardinality(count); +} + +void DuckDBFunctionsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction( + TableFunction("duckdb_functions", {}, DuckDBFunctionsFunction, DuckDBFunctionsBind, DuckDBFunctionsInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_indexes.cpp b/src/duckdb/src/function/table/system/duckdb_indexes.cpp new file mode 100644 index 00000000..92067396 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_indexes.cpp @@ -0,0 +1,134 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/index.hpp" + +namespace duckdb { + +struct DuckDBIndexesData : public GlobalTableFunctionState { + DuckDBIndexesData() : offset(0) { + } + + vector> entries; + idx_t offset; +}; + +static unique_ptr DuckDBIndexesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("schema_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("index_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("index_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("table_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("table_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("is_unique"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("is_primary"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("expressions"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("sql"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBIndexesInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas for tables and collect them and collect them + auto schemas = Catalog::GetAllSchemas(context); + for (auto &schema : schemas) { + schema.get().Scan(context, CatalogType::INDEX_ENTRY, + [&](CatalogEntry &entry) { result->entries.push_back(entry); }); + }; + return std::move(result); +} + +void DuckDBIndexesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset++].get(); + + auto &index = entry.Cast(); + // return values: + + idx_t col = 0; + // database_name, VARCHAR + output.SetValue(col++, count, index.catalog.GetName()); + // database_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(index.catalog.GetOid())); + // schema_name, VARCHAR + output.SetValue(col++, count, Value(index.schema.name)); + // schema_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(index.schema.oid)); + // index_name, VARCHAR + output.SetValue(col++, count, Value(index.name)); + // index_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(index.oid)); + // find the table in the catalog + auto &table_entry = + index.schema.catalog.GetEntry(context, index.GetSchemaName(), index.GetTableName()); + // table_name, VARCHAR + output.SetValue(col++, count, Value(table_entry.name)); + // table_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(table_entry.oid)); + if (index.index) { + // is_unique, BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(index.index->IsUnique())); + // is_primary, BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(index.index->IsPrimary())); + } else { + output.SetValue(col++, count, Value()); + output.SetValue(col++, count, Value()); + } + // expressions, VARCHAR + output.SetValue(col++, count, Value()); + // sql, VARCHAR + auto sql = index.ToSQL(); + output.SetValue(col++, count, sql.empty() ? Value() : Value(std::move(sql))); + + count++; + } + output.SetCardinality(count); +} + +void DuckDBIndexesFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_indexes", {}, DuckDBIndexesFunction, DuckDBIndexesBind, DuckDBIndexesInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_keywords.cpp b/src/duckdb/src/function/table/system/duckdb_keywords.cpp new file mode 100644 index 00000000..35167a8b --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_keywords.cpp @@ -0,0 +1,78 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/parser.hpp" + +namespace duckdb { + +struct DuckDBKeywordsData : public GlobalTableFunctionState { + DuckDBKeywordsData() : offset(0) { + } + + vector entries; + idx_t offset; +}; + +static unique_ptr DuckDBKeywordsBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("keyword_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("keyword_category"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBKeywordsInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + result->entries = Parser::KeywordList(); + return std::move(result); +} + +void DuckDBKeywordsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset++]; + + // keyword_name, VARCHAR + output.SetValue(0, count, Value(entry.name)); + // keyword_category, VARCHAR + string category_name; + switch (entry.category) { + case KeywordCategory::KEYWORD_RESERVED: + category_name = "reserved"; + break; + case KeywordCategory::KEYWORD_UNRESERVED: + category_name = "unreserved"; + break; + case KeywordCategory::KEYWORD_TYPE_FUNC: + category_name = "type_function"; + break; + case KeywordCategory::KEYWORD_COL_NAME: + category_name = "column_name"; + break; + default: + throw InternalException("Unrecognized keyword category"); + } + output.SetValue(1, count, Value(std::move(category_name))); + + count++; + } + output.SetCardinality(count); +} + +void DuckDBKeywordsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction( + TableFunction("duckdb_keywords", {}, DuckDBKeywordsFunction, DuckDBKeywordsBind, DuckDBKeywordsInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_schemas.cpp b/src/duckdb/src/function/table/system/duckdb_schemas.cpp new file mode 100644 index 00000000..f0c4e6e4 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_schemas.cpp @@ -0,0 +1,88 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" + +namespace duckdb { + +struct DuckDBSchemasData : public GlobalTableFunctionState { + DuckDBSchemasData() : offset(0) { + } + + vector> entries; + idx_t offset; +}; + +static unique_ptr DuckDBSchemasBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("internal"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("sql"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBSchemasInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas and collect them + result->entries = Catalog::GetAllSchemas(context); + + return std::move(result); +} + +void DuckDBSchemasFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset].get(); + + // return values: + idx_t col = 0; + // "oid", PhysicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(entry.oid)); + // database_name, VARCHAR + output.SetValue(col++, count, entry.catalog.GetName()); + // database_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(entry.catalog.GetOid())); + // "schema_name", PhysicalType::VARCHAR + output.SetValue(col++, count, Value(entry.name)); + // "internal", PhysicalType::BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(entry.internal)); + // "sql", PhysicalType::VARCHAR + output.SetValue(col++, count, Value()); + + data.offset++; + count++; + } + output.SetCardinality(count); +} + +void DuckDBSchemasFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_schemas", {}, DuckDBSchemasFunction, DuckDBSchemasBind, DuckDBSchemasInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_sequences.cpp b/src/duckdb/src/function/table/system/duckdb_sequences.cpp new file mode 100644 index 00000000..985b8094 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_sequences.cpp @@ -0,0 +1,132 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" + +namespace duckdb { + +struct DuckDBSequencesData : public GlobalTableFunctionState { + DuckDBSequencesData() : offset(0) { + } + + vector> entries; + idx_t offset; +}; + +static unique_ptr DuckDBSequencesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("schema_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("sequence_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("sequence_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("temporary"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("start_value"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("min_value"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("max_value"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("increment_by"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("cycle"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("last_value"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("sql"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBSequencesInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas for tables and collect themand collect them + auto schemas = Catalog::GetAllSchemas(context); + for (auto &schema : schemas) { + schema.get().Scan(context, CatalogType::SEQUENCE_ENTRY, + [&](CatalogEntry &entry) { result->entries.push_back(entry.Cast()); }); + }; + return std::move(result); +} + +void DuckDBSequencesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &seq = data.entries[data.offset++].get(); + + // return values: + idx_t col = 0; + // database_name, VARCHAR + output.SetValue(col++, count, seq.catalog.GetName()); + // database_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(seq.catalog.GetOid())); + // schema_name, VARCHAR + output.SetValue(col++, count, Value(seq.schema.name)); + // schema_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(seq.schema.oid)); + // sequence_name, VARCHAR + output.SetValue(col++, count, Value(seq.name)); + // sequence_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(seq.oid)); + // temporary, BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(seq.temporary)); + // start_value, BIGINT + output.SetValue(col++, count, Value::BIGINT(seq.start_value)); + // min_value, BIGINT + output.SetValue(col++, count, Value::BIGINT(seq.min_value)); + // max_value, BIGINT + output.SetValue(col++, count, Value::BIGINT(seq.max_value)); + // increment_by, BIGINT + output.SetValue(col++, count, Value::BIGINT(seq.increment)); + // cycle, BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(seq.cycle)); + // last_value, BIGINT + output.SetValue(col++, count, seq.usage_count == 0 ? Value() : Value::BOOLEAN(seq.last_value)); + // sql, LogicalType::VARCHAR + output.SetValue(col++, count, Value(seq.ToSQL())); + + count++; + } + output.SetCardinality(count); +} + +void DuckDBSequencesFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction( + TableFunction("duckdb_sequences", {}, DuckDBSequencesFunction, DuckDBSequencesBind, DuckDBSequencesInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_settings.cpp b/src/duckdb/src/function/table/system/duckdb_settings.cpp new file mode 100644 index 00000000..c0fa8bbf --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_settings.cpp @@ -0,0 +1,105 @@ +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +struct DuckDBSettingValue { + string name; + string value; + string description; + string input_type; +}; + +struct DuckDBSettingsData : public GlobalTableFunctionState { + DuckDBSettingsData() : offset(0) { + } + + vector settings; + idx_t offset; +}; + +static unique_ptr DuckDBSettingsBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("value"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("description"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("input_type"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBSettingsInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + auto &config = DBConfig::GetConfig(context); + auto options_count = DBConfig::GetOptionCount(); + for (idx_t i = 0; i < options_count; i++) { + auto option = DBConfig::GetOptionByIndex(i); + D_ASSERT(option); + DuckDBSettingValue value; + value.name = option->name; + value.value = option->get_setting(context).ToString(); + value.description = option->description; + value.input_type = EnumUtil::ToString(option->parameter_type); + + result->settings.push_back(std::move(value)); + } + for (auto &ext_param : config.extension_parameters) { + Value setting_val; + string setting_str_val; + if (context.TryGetCurrentSetting(ext_param.first, setting_val)) { + setting_str_val = setting_val.ToString(); + } + DuckDBSettingValue value; + value.name = ext_param.first; + value.value = std::move(setting_str_val); + value.description = ext_param.second.description; + value.input_type = ext_param.second.type.ToString(); + + result->settings.push_back(std::move(value)); + } + return std::move(result); +} + +void DuckDBSettingsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.settings.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.settings.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.settings[data.offset++]; + + // return values: + // name, LogicalType::VARCHAR + output.SetValue(0, count, Value(entry.name)); + // value, LogicalType::VARCHAR + output.SetValue(1, count, Value(entry.value)); + // description, LogicalType::VARCHAR + output.SetValue(2, count, Value(entry.description)); + // input_type, LogicalType::VARCHAR + output.SetValue(3, count, Value(entry.input_type)); + count++; + } + output.SetCardinality(count); +} + +void DuckDBSettingsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction( + TableFunction("duckdb_settings", {}, DuckDBSettingsFunction, DuckDBSettingsBind, DuckDBSettingsInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_tables.cpp b/src/duckdb/src/function/table/system/duckdb_tables.cpp new file mode 100644 index 00000000..f73ade55 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_tables.cpp @@ -0,0 +1,164 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/constraints/unique_constraint.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table_storage_info.hpp" + +namespace duckdb { + +struct DuckDBTablesData : public GlobalTableFunctionState { + DuckDBTablesData() : offset(0) { + } + + vector> entries; + idx_t offset; +}; + +static unique_ptr DuckDBTablesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("schema_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("table_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("table_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("internal"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("temporary"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("has_primary_key"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("estimated_size"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("column_count"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("index_count"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("check_constraint_count"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("sql"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBTablesInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas for tables and collect themand collect them + auto schemas = Catalog::GetAllSchemas(context); + for (auto &schema : schemas) { + schema.get().Scan(context, CatalogType::TABLE_ENTRY, + [&](CatalogEntry &entry) { result->entries.push_back(entry); }); + }; + return std::move(result); +} + +static bool TableHasPrimaryKey(TableCatalogEntry &table) { + for (auto &constraint : table.GetConstraints()) { + if (constraint->type == ConstraintType::UNIQUE) { + auto &unique = constraint->Cast(); + if (unique.is_primary_key) { + return true; + } + } + } + return false; +} + +static idx_t CheckConstraintCount(TableCatalogEntry &table) { + idx_t check_count = 0; + for (auto &constraint : table.GetConstraints()) { + if (constraint->type == ConstraintType::CHECK) { + check_count++; + } + } + return check_count; +} + +void DuckDBTablesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset++].get(); + + if (entry.type != CatalogType::TABLE_ENTRY) { + continue; + } + auto &table = entry.Cast(); + auto storage_info = table.GetStorageInfo(context); + // return values: + idx_t col = 0; + // database_name, VARCHAR + output.SetValue(col++, count, table.catalog.GetName()); + // database_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(table.catalog.GetOid())); + // schema_name, LogicalType::VARCHAR + output.SetValue(col++, count, Value(table.schema.name)); + // schema_oid, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(table.schema.oid)); + // table_name, LogicalType::VARCHAR + output.SetValue(col++, count, Value(table.name)); + // table_oid, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(table.oid)); + // internal, LogicalType::BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(table.internal)); + // temporary, LogicalType::BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(table.temporary)); + // has_primary_key, LogicalType::BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(TableHasPrimaryKey(table))); + // estimated_size, LogicalType::BIGINT + Value card_val = + storage_info.cardinality == DConstants::INVALID_INDEX ? Value() : Value::BIGINT(storage_info.cardinality); + output.SetValue(col++, count, card_val); + // column_count, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(table.GetColumns().LogicalColumnCount())); + // index_count, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(storage_info.index_info.size())); + // check_constraint_count, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(CheckConstraintCount(table))); + // sql, LogicalType::VARCHAR + output.SetValue(col++, count, Value(table.ToSQL())); + + count++; + } + output.SetCardinality(count); +} + +void DuckDBTablesFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_tables", {}, DuckDBTablesFunction, DuckDBTablesBind, DuckDBTablesInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp b/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp new file mode 100644 index 00000000..d4f043af --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp @@ -0,0 +1,59 @@ +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +struct DuckDBTemporaryFilesData : public GlobalTableFunctionState { + DuckDBTemporaryFilesData() : offset(0) { + } + + vector entries; + idx_t offset; +}; + +static unique_ptr DuckDBTemporaryFilesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("path"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("size"); + return_types.emplace_back(LogicalType::BIGINT); + + return nullptr; +} + +unique_ptr DuckDBTemporaryFilesInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + result->entries = BufferManager::GetBufferManager(context).GetTemporaryFiles(); + return std::move(result); +} + +void DuckDBTemporaryFilesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset++]; + // return values: + idx_t col = 0; + // database_name, VARCHAR + output.SetValue(col++, count, entry.path); + // database_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(entry.size)); + count++; + } + output.SetCardinality(count); +} + +void DuckDBTemporaryFilesFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_temporary_files", {}, DuckDBTemporaryFilesFunction, DuckDBTemporaryFilesBind, + DuckDBTemporaryFilesInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_types.cpp b/src/duckdb/src/function/table/system/duckdb_types.cpp new file mode 100644 index 00000000..7647bcbd --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_types.cpp @@ -0,0 +1,188 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" + +namespace duckdb { + +struct DuckDBTypesData : public GlobalTableFunctionState { + DuckDBTypesData() : offset(0) { + } + + vector> entries; + idx_t offset; + unordered_set oids; +}; + +static unique_ptr DuckDBTypesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("schema_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("type_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("type_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("type_size"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("logical_type"); + return_types.emplace_back(LogicalType::VARCHAR); + + // NUMERIC, STRING, DATETIME, BOOLEAN, COMPOSITE, USER + names.emplace_back("type_category"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("internal"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("labels"); + return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); + + return nullptr; +} + +unique_ptr DuckDBTypesInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + auto schemas = Catalog::GetAllSchemas(context); + for (auto &schema : schemas) { + schema.get().Scan(context, CatalogType::TYPE_ENTRY, + [&](CatalogEntry &entry) { result->entries.push_back(entry.Cast()); }); + }; + return std::move(result); +} + +void DuckDBTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &type_entry = data.entries[data.offset++].get(); + auto &type = type_entry.user_type; + + // return values: + idx_t col = 0; + // database_name, VARCHAR + output.SetValue(col++, count, type_entry.catalog.GetName()); + // database_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(type_entry.catalog.GetOid())); + // schema_name, LogicalType::VARCHAR + output.SetValue(col++, count, Value(type_entry.schema.name)); + // schema_oid, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(type_entry.schema.oid)); + // type_oid, BIGINT + int64_t oid; + if (type_entry.internal) { + oid = int64_t(type.id()); + } else { + oid = type_entry.oid; + } + Value oid_val; + if (data.oids.find(oid) == data.oids.end()) { + data.oids.insert(oid); + oid_val = Value::BIGINT(oid); + } else { + oid_val = Value(); + } + output.SetValue(col++, count, oid_val); + // type_name, VARCHAR + output.SetValue(col++, count, Value(type_entry.name)); + // type_size, BIGINT + auto internal_type = type.InternalType(); + output.SetValue(col++, count, + internal_type == PhysicalType::INVALID ? Value() : Value::BIGINT(GetTypeIdSize(internal_type))); + // logical_type, VARCHAR + output.SetValue(col++, count, Value(EnumUtil::ToString(type.id()))); + // type_category, VARCHAR + string category; + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::HUGEINT: + category = "NUMERIC"; + break; + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::INTERVAL: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIMESTAMP_TZ: + category = "DATETIME"; + break; + case LogicalTypeId::CHAR: + case LogicalTypeId::VARCHAR: + category = "STRING"; + break; + case LogicalTypeId::BOOLEAN: + category = "BOOLEAN"; + break; + case LogicalTypeId::STRUCT: + case LogicalTypeId::LIST: + case LogicalTypeId::MAP: + case LogicalTypeId::UNION: + category = "COMPOSITE"; + break; + default: + break; + } + output.SetValue(col++, count, category.empty() ? Value() : Value(category)); + // internal, BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(type_entry.internal)); + // labels, VARCHAR[] + if (type.id() == LogicalTypeId::ENUM && type.AuxInfo()) { + auto data = FlatVector::GetData(EnumType::GetValuesInsertOrder(type)); + idx_t size = EnumType::GetSize(type); + + vector labels; + for (idx_t i = 0; i < size; i++) { + labels.emplace_back(data[i]); + } + + output.SetValue(col++, count, Value::LIST(labels)); + } else { + output.SetValue(col++, count, Value()); + } + + count++; + } + output.SetCardinality(count); +} + +void DuckDBTypesFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_types", {}, DuckDBTypesFunction, DuckDBTypesBind, DuckDBTypesInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_views.cpp b/src/duckdb/src/function/table/system/duckdb_views.cpp new file mode 100644 index 00000000..6375db25 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_views.cpp @@ -0,0 +1,116 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" + +namespace duckdb { + +struct DuckDBViewsData : public GlobalTableFunctionState { + DuckDBViewsData() : offset(0) { + } + + vector> entries; + idx_t offset; +}; + +static unique_ptr DuckDBViewsBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("schema_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("schema_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("view_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("view_oid"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("internal"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("temporary"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("column_count"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("sql"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBViewsInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + // scan all the schemas for tables and collect them and collect them + auto schemas = Catalog::GetAllSchemas(context); + for (auto &schema : schemas) { + schema.get().Scan(context, CatalogType::VIEW_ENTRY, + [&](CatalogEntry &entry) { result->entries.push_back(entry); }); + }; + return std::move(result); +} + +void DuckDBViewsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.entries[data.offset++].get(); + + if (entry.type != CatalogType::VIEW_ENTRY) { + continue; + } + auto &view = entry.Cast(); + + // return values: + idx_t col = 0; + // database_name, VARCHAR + output.SetValue(col++, count, view.catalog.GetName()); + // database_oid, BIGINT + output.SetValue(col++, count, Value::BIGINT(view.catalog.GetOid())); + // schema_name, LogicalType::VARCHAR + output.SetValue(col++, count, Value(view.schema.name)); + // schema_oid, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(view.schema.oid)); + // view_name, LogicalType::VARCHAR + output.SetValue(col++, count, Value(view.name)); + // view_oid, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(view.oid)); + // internal, LogicalType::BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(view.internal)); + // temporary, LogicalType::BOOLEAN + output.SetValue(col++, count, Value::BOOLEAN(view.temporary)); + // column_count, LogicalType::BIGINT + output.SetValue(col++, count, Value::BIGINT(view.types.size())); + // sql, LogicalType::VARCHAR + output.SetValue(col++, count, Value(view.ToSQL())); + + count++; + } + output.SetCardinality(count); +} + +void DuckDBViewsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_views", {}, DuckDBViewsFunction, DuckDBViewsBind, DuckDBViewsInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_collations.cpp b/src/duckdb/src/function/table/system/pragma_collations.cpp new file mode 100644 index 00000000..187ae174 --- /dev/null +++ b/src/duckdb/src/function/table/system/pragma_collations.cpp @@ -0,0 +1,58 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +struct PragmaCollateData : public GlobalTableFunctionState { + PragmaCollateData() : offset(0) { + } + + vector entries; + idx_t offset; +}; + +static unique_ptr PragmaCollateBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("collname"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr PragmaCollateInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + + auto schemas = Catalog::GetAllSchemas(context); + for (auto schema : schemas) { + schema.get().Scan(context, CatalogType::COLLATION_ENTRY, + [&](CatalogEntry &entry) { result->entries.push_back(entry.name); }); + } + return std::move(result); +} + +static void PragmaCollateFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, data.entries.size()); + output.SetCardinality(next - data.offset); + for (idx_t i = data.offset; i < next; i++) { + auto index = i - data.offset; + output.SetValue(0, index, Value(data.entries[i])); + } + + data.offset = next; +} + +void PragmaCollations::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction( + TableFunction("pragma_collations", {}, PragmaCollateFunction, PragmaCollateBind, PragmaCollateInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_database_size.cpp b/src/duckdb/src/function/table/system/pragma_database_size.cpp new file mode 100644 index 00000000..c10eae0c --- /dev/null +++ b/src/duckdb/src/function/table/system/pragma_database_size.cpp @@ -0,0 +1,97 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/database_manager.hpp" + +namespace duckdb { + +struct PragmaDatabaseSizeData : public GlobalTableFunctionState { + PragmaDatabaseSizeData() : index(0) { + } + + idx_t index; + vector> databases; + Value memory_usage; + Value memory_limit; +}; + +static unique_ptr PragmaDatabaseSizeBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("database_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("database_size"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("block_size"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("total_blocks"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("used_blocks"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("free_blocks"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("wal_size"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("memory_usage"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("memory_limit"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr PragmaDatabaseSizeInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + result->databases = DatabaseManager::Get(context).GetDatabases(context); + auto &buffer_manager = BufferManager::GetBufferManager(context); + result->memory_usage = Value(StringUtil::BytesToHumanReadableString(buffer_manager.GetUsedMemory())); + auto max_memory = buffer_manager.GetMaxMemory(); + result->memory_limit = + max_memory == (idx_t)-1 ? Value("Unlimited") : Value(StringUtil::BytesToHumanReadableString(max_memory)); + + return std::move(result); +} + +void PragmaDatabaseSizeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + idx_t row = 0; + for (; data.index < data.databases.size() && row < STANDARD_VECTOR_SIZE; data.index++) { + auto &db = data.databases[data.index].get(); + if (db.IsSystem() || db.IsTemporary()) { + continue; + } + auto ds = db.GetCatalog().GetDatabaseSize(context); + idx_t col = 0; + output.data[col++].SetValue(row, Value(db.GetName())); + output.data[col++].SetValue(row, Value(StringUtil::BytesToHumanReadableString(ds.bytes))); + output.data[col++].SetValue(row, Value::BIGINT(ds.block_size)); + output.data[col++].SetValue(row, Value::BIGINT(ds.total_blocks)); + output.data[col++].SetValue(row, Value::BIGINT(ds.used_blocks)); + output.data[col++].SetValue(row, Value::BIGINT(ds.free_blocks)); + output.data[col++].SetValue( + row, ds.wal_size == idx_t(-1) ? Value() : Value(StringUtil::BytesToHumanReadableString(ds.wal_size))); + output.data[col++].SetValue(row, data.memory_usage); + output.data[col++].SetValue(row, data.memory_limit); + row++; + } + output.SetCardinality(row); +} + +void PragmaDatabaseSize::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("pragma_database_size", {}, PragmaDatabaseSizeFunction, PragmaDatabaseSizeBind, + PragmaDatabaseSizeInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_metadata_info.cpp b/src/duckdb/src/function/table/system/pragma_metadata_info.cpp new file mode 100644 index 00000000..92c18030 --- /dev/null +++ b/src/duckdb/src/function/table/system/pragma_metadata_info.cpp @@ -0,0 +1,83 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/storage/database_size.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/function/function_set.hpp" +namespace duckdb { + +struct PragmaMetadataFunctionData : public TableFunctionData { + explicit PragmaMetadataFunctionData() { + } + + vector metadata_info; +}; + +struct PragmaMetadataOperatorData : public GlobalTableFunctionState { + PragmaMetadataOperatorData() : offset(0) { + } + + idx_t offset; +}; + +static unique_ptr PragmaMetadataInfoBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("block_id"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("total_blocks"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("free_blocks"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("free_list"); + return_types.emplace_back(LogicalType::LIST(LogicalType::BIGINT)); + + string db_name = + input.inputs.empty() ? DatabaseManager::GetDefaultDatabase(context) : StringValue::Get(input.inputs[0]); + auto &catalog = Catalog::GetCatalog(context, db_name); + auto result = make_uniq(); + result->metadata_info = catalog.GetMetadataInfo(context); + return std::move(result); +} + +unique_ptr PragmaMetadataInfoInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void PragmaMetadataInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &data = data_p.global_state->Cast(); + idx_t count = 0; + while (data.offset < bind_data.metadata_info.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = bind_data.metadata_info[data.offset++]; + + idx_t col_idx = 0; + // block_id + output.SetValue(col_idx++, count, Value::BIGINT(entry.block_id)); + // total_blocks + output.SetValue(col_idx++, count, Value::BIGINT(entry.total_blocks)); + // free_blocks + output.SetValue(col_idx++, count, Value::BIGINT(entry.free_list.size())); + // free_list + vector list_values; + for (auto &free_id : entry.free_list) { + list_values.push_back(Value::BIGINT(free_id)); + } + output.SetValue(col_idx++, count, Value::LIST(LogicalType::BIGINT, std::move(list_values))); + count++; + } + output.SetCardinality(count); +} + +void PragmaMetadataInfo::RegisterFunction(BuiltinFunctions &set) { + TableFunctionSet metadata_info("pragma_metadata_info"); + metadata_info.AddFunction( + TableFunction({}, PragmaMetadataInfoFunction, PragmaMetadataInfoBind, PragmaMetadataInfoInit)); + metadata_info.AddFunction(TableFunction({LogicalType::VARCHAR}, PragmaMetadataInfoFunction, PragmaMetadataInfoBind, + PragmaMetadataInfoInit)); + set.AddFunction(metadata_info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_storage_info.cpp b/src/duckdb/src/function/table/system/pragma_storage_info.cpp new file mode 100644 index 00000000..90c60d15 --- /dev/null +++ b/src/duckdb/src/function/table/system/pragma_storage_info.cpp @@ -0,0 +1,151 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/parser/qualified_name.hpp" +#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" +#include "duckdb/planner/constraints/bound_unique_constraint.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/planner/binder.hpp" + +#include + +namespace duckdb { + +struct PragmaStorageFunctionData : public TableFunctionData { + explicit PragmaStorageFunctionData(TableCatalogEntry &table_entry) : table_entry(table_entry) { + } + + TableCatalogEntry &table_entry; + vector column_segments_info; +}; + +struct PragmaStorageOperatorData : public GlobalTableFunctionState { + PragmaStorageOperatorData() : offset(0) { + } + + idx_t offset; +}; + +static unique_ptr PragmaStorageInfoBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("row_group_id"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("column_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("column_id"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("column_path"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("segment_id"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("segment_type"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("start"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("count"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("compression"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("stats"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("has_updates"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("persistent"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("block_id"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("block_offset"); + return_types.emplace_back(LogicalType::BIGINT); + + names.emplace_back("segment_info"); + return_types.emplace_back(LogicalType::VARCHAR); + + auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); + + // look up the table name in the catalog + Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); + auto &table_entry = Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); + auto result = make_uniq(table_entry); + result->column_segments_info = table_entry.GetColumnSegmentInfo(); + return std::move(result); +} + +unique_ptr PragmaStorageInfoInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &data = data_p.global_state->Cast(); + idx_t count = 0; + auto &columns = bind_data.table_entry.GetColumns(); + while (data.offset < bind_data.column_segments_info.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = bind_data.column_segments_info[data.offset++]; + + idx_t col_idx = 0; + // row_group_id + output.SetValue(col_idx++, count, Value::BIGINT(entry.row_group_index)); + // column_name + auto &col = columns.GetColumn(PhysicalIndex(entry.column_id)); + output.SetValue(col_idx++, count, Value(col.Name())); + // column_id + output.SetValue(col_idx++, count, Value::BIGINT(entry.column_id)); + // column_path + output.SetValue(col_idx++, count, Value(entry.column_path)); + // segment_id + output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_idx)); + // segment_type + output.SetValue(col_idx++, count, Value(entry.segment_type)); + // start + output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_start)); + // count + output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_count)); + // compression + output.SetValue(col_idx++, count, Value(entry.compression_type)); + // stats + output.SetValue(col_idx++, count, Value(entry.segment_stats)); + // has_updates + output.SetValue(col_idx++, count, Value::BOOLEAN(entry.has_updates)); + // persistent + output.SetValue(col_idx++, count, Value::BOOLEAN(entry.persistent)); + // block_id + // block_offset + if (entry.persistent) { + output.SetValue(col_idx++, count, Value::BIGINT(entry.block_id)); + output.SetValue(col_idx++, count, Value::BIGINT(entry.block_offset)); + } else { + output.SetValue(col_idx++, count, Value()); + output.SetValue(col_idx++, count, Value()); + } + // segment_info + output.SetValue(col_idx++, count, Value(entry.segment_info)); + count++; + } + output.SetCardinality(count); +} + +void PragmaStorageInfo::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("pragma_storage_info", {LogicalType::VARCHAR}, PragmaStorageInfoFunction, + PragmaStorageInfoBind, PragmaStorageInfoInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_table_info.cpp b/src/duckdb/src/function/table/system/pragma_table_info.cpp new file mode 100644 index 00000000..38c36098 --- /dev/null +++ b/src/duckdb/src/function/table/system/pragma_table_info.cpp @@ -0,0 +1,185 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/parser/qualified_name.hpp" +#include "duckdb/parser/constraints/not_null_constraint.hpp" +#include "duckdb/parser/constraints/unique_constraint.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/binder.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" + +#include + +namespace duckdb { + +struct PragmaTableFunctionData : public TableFunctionData { + explicit PragmaTableFunctionData(CatalogEntry &entry_p) : entry(entry_p) { + } + + CatalogEntry &entry; +}; + +struct PragmaTableOperatorData : public GlobalTableFunctionState { + PragmaTableOperatorData() : offset(0) { + } + idx_t offset; +}; + +static unique_ptr PragmaTableInfoBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + + names.emplace_back("cid"); + return_types.emplace_back(LogicalType::INTEGER); + + names.emplace_back("name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("type"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("notnull"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("dflt_value"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("pk"); + return_types.emplace_back(LogicalType::BOOLEAN); + + auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); + + // look up the table name in the catalog + Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); + auto &entry = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, qname.catalog, qname.schema, qname.name); + return make_uniq(entry); +} + +unique_ptr PragmaTableInfoInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void CheckConstraints(TableCatalogEntry &table, const ColumnDefinition &column, bool &out_not_null, + bool &out_pk) { + out_not_null = false; + out_pk = false; + // check all constraints + // FIXME: this is pretty inefficient, it probably doesn't matter + for (auto &constraint : table.GetConstraints()) { + switch (constraint->type) { + case ConstraintType::NOT_NULL: { + auto ¬_null = constraint->Cast(); + if (not_null.index == column.Logical()) { + out_not_null = true; + } + break; + } + case ConstraintType::UNIQUE: { + auto &unique = constraint->Cast(); + if (unique.is_primary_key) { + if (unique.index == column.Logical()) { + out_pk = true; + } + if (std::find(unique.columns.begin(), unique.columns.end(), column.GetName()) != unique.columns.end()) { + out_pk = true; + } + } + break; + } + default: + break; + } + } +} + +static void PragmaTableInfoTable(PragmaTableOperatorData &data, TableCatalogEntry &table, DataChunk &output) { + if (data.offset >= table.GetColumns().LogicalColumnCount()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, table.GetColumns().LogicalColumnCount()); + output.SetCardinality(next - data.offset); + + for (idx_t i = data.offset; i < next; i++) { + bool not_null, pk; + auto index = i - data.offset; + auto &column = table.GetColumn(LogicalIndex(i)); + D_ASSERT(column.Oid() < (idx_t)NumericLimits::Maximum()); + CheckConstraints(table, column, not_null, pk); + + // return values: + // "cid", PhysicalType::INT32 + output.SetValue(0, index, Value::INTEGER((int32_t)column.Oid())); + // "name", PhysicalType::VARCHAR + output.SetValue(1, index, Value(column.Name())); + // "type", PhysicalType::VARCHAR + output.SetValue(2, index, Value(column.Type().ToString())); + // "notnull", PhysicalType::BOOL + output.SetValue(3, index, Value::BOOLEAN(not_null)); + // "dflt_value", PhysicalType::VARCHAR + Value def_value = column.DefaultValue() ? Value(column.DefaultValue()->ToString()) : Value(); + output.SetValue(4, index, def_value); + // "pk", PhysicalType::BOOL + output.SetValue(5, index, Value::BOOLEAN(pk)); + } + data.offset = next; +} + +static void PragmaTableInfoView(PragmaTableOperatorData &data, ViewCatalogEntry &view, DataChunk &output) { + if (data.offset >= view.types.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, view.types.size()); + output.SetCardinality(next - data.offset); + + for (idx_t i = data.offset; i < next; i++) { + auto index = i - data.offset; + auto type = view.types[i]; + auto &name = view.aliases[i]; + // return values: + // "cid", PhysicalType::INT32 + + output.SetValue(0, index, Value::INTEGER((int32_t)i)); + // "name", PhysicalType::VARCHAR + output.SetValue(1, index, Value(name)); + // "type", PhysicalType::VARCHAR + output.SetValue(2, index, Value(type.ToString())); + // "notnull", PhysicalType::BOOL + output.SetValue(3, index, Value::BOOLEAN(false)); + // "dflt_value", PhysicalType::VARCHAR + output.SetValue(4, index, Value()); + // "pk", PhysicalType::BOOL + output.SetValue(5, index, Value::BOOLEAN(false)); + } + data.offset = next; +} + +static void PragmaTableInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &state = data_p.global_state->Cast(); + switch (bind_data.entry.type) { + case CatalogType::TABLE_ENTRY: + PragmaTableInfoTable(state, bind_data.entry.Cast(), output); + break; + case CatalogType::VIEW_ENTRY: + PragmaTableInfoView(state, bind_data.entry.Cast(), output); + break; + default: + throw NotImplementedException("Unimplemented catalog type for pragma_table_info"); + } +} + +void PragmaTableInfo::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("pragma_table_info", {LogicalType::VARCHAR}, PragmaTableInfoFunction, + PragmaTableInfoBind, PragmaTableInfoInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/test_all_types.cpp b/src/duckdb/src/function/table/system/test_all_types.cpp new file mode 100644 index 00000000..a0b85626 --- /dev/null +++ b/src/duckdb/src/function/table/system/test_all_types.cpp @@ -0,0 +1,276 @@ +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/function/table/system_functions.hpp" + +#include +#include + +namespace duckdb { + +struct TestAllTypesData : public GlobalTableFunctionState { + TestAllTypesData() : offset(0) { + } + + vector> entries; + idx_t offset; +}; + +vector TestAllTypesFun::GetTestTypes(bool use_large_enum) { + vector result; + // scalar types/numerics + result.emplace_back(LogicalType::BOOLEAN, "bool"); + result.emplace_back(LogicalType::TINYINT, "tinyint"); + result.emplace_back(LogicalType::SMALLINT, "smallint"); + result.emplace_back(LogicalType::INTEGER, "int"); + result.emplace_back(LogicalType::BIGINT, "bigint"); + result.emplace_back(LogicalType::HUGEINT, "hugeint"); + result.emplace_back(LogicalType::UTINYINT, "utinyint"); + result.emplace_back(LogicalType::USMALLINT, "usmallint"); + result.emplace_back(LogicalType::UINTEGER, "uint"); + result.emplace_back(LogicalType::UBIGINT, "ubigint"); + result.emplace_back(LogicalType::DATE, "date"); + result.emplace_back(LogicalType::TIME, "time"); + result.emplace_back(LogicalType::TIMESTAMP, "timestamp"); + result.emplace_back(LogicalType::TIMESTAMP_S, "timestamp_s"); + result.emplace_back(LogicalType::TIMESTAMP_MS, "timestamp_ms"); + result.emplace_back(LogicalType::TIMESTAMP_NS, "timestamp_ns"); + result.emplace_back(LogicalType::TIME_TZ, "time_tz"); + result.emplace_back(LogicalType::TIMESTAMP_TZ, "timestamp_tz"); + result.emplace_back(LogicalType::FLOAT, "float"); + result.emplace_back(LogicalType::DOUBLE, "double"); + result.emplace_back(LogicalType::DECIMAL(4, 1), "dec_4_1"); + result.emplace_back(LogicalType::DECIMAL(9, 4), "dec_9_4"); + result.emplace_back(LogicalType::DECIMAL(18, 6), "dec_18_6"); + result.emplace_back(LogicalType::DECIMAL(38, 10), "dec38_10"); + result.emplace_back(LogicalType::UUID, "uuid"); + + // interval + interval_t min_interval; + min_interval.months = 0; + min_interval.days = 0; + min_interval.micros = 0; + + interval_t max_interval; + max_interval.months = 999; + max_interval.days = 999; + max_interval.micros = 999999999; + result.emplace_back(LogicalType::INTERVAL, "interval", Value::INTERVAL(min_interval), + Value::INTERVAL(max_interval)); + // strings/blobs/bitstrings + result.emplace_back(LogicalType::VARCHAR, "varchar", Value("🦆🦆🦆🦆🦆🦆"), + Value(string("goo\x00se", 6))); + result.emplace_back(LogicalType::BLOB, "blob", Value::BLOB("thisisalongblob\\x00withnullbytes"), + Value::BLOB("\\x00\\x00\\x00a")); + result.emplace_back(LogicalType::BIT, "bit", Value::BIT("0010001001011100010101011010111"), Value::BIT("10101")); + + // enums + Vector small_enum(LogicalType::VARCHAR, 2); + auto small_enum_ptr = FlatVector::GetData(small_enum); + small_enum_ptr[0] = StringVector::AddStringOrBlob(small_enum, "DUCK_DUCK_ENUM"); + small_enum_ptr[1] = StringVector::AddStringOrBlob(small_enum, "GOOSE"); + result.emplace_back(LogicalType::ENUM(small_enum, 2), "small_enum"); + + Vector medium_enum(LogicalType::VARCHAR, 300); + auto medium_enum_ptr = FlatVector::GetData(medium_enum); + for (idx_t i = 0; i < 300; i++) { + medium_enum_ptr[i] = StringVector::AddStringOrBlob(medium_enum, string("enum_") + to_string(i)); + } + result.emplace_back(LogicalType::ENUM(medium_enum, 300), "medium_enum"); + + if (use_large_enum) { + // this is a big one... not sure if we should push this one here, but it's required for completeness + Vector large_enum(LogicalType::VARCHAR, 70000); + auto large_enum_ptr = FlatVector::GetData(large_enum); + for (idx_t i = 0; i < 70000; i++) { + large_enum_ptr[i] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(i)); + } + result.emplace_back(LogicalType::ENUM(large_enum, 70000), "large_enum"); + } else { + Vector large_enum(LogicalType::VARCHAR, 2); + auto large_enum_ptr = FlatVector::GetData(large_enum); + large_enum_ptr[0] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(0)); + large_enum_ptr[1] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(69999)); + result.emplace_back(LogicalType::ENUM(large_enum, 2), "large_enum"); + } + + // arrays + auto int_list_type = LogicalType::LIST(LogicalType::INTEGER); + auto empty_int_list = Value::EMPTYLIST(LogicalType::INTEGER); + auto int_list = Value::LIST({Value::INTEGER(42), Value::INTEGER(999), Value(LogicalType::INTEGER), + Value(LogicalType::INTEGER), Value::INTEGER(-42)}); + result.emplace_back(int_list_type, "int_array", empty_int_list, int_list); + + auto double_list_type = LogicalType::LIST(LogicalType::DOUBLE); + auto empty_double_list = Value::EMPTYLIST(LogicalType::DOUBLE); + auto double_list = Value::LIST( + {Value::DOUBLE(42), Value::DOUBLE(NAN), Value::DOUBLE(std::numeric_limits::infinity()), + Value::DOUBLE(-std::numeric_limits::infinity()), Value(LogicalType::DOUBLE), Value::DOUBLE(-42)}); + result.emplace_back(double_list_type, "double_array", empty_double_list, double_list); + + auto date_list_type = LogicalType::LIST(LogicalType::DATE); + auto empty_date_list = Value::EMPTYLIST(LogicalType::DATE); + auto date_list = + Value::LIST({Value::DATE(date_t()), Value::DATE(date_t::infinity()), Value::DATE(date_t::ninfinity()), + Value(LogicalType::DATE), Value::DATE(Date::FromString("2022-05-12"))}); + result.emplace_back(date_list_type, "date_array", empty_date_list, date_list); + + auto timestamp_list_type = LogicalType::LIST(LogicalType::TIMESTAMP); + auto empty_timestamp_list = Value::EMPTYLIST(LogicalType::TIMESTAMP); + auto timestamp_list = Value::LIST({Value::TIMESTAMP(timestamp_t()), Value::TIMESTAMP(timestamp_t::infinity()), + Value::TIMESTAMP(timestamp_t::ninfinity()), Value(LogicalType::TIMESTAMP), + Value::TIMESTAMP(Timestamp::FromString("2022-05-12 16:23:45"))}); + result.emplace_back(timestamp_list_type, "timestamp_array", empty_timestamp_list, timestamp_list); + + auto timestamptz_list_type = LogicalType::LIST(LogicalType::TIMESTAMP_TZ); + auto empty_timestamptz_list = Value::EMPTYLIST(LogicalType::TIMESTAMP_TZ); + auto timestamptz_list = Value::LIST({Value::TIMESTAMPTZ(timestamp_t()), Value::TIMESTAMPTZ(timestamp_t::infinity()), + Value::TIMESTAMPTZ(timestamp_t::ninfinity()), Value(LogicalType::TIMESTAMP_TZ), + Value::TIMESTAMPTZ(Timestamp::FromString("2022-05-12 16:23:45-07"))}); + result.emplace_back(timestamptz_list_type, "timestamptz_array", empty_timestamptz_list, timestamptz_list); + + auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); + auto empty_varchar_list = Value::EMPTYLIST(LogicalType::VARCHAR); + auto varchar_list = + Value::LIST({Value("🦆🦆🦆🦆🦆🦆"), Value("goose"), Value(LogicalType::VARCHAR), Value("")}); + result.emplace_back(varchar_list_type, "varchar_array", empty_varchar_list, varchar_list); + + // nested arrays + auto nested_list_type = LogicalType::LIST(int_list_type); + auto empty_nested_list = Value::EMPTYLIST(int_list_type); + auto nested_int_list = Value::LIST({empty_int_list, int_list, Value(int_list_type), empty_int_list, int_list}); + result.emplace_back(nested_list_type, "nested_int_array", empty_nested_list, nested_int_list); + + // structs + child_list_t struct_type_list; + struct_type_list.push_back(make_pair("a", LogicalType::INTEGER)); + struct_type_list.push_back(make_pair("b", LogicalType::VARCHAR)); + auto struct_type = LogicalType::STRUCT(struct_type_list); + + child_list_t min_struct_list; + min_struct_list.push_back(make_pair("a", Value(LogicalType::INTEGER))); + min_struct_list.push_back(make_pair("b", Value(LogicalType::VARCHAR))); + auto min_struct_val = Value::STRUCT(std::move(min_struct_list)); + + child_list_t max_struct_list; + max_struct_list.push_back(make_pair("a", Value::INTEGER(42))); + max_struct_list.push_back(make_pair("b", Value("🦆🦆🦆🦆🦆🦆"))); + auto max_struct_val = Value::STRUCT(std::move(max_struct_list)); + + result.emplace_back(struct_type, "struct", min_struct_val, max_struct_val); + + // structs with lists + child_list_t struct_list_type_list; + struct_list_type_list.push_back(make_pair("a", int_list_type)); + struct_list_type_list.push_back(make_pair("b", varchar_list_type)); + auto struct_list_type = LogicalType::STRUCT(struct_list_type_list); + + child_list_t min_struct_vl_list; + min_struct_vl_list.push_back(make_pair("a", Value(int_list_type))); + min_struct_vl_list.push_back(make_pair("b", Value(varchar_list_type))); + auto min_struct_val_list = Value::STRUCT(std::move(min_struct_vl_list)); + + child_list_t max_struct_vl_list; + max_struct_vl_list.push_back(make_pair("a", int_list)); + max_struct_vl_list.push_back(make_pair("b", varchar_list)); + auto max_struct_val_list = Value::STRUCT(std::move(max_struct_vl_list)); + + result.emplace_back(struct_list_type, "struct_of_arrays", std::move(min_struct_val_list), + std::move(max_struct_val_list)); + + // array of structs + auto array_of_structs_type = LogicalType::LIST(struct_type); + auto min_array_of_struct_val = Value::EMPTYLIST(struct_type); + auto max_array_of_struct_val = Value::LIST({min_struct_val, max_struct_val, Value(struct_type)}); + result.emplace_back(array_of_structs_type, "array_of_structs", std::move(min_array_of_struct_val), + std::move(max_array_of_struct_val)); + + // map + auto map_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); + auto min_map_value = Value::MAP(ListType::GetChildType(map_type), vector()); + + child_list_t map_struct1; + map_struct1.push_back(make_pair("key", Value("key1"))); + map_struct1.push_back(make_pair("value", Value("🦆🦆🦆🦆🦆🦆"))); + child_list_t map_struct2; + map_struct2.push_back(make_pair("key", Value("key2"))); + map_struct2.push_back(make_pair("key", Value("goose"))); + + vector map_values; + map_values.push_back(Value::STRUCT(map_struct1)); + map_values.push_back(Value::STRUCT(map_struct2)); + + auto max_map_value = Value::MAP(ListType::GetChildType(map_type), map_values); + result.emplace_back(map_type, "map", std::move(min_map_value), std::move(max_map_value)); + + // union + child_list_t members = {{"name", LogicalType::VARCHAR}, {"age", LogicalType::SMALLINT}}; + auto union_type = LogicalType::UNION(members); + const Value &min = Value::UNION(members, 0, Value("Frank")); + const Value &max = Value::UNION(members, 1, Value::SMALLINT(5)); + result.emplace_back(union_type, "union", min, max); + + return result; +} + +struct TestAllTypesBindData : public TableFunctionData { + vector test_types; +}; + +static unique_ptr TestAllTypesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto result = make_uniq(); + bool use_large_enum = false; + auto entry = input.named_parameters.find("use_large_enum"); + if (entry != input.named_parameters.end()) { + use_large_enum = BooleanValue::Get(entry->second); + } + result->test_types = TestAllTypesFun::GetTestTypes(use_large_enum); + for (auto &test_type : result->test_types) { + return_types.push_back(test_type.type); + names.push_back(test_type.name); + } + return std::move(result); +} + +unique_ptr TestAllTypesInit(ClientContext &context, TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->Cast(); + auto result = make_uniq(); + // 3 rows: min, max and NULL + result->entries.resize(3); + // initialize the values + for (auto &test_type : bind_data.test_types) { + result->entries[0].push_back(test_type.min_value); + result->entries[1].push_back(test_type.max_value); + result->entries[2].emplace_back(test_type.type); + } + return std::move(result); +} + +void TestAllTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { + auto &vals = data.entries[data.offset++]; + for (idx_t col_idx = 0; col_idx < vals.size(); col_idx++) { + output.SetValue(col_idx, count, vals[col_idx]); + } + count++; + } + output.SetCardinality(count); +} + +void TestAllTypesFun::RegisterFunction(BuiltinFunctions &set) { + TableFunction test_all_types("test_all_types", {}, TestAllTypesFunction, TestAllTypesBind, TestAllTypesInit); + test_all_types.named_parameters["use_large_enum"] = LogicalType::BOOLEAN; + set.AddFunction(test_all_types); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/test_vector_types.cpp b/src/duckdb/src/function/table/system/test_vector_types.cpp new file mode 100644 index 00000000..7562a031 --- /dev/null +++ b/src/duckdb/src/function/table/system/test_vector_types.cpp @@ -0,0 +1,306 @@ +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/pair.hpp" + +namespace duckdb { + +// FLAT, CONSTANT, DICTIONARY, SEQUENCE +struct TestVectorBindData : public TableFunctionData { + vector types; + bool all_flat = false; +}; + +struct TestVectorTypesData : public GlobalTableFunctionState { + TestVectorTypesData() : offset(0) { + } + + vector> entries; + idx_t offset; +}; + +struct TestVectorInfo { + TestVectorInfo(const vector &types, const map &test_type_map, + vector> &entries) + : types(types), test_type_map(test_type_map), entries(entries) { + } + + const vector &types; + const map &test_type_map; + vector> &entries; +}; + +struct TestGeneratedValues { +public: + void AddColumn(vector values) { + if (!column_values.empty() && column_values[0].size() != values.size()) { + throw InternalException("Size mismatch when adding a column to TestGeneratedValues"); + } + column_values.push_back(std::move(values)); + } + + const Value &GetValue(idx_t row, idx_t column) const { + return column_values[column][row]; + } + + idx_t Rows() const { + return column_values.empty() ? 0 : column_values[0].size(); + } + + idx_t Columns() const { + return column_values.size(); + } + +private: + vector> column_values; +}; + +struct TestVectorFlat { + static constexpr const idx_t TEST_VECTOR_CARDINALITY = 3; + + static vector GenerateValues(TestVectorInfo &info, const LogicalType &type) { + vector result; + switch (type.InternalType()) { + case PhysicalType::STRUCT: { + vector> struct_children; + auto &child_types = StructType::GetChildTypes(type); + + struct_children.resize(TEST_VECTOR_CARDINALITY); + for (auto &child_type : child_types) { + auto child_values = GenerateValues(info, child_type.second); + + for (idx_t i = 0; i < child_values.size(); i++) { + struct_children[i].push_back(make_pair(child_type.first, std::move(child_values[i]))); + } + } + for (auto &struct_child : struct_children) { + result.push_back(Value::STRUCT(std::move(struct_child))); + } + break; + } + case PhysicalType::LIST: { + auto &child_type = ListType::GetChildType(type); + auto child_values = GenerateValues(info, child_type); + + result.push_back(Value::LIST(child_type, {child_values[0], child_values[1]})); + result.push_back(Value::LIST(child_type, {})); + result.push_back(Value::LIST(child_type, {child_values[2]})); + break; + } + default: { + auto entry = info.test_type_map.find(type.id()); + if (entry == info.test_type_map.end()) { + throw NotImplementedException("Unimplemented type for test_vector_types %s", type.ToString()); + } + result.push_back(entry->second.min_value); + result.push_back(entry->second.max_value); + result.emplace_back(type); + break; + } + } + return result; + } + + static TestGeneratedValues GenerateValues(TestVectorInfo &info) { + // generate the values for each column + TestGeneratedValues generated_values; + for (auto &type : info.types) { + generated_values.AddColumn(GenerateValues(info, type)); + } + return generated_values; + } + + static void Generate(TestVectorInfo &info) { + auto result_values = GenerateValues(info); + for (idx_t cur_row = 0; cur_row < result_values.Rows(); cur_row += STANDARD_VECTOR_SIZE) { + auto result = make_uniq(); + result->Initialize(Allocator::DefaultAllocator(), info.types); + auto cardinality = MinValue(STANDARD_VECTOR_SIZE, result_values.Rows() - cur_row); + for (idx_t c = 0; c < info.types.size(); c++) { + for (idx_t i = 0; i < cardinality; i++) { + result->data[c].SetValue(i, result_values.GetValue(cur_row + i, c)); + } + } + result->SetCardinality(cardinality); + info.entries.push_back(std::move(result)); + } + } +}; + +struct TestVectorConstant { + static void Generate(TestVectorInfo &info) { + auto values = TestVectorFlat::GenerateValues(info); + for (idx_t cur_row = 0; cur_row < TestVectorFlat::TEST_VECTOR_CARDINALITY; cur_row += STANDARD_VECTOR_SIZE) { + auto result = make_uniq(); + result->Initialize(Allocator::DefaultAllocator(), info.types); + auto cardinality = MinValue(STANDARD_VECTOR_SIZE, TestVectorFlat::TEST_VECTOR_CARDINALITY - cur_row); + for (idx_t c = 0; c < info.types.size(); c++) { + result->data[c].SetValue(0, values.GetValue(0, c)); + result->data[c].SetVectorType(VectorType::CONSTANT_VECTOR); + } + result->SetCardinality(cardinality); + + info.entries.push_back(std::move(result)); + } + } +}; + +struct TestVectorSequence { + static void GenerateVector(TestVectorInfo &info, const LogicalType &type, Vector &result) { + D_ASSERT(type == result.GetType()); + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + result.Sequence(3, 2, 3); + return; + default: + break; + } + switch (type.InternalType()) { + case PhysicalType::STRUCT: { + auto &child_entries = StructVector::GetEntries(result); + for (auto &child_entry : child_entries) { + GenerateVector(info, child_entry->GetType(), *child_entry); + } + break; + } + case PhysicalType::LIST: { + auto data = FlatVector::GetData(result); + data[0].offset = 0; + data[0].length = 2; + data[1].offset = 2; + data[1].length = 0; + data[2].offset = 2; + data[2].length = 1; + + GenerateVector(info, ListType::GetChildType(type), ListVector::GetEntry(result)); + ListVector::SetListSize(result, 3); + break; + } + default: { + auto entry = info.test_type_map.find(type.id()); + if (entry == info.test_type_map.end()) { + throw NotImplementedException("Unimplemented type for test_vector_types %s", type.ToString()); + } + result.SetValue(0, entry->second.min_value); + result.SetValue(1, entry->second.max_value); + result.SetValue(2, Value(type)); + break; + } + } + } + + static void Generate(TestVectorInfo &info) { +#if STANDARD_VECTOR_SIZE > 2 + auto result = make_uniq(); + result->Initialize(Allocator::DefaultAllocator(), info.types); + + for (idx_t c = 0; c < info.types.size(); c++) { + GenerateVector(info, info.types[c], result->data[c]); + } + result->SetCardinality(3); + info.entries.push_back(std::move(result)); +#endif + } +}; + +struct TestVectorDictionary { + static void Generate(TestVectorInfo &info) { + idx_t current_chunk = info.entries.size(); + + unordered_set slice_entries {1, 2}; + + TestVectorFlat::Generate(info); + idx_t current_idx = 0; + for (idx_t i = current_chunk; i < info.entries.size(); i++) { + auto &chunk = *info.entries[i]; + SelectionVector sel(STANDARD_VECTOR_SIZE); + idx_t sel_idx = 0; + for (idx_t k = 0; k < chunk.size(); k++) { + if (slice_entries.count(current_idx + k) > 0) { + sel.set_index(sel_idx++, k); + } + } + chunk.Slice(sel, sel_idx); + current_idx += chunk.size(); + } + } +}; + +static unique_ptr TestVectorTypesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto result = make_uniq(); + for (idx_t i = 0; i < input.inputs.size(); i++) { + string name = "test_vector"; + if (i > 0) { + name += to_string(i + 1); + } + auto &input_val = input.inputs[i]; + names.emplace_back(name); + return_types.push_back(input_val.type()); + result->types.push_back(input_val.type()); + } + for (auto &entry : input.named_parameters) { + if (entry.first == "all_flat") { + result->all_flat = BooleanValue::Get(entry.second); + } else { + throw InternalException("Unrecognized named parameter for test_vector_types"); + } + } + return std::move(result); +} + +unique_ptr TestVectorTypesInit(ClientContext &context, TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->Cast(); + + auto result = make_uniq(); + + auto test_types = TestAllTypesFun::GetTestTypes(); + + map test_type_map; + for (auto &test_type : test_types) { + test_type_map.insert(make_pair(test_type.type.id(), std::move(test_type))); + } + + TestVectorInfo info(bind_data.types, test_type_map, result->entries); + TestVectorFlat::Generate(info); + TestVectorConstant::Generate(info); + TestVectorDictionary::Generate(info); + TestVectorSequence::Generate(info); + for (auto &entry : result->entries) { + entry->Verify(); + } + if (bind_data.all_flat) { + for (auto &entry : result->entries) { + entry->Flatten(); + entry->Verify(); + } + } + return std::move(result); +} + +void TestVectorTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.entries.size()) { + // finished returning values + return; + } + output.Reference(*data.entries[data.offset]); + data.offset++; +} + +void TestVectorTypesFun::RegisterFunction(BuiltinFunctions &set) { + TableFunction test_vector_types("test_vector_types", {LogicalType::ANY}, TestVectorTypesFunction, + TestVectorTypesBind, TestVectorTypesInit); + test_vector_types.varargs = LogicalType::ANY; + test_vector_types.named_parameters["all_flat"] = LogicalType::BOOLEAN; + + set.AddFunction(std::move(test_vector_types)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system_functions.cpp b/src/duckdb/src/function/table/system_functions.cpp new file mode 100644 index 00000000..7b6a5b04 --- /dev/null +++ b/src/duckdb/src/function/table/system_functions.cpp @@ -0,0 +1,41 @@ +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +void BuiltinFunctions::RegisterSQLiteFunctions() { + PragmaVersion::RegisterFunction(*this); + PragmaPlatform::RegisterFunction(*this); + PragmaCollations::RegisterFunction(*this); + PragmaTableInfo::RegisterFunction(*this); + PragmaStorageInfo::RegisterFunction(*this); + PragmaMetadataInfo::RegisterFunction(*this); + PragmaDatabaseSize::RegisterFunction(*this); + PragmaLastProfilingOutput::RegisterFunction(*this); + PragmaDetailedProfilingOutput::RegisterFunction(*this); + + DuckDBColumnsFun::RegisterFunction(*this); + DuckDBConstraintsFun::RegisterFunction(*this); + DuckDBDatabasesFun::RegisterFunction(*this); + DuckDBFunctionsFun::RegisterFunction(*this); + DuckDBKeywordsFun::RegisterFunction(*this); + DuckDBIndexesFun::RegisterFunction(*this); + DuckDBSchemasFun::RegisterFunction(*this); + DuckDBDependenciesFun::RegisterFunction(*this); + DuckDBExtensionsFun::RegisterFunction(*this); + DuckDBSequencesFun::RegisterFunction(*this); + DuckDBSettingsFun::RegisterFunction(*this); + DuckDBTablesFun::RegisterFunction(*this); + DuckDBTemporaryFilesFun::RegisterFunction(*this); + DuckDBTypesFun::RegisterFunction(*this); + DuckDBViewsFun::RegisterFunction(*this); + TestAllTypesFun::RegisterFunction(*this); + TestVectorTypesFun::RegisterFunction(*this); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/table_scan.cpp b/src/duckdb/src/function/table/table_scan.cpp new file mode 100644 index 00000000..302b9472 --- /dev/null +++ b/src/duckdb/src/function/table/table_scan.cpp @@ -0,0 +1,505 @@ +#include "duckdb/function/table/table_scan.hpp" + +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/transaction/local_storage.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/catalog/dependency_list.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Table Scan +//===--------------------------------------------------------------------===// +bool TableScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, + LocalTableFunctionState *local_state, GlobalTableFunctionState *gstate); + +struct TableScanLocalState : public LocalTableFunctionState { + //! The current position in the scan + TableScanState scan_state; + //! The DataChunk containing all read columns (even filter columns that are immediately removed) + DataChunk all_columns; +}; + +static storage_t GetStorageIndex(TableCatalogEntry &table, column_t column_id) { + if (column_id == DConstants::INVALID_INDEX) { + return column_id; + } + auto &col = table.GetColumn(LogicalIndex(column_id)); + return col.StorageOid(); +} + +struct TableScanGlobalState : public GlobalTableFunctionState { + TableScanGlobalState(ClientContext &context, const FunctionData *bind_data_p) { + D_ASSERT(bind_data_p); + auto &bind_data = bind_data_p->Cast(); + max_threads = bind_data.table.GetStorage().MaxThreads(context); + } + + ParallelTableScanState state; + idx_t max_threads; + + vector projection_ids; + vector scanned_types; + + idx_t MaxThreads() const override { + return max_threads; + } + + bool CanRemoveFilterColumns() const { + return !projection_ids.empty(); + } +}; + +static unique_ptr TableScanInitLocal(ExecutionContext &context, TableFunctionInitInput &input, + GlobalTableFunctionState *gstate) { + auto result = make_uniq(); + auto &bind_data = input.bind_data->Cast(); + vector column_ids = input.column_ids; + for (auto &col : column_ids) { + auto storage_idx = GetStorageIndex(bind_data.table, col); + col = storage_idx; + } + result->scan_state.Initialize(std::move(column_ids), input.filters.get()); + TableScanParallelStateNext(context.client, input.bind_data.get(), result.get(), gstate); + if (input.CanRemoveFilterColumns()) { + auto &tsgs = gstate->Cast(); + result->all_columns.Initialize(context.client, tsgs.scanned_types); + } + return std::move(result); +} + +unique_ptr TableScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { + + D_ASSERT(input.bind_data); + auto &bind_data = input.bind_data->Cast(); + auto result = make_uniq(context, input.bind_data.get()); + bind_data.table.GetStorage().InitializeParallelScan(context, result->state); + if (input.CanRemoveFilterColumns()) { + result->projection_ids = input.projection_ids; + const auto &columns = bind_data.table.GetColumns(); + for (const auto &col_idx : input.column_ids) { + if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { + result->scanned_types.emplace_back(LogicalType::ROW_TYPE); + } else { + result->scanned_types.push_back(columns.GetColumn(LogicalIndex(col_idx)).Type()); + } + } + } + return std::move(result); +} + +static unique_ptr TableScanStatistics(ClientContext &context, const FunctionData *bind_data_p, + column_t column_id) { + auto &bind_data = bind_data_p->Cast(); + auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); + if (local_storage.Find(bind_data.table.GetStorage())) { + // we don't emit any statistics for tables that have outstanding transaction-local data + return nullptr; + } + return bind_data.table.GetStatistics(context, column_id); +} + +static void TableScanFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &gstate = data_p.global_state->Cast(); + auto &state = data_p.local_state->Cast(); + auto &transaction = DuckTransaction::Get(context, bind_data.table.catalog); + auto &storage = bind_data.table.GetStorage(); + do { + if (bind_data.is_create_index) { + storage.CreateIndexScan(state.scan_state, output, + TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED); + } else if (gstate.CanRemoveFilterColumns()) { + state.all_columns.Reset(); + storage.Scan(transaction, state.all_columns, state.scan_state); + output.ReferenceColumns(state.all_columns, gstate.projection_ids); + } else { + storage.Scan(transaction, output, state.scan_state); + } + if (output.size() > 0) { + return; + } + if (!TableScanParallelStateNext(context, data_p.bind_data.get(), data_p.local_state.get(), + data_p.global_state.get())) { + return; + } + } while (true); +} + +bool TableScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, + LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state) { + auto &bind_data = bind_data_p->Cast(); + auto ¶llel_state = global_state->Cast(); + auto &state = local_state->Cast(); + auto &storage = bind_data.table.GetStorage(); + + return storage.NextParallelScan(context, parallel_state.state, state.scan_state); +} + +double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p, + const GlobalTableFunctionState *gstate_p) { + auto &bind_data = bind_data_p->Cast(); + auto &gstate = gstate_p->Cast(); + auto &storage = bind_data.table.GetStorage(); + idx_t total_rows = storage.GetTotalRows(); + if (total_rows == 0) { + //! Table is either empty or smaller than a vector size, so it is finished + return 100; + } + idx_t scanned_rows = gstate.state.scan_state.processed_rows; + scanned_rows += gstate.state.local_state.processed_rows; + auto percentage = 100 * (double(scanned_rows) / total_rows); + if (percentage > 100) { + //! In case the last chunk has less elements than STANDARD_VECTOR_SIZE, if our percentage is over 100 + //! It means we finished this table. + return 100; + } + return percentage; +} + +idx_t TableScanGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, + LocalTableFunctionState *local_state, GlobalTableFunctionState *gstate_p) { + auto &state = local_state->Cast(); + if (state.scan_state.table_state.row_group) { + return state.scan_state.table_state.batch_index; + } + if (state.scan_state.local_state.row_group) { + return state.scan_state.table_state.batch_index + state.scan_state.local_state.batch_index; + } + return 0; +} + +BindInfo TableScanGetBindInfo(const FunctionData *bind_data) { + return BindInfo(ScanType::TABLE); +} + +void TableScanDependency(DependencyList &entries, const FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + entries.AddDependency(bind_data.table); +} + +unique_ptr TableScanCardinality(ClientContext &context, const FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); + auto &storage = bind_data.table.GetStorage(); + idx_t estimated_cardinality = storage.info->cardinality + local_storage.AddedRows(bind_data.table.GetStorage()); + return make_uniq(storage.info->cardinality, estimated_cardinality); +} + +//===--------------------------------------------------------------------===// +// Index Scan +//===--------------------------------------------------------------------===// +struct IndexScanGlobalState : public GlobalTableFunctionState { + explicit IndexScanGlobalState(data_ptr_t row_id_data) : row_ids(LogicalType::ROW_TYPE, row_id_data) { + } + + Vector row_ids; + ColumnFetchState fetch_state; + TableScanState local_storage_state; + vector column_ids; + bool finished; +}; + +static unique_ptr IndexScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->Cast(); + data_ptr_t row_id_data = nullptr; + if (!bind_data.result_ids.empty()) { + row_id_data = (data_ptr_t)&bind_data.result_ids[0]; // NOLINT - this is not pretty + } + auto result = make_uniq(row_id_data); + auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); + + result->column_ids.reserve(input.column_ids.size()); + for (auto &id : input.column_ids) { + result->column_ids.push_back(GetStorageIndex(bind_data.table, id)); + } + result->local_storage_state.Initialize(result->column_ids, input.filters.get()); + local_storage.InitializeScan(bind_data.table.GetStorage(), result->local_storage_state.local_state, input.filters); + + result->finished = false; + return std::move(result); +} + +static void IndexScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &state = data_p.global_state->Cast(); + auto &transaction = DuckTransaction::Get(context, bind_data.table.catalog); + auto &local_storage = LocalStorage::Get(transaction); + + if (!state.finished) { + bind_data.table.GetStorage().Fetch(transaction, output, state.column_ids, state.row_ids, + bind_data.result_ids.size(), state.fetch_state); + state.finished = true; + } + if (output.size() == 0) { + local_storage.Scan(state.local_storage_state.local_state, state.column_ids, output); + } +} + +static void RewriteIndexExpression(Index &index, LogicalGet &get, Expression &expr, bool &rewrite_possible) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + auto &bound_colref = expr.Cast(); + // bound column ref: rewrite to fit in the current set of bound column ids + bound_colref.binding.table_index = get.table_index; + column_t referenced_column = index.column_ids[bound_colref.binding.column_index]; + // search for the referenced column in the set of column_ids + for (idx_t i = 0; i < get.column_ids.size(); i++) { + if (get.column_ids[i] == referenced_column) { + bound_colref.binding.column_index = i; + return; + } + } + // column id not found in bound columns in the LogicalGet: rewrite not possible + rewrite_possible = false; + } + ExpressionIterator::EnumerateChildren( + expr, [&](Expression &child) { RewriteIndexExpression(index, get, child, rewrite_possible); }); +} + +void TableScanPushdownComplexFilter(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p, + vector> &filters) { + auto &bind_data = bind_data_p->Cast(); + auto &table = bind_data.table; + auto &storage = table.GetStorage(); + + auto &config = ClientConfig::GetConfig(context); + if (!config.enable_optimizer) { + // we only push index scans if the optimizer is enabled + return; + } + if (bind_data.is_index_scan) { + return; + } + if (!get.table_filters.filters.empty()) { + // if there were filters before we can't convert this to an index scan + return; + } + if (!get.projection_ids.empty()) { + // if columns were pruned by RemoveUnusedColumns we can't convert this to an index scan, + // because index scan does not support filter_prune (yet) + return; + } + if (filters.empty()) { + // no indexes or no filters: skip the pushdown + return; + } + // behold + storage.info->indexes.Scan([&](Index &index) { + // first rewrite the index expression so the ColumnBindings align with the column bindings of the current table + + if (index.unbound_expressions.size() > 1) { + // NOTE: index scans are not (yet) supported for compound index keys + return false; + } + + auto index_expression = index.unbound_expressions[0]->Copy(); + bool rewrite_possible = true; + RewriteIndexExpression(index, get, *index_expression, rewrite_possible); + if (!rewrite_possible) { + // could not rewrite! + return false; + } + + Value low_value, high_value, equal_value; + ExpressionType low_comparison_type = ExpressionType::INVALID, high_comparison_type = ExpressionType::INVALID; + // try to find a matching index for any of the filter expressions + for (auto &filter : filters) { + auto &expr = *filter; + + // create a matcher for a comparison with a constant + ComparisonExpressionMatcher matcher; + // match on a comparison type + matcher.expr_type = make_uniq(); + // match on a constant comparison with the indexed expression + matcher.matchers.push_back(make_uniq(*index_expression)); + matcher.matchers.push_back(make_uniq()); + + matcher.policy = SetMatcher::Policy::UNORDERED; + + vector> bindings; + if (matcher.Match(expr, bindings)) { + // range or equality comparison with constant value + // we can use our index here + // bindings[0] = the expression + // bindings[1] = the index expression + // bindings[2] = the constant + auto &comparison = bindings[0].get().Cast(); + auto constant_value = bindings[2].get().Cast().value; + auto comparison_type = comparison.type; + if (comparison.left->type == ExpressionType::VALUE_CONSTANT) { + // the expression is on the right side, we flip them around + comparison_type = FlipComparisonExpression(comparison_type); + } + if (comparison_type == ExpressionType::COMPARE_EQUAL) { + // equality value + // equality overrides any other bounds so we just break here + equal_value = constant_value; + break; + } else if (comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || + comparison_type == ExpressionType::COMPARE_GREATERTHAN) { + // greater than means this is a lower bound + low_value = constant_value; + low_comparison_type = comparison_type; + } else { + // smaller than means this is an upper bound + high_value = constant_value; + high_comparison_type = comparison_type; + } + } else if (expr.type == ExpressionType::COMPARE_BETWEEN) { + // BETWEEN expression + auto &between = expr.Cast(); + if (!between.input->Equals(*index_expression)) { + // expression doesn't match the current index expression + continue; + } + if (between.lower->type != ExpressionType::VALUE_CONSTANT || + between.upper->type != ExpressionType::VALUE_CONSTANT) { + // not a constant comparison + continue; + } + low_value = (between.lower->Cast()).value; + low_comparison_type = between.lower_inclusive ? ExpressionType::COMPARE_GREATERTHANOREQUALTO + : ExpressionType::COMPARE_GREATERTHAN; + high_value = (between.upper->Cast()).value; + high_comparison_type = between.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO + : ExpressionType::COMPARE_LESSTHAN; + break; + } + } + if (!equal_value.IsNull() || !low_value.IsNull() || !high_value.IsNull()) { + // we can scan this index using this predicate: try a scan + auto &transaction = Transaction::Get(context, bind_data.table.catalog); + unique_ptr index_state; + if (!equal_value.IsNull()) { + // equality predicate + index_state = + index.InitializeScanSinglePredicate(transaction, equal_value, ExpressionType::COMPARE_EQUAL); + } else if (!low_value.IsNull() && !high_value.IsNull()) { + // two-sided predicate + index_state = index.InitializeScanTwoPredicates(transaction, low_value, low_comparison_type, high_value, + high_comparison_type); + } else if (!low_value.IsNull()) { + // less than predicate + index_state = index.InitializeScanSinglePredicate(transaction, low_value, low_comparison_type); + } else { + D_ASSERT(!high_value.IsNull()); + index_state = index.InitializeScanSinglePredicate(transaction, high_value, high_comparison_type); + } + if (index.Scan(transaction, storage, *index_state, STANDARD_VECTOR_SIZE, bind_data.result_ids)) { + // use an index scan! + bind_data.is_index_scan = true; + get.function = TableScanFunction::GetIndexScanFunction(); + } else { + bind_data.result_ids.clear(); + } + return true; + } + return false; + }); +} + +string TableScanToString(const FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + string result = bind_data.table.name; + return result; +} + +static void TableScanSerialize(Serializer &serializer, const optional_ptr bind_data_p, + const TableFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "catalog", bind_data.table.schema.catalog.GetName()); + serializer.WriteProperty(101, "schema", bind_data.table.schema.name); + serializer.WriteProperty(102, "table", bind_data.table.name); + serializer.WriteProperty(103, "is_index_scan", bind_data.is_index_scan); + serializer.WriteProperty(104, "is_create_index", bind_data.is_create_index); + serializer.WriteProperty(105, "result_ids", bind_data.result_ids); +} + +static unique_ptr TableScanDeserialize(Deserializer &deserializer, TableFunction &function) { + auto catalog = deserializer.ReadProperty(100, "catalog"); + auto schema = deserializer.ReadProperty(101, "schema"); + auto table = deserializer.ReadProperty(102, "table"); + auto &catalog_entry = + Catalog::GetEntry(deserializer.Get(), catalog, schema, table); + if (catalog_entry.type != CatalogType::TABLE_ENTRY) { + throw SerializationException("Cant find table for %s.%s", schema, table); + } + auto result = make_uniq(catalog_entry.Cast()); + deserializer.ReadProperty(103, "is_index_scan", result->is_index_scan); + deserializer.ReadProperty(104, "is_create_index", result->is_create_index); + deserializer.ReadProperty(105, "result_ids", result->result_ids); + return std::move(result); +} + +TableFunction TableScanFunction::GetIndexScanFunction() { + TableFunction scan_function("index_scan", {}, IndexScanFunction); + scan_function.init_local = nullptr; + scan_function.init_global = IndexScanInitGlobal; + scan_function.statistics = TableScanStatistics; + scan_function.dependency = TableScanDependency; + scan_function.cardinality = TableScanCardinality; + scan_function.pushdown_complex_filter = nullptr; + scan_function.to_string = TableScanToString; + scan_function.table_scan_progress = nullptr; + scan_function.get_batch_index = nullptr; + scan_function.projection_pushdown = true; + scan_function.filter_pushdown = false; + scan_function.serialize = TableScanSerialize; + scan_function.deserialize = TableScanDeserialize; + return scan_function; +} + +TableFunction TableScanFunction::GetFunction() { + TableFunction scan_function("seq_scan", {}, TableScanFunc); + scan_function.init_local = TableScanInitLocal; + scan_function.init_global = TableScanInitGlobal; + scan_function.statistics = TableScanStatistics; + scan_function.dependency = TableScanDependency; + scan_function.cardinality = TableScanCardinality; + scan_function.pushdown_complex_filter = TableScanPushdownComplexFilter; + scan_function.to_string = TableScanToString; + scan_function.table_scan_progress = TableScanProgress; + scan_function.get_batch_index = TableScanGetBatchIndex; + scan_function.get_batch_info = TableScanGetBindInfo; + scan_function.projection_pushdown = true; + scan_function.filter_pushdown = true; + scan_function.filter_prune = true; + scan_function.serialize = TableScanSerialize; + scan_function.deserialize = TableScanDeserialize; + return scan_function; +} + +optional_ptr TableScanFunction::GetTableEntry(const TableFunction &function, + const optional_ptr bind_data_p) { + if (function.function != TableScanFunc || !bind_data_p) { + return nullptr; + } + auto &bind_data = bind_data_p->Cast(); + return &bind_data.table; +} + +void TableScanFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunctionSet table_scan_set("seq_scan"); + table_scan_set.AddFunction(GetFunction()); + set.AddFunction(std::move(table_scan_set)); + + set.AddFunction(GetIndexScanFunction()); +} + +void BuiltinFunctions::RegisterTableScanFunctions() { + TableScanFunction::RegisterFunction(*this); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/unnest.cpp b/src/duckdb/src/function/table/unnest.cpp new file mode 100644 index 00000000..15f39508 --- /dev/null +++ b/src/duckdb/src/function/table/unnest.cpp @@ -0,0 +1,86 @@ +#include "duckdb/function/table/range.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_unnest_expression.hpp" +#include "duckdb/execution/operator/projection/physical_unnest.hpp" + +namespace duckdb { + +struct UnnestBindData : public FunctionData { + explicit UnnestBindData(LogicalType input_type_p) : input_type(std::move(input_type_p)) { + } + + LogicalType input_type; + +public: + unique_ptr Copy() const override { + return make_uniq(input_type); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return input_type == other.input_type; + } +}; + +struct UnnestGlobalState : public GlobalTableFunctionState { + UnnestGlobalState() { + } + + vector> select_list; + + idx_t MaxThreads() const override { + return GlobalTableFunctionState::MAX_THREADS; + } +}; + +struct UnnestLocalState : public LocalTableFunctionState { + UnnestLocalState() { + } + + unique_ptr operator_state; +}; + +static unique_ptr UnnestBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + if (input.input_table_types.size() != 1 || input.input_table_types[0].id() != LogicalTypeId::LIST) { + throw BinderException("UNNEST requires a single list as input"); + } + return_types.push_back(ListType::GetChildType(input.input_table_types[0])); + names.push_back(input.input_table_names[0]); + return make_uniq(input.input_table_types[0]); +} + +static unique_ptr UnnestLocalInit(ExecutionContext &context, TableFunctionInitInput &input, + GlobalTableFunctionState *global_state) { + auto &gstate = global_state->Cast(); + + auto result = make_uniq(); + result->operator_state = PhysicalUnnest::GetState(context, gstate.select_list); + return std::move(result); +} + +static unique_ptr UnnestInit(ClientContext &context, TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->Cast(); + auto result = make_uniq(); + auto ref = make_uniq(bind_data.input_type, 0); + auto bound_unnest = make_uniq(ListType::GetChildType(bind_data.input_type)); + bound_unnest->child = std::move(ref); + result->select_list.push_back(std::move(bound_unnest)); + return std::move(result); +} + +static OperatorResultType UnnestFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, + DataChunk &output) { + auto &state = data_p.global_state->Cast(); + auto &lstate = data_p.local_state->Cast(); + return PhysicalUnnest::ExecuteInternal(context, input, output, *lstate.operator_state, state.select_list, false); +} + +void UnnestTableFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunction unnest_function("unnest", {LogicalTypeId::TABLE}, nullptr, UnnestBind, UnnestInit, UnnestLocalInit); + unnest_function.in_out_function = UnnestFunction; + set.AddFunction(unnest_function); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp new file mode 100644 index 00000000..c1f391c6 --- /dev/null +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -0,0 +1,147 @@ +#ifndef DUCKDB_VERSION +#define DUCKDB_VERSION "0.9.0" +#endif +#ifndef DUCKDB_SOURCE_ID +#define DUCKDB_SOURCE_ID "0d84ccf478" +#endif +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/common/string_util.hpp" + +#include + +namespace duckdb { + +struct PragmaVersionData : public GlobalTableFunctionState { + PragmaVersionData() : finished(false) { + } + + bool finished; +}; + +static unique_ptr PragmaVersionBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("library_version"); + return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("source_id"); + return_types.emplace_back(LogicalType::VARCHAR); + return nullptr; +} + +static unique_ptr PragmaVersionInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void PragmaVersionFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.finished) { + // finished returning values + return; + } + output.SetCardinality(1); + output.SetValue(0, 0, DuckDB::LibraryVersion()); + output.SetValue(1, 0, DuckDB::SourceID()); + data.finished = true; +} + +void PragmaVersion::RegisterFunction(BuiltinFunctions &set) { + TableFunction pragma_version("pragma_version", {}, PragmaVersionFunction); + pragma_version.bind = PragmaVersionBind; + pragma_version.init_global = PragmaVersionInit; + set.AddFunction(pragma_version); +} + +idx_t DuckDB::StandardVectorSize() { + return STANDARD_VECTOR_SIZE; +} + +const char *DuckDB::SourceID() { + return DUCKDB_SOURCE_ID; +} + +const char *DuckDB::LibraryVersion() { + return DUCKDB_VERSION; +} + +string DuckDB::Platform() { +#if defined(DUCKDB_CUSTOM_PLATFORM) + return DUCKDB_QUOTE_DEFINE(DUCKDB_CUSTOM_PLATFORM); +#endif +#if defined(DUCKDB_WASM_VERSION) + // DuckDB-Wasm requires CUSTOM_PLATFORM to be defined + static_assert(0, "DUCKDB_WASM_VERSION should rely on CUSTOM_PLATFORM being provided"); +#endif + string os = "linux"; +#if INTPTR_MAX == INT64_MAX + string arch = "amd64"; +#elif INTPTR_MAX == INT32_MAX + string arch = "i686"; +#else +#error Unknown pointer size or missing size macros! +#endif + string postfix = ""; + +#ifdef _WIN32 + os = "windows"; +#elif defined(__APPLE__) + os = "osx"; +#endif +#if defined(__aarch64__) || defined(__ARM_ARCH_ISA_A64) + arch = "arm64"; +#endif + +#if !defined(_GLIBCXX_USE_CXX11_ABI) || _GLIBCXX_USE_CXX11_ABI == 0 + if (os == "linux") { + postfix = "_gcc4"; + } +#endif +#if defined(__ANDROID__) + postfix += "_android"; // using + because it may also be gcc4 +#endif +#ifdef __MINGW32__ + postfix = "_mingw"; +#endif +// this is used for the windows R builds which use a separate build environment +#ifdef DUCKDB_PLATFORM_RTOOLS + postfix = "_rtools"; +#endif + return os + "_" + arch + postfix; +} + +struct PragmaPlatformData : public GlobalTableFunctionState { + PragmaPlatformData() : finished(false) { + } + + bool finished; +}; + +static unique_ptr PragmaPlatformBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("platform"); + return_types.emplace_back(LogicalType::VARCHAR); + return nullptr; +} + +static unique_ptr PragmaPlatformInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +static void PragmaPlatformFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.finished) { + // finished returning values + return; + } + output.SetCardinality(1); + output.SetValue(0, 0, DuckDB::Platform()); + data.finished = true; +} + +void PragmaPlatform::RegisterFunction(BuiltinFunctions &set) { + TableFunction pragma_platform("pragma_platform", {}, PragmaPlatformFunction); + pragma_platform.bind = PragmaPlatformBind; + pragma_platform.init_global = PragmaPlatformInit; + set.AddFunction(pragma_platform); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table_function.cpp b/src/duckdb/src/function/table_function.cpp new file mode 100644 index 00000000..4fcf8d82 --- /dev/null +++ b/src/duckdb/src/function/table_function.cpp @@ -0,0 +1,57 @@ +#include "duckdb/function/table_function.hpp" + +namespace duckdb { + +GlobalTableFunctionState::~GlobalTableFunctionState() { +} + +LocalTableFunctionState::~LocalTableFunctionState() { +} + +TableFunctionInfo::~TableFunctionInfo() { +} + +TableFunction::TableFunction(string name, vector arguments, table_function_t function, + table_function_bind_t bind, table_function_init_global_t init_global, + table_function_init_local_t init_local) + : SimpleNamedParameterFunction(std::move(name), std::move(arguments)), bind(bind), bind_replace(nullptr), + init_global(init_global), init_local(init_local), function(function), in_out_function(nullptr), + in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), + pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_batch_index(nullptr), + get_batch_info(nullptr), serialize(nullptr), deserialize(nullptr), projection_pushdown(false), + filter_pushdown(false), filter_prune(false) { +} + +TableFunction::TableFunction(const vector &arguments, table_function_t function, + table_function_bind_t bind, table_function_init_global_t init_global, + table_function_init_local_t init_local) + : TableFunction(string(), arguments, function, bind, init_global, init_local) { +} +TableFunction::TableFunction() + : SimpleNamedParameterFunction("", {}), bind(nullptr), bind_replace(nullptr), init_global(nullptr), + init_local(nullptr), function(nullptr), in_out_function(nullptr), statistics(nullptr), dependency(nullptr), + cardinality(nullptr), pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), + get_batch_index(nullptr), get_batch_info(nullptr), serialize(nullptr), deserialize(nullptr), + projection_pushdown(false), filter_pushdown(false), filter_prune(false) { +} + +bool TableFunction::Equal(const TableFunction &rhs) const { + // number of types + if (this->arguments.size() != rhs.arguments.size()) { + return false; + } + // argument types + for (idx_t i = 0; i < this->arguments.size(); ++i) { + if (this->arguments[i] != rhs.arguments[i]) { + return false; + } + } + // varargs + if (this->varargs != rhs.varargs) { + return false; + } + + return true; // they are equal +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table_macro_function.cpp b/src/duckdb/src/function/table_macro_function.cpp new file mode 100644 index 00000000..9fbb1792 --- /dev/null +++ b/src/duckdb/src/function/table_macro_function.cpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table_macro_function.hpp +// +// +//===----------------------------------------------------------------------===// +//! The SelectStatement of the view +#include "duckdb/function/table_macro_function.hpp" + +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +TableMacroFunction::TableMacroFunction(unique_ptr query_node) + : MacroFunction(MacroType::TABLE_MACRO), query_node(std::move(query_node)) { +} + +TableMacroFunction::TableMacroFunction(void) : MacroFunction(MacroType::TABLE_MACRO) { +} + +unique_ptr TableMacroFunction::Copy() const { + auto result = make_uniq(); + result->query_node = query_node->Copy(); + this->CopyProperties(*result); + return std::move(result); +} + +string TableMacroFunction::ToSQL(const string &schema, const string &name) const { + return MacroFunction::ToSQL(schema, name) + StringUtil::Format("TABLE (%s);", query_node->ToString()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/udf_function.cpp b/src/duckdb/src/function/udf_function.cpp new file mode 100644 index 00000000..3c03dbbe --- /dev/null +++ b/src/duckdb/src/function/udf_function.cpp @@ -0,0 +1,27 @@ +#include "duckdb/function/udf_function.hpp" + +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" + +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +void UDFWrapper::RegisterFunction(string name, vector args, LogicalType ret_type, + scalar_function_t udf_function, ClientContext &context, LogicalType varargs) { + + ScalarFunction scalar_function(std::move(name), std::move(args), std::move(ret_type), std::move(udf_function)); + scalar_function.varargs = std::move(varargs); + scalar_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + CreateScalarFunctionInfo info(scalar_function); + info.schema = DEFAULT_SCHEMA; + context.RegisterFunction(info); +} + +void UDFWrapper::RegisterAggrFunction(AggregateFunction aggr_function, ClientContext &context, LogicalType varargs) { + aggr_function.varargs = std::move(varargs); + CreateAggregateFunctionInfo info(std::move(aggr_function)); + context.RegisterFunction(info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h new file mode 100644 index 00000000..adc5208a --- /dev/null +++ b/src/duckdb/src/include/duckdb.h @@ -0,0 +1,2516 @@ +//===----------------------------------------------------------------------===// +// +// DuckDB +// +// duckdb.h +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +// duplicate of duckdb/main/winapi.hpp +#ifndef DUCKDB_API +#ifdef _WIN32 +#if defined(DUCKDB_BUILD_LIBRARY) && !defined(DUCKDB_BUILD_LOADABLE_EXTENSION) +#define DUCKDB_API __declspec(dllexport) +#else +#define DUCKDB_API __declspec(dllimport) +#endif +#else +#define DUCKDB_API +#endif +#endif + +// duplicate of duckdb/main/winapi.hpp +#ifndef DUCKDB_EXTENSION_API +#ifdef _WIN32 +#ifdef DUCKDB_BUILD_LOADABLE_EXTENSION +#define DUCKDB_EXTENSION_API __declspec(dllexport) +#else +#define DUCKDB_EXTENSION_API +#endif +#else +#define DUCKDB_EXTENSION_API __attribute__((visibility("default"))) +#endif +#endif + +// API versions +// if no explicit API version is defined, the latest API version is used +// Note that using older API versions (i.e. not using DUCKDB_API_LATEST) is deprecated. +// These will not be supported long-term, and will be removed in future versions. +#ifndef DUCKDB_API_0_3_1 +#define DUCKDB_API_0_3_1 1 +#endif +#ifndef DUCKDB_API_0_3_2 +#define DUCKDB_API_0_3_2 2 +#endif +#ifndef DUCKDB_API_LATEST +#define DUCKDB_API_LATEST DUCKDB_API_0_3_2 +#endif + +#ifndef DUCKDB_API_VERSION +#define DUCKDB_API_VERSION DUCKDB_API_LATEST +#endif + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +//===--------------------------------------------------------------------===// +// Type Information +//===--------------------------------------------------------------------===// +typedef uint64_t idx_t; + +typedef enum DUCKDB_TYPE { + DUCKDB_TYPE_INVALID = 0, + // bool + DUCKDB_TYPE_BOOLEAN, + // int8_t + DUCKDB_TYPE_TINYINT, + // int16_t + DUCKDB_TYPE_SMALLINT, + // int32_t + DUCKDB_TYPE_INTEGER, + // int64_t + DUCKDB_TYPE_BIGINT, + // uint8_t + DUCKDB_TYPE_UTINYINT, + // uint16_t + DUCKDB_TYPE_USMALLINT, + // uint32_t + DUCKDB_TYPE_UINTEGER, + // uint64_t + DUCKDB_TYPE_UBIGINT, + // float + DUCKDB_TYPE_FLOAT, + // double + DUCKDB_TYPE_DOUBLE, + // duckdb_timestamp, in microseconds + DUCKDB_TYPE_TIMESTAMP, + // duckdb_date + DUCKDB_TYPE_DATE, + // duckdb_time + DUCKDB_TYPE_TIME, + // duckdb_interval + DUCKDB_TYPE_INTERVAL, + // duckdb_hugeint + DUCKDB_TYPE_HUGEINT, + // const char* + DUCKDB_TYPE_VARCHAR, + // duckdb_blob + DUCKDB_TYPE_BLOB, + // decimal + DUCKDB_TYPE_DECIMAL, + // duckdb_timestamp, in seconds + DUCKDB_TYPE_TIMESTAMP_S, + // duckdb_timestamp, in milliseconds + DUCKDB_TYPE_TIMESTAMP_MS, + // duckdb_timestamp, in nanoseconds + DUCKDB_TYPE_TIMESTAMP_NS, + // enum type, only useful as logical type + DUCKDB_TYPE_ENUM, + // list type, only useful as logical type + DUCKDB_TYPE_LIST, + // struct type, only useful as logical type + DUCKDB_TYPE_STRUCT, + // map type, only useful as logical type + DUCKDB_TYPE_MAP, + // duckdb_hugeint + DUCKDB_TYPE_UUID, + // union type, only useful as logical type + DUCKDB_TYPE_UNION, + // duckdb_bit + DUCKDB_TYPE_BIT, +} duckdb_type; + +//! Days are stored as days since 1970-01-01 +//! Use the duckdb_from_date/duckdb_to_date function to extract individual information +typedef struct { + int32_t days; +} duckdb_date; + +typedef struct { + int32_t year; + int8_t month; + int8_t day; +} duckdb_date_struct; + +//! Time is stored as microseconds since 00:00:00 +//! Use the duckdb_from_time/duckdb_to_time function to extract individual information +typedef struct { + int64_t micros; +} duckdb_time; + +typedef struct { + int8_t hour; + int8_t min; + int8_t sec; + int32_t micros; +} duckdb_time_struct; + +//! Timestamps are stored as microseconds since 1970-01-01 +//! Use the duckdb_from_timestamp/duckdb_to_timestamp function to extract individual information +typedef struct { + int64_t micros; +} duckdb_timestamp; + +typedef struct { + duckdb_date_struct date; + duckdb_time_struct time; +} duckdb_timestamp_struct; + +typedef struct { + int32_t months; + int32_t days; + int64_t micros; +} duckdb_interval; + +//! Hugeints are composed in a (lower, upper) component +//! The value of the hugeint is upper * 2^64 + lower +//! For easy usage, the functions duckdb_hugeint_to_double/duckdb_double_to_hugeint are recommended +typedef struct { + uint64_t lower; + int64_t upper; +} duckdb_hugeint; + +typedef struct { + uint8_t width; + uint8_t scale; + + duckdb_hugeint value; +} duckdb_decimal; + +typedef struct { + char *data; + idx_t size; +} duckdb_string; + +/* + The internal data representation of a VARCHAR/BLOB column +*/ +typedef struct { + union { + struct { + uint32_t length; + char prefix[4]; + char *ptr; + } pointer; + struct { + uint32_t length; + char inlined[12]; + } inlined; + } value; +} duckdb_string_t; + +typedef struct { + void *data; + idx_t size; +} duckdb_blob; + +typedef struct { + uint64_t offset; + uint64_t length; +} duckdb_list_entry; + +typedef struct { +#if DUCKDB_API_VERSION < DUCKDB_API_0_3_2 + void *data; + bool *nullmask; + duckdb_type type; + char *name; +#else + // deprecated, use duckdb_column_data + void *__deprecated_data; + // deprecated, use duckdb_nullmask_data + bool *__deprecated_nullmask; + // deprecated, use duckdb_column_type + duckdb_type __deprecated_type; + // deprecated, use duckdb_column_name + char *__deprecated_name; +#endif + void *internal_data; +} duckdb_column; + +typedef struct { +#if DUCKDB_API_VERSION < DUCKDB_API_0_3_2 + idx_t column_count; + idx_t row_count; + idx_t rows_changed; + duckdb_column *columns; + char *error_message; +#else + // deprecated, use duckdb_column_count + idx_t __deprecated_column_count; + // deprecated, use duckdb_row_count + idx_t __deprecated_row_count; + // deprecated, use duckdb_rows_changed + idx_t __deprecated_rows_changed; + // deprecated, use duckdb_column_ family of functions + duckdb_column *__deprecated_columns; + // deprecated, use duckdb_result_error + char *__deprecated_error_message; +#endif + void *internal_data; +} duckdb_result; + +typedef struct _duckdb_database { + void *__db; +} * duckdb_database; +typedef struct _duckdb_connection { + void *__conn; +} * duckdb_connection; +typedef struct _duckdb_prepared_statement { + void *__prep; +} * duckdb_prepared_statement; +typedef struct _duckdb_extracted_statements { + void *__extrac; +} * duckdb_extracted_statements; +typedef struct _duckdb_pending_result { + void *__pend; +} * duckdb_pending_result; +typedef struct _duckdb_appender { + void *__appn; +} * duckdb_appender; +typedef struct _duckdb_arrow { + void *__arrw; +} * duckdb_arrow; +typedef struct _duckdb_arrow_stream { + void *__arrwstr; +} * duckdb_arrow_stream; +typedef struct _duckdb_config { + void *__cnfg; +} * duckdb_config; +typedef struct _duckdb_arrow_schema { + void *__arrs; +} * duckdb_arrow_schema; +typedef struct _duckdb_arrow_array { + void *__arra; +} * duckdb_arrow_array; +typedef struct _duckdb_logical_type { + void *__lglt; +} * duckdb_logical_type; +typedef struct _duckdb_data_chunk { + void *__dtck; +} * duckdb_data_chunk; +typedef struct _duckdb_vector { + void *__vctr; +} * duckdb_vector; +typedef struct _duckdb_value { + void *__val; +} * duckdb_value; + +typedef enum { DuckDBSuccess = 0, DuckDBError = 1 } duckdb_state; +typedef enum { + DUCKDB_PENDING_RESULT_READY = 0, + DUCKDB_PENDING_RESULT_NOT_READY = 1, + DUCKDB_PENDING_ERROR = 2, + DUCKDB_PENDING_NO_TASKS_AVAILABLE = 3 +} duckdb_pending_state; + +//===--------------------------------------------------------------------===// +// Open/Connect +//===--------------------------------------------------------------------===// + +/*! +Creates a new database or opens an existing database file stored at the the given path. +If no path is given a new in-memory database is created instead. +The instantiated database should be closed with 'duckdb_close' + +* path: Path to the database file on disk, or `nullptr` or `:memory:` to open an in-memory database. +* out_database: The result database object. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_open(const char *path, duckdb_database *out_database); + +/*! +Extended version of duckdb_open. Creates a new database or opens an existing database file stored at the the given path. + +* path: Path to the database file on disk, or `nullptr` or `:memory:` to open an in-memory database. +* out_database: The result database object. +* config: (Optional) configuration used to start up the database system. +* out_error: If set and the function returns DuckDBError, this will contain the reason why the start-up failed. +Note that the error must be freed using `duckdb_free`. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_open_ext(const char *path, duckdb_database *out_database, duckdb_config config, + char **out_error); + +/*! +Closes the specified database and de-allocates all memory allocated for that database. +This should be called after you are done with any database allocated through `duckdb_open`. +Note that failing to call `duckdb_close` (in case of e.g. a program crash) will not cause data corruption. +Still it is recommended to always correctly close a database object after you are done with it. + +* database: The database object to shut down. +*/ +DUCKDB_API void duckdb_close(duckdb_database *database); + +/*! +Opens a connection to a database. Connections are required to query the database, and store transactional state +associated with the connection. +The instantiated connection should be closed using 'duckdb_disconnect' + +* database: The database file to connect to. +* out_connection: The result connection object. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_connect(duckdb_database database, duckdb_connection *out_connection); + +/*! +Interrupt running query + +* connection: The connection to interruot +*/ +DUCKDB_API void duckdb_interrupt(duckdb_connection connection); + +/*! +Get progress of the running query + +* connection: The working connection +* returns: -1 if no progress or a percentage of the progress +*/ +DUCKDB_API double duckdb_query_progress(duckdb_connection connection); + +/*! +Closes the specified connection and de-allocates all memory allocated for that connection. + +* connection: The connection to close. +*/ +DUCKDB_API void duckdb_disconnect(duckdb_connection *connection); + +/*! +Returns the version of the linked DuckDB, with a version postfix for dev versions + +Usually used for developing C extensions that must return this for a compatibility check. +*/ +DUCKDB_API const char *duckdb_library_version(); + +//===--------------------------------------------------------------------===// +// Configuration +//===--------------------------------------------------------------------===// +/*! +Initializes an empty configuration object that can be used to provide start-up options for the DuckDB instance +through `duckdb_open_ext`. + +This will always succeed unless there is a malloc failure. + +* out_config: The result configuration object. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_create_config(duckdb_config *out_config); + +/*! +This returns the total amount of configuration options available for usage with `duckdb_get_config_flag`. + +This should not be called in a loop as it internally loops over all the options. + +* returns: The amount of config options available. +*/ +DUCKDB_API size_t duckdb_config_count(); + +/*! +Obtains a human-readable name and description of a specific configuration option. This can be used to e.g. +display configuration options. This will succeed unless `index` is out of range (i.e. `>= duckdb_config_count`). + +The result name or description MUST NOT be freed. + +* index: The index of the configuration option (between 0 and `duckdb_config_count`) +* out_name: A name of the configuration flag. +* out_description: A description of the configuration flag. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_get_config_flag(size_t index, const char **out_name, const char **out_description); + +/*! +Sets the specified option for the specified configuration. The configuration option is indicated by name. +To obtain a list of config options, see `duckdb_get_config_flag`. + +In the source code, configuration options are defined in `config.cpp`. + +This can fail if either the name is invalid, or if the value provided for the option is invalid. + +* duckdb_config: The configuration object to set the option on. +* name: The name of the configuration flag to set. +* option: The value to set the configuration flag to. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_set_config(duckdb_config config, const char *name, const char *option); + +/*! +Destroys the specified configuration option and de-allocates all memory allocated for the object. + +* config: The configuration object to destroy. +*/ +DUCKDB_API void duckdb_destroy_config(duckdb_config *config); + +//===--------------------------------------------------------------------===// +// Query Execution +//===--------------------------------------------------------------------===// +/*! +Executes a SQL query within a connection and stores the full (materialized) result in the out_result pointer. +If the query fails to execute, DuckDBError is returned and the error message can be retrieved by calling +`duckdb_result_error`. + +Note that after running `duckdb_query`, `duckdb_destroy_result` must be called on the result object even if the +query fails, otherwise the error stored within the result will not be freed correctly. + +* connection: The connection to perform the query in. +* query: The SQL query to run. +* out_result: The query result. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_query(duckdb_connection connection, const char *query, duckdb_result *out_result); + +/*! +Closes the result and de-allocates all memory allocated for that connection. + +* result: The result to destroy. +*/ +DUCKDB_API void duckdb_destroy_result(duckdb_result *result); + +/*! +Returns the column name of the specified column. The result should not need be freed; the column names will +automatically be destroyed when the result is destroyed. + +Returns `NULL` if the column is out of range. + +* result: The result object to fetch the column name from. +* col: The column index. +* returns: The column name of the specified column. +*/ +DUCKDB_API const char *duckdb_column_name(duckdb_result *result, idx_t col); + +/*! +Returns the column type of the specified column. + +Returns `DUCKDB_TYPE_INVALID` if the column is out of range. + +* result: The result object to fetch the column type from. +* col: The column index. +* returns: The column type of the specified column. +*/ +DUCKDB_API duckdb_type duckdb_column_type(duckdb_result *result, idx_t col); + +/*! +Returns the logical column type of the specified column. + +The return type of this call should be destroyed with `duckdb_destroy_logical_type`. + +Returns `NULL` if the column is out of range. + +* result: The result object to fetch the column type from. +* col: The column index. +* returns: The logical column type of the specified column. +*/ +DUCKDB_API duckdb_logical_type duckdb_column_logical_type(duckdb_result *result, idx_t col); + +/*! +Returns the number of columns present in a the result object. + +* result: The result object. +* returns: The number of columns present in the result object. +*/ +DUCKDB_API idx_t duckdb_column_count(duckdb_result *result); + +/*! +Returns the number of rows present in a the result object. + +* result: The result object. +* returns: The number of rows present in the result object. +*/ +DUCKDB_API idx_t duckdb_row_count(duckdb_result *result); + +/*! +Returns the number of rows changed by the query stored in the result. This is relevant only for INSERT/UPDATE/DELETE +queries. For other queries the rows_changed will be 0. + +* result: The result object. +* returns: The number of rows changed. +*/ +DUCKDB_API idx_t duckdb_rows_changed(duckdb_result *result); + +/*! +**DEPRECATED**: Prefer using `duckdb_result_get_chunk` instead. + +Returns the data of a specific column of a result in columnar format. + +The function returns a dense array which contains the result data. The exact type stored in the array depends on the +corresponding duckdb_type (as provided by `duckdb_column_type`). For the exact type by which the data should be +accessed, see the comments in [the types section](types) or the `DUCKDB_TYPE` enum. + +For example, for a column of type `DUCKDB_TYPE_INTEGER`, rows can be accessed in the following manner: +```c +int32_t *data = (int32_t *) duckdb_column_data(&result, 0); +printf("Data for row %d: %d\n", row, data[row]); +``` + +* result: The result object to fetch the column data from. +* col: The column index. +* returns: The column data of the specified column. +*/ +DUCKDB_API void *duckdb_column_data(duckdb_result *result, idx_t col); + +/*! +**DEPRECATED**: Prefer using `duckdb_result_get_chunk` instead. + +Returns the nullmask of a specific column of a result in columnar format. The nullmask indicates for every row +whether or not the corresponding row is `NULL`. If a row is `NULL`, the values present in the array provided +by `duckdb_column_data` are undefined. + +```c +int32_t *data = (int32_t *) duckdb_column_data(&result, 0); +bool *nullmask = duckdb_nullmask_data(&result, 0); +if (nullmask[row]) { + printf("Data for row %d: NULL\n", row); +} else { + printf("Data for row %d: %d\n", row, data[row]); +} +``` + +* result: The result object to fetch the nullmask from. +* col: The column index. +* returns: The nullmask of the specified column. +*/ +DUCKDB_API bool *duckdb_nullmask_data(duckdb_result *result, idx_t col); + +/*! +Returns the error message contained within the result. The error is only set if `duckdb_query` returns `DuckDBError`. + +The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_result` is called. + +* result: The result object to fetch the error from. +* returns: The error of the result. +*/ +DUCKDB_API const char *duckdb_result_error(duckdb_result *result); + +//===--------------------------------------------------------------------===// +// Result Functions +//===--------------------------------------------------------------------===// + +/*! +Fetches a data chunk from the duckdb_result. This function should be called repeatedly until the result is exhausted. + +The result must be destroyed with `duckdb_destroy_data_chunk`. + +This function supersedes all `duckdb_value` functions, as well as the `duckdb_column_data` and `duckdb_nullmask_data` +functions. It results in significantly better performance, and should be preferred in newer code-bases. + +If this function is used, none of the other result functions can be used and vice versa (i.e. this function cannot be +mixed with the legacy result functions). + +Use `duckdb_result_chunk_count` to figure out how many chunks there are in the result. + +* result: The result object to fetch the data chunk from. +* chunk_index: The chunk index to fetch from. +* returns: The resulting data chunk. Returns `NULL` if the chunk index is out of bounds. +*/ +DUCKDB_API duckdb_data_chunk duckdb_result_get_chunk(duckdb_result result, idx_t chunk_index); + +/*! +Checks if the type of the internal result is StreamQueryResult. + +* result: The result object to check. +* returns: Whether or not the result object is of the type StreamQueryResult +*/ +DUCKDB_API bool duckdb_result_is_streaming(duckdb_result result); + +/*! +Returns the number of data chunks present in the result. + +* result: The result object +* returns: Number of data chunks present in the result. +*/ +DUCKDB_API idx_t duckdb_result_chunk_count(duckdb_result result); + +// Safe fetch functions +// These functions will perform conversions if necessary. +// On failure (e.g. if conversion cannot be performed or if the value is NULL) a default value is returned. +// Note that these functions are slow since they perform bounds checking and conversion +// For fast access of values prefer using `duckdb_result_get_chunk` + +/*! + * returns: The boolean value at the specified location, or false if the value cannot be converted. + */ +DUCKDB_API bool duckdb_value_boolean(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The int8_t value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API int8_t duckdb_value_int8(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The int16_t value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API int16_t duckdb_value_int16(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The int32_t value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API int32_t duckdb_value_int32(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The int64_t value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API int64_t duckdb_value_int64(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The duckdb_hugeint value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API duckdb_hugeint duckdb_value_hugeint(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The duckdb_decimal value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API duckdb_decimal duckdb_value_decimal(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The uint8_t value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API uint8_t duckdb_value_uint8(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The uint16_t value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API uint16_t duckdb_value_uint16(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The uint32_t value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API uint32_t duckdb_value_uint32(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The uint64_t value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API uint64_t duckdb_value_uint64(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The float value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API float duckdb_value_float(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The double value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API double duckdb_value_double(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The duckdb_date value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API duckdb_date duckdb_value_date(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The duckdb_time value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API duckdb_time duckdb_value_time(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The duckdb_timestamp value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API duckdb_timestamp duckdb_value_timestamp(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: The duckdb_interval value at the specified location, or 0 if the value cannot be converted. + */ +DUCKDB_API duckdb_interval duckdb_value_interval(duckdb_result *result, idx_t col, idx_t row); + +/*! +* DEPRECATED: use duckdb_value_string instead. This function does not work correctly if the string contains null bytes. +* returns: The text value at the specified location as a null-terminated string, or nullptr if the value cannot be +converted. The result must be freed with `duckdb_free`. +*/ +DUCKDB_API char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t row); + +/*!s +* returns: The string value at the specified location. +The result must be freed with `duckdb_free`. +*/ +DUCKDB_API duckdb_string duckdb_value_string(duckdb_result *result, idx_t col, idx_t row); + +/*! +* DEPRECATED: use duckdb_value_string_internal instead. This function does not work correctly if the string contains +null bytes. +* returns: The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. +If the column is NOT a VARCHAR column this function will return NULL. + +The result must NOT be freed. +*/ +DUCKDB_API char *duckdb_value_varchar_internal(duckdb_result *result, idx_t col, idx_t row); + +/*! +* DEPRECATED: use duckdb_value_string_internal instead. This function does not work correctly if the string contains +null bytes. +* returns: The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. +If the column is NOT a VARCHAR column this function will return NULL. + +The result must NOT be freed. +*/ +DUCKDB_API duckdb_string duckdb_value_string_internal(duckdb_result *result, idx_t col, idx_t row); + +/*! +* returns: The duckdb_blob value at the specified location. Returns a blob with blob.data set to nullptr if the +value cannot be converted. The resulting "blob.data" must be freed with `duckdb_free.` +*/ +DUCKDB_API duckdb_blob duckdb_value_blob(duckdb_result *result, idx_t col, idx_t row); + +/*! + * returns: Returns true if the value at the specified index is NULL, and false otherwise. + */ +DUCKDB_API bool duckdb_value_is_null(duckdb_result *result, idx_t col, idx_t row); + +//===--------------------------------------------------------------------===// +// Helpers +//===--------------------------------------------------------------------===// +/*! +Allocate `size` bytes of memory using the duckdb internal malloc function. Any memory allocated in this manner +should be freed using `duckdb_free`. + +* size: The number of bytes to allocate. +* returns: A pointer to the allocated memory region. +*/ +DUCKDB_API void *duckdb_malloc(size_t size); + +/*! +Free a value returned from `duckdb_malloc`, `duckdb_value_varchar` or `duckdb_value_blob`. + +* ptr: The memory region to de-allocate. +*/ +DUCKDB_API void duckdb_free(void *ptr); + +/*! +The internal vector size used by DuckDB. +This is the amount of tuples that will fit into a data chunk created by `duckdb_create_data_chunk`. + +* returns: The vector size. +*/ +DUCKDB_API idx_t duckdb_vector_size(); + +/*! +Whether or not the duckdb_string_t value is inlined. +This means that the data of the string does not have a separate allocation. + +*/ +DUCKDB_API bool duckdb_string_is_inlined(duckdb_string_t string); + +//===--------------------------------------------------------------------===// +// Date/Time/Timestamp Helpers +//===--------------------------------------------------------------------===// +/*! +Decompose a `duckdb_date` object into year, month and date (stored as `duckdb_date_struct`). + +* date: The date object, as obtained from a `DUCKDB_TYPE_DATE` column. +* returns: The `duckdb_date_struct` with the decomposed elements. +*/ +DUCKDB_API duckdb_date_struct duckdb_from_date(duckdb_date date); + +/*! +Re-compose a `duckdb_date` from year, month and date (`duckdb_date_struct`). + +* date: The year, month and date stored in a `duckdb_date_struct`. +* returns: The `duckdb_date` element. +*/ +DUCKDB_API duckdb_date duckdb_to_date(duckdb_date_struct date); + +/*! +Decompose a `duckdb_time` object into hour, minute, second and microsecond (stored as `duckdb_time_struct`). + +* time: The time object, as obtained from a `DUCKDB_TYPE_TIME` column. +* returns: The `duckdb_time_struct` with the decomposed elements. +*/ +DUCKDB_API duckdb_time_struct duckdb_from_time(duckdb_time time); + +/*! +Re-compose a `duckdb_time` from hour, minute, second and microsecond (`duckdb_time_struct`). + +* time: The hour, minute, second and microsecond in a `duckdb_time_struct`. +* returns: The `duckdb_time` element. +*/ +DUCKDB_API duckdb_time duckdb_to_time(duckdb_time_struct time); + +/*! +Decompose a `duckdb_timestamp` object into a `duckdb_timestamp_struct`. + +* ts: The ts object, as obtained from a `DUCKDB_TYPE_TIMESTAMP` column. +* returns: The `duckdb_timestamp_struct` with the decomposed elements. +*/ +DUCKDB_API duckdb_timestamp_struct duckdb_from_timestamp(duckdb_timestamp ts); + +/*! +Re-compose a `duckdb_timestamp` from a duckdb_timestamp_struct. + +* ts: The de-composed elements in a `duckdb_timestamp_struct`. +* returns: The `duckdb_timestamp` element. +*/ +DUCKDB_API duckdb_timestamp duckdb_to_timestamp(duckdb_timestamp_struct ts); + +//===--------------------------------------------------------------------===// +// Hugeint Helpers +//===--------------------------------------------------------------------===// +/*! +Converts a duckdb_hugeint object (as obtained from a `DUCKDB_TYPE_HUGEINT` column) into a double. + +* val: The hugeint value. +* returns: The converted `double` element. +*/ +DUCKDB_API double duckdb_hugeint_to_double(duckdb_hugeint val); + +/*! +Converts a double value to a duckdb_hugeint object. + +If the conversion fails because the double value is too big the result will be 0. + +* val: The double value. +* returns: The converted `duckdb_hugeint` element. +*/ +DUCKDB_API duckdb_hugeint duckdb_double_to_hugeint(double val); + +/*! +Converts a double value to a duckdb_decimal object. + +If the conversion fails because the double value is too big, or the width/scale are invalid the result will be 0. + +* val: The double value. +* returns: The converted `duckdb_decimal` element. +*/ +DUCKDB_API duckdb_decimal duckdb_double_to_decimal(double val, uint8_t width, uint8_t scale); + +//===--------------------------------------------------------------------===// +// Decimal Helpers +//===--------------------------------------------------------------------===// +/*! +Converts a duckdb_decimal object (as obtained from a `DUCKDB_TYPE_DECIMAL` column) into a double. + +* val: The decimal value. +* returns: The converted `double` element. +*/ +DUCKDB_API double duckdb_decimal_to_double(duckdb_decimal val); + +//===--------------------------------------------------------------------===// +// Prepared Statements +//===--------------------------------------------------------------------===// +// A prepared statement is a parameterized query that allows you to bind parameters to it. +// * This is useful to easily supply parameters to functions and avoid SQL injection attacks. +// * This is useful to speed up queries that you will execute several times with different parameters. +// Because the query will only be parsed, bound, optimized and planned once during the prepare stage, +// rather than once per execution. +// For example: +// SELECT * FROM tbl WHERE id=? +// Or a query with multiple parameters: +// SELECT * FROM tbl WHERE id=$1 OR name=$2 + +/*! +Create a prepared statement object from a query. + +Note that after calling `duckdb_prepare`, the prepared statement should always be destroyed using +`duckdb_destroy_prepare`, even if the prepare fails. + +If the prepare fails, `duckdb_prepare_error` can be called to obtain the reason why the prepare failed. + +* connection: The connection object +* query: The SQL query to prepare +* out_prepared_statement: The resulting prepared statement object +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_prepare(duckdb_connection connection, const char *query, + duckdb_prepared_statement *out_prepared_statement); + +/*! +Closes the prepared statement and de-allocates all memory allocated for the statement. + +* prepared_statement: The prepared statement to destroy. +*/ +DUCKDB_API void duckdb_destroy_prepare(duckdb_prepared_statement *prepared_statement); + +/*! +Returns the error message associated with the given prepared statement. +If the prepared statement has no error message, this returns `nullptr` instead. + +The error message should not be freed. It will be de-allocated when `duckdb_destroy_prepare` is called. + +* prepared_statement: The prepared statement to obtain the error from. +* returns: The error message, or `nullptr` if there is none. +*/ +DUCKDB_API const char *duckdb_prepare_error(duckdb_prepared_statement prepared_statement); + +/*! +Returns the number of parameters that can be provided to the given prepared statement. + +Returns 0 if the query was not successfully prepared. + +* prepared_statement: The prepared statement to obtain the number of parameters for. +*/ +DUCKDB_API idx_t duckdb_nparams(duckdb_prepared_statement prepared_statement); + +/*! +Returns the name used to identify the parameter +The returned string should be freed using `duckdb_free`. + +Returns NULL if the index is out of range for the provided prepared statement. + +* prepared_statement: The prepared statement for which to get the parameter name from. +*/ +const char *duckdb_parameter_name(duckdb_prepared_statement prepared_statement, idx_t index); + +/*! +Returns the parameter type for the parameter at the given index. + +Returns `DUCKDB_TYPE_INVALID` if the parameter index is out of range or the statement was not successfully prepared. + +* prepared_statement: The prepared statement. +* param_idx: The parameter index. +* returns: The parameter type +*/ +DUCKDB_API duckdb_type duckdb_param_type(duckdb_prepared_statement prepared_statement, idx_t param_idx); + +/*! +Clear the params bind to the prepared statement. +*/ +DUCKDB_API duckdb_state duckdb_clear_bindings(duckdb_prepared_statement prepared_statement); + +/*! +Binds a value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_value val); + +/*! +Retrieve the index of the parameter for the prepared statement, identified by name +*/ +DUCKDB_API duckdb_state duckdb_bind_parameter_index(duckdb_prepared_statement prepared_statement, idx_t *param_idx_out, + const char *name); + +/*! +Binds a bool value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_boolean(duckdb_prepared_statement prepared_statement, idx_t param_idx, bool val); + +/*! +Binds an int8_t value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_int8(duckdb_prepared_statement prepared_statement, idx_t param_idx, int8_t val); + +/*! +Binds an int16_t value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_int16(duckdb_prepared_statement prepared_statement, idx_t param_idx, int16_t val); + +/*! +Binds an int32_t value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_int32(duckdb_prepared_statement prepared_statement, idx_t param_idx, int32_t val); + +/*! +Binds an int64_t value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_int64(duckdb_prepared_statement prepared_statement, idx_t param_idx, int64_t val); + +/*! +Binds an duckdb_hugeint value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_hugeint(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_hugeint val); +/*! +Binds a duckdb_decimal value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_decimal(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_decimal val); + +/*! +Binds an uint8_t value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_uint8(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint8_t val); + +/*! +Binds an uint16_t value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_uint16(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint16_t val); + +/*! +Binds an uint32_t value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_uint32(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint32_t val); + +/*! +Binds an uint64_t value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_uint64(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint64_t val); + +/*! +Binds an float value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_float(duckdb_prepared_statement prepared_statement, idx_t param_idx, float val); + +/*! +Binds an double value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_double(duckdb_prepared_statement prepared_statement, idx_t param_idx, double val); + +/*! +Binds a duckdb_date value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_date(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_date val); + +/*! +Binds a duckdb_time value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_time(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_time val); + +/*! +Binds a duckdb_timestamp value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_timestamp(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_timestamp val); + +/*! +Binds a duckdb_interval value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_interval(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_interval val); + +/*! +Binds a null-terminated varchar value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_varchar(duckdb_prepared_statement prepared_statement, idx_t param_idx, + const char *val); + +/*! +Binds a varchar value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_varchar_length(duckdb_prepared_statement prepared_statement, idx_t param_idx, + const char *val, idx_t length); + +/*! +Binds a blob value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_blob(duckdb_prepared_statement prepared_statement, idx_t param_idx, + const void *data, idx_t length); + +/*! +Binds a NULL value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_null(duckdb_prepared_statement prepared_statement, idx_t param_idx); + +/*! +Executes the prepared statement with the given bound parameters, and returns a materialized query result. + +This method can be called multiple times for each prepared statement, and the parameters can be modified +between calls to this function. + +* prepared_statement: The prepared statement to execute. +* out_result: The query result. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_execute_prepared(duckdb_prepared_statement prepared_statement, + duckdb_result *out_result); + +/*! +Executes the prepared statement with the given bound parameters, and returns an arrow query result. + +* prepared_statement: The prepared statement to execute. +* out_result: The query result. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_execute_prepared_arrow(duckdb_prepared_statement prepared_statement, + duckdb_arrow *out_result); + +/*! +Scans the Arrow stream and creates a view with the given name. + +* connection: The connection on which to execute the scan. +* table_name: Name of the temporary view to create. +* arrow: Arrow stream wrapper. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_arrow_scan(duckdb_connection connection, const char *table_name, + duckdb_arrow_stream arrow); + +/*! +Scans the Arrow array and creates a view with the given name. + +* connection: The connection on which to execute the scan. +* table_name: Name of the temporary view to create. +* arrow_schema: Arrow schema wrapper. +* arrow_array: Arrow array wrapper. +* out_stream: Output array stream that wraps around the passed schema, for releasing/deleting once done. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_arrow_array_scan(duckdb_connection connection, const char *table_name, + duckdb_arrow_schema arrow_schema, duckdb_arrow_array arrow_array, + duckdb_arrow_stream *out_stream); + +//===--------------------------------------------------------------------===// +// Extract Statements +//===--------------------------------------------------------------------===// +// A query string can be extracted into multiple SQL statements. Each statement can be prepared and executed separately. + +/*! +Extract all statements from a query. +Note that after calling `duckdb_extract_statements`, the extracted statements should always be destroyed using +`duckdb_destroy_extracted`, even if no statements were extracted. +If the extract fails, `duckdb_extract_statements_error` can be called to obtain the reason why the extract failed. +* connection: The connection object +* query: The SQL query to extract +* out_extracted_statements: The resulting extracted statements object +* returns: The number of extracted statements or 0 on failure. +*/ +DUCKDB_API idx_t duckdb_extract_statements(duckdb_connection connection, const char *query, + duckdb_extracted_statements *out_extracted_statements); + +/*! +Prepare an extracted statement. +Note that after calling `duckdb_prepare_extracted_statement`, the prepared statement should always be destroyed using +`duckdb_destroy_prepare`, even if the prepare fails. +If the prepare fails, `duckdb_prepare_error` can be called to obtain the reason why the prepare failed. +* connection: The connection object +* extracted_statements: The extracted statements object +* index: The index of the extracted statement to prepare +* out_prepared_statement: The resulting prepared statement object +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_prepare_extracted_statement(duckdb_connection connection, + duckdb_extracted_statements extracted_statements, + idx_t index, + duckdb_prepared_statement *out_prepared_statement); +/*! +Returns the error message contained within the extracted statements. +The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_extracted` is called. +* result: The extracted statements to fetch the error from. +* returns: The error of the extracted statements. +*/ +DUCKDB_API const char *duckdb_extract_statements_error(duckdb_extracted_statements extracted_statements); + +/*! +De-allocates all memory allocated for the extracted statements. +* extracted_statements: The extracted statements to destroy. +*/ +DUCKDB_API void duckdb_destroy_extracted(duckdb_extracted_statements *extracted_statements); + +//===--------------------------------------------------------------------===// +// Pending Result Interface +//===--------------------------------------------------------------------===// +/*! +Executes the prepared statement with the given bound parameters, and returns a pending result. +The pending result represents an intermediate structure for a query that is not yet fully executed. +The pending result can be used to incrementally execute a query, returning control to the client between tasks. + +Note that after calling `duckdb_pending_prepared`, the pending result should always be destroyed using +`duckdb_destroy_pending`, even if this function returns DuckDBError. + +* prepared_statement: The prepared statement to execute. +* out_result: The pending query result. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_pending_prepared(duckdb_prepared_statement prepared_statement, + duckdb_pending_result *out_result); + +/*! +Executes the prepared statement with the given bound parameters, and returns a pending result. +This pending result will create a streaming duckdb_result when executed. +The pending result represents an intermediate structure for a query that is not yet fully executed. + +Note that after calling `duckdb_pending_prepared_streaming`, the pending result should always be destroyed using +`duckdb_destroy_pending`, even if this function returns DuckDBError. + +* prepared_statement: The prepared statement to execute. +* out_result: The pending query result. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_pending_prepared_streaming(duckdb_prepared_statement prepared_statement, + duckdb_pending_result *out_result); + +/*! +Closes the pending result and de-allocates all memory allocated for the result. + +* pending_result: The pending result to destroy. +*/ +DUCKDB_API void duckdb_destroy_pending(duckdb_pending_result *pending_result); + +/*! +Returns the error message contained within the pending result. + +The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_pending` is called. + +* result: The pending result to fetch the error from. +* returns: The error of the pending result. +*/ +DUCKDB_API const char *duckdb_pending_error(duckdb_pending_result pending_result); + +/*! +Executes a single task within the query, returning whether or not the query is ready. + +If this returns DUCKDB_PENDING_RESULT_READY, the duckdb_execute_pending function can be called to obtain the result. +If this returns DUCKDB_PENDING_RESULT_NOT_READY, the duckdb_pending_execute_task function should be called again. +If this returns DUCKDB_PENDING_ERROR, an error occurred during execution. + +The error message can be obtained by calling duckdb_pending_error on the pending_result. + +* pending_result: The pending result to execute a task within.. +* returns: The state of the pending result after the execution. +*/ +DUCKDB_API duckdb_pending_state duckdb_pending_execute_task(duckdb_pending_result pending_result); + +/*! +Fully execute a pending query result, returning the final query result. + +If duckdb_pending_execute_task has been called until DUCKDB_PENDING_RESULT_READY was returned, this will return fast. +Otherwise, all remaining tasks must be executed first. + +* pending_result: The pending result to execute. +* out_result: The result object. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_execute_pending(duckdb_pending_result pending_result, duckdb_result *out_result); + +/*! +Returns whether a duckdb_pending_state is finished executing. For example if `pending_state` is +DUCKDB_PENDING_RESULT_READY, this function will return true. + +* pending_state: The pending state on which to decide whether to finish execution. +* returns: Boolean indicating pending execution should be considered finished. +*/ +DUCKDB_API bool duckdb_pending_execution_is_finished(duckdb_pending_state pending_state); + +//===--------------------------------------------------------------------===// +// Value Interface +//===--------------------------------------------------------------------===// +/*! +Destroys the value and de-allocates all memory allocated for that type. + +* value: The value to destroy. +*/ +DUCKDB_API void duckdb_destroy_value(duckdb_value *value); + +/*! +Creates a value from a null-terminated string + +* value: The null-terminated string +* returns: The value. This must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_API duckdb_value duckdb_create_varchar(const char *text); + +/*! +Creates a value from a string + +* value: The text +* length: The length of the text +* returns: The value. This must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_API duckdb_value duckdb_create_varchar_length(const char *text, idx_t length); + +/*! +Creates a value from an int64 + +* value: The bigint value +* returns: The value. This must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_API duckdb_value duckdb_create_int64(int64_t val); + +/*! +Obtains a string representation of the given value. +The result must be destroyed with `duckdb_free`. + +* value: The value +* returns: The string value. This must be destroyed with `duckdb_free`. +*/ +DUCKDB_API char *duckdb_get_varchar(duckdb_value value); + +/*! +Obtains an int64 of the given value. + +* value: The value +* returns: The int64 value, or 0 if no conversion is possible +*/ +DUCKDB_API int64_t duckdb_get_int64(duckdb_value value); + +//===--------------------------------------------------------------------===// +// Logical Type Interface +//===--------------------------------------------------------------------===// + +/*! +Creates a `duckdb_logical_type` from a standard primitive type. +The resulting type should be destroyed with `duckdb_destroy_logical_type`. + +This should not be used with `DUCKDB_TYPE_DECIMAL`. + +* type: The primitive type to create. +* returns: The logical type. +*/ +DUCKDB_API duckdb_logical_type duckdb_create_logical_type(duckdb_type type); + +/*! +Creates a list type from its child type. +The resulting type should be destroyed with `duckdb_destroy_logical_type`. + +* type: The child type of list type to create. +* returns: The logical type. +*/ +DUCKDB_API duckdb_logical_type duckdb_create_list_type(duckdb_logical_type type); + +/*! +Creates a map type from its key type and value type. +The resulting type should be destroyed with `duckdb_destroy_logical_type`. + +* type: The key type and value type of map type to create. +* returns: The logical type. +*/ +DUCKDB_API duckdb_logical_type duckdb_create_map_type(duckdb_logical_type key_type, duckdb_logical_type value_type); + +/*! +Creates a UNION type from the passed types array +The resulting type should be destroyed with `duckdb_destroy_logical_type`. + +* types: The array of types that the union should consist of. +* type_amount: The size of the types array. +* returns: The logical type. +*/ +DUCKDB_API duckdb_logical_type duckdb_create_union_type(duckdb_logical_type member_types, const char **member_names, + idx_t member_count); + +/*! +Creates a STRUCT type from the passed member name and type arrays. +The resulting type should be destroyed with `duckdb_destroy_logical_type`. + +* member_types: The array of types that the struct should consist of. +* member_names: The array of names that the struct should consist of. +* member_count: The number of members that were specified for both arrays. +* returns: The logical type. +*/ +DUCKDB_API duckdb_logical_type duckdb_create_struct_type(duckdb_logical_type *member_types, const char **member_names, + idx_t member_count); + +/*! +Creates a `duckdb_logical_type` of type decimal with the specified width and scale +The resulting type should be destroyed with `duckdb_destroy_logical_type`. + +* width: The width of the decimal type +* scale: The scale of the decimal type +* returns: The logical type. +*/ +DUCKDB_API duckdb_logical_type duckdb_create_decimal_type(uint8_t width, uint8_t scale); + +/*! +Retrieves the type class of a `duckdb_logical_type`. + +* type: The logical type object +* returns: The type id +*/ +DUCKDB_API duckdb_type duckdb_get_type_id(duckdb_logical_type type); + +/*! +Retrieves the width of a decimal type. + +* type: The logical type object +* returns: The width of the decimal type +*/ +DUCKDB_API uint8_t duckdb_decimal_width(duckdb_logical_type type); + +/*! +Retrieves the scale of a decimal type. + +* type: The logical type object +* returns: The scale of the decimal type +*/ +DUCKDB_API uint8_t duckdb_decimal_scale(duckdb_logical_type type); + +/*! +Retrieves the internal storage type of a decimal type. + +* type: The logical type object +* returns: The internal type of the decimal type +*/ +DUCKDB_API duckdb_type duckdb_decimal_internal_type(duckdb_logical_type type); + +/*! +Retrieves the internal storage type of an enum type. + +* type: The logical type object +* returns: The internal type of the enum type +*/ +DUCKDB_API duckdb_type duckdb_enum_internal_type(duckdb_logical_type type); + +/*! +Retrieves the dictionary size of the enum type + +* type: The logical type object +* returns: The dictionary size of the enum type +*/ +DUCKDB_API uint32_t duckdb_enum_dictionary_size(duckdb_logical_type type); + +/*! +Retrieves the dictionary value at the specified position from the enum. + +The result must be freed with `duckdb_free` + +* type: The logical type object +* index: The index in the dictionary +* returns: The string value of the enum type. Must be freed with `duckdb_free`. +*/ +DUCKDB_API char *duckdb_enum_dictionary_value(duckdb_logical_type type, idx_t index); + +/*! +Retrieves the child type of the given list type. + +The result must be freed with `duckdb_destroy_logical_type` + +* type: The logical type object +* returns: The child type of the list type. Must be destroyed with `duckdb_destroy_logical_type`. +*/ +DUCKDB_API duckdb_logical_type duckdb_list_type_child_type(duckdb_logical_type type); + +/*! +Retrieves the key type of the given map type. + +The result must be freed with `duckdb_destroy_logical_type` + +* type: The logical type object +* returns: The key type of the map type. Must be destroyed with `duckdb_destroy_logical_type`. +*/ +DUCKDB_API duckdb_logical_type duckdb_map_type_key_type(duckdb_logical_type type); + +/*! +Retrieves the value type of the given map type. + +The result must be freed with `duckdb_destroy_logical_type` + +* type: The logical type object +* returns: The value type of the map type. Must be destroyed with `duckdb_destroy_logical_type`. +*/ +DUCKDB_API duckdb_logical_type duckdb_map_type_value_type(duckdb_logical_type type); + +/*! +Returns the number of children of a struct type. + +* type: The logical type object +* returns: The number of children of a struct type. +*/ +DUCKDB_API idx_t duckdb_struct_type_child_count(duckdb_logical_type type); + +/*! +Retrieves the name of the struct child. + +The result must be freed with `duckdb_free` + +* type: The logical type object +* index: The child index +* returns: The name of the struct type. Must be freed with `duckdb_free`. +*/ +DUCKDB_API char *duckdb_struct_type_child_name(duckdb_logical_type type, idx_t index); + +/*! +Retrieves the child type of the given struct type at the specified index. + +The result must be freed with `duckdb_destroy_logical_type` + +* type: The logical type object +* index: The child index +* returns: The child type of the struct type. Must be destroyed with `duckdb_destroy_logical_type`. +*/ +DUCKDB_API duckdb_logical_type duckdb_struct_type_child_type(duckdb_logical_type type, idx_t index); + +/*! +Returns the number of members that the union type has. + +* type: The logical type (union) object +* returns: The number of members of a union type. +*/ +DUCKDB_API idx_t duckdb_union_type_member_count(duckdb_logical_type type); + +/*! +Retrieves the name of the union member. + +The result must be freed with `duckdb_free` + +* type: The logical type object +* index: The child index +* returns: The name of the union member. Must be freed with `duckdb_free`. +*/ +DUCKDB_API char *duckdb_union_type_member_name(duckdb_logical_type type, idx_t index); + +/*! +Retrieves the child type of the given union member at the specified index. + +The result must be freed with `duckdb_destroy_logical_type` + +* type: The logical type object +* index: The child index +* returns: The child type of the union member. Must be destroyed with `duckdb_destroy_logical_type`. +*/ +DUCKDB_API duckdb_logical_type duckdb_union_type_member_type(duckdb_logical_type type, idx_t index); + +/*! +Destroys the logical type and de-allocates all memory allocated for that type. + +* type: The logical type to destroy. +*/ +DUCKDB_API void duckdb_destroy_logical_type(duckdb_logical_type *type); + +//===--------------------------------------------------------------------===// +// Data Chunk Interface +//===--------------------------------------------------------------------===// +/*! +Creates an empty DataChunk with the specified set of types. + +* types: An array of types of the data chunk. +* column_count: The number of columns. +* returns: The data chunk. +*/ +DUCKDB_API duckdb_data_chunk duckdb_create_data_chunk(duckdb_logical_type *types, idx_t column_count); + +/*! +Destroys the data chunk and de-allocates all memory allocated for that chunk. + +* chunk: The data chunk to destroy. +*/ +DUCKDB_API void duckdb_destroy_data_chunk(duckdb_data_chunk *chunk); + +/*! +Resets a data chunk, clearing the validity masks and setting the cardinality of the data chunk to 0. + +* chunk: The data chunk to reset. +*/ +DUCKDB_API void duckdb_data_chunk_reset(duckdb_data_chunk chunk); + +/*! +Retrieves the number of columns in a data chunk. + +* chunk: The data chunk to get the data from +* returns: The number of columns in the data chunk +*/ +DUCKDB_API idx_t duckdb_data_chunk_get_column_count(duckdb_data_chunk chunk); + +/*! +Retrieves the vector at the specified column index in the data chunk. + +The pointer to the vector is valid for as long as the chunk is alive. +It does NOT need to be destroyed. + +* chunk: The data chunk to get the data from +* returns: The vector +*/ +DUCKDB_API duckdb_vector duckdb_data_chunk_get_vector(duckdb_data_chunk chunk, idx_t col_idx); + +/*! +Retrieves the current number of tuples in a data chunk. + +* chunk: The data chunk to get the data from +* returns: The number of tuples in the data chunk +*/ +DUCKDB_API idx_t duckdb_data_chunk_get_size(duckdb_data_chunk chunk); + +/*! +Sets the current number of tuples in a data chunk. + +* chunk: The data chunk to set the size in +* size: The number of tuples in the data chunk +*/ +DUCKDB_API void duckdb_data_chunk_set_size(duckdb_data_chunk chunk, idx_t size); + +//===--------------------------------------------------------------------===// +// Vector Interface +//===--------------------------------------------------------------------===// +/*! +Retrieves the column type of the specified vector. + +The result must be destroyed with `duckdb_destroy_logical_type`. + +* vector: The vector get the data from +* returns: The type of the vector +*/ +DUCKDB_API duckdb_logical_type duckdb_vector_get_column_type(duckdb_vector vector); + +/*! +Retrieves the data pointer of the vector. + +The data pointer can be used to read or write values from the vector. +How to read or write values depends on the type of the vector. + +* vector: The vector to get the data from +* returns: The data pointer +*/ +DUCKDB_API void *duckdb_vector_get_data(duckdb_vector vector); + +/*! +Retrieves the validity mask pointer of the specified vector. + +If all values are valid, this function MIGHT return NULL! + +The validity mask is a bitset that signifies null-ness within the data chunk. +It is a series of uint64_t values, where each uint64_t value contains validity for 64 tuples. +The bit is set to 1 if the value is valid (i.e. not NULL) or 0 if the value is invalid (i.e. NULL). + +Validity of a specific value can be obtained like this: + +idx_t entry_idx = row_idx / 64; +idx_t idx_in_entry = row_idx % 64; +bool is_valid = validity_mask[entry_idx] & (1 << idx_in_entry); + +Alternatively, the (slower) duckdb_validity_row_is_valid function can be used. + +* vector: The vector to get the data from +* returns: The pointer to the validity mask, or NULL if no validity mask is present +*/ +DUCKDB_API uint64_t *duckdb_vector_get_validity(duckdb_vector vector); + +/*! +Ensures the validity mask is writable by allocating it. + +After this function is called, `duckdb_vector_get_validity` will ALWAYS return non-NULL. +This allows null values to be written to the vector, regardless of whether a validity mask was present before. + +* vector: The vector to alter +*/ +DUCKDB_API void duckdb_vector_ensure_validity_writable(duckdb_vector vector); + +/*! +Assigns a string element in the vector at the specified location. + +* vector: The vector to alter +* index: The row position in the vector to assign the string to +* str: The null-terminated string +*/ +DUCKDB_API void duckdb_vector_assign_string_element(duckdb_vector vector, idx_t index, const char *str); + +/*! +Assigns a string element in the vector at the specified location. + +* vector: The vector to alter +* index: The row position in the vector to assign the string to +* str: The string +* str_len: The length of the string (in bytes) +*/ +DUCKDB_API void duckdb_vector_assign_string_element_len(duckdb_vector vector, idx_t index, const char *str, + idx_t str_len); + +/*! +Retrieves the child vector of a list vector. + +The resulting vector is valid as long as the parent vector is valid. + +* vector: The vector +* returns: The child vector +*/ +DUCKDB_API duckdb_vector duckdb_list_vector_get_child(duckdb_vector vector); + +/*! +Returns the size of the child vector of the list + +* vector: The vector +* returns: The size of the child list +*/ +DUCKDB_API idx_t duckdb_list_vector_get_size(duckdb_vector vector); + +/*! +Sets the total size of the underlying child-vector of a list vector. + +* vector: The list vector. +* size: The size of the child list. +* returns: The duckdb state. Returns DuckDBError if the vector is nullptr. +*/ +DUCKDB_API duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size); + +/*! +Sets the total capacity of the underlying child-vector of a list. + +* vector: The list vector. +* required_capacity: the total capacity to reserve. +* return: The duckdb state. Returns DuckDBError if the vector is nullptr. +*/ +DUCKDB_API duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity); + +/*! +Retrieves the child vector of a struct vector. + +The resulting vector is valid as long as the parent vector is valid. + +* vector: The vector +* index: The child index +* returns: The child vector +*/ +DUCKDB_API duckdb_vector duckdb_struct_vector_get_child(duckdb_vector vector, idx_t index); + +//===--------------------------------------------------------------------===// +// Validity Mask Functions +//===--------------------------------------------------------------------===// +/*! +Returns whether or not a row is valid (i.e. not NULL) in the given validity mask. + +* validity: The validity mask, as obtained through `duckdb_vector_get_validity` +* row: The row index +* returns: true if the row is valid, false otherwise +*/ +DUCKDB_API bool duckdb_validity_row_is_valid(uint64_t *validity, idx_t row); + +/*! +In a validity mask, sets a specific row to either valid or invalid. + +Note that `duckdb_vector_ensure_validity_writable` should be called before calling `duckdb_vector_get_validity`, +to ensure that there is a validity mask to write to. + +* validity: The validity mask, as obtained through `duckdb_vector_get_validity`. +* row: The row index +* valid: Whether or not to set the row to valid, or invalid +*/ +DUCKDB_API void duckdb_validity_set_row_validity(uint64_t *validity, idx_t row, bool valid); + +/*! +In a validity mask, sets a specific row to invalid. + +Equivalent to `duckdb_validity_set_row_validity` with valid set to false. + +* validity: The validity mask +* row: The row index +*/ +DUCKDB_API void duckdb_validity_set_row_invalid(uint64_t *validity, idx_t row); + +/*! +In a validity mask, sets a specific row to valid. + +Equivalent to `duckdb_validity_set_row_validity` with valid set to true. + +* validity: The validity mask +* row: The row index +*/ +DUCKDB_API void duckdb_validity_set_row_valid(uint64_t *validity, idx_t row); + +//===--------------------------------------------------------------------===// +// Table Functions +//===--------------------------------------------------------------------===// +typedef void *duckdb_table_function; +typedef void *duckdb_bind_info; +typedef void *duckdb_init_info; +typedef void *duckdb_function_info; + +typedef void (*duckdb_table_function_bind_t)(duckdb_bind_info info); +typedef void (*duckdb_table_function_init_t)(duckdb_init_info info); +typedef void (*duckdb_table_function_t)(duckdb_function_info info, duckdb_data_chunk output); +typedef void (*duckdb_delete_callback_t)(void *data); + +/*! +Creates a new empty table function. + +The return value should be destroyed with `duckdb_destroy_table_function`. + +* returns: The table function object. +*/ +DUCKDB_API duckdb_table_function duckdb_create_table_function(); + +/*! +Destroys the given table function object. + +* table_function: The table function to destroy +*/ +DUCKDB_API void duckdb_destroy_table_function(duckdb_table_function *table_function); + +/*! +Sets the name of the given table function. + +* table_function: The table function +* name: The name of the table function +*/ +DUCKDB_API void duckdb_table_function_set_name(duckdb_table_function table_function, const char *name); + +/*! +Adds a parameter to the table function. + +* table_function: The table function +* type: The type of the parameter to add. +*/ +DUCKDB_API void duckdb_table_function_add_parameter(duckdb_table_function table_function, duckdb_logical_type type); + +/*! +Adds a named parameter to the table function. + +* table_function: The table function +* name: The name of the parameter +* type: The type of the parameter to add. +*/ +DUCKDB_API void duckdb_table_function_add_named_parameter(duckdb_table_function table_function, const char *name, + duckdb_logical_type type); + +/*! +Assigns extra information to the table function that can be fetched during binding, etc. + +* table_function: The table function +* extra_info: The extra information +* destroy: The callback that will be called to destroy the bind data (if any) +*/ +DUCKDB_API void duckdb_table_function_set_extra_info(duckdb_table_function table_function, void *extra_info, + duckdb_delete_callback_t destroy); + +/*! +Sets the bind function of the table function + +* table_function: The table function +* bind: The bind function +*/ +DUCKDB_API void duckdb_table_function_set_bind(duckdb_table_function table_function, duckdb_table_function_bind_t bind); + +/*! +Sets the init function of the table function + +* table_function: The table function +* init: The init function +*/ +DUCKDB_API void duckdb_table_function_set_init(duckdb_table_function table_function, duckdb_table_function_init_t init); + +/*! +Sets the thread-local init function of the table function + +* table_function: The table function +* init: The init function +*/ +DUCKDB_API void duckdb_table_function_set_local_init(duckdb_table_function table_function, + duckdb_table_function_init_t init); + +/*! +Sets the main function of the table function + +* table_function: The table function +* function: The function +*/ +DUCKDB_API void duckdb_table_function_set_function(duckdb_table_function table_function, + duckdb_table_function_t function); + +/*! +Sets whether or not the given table function supports projection pushdown. + +If this is set to true, the system will provide a list of all required columns in the `init` stage through +the `duckdb_init_get_column_count` and `duckdb_init_get_column_index` functions. +If this is set to false (the default), the system will expect all columns to be projected. + +* table_function: The table function +* pushdown: True if the table function supports projection pushdown, false otherwise. +*/ +DUCKDB_API void duckdb_table_function_supports_projection_pushdown(duckdb_table_function table_function, bool pushdown); + +/*! +Register the table function object within the given connection. + +The function requires at least a name, a bind function, an init function and a main function. + +If the function is incomplete or a function with this name already exists DuckDBError is returned. + +* con: The connection to register it in. +* function: The function pointer +* returns: Whether or not the registration was successful. +*/ +DUCKDB_API duckdb_state duckdb_register_table_function(duckdb_connection con, duckdb_table_function function); + +//===--------------------------------------------------------------------===// +// Table Function Bind +//===--------------------------------------------------------------------===// +/*! +Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info` + +* info: The info object +* returns: The extra info +*/ +DUCKDB_API void *duckdb_bind_get_extra_info(duckdb_bind_info info); + +/*! +Adds a result column to the output of the table function. + +* info: The info object +* name: The name of the column +* type: The logical type of the column +*/ +DUCKDB_API void duckdb_bind_add_result_column(duckdb_bind_info info, const char *name, duckdb_logical_type type); + +/*! +Retrieves the number of regular (non-named) parameters to the function. + +* info: The info object +* returns: The number of parameters +*/ +DUCKDB_API idx_t duckdb_bind_get_parameter_count(duckdb_bind_info info); + +/*! +Retrieves the parameter at the given index. + +The result must be destroyed with `duckdb_destroy_value`. + +* info: The info object +* index: The index of the parameter to get +* returns: The value of the parameter. Must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_API duckdb_value duckdb_bind_get_parameter(duckdb_bind_info info, idx_t index); + +/*! +Retrieves a named parameter with the given name. + +The result must be destroyed with `duckdb_destroy_value`. + +* info: The info object +* name: The name of the parameter +* returns: The value of the parameter. Must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_API duckdb_value duckdb_bind_get_named_parameter(duckdb_bind_info info, const char *name); + +/*! +Sets the user-provided bind data in the bind object. This object can be retrieved again during execution. + +* info: The info object +* extra_data: The bind data object. +* destroy: The callback that will be called to destroy the bind data (if any) +*/ +DUCKDB_API void duckdb_bind_set_bind_data(duckdb_bind_info info, void *bind_data, duckdb_delete_callback_t destroy); + +/*! +Sets the cardinality estimate for the table function, used for optimization. + +* info: The bind data object. +* is_exact: Whether or not the cardinality estimate is exact, or an approximation +*/ +DUCKDB_API void duckdb_bind_set_cardinality(duckdb_bind_info info, idx_t cardinality, bool is_exact); + +/*! +Report that an error has occurred while calling bind. + +* info: The info object +* error: The error message +*/ +DUCKDB_API void duckdb_bind_set_error(duckdb_bind_info info, const char *error); + +//===--------------------------------------------------------------------===// +// Table Function Init +//===--------------------------------------------------------------------===// + +/*! +Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info` + +* info: The info object +* returns: The extra info +*/ +DUCKDB_API void *duckdb_init_get_extra_info(duckdb_init_info info); + +/*! +Gets the bind data set by `duckdb_bind_set_bind_data` during the bind. + +Note that the bind data should be considered as read-only. +For tracking state, use the init data instead. + +* info: The info object +* returns: The bind data object +*/ +DUCKDB_API void *duckdb_init_get_bind_data(duckdb_init_info info); + +/*! +Sets the user-provided init data in the init object. This object can be retrieved again during execution. + +* info: The info object +* extra_data: The init data object. +* destroy: The callback that will be called to destroy the init data (if any) +*/ +DUCKDB_API void duckdb_init_set_init_data(duckdb_init_info info, void *init_data, duckdb_delete_callback_t destroy); + +/*! +Returns the number of projected columns. + +This function must be used if projection pushdown is enabled to figure out which columns to emit. + +* info: The info object +* returns: The number of projected columns. +*/ +DUCKDB_API idx_t duckdb_init_get_column_count(duckdb_init_info info); + +/*! +Returns the column index of the projected column at the specified position. + +This function must be used if projection pushdown is enabled to figure out which columns to emit. + +* info: The info object +* column_index: The index at which to get the projected column index, from 0..duckdb_init_get_column_count(info) +* returns: The column index of the projected column. +*/ +DUCKDB_API idx_t duckdb_init_get_column_index(duckdb_init_info info, idx_t column_index); + +/*! +Sets how many threads can process this table function in parallel (default: 1) + +* info: The info object +* max_threads: The maximum amount of threads that can process this table function +*/ +DUCKDB_API void duckdb_init_set_max_threads(duckdb_init_info info, idx_t max_threads); + +/*! +Report that an error has occurred while calling init. + +* info: The info object +* error: The error message +*/ +DUCKDB_API void duckdb_init_set_error(duckdb_init_info info, const char *error); + +//===--------------------------------------------------------------------===// +// Table Function +//===--------------------------------------------------------------------===// + +/*! +Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info` + +* info: The info object +* returns: The extra info +*/ +DUCKDB_API void *duckdb_function_get_extra_info(duckdb_function_info info); +/*! +Gets the bind data set by `duckdb_bind_set_bind_data` during the bind. + +Note that the bind data should be considered as read-only. +For tracking state, use the init data instead. + +* info: The info object +* returns: The bind data object +*/ +DUCKDB_API void *duckdb_function_get_bind_data(duckdb_function_info info); + +/*! +Gets the init data set by `duckdb_init_set_init_data` during the init. + +* info: The info object +* returns: The init data object +*/ +DUCKDB_API void *duckdb_function_get_init_data(duckdb_function_info info); + +/*! +Gets the thread-local init data set by `duckdb_init_set_init_data` during the local_init. + +* info: The info object +* returns: The init data object +*/ +DUCKDB_API void *duckdb_function_get_local_init_data(duckdb_function_info info); + +/*! +Report that an error has occurred while executing the function. + +* info: The info object +* error: The error message +*/ +DUCKDB_API void duckdb_function_set_error(duckdb_function_info info, const char *error); + +//===--------------------------------------------------------------------===// +// Replacement Scans +//===--------------------------------------------------------------------===// +typedef void *duckdb_replacement_scan_info; + +typedef void (*duckdb_replacement_callback_t)(duckdb_replacement_scan_info info, const char *table_name, void *data); + +/*! +Add a replacement scan definition to the specified database + +* db: The database object to add the replacement scan to +* replacement: The replacement scan callback +* extra_data: Extra data that is passed back into the specified callback +* delete_callback: The delete callback to call on the extra data, if any +*/ +DUCKDB_API void duckdb_add_replacement_scan(duckdb_database db, duckdb_replacement_callback_t replacement, + void *extra_data, duckdb_delete_callback_t delete_callback); + +/*! +Sets the replacement function name to use. If this function is called in the replacement callback, + the replacement scan is performed. If it is not called, the replacement callback is not performed. + +* info: The info object +* function_name: The function name to substitute. +*/ +DUCKDB_API void duckdb_replacement_scan_set_function_name(duckdb_replacement_scan_info info, const char *function_name); + +/*! +Adds a parameter to the replacement scan function. + +* info: The info object +* parameter: The parameter to add. +*/ +DUCKDB_API void duckdb_replacement_scan_add_parameter(duckdb_replacement_scan_info info, duckdb_value parameter); + +/*! +Report that an error has occurred while executing the replacement scan. + +* info: The info object +* error: The error message +*/ +DUCKDB_API void duckdb_replacement_scan_set_error(duckdb_replacement_scan_info info, const char *error); + +//===--------------------------------------------------------------------===// +// Appender +//===--------------------------------------------------------------------===// + +// Appenders are the most efficient way of loading data into DuckDB from within the C interface, and are recommended for +// fast data loading. The appender is much faster than using prepared statements or individual `INSERT INTO` statements. + +// Appends are made in row-wise format. For every column, a `duckdb_append_[type]` call should be made, after which +// the row should be finished by calling `duckdb_appender_end_row`. After all rows have been appended, +// `duckdb_appender_destroy` should be used to finalize the appender and clean up the resulting memory. + +// Note that `duckdb_appender_destroy` should always be called on the resulting appender, even if the function returns +// `DuckDBError`. + +/*! +Creates an appender object. + +* connection: The connection context to create the appender in. +* schema: The schema of the table to append to, or `nullptr` for the default schema. +* table: The table name to append to. +* out_appender: The resulting appender object. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_appender_create(duckdb_connection connection, const char *schema, const char *table, + duckdb_appender *out_appender); + +/*! +Returns the error message associated with the given appender. +If the appender has no error message, this returns `nullptr` instead. + +The error message should not be freed. It will be de-allocated when `duckdb_appender_destroy` is called. + +* appender: The appender to get the error from. +* returns: The error message, or `nullptr` if there is none. +*/ +DUCKDB_API const char *duckdb_appender_error(duckdb_appender appender); + +/*! +Flush the appender to the table, forcing the cache of the appender to be cleared and the data to be appended to the +base table. + +This should generally not be used unless you know what you are doing. Instead, call `duckdb_appender_destroy` when you +are done with the appender. + +* appender: The appender to flush. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_appender_flush(duckdb_appender appender); + +/*! +Close the appender, flushing all intermediate state in the appender to the table and closing it for further appends. + +This is generally not necessary. Call `duckdb_appender_destroy` instead. + +* appender: The appender to flush and close. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_appender_close(duckdb_appender appender); + +/*! +Close the appender and destroy it. Flushing all intermediate state in the appender to the table, and de-allocating +all memory associated with the appender. + +* appender: The appender to flush, close and destroy. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_appender_destroy(duckdb_appender *appender); + +/*! +A nop function, provided for backwards compatibility reasons. Does nothing. Only `duckdb_appender_end_row` is required. +*/ +DUCKDB_API duckdb_state duckdb_appender_begin_row(duckdb_appender appender); + +/*! +Finish the current row of appends. After end_row is called, the next row can be appended. + +* appender: The appender. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_appender_end_row(duckdb_appender appender); + +/*! +Append a bool value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_bool(duckdb_appender appender, bool value); + +/*! +Append an int8_t value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_int8(duckdb_appender appender, int8_t value); +/*! +Append an int16_t value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_int16(duckdb_appender appender, int16_t value); +/*! +Append an int32_t value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_int32(duckdb_appender appender, int32_t value); +/*! +Append an int64_t value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_int64(duckdb_appender appender, int64_t value); +/*! +Append a duckdb_hugeint value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_hugeint(duckdb_appender appender, duckdb_hugeint value); + +/*! +Append a uint8_t value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_uint8(duckdb_appender appender, uint8_t value); +/*! +Append a uint16_t value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_uint16(duckdb_appender appender, uint16_t value); +/*! +Append a uint32_t value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_uint32(duckdb_appender appender, uint32_t value); +/*! +Append a uint64_t value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_uint64(duckdb_appender appender, uint64_t value); + +/*! +Append a float value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_float(duckdb_appender appender, float value); +/*! +Append a double value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_double(duckdb_appender appender, double value); + +/*! +Append a duckdb_date value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_date(duckdb_appender appender, duckdb_date value); +/*! +Append a duckdb_time value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_time(duckdb_appender appender, duckdb_time value); +/*! +Append a duckdb_timestamp value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_timestamp(duckdb_appender appender, duckdb_timestamp value); +/*! +Append a duckdb_interval value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_interval(duckdb_appender appender, duckdb_interval value); + +/*! +Append a varchar value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_varchar(duckdb_appender appender, const char *val); +/*! +Append a varchar value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_varchar_length(duckdb_appender appender, const char *val, idx_t length); +/*! +Append a blob value to the appender. +*/ +DUCKDB_API duckdb_state duckdb_append_blob(duckdb_appender appender, const void *data, idx_t length); +/*! +Append a NULL value to the appender (of any type). +*/ +DUCKDB_API duckdb_state duckdb_append_null(duckdb_appender appender); + +/*! +Appends a pre-filled data chunk to the specified appender. + +The types of the data chunk must exactly match the types of the table, no casting is performed. +If the types do not match or the appender is in an invalid state, DuckDBError is returned. +If the append is successful, DuckDBSuccess is returned. + +* appender: The appender to append to. +* chunk: The data chunk to append. +* returns: The return state. +*/ +DUCKDB_API duckdb_state duckdb_append_data_chunk(duckdb_appender appender, duckdb_data_chunk chunk); + +//===--------------------------------------------------------------------===// +// Arrow Interface +//===--------------------------------------------------------------------===// +/*! +Executes a SQL query within a connection and stores the full (materialized) result in an arrow structure. +If the query fails to execute, DuckDBError is returned and the error message can be retrieved by calling +`duckdb_query_arrow_error`. + +Note that after running `duckdb_query_arrow`, `duckdb_destroy_arrow` must be called on the result object even if the +query fails, otherwise the error stored within the result will not be freed correctly. + +* connection: The connection to perform the query in. +* query: The SQL query to run. +* out_result: The query result. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_query_arrow(duckdb_connection connection, const char *query, duckdb_arrow *out_result); + +/*! +Fetch the internal arrow schema from the arrow result. + +* result: The result to fetch the schema from. +* out_schema: The output schema. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_query_arrow_schema(duckdb_arrow result, duckdb_arrow_schema *out_schema); + +/*! +Fetch the internal arrow schema from the prepared statement. + +* result: The prepared statement to fetch the schema from. +* out_schema: The output schema. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_prepared_arrow_schema(duckdb_prepared_statement prepared, + duckdb_arrow_schema *out_schema); + +/*! +Fetch an internal arrow array from the arrow result. + +This function can be called multiple time to get next chunks, which will free the previous out_array. +So consume the out_array before calling this function again. + +* result: The result to fetch the array from. +* out_array: The output array. +* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_API duckdb_state duckdb_query_arrow_array(duckdb_arrow result, duckdb_arrow_array *out_array); + +/*! +Returns the number of columns present in a the arrow result object. + +* result: The result object. +* returns: The number of columns present in the result object. +*/ +DUCKDB_API idx_t duckdb_arrow_column_count(duckdb_arrow result); + +/*! +Returns the number of rows present in a the arrow result object. + +* result: The result object. +* returns: The number of rows present in the result object. +*/ +DUCKDB_API idx_t duckdb_arrow_row_count(duckdb_arrow result); + +/*! +Returns the number of rows changed by the query stored in the arrow result. This is relevant only for +INSERT/UPDATE/DELETE queries. For other queries the rows_changed will be 0. + +* result: The result object. +* returns: The number of rows changed. +*/ +DUCKDB_API idx_t duckdb_arrow_rows_changed(duckdb_arrow result); + +/*! +Returns the error message contained within the result. The error is only set if `duckdb_query_arrow` returns +`DuckDBError`. + +The error message should not be freed. It will be de-allocated when `duckdb_destroy_arrow` is called. + +* result: The result object to fetch the nullmask from. +* returns: The error of the result. +*/ +DUCKDB_API const char *duckdb_query_arrow_error(duckdb_arrow result); + +/*! +Closes the result and de-allocates all memory allocated for the arrow result. + +* result: The result to destroy. +*/ +DUCKDB_API void duckdb_destroy_arrow(duckdb_arrow *result); + +//===--------------------------------------------------------------------===// +// Threading Information +//===--------------------------------------------------------------------===// +typedef void *duckdb_task_state; + +/*! +Execute DuckDB tasks on this thread. + +Will return after `max_tasks` have been executed, or if there are no more tasks present. + +* database: The database object to execute tasks for +* max_tasks: The maximum amount of tasks to execute +*/ +DUCKDB_API void duckdb_execute_tasks(duckdb_database database, idx_t max_tasks); + +/*! +Creates a task state that can be used with duckdb_execute_tasks_state to execute tasks until + duckdb_finish_execution is called on the state. + +duckdb_destroy_state should be called on the result in order to free memory. + +* database: The database object to create the task state for +* returns: The task state that can be used with duckdb_execute_tasks_state. +*/ +DUCKDB_API duckdb_task_state duckdb_create_task_state(duckdb_database database); + +/*! +Execute DuckDB tasks on this thread. + +The thread will keep on executing tasks forever, until duckdb_finish_execution is called on the state. +Multiple threads can share the same duckdb_task_state. + +* state: The task state of the executor +*/ +DUCKDB_API void duckdb_execute_tasks_state(duckdb_task_state state); + +/*! +Execute DuckDB tasks on this thread. + +The thread will keep on executing tasks until either duckdb_finish_execution is called on the state, +max_tasks tasks have been executed or there are no more tasks to be executed. + +Multiple threads can share the same duckdb_task_state. + +* state: The task state of the executor +* max_tasks: The maximum amount of tasks to execute +* returns: The amount of tasks that have actually been executed +*/ +DUCKDB_API idx_t duckdb_execute_n_tasks_state(duckdb_task_state state, idx_t max_tasks); + +/*! +Finish execution on a specific task. + +* state: The task state to finish execution +*/ +DUCKDB_API void duckdb_finish_execution(duckdb_task_state state); + +/*! +Check if the provided duckdb_task_state has finished execution + +* state: The task state to inspect +* returns: Whether or not duckdb_finish_execution has been called on the task state +*/ +DUCKDB_API bool duckdb_task_state_is_finished(duckdb_task_state state); + +/*! +Destroys the task state returned from duckdb_create_task_state. + +Note that this should not be called while there is an active duckdb_execute_tasks_state running +on the task state. + +* state: The task state to clean up +*/ +DUCKDB_API void duckdb_destroy_task_state(duckdb_task_state state); + +/*! +Returns true if execution of the current query is finished. + +* con: The connection on which to check +*/ +DUCKDB_API bool duckdb_execution_is_finished(duckdb_connection con); + +//===--------------------------------------------------------------------===// +// Streaming Result Interface +//===--------------------------------------------------------------------===// + +/*! +Fetches a data chunk from the (streaming) duckdb_result. This function should be called repeatedly until the result is +exhausted. + +The result must be destroyed with `duckdb_destroy_data_chunk`. + +This function can only be used on duckdb_results created with 'duckdb_pending_prepared_streaming' + +If this function is used, none of the other result functions can be used and vice versa (i.e. this function cannot be +mixed with the legacy result functions or the materialized result functions). + +It is not known beforehand how many chunks will be returned by this result. + +* result: The result object to fetch the data chunk from. +* returns: The resulting data chunk. Returns `NULL` if the result has an error. +*/ +DUCKDB_API duckdb_data_chunk duckdb_stream_fetch_chunk(duckdb_result result); + +#ifdef __cplusplus +} +#endif diff --git a/src/duckdb/src/include/duckdb.hpp b/src/duckdb/src/include/duckdb.hpp new file mode 100644 index 00000000..168dfd00 --- /dev/null +++ b/src/duckdb/src/include/duckdb.hpp @@ -0,0 +1,14 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/connection.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/query_result.hpp" +#include "duckdb/main/appender.hpp" diff --git a/src/duckdb/src/include/duckdb/catalog/catalog.hpp b/src/duckdb/src/include/duckdb/catalog/catalog.hpp new file mode 100644 index 00000000..f04f027b --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog.hpp @@ -0,0 +1,361 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/parser/query_error_context.hpp" +#include "duckdb/catalog/catalog_transaction.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" +#include + +namespace duckdb { +struct CreateSchemaInfo; +struct DropInfo; +struct BoundCreateTableInfo; +struct AlterTableInfo; +struct CreateTableFunctionInfo; +struct CreateCopyFunctionInfo; +struct CreatePragmaFunctionInfo; +struct CreateFunctionInfo; +struct CreateViewInfo; +struct CreateSequenceInfo; +struct CreateCollationInfo; +struct CreateIndexInfo; +struct CreateTypeInfo; +struct CreateTableInfo; +struct DatabaseSize; +struct MetadataBlockInfo; + +class AttachedDatabase; +class ClientContext; +class Transaction; + +class AggregateFunctionCatalogEntry; +class CollateCatalogEntry; +class SchemaCatalogEntry; +class TableCatalogEntry; +class ViewCatalogEntry; +class SequenceCatalogEntry; +class TableFunctionCatalogEntry; +class CopyFunctionCatalogEntry; +class PragmaFunctionCatalogEntry; +class CatalogSet; +class DatabaseInstance; +class DependencyManager; + +struct CatalogLookup; +struct CatalogEntryLookup; +struct SimilarCatalogEntry; + +class Binder; +class LogicalOperator; +class PhysicalOperator; +class LogicalCreateIndex; +class LogicalCreateTable; +class LogicalInsert; +class LogicalDelete; +class LogicalUpdate; +class CreateStatement; + +//! The Catalog object represents the catalog of the database. +class Catalog { +public: + explicit Catalog(AttachedDatabase &db); + virtual ~Catalog(); + +public: + //! Get the SystemCatalog from the ClientContext + DUCKDB_API static Catalog &GetSystemCatalog(ClientContext &context); + //! Get the SystemCatalog from the DatabaseInstance + DUCKDB_API static Catalog &GetSystemCatalog(DatabaseInstance &db); + //! Get the specified Catalog from the ClientContext + DUCKDB_API static Catalog &GetCatalog(ClientContext &context, const string &catalog_name); + //! Get the specified Catalog from the DatabaseInstance + DUCKDB_API static Catalog &GetCatalog(DatabaseInstance &db, const string &catalog_name); + //! Gets the specified Catalog from the database if it exists + DUCKDB_API static optional_ptr GetCatalogEntry(ClientContext &context, const string &catalog_name); + //! Get the specific Catalog from the AttachedDatabase + DUCKDB_API static Catalog &GetCatalog(AttachedDatabase &db); + + DUCKDB_API AttachedDatabase &GetAttached(); + DUCKDB_API DatabaseInstance &GetDatabase(); + + virtual bool IsDuckCatalog() { + return false; + } + virtual void Initialize(bool load_builtin) = 0; + + bool IsSystemCatalog() const; + bool IsTemporaryCatalog() const; + + //! Returns the current version of the catalog (incremented whenever anything changes, not stored between restarts) + DUCKDB_API idx_t GetCatalogVersion(); + //! Trigger a modification in the catalog, increasing the catalog version and returning the previous version + DUCKDB_API idx_t ModifyCatalog(); + + //! Returns the catalog name - based on how the catalog was attached + DUCKDB_API const string &GetName(); + DUCKDB_API idx_t GetOid(); + DUCKDB_API virtual string GetCatalogType() = 0; + + DUCKDB_API CatalogTransaction GetCatalogTransaction(ClientContext &context); + + //! Creates a schema in the catalog. + DUCKDB_API virtual optional_ptr CreateSchema(CatalogTransaction transaction, + CreateSchemaInfo &info) = 0; + DUCKDB_API optional_ptr CreateSchema(ClientContext &context, CreateSchemaInfo &info); + //! Creates a table in the catalog. + DUCKDB_API optional_ptr CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info); + DUCKDB_API optional_ptr CreateTable(ClientContext &context, BoundCreateTableInfo &info); + //! Creates a table in the catalog. + DUCKDB_API optional_ptr CreateTable(ClientContext &context, unique_ptr info); + //! Create a table function in the catalog + DUCKDB_API optional_ptr CreateTableFunction(CatalogTransaction transaction, + CreateTableFunctionInfo &info); + DUCKDB_API optional_ptr CreateTableFunction(ClientContext &context, CreateTableFunctionInfo &info); + // Kept for backwards compatibility + DUCKDB_API optional_ptr CreateTableFunction(ClientContext &context, + optional_ptr info); + //! Create a copy function in the catalog + DUCKDB_API optional_ptr CreateCopyFunction(CatalogTransaction transaction, + CreateCopyFunctionInfo &info); + DUCKDB_API optional_ptr CreateCopyFunction(ClientContext &context, CreateCopyFunctionInfo &info); + //! Create a pragma function in the catalog + DUCKDB_API optional_ptr CreatePragmaFunction(CatalogTransaction transaction, + CreatePragmaFunctionInfo &info); + DUCKDB_API optional_ptr CreatePragmaFunction(ClientContext &context, CreatePragmaFunctionInfo &info); + //! Create a scalar or aggregate function in the catalog + DUCKDB_API optional_ptr CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info); + DUCKDB_API optional_ptr CreateFunction(ClientContext &context, CreateFunctionInfo &info); + //! Creates a table in the catalog. + DUCKDB_API optional_ptr CreateView(CatalogTransaction transaction, CreateViewInfo &info); + DUCKDB_API optional_ptr CreateView(ClientContext &context, CreateViewInfo &info); + //! Creates a sequence in the catalog. + DUCKDB_API optional_ptr CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info); + DUCKDB_API optional_ptr CreateSequence(ClientContext &context, CreateSequenceInfo &info); + //! Creates a Enum in the catalog. + DUCKDB_API optional_ptr CreateType(CatalogTransaction transaction, CreateTypeInfo &info); + DUCKDB_API optional_ptr CreateType(ClientContext &context, CreateTypeInfo &info); + //! Creates a collation in the catalog + DUCKDB_API optional_ptr CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info); + DUCKDB_API optional_ptr CreateCollation(ClientContext &context, CreateCollationInfo &info); + //! Creates an index in the catalog + DUCKDB_API optional_ptr CreateIndex(CatalogTransaction transaction, CreateIndexInfo &info); + DUCKDB_API optional_ptr CreateIndex(ClientContext &context, CreateIndexInfo &info); + + //! Creates a table in the catalog. + DUCKDB_API optional_ptr CreateTable(CatalogTransaction transaction, SchemaCatalogEntry &schema, + BoundCreateTableInfo &info); + //! Create a table function in the catalog + DUCKDB_API optional_ptr + CreateTableFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, CreateTableFunctionInfo &info); + //! Create a copy function in the catalog + DUCKDB_API optional_ptr CreateCopyFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateCopyFunctionInfo &info); + //! Create a pragma function in the catalog + DUCKDB_API optional_ptr + CreatePragmaFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, CreatePragmaFunctionInfo &info); + //! Create a scalar or aggregate function in the catalog + DUCKDB_API optional_ptr CreateFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateFunctionInfo &info); + //! Creates a view in the catalog + DUCKDB_API optional_ptr CreateView(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateViewInfo &info); + //! Creates a table in the catalog. + DUCKDB_API optional_ptr CreateSequence(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateSequenceInfo &info); + //! Creates a enum in the catalog. + DUCKDB_API optional_ptr CreateType(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateTypeInfo &info); + //! Creates a collation in the catalog + DUCKDB_API optional_ptr CreateCollation(CatalogTransaction transaction, SchemaCatalogEntry &schema, + CreateCollationInfo &info); + + //! Drops an entry from the catalog + DUCKDB_API void DropEntry(ClientContext &context, DropInfo &info); + + //! Returns the schema object with the specified name, or throws an exception if it does not exist + DUCKDB_API SchemaCatalogEntry &GetSchema(ClientContext &context, const string &name, + QueryErrorContext error_context = QueryErrorContext()); + DUCKDB_API optional_ptr GetSchema(ClientContext &context, const string &name, + OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()); + DUCKDB_API SchemaCatalogEntry &GetSchema(CatalogTransaction transaction, const string &name, + QueryErrorContext error_context = QueryErrorContext()); + DUCKDB_API virtual optional_ptr + GetSchema(CatalogTransaction transaction, const string &schema_name, OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()) = 0; + DUCKDB_API static SchemaCatalogEntry &GetSchema(ClientContext &context, const string &catalog_name, + const string &schema_name, + QueryErrorContext error_context = QueryErrorContext()); + DUCKDB_API static optional_ptr GetSchema(ClientContext &context, const string &catalog_name, + const string &schema_name, + OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()); + //! Scans all the schemas in the system one-by-one, invoking the callback for each entry + DUCKDB_API virtual void ScanSchemas(ClientContext &context, std::function callback) = 0; + //! Gets the "schema.name" entry of the specified type, if entry does not exist behavior depends on OnEntryNotFound + DUCKDB_API optional_ptr GetEntry(ClientContext &context, CatalogType type, const string &schema, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()); + DUCKDB_API CatalogEntry &GetEntry(ClientContext &context, CatalogType type, const string &schema, + const string &name, QueryErrorContext error_context = QueryErrorContext()); + //! Gets the "catalog.schema.name" entry of the specified type, if entry does not exist behavior depends on + //! OnEntryNotFound + DUCKDB_API static optional_ptr GetEntry(ClientContext &context, CatalogType type, + const string &catalog, const string &schema, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()); + DUCKDB_API static CatalogEntry &GetEntry(ClientContext &context, CatalogType type, const string &catalog, + const string &schema, const string &name, + QueryErrorContext error_context = QueryErrorContext()); + + //! Gets the "schema.name" entry without a specified type, if entry does not exist an exception is thrown + DUCKDB_API CatalogEntry &GetEntry(ClientContext &context, const string &schema, const string &name); + + //! Fetches a logical type from the catalog + DUCKDB_API LogicalType GetType(ClientContext &context, const string &schema, const string &names, + OnEntryNotFound if_not_found); + + DUCKDB_API static LogicalType GetType(ClientContext &context, const string &catalog_name, const string &schema, + const string &name); + + template + optional_ptr GetEntry(ClientContext &context, const string &schema_name, const string &name, + OnEntryNotFound if_not_found, QueryErrorContext error_context = QueryErrorContext()) { + auto entry = GetEntry(context, T::Type, schema_name, name, if_not_found, error_context); + if (!entry) { + return nullptr; + } + if (entry->type != T::Type) { + throw CatalogException(error_context.FormatError("%s is not an %s", name, T::Name)); + } + return &entry->template Cast(); + } + template + T &GetEntry(ClientContext &context, const string &schema_name, const string &name, + QueryErrorContext error_context = QueryErrorContext()) { + auto entry = GetEntry(context, schema_name, name, OnEntryNotFound::THROW_EXCEPTION, error_context); + return *entry; + } + + //! Append a scalar or aggregate function to the catalog + DUCKDB_API optional_ptr AddFunction(ClientContext &context, CreateFunctionInfo &info); + + //! Alter an existing entry in the catalog. + DUCKDB_API void Alter(ClientContext &context, AlterInfo &info); + + virtual unique_ptr PlanCreateTableAs(ClientContext &context, LogicalCreateTable &op, + unique_ptr plan) = 0; + virtual unique_ptr PlanInsert(ClientContext &context, LogicalInsert &op, + unique_ptr plan) = 0; + virtual unique_ptr PlanDelete(ClientContext &context, LogicalDelete &op, + unique_ptr plan) = 0; + virtual unique_ptr PlanUpdate(ClientContext &context, LogicalUpdate &op, + unique_ptr plan) = 0; + virtual unique_ptr BindCreateIndex(Binder &binder, CreateStatement &stmt, TableCatalogEntry &table, + unique_ptr plan) = 0; + + virtual DatabaseSize GetDatabaseSize(ClientContext &context) = 0; + virtual vector GetMetadataInfo(ClientContext &context); + + virtual bool InMemory() = 0; + virtual string GetDBPath() = 0; + +public: + template + static optional_ptr GetEntry(ClientContext &context, const string &catalog_name, const string &schema_name, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()) { + auto entry = GetEntry(context, T::Type, catalog_name, schema_name, name, if_not_found, error_context); + if (!entry) { + return nullptr; + } + if (entry->type != T::Type) { + throw CatalogException(error_context.FormatError("%s is not an %s", name, T::Name)); + } + return &entry->template Cast(); + } + template + static T &GetEntry(ClientContext &context, const string &catalog_name, const string &schema_name, + const string &name, QueryErrorContext error_context = QueryErrorContext()) { + auto entry = + GetEntry(context, catalog_name, schema_name, name, OnEntryNotFound::THROW_EXCEPTION, error_context); + return *entry; + } + + DUCKDB_API vector> GetSchemas(ClientContext &context); + DUCKDB_API static vector> GetSchemas(ClientContext &context, + const string &catalog_name); + DUCKDB_API static vector> GetAllSchemas(ClientContext &context); + + virtual void Verify(); + + static CatalogException UnrecognizedConfigurationError(ClientContext &context, const string &name); + + //! Autoload the extension required for `configuration_name` or throw a CatalogException + static void AutoloadExtensionByConfigName(ClientContext &context, const string &configuration_name); + //! Autoload the extension required for `function_name` or throw a CatalogException + static bool AutoLoadExtensionByCatalogEntry(ClientContext &context, CatalogType type, const string &entry_name); + DUCKDB_API static bool TryAutoLoad(ClientContext &context, const string &extension_name) noexcept; + +protected: + //! Reference to the database + AttachedDatabase &db; + +private: + //! Lookup an entry in the schema, returning a lookup with the entry and schema if they exist + CatalogEntryLookup TryLookupEntryInternal(CatalogTransaction transaction, CatalogType type, const string &schema, + const string &name); + //! Calls LookupEntryInternal on the schema, trying other schemas if the schema is invalid. Sets + //! CatalogEntryLookup->error depending on if_not_found when no entry is found + CatalogEntryLookup TryLookupEntry(ClientContext &context, CatalogType type, const string &schema, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()); + //! Lookup an entry using TryLookupEntry, throws if entry not found and if_not_found == THROW_EXCEPTION + CatalogEntryLookup LookupEntry(ClientContext &context, CatalogType type, const string &schema, const string &name, + OnEntryNotFound if_not_found, QueryErrorContext error_context = QueryErrorContext()); + static CatalogEntryLookup TryLookupEntry(ClientContext &context, vector &lookups, CatalogType type, + const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()); + static CatalogEntryLookup TryLookupEntry(ClientContext &context, CatalogType type, const string &catalog, + const string &schema, const string &name, OnEntryNotFound if_not_found, + QueryErrorContext error_context); + + //! Return an exception with did-you-mean suggestion. + static CatalogException CreateMissingEntryException(ClientContext &context, const string &entry_name, + CatalogType type, + const reference_set_t &schemas, + QueryErrorContext error_context); + + //! Return the close entry name, the distance and the belonging schema. + static SimilarCatalogEntry SimilarEntryInSchemas(ClientContext &context, const string &entry_name, CatalogType type, + const reference_set_t &schemas); + + virtual void DropSchema(ClientContext &context, DropInfo &info) = 0; + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry.hpp new file mode 100644 index 00000000..452e61b5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry.hpp @@ -0,0 +1,109 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include + +namespace duckdb { +struct AlterInfo; +class Catalog; +class CatalogSet; +class ClientContext; +class SchemaCatalogEntry; +class Serializer; +class Deserializer; + +struct CreateInfo; + +//! Abstract base class of an entry in the catalog +class CatalogEntry { +public: + CatalogEntry(CatalogType type, Catalog &catalog, string name); + CatalogEntry(CatalogType type, string name, idx_t oid); + virtual ~CatalogEntry(); + + //! The oid of the entry + idx_t oid; + //! The type of this catalog entry + CatalogType type; + //! Reference to the catalog set this entry is stored in + optional_ptr set; + //! The name of the entry + string name; + //! Whether or not the object is deleted + bool deleted; + //! Whether or not the object is temporary and should not be added to the WAL + bool temporary; + //! Whether or not the entry is an internal entry (cannot be deleted, not dumped, etc) + bool internal; + //! Timestamp at which the catalog entry was created + atomic timestamp; + //! Child entry + unique_ptr child; + //! Parent entry (the node that dependents_map this node) + optional_ptr parent; + +public: + virtual unique_ptr AlterEntry(ClientContext &context, AlterInfo &info); + virtual void UndoAlter(ClientContext &context, AlterInfo &info); + + virtual unique_ptr Copy(ClientContext &context) const; + + virtual unique_ptr GetInfo() const; + + //! Sets the CatalogEntry as the new root entry (i.e. the newest entry) + // this is called on a rollback to an AlterEntry + virtual void SetAsRoot(); + + //! Convert the catalog entry to a SQL string that can be used to re-construct the catalog entry + virtual string ToSQL() const; + + virtual Catalog &ParentCatalog(); + virtual SchemaCatalogEntry &ParentSchema(); + + virtual void Verify(Catalog &catalog); + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +class InCatalogEntry : public CatalogEntry { +public: + InCatalogEntry(CatalogType type, Catalog &catalog, string name); + ~InCatalogEntry() override; + + //! The catalog the entry belongs to + Catalog &catalog; + +public: + Catalog &ParentCatalog() override { + return catalog; + } + + void Verify(Catalog &catalog) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp new file mode 100644 index 00000000..d40dcb97 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/function_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" + +namespace duckdb { + +//! An aggregate function in the catalog +class AggregateFunctionCatalogEntry : public FunctionEntry { +public: + static constexpr const CatalogType Type = CatalogType::AGGREGATE_FUNCTION_ENTRY; + static constexpr const char *Name = "aggregate function"; + +public: + AggregateFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateAggregateFunctionInfo &info) + : FunctionEntry(CatalogType::AGGREGATE_FUNCTION_ENTRY, catalog, schema, info), functions(info.functions) { + } + + //! The aggregate functions + AggregateFunctionSet functions; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/collate_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/collate_catalog_entry.hpp new file mode 100644 index 00000000..67e61293 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/collate_catalog_entry.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/collate_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/parser/parsed_data/create_collation_info.hpp" + +namespace duckdb { + +//! A collation catalog entry +class CollateCatalogEntry : public StandardEntry { +public: + static constexpr const CatalogType Type = CatalogType::COLLATION_ENTRY; + static constexpr const char *Name = "collation"; + +public: + CollateCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateCollationInfo &info) + : StandardEntry(CatalogType::COLLATION_ENTRY, schema, catalog, info.name), function(info.function), + combinable(info.combinable), not_required_for_equality(info.not_required_for_equality) { + } + + //! The collation function to push in case collation is required + ScalarFunction function; + //! Whether or not the collation can be combined with other collations. + bool combinable; + //! Whether or not the collation is required for equality comparisons or not. For many collations a binary + //! comparison for equality comparisons is correct, allowing us to skip the collation in these cases which greatly + //! speeds up processing. + bool not_required_for_equality; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/column_dependency_manager.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/column_dependency_manager.hpp new file mode 100644 index 00000000..51cc8286 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/column_dependency_manager.hpp @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/column_dependency_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/parser/column_list.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/common/stack.hpp" +#include "duckdb/common/index_map.hpp" + +namespace duckdb { + +//! Dependency Manager local to a table, responsible for keeping track of generated column dependencies + +class ColumnDependencyManager { +public: + DUCKDB_API ColumnDependencyManager(); + DUCKDB_API ~ColumnDependencyManager(); + ColumnDependencyManager(ColumnDependencyManager &&other) = default; + ColumnDependencyManager(const ColumnDependencyManager &other) = delete; + +public: + //! Get the bind order that ensures dependencies are resolved before dependents are + stack GetBindOrder(const ColumnList &columns); + + //! Adds a connection between the dependent and its dependencies + void AddGeneratedColumn(LogicalIndex index, const vector &indices, bool root = true); + //! Add a generated column from a column definition + void AddGeneratedColumn(const ColumnDefinition &column, const ColumnList &list); + + //! Removes the column(s) and outputs the new column indices + vector RemoveColumn(LogicalIndex index, idx_t column_amount); + + bool IsDependencyOf(LogicalIndex dependent, LogicalIndex dependency) const; + bool HasDependencies(LogicalIndex index) const; + const logical_index_set_t &GetDependencies(LogicalIndex index) const; + + bool HasDependents(LogicalIndex index) const; + const logical_index_set_t &GetDependents(LogicalIndex index) const; + +private: + void RemoveStandardColumn(LogicalIndex index); + void RemoveGeneratedColumn(LogicalIndex index); + + void AdjustSingle(LogicalIndex idx, idx_t offset); + // Clean up the gaps created by a Remove operation + vector CleanupInternals(idx_t column_amount); + +private: + //! A map of column dependency to generated column(s) + logical_index_map_t dependencies_map; + //! A map of generated column name to (potentially generated)column dependencies + logical_index_map_t dependents_map; + //! For resolve-order purposes, keep track of the 'direct' (not inherited) dependencies of a generated column + logical_index_map_t direct_dependencies; + logical_index_set_t deleted_columns; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp new file mode 100644 index 00000000..d65c324b --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/function/copy_function.hpp" + +namespace duckdb { + +class Catalog; +struct CreateCopyFunctionInfo; + +//! A table function in the catalog +class CopyFunctionCatalogEntry : public StandardEntry { +public: + static constexpr const CatalogType Type = CatalogType::COPY_FUNCTION_ENTRY; + static constexpr const char *Name = "copy function"; + +public: + CopyFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateCopyFunctionInfo &info); + + //! The copy function + CopyFunction function; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_index_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_index_entry.hpp new file mode 100644 index 00000000..270c0748 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_index_entry.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/duck_index_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" + +namespace duckdb { + +//! An index catalog entry +class DuckIndexEntry : public IndexCatalogEntry { +public: + //! Create an IndexCatalogEntry and initialize storage for it + DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info); + ~DuckIndexEntry(); + + shared_ptr info; + +public: + string GetSchemaName() const override; + string GetTableName() const override; + //! This drops in-memory index data and marks all blocks on disk as free blocks, allowing to reclaim them + void CommitDrop(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_schema_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_schema_entry.hpp new file mode 100644 index 00000000..b1dea57b --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_schema_entry.hpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/dschema_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" + +namespace duckdb { + +//! A schema in the catalog +class DuckSchemaEntry : public SchemaCatalogEntry { +public: + DuckSchemaEntry(Catalog &catalog, string name, bool is_internal); + +private: + //! The catalog set holding the tables + CatalogSet tables; + //! The catalog set holding the indexes + CatalogSet indexes; + //! The catalog set holding the table functions + CatalogSet table_functions; + //! The catalog set holding the copy functions + CatalogSet copy_functions; + //! The catalog set holding the pragma functions + CatalogSet pragma_functions; + //! The catalog set holding the scalar and aggregate functions + CatalogSet functions; + //! The catalog set holding the sequences + CatalogSet sequences; + //! The catalog set holding the collations + CatalogSet collations; + //! The catalog set holding the types + CatalogSet types; + +public: + optional_ptr AddEntry(CatalogTransaction transaction, unique_ptr entry, + OnCreateConflict on_conflict); + optional_ptr AddEntryInternal(CatalogTransaction transaction, unique_ptr entry, + OnCreateConflict on_conflict, DependencyList dependencies); + + optional_ptr CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) override; + optional_ptr CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) override; + optional_ptr CreateIndex(ClientContext &context, CreateIndexInfo &info, + TableCatalogEntry &table) override; + optional_ptr CreateView(CatalogTransaction transaction, CreateViewInfo &info) override; + optional_ptr CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) override; + optional_ptr CreateTableFunction(CatalogTransaction transaction, + CreateTableFunctionInfo &info) override; + optional_ptr CreateCopyFunction(CatalogTransaction transaction, + CreateCopyFunctionInfo &info) override; + optional_ptr CreatePragmaFunction(CatalogTransaction transaction, + CreatePragmaFunctionInfo &info) override; + optional_ptr CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) override; + optional_ptr CreateType(CatalogTransaction transaction, CreateTypeInfo &info) override; + void Alter(ClientContext &context, AlterInfo &info) override; + void Scan(ClientContext &context, CatalogType type, const std::function &callback) override; + void Scan(CatalogType type, const std::function &callback) override; + void DropEntry(ClientContext &context, DropInfo &info) override; + optional_ptr GetEntry(CatalogTransaction transaction, CatalogType type, const string &name) override; + SimilarCatalogEntry GetSimilarEntry(CatalogTransaction transaction, CatalogType type, const string &name) override; + + void Verify(Catalog &catalog) override; + +private: + //! Get the catalog set for the specified type + CatalogSet &GetCatalogSet(CatalogType type); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp new file mode 100644 index 00000000..b67c9cdb --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/dtable_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +namespace duckdb { + +//! A table catalog entry +class DuckTableEntry : public TableCatalogEntry { +public: + //! Create a TableCatalogEntry and initialize storage for it + DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, + std::shared_ptr inherited_storage = nullptr); + +public: + unique_ptr AlterEntry(ClientContext &context, AlterInfo &info) override; + void UndoAlter(ClientContext &context, AlterInfo &info) override; + //! Returns the underlying storage of the table + DataTable &GetStorage() override; + //! Returns a list of the bound constraints of the table + const vector> &GetBoundConstraints() override; + + //! Get statistics of a column (physical or virtual) within the table + unique_ptr GetStatistics(ClientContext &context, column_t column_id) override; + + unique_ptr Copy(ClientContext &context) const override; + + void SetAsRoot() override; + + void CommitAlter(string &column_name); + void CommitDrop(); + + TableFunction GetScanFunction(ClientContext &context, unique_ptr &bind_data) override; + + vector GetColumnSegmentInfo() override; + + TableStorageInfo GetStorageInfo(ClientContext &context) override; + + bool IsDuckTable() const override { + return true; + } + +private: + unique_ptr RenameColumn(ClientContext &context, RenameColumnInfo &info); + unique_ptr AddColumn(ClientContext &context, AddColumnInfo &info); + unique_ptr RemoveColumn(ClientContext &context, RemoveColumnInfo &info); + unique_ptr SetDefault(ClientContext &context, SetDefaultInfo &info); + unique_ptr ChangeColumnType(ClientContext &context, ChangeColumnTypeInfo &info); + unique_ptr SetNotNull(ClientContext &context, SetNotNullInfo &info); + unique_ptr DropNotNull(ClientContext &context, DropNotNullInfo &info); + unique_ptr AddForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info); + unique_ptr DropForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info); + + void UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_index, const vector &adjusted_indices, + const RemoveColumnInfo &info, CreateTableInfo &create_info, bool is_generated); + +private: + //! A reference to the underlying storage unit used for this table + std::shared_ptr storage; + //! A list of constraints that are part of this table + vector> bound_constraints; + //! Manages dependencies of the individual columns of the table + ColumnDependencyManager column_dependency_manager; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/function_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/function_entry.hpp new file mode 100644 index 00000000..69f4918c --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/function_entry.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/function_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/parser/parsed_data/create_function_info.hpp" + +namespace duckdb { + +//! An aggregate function in the catalog +class FunctionEntry : public StandardEntry { +public: + FunctionEntry(CatalogType type, Catalog &catalog, SchemaCatalogEntry &schema, CreateFunctionInfo &info) + : StandardEntry(type, schema, catalog, info.name) { + description = std::move(info.description); + parameter_names = std::move(info.parameter_names); + example = std::move(info.example); + } + + //! The description (if any) + string description; + //! Parameter names (if any) + vector parameter_names; + //! The example (if any) + string example; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/index_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/index_catalog_entry.hpp new file mode 100644 index 00000000..90979804 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/index_catalog_entry.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/index_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/storage/metadata/metadata_writer.hpp" + +namespace duckdb { + +struct DataTableInfo; +class Index; + +//! An index catalog entry +class IndexCatalogEntry : public StandardEntry { +public: + static constexpr const CatalogType Type = CatalogType::INDEX_ENTRY; + static constexpr const char *Name = "index"; + +public: + //! Create an IndexCatalogEntry and initialize storage for it + IndexCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info); + + optional_ptr index; + string sql; + vector> expressions; + vector> parsed_expressions; + case_insensitive_map_t options; + +public: + unique_ptr GetInfo() const override; + string ToSQL() const override; + + virtual string GetSchemaName() const = 0; + virtual string GetTableName() const = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/list.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/list.hpp new file mode 100644 index 00000000..7f71bf74 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/list.hpp @@ -0,0 +1,12 @@ +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/macro_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/macro_catalog_entry.hpp new file mode 100644 index 00000000..938ffee2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/macro_catalog_entry.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/macro_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/catalog/catalog_entry/function_entry.hpp" +#include "duckdb/function/macro_function.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" + +namespace duckdb { + +//! A macro function in the catalog +class MacroCatalogEntry : public FunctionEntry { +public: + MacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info); + + //! The macro function + unique_ptr function; + +public: + unique_ptr GetInfo() const override; + + string ToSQL() const override { + return function->ToSQL(schema.name, name); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp new file mode 100644 index 00000000..5a6e5731 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/function_entry.hpp" +#include "duckdb/function/pragma_function.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +class Catalog; +struct CreatePragmaFunctionInfo; + +//! A table function in the catalog +class PragmaFunctionCatalogEntry : public FunctionEntry { +public: + static constexpr const CatalogType Type = CatalogType::PRAGMA_FUNCTION_ENTRY; + static constexpr const char *Name = "pragma function"; + +public: + PragmaFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreatePragmaFunctionInfo &info); + + //! The pragma functions + PragmaFunctionSet functions; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp new file mode 100644 index 00000000..a4440a86 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/function_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" + +namespace duckdb { + +//! A table function in the catalog +class ScalarFunctionCatalogEntry : public FunctionEntry { +public: + static constexpr const CatalogType Type = CatalogType::SCALAR_FUNCTION_ENTRY; + static constexpr const char *Name = "scalar function"; + +public: + ScalarFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateScalarFunctionInfo &info); + + //! The scalar functions + ScalarFunctionSet functions; + +public: + unique_ptr AlterEntry(ClientContext &context, AlterInfo &info) override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp new file mode 100644 index 00000000..36ec09b7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/macro_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/catalog/catalog_entry/macro_catalog_entry.hpp" + +namespace duckdb { + +//! A macro function in the catalog +class ScalarMacroCatalogEntry : public MacroCatalogEntry { +public: + static constexpr const CatalogType Type = CatalogType::MACRO_ENTRY; + static constexpr const char *Name = "macro function"; + +public: + ScalarMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/schema_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/schema_catalog_entry.hpp new file mode 100644 index 00000000..ead17570 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/schema_catalog_entry.hpp @@ -0,0 +1,98 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/schema_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/parser/query_error_context.hpp" + +namespace duckdb { +class ClientContext; + +class StandardEntry; +class TableCatalogEntry; +class TableFunctionCatalogEntry; +class SequenceCatalogEntry; + +enum class OnCreateConflict : uint8_t; + +struct AlterTableInfo; +struct CreateIndexInfo; +struct CreateFunctionInfo; +struct CreateCollationInfo; +struct CreateViewInfo; +struct BoundCreateTableInfo; +struct CreatePragmaFunctionInfo; +struct CreateSequenceInfo; +struct CreateSchemaInfo; +struct CreateTableFunctionInfo; +struct CreateCopyFunctionInfo; +struct CreateTypeInfo; + +struct DropInfo; + +//! A schema in the catalog +class SchemaCatalogEntry : public InCatalogEntry { +public: + static constexpr const CatalogType Type = CatalogType::SCHEMA_ENTRY; + static constexpr const char *Name = "schema"; + +public: + SchemaCatalogEntry(Catalog &catalog, string name, bool is_internal); + +public: + unique_ptr GetInfo() const override; + + //! Scan the specified catalog set, invoking the callback method for every entry + virtual void Scan(ClientContext &context, CatalogType type, + const std::function &callback) = 0; + //! Scan the specified catalog set, invoking the callback method for every committed entry + virtual void Scan(CatalogType type, const std::function &callback) = 0; + + string ToSQL() const override; + + //! Creates an index with the given name in the schema + virtual optional_ptr CreateIndex(ClientContext &context, CreateIndexInfo &info, + TableCatalogEntry &table) = 0; + //! Create a scalar or aggregate function within the given schema + virtual optional_ptr CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) = 0; + //! Creates a table with the given name in the schema + virtual optional_ptr CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) = 0; + //! Creates a view with the given name in the schema + virtual optional_ptr CreateView(CatalogTransaction transaction, CreateViewInfo &info) = 0; + //! Creates a sequence with the given name in the schema + virtual optional_ptr CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) = 0; + //! Create a table function within the given schema + virtual optional_ptr CreateTableFunction(CatalogTransaction transaction, + CreateTableFunctionInfo &info) = 0; + //! Create a copy function within the given schema + virtual optional_ptr CreateCopyFunction(CatalogTransaction transaction, + CreateCopyFunctionInfo &info) = 0; + //! Create a pragma function within the given schema + virtual optional_ptr CreatePragmaFunction(CatalogTransaction transaction, + CreatePragmaFunctionInfo &info) = 0; + //! Create a collation within the given schema + virtual optional_ptr CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) = 0; + //! Create a enum within the given schema + virtual optional_ptr CreateType(CatalogTransaction transaction, CreateTypeInfo &info) = 0; + + DUCKDB_API virtual optional_ptr GetEntry(CatalogTransaction transaction, CatalogType type, + const string &name) = 0; + DUCKDB_API virtual SimilarCatalogEntry GetSimilarEntry(CatalogTransaction transaction, CatalogType type, + const string &name); + + //! Drops an entry from the schema + virtual void DropEntry(ClientContext &context, DropInfo &info) = 0; + + //! Alters a catalog entry + virtual void Alter(ClientContext &context, AlterInfo &info) = 0; + + CatalogTransaction GetCatalogTransaction(ClientContext &context); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp new file mode 100644 index 00000000..631df988 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" + +namespace duckdb { + +struct SequenceValue { + SequenceValue() : usage_count(0), counter(-1) { + } + SequenceValue(uint64_t usage_count, int64_t counter) : usage_count(usage_count), counter(counter) { + } + + uint64_t usage_count; + int64_t counter; +}; + +//! A sequence catalog entry +class SequenceCatalogEntry : public StandardEntry { +public: + static constexpr const CatalogType Type = CatalogType::SEQUENCE_ENTRY; + static constexpr const char *Name = "sequence"; + +public: + //! Create a real TableCatalogEntry and initialize storage for it + SequenceCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateSequenceInfo &info); + + //! Lock for getting a value on the sequence + mutex lock; + //! The amount of times the sequence has been used + uint64_t usage_count; + //! The sequence counter + int64_t counter; + //! The most recently returned value + int64_t last_value; + //! The increment value + int64_t increment; + //! The minimum value of the sequence + int64_t start_value; + //! The minimum value of the sequence + int64_t min_value; + //! The maximum value of the sequence + int64_t max_value; + //! Whether or not the sequence cycles + bool cycle; + +public: + unique_ptr GetInfo() const override; + + string ToSQL() const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp new file mode 100644 index 00000000..207e52e7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -0,0 +1,116 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/table_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/parser/column_list.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/planner/bound_constraint.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/catalog/catalog_entry/table_column_type.hpp" +#include "duckdb/catalog/catalog_entry/column_dependency_manager.hpp" + +namespace duckdb { + +class DataTable; +struct CreateTableInfo; +struct BoundCreateTableInfo; + +struct RenameColumnInfo; +struct AddColumnInfo; +struct RemoveColumnInfo; +struct SetDefaultInfo; +struct ChangeColumnTypeInfo; +struct AlterForeignKeyInfo; +struct SetNotNullInfo; +struct DropNotNullInfo; + +class TableFunction; +struct FunctionData; + +class TableColumnInfo; +struct ColumnSegmentInfo; +class TableStorageInfo; + +class LogicalGet; +class LogicalProjection; +class LogicalUpdate; + +//! A table catalog entry +class TableCatalogEntry : public StandardEntry { +public: + static constexpr const CatalogType Type = CatalogType::TABLE_ENTRY; + static constexpr const char *Name = "table"; + +public: + //! Create a TableCatalogEntry and initialize storage for it + DUCKDB_API TableCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info); + +public: + DUCKDB_API unique_ptr GetInfo() const override; + + DUCKDB_API bool HasGeneratedColumns() const; + + //! Returns whether or not a column with the given name exists + DUCKDB_API bool ColumnExists(const string &name); + //! Returns a reference to the column of the specified name. Throws an + //! exception if the column does not exist. + DUCKDB_API const ColumnDefinition &GetColumn(const string &name); + //! Returns a reference to the column of the specified logical index. Throws an + //! exception if the column does not exist. + DUCKDB_API const ColumnDefinition &GetColumn(LogicalIndex idx); + //! Returns a list of types of the table, excluding generated columns + DUCKDB_API vector GetTypes(); + //! Returns a list of the columns of the table + DUCKDB_API const ColumnList &GetColumns() const; + //! Returns the underlying storage of the table + virtual DataTable &GetStorage(); + //! Returns a list of the bound constraints of the table + virtual const vector> &GetBoundConstraints(); + + //! Returns a list of the constraints of the table + DUCKDB_API const vector> &GetConstraints(); + DUCKDB_API string ToSQL() const override; + + //! Get statistics of a column (physical or virtual) within the table + virtual unique_ptr GetStatistics(ClientContext &context, column_t column_id) = 0; + + //! Returns the column index of the specified column name. + //! If the column does not exist: + //! If if_column_exists is true, returns DConstants::INVALID_INDEX + //! If if_column_exists is false, throws an exception + DUCKDB_API LogicalIndex GetColumnIndex(string &name, bool if_exists = false); + + //! Returns the scan function that can be used to scan the given table + virtual TableFunction GetScanFunction(ClientContext &context, unique_ptr &bind_data) = 0; + + virtual bool IsDuckTable() const { + return false; + } + + DUCKDB_API static string ColumnsToSQL(const ColumnList &columns, const vector> &constraints); + + //! Returns a list of segment information for this table, if exists + virtual vector GetColumnSegmentInfo(); + + //! Returns the storage info of this table + virtual TableStorageInfo GetStorageInfo(ClientContext &context) = 0; + + virtual void BindUpdateConstraints(LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, + ClientContext &context); + +protected: + //! A list of columns that are part of this table + ColumnList columns; + //! A list of constraints that are part of this table + vector> constraints; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_column_type.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_column_type.hpp new file mode 100644 index 00000000..4f21cde6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_column_type.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/table_column_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class TableColumnType : uint8_t { STANDARD = 0, GENERATED = 1 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp new file mode 100644 index 00000000..fc784e55 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/function_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/parser/parsed_data/create_table_function_info.hpp" + +namespace duckdb { + +//! A table function in the catalog +class TableFunctionCatalogEntry : public FunctionEntry { +public: + static constexpr const CatalogType Type = CatalogType::TABLE_FUNCTION_ENTRY; + static constexpr const char *Name = "table function"; + +public: + TableFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableFunctionInfo &info); + + //! The table function + TableFunctionSet functions; + +public: + unique_ptr AlterEntry(ClientContext &context, AlterInfo &info) override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp new file mode 100644 index 00000000..171d14e5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/macro_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/catalog/catalog_entry/macro_catalog_entry.hpp" + +namespace duckdb { + +//! A macro function in the catalog +class TableMacroCatalogEntry : public MacroCatalogEntry { +public: + static constexpr const CatalogType Type = CatalogType::TABLE_MACRO_ENTRY; + static constexpr const char *Name = "table macro function"; + +public: + TableMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/type_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/type_catalog_entry.hpp new file mode 100644 index 00000000..c5f61cfb --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/type_catalog_entry.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/type_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" + +namespace duckdb { + +//! A type catalog entry +class TypeCatalogEntry : public StandardEntry { +public: + static constexpr const CatalogType Type = CatalogType::TYPE_ENTRY; + static constexpr const char *Name = "type"; + +public: + //! Create a TypeCatalogEntry and initialize storage for it + TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTypeInfo &info); + + LogicalType user_type; + +public: + unique_ptr GetInfo() const override; + + string ToSQL() const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/view_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/view_catalog_entry.hpp new file mode 100644 index 00000000..459ed527 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/view_catalog_entry.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry/view_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class DataTable; +struct CreateViewInfo; + +//! A view catalog entry +class ViewCatalogEntry : public StandardEntry { +public: + static constexpr const CatalogType Type = CatalogType::VIEW_ENTRY; + static constexpr const char *Name = "view"; + +public: + //! Create a real TableCatalogEntry and initialize storage for it + ViewCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateViewInfo &info); + + //! The query of the view + unique_ptr query; + //! The SQL query (if any) + string sql; + //! The set of aliases associated with the view + vector aliases; + //! The returned types of the view + vector types; + +public: + unique_ptr GetInfo() const override; + + unique_ptr AlterEntry(ClientContext &context, AlterInfo &info) override; + + unique_ptr Copy(ClientContext &context) const override; + + string ToSQL() const override; + +private: + void Initialize(CreateViewInfo &info); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry_map.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry_map.hpp new file mode 100644 index 00000000..7557b3b7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry_map.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_entry_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/unordered_map.hpp" + +namespace duckdb { +class CatalogEntry; + +struct CatalogEntryHashFunction { + uint64_t operator()(const reference &a) const { + std::hash hash_func; + return hash_func((void *)&a.get()); + } +}; + +struct CatalogEntryEquality { + bool operator()(const reference &a, const reference &b) const { + return RefersToSameObject(a, b); + } +}; + +using catalog_entry_set_t = unordered_set, CatalogEntryHashFunction, CatalogEntryEquality>; + +template +using catalog_entry_map_t = unordered_map, T, CatalogEntryHashFunction, CatalogEntryEquality>; + +using catalog_entry_vector_t = vector>; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_search_path.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_search_path.hpp new file mode 100644 index 00000000..479d6c52 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_search_path.hpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_search_path.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +class ClientContext; + +struct CatalogSearchEntry { + CatalogSearchEntry(string catalog, string schema); + + string catalog; + string schema; + +public: + string ToString() const; + static string ListToString(const vector &input); + static CatalogSearchEntry Parse(const string &input); + static vector ParseList(const string &input); + +private: + static CatalogSearchEntry ParseInternal(const string &input, idx_t &pos); + static string WriteOptionallyQuoted(const string &input); +}; + +enum class CatalogSetPathType { SET_SCHEMA, SET_SCHEMAS }; + +//! The schema search path, in order by which entries are searched if no schema entry is provided +class CatalogSearchPath { +public: + DUCKDB_API explicit CatalogSearchPath(ClientContext &client_p); + CatalogSearchPath(const CatalogSearchPath &other) = delete; + + DUCKDB_API void Set(CatalogSearchEntry new_value, CatalogSetPathType set_type); + DUCKDB_API void Set(vector new_paths, CatalogSetPathType set_type); + DUCKDB_API void Reset(); + + DUCKDB_API const vector &Get(); + const vector &GetSetPaths() { + return set_paths; + } + DUCKDB_API const CatalogSearchEntry &GetDefault(); + DUCKDB_API string GetDefaultSchema(const string &catalog); + DUCKDB_API string GetDefaultCatalog(const string &schema); + + DUCKDB_API vector GetSchemasForCatalog(const string &catalog); + DUCKDB_API vector GetCatalogsForSchema(const string &schema); + + DUCKDB_API bool SchemaInSearchPath(ClientContext &context, const string &catalog_name, const string &schema_name); + +private: + void SetPaths(vector new_paths); + + string GetSetName(CatalogSetPathType set_type); + +private: + ClientContext &context; + vector paths; + //! Only the paths that were explicitly set (minus the always included paths) + vector set_paths; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp new file mode 100644 index 00000000..dba2d661 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp @@ -0,0 +1,165 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/catalog/default/default_generator.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/transaction/transaction.hpp" +#include "duckdb/catalog/catalog_transaction.hpp" +#include "duckdb/catalog/similar_catalog_entry.hpp" +#include +#include + +namespace duckdb { +struct AlterInfo; + +class ClientContext; +class DependencyList; +struct MappingValue; +struct EntryIndex; + +class DuckCatalog; +class TableCatalogEntry; +class SequenceCatalogEntry; + +typedef unordered_map> set_lock_map_t; + +struct EntryValue { + EntryValue() { + throw InternalException("EntryValue called without a catalog entry"); + } + + explicit EntryValue(unique_ptr entry_p) : entry(std::move(entry_p)), reference_count(0) { + } + //! enable move constructors + EntryValue(EntryValue &&other) noexcept { + Swap(other); + } + EntryValue &operator=(EntryValue &&other) noexcept { + Swap(other); + return *this; + } + void Swap(EntryValue &other) { + std::swap(entry, other.entry); + idx_t count = reference_count; + reference_count = other.reference_count.load(); + other.reference_count = count; + } + + unique_ptr entry; + atomic reference_count; +}; + +//! The Catalog Set stores (key, value) map of a set of CatalogEntries +class CatalogSet { + friend class DependencyManager; + friend class EntryDropper; + friend struct EntryIndex; + +public: + DUCKDB_API explicit CatalogSet(Catalog &catalog, unique_ptr defaults = nullptr); + ~CatalogSet(); + + //! Create an entry in the catalog set. Returns whether or not it was + //! successful. + DUCKDB_API bool CreateEntry(CatalogTransaction transaction, const string &name, unique_ptr value, + DependencyList &dependencies); + DUCKDB_API bool CreateEntry(ClientContext &context, const string &name, unique_ptr value, + DependencyList &dependencies); + + DUCKDB_API bool AlterEntry(CatalogTransaction transaction, const string &name, AlterInfo &alter_info); + + DUCKDB_API bool DropEntry(CatalogTransaction transaction, const string &name, bool cascade, + bool allow_drop_internal = false); + DUCKDB_API bool DropEntry(ClientContext &context, const string &name, bool cascade, + bool allow_drop_internal = false); + + DUCKDB_API DuckCatalog &GetCatalog(); + + bool AlterOwnership(CatalogTransaction transaction, ChangeOwnershipInfo &info); + + void CleanupEntry(CatalogEntry &catalog_entry); + + //! Returns the entry with the specified name + DUCKDB_API optional_ptr GetEntry(CatalogTransaction transaction, const string &name); + DUCKDB_API optional_ptr GetEntry(ClientContext &context, const string &name); + + //! Gets the entry that is most similar to the given name (i.e. smallest levenshtein distance), or empty string if + //! none is found. The returned pair consists of the entry name and the distance (smaller means closer). + SimilarCatalogEntry SimilarEntry(CatalogTransaction transaction, const string &name); + + //! Rollback to be the currently valid entry for a certain catalog + //! entry + void Undo(CatalogEntry &entry); + + //! Scan the catalog set, invoking the callback method for every committed entry + DUCKDB_API void Scan(const std::function &callback); + //! Scan the catalog set, invoking the callback method for every entry + DUCKDB_API void Scan(CatalogTransaction transaction, const std::function &callback); + DUCKDB_API void Scan(ClientContext &context, const std::function &callback); + + template + vector> GetEntries(CatalogTransaction transaction) { + vector> result; + Scan(transaction, [&](CatalogEntry &entry) { result.push_back(entry.Cast()); }); + return result; + } + + DUCKDB_API bool HasConflict(CatalogTransaction transaction, transaction_t timestamp); + DUCKDB_API bool UseTimestamp(CatalogTransaction transaction, transaction_t timestamp); + + void UpdateTimestamp(CatalogEntry &entry, transaction_t timestamp); + + void Verify(Catalog &catalog); + +private: + //! Given a root entry, gets the entry valid for this transaction + CatalogEntry &GetEntryForTransaction(CatalogTransaction transaction, CatalogEntry ¤t); + CatalogEntry &GetCommittedEntry(CatalogEntry ¤t); + optional_ptr GetEntryInternal(CatalogTransaction transaction, const string &name, + EntryIndex *entry_index); + optional_ptr GetEntryInternal(CatalogTransaction transaction, EntryIndex &entry_index); + //! Drops an entry from the catalog set; must hold the catalog_lock to safely call this + void DropEntryInternal(CatalogTransaction transaction, EntryIndex entry_index, CatalogEntry &entry, bool cascade); + optional_ptr CreateEntryInternal(CatalogTransaction transaction, unique_ptr entry); + optional_ptr GetMapping(CatalogTransaction transaction, const string &name, bool get_latest = false); + void PutMapping(CatalogTransaction transaction, const string &name, EntryIndex entry_index); + void DeleteMapping(CatalogTransaction transaction, const string &name); + void DropEntryDependencies(CatalogTransaction transaction, EntryIndex &entry_index, CatalogEntry &entry, + bool cascade); + + //! Create all default entries + void CreateDefaultEntries(CatalogTransaction transaction, unique_lock &lock); + //! Attempt to create a default entry with the specified name. Returns the entry if successful, nullptr otherwise. + optional_ptr CreateDefaultEntry(CatalogTransaction transaction, const string &name, + unique_lock &lock); + + EntryIndex PutEntry(idx_t entry_index, unique_ptr entry); + void PutEntry(EntryIndex index, unique_ptr entry); + +private: + DuckCatalog &catalog; + //! The catalog lock is used to make changes to the data + mutex catalog_lock; + //! The set of catalog entries + unordered_map entries; + //! Mapping of string to catalog entry + case_insensitive_map_t> mapping; + //! The current catalog entry index + idx_t current_entry = 0; + //! The generator used to generate default internal entries + unique_ptr defaults; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_transaction.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_transaction.hpp new file mode 100644 index 00000000..47fa68a7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/catalog_transaction.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/catalog_transaction.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { +class Catalog; +class ClientContext; +class DatabaseInstance; +class Transaction; + +struct CatalogTransaction { + CatalogTransaction(Catalog &catalog, ClientContext &context); + CatalogTransaction(DatabaseInstance &db, transaction_t transaction_id_p, transaction_t start_time_p); + + optional_ptr db; + optional_ptr context; + optional_ptr transaction; + transaction_t transaction_id; + transaction_t start_time; + + ClientContext &GetContext(); + + static CatalogTransaction GetSystemTransaction(DatabaseInstance &db); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp new file mode 100644 index 00000000..817aefbc --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/default/builtin_types/types.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is generated by scripts/generate_builtin_types.py + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/array.hpp" + +namespace duckdb { + +struct DefaultType { + const char *name; + LogicalTypeId type; +}; + +using builtin_type_array = std::array; + +static constexpr const builtin_type_array BUILTIN_TYPES{{ + {"decimal", LogicalTypeId::DECIMAL}, + {"dec", LogicalTypeId::DECIMAL}, + {"numeric", LogicalTypeId::DECIMAL}, + {"time", LogicalTypeId::TIME}, + {"date", LogicalTypeId::DATE}, + {"timestamp", LogicalTypeId::TIMESTAMP}, + {"datetime", LogicalTypeId::TIMESTAMP}, + {"timestamp_us", LogicalTypeId::TIMESTAMP}, + {"timestamp_ms", LogicalTypeId::TIMESTAMP_MS}, + {"timestamp_ns", LogicalTypeId::TIMESTAMP_NS}, + {"timestamp_s", LogicalTypeId::TIMESTAMP_SEC}, + {"timestamptz", LogicalTypeId::TIMESTAMP_TZ}, + {"timetz", LogicalTypeId::TIME_TZ}, + {"interval", LogicalTypeId::INTERVAL}, + {"varchar", LogicalTypeId::VARCHAR}, + {"bpchar", LogicalTypeId::VARCHAR}, + {"string", LogicalTypeId::VARCHAR}, + {"char", LogicalTypeId::VARCHAR}, + {"nvarchar", LogicalTypeId::VARCHAR}, + {"text", LogicalTypeId::VARCHAR}, + {"blob", LogicalTypeId::BLOB}, + {"bytea", LogicalTypeId::BLOB}, + {"varbinary", LogicalTypeId::BLOB}, + {"binary", LogicalTypeId::BLOB}, + {"hugeint", LogicalTypeId::HUGEINT}, + {"int128", LogicalTypeId::HUGEINT}, + {"bigint", LogicalTypeId::BIGINT}, + {"oid", LogicalTypeId::BIGINT}, + {"long", LogicalTypeId::BIGINT}, + {"int8", LogicalTypeId::BIGINT}, + {"int64", LogicalTypeId::BIGINT}, + {"ubigint", LogicalTypeId::UBIGINT}, + {"uint64", LogicalTypeId::UBIGINT}, + {"integer", LogicalTypeId::INTEGER}, + {"int", LogicalTypeId::INTEGER}, + {"int4", LogicalTypeId::INTEGER}, + {"signed", LogicalTypeId::INTEGER}, + {"integral", LogicalTypeId::INTEGER}, + {"int32", LogicalTypeId::INTEGER}, + {"uinteger", LogicalTypeId::UINTEGER}, + {"uint32", LogicalTypeId::UINTEGER}, + {"smallint", LogicalTypeId::SMALLINT}, + {"int2", LogicalTypeId::SMALLINT}, + {"short", LogicalTypeId::SMALLINT}, + {"int16", LogicalTypeId::SMALLINT}, + {"usmallint", LogicalTypeId::USMALLINT}, + {"uint16", LogicalTypeId::USMALLINT}, + {"tinyint", LogicalTypeId::TINYINT}, + {"int1", LogicalTypeId::TINYINT}, + {"utinyint", LogicalTypeId::UTINYINT}, + {"uint8", LogicalTypeId::UTINYINT}, + {"struct", LogicalTypeId::STRUCT}, + {"row", LogicalTypeId::STRUCT}, + {"list", LogicalTypeId::LIST}, + {"map", LogicalTypeId::MAP}, + {"union", LogicalTypeId::UNION}, + {"bit", LogicalTypeId::BIT}, + {"bitstring", LogicalTypeId::BIT}, + {"boolean", LogicalTypeId::BOOLEAN}, + {"bool", LogicalTypeId::BOOLEAN}, + {"logical", LogicalTypeId::BOOLEAN}, + {"uuid", LogicalTypeId::UUID}, + {"guid", LogicalTypeId::UUID}, + {"enum", LogicalTypeId::ENUM}, + {"null", LogicalTypeId::SQLNULL}, + {"float", LogicalTypeId::FLOAT}, + {"real", LogicalTypeId::FLOAT}, + {"float4", LogicalTypeId::FLOAT}, + {"double", LogicalTypeId::DOUBLE}, + {"float8", LogicalTypeId::DOUBLE} +}}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_functions.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_functions.hpp new file mode 100644 index 00000000..c35d438f --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/default/default_functions.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/default/default_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/default/default_generator.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" + +namespace duckdb { +class SchemaCatalogEntry; + +struct DefaultMacro { + const char *schema; + const char *name; + const char *parameters[8]; + const char *macro; +}; + +class DefaultFunctionGenerator : public DefaultGenerator { +public: + DefaultFunctionGenerator(Catalog &catalog, SchemaCatalogEntry &schema); + + SchemaCatalogEntry &schema; + + DUCKDB_API static unique_ptr CreateInternalMacroInfo(DefaultMacro &default_macro); + DUCKDB_API static unique_ptr CreateInternalTableMacroInfo(DefaultMacro &default_macro); + +public: + unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; + vector GetDefaultEntries() override; + +private: + static unique_ptr CreateInternalTableMacroInfo(DefaultMacro &default_macro, + unique_ptr function); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_generator.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_generator.hpp new file mode 100644 index 00000000..6abc0ecc --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/default/default_generator.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/default/default_generator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/common/atomic.hpp" + +namespace duckdb { +class ClientContext; + +class DefaultGenerator { +public: + explicit DefaultGenerator(Catalog &catalog) : catalog(catalog), created_all_entries(false) { + } + virtual ~DefaultGenerator() { + } + + Catalog &catalog; + atomic created_all_entries; + +public: + //! Creates a default entry with the specified name, or returns nullptr if no such entry can be generated + virtual unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) = 0; + //! Get a list of all default entries in the generator + virtual vector GetDefaultEntries() = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_schemas.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_schemas.hpp new file mode 100644 index 00000000..673425c9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/default/default_schemas.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/default/default_schemas.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/default/default_generator.hpp" + +namespace duckdb { + +class DefaultSchemaGenerator : public DefaultGenerator { +public: + explicit DefaultSchemaGenerator(Catalog &catalog); + +public: + unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; + vector GetDefaultEntries() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_types.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_types.hpp new file mode 100644 index 00000000..83d90982 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/default/default_types.hpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/default/default_types.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/catalog/default/default_generator.hpp" + +namespace duckdb { +class SchemaCatalogEntry; + +class DefaultTypeGenerator : public DefaultGenerator { +public: + DefaultTypeGenerator(Catalog &catalog, SchemaCatalogEntry &schema); + + SchemaCatalogEntry &schema; + +public: + DUCKDB_API static LogicalTypeId GetDefaultType(const string &name); + + unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; + vector GetDefaultEntries() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_views.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_views.hpp new file mode 100644 index 00000000..6ebd29eb --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/default/default_views.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/default/default_views.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/catalog/default/default_generator.hpp" + +namespace duckdb { +class SchemaCatalogEntry; + +class DefaultViewGenerator : public DefaultGenerator { +public: + DefaultViewGenerator(Catalog &catalog, SchemaCatalogEntry &schema); + + SchemaCatalogEntry &schema; + +public: + unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; + vector GetDefaultEntries() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/dependency.hpp b/src/duckdb/src/include/duckdb/catalog/dependency.hpp new file mode 100644 index 00000000..987dd89b --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/dependency.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/dependency.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { +class CatalogEntry; + +enum class DependencyType { + DEPENDENCY_REGULAR = 0, + DEPENDENCY_AUTOMATIC = 1, + DEPENDENCY_OWNS = 2, + DEPENDENCY_OWNED_BY = 3 +}; + +struct Dependency { + Dependency(CatalogEntry &entry, DependencyType dependency_type = DependencyType::DEPENDENCY_REGULAR) + : // NOLINT: Allow implicit conversion from `CatalogEntry` + entry(entry), dependency_type(dependency_type) { + } + + //! The catalog entry this depends on + reference entry; + //! The type of dependency + DependencyType dependency_type; +}; + +struct DependencyHashFunction { + uint64_t operator()(const Dependency &a) const { + std::hash hash_func; + return hash_func((void *)&a.entry.get()); + } +}; + +struct DependencyEquality { + bool operator()(const Dependency &a, const Dependency &b) const { + return RefersToSameObject(a.entry, b.entry); + } +}; +using dependency_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/dependency_list.hpp b/src/duckdb/src/include/duckdb/catalog/dependency_list.hpp new file mode 100644 index 00000000..cf822097 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/dependency_list.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/dependency_list.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry_map.hpp" + +namespace duckdb { +class Catalog; +class CatalogEntry; + +//! The DependencyList +class DependencyList { + friend class DependencyManager; + +public: + DUCKDB_API void AddDependency(CatalogEntry &entry); + + DUCKDB_API void VerifyDependencies(Catalog &catalog, const string &name); + +private: + catalog_entry_set_t set; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp b/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp new file mode 100644 index 00000000..3ebab885 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/dependency_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/catalog/dependency.hpp" +#include "duckdb/catalog/catalog_entry_map.hpp" +#include "duckdb/catalog/catalog_transaction.hpp" + +#include + +namespace duckdb { +class DuckCatalog; +class ClientContext; +class DependencyList; + +//! The DependencyManager is in charge of managing dependencies between catalog entries +class DependencyManager { + friend class CatalogSet; + +public: + explicit DependencyManager(DuckCatalog &catalog); + + //! Erase the object from the DependencyManager; this should only happen when the object itself is destroyed + void EraseObject(CatalogEntry &object); + + //! Scans all dependencies, returning pairs of (object, dependent) + void Scan(const std::function &callback); + + void AddOwnership(CatalogTransaction transaction, CatalogEntry &owner, CatalogEntry &entry); + +private: + DuckCatalog &catalog; + //! Map of objects that DEPEND on [object], i.e. [object] can only be deleted when all entries in the dependency map + //! are deleted. + catalog_entry_map_t dependents_map; + //! Map of objects that the source object DEPENDS on, i.e. when any of the entries in the vector perform a CASCADE + //! drop then [object] is deleted as well + catalog_entry_map_t dependencies_map; + +private: + void AddObject(CatalogTransaction transaction, CatalogEntry &object, DependencyList &dependencies); + void DropObject(CatalogTransaction transaction, CatalogEntry &object, bool cascade); + void AlterObject(CatalogTransaction transaction, CatalogEntry &old_obj, CatalogEntry &new_obj); + void EraseObjectInternal(CatalogEntry &object); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/duck_catalog.hpp b/src/duckdb/src/include/duckdb/catalog/duck_catalog.hpp new file mode 100644 index 00000000..26cf9b86 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/duck_catalog.hpp @@ -0,0 +1,77 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/dcatalog.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +//! The Catalog object represents the catalog of the database. +class DuckCatalog : public Catalog { +public: + explicit DuckCatalog(AttachedDatabase &db); + ~DuckCatalog(); + +public: + bool IsDuckCatalog() override; + void Initialize(bool load_builtin) override; + string GetCatalogType() override { + return "duckdb"; + } + + DependencyManager &GetDependencyManager() { + return *dependency_manager; + } + mutex &GetWriteLock() { + return write_lock; + } + +public: + DUCKDB_API optional_ptr CreateSchema(CatalogTransaction transaction, CreateSchemaInfo &info) override; + DUCKDB_API void ScanSchemas(ClientContext &context, std::function callback) override; + DUCKDB_API void ScanSchemas(std::function callback); + + DUCKDB_API optional_ptr + GetSchema(CatalogTransaction transaction, const string &schema_name, OnEntryNotFound if_not_found, + QueryErrorContext error_context = QueryErrorContext()) override; + + DUCKDB_API unique_ptr PlanCreateTableAs(ClientContext &context, LogicalCreateTable &op, + unique_ptr plan) override; + DUCKDB_API unique_ptr PlanInsert(ClientContext &context, LogicalInsert &op, + unique_ptr plan) override; + DUCKDB_API unique_ptr PlanDelete(ClientContext &context, LogicalDelete &op, + unique_ptr plan) override; + DUCKDB_API unique_ptr PlanUpdate(ClientContext &context, LogicalUpdate &op, + unique_ptr plan) override; + DUCKDB_API unique_ptr BindCreateIndex(Binder &binder, CreateStatement &stmt, + TableCatalogEntry &table, + unique_ptr plan) override; + + DatabaseSize GetDatabaseSize(ClientContext &context) override; + vector GetMetadataInfo(ClientContext &context) override; + + DUCKDB_API bool InMemory() override; + DUCKDB_API string GetDBPath() override; + +private: + DUCKDB_API void DropSchema(CatalogTransaction transaction, DropInfo &info); + DUCKDB_API void DropSchema(ClientContext &context, DropInfo &info) override; + optional_ptr CreateSchemaInternal(CatalogTransaction transaction, CreateSchemaInfo &info); + void Verify() override; + +private: + //! The DependencyManager manages dependencies between different catalog objects + unique_ptr dependency_manager; + //! Write lock for the catalog + mutex write_lock; + //! The catalog set holding the schemas + unique_ptr schemas; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/mapping_value.hpp b/src/duckdb/src/include/duckdb/catalog/mapping_value.hpp new file mode 100644 index 00000000..50d16c04 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/mapping_value.hpp @@ -0,0 +1,92 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/mapping_value.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" + +namespace duckdb { +struct AlterInfo; + +class ClientContext; + +struct EntryIndex { + EntryIndex() : catalog(nullptr), index(DConstants::INVALID_INDEX) { + } + EntryIndex(CatalogSet &catalog, idx_t index) : catalog(&catalog), index(index) { + auto entry = catalog.entries.find(index); + if (entry == catalog.entries.end()) { + throw InternalException("EntryIndex - Catalog entry not found in constructor!?"); + } + catalog.entries[index].reference_count++; + } + ~EntryIndex() { + if (!catalog) { + return; + } + auto entry = catalog->entries.find(index); + D_ASSERT(entry != catalog->entries.end()); + auto remaining_ref = --entry->second.reference_count; + if (remaining_ref == 0) { + catalog->entries.erase(index); + } + catalog = nullptr; + } + // disable copy constructors + EntryIndex(const EntryIndex &other) = delete; + EntryIndex &operator=(const EntryIndex &) = delete; + //! enable move constructors + EntryIndex(EntryIndex &&other) noexcept { + catalog = nullptr; + index = DConstants::INVALID_INDEX; + std::swap(catalog, other.catalog); + std::swap(index, other.index); + } + EntryIndex &operator=(EntryIndex &&other) noexcept { + std::swap(catalog, other.catalog); + std::swap(index, other.index); + return *this; + } + + unique_ptr &GetEntry() { + auto entry = catalog->entries.find(index); + if (entry == catalog->entries.end()) { + throw InternalException("EntryIndex - Catalog entry not found!?"); + } + return entry->second.entry; + } + idx_t GetIndex() { + return index; + } + EntryIndex Copy() { + if (catalog) { + return EntryIndex(*catalog, index); + } else { + return EntryIndex(); + } + } + +private: + CatalogSet *catalog; + idx_t index; +}; + +struct MappingValue { + explicit MappingValue(EntryIndex index_p) + : index(std::move(index_p)), timestamp(0), deleted(false), parent(nullptr) { + } + + EntryIndex index; + transaction_t timestamp; + bool deleted; + unique_ptr child; + MappingValue *parent; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/similar_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/similar_catalog_entry.hpp new file mode 100644 index 00000000..ca371153 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/similar_catalog_entry.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/similar_catalog_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { +class SchemaCatalogEntry; + +//! Return value of SimilarEntryInSchemas +struct SimilarCatalogEntry { + //! The entry name. Empty if absent + string name; + //! The distance to the given name. + idx_t distance = idx_t(-1); + //! The schema of the entry. + optional_ptr schema; + + bool Found() const { + return !name.empty(); + } + + DUCKDB_API string GetQualifiedName(bool qualify_catalog, bool qualify_schema) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/standard_entry.hpp b/src/duckdb/src/include/duckdb/catalog/standard_entry.hpp new file mode 100644 index 00000000..2bbc6992 --- /dev/null +++ b/src/duckdb/src/include/duckdb/catalog/standard_entry.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/catalog/standard_entry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry.hpp" + +namespace duckdb { +class SchemaCatalogEntry; + +//! A StandardEntry is a catalog entry that is a member of a schema +class StandardEntry : public InCatalogEntry { +public: + StandardEntry(CatalogType type, SchemaCatalogEntry &schema, Catalog &catalog, string name) + : InCatalogEntry(type, catalog, name), schema(schema) { + } + ~StandardEntry() override { + } + + //! The schema the entry belongs to + SchemaCatalogEntry &schema; + +public: + SchemaCatalogEntry &ParentSchema() override { + return schema; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/adbc/adbc-init.hpp b/src/duckdb/src/include/duckdb/common/adbc/adbc-init.hpp new file mode 100644 index 00000000..d8302c66 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/adbc/adbc-init.hpp @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#ifndef DUCKDB_ADBC_INIT +#define DUCKDB_ADBC_INIT + +#include "duckdb/common/adbc/adbc.hpp" + +#ifdef __cplusplus +extern "C" { +#endif + +//! We gotta leak the symbols of the init function +duckdb_adbc::AdbcStatusCode duckdb_adbc_init(size_t count, struct duckdb_adbc::AdbcDriver *driver, + struct duckdb_adbc::AdbcError *error); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/duckdb/src/include/duckdb/common/adbc/adbc.h b/src/duckdb/src/include/duckdb/common/adbc/adbc.h new file mode 100644 index 00000000..b5f9faf4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/adbc/adbc.h @@ -0,0 +1,1089 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// \file adbc.h ADBC: Arrow Database connectivity +/// +/// An Arrow-based interface between applications and database +/// drivers. ADBC aims to provide a vendor-independent API for SQL +/// and Substrait-based database access that is targeted at +/// analytics/OLAP use cases. +/// +/// This API is intended to be implemented directly by drivers and +/// used directly by client applications. To assist portability +/// between different vendors, a "driver manager" library is also +/// provided, which implements this same API, but dynamically loads +/// drivers internally and forwards calls appropriately. +/// +/// ADBC uses structs with free functions that operate on those +/// structs to model objects. +/// +/// In general, objects allow serialized access from multiple threads, +/// but not concurrent access. Specific implementations may permit +/// multiple threads. +/// +/// \version 1.0.0 + +#pragma once + +#include +#include +#include "duckdb/common/arrow/arrow.hpp" + +/// \defgroup Arrow C Data Interface +/// Definitions for the C Data Interface/C Stream Interface. +/// +/// See https://arrow.apache.org/docs/format/CDataInterface.html +/// +/// @{ + +//! @cond Doxygen_Suppress +namespace duckdb_adbc { + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ARROW_ADBC +#define ARROW_ADBC + +#ifndef ADBC_EXPORTING +#define ADBC_EXPORTING +#endif + +// Storage class macros for Windows +// Allow overriding/aliasing with application-defined macros +#if !defined(ADBC_EXPORT) +#if defined(_WIN32) +#if defined(ADBC_EXPORTING) +#define ADBC_EXPORT __declspec(dllexport) +#else +#define ADBC_EXPORT __declspec(dllimport) +#endif // defined(ADBC_EXPORTING) +#else +#define ADBC_EXPORT +#endif // defined(_WIN32) +#endif // !defined(ADBC_EXPORT) + +/// \defgroup adbc-error-handling Error Handling +/// ADBC uses integer error codes to signal errors. To provide more +/// detail about errors, functions may also return an AdbcError via an +/// optional out parameter, which can be inspected. If provided, it is +/// the responsibility of the caller to zero-initialize the AdbcError +/// value. +/// +/// @{ + +/// \brief Error codes for operations that may fail. +typedef uint8_t AdbcStatusCode; + +/// \brief No error. +#define ADBC_STATUS_OK 0 +/// \brief An unknown error occurred. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_UNKNOWN 1 +/// \brief The operation is not implemented or supported. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_NOT_IMPLEMENTED 2 +/// \brief A requested resource was not found. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_NOT_FOUND 3 +/// \brief A requested resource already exists. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_ALREADY_EXISTS 4 +/// \brief The arguments are invalid, likely a programming error. +/// +/// For instance, they may be of the wrong format, or out of range. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_INVALID_ARGUMENT 5 +/// \brief The preconditions for the operation are not met, likely a +/// programming error. +/// +/// For instance, the object may be uninitialized, or may have not +/// been fully configured. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_INVALID_STATE 6 +/// \brief Invalid data was processed (not a programming error). +/// +/// For instance, a division by zero may have occurred during query +/// execution. +/// +/// May indicate a database-side error only. +#define ADBC_STATUS_INVALID_DATA 7 +/// \brief The database's integrity was affected. +/// +/// For instance, a foreign key check may have failed, or a uniqueness +/// constraint may have been violated. +/// +/// May indicate a database-side error only. +#define ADBC_STATUS_INTEGRITY 8 +/// \brief An error internal to the driver or database occurred. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_INTERNAL 9 +/// \brief An I/O error occurred. +/// +/// For instance, a remote service may be unavailable. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_IO 10 +/// \brief The operation was cancelled, not due to a timeout. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_CANCELLED 11 +/// \brief The operation was cancelled due to a timeout. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_TIMEOUT 12 +/// \brief Authentication failed. +/// +/// May indicate a database-side error only. +#define ADBC_STATUS_UNAUTHENTICATED 13 +/// \brief The client is not authorized to perform the given operation. +/// +/// May indicate a database-side error only. +#define ADBC_STATUS_UNAUTHORIZED 14 + +/// \brief A detailed error message for an operation. +struct ADBC_EXPORT AdbcError { + /// \brief The error message. + char *message; + + /// \brief A vendor-specific error code, if applicable. + int32_t vendor_code; + + /// \brief A SQLSTATE error code, if provided, as defined by the + /// SQL:2003 standard. If not set, it should be set to + /// "\0\0\0\0\0". + char sqlstate[5]; + + /// \brief Release the contained error. + /// + /// Unlike other structures, this is an embedded callback to make it + /// easier for the driver manager and driver to cooperate. + void (*release)(struct AdbcError *error); +}; + +/// @} + +/// \defgroup adbc-constants Constants +/// @{ + +/// \brief ADBC revision 1.0.0. +/// +/// When passed to an AdbcDriverInitFunc(), the driver parameter must +/// point to an AdbcDriver. +#define ADBC_VERSION_1_0_0 1000000 + +/// \brief Canonical option value for enabling an option. +/// +/// For use as the value in SetOption calls. +#define ADBC_OPTION_VALUE_ENABLED "true" +/// \brief Canonical option value for disabling an option. +/// +/// For use as the value in SetOption calls. +#define ADBC_OPTION_VALUE_DISABLED "false" + +/// \brief The database vendor/product name (e.g. the server name). +/// (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_VENDOR_NAME 0 +/// \brief The database vendor/product version (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_VENDOR_VERSION 1 +/// \brief The database vendor/product Arrow library version (type: +/// utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_VENDOR_ARROW_VERSION 2 + +/// \brief The driver name (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_DRIVER_NAME 100 +/// \brief The driver version (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_DRIVER_VERSION 101 +/// \brief The driver Arrow library version (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_DRIVER_ARROW_VERSION 102 + +/// \brief Return metadata on catalogs, schemas, tables, and columns. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_ALL 0 +/// \brief Return metadata on catalogs only. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_CATALOGS 1 +/// \brief Return metadata on catalogs and schemas. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_DB_SCHEMAS 2 +/// \brief Return metadata on catalogs, schemas, and tables. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_TABLES 3 +/// \brief Return metadata on catalogs, schemas, tables, and columns. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_COLUMNS ADBC_OBJECT_DEPTH_ALL + +/// \brief The name of the canonical option for whether autocommit is +/// enabled. +/// +/// \see AdbcConnectionSetOption +#define ADBC_CONNECTION_OPTION_AUTOCOMMIT "adbc.connection.autocommit" + +/// \brief The name of the canonical option for whether the current +/// connection should be restricted to being read-only. +/// +/// \see AdbcConnectionSetOption +#define ADBC_CONNECTION_OPTION_READ_ONLY "adbc.connection.readonly" + +/// \brief The name of the canonical option for setting the isolation +/// level of a transaction. +/// +/// Should only be used in conjunction with autocommit disabled and +/// AdbcConnectionCommit / AdbcConnectionRollback. If the desired +/// isolation level is not supported by a driver, it should return an +/// appropriate error. +/// +/// \see AdbcConnectionSetOption +#define ADBC_CONNECTION_OPTION_ISOLATION_LEVEL "adbc.connection.transaction.isolation_level" + +/// \brief Use database or driver default isolation level +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_DEFAULT "adbc.connection.transaction.isolation.default" + +/// \brief The lowest isolation level. Dirty reads are allowed, so one +/// transaction may see not-yet-committed changes made by others. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_READ_UNCOMMITTED "adbc.connection.transaction.isolation.read_uncommitted" + +/// \brief Lock-based concurrency control keeps write locks until the +/// end of the transaction, but read locks are released as soon as a +/// SELECT is performed. Non-repeatable reads can occur in this +/// isolation level. +/// +/// More simply put, Read Committed is an isolation level that guarantees +/// that any data read is committed at the moment it is read. It simply +/// restricts the reader from seeing any intermediate, uncommitted, +/// 'dirty' reads. It makes no promise whatsoever that if the transaction +/// re-issues the read, it will find the same data; data is free to change +/// after it is read. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_READ_COMMITTED "adbc.connection.transaction.isolation.read_committed" + +/// \brief Lock-based concurrency control keeps read AND write locks +/// (acquired on selection data) until the end of the transaction. +/// +/// However, range-locks are not managed, so phantom reads can occur. +/// Write skew is possible at this isolation level in some systems. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_REPEATABLE_READ "adbc.connection.transaction.isolation.repeatable_read" + +/// \brief This isolation guarantees that all reads in the transaction +/// will see a consistent snapshot of the database and the transaction +/// should only successfully commit if no updates conflict with any +/// concurrent updates made since that snapshot. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_SNAPSHOT "adbc.connection.transaction.isolation.snapshot" + +/// \brief Serializability requires read and write locks to be released +/// only at the end of the transaction. This includes acquiring range- +/// locks when a select query uses a ranged WHERE clause to avoid +/// phantom reads. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_SERIALIZABLE "adbc.connection.transaction.isolation.serializable" + +/// \brief The central distinction between serializability and linearizability +/// is that serializability is a global property; a property of an entire +/// history of operations and transactions. Linearizability is a local +/// property; a property of a single operation/transaction. +/// +/// Linearizability can be viewed as a special case of strict serializability +/// where transactions are restricted to consist of a single operation applied +/// to a single object. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_LINEARIZABLE "adbc.connection.transaction.isolation.linearizable" + +/// \defgroup adbc-statement-ingestion Bulk Data Ingestion +/// While it is possible to insert data via prepared statements, it can +/// be more efficient to explicitly perform a bulk insert. For +/// compatible drivers, this can be accomplished by setting up and +/// executing a statement. Instead of setting a SQL query or Substrait +/// plan, bind the source data via AdbcStatementBind, and set the name +/// of the table to be created via AdbcStatementSetOption and the +/// options below. Then, call AdbcStatementExecute with a NULL for +/// the out parameter (to indicate you do not expect a result set). +/// +/// @{ + +/// \brief The name of the target table for a bulk insert. +/// +/// The driver should attempt to create the table if it does not +/// exist. If the table exists but has a different schema, +/// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be +/// appended to the target table. +#define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table" +/// \brief Whether to create (the default) or append. +#define ADBC_INGEST_OPTION_MODE "adbc.ingest.mode" +/// \brief Create the table and insert data; error if the table exists. +#define ADBC_INGEST_OPTION_MODE_CREATE "adbc.ingest.mode.create" +/// \brief Do not create the table, and insert data; error if the +/// table does not exist (ADBC_STATUS_NOT_FOUND) or does not match +/// the schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). +#define ADBC_INGEST_OPTION_MODE_APPEND "adbc.ingest.mode.append" + +/// @} + +/// @} + +/// \defgroup adbc-database Database Initialization +/// Clients first initialize a database, then create a connection +/// (below). This gives the implementation a place to initialize and +/// own any common connection state. For example, in-memory databases +/// can place ownership of the actual database in this object. +/// @{ + +/// \brief An instance of a database. +/// +/// Must be kept alive as long as any connections exist. +struct ADBC_EXPORT AdbcDatabase { + /// \brief Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + void *private_data; + /// \brief The associated driver (used by the driver manager to help + /// track state). + struct AdbcDriver *private_driver; +}; + +/// @} + +/// \defgroup adbc-connection Connection Establishment +/// Functions for creating, using, and releasing database connections. +/// @{ + +/// \brief An active database connection. +/// +/// Provides methods for query execution, managing prepared +/// statements, using transactions, and so on. +/// +/// Connections are not required to be thread-safe, but they can be +/// used from multiple threads so long as clients take care to +/// serialize accesses to a connection. +struct ADBC_EXPORT AdbcConnection { + /// \brief Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + void *private_data; + /// \brief The associated driver (used by the driver manager to help + /// track state). + struct AdbcDriver *private_driver; +}; + +/// @} + +/// \defgroup adbc-statement Managing Statements +/// Applications should first initialize a statement with +/// AdbcStatementNew. Then, the statement should be configured with +/// functions like AdbcStatementSetSqlQuery and +/// AdbcStatementSetOption. Finally, the statement can be executed +/// with AdbcStatementExecuteQuery (or call AdbcStatementPrepare first +/// to turn it into a prepared statement instead). +/// @{ + +/// \brief A container for all state needed to execute a database +/// query, such as the query itself, parameters for prepared +/// statements, driver parameters, etc. +/// +/// Statements may represent queries or prepared statements. +/// +/// Statements may be used multiple times and can be reconfigured +/// (e.g. they can be reused to execute multiple different queries). +/// However, executing a statement (and changing certain other state) +/// will invalidate result sets obtained prior to that execution. +/// +/// Multiple statements may be created from a single connection. +/// However, the driver may block or error if they are used +/// concurrently (whether from a single thread or multiple threads). +/// +/// Statements are not required to be thread-safe, but they can be +/// used from multiple threads so long as clients take care to +/// serialize accesses to a statement. +struct ADBC_EXPORT AdbcStatement { + /// \brief Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + void *private_data; + + /// \brief The associated driver (used by the driver manager to help + /// track state). + struct AdbcDriver *private_driver; +}; + +/// \defgroup adbc-statement-partition Partitioned Results +/// Some backends may internally partition the results. These +/// partitions are exposed to clients who may wish to integrate them +/// with a threaded or distributed execution model, where partitions +/// can be divided among threads or machines and fetched in parallel. +/// +/// To use partitioning, execute the statement with +/// AdbcStatementExecutePartitions to get the partition descriptors. +/// Call AdbcConnectionReadPartition to turn the individual +/// descriptors into ArrowArrayStream instances. This may be done on +/// a different connection than the one the partition was created +/// with, or even in a different process on another machine. +/// +/// Drivers are not required to support partitioning. +/// +/// @{ + +/// \brief The partitions of a distributed/partitioned result set. +struct AdbcPartitions { + /// \brief The number of partitions. + size_t num_partitions; + + /// \brief The partitions of the result set, where each entry (up to + /// num_partitions entries) is an opaque identifier that can be + /// passed to AdbcConnectionReadPartition. + const uint8_t **partitions; + + /// \brief The length of each corresponding entry in partitions. + const size_t *partition_lengths; + + /// \brief Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + void *private_data; + + /// \brief Release the contained partitions. + /// + /// Unlike other structures, this is an embedded callback to make it + /// easier for the driver manager and driver to cooperate. + void (*release)(struct AdbcPartitions *partitions); +}; + +/// @} + +/// @} + +/// \defgroup adbc-driver Driver Initialization +/// +/// These functions are intended to help support integration between a +/// driver and the driver manager. +/// @{ + +/// \brief An instance of an initialized database driver. +/// +/// This provides a common interface for vendor-specific driver +/// initialization routines. Drivers should populate this struct, and +/// applications can call ADBC functions through this struct, without +/// worrying about multiple definitions of the same symbol. +struct ADBC_EXPORT AdbcDriver { + /// \brief Opaque driver-defined state. + /// This field is NULL if the driver is unintialized/freed (but + /// it need not have a value even if the driver is initialized). + void *private_data; + /// \brief Opaque driver manager-defined state. + /// This field is NULL if the driver is unintialized/freed (but + /// it need not have a value even if the driver is initialized). + void *private_manager; + + /// \brief Release the driver and perform any cleanup. + /// + /// This is an embedded callback to make it easier for the driver + /// manager and driver to cooperate. + AdbcStatusCode (*release)(struct AdbcDriver *driver, struct AdbcError *error); + + AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase *, struct AdbcError *); + AdbcStatusCode (*DatabaseNew)(struct AdbcDatabase *, struct AdbcError *); + AdbcStatusCode (*DatabaseSetOption)(struct AdbcDatabase *, const char *, const char *, struct AdbcError *); + AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase *, struct AdbcError *); + + AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection *, struct AdbcError *); + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection *, uint32_t *, size_t, struct ArrowArrayStream *, + struct AdbcError *); + AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection *, int, const char *, const char *, const char *, + const char **, const char *, struct ArrowArrayStream *, struct AdbcError *); + AdbcStatusCode (*ConnectionGetTableSchema)(struct AdbcConnection *, const char *, const char *, const char *, + struct ArrowSchema *, struct AdbcError *); + AdbcStatusCode (*ConnectionGetTableTypes)(struct AdbcConnection *, struct ArrowArrayStream *, struct AdbcError *); + AdbcStatusCode (*ConnectionInit)(struct AdbcConnection *, struct AdbcDatabase *, struct AdbcError *); + AdbcStatusCode (*ConnectionNew)(struct AdbcConnection *, struct AdbcError *); + AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection *, const char *, const char *, struct AdbcError *); + AdbcStatusCode (*ConnectionReadPartition)(struct AdbcConnection *, const uint8_t *, size_t, + struct ArrowArrayStream *, struct AdbcError *); + AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection *, struct AdbcError *); + AdbcStatusCode (*ConnectionRollback)(struct AdbcConnection *, struct AdbcError *); + + AdbcStatusCode (*StatementBind)(struct AdbcStatement *, struct ArrowArray *, struct ArrowSchema *, + struct AdbcError *); + AdbcStatusCode (*StatementBindStream)(struct AdbcStatement *, struct ArrowArrayStream *, struct AdbcError *); + AdbcStatusCode (*StatementExecuteQuery)(struct AdbcStatement *, struct ArrowArrayStream *, int64_t *, + struct AdbcError *); + AdbcStatusCode (*StatementExecutePartitions)(struct AdbcStatement *, struct ArrowSchema *, struct AdbcPartitions *, + int64_t *, struct AdbcError *); + AdbcStatusCode (*StatementGetParameterSchema)(struct AdbcStatement *, struct ArrowSchema *, struct AdbcError *); + AdbcStatusCode (*StatementNew)(struct AdbcConnection *, struct AdbcStatement *, struct AdbcError *); + AdbcStatusCode (*StatementPrepare)(struct AdbcStatement *, struct AdbcError *); + AdbcStatusCode (*StatementRelease)(struct AdbcStatement *, struct AdbcError *); + AdbcStatusCode (*StatementSetOption)(struct AdbcStatement *, const char *, const char *, struct AdbcError *); + AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement *, const char *, struct AdbcError *); + AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement *, const uint8_t *, size_t, struct AdbcError *); +}; + +/// @} + +/// \addtogroup adbc-database +/// @{ + +/// \brief Allocate a new (but uninitialized) database. +/// +/// Callers pass in a zero-initialized AdbcDatabase. +/// +/// Drivers should allocate their internal data structure and set the private_data +/// field to point to the newly allocated struct. This struct should be released +/// when AdbcDatabaseRelease is called. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase *database, struct AdbcError *error); + +/// \brief Set a char* option. +/// +/// Options may be set before AdbcDatabaseInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, + struct AdbcError *error); + +/// \brief Finish setting options and initialize the database. +/// +/// Some drivers may support setting options after initialization +/// as well. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase *database, struct AdbcError *error); + +/// \brief Destroy this database. No connections may exist. +/// \param[in] database The database to release. +/// \param[out] error An optional location to return an error +/// message if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error); + +/// @} + +/// \addtogroup adbc-connection +/// @{ + +/// \brief Allocate a new (but uninitialized) connection. +/// +/// Callers pass in a zero-initialized AdbcConnection. +/// +/// Drivers should allocate their internal data structure and set the private_data +/// field to point to the newly allocated struct. This struct should be released +/// when AdbcConnectionRelease is called. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection *connection, struct AdbcError *error); + +/// \brief Set a char* option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, + struct AdbcError *error); + +/// \brief Finish setting options and initialize the connection. +/// +/// Some drivers may support setting options after initialization +/// as well. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, + struct AdbcError *error); + +/// \brief Destroy this connection. +/// +/// \param[in] connection The connection to release. +/// \param[out] error An optional location to return an error +/// message if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error); + +/// \defgroup adbc-connection-metadata Metadata +/// Functions for retrieving metadata about the database. +/// +/// Generally, these functions return an ArrowArrayStream that can be +/// consumed to get the metadata as Arrow data. The returned metadata +/// has an expected schema given in the function docstring. Schema +/// fields are nullable unless otherwise marked. While no +/// AdbcStatement is used in these functions, the result set may count +/// as an active statement to the driver for the purposes of +/// concurrency management (e.g. if the driver has a limit on +/// concurrent active statements and it must execute a SQL query +/// internally in order to implement the metadata function). +/// +/// Some functions accept "search pattern" arguments, which are +/// strings that can contain the special character "%" to match zero +/// or more characters, or "_" to match exactly one character. (See +/// the documentation of DatabaseMetaData in JDBC or "Pattern Value +/// Arguments" in the ODBC documentation.) Escaping is not currently +/// supported. +/// +/// @{ + +/// \brief Get metadata about the database/driver. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ----------------------------|------------------------ +/// info_name | uint32 not null +/// info_value | INFO_SCHEMA +/// +/// INFO_SCHEMA is a dense union with members: +/// +/// Field Name (Type Code) | Field Type +/// ----------------------------|------------------------ +/// string_value (0) | utf8 +/// bool_value (1) | bool +/// int64_value (2) | int64 +/// int32_bitmask (3) | int32 +/// string_list (4) | list +/// int32_to_int32_list_map (5) | map> +/// +/// Each metadatum is identified by an integer code. The recognized +/// codes are defined as constants. Codes [0, 10_000) are reserved +/// for ADBC usage. Drivers/vendors will ignore requests for +/// unrecognized codes (the row will be omitted from the result). +/// +/// \param[in] connection The connection to query. +/// \param[in] info_codes A list of metadata codes to fetch, or NULL +/// to fetch all. +/// \param[in] info_codes_length The length of the info_codes +/// parameter. Ignored if info_codes is NULL. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, + struct ArrowArrayStream *out, struct AdbcError *error); + +/// \brief Get a hierarchical view of all catalogs, database schemas, +/// tables, and columns. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// | Field Name | Field Type | +/// |--------------------------|-------------------------| +/// | catalog_name | utf8 | +/// | catalog_db_schemas | list | +/// +/// DB_SCHEMA_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|-------------------------| +/// | db_schema_name | utf8 | +/// | db_schema_tables | list | +/// +/// TABLE_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|-------------------------| +/// | table_name | utf8 not null | +/// | table_type | utf8 not null | +/// | table_columns | list | +/// | table_constraints | list | +/// +/// COLUMN_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|-------------------------|----------| +/// | column_name | utf8 not null | | +/// | ordinal_position | int32 | (1) | +/// | remarks | utf8 | (2) | +/// | xdbc_data_type | int16 | (3) | +/// | xdbc_type_name | utf8 | (3) | +/// | xdbc_column_size | int32 | (3) | +/// | xdbc_decimal_digits | int16 | (3) | +/// | xdbc_num_prec_radix | int16 | (3) | +/// | xdbc_nullable | int16 | (3) | +/// | xdbc_column_def | utf8 | (3) | +/// | xdbc_sql_data_type | int16 | (3) | +/// | xdbc_datetime_sub | int16 | (3) | +/// | xdbc_char_octet_length | int32 | (3) | +/// | xdbc_is_nullable | utf8 | (3) | +/// | xdbc_scope_catalog | utf8 | (3) | +/// | xdbc_scope_schema | utf8 | (3) | +/// | xdbc_scope_table | utf8 | (3) | +/// | xdbc_is_autoincrement | bool | (3) | +/// | xdbc_is_generatedcolumn | bool | (3) | +/// +/// 1. The column's ordinal position in the table (starting from 1). +/// 2. Database-specific description of the column. +/// 3. Optional value. Should be null if not supported by the driver. +/// xdbc_ values are meant to provide JDBC/ODBC-compatible metadata +/// in an agnostic manner. +/// +/// CONSTRAINT_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|-------------------------|----------| +/// | constraint_name | utf8 | | +/// | constraint_type | utf8 not null | (1) | +/// | constraint_column_names | list not null | (2) | +/// | constraint_column_usage | list | (3) | +/// +/// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'. +/// 2. The columns on the current table that are constrained, in +/// order. +/// 3. For FOREIGN KEY only, the referenced table and columns. +/// +/// USAGE_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|-------------------------| +/// | fk_catalog | utf8 | +/// | fk_db_schema | utf8 | +/// | fk_table | utf8 not null | +/// | fk_column_name | utf8 not null | +/// +/// \param[in] connection The database connection. +/// \param[in] depth The level of nesting to display. If 0, display +/// all levels. If 1, display only catalogs (i.e. catalog_schemas +/// will be null). If 2, display only catalogs and schemas +/// (i.e. db_schema_tables will be null), and so on. +/// \param[in] catalog Only show tables in the given catalog. If NULL, +/// do not filter by catalog. If an empty string, only show tables +/// without a catalog. May be a search pattern (see section +/// documentation). +/// \param[in] db_schema Only show tables in the given database schema. If +/// NULL, do not filter by database schema. If an empty string, only show +/// tables without a database schema. May be a search pattern (see section +/// documentation). +/// \param[in] table_name Only show tables with the given name. If NULL, do not +/// filter by name. May be a search pattern (see section documentation). +/// \param[in] table_type Only show tables matching one of the given table +/// types. If NULL, show tables of any type. Valid table types can be fetched +/// from GetTableTypes. Terminate the list with a NULL entry. +/// \param[in] column_name Only show columns with the given name. If +/// NULL, do not filter by name. May be a search pattern (see +/// section documentation). +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, + const char *db_schema, const char *table_name, const char **table_type, + const char *column_name, struct ArrowArrayStream *out, struct AdbcError *error); + +/// \brief Get the Arrow schema of a table. +/// +/// \param[in] connection The database connection. +/// \param[in] catalog The catalog (or nullptr if not applicable). +/// \param[in] db_schema The database schema (or nullptr if not applicable). +/// \param[in] table_name The table name. +/// \param[out] schema The table schema. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, + const char *db_schema, const char *table_name, struct ArrowSchema *schema, + struct AdbcError *error); + +/// \brief Get a list of table types in the database. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ---------------|-------------- +/// table_type | utf8 not null +/// +/// \param[in] connection The database connection. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *out, + struct AdbcError *error); + +/// @} + +/// \defgroup adbc-connection-partition Partitioned Results +/// Some databases may internally partition the results. These +/// partitions are exposed to clients who may wish to integrate them +/// with a threaded or distributed execution model, where partitions +/// can be divided among threads or machines for processing. +/// +/// Drivers are not required to support partitioning. +/// +/// Partitions are not ordered. If the result set is sorted, +/// implementations should return a single partition. +/// +/// @{ + +/// \brief Construct a statement for a partition of a query. The +/// results can then be read independently. +/// +/// A partition can be retrieved from AdbcPartitions. +/// +/// \param[in] connection The connection to use. This does not have +/// to be the same connection that the partition was created on. +/// \param[in] serialized_partition The partition descriptor. +/// \param[in] serialized_length The partition descriptor length. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, + size_t serialized_length, struct ArrowArrayStream *out, + struct AdbcError *error); + +/// @} + +/// \defgroup adbc-connection-transaction Transaction Semantics +/// +/// Connections start out in auto-commit mode by default (if +/// applicable for the given vendor). Use AdbcConnectionSetOption and +/// ADBC_CONNECTION_OPTION_AUTO_COMMIT to change this. +/// +/// @{ + +/// \brief Commit any pending transactions. Only used if autocommit is +/// disabled. +/// +/// Behavior is undefined if this is mixed with SQL transaction +/// statements. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error); + +/// \brief Roll back any pending transactions. Only used if autocommit +/// is disabled. +/// +/// Behavior is undefined if this is mixed with SQL transaction +/// statements. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error); + +/// @} + +/// @} + +/// \addtogroup adbc-statement +/// @{ + +/// \brief Create a new statement for a given connection. +/// +/// Callers pass in a zero-initialized AdbcStatement. +/// +/// Drivers should allocate their internal data structure and set the private_data +/// field to point to the newly allocated struct. This struct should be released +/// when AdbcStatementRelease is called. +ADBC_EXPORT +AdbcStatusCode AdbcStatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, + struct AdbcError *error); + +/// \brief Destroy a statement. +/// \param[in] statement The statement to release. +/// \param[out] error An optional location to return an error +/// message if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement *statement, struct AdbcError *error); + +/// \brief Execute a statement and get the results. +/// +/// This invalidates any prior result sets. +/// +/// \param[in] statement The statement to execute. +/// \param[out] out The results. Pass NULL if the client does not +/// expect a result set. +/// \param[out] rows_affected The number of rows affected if known, +/// else -1. Pass NULL if the client does not want this information. +/// \param[out] error An optional location to return an error +/// message if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, + int64_t *rows_affected, struct AdbcError *error); + +/// \brief Turn this statement into a prepared statement to be +/// executed multiple times. +/// +/// This invalidates any prior result sets. +ADBC_EXPORT +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement *statement, struct AdbcError *error); + +/// \defgroup adbc-statement-sql SQL Semantics +/// Functions for executing SQL queries, or querying SQL-related +/// metadata. Drivers are not required to support both SQL and +/// Substrait semantics. If they do, it may be via converting +/// between representations internally. +/// @{ + +/// \brief Set the SQL query to execute. +/// +/// The query can then be executed with AdbcStatementExecute. For +/// queries expected to be executed repeatedly, AdbcStatementPrepare +/// the statement first. +/// +/// \param[in] statement The statement. +/// \param[in] query The query to execute. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error); + +/// @} + +/// \defgroup adbc-statement-substrait Substrait Semantics +/// Functions for executing Substrait plans, or querying +/// Substrait-related metadata. Drivers are not required to support +/// both SQL and Substrait semantics. If they do, it may be via +/// converting between representations internally. +/// @{ + +/// \brief Set the Substrait plan to execute. +/// +/// The query can then be executed with AdbcStatementExecute. For +/// queries expected to be executed repeatedly, AdbcStatementPrepare +/// the statement first. +/// +/// \param[in] statement The statement. +/// \param[in] plan The serialized substrait.Plan to execute. +/// \param[in] length The length of the serialized plan. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, + struct AdbcError *error); + +/// @} + +/// \brief Bind Arrow data. This can be used for bulk inserts or +/// prepared statements. +/// +/// \param[in] statement The statement to bind to. +/// \param[in] values The values to bind. The driver will call the +/// release callback itself, although it may not do this until the +/// statement is released. +/// \param[in] schema The schema of the values to bind. +/// \param[out] error An optional location to return an error message +/// if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcStatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema, + struct AdbcError *error); + +/// \brief Bind Arrow data. This can be used for bulk inserts or +/// prepared statements. +/// \param[in] statement The statement to bind to. +/// \param[in] stream The values to bind. The driver will call the +/// release callback itself, although it may not do this until the +/// statement is released. +/// \param[out] error An optional location to return an error message +/// if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, + struct AdbcError *error); + +/// \brief Get the schema for bound parameters. +/// +/// This retrieves an Arrow schema describing the number, names, and +/// types of the parameters in a parameterized statement. The fields +/// of the schema should be in order of the ordinal position of the +/// parameters; named parameters should appear only once. +/// +/// If the parameter does not have a name, or the name cannot be +/// determined, the name of the corresponding field in the schema will +/// be an empty string. If the type cannot be determined, the type of +/// the corresponding field will be NA (NullType). +/// +/// This should be called after AdbcStatementPrepare. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the schema cannot be determined. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcError *error); + +/// \brief Set a string option on a statement. +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, + struct AdbcError *error); + +/// \addtogroup adbc-statement-partition +/// @{ + +/// \brief Execute a statement and get the results as a partitioned +/// result set. +/// +/// \param[in] statement The statement to execute. +/// \param[out] schema The schema of the result set. +/// \param[out] partitions The result partitions. +/// \param[out] rows_affected The number of rows affected if known, +/// else -1. Pass NULL if the client does not want this information. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support +/// partitioned results +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcPartitions *partitions, int64_t *rows_affected, + struct AdbcError *error); + +/// @} + +/// @} + +/// \addtogroup adbc-driver +/// @{ + +/// \brief Common entry point for drivers via the driver manager +/// (which uses dlopen(3)/LoadLibrary). The driver manager is told +/// to load a library and call a function of this type to load the +/// driver. +/// +/// Although drivers may choose any name for this function, the +/// recommended name is "AdbcDriverInit". +/// +/// \param[in] version The ADBC revision to attempt to initialize (see +/// ADBC_VERSION_1_0_0). +/// \param[out] driver The table of function pointers to +/// initialize. Should be a pointer to the appropriate struct for +/// the given version (see the documentation for the version). +/// \param[out] error An optional location to return an error message +/// if necessary. +/// \return ADBC_STATUS_OK if the driver was initialized, or +/// ADBC_STATUS_NOT_IMPLEMENTED if the version is not supported. In +/// that case, clients may retry with a different version. +typedef AdbcStatusCode (*AdbcDriverInitFunc)(int version, void *driver, struct AdbcError *error); + +/// @} + +#endif // ADBC + +#ifdef __cplusplus +} +#endif +} // namespace duckdb_adbc diff --git a/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp b/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp new file mode 100644 index 00000000..bc7e37fb --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/adbc/adbc.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/adbc/adbc.h" + +#include + +namespace duckdb_adbc { + +AdbcStatusCode DatabaseNew(struct AdbcDatabase *database, struct AdbcError *error); + +AdbcStatusCode DatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, + struct AdbcError *error); + +AdbcStatusCode DatabaseInit(struct AdbcDatabase *database, struct AdbcError *error); + +AdbcStatusCode DatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error); + +AdbcStatusCode ConnectionNew(struct AdbcConnection *connection, struct AdbcError *error); + +AdbcStatusCode ConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, + struct AdbcError *error); + +AdbcStatusCode ConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, + struct AdbcError *error); + +AdbcStatusCode ConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error); + +AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, + struct ArrowArrayStream *out, struct AdbcError *error); + +AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, + const char *db_schema, const char *table_name, const char **table_type, + const char *column_name, struct ArrowArrayStream *out, struct AdbcError *error); + +AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, const char *db_schema, + const char *table_name, struct ArrowSchema *schema, struct AdbcError *error); + +AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *out, + struct AdbcError *error); + +AdbcStatusCode ConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, + size_t serialized_length, struct ArrowArrayStream *out, struct AdbcError *error); + +AdbcStatusCode ConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error); + +AdbcStatusCode ConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error); + +AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, + struct AdbcError *error); + +AdbcStatusCode StatementRelease(struct AdbcStatement *statement, struct AdbcError *error); + +AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, + int64_t *rows_affected, struct AdbcError *error); + +AdbcStatusCode StatementPrepare(struct AdbcStatement *statement, struct AdbcError *error); + +AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error); + +AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, + struct AdbcError *error); + +AdbcStatusCode StatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema, + struct AdbcError *error); + +AdbcStatusCode StatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, + struct AdbcError *error); + +AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcError *error); + +AdbcStatusCode StatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, + struct AdbcError *error); + +AdbcStatusCode StatementExecutePartitions(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcPartitions *partitions, int64_t *rows_affected, + struct AdbcError *error); + +//! This method should only be called when the string is guaranteed to not be NULL +void SetError(struct AdbcError *error, const std::string &message); +void SetError(struct AdbcError *error, const char *message); + +void InitializeADBCError(AdbcError *error); + +} // namespace duckdb_adbc diff --git a/src/duckdb/src/include/duckdb/common/adbc/driver_manager.h b/src/duckdb/src/include/duckdb/common/adbc/driver_manager.h new file mode 100644 index 00000000..b70ec3a7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/adbc/driver_manager.h @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "duckdb/common/adbc/adbc.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ADBC_DRIVER_MANAGER_H +#define ADBC_DRIVER_MANAGER_H +namespace duckdb_adbc { +/// \brief Common entry point for drivers via the driver manager. +/// +/// The driver manager can fill in default implementations of some +/// ADBC functions for drivers. Drivers must implement a minimum level +/// of functionality for this to be possible, however, and some +/// functions must be implemented by the driver. +/// +/// \param[in] driver_name An identifier for the driver (e.g. a path to a +/// shared library on Linux). +/// \param[in] entrypoint An identifier for the entrypoint (e.g. the +/// symbol to call for AdbcDriverInitFunc on Linux). +/// \param[in] version The ADBC revision to attempt to initialize. +/// \param[out] driver The table of function pointers to initialize. +/// \param[out] error An optional location to return an error message +/// if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *driver, + struct AdbcError *error); + +/// \brief Common entry point for drivers via the driver manager. +/// +/// The driver manager can fill in default implementations of some +/// ADBC functions for drivers. Drivers must implement a minimum level +/// of functionality for this to be possible, however, and some +/// functions must be implemented by the driver. +/// +/// \param[in] init_func The entrypoint to call. +/// \param[in] version The ADBC revision to attempt to initialize. +/// \param[out] driver The table of function pointers to initialize. +/// \param[out] error An optional location to return an error message +/// if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *driver, + struct AdbcError *error); + +/// \brief Set the AdbcDriverInitFunc to use. +/// +/// This is an extension to the ADBC API. The driver manager shims +/// the AdbcDatabase* functions to allow you to specify the +/// driver/entrypoint dynamically. This function lets you set the +/// entrypoint explicitly, for applications that can dynamically +/// load drivers on their own. +ADBC_EXPORT +AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, + struct AdbcError *error); + +/// \brief Get a human-friendly description of a status code. +ADBC_EXPORT +const char *AdbcStatusCodeMessage(AdbcStatusCode code); + +#endif // ADBC_DRIVER_MANAGER_H + +#ifdef __cplusplus +} +#endif +} // namespace duckdb_adbc diff --git a/src/duckdb/src/include/duckdb/common/adbc/single_batch_array_stream.hpp b/src/duckdb/src/include/duckdb/common/adbc/single_batch_array_stream.hpp new file mode 100644 index 00000000..583ce276 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/adbc/single_batch_array_stream.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/common/adbc/adbc.h" + +namespace duckdb_adbc { + +struct SingleBatchArrayStream { + struct ArrowSchema schema; + struct ArrowArray batch; +}; + +AdbcStatusCode BatchToArrayStream(struct ArrowArray *values, struct ArrowSchema *schema, + struct ArrowArrayStream *stream, struct AdbcError *error); + +} // namespace duckdb_adbc diff --git a/src/duckdb/src/include/duckdb/common/algorithm.hpp b/src/duckdb/src/include/duckdb/common/algorithm.hpp new file mode 100644 index 00000000..0608166f --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/algorithm.hpp @@ -0,0 +1,12 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/algorithm.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include diff --git a/src/duckdb/src/include/duckdb/common/allocator.hpp b/src/duckdb/src/include/duckdb/common/allocator.hpp new file mode 100644 index 00000000..9d4528ae --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/allocator.hpp @@ -0,0 +1,162 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/allocator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { +class Allocator; +class AttachedDatabase; +class ClientContext; +class DatabaseInstance; +class ExecutionContext; +class ThreadContext; + +struct AllocatorDebugInfo; + +struct PrivateAllocatorData { + PrivateAllocatorData(); + virtual ~PrivateAllocatorData(); + + unique_ptr debug_info; + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +typedef data_ptr_t (*allocate_function_ptr_t)(PrivateAllocatorData *private_data, idx_t size); +typedef void (*free_function_ptr_t)(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size); +typedef data_ptr_t (*reallocate_function_ptr_t)(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, + idx_t size); + +class AllocatedData { +public: + DUCKDB_API AllocatedData(); + DUCKDB_API AllocatedData(Allocator &allocator, data_ptr_t pointer, idx_t allocated_size); + DUCKDB_API ~AllocatedData(); + // disable copy constructors + AllocatedData(const AllocatedData &other) = delete; + AllocatedData &operator=(const AllocatedData &) = delete; + //! enable move constructors + DUCKDB_API AllocatedData(AllocatedData &&other) noexcept; + DUCKDB_API AllocatedData &operator=(AllocatedData &&) noexcept; + + data_ptr_t get() { + return pointer; + } + const_data_ptr_t get() const { + return pointer; + } + idx_t GetSize() const { + return allocated_size; + } + bool IsSet() { + return pointer; + } + void Reset(); + +private: + optional_ptr allocator; + data_ptr_t pointer; + idx_t allocated_size; +}; + +class Allocator { + // 281TB ought to be enough for anybody + static constexpr const idx_t MAXIMUM_ALLOC_SIZE = 281474976710656ULL; + +public: + DUCKDB_API Allocator(); + DUCKDB_API Allocator(allocate_function_ptr_t allocate_function_p, free_function_ptr_t free_function_p, + reallocate_function_ptr_t reallocate_function_p, + unique_ptr private_data); + Allocator &operator=(Allocator &&allocator) noexcept = delete; + DUCKDB_API ~Allocator(); + + DUCKDB_API data_ptr_t AllocateData(idx_t size); + DUCKDB_API void FreeData(data_ptr_t pointer, idx_t size); + DUCKDB_API data_ptr_t ReallocateData(data_ptr_t pointer, idx_t old_size, idx_t new_size); + + AllocatedData Allocate(idx_t size) { + return AllocatedData(*this, AllocateData(size), size); + } + static data_ptr_t DefaultAllocate(PrivateAllocatorData *private_data, idx_t size) { + return data_ptr_cast(malloc(size)); + } + static void DefaultFree(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size) { + free(pointer); + } + static data_ptr_t DefaultReallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, + idx_t size) { + return data_ptr_cast(realloc(pointer, size)); + } + static Allocator &Get(ClientContext &context); + static Allocator &Get(DatabaseInstance &db); + static Allocator &Get(AttachedDatabase &db); + + PrivateAllocatorData *GetPrivateData() { + return private_data.get(); + } + + DUCKDB_API static Allocator &DefaultAllocator(); + DUCKDB_API static shared_ptr &DefaultAllocatorReference(); + + static void ThreadFlush(idx_t threshold); + +private: + allocate_function_ptr_t allocate_function; + free_function_ptr_t free_function; + reallocate_function_ptr_t reallocate_function; + + unique_ptr private_data; +}; + +template +T *AllocateArray(idx_t size) { + return (T *)Allocator::DefaultAllocator().AllocateData(size * sizeof(T)); +} + +template +void DeleteArray(T *ptr, idx_t size) { + Allocator::DefaultAllocator().FreeData(data_ptr_cast(ptr), size * sizeof(T)); +} + +template +T *AllocateObject(ARGS &&... args) { + auto data = Allocator::DefaultAllocator().AllocateData(sizeof(T)); + return new (data) T(std::forward(args)...); +} + +template +void DestroyObject(T *ptr) { + ptr->~T(); + Allocator::DefaultAllocator().FreeData(data_ptr_cast(ptr), sizeof(T)); +} + +//! The BufferAllocator is a wrapper around the global allocator class that sends any allocations made through the +//! buffer manager. This makes the buffer manager aware of the memory usage, allowing it to potentially free +//! other blocks to make space in memory. +//! Note that there is a cost to doing so (several atomic operations will be performed on allocation/free). +//! As such this class should be used primarily for larger allocations. +struct BufferAllocator { + DUCKDB_API static Allocator &Get(ClientContext &context); + DUCKDB_API static Allocator &Get(DatabaseInstance &db); + DUCKDB_API static Allocator &Get(AttachedDatabase &db); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/array.hpp b/src/duckdb/src/include/duckdb/common/array.hpp new file mode 100644 index 00000000..fc03178d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/array.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/array.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::array; +} diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp new file mode 100644 index 00000000..0961e595 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp @@ -0,0 +1,109 @@ +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/common/arrow/arrow_buffer.hpp" +#include "duckdb/main/client_properties.hpp" +#include "duckdb/common/array.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Arrow append data +//===--------------------------------------------------------------------===// +typedef void (*initialize_t)(ArrowAppendData &result, const LogicalType &type, idx_t capacity); +// append_data: The arrow array we're appending into +// input: The data we're appending +// from: The offset into the input we're scanning +// to: The last index of the input we're scanning +// input_size: The total size of the 'input' Vector. +typedef void (*append_vector_t)(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); +typedef void (*finalize_t)(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); + +// This struct is used to save state for appending a column +// afterwards the ownership is passed to the arrow array, as 'private_data' +// FIXME: we should separate the append state variables from the variables required by the ArrowArray into +// ArrowAppendState +struct ArrowAppendData { + explicit ArrowAppendData(ClientProperties &options_p) : options(options_p) { + } + // the buffers of the arrow vector + ArrowBuffer validity; + ArrowBuffer main_buffer; + ArrowBuffer aux_buffer; + + idx_t row_count = 0; + idx_t null_count = 0; + + // function pointers for construction + initialize_t initialize = nullptr; + append_vector_t append_vector = nullptr; + finalize_t finalize = nullptr; + + // child data (if any) + vector> child_data; + + // the arrow array C API data, only set after Finalize + unique_ptr array; + duckdb::array buffers = {{nullptr, nullptr, nullptr}}; + vector child_pointers; + + ClientProperties options; +}; + +//===--------------------------------------------------------------------===// +// Append Helper Functions +//===--------------------------------------------------------------------===// +static void GetBitPosition(idx_t row_idx, idx_t ¤t_byte, uint8_t ¤t_bit) { + current_byte = row_idx / 8; + current_bit = row_idx % 8; +} + +static void UnsetBit(uint8_t *data, idx_t current_byte, uint8_t current_bit) { + data[current_byte] &= ~((uint64_t)1 << current_bit); +} + +static void NextBit(idx_t ¤t_byte, uint8_t ¤t_bit) { + current_bit++; + if (current_bit == 8) { + current_byte++; + current_bit = 0; + } +} + +static void ResizeValidity(ArrowBuffer &buffer, idx_t row_count) { + auto byte_count = (row_count + 7) / 8; + buffer.resize(byte_count, 0xFF); +} + +static void SetNull(ArrowAppendData &append_data, uint8_t *validity_data, idx_t current_byte, uint8_t current_bit) { + UnsetBit(validity_data, current_byte, current_bit); + append_data.null_count++; +} + +static void AppendValidity(ArrowAppendData &append_data, UnifiedVectorFormat &format, idx_t from, idx_t to) { + // resize the buffer, filling the validity buffer with all valid values + idx_t size = to - from; + ResizeValidity(append_data.validity, append_data.row_count + size); + if (format.validity.AllValid()) { + // if all values are valid we don't need to do anything else + return; + } + + // otherwise we iterate through the validity mask + auto validity_data = (uint8_t *)append_data.validity.data(); + uint8_t current_bit; + idx_t current_byte; + GetBitPosition(append_data.row_count, current_byte, current_bit); + for (idx_t i = from; i < to; i++) { + auto source_idx = format.sel->get_index(i); + // append the validity mask + if (!format.validity.RowIsValid(source_idx)) { + SetNull(append_data, validity_data, current_byte, current_bit); + } + NextBit(current_byte, current_bit); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/bool_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/bool_data.hpp new file mode 100644 index 00000000..59313faa --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/bool_data.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "duckdb/common/arrow/appender/append_data.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +struct ArrowBoolData { +public: + static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); + static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/enum_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/enum_data.hpp new file mode 100644 index 00000000..54cfdc0d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/enum_data.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include "duckdb/common/arrow/appender/append_data.hpp" +#include "duckdb/common/arrow/appender/scalar_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Enums +//===--------------------------------------------------------------------===// +template +struct ArrowEnumData : public ArrowScalarBaseData { + static idx_t GetLength(string_t input) { + return input.GetSize(); + } + static void WriteData(data_ptr_t target, string_t input) { + memcpy(target, input.GetData(), input.GetSize()); + } + static void EnumAppendVector(ArrowAppendData &append_data, const Vector &input, idx_t size) { + D_ASSERT(input.GetVectorType() == VectorType::FLAT_VECTOR); + + // resize the validity mask and set up the validity buffer for iteration + ResizeValidity(append_data.validity, append_data.row_count + size); + + // resize the offset buffer - the offset buffer holds the offsets into the child array + append_data.main_buffer.resize(append_data.main_buffer.size() + sizeof(uint32_t) * (size + 1)); + auto data = FlatVector::GetData(input); + auto offset_data = append_data.main_buffer.GetData(); + if (append_data.row_count == 0) { + // first entry + offset_data[0] = 0; + } + // now append the string data to the auxiliary buffer + // the auxiliary buffer's length depends on the string lengths, so we resize as required + auto last_offset = offset_data[append_data.row_count]; + for (idx_t i = 0; i < size; i++) { + auto offset_idx = append_data.row_count + i + 1; + + auto string_length = GetLength(data[i]); + + // append the offset data + auto current_offset = last_offset + string_length; + offset_data[offset_idx] = current_offset; + + // resize the string buffer if required, and write the string data + append_data.aux_buffer.resize(current_offset); + WriteData(append_data.aux_buffer.data() + last_offset, data[i]); + + last_offset = current_offset; + } + append_data.row_count += size; + } + static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { + result.main_buffer.reserve(capacity * sizeof(TGT)); + // construct the enum child data + auto enum_data = ArrowAppender::InitializeChild(LogicalType::VARCHAR, EnumType::GetSize(type), result.options); + EnumAppendVector(*enum_data, EnumType::GetValuesInsertOrder(type), EnumType::GetSize(type)); + result.child_data.push_back(std::move(enum_data)); + } + + static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { + result->n_buffers = 2; + result->buffers[1] = append_data.main_buffer.data(); + // finalize the enum child data, and assign it to the dictionary + result->dictionary = ArrowAppender::FinalizeChild(LogicalType::VARCHAR, *append_data.child_data[0]); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/list.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/list.hpp new file mode 100644 index 00000000..48b94def --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/list.hpp @@ -0,0 +1,8 @@ +#include "duckdb/common/arrow/appender/bool_data.hpp" +#include "duckdb/common/arrow/appender/enum_data.hpp" +#include "duckdb/common/arrow/appender/list_data.hpp" +#include "duckdb/common/arrow/appender/map_data.hpp" +#include "duckdb/common/arrow/appender/scalar_data.hpp" +#include "duckdb/common/arrow/appender/struct_data.hpp" +#include "duckdb/common/arrow/appender/union_data.hpp" +#include "duckdb/common/arrow/appender/varchar_data.hpp" diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp new file mode 100644 index 00000000..9507ac72 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "duckdb/common/arrow/appender/append_data.hpp" + +namespace duckdb { + +struct ArrowListData { +public: + static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); + static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); + +public: + static void AppendOffsets(ArrowAppendData &append_data, UnifiedVectorFormat &format, idx_t from, idx_t to, + vector &child_sel); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp new file mode 100644 index 00000000..9bb31c2f --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/common/arrow/appender/append_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Maps +//===--------------------------------------------------------------------===// +struct ArrowMapData { +public: + static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); + static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp new file mode 100644 index 00000000..2c2769e1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp @@ -0,0 +1,88 @@ +#pragma once + +#include "duckdb/common/arrow/appender/append_data.hpp" +#include "duckdb/function/table/arrow.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Scalar Types +//===--------------------------------------------------------------------===// +struct ArrowScalarConverter { + template + static TGT Operation(SRC input) { + return input; + } + + static bool SkipNulls() { + return false; + } + + template + static void SetNull(TGT &value) { + } +}; + +struct ArrowIntervalConverter { + template + static TGT Operation(SRC input) { + ArrowInterval result; + result.months = input.months; + result.days = input.days; + result.nanoseconds = input.micros * Interval::NANOS_PER_MICRO; + return result; + } + + static bool SkipNulls() { + return true; + } + + template + static void SetNull(TGT &value) { + } +}; + +template +struct ArrowScalarBaseData { + static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + D_ASSERT(to >= from); + idx_t size = to - from; + D_ASSERT(size <= input_size); + UnifiedVectorFormat format; + input.ToUnifiedFormat(input_size, format); + + // append the validity mask + AppendValidity(append_data, format, from, to); + + // append the main data + append_data.main_buffer.resize(append_data.main_buffer.size() + sizeof(TGT) * size); + auto data = UnifiedVectorFormat::GetData(format); + auto result_data = append_data.main_buffer.GetData(); + + for (idx_t i = from; i < to; i++) { + auto source_idx = format.sel->get_index(i); + auto result_idx = append_data.row_count + i - from; + + if (OP::SkipNulls() && !format.validity.RowIsValid(source_idx)) { + OP::template SetNull(result_data[result_idx]); + continue; + } + result_data[result_idx] = OP::template Operation(data[source_idx]); + } + append_data.row_count += size; + } +}; + +template +struct ArrowScalarData : public ArrowScalarBaseData { + static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { + result.main_buffer.reserve(capacity * sizeof(TGT)); + } + + static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { + result->n_buffers = 2; + result->buffers[1] = append_data.main_buffer.data(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/struct_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/struct_data.hpp new file mode 100644 index 00000000..de2514fc --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/struct_data.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "duckdb/common/arrow/appender/append_data.hpp" +#include "duckdb/common/arrow/appender/scalar_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Structs +//===--------------------------------------------------------------------===// +struct ArrowStructData { +public: + static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); + static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/union_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/union_data.hpp new file mode 100644 index 00000000..8b2850fc --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/union_data.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "duckdb/common/arrow/appender/append_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Unions +//===--------------------------------------------------------------------===// +/** + * Based on https://arrow.apache.org/docs/format/Columnar.html#union-layout & + * https://arrow.apache.org/docs/format/CDataInterface.html + */ +struct ArrowUnionData { +public: + static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); + static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp new file mode 100644 index 00000000..03984fc7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp @@ -0,0 +1,105 @@ +#pragma once + +#include "duckdb/common/arrow/appender/append_data.hpp" +#include "duckdb/common/arrow/appender/scalar_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Varchar +//===--------------------------------------------------------------------===// +struct ArrowVarcharConverter { + template + static idx_t GetLength(SRC input) { + return input.GetSize(); + } + + template + static void WriteData(data_ptr_t target, SRC input) { + memcpy(target, input.GetData(), input.GetSize()); + } +}; + +struct ArrowUUIDConverter { + template + static idx_t GetLength(SRC input) { + return UUID::STRING_SIZE; + } + + template + static void WriteData(data_ptr_t target, SRC input) { + UUID::ToString(input, char_ptr_cast(target)); + } +}; + +template +struct ArrowVarcharData { + static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { + result.main_buffer.reserve((capacity + 1) * sizeof(BUFTYPE)); + + result.aux_buffer.reserve(capacity); + } + + static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + idx_t size = to - from; + UnifiedVectorFormat format; + input.ToUnifiedFormat(input_size, format); + + // resize the validity mask and set up the validity buffer for iteration + ResizeValidity(append_data.validity, append_data.row_count + size); + auto validity_data = (uint8_t *)append_data.validity.data(); + + // resize the offset buffer - the offset buffer holds the offsets into the child array + append_data.main_buffer.resize(append_data.main_buffer.size() + sizeof(BUFTYPE) * (size + 1)); + auto data = UnifiedVectorFormat::GetData(format); + auto offset_data = append_data.main_buffer.GetData(); + if (append_data.row_count == 0) { + // first entry + offset_data[0] = 0; + } + // now append the string data to the auxiliary buffer + // the auxiliary buffer's length depends on the string lengths, so we resize as required + auto last_offset = offset_data[append_data.row_count]; + idx_t max_offset = append_data.row_count + to - from; + if (max_offset > NumericLimits::Maximum() && + append_data.options.arrow_offset_size == ArrowOffsetSize::REGULAR) { + throw InvalidInputException("Arrow Appender: The maximum total string size for regular string buffers is " + "%u but the offset of %lu exceeds this.", + NumericLimits::Maximum(), max_offset); + } + for (idx_t i = from; i < to; i++) { + auto source_idx = format.sel->get_index(i); + auto offset_idx = append_data.row_count + i + 1 - from; + + if (!format.validity.RowIsValid(source_idx)) { + uint8_t current_bit; + idx_t current_byte; + GetBitPosition(append_data.row_count + i - from, current_byte, current_bit); + SetNull(append_data, validity_data, current_byte, current_bit); + offset_data[offset_idx] = last_offset; + continue; + } + + auto string_length = OP::GetLength(data[source_idx]); + + // append the offset data + auto current_offset = last_offset + string_length; + offset_data[offset_idx] = current_offset; + + // resize the string buffer if required, and write the string data + append_data.aux_buffer.resize(current_offset); + OP::WriteData(append_data.aux_buffer.data() + last_offset, data[source_idx]); + + last_offset = current_offset; + } + append_data.row_count += size; + } + + static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { + result->n_buffers = 3; + result->buffers[1] = append_data.main_buffer.data(); + result->buffers[2] = append_data.aux_buffer.data(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/arrow.hpp b/src/duckdb/src/include/duckdb/common/arrow/arrow.hpp new file mode 100644 index 00000000..b2f613ec --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/arrow.hpp @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arrow/arrow.hpp +// +// +//===----------------------------------------------------------------------===// + +#ifndef ARROW_FLAG_DICTIONARY_ORDERED + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char *format; + const char *name; + const char *metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema **children; + struct ArrowSchema *dictionary; + + // Release callback + void (*release)(struct ArrowSchema *); + // Opaque producer-specific data + void *private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void **buffers; + struct ArrowArray **children; + struct ArrowArray *dictionary; + + // Release callback + void (*release)(struct ArrowArray *); + // Opaque producer-specific data + void *private_data; +}; +#endif + +#ifndef ARROW_C_STREAM_INTERFACE +#define ARROW_C_STREAM_INTERFACE +// EXPERIMENTAL +struct ArrowArrayStream { + // Callback to get the stream type + // (will be the same for all arrays in the stream). + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + int (*get_schema)(struct ArrowArrayStream *, struct ArrowSchema *out); + // Callback to get the next array + // (if no error and the array is released, the stream has ended) + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + int (*get_next)(struct ArrowArrayStream *, struct ArrowArray *out); + + // Callback to get optional detailed error information. + // This must only be called if the last stream operation failed + // with a non-0 return code. The returned pointer is only valid until + // the next operation on this stream (including release). + // If unavailable, NULL is returned. + const char *(*get_last_error)(struct ArrowArrayStream *); + + // Release callback: release the stream's own resources. + // Note that arrays returned by `get_next` must be individually released. + void (*release)(struct ArrowArrayStream *); + // Opaque producer-specific data + void *private_data; +}; +#endif + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/duckdb/src/include/duckdb/common/arrow/arrow_appender.hpp b/src/duckdb/src/include/duckdb/common/arrow/arrow_appender.hpp new file mode 100644 index 00000000..0d3edf7e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/arrow_appender.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arrow/arrow_appender.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/arrow/arrow_converter.hpp" +#include "duckdb/common/arrow/arrow.hpp" + +namespace duckdb { + +struct ArrowAppendData; + +//! The ArrowAppender class can be used to incrementally construct an arrow array by appending data chunks into it +class ArrowAppender { +public: + DUCKDB_API ArrowAppender(vector types, idx_t initial_capacity, ClientProperties options); + DUCKDB_API ~ArrowAppender(); + + //! Append a data chunk to the underlying arrow array + DUCKDB_API void Append(DataChunk &input, idx_t from, idx_t to, idx_t input_size); + //! Returns the underlying arrow array + DUCKDB_API ArrowArray Finalize(); + +public: + static void ReleaseArray(ArrowArray *array); + static ArrowArray *FinalizeChild(const LogicalType &type, ArrowAppendData &append_data); + static unique_ptr InitializeChild(const LogicalType &type, idx_t capacity, + ClientProperties &options); + +private: + //! The types of the chunks that will be appended in + vector types; + //! The root arrow append data + vector> root_data; + //! The total row count that has been appended + idx_t row_count = 0; + + ClientProperties options; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/arrow_buffer.hpp b/src/duckdb/src/include/duckdb/common/arrow/arrow_buffer.hpp new file mode 100644 index 00000000..e1624ef6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/arrow_buffer.hpp @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arrow/arrow_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +struct ArrowSchema; + +namespace duckdb { + +struct ArrowBuffer { + static constexpr const idx_t MINIMUM_SHRINK_SIZE = 4096; + + ArrowBuffer() : dataptr(nullptr), count(0), capacity(0) { + } + ~ArrowBuffer() { + if (!dataptr) { + return; + } + free(dataptr); + dataptr = nullptr; + count = 0; + capacity = 0; + } + // disable copy constructors + ArrowBuffer(const ArrowBuffer &other) = delete; + ArrowBuffer &operator=(const ArrowBuffer &) = delete; + //! enable move constructors + ArrowBuffer(ArrowBuffer &&other) noexcept { + std::swap(dataptr, other.dataptr); + std::swap(count, other.count); + std::swap(capacity, other.capacity); + } + ArrowBuffer &operator=(ArrowBuffer &&other) noexcept { + std::swap(dataptr, other.dataptr); + std::swap(count, other.count); + std::swap(capacity, other.capacity); + return *this; + } + + void reserve(idx_t bytes) { // NOLINT + auto new_capacity = NextPowerOfTwo(bytes); + if (new_capacity <= capacity) { + return; + } + ReserveInternal(new_capacity); + } + + void resize(idx_t bytes) { // NOLINT + reserve(bytes); + count = bytes; + } + + void resize(idx_t bytes, data_t value) { // NOLINT + reserve(bytes); + for (idx_t i = count; i < bytes; i++) { + dataptr[i] = value; + } + count = bytes; + } + + idx_t size() { // NOLINT + return count; + } + + data_ptr_t data() { // NOLINT + return dataptr; + } + + template + T *GetData() { + return reinterpret_cast(data()); + } + +private: + void ReserveInternal(idx_t bytes) { + if (dataptr) { + dataptr = data_ptr_cast(realloc(dataptr, bytes)); + } else { + dataptr = data_ptr_cast(malloc(bytes)); + } + capacity = bytes; + } + +private: + data_ptr_t dataptr = nullptr; + idx_t count = 0; + idx_t capacity = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/arrow_converter.hpp b/src/duckdb/src/include/duckdb/common/arrow/arrow_converter.hpp new file mode 100644 index 00000000..2f80fb4a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/arrow_converter.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arrow/arrow_converter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/main/client_properties.hpp" + +namespace duckdb { + +struct ArrowConverter { + DUCKDB_API static void ToArrowSchema(ArrowSchema *out_schema, const vector &types, + const vector &names, const ClientProperties &options); + DUCKDB_API static void ToArrowArray(DataChunk &input, ArrowArray *out_array, ClientProperties options); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/arrow_wrapper.hpp b/src/duckdb/src/include/duckdb/common/arrow/arrow_wrapper.hpp new file mode 100644 index 00000000..d022c33a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/arrow_wrapper.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arrow/arrow_wrapper.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once +#include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/main/chunk_scan_state.hpp" +#include "duckdb/main/client_properties.hpp" + +//! Here we have the internal duckdb classes that interact with Arrow's Internal Header (i.e., duckdb/commons/arrow.hpp) +namespace duckdb { +class QueryResult; +class DataChunk; + +class ArrowSchemaWrapper { +public: + ArrowSchema arrow_schema; + + ArrowSchemaWrapper() { + arrow_schema.release = nullptr; + } + + ~ArrowSchemaWrapper(); +}; +class ArrowArrayWrapper { +public: + ArrowArray arrow_array; + ArrowArrayWrapper() { + arrow_array.length = 0; + arrow_array.release = nullptr; + } + ~ArrowArrayWrapper(); +}; + +class ArrowArrayStreamWrapper { +public: + ArrowArrayStream arrow_array_stream; + int64_t number_of_rows; + +public: + void GetSchema(ArrowSchemaWrapper &schema); + + shared_ptr GetNextChunk(); + + const char *GetError(); + + ~ArrowArrayStreamWrapper(); + ArrowArrayStreamWrapper() { + arrow_array_stream.release = nullptr; + } +}; + +class ArrowUtil { +public: + static bool TryFetchChunk(ChunkScanState &scan_state, ClientProperties options, idx_t chunk_size, ArrowArray *out, + idx_t &result_count, PreservedError &error); + static idx_t FetchChunk(ChunkScanState &scan_state, ClientProperties options, idx_t chunk_size, ArrowArray *out); + +private: + static bool TryFetchNext(QueryResult &result, unique_ptr &out, PreservedError &error); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/nanoarrow/nanoarrow.h b/src/duckdb/src/include/duckdb/common/arrow/nanoarrow/nanoarrow.h new file mode 100644 index 00000000..50814d51 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/nanoarrow/nanoarrow.h @@ -0,0 +1,462 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef NANOARROW_H_INCLUDED +#define NANOARROW_H_INCLUDED + +#include +#include +#include + +#include "duckdb/common/arrow/arrow.hpp" + +namespace duckdb_nanoarrow { + +/// \file Arrow C Implementation +/// +/// EXPERIMENTAL. Interface subject to change. + +/// \page object-model Object Model +/// +/// Except where noted, objects are not thread-safe and clients should +/// take care to serialize accesses to methods. +/// +/// Because this library is intended to be vendored, it provides full type +/// definitions and encourages clients to stack or statically allocate +/// where convenient. + +/// \defgroup nanoarrow-malloc Memory management +/// +/// Non-buffer members of a struct ArrowSchema and struct ArrowArray +/// must be allocated using ArrowMalloc() or ArrowRealloc() and freed +/// using ArrowFree for schemas and arrays allocated here. Buffer members +/// are allocated using an ArrowBufferAllocator. + +/// \brief Allocate like malloc() +void *ArrowMalloc(int64_t size); + +/// \brief Reallocate like realloc() +void *ArrowRealloc(void *ptr, int64_t size); + +/// \brief Free a pointer allocated using ArrowMalloc() or ArrowRealloc(). +void ArrowFree(void *ptr); + +/// \brief Array buffer allocation and deallocation +/// +/// Container for allocate, reallocate, and free methods that can be used +/// to customize allocation and deallocation of buffers when constructing +/// an ArrowArray. +struct ArrowBufferAllocator { + /// \brief Allocate a buffer or return NULL if it cannot be allocated + uint8_t *(*allocate)(struct ArrowBufferAllocator *allocator, int64_t size); + + /// \brief Reallocate a buffer or return NULL if it cannot be reallocated + uint8_t *(*reallocate)(struct ArrowBufferAllocator *allocator, uint8_t *ptr, int64_t old_size, int64_t new_size); + + /// \brief Deallocate a buffer allocated by this allocator + void (*free)(struct ArrowBufferAllocator *allocator, uint8_t *ptr, int64_t size); + + /// \brief Opaque data specific to the allocator + void *private_data; +}; + +/// \brief Return the default allocator +/// +/// The default allocator uses ArrowMalloc(), ArrowRealloc(), and +/// ArrowFree(). +struct ArrowBufferAllocator *ArrowBufferAllocatorDefault(); + +/// }@ + +/// \defgroup nanoarrow-errors Error handling primitives +/// Functions generally return an errno-compatible error code; functions that +/// need to communicate more verbose error information accept a pointer +/// to an ArrowError. This can be stack or statically allocated. The +/// content of the message is undefined unless an error code has been +/// returned. + +/// \brief Error type containing a UTF-8 encoded message. +struct ArrowError { + char message[1024]; +}; + +/// \brief Return code for success. +#define NANOARROW_OK 0 + +/// \brief Represents an errno-compatible error code +typedef int ArrowErrorCode; + +/// \brief Set the contents of an error using printf syntax +ArrowErrorCode ArrowErrorSet(struct ArrowError *error, const char *fmt, ...); + +/// \brief Get the contents of an error +const char *ArrowErrorMessage(struct ArrowError *error); + +/// }@ + +/// \defgroup nanoarrow-utils Utility data structures + +/// \brief An non-owning view of a string +struct ArrowStringView { + /// \brief A pointer to the start of the string + /// + /// If n_bytes is 0, this value may be NULL. + const char *data; + + /// \brief The size of the string in bytes, + /// + /// (Not including the null terminator.) + int64_t n_bytes; +}; + +/// \brief Arrow type enumerator +/// +/// These names are intended to map to the corresponding arrow::Type::type +/// enumerator; however, the numeric values are specifically not equal +/// (i.e., do not rely on numeric comparison). +enum ArrowType { + NANOARROW_TYPE_UNINITIALIZED = 0, + NANOARROW_TYPE_NA = 1, + NANOARROW_TYPE_BOOL, + NANOARROW_TYPE_UINT8, + NANOARROW_TYPE_INT8, + NANOARROW_TYPE_UINT16, + NANOARROW_TYPE_INT16, + NANOARROW_TYPE_UINT32, + NANOARROW_TYPE_INT32, + NANOARROW_TYPE_UINT64, + NANOARROW_TYPE_INT64, + NANOARROW_TYPE_HALF_FLOAT, + NANOARROW_TYPE_FLOAT, + NANOARROW_TYPE_DOUBLE, + NANOARROW_TYPE_STRING, + NANOARROW_TYPE_BINARY, + NANOARROW_TYPE_FIXED_SIZE_BINARY, + NANOARROW_TYPE_DATE32, + NANOARROW_TYPE_DATE64, + NANOARROW_TYPE_TIMESTAMP, + NANOARROW_TYPE_TIME32, + NANOARROW_TYPE_TIME64, + NANOARROW_TYPE_INTERVAL_MONTHS, + NANOARROW_TYPE_INTERVAL_DAY_TIME, + NANOARROW_TYPE_DECIMAL128, + NANOARROW_TYPE_DECIMAL256, + NANOARROW_TYPE_LIST, + NANOARROW_TYPE_STRUCT, + NANOARROW_TYPE_SPARSE_UNION, + NANOARROW_TYPE_DENSE_UNION, + NANOARROW_TYPE_DICTIONARY, + NANOARROW_TYPE_MAP, + NANOARROW_TYPE_EXTENSION, + NANOARROW_TYPE_FIXED_SIZE_LIST, + NANOARROW_TYPE_DURATION, + NANOARROW_TYPE_LARGE_STRING, + NANOARROW_TYPE_LARGE_BINARY, + NANOARROW_TYPE_LARGE_LIST, + NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO +}; + +/// \brief Arrow time unit enumerator +/// +/// These names and values map to the corresponding arrow::TimeUnit::type +/// enumerator. +enum ArrowTimeUnit { + NANOARROW_TIME_UNIT_SECOND = 0, + NANOARROW_TIME_UNIT_MILLI = 1, + NANOARROW_TIME_UNIT_MICRO = 2, + NANOARROW_TIME_UNIT_NANO = 3 +}; + +/// }@ + +/// \defgroup nanoarrow-schema Schema producer helpers +/// These functions allocate, copy, and destroy ArrowSchema structures + +/// \brief Initialize the fields of a schema +/// +/// Initializes the fields and release callback of schema_out. Caller +/// is responsible for calling the schema->release callback if +/// NANOARROW_OK is returned. +ArrowErrorCode ArrowSchemaInit(struct ArrowSchema *schema, enum ArrowType type); + +/// \brief Initialize the fields of a fixed-size schema +/// +/// Returns EINVAL for fixed_size <= 0 or for data_type that is not +/// NANOARROW_TYPE_FIXED_SIZE_BINARY or NANOARROW_TYPE_FIXED_SIZE_LIST. +ArrowErrorCode ArrowSchemaInitFixedSize(struct ArrowSchema *schema, enum ArrowType data_type, int32_t fixed_size); + +/// \brief Initialize the fields of a decimal schema +/// +/// Returns EINVAL for scale <= 0 or for data_type that is not +/// NANOARROW_TYPE_DECIMAL128 or NANOARROW_TYPE_DECIMAL256. +ArrowErrorCode ArrowSchemaInitDecimal(struct ArrowSchema *schema, enum ArrowType data_type, int32_t decimal_precision, + int32_t decimal_scale); + +/// \brief Initialize the fields of a time, timestamp, or duration schema +/// +/// Returns EINVAL for data_type that is not +/// NANOARROW_TYPE_TIME32, NANOARROW_TYPE_TIME64, +/// NANOARROW_TYPE_TIMESTAMP, or NANOARROW_TYPE_DURATION. The +/// timezone parameter must be NULL for a non-timestamp data_type. +ArrowErrorCode ArrowSchemaInitDateTime(struct ArrowSchema *schema, enum ArrowType data_type, + enum ArrowTimeUnit time_unit, const char *timezone); + +/// \brief Make a (recursive) copy of a schema +/// +/// Allocates and copies fields of schema into schema_out. +ArrowErrorCode ArrowSchemaDeepCopy(struct ArrowSchema *schema, struct ArrowSchema *schema_out); + +/// \brief Copy format into schema->format +/// +/// schema must have been allocated using ArrowSchemaInit or +/// ArrowSchemaDeepCopy. +ArrowErrorCode ArrowSchemaSetFormat(struct ArrowSchema *schema, const char *format); + +/// \brief Copy name into schema->name +/// +/// schema must have been allocated using ArrowSchemaInit or +/// ArrowSchemaDeepCopy. +ArrowErrorCode ArrowSchemaSetName(struct ArrowSchema *schema, const char *name); + +/// \brief Copy metadata into schema->metadata +/// +/// schema must have been allocated using ArrowSchemaInit or +/// ArrowSchemaDeepCopy. +ArrowErrorCode ArrowSchemaSetMetadata(struct ArrowSchema *schema, const char *metadata); + +/// \brief Allocate the schema->children array +/// +/// Includes the memory for each child struct ArrowSchema. +/// schema must have been allocated using ArrowSchemaInit or +/// ArrowSchemaDeepCopy. +ArrowErrorCode ArrowSchemaAllocateChildren(struct ArrowSchema *schema, int64_t n_children); + +/// \brief Allocate the schema->dictionary member +/// +/// schema must have been allocated using ArrowSchemaInit or +/// ArrowSchemaDeepCopy. +ArrowErrorCode ArrowSchemaAllocateDictionary(struct ArrowSchema *schema); + +/// \brief Reader for key/value pairs in schema metadata +struct ArrowMetadataReader { + const char *metadata; + int64_t offset; + int32_t remaining_keys; +}; + +/// \brief Initialize an ArrowMetadataReader +ArrowErrorCode ArrowMetadataReaderInit(struct ArrowMetadataReader *reader, const char *metadata); + +/// \brief Read the next key/value pair from an ArrowMetadataReader +ArrowErrorCode ArrowMetadataReaderRead(struct ArrowMetadataReader *reader, struct ArrowStringView *key_out, + struct ArrowStringView *value_out); + +/// \brief The number of bytes in in a key/value metadata string +int64_t ArrowMetadataSizeOf(const char *metadata); + +/// \brief Check for a key in schema metadata +char ArrowMetadataHasKey(const char *metadata, const char *key); + +/// \brief Extract a value from schema metadata +ArrowErrorCode ArrowMetadataGetValue(const char *metadata, const char *key, const char *default_value, + struct ArrowStringView *value_out); + +/// }@ + +/// \defgroup nanoarrow-schema-view Schema consumer helpers + +/// \brief A non-owning view of a parsed ArrowSchema +/// +/// Contains more readily extractable values than a raw ArrowSchema. +/// Clients can stack or statically allocate this structure but are +/// encouraged to use the provided getters to ensure forward +/// compatiblity. +struct ArrowSchemaView { + /// \brief A pointer to the schema represented by this view + struct ArrowSchema *schema; + + /// \brief The data type represented by the schema + /// + /// This value may be NANOARROW_TYPE_DICTIONARY if the schema has a + /// non-null dictionary member; datetime types are valid values. + /// This value will never be NANOARROW_TYPE_EXTENSION (see + /// extension_name and/or extension_metadata to check for + /// an extension type). + enum ArrowType data_type; + + /// \brief The storage data type represented by the schema + /// + /// This value will never be NANOARROW_TYPE_DICTIONARY, NANOARROW_TYPE_EXTENSION + /// or any datetime type. This value represents only the type required to + /// interpret the buffers in the array. + enum ArrowType storage_data_type; + + /// \brief The extension type name if it exists + /// + /// If the ARROW:extension:name key is present in schema.metadata, + /// extension_name.data will be non-NULL. + struct ArrowStringView extension_name; + + /// \brief The extension type metadata if it exists + /// + /// If the ARROW:extension:metadata key is present in schema.metadata, + /// extension_metadata.data will be non-NULL. + struct ArrowStringView extension_metadata; + + /// \brief The expected number of buffers in a paired ArrowArray + int32_t n_buffers; + + /// \brief The index of the validity buffer or -1 if one does not exist + int32_t validity_buffer_id; + + /// \brief The index of the offset buffer or -1 if one does not exist + int32_t offset_buffer_id; + + /// \brief The index of the data buffer or -1 if one does not exist + int32_t data_buffer_id; + + /// \brief The index of the type_ids buffer or -1 if one does not exist + int32_t type_id_buffer_id; + + /// \brief Format fixed size parameter + /// + /// This value is set when parsing a fixed-size binary or fixed-size + /// list schema; this value is undefined for other types. For a + /// fixed-size binary schema this value is in bytes; for a fixed-size + /// list schema this value refers to the number of child elements for + /// each element of the parent. + int32_t fixed_size; + + /// \brief Decimal bitwidth + /// + /// This value is set when parsing a decimal type schema; + /// this value is undefined for other types. + int32_t decimal_bitwidth; + + /// \brief Decimal precision + /// + /// This value is set when parsing a decimal type schema; + /// this value is undefined for other types. + int32_t decimal_precision; + + /// \brief Decimal scale + /// + /// This value is set when parsing a decimal type schema; + /// this value is undefined for other types. + int32_t decimal_scale; + + /// \brief Format time unit parameter + /// + /// This value is set when parsing a date/time type. The value is + /// undefined for other types. + enum ArrowTimeUnit time_unit; + + /// \brief Format timezone parameter + /// + /// This value is set when parsing a timestamp type and represents + /// the timezone format parameter. The ArrowStrintgView points to + /// data within the schema and the value is undefined for other types. + struct ArrowStringView timezone; + + /// \brief Union type ids parameter + /// + /// This value is set when parsing a union type and represents + /// type ids parameter. The ArrowStringView points to + /// data within the schema and the value is undefined for other types. + struct ArrowStringView union_type_ids; +}; + +/// \brief Initialize an ArrowSchemaView +ArrowErrorCode ArrowSchemaViewInit(struct ArrowSchemaView *schema_view, struct ArrowSchema *schema, + struct ArrowError *error); + +/// }@ + +/// \defgroup nanoarrow-buffer-builder Growable buffer builders + +/// \brief An owning mutable view of a buffer +struct ArrowBuffer { + /// \brief A pointer to the start of the buffer + /// + /// If capacity_bytes is 0, this value may be NULL. + uint8_t *data; + + /// \brief The size of the buffer in bytes + int64_t size_bytes; + + /// \brief The capacity of the buffer in bytes + int64_t capacity_bytes; + + /// \brief The allocator that will be used to reallocate and/or free the buffer + struct ArrowBufferAllocator *allocator; +}; + +/// \brief Initialize an ArrowBuffer +/// +/// Initialize a buffer with a NULL, zero-size buffer using the default +/// buffer allocator. +void ArrowBufferInit(struct ArrowBuffer *buffer); + +/// \brief Set a newly-initialized buffer's allocator +/// +/// Returns EINVAL if the buffer has already been allocated. +ArrowErrorCode ArrowBufferSetAllocator(struct ArrowBuffer *buffer, struct ArrowBufferAllocator *allocator); + +/// \brief Reset an ArrowBuffer +/// +/// Releases the buffer using the allocator's free method if +/// the buffer's data member is non-null, sets the data member +/// to NULL, and sets the buffer's size and capacity to 0. +void ArrowBufferReset(struct ArrowBuffer *buffer); + +/// \brief Move an ArrowBuffer +/// +/// Transfers the buffer data and lifecycle management to another +/// address and resets buffer. +void ArrowBufferMove(struct ArrowBuffer *buffer, struct ArrowBuffer *buffer_out); + +/// \brief Grow or shrink a buffer to a given capacity +/// +/// When shrinking the capacity of the buffer, the buffer is only reallocated +/// if shrink_to_fit is non-zero. Calling ArrowBufferResize() does not +/// adjust the buffer's size member except to ensure that the invariant +/// capacity >= size remains true. +ArrowErrorCode ArrowBufferResize(struct ArrowBuffer *buffer, int64_t new_capacity_bytes, char shrink_to_fit); + +/// \brief Ensure a buffer has at least a given additional capacity +/// +/// Ensures that the buffer has space to append at least +/// additional_size_bytes, overallocating when required. +ArrowErrorCode ArrowBufferReserve(struct ArrowBuffer *buffer, int64_t additional_size_bytes); + +/// \brief Write data to buffer and increment the buffer size +/// +/// This function does not check that buffer has the required capacity +void ArrowBufferAppendUnsafe(struct ArrowBuffer *buffer, const void *data, int64_t size_bytes); + +/// \brief Write data to buffer and increment the buffer size +/// +/// This function writes and ensures that the buffer has the required capacity, +/// possibly by reallocating the buffer. Like ArrowBufferReserve, this will +/// overallocate when reallocation is required. +ArrowErrorCode ArrowBufferAppend(struct ArrowBuffer *buffer, const void *data, int64_t size_bytes); + +/// }@ + +} // namespace duckdb_nanoarrow + +#endif // NANOARROW_H_INCLUDED diff --git a/src/duckdb/src/include/duckdb/common/arrow/nanoarrow/nanoarrow.hpp b/src/duckdb/src/include/duckdb/common/arrow/nanoarrow/nanoarrow.hpp new file mode 100644 index 00000000..d939bada --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/nanoarrow/nanoarrow.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "duckdb/common/arrow/nanoarrow/nanoarrow.h" + +// Bring in the symbols from duckdb_nanoarrow into duckdb +namespace duckdb { + +// using duckdb_nanoarrow::ArrowBuffer; //We have a variant of this that should be renamed +using duckdb_nanoarrow::ArrowBufferAllocator; +using duckdb_nanoarrow::ArrowError; +using duckdb_nanoarrow::ArrowSchemaView; +using duckdb_nanoarrow::ArrowStringView; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/result_arrow_wrapper.hpp b/src/duckdb/src/include/duckdb/common/arrow/result_arrow_wrapper.hpp new file mode 100644 index 00000000..629316de --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arrow/result_arrow_wrapper.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arrow/result_arrow_wrapper.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/query_result.hpp" +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/main/chunk_scan_state.hpp" + +namespace duckdb { +class ResultArrowArrayStreamWrapper { +public: + explicit ResultArrowArrayStreamWrapper(unique_ptr result, idx_t batch_size); + +public: + ArrowArrayStream stream; + unique_ptr result; + PreservedError last_error; + idx_t batch_size; + vector column_types; + vector column_names; + unique_ptr scan_state; + +private: + static int MyStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out); + static int MyStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out); + static void MyStreamRelease(struct ArrowArrayStream *stream); + static const char *MyStreamGetLastError(struct ArrowArrayStream *stream); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/assert.hpp b/src/duckdb/src/include/duckdb/common/assert.hpp new file mode 100644 index 00000000..62721f73 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/assert.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/assert.hpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/common/winapi.hpp" + +#pragma once + +#if (defined(DUCKDB_USE_STANDARD_ASSERT) || !defined(DEBUG)) && !defined(DUCKDB_FORCE_ASSERT) && !defined(__MVS__) + +#include +#define D_ASSERT assert +namespace duckdb { +DUCKDB_API void DuckDBAssertInternal(bool condition, const char *condition_name, const char *file, int linenr); +} + +#else +namespace duckdb { +DUCKDB_API void DuckDBAssertInternal(bool condition, const char *condition_name, const char *file, int linenr); +} + +#define D_ASSERT(condition) duckdb::DuckDBAssertInternal(bool(condition), #condition, __FILE__, __LINE__) + +#endif diff --git a/src/duckdb/src/include/duckdb/common/atomic.hpp b/src/duckdb/src/include/duckdb/common/atomic.hpp new file mode 100644 index 00000000..4445b311 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/atomic.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/atomic.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::atomic; +} diff --git a/src/duckdb/src/include/duckdb/common/bind_helpers.hpp b/src/duckdb/src/include/duckdb/common/bind_helpers.hpp new file mode 100644 index 00000000..99204762 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/bind_helpers.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/bind_helpers.hpp +// +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "duckdb/common/vector.hpp" +#include "duckdb/common/common.hpp" + +namespace duckdb { + +class Value; + +Value ConvertVectorToValue(vector set); +vector ParseColumnList(const vector &set, vector &names, const string &option_name); +vector ParseColumnList(const Value &value, vector &names, const string &option_name); +vector ParseColumnsOrdered(const vector &set, vector &names, const string &loption); +vector ParseColumnsOrdered(const Value &value, vector &names, const string &loption); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/bit_utils.hpp b/src/duckdb/src/include/duckdb/common/bit_utils.hpp new file mode 100644 index 00000000..da4aad66 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/bit_utils.hpp @@ -0,0 +1,141 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/bit_utils.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/hugeint.hpp" + +#ifdef _MSC_VER +#define __restrict__ +#define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__ +#define __ORDER_LITTLE_ENDIAN__ 2 +#include +static inline int __builtin_ctzll(unsigned long long x) { +#ifdef _WIN64 + unsigned long ret; + _BitScanForward64(&ret, x); + return (int)ret; +#else + unsigned long low, high; + bool low_set = _BitScanForward(&low, (unsigned __int32)(x)) != 0; + _BitScanForward(&high, (unsigned __int32)(x >> 32)); + high += 32; + return low_set ? low : high; +#endif +} +static inline int __builtin_clzll(unsigned long long mask) { + unsigned long where; +// BitScanReverse scans from MSB to LSB for first set bit. +// Returns 0 if no set bit is found. +#if defined(_WIN64) + if (_BitScanReverse64(&where, mask)) + return static_cast(63 - where); +#elif defined(_WIN32) + // Scan the high 32 bits. + if (_BitScanReverse(&where, static_cast(mask >> 32))) + return static_cast(63 - (where + 32)); // Create a bit offset from the MSB. + // Scan the low 32 bits. + if (_BitScanReverse(&where, static_cast(mask))) + return static_cast(63 - where); +#else +#error "Implementation of __builtin_clzll required" +#endif + return 64; // Undefined Behavior. +} + +static inline int __builtin_ctz(unsigned int value) { + unsigned long trailing_zero = 0; + + if (_BitScanForward(&trailing_zero, value)) { + return trailing_zero; + } else { + // This is undefined, I better choose 32 than 0 + return 32; + } +} + +static inline int __builtin_clz(unsigned int value) { + unsigned long leading_zero = 0; + + if (_BitScanReverse(&leading_zero, value)) { + return 31 - leading_zero; + } else { + // Same remarks as above + return 32; + } +} + +#endif + +namespace duckdb { + +template +struct CountZeros {}; + +template <> +struct CountZeros { + inline static int Leading(uint32_t value) { + if (!value) { + return 32; + } + return __builtin_clz(value); + } + inline static int Trailing(uint32_t value) { + if (!value) { + return 32; + } + return __builtin_ctz(value); + } +}; + +template <> +struct CountZeros { + inline static int Leading(uint64_t value) { + if (!value) { + return 64; + } + return __builtin_clzll(value); + } + inline static int Trailing(uint64_t value) { + if (!value) { + return 64; + } + return __builtin_ctzll(value); + } +}; + +template <> +struct CountZeros { + inline static int Leading(hugeint_t value) { + const uint64_t upper = (uint64_t)value.upper; + const uint64_t lower = value.lower; + + if (upper) { + return __builtin_clzll(upper); + } else if (lower) { + return 64 + __builtin_clzll(lower); + } else { + return 128; + } + } + + inline static int Trailing(hugeint_t value) { + const uint64_t upper = (uint64_t)value.upper; + const uint64_t lower = value.lower; + + if (lower) { + return __builtin_ctzll(lower); + } else if (upper) { + return 64 + __builtin_ctzll(upper); + } else { + return 128; + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/bitpacking.hpp b/src/duckdb/src/include/duckdb/common/bitpacking.hpp new file mode 100644 index 00000000..43a20042 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/bitpacking.hpp @@ -0,0 +1,260 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/bitpacking.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "bitpackinghelpers.h" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/numeric_utils.hpp" + +namespace duckdb { + +using bitpacking_width_t = uint8_t; + +struct HugeIntPacker { + static void Pack(const hugeint_t *__restrict in, uint32_t *__restrict out, bitpacking_width_t width); + static void Unpack(const uint32_t *__restrict in, hugeint_t *__restrict out, bitpacking_width_t width); +}; + +class BitpackingPrimitives { + +public: + static constexpr const idx_t BITPACKING_ALGORITHM_GROUP_SIZE = 32; + static constexpr const idx_t BITPACKING_HEADER_SIZE = sizeof(uint64_t); + static constexpr const bool BYTE_ALIGNED = false; + + // To ensure enough data is available, use GetRequiredSize() to determine the correct size for dst buffer + // Note: input should be aligned to BITPACKING_ALGORITHM_GROUP_SIZE for good performance. + template + inline static void PackBuffer(data_ptr_t dst, T *src, idx_t count, bitpacking_width_t width) { + if (ASSUME_INPUT_ALIGNED) { + for (idx_t i = 0; i < count; i += BITPACKING_ALGORITHM_GROUP_SIZE) { + PackGroup(dst + (i * width) / 8, src + i, width); + } + } else { + idx_t misaligned_count = count % BITPACKING_ALGORITHM_GROUP_SIZE; + T tmp_buffer[BITPACKING_ALGORITHM_GROUP_SIZE]; // TODO maybe faster on the heap? + + count -= misaligned_count; + + for (idx_t i = 0; i < count; i += BITPACKING_ALGORITHM_GROUP_SIZE) { + PackGroup(dst + (i * width) / 8, src + i, width); + } + + // Input was not aligned to BITPACKING_ALGORITHM_GROUP_SIZE, we need a copy + if (misaligned_count) { + memcpy(tmp_buffer, src + count, misaligned_count * sizeof(T)); + PackGroup(dst + (count * width) / 8, tmp_buffer, width); + } + } + } + + // Unpacks a block of BITPACKING_ALGORITHM_GROUP_SIZE values + // Assumes both src and dst to be of the correct size + template + inline static void UnPackBuffer(data_ptr_t dst, data_ptr_t src, idx_t count, bitpacking_width_t width, + bool skip_sign_extension = false) { + + for (idx_t i = 0; i < count; i += BITPACKING_ALGORITHM_GROUP_SIZE) { + UnPackGroup(dst + i * sizeof(T), src + (i * width) / 8, width, skip_sign_extension); + } + } + + // Packs a block of BITPACKING_ALGORITHM_GROUP_SIZE values + template + inline static void PackBlock(data_ptr_t dst, T *src, bitpacking_width_t width) { + return PackGroup(dst, src, width); + } + + // Unpacks a block of BITPACKING_ALGORITHM_GROUP_SIZE values + template + inline static void UnPackBlock(data_ptr_t dst, data_ptr_t src, bitpacking_width_t width, + bool skip_sign_extension = false) { + return UnPackGroup(dst, src, width, skip_sign_extension); + } + + // Calculates the minimum required number of bits per value that can store all values + template ::IsSigned()> + inline static bitpacking_width_t MinimumBitWidth(T value) { + return FindMinimumBitWidth(value, value); + } + + // Calculates the minimum required number of bits per value that can store all values + template ::IsSigned()> + inline static bitpacking_width_t MinimumBitWidth(T *values, idx_t count) { + return FindMinimumBitWidth(values, count); + } + + // Calculates the minimum required number of bits per value that can store all values, + // given a predetermined minimum and maximum value of the buffer + template ::IsSigned()> + inline static bitpacking_width_t MinimumBitWidth(T minimum, T maximum) { + return FindMinimumBitWidth(minimum, maximum); + } + + inline static idx_t GetRequiredSize(idx_t count, bitpacking_width_t width) { + count = RoundUpToAlgorithmGroupSize(count); + return ((count * width) / 8); + } + + template + inline static T RoundUpToAlgorithmGroupSize(T num_to_round) { + int remainder = num_to_round % BITPACKING_ALGORITHM_GROUP_SIZE; + if (remainder == 0) { + return num_to_round; + } + + return num_to_round + BITPACKING_ALGORITHM_GROUP_SIZE - remainder; + } + +private: + template + static bitpacking_width_t FindMinimumBitWidth(T *values, idx_t count) { + T min_value = values[0]; + T max_value = values[0]; + + for (idx_t i = 1; i < count; i++) { + if (values[i] > max_value) { + max_value = values[i]; + } + + if (is_signed) { + if (values[i] < min_value) { + min_value = values[i]; + } + } + } + + return FindMinimumBitWidth(min_value, max_value); + } + + template + static bitpacking_width_t FindMinimumBitWidth(T min_value, T max_value) { + bitpacking_width_t bitwidth; + T value; + + if (is_signed) { + if (min_value == NumericLimits::Minimum()) { + // handle special case of the minimal value, as it cannot be negated like all other values. + return sizeof(T) * 8; + } else { + value = MaxValue((T)-min_value, max_value); + } + } else { + value = max_value; + } + + if (value == 0) { + return 0; + } + + if (is_signed) { + bitwidth = 1; + } else { + bitwidth = 0; + } + + while (value) { + bitwidth++; + value >>= 1; + } + + bitwidth = GetEffectiveWidth(bitwidth); + + // Assert results are correct +#ifdef DEBUG + if (bitwidth < sizeof(T) * 8 && bitwidth != 0) { + if (is_signed) { + D_ASSERT(max_value <= (T(1) << (bitwidth - 1)) - 1); + D_ASSERT(min_value >= (T(-1) * ((T(1) << (bitwidth - 1)) - 1) - 1)); + } else { + D_ASSERT(max_value <= (T(1) << (bitwidth)) - 1); + } + } +#endif + if (round_to_next_byte) { + return (bitwidth / 8 + (bitwidth % 8 != 0)) * 8; + } + return bitwidth; + } + + // Sign bit extension + template ::type> + static void SignExtend(data_ptr_t dst, bitpacking_width_t width) { + T const mask = T_U(1) << (width - 1); + for (idx_t i = 0; i < BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; ++i) { + T value = Load(dst + i * sizeof(T)); + value = value & ((T_U(1) << width) - T_U(1)); + T result = (value ^ mask) - mask; + Store(result, dst + i * sizeof(T)); + } + } + + // Prevent compression at widths that are ineffective + template + static bitpacking_width_t GetEffectiveWidth(bitpacking_width_t width) { + bitpacking_width_t bits_of_type = sizeof(T) * 8; + bitpacking_width_t type_size = sizeof(T); + if (width + type_size > bits_of_type) { + return bits_of_type; + } + return width; + } + + template + static inline void PackGroup(data_ptr_t dst, T *values, bitpacking_width_t width) { + if (std::is_same::value || std::is_same::value) { + duckdb_fastpforlib::fastpack(reinterpret_cast(values), reinterpret_cast(dst), + static_cast(width)); + } else if (std::is_same::value || std::is_same::value) { + duckdb_fastpforlib::fastpack(reinterpret_cast(values), reinterpret_cast(dst), + static_cast(width)); + } else if (std::is_same::value || std::is_same::value) { + duckdb_fastpforlib::fastpack(reinterpret_cast(values), reinterpret_cast(dst), + static_cast(width)); + } else if (std::is_same::value || std::is_same::value) { + duckdb_fastpforlib::fastpack(reinterpret_cast(values), reinterpret_cast(dst), + static_cast(width)); + } else if (std::is_same::value) { + HugeIntPacker::Pack(reinterpret_cast(values), reinterpret_cast(dst), width); + } else { + throw InternalException("Unsupported type for bitpacking"); + } + } + + template + static inline void UnPackGroup(data_ptr_t dst, data_ptr_t src, bitpacking_width_t width, + bool skip_sign_extension = false) { + if (std::is_same::value || std::is_same::value) { + duckdb_fastpforlib::fastunpack(reinterpret_cast(src), reinterpret_cast(dst), + static_cast(width)); + } else if (std::is_same::value || std::is_same::value) { + duckdb_fastpforlib::fastunpack(reinterpret_cast(src), reinterpret_cast(dst), + static_cast(width)); + } else if (std::is_same::value || std::is_same::value) { + duckdb_fastpforlib::fastunpack(reinterpret_cast(src), reinterpret_cast(dst), + static_cast(width)); + } else if (std::is_same::value || std::is_same::value) { + duckdb_fastpforlib::fastunpack(reinterpret_cast(src), reinterpret_cast(dst), + static_cast(width)); + } else if (std::is_same::value) { + HugeIntPacker::Unpack(reinterpret_cast(src), reinterpret_cast(dst), width); + } else { + throw InternalException("Unsupported type for bitpacking"); + } + + if (NumericLimits::IsSigned() && !skip_sign_extension && width > 0 && width < sizeof(T) * 8) { + SignExtend(dst, width); + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/bitset.hpp b/src/duckdb/src/include/duckdb/common/bitset.hpp new file mode 100644 index 00000000..60812350 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/bitset.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/bitset.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::bitset; +} diff --git a/src/duckdb/src/include/duckdb/common/box_renderer.hpp b/src/duckdb/src/include/duckdb/common/box_renderer.hpp new file mode 100644 index 00000000..df6c24c9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/box_renderer.hpp @@ -0,0 +1,120 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/box_renderer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/common/list.hpp" + +namespace duckdb { +class ColumnDataCollection; +class ColumnDataRowCollection; + +enum class ValueRenderAlignment { LEFT, MIDDLE, RIGHT }; +enum class RenderMode : uint8_t { ROWS, COLUMNS }; + +struct BoxRendererConfig { + // a max_width of 0 means we default to the terminal width + idx_t max_width = 0; + // the maximum amount of rows to render + idx_t max_rows = 20; + // the limit that is applied prior to rendering + // if we are rendering exactly "limit" rows then a question mark is rendered instead + idx_t limit = 0; + // the max col width determines the maximum size of a single column + // note that the max col width is only used if the result does not fit on the screen + idx_t max_col_width = 20; + //! how to render NULL values + string null_value = "NULL"; + //! Whether or not to render row-wise or column-wise + RenderMode render_mode = RenderMode::ROWS; + +#ifndef DUCKDB_ASCII_TREE_RENDERER + const char *LTCORNER = "\342\224\214"; // "┌"; + const char *RTCORNER = "\342\224\220"; // "┐"; + const char *LDCORNER = "\342\224\224"; // "└"; + const char *RDCORNER = "\342\224\230"; // "┘"; + + const char *MIDDLE = "\342\224\274"; // "┼"; + const char *TMIDDLE = "\342\224\254"; // "┬"; + const char *LMIDDLE = "\342\224\234"; // "├"; + const char *RMIDDLE = "\342\224\244"; // "┤"; + const char *DMIDDLE = "\342\224\264"; // "┴"; + + const char *VERTICAL = "\342\224\202"; // "│"; + const char *HORIZONTAL = "\342\224\200"; // "─"; + + const char *DOTDOTDOT = "\xE2\x80\xA6"; // "…"; + const char *DOT = "\xC2\xB7"; // "·"; + const idx_t DOTDOTDOT_LENGTH = 1; + +#else + // ASCII version + const char *LTCORNER = "<"; + const char *RTCORNER = ">"; + const char *LDCORNER = "<"; + const char *RDCORNER = ">"; + + const char *MIDDLE = "+"; + const char *TMIDDLE = "+"; + const char *LMIDDLE = "+"; + const char *RMIDDLE = "+"; + const char *DMIDDLE = "+"; + + const char *VERTICAL = "|"; + const char *HORIZONTAL = "-"; + + const char *DOTDOTDOT = "..."; // "..."; + const char *DOT = "."; // "."; + const idx_t DOTDOTDOT_LENGTH = 3; +#endif +}; + +class BoxRenderer { + static const idx_t SPLIT_COLUMN; + +public: + explicit BoxRenderer(BoxRendererConfig config_p = BoxRendererConfig()); + + string ToString(ClientContext &context, const vector &names, const ColumnDataCollection &op); + + void Render(ClientContext &context, const vector &names, const ColumnDataCollection &op, std::ostream &ss); + void Print(ClientContext &context, const vector &names, const ColumnDataCollection &op); + +private: + //! The configuration used for rendering + BoxRendererConfig config; + +private: + void RenderValue(std::ostream &ss, const string &value, idx_t column_width, + ValueRenderAlignment alignment = ValueRenderAlignment::MIDDLE); + string RenderType(const LogicalType &type); + ValueRenderAlignment TypeAlignment(const LogicalType &type); + string GetRenderValue(ColumnDataRowCollection &rows, idx_t c, idx_t r); + list FetchRenderCollections(ClientContext &context, const ColumnDataCollection &result, + idx_t top_rows, idx_t bottom_rows); + list PivotCollections(ClientContext &context, list input, + vector &column_names, vector &result_types, + idx_t row_count); + vector ComputeRenderWidths(const vector &names, const vector &result_types, + list &collections, idx_t min_width, idx_t max_width, + vector &column_map, idx_t &total_length); + void RenderHeader(const vector &names, const vector &result_types, + const vector &column_map, const vector &widths, const vector &boundaries, + idx_t total_length, bool has_results, std::ostream &ss); + void RenderValues(const list &collections, const vector &column_map, + const vector &widths, const vector &result_types, std::ostream &ss); + void RenderRowCount(string row_count_str, string shown_str, const string &column_count_str, + const vector &boundaries, bool has_hidden_rows, bool has_hidden_columns, + idx_t total_length, idx_t row_count, idx_t column_count, idx_t minimum_row_length, + std::ostream &ss); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/bswap.hpp b/src/duckdb/src/include/duckdb/common/bswap.hpp new file mode 100644 index 00000000..fae53969 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/bswap.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/bswap.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +#define BSWAP16(x) ((uint16_t)((((uint16_t)(x)&0xff00) >> 8) | (((uint16_t)(x)&0x00ff) << 8))) + +#define BSWAP32(x) \ + ((uint32_t)((((uint32_t)(x)&0xff000000) >> 24) | (((uint32_t)(x)&0x00ff0000) >> 8) | \ + (((uint32_t)(x)&0x0000ff00) << 8) | (((uint32_t)(x)&0x000000ff) << 24))) + +#define BSWAP64(x) \ + ((uint64_t)((((uint64_t)(x)&0xff00000000000000ull) >> 56) | (((uint64_t)(x)&0x00ff000000000000ull) >> 40) | \ + (((uint64_t)(x)&0x0000ff0000000000ull) >> 24) | (((uint64_t)(x)&0x000000ff00000000ull) >> 8) | \ + (((uint64_t)(x)&0x00000000ff000000ull) << 8) | (((uint64_t)(x)&0x0000000000ff0000ull) << 24) | \ + (((uint64_t)(x)&0x000000000000ff00ull) << 40) | (((uint64_t)(x)&0x00000000000000ffull) << 56))) + +template +static inline T BSwap(const T &x) { + static_assert(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8, + "Size of type must be 1, 2, 4, or 8 for BSwap"); + if (sizeof(T) == 1) { + return x; + } else if (sizeof(T) == 2) { + return BSWAP16(x); + } else if (sizeof(T) == 4) { + return BSWAP32(x); + } else { + return BSWAP64(x); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/case_insensitive_map.hpp b/src/duckdb/src/include/duckdb/common/case_insensitive_map.hpp new file mode 100644 index 00000000..30ef9bc9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/case_insensitive_map.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/case_insensitive_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { + +struct CaseInsensitiveStringHashFunction { + uint64_t operator()(const string &str) const { + return StringUtil::CIHash(str); + } +}; + +struct CaseInsensitiveStringEquality { + bool operator()(const string &a, const string &b) const { + return StringUtil::CIEquals(a, b); + } +}; + +template +using case_insensitive_map_t = + unordered_map; + +using case_insensitive_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/checksum.hpp b/src/duckdb/src/include/duckdb/common/checksum.hpp new file mode 100644 index 00000000..425d6b1f --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/checksum.hpp @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/checksum.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//! Compute a checksum over a buffer of size size +uint64_t Checksum(uint8_t *buffer, size_t size); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/chrono.hpp b/src/duckdb/src/include/duckdb/common/chrono.hpp new file mode 100644 index 00000000..797a867c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/chrono.hpp @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/chrono.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::chrono::duration; +using std::chrono::duration_cast; +using std::chrono::high_resolution_clock; +using std::chrono::milliseconds; +using std::chrono::system_clock; +using std::chrono::time_point; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/common.hpp b/src/duckdb/src/include/duckdb/common/common.hpp new file mode 100644 index 00000000..e337721c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/common.hpp @@ -0,0 +1,13 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/common.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/vector.hpp" diff --git a/src/duckdb/src/include/duckdb/common/compressed_file_system.hpp b/src/duckdb/src/include/duckdb/common/compressed_file_system.hpp new file mode 100644 index 00000000..3cdfb32d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/compressed_file_system.hpp @@ -0,0 +1,80 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/compressed_file_system.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/file_system.hpp" + +namespace duckdb { +class CompressedFile; + +struct StreamData { + // various buffers & pointers + bool write = false; + bool refresh = false; + unsafe_unique_array in_buff; + unsafe_unique_array out_buff; + data_ptr_t out_buff_start = nullptr; + data_ptr_t out_buff_end = nullptr; + data_ptr_t in_buff_start = nullptr; + data_ptr_t in_buff_end = nullptr; + + idx_t in_buf_size = 0; + idx_t out_buf_size = 0; +}; + +struct StreamWrapper { + DUCKDB_API virtual ~StreamWrapper(); + + DUCKDB_API virtual void Initialize(CompressedFile &file, bool write) = 0; + DUCKDB_API virtual bool Read(StreamData &stream_data) = 0; + DUCKDB_API virtual void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, + int64_t nr_bytes) = 0; + DUCKDB_API virtual void Close() = 0; +}; + +class CompressedFileSystem : public FileSystem { +public: + DUCKDB_API int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + DUCKDB_API int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + + DUCKDB_API void Reset(FileHandle &handle) override; + + DUCKDB_API int64_t GetFileSize(FileHandle &handle) override; + + DUCKDB_API bool OnDiskFile(FileHandle &handle) override; + DUCKDB_API bool CanSeek() override; + + DUCKDB_API virtual unique_ptr CreateStream() = 0; + DUCKDB_API virtual idx_t InBufferSize() = 0; + DUCKDB_API virtual idx_t OutBufferSize() = 0; +}; + +class CompressedFile : public FileHandle { +public: + DUCKDB_API CompressedFile(CompressedFileSystem &fs, unique_ptr child_handle_p, const string &path); + DUCKDB_API ~CompressedFile() override; + + CompressedFileSystem &compressed_fs; + unique_ptr child_handle; + //! Whether the file is opened for reading or for writing + bool write = false; + StreamData stream_data; + +public: + DUCKDB_API void Initialize(bool write); + DUCKDB_API int64_t ReadData(void *buffer, int64_t nr_bytes); + DUCKDB_API int64_t WriteData(data_ptr_t buffer, int64_t nr_bytes); + DUCKDB_API void Close() override; + +private: + unique_ptr stream_wrapper; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/constants.hpp b/src/duckdb/src/include/duckdb/common/constants.hpp new file mode 100644 index 00000000..03bded78 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/constants.hpp @@ -0,0 +1,105 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/constants.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include "duckdb/common/string.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/typedefs.hpp" + +namespace duckdb { +class Serializer; +class Deserializer; +class BinarySerializer; +class BinaryDeserializer; +class WriteStream; +class ReadStream; + +//! inline std directives that we use frequently +#ifndef DUCKDB_DEBUG_MOVE +using std::move; +#endif + +// NOTE: there is a copy of this in the Postgres' parser grammar (gram.y) +#define DEFAULT_SCHEMA "main" +#define INVALID_SCHEMA "" +#define INVALID_CATALOG "" +#define SYSTEM_CATALOG "system" +#define TEMP_CATALOG "temp" + +DUCKDB_API bool IsInvalidSchema(const string &str); +DUCKDB_API bool IsInvalidCatalog(const string &str); + +//! Special value used to signify the ROW ID of a table +DUCKDB_API extern const column_t COLUMN_IDENTIFIER_ROW_ID; +DUCKDB_API bool IsRowIdColumnId(column_t column_id); + +//! The maximum row identifier used in tables +extern const row_t MAX_ROW_ID; +//! Transaction-local row IDs start at MAX_ROW_ID +extern const row_t MAX_ROW_ID_LOCAL; + +extern const transaction_t TRANSACTION_ID_START; +extern const transaction_t MAX_TRANSACTION_ID; +extern const transaction_t MAXIMUM_QUERY_ID; +extern const transaction_t NOT_DELETED_ID; + +extern const double PI; + +struct DConstants { + //! The value used to signify an invalid index entry + static constexpr const idx_t INVALID_INDEX = idx_t(-1); +}; + +struct LogicalIndex { + explicit LogicalIndex(idx_t index) : index(index) { + } + + idx_t index; + + inline bool operator==(const LogicalIndex &rhs) const { + return index == rhs.index; + }; + inline bool operator!=(const LogicalIndex &rhs) const { + return index != rhs.index; + }; + inline bool operator<(const LogicalIndex &rhs) const { + return index < rhs.index; + }; + bool IsValid() { + return index != DConstants::INVALID_INDEX; + } +}; + +struct PhysicalIndex { + explicit PhysicalIndex(idx_t index) : index(index) { + } + + idx_t index; + + inline bool operator==(const PhysicalIndex &rhs) const { + return index == rhs.index; + }; + inline bool operator!=(const PhysicalIndex &rhs) const { + return index != rhs.index; + }; + inline bool operator<(const PhysicalIndex &rhs) const { + return index < rhs.index; + }; + bool IsValid() { + return index != DConstants::INVALID_INDEX; + } +}; + +DUCKDB_API bool IsPowerOfTwo(uint64_t v); +DUCKDB_API uint64_t NextPowerOfTwo(uint64_t v); +DUCKDB_API uint64_t PreviousPowerOfTwo(uint64_t v); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/crypto/md5.hpp b/src/duckdb/src/include/duckdb/common/crypto/md5.hpp new file mode 100644 index 00000000..856015c3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/crypto/md5.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/crypto/md5.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/string_type.hpp" + +namespace duckdb { + +class MD5Context { +public: + static constexpr idx_t MD5_HASH_LENGTH_BINARY = 16; + static constexpr idx_t MD5_HASH_LENGTH_TEXT = 32; + +public: + MD5Context(); + + void Add(const_data_ptr_t data, idx_t len) { + MD5Update(data, len); + } + void Add(const char *data); + void Add(string_t string) { + MD5Update(const_data_ptr_cast(string.GetData()), string.GetSize()); + } + void Add(const string &data) { + MD5Update(const_data_ptr_cast(data.c_str()), data.size()); + } + + //! Write the 16-byte (binary) digest to the specified location + void Finish(data_ptr_t out_digest); + //! Write the 32-character digest (in hexadecimal format) to the specified location + void FinishHex(char *out_digest); + //! Returns the 32-character digest (in hexadecimal format) as a string + string FinishHex(); + +private: + void MD5Update(const_data_ptr_t data, idx_t len); + static void DigestToBase16(const_data_ptr_t digest, char *zBuf); + + uint32_t buf[4]; + uint32_t bits[2]; + unsigned char in[64]; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/cycle_counter.hpp b/src/duckdb/src/include/duckdb/common/cycle_counter.hpp new file mode 100644 index 00000000..37cd67f3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/cycle_counter.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/cycle_counter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/helper.hpp" +#include "duckdb/common/chrono.hpp" + +namespace duckdb { + +//! The cycle counter can be used to measure elapsed cycles for a function, expression and ... +//! Optimized by sampling mechanism. Once per 100 times. +//! //Todo Can be optimized further by calling RDTSC once per sample +class CycleCounter { + friend struct ExpressionInfo; + friend struct ExpressionRootInfo; + static constexpr int SAMPLING_RATE = 50; + +public: + CycleCounter() { + } + // Next_sample determines if a sample needs to be taken, if so start the profiler + void BeginSample() { + if (current_count >= next_sample) { + tmp = Tick(); + } + } + + // End the sample + void EndSample(int chunk_size) { + if (current_count >= next_sample) { + time += Tick() - tmp; + } + if (current_count >= next_sample) { + next_sample = SAMPLING_RATE; + ++sample_count; + sample_tuples_count += chunk_size; + current_count = 0; + } else { + ++current_count; + } + tuples_count += chunk_size; + } + +private: + uint64_t Tick() const; + // current number on RDT register + uint64_t tmp; + // Elapsed cycles + uint64_t time = 0; + //! Count the number of time the executor called since last sampling + uint64_t current_count = 0; + //! Show the next sample + uint64_t next_sample = 0; + //! Count the number of samples + uint64_t sample_count = 0; + //! Count the number of tuples sampled + uint64_t sample_tuples_count = 0; + //! Count the number of ALL tuples + uint64_t tuples_count = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/deque.hpp b/src/duckdb/src/include/duckdb/common/deque.hpp new file mode 100644 index 00000000..f5c8ba99 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/deque.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/deque.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::deque; +} diff --git a/src/duckdb/src/include/duckdb/common/dl.hpp b/src/duckdb/src/include/duckdb/common/dl.hpp new file mode 100644 index 00000000..89494fc6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/dl.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/dl.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/windows.hpp" +#include "duckdb/common/local_file_system.hpp" +#include "duckdb/common/windows_util.hpp" + +#ifndef _WIN32 +#include +#else +#define RTLD_NOW 0 +#define RTLD_LOCAL 0 +#endif + +namespace duckdb { + +#ifdef _WIN32 + +inline void *dlopen(const char *file, int mode) { + D_ASSERT(file); + auto fpath = WindowsUtil::UTF8ToUnicode(file); + return (void *)LoadLibraryW(fpath.c_str()); +} + +inline void *dlsym(void *handle, const char *name) { + D_ASSERT(handle); + return (void *)GetProcAddress((HINSTANCE)handle, name); +} + +inline std::string GetDLError(void) { + return LocalFileSystem::GetLastErrorAsString(); +} + +#else + +inline std::string GetDLError(void) { + return dlerror(); +} + +#endif + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enum_class_hash.hpp b/src/duckdb/src/include/duckdb/common/enum_class_hash.hpp new file mode 100644 index 00000000..1de2c811 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enum_class_hash.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enum_class_hash.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +/* For compatibility with older C++ STL, an explicit hash class + is required for enums with C++ sets and maps */ +struct EnumClassHash { + template + std::size_t operator()(T t) const { + return static_cast(t); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp new file mode 100644 index 00000000..50cf642d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -0,0 +1,1070 @@ +//------------------------------------------------------------------------- +// This file is automatically generated by scripts/generate_enum_util.py +// Do not edit this file manually, your changes will be overwritten +// If you want to exclude an enum from serialization, add it to the blacklist in the script +// +// Note: The generated code will only work properly if the enum is a top level item in the duckdb namespace +// If the enum is nested in a class, or in another namespace, the generated code will not compile. +// You should move the enum to the duckdb namespace, manually write a specialization or add it to the blacklist +//------------------------------------------------------------------------- + + +#pragma once + +#include +#include "duckdb/common/string.hpp" + +namespace duckdb { + +struct EnumUtil { + // String -> Enum + template + static T FromString(const char *value) = delete; + + template + static T FromString(const string &value) { return FromString(value.c_str()); } + + // Enum -> String + template + static const char *ToChars(T value) = delete; + + template + static string ToString(T value) { return string(ToChars(value)); } +}; + +enum class AccessMode : uint8_t; + +enum class AggregateHandling : uint8_t; + +enum class AggregateOrderDependent : uint8_t; + +enum class AggregateType : uint8_t; + +enum class AlterForeignKeyType : uint8_t; + +enum class AlterScalarFunctionType : uint8_t; + +enum class AlterTableFunctionType : uint8_t; + +enum class AlterTableType : uint8_t; + +enum class AlterType : uint8_t; + +enum class AlterViewType : uint8_t; + +enum class AppenderType : uint8_t; + +enum class ArrowDateTimeType : uint8_t; + +enum class ArrowVariableSizeType : uint8_t; + +enum class BindingMode : uint8_t; + +enum class BitpackingMode : uint8_t; + +enum class BlockState : uint8_t; + +enum class CAPIResultSetType : uint8_t; + +enum class CSVState : uint8_t; + +enum class CTEMaterialize : uint8_t; + +enum class CatalogType : uint8_t; + +enum class CheckpointAbort : uint8_t; + +enum class ChunkInfoType : uint8_t; + +enum class ColumnDataAllocatorType : uint8_t; + +enum class ColumnDataScanProperties : uint8_t; + +enum class ColumnSegmentType : uint8_t; + +enum class CompressedMaterializationDirection : uint8_t; + +enum class CompressionType : uint8_t; + +enum class ConflictManagerMode : uint8_t; + +enum class ConstraintType : uint8_t; + +enum class DataFileType : uint8_t; + +enum class DatePartSpecifier : uint8_t; + +enum class DebugInitialize : uint8_t; + +enum class DefaultOrderByNullType : uint8_t; + +enum class DistinctType : uint8_t; + +enum class ErrorType : uint16_t; + +enum class ExceptionFormatValueType : uint8_t; + +enum class ExplainOutputType : uint8_t; + +enum class ExplainType : uint8_t; + +enum class ExpressionClass : uint8_t; + +enum class ExpressionType : uint8_t; + +enum class ExtensionLoadResult : uint8_t; + +enum class ExtraTypeInfoType : uint8_t; + +enum class FileBufferType : uint8_t; + +enum class FileCompressionType : uint8_t; + +enum class FileGlobOptions : uint8_t; + +enum class FileLockType : uint8_t; + +enum class FilterPropagateResult : uint8_t; + +enum class ForeignKeyType : uint8_t; + +enum class FunctionNullHandling : uint8_t; + +enum class FunctionSideEffects : uint8_t; + +enum class HLLStorageType : uint8_t; + +enum class IndexConstraintType : uint8_t; + +enum class IndexType : uint8_t; + +enum class InsertColumnOrder : uint8_t; + +enum class InterruptMode : uint8_t; + +enum class JoinRefType : uint8_t; + +enum class JoinType : uint8_t; + +enum class KeywordCategory : uint8_t; + +enum class LoadType : uint8_t; + +enum class LogicalOperatorType : uint8_t; + +enum class LogicalTypeId : uint8_t; + +enum class LookupResultType : uint8_t; + +enum class MacroType : uint8_t; + +enum class MapInvalidReason : uint8_t; + +enum class NType : uint8_t; + +enum class NewLineIdentifier : uint8_t; + +enum class OnConflictAction : uint8_t; + +enum class OnCreateConflict : uint8_t; + +enum class OnEntryNotFound : uint8_t; + +enum class OperatorFinalizeResultType : uint8_t; + +enum class OperatorResultType : uint8_t; + +enum class OptimizerType : uint32_t; + +enum class OrderByNullType : uint8_t; + +enum class OrderPreservationType : uint8_t; + +enum class OrderType : uint8_t; + +enum class OutputStream : uint8_t; + +enum class ParseInfoType : uint8_t; + +enum class ParserExtensionResultType : uint8_t; + +enum class ParserMode : uint8_t; + +enum class PartitionSortStage : uint8_t; + +enum class PartitionedColumnDataType : uint8_t; + +enum class PartitionedTupleDataType : uint8_t; + +enum class PendingExecutionResult : uint8_t; + +enum class PhysicalOperatorType : uint8_t; + +enum class PhysicalType : uint8_t; + +enum class PragmaType : uint8_t; + +enum class PreparedParamType : uint8_t; + +enum class ProfilerPrintFormat : uint8_t; + +enum class QueryNodeType : uint8_t; + +enum class QueryResultType : uint8_t; + +enum class QuoteRule : uint8_t; + +enum class RelationType : uint8_t; + +enum class RenderMode : uint8_t; + +enum class ResultModifierType : uint8_t; + +enum class SampleMethod : uint8_t; + +enum class SequenceInfo : uint8_t; + +enum class SetOperationType : uint8_t; + +enum class SetScope : uint8_t; + +enum class SetType : uint8_t; + +enum class SimplifiedTokenType : uint8_t; + +enum class SinkCombineResultType : uint8_t; + +enum class SinkFinalizeType : uint8_t; + +enum class SinkResultType : uint8_t; + +enum class SourceResultType : uint8_t; + +enum class StatementReturnType : uint8_t; + +enum class StatementType : uint8_t; + +enum class StatisticsType : uint8_t; + +enum class StatsInfo : uint8_t; + +enum class StrTimeSpecifier : uint8_t; + +enum class SubqueryType : uint8_t; + +enum class TableColumnType : uint8_t; + +enum class TableFilterType : uint8_t; + +enum class TableReferenceType : uint8_t; + +enum class TableScanType : uint8_t; + +enum class TaskExecutionMode : uint8_t; + +enum class TaskExecutionResult : uint8_t; + +enum class TimestampCastResult : uint8_t; + +enum class TransactionType : uint8_t; + +enum class TupleDataPinProperties : uint8_t; + +enum class UndoFlags : uint32_t; + +enum class UnionInvalidReason : uint8_t; + +enum class VectorAuxiliaryDataType : uint8_t; + +enum class VectorBufferType : uint8_t; + +enum class VectorType : uint8_t; + +enum class VerificationType : uint8_t; + +enum class VerifyExistenceType : uint8_t; + +enum class WALType : uint8_t; + +enum class WindowAggregationMode : uint32_t; + +enum class WindowBoundary : uint8_t; + + +template<> +const char* EnumUtil::ToChars(AccessMode value); + +template<> +const char* EnumUtil::ToChars(AggregateHandling value); + +template<> +const char* EnumUtil::ToChars(AggregateOrderDependent value); + +template<> +const char* EnumUtil::ToChars(AggregateType value); + +template<> +const char* EnumUtil::ToChars(AlterForeignKeyType value); + +template<> +const char* EnumUtil::ToChars(AlterScalarFunctionType value); + +template<> +const char* EnumUtil::ToChars(AlterTableFunctionType value); + +template<> +const char* EnumUtil::ToChars(AlterTableType value); + +template<> +const char* EnumUtil::ToChars(AlterType value); + +template<> +const char* EnumUtil::ToChars(AlterViewType value); + +template<> +const char* EnumUtil::ToChars(AppenderType value); + +template<> +const char* EnumUtil::ToChars(ArrowDateTimeType value); + +template<> +const char* EnumUtil::ToChars(ArrowVariableSizeType value); + +template<> +const char* EnumUtil::ToChars(BindingMode value); + +template<> +const char* EnumUtil::ToChars(BitpackingMode value); + +template<> +const char* EnumUtil::ToChars(BlockState value); + +template<> +const char* EnumUtil::ToChars(CAPIResultSetType value); + +template<> +const char* EnumUtil::ToChars(CSVState value); + +template<> +const char* EnumUtil::ToChars(CTEMaterialize value); + +template<> +const char* EnumUtil::ToChars(CatalogType value); + +template<> +const char* EnumUtil::ToChars(CheckpointAbort value); + +template<> +const char* EnumUtil::ToChars(ChunkInfoType value); + +template<> +const char* EnumUtil::ToChars(ColumnDataAllocatorType value); + +template<> +const char* EnumUtil::ToChars(ColumnDataScanProperties value); + +template<> +const char* EnumUtil::ToChars(ColumnSegmentType value); + +template<> +const char* EnumUtil::ToChars(CompressedMaterializationDirection value); + +template<> +const char* EnumUtil::ToChars(CompressionType value); + +template<> +const char* EnumUtil::ToChars(ConflictManagerMode value); + +template<> +const char* EnumUtil::ToChars(ConstraintType value); + +template<> +const char* EnumUtil::ToChars(DataFileType value); + +template<> +const char* EnumUtil::ToChars(DatePartSpecifier value); + +template<> +const char* EnumUtil::ToChars(DebugInitialize value); + +template<> +const char* EnumUtil::ToChars(DefaultOrderByNullType value); + +template<> +const char* EnumUtil::ToChars(DistinctType value); + +template<> +const char* EnumUtil::ToChars(ErrorType value); + +template<> +const char* EnumUtil::ToChars(ExceptionFormatValueType value); + +template<> +const char* EnumUtil::ToChars(ExplainOutputType value); + +template<> +const char* EnumUtil::ToChars(ExplainType value); + +template<> +const char* EnumUtil::ToChars(ExpressionClass value); + +template<> +const char* EnumUtil::ToChars(ExpressionType value); + +template<> +const char* EnumUtil::ToChars(ExtensionLoadResult value); + +template<> +const char* EnumUtil::ToChars(ExtraTypeInfoType value); + +template<> +const char* EnumUtil::ToChars(FileBufferType value); + +template<> +const char* EnumUtil::ToChars(FileCompressionType value); + +template<> +const char* EnumUtil::ToChars(FileGlobOptions value); + +template<> +const char* EnumUtil::ToChars(FileLockType value); + +template<> +const char* EnumUtil::ToChars(FilterPropagateResult value); + +template<> +const char* EnumUtil::ToChars(ForeignKeyType value); + +template<> +const char* EnumUtil::ToChars(FunctionNullHandling value); + +template<> +const char* EnumUtil::ToChars(FunctionSideEffects value); + +template<> +const char* EnumUtil::ToChars(HLLStorageType value); + +template<> +const char* EnumUtil::ToChars(IndexConstraintType value); + +template<> +const char* EnumUtil::ToChars(IndexType value); + +template<> +const char* EnumUtil::ToChars(InsertColumnOrder value); + +template<> +const char* EnumUtil::ToChars(InterruptMode value); + +template<> +const char* EnumUtil::ToChars(JoinRefType value); + +template<> +const char* EnumUtil::ToChars(JoinType value); + +template<> +const char* EnumUtil::ToChars(KeywordCategory value); + +template<> +const char* EnumUtil::ToChars(LoadType value); + +template<> +const char* EnumUtil::ToChars(LogicalOperatorType value); + +template<> +const char* EnumUtil::ToChars(LogicalTypeId value); + +template<> +const char* EnumUtil::ToChars(LookupResultType value); + +template<> +const char* EnumUtil::ToChars(MacroType value); + +template<> +const char* EnumUtil::ToChars(MapInvalidReason value); + +template<> +const char* EnumUtil::ToChars(NType value); + +template<> +const char* EnumUtil::ToChars(NewLineIdentifier value); + +template<> +const char* EnumUtil::ToChars(OnConflictAction value); + +template<> +const char* EnumUtil::ToChars(OnCreateConflict value); + +template<> +const char* EnumUtil::ToChars(OnEntryNotFound value); + +template<> +const char* EnumUtil::ToChars(OperatorFinalizeResultType value); + +template<> +const char* EnumUtil::ToChars(OperatorResultType value); + +template<> +const char* EnumUtil::ToChars(OptimizerType value); + +template<> +const char* EnumUtil::ToChars(OrderByNullType value); + +template<> +const char* EnumUtil::ToChars(OrderPreservationType value); + +template<> +const char* EnumUtil::ToChars(OrderType value); + +template<> +const char* EnumUtil::ToChars(OutputStream value); + +template<> +const char* EnumUtil::ToChars(ParseInfoType value); + +template<> +const char* EnumUtil::ToChars(ParserExtensionResultType value); + +template<> +const char* EnumUtil::ToChars(ParserMode value); + +template<> +const char* EnumUtil::ToChars(PartitionSortStage value); + +template<> +const char* EnumUtil::ToChars(PartitionedColumnDataType value); + +template<> +const char* EnumUtil::ToChars(PartitionedTupleDataType value); + +template<> +const char* EnumUtil::ToChars(PendingExecutionResult value); + +template<> +const char* EnumUtil::ToChars(PhysicalOperatorType value); + +template<> +const char* EnumUtil::ToChars(PhysicalType value); + +template<> +const char* EnumUtil::ToChars(PragmaType value); + +template<> +const char* EnumUtil::ToChars(PreparedParamType value); + +template<> +const char* EnumUtil::ToChars(ProfilerPrintFormat value); + +template<> +const char* EnumUtil::ToChars(QueryNodeType value); + +template<> +const char* EnumUtil::ToChars(QueryResultType value); + +template<> +const char* EnumUtil::ToChars(QuoteRule value); + +template<> +const char* EnumUtil::ToChars(RelationType value); + +template<> +const char* EnumUtil::ToChars(RenderMode value); + +template<> +const char* EnumUtil::ToChars(ResultModifierType value); + +template<> +const char* EnumUtil::ToChars(SampleMethod value); + +template<> +const char* EnumUtil::ToChars(SequenceInfo value); + +template<> +const char* EnumUtil::ToChars(SetOperationType value); + +template<> +const char* EnumUtil::ToChars(SetScope value); + +template<> +const char* EnumUtil::ToChars(SetType value); + +template<> +const char* EnumUtil::ToChars(SimplifiedTokenType value); + +template<> +const char* EnumUtil::ToChars(SinkCombineResultType value); + +template<> +const char* EnumUtil::ToChars(SinkFinalizeType value); + +template<> +const char* EnumUtil::ToChars(SinkResultType value); + +template<> +const char* EnumUtil::ToChars(SourceResultType value); + +template<> +const char* EnumUtil::ToChars(StatementReturnType value); + +template<> +const char* EnumUtil::ToChars(StatementType value); + +template<> +const char* EnumUtil::ToChars(StatisticsType value); + +template<> +const char* EnumUtil::ToChars(StatsInfo value); + +template<> +const char* EnumUtil::ToChars(StrTimeSpecifier value); + +template<> +const char* EnumUtil::ToChars(SubqueryType value); + +template<> +const char* EnumUtil::ToChars(TableColumnType value); + +template<> +const char* EnumUtil::ToChars(TableFilterType value); + +template<> +const char* EnumUtil::ToChars(TableReferenceType value); + +template<> +const char* EnumUtil::ToChars(TableScanType value); + +template<> +const char* EnumUtil::ToChars(TaskExecutionMode value); + +template<> +const char* EnumUtil::ToChars(TaskExecutionResult value); + +template<> +const char* EnumUtil::ToChars(TimestampCastResult value); + +template<> +const char* EnumUtil::ToChars(TransactionType value); + +template<> +const char* EnumUtil::ToChars(TupleDataPinProperties value); + +template<> +const char* EnumUtil::ToChars(UndoFlags value); + +template<> +const char* EnumUtil::ToChars(UnionInvalidReason value); + +template<> +const char* EnumUtil::ToChars(VectorAuxiliaryDataType value); + +template<> +const char* EnumUtil::ToChars(VectorBufferType value); + +template<> +const char* EnumUtil::ToChars(VectorType value); + +template<> +const char* EnumUtil::ToChars(VerificationType value); + +template<> +const char* EnumUtil::ToChars(VerifyExistenceType value); + +template<> +const char* EnumUtil::ToChars(WALType value); + +template<> +const char* EnumUtil::ToChars(WindowAggregationMode value); + +template<> +const char* EnumUtil::ToChars(WindowBoundary value); + + +template<> +AccessMode EnumUtil::FromString(const char *value); + +template<> +AggregateHandling EnumUtil::FromString(const char *value); + +template<> +AggregateOrderDependent EnumUtil::FromString(const char *value); + +template<> +AggregateType EnumUtil::FromString(const char *value); + +template<> +AlterForeignKeyType EnumUtil::FromString(const char *value); + +template<> +AlterScalarFunctionType EnumUtil::FromString(const char *value); + +template<> +AlterTableFunctionType EnumUtil::FromString(const char *value); + +template<> +AlterTableType EnumUtil::FromString(const char *value); + +template<> +AlterType EnumUtil::FromString(const char *value); + +template<> +AlterViewType EnumUtil::FromString(const char *value); + +template<> +AppenderType EnumUtil::FromString(const char *value); + +template<> +ArrowDateTimeType EnumUtil::FromString(const char *value); + +template<> +ArrowVariableSizeType EnumUtil::FromString(const char *value); + +template<> +BindingMode EnumUtil::FromString(const char *value); + +template<> +BitpackingMode EnumUtil::FromString(const char *value); + +template<> +BlockState EnumUtil::FromString(const char *value); + +template<> +CAPIResultSetType EnumUtil::FromString(const char *value); + +template<> +CSVState EnumUtil::FromString(const char *value); + +template<> +CTEMaterialize EnumUtil::FromString(const char *value); + +template<> +CatalogType EnumUtil::FromString(const char *value); + +template<> +CheckpointAbort EnumUtil::FromString(const char *value); + +template<> +ChunkInfoType EnumUtil::FromString(const char *value); + +template<> +ColumnDataAllocatorType EnumUtil::FromString(const char *value); + +template<> +ColumnDataScanProperties EnumUtil::FromString(const char *value); + +template<> +ColumnSegmentType EnumUtil::FromString(const char *value); + +template<> +CompressedMaterializationDirection EnumUtil::FromString(const char *value); + +template<> +CompressionType EnumUtil::FromString(const char *value); + +template<> +ConflictManagerMode EnumUtil::FromString(const char *value); + +template<> +ConstraintType EnumUtil::FromString(const char *value); + +template<> +DataFileType EnumUtil::FromString(const char *value); + +template<> +DatePartSpecifier EnumUtil::FromString(const char *value); + +template<> +DebugInitialize EnumUtil::FromString(const char *value); + +template<> +DefaultOrderByNullType EnumUtil::FromString(const char *value); + +template<> +DistinctType EnumUtil::FromString(const char *value); + +template<> +ErrorType EnumUtil::FromString(const char *value); + +template<> +ExceptionFormatValueType EnumUtil::FromString(const char *value); + +template<> +ExplainOutputType EnumUtil::FromString(const char *value); + +template<> +ExplainType EnumUtil::FromString(const char *value); + +template<> +ExpressionClass EnumUtil::FromString(const char *value); + +template<> +ExpressionType EnumUtil::FromString(const char *value); + +template<> +ExtensionLoadResult EnumUtil::FromString(const char *value); + +template<> +ExtraTypeInfoType EnumUtil::FromString(const char *value); + +template<> +FileBufferType EnumUtil::FromString(const char *value); + +template<> +FileCompressionType EnumUtil::FromString(const char *value); + +template<> +FileGlobOptions EnumUtil::FromString(const char *value); + +template<> +FileLockType EnumUtil::FromString(const char *value); + +template<> +FilterPropagateResult EnumUtil::FromString(const char *value); + +template<> +ForeignKeyType EnumUtil::FromString(const char *value); + +template<> +FunctionNullHandling EnumUtil::FromString(const char *value); + +template<> +FunctionSideEffects EnumUtil::FromString(const char *value); + +template<> +HLLStorageType EnumUtil::FromString(const char *value); + +template<> +IndexConstraintType EnumUtil::FromString(const char *value); + +template<> +IndexType EnumUtil::FromString(const char *value); + +template<> +InsertColumnOrder EnumUtil::FromString(const char *value); + +template<> +InterruptMode EnumUtil::FromString(const char *value); + +template<> +JoinRefType EnumUtil::FromString(const char *value); + +template<> +JoinType EnumUtil::FromString(const char *value); + +template<> +KeywordCategory EnumUtil::FromString(const char *value); + +template<> +LoadType EnumUtil::FromString(const char *value); + +template<> +LogicalOperatorType EnumUtil::FromString(const char *value); + +template<> +LogicalTypeId EnumUtil::FromString(const char *value); + +template<> +LookupResultType EnumUtil::FromString(const char *value); + +template<> +MacroType EnumUtil::FromString(const char *value); + +template<> +MapInvalidReason EnumUtil::FromString(const char *value); + +template<> +NType EnumUtil::FromString(const char *value); + +template<> +NewLineIdentifier EnumUtil::FromString(const char *value); + +template<> +OnConflictAction EnumUtil::FromString(const char *value); + +template<> +OnCreateConflict EnumUtil::FromString(const char *value); + +template<> +OnEntryNotFound EnumUtil::FromString(const char *value); + +template<> +OperatorFinalizeResultType EnumUtil::FromString(const char *value); + +template<> +OperatorResultType EnumUtil::FromString(const char *value); + +template<> +OptimizerType EnumUtil::FromString(const char *value); + +template<> +OrderByNullType EnumUtil::FromString(const char *value); + +template<> +OrderPreservationType EnumUtil::FromString(const char *value); + +template<> +OrderType EnumUtil::FromString(const char *value); + +template<> +OutputStream EnumUtil::FromString(const char *value); + +template<> +ParseInfoType EnumUtil::FromString(const char *value); + +template<> +ParserExtensionResultType EnumUtil::FromString(const char *value); + +template<> +ParserMode EnumUtil::FromString(const char *value); + +template<> +PartitionSortStage EnumUtil::FromString(const char *value); + +template<> +PartitionedColumnDataType EnumUtil::FromString(const char *value); + +template<> +PartitionedTupleDataType EnumUtil::FromString(const char *value); + +template<> +PendingExecutionResult EnumUtil::FromString(const char *value); + +template<> +PhysicalOperatorType EnumUtil::FromString(const char *value); + +template<> +PhysicalType EnumUtil::FromString(const char *value); + +template<> +PragmaType EnumUtil::FromString(const char *value); + +template<> +PreparedParamType EnumUtil::FromString(const char *value); + +template<> +ProfilerPrintFormat EnumUtil::FromString(const char *value); + +template<> +QueryNodeType EnumUtil::FromString(const char *value); + +template<> +QueryResultType EnumUtil::FromString(const char *value); + +template<> +QuoteRule EnumUtil::FromString(const char *value); + +template<> +RelationType EnumUtil::FromString(const char *value); + +template<> +RenderMode EnumUtil::FromString(const char *value); + +template<> +ResultModifierType EnumUtil::FromString(const char *value); + +template<> +SampleMethod EnumUtil::FromString(const char *value); + +template<> +SequenceInfo EnumUtil::FromString(const char *value); + +template<> +SetOperationType EnumUtil::FromString(const char *value); + +template<> +SetScope EnumUtil::FromString(const char *value); + +template<> +SetType EnumUtil::FromString(const char *value); + +template<> +SimplifiedTokenType EnumUtil::FromString(const char *value); + +template<> +SinkCombineResultType EnumUtil::FromString(const char *value); + +template<> +SinkFinalizeType EnumUtil::FromString(const char *value); + +template<> +SinkResultType EnumUtil::FromString(const char *value); + +template<> +SourceResultType EnumUtil::FromString(const char *value); + +template<> +StatementReturnType EnumUtil::FromString(const char *value); + +template<> +StatementType EnumUtil::FromString(const char *value); + +template<> +StatisticsType EnumUtil::FromString(const char *value); + +template<> +StatsInfo EnumUtil::FromString(const char *value); + +template<> +StrTimeSpecifier EnumUtil::FromString(const char *value); + +template<> +SubqueryType EnumUtil::FromString(const char *value); + +template<> +TableColumnType EnumUtil::FromString(const char *value); + +template<> +TableFilterType EnumUtil::FromString(const char *value); + +template<> +TableReferenceType EnumUtil::FromString(const char *value); + +template<> +TableScanType EnumUtil::FromString(const char *value); + +template<> +TaskExecutionMode EnumUtil::FromString(const char *value); + +template<> +TaskExecutionResult EnumUtil::FromString(const char *value); + +template<> +TimestampCastResult EnumUtil::FromString(const char *value); + +template<> +TransactionType EnumUtil::FromString(const char *value); + +template<> +TupleDataPinProperties EnumUtil::FromString(const char *value); + +template<> +UndoFlags EnumUtil::FromString(const char *value); + +template<> +UnionInvalidReason EnumUtil::FromString(const char *value); + +template<> +VectorAuxiliaryDataType EnumUtil::FromString(const char *value); + +template<> +VectorBufferType EnumUtil::FromString(const char *value); + +template<> +VectorType EnumUtil::FromString(const char *value); + +template<> +VerificationType EnumUtil::FromString(const char *value); + +template<> +VerifyExistenceType EnumUtil::FromString(const char *value); + +template<> +WALType EnumUtil::FromString(const char *value); + +template<> +WindowAggregationMode EnumUtil::FromString(const char *value); + +template<> +WindowBoundary EnumUtil::FromString(const char *value); + + +} diff --git a/src/duckdb/src/include/duckdb/common/enums/access_mode.hpp b/src/duckdb/src/include/duckdb/common/enums/access_mode.hpp new file mode 100644 index 00000000..bc1c6170 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/access_mode.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/access_mode.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class AccessMode : uint8_t { UNDEFINED = 0, AUTOMATIC = 1, READ_ONLY = 2, READ_WRITE = 3 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/aggregate_handling.hpp b/src/duckdb/src/include/duckdb/common/enums/aggregate_handling.hpp new file mode 100644 index 00000000..f11aec03 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/aggregate_handling.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/aggregate_handling.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===---- +enum class AggregateHandling : uint8_t { + STANDARD_HANDLING, // standard handling as in the SELECT clause + NO_AGGREGATES_ALLOWED, // no aggregates allowed: any aggregates in this node will result in an error + FORCE_AGGREGATES // force aggregates: any non-aggregate select list entry will become a GROUP +}; + +const char *ToString(AggregateHandling value); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/catalog_type.hpp b/src/duckdb/src/include/duckdb/common/enums/catalog_type.hpp new file mode 100644 index 00000000..2b9faa9c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/catalog_type.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/catalog_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Catalog Types +//===--------------------------------------------------------------------===// +enum class CatalogType : uint8_t { + INVALID = 0, + TABLE_ENTRY = 1, + SCHEMA_ENTRY = 2, + VIEW_ENTRY = 3, + INDEX_ENTRY = 4, + PREPARED_STATEMENT = 5, + SEQUENCE_ENTRY = 6, + COLLATION_ENTRY = 7, + TYPE_ENTRY = 8, + DATABASE_ENTRY = 9, + + // functions + TABLE_FUNCTION_ENTRY = 25, + SCALAR_FUNCTION_ENTRY = 26, + AGGREGATE_FUNCTION_ENTRY = 27, + PRAGMA_FUNCTION_ENTRY = 28, + COPY_FUNCTION_ENTRY = 29, + MACRO_ENTRY = 30, + TABLE_MACRO_ENTRY = 31, + + // version info + UPDATED_ENTRY = 50, + DELETED_ENTRY = 51, +}; + +DUCKDB_API string CatalogTypeToString(CatalogType type); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp b/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp new file mode 100644 index 00000000..e884d52a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/compression_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +enum class CompressionType : uint8_t { + COMPRESSION_AUTO = 0, + COMPRESSION_UNCOMPRESSED = 1, + COMPRESSION_CONSTANT = 2, + COMPRESSION_RLE = 3, + COMPRESSION_DICTIONARY = 4, + COMPRESSION_PFOR_DELTA = 5, + COMPRESSION_BITPACKING = 6, + COMPRESSION_FSST = 7, + COMPRESSION_CHIMP = 8, + COMPRESSION_PATAS = 9, + COMPRESSION_COUNT // This has to stay the last entry of the type! +}; + +vector ListCompressionTypes(void); +CompressionType CompressionTypeFromString(const string &str); +string CompressionTypeToString(CompressionType type); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/cte_materialize.hpp b/src/duckdb/src/include/duckdb/common/enums/cte_materialize.hpp new file mode 100644 index 00000000..356d298a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/cte_materialize.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/cte_materialize.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class CTEMaterialize : uint8_t { + CTE_MATERIALIZE_DEFAULT = 1, /* no option specified */ + CTE_MATERIALIZE_ALWAYS = 2, /* MATERIALIZED */ + CTE_MATERIALIZE_NEVER = 3 /* NOT MATERIALIZED */ +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/date_part_specifier.hpp b/src/duckdb/src/include/duckdb/common/enums/date_part_specifier.hpp new file mode 100644 index 00000000..54e2790e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/date_part_specifier.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/date_part_specifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class DatePartSpecifier : uint8_t { + // BIGINT values + YEAR, + MONTH, + DAY, + DECADE, + CENTURY, + MILLENNIUM, + MICROSECONDS, + MILLISECONDS, + SECOND, + MINUTE, + HOUR, + DOW, + ISODOW, + WEEK, + ISOYEAR, + QUARTER, + DOY, + YEARWEEK, + ERA, + TIMEZONE, + TIMEZONE_HOUR, + TIMEZONE_MINUTE, + + // DOUBLE values + EPOCH, + JULIAN_DAY, + + // Invalid + INVALID, + + // Type ranges + BEGIN_BIGINT = YEAR, + BEGIN_DOUBLE = EPOCH, + BEGIN_INVALID = INVALID, +}; + +inline bool IsBigintDatepart(DatePartSpecifier part_code) { + return size_t(part_code) < size_t(DatePartSpecifier::BEGIN_DOUBLE); +} + +DUCKDB_API bool TryGetDatePartSpecifier(const string &specifier, DatePartSpecifier &result); +DUCKDB_API DatePartSpecifier GetDatePartSpecifier(const string &specifier); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/debug_initialize.hpp b/src/duckdb/src/include/duckdb/common/enums/debug_initialize.hpp new file mode 100644 index 00000000..5cf794e0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/debug_initialize.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/debug_initialize.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class DebugInitialize : uint8_t { NO_INITIALIZE = 0, DEBUG_ZERO_INITIALIZE = 1, DEBUG_ONE_INITIALIZE = 2 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/expression_type.hpp b/src/duckdb/src/include/duckdb/common/enums/expression_type.hpp new file mode 100644 index 00000000..e404ef4a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/expression_type.hpp @@ -0,0 +1,215 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/expression_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Predicate Expression Operation Types +//===--------------------------------------------------------------------===// +enum class ExpressionType : uint8_t { + INVALID = 0, + + // explicitly cast left as right (right is integer in ValueType enum) + OPERATOR_CAST = 12, + // logical not operator + OPERATOR_NOT = 13, + // is null operator + OPERATOR_IS_NULL = 14, + // is not null operator + OPERATOR_IS_NOT_NULL = 15, + + // ----------------------------- + // Comparison Operators + // ----------------------------- + // equal operator between left and right + COMPARE_EQUAL = 25, + // compare initial boundary + COMPARE_BOUNDARY_START = COMPARE_EQUAL, + // inequal operator between left and right + COMPARE_NOTEQUAL = 26, + // less than operator between left and right + COMPARE_LESSTHAN = 27, + // greater than operator between left and right + COMPARE_GREATERTHAN = 28, + // less than equal operator between left and right + COMPARE_LESSTHANOREQUALTO = 29, + // greater than equal operator between left and right + COMPARE_GREATERTHANOREQUALTO = 30, + // IN operator [left IN (right1, right2, ...)] + COMPARE_IN = 35, + // NOT IN operator [left NOT IN (right1, right2, ...)] + COMPARE_NOT_IN = 36, + // IS DISTINCT FROM operator + COMPARE_DISTINCT_FROM = 37, + + COMPARE_BETWEEN = 38, + COMPARE_NOT_BETWEEN = 39, + // IS NOT DISTINCT FROM operator + COMPARE_NOT_DISTINCT_FROM = 40, + // compare final boundary + COMPARE_BOUNDARY_END = COMPARE_NOT_DISTINCT_FROM, + + // ----------------------------- + // Conjunction Operators + // ----------------------------- + CONJUNCTION_AND = 50, + CONJUNCTION_OR = 51, + + // ----------------------------- + // Values + // ----------------------------- + VALUE_CONSTANT = 75, + VALUE_PARAMETER = 76, + VALUE_TUPLE = 77, + VALUE_TUPLE_ADDRESS = 78, + VALUE_NULL = 79, + VALUE_VECTOR = 80, + VALUE_SCALAR = 81, + VALUE_DEFAULT = 82, + + // ----------------------------- + // Aggregates + // ----------------------------- + AGGREGATE = 100, + BOUND_AGGREGATE = 101, + GROUPING_FUNCTION = 102, + + // ----------------------------- + // Window Functions + // ----------------------------- + WINDOW_AGGREGATE = 110, + + WINDOW_RANK = 120, + WINDOW_RANK_DENSE = 121, + WINDOW_NTILE = 122, + WINDOW_PERCENT_RANK = 123, + WINDOW_CUME_DIST = 124, + WINDOW_ROW_NUMBER = 125, + + WINDOW_FIRST_VALUE = 130, + WINDOW_LAST_VALUE = 131, + WINDOW_LEAD = 132, + WINDOW_LAG = 133, + WINDOW_NTH_VALUE = 134, + + // ----------------------------- + // Functions + // ----------------------------- + FUNCTION = 140, + BOUND_FUNCTION = 141, + + // ----------------------------- + // Operators + // ----------------------------- + CASE_EXPR = 150, + OPERATOR_NULLIF = 151, + OPERATOR_COALESCE = 152, + ARRAY_EXTRACT = 153, + ARRAY_SLICE = 154, + STRUCT_EXTRACT = 155, + ARRAY_CONSTRUCTOR = 156, + ARROW = 157, + + // ----------------------------- + // Subquery IN/EXISTS + // ----------------------------- + SUBQUERY = 175, + + // ----------------------------- + // Parser + // ----------------------------- + STAR = 200, + TABLE_STAR = 201, + PLACEHOLDER = 202, + COLUMN_REF = 203, + FUNCTION_REF = 204, + TABLE_REF = 205, + + // ----------------------------- + // Miscellaneous + // ----------------------------- + CAST = 225, + BOUND_REF = 227, + BOUND_COLUMN_REF = 228, + BOUND_UNNEST = 229, + COLLATE = 230, + LAMBDA = 231, + POSITIONAL_REFERENCE = 232, + BOUND_LAMBDA_REF = 233 +}; + +//===--------------------------------------------------------------------===// +// Expression Class +//===--------------------------------------------------------------------===// +enum class ExpressionClass : uint8_t { + INVALID = 0, + //===--------------------------------------------------------------------===// + // Parsed Expressions + //===--------------------------------------------------------------------===// + AGGREGATE = 1, + CASE = 2, + CAST = 3, + COLUMN_REF = 4, + COMPARISON = 5, + CONJUNCTION = 6, + CONSTANT = 7, + DEFAULT = 8, + FUNCTION = 9, + OPERATOR = 10, + STAR = 11, + SUBQUERY = 13, + WINDOW = 14, + PARAMETER = 15, + COLLATE = 16, + LAMBDA = 17, + POSITIONAL_REFERENCE = 18, + BETWEEN = 19, + //===--------------------------------------------------------------------===// + // Bound Expressions + //===--------------------------------------------------------------------===// + BOUND_AGGREGATE = 25, + BOUND_CASE = 26, + BOUND_CAST = 27, + BOUND_COLUMN_REF = 28, + BOUND_COMPARISON = 29, + BOUND_CONJUNCTION = 30, + BOUND_CONSTANT = 31, + BOUND_DEFAULT = 32, + BOUND_FUNCTION = 33, + BOUND_OPERATOR = 34, + BOUND_PARAMETER = 35, + BOUND_REF = 36, + BOUND_SUBQUERY = 37, + BOUND_WINDOW = 38, + BOUND_BETWEEN = 39, + BOUND_UNNEST = 40, + BOUND_LAMBDA = 41, + BOUND_LAMBDA_REF = 42, + //===--------------------------------------------------------------------===// + // Miscellaneous + //===--------------------------------------------------------------------===// + BOUND_EXPRESSION = 50 +}; + +DUCKDB_API string ExpressionTypeToString(ExpressionType type); +string ExpressionTypeToOperator(ExpressionType type); + +// Operator String to ExpressionType (e.g. + => OPERATOR_ADD) +ExpressionType OperatorToExpressionType(const string &op); +//! Negate a comparison expression, turning e.g. = into !=, or < into >= +ExpressionType NegateComparisonExpression(ExpressionType type); +//! Flip a comparison expression, turning e.g. < into >, or = into = +ExpressionType FlipComparisonExpression(ExpressionType type); + +DUCKDB_API string ExpressionClassToString(ExpressionClass type); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/file_compression_type.hpp b/src/duckdb/src/include/duckdb/common/enums/file_compression_type.hpp new file mode 100644 index 00000000..98fe5e75 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/file_compression_type.hpp @@ -0,0 +1,19 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/file_compression_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class FileCompressionType : uint8_t { AUTO_DETECT = 0, UNCOMPRESSED = 1, GZIP = 2, ZSTD = 3 }; + +FileCompressionType FileCompressionTypeFromString(const string &input); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/file_glob_options.hpp b/src/duckdb/src/include/duckdb/common/enums/file_glob_options.hpp new file mode 100644 index 00000000..94f528d5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/file_glob_options.hpp @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/file_glob_options.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class FileGlobOptions : uint8_t { + DISALLOW_EMPTY = 0, + ALLOW_EMPTY = 1, +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/filter_propagate_result.hpp b/src/duckdb/src/include/duckdb/common/enums/filter_propagate_result.hpp new file mode 100644 index 00000000..edf239eb --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/filter_propagate_result.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/filter_propagate_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class FilterPropagateResult : uint8_t { + NO_PRUNING_POSSIBLE = 0, + FILTER_ALWAYS_TRUE = 1, + FILTER_ALWAYS_FALSE = 2, + FILTER_TRUE_OR_NULL = 3, + FILTER_FALSE_OR_NULL = 4 +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/index_type.hpp b/src/duckdb/src/include/duckdb/common/enums/index_type.hpp new file mode 100644 index 00000000..a2cbd3f0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/index_type.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/index_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Index Types +//===--------------------------------------------------------------------===// +enum class IndexType : uint8_t { + INVALID = 0, // invalid index type + ART = 1, // Adaptive Radix Tree + EXTENSION = 100 // Extension index +}; + +//===--------------------------------------------------------------------===// +// Index Constraint Types +//===--------------------------------------------------------------------===// +enum class IndexConstraintType : uint8_t { + NONE = 0, // index is an index don't built to any constraint + UNIQUE = 1, // index is an index built to enforce a UNIQUE constraint + PRIMARY = 2, // index is an index built to enforce a PRIMARY KEY constraint + FOREIGN = 3 // index is an index built to enforce a FOREIGN KEY constraint +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/join_type.hpp b/src/duckdb/src/include/duckdb/common/enums/join_type.hpp new file mode 100644 index 00000000..c72cf001 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/join_type.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/join_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Join Types +//===--------------------------------------------------------------------===// +enum class JoinType : uint8_t { + INVALID = 0, // invalid join type + LEFT = 1, // left + RIGHT = 2, // right + INNER = 3, // inner + OUTER = 4, // outer + SEMI = 5, // SEMI join returns left side row ONLY if it has a join partner, no duplicates + ANTI = 6, // ANTI join returns left side row ONLY if it has NO join partner, no duplicates + MARK = 7, // MARK join returns marker indicating whether or not there is a join partner (true), there is no join + // partner (false) + SINGLE = 8 // SINGLE join is like LEFT OUTER JOIN, BUT returns at most one join partner per entry on the LEFT side + // (and NULL if no partner is found) +}; + +//! True if join is left or full outer join +bool IsLeftOuterJoin(JoinType type); + +//! True if join is rght or full outer join +bool IsRightOuterJoin(JoinType type); + +// **DEPRECATED**: Use EnumUtil directly instead. +string JoinTypeToString(JoinType type); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/joinref_type.hpp b/src/duckdb/src/include/duckdb/common/enums/joinref_type.hpp new file mode 100644 index 00000000..85b8c66b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/joinref_type.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/joinref_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Join Reference Types +//===--------------------------------------------------------------------===// +enum class JoinRefType : uint8_t { + REGULAR, // Explicit conditions + NATURAL, // Implied conditions + CROSS, // No condition + POSITIONAL, // Positional condition + ASOF, // AsOf conditions + DEPENDENT, // Dependent join conditions +}; + +const char *ToString(JoinRefType value); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp b/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp new file mode 100644 index 00000000..9215d3e4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp @@ -0,0 +1,114 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/logical_operator_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Logical Operator Types +//===--------------------------------------------------------------------===// +enum class LogicalOperatorType : uint8_t { + LOGICAL_INVALID = 0, + LOGICAL_PROJECTION = 1, + LOGICAL_FILTER = 2, + LOGICAL_AGGREGATE_AND_GROUP_BY = 3, + LOGICAL_WINDOW = 4, + LOGICAL_UNNEST = 5, + LOGICAL_LIMIT = 6, + LOGICAL_ORDER_BY = 7, + LOGICAL_TOP_N = 8, + LOGICAL_COPY_TO_FILE = 10, + LOGICAL_DISTINCT = 11, + LOGICAL_SAMPLE = 12, + LOGICAL_LIMIT_PERCENT = 13, + LOGICAL_PIVOT = 14, + + // ----------------------------- + // Data sources + // ----------------------------- + LOGICAL_GET = 25, + LOGICAL_CHUNK_GET = 26, + LOGICAL_DELIM_GET = 27, + LOGICAL_EXPRESSION_GET = 28, + LOGICAL_DUMMY_SCAN = 29, + LOGICAL_EMPTY_RESULT = 30, + LOGICAL_CTE_REF = 31, + // ----------------------------- + // Joins + // ----------------------------- + LOGICAL_JOIN = 50, + LOGICAL_DELIM_JOIN = 51, + LOGICAL_COMPARISON_JOIN = 52, + LOGICAL_ANY_JOIN = 53, + LOGICAL_CROSS_PRODUCT = 54, + LOGICAL_POSITIONAL_JOIN = 55, + LOGICAL_ASOF_JOIN = 56, + LOGICAL_DEPENDENT_JOIN = 57, + // ----------------------------- + // SetOps + // ----------------------------- + LOGICAL_UNION = 75, + LOGICAL_EXCEPT = 76, + LOGICAL_INTERSECT = 77, + LOGICAL_RECURSIVE_CTE = 78, + LOGICAL_MATERIALIZED_CTE = 79, + + // ----------------------------- + // Updates + // ----------------------------- + LOGICAL_INSERT = 100, + LOGICAL_DELETE = 101, + LOGICAL_UPDATE = 102, + + // ----------------------------- + // Schema + // ----------------------------- + LOGICAL_ALTER = 125, + LOGICAL_CREATE_TABLE = 126, + LOGICAL_CREATE_INDEX = 127, + LOGICAL_CREATE_SEQUENCE = 128, + LOGICAL_CREATE_VIEW = 129, + LOGICAL_CREATE_SCHEMA = 130, + LOGICAL_CREATE_MACRO = 131, + LOGICAL_DROP = 132, + LOGICAL_PRAGMA = 133, + LOGICAL_TRANSACTION = 134, + LOGICAL_CREATE_TYPE = 135, + LOGICAL_ATTACH = 136, + LOGICAL_DETACH = 137, + + // ----------------------------- + // Explain + // ----------------------------- + LOGICAL_EXPLAIN = 150, + + // ----------------------------- + // Show + // ----------------------------- + LOGICAL_SHOW = 160, + + // ----------------------------- + // Helpers + // ----------------------------- + LOGICAL_PREPARE = 175, + LOGICAL_EXECUTE = 176, + LOGICAL_EXPORT = 177, + LOGICAL_VACUUM = 178, + LOGICAL_SET = 179, + LOGICAL_LOAD = 180, + LOGICAL_RESET = 181, + + LOGICAL_EXTENSION_OPERATOR = 255 +}; + +DUCKDB_API string LogicalOperatorToString(LogicalOperatorType type); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/on_entry_not_found.hpp b/src/duckdb/src/include/duckdb/common/enums/on_entry_not_found.hpp new file mode 100644 index 00000000..092f9cd0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/on_entry_not_found.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/on_entry_not_found.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class OnEntryNotFound : uint8_t { THROW_EXCEPTION = 0, RETURN_NULL = 1 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp b/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp new file mode 100644 index 00000000..f7ada047 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/operator_result_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//! The OperatorResultType is used to indicate how data should flow around a regular (i.e. non-sink and non-source) +//! physical operator +//! There are four possible results: +//! NEED_MORE_INPUT means the operator is done with the current input and can consume more input if available +//! If there is more input the operator will be called with more input, otherwise the operator will not be called again. +//! HAVE_MORE_OUTPUT means the operator is not finished yet with the current input. +//! The operator will be called again with the same input. +//! FINISHED means the operator has finished the entire pipeline and no more processing is necessary. +//! The operator will not be called again, and neither will any other operators in this pipeline. +//! BLOCKED means the operator does not want to be called right now. e.g. because its currently doing async I/O. The +//! operator has set the interrupt state and the caller is expected to handle it. Note that intermediate operators +//! should currently not emit this state. +enum class OperatorResultType : uint8_t { NEED_MORE_INPUT, HAVE_MORE_OUTPUT, FINISHED, BLOCKED }; + +//! OperatorFinalizeResultType is used to indicate whether operators have finished flushing their cached results. +//! FINISHED means the operator has flushed all cached data. +//! HAVE_MORE_OUTPUT means the operator contains more results. +enum class OperatorFinalizeResultType : uint8_t { HAVE_MORE_OUTPUT, FINISHED }; + +//! SourceResultType is used to indicate the result of data being pulled out of a source. +//! There are three possible results: +//! HAVE_MORE_OUTPUT means the source has more output, this flag should only be set when data is returned, empty results +//! should only occur for the FINISHED and BLOCKED flags +//! FINISHED means the source is exhausted +//! BLOCKED means the source is currently blocked, e.g. by some async I/O +enum class SourceResultType : uint8_t { HAVE_MORE_OUTPUT, FINISHED, BLOCKED }; + +//! The SinkResultType is used to indicate the result of data flowing into a sink +//! There are three possible results: +//! NEED_MORE_INPUT means the sink needs more input +//! FINISHED means the sink is finished executing, and more input will not change the result any further +//! BLOCKED means the sink is currently blocked, e.g. by some async I/O. +enum class SinkResultType : uint8_t { NEED_MORE_INPUT, FINISHED, BLOCKED }; + +// todo comment +enum class SinkCombineResultType : uint8_t { FINISHED, BLOCKED }; + +//! The SinkFinalizeType is used to indicate the result of a Finalize call on a sink +//! There are two possible results: +//! READY means the sink is ready for further processing +//! NO_OUTPUT_POSSIBLE means the sink will never provide output, and any pipelines involving the sink can be skipped +//! BLOCKED means the finalize call to the sink is currently blocked, e.g. by some async I/O. +enum class SinkFinalizeType : uint8_t { READY, NO_OUTPUT_POSSIBLE, BLOCKED }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp new file mode 100644 index 00000000..873d9b2d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/optimizer_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class OptimizerType : uint32_t { + INVALID = 0, + EXPRESSION_REWRITER, + FILTER_PULLUP, + FILTER_PUSHDOWN, + REGEX_RANGE, + IN_CLAUSE, + JOIN_ORDER, + DELIMINATOR, + UNNEST_REWRITER, + UNUSED_COLUMNS, + STATISTICS_PROPAGATION, + COMMON_SUBEXPRESSIONS, + COMMON_AGGREGATE, + COLUMN_LIFETIME, + TOP_N, + COMPRESSED_MATERIALIZATION, + DUPLICATE_GROUPS, + REORDER_FILTER, + EXTENSION +}; + +string OptimizerTypeToString(OptimizerType type); +OptimizerType OptimizerTypeFromString(const string &str); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/order_preservation_type.hpp b/src/duckdb/src/include/duckdb/common/enums/order_preservation_type.hpp new file mode 100644 index 00000000..ef590c03 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/order_preservation_type.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// src/include/duckdb/common/enums/order_preservation_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Order Preservation Type +//===--------------------------------------------------------------------===// +enum class OrderPreservationType : uint8_t { + NO_ORDER, // the operator makes no guarantees on order preservation (i.e. it might re-order the entire input) + INSERTION_ORDER, // the operator maintains the order of the child operators + FIXED_ORDER // the operator outputs rows in a fixed order that must be maintained (e.g. ORDER BY) +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/order_type.hpp b/src/duckdb/src/include/duckdb/common/enums/order_type.hpp new file mode 100644 index 00000000..a23457d2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/order_type.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/order_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +enum class OrderType : uint8_t { INVALID = 0, ORDER_DEFAULT = 1, ASCENDING = 2, DESCENDING = 3 }; + +enum class OrderByNullType : uint8_t { INVALID = 0, ORDER_DEFAULT = 1, NULLS_FIRST = 2, NULLS_LAST = 3 }; + +enum class DefaultOrderByNullType : uint8_t { + INVALID = 0, + NULLS_FIRST = 2, + NULLS_LAST = 3, + NULLS_FIRST_ON_ASC_LAST_ON_DESC = 4, + NULLS_LAST_ON_ASC_FIRST_ON_DESC = 5 +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/output_type.hpp b/src/duckdb/src/include/duckdb/common/enums/output_type.hpp new file mode 100644 index 00000000..587eb81a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/output_type.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/output_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class ExplainOutputType : uint8_t { ALL = 0, OPTIMIZED_ONLY = 1, PHYSICAL_ONLY = 2 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/pending_execution_result.hpp b/src/duckdb/src/include/duckdb/common/enums/pending_execution_result.hpp new file mode 100644 index 00000000..e130e973 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/pending_execution_result.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/pending_execution_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class PendingExecutionResult : uint8_t { RESULT_READY, RESULT_NOT_READY, EXECUTION_ERROR, NO_TASKS_AVAILABLE }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/physical_operator_type.hpp b/src/duckdb/src/include/duckdb/common/enums/physical_operator_type.hpp new file mode 100644 index 00000000..5b15d39f --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/physical_operator_type.hpp @@ -0,0 +1,119 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/physical_operator_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Physical Operator Types +//===--------------------------------------------------------------------===// +enum class PhysicalOperatorType : uint8_t { + INVALID, + ORDER_BY, + LIMIT, + STREAMING_LIMIT, + LIMIT_PERCENT, + TOP_N, + WINDOW, + UNNEST, + UNGROUPED_AGGREGATE, + HASH_GROUP_BY, + PERFECT_HASH_GROUP_BY, + FILTER, + PROJECTION, + COPY_TO_FILE, + BATCH_COPY_TO_FILE, + FIXED_BATCH_COPY_TO_FILE, + RESERVOIR_SAMPLE, + STREAMING_SAMPLE, + STREAMING_WINDOW, + PIVOT, + + // ----------------------------- + // Scans + // ----------------------------- + TABLE_SCAN, + DUMMY_SCAN, + COLUMN_DATA_SCAN, + CHUNK_SCAN, + RECURSIVE_CTE_SCAN, + CTE_SCAN, + DELIM_SCAN, + EXPRESSION_SCAN, + POSITIONAL_SCAN, + // ----------------------------- + // Joins + // ----------------------------- + BLOCKWISE_NL_JOIN, + NESTED_LOOP_JOIN, + HASH_JOIN, + CROSS_PRODUCT, + PIECEWISE_MERGE_JOIN, + IE_JOIN, + DELIM_JOIN, + INDEX_JOIN, + POSITIONAL_JOIN, + ASOF_JOIN, + // ----------------------------- + // SetOps + // ----------------------------- + UNION, + RECURSIVE_CTE, + CTE, + + // ----------------------------- + // Updates + // ----------------------------- + INSERT, + BATCH_INSERT, + DELETE_OPERATOR, + UPDATE, + + // ----------------------------- + // Schema + // ----------------------------- + CREATE_TABLE, + CREATE_TABLE_AS, + BATCH_CREATE_TABLE_AS, + CREATE_INDEX, + ALTER, + CREATE_SEQUENCE, + CREATE_VIEW, + CREATE_SCHEMA, + CREATE_MACRO, + DROP, + PRAGMA, + TRANSACTION, + CREATE_TYPE, + ATTACH, + DETACH, + + // ----------------------------- + // Helpers + // ----------------------------- + EXPLAIN, + EXPLAIN_ANALYZE, + EMPTY_RESULT, + EXECUTE, + PREPARE, + VACUUM, + EXPORT, + SET, + LOAD, + INOUT_FUNCTION, + RESULT_COLLECTOR, + RESET, + EXTENSION +}; + +string PhysicalOperatorToString(PhysicalOperatorType type); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp b/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp new file mode 100644 index 00000000..2e4e2b05 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/profiler_format.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class ProfilerPrintFormat : uint8_t { QUERY_TREE, JSON, QUERY_TREE_OPTIMIZER }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/relation_type.hpp b/src/duckdb/src/include/duckdb/common/enums/relation_type.hpp new file mode 100644 index 00000000..f1c5fd33 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/relation_type.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/relation_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Catalog Types +//===--------------------------------------------------------------------===// +enum class RelationType : uint8_t { + INVALID_RELATION, + TABLE_RELATION, + PROJECTION_RELATION, + FILTER_RELATION, + EXPLAIN_RELATION, + CROSS_PRODUCT_RELATION, + JOIN_RELATION, + AGGREGATE_RELATION, + SET_OPERATION_RELATION, + DISTINCT_RELATION, + LIMIT_RELATION, + ORDER_RELATION, + CREATE_VIEW_RELATION, + CREATE_TABLE_RELATION, + INSERT_RELATION, + VALUE_LIST_RELATION, + DELETE_RELATION, + UPDATE_RELATION, + WRITE_CSV_RELATION, + WRITE_PARQUET_RELATION, + READ_CSV_RELATION, + SUBQUERY_RELATION, + TABLE_FUNCTION_RELATION, + VIEW_RELATION, + QUERY_RELATION +}; + +string RelationTypeToString(RelationType type); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/scan_options.hpp b/src/duckdb/src/include/duckdb/common/enums/scan_options.hpp new file mode 100644 index 00000000..3281aff4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/scan_options.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/scan_options.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class TableScanType : uint8_t { + //! Regular table scan: scan all tuples that are relevant for the current transaction + TABLE_SCAN_REGULAR = 0, + //! Scan all rows, including any deleted rows. Committed updates are merged in. + TABLE_SCAN_COMMITTED_ROWS = 1, + //! Scan all rows, including any deleted rows. Throws an exception if there are any uncommitted updates. + TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES = 2, + //! Scan all rows, excluding any permanently deleted rows. + //! Permanently deleted rows are rows which no transaction will ever need again. + TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED = 3 +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/set_operation_type.hpp b/src/duckdb/src/include/duckdb/common/enums/set_operation_type.hpp new file mode 100644 index 00000000..0cf5628e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/set_operation_type.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/set_operation_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class SetOperationType : uint8_t { NONE = 0, UNION = 1, EXCEPT = 2, INTERSECT = 3, UNION_BY_NAME = 4 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/set_scope.hpp b/src/duckdb/src/include/duckdb/common/enums/set_scope.hpp new file mode 100644 index 00000000..59420651 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/set_scope.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/set_scope.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class SetScope : uint8_t { + AUTOMATIC = 0, + LOCAL = 1, /* unused */ + SESSION = 2, + GLOBAL = 3 +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/set_type.hpp b/src/duckdb/src/include/duckdb/common/enums/set_type.hpp new file mode 100644 index 00000000..0b5eedd2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/set_type.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/set_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class SetType : uint8_t { SET = 0, RESET = 1 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp b/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp new file mode 100644 index 00000000..3ed9ba63 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/statement_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Statement Types +//===--------------------------------------------------------------------===// +enum class StatementType : uint8_t { + INVALID_STATEMENT, // invalid statement type + SELECT_STATEMENT, // select statement type + INSERT_STATEMENT, // insert statement type + UPDATE_STATEMENT, // update statement type + CREATE_STATEMENT, // create statement type + DELETE_STATEMENT, // delete statement type + PREPARE_STATEMENT, // prepare statement type + EXECUTE_STATEMENT, // execute statement type + ALTER_STATEMENT, // alter statement type + TRANSACTION_STATEMENT, // transaction statement type, + COPY_STATEMENT, // copy type + ANALYZE_STATEMENT, // analyze type + VARIABLE_SET_STATEMENT, // variable set statement type + CREATE_FUNC_STATEMENT, // create func statement type + EXPLAIN_STATEMENT, // explain statement type + DROP_STATEMENT, // DROP statement type + EXPORT_STATEMENT, // EXPORT statement type + PRAGMA_STATEMENT, // PRAGMA statement type + SHOW_STATEMENT, // SHOW statement type + VACUUM_STATEMENT, // VACUUM statement type + CALL_STATEMENT, // CALL statement type + SET_STATEMENT, // SET statement type + LOAD_STATEMENT, // LOAD statement type + RELATION_STATEMENT, + EXTENSION_STATEMENT, + LOGICAL_PLAN_STATEMENT, + ATTACH_STATEMENT, + DETACH_STATEMENT, + MULTI_STATEMENT + +}; + +DUCKDB_API string StatementTypeToString(StatementType type); + +enum class StatementReturnType : uint8_t { + QUERY_RESULT, // the statement returns a query result (e.g. for display to the user) + CHANGED_ROWS, // the statement returns a single row containing the number of changed rows (e.g. an insert stmt) + NOTHING // the statement returns nothing +}; + +string StatementReturnTypeToString(StatementReturnType type); + +//! A struct containing various properties of a SQL statement +struct StatementProperties { + StatementProperties() + : requires_valid_transaction(true), allow_stream_result(false), bound_all_parameters(true), + return_type(StatementReturnType::QUERY_RESULT), parameter_count(0) { + } + + //! The set of databases this statement will modify + unordered_set modified_databases; + //! Whether or not the statement requires a valid transaction. Almost all statements require this, with the + //! exception of + bool requires_valid_transaction; + //! Whether or not the result can be streamed to the client + bool allow_stream_result; + //! Whether or not all parameters have successfully had their types determined + bool bound_all_parameters; + //! What type of data the statement returns + StatementReturnType return_type; + //! The number of prepared statement parameters + idx_t parameter_count; + + bool IsReadOnly() { + return modified_databases.empty(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/subquery_type.hpp b/src/duckdb/src/include/duckdb/common/enums/subquery_type.hpp new file mode 100644 index 00000000..2415ef5b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/subquery_type.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/subquery_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Subquery Types +//===--------------------------------------------------------------------===// +enum class SubqueryType : uint8_t { + INVALID = 0, + SCALAR = 1, // Regular scalar subquery + EXISTS = 2, // EXISTS (SELECT...) + NOT_EXISTS = 3, // NOT EXISTS(SELECT...) + ANY = 4, // x = ANY(SELECT...) OR x IN (SELECT...) +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/tableref_type.hpp b/src/duckdb/src/include/duckdb/common/enums/tableref_type.hpp new file mode 100644 index 00000000..1f042c16 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/tableref_type.hpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/tableref_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Table Reference Types +//===--------------------------------------------------------------------===// +enum class TableReferenceType : uint8_t { + INVALID = 0, // invalid table reference type + BASE_TABLE = 1, // base table reference + SUBQUERY = 2, // output of a subquery + JOIN = 3, // output of join + TABLE_FUNCTION = 5, // table producing function + EXPRESSION_LIST = 6, // expression list + CTE = 7, // Recursive CTE + EMPTY = 8, // placeholder for empty FROM + PIVOT = 9 // pivot statement +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/undo_flags.hpp b/src/duckdb/src/include/duckdb/common/enums/undo_flags.hpp new file mode 100644 index 00000000..92816114 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/undo_flags.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/undo_flags.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class UndoFlags : uint32_t { // far too big but aligned (TM) + EMPTY_ENTRY = 0, + CATALOG_ENTRY = 1, + INSERT_TUPLE = 2, + DELETE_TUPLE = 3, + UPDATE_TUPLE = 4 +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/vector_type.hpp b/src/duckdb/src/include/duckdb/common/enums/vector_type.hpp new file mode 100644 index 00000000..155df5fe --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/vector_type.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/vector_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class VectorType : uint8_t { + FLAT_VECTOR, // Flat vectors represent a standard uncompressed vector + FSST_VECTOR, // Contains string data compressed with FSST + CONSTANT_VECTOR, // Constant vector represents a single constant + DICTIONARY_VECTOR, // Dictionary vector represents a selection vector on top of another vector + SEQUENCE_VECTOR // Sequence vector represents a sequence with a start point and an increment +}; + +string VectorTypeToString(VectorType type); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/wal_type.hpp b/src/duckdb/src/include/duckdb/common/enums/wal_type.hpp new file mode 100644 index 00000000..b18fa9f1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/wal_type.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/wal_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class WALType : uint8_t { + INVALID = 0, + // ----------------------------- + // Catalog + // ----------------------------- + CREATE_TABLE = 1, + DROP_TABLE = 2, + + CREATE_SCHEMA = 3, + DROP_SCHEMA = 4, + + CREATE_VIEW = 5, + DROP_VIEW = 6, + + CREATE_SEQUENCE = 8, + DROP_SEQUENCE = 9, + SEQUENCE_VALUE = 10, + + CREATE_MACRO = 11, + DROP_MACRO = 12, + + CREATE_TYPE = 13, + DROP_TYPE = 14, + + ALTER_INFO = 20, + + CREATE_TABLE_MACRO = 21, + DROP_TABLE_MACRO = 22, + + CREATE_INDEX = 23, + DROP_INDEX = 24, + + // ----------------------------- + // Data + // ----------------------------- + USE_TABLE = 25, + INSERT_TUPLE = 26, + DELETE_TUPLE = 27, + UPDATE_TUPLE = 28, + // ----------------------------- + // Flush + // ----------------------------- + CHECKPOINT = 99, + WAL_FLUSH = 100 +}; +} diff --git a/src/duckdb/src/include/duckdb/common/enums/window_aggregation_mode.hpp b/src/duckdb/src/include/duckdb/common/enums/window_aggregation_mode.hpp new file mode 100644 index 00000000..2bbd1013 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/window_aggregation_mode.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/window_aggregation_mode.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class WindowAggregationMode : uint32_t { + //! Use the window aggregate API if available + WINDOW = 0, + //! Don't use window, but use combine if available + COMBINE, + //! Don't use combine or window (compute each frame separately) + SEPARATE +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/exception.hpp b/src/duckdb/src/include/duckdb/common/exception.hpp new file mode 100644 index 00000000..cffde874 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/exception.hpp @@ -0,0 +1,477 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/exception.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception_format_value.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/typedefs.hpp" + +#include +#include + +namespace duckdb { +enum class PhysicalType : uint8_t; +struct LogicalType; +struct hugeint_t; + +inline void assert_restrict_function(const void *left_start, const void *left_end, const void *right_start, + const void *right_end, const char *fname, int linenr) { + // assert that the two pointers do not overlap +#ifdef DEBUG + if (!(left_end <= right_start || right_end <= left_start)) { + printf("ASSERT RESTRICT FAILED: %s:%d\n", fname, linenr); + D_ASSERT(0); + } +#endif +} + +#define ASSERT_RESTRICT(left_start, left_end, right_start, right_end) \ + assert_restrict_function(left_start, left_end, right_start, right_end, __FILE__, __LINE__) + +//===--------------------------------------------------------------------===// +// Exception Types +//===--------------------------------------------------------------------===// + +enum class ExceptionType { + INVALID = 0, // invalid type + OUT_OF_RANGE = 1, // value out of range error + CONVERSION = 2, // conversion/casting error + UNKNOWN_TYPE = 3, // unknown type + DECIMAL = 4, // decimal related + MISMATCH_TYPE = 5, // type mismatch + DIVIDE_BY_ZERO = 6, // divide by 0 + OBJECT_SIZE = 7, // object size exceeded + INVALID_TYPE = 8, // incompatible for operation + SERIALIZATION = 9, // serialization + TRANSACTION = 10, // transaction management + NOT_IMPLEMENTED = 11, // method not implemented + EXPRESSION = 12, // expression parsing + CATALOG = 13, // catalog related + PARSER = 14, // parser related + PLANNER = 15, // planner related + SCHEDULER = 16, // scheduler related + EXECUTOR = 17, // executor related + CONSTRAINT = 18, // constraint related + INDEX = 19, // index related + STAT = 20, // stat related + CONNECTION = 21, // connection related + SYNTAX = 22, // syntax related + SETTINGS = 23, // settings related + BINDER = 24, // binder related + NETWORK = 25, // network related + OPTIMIZER = 26, // optimizer related + NULL_POINTER = 27, // nullptr exception + IO = 28, // IO exception + INTERRUPT = 29, // interrupt + FATAL = 30, // Fatal exceptions are non-recoverable, and render the entire DB in an unusable state + INTERNAL = 31, // Internal exceptions indicate something went wrong internally (i.e. bug in the code base) + INVALID_INPUT = 32, // Input or arguments error + OUT_OF_MEMORY = 33, // out of memory + PERMISSION = 34, // insufficient permissions + PARAMETER_NOT_RESOLVED = 35, // parameter types could not be resolved + PARAMETER_NOT_ALLOWED = 36, // parameter types not allowed + DEPENDENCY = 37, // dependency + HTTP = 38, + MISSING_EXTENSION = 39, // Thrown when an extension is used but not loaded + AUTOLOAD = 40 // Thrown when an extension is used but not loaded +}; +class HTTPException; + +class Exception : public std::exception { +public: + DUCKDB_API explicit Exception(const string &msg); + DUCKDB_API Exception(ExceptionType exception_type, const string &message); + + ExceptionType type; + +public: + DUCKDB_API const char *what() const noexcept override; + DUCKDB_API const string &RawMessage() const; + + DUCKDB_API static string ExceptionTypeToString(ExceptionType type); + [[noreturn]] DUCKDB_API static void ThrowAsTypeWithMessage(ExceptionType type, const string &message, + const std::shared_ptr &original); + virtual std::shared_ptr Copy() const { + return make_shared(type, raw_message_); + } + DUCKDB_API const HTTPException &AsHTTPException() const; + + template + static string ConstructMessage(const string &msg, Args... params) { + const std::size_t num_args = sizeof...(Args); + if (num_args == 0) + return msg; + std::vector values; + return ConstructMessageRecursive(msg, values, params...); + } + + DUCKDB_API static string ConstructMessageRecursive(const string &msg, std::vector &values); + + template + static string ConstructMessageRecursive(const string &msg, std::vector &values, T param, + Args... params) { + values.push_back(ExceptionFormatValue::CreateFormatValue(param)); + return ConstructMessageRecursive(msg, values, params...); + } + + DUCKDB_API static bool UncaughtException(); + + DUCKDB_API static string GetStackTrace(int max_depth = 120); + static string FormatStackTrace(string message = "") { + return (message + "\n" + GetStackTrace()); + } + +private: + string exception_message_; + string raw_message_; +}; + +//===--------------------------------------------------------------------===// +// Exception derived classes +//===--------------------------------------------------------------------===// + +//! Exceptions that are StandardExceptions do NOT invalidate the current transaction when thrown +class StandardException : public Exception { +public: + DUCKDB_API StandardException(ExceptionType exception_type, const string &message); +}; + +class CatalogException : public StandardException { +public: + DUCKDB_API explicit CatalogException(const string &msg); + + template + explicit CatalogException(const string &msg, Args... params) : CatalogException(ConstructMessage(msg, params...)) { + } +}; + +class ConnectionException : public StandardException { +public: + DUCKDB_API explicit ConnectionException(const string &msg); + + template + explicit ConnectionException(const string &msg, Args... params) + : ConnectionException(ConstructMessage(msg, params...)) { + } +}; + +class ParserException : public StandardException { +public: + DUCKDB_API explicit ParserException(const string &msg); + + template + explicit ParserException(const string &msg, Args... params) : ParserException(ConstructMessage(msg, params...)) { + } +}; + +class PermissionException : public StandardException { +public: + DUCKDB_API explicit PermissionException(const string &msg); + + template + explicit PermissionException(const string &msg, Args... params) + : PermissionException(ConstructMessage(msg, params...)) { + } +}; + +class BinderException : public StandardException { +public: + DUCKDB_API explicit BinderException(const string &msg); + + template + explicit BinderException(const string &msg, Args... params) : BinderException(ConstructMessage(msg, params...)) { + } +}; + +class ConversionException : public Exception { +public: + DUCKDB_API explicit ConversionException(const string &msg); + + template + explicit ConversionException(const string &msg, Args... params) + : ConversionException(ConstructMessage(msg, params...)) { + } +}; + +class TransactionException : public Exception { +public: + DUCKDB_API explicit TransactionException(const string &msg); + + template + explicit TransactionException(const string &msg, Args... params) + : TransactionException(ConstructMessage(msg, params...)) { + } +}; + +class NotImplementedException : public Exception { +public: + DUCKDB_API explicit NotImplementedException(const string &msg); + + template + explicit NotImplementedException(const string &msg, Args... params) + : NotImplementedException(ConstructMessage(msg, params...)) { + } +}; + +class OutOfRangeException : public Exception { +public: + DUCKDB_API explicit OutOfRangeException(const string &msg); + + template + explicit OutOfRangeException(const string &msg, Args... params) + : OutOfRangeException(ConstructMessage(msg, params...)) { + } +}; + +class OutOfMemoryException : public Exception { +public: + DUCKDB_API explicit OutOfMemoryException(const string &msg); + + template + explicit OutOfMemoryException(const string &msg, Args... params) + : OutOfMemoryException(ConstructMessage(msg, params...)) { + } +}; + +class SyntaxException : public Exception { +public: + DUCKDB_API explicit SyntaxException(const string &msg); + + template + explicit SyntaxException(const string &msg, Args... params) : SyntaxException(ConstructMessage(msg, params...)) { + } +}; + +class ConstraintException : public Exception { +public: + DUCKDB_API explicit ConstraintException(const string &msg); + + template + explicit ConstraintException(const string &msg, Args... params) + : ConstraintException(ConstructMessage(msg, params...)) { + } +}; + +class DependencyException : public Exception { +public: + DUCKDB_API explicit DependencyException(const string &msg); + + template + explicit DependencyException(const string &msg, Args... params) + : DependencyException(ConstructMessage(msg, params...)) { + } +}; + +class IOException : public Exception { +public: + DUCKDB_API explicit IOException(const string &msg); + explicit IOException(ExceptionType exception_type, const string &msg) : Exception(exception_type, msg) { + } + + template + explicit IOException(const string &msg, Args... params) : IOException(ConstructMessage(msg, params...)) { + } +}; + +class MissingExtensionException : public Exception { +public: + DUCKDB_API explicit MissingExtensionException(const string &msg); + + template + explicit MissingExtensionException(const string &msg, Args... params) + : MissingExtensionException(ConstructMessage(msg, params...)) { + } +}; + +class AutoloadException : public Exception { +public: + DUCKDB_API explicit AutoloadException(const string &extension_name, Exception &e); + + template + explicit AutoloadException(const string &extension_name, Exception &e, Args... params) + : AutoloadException(ConstructMessage(extension_name, e, params...)) { + } + +protected: + Exception &wrapped_exception; +}; + +class HTTPException : public IOException { +public: + template + struct ResponseShape { + typedef int status; + }; + + template ::status = 0, typename... ARGS> + explicit HTTPException(RESPONSE &response, const string &msg, ARGS... params) + : HTTPException(response.status, response.body, response.headers, response.reason, msg, params...) { + } + + template + struct ResponseWrapperShape { + typedef int code; + }; + template ::code = 0, typename... ARGS> + explicit HTTPException(RESPONSE &response, const string &msg, ARGS... params) + : HTTPException(response.code, response.body, response.headers, response.error, msg, params...) { + } + + template + explicit HTTPException(int status_code, string response_body, HEADERS headers, const string &reason, + const string &msg, ARGS... params) + : IOException(ExceptionType::HTTP, ConstructMessage(msg, params...)), status_code(status_code), reason(reason), + response_body(std::move(response_body)) { + this->headers.insert(headers.begin(), headers.end()); + D_ASSERT(this->headers.size() > 0); + } + + std::shared_ptr Copy() const { + return make_shared(status_code, response_body, headers, reason, RawMessage()); + } + + const std::multimap GetHeaders() const { + return headers; + } + int GetStatusCode() const { + return status_code; + } + const string &GetResponseBody() const { + return response_body; + } + const string &GetReason() const { + return reason; + } + [[noreturn]] void Throw() const { + throw HTTPException(status_code, response_body, headers, reason, RawMessage()); + } + +private: + int status_code; + string reason; + string response_body; + std::multimap headers; +}; + +class SerializationException : public Exception { +public: + DUCKDB_API explicit SerializationException(const string &msg); + + template + explicit SerializationException(const string &msg, Args... params) + : SerializationException(ConstructMessage(msg, params...)) { + } +}; + +class SequenceException : public Exception { +public: + DUCKDB_API explicit SequenceException(const string &msg); + + template + explicit SequenceException(const string &msg, Args... params) + : SequenceException(ConstructMessage(msg, params...)) { + } +}; + +class InterruptException : public Exception { +public: + DUCKDB_API InterruptException(); +}; + +class FatalException : public Exception { +public: + explicit FatalException(const string &msg) : FatalException(ExceptionType::FATAL, msg) { + } + template + explicit FatalException(const string &msg, Args... params) : FatalException(ConstructMessage(msg, params...)) { + } + +protected: + DUCKDB_API explicit FatalException(ExceptionType type, const string &msg); + template + explicit FatalException(ExceptionType type, const string &msg, Args... params) + : FatalException(type, ConstructMessage(msg, params...)) { + } +}; + +class InternalException : public FatalException { +public: + DUCKDB_API explicit InternalException(const string &msg); + + template + explicit InternalException(const string &msg, Args... params) + : InternalException(ConstructMessage(msg, params...)) { + } +}; + +class InvalidInputException : public Exception { +public: + DUCKDB_API explicit InvalidInputException(const string &msg); + + template + explicit InvalidInputException(const string &msg, Args... params) + : InvalidInputException(ConstructMessage(msg, params...)) { + } +}; + +class CastException : public Exception { +public: + DUCKDB_API CastException(const PhysicalType origType, const PhysicalType newType); + DUCKDB_API CastException(const LogicalType &origType, const LogicalType &newType); + DUCKDB_API + CastException(const string &msg); //! Needed to be able to recreate the exception after it's been serialized +}; + +class InvalidTypeException : public Exception { +public: + DUCKDB_API InvalidTypeException(PhysicalType type, const string &msg); + DUCKDB_API InvalidTypeException(const LogicalType &type, const string &msg); + DUCKDB_API + InvalidTypeException(const string &msg); //! Needed to be able to recreate the exception after it's been serialized +}; + +class TypeMismatchException : public Exception { +public: + DUCKDB_API TypeMismatchException(const PhysicalType type_1, const PhysicalType type_2, const string &msg); + DUCKDB_API TypeMismatchException(const LogicalType &type_1, const LogicalType &type_2, const string &msg); + DUCKDB_API + TypeMismatchException(const string &msg); //! Needed to be able to recreate the exception after it's been serialized +}; + +class ValueOutOfRangeException : public Exception { +public: + DUCKDB_API ValueOutOfRangeException(const int64_t value, const PhysicalType origType, const PhysicalType newType); + DUCKDB_API ValueOutOfRangeException(const hugeint_t value, const PhysicalType origType, const PhysicalType newType); + DUCKDB_API ValueOutOfRangeException(const double value, const PhysicalType origType, const PhysicalType newType); + DUCKDB_API ValueOutOfRangeException(const PhysicalType varType, const idx_t length); + DUCKDB_API ValueOutOfRangeException( + const string &msg); //! Needed to be able to recreate the exception after it's been serialized +}; + +class ParameterNotAllowedException : public StandardException { +public: + DUCKDB_API explicit ParameterNotAllowedException(const string &msg); + + template + explicit ParameterNotAllowedException(const string &msg, Args... params) + : ParameterNotAllowedException(ConstructMessage(msg, params...)) { + } +}; + +//! Special exception that should be thrown in the binder if parameter types could not be resolved +//! This will cause prepared statements to be forcibly rebound with the actual parameter values +//! This exception is fatal if thrown outside of the binder (i.e. it should never be thrown outside of the binder) +class ParameterNotResolvedException : public Exception { +public: + DUCKDB_API explicit ParameterNotResolvedException(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp new file mode 100644 index 00000000..1834663e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/exception_format_value.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/string.hpp" +#include "duckdb/common/hugeint.hpp" + +#include + +namespace duckdb { + +// Helper class to support custom overloading +// Escaping " and quoting the value with " +class SQLIdentifier { +public: + SQLIdentifier(const string &raw_string) : raw_string(raw_string) { + } + +public: + string raw_string; +}; + +// Helper class to support custom overloading +// Escaping ' and quoting the value with ' +class SQLString { +public: + SQLString(const string &raw_string) : raw_string(raw_string) { + } + +public: + string raw_string; +}; + +enum class PhysicalType : uint8_t; +struct LogicalType; + +enum class ExceptionFormatValueType : uint8_t { + FORMAT_VALUE_TYPE_DOUBLE, + FORMAT_VALUE_TYPE_INTEGER, + FORMAT_VALUE_TYPE_STRING +}; + +struct ExceptionFormatValue { + DUCKDB_API ExceptionFormatValue(double dbl_val); // NOLINT + DUCKDB_API ExceptionFormatValue(int64_t int_val); // NOLINT + DUCKDB_API ExceptionFormatValue(string str_val); // NOLINT + DUCKDB_API ExceptionFormatValue(hugeint_t hg_val); // NOLINT + + ExceptionFormatValueType type; + + double dbl_val = 0; + int64_t int_val = 0; + string str_val; + +public: + template + static ExceptionFormatValue CreateFormatValue(T value) { + return int64_t(value); + } + static string Format(const string &msg, std::vector &values); +}; + +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(SQLString value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(SQLIdentifier value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(LogicalType value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value); +template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/extra_operator_info.hpp b/src/duckdb/src/include/duckdb/common/extra_operator_info.hpp new file mode 100644 index 00000000..194246bc --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/extra_operator_info.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/extra_operator_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include "duckdb/common/operator/comparison_operators.hpp" + +namespace duckdb { + +class ExtraOperatorInfo { +public: + ExtraOperatorInfo() : file_filters("") { + } + ExtraOperatorInfo(ExtraOperatorInfo &extra_info) : file_filters(extra_info.file_filters) { + } + string file_filters; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp new file mode 100644 index 00000000..74895fd0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp @@ -0,0 +1,185 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/extra_type_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +//! Extra Type Info Type +enum class ExtraTypeInfoType : uint8_t { + INVALID_TYPE_INFO = 0, + GENERIC_TYPE_INFO = 1, + DECIMAL_TYPE_INFO = 2, + STRING_TYPE_INFO = 3, + LIST_TYPE_INFO = 4, + STRUCT_TYPE_INFO = 5, + ENUM_TYPE_INFO = 6, + USER_TYPE_INFO = 7, + AGGREGATE_STATE_TYPE_INFO = 8 +}; + +struct ExtraTypeInfo { + explicit ExtraTypeInfo(ExtraTypeInfoType type); + explicit ExtraTypeInfo(ExtraTypeInfoType type, string alias); + virtual ~ExtraTypeInfo(); + + ExtraTypeInfoType type; + string alias; + +public: + bool Equals(ExtraTypeInfo *other_p) const; + + virtual void Serialize(Serializer &serializer) const; + static shared_ptr Deserialize(Deserializer &source); + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + +protected: + virtual bool EqualsInternal(ExtraTypeInfo *other_p) const; +}; + +struct DecimalTypeInfo : public ExtraTypeInfo { + DecimalTypeInfo(uint8_t width_p, uint8_t scale_p); + + uint8_t width; + uint8_t scale; + +public: + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; + +private: + DecimalTypeInfo(); +}; + +struct StringTypeInfo : public ExtraTypeInfo { + explicit StringTypeInfo(string collation_p); + + string collation; + +public: + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; + +private: + StringTypeInfo(); +}; + +struct ListTypeInfo : public ExtraTypeInfo { + explicit ListTypeInfo(LogicalType child_type_p); + + LogicalType child_type; + +public: + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; + +private: + ListTypeInfo(); +}; + +struct StructTypeInfo : public ExtraTypeInfo { + explicit StructTypeInfo(child_list_t child_types_p); + + child_list_t child_types; + +public: + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &deserializer); + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; + +private: + StructTypeInfo(); +}; + +struct AggregateStateTypeInfo : public ExtraTypeInfo { + explicit AggregateStateTypeInfo(aggregate_state_t state_type_p); + + aggregate_state_t state_type; + +public: + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; + +private: + AggregateStateTypeInfo(); +}; + +struct UserTypeInfo : public ExtraTypeInfo { + explicit UserTypeInfo(string name_p); + + string user_type_name; + +public: + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; + +private: + UserTypeInfo(); +}; + +// If this type is primarily stored in the catalog or not. Enums from Pandas/Factors are not in the catalog. +enum EnumDictType : uint8_t { INVALID = 0, VECTOR_DICT = 1 }; + +struct EnumTypeInfo : public ExtraTypeInfo { + explicit EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p); + EnumTypeInfo(const EnumTypeInfo &) = delete; + EnumTypeInfo &operator=(const EnumTypeInfo &) = delete; + +public: + const EnumDictType &GetEnumDictType() const; + const Vector &GetValuesInsertOrder() const; + const idx_t &GetDictSize() const; + static PhysicalType DictType(idx_t size); + + static LogicalType CreateType(Vector &ordered_data, idx_t size); + + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + +protected: + // Equalities are only used in enums with different catalog entries + bool EqualsInternal(ExtraTypeInfo *other_p) const override; + + Vector values_insert_order; + +private: + EnumDictType dict_type; + idx_t dict_size; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/fast_mem.hpp b/src/duckdb/src/include/duckdb/common/fast_mem.hpp new file mode 100644 index 00000000..dc2730f3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/fast_mem.hpp @@ -0,0 +1,1221 @@ +// DuckDB +// +// duckdb/common/fast_mem.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" + +template +static inline void MemcpyFixed(void *dest, const void *src) { + memcpy(dest, src, SIZE); +} + +template +static inline int MemcmpFixed(const void *str1, const void *str2) { + return memcmp(str1, str2, SIZE); +} + +template +static inline void MemsetFixed(void *ptr, int value) { + memset(ptr, value, SIZE); +} + +namespace duckdb { + +//! This templated memcpy is significantly faster than std::memcpy, +//! but only when you are calling memcpy with a const size in a loop. +//! For instance `while () { memcpy(, , const_size); ... }` +static inline void FastMemcpy(void *dest, const void *src, const size_t size) { + // LCOV_EXCL_START + switch (size) { + case 0: + return; + case 1: + return MemcpyFixed<1>(dest, src); + case 2: + return MemcpyFixed<2>(dest, src); + case 3: + return MemcpyFixed<3>(dest, src); + case 4: + return MemcpyFixed<4>(dest, src); + case 5: + return MemcpyFixed<5>(dest, src); + case 6: + return MemcpyFixed<6>(dest, src); + case 7: + return MemcpyFixed<7>(dest, src); + case 8: + return MemcpyFixed<8>(dest, src); + case 9: + return MemcpyFixed<9>(dest, src); + case 10: + return MemcpyFixed<10>(dest, src); + case 11: + return MemcpyFixed<11>(dest, src); + case 12: + return MemcpyFixed<12>(dest, src); + case 13: + return MemcpyFixed<13>(dest, src); + case 14: + return MemcpyFixed<14>(dest, src); + case 15: + return MemcpyFixed<15>(dest, src); + case 16: + return MemcpyFixed<16>(dest, src); + case 17: + return MemcpyFixed<17>(dest, src); + case 18: + return MemcpyFixed<18>(dest, src); + case 19: + return MemcpyFixed<19>(dest, src); + case 20: + return MemcpyFixed<20>(dest, src); + case 21: + return MemcpyFixed<21>(dest, src); + case 22: + return MemcpyFixed<22>(dest, src); + case 23: + return MemcpyFixed<23>(dest, src); + case 24: + return MemcpyFixed<24>(dest, src); + case 25: + return MemcpyFixed<25>(dest, src); + case 26: + return MemcpyFixed<26>(dest, src); + case 27: + return MemcpyFixed<27>(dest, src); + case 28: + return MemcpyFixed<28>(dest, src); + case 29: + return MemcpyFixed<29>(dest, src); + case 30: + return MemcpyFixed<30>(dest, src); + case 31: + return MemcpyFixed<31>(dest, src); + case 32: + return MemcpyFixed<32>(dest, src); + case 33: + return MemcpyFixed<33>(dest, src); + case 34: + return MemcpyFixed<34>(dest, src); + case 35: + return MemcpyFixed<35>(dest, src); + case 36: + return MemcpyFixed<36>(dest, src); + case 37: + return MemcpyFixed<37>(dest, src); + case 38: + return MemcpyFixed<38>(dest, src); + case 39: + return MemcpyFixed<39>(dest, src); + case 40: + return MemcpyFixed<40>(dest, src); + case 41: + return MemcpyFixed<41>(dest, src); + case 42: + return MemcpyFixed<42>(dest, src); + case 43: + return MemcpyFixed<43>(dest, src); + case 44: + return MemcpyFixed<44>(dest, src); + case 45: + return MemcpyFixed<45>(dest, src); + case 46: + return MemcpyFixed<46>(dest, src); + case 47: + return MemcpyFixed<47>(dest, src); + case 48: + return MemcpyFixed<48>(dest, src); + case 49: + return MemcpyFixed<49>(dest, src); + case 50: + return MemcpyFixed<50>(dest, src); + case 51: + return MemcpyFixed<51>(dest, src); + case 52: + return MemcpyFixed<52>(dest, src); + case 53: + return MemcpyFixed<53>(dest, src); + case 54: + return MemcpyFixed<54>(dest, src); + case 55: + return MemcpyFixed<55>(dest, src); + case 56: + return MemcpyFixed<56>(dest, src); + case 57: + return MemcpyFixed<57>(dest, src); + case 58: + return MemcpyFixed<58>(dest, src); + case 59: + return MemcpyFixed<59>(dest, src); + case 60: + return MemcpyFixed<60>(dest, src); + case 61: + return MemcpyFixed<61>(dest, src); + case 62: + return MemcpyFixed<62>(dest, src); + case 63: + return MemcpyFixed<63>(dest, src); + case 64: + return MemcpyFixed<64>(dest, src); + case 65: + return MemcpyFixed<65>(dest, src); + case 66: + return MemcpyFixed<66>(dest, src); + case 67: + return MemcpyFixed<67>(dest, src); + case 68: + return MemcpyFixed<68>(dest, src); + case 69: + return MemcpyFixed<69>(dest, src); + case 70: + return MemcpyFixed<70>(dest, src); + case 71: + return MemcpyFixed<71>(dest, src); + case 72: + return MemcpyFixed<72>(dest, src); + case 73: + return MemcpyFixed<73>(dest, src); + case 74: + return MemcpyFixed<74>(dest, src); + case 75: + return MemcpyFixed<75>(dest, src); + case 76: + return MemcpyFixed<76>(dest, src); + case 77: + return MemcpyFixed<77>(dest, src); + case 78: + return MemcpyFixed<78>(dest, src); + case 79: + return MemcpyFixed<79>(dest, src); + case 80: + return MemcpyFixed<80>(dest, src); + case 81: + return MemcpyFixed<81>(dest, src); + case 82: + return MemcpyFixed<82>(dest, src); + case 83: + return MemcpyFixed<83>(dest, src); + case 84: + return MemcpyFixed<84>(dest, src); + case 85: + return MemcpyFixed<85>(dest, src); + case 86: + return MemcpyFixed<86>(dest, src); + case 87: + return MemcpyFixed<87>(dest, src); + case 88: + return MemcpyFixed<88>(dest, src); + case 89: + return MemcpyFixed<89>(dest, src); + case 90: + return MemcpyFixed<90>(dest, src); + case 91: + return MemcpyFixed<91>(dest, src); + case 92: + return MemcpyFixed<92>(dest, src); + case 93: + return MemcpyFixed<93>(dest, src); + case 94: + return MemcpyFixed<94>(dest, src); + case 95: + return MemcpyFixed<95>(dest, src); + case 96: + return MemcpyFixed<96>(dest, src); + case 97: + return MemcpyFixed<97>(dest, src); + case 98: + return MemcpyFixed<98>(dest, src); + case 99: + return MemcpyFixed<99>(dest, src); + case 100: + return MemcpyFixed<100>(dest, src); + case 101: + return MemcpyFixed<101>(dest, src); + case 102: + return MemcpyFixed<102>(dest, src); + case 103: + return MemcpyFixed<103>(dest, src); + case 104: + return MemcpyFixed<104>(dest, src); + case 105: + return MemcpyFixed<105>(dest, src); + case 106: + return MemcpyFixed<106>(dest, src); + case 107: + return MemcpyFixed<107>(dest, src); + case 108: + return MemcpyFixed<108>(dest, src); + case 109: + return MemcpyFixed<109>(dest, src); + case 110: + return MemcpyFixed<110>(dest, src); + case 111: + return MemcpyFixed<111>(dest, src); + case 112: + return MemcpyFixed<112>(dest, src); + case 113: + return MemcpyFixed<113>(dest, src); + case 114: + return MemcpyFixed<114>(dest, src); + case 115: + return MemcpyFixed<115>(dest, src); + case 116: + return MemcpyFixed<116>(dest, src); + case 117: + return MemcpyFixed<117>(dest, src); + case 118: + return MemcpyFixed<118>(dest, src); + case 119: + return MemcpyFixed<119>(dest, src); + case 120: + return MemcpyFixed<120>(dest, src); + case 121: + return MemcpyFixed<121>(dest, src); + case 122: + return MemcpyFixed<122>(dest, src); + case 123: + return MemcpyFixed<123>(dest, src); + case 124: + return MemcpyFixed<124>(dest, src); + case 125: + return MemcpyFixed<125>(dest, src); + case 126: + return MemcpyFixed<126>(dest, src); + case 127: + return MemcpyFixed<127>(dest, src); + case 128: + return MemcpyFixed<128>(dest, src); + case 129: + return MemcpyFixed<129>(dest, src); + case 130: + return MemcpyFixed<130>(dest, src); + case 131: + return MemcpyFixed<131>(dest, src); + case 132: + return MemcpyFixed<132>(dest, src); + case 133: + return MemcpyFixed<133>(dest, src); + case 134: + return MemcpyFixed<134>(dest, src); + case 135: + return MemcpyFixed<135>(dest, src); + case 136: + return MemcpyFixed<136>(dest, src); + case 137: + return MemcpyFixed<137>(dest, src); + case 138: + return MemcpyFixed<138>(dest, src); + case 139: + return MemcpyFixed<139>(dest, src); + case 140: + return MemcpyFixed<140>(dest, src); + case 141: + return MemcpyFixed<141>(dest, src); + case 142: + return MemcpyFixed<142>(dest, src); + case 143: + return MemcpyFixed<143>(dest, src); + case 144: + return MemcpyFixed<144>(dest, src); + case 145: + return MemcpyFixed<145>(dest, src); + case 146: + return MemcpyFixed<146>(dest, src); + case 147: + return MemcpyFixed<147>(dest, src); + case 148: + return MemcpyFixed<148>(dest, src); + case 149: + return MemcpyFixed<149>(dest, src); + case 150: + return MemcpyFixed<150>(dest, src); + case 151: + return MemcpyFixed<151>(dest, src); + case 152: + return MemcpyFixed<152>(dest, src); + case 153: + return MemcpyFixed<153>(dest, src); + case 154: + return MemcpyFixed<154>(dest, src); + case 155: + return MemcpyFixed<155>(dest, src); + case 156: + return MemcpyFixed<156>(dest, src); + case 157: + return MemcpyFixed<157>(dest, src); + case 158: + return MemcpyFixed<158>(dest, src); + case 159: + return MemcpyFixed<159>(dest, src); + case 160: + return MemcpyFixed<160>(dest, src); + case 161: + return MemcpyFixed<161>(dest, src); + case 162: + return MemcpyFixed<162>(dest, src); + case 163: + return MemcpyFixed<163>(dest, src); + case 164: + return MemcpyFixed<164>(dest, src); + case 165: + return MemcpyFixed<165>(dest, src); + case 166: + return MemcpyFixed<166>(dest, src); + case 167: + return MemcpyFixed<167>(dest, src); + case 168: + return MemcpyFixed<168>(dest, src); + case 169: + return MemcpyFixed<169>(dest, src); + case 170: + return MemcpyFixed<170>(dest, src); + case 171: + return MemcpyFixed<171>(dest, src); + case 172: + return MemcpyFixed<172>(dest, src); + case 173: + return MemcpyFixed<173>(dest, src); + case 174: + return MemcpyFixed<174>(dest, src); + case 175: + return MemcpyFixed<175>(dest, src); + case 176: + return MemcpyFixed<176>(dest, src); + case 177: + return MemcpyFixed<177>(dest, src); + case 178: + return MemcpyFixed<178>(dest, src); + case 179: + return MemcpyFixed<179>(dest, src); + case 180: + return MemcpyFixed<180>(dest, src); + case 181: + return MemcpyFixed<181>(dest, src); + case 182: + return MemcpyFixed<182>(dest, src); + case 183: + return MemcpyFixed<183>(dest, src); + case 184: + return MemcpyFixed<184>(dest, src); + case 185: + return MemcpyFixed<185>(dest, src); + case 186: + return MemcpyFixed<186>(dest, src); + case 187: + return MemcpyFixed<187>(dest, src); + case 188: + return MemcpyFixed<188>(dest, src); + case 189: + return MemcpyFixed<189>(dest, src); + case 190: + return MemcpyFixed<190>(dest, src); + case 191: + return MemcpyFixed<191>(dest, src); + case 192: + return MemcpyFixed<192>(dest, src); + case 193: + return MemcpyFixed<193>(dest, src); + case 194: + return MemcpyFixed<194>(dest, src); + case 195: + return MemcpyFixed<195>(dest, src); + case 196: + return MemcpyFixed<196>(dest, src); + case 197: + return MemcpyFixed<197>(dest, src); + case 198: + return MemcpyFixed<198>(dest, src); + case 199: + return MemcpyFixed<199>(dest, src); + case 200: + return MemcpyFixed<200>(dest, src); + case 201: + return MemcpyFixed<201>(dest, src); + case 202: + return MemcpyFixed<202>(dest, src); + case 203: + return MemcpyFixed<203>(dest, src); + case 204: + return MemcpyFixed<204>(dest, src); + case 205: + return MemcpyFixed<205>(dest, src); + case 206: + return MemcpyFixed<206>(dest, src); + case 207: + return MemcpyFixed<207>(dest, src); + case 208: + return MemcpyFixed<208>(dest, src); + case 209: + return MemcpyFixed<209>(dest, src); + case 210: + return MemcpyFixed<210>(dest, src); + case 211: + return MemcpyFixed<211>(dest, src); + case 212: + return MemcpyFixed<212>(dest, src); + case 213: + return MemcpyFixed<213>(dest, src); + case 214: + return MemcpyFixed<214>(dest, src); + case 215: + return MemcpyFixed<215>(dest, src); + case 216: + return MemcpyFixed<216>(dest, src); + case 217: + return MemcpyFixed<217>(dest, src); + case 218: + return MemcpyFixed<218>(dest, src); + case 219: + return MemcpyFixed<219>(dest, src); + case 220: + return MemcpyFixed<220>(dest, src); + case 221: + return MemcpyFixed<221>(dest, src); + case 222: + return MemcpyFixed<222>(dest, src); + case 223: + return MemcpyFixed<223>(dest, src); + case 224: + return MemcpyFixed<224>(dest, src); + case 225: + return MemcpyFixed<225>(dest, src); + case 226: + return MemcpyFixed<226>(dest, src); + case 227: + return MemcpyFixed<227>(dest, src); + case 228: + return MemcpyFixed<228>(dest, src); + case 229: + return MemcpyFixed<229>(dest, src); + case 230: + return MemcpyFixed<230>(dest, src); + case 231: + return MemcpyFixed<231>(dest, src); + case 232: + return MemcpyFixed<232>(dest, src); + case 233: + return MemcpyFixed<233>(dest, src); + case 234: + return MemcpyFixed<234>(dest, src); + case 235: + return MemcpyFixed<235>(dest, src); + case 236: + return MemcpyFixed<236>(dest, src); + case 237: + return MemcpyFixed<237>(dest, src); + case 238: + return MemcpyFixed<238>(dest, src); + case 239: + return MemcpyFixed<239>(dest, src); + case 240: + return MemcpyFixed<240>(dest, src); + case 241: + return MemcpyFixed<241>(dest, src); + case 242: + return MemcpyFixed<242>(dest, src); + case 243: + return MemcpyFixed<243>(dest, src); + case 244: + return MemcpyFixed<244>(dest, src); + case 245: + return MemcpyFixed<245>(dest, src); + case 246: + return MemcpyFixed<246>(dest, src); + case 247: + return MemcpyFixed<247>(dest, src); + case 248: + return MemcpyFixed<248>(dest, src); + case 249: + return MemcpyFixed<249>(dest, src); + case 250: + return MemcpyFixed<250>(dest, src); + case 251: + return MemcpyFixed<251>(dest, src); + case 252: + return MemcpyFixed<252>(dest, src); + case 253: + return MemcpyFixed<253>(dest, src); + case 254: + return MemcpyFixed<254>(dest, src); + case 255: + return MemcpyFixed<255>(dest, src); + case 256: + return MemcpyFixed<256>(dest, src); + default: + memcpy(dest, src, size); + } + // LCOV_EXCL_STOP +} + +//! This templated memcmp is significantly faster than std::memcmp, +//! but only when you are calling memcmp with a const size in a loop. +//! For instance `while () { memcmp(, , const_size); ... }` +static inline int FastMemcmp(const void *str1, const void *str2, const size_t size) { + // LCOV_EXCL_START + switch (size) { + case 0: + return 0; + case 1: + return MemcmpFixed<1>(str1, str2); + case 2: + return MemcmpFixed<2>(str1, str2); + case 3: + return MemcmpFixed<3>(str1, str2); + case 4: + return MemcmpFixed<4>(str1, str2); + case 5: + return MemcmpFixed<5>(str1, str2); + case 6: + return MemcmpFixed<6>(str1, str2); + case 7: + return MemcmpFixed<7>(str1, str2); + case 8: + return MemcmpFixed<8>(str1, str2); + case 9: + return MemcmpFixed<9>(str1, str2); + case 10: + return MemcmpFixed<10>(str1, str2); + case 11: + return MemcmpFixed<11>(str1, str2); + case 12: + return MemcmpFixed<12>(str1, str2); + case 13: + return MemcmpFixed<13>(str1, str2); + case 14: + return MemcmpFixed<14>(str1, str2); + case 15: + return MemcmpFixed<15>(str1, str2); + case 16: + return MemcmpFixed<16>(str1, str2); + case 17: + return MemcmpFixed<17>(str1, str2); + case 18: + return MemcmpFixed<18>(str1, str2); + case 19: + return MemcmpFixed<19>(str1, str2); + case 20: + return MemcmpFixed<20>(str1, str2); + case 21: + return MemcmpFixed<21>(str1, str2); + case 22: + return MemcmpFixed<22>(str1, str2); + case 23: + return MemcmpFixed<23>(str1, str2); + case 24: + return MemcmpFixed<24>(str1, str2); + case 25: + return MemcmpFixed<25>(str1, str2); + case 26: + return MemcmpFixed<26>(str1, str2); + case 27: + return MemcmpFixed<27>(str1, str2); + case 28: + return MemcmpFixed<28>(str1, str2); + case 29: + return MemcmpFixed<29>(str1, str2); + case 30: + return MemcmpFixed<30>(str1, str2); + case 31: + return MemcmpFixed<31>(str1, str2); + case 32: + return MemcmpFixed<32>(str1, str2); + case 33: + return MemcmpFixed<33>(str1, str2); + case 34: + return MemcmpFixed<34>(str1, str2); + case 35: + return MemcmpFixed<35>(str1, str2); + case 36: + return MemcmpFixed<36>(str1, str2); + case 37: + return MemcmpFixed<37>(str1, str2); + case 38: + return MemcmpFixed<38>(str1, str2); + case 39: + return MemcmpFixed<39>(str1, str2); + case 40: + return MemcmpFixed<40>(str1, str2); + case 41: + return MemcmpFixed<41>(str1, str2); + case 42: + return MemcmpFixed<42>(str1, str2); + case 43: + return MemcmpFixed<43>(str1, str2); + case 44: + return MemcmpFixed<44>(str1, str2); + case 45: + return MemcmpFixed<45>(str1, str2); + case 46: + return MemcmpFixed<46>(str1, str2); + case 47: + return MemcmpFixed<47>(str1, str2); + case 48: + return MemcmpFixed<48>(str1, str2); + case 49: + return MemcmpFixed<49>(str1, str2); + case 50: + return MemcmpFixed<50>(str1, str2); + case 51: + return MemcmpFixed<51>(str1, str2); + case 52: + return MemcmpFixed<52>(str1, str2); + case 53: + return MemcmpFixed<53>(str1, str2); + case 54: + return MemcmpFixed<54>(str1, str2); + case 55: + return MemcmpFixed<55>(str1, str2); + case 56: + return MemcmpFixed<56>(str1, str2); + case 57: + return MemcmpFixed<57>(str1, str2); + case 58: + return MemcmpFixed<58>(str1, str2); + case 59: + return MemcmpFixed<59>(str1, str2); + case 60: + return MemcmpFixed<60>(str1, str2); + case 61: + return MemcmpFixed<61>(str1, str2); + case 62: + return MemcmpFixed<62>(str1, str2); + case 63: + return MemcmpFixed<63>(str1, str2); + case 64: + return MemcmpFixed<64>(str1, str2); + default: + return memcmp(str1, str2, size); + } + // LCOV_EXCL_STOP +} + +static inline void FastMemset(void *ptr, int value, size_t size) { + // LCOV_EXCL_START + switch (size) { + case 0: + return; + case 1: + return MemsetFixed<1>(ptr, value); + case 2: + return MemsetFixed<2>(ptr, value); + case 3: + return MemsetFixed<3>(ptr, value); + case 4: + return MemsetFixed<4>(ptr, value); + case 5: + return MemsetFixed<5>(ptr, value); + case 6: + return MemsetFixed<6>(ptr, value); + case 7: + return MemsetFixed<7>(ptr, value); + case 8: + return MemsetFixed<8>(ptr, value); + case 9: + return MemsetFixed<9>(ptr, value); + case 10: + return MemsetFixed<10>(ptr, value); + case 11: + return MemsetFixed<11>(ptr, value); + case 12: + return MemsetFixed<12>(ptr, value); + case 13: + return MemsetFixed<13>(ptr, value); + case 14: + return MemsetFixed<14>(ptr, value); + case 15: + return MemsetFixed<15>(ptr, value); + case 16: + return MemsetFixed<16>(ptr, value); + case 17: + return MemsetFixed<17>(ptr, value); + case 18: + return MemsetFixed<18>(ptr, value); + case 19: + return MemsetFixed<19>(ptr, value); + case 20: + return MemsetFixed<20>(ptr, value); + case 21: + return MemsetFixed<21>(ptr, value); + case 22: + return MemsetFixed<22>(ptr, value); + case 23: + return MemsetFixed<23>(ptr, value); + case 24: + return MemsetFixed<24>(ptr, value); + case 25: + return MemsetFixed<25>(ptr, value); + case 26: + return MemsetFixed<26>(ptr, value); + case 27: + return MemsetFixed<27>(ptr, value); + case 28: + return MemsetFixed<28>(ptr, value); + case 29: + return MemsetFixed<29>(ptr, value); + case 30: + return MemsetFixed<30>(ptr, value); + case 31: + return MemsetFixed<31>(ptr, value); + case 32: + return MemsetFixed<32>(ptr, value); + case 33: + return MemsetFixed<33>(ptr, value); + case 34: + return MemsetFixed<34>(ptr, value); + case 35: + return MemsetFixed<35>(ptr, value); + case 36: + return MemsetFixed<36>(ptr, value); + case 37: + return MemsetFixed<37>(ptr, value); + case 38: + return MemsetFixed<38>(ptr, value); + case 39: + return MemsetFixed<39>(ptr, value); + case 40: + return MemsetFixed<40>(ptr, value); + case 41: + return MemsetFixed<41>(ptr, value); + case 42: + return MemsetFixed<42>(ptr, value); + case 43: + return MemsetFixed<43>(ptr, value); + case 44: + return MemsetFixed<44>(ptr, value); + case 45: + return MemsetFixed<45>(ptr, value); + case 46: + return MemsetFixed<46>(ptr, value); + case 47: + return MemsetFixed<47>(ptr, value); + case 48: + return MemsetFixed<48>(ptr, value); + case 49: + return MemsetFixed<49>(ptr, value); + case 50: + return MemsetFixed<50>(ptr, value); + case 51: + return MemsetFixed<51>(ptr, value); + case 52: + return MemsetFixed<52>(ptr, value); + case 53: + return MemsetFixed<53>(ptr, value); + case 54: + return MemsetFixed<54>(ptr, value); + case 55: + return MemsetFixed<55>(ptr, value); + case 56: + return MemsetFixed<56>(ptr, value); + case 57: + return MemsetFixed<57>(ptr, value); + case 58: + return MemsetFixed<58>(ptr, value); + case 59: + return MemsetFixed<59>(ptr, value); + case 60: + return MemsetFixed<60>(ptr, value); + case 61: + return MemsetFixed<61>(ptr, value); + case 62: + return MemsetFixed<62>(ptr, value); + case 63: + return MemsetFixed<63>(ptr, value); + case 64: + return MemsetFixed<64>(ptr, value); + case 65: + return MemsetFixed<65>(ptr, value); + case 66: + return MemsetFixed<66>(ptr, value); + case 67: + return MemsetFixed<67>(ptr, value); + case 68: + return MemsetFixed<68>(ptr, value); + case 69: + return MemsetFixed<69>(ptr, value); + case 70: + return MemsetFixed<70>(ptr, value); + case 71: + return MemsetFixed<71>(ptr, value); + case 72: + return MemsetFixed<72>(ptr, value); + case 73: + return MemsetFixed<73>(ptr, value); + case 74: + return MemsetFixed<74>(ptr, value); + case 75: + return MemsetFixed<75>(ptr, value); + case 76: + return MemsetFixed<76>(ptr, value); + case 77: + return MemsetFixed<77>(ptr, value); + case 78: + return MemsetFixed<78>(ptr, value); + case 79: + return MemsetFixed<79>(ptr, value); + case 80: + return MemsetFixed<80>(ptr, value); + case 81: + return MemsetFixed<81>(ptr, value); + case 82: + return MemsetFixed<82>(ptr, value); + case 83: + return MemsetFixed<83>(ptr, value); + case 84: + return MemsetFixed<84>(ptr, value); + case 85: + return MemsetFixed<85>(ptr, value); + case 86: + return MemsetFixed<86>(ptr, value); + case 87: + return MemsetFixed<87>(ptr, value); + case 88: + return MemsetFixed<88>(ptr, value); + case 89: + return MemsetFixed<89>(ptr, value); + case 90: + return MemsetFixed<90>(ptr, value); + case 91: + return MemsetFixed<91>(ptr, value); + case 92: + return MemsetFixed<92>(ptr, value); + case 93: + return MemsetFixed<93>(ptr, value); + case 94: + return MemsetFixed<94>(ptr, value); + case 95: + return MemsetFixed<95>(ptr, value); + case 96: + return MemsetFixed<96>(ptr, value); + case 97: + return MemsetFixed<97>(ptr, value); + case 98: + return MemsetFixed<98>(ptr, value); + case 99: + return MemsetFixed<99>(ptr, value); + case 100: + return MemsetFixed<100>(ptr, value); + case 101: + return MemsetFixed<101>(ptr, value); + case 102: + return MemsetFixed<102>(ptr, value); + case 103: + return MemsetFixed<103>(ptr, value); + case 104: + return MemsetFixed<104>(ptr, value); + case 105: + return MemsetFixed<105>(ptr, value); + case 106: + return MemsetFixed<106>(ptr, value); + case 107: + return MemsetFixed<107>(ptr, value); + case 108: + return MemsetFixed<108>(ptr, value); + case 109: + return MemsetFixed<109>(ptr, value); + case 110: + return MemsetFixed<110>(ptr, value); + case 111: + return MemsetFixed<111>(ptr, value); + case 112: + return MemsetFixed<112>(ptr, value); + case 113: + return MemsetFixed<113>(ptr, value); + case 114: + return MemsetFixed<114>(ptr, value); + case 115: + return MemsetFixed<115>(ptr, value); + case 116: + return MemsetFixed<116>(ptr, value); + case 117: + return MemsetFixed<117>(ptr, value); + case 118: + return MemsetFixed<118>(ptr, value); + case 119: + return MemsetFixed<119>(ptr, value); + case 120: + return MemsetFixed<120>(ptr, value); + case 121: + return MemsetFixed<121>(ptr, value); + case 122: + return MemsetFixed<122>(ptr, value); + case 123: + return MemsetFixed<123>(ptr, value); + case 124: + return MemsetFixed<124>(ptr, value); + case 125: + return MemsetFixed<125>(ptr, value); + case 126: + return MemsetFixed<126>(ptr, value); + case 127: + return MemsetFixed<127>(ptr, value); + case 128: + return MemsetFixed<128>(ptr, value); + case 129: + return MemsetFixed<129>(ptr, value); + case 130: + return MemsetFixed<130>(ptr, value); + case 131: + return MemsetFixed<131>(ptr, value); + case 132: + return MemsetFixed<132>(ptr, value); + case 133: + return MemsetFixed<133>(ptr, value); + case 134: + return MemsetFixed<134>(ptr, value); + case 135: + return MemsetFixed<135>(ptr, value); + case 136: + return MemsetFixed<136>(ptr, value); + case 137: + return MemsetFixed<137>(ptr, value); + case 138: + return MemsetFixed<138>(ptr, value); + case 139: + return MemsetFixed<139>(ptr, value); + case 140: + return MemsetFixed<140>(ptr, value); + case 141: + return MemsetFixed<141>(ptr, value); + case 142: + return MemsetFixed<142>(ptr, value); + case 143: + return MemsetFixed<143>(ptr, value); + case 144: + return MemsetFixed<144>(ptr, value); + case 145: + return MemsetFixed<145>(ptr, value); + case 146: + return MemsetFixed<146>(ptr, value); + case 147: + return MemsetFixed<147>(ptr, value); + case 148: + return MemsetFixed<148>(ptr, value); + case 149: + return MemsetFixed<149>(ptr, value); + case 150: + return MemsetFixed<150>(ptr, value); + case 151: + return MemsetFixed<151>(ptr, value); + case 152: + return MemsetFixed<152>(ptr, value); + case 153: + return MemsetFixed<153>(ptr, value); + case 154: + return MemsetFixed<154>(ptr, value); + case 155: + return MemsetFixed<155>(ptr, value); + case 156: + return MemsetFixed<156>(ptr, value); + case 157: + return MemsetFixed<157>(ptr, value); + case 158: + return MemsetFixed<158>(ptr, value); + case 159: + return MemsetFixed<159>(ptr, value); + case 160: + return MemsetFixed<160>(ptr, value); + case 161: + return MemsetFixed<161>(ptr, value); + case 162: + return MemsetFixed<162>(ptr, value); + case 163: + return MemsetFixed<163>(ptr, value); + case 164: + return MemsetFixed<164>(ptr, value); + case 165: + return MemsetFixed<165>(ptr, value); + case 166: + return MemsetFixed<166>(ptr, value); + case 167: + return MemsetFixed<167>(ptr, value); + case 168: + return MemsetFixed<168>(ptr, value); + case 169: + return MemsetFixed<169>(ptr, value); + case 170: + return MemsetFixed<170>(ptr, value); + case 171: + return MemsetFixed<171>(ptr, value); + case 172: + return MemsetFixed<172>(ptr, value); + case 173: + return MemsetFixed<173>(ptr, value); + case 174: + return MemsetFixed<174>(ptr, value); + case 175: + return MemsetFixed<175>(ptr, value); + case 176: + return MemsetFixed<176>(ptr, value); + case 177: + return MemsetFixed<177>(ptr, value); + case 178: + return MemsetFixed<178>(ptr, value); + case 179: + return MemsetFixed<179>(ptr, value); + case 180: + return MemsetFixed<180>(ptr, value); + case 181: + return MemsetFixed<181>(ptr, value); + case 182: + return MemsetFixed<182>(ptr, value); + case 183: + return MemsetFixed<183>(ptr, value); + case 184: + return MemsetFixed<184>(ptr, value); + case 185: + return MemsetFixed<185>(ptr, value); + case 186: + return MemsetFixed<186>(ptr, value); + case 187: + return MemsetFixed<187>(ptr, value); + case 188: + return MemsetFixed<188>(ptr, value); + case 189: + return MemsetFixed<189>(ptr, value); + case 190: + return MemsetFixed<190>(ptr, value); + case 191: + return MemsetFixed<191>(ptr, value); + case 192: + return MemsetFixed<192>(ptr, value); + case 193: + return MemsetFixed<193>(ptr, value); + case 194: + return MemsetFixed<194>(ptr, value); + case 195: + return MemsetFixed<195>(ptr, value); + case 196: + return MemsetFixed<196>(ptr, value); + case 197: + return MemsetFixed<197>(ptr, value); + case 198: + return MemsetFixed<198>(ptr, value); + case 199: + return MemsetFixed<199>(ptr, value); + case 200: + return MemsetFixed<200>(ptr, value); + case 201: + return MemsetFixed<201>(ptr, value); + case 202: + return MemsetFixed<202>(ptr, value); + case 203: + return MemsetFixed<203>(ptr, value); + case 204: + return MemsetFixed<204>(ptr, value); + case 205: + return MemsetFixed<205>(ptr, value); + case 206: + return MemsetFixed<206>(ptr, value); + case 207: + return MemsetFixed<207>(ptr, value); + case 208: + return MemsetFixed<208>(ptr, value); + case 209: + return MemsetFixed<209>(ptr, value); + case 210: + return MemsetFixed<210>(ptr, value); + case 211: + return MemsetFixed<211>(ptr, value); + case 212: + return MemsetFixed<212>(ptr, value); + case 213: + return MemsetFixed<213>(ptr, value); + case 214: + return MemsetFixed<214>(ptr, value); + case 215: + return MemsetFixed<215>(ptr, value); + case 216: + return MemsetFixed<216>(ptr, value); + case 217: + return MemsetFixed<217>(ptr, value); + case 218: + return MemsetFixed<218>(ptr, value); + case 219: + return MemsetFixed<219>(ptr, value); + case 220: + return MemsetFixed<220>(ptr, value); + case 221: + return MemsetFixed<221>(ptr, value); + case 222: + return MemsetFixed<222>(ptr, value); + case 223: + return MemsetFixed<223>(ptr, value); + case 224: + return MemsetFixed<224>(ptr, value); + case 225: + return MemsetFixed<225>(ptr, value); + case 226: + return MemsetFixed<226>(ptr, value); + case 227: + return MemsetFixed<227>(ptr, value); + case 228: + return MemsetFixed<228>(ptr, value); + case 229: + return MemsetFixed<229>(ptr, value); + case 230: + return MemsetFixed<230>(ptr, value); + case 231: + return MemsetFixed<231>(ptr, value); + case 232: + return MemsetFixed<232>(ptr, value); + case 233: + return MemsetFixed<233>(ptr, value); + case 234: + return MemsetFixed<234>(ptr, value); + case 235: + return MemsetFixed<235>(ptr, value); + case 236: + return MemsetFixed<236>(ptr, value); + case 237: + return MemsetFixed<237>(ptr, value); + case 238: + return MemsetFixed<238>(ptr, value); + case 239: + return MemsetFixed<239>(ptr, value); + case 240: + return MemsetFixed<240>(ptr, value); + case 241: + return MemsetFixed<241>(ptr, value); + case 242: + return MemsetFixed<242>(ptr, value); + case 243: + return MemsetFixed<243>(ptr, value); + case 244: + return MemsetFixed<244>(ptr, value); + case 245: + return MemsetFixed<245>(ptr, value); + case 246: + return MemsetFixed<246>(ptr, value); + case 247: + return MemsetFixed<247>(ptr, value); + case 248: + return MemsetFixed<248>(ptr, value); + case 249: + return MemsetFixed<249>(ptr, value); + case 250: + return MemsetFixed<250>(ptr, value); + case 251: + return MemsetFixed<251>(ptr, value); + case 252: + return MemsetFixed<252>(ptr, value); + case 253: + return MemsetFixed<253>(ptr, value); + case 254: + return MemsetFixed<254>(ptr, value); + case 255: + return MemsetFixed<255>(ptr, value); + case 256: + return MemsetFixed<256>(ptr, value); + default: + memset(ptr, value, size); + } + // LCOV_EXCL_STOP +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/file_buffer.hpp b/src/duckdb/src/include/duckdb/common/file_buffer.hpp new file mode 100644 index 00000000..876aeab4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/file_buffer.hpp @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/file_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/enums/debug_initialize.hpp" + +namespace duckdb { +class Allocator; +struct FileHandle; + +enum class FileBufferType : uint8_t { BLOCK = 1, MANAGED_BUFFER = 2, TINY_BUFFER = 3 }; + +//! The FileBuffer represents a buffer that can be read or written to a Direct IO FileHandle. +class FileBuffer { +public: + //! Allocates a buffer of the specified size, with room for additional header bytes + //! (typically 8 bytes). On return, this->AllocSize() >= this->size >= user_size. + //! Our allocation size will always be page-aligned, which is necessary to support + //! DIRECT_IO + FileBuffer(Allocator &allocator, FileBufferType type, uint64_t user_size); + FileBuffer(FileBuffer &source, FileBufferType type); + + virtual ~FileBuffer(); + + Allocator &allocator; + //! The type of the buffer + FileBufferType type; + //! The buffer that users can write to + data_ptr_t buffer; + //! The size of the portion that users can write to, this is equivalent to internal_size - BLOCK_HEADER_SIZE + uint64_t size; + +public: + //! Read into the FileBuffer from the specified location. + void Read(FileHandle &handle, uint64_t location); + //! Write the contents of the FileBuffer to the specified location. + void Write(FileHandle &handle, uint64_t location); + + void Clear(); + + // Same rules as the constructor. We will add room for a header, in additio to + // the requested user bytes. We will then sector-align the result. + void Resize(uint64_t user_size); + + uint64_t AllocSize() const { + return internal_size; + } + data_ptr_t InternalBuffer() { + return internal_buffer; + } + + struct MemoryRequirement { + idx_t alloc_size; + idx_t header_size; + }; + + MemoryRequirement CalculateMemory(uint64_t user_size); + + void Initialize(DebugInitialize info); + +protected: + //! The pointer to the internal buffer that will be read or written, including the buffer header + data_ptr_t internal_buffer; + //! The aligned size as passed to the constructor. This is the size that is read or written to disk. + uint64_t internal_size; + + void ReallocBuffer(size_t malloc_size); + void Init(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/file_opener.hpp b/src/duckdb/src/include/duckdb/common/file_opener.hpp new file mode 100644 index 00000000..7268599a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/file_opener.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/file_opener.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/string.hpp" +#include "duckdb/common/winapi.hpp" + +namespace duckdb { + +class ClientContext; +class Value; + +struct FileOpenerInfo { + string file_path; +}; + +//! Abstract type that provide client-specific context to FileSystem. +class FileOpener { +public: + FileOpener() { + } + virtual ~FileOpener() {}; + + virtual bool TryGetCurrentSetting(const string &key, Value &result, FileOpenerInfo &info); + virtual bool TryGetCurrentSetting(const string &key, Value &result) = 0; + virtual ClientContext *TryGetClientContext() = 0; + + DUCKDB_API static ClientContext *TryGetClientContext(FileOpener *opener); + DUCKDB_API static bool TryGetCurrentSetting(FileOpener *opener, const string &key, Value &result); + DUCKDB_API static bool TryGetCurrentSetting(FileOpener *opener, const string &key, Value &result, + FileOpenerInfo &info); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/file_system.hpp b/src/duckdb/src/include/duckdb/common/file_system.hpp new file mode 100644 index 00000000..1f06b3ed --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/file_system.hpp @@ -0,0 +1,253 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/file_system.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/enums/file_compression_type.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_buffer.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/file_glob_options.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include + +#undef CreateDirectory +#undef MoveFile +#undef RemoveDirectory + +namespace duckdb { +class AttachedDatabase; +class ClientContext; +class DatabaseInstance; +class FileOpener; +class FileSystem; + +enum class FileType { + //! Regular file + FILE_TYPE_REGULAR, + //! Directory + FILE_TYPE_DIR, + //! FIFO named pipe + FILE_TYPE_FIFO, + //! Socket + FILE_TYPE_SOCKET, + //! Symbolic link + FILE_TYPE_LINK, + //! Block device + FILE_TYPE_BLOCKDEV, + //! Character device + FILE_TYPE_CHARDEV, + //! Unknown or invalid file handle + FILE_TYPE_INVALID, +}; + +struct FileHandle { +public: + DUCKDB_API FileHandle(FileSystem &file_system, string path); + FileHandle(const FileHandle &) = delete; + DUCKDB_API virtual ~FileHandle(); + + DUCKDB_API int64_t Read(void *buffer, idx_t nr_bytes); + DUCKDB_API int64_t Write(void *buffer, idx_t nr_bytes); + DUCKDB_API void Read(void *buffer, idx_t nr_bytes, idx_t location); + DUCKDB_API void Write(void *buffer, idx_t nr_bytes, idx_t location); + DUCKDB_API void Seek(idx_t location); + DUCKDB_API void Reset(); + DUCKDB_API idx_t SeekPosition(); + DUCKDB_API void Sync(); + DUCKDB_API void Truncate(int64_t new_size); + DUCKDB_API string ReadLine(); + + DUCKDB_API bool CanSeek(); + DUCKDB_API bool OnDiskFile(); + DUCKDB_API idx_t GetFileSize(); + DUCKDB_API FileType GetType(); + + //! Closes the file handle. + DUCKDB_API virtual void Close() = 0; + + string GetPath() const { + return path; + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + +public: + FileSystem &file_system; + string path; +}; + +enum class FileLockType : uint8_t { NO_LOCK = 0, READ_LOCK = 1, WRITE_LOCK = 2 }; + +class FileFlags { +public: + //! Open file with read access + static constexpr uint8_t FILE_FLAGS_READ = 1 << 0; + //! Open file with write access + static constexpr uint8_t FILE_FLAGS_WRITE = 1 << 1; + //! Use direct IO when reading/writing to the file + static constexpr uint8_t FILE_FLAGS_DIRECT_IO = 1 << 2; + //! Create file if not exists, can only be used together with WRITE + static constexpr uint8_t FILE_FLAGS_FILE_CREATE = 1 << 3; + //! Always create a new file. If a file exists, the file is truncated. Cannot be used together with CREATE. + static constexpr uint8_t FILE_FLAGS_FILE_CREATE_NEW = 1 << 4; + //! Open file in append mode + static constexpr uint8_t FILE_FLAGS_APPEND = 1 << 5; +}; + +class FileSystem { +public: + DUCKDB_API virtual ~FileSystem(); + +public: + DUCKDB_API static constexpr FileLockType DEFAULT_LOCK = FileLockType::NO_LOCK; + DUCKDB_API static constexpr FileCompressionType DEFAULT_COMPRESSION = FileCompressionType::UNCOMPRESSED; + DUCKDB_API static FileSystem &GetFileSystem(ClientContext &context); + DUCKDB_API static FileSystem &GetFileSystem(DatabaseInstance &db); + DUCKDB_API static FileSystem &Get(AttachedDatabase &db); + + DUCKDB_API virtual unique_ptr OpenFile(const string &path, uint8_t flags, + FileLockType lock = DEFAULT_LOCK, + FileCompressionType compression = DEFAULT_COMPRESSION, + FileOpener *opener = nullptr); + + //! Read exactly nr_bytes from the specified location in the file. Fails if nr_bytes could not be read. This is + //! equivalent to calling SetFilePointer(location) followed by calling Read(). + DUCKDB_API virtual void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location); + //! Write exactly nr_bytes to the specified location in the file. Fails if nr_bytes could not be written. This is + //! equivalent to calling SetFilePointer(location) followed by calling Write(). + DUCKDB_API virtual void Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location); + //! Read nr_bytes from the specified file into the buffer, moving the file pointer forward by nr_bytes. Returns the + //! amount of bytes read. + DUCKDB_API virtual int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes); + //! Write nr_bytes from the buffer into the file, moving the file pointer forward by nr_bytes. + DUCKDB_API virtual int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes); + + //! Returns the file size of a file handle, returns -1 on error + DUCKDB_API virtual int64_t GetFileSize(FileHandle &handle); + //! Returns the file last modified time of a file handle, returns timespec with zero on all attributes on error + DUCKDB_API virtual time_t GetLastModifiedTime(FileHandle &handle); + //! Returns the file last modified time of a file handle, returns timespec with zero on all attributes on error + DUCKDB_API virtual FileType GetFileType(FileHandle &handle); + //! Truncate a file to a maximum size of new_size, new_size should be smaller than or equal to the current size of + //! the file + DUCKDB_API virtual void Truncate(FileHandle &handle, int64_t new_size); + + //! Check if a directory exists + DUCKDB_API virtual bool DirectoryExists(const string &directory); + //! Create a directory if it does not exist + DUCKDB_API virtual void CreateDirectory(const string &directory); + //! Recursively remove a directory and all files in it + DUCKDB_API virtual void RemoveDirectory(const string &directory); + //! List files in a directory, invoking the callback method for each one with (filename, is_dir) + DUCKDB_API virtual bool ListFiles(const string &directory, + const std::function &callback, + FileOpener *opener = nullptr); + + //! Move a file from source path to the target, StorageManager relies on this being an atomic action for ACID + //! properties + DUCKDB_API virtual void MoveFile(const string &source, const string &target); + //! Check if a file exists + DUCKDB_API virtual bool FileExists(const string &filename); + //! Check if path is pipe + DUCKDB_API virtual bool IsPipe(const string &filename); + //! Remove a file from disk + DUCKDB_API virtual void RemoveFile(const string &filename); + //! Sync a file handle to disk + DUCKDB_API virtual void FileSync(FileHandle &handle); + //! Sets the working directory + DUCKDB_API static void SetWorkingDirectory(const string &path); + //! Gets the working directory + DUCKDB_API static string GetWorkingDirectory(); + //! Gets the users home directory + DUCKDB_API static string GetHomeDirectory(optional_ptr opener); + //! Gets the users home directory + DUCKDB_API virtual string GetHomeDirectory(); + //! Expands a given path, including e.g. expanding the home directory of the user + DUCKDB_API static string ExpandPath(const string &path, optional_ptr opener); + //! Expands a given path, including e.g. expanding the home directory of the user + DUCKDB_API virtual string ExpandPath(const string &path); + //! Returns the system-available memory in bytes. Returns DConstants::INVALID_INDEX if the system function fails. + DUCKDB_API static idx_t GetAvailableMemory(); + //! Path separator for path + DUCKDB_API virtual string PathSeparator(const string &path); + //! Checks if path is starts with separator (i.e., '/' on UNIX '\\' on Windows) + DUCKDB_API bool IsPathAbsolute(const string &path); + //! Normalize an absolute path - the goal of normalizing is converting "\test.db" and "C:/test.db" into "C:\test.db" + //! so that the database system cache can correctly + DUCKDB_API string NormalizeAbsolutePath(const string &path); + //! Join two paths together + DUCKDB_API string JoinPath(const string &a, const string &path); + //! Convert separators in a path to the local separators (e.g. convert "/" into \\ on windows) + DUCKDB_API string ConvertSeparators(const string &path); + //! Extract the base name of a file (e.g. if the input is lib/example.dll the base name is 'example') + DUCKDB_API string ExtractBaseName(const string &path); + //! Extract the name of a file (e.g if the input is lib/example.dll the name is 'example.dll') + DUCKDB_API string ExtractName(const string &path); + + //! Returns the value of an environment variable - or the empty string if it is not set + DUCKDB_API static string GetEnvVariable(const string &name); + + //! Whether there is a glob in the string + DUCKDB_API static bool HasGlob(const string &str); + //! Runs a glob on the file system, returning a list of matching files + DUCKDB_API virtual vector Glob(const string &path, FileOpener *opener = nullptr); + DUCKDB_API vector GlobFiles(const string &path, ClientContext &context, + FileGlobOptions options = FileGlobOptions::DISALLOW_EMPTY); + + //! registers a sub-file system to handle certain file name prefixes, e.g. http:// etc. + DUCKDB_API virtual void RegisterSubSystem(unique_ptr sub_fs); + DUCKDB_API virtual void RegisterSubSystem(FileCompressionType compression_type, unique_ptr fs); + + //! Unregister a sub-filesystem by name + DUCKDB_API virtual void UnregisterSubSystem(const string &name); + + //! List registered sub-filesystems, including builtin ones + DUCKDB_API virtual vector ListSubSystems(); + + //! Whether or not a sub-system can handle a specific file path + DUCKDB_API virtual bool CanHandleFile(const string &fpath); + + //! Set the file pointer of a file handle to a specified location. Reads and writes will happen from this location + DUCKDB_API virtual void Seek(FileHandle &handle, idx_t location); + //! Reset a file to the beginning (equivalent to Seek(handle, 0) for simple files) + DUCKDB_API virtual void Reset(FileHandle &handle); + DUCKDB_API virtual idx_t SeekPosition(FileHandle &handle); + + //! Whether or not we can seek into the file + DUCKDB_API virtual bool CanSeek(); + //! Whether or not the FS handles plain files on disk. This is relevant for certain optimizations, as random reads + //! in a file on-disk are much cheaper than e.g. random reads in a file over the network + DUCKDB_API virtual bool OnDiskFile(FileHandle &handle); + + DUCKDB_API virtual unique_ptr OpenCompressedFile(unique_ptr handle, bool write); + + //! Create a LocalFileSystem. + DUCKDB_API static unique_ptr CreateLocal(); + + //! Return the name of the filesytem. Used for forming diagnosis messages. + DUCKDB_API virtual std::string GetName() const = 0; + + //! Whether or not a file is remote or local, based only on file path + DUCKDB_API static bool IsRemoteFile(const string &path); + + DUCKDB_API virtual void SetDisabledFileSystems(const vector &names); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/filename_pattern.hpp b/src/duckdb/src/include/duckdb/common/filename_pattern.hpp new file mode 100644 index 00000000..3795fc36 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/filename_pattern.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/filename_pattern.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/types/uuid.hpp" + +namespace duckdb { + +class FilenamePattern { + +public: + FilenamePattern() : _base("data_"), _pos(_base.length()), _uuid(false) { + } + ~FilenamePattern() { + } + +public: + void SetFilenamePattern(const string &pattern); + string CreateFilename(FileSystem &fs, const string &path, const string &extension, idx_t offset) const; + +private: + string _base; + idx_t _pos; + bool _uuid; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/fixed_size_map.hpp b/src/duckdb/src/include/duckdb/common/fixed_size_map.hpp new file mode 100644 index 00000000..02950e09 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/fixed_size_map.hpp @@ -0,0 +1,208 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/fixed_size_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/validity_mask.hpp" + +namespace duckdb { + +template +struct fixed_size_map_iterator_t; + +template +struct const_fixed_size_map_iterator_t; + +template +class fixed_size_map_t { + friend struct fixed_size_map_iterator_t; + friend struct const_fixed_size_map_iterator_t; + +public: + using key_type = idx_t; + using mapped_type = T; + +public: + explicit fixed_size_map_t(idx_t capacity_p = 0) : capacity(capacity_p) { + resize(capacity); + } + + idx_t size() const { + return count; + } + + void resize(idx_t capacity_p) { + capacity = capacity_p; + occupied = ValidityMask(capacity); + values = make_unsafe_uniq_array(capacity + 1); + clear(); + } + + void clear() { + count = 0; + occupied.SetAllInvalid(capacity); + } + + T &operator[](const idx_t &key) { + D_ASSERT(key < capacity); + count += 1 - occupied.RowIsValid(key); + occupied.SetValidUnsafe(key); + return values[key]; + } + + const T &operator[](const idx_t &key) const { + D_ASSERT(key < capacity); + return values[key]; + } + + fixed_size_map_iterator_t begin() { + return fixed_size_map_iterator_t(begin_internal(), *this); + } + + const_fixed_size_map_iterator_t begin() const { + return const_fixed_size_map_iterator_t(begin_internal(), *this); + } + + fixed_size_map_iterator_t end() { + return fixed_size_map_iterator_t(capacity, *this); + } + + const_fixed_size_map_iterator_t end() const { + return const_fixed_size_map_iterator_t(capacity, *this); + } + + fixed_size_map_iterator_t find(const idx_t &index) { + if (occupied.RowIsValid(index)) { + return fixed_size_map_iterator_t(index, *this); + } else { + return end(); + } + } + + const_fixed_size_map_iterator_t find(const idx_t &index) const { + if (occupied.RowIsValid(index)) { + return const_fixed_size_map_iterator_t(index, *this); + } else { + return end(); + } + } + +private: + idx_t begin_internal() const { + idx_t index; + for (index = 0; index < capacity; index++) { + if (occupied.RowIsValid(index)) { + break; + } + } + return index; + } + +private: + idx_t capacity; + idx_t count; + + ValidityMask occupied; + unsafe_unique_array values; +}; + +template +struct fixed_size_map_iterator_t { +public: + fixed_size_map_iterator_t(idx_t index_p, fixed_size_map_t &map_p) : map(map_p), current(index_p) { + } + + fixed_size_map_iterator_t &operator++() { + for (current++; current < map.capacity; current++) { + if (map.occupied.RowIsValidUnsafe(current)) { + break; + } + } + return *this; + } + + fixed_size_map_iterator_t operator++(int) { + fixed_size_map_iterator_t tmp = *this; + ++(*this); + return tmp; + } + + idx_t &GetKey() { + return current; + } + + const idx_t &GetKey() const { + return current; + } + + T &GetValue() { + return map.values[current]; + } + + const T &GetValue() const { + return map.values[current]; + } + + friend bool operator==(const fixed_size_map_iterator_t &a, const fixed_size_map_iterator_t &b) { + return a.current == b.current; + } + + friend bool operator!=(const fixed_size_map_iterator_t &a, const fixed_size_map_iterator_t &b) { + return !(a == b); + } + +private: + fixed_size_map_t ↦ + idx_t current; +}; + +template +struct const_fixed_size_map_iterator_t { +public: + const_fixed_size_map_iterator_t(idx_t index_p, const fixed_size_map_t &map_p) : map(map_p), current(index_p) { + } + + const_fixed_size_map_iterator_t &operator++() { + for (current++; current < map.capacity; current++) { + if (map.occupied.RowIsValidUnsafe(current)) { + break; + } + } + return *this; + } + + const_fixed_size_map_iterator_t operator++(int) { + const_fixed_size_map_iterator_t tmp = *this; + ++(*this); + return tmp; + } + + const idx_t &GetKey() const { + return current; + } + + const T &GetValue() const { + return map.values[current]; + } + + friend bool operator==(const const_fixed_size_map_iterator_t &a, const const_fixed_size_map_iterator_t &b) { + return a.current == b.current; + } + + friend bool operator!=(const const_fixed_size_map_iterator_t &a, const const_fixed_size_map_iterator_t &b) { + return !(a == b); + } + +private: + const fixed_size_map_t ↦ + idx_t current; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/fsst.hpp b/src/duckdb/src/include/duckdb/common/fsst.hpp new file mode 100644 index 00000000..f0cf6faf --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/fsst.hpp @@ -0,0 +1,19 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/fsst.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +class FSSTPrimitives { +public: + static string_t DecompressValue(void *duckdb_fsst_decoder, Vector &result, const char *compressed_string, + idx_t compressed_string_len); + static Value DecompressValue(void *duckdb_fsst_decoder, const char *compressed_string, idx_t compressed_string_len); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/fstream.hpp b/src/duckdb/src/include/duckdb/common/fstream.hpp new file mode 100644 index 00000000..ff866092 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/fstream.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/fstream.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace duckdb { +using std::endl; +using std::fstream; +using std::ifstream; +using std::ios; +using std::ios_base; +using std::ofstream; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/gzip_file_system.hpp b/src/duckdb/src/include/duckdb/common/gzip_file_system.hpp new file mode 100644 index 00000000..9d6f6a21 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/gzip_file_system.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/gzip_file_system.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/compressed_file_system.hpp" + +namespace duckdb { + +class GZipFileSystem : public CompressedFileSystem { + // 32 KB + static constexpr const idx_t BUFFER_SIZE = 1u << 15; + +public: + unique_ptr OpenCompressedFile(unique_ptr handle, bool write) override; + + std::string GetName() const override { + return "GZipFileSystem"; + } + + //! Verifies that a buffer contains a valid GZIP header + static void VerifyGZIPHeader(uint8_t gzip_hdr[], idx_t read_count); + //! Consumes a byte stream as a gzip string, returning the decompressed string + static string UncompressGZIPString(const string &in); + + unique_ptr CreateStream() override; + idx_t InBufferSize() override; + idx_t OutBufferSize() override; +}; + +static constexpr const uint8_t GZIP_COMPRESSION_DEFLATE = 0x08; + +static constexpr const uint8_t GZIP_FLAG_ASCII = 0x1; +static constexpr const uint8_t GZIP_FLAG_MULTIPART = 0x2; +static constexpr const uint8_t GZIP_FLAG_EXTRA = 0x4; +static constexpr const uint8_t GZIP_FLAG_NAME = 0x8; +static constexpr const uint8_t GZIP_FLAG_COMMENT = 0x10; +static constexpr const uint8_t GZIP_FLAG_ENCRYPT = 0x20; + +static constexpr const uint8_t GZIP_HEADER_MINSIZE = 10; +// MAXSIZE should be the same as input buffer size +static constexpr const idx_t GZIP_HEADER_MAXSIZE = 1u << 15; +static constexpr const uint8_t GZIP_FOOTER_SIZE = 8; + +static constexpr const unsigned char GZIP_FLAG_UNSUPPORTED = + GZIP_FLAG_ASCII | GZIP_FLAG_MULTIPART | GZIP_FLAG_COMMENT | GZIP_FLAG_ENCRYPT; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/helper.hpp b/src/duckdb/src/include/duckdb/common/helper.hpp new file mode 100644 index 00000000..048064fd --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/helper.hpp @@ -0,0 +1,213 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/helper.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include +#include + +#ifdef _MSC_VER +#define suint64_t int64_t +#endif + +#if defined(_WIN32) || defined(_WIN64) +#define DUCKDB_WINDOWS +#elif defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) +#define DUCKDB_POSIX +#endif + +namespace duckdb { + +// explicit fallthrough for switch_statementss +#ifndef __has_cpp_attribute // For backwards compatibility +#define __has_cpp_attribute(x) 0 +#endif +#if __has_cpp_attribute(clang::fallthrough) +#define DUCKDB_EXPLICIT_FALLTHROUGH [[clang::fallthrough]] +#elif __has_cpp_attribute(gnu::fallthrough) +#define DUCKDB_EXPLICIT_FALLTHROUGH [[gnu::fallthrough]] +#else +#define DUCKDB_EXPLICIT_FALLTHROUGH +#endif + +template +struct AlwaysFalse { + static constexpr bool value = false; +}; + +template +using reference = std::reference_wrapper; + +template +struct __unique_if +{ + typedef unique_ptr<_Tp, std::default_delete<_Tp>, SAFE> __unique_single; +}; + +template +struct __unique_if<_Tp[]> +{ + typedef unique_ptr<_Tp[]> __unique_array_unknown_bound; +}; + +template +struct __unique_if<_Tp[_Np]> +{ + typedef void __unique_array_known_bound; +}; + +template +inline +typename __unique_if<_Tp, true>::__unique_single +make_uniq(_Args&&... __args) +{ + return unique_ptr<_Tp, std::default_delete<_Tp>, true>(new _Tp(std::forward<_Args>(__args)...)); +} + +template +inline +typename __unique_if<_Tp, false>::__unique_single +make_unsafe_uniq(_Args&&... __args) +{ + return unique_ptr<_Tp, std::default_delete<_Tp>, false>(new _Tp(std::forward<_Args>(__args)...)); +} + +template +inline unique_ptr<_Tp[], std::default_delete<_Tp>, true> +make_uniq_array(size_t __n) +{ + return unique_ptr<_Tp[], std::default_delete<_Tp>, true>(new _Tp[__n]()); +} + +template +inline unique_ptr<_Tp[], std::default_delete<_Tp>, false> +make_unsafe_uniq_array(size_t __n) +{ + return unique_ptr<_Tp[], std::default_delete<_Tp>, false>(new _Tp[__n]()); +} + +template + typename __unique_if<_Tp>::__unique_array_known_bound + make_uniq(_Args&&...) = delete; + + +template +unique_ptr make_uniq_base(Args &&... args) { + return unique_ptr(new T(std::forward(args)...)); +} + +#ifdef DUCKDB_ENABLE_DEPRECATED_API +template +unique_ptr make_unique_base(Args &&... args) { + return unique_ptr(new T(std::forward(args)...)); +} +#endif // DUCKDB_ENABLE_DEPRECATED_API + +template +unique_ptr unique_ptr_cast(unique_ptr src) { + return unique_ptr(static_cast(src.release())); +} + +struct SharedConstructor { + template + static shared_ptr Create(ARGS &&...args) { + return make_shared(std::forward(args)...); + } +}; + +struct UniqueConstructor { + template + static unique_ptr Create(ARGS &&...args) { + return make_uniq(std::forward(args)...); + } +}; + +#ifdef DUCKDB_DEBUG_MOVE +template +typename std::remove_reference::type&& move(T&& t) noexcept { + // the nonsensical sizeof check ensures this is never instantiated + static_assert(sizeof(T) == 0, "Use std::move instead of unqualified move or duckdb::move"); +} +#endif + +template +static duckdb::unique_ptr make_unique(_Args&&... __args) { +#ifndef DUCKDB_ENABLE_DEPRECATED_API + static_assert(sizeof(T) == 0, "Use make_uniq instead of make_unique!"); +#endif // DUCKDB_ENABLE_DEPRECATED_API + return unique_ptr(new T(std::forward<_Args>(__args)...)); +} + +template +T MaxValue(T a, T b) { + return a > b ? a : b; +} + +template +T MinValue(T a, T b) { + return a < b ? a : b; +} + +template +T AbsValue(T a) { + return a < 0 ? -a : a; +} + +//Align value (ceiling) +template +static inline T AlignValue(T n) { + return ((n + (val - 1)) / val) * val; +} + +template +static inline bool ValueIsAligned(T n) { + return (n % val) == 0; +} + +template +T SignValue(T a) { + return a < 0 ? -1 : 1; +} + +template +const T Load(const_data_ptr_t ptr) { + T ret; + memcpy(&ret, ptr, sizeof(ret)); + return ret; +} + +template +void Store(const T &val, data_ptr_t ptr) { + memcpy(ptr, (void *)&val, sizeof(val)); +} + +//! This assigns a shared pointer, but ONLY assigns if "target" is not equal to "source" +//! If this is often the case, this manner of assignment is significantly faster (~20X faster) +//! Since it avoids the need of an atomic incref/decref at the cost of a single pointer comparison +//! Benchmark: https://gist.github.com/Mytherin/4db3faa8e233c4a9b874b21f62bb4b96 +//! If the shared pointers are not the same, the penalty is very low (on the order of 1%~ slower) +//! This method should always be preferred if there is a (reasonable) chance that the pointers are the same +template +void AssignSharedPointer(shared_ptr &target, const shared_ptr &source) { + if (target.get() != source.get()) { + target = source; + } +} + +template +using const_reference = std::reference_wrapper; + +//! Returns whether or not two reference wrappers refer to the same object +template +bool RefersToSameObject(const reference &A, const reference &B) { + return &A.get() == &B.get(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/hive_partitioning.hpp b/src/duckdb/src/include/duckdb/common/hive_partitioning.hpp new file mode 100644 index 00000000..1ca824fa --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/hive_partitioning.hpp @@ -0,0 +1,124 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/hive_partitioning.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/partitioned_column_data.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/filter_combiner.hpp" +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "re2/re2.h" + +#include +#include + +namespace duckdb { + +class HivePartitioning { +public: + //! Parse a filename that follows the hive partitioning scheme + DUCKDB_API static std::map Parse(const string &filename); + DUCKDB_API static std::map Parse(const string &filename, duckdb_re2::RE2 ®ex); + //! Prunes a list of filenames based on a set of filters, can be used by TableFunctions in the + //! pushdown_complex_filter function to skip files with filename-based filters. Also removes the filters that always + //! evaluate to true. + DUCKDB_API static void ApplyFiltersToFileList(ClientContext &context, vector &files, + vector> &filters, + unordered_map &column_map, LogicalGet &get, + bool hive_enabled, bool filename_enabled); + + //! Returns the compiled regex pattern to match hive partitions + DUCKDB_API static const string REGEX_STRING; +}; + +struct HivePartitionKey { + //! Columns by which we want to partition + vector values; + //! Precomputed hash of values + hash_t hash; + + struct Hash { + std::size_t operator()(const HivePartitionKey &k) const { + return k.hash; + } + }; + + struct Equality { + bool operator()(const HivePartitionKey &a, const HivePartitionKey &b) const { + if (a.values.size() != b.values.size()) { + return false; + } + for (idx_t i = 0; i < a.values.size(); i++) { + if (!Value::NotDistinctFrom(a.values[i], b.values[i])) { + return false; + } + } + return true; + } + }; +}; + +//! Maps hive partitions to partition_ids +typedef unordered_map hive_partition_map_t; + +//! class shared between HivePartitionColumnData classes that synchronizes partition discovery between threads. +//! each HivePartitionedColumnData will hold a local copy of the key->partition map +class GlobalHivePartitionState { +public: + mutex lock; + hive_partition_map_t partition_map; + //! Used for incremental updating local copies of the partition map; + vector partitions; +}; + +class HivePartitionedColumnData : public PartitionedColumnData { +public: + HivePartitionedColumnData(ClientContext &context, vector types, vector partition_by_cols, + shared_ptr global_state = nullptr) + : PartitionedColumnData(PartitionedColumnDataType::HIVE, context, std::move(types)), + global_state(std::move(global_state)), group_by_columns(std::move(partition_by_cols)), + hashes_v(LogicalType::HASH) { + InitializeKeys(); + } + HivePartitionedColumnData(const HivePartitionedColumnData &other); + void ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) override; + + //! Reverse lookup map to reconstruct keys from a partition id + std::map GetReverseMap(); + +protected: + //! Create allocators for all currently registered partitions + void GrowAllocators(); + //! Create append states for all currently registered partitions + void GrowAppendState(PartitionedColumnDataAppendState &state); + //! Create and initialize partitions for all currently registered partitions + void GrowPartitions(PartitionedColumnDataAppendState &state); + //! Register a newly discovered partition + idx_t RegisterNewPartition(HivePartitionKey key, PartitionedColumnDataAppendState &state); + //! Copy the newly added entries in the global_state.map to the local_partition_map (requires lock!) + void SynchronizeLocalMap(); + +private: + void InitializeKeys(); + +protected: + //! Shared HivePartitionedColumnData should always have a global state to allow parallel key discovery + shared_ptr global_state; + //! Thread-local copy of the partition map + hive_partition_map_t local_partition_map; + //! The columns that make up the key + vector group_by_columns; + //! Thread-local pre-allocated vector for hashes + Vector hashes_v; + //! Thread-local pre-allocated HivePartitionKeys + vector keys; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/http_state.hpp b/src/duckdb/src/include/duckdb/common/http_state.hpp new file mode 100644 index 00000000..ca60606b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/http_state.hpp @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/http_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_opener.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { + +class CachedFileHandle; + +//! Represents a file that is intended to be fully downloaded, then used in parallel by multiple threads +class CachedFile : public std::enable_shared_from_this { + friend class CachedFileHandle; + +public: + unique_ptr GetHandle() { + auto this_ptr = shared_from_this(); + return make_uniq(this_ptr); + } + +private: + //! Cached Data + shared_ptr data; + //! Data capacity + uint64_t capacity = 0; + //! Lock for initializing the file + mutex lock; + //! When initialized is set to true, the file is safe for parallel reading without holding the lock + atomic initialized = {false}; +}; + +//! Handle to a CachedFile +class CachedFileHandle { +public: + explicit CachedFileHandle(shared_ptr &file_p); + + //! allocate a buffer for the file + void AllocateBuffer(idx_t size); + //! Indicate the file is fully downloaded and safe for parallel reading without lock + void SetInitialized(); + //! Grow buffer to new size, copying over `bytes_to_copy` to the new buffer + void GrowBuffer(idx_t new_capacity, idx_t bytes_to_copy); + //! Write to the buffer + void Write(const char *buffer, idx_t length, idx_t offset = 0); + + bool Initialized() { + return file->initialized; + } + const char *GetData() { + return file->data.get(); + } + uint64_t GetCapacity() { + return file->capacity; + } + +private: + unique_ptr> lock; + shared_ptr file; +}; + +class HTTPState { +public: + //! Reset all counters and cached files + void Reset(); + //! Get cache entry, create if not exists + shared_ptr &GetCachedFile(const string &path); + //! Helper function to get the HTTP state + static shared_ptr TryGetState(FileOpener *opener); + + bool IsEmpty() { + return head_count == 0 && get_count == 0 && put_count == 0 && post_count == 0 && total_bytes_received == 0 && + total_bytes_sent == 0; + } + + atomic head_count {0}; + atomic get_count {0}; + atomic put_count {0}; + atomic post_count {0}; + atomic total_bytes_received {0}; + atomic total_bytes_sent {0}; + +private: + //! Mutex to lock when getting the cached file(Parallel Only) + mutex cached_files_mutex; + //! In case of fully downloading the file, the cached files of this query + unordered_map> cached_files; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/hugeint.hpp b/src/duckdb/src/include/duckdb/common/hugeint.hpp new file mode 100644 index 00000000..e1bb454b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/hugeint.hpp @@ -0,0 +1,78 @@ +#pragma once + +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/string.hpp" +#include +#include "duckdb/common/typedefs.hpp" + +namespace duckdb { + +struct hugeint_t { +public: + uint64_t lower; + int64_t upper; + +public: + hugeint_t() = default; + DUCKDB_API hugeint_t(int64_t value); // NOLINT: Allow implicit conversion from `int64_t` + constexpr hugeint_t(int64_t upper, uint64_t lower) : lower(lower), upper(upper) { + } + constexpr hugeint_t(const hugeint_t &rhs) = default; + constexpr hugeint_t(hugeint_t &&rhs) = default; + hugeint_t &operator=(const hugeint_t &rhs) = default; + hugeint_t &operator=(hugeint_t &&rhs) = default; + + DUCKDB_API string ToString() const; + + // comparison operators + DUCKDB_API bool operator==(const hugeint_t &rhs) const; + DUCKDB_API bool operator!=(const hugeint_t &rhs) const; + DUCKDB_API bool operator<=(const hugeint_t &rhs) const; + DUCKDB_API bool operator<(const hugeint_t &rhs) const; + DUCKDB_API bool operator>(const hugeint_t &rhs) const; + DUCKDB_API bool operator>=(const hugeint_t &rhs) const; + + // arithmetic operators + DUCKDB_API hugeint_t operator+(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator-(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator*(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator/(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator%(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator-() const; + + // bitwise operators + DUCKDB_API hugeint_t operator>>(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator<<(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator&(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator|(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator^(const hugeint_t &rhs) const; + DUCKDB_API hugeint_t operator~() const; + + // in-place operators + DUCKDB_API hugeint_t &operator+=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator-=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator*=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator/=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator%=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator>>=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator<<=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator&=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator|=(const hugeint_t &rhs); + DUCKDB_API hugeint_t &operator^=(const hugeint_t &rhs); + + // boolean operators + DUCKDB_API explicit operator bool() const; + DUCKDB_API bool operator!() const; + + // cast operators + DUCKDB_API explicit operator uint8_t() const; + DUCKDB_API explicit operator uint16_t() const; + DUCKDB_API explicit operator uint32_t() const; + DUCKDB_API explicit operator uint64_t() const; + DUCKDB_API explicit operator int8_t() const; + DUCKDB_API explicit operator int16_t() const; + DUCKDB_API explicit operator int32_t() const; + DUCKDB_API explicit operator int64_t() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/index_map.hpp b/src/duckdb/src/include/duckdb/common/index_map.hpp new file mode 100644 index 00000000..eb97fd3d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/index_map.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/index_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +struct LogicalIndexHashFunction { + uint64_t operator()(const LogicalIndex &index) const { + return std::hash()(index.index); + } +}; + +struct PhysicalIndexHashFunction { + uint64_t operator()(const PhysicalIndex &index) const { + return std::hash()(index.index); + } +}; + +template +using logical_index_map_t = unordered_map; + +using logical_index_set_t = unordered_set; + +template +using physical_index_map_t = unordered_map; + +using physical_index_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/index_vector.hpp b/src/duckdb/src/include/duckdb/common/index_vector.hpp new file mode 100644 index 00000000..3f35b676 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/index_vector.hpp @@ -0,0 +1,84 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/index_vector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +template +class IndexVector { +public: + void push_back(T element) { + internal_vector.push_back(std::move(element)); + } + + T &operator[](INDEX_TYPE idx) { + return internal_vector[idx.index]; + } + + const T &operator[](INDEX_TYPE idx) const { + return internal_vector[idx.index]; + } + + idx_t size() const { + return internal_vector.size(); + } + + bool empty() const { + return internal_vector.empty(); + } + + void reserve(idx_t size) { + internal_vector.reserve(size); + } + + typename vector::iterator begin() { + return internal_vector.begin(); + } + typename vector::iterator end() { + return internal_vector.end(); + } + typename vector::const_iterator cbegin() { + return internal_vector.cbegin(); + } + typename vector::const_iterator cend() { + return internal_vector.cend(); + } + typename vector::const_iterator begin() const { + return internal_vector.begin(); + } + typename vector::const_iterator end() const { + return internal_vector.end(); + } + + void Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "internal_vector", internal_vector); + } + + static IndexVector Deserialize(Deserializer &deserializer) { + IndexVector result; + deserializer.ReadProperty(100, "internal_vector", result.internal_vector); + return result; + } + +private: + vector internal_vector; +}; + +template +using physical_index_vector_t = IndexVector; + +template +using logical_index_vector_t = IndexVector; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/likely.hpp b/src/duckdb/src/include/duckdb/common/likely.hpp new file mode 100644 index 00000000..aea5ccb6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/likely.hpp @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/likely.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#if __GNUC__ +#define DUCKDB_BUILTIN_EXPECT(cond, expected_value) (__builtin_expect(cond, expected_value)) +#else +#define DUCKDB_BUILTIN_EXPECT(cond, expected_value) (cond) +#endif + +#define DUCKDB_LIKELY(...) DUCKDB_BUILTIN_EXPECT((__VA_ARGS__), 1) +#define DUCKDB_UNLIKELY(...) DUCKDB_BUILTIN_EXPECT((__VA_ARGS__), 0) diff --git a/src/duckdb/src/include/duckdb/common/limits.hpp b/src/duckdb/src/include/duckdb/common/limits.hpp new file mode 100644 index 00000000..bdf44d0c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/limits.hpp @@ -0,0 +1,104 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/limits.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/types.hpp" + +#include + +// Undef annoying windows macro +#undef max + +#include + +namespace duckdb { + +template +struct NumericLimits { + static constexpr T Minimum() { + return std::numeric_limits::lowest(); + } + static constexpr T Maximum() { + return std::numeric_limits::max(); + } + DUCKDB_API static constexpr bool IsSigned() { + return std::is_signed::value; + } + DUCKDB_API static constexpr idx_t Digits(); +}; + +template <> +struct NumericLimits { + static constexpr hugeint_t Minimum() { + return {std::numeric_limits::lowest(), 1}; + }; + static constexpr hugeint_t Maximum() { + return {std::numeric_limits::max(), std::numeric_limits::max()}; + }; + static constexpr bool IsSigned() { + return true; + } + + static constexpr idx_t Digits() { + return 39; + } +}; + +template <> +constexpr idx_t NumericLimits::Digits() { + return 3; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 5; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 10; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 19; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 3; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 5; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 10; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 20; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 127; +} + +template <> +constexpr idx_t NumericLimits::Digits() { + return 250; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/list.hpp b/src/duckdb/src/include/duckdb/common/list.hpp new file mode 100644 index 00000000..21b20895 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/list.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/list.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::list; +} diff --git a/src/duckdb/src/include/duckdb/common/local_file_system.hpp b/src/duckdb/src/include/duckdb/common/local_file_system.hpp new file mode 100644 index 00000000..4babd04b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/local_file_system.hpp @@ -0,0 +1,100 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/local_file_system.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_system.hpp" + +namespace duckdb { + +class LocalFileSystem : public FileSystem { +public: + unique_ptr OpenFile(const string &path, uint8_t flags, FileLockType lock = FileLockType::NO_LOCK, + FileCompressionType compression = FileCompressionType::UNCOMPRESSED, + FileOpener *opener = nullptr) override; + + //! Read exactly nr_bytes from the specified location in the file. Fails if nr_bytes could not be read. This is + //! equivalent to calling SetFilePointer(location) followed by calling Read(). + void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + //! Write exactly nr_bytes to the specified location in the file. Fails if nr_bytes could not be written. This is + //! equivalent to calling SetFilePointer(location) followed by calling Write(). + void Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + //! Read nr_bytes from the specified file into the buffer, moving the file pointer forward by nr_bytes. Returns the + //! amount of bytes read. + int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + //! Write nr_bytes from the buffer into the file, moving the file pointer forward by nr_bytes. + int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + + //! Returns the file size of a file handle, returns -1 on error + int64_t GetFileSize(FileHandle &handle) override; + //! Returns the file last modified time of a file handle, returns timespec with zero on all attributes on error + time_t GetLastModifiedTime(FileHandle &handle) override; + //! Returns the file last modified time of a file handle, returns timespec with zero on all attributes on error + FileType GetFileType(FileHandle &handle) override; + //! Truncate a file to a maximum size of new_size, new_size should be smaller than or equal to the current size of + //! the file + void Truncate(FileHandle &handle, int64_t new_size) override; + + //! Check if a directory exists + bool DirectoryExists(const string &directory) override; + //! Create a directory if it does not exist + void CreateDirectory(const string &directory) override; + //! Recursively remove a directory and all files in it + void RemoveDirectory(const string &directory) override; + //! List files in a directory, invoking the callback method for each one with (filename, is_dir) + bool ListFiles(const string &directory, const std::function &callback, + FileOpener *opener = nullptr) override; + //! Move a file from source path to the target, StorageManager relies on this being an atomic action for ACID + //! properties + void MoveFile(const string &source, const string &target) override; + //! Check if a file exists + bool FileExists(const string &filename) override; + + //! Check if path is a pipe + bool IsPipe(const string &filename) override; + //! Remove a file from disk + void RemoveFile(const string &filename) override; + //! Sync a file handle to disk + void FileSync(FileHandle &handle) override; + + //! Runs a glob on the file system, returning a list of matching files + vector Glob(const string &path, FileOpener *opener = nullptr) override; + + bool CanHandleFile(const string &fpath) override { + //! Whether or not a sub-system can handle a specific file path + return false; + } + + //! Set the file pointer of a file handle to a specified location. Reads and writes will happen from this location + void Seek(FileHandle &handle, idx_t location) override; + //! Return the current seek posiiton in the file. + idx_t SeekPosition(FileHandle &handle) override; + + //! Whether or not we can seek into the file + bool CanSeek() override; + //! Whether or not the FS handles plain files on disk. This is relevant for certain optimizations, as random reads + //! in a file on-disk are much cheaper than e.g. random reads in a file over the network + bool OnDiskFile(FileHandle &handle) override; + + std::string GetName() const override { + return "LocalFileSystem"; + } + + //! Returns the last Win32 error, in string format. Returns an empty string if there is no error, or on non-Windows + //! systems. + static std::string GetLastErrorAsString(); + +private: + //! Set the file pointer of a file handle to a specified location. Reads and writes will happen from this location + void SetFilePointer(FileHandle &handle, idx_t location); + idx_t GetFilePointer(FileHandle &handle); + + vector FetchFileWithoutGlob(const string &path, FileOpener *opener, bool absolute_path); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/map.hpp b/src/duckdb/src/include/duckdb/common/map.hpp new file mode 100644 index 00000000..cdcfdcdf --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/map.hpp @@ -0,0 +1,16 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::map; +using std::multimap; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/memory_safety.hpp b/src/duckdb/src/include/duckdb/common/memory_safety.hpp new file mode 100644 index 00000000..7d4d5eab --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/memory_safety.hpp @@ -0,0 +1,15 @@ +#pragma once + +namespace duckdb { + +template +struct MemorySafety { +#ifdef DEBUG + // In DEBUG mode safety is always on + static constexpr bool enabled = true; +#else + static constexpr bool enabled = ENABLED; +#endif +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/multi_file_reader.hpp b/src/duckdb/src/include/duckdb/common/multi_file_reader.hpp new file mode 100644 index 00000000..c83fae5f --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/multi_file_reader.hpp @@ -0,0 +1,207 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/multi_file_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/multi_file_reader_options.hpp" +#include "duckdb/common/enums/file_glob_options.hpp" +#include "duckdb/common/union_by_name.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { +class TableFunction; +class TableFunctionSet; +class TableFilterSet; +class LogicalGet; +class Expression; +class ClientContext; +class DataChunk; + +struct HivePartitioningIndex { + HivePartitioningIndex(string value, idx_t index); + + string value; + idx_t index; + + DUCKDB_API void Serialize(Serializer &serializer) const; + DUCKDB_API static HivePartitioningIndex Deserialize(Deserializer &deserializer); +}; + +//! The bind data for the multi-file reader, obtained through MultiFileReader::BindReader +struct MultiFileReaderBindData { + //! The index of the filename column (if any) + idx_t filename_idx = DConstants::INVALID_INDEX; + //! The set of hive partitioning indexes (if any) + vector hive_partitioning_indexes; + + DUCKDB_API void Serialize(Serializer &serializer) const; + DUCKDB_API static MultiFileReaderBindData Deserialize(Deserializer &deserializer); +}; + +struct MultiFileFilterEntry { + idx_t index = DConstants::INVALID_INDEX; + bool is_constant = false; +}; + +struct MultiFileConstantEntry { + MultiFileConstantEntry(idx_t column_id, Value value_p) : column_id(column_id), value(std::move(value_p)) { + } + + //! The column id to apply the constant value to + idx_t column_id; + //! The constant value + Value value; +}; + +struct MultiFileReaderData { + //! The column ids to read from the file + vector column_ids; + //! The mapping of column id -> result column id + //! The result chunk will be filled as follows: chunk.data[column_mapping[i]] = ReadColumn(column_ids[i]); + vector column_mapping; + //! Whether or not there are no columns to read. This can happen when a file only consists of constants + bool empty_columns = false; + //! Filters can point to either (1) local columns in the file, or (2) constant values in the `constant_map` + //! This map specifies where the to-be-filtered value can be found + vector filter_map; + //! The set of table filters + optional_ptr filters; + //! The constants that should be applied at the various positions + vector constant_map; + //! Map of column_id -> cast, used when reading multiple files when files have diverging types + //! for the same column + unordered_map cast_map; +}; + +struct MultiFileReader { + //! Add the parameters for multi-file readers (e.g. union_by_name, filename) to a table function + DUCKDB_API static void AddParameters(TableFunction &table_function); + //! Performs any globbing for the multi-file reader and returns a list of files to be read + DUCKDB_API static vector GetFileList(ClientContext &context, const Value &input, const string &name, + FileGlobOptions options = FileGlobOptions::DISALLOW_EMPTY); + //! Parse the named parameters of a multi-file reader + DUCKDB_API static bool ParseOption(const string &key, const Value &val, MultiFileReaderOptions &options, + ClientContext &context); + //! Perform complex filter pushdown into the multi-file reader, potentially filtering out files that should be read + //! If "true" the first file has been eliminated + DUCKDB_API static bool ComplexFilterPushdown(ClientContext &context, vector &files, + const MultiFileReaderOptions &options, LogicalGet &get, + vector> &filters); + //! Bind the options of the multi-file reader, potentially emitting any extra columns that are required + DUCKDB_API static MultiFileReaderBindData BindOptions(MultiFileReaderOptions &options, const vector &files, + vector &return_types, vector &names); + //! Finalize the bind phase of the multi-file reader after we know (1) the required (output) columns, and (2) the + //! pushed down table filters + DUCKDB_API static void FinalizeBind(const MultiFileReaderOptions &file_options, + const MultiFileReaderBindData &options, const string &filename, + const vector &local_names, const vector &global_types, + const vector &global_names, const vector &global_column_ids, + MultiFileReaderData &reader_data, ClientContext &context); + //! Create all required mappings from the global types/names to the file-local types/names + DUCKDB_API static void CreateMapping(const string &file_name, const vector &local_types, + const vector &local_names, const vector &global_types, + const vector &global_names, const vector &global_column_ids, + optional_ptr filters, MultiFileReaderData &reader_data, + const string &initial_file); + //! Finalize the reading of a chunk - applying any constants that are required + DUCKDB_API static void FinalizeChunk(const MultiFileReaderBindData &bind_data, + const MultiFileReaderData &reader_data, DataChunk &chunk); + //! Creates a table function set from a single reader function (including e.g. list parameters, etc) + DUCKDB_API static TableFunctionSet CreateFunctionSet(TableFunction table_function); + + template + static MultiFileReaderBindData BindUnionReader(ClientContext &context, vector &return_types, + vector &names, RESULT_CLASS &result, + OPTIONS_CLASS &options) { + D_ASSERT(options.file_options.union_by_name); + vector union_col_names; + vector union_col_types; + // obtain the set of union column names + types by unifying the types of all of the files + // note that this requires opening readers for each file and reading the metadata of each file + auto union_readers = + UnionByName::UnionCols(context, result.files, union_col_types, union_col_names, options); + + std::move(union_readers.begin(), union_readers.end(), std::back_inserter(result.union_readers)); + // perform the binding on the obtained set of names + types + auto bind_data = + MultiFileReader::BindOptions(options.file_options, result.files, union_col_types, union_col_names); + names = union_col_names; + return_types = union_col_types; + result.Initialize(result.union_readers[0]); + D_ASSERT(names.size() == return_types.size()); + return bind_data; + } + + template + static MultiFileReaderBindData BindReader(ClientContext &context, vector &return_types, + vector &names, RESULT_CLASS &result, OPTIONS_CLASS &options) { + if (options.file_options.union_by_name) { + return BindUnionReader(context, return_types, names, result, options); + } else { + shared_ptr reader; + reader = make_shared(context, result.files[0], options); + return_types = reader->return_types; + names = reader->names; + result.Initialize(std::move(reader)); + return MultiFileReader::BindOptions(options.file_options, result.files, return_types, names); + } + } + + template + static void InitializeReader(READER_CLASS &reader, const MultiFileReaderOptions &options, + const MultiFileReaderBindData &bind_data, const vector &global_types, + const vector &global_names, const vector &global_column_ids, + optional_ptr table_filters, const string &initial_file, + ClientContext &context) { + FinalizeBind(options, bind_data, reader.GetFileName(), reader.GetNames(), global_types, global_names, + global_column_ids, reader.reader_data, context); + CreateMapping(reader.GetFileName(), reader.GetTypes(), reader.GetNames(), global_types, global_names, + global_column_ids, table_filters, reader.reader_data, initial_file); + reader.reader_data.filters = table_filters; + } + + template + static void PruneReaders(BIND_DATA &data) { + unordered_set file_set; + for (auto &file : data.files) { + file_set.insert(file); + } + + if (data.initial_reader) { + // check if the initial reader should still be read + auto entry = file_set.find(data.initial_reader->GetFileName()); + if (entry == file_set.end()) { + data.initial_reader.reset(); + } + } + for (idx_t r = 0; r < data.union_readers.size(); r++) { + if (!data.union_readers[r]) { + data.union_readers.erase(data.union_readers.begin() + r); + r--; + continue; + } + // check if the union reader should still be read or not + auto entry = file_set.find(data.union_readers[r]->GetFileName()); + if (entry == file_set.end()) { + data.union_readers.erase(data.union_readers.begin() + r); + r--; + continue; + } + } + } + +private: + static void CreateNameMapping(const string &file_name, const vector &local_types, + const vector &local_names, const vector &global_types, + const vector &global_names, const vector &global_column_ids, + MultiFileReaderData &reader_data, const string &initial_file); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/multi_file_reader_options.hpp b/src/duckdb/src/include/duckdb/common/multi_file_reader_options.hpp new file mode 100644 index 00000000..74a737f3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/multi_file_reader_options.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/multi_file_reader_options.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/hive_partitioning.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { +struct BindInfo; + +struct MultiFileReaderOptions { + bool filename = false; + bool hive_partitioning = false; + bool auto_detect_hive_partitioning = true; + bool union_by_name = false; + bool hive_types_autocast = true; + case_insensitive_map_t hive_types_schema; + + DUCKDB_API void Serialize(Serializer &serializer) const; + DUCKDB_API static MultiFileReaderOptions Deserialize(Deserializer &source); + DUCKDB_API void AddBatchInfo(BindInfo &bind_info) const; + DUCKDB_API void AutoDetectHivePartitioning(const vector &files, ClientContext &context); + DUCKDB_API static bool AutoDetectHivePartitioningInternal(const vector &files, ClientContext &context); + DUCKDB_API void AutoDetectHiveTypesInternal(const string &file, ClientContext &context); + DUCKDB_API void VerifyHiveTypesArePartitions(const std::map &partitions) const; + DUCKDB_API LogicalType GetHiveLogicalType(const string &hive_partition_column) const; + DUCKDB_API Value GetHivePartitionValue(const string &base, const string &entry, ClientContext &context) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/mutex.hpp b/src/duckdb/src/include/duckdb/common/mutex.hpp new file mode 100644 index 00000000..1758c0ff --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/mutex.hpp @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/mutex.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#ifdef __MVS__ +#include +#endif +#include + +namespace duckdb { +using std::lock_guard; +using std::mutex; +using std::unique_lock; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/named_parameter_map.hpp b/src/duckdb/src/include/duckdb/common/named_parameter_map.hpp new file mode 100644 index 00000000..b39e81e1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/named_parameter_map.hpp @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/named_parameter_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/types.hpp" +namespace duckdb { + +using named_parameter_type_map_t = case_insensitive_map_t; +using named_parameter_map_t = case_insensitive_map_t; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/numeric_utils.hpp b/src/duckdb/src/include/duckdb/common/numeric_utils.hpp new file mode 100644 index 00000000..73ee6191 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/numeric_utils.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/numeric_utils.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include "duckdb/common/hugeint.hpp" + +namespace duckdb { + +template +struct MakeSigned { + using type = typename std::make_signed::type; +}; + +template <> +struct MakeSigned { + using type = hugeint_t; +}; + +template +struct MakeUnsigned { + using type = typename std::make_unsigned::type; +}; + +// hugeint_t does not actually have an unsigned variant (yet), but this is required to make compression work +// if an unsigned variant gets implemented this (probably) can be changed without breaking anything +template <> +struct MakeUnsigned { + using type = hugeint_t; +}; + +template +struct IsIntegral { + static constexpr bool value = std::is_integral::value; +}; + +template <> +struct IsIntegral { + static constexpr bool value = true; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/opener_file_system.hpp b/src/duckdb/src/include/duckdb/common/opener_file_system.hpp new file mode 100644 index 00000000..28b35245 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/opener_file_system.hpp @@ -0,0 +1,122 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/opener_file_system.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_system.hpp" + +namespace duckdb { + +// The OpenerFileSystem is wrapper for a file system that pushes an appropriate FileOpener into the various API calls +class OpenerFileSystem : public FileSystem { +public: + virtual FileSystem &GetFileSystem() const = 0; + virtual optional_ptr GetOpener() const = 0; + + unique_ptr OpenFile(const string &path, uint8_t flags, FileLockType lock = FileLockType::NO_LOCK, + FileCompressionType compression = FileCompressionType::UNCOMPRESSED, + FileOpener *opener = nullptr) override { + if (opener) { + throw InternalException("OpenerFileSystem cannot take an opener - the opener is pushed automatically"); + } + return GetFileSystem().OpenFile(path, flags, lock, compression, GetOpener().get()); + } + + void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override { + GetFileSystem().Read(handle, buffer, nr_bytes, location); + }; + + void Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override { + GetFileSystem().Write(handle, buffer, nr_bytes, location); + } + + int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override { + return GetFileSystem().Read(handle, buffer, nr_bytes); + } + + int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override { + return GetFileSystem().Write(handle, buffer, nr_bytes); + } + + int64_t GetFileSize(FileHandle &handle) override { + return GetFileSystem().GetFileSize(handle); + } + time_t GetLastModifiedTime(FileHandle &handle) override { + return GetFileSystem().GetLastModifiedTime(handle); + } + FileType GetFileType(FileHandle &handle) override { + return GetFileSystem().GetFileType(handle); + } + + void Truncate(FileHandle &handle, int64_t new_size) override { + GetFileSystem().Truncate(handle, new_size); + } + + void FileSync(FileHandle &handle) override { + GetFileSystem().FileSync(handle); + } + + bool DirectoryExists(const string &directory) override { + return GetFileSystem().DirectoryExists(directory); + } + void CreateDirectory(const string &directory) override { + return GetFileSystem().CreateDirectory(directory); + } + + void RemoveDirectory(const string &directory) override { + return GetFileSystem().RemoveDirectory(directory); + } + + bool ListFiles(const string &directory, const std::function &callback, + FileOpener *opener = nullptr) override { + if (opener) { + throw InternalException("OpenerFileSystem cannot take an opener - the opener is pushed automatically"); + } + return GetFileSystem().ListFiles(directory, callback, GetOpener().get()); + } + + void MoveFile(const string &source, const string &target) override { + GetFileSystem().MoveFile(source, target); + } + + string GetHomeDirectory() override { + return FileSystem::GetHomeDirectory(GetOpener()); + } + + string ExpandPath(const string &path) override { + return FileSystem::ExpandPath(path, GetOpener()); + } + + bool FileExists(const string &filename) override { + return GetFileSystem().FileExists(filename); + } + + bool IsPipe(const string &filename) override { + return GetFileSystem().IsPipe(filename); + } + void RemoveFile(const string &filename) override { + GetFileSystem().RemoveFile(filename); + } + + string PathSeparator(const string &path) override { + return GetFileSystem().PathSeparator(path); + } + + vector Glob(const string &path, FileOpener *opener = nullptr) override { + if (opener) { + throw InternalException("OpenerFileSystem cannot take an opener - the opener is pushed automatically"); + } + return GetFileSystem().Glob(path, GetOpener().get()); + } + + std::string GetName() const override { + return "OpenerFileSystem - " + GetFileSystem().GetName(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/abs.hpp b/src/duckdb/src/include/duckdb/common/operator/abs.hpp new file mode 100644 index 00000000..26f48726 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/abs.hpp @@ -0,0 +1,81 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/abs.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" + +namespace duckdb { + +struct AbsOperator { + template + static inline TR Operation(TA input) { + return input < 0 ? -input : input; + } +}; + +template <> +inline hugeint_t AbsOperator::Operation(hugeint_t input) { + const hugeint_t zero(0); + return (input < zero) ? -input : input; +} + +struct TryAbsOperator { + template + static inline TR Operation(TA input) { + return AbsOperator::Operation(input); + } +}; + +template <> +inline int8_t TryAbsOperator::Operation(int8_t input) { + if (input == NumericLimits::Minimum()) { + throw OutOfRangeException("Overflow on abs(%d)", input); + } + return input < 0 ? -input : input; +} + +template <> +inline int16_t TryAbsOperator::Operation(int16_t input) { + if (input == NumericLimits::Minimum()) { + throw OutOfRangeException("Overflow on abs(%d)", input); + } + return input < 0 ? -input : input; +} + +template <> +inline int32_t TryAbsOperator::Operation(int32_t input) { + if (input == NumericLimits::Minimum()) { + throw OutOfRangeException("Overflow on abs(%d)", input); + } + return input < 0 ? -input : input; +} + +template <> +inline int64_t TryAbsOperator::Operation(int64_t input) { + if (input == NumericLimits::Minimum()) { + throw OutOfRangeException("Overflow on abs(%d)", input); + } + return input < 0 ? -input : input; +} + +template <> +inline dtime_t TryAbsOperator::Operation(dtime_t input) { + return dtime_t(TryAbsOperator::Operation(input.micros)); +} + +template <> +inline interval_t TryAbsOperator::Operation(interval_t input) { + return {TryAbsOperator::Operation(input.months), + TryAbsOperator::Operation(input.days), + TryAbsOperator::Operation(input.micros)}; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/add.hpp b/src/duckdb/src/include/duckdb/common/operator/add.hpp new file mode 100644 index 00000000..b37e5ca8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/add.hpp @@ -0,0 +1,129 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/add.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/type_util.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +struct AddOperator { + template + static inline TR Operation(TA left, TB right) { + return left + right; + } +}; + +template <> +float AddOperator::Operation(float left, float right); +template <> +double AddOperator::Operation(double left, double right); +template <> +date_t AddOperator::Operation(date_t left, int32_t right); +template <> +date_t AddOperator::Operation(int32_t left, date_t right); +template <> +timestamp_t AddOperator::Operation(date_t left, dtime_t right); +template <> +timestamp_t AddOperator::Operation(dtime_t left, date_t right); +template <> +interval_t AddOperator::Operation(interval_t left, interval_t right); +template <> +date_t AddOperator::Operation(date_t left, interval_t right); +template <> +date_t AddOperator::Operation(interval_t left, date_t right); +template <> +timestamp_t AddOperator::Operation(timestamp_t left, interval_t right); +template <> +timestamp_t AddOperator::Operation(interval_t left, timestamp_t right); + +struct TryAddOperator { + template + static inline bool Operation(TA left, TB right, TR &result) { + throw InternalException("Unimplemented type for TryAddOperator"); + } +}; + +template <> +bool TryAddOperator::Operation(uint8_t left, uint8_t right, uint8_t &result); +template <> +bool TryAddOperator::Operation(uint16_t left, uint16_t right, uint16_t &result); +template <> +bool TryAddOperator::Operation(uint32_t left, uint32_t right, uint32_t &result); +template <> +bool TryAddOperator::Operation(uint64_t left, uint64_t right, uint64_t &result); + +template <> +bool TryAddOperator::Operation(int8_t left, int8_t right, int8_t &result); +template <> +bool TryAddOperator::Operation(int16_t left, int16_t right, int16_t &result); +template <> +bool TryAddOperator::Operation(int32_t left, int32_t right, int32_t &result); +template <> +DUCKDB_API bool TryAddOperator::Operation(int64_t left, int64_t right, int64_t &result); +template <> +bool TryAddOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result); + +struct AddOperatorOverflowCheck { + template + static inline TR Operation(TA left, TB right) { + TR result; + if (!TryAddOperator::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in addition of %s (%s + %s)!", TypeIdToString(GetTypeId()), + NumericHelper::ToString(left), NumericHelper::ToString(right)); + } + return result; + } +}; + +struct TryDecimalAdd { + template + static inline bool Operation(TA left, TB right, TR &result) { + throw InternalException("Unimplemented type for TryDecimalAdd"); + } +}; + +template <> +bool TryDecimalAdd::Operation(int16_t left, int16_t right, int16_t &result); +template <> +bool TryDecimalAdd::Operation(int32_t left, int32_t right, int32_t &result); +template <> +bool TryDecimalAdd::Operation(int64_t left, int64_t right, int64_t &result); +template <> +bool TryDecimalAdd::Operation(hugeint_t left, hugeint_t right, hugeint_t &result); + +struct DecimalAddOverflowCheck { + template + static inline TR Operation(TA left, TB right) { + TR result; + if (!TryDecimalAdd::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in addition of DECIMAL(18) (%d + %d). You might want to add an " + "explicit cast to a bigger decimal.", + left, right); + } + return result; + } +}; + +template <> +hugeint_t DecimalAddOverflowCheck::Operation(hugeint_t left, hugeint_t right); + +struct AddTimeOperator { + template + static inline TR Operation(TA left, TB right); +}; + +template <> +dtime_t AddTimeOperator::Operation(dtime_t left, interval_t right); +template <> +dtime_t AddTimeOperator::Operation(interval_t left, dtime_t right); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/aggregate_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/aggregate_operators.hpp new file mode 100644 index 00000000..583d862d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/aggregate_operators.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/aggregate_operators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include "duckdb/common/operator/comparison_operators.hpp" + +namespace duckdb { + +struct Min { + template + static inline T Operation(T left, T right) { + return LessThan::Operation(left, right) ? left : right; + } +}; + +struct Max { + template + static inline T Operation(T left, T right) { + return GreaterThan::Operation(left, right) ? left : right; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp new file mode 100644 index 00000000..fdcc90cd --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp @@ -0,0 +1,815 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/cast_operators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/operator/convert_to_string.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { +struct ValidityMask; +class Vector; + +struct TryCast { + template + static inline bool Operation(SRC input, DST &result, bool strict = false) { + throw NotImplementedException("Unimplemented type for cast (%s -> %s)", GetTypeId(), GetTypeId()); + } +}; + +struct TryCastErrorMessage { + template + static inline bool Operation(SRC input, DST &result, string *error_message, bool strict = false) { + throw NotImplementedException("Unimplemented type for cast (%s -> %s)", GetTypeId(), GetTypeId()); + } +}; + +struct TryCastErrorMessageCommaSeparated { + template + static inline bool Operation(SRC input, DST &result, string *error_message, bool strict = false) { + throw NotImplementedException("Unimplemented type for cast (%s -> %s)", GetTypeId(), GetTypeId()); + } +}; + +template +static string CastExceptionText(SRC input) { + if (std::is_same()) { + return "Could not convert string '" + ConvertToString::Operation(input) + "' to " + + TypeIdToString(GetTypeId()); + } + if (TypeIsNumber() && TypeIsNumber()) { + return "Type " + TypeIdToString(GetTypeId()) + " with value " + ConvertToString::Operation(input) + + " can't be cast because the value is out of range for the destination type " + + TypeIdToString(GetTypeId()); + } + return "Type " + TypeIdToString(GetTypeId()) + " with value " + ConvertToString::Operation(input) + + " can't be cast to the destination type " + TypeIdToString(GetTypeId()); +} + +struct Cast { + template + static inline DST Operation(SRC input) { + DST result; + if (!TryCast::Operation(input, result)) { + throw InvalidInputException(CastExceptionText(input)); + } + return result; + } +}; + +struct HandleCastError { + static void AssignError(const string &error_message, string *error_message_ptr) { + if (!error_message_ptr) { + throw ConversionException(error_message); + } + if (error_message_ptr->empty()) { + *error_message_ptr = error_message; + } + } +}; + +//===--------------------------------------------------------------------===// +// Cast bool -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(bool input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(bool input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast int8_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int8_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast int16_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int16_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast int32_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int32_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast int64_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(int64_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast hugeint_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(hugeint_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast uint8_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint8_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast uint16_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint16_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast uint32_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint32_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast uint64_t -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(uint64_t input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast float -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(float input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(float input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// Cast double -> Numeric +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(double input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(double input, double &result, bool strict); + +//===--------------------------------------------------------------------===// +// String -> Numeric Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(string_t input, bool &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, int8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, int16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, int32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, int64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, uint8_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, uint16_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, uint32_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, uint64_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, hugeint_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, float &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, double &result, bool strict); +template <> +DUCKDB_API bool TryCastErrorMessage::Operation(string_t input, float &result, string *error_message, bool strict); +template <> +DUCKDB_API bool TryCastErrorMessage::Operation(string_t input, double &result, string *error_message, bool strict); +template <> +DUCKDB_API bool TryCastErrorMessageCommaSeparated::Operation(string_t input, float &result, string *error_message, + bool strict); +template <> +DUCKDB_API bool TryCastErrorMessageCommaSeparated::Operation(string_t input, double &result, string *error_message, + bool strict); + +//===--------------------------------------------------------------------===// +// Date Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(date_t input, date_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(date_t input, timestamp_t &result, bool strict); + +//===--------------------------------------------------------------------===// +// Time Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(dtime_t input, dtime_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(dtime_t input, dtime_tz_t &result, bool strict); + +//===--------------------------------------------------------------------===// +// Time With Time Zone Casts (Offset) +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(dtime_tz_t input, dtime_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(dtime_tz_t input, dtime_tz_t &result, bool strict); + +//===--------------------------------------------------------------------===// +// Timestamp Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(timestamp_t input, date_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(timestamp_t input, dtime_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(timestamp_t input, dtime_tz_t &result, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(timestamp_t input, timestamp_t &result, bool strict); + +//===--------------------------------------------------------------------===// +// Interval Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCast::Operation(interval_t input, interval_t &result, bool strict); + +//===--------------------------------------------------------------------===// +// String -> Date Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastErrorMessage::Operation(string_t input, date_t &result, string *error_message, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, date_t &result, bool strict); +template <> +date_t Cast::Operation(string_t input); +//===--------------------------------------------------------------------===// +// String -> Time Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastErrorMessage::Operation(string_t input, dtime_t &result, string *error_message, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, dtime_t &result, bool strict); +template <> +dtime_t Cast::Operation(string_t input); +//===--------------------------------------------------------------------===// +// String -> TimeTZ Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastErrorMessage::Operation(string_t input, dtime_tz_t &result, string *error_message, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, dtime_tz_t &result, bool strict); +template <> +dtime_tz_t Cast::Operation(string_t input); +//===--------------------------------------------------------------------===// +// String -> Timestamp Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastErrorMessage::Operation(string_t input, timestamp_t &result, string *error_message, bool strict); +template <> +DUCKDB_API bool TryCast::Operation(string_t input, timestamp_t &result, bool strict); +template <> +timestamp_t Cast::Operation(string_t input); +//===--------------------------------------------------------------------===// +// String -> Interval Casts +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastErrorMessage::Operation(string_t input, interval_t &result, string *error_message, bool strict); + +//===--------------------------------------------------------------------===// +// string -> Non-Standard Timestamps +//===--------------------------------------------------------------------===// +struct TryCastToTimestampNS { + template + static inline bool Operation(SRC input, DST &result, bool strict = false) { + throw InternalException("Unsupported type for try cast to timestamp (ns)"); + } +}; + +struct TryCastToTimestampMS { + template + static inline bool Operation(SRC input, DST &result, bool strict = false) { + throw InternalException("Unsupported type for try cast to timestamp (ms)"); + } +}; + +struct TryCastToTimestampSec { + template + static inline bool Operation(SRC input, DST &result, bool strict = false) { + throw InternalException("Unsupported type for try cast to timestamp (s)"); + } +}; + +template <> +DUCKDB_API bool TryCastToTimestampNS::Operation(string_t input, timestamp_t &result, bool strict); +template <> +DUCKDB_API bool TryCastToTimestampMS::Operation(string_t input, timestamp_t &result, bool strict); +template <> +DUCKDB_API bool TryCastToTimestampSec::Operation(string_t input, timestamp_t &result, bool strict); + +template <> +DUCKDB_API bool TryCastToTimestampNS::Operation(date_t input, timestamp_t &result, bool strict); +template <> +DUCKDB_API bool TryCastToTimestampMS::Operation(date_t input, timestamp_t &result, bool strict); +template <> +DUCKDB_API bool TryCastToTimestampSec::Operation(date_t input, timestamp_t &result, bool strict); + +//===--------------------------------------------------------------------===// +// Non-Standard Timestamps -> string/standard timestamp +//===--------------------------------------------------------------------===// + +struct CastFromTimestampNS { + template + static inline string_t Operation(SRC input, Vector &result) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastFromTimestampMS { + template + static inline string_t Operation(SRC input, Vector &result) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastFromTimestampSec { + template + static inline string_t Operation(SRC input, Vector &result) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastTimestampUsToMs { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastTimestampUsToNs { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastTimestampUsToSec { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastTimestampMsToUs { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastTimestampMsToNs { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to TIMESTAMP_NS could not be performed!"); + } +}; + +struct CastTimestampNsToUs { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastTimestampSecToMs { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to TIMESTAMP_MS could not be performed!"); + } +}; + +struct CastTimestampSecToUs { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); + } +}; + +struct CastTimestampSecToNs { + template + static inline DST Operation(SRC input) { + throw duckdb::NotImplementedException("Cast to TIMESTAMP_NS could not be performed!"); + } +}; + +template <> +duckdb::timestamp_t CastTimestampUsToMs::Operation(duckdb::timestamp_t input); +template <> +duckdb::timestamp_t CastTimestampUsToNs::Operation(duckdb::timestamp_t input); +template <> +duckdb::timestamp_t CastTimestampUsToSec::Operation(duckdb::timestamp_t input); +template <> +duckdb::timestamp_t CastTimestampMsToUs::Operation(duckdb::timestamp_t input); +template <> +duckdb::timestamp_t CastTimestampMsToNs::Operation(duckdb::timestamp_t input); +template <> +duckdb::timestamp_t CastTimestampNsToUs::Operation(duckdb::timestamp_t input); +template <> +duckdb::timestamp_t CastTimestampSecToMs::Operation(duckdb::timestamp_t input); +template <> +duckdb::timestamp_t CastTimestampSecToUs::Operation(duckdb::timestamp_t input); +template <> +duckdb::timestamp_t CastTimestampSecToNs::Operation(duckdb::timestamp_t input); + +template <> +duckdb::string_t CastFromTimestampNS::Operation(duckdb::timestamp_t input, Vector &result); +template <> +duckdb::string_t CastFromTimestampMS::Operation(duckdb::timestamp_t input, Vector &result); +template <> +duckdb::string_t CastFromTimestampSec::Operation(duckdb::timestamp_t input, Vector &result); + +//===--------------------------------------------------------------------===// +// Blobs +//===--------------------------------------------------------------------===// +struct CastFromBlob { + template + static inline string_t Operation(SRC input, Vector &result) { + throw duckdb::NotImplementedException("Cast from blob could not be performed!"); + } +}; +template <> +duckdb::string_t CastFromBlob::Operation(duckdb::string_t input, Vector &vector); + +struct CastFromBlobToBit { + template + static inline string_t Operation(SRC input, Vector &result) { + throw NotImplementedException("Cast from blob could not be performed!"); + } +}; +template <> +string_t CastFromBlobToBit::Operation(string_t input, Vector &result); + +struct TryCastToBlob { + template + static inline bool Operation(SRC input, DST &result, Vector &result_vector, string *error_message, + bool strict = false) { + throw InternalException("Unsupported type for try cast to blob"); + } +}; +template <> +bool TryCastToBlob::Operation(string_t input, string_t &result, Vector &result_vector, string *error_message, + bool strict); + +//===--------------------------------------------------------------------===// +// Bits +//===--------------------------------------------------------------------===// +struct CastFromBitToString { + template + static inline string_t Operation(SRC input, Vector &result) { + throw duckdb::NotImplementedException("Cast from bit could not be performed!"); + } +}; +template <> +duckdb::string_t CastFromBitToString::Operation(duckdb::string_t input, Vector &vector); + +struct CastFromBitToNumeric { + template + static inline bool Operation(SRC input, DST &result, bool strict = false) { + D_ASSERT(input.GetSize() > 1); + + // TODO: Allow conversion if the significant bytes of the bitstring can be cast to the target type + // Currently only allows bitstring -> numeric if the full bitstring fits inside the numeric type + if (input.GetSize() - 1 > sizeof(DST)) { + throw ConversionException("Bitstring doesn't fit inside of %s", GetTypeId()); + } + Bit::BitToNumeric(input, result); + return (true); + } +}; +template <> +bool CastFromBitToNumeric::Operation(string_t input, bool &result, bool strict); +template <> +bool CastFromBitToNumeric::Operation(string_t input, hugeint_t &result, bool strict); + +struct CastFromBitToBlob { + template + static inline string_t Operation(SRC input, Vector &result) { + D_ASSERT(input.GetSize() > 1); + return StringVector::AddStringOrBlob(result, Bit::BitToBlob(input)); + } +}; + +struct TryCastToBit { + template + static inline bool Operation(SRC input, DST &result, Vector &result_vector, string *error_message, + bool strict = false) { + throw InternalException("Unsupported type for try cast to bit"); + } +}; + +template <> +bool TryCastToBit::Operation(string_t input, string_t &result, Vector &result_vector, string *error_message, + bool strict); + +//===--------------------------------------------------------------------===// +// UUID +//===--------------------------------------------------------------------===// +struct CastFromUUID { + template + static inline string_t Operation(SRC input, Vector &result) { + throw duckdb::NotImplementedException("Cast from uuid could not be performed!"); + } +}; +template <> +duckdb::string_t CastFromUUID::Operation(duckdb::hugeint_t input, Vector &vector); + +struct TryCastToUUID { + template + static inline bool Operation(SRC input, DST &result, Vector &result_vector, string *error_message, + bool strict = false) { + throw InternalException("Unsupported type for try cast to uuid"); + } +}; + +template <> +DUCKDB_API bool TryCastToUUID::Operation(string_t input, hugeint_t &result, Vector &result_vector, + string *error_message, bool strict); + +//===--------------------------------------------------------------------===// +// Pointers +//===--------------------------------------------------------------------===// +struct CastFromPointer { + template + static inline string_t Operation(SRC input, Vector &result) { + throw duckdb::NotImplementedException("Cast from pointer could not be performed!"); + } +}; +template <> +duckdb::string_t CastFromPointer::Operation(uintptr_t input, Vector &vector); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp new file mode 100644 index 00000000..ef7203a2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp @@ -0,0 +1,228 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/comparison_operators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/helper.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/string_type.hpp" + +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Comparison Operations +//===--------------------------------------------------------------------===// +struct Equals { + template + static inline bool Operation(const T &left, const T &right) { + return left == right; + } +}; +struct NotEquals { + template + static inline bool Operation(const T &left, const T &right) { + return !Equals::Operation(left, right); + } +}; + +struct GreaterThan { + template + static inline bool Operation(const T &left, const T &right) { + return left > right; + } +}; + +struct GreaterThanEquals { + template + static inline bool Operation(const T &left, const T &right) { + return !GreaterThan::Operation(right, left); + } +}; + +struct LessThan { + template + static inline bool Operation(const T &left, const T &right) { + return GreaterThan::Operation(right, left); + } +}; + +struct LessThanEquals { + template + static inline bool Operation(const T &left, const T &right) { + return !GreaterThan::Operation(left, right); + } +}; + +template <> +DUCKDB_API bool Equals::Operation(const float &left, const float &right); +template <> +DUCKDB_API bool Equals::Operation(const double &left, const double &right); + +template <> +DUCKDB_API bool GreaterThan::Operation(const float &left, const float &right); +template <> +DUCKDB_API bool GreaterThan::Operation(const double &left, const double &right); + +template <> +DUCKDB_API bool GreaterThanEquals::Operation(const float &left, const float &right); +template <> +DUCKDB_API bool GreaterThanEquals::Operation(const double &left, const double &right); + +// Distinct semantics are from Postgres record sorting. NULL = NULL and not-NULL < NULL +// Deferring to the non-distinct operations removes the need for further specialisation. +// TODO: To reverse the semantics, swap left_null and right_null for comparisons +struct DistinctFrom { + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + if (left_null || right_null) { + return left_null != right_null; + } + return NotEquals::Operation(left, right); + } +}; + +struct NotDistinctFrom { + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + return !DistinctFrom::Operation(left, right, left_null, right_null); + } +}; + +struct DistinctGreaterThan { + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + if (left_null || right_null) { + return !right_null; + } + return GreaterThan::Operation(left, right); + } +}; + +struct DistinctGreaterThanNullsFirst { + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + return DistinctGreaterThan::Operation(left, right, right_null, left_null); + } +}; + +struct DistinctGreaterThanEquals { + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + return !DistinctGreaterThan::Operation(right, left, right_null, left_null); + } +}; + +struct DistinctLessThan { + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + return DistinctGreaterThan::Operation(right, left, right_null, left_null); + } +}; + +struct DistinctLessThanNullsFirst { + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + return DistinctGreaterThan::Operation(right, left, left_null, right_null); + } +}; + +struct DistinctLessThanEquals { + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + return !DistinctGreaterThan::Operation(left, right, left_null, right_null); + } +}; + +//===--------------------------------------------------------------------===// +// Comparison Operator Wrappers (so (Not)DistinctFrom have the same API) +//===--------------------------------------------------------------------===// +template +struct ComparisonOperationWrapper { + static constexpr const bool COMPARE_NULL = false; + + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + if (right_null || left_null) { + return false; + } + return OP::template Operation(left, right); + } +}; + +template <> +struct ComparisonOperationWrapper { + static constexpr const bool COMPARE_NULL = true; + + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + return DistinctFrom::template Operation(left, right, left_null, right_null); + } +}; + +template <> +struct ComparisonOperationWrapper { + static constexpr const bool COMPARE_NULL = true; + + template + static inline bool Operation(const T &left, const T &right, bool left_null, bool right_null) { + return NotDistinctFrom::template Operation(left, right, left_null, right_null); + } +}; + +//===--------------------------------------------------------------------===// +// Specialized Boolean Comparison Operators +//===--------------------------------------------------------------------===// +template <> +inline bool GreaterThan::Operation(const bool &left, const bool &right) { + return !right && left; +} +//===--------------------------------------------------------------------===// +// Specialized String Comparison Operations +//===--------------------------------------------------------------------===// +template <> +inline bool Equals::Operation(const string_t &left, const string_t &right) { + return left == right; +} + +template <> +inline bool GreaterThan::Operation(const string_t &left, const string_t &right) { + return left > right; +} + +//===--------------------------------------------------------------------===// +// Specialized Interval Comparison Operators +//===--------------------------------------------------------------------===// +template <> +inline bool Equals::Operation(const interval_t &left, const interval_t &right) { + return Interval::Equals(left, right); +} +template <> +inline bool GreaterThan::Operation(const interval_t &left, const interval_t &right) { + return Interval::GreaterThan(left, right); +} + +inline bool operator<(const interval_t &lhs, const interval_t &rhs) { + return LessThan::Operation(lhs, rhs); +} + +//===--------------------------------------------------------------------===// +// Specialized Hugeint Comparison Operators +//===--------------------------------------------------------------------===// +template <> +inline bool Equals::Operation(const hugeint_t &left, const hugeint_t &right) { + return Hugeint::Equals(left, right); +} +template <> +inline bool GreaterThan::Operation(const hugeint_t &left, const hugeint_t &right) { + return Hugeint::GreaterThan(left, right); +} +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/constant_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/constant_operators.hpp new file mode 100644 index 00000000..7c975269 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/constant_operators.hpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/constant_operators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +struct PickLeft { + template + static inline T Operation(T left, T right) { + return left; + } +}; + +struct PickRight { + template + static inline T Operation(T left, T right) { + return right; + } +}; + +struct NOP { + template + static inline T Operation(T left) { + return left; + } +}; + +struct ConstantZero { + template + static inline T Operation(T left, T right) { + return 0; + } +}; + +struct ConstantOne { + template + static inline T Operation(T left, T right) { + return 1; + } +}; + +struct AddOne { + template + static inline T Operation(T left, T right) { + return right + 1; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/convert_to_string.hpp b/src/duckdb/src/include/duckdb/common/operator/convert_to_string.hpp new file mode 100644 index 00000000..c979265e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/convert_to_string.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/convert_to_string.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/type_util.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +struct ConvertToString { + template + static inline string Operation(SRC input) { + throw InternalException("Unrecognized type for ConvertToString %s", GetTypeId()); + } +}; + +template <> +DUCKDB_API string ConvertToString::Operation(bool input); +template <> +DUCKDB_API string ConvertToString::Operation(int8_t input); +template <> +DUCKDB_API string ConvertToString::Operation(int16_t input); +template <> +DUCKDB_API string ConvertToString::Operation(int32_t input); +template <> +DUCKDB_API string ConvertToString::Operation(int64_t input); +template <> +DUCKDB_API string ConvertToString::Operation(uint8_t input); +template <> +DUCKDB_API string ConvertToString::Operation(uint16_t input); +template <> +DUCKDB_API string ConvertToString::Operation(uint32_t input); +template <> +DUCKDB_API string ConvertToString::Operation(uint64_t input); +template <> +DUCKDB_API string ConvertToString::Operation(hugeint_t input); +template <> +DUCKDB_API string ConvertToString::Operation(float input); +template <> +DUCKDB_API string ConvertToString::Operation(double input); +template <> +DUCKDB_API string ConvertToString::Operation(interval_t input); +template <> +DUCKDB_API string ConvertToString::Operation(date_t input); +template <> +DUCKDB_API string ConvertToString::Operation(dtime_t input); +template <> +DUCKDB_API string ConvertToString::Operation(timestamp_t input); +template <> +DUCKDB_API string ConvertToString::Operation(string_t input); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/decimal_cast_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/decimal_cast_operators.hpp new file mode 100644 index 00000000..6610ccad --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/decimal_cast_operators.hpp @@ -0,0 +1,405 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/decimal_cast_operators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/operator/cast_operators.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Decimal Casts +//===--------------------------------------------------------------------===// +struct TryCastToDecimal { + template + static inline bool Operation(SRC input, DST &result, string *error_message, uint8_t width, uint8_t scale) { + throw NotImplementedException("Unimplemented type for TryCastToDecimal!"); + } +}; + +struct TryCastToDecimalCommaSeparated { + template + static inline bool Operation(SRC input, DST &result, string *error_message, uint8_t width, uint8_t scale) { + throw NotImplementedException("Unimplemented type for TryCastToDecimal!"); + } +}; + +struct TryCastFromDecimal { + template + static inline bool Operation(SRC input, DST &result, string *error_message, uint8_t width, uint8_t scale) { + throw NotImplementedException("Unimplemented type for TryCastFromDecimal!"); + } +}; + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> bool +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(bool input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(bool input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(bool input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(bool input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, bool &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, bool &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, bool &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, bool &result, string *error_message, uint8_t width, uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> int8_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int8_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int8_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int8_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int8_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, int8_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, int8_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> int16_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int16_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int16_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int16_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int16_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, int16_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> int32_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int32_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int32_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int32_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int32_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, int32_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> int64_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int64_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int64_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int64_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(int64_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, int64_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> hugeint_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(hugeint_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(hugeint_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(hugeint_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(hugeint_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> uint8_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint8_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint8_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint8_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint8_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, uint8_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, uint8_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, uint8_t &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, uint8_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> uint16_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint16_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint16_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint16_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint16_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, uint16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, uint16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, uint16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, uint16_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> uint32_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint32_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint32_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint32_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint32_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, uint32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, uint32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, uint32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, uint32_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> uint64_t +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint64_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint64_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint64_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(uint64_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, uint64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, uint64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, uint64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, uint64_t &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> float +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(float input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(float input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(float input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(float input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, float &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, float &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, float &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, float &result, string *error_message, uint8_t width, uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> double +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(double input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(double input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(double input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(double input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); + +template <> +bool TryCastFromDecimal::Operation(int16_t input, double &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int32_t input, double &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(int64_t input, double &result, string *error_message, uint8_t width, uint8_t scale); +template <> +bool TryCastFromDecimal::Operation(hugeint_t input, double &result, string *error_message, uint8_t width, + uint8_t scale); + +//===--------------------------------------------------------------------===// +// Cast Decimal <-> VARCHAR +//===--------------------------------------------------------------------===// +template <> +DUCKDB_API bool TryCastToDecimal::Operation(string_t input, int16_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(string_t input, int32_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(string_t input, int64_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimal::Operation(string_t input, hugeint_t &result, string *error_message, uint8_t width, + uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimalCommaSeparated::Operation(string_t input, int16_t &result, string *error_message, + uint8_t width, uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimalCommaSeparated::Operation(string_t input, int32_t &result, string *error_message, + uint8_t width, uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimalCommaSeparated::Operation(string_t input, int64_t &result, string *error_message, + uint8_t width, uint8_t scale); +template <> +DUCKDB_API bool TryCastToDecimalCommaSeparated::Operation(string_t input, hugeint_t &result, string *error_message, + uint8_t width, uint8_t scale); + +struct StringCastFromDecimal { + template + static inline string_t Operation(SRC input, uint8_t width, uint8_t scale, Vector &result) { + throw NotImplementedException("Unimplemented type for string cast!"); + } +}; + +template <> +string_t StringCastFromDecimal::Operation(int16_t input, uint8_t width, uint8_t scale, Vector &result); +template <> +string_t StringCastFromDecimal::Operation(int32_t input, uint8_t width, uint8_t scale, Vector &result); +template <> +string_t StringCastFromDecimal::Operation(int64_t input, uint8_t width, uint8_t scale, Vector &result); +template <> +string_t StringCastFromDecimal::Operation(hugeint_t input, uint8_t width, uint8_t scale, Vector &result); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/multiply.hpp b/src/duckdb/src/include/duckdb/common/operator/multiply.hpp new file mode 100644 index 00000000..5ee38f87 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/multiply.hpp @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/multiply.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/type_util.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +struct interval_t; + +struct MultiplyOperator { + template + static inline TR Operation(TA left, TB right) { + return left * right; + } +}; + +template <> +float MultiplyOperator::Operation(float left, float right); +template <> +double MultiplyOperator::Operation(double left, double right); +template <> +interval_t MultiplyOperator::Operation(interval_t left, int64_t right); +template <> +interval_t MultiplyOperator::Operation(int64_t left, interval_t right); + +struct TryMultiplyOperator { + template + static inline bool Operation(TA left, TB right, TR &result) { + throw InternalException("Unimplemented type for TryMultiplyOperator"); + } +}; + +template <> +bool TryMultiplyOperator::Operation(uint8_t left, uint8_t right, uint8_t &result); +template <> +bool TryMultiplyOperator::Operation(uint16_t left, uint16_t right, uint16_t &result); +template <> +bool TryMultiplyOperator::Operation(uint32_t left, uint32_t right, uint32_t &result); +template <> +bool TryMultiplyOperator::Operation(uint64_t left, uint64_t right, uint64_t &result); + +template <> +bool TryMultiplyOperator::Operation(int8_t left, int8_t right, int8_t &result); +template <> +bool TryMultiplyOperator::Operation(int16_t left, int16_t right, int16_t &result); +template <> +bool TryMultiplyOperator::Operation(int32_t left, int32_t right, int32_t &result); +template <> +DUCKDB_API bool TryMultiplyOperator::Operation(int64_t left, int64_t right, int64_t &result); +template <> +DUCKDB_API bool TryMultiplyOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result); + +struct MultiplyOperatorOverflowCheck { + template + static inline TR Operation(TA left, TB right) { + TR result; + if (!TryMultiplyOperator::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in multiplication of %s (%s * %s)!", TypeIdToString(GetTypeId()), + NumericHelper::ToString(left), NumericHelper::ToString(right)); + } + return result; + } +}; + +struct TryDecimalMultiply { + template + static inline bool Operation(TA left, TB right, TR &result) { + throw InternalException("Unimplemented type for TryDecimalMultiply"); + } +}; + +template <> +bool TryDecimalMultiply::Operation(int16_t left, int16_t right, int16_t &result); +template <> +bool TryDecimalMultiply::Operation(int32_t left, int32_t right, int32_t &result); +template <> +bool TryDecimalMultiply::Operation(int64_t left, int64_t right, int64_t &result); +template <> +bool TryDecimalMultiply::Operation(hugeint_t left, hugeint_t right, hugeint_t &result); + +struct DecimalMultiplyOverflowCheck { + template + static inline TR Operation(TA left, TB right) { + TR result; + if (!TryDecimalMultiply::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in multiplication of DECIMAL(18) (%d * %d). You might want to add an " + "explicit cast to a bigger decimal.", + left, right); + } + return result; + } +}; + +template <> +hugeint_t DecimalMultiplyOverflowCheck::Operation(hugeint_t left, hugeint_t right); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/numeric_binary_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/numeric_binary_operators.hpp new file mode 100644 index 00000000..c85cd4f7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/numeric_binary_operators.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/numeric_binary_operators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/assert.hpp" +#include + +namespace duckdb { + +struct DivideOperator { + template + static inline TR Operation(TA left, TB right) { + D_ASSERT(right != 0); // this should be checked before! + return left / right; + } +}; + +struct ModuloOperator { + template + static inline TR Operation(TA left, TB right) { + D_ASSERT(right != 0); + return left % right; + } +}; + +template <> +float DivideOperator::Operation(float left, float right); +template <> +double DivideOperator::Operation(double left, double right); +template <> +hugeint_t DivideOperator::Operation(hugeint_t left, hugeint_t right); +template <> +interval_t DivideOperator::Operation(interval_t left, int64_t right); + +template <> +float ModuloOperator::Operation(float left, float right); +template <> +double ModuloOperator::Operation(double left, double right); +template <> +hugeint_t ModuloOperator::Operation(hugeint_t left, hugeint_t right); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/numeric_cast.hpp b/src/duckdb/src/include/duckdb/common/operator/numeric_cast.hpp new file mode 100644 index 00000000..3e356baf --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/numeric_cast.hpp @@ -0,0 +1,473 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/numeric_cast.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/types/vector.hpp" +#include + +namespace duckdb { + +template +static bool TryCastWithOverflowCheck(SRC value, DST &result) { + if (!Value::IsFinite(value)) { + return false; + } + if (NumericLimits::IsSigned() != NumericLimits::IsSigned()) { + if (NumericLimits::IsSigned()) { + // signed to unsigned conversion + if (NumericLimits::Digits() > NumericLimits::Digits()) { + if (value < 0 || value > (SRC)NumericLimits::Maximum()) { + return false; + } + } else { + if (value < 0) { + return false; + } + } + result = (DST)value; + return true; + } else { + // unsigned to signed conversion + if (NumericLimits::Digits() >= NumericLimits::Digits()) { + if (value <= (SRC)NumericLimits::Maximum()) { + result = (DST)value; + return true; + } + return false; + } else { + result = (DST)value; + return true; + } + } + } else { + // same sign conversion + if (NumericLimits::Digits() >= NumericLimits::Digits()) { + result = (DST)value; + return true; + } else { + if (value < SRC(NumericLimits::Minimum()) || value > SRC(NumericLimits::Maximum())) { + return false; + } + result = (DST)value; + return true; + } + } +} + +template +bool TryCastWithOverflowCheckFloat(SRC value, T &result, SRC min, SRC max) { + if (!Value::IsFinite(value)) { + return false; + } + if (!(value >= min && value < max)) { + return false; + } + // PG FLOAT => INT casts use statistical rounding. + result = std::nearbyint(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(float value, int8_t &result) { + return TryCastWithOverflowCheckFloat(value, result, -128.0f, 128.0f); +} + +template <> +bool TryCastWithOverflowCheck(float value, int16_t &result) { + return TryCastWithOverflowCheckFloat(value, result, -32768.0f, 32768.0f); +} + +template <> +bool TryCastWithOverflowCheck(float value, int32_t &result) { + return TryCastWithOverflowCheckFloat(value, result, -2147483648.0f, 2147483648.0f); +} + +template <> +bool TryCastWithOverflowCheck(float value, int64_t &result) { + return TryCastWithOverflowCheckFloat(value, result, -9223372036854775808.0f, + 9223372036854775808.0f); +} + +template <> +bool TryCastWithOverflowCheck(float value, uint8_t &result) { + return TryCastWithOverflowCheckFloat(value, result, 0.0f, 256.0f); +} + +template <> +bool TryCastWithOverflowCheck(float value, uint16_t &result) { + return TryCastWithOverflowCheckFloat(value, result, 0.0f, 65536.0f); +} + +template <> +bool TryCastWithOverflowCheck(float value, uint32_t &result) { + return TryCastWithOverflowCheckFloat(value, result, 0.0f, 4294967296.0f); +} + +template <> +bool TryCastWithOverflowCheck(float value, uint64_t &result) { + return TryCastWithOverflowCheckFloat(value, result, 0.0f, 18446744073709551616.0f); +} + +template <> +bool TryCastWithOverflowCheck(double value, int8_t &result) { + return TryCastWithOverflowCheckFloat(value, result, -128.0, 128.0); +} + +template <> +bool TryCastWithOverflowCheck(double value, int16_t &result) { + return TryCastWithOverflowCheckFloat(value, result, -32768.0, 32768.0); +} + +template <> +bool TryCastWithOverflowCheck(double value, int32_t &result) { + return TryCastWithOverflowCheckFloat(value, result, -2147483648.0, 2147483648.0); +} + +template <> +bool TryCastWithOverflowCheck(double value, int64_t &result) { + return TryCastWithOverflowCheckFloat(value, result, -9223372036854775808.0, 9223372036854775808.0); +} + +template <> +bool TryCastWithOverflowCheck(double value, uint8_t &result) { + return TryCastWithOverflowCheckFloat(value, result, 0.0, 256.0); +} + +template <> +bool TryCastWithOverflowCheck(double value, uint16_t &result) { + return TryCastWithOverflowCheckFloat(value, result, 0.0, 65536.0); +} + +template <> +bool TryCastWithOverflowCheck(double value, uint32_t &result) { + return TryCastWithOverflowCheckFloat(value, result, 0.0, 4294967296.0); +} + +template <> +bool TryCastWithOverflowCheck(double value, uint64_t &result) { + return TryCastWithOverflowCheckFloat(value, result, 0.0, 18446744073709551615.0); +} +template <> +bool TryCastWithOverflowCheck(float input, float &result) { + result = input; + return true; +} +template <> +bool TryCastWithOverflowCheck(float input, double &result) { + result = double(input); + return true; +} +template <> +bool TryCastWithOverflowCheck(double input, double &result) { + result = input; + return true; +} + +template <> +bool TryCastWithOverflowCheck(double input, float &result) { + if (!Value::IsFinite(input)) { + result = float(input); + return true; + } + auto res = float(input); + if (!Value::FloatIsFinite(input)) { + return false; + } + result = res; + return true; +} + +//===--------------------------------------------------------------------===// +// Cast Numeric -> bool +//===--------------------------------------------------------------------===// +template <> +bool TryCastWithOverflowCheck(bool value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(int8_t value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(int16_t value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(int32_t value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(int64_t value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(uint8_t value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(uint16_t value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(uint32_t value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(uint64_t value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(float value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(double value, bool &result) { + result = bool(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t input, bool &result) { + result = input.upper != 0 || input.lower != 0; + return true; +} + +//===--------------------------------------------------------------------===// +// Cast bool -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCastWithOverflowCheck(bool value, int8_t &result) { + result = int8_t(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, int16_t &result) { + result = int16_t(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, int32_t &result) { + result = int32_t(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, int64_t &result) { + result = int64_t(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, uint8_t &result) { + result = uint8_t(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, uint16_t &result) { + result = uint16_t(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, uint32_t &result) { + result = uint32_t(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, uint64_t &result) { + result = uint64_t(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, float &result) { + result = float(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool value, double &result) { + result = double(value); + return true; +} + +template <> +bool TryCastWithOverflowCheck(bool input, hugeint_t &result) { + result.upper = 0; + result.lower = input ? 1 : 0; + return true; +} + +//===--------------------------------------------------------------------===// +// Cast Numeric -> hugeint +//===--------------------------------------------------------------------===// +template <> +bool TryCastWithOverflowCheck(int8_t value, hugeint_t &result) { + return Hugeint::TryConvert(value, result); +} + +template <> +bool TryCastWithOverflowCheck(int16_t value, hugeint_t &result) { + return Hugeint::TryConvert(value, result); +} + +template <> +bool TryCastWithOverflowCheck(int32_t value, hugeint_t &result) { + return Hugeint::TryConvert(value, result); +} + +template <> +bool TryCastWithOverflowCheck(int64_t value, hugeint_t &result) { + return Hugeint::TryConvert(value, result); +} + +template <> +bool TryCastWithOverflowCheck(uint8_t value, hugeint_t &result) { + return Hugeint::TryConvert(value, result); +} + +template <> +bool TryCastWithOverflowCheck(uint16_t value, hugeint_t &result) { + return Hugeint::TryConvert(value, result); +} + +template <> +bool TryCastWithOverflowCheck(uint32_t value, hugeint_t &result) { + return Hugeint::TryConvert(value, result); +} + +template <> +bool TryCastWithOverflowCheck(uint64_t value, hugeint_t &result) { + return Hugeint::TryConvert(value, result); +} + +template <> +bool TryCastWithOverflowCheck(float value, hugeint_t &result) { + return Hugeint::TryConvert(std::nearbyintf(value), result); +} + +template <> +bool TryCastWithOverflowCheck(double value, hugeint_t &result) { + return Hugeint::TryConvert(std::nearbyint(value), result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, hugeint_t &result) { + result = value; + return true; +} + +//===--------------------------------------------------------------------===// +// Cast Hugeint -> Numeric +//===--------------------------------------------------------------------===// +template <> +bool TryCastWithOverflowCheck(hugeint_t value, int8_t &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, int16_t &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, int32_t &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, int64_t &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, uint8_t &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, uint16_t &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, uint32_t &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, uint64_t &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, float &result) { + return Hugeint::TryCast(value, result); +} + +template <> +bool TryCastWithOverflowCheck(hugeint_t value, double &result) { + return Hugeint::TryCast(value, result); +} + +struct NumericTryCastToBit { + template + static inline string_t Operation(SRC input, Vector &result) { + return StringVector::AddStringOrBlob(result, Bit::NumericToBit(input)); + } +}; + +struct NumericTryCast { + template + static inline bool Operation(SRC input, DST &result, bool strict = false) { + return TryCastWithOverflowCheck(input, result); + } +}; + +struct NumericCast { + template + static inline DST Operation(SRC input) { + DST result; + if (!NumericTryCast::Operation(input, result)) { + throw InvalidInputException(CastExceptionText(input)); + } + return result; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/string_cast.hpp b/src/duckdb/src/include/duckdb/common/operator/string_cast.hpp new file mode 100644 index 00000000..3e688f53 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/string_cast.hpp @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/string_cast.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/string_type.hpp" + +namespace duckdb { + +//! StringCast +class Vector; + +struct StringCast { + template + static inline string_t Operation(SRC input, Vector &result) { + throw NotImplementedException("Unimplemented type for string cast!"); + } +}; + +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(bool input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(int8_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(int16_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(int32_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(int64_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(uint8_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(uint16_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(uint32_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(uint64_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(hugeint_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(float input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(double input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(interval_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(duckdb::string_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(date_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(dtime_t input, Vector &result); +template <> +DUCKDB_API duckdb::string_t StringCast::Operation(timestamp_t input, Vector &result); + +//! Temporary casting for Time Zone types. TODO: turn casting into functions. +struct StringCastTZ { + template + static inline string_t Operation(SRC input, Vector &vector) { + return StringCast::Operation(input, vector); + } +}; + +template <> +duckdb::string_t StringCastTZ::Operation(date_t input, Vector &result); +template <> +duckdb::string_t StringCastTZ::Operation(dtime_tz_t input, Vector &result); +template <> +duckdb::string_t StringCastTZ::Operation(timestamp_t input, Vector &result); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/subtract.hpp b/src/duckdb/src/include/duckdb/common/operator/subtract.hpp new file mode 100644 index 00000000..7f268e71 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/operator/subtract.hpp @@ -0,0 +1,126 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/operator/subtract.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/type_util.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +struct interval_t; +struct date_t; +struct timestamp_t; +struct dtime_t; + +struct SubtractOperator { + template + static inline TR Operation(TA left, TB right) { + return left - right; + } +}; + +template <> +float SubtractOperator::Operation(float left, float right); +template <> +double SubtractOperator::Operation(double left, double right); +template <> +interval_t SubtractOperator::Operation(interval_t left, interval_t right); +template <> +int64_t SubtractOperator::Operation(date_t left, date_t right); +template <> +date_t SubtractOperator::Operation(date_t left, int32_t right); +template <> +date_t SubtractOperator::Operation(date_t left, interval_t right); +template <> +timestamp_t SubtractOperator::Operation(timestamp_t left, interval_t right); +template <> +interval_t SubtractOperator::Operation(timestamp_t left, timestamp_t right); + +struct TrySubtractOperator { + template + static inline bool Operation(TA left, TB right, TR &result) { + throw InternalException("Unimplemented type for TrySubtractOperator"); + } +}; + +template <> +bool TrySubtractOperator::Operation(uint8_t left, uint8_t right, uint8_t &result); +template <> +bool TrySubtractOperator::Operation(uint16_t left, uint16_t right, uint16_t &result); +template <> +bool TrySubtractOperator::Operation(uint32_t left, uint32_t right, uint32_t &result); +template <> +bool TrySubtractOperator::Operation(uint64_t left, uint64_t right, uint64_t &result); + +template <> +bool TrySubtractOperator::Operation(int8_t left, int8_t right, int8_t &result); +template <> +bool TrySubtractOperator::Operation(int16_t left, int16_t right, int16_t &result); +template <> +bool TrySubtractOperator::Operation(int32_t left, int32_t right, int32_t &result); +template <> +bool TrySubtractOperator::Operation(int64_t left, int64_t right, int64_t &result); +template <> +bool TrySubtractOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result); + +struct SubtractOperatorOverflowCheck { + template + static inline TR Operation(TA left, TB right) { + TR result; + if (!TrySubtractOperator::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in subtraction of %s (%s - %s)!", TypeIdToString(GetTypeId()), + NumericHelper::ToString(left), NumericHelper::ToString(right)); + } + return result; + } +}; + +struct TryDecimalSubtract { + template + static inline bool Operation(TA left, TB right, TR &result) { + throw InternalException("Unimplemented type for TryDecimalSubtract"); + } +}; + +template <> +bool TryDecimalSubtract::Operation(int16_t left, int16_t right, int16_t &result); +template <> +bool TryDecimalSubtract::Operation(int32_t left, int32_t right, int32_t &result); +template <> +bool TryDecimalSubtract::Operation(int64_t left, int64_t right, int64_t &result); +template <> +bool TryDecimalSubtract::Operation(hugeint_t left, hugeint_t right, hugeint_t &result); + +struct DecimalSubtractOverflowCheck { + template + static inline TR Operation(TA left, TB right) { + TR result; + if (!TryDecimalSubtract::Operation(left, right, result)) { + throw OutOfRangeException("Overflow in subtract of DECIMAL(18) (%d - %d). You might want to add an " + "explicit cast to a bigger decimal.", + left, right); + } + return result; + } +}; + +template <> +hugeint_t DecimalSubtractOverflowCheck::Operation(hugeint_t left, hugeint_t right); + +struct SubtractTimeOperator { + template + static TR Operation(TA left, TB right); +}; + +template <> +dtime_t SubtractTimeOperator::Operation(dtime_t left, interval_t right); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/optional_idx.hpp b/src/duckdb/src/include/duckdb/common/optional_idx.hpp new file mode 100644 index 00000000..28c618f2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/optional_idx.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/optional_idx.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +class optional_idx { + static constexpr const idx_t INVALID_INDEX = idx_t(-1); + +public: + optional_idx() : index(INVALID_INDEX) { + } + optional_idx(idx_t index) : index(index) { // NOLINT: allow implicit conversion from idx_t + if (index == INVALID_INDEX) { + throw InternalException("optional_idx cannot be initialized with an invalid index"); + } + } + + static optional_idx Invalid() { + return INVALID_INDEX; + } + + bool IsValid() const { + return index != DConstants::INVALID_INDEX; + } + void Invalidate() { + index = INVALID_INDEX; + } + idx_t GetIndex() const { + if (index == INVALID_INDEX) { + throw InternalException("Attempting to get the index of an optional_idx that is not set"); + } + return index; + } + +private: + idx_t index; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/optional_ptr.hpp b/src/duckdb/src/include/duckdb/common/optional_ptr.hpp new file mode 100644 index 00000000..82b845a8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/optional_ptr.hpp @@ -0,0 +1,77 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/optional_ptr.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/unique_ptr.hpp" + +namespace duckdb { + +template +class optional_ptr { +public: + optional_ptr() : ptr(nullptr) { + } + optional_ptr(T *ptr_p) : ptr(ptr_p) { // NOLINT: allow implicit creation from pointer + } + optional_ptr(const unique_ptr &ptr_p) : ptr(ptr_p.get()) { // NOLINT: allow implicit creation from unique pointer + } + + void CheckValid() const { + if (!ptr) { + throw InternalException("Attempting to dereference an optional pointer that is not set"); + } + } + + operator bool() const { + return ptr; + } + T &operator*() { + CheckValid(); + return *ptr; + } + const T &operator*() const { + CheckValid(); + return *ptr; + } + T *operator->() { + CheckValid(); + return ptr; + } + const T *operator->() const { + CheckValid(); + return ptr; + } + T *get() { + // CheckValid(); + return ptr; + } + const T *get() const { + // CheckValid(); + return ptr; + } + // this looks dirty - but this is the default behavior of raw pointers + T *get_mutable() const { + // CheckValid(); + return ptr; + } + + bool operator==(const optional_ptr &rhs) const { + return ptr == rhs.ptr; + } + + bool operator!=(const optional_ptr &rhs) const { + return ptr != rhs.ptr; + } + +private: + T *ptr; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/pair.hpp b/src/duckdb/src/include/duckdb/common/pair.hpp new file mode 100644 index 00000000..22d50f70 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/pair.hpp @@ -0,0 +1,16 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/pair.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::make_pair; +using std::pair; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/perfect_map_set.hpp b/src/duckdb/src/include/duckdb/common/perfect_map_set.hpp new file mode 100644 index 00000000..1e735a63 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/perfect_map_set.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/perfect_map_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +struct PerfectHash { + inline std::size_t operator()(const idx_t &h) const { + return h; + } +}; + +struct PerfectEquality { + inline bool operator()(const idx_t &a, const idx_t &b) const { + return a == b; + } +}; + +template +using perfect_map_t = unordered_map; + +using perfect_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/pipe_file_system.hpp b/src/duckdb/src/include/duckdb/common/pipe_file_system.hpp new file mode 100644 index 00000000..8d050219 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/pipe_file_system.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/pipe_file_system.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_system.hpp" + +namespace duckdb { + +class PipeFileSystem : public FileSystem { +public: + static unique_ptr OpenPipe(unique_ptr handle); + + int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + + int64_t GetFileSize(FileHandle &handle) override; + + void Reset(FileHandle &handle) override; + bool OnDiskFile(FileHandle &handle) override { + return false; + }; + bool CanSeek() override { + return false; + } + void FileSync(FileHandle &handle) override; + + std::string GetName() const override { + return "PipeFileSystem"; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/preserved_error.hpp b/src/duckdb/src/include/duckdb/common/preserved_error.hpp new file mode 100644 index 00000000..c95cb780 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/preserved_error.hpp @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/preserved_error.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string.hpp" + +namespace duckdb { + +class PreservedError { +public: + //! Not initialized, default constructor + DUCKDB_API PreservedError(); + //! From std::exception + PreservedError(const std::exception &ex) + : initialized(true), type(ExceptionType::INVALID), raw_message(SanitizeErrorMessage(ex.what())), + exception_instance(nullptr) { + } + //! From a raw string + DUCKDB_API explicit PreservedError(const string &raw_message); + //! From an Exception + DUCKDB_API PreservedError(const Exception &exception); + +public: + //! Throw the error + [[noreturn]] DUCKDB_API void Throw(const string &prepended_message = "") const; + //! Get the internal exception type of the error + DUCKDB_API const ExceptionType &Type() const; + //! Allows adding addition information to the message + DUCKDB_API PreservedError &AddToMessage(const string &prepended_message); + //! Used in clients like C-API, creates the final message and returns a reference to it + DUCKDB_API const string &Message(); + //! Let's us do things like 'if (error)' + DUCKDB_API operator bool() const; + DUCKDB_API bool operator==(const PreservedError &other) const; + const shared_ptr &GetError() { + return exception_instance; + } + +private: + //! Whether this PreservedError contains an exception or not + bool initialized; + //! The ExceptionType of the preserved exception + ExceptionType type; + //! The message the exception was constructed with (does not contain the Exception Type) + string raw_message; + //! The final message (stored in the preserved error for compatibility reasons with C-API) + string final_message; + std::shared_ptr exception_instance; + +private: + DUCKDB_API static string SanitizeErrorMessage(string error); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/printer.hpp b/src/duckdb/src/include/duckdb/common/printer.hpp new file mode 100644 index 00000000..8ecc9f58 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/printer.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/printer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +enum class OutputStream : uint8_t { STREAM_STDOUT = 1, STREAM_STDERR = 2 }; + +//! Printer is a static class that allows printing to logs or stdout/stderr +class Printer { +public: + //! Print the object to the stream + DUCKDB_API static void Print(OutputStream stream, const string &str); + //! Print the object to stderr + DUCKDB_API static void Print(const string &str); + //! Print the formatted object to the stream + template + static void PrintF(OutputStream stream, const string &str, Args... params) { + Printer::Print(stream, StringUtil::Format(str, params...)); + } + //! Print the formatted object to stderr + template + static void PrintF(const string &str, Args... params) { + Printer::PrintF(OutputStream::STREAM_STDERR, str, std::forward(params)...); + } + //! Directly prints the string to stdout without a newline + DUCKDB_API static void RawPrint(OutputStream stream, const string &str); + //! Flush an output stream + DUCKDB_API static void Flush(OutputStream stream); + //! Whether or not we are printing to a terminal + DUCKDB_API static bool IsTerminal(OutputStream stream); + //! The terminal width + DUCKDB_API static idx_t TerminalWidth(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/profiler.hpp b/src/duckdb/src/include/duckdb/common/profiler.hpp new file mode 100644 index 00000000..aeb704bf --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/profiler.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/profiler.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/chrono.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { + +//! The profiler can be used to measure elapsed time +template +class BaseProfiler { +public: + //! Starts the timer + void Start() { + finished = false; + start = Tick(); + } + //! Finishes timing + void End() { + end = Tick(); + finished = true; + } + + //! Returns the elapsed time in seconds. If End() has been called, returns + //! the total elapsed time. Otherwise returns how far along the timer is + //! right now. + double Elapsed() const { + auto _end = finished ? end : Tick(); + return std::chrono::duration_cast>(_end - start).count(); + } + +private: + time_point Tick() const { + return T::now(); + } + time_point start; + time_point end; + bool finished = false; +}; + +using Profiler = BaseProfiler; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp b/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp new file mode 100644 index 00000000..d50434be --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/progress_bar/progress_bar_display.hpp" +#include "duckdb/common/unicode_bar.hpp" + +namespace duckdb { + +class TerminalProgressBarDisplay : public ProgressBarDisplay { +public: + TerminalProgressBarDisplay() { + } + ~TerminalProgressBarDisplay() override { + } + +public: + void Update(double percentage) override; + void Finish() override; + +private: + static constexpr const idx_t PARTIAL_BLOCK_COUNT = UnicodeBar::PartialBlocksCount(); +#ifndef DUCKDB_ASCII_TREE_RENDERER + const char *PROGRESS_EMPTY = " "; + const char *const *PROGRESS_PARTIAL = UnicodeBar::PartialBlocks(); + const char *PROGRESS_BLOCK = UnicodeBar::FullBlock(); + const char *PROGRESS_START = "\xE2\x96\x95"; + const char *PROGRESS_END = "\xE2\x96\x8F"; +#else + const char *PROGRESS_EMPTY = " "; + const char *const PROGRESS_PARTIAL[PARTIAL_BLOCK_COUNT] = {" ", " ", " ", " ", " ", " ", " ", " "}; + const char *PROGRESS_BLOCK = "="; + const char *PROGRESS_START = "["; + const char *PROGRESS_END = "]"; +#endif + static constexpr const idx_t PROGRESS_BAR_WIDTH = 60; + +private: + void PrintProgressInternal(int percentage); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar.hpp b/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar.hpp new file mode 100644 index 00000000..d3d77e15 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/progress_bar/progress_bar.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/execution/executor.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/profiler.hpp" +#include "duckdb/common/progress_bar/progress_bar_display.hpp" + +namespace duckdb { + +struct ClientConfig; +typedef unique_ptr (*progress_bar_display_create_func_t)(); + +class ProgressBar { +public: + static unique_ptr DefaultProgressBarDisplay(); + static void SystemOverrideCheck(ClientConfig &config); + + explicit ProgressBar( + Executor &executor, idx_t show_progress_after, + progress_bar_display_create_func_t create_display_func = ProgressBar::DefaultProgressBarDisplay); + + //! Starts the thread + void Start(); + //! Updates the progress bar and prints it to the screen + void Update(bool final); + //! Gets current percentage + double GetCurrentPercentage(); + + void PrintProgressInternal(int percentage); + void PrintProgress(int percentage); + void FinishProgressBarPrint(); + bool ShouldPrint(bool final) const; + bool PrintEnabled() const; + +private: + //! The executor + Executor &executor; + //! The profiler used to measure the time since the progress bar was started + Profiler profiler; + //! The time in ms after which to start displaying the progress bar + idx_t show_progress_after; + //! The current progress percentage + double current_percentage; + //! The display used to print the progress + unique_ptr display; + //! Whether or not profiling is supported for the current query + bool supported = true; + //! Whether the bar has already finished + bool finished = false; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar_display.hpp b/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar_display.hpp new file mode 100644 index 00000000..815272ca --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar_display.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/progress_bar/progress_bar_display.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +class ProgressBarDisplay { +public: + ProgressBarDisplay() { + } + virtual ~ProgressBarDisplay() { + } + +public: + virtual void Update(double percentage) = 0; + virtual void Finish() = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/queue.hpp b/src/duckdb/src/include/duckdb/common/queue.hpp new file mode 100644 index 00000000..d3e28d98 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/queue.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/queue.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::queue; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/radix.hpp b/src/duckdb/src/include/duckdb/common/radix.hpp new file mode 100644 index 00000000..46f6dc6d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/radix.hpp @@ -0,0 +1,189 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/radix.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/bswap.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/value.hpp" + +#include +#include // strlen() on Solaris +#include + +namespace duckdb { + +struct Radix { +public: + static inline bool IsLittleEndian() { + int n = 1; + if (*char_ptr_cast(&n) == 1) { + return true; + } else { + return false; + } + } + + template + static inline void EncodeData(data_ptr_t dataptr, T value) { + throw NotImplementedException("Cannot create data from this type"); + } + + static inline void EncodeStringDataPrefix(data_ptr_t dataptr, string_t value, idx_t prefix_len) { + auto len = value.GetSize(); + memcpy(dataptr, value.GetData(), MinValue(len, prefix_len)); + if (len < prefix_len) { + memset(dataptr + len, '\0', prefix_len - len); + } + } + + static inline uint8_t FlipSign(uint8_t key_byte) { + return key_byte ^ 128; + } + + static inline uint32_t EncodeFloat(float x) { + uint64_t buff; + + //! zero + if (x == 0) { + buff = 0; + buff |= (1u << 31); + return buff; + } + // nan + if (Value::IsNan(x)) { + return UINT_MAX; + } + //! infinity + if (x > FLT_MAX) { + return UINT_MAX - 1; + } + //! -infinity + if (x < -FLT_MAX) { + return 0; + } + buff = Load(const_data_ptr_cast(&x)); + if ((buff & (1u << 31)) == 0) { //! +0 and positive numbers + buff |= (1u << 31); + } else { //! negative numbers + buff = ~buff; //! complement 1 + } + + return buff; + } + + static inline uint64_t EncodeDouble(double x) { + uint64_t buff; + //! zero + if (x == 0) { + buff = 0; + buff += (1ull << 63); + return buff; + } + // nan + if (Value::IsNan(x)) { + return ULLONG_MAX; + } + //! infinity + if (x > DBL_MAX) { + return ULLONG_MAX - 1; + } + //! -infinity + if (x < -DBL_MAX) { + return 0; + } + buff = Load(const_data_ptr_cast(&x)); + if (buff < (1ull << 63)) { //! +0 and positive numbers + buff += (1ull << 63); + } else { //! negative numbers + buff = ~buff; //! complement 1 + } + return buff; + } +}; + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, bool value) { + Store(value ? 1 : 0, dataptr); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, int8_t value) { + Store(value, dataptr); + dataptr[0] = FlipSign(dataptr[0]); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, int16_t value) { + Store(BSwap(value), dataptr); + dataptr[0] = FlipSign(dataptr[0]); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, int32_t value) { + Store(BSwap(value), dataptr); + dataptr[0] = FlipSign(dataptr[0]); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, int64_t value) { + Store(BSwap(value), dataptr); + dataptr[0] = FlipSign(dataptr[0]); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, uint8_t value) { + Store(value, dataptr); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, uint16_t value) { + Store(BSwap(value), dataptr); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, uint32_t value) { + Store(BSwap(value), dataptr); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, uint64_t value) { + Store(BSwap(value), dataptr); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, hugeint_t value) { + EncodeData(dataptr, value.upper); + EncodeData(dataptr + sizeof(value.upper), value.lower); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, float value) { + uint32_t converted_value = EncodeFloat(value); + Store(BSwap(converted_value), dataptr); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, double value) { + uint64_t converted_value = EncodeDouble(value); + Store(BSwap(converted_value), dataptr); +} + +template <> +inline void Radix::EncodeData(data_ptr_t dataptr, interval_t value) { + EncodeData(dataptr, value.months); + dataptr += sizeof(value.months); + EncodeData(dataptr, value.days); + dataptr += sizeof(value.days); + EncodeData(dataptr, value.micros); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp b/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp new file mode 100644 index 00000000..70f0516b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp @@ -0,0 +1,145 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/radix_partitioning.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/types/column/partitioned_column_data.hpp" +#include "duckdb/common/types/row/partitioned_tuple_data.hpp" + +namespace duckdb { + +class BufferManager; +class Vector; +struct UnifiedVectorFormat; +struct SelectionVector; + +//! Generic radix partitioning functions +struct RadixPartitioning { +public: + //! 4096 partitions ought to be enough to go out-of-core properly + static constexpr const idx_t MAX_RADIX_BITS = 12; + + //! The number of partitions for a given number of radix bits + static inline constexpr idx_t NumberOfPartitions(idx_t radix_bits) { + return idx_t(1) << radix_bits; + } + + //! Inverse of NumberOfPartitions, given a number of partitions, get the number of radix bits + static inline idx_t RadixBits(idx_t n_partitions) { + D_ASSERT(IsPowerOfTwo(n_partitions)); + for (idx_t r = 0; r < sizeof(idx_t) * 8; r++) { + if (n_partitions == NumberOfPartitions(r)) { + return r; + } + } + throw InternalException("RadixPartitioning::RadixBits unable to find partition count!"); + } + + //! Radix bits begin after uint16_t because these bits are used as salt in the aggregate HT + static inline constexpr idx_t Shift(idx_t radix_bits) { + return (sizeof(hash_t) - sizeof(uint16_t)) * 8 - radix_bits; + } + + //! Mask of the radix bits of the hash + static inline constexpr hash_t Mask(idx_t radix_bits) { + return (hash_t(1 << radix_bits) - 1) << Shift(radix_bits); + } + + //! Select using a cutoff on the radix bits of the hash + static idx_t Select(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t radix_bits, idx_t cutoff, + SelectionVector *true_sel, SelectionVector *false_sel); +}; + +//! RadixPartitionedColumnData is a PartitionedColumnData that partitions input based on the radix of a hash +class RadixPartitionedColumnData : public PartitionedColumnData { +public: + RadixPartitionedColumnData(ClientContext &context, vector types, idx_t radix_bits, idx_t hash_col_idx); + RadixPartitionedColumnData(const RadixPartitionedColumnData &other); + ~RadixPartitionedColumnData() override; + + idx_t GetRadixBits() const { + return radix_bits; + } + +protected: + //===--------------------------------------------------------------------===// + // Radix Partitioning interface implementation + //===--------------------------------------------------------------------===// + idx_t BufferSize() const override { + switch (radix_bits) { + case 1: + case 2: + case 3: + case 4: + return GetBufferSize(1 << 1); + case 5: + return GetBufferSize(1 << 2); + case 6: + return GetBufferSize(1 << 3); + default: + return GetBufferSize(1 << 4); + } + } + + void InitializeAppendStateInternal(PartitionedColumnDataAppendState &state) const override; + void ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) override; + + static constexpr idx_t GetBufferSize(idx_t div) { + return STANDARD_VECTOR_SIZE / div == 0 ? 1 : STANDARD_VECTOR_SIZE / div; + } + +private: + //! The number of radix bits + const idx_t radix_bits; + //! The index of the column holding the hashes + const idx_t hash_col_idx; +}; + +//! RadixPartitionedTupleData is a PartitionedTupleData that partitions input based on the radix of a hash +class RadixPartitionedTupleData : public PartitionedTupleData { +public: + RadixPartitionedTupleData(BufferManager &buffer_manager, const TupleDataLayout &layout, idx_t radix_bits_p, + idx_t hash_col_idx_p); + RadixPartitionedTupleData(const RadixPartitionedTupleData &other); + ~RadixPartitionedTupleData() override; + + idx_t GetRadixBits() const { + return radix_bits; + } + +private: + void Initialize(); + +protected: + //===--------------------------------------------------------------------===// + // Radix Partitioning interface implementation + //===--------------------------------------------------------------------===// + void InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, + TupleDataPinProperties properties) const override; + void ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input) override; + void ComputePartitionIndices(Vector &row_locations, idx_t count, Vector &partition_indices) const override; + idx_t MaxPartitionIndex() const override { + return RadixPartitioning::NumberOfPartitions(radix_bits) - 1; + } + + bool RepartitionReverseOrder() const override { + return true; + } + void RepartitionFinalizeStates(PartitionedTupleData &old_partitioned_data, + PartitionedTupleData &new_partitioned_data, PartitionedTupleDataAppendState &state, + idx_t finished_partition_idx) const override; + +private: + //! The number of radix bits + const idx_t radix_bits; + //! The index of the column holding the hashes + const idx_t hash_col_idx; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/random_engine.hpp b/src/duckdb/src/include/duckdb/common/random_engine.hpp new file mode 100644 index 00000000..1185dbcb --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/random_engine.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/random_engine.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/mutex.hpp" + +#include + +namespace duckdb { +class ClientContext; +struct RandomState; + +struct RandomEngine { + RandomEngine(int64_t seed = -1); + ~RandomEngine(); + +public: + //! Generate a random number between min and max + double NextRandom(double min, double max); + + //! Generate a random number between 0 and 1 + double NextRandom(); + uint32_t NextRandomInteger(); + + void SetSeed(uint32_t seed); + + static RandomEngine &Get(ClientContext &context); + + mutex lock; + +private: + unique_ptr random_state; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/re2_regex.hpp b/src/duckdb/src/include/duckdb/common/re2_regex.hpp new file mode 100644 index 00000000..3c13179a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/re2_regex.hpp @@ -0,0 +1,73 @@ +// RE2 compatibility layer with std::regex + +#pragma once + +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/vector.hpp" +#include +#include + +namespace duckdb_re2 { +class RE2; + +enum class RegexOptions : uint8_t { NONE, CASE_INSENSITIVE }; + +class Regex { +public: + DUCKDB_API Regex(const std::string &pattern, RegexOptions options = RegexOptions::NONE); + Regex(const char *pattern, RegexOptions options = RegexOptions::NONE) : Regex(std::string(pattern)) { + } + const duckdb_re2::RE2 &GetRegex() const { + return *regex; + } + +private: + std::shared_ptr regex; +}; + +struct GroupMatch { + std::string text; + uint32_t position; + + const std::string &str() const { + return text; + } + operator std::string() const { + return text; + } +}; + +struct Match { + duckdb::vector groups; + + GroupMatch &GetGroup(uint64_t index) { + if (index >= groups.size()) { + throw std::runtime_error("RE2: Match index is out of range"); + } + return groups[index]; + } + + std::string str(uint64_t index) { + return GetGroup(index).text; + } + + uint64_t position(uint64_t index) { + return GetGroup(index).position; + } + + uint64_t length(uint64_t index) { + return GetGroup(index).text.size(); + } + + GroupMatch &operator[](uint64_t i) { + return GetGroup(i); + } +}; + +DUCKDB_API bool RegexSearch(const std::string &input, Match &match, const Regex ®ex); +DUCKDB_API bool RegexMatch(const std::string &input, Match &match, const Regex ®ex); +DUCKDB_API bool RegexMatch(const char *start, const char *end, Match &match, const Regex ®ex); +DUCKDB_API bool RegexMatch(const std::string &input, const Regex ®ex); +DUCKDB_API duckdb::vector RegexFindAll(const std::string &input, const Regex ®ex); + +} // namespace duckdb_re2 diff --git a/src/duckdb/src/include/duckdb/common/reference_map.hpp b/src/duckdb/src/include/duckdb/common/reference_map.hpp new file mode 100644 index 00000000..2d95f997 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/reference_map.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/reference_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { +class Expression; + +template +struct ReferenceHashFunction { + uint64_t operator()(const reference &ref) const { + return std::hash()((void *)&ref.get()); + } +}; + +template +struct ReferenceEquality { + bool operator()(const reference &a, const reference &b) const { + return &a.get() == &b.get(); + } +}; + +template +using reference_map_t = unordered_map, TGT, ReferenceHashFunction, ReferenceEquality>; + +template +using reference_set_t = unordered_set, ReferenceHashFunction, ReferenceEquality>; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/row_operations/row_matcher.hpp b/src/duckdb/src/include/duckdb/common/row_operations/row_matcher.hpp new file mode 100644 index 00000000..006e2f6d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/row_operations/row_matcher.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/row_operations/row_matcher.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +class Vector; +class DataChunk; +class TupleDataLayout; +struct TupleDataVectorFormat; +struct SelectionVector; +struct MatchFunction; + +typedef idx_t (*match_function_t)(Vector &lhs_vector, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, + const idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, + const idx_t col_idx, const vector &child_functions, + SelectionVector *no_match_sel, idx_t &no_match_count); + +struct MatchFunction { + match_function_t function; + vector child_functions; +}; + +struct RowMatcher { +public: + using Predicates = vector; + + //! Initializes the RowMatcher, filling match_functions using layout and predicates + void Initialize(const bool no_match_sel, const TupleDataLayout &layout, const Predicates &predicates); + //! Given a DataChunk on the LHS, on which we've called TupleDataCollection::ToUnifiedFormat, + //! we match it with rows on the RHS, according to the given layout and locations. + //! Initially, 'sel' has 'count' entries which point to what needs to be compared. + //! After matching is done, this returns how many matching entries there are, which 'sel' is modified to point to + idx_t Match(DataChunk &lhs, const vector &lhs_formats, SelectionVector &sel, idx_t count, + const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, SelectionVector *no_match_sel, + idx_t &no_match_count); + +private: + //! Gets the templated match function for a given column + MatchFunction GetMatchFunction(const bool no_match_sel, const LogicalType &type, const ExpressionType predicate); + template + MatchFunction GetMatchFunction(const LogicalType &type, const ExpressionType predicate); + template + MatchFunction GetMatchFunction(const ExpressionType predicate); + template + MatchFunction GetStructMatchFunction(const LogicalType &type, const ExpressionType predicate); + template + MatchFunction GetListMatchFunction(const ExpressionType predicate); + +private: + vector match_functions; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp new file mode 100644 index 00000000..095c8b98 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp @@ -0,0 +1,132 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/row_operations/row_operations.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +class ArenaAllocator; +struct AggregateObject; +struct AggregateFilterData; +class DataChunk; +class RowLayout; +class TupleDataLayout; +class RowDataCollection; +struct SelectionVector; +class StringHeap; +class Vector; +struct UnifiedVectorFormat; + +struct RowOperationsState { + explicit RowOperationsState(ArenaAllocator &allocator) : allocator(allocator) { + } + + ArenaAllocator &allocator; +}; + +// RowOperations contains a set of operations that operate on data using a RowLayout +struct RowOperations { + //===--------------------------------------------------------------------===// + // Aggregation Operators + //===--------------------------------------------------------------------===// + //! initialize - unaligned addresses + static void InitializeStates(TupleDataLayout &layout, Vector &addresses, const SelectionVector &sel, idx_t count); + //! destructor - unaligned addresses, updated + static void DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, idx_t count); + //! update - aligned addresses + static void UpdateStates(RowOperationsState &state, AggregateObject &aggr, Vector &addresses, DataChunk &payload, + idx_t arg_idx, idx_t count); + //! filtered update - aligned addresses + static void UpdateFilteredStates(RowOperationsState &state, AggregateFilterData &filter_data, AggregateObject &aggr, + Vector &addresses, DataChunk &payload, idx_t arg_idx); + //! combine - unaligned addresses, updated + static void CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, Vector &targets, + idx_t count); + //! finalize - unaligned addresses, updated + static void FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, DataChunk &result, + idx_t aggr_idx); + + //===--------------------------------------------------------------------===// + // Read/Write Operators + //===--------------------------------------------------------------------===// + //! Scatter group data to the rows. Initialises the ValidityMask. + static void Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, + RowDataCollection &string_heap, const SelectionVector &sel, idx_t count); + //! Gather a single column. + //! If heap_ptr is not null, then the data is assumed to contain swizzled pointers, + //! which will be unswizzled in memory. + static void Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, + const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size = 0, + data_ptr_t heap_ptr = nullptr); + //! Full Scan an entire columns + static void FullScanColumn(const TupleDataLayout &layout, Vector &rows, Vector &col, idx_t count, idx_t col_idx); + + //===--------------------------------------------------------------------===// + // Comparison Operators + //===--------------------------------------------------------------------===// + //! Compare a block of key data against the row values to produce an updated selection that matches + //! and a second (optional) selection of non-matching values. + //! Returns the number of matches remaining in the selection. + using Predicates = vector; + + static idx_t Match(DataChunk &columns, UnifiedVectorFormat col_data[], const TupleDataLayout &layout, Vector &rows, + const Predicates &predicates, SelectionVector &sel, idx_t count, SelectionVector *no_match, + idx_t &no_match_count); + + //===--------------------------------------------------------------------===// + // Heap Operators + //===--------------------------------------------------------------------===// + //! Compute the entry sizes of a vector with variable size type (used before building heap buffer space). + static void ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, + const SelectionVector &sel, idx_t offset = 0); + //! Compute the entry sizes of vector data with variable size type (used before building heap buffer space). + static void ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, + idx_t ser_count, const SelectionVector &sel, idx_t offset = 0); + //! Scatter vector with variable size type to the heap. + static void HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, idx_t col_idx, + data_ptr_t *key_locations, data_ptr_t *validitymask_locations, idx_t offset = 0); + //! Scatter vector data with variable size type to the heap. + static void HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, + idx_t ser_count, idx_t col_idx, data_ptr_t *key_locations, + data_ptr_t *validitymask_locations, idx_t offset = 0); + //! Gather a single column with variable size type from the heap. + static void HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, const idx_t &col_idx, + data_ptr_t key_locations[], data_ptr_t validitymask_locations[]); + + //===--------------------------------------------------------------------===// + // Sorting Operators + //===--------------------------------------------------------------------===// + //! Scatter vector data to the rows in radix-sortable format. + static void RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, + data_ptr_t key_locations[], bool desc, bool has_null, bool nulls_first, idx_t prefix_len, + idx_t width, idx_t offset = 0); + + //===--------------------------------------------------------------------===// + // Out-of-Core Operators + //===--------------------------------------------------------------------===// + //! Swizzles blob pointers to offset within heap row + static void SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count); + //! Swizzles the base pointer of each row to offset within heap block + static void SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, + const idx_t count, const idx_t base_offset = 0); + //! Copies 'count' heap rows that are pointed to by the rows at 'row_ptr' to 'heap_ptr' and swizzles the pointers + static void CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, + data_ptr_t heap_ptr, const idx_t count); + + //! Unswizzles the base offset within heap block the rows to pointers + static void UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, + const data_ptr_t base_heap_ptr, const idx_t count); + //! Unswizzles all offsets back to pointers + static void UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, + const data_ptr_t base_heap_ptr, const idx_t count); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/binary_deserializer.hpp b/src/duckdb/src/include/duckdb/common/serializer/binary_deserializer.hpp new file mode 100644 index 00000000..9d8b11ec --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/binary_deserializer.hpp @@ -0,0 +1,155 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/binary_deserializer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/encoding_util.hpp" +#include "duckdb/common/serializer/read_stream.hpp" + +namespace duckdb { +class ClientContext; + +class BinaryDeserializer : public Deserializer { +public: + explicit BinaryDeserializer(ReadStream &stream) : stream(stream) { + deserialize_enum_from_string = false; + } + + template + unique_ptr Deserialize() { + OnObjectBegin(); + auto result = T::Deserialize(*this); + OnObjectEnd(); + D_ASSERT(nesting_level == 0); // make sure we are at the root level + return result; + } + + template + static unique_ptr Deserialize(ReadStream &stream) { + BinaryDeserializer deserializer(stream); + return deserializer.template Deserialize(); + } + + template + static unique_ptr Deserialize(ReadStream &stream, ClientContext &context, bound_parameter_map_t ¶meters) { + BinaryDeserializer deserializer(stream); + deserializer.Set(context); + deserializer.Set(parameters); + return deserializer.template Deserialize(); + } + + void Begin() { + OnObjectBegin(); + } + + void End() { + OnObjectEnd(); + D_ASSERT(nesting_level == 0); // make sure we are at the root level + } + + ReadStream &GetStream() { + return stream; + } + +private: + ReadStream &stream; + idx_t nesting_level = 0; + + // Allow peeking 1 field ahead + bool has_buffered_field = false; + field_id_t buffered_field = 0; + +private: + field_id_t PeekField() { + if (!has_buffered_field) { + buffered_field = ReadPrimitive(); + has_buffered_field = true; + } + return buffered_field; + } + void ConsumeField() { + if (!has_buffered_field) { + buffered_field = ReadPrimitive(); + } else { + has_buffered_field = false; + } + } + field_id_t NextField() { + if (has_buffered_field) { + has_buffered_field = false; + return buffered_field; + } + return ReadPrimitive(); + } + + void ReadData(data_ptr_t buffer, idx_t read_size) { + stream.ReadData(buffer, read_size); + } + + template + T ReadPrimitive() { + T value; + ReadData(data_ptr_cast(&value), sizeof(T)); + return value; + } + + template + T VarIntDecode() { + // FIXME: maybe we should pass a source to EncodingUtil instead + uint8_t buffer[16]; + idx_t varint_size; + for (varint_size = 0; varint_size < 16; varint_size++) { + ReadData(buffer + varint_size, 1); + if (!(buffer[varint_size] & 0x80)) { + varint_size++; + break; + } + } + T value; + auto read_size = EncodingUtil::DecodeLEB128(buffer, value); + D_ASSERT(read_size == varint_size); + (void)read_size; + return value; + } + + //===--------------------------------------------------------------------===// + // Nested Types Hooks + //===--------------------------------------------------------------------===// + void OnPropertyBegin(const field_id_t field_id, const char *tag) final; + void OnPropertyEnd() final; + bool OnOptionalPropertyBegin(const field_id_t field_id, const char *tag) final; + void OnOptionalPropertyEnd(bool present) final; + void OnObjectBegin() final; + void OnObjectEnd() final; + idx_t OnListBegin() final; + void OnListEnd() final; + bool OnNullableBegin() final; + void OnNullableEnd() final; + + //===--------------------------------------------------------------------===// + // Primitive Types + //===--------------------------------------------------------------------===// + bool ReadBool() final; + char ReadChar() final; + int8_t ReadSignedInt8() final; + uint8_t ReadUnsignedInt8() final; + int16_t ReadSignedInt16() final; + uint16_t ReadUnsignedInt16() final; + int32_t ReadSignedInt32() final; + uint32_t ReadUnsignedInt32() final; + int64_t ReadSignedInt64() final; + uint64_t ReadUnsignedInt64() final; + float ReadFloat() final; + double ReadDouble() final; + string ReadString() final; + hugeint_t ReadHugeInt() final; + void ReadDataPtr(data_ptr_t &ptr, idx_t count) final; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/binary_serializer.hpp b/src/duckdb/src/include/duckdb/common/serializer/binary_serializer.hpp new file mode 100644 index 00000000..c7d22e31 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/binary_serializer.hpp @@ -0,0 +1,112 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/binary_serializer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/write_stream.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/serializer/encoding_util.hpp" + +namespace duckdb { + +class BinarySerializer : public Serializer { +public: + explicit BinarySerializer(WriteStream &stream, bool serialize_default_values_p = false) : stream(stream) { + serialize_default_values = serialize_default_values_p; + serialize_enum_as_string = false; + } + +private: + struct DebugState { + unordered_set seen_field_tags; + unordered_set seen_field_ids; + vector> seen_fields; + }; + + void WriteData(const_data_ptr_t buffer, idx_t write_size) { + stream.WriteData(buffer, write_size); + } + + template + void Write(T element) { + static_assert(std::is_trivially_destructible(), "Write element must be trivially destructible"); + WriteData(const_data_ptr_cast(&element), sizeof(T)); + } + void WriteData(const char *ptr, idx_t write_size) { + WriteData(const_data_ptr_cast(ptr), write_size); + } + + template + void VarIntEncode(T value) { + uint8_t buffer[16]; + auto write_size = EncodingUtil::EncodeLEB128(buffer, value); + D_ASSERT(write_size <= sizeof(buffer)); + WriteData(buffer, write_size); + } + +public: + template + static void Serialize(const T &value, WriteStream &stream, bool serialize_default_values = false) { + BinarySerializer serializer(stream, serialize_default_values); + serializer.OnObjectBegin(); + value.Serialize(serializer); + serializer.OnObjectEnd(); + } + + void Begin() { + OnObjectBegin(); + } + void End() { + OnObjectEnd(); + } + +protected: + //------------------------------------------------------------------------- + // Nested Type Hooks + //------------------------------------------------------------------------- + // We serialize optional values as a message with a "present" flag, followed by the value. + void OnPropertyBegin(const field_id_t field_id, const char *tag) final; + void OnPropertyEnd() final; + void OnOptionalPropertyBegin(const field_id_t field_id, const char *tag, bool present) final; + void OnOptionalPropertyEnd(bool present) final; + void OnListBegin(idx_t count) final; + void OnListEnd() final; + void OnObjectBegin() final; + void OnObjectEnd() final; + void OnNullableBegin(bool present) final; + void OnNullableEnd() final; + + //------------------------------------------------------------------------- + // Primitive Types + //------------------------------------------------------------------------- + void WriteNull() final; + void WriteValue(char value) final; + void WriteValue(uint8_t value) final; + void WriteValue(int8_t value) final; + void WriteValue(uint16_t value) final; + void WriteValue(int16_t value) final; + void WriteValue(uint32_t value) final; + void WriteValue(int32_t value) final; + void WriteValue(uint64_t value) final; + void WriteValue(int64_t value) final; + void WriteValue(hugeint_t value) final; + void WriteValue(float value) final; + void WriteValue(double value) final; + void WriteValue(const string_t value) final; + void WriteValue(const string &value) final; + void WriteValue(const char *value) final; + void WriteValue(bool value) final; + void WriteDataPtr(const_data_ptr_t ptr, idx_t count) final; + +private: + vector debug_stack; + WriteStream &stream; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/buffered_file_reader.hpp b/src/duckdb/src/include/duckdb/common/serializer/buffered_file_reader.hpp new file mode 100644 index 00000000..34d78e16 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/buffered_file_reader.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/buffered_file_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/serializer/read_stream.hpp" + +namespace duckdb { + +class BufferedFileReader : public ReadStream { +public: + BufferedFileReader(FileSystem &fs, const char *path, FileLockType lock_type = FileLockType::READ_LOCK, + optional_ptr opener = nullptr); + + FileSystem &fs; + unsafe_unique_array data; + idx_t offset; + idx_t read_data; + unique_ptr handle; + +public: + void ReadData(data_ptr_t buffer, uint64_t read_size) override; + //! Returns true if the reader has finished reading the entire file + bool Finished(); + + idx_t FileSize() { + return file_size; + } + + void Seek(uint64_t location); + uint64_t CurrentOffset(); + +private: + idx_t file_size; + idx_t total_read; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/buffered_file_writer.hpp b/src/duckdb/src/include/duckdb/common/serializer/buffered_file_writer.hpp new file mode 100644 index 00000000..8ad0ffda --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/buffered_file_writer.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/buffered_file_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/serializer/write_stream.hpp" +#include "duckdb/common/file_system.hpp" + +namespace duckdb { + +#define FILE_BUFFER_SIZE 4096 + +class BufferedFileWriter : public WriteStream { +public: + static constexpr uint8_t DEFAULT_OPEN_FLAGS = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE; + + //! Serializes to a buffer allocated by the serializer, will expand when + //! writing past the initial threshold + DUCKDB_API BufferedFileWriter(FileSystem &fs, const string &path, uint8_t open_flags = DEFAULT_OPEN_FLAGS); + + FileSystem &fs; + string path; + unsafe_unique_array data; + idx_t offset; + idx_t total_written; + unique_ptr handle; + +public: + DUCKDB_API void WriteData(const_data_ptr_t buffer, idx_t write_size) override; + //! Flush the buffer to disk and sync the file to ensure writing is completed + DUCKDB_API void Sync(); + //! Flush the buffer to the file (without sync) + DUCKDB_API void Flush(); + //! Returns the current size of the file + DUCKDB_API int64_t GetFileSize(); + //! Truncate the size to a previous size (given that size <= GetFileSize()) + DUCKDB_API void Truncate(int64_t size); + + DUCKDB_API idx_t GetTotalWritten(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/deserialization_data.hpp b/src/duckdb/src/include/duckdb/common/serializer/deserialization_data.hpp new file mode 100644 index 00000000..c936644e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/deserialization_data.hpp @@ -0,0 +1,181 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/deserialization_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/stack.hpp" +#include "duckdb/planner/bound_parameter_map.hpp" + +namespace duckdb { +class ClientContext; +class Catalog; +class DatabaseInstance; +enum class ExpressionType : uint8_t; + +struct DeserializationData { + stack> contexts; + stack> databases; + stack enums; + stack> parameter_data; + stack> types; + + template + void Set(T entry) = delete; + + template + T Get() = delete; + + template + void Unset() = delete; + + template + inline void AssertNotEmpty(const stack &e) { + if (e.empty()) { + throw InternalException("DeserializationData - unexpected empty stack"); + } + } +}; + +template <> +inline void DeserializationData::Set(ExpressionType type) { + enums.push(idx_t(type)); +} + +template <> +inline ExpressionType DeserializationData::Get() { + AssertNotEmpty(enums); + return ExpressionType(enums.top()); +} + +template <> +inline void DeserializationData::Unset() { + AssertNotEmpty(enums); + enums.pop(); +} + +template <> +inline void DeserializationData::Set(LogicalOperatorType type) { + enums.push(idx_t(type)); +} + +template <> +inline LogicalOperatorType DeserializationData::Get() { + AssertNotEmpty(enums); + return LogicalOperatorType(enums.top()); +} + +template <> +inline void DeserializationData::Unset() { + AssertNotEmpty(enums); + enums.pop(); +} + +template <> +inline void DeserializationData::Set(CompressionType type) { + enums.push(idx_t(type)); +} + +template <> +inline CompressionType DeserializationData::Get() { + AssertNotEmpty(enums); + return CompressionType(enums.top()); +} + +template <> +inline void DeserializationData::Unset() { + AssertNotEmpty(enums); + enums.pop(); +} + +template <> +inline void DeserializationData::Set(CatalogType type) { + enums.push(idx_t(type)); +} + +template <> +inline CatalogType DeserializationData::Get() { + AssertNotEmpty(enums); + return CatalogType(enums.top()); +} + +template <> +inline void DeserializationData::Unset() { + AssertNotEmpty(enums); + enums.pop(); +} + +template <> +inline void DeserializationData::Set(ClientContext &context) { + contexts.push(context); +} + +template <> +inline ClientContext &DeserializationData::Get() { + AssertNotEmpty(contexts); + return contexts.top(); +} + +template <> +inline void DeserializationData::Unset() { + AssertNotEmpty(contexts); + contexts.pop(); +} + +template <> +inline void DeserializationData::Set(DatabaseInstance &db) { + databases.push(db); +} + +template <> +inline DatabaseInstance &DeserializationData::Get() { + AssertNotEmpty(databases); + return databases.top(); +} + +template <> +inline void DeserializationData::Unset() { + AssertNotEmpty(databases); + databases.pop(); +} + +template <> +inline void DeserializationData::Set(bound_parameter_map_t &context) { + parameter_data.push(context); +} + +template <> +inline bound_parameter_map_t &DeserializationData::Get() { + AssertNotEmpty(parameter_data); + return parameter_data.top(); +} + +template <> +inline void DeserializationData::Unset() { + AssertNotEmpty(parameter_data); + parameter_data.pop(); +} + +template <> +inline void DeserializationData::Set(LogicalType &type) { + types.emplace(type); +} + +template <> +inline LogicalType &DeserializationData::Get() { + AssertNotEmpty(types); + return types.top(); +} + +template <> +inline void DeserializationData::Unset() { + AssertNotEmpty(types); + types.pop(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp b/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp new file mode 100644 index 00000000..3014c51e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp @@ -0,0 +1,465 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/format_serializer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/serializer/serialization_traits.hpp" +#include "duckdb/common/serializer/deserialization_data.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +class Deserializer { +protected: + bool deserialize_enum_from_string = false; + DeserializationData data; + +public: + virtual ~Deserializer() { + } + + class List { + friend Deserializer; + + private: + Deserializer &deserializer; + explicit List(Deserializer &deserializer) : deserializer(deserializer) { + } + + public: + // Deserialize an element + template + T ReadElement(); + + // Deserialize an object + template + void ReadObject(FUNC f); + }; + +public: + // Read into an existing value + template + inline void ReadProperty(const field_id_t field_id, const char *tag, T &ret) { + OnPropertyBegin(field_id, tag); + ret = Read(); + OnPropertyEnd(); + } + + // Read and return a value + template + inline T ReadProperty(const field_id_t field_id, const char *tag) { + OnPropertyBegin(field_id, tag); + auto ret = Read(); + OnPropertyEnd(); + return ret; + } + + // Default Value return + template + inline T ReadPropertyWithDefault(const field_id_t field_id, const char *tag) { + if (!OnOptionalPropertyBegin(field_id, tag)) { + OnOptionalPropertyEnd(false); + return std::forward(SerializationDefaultValue::GetDefault()); + } + auto ret = Read(); + OnOptionalPropertyEnd(true); + return ret; + } + + template + inline T ReadPropertyWithDefault(const field_id_t field_id, const char *tag, T &&default_value) { + if (!OnOptionalPropertyBegin(field_id, tag)) { + OnOptionalPropertyEnd(false); + return std::forward(default_value); + } + auto ret = Read(); + OnOptionalPropertyEnd(true); + return ret; + } + + // Default value in place + template + inline void ReadPropertyWithDefault(const field_id_t field_id, const char *tag, T &ret) { + if (!OnOptionalPropertyBegin(field_id, tag)) { + ret = std::forward(SerializationDefaultValue::GetDefault()); + OnOptionalPropertyEnd(false); + return; + } + ret = Read(); + OnOptionalPropertyEnd(true); + } + + template + inline void ReadPropertyWithDefault(const field_id_t field_id, const char *tag, T &ret, T &&default_value) { + if (!OnOptionalPropertyBegin(field_id, tag)) { + ret = std::forward(default_value); + OnOptionalPropertyEnd(false); + return; + } + ret = Read(); + OnOptionalPropertyEnd(true); + } + + // Special case: + // Read into an existing data_ptr_t + inline void ReadProperty(const field_id_t field_id, const char *tag, data_ptr_t ret, idx_t count) { + OnPropertyBegin(field_id, tag); + ReadDataPtr(ret, count); + OnPropertyEnd(); + } + + // Try to read a property, if it is not present, continue, otherwise read and discard the value + template + inline void ReadDeletedProperty(const field_id_t field_id, const char *tag) { + // Try to read the property. If not present, great! + if (!OnOptionalPropertyBegin(field_id, tag)) { + OnOptionalPropertyEnd(false); + return; + } + // Otherwise read and discard the value + (void)Read(); + OnOptionalPropertyEnd(true); + } + + //! Set a serialization property + template + void Set(T entry) { + return data.Set(entry); + } + + //! Retrieve the last set serialization property of this type + template + T Get() { + return data.Get(); + } + + //! Unset a serialization property + template + void Unset() { + return data.Unset(); + } + + template + void ReadList(const field_id_t field_id, const char *tag, FUNC func) { + OnPropertyBegin(field_id, tag); + auto size = OnListBegin(); + List list {*this}; + for (idx_t i = 0; i < size; i++) { + func(list, i); + } + OnListEnd(); + OnPropertyEnd(); + } + + template + void ReadObject(const field_id_t field_id, const char *tag, FUNC func) { + OnPropertyBegin(field_id, tag); + OnObjectBegin(); + func(*this); + OnObjectEnd(); + OnPropertyEnd(); + } + +private: + // Deserialize anything implementing a Deserialize method + template + inline typename std::enable_if::value, T>::type Read() { + OnObjectBegin(); + auto val = T::Deserialize(*this); + OnObjectEnd(); + return val; + } + + template + inline typename std::enable_if::value, T>::type Read() { + using ELEMENT_TYPE = typename is_unique_ptr::ELEMENT_TYPE; + unique_ptr ptr = nullptr; + auto is_present = OnNullableBegin(); + if (is_present) { + OnObjectBegin(); + ptr = ELEMENT_TYPE::Deserialize(*this); + OnObjectEnd(); + } + OnNullableEnd(); + return ptr; + } + + // Deserialize shared_ptr + template + inline typename std::enable_if::value, T>::type Read() { + using ELEMENT_TYPE = typename is_shared_ptr::ELEMENT_TYPE; + shared_ptr ptr = nullptr; + auto is_present = OnNullableBegin(); + if (is_present) { + OnObjectBegin(); + ptr = ELEMENT_TYPE::Deserialize(*this); + OnObjectEnd(); + } + OnNullableEnd(); + return ptr; + } + + // Deserialize a vector + template + inline typename std::enable_if::value, T>::type Read() { + using ELEMENT_TYPE = typename is_vector::ELEMENT_TYPE; + T vec; + auto size = OnListBegin(); + for (idx_t i = 0; i < size; i++) { + vec.push_back(Read()); + } + OnListEnd(); + return vec; + } + + template + inline typename std::enable_if::value, T>::type Read() { + using ELEMENT_TYPE = typename is_unsafe_vector::ELEMENT_TYPE; + T vec; + auto size = OnListBegin(); + for (idx_t i = 0; i < size; i++) { + vec.push_back(Read()); + } + OnListEnd(); + + return vec; + } + + // Deserialize a map + template + inline typename std::enable_if::value, T>::type Read() { + using KEY_TYPE = typename is_unordered_map::KEY_TYPE; + using VALUE_TYPE = typename is_unordered_map::VALUE_TYPE; + + T map; + auto size = OnListBegin(); + for (idx_t i = 0; i < size; i++) { + OnObjectBegin(); + auto key = ReadProperty(0, "key"); + auto value = ReadProperty(1, "value"); + OnObjectEnd(); + map[std::move(key)] = std::move(value); + } + OnListEnd(); + return map; + } + + template + inline typename std::enable_if::value, T>::type Read() { + using KEY_TYPE = typename is_map::KEY_TYPE; + using VALUE_TYPE = typename is_map::VALUE_TYPE; + + T map; + auto size = OnListBegin(); + for (idx_t i = 0; i < size; i++) { + OnObjectBegin(); + auto key = ReadProperty(0, "key"); + auto value = ReadProperty(1, "value"); + OnObjectEnd(); + map[std::move(key)] = std::move(value); + } + OnListEnd(); + return map; + } + + // Deserialize an unordered set + template + inline typename std::enable_if::value, T>::type Read() { + using ELEMENT_TYPE = typename is_unordered_set::ELEMENT_TYPE; + auto size = OnListBegin(); + T set; + for (idx_t i = 0; i < size; i++) { + set.insert(Read()); + } + OnListEnd(); + return set; + } + + // Deserialize a set + template + inline typename std::enable_if::value, T>::type Read() { + using ELEMENT_TYPE = typename is_set::ELEMENT_TYPE; + auto size = OnListBegin(); + T set; + for (idx_t i = 0; i < size; i++) { + set.insert(Read()); + } + OnListEnd(); + return set; + } + + // Deserialize a pair + template + inline typename std::enable_if::value, T>::type Read() { + using FIRST_TYPE = typename is_pair::FIRST_TYPE; + using SECOND_TYPE = typename is_pair::SECOND_TYPE; + OnObjectBegin(); + auto first = ReadProperty(0, "first"); + auto second = ReadProperty(1, "second"); + OnObjectEnd(); + return std::make_pair(first, second); + } + + // Primitive types + // Deserialize a bool + template + inline typename std::enable_if::value, T>::type Read() { + return ReadBool(); + } + + // Deserialize a char + template + inline typename std::enable_if::value, T>::type Read() { + return ReadChar(); + } + + // Deserialize a int8_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadSignedInt8(); + } + + // Deserialize a uint8_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadUnsignedInt8(); + } + + // Deserialize a int16_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadSignedInt16(); + } + + // Deserialize a uint16_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadUnsignedInt16(); + } + + // Deserialize a int32_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadSignedInt32(); + } + + // Deserialize a uint32_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadUnsignedInt32(); + } + + // Deserialize a int64_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadSignedInt64(); + } + + // Deserialize a uint64_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadUnsignedInt64(); + } + + // Deserialize a float + template + inline typename std::enable_if::value, T>::type Read() { + return ReadFloat(); + } + + // Deserialize a double + template + inline typename std::enable_if::value, T>::type Read() { + return ReadDouble(); + } + + // Deserialize a string + template + inline typename std::enable_if::value, T>::type Read() { + return ReadString(); + } + + // Deserialize a Enum + template + inline typename std::enable_if::value, T>::type Read() { + if (deserialize_enum_from_string) { + auto str = ReadString(); + return EnumUtil::FromString(str.c_str()); + } else { + return (T)Read::type>(); + } + } + + // Deserialize a hugeint_t + template + inline typename std::enable_if::value, T>::type Read() { + return ReadHugeInt(); + } + + // Deserialize a LogicalIndex + template + inline typename std::enable_if::value, T>::type Read() { + return LogicalIndex(ReadUnsignedInt64()); + } + + // Deserialize a PhysicalIndex + template + inline typename std::enable_if::value, T>::type Read() { + return PhysicalIndex(ReadUnsignedInt64()); + } + +protected: + // Hooks for subclasses to override to implement custom behavior + virtual void OnPropertyBegin(const field_id_t field_id, const char *tag) = 0; + virtual void OnPropertyEnd() = 0; + virtual bool OnOptionalPropertyBegin(const field_id_t field_id, const char *tag) = 0; + virtual void OnOptionalPropertyEnd(bool present) = 0; + + virtual void OnObjectBegin() = 0; + virtual void OnObjectEnd() = 0; + virtual idx_t OnListBegin() = 0; + virtual void OnListEnd() = 0; + virtual bool OnNullableBegin() = 0; + virtual void OnNullableEnd() = 0; + + // Handle primitive types, a serializer needs to implement these. + virtual bool ReadBool() = 0; + virtual char ReadChar() { + throw NotImplementedException("ReadChar not implemented"); + } + virtual int8_t ReadSignedInt8() = 0; + virtual uint8_t ReadUnsignedInt8() = 0; + virtual int16_t ReadSignedInt16() = 0; + virtual uint16_t ReadUnsignedInt16() = 0; + virtual int32_t ReadSignedInt32() = 0; + virtual uint32_t ReadUnsignedInt32() = 0; + virtual int64_t ReadSignedInt64() = 0; + virtual uint64_t ReadUnsignedInt64() = 0; + virtual hugeint_t ReadHugeInt() = 0; + virtual float ReadFloat() = 0; + virtual double ReadDouble() = 0; + virtual string ReadString() = 0; + virtual void ReadDataPtr(data_ptr_t &ptr, idx_t count) = 0; +}; + +template +void Deserializer::List::ReadObject(FUNC f) { + deserializer.OnObjectBegin(); + f(deserializer); + deserializer.OnObjectEnd(); +} + +template +T Deserializer::List::ReadElement() { + return deserializer.Read(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp b/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp new file mode 100644 index 00000000..f30cf579 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp @@ -0,0 +1,132 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/encoding_util.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/typedefs.hpp" +#include + +namespace duckdb { + +struct EncodingUtil { + + // Encode unsigned integer, returns the number of bytes written + template + static idx_t EncodeUnsignedLEB128(data_ptr_t target, T value) { + static_assert(std::is_integral::value, "Must be integral"); + static_assert(std::is_unsigned::value, "Must be unsigned"); + static_assert(sizeof(T) <= sizeof(uint64_t), "Must be uint64_t or smaller"); + + idx_t offset = 0; + do { + uint8_t byte = value & 0x7F; + value >>= 7; + if (value != 0) { + byte |= 0x80; + } + target[offset++] = byte; + } while (value != 0); + return offset; + } + + // Decode unsigned integer, returns the number of bytes read + template + static idx_t DecodeUnsignedLEB128(const_data_ptr_t source, T &result) { + static_assert(std::is_integral::value, "Must be integral"); + static_assert(std::is_unsigned::value, "Must be unsigned"); + static_assert(sizeof(T) <= sizeof(uint64_t), "Must be uint64_t or smaller"); + + result = 0; + idx_t shift = 0; + idx_t offset = 0; + uint8_t byte; + do { + byte = source[offset++]; + result |= static_cast(byte & 0x7F) << shift; + shift += 7; + } while (byte & 0x80); + + return offset; + } + + // Encode signed integer, returns the number of bytes written + template + static idx_t EncodeSignedLEB128(data_ptr_t target, T value) { + static_assert(std::is_integral::value, "Must be integral"); + static_assert(std::is_signed::value, "Must be signed"); + static_assert(sizeof(T) <= sizeof(int64_t), "Must be int64_t or smaller"); + + idx_t offset = 0; + do { + uint8_t byte = value & 0x7F; + value >>= 7; + + // Determine whether more bytes are needed + if ((value == 0 && (byte & 0x40) == 0) || (value == -1 && (byte & 0x40))) { + target[offset++] = byte; + break; + } else { + byte |= 0x80; + target[offset++] = byte; + } + } while (true); + return offset; + } + + // Decode signed integer, returns the number of bytes read + template + static idx_t DecodeSignedLEB128(const_data_ptr_t source, T &result) { + static_assert(std::is_integral::value, "Must be integral"); + static_assert(std::is_signed::value, "Must be signed"); + static_assert(sizeof(T) <= sizeof(int64_t), "Must be int64_t or smaller"); + + // This is used to avoid undefined behavior when shifting into the sign bit + using unsigned_type = typename std::make_unsigned::type; + + result = 0; + idx_t shift = 0; + idx_t offset = 0; + + uint8_t byte; + do { + byte = source[offset++]; + result |= static_cast(byte & 0x7F) << shift; + shift += 7; + } while (byte & 0x80); + + // Sign-extend if the most significant bit of the last byte is set + if (shift < sizeof(T) * 8 && (byte & 0x40)) { + result |= -(static_cast(1) << shift); + } + return offset; + } + + template + static typename std::enable_if::value, idx_t>::type DecodeLEB128(const_data_ptr_t source, + T &result) { + return DecodeSignedLEB128(source, result); + } + + template + static typename std::enable_if::value, idx_t>::type DecodeLEB128(const_data_ptr_t source, + T &result) { + return DecodeUnsignedLEB128(source, result); + } + + template + static typename std::enable_if::value, idx_t>::type EncodeLEB128(data_ptr_t target, T value) { + return EncodeSignedLEB128(target, value); + } + + template + static typename std::enable_if::value, idx_t>::type EncodeLEB128(data_ptr_t target, T value) { + return EncodeUnsignedLEB128(target, value); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/memory_stream.hpp b/src/duckdb/src/include/duckdb/common/serializer/memory_stream.hpp new file mode 100644 index 00000000..16ee46f2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/memory_stream.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/memory_stream.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/serializer/write_stream.hpp" +#include "duckdb/common/serializer/read_stream.hpp" +#include "duckdb/common/typedefs.hpp" + +namespace duckdb { + +class MemoryStream : public WriteStream, public ReadStream { +private: + idx_t position; + idx_t capacity; + bool owns_data; + data_ptr_t data; + +public: + // Create a new owning MemoryStream with an internal backing buffer with the specified capacity. The stream will + // own the backing buffer, resize it when needed and free its memory when the stream is destroyed + explicit MemoryStream(idx_t capacity = 512); + + // Create a new non-owning MemoryStream over the specified external buffer and capacity. The stream will not take + // ownership of the backing buffer, will not attempt to resize it and will not free the memory when the stream + // is destroyed + explicit MemoryStream(data_ptr_t buffer, idx_t capacity); + + ~MemoryStream() override; + + // Write data to the stream. + // Throws if the write would exceed the capacity of the stream and the backing buffer is not owned by the stream + void WriteData(const_data_ptr_t buffer, idx_t write_size) override; + + // Read data from the stream. + // Throws if the read would exceed the capacity of the stream + void ReadData(data_ptr_t buffer, idx_t read_size) override; + + // Rewind the stream to the start, keeping the capacity and the backing buffer intact + void Rewind(); + + // Release ownership of the backing buffer and turn a owning stream into a non-owning one. + // The stream will no longer be responsible for freeing the data. + // The stream will also no longer attempt to automatically resize the buffer when the capacity is reached. + void Release(); + + // Get a pointer to the underlying backing buffer + data_ptr_t GetData() const; + + // Get the current position in the stream + idx_t GetPosition() const; + + // Get the capacity of the stream + idx_t GetCapacity() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/read_stream.hpp b/src/duckdb/src/include/duckdb/common/serializer/read_stream.hpp new file mode 100644 index 00000000..18528d56 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/read_stream.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/read_stream.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector.hpp" +#include + +namespace duckdb { + +class ReadStream { +public: + // Reads a set amount of data from the stream into the specified buffer and moves the stream forward accordingly + virtual void ReadData(data_ptr_t buffer, idx_t read_size) = 0; + + // Reads a type from the stream and moves the stream forward sizeof(T) bytes + // The type must be a standard layout type + template + T Read() { + static_assert(std::is_standard_layout(), "Read element must be a standard layout data type"); + T value; + ReadData(data_ptr_cast(&value), sizeof(T)); + return value; + } + + virtual ~ReadStream() { + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp b/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp new file mode 100644 index 00000000..4b079b63 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp @@ -0,0 +1,261 @@ +#pragma once +#include +#include + +#include "duckdb/common/vector.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { + +class Serializer; // Forward declare +class Deserializer; // Forward declare + +typedef uint16_t field_id_t; +const field_id_t MESSAGE_TERMINATOR_FIELD_ID = 0xFFFF; + +// Backport to c++11 +template +using void_t = void; + +// Check for anything implementing a `void Serialize(Serializer &Serializer)` method +template +struct has_serialize : std::false_type {}; +template +struct has_serialize< + T, typename std::enable_if< + std::is_same().Serialize(std::declval())), void>::value, T>::type> + : std::true_type {}; + +template +struct has_deserialize : std::false_type {}; + +// Accept `static unique_ptr Deserialize(Deserializer& deserializer)` +template +struct has_deserialize< + T, typename std::enable_if(Deserializer &)>::value, T>::type> + : std::true_type {}; + +// Accept `static shared_ptr Deserialize(Deserializer& deserializer)` +template +struct has_deserialize< + T, typename std::enable_if(Deserializer &)>::value, T>::type> + : std::true_type {}; + +// Accept `static T Deserialize(Deserializer& deserializer)` +template +struct has_deserialize< + T, typename std::enable_if::value, T>::type> + : std::true_type {}; + +// Check if T is a vector, and provide access to the inner type +template +struct is_vector : std::false_type {}; +template +struct is_vector> : std::true_type { + typedef T ELEMENT_TYPE; +}; + +template +struct is_unsafe_vector : std::false_type {}; +template +struct is_unsafe_vector> : std::true_type { + typedef T ELEMENT_TYPE; +}; + +// Check if T is a unordered map, and provide access to the inner type +template +struct is_unordered_map : std::false_type {}; +template +struct is_unordered_map> : std::true_type { + typedef typename std::tuple_element<0, std::tuple>::type KEY_TYPE; + typedef typename std::tuple_element<1, std::tuple>::type VALUE_TYPE; + typedef typename std::tuple_element<2, std::tuple>::type HASH_TYPE; + typedef typename std::tuple_element<3, std::tuple>::type EQUAL_TYPE; +}; + +template +struct is_map : std::false_type {}; +template +struct is_map> : std::true_type { + typedef typename std::tuple_element<0, std::tuple>::type KEY_TYPE; + typedef typename std::tuple_element<1, std::tuple>::type VALUE_TYPE; + typedef typename std::tuple_element<2, std::tuple>::type HASH_TYPE; + typedef typename std::tuple_element<3, std::tuple>::type EQUAL_TYPE; +}; + +template +struct is_unique_ptr : std::false_type {}; +template +struct is_unique_ptr> : std::true_type { + typedef T ELEMENT_TYPE; +}; + +template +struct is_shared_ptr : std::false_type {}; +template +struct is_shared_ptr> : std::true_type { + typedef T ELEMENT_TYPE; +}; + +template +struct is_optional_ptr : std::false_type {}; +template +struct is_optional_ptr> : std::true_type { + typedef T ELEMENT_TYPE; +}; + +template +struct is_pair : std::false_type {}; +template +struct is_pair> : std::true_type { + typedef T FIRST_TYPE; + typedef U SECOND_TYPE; +}; + +template +struct is_unordered_set : std::false_type {}; +template +struct is_unordered_set> : std::true_type { + typedef typename std::tuple_element<0, std::tuple>::type ELEMENT_TYPE; + typedef typename std::tuple_element<1, std::tuple>::type HASH_TYPE; + typedef typename std::tuple_element<2, std::tuple>::type EQUAL_TYPE; +}; + +template +struct is_set : std::false_type {}; +template +struct is_set> : std::true_type { + typedef typename std::tuple_element<0, std::tuple>::type ELEMENT_TYPE; + typedef typename std::tuple_element<1, std::tuple>::type HASH_TYPE; + typedef typename std::tuple_element<2, std::tuple>::type EQUAL_TYPE; +}; + +template +struct is_atomic : std::false_type {}; + +template +struct is_atomic> : std::true_type { + typedef T TYPE; +}; + +struct SerializationDefaultValue { + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + using INNER = typename is_atomic::TYPE; + return static_cast(GetDefault()); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + using INNER = typename is_atomic::TYPE; + return value == GetDefault(); + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return static_cast(0); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return value == static_cast(0); + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return !value; + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return !value; + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return !value; + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return value.empty(); + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return value.empty(); + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return value.empty(); + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return value.empty(); + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return value.empty(); + } + + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return value.empty(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/serializer.hpp b/src/duckdb/src/include/duckdb/common/serializer/serializer.hpp new file mode 100644 index 00000000..db9f2d32 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/serializer.hpp @@ -0,0 +1,307 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/serializer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/serializer/serialization_traits.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +class Serializer { +protected: + bool serialize_enum_as_string = false; + bool serialize_default_values = false; + +public: + virtual ~Serializer() { + } + + class List { + friend Serializer; + + private: + Serializer &serializer; + explicit List(Serializer &serializer) : serializer(serializer) { + } + + public: + // Serialize an element + template + void WriteElement(const T &value); + + // Serialize an object + template + void WriteObject(FUNC f); + }; + +public: + // Serialize a value + template + void WriteProperty(const field_id_t field_id, const char *tag, const T &value) { + OnPropertyBegin(field_id, tag); + WriteValue(value); + OnPropertyEnd(); + } + + // Default value + template + void WritePropertyWithDefault(const field_id_t field_id, const char *tag, const T &value) { + // If current value is default, don't write it + if (!serialize_default_values && SerializationDefaultValue::IsDefault(value)) { + OnOptionalPropertyBegin(field_id, tag, false); + OnOptionalPropertyEnd(false); + return; + } + OnOptionalPropertyBegin(field_id, tag, true); + WriteValue(value); + OnOptionalPropertyEnd(true); + } + + template + void WritePropertyWithDefault(const field_id_t field_id, const char *tag, const T &value, const T &&default_value) { + // If current value is default, don't write it + if (!serialize_default_values && (value == default_value)) { + OnOptionalPropertyBegin(field_id, tag, false); + OnOptionalPropertyEnd(false); + return; + } + OnOptionalPropertyBegin(field_id, tag, true); + WriteValue(value); + OnOptionalPropertyEnd(true); + } + + // Special case: data_ptr_T + void WriteProperty(const field_id_t field_id, const char *tag, const_data_ptr_t ptr, idx_t count) { + OnPropertyBegin(field_id, tag); + WriteDataPtr(ptr, count); + OnPropertyEnd(); + } + + // Manually begin an object + template + void WriteObject(const field_id_t field_id, const char *tag, FUNC f) { + OnPropertyBegin(field_id, tag); + OnObjectBegin(); + f(*this); + OnObjectEnd(); + OnPropertyEnd(); + } + + template + void WriteList(const field_id_t field_id, const char *tag, idx_t count, FUNC func) { + OnPropertyBegin(field_id, tag); + OnListBegin(count); + List list {*this}; + for (idx_t i = 0; i < count; i++) { + func(list, i); + } + OnListEnd(); + OnPropertyEnd(); + } + +protected: + template + typename std::enable_if::value, void>::type WriteValue(const T value) { + if (serialize_enum_as_string) { + // Use the enum serializer to lookup tostring function + auto str = EnumUtil::ToChars(value); + WriteValue(str); + } else { + // Use the underlying type + WriteValue(static_cast::type>(value)); + } + } + + // Unique Pointer Ref + template + void WriteValue(const unique_ptr &ptr) { + WriteValue(ptr.get()); + } + + // Shared Pointer Ref + template + void WriteValue(const shared_ptr &ptr) { + WriteValue(ptr.get()); + } + + // Pointer + template + void WriteValue(const T *ptr) { + if (ptr == nullptr) { + OnNullableBegin(false); + OnNullableEnd(); + } else { + OnNullableBegin(true); + WriteValue(*ptr); + OnNullableEnd(); + } + } + + // Pair + template + void WriteValue(const std::pair &pair) { + OnObjectBegin(); + WriteProperty(0, "first", pair.first); + WriteProperty(1, "second", pair.second); + OnObjectEnd(); + } + + // Reference Wrapper + template + void WriteValue(const reference ref) { + WriteValue(ref.get()); + } + + // Vector + template + void WriteValue(const vector &vec) { + auto count = vec.size(); + OnListBegin(count); + for (auto &item : vec) { + WriteValue(item); + } + OnListEnd(); + } + + template + void WriteValue(const unsafe_vector &vec) { + auto count = vec.size(); + OnListBegin(count); + for (auto &item : vec) { + WriteValue(item); + } + OnListEnd(); + } + + // UnorderedSet + // Serialized the same way as a list/vector + template + void WriteValue(const duckdb::unordered_set &set) { + auto count = set.size(); + OnListBegin(count); + for (auto &item : set) { + WriteValue(item); + } + OnListEnd(); + } + + // Set + // Serialized the same way as a list/vector + template + void WriteValue(const duckdb::set &set) { + auto count = set.size(); + OnListBegin(count); + for (auto &item : set) { + WriteValue(item); + } + OnListEnd(); + } + + // Map + // serialized as a list of pairs + template + void WriteValue(const duckdb::unordered_map &map) { + auto count = map.size(); + OnListBegin(count); + for (auto &item : map) { + OnObjectBegin(); + WriteProperty(0, "key", item.first); + WriteProperty(1, "value", item.second); + OnObjectEnd(); + } + OnListEnd(); + } + + // Map + // serialized as a list of pairs + template + void WriteValue(const duckdb::map &map) { + auto count = map.size(); + OnListBegin(count); + for (auto &item : map) { + OnObjectBegin(); + WriteProperty(0, "key", item.first); + WriteProperty(1, "value", item.second); + OnObjectEnd(); + } + OnListEnd(); + } + + // class or struct implementing `Serialize(Serializer& Serializer)`; + template + typename std::enable_if::value>::type WriteValue(const T &value) { + OnObjectBegin(); + value.Serialize(*this); + OnObjectEnd(); + } + +protected: + // Hooks for subclasses to override to implement custom behavior + virtual void OnPropertyBegin(const field_id_t field_id, const char *tag) = 0; + virtual void OnPropertyEnd() = 0; + virtual void OnOptionalPropertyBegin(const field_id_t field_id, const char *tag, bool present) = 0; + virtual void OnOptionalPropertyEnd(bool present) = 0; + virtual void OnObjectBegin() = 0; + virtual void OnObjectEnd() = 0; + virtual void OnListBegin(idx_t count) = 0; + virtual void OnListEnd() = 0; + virtual void OnNullableBegin(bool present) = 0; + virtual void OnNullableEnd() = 0; + + // Handle primitive types, a serializer needs to implement these. + virtual void WriteNull() = 0; + virtual void WriteValue(char value) { + throw NotImplementedException("Write char value not implemented"); + } + virtual void WriteValue(bool value) = 0; + virtual void WriteValue(uint8_t value) = 0; + virtual void WriteValue(int8_t value) = 0; + virtual void WriteValue(uint16_t value) = 0; + virtual void WriteValue(int16_t value) = 0; + virtual void WriteValue(uint32_t value) = 0; + virtual void WriteValue(int32_t value) = 0; + virtual void WriteValue(uint64_t value) = 0; + virtual void WriteValue(int64_t value) = 0; + virtual void WriteValue(hugeint_t value) = 0; + virtual void WriteValue(float value) = 0; + virtual void WriteValue(double value) = 0; + virtual void WriteValue(const string_t value) = 0; + virtual void WriteValue(const string &value) = 0; + virtual void WriteValue(const char *str) = 0; + virtual void WriteDataPtr(const_data_ptr_t ptr, idx_t count) = 0; + void WriteValue(LogicalIndex value) { + WriteValue(value.index); + } + void WriteValue(PhysicalIndex value) { + WriteValue(value.index); + } +}; + +// We need to special case vector because elements of vector cannot be referenced +template <> +void Serializer::WriteValue(const vector &vec); + +// List Impl +template +void Serializer::List::WriteObject(FUNC f) { + serializer.OnObjectBegin(); + f(serializer); + serializer.OnObjectEnd(); +} + +template +void Serializer::List::WriteElement(const T &value) { + serializer.WriteValue(value); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/write_stream.hpp b/src/duckdb/src/include/duckdb/common/serializer/write_stream.hpp new file mode 100644 index 00000000..667400ab --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/write_stream.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/write_stream.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector.hpp" +#include + +namespace duckdb { + +class WriteStream { +public: + // Writes a set amount of data from the specified buffer into the stream and moves the stream forward accordingly + virtual void WriteData(const_data_ptr_t buffer, idx_t write_size) = 0; + + // Writes a type into the stream and moves the stream forward sizeof(T) bytes + // The type must be a standard layout type + template + void Write(T element) { + static_assert(std::is_standard_layout(), "Write element must be a standard layout data type"); + WriteData(const_data_ptr_cast(&element), sizeof(T)); + } + + virtual ~WriteStream() { + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/set.hpp b/src/duckdb/src/include/duckdb/common/set.hpp new file mode 100644 index 00000000..aee25d93 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/set.hpp @@ -0,0 +1,16 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::multiset; +using std::set; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/shared_ptr.hpp b/src/duckdb/src/include/duckdb/common/shared_ptr.hpp new file mode 100644 index 00000000..4d97075e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/shared_ptr.hpp @@ -0,0 +1,19 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/shared_ptr.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { + +using std::make_shared; +using std::shared_ptr; +using std::weak_ptr; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/comparators.hpp b/src/duckdb/src/include/duckdb/common/sort/comparators.hpp new file mode 100644 index 00000000..c05c16aa --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sort/comparators.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/sort/comparators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/row/row_layout.hpp" + +namespace duckdb { + +struct SortLayout; +struct SBScanState; + +using ValidityBytes = RowLayout::ValidityBytes; + +struct Comparators { +public: + //! Whether a tie between two blobs can be broken + static bool TieIsBreakable(const idx_t &col_idx, const data_ptr_t &row_ptr, const SortLayout &sort_layout); + //! Compares the tuples that a being read from in the 'left' and 'right blocks during merge sort + //! (only in case we cannot simply 'memcmp' - if there are blob columns) + static int CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, + const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort); + //! Compare two blob values + static int CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type); + +private: + //! Compares two blob values that were initially tied by their prefix + static int BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, + const SortLayout &sort_layout, const bool &external); + //! Compare two fixed-size values + template + static int TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr); + + //! Compare two values at the pointers (can be recursive if nested type) + static int CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid); + //! Compares two fixed-size values at the given pointers + template + static int TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr); + //! Compares two string values at the given pointers + static int CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid); + //! Compares two struct values at the given pointers (recursive) + static int CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, + const child_list_t &types, bool valid); + //! Compare two list values at the pointers (can be recursive if nested type) + static int CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid); + //! Compares a list of fixed-size values + template + static int TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const ValidityBytes &left_validity, + const ValidityBytes &right_validity, const idx_t &count); + + //! Unwizzles an offset into a pointer + static void UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); + //! Swizzles a pointer into an offset + static void SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp b/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp new file mode 100644 index 00000000..d779511c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp @@ -0,0 +1,708 @@ +/* +pdqsort.h - Pattern-defeating quicksort. + +Copyright (c) 2021 Orson Peters + +This software is provided 'as-is', without any express or implied warranty. In no event will the +authors be held liable for any damages arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, including commercial +applications, and to alter it and redistribute it freely, subject to the following restrictions: + +1. The origin of this software must not be misrepresented; you must not claim that you wrote the + original software. If you use this software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + +2. Altered source versions must be plainly marked as such, and must not be misrepresented as + being the original software. + +3. This notice may not be removed or altered from any source distribution. +*/ + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/unique_ptr.hpp" + +#include +#include +#include +#include +#include + +namespace duckdb_pdqsort { + +using duckdb::idx_t; +using duckdb::data_t; +using duckdb::data_ptr_t; +using duckdb::unique_ptr; +using duckdb::unique_array; +using duckdb::unsafe_unique_array; +using duckdb::make_uniq_array; +using duckdb::make_unsafe_uniq_array; +using duckdb::FastMemcpy; +using duckdb::FastMemcmp; + +enum { + // Partitions below this size are sorted using insertion sort. + insertion_sort_threshold = 24, + + // Partitions above this size use Tukey's ninther to select the pivot. + ninther_threshold = 128, + + // When we detect an already sorted partition, attempt an insertion sort that allows this + // amount of element moves before giving up. + partial_insertion_sort_limit = 8, + + // Must be multiple of 8 due to loop unrolling, and < 256 to fit in unsigned char. + block_size = 64, + + // Cacheline size, assumes power of two. + cacheline_size = 64 + +}; + +// Returns floor(log2(n)), assumes n > 0. +template +inline int log2(T n) { + int log = 0; + while (n >>= 1) { + ++log; + } + return log; +} + +struct PDQConstants { + PDQConstants(idx_t entry_size, idx_t comp_offset, idx_t comp_size, data_ptr_t end) + : entry_size(entry_size), comp_offset(comp_offset), comp_size(comp_size), + tmp_buf_ptr(make_unsafe_uniq_array(entry_size)), tmp_buf(tmp_buf_ptr.get()), + iter_swap_buf_ptr(make_unsafe_uniq_array(entry_size)), iter_swap_buf(iter_swap_buf_ptr.get()), + swap_offsets_buf_ptr(make_unsafe_uniq_array(entry_size)), + swap_offsets_buf(swap_offsets_buf_ptr.get()), end(end) { + } + + const duckdb::idx_t entry_size; + const idx_t comp_offset; + const idx_t comp_size; + + unsafe_unique_array tmp_buf_ptr; + const data_ptr_t tmp_buf; + + unsafe_unique_array iter_swap_buf_ptr; + const data_ptr_t iter_swap_buf; + + unsafe_unique_array swap_offsets_buf_ptr; + const data_ptr_t swap_offsets_buf; + + const data_ptr_t end; +}; + +struct PDQIterator { + PDQIterator(data_ptr_t ptr, const idx_t &entry_size) : ptr(ptr), entry_size(entry_size) { + } + + inline PDQIterator(const PDQIterator &other) : ptr(other.ptr), entry_size(other.entry_size) { + } + + inline const data_ptr_t &operator*() const { + return ptr; + } + + inline PDQIterator &operator++() { + ptr += entry_size; + return *this; + } + + inline PDQIterator &operator--() { + ptr -= entry_size; + return *this; + } + + inline PDQIterator operator++(int) { + auto tmp = *this; + ptr += entry_size; + return tmp; + } + + inline PDQIterator operator--(int) { + auto tmp = *this; + ptr -= entry_size; + return tmp; + } + + inline PDQIterator operator+(const idx_t &i) const { + auto result = *this; + result.ptr += i * entry_size; + return result; + } + + inline PDQIterator operator-(const idx_t &i) const { + PDQIterator result = *this; + result.ptr -= i * entry_size; + return result; + } + + inline PDQIterator &operator=(const PDQIterator &other) { + D_ASSERT(entry_size == other.entry_size); + ptr = other.ptr; + return *this; + } + + inline friend idx_t operator-(const PDQIterator &lhs, const PDQIterator &rhs) { + D_ASSERT((*lhs - *rhs) % lhs.entry_size == 0); + D_ASSERT(*lhs - *rhs >= 0); + return (*lhs - *rhs) / lhs.entry_size; + } + + inline friend bool operator<(const PDQIterator &lhs, const PDQIterator &rhs) { + return *lhs < *rhs; + } + + inline friend bool operator>(const PDQIterator &lhs, const PDQIterator &rhs) { + return *lhs > *rhs; + } + + inline friend bool operator>=(const PDQIterator &lhs, const PDQIterator &rhs) { + return *lhs >= *rhs; + } + + inline friend bool operator<=(const PDQIterator &lhs, const PDQIterator &rhs) { + return *lhs <= *rhs; + } + + inline friend bool operator==(const PDQIterator &lhs, const PDQIterator &rhs) { + return *lhs == *rhs; + } + + inline friend bool operator!=(const PDQIterator &lhs, const PDQIterator &rhs) { + return *lhs != *rhs; + } + +private: + data_ptr_t ptr; + const idx_t &entry_size; +}; + +static inline bool comp(const data_ptr_t &l, const data_ptr_t &r, const PDQConstants &constants) { + D_ASSERT(l == constants.tmp_buf || l == constants.swap_offsets_buf || l < constants.end); + D_ASSERT(r == constants.tmp_buf || r == constants.swap_offsets_buf || r < constants.end); + return FastMemcmp(l + constants.comp_offset, r + constants.comp_offset, constants.comp_size) < 0; +} + +static inline const data_ptr_t &GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { + D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); + FastMemcpy(constants.tmp_buf, src, constants.entry_size); + return constants.tmp_buf; +} + +static inline const data_ptr_t &SWAP_OFFSETS_GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { + D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); + FastMemcpy(constants.swap_offsets_buf, src, constants.entry_size); + return constants.swap_offsets_buf; +} + +static inline void MOVE(const data_ptr_t &dest, const data_ptr_t &src, const PDQConstants &constants) { + D_ASSERT(dest == constants.tmp_buf || dest == constants.swap_offsets_buf || dest < constants.end); + D_ASSERT(src == constants.tmp_buf || src == constants.swap_offsets_buf || src < constants.end); + FastMemcpy(dest, src, constants.entry_size); +} + +static inline void iter_swap(const PDQIterator &lhs, const PDQIterator &rhs, const PDQConstants &constants) { + D_ASSERT(*lhs < constants.end); + D_ASSERT(*rhs < constants.end); + FastMemcpy(constants.iter_swap_buf, *lhs, constants.entry_size); + FastMemcpy(*lhs, *rhs, constants.entry_size); + FastMemcpy(*rhs, constants.iter_swap_buf, constants.entry_size); +} + +// Sorts [begin, end) using insertion sort with the given comparison function. +inline void insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { + if (begin == end) { + return; + } + + for (PDQIterator cur = begin + 1; cur != end; ++cur) { + PDQIterator sift = cur; + PDQIterator sift_1 = cur - 1; + + // Compare first so we can avoid 2 moves for an element already positioned correctly. + if (comp(*sift, *sift_1, constants)) { + const auto &tmp = GET_TMP(*sift, constants); + + do { + MOVE(*sift--, *sift_1, constants); + } while (sift != begin && comp(tmp, *--sift_1, constants)); + + MOVE(*sift, tmp, constants); + } + } +} + +// Sorts [begin, end) using insertion sort with the given comparison function. Assumes +// *(begin - 1) is an element smaller than or equal to any element in [begin, end). +inline void unguarded_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { + if (begin == end) { + return; + } + + for (PDQIterator cur = begin + 1; cur != end; ++cur) { + PDQIterator sift = cur; + PDQIterator sift_1 = cur - 1; + + // Compare first so we can avoid 2 moves for an element already positioned correctly. + if (comp(*sift, *sift_1, constants)) { + const auto &tmp = GET_TMP(*sift, constants); + + do { + MOVE(*sift--, *sift_1, constants); + } while (comp(tmp, *--sift_1, constants)); + + MOVE(*sift, tmp, constants); + } + } +} + +// Attempts to use insertion sort on [begin, end). Will return false if more than +// partial_insertion_sort_limit elements were moved, and abort sorting. Otherwise it will +// successfully sort and return true. +inline bool partial_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { + if (begin == end) { + return true; + } + + std::size_t limit = 0; + for (PDQIterator cur = begin + 1; cur != end; ++cur) { + PDQIterator sift = cur; + PDQIterator sift_1 = cur - 1; + + // Compare first so we can avoid 2 moves for an element already positioned correctly. + if (comp(*sift, *sift_1, constants)) { + const auto &tmp = GET_TMP(*sift, constants); + + do { + MOVE(*sift--, *sift_1, constants); + } while (sift != begin && comp(tmp, *--sift_1, constants)); + + MOVE(*sift, tmp, constants); + limit += cur - sift; + } + + if (limit > partial_insertion_sort_limit) { + return false; + } + } + + return true; +} + +inline void sort2(const PDQIterator &a, const PDQIterator &b, const PDQConstants &constants) { + if (comp(*b, *a, constants)) { + iter_swap(a, b, constants); + } +} + +// Sorts the elements *a, *b and *c using comparison function comp. +inline void sort3(const PDQIterator &a, const PDQIterator &b, const PDQIterator &c, const PDQConstants &constants) { + sort2(a, b, constants); + sort2(b, c, constants); + sort2(a, b, constants); +} + +template +inline T *align_cacheline(T *p) { +#if defined(UINTPTR_MAX) && __cplusplus >= 201103L + std::uintptr_t ip = reinterpret_cast(p); +#else + std::size_t ip = reinterpret_cast(p); +#endif + ip = (ip + cacheline_size - 1) & -cacheline_size; + return reinterpret_cast(ip); +} + +inline void swap_offsets(const PDQIterator &first, const PDQIterator &last, unsigned char *offsets_l, + unsigned char *offsets_r, size_t num, bool use_swaps, const PDQConstants &constants) { + if (use_swaps) { + // This case is needed for the descending distribution, where we need + // to have proper swapping for pdqsort to remain O(n). + for (size_t i = 0; i < num; ++i) { + iter_swap(first + offsets_l[i], last - offsets_r[i], constants); + } + } else if (num > 0) { + PDQIterator l = first + offsets_l[0]; + PDQIterator r = last - offsets_r[0]; + const auto &tmp = SWAP_OFFSETS_GET_TMP(*l, constants); + MOVE(*l, *r, constants); + for (size_t i = 1; i < num; ++i) { + l = first + offsets_l[i]; + MOVE(*r, *l, constants); + r = last - offsets_r[i]; + MOVE(*l, *r, constants); + } + MOVE(*r, tmp, constants); + } +} + +// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal +// to the pivot are put in the right-hand partition. Returns the position of the pivot after +// partitioning and whether the passed sequence already was correctly partitioned. Assumes the +// pivot is a median of at least 3 elements and that [begin, end) is at least +// insertion_sort_threshold long. Uses branchless partitioning. +inline std::pair partition_right_branchless(const PDQIterator &begin, const PDQIterator &end, + const PDQConstants &constants) { + // Move pivot into local for speed. + const auto &pivot = GET_TMP(*begin, constants); + PDQIterator first = begin; + PDQIterator last = end; + + // Find the first element greater than or equal than the pivot (the median of 3 guarantees + // this exists). + while (comp(*++first, pivot, constants)) { + } + + // Find the first element strictly smaller than the pivot. We have to guard this search if + // there was no element before *first. + if (first - 1 == begin) { + while (first < last && !comp(*--last, pivot, constants)) { + } + } else { + while (!comp(*--last, pivot, constants)) { + } + } + + // If the first pair of elements that should be swapped to partition are the same element, + // the passed in sequence already was correctly partitioned. + bool already_partitioned = first >= last; + if (!already_partitioned) { + iter_swap(first, last, constants); + ++first; + + // The following branchless partitioning is derived from "BlockQuicksort: How Branch + // Mispredictions don’t affect Quicksort" by Stefan Edelkamp and Armin Weiss, but + // heavily micro-optimized. + unsigned char offsets_l_storage[block_size + cacheline_size]; + unsigned char offsets_r_storage[block_size + cacheline_size]; + unsigned char *offsets_l = align_cacheline(offsets_l_storage); + unsigned char *offsets_r = align_cacheline(offsets_r_storage); + + PDQIterator offsets_l_base = first; + PDQIterator offsets_r_base = last; + size_t num_l, num_r, start_l, start_r; + num_l = num_r = start_l = start_r = 0; + + while (first < last) { + // Fill up offset blocks with elements that are on the wrong side. + // First we determine how much elements are considered for each offset block. + size_t num_unknown = last - first; + size_t left_split = num_l == 0 ? (num_r == 0 ? num_unknown / 2 : num_unknown) : 0; + size_t right_split = num_r == 0 ? (num_unknown - left_split) : 0; + + // Fill the offset blocks. + if (left_split >= block_size) { + for (size_t i = 0; i < block_size;) { + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + } + } else { + for (size_t i = 0; i < left_split;) { + offsets_l[num_l] = i++; + num_l += !comp(*first, pivot, constants); + ++first; + } + } + + if (right_split >= block_size) { + for (size_t i = 0; i < block_size;) { + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + } + } else { + for (size_t i = 0; i < right_split;) { + offsets_r[num_r] = ++i; + num_r += comp(*--last, pivot, constants); + } + } + + // Swap elements and update block sizes and first/last boundaries. + size_t num = std::min(num_l, num_r); + swap_offsets(offsets_l_base, offsets_r_base, offsets_l + start_l, offsets_r + start_r, num, num_l == num_r, + constants); + num_l -= num; + num_r -= num; + start_l += num; + start_r += num; + + if (num_l == 0) { + start_l = 0; + offsets_l_base = first; + } + + if (num_r == 0) { + start_r = 0; + offsets_r_base = last; + } + } + + // We have now fully identified [first, last)'s proper position. Swap the last elements. + if (num_l) { + offsets_l += start_l; + while (num_l--) { + iter_swap(offsets_l_base + offsets_l[num_l], --last, constants); + } + first = last; + } + if (num_r) { + offsets_r += start_r; + while (num_r--) { + iter_swap(offsets_r_base - offsets_r[num_r], first, constants), ++first; + } + last = first; + } + } + + // Put the pivot in the right place. + PDQIterator pivot_pos = first - 1; + MOVE(*begin, *pivot_pos, constants); + MOVE(*pivot_pos, pivot, constants); + + return std::make_pair(pivot_pos, already_partitioned); +} + +// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal +// to the pivot are put in the right-hand partition. Returns the position of the pivot after +// partitioning and whether the passed sequence already was correctly partitioned. Assumes the +// pivot is a median of at least 3 elements and that [begin, end) is at least +// insertion_sort_threshold long. +inline std::pair partition_right(const PDQIterator &begin, const PDQIterator &end, + const PDQConstants &constants) { + // Move pivot into local for speed. + const auto &pivot = GET_TMP(*begin, constants); + + PDQIterator first = begin; + PDQIterator last = end; + + // Find the first element greater than or equal than the pivot (the median of 3 guarantees + // this exists). + while (comp(*++first, pivot, constants)) { + } + + // Find the first element strictly smaller than the pivot. We have to guard this search if + // there was no element before *first. + if (first - 1 == begin) { + while (first < last && !comp(*--last, pivot, constants)) { + } + } else { + while (!comp(*--last, pivot, constants)) { + } + } + + // If the first pair of elements that should be swapped to partition are the same element, + // the passed in sequence already was correctly partitioned. + bool already_partitioned = first >= last; + + // Keep swapping pairs of elements that are on the wrong side of the pivot. Previously + // swapped pairs guard the searches, which is why the first iteration is special-cased + // above. + while (first < last) { + iter_swap(first, last, constants); + while (comp(*++first, pivot, constants)) { + } + while (!comp(*--last, pivot, constants)) { + } + } + + // Put the pivot in the right place. + PDQIterator pivot_pos = first - 1; + MOVE(*begin, *pivot_pos, constants); + MOVE(*pivot_pos, pivot, constants); + + return std::make_pair(pivot_pos, already_partitioned); +} + +// Similar function to the one above, except elements equal to the pivot are put to the left of +// the pivot and it doesn't check or return if the passed sequence already was partitioned. +// Since this is rarely used (the many equal case), and in that case pdqsort already has O(n) +// performance, no block quicksort is applied here for simplicity. +inline PDQIterator partition_left(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { + const auto &pivot = GET_TMP(*begin, constants); + PDQIterator first = begin; + PDQIterator last = end; + + while (comp(pivot, *--last, constants)) { + } + + if (last + 1 == end) { + while (first < last && !comp(pivot, *++first, constants)) { + } + } else { + while (!comp(pivot, *++first, constants)) { + } + } + + while (first < last) { + iter_swap(first, last, constants); + while (comp(pivot, *--last, constants)) { + } + while (!comp(pivot, *++first, constants)) { + } + } + + PDQIterator pivot_pos = last; + MOVE(*begin, *pivot_pos, constants); + MOVE(*pivot_pos, pivot, constants); + + return pivot_pos; +} + +template +inline void pdqsort_loop(PDQIterator begin, const PDQIterator &end, const PDQConstants &constants, int bad_allowed, + bool leftmost = true) { + // Use a while loop for tail recursion elimination. + while (true) { + idx_t size = end - begin; + + // Insertion sort is faster for small arrays. + if (size < insertion_sort_threshold) { + if (leftmost) { + insertion_sort(begin, end, constants); + } else { + unguarded_insertion_sort(begin, end, constants); + } + return; + } + + // Choose pivot as median of 3 or pseudomedian of 9. + idx_t s2 = size / 2; + if (size > ninther_threshold) { + sort3(begin, begin + s2, end - 1, constants); + sort3(begin + 1, begin + (s2 - 1), end - 2, constants); + sort3(begin + 2, begin + (s2 + 1), end - 3, constants); + sort3(begin + (s2 - 1), begin + s2, begin + (s2 + 1), constants); + iter_swap(begin, begin + s2, constants); + } else { + sort3(begin + s2, begin, end - 1, constants); + } + + // If *(begin - 1) is the end of the right partition of a previous partition operation + // there is no element in [begin, end) that is smaller than *(begin - 1). Then if our + // pivot compares equal to *(begin - 1) we change strategy, putting equal elements in + // the left partition, greater elements in the right partition. We do not have to + // recurse on the left partition, since it's sorted (all equal). + if (!leftmost && !comp(*(begin - 1), *begin, constants)) { + begin = partition_left(begin, end, constants) + 1; + continue; + } + + // Partition and get results. + std::pair part_result = + Branchless ? partition_right_branchless(begin, end, constants) : partition_right(begin, end, constants); + PDQIterator pivot_pos = part_result.first; + bool already_partitioned = part_result.second; + + // Check for a highly unbalanced partition. + idx_t l_size = pivot_pos - begin; + idx_t r_size = end - (pivot_pos + 1); + bool highly_unbalanced = l_size < size / 8 || r_size < size / 8; + + // If we got a highly unbalanced partition we shuffle elements to break many patterns. + if (highly_unbalanced) { + // If we had too many bad partitions, switch to heapsort to guarantee O(n log n). + // if (--bad_allowed == 0) { + // std::make_heap(begin, end, comp); + // std::sort_heap(begin, end, comp); + // return; + // } + + if (l_size >= insertion_sort_threshold) { + iter_swap(begin, begin + l_size / 4, constants); + iter_swap(pivot_pos - 1, pivot_pos - l_size / 4, constants); + + if (l_size > ninther_threshold) { + iter_swap(begin + 1, begin + (l_size / 4 + 1), constants); + iter_swap(begin + 2, begin + (l_size / 4 + 2), constants); + iter_swap(pivot_pos - 2, pivot_pos - (l_size / 4 + 1), constants); + iter_swap(pivot_pos - 3, pivot_pos - (l_size / 4 + 2), constants); + } + } + + if (r_size >= insertion_sort_threshold) { + iter_swap(pivot_pos + 1, pivot_pos + (1 + r_size / 4), constants); + iter_swap(end - 1, end - r_size / 4, constants); + + if (r_size > ninther_threshold) { + iter_swap(pivot_pos + 2, pivot_pos + (2 + r_size / 4), constants); + iter_swap(pivot_pos + 3, pivot_pos + (3 + r_size / 4), constants); + iter_swap(end - 2, end - (1 + r_size / 4), constants); + iter_swap(end - 3, end - (2 + r_size / 4), constants); + } + } + } else { + // If we were decently balanced and we tried to sort an already partitioned + // sequence try to use insertion sort. + if (already_partitioned && partial_insertion_sort(begin, pivot_pos, constants) && + partial_insertion_sort(pivot_pos + 1, end, constants)) { + return; + } + } + + // Sort the left partition first using recursion and do tail recursion elimination for + // the right-hand partition. + pdqsort_loop(begin, pivot_pos, constants, bad_allowed, leftmost); + begin = pivot_pos + 1; + leftmost = false; + } +} + +inline void pdqsort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { + if (begin == end) { + return; + } + pdqsort_loop(begin, end, constants, log2(end - begin)); +} + +inline void pdqsort_branchless(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { + if (begin == end) { + return; + } + pdqsort_loop(begin, end, constants, log2(end - begin)); +} + +} // namespace duckdb_pdqsort diff --git a/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp b/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp new file mode 100644 index 00000000..cb497569 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp @@ -0,0 +1,229 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/sort/partition_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/types/column/partitioned_column_data.hpp" +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" + +namespace duckdb { + +class PartitionGlobalHashGroup { +public: + using GlobalSortStatePtr = unique_ptr; + using Orders = vector; + using Types = vector; + + PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, const Orders &orders, + const Types &payload_types, bool external); + + int ComparePartitions(const SBIterator &left, const SBIterator &right) const; + + void ComputeMasks(ValidityMask &partition_mask, ValidityMask &order_mask); + + GlobalSortStatePtr global_sort; + atomic count; + idx_t batch_base; + + // Mask computation + SortLayout partition_layout; +}; + +class PartitionGlobalSinkState { +public: + using HashGroupPtr = unique_ptr; + using Orders = vector; + using Types = vector; + + using GroupingPartition = unique_ptr; + using GroupingAppend = unique_ptr; + + static void GenerateOrderings(Orders &partitions, Orders &orders, + const vector> &partition_bys, const Orders &order_bys, + const vector> &partitions_stats); + + PartitionGlobalSinkState(ClientContext &context, const vector> &partition_bys, + const vector &order_bys, const Types &payload_types, + const vector> &partitions_stats, idx_t estimated_cardinality); + + bool HasMergeTasks() const; + + unique_ptr CreatePartition(idx_t new_bits) const; + void SyncPartitioning(const PartitionGlobalSinkState &other); + + void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); + void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); + + ClientContext &context; + BufferManager &buffer_manager; + Allocator &allocator; + mutex lock; + + // OVER(PARTITION BY...) (hash grouping) + unique_ptr grouping_data; + //! Payload plus hash column + TupleDataLayout grouping_types; + //! The number of radix bits if this partition is being synced with another + idx_t fixed_bits; + + // OVER(...) (sorting) + Orders partitions; + Orders orders; + const Types payload_types; + vector hash_groups; + bool external; + // Reverse lookup from hash bins to non-empty hash groups + vector bin_groups; + + // OVER() (no sorting) + unique_ptr rows; + unique_ptr strings; + + // Threading + idx_t memory_per_thread; + idx_t max_bits; + atomic count; + +private: + void ResizeGroupingData(idx_t cardinality); + void SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); +}; + +class PartitionLocalSinkState { +public: + using LocalSortStatePtr = unique_ptr; + + PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p); + + // Global state + PartitionGlobalSinkState &gstate; + Allocator &allocator; + + // Shared expression evaluation + ExpressionExecutor executor; + DataChunk group_chunk; + DataChunk payload_chunk; + size_t sort_cols; + + // OVER(PARTITION BY...) (hash grouping) + unique_ptr local_partition; + unique_ptr local_append; + + // OVER(ORDER BY...) (only sorting) + LocalSortStatePtr local_sort; + + // OVER() (no sorting) + RowLayout payload_layout; + unique_ptr rows; + unique_ptr strings; + + //! Compute the hash values + void Hash(DataChunk &input_chunk, Vector &hash_vector); + //! Sink an input chunk + void Sink(DataChunk &input_chunk); + //! Merge the state into the global state. + void Combine(); +}; + +enum class PartitionSortStage : uint8_t { INIT, SCAN, PREPARE, MERGE, SORTED }; + +class PartitionLocalMergeState; + +class PartitionGlobalMergeState { +public: + using GroupDataPtr = unique_ptr; + + // OVER(PARTITION BY...) + PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data, hash_t hash_bin); + + // OVER(ORDER BY...) + explicit PartitionGlobalMergeState(PartitionGlobalSinkState &sink); + + bool IsSorted() const { + lock_guard guard(lock); + return stage == PartitionSortStage::SORTED; + } + + bool AssignTask(PartitionLocalMergeState &local_state); + bool TryPrepareNextStage(); + void CompleteTask(); + + PartitionGlobalSinkState &sink; + GroupDataPtr group_data; + PartitionGlobalHashGroup *hash_group; + vector column_ids; + TupleDataParallelScanState chunk_state; + GlobalSortState *global_sort; + const idx_t memory_per_thread; + const idx_t num_threads; + +private: + mutable mutex lock; + PartitionSortStage stage; + idx_t total_tasks; + idx_t tasks_assigned; + idx_t tasks_completed; +}; + +class PartitionLocalMergeState { +public: + explicit PartitionLocalMergeState(PartitionGlobalSinkState &gstate); + + bool TaskFinished() { + return finished; + } + + void Prepare(); + void Scan(); + void Merge(); + + void ExecuteTask(); + + PartitionGlobalMergeState *merge_state; + PartitionSortStage stage; + atomic finished; + + // Sorting buffers + ExpressionExecutor executor; + DataChunk sort_chunk; + DataChunk payload_chunk; +}; + +class PartitionGlobalMergeStates { +public: + struct Callback { + virtual bool HasError() const { + return false; + } + }; + + using PartitionGlobalMergeStatePtr = unique_ptr; + + explicit PartitionGlobalMergeStates(PartitionGlobalSinkState &sink); + + bool ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback); + + vector states; +}; + +class PartitionMergeEvent : public BasePipelineEvent { +public: + PartitionMergeEvent(PartitionGlobalSinkState &gstate_p, Pipeline &pipeline_p) + : BasePipelineEvent(pipeline_p), gstate(gstate_p), merge_states(gstate_p) { + } + + PartitionGlobalSinkState &gstate; + PartitionGlobalMergeStates merge_states; + +public: + void Schedule() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/sort.hpp b/src/duckdb/src/include/duckdb/common/sort/sort.hpp new file mode 100644 index 00000000..e8478d90 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sort/sort.hpp @@ -0,0 +1,207 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/sort/sort.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/sort/sorted_block.hpp" +#include "duckdb/common/types/row/row_data_collection.hpp" +#include "duckdb/planner/bound_query_node.hpp" + +namespace duckdb { + +class RowLayout; +struct LocalSortState; + +struct SortConstants { + static constexpr idx_t VALUES_PER_RADIX = 256; + static constexpr idx_t MSD_RADIX_LOCATIONS = VALUES_PER_RADIX + 1; + static constexpr idx_t INSERTION_SORT_THRESHOLD = 24; + static constexpr idx_t MSD_RADIX_SORT_SIZE_THRESHOLD = 4; +}; + +struct SortLayout { +public: + SortLayout() { + } + explicit SortLayout(const vector &orders); + SortLayout GetPrefixComparisonLayout(idx_t num_prefix_cols) const; + +public: + idx_t column_count; + vector order_types; + vector order_by_null_types; + vector logical_types; + + bool all_constant; + vector constant_size; + vector column_sizes; + vector prefix_lengths; + vector stats; + vector has_null; + + idx_t comparison_size; + idx_t entry_size; + + RowLayout blob_layout; + unordered_map sorting_to_blob_col; +}; + +struct GlobalSortState { +public: + GlobalSortState(BufferManager &buffer_manager, const vector &orders, RowLayout &payload_layout); + + //! Add local state sorted data to this global state + void AddLocalState(LocalSortState &local_sort_state); + //! Prepares the GlobalSortState for the merge sort phase (after completing radix sort phase) + void PrepareMergePhase(); + //! Initializes the global sort state for another round of merging + void InitializeMergeRound(); + //! Completes the cascaded merge sort round. + //! Pass true if you wish to use the radix data for further comparisons. + void CompleteMergeRound(bool keep_radix_data = false); + //! Print the sorted data to the console. + void Print(); + +public: + //! The lock for updating the order global state + mutex lock; + //! The buffer manager + BufferManager &buffer_manager; + + //! Sorting and payload layouts + const SortLayout sort_layout; + const RowLayout payload_layout; + + //! Sorted data + vector> sorted_blocks; + vector>> sorted_blocks_temp; + unique_ptr odd_one_out; + + //! Pinned heap data (if sorting in memory) + vector> heap_blocks; + vector pinned_blocks; + + //! Capacity (number of rows) used to initialize blocks + idx_t block_capacity; + //! Whether we are doing an external sort + bool external; + + //! Progress in merge path stage + idx_t pair_idx; + idx_t num_pairs; + idx_t l_start; + idx_t r_start; +}; + +struct LocalSortState { +public: + LocalSortState(); + + //! Initialize the layouts and RowDataCollections + void Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p); + //! Sink one DataChunk into the local sort state + void SinkChunk(DataChunk &sort, DataChunk &payload); + //! Size of accumulated data in bytes + idx_t SizeInBytes() const; + //! Sort the data accumulated so far + void Sort(GlobalSortState &global_sort_state, bool reorder_heap); + //! Concatenate the blocks held by a RowDataCollection into a single block + static unique_ptr ConcatenateBlocks(RowDataCollection &row_data); + +private: + //! Sorts the data in the newly created SortedBlock + void SortInMemory(); + //! Re-order the local state after sorting + void ReOrder(GlobalSortState &gstate, bool reorder_heap); + //! Re-order a SortedData object after sorting + void ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, + bool reorder_heap); + +public: + //! Whether this local state has been initialized + bool initialized; + //! The buffer manager + BufferManager *buffer_manager; + //! The sorting and payload layouts + const SortLayout *sort_layout; + const RowLayout *payload_layout; + //! Radix/memcmp sortable data + unique_ptr radix_sorting_data; + //! Variable sized sorting data and accompanying heap + unique_ptr blob_sorting_data; + unique_ptr blob_sorting_heap; + //! Payload data and accompanying heap + unique_ptr payload_data; + unique_ptr payload_heap; + //! Sorted data + vector> sorted_blocks; + +private: + //! Selection vector and addresses for scattering the data to rows + const SelectionVector &sel_ptr = *FlatVector::IncrementalSelectionVector(); + Vector addresses = Vector(LogicalType::POINTER); +}; + +struct MergeSorter { +public: + MergeSorter(GlobalSortState &state, BufferManager &buffer_manager); + + //! Finds and merges partitions until the current cascaded merge round is finished + void PerformInMergeRound(); + +private: + //! The global sorting state + GlobalSortState &state; + //! The sorting and payload layouts + BufferManager &buffer_manager; + const SortLayout &sort_layout; + + //! The left and right reader + unique_ptr left; + unique_ptr right; + + //! Input and output blocks + unique_ptr left_input; + unique_ptr right_input; + SortedBlock *result; + +private: + //! Computes the left and right block that will be merged next (Merge Path partition) + void GetNextPartition(); + //! Finds the boundary of the next partition using binary search + void GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx); + //! Compare values within SortedBlocks using a global index + int CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx); + + //! Finds the next partition and merges it + void MergePartition(); + + //! Computes how the next 'count' tuples should be merged by setting the 'left_smaller' array + void ComputeMerge(const idx_t &count, bool left_smaller[]); + + //! Merges the radix sorting blocks according to the 'left_smaller' array + void MergeRadix(const idx_t &count, const bool left_smaller[]); + //! Merges SortedData according to the 'left_smaller' array + void MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, + const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices); + //! Merges constant size rows according to the 'left_smaller' array + void MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, idx_t &r_entry_idx, + const idx_t &r_count, RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, + const bool left_smaller[], idx_t &copied, const idx_t &count); + //! Flushes constant size rows into the result + void FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, + RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, + const idx_t &count); + //! Flushes blob rows and accompanying heap + void FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, + idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, + data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, BufferHandle &target_heap_handle, + data_ptr_t &target_heap_ptr, idx_t &copied, const idx_t &count); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp b/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp new file mode 100644 index 00000000..bee58bca --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp @@ -0,0 +1,243 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/sort/sorted_block.hpp +// +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/sort/comparators.hpp" +#include "duckdb/common/types/row/row_data_collection_scanner.hpp" +#include "duckdb/common/types/row/row_layout.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" + +namespace duckdb { + +class BufferManager; +struct RowDataBlock; +struct SortLayout; +struct GlobalSortState; + +enum class SortedDataType { BLOB, PAYLOAD }; + +//! Object that holds sorted rows, and an accompanying heap if there are blobs +struct SortedData { +public: + SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, GlobalSortState &state); + //! Number of rows that this object holds + idx_t Count(); + //! Initialize new block to write to + void CreateBlock(); + //! Create a slice that holds the rows between the start and end indices + unique_ptr CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index); + //! Unswizzles all + void Unswizzle(); + +public: + const SortedDataType type; + //! Layout of this data + const RowLayout layout; + //! Data and heap blocks + vector> data_blocks; + vector> heap_blocks; + //! Whether the pointers in this sorted data are swizzled + bool swizzled; + +private: + //! The buffer manager + BufferManager &buffer_manager; + //! The global state + GlobalSortState &state; +}; + +//! Block that holds sorted rows: radix, blob and payload data +struct SortedBlock { +public: + SortedBlock(BufferManager &buffer_manager, GlobalSortState &gstate); + //! Number of rows that this object holds + idx_t Count() const; + //! Initialize this block to write data to + void InitializeWrite(); + //! Init new block to write to + void CreateBlock(); + //! Fill this sorted block by appending the blocks held by a vector of sorted blocks + void AppendSortedBlocks(vector> &sorted_blocks); + //! Locate the block and entry index of a row in this block, + //! given an index between 0 and the total number of rows in this block + void GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index); + //! Create a slice that holds the rows between the start and end indices + unique_ptr CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx); + + //! Size (in bytes) of the heap of this block + idx_t HeapSize() const; + //! Total size (in bytes) of this block + idx_t SizeInBytes() const; + +public: + //! Radix/memcmp sortable data + vector> radix_sorting_data; + //! Variable sized sorting data + unique_ptr blob_sorting_data; + //! Payload data + unique_ptr payload_data; + +private: + //! Buffer manager, global state, and sorting layout constants + BufferManager &buffer_manager; + GlobalSortState &state; + const SortLayout &sort_layout; + const RowLayout &payload_layout; +}; + +//! State used to scan a SortedBlock e.g. during merge sort +struct SBScanState { +public: + SBScanState(BufferManager &buffer_manager, GlobalSortState &state); + + void PinRadix(idx_t block_idx_to); + void PinData(SortedData &sd); + + data_ptr_t RadixPtr() const; + data_ptr_t DataPtr(SortedData &sd) const; + data_ptr_t HeapPtr(SortedData &sd) const; + data_ptr_t BaseHeapPtr(SortedData &sd) const; + + idx_t Remaining() const; + + void SetIndices(idx_t block_idx_to, idx_t entry_idx_to); + +public: + BufferManager &buffer_manager; + const SortLayout &sort_layout; + GlobalSortState &state; + + SortedBlock *sb; + + idx_t block_idx; + idx_t entry_idx; + + BufferHandle radix_handle; + + BufferHandle blob_sorting_data_handle; + BufferHandle blob_sorting_heap_handle; + + BufferHandle payload_data_handle; + BufferHandle payload_heap_handle; +}; + +//! Used to scan the data into DataChunks after sorting +struct PayloadScanner { +public: + PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush = true); + explicit PayloadScanner(GlobalSortState &global_sort_state, bool flush = true); + + //! Scan a single block + PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush = false); + + //! The type layout of the payload + inline const vector &GetPayloadTypes() const { + return scanner->GetTypes(); + } + + //! The number of rows scanned so far + inline idx_t Scanned() const { + return scanner->Scanned(); + } + + //! The number of remaining rows + inline idx_t Remaining() const { + return scanner->Remaining(); + } + + //! Scans the next data chunk from the sorted data + void Scan(DataChunk &chunk); + +private: + //! The sorted data being scanned + unique_ptr rows; + unique_ptr heap; + //! The actual scanner + unique_ptr scanner; +}; + +struct SBIterator { + static int ComparisonValue(ExpressionType comparison); + + SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p = 0); + + inline idx_t GetIndex() const { + return entry_idx; + } + + inline void SetIndex(idx_t entry_idx_p) { + const auto new_block_idx = entry_idx_p / block_capacity; + if (new_block_idx != scan.block_idx) { + scan.SetIndices(new_block_idx, 0); + if (new_block_idx < block_count) { + scan.PinRadix(scan.block_idx); + block_ptr = scan.RadixPtr(); + if (!all_constant) { + scan.PinData(*scan.sb->blob_sorting_data); + } + } + } + + scan.entry_idx = entry_idx_p % block_capacity; + entry_ptr = block_ptr + scan.entry_idx * entry_size; + entry_idx = entry_idx_p; + } + + inline SBIterator &operator++() { + if (++scan.entry_idx < block_capacity) { + entry_ptr += entry_size; + ++entry_idx; + } else { + SetIndex(entry_idx + 1); + } + + return *this; + } + + inline SBIterator &operator--() { + if (scan.entry_idx) { + --scan.entry_idx; + --entry_idx; + entry_ptr -= entry_size; + } else { + SetIndex(entry_idx - 1); + } + + return *this; + } + + inline bool Compare(const SBIterator &other) const { + int comp_res; + if (all_constant) { + comp_res = FastMemcmp(entry_ptr, other.entry_ptr, cmp_size); + } else { + comp_res = Comparators::CompareTuple(scan, other.scan, entry_ptr, other.entry_ptr, sort_layout, external); + } + + return comp_res <= cmp; + } + + // Fixed comparison parameters + const SortLayout &sort_layout; + const idx_t block_count; + const idx_t block_capacity; + const size_t cmp_size; + const size_t entry_size; + const bool all_constant; + const bool external; + const int cmp; + + // Iteration state + SBScanState scan; + idx_t entry_idx; + data_ptr_t block_ptr; + data_ptr_t entry_ptr; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/stack.hpp b/src/duckdb/src/include/duckdb/common/stack.hpp new file mode 100644 index 00000000..18e27080 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/stack.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/stack.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::stack; +} diff --git a/src/duckdb/src/include/duckdb/common/stack_checker.hpp b/src/duckdb/src/include/duckdb/common/stack_checker.hpp new file mode 100644 index 00000000..a2375e8e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/stack_checker.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/stack_checker.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +template +class StackChecker { +public: + StackChecker(RECURSIVE_CLASS &recursive_class_p, idx_t stack_usage_p) + : recursive_class(recursive_class_p), stack_usage(stack_usage_p) { + recursive_class.stack_depth += stack_usage; + } + ~StackChecker() { + recursive_class.stack_depth -= stack_usage; + } + StackChecker(StackChecker &&other) noexcept + : recursive_class(other.recursive_class), stack_usage(other.stack_usage) { + other.stack_usage = 0; + } + StackChecker(const StackChecker &) = delete; + +private: + RECURSIVE_CLASS &recursive_class; + idx_t stack_usage; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/string.hpp b/src/duckdb/src/include/duckdb/common/string.hpp new file mode 100644 index 00000000..ad717374 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/string.hpp @@ -0,0 +1,16 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/string.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace duckdb { +using std::string; +} diff --git a/src/duckdb/src/include/duckdb/common/string_map_set.hpp b/src/duckdb/src/include/duckdb/common/string_map_set.hpp new file mode 100644 index 00000000..00600c42 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/string_map_set.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/string_map_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +struct StringHash { + std::size_t operator()(const string_t &k) const { + return Hash(k); + } +}; + +struct StringEquality { + bool operator()(const string_t &a, const string_t &b) const { + return Equals::Operation(a, b); + } +}; + +template +using string_map_t = unordered_map; + +using string_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/string_util.hpp b/src/duckdb/src/include/duckdb/common/string_util.hpp new file mode 100644 index 00000000..7074387a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/string_util.hpp @@ -0,0 +1,245 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/string_util.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector.hpp" + +#include + +namespace duckdb { + +#ifndef DUCKDB_QUOTE_DEFINE +// Preprocessor trick to allow text to be converted to C-string / string +// Expecte use is: +// #ifdef SOME_DEFINE +// string str = DUCKDB_QUOTE_DEFINE(SOME_DEFINE) +// ...do something with str +// #endif SOME_DEFINE +#define DUCKDB_QUOTE_DEFINE_IMPL(x) #x +#define DUCKDB_QUOTE_DEFINE(x) DUCKDB_QUOTE_DEFINE_IMPL(x) +#endif + +/** + * String Utility Functions + * Note that these are not the most efficient implementations (i.e., they copy + * memory) and therefore they should only be used for debug messages and other + * such things. + */ +class StringUtil { +public: + static string GenerateRandomName(idx_t length = 16); + + static uint8_t GetHexValue(char c) { + if (c >= '0' && c <= '9') { + return c - '0'; + } + if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } + if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } + throw InvalidInputException("Invalid input for hex digit: %s", string(c, 1)); + } + + static uint8_t GetBinaryValue(char c) { + if (c >= '0' && c <= '1') { + return c - '0'; + } + throw InvalidInputException("Invalid input for binary digit: %s", string(c, 1)); + } + + static bool CharacterIsSpace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; + } + static bool CharacterIsNewline(char c) { + return c == '\n' || c == '\r'; + } + static bool CharacterIsDigit(char c) { + return c >= '0' && c <= '9'; + } + static bool CharacterIsHex(char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); + } + static char CharacterToLower(char c) { + if (c >= 'A' && c <= 'Z') { + return c - ('A' - 'a'); + } + return c; + } + static char CharacterIsAlpha(char c) { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); + } + static bool CharacterIsOperator(char c) { + if (c == '_') { + return false; + } + if (c >= '!' && c <= '/') { + return true; + } + if (c >= ':' && c <= '@') { + return true; + } + if (c >= '[' && c <= '`') { + return true; + } + if (c >= '{' && c <= '~') { + return true; + } + return false; + } + + template + static vector ConvertStrings(const vector &strings) { + vector result; + for (auto &string : strings) { + result.emplace_back(string); + } + return result; + } + + static vector ConvertToSQLIdentifiers(const vector &strings) { + return ConvertStrings(strings); + } + + static vector ConvertToSQLStrings(const vector &strings) { + return ConvertStrings(strings); + } + + //! Returns true if the needle string exists in the haystack + DUCKDB_API static bool Contains(const string &haystack, const string &needle); + + //! Returns true if the target string starts with the given prefix + DUCKDB_API static bool StartsWith(string str, string prefix); + + //! Returns true if the target string ends with the given suffix. + DUCKDB_API static bool EndsWith(const string &str, const string &suffix); + + //! Repeat a string multiple times + DUCKDB_API static string Repeat(const string &str, const idx_t n); + + //! Split the input string based on newline char + DUCKDB_API static vector Split(const string &str, char delimiter); + + //! Split the input string allong a quote. Note that any escaping is NOT supported. + DUCKDB_API static vector SplitWithQuote(const string &str, char delimiter = ',', char quote = '"'); + + //! Join multiple strings into one string. Components are concatenated by the given separator + DUCKDB_API static string Join(const vector &input, const string &separator); + + template + static string ToString(const vector &input, const string &separator) { + vector input_list; + for (auto &i : input) { + input_list.push_back(i.ToString()); + } + return StringUtil::Join(input_list, separator); + } + + //! Join multiple items of container with given size, transformed to string + //! using function, into one string using the given separator + template + static string Join(const C &input, S count, const string &separator, Func f) { + // The result + std::string result; + + // If the input isn't empty, append the first element. We do this so we + // don't need to introduce an if into the loop. + if (count > 0) { + result += f(input[0]); + } + + // Append the remaining input components, after the first + for (size_t i = 1; i < count; i++) { + result += separator + f(input[i]); + } + + return result; + } + + //! Return a string that formats the give number of bytes + DUCKDB_API static string BytesToHumanReadableString(idx_t bytes); + + //! Convert a string to uppercase + DUCKDB_API static string Upper(const string &str); + + //! Convert a string to lowercase + DUCKDB_API static string Lower(const string &str); + + DUCKDB_API static bool IsLower(const string &str); + + //! Case insensitive hash + DUCKDB_API static uint64_t CIHash(const string &str); + + //! Case insensitive equals + DUCKDB_API static bool CIEquals(const string &l1, const string &l2); + + //! Format a string using printf semantics + template + static string Format(const string fmt_str, Args... params) { + return Exception::ConstructMessage(fmt_str, params...); + } + + //! Split the input string into a vector of strings based on the split string + DUCKDB_API static vector Split(const string &input, const string &split); + + //! Remove the whitespace char in the left end of the string + DUCKDB_API static void LTrim(string &str); + //! Remove the whitespace char in the right end of the string + DUCKDB_API static void RTrim(string &str); + //! Remove the all chars from chars_to_trim char in the right end of the string + DUCKDB_API static void RTrim(string &str, const string &chars_to_trim); + //! Remove the whitespace char in the left and right end of the string + DUCKDB_API static void Trim(string &str); + + DUCKDB_API static string Replace(string source, const string &from, const string &to); + + //! Get the levenshtein distance from two strings + //! The not_equal_penalty is the penalty given when two characters in a string are not equal + //! The regular levenshtein distance has a not equal penalty of 1, which means changing a character is as expensive + //! as adding or removing one For similarity searches we often want to give extra weight to changing a character For + //! example: with an equal penalty of 1, "pg_am" is closer to "depdelay" than "depdelay_minutes" + //! with an equal penalty of 3, "depdelay_minutes" is closer to "depdelay" than to "pg_am" + DUCKDB_API static idx_t LevenshteinDistance(const string &s1, const string &s2, idx_t not_equal_penalty = 1); + + //! Returns the similarity score between two strings + DUCKDB_API static idx_t SimilarityScore(const string &s1, const string &s2); + //! Get the top-n strings (sorted by the given score distance) from a set of scores. + //! At least one entry is returned (if there is one). + //! Strings are only returned if they have a score less than the threshold. + DUCKDB_API static vector TopNStrings(vector> scores, idx_t n = 5, + idx_t threshold = 5); + //! Computes the levenshtein distance of each string in strings, and compares it to target, then returns TopNStrings + //! with the given params. + DUCKDB_API static vector TopNLevenshtein(const vector &strings, const string &target, idx_t n = 5, + idx_t threshold = 5); + DUCKDB_API static string CandidatesMessage(const vector &candidates, + const string &candidate = "Candidate bindings"); + + //! Generate an error message in the form of "{message_prefix}: nearest_string, nearest_string2, ... + //! Equivalent to calling TopNLevenshtein followed by CandidatesMessage + DUCKDB_API static string CandidatesErrorMessage(const vector &strings, const string &target, + const string &message_prefix, idx_t n = 5); + + //! Returns true if two null-terminated strings are equal or point to the same address. + //! Returns false if only one of the strings is nullptr + static bool Equals(const char *s1, const char *s2) { + if (s1 == s2) { + return true; + } + if (s1 == nullptr || s2 == nullptr) { + return false; + } + return strcmp(s1, s2) == 0; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/swap.hpp b/src/duckdb/src/include/duckdb/common/swap.hpp new file mode 100644 index 00000000..ef305da1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/swap.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/swap.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::swap; +} diff --git a/src/duckdb/src/include/duckdb/common/thread.hpp b/src/duckdb/src/include/duckdb/common/thread.hpp new file mode 100644 index 00000000..7540dfc5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/thread.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/thread.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::thread; +} diff --git a/src/duckdb/src/include/duckdb/common/to_string.hpp b/src/duckdb/src/include/duckdb/common/to_string.hpp new file mode 100644 index 00000000..f7afe6ee --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/to_string.hpp @@ -0,0 +1,13 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/to_string.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { +using std::to_string; +} diff --git a/src/duckdb/src/include/duckdb/common/tree_renderer.hpp b/src/duckdb/src/include/duckdb/common/tree_renderer.hpp new file mode 100644 index 00000000..1f2f9ea7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/tree_renderer.hpp @@ -0,0 +1,150 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/tree_renderer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/main/query_profiler.hpp" + +namespace duckdb { +class LogicalOperator; +class PhysicalOperator; +class Pipeline; +struct PipelineRenderNode; + +struct RenderTreeNode { + string name; + string extra_text; +}; + +struct RenderTree { + RenderTree(idx_t width, idx_t height); + + unique_ptr[]> nodes; + idx_t width; + idx_t height; + +public: + RenderTreeNode *GetNode(idx_t x, idx_t y); + void SetNode(idx_t x, idx_t y, unique_ptr node); + bool HasNode(idx_t x, idx_t y); + + idx_t GetPosition(idx_t x, idx_t y); +}; + +struct TreeRendererConfig { + void enable_detailed() { + MAX_EXTRA_LINES = 1000; + detailed = true; + } + + void enable_standard() { + MAX_EXTRA_LINES = 30; + detailed = false; + } + + idx_t MAXIMUM_RENDER_WIDTH = 240; + idx_t NODE_RENDER_WIDTH = 29; + idx_t MINIMUM_RENDER_WIDTH = 15; + idx_t MAX_EXTRA_LINES = 30; + bool detailed = false; + +#ifndef DUCKDB_ASCII_TREE_RENDERER + const char *LTCORNER = "\342\224\214"; // "┌"; + const char *RTCORNER = "\342\224\220"; // "┐"; + const char *LDCORNER = "\342\224\224"; // "└"; + const char *RDCORNER = "\342\224\230"; // "┘"; + + const char *MIDDLE = "\342\224\274"; // "┼"; + const char *TMIDDLE = "\342\224\254"; // "┬"; + const char *LMIDDLE = "\342\224\234"; // "├"; + const char *RMIDDLE = "\342\224\244"; // "┤"; + const char *DMIDDLE = "\342\224\264"; // "┴"; + + const char *VERTICAL = "\342\224\202"; // "│"; + const char *HORIZONTAL = "\342\224\200"; // "─"; +#else + // ASCII version + const char *LTCORNER = "<"; + const char *RTCORNER = ">"; + const char *LDCORNER = "<"; + const char *RDCORNER = ">"; + + const char *MIDDLE = "+"; + const char *TMIDDLE = "+"; + const char *LMIDDLE = "+"; + const char *RMIDDLE = "+"; + const char *DMIDDLE = "+"; + + const char *VERTICAL = "|"; + const char *HORIZONTAL = "-"; +#endif +}; + +class TreeRenderer { +public: + explicit TreeRenderer(TreeRendererConfig config_p = TreeRendererConfig()) : config(std::move(config_p)) { + } + + string ToString(const LogicalOperator &op); + string ToString(const PhysicalOperator &op); + string ToString(const QueryProfiler::TreeNode &op); + string ToString(const Pipeline &op); + + void Render(const LogicalOperator &op, std::ostream &ss); + void Render(const PhysicalOperator &op, std::ostream &ss); + void Render(const QueryProfiler::TreeNode &op, std::ostream &ss); + void Render(const Pipeline &op, std::ostream &ss); + + void ToStream(RenderTree &root, std::ostream &ss); + + void EnableDetailed() { + config.enable_detailed(); + } + void EnableStandard() { + config.enable_standard(); + } + +private: + unique_ptr CreateTree(const LogicalOperator &op); + unique_ptr CreateTree(const PhysicalOperator &op); + unique_ptr CreateTree(const QueryProfiler::TreeNode &op); + unique_ptr CreateTree(const Pipeline &op); + + string ExtraInfoSeparator(); + unique_ptr CreateRenderNode(string name, string extra_info); + unique_ptr CreateNode(const LogicalOperator &op); + unique_ptr CreateNode(const PhysicalOperator &op); + unique_ptr CreateNode(const QueryProfiler::TreeNode &op); + unique_ptr CreateNode(const PipelineRenderNode &op); + +private: + //! The configuration used for rendering + TreeRendererConfig config; + +private: + void RenderTopLayer(RenderTree &root, std::ostream &ss, idx_t y); + void RenderBoxContent(RenderTree &root, std::ostream &ss, idx_t y); + void RenderBottomLayer(RenderTree &root, std::ostream &ss, idx_t y); + + bool CanSplitOnThisChar(char l); + bool IsPadding(char l); + string RemovePadding(string l); + void SplitUpExtraInfo(const string &extra_info, vector &result); + void SplitStringBuffer(const string &source, vector &result); + + template + idx_t CreateRenderTreeRecursive(RenderTree &result, const T &op, idx_t x, idx_t y); + + template + unique_ptr CreateRenderTree(const T &op); + string ExtractExpressionsRecursive(ExpressionInfo &states); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/type_util.hpp b/src/duckdb/src/include/duckdb/common/type_util.hpp new file mode 100644 index 00000000..fd8f4094 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/type_util.hpp @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/type_util.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/datetime.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/interval.hpp" + +namespace duckdb { + +//! Returns the PhysicalType for the given type +template +PhysicalType GetTypeId() { + if (std::is_same()) { + return PhysicalType::BOOL; + } else if (std::is_same()) { + return PhysicalType::INT8; + } else if (std::is_same()) { + return PhysicalType::INT16; + } else if (std::is_same()) { + return PhysicalType::INT32; + } else if (std::is_same()) { + return PhysicalType::INT64; + } else if (std::is_same()) { + return PhysicalType::UINT8; + } else if (std::is_same()) { + return PhysicalType::UINT16; + } else if (std::is_same()) { + return PhysicalType::UINT32; + } else if (std::is_same()) { + return PhysicalType::UINT64; + } else if (std::is_same()) { + return PhysicalType::INT128; + } else if (std::is_same()) { + return PhysicalType::INT32; + } else if (std::is_same()) { + return PhysicalType::INT64; + } else if (std::is_same()) { + return PhysicalType::INT64; + } else if (std::is_same()) { + return PhysicalType::FLOAT; + } else if (std::is_same()) { + return PhysicalType::DOUBLE; + } else if (std::is_same() || std::is_same() || std::is_same()) { + return PhysicalType::VARCHAR; + } else if (std::is_same()) { + return PhysicalType::INTERVAL; + } else { + return PhysicalType::INVALID; + } +} + +template +bool TypeIsNumber() { + return std::is_integral() || std::is_floating_point() || std::is_same(); +} + +template +bool IsValidType() { + return GetTypeId() != PhysicalType::INVALID; +} + +template +bool IsIntegerType() { + return TypeIsIntegral(GetTypeId()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/typedefs.hpp b/src/duckdb/src/include/duckdb/common/typedefs.hpp new file mode 100644 index 00000000..1d24002d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/typedefs.hpp @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/typedefs.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { + +//! a saner size_t for loop indices etc +typedef uint64_t idx_t; + +//! The type used for row identifiers +typedef int64_t row_t; + +//! The type used for hashes +typedef uint64_t hash_t; + +//! data pointers +typedef uint8_t data_t; +typedef data_t *data_ptr_t; +typedef const data_t *const_data_ptr_t; + +//! Type used for the selection vector +typedef uint32_t sel_t; +//! Type used for transaction timestamps +typedef idx_t transaction_t; + +//! Type used for column identifiers +typedef idx_t column_t; +//! Type used for storage (column) identifiers +typedef idx_t storage_t; + +template +data_ptr_t data_ptr_cast(SRC *src) { + return reinterpret_cast(src); +} + +template +const_data_ptr_t const_data_ptr_cast(const SRC *src) { + return reinterpret_cast(src); +} + +template +char *char_ptr_cast(SRC *src) { + return reinterpret_cast(src); +} + +template +const char *const_char_ptr_cast(const SRC *src) { + return reinterpret_cast(src); +} + +template +const unsigned char *const_uchar_ptr_cast(const SRC *src) { + return reinterpret_cast(src); +} + +template +uintptr_t CastPointerToValue(SRC *src) { + return uintptr_t(src); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types.hpp b/src/duckdb/src/include/duckdb/common/types.hpp new file mode 100644 index 00000000..a99bbd80 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types.hpp @@ -0,0 +1,464 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/vector.hpp" + +#include + +namespace duckdb { + +class Serializer; +class Deserializer; +class Value; +class TypeCatalogEntry; +class Vector; +class ClientContext; + +struct string_t; + +template +using child_list_t = vector>; +//! FIXME: this should be a single_thread_ptr +template +using buffer_ptr = shared_ptr; + +template +buffer_ptr make_buffer(Args &&...args) { + return make_shared(std::forward(args)...); +} + +struct list_entry_t { + list_entry_t() = default; + list_entry_t(uint64_t offset, uint64_t length) : offset(offset), length(length) { + } + inline constexpr bool operator != (const list_entry_t &other) const { + return !(*this == other); + } + inline constexpr bool operator == (const list_entry_t &other) const { + return offset == other.offset && length == other.length; + } + + uint64_t offset; + uint64_t length; +}; + +using union_tag_t = uint8_t; + +//===--------------------------------------------------------------------===// +// Internal Types +//===--------------------------------------------------------------------===// + +// taken from arrow's type.h +enum class PhysicalType : uint8_t { + ///// A NULL type having no physical storage + //NA = 0, + + /// Boolean as 8 bit "bool" value + BOOL = 1, + + /// Unsigned 8-bit little-endian integer + UINT8 = 2, + + /// Signed 8-bit little-endian integer + INT8 = 3, + + /// Unsigned 16-bit little-endian integer + UINT16 = 4, + + /// Signed 16-bit little-endian integer + INT16 = 5, + + /// Unsigned 32-bit little-endian integer + UINT32 = 6, + + /// Signed 32-bit little-endian integer + INT32 = 7, + + /// Unsigned 64-bit little-endian integer + UINT64 = 8, + + /// Signed 64-bit little-endian integer + INT64 = 9, + + ///// 2-byte floating point value + //HALF_FLOAT = 10, + + /// 4-byte floating point value + FLOAT = 11, + + /// 8-byte floating point value + DOUBLE = 12, + + ///// UTF8 variable-length string as List + //STRING = 13, + + ///// Variable-length bytes (no guarantee of UTF8-ness) + //BINARY = 14, + + ///// Fixed-size binary. Each value occupies the same number of bytes + //FIXED_SIZE_BINARY = 15, + + ///// int32_t days since the UNIX epoch + //DATE32 = 16, + + ///// int64_t milliseconds since the UNIX epoch + //DATE64 = 17, + + ///// Exact timestamp encoded with int64 since UNIX epoch + ///// Default unit millisecond + //TIMESTAMP = 18, + + ///// Time as signed 32-bit integer, representing either seconds or + ///// milliseconds since midnight + //TIME32 = 19, + + ///// Time as signed 64-bit integer, representing either microseconds or + ///// nanoseconds since midnight + //TIME64 = 20, + + /// YEAR_MONTH or DAY_TIME interval in SQL style + INTERVAL = 21, + + /// Precision- and scale-based decimal type. Storage type depends on the + /// parameters. + // DECIMAL = 22, + + /// A list of some logical data type + LIST = 23, + + /// Struct of logical types + STRUCT = 24, + + ///// Unions of logical types + //UNION = 25, + + ///// Dictionary-encoded type, also called "categorical" or "factor" + ///// in other programming languages. Holds the dictionary value + ///// type but not the dictionary itself, which is part of the + ///// ArrayData struct + //DICTIONARY = 26, + + ///// Custom data type, implemented by user + //EXTENSION = 28, + + ///// Fixed size list of some logical type + //FIXED_SIZE_LIST = 29, + + ///// Measure of elapsed time in either seconds, milliseconds, microseconds + ///// or nanoseconds. + //DURATION = 30, + + ///// Like STRING, but with 64-bit offsets + //LARGE_STRING = 31, + + ///// Like BINARY, but with 64-bit offsets + //LARGE_BINARY = 32, + + ///// Like LIST, but with 64-bit offsets + //LARGE_LIST = 33, + + /// DuckDB Extensions + VARCHAR = 200, // our own string representation, different from STRING and LARGE_STRING above + INT128 = 204, // 128-bit integers + UNKNOWN = 205, // Unknown physical type of user defined types + /// Boolean as 1 bit, LSB bit-packed ordering + BIT = 206, + + INVALID = 255 +}; + +//===--------------------------------------------------------------------===// +// SQL Types +//===--------------------------------------------------------------------===// +enum class LogicalTypeId : uint8_t { + INVALID = 0, + SQLNULL = 1, /* NULL type, used for constant NULL */ + UNKNOWN = 2, /* unknown type, used for parameter expressions */ + ANY = 3, /* ANY type, used for functions that accept any type as parameter */ + USER = 4, /* A User Defined Type (e.g., ENUMs before the binder) */ + BOOLEAN = 10, + TINYINT = 11, + SMALLINT = 12, + INTEGER = 13, + BIGINT = 14, + DATE = 15, + TIME = 16, + TIMESTAMP_SEC = 17, + TIMESTAMP_MS = 18, + TIMESTAMP = 19, //! us + TIMESTAMP_NS = 20, + DECIMAL = 21, + FLOAT = 22, + DOUBLE = 23, + CHAR = 24, + VARCHAR = 25, + BLOB = 26, + INTERVAL = 27, + UTINYINT = 28, + USMALLINT = 29, + UINTEGER = 30, + UBIGINT = 31, + TIMESTAMP_TZ = 32, + TIME_TZ = 34, + BIT = 36, + + HUGEINT = 50, + POINTER = 51, + VALIDITY = 53, + UUID = 54, + + STRUCT = 100, + LIST = 101, + MAP = 102, + TABLE = 103, + ENUM = 104, + AGGREGATE_STATE = 105, + LAMBDA = 106, + UNION = 107 +}; + + +struct ExtraTypeInfo; + +struct aggregate_state_t; + +struct LogicalType { + DUCKDB_API LogicalType(); + DUCKDB_API LogicalType(LogicalTypeId id); // NOLINT: Allow implicit conversion from `LogicalTypeId` + DUCKDB_API LogicalType(LogicalTypeId id, shared_ptr type_info); + DUCKDB_API LogicalType(const LogicalType &other); + DUCKDB_API LogicalType(LogicalType &&other) noexcept; + + DUCKDB_API ~LogicalType(); + + inline LogicalTypeId id() const { + return id_; + } + inline PhysicalType InternalType() const { + return physical_type_; + } + inline const ExtraTypeInfo *AuxInfo() const { + return type_info_.get(); + } + + inline shared_ptr GetAuxInfoShrPtr() const { + return type_info_; + } + + inline void CopyAuxInfo(const LogicalType& other) { + type_info_ = other.type_info_; + } + bool EqualTypeInfo(const LogicalType& rhs) const; + + // copy assignment + inline LogicalType& operator=(const LogicalType &other) { + id_ = other.id_; + physical_type_ = other.physical_type_; + type_info_ = other.type_info_; + return *this; + } + // move assignment + inline LogicalType& operator=(LogicalType&& other) noexcept { + id_ = other.id_; + physical_type_ = other.physical_type_; + std::swap(type_info_, other.type_info_); + return *this; + } + + DUCKDB_API bool operator==(const LogicalType &rhs) const; + inline bool operator!=(const LogicalType &rhs) const { + return !(*this == rhs); + } + + DUCKDB_API void Serialize(Serializer &serializer) const; + DUCKDB_API static LogicalType Deserialize(Deserializer &deserializer); + + + static bool TypeIsTimestamp(LogicalTypeId id) { + return (id == LogicalTypeId::TIMESTAMP || + id == LogicalTypeId::TIMESTAMP_MS || + id == LogicalTypeId::TIMESTAMP_NS || + id == LogicalTypeId::TIMESTAMP_SEC || + id == LogicalTypeId::TIMESTAMP_TZ); + } + static bool TypeIsTimestamp(const LogicalType& type) { + return TypeIsTimestamp(type.id()); + } + DUCKDB_API string ToString() const; + DUCKDB_API bool IsIntegral() const; + DUCKDB_API bool IsNumeric() const; + DUCKDB_API hash_t Hash() const; + DUCKDB_API void SetAlias(string alias); + DUCKDB_API bool HasAlias() const; + DUCKDB_API string GetAlias() const; + + DUCKDB_API static LogicalType MaxLogicalType(const LogicalType &left, const LogicalType &right); + + //! Gets the decimal properties of a numeric type. Fails if the type is not numeric. + DUCKDB_API bool GetDecimalProperties(uint8_t &width, uint8_t &scale) const; + + DUCKDB_API void Verify() const; + + DUCKDB_API bool IsValid() const; + +private: + LogicalTypeId id_; + PhysicalType physical_type_; + shared_ptr type_info_; + +private: + PhysicalType GetInternalType(); + +public: + static constexpr const LogicalTypeId SQLNULL = LogicalTypeId::SQLNULL; + static constexpr const LogicalTypeId UNKNOWN = LogicalTypeId::UNKNOWN; + static constexpr const LogicalTypeId BOOLEAN = LogicalTypeId::BOOLEAN; + static constexpr const LogicalTypeId TINYINT = LogicalTypeId::TINYINT; + static constexpr const LogicalTypeId UTINYINT = LogicalTypeId::UTINYINT; + static constexpr const LogicalTypeId SMALLINT = LogicalTypeId::SMALLINT; + static constexpr const LogicalTypeId USMALLINT = LogicalTypeId::USMALLINT; + static constexpr const LogicalTypeId INTEGER = LogicalTypeId::INTEGER; + static constexpr const LogicalTypeId UINTEGER = LogicalTypeId::UINTEGER; + static constexpr const LogicalTypeId BIGINT = LogicalTypeId::BIGINT; + static constexpr const LogicalTypeId UBIGINT = LogicalTypeId::UBIGINT; + static constexpr const LogicalTypeId FLOAT = LogicalTypeId::FLOAT; + static constexpr const LogicalTypeId DOUBLE = LogicalTypeId::DOUBLE; + static constexpr const LogicalTypeId DATE = LogicalTypeId::DATE; + static constexpr const LogicalTypeId TIMESTAMP = LogicalTypeId::TIMESTAMP; + static constexpr const LogicalTypeId TIMESTAMP_S = LogicalTypeId::TIMESTAMP_SEC; + static constexpr const LogicalTypeId TIMESTAMP_MS = LogicalTypeId::TIMESTAMP_MS; + static constexpr const LogicalTypeId TIMESTAMP_NS = LogicalTypeId::TIMESTAMP_NS; + static constexpr const LogicalTypeId TIME = LogicalTypeId::TIME; + static constexpr const LogicalTypeId TIMESTAMP_TZ = LogicalTypeId::TIMESTAMP_TZ; + static constexpr const LogicalTypeId TIME_TZ = LogicalTypeId::TIME_TZ; + static constexpr const LogicalTypeId VARCHAR = LogicalTypeId::VARCHAR; + static constexpr const LogicalTypeId ANY = LogicalTypeId::ANY; + static constexpr const LogicalTypeId BLOB = LogicalTypeId::BLOB; + static constexpr const LogicalTypeId BIT = LogicalTypeId::BIT; + static constexpr const LogicalTypeId INTERVAL = LogicalTypeId::INTERVAL; + static constexpr const LogicalTypeId HUGEINT = LogicalTypeId::HUGEINT; + static constexpr const LogicalTypeId UUID = LogicalTypeId::UUID; + static constexpr const LogicalTypeId HASH = LogicalTypeId::UBIGINT; + static constexpr const LogicalTypeId POINTER = LogicalTypeId::POINTER; + static constexpr const LogicalTypeId TABLE = LogicalTypeId::TABLE; + static constexpr const LogicalTypeId LAMBDA = LogicalTypeId::LAMBDA; + static constexpr const LogicalTypeId INVALID = LogicalTypeId::INVALID; + static constexpr const LogicalTypeId ROW_TYPE = LogicalTypeId::BIGINT; + + // explicitly allowing these functions to be capitalized to be in-line with the remaining functions + DUCKDB_API static LogicalType DECIMAL(int width, int scale); // NOLINT + DUCKDB_API static LogicalType VARCHAR_COLLATION(string collation); // NOLINT + DUCKDB_API static LogicalType LIST(const LogicalType &child); // NOLINT + DUCKDB_API static LogicalType STRUCT(child_list_t children); // NOLINT + DUCKDB_API static LogicalType AGGREGATE_STATE(aggregate_state_t state_type); // NOLINT + DUCKDB_API static LogicalType MAP(const LogicalType &child); // NOLINT + DUCKDB_API static LogicalType MAP(LogicalType key, LogicalType value); // NOLINT + DUCKDB_API static LogicalType UNION( child_list_t members); // NOLINT + DUCKDB_API static LogicalType ENUM(Vector &ordered_data, idx_t size); // NOLINT + // DEPRECATED - provided for backwards compatibility + DUCKDB_API static LogicalType ENUM(const string &enum_name, Vector &ordered_data, idx_t size); // NOLINT + DUCKDB_API static LogicalType USER(const string &user_type_name); // NOLINT + //! A list of all NUMERIC types (integral and floating point types) + DUCKDB_API static const vector Numeric(); + //! A list of all INTEGRAL types + DUCKDB_API static const vector Integral(); + //! A list of ALL SQL types + DUCKDB_API static const vector AllTypes(); +}; + +struct DecimalType { + DUCKDB_API static uint8_t GetWidth(const LogicalType &type); + DUCKDB_API static uint8_t GetScale(const LogicalType &type); + DUCKDB_API static uint8_t MaxWidth(); +}; + +struct StringType { + DUCKDB_API static string GetCollation(const LogicalType &type); +}; + +struct ListType { + DUCKDB_API static const LogicalType &GetChildType(const LogicalType &type); +}; + +struct UserType { + DUCKDB_API static const string &GetTypeName(const LogicalType &type); +}; + +struct EnumType { + DUCKDB_API static int64_t GetPos(const LogicalType &type, const string_t& key); + DUCKDB_API static const Vector &GetValuesInsertOrder(const LogicalType &type); + DUCKDB_API static idx_t GetSize(const LogicalType &type); + DUCKDB_API static const string GetValue(const Value &val); + DUCKDB_API static PhysicalType GetPhysicalType(const LogicalType &type); + DUCKDB_API static string_t GetString(const LogicalType &type, idx_t pos); +}; + +struct StructType { + DUCKDB_API static const child_list_t &GetChildTypes(const LogicalType &type); + DUCKDB_API static const LogicalType &GetChildType(const LogicalType &type, idx_t index); + DUCKDB_API static const string &GetChildName(const LogicalType &type, idx_t index); + DUCKDB_API static idx_t GetChildCount(const LogicalType &type); + DUCKDB_API static bool IsUnnamed(const LogicalType &type); +}; + +struct MapType { + DUCKDB_API static const LogicalType &KeyType(const LogicalType &type); + DUCKDB_API static const LogicalType &ValueType(const LogicalType &type); +}; + +struct UnionType { + DUCKDB_API static const idx_t MAX_UNION_MEMBERS = 256; + DUCKDB_API static idx_t GetMemberCount(const LogicalType &type); + DUCKDB_API static const LogicalType &GetMemberType(const LogicalType &type, idx_t index); + DUCKDB_API static const string &GetMemberName(const LogicalType &type, idx_t index); + DUCKDB_API static const child_list_t CopyMemberTypes(const LogicalType &type); +}; + +struct AggregateStateType { + DUCKDB_API static const string GetTypeName(const LogicalType &type); + DUCKDB_API static const aggregate_state_t &GetStateType(const LogicalType &type); +}; + +// **DEPRECATED**: Use EnumUtil directly instead. +DUCKDB_API string LogicalTypeIdToString(LogicalTypeId type); + +DUCKDB_API LogicalTypeId TransformStringToLogicalTypeId(const string &str); + +DUCKDB_API LogicalType TransformStringToLogicalType(const string &str); + +DUCKDB_API LogicalType TransformStringToLogicalType(const string &str, ClientContext &context); + +//! The PhysicalType used by the row identifiers column +extern const PhysicalType ROW_TYPE; + +DUCKDB_API string TypeIdToString(PhysicalType type); +DUCKDB_API idx_t GetTypeIdSize(PhysicalType type); +DUCKDB_API bool TypeIsConstantSize(PhysicalType type); +DUCKDB_API bool TypeIsIntegral(PhysicalType type); +DUCKDB_API bool TypeIsNumeric(PhysicalType type); +DUCKDB_API bool TypeIsInteger(PhysicalType type); + +bool ApproxEqual(float l, float r); +bool ApproxEqual(double l, double r); + +struct aggregate_state_t { + aggregate_state_t() {} + aggregate_state_t(string function_name_p, LogicalType return_type_p, vector bound_argument_types_p) : function_name(std::move(function_name_p)), return_type(std::move(return_type_p)), bound_argument_types(std::move(bound_argument_types_p)) { + } + + string function_name; + LogicalType return_type; + vector bound_argument_types; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/arrow_aux_data.hpp b/src/duckdb/src/include/duckdb/common/types/arrow_aux_data.hpp new file mode 100644 index 00000000..cf6c176f --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/arrow_aux_data.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/arrow_aux_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/vector_buffer.hpp" +#include "duckdb/common/arrow/arrow_wrapper.hpp" + +namespace duckdb { + +struct ArrowAuxiliaryData : public VectorAuxiliaryData { + static constexpr const VectorAuxiliaryDataType TYPE = VectorAuxiliaryDataType::ARROW_AUXILIARY; + explicit ArrowAuxiliaryData(shared_ptr arrow_array_p) + : VectorAuxiliaryData(VectorAuxiliaryDataType::ARROW_AUXILIARY), arrow_array(std::move(arrow_array_p)) { + } + ~ArrowAuxiliaryData() override { + } + + shared_ptr arrow_array; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp new file mode 100644 index 00000000..7d1c751f --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/batched_chunk_collection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/map.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" + +namespace duckdb { +class BufferManager; +class ClientContext; + +struct BatchedChunkScanState { + map>::iterator iterator; + ColumnDataScanState scan_state; +}; + +//! A BatchedDataCollection holds a number of data entries that are partitioned by batch index +//! Scans over a BatchedDataCollection are ordered by batch index +class BatchedDataCollection { +public: + DUCKDB_API BatchedDataCollection(ClientContext &context, vector types, bool buffer_managed = false); + + //! Appends a datachunk with the given batch index to the batched collection + DUCKDB_API void Append(DataChunk &input, idx_t batch_index); + + //! Merge the other batched chunk collection into this batched collection + DUCKDB_API void Merge(BatchedDataCollection &other); + + //! Initialize a scan over the batched chunk collection + DUCKDB_API void InitializeScan(BatchedChunkScanState &state); + + //! Scan a chunk from the batched chunk collection, in-order of batch index + DUCKDB_API void Scan(BatchedChunkScanState &state, DataChunk &output); + + //! Fetch a column data collection from the batched data collection - this consumes all of the data stored within + DUCKDB_API unique_ptr FetchCollection(); + + DUCKDB_API string ToString() const; + DUCKDB_API void Print() const; + +private: + struct CachedCollection { + idx_t batch_index = DConstants::INVALID_INDEX; + ColumnDataCollection *collection = nullptr; + ColumnDataAppendState append_state; + }; + + ClientContext &context; + vector types; + bool buffer_managed; + //! The data of the batched chunk collection - a set of batch_index -> ColumnDataCollection pointers + map> data; + //! The last batch collection that was inserted into + CachedCollection last_collection; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/bit.hpp b/src/duckdb/src/include/duckdb/common/types/bit.hpp new file mode 100644 index 00000000..0eb82a4e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/bit.hpp @@ -0,0 +1,143 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/bit.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/string_type.hpp" + +namespace duckdb { + +//! The Bit class is a static class that holds helper functions for the BIT type. +class Bit { +public: + //! Returns the number of bits in the bit string + DUCKDB_API static idx_t BitLength(string_t bits); + //! Returns the number of set bits in the bit string + DUCKDB_API static idx_t BitCount(string_t bits); + //! Returns the number of bytes in the bit string + DUCKDB_API static idx_t OctetLength(string_t bits); + //! Extracts the nth bit from bit string; the first (leftmost) bit is indexed 0 + DUCKDB_API static idx_t GetBit(string_t bit_string, idx_t n); + //! Sets the nth bit in bit string to newvalue; the first (leftmost) bit is indexed 0 + DUCKDB_API static void SetBit(string_t &bit_string, idx_t n, idx_t new_value); + //! Returns first starting index of the specified substring within bits, or zero if it's not present. + DUCKDB_API static idx_t BitPosition(string_t substring, string_t bits); + //! Converts bits to a string, writing the output to the designated output string. + //! The string needs to have space for at least GetStringSize(bits) bytes. + DUCKDB_API static void ToString(string_t bits, char *output); + DUCKDB_API static string ToString(string_t str); + //! Returns the bit size of a string -> bit conversion + DUCKDB_API static bool TryGetBitStringSize(string_t str, idx_t &result_size, string *error_message); + //! Convert a string to a bit. This function should ONLY be called after calling GetBitSize, since it does NOT + //! perform data validation. + DUCKDB_API static void ToBit(string_t str, string_t &output); + + DUCKDB_API static string ToBit(string_t str); + + //! output needs to have enough space allocated before calling this function (blob size + 1) + DUCKDB_API static void BlobToBit(string_t blob, string_t &output); + + DUCKDB_API static string BlobToBit(string_t blob); + + //! output_str needs to have enough space allocated before calling this function (sizeof(T) + 1) + template + static void NumericToBit(T numeric, string_t &output_str); + + template + static string NumericToBit(T numeric); + + //! bit is expected to fit inside of output num (bit size <= sizeof(T) + 1) + template + static void BitToNumeric(string_t bit, T &output_num); + + template + static T BitToNumeric(string_t bit); + + //! bit is expected to fit inside of output_blob (bit size = output_blob + 1) + static void BitToBlob(string_t bit, string_t &output_blob); + + static string BitToBlob(string_t bit); + + //! Creates a new bitstring of determined length + DUCKDB_API static void BitString(const string_t &input, const idx_t &len, string_t &result); + DUCKDB_API static void SetEmptyBitString(string_t &target, string_t &input); + DUCKDB_API static void SetEmptyBitString(string_t &target, idx_t len); + DUCKDB_API static idx_t ComputeBitstringLen(idx_t len); + + DUCKDB_API static void RightShift(const string_t &bit_string, const idx_t &shif, string_t &result); + DUCKDB_API static void LeftShift(const string_t &bit_string, const idx_t &shift, string_t &result); + DUCKDB_API static void BitwiseAnd(const string_t &rhs, const string_t &lhs, string_t &result); + DUCKDB_API static void BitwiseOr(const string_t &rhs, const string_t &lhs, string_t &result); + DUCKDB_API static void BitwiseXor(const string_t &rhs, const string_t &lhs, string_t &result); + DUCKDB_API static void BitwiseNot(const string_t &rhs, string_t &result); + + DUCKDB_API static void Verify(const string_t &input); + +private: + static void Finalize(string_t &str); + static idx_t GetBitInternal(string_t bit_string, idx_t n); + static void SetBitInternal(string_t &bit_string, idx_t n, idx_t new_value); + static idx_t GetBitIndex(idx_t n); + static uint8_t GetFirstByte(const string_t &str); +}; + +//===--------------------------------------------------------------------===// +// Bit Template definitions +//===--------------------------------------------------------------------===// +template +void Bit::NumericToBit(T numeric, string_t &output_str) { + D_ASSERT(output_str.GetSize() >= sizeof(T) + 1); + + auto output = output_str.GetDataWriteable(); + auto data = const_data_ptr_cast(&numeric); + + *output = 0; // set padding to 0 + ++output; + for (idx_t idx = 0; idx < sizeof(T); ++idx) { + output[idx] = data[sizeof(T) - idx - 1]; + } + Bit::Finalize(output_str); +} + +template +string Bit::NumericToBit(T numeric) { + auto bit_len = sizeof(T) + 1; + auto buffer = make_unsafe_uniq_array(bit_len); + string_t output_str(buffer.get(), bit_len); + Bit::NumericToBit(numeric, output_str); + return output_str.GetString(); +} + +template +T Bit::BitToNumeric(string_t bit) { + T output; + Bit::BitToNumeric(bit, output); + return (output); +} + +template +void Bit::BitToNumeric(string_t bit, T &output_num) { + D_ASSERT(bit.GetSize() <= sizeof(T) + 1); + + output_num = 0; + auto data = const_data_ptr_cast(bit.GetData()); + auto output = data_ptr_cast(&output_num); + + idx_t padded_byte_idx = sizeof(T) - bit.GetSize() + 1; + output[sizeof(T) - 1 - padded_byte_idx] = GetFirstByte(bit); + for (idx_t idx = padded_byte_idx + 1; idx < sizeof(T); ++idx) { + output[sizeof(T) - 1 - idx] = data[1 + idx - padded_byte_idx]; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/blob.hpp b/src/duckdb/src/include/duckdb/common/types/blob.hpp new file mode 100644 index 00000000..26f2facc --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/blob.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/blob.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +//! The Blob class is a static class that holds helper functions for the Blob type. +class Blob { +public: + // map of integer -> hex value + static constexpr const char *HEX_TABLE = "0123456789ABCDEF"; + // reverse map of byte -> integer value, or -1 for invalid hex values + static const int HEX_MAP[256]; + //! map of index -> base64 character + static constexpr const char *BASE64_MAP = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + //! padding character used in base64 encoding + static constexpr const char BASE64_PADDING = '='; + +public: + //! Returns the string size of a blob -> string conversion + DUCKDB_API static idx_t GetStringSize(string_t blob); + //! Converts a blob to a string, writing the output to the designated output string. + //! The string needs to have space for at least GetStringSize(blob) bytes. + DUCKDB_API static void ToString(string_t blob, char *output); + //! Convert a blob object to a string + DUCKDB_API static string ToString(string_t blob); + + //! Returns the blob size of a string -> blob conversion + DUCKDB_API static bool TryGetBlobSize(string_t str, idx_t &result_size, string *error_message); + DUCKDB_API static idx_t GetBlobSize(string_t str); + //! Convert a string to a blob. This function should ONLY be called after calling GetBlobSize, since it does NOT + //! perform data validation. + DUCKDB_API static void ToBlob(string_t str, data_ptr_t output); + //! Convert a string object to a blob + DUCKDB_API static string ToBlob(string_t str); + + // base 64 conversion functions + //! Returns the string size of a blob -> base64 conversion + DUCKDB_API static idx_t ToBase64Size(string_t blob); + //! Converts a blob to a base64 string, output should have space for at least ToBase64Size(blob) bytes + DUCKDB_API static void ToBase64(string_t blob, char *output); + + //! Returns the string size of a base64 string -> blob conversion + DUCKDB_API static idx_t FromBase64Size(string_t str); + //! Converts a base64 string to a blob, output should have space for at least FromBase64Size(blob) bytes + DUCKDB_API static void FromBase64(string_t str, data_ptr_t output, idx_t output_size); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/cast_helpers.hpp b/src/duckdb/src/include/duckdb/common/types/cast_helpers.hpp new file mode 100644 index 00000000..43903c28 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/cast_helpers.hpp @@ -0,0 +1,558 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/cast_helpers.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/vector.hpp" +#include "fmt/format.h" + +namespace duckdb { + +//! NumericHelper is a static class that holds helper functions for integers/doubles +class NumericHelper { +public: + static constexpr uint8_t CACHED_POWERS_OF_TEN = 20; + static const int64_t POWERS_OF_TEN[CACHED_POWERS_OF_TEN]; + static const double DOUBLE_POWERS_OF_TEN[40]; + +public: + template + static int UnsignedLength(T value); + template + static int SignedLength(SIGNED value) { + int sign = -(value < 0); + UNSIGNED unsigned_value = (value ^ sign) - sign; + return UnsignedLength(unsigned_value) - sign; + } + + // Formats value in reverse and returns a pointer to the beginning. + template + static char *FormatUnsigned(T value, char *ptr) { + while (value >= 100) { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". + auto index = static_cast((value % 100) * 2); + value /= 100; + *--ptr = duckdb_fmt::internal::data::digits[index + 1]; + *--ptr = duckdb_fmt::internal::data::digits[index]; + } + if (value < 10) { + *--ptr = static_cast('0' + value); + return ptr; + } + auto index = static_cast(value * 2); + *--ptr = duckdb_fmt::internal::data::digits[index + 1]; + *--ptr = duckdb_fmt::internal::data::digits[index]; + return ptr; + } + + template + static string_t FormatSigned(SIGNED value, Vector &vector) { + int sign = -(value < 0); + UNSIGNED unsigned_value = UNSIGNED(value ^ sign) - sign; + int length = UnsignedLength(unsigned_value) - sign; + string_t result = StringVector::EmptyString(vector, length); + auto dataptr = result.GetDataWriteable(); + auto endptr = dataptr + length; + endptr = FormatUnsigned(unsigned_value, endptr); + if (sign) { + *--endptr = '-'; + } + result.Finalize(); + return result; + } + + template + static std::string ToString(T value) { + return std::to_string(value); + } +}; + +template <> +int NumericHelper::UnsignedLength(uint8_t value); +template <> +int NumericHelper::UnsignedLength(uint16_t value); +template <> +int NumericHelper::UnsignedLength(uint32_t value); +template <> +int NumericHelper::UnsignedLength(uint64_t value); + +template <> +std::string NumericHelper::ToString(hugeint_t value); + +struct DecimalToString { + template + static int DecimalLength(SIGNED value, uint8_t width, uint8_t scale) { + if (scale == 0) { + // scale is 0: regular number + return NumericHelper::SignedLength(value); + } + // length is max of either: + // scale + 2 OR + // integer length + 1 + // scale + 2 happens when the number is in the range of (-1, 1) + // in that case we print "0.XXX", which is the scale, plus "0." (2 chars) + // integer length + 1 happens when the number is outside of that range + // in that case we print the integer number, but with one extra character ('.') + auto extra_characters = width > scale ? 2 : 1; + return MaxValue(scale + extra_characters + (value < 0 ? 1 : 0), + NumericHelper::SignedLength(value) + 1); + } + + template + static void FormatDecimal(SIGNED value, uint8_t width, uint8_t scale, char *dst, idx_t len) { + char *end = dst + len; + if (value < 0) { + value = -value; + *dst = '-'; + } + if (scale == 0) { + NumericHelper::FormatUnsigned(value, end); + return; + } + // we write two numbers: + // the numbers BEFORE the decimal (major) + // and the numbers AFTER the decimal (minor) + UNSIGNED minor = value % (UNSIGNED)NumericHelper::POWERS_OF_TEN[scale]; + UNSIGNED major = value / (UNSIGNED)NumericHelper::POWERS_OF_TEN[scale]; + // write the number after the decimal + dst = NumericHelper::FormatUnsigned(minor, end); + // (optionally) pad with zeros and add the decimal point + while (dst > (end - scale)) { + *--dst = '0'; + } + *--dst = '.'; + // now write the part before the decimal + D_ASSERT(width > scale || major == 0); + if (width > scale) { + // there are numbers after the comma + dst = NumericHelper::FormatUnsigned(major, dst); + } + } + + template + static string_t Format(SIGNED value, uint8_t width, uint8_t scale, Vector &vector) { + int len = DecimalLength(value, width, scale); + string_t result = StringVector::EmptyString(vector, len); + FormatDecimal(value, width, scale, result.GetDataWriteable(), len); + result.Finalize(); + return result; + } +}; + +struct HugeintToStringCast { + static int UnsignedLength(hugeint_t value) { + D_ASSERT(value.upper >= 0); + if (value.upper == 0) { + return NumericHelper::UnsignedLength(value.lower); + } + // search the length using the POWERS_OF_TEN array + // the length has to be between [17] and [38], because the hugeint is bigger than 2^63 + // we use the same approach as above, but split a bit more because comparisons for hugeints are more expensive + if (value >= Hugeint::POWERS_OF_TEN[27]) { + // [27..38] + if (value >= Hugeint::POWERS_OF_TEN[32]) { + if (value >= Hugeint::POWERS_OF_TEN[36]) { + int length = 37; + length += value >= Hugeint::POWERS_OF_TEN[37]; + length += value >= Hugeint::POWERS_OF_TEN[38]; + return length; + } else { + int length = 33; + length += value >= Hugeint::POWERS_OF_TEN[33]; + length += value >= Hugeint::POWERS_OF_TEN[34]; + length += value >= Hugeint::POWERS_OF_TEN[35]; + return length; + } + } else { + if (value >= Hugeint::POWERS_OF_TEN[30]) { + int length = 31; + length += value >= Hugeint::POWERS_OF_TEN[31]; + length += value >= Hugeint::POWERS_OF_TEN[32]; + return length; + } else { + int length = 28; + length += value >= Hugeint::POWERS_OF_TEN[28]; + length += value >= Hugeint::POWERS_OF_TEN[29]; + return length; + } + } + } else { + // [17..27] + if (value >= Hugeint::POWERS_OF_TEN[22]) { + // [22..27] + if (value >= Hugeint::POWERS_OF_TEN[25]) { + int length = 26; + length += value >= Hugeint::POWERS_OF_TEN[26]; + return length; + } else { + int length = 23; + length += value >= Hugeint::POWERS_OF_TEN[23]; + length += value >= Hugeint::POWERS_OF_TEN[24]; + return length; + } + } else { + // [17..22] + if (value >= Hugeint::POWERS_OF_TEN[20]) { + int length = 21; + length += value >= Hugeint::POWERS_OF_TEN[21]; + return length; + } else { + int length = 18; + length += value >= Hugeint::POWERS_OF_TEN[18]; + length += value >= Hugeint::POWERS_OF_TEN[19]; + return length; + } + } + } + } + + // Formats value in reverse and returns a pointer to the beginning. + static char *FormatUnsigned(hugeint_t value, char *ptr) { + while (value.upper > 0) { + // while integer division is slow, hugeint division is MEGA slow + // we want to avoid doing as many divisions as possible + // for that reason we start off doing a division by a large power of ten that uint64_t can hold + // (100000000000000000) - this is the third largest + // the reason we don't use the largest is because that can result in an overflow inside the division + // function + uint64_t remainder; + value = Hugeint::DivModPositive(value, 100000000000000000ULL, remainder); + + auto startptr = ptr; + // now we format the remainder: note that we need to pad with zero's in case + // the remainder is small (i.e. less than 10000000000000000) + ptr = NumericHelper::FormatUnsigned(remainder, ptr); + + int format_length = startptr - ptr; + // pad with zero + for (int i = format_length; i < 17; i++) { + *--ptr = '0'; + } + } + // once the value falls in the range of a uint64_t, fallback to formatting as uint64_t to avoid hugeint division + return NumericHelper::FormatUnsigned(value.lower, ptr); + } + + static string_t FormatSigned(hugeint_t value, Vector &vector) { + int negative = value.upper < 0; + if (negative) { + Hugeint::NegateInPlace(value); + } + int length = UnsignedLength(value) + negative; + string_t result = StringVector::EmptyString(vector, length); + auto dataptr = result.GetDataWriteable(); + auto endptr = dataptr + length; + if (value.upper == 0) { + // small value: format as uint64_t + endptr = NumericHelper::FormatUnsigned(value.lower, endptr); + } else { + endptr = FormatUnsigned(value, endptr); + } + if (negative) { + *--endptr = '-'; + } + D_ASSERT(endptr == dataptr); + result.Finalize(); + return result; + } + + static int DecimalLength(hugeint_t value, uint8_t width, uint8_t scale) { + int negative; + if (value.upper < 0) { + Hugeint::NegateInPlace(value); + negative = 1; + } else { + negative = 0; + } + if (scale == 0) { + // scale is 0: regular number + return UnsignedLength(value) + negative; + } + // length is max of either: + // scale + 2 OR + // integer length + 1 + // scale + 2 happens when the number is in the range of (-1, 1) + // in that case we print "0.XXX", which is the scale, plus "0." (2 chars) + // integer length + 1 happens when the number is outside of that range + // in that case we print the integer number, but with one extra character ('.') + auto extra_numbers = width > scale ? 2 : 1; + return MaxValue(scale + extra_numbers, UnsignedLength(value) + 1) + negative; + } + + static void FormatDecimal(hugeint_t value, uint8_t width, uint8_t scale, char *dst, int len) { + auto endptr = dst + len; + + int negative = value.upper < 0; + if (negative) { + Hugeint::NegateInPlace(value); + *dst = '-'; + dst++; + } + if (scale == 0) { + // with scale=0 we format the number as a regular number + FormatUnsigned(value, endptr); + return; + } + + // we write two numbers: + // the numbers BEFORE the decimal (major) + // and the numbers AFTER the decimal (minor) + hugeint_t minor; + hugeint_t major = Hugeint::DivMod(value, Hugeint::POWERS_OF_TEN[scale], minor); + + // write the number after the decimal + dst = FormatUnsigned(minor, endptr); + // (optionally) pad with zeros and add the decimal point + while (dst > (endptr - scale)) { + *--dst = '0'; + } + *--dst = '.'; + // now write the part before the decimal + D_ASSERT(width > scale || major == 0); + if (width > scale) { + dst = FormatUnsigned(major, dst); + } + } + + static string_t FormatDecimal(hugeint_t value, uint8_t width, uint8_t scale, Vector &vector) { + int length = DecimalLength(value, width, scale); + string_t result = StringVector::EmptyString(vector, length); + + auto dst = result.GetDataWriteable(); + + FormatDecimal(value, width, scale, dst, length); + + result.Finalize(); + return result; + } +}; + +struct DateToStringCast { + static idx_t Length(int32_t date[], idx_t &year_length, bool &add_bc) { + // format is YYYY-MM-DD with optional (BC) at the end + // regular length is 10 + idx_t length = 6; + year_length = 4; + add_bc = false; + if (date[0] <= 0) { + // add (BC) suffix + length += 5; + date[0] = -date[0] + 1; + add_bc = true; + } + + // potentially add extra characters depending on length of year + year_length += date[0] >= 10000; + year_length += date[0] >= 100000; + year_length += date[0] >= 1000000; + year_length += date[0] >= 10000000; + length += year_length; + return length; + } + + static void Format(char *data, int32_t date[], idx_t year_length, bool add_bc) { + // now we write the string, first write the year + auto endptr = data + year_length; + endptr = NumericHelper::FormatUnsigned(date[0], endptr); + // add optional leading zeros + while (endptr > data) { + *--endptr = '0'; + } + // now write the month and day + auto ptr = data + year_length; + for (int i = 1; i <= 2; i++) { + ptr[0] = '-'; + if (date[i] < 10) { + ptr[1] = '0'; + ptr[2] = '0' + date[i]; + } else { + auto index = static_cast(date[i] * 2); + ptr[1] = duckdb_fmt::internal::data::digits[index]; + ptr[2] = duckdb_fmt::internal::data::digits[index + 1]; + } + ptr += 3; + } + // optionally add BC to the end of the date + if (add_bc) { + memcpy(ptr, " (BC)", 5); + } + } +}; + +struct TimeToStringCast { + //! Format microseconds to a buffer of length 6. Returns the number of trailing zeros + static int32_t FormatMicros(uint32_t microseconds, char micro_buffer[]) { + char *endptr = micro_buffer + 6; + endptr = NumericHelper::FormatUnsigned(microseconds, endptr); + while (endptr > micro_buffer) { + *--endptr = '0'; + } + idx_t trailing_zeros = 0; + for (idx_t i = 5; i > 0; i--) { + if (micro_buffer[i] != '0') { + break; + } + trailing_zeros++; + } + return trailing_zeros; + } + + static idx_t Length(int32_t time[], char micro_buffer[]) { + // format is HH:MM:DD.MS + // microseconds come after the time with a period separator + idx_t length; + if (time[3] == 0) { + // no microseconds + // format is HH:MM:DD + length = 8; + } else { + length = 15; + // for microseconds, we truncate any trailing zeros (i.e. "90000" becomes ".9") + // first write the microseconds to the microsecond buffer + // we write backwards and pad with zeros to the left + // now we figure out how many digits we need to include by looking backwards + // and checking how many zeros we encounter + length -= FormatMicros(time[3], micro_buffer); + } + return length; + } + + static void FormatTwoDigits(char *ptr, int32_t value) { + D_ASSERT(value >= 0 && value <= 99); + if (value < 10) { + ptr[0] = '0'; + ptr[1] = '0' + value; + } else { + auto index = static_cast(value * 2); + ptr[0] = duckdb_fmt::internal::data::digits[index]; + ptr[1] = duckdb_fmt::internal::data::digits[index + 1]; + } + } + + static void Format(char *data, idx_t length, int32_t time[], char micro_buffer[]) { + // first write hour, month and day + auto ptr = data; + ptr[2] = ':'; + ptr[5] = ':'; + for (int i = 0; i <= 2; i++) { + FormatTwoDigits(ptr, time[i]); + ptr += 3; + } + if (length > 8) { + // write the micro seconds at the end + data[8] = '.'; + memcpy(data + 9, micro_buffer, length - 9); + } + } +}; + +struct IntervalToStringCast { + static void FormatSignedNumber(int64_t value, char buffer[], idx_t &length) { + int sign = -(value < 0); + uint64_t unsigned_value = (value ^ sign) - sign; + length += NumericHelper::UnsignedLength(unsigned_value) - sign; + auto endptr = buffer + length; + endptr = NumericHelper::FormatUnsigned(unsigned_value, endptr); + if (sign) { + *--endptr = '-'; + } + } + + static void FormatTwoDigits(int64_t value, char buffer[], idx_t &length) { + TimeToStringCast::FormatTwoDigits(buffer + length, value); + length += 2; + } + + static void FormatIntervalValue(int32_t value, char buffer[], idx_t &length, const char *name, idx_t name_len) { + if (value == 0) { + return; + } + if (length != 0) { + // space if there is already something in the buffer + buffer[length++] = ' '; + } + FormatSignedNumber(value, buffer, length); + // append the name together with a potential "s" (for plurals) + memcpy(buffer + length, name, name_len); + length += name_len; + if (value != 1) { + buffer[length++] = 's'; + } + } + + //! Formats an interval to a buffer, the buffer should be >=70 characters + //! years: 17 characters (max value: "-2147483647 years") + //! months: 9 (max value: "12 months") + //! days: 16 characters (max value: "-2147483647 days") + //! time: 24 characters (max value: -2562047788:00:00.123456) + //! spaces between all characters (+3 characters) + //! Total: 70 characters + //! Returns the length of the interval + static idx_t Format(interval_t interval, char buffer[]) { + idx_t length = 0; + if (interval.months != 0) { + int32_t years = interval.months / 12; + int32_t months = interval.months - years * 12; + // format the years and months + FormatIntervalValue(years, buffer, length, " year", 5); + FormatIntervalValue(months, buffer, length, " month", 6); + } + if (interval.days != 0) { + // format the days + FormatIntervalValue(interval.days, buffer, length, " day", 4); + } + if (interval.micros != 0) { + if (length != 0) { + // space if there is already something in the buffer + buffer[length++] = ' '; + } + int64_t micros = interval.micros; + if (micros < 0) { + // negative time: append negative sign + buffer[length++] = '-'; + } else { + micros = -micros; + } + int64_t hour = -(micros / Interval::MICROS_PER_HOUR); + micros += hour * Interval::MICROS_PER_HOUR; + int64_t min = -(micros / Interval::MICROS_PER_MINUTE); + micros += min * Interval::MICROS_PER_MINUTE; + int64_t sec = -(micros / Interval::MICROS_PER_SEC); + micros += sec * Interval::MICROS_PER_SEC; + micros = -micros; + + if (hour < 10) { + buffer[length++] = '0'; + } + FormatSignedNumber(hour, buffer, length); + buffer[length++] = ':'; + FormatTwoDigits(min, buffer, length); + buffer[length++] = ':'; + FormatTwoDigits(sec, buffer, length); + if (micros != 0) { + buffer[length++] = '.'; + auto trailing_zeros = TimeToStringCast::FormatMicros(micros, buffer + length); + length += 6 - trailing_zeros; + } + } else if (length == 0) { + // empty interval: default to 00:00:00 + memcpy(buffer, "00:00:00", 8); + return 8; + } + return length; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/chunk_collection.hpp b/src/duckdb/src/include/duckdb/common/types/chunk_collection.hpp new file mode 100644 index 00000000..e3d1ddaa --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/chunk_collection.hpp @@ -0,0 +1,137 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/chunk_collection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/order_type.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/winapi.hpp" + +namespace duckdb { +class Allocator; +class ClientContext; + +//! A ChunkCollection represents a set of DataChunks that all have the same +//! types +/*! + A ChunkCollection represents a set of DataChunks concatenated together in a + list. Individual values of the collection can be iterated over using the + iterator. It is also possible to iterate directly over the chunks for more + direct access. +*/ +class ChunkCollection { +public: + explicit ChunkCollection(Allocator &allocator); + explicit ChunkCollection(ClientContext &context); + + //! The types of columns in the ChunkCollection + vector &Types() { + return types; + } + const vector &Types() const { + return types; + } + + //! The amount of rows in the ChunkCollection + const idx_t &Count() const { + return count; + } + + //! The amount of columns in the ChunkCollection + idx_t ColumnCount() const { + return types.size(); + } + + //! Append a new DataChunk directly to this ChunkCollection + DUCKDB_API void Append(DataChunk &new_chunk); + + //! Append a new DataChunk directly to this ChunkCollection + DUCKDB_API void Append(unique_ptr new_chunk); + + //! Append another ChunkCollection directly to this ChunkCollection + DUCKDB_API void Append(ChunkCollection &other); + + //! Merge is like Append but messes up the order and destroys the other collection + DUCKDB_API void Merge(ChunkCollection &other); + + //! Fuse adds new columns to the right of the collection + DUCKDB_API void Fuse(ChunkCollection &other); + + DUCKDB_API void Verify(); + + //! Gets the value of the column at the specified index + DUCKDB_API Value GetValue(idx_t column, idx_t index); + //! Sets the value of the column at the specified index + DUCKDB_API void SetValue(idx_t column, idx_t index, const Value &value); + + //! Copy a single cell to a target vector + DUCKDB_API void CopyCell(idx_t column, idx_t index, Vector &target, idx_t target_offset); + + DUCKDB_API string ToString() const; + DUCKDB_API void Print() const; + + //! Gets a reference to the chunk at the given index + DataChunk &GetChunkForRow(idx_t row_index) { + return *chunks[LocateChunk(row_index)]; + } + + //! Gets a reference to the chunk at the given index + DataChunk &GetChunk(idx_t chunk_index) { + D_ASSERT(chunk_index < chunks.size()); + return *chunks[chunk_index]; + } + const DataChunk &GetChunk(idx_t chunk_index) const { + D_ASSERT(chunk_index < chunks.size()); + return *chunks[chunk_index]; + } + + const vector> &Chunks() { + return chunks; + } + + idx_t ChunkCount() const { + return chunks.size(); + } + + void Reset() { + count = 0; + chunks.clear(); + types.clear(); + } + + unique_ptr Fetch() { + if (ChunkCount() == 0) { + return nullptr; + } + + auto res = std::move(chunks[0]); + chunks.erase(chunks.begin() + 0); + return res; + } + + //! Locates the chunk that belongs to the specific index + idx_t LocateChunk(idx_t index) { + idx_t result = index / STANDARD_VECTOR_SIZE; + D_ASSERT(result < chunks.size()); + return result; + } + + Allocator &GetAllocator() { + return allocator; + } + +private: + Allocator &allocator; + //! The total amount of elements in the collection + idx_t count; + //! The set of data chunks in the collection + vector> chunks; + //! The types of the ChunkCollection + vector types; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp new file mode 100644 index 00000000..2b04390a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp @@ -0,0 +1,104 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/column/column_data_allocator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" + +namespace duckdb { + +struct ChunkMetaData; +struct VectorMetaData; + +struct BlockMetaData { + //! The underlying block handle + shared_ptr handle; + //! How much space is currently used within the block + uint32_t size; + //! How much space is available in the block + uint32_t capacity; + + uint32_t Capacity(); +}; + +class ColumnDataAllocator { +public: + explicit ColumnDataAllocator(Allocator &allocator); + explicit ColumnDataAllocator(BufferManager &buffer_manager); + ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type); + ColumnDataAllocator(ColumnDataAllocator &allocator); + + //! Returns an allocator object to allocate with. This returns the allocator in IN_MEMORY_ALLOCATOR, and a buffer + //! allocator in case of BUFFER_MANAGER_ALLOCATOR. + Allocator &GetAllocator(); + //! Returns the allocator type + ColumnDataAllocatorType GetType() { + return type; + } + void MakeShared() { + shared = true; + } + bool IsShared() const { + return shared; + } + idx_t BlockCount() const { + return blocks.size(); + } + idx_t SizeInBytes() const { + idx_t total_size = 0; + for (const auto &block : blocks) { + total_size += block.size; + } + return total_size; + } + +public: + void AllocateData(idx_t size, uint32_t &block_id, uint32_t &offset, ChunkManagementState *chunk_state); + + void Initialize(ColumnDataAllocator &other); + void InitializeChunkState(ChunkManagementState &state, ChunkMetaData &meta_data); + data_ptr_t GetDataPointer(ChunkManagementState &state, uint32_t block_id, uint32_t offset); + void UnswizzlePointers(ChunkManagementState &state, Vector &result, idx_t v_offset, uint16_t count, + uint32_t block_id, uint32_t offset); + + //! Deletes the block with the given id + void DeleteBlock(uint32_t block_id); + +private: + void AllocateEmptyBlock(idx_t size); + BufferHandle AllocateBlock(idx_t size); + BufferHandle Pin(uint32_t block_id); + + bool HasBlocks() const { + return !blocks.empty(); + } + +private: + void AllocateBuffer(idx_t size, uint32_t &block_id, uint32_t &offset, ChunkManagementState *chunk_state); + void AllocateMemory(idx_t size, uint32_t &block_id, uint32_t &offset, ChunkManagementState *chunk_state); + void AssignPointer(uint32_t &block_id, uint32_t &offset, data_ptr_t pointer); + +private: + ColumnDataAllocatorType type; + union { + //! The allocator object (if this is a IN_MEMORY_ALLOCATOR) + Allocator *allocator; + //! The buffer manager (if this is a BUFFER_MANAGER_ALLOCATOR) + BufferManager *buffer_manager; + } alloc; + //! The set of blocks used by the column data collection + vector blocks; + //! The set of allocated data + vector allocated_data; + //! Whether this ColumnDataAllocator is shared across ColumnDataCollections that allocate in parallel + bool shared = false; + //! Lock used in case this ColumnDataAllocator is shared across threads + mutex lock; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp new file mode 100644 index 00000000..0fcd5f4a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp @@ -0,0 +1,226 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/column/column_data_collection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types/column/column_data_collection_iterators.hpp" + +namespace duckdb { +class BufferManager; +class BlockHandle; +class ClientContext; +struct ColumnDataCopyFunction; +class ColumnDataAllocator; +class ColumnDataCollection; +class ColumnDataCollectionSegment; +class ColumnDataRowCollection; + +//! The ColumnDataCollection represents a set of (buffer-managed) data stored in columnar format +//! It is efficient to read and scan +class ColumnDataCollection { +public: + //! Constructs an in-memory column data collection from an allocator + DUCKDB_API ColumnDataCollection(Allocator &allocator, vector types); + //! Constructs an empty (but valid) in-memory column data collection from an allocator + DUCKDB_API ColumnDataCollection(Allocator &allocator); + //! Constructs a buffer-managed column data collection + DUCKDB_API ColumnDataCollection(BufferManager &buffer_manager, vector types); + //! Constructs either an in-memory or a buffer-managed column data collection + DUCKDB_API ColumnDataCollection(ClientContext &context, vector types, + ColumnDataAllocatorType type = ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); + //! Creates a column data collection that inherits the blocks to write to. This allows blocks to be shared + //! between multiple column data collections and prevents wasting space. + //! Note that after one CDC inherits blocks from another, the other + //! cannot be written to anymore (i.e. we take ownership of the half-written blocks). + DUCKDB_API ColumnDataCollection(ColumnDataCollection &parent); + DUCKDB_API ColumnDataCollection(shared_ptr allocator, vector types); + DUCKDB_API ~ColumnDataCollection(); + +public: + //! The types of columns in the ColumnDataCollection + vector &Types() { + return types; + } + const vector &Types() const { + return types; + } + + //! The amount of rows in the ColumnDataCollection + const idx_t &Count() const { + return count; + } + + //! The amount of columns in the ColumnDataCollection + idx_t ColumnCount() const { + return types.size(); + } + + //! The size (in bytes) of this ColumnDataCollection + idx_t SizeInBytes() const; + + //! Get the allocator + DUCKDB_API Allocator &GetAllocator() const; + + //! Initializes an Append state - useful for optimizing many appends made to the same column data collection + DUCKDB_API void InitializeAppend(ColumnDataAppendState &state); + //! Append a DataChunk to this ColumnDataCollection using the specified append state + DUCKDB_API void Append(ColumnDataAppendState &state, DataChunk &new_chunk); + + //! Initializes a chunk with the correct types that can be used to call Scan + DUCKDB_API void InitializeScanChunk(DataChunk &chunk) const; + //! Initializes a chunk with the correct types for a given scan state + DUCKDB_API void InitializeScanChunk(ColumnDataScanState &state, DataChunk &chunk) const; + //! Initializes a Scan state for scanning all columns + DUCKDB_API void + InitializeScan(ColumnDataScanState &state, + ColumnDataScanProperties properties = ColumnDataScanProperties::ALLOW_ZERO_COPY) const; + //! Initializes a Scan state for scanning a subset of the columns + DUCKDB_API void + InitializeScan(ColumnDataScanState &state, vector column_ids, + ColumnDataScanProperties properties = ColumnDataScanProperties::ALLOW_ZERO_COPY) const; + //! Initialize a parallel scan over the column data collection over all columns + DUCKDB_API void + InitializeScan(ColumnDataParallelScanState &state, + ColumnDataScanProperties properties = ColumnDataScanProperties::ALLOW_ZERO_COPY) const; + //! Initialize a parallel scan over the column data collection over a subset of the columns + DUCKDB_API void + InitializeScan(ColumnDataParallelScanState &state, vector column_ids, + ColumnDataScanProperties properties = ColumnDataScanProperties::ALLOW_ZERO_COPY) const; + //! Scans a DataChunk from the ColumnDataCollection + DUCKDB_API bool Scan(ColumnDataScanState &state, DataChunk &result) const; + //! Scans a DataChunk from the ColumnDataCollection + DUCKDB_API bool Scan(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, DataChunk &result) const; + + //! Append a DataChunk directly to this ColumnDataCollection - calls InitializeAppend and Append internally + DUCKDB_API void Append(DataChunk &new_chunk); + + //! Appends the other ColumnDataCollection to this, destroying the other data collection + DUCKDB_API void Combine(ColumnDataCollection &other); + + DUCKDB_API void Verify(); + + DUCKDB_API string ToString() const; + DUCKDB_API void Print() const; + + DUCKDB_API void Reset(); + + //! Returns the number of data chunks present in the ColumnDataCollection + DUCKDB_API idx_t ChunkCount() const; + //! Fetch an individual chunk from the ColumnDataCollection + DUCKDB_API void FetchChunk(idx_t chunk_idx, DataChunk &result) const; + + //! Constructs a class that can be iterated over to fetch individual chunks + //! Iterating over this is syntactic sugar over just calling Scan + DUCKDB_API ColumnDataChunkIterationHelper Chunks() const; + //! Constructs a class that can be iterated over to fetch individual chunks + //! Only the column indexes specified in the column_ids list are scanned + DUCKDB_API ColumnDataChunkIterationHelper Chunks(vector column_ids) const; + + //! Constructs a class that can be iterated over to fetch individual rows + //! Note that row iteration is slow, and the `.Chunks()` method should be used instead + DUCKDB_API ColumnDataRowIterationHelper Rows() const; + + //! Returns a materialized set of all of the rows in the column data collection + //! Note that usage of this is slow - avoid using this unless the amount of rows is small, or if you do not care + //! about performance + DUCKDB_API ColumnDataRowCollection GetRows() const; + + //! Compare two column data collections to another. If they are equal according to result equality rules, + //! return true. That means null values are equal, and approx equality is used for floating point values. + //! If they are not equal, return false and fill in the error message. + static bool ResultEquals(const ColumnDataCollection &left, const ColumnDataCollection &right, string &error_message, + bool ordered = false); + + //! Obtains the next scan index to scan from + bool NextScanIndex(ColumnDataScanState &state, idx_t &chunk_index, idx_t &segment_index, idx_t &row_index) const; + //! Scans at the indices (obtained from NextScanIndex) + void ScanAtIndex(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, DataChunk &result, + idx_t chunk_index, idx_t segment_index, idx_t row_index) const; + + //! Initialize the column data collection + void Initialize(vector types); + + //! Get references to the string heaps in this ColumnDataCollection + vector> GetHeapReferences(); + //! Get the allocator type of this ColumnDataCollection + ColumnDataAllocatorType GetAllocatorType() const; + + //! Get a vector of the segments in this ColumnDataCollection + const vector> &GetSegments() const; + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! Creates a new segment within the ColumnDataCollection + void CreateSegment(); + + static ColumnDataCopyFunction GetCopyFunction(const LogicalType &type); + +private: + //! The Column Data Allocator + buffer_ptr allocator; + //! The types of the stored entries + vector types; + //! The number of entries stored in the column data collection + idx_t count; + //! The data segments of the column data collection + vector> segments; + //! The set of copy functions + vector copy_functions; + //! When the column data collection is marked as finished - new tuples can no longer be appended to it + bool finished_append; +}; + +//! The ColumnDataRowCollection represents a set of materialized rows, as obtained from the ColumnDataCollection +class ColumnDataRowCollection { +public: + DUCKDB_API ColumnDataRowCollection(const ColumnDataCollection &collection); + +public: + DUCKDB_API Value GetValue(idx_t column, idx_t index) const; + +public: + // container API + bool empty() const { + return rows.empty(); + } + idx_t size() const { + return rows.size(); + } + + DUCKDB_API ColumnDataRow &operator[](idx_t i); + DUCKDB_API const ColumnDataRow &operator[](idx_t i) const; + + vector::iterator begin() { + return rows.begin(); + } + vector::iterator end() { + return rows.end(); + } + vector::const_iterator cbegin() const { + return rows.cbegin(); + } + vector::const_iterator cend() const { + return rows.cend(); + } + vector::const_iterator begin() const { + return rows.begin(); + } + vector::const_iterator end() const { + return rows.end(); + } + +private: + vector rows; + vector> chunks; + ColumnDataScanState scan_state; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp new file mode 100644 index 00000000..3ceb700b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp @@ -0,0 +1,86 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/column/column_data_collection_iterators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_scan_states.hpp" + +namespace duckdb { +class ColumnDataCollection; + +class ColumnDataChunkIterationHelper { +public: + DUCKDB_API ColumnDataChunkIterationHelper(const ColumnDataCollection &collection, vector column_ids); + +private: + const ColumnDataCollection &collection; + vector column_ids; + +private: + class ColumnDataChunkIterator; + + class ColumnDataChunkIterator { + public: + DUCKDB_API explicit ColumnDataChunkIterator(const ColumnDataCollection *collection_p, + vector column_ids); + + const ColumnDataCollection *collection; + ColumnDataScanState scan_state; + shared_ptr scan_chunk; + idx_t row_index; + + public: + DUCKDB_API void Next(); + + DUCKDB_API ColumnDataChunkIterator &operator++(); + DUCKDB_API bool operator!=(const ColumnDataChunkIterator &other) const; + DUCKDB_API DataChunk &operator*() const; + }; + +public: + ColumnDataChunkIterator begin() { + return ColumnDataChunkIterator(&collection, column_ids); + } + ColumnDataChunkIterator end() { + return ColumnDataChunkIterator(nullptr, vector()); + } +}; + +class ColumnDataRowIterationHelper { +public: + DUCKDB_API ColumnDataRowIterationHelper(const ColumnDataCollection &collection); + +private: + const ColumnDataCollection &collection; + +private: + class ColumnDataRowIterator; + + class ColumnDataRowIterator { + public: + DUCKDB_API explicit ColumnDataRowIterator(const ColumnDataCollection *collection_p); + + const ColumnDataCollection *collection; + ColumnDataScanState scan_state; + shared_ptr scan_chunk; + ColumnDataRow current_row; + + public: + void Next(); + + DUCKDB_API ColumnDataRowIterator &operator++(); + DUCKDB_API bool operator!=(const ColumnDataRowIterator &other) const; + DUCKDB_API const ColumnDataRow &operator*() const; + }; + +public: + DUCKDB_API ColumnDataRowIterator begin(); + DUCKDB_API ColumnDataRowIterator end(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_segment.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_segment.hpp new file mode 100644 index 00000000..5f69e211 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_segment.hpp @@ -0,0 +1,145 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/column/column_data_collection_segment.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_allocator.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" + +namespace duckdb { + +struct VectorChildIndex { + explicit VectorChildIndex(idx_t index = DConstants::INVALID_INDEX) : index(index) { + } + + idx_t index; + + bool IsValid() { + return index != DConstants::INVALID_INDEX; + } +}; + +struct VectorDataIndex { + explicit VectorDataIndex(idx_t index = DConstants::INVALID_INDEX) : index(index) { + } + + idx_t index; + + bool IsValid() { + return index != DConstants::INVALID_INDEX; + } +}; + +struct SwizzleMetaData { + SwizzleMetaData(VectorDataIndex child_index_p, uint16_t offset_p, uint16_t count_p) + : child_index(child_index_p), offset(offset_p), count(count_p) { + } + //! Index of block storing heap + VectorDataIndex child_index; + //! Offset into the string_t vector + uint16_t offset; + //! Number of strings starting at 'offset' that have strings stored in the block with index 'child_index' + uint16_t count; +}; + +struct VectorMetaData { + //! Where the vector data lives + uint32_t block_id; + uint32_t offset; + //! The number of entries present in this vector + uint16_t count; + //! Meta data about string pointers + vector swizzle_data; + + //! Child data of this vector (used only for lists and structs) + //! Note: child indices are stored with one layer of indirection + //! The child_index here refers to the `child_indices` array in the ColumnDataCollectionSegment + //! The entry in the child_indices array then refers to the actual `VectorMetaData` index + //! In case of structs, the child_index refers to the FIRST child in the `child_indices` array + //! Subsequent children are stored consecutively, i.e. + //! first child: segment.child_indices[child_index + 0] + //! nth child : segment.child_indices[child_index + (n - 1)] + VectorChildIndex child_index; + //! Next vector entry (in case there is more data - used only in case of children of lists) + VectorDataIndex next_data; +}; + +struct ChunkMetaData { + //! The set of vectors of the chunk + vector vector_data; + //! The block ids referenced by the chunk + unordered_set block_ids; + //! The number of entries in the chunk + uint16_t count; +}; + +class ColumnDataCollectionSegment { +public: + ColumnDataCollectionSegment(shared_ptr allocator, vector types_p); + + shared_ptr allocator; + //! The types of the chunks + vector types; + //! The number of entries in the internal column data + idx_t count; + //! Set of chunk meta data + vector chunk_data; + //! Set of vector meta data + vector vector_data; + //! The set of child indices + vector child_indices; + //! The string heap for the column data collection (only used for IN_MEMORY_ALLOCATOR) + shared_ptr heap; + +public: + void AllocateNewChunk(); + //! Allocate space for a vector of a specific type in the segment + VectorDataIndex AllocateVector(const LogicalType &type, ChunkMetaData &chunk_data, + ChunkManagementState *chunk_state = nullptr, + VectorDataIndex prev_index = VectorDataIndex()); + //! Allocate space for a vector during append + VectorDataIndex AllocateVector(const LogicalType &type, ChunkMetaData &chunk_data, + ColumnDataAppendState &append_state, VectorDataIndex prev_index = VectorDataIndex()); + //! Allocate space for string data during append (BUFFER_MANAGER_ALLOCATOR only) + VectorDataIndex AllocateStringHeap(idx_t size, ChunkMetaData &chunk_meta, ColumnDataAppendState &append_state, + VectorDataIndex prev_index = VectorDataIndex()); + + void InitializeChunkState(idx_t chunk_index, ChunkManagementState &state); + void ReadChunk(idx_t chunk_index, ChunkManagementState &state, DataChunk &chunk, + const vector &column_ids); + + idx_t ReadVector(ChunkManagementState &state, VectorDataIndex vector_index, Vector &result); + + VectorDataIndex GetChildIndex(VectorChildIndex index, idx_t child_entry = 0); + VectorChildIndex AddChildIndex(VectorDataIndex index); + VectorChildIndex ReserveChildren(idx_t child_count); + void SetChildIndex(VectorChildIndex base_idx, idx_t child_number, VectorDataIndex index); + + VectorMetaData &GetVectorData(VectorDataIndex index) { + D_ASSERT(index.index < vector_data.size()); + return vector_data[index.index]; + } + + idx_t ChunkCount() const; + idx_t SizeInBytes() const; + + void FetchChunk(idx_t chunk_idx, DataChunk &result); + void FetchChunk(idx_t chunk_idx, DataChunk &result, const vector &column_ids); + + void Verify(); + + static idx_t GetDataSize(idx_t type_size); + static validity_t *GetValidityPointer(data_ptr_t base_ptr, idx_t type_size); + +private: + idx_t ReadVectorInternal(ChunkManagementState &state, VectorDataIndex vector_index, Vector &result); + VectorDataIndex AllocateVectorInternal(const LogicalType &type, ChunkMetaData &chunk_meta, + ChunkManagementState *chunk_state); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_consumer.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_consumer.hpp new file mode 100644 index 00000000..d056cf08 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_consumer.hpp @@ -0,0 +1,86 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/column/column_data_consumer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/types/column/column_data_collection_segment.hpp" +#include "duckdb/common/types/column/column_data_scan_states.hpp" + +namespace duckdb { + +struct ColumnDataConsumerScanState { + ColumnDataAllocator *allocator = nullptr; + ChunkManagementState current_chunk_state; + idx_t chunk_index; +}; + +//! ColumnDataConsumer can scan a ColumnDataCollection, and consume it in the process, i.e., read blocks are deleted +class ColumnDataConsumer { +public: + struct ChunkReference { + public: + ChunkReference(ColumnDataCollectionSegment *segment_p, uint32_t chunk_index_p); + uint32_t GetMinimumBlockID() const; + friend bool operator<(const ChunkReference &lhs, const ChunkReference &rhs) { + // Sort by allocator first + if (lhs.segment->allocator.get() != rhs.segment->allocator.get()) { + return lhs.segment->allocator.get() < rhs.segment->allocator.get(); + } + // Then by minimum block id + return lhs.GetMinimumBlockID() < rhs.GetMinimumBlockID(); + } + + public: + ColumnDataCollectionSegment *segment; + uint32_t chunk_index_in_segment; + }; + +public: + ColumnDataConsumer(ColumnDataCollection &collection, vector column_ids); + + idx_t Count() const { + return collection.Count(); + } + + idx_t ChunkCount() const { + return chunk_count; + } + +public: + //! Initialize the scan of the ColumnDataCollection + void InitializeScan(); + //! Assign a chunk to the scan state + bool AssignChunk(ColumnDataConsumerScanState &state); + //! Scan the assigned chunk + void ScanChunk(ColumnDataConsumerScanState &state, DataChunk &chunk) const; + //! Indicate that scanning the chunk is done + void FinishChunk(ColumnDataConsumerScanState &state); + +private: + void ConsumeChunks(idx_t delete_index_start, idx_t delete_index_end); + +private: + mutex lock; + //! The collection being scanned + ColumnDataCollection &collection; + //! The column ids to scan + vector column_ids; + //! The number of chunk references + idx_t chunk_count; + //! The chunks (in order) to be scanned + vector chunk_references; + //! Current index into "chunks" + idx_t current_chunk_index; + //! Chunks currently in progress + unordered_set chunks_in_progress; + //! The data has been consumed up to this chunk index + idx_t chunk_delete_index; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp new file mode 100644 index 00000000..c809520c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp @@ -0,0 +1,82 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/column/column_data_scan_states.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/mutex.hpp" + +namespace duckdb { + +enum class ColumnDataAllocatorType : uint8_t { + //! Use a buffer manager to allocate large chunks of memory that vectors then use + BUFFER_MANAGER_ALLOCATOR, + //! Use an in-memory allocator, allocating data for every chunk + //! This causes the column data collection to allocate blocks that are not tied to a buffer manager + IN_MEMORY_ALLOCATOR, + //! Use a buffer manager to allocate vectors, but use a StringHeap for strings + HYBRID +}; + +enum class ColumnDataScanProperties : uint8_t { + INVALID, + //! Allow zero copy scans - this introduces a dependency on the resulting vector on the scan state of the column + //! data collection, which means vectors might not be valid anymore after the next chunk is scanned. + ALLOW_ZERO_COPY, + //! Disallow zero-copy scans, always copying data into the target vector + //! As a result, data scanned will be valid even after the column data collection is destroyed + DISALLOW_ZERO_COPY +}; + +struct ChunkManagementState { + unordered_map handles; + ColumnDataScanProperties properties = ColumnDataScanProperties::INVALID; +}; + +struct ColumnDataAppendState { + ChunkManagementState current_chunk_state; + vector vector_data; +}; + +struct ColumnDataScanState { + ChunkManagementState current_chunk_state; + idx_t segment_index; + idx_t chunk_index; + idx_t current_row_index; + idx_t next_row_index; + ColumnDataScanProperties properties; + vector column_ids; +}; + +struct ColumnDataParallelScanState { + ColumnDataScanState scan_state; + mutex lock; +}; + +struct ColumnDataLocalScanState { + ChunkManagementState current_chunk_state; + idx_t current_segment_index = DConstants::INVALID_INDEX; + idx_t current_row_index; +}; + +class ColumnDataRow { +public: + ColumnDataRow(DataChunk &chunk, idx_t row_index, idx_t base_index); + + DataChunk &chunk; + idx_t row_index; + idx_t base_index; + +public: + Value GetValue(idx_t column_index) const; + idx_t RowIndex() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/partitioned_column_data.hpp b/src/duckdb/src/include/duckdb/common/types/column/partitioned_column_data.hpp new file mode 100644 index 00000000..3defecb0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/column/partitioned_column_data.hpp @@ -0,0 +1,125 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/column/partitioned_column_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/perfect_map_set.hpp" +#include "duckdb/common/types/column/column_data_allocator.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" + +namespace duckdb { + +//! Local state for parallel partitioning +struct PartitionedColumnDataAppendState { +public: + PartitionedColumnDataAppendState() : partition_indices(LogicalType::UBIGINT) { + } + +public: + Vector partition_indices; + SelectionVector partition_sel; + perfect_map_t partition_entries; + DataChunk slice_chunk; + + vector> partition_buffers; + vector> partition_append_states; +}; + +enum class PartitionedColumnDataType : uint8_t { + INVALID, + //! Radix partitioning on a hash column + RADIX, + //! Hive-style multi-field partitioning + HIVE +}; + +//! Shared allocators for parallel partitioning +struct PartitionColumnDataAllocators { + mutex lock; + vector> allocators; +}; + +//! PartitionedColumnData represents partitioned columnar data, which serves as an interface for different types of +//! partitioning, e.g., radix, hive +class PartitionedColumnData { +public: + unique_ptr CreateShared(); + virtual ~PartitionedColumnData(); + +public: + //! Initializes a local state for parallel partitioning that can be merged into this PartitionedColumnData + void InitializeAppendState(PartitionedColumnDataAppendState &state) const; + //! Appends a DataChunk to this PartitionedColumnData + void Append(PartitionedColumnDataAppendState &state, DataChunk &input); + //! Flushes any remaining data in the append state into this PartitionedColumnData + void FlushAppendState(PartitionedColumnDataAppendState &state); + //! Combine another PartitionedColumnData into this PartitionedColumnData + void Combine(PartitionedColumnData &other); + //! Get the partitions in this PartitionedColumnData + vector> &GetPartitions(); + +protected: + //===--------------------------------------------------------------------===// + // Partitioning type implementation interface + //===--------------------------------------------------------------------===// + //! Size of the buffers in the append states for this type of partitioning (default 128) + virtual idx_t BufferSize() const { + return MinValue(128, STANDARD_VECTOR_SIZE); + } + //! Initialize a PartitionedColumnDataAppendState for this type of partitioning (optional) + virtual void InitializeAppendStateInternal(PartitionedColumnDataAppendState &state) const { + } + //! Compute the partition indices for this type of partitioning for the input DataChunk and store them in the + //! `partition_data` of the local state. If this type creates partitions on the fly (for, e.g., hive), this + //! function is also in charge of creating new partitions and mapping the input data to a partition index + virtual void ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) { + throw NotImplementedException("ComputePartitionIndices for this type of PartitionedColumnData"); + } + +protected: + //! PartitionedColumnData can only be instantiated by derived classes + PartitionedColumnData(PartitionedColumnDataType type, ClientContext &context, vector types); + PartitionedColumnData(const PartitionedColumnData &other); + + //! If the buffer is half full, we append to the partition + inline idx_t HalfBufferSize() const { + D_ASSERT(IsPowerOfTwo(BufferSize())); + return BufferSize() / 2; + } + //! Create a new shared allocator + void CreateAllocator(); + //! Create a collection for a specific a partition + unique_ptr CreatePartitionCollection(idx_t partition_index) const { + return make_uniq(allocators->allocators[partition_index], types); + } + //! Create a DataChunk used for buffering appends to the partition + unique_ptr CreatePartitionBuffer() const; + +protected: + PartitionedColumnDataType type; + ClientContext &context; + vector types; + + mutex lock; + shared_ptr allocators; + vector> partitions; + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/conflict_manager.hpp b/src/duckdb/src/include/duckdb/common/types/conflict_manager.hpp new file mode 100644 index 00000000..59024b8a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/conflict_manager.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/common/types/selection_vector.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class Index; +class ConflictInfo; + +enum class ConflictManagerMode : uint8_t { + SCAN, // gather conflicts without throwing + THROW // throw on the conflicts that were not found during the scan +}; + +enum class LookupResultType : uint8_t { LOOKUP_MISS, LOOKUP_HIT, LOOKUP_NULL }; + +class ConflictManager { +public: + ConflictManager(VerifyExistenceType lookup_type, idx_t input_size, + optional_ptr conflict_info = nullptr); + +public: + void SetIndexCount(idx_t count); + // These methods return a boolean indicating whether we should throw or not + bool AddMiss(idx_t chunk_index); + bool AddHit(idx_t chunk_index, row_t row_id); + bool AddNull(idx_t chunk_index); + VerifyExistenceType LookupType() const; + // This should be called before using the conflicts selection vector + void Finalize(); + idx_t ConflictCount() const; + const ManagedSelection &Conflicts() const; + Vector &RowIds(); + const ConflictInfo &GetConflictInfo() const; + void FinishLookup(); + void SetMode(ConflictManagerMode mode); + +private: + bool IsConflict(LookupResultType type); + const unordered_set &InternalConflictSet() const; + Vector &InternalRowIds(); + Vector &InternalIntermediate(); + ManagedSelection &InternalSelection(); + bool SingleIndexTarget() const; + bool ShouldThrow(idx_t chunk_index) const; + bool ShouldIgnoreNulls() const; + void AddConflictInternal(idx_t chunk_index, row_t row_id); + void AddToConflictSet(idx_t chunk_index); + +private: + VerifyExistenceType lookup_type; + idx_t input_size; + optional_ptr conflict_info; + idx_t index_count; + bool finalized = false; + ManagedSelection conflicts; + unique_ptr row_ids; + // Used to check if a given conflict is part of the conflict target or not + unique_ptr> conflict_set; + // Contains 'input_size' booleans, indicating if a given index in the input chunk has a conflict + unique_ptr intermediate_vector; + // Mapping from chunk_index to row_id + vector row_id_map; + // Whether we have already found the one conflict target we're interested in + bool single_index_finished = false; + ConflictManagerMode mode; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/constraint_conflict_info.hpp b/src/duckdb/src/include/duckdb/common/types/constraint_conflict_info.hpp new file mode 100644 index 00000000..a00bd3ee --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/constraint_conflict_info.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/selection_vector.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +class Index; + +class ConflictInfo { +public: + ConflictInfo(const unordered_set &column_ids, bool only_check_unique = true) + : column_ids(column_ids), only_check_unique(only_check_unique) { + } + const unordered_set &column_ids; + +public: + bool ConflictTargetMatches(Index &index) const; + void VerifyAllConflictsMeetCondition() const; + +public: + bool only_check_unique = true; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/data_chunk.hpp b/src/duckdb/src/include/duckdb/common/types/data_chunk.hpp new file mode 100644 index 00000000..fdb44cdd --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/data_chunk.hpp @@ -0,0 +1,170 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/data_chunk.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/winapi.hpp" + +namespace duckdb { +class Allocator; +class ClientContext; +class ExecutionContext; +class VectorCache; +class Serializer; +class Deserializer; + +//! A Data Chunk represents a set of vectors. +/*! + The data chunk class is the intermediate representation used by the + execution engine of DuckDB. It effectively represents a subset of a relation. + It holds a set of vectors that all have the same length. + + DataChunk is initialized using the DataChunk::Initialize function by + providing it with a vector of TypeIds for the Vector members. By default, + this function will also allocate a chunk of memory in the DataChunk for the + vectors and all the vectors will be referencing vectors to the data owned by + the chunk. The reason for this behavior is that the underlying vectors can + become referencing vectors to other chunks as well (i.e. in the case an + operator does not alter the data, such as a Filter operator which only adds a + selection vector). + + In addition to holding the data of the vectors, the DataChunk also owns the + selection vector that underlying vectors can point to. +*/ +class DataChunk { +public: + //! Creates an empty DataChunk + DUCKDB_API DataChunk(); + DUCKDB_API ~DataChunk(); + + //! The vectors owned by the DataChunk. + vector data; + +public: + inline idx_t size() const { // NOLINT + return count; + } + inline idx_t ColumnCount() const { + return data.size(); + } + inline void SetCardinality(idx_t count_p) { + D_ASSERT(count_p <= capacity); + this->count = count_p; + } + inline void SetCardinality(const DataChunk &other) { + SetCardinality(other.size()); + } + inline void SetCapacity(idx_t capacity_p) { + this->capacity = capacity_p; + } + inline void SetCapacity(const DataChunk &other) { + SetCapacity(other.capacity); + } + + DUCKDB_API Value GetValue(idx_t col_idx, idx_t index) const; + DUCKDB_API void SetValue(idx_t col_idx, idx_t index, const Value &val); + + //! Returns true if all vectors in the DataChunk are constant + DUCKDB_API bool AllConstant() const; + + //! Set the DataChunk to reference another data chunk + DUCKDB_API void Reference(DataChunk &chunk); + //! Set the DataChunk to own the data of data chunk, destroying the other chunk in the process + DUCKDB_API void Move(DataChunk &chunk); + + //! Initializes the DataChunk with the specified types to an empty DataChunk + //! This will create one vector of the specified type for each LogicalType in the + //! types list. The vector will be referencing vector to the data owned by + //! the DataChunk. + DUCKDB_API void Initialize(Allocator &allocator, const vector &types, + idx_t capacity = STANDARD_VECTOR_SIZE); + DUCKDB_API void Initialize(ClientContext &context, const vector &types, + idx_t capacity = STANDARD_VECTOR_SIZE); + //! Initializes an empty DataChunk with the given types. The vectors will *not* have any data allocated for them. + DUCKDB_API void InitializeEmpty(const vector &types); + + DUCKDB_API void InitializeEmpty(vector::const_iterator begin, vector::const_iterator end); + DUCKDB_API void Initialize(Allocator &allocator, vector::const_iterator begin, + vector::const_iterator end, idx_t capacity = STANDARD_VECTOR_SIZE); + DUCKDB_API void Initialize(ClientContext &context, vector::const_iterator begin, + vector::const_iterator end, idx_t capacity = STANDARD_VECTOR_SIZE); + + //! Append the other DataChunk to this one. The column count and types of + //! the two DataChunks have to match exactly. Throws an exception if there + //! is not enough space in the chunk and resize is not allowed. + DUCKDB_API void Append(const DataChunk &other, bool resize = false, SelectionVector *sel = nullptr, + idx_t count = 0); + + //! Destroy all data and columns owned by this DataChunk + DUCKDB_API void Destroy(); + + //! Copies the data from this vector to another vector. + DUCKDB_API void Copy(DataChunk &other, idx_t offset = 0) const; + DUCKDB_API void Copy(DataChunk &other, const SelectionVector &sel, const idx_t source_count, + const idx_t offset = 0) const; + + //! Splits the DataChunk in two + DUCKDB_API void Split(DataChunk &other, idx_t split_idx); + + //! Fuses a DataChunk onto the right of this one, and destroys the other. Inverse of Split. + DUCKDB_API void Fuse(DataChunk &other); + + //! Makes this DataChunk reference the specified columns in the other DataChunk + DUCKDB_API void ReferenceColumns(DataChunk &other, const vector &column_ids); + + //! Turn all the vectors from the chunk into flat vectors + DUCKDB_API void Flatten(); + + // FIXME: this is DUCKDB_API, might need conversion back to regular unique ptr? + DUCKDB_API unsafe_unique_array ToUnifiedFormat(); + + DUCKDB_API void Slice(const SelectionVector &sel_vector, idx_t count); + + //! Slice all Vectors from other.data[i] to data[i + 'col_offset'] + //! Turning all Vectors into Dictionary Vectors, using 'sel' + DUCKDB_API void Slice(DataChunk &other, const SelectionVector &sel, idx_t count, idx_t col_offset = 0); + + //! Resets the DataChunk to its state right after the DataChunk::Initialize + //! function was called. This sets the count to 0, and resets each member + //! Vector to point back to the data owned by this DataChunk. + DUCKDB_API void Reset(); + + DUCKDB_API void Serialize(Serializer &serializer) const; + DUCKDB_API void Deserialize(Deserializer &source); + + //! Hashes the DataChunk to the target vector + DUCKDB_API void Hash(Vector &result); + //! Hashes specific vectors of the DataChunk to the target vector + DUCKDB_API void Hash(vector &column_ids, Vector &result); + + //! Returns a list of types of the vectors of this data chunk + DUCKDB_API vector GetTypes(); + + //! Converts this DataChunk to a printable string representation + DUCKDB_API string ToString() const; + DUCKDB_API void Print() const; + + DataChunk(const DataChunk &) = delete; + + //! Verify that the DataChunk is in a consistent, not corrupt state. DEBUG + //! FUNCTION ONLY! + DUCKDB_API void Verify(); + +private: + //! The amount of tuples stored in the data chunk + idx_t count; + //! The amount of tuples that can be stored in the data chunk + idx_t capacity; + //! Vector caches, used to store data when ::Initialize is called + vector vector_caches; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/date.hpp b/src/duckdb/src/include/duckdb/common/types/date.hpp new file mode 100644 index 00000000..8e9bfac8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/date.hpp @@ -0,0 +1,224 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/date.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/limits.hpp" + +#include + +namespace duckdb { + +struct timestamp_t; + +//! Type used to represent dates (days since 1970-01-01) +struct date_t { // NOLINT + int32_t days; + + date_t() = default; + explicit inline date_t(int32_t days_p) : days(days_p) { + } + + // explicit conversion + explicit inline operator int32_t() const { + return days; + } + + // comparison operators + inline bool operator==(const date_t &rhs) const { + return days == rhs.days; + }; + inline bool operator!=(const date_t &rhs) const { + return days != rhs.days; + }; + inline bool operator<=(const date_t &rhs) const { + return days <= rhs.days; + }; + inline bool operator<(const date_t &rhs) const { + return days < rhs.days; + }; + inline bool operator>(const date_t &rhs) const { + return days > rhs.days; + }; + inline bool operator>=(const date_t &rhs) const { + return days >= rhs.days; + }; + + // arithmetic operators + inline date_t operator+(const int32_t &days) const { + return date_t(this->days + days); + }; + inline date_t operator-(const int32_t &days) const { + return date_t(this->days - days); + }; + + // in-place operators + inline date_t &operator+=(const int32_t &days) { + this->days += days; + return *this; + }; + inline date_t &operator-=(const int32_t &days) { + this->days -= days; + return *this; + }; + + // special values + static inline date_t infinity() { // NOLINT + return date_t(NumericLimits::Maximum()); + } // NOLINT + static inline date_t ninfinity() { // NOLINT + return date_t(-NumericLimits::Maximum()); + } // NOLINT + static inline date_t epoch() { // NOLINT + return date_t(0); + } // NOLINT +}; + +//! The Date class is a static class that holds helper functions for the Date type. +class Date { +public: + static const char *PINF; // NOLINT + static const char *NINF; // NOLINT + static const char *EPOCH; // NOLINT + + static const string_t MONTH_NAMES[12]; + static const string_t MONTH_NAMES_ABBREVIATED[12]; + static const string_t DAY_NAMES[7]; + static const string_t DAY_NAMES_ABBREVIATED[7]; + static const int32_t NORMAL_DAYS[13]; + static const int32_t CUMULATIVE_DAYS[13]; + static const int32_t LEAP_DAYS[13]; + static const int32_t CUMULATIVE_LEAP_DAYS[13]; + static const int32_t CUMULATIVE_YEAR_DAYS[401]; + static const int8_t MONTH_PER_DAY_OF_YEAR[365]; + static const int8_t LEAP_MONTH_PER_DAY_OF_YEAR[366]; + + // min date is 5877642-06-25 (BC) (-2^31+2) + constexpr static const int32_t DATE_MIN_YEAR = -5877641; + constexpr static const int32_t DATE_MIN_MONTH = 6; + constexpr static const int32_t DATE_MIN_DAY = 25; + // max date is 5881580-07-10 (2^31-2) + constexpr static const int32_t DATE_MAX_YEAR = 5881580; + constexpr static const int32_t DATE_MAX_MONTH = 7; + constexpr static const int32_t DATE_MAX_DAY = 10; + constexpr static const int32_t EPOCH_YEAR = 1970; + + constexpr static const int32_t YEAR_INTERVAL = 400; + constexpr static const int32_t DAYS_PER_YEAR_INTERVAL = 146097; + +public: + //! Convert a string in the format "YYYY-MM-DD" to a date object + DUCKDB_API static date_t FromString(const string &str, bool strict = false); + //! Convert a string in the format "YYYY-MM-DD" to a date object + DUCKDB_API static date_t FromCString(const char *str, idx_t len, bool strict = false); + //! Convert a date object to a string in the format "YYYY-MM-DD" + DUCKDB_API static string ToString(date_t date); + //! Try to convert text in a buffer to a date; returns true if parsing was successful + //! If the date was a "special" value, the special flag will be set. + DUCKDB_API static bool TryConvertDate(const char *buf, idx_t len, idx_t &pos, date_t &result, bool &special, + bool strict = false); + + //! Create a string "YYYY-MM-DD" from a specified (year, month, day) + //! combination + DUCKDB_API static string Format(int32_t year, int32_t month, int32_t day); + + //! Extract the year, month and day from a given date object + DUCKDB_API static void Convert(date_t date, int32_t &out_year, int32_t &out_month, int32_t &out_day); + //! Create a Date object from a specified (year, month, day) combination + DUCKDB_API static date_t FromDate(int32_t year, int32_t month, int32_t day); + DUCKDB_API static bool TryFromDate(int32_t year, int32_t month, int32_t day, date_t &result); + + //! Returns true if (year) is a leap year, and false otherwise + DUCKDB_API static bool IsLeapYear(int32_t year); + + //! Returns true if the specified (year, month, day) combination is a valid + //! date + DUCKDB_API static bool IsValid(int32_t year, int32_t month, int32_t day); + + //! Returns true if the specified date is finite + static inline bool IsFinite(date_t date) { + return date != date_t::infinity() && date != date_t::ninfinity(); + } + + //! The max number of days in a month of a given year + DUCKDB_API static int32_t MonthDays(int32_t year, int32_t month); + + //! Extract the epoch from the date (seconds since 1970-01-01) + DUCKDB_API static int64_t Epoch(date_t date); + //! Extract the epoch from the date (nanoseconds since 1970-01-01) + DUCKDB_API static int64_t EpochNanoseconds(date_t date); + //! Extract the epoch from the date (microseconds since 1970-01-01) + DUCKDB_API static int64_t EpochMicroseconds(date_t date); + //! Extract the epoch from the date (milliseconds since 1970-01-01) + DUCKDB_API static int64_t EpochMilliseconds(date_t date); + //! Convert the epoch (seconds since 1970-01-01) to a date_t + DUCKDB_API static date_t EpochToDate(int64_t epoch); + + //! Extract the number of days since epoch (days since 1970-01-01) + DUCKDB_API static int32_t EpochDays(date_t date); + //! Convert the epoch number of days to a date_t + DUCKDB_API static date_t EpochDaysToDate(int32_t epoch); + + //! Extract year of a date entry + DUCKDB_API static int32_t ExtractYear(date_t date); + //! Extract year of a date entry, but optimized to first try the last year found + DUCKDB_API static int32_t ExtractYear(date_t date, int32_t *last_year); + DUCKDB_API static int32_t ExtractYear(timestamp_t ts, int32_t *last_year); + //! Extract month of a date entry + DUCKDB_API static int32_t ExtractMonth(date_t date); + //! Extract day of a date entry + DUCKDB_API static int32_t ExtractDay(date_t date); + //! Extract the day of the week (1-7) + DUCKDB_API static int32_t ExtractISODayOfTheWeek(date_t date); + //! Extract the day of the year + DUCKDB_API static int32_t ExtractDayOfTheYear(date_t date); + //! Extract the day of the year + DUCKDB_API static int64_t ExtractJulianDay(date_t date); + //! Extract the ISO week number + //! ISO weeks start on Monday and the first week of a year + //! contains January 4 of that year. + //! In the ISO week-numbering system, it is possible for early-January dates + //! to be part of the 52nd or 53rd week of the previous year. + DUCKDB_API static void ExtractISOYearWeek(date_t date, int32_t &year, int32_t &week); + DUCKDB_API static int32_t ExtractISOWeekNumber(date_t date); + DUCKDB_API static int32_t ExtractISOYearNumber(date_t date); + //! Extract the week number as Python handles it. + //! Either Monday or Sunday is the first day of the week, + //! and any date before the first Monday/Sunday returns week 0 + //! This is a bit more consistent because week numbers in a year are always incrementing + DUCKDB_API static int32_t ExtractWeekNumberRegular(date_t date, bool monday_first = true); + //! Returns the date of the monday of the current week. + DUCKDB_API static date_t GetMondayOfCurrentWeek(date_t date); + + //! Helper function to parse two digits from a string (e.g. "30" -> 30, "03" -> 3, "3" -> 3) + DUCKDB_API static bool ParseDoubleDigit(const char *buf, idx_t len, idx_t &pos, int32_t &result); + + DUCKDB_API static string ConversionError(const string &str); + DUCKDB_API static string ConversionError(string_t str); + +private: + static void ExtractYearOffset(int32_t &n, int32_t &year, int32_t &year_offset); +}; + +} // namespace duckdb + +namespace std { + +//! Date +template <> +struct hash { + std::size_t operator()(const duckdb::date_t &k) const { + using std::hash; + return hash()((int32_t)k); + } +}; +} // namespace std diff --git a/src/duckdb/src/include/duckdb/common/types/datetime.hpp b/src/duckdb/src/include/duckdb/common/types/datetime.hpp new file mode 100644 index 00000000..5e13b610 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/datetime.hpp @@ -0,0 +1,153 @@ +#pragma once + +#include "duckdb/common/common.hpp" + +#include + +namespace duckdb { + +//! Type used to represent time (microseconds) +struct dtime_t { // NOLINT + int64_t micros; + + dtime_t() = default; + explicit inline dtime_t(int64_t micros_p) : micros(micros_p) { + } + inline dtime_t &operator=(int64_t micros_p) { + micros = micros_p; + return *this; + } + + // explicit conversion + explicit inline operator int64_t() const { + return micros; + } + explicit inline operator double() const { + return micros; + } + + // comparison operators + inline bool operator==(const dtime_t &rhs) const { + return micros == rhs.micros; + }; + inline bool operator!=(const dtime_t &rhs) const { + return micros != rhs.micros; + }; + inline bool operator<=(const dtime_t &rhs) const { + return micros <= rhs.micros; + }; + inline bool operator<(const dtime_t &rhs) const { + return micros < rhs.micros; + }; + inline bool operator>(const dtime_t &rhs) const { + return micros > rhs.micros; + }; + inline bool operator>=(const dtime_t &rhs) const { + return micros >= rhs.micros; + }; + + // arithmetic operators + inline dtime_t operator+(const int64_t µs) const { + return dtime_t(this->micros + micros); + }; + inline dtime_t operator+(const double µs) const { + return dtime_t(this->micros + int64_t(micros)); + }; + inline dtime_t operator-(const int64_t µs) const { + return dtime_t(this->micros - micros); + }; + inline dtime_t operator*(const idx_t &copies) const { + return dtime_t(this->micros * copies); + }; + inline dtime_t operator/(const idx_t &copies) const { + return dtime_t(this->micros / copies); + }; + inline int64_t operator-(const dtime_t &other) const { + return this->micros - other.micros; + }; + + // in-place operators + inline dtime_t &operator+=(const int64_t µs) { + this->micros += micros; + return *this; + }; + inline dtime_t &operator-=(const int64_t µs) { + this->micros -= micros; + return *this; + }; + inline dtime_t &operator+=(const dtime_t &other) { + this->micros += other.micros; + return *this; + }; + + // special values + static inline dtime_t allballs() { // NOLINT + return dtime_t(0); + } // NOLINT +}; + +struct dtime_tz_t { // NOLINT + static constexpr const int TIME_BITS = 40; + static constexpr const int OFFSET_BITS = 24; + static constexpr const uint64_t OFFSET_MASK = ~uint64_t(0) >> TIME_BITS; + static constexpr const int32_t MAX_OFFSET = 1559 * 60 * 60; + static constexpr const int32_t MIN_OFFSET = -MAX_OFFSET; + + uint64_t bits; + + dtime_tz_t() = default; + + inline dtime_tz_t(dtime_t t, int32_t offset) + : bits((uint64_t(t.micros) << OFFSET_BITS) | uint64_t(offset + MAX_OFFSET)) { + } + + inline dtime_t time() const { // NOLINT + return dtime_t(bits >> OFFSET_BITS); + } + + inline int32_t offset() const { // NOLINT + return int32_t(bits & OFFSET_MASK) - MAX_OFFSET; + } + + // comparison operators + inline bool operator==(const dtime_tz_t &rhs) const { + return bits == rhs.bits; + }; + inline bool operator!=(const dtime_tz_t &rhs) const { + return bits != rhs.bits; + }; + inline bool operator<=(const dtime_tz_t &rhs) const { + return bits <= rhs.bits; + }; + inline bool operator<(const dtime_tz_t &rhs) const { + return bits < rhs.bits; + }; + inline bool operator>(const dtime_tz_t &rhs) const { + return bits > rhs.bits; + }; + inline bool operator>=(const dtime_tz_t &rhs) const { + return bits >= rhs.bits; + }; +}; + +} // namespace duckdb + +namespace std { + +//! Time +template <> +struct hash { + std::size_t operator()(const duckdb::dtime_t &k) const { + using std::hash; + return hash()((int64_t)k); + } +}; + +template <> +struct hash { + std::size_t operator()(const duckdb::dtime_tz_t &k) const { + using std::hash; + return hash()(k.bits); + } +}; +} // namespace std diff --git a/src/duckdb/src/include/duckdb/common/types/decimal.hpp b/src/duckdb/src/include/duckdb/common/types/decimal.hpp new file mode 100644 index 00000000..e08e544a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/decimal.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/decimal.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" + +namespace duckdb { + +template +struct DecimalWidth {}; + +template <> +struct DecimalWidth { + static constexpr uint8_t max = 4; +}; + +template <> +struct DecimalWidth { + static constexpr uint8_t max = 9; +}; + +template <> +struct DecimalWidth { + static constexpr uint8_t max = 18; +}; + +template <> +struct DecimalWidth { + static constexpr uint8_t max = 38; +}; + +//! The Decimal class is a static class that holds helper functions for the Decimal type +class Decimal { +public: + static constexpr uint8_t MAX_WIDTH_INT16 = DecimalWidth::max; + static constexpr uint8_t MAX_WIDTH_INT32 = DecimalWidth::max; + static constexpr uint8_t MAX_WIDTH_INT64 = DecimalWidth::max; + static constexpr uint8_t MAX_WIDTH_INT128 = DecimalWidth::max; + static constexpr uint8_t MAX_WIDTH_DECIMAL = MAX_WIDTH_INT128; + +public: + static string ToString(int16_t value, uint8_t width, uint8_t scale); + static string ToString(int32_t value, uint8_t width, uint8_t scale); + static string ToString(int64_t value, uint8_t width, uint8_t scale); + static string ToString(hugeint_t value, uint8_t width, uint8_t scale); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/hash.hpp b/src/duckdb/src/include/duckdb/common/types/hash.hpp new file mode 100644 index 00000000..337705a8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/hash.hpp @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/hash.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +struct string_t; +struct interval_t; + +// efficient hash function that maximizes the avalanche effect and minimizes +// bias +// see: https://nullprogram.com/blog/2018/07/31/ + +inline hash_t murmurhash64(uint64_t x) { + x ^= x >> 32; + x *= 0xd6e8feb86659fd93U; + x ^= x >> 32; + x *= 0xd6e8feb86659fd93U; + x ^= x >> 32; + return x; +} + +inline hash_t murmurhash32(uint32_t x) { + return murmurhash64(x); +} + +template +hash_t Hash(T value) { + return murmurhash32(value); +} + +//! Combine two hashes by XORing them +inline hash_t CombineHash(hash_t left, hash_t right) { + return left ^ right; +} + +template <> +DUCKDB_API hash_t Hash(uint64_t val); +template <> +DUCKDB_API hash_t Hash(int64_t val); +template <> +DUCKDB_API hash_t Hash(hugeint_t val); +template <> +DUCKDB_API hash_t Hash(float val); +template <> +DUCKDB_API hash_t Hash(double val); +template <> +DUCKDB_API hash_t Hash(const char *val); +template <> +DUCKDB_API hash_t Hash(char *val); +template <> +DUCKDB_API hash_t Hash(string_t val); +template <> +DUCKDB_API hash_t Hash(interval_t val); +DUCKDB_API hash_t Hash(const char *val, size_t size); +DUCKDB_API hash_t Hash(uint8_t *val, size_t size); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp new file mode 100644 index 00000000..babe5c8b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp @@ -0,0 +1,164 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/hugeint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/type_util.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +//! The Hugeint class contains static operations for the INT128 type +class Hugeint { +public: + //! Convert a hugeint object to a string + static string ToString(hugeint_t input); + + template + DUCKDB_API static bool TryCast(hugeint_t input, T &result); + + template + static T Cast(hugeint_t input) { + T result = 0; + TryCast(input, result); + return result; + } + + template + static bool TryConvert(T value, hugeint_t &result); + + template + static hugeint_t Convert(T value) { + hugeint_t result; + if (!TryConvert(value, result)) { // LCOV_EXCL_START + throw ValueOutOfRangeException(double(value), GetTypeId(), GetTypeId()); + } // LCOV_EXCL_STOP + return result; + } + + static void NegateInPlace(hugeint_t &input) { + if (input.upper == NumericLimits::Minimum() && input.lower == 0) { + throw OutOfRangeException("HUGEINT is out of range"); + } + input.lower = NumericLimits::Maximum() - input.lower + 1; + input.upper = -1 - input.upper + (input.lower == 0); + } + static hugeint_t Negate(hugeint_t input) { + NegateInPlace(input); + return input; + } + + static bool TryMultiply(hugeint_t lhs, hugeint_t rhs, hugeint_t &result); + + static hugeint_t Add(hugeint_t lhs, hugeint_t rhs); + static hugeint_t Subtract(hugeint_t lhs, hugeint_t rhs); + static hugeint_t Multiply(hugeint_t lhs, hugeint_t rhs); + static hugeint_t Divide(hugeint_t lhs, hugeint_t rhs); + static hugeint_t Modulo(hugeint_t lhs, hugeint_t rhs); + + // DivMod -> returns the result of the division (lhs / rhs), and fills up the remainder + static hugeint_t DivMod(hugeint_t lhs, hugeint_t rhs, hugeint_t &remainder); + // DivMod but lhs MUST be positive, and rhs is a uint64_t + static hugeint_t DivModPositive(hugeint_t lhs, uint64_t rhs, uint64_t &remainder); + + static bool AddInPlace(hugeint_t &lhs, hugeint_t rhs); + static bool SubtractInPlace(hugeint_t &lhs, hugeint_t rhs); + + // comparison operators + // note that everywhere here we intentionally use bitwise ops + // this is because they seem to be consistently much faster (benchmarked on a Macbook Pro) + static bool Equals(hugeint_t lhs, hugeint_t rhs) { + int lower_equals = lhs.lower == rhs.lower; + int upper_equals = lhs.upper == rhs.upper; + return lower_equals & upper_equals; + } + static bool NotEquals(hugeint_t lhs, hugeint_t rhs) { + int lower_not_equals = lhs.lower != rhs.lower; + int upper_not_equals = lhs.upper != rhs.upper; + return lower_not_equals | upper_not_equals; + } + static bool GreaterThan(hugeint_t lhs, hugeint_t rhs) { + int upper_bigger = lhs.upper > rhs.upper; + int upper_equal = lhs.upper == rhs.upper; + int lower_bigger = lhs.lower > rhs.lower; + return upper_bigger | (upper_equal & lower_bigger); + } + static bool GreaterThanEquals(hugeint_t lhs, hugeint_t rhs) { + int upper_bigger = lhs.upper > rhs.upper; + int upper_equal = lhs.upper == rhs.upper; + int lower_bigger_equals = lhs.lower >= rhs.lower; + return upper_bigger | (upper_equal & lower_bigger_equals); + } + static bool LessThan(hugeint_t lhs, hugeint_t rhs) { + int upper_smaller = lhs.upper < rhs.upper; + int upper_equal = lhs.upper == rhs.upper; + int lower_smaller = lhs.lower < rhs.lower; + return upper_smaller | (upper_equal & lower_smaller); + } + static bool LessThanEquals(hugeint_t lhs, hugeint_t rhs) { + int upper_smaller = lhs.upper < rhs.upper; + int upper_equal = lhs.upper == rhs.upper; + int lower_smaller_equals = lhs.lower <= rhs.lower; + return upper_smaller | (upper_equal & lower_smaller_equals); + } + static const hugeint_t POWERS_OF_TEN[40]; +}; + +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, int8_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, int16_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, int32_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, int64_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, uint8_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, uint16_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, uint32_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, uint64_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, hugeint_t &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, float &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, double &result); +template <> +DUCKDB_API bool Hugeint::TryCast(hugeint_t input, long double &result); + +template <> +bool Hugeint::TryConvert(int8_t value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(int16_t value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(int32_t value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(int64_t value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(uint8_t value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(uint16_t value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(uint32_t value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(uint64_t value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(float value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(double value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(long double value, hugeint_t &result); +template <> +bool Hugeint::TryConvert(const char *value, hugeint_t &result); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/hyperloglog.hpp b/src/duckdb/src/include/duckdb/common/types/hyperloglog.hpp new file mode 100644 index 00000000..c46ce72a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/hyperloglog.hpp @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/hyperloglog.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/types/vector.hpp" +#include "hyperloglog.hpp" + +namespace duckdb_hll { +struct robj; +} + +namespace duckdb { + +enum class HLLStorageType : uint8_t { UNCOMPRESSED = 1 }; + +class Serializer; +class Deserializer; + +//! The HyperLogLog class holds a HyperLogLog counter for approximate cardinality counting +class HyperLogLog { +public: + HyperLogLog(); + ~HyperLogLog(); + // implicit copying of HyperLogLog is not allowed + HyperLogLog(const HyperLogLog &) = delete; + + //! Adds an element of the specified size to the HyperLogLog counter + void Add(data_ptr_t element, idx_t size); + //! Return the count of this HyperLogLog counter + idx_t Count() const; + //! Merge this HyperLogLog counter with another counter to create a new one + unique_ptr Merge(HyperLogLog &other); + HyperLogLog *MergePointer(HyperLogLog &other); + //! Merge a set of HyperLogLogs to create one big one + static unique_ptr Merge(HyperLogLog logs[], idx_t count); + //! Get the size (in bytes) of a HLL + static idx_t GetSize(); + //! Get pointer to the HLL + data_ptr_t GetPtr() const; + //! Get copy of the HLL + unique_ptr Copy(); + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + //! Compute HLL hashes over vdata, and store them in 'hashes' + //! Then, compute register indices and prefix lengths, and also store them in 'hashes' as a pair of uint32_t + static void ProcessEntries(UnifiedVectorFormat &vdata, const LogicalType &type, uint64_t hashes[], uint8_t counts[], + idx_t count); + //! Add the indices and counts to the logs + static void AddToLogs(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[], + HyperLogLog **logs[], const SelectionVector *log_sel); + //! Add the indices and counts to THIS log + void AddToLog(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[]); + +private: + explicit HyperLogLog(duckdb_hll::robj *hll); + + duckdb_hll::robj *hll; + mutex lock; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/interval.hpp b/src/duckdb/src/include/duckdb/common/types/interval.hpp new file mode 100644 index 00000000..9c22ab35 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/interval.hpp @@ -0,0 +1,151 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/interval.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" + +namespace duckdb { + +struct dtime_t; +struct date_t; +struct timestamp_t; + +class Serializer; +class Deserializer; + +struct interval_t { + int32_t months; + int32_t days; + int64_t micros; + + inline bool operator==(const interval_t &rhs) const { + return this->days == rhs.days && this->months == rhs.months && this->micros == rhs.micros; + } + + // Serialization + void Serialize(Serializer &serializer) const; + static interval_t Deserialize(Deserializer &source); +}; + +//! The Interval class is a static class that holds helper functions for the Interval +//! type. +class Interval { +public: + static constexpr const int32_t MONTHS_PER_MILLENIUM = 12000; + static constexpr const int32_t MONTHS_PER_CENTURY = 1200; + static constexpr const int32_t MONTHS_PER_DECADE = 120; + static constexpr const int32_t MONTHS_PER_YEAR = 12; + static constexpr const int32_t MONTHS_PER_QUARTER = 3; + static constexpr const int32_t DAYS_PER_WEEK = 7; + //! only used for interval comparison/ordering purposes, in which case a month counts as 30 days + static constexpr const int64_t DAYS_PER_MONTH = 30; + static constexpr const int64_t DAYS_PER_YEAR = 365; + static constexpr const int64_t MSECS_PER_SEC = 1000; + static constexpr const int32_t SECS_PER_MINUTE = 60; + static constexpr const int32_t MINS_PER_HOUR = 60; + static constexpr const int32_t HOURS_PER_DAY = 24; + static constexpr const int32_t SECS_PER_HOUR = SECS_PER_MINUTE * MINS_PER_HOUR; + static constexpr const int32_t SECS_PER_DAY = SECS_PER_HOUR * HOURS_PER_DAY; + static constexpr const int32_t SECS_PER_WEEK = SECS_PER_DAY * DAYS_PER_WEEK; + + static constexpr const int64_t MICROS_PER_MSEC = 1000; + static constexpr const int64_t MICROS_PER_SEC = MICROS_PER_MSEC * MSECS_PER_SEC; + static constexpr const int64_t MICROS_PER_MINUTE = MICROS_PER_SEC * SECS_PER_MINUTE; + static constexpr const int64_t MICROS_PER_HOUR = MICROS_PER_MINUTE * MINS_PER_HOUR; + static constexpr const int64_t MICROS_PER_DAY = MICROS_PER_HOUR * HOURS_PER_DAY; + static constexpr const int64_t MICROS_PER_WEEK = MICROS_PER_DAY * DAYS_PER_WEEK; + static constexpr const int64_t MICROS_PER_MONTH = MICROS_PER_DAY * DAYS_PER_MONTH; + + static constexpr const int64_t NANOS_PER_MICRO = 1000; + static constexpr const int64_t NANOS_PER_MSEC = NANOS_PER_MICRO * MICROS_PER_MSEC; + static constexpr const int64_t NANOS_PER_SEC = NANOS_PER_MSEC * MSECS_PER_SEC; + static constexpr const int64_t NANOS_PER_MINUTE = NANOS_PER_SEC * SECS_PER_MINUTE; + static constexpr const int64_t NANOS_PER_HOUR = NANOS_PER_MINUTE * MINS_PER_HOUR; + static constexpr const int64_t NANOS_PER_DAY = NANOS_PER_HOUR * HOURS_PER_DAY; + static constexpr const int64_t NANOS_PER_WEEK = NANOS_PER_DAY * DAYS_PER_WEEK; + +public: + //! Convert a string to an interval object + static bool FromString(const string &str, interval_t &result); + //! Convert a string to an interval object + static bool FromCString(const char *str, idx_t len, interval_t &result, string *error_message, bool strict); + //! Convert an interval object to a string + static string ToString(const interval_t &val); + + //! Convert milliseconds to a normalised interval + DUCKDB_API static interval_t FromMicro(int64_t micros); + + //! Get Interval in milliseconds + static int64_t GetMilli(const interval_t &val); + + //! Get Interval in microseconds + static int64_t GetMicro(const interval_t &val); + + //! Get Interval in Nanoseconds + static int64_t GetNanoseconds(const interval_t &val); + + //! Returns the age between two timestamps (including 30 day months) + static interval_t GetAge(timestamp_t timestamp_1, timestamp_t timestamp_2); + + //! Returns the exact difference between two timestamps (days and seconds) + static interval_t GetDifference(timestamp_t timestamp_1, timestamp_t timestamp_2); + + //! Returns the inverted interval + static interval_t Invert(interval_t interval); + + //! Add an interval to a date + static date_t Add(date_t left, interval_t right); + //! Add an interval to a timestamp + static timestamp_t Add(timestamp_t left, interval_t right); + //! Add an interval to a time. In case the time overflows or underflows, modify the date by the overflow. + //! For example if we go from 23:00 to 02:00, we add a day to the date + static dtime_t Add(dtime_t left, interval_t right, date_t &date); + + //! Comparison operators + inline static bool Equals(const interval_t &left, const interval_t &right); + inline static bool GreaterThan(const interval_t &left, const interval_t &right); +}; +static void NormalizeIntervalEntries(interval_t input, int64_t &months, int64_t &days, int64_t µs) { + int64_t extra_months_d = input.days / Interval::DAYS_PER_MONTH; + int64_t extra_months_micros = input.micros / Interval::MICROS_PER_MONTH; + input.days -= extra_months_d * Interval::DAYS_PER_MONTH; + input.micros -= extra_months_micros * Interval::MICROS_PER_MONTH; + + int64_t extra_days_micros = input.micros / Interval::MICROS_PER_DAY; + input.micros -= extra_days_micros * Interval::MICROS_PER_DAY; + + months = input.months + extra_months_d + extra_months_micros; + days = input.days + extra_days_micros; + micros = input.micros; +} + +bool Interval::Equals(const interval_t &left, const interval_t &right) { + return left.months == right.months && left.days == right.days && left.micros == right.micros; +} + +bool Interval::GreaterThan(const interval_t &left, const interval_t &right) { + int64_t lmonths, ldays, lmicros; + int64_t rmonths, rdays, rmicros; + NormalizeIntervalEntries(left, lmonths, ldays, lmicros); + NormalizeIntervalEntries(right, rmonths, rdays, rmicros); + + if (lmonths > rmonths) { + return true; + } else if (lmonths < rmonths) { + return false; + } + if (ldays > rdays) { + return true; + } else if (ldays < rdays) { + return false; + } + return lmicros > rmicros; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/list_segment.hpp b/src/duckdb/src/include/duckdb/common/types/list_segment.hpp new file mode 100644 index 00000000..ea4c2ad8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/list_segment.hpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/list_segment.hpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/types/vector.hpp" + +#pragma once + +namespace duckdb { + +struct ListSegment { + constexpr const static idx_t INITIAL_CAPACITY = 4; + + uint16_t count; + uint16_t capacity; + ListSegment *next; +}; +struct LinkedList { + LinkedList() : total_capacity(0), first_segment(nullptr), last_segment(nullptr) {}; + LinkedList(idx_t total_capacity_p, ListSegment *first_segment_p, ListSegment *last_segment_p) + : total_capacity(total_capacity_p), first_segment(first_segment_p), last_segment(last_segment_p) { + } + + idx_t total_capacity; + ListSegment *first_segment; + ListSegment *last_segment; +}; + +// forward declarations +struct ListSegmentFunctions; +typedef ListSegment *(*create_segment_t)(const ListSegmentFunctions &functions, ArenaAllocator &allocator, + uint16_t capacity); +typedef void (*write_data_to_segment_t)(const ListSegmentFunctions &functions, ArenaAllocator &allocator, + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, + idx_t &entry_idx); +typedef void (*read_data_from_segment_t)(const ListSegmentFunctions &functions, const ListSegment *segment, + Vector &result, idx_t &total_count); + +struct ListSegmentFunctions { + create_segment_t create_segment; + write_data_to_segment_t write_data; + read_data_from_segment_t read_data; + + vector child_functions; + + void AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &input_data, + idx_t &entry_idx) const; + void BuildListVector(const LinkedList &linked_list, Vector &result, idx_t &initial_total_count) const; +}; + +void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType &type); +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/null_value.hpp b/src/duckdb/src/include/duckdb/common/types/null_value.hpp new file mode 100644 index 00000000..ed96b721 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/null_value.hpp @@ -0,0 +1,80 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/null_value.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/windows_undefs.hpp" + +#include +#include +#include + +namespace duckdb { + +//! Placeholder to insert in Vectors or to use for hashing NULLs +template +inline T NullValue() { + return std::numeric_limits::min(); +} + +constexpr const char str_nil[2] = {'\200', '\0'}; + +template <> +inline const char *NullValue() { + D_ASSERT(str_nil[0] == '\200' && str_nil[1] == '\0'); + return str_nil; +} + +template <> +inline string_t NullValue() { + return string_t(NullValue()); +} + +template <> +inline char *NullValue() { + return (char *)NullValue(); // NOLINT +} + +template <> +inline string NullValue() { + return string(NullValue()); +} + +template <> +inline interval_t NullValue() { + interval_t null_value; + null_value.days = NullValue(); + null_value.months = NullValue(); + null_value.micros = NullValue(); + return null_value; +} + +template <> +inline hugeint_t NullValue() { + hugeint_t min; + min.lower = 0; + min.upper = std::numeric_limits::min(); + return min; +} + +template <> +inline float NullValue() { + return NAN; +} + +template <> +inline double NullValue() { + return NAN; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp b/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp new file mode 100644 index 00000000..f3118723 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp @@ -0,0 +1,187 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/partitioned_tuple_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/fixed_size_map.hpp" +#include "duckdb/common/perfect_map_set.hpp" +#include "duckdb/common/types/row/tuple_data_allocator.hpp" +#include "duckdb/common/types/row/tuple_data_collection.hpp" + +namespace duckdb { + +//! Local state for parallel partitioning +struct PartitionedTupleDataAppendState { +public: + PartitionedTupleDataAppendState() : partition_indices(LogicalType::UBIGINT) { + } + +public: + Vector partition_indices; + SelectionVector partition_sel; + SelectionVector reverse_partition_sel; + + static constexpr idx_t MAP_THRESHOLD = 256; + perfect_map_t partition_entries; + fixed_size_map_t fixed_partition_entries; + + vector> partition_pin_states; + TupleDataChunkState chunk_state; +}; + +enum class PartitionedTupleDataType : uint8_t { + INVALID, + //! Radix partitioning on a hash column + RADIX +}; + +//! Shared allocators for parallel partitioning +struct PartitionTupleDataAllocators { + mutex lock; + vector> allocators; +}; + +//! PartitionedTupleData represents partitioned row data, which serves as an interface for different types of +//! partitioning, e.g., radix, hive +class PartitionedTupleData { +public: + virtual ~PartitionedTupleData(); + +public: + //! Get the layout of this PartitionedTupleData + const TupleDataLayout &GetLayout() const; + //! Get the partitioning type of this PartitionedTupleData + PartitionedTupleDataType GetType() const; + //! Initializes a local state for parallel partitioning that can be merged into this PartitionedTupleData + void InitializeAppendState(PartitionedTupleDataAppendState &state, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; + //! Appends a DataChunk to this PartitionedTupleData + void Append(PartitionedTupleDataAppendState &state, DataChunk &input, + const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), + const idx_t append_count = DConstants::INVALID_INDEX); + //! Appends a DataChunk to this PartitionedTupleData + //! - ToUnifiedFormat has already been called + void AppendUnified(PartitionedTupleDataAppendState &state, DataChunk &input, + const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), + const idx_t append_count = DConstants::INVALID_INDEX); + //! Appends rows to this PartitionedTupleData + void Append(PartitionedTupleDataAppendState &state, TupleDataChunkState &input, const idx_t count); + //! Flushes any remaining data in the append state into this PartitionedTupleData + void FlushAppendState(PartitionedTupleDataAppendState &state); + //! Combine another PartitionedTupleData into this PartitionedTupleData + void Combine(PartitionedTupleData &other); + //! Resets this PartitionedTupleData + void Reset(); + //! Repartition this PartitionedTupleData into the new PartitionedTupleData + void Repartition(PartitionedTupleData &new_partitioned_data); + //! Unpins the data + void Unpin(); + //! Get the partitions in this PartitionedTupleData + vector> &GetPartitions(); + //! Get the data of this PartitionedTupleData as a single unpartitioned TupleDataCollection + unique_ptr GetUnpartitioned(); + //! Get the count of this PartitionedTupleData + idx_t Count() const; + //! Get the size (in bytes) of this PartitionedTupleData + idx_t SizeInBytes() const; + //! Get the number of partitions of this PartitionedTupleData + idx_t PartitionCount() const; + //! Converts this PartitionedTupleData to a string representation + string ToString(); + //! Prints the string representation of this PartitionedTupleData + void Print(); + +protected: + //===--------------------------------------------------------------------===// + // Partitioning type implementation interface + //===--------------------------------------------------------------------===// + //! Initialize a PartitionedTupleDataAppendState for this type of partitioning (optional) + virtual void InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, + TupleDataPinProperties properties) const { + } + //! Compute the partition indices for this type of partitioning for the input DataChunk and store them in the + //! `partition_data` of the local state. If this type creates partitions on the fly (for, e.g., hive), this + //! function is also in charge of creating new partitions and mapping the input data to a partition index + virtual void ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input) { + throw NotImplementedException("ComputePartitionIndices for this type of PartitionedTupleData"); + } + //! Compute partition indices from rows (similar to function above) + virtual void ComputePartitionIndices(Vector &row_locations, idx_t append_count, Vector &partition_indices) const { + throw NotImplementedException("ComputePartitionIndices for this type of PartitionedTupleData"); + } + //! Maximum partition index (optional) + virtual idx_t MaxPartitionIndex() const { + return DConstants::INVALID_INDEX; + } + + //! Whether or not to iterate over the original partitions in reverse order when repartitioning (optional) + virtual bool RepartitionReverseOrder() const { + return false; + } + //! Finalize states while repartitioning - useful for unpinning blocks that are no longer needed (optional) + virtual void RepartitionFinalizeStates(PartitionedTupleData &old_partitioned_data, + PartitionedTupleData &new_partitioned_data, + PartitionedTupleDataAppendState &state, idx_t finished_partition_idx) const { + } + +protected: + //! PartitionedTupleData can only be instantiated by derived classes + PartitionedTupleData(PartitionedTupleDataType type, BufferManager &buffer_manager, const TupleDataLayout &layout); + PartitionedTupleData(const PartitionedTupleData &other); + + //! Create a new shared allocator + void CreateAllocator(); + //! Whether to use fixed size map or regular marp + bool UseFixedSizeMap() const; + //! Builds a selection vector in the Append state for the partitions + //! - returns true if everything belongs to the same partition - stores partition index in single_partition_idx + void BuildPartitionSel(PartitionedTupleDataAppendState &state, const SelectionVector &append_sel, + const idx_t append_count); + template + void BuildPartitionSel(PartitionedTupleDataAppendState &state, MAP_TYPE &partition_entries, + const SelectionVector &append_sel, const idx_t append_count); + //! Builds out the buffer space in the partitions + void BuildBufferSpace(PartitionedTupleDataAppendState &state); + template + void BuildBufferSpace(PartitionedTupleDataAppendState &state, const MAP_TYPE &partition_entries); + //! Create a collection for a specific a partition + unique_ptr CreatePartitionCollection(idx_t partition_index) const { + if (allocators) { + return make_uniq(allocators->allocators[partition_index]); + } else { + return make_uniq(buffer_manager, layout); + } + } + //! Verify count/data size of this PartitionedTupleData + void Verify() const; + +protected: + PartitionedTupleDataType type; + BufferManager &buffer_manager; + const TupleDataLayout layout; + idx_t count; + idx_t data_size; + + mutex lock; + shared_ptr allocators; + vector> partitions; + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/row_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/row/row_data_collection.hpp new file mode 100644 index 00000000..acdd6de3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/row_data_collection.hpp @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/row_data_collection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" + +namespace duckdb { + +struct RowDataBlock { +public: + RowDataBlock(BufferManager &buffer_manager, idx_t capacity, idx_t entry_size) + : capacity(capacity), entry_size(entry_size), count(0), byte_offset(0) { + idx_t size = MaxValue(Storage::BLOCK_SIZE, capacity * entry_size); + buffer_manager.Allocate(size, false, &block); + D_ASSERT(BufferManager::GetAllocSize(size) == block->GetMemoryUsage()); + } + explicit RowDataBlock(idx_t entry_size) : entry_size(entry_size) { + } + //! The buffer block handle + shared_ptr block; + //! Capacity (number of entries) and entry size that fit in this block + idx_t capacity; + const idx_t entry_size; + //! Number of entries currently in this block + idx_t count; + //! Write offset (if variable size entries) + idx_t byte_offset; + +private: + //! Implicit copying is not allowed + RowDataBlock(const RowDataBlock &) = delete; + +public: + unique_ptr Copy() { + auto result = make_uniq(entry_size); + result->block = block; + result->capacity = capacity; + result->count = count; + result->byte_offset = byte_offset; + return result; + } +}; + +struct BlockAppendEntry { + BlockAppendEntry(data_ptr_t baseptr, idx_t count) : baseptr(baseptr), count(count) { + } + data_ptr_t baseptr; + idx_t count; +}; + +class RowDataCollection { +public: + RowDataCollection(BufferManager &buffer_manager, idx_t block_capacity, idx_t entry_size, bool keep_pinned = false); + + unique_ptr CloneEmpty(bool keep_pinned = false) const { + return make_uniq(buffer_manager, block_capacity, entry_size, keep_pinned); + } + + //! BufferManager + BufferManager &buffer_manager; + //! The total number of stored entries + idx_t count; + //! The number of entries per block + idx_t block_capacity; + //! Size of entries in the blocks + idx_t entry_size; + //! The blocks holding the main data + vector> blocks; + //! The blocks that this collection currently has pinned + vector pinned_blocks; + //! Whether the blocks should stay pinned (necessary for e.g. a heap) + const bool keep_pinned; + +public: + idx_t AppendToBlock(RowDataBlock &block, BufferHandle &handle, vector &append_entries, + idx_t remaining, idx_t entry_sizes[]); + RowDataBlock &CreateBlock(); + vector Build(idx_t added_count, data_ptr_t key_locations[], idx_t entry_sizes[], + const SelectionVector *sel = FlatVector::IncrementalSelectionVector()); + + void Merge(RowDataCollection &other); + + void Clear() { + blocks.clear(); + pinned_blocks.clear(); + count = 0; + } + + //! The size (in bytes) of this RowDataCollection + idx_t SizeInBytes() const { + VerifyBlockSizes(); + idx_t size = 0; + for (auto &block : blocks) { + size += block->block->GetMemoryUsage(); + } + return size; + } + + //! Verifies that the block sizes are correct (Debug only) + void VerifyBlockSizes() const { +#ifdef DEBUG + for (auto &block : blocks) { + D_ASSERT(block->block->GetMemoryUsage() == BufferManager::GetAllocSize(block->capacity * entry_size)); + } +#endif + } + + static inline idx_t EntriesPerBlock(idx_t width) { + return Storage::BLOCK_SIZE / width; + } + +private: + mutex rdc_lock; + + //! Copying is not allowed + RowDataCollection(const RowDataCollection &) = delete; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/row_data_collection_scanner.hpp b/src/duckdb/src/include/duckdb/common/types/row/row_data_collection_scanner.hpp new file mode 100644 index 00000000..69753181 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/row_data_collection_scanner.hpp @@ -0,0 +1,121 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/row_data_collection_scanner.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/row/row_layout.hpp" + +namespace duckdb { + +class BufferHandle; +class RowDataCollection; +struct RowDataBlock; +class DataChunk; + +//! Used to scan the data into DataChunks after sorting +struct RowDataCollectionScanner { +public: + using Types = vector; + + struct ScanState { + explicit ScanState(const RowDataCollectionScanner &scanner_p) : scanner(scanner_p), block_idx(0), entry_idx(0) { + } + + void PinData(); + + //! The data layout + const RowDataCollectionScanner &scanner; + + idx_t block_idx; + idx_t entry_idx; + + BufferHandle data_handle; + BufferHandle heap_handle; + + // We must pin ALL blocks we are going to gather from + vector pinned_blocks; + }; + + //! Ensure that heap blocks correspond to row blocks + static void AlignHeapBlocks(RowDataCollection &dst_block_collection, RowDataCollection &dst_string_heap, + RowDataCollection &src_block_collection, RowDataCollection &src_string_heap, + const RowLayout &layout); + + RowDataCollectionScanner(RowDataCollection &rows, RowDataCollection &heap, const RowLayout &layout, bool external, + bool flush = true); + + // Single block scan + RowDataCollectionScanner(RowDataCollection &rows, RowDataCollection &heap, const RowLayout &layout, bool external, + idx_t block_idx, bool flush); + + //! The type layout of the payload + inline const vector &GetTypes() const { + return layout.GetTypes(); + } + + //! The number of rows in the collection + inline idx_t Count() const { + return total_count; + } + + //! The number of rows scanned so far + inline idx_t Scanned() const { + return total_scanned; + } + + //! The number of remaining rows + inline idx_t Remaining() const { + return total_count - total_scanned; + } + + //! The number of remaining rows + inline idx_t BlockIndex() const { + return read_state.block_idx; + } + + //! Swizzle the blocks for external scanning + //! Swizzling is all or nothing, so if we have scanned previously, + //! we need to re-swizzle. + void ReSwizzle(); + + void SwizzleBlock(RowDataBlock &data_block, RowDataBlock &heap_block); + + //! Scans the next data chunk from the sorted data + void Scan(DataChunk &chunk); + + //! Resets to the start and updates the flush flag + void Reset(bool flush = true); + +private: + //! The row data being scanned + RowDataCollection &rows; + //! The row heap being scanned + RowDataCollection &heap; + //! The data layout + const RowLayout layout; + //! Read state + ScanState read_state; + //! The total count of sorted_data + idx_t total_count; + //! The number of rows scanned so far + idx_t total_scanned; + //! Addresses used to gather from the sorted data + Vector addresses = Vector(LogicalType::POINTER); + //! Whether the blocks can be flushed to disk + const bool external; + //! Whether to flush the blocks after scanning + bool flush; + //! Whether we are unswizzling the blocks + const bool unswizzling; + + //! Checks that the newest block is valid + void ValidateUnscannedBlock() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/row_layout.hpp b/src/duckdb/src/include/duckdb/common/types/row/row_layout.hpp new file mode 100644 index 00000000..702cba4d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/row_layout.hpp @@ -0,0 +1,82 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/row_layout.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class RowLayout { +public: + friend class TupleDataLayout; + using ValidityBytes = TemplatedValidityMask; + + //! Creates an empty RowLayout + RowLayout(); + +public: + //! Initializes the RowLayout with the specified types to an empty RowLayout + void Initialize(vector types, bool align = true); + //! Returns the number of data columns + inline idx_t ColumnCount() const { + return types.size(); + } + //! Returns a list of the column types for this data chunk + inline const vector &GetTypes() const { + return types; + } + //! Returns the total width required for each row, including padding + inline idx_t GetRowWidth() const { + return row_width; + } + //! Returns the offset to the start of the data + inline idx_t GetDataOffset() const { + return flag_width; + } + //! Returns the total width required for the data, including padding + inline idx_t GetDataWidth() const { + return data_width; + } + //! Returns the offset to the start of the aggregates + inline idx_t GetAggrOffset() const { + return flag_width + data_width; + } + //! Returns the column offsets into each row + inline const vector &GetOffsets() const { + return offsets; + } + //! Returns whether all columns in this layout are constant size + inline bool AllConstant() const { + return all_constant; + } + inline idx_t GetHeapOffset() const { + return heap_pointer_offset; + } + +private: + //! The types of the data columns + vector types; + //! The width of the validity header + idx_t flag_width; + //! The width of the data portion + idx_t data_width; + //! The width of the entire row + idx_t row_width; + //! The offsets to the columns and aggregate data in each row + vector offsets; + //! Whether all columns in this layout are constant size + bool all_constant; + //! Offset to the pointer to the heap for each row + idx_t heap_pointer_offset; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp new file mode 100644 index 00000000..93eedfa0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp @@ -0,0 +1,122 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/tuple_data_allocator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/row/tuple_data_layout.hpp" +#include "duckdb/common/types/row/tuple_data_states.hpp" + +namespace duckdb { + +struct TupleDataSegment; +struct TupleDataChunk; +struct TupleDataChunkPart; + +struct TupleDataBlock { +public: + TupleDataBlock(BufferManager &buffer_manager, idx_t capacity_p); + + //! Disable copy constructors + TupleDataBlock(const TupleDataBlock &other) = delete; + TupleDataBlock &operator=(const TupleDataBlock &) = delete; + + //! Enable move constructors + TupleDataBlock(TupleDataBlock &&other) noexcept; + TupleDataBlock &operator=(TupleDataBlock &&) noexcept; + +public: + //! Remaining capacity (in bytes) + idx_t RemainingCapacity() const { + D_ASSERT(size <= capacity); + return capacity - size; + } + + //! Remaining capacity (in rows) + idx_t RemainingCapacity(idx_t row_width) const { + return RemainingCapacity() / row_width; + } + +public: + //! The underlying row block + shared_ptr handle; + //! Capacity (in bytes) + idx_t capacity; + //! Occupied size (in bytes) + idx_t size; +}; + +class TupleDataAllocator { +public: + TupleDataAllocator(BufferManager &buffer_manager, const TupleDataLayout &layout); + TupleDataAllocator(TupleDataAllocator &allocator); + + //! Get the buffer manager + BufferManager &GetBufferManager(); + //! Get the buffer allocator + Allocator &GetAllocator(); + //! Get the layout + const TupleDataLayout &GetLayout() const; + //! Number of row blocks + idx_t RowBlockCount() const; + //! Number of heap blocks + idx_t HeapBlockCount() const; + +public: + //! Builds out the chunks for next append, given the metadata in the append state + void Build(TupleDataSegment &segment, TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, + const idx_t append_offset, const idx_t append_count); + //! Initializes a chunk, making its pointers valid + void InitializeChunkState(TupleDataSegment &segment, TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, + idx_t chunk_idx, bool init_heap); + static void RecomputeHeapPointers(Vector &old_heap_ptrs, const SelectionVector &old_heap_sel, + const data_ptr_t row_locations[], Vector &new_heap_ptrs, const idx_t offset, + const idx_t count, const TupleDataLayout &layout, const idx_t base_col_offset); + //! Releases or stores any handles in the management state that are no longer required + void ReleaseOrStoreHandles(TupleDataPinState &state, TupleDataSegment &segment, TupleDataChunk &chunk, + bool release_heap); + //! Releases or stores ALL handles in the management state + void ReleaseOrStoreHandles(TupleDataPinState &state, TupleDataSegment &segment); + +private: + //! Builds out a single part (grabs the lock) + TupleDataChunkPart BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, + const idx_t append_offset, const idx_t append_count, TupleDataChunk &chunk); + //! Internal function for InitializeChunkState + void InitializeChunkStateInternal(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, idx_t offset, + bool recompute, bool init_heap_pointers, bool init_heap_sizes, + unsafe_vector> &parts); + //! Internal function for ReleaseOrStoreHandles + static void ReleaseOrStoreHandlesInternal(TupleDataSegment &segment, + unsafe_vector &pinned_row_handles, + perfect_map_t &handles, const perfect_set_t &block_ids, + unsafe_vector &blocks, TupleDataPinProperties properties); + //! Pins the given row block + BufferHandle &PinRowBlock(TupleDataPinState &state, const TupleDataChunkPart &part); + //! Pins the given heap block + BufferHandle &PinHeapBlock(TupleDataPinState &state, const TupleDataChunkPart &part); + //! Gets the pointer to the rows for the given chunk part + data_ptr_t GetRowPointer(TupleDataPinState &state, const TupleDataChunkPart &part); + //! Gets the base pointer to the heap for the given chunk part + data_ptr_t GetBaseHeapPointer(TupleDataPinState &state, const TupleDataChunkPart &part); + +private: + //! The buffer manager + BufferManager &buffer_manager; + //! The layout of the data + const TupleDataLayout layout; + //! Blocks storing the fixed-size rows + unsafe_vector row_blocks; + //! Blocks storing the variable-size data of the fixed-size rows (e.g., string, list) + unsafe_vector heap_blocks; + + //! Re-usable arrays used while building buffer space + unsafe_vector> chunk_parts; + unsafe_vector> chunk_part_indices; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp new file mode 100644 index 00000000..34efa90d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp @@ -0,0 +1,249 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/tuple_data_collection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/row/tuple_data_layout.hpp" +#include "duckdb/common/types/row/tuple_data_segment.hpp" +#include "duckdb/common/types/row/tuple_data_states.hpp" + +namespace duckdb { + +class TupleDataAllocator; +struct TupleDataScatterFunction; +struct TupleDataGatherFunction; +struct RowOperationsState; + +typedef void (*tuple_data_scatter_function_t)(const Vector &source, const TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const TupleDataLayout &layout, const Vector &row_locations, + Vector &heap_locations, const idx_t col_idx, + const UnifiedVectorFormat &list_format, + const vector &child_functions); + +struct TupleDataScatterFunction { + tuple_data_scatter_function_t function; + vector child_functions; +}; + +typedef void (*tuple_data_gather_function_t)(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, + const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, + const SelectionVector &target_sel, Vector &list_vector, + const vector &child_functions); + +struct TupleDataGatherFunction { + tuple_data_gather_function_t function; + vector child_functions; +}; + +//! TupleDataCollection represents a set of buffer-managed data stored in row format +//! FIXME: rename to RowDataCollection after we phase it out +class TupleDataCollection { + friend class TupleDataChunkIterator; + friend class PartitionedTupleData; + +public: + //! Constructs a TupleDataCollection with the specified layout + TupleDataCollection(BufferManager &buffer_manager, const TupleDataLayout &layout); + //! Constructs a TupleDataCollection with the same (shared) allocator + explicit TupleDataCollection(shared_ptr allocator); + + ~TupleDataCollection(); + +public: + //! The layout of the stored rows + const TupleDataLayout &GetLayout() const; + //! The number of rows stored in the tuple data collection + const idx_t &Count() const; + //! The number of chunks stored in the tuple data collection + idx_t ChunkCount() const; + //! The size (in bytes) of the blocks held by this tuple data collection + idx_t SizeInBytes() const; + //! Unpins all held pins + void Unpin(); + + //! Gets the scatter function for the given type + static TupleDataScatterFunction GetScatterFunction(const LogicalType &type, bool within_list = false); + //! Gets the gather function for the given type + static TupleDataGatherFunction GetGatherFunction(const LogicalType &type, bool within_list = false); + + //! Initializes an Append state - useful for optimizing many appends made to the same tuple data collection + void InitializeAppend(TupleDataAppendState &append_state, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE); + //! Initializes an Append state - useful for optimizing many appends made to the same tuple data collection + void InitializeAppend(TupleDataAppendState &append_state, vector column_ids, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE); + //! Initializes the Pin state of an Append state + //! - Useful for optimizing many appends made to the same tuple data collection + void InitializeAppend(TupleDataPinState &pin_state, + TupleDataPinProperties = TupleDataPinProperties::UNPIN_AFTER_DONE); + //! Initializes the Chunk state of an Append state + //! - Useful for optimizing many appends made to the same tuple data collection + void InitializeChunkState(TupleDataChunkState &chunk_state, vector column_ids = {}); + //! Initializes the Chunk state of an Append state + //! - Useful for optimizing many appends made to the same tuple data collection + static void InitializeChunkState(TupleDataChunkState &chunk_state, const vector &types, + vector column_ids = {}); + //! Append a DataChunk directly to this TupleDataCollection - calls InitializeAppend and Append internally + void Append(DataChunk &new_chunk, const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), + idx_t append_count = DConstants::INVALID_INDEX); + //! Append a DataChunk directly to this TupleDataCollection - calls InitializeAppend and Append internally + void Append(DataChunk &new_chunk, vector column_ids, + const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), + const idx_t append_count = DConstants::INVALID_INDEX); + //! Append a DataChunk to this TupleDataCollection using the specified Append state + void Append(TupleDataAppendState &append_state, DataChunk &new_chunk, + const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), + const idx_t append_count = DConstants::INVALID_INDEX); + //! Append a DataChunk to this TupleDataCollection using the specified pin and Chunk states + void Append(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, DataChunk &new_chunk, + const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), + const idx_t append_count = DConstants::INVALID_INDEX); + //! Append a DataChunk to this TupleDataCollection using the specified pin and Chunk states + //! - ToUnifiedFormat has already been called + void AppendUnified(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, DataChunk &new_chunk, + const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), + const idx_t append_count = DConstants::INVALID_INDEX); + + //! Creates a UnifiedVectorFormat in the given Chunk state for the given DataChunk + static void ToUnifiedFormat(TupleDataChunkState &chunk_state, DataChunk &new_chunk); + //! Gets the UnifiedVectorFormat from the Chunk state as an array + static void GetVectorData(const TupleDataChunkState &chunk_state, UnifiedVectorFormat result[]); + //! Computes the heap sizes for the new DataChunk that will be appended + static void ComputeHeapSizes(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, + const SelectionVector &append_sel, const idx_t append_count); + + //! Builds out the buffer space for the specified Chunk state + void Build(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, const idx_t append_offset, + const idx_t append_count); + //! Scatters the given DataChunk to the rows in the specified Chunk state + void Scatter(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, const SelectionVector &append_sel, + const idx_t append_count) const; + //! Scatters the given Vector to the given column id to the rows in the specified Chunk state + void Scatter(TupleDataChunkState &chunk_state, const Vector &source, const column_t column_id, + const SelectionVector &append_sel, const idx_t append_count) const; + //! Copy rows from input to the built Chunk state + void CopyRows(TupleDataChunkState &chunk_state, TupleDataChunkState &input, const SelectionVector &append_sel, + const idx_t append_count) const; + + //! Finalizes the Pin state, releasing or storing blocks + void FinalizePinState(TupleDataPinState &pin_state, TupleDataSegment &segment); + //! Finalizes the Pin state, releasing or storing blocks + void FinalizePinState(TupleDataPinState &pin_state); + + //! Appends the other TupleDataCollection to this, destroying the other data collection + void Combine(TupleDataCollection &other); + //! Appends the other TupleDataCollection to this, destroying the other data collection + void Combine(unique_ptr other); + //! Resets the TupleDataCollection, clearing all data + void Reset(); + + //! Initializes a chunk with the correct types that can be used to call Append/Scan + void InitializeChunk(DataChunk &chunk) const; + //! Initializes a chunk with the correct types for a given scan state + void InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const; + //! Initializes a Scan state for scanning all columns + void InitializeScan(TupleDataScanState &state, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; + //! Initializes a Scan state for scanning a subset of the columns + void InitializeScan(TupleDataScanState &state, vector column_ids, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; + //! Initialize a parallel scan over the tuple data collection over all columns + void InitializeScan(TupleDataParallelScanState &state, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; + //! Initialize a parallel scan over the tuple data collection over a subset of the columns + void InitializeScan(TupleDataParallelScanState &gstate, vector column_ids, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; + //! Scans a DataChunk from the TupleDataCollection + bool Scan(TupleDataScanState &state, DataChunk &result); + //! Scans a DataChunk from the TupleDataCollection + bool Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate, DataChunk &result); + //! Whether the last scan has been completed on this TupleDataCollection + bool ScanComplete(const TupleDataScanState &state) const; + + //! Gathers a DataChunk from the TupleDataCollection, given the specific row locations (requires full pin) + void Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, DataChunk &result, + const SelectionVector &target_sel) const; + //! Gathers a DataChunk (only the columns given by column_ids) from the TupleDataCollection, + //! given the specific row locations (requires full pin) + void Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, + const vector &column_ids, DataChunk &result, const SelectionVector &target_sel) const; + //! Gathers a Vector (from the given column id) from the TupleDataCollection + //! given the specific row locations (requires full pin) + void Gather(Vector &row_locations, const SelectionVector &sel, const idx_t scan_count, const column_t column_id, + Vector &result, const SelectionVector &target_sel) const; + + //! Converts this TupleDataCollection to a string representation + string ToString(); + //! Prints the string representation of this TupleDataCollection + void Print(); + + //! Verify that all blocks are pinned + void VerifyEverythingPinned() const; + +private: + //! Initializes the TupleDataCollection (called by the constructor) + void Initialize(); + //! Gets all column ids + void GetAllColumnIDs(vector &column_ids); + //! Adds a segment to this TupleDataCollection + void AddSegment(TupleDataSegment &&segment); + + //! Computes the heap sizes for the specific Vector that will be appended + static void ComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, TupleDataVectorFormat &source, + const SelectionVector &append_sel, const idx_t append_count); + //! Computes the heap sizes for the specific Vector that will be appended (within a list) + static void WithinListHeapComputeSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, const SelectionVector &append_sel, + const idx_t append_count, const UnifiedVectorFormat &list_data); + //! Computes the heap sizes for the fixed-size type Vector that will be appended (within a list) + static void ComputeFixedWithinListHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, const SelectionVector &append_sel, + const idx_t append_count, const UnifiedVectorFormat &list_data); + //! Computes the heap sizes for the string Vector that will be appended (within a list) + static void StringWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const UnifiedVectorFormat &list_data); + //! Computes the heap sizes for the struct Vector that will be appended (within a list) + static void StructWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, + const SelectionVector &append_sel, const idx_t append_count, + const UnifiedVectorFormat &list_data); + //! Computes the heap sizes for the list Vector that will be appended (within a list) + static void ListWithinListComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, + TupleDataVectorFormat &source_format, const SelectionVector &append_sel, + const idx_t append_count, const UnifiedVectorFormat &list_data); + + //! Get the next segment/chunk index for the scan + bool NextScanIndex(TupleDataScanState &scan_state, idx_t &segment_index, idx_t &chunk_index); + //! Scans the chunk at the given segment/chunk indices + void ScanAtIndex(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, const vector &column_ids, + idx_t segment_index, idx_t chunk_index, DataChunk &result); + + //! Verify count/data size of this collection + void Verify() const; + +private: + //! The layout of the TupleDataCollection + const TupleDataLayout layout; + //! The TupleDataAllocator + shared_ptr allocator; + //! The number of entries stored in the TupleDataCollection + idx_t count; + //! The size (in bytes) of this TupleDataCollection + idx_t data_size; + //! The data segments of the TupleDataCollection + unsafe_vector segments; + //! The set of scatter functions + vector scatter_functions; + //! The set of gather functions + vector gather_functions; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_iterator.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_iterator.hpp new file mode 100644 index 00000000..40a8eb1a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_iterator.hpp @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/tuple_data_iterator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/row/tuple_data_collection.hpp" + +namespace duckdb { + +class TupleDataChunkIterator { +public: + //! Creates a TupleDataChunkIterator that iterates over all DataChunks in the TupleDataCollection + TupleDataChunkIterator(TupleDataCollection &collection, TupleDataPinProperties properties, bool init_heap); + //! Creates a TupleDataChunkIterator that iterates over the specified DataChunk range in the TupleDataCollection + TupleDataChunkIterator(TupleDataCollection &collection, TupleDataPinProperties properties, idx_t chunk_idx_from, + idx_t chunk_idx_to, bool init_heap); + +public: + //! Whether the iterator is done + bool Done() const; + //! Fetches the next STANDARD_VECTOR_SIZE row locations (and heap locations/sizes if init_heap is true) + bool Next(); + //! Resets the scan indices to the start + void Reset(); + //! Get the count of the current "DataChunk" + idx_t GetCurrentChunkCount() const; + //! Get the Chunk state of the scan state of this iterator + TupleDataChunkState &GetChunkState(); + //! Get the array holding the row locations + data_ptr_t *GetRowLocations(); + //! Get the array holding the heap locations + data_ptr_t *GetHeapLocations(); + //! Get the array holding the heap sizes + idx_t *GetHeapSizes(); + +private: + //! Initializes the row locations (and heap locations/sizes if init_heap is true) at the current scan indices + void InitializeCurrentChunk(); + +private: + //! The collection being iterated over + TupleDataCollection &collection; + //! Whether or not to fetch the heap locations/sizes while iterating + bool init_heap; + + //! Start indices + idx_t start_segment_idx; + idx_t start_chunk_idx; + //! End indices + idx_t end_segment_idx; + idx_t end_chunk_idx; + + //! Current scan state and scan indices + TupleDataScanState state; + idx_t current_segment_idx; + idx_t current_chunk_idx; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_layout.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_layout.hpp new file mode 100644 index 00000000..f0d90e47 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_layout.hpp @@ -0,0 +1,120 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/tuple_data_layout.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/row/row_layout.hpp" +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class TupleDataLayout { +public: + using Aggregates = vector; + using ValidityBytes = TemplatedValidityMask; + + //! Creates an empty TupleDataLayout + TupleDataLayout(); + //! Create a copy of this TupleDataLayout + TupleDataLayout Copy() const; + +public: + //! Initializes the TupleDataLayout with the specified types and aggregates to an empty TupleDataLayout + void Initialize(vector types_p, Aggregates aggregates_p, bool align = true, bool heap_offset = true); + //! Initializes the TupleDataLayout with the specified types to an empty TupleDataLayout + void Initialize(vector types, bool align = true, bool heap_offset = true); + //! Initializes the TupleDataLayout with the specified aggregates to an empty TupleDataLayout + void Initialize(Aggregates aggregates_p, bool align = true, bool heap_offset = true); + + //! Returns the number of data columns + inline idx_t ColumnCount() const { + return types.size(); + } + //! Returns a list of the column types for this data chunk + inline const vector &GetTypes() const { + return types; + } + //! Returns the number of aggregates + inline idx_t AggregateCount() const { + return aggregates.size(); + } + //! Returns a list of the aggregates for this data chunk + inline Aggregates &GetAggregates() { + return aggregates; + } + //! Returns a map from column id to the struct TupleDataLayout + const inline TupleDataLayout &GetStructLayout(idx_t col_idx) const { + D_ASSERT(struct_layouts->find(col_idx) != struct_layouts->end()); + return struct_layouts->find(col_idx)->second; + } + //! Returns the total width required for each row, including padding + inline idx_t GetRowWidth() const { + return row_width; + } + //! Returns the offset to the start of the data + inline idx_t GetDataOffset() const { + return flag_width; + } + //! Returns the total width required for the data, including padding + inline idx_t GetDataWidth() const { + return data_width; + } + //! Returns the offset to the start of the aggregates + inline idx_t GetAggrOffset() const { + return flag_width + data_width; + } + //! Returns the total width required for the aggregates, including padding + inline idx_t GetAggrWidth() const { + return aggr_width; + } + //! Returns the column offsets into each row + inline const vector &GetOffsets() const { + return offsets; + } + //! Returns whether all columns in this layout are constant size + inline bool AllConstant() const { + return all_constant; + } + //! Gets offset to where heap size is stored + inline idx_t GetHeapSizeOffset() const { + return heap_size_offset; + } + //! Returns whether any of the aggregates have a destructor + inline bool HasDestructor() const { + return has_destructor; + } + +private: + //! The types of the data columns + vector types; + //! The aggregate functions + Aggregates aggregates; + //! Structs are a recursive TupleDataLayout + unique_ptr> struct_layouts; + //! The width of the validity header + idx_t flag_width; + //! The width of the data portion + idx_t data_width; + //! The width of the aggregate state portion + idx_t aggr_width; + //! The width of the entire row + idx_t row_width; + //! The offsets to the columns and aggregate data in each row + vector offsets; + //! Whether all columns in this layout are constant size + bool all_constant; + //! Offset to the heap size of every row + idx_t heap_size_offset; + //! Whether any of the aggregates have a destructor + bool has_destructor; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp new file mode 100644 index 00000000..7b8489f2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp @@ -0,0 +1,129 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/tuple_data_segment.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/perfect_map_set.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +class TupleDataAllocator; +class TupleDataLayout; + +struct TupleDataChunkPart { +public: + TupleDataChunkPart(mutex &lock); + + //! Disable copy constructors + TupleDataChunkPart(const TupleDataChunkPart &other) = delete; + TupleDataChunkPart &operator=(const TupleDataChunkPart &) = delete; + + //! Enable move constructors + TupleDataChunkPart(TupleDataChunkPart &&other) noexcept; + TupleDataChunkPart &operator=(TupleDataChunkPart &&) noexcept; + + static constexpr const uint32_t INVALID_INDEX = (uint32_t)-1; + +public: + //! Index/offset of the row block + uint32_t row_block_index; + uint32_t row_block_offset; + //! Pointer/index/offset of the heap block + uint32_t heap_block_index; + uint32_t heap_block_offset; + data_ptr_t base_heap_ptr; + //! Total heap size for this chunk part + uint32_t total_heap_size; + //! Tuple count for this chunk part + uint32_t count; + //! Lock for recomputing heap pointers (owned by TupleDataChunk) + reference lock; +}; + +struct TupleDataChunk { +public: + TupleDataChunk(); + + //! Disable copy constructors + TupleDataChunk(const TupleDataChunk &other) = delete; + TupleDataChunk &operator=(const TupleDataChunk &) = delete; + + //! Enable move constructors + TupleDataChunk(TupleDataChunk &&other) noexcept; + TupleDataChunk &operator=(TupleDataChunk &&) noexcept; + + //! Add a part to this chunk + void AddPart(TupleDataChunkPart &&part, const TupleDataLayout &layout); + //! Tries to merge the last chunk part into the second-to-last one + void MergeLastChunkPart(const TupleDataLayout &layout); + //! Verify counts of the parts in this chunk + void Verify() const; + +public: + //! The parts of this chunk + unsafe_vector parts; + //! The row block ids referenced by the chunk + perfect_set_t row_block_ids; + //! The heap block ids referenced by the chunk + perfect_set_t heap_block_ids; + //! Tuple count for this chunk + idx_t count; + //! Lock for recomputing heap pointers + unsafe_unique_ptr lock; +}; + +struct TupleDataSegment { +public: + explicit TupleDataSegment(shared_ptr allocator); + + ~TupleDataSegment(); + + //! Disable copy constructors + TupleDataSegment(const TupleDataSegment &other) = delete; + TupleDataSegment &operator=(const TupleDataSegment &) = delete; + + //! Enable move constructors + TupleDataSegment(TupleDataSegment &&other) noexcept; + TupleDataSegment &operator=(TupleDataSegment &&) noexcept; + + //! The number of chunks in this segment + idx_t ChunkCount() const; + //! The size (in bytes) of this segment + idx_t SizeInBytes() const; + //! Unpins all held pins + void Unpin(); + + //! Verify counts of the chunks in this segment + void Verify() const; + //! Verify that all blocks in this segment are pinned + void VerifyEverythingPinned() const; + +public: + //! The allocator for this segment + shared_ptr allocator; + //! The chunks of this segment + unsafe_vector chunks; + //! The tuple count of this segment + idx_t count; + //! The data size of this segment + idx_t data_size; + + //! Lock for modifying pinned_handles + mutex pinned_handles_lock; + //! Where handles to row blocks will be stored with TupleDataPinProperties::KEEP_EVERYTHING_PINNED + unsafe_vector pinned_row_handles; + //! Where handles to heap blocks will be stored with TupleDataPinProperties::KEEP_EVERYTHING_PINNED + unsafe_vector pinned_heap_handles; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp new file mode 100644 index 00000000..6d29d36c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/row/tuple_data_states.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/perfect_map_set.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +enum class TupleDataPinProperties : uint8_t { + INVALID, + //! Keeps all passed blocks pinned while scanning/iterating over the chunks (for both reading/writing) + KEEP_EVERYTHING_PINNED, + //! Unpins blocks after they are done (for both reading/writing) + UNPIN_AFTER_DONE, + //! Destroys blocks after they are done (for reading only) + DESTROY_AFTER_DONE, + //! Assumes all blocks are already pinned (for reading only) + ALREADY_PINNED +}; + +struct TupleDataPinState { + perfect_map_t row_handles; + perfect_map_t heap_handles; + TupleDataPinProperties properties = TupleDataPinProperties::INVALID; +}; + +struct CombinedListData { + UnifiedVectorFormat combined_data; + list_entry_t combined_list_entries[STANDARD_VECTOR_SIZE]; + buffer_ptr selection_data; +}; + +struct TupleDataVectorFormat { + const SelectionVector *original_sel; + SelectionVector original_owned_sel; + + UnifiedVectorFormat unified; + vector children; + unique_ptr combined_list_data; +}; + +struct TupleDataChunkState { + vector vector_data; + vector column_ids; + + Vector row_locations = Vector(LogicalType::POINTER); + Vector heap_locations = Vector(LogicalType::POINTER); + Vector heap_sizes = Vector(LogicalType::UBIGINT); +}; + +struct TupleDataAppendState { + TupleDataPinState pin_state; + TupleDataChunkState chunk_state; +}; + +struct TupleDataScanState { + TupleDataPinState pin_state; + TupleDataChunkState chunk_state; + idx_t segment_index = DConstants::INVALID_INDEX; + idx_t chunk_index = DConstants::INVALID_INDEX; +}; + +struct TupleDataParallelScanState { + TupleDataScanState scan_state; + mutex lock; +}; + +using TupleDataLocalScanState = TupleDataScanState; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/sel_cache.hpp b/src/duckdb/src/include/duckdb/common/types/sel_cache.hpp new file mode 100644 index 00000000..d0c096e1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/sel_cache.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/sel_cache.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/vector_buffer.hpp" +#include "duckdb/common/unordered_map.hpp" + +namespace duckdb { + +//! Selection vector cache used for caching vector slices +struct SelCache { + unordered_map> cache; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp new file mode 100644 index 00000000..ec403fa2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp @@ -0,0 +1,210 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/selection_vector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/vector_size.hpp" + +namespace duckdb { +class VectorBuffer; + +struct SelectionData { + DUCKDB_API explicit SelectionData(idx_t count); + + unsafe_unique_array owned_data; +}; + +struct SelectionVector { + SelectionVector() : sel_vector(nullptr) { + } + explicit SelectionVector(sel_t *sel) { + Initialize(sel); + } + explicit SelectionVector(idx_t count) { + Initialize(count); + } + SelectionVector(idx_t start, idx_t count) { + Initialize(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < count; i++) { + set_index(i, start + i); + } + } + SelectionVector(const SelectionVector &sel_vector) { + Initialize(sel_vector); + } + explicit SelectionVector(buffer_ptr data) { + Initialize(std::move(data)); + } + SelectionVector &operator=(SelectionVector &&other) { + sel_vector = other.sel_vector; + other.sel_vector = nullptr; + selection_data = std::move(other.selection_data); + return *this; + } + +public: + static idx_t Inverted(const SelectionVector &src, SelectionVector &dst, idx_t source_size, idx_t count) { + idx_t src_idx = 0; + idx_t dst_idx = 0; + for (idx_t i = 0; i < count; i++) { + if (src_idx < source_size && src.get_index(src_idx) == i) { + src_idx++; + // This index is selected by 'src', skip it in 'dst' + continue; + } + // This index does not exist in 'src', add it to the selection of 'dst' + dst.set_index(dst_idx++, i); + } + return dst_idx; + } + + void Initialize(sel_t *sel) { + selection_data.reset(); + sel_vector = sel; + } + void Initialize(idx_t count = STANDARD_VECTOR_SIZE) { + selection_data = make_shared(count); + sel_vector = selection_data->owned_data.get(); + } + void Initialize(buffer_ptr data) { + selection_data = std::move(data); + sel_vector = selection_data->owned_data.get(); + } + void Initialize(const SelectionVector &other) { + selection_data = other.selection_data; + sel_vector = other.sel_vector; + } + + inline void set_index(idx_t idx, idx_t loc) { + sel_vector[idx] = loc; + } + inline void swap(idx_t i, idx_t j) { + sel_t tmp = sel_vector[i]; + sel_vector[i] = sel_vector[j]; + sel_vector[j] = tmp; + } + inline idx_t get_index(idx_t idx) const { + return sel_vector ? sel_vector[idx] : idx; + } + sel_t *data() { + return sel_vector; + } + const sel_t *data() const { + return sel_vector; + } + buffer_ptr sel_data() { + return selection_data; + } + buffer_ptr Slice(const SelectionVector &sel, idx_t count) const; + + string ToString(idx_t count = 0) const; + void Print(idx_t count = 0) const; + + inline sel_t &operator[](idx_t index) const { + return sel_vector[index]; + } + +private: + sel_t *sel_vector; + buffer_ptr selection_data; +}; + +class OptionalSelection { +public: + explicit inline OptionalSelection(SelectionVector *sel_p) { + Initialize(sel_p); + } + void Initialize(SelectionVector *sel_p) { + sel = sel_p; + if (sel) { + vec.Initialize(sel->data()); + sel = &vec; + } + } + + inline operator SelectionVector *() { + return sel; + } + + inline void Append(idx_t &count, const idx_t idx) { + if (sel) { + sel->set_index(count, idx); + } + ++count; + } + + inline void Advance(idx_t completed) { + if (sel) { + sel->Initialize(sel->data() + completed); + } + } + +private: + SelectionVector *sel; + SelectionVector vec; +}; + +// Contains a selection vector, combined with a count +class ManagedSelection { +public: + explicit inline ManagedSelection(idx_t size, bool initialize = true) + : initialized(initialize), size(size), internal_opt_selvec(nullptr) { + count = 0; + if (!initialized) { + return; + } + sel_vec.Initialize(size); + internal_opt_selvec.Initialize(&sel_vec); + } + +public: + bool Initialized() const { + return initialized; + } + void Initialize(idx_t size) { + D_ASSERT(!initialized); + this->size = size; + sel_vec.Initialize(size); + internal_opt_selvec.Initialize(&sel_vec); + initialized = true; + } + + inline idx_t operator[](idx_t index) const { + D_ASSERT(index < size); + return sel_vec.get_index(index); + } + inline bool IndexMapsToLocation(idx_t idx, idx_t location) const { + return idx < count && sel_vec.get_index(idx) == location; + } + inline void Append(const idx_t idx) { + internal_opt_selvec.Append(count, idx); + } + inline idx_t Count() const { + return count; + } + inline idx_t Size() const { + return size; + } + inline const SelectionVector &Selection() const { + return sel_vec; + } + inline SelectionVector &Selection() { + return sel_vec; + } + +private: + bool initialized = false; + idx_t count; + idx_t size; + SelectionVector sel_vec; + OptionalSelection internal_opt_selvec; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/string_heap.hpp b/src/duckdb/src/include/duckdb/common/types/string_heap.hpp new file mode 100644 index 00000000..59dbb3f0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/string_heap.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/string_heap.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/storage/arena_allocator.hpp" + +namespace duckdb { +//! A string heap is the owner of a set of strings, strings can be inserted into +//! it On every insert, a pointer to the inserted string is returned The +//! returned pointer will remain valid until the StringHeap is destroyed +class StringHeap { +public: + DUCKDB_API StringHeap(Allocator &allocator = Allocator::DefaultAllocator()); + + DUCKDB_API void Destroy(); + DUCKDB_API void Move(StringHeap &other); + + //! Add a string to the string heap, returns a pointer to the string + DUCKDB_API string_t AddString(const char *data, idx_t len); + //! Add a string to the string heap, returns a pointer to the string + DUCKDB_API string_t AddString(const char *data); + //! Add a string to the string heap, returns a pointer to the string + DUCKDB_API string_t AddString(const string &data); + //! Add a string to the string heap, returns a pointer to the string + DUCKDB_API string_t AddString(const string_t &data); + //! Add a blob to the string heap; blobs can be non-valid UTF8 + DUCKDB_API string_t AddBlob(const string_t &data); + //! Add a blob to the string heap; blobs can be non-valid UTF8 + DUCKDB_API string_t AddBlob(const char *data, idx_t len); + //! Allocates space for an empty string of size "len" on the heap + DUCKDB_API string_t EmptyString(idx_t len); + + //! Size of strings + DUCKDB_API idx_t SizeInBytes() const; + +private: + ArenaAllocator allocator; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/string_type.hpp b/src/duckdb/src/include/duckdb/common/types/string_type.hpp new file mode 100644 index 00000000..5cec0eea --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/string_type.hpp @@ -0,0 +1,221 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/string_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/helper.hpp" + +#include + +namespace duckdb { + +struct string_t { + friend struct StringComparisonOperators; + friend class StringSegment; + +public: + static constexpr idx_t PREFIX_BYTES = 4 * sizeof(char); + static constexpr idx_t INLINE_BYTES = 12 * sizeof(char); + static constexpr idx_t HEADER_SIZE = sizeof(uint32_t) + PREFIX_BYTES; +#ifndef DUCKDB_DEBUG_NO_INLINE + static constexpr idx_t PREFIX_LENGTH = PREFIX_BYTES; + static constexpr idx_t INLINE_LENGTH = INLINE_BYTES; +#else + static constexpr idx_t PREFIX_LENGTH = 0; + static constexpr idx_t INLINE_LENGTH = 0; +#endif + + string_t() = default; + explicit string_t(uint32_t len) { + value.inlined.length = len; + } + string_t(const char *data, uint32_t len) { + value.inlined.length = len; + D_ASSERT(data || GetSize() == 0); + if (IsInlined()) { + // zero initialize the prefix first + // this makes sure that strings with length smaller than 4 still have an equal prefix + memset(value.inlined.inlined, 0, INLINE_BYTES); + if (GetSize() == 0) { + return; + } + // small string: inlined + memcpy(value.inlined.inlined, data, GetSize()); + } else { + // large string: store pointer +#ifndef DUCKDB_DEBUG_NO_INLINE + memcpy(value.pointer.prefix, data, PREFIX_LENGTH); +#else + memset(value.pointer.prefix, 0, PREFIX_BYTES); +#endif + value.pointer.ptr = (char *)data; // NOLINT + } + } + string_t(const char *data) : string_t(data, strlen(data)) { // NOLINT: Allow implicit conversion from `const char*` + } + string_t(const string &value) + : string_t(value.c_str(), value.size()) { // NOLINT: Allow implicit conversion from `const char*` + } + + bool IsInlined() const { + return GetSize() <= INLINE_LENGTH; + } + + const char *GetData() const { + return IsInlined() ? const_char_ptr_cast(value.inlined.inlined) : value.pointer.ptr; + } + const char *GetDataUnsafe() const { + return GetData(); + } + + char *GetDataWriteable() const { + return IsInlined() ? (char *)value.inlined.inlined : value.pointer.ptr; // NOLINT + } + + const char *GetPrefix() const { + return value.pointer.prefix; + } + + char *GetPrefixWriteable() const { + return (char *)value.pointer.prefix; + } + + idx_t GetSize() const { + return value.inlined.length; + } + + string GetString() const { + return string(GetData(), GetSize()); + } + + explicit operator string() const { + return GetString(); + } + + char *GetPointer() const { + D_ASSERT(!IsInlined()); + return value.pointer.ptr; + } + + void SetPointer(char *new_ptr) { + D_ASSERT(!IsInlined()); + value.pointer.ptr = new_ptr; + } + + void Finalize() { + // set trailing NULL byte + if (GetSize() <= INLINE_LENGTH) { + // fill prefix with zeros if the length is smaller than the prefix length + for (idx_t i = GetSize(); i < INLINE_BYTES; i++) { + value.inlined.inlined[i] = '\0'; + } + } else { + // copy the data into the prefix +#ifndef DUCKDB_DEBUG_NO_INLINE + auto dataptr = GetData(); + memcpy(value.pointer.prefix, dataptr, PREFIX_LENGTH); +#else + memset(value.pointer.prefix, 0, PREFIX_BYTES); +#endif + } + } + + void Verify() const; + void VerifyNull() const; + + struct StringComparisonOperators { + static inline bool Equals(const string_t &a, const string_t &b) { +#ifdef DUCKDB_DEBUG_NO_INLINE + if (a.GetSize() != b.GetSize()) + return false; + return (memcmp(a.GetData(), b.GetData(), a.GetSize()) == 0); +#endif + uint64_t A = Load(const_data_ptr_cast(&a)); + uint64_t B = Load(const_data_ptr_cast(&b)); + if (A != B) { + // Either length or prefix are different -> not equal + return false; + } + // they have the same length and same prefix! + A = Load(const_data_ptr_cast(&a) + 8u); + B = Load(const_data_ptr_cast(&b) + 8u); + if (A == B) { + // either they are both inlined (so compare equal) or point to the same string (so compare equal) + return true; + } + if (!a.IsInlined()) { + // 'long' strings of the same length -> compare pointed value + if (memcmp(a.value.pointer.ptr, b.value.pointer.ptr, a.GetSize()) == 0) { + return true; + } + } + // either they are short string of same length but different content + // or they point to string with different content + // either way, they can't represent the same underlying string + return false; + } + // compare up to shared length. if still the same, compare lengths + static bool GreaterThan(const string_t &left, const string_t &right) { + const uint32_t left_length = left.GetSize(); + const uint32_t right_length = right.GetSize(); + const uint32_t min_length = std::min(left_length, right_length); + +#ifndef DUCKDB_DEBUG_NO_INLINE + uint32_t A = Load(const_data_ptr_cast(left.GetPrefix())); + uint32_t B = Load(const_data_ptr_cast(right.GetPrefix())); + + // Utility to move 0xa1b2c3d4 into 0xd4c3b2a1, basically inverting the order byte-a-byte + auto bswap = [](uint32_t v) -> uint32_t { + uint32_t t1 = (v >> 16u) | (v << 16u); + uint32_t t2 = t1 & 0x00ff00ff; + uint32_t t3 = t1 & 0xff00ff00; + return (t2 << 8u) | (t3 >> 8u); + }; + + // Check on prefix ----- + // We dont' need to mask since: + // if the prefix is greater(after bswap), it will stay greater regardless of the extra bytes + // if the prefix is smaller(after bswap), it will stay smaller regardless of the extra bytes + // if the prefix is equal, the extra bytes are guaranteed to be /0 for the shorter one + + if (A != B) + return bswap(A) > bswap(B); +#endif + auto memcmp_res = memcmp(left.GetData(), right.GetData(), min_length); + return memcmp_res > 0 || (memcmp_res == 0 && left_length > right_length); + } + }; + + bool operator==(const string_t &r) const { + return StringComparisonOperators::Equals(*this, r); + } + + bool operator>(const string_t &r) const { + return StringComparisonOperators::GreaterThan(*this, r); + } + bool operator<(const string_t &r) const { + return r > *this; + } + +private: + union { + struct { + uint32_t length; + char prefix[4]; + char *ptr; + } pointer; + struct { + uint32_t length; + char inlined[12]; + } inlined; + } value; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/time.hpp b/src/duckdb/src/include/duckdb/common/types/time.hpp new file mode 100644 index 00000000..c19f2a01 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/time.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/time.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/winapi.hpp" + +namespace duckdb { + +struct dtime_t; +struct dtime_tz_t; + +//! The Time class is a static class that holds helper functions for the Time +//! type. +class Time { +public: + //! Convert a string in the format "hh:mm:ss" to a time object + DUCKDB_API static dtime_t FromString(const string &str, bool strict = false); + DUCKDB_API static dtime_t FromCString(const char *buf, idx_t len, bool strict = false); + DUCKDB_API static bool TryConvertTime(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict = false); + DUCKDB_API static bool TryConvertTimeTZ(const char *buf, idx_t len, idx_t &pos, dtime_tz_t &result, + bool strict = false); + //! Format is ±[HH]HH[:MM[:SS]] + DUCKDB_API static bool TryParseUTCOffset(const char *str, idx_t &pos, idx_t len, int32_t &offset); + + //! Convert a time object to a string in the format "hh:mm:ss" + DUCKDB_API static string ToString(dtime_t time); + //! Convert a UTC offset to ±HH[:MM] + DUCKDB_API static string ToUTCOffset(int hour_offset, int minute_offset); + + DUCKDB_API static dtime_t FromTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds = 0); + + //! Extract the time from a given timestamp object + DUCKDB_API static void Convert(dtime_t time, int32_t &out_hour, int32_t &out_min, int32_t &out_sec, + int32_t &out_micros); + + DUCKDB_API static string ConversionError(const string &str); + DUCKDB_API static string ConversionError(string_t str); + + DUCKDB_API static dtime_t FromTimeMs(int64_t time_ms); + DUCKDB_API static dtime_t FromTimeNs(int64_t time_ns); + + DUCKDB_API static bool IsValidTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds); + +private: + static bool TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/timestamp.hpp b/src/duckdb/src/include/duckdb/common/types/timestamp.hpp new file mode 100644 index 00000000..0bb5101a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/timestamp.hpp @@ -0,0 +1,206 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/timestamp.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/winapi.hpp" + +#include + +namespace duckdb { + +struct date_t; +struct dtime_t; + +//! Type used to represent timestamps (seconds,microseconds,milliseconds or nanoseconds since 1970-01-01) +struct timestamp_t { // NOLINT + int64_t value; + + timestamp_t() = default; + explicit inline constexpr timestamp_t(int64_t value_p) : value(value_p) { + } + inline timestamp_t &operator=(int64_t value_p) { + value = value_p; + return *this; + } + + // explicit conversion + explicit inline operator int64_t() const { + return value; + } + + // comparison operators + inline bool operator==(const timestamp_t &rhs) const { + return value == rhs.value; + }; + inline bool operator!=(const timestamp_t &rhs) const { + return value != rhs.value; + }; + inline bool operator<=(const timestamp_t &rhs) const { + return value <= rhs.value; + }; + inline bool operator<(const timestamp_t &rhs) const { + return value < rhs.value; + }; + inline bool operator>(const timestamp_t &rhs) const { + return value > rhs.value; + }; + inline bool operator>=(const timestamp_t &rhs) const { + return value >= rhs.value; + }; + + // arithmetic operators + timestamp_t operator+(const double &value) const; + int64_t operator-(const timestamp_t &other) const; + + // in-place operators + timestamp_t &operator+=(const int64_t &delta); + timestamp_t &operator-=(const int64_t &delta); + + // special values + static constexpr timestamp_t infinity() { // NOLINT + return timestamp_t(NumericLimits::Maximum()); + } // NOLINT + static constexpr timestamp_t ninfinity() { // NOLINT + return timestamp_t(-NumericLimits::Maximum()); + } // NOLINT + static constexpr inline timestamp_t epoch() { // NOLINT + return timestamp_t(0); + } // NOLINT +}; + +struct timestamp_tz_t : public timestamp_t { // NOLINT +}; +struct timestamp_ns_t : public timestamp_t { // NOLINT +}; +struct timestamp_ms_t : public timestamp_t { // NOLINT +}; +struct timestamp_sec_t : public timestamp_t { // NOLINT +}; + +enum class TimestampCastResult : uint8_t { SUCCESS, ERROR_INCORRECT_FORMAT, ERROR_NON_UTC_TIMEZONE }; + +//! The Timestamp class is a static class that holds helper functions for the Timestamp +//! type. +class Timestamp { +public: + // min timestamp is 290308-12-22 (BC) + constexpr static const int32_t MIN_YEAR = -290308; + constexpr static const int32_t MIN_MONTH = 12; + constexpr static const int32_t MIN_DAY = 22; + +public: + //! Convert a string in the format "YYYY-MM-DD hh:mm:ss[.f][-+TH[:tm]]" to a timestamp object + DUCKDB_API static timestamp_t FromString(const string &str); + //! Convert a string where the offset can also be a time zone string: / [A_Za-z0-9/_]+/ + //! If has_offset is true, then the result is an instant that was offset from UTC + //! If the tz is not empty, the result is still an instant, but the parts can be extracted and applied to the TZ + DUCKDB_API static bool TryConvertTimestampTZ(const char *str, idx_t len, timestamp_t &result, bool &has_offset, + string_t &tz); + DUCKDB_API static TimestampCastResult TryConvertTimestamp(const char *str, idx_t len, timestamp_t &result); + DUCKDB_API static timestamp_t FromCString(const char *str, idx_t len); + //! Convert a date object to a string in the format "YYYY-MM-DD hh:mm:ss" + DUCKDB_API static string ToString(timestamp_t timestamp); + + DUCKDB_API static date_t GetDate(timestamp_t timestamp); + + DUCKDB_API static dtime_t GetTime(timestamp_t timestamp); + //! Create a Timestamp object from a specified (date, time) combination + DUCKDB_API static timestamp_t FromDatetime(date_t date, dtime_t time); + DUCKDB_API static bool TryFromDatetime(date_t date, dtime_t time, timestamp_t &result); + + //! Is the character a valid part of a time zone name? + static inline bool CharacterIsTimeZone(char c) { + return StringUtil::CharacterIsAlpha(c) || StringUtil::CharacterIsDigit(c) || c == '_' || c == '/' || c == '+' || + c == '-'; + } + + //! Is the timestamp finite or infinite? + static inline bool IsFinite(timestamp_t timestamp) { + return timestamp != timestamp_t::infinity() && timestamp != timestamp_t::ninfinity(); + } + + //! Extract the date and time from a given timestamp object + DUCKDB_API static void Convert(timestamp_t date, date_t &out_date, dtime_t &out_time); + //! Returns current timestamp + DUCKDB_API static timestamp_t GetCurrentTimestamp(); + + //! Convert the epoch (in sec) to a timestamp + DUCKDB_API static timestamp_t FromEpochSeconds(int64_t ms); + //! Convert the epoch (in ms) to a timestamp + DUCKDB_API static timestamp_t FromEpochMs(int64_t ms); + //! Convert the epoch (in microseconds) to a timestamp + DUCKDB_API static timestamp_t FromEpochMicroSeconds(int64_t micros); + //! Convert the epoch (in nanoseconds) to a timestamp + DUCKDB_API static timestamp_t FromEpochNanoSeconds(int64_t micros); + + //! Convert the epoch (in seconds) to a timestamp + DUCKDB_API static int64_t GetEpochSeconds(timestamp_t timestamp); + //! Convert the epoch (in ms) to a timestamp + DUCKDB_API static int64_t GetEpochMs(timestamp_t timestamp); + //! Convert a timestamp to epoch (in microseconds) + DUCKDB_API static int64_t GetEpochMicroSeconds(timestamp_t timestamp); + //! Convert a timestamp to epoch (in nanoseconds) + DUCKDB_API static int64_t GetEpochNanoSeconds(timestamp_t timestamp); + //! Convert a timestamp to a Julian Day + DUCKDB_API static double GetJulianDay(timestamp_t timestamp); + + DUCKDB_API static bool TryParseUTCOffset(const char *str, idx_t &pos, idx_t len, int &hour_offset, + int &minute_offset); + + DUCKDB_API static string ConversionError(const string &str); + DUCKDB_API static string ConversionError(string_t str); + DUCKDB_API static string UnsupportedTimezoneError(const string &str); + DUCKDB_API static string UnsupportedTimezoneError(string_t str); +}; + +} // namespace duckdb + +namespace std { + +//! Timestamp +template <> +struct hash { + std::size_t operator()(const duckdb::timestamp_t &k) const { + using std::hash; + return hash()((int64_t)k); + } +}; +template <> +struct hash { + std::size_t operator()(const duckdb::timestamp_ms_t &k) const { + using std::hash; + return hash()((int64_t)k); + } +}; +template <> +struct hash { + std::size_t operator()(const duckdb::timestamp_ns_t &k) const { + using std::hash; + return hash()((int64_t)k); + } +}; +template <> +struct hash { + std::size_t operator()(const duckdb::timestamp_sec_t &k) const { + using std::hash; + return hash()((int64_t)k); + } +}; +template <> +struct hash { + std::size_t operator()(const duckdb::timestamp_tz_t &k) const { + using std::hash; + return hash()((int64_t)k); + } +}; +} // namespace std diff --git a/src/duckdb/src/include/duckdb/common/types/type_map.hpp b/src/duckdb/src/include/duckdb/common/types/type_map.hpp new file mode 100644 index 00000000..0d3566fc --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/type_map.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/type_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +struct LogicalTypeHashFunction { + uint64_t operator()(const LogicalType &type) const { + return (uint64_t)type.Hash(); + } +}; + +struct LogicalTypeEquality { + bool operator()(const LogicalType &a, const LogicalType &b) const { + return a == b; + } +}; + +template +using type_map_t = unordered_map; + +using type_set_t = unordered_set; + +struct LogicalTypeIdHashFunction { + uint64_t operator()(const LogicalTypeId &type_id) const { + return duckdb::Hash((uint8_t)type_id); + } +}; + +struct LogicalTypeIdEquality { + bool operator()(const LogicalTypeId &a, const LogicalTypeId &b) const { + return a == b; + } +}; + +template +using type_id_map_t = unordered_map; + +using type_id_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/uuid.hpp b/src/duckdb/src/include/duckdb/common/types/uuid.hpp new file mode 100644 index 00000000..36a70d73 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/uuid.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/uuid.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/string_type.hpp" + +namespace duckdb { +class ClientContext; +struct RandomEngine; + +//! The UUID class contains static operations for the UUID type +class UUID { +public: + constexpr static const uint8_t STRING_SIZE = 36; + //! Convert a uuid string to a hugeint object + static bool FromString(string str, hugeint_t &result); + //! Convert a uuid string to a hugeint object + static bool FromCString(const char *str, idx_t len, hugeint_t &result) { + return FromString(string(str, 0, len), result); + } + //! Convert a hugeint object to a uuid style string + static void ToString(hugeint_t input, char *buf); + + //! Convert a hugeint object to a uuid style string + static hugeint_t GenerateRandomUUID(RandomEngine &engine); + static hugeint_t GenerateRandomUUID(); + + //! Convert a hugeint object to a uuid style string + static string ToString(hugeint_t input) { + char buff[STRING_SIZE]; + ToString(input, buff); + return string(buff, STRING_SIZE); + } + + static hugeint_t FromString(string str) { + hugeint_t result; + FromString(str, result); + return result; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/validity_mask.hpp b/src/duckdb/src/include/duckdb/common/types/validity_mask.hpp new file mode 100644 index 00000000..d23c862c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/validity_mask.hpp @@ -0,0 +1,343 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/validity_mask.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/vector_size.hpp" + +namespace duckdb { +struct ValidityMask; + +template +struct TemplatedValidityData { + static constexpr const int BITS_PER_VALUE = sizeof(V) * 8; + static constexpr const V MAX_ENTRY = ~V(0); + +public: + inline explicit TemplatedValidityData(idx_t count) { + auto entry_count = EntryCount(count); + owned_data = make_unsafe_uniq_array(entry_count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + owned_data[entry_idx] = MAX_ENTRY; + } + } + inline TemplatedValidityData(const V *validity_mask, idx_t count) { + D_ASSERT(validity_mask); + auto entry_count = EntryCount(count); + owned_data = make_unsafe_uniq_array(entry_count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + owned_data[entry_idx] = validity_mask[entry_idx]; + } + } + + unsafe_unique_array owned_data; + +public: + static inline idx_t EntryCount(idx_t count) { + return (count + (BITS_PER_VALUE - 1)) / BITS_PER_VALUE; + } +}; + +using validity_t = uint64_t; + +struct ValidityData : TemplatedValidityData { +public: + DUCKDB_API explicit ValidityData(idx_t count); + DUCKDB_API ValidityData(const ValidityMask &original, idx_t count); +}; + +//! Type used for validity masks +template +struct TemplatedValidityMask { + using ValidityBuffer = TemplatedValidityData; + +public: + static constexpr const int BITS_PER_VALUE = ValidityBuffer::BITS_PER_VALUE; + static constexpr const int STANDARD_ENTRY_COUNT = (STANDARD_VECTOR_SIZE + (BITS_PER_VALUE - 1)) / BITS_PER_VALUE; + static constexpr const int STANDARD_MASK_SIZE = STANDARD_ENTRY_COUNT * sizeof(validity_t); + +public: + inline TemplatedValidityMask() : validity_mask(nullptr) { + } + inline explicit TemplatedValidityMask(idx_t max_count) { + Initialize(max_count); + } + inline explicit TemplatedValidityMask(V *ptr) : validity_mask(ptr) { + } + inline TemplatedValidityMask(const TemplatedValidityMask &original, idx_t count) { + Copy(original, count); + } + + static inline idx_t ValidityMaskSize(idx_t count = STANDARD_VECTOR_SIZE) { + return ValidityBuffer::EntryCount(count) * sizeof(V); + } + inline bool AllValid() const { + return !validity_mask; + } + inline bool CheckAllValid(idx_t count) const { + return CountValid(count) == count; + } + + inline bool CheckAllValid(idx_t to, idx_t from) const { + if (AllValid()) { + return true; + } + for (idx_t i = from; i < to; i++) { + if (!RowIsValid(i)) { + return false; + } + } + return true; + } + + idx_t CountValid(const idx_t count) const { + if (AllValid() || count == 0) { + return count; + } + + idx_t valid = 0; + const auto entry_count = EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count;) { + auto entry = GetValidityEntry(entry_idx++); + // Handle ragged end (if not exactly multiple of BITS_PER_VALUE) + if (entry_idx == entry_count && count % BITS_PER_VALUE != 0) { + idx_t idx_in_entry; + GetEntryIndex(count, entry_idx, idx_in_entry); + for (idx_t i = 0; i < idx_in_entry; ++i) { + valid += idx_t(RowIsValid(entry, i)); + } + break; + } + + // Handle all set + if (AllValid(entry)) { + valid += BITS_PER_VALUE; + continue; + } + + // Count partial entry (Kernighan's algorithm) + while (entry) { + entry &= (entry - 1); + ++valid; + } + } + + return valid; + } + + inline V *GetData() const { + return validity_mask; + } + inline void Reset() { + validity_mask = nullptr; + validity_data.reset(); + } + + static inline idx_t EntryCount(idx_t count) { + return ValidityBuffer::EntryCount(count); + } + inline V GetValidityEntry(idx_t entry_idx) const { + if (!validity_mask) { + return ValidityBuffer::MAX_ENTRY; + } + return GetValidityEntryUnsafe(entry_idx); + } + inline V &GetValidityEntryUnsafe(idx_t entry_idx) const { + return validity_mask[entry_idx]; + } + static inline bool AllValid(V entry) { + return entry == ValidityBuffer::MAX_ENTRY; + } + static inline bool NoneValid(V entry) { + return entry == 0; + } + static inline bool RowIsValid(const V &entry, const idx_t &idx_in_entry) { + return entry & (V(1) << V(idx_in_entry)); + } + static inline void GetEntryIndex(idx_t row_idx, idx_t &entry_idx, idx_t &idx_in_entry) { + entry_idx = row_idx / BITS_PER_VALUE; + idx_in_entry = row_idx % BITS_PER_VALUE; + } + //! Get an entry that has first-n bits set as valid and rest set as invalid + static inline V EntryWithValidBits(idx_t n) { + if (n == 0) { + return V(0); + } + return ValidityBuffer::MAX_ENTRY >> (BITS_PER_VALUE - n); + } + static inline idx_t SizeInBytes(idx_t n) { + return (n + BITS_PER_VALUE - 1) / BITS_PER_VALUE; + } + + //! RowIsValidUnsafe should only be used if AllValid() is false: it achieves the same as RowIsValid but skips a + //! not-null check + inline bool RowIsValidUnsafe(idx_t row_idx) const { + D_ASSERT(validity_mask); + idx_t entry_idx, idx_in_entry; + GetEntryIndex(row_idx, entry_idx, idx_in_entry); + auto entry = GetValidityEntry(entry_idx); + return RowIsValid(entry, idx_in_entry); + } + + //! Returns true if a row is valid (i.e. not null), false otherwise + inline bool RowIsValid(idx_t row_idx) const { + if (!validity_mask) { + return true; + } + return RowIsValidUnsafe(row_idx); + } + + //! Same as SetValid, but skips a null check on validity_mask + inline void SetValidUnsafe(idx_t row_idx) { + D_ASSERT(validity_mask); + idx_t entry_idx, idx_in_entry; + GetEntryIndex(row_idx, entry_idx, idx_in_entry); + validity_mask[entry_idx] |= (V(1) << V(idx_in_entry)); + } + + //! Marks the entry at the specified row index as valid (i.e. not-null) + inline void SetValid(idx_t row_idx) { + if (!validity_mask) { + // if AllValid() we don't need to do anything + // the row is already valid + return; + } + SetValidUnsafe(row_idx); + } + + //! Marks the bit at the specified entry as invalid (i.e. null) + inline void SetInvalidUnsafe(idx_t entry_idx, idx_t idx_in_entry) { + D_ASSERT(validity_mask); + validity_mask[entry_idx] &= ~(V(1) << V(idx_in_entry)); + } + + //! Marks the bit at the specified row index as invalid (i.e. null) + inline void SetInvalidUnsafe(idx_t row_idx) { + idx_t entry_idx, idx_in_entry; + GetEntryIndex(row_idx, entry_idx, idx_in_entry); + SetInvalidUnsafe(entry_idx, idx_in_entry); + } + + //! Marks the entry at the specified row index as invalid (i.e. null) + inline void SetInvalid(idx_t row_idx) { + if (!validity_mask) { + D_ASSERT(row_idx <= STANDARD_VECTOR_SIZE); + Initialize(STANDARD_VECTOR_SIZE); + } + SetInvalidUnsafe(row_idx); + } + + //! Mark the entry at the specified index as either valid or invalid (non-null or null) + inline void Set(idx_t row_idx, bool valid) { + if (valid) { + SetValid(row_idx); + } else { + SetInvalid(row_idx); + } + } + + //! Ensure the validity mask is writable, allocating space if it is not initialized + inline void EnsureWritable() { + if (!validity_mask) { + Initialize(); + } + } + + //! Marks exactly "count" bits in the validity mask as invalid (null) + inline void SetAllInvalid(idx_t count) { + EnsureWritable(); + if (count == 0) { + return; + } + auto last_entry_index = ValidityBuffer::EntryCount(count) - 1; + for (idx_t i = 0; i < last_entry_index; i++) { + validity_mask[i] = 0; + } + auto last_entry_bits = count % static_cast(BITS_PER_VALUE); + validity_mask[last_entry_index] = (last_entry_bits == 0) ? 0 : (ValidityBuffer::MAX_ENTRY << (last_entry_bits)); + } + + //! Marks exactly "count" bits in the validity mask as valid (not null) + inline void SetAllValid(idx_t count) { + EnsureWritable(); + if (count == 0) { + return; + } + auto last_entry_index = ValidityBuffer::EntryCount(count) - 1; + for (idx_t i = 0; i < last_entry_index; i++) { + validity_mask[i] = ValidityBuffer::MAX_ENTRY; + } + auto last_entry_bits = count % static_cast(BITS_PER_VALUE); + validity_mask[last_entry_index] |= + (last_entry_bits == 0) ? ValidityBuffer::MAX_ENTRY : ~(ValidityBuffer::MAX_ENTRY << (last_entry_bits)); + } + + inline bool IsMaskSet() const { + if (validity_mask) { + return true; + } + return false; + } + +public: + inline void Initialize(validity_t *validity) { + validity_data.reset(); + validity_mask = validity; + } + inline void Initialize(const TemplatedValidityMask &other) { + validity_mask = other.validity_mask; + validity_data = other.validity_data; + } + inline void Initialize(idx_t count = STANDARD_VECTOR_SIZE) { + validity_data = make_buffer(count); + validity_mask = validity_data->owned_data.get(); + } + inline void Copy(const TemplatedValidityMask &other, idx_t count) { + if (other.AllValid()) { + validity_data = nullptr; + validity_mask = nullptr; + } else { + validity_data = make_buffer(other.validity_mask, count); + validity_mask = validity_data->owned_data.get(); + } + } + +protected: + V *validity_mask; + buffer_ptr validity_data; +}; + +struct ValidityMask : public TemplatedValidityMask { +public: + inline ValidityMask() : TemplatedValidityMask(nullptr) { + } + inline explicit ValidityMask(idx_t max_count) : TemplatedValidityMask(max_count) { + } + inline explicit ValidityMask(validity_t *ptr) : TemplatedValidityMask(ptr) { + } + inline ValidityMask(const ValidityMask &original, idx_t count) : TemplatedValidityMask(original, count) { + } + +public: + DUCKDB_API void Resize(idx_t old_size, idx_t new_size); + + DUCKDB_API void SliceInPlace(const ValidityMask &other, idx_t target_offset, idx_t source_offset, idx_t count); + DUCKDB_API void Slice(const ValidityMask &other, idx_t source_offset, idx_t count); + DUCKDB_API void Combine(const ValidityMask &other, idx_t count); + DUCKDB_API string ToString(idx_t count) const; + + DUCKDB_API static bool IsAligned(idx_t count); + + void Write(WriteStream &writer, idx_t count); + void Read(ReadStream &reader, idx_t count); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/value.hpp b/src/duckdb/src/include/duckdb/common/types/value.hpp new file mode 100644 index 00000000..1d748169 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/value.hpp @@ -0,0 +1,545 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/value.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/datetime.hpp" +#include "duckdb/common/types/interval.hpp" + +namespace duckdb { + +class CastFunctionSet; +struct GetCastFunctionInput; +struct ExtraValueInfo; + +//! The Value object holds a single arbitrary value of any type that can be +//! stored in the database. +class Value { + friend struct StringValue; + friend struct StructValue; + friend struct ListValue; + friend struct UnionValue; + +public: + //! Create an empty NULL value of the specified type + DUCKDB_API explicit Value(LogicalType type = LogicalType::SQLNULL); + //! Create an INTEGER value + DUCKDB_API Value(int32_t val); // NOLINT: Allow implicit conversion from `int32_t` + //! Create a BIGINT value + DUCKDB_API Value(int64_t val); // NOLINT: Allow implicit conversion from `int64_t` + //! Create a FLOAT value + DUCKDB_API Value(float val); // NOLINT: Allow implicit conversion from `float` + //! Create a DOUBLE value + DUCKDB_API Value(double val); // NOLINT: Allow implicit conversion from `double` + //! Create a VARCHAR value + DUCKDB_API Value(const char *val); // NOLINT: Allow implicit conversion from `const char *` + //! Create a NULL value + DUCKDB_API Value(std::nullptr_t val); // NOLINT: Allow implicit conversion from `nullptr_t` + //! Create a VARCHAR value + DUCKDB_API Value(string_t val); // NOLINT: Allow implicit conversion from `string_t` + //! Create a VARCHAR value + DUCKDB_API Value(string val); // NOLINT: Allow implicit conversion from `string` + //! Copy constructor + DUCKDB_API Value(const Value &other); + //! Move constructor + DUCKDB_API Value(Value &&other) noexcept; + //! Destructor + DUCKDB_API ~Value(); + + // copy assignment + DUCKDB_API Value &operator=(const Value &other); + // move assignment + DUCKDB_API Value &operator=(Value &&other) noexcept; + + inline LogicalType &GetTypeMutable() { + return type_; + } + inline const LogicalType &type() const { // NOLINT + return type_; + } + inline bool IsNull() const { + return is_null; + } + + //! Create the lowest possible value of a given type (numeric only) + DUCKDB_API static Value MinimumValue(const LogicalType &type); + //! Create the highest possible value of a given type (numeric only) + DUCKDB_API static Value MaximumValue(const LogicalType &type); + //! Create the negative infinite value of a given type (numeric only) + DUCKDB_API static Value NegativeInfinity(const LogicalType &type); + //! Create the positive infinite value of a given type (numeric only) + DUCKDB_API static Value Infinity(const LogicalType &type); + //! Create a Numeric value of the specified type with the specified value + DUCKDB_API static Value Numeric(const LogicalType &type, int64_t value); + DUCKDB_API static Value Numeric(const LogicalType &type, hugeint_t value); + + //! Create a tinyint Value from a specified value + DUCKDB_API static Value BOOLEAN(int8_t value); + //! Create a tinyint Value from a specified value + DUCKDB_API static Value TINYINT(int8_t value); + //! Create a smallint Value from a specified value + DUCKDB_API static Value SMALLINT(int16_t value); + //! Create an integer Value from a specified value + DUCKDB_API static Value INTEGER(int32_t value); + //! Create a bigint Value from a specified value + DUCKDB_API static Value BIGINT(int64_t value); + //! Create an unsigned tinyint Value from a specified value + DUCKDB_API static Value UTINYINT(uint8_t value); + //! Create an unsigned smallint Value from a specified value + DUCKDB_API static Value USMALLINT(uint16_t value); + //! Create an unsigned integer Value from a specified value + DUCKDB_API static Value UINTEGER(uint32_t value); + //! Create an unsigned bigint Value from a specified value + DUCKDB_API static Value UBIGINT(uint64_t value); + //! Create a hugeint Value from a specified value + DUCKDB_API static Value HUGEINT(hugeint_t value); + //! Create a uuid Value from a specified value + DUCKDB_API static Value UUID(const string &value); + //! Create a uuid Value from a specified value + DUCKDB_API static Value UUID(hugeint_t value); + //! Create a hash Value from a specified value + DUCKDB_API static Value HASH(hash_t value); + //! Create a pointer Value from a specified value + DUCKDB_API static Value POINTER(uintptr_t value); + //! Create a date Value from a specified date + DUCKDB_API static Value DATE(date_t date); + //! Create a date Value from a specified date + DUCKDB_API static Value DATE(int32_t year, int32_t month, int32_t day); + //! Create a time Value from a specified time + DUCKDB_API static Value TIME(dtime_t time); + DUCKDB_API static Value TIMETZ(dtime_tz_t time); + //! Create a time Value from a specified time + DUCKDB_API static Value TIME(int32_t hour, int32_t min, int32_t sec, int32_t micros); + //! Create a timestamp Value from a specified date/time combination + DUCKDB_API static Value TIMESTAMP(date_t date, dtime_t time); + //! Create a timestamp Value from a specified timestamp + DUCKDB_API static Value TIMESTAMP(timestamp_t timestamp); + DUCKDB_API static Value TIMESTAMPNS(timestamp_t timestamp); + DUCKDB_API static Value TIMESTAMPMS(timestamp_t timestamp); + DUCKDB_API static Value TIMESTAMPSEC(timestamp_t timestamp); + DUCKDB_API static Value TIMESTAMPTZ(timestamp_t timestamp); + //! Create a timestamp Value from a specified timestamp in separate values + DUCKDB_API static Value TIMESTAMP(int32_t year, int32_t month, int32_t day, int32_t hour, int32_t min, int32_t sec, + int32_t micros); + DUCKDB_API static Value INTERVAL(int32_t months, int32_t days, int64_t micros); + DUCKDB_API static Value INTERVAL(interval_t interval); + + // Create a enum Value from a specified uint value + DUCKDB_API static Value ENUM(uint64_t value, const LogicalType &original_type); + + // Decimal values + DUCKDB_API static Value DECIMAL(int16_t value, uint8_t width, uint8_t scale); + DUCKDB_API static Value DECIMAL(int32_t value, uint8_t width, uint8_t scale); + DUCKDB_API static Value DECIMAL(int64_t value, uint8_t width, uint8_t scale); + DUCKDB_API static Value DECIMAL(hugeint_t value, uint8_t width, uint8_t scale); + //! Create a float Value from a specified value + DUCKDB_API static Value FLOAT(float value); + //! Create a double Value from a specified value + DUCKDB_API static Value DOUBLE(double value); + //! Create a struct value with given list of entries + DUCKDB_API static Value STRUCT(child_list_t values); + //! Create a list value with the given entries, list type is inferred from children + //! Cannot be called with an empty list, use either EMPTYLIST or LIST with a type instead + DUCKDB_API static Value LIST(vector values); + //! Create a list value with the given entries + DUCKDB_API static Value LIST(const LogicalType &child_type, vector values); + //! Create an empty list with the specified child-type + DUCKDB_API static Value EMPTYLIST(const LogicalType &child_type); + //! Create a map value with the given entries + DUCKDB_API static Value MAP(const LogicalType &child_type, vector values); + //! Create a union value from a selected value and a tag from a set of alternatives. + DUCKDB_API static Value UNION(child_list_t members, uint8_t tag, Value value); + + //! Create a blob Value from a data pointer and a length: no bytes are interpreted + DUCKDB_API static Value BLOB(const_data_ptr_t data, idx_t len); + static Value BLOB_RAW(const string &data) { // NOLINT + return Value::BLOB(const_data_ptr_cast(data.c_str()), data.size()); + } + //! Creates a blob by casting a specified string to a blob (i.e. interpreting \x characters) + DUCKDB_API static Value BLOB(const string &data); + //! Creates a bitstring by casting a specified string to a bitstring + DUCKDB_API static Value BIT(const_data_ptr_t data, idx_t len); + DUCKDB_API static Value BIT(const string &data); + + template + T GetValue() const; + template + static Value CreateValue(T value) { + static_assert(AlwaysFalse::value, "No specialization exists for this type"); + return Value(nullptr); + } + // Returns the internal value. Unlike GetValue(), this method does not perform casting, and assumes T matches the + // type of the value. Only use this if you know what you are doing. + template + T GetValueUnsafe() const; + //! Returns a reference to the internal value. This can only be used for primitive types. + template + T &GetReferenceUnsafe(); + + //! Return a copy of this value + Value Copy() const { + return Value(*this); + } + + //! Hashes the Value + DUCKDB_API hash_t Hash() const; + //! Convert this value to a string + DUCKDB_API string ToString() const; + //! Convert this value to a SQL-parseable string + DUCKDB_API string ToSQLString() const; + + DUCKDB_API uintptr_t GetPointer() const; + + //! Cast this value to another type, throws exception if its not possible + DUCKDB_API Value CastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, + bool strict = false) const; + DUCKDB_API Value CastAs(ClientContext &context, const LogicalType &target_type, bool strict = false) const; + DUCKDB_API Value DefaultCastAs(const LogicalType &target_type, bool strict = false) const; + //! Tries to cast this value to another type, and stores the result in "new_value" + DUCKDB_API bool TryCastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, + Value &new_value, string *error_message, bool strict = false) const; + DUCKDB_API bool TryCastAs(ClientContext &context, const LogicalType &target_type, Value &new_value, + string *error_message, bool strict = false) const; + DUCKDB_API bool DefaultTryCastAs(const LogicalType &target_type, Value &new_value, string *error_message, + bool strict = false) const; + //! Tries to cast this value to another type, and stores the result in THIS value again + DUCKDB_API bool TryCastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, + bool strict = false); + DUCKDB_API bool TryCastAs(ClientContext &context, const LogicalType &target_type, bool strict = false); + DUCKDB_API bool DefaultTryCastAs(const LogicalType &target_type, bool strict = false); + + DUCKDB_API void Reinterpret(LogicalType new_type); + + //! Serializes a Value to a stand-alone binary blob + DUCKDB_API void Serialize(Serializer &serializer) const; + //! Deserializes a Value from a blob + DUCKDB_API static Value Deserialize(Deserializer &deserializer); + + //===--------------------------------------------------------------------===// + // Comparison Operators + //===--------------------------------------------------------------------===// + DUCKDB_API bool operator==(const Value &rhs) const; + DUCKDB_API bool operator!=(const Value &rhs) const; + DUCKDB_API bool operator<(const Value &rhs) const; + DUCKDB_API bool operator>(const Value &rhs) const; + DUCKDB_API bool operator<=(const Value &rhs) const; + DUCKDB_API bool operator>=(const Value &rhs) const; + + DUCKDB_API bool operator==(const int64_t &rhs) const; + DUCKDB_API bool operator!=(const int64_t &rhs) const; + DUCKDB_API bool operator<(const int64_t &rhs) const; + DUCKDB_API bool operator>(const int64_t &rhs) const; + DUCKDB_API bool operator<=(const int64_t &rhs) const; + DUCKDB_API bool operator>=(const int64_t &rhs) const; + + DUCKDB_API static bool FloatIsFinite(float value); + DUCKDB_API static bool DoubleIsFinite(double value); + template + static bool IsNan(T value) { + throw InternalException("Unimplemented template type for Value::IsNan"); + } + template + static bool IsFinite(T value) { + return true; + } + DUCKDB_API static bool StringIsValid(const char *str, idx_t length); + static bool StringIsValid(const string &str) { + return StringIsValid(str.c_str(), str.size()); + } + + //! Returns true if the values are (approximately) equivalent. Note this is NOT the SQL equivalence. For this + //! function, NULL values are equivalent and floating point values that are close are equivalent. + DUCKDB_API static bool ValuesAreEqual(CastFunctionSet &set, GetCastFunctionInput &get_input, + const Value &result_value, const Value &value); + DUCKDB_API static bool ValuesAreEqual(ClientContext &context, const Value &result_value, const Value &value); + DUCKDB_API static bool DefaultValuesAreEqual(const Value &result_value, const Value &value); + //! Returns true if the values are not distinct from each other, following SQL semantics for NOT DISTINCT FROM. + DUCKDB_API static bool NotDistinctFrom(const Value &lvalue, const Value &rvalue); + + friend std::ostream &operator<<(std::ostream &out, const Value &val) { + out << val.ToString(); + return out; + } + DUCKDB_API void Print() const; + +private: + //! The logical of the value + LogicalType type_; // NOLINT + + //! Whether or not the value is NULL + bool is_null; + + //! The value of the object, if it is of a constant size Type + union Val { + int8_t boolean; + int8_t tinyint; + int16_t smallint; + int32_t integer; + int64_t bigint; + uint8_t utinyint; + uint16_t usmallint; + uint32_t uinteger; + uint64_t ubigint; + hugeint_t hugeint; + float float_; // NOLINT + double double_; // NOLINT + uintptr_t pointer; + uint64_t hash; + date_t date; + dtime_t time; + dtime_tz_t timetz; + timestamp_t timestamp; + interval_t interval; + } value_; // NOLINT + + shared_ptr value_info_; // NOLINT + +private: + template + T GetValueInternal() const; +}; + +//===--------------------------------------------------------------------===// +// Type-specific getters +//===--------------------------------------------------------------------===// +// Note that these are equivalent to calling GetValueUnsafe, meaning no cast will be performed +// instead, an assertion will be triggered if the value is not of the correct type +struct BooleanValue { + DUCKDB_API static bool Get(const Value &value); +}; + +struct TinyIntValue { + DUCKDB_API static int8_t Get(const Value &value); +}; + +struct SmallIntValue { + DUCKDB_API static int16_t Get(const Value &value); +}; + +struct IntegerValue { + DUCKDB_API static int32_t Get(const Value &value); +}; + +struct BigIntValue { + DUCKDB_API static int64_t Get(const Value &value); +}; + +struct HugeIntValue { + DUCKDB_API static hugeint_t Get(const Value &value); +}; + +struct UTinyIntValue { + DUCKDB_API static uint8_t Get(const Value &value); +}; + +struct USmallIntValue { + DUCKDB_API static uint16_t Get(const Value &value); +}; + +struct UIntegerValue { + DUCKDB_API static uint32_t Get(const Value &value); +}; + +struct UBigIntValue { + DUCKDB_API static uint64_t Get(const Value &value); +}; + +struct FloatValue { + DUCKDB_API static float Get(const Value &value); +}; + +struct DoubleValue { + DUCKDB_API static double Get(const Value &value); +}; + +struct StringValue { + DUCKDB_API static const string &Get(const Value &value); +}; + +struct DateValue { + DUCKDB_API static date_t Get(const Value &value); +}; + +struct TimeValue { + DUCKDB_API static dtime_t Get(const Value &value); +}; + +struct TimestampValue { + DUCKDB_API static timestamp_t Get(const Value &value); +}; + +struct IntervalValue { + DUCKDB_API static interval_t Get(const Value &value); +}; + +struct StructValue { + DUCKDB_API static const vector &GetChildren(const Value &value); +}; + +struct ListValue { + DUCKDB_API static const vector &GetChildren(const Value &value); +}; + +struct UnionValue { + DUCKDB_API static const Value &GetValue(const Value &value); + DUCKDB_API static uint8_t GetTag(const Value &value); + DUCKDB_API static const LogicalType &GetType(const Value &value); +}; + +//! Return the internal integral value for any type that is stored as an integral value internally +//! This can be used on values of type integer, uinteger, but also date, timestamp, decimal, etc +struct IntegralValue { + static hugeint_t Get(const Value &value); +}; + +template <> +Value DUCKDB_API Value::CreateValue(bool value); +template <> +Value DUCKDB_API Value::CreateValue(uint8_t value); +template <> +Value DUCKDB_API Value::CreateValue(uint16_t value); +template <> +Value DUCKDB_API Value::CreateValue(uint32_t value); +template <> +Value DUCKDB_API Value::CreateValue(uint64_t value); +template <> +Value DUCKDB_API Value::CreateValue(int8_t value); +template <> +Value DUCKDB_API Value::CreateValue(int16_t value); +template <> +Value DUCKDB_API Value::CreateValue(int32_t value); +template <> +Value DUCKDB_API Value::CreateValue(int64_t value); +template <> +Value DUCKDB_API Value::CreateValue(hugeint_t value); +template <> +Value DUCKDB_API Value::CreateValue(date_t value); +template <> +Value DUCKDB_API Value::CreateValue(dtime_t value); +template <> +Value DUCKDB_API Value::CreateValue(dtime_tz_t value); +template <> +Value DUCKDB_API Value::CreateValue(timestamp_t value); +template <> +Value DUCKDB_API Value::CreateValue(timestamp_sec_t value); +template <> +Value DUCKDB_API Value::CreateValue(timestamp_ms_t value); +template <> +Value DUCKDB_API Value::CreateValue(timestamp_ns_t value); +template <> +Value DUCKDB_API Value::CreateValue(timestamp_tz_t value); +template <> +Value DUCKDB_API Value::CreateValue(const char *value); +template <> +Value DUCKDB_API Value::CreateValue(string value); +template <> +Value DUCKDB_API Value::CreateValue(string_t value); +template <> +Value DUCKDB_API Value::CreateValue(float value); +template <> +Value DUCKDB_API Value::CreateValue(double value); +template <> +Value DUCKDB_API Value::CreateValue(interval_t value); +template <> +Value DUCKDB_API Value::CreateValue(Value value); + +template <> +DUCKDB_API bool Value::GetValue() const; +template <> +DUCKDB_API int8_t Value::GetValue() const; +template <> +DUCKDB_API int16_t Value::GetValue() const; +template <> +DUCKDB_API int32_t Value::GetValue() const; +template <> +DUCKDB_API int64_t Value::GetValue() const; +template <> +DUCKDB_API uint8_t Value::GetValue() const; +template <> +DUCKDB_API uint16_t Value::GetValue() const; +template <> +DUCKDB_API uint32_t Value::GetValue() const; +template <> +DUCKDB_API uint64_t Value::GetValue() const; +template <> +DUCKDB_API hugeint_t Value::GetValue() const; +template <> +DUCKDB_API string Value::GetValue() const; +template <> +DUCKDB_API float Value::GetValue() const; +template <> +DUCKDB_API double Value::GetValue() const; +template <> +DUCKDB_API date_t Value::GetValue() const; +template <> +DUCKDB_API dtime_t Value::GetValue() const; +template <> +DUCKDB_API timestamp_t Value::GetValue() const; +template <> +DUCKDB_API interval_t Value::GetValue() const; +template <> +DUCKDB_API Value Value::GetValue() const; + +template <> +DUCKDB_API bool Value::GetValueUnsafe() const; +template <> +DUCKDB_API int8_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API int16_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API int32_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API int64_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API hugeint_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API uint8_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API uint16_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API uint32_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API uint64_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API string Value::GetValueUnsafe() const; +template <> +DUCKDB_API string_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API float Value::GetValueUnsafe() const; +template <> +DUCKDB_API double Value::GetValueUnsafe() const; +template <> +DUCKDB_API date_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API dtime_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API timestamp_t Value::GetValueUnsafe() const; +template <> +DUCKDB_API interval_t Value::GetValueUnsafe() const; + +template <> +DUCKDB_API bool Value::IsNan(float input); +template <> +DUCKDB_API bool Value::IsNan(double input); + +template <> +DUCKDB_API bool Value::IsFinite(float input); +template <> +DUCKDB_API bool Value::IsFinite(double input); +template <> +DUCKDB_API bool Value::IsFinite(date_t input); +template <> +DUCKDB_API bool Value::IsFinite(timestamp_t input); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/value_map.hpp b/src/duckdb/src/include/duckdb/common/types/value_map.hpp new file mode 100644 index 00000000..dfe4d77d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/value_map.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/value_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +struct ValueHashFunction { + uint64_t operator()(const Value &value) const { + return (uint64_t)value.Hash(); + } +}; + +struct ValueEquality { + bool operator()(const Value &a, const Value &b) const { + return Value::NotDistinctFrom(a, b); + } +}; + +template +using value_map_t = unordered_map; + +using value_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/vector.hpp b/src/duckdb/src/include/duckdb/common/types/vector.hpp new file mode 100644 index 00000000..885aabee --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/vector.hpp @@ -0,0 +1,505 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/vector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/bitset.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/vector_type.hpp" +#include "duckdb/common/types/selection_vector.hpp" +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/types/vector_buffer.hpp" +#include "duckdb/common/vector_size.hpp" + +namespace duckdb { + +struct UnifiedVectorFormat { + const SelectionVector *sel; + data_ptr_t data; + ValidityMask validity; + SelectionVector owned_sel; + + template + static inline const T *GetData(const UnifiedVectorFormat &format) { + return reinterpret_cast(format.data); + } + template + static inline T *GetDataNoConst(UnifiedVectorFormat &format) { + return reinterpret_cast(format.data); + } +}; + +struct RecursiveUnifiedVectorFormat { + UnifiedVectorFormat unified; + vector children; +}; + +class VectorCache; +class VectorStructBuffer; +class VectorListBuffer; + +struct SelCache; + +struct ConsecutiveChildListInfo { + ConsecutiveChildListInfo() : is_constant(true), needs_slicing(false), child_list_info(list_entry_t(0, 0)) { + } + bool is_constant; + bool needs_slicing; + list_entry_t child_list_info; +}; + +//! Vector of values of a specified PhysicalType. +class Vector { + friend struct ConstantVector; + friend struct DictionaryVector; + friend struct FlatVector; + friend struct ListVector; + friend struct StringVector; + friend struct FSSTVector; + friend struct StructVector; + friend struct UnionVector; + friend struct SequenceVector; + + friend class DataChunk; + friend class VectorCacheBuffer; + +public: + //! Create a vector that references the other vector + DUCKDB_API Vector(Vector &other); + //! Create a vector that slices another vector + DUCKDB_API explicit Vector(Vector &other, const SelectionVector &sel, idx_t count); + //! Create a vector that slices another vector between a pair of offsets + DUCKDB_API explicit Vector(Vector &other, idx_t offset, idx_t end); + //! Create a vector of size one holding the passed on value + DUCKDB_API explicit Vector(const Value &value); + //! Create a vector of size tuple_count (non-standard) + DUCKDB_API explicit Vector(LogicalType type, idx_t capacity = STANDARD_VECTOR_SIZE); + //! Create an empty standard vector with a type, equivalent to calling Vector(type, true, false) + DUCKDB_API explicit Vector(const VectorCache &cache); + //! Create a non-owning vector that references the specified data + DUCKDB_API Vector(LogicalType type, data_ptr_t dataptr); + //! Create an owning vector that holds at most STANDARD_VECTOR_SIZE entries. + /*! + Create a new vector + If create_data is true, the vector will be an owning empty vector. + If zero_data is true, the allocated data will be zero-initialized. + */ + DUCKDB_API Vector(LogicalType type, bool create_data, bool zero_data, idx_t capacity = STANDARD_VECTOR_SIZE); + // implicit copying of Vectors is not allowed + Vector(const Vector &) = delete; + // but moving of vectors is allowed + DUCKDB_API Vector(Vector &&other) noexcept; + +public: + //! Create a vector that references the specified value. + DUCKDB_API void Reference(const Value &value); + //! Causes this vector to reference the data held by the other vector. + //! The type of the "other" vector should match the type of this vector + DUCKDB_API void Reference(const Vector &other); + //! Reinterpret the data of the other vector as the type of this vector + //! Note that this takes the data of the other vector as-is and places it in this vector + //! Without changing the type of this vector + DUCKDB_API void Reinterpret(const Vector &other); + + //! Causes this vector to reference the data held by the other vector, changes the type if required. + DUCKDB_API void ReferenceAndSetType(const Vector &other); + + //! Resets a vector from a vector cache. + //! This turns the vector back into an empty FlatVector with STANDARD_VECTOR_SIZE entries. + //! The VectorCache is used so this can be done without requiring any allocations. + DUCKDB_API void ResetFromCache(const VectorCache &cache); + + //! Creates a reference to a slice of the other vector + DUCKDB_API void Slice(Vector &other, idx_t offset, idx_t end); + //! Creates a reference to a slice of the other vector + DUCKDB_API void Slice(Vector &other, const SelectionVector &sel, idx_t count); + //! Turns the vector into a dictionary vector with the specified dictionary + DUCKDB_API void Slice(const SelectionVector &sel, idx_t count); + //! Slice the vector, keeping the result around in a cache or potentially using the cache instead of slicing + DUCKDB_API void Slice(const SelectionVector &sel, idx_t count, SelCache &cache); + + //! Creates the data of this vector with the specified type. Any data that + //! is currently in the vector is destroyed. + DUCKDB_API void Initialize(bool zero_data = false, idx_t capacity = STANDARD_VECTOR_SIZE); + + //! Converts this Vector to a printable string representation + DUCKDB_API string ToString(idx_t count) const; + DUCKDB_API void Print(idx_t count) const; + + DUCKDB_API string ToString() const; + DUCKDB_API void Print() const; + + //! Flatten the vector, removing any compression and turning it into a FLAT_VECTOR + DUCKDB_API void Flatten(idx_t count); + DUCKDB_API void Flatten(const SelectionVector &sel, idx_t count); + //! Creates a UnifiedVectorFormat of a vector + //! The UnifiedVectorFormat allows efficient reading of vectors regardless of their vector type + //! It contains (1) a data pointer, (2) a validity mask, and (3) a selection vector + //! Access to the individual vector elements can be performed through data_pointer[sel_idx[i]]/validity[sel_idx[i]] + //! The most common vector types (flat, constant & dictionary) can be converted to the canonical format "for free" + //! ToUnifiedFormat was originally called Orrify, as a tribute to Orri Erling who came up with it + DUCKDB_API void ToUnifiedFormat(idx_t count, UnifiedVectorFormat &data); + //! Recursively calls UnifiedVectorFormat on a vector and its child vectors (for nested types) + static void RecursiveToUnifiedFormat(Vector &input, idx_t count, RecursiveUnifiedVectorFormat &data); + + //! Turn the vector into a sequence vector + DUCKDB_API void Sequence(int64_t start, int64_t increment, idx_t count); + + //! Verify that the Vector is in a consistent, not corrupt state. DEBUG + //! FUNCTION ONLY! + DUCKDB_API void Verify(idx_t count); + //! Asserts that the CheckMapValidity returns MapInvalidReason::VALID + DUCKDB_API static void VerifyMap(Vector &map, const SelectionVector &sel, idx_t count); + DUCKDB_API static void VerifyUnion(Vector &map, const SelectionVector &sel, idx_t count); + DUCKDB_API static void Verify(Vector &vector, const SelectionVector &sel, idx_t count); + DUCKDB_API void UTFVerify(idx_t count); + DUCKDB_API void UTFVerify(const SelectionVector &sel, idx_t count); + + //! Returns the [index] element of the Vector as a Value. + DUCKDB_API Value GetValue(idx_t index) const; + //! Sets the [index] element of the Vector to the specified Value. + DUCKDB_API void SetValue(idx_t index, const Value &val); + + inline void SetAuxiliary(buffer_ptr new_buffer) { + auxiliary = std::move(new_buffer); + }; + + //! This functions resizes the vector + DUCKDB_API void Resize(idx_t cur_size, idx_t new_size); + + DUCKDB_API void Serialize(Serializer &serializer, idx_t count); + DUCKDB_API void Deserialize(Deserializer &deserializer, idx_t count); + + // Getters + inline VectorType GetVectorType() const { + return vector_type; + } + inline const LogicalType &GetType() const { + return type; + } + inline data_ptr_t GetData() { + return data; + } + + inline buffer_ptr GetAuxiliary() { + return auxiliary; + } + + inline buffer_ptr GetBuffer() { + return buffer; + } + + // Setters + DUCKDB_API void SetVectorType(VectorType vector_type); + +private: + //! Returns the [index] element of the Vector as a Value. + static Value GetValue(const Vector &v, idx_t index); + //! Returns the [index] element of the Vector as a Value. + static Value GetValueInternal(const Vector &v, idx_t index); + +protected: + //! The vector type specifies how the data of the vector is physically stored (i.e. if it is a single repeated + //! constant, if it is compressed) + VectorType vector_type; + //! The type of the elements stored in the vector (e.g. integer, float) + LogicalType type; + //! A pointer to the data. + data_ptr_t data; + //! The validity mask of the vector + ValidityMask validity; + //! The main buffer holding the data of the vector + buffer_ptr buffer; + //! The buffer holding auxiliary data of the vector + //! e.g. a string vector uses this to store strings + buffer_ptr auxiliary; +}; + +//! The DictionaryBuffer holds a selection vector +class VectorChildBuffer : public VectorBuffer { +public: + explicit VectorChildBuffer(Vector vector) + : VectorBuffer(VectorBufferType::VECTOR_CHILD_BUFFER), data(std::move(vector)) { + } + +public: + Vector data; +}; + +struct ConstantVector { + static inline const_data_ptr_t GetData(const Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR || + vector.GetVectorType() == VectorType::FLAT_VECTOR); + return vector.data; + } + static inline data_ptr_t GetData(Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR || + vector.GetVectorType() == VectorType::FLAT_VECTOR); + return vector.data; + } + template + static inline const T *GetData(const Vector &vector) { + return (const T *)ConstantVector::GetData(vector); + } + template + static inline T *GetData(Vector &vector) { + return (T *)ConstantVector::GetData(vector); + } + static inline bool IsNull(const Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + return !vector.validity.RowIsValid(0); + } + DUCKDB_API static void SetNull(Vector &vector, bool is_null); + static inline ValidityMask &Validity(Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + return vector.validity; + } + DUCKDB_API static const SelectionVector *ZeroSelectionVector(idx_t count, SelectionVector &owned_sel); + DUCKDB_API static const SelectionVector *ZeroSelectionVector(); + //! Turns "vector" into a constant vector by referencing a value within the source vector + DUCKDB_API static void Reference(Vector &vector, Vector &source, idx_t position, idx_t count); + + static const sel_t ZERO_VECTOR[STANDARD_VECTOR_SIZE]; +}; + +struct DictionaryVector { + static inline const SelectionVector &SelVector(const Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::DICTIONARY_VECTOR); + return ((const DictionaryBuffer &)*vector.buffer).GetSelVector(); + } + static inline SelectionVector &SelVector(Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::DICTIONARY_VECTOR); + return ((DictionaryBuffer &)*vector.buffer).GetSelVector(); + } + static inline const Vector &Child(const Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::DICTIONARY_VECTOR); + return ((const VectorChildBuffer &)*vector.auxiliary).data; + } + static inline Vector &Child(Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::DICTIONARY_VECTOR); + return ((VectorChildBuffer &)*vector.auxiliary).data; + } +}; + +struct FlatVector { + static inline data_ptr_t GetData(Vector &vector) { + return ConstantVector::GetData(vector); + } + template + static inline const T *GetData(const Vector &vector) { + return ConstantVector::GetData(vector); + } + template + static inline T *GetData(Vector &vector) { + return ConstantVector::GetData(vector); + } + static inline void SetData(Vector &vector, data_ptr_t data) { + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); + vector.data = data; + } + template + static inline T GetValue(Vector &vector, idx_t idx) { + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); + return FlatVector::GetData(vector)[idx]; + } + static inline const ValidityMask &Validity(const Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); + return vector.validity; + } + static inline ValidityMask &Validity(Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); + return vector.validity; + } + static inline void SetValidity(Vector &vector, ValidityMask &new_validity) { + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); + vector.validity.Initialize(new_validity); + } + DUCKDB_API static void SetNull(Vector &vector, idx_t idx, bool is_null); + static inline bool IsNull(const Vector &vector, idx_t idx) { + D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); + return !vector.validity.RowIsValid(idx); + } + DUCKDB_API static const SelectionVector *IncrementalSelectionVector(); +}; + +struct ListVector { + static inline list_entry_t *GetData(Vector &v) { + if (v.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(v); + return GetData(child); + } + return FlatVector::GetData(v); + } + //! Gets a reference to the underlying child-vector of a list + DUCKDB_API static const Vector &GetEntry(const Vector &vector); + //! Gets a reference to the underlying child-vector of a list + DUCKDB_API static Vector &GetEntry(Vector &vector); + //! Gets the total size of the underlying child-vector of a list + DUCKDB_API static idx_t GetListSize(const Vector &vector); + //! Sets the total size of the underlying child-vector of a list + DUCKDB_API static void SetListSize(Vector &vec, idx_t size); + //! Gets the total capacity of the underlying child-vector of a list + DUCKDB_API static idx_t GetListCapacity(const Vector &vector); + //! Sets the total capacity of the underlying child-vector of a list + DUCKDB_API static void Reserve(Vector &vec, idx_t required_capacity); + DUCKDB_API static void Append(Vector &target, const Vector &source, idx_t source_size, idx_t source_offset = 0); + DUCKDB_API static void Append(Vector &target, const Vector &source, const SelectionVector &sel, idx_t source_size, + idx_t source_offset = 0); + DUCKDB_API static void PushBack(Vector &target, const Value &insert); + //! Returns the child_vector of list starting at offset until offset + count, and its length + DUCKDB_API static idx_t GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count); + //! Returns information to only copy a section of a list child vector + DUCKDB_API static ConsecutiveChildListInfo GetConsecutiveChildListInfo(Vector &list, idx_t offset, idx_t count); + //! Slice and flatten a child vector to only contain a consecutive subsection of the child entries + DUCKDB_API static void GetConsecutiveChildSelVector(Vector &list, SelectionVector &sel, idx_t offset, idx_t count); + //! Share the entry of the other list vector + DUCKDB_API static void ReferenceEntry(Vector &vector, Vector &other); +}; + +struct StringVector { + //! Add a string to the string heap of the vector (auxiliary data) + DUCKDB_API static string_t AddString(Vector &vector, const char *data, idx_t len); + //! Add a string or a blob to the string heap of the vector (auxiliary data) + //! This function is the same as ::AddString, except the added data does not need to be valid UTF8 + DUCKDB_API static string_t AddStringOrBlob(Vector &vector, const char *data, idx_t len); + //! Add a string to the string heap of the vector (auxiliary data) + DUCKDB_API static string_t AddString(Vector &vector, const char *data); + //! Add a string to the string heap of the vector (auxiliary data) + DUCKDB_API static string_t AddString(Vector &vector, string_t data); + //! Add a string to the string heap of the vector (auxiliary data) + DUCKDB_API static string_t AddString(Vector &vector, const string &data); + //! Add a string or a blob to the string heap of the vector (auxiliary data) + //! This function is the same as ::AddString, except the added data does not need to be valid UTF8 + DUCKDB_API static string_t AddStringOrBlob(Vector &vector, string_t data); + //! Allocates an empty string of the specified size, and returns a writable pointer that can be used to store the + //! result of an operation + DUCKDB_API static string_t EmptyString(Vector &vector, idx_t len); + //! Adds a reference to a handle that stores strings of this vector + DUCKDB_API static void AddHandle(Vector &vector, BufferHandle handle); + //! Adds a reference to an unspecified vector buffer that stores strings of this vector + DUCKDB_API static void AddBuffer(Vector &vector, buffer_ptr buffer); + //! Add a reference from this vector to the string heap of the provided vector + DUCKDB_API static void AddHeapReference(Vector &vector, Vector &other); +}; + +struct FSSTVector { + static inline const ValidityMask &Validity(const Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::FSST_VECTOR); + return vector.validity; + } + static inline ValidityMask &Validity(Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::FSST_VECTOR); + return vector.validity; + } + static inline void SetValidity(Vector &vector, ValidityMask &new_validity) { + D_ASSERT(vector.GetVectorType() == VectorType::FSST_VECTOR); + vector.validity.Initialize(new_validity); + } + static inline const_data_ptr_t GetCompressedData(const Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::FSST_VECTOR); + return vector.data; + } + static inline data_ptr_t GetCompressedData(Vector &vector) { + D_ASSERT(vector.GetVectorType() == VectorType::FSST_VECTOR); + return vector.data; + } + template + static inline const T *GetCompressedData(const Vector &vector) { + return (const T *)FSSTVector::GetCompressedData(vector); + } + template + static inline T *GetCompressedData(Vector &vector) { + return (T *)FSSTVector::GetCompressedData(vector); + } + //! Decompresses an FSST_VECTOR into a FLAT_VECTOR. Note: validity is not copied. + static void DecompressVector(const Vector &src, Vector &dst, idx_t src_offset, idx_t dst_offset, idx_t copy_count, + const SelectionVector *sel); + + DUCKDB_API static string_t AddCompressedString(Vector &vector, string_t data); + DUCKDB_API static string_t AddCompressedString(Vector &vector, const char *data, idx_t len); + DUCKDB_API static void RegisterDecoder(Vector &vector, buffer_ptr &duckdb_fsst_decoder); + DUCKDB_API static void *GetDecoder(const Vector &vector); + //! Setting the string count is required to be able to correctly flatten the vector + DUCKDB_API static void SetCount(Vector &vector, idx_t count); + DUCKDB_API static idx_t GetCount(Vector &vector); +}; + +enum class MapInvalidReason : uint8_t { VALID, NULL_KEY_LIST, NULL_KEY, DUPLICATE_KEY }; + +struct MapVector { + DUCKDB_API static const Vector &GetKeys(const Vector &vector); + DUCKDB_API static const Vector &GetValues(const Vector &vector); + DUCKDB_API static Vector &GetKeys(Vector &vector); + DUCKDB_API static Vector &GetValues(Vector &vector); + DUCKDB_API static MapInvalidReason + CheckMapValidity(Vector &map, idx_t count, const SelectionVector &sel = *FlatVector::IncrementalSelectionVector()); + DUCKDB_API static void MapConversionVerify(Vector &vector, idx_t count); +}; + +struct StructVector { + DUCKDB_API static const vector> &GetEntries(const Vector &vector); + DUCKDB_API static vector> &GetEntries(Vector &vector); +}; + +enum class UnionInvalidReason : uint8_t { VALID, TAG_OUT_OF_RANGE, NO_MEMBERS, VALIDITY_OVERLAP, TAG_MISMATCH }; + +struct UnionVector { + // Unions are stored as structs, but the first child is always the "tag" + // vector, specifying the currently selected member for that row. + // The remaining children are the members of the union. + // INVARIANTS: + // 1. Only one member vector (the one "selected" by the tag) can be + // non-NULL in each row. + // + // 2. The validity of the tag vector always matches the validity of the + // union vector itself. + // + // 3. For each tag in the tag vector, 0 <= tag < |members| + + //! Get the tag vector of a union vector + DUCKDB_API static const Vector &GetTags(const Vector &v); + DUCKDB_API static Vector &GetTags(Vector &v); + + //! Get the tag at the specific index of the union vector + DUCKDB_API static union_tag_t GetTag(const Vector &vector, idx_t index); + + //! Get the member vector of a union vector by index + DUCKDB_API static const Vector &GetMember(const Vector &vector, idx_t member_index); + DUCKDB_API static Vector &GetMember(Vector &vector, idx_t member_index); + + //! Set every entry in the UnionVector to a specific member. + //! This is useful to set the entire vector to a single member, e.g. when "creating" + //! a union to return in a function, when you only have one alternative to return. + //! if 'keep_tags_for_null' is false, the tags will be set to NULL where the member is NULL. + //! (the validity of the tag vector will match the selected member vector) + //! otherwise, they are all set to the 'tag'. + //! This will also handle invalidation of the non-selected members + DUCKDB_API static void SetToMember(Vector &vector, union_tag_t tag, Vector &member_vector, idx_t count, + bool keep_tags_for_null); + + DUCKDB_API static UnionInvalidReason + CheckUnionValidity(Vector &vector, idx_t count, + const SelectionVector &sel = *FlatVector::IncrementalSelectionVector()); +}; + +struct SequenceVector { + static void GetSequence(const Vector &vector, int64_t &start, int64_t &increment, int64_t &sequence_count) { + D_ASSERT(vector.GetVectorType() == VectorType::SEQUENCE_VECTOR); + auto data = (int64_t *)vector.buffer->GetData(); + start = data[0]; + increment = data[1]; + sequence_count = data[2]; + } + static void GetSequence(const Vector &vector, int64_t &start, int64_t &increment) { + int64_t sequence_count; + GetSequence(vector, start, increment, sequence_count); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/vector_buffer.hpp b/src/duckdb/src/include/duckdb/common/types/vector_buffer.hpp new file mode 100644 index 00000000..94341102 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/vector_buffer.hpp @@ -0,0 +1,284 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/vector_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/selection_vector.hpp" +#include "duckdb/common/types/string_heap.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" + +namespace duckdb { + +class BufferHandle; +class VectorBuffer; +class Vector; + +enum class VectorBufferType : uint8_t { + STANDARD_BUFFER, // standard buffer, holds a single array of data + DICTIONARY_BUFFER, // dictionary buffer, holds a selection vector + VECTOR_CHILD_BUFFER, // vector child buffer: holds another vector + STRING_BUFFER, // string buffer, holds a string heap + FSST_BUFFER, // fsst compressed string buffer, holds a string heap, fsst symbol table and a string count + STRUCT_BUFFER, // struct buffer, holds a ordered mapping from name to child vector + LIST_BUFFER, // list buffer, holds a single flatvector child + MANAGED_BUFFER, // managed buffer, holds a buffer managed by the buffermanager + OPAQUE_BUFFER // opaque buffer, can be created for example by the parquet reader +}; + +enum class VectorAuxiliaryDataType : uint8_t { + ARROW_AUXILIARY // Holds Arrow Chunks that this vector depends on +}; + +struct VectorAuxiliaryData { + explicit VectorAuxiliaryData(VectorAuxiliaryDataType type_p) + : type(type_p) { + + }; + VectorAuxiliaryDataType type; + + virtual ~VectorAuxiliaryData() { + } + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast vector auxiliary data to type - type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast vector auxiliary data to type - type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +//! The VectorBuffer is a class used by the vector to hold its data +class VectorBuffer { +public: + explicit VectorBuffer(VectorBufferType type) : buffer_type(type) { + } + explicit VectorBuffer(idx_t data_size) : buffer_type(VectorBufferType::STANDARD_BUFFER) { + if (data_size > 0) { + data = make_unsafe_uniq_array(data_size); + } + } + explicit VectorBuffer(unsafe_unique_array data_p) + : buffer_type(VectorBufferType::STANDARD_BUFFER), data(std::move(data_p)) { + } + virtual ~VectorBuffer() { + } + VectorBuffer() { + } + +public: + data_ptr_t GetData() { + return data.get(); + } + + void SetData(unsafe_unique_array new_data) { + data = std::move(new_data); + } + + VectorAuxiliaryData *GetAuxiliaryData() { + return aux_data.get(); + } + + void SetAuxiliaryData(unique_ptr aux_data_p) { + aux_data = std::move(aux_data_p); + } + + void MoveAuxiliaryData(VectorBuffer &source_buffer) { + SetAuxiliaryData(std::move(source_buffer.aux_data)); + } + + static buffer_ptr CreateStandardVector(PhysicalType type, idx_t capacity = STANDARD_VECTOR_SIZE); + static buffer_ptr CreateConstantVector(PhysicalType type); + static buffer_ptr CreateConstantVector(const LogicalType &logical_type); + static buffer_ptr CreateStandardVector(const LogicalType &logical_type, + idx_t capacity = STANDARD_VECTOR_SIZE); + + inline VectorBufferType GetBufferType() const { + return buffer_type; + } + + inline VectorAuxiliaryDataType GetAuxiliaryDataType() const { + return aux_data->type; + } + +protected: + VectorBufferType buffer_type; + unique_ptr aux_data; + unsafe_unique_array data; + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +//! The DictionaryBuffer holds a selection vector +class DictionaryBuffer : public VectorBuffer { +public: + explicit DictionaryBuffer(const SelectionVector &sel) + : VectorBuffer(VectorBufferType::DICTIONARY_BUFFER), sel_vector(sel) { + } + explicit DictionaryBuffer(buffer_ptr data) + : VectorBuffer(VectorBufferType::DICTIONARY_BUFFER), sel_vector(std::move(data)) { + } + explicit DictionaryBuffer(idx_t count = STANDARD_VECTOR_SIZE) + : VectorBuffer(VectorBufferType::DICTIONARY_BUFFER), sel_vector(count) { + } + +public: + const SelectionVector &GetSelVector() const { + return sel_vector; + } + SelectionVector &GetSelVector() { + return sel_vector; + } + void SetSelVector(const SelectionVector &vector) { + this->sel_vector.Initialize(vector); + } + +private: + SelectionVector sel_vector; +}; + +class VectorStringBuffer : public VectorBuffer { +public: + VectorStringBuffer(); + explicit VectorStringBuffer(VectorBufferType type); + +public: + string_t AddString(const char *data, idx_t len) { + return heap.AddString(data, len); + } + string_t AddString(string_t data) { + return heap.AddString(data); + } + string_t AddBlob(string_t data) { + return heap.AddBlob(data.GetData(), data.GetSize()); + } + string_t EmptyString(idx_t len) { + return heap.EmptyString(len); + } + + void AddHeapReference(buffer_ptr heap) { + references.push_back(std::move(heap)); + } + +private: + //! The string heap of this buffer + StringHeap heap; + // References to additional vector buffers referenced by this string buffer + vector> references; +}; + +class VectorFSSTStringBuffer : public VectorStringBuffer { +public: + VectorFSSTStringBuffer(); + +public: + void AddDecoder(buffer_ptr &duckdb_fsst_decoder_p) { + duckdb_fsst_decoder = duckdb_fsst_decoder_p; + } + void *GetDecoder() { + return duckdb_fsst_decoder.get(); + } + void SetCount(idx_t count) { + total_string_count = count; + } + idx_t GetCount() { + return total_string_count; + } + +private: + buffer_ptr duckdb_fsst_decoder; + idx_t total_string_count = 0; +}; + +class VectorStructBuffer : public VectorBuffer { +public: + VectorStructBuffer(); + explicit VectorStructBuffer(const LogicalType &struct_type, idx_t capacity = STANDARD_VECTOR_SIZE); + VectorStructBuffer(Vector &other, const SelectionVector &sel, idx_t count); + ~VectorStructBuffer() override; + +public: + const vector> &GetChildren() const { + return children; + } + vector> &GetChildren() { + return children; + } + +private: + //! child vectors used for nested data + vector> children; +}; + +class VectorListBuffer : public VectorBuffer { +public: + explicit VectorListBuffer(unique_ptr vector, idx_t initial_capacity = STANDARD_VECTOR_SIZE); + explicit VectorListBuffer(const LogicalType &list_type, idx_t initial_capacity = STANDARD_VECTOR_SIZE); + ~VectorListBuffer() override; + +public: + Vector &GetChild() { + return *child; + } + void Reserve(idx_t to_reserve); + + void Append(const Vector &to_append, idx_t to_append_size, idx_t source_offset = 0); + void Append(const Vector &to_append, const SelectionVector &sel, idx_t to_append_size, idx_t source_offset = 0); + + void PushBack(const Value &insert); + + idx_t GetSize() { + return size; + } + + idx_t GetCapacity() { + return capacity; + } + + void SetCapacity(idx_t new_capacity); + void SetSize(idx_t new_size); + +private: + //! child vectors used for nested data + unique_ptr child; + idx_t capacity = 0; + idx_t size = 0; +}; + +//! The ManagedVectorBuffer holds a buffer handle +class ManagedVectorBuffer : public VectorBuffer { +public: + explicit ManagedVectorBuffer(BufferHandle handle); + ~ManagedVectorBuffer() override; + +private: + BufferHandle handle; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/vector_cache.hpp b/src/duckdb/src/include/duckdb/common/types/vector_cache.hpp new file mode 100644 index 00000000..333950ce --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/vector_cache.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/vector_cache.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/vector_buffer.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +class Allocator; +class Vector; + +//! The VectorCache holds cached data that allows for re-use of the same memory by vectors +class VectorCache { +public: + //! Instantiate a vector cache with the given type and capacity + DUCKDB_API explicit VectorCache(Allocator &allocator, const LogicalType &type, + idx_t capacity = STANDARD_VECTOR_SIZE); + + buffer_ptr buffer; + +public: + void ResetFromCache(Vector &result) const; + + const LogicalType &GetType() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/unicode_bar.hpp b/src/duckdb/src/include/duckdb/common/unicode_bar.hpp new file mode 100644 index 00000000..781d6553 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/unicode_bar.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/unicode_bar.hpp +// +// +//===----------------------------------------------------------------------===// + +namespace duckdb { +struct UnicodeBar { +private: + static constexpr idx_t PARTIAL_BLOCKS_COUNT = 8; + +public: + static constexpr idx_t PartialBlocksCount() { + return PARTIAL_BLOCKS_COUNT; + } + + static const char *const *PartialBlocks() { + static const char *PARTIAL_BLOCKS[PARTIAL_BLOCKS_COUNT] = {" ", + "\xE2\x96\x8F", + "\xE2\x96\x8E", + "\xE2\x96\x8D", + "\xE2\x96\x8C", + "\xE2\x96\x8B", + "\xE2\x96\x8A", + "\xE2\x96\x89"}; + return PARTIAL_BLOCKS; + } + + static const char *FullBlock() { + return "\xE2\x96\x88"; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/union_by_name.hpp b/src/duckdb/src/include/duckdb/common/union_by_name.hpp new file mode 100644 index 00000000..d92ed71a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/union_by_name.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/union_by_name.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { + +class UnionByName { +public: + static void CombineUnionTypes(const vector &new_names, const vector &new_types, + vector &union_col_types, vector &union_col_names, + case_insensitive_map_t &union_names_map); + + //! Union all files(readers) by their col names + template + static vector> UnionCols(ClientContext &context, const vector &files, + vector &union_col_types, + vector &union_col_names, OPTION_TYPE &options) { + vector> union_readers; + case_insensitive_map_t union_names_map; + for (idx_t file_idx = 0; file_idx < files.size(); ++file_idx) { + const auto file_name = files[file_idx]; + auto reader = make_uniq(context, file_name, options); + + auto &col_names = reader->GetNames(); + auto &sql_types = reader->GetTypes(); + CombineUnionTypes(col_names, sql_types, union_col_types, union_col_names, union_names_map); + union_readers.push_back(std::move(reader)); + } + return union_readers; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/unique_ptr.hpp b/src/duckdb/src/include/duckdb/common/unique_ptr.hpp new file mode 100644 index 00000000..f81270fb --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/unique_ptr.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/memory_safety.hpp" + +#include +#include + +namespace duckdb { + +template , bool SAFE = true> +class unique_ptr : public std::unique_ptr<_Tp, _Dp> { +public: + using original = std::unique_ptr<_Tp, _Dp>; + using original::original; + +private: + static inline void AssertNotNull(const bool null) { +#if defined(DUCKDB_DEBUG_NO_SAFETY) || defined(DUCKDB_CLANG_TIDY) + return; +#else + if (DUCKDB_UNLIKELY(null)) { + throw duckdb::InternalException("Attempted to dereference unique_ptr that is NULL!"); + } +#endif + } + +public: + typename std::add_lvalue_reference<_Tp>::type operator*() const { + const auto ptr = original::get(); + if (MemorySafety::enabled) { + AssertNotNull(!ptr); + } + return *ptr; + } + + typename original::pointer operator->() const { + const auto ptr = original::get(); + if (MemorySafety::enabled) { + AssertNotNull(!ptr); + } + return ptr; + } + +#ifdef DUCKDB_CLANG_TIDY + // This is necessary to tell clang-tidy that it reinitializes the variable after a move + [[clang::reinitializes]] +#endif + inline void + reset(typename original::pointer ptr = typename original::pointer()) noexcept { + original::reset(ptr); + } +}; + +template +class unique_ptr<_Tp[], _Dp, SAFE> : public std::unique_ptr<_Tp[], std::default_delete<_Tp[]>> { +public: + using original = std::unique_ptr<_Tp[], std::default_delete<_Tp[]>>; + using original::original; + +private: + static inline void AssertNotNull(const bool null) { +#if defined(DUCKDB_DEBUG_NO_SAFETY) || defined(DUCKDB_CLANG_TIDY) + return; +#else + if (DUCKDB_UNLIKELY(null)) { + throw duckdb::InternalException("Attempted to dereference unique_ptr that is NULL!"); + } +#endif + } + +public: + typename std::add_lvalue_reference<_Tp>::type operator[](size_t __i) const { + const auto ptr = original::get(); + if (MemorySafety::enabled) { + AssertNotNull(!ptr); + } + return ptr[__i]; + } +}; + +template +using unique_array = unique_ptr, true>; + +template +using unsafe_unique_array = unique_ptr, false>; + +template +using unsafe_unique_ptr = unique_ptr, false>; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/unordered_map.hpp b/src/duckdb/src/include/duckdb/common/unordered_map.hpp new file mode 100644 index 00000000..9e6d2b82 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/unordered_map.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/unordered_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::unordered_map; +} diff --git a/src/duckdb/src/include/duckdb/common/unordered_set.hpp b/src/duckdb/src/include/duckdb/common/unordered_set.hpp new file mode 100644 index 00000000..6d9defc0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/unordered_set.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/unordered_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +using std::unordered_set; +} diff --git a/src/duckdb/src/include/duckdb/common/value_operations/value_operations.hpp b/src/duckdb/src/include/duckdb/common/value_operations/value_operations.hpp new file mode 100644 index 00000000..055f4e4d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/value_operations/value_operations.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/value_operations/value_operations.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +struct ValueOperations { + //===--------------------------------------------------------------------===// + // Comparison Operations + //===--------------------------------------------------------------------===// + // A == B + static bool Equals(const Value &left, const Value &right); + // A != B + static bool NotEquals(const Value &left, const Value &right); + // A > B + static bool GreaterThan(const Value &left, const Value &right); + // A >= B + static bool GreaterThanEquals(const Value &left, const Value &right); + // A < B + static bool LessThan(const Value &left, const Value &right); + // A <= B + static bool LessThanEquals(const Value &left, const Value &right); + //===--------------------------------------------------------------------===// + // Distinction Operations + //===--------------------------------------------------------------------===// + // A == B, NULLs equal + static bool NotDistinctFrom(const Value &left, const Value &right); + // A != B, NULLs equal + static bool DistinctFrom(const Value &left, const Value &right); + // A > B, NULLs last + static bool DistinctGreaterThan(const Value &left, const Value &right); + // A >= B, NULLs last + static bool DistinctGreaterThanEquals(const Value &left, const Value &right); + // A < B, NULLs last + static bool DistinctLessThan(const Value &left, const Value &right); + // A <= B, NULLs last + static bool DistinctLessThanEquals(const Value &left, const Value &right); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector.hpp b/src/duckdb/src/include/duckdb/common/vector.hpp new file mode 100644 index 00000000..66a6bc73 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector.hpp @@ -0,0 +1,108 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/memory_safety.hpp" +#include + +namespace duckdb { + +template +class vector : public std::vector<_Tp, std::allocator<_Tp>> { +public: + using original = std::vector<_Tp, std::allocator<_Tp>>; + using original::original; + using size_type = typename original::size_type; + using const_reference = typename original::const_reference; + using reference = typename original::reference; + +private: + static inline void AssertIndexInBounds(idx_t index, idx_t size) { +#if defined(DUCKDB_DEBUG_NO_SAFETY) || defined(DUCKDB_CLANG_TIDY) + return; +#else + if (DUCKDB_UNLIKELY(index >= size)) { + throw InternalException("Attempted to access index %ld within vector of size %ld", index, size); + } +#endif + } + +public: +#ifdef DUCKDB_CLANG_TIDY + // This is necessary to tell clang-tidy that it reinitializes the variable after a move + [[clang::reinitializes]] +#endif + inline void + clear() noexcept { + original::clear(); + } + + // Because we create the other constructor, the implicitly created constructor + // gets deleted, so we have to be explicit + vector() = default; + vector(original &&other) : original(std::move(other)) { + } + template + vector(vector<_Tp, _SAFE> &&other) : original(std::move(other)) { + } + + template + inline typename original::reference get(typename original::size_type __n) { + if (MemorySafety<_SAFE>::enabled) { + AssertIndexInBounds(__n, original::size()); + } + return original::operator[](__n); + } + + template + inline typename original::const_reference get(typename original::size_type __n) const { + if (MemorySafety<_SAFE>::enabled) { + AssertIndexInBounds(__n, original::size()); + } + return original::operator[](__n); + } + + typename original::reference operator[](typename original::size_type __n) { + return get(__n); + } + typename original::const_reference operator[](typename original::size_type __n) const { + return get(__n); + } + + typename original::reference front() { + return get(0); + } + + typename original::const_reference front() const { + return get(0); + } + + typename original::reference back() { + if (MemorySafety::enabled && original::empty()) { + throw InternalException("'back' called on an empty vector!"); + } + return get(original::size() - 1); + } + + typename original::const_reference back() const { + if (MemorySafety::enabled && original::empty()) { + throw InternalException("'back' called on an empty vector!"); + } + return get(original::size() - 1); + } +}; + +template +using unsafe_vector = vector; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp new file mode 100644 index 00000000..4a80ac81 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp @@ -0,0 +1,405 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/aggregate_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/aggregate_state.hpp" + +namespace duckdb { + +// structs +struct AggregateInputData; +struct FrameBounds { + FrameBounds() : start(0), end(0) {}; + FrameBounds(idx_t start, idx_t end) : start(start), end(end) {}; + idx_t start = 0; + idx_t end = 0; +}; + +class AggregateExecutor { +private: + template + static inline void NullaryFlatLoop(STATE_TYPE **__restrict states, AggregateInputData &aggr_input_data, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + OP::template Operation(*states[i], aggr_input_data, i); + } + } + + template + static inline void NullaryScatterLoop(STATE_TYPE **__restrict states, AggregateInputData &aggr_input_data, + const SelectionVector &ssel, idx_t count) { + + for (idx_t i = 0; i < count; i++) { + auto sidx = ssel.get_index(i); + OP::template Operation(*states[sidx], aggr_input_data, sidx); + } + } + + template + static inline void UnaryFlatLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, + STATE_TYPE **__restrict states, ValidityMask &mask, idx_t count) { + if (OP::IgnoreNull() && !mask.AllValid()) { + AggregateUnaryInput input(aggr_input_data, mask); + auto &base_idx = input.input_idx; + base_idx = 0; + auto entry_count = ValidityMask::EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + auto validity_entry = mask.GetValidityEntry(entry_idx); + idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (ValidityMask::AllValid(validity_entry)) { + // all valid: perform operation + for (; base_idx < next; base_idx++) { + OP::template Operation(*states[base_idx], idata[base_idx], input); + } + } else if (ValidityMask::NoneValid(validity_entry)) { + // nothing valid: skip all + base_idx = next; + continue; + } else { + // partially valid: need to check individual elements for validity + idx_t start = base_idx; + for (; base_idx < next; base_idx++) { + if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { + OP::template Operation(*states[base_idx], idata[base_idx], + input); + } + } + } + } + } else { + AggregateUnaryInput input(aggr_input_data, mask); + auto &i = input.input_idx; + for (i = 0; i < count; i++) { + OP::template Operation(*states[i], idata[i], input); + } + } + } + + template + static inline void UnaryScatterLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, + STATE_TYPE **__restrict states, const SelectionVector &isel, + const SelectionVector &ssel, ValidityMask &mask, idx_t count) { + if (OP::IgnoreNull() && !mask.AllValid()) { + // potential NULL values and NULL values are ignored + AggregateUnaryInput input(aggr_input_data, mask); + for (idx_t i = 0; i < count; i++) { + input.input_idx = isel.get_index(i); + auto sidx = ssel.get_index(i); + if (mask.RowIsValid(input.input_idx)) { + OP::template Operation(*states[sidx], idata[input.input_idx], input); + } + } + } else { + // quick path: no NULL values or NULL values are not ignored + AggregateUnaryInput input(aggr_input_data, mask); + for (idx_t i = 0; i < count; i++) { + input.input_idx = isel.get_index(i); + auto sidx = ssel.get_index(i); + OP::template Operation(*states[sidx], idata[input.input_idx], input); + } + } + } + + template + static inline void UnaryFlatUpdateLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, + STATE_TYPE *__restrict state, idx_t count, ValidityMask &mask) { + AggregateUnaryInput input(aggr_input_data, mask); + auto &base_idx = input.input_idx; + base_idx = 0; + auto entry_count = ValidityMask::EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + auto validity_entry = mask.GetValidityEntry(entry_idx); + idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (!OP::IgnoreNull() || ValidityMask::AllValid(validity_entry)) { + // all valid: perform operation + for (; base_idx < next; base_idx++) { + OP::template Operation(*state, idata[base_idx], input); + } + } else if (ValidityMask::NoneValid(validity_entry)) { + // nothing valid: skip all + base_idx = next; + continue; + } else { + // partially valid: need to check individual elements for validity + idx_t start = base_idx; + for (; base_idx < next; base_idx++) { + if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { + OP::template Operation(*state, idata[base_idx], input); + } + } + } + } + } + + template + static inline void UnaryUpdateLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, + STATE_TYPE *__restrict state, idx_t count, ValidityMask &mask, + const SelectionVector &__restrict sel_vector) { + AggregateUnaryInput input(aggr_input_data, mask); + if (OP::IgnoreNull() && !mask.AllValid()) { + // potential NULL values and NULL values are ignored + for (idx_t i = 0; i < count; i++) { + input.input_idx = sel_vector.get_index(i); + if (mask.RowIsValid(input.input_idx)) { + OP::template Operation(*state, idata[input.input_idx], input); + } + } + } else { + // quick path: no NULL values or NULL values are not ignored + for (idx_t i = 0; i < count; i++) { + input.input_idx = sel_vector.get_index(i); + OP::template Operation(*state, idata[input.input_idx], input); + } + } + } + + template + static inline void BinaryScatterLoop(const A_TYPE *__restrict adata, AggregateInputData &aggr_input_data, + const B_TYPE *__restrict bdata, STATE_TYPE **__restrict states, idx_t count, + const SelectionVector &asel, const SelectionVector &bsel, + const SelectionVector &ssel, ValidityMask &avalidity, + ValidityMask &bvalidity) { + AggregateBinaryInput input(aggr_input_data, avalidity, bvalidity); + if (OP::IgnoreNull() && (!avalidity.AllValid() || !bvalidity.AllValid())) { + // potential NULL values and NULL values are ignored + for (idx_t i = 0; i < count; i++) { + input.lidx = asel.get_index(i); + input.ridx = bsel.get_index(i); + auto sidx = ssel.get_index(i); + if (avalidity.RowIsValid(input.lidx) && bvalidity.RowIsValid(input.ridx)) { + OP::template Operation(*states[sidx], adata[input.lidx], + bdata[input.ridx], input); + } + } + } else { + // quick path: no NULL values or NULL values are not ignored + for (idx_t i = 0; i < count; i++) { + input.lidx = asel.get_index(i); + input.ridx = bsel.get_index(i); + auto sidx = ssel.get_index(i); + OP::template Operation(*states[sidx], adata[input.lidx], + bdata[input.ridx], input); + } + } + } + + template + static inline void BinaryUpdateLoop(const A_TYPE *__restrict adata, AggregateInputData &aggr_input_data, + const B_TYPE *__restrict bdata, STATE_TYPE *__restrict state, idx_t count, + const SelectionVector &asel, const SelectionVector &bsel, + ValidityMask &avalidity, ValidityMask &bvalidity) { + AggregateBinaryInput input(aggr_input_data, avalidity, bvalidity); + if (OP::IgnoreNull() && (!avalidity.AllValid() || !bvalidity.AllValid())) { + // potential NULL values and NULL values are ignored + for (idx_t i = 0; i < count; i++) { + input.lidx = asel.get_index(i); + input.ridx = bsel.get_index(i); + if (avalidity.RowIsValid(input.lidx) && bvalidity.RowIsValid(input.ridx)) { + OP::template Operation(*state, adata[input.lidx], bdata[input.ridx], + input); + } + } + } else { + // quick path: no NULL values or NULL values are not ignored + for (idx_t i = 0; i < count; i++) { + input.lidx = asel.get_index(i); + input.ridx = bsel.get_index(i); + OP::template Operation(*state, adata[input.lidx], bdata[input.ridx], + input); + } + } + } + +public: + template + static void NullaryScatter(Vector &states, AggregateInputData &aggr_input_data, idx_t count) { + if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { + auto sdata = ConstantVector::GetData(states); + OP::template ConstantOperation(**sdata, aggr_input_data, count); + } else if (states.GetVectorType() == VectorType::FLAT_VECTOR) { + auto sdata = FlatVector::GetData(states); + NullaryFlatLoop(sdata, aggr_input_data, count); + } else { + UnifiedVectorFormat sdata; + states.ToUnifiedFormat(count, sdata); + NullaryScatterLoop((STATE_TYPE **)sdata.data, aggr_input_data, *sdata.sel, count); + } + } + + template + static void NullaryUpdate(data_ptr_t state, AggregateInputData &aggr_input_data, idx_t count) { + OP::template ConstantOperation(*reinterpret_cast(state), aggr_input_data, count); + } + + template + static void UnaryScatter(Vector &input, Vector &states, AggregateInputData &aggr_input_data, idx_t count) { + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR && + states.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (OP::IgnoreNull() && ConstantVector::IsNull(input)) { + // constant NULL input in function that ignores NULL values + return; + } + // regular constant: get first state + auto idata = ConstantVector::GetData(input); + auto sdata = ConstantVector::GetData(states); + AggregateUnaryInput input_data(aggr_input_data, ConstantVector::Validity(input)); + OP::template ConstantOperation(**sdata, *idata, input_data, count); + } else if (input.GetVectorType() == VectorType::FLAT_VECTOR && + states.GetVectorType() == VectorType::FLAT_VECTOR) { + auto idata = FlatVector::GetData(input); + auto sdata = FlatVector::GetData(states); + UnaryFlatLoop(idata, aggr_input_data, sdata, FlatVector::Validity(input), + count); + } else { + UnifiedVectorFormat idata, sdata; + input.ToUnifiedFormat(count, idata); + states.ToUnifiedFormat(count, sdata); + UnaryScatterLoop(UnifiedVectorFormat::GetData(idata), + aggr_input_data, (STATE_TYPE **)sdata.data, *idata.sel, + *sdata.sel, idata.validity, count); + } + } + + template + static void UnaryUpdate(Vector &input, AggregateInputData &aggr_input_data, data_ptr_t state, idx_t count) { + switch (input.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + if (OP::IgnoreNull() && ConstantVector::IsNull(input)) { + return; + } + auto idata = ConstantVector::GetData(input); + AggregateUnaryInput input_data(aggr_input_data, ConstantVector::Validity(input)); + OP::template ConstantOperation(*reinterpret_cast(state), *idata, + input_data, count); + break; + } + case VectorType::FLAT_VECTOR: { + auto idata = FlatVector::GetData(input); + UnaryFlatUpdateLoop(idata, aggr_input_data, (STATE_TYPE *)state, count, + FlatVector::Validity(input)); + break; + } + default: { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + UnaryUpdateLoop(UnifiedVectorFormat::GetData(idata), + aggr_input_data, (STATE_TYPE *)state, count, idata.validity, + *idata.sel); + break; + } + } + } + + template + static void BinaryScatter(AggregateInputData &aggr_input_data, Vector &a, Vector &b, Vector &states, idx_t count) { + UnifiedVectorFormat adata, bdata, sdata; + + a.ToUnifiedFormat(count, adata); + b.ToUnifiedFormat(count, bdata); + states.ToUnifiedFormat(count, sdata); + + BinaryScatterLoop( + UnifiedVectorFormat::GetData(adata), aggr_input_data, UnifiedVectorFormat::GetData(bdata), + (STATE_TYPE **)sdata.data, count, *adata.sel, *bdata.sel, *sdata.sel, adata.validity, bdata.validity); + } + + template + static void BinaryUpdate(AggregateInputData &aggr_input_data, Vector &a, Vector &b, data_ptr_t state, idx_t count) { + UnifiedVectorFormat adata, bdata; + + a.ToUnifiedFormat(count, adata); + b.ToUnifiedFormat(count, bdata); + + BinaryUpdateLoop( + UnifiedVectorFormat::GetData(adata), aggr_input_data, UnifiedVectorFormat::GetData(bdata), + (STATE_TYPE *)state, count, *adata.sel, *bdata.sel, adata.validity, bdata.validity); + } + + template + static void Combine(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) { + D_ASSERT(source.GetType().id() == LogicalTypeId::POINTER && target.GetType().id() == LogicalTypeId::POINTER); + auto sdata = FlatVector::GetData(source); + auto tdata = FlatVector::GetData(target); + + for (idx_t i = 0; i < count; i++) { + OP::template Combine(*sdata[i], *tdata[i], aggr_input_data); + } + } + + template + static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + idx_t offset) { + if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + auto sdata = ConstantVector::GetData(states); + auto rdata = ConstantVector::GetData(result); + AggregateFinalizeData finalize_data(result, aggr_input_data); + OP::template Finalize(**sdata, *rdata, finalize_data); + } else { + D_ASSERT(states.GetVectorType() == VectorType::FLAT_VECTOR); + result.SetVectorType(VectorType::FLAT_VECTOR); + + auto sdata = FlatVector::GetData(states); + auto rdata = FlatVector::GetData(result); + AggregateFinalizeData finalize_data(result, aggr_input_data); + for (idx_t i = 0; i < count; i++) { + finalize_data.result_idx = i + offset; + OP::template Finalize(*sdata[i], rdata[finalize_data.result_idx], + finalize_data); + } + } + } + + template + static void VoidFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + idx_t offset) { + if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + auto sdata = ConstantVector::GetData(states); + AggregateFinalizeData finalize_data(result, aggr_input_data); + OP::template Finalize(**sdata, finalize_data); + } else { + D_ASSERT(states.GetVectorType() == VectorType::FLAT_VECTOR); + result.SetVectorType(VectorType::FLAT_VECTOR); + + auto sdata = FlatVector::GetData(states); + AggregateFinalizeData finalize_data(result, aggr_input_data); + for (idx_t i = 0; i < count; i++) { + finalize_data.result_idx = i + offset; + OP::template Finalize(*sdata[i], finalize_data); + } + } + } + + template + static void UnaryWindow(Vector &input, const ValidityMask &ifilter, AggregateInputData &aggr_input_data, + data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, Vector &result, + idx_t rid, idx_t bias) { + + auto idata = FlatVector::GetData(input) - bias; + const auto &ivalid = FlatVector::Validity(input); + OP::template Window( + idata, ifilter, ivalid, aggr_input_data, *reinterpret_cast(state), frame, prev, result, rid, bias); + } + + template + static void Destroy(Vector &states, AggregateInputData &aggr_input_data, idx_t count) { + auto sdata = FlatVector::GetData(states); + for (idx_t i = 0; i < count; i++) { + OP::template Destroy(*sdata[i], aggr_input_data); + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/binary_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/binary_executor.hpp new file mode 100644 index 00000000..55c10bb2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/binary_executor.hpp @@ -0,0 +1,520 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/binary_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include + +namespace duckdb { + +struct DefaultNullCheckOperator { + template + static inline bool Operation(LEFT_TYPE left, RIGHT_TYPE right) { + return false; + } +}; + +struct BinaryStandardOperatorWrapper { + template + static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { + return OP::template Operation(left, right); + } + + static bool AddsNulls() { + return false; + } +}; + +struct BinarySingleArgumentOperatorWrapper { + template + static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { + return OP::template Operation(left, right); + } + + static bool AddsNulls() { + return false; + } +}; + +struct BinaryLambdaWrapper { + template + static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { + return fun(left, right); + } + + static bool AddsNulls() { + return false; + } +}; + +struct BinaryLambdaWrapperWithNulls { + template + static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { + return fun(left, right, mask, idx); + } + + static bool AddsNulls() { + return true; + } +}; + +struct BinaryExecutor { + template + static void ExecuteFlatLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + RESULT_TYPE *__restrict result_data, idx_t count, ValidityMask &mask, FUNC fun) { + if (!LEFT_CONSTANT) { + ASSERT_RESTRICT(ldata, ldata + count, result_data, result_data + count); + } + if (!RIGHT_CONSTANT) { + ASSERT_RESTRICT(rdata, rdata + count, result_data, result_data + count); + } + + if (!mask.AllValid()) { + idx_t base_idx = 0; + auto entry_count = ValidityMask::EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + auto validity_entry = mask.GetValidityEntry(entry_idx); + idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (ValidityMask::AllValid(validity_entry)) { + // all valid: perform operation + for (; base_idx < next; base_idx++) { + auto lentry = ldata[LEFT_CONSTANT ? 0 : base_idx]; + auto rentry = rdata[RIGHT_CONSTANT ? 0 : base_idx]; + result_data[base_idx] = + OPWRAPPER::template Operation( + fun, lentry, rentry, mask, base_idx); + } + } else if (ValidityMask::NoneValid(validity_entry)) { + // nothing valid: skip all + base_idx = next; + continue; + } else { + // partially valid: need to check individual elements for validity + idx_t start = base_idx; + for (; base_idx < next; base_idx++) { + if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { + auto lentry = ldata[LEFT_CONSTANT ? 0 : base_idx]; + auto rentry = rdata[RIGHT_CONSTANT ? 0 : base_idx]; + result_data[base_idx] = + OPWRAPPER::template Operation( + fun, lentry, rentry, mask, base_idx); + } + } + } + } + } else { + for (idx_t i = 0; i < count; i++) { + auto lentry = ldata[LEFT_CONSTANT ? 0 : i]; + auto rentry = rdata[RIGHT_CONSTANT ? 0 : i]; + result_data[i] = OPWRAPPER::template Operation( + fun, lentry, rentry, mask, i); + } + } + } + + template + static void ExecuteConstant(Vector &left, Vector &right, Vector &result, FUNC fun) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + auto ldata = ConstantVector::GetData(left); + auto rdata = ConstantVector::GetData(right); + auto result_data = ConstantVector::GetData(result); + + if (ConstantVector::IsNull(left) || ConstantVector::IsNull(right)) { + ConstantVector::SetNull(result, true); + return; + } + *result_data = OPWRAPPER::template Operation( + fun, *ldata, *rdata, ConstantVector::Validity(result), 0); + } + + template + static void ExecuteFlat(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + auto ldata = FlatVector::GetData(left); + auto rdata = FlatVector::GetData(right); + + if ((LEFT_CONSTANT && ConstantVector::IsNull(left)) || (RIGHT_CONSTANT && ConstantVector::IsNull(right))) { + // either left or right is constant NULL: result is constant NULL + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + if (LEFT_CONSTANT) { + if (OPWRAPPER::AddsNulls()) { + result_validity.Copy(FlatVector::Validity(right), count); + } else { + FlatVector::SetValidity(result, FlatVector::Validity(right)); + } + } else if (RIGHT_CONSTANT) { + if (OPWRAPPER::AddsNulls()) { + result_validity.Copy(FlatVector::Validity(left), count); + } else { + FlatVector::SetValidity(result, FlatVector::Validity(left)); + } + } else { + if (OPWRAPPER::AddsNulls()) { + result_validity.Copy(FlatVector::Validity(left), count); + if (result_validity.AllValid()) { + result_validity.Copy(FlatVector::Validity(right), count); + } else { + result_validity.Combine(FlatVector::Validity(right), count); + } + } else { + FlatVector::SetValidity(result, FlatVector::Validity(left)); + result_validity.Combine(FlatVector::Validity(right), count); + } + } + ExecuteFlatLoop( + ldata, rdata, result_data, count, result_validity, fun); + } + + template + static void ExecuteGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + RESULT_TYPE *__restrict result_data, const SelectionVector *__restrict lsel, + const SelectionVector *__restrict rsel, idx_t count, ValidityMask &lvalidity, + ValidityMask &rvalidity, ValidityMask &result_validity, FUNC fun) { + if (!lvalidity.AllValid() || !rvalidity.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto lindex = lsel->get_index(i); + auto rindex = rsel->get_index(i); + if (lvalidity.RowIsValid(lindex) && rvalidity.RowIsValid(rindex)) { + auto lentry = ldata[lindex]; + auto rentry = rdata[rindex]; + result_data[i] = OPWRAPPER::template Operation( + fun, lentry, rentry, result_validity, i); + } else { + result_validity.SetInvalid(i); + } + } + } else { + for (idx_t i = 0; i < count; i++) { + auto lentry = ldata[lsel->get_index(i)]; + auto rentry = rdata[rsel->get_index(i)]; + result_data[i] = OPWRAPPER::template Operation( + fun, lentry, rentry, result_validity, i); + } + } + } + + template + static void ExecuteGeneric(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + UnifiedVectorFormat ldata, rdata; + + left.ToUnifiedFormat(count, ldata); + right.ToUnifiedFormat(count, rdata); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + ExecuteGenericLoop( + UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), + result_data, ldata.sel, rdata.sel, count, ldata.validity, rdata.validity, FlatVector::Validity(result), + fun); + } + + template + static void ExecuteSwitch(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + auto left_vector_type = left.GetVectorType(); + auto right_vector_type = right.GetVectorType(); + if (left_vector_type == VectorType::CONSTANT_VECTOR && right_vector_type == VectorType::CONSTANT_VECTOR) { + ExecuteConstant(left, right, result, fun); + } else if (left_vector_type == VectorType::FLAT_VECTOR && right_vector_type == VectorType::CONSTANT_VECTOR) { + ExecuteFlat(left, right, result, + count, fun); + } else if (left_vector_type == VectorType::CONSTANT_VECTOR && right_vector_type == VectorType::FLAT_VECTOR) { + ExecuteFlat(left, right, result, + count, fun); + } else if (left_vector_type == VectorType::FLAT_VECTOR && right_vector_type == VectorType::FLAT_VECTOR) { + ExecuteFlat(left, right, result, + count, fun); + } else { + ExecuteGeneric(left, right, result, count, fun); + } + } + +public: + template > + static void Execute(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + ExecuteSwitch(left, right, result, count, + fun); + } + + template + static void Execute(Vector &left, Vector &right, Vector &result, idx_t count) { + ExecuteSwitch(left, right, result, count, false); + } + + template + static void ExecuteStandard(Vector &left, Vector &right, Vector &result, idx_t count) { + ExecuteSwitch(left, right, result, + count, false); + } + + template > + static void ExecuteWithNulls(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + ExecuteSwitch(left, right, result, + count, fun); + } + +public: + template + static idx_t SelectConstant(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + auto ldata = ConstantVector::GetData(left); + auto rdata = ConstantVector::GetData(right); + + // both sides are constant, return either 0 or the count + // in this case we do not fill in the result selection vector at all + if (ConstantVector::IsNull(left) || ConstantVector::IsNull(right) || !OP::Operation(*ldata, *rdata)) { + if (false_sel) { + for (idx_t i = 0; i < count; i++) { + false_sel->set_index(i, sel->get_index(i)); + } + } + return 0; + } else { + if (true_sel) { + for (idx_t i = 0; i < count; i++) { + true_sel->set_index(i, sel->get_index(i)); + } + } + return count; + } + } + + template + static inline idx_t SelectFlatLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + const SelectionVector *sel, idx_t count, ValidityMask &validity_mask, + SelectionVector *true_sel, SelectionVector *false_sel) { + idx_t true_count = 0, false_count = 0; + idx_t base_idx = 0; + auto entry_count = ValidityMask::EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + auto validity_entry = validity_mask.GetValidityEntry(entry_idx); + idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (ValidityMask::AllValid(validity_entry)) { + // all valid: perform operation + for (; base_idx < next; base_idx++) { + idx_t result_idx = sel->get_index(base_idx); + idx_t lidx = LEFT_CONSTANT ? 0 : base_idx; + idx_t ridx = RIGHT_CONSTANT ? 0 : base_idx; + bool comparison_result = OP::Operation(ldata[lidx], rdata[ridx]); + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count, result_idx); + true_count += comparison_result; + } + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count, result_idx); + false_count += !comparison_result; + } + } + } else if (ValidityMask::NoneValid(validity_entry)) { + // nothing valid: skip all + if (HAS_FALSE_SEL) { + for (; base_idx < next; base_idx++) { + idx_t result_idx = sel->get_index(base_idx); + false_sel->set_index(false_count, result_idx); + false_count++; + } + } + base_idx = next; + continue; + } else { + // partially valid: need to check individual elements for validity + idx_t start = base_idx; + for (; base_idx < next; base_idx++) { + idx_t result_idx = sel->get_index(base_idx); + idx_t lidx = LEFT_CONSTANT ? 0 : base_idx; + idx_t ridx = RIGHT_CONSTANT ? 0 : base_idx; + bool comparison_result = ValidityMask::RowIsValid(validity_entry, base_idx - start) && + OP::Operation(ldata[lidx], rdata[ridx]); + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count, result_idx); + true_count += comparison_result; + } + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count, result_idx); + false_count += !comparison_result; + } + } + } + } + if (HAS_TRUE_SEL) { + return true_count; + } else { + return count - false_count; + } + } + + template + static inline idx_t SelectFlatLoopSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + const SelectionVector *sel, idx_t count, ValidityMask &mask, + SelectionVector *true_sel, SelectionVector *false_sel) { + if (true_sel && false_sel) { + return SelectFlatLoop( + ldata, rdata, sel, count, mask, true_sel, false_sel); + } else if (true_sel) { + return SelectFlatLoop( + ldata, rdata, sel, count, mask, true_sel, false_sel); + } else { + D_ASSERT(false_sel); + return SelectFlatLoop( + ldata, rdata, sel, count, mask, true_sel, false_sel); + } + } + + template + static idx_t SelectFlat(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + auto ldata = FlatVector::GetData(left); + auto rdata = FlatVector::GetData(right); + + if (LEFT_CONSTANT && ConstantVector::IsNull(left)) { + if (false_sel) { + for (idx_t i = 0; i < count; i++) { + false_sel->set_index(i, sel->get_index(i)); + } + } + return 0; + } + if (RIGHT_CONSTANT && ConstantVector::IsNull(right)) { + if (false_sel) { + for (idx_t i = 0; i < count; i++) { + false_sel->set_index(i, sel->get_index(i)); + } + } + return 0; + } + + if (LEFT_CONSTANT) { + return SelectFlatLoopSwitch( + ldata, rdata, sel, count, FlatVector::Validity(right), true_sel, false_sel); + } else if (RIGHT_CONSTANT) { + return SelectFlatLoopSwitch( + ldata, rdata, sel, count, FlatVector::Validity(left), true_sel, false_sel); + } else { + ValidityMask combined_mask = FlatVector::Validity(left); + combined_mask.Combine(FlatVector::Validity(right), count); + return SelectFlatLoopSwitch( + ldata, rdata, sel, count, combined_mask, true_sel, false_sel); + } + } + + template + static inline idx_t + SelectGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, + const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lvalidity, + ValidityMask &rvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { + idx_t true_count = 0, false_count = 0; + for (idx_t i = 0; i < count; i++) { + auto result_idx = result_sel->get_index(i); + auto lindex = lsel->get_index(i); + auto rindex = rsel->get_index(i); + if ((NO_NULL || (lvalidity.RowIsValid(lindex) && rvalidity.RowIsValid(rindex))) && + OP::Operation(ldata[lindex], rdata[rindex])) { + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count++, result_idx); + } + } else { + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count++, result_idx); + } + } + } + if (HAS_TRUE_SEL) { + return true_count; + } else { + return count - false_count; + } + } + template + static inline idx_t + SelectGenericLoopSelSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, + const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lvalidity, + ValidityMask &rvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { + if (true_sel && false_sel) { + return SelectGenericLoop( + ldata, rdata, lsel, rsel, result_sel, count, lvalidity, rvalidity, true_sel, false_sel); + } else if (true_sel) { + return SelectGenericLoop( + ldata, rdata, lsel, rsel, result_sel, count, lvalidity, rvalidity, true_sel, false_sel); + } else { + D_ASSERT(false_sel); + return SelectGenericLoop( + ldata, rdata, lsel, rsel, result_sel, count, lvalidity, rvalidity, true_sel, false_sel); + } + } + + template + static inline idx_t + SelectGenericLoopSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, + const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, + const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lvalidity, + ValidityMask &rvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { + if (!lvalidity.AllValid() || !rvalidity.AllValid()) { + return SelectGenericLoopSelSwitch( + ldata, rdata, lsel, rsel, result_sel, count, lvalidity, rvalidity, true_sel, false_sel); + } else { + return SelectGenericLoopSelSwitch( + ldata, rdata, lsel, rsel, result_sel, count, lvalidity, rvalidity, true_sel, false_sel); + } + } + + template + static idx_t SelectGeneric(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + UnifiedVectorFormat ldata, rdata; + + left.ToUnifiedFormat(count, ldata); + right.ToUnifiedFormat(count, rdata); + + return SelectGenericLoopSwitch( + UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), ldata.sel, + rdata.sel, sel, count, ldata.validity, rdata.validity, true_sel, false_sel); + } + + template + static idx_t Select(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + if (!sel) { + sel = FlatVector::IncrementalSelectionVector(); + } + if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && + right.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return SelectConstant(left, right, sel, count, true_sel, false_sel); + } else if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && + right.GetVectorType() == VectorType::FLAT_VECTOR) { + return SelectFlat(left, right, sel, count, true_sel, false_sel); + } else if (left.GetVectorType() == VectorType::FLAT_VECTOR && + right.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return SelectFlat(left, right, sel, count, true_sel, false_sel); + } else if (left.GetVectorType() == VectorType::FLAT_VECTOR && + right.GetVectorType() == VectorType::FLAT_VECTOR) { + return SelectFlat(left, right, sel, count, true_sel, false_sel); + } else { + return SelectGeneric(left, right, sel, count, true_sel, false_sel); + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/general_cast.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/general_cast.hpp new file mode 100644 index 00000000..9ddfe4fc --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/general_cast.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/general_cast.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +struct HandleVectorCastError { + template + static RESULT_TYPE Operation(string error_message, ValidityMask &mask, idx_t idx, string *error_message_ptr, + bool &all_converted) { + HandleCastError::AssignError(error_message, error_message_ptr); + all_converted = false; + mask.SetInvalid(idx); + return NullValue(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/generic_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/generic_executor.hpp new file mode 100644 index 00000000..55f84e04 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/generic_executor.hpp @@ -0,0 +1,377 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/generic_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include + +namespace duckdb { + +struct PrimitiveTypeState { + UnifiedVectorFormat main_data; + + void PrepareVector(Vector &input, idx_t count) { + input.ToUnifiedFormat(count, main_data); + } +}; + +template +struct PrimitiveType { + PrimitiveType() { + } + PrimitiveType(INPUT_TYPE val) : val(val) { + } // NOLINT: allow implicit cast + + INPUT_TYPE val; + + using STRUCT_STATE = PrimitiveTypeState; + + static bool ConstructType(STRUCT_STATE &state, idx_t i, PrimitiveType &result) { + auto &vdata = state.main_data; + auto idx = vdata.sel->get_index(i); + auto ptr = UnifiedVectorFormat::GetData(vdata); + result.val = ptr[idx]; + return true; + } + + static void AssignResult(Vector &result, idx_t i, PrimitiveType value) { + auto result_data = FlatVector::GetData(result); + result_data[i] = value.val; + } +}; + +template +struct StructTypeState { + UnifiedVectorFormat main_data; + UnifiedVectorFormat child_data[CHILD_COUNT]; + + void PrepareVector(Vector &input, idx_t count) { + auto &entries = StructVector::GetEntries(input); + + input.ToUnifiedFormat(count, main_data); + + for (idx_t i = 0; i < CHILD_COUNT; i++) { + entries[i]->ToUnifiedFormat(count, child_data[i]); + } + } +}; + +template +struct StructTypeUnary { + A_TYPE a_val; + + using STRUCT_STATE = StructTypeState<1>; + + static bool ConstructType(STRUCT_STATE &state, idx_t i, StructTypeUnary &result) { + auto &a_data = state.child_data[0]; + auto a_idx = a_data.sel->get_index(i); + if (!a_data.validity.RowIsValid(a_idx)) { + return false; + } + auto a_ptr = UnifiedVectorFormat::GetData(a_data); + result.a_val = a_ptr[a_idx]; + return true; + } + + static void AssignResult(Vector &result, idx_t i, StructTypeUnary value) { + auto &entries = StructVector::GetEntries(result); + + auto a_data = FlatVector::GetData(*entries[0]); + a_data[i] = value.a_val; + } +}; + +template +struct StructTypeBinary { + A_TYPE a_val; + B_TYPE b_val; + + using STRUCT_STATE = StructTypeState<2>; + + static bool ConstructType(STRUCT_STATE &state, idx_t i, StructTypeBinary &result) { + auto &a_data = state.child_data[0]; + auto &b_data = state.child_data[1]; + + auto a_idx = a_data.sel->get_index(i); + auto b_idx = b_data.sel->get_index(i); + if (!a_data.validity.RowIsValid(a_idx) || !b_data.validity.RowIsValid(b_idx)) { + return false; + } + auto a_ptr = UnifiedVectorFormat::GetData(a_data); + auto b_ptr = UnifiedVectorFormat::GetData(b_data); + result.a_val = a_ptr[a_idx]; + result.b_val = b_ptr[b_idx]; + return true; + } + + static void AssignResult(Vector &result, idx_t i, StructTypeBinary value) { + auto &entries = StructVector::GetEntries(result); + + auto a_data = FlatVector::GetData(*entries[0]); + auto b_data = FlatVector::GetData(*entries[1]); + a_data[i] = value.a_val; + b_data[i] = value.b_val; + } +}; + +template +struct StructTypeTernary { + A_TYPE a_val; + B_TYPE b_val; + C_TYPE c_val; + + using STRUCT_STATE = StructTypeState<3>; + + static bool ConstructType(STRUCT_STATE &state, idx_t i, StructTypeTernary &result) { + auto &a_data = state.child_data[0]; + auto &b_data = state.child_data[1]; + auto &c_data = state.child_data[2]; + + auto a_idx = a_data.sel->get_index(i); + auto b_idx = b_data.sel->get_index(i); + auto c_idx = c_data.sel->get_index(i); + if (!a_data.validity.RowIsValid(a_idx) || !b_data.validity.RowIsValid(b_idx) || + !c_data.validity.RowIsValid(c_idx)) { + return false; + } + auto a_ptr = UnifiedVectorFormat::GetData(a_data); + auto b_ptr = UnifiedVectorFormat::GetData(b_data); + auto c_ptr = UnifiedVectorFormat::GetData(c_data); + result.a_val = a_ptr[a_idx]; + result.b_val = b_ptr[b_idx]; + result.c_val = c_ptr[c_idx]; + return true; + } + + static void AssignResult(Vector &result, idx_t i, StructTypeTernary value) { + auto &entries = StructVector::GetEntries(result); + + auto a_data = FlatVector::GetData(*entries[0]); + auto b_data = FlatVector::GetData(*entries[1]); + auto c_data = FlatVector::GetData(*entries[2]); + a_data[i] = value.a_val; + b_data[i] = value.b_val; + c_data[i] = value.c_val; + } +}; + +template +struct StructTypeQuaternary { + A_TYPE a_val; + B_TYPE b_val; + C_TYPE c_val; + D_TYPE d_val; + + using STRUCT_STATE = StructTypeState<4>; + + static bool ConstructType(STRUCT_STATE &state, idx_t i, + StructTypeQuaternary &result) { + auto &a_data = state.child_data[0]; + auto &b_data = state.child_data[1]; + auto &c_data = state.child_data[2]; + auto &d_data = state.child_data[3]; + + auto a_idx = a_data.sel->get_index(i); + auto b_idx = b_data.sel->get_index(i); + auto c_idx = c_data.sel->get_index(i); + auto d_idx = d_data.sel->get_index(i); + if (!a_data.validity.RowIsValid(a_idx) || !b_data.validity.RowIsValid(b_idx) || + !c_data.validity.RowIsValid(c_idx) || !d_data.validity.RowIsValid(d_idx)) { + return false; + } + auto a_ptr = UnifiedVectorFormat::GetData(a_data); + auto b_ptr = UnifiedVectorFormat::GetData(b_data); + auto c_ptr = UnifiedVectorFormat::GetData(c_data); + auto d_ptr = UnifiedVectorFormat::GetData(d_data); + result.a_val = a_ptr[a_idx]; + result.b_val = b_ptr[b_idx]; + result.c_val = c_ptr[c_idx]; + result.d_val = d_ptr[d_idx]; + return true; + } + + static void AssignResult(Vector &result, idx_t i, StructTypeQuaternary value) { + auto &entries = StructVector::GetEntries(result); + + auto a_data = FlatVector::GetData(*entries[0]); + auto b_data = FlatVector::GetData(*entries[1]); + auto c_data = FlatVector::GetData(*entries[2]); + auto d_data = FlatVector::GetData(*entries[3]); + + a_data[i] = value.a_val; + b_data[i] = value.b_val; + c_data[i] = value.c_val; + d_data[i] = value.d_val; + } +}; + +//! The GenericExecutor can handle struct types in addition to primitive types +struct GenericExecutor { +private: + template + static void ExecuteUnaryInternal(Vector &input, Vector &result, idx_t count, FUNC &fun) { + auto constant = input.GetVectorType() == VectorType::CONSTANT_VECTOR; + + typename A_TYPE::STRUCT_STATE state; + state.PrepareVector(input, count); + + for (idx_t i = 0; i < (constant ? 1 : count); i++) { + auto idx = state.main_data.sel->get_index(i); + if (!state.main_data.validity.RowIsValid(idx)) { + FlatVector::SetNull(result, i, true); + continue; + } + A_TYPE input; + if (!A_TYPE::ConstructType(state, i, input)) { + FlatVector::SetNull(result, i, true); + continue; + } + RESULT_TYPE::AssignResult(result, i, fun(input)); + } + if (constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + } + + template + static void ExecuteBinaryInternal(Vector &a, Vector &b, Vector &result, idx_t count, FUNC &fun) { + auto constant = + a.GetVectorType() == VectorType::CONSTANT_VECTOR && b.GetVectorType() == VectorType::CONSTANT_VECTOR; + + typename A_TYPE::STRUCT_STATE a_state; + typename B_TYPE::STRUCT_STATE b_state; + a_state.PrepareVector(a, count); + b_state.PrepareVector(b, count); + + for (idx_t i = 0; i < (constant ? 1 : count); i++) { + auto a_idx = a_state.main_data.sel->get_index(i); + auto b_idx = a_state.main_data.sel->get_index(i); + if (!a_state.main_data.validity.RowIsValid(a_idx) || !b_state.main_data.validity.RowIsValid(b_idx)) { + FlatVector::SetNull(result, i, true); + continue; + } + A_TYPE a_val; + B_TYPE b_val; + if (!A_TYPE::ConstructType(a_state, i, a_val) || !B_TYPE::ConstructType(b_state, i, b_val)) { + FlatVector::SetNull(result, i, true); + continue; + } + RESULT_TYPE::AssignResult(result, i, fun(a_val, b_val)); + } + if (constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + } + + template + static void ExecuteTernaryInternal(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUNC &fun) { + auto constant = a.GetVectorType() == VectorType::CONSTANT_VECTOR && + b.GetVectorType() == VectorType::CONSTANT_VECTOR && + c.GetVectorType() == VectorType::CONSTANT_VECTOR; + + typename A_TYPE::STRUCT_STATE a_state; + typename B_TYPE::STRUCT_STATE b_state; + typename C_TYPE::STRUCT_STATE c_state; + + a_state.PrepareVector(a, count); + b_state.PrepareVector(b, count); + c_state.PrepareVector(c, count); + + for (idx_t i = 0; i < (constant ? 1 : count); i++) { + auto a_idx = a_state.main_data.sel->get_index(i); + auto b_idx = a_state.main_data.sel->get_index(i); + auto c_idx = a_state.main_data.sel->get_index(i); + if (!a_state.main_data.validity.RowIsValid(a_idx) || !b_state.main_data.validity.RowIsValid(b_idx) || + !c_state.main_data.validity.RowIsValid(c_idx)) { + FlatVector::SetNull(result, i, true); + continue; + } + A_TYPE a_val; + B_TYPE b_val; + C_TYPE c_val; + if (!A_TYPE::ConstructType(a_state, i, a_val) || !B_TYPE::ConstructType(b_state, i, b_val) || + !C_TYPE::ConstructType(c_state, i, c_val)) { + FlatVector::SetNull(result, i, true); + continue; + } + RESULT_TYPE::AssignResult(result, i, fun(a_val, b_val, c_val)); + } + if (constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + } + + template + static void ExecuteQuaternaryInternal(Vector &a, Vector &b, Vector &c, Vector &d, Vector &result, idx_t count, + FUNC &fun) { + auto constant = + a.GetVectorType() == VectorType::CONSTANT_VECTOR && b.GetVectorType() == VectorType::CONSTANT_VECTOR && + c.GetVectorType() == VectorType::CONSTANT_VECTOR && d.GetVectorType() == VectorType::CONSTANT_VECTOR; + + typename A_TYPE::STRUCT_STATE a_state; + typename B_TYPE::STRUCT_STATE b_state; + typename C_TYPE::STRUCT_STATE c_state; + typename D_TYPE::STRUCT_STATE d_state; + + a_state.PrepareVector(a, count); + b_state.PrepareVector(b, count); + c_state.PrepareVector(c, count); + d_state.PrepareVector(d, count); + + for (idx_t i = 0; i < (constant ? 1 : count); i++) { + auto a_idx = a_state.main_data.sel->get_index(i); + auto b_idx = a_state.main_data.sel->get_index(i); + auto c_idx = a_state.main_data.sel->get_index(i); + auto d_idx = a_state.main_data.sel->get_index(i); + if (!a_state.main_data.validity.RowIsValid(a_idx) || !b_state.main_data.validity.RowIsValid(b_idx) || + !c_state.main_data.validity.RowIsValid(c_idx) || !d_state.main_data.validity.RowIsValid(d_idx)) { + FlatVector::SetNull(result, i, true); + continue; + } + A_TYPE a_val; + B_TYPE b_val; + C_TYPE c_val; + D_TYPE d_val; + if (!A_TYPE::ConstructType(a_state, i, a_val) || !B_TYPE::ConstructType(b_state, i, b_val) || + !C_TYPE::ConstructType(c_state, i, c_val) || !D_TYPE::ConstructType(d_state, i, d_val)) { + FlatVector::SetNull(result, i, true); + continue; + } + RESULT_TYPE::AssignResult(result, i, fun(a_val, b_val, c_val, d_val)); + } + if (constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + } + +public: + template > + static void ExecuteUnary(Vector &input, Vector &result, idx_t count, FUNC fun) { + ExecuteUnaryInternal(input, result, count, fun); + } + template > + static void ExecuteBinary(Vector &a, Vector &b, Vector &result, idx_t count, FUNC fun) { + ExecuteBinaryInternal(a, b, result, count, fun); + } + template > + static void ExecuteTernary(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUNC fun) { + ExecuteTernaryInternal(a, b, c, result, count, fun); + } + template > + static void ExecuteQuaternary(Vector &a, Vector &b, Vector &c, Vector &d, Vector &result, idx_t count, FUNC fun) { + ExecuteQuaternaryInternal(a, b, c, d, result, count, fun); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/senary_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/senary_executor.hpp new file mode 100644 index 00000000..49019916 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/senary_executor.hpp @@ -0,0 +1,102 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/senary_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" + +#include + +namespace duckdb { + +struct SenaryExecutor { + static const size_t NCOLS = 6; + + template > + static void Execute(DataChunk &input, Vector &result, FUN fun) { + D_ASSERT(input.ColumnCount() >= NCOLS); + const auto count = input.size(); + + bool all_constant = true; + bool any_null = false; + for (const auto &v : input.data) { + if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(v)) { + any_null = true; + } + } else { + all_constant = false; + break; + } + } + + if (all_constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + if (any_null) { + ConstantVector::SetNull(result, true); + } else { + auto adata = ConstantVector::GetData(input.data[0]); + auto bdata = ConstantVector::GetData(input.data[1]); + auto cdata = ConstantVector::GetData(input.data[2]); + auto ddata = ConstantVector::GetData(input.data[3]); + auto edata = ConstantVector::GetData(input.data[4]); + auto fdata = ConstantVector::GetData(input.data[5]); + auto result_data = ConstantVector::GetData(result); + result_data[0] = fun(*adata, *bdata, *cdata, *ddata, *edata, *fdata); + } + } else { + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + bool all_valid = true; + vector vdata(NCOLS); + for (size_t c = 0; c < NCOLS; ++c) { + input.data[c].ToUnifiedFormat(count, vdata[c]); + all_valid = all_valid && vdata[c].validity.AllValid(); + } + + auto adata = (const TA *)(vdata[0].data); + auto bdata = (const TB *)(vdata[1].data); + auto cdata = (const TC *)(vdata[2].data); + auto ddata = (const TD *)(vdata[3].data); + auto edata = (const TE *)(vdata[4].data); + auto fdata = (const TF *)(vdata[5].data); + + vector idx(NCOLS); + if (all_valid) { + for (idx_t r = 0; r < count; ++r) { + for (size_t c = 0; c < NCOLS; ++c) { + idx[c] = vdata[c].sel->get_index(r); + } + result_data[r] = + fun(adata[idx[0]], bdata[idx[1]], cdata[idx[2]], ddata[idx[3]], edata[idx[4]], fdata[idx[5]]); + } + } else { + for (idx_t r = 0; r < count; ++r) { + all_valid = true; + for (size_t c = 0; c < NCOLS; ++c) { + idx[c] = vdata[c].sel->get_index(r); + if (!vdata[c].validity.RowIsValid(idx[c])) { + result_validity.SetInvalid(r); + all_valid = false; + break; + } + } + if (all_valid) { + result_data[r] = fun(adata[idx[0]], bdata[idx[1]], cdata[idx[2]], ddata[idx[3]], edata[idx[4]], + fdata[idx[5]]); + } + } + } + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/septenary_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/septenary_executor.hpp new file mode 100644 index 00000000..f727c0cc --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/septenary_executor.hpp @@ -0,0 +1,104 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/septenary_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" + +#include + +namespace duckdb { + +struct SeptenaryExecutor { + static const size_t NCOLS = 7; + + template > + static void Execute(DataChunk &input, Vector &result, FUN fun) { + D_ASSERT(input.ColumnCount() >= NCOLS); + const auto count = input.size(); + + bool all_constant = true; + bool any_null = false; + for (const auto &v : input.data) { + if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(v)) { + any_null = true; + } + } else { + all_constant = false; + break; + } + } + + if (all_constant) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + if (any_null) { + ConstantVector::SetNull(result, true); + } else { + auto adata = ConstantVector::GetData(input.data[0]); + auto bdata = ConstantVector::GetData(input.data[1]); + auto cdata = ConstantVector::GetData(input.data[2]); + auto ddata = ConstantVector::GetData(input.data[3]); + auto edata = ConstantVector::GetData(input.data[4]); + auto fdata = ConstantVector::GetData(input.data[5]); + auto gdata = ConstantVector::GetData(input.data[6]); + auto result_data = ConstantVector::GetData(result); + result_data[0] = fun(*adata, *bdata, *cdata, *ddata, *edata, *fdata, *gdata); + } + } else { + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + bool all_valid = true; + vector vdata(NCOLS); + for (size_t c = 0; c < NCOLS; ++c) { + input.data[c].ToUnifiedFormat(count, vdata[c]); + all_valid = all_valid && vdata[c].validity.AllValid(); + } + + auto adata = (const TA *)(vdata[0].data); + auto bdata = (const TB *)(vdata[1].data); + auto cdata = (const TC *)(vdata[2].data); + auto ddata = (const TD *)(vdata[3].data); + auto edata = (const TE *)(vdata[4].data); + auto fdata = (const TF *)(vdata[5].data); + auto gdata = (const TG *)(vdata[6].data); + + vector idx(NCOLS); + if (all_valid) { + for (idx_t r = 0; r < count; ++r) { + for (size_t c = 0; c < NCOLS; ++c) { + idx[c] = vdata[c].sel->get_index(r); + } + result_data[r] = fun(adata[idx[0]], bdata[idx[1]], cdata[idx[2]], ddata[idx[3]], edata[idx[4]], + fdata[idx[5]], gdata[idx[6]]); + } + } else { + for (idx_t r = 0; r < count; ++r) { + all_valid = true; + for (size_t c = 0; c < NCOLS; ++c) { + idx[c] = vdata[c].sel->get_index(r); + if (!vdata[c].validity.RowIsValid(idx[c])) { + result_validity.SetInvalid(r); + all_valid = false; + break; + } + } + if (all_valid) { + result_data[r] = fun(adata[idx[0]], bdata[idx[1]], cdata[idx[2]], ddata[idx[3]], edata[idx[4]], + fdata[idx[5]], gdata[idx[6]]); + } + } + } + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/ternary_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/ternary_executor.hpp new file mode 100644 index 00000000..5fea2ab4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/ternary_executor.hpp @@ -0,0 +1,208 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/ternary_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include + +namespace duckdb { + +template +struct TernaryStandardOperatorWrapper { + template + static inline RESULT_TYPE Operation(FUN fun, A_TYPE a, B_TYPE b, C_TYPE c, ValidityMask &mask, idx_t idx) { + return OP::template Operation(a, b, c); + } +}; + +struct TernaryLambdaWrapper { + template + static inline RESULT_TYPE Operation(FUN fun, A_TYPE a, B_TYPE b, C_TYPE c, ValidityMask &mask, idx_t idx) { + return fun(a, b, c); + } +}; + +struct TernaryLambdaWrapperWithNulls { + template + static inline RESULT_TYPE Operation(FUN fun, A_TYPE a, B_TYPE b, C_TYPE c, ValidityMask &mask, idx_t idx) { + return fun(a, b, c, mask, idx); + } +}; + +struct TernaryExecutor { +private: + template + static inline void ExecuteLoop(const A_TYPE *__restrict adata, const B_TYPE *__restrict bdata, + const C_TYPE *__restrict cdata, RESULT_TYPE *__restrict result_data, idx_t count, + const SelectionVector &asel, const SelectionVector &bsel, + const SelectionVector &csel, ValidityMask &avalidity, ValidityMask &bvalidity, + ValidityMask &cvalidity, ValidityMask &result_validity, FUN fun) { + if (!avalidity.AllValid() || !bvalidity.AllValid() || !cvalidity.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto aidx = asel.get_index(i); + auto bidx = bsel.get_index(i); + auto cidx = csel.get_index(i); + if (avalidity.RowIsValid(aidx) && bvalidity.RowIsValid(bidx) && cvalidity.RowIsValid(cidx)) { + result_data[i] = OPWRAPPER::template Operation( + fun, adata[aidx], bdata[bidx], cdata[cidx], result_validity, i); + } else { + result_validity.SetInvalid(i); + } + } + } else { + for (idx_t i = 0; i < count; i++) { + auto aidx = asel.get_index(i); + auto bidx = bsel.get_index(i); + auto cidx = csel.get_index(i); + result_data[i] = OPWRAPPER::template Operation( + fun, adata[aidx], bdata[bidx], cdata[cidx], result_validity, i); + } + } + } + +public: + template + static void ExecuteGeneric(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUN fun) { + if (a.GetVectorType() == VectorType::CONSTANT_VECTOR && b.GetVectorType() == VectorType::CONSTANT_VECTOR && + c.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + if (ConstantVector::IsNull(a) || ConstantVector::IsNull(b) || ConstantVector::IsNull(c)) { + ConstantVector::SetNull(result, true); + } else { + auto adata = ConstantVector::GetData(a); + auto bdata = ConstantVector::GetData(b); + auto cdata = ConstantVector::GetData(c); + auto result_data = ConstantVector::GetData(result); + auto &result_validity = ConstantVector::Validity(result); + result_data[0] = OPWRAPPER::template Operation( + fun, adata[0], bdata[0], cdata[0], result_validity, 0); + } + } else { + result.SetVectorType(VectorType::FLAT_VECTOR); + + UnifiedVectorFormat adata, bdata, cdata; + a.ToUnifiedFormat(count, adata); + b.ToUnifiedFormat(count, bdata); + c.ToUnifiedFormat(count, cdata); + + ExecuteLoop( + UnifiedVectorFormat::GetData(adata), UnifiedVectorFormat::GetData(bdata), + UnifiedVectorFormat::GetData(cdata), FlatVector::GetData(result), count, + *adata.sel, *bdata.sel, *cdata.sel, adata.validity, bdata.validity, cdata.validity, + FlatVector::Validity(result), fun); + } + } + + template > + static void Execute(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUN fun) { + ExecuteGeneric(a, b, c, result, count, fun); + } + + template + static void ExecuteStandard(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count) { + ExecuteGeneric, bool>(a, b, c, result, + count, false); + } + + template > + static void ExecuteWithNulls(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUN fun) { + ExecuteGeneric(a, b, c, result, count, + fun); + } + +private: + template + static inline idx_t SelectLoop(const A_TYPE *__restrict adata, const B_TYPE *__restrict bdata, + const C_TYPE *__restrict cdata, const SelectionVector *result_sel, idx_t count, + const SelectionVector &asel, const SelectionVector &bsel, + const SelectionVector &csel, ValidityMask &avalidity, ValidityMask &bvalidity, + ValidityMask &cvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { + idx_t true_count = 0, false_count = 0; + for (idx_t i = 0; i < count; i++) { + auto result_idx = result_sel->get_index(i); + auto aidx = asel.get_index(i); + auto bidx = bsel.get_index(i); + auto cidx = csel.get_index(i); + bool comparison_result = + (NO_NULL || (avalidity.RowIsValid(aidx) && bvalidity.RowIsValid(bidx) && cvalidity.RowIsValid(cidx))) && + OP::Operation(adata[aidx], bdata[bidx], cdata[cidx]); + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count, result_idx); + true_count += comparison_result; + } + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count, result_idx); + false_count += !comparison_result; + } + } + if (HAS_TRUE_SEL) { + return true_count; + } else { + return count - false_count; + } + } + + template + static inline idx_t SelectLoopSelSwitch(UnifiedVectorFormat &adata, UnifiedVectorFormat &bdata, + UnifiedVectorFormat &cdata, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + if (true_sel && false_sel) { + return SelectLoop( + UnifiedVectorFormat::GetData(adata), UnifiedVectorFormat::GetData(bdata), + UnifiedVectorFormat::GetData(cdata), sel, count, *adata.sel, *bdata.sel, *cdata.sel, + adata.validity, bdata.validity, cdata.validity, true_sel, false_sel); + } else if (true_sel) { + return SelectLoop( + UnifiedVectorFormat::GetData(adata), UnifiedVectorFormat::GetData(bdata), + UnifiedVectorFormat::GetData(cdata), sel, count, *adata.sel, *bdata.sel, *cdata.sel, + adata.validity, bdata.validity, cdata.validity, true_sel, false_sel); + } else { + D_ASSERT(false_sel); + return SelectLoop( + UnifiedVectorFormat::GetData(adata), UnifiedVectorFormat::GetData(bdata), + UnifiedVectorFormat::GetData(cdata), sel, count, *adata.sel, *bdata.sel, *cdata.sel, + adata.validity, bdata.validity, cdata.validity, true_sel, false_sel); + } + } + + template + static inline idx_t SelectLoopSwitch(UnifiedVectorFormat &adata, UnifiedVectorFormat &bdata, + UnifiedVectorFormat &cdata, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + if (!adata.validity.AllValid() || !bdata.validity.AllValid() || !cdata.validity.AllValid()) { + return SelectLoopSelSwitch(adata, bdata, cdata, sel, count, true_sel, + false_sel); + } else { + return SelectLoopSelSwitch(adata, bdata, cdata, sel, count, true_sel, + false_sel); + } + } + +public: + template + static idx_t Select(Vector &a, Vector &b, Vector &c, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { + if (!sel) { + sel = FlatVector::IncrementalSelectionVector(); + } + UnifiedVectorFormat adata, bdata, cdata; + a.ToUnifiedFormat(count, adata); + b.ToUnifiedFormat(count, bdata); + c.ToUnifiedFormat(count, cdata); + + return SelectLoopSwitch(adata, bdata, cdata, sel, count, true_sel, false_sel); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/unary_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/unary_executor.hpp new file mode 100644 index 00000000..d7f16f78 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/unary_executor.hpp @@ -0,0 +1,217 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/unary_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include + +namespace duckdb { + +struct UnaryOperatorWrapper { + template + static inline RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + return OP::template Operation(input); + } +}; + +struct UnaryLambdaWrapper { + template + static inline RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto fun = (FUNC *)dataptr; + return (*fun)(input); + } +}; + +struct GenericUnaryWrapper { + template + static inline RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + return OP::template Operation(input, mask, idx, dataptr); + } +}; + +struct UnaryLambdaWrapperWithNulls { + template + static inline RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto fun = (FUNC *)dataptr; + return (*fun)(input, mask, idx); + } +}; + +template +struct UnaryStringOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto vector = (Vector *)dataptr; + return OP::template Operation(input, *vector); + } +}; + +struct UnaryExecutor { +private: + template + static inline void ExecuteLoop(const INPUT_TYPE *__restrict ldata, RESULT_TYPE *__restrict result_data, idx_t count, + const SelectionVector *__restrict sel_vector, ValidityMask &mask, + ValidityMask &result_mask, void *dataptr, bool adds_nulls) { +#ifdef DEBUG + // ldata may point to a compressed dictionary buffer which can be smaller than ldata + count + idx_t max_index = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = sel_vector->get_index(i); + max_index = MaxValue(max_index, idx); + } + ASSERT_RESTRICT(ldata, ldata + max_index, result_data, result_data + count); +#endif + + if (!mask.AllValid()) { + result_mask.EnsureWritable(); + for (idx_t i = 0; i < count; i++) { + auto idx = sel_vector->get_index(i); + if (mask.RowIsValidUnsafe(idx)) { + result_data[i] = + OPWRAPPER::template Operation(ldata[idx], result_mask, i, dataptr); + } else { + result_mask.SetInvalid(i); + } + } + } else { + if (adds_nulls) { + result_mask.EnsureWritable(); + } + for (idx_t i = 0; i < count; i++) { + auto idx = sel_vector->get_index(i); + result_data[i] = + OPWRAPPER::template Operation(ldata[idx], result_mask, i, dataptr); + } + } + } + + template + static inline void ExecuteFlat(const INPUT_TYPE *__restrict ldata, RESULT_TYPE *__restrict result_data, idx_t count, + ValidityMask &mask, ValidityMask &result_mask, void *dataptr, bool adds_nulls) { + ASSERT_RESTRICT(ldata, ldata + count, result_data, result_data + count); + + if (!mask.AllValid()) { + if (!adds_nulls) { + result_mask.Initialize(mask); + } else { + result_mask.Copy(mask, count); + } + idx_t base_idx = 0; + auto entry_count = ValidityMask::EntryCount(count); + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { + auto validity_entry = mask.GetValidityEntry(entry_idx); + idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (ValidityMask::AllValid(validity_entry)) { + // all valid: perform operation + for (; base_idx < next; base_idx++) { + result_data[base_idx] = OPWRAPPER::template Operation( + ldata[base_idx], result_mask, base_idx, dataptr); + } + } else if (ValidityMask::NoneValid(validity_entry)) { + // nothing valid: skip all + base_idx = next; + continue; + } else { + // partially valid: need to check individual elements for validity + idx_t start = base_idx; + for (; base_idx < next; base_idx++) { + if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { + D_ASSERT(mask.RowIsValid(base_idx)); + result_data[base_idx] = OPWRAPPER::template Operation( + ldata[base_idx], result_mask, base_idx, dataptr); + } + } + } + } + } else { + if (adds_nulls) { + result_mask.EnsureWritable(); + } + for (idx_t i = 0; i < count; i++) { + result_data[i] = + OPWRAPPER::template Operation(ldata[i], result_mask, i, dataptr); + } + } + } + + template + static inline void ExecuteStandard(Vector &input, Vector &result, idx_t count, void *dataptr, bool adds_nulls) { + switch (input.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto result_data = ConstantVector::GetData(result); + auto ldata = ConstantVector::GetData(input); + + if (ConstantVector::IsNull(input)) { + ConstantVector::SetNull(result, true); + } else { + ConstantVector::SetNull(result, false); + *result_data = OPWRAPPER::template Operation( + *ldata, ConstantVector::Validity(result), 0, dataptr); + } + break; + } + case VectorType::FLAT_VECTOR: { + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto ldata = FlatVector::GetData(input); + + ExecuteFlat(ldata, result_data, count, FlatVector::Validity(input), + FlatVector::Validity(result), dataptr, adds_nulls); + break; + } + default: { + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto ldata = UnifiedVectorFormat::GetData(vdata); + + ExecuteLoop(ldata, result_data, count, vdata.sel, vdata.validity, + FlatVector::Validity(result), dataptr, adds_nulls); + break; + } + } + } + +public: + template + static void Execute(Vector &input, Vector &result, idx_t count) { + ExecuteStandard(input, result, count, nullptr, false); + } + + template > + static void Execute(Vector &input, Vector &result, idx_t count, FUNC fun) { + ExecuteStandard(input, result, count, (void *)&fun, false); + } + + template + static void GenericExecute(Vector &input, Vector &result, idx_t count, void *dataptr, bool adds_nulls = false) { + ExecuteStandard(input, result, count, dataptr, adds_nulls); + } + + template > + static void ExecuteWithNulls(Vector &input, Vector &result, idx_t count, FUNC fun) { + ExecuteStandard(input, result, count, (void *)&fun, + true); + } + + template + static void ExecuteString(Vector &input, Vector &result, idx_t count) { + UnaryExecutor::GenericExecute>(input, result, count, + (void *)&result); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/vector_operations.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/vector_operations.hpp new file mode 100644 index 00000000..f4966df4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_operations/vector_operations.hpp @@ -0,0 +1,184 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_operations/vector_operations.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/types/vector.hpp" + +#include + +namespace duckdb { +class CastFunctionSet; +struct GetCastFunctionInput; + +// VectorOperations contains a set of operations that operate on sets of +// vectors. In general, the operators must all have the same type, otherwise an +// exception is thrown. Note that the functions underneath use restrict +// pointers, hence the data that the vectors point to (and hence the vector +// themselves) should not be equal! For example, if you call the function Add(A, +// B, A) then ASSERT_RESTRICT will be triggered. Instead call AddInPlace(A, B) +// or Add(A, B, C) +struct VectorOperations { + //===--------------------------------------------------------------------===// + // In-Place Operators + //===--------------------------------------------------------------------===// + //! left += delta + static void AddInPlace(Vector &left, int64_t delta, idx_t count); + + //===--------------------------------------------------------------------===// + // NULL Operators + //===--------------------------------------------------------------------===// + //! result = IS NOT NULL(input) + static void IsNotNull(Vector &arg, Vector &result, idx_t count); + //! result = IS NULL (input) + static void IsNull(Vector &input, Vector &result, idx_t count); + // Returns whether or not arg vector has a NULL value + static bool HasNull(Vector &input, idx_t count); + static bool HasNotNull(Vector &input, idx_t count); + //! Count the number of not-NULL values. + static idx_t CountNotNull(Vector &input, const idx_t count); + + //===--------------------------------------------------------------------===// + // Boolean Operations + //===--------------------------------------------------------------------===// + // result = left && right + static void And(Vector &left, Vector &right, Vector &result, idx_t count); + // result = left || right + static void Or(Vector &left, Vector &right, Vector &result, idx_t count); + // result = NOT(left) + static void Not(Vector &left, Vector &result, idx_t count); + + //===--------------------------------------------------------------------===// + // Comparison Operations + //===--------------------------------------------------------------------===// + // result = left == right + static void Equals(Vector &left, Vector &right, Vector &result, idx_t count); + // result = left != right + static void NotEquals(Vector &left, Vector &right, Vector &result, idx_t count); + // result = left > right + static void GreaterThan(Vector &left, Vector &right, Vector &result, idx_t count); + // result = left >= right + static void GreaterThanEquals(Vector &left, Vector &right, Vector &result, idx_t count); + // result = left < right + static void LessThan(Vector &left, Vector &right, Vector &result, idx_t count); + // result = left <= right + static void LessThanEquals(Vector &left, Vector &right, Vector &result, idx_t count); + + // result = A != B with nulls being equal + static void DistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count); + // result := A == B with nulls being equal + static void NotDistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count); + // result := A > B with nulls being maximal + static void DistinctGreaterThan(Vector &left, Vector &right, Vector &result, idx_t count); + // result := A >= B with nulls being maximal + static void DistinctGreaterThanEquals(Vector &left, Vector &right, Vector &result, idx_t count); + // result := A < B with nulls being maximal + static void DistinctLessThan(Vector &left, Vector &right, Vector &result, idx_t count); + // result := A <= B with nulls being maximal + static void DistinctLessThanEquals(Vector &left, Vector &right, Vector &result, idx_t count); + + //===--------------------------------------------------------------------===// + // Select Comparisons + //===--------------------------------------------------------------------===// + static idx_t Equals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel); + static idx_t NotEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + static idx_t GreaterThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + static idx_t GreaterThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + static idx_t LessThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + static idx_t LessThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + + // true := A != B with nulls being equal + static idx_t DistinctFrom(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + // true := A == B with nulls being equal + static idx_t NotDistinctFrom(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + // true := A > B with nulls being maximal + static idx_t DistinctGreaterThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + // true := A >= B with nulls being maximal + static idx_t DistinctGreaterThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + // true := A < B with nulls being maximal + static idx_t DistinctLessThan(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + // true := A <= B with nulls being maximal + static idx_t DistinctLessThanEquals(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + + // true := A > B with nulls being minimal + static idx_t DistinctGreaterThanNullsFirst(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + // true := A < B with nulls being minimal + static idx_t DistinctLessThanNullsFirst(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + + //===--------------------------------------------------------------------===// + // Nested Comparisons + //===--------------------------------------------------------------------===// + // true := A != B with nulls being equal + static idx_t NestedNotEquals(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + // true := A == B with nulls being equal + static idx_t NestedEquals(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + + //===--------------------------------------------------------------------===// + // Hash functions + //===--------------------------------------------------------------------===// + // hashes = HASH(input) + static void Hash(Vector &input, Vector &hashes, idx_t count); + static void Hash(Vector &input, Vector &hashes, const SelectionVector &rsel, idx_t count); + // hashes ^= HASH(input) + static void CombineHash(Vector &hashes, Vector &input, idx_t count); + static void CombineHash(Vector &hashes, Vector &input, const SelectionVector &rsel, idx_t count); + + //===--------------------------------------------------------------------===// + // Generate functions + //===--------------------------------------------------------------------===// + static void GenerateSequence(Vector &result, idx_t count, int64_t start = 0, int64_t increment = 1); + static void GenerateSequence(Vector &result, idx_t count, const SelectionVector &sel, int64_t start = 0, + int64_t increment = 1); + //===--------------------------------------------------------------------===// + // Helpers + //===--------------------------------------------------------------------===// + //! Cast the data from the source type to the target type. Any elements that could not be converted are turned into + //! NULLs. If any elements cannot be converted, returns false and fills in the error_message. If no error message is + //! provided, an exception is thrown instead. + DUCKDB_API static bool TryCast(CastFunctionSet &set, GetCastFunctionInput &input, Vector &source, Vector &result, + idx_t count, string *error_message, bool strict = false); + DUCKDB_API static bool DefaultTryCast(Vector &source, Vector &result, idx_t count, string *error_message, + bool strict = false); + DUCKDB_API static bool TryCast(ClientContext &context, Vector &source, Vector &result, idx_t count, + string *error_message, bool strict = false); + //! Cast the data from the source type to the target type. Throws an exception if the cast fails. + DUCKDB_API static void Cast(ClientContext &context, Vector &source, Vector &result, idx_t count, + bool strict = false); + DUCKDB_API static void DefaultCast(Vector &source, Vector &result, idx_t count, bool strict = false); + + // Copy the data of to the target vector + static void Copy(const Vector &source, Vector &target, idx_t source_count, idx_t source_offset, + idx_t target_offset); + static void Copy(const Vector &source, Vector &target, const SelectionVector &sel, idx_t source_count, + idx_t source_offset, idx_t target_offset); + + // Copy the data of to the target location, setting null values to + // NullValue. Used to store data without separate NULL mask. + static void WriteToStorage(Vector &source, idx_t count, data_ptr_t target); + // Reads the data of to the target vector, setting the nullmask + // for any NullValue of source. Used to go back from storage to a proper vector + static void ReadFromStorage(data_ptr_t source, idx_t count, Vector &result); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_size.hpp b/src/duckdb/src/include/duckdb/common/vector_size.hpp new file mode 100644 index 00000000..db28133a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/vector_size.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/vector_size.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" + +namespace duckdb { + +//! The vector size used in the execution engine +#ifndef STANDARD_VECTOR_SIZE +#define STANDARD_VECTOR_SIZE 2048 +#endif + +#if ((STANDARD_VECTOR_SIZE & (STANDARD_VECTOR_SIZE - 1)) != 0) +#error Vector size should be a power of two +#endif + +//! Zero selection vector: completely filled with the value 0 [READ ONLY] +extern const sel_t ZERO_VECTOR[STANDARD_VECTOR_SIZE]; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp new file mode 100644 index 00000000..69990bcb --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp @@ -0,0 +1,84 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/virtual_file_system.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +// bunch of wrappers to allow registering protocol handlers +class VirtualFileSystem : public FileSystem { +public: + VirtualFileSystem(); + + unique_ptr OpenFile(const string &path, uint8_t flags, FileLockType lock = FileLockType::NO_LOCK, + FileCompressionType compression = FileCompressionType::UNCOMPRESSED, + FileOpener *opener = nullptr) override; + + void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + void Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + + int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + + int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + + int64_t GetFileSize(FileHandle &handle) override; + time_t GetLastModifiedTime(FileHandle &handle) override; + FileType GetFileType(FileHandle &handle) override; + + void Truncate(FileHandle &handle, int64_t new_size) override; + + void FileSync(FileHandle &handle) override; + + // need to look up correct fs for this + bool DirectoryExists(const string &directory) override; + void CreateDirectory(const string &directory) override; + + void RemoveDirectory(const string &directory) override; + + bool ListFiles(const string &directory, const std::function &callback, + FileOpener *opener = nullptr) override; + + void MoveFile(const string &source, const string &target) override; + + bool FileExists(const string &filename) override; + + bool IsPipe(const string &filename) override; + virtual void RemoveFile(const string &filename) override; + + virtual vector Glob(const string &path, FileOpener *opener = nullptr) override; + + void RegisterSubSystem(unique_ptr fs) override; + + void UnregisterSubSystem(const string &name) override; + + void RegisterSubSystem(FileCompressionType compression_type, unique_ptr fs) override; + + vector ListSubSystems() override; + + std::string GetName() const override; + + void SetDisabledFileSystems(const vector &names) override; + + string PathSeparator(const string &path) override; + +private: + FileSystem &FindFileSystem(const string &path); + FileSystem &FindFileSystemInternal(const string &path); + +private: + vector> sub_systems; + map> compressed_fs; + const unique_ptr default_fs; + unordered_set disabled_file_systems; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/winapi.hpp b/src/duckdb/src/include/duckdb/common/winapi.hpp new file mode 100644 index 00000000..1b77335c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/winapi.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/winapi.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#ifndef DUCKDB_API +#if defined(_WIN32) && !defined(__MINGW32__) +#if defined(DUCKDB_BUILD_LIBRARY) && !defined(DUCKDB_BUILD_LOADABLE_EXTENSION) +#define DUCKDB_API __declspec(dllexport) +#else +#define DUCKDB_API __declspec(dllimport) +#endif +#else +#define DUCKDB_API +#endif +#endif + +#ifndef DUCKDB_EXTENSION_API +#ifdef _WIN32 +#ifdef DUCKDB_BUILD_LOADABLE_EXTENSION +#define DUCKDB_EXTENSION_API __declspec(dllexport) +#else +#define DUCKDB_EXTENSION_API +#endif +#else +#define DUCKDB_EXTENSION_API __attribute__((visibility("default"))) +#endif +#endif diff --git a/src/duckdb/src/include/duckdb/common/windows.hpp b/src/duckdb/src/include/duckdb/common/windows.hpp new file mode 100644 index 00000000..dc1b5653 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/windows.hpp @@ -0,0 +1,19 @@ +#pragma once + +#if defined(_WIN32) + +#ifndef NOMINMAX +#define NOMINMAX +#endif + +#ifndef _WINSOCKAPI_ +#define _WINSOCKAPI_ +#endif + +#include + +#undef CreateDirectory +#undef MoveFile +#undef RemoveDirectory + +#endif diff --git a/src/duckdb/src/include/duckdb/common/windows_undefs.hpp b/src/duckdb/src/include/duckdb/common/windows_undefs.hpp new file mode 100644 index 00000000..7b4866d8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/windows_undefs.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/windows_undefs.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#ifdef WIN32 + +#ifdef min +#undef min +#endif + +#ifdef max +#undef max +#endif + +#ifdef ERROR +#undef ERROR +#endif + +#ifdef small +#undef small +#endif + +#ifdef CreateDirectory +#undef CreateDirectory +#endif + +#ifdef MoveFile +#undef MoveFile +#endif + +#ifdef RemoveDirectory +#undef RemoveDirectory +#endif + +#endif diff --git a/src/duckdb/src/include/duckdb/common/windows_util.hpp b/src/duckdb/src/include/duckdb/common/windows_util.hpp new file mode 100644 index 00000000..18a02d09 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/windows_util.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/windows_util.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/windows.hpp" + +namespace duckdb { + +#ifdef DUCKDB_WINDOWS +class WindowsUtil { +public: + //! Windows helper functions + static std::wstring UTF8ToUnicode(const char *input); + static string UnicodeToUTF8(LPCWSTR input); + static string UTF8ToMBCS(const char *input, bool use_ansi = false); +}; +#endif + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/corr.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/corr.hpp new file mode 100644 index 00000000..0d595b11 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/corr.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/algebraic/corr.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/core_functions/aggregate/algebraic/covar.hpp" +#include "duckdb/core_functions/aggregate/algebraic/stddev.hpp" + +namespace duckdb { + +struct CorrState { + CovarState cov_pop; + StddevState dev_pop_x; + StddevState dev_pop_y; +}; + +// Returns the correlation coefficient for non-null pairs in a group. +// CORR(y, x) = COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y)) +struct CorrOperation { + template + static void Initialize(STATE &state) { + CovarOperation::Initialize(state.cov_pop); + STDDevBaseOperation::Initialize(state.dev_pop_x); + STDDevBaseOperation::Initialize(state.dev_pop_y); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + CovarOperation::Operation(state.cov_pop, y, x, idata); + STDDevBaseOperation::Execute(state.dev_pop_x, x); + STDDevBaseOperation::Execute(state.dev_pop_y, y); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); + STDDevBaseOperation::Combine(source.dev_pop_x, target.dev_pop_x, aggr_input_data); + STDDevBaseOperation::Combine(source.dev_pop_y, target.dev_pop_y, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.cov_pop.count == 0 || state.dev_pop_x.count == 0 || state.dev_pop_y.count == 0) { + finalize_data.ReturnNull(); + } else { + auto cov = state.cov_pop.co_moment / state.cov_pop.count; + auto std_x = state.dev_pop_x.count > 1 ? sqrt(state.dev_pop_x.dsquared / state.dev_pop_x.count) : 0; + if (!Value::DoubleIsFinite(std_x)) { + throw OutOfRangeException("STDDEV_POP for X is out of range!"); + } + auto std_y = state.dev_pop_y.count > 1 ? sqrt(state.dev_pop_y.dsquared / state.dev_pop_y.count) : 0; + if (!Value::DoubleIsFinite(std_y)) { + throw OutOfRangeException("STDDEV_POP for Y is out of range!"); + } + if (std_x * std_y == 0) { + finalize_data.ReturnNull(); + return; + } + target = cov / (std_x * std_y); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/covar.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/covar.hpp new file mode 100644 index 00000000..79d7db1e --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/covar.hpp @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/algebraic/covar.hpp +// +// +//===----------------------------------------------------------------------===// +// COVAR_POP(y,x) + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" + +namespace duckdb { + +struct CovarState { + uint64_t count; + double meanx; + double meany; + double co_moment; +}; + +struct CovarOperation { + template + static void Initialize(STATE &state) { + state.count = 0; + state.meanx = 0; + state.meany = 0; + state.co_moment = 0; + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + // update running mean and d^2 + const uint64_t n = ++(state.count); + + const double dx = (x - state.meanx); + const double meanx = state.meanx + dx / n; + + const double dy = (y - state.meany); + const double meany = state.meany + dy / n; + + // Schubert and Gertz SSDBM 2018 (4.3) + const double C = state.co_moment + dx * (y - meany); + + state.meanx = meanx; + state.meany = meany; + state.co_moment = C; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (target.count == 0) { + target = source; + } else if (source.count > 0) { + const auto count = target.count + source.count; + const auto meanx = (source.count * source.meanx + target.count * target.meanx) / count; + const auto meany = (source.count * source.meany + target.count * target.meany) / count; + + // Schubert and Gertz SSDBM 2018, equation 21 + const auto deltax = target.meanx - source.meanx; + const auto deltay = target.meany - source.meany; + target.co_moment = + source.co_moment + target.co_moment + deltax * deltay * source.count * target.count / count; + target.meanx = meanx; + target.meany = meany; + target.count = count; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct CovarPopOperation : public CovarOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.co_moment / state.count; + } + } +}; + +struct CovarSampOperation : public CovarOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count < 2) { + finalize_data.ReturnNull(); + } else { + target = state.co_moment / (state.count - 1); + } + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/stddev.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/stddev.hpp new file mode 100644 index 00000000..88063eea --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic/stddev.hpp @@ -0,0 +1,147 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/algebraic/stddev.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" +#include + +namespace duckdb { + +struct StddevState { + uint64_t count; // n + double mean; // M1 + double dsquared; // M2 +}; + +// Streaming approximate standard deviation using Welford's +// method, DOI: 10.2307/1266577 +struct STDDevBaseOperation { + template + static void Initialize(STATE &state) { + state.count = 0; + state.mean = 0; + state.dsquared = 0; + } + + template + static void Execute(STATE &state, const INPUT_TYPE &input) { + // update running mean and d^2 + state.count++; + const double mean_differential = (input - state.mean) / state.count; + const double new_mean = state.mean + mean_differential; + const double dsquared_increment = (input - new_mean) * (input - state.mean); + const double new_dsquared = state.dsquared + dsquared_increment; + + state.mean = new_mean; + state.dsquared = new_dsquared; + + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + Execute(state, input); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (target.count == 0) { + target = source; + } else if (source.count > 0) { + const auto count = target.count + source.count; + const auto mean = (source.count * source.mean + target.count * target.mean) / count; + const auto delta = source.mean - target.mean; + target.dsquared = + source.dsquared + target.dsquared + delta * delta * source.count * target.count / count; + target.mean = mean; + target.count = count; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct VarSampOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count <= 1) { + finalize_data.ReturnNull(); + } else { + target = state.dsquared / (state.count - 1); + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("VARSAMP is out of range!"); + } + } + } +}; + +struct VarPopOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.count > 1 ? (state.dsquared / state.count) : 0; + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("VARPOP is out of range!"); + } + } + } +}; + +struct STDDevSampOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count <= 1) { + finalize_data.ReturnNull(); + } else { + target = sqrt(state.dsquared / (state.count - 1)); + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("STDDEV_SAMP is out of range!"); + } + } + } +}; + +struct STDDevPopOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.count > 1 ? sqrt(state.dsquared / state.count) : 0; + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("STDDEV_POP is out of range!"); + } + } + } +}; + +struct StandardErrorOfTheMeanOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = sqrt(state.dsquared / state.count) / sqrt((state.count)); + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("SEM is out of range!"); + } + } + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic_functions.hpp new file mode 100644 index 00000000..41ecd18a --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/algebraic_functions.hpp @@ -0,0 +1,126 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/algebraic_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct AvgFun { + static constexpr const char *Name = "avg"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Calculates the average value for all tuples in x."; + static constexpr const char *Example = "SUM(x) / COUNT(*)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct MeanFun { + using ALIAS = AvgFun; + + static constexpr const char *Name = "mean"; +}; + +struct CorrFun { + static constexpr const char *Name = "corr"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the correlation coefficient for non-null pairs in a group."; + static constexpr const char *Example = "COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y))"; + + static AggregateFunction GetFunction(); +}; + +struct CovarPopFun { + static constexpr const char *Name = "covar_pop"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the population covariance of input values."; + static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / COUNT(*)"; + + static AggregateFunction GetFunction(); +}; + +struct CovarSampFun { + static constexpr const char *Name = "covar_samp"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the sample covariance for non-null pairs in a group."; + static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / (COUNT(*) - 1)"; + + static AggregateFunction GetFunction(); +}; + +struct FAvgFun { + static constexpr const char *Name = "favg"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Calculates the average using a more accurate floating point summation (Kahan Sum)"; + static constexpr const char *Example = "favg(A)"; + + static AggregateFunction GetFunction(); +}; + +struct StandardErrorOfTheMeanFun { + static constexpr const char *Name = "sem"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the standard error of the mean"; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct StdDevPopFun { + static constexpr const char *Name = "stddev_pop"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the population standard deviation."; + static constexpr const char *Example = "sqrt(var_pop(x))"; + + static AggregateFunction GetFunction(); +}; + +struct StdDevSampFun { + static constexpr const char *Name = "stddev_samp"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the sample standard deviation"; + static constexpr const char *Example = "sqrt(var_samp(x))"; + + static AggregateFunction GetFunction(); +}; + +struct StddevFun { + using ALIAS = StdDevSampFun; + + static constexpr const char *Name = "stddev"; +}; + +struct VarPopFun { + static constexpr const char *Name = "var_pop"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the population variance."; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct VarSampFun { + static constexpr const char *Name = "var_samp"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the sample variance of all input values."; + static constexpr const char *Example = "(SUM(x^2) - SUM(x)^2 / COUNT(x)) / (COUNT(x) - 1)"; + + static AggregateFunction GetFunction(); +}; + +struct VarianceFun { + using ALIAS = VarSampFun; + + static constexpr const char *Name = "variance"; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/distributive_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/distributive_functions.hpp new file mode 100644 index 00000000..a454620a --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/distributive_functions.hpp @@ -0,0 +1,231 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/distributive_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ApproxCountDistinctFun { + static constexpr const char *Name = "approx_count_distinct"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the approximate count of distinct elements using HyperLogLog."; + static constexpr const char *Example = "approx_count_distinct(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ArgMinFun { + static constexpr const char *Name = "arg_min"; + static constexpr const char *Parameters = "arg,val"; + static constexpr const char *Description = "Finds the row with the minimum val. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_min(A,B)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ArgminFun { + using ALIAS = ArgMinFun; + + static constexpr const char *Name = "argmin"; +}; + +struct MinByFun { + using ALIAS = ArgMinFun; + + static constexpr const char *Name = "min_by"; +}; + +struct ArgMaxFun { + static constexpr const char *Name = "arg_max"; + static constexpr const char *Parameters = "arg,val"; + static constexpr const char *Description = "Finds the row with the maximum val. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_max(A,B)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ArgmaxFun { + using ALIAS = ArgMaxFun; + + static constexpr const char *Name = "argmax"; +}; + +struct MaxByFun { + using ALIAS = ArgMaxFun; + + static constexpr const char *Name = "max_by"; +}; + +struct BitAndFun { + static constexpr const char *Name = "bit_and"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns the bitwise AND of all bits in a given expression."; + static constexpr const char *Example = "bit_and(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BitOrFun { + static constexpr const char *Name = "bit_or"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns the bitwise OR of all bits in a given expression."; + static constexpr const char *Example = "bit_or(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BitXorFun { + static constexpr const char *Name = "bit_xor"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns the bitwise XOR of all bits in a given expression."; + static constexpr const char *Example = "bit_xor(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BitstringAggFun { + static constexpr const char *Name = "bitstring_agg"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns a bitstring with bits set for each distinct value."; + static constexpr const char *Example = "bitstring_agg(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BoolAndFun { + static constexpr const char *Name = "bool_and"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns TRUE if every input value is TRUE, otherwise FALSE."; + static constexpr const char *Example = "bool_and(A)"; + + static AggregateFunction GetFunction(); +}; + +struct BoolOrFun { + static constexpr const char *Name = "bool_or"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns TRUE if any input value is TRUE, otherwise FALSE."; + static constexpr const char *Example = "bool_or(A)"; + + static AggregateFunction GetFunction(); +}; + +struct EntropyFun { + static constexpr const char *Name = "entropy"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the log-2 entropy of count input-values."; + static constexpr const char *Example = ""; + + static AggregateFunctionSet GetFunctions(); +}; + +struct KahanSumFun { + static constexpr const char *Name = "kahan_sum"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Calculates the sum using a more accurate floating point summation (Kahan Sum)."; + static constexpr const char *Example = "kahan_sum(A)"; + + static AggregateFunction GetFunction(); +}; + +struct FsumFun { + using ALIAS = KahanSumFun; + + static constexpr const char *Name = "fsum"; +}; + +struct SumkahanFun { + using ALIAS = KahanSumFun; + + static constexpr const char *Name = "sumkahan"; +}; + +struct KurtosisFun { + static constexpr const char *Name = "kurtosis"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the excess kurtosis (Fisher’s definition) of all input values, with a bias correction according to the sample size"; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct MinFun { + static constexpr const char *Name = "min"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns the minimum value present in arg."; + static constexpr const char *Example = "min(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct MaxFun { + static constexpr const char *Name = "max"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns the maximum value present in arg."; + static constexpr const char *Example = "max(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ProductFun { + static constexpr const char *Name = "product"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Calculates the product of all tuples in arg."; + static constexpr const char *Example = "product(A)"; + + static AggregateFunction GetFunction(); +}; + +struct SkewnessFun { + static constexpr const char *Name = "skewness"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the skewness of all input values."; + static constexpr const char *Example = "skewness(A)"; + + static AggregateFunction GetFunction(); +}; + +struct StringAggFun { + static constexpr const char *Name = "string_agg"; + static constexpr const char *Parameters = "str,arg"; + static constexpr const char *Description = "Concatenates the column string values with an optional separator."; + static constexpr const char *Example = "string_agg(A, '-')"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct GroupConcatFun { + using ALIAS = StringAggFun; + + static constexpr const char *Name = "group_concat"; +}; + +struct SumFun { + static constexpr const char *Name = "sum"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Calculates the sum value for all tuples in arg."; + static constexpr const char *Example = "sum(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct SumNoOverflowFun { + static constexpr const char *Name = "sum_no_overflow"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Calculates the sum value for all tuples in arg without overflow checks."; + static constexpr const char *Example = "sum_no_overflow(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/holistic_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/holistic_functions.hpp new file mode 100644 index 00000000..3627874e --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/holistic_functions.hpp @@ -0,0 +1,87 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/holistic_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ApproxQuantileFun { + static constexpr const char *Name = "approx_quantile"; + static constexpr const char *Parameters = "x,pos"; + static constexpr const char *Description = "Computes the approximate quantile using T-Digest."; + static constexpr const char *Example = "approx_quantile(A,0.5)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct MadFun { + static constexpr const char *Name = "mad"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the median absolute deviation for the values within x. NULL values are ignored. Temporal types return a positive INTERVAL. "; + static constexpr const char *Example = "MEDIAN(ABS(x-MEDIAN(x)))"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct MedianFun { + static constexpr const char *Name = "median"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the middle value of the set. NULL values are ignored. For even value counts, quantitiative values are averaged and ordinal values return the lower value."; + static constexpr const char *Example = "QUANTILE_CONT(x, 0.5)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ModeFun { + static constexpr const char *Name = "mode"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the most frequent value for the values within x. NULL values are ignored."; + static constexpr const char *Example = ""; + + static AggregateFunctionSet GetFunctions(); +}; + +struct QuantileDiscFun { + static constexpr const char *Name = "quantile_disc"; + static constexpr const char *Parameters = "x,pos"; + static constexpr const char *Description = "Returns the exact quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding exact quantiles."; + static constexpr const char *Example = ""; + + static AggregateFunctionSet GetFunctions(); +}; + +struct QuantileFun { + using ALIAS = QuantileDiscFun; + + static constexpr const char *Name = "quantile"; +}; + +struct QuantileContFun { + static constexpr const char *Name = "quantile_cont"; + static constexpr const char *Parameters = "x,pos"; + static constexpr const char *Description = "Returns the intepolated quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding intepolated quantiles. "; + static constexpr const char *Example = ""; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ReservoirQuantileFun { + static constexpr const char *Name = "reservoir_quantile"; + static constexpr const char *Parameters = "x,quantile,sample_size"; + static constexpr const char *Description = "Gives the approximate quantile using reservoir sampling, the sample size is optional and uses 8192 as a default size."; + static constexpr const char *Example = "reservoir_quantile(A,0.5,1024)"; + + static AggregateFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/nested_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/nested_functions.hpp new file mode 100644 index 00000000..b6511fec --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/nested_functions.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/nested_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct HistogramFun { + static constexpr const char *Name = "histogram"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns a LIST of STRUCTs with the fields bucket and count."; + static constexpr const char *Example = "histogram(A)"; + + static AggregateFunctionSet GetFunctions(); + static AggregateFunction GetHistogramUnorderedMap(LogicalType &type); +}; + +struct ListFun { + static constexpr const char *Name = "list"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns a LIST containing all the values of a column."; + static constexpr const char *Example = "list(A)"; + + static AggregateFunction GetFunction(); +}; + +struct ArrayAggFun { + using ALIAS = ListFun; + + static constexpr const char *Name = "array_agg"; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/regression/regr_count.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/regression/regr_count.hpp new file mode 100644 index 00000000..926cc393 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/regression/regr_count.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/regression/regr_count.hpp +// +// +//===----------------------------------------------------------------------===// +// REGR_COUNT(y, x) + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/core_functions/aggregate/algebraic/covar.hpp" +#include "duckdb/core_functions/aggregate/algebraic/stddev.hpp" + +namespace duckdb { + +struct RegrCountFunction { + template + static void Initialize(STATE &state) { + state = 0; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target += source; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + target = state; + } + static bool IgnoreNull() { + return true; + } + template + static void Operation(STATE &state, const A_TYPE &, const B_TYPE &, AggregateBinaryInput &) { + state += 1; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/regression/regr_slope.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/regression/regr_slope.hpp new file mode 100644 index 00000000..0563a9a5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/regression/regr_slope.hpp @@ -0,0 +1,61 @@ +// REGR_SLOPE(y, x) +// Returns the slope of the linear regression line for non-null pairs in a group. +// It is computed for non-null pairs using the following formula: +// COVAR_POP(x,y) / VAR_POP(x) + +//! Input : Any numeric type +//! Output : Double + +#pragma once +#include "duckdb/core_functions/aggregate/algebraic/stddev.hpp" +#include "duckdb/core_functions/aggregate/algebraic/covar.hpp" + +namespace duckdb { + +struct RegrSlopeState { + CovarState cov_pop; + StddevState var_pop; +}; + +struct RegrSlopeOperation { + template + static void Initialize(STATE &state) { + CovarOperation::Initialize(state.cov_pop); + STDDevBaseOperation::Initialize(state.var_pop); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + CovarOperation::Operation(state.cov_pop, y, x,idata); + STDDevBaseOperation::Execute(state.var_pop, x); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); + STDDevBaseOperation::Combine(source.var_pop, target.var_pop, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.cov_pop.count == 0 || state.var_pop.count == 0) { + finalize_data.ReturnNull(); + } else { + auto cov = state.cov_pop.co_moment / state.cov_pop.count; + auto var_pop = state.var_pop.count > 1 ? (state.var_pop.dsquared / state.var_pop.count) : 0; + if (!Value::DoubleIsFinite(var_pop)) { + throw OutOfRangeException("VARPOP is out of range!"); + } + if (var_pop == 0) { + finalize_data.ReturnNull(); + return; + } + target = cov / var_pop; + } + } + + static bool IgnoreNull() { + return true; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/regression_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/regression_functions.hpp new file mode 100644 index 00000000..70cd5f07 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/regression_functions.hpp @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/regression_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct RegrAvgxFun { + static constexpr const char *Name = "regr_avgx"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the average of the independent variable for non-null pairs in a group, where x is the independent variable and y is the dependent variable."; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct RegrAvgyFun { + static constexpr const char *Name = "regr_avgy"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the average of the dependent variable for non-null pairs in a group, where x is the independent variable and y is the dependent variable."; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct RegrCountFun { + static constexpr const char *Name = "regr_count"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the number of non-null number pairs in a group."; + static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / COUNT(*)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrInterceptFun { + static constexpr const char *Name = "regr_intercept"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the intercept of the univariate linear regression line for non-null pairs in a group."; + static constexpr const char *Example = "AVG(y)-REGR_SLOPE(y,x)*AVG(x)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrR2Fun { + static constexpr const char *Name = "regr_r2"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the coefficient of determination for non-null pairs in a group."; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct RegrSlopeFun { + static constexpr const char *Name = "regr_slope"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the slope of the linear regression line for non-null pairs in a group."; + static constexpr const char *Example = "COVAR_POP(x,y) / VAR_POP(x)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrSXXFun { + static constexpr const char *Name = "regr_sxx"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = ""; + static constexpr const char *Example = "REGR_COUNT(y, x) * VAR_POP(x)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrSXYFun { + static constexpr const char *Name = "regr_sxy"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the population covariance of input values"; + static constexpr const char *Example = "REGR_COUNT(y, x) * COVAR_POP(y, x)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrSYYFun { + static constexpr const char *Name = "regr_syy"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = ""; + static constexpr const char *Example = "REGR_COUNT(y, x) * VAR_POP(y)"; + + static AggregateFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp b/src/duckdb/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp new file mode 100644 index 00000000..45f533a7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp @@ -0,0 +1,163 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/sum_helpers.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +static inline void KahanAddInternal(double input, double &summed, double &err) { + double diff = input - err; + double newval = summed + diff; + err = (newval - summed) - diff; + summed = newval; +} + +template +struct SumState { + bool isset; + T value; + + void Initialize() { + this->isset = false; + } + + void Combine(const SumState &other) { + this->isset = other.isset || this->isset; + this->value += other.value; + } +}; + +struct KahanSumState { + bool isset; + double value; + double err; + + void Initialize() { + this->isset = false; + this->err = 0.0; + } + + void Combine(const KahanSumState &other) { + this->isset = other.isset || this->isset; + KahanAddInternal(other.value, this->value, this->err); + KahanAddInternal(other.err, this->value, this->err); + } +}; + +struct RegularAdd { + template + static void AddNumber(STATE &state, T input) { + state.value += input; + } + + template + static void AddConstant(STATE &state, T input, idx_t count) { + state.value += input * count; + } +}; + +struct KahanAdd { + template + static void AddNumber(STATE &state, T input) { + KahanAddInternal(input, state.value, state.err); + } + + template + static void AddConstant(STATE &state, T input, idx_t count) { + KahanAddInternal(input * count, state.value, state.err); + } +}; + +struct HugeintAdd { + static void AddValue(hugeint_t &result, uint64_t value, int positive) { + // integer summation taken from Tim Gubner et al. - Efficient Query Processing + // with Optimistically Compressed Hash Tables & Strings in the USSR + + // add the value to the lower part of the hugeint + result.lower += value; + // now handle overflows + int overflow = result.lower < value; + // we consider two situations: + // (1) input[idx] is positive, and current value is lower than value: overflow + // (2) input[idx] is negative, and current value is higher than value: underflow + if (!(overflow ^ positive)) { + // in the case of an overflow or underflow we either increment or decrement the upper base + // positive: +1, negative: -1 + result.upper += -1 + 2 * positive; + } + } + + template + static void AddNumber(STATE &state, T input) { + AddValue(state.value, uint64_t(input), input >= 0); + } + + template + static void AddConstant(STATE &state, T input, idx_t count) { + // add a constant X number of times + // fast path: check if value * count fits into a uint64_t + // note that we check if value * VECTOR_SIZE fits in a uint64_t to avoid having to actually do a division + // this is still a pretty high number (18014398509481984) so most positive numbers will fit + if (input >= 0 && uint64_t(input) < (NumericLimits::Maximum() / STANDARD_VECTOR_SIZE)) { + // if it does just multiply it and add the value + uint64_t value = uint64_t(input) * count; + AddValue(state.value, value, 1); + } else { + // if it doesn't fit we have two choices + // either we loop over count and add the values individually + // or we convert to a hugeint and multiply the hugeint + // the problem is that hugeint multiplication is expensive + // hence we switch here: with a low count we do the loop + // with a high count we do the hugeint multiplication + if (count < 8) { + for (idx_t i = 0; i < count; i++) { + AddValue(state.value, uint64_t(input), input >= 0); + } + } else { + hugeint_t addition = hugeint_t(input) * count; + state.value += addition; + } + } + } +}; + +template +struct BaseSumOperation { + template + static void Initialize(STATE &state) { + state.value = 0; + STATEOP::template Initialize(state); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + STATEOP::template Combine(source, target, aggr_input_data); + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + STATEOP::template AddValues(state, 1); + ADDOP::template AddNumber(state, input); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) { + STATEOP::template AddValues(state, count); + ADDOP::template AddConstant(state, input, count); + } + + static bool IgnoreNull() { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/core_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/core_functions.hpp new file mode 100644 index 00000000..3705be14 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/core_functions.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/core_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +class Catalog; +struct CatalogTransaction; + +struct CoreFunctions { + static void RegisterFunctions(Catalog &catalog, CatalogTransaction transaction); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/function_list.hpp b/src/duckdb/src/include/duckdb/core_functions/function_list.hpp new file mode 100644 index 00000000..87eb5f8c --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/function_list.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/function_list.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.hpp" + +namespace duckdb { + +typedef ScalarFunction (*get_scalar_function_t)(); +typedef ScalarFunctionSet (*get_scalar_function_set_t)(); +typedef AggregateFunction (*get_aggregate_function_t)(); +typedef AggregateFunctionSet (*get_aggregate_function_set_t)(); + +struct StaticFunctionDefinition { + const char *name; + const char *parameters; + const char *description; + const char *example; + get_scalar_function_t get_function; + get_scalar_function_set_t get_function_set; + get_aggregate_function_t get_aggregate_function; + get_aggregate_function_set_t get_aggregate_function_set; + + static StaticFunctionDefinition *GetFunctionList(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/bit_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/bit_functions.hpp new file mode 100644 index 00000000..c114d72a --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/bit_functions.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/bit_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct GetBitFun { + static constexpr const char *Name = "get_bit"; + static constexpr const char *Parameters = "bitstring,index"; + static constexpr const char *Description = "Extracts the nth bit from bitstring; the first (leftmost) bit is indexed 0"; + static constexpr const char *Example = "get_bit('0110010'::BIT, 2)"; + + static ScalarFunction GetFunction(); +}; + +struct SetBitFun { + static constexpr const char *Name = "set_bit"; + static constexpr const char *Parameters = "bitstring,index,new_value"; + static constexpr const char *Description = "Sets the nth bit in bitstring to newvalue; the first (leftmost) bit is indexed 0. Returns a new bitstring"; + static constexpr const char *Example = "set_bit('0110010'::BIT, 2, 0)"; + + static ScalarFunction GetFunction(); +}; + +struct BitPositionFun { + static constexpr const char *Name = "bit_position"; + static constexpr const char *Parameters = "substring,bitstring"; + static constexpr const char *Description = "Returns first starting index of the specified substring within bits, or zero if it is not present. The first (leftmost) bit is indexed 1"; + static constexpr const char *Example = "bit_position('010'::BIT, '1110101'::BIT)"; + + static ScalarFunction GetFunction(); +}; + +struct BitStringFun { + static constexpr const char *Name = "bitstring"; + static constexpr const char *Parameters = "bitstring,length"; + static constexpr const char *Description = "Pads the bitstring until the specified length"; + static constexpr const char *Example = "bitstring('1010'::BIT, 7)"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/blob_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/blob_functions.hpp new file mode 100644 index 00000000..0b1c9483 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/blob_functions.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/blob_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct DecodeFun { + static constexpr const char *Name = "decode"; + static constexpr const char *Parameters = "blob"; + static constexpr const char *Description = "Convert blob to varchar. Fails if blob is not valid utf-8"; + static constexpr const char *Example = "decode('\\xC3\\xBC'::BLOB)"; + + static ScalarFunction GetFunction(); +}; + +struct EncodeFun { + static constexpr const char *Name = "encode"; + static constexpr const char *Parameters = "string"; + static constexpr const char *Description = "Convert varchar to blob. Converts utf-8 characters into literal encoding"; + static constexpr const char *Example = "encode('my_string_with_ü')"; + + static ScalarFunction GetFunction(); +}; + +struct FromBase64Fun { + static constexpr const char *Name = "from_base64"; + static constexpr const char *Parameters = "string"; + static constexpr const char *Description = "Convert a base64 encoded string to a character string"; + static constexpr const char *Example = "from_base64('QQ==')"; + + static ScalarFunction GetFunction(); +}; + +struct ToBase64Fun { + static constexpr const char *Name = "to_base64"; + static constexpr const char *Parameters = "blob"; + static constexpr const char *Description = "Convert a blob to a base64 encoded string"; + static constexpr const char *Example = "base64('A'::blob)"; + + static ScalarFunction GetFunction(); +}; + +struct Base64Fun { + using ALIAS = ToBase64Fun; + + static constexpr const char *Name = "base64"; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/date_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/date_functions.hpp new file mode 100644 index 00000000..1664e62b --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/date_functions.hpp @@ -0,0 +1,573 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/date_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct AgeFun { + static constexpr const char *Name = "age"; + static constexpr const char *Parameters = "timestamp,timestamp"; + static constexpr const char *Description = "Subtract arguments, resulting in the time difference between the two timestamps"; + static constexpr const char *Example = "age(TIMESTAMP '2001-04-10', TIMESTAMP '1992-09-20')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CenturyFun { + static constexpr const char *Name = "century"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the century component from a date or timestamp"; + static constexpr const char *Example = "century(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CurrentDateFun { + static constexpr const char *Name = "current_date"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the current date"; + static constexpr const char *Example = "current_date()"; + + static ScalarFunction GetFunction(); +}; + +struct TodayFun { + using ALIAS = CurrentDateFun; + + static constexpr const char *Name = "today"; +}; + +struct DateDiffFun { + static constexpr const char *Name = "date_diff"; + static constexpr const char *Parameters = "part,startdate,enddate"; + static constexpr const char *Description = "The number of partition boundaries between the timestamps"; + static constexpr const char *Example = "date_diff('hour', TIMESTAMPTZ '1992-09-30 23:59:59', TIMESTAMPTZ '1992-10-01 01:58:00')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DatediffFun { + using ALIAS = DateDiffFun; + + static constexpr const char *Name = "datediff"; +}; + +struct DatePartFun { + static constexpr const char *Name = "date_part"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Get subfield (equivalent to extract)"; + static constexpr const char *Example = "date_part('minute', TIMESTAMP '1992-09-20 20:38:40')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DatepartFun { + using ALIAS = DatePartFun; + + static constexpr const char *Name = "datepart"; +}; + +struct DateSubFun { + static constexpr const char *Name = "date_sub"; + static constexpr const char *Parameters = "part,startdate,enddate"; + static constexpr const char *Description = "The number of complete partitions between the timestamps"; + static constexpr const char *Example = "date_sub('hour', TIMESTAMPTZ '1992-09-30 23:59:59', TIMESTAMPTZ '1992-10-01 01:58:00')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DatesubFun { + using ALIAS = DateSubFun; + + static constexpr const char *Name = "datesub"; +}; + +struct DateTruncFun { + static constexpr const char *Name = "date_trunc"; + static constexpr const char *Parameters = "part,timestamp"; + static constexpr const char *Description = "Truncate to specified precision"; + static constexpr const char *Example = "date_trunc('hour', TIMESTAMPTZ '1992-09-20 20:38:40')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DatetruncFun { + using ALIAS = DateTruncFun; + + static constexpr const char *Name = "datetrunc"; +}; + +struct DayFun { + static constexpr const char *Name = "day"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the day component from a date or timestamp"; + static constexpr const char *Example = "day(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DayNameFun { + static constexpr const char *Name = "dayname"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "The (English) name of the weekday"; + static constexpr const char *Example = "dayname(TIMESTAMP '1992-03-22')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DayOfMonthFun { + static constexpr const char *Name = "dayofmonth"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the dayofmonth component from a date or timestamp"; + static constexpr const char *Example = "dayofmonth(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DayOfWeekFun { + static constexpr const char *Name = "dayofweek"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the dayofweek component from a date or timestamp"; + static constexpr const char *Example = "dayofweek(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DayOfYearFun { + static constexpr const char *Name = "dayofyear"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the dayofyear component from a date or timestamp"; + static constexpr const char *Example = "dayofyear(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DecadeFun { + static constexpr const char *Name = "decade"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the decade component from a date or timestamp"; + static constexpr const char *Example = "decade(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EpochFun { + static constexpr const char *Name = "epoch"; + static constexpr const char *Parameters = "temporal"; + static constexpr const char *Description = "Extract the epoch component from a temporal type"; + static constexpr const char *Example = "epoch(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EpochMsFun { + static constexpr const char *Name = "epoch_ms"; + static constexpr const char *Parameters = "temporal"; + static constexpr const char *Description = "Extract the epoch component in milliseconds from a temporal type"; + static constexpr const char *Example = "epoch_ms(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EpochUsFun { + static constexpr const char *Name = "epoch_us"; + static constexpr const char *Parameters = "temporal"; + static constexpr const char *Description = "Extract the epoch component in microseconds from a temporal type"; + static constexpr const char *Example = "epoch_us(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EpochNsFun { + static constexpr const char *Name = "epoch_ns"; + static constexpr const char *Parameters = "temporal"; + static constexpr const char *Description = "Extract the epoch component in nanoseconds from a temporal type"; + static constexpr const char *Example = "epoch_ns(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EraFun { + static constexpr const char *Name = "era"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the era component from a date or timestamp"; + static constexpr const char *Example = "era(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CurrentTimeFun { + static constexpr const char *Name = "get_current_time"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the current time"; + static constexpr const char *Example = "get_current_time()"; + + static ScalarFunction GetFunction(); +}; + +struct GetCurrentTimestampFun { + static constexpr const char *Name = "get_current_timestamp"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the current timestamp"; + static constexpr const char *Example = "get_current_timestamp()"; + + static ScalarFunction GetFunction(); +}; + +struct NowFun { + using ALIAS = GetCurrentTimestampFun; + + static constexpr const char *Name = "now"; +}; + +struct TransactionTimestampFun { + using ALIAS = GetCurrentTimestampFun; + + static constexpr const char *Name = "transaction_timestamp"; +}; + +struct HoursFun { + static constexpr const char *Name = "hour"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the hour component from a date or timestamp"; + static constexpr const char *Example = "hour(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ISODayOfWeekFun { + static constexpr const char *Name = "isodow"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the isodow component from a date or timestamp"; + static constexpr const char *Example = "isodow(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ISOYearFun { + static constexpr const char *Name = "isoyear"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the isoyear component from a date or timestamp"; + static constexpr const char *Example = "isoyear(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct JulianDayFun { + static constexpr const char *Name = "julian"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the Julian Day number from a date or timestamp"; + static constexpr const char *Example = "julian(timestamp '2006-01-01 12:00')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct LastDayFun { + static constexpr const char *Name = "last_day"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Returns the last day of the month"; + static constexpr const char *Example = "last_day(TIMESTAMP '1992-03-22 01:02:03.1234')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MakeDateFun { + static constexpr const char *Name = "make_date"; + static constexpr const char *Parameters = "year,month,day"; + static constexpr const char *Description = "The date for the given parts"; + static constexpr const char *Example = "make_date(1992, 9, 20)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MakeTimeFun { + static constexpr const char *Name = "make_time"; + static constexpr const char *Parameters = "hour,minute,seconds"; + static constexpr const char *Description = "The time for the given parts"; + static constexpr const char *Example = "make_time(13, 34, 27.123456)"; + + static ScalarFunction GetFunction(); +}; + +struct MakeTimestampFun { + static constexpr const char *Name = "make_timestamp"; + static constexpr const char *Parameters = "year,month,day,hour,minute,seconds"; + static constexpr const char *Description = "The timestamp for the given parts"; + static constexpr const char *Example = "make_timestamp(1992, 9, 20, 13, 34, 27.123456)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MicrosecondsFun { + static constexpr const char *Name = "microsecond"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the microsecond component from a date or timestamp"; + static constexpr const char *Example = "microsecond(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MillenniumFun { + static constexpr const char *Name = "millennium"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the millennium component from a date or timestamp"; + static constexpr const char *Example = "millennium(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MillisecondsFun { + static constexpr const char *Name = "millisecond"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the millisecond component from a date or timestamp"; + static constexpr const char *Example = "millisecond(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MinutesFun { + static constexpr const char *Name = "minute"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the minute component from a date or timestamp"; + static constexpr const char *Example = "minute(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MonthFun { + static constexpr const char *Name = "month"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the month component from a date or timestamp"; + static constexpr const char *Example = "month(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MonthNameFun { + static constexpr const char *Name = "monthname"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "The (English) name of the month"; + static constexpr const char *Example = "monthname(TIMESTAMP '1992-09-20')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct QuarterFun { + static constexpr const char *Name = "quarter"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the quarter component from a date or timestamp"; + static constexpr const char *Example = "quarter(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SecondsFun { + static constexpr const char *Name = "second"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the second component from a date or timestamp"; + static constexpr const char *Example = "second(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct StrfTimeFun { + static constexpr const char *Name = "strftime"; + static constexpr const char *Parameters = "text,format"; + static constexpr const char *Description = "Converts timestamp to string according to the format string"; + static constexpr const char *Example = "strftime(timestamp '1992-01-01 20:38:40', '%a, %-d %B %Y - %I:%M:%S %p')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct StrpTimeFun { + static constexpr const char *Name = "strptime"; + static constexpr const char *Parameters = "text,format"; + static constexpr const char *Description = "Converts string to timestamp with time zone according to the format string if %Z is specified"; + static constexpr const char *Example = "strptime('Wed, 1 January 1992 - 08:38:40 PST', '%a, %-d %B %Y - %H:%M:%S %Z')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimeBucketFun { + static constexpr const char *Name = "time_bucket"; + static constexpr const char *Parameters = "bucket_width,timestamp,origin"; + static constexpr const char *Description = "Truncate TIMESTAMPTZ by the specified interval bucket_width. Buckets are aligned relative to origin TIMESTAMPTZ. The origin defaults to 2000-01-03 00:00:00+00 for buckets that do not include a month or year interval, and to 2000-01-01 00:00:00+00 for month and year buckets"; + static constexpr const char *Example = "time_bucket(INTERVAL '2 weeks', TIMESTAMP '1992-04-20 15:26:00-07', TIMESTAMP '1992-04-01 00:00:00-07')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimezoneFun { + static constexpr const char *Name = "timezone"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the timezone component from a date or timestamp"; + static constexpr const char *Example = "timezone(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimezoneHourFun { + static constexpr const char *Name = "timezone_hour"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the timezone_hour component from a date or timestamp"; + static constexpr const char *Example = "timezone_hour(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimezoneMinuteFun { + static constexpr const char *Name = "timezone_minute"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the timezone_minute component from a date or timestamp"; + static constexpr const char *Example = "timezone_minute(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ToDaysFun { + static constexpr const char *Name = "to_days"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a day interval"; + static constexpr const char *Example = "to_days(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToHoursFun { + static constexpr const char *Name = "to_hours"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a hour interval"; + static constexpr const char *Example = "to_hours(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMicrosecondsFun { + static constexpr const char *Name = "to_microseconds"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a microsecond interval"; + static constexpr const char *Example = "to_microseconds(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMillisecondsFun { + static constexpr const char *Name = "to_milliseconds"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a millisecond interval"; + static constexpr const char *Example = "to_milliseconds(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMinutesFun { + static constexpr const char *Name = "to_minutes"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a minute interval"; + static constexpr const char *Example = "to_minutes(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMonthsFun { + static constexpr const char *Name = "to_months"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a month interval"; + static constexpr const char *Example = "to_months(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToSecondsFun { + static constexpr const char *Name = "to_seconds"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a second interval"; + static constexpr const char *Example = "to_seconds(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToTimestampFun { + static constexpr const char *Name = "to_timestamp"; + static constexpr const char *Parameters = "sec"; + static constexpr const char *Description = "Converts secs since epoch to a timestamp with time zone"; + static constexpr const char *Example = "to_timestamp(1284352323.5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToYearsFun { + static constexpr const char *Name = "to_years"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a year interval"; + static constexpr const char *Example = "to_years(5)"; + + static ScalarFunction GetFunction(); +}; + +struct TryStrpTimeFun { + static constexpr const char *Name = "try_strptime"; + static constexpr const char *Parameters = "text,format"; + static constexpr const char *Description = "Converts string to timestamp using the format string (timestamp with time zone if %Z is specified). Returns NULL on failure"; + static constexpr const char *Example = "try_strptime('Wed, 1 January 1992 - 08:38:40 PM', '%a, %-d %B %Y - %I:%M:%S %p')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct WeekFun { + static constexpr const char *Name = "week"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the week component from a date or timestamp"; + static constexpr const char *Example = "week(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct WeekDayFun { + static constexpr const char *Name = "weekday"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the weekday component from a date or timestamp"; + static constexpr const char *Example = "weekday(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct WeekOfYearFun { + static constexpr const char *Name = "weekofyear"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the weekofyear component from a date or timestamp"; + static constexpr const char *Example = "weekofyear(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct YearFun { + static constexpr const char *Name = "year"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the year component from a date or timestamp"; + static constexpr const char *Example = "year(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct YearWeekFun { + static constexpr const char *Name = "yearweek"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the yearweek component from a date or timestamp"; + static constexpr const char *Example = "yearweek(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/debug_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/debug_functions.hpp new file mode 100644 index 00000000..5c83d51c --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/debug_functions.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/debug_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct VectorTypeFun { + static constexpr const char *Name = "vector_type"; + static constexpr const char *Parameters = "col"; + static constexpr const char *Description = "Returns the VectorType of a given column"; + static constexpr const char *Example = "vector_type(col)"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/enum_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/enum_functions.hpp new file mode 100644 index 00000000..66c7d681 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/enum_functions.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/enum_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct EnumFirstFun { + static constexpr const char *Name = "enum_first"; + static constexpr const char *Parameters = "enum"; + static constexpr const char *Description = "Returns the first value of the input enum type"; + static constexpr const char *Example = "enum_first(NULL::mood)"; + + static ScalarFunction GetFunction(); +}; + +struct EnumLastFun { + static constexpr const char *Name = "enum_last"; + static constexpr const char *Parameters = "enum"; + static constexpr const char *Description = "Returns the last value of the input enum type"; + static constexpr const char *Example = "enum_last(NULL::mood)"; + + static ScalarFunction GetFunction(); +}; + +struct EnumCodeFun { + static constexpr const char *Name = "enum_code"; + static constexpr const char *Parameters = "enum"; + static constexpr const char *Description = "Returns the numeric value backing the given enum value"; + static constexpr const char *Example = "enum_code('happy'::mood)"; + + static ScalarFunction GetFunction(); +}; + +struct EnumRangeFun { + static constexpr const char *Name = "enum_range"; + static constexpr const char *Parameters = "enum"; + static constexpr const char *Description = "Returns all values of the input enum type as an array"; + static constexpr const char *Example = "enum_range(NULL::mood)"; + + static ScalarFunction GetFunction(); +}; + +struct EnumRangeBoundaryFun { + static constexpr const char *Name = "enum_range_boundary"; + static constexpr const char *Parameters = "start,end"; + static constexpr const char *Description = "Returns the range between the two given enum values as an array. The values must be of the same enum type. When the first parameter is NULL, the result starts with the first value of the enum type. When the second parameter is NULL, the result ends with the last value of the enum type"; + static constexpr const char *Example = "enum_range_boundary(NULL, 'happy'::mood)"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/generic_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/generic_functions.hpp new file mode 100644 index 00000000..970e6195 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/generic_functions.hpp @@ -0,0 +1,153 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/generic_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct AliasFun { + static constexpr const char *Name = "alias"; + static constexpr const char *Parameters = "expr"; + static constexpr const char *Description = "Returns the name of a given expression"; + static constexpr const char *Example = "alias(42 + 1)"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentSettingFun { + static constexpr const char *Name = "current_setting"; + static constexpr const char *Parameters = "setting_name"; + static constexpr const char *Description = "Returns the current value of the configuration setting"; + static constexpr const char *Example = "current_setting('access_mode')"; + + static ScalarFunction GetFunction(); +}; + +struct ErrorFun { + static constexpr const char *Name = "error"; + static constexpr const char *Parameters = "message"; + static constexpr const char *Description = "Throws the given error message"; + static constexpr const char *Example = "error('access_mode')"; + + static ScalarFunction GetFunction(); +}; + +struct HashFun { + static constexpr const char *Name = "hash"; + static constexpr const char *Parameters = "param"; + static constexpr const char *Description = "Returns an integer with the hash of the value. Note that this is not a cryptographic hash"; + static constexpr const char *Example = "hash('🦆')"; + + static ScalarFunction GetFunction(); +}; + +struct LeastFun { + static constexpr const char *Name = "least"; + static constexpr const char *Parameters = "arg1, arg2, ..."; + static constexpr const char *Description = "Returns the lowest value of the set of input parameters"; + static constexpr const char *Example = "least(42, 84)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct GreatestFun { + static constexpr const char *Name = "greatest"; + static constexpr const char *Parameters = "arg1, arg2, ..."; + static constexpr const char *Description = "Returns the highest value of the set of input parameters"; + static constexpr const char *Example = "greatest(42, 84)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct StatsFun { + static constexpr const char *Name = "stats"; + static constexpr const char *Parameters = "expression"; + static constexpr const char *Description = "Returns a string with statistics about the expression. Expression can be a column, constant, or SQL expression"; + static constexpr const char *Example = "stats(5)"; + + static ScalarFunction GetFunction(); +}; + +struct TypeOfFun { + static constexpr const char *Name = "typeof"; + static constexpr const char *Parameters = "expression"; + static constexpr const char *Description = "Returns the name of the data type of the result of the expression"; + static constexpr const char *Example = "typeof('abc')"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentQueryFun { + static constexpr const char *Name = "current_query"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the current query as a string"; + static constexpr const char *Example = "current_query()"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentSchemaFun { + static constexpr const char *Name = "current_schema"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the name of the currently active schema. Default is main"; + static constexpr const char *Example = "current_schema()"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentSchemasFun { + static constexpr const char *Name = "current_schemas"; + static constexpr const char *Parameters = "include_implicit"; + static constexpr const char *Description = "Returns list of schemas. Pass a parameter of True to include implicit schemas"; + static constexpr const char *Example = "current_schemas(true)"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentDatabaseFun { + static constexpr const char *Name = "current_database"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the name of the currently active database"; + static constexpr const char *Example = "current_database()"; + + static ScalarFunction GetFunction(); +}; + +struct InSearchPathFun { + static constexpr const char *Name = "in_search_path"; + static constexpr const char *Parameters = "database_name,schema_name"; + static constexpr const char *Description = "Returns whether or not the database/schema are in the search path"; + static constexpr const char *Example = "in_search_path('memory', 'main')"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentTransactionIdFun { + static constexpr const char *Name = "txid_current"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the current transaction’s ID (a BIGINT). It will assign a new one if the current transaction does not have one already"; + static constexpr const char *Example = "txid_current()"; + + static ScalarFunction GetFunction(); +}; + +struct VersionFun { + static constexpr const char *Name = "version"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the currently active version of DuckDB in this format: v0.3.2 "; + static constexpr const char *Example = "version()"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/list_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/list_functions.hpp new file mode 100644 index 00000000..c1f2dc3e --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/list_functions.hpp @@ -0,0 +1,273 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/list_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ListFlattenFun { + static constexpr const char *Name = "flatten"; + static constexpr const char *Parameters = "nested_list"; + static constexpr const char *Description = "Flatten a nested list by one level"; + static constexpr const char *Example = "flatten([[1, 2, 3], [4, 5]])"; + + static ScalarFunction GetFunction(); +}; + +struct ListAggregateFun { + static constexpr const char *Name = "list_aggregate"; + static constexpr const char *Parameters = "list,name"; + static constexpr const char *Description = "Executes the aggregate function name on the elements of list"; + static constexpr const char *Example = "list_aggregate([1, 2, NULL], 'min')"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayAggregateFun { + using ALIAS = ListAggregateFun; + + static constexpr const char *Name = "array_aggregate"; +}; + +struct ListAggrFun { + using ALIAS = ListAggregateFun; + + static constexpr const char *Name = "list_aggr"; +}; + +struct ArrayAggrFun { + using ALIAS = ListAggregateFun; + + static constexpr const char *Name = "array_aggr"; +}; + +struct AggregateFun { + using ALIAS = ListAggregateFun; + + static constexpr const char *Name = "aggregate"; +}; + +struct ListDistinctFun { + static constexpr const char *Name = "list_distinct"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Removes all duplicates and NULLs from a list. Does not preserve the original order"; + static constexpr const char *Example = "list_distinct([1, 1, NULL, -3, 1, 5])"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayDistinctFun { + using ALIAS = ListDistinctFun; + + static constexpr const char *Name = "array_distinct"; +}; + +struct ListUniqueFun { + static constexpr const char *Name = "list_unique"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Counts the unique elements of a list"; + static constexpr const char *Example = "list_unique([1, 1, NULL, -3, 1, 5])"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayUniqueFun { + using ALIAS = ListUniqueFun; + + static constexpr const char *Name = "array_unique"; +}; + +struct ListValueFun { + static constexpr const char *Name = "list_value"; + static constexpr const char *Parameters = "any,..."; + static constexpr const char *Description = "Create a LIST containing the argument values"; + static constexpr const char *Example = "list_value(4, 5, 6)"; + + static ScalarFunction GetFunction(); +}; + +struct ListPackFun { + using ALIAS = ListValueFun; + + static constexpr const char *Name = "list_pack"; +}; + +struct ListSliceFun { + static constexpr const char *Name = "list_slice"; + static constexpr const char *Parameters = "list,begin,end[,step]"; + static constexpr const char *Description = "Extract a sublist using slice conventions. Negative values are accepted"; + static constexpr const char *Example = "list_slice(l, 2, 4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArraySliceFun { + using ALIAS = ListSliceFun; + + static constexpr const char *Name = "array_slice"; +}; + +struct ListSortFun { + static constexpr const char *Name = "list_sort"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Sorts the elements of the list"; + static constexpr const char *Example = "list_sort([3, 6, 1, 2])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArraySortFun { + using ALIAS = ListSortFun; + + static constexpr const char *Name = "array_sort"; +}; + +struct ListReverseSortFun { + static constexpr const char *Name = "list_reverse_sort"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Sorts the elements of the list in reverse order"; + static constexpr const char *Example = "list_reverse_sort([3, 6, 1, 2])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayReverseSortFun { + using ALIAS = ListReverseSortFun; + + static constexpr const char *Name = "array_reverse_sort"; +}; + +struct ListTransformFun { + static constexpr const char *Name = "list_transform"; + static constexpr const char *Parameters = "list,lambda"; + static constexpr const char *Description = "Returns a list that is the result of applying the lambda function to each element of the input list. See the Lambda Functions section for more details"; + static constexpr const char *Example = "list_transform([1, 2, 3], x -> x + 1)"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayTransformFun { + using ALIAS = ListTransformFun; + + static constexpr const char *Name = "array_transform"; +}; + +struct ListApplyFun { + using ALIAS = ListTransformFun; + + static constexpr const char *Name = "list_apply"; +}; + +struct ArrayApplyFun { + using ALIAS = ListTransformFun; + + static constexpr const char *Name = "array_apply"; +}; + +struct ApplyFun { + using ALIAS = ListTransformFun; + + static constexpr const char *Name = "apply"; +}; + +struct ListFilterFun { + static constexpr const char *Name = "list_filter"; + static constexpr const char *Parameters = "list,lambda"; + static constexpr const char *Description = "Constructs a list from those elements of the input list for which the lambda function returns true"; + static constexpr const char *Example = "list_filter([3, 4, 5], x -> x > 4)"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayFilterFun { + using ALIAS = ListFilterFun; + + static constexpr const char *Name = "array_filter"; +}; + +struct FilterFun { + using ALIAS = ListFilterFun; + + static constexpr const char *Name = "filter"; +}; + +struct GenerateSeriesFun { + static constexpr const char *Name = "generate_series"; + static constexpr const char *Parameters = "start,stop,step"; + static constexpr const char *Description = "Create a list of values between start and stop - the stop parameter is inclusive"; + static constexpr const char *Example = "generate_series(2, 5, 3)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListRangeFun { + static constexpr const char *Name = "range"; + static constexpr const char *Parameters = "start,stop,step"; + static constexpr const char *Description = "Create a list of values between start and stop - the stop parameter is exclusive"; + static constexpr const char *Example = "range(2, 5, 3)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListCosineSimilarityFun { + static constexpr const char *Name = "list_cosine_similarity"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Compute the cosine similarity between two lists"; + static constexpr const char *Example = "list_cosine_similarity([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListCosineSimilarityFunAlias { + using ALIAS = ListCosineSimilarityFun; + + static constexpr const char *Name = "<=>"; +}; + +struct ListDistanceFun { + static constexpr const char *Name = "list_distance"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Compute the distance between two lists"; + static constexpr const char *Example = "list_distance([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListDistanceFunAlias { + using ALIAS = ListDistanceFun; + + static constexpr const char *Name = "<->"; +}; + +struct ListInnerProductFun { + static constexpr const char *Name = "list_inner_product"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Compute the inner product between two lists"; + static constexpr const char *Example = "list_inner_product([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListDotProductFun { + using ALIAS = ListInnerProductFun; + + static constexpr const char *Name = "list_dot_product"; +}; + +struct ListInnerProductFunAlias { + using ALIAS = ListInnerProductFun; + + static constexpr const char *Name = "<#>"; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/map_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/map_functions.hpp new file mode 100644 index 00000000..21d5ec1b --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/map_functions.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/map_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct CardinalityFun { + static constexpr const char *Name = "cardinality"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns the size of the map (or the number of entries in the map)"; + static constexpr const char *Example = "cardinality( map([4, 2], ['a', 'b']) );"; + + static ScalarFunction GetFunction(); +}; + +struct MapFun { + static constexpr const char *Name = "map"; + static constexpr const char *Parameters = "keys,values"; + static constexpr const char *Description = "Creates a map from a set of keys and values"; + static constexpr const char *Example = "map(['key1', 'key2'], ['val1', 'val2'])"; + + static ScalarFunction GetFunction(); +}; + +struct MapEntriesFun { + static constexpr const char *Name = "map_entries"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns the map entries as a list of keys/values"; + static constexpr const char *Example = "map_entries(map(['key'], ['val']))"; + + static ScalarFunction GetFunction(); +}; + +struct MapExtractFun { + static constexpr const char *Name = "map_extract"; + static constexpr const char *Parameters = "map,key"; + static constexpr const char *Description = "Returns a list containing the value for a given key or an empty list if the key is not contained in the map. The type of the key provided in the second parameter must match the type of the map’s keys else an error is returned"; + static constexpr const char *Example = "map_extract(map(['key'], ['val']), 'key')"; + + static ScalarFunction GetFunction(); +}; + +struct ElementAtFun { + using ALIAS = MapExtractFun; + + static constexpr const char *Name = "element_at"; +}; + +struct MapFromEntriesFun { + static constexpr const char *Name = "map_from_entries"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns a map created from the entries of the array"; + static constexpr const char *Example = "map_from_entries([{k: 5, v: 'val1'}, {k: 3, v: 'val2'}]);"; + + static ScalarFunction GetFunction(); +}; + +struct MapConcatFun { + static constexpr const char *Name = "map_concat"; + static constexpr const char *Parameters = "any,..."; + static constexpr const char *Description = "Returns a map created from merging the input maps, on key collision the value is taken from the last map with that key"; + static constexpr const char *Example = "map_concat(map([1,2], ['a', 'b']), map([2,3], ['c', 'd']));"; + + static ScalarFunction GetFunction(); +}; + +struct MapKeysFun { + static constexpr const char *Name = "map_keys"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns the keys of a map as a list"; + static constexpr const char *Example = "map_keys(map(['key'], ['val']))"; + + static ScalarFunction GetFunction(); +}; + +struct MapValuesFun { + static constexpr const char *Name = "map_values"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns the values of a map as a list"; + static constexpr const char *Example = "map_values(map(['key'], ['val']))"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/math_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/math_functions.hpp new file mode 100644 index 00000000..771ffa17 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/math_functions.hpp @@ -0,0 +1,396 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/math_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct AbsOperatorFun { + static constexpr const char *Name = "@"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Absolute value"; + static constexpr const char *Example = "abs(-17.4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct AbsFun { + using ALIAS = AbsOperatorFun; + + static constexpr const char *Name = "abs"; +}; + +struct PowOperatorFun { + static constexpr const char *Name = "**"; + static constexpr const char *Parameters = "x,y"; + static constexpr const char *Description = "Computes x to the power of y"; + static constexpr const char *Example = "pow(2, 3)"; + + static ScalarFunction GetFunction(); +}; + +struct PowFun { + using ALIAS = PowOperatorFun; + + static constexpr const char *Name = "pow"; +}; + +struct PowerFun { + using ALIAS = PowOperatorFun; + + static constexpr const char *Name = "power"; +}; + +struct PowOperatorFunAlias { + using ALIAS = PowOperatorFun; + + static constexpr const char *Name = "^"; +}; + +struct FactorialOperatorFun { + static constexpr const char *Name = "!__postfix"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Factorial of x. Computes the product of the current integer and all integers below it"; + static constexpr const char *Example = "4!"; + + static ScalarFunction GetFunction(); +}; + +struct FactorialFun { + using ALIAS = FactorialOperatorFun; + + static constexpr const char *Name = "factorial"; +}; + +struct AcosFun { + static constexpr const char *Name = "acos"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the arccosine of x"; + static constexpr const char *Example = "acos(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct AsinFun { + static constexpr const char *Name = "asin"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the arcsine of x"; + static constexpr const char *Example = "asin(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct AtanFun { + static constexpr const char *Name = "atan"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the arctangent of x"; + static constexpr const char *Example = "atan(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct Atan2Fun { + static constexpr const char *Name = "atan2"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Computes the arctangent (y, x)"; + static constexpr const char *Example = "atan2(1.0, 0.0)"; + + static ScalarFunction GetFunction(); +}; + +struct BitCountFun { + static constexpr const char *Name = "bit_count"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the number of bits that are set"; + static constexpr const char *Example = "bit_count(31)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CbrtFun { + static constexpr const char *Name = "cbrt"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the cube root of x"; + static constexpr const char *Example = "cbrt(8)"; + + static ScalarFunction GetFunction(); +}; + +struct CeilFun { + static constexpr const char *Name = "ceil"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Rounds the number up"; + static constexpr const char *Example = "ceil(17.4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CeilingFun { + using ALIAS = CeilFun; + + static constexpr const char *Name = "ceiling"; +}; + +struct CosFun { + static constexpr const char *Name = "cos"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the cos of x"; + static constexpr const char *Example = "cos(90)"; + + static ScalarFunction GetFunction(); +}; + +struct CotFun { + static constexpr const char *Name = "cot"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the cotangent of x"; + static constexpr const char *Example = "cot(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct DegreesFun { + static constexpr const char *Name = "degrees"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Converts radians to degrees"; + static constexpr const char *Example = "degrees(pi())"; + + static ScalarFunction GetFunction(); +}; + +struct EvenFun { + static constexpr const char *Name = "even"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Rounds x to next even number by rounding away from zero"; + static constexpr const char *Example = "even(2.9)"; + + static ScalarFunction GetFunction(); +}; + +struct ExpFun { + static constexpr const char *Name = "exp"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes e to the power of x"; + static constexpr const char *Example = "exp(1)"; + + static ScalarFunction GetFunction(); +}; + +struct FloorFun { + static constexpr const char *Name = "floor"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Rounds the number down"; + static constexpr const char *Example = "floor(17.4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct IsFiniteFun { + static constexpr const char *Name = "isfinite"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns true if the floating point value is finite, false otherwise"; + static constexpr const char *Example = "isfinite(5.5)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct IsInfiniteFun { + static constexpr const char *Name = "isinf"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns true if the floating point value is infinite, false otherwise"; + static constexpr const char *Example = "isinf('Infinity'::float)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct IsNanFun { + static constexpr const char *Name = "isnan"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns true if the floating point value is not a number, false otherwise"; + static constexpr const char *Example = "isnan('NaN'::FLOAT)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct GammaFun { + static constexpr const char *Name = "gamma"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Interpolation of (x-1) factorial (so decimal inputs are allowed)"; + static constexpr const char *Example = "gamma(5.5)"; + + static ScalarFunction GetFunction(); +}; + +struct GreatestCommonDivisorFun { + static constexpr const char *Name = "greatest_common_divisor"; + static constexpr const char *Parameters = "x,y"; + static constexpr const char *Description = "Computes the greatest common divisor of x and y"; + static constexpr const char *Example = "greatest_common_divisor(42, 57)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct GcdFun { + using ALIAS = GreatestCommonDivisorFun; + + static constexpr const char *Name = "gcd"; +}; + +struct LeastCommonMultipleFun { + static constexpr const char *Name = "least_common_multiple"; + static constexpr const char *Parameters = "x,y"; + static constexpr const char *Description = "Computes the least common multiple of x and y"; + static constexpr const char *Example = "least_common_multiple(42, 57)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct LcmFun { + using ALIAS = LeastCommonMultipleFun; + + static constexpr const char *Name = "lcm"; +}; + +struct LogGammaFun { + static constexpr const char *Name = "lgamma"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the log of the gamma function"; + static constexpr const char *Example = "lgamma(2)"; + + static ScalarFunction GetFunction(); +}; + +struct LnFun { + static constexpr const char *Name = "ln"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the natural logarithm of x"; + static constexpr const char *Example = "ln(2)"; + + static ScalarFunction GetFunction(); +}; + +struct Log2Fun { + static constexpr const char *Name = "log2"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the 2-log of x"; + static constexpr const char *Example = "log2(8)"; + + static ScalarFunction GetFunction(); +}; + +struct Log10Fun { + static constexpr const char *Name = "log10"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the 10-log of x"; + static constexpr const char *Example = "log10(1000)"; + + static ScalarFunction GetFunction(); +}; + +struct LogFun { + using ALIAS = Log10Fun; + + static constexpr const char *Name = "log"; +}; + +struct NextAfterFun { + static constexpr const char *Name = "nextafter"; + static constexpr const char *Parameters = "x, y"; + static constexpr const char *Description = "Returns the next floating point value after x in the direction of y"; + static constexpr const char *Example = "nextafter(1::float, 2::float)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct PiFun { + static constexpr const char *Name = "pi"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the value of pi"; + static constexpr const char *Example = "pi()"; + + static ScalarFunction GetFunction(); +}; + +struct RadiansFun { + static constexpr const char *Name = "radians"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Converts degrees to radians"; + static constexpr const char *Example = "radians(90)"; + + static ScalarFunction GetFunction(); +}; + +struct RoundFun { + static constexpr const char *Name = "round"; + static constexpr const char *Parameters = "x,precision"; + static constexpr const char *Description = "Rounds x to s decimal places"; + static constexpr const char *Example = "round(42.4332, 2)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SignFun { + static constexpr const char *Name = "sign"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the sign of x as -1, 0 or 1"; + static constexpr const char *Example = "sign(-349)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SignBitFun { + static constexpr const char *Name = "signbit"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns whether the signbit is set or not"; + static constexpr const char *Example = "signbit(-0.0)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SinFun { + static constexpr const char *Name = "sin"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the sin of x"; + static constexpr const char *Example = "sin(90)"; + + static ScalarFunction GetFunction(); +}; + +struct SqrtFun { + static constexpr const char *Name = "sqrt"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the square root of x"; + static constexpr const char *Example = "sqrt(4)"; + + static ScalarFunction GetFunction(); +}; + +struct TanFun { + static constexpr const char *Name = "tan"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the tan of x"; + static constexpr const char *Example = "tan(90)"; + + static ScalarFunction GetFunction(); +}; + +struct TruncFun { + static constexpr const char *Name = "trunc"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Truncates the number"; + static constexpr const char *Example = "trunc(17.4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/operators_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/operators_functions.hpp new file mode 100644 index 00000000..908ec939 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/operators_functions.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/operators_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct BitwiseAndFun { + static constexpr const char *Name = "&"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Bitwise AND"; + static constexpr const char *Example = "91 & 15"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct BitwiseOrFun { + static constexpr const char *Name = "|"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Bitwise OR"; + static constexpr const char *Example = "32 | 3"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct BitwiseNotFun { + static constexpr const char *Name = "~"; + static constexpr const char *Parameters = "input"; + static constexpr const char *Description = "Bitwise NOT"; + static constexpr const char *Example = "~15"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct LeftShiftFun { + static constexpr const char *Name = "<<"; + static constexpr const char *Parameters = "input"; + static constexpr const char *Description = "Bitwise shift left"; + static constexpr const char *Example = "1 << 4"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct RightShiftFun { + static constexpr const char *Name = ">>"; + static constexpr const char *Parameters = "input"; + static constexpr const char *Description = "Bitwise shift right"; + static constexpr const char *Example = "8 >> 2"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct BitwiseXorFun { + static constexpr const char *Name = "xor"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Bitwise XOR"; + static constexpr const char *Example = "xor(17, 5)"; + + static ScalarFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/random_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/random_functions.hpp new file mode 100644 index 00000000..995c3df6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/random_functions.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/random_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct RandomFun { + static constexpr const char *Name = "random"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns a random number between 0 and 1"; + static constexpr const char *Example = "random()"; + + static ScalarFunction GetFunction(); +}; + +struct SetseedFun { + static constexpr const char *Name = "setseed"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Sets the seed to be used for the random function"; + static constexpr const char *Example = "setseed(0.42)"; + + static ScalarFunction GetFunction(); +}; + +struct UUIDFun { + static constexpr const char *Name = "uuid"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns a random UUID similar to this: eeccb8c5-9943-b2bb-bb5e-222f4e14b687"; + static constexpr const char *Example = "uuid()"; + + static ScalarFunction GetFunction(); +}; + +struct GenRandomUuidFun { + using ALIAS = UUIDFun; + + static constexpr const char *Name = "gen_random_uuid"; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/string_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/string_functions.hpp new file mode 100644 index 00000000..5cf6ab2d --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/string_functions.hpp @@ -0,0 +1,474 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/string_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct StartsWithOperatorFun { + static constexpr const char *Name = "^@"; + static constexpr const char *Parameters = "string,search_string"; + static constexpr const char *Description = "Returns true if string begins with search_string"; + static constexpr const char *Example = "starts_with('abc','a')"; + + static ScalarFunction GetFunction(); +}; + +struct StartsWithFun { + using ALIAS = StartsWithOperatorFun; + + static constexpr const char *Name = "starts_with"; +}; + +struct ASCIIFun { + static constexpr const char *Name = "ascii"; + static constexpr const char *Parameters = "string"; + static constexpr const char *Description = "Returns an integer that represents the Unicode code point of the first character of the string"; + static constexpr const char *Example = "ascii('Ω')"; + + static ScalarFunction GetFunction(); +}; + +struct BarFun { + static constexpr const char *Name = "bar"; + static constexpr const char *Parameters = "x,min,max,width"; + static constexpr const char *Description = "Draws a band whose width is proportional to (x - min) and equal to width characters when x = max. width defaults to 80"; + static constexpr const char *Example = "bar(5, 0, 20, 10)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct BinFun { + static constexpr const char *Name = "bin"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Converts the value to binary representation"; + static constexpr const char *Example = "bin(42)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ToBinaryFun { + using ALIAS = BinFun; + + static constexpr const char *Name = "to_binary"; +}; + +struct ChrFun { + static constexpr const char *Name = "chr"; + static constexpr const char *Parameters = "code_point"; + static constexpr const char *Description = "Returns a character which is corresponding the ASCII code value or Unicode code point"; + static constexpr const char *Example = "chr(65)"; + + static ScalarFunction GetFunction(); +}; + +struct DamerauLevenshteinFun { + static constexpr const char *Name = "damerau_levenshtein"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "Extension of Levenshtein distance to also include transposition of adjacent characters as an allowed edit operation. In other words, the minimum number of edit operations (insertions, deletions, substitutions or transpositions) required to change one string to another. Different case is considered different"; + static constexpr const char *Example = "damerau_levenshtein('hello', 'world')"; + + static ScalarFunction GetFunction(); +}; + +struct FormatFun { + static constexpr const char *Name = "format"; + static constexpr const char *Parameters = "format,parameters..."; + static constexpr const char *Description = "Formats a string using fmt syntax"; + static constexpr const char *Example = "format('Benchmark \"{}\" took {} seconds', 'CSV', 42)"; + + static ScalarFunction GetFunction(); +}; + +struct FormatBytesFun { + static constexpr const char *Name = "format_bytes"; + static constexpr const char *Parameters = "bytes"; + static constexpr const char *Description = "Converts bytes to a human-readable presentation (e.g. 16000 -> 16KB)"; + static constexpr const char *Example = "format_bytes(1000 * 16)"; + + static ScalarFunction GetFunction(); +}; + +struct FormatreadabledecimalsizeFun { + using ALIAS = FormatBytesFun; + + static constexpr const char *Name = "formatReadableDecimalSize"; +}; + +struct HammingFun { + static constexpr const char *Name = "hamming"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "The number of positions with different characters for 2 strings of equal length. Different case is considered different"; + static constexpr const char *Example = "hamming('duck','luck')"; + + static ScalarFunction GetFunction(); +}; + +struct MismatchesFun { + using ALIAS = HammingFun; + + static constexpr const char *Name = "mismatches"; +}; + +struct HexFun { + static constexpr const char *Name = "hex"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Converts the value to hexadecimal representation"; + static constexpr const char *Example = "hex(42)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ToHexFun { + using ALIAS = HexFun; + + static constexpr const char *Name = "to_hex"; +}; + +struct InstrFun { + static constexpr const char *Name = "instr"; + static constexpr const char *Parameters = "haystack,needle"; + static constexpr const char *Description = "Returns location of first occurrence of needle in haystack, counting from 1. Returns 0 if no match found"; + static constexpr const char *Example = "instr('test test','es')"; + + static ScalarFunction GetFunction(); +}; + +struct StrposFun { + using ALIAS = InstrFun; + + static constexpr const char *Name = "strpos"; +}; + +struct PositionFun { + using ALIAS = InstrFun; + + static constexpr const char *Name = "position"; +}; + +struct JaccardFun { + static constexpr const char *Name = "jaccard"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "The Jaccard similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; + static constexpr const char *Example = "jaccard('duck','luck')"; + + static ScalarFunction GetFunction(); +}; + +struct JaroSimilarityFun { + static constexpr const char *Name = "jaro_similarity"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "The Jaro similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; + static constexpr const char *Example = "jaro_similarity('duck','duckdb')"; + + static ScalarFunction GetFunction(); +}; + +struct JaroWinklerSimilarityFun { + static constexpr const char *Name = "jaro_winkler_similarity"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "The Jaro-Winkler similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; + static constexpr const char *Example = "jaro_winkler_similarity('duck','duckdb')"; + + static ScalarFunction GetFunction(); +}; + +struct LeftFun { + static constexpr const char *Name = "left"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Extract the left-most count characters"; + static constexpr const char *Example = "left('Hello🦆', 2)"; + + static ScalarFunction GetFunction(); +}; + +struct LeftGraphemeFun { + static constexpr const char *Name = "left_grapheme"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Extract the left-most count grapheme clusters"; + static constexpr const char *Example = "left_grapheme('🤦🏼‍♂️🤦🏽‍♀️', 1)"; + + static ScalarFunction GetFunction(); +}; + +struct LevenshteinFun { + static constexpr const char *Name = "levenshtein"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "The minimum number of single-character edits (insertions, deletions or substitutions) required to change one string to the other. Different case is considered different"; + static constexpr const char *Example = "levenshtein('duck','db')"; + + static ScalarFunction GetFunction(); +}; + +struct Editdist3Fun { + using ALIAS = LevenshteinFun; + + static constexpr const char *Name = "editdist3"; +}; + +struct LpadFun { + static constexpr const char *Name = "lpad"; + static constexpr const char *Parameters = "string,count,character"; + static constexpr const char *Description = "Pads the string with the character from the left until it has count characters"; + static constexpr const char *Example = "lpad('hello', 10, '>')"; + + static ScalarFunction GetFunction(); +}; + +struct LtrimFun { + static constexpr const char *Name = "ltrim"; + static constexpr const char *Parameters = "string,characters"; + static constexpr const char *Description = "Removes any occurrences of any of the characters from the left side of the string"; + static constexpr const char *Example = "ltrim('>>>>test<<', '><')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MD5Fun { + static constexpr const char *Name = "md5"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Returns the MD5 hash of the value as a string"; + static constexpr const char *Example = "md5('123')"; + + static ScalarFunction GetFunction(); +}; + +struct MD5NumberFun { + static constexpr const char *Name = "md5_number"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Returns the MD5 hash of the value as an INT128"; + static constexpr const char *Example = "md5_number('123')"; + + static ScalarFunction GetFunction(); +}; + +struct MD5NumberLowerFun { + static constexpr const char *Name = "md5_number_lower"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Returns the MD5 hash of the value as an INT128"; + static constexpr const char *Example = "md5_number_lower('123')"; + + static ScalarFunction GetFunction(); +}; + +struct MD5NumberUpperFun { + static constexpr const char *Name = "md5_number_upper"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Returns the MD5 hash of the value as an INT128"; + static constexpr const char *Example = "md5_number_upper('123')"; + + static ScalarFunction GetFunction(); +}; + +struct PrintfFun { + static constexpr const char *Name = "printf"; + static constexpr const char *Parameters = "format,parameters..."; + static constexpr const char *Description = "Formats a string using printf syntax"; + static constexpr const char *Example = "printf('Benchmark \"%s\" took %d seconds', 'CSV', 42)"; + + static ScalarFunction GetFunction(); +}; + +struct RepeatFun { + static constexpr const char *Name = "repeat"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Repeats the string count number of times"; + static constexpr const char *Example = "repeat('A', 5)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ReplaceFun { + static constexpr const char *Name = "replace"; + static constexpr const char *Parameters = "string,source,target"; + static constexpr const char *Description = "Replaces any occurrences of the source with target in string"; + static constexpr const char *Example = "replace('hello', 'l', '-')"; + + static ScalarFunction GetFunction(); +}; + +struct ReverseFun { + static constexpr const char *Name = "reverse"; + static constexpr const char *Parameters = "string"; + static constexpr const char *Description = "Reverses the string"; + static constexpr const char *Example = "reverse('hello')"; + + static ScalarFunction GetFunction(); +}; + +struct RightFun { + static constexpr const char *Name = "right"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Extract the right-most count characters"; + static constexpr const char *Example = "right('Hello🦆', 3)"; + + static ScalarFunction GetFunction(); +}; + +struct RightGraphemeFun { + static constexpr const char *Name = "right_grapheme"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Extract the right-most count grapheme clusters"; + static constexpr const char *Example = "right_grapheme('🤦🏼‍♂️🤦🏽‍♀️', 1)"; + + static ScalarFunction GetFunction(); +}; + +struct RpadFun { + static constexpr const char *Name = "rpad"; + static constexpr const char *Parameters = "string,count,character"; + static constexpr const char *Description = "Pads the string with the character from the right until it has count characters"; + static constexpr const char *Example = "rpad('hello', 10, '<')"; + + static ScalarFunction GetFunction(); +}; + +struct RtrimFun { + static constexpr const char *Name = "rtrim"; + static constexpr const char *Parameters = "string,characters"; + static constexpr const char *Description = "Removes any occurrences of any of the characters from the right side of the string"; + static constexpr const char *Example = "rtrim('>>>>test<<', '><')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SHA256Fun { + static constexpr const char *Name = "sha256"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Returns the SHA256 hash of the value"; + static constexpr const char *Example = "sha256('hello')"; + + static ScalarFunction GetFunction(); +}; + +struct StringSplitFun { + static constexpr const char *Name = "string_split"; + static constexpr const char *Parameters = "string,separator"; + static constexpr const char *Description = "Splits the string along the separator"; + static constexpr const char *Example = "string_split('hello-world', '-')"; + + static ScalarFunction GetFunction(); +}; + +struct StrSplitFun { + using ALIAS = StringSplitFun; + + static constexpr const char *Name = "str_split"; +}; + +struct StringToArrayFun { + using ALIAS = StringSplitFun; + + static constexpr const char *Name = "string_to_array"; +}; + +struct SplitFun { + using ALIAS = StringSplitFun; + + static constexpr const char *Name = "split"; +}; + +struct StringSplitRegexFun { + static constexpr const char *Name = "string_split_regex"; + static constexpr const char *Parameters = "string,separator"; + static constexpr const char *Description = "Splits the string along the regex"; + static constexpr const char *Example = "string_split_regex('hello␣world; 42', ';?␣')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct StrSplitRegexFun { + using ALIAS = StringSplitRegexFun; + + static constexpr const char *Name = "str_split_regex"; +}; + +struct RegexpSplitToArrayFun { + using ALIAS = StringSplitRegexFun; + + static constexpr const char *Name = "regexp_split_to_array"; +}; + +struct TranslateFun { + static constexpr const char *Name = "translate"; + static constexpr const char *Parameters = "string,from,to"; + static constexpr const char *Description = "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted"; + static constexpr const char *Example = "translate('12345', '143', 'ax')"; + + static ScalarFunction GetFunction(); +}; + +struct TrimFun { + static constexpr const char *Name = "trim"; + static constexpr const char *Parameters = "string,characters"; + static constexpr const char *Description = "Removes any occurrences of any of the characters from either side of the string"; + static constexpr const char *Example = "trim('>>>>test<<', '><')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct UnbinFun { + static constexpr const char *Name = "unbin"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Converts a value from binary representation to a blob"; + static constexpr const char *Example = "unbin('0110')"; + + static ScalarFunction GetFunction(); +}; + +struct FromBinaryFun { + using ALIAS = UnbinFun; + + static constexpr const char *Name = "from_binary"; +}; + +struct UnhexFun { + static constexpr const char *Name = "unhex"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Converts a value from hexadecimal representation to a blob"; + static constexpr const char *Example = "unhex('2A')"; + + static ScalarFunction GetFunction(); +}; + +struct FromHexFun { + using ALIAS = UnhexFun; + + static constexpr const char *Name = "from_hex"; +}; + +struct UnicodeFun { + static constexpr const char *Name = "unicode"; + static constexpr const char *Parameters = "str"; + static constexpr const char *Description = "Returns the unicode codepoint of the first character of the string"; + static constexpr const char *Example = "unicode('ü')"; + + static ScalarFunction GetFunction(); +}; + +struct OrdFun { + using ALIAS = UnicodeFun; + + static constexpr const char *Name = "ord"; +}; + +struct ToBaseFun { + static constexpr const char *Name = "to_base"; + static constexpr const char *Parameters = "number,radix,min_length"; + static constexpr const char *Description = "Converts a value to a string in the given base radix, optionally padding with leading zeros to the minimum length"; + static constexpr const char *Example = "to_base(42, 16)"; + + static ScalarFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/struct_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/struct_functions.hpp new file mode 100644 index 00000000..b83c5e95 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/struct_functions.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/struct_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct StructInsertFun { + static constexpr const char *Name = "struct_insert"; + static constexpr const char *Parameters = "struct,any"; + static constexpr const char *Description = "Adds field(s)/value(s) to an existing STRUCT with the argument values. The entry name(s) will be the bound variable name(s)"; + static constexpr const char *Example = "struct_insert({'a': 1}, b := 2)"; + + static ScalarFunction GetFunction(); +}; + +struct StructPackFun { + static constexpr const char *Name = "struct_pack"; + static constexpr const char *Parameters = "any"; + static constexpr const char *Description = "Creates a STRUCT containing the argument values. The entry name will be the bound variable name"; + static constexpr const char *Example = "struct_pack(i := 4, s := 'string')"; + + static ScalarFunction GetFunction(); +}; + +struct RowFun { + static constexpr const char *Name = "row"; + static constexpr const char *Parameters = "any"; + static constexpr const char *Description = "Creates an unnamed STRUCT containing the argument values."; + static constexpr const char *Example = "row(4, 'hello')"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/core_functions/scalar/union_functions.hpp b/src/duckdb/src/include/duckdb/core_functions/scalar/union_functions.hpp new file mode 100644 index 00000000..8b869460 --- /dev/null +++ b/src/duckdb/src/include/duckdb/core_functions/scalar/union_functions.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/union_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct UnionExtractFun { + static constexpr const char *Name = "union_extract"; + static constexpr const char *Parameters = "union,tag"; + static constexpr const char *Description = "Extract the value with the named tags from the union. NULL if the tag is not currently selected"; + static constexpr const char *Example = "union_extract(s, 'k')"; + + static ScalarFunction GetFunction(); +}; + +struct UnionTagFun { + static constexpr const char *Name = "union_tag"; + static constexpr const char *Parameters = "union"; + static constexpr const char *Description = "Retrieve the currently selected tag of the union as an ENUM"; + static constexpr const char *Example = "union_tag(union_value(k := 'foo'))"; + + static ScalarFunction GetFunction(); +}; + +struct UnionValueFun { + static constexpr const char *Name = "union_value"; + static constexpr const char *Parameters = "tag"; + static constexpr const char *Description = "Create a single member UNION containing the argument value. The tag of the value will be the bound variable name"; + static constexpr const char *Example = "union_value(k := 'hello')"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/adaptive_filter.hpp b/src/duckdb/src/include/duckdb/execution/adaptive_filter.hpp new file mode 100644 index 00000000..61bd50c0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/adaptive_filter.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/adaptive_filter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression/list.hpp" + +#include +namespace duckdb { + +class AdaptiveFilter { +public: + explicit AdaptiveFilter(const Expression &expr); + explicit AdaptiveFilter(TableFilterSet *table_filters); + void AdaptRuntimeStatistics(double duration); + vector permutation; + +private: + //! used for adaptive expression reordering + idx_t iteration_count; + idx_t swap_idx; + idx_t right_random_border; + idx_t observe_interval; + idx_t execute_interval; + double runtime_sum; + double prev_mean; + bool observe; + bool warmup; + vector swap_likeliness; + std::default_random_engine generator; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/aggregate_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/aggregate_hashtable.hpp new file mode 100644 index 00000000..b8a4a8fa --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/aggregate_hashtable.hpp @@ -0,0 +1,209 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/aggregate_hashtable.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/row_operations/row_matcher.hpp" +#include "duckdb/common/types/row/partitioned_tuple_data.hpp" +#include "duckdb/execution/base_aggregate_hashtable.hpp" +#include "duckdb/storage/arena_allocator.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" + +namespace duckdb { + +class BlockHandle; +class BufferHandle; + +struct FlushMoveState; + +//! GroupedAggregateHashTable is a linear probing HT that is used for computing +//! aggregates +/*! + GroupedAggregateHashTable is a HT that is used for computing aggregates. It takes + as input the set of groups and the types of the aggregates to compute and + stores them in the HT. It uses linear probing for collision resolution. +*/ + +struct aggr_ht_entry_t { +public: + explicit aggr_ht_entry_t(hash_t value_p) : value(value_p) { + } + + inline bool IsOccupied() const { + return value != 0; + } + + inline data_ptr_t GetPointer() const { + D_ASSERT(IsOccupied()); + return reinterpret_cast(value & POINTER_MASK); + } + inline void SetPointer(const data_ptr_t &pointer) { + // Pointer shouldn't use upper bits + D_ASSERT((reinterpret_cast(pointer) & SALT_MASK) == 0); + // Value should have all 1's in the pointer area + D_ASSERT((value & POINTER_MASK) == POINTER_MASK); + // Set upper bits to 1 in pointer so the salt stays intact + value &= reinterpret_cast(pointer) | SALT_MASK; + } + + static inline hash_t ExtractSalt(const hash_t &hash) { + // Leaves upper bits intact, sets lower bits to all 1's + return hash | POINTER_MASK; + } + inline hash_t GetSalt() const { + return ExtractSalt(value); + } + inline void SetSalt(const hash_t &salt) { + // Shouldn't be occupied when we set this + D_ASSERT(!IsOccupied()); + // Salt should have all 1's in the pointer field + D_ASSERT((salt & POINTER_MASK) == POINTER_MASK); + // No need to mask, just put the whole thing there + value = salt; + } + +private: + //! Upper 16 bits are salt + static constexpr const hash_t SALT_MASK = 0xFFFF000000000000; + //! Lower 48 bits are the pointer + static constexpr const hash_t POINTER_MASK = 0x0000FFFFFFFFFFFF; + + hash_t value; +}; + +class GroupedAggregateHashTable : public BaseAggregateHashTable { +public: + GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, vector group_types, + vector payload_types, const vector &aggregates, + idx_t initial_capacity = InitialCapacity(), idx_t radix_bits = 0); + GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, vector group_types, + vector payload_types, vector aggregates, + idx_t initial_capacity = InitialCapacity(), idx_t radix_bits = 0); + GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, vector group_types); + ~GroupedAggregateHashTable() override; + +public: + //! The hash table load factor, when a resize is triggered + constexpr static float LOAD_FACTOR = 1.5; + + //! Get the layout of this HT + const TupleDataLayout &GetLayout() const; + //! Number of groups in the HT + idx_t Count() const; + //! Initial capacity of the HT + static idx_t InitialCapacity(); + //! Capacity that can hold 'count' entries without resizing + static idx_t GetCapacityForCount(idx_t count); + //! Current capacity of the HT + idx_t Capacity() const; + //! Threshold at which to resize the HT + idx_t ResizeThreshold() const; + + //! Add the given data to the HT, computing the aggregates grouped by the + //! data in the group chunk. When resize = true, aggregates will not be + //! computed but instead just assigned. + idx_t AddChunk(DataChunk &groups, DataChunk &payload, const unsafe_vector &filter); + idx_t AddChunk(DataChunk &groups, Vector &group_hashes, DataChunk &payload, const unsafe_vector &filter); + idx_t AddChunk(DataChunk &groups, DataChunk &payload, AggregateType filter); + + //! Fetch the aggregates for specific groups from the HT and place them in the result + void FetchAggregates(DataChunk &groups, DataChunk &result); + + //! Finds or creates groups in the hashtable using the specified group keys. The addresses vector will be filled + //! with pointers to the groups in the hash table, and the new_groups selection vector will point to the newly + //! created groups. The return value is the amount of newly created groups. + idx_t FindOrCreateGroups(DataChunk &groups, Vector &group_hashes, Vector &addresses_out, + SelectionVector &new_groups_out); + idx_t FindOrCreateGroups(DataChunk &groups, Vector &addresses_out, SelectionVector &new_groups_out); + void FindOrCreateGroups(DataChunk &groups, Vector &addresses_out); + + unique_ptr &GetPartitionedData(); + shared_ptr GetAggregateAllocator(); + + //! Resize the HT to the specified size. Must be larger than the current size. + void Resize(idx_t size); + //! Resets the pointer table of the HT to all 0's + void ClearPointerTable(); + //! Resets the group count to 0 + void ResetCount(); + //! Set the radix bits for this HT + void SetRadixBits(idx_t radix_bits); + //! Initializes the PartitionedTupleData + void InitializePartitionedData(); + + //! Executes the filter(if any) and update the aggregates + void Combine(GroupedAggregateHashTable &other); + void Combine(TupleDataCollection &other_data); + + //! Unpins the data blocks + void UnpinData(); + +private: + //! Efficiently matches groups + RowMatcher row_matcher; + + //! Append state + struct AggregateHTAppendState { + AggregateHTAppendState(); + + PartitionedTupleDataAppendState append_state; + + Vector ht_offsets; + Vector hash_salts; + SelectionVector group_compare_vector; + SelectionVector no_match_vector; + SelectionVector empty_vector; + SelectionVector new_groups; + Vector addresses; + unsafe_unique_array group_data; + DataChunk group_chunk; + } state; + + //! The number of radix bits to partition by + idx_t radix_bits; + //! The data of the HT + unique_ptr partitioned_data; + + //! Predicates for matching groups (always ExpressionType::COMPARE_EQUAL) + vector predicates; + + //! The number of groups in the HT + idx_t count; + //! The capacity of the HT. This can be increased using GroupedAggregateHashTable::Resize + idx_t capacity; + //! The hash map (pointer table) of the HT: allocated data and pointer into it + AllocatedData hash_map; + aggr_ht_entry_t *entries; + //! Offset of the hash column in the rows + idx_t hash_offset; + //! Bitmask for getting relevant bits from the hashes to determine the position + hash_t bitmask; + + //! The active arena allocator used by the aggregates for their internal state + shared_ptr aggregate_allocator; + //! Owning arena allocators that this HT has data from + vector> stored_allocators; + +private: + //! Disabled the copy constructor + GroupedAggregateHashTable(const GroupedAggregateHashTable &) = delete; + //! Destroy the HT + void Destroy(); + + //! Apply bitmask to get the entry in the HT + inline idx_t ApplyBitMask(hash_t hash) const; + + //! Does the actual group matching / creation + idx_t FindOrCreateGroupsInternal(DataChunk &groups, Vector &group_hashes, Vector &addresses, + SelectionVector &new_groups); + + //! Verify the pointer table of the HT + void Verify(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/base_aggregate_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/base_aggregate_hashtable.hpp new file mode 100644 index 00000000..df4f7288 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/base_aggregate_hashtable.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/base_aggregate_hashtable.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/row/tuple_data_layout.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" + +namespace duckdb { +class BufferManager; + +class BaseAggregateHashTable { +public: + BaseAggregateHashTable(ClientContext &context, Allocator &allocator, const vector &aggregates, + vector payload_types); + virtual ~BaseAggregateHashTable() { + } + +protected: + Allocator &allocator; + BufferManager &buffer_manager; + //! A helper for managing offsets into the data buffers + TupleDataLayout layout; + //! The types of the payload columns stored in the hashtable + vector payload_types; + //! Intermediate structures and data for aggregate filters + AggregateFilterDataSet filter_set; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/column_binding_resolver.hpp b/src/duckdb/src/include/duckdb/execution/column_binding_resolver.hpp new file mode 100644 index 00000000..fa9fd3f9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/column_binding_resolver.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/column_binding_resolver.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +//! The ColumnBindingResolver resolves ColumnBindings into base tables +//! (table_index, column_index) into physical indices into the DataChunks that +//! are used within the execution engine +class ColumnBindingResolver : public LogicalOperatorVisitor { +public: + ColumnBindingResolver(); + + void VisitOperator(LogicalOperator &op) override; + static void Verify(LogicalOperator &op); + +protected: + vector bindings; + + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + static unordered_set VerifyInternal(LogicalOperator &op); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/execution_context.hpp b/src/duckdb/src/include/duckdb/execution/execution_context.hpp new file mode 100644 index 00000000..597094c3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/execution_context.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/execution_context.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { +class ClientContext; +class ThreadContext; +class Pipeline; + +class ExecutionContext { +public: + ExecutionContext(ClientContext &client_p, ThreadContext &thread_p, optional_ptr pipeline_p) + : client(client_p), thread(thread_p), pipeline(pipeline_p) { + } + + //! The client-global context; caution needs to be taken when used in parallel situations + ClientContext &client; + //! The thread-local context for this execution + ThreadContext &thread; + //! Reference to the pipeline for this execution, can be used for example by operators determine caching strategy + optional_ptr pipeline; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/executor.hpp b/src/duckdb/src/include/duckdb/execution/executor.hpp new file mode 100644 index 00000000..eb9f73e4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/executor.hpp @@ -0,0 +1,163 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/pending_execution_result.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/parallel/pipeline.hpp" + +namespace duckdb { +class ClientContext; +class DataChunk; +class PhysicalOperator; +class PipelineExecutor; +class OperatorState; +class QueryProfiler; +class ThreadContext; +class Task; + +struct PipelineEventStack; +struct ProducerToken; +struct ScheduleEventData; + +class Executor { + friend class Pipeline; + friend class PipelineTask; + friend class PipelineBuildState; + +public: + explicit Executor(ClientContext &context); + ~Executor(); + + ClientContext &context; + +public: + static Executor &Get(ClientContext &context); + + void Initialize(PhysicalOperator &physical_plan); + void Initialize(unique_ptr physical_plan); + + void CancelTasks(); + PendingExecutionResult ExecuteTask(); + + void Reset(); + + vector GetTypes(); + + unique_ptr FetchChunk(); + + //! Push a new error + void PushError(PreservedError exception); + + //! True if an error has been thrown + bool HasError(); + //! Throw the exception that was pushed using PushError. + //! Should only be called if HasError returns true + void ThrowException(); + + //! Work on tasks for this specific executor, until there are no tasks remaining + void WorkOnTasks(); + + //! Flush a thread context into the client context + void Flush(ThreadContext &context); + + //! Reschedules a task that was blocked + void RescheduleTask(shared_ptr &task); + + //! Add the task to be rescheduled + void AddToBeRescheduled(shared_ptr &task); + + //! Returns the progress of the pipelines + bool GetPipelinesProgress(double ¤t_progress); + + void CompletePipeline() { + completed_pipelines++; + } + ProducerToken &GetToken() { + return *producer; + } + void AddEvent(shared_ptr event); + + void AddRecursiveCTE(PhysicalOperator &rec_cte); + void AddMaterializedCTE(PhysicalOperator &mat_cte); + void ReschedulePipelines(const vector> &pipelines, vector> &events); + + //! Whether or not the root of the pipeline is a result collector object + bool HasResultCollector(); + //! Returns the query result - can only be used if `HasResultCollector` returns true + unique_ptr GetResult(); + + //! Returns true if all pipelines have been completed + bool ExecutionIsFinished(); + +private: + void InitializeInternal(PhysicalOperator &physical_plan); + + void ScheduleEvents(const vector> &meta_pipelines); + static void ScheduleEventsInternal(ScheduleEventData &event_data); + + static void VerifyScheduledEvents(const ScheduleEventData &event_data); + static void VerifyScheduledEventsInternal(const idx_t i, const vector &vertices, vector &visited, + vector &recursion_stack); + + static void SchedulePipeline(const shared_ptr &pipeline, ScheduleEventData &event_data); + + bool NextExecutor(); + + shared_ptr CreateChildPipeline(Pipeline ¤t, PhysicalOperator &op); + + void VerifyPipeline(Pipeline &pipeline); + void VerifyPipelines(); + +private: + optional_ptr physical_plan; + unique_ptr owned_plan; + + mutex executor_lock; + mutex error_lock; + //! All pipelines of the query plan + vector> pipelines; + //! The root pipelines of the query + vector> root_pipelines; + //! The recursive CTE's in this query plan + vector> recursive_ctes; + //! The materialized CTE's in this query plan + vector> materialized_ctes; + //! The pipeline executor for the root pipeline + unique_ptr root_executor; + //! The current root pipeline index + idx_t root_pipeline_idx; + //! The producer of this query + unique_ptr producer; + //! Exceptions that occurred during the execution of the current query + vector exceptions; + //! List of events + vector> events; + //! The query profiler + shared_ptr profiler; + + //! The amount of completed pipelines of the query + atomic completed_pipelines; + //! The total amount of pipelines in the query + idx_t total_pipelines; + //! Whether or not execution is cancelled + bool cancelled; + + //! The last pending execution result (if any) + PendingExecutionResult execution_result; + //! The current task in process (if any) + shared_ptr task; + + //! Task that have been descheduled + unordered_map> to_be_rescheduled_tasks; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/expression_executor.hpp b/src/duckdb/src/include/duckdb/execution/expression_executor.hpp new file mode 100644 index 00000000..a3430435 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/expression_executor.hpp @@ -0,0 +1,163 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/expression_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/planner/bound_tokens.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { +class Allocator; +class ExecutionContext; + +//! ExpressionExecutor is responsible for executing a set of expressions and storing the result in a data chunk +class ExpressionExecutor { + friend class Index; + friend class CreateIndexLocalSinkState; + +public: + DUCKDB_API explicit ExpressionExecutor(ClientContext &context); + DUCKDB_API ExpressionExecutor(ClientContext &context, const Expression *expression); + DUCKDB_API ExpressionExecutor(ClientContext &context, const Expression &expression); + DUCKDB_API ExpressionExecutor(ClientContext &context, const vector> &expressions); + ExpressionExecutor(ExpressionExecutor &&) = delete; + + //! The expressions of the executor + vector expressions; + //! The data chunk of the current physical operator, used to resolve + //! column references and determines the output cardinality + DataChunk *chunk = nullptr; + +public: + bool HasContext(); + ClientContext &GetContext(); + Allocator &GetAllocator(); + + //! Add an expression to the set of to-be-executed expressions of the executor + DUCKDB_API void AddExpression(const Expression &expr); + + //! Execute the set of expressions with the given input chunk and store the result in the output chunk + DUCKDB_API void Execute(DataChunk *input, DataChunk &result); + inline void Execute(DataChunk &input, DataChunk &result) { + Execute(&input, result); + } + inline void Execute(DataChunk &result) { + Execute(nullptr, result); + } + + //! Execute the ExpressionExecutor and put the result in the result vector; this should only be used for expression + //! executors with a single expression + DUCKDB_API void ExecuteExpression(DataChunk &input, Vector &result); + //! Execute the ExpressionExecutor and put the result in the result vector; this should only be used for expression + //! executors with a single expression + DUCKDB_API void ExecuteExpression(Vector &result); + //! Execute the ExpressionExecutor and generate a selection vector from all true values in the result; this should + //! only be used with a single boolean expression + DUCKDB_API idx_t SelectExpression(DataChunk &input, SelectionVector &sel); + + //! Execute the expression with index `expr_idx` and store the result in the result vector + DUCKDB_API void ExecuteExpression(idx_t expr_idx, Vector &result); + //! Evaluate a scalar expression and fold it into a single value + DUCKDB_API static Value EvaluateScalar(ClientContext &context, const Expression &expr, + bool allow_unfoldable = false); + //! Try to evaluate a scalar expression and fold it into a single value, returns false if an exception is thrown + DUCKDB_API static bool TryEvaluateScalar(ClientContext &context, const Expression &expr, Value &result); + + //! Initialize the state of a given expression + static unique_ptr InitializeState(const Expression &expr, ExpressionExecutorState &state); + + inline void SetChunk(DataChunk *chunk) { + this->chunk = chunk; + } + inline void SetChunk(DataChunk &chunk) { + SetChunk(&chunk); + } + + DUCKDB_API vector> &GetStates(); + +protected: + void Initialize(const Expression &expr, ExpressionExecutorState &state); + + static unique_ptr InitializeState(const BoundReferenceExpression &expr, + ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundBetweenExpression &expr, + ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundCaseExpression &expr, ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundCastExpression &expr, ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundComparisonExpression &expr, + ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundConjunctionExpression &expr, + ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundConstantExpression &expr, + ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundFunctionExpression &expr, + ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundOperatorExpression &expr, + ExpressionExecutorState &state); + static unique_ptr InitializeState(const BoundParameterExpression &expr, + ExpressionExecutorState &state); + + void Execute(const Expression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + + void Execute(const BoundBetweenExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + void Execute(const BoundCaseExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + void Execute(const BoundCastExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + + void Execute(const BoundComparisonExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + void Execute(const BoundConjunctionExpression &expr, ExpressionState *state, const SelectionVector *sel, + idx_t count, Vector &result); + void Execute(const BoundConstantExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + void Execute(const BoundFunctionExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + void Execute(const BoundOperatorExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + void Execute(const BoundParameterExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + void Execute(const BoundReferenceExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + Vector &result); + + //! Execute the (boolean-returning) expression and generate a selection vector with all entries that are "true" in + //! the result + idx_t Select(const Expression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + idx_t DefaultSelect(const Expression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + + idx_t Select(const BoundBetweenExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + idx_t Select(const BoundComparisonExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel); + idx_t Select(const BoundConjunctionExpression &expr, ExpressionState *state, const SelectionVector *sel, + idx_t count, SelectionVector *true_sel, SelectionVector *false_sel); + + //! Verify that the output of a step in the ExpressionExecutor is correct + void Verify(const Expression &expr, Vector &result, idx_t count); + + void FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count); + +private: + //! Client context + optional_ptr context; + //! The states of the expression executor; this holds any intermediates and temporary states of expressions + vector> states; + +private: + // it is possible to create an expression executor without a ClientContext - but it should be avoided + DUCKDB_API ExpressionExecutor(); + DUCKDB_API ExpressionExecutor(const vector> &exprs); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp b/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp new file mode 100644 index 00000000..c40908ed --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/expression_executor_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/cycle_counter.hpp" +#include "duckdb/function/function.hpp" + +namespace duckdb { +class Expression; +class ExpressionExecutor; +struct ExpressionExecutorState; +struct FunctionLocalState; + +struct ExpressionState { + ExpressionState(const Expression &expr, ExpressionExecutorState &root); + virtual ~ExpressionState() { + } + + const Expression &expr; + ExpressionExecutorState &root; + vector> child_states; + vector types; + DataChunk intermediate_chunk; + CycleCounter profiler; + +public: + void AddChild(Expression *expr); + void Finalize(); + Allocator &GetAllocator(); + bool HasContext(); + DUCKDB_API ClientContext &GetContext(); + + void Verify(ExpressionExecutorState &root); + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct ExecuteFunctionState : public ExpressionState { + ExecuteFunctionState(const Expression &expr, ExpressionExecutorState &root); + ~ExecuteFunctionState(); + + unique_ptr local_state; + +public: + static optional_ptr GetFunctionState(ExpressionState &state) { + return state.Cast().local_state.get(); + } +}; + +struct ExpressionExecutorState { + ExpressionExecutorState(); + + unique_ptr root_state; + ExpressionExecutor *executor = nullptr; + CycleCounter profiler; + + void Verify(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp new file mode 100644 index 00000000..c3fc090d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp @@ -0,0 +1,151 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/art.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/index.hpp" +#include "duckdb/execution/index/art/node.hpp" +#include "duckdb/common/array.hpp" + +namespace duckdb { + +// classes +enum class VerifyExistenceType : uint8_t { + APPEND = 0, // appends to a table + APPEND_FK = 1, // appends to a table that has a foreign key + DELETE_FK = 2 // delete from a table that has a foreign key +}; +class ConflictManager; +class ARTKey; +class FixedSizeAllocator; + +// structs +struct ARTIndexScanState; +struct ARTFlags { + vector vacuum_flags; + vector merge_buffer_counts; +}; + +class ART : public Index { +public: + //! FixedSizeAllocator count of the ART + static constexpr uint8_t ALLOCATOR_COUNT = 6; + +public: + //! Constructs an ART + ART(const vector &column_ids, TableIOManager &table_io_manager, + const vector> &unbound_expressions, const IndexConstraintType constraint_type, + AttachedDatabase &db, + const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr = nullptr, + const BlockPointer &block = BlockPointer()); + + //! Root of the tree + Node tree = Node(); + //! Fixed-size allocators holding the ART nodes + shared_ptr, ALLOCATOR_COUNT>> allocators; + //! True, if the ART owns its data + bool owns_data; + +public: + //! Initialize a single predicate scan on the index with the given expression and column IDs + unique_ptr InitializeScanSinglePredicate(const Transaction &transaction, const Value &value, + const ExpressionType expression_type) override; + //! Initialize a two predicate scan on the index with the given expression and column IDs + unique_ptr InitializeScanTwoPredicates(const Transaction &transaction, const Value &low_value, + const ExpressionType low_expression_type, + const Value &high_value, + const ExpressionType high_expression_type) override; + //! Performs a lookup on the index, fetching up to max_count result IDs. Returns true if all row IDs were fetched, + //! and false otherwise + bool Scan(const Transaction &transaction, const DataTable &table, IndexScanState &state, const idx_t max_count, + vector &result_ids) override; + + //! Called when data is appended to the index. The lock obtained from InitializeLock must be held + PreservedError Append(IndexLock &lock, DataChunk &entries, Vector &row_identifiers) override; + //! Verify that data can be appended to the index without a constraint violation + void VerifyAppend(DataChunk &chunk) override; + //! Verify that data can be appended to the index without a constraint violation using the conflict manager + void VerifyAppend(DataChunk &chunk, ConflictManager &conflict_manager) override; + //! Deletes all data from the index. The lock obtained from InitializeLock must be held + void CommitDrop(IndexLock &index_lock) override; + //! Delete a chunk of entries from the index. The lock obtained from InitializeLock must be held + void Delete(IndexLock &lock, DataChunk &entries, Vector &row_identifiers) override; + //! Insert a chunk of entries into the index + PreservedError Insert(IndexLock &lock, DataChunk &data, Vector &row_ids) override; + + //! Construct an ART from a vector of sorted keys + bool ConstructFromSorted(idx_t count, vector &keys, Vector &row_identifiers); + + //! Search equal values and fetches the row IDs + bool SearchEqual(ARTKey &key, idx_t max_count, vector &result_ids); + //! Search equal values used for joins that do not need to fetch data + void SearchEqualJoinNoFetch(ARTKey &key, idx_t &result_size); + + //! Serializes the index and returns the pair of block_id offset positions + BlockPointer Serialize(MetadataWriter &writer) override; + + //! Merge another index into this index. The lock obtained from InitializeLock must be held, and the other + //! index must also be locked during the merge + bool MergeIndexes(IndexLock &state, Index &other_index) override; + + //! Traverses an ART and vacuums the qualifying nodes. The lock obtained from InitializeLock must be held + void Vacuum(IndexLock &state) override; + + //! Generate ART keys for an input chunk + static void GenerateKeys(ArenaAllocator &allocator, DataChunk &input, vector &keys); + + //! Generate a string containing all the expressions and their respective values that violate a constraint + string GenerateErrorKeyName(DataChunk &input, idx_t row); + //! Generate the matching error message for a constraint violation + string GenerateConstraintErrorMessage(VerifyExistenceType verify_type, const string &key_name); + //! Performs constraint checking for a chunk of input data + void CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_manager) override; + + //! Returns the string representation of the ART, or only traverses and verifies the index + string VerifyAndToString(IndexLock &state, const bool only_verify) override; + + //! Find the node with a matching key, or return nullptr if not found + optional_ptr Lookup(const Node &node, const ARTKey &key, idx_t depth); + //! Insert a key into the tree + bool Insert(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id); + +private: + //! Insert a row ID into a leaf + bool InsertToLeaf(Node &leaf, const row_t &row_id); + //! Erase a key from the tree (if a leaf has more than one value) or erase the leaf itself + void Erase(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id); + + //! Returns all row IDs belonging to a key greater (or equal) than the search key + bool SearchGreater(ARTIndexScanState &state, ARTKey &key, bool equal, idx_t max_count, vector &result_ids); + //! Returns all row IDs belonging to a key less (or equal) than the upper_bound + bool SearchLess(ARTIndexScanState &state, ARTKey &upper_bound, bool equal, idx_t max_count, + vector &result_ids); + //! Returns all row IDs belonging to a key within the range of lower_bound and upper_bound + bool SearchCloseRange(ARTIndexScanState &state, ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, + bool right_equal, idx_t max_count, vector &result_ids); + + //! Initializes a merge operation by returning a set containing the buffer count of each fixed-size allocator + void InitializeMerge(ARTFlags &flags); + + //! Initializes a vacuum operation by calling the initialize operation of the respective + //! node allocator, and returns a vector containing either true, if the allocator at + //! the respective position qualifies, or false, if not + void InitializeVacuum(ARTFlags &flags); + //! Finalizes a vacuum operation by calling the finalize operation of all qualifying + //! fixed size allocators + void FinalizeVacuum(const ARTFlags &flags); + + //! Internal function to return the string representation of the ART, + //! or only traverses and verifies the index + string VerifyAndToStringInternal(const bool only_verify); + + //! Deserialize the allocators of the ART + void Deserialize(const BlockPointer &pointer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art_key.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art_key.hpp new file mode 100644 index 00000000..0bb9c0f0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/art_key.hpp @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/art_key.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/radix.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/storage/arena_allocator.hpp" + +namespace duckdb { + +class ARTKey { +public: + ARTKey(); + ARTKey(const data_ptr_t &data, const uint32_t &len); + ARTKey(ArenaAllocator &allocator, const uint32_t &len); + + uint32_t len; + data_ptr_t data; + +public: + template + static inline ARTKey CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, T element) { + auto data = ARTKey::CreateData(allocator, element); + return ARTKey(data, sizeof(element)); + } + + template + static inline ARTKey CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, const Value &element) { + return CreateARTKey(allocator, type, element.GetValueUnsafe()); + } + + template + static inline void CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, T element) { + key.data = ARTKey::CreateData(allocator, element); + key.len = sizeof(element); + } + + template + static inline void CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, + const Value element) { + key.data = ARTKey::CreateData(allocator, element.GetValueUnsafe()); + key.len = sizeof(element); + } + +public: + data_t &operator[](size_t i) { + return data[i]; + } + const data_t &operator[](size_t i) const { + return data[i]; + } + bool operator>(const ARTKey &k) const; + bool operator>=(const ARTKey &k) const; + bool operator==(const ARTKey &k) const; + + inline bool ByteMatches(const ARTKey &other, const uint32_t &depth) const { + return data[depth] == other[depth]; + } + inline bool Empty() const { + return len == 0; + } + void ConcatenateARTKey(ArenaAllocator &allocator, ARTKey &concat_key); + +private: + template + static inline data_ptr_t CreateData(ArenaAllocator &allocator, T value) { + auto data = allocator.Allocate(sizeof(value)); + Radix::EncodeData(data, value); + return data; + } +}; + +template <> +ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, string_t value); +template <> +ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, const char *value); +template <> +void ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, string_t value); +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp b/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp new file mode 100644 index 00000000..437a818b --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp @@ -0,0 +1,82 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/iterator.hpp +// +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "duckdb/common/stack.hpp" +#include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/leaf.hpp" +#include "duckdb/execution/index/art/node.hpp" + +namespace duckdb { + +//! Keeps track of the byte leading to the currently active child of the node +struct IteratorEntry { + IteratorEntry(Node node, uint8_t byte) : node(node), byte(byte) { + } + + Node node; + uint8_t byte = 0; +}; + +//! Keeps track of the current key in the iterator leading down to the top node in the stack +class IteratorKey { +public: + //! Pushes a byte into the current key + inline void Push(const uint8_t key_byte) { + key_bytes.push_back(key_byte); + } + //! Pops n bytes from the current key + inline void Pop(const idx_t n) { + key_bytes.resize(key_bytes.size() - n); + } + + //! Subscript operator + inline uint8_t &operator[](idx_t idx) { + D_ASSERT(idx < key_bytes.size()); + return key_bytes[idx]; + } + //! Greater than operator + bool operator>(const ARTKey &key) const; + //! Greater than or equal to operator + bool operator>=(const ARTKey &key) const; + //! Equal to operator + bool operator==(const ARTKey &key) const; + +private: + vector key_bytes; +}; + +class Iterator { +public: + //! Holds the current key leading down to the top node on the stack + IteratorKey current_key; + //! Pointer to the ART + optional_ptr art = nullptr; + + //! Scans the tree, starting at the current top node on the stack, and ending at upper_bound. + //! If upper_bound is the empty ARTKey, than there is no upper bound + bool Scan(const ARTKey &upper_bound, const idx_t max_count, vector &result_ids, const bool equal); + //! Finds the minimum (leaf) of the current subtree + void FindMinimum(const Node &node); + //! Finds the lower bound of the ART and adds the nodes to the stack. Returns false, if the lower + //! bound exceeds the maximum value of the ART + bool LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth); + +private: + //! Stack of nodes from the root to the currently active node + stack nodes; + //! Last visited leaf node + Node last_leaf = Node(); + + //! Goes to the next leaf in the ART and sets it as last_leaf, + //! returns false if there is no next leaf + bool Next(); + //! Pop the top node from the stack of iterator entries and adjust the current key + void PopNode(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp b/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp new file mode 100644 index 00000000..a981b61b --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/leaf.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/index/fixed_size_allocator.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" + +namespace duckdb { + +// classes +class MetadataWriter; +class MetadataReader; + +// structs +struct BlockPointer; + +//! The LEAF is a special node type that contains a count, up to LEAF_SIZE row IDs, +//! and a Node pointer. If this pointer is set, then it must point to another LEAF, +//! creating a chain of leaf nodes storing row IDs. +//! This class also contains functionality for nodes of type LEAF_INLINED, in which case we store the +//! row ID directly in the node pointer. +class Leaf { +public: + //! Delete copy constructors, as any Leaf can never own its memory + Leaf(const Leaf &) = delete; + Leaf &operator=(const Leaf &) = delete; + + //! The number of row IDs in this leaf + uint8_t count; + //! Up to LEAF_SIZE row IDs + row_t row_ids[Node::LEAF_SIZE]; + //! A pointer to the next LEAF node + Node ptr; + +public: + //! Inline a row ID into a node pointer + static void New(Node &node, const row_t row_id); + //! Get a new chain of leaf nodes, might cause new buffer allocations, + //! with the node parameter holding the tail of the chain + static void New(ART &art, reference &node, const row_t *row_ids, idx_t count); + //! Get a new leaf node without any data + static Leaf &New(ART &art, Node &node); + //! Free the leaf (chain) + static void Free(ART &art, Node &node); + + //! Initializes a merge by incrementing the buffer IDs of the leaf (chain) + static void InitializeMerge(ART &art, Node &node, const ARTFlags &flags); + //! Merge leaf (chains) and free all copied leaf nodes + static void Merge(ART &art, Node &l_node, Node &r_node); + + //! Insert a row ID into a leaf + static void Insert(ART &art, Node &node, const row_t row_id); + //! Remove a row ID from a leaf. Returns true, if the leaf is empty after the removal + static bool Remove(ART &art, reference &node, const row_t row_id); + + //! Get the total count of row IDs in the chain of leaves + static idx_t TotalCount(ART &art, const Node &node); + //! Fill the result_ids vector with the row IDs of this leaf chain, if the total count does not exceed max_count + static bool GetRowIds(ART &art, const Node &node, vector &result_ids, const idx_t max_count); + //! Returns whether the leaf contains the row ID + static bool ContainsRowId(ART &art, const Node &node, const row_t row_id); + + //! Returns the string representation of the leaf (chain), or only traverses and verifies the leaf (chain) + static string VerifyAndToString(ART &art, const Node &node, const bool only_verify); + + //! Vacuum the leaf (chain) + static void Vacuum(ART &art, Node &node); + +private: + //! Moves the inlined row ID onto a leaf + static void MoveInlinedToLeaf(ART &art, Node &node); + //! Appends the row ID to this leaf, or creates a subsequent leaf, if this node is full + Leaf &Append(ART &art, const row_t row_id); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node.hpp new file mode 100644 index 00000000..2e170d41 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/node.hpp @@ -0,0 +1,133 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/execution/index/index_pointer.hpp" +#include "duckdb/execution/index/fixed_size_allocator.hpp" + +namespace duckdb { + +// classes +enum class NType : uint8_t { + PREFIX = 1, + LEAF = 2, + NODE_4 = 3, + NODE_16 = 4, + NODE_48 = 5, + NODE_256 = 6, + LEAF_INLINED = 7, +}; + +class ART; +class Prefix; +class MetadataReader; +class MetadataWriter; + +// structs +struct BlockPointer; +struct ARTFlags; +struct MetaBlockPointer; + +//! The Node is the pointer class of the ART index. +//! It inherits from the IndexPointer, and adds ART-specific functionality +class Node : public IndexPointer { +public: + //! Node thresholds + static constexpr uint8_t NODE_48_SHRINK_THRESHOLD = 12; + static constexpr uint8_t NODE_256_SHRINK_THRESHOLD = 36; + //! Node sizes + static constexpr uint8_t NODE_4_CAPACITY = 4; + static constexpr uint8_t NODE_16_CAPACITY = 16; + static constexpr uint8_t NODE_48_CAPACITY = 48; + static constexpr uint16_t NODE_256_CAPACITY = 256; + //! Other constants + static constexpr uint8_t EMPTY_MARKER = 48; + static constexpr uint8_t LEAF_SIZE = 4; + static constexpr uint8_t PREFIX_SIZE = 15; + static constexpr idx_t AND_ROW_ID = 0x00FFFFFFFFFFFFFF; + +public: + //! Get a new pointer to a node, might cause a new buffer allocation, and initialize it + static void New(ART &art, Node &node, const NType type); + //! Free the node (and its subtree) + static void Free(ART &art, Node &node); + + //! Get references to the allocator + static FixedSizeAllocator &GetAllocator(const ART &art, const NType type); + //! Get a (immutable) reference to the node. If dirty is false, then T should be a const class + template + static inline const NODE &Ref(const ART &art, const Node ptr, const NType type) { + return *(GetAllocator(art, type).Get(ptr, false)); + } + //! Get a (const) reference to the node. If dirty is false, then T should be a const class + template + static inline NODE &RefMutable(const ART &art, const Node ptr, const NType type) { + return *(GetAllocator(art, type).Get(ptr)); + } + + //! Replace the child node at byte + void ReplaceChild(const ART &art, const uint8_t byte, const Node child) const; + //! Insert the child node at byte + static void InsertChild(ART &art, Node &node, const uint8_t byte, const Node child); + //! Delete the child node at byte + static void DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte); + + //! Get the child (immutable) for the respective byte in the node + optional_ptr GetChild(ART &art, const uint8_t byte) const; + //! Get the child for the respective byte in the node + optional_ptr GetChildMutable(ART &art, const uint8_t byte) const; + //! Get the first child (immutable) that is greater or equal to the specific byte + optional_ptr GetNextChild(ART &art, uint8_t &byte) const; + //! Get the first child that is greater or equal to the specific byte + optional_ptr GetNextChildMutable(ART &art, uint8_t &byte) const; + + //! Returns the string representation of the node, or only traverses and verifies the node and its subtree + string VerifyAndToString(ART &art, const bool only_verify) const; + //! Returns the capacity of the node + idx_t GetCapacity() const; + //! Returns the matching node type for a given count + static NType GetARTNodeTypeByCount(const idx_t count); + + //! Initializes a merge by incrementing the buffer IDs of a node and its subtree + void InitializeMerge(ART &art, const ARTFlags &flags); + //! Merge another node into this node + bool Merge(ART &art, Node &other); + //! Merge two nodes by first resolving their prefixes + bool ResolvePrefixes(ART &art, Node &other); + //! Merge two nodes that have no prefix or the same prefix + bool MergeInternal(ART &art, Node &other); + + //! Vacuum all nodes that exceed their respective vacuum thresholds + void Vacuum(ART &art, const ARTFlags &flags); + + //! Get the row ID (8th to 63rd bit) + inline row_t GetRowId() const { + return Get() & AND_ROW_ID; + } + //! Set the row ID (8th to 63rd bit) + inline void SetRowId(const row_t row_id) { + Set((Get() & AND_METADATA) | row_id); + } + + //! Returns the type of the node, which is held in the metadata + inline NType GetType() const { + return NType(GetMetadata()); + } + + //! Assign operator + inline void operator=(const IndexPointer &ptr) { + Set(ptr.Get()); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node16.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node16.hpp new file mode 100644 index 00000000..36d85e83 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/node16.hpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/node16.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/index/fixed_size_allocator.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" + +namespace duckdb { + +//! Node16 holds up to 16 Node children sorted by their key byte +class Node16 { +public: + //! Delete copy constructors, as any Node16 can never own its memory + Node16(const Node16 &) = delete; + Node16 &operator=(const Node16 &) = delete; + + //! Number of non-null children + uint8_t count; + //! Array containing all partial key bytes + uint8_t key[Node::NODE_16_CAPACITY]; + //! Node pointers to the child nodes + Node children[Node::NODE_16_CAPACITY]; + +public: + //! Get a new Node16, might cause a new buffer allocation, and initialize it + static Node16 &New(ART &art, Node &node); + //! Free the node (and its subtree) + static void Free(ART &art, Node &node); + + //! Initializes all the fields of the node while growing a Node4 to a Node16 + static Node16 &GrowNode4(ART &art, Node &node16, Node &node4); + //! Initializes all fields of the node while shrinking a Node48 to a Node16 + static Node16 &ShrinkNode48(ART &art, Node &node16, Node &node48); + + //! Initializes a merge by incrementing the buffer IDs of the node + void InitializeMerge(ART &art, const ARTFlags &flags); + + //! Insert a child node at byte + static void InsertChild(ART &art, Node &node, const uint8_t byte, const Node child); + //! Delete the child node at byte + static void DeleteChild(ART &art, Node &node, const uint8_t byte); + + //! Replace the child node at byte + void ReplaceChild(const uint8_t byte, const Node child); + + //! Get the (immutable) child for the respective byte in the node + optional_ptr GetChild(const uint8_t byte) const; + //! Get the child for the respective byte in the node + optional_ptr GetChildMutable(const uint8_t byte); + //! Get the first (immutable) child that is greater or equal to the specific byte + optional_ptr GetNextChild(uint8_t &byte) const; + //! Get the first child that is greater or equal to the specific byte + optional_ptr GetNextChildMutable(uint8_t &byte); + + //! Vacuum the children of the node + void Vacuum(ART &art, const ARTFlags &flags); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node256.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node256.hpp new file mode 100644 index 00000000..bd6e9c3d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/node256.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/node256.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/index/fixed_size_allocator.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" + +namespace duckdb { + +//! Node256 holds up to 256 Node children which can be directly indexed by the key byte +class Node256 { +public: + //! Delete copy constructors, as any Node256 can never own its memory + Node256(const Node256 &) = delete; + Node256 &operator=(const Node256 &) = delete; + + //! Number of non-null children + uint16_t count; + //! Node pointers to the child nodes + Node children[Node::NODE_256_CAPACITY]; + +public: + //! Get a new Node256, might cause a new buffer allocation, and initialize it + static Node256 &New(ART &art, Node &node); + //! Free the node (and its subtree) + static void Free(ART &art, Node &node); + + //! Initializes all the fields of the node while growing a Node48 to a Node256 + static Node256 &GrowNode48(ART &art, Node &node256, Node &node48); + + //! Initializes a merge by incrementing the buffer IDs of the node + void InitializeMerge(ART &art, const ARTFlags &flags); + + //! Insert a child node at byte + static void InsertChild(ART &art, Node &node, const uint8_t byte, const Node child); + //! Delete the child node at byte + static void DeleteChild(ART &art, Node &node, const uint8_t byte); + + //! Replace the child node at byte + inline void ReplaceChild(const uint8_t byte, const Node child) { + children[byte] = child; + } + + //! Get the (immutable) child for the respective byte in the node + optional_ptr GetChild(const uint8_t byte) const; + //! Get the child for the respective byte in the node + optional_ptr GetChildMutable(const uint8_t byte); + //! Get the first (immutable) child that is greater or equal to the specific byte + optional_ptr GetNextChild(uint8_t &byte) const; + //! Get the first child that is greater or equal to the specific byte + optional_ptr GetNextChildMutable(uint8_t &byte); + + //! Vacuum the children of the node + void Vacuum(ART &art, const ARTFlags &flags); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node4.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node4.hpp new file mode 100644 index 00000000..86952b85 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/node4.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/node4.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/index/fixed_size_allocator.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" + +namespace duckdb { + +//! Node4 holds up to four Node children sorted by their key byte +class Node4 { +public: + //! Delete copy constructors, as any Node4 can never own its memory + Node4(const Node4 &) = delete; + Node4 &operator=(const Node4 &) = delete; + + //! Number of non-null children + uint8_t count; + //! Array containing all partial key bytes + uint8_t key[Node::NODE_4_CAPACITY]; + //! Node pointers to the child nodes + Node children[Node::NODE_4_CAPACITY]; + +public: + //! Get a new Node4, might cause a new buffer allocation, and initialize it + static Node4 &New(ART &art, Node &node); + //! Free the node (and its subtree) + static void Free(ART &art, Node &node); + + //! Initializes all fields of the node while shrinking a Node16 to a Node4 + static Node4 &ShrinkNode16(ART &art, Node &node4, Node &node16); + + //! Initializes a merge by incrementing the buffer IDs of the child nodes + void InitializeMerge(ART &art, const ARTFlags &flags); + + //! Insert a child node at byte + static void InsertChild(ART &art, Node &node, const uint8_t byte, const Node child); + //! Delete the child node at byte + static void DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte); + + //! Replace the child node at byte + void ReplaceChild(const uint8_t byte, const Node child); + + //! Get the (immutable) child for the respective byte in the node + optional_ptr GetChild(const uint8_t byte) const; + //! Get the child for the respective byte in the node + optional_ptr GetChildMutable(const uint8_t byte); + //! Get the first (immutable) child that is greater or equal to the specific byte + optional_ptr GetNextChild(uint8_t &byte) const; + //! Get the first child that is greater or equal to the specific byte + optional_ptr GetNextChildMutable(uint8_t &byte); + + //! Vacuum the children of the node + void Vacuum(ART &art, const ARTFlags &flags); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node48.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node48.hpp new file mode 100644 index 00000000..f57eea2f --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/node48.hpp @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/node48.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/index/fixed_size_allocator.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" + +namespace duckdb { + +//! Node48 holds up to 48 Node children. It contains a child_index array which can be directly indexed by the key +//! byte, and which contains the position of the child node in the children array +class Node48 { +public: + //! Delete copy constructors, as any Node48 can never own its memory + Node48(const Node48 &) = delete; + Node48 &operator=(const Node48 &) = delete; + + //! Number of non-null children + uint8_t count; + //! Array containing all possible partial key bytes, those not set have an EMPTY_MARKER + uint8_t child_index[Node::NODE_256_CAPACITY]; + //! Node pointers to the child nodes + Node children[Node::NODE_48_CAPACITY]; + +public: + //! Get a new Node48, might cause a new buffer allocation, and initialize it + static Node48 &New(ART &art, Node &node); + //! Free the node (and its subtree) + static void Free(ART &art, Node &node); + + //! Initializes all the fields of the node while growing a Node16 to a Node48 + static Node48 &GrowNode16(ART &art, Node &node48, Node &node16); + //! Initializes all fields of the node while shrinking a Node256 to a Node48 + static Node48 &ShrinkNode256(ART &art, Node &node48, Node &node256); + + //! Initializes a merge by incrementing the buffer IDs of the node + void InitializeMerge(ART &art, const ARTFlags &flags); + + //! Insert a child node at byte + static void InsertChild(ART &art, Node &node, const uint8_t byte, const Node child); + //! Delete the child node at byte + static void DeleteChild(ART &art, Node &node, const uint8_t byte); + + //! Replace the child node at byte + inline void ReplaceChild(const uint8_t byte, const Node child) { + D_ASSERT(child_index[byte] != Node::EMPTY_MARKER); + children[child_index[byte]] = child; + } + + //! Get the (immutable) child for the respective byte in the node + optional_ptr GetChild(const uint8_t byte) const; + //! Get the child for the respective byte in the node + optional_ptr GetChildMutable(const uint8_t byte); + //! Get the first (immutable) child that is greater or equal to the specific byte + optional_ptr GetNextChild(uint8_t &byte) const; + //! Get the first child that is greater or equal to the specific byte + optional_ptr GetNextChildMutable(uint8_t &byte); + + //! Vacuum the children of the node + void Vacuum(ART &art, const ARTFlags &flags); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp b/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp new file mode 100644 index 00000000..7c3068fb --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/art/prefix.hpp +// +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "duckdb/execution/index/fixed_size_allocator.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/node.hpp" + +namespace duckdb { + +// classes +class ARTKey; + +//! The Prefix is a special node type that contains up to PREFIX_SIZE bytes, and one byte for the count, +//! and a Node pointer. This pointer either points to a prefix node or another Node. +class Prefix { +public: + //! Delete copy constructors, as any Prefix can never own its memory + Prefix(const Prefix &) = delete; + Prefix &operator=(const Prefix &) = delete; + + //! Up to PREFIX_SIZE bytes of prefix data and the count + uint8_t data[Node::PREFIX_SIZE + 1]; + //! A pointer to the next Node + Node ptr; + +public: + //! Get a new empty prefix node, might cause a new buffer allocation + static Prefix &New(ART &art, Node &node); + //! Create a new prefix node containing a single byte and a pointer to a next node + static Prefix &New(ART &art, Node &node, uint8_t byte, const Node &next = Node()); + //! Get a new chain of prefix nodes, might cause new buffer allocations, + //! with the node parameter holding the tail of the chain + static void New(ART &art, reference &node, const ARTKey &key, const uint32_t depth, uint32_t count); + //! Free the node (and its subtree) + static void Free(ART &art, Node &node); + + //! Initializes a merge by incrementing the buffer ID of the prefix and its child node(s) + static void InitializeMerge(ART &art, Node &node, const ARTFlags &flags); + + //! Appends a byte and a child_prefix to prefix. If there is no prefix, than it pushes the + //! byte on top of child_prefix. If there is no child_prefix, then it creates a new + //! prefix node containing that byte + static void Concatenate(ART &art, Node &prefix_node, const uint8_t byte, Node &child_prefix_node); + //! Traverse a prefix and a key until (1) encountering a non-prefix node, or (2) encountering + //! a mismatching byte, in which case depth indexes the mismatching byte in the key + static idx_t Traverse(ART &art, reference &prefix_node, const ARTKey &key, idx_t &depth); + //! Traverse a prefix and a key until (1) encountering a non-prefix node, or (2) encountering + //! a mismatching byte, in which case depth indexes the mismatching byte in the key + static idx_t TraverseMutable(ART &art, reference &prefix_node, const ARTKey &key, idx_t &depth); + //! Traverse two prefixes to find (1) that they match (so far), or (2) that they have a mismatching position, + //! or (3) that one prefix contains the other prefix. This function aids in merging Nodes, and, therefore, + //! the nodes are not const + static bool Traverse(ART &art, reference &l_node, reference &r_node, idx_t &mismatch_position); + //! Returns the byte at position + static inline uint8_t GetByte(const ART &art, const Node &prefix_node, const idx_t position) { + auto &prefix = Node::Ref(art, prefix_node, NType::PREFIX); + D_ASSERT(position < Node::PREFIX_SIZE); + D_ASSERT(position < prefix.data[Node::PREFIX_SIZE]); + return prefix.data[position]; + } + //! Removes the first n bytes from the prefix and shifts all subsequent bytes in the + //! prefix node(s) by n. Frees empty prefix nodes + static void Reduce(ART &art, Node &prefix_node, const idx_t n); + //! Splits the prefix at position. prefix_node then references the ptr (if any bytes left before + //! the split), or stays unchanged (no bytes left before the split). child_node references + //! the node after the split, which is either a new prefix node, or ptr + static void Split(ART &art, reference &prefix_node, Node &child_node, idx_t position); + + //! Returns the string representation of the node, or only traverses and verifies the node and its subtree + static string VerifyAndToString(ART &art, const Node &node, const bool only_verify); + + //! Vacuum the child of the node + static void Vacuum(ART &art, Node &node, const ARTFlags &flags); + +private: + //! Appends the byte to this prefix node, or creates a subsequent prefix node, + //! if this node is full + Prefix &Append(ART &art, const uint8_t byte); + //! Appends the other_prefix and all its subsequent prefix nodes to this prefix node. + //! Also frees all copied/appended nodes + void Append(ART &art, Node other_prefix); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp b/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp new file mode 100644 index 00000000..6ec5f2d8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp @@ -0,0 +1,120 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/fixed_size_allocator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/metadata/metadata_manager.hpp" +#include "duckdb/storage/metadata/metadata_writer.hpp" +#include "duckdb/execution/index/fixed_size_buffer.hpp" +#include "duckdb/execution/index/index_pointer.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/map.hpp" + +namespace duckdb { + +//! The FixedSizeAllocator provides pointers to fixed-size memory segments of pre-allocated memory buffers. +//! The pointers are IndexPointers, and the leftmost byte (metadata) must always be zero. +//! It is also possible to directly request a C++ pointer to the underlying segment of an index pointer. +class FixedSizeAllocator { +public: + //! We can vacuum 10% or more of the total in-memory footprint + static constexpr uint8_t VACUUM_THRESHOLD = 10; + +public: + FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager); + + //! Block manager of the database instance + BlockManager &block_manager; + //! Buffer manager of the database instance + BufferManager &buffer_manager; + //! Metadata manager for (de)serialization + MetadataManager &metadata_manager; + +public: + //! Get a new IndexPointer to a segment, might cause a new buffer allocation + IndexPointer New(); + //! Free the segment of the IndexPointer + void Free(const IndexPointer ptr); + //! Returns a pointer of type T to a segment. If dirty is false, then T should be a const class + template + inline T *Get(const IndexPointer ptr, const bool dirty = true) { + return (T *)Get(ptr, dirty); + } + + //! Resets the allocator, e.g., during 'DELETE FROM table' + void Reset(); + + //! Returns the in-memory usage in bytes + inline idx_t GetMemoryUsage() const; + + //! Returns the upper bound of the available buffer IDs, i.e., upper_bound > max_buffer_id + idx_t GetUpperBoundBufferId() const; + //! Merge another FixedSizeAllocator into this allocator. Both must have the same segment size + void Merge(FixedSizeAllocator &other); + + //! Initialize a vacuum operation, and return true, if the allocator needs a vacuum + bool InitializeVacuum(); + //! Finalize a vacuum operation by freeing all vacuumed buffers + void FinalizeVacuum(); + //! Returns true, if an IndexPointer qualifies for a vacuum operation, and false otherwise + inline bool NeedsVacuum(const IndexPointer ptr) const { + if (vacuum_buffers.find(ptr.GetBufferId()) != vacuum_buffers.end()) { + return true; + } + return false; + } + //! Vacuums an IndexPointer + IndexPointer VacuumPointer(const IndexPointer ptr); + + //! Serializes all in-memory buffers and the metadata + BlockPointer Serialize(PartialBlockManager &partial_block_manager, MetadataWriter &writer); + //! Deserializes all metadata + void Deserialize(const BlockPointer &block_pointer); + +private: + //! Allocation size of one segment in a buffer + //! We only need this value to calculate bitmask_count, bitmask_offset, and + //! available_segments_per_buffer + idx_t segment_size; + + //! Number of validity_t values in the bitmask + idx_t bitmask_count; + //! First starting byte of the payload (segments) + idx_t bitmask_offset; + //! Number of possible segment allocations per buffer + idx_t available_segments_per_buffer; + + //! Total number of allocated segments in all buffers + //! We can recalculate this by iterating over all buffers + idx_t total_segment_count; + + //! Buffers containing the segments + unordered_map buffers; + //! Buffers with free space + unordered_set buffers_with_free_space; + //! Buffers qualifying for a vacuum (helper field to allow for fast NeedsVacuum checks) + unordered_set vacuum_buffers; + +private: + //! Returns the data_ptr_t to a segment, and sets the dirty flag of the buffer containing that segment + inline data_ptr_t Get(const IndexPointer ptr, const bool dirty = true) { + D_ASSERT(ptr.GetOffset() < available_segments_per_buffer); + D_ASSERT(buffers.find(ptr.GetBufferId()) != buffers.end()); + auto &buffer = buffers.find(ptr.GetBufferId())->second; + auto buffer_ptr = buffer.Get(dirty); + return buffer_ptr + ptr.GetOffset() * segment_size + bitmask_offset; + } + //! Returns an available buffer id + idx_t GetAvailableBufferId() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp b/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp new file mode 100644 index 00000000..5156f201 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp @@ -0,0 +1,109 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/fixed_size_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/partial_block_manager.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" +#include "duckdb/storage/block_manager.hpp" + +namespace duckdb { + +class FixedSizeAllocator; +class MetadataWriter; + +struct PartialBlockForIndex : public PartialBlock { +public: + PartialBlockForIndex(PartialBlockState state, BlockManager &block_manager, + const shared_ptr &block_handle); + ~PartialBlockForIndex() override {}; + +public: + void Flush(const idx_t free_space_left) override; + void Clear() override; + void Merge(PartialBlock &other, idx_t offset, idx_t other_size) override; +}; + +//! A fixed-size buffer holds fixed-size segments of data. It lazily deserializes a buffer, if on-disk and not +//! yet in memory, and it only serializes dirty and non-written buffers to disk during +//! serialization. +class FixedSizeBuffer { +public: + //! Constants for fast offset calculations in the bitmask + static constexpr idx_t BASE[] = {0x00000000FFFFFFFF, 0x0000FFFF, 0x00FF, 0x0F, 0x3, 0x1}; + static constexpr uint8_t SHIFT[] = {32, 16, 8, 4, 2, 1}; + +public: + //! Constructor for a new in-memory buffer + explicit FixedSizeBuffer(BlockManager &block_manager); + //! Constructor for deserializing buffer metadata from disk + FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, + const BlockPointer &block_pointer); + + //! Block manager of the database instance + BlockManager &block_manager; + + //! The number of allocated segments + idx_t segment_count; + //! The size of allocated memory in this buffer (necessary for copying while pinning) + idx_t allocation_size; + + //! True: the in-memory buffer is no longer consistent with a (possibly existing) copy on disk + bool dirty; + //! True: can be vacuumed after the vacuum operation + bool vacuum; + + //! Partial block id and offset + BlockPointer block_pointer; + +public: + //! Returns true, if the buffer is in-memory + inline bool InMemory() const { + return buffer_handle.IsValid(); + } + //! Returns true, if the block is on-disk + inline bool OnDisk() const { + return block_pointer.IsValid(); + } + //! Returns a pointer to the buffer in memory, and calls Deserialize, if the buffer is not in memory + inline data_ptr_t Get(const bool dirty_p = true) { + if (!InMemory()) { + Pin(); + } + if (dirty_p) { + dirty = dirty_p; + } + return buffer_handle.Ptr(); + } + //! Destroys the in-memory buffer and the on-disk block + void Destroy(); + //! Serializes a buffer (if dirty or not on disk) + void Serialize(PartialBlockManager &partial_block_manager, const idx_t available_segments, const idx_t segment_size, + const idx_t bitmask_offset); + //! Pin a buffer (if not in-memory) + void Pin(); + //! Returns the first free offset in a bitmask + uint32_t GetOffset(const idx_t bitmask_count); + +private: + //! The buffer handle of the in-memory buffer + BufferHandle buffer_handle; + //! The block handle of the on-disk buffer + shared_ptr block_handle; + +private: + //! Returns the maximum non-free offset in a bitmask + uint32_t GetMaxOffset(const idx_t available_segments_per_buffer); + //! Sets all uninitialized regions of a buffer in the respective partial block allocation + void SetUninitializedRegions(PartialBlockForIndex &p_block_for_index, const idx_t segment_size, const idx_t offset, + const idx_t bitmask_offset); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/index_pointer.hpp b/src/duckdb/src/include/duckdb/execution/index/index_pointer.hpp new file mode 100644 index 00000000..3b7f0c75 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/index/index_pointer.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/index/index_pointer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/typedefs.hpp" + +namespace duckdb { + +class IndexPointer { +public: + //! Bit-shifting + static constexpr idx_t SHIFT_OFFSET = 32; + static constexpr idx_t SHIFT_METADATA = 56; + //! AND operations + static constexpr idx_t AND_OFFSET = 0x0000000000FFFFFF; + static constexpr idx_t AND_BUFFER_ID = 0x00000000FFFFFFFF; + static constexpr idx_t AND_METADATA = 0xFF00000000000000; + +public: + //! Constructs an empty IndexPointer + IndexPointer() : data(0) {}; + //! Constructs an in-memory IndexPointer with a buffer ID and an offset + IndexPointer(const uint32_t buffer_id, const uint32_t offset) : data(0) { + auto shifted_offset = ((idx_t)offset) << SHIFT_OFFSET; + data += shifted_offset; + data += buffer_id; + }; + +public: + //! Get data (all 64 bits) + inline idx_t Get() const { + return data; + } + //! Set data (all 64 bits) + inline void Set(const idx_t data_p) { + data = data_p; + } + + //! Returns false, if the metadata is empty + inline bool HasMetadata() const { + return data & AND_METADATA; + } + //! Get metadata (zero to 7th bit) + inline uint8_t GetMetadata() const { + return data >> SHIFT_METADATA; + } + //! Set metadata (zero to 7th bit) + inline void SetMetadata(const uint8_t metadata) { + data += (idx_t)metadata << SHIFT_METADATA; + } + + //! Get the offset (8th to 23rd bit) + inline idx_t GetOffset() const { + auto offset = data >> SHIFT_OFFSET; + return offset & AND_OFFSET; + } + //! Get the buffer ID (24th to 63rd bit) + inline idx_t GetBufferId() const { + return data & AND_BUFFER_ID; + } + + //! Resets the IndexPointer + inline void Clear() { + data = 0; + } + + //! Adds an idx_t to a buffer ID, the rightmost 32 bits of data contain the buffer ID + inline void IncreaseBufferId(const idx_t summand) { + data += summand; + } + + //! Comparison operator + inline bool operator==(const IndexPointer &ptr) const { + return data == ptr.data; + } + +private: + //! Data holds all the information contained in an IndexPointer + //! [0 - 7: metadata, + //! 8 - 23: offset, 24 - 63: buffer ID] + //! NOTE: we do not use bit fields because when using bit fields Windows compiles + //! the IndexPointer class into 16 bytes instead of the intended 8 bytes, doubling the + //! space requirements + //! https://learn.microsoft.com/en-us/cpp/cpp/cpp-bit-fields?view=msvc-170 + idx_t data; +}; + +static_assert(sizeof(IndexPointer) == sizeof(idx_t), "Invalid size for IndexPointer."); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp new file mode 100644 index 00000000..a602ae18 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp @@ -0,0 +1,334 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/join_hashtable.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/common/types/column/column_data_consumer.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/row/tuple_data_iterator.hpp" +#include "duckdb/common/types/row/tuple_data_layout.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/execution/aggregate_hashtable.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/storage/storage_info.hpp" + +namespace duckdb { + +class BufferManager; +class BufferHandle; +class ColumnDataCollection; +struct ColumnDataAppendState; +struct ClientConfig; + +struct JoinHTScanState { +public: + JoinHTScanState(TupleDataCollection &collection, idx_t chunk_idx_from, idx_t chunk_idx_to, + TupleDataPinProperties properties = TupleDataPinProperties::ALREADY_PINNED) + : iterator(collection, properties, chunk_idx_from, chunk_idx_to, false), offset_in_chunk(0) { + } + + TupleDataChunkIterator iterator; + idx_t offset_in_chunk; + +private: + //! Implicit copying is not allowed + JoinHTScanState(const JoinHTScanState &) = delete; +}; + +//! JoinHashTable is a linear probing HT that is used for computing joins +/*! + The JoinHashTable concatenates incoming chunks inside a linked list of + data ptrs. The storage looks like this internally. + [SERIALIZED ROW][NEXT POINTER] + [SERIALIZED ROW][NEXT POINTER] + There is a separate hash map of pointers that point into this table. + This is what is used to resolve the hashes. + [POINTER] + [POINTER] + [POINTER] + The pointers are either NULL +*/ +class JoinHashTable { +public: + using ValidityBytes = TemplatedValidityMask; + + //! Scan structure that can be used to resume scans, as a single probe can + //! return 1024*N values (where N is the size of the HT). This is + //! returned by the JoinHashTable::Scan function and can be used to resume a + //! probe. + struct ScanStructure { + TupleDataChunkState &key_state; + Vector pointers; + idx_t count; + SelectionVector sel_vector; + // whether or not the given tuple has found a match + unsafe_unique_array found_match; + JoinHashTable &ht; + bool finished; + + explicit ScanStructure(JoinHashTable &ht, TupleDataChunkState &key_state); + //! Get the next batch of data from the scan structure + void Next(DataChunk &keys, DataChunk &left, DataChunk &result); + + private: + //! Next operator for the inner join + void NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &result); + //! Next operator for the semi join + void NextSemiJoin(DataChunk &keys, DataChunk &left, DataChunk &result); + //! Next operator for the anti join + void NextAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result); + //! Next operator for the left outer join + void NextLeftJoin(DataChunk &keys, DataChunk &left, DataChunk &result); + //! Next operator for the mark join + void NextMarkJoin(DataChunk &keys, DataChunk &left, DataChunk &result); + //! Next operator for the single join + void NextSingleJoin(DataChunk &keys, DataChunk &left, DataChunk &result); + + //! Scan the hashtable for matches of the specified keys, setting the found_match[] array to true or false + //! for every tuple + void ScanKeyMatches(DataChunk &keys); + template + void NextSemiOrAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result); + + void ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &child, DataChunk &result); + + idx_t ScanInnerJoin(DataChunk &keys, SelectionVector &result_vector); + + public: + void InitializeSelectionVector(const SelectionVector *¤t_sel); + void AdvancePointers(); + void AdvancePointers(const SelectionVector &sel, idx_t sel_count); + void GatherResult(Vector &result, const SelectionVector &result_vector, const SelectionVector &sel_vector, + const idx_t count, const idx_t col_idx); + void GatherResult(Vector &result, const SelectionVector &sel_vector, const idx_t count, const idx_t col_idx); + idx_t ResolvePredicates(DataChunk &keys, SelectionVector &match_sel, SelectionVector *no_match_sel); + }; + +public: + JoinHashTable(BufferManager &buffer_manager, const vector &conditions, + vector build_types, JoinType type); + ~JoinHashTable(); + + //! Add the given data to the HT + void Build(PartitionedTupleDataAppendState &append_state, DataChunk &keys, DataChunk &input); + //! Merge another HT into this one + void Merge(JoinHashTable &other); + //! Combines the partitions in sink_collection into data_collection, as if it were not partitioned + void Unpartition(); + //! Initialize the pointer table for the probe + void InitializePointerTable(); + //! Finalize the build of the HT, constructing the actual hash table and making the HT ready for probing. + //! Finalize must be called before any call to Probe, and after Finalize is called Build should no longer be + //! ever called. + void Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool parallel); + //! Probe the HT with the given input chunk, resulting in the given result + unique_ptr Probe(DataChunk &keys, TupleDataChunkState &key_state, + Vector *precomputed_hashes = nullptr); + //! Scan the HT to construct the full outer join result + void ScanFullOuter(JoinHTScanState &state, Vector &addresses, DataChunk &result); + + //! Fill the pointer with all the addresses from the hashtable for full scan + idx_t FillWithHTOffsets(JoinHTScanState &state, Vector &addresses); + + idx_t Count() const { + return data_collection->Count(); + } + idx_t SizeInBytes() const { + return data_collection->SizeInBytes(); + } + + PartitionedTupleData &GetSinkCollection() { + return *sink_collection; + } + + TupleDataCollection &GetDataCollection() { + return *data_collection; + } + + //! BufferManager + BufferManager &buffer_manager; + //! The join conditions + const vector &conditions; + //! The types of the keys used in equality comparison + vector equality_types; + //! The types of the keys + vector condition_types; + //! The types of all conditions + vector build_types; + //! The comparison predicates + vector predicates; + //! Data column layout + TupleDataLayout layout; + //! Efficiently matches rows + RowMatcher row_matcher; + RowMatcher row_matcher_no_match_sel; + //! The size of an entry as stored in the HashTable + idx_t entry_size; + //! The total tuple size + idx_t tuple_size; + //! Next pointer offset in tuple + idx_t pointer_offset; + //! A constant false column for initialising right outer joins + Vector vfound; + //! The join type of the HT + JoinType join_type; + //! Whether or not the HT has been finalized + bool finalized; + //! Whether or not any of the key elements contain NULL + bool has_null; + //! Bitmask for getting relevant bits from the hashes to determine the position + uint64_t bitmask; + + struct { + mutex mj_lock; + //! The types of the duplicate eliminated columns, only used in correlated MARK JOIN for flattening + //! ANY()/ALL() expressions + vector correlated_types; + //! The aggregate expression nodes used by the HT + vector> correlated_aggregates; + //! The HT that holds the group counts for every correlated column + unique_ptr correlated_counts; + //! Group chunk used for aggregating into correlated_counts + DataChunk group_chunk; + //! Payload chunk used for aggregating into correlated_counts + DataChunk correlated_payload; + //! Result chunk used for aggregating into correlated_counts + DataChunk result_chunk; + } correlated_mark_join_info; + +private: + unique_ptr InitializeScanStructure(DataChunk &keys, TupleDataChunkState &key_state, + const SelectionVector *¤t_sel); + void Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes); + + //! Apply a bitmask to the hashes + void ApplyBitmask(Vector &hashes, idx_t count); + void ApplyBitmask(Vector &hashes, const SelectionVector &sel, idx_t count, Vector &pointers); + +private: + //! Insert the given set of locations into the HT with the given set of hashes + void InsertHashes(Vector &hashes, idx_t count, data_ptr_t key_locations[], bool parallel); + + idx_t PrepareKeys(DataChunk &keys, vector &vector_data, const SelectionVector *¤t_sel, + SelectionVector &sel, bool build_side); + + //! Lock for combining data_collection when merging HTs + mutex data_lock; + //! Partitioned data collection that the data is sunk into when building + unique_ptr sink_collection; + //! The DataCollection holding the main data of the hash table + unique_ptr data_collection; + //! The hash map of the HT, created after finalization + AllocatedData hash_map; + //! Whether or not NULL values are considered equal in each of the comparisons + vector null_values_are_equal; + + //! Copying not allowed + JoinHashTable(const JoinHashTable &) = delete; + +public: + //===--------------------------------------------------------------------===// + // External Join + //===--------------------------------------------------------------------===// + struct ProbeSpillLocalAppendState { + //! Local partition and append state (if partitioned) + PartitionedColumnData *local_partition; + PartitionedColumnDataAppendState *local_partition_append_state; + //! Local spill and append state (if not partitioned) + ColumnDataCollection *local_spill_collection; + ColumnDataAppendState *local_spill_append_state; + }; + //! ProbeSpill represents materialized probe-side data that could not be probed during PhysicalHashJoin::Execute + //! because the HashTable did not fit in memory. The ProbeSpill is not partitioned if the remaining data can be + //! dealt with in just 1 more round of probing, otherwise it is radix partitioned in the same way as the HashTable + struct ProbeSpill { + public: + ProbeSpill(JoinHashTable &ht, ClientContext &context, const vector &probe_types); + + public: + //! Create a state for a new thread + ProbeSpillLocalAppendState RegisterThread(); + //! Append a chunk to this ProbeSpill + void Append(DataChunk &chunk, ProbeSpillLocalAppendState &local_state); + //! Finalize by merging the thread-local accumulated data + void Finalize(); + + public: + //! Prepare the next probe round + void PrepareNextProbe(); + //! Scans and consumes the ColumnDataCollection + unique_ptr consumer; + + private: + JoinHashTable &ht; + mutex lock; + ClientContext &context; + + //! Whether the probe data is partitioned + bool partitioned; + //! The types of the probe DataChunks + const vector &probe_types; + //! The column ids + vector column_ids; + + //! The partitioned probe data (if partitioned) and append states + unique_ptr global_partitions; + vector> local_partitions; + vector> local_partition_append_states; + + //! The probe data (if not partitioned) and append states + unique_ptr global_spill_collection; + vector> local_spill_collections; + vector> local_spill_append_states; + }; + + //! Whether we are doing an external hash join + bool external; + //! The current number of radix bits used to partition + idx_t radix_bits; + //! The max size of the HT + idx_t max_ht_size; + //! Total count + idx_t total_count; + + //! Capacity of the pointer table given the ht count + //! (minimum of 1024 to prevent collision chance for small HT's) + static idx_t PointerTableCapacity(idx_t count) { + return MaxValue(NextPowerOfTwo(count * 2), 1 << 10); + } + //! Size of the pointer table (in bytes) + static idx_t PointerTableSize(idx_t count) { + return PointerTableCapacity(count) * sizeof(data_ptr_t); + } + + //! Whether we need to do an external join + bool RequiresExternalJoin(ClientConfig &config, vector> &local_hts); + //! Computes partition sizes and number of radix bits (called before scheduling partition tasks) + bool RequiresPartitioning(ClientConfig &config, vector> &local_hts); + //! Partition this HT + void Partition(JoinHashTable &global_ht); + + //! Delete blocks that belong to the current partitioned HT + void Reset(); + //! Build HT for the next partitioned probe round + bool PrepareExternalFinalize(); + //! Probe whatever we can, sink the rest into a thread-local HT + unique_ptr ProbeAndSpill(DataChunk &keys, TupleDataChunkState &key_state, DataChunk &payload, + ProbeSpill &probe_spill, ProbeSpillLocalAppendState &spill_state, + DataChunk &spill_chunk); + +private: + //! First and last partition of the current probe round + idx_t partition_start; + idx_t partition_end; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/nested_loop_join.hpp b/src/duckdb/src/include/duckdb/execution/nested_loop_join.hpp new file mode 100644 index 00000000..b5f36ee0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/nested_loop_join.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/nested_loop_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" + +namespace duckdb { +class ColumnDataCollection; + +struct NestedLoopJoinInner { + static idx_t Perform(idx_t <uple, idx_t &rtuple, DataChunk &left_conditions, DataChunk &right_conditions, + SelectionVector &lvector, SelectionVector &rvector, const vector &conditions); +}; + +struct NestedLoopJoinMark { + static void Perform(DataChunk &left, ColumnDataCollection &right, bool found_match[], + const vector &conditions); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/aggregate_object.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/aggregate_object.hpp new file mode 100644 index 00000000..874a02a2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/aggregate_object.hpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/aggregate/aggregate_object.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/aggregate_function.hpp" + +namespace duckdb { + +class BoundAggregateExpression; +class BoundWindowExpression; + +struct FunctionDataWrapper { + FunctionDataWrapper(unique_ptr function_data_p) : function_data(std::move(function_data_p)) { + } + + unique_ptr function_data; +}; + +struct AggregateObject { + AggregateObject(AggregateFunction function, FunctionData *bind_data, idx_t child_count, idx_t payload_size, + AggregateType aggr_type, PhysicalType return_type, Expression *filter = nullptr); + explicit AggregateObject(BoundAggregateExpression *aggr); + explicit AggregateObject(BoundWindowExpression &window); + + FunctionData *GetFunctionData() const { + return bind_data_wrapper ? bind_data_wrapper->function_data.get() : nullptr; + } + + AggregateFunction function; + shared_ptr bind_data_wrapper; + idx_t child_count; + idx_t payload_size; + AggregateType aggr_type; + PhysicalType return_type; + Expression *filter = nullptr; + +public: + bool IsDistinct() const { + return aggr_type == AggregateType::DISTINCT; + } + static vector CreateAggregateObjects(const vector &bindings); +}; + +struct AggregateFilterData { + AggregateFilterData(ClientContext &context, Expression &filter_expr, const vector &payload_types); + + idx_t ApplyFilter(DataChunk &payload); + + ExpressionExecutor filter_executor; + DataChunk filtered_payload; + SelectionVector true_sel; +}; + +struct AggregateFilterDataSet { + AggregateFilterDataSet(); + + vector> filter_data; + +public: + void Initialize(ClientContext &context, const vector &aggregates, + const vector &payload_types); + + AggregateFilterData &GetFilterData(idx_t aggr_idx); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp new file mode 100644 index 00000000..772d116a --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp" +#include "duckdb/execution/radix_partitioned_hashtable.hpp" + +namespace duckdb { + +class GroupedAggregateData; + +struct DistinctAggregateCollectionInfo { +public: + DistinctAggregateCollectionInfo(const vector> &aggregates, vector indices); + +public: + // The indices of the aggregates that are distinct + unsafe_vector indices; + // The amount of radix_tables that are occupied + idx_t table_count; + //! Occupied tables, not equal to indices if aggregates share input data + vector table_indices; + //! This indirection is used to allow two aggregates to share the same input data + unordered_map table_map; + const vector> &aggregates; + // Total amount of children of the distinct aggregates + idx_t total_child_count; + +public: + static unique_ptr Create(vector> &aggregates); + const unsafe_vector &Indices() const; + bool AnyDistinct() const; + +private: + //! Returns the amount of tables that are occupied + idx_t CreateTableIndexMap(); +}; + +struct DistinctAggregateData { +public: + DistinctAggregateData(const DistinctAggregateCollectionInfo &info); + DistinctAggregateData(const DistinctAggregateCollectionInfo &info, const GroupingSet &groups, + const vector> *group_expressions); + //! The data used by the hashtables + vector> grouped_aggregate_data; + //! The hashtables + vector> radix_tables; + //! The groups (arguments) + vector grouping_sets; + const DistinctAggregateCollectionInfo &info; + +public: + bool IsDistinct(idx_t index) const; +}; + +struct DistinctAggregateState { +public: + DistinctAggregateState(const DistinctAggregateData &data, ClientContext &client); + + //! The executor + ExpressionExecutor child_executor; + //! The global sink states of the hash tables + vector> radix_states; + //! Output chunks to receive distinct data from hashtables + vector> distinct_output_chunks; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp new file mode 100644 index 00000000..42269420 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/parser/group_by_node.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +class GroupedAggregateData { +public: + GroupedAggregateData() { + } + //! The groups + vector> groups; + //! The set of GROUPING functions + vector> grouping_functions; + //! The group types + vector group_types; + + //! The aggregates that have to be computed + vector> aggregates; + //! The payload types + vector payload_types; + //! The aggregate return types + vector aggregate_return_types; + //! Pointers to the aggregates + vector bindings; + idx_t filter_count; + +public: + idx_t GroupCount() const; + + const vector> &GetGroupingFunctions() const; + + void InitializeGroupby(vector> groups, vector> expressions, + vector> grouping_functions); + + //! Initialize a GroupedAggregateData object for use with distinct aggregates + void InitializeDistinct(const unique_ptr &aggregate, const vector> *groups_p); + +private: + void InitializeDistinctGroups(const vector> *groups); + void InitializeGroupbyGroups(vector> groups); + void SetGroupingFunctions(vector> &functions); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp new file mode 100644 index 00000000..d36b17c0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp @@ -0,0 +1,156 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" +#include "duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/execution/radix_partitioned_hashtable.hpp" +#include "duckdb/parser/group_by_node.hpp" +#include "duckdb/storage/data_table.hpp" + +namespace duckdb { + +class ClientContext; +class BufferManager; +class PhysicalHashAggregate; + +struct HashAggregateGroupingData { +public: + HashAggregateGroupingData(GroupingSet &grouping_set_p, const GroupedAggregateData &grouped_aggregate_data, + unique_ptr &info); + +public: + RadixPartitionedHashTable table_data; + unique_ptr distinct_data; + +public: + bool HasDistinct() const; +}; + +struct HashAggregateGroupingGlobalState { +public: + HashAggregateGroupingGlobalState(const HashAggregateGroupingData &data, ClientContext &context); + // Radix state of the GROUPING_SET ht + unique_ptr table_state; + // State of the DISTINCT aggregates of this GROUPING_SET + unique_ptr distinct_state; +}; + +struct HashAggregateGroupingLocalState { +public: + HashAggregateGroupingLocalState(const PhysicalHashAggregate &op, const HashAggregateGroupingData &data, + ExecutionContext &context); + +public: + // Radix state of the GROUPING_SET ht + unique_ptr table_state; + // Local states of the DISTINCT aggregates hashtables + vector> distinct_states; +}; + +//! PhysicalHashAggregate is a group-by and aggregate implementation that uses a hash table to perform the grouping +//! This only contains read-only variables, anything that is stateful instead gets stored in the Global/Local states +class PhysicalHashAggregate : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::HASH_GROUP_BY; + +public: + PhysicalHashAggregate(ClientContext &context, vector types, vector> expressions, + idx_t estimated_cardinality); + PhysicalHashAggregate(ClientContext &context, vector types, vector> expressions, + vector> groups, idx_t estimated_cardinality); + PhysicalHashAggregate(ClientContext &context, vector types, vector> expressions, + vector> groups, vector grouping_sets, + vector> grouping_functions, idx_t estimated_cardinality); + + //! The grouping sets + GroupedAggregateData grouped_aggregate_data; + + vector grouping_sets; + //! The radix partitioned hash tables (one per grouping set) + vector groupings; + unique_ptr distinct_collection_info; + //! A recreation of the input chunk, with nulls for everything that isnt a group + vector input_group_types; + + // Filters given to Sink and friends + unsafe_vector non_distinct_filter; + unsafe_vector distinct_filter; + + unordered_map filter_indexes; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + bool ParallelSource() const override { + return true; + } + + OrderPreservationType SourceOrder() const override { + return OrderPreservationType::NO_ORDER; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + SinkFinalizeType FinalizeInternal(Pipeline &pipeline, Event &event, ClientContext &context, GlobalSinkState &gstate, + bool check_distinct) const; + + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } + + bool SinkOrderDependent() const override { + return false; + } + +public: + string ParamsToString() const override; + //! Toggle multi-scan capability on a hash table, which prevents the scan of the aggregate from being destructive + //! If this is not toggled the GetData method will destroy the hash table as it is scanning it + static void SetMultiScan(GlobalSinkState &state); + +private: + //! When we only have distinct aggregates, we can delay adding groups to the main ht + bool CanSkipRegularSink() const; + + //! Finalize the distinct aggregates + SinkFinalizeType FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, + GlobalSinkState &gstate) const; + //! Combine the distinct aggregates + void CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const; + //! Sink the distinct aggregates for a single grouping + void SinkDistinctGrouping(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, + idx_t grouping_idx) const; + //! Sink the distinct aggregates + void SinkDistinct(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const; + //! Create groups in the main ht for groups that would otherwise get filtered out completely + SinkResultType SinkGroupsOnly(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate, + DataChunk &input) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp new file mode 100644 index 00000000..8fa1a231 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp @@ -0,0 +1,86 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/execution/base_aggregate_hashtable.hpp" + +namespace duckdb { +class ClientContext; +class PerfectAggregateHashTable; + +//! PhysicalPerfectHashAggregate performs a group-by and aggregation using a perfect hash table +class PhysicalPerfectHashAggregate : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::PERFECT_HASH_GROUP_BY; + +public: + PhysicalPerfectHashAggregate(ClientContext &context, vector types, + vector> aggregates, vector> groups, + const vector> &group_stats, vector required_bits, + idx_t estimated_cardinality); + + //! The groups + vector> groups; + //! The aggregates that have to be computed + vector> aggregates; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + OrderPreservationType SourceOrder() const override { + return OrderPreservationType::NO_ORDER; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + string ParamsToString() const override; + + //! Create a perfect aggregate hash table for this node + unique_ptr CreateHT(Allocator &allocator, ClientContext &context) const; + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } + + bool SinkOrderDependent() const override { + return false; + } + +public: + //! The group types + vector group_types; + //! The payload types + vector payload_types; + //! The aggregates to be computed + vector aggregate_objects; + //! The minimum value of each of the groups + vector group_minima; + //! The number of bits we need to completely cover each of the groups + vector required_bits; + + unordered_map filter_indexes; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_streaming_window.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_streaming_window.hpp new file mode 100644 index 00000000..512950c1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_streaming_window.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/aggregate/physical_window.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! PhysicalStreamingWindow implements streaming window functions (i.e. with an empty OVER clause) +class PhysicalStreamingWindow : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::STREAMING_WINDOW; + +public: + PhysicalStreamingWindow(vector types, vector> select_list, + idx_t estimated_cardinality, + PhysicalOperatorType type = PhysicalOperatorType::STREAMING_WINDOW); + + //! The projection list of the WINDOW statement + vector> select_list; + +public: + unique_ptr GetGlobalOperatorState(ClientContext &context) const override; + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + + OrderPreservationType OperatorOrder() const override { + return OrderPreservationType::FIXED_ORDER; + } + + string ParamsToString() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp new file mode 100644 index 00000000..c4ca5480 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp" +#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" +#include "duckdb/parser/group_by_node.hpp" +#include "duckdb/execution/radix_partitioned_hashtable.hpp" +#include "duckdb/common/unordered_map.hpp" + +namespace duckdb { + +//! PhysicalUngroupedAggregate is an aggregate operator that can only perform aggregates (1) without any groups, (2) +//! without any DISTINCT aggregates, and (3) when all aggregates are combineable +class PhysicalUngroupedAggregate : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::UNGROUPED_AGGREGATE; + +public: + PhysicalUngroupedAggregate(vector types, vector> expressions, + idx_t estimated_cardinality); + + //! The aggregates that have to be computed + vector> aggregates; + unique_ptr distinct_data; + unique_ptr distinct_collection_info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + string ParamsToString() const override; + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } + + bool SinkOrderDependent() const override; + +private: + //! Finalize the distinct aggregates + SinkFinalizeType FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, + GlobalSinkState &gstate) const; + //! Combine the distinct aggregates + void CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const; + //! Sink the distinct aggregates + void SinkDistinct(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp new file mode 100644 index 00000000..5bf8fdcc --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp @@ -0,0 +1,80 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/aggregate/physical_window.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/pipeline.hpp" + +namespace duckdb { + +//! PhysicalWindow implements window functions +//! It assumes that all functions have a common partitioning and ordering +class PhysicalWindow : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::WINDOW; + +public: + PhysicalWindow(vector types, vector> select_list, idx_t estimated_cardinality, + PhysicalOperatorType type = PhysicalOperatorType::WINDOW); + + //! The projection list of the WINDOW statement (may contain aggregates) + vector> select_list; + //! Whether or not the window is order dependent (only true if all window functions contain neither an order nor a + //! partition clause) + bool is_order_dependent; + +public: + // Source interface + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + idx_t GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, + LocalSourceState &lstate) const override; + + bool IsSource() const override { + return true; + } + bool ParallelSource() const override { + return true; + } + + bool SupportsBatchIndex() const override; + OrderPreservationType SourceOrder() const override; + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return !is_order_dependent; + } + + bool SinkOrderDependent() const override { + return is_order_dependent; + } + +public: + idx_t MaxThreads(ClientContext &context); + + string ParamsToString() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/filter/physical_filter.hpp b/src/duckdb/src/include/duckdb/execution/operator/filter/physical_filter.hpp new file mode 100644 index 00000000..57189d59 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/filter/physical_filter.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/filter/physical_filter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! PhysicalFilter represents a filter operator. It removes non-matching tuples +//! from the result. Note that it does not physically change the data, it only +//! adds a selection vector to the chunk. +class PhysicalFilter : public CachingPhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::FILTER; + +public: + PhysicalFilter(vector types, vector> select_list, idx_t estimated_cardinality); + + //! The filter expression + unique_ptr expression; + +public: + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + bool ParallelOperator() const override { + return true; + } + + string ParamsToString() const override; + +protected: + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp new file mode 100644 index 00000000..04c15209 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_batch_collector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/helper/physical_result_collector.hpp" + +namespace duckdb { + +class PhysicalBatchCollector : public PhysicalResultCollector { +public: + PhysicalBatchCollector(PreparedStatementData &data); + +public: + unique_ptr GetResult(GlobalSinkState &state) override; + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool RequiresBatchIndex() const override { + return true; + } + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_execute.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_execute.hpp new file mode 100644 index 00000000..46b1e38c --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_execute.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_execute.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/main/prepared_statement_data.hpp" + +namespace duckdb { + +class PhysicalExecute : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::EXECUTE; + +public: + explicit PhysicalExecute(PhysicalOperator &plan); + + PhysicalOperator &plan; + unique_ptr owned_plan; + shared_ptr prepared; + +public: + vector> GetChildren() const override; + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_explain_analyze.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_explain_analyze.hpp new file mode 100644 index 00000000..9ecd71db --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_explain_analyze.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_explain_analyze.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class PhysicalExplainAnalyze : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::EXPLAIN_ANALYZE; + +public: + PhysicalExplainAnalyze(vector types) + : PhysicalOperator(PhysicalOperatorType::EXPLAIN_ANALYZE, std::move(types), 1) { + } + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink Interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit.hpp new file mode 100644 index 00000000..2282dbfd --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit.hpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_limit.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! PhyisicalLimit represents the LIMIT operator +class PhysicalLimit : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::LIMIT; + +public: + PhysicalLimit(vector types, idx_t limit, idx_t offset, unique_ptr limit_expression, + unique_ptr offset_expression, idx_t estimated_cardinality); + + idx_t limit_value; + idx_t offset_value; + unique_ptr limit_expression; + unique_ptr offset_expression; + +public: + bool SinkOrderDependent() const override { + return true; + } + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink Interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } + + bool RequiresBatchIndex() const override { + return true; + } + +public: + static bool ComputeOffset(ExecutionContext &context, DataChunk &input, idx_t &limit, idx_t &offset, + idx_t current_offset, idx_t &max_element, Expression *limit_expression, + Expression *offset_expression); + static bool HandleOffset(DataChunk &input, idx_t ¤t_offset, idx_t offset, idx_t limit); + static Value GetDelimiter(ExecutionContext &context, DataChunk &input, Expression *expr); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit_percent.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit_percent.hpp new file mode 100644 index 00000000..93c04ed8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit_percent.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_limit_percent.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! PhyisicalLimitPercent represents the LIMIT PERCENT operator +class PhysicalLimitPercent : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::LIMIT_PERCENT; + +public: + PhysicalLimitPercent(vector types, double limit_percent, idx_t offset, + unique_ptr limit_expression, unique_ptr offset_expression, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::LIMIT_PERCENT, std::move(types), estimated_cardinality), + limit_percent(limit_percent), offset_value(offset), limit_expression(std::move(limit_expression)), + offset_expression(std::move(offset_expression)) { + } + + double limit_percent; + idx_t offset_value; + unique_ptr limit_expression; + unique_ptr offset_expression; + +public: + bool SinkOrderDependent() const override { + return true; + } + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + + bool IsSink() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_load.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_load.hpp new file mode 100644 index 00000000..4511ff01 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_load.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_vacuum.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/load_info.hpp" + +namespace duckdb { + +//! PhysicalLoad represents an extension LOAD operation +class PhysicalLoad : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::LOAD; + +public: + explicit PhysicalLoad(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::LOAD, {LogicalType::BOOLEAN}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp new file mode 100644 index 00000000..4e6efaf8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_materialized_collector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/helper/physical_result_collector.hpp" + +namespace duckdb { + +class PhysicalMaterializedCollector : public PhysicalResultCollector { +public: + PhysicalMaterializedCollector(PreparedStatementData &data, bool parallel); + + bool parallel; + +public: + unique_ptr GetResult(GlobalSinkState &state) override; + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool ParallelSink() const override; + bool SinkOrderDependent() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_pragma.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_pragma.hpp new file mode 100644 index 00000000..ff2c721d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_pragma.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_pragma.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/pragma_info.hpp" +#include "duckdb/function/pragma_function.hpp" + +namespace duckdb { + +//! PhysicalPragma represents the PRAGMA operator +class PhysicalPragma : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::PRAGMA; + +public: + PhysicalPragma(PragmaFunction function_p, PragmaInfo info_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::PRAGMA, {LogicalType::BOOLEAN}, estimated_cardinality), + function(std::move(function_p)), info(std::move(info_p)) { + } + + //! The pragma function to call + PragmaFunction function; + //! The context of the call + PragmaInfo info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp new file mode 100644 index 00000000..01552842 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_prepare.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/common/enums/physical_operator_type.hpp" +#include "duckdb/main/prepared_statement_data.hpp" + +namespace duckdb { + +class PhysicalPrepare : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::PREPARE; + +public: + PhysicalPrepare(string name, shared_ptr prepared, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::PREPARE, {LogicalType::BOOLEAN}, estimated_cardinality), name(name), + prepared(std::move(prepared)) { + } + + string name; + shared_ptr prepared; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reservoir_sample.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reservoir_sample.hpp new file mode 100644 index 00000000..40991b72 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reservoir_sample.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_reservoir_sample.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" + +namespace duckdb { + +//! PhysicalReservoirSample represents a sample taken using reservoir sampling, which is a blocking sampling method +class PhysicalReservoirSample : public PhysicalOperator { +public: + PhysicalReservoirSample(vector types, unique_ptr options, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::RESERVOIR_SAMPLE, std::move(types), estimated_cardinality), + options(std::move(options)) { + } + + unique_ptr options; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool ParallelSink() const override { + return true; + } + + bool IsSink() const override { + return true; + } + + string ParamsToString() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp new file mode 100644 index 00000000..c3e69e4a --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_reset.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/set_scope.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/vacuum_info.hpp" + +namespace duckdb { + +struct DBConfig; +struct ExtensionOption; + +//! PhysicalReset represents a RESET operation (e.g. RESET a = 42) +class PhysicalReset : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::RESET; + +public: + PhysicalReset(const std::string &name_p, SetScope scope_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::RESET, {LogicalType::BOOLEAN}, estimated_cardinality), name(name_p), + scope(scope_p) { + } + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + const std::string name; + const SetScope scope; + +private: + void ResetExtensionVariable(ExecutionContext &context, DBConfig &config, ExtensionOption &extension_option) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp new file mode 100644 index 00000000..123c53ba --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_result_collector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/common/enums/statement_type.hpp" + +namespace duckdb { +class PreparedStatementData; + +//! PhysicalResultCollector is an abstract class that is used to generate the final result of a query +class PhysicalResultCollector : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::RESULT_COLLECTOR; + +public: + explicit PhysicalResultCollector(PreparedStatementData &data); + + StatementType statement_type; + StatementProperties properties; + PhysicalOperator &plan; + vector names; + +public: + static unique_ptr GetResultCollector(ClientContext &context, PreparedStatementData &data); + +public: + //! The final method used to fetch the query result from this operator + virtual unique_ptr GetResult(GlobalSinkState &state) = 0; + + bool IsSink() const override { + return true; + } + +public: + vector> GetChildren() const override; + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp new file mode 100644 index 00000000..e5334563 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/set_scope.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/vacuum_info.hpp" + +namespace duckdb { + +struct DBConfig; +struct ExtensionOption; + +//! PhysicalSet represents a SET operation (e.g. SET a = 42) +class PhysicalSet : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::SET; + +public: + PhysicalSet(const std::string &name_p, Value value_p, SetScope scope_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::SET, {LogicalType::BOOLEAN}, estimated_cardinality), name(name_p), + value(value_p), scope(scope_p) { + } + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + + static void SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, + SetScope scope, const Value &value); + +public: + const std::string name; + const Value value; + const SetScope scope; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_limit.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_limit.hpp new file mode 100644 index 00000000..a12b4e24 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_limit.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_streaming_limit.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class PhysicalStreamingLimit : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::STREAMING_LIMIT; + +public: + PhysicalStreamingLimit(vector types, idx_t limit, idx_t offset, + unique_ptr limit_expression, unique_ptr offset_expression, + idx_t estimated_cardinality, bool parallel); + + idx_t limit_value; + idx_t offset_value; + unique_ptr limit_expression; + unique_ptr offset_expression; + bool parallel; + +public: + // Operator interface + unique_ptr GetOperatorState(ExecutionContext &context) const override; + unique_ptr GetGlobalOperatorState(ClientContext &context) const override; + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + + OrderPreservationType OperatorOrder() const override; + bool ParallelOperator() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_sample.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_sample.hpp new file mode 100644 index 00000000..40445cf5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_sample.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_streaming_sample.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" + +namespace duckdb { + +//! PhysicalStreamingSample represents a streaming sample using either system or bernoulli sampling +class PhysicalStreamingSample : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::STREAMING_SAMPLE; + +public: + PhysicalStreamingSample(vector types, SampleMethod method, double percentage, int64_t seed, + idx_t estimated_cardinality); + + SampleMethod method; + double percentage; + int64_t seed; + +public: + // Operator interface + unique_ptr GetOperatorState(ExecutionContext &context) const override; + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + + bool ParallelOperator() const override { + return true; + } + + string ParamsToString() const override; + +private: + void SystemSample(DataChunk &input, DataChunk &result, OperatorState &state) const; + void BernoulliSample(DataChunk &input, DataChunk &result, OperatorState &state) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_transaction.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_transaction.hpp new file mode 100644 index 00000000..3a2fee92 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_transaction.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_transaction.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/transaction_info.hpp" + +namespace duckdb { + +//! PhysicalTransaction represents a transaction operator (e.g. BEGIN or COMMIT) +class PhysicalTransaction : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::TRANSACTION; + +public: + explicit PhysicalTransaction(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::TRANSACTION, {LogicalType::BOOLEAN}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_vacuum.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_vacuum.hpp new file mode 100644 index 00000000..77238a04 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_vacuum.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_vacuum.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/vacuum_info.hpp" + +namespace duckdb { + +//! PhysicalVacuum represents a VACUUM operation (i.e. VACUUM or ANALYZE) +class PhysicalVacuum : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::VACUUM; + +public: + PhysicalVacuum(unique_ptr info, idx_t estimated_cardinality); + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return info->has_table; + } + + bool ParallelSink() const override { + return IsSink(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp new file mode 100644 index 00000000..b90a2676 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/outer_join_marker.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/operator/join/physical_comparison_join.hpp" +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +struct OuterJoinGlobalScanState { + mutex lock; + ColumnDataCollection *data = nullptr; + ColumnDataParallelScanState global_scan; +}; + +struct OuterJoinLocalScanState { + DataChunk scan_chunk; + SelectionVector match_sel; + ColumnDataLocalScanState local_scan; +}; + +class OuterJoinMarker { +public: + explicit OuterJoinMarker(bool enabled); + + bool Enabled() { + return enabled; + } + //! Initializes the outer join counter + void Initialize(idx_t count); + //! Resets the outer join counter + void Reset(); + + //! Sets an indiivdual match + void SetMatch(idx_t position); + + //! Sets multiple matches + void SetMatches(const SelectionVector &sel, idx_t count, idx_t base_idx = 0); + + //! Constructs a left-join result based on which tuples have not found matches + void ConstructLeftJoinResult(DataChunk &left, DataChunk &result); + + //! Returns the maximum number of threads that can be associated with an right-outer join scan + idx_t MaxThreads() const; + + //! Initialize a scan + void InitializeScan(ColumnDataCollection &data, OuterJoinGlobalScanState &gstate); + + //! Initialize a local scan + void InitializeScan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanState &lstate); + + //! Perform the scan + void Scan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanState &lstate, DataChunk &result); + + //! Read-only matches vector + const bool *GetMatches() const { + return found_match.get(); + } + +private: + bool enabled; + unsafe_unique_array found_match; + idx_t count; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp new file mode 100644 index 00000000..33fcb6a2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/perfect_hash_join_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/execution/join_hashtable.hpp" +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +class HashJoinOperatorState; +class HashJoinGlobalSinkState; +class PhysicalHashJoin; + +struct PerfectHashJoinStats { + Value build_min; + Value build_max; + Value probe_min; + Value probe_max; + bool is_build_small = false; + bool is_build_dense = false; + bool is_probe_in_domain = false; + idx_t build_range = 0; + idx_t estimated_cardinality = 0; +}; + +//! PhysicalHashJoin represents a hash loop join between two tables +class PerfectHashJoinExecutor { + using PerfectHashTable = vector; + +public: + explicit PerfectHashJoinExecutor(const PhysicalHashJoin &join, JoinHashTable &ht, PerfectHashJoinStats pjoin_stats); + +public: + bool CanDoPerfectHashJoin(); + + unique_ptr GetOperatorState(ExecutionContext &context); + OperatorResultType ProbePerfectHashTable(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + OperatorState &state); + bool BuildPerfectHashTable(LogicalType &type); + +private: + void FillSelectionVectorSwitchProbe(Vector &source, SelectionVector &build_sel_vec, SelectionVector &probe_sel_vec, + idx_t count, idx_t &probe_sel_count); + template + void TemplatedFillSelectionVectorProbe(Vector &source, SelectionVector &build_sel_vec, + SelectionVector &probe_sel_vec, idx_t count, idx_t &prob_sel_count); + + bool FillSelectionVectorSwitchBuild(Vector &source, SelectionVector &sel_vec, SelectionVector &seq_sel_vec, + idx_t count); + template + bool TemplatedFillSelectionVectorBuild(Vector &source, SelectionVector &sel_vec, SelectionVector &seq_sel_vec, + idx_t count); + bool FullScanHashTable(LogicalType &key_type); + +private: + const PhysicalHashJoin &join; + JoinHashTable &ht; + //! Columnar perfect hash table + PerfectHashTable perfect_hash_table; + //! Build and probe statistics + PerfectHashJoinStats perfect_join_statistics; + //! Stores the occurences of each value in the build side + unsafe_unique_array bitmap_build_idx; + //! Stores the number of unique keys in the build side + idx_t unique_keys = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp new file mode 100644 index 00000000..467eafe6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp @@ -0,0 +1,84 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_asof_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/join/physical_comparison_join.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" + +namespace duckdb { + +//! PhysicalAsOfJoin represents an as-of join between two tables +class PhysicalAsOfJoin : public PhysicalComparisonJoin { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::ASOF_JOIN; + +public: + PhysicalAsOfJoin(LogicalComparisonJoin &op, unique_ptr left, unique_ptr right); + + vector join_key_types; + vector null_sensitive; + ExpressionType comparison_type; + + // Equalities + vector> lhs_partitions; + vector> rhs_partitions; + + // Inequality Only + vector lhs_orders; + vector rhs_orders; + + // Projection mappings + vector right_projection_map; + +public: + // Operator Interface + unique_ptr GetGlobalOperatorState(ClientContext &context) const override; + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + bool ParallelOperator() const override { + return true; + } + +protected: + // CachingOperator Interface + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + +public: + // Source interface + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + bool ParallelSource() const override { + return true; + } + +public: + // Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_blockwise_nl_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_blockwise_nl_join.hpp new file mode 100644 index 00000000..0485cd27 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_blockwise_nl_join.hpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_blockwise_nl_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/execution/operator/join/physical_join.hpp" + +namespace duckdb { + +//! PhysicalBlockwiseNLJoin represents a nested loop join between two tables on arbitrary expressions. This is different +//! from the PhysicalNestedLoopJoin in that it does not require expressions to be comparisons between the LHS and the +//! RHS. +class PhysicalBlockwiseNLJoin : public PhysicalJoin { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::BLOCKWISE_NL_JOIN; + +public: + PhysicalBlockwiseNLJoin(LogicalOperator &op, unique_ptr left, unique_ptr right, + unique_ptr condition, JoinType join_type, idx_t estimated_cardinality); + + unique_ptr condition; + +public: + // Operator Interface + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + bool ParallelOperator() const override { + return true; + } + +protected: + // CachingOperatorState Interface + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return IsRightOuterJoin(join_type); + } + bool ParallelSource() const override { + return true; + } + +public: + // Sink interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } + +public: + string ParamsToString() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_comparison_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_comparison_join.hpp new file mode 100644 index 00000000..b049c80f --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_comparison_join.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_comparison_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/join/physical_join.hpp" + +namespace duckdb { +class ColumnDataCollection; +struct ColumnDataScanState; + +//! PhysicalJoin represents the base class of the join operators +class PhysicalComparisonJoin : public PhysicalJoin { +public: + PhysicalComparisonJoin(LogicalOperator &op, PhysicalOperatorType type, vector cond, + JoinType join_type, idx_t estimated_cardinality); + + vector conditions; + +public: + string ParamsToString() const override; + + //! Construct the join result of a join with an empty RHS + static void ConstructEmptyJoinResult(JoinType type, bool has_null, DataChunk &input, DataChunk &result); + //! Construct the remainder of a Full Outer Join based on which tuples in the RHS found no match + static void ConstructFullOuterJoinResult(bool *found_match, ColumnDataCollection &input, DataChunk &result, + ColumnDataScanState &scan_state); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_cross_product.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_cross_product.hpp new file mode 100644 index 00000000..8091bb3c --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_cross_product.hpp @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_cross_product.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +//! PhysicalCrossProduct represents a cross product between two tables +class PhysicalCrossProduct : public CachingPhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CROSS_PRODUCT; + +public: + PhysicalCrossProduct(vector types, unique_ptr left, + unique_ptr right, idx_t estimated_cardinality); + +public: + // Operator Interface + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + OrderPreservationType OperatorOrder() const override { + return OrderPreservationType::NO_ORDER; + } + bool ParallelOperator() const override { + return true; + } + +protected: + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + +public: + // Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } + bool SinkOrderDependent() const override { + return false; + } + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + vector> GetSources() const override; +}; + +class CrossProductExecutor { +public: + explicit CrossProductExecutor(ColumnDataCollection &rhs); + + OperatorResultType Execute(DataChunk &input, DataChunk &output); + + // returns if the left side is scanned as a constant vector + bool ScanLHS() { + return scan_input_chunk; + } + + // returns the position in the chunk of chunk scanned as a constant input vector + idx_t PositionInChunk() { + return position_in_chunk; + } + + idx_t ScanPosition() { + return scan_state.current_row_index; + } + +private: + void Reset(DataChunk &input, DataChunk &output); + bool NextValue(DataChunk &input, DataChunk &output); + +private: + ColumnDataCollection &rhs; + ColumnDataScanState scan_state; + DataChunk scan_chunk; + idx_t position_in_chunk; + bool initialized; + bool finished; + bool scan_input_chunk; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_delim_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_delim_join.hpp new file mode 100644 index 00000000..0e3a7754 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_delim_join.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_delim_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { +class PhysicalHashAggregate; + +//! PhysicalDelimJoin represents a join where the LHS will be duplicate eliminated and pushed into a +//! PhysicalColumnDataScan in the RHS. +class PhysicalDelimJoin : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::DELIM_JOIN; + +public: + PhysicalDelimJoin(vector types, unique_ptr original_join, + vector> delim_scans, idx_t estimated_cardinality); + + unique_ptr join; + unique_ptr distinct; + vector> delim_scans; + +public: + vector> GetChildren() const override; + +public: + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } + OrderPreservationType SourceOrder() const override { + return OrderPreservationType::NO_ORDER; + } + bool SinkOrderDependent() const override { + return false; + } + string ParamsToString() const override; + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp new file mode 100644 index 00000000..3947347d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_hash_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/execution/join_hashtable.hpp" +#include "duckdb/execution/operator/join/perfect_hash_join_executor.hpp" +#include "duckdb/execution/operator/join/physical_comparison_join.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/operator/logical_join.hpp" + +namespace duckdb { + +//! PhysicalHashJoin represents a hash loop join between two tables +class PhysicalHashJoin : public PhysicalComparisonJoin { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::HASH_JOIN; + +public: + PhysicalHashJoin(LogicalOperator &op, unique_ptr left, unique_ptr right, + vector cond, JoinType join_type, const vector &left_projection_map, + const vector &right_projection_map, vector delim_types, + idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_stats); + PhysicalHashJoin(LogicalOperator &op, unique_ptr left, unique_ptr right, + vector cond, JoinType join_type, idx_t estimated_cardinality, + PerfectHashJoinStats join_state); + + //! Initialize HT for this operator + unique_ptr InitializeHashTable(ClientContext &context) const; + + vector right_projection_map; + //! The types of the keys + vector condition_types; + //! The types of all conditions + vector build_types; + //! Duplicate eliminated types; only used for delim_joins (i.e. correlated subqueries) + vector delim_types; + //! Used in perfect hash join + PerfectHashJoinStats perfect_join_statistics; + +public: + // Operator Interface + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + bool ParallelOperator() const override { + return true; + } + +protected: + // CachingOperator Interface + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + //! Becomes a source when it is an external join + bool IsSource() const override { + return true; + } + + bool ParallelSource() const override { + return true; + } + +public: + // Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp new file mode 100644 index 00000000..c7d251ad --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/join/physical_range_join.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" + +namespace duckdb { + +//! PhysicalIEJoin represents a two inequality range join between +//! two tables +class PhysicalIEJoin : public PhysicalRangeJoin { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::IE_JOIN; + +public: + PhysicalIEJoin(LogicalComparisonJoin &op, unique_ptr left, unique_ptr right, + vector cond, JoinType join_type, idx_t estimated_cardinality); + + vector join_key_types; + vector> lhs_orders; + vector> rhs_orders; + +public: + // CachingOperator Interface + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + +public: + // Source interface + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + bool ParallelSource() const override { + return true; + } + +public: + // Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + +private: + // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) + void ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_index_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_index_join.hpp new file mode 100644 index 00000000..ebe2cb23 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_index_join.hpp @@ -0,0 +1,79 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_index_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/execution/operator/join/physical_comparison_join.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/operator/logical_join.hpp" +#include "duckdb/storage/index.hpp" + +namespace duckdb { + +//! PhysicalIndexJoin represents an index join between two tables +class PhysicalIndexJoin : public CachingPhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::INDEX_JOIN; + +public: + PhysicalIndexJoin(LogicalOperator &op, unique_ptr left, unique_ptr right, + vector cond, JoinType join_type, const vector &left_projection_map, + vector right_projection_map, vector column_ids, Index &index, bool lhs_first, + idx_t estimated_cardinality); + + //! Columns from RHS used in the query + vector column_ids; + //! Columns to be fetched + vector fetch_ids; + //! Types of fetch columns + vector fetch_types; + //! Columns indexed by index + unordered_set index_ids; + //! Projected ids from LHS + vector left_projection_map; + //! Projected ids from RHS + vector right_projection_map; + //! The types of the keys + vector condition_types; + //! The types of all conditions + vector build_types; + //! Index used for join + Index &index; + + vector conditions; + + JoinType join_type; + //! In case we swap rhs with lhs we need to output columns related to rhs first. + bool lhs_first = true; + +public: + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + OrderPreservationType OperatorOrder() const override { + return OrderPreservationType::NO_ORDER; + } + bool ParallelOperator() const override { + return true; + } + +protected: + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + vector> GetSources() const override; + +private: + void GetRHSMatches(ExecutionContext &context, DataChunk &input, OperatorState &state_p) const; + //! Fills result chunk + void Output(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state_p) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_join.hpp new file mode 100644 index 00000000..bcc25eba --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_join.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" + +namespace duckdb { + +//! PhysicalJoin represents the base class of the join operators +class PhysicalJoin : public CachingPhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::INVALID; + +public: + PhysicalJoin(LogicalOperator &op, PhysicalOperatorType type, JoinType join_type, idx_t estimated_cardinality); + + JoinType join_type; + +public: + bool EmptyResultIfRHSIsEmpty() const; + + static bool HasNullValues(DataChunk &chunk); + static void ConstructSemiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]); + static void ConstructAntiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]); + static void ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &left, DataChunk &result, bool found_match[], + bool has_null); + +public: + static void BuildJoinPipelines(Pipeline ¤t, MetaPipeline &confluent_pipelines, PhysicalOperator &op); + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + vector> GetSources() const override; + + OrderPreservationType SourceOrder() const override { + return OrderPreservationType::NO_ORDER; + } + OrderPreservationType OperatorOrder() const override { + return OrderPreservationType::NO_ORDER; + } + bool SinkOrderDependent() const override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp new file mode 100644 index 00000000..31060230 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp @@ -0,0 +1,82 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_nested_loop_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/execution/operator/join/physical_comparison_join.hpp" + +namespace duckdb { + +//! PhysicalNestedLoopJoin represents a nested loop join between two tables +class PhysicalNestedLoopJoin : public PhysicalComparisonJoin { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::NESTED_LOOP_JOIN; + +public: + PhysicalNestedLoopJoin(LogicalOperator &op, unique_ptr left, unique_ptr right, + vector cond, JoinType join_type, idx_t estimated_cardinality); + +public: + // Operator Interface + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + bool ParallelOperator() const override { + return true; + } + +protected: + // CachingOperator Interface + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return IsRightOuterJoin(join_type); + } + bool ParallelSource() const override { + return true; + } + +public: + // Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } + + static bool IsSupported(const vector &conditions, JoinType join_type); + +public: + //! Returns a list of the types of the join conditions + vector GetJoinTypes() const; + +private: + // resolve joins that output max N elements (SEMI, ANTI, MARK) + void ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state) const; + // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) + OperatorResultType ResolveComplexJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + OperatorState &state) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp new file mode 100644 index 00000000..59c3ca54 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp @@ -0,0 +1,82 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/join/physical_range_join.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" + +namespace duckdb { + +class MergeJoinGlobalState; + +//! PhysicalPiecewiseMergeJoin represents a piecewise merge loop join between +//! two tables +class PhysicalPiecewiseMergeJoin : public PhysicalRangeJoin { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::PIECEWISE_MERGE_JOIN; + +public: + PhysicalPiecewiseMergeJoin(LogicalComparisonJoin &op, unique_ptr left, + unique_ptr right, vector cond, JoinType join_type, + idx_t estimated_cardinality); + + vector join_key_types; + vector lhs_orders; + vector rhs_orders; + +public: + // Operator Interface + unique_ptr GetOperatorState(ExecutionContext &context) const override; + + bool ParallelOperator() const override { + return true; + } + +protected: + // CachingOperator Interface + OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return IsRightOuterJoin(join_type); + } + bool ParallelSource() const override { + return true; + } + +public: + // Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } + +private: + // resolve joins that output max N elements (SEMI, ANTI, MARK) + void ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state) const; + // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) + OperatorResultType ResolveComplexJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + OperatorState &state) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_positional_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_positional_join.hpp new file mode 100644 index 00000000..409229ef --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_positional_join.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_positional_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" + +namespace duckdb { + +//! PhysicalPositionalJoin represents a cross product between two tables +class PhysicalPositionalJoin : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::POSITIONAL_JOIN; + +public: + PhysicalPositionalJoin(vector types, unique_ptr left, + unique_ptr right, idx_t estimated_cardinality); + +public: + // Operator Interface + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + + bool IsSink() const override { + return true; + } + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + vector> GetSources() const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp new file mode 100644 index 00000000..ef9f0343 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp @@ -0,0 +1,118 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/join/physical_comparison_join.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" +#include "duckdb/common/sort/sort.hpp" + +namespace duckdb { + +struct GlobalSortState; + +//! PhysicalRangeJoin represents one or more inequality range join predicates between +//! two tables +class PhysicalRangeJoin : public PhysicalComparisonJoin { +public: + class LocalSortedTable { + public: + LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child); + + void Sink(DataChunk &input, GlobalSortState &global_sort_state); + + inline void Sort(GlobalSortState &global_sort_state) { + local_sort_state.Sort(global_sort_state, true); + } + + //! The hosting operator + const PhysicalRangeJoin &op; + //! The local sort state + LocalSortState local_sort_state; + //! Local copy of the sorting expression executor + ExpressionExecutor executor; + //! Holds a vector of incoming sorting columns + DataChunk keys; + //! The number of NULL values + idx_t has_null; + //! The total number of rows + idx_t count; + + private: + // Merge the NULLs of all non-DISTINCT predicates into the primary so they sort to the end. + idx_t MergeNulls(const vector &conditions); + }; + + class GlobalSortedTable { + public: + GlobalSortedTable(ClientContext &context, const vector &orders, RowLayout &payload_layout); + + inline idx_t Count() const { + return count; + } + + inline idx_t BlockCount() const { + if (global_sort_state.sorted_blocks.empty()) { + return 0; + } + D_ASSERT(global_sort_state.sorted_blocks.size() == 1); + return global_sort_state.sorted_blocks[0]->radix_sorting_data.size(); + } + + inline idx_t BlockSize(idx_t i) const { + return global_sort_state.sorted_blocks[0]->radix_sorting_data[i]->count; + } + + void Combine(LocalSortedTable <able); + void IntializeMatches(); + void Print(); + + //! Starts the sorting process. + void Finalize(Pipeline &pipeline, Event &event); + //! Schedules tasks to merge sort the current child's data during a Finalize phase + void ScheduleMergeTasks(Pipeline &pipeline, Event &event); + + GlobalSortState global_sort_state; + //! Whether or not the RHS has NULL values + atomic has_null; + //! The total number of rows in the RHS + atomic count; + //! A bool indicating for each tuple in the RHS if they found a match (only used in FULL OUTER JOIN) + unsafe_unique_array found_match; + //! Memory usage per thread + idx_t memory_per_thread; + }; + +public: + PhysicalRangeJoin(LogicalComparisonJoin &op, PhysicalOperatorType type, unique_ptr left, + unique_ptr right, vector cond, JoinType join_type, + idx_t estimated_cardinality); + + // Projection mappings + using ProjectionMapping = vector; + ProjectionMapping left_projection_map; + ProjectionMapping right_projection_map; + + //! The full set of types (left + right child) + vector unprojected_types; + +public: + // Gather the result values and slice the payload columns to those values. + // Returns a buffer handle to the pinned heap block (if any) + static BufferHandle SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, + const SelectionVector &result, const idx_t result_count, + const idx_t left_cols = 0); + // Apply a tail condition to the current selection + static idx_t SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, + const SelectionVector *sel, idx_t count, SelectionVector *true_sel); + + //! Utility to project full width internal chunks to projected results + void ProjectResult(DataChunk &chunk, DataChunk &result) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/order/physical_order.hpp b/src/duckdb/src/include/duckdb/execution/operator/order/physical_order.hpp new file mode 100644 index 00000000..cb8a1774 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/order/physical_order.hpp @@ -0,0 +1,84 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/order/physical_order.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/planner/bound_query_node.hpp" + +namespace duckdb { + +class OrderGlobalSinkState; + +//! Physically re-orders the input data +class PhysicalOrder : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::ORDER_BY; + +public: + PhysicalOrder(vector types, vector orders, vector projections, + idx_t estimated_cardinality); + + //! Input data + vector orders; + vector projections; + +public: + // Source interface + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + idx_t GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, + LocalSourceState &lstate) const override; + + bool IsSource() const override { + return true; + } + + bool ParallelSource() const override { + return true; + } + + bool SupportsBatchIndex() const override { + return true; + } + + OrderPreservationType SourceOrder() const override { + return OrderPreservationType::FIXED_ORDER; + } + +public: + // Sink interface + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } + bool SinkOrderDependent() const override { + return false; + } + +public: + string ParamsToString() const override; + + //! Schedules tasks to merge the data during the Finalize phase + static void ScheduleMergeTasks(Pipeline &pipeline, Event &event, OrderGlobalSinkState &state); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/order/physical_top_n.hpp b/src/duckdb/src/include/duckdb/execution/operator/order/physical_top_n.hpp new file mode 100644 index 00000000..87e743f4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/order/physical_top_n.hpp @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/order/physical_top_n.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/bound_query_node.hpp" + +namespace duckdb { + +//! Represents a physical ordering of the data. Note that this will not change +//! the data but only add a selection vector. +class PhysicalTopN : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::TOP_N; + +public: + PhysicalTopN(vector types, vector orders, idx_t limit, idx_t offset, + idx_t estimated_cardinality); + + vector orders; + idx_t limit; + idx_t offset; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + OrderPreservationType SourceOrder() const override { + return OrderPreservationType::FIXED_ORDER; + } + +public: + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } + + string ParamsToString() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/csv_rejects_table.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/csv_rejects_table.hpp new file mode 100644 index 00000000..10f57b11 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/csv_rejects_table.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include "duckdb.hpp" +#ifndef DUCKDB_AMALGAMATION +#include "duckdb/storage/object_cache.hpp" +#endif + +namespace duckdb { + +struct ReadCSVData; + +class CSVRejectsTable : public ObjectCacheEntry { +public: + CSVRejectsTable(string name) : name(name), count(0) { + } + ~CSVRejectsTable() override = default; + mutex write_lock; + string name; + idx_t count; + + static shared_ptr GetOrCreate(ClientContext &context, const string &name); + + void InitializeTable(ClientContext &context, const ReadCSVData &options); + TableCatalogEntry &GetTable(ClientContext &context); + +public: + static string ObjectType() { + return "csv_rejects_table_cache"; + } + + string GetObjectType() override { + return ObjectType(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp new file mode 100644 index 00000000..e91e4d92 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp @@ -0,0 +1,81 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/function/copy_function.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/filename_pattern.hpp" + +namespace duckdb { + +//! Copy the contents of a query into a table +class PhysicalBatchCopyToFile : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::BATCH_COPY_TO_FILE; + +public: + PhysicalBatchCopyToFile(vector types, CopyFunction function, unique_ptr bind_data, + idx_t estimated_cardinality); + + CopyFunction function; + unique_ptr bind_data; + string file_path; + bool use_tmp_file; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + void NextBatch(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate_p) const override; + + bool RequiresBatchIndex() const override { + return true; + } + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } + +private: + void PrepareBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t batch_index, + unique_ptr collection) const; + void FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index) const; + SinkFinalizeType FinalFlush(ClientContext &context, GlobalSinkState &gstate_p) const; +}; + +struct ActiveFlushGuard { + explicit ActiveFlushGuard(atomic &bool_value_p) : bool_value(bool_value_p) { + bool_value = true; + } + ~ActiveFlushGuard() { + bool_value = false; + } + + atomic &bool_value; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp new file mode 100644 index 00000000..bc87f59f --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/persistent/physical_batch_insert.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/persistent/physical_insert.hpp" + +namespace duckdb { + +class PhysicalBatchInsert : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::BATCH_INSERT; + +public: + //! INSERT INTO + PhysicalBatchInsert(vector types, TableCatalogEntry &table, + physical_index_vector_t column_index_map, vector> bound_defaults, + idx_t estimated_cardinality); + //! CREATE TABLE AS + PhysicalBatchInsert(LogicalOperator &op, SchemaCatalogEntry &schema, unique_ptr info, + idx_t estimated_cardinality); + + //! The map from insert column index to table column index + physical_index_vector_t column_index_map; + //! The table to insert into + optional_ptr insert_table; + //! The insert types + vector insert_types; + //! The default expressions of the columns for which no value is provided + vector> bound_defaults; + //! Table schema, in case of CREATE TABLE AS + optional_ptr schema; + //! Create table info, in case of CREATE TABLE AS + unique_ptr info; + // Which action to perform on conflict + OnConflictAction action_type; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + void NextBatch(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate_p) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool RequiresBatchIndex() const override { + return true; + } + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp new file mode 100644 index 00000000..55e64dca --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/persistent/physical_copy_to_file.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/function/copy_function.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/filename_pattern.hpp" + +namespace duckdb { + +//! Copy the contents of a query into a table +class PhysicalCopyToFile : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::COPY_TO_FILE; + +public: + PhysicalCopyToFile(vector types, CopyFunction function, unique_ptr bind_data, + idx_t estimated_cardinality); + + CopyFunction function; + unique_ptr bind_data; + string file_path; + bool use_tmp_file; + FilenamePattern filename_pattern; + bool overwrite_or_ignore; + bool parallel; + bool per_thread_output; + + bool partition_output; + vector partition_columns; + vector names; + vector expected_types; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool IsSink() const override { + return true; + } + + bool SinkOrderDependent() const override { + return true; + } + + bool ParallelSink() const override { + return per_thread_output || partition_output || parallel; + } + + static void MoveTmpFile(ClientContext &context, const string &tmp_file_path); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_delete.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_delete.hpp new file mode 100644 index 00000000..387df24d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_delete.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/persistent/physical_delete.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { +class DataTable; + +//! Physically delete data from a table +class PhysicalDelete : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::DELETE_OPERATOR; + +public: + PhysicalDelete(vector types, TableCatalogEntry &tableref, DataTable &table, idx_t row_id_index, + idx_t estimated_cardinality, bool return_chunk) + : PhysicalOperator(PhysicalOperatorType::DELETE_OPERATOR, std::move(types), estimated_cardinality), + tableref(tableref), table(table), row_id_index(row_id_index), return_chunk(return_chunk) { + } + + TableCatalogEntry &tableref; + DataTable &table; + idx_t row_id_index; + bool return_chunk; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_export.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_export.hpp new file mode 100644 index 00000000..54622e4d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_export.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/persistent/physical_export.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/function/copy_function.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/parser/parsed_data/exported_table_data.hpp" + +namespace duckdb { +//! Parse a file from disk using a specified copy function and return the set of chunks retrieved from the file +class PhysicalExport : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::EXPORT; + +public: + PhysicalExport(vector types, CopyFunction function, unique_ptr info, + idx_t estimated_cardinality, BoundExportData exported_tables) + : PhysicalOperator(PhysicalOperatorType::EXPORT, std::move(types), estimated_cardinality), + function(std::move(function)), info(std::move(info)), exported_tables(std::move(exported_tables)) { + } + + //! The copy function to use to read the file + CopyFunction function; + //! The binding info containing the set of options for reading the file + unique_ptr info; + //! The table info for each table that will be exported + BoundExportData exported_tables; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + + bool ParallelSink() const override { + return true; + } + bool IsSink() const override { + return true; + } + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + vector> GetSources() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_fixed_batch_copy.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_fixed_batch_copy.hpp new file mode 100644 index 00000000..b8ca7329 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_fixed_batch_copy.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/persistent/physical_fixed_batch_copy.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/function/copy_function.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/filename_pattern.hpp" + +namespace duckdb { + +class PhysicalFixedBatchCopy : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::FIXED_BATCH_COPY_TO_FILE; + +public: + PhysicalFixedBatchCopy(vector types, CopyFunction function, unique_ptr bind_data, + idx_t estimated_cardinality); + + CopyFunction function; + unique_ptr bind_data; + string file_path; + bool use_tmp_file; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + void NextBatch(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate_p) const override; + + bool RequiresBatchIndex() const override { + return true; + } + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return true; + } + +public: + void AddRawBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t batch_index, + unique_ptr collection) const; + void RepartitionBatches(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index, + bool final = false) const; + void FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index) const; + bool ExecuteTask(ClientContext &context, GlobalSinkState &gstate_p) const; + void ExecuteTasks(ClientContext &context, GlobalSinkState &gstate_p) const; + SinkFinalizeType FinalFlush(ClientContext &context, GlobalSinkState &gstate_p) const; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_insert.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_insert.hpp new file mode 100644 index 00000000..8609a928 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_insert.hpp @@ -0,0 +1,123 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/persistent/physical_insert.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/common/index_vector.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" + +namespace duckdb { + +class InsertLocalState; + +//! Physically insert a set of data into a table +class PhysicalInsert : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::INSERT; + +public: + //! INSERT INTO + PhysicalInsert(vector types, TableCatalogEntry &table, physical_index_vector_t column_index_map, + vector> bound_defaults, vector> set_expressions, + vector set_columns, vector set_types, idx_t estimated_cardinality, + bool return_chunk, bool parallel, OnConflictAction action_type, + unique_ptr on_conflict_condition, unique_ptr do_update_condition, + unordered_set on_conflict_filter, vector columns_to_fetch); + //! CREATE TABLE AS + PhysicalInsert(LogicalOperator &op, SchemaCatalogEntry &schema, unique_ptr info, + idx_t estimated_cardinality, bool parallel); + + //! The map from insert column index to table column index + physical_index_vector_t column_index_map; + //! The table to insert into + optional_ptr insert_table; + //! The insert types + vector insert_types; + //! The default expressions of the columns for which no value is provided + vector> bound_defaults; + //! If the returning statement is present, return the whole chunk + bool return_chunk; + //! Table schema, in case of CREATE TABLE AS + optional_ptr schema; + //! Create table info, in case of CREATE TABLE AS + unique_ptr info; + //! Whether or not the INSERT can be executed in parallel + //! This insert is not order preserving if executed in parallel + bool parallel; + // Which action to perform on conflict + OnConflictAction action_type; + + // The DO UPDATE set expressions, if 'action_type' is UPDATE + vector> set_expressions; + // Which columns are targeted by the set expressions + vector set_columns; + // The types of the columns targeted by a SET expression + vector set_types; + + // Condition for the ON CONFLICT clause + unique_ptr on_conflict_condition; + // Condition for the DO UPDATE clause + unique_ptr do_update_condition; + // The column ids to apply the ON CONFLICT on + unordered_set conflict_target; + + // Column ids from the original table to fetch + vector columns_to_fetch; + // Matching types to the column ids to fetch + vector types_to_fetch; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + + bool ParallelSink() const override { + return parallel; + } + + bool SinkOrderDependent() const override { + return true; + } + +public: + static void GetInsertInfo(const BoundCreateTableInfo &info, vector &insert_types, + vector> &bound_defaults); + static void ResolveDefaults(const TableCatalogEntry &table, DataChunk &chunk, + const physical_index_vector_t &column_index_map, + ExpressionExecutor &defaults_executor, DataChunk &result); + +protected: + void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_chunk, DataChunk &input_chunk, + ClientContext &client) const; + //! Returns the amount of updated tuples + void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, TableCatalogEntry &table, Vector &row_ids, + DataChunk &result) const; + idx_t OnConflictHandling(TableCatalogEntry &table, ExecutionContext &context, InsertLocalState &lstate) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_update.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_update.hpp new file mode 100644 index 00000000..0cccd98d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_update.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/persistent/physical_update.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { +class DataTable; + +//! Physically update data in a table +class PhysicalUpdate : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::UPDATE; + +public: + PhysicalUpdate(vector types, TableCatalogEntry &tableref, DataTable &table, + vector columns, vector> expressions, + vector> bound_defaults, idx_t estimated_cardinality, bool return_chunk); + + TableCatalogEntry &tableref; + DataTable &table; + vector columns; + vector> expressions; + vector> bound_defaults; + bool update_is_del_and_insert; + //! If the returning statement is present, return the whole chunk + bool return_chunk; + +public: + // Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/projection/physical_pivot.hpp b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_pivot.hpp new file mode 100644 index 00000000..badb073d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_pivot.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/projection/physical_pivot.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/planner/tableref/bound_pivotref.hpp" + +namespace duckdb { + +//! PhysicalPivot implements the physical PIVOT operation +class PhysicalPivot : public PhysicalOperator { +public: + PhysicalPivot(vector types, unique_ptr child, BoundPivotInfo bound_pivot); + + BoundPivotInfo bound_pivot; + //! The map for pivot value -> column index + string_map_t pivot_map; + //! The empty aggregate values + vector empty_aggregates; + +public: + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + + bool ParallelOperator() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/projection/physical_projection.hpp b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_projection.hpp new file mode 100644 index 00000000..d2650b98 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_projection.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/projection/physical_projection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class PhysicalProjection : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::PROJECTION; + +public: + PhysicalProjection(vector types, vector> select_list, + idx_t estimated_cardinality); + + vector> select_list; + +public: + unique_ptr GetOperatorState(ExecutionContext &context) const override; + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + + bool ParallelOperator() const override { + return true; + } + + string ParamsToString() const override; + + static unique_ptr + CreateJoinProjection(vector proj_types, const vector &lhs_types, + const vector &rhs_types, const vector &left_projection_map, + const vector &right_projection_map, const idx_t estimated_cardinality); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp new file mode 100644 index 00000000..659e5cc4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/projection/physical_tableinout_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/function/table_function.hpp" + +namespace duckdb { + +class PhysicalTableInOutFunction : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::INOUT_FUNCTION; + +public: + PhysicalTableInOutFunction(vector types, TableFunction function_p, + unique_ptr bind_data_p, vector column_ids_p, + idx_t estimated_cardinality, vector projected_input); + +public: + unique_ptr GetOperatorState(ExecutionContext &context) const override; + unique_ptr GetGlobalOperatorState(ClientContext &context) const override; + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + OperatorFinalizeResultType FinalExecute(ExecutionContext &context, DataChunk &chunk, GlobalOperatorState &gstate, + OperatorState &state) const override; + + bool ParallelOperator() const override { + return true; + } + + bool RequiresFinalExecute() const override { + return function.in_out_function_final; + } + +private: + //! The table function + TableFunction function; + //! Bind data of the function + unique_ptr bind_data; + //! The set of column ids to fetch + vector column_ids; + //! The set of input columns to project out + vector projected_input; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/projection/physical_unnest.hpp b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_unnest.hpp new file mode 100644 index 00000000..19400135 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_unnest.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/projection/physical_unnest.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! PhysicalUnnest implements the physical UNNEST operation +class PhysicalUnnest : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::UNNEST; + +public: + PhysicalUnnest(vector types, vector> select_list, idx_t estimated_cardinality, + PhysicalOperatorType type = PhysicalOperatorType::UNNEST); + + //! The projection list of the UNNEST + //! E.g. SELECT 1, UNNEST([1]), UNNEST([2, 3]); has two UNNESTs in its select_list + vector> select_list; + +public: + unique_ptr GetOperatorState(ExecutionContext &context) const override; + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + + bool ParallelOperator() const override { + return true; + } + +public: + static unique_ptr GetState(ExecutionContext &context, + const vector> &select_list); + //! Executes the UNNEST operator internally and emits a chunk of unnested data. If include_input is set, then + //! the resulting chunk also contains vectors for all non-UNNEST columns in the projection. If include_input is + //! not set, then the UNNEST behaves as a table function and only emits the unnested data. + static OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + OperatorState &state, const vector> &select_list, + bool include_input = true); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/base_csv_reader.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/base_csv_reader.hpp new file mode 100644 index 00000000..ea214f8a --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/base_csv_reader.hpp @@ -0,0 +1,123 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/base_csv_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/enums/file_compression_type.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/queue.hpp" +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/common/multi_file_reader.hpp" +#include "duckdb/execution/operator/scan/csv/csv_line_info.hpp" + +#include + +namespace duckdb { +struct CopyInfo; +struct CSVFileHandle; +struct FileHandle; +struct StrpTimeFormat; + +class FileOpener; +class FileSystem; + +enum class ParserMode : uint8_t { PARSING = 0, SNIFFING_DATATYPES = 1, PARSING_HEADER = 2 }; + +//! Buffered CSV reader is a class that reads values from a stream and parses them as a CSV file +class BaseCSVReader { +public: + BaseCSVReader(ClientContext &context, CSVReaderOptions options, + const vector &requested_types = vector()); + virtual ~BaseCSVReader(); + + ClientContext &context; + FileSystem &fs; + Allocator &allocator; + CSVReaderOptions options; + vector return_types; + vector names; + MultiFileReaderData reader_data; + + idx_t linenr = 0; + bool linenr_estimated = false; + + bool row_empty = false; + idx_t sample_chunk_idx = 0; + bool jumping_samples = false; + bool end_of_file_reached = false; + bool bom_checked = false; + + idx_t bytes_in_chunk = 0; + double bytes_per_line_avg = 0; + + DataChunk parse_chunk; + + ParserMode mode; + +public: + const string &GetFileName() { + return options.file_path; + } + const vector &GetNames() { + return names; + } + const vector &GetTypes() { + return return_types; + } + //! Get the 1-indexed global line number for the given local error line + virtual idx_t GetLineError(idx_t line_error, idx_t buffer_idx, bool stop_at_first = true) { + return line_error + 1; + }; + + virtual void Increment(idx_t buffer_idx) { + return; + } + + //! Initialize projection indices to select all columns + void InitializeProjection(); + + static unique_ptr OpenCSV(ClientContext &context, const CSVReaderOptions &options); + + static bool TryCastDateVector(map &options, Vector &input_vector, + Vector &result_vector, idx_t count, string &error_message, idx_t &line_error); + + static bool TryCastTimestampVector(map &options, Vector &input_vector, + Vector &result_vector, idx_t count, string &error_message); + +protected: + //! Initializes the parse_chunk with varchar columns and aligns info with new number of cols + void InitParseChunk(idx_t num_cols); + //! Adds a value to the current row + void AddValue(string_t str_val, idx_t &column, vector &escape_positions, bool has_quotes, + idx_t buffer_idx = 0); + //! Adds a row to the insert_chunk, returns true if the chunk is filled as a result of this row being added + bool AddRow(DataChunk &insert_chunk, idx_t &column, string &error_message, idx_t buffer_idx = 0); + //! Finalizes a chunk, parsing all values that have been added so far and adding them to the insert_chunk + bool Flush(DataChunk &insert_chunk, idx_t buffer_idx = 0, bool try_add_line = false); + + void VerifyUTF8(idx_t col_idx); + void VerifyUTF8(idx_t col_idx, idx_t row_idx, DataChunk &chunk, int64_t offset = 0); + string GetLineNumberStr(idx_t linenr, bool linenr_estimated, idx_t buffer_idx = 0); + + //! Sets the newline delimiter + void SetNewLineDelimiter(bool carry = false, bool carry_followed_by_nl = false); + + //! Verifies that the line length did not go over a pre-defined limit. + void VerifyLineLength(idx_t line_size, idx_t buffer_idx = 0); + +protected: + //! Whether or not the current row's columns have overflown return_types.size() + bool error_column_overflow = false; + //! Number of sniffed columns - only used when auto-detecting +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/buffered_csv_reader.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/buffered_csv_reader.hpp new file mode 100644 index 00000000..0dc15677 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/buffered_csv_reader.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/buffered_csv_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp" +#include "duckdb/execution/operator/scan/csv/base_csv_reader.hpp" +#include "duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp" + +namespace duckdb { +struct CopyInfo; +struct CSVFileHandle; +struct FileHandle; +struct StrpTimeFormat; + +class FileOpener; +class FileSystem; + +//! Buffered CSV reader is a class that reads values from a stream and parses them as a CSV file +class BufferedCSVReader : public BaseCSVReader { + //! Initial buffer read size; can be extended for long lines + static constexpr idx_t INITIAL_BUFFER_SIZE = 16384; + //! Larger buffer size for non disk files + static constexpr idx_t INITIAL_BUFFER_SIZE_LARGE = 10000000; // 10MB + +public: + BufferedCSVReader(ClientContext &context, CSVReaderOptions options, + const vector &requested_types = vector()); + BufferedCSVReader(ClientContext &context, string filename, CSVReaderOptions options, + const vector &requested_types = vector()); + virtual ~BufferedCSVReader() { + } + + unsafe_unique_array buffer; + idx_t buffer_size; + idx_t position; + idx_t start = 0; + + vector> cached_buffers; + + unique_ptr file_handle; + //! CSV State Machine Cache + CSVStateMachineCache state_machine_cache; + +public: + //! Extract a single DataChunk from the CSV file and stores it in insert_chunk + void ParseCSV(DataChunk &insert_chunk); + static string ColumnTypesError(case_insensitive_map_t sql_types_per_column, const vector &names); + +private: + //! Initialize Parser + void Initialize(const vector &requested_types); + //! Skips skip_rows, reads header row from input stream + void SkipRowsAndReadHeader(idx_t skip_rows, bool skip_header); + //! Resets the buffer + void ResetBuffer(); + //! Reads a new buffer from the CSV file if the current one has been exhausted + bool ReadBuffer(idx_t &start, idx_t &line_start); + //! Try to parse a single datachunk from the file. Throws an exception if anything goes wrong. + void ParseCSV(ParserMode mode); + //! Extract a single DataChunk from the CSV file and stores it in insert_chunk + bool TryParseCSV(ParserMode mode, DataChunk &insert_chunk, string &error_message); + //! Skip Empty lines for tables with over one column + void SkipEmptyLines(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_buffer.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_buffer.hpp new file mode 100644 index 00000000..c49168e2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_buffer.hpp @@ -0,0 +1,110 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/csv_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/execution/operator/scan/csv/csv_file_handle.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" + +namespace duckdb { + +class CSVBufferHandle { +public: + CSVBufferHandle(BufferHandle handle_p, idx_t actual_size_p, const bool is_first_buffer_p, + const bool is_final_buffer_p, idx_t csv_global_state_p, idx_t start_position_p, idx_t file_idx_p) + : handle(std::move(handle_p)), actual_size(actual_size_p), is_first_buffer(is_first_buffer_p), + is_last_buffer(is_final_buffer_p), csv_global_start(csv_global_state_p), start_position(start_position_p), + file_idx(file_idx_p) {}; + CSVBufferHandle() + : actual_size(0), is_first_buffer(false), is_last_buffer(false), csv_global_start(0), start_position(0), + file_idx(0) {}; + //! Handle created during allocation + BufferHandle handle; + const idx_t actual_size; + const bool is_first_buffer; + const bool is_last_buffer; + const idx_t csv_global_start; + const idx_t start_position; + const idx_t file_idx; + inline char *Ptr() { + return char_ptr_cast(handle.Ptr()); + } +}; + +//! CSV Buffers are parts of a decompressed CSV File. +//! For a decompressed file of 100Mb. With our Buffer size set to 32Mb, we would generate 4 buffers. +//! One for the first 32Mb, second and third for the other 32Mb, and the last one with 4 Mb +//! These buffers are actually used for sniffing and parsing! +class CSVBuffer { +public: + //! Constructor for Initial Buffer + CSVBuffer(ClientContext &context, idx_t buffer_size_p, CSVFileHandle &file_handle, + idx_t &global_csv_current_position, idx_t file_number); + + //! Constructor for `Next()` Buffers + CSVBuffer(CSVFileHandle &file_handle, ClientContext &context, idx_t buffer_size, idx_t global_csv_current_position, + idx_t file_number_p); + + //! Creates a new buffer with the next part of the CSV File + shared_ptr Next(CSVFileHandle &file_handle, idx_t buffer_size, idx_t file_number); + + //! Gets the buffer actual size + idx_t GetBufferSize(); + + //! Gets the start position of the buffer, only relevant for the first time it's scanned + idx_t GetStart(); + + //! If this buffer is the last buffer of the CSV File + bool IsCSVFileLastBuffer(); + + //! Allocates internal buffer, sets 'block' and 'handle' variables. + void AllocateBuffer(idx_t buffer_size); + + void Reload(CSVFileHandle &file_handle); + //! Wrapper for the Pin Function, if it can seek, it means that the buffer might have been destroyed, hence we must + //! Scan it from the disk file again. + unique_ptr Pin(CSVFileHandle &file_handle); + //! Wrapper for the unpin + void Unpin(); + char *Ptr() { + return char_ptr_cast(handle.Ptr()); + } + + static constexpr idx_t CSV_BUFFER_SIZE = 32000000; // 32MB + //! In case the file has a size < 32MB, we will use this size instead + //! This is to avoid mallocing a lot of memory for a small file + //! And if it's a compressed file we can't use the actual size of the file + static constexpr idx_t CSV_MINIMUM_BUFFER_SIZE = 10000000; // 10MB + //! If this is the last buffer of the CSV File + bool last_buffer = false; + +private: + ClientContext &context; + //! Actual size can be smaller than the buffer size in case we allocate it too optimistically. + idx_t actual_buffer_size; + //! We need to check for Byte Order Mark, to define the start position of this buffer + //! https://en.wikipedia.org/wiki/Byte_order_mark#UTF-8 + idx_t start_position = 0; + //! If this is the first buffer of the CSV File + bool first_buffer = false; + //! Global position from the CSV File where this buffer starts + idx_t global_csv_start = 0; + //! Number of the file that is in this buffer + idx_t file_number = 0; + //! If we can seek in the file or not. + //! If we can't seek, this means we can't destroy the buffers + bool can_seek; + //! -------- Allocated Block ---------// + //! Block created in allocation + shared_ptr block; + BufferHandle handle; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp new file mode 100644 index 00000000..169ac0be --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/client_context.hpp" +#include "duckdb/execution/operator/scan/csv/csv_file_handle.hpp" +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" + +namespace duckdb { +class CSVBuffer; +class CSVStateMachine; + +//! This class is used to manage the CSV buffers. Buffers are cached when used for auto detection. +//! When parsing, buffer are not cached and just returned. +class CSVBufferManager { +public: + CSVBufferManager(ClientContext &context, unique_ptr file_handle, const CSVReaderOptions &options, + idx_t file_idx = 0); + //! Returns a buffer from a buffer id (starting from 0). If it's in the auto-detection then we cache new buffers + //! Otherwise we remove them from the cache if they are already there, or just return them bypassing the cache. + unique_ptr GetBuffer(const idx_t pos); + //! Returns the starting position of the first buffer + idx_t GetStartPos(); + //! unique_ptr to the file handle, gets stolen after sniffing + unique_ptr file_handle; + //! Initializes the buffer manager, during it's construction/reset + void Initialize(); + + void UnpinBuffer(idx_t cache_idx); + + ClientContext &context; + idx_t skip_rows = 0; + idx_t file_idx; + bool done = false; + +private: + //! Reads next buffer in reference to cached_buffers.front() + bool ReadNextAndCacheIt(); + vector> cached_buffers; + shared_ptr last_buffer; + idx_t global_csv_pos = 0; + //! The size of the buffer, if the csv file has a smaller size than this, we will use that instead to malloc less + idx_t buffer_size; + //! Starting position of first buffer + idx_t start_pos = 0; +}; + +class CSVBufferIterator { +public: + explicit CSVBufferIterator(shared_ptr buffer_manager_p) + : buffer_manager(std::move(buffer_manager_p)) { + cur_pos = buffer_manager->GetStartPos(); + }; + + //! This functions templates an operation over the CSV File + template + inline bool Process(CSVStateMachine &machine, T &result) { + + OP::Initialize(machine); + //! If current buffer is not set we try to get a new one + if (!cur_buffer_handle) { + cur_pos = 0; + if (cur_buffer_idx == 0) { + cur_pos = buffer_manager->GetStartPos(); + } + cur_buffer_handle = buffer_manager->GetBuffer(cur_buffer_idx++); + D_ASSERT(cur_buffer_handle); + } + while (cur_buffer_handle) { + char *buffer_handle_ptr = cur_buffer_handle->Ptr(); + while (cur_pos < cur_buffer_handle->actual_size) { + if (OP::Process(machine, result, buffer_handle_ptr[cur_pos], cur_pos)) { + //! Not-Done Processing the File, but the Operator is happy! + OP::Finalize(machine, result); + return false; + } + cur_pos++; + } + cur_buffer_handle = buffer_manager->GetBuffer(cur_buffer_idx++); + cur_pos = 0; + } + //! Done Processing the File + OP::Finalize(machine, result); + return true; + } + //! Returns true if the iterator is finished + bool Finished(); + //! Resets the iterator + void Reset(); + +private: + idx_t cur_pos = 0; + idx_t cur_buffer_idx = 0; + shared_ptr buffer_manager; + unique_ptr cur_buffer_handle; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_file_handle.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_file_handle.hpp new file mode 100644 index 00000000..c27538fe --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_file_handle.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/csv_file_handle.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/allocator.hpp" + +namespace duckdb { +class Allocator; +class FileSystem; + +struct CSVFileHandle { +public: + CSVFileHandle(FileSystem &fs, Allocator &allocator, unique_ptr file_handle_p, const string &path_p, + FileCompressionType compression); + + mutex main_mutex; + +public: + bool CanSeek(); + void Seek(idx_t position); + bool OnDiskFile(); + + idx_t FileSize(); + + bool FinishedReading(); + + idx_t Read(void *buffer, idx_t nr_bytes); + + string ReadLine(); + + string GetFilePath(); + + static unique_ptr OpenFileHandle(FileSystem &fs, Allocator &allocator, const string &path, + FileCompressionType compression); + static unique_ptr OpenFile(FileSystem &fs, Allocator &allocator, const string &path, + FileCompressionType compression); + +private: + unique_ptr file_handle; + string path; + bool can_seek = false; + bool on_disk_file = false; + idx_t file_size = 0; + + idx_t requested_bytes = 0; + //! If we finished reading the file + bool finished = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_line_info.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_line_info.hpp new file mode 100644 index 00000000..b5713283 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_line_info.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/csv_line_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { +struct LineInfo { +public: + explicit LineInfo(mutex &main_mutex_p, vector> &batch_to_tuple_end_p, + vector> &tuple_start_p, vector> &tuple_end_p) + : main_mutex(main_mutex_p), batch_to_tuple_end(batch_to_tuple_end_p), tuple_start(tuple_start_p), + tuple_end(tuple_end_p) {}; + bool CanItGetLine(idx_t file_idx, idx_t batch_idx); + + //! Return the 1-indexed line number + idx_t GetLine(idx_t batch_idx, idx_t line_error = 0, idx_t file_idx = 0, idx_t cur_start = 0, bool verify = true, + bool stop_at_first = true); + //! In case an error happened we have to increment the lines read of that batch + void Increment(idx_t file_idx, idx_t batch_idx); + //! Verify if the CSV File was read correctly from [0,batch_idx] batches. + void Verify(idx_t file_idx, idx_t batch_idx, idx_t cur_first_pos); + //! Lines read per batch, > + vector> lines_read; + //! Lines read per batch, > + vector> lines_errored; + //! Set of batches that have been initialized but are not yet finished. + vector> current_batches; + //! Pointer to CSV Reader Mutex + mutex &main_mutex; + //! Pointer Batch to Tuple End + vector> &batch_to_tuple_end; + //! Pointer Batch to Tuple Start + vector> &tuple_start; + //! Pointer Batch to Tuple End + vector> &tuple_end; + //! If we already threw an exception on a previous thread. + bool done = false; + idx_t first_line = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_reader_options.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_reader_options.hpp new file mode 100644 index 00000000..b002c468 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_reader_options.hpp @@ -0,0 +1,210 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/csv_reader_options.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/scan/csv/csv_buffer.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/multi_file_reader_options.hpp" + +namespace duckdb { + +enum class NewLineIdentifier : uint8_t { + SINGLE = 1, // Either \r or \n + CARRY_ON = 2, // \r\n + MIX = 3, // Hippie-Land, can't run it multithreaded + NOT_SET = 4 +}; + +enum class ParallelMode { AUTOMATIC = 0, PARALLEL = 1, SINGLE_THREADED = 2 }; + +//! Struct that holds the configuration of a CSV State Machine +//! Basically which char, quote and escape were used to generate it. +struct CSVStateMachineOptions { + CSVStateMachineOptions() {}; + CSVStateMachineOptions(char delimiter_p, char quote_p, char escape_p) + : delimiter(delimiter_p), quote(quote_p), escape(escape_p) {}; + + //! Delimiter to separate columns within each line + char delimiter = ','; + //! Quote used for columns that contain reserved characters, e.g ' + char quote = '\"'; + //! Escape character to escape quote character + char escape = '\0'; + + bool operator==(const CSVStateMachineOptions &other) const { + return delimiter == other.delimiter && quote == other.quote && escape == other.escape; + } +}; + +struct DialectOptions { + CSVStateMachineOptions state_machine_options; + //! New Line separator + NewLineIdentifier new_line = NewLineIdentifier::NOT_SET; + //! Expected number of columns + idx_t num_cols = 0; + //! Whether or not the file has a header line + bool header = false; + //! The date format to use (if any is specified) + map date_format = {{LogicalTypeId::DATE, {}}, {LogicalTypeId::TIMESTAMP, {}}}; + //! Whether or not a type format is specified + map has_format = {{LogicalTypeId::DATE, false}, {LogicalTypeId::TIMESTAMP, false}}; + //! How many leading rows to skip + idx_t skip_rows = 0; + //! True start of the first CSV Buffer (After skipping empty lines, headers, notes and so on) + idx_t true_start = 0; +}; + +struct CSVReaderOptions { + //===--------------------------------------------------------------------===// + // CommonCSVOptions + //===--------------------------------------------------------------------===// + //! See struct above. + DialectOptions dialect_options; + //! Whether or not a delimiter was defined by the user + bool has_delimiter = false; + //! Whether or not a new_line was defined by the user + bool has_newline = false; + //! Whether or not a quote was defined by the user + bool has_quote = false; + //! Whether or not an escape character was defined by the user + bool has_escape = false; + //! Whether or not a header information was given by the user + bool has_header = false; + //! Whether or not we should ignore InvalidInput errors + bool ignore_errors = false; + //! Rejects table name + string rejects_table_name; + //! Rejects table entry limit (0 = no limit) + idx_t rejects_limit = 0; + //! Columns to use as recovery key for rejected rows when reading with ignore_errors = true + vector rejects_recovery_columns; + //! Index of the recovery columns + vector rejects_recovery_column_ids; + //! Number of samples to buffer + idx_t buffer_sample_size = STANDARD_VECTOR_SIZE * 50; + //! Specifies the string that represents a null value + string null_str; + //! Whether file is compressed or not, and if so which compression type + //! AUTO_DETECT (default; infer from file extension) + FileCompressionType compression = FileCompressionType::AUTO_DETECT; + //! Option to convert quoted values to NULL values + bool allow_quoted_nulls = true; + + //===--------------------------------------------------------------------===// + // CSVAutoOptions + //===--------------------------------------------------------------------===// + //! SQL Type list mapping of name to SQL type index in sql_type_list + case_insensitive_map_t sql_types_per_column; + //! User-defined SQL type list + vector sql_type_list; + //! User-defined name list + vector name_list; + //! Types considered as candidates for auto detection ordered by descending specificity (~ from high to low) + vector auto_type_candidates = {LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::DATE, + LogicalType::TIME, LogicalType::DOUBLE, LogicalType::BIGINT, + LogicalType::BOOLEAN, LogicalType::SQLNULL}; + + //===--------------------------------------------------------------------===// + // ReadCSVOptions + //===--------------------------------------------------------------------===// + //! Whether or not the skip_rows is set by the user + bool skip_rows_set = false; + //! Maximum CSV line size: specified because if we reach this amount, we likely have wrong delimiters (default: 2MB) + //! note that this is the guaranteed line length that will succeed, longer lines may be accepted if slightly above + idx_t maximum_line_size = 2097152; + //! Whether or not header names shall be normalized + bool normalize_names = false; + //! True, if column with that index must skip null check + vector force_not_null; + //! Number of sample chunks used in auto-detection + idx_t sample_size_chunks = 20480 / STANDARD_VECTOR_SIZE; + //! Consider all columns to be of type varchar + bool all_varchar = false; + //! Whether or not to automatically detect dialect and datatypes + bool auto_detect = false; + //! The file path of the CSV file to read + string file_path; + //! Multi-file reader options + MultiFileReaderOptions file_options; + //! Buffer Size (Parallel Scan) + idx_t buffer_size = CSVBuffer::CSV_BUFFER_SIZE; + //! Decimal separator when reading as numeric + string decimal_separator = "."; + //! Whether or not to pad rows that do not have enough columns with NULL values + bool null_padding = false; + + //! If we are running the parallel version of the CSV Reader. In general, the system should always auto-detect + //! When it can't execute a parallel run before execution. However, there are (rather specific) situations where + //! setting up this manually might be important + ParallelMode parallel_mode; + //===--------------------------------------------------------------------===// + // WriteCSVOptions + //===--------------------------------------------------------------------===// + //! True, if column with that index must be quoted + vector force_quote; + //! Prefix/suffix/custom newline the entire file once (enables writing of files as JSON arrays) + string prefix; + string suffix; + string write_newline; + + //! The date format to use (if any is specified) + map date_format = {{LogicalTypeId::DATE, {}}, {LogicalTypeId::TIMESTAMP, {}}}; + //! The date format to use for writing (if any is specified) + map write_date_format = {{LogicalTypeId::DATE, {}}, {LogicalTypeId::TIMESTAMP, {}}}; + //! Whether or not a type format is specified + map has_format = {{LogicalTypeId::DATE, false}, {LogicalTypeId::TIMESTAMP, false}}; + + void Serialize(Serializer &serializer) const; + static CSVReaderOptions Deserialize(Deserializer &deserializer); + + void SetCompression(const string &compression); + + bool GetHeader() const; + void SetHeader(bool has_header); + + string GetEscape() const; + void SetEscape(const string &escape); + + int64_t GetSkipRows() const; + void SetSkipRows(int64_t rows); + + string GetQuote() const; + void SetQuote(const string "e); + void SetDelimiter(const string &delimiter); + string GetDelimiter() const; + + NewLineIdentifier GetNewline() const; + void SetNewline(const string &input); + //! Set an option that is supported by both reading and writing functions, called by + //! the SetReadOption and SetWriteOption methods + bool SetBaseOption(const string &loption, const Value &value); + + //! loption - lowercase string + //! set - argument(s) to the option + //! expected_names - names expected if the option is "columns" + void SetReadOption(const string &loption, const Value &value, vector &expected_names); + void SetWriteOption(const string &loption, const Value &value); + void SetDateFormat(LogicalTypeId type, const string &format, bool read_format); + void ToNamedParameters(named_parameter_map_t &out); + void FromNamedParameters(named_parameter_map_t &in, ClientContext &context, vector &return_types, + vector &names); + + string ToString() const; + + named_parameter_map_t OutputReadSettings(); + +public: + //! Whether columns were explicitly provided through named parameters + bool explicitly_set_columns = false; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_sniffer.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_sniffer.hpp new file mode 100644 index 00000000..01ed8560 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_sniffer.hpp @@ -0,0 +1,129 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/csv_sniffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/scan/csv/csv_state_machine.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/execution/operator/scan/csv/quote_rules.hpp" + +namespace duckdb { +//! Struct to store the result of the Sniffer +struct SnifferResult { + SnifferResult(vector return_types_p, vector names_p) + : return_types(std::move(return_types_p)), names(std::move(names_p)) { + } + //! Return Types that were detected + vector return_types; + //! Column Names that were detected + vector names; +}; + +//! Sniffer that detects Header, Dialect and Types of CSV Files +class CSVSniffer { +public: + explicit CSVSniffer(CSVReaderOptions &options_p, shared_ptr buffer_manager_p, + CSVStateMachineCache &state_machine_cache, bool explicit_set_columns = false); + + //! Main method that sniffs the CSV file, returns the types, names and options as a result + //! CSV Sniffing consists of five steps: + //! 1. Dialect Detection: Generate the CSV Options (delimiter, quote, escape, etc.) + //! 2. Type Detection: Figures out the types of the columns (For one chunk) + //! 3. Header Detection: Figures out if the CSV file has a header and produces the names of the columns + //! 4. Type Replacement: Replaces the types of the columns if the user specified them + //! 5. Type Refinement: Refines the types of the columns for the remaining chunks + SnifferResult SniffCSV(); + +private: + //! CSV State Machine Cache + CSVStateMachineCache &state_machine_cache; + //! Highest number of columns found + idx_t max_columns_found = 0; + //! Current Candidates being considered + vector> candidates; + //! Reference to original CSV Options, it will be modified as a result of the sniffer. + CSVReaderOptions &options; + //! Buffer being used on sniffer + shared_ptr buffer_manager; + + //! ------------------------------------------------------// + //! ----------------- Dialect Detection ----------------- // + //! ------------------------------------------------------// + //! First phase of auto detection: detect CSV dialect (i.e. delimiter, quote rules, etc) + void DetectDialect(); + //! Functions called in the main DetectDialect(); function + //! 1. Generates the search space candidates for the dialect + void GenerateCandidateDetectionSearchSpace(vector &delim_candidates, vector "erule_candidates, + unordered_map> "e_candidates_map, + unordered_map> &escape_candidates_map); + //! 2. Generates the search space candidates for the state machines + void GenerateStateMachineSearchSpace(vector> &csv_state_machines, + const vector &delimiter_candidates, + const vector "erule_candidates, + const unordered_map> "e_candidates_map, + const unordered_map> &escape_candidates_map); + //! 3. Analyzes if dialect candidate is a good candidate to be considered, if so, it adds it to the candidates + void AnalyzeDialectCandidate(unique_ptr, idx_t &rows_read, idx_t &best_consistent_rows, + idx_t &prev_padding_count); + //! 4. Refine Candidates over remaining chunks + void RefineCandidates(); + //! Checks if candidate still produces good values for the next chunk + bool RefineCandidateNextChunk(CSVStateMachine &candidate); + + //! ------------------------------------------------------// + //! ------------------- Type Detection ------------------ // + //! ------------------------------------------------------// + //! Second phase of auto detection: detect types, format template candidates + //! ordered by descending specificity (~ from high to low) + void DetectTypes(); + //! Change the date format for the type to the string + //! Try to cast a string value to the specified sql type + bool TryCastValue(CSVStateMachine &candidate, const Value &value, const LogicalType &sql_type); + void SetDateFormat(CSVStateMachine &candidate, const string &format_specifier, const LogicalTypeId &sql_type); + //! Functions that performs detection for date and timestamp formats + void DetectDateAndTimeStampFormats(CSVStateMachine &candidate, map &has_format_candidates, + map> &format_candidates, + const LogicalType &sql_type, const string &separator, Value &dummy_val); + + //! Variables for Type Detection + //! Format Candidates for Date and Timestamp Types + const map> format_template_candidates = { + {LogicalTypeId::DATE, {"%m-%d-%Y", "%m-%d-%y", "%d-%m-%Y", "%d-%m-%y", "%Y-%m-%d", "%y-%m-%d"}}, + {LogicalTypeId::TIMESTAMP, + {"%Y-%m-%d %H:%M:%S.%f", "%m-%d-%Y %I:%M:%S %p", "%m-%d-%y %I:%M:%S %p", "%d-%m-%Y %H:%M:%S", + "%d-%m-%y %H:%M:%S", "%Y-%m-%d %H:%M:%S", "%y-%m-%d %H:%M:%S"}}, + }; + unordered_map> best_sql_types_candidates_per_column_idx; + map> best_format_candidates; + unique_ptr best_candidate; + idx_t best_start_with_header = 0; + idx_t best_start_without_header = 0; + vector best_header_row; + + //! ------------------------------------------------------// + //! ------------------ Header Detection ----------------- // + //! ------------------------------------------------------// + void DetectHeader(); + vector names; + //! If Column Names and Types have been explicitly set + const bool explicit_set_columns; + + //! ------------------------------------------------------// + //! ------------------ Type Replacement ----------------- // + //! ------------------------------------------------------// + void ReplaceTypes(); + + //! ------------------------------------------------------// + //! ------------------ Type Refinement ------------------ // + //! ------------------------------------------------------// + void RefineTypes(); + bool TryCastVector(Vector &parse_chunk_col, idx_t size, const LogicalType &sql_type); + vector detected_types; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_state_machine.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_state_machine.hpp new file mode 100644 index 00000000..b4d82c96 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_state_machine.hpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/csv_state_machine.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp" +#include "duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp" + +namespace duckdb { + +//! All States of CSV Parsing +enum class CSVState : uint8_t { + STANDARD = 0, //! Regular unquoted field state + DELIMITER = 1, //! State after encountering a field separator (e.g., ;) + RECORD_SEPARATOR = 2, //! State after encountering a record separator (i.e., \n) + CARRIAGE_RETURN = 3, //! State after encountering a carriage return(i.e., \r) + QUOTED = 4, //! State when inside a quoted field + UNQUOTED = 5, //! State when leaving a quoted field + ESCAPE = 6, //! State when encountering an escape character (e.g., \) + EMPTY_LINE = 7, //! State when encountering an empty line (i.e., \r\r \n\n, \n\r) + INVALID = 8 //! Got to an Invalid State, this should error. +}; + +//! The CSV State Machine comprises a state transition array (STA). +//! The STA indicates the current state of parsing based on both the current and preceding characters. +//! This reveals whether we are dealing with a Field, a New Line, a Delimiter, and so forth. +//! The STA's creation depends on the provided quote, character, and delimiter options for that state machine. +//! The motivation behind implementing an STA is to remove branching in regular CSV Parsing by predicting and detecting +//! the states. Note: The State Machine is currently utilized solely in the CSV Sniffer. +class CSVStateMachine { +public: + explicit CSVStateMachine(CSVReaderOptions &options_p, const CSVStateMachineOptions &state_machine_options, + shared_ptr buffer_manager_p, + CSVStateMachineCache &csv_state_machine_cache_p); + //! Resets the state machine, so it can be used again + void Reset(); + + //! Aux Function for string UTF8 Verification + void VerifyUTF8(); + + CSVStateMachineCache &csv_state_machine_cache; + + const CSVReaderOptions &options; + CSVBufferIterator csv_buffer_iterator; + //! Stores identified start row for this file (e.g., a file can start with garbage like notes, before the header) + idx_t start_row = 0; + //! The Transition Array is a Finite State Machine + //! It holds the transitions of all states, on all 256 possible different characters + const state_machine_t &transition_array; + + //! Both these variables are used for new line identifier detection + bool single_record_separator = false; + bool carry_on_separator = false; + + //! Variables Used for Sniffing + CSVState state; + CSVState previous_state; + CSVState pre_previous_state; + idx_t cur_rows; + idx_t column_count; + string value; + idx_t rows_read; + idx_t line_start_pos = 0; + + //! Dialect options resulting from sniffing + DialectOptions dialect_options; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp new file mode 100644 index 00000000..f63024cb --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp" +#include "duckdb/execution/operator/scan/csv/quote_rules.hpp" + +namespace duckdb { +static constexpr uint32_t NUM_STATES = 9; +static constexpr uint32_t NUM_TRANSITIONS = 256; +typedef uint8_t state_machine_t[NUM_STATES][NUM_TRANSITIONS]; + +//! Hash function used in out state machine cache, it hashes and combines all options used to generate a state machine +struct HashCSVStateMachineConfig { + size_t operator()(CSVStateMachineOptions const &config) const noexcept { + auto h_delimiter = Hash(config.delimiter); + auto h_quote = Hash(config.quote); + auto h_escape = Hash(config.escape); + return CombineHash(h_delimiter, CombineHash(h_quote, h_escape)); + } +}; + +//! The CSVStateMachineCache caches state machines, although small ~2kb, the actual creation of multiple State Machines +//! can become a bottleneck on sniffing, when reading very small csv files. +//! Hence the cache stores State Machines based on their different delimiter|quote|escape options. +class CSVStateMachineCache { +public: + CSVStateMachineCache(); + ~CSVStateMachineCache() {}; + //! Gets a state machine from the cache, if it's not from one the default options + //! It first caches it, then returns it. + const state_machine_t &Get(const CSVStateMachineOptions &state_machine_options); + +private: + void Insert(const CSVStateMachineOptions &state_machine_options); + //! Cache on delimiter|quote|escape + unordered_map state_machine_cache; + //! Default value for options used to intialize CSV State Machine Cache + const vector default_delimiter = {',', '|', ';', '\t'}; + const vector> default_quote = {{'\"'}, {'\"', '\''}, {'\0'}}; + const vector default_quote_rule = {QuoteRule::QUOTES_RFC, QuoteRule::QUOTES_OTHER, QuoteRule::NO_QUOTES}; + const vector> default_escape = {{'\0', '\"', '\''}, {'\\'}, {'\0'}}; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp new file mode 100644 index 00000000..511df229 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp @@ -0,0 +1,167 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/scan/csv/base_csv_reader.hpp" +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/execution/operator/scan/csv/csv_file_handle.hpp" +#include "duckdb/execution/operator/scan/csv/csv_buffer.hpp" +#include "duckdb/execution/operator/scan/csv/csv_line_info.hpp" + +#include +#include + +namespace duckdb { + +struct CSVBufferRead { + CSVBufferRead(unique_ptr buffer_p, idx_t buffer_start_p, idx_t buffer_end_p, idx_t batch_index, + idx_t local_batch_index_p, optional_ptr line_info_p) + : buffer(std::move(buffer_p)), line_info(line_info_p), buffer_start(buffer_start_p), buffer_end(buffer_end_p), + batch_index(batch_index), local_batch_index(local_batch_index_p) { + D_ASSERT(buffer); + if (buffer_end > buffer->actual_size) { + buffer_end = buffer->actual_size; + } + } + + CSVBufferRead(unique_ptr buffer_p, unique_ptr nxt_buffer_p, idx_t buffer_start_p, + idx_t buffer_end_p, idx_t batch_index, idx_t local_batch_index, optional_ptr line_info_p) + : CSVBufferRead(std::move(buffer_p), buffer_start_p, buffer_end_p, batch_index, local_batch_index, + line_info_p) { + next_buffer = std::move(nxt_buffer_p); + } + + CSVBufferRead() : buffer_start(0), buffer_end(NumericLimits::Maximum()) {}; + + const char &operator[](size_t i) const { + if (i < buffer->actual_size) { + auto buffer_ptr = buffer->Ptr(); + return buffer_ptr[i]; + } + auto next_ptr = next_buffer->Ptr(); + return next_ptr[i - buffer->actual_size]; + } + + string_t GetValue(idx_t start_buffer, idx_t position_buffer, idx_t offset) { + idx_t length = position_buffer - start_buffer - offset; + // 1) It's all in the current buffer + if (start_buffer + length <= buffer->actual_size) { + auto buffer_ptr = buffer->Ptr(); + return string_t(buffer_ptr + start_buffer, length); + } else if (start_buffer >= buffer->actual_size) { + // 2) It's all in the next buffer + D_ASSERT(next_buffer); + D_ASSERT(next_buffer->actual_size >= length + (start_buffer - buffer->actual_size)); + auto buffer_ptr = next_buffer->Ptr(); + return string_t(buffer_ptr + (start_buffer - buffer->actual_size), length); + } else { + // 3) It starts in the current buffer and ends in the next buffer + D_ASSERT(next_buffer); + auto intersection = make_unsafe_uniq_array(length); + idx_t cur_pos = 0; + auto buffer_ptr = buffer->Ptr(); + for (idx_t i = start_buffer; i < buffer->actual_size; i++) { + intersection[cur_pos++] = buffer_ptr[i]; + } + idx_t nxt_buffer_pos = 0; + auto next_buffer_ptr = next_buffer->Ptr(); + for (; cur_pos < length; cur_pos++) { + intersection[cur_pos] = next_buffer_ptr[nxt_buffer_pos++]; + } + intersections.emplace_back(std::move(intersection)); + return string_t(intersections.back().get(), length); + } + } + + unique_ptr buffer; + unique_ptr next_buffer; + vector> intersections; + optional_ptr line_info; + + idx_t buffer_start; + idx_t buffer_end; + idx_t batch_index; + idx_t local_batch_index; + idx_t lines_read = 0; +}; + +struct VerificationPositions { + idx_t beginning_of_first_line = 0; + idx_t end_of_last_line = 0; +}; + +//! CSV Reader for Parallel Reading +class ParallelCSVReader : public BaseCSVReader { +public: + ParallelCSVReader(ClientContext &context, CSVReaderOptions options, unique_ptr buffer, + idx_t first_pos_first_buffer, const vector &requested_types, idx_t file_idx_p); + virtual ~ParallelCSVReader() { + } + + //! Current Position (Relative to the Buffer) + idx_t position_buffer = 0; + + //! Start of the piece of the buffer this thread should read + idx_t start_buffer = 0; + //! End of the piece of this buffer this thread should read + idx_t end_buffer = NumericLimits::Maximum(); + //! The actual buffer size + idx_t buffer_size = 0; + + //! If this flag is set, it means we are about to try to read our last row. + bool reached_remainder_state = false; + + bool finished = false; + + unique_ptr buffer; + + idx_t file_idx; + + VerificationPositions GetVerificationPositions(); + + //! Position of the first read line and last read line for verification purposes + VerificationPositions verification_positions; + +public: + void SetBufferRead(unique_ptr buffer); + //! Extract a single DataChunk from the CSV file and stores it in insert_chunk + void ParseCSV(DataChunk &insert_chunk); + + idx_t GetLineError(idx_t line_error, idx_t buffer_idx, bool stop_at_first = true) override; + void Increment(idx_t buffer_idx) override; + +private: + //! Initialize Parser + void Initialize(const vector &requested_types); + //! Try to parse a single datachunk from the file. Throws an exception if anything goes wrong. + void ParseCSV(ParserMode mode); + //! Try to parse a single datachunk from the file. Returns whether or not the parsing is successful + bool TryParseCSV(ParserMode mode); + //! Extract a single DataChunk from the CSV file and stores it in insert_chunk + bool TryParseCSV(ParserMode mode, DataChunk &insert_chunk, string &error_message); + //! Sets Position depending on the byte_start of this thread + bool SetPosition(); + //! Called when scanning the 1st buffer, skips empty lines + void SkipEmptyLines(); + //! When a buffer finishes reading its piece, it still can try to scan up to the real end of the buffer + //! Up to finding a new line. This function sets the buffer_end and marks a boolean variable + //! when changing the buffer end the first time. + //! It returns FALSE if the parser should jump to the final state of parsing or not + bool BufferRemainder(); + + bool NewLineDelimiter(bool carry, bool carry_followed_by_nl, bool first_char); + + //! Parses a CSV file with a one-byte delimiter, escape and quote character + bool TryParseSimpleCSV(DataChunk &insert_chunk, string &error_message, bool try_add_line = false); + + //! First Position of First Buffer + idx_t first_pos_first_buffer = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/csv/quote_rules.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/quote_rules.hpp new file mode 100644 index 00000000..4dc76717 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/csv/quote_rules.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/csv/quote_rules.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/vector.hpp" + +namespace duckdb { +//! Different Rules regarding possible combinations of Quote and Escape Values for CSV Dialects. +//! Each rule has a comment on the possible combinations. +enum class QuoteRule : uint8_t { + QUOTES_RFC = 0, //! quote = " escape = (\0 || " || ') + QUOTES_OTHER = 1, //! quote = ( " || ' ) escape = '\\' + NO_QUOTES = 2 //! quote = \0 escape = \0 +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_column_data_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_column_data_scan.hpp new file mode 100644 index 00000000..17dc8edb --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_column_data_scan.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/physical_column_data_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +//! The PhysicalColumnDataScan scans a ColumnDataCollection +class PhysicalColumnDataScan : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::COLUMN_DATA_SCAN; + +public: + PhysicalColumnDataScan(vector types, PhysicalOperatorType op_type, idx_t estimated_cardinality, + unique_ptr owned_collection = nullptr); + + PhysicalColumnDataScan(vector types, PhysicalOperatorType op_type, idx_t estimated_cardinality, + idx_t cte_index) + : PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(nullptr), + cte_index(cte_index) { + } + + // the column data collection to scan + optional_ptr collection; + //! Owned column data collection, if any + unique_ptr owned_collection; + + idx_t cte_index; + +public: + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + + string ParamsToString() const override; + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_dummy_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_dummy_scan.hpp new file mode 100644 index 00000000..3e20636e --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_dummy_scan.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/physical_dummy_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +class PhysicalDummyScan : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::DUMMY_SCAN; + +public: + explicit PhysicalDummyScan(vector types, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::DUMMY_SCAN, std::move(types), estimated_cardinality) { + } + +public: + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_empty_result.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_empty_result.hpp new file mode 100644 index 00000000..9d3b0daa --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_empty_result.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/physical_empty_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +class PhysicalEmptyResult : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::EMPTY_RESULT; + +public: + explicit PhysicalEmptyResult(vector types, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::EMPTY_RESULT, std::move(types), estimated_cardinality) { + } + +public: + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_expression_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_expression_scan.hpp new file mode 100644 index 00000000..1588d524 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_expression_scan.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/physical_expression_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! The PhysicalExpressionScan scans a set of expressions +class PhysicalExpressionScan : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::EXPRESSION_SCAN; + +public: + PhysicalExpressionScan(vector types, vector>> expressions, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::EXPRESSION_SCAN, std::move(types), estimated_cardinality), + expressions(std::move(expressions)) { + } + + //! The set of expressions to scan + vector>> expressions; + +public: + unique_ptr GetOperatorState(ExecutionContext &context) const override; + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const override; + + bool ParallelOperator() const override { + return true; + } + +public: + bool IsFoldable() const; + void EvaluateExpression(ClientContext &context, idx_t expression_idx, DataChunk *child_chunk, + DataChunk &result) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_positional_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_positional_scan.hpp new file mode 100644 index 00000000..c142009d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_positional_scan.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/physical_positional_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/data_table.hpp" + +namespace duckdb { + +//! Represents a scan of a base table +class PhysicalPositionalScan : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::POSITIONAL_SCAN; + +public: + //! Regular Table Scan + PhysicalPositionalScan(vector types, unique_ptr left, + unique_ptr right); + + //! The child table functions + vector> child_tables; + +public: + bool Equals(const PhysicalOperator &other) const override; + +public: + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + double GetProgress(ClientContext &context, GlobalSourceState &gstate) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp new file mode 100644 index 00000000..e00a18ec --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/scan/physical_table_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/common/extra_operator_info.hpp" + +namespace duckdb { + +//! Represents a scan of a base table +class PhysicalTableScan : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::TABLE_SCAN; + +public: + //! Table scan that immediately projects out filter columns that are unused in the remainder of the query plan + PhysicalTableScan(vector types, TableFunction function, unique_ptr bind_data, + vector returned_types, vector column_ids, vector projection_ids, + vector names, unique_ptr table_filters, idx_t estimated_cardinality, + ExtraOperatorInfo extra_info); + + //! The table function + TableFunction function; + //! Bind data of the function + unique_ptr bind_data; + //! The types of ALL columns that can be returned by the table function + vector returned_types; + //! The column ids used within the table function + vector column_ids; + //! The projected-out column ids + vector projection_ids; + //! The names of the columns + vector names; + //! The table filters + unique_ptr table_filters; + //! Currently stores any filters applied to file names (as strings) + ExtraOperatorInfo extra_info; + +public: + string GetName() const override; + string ParamsToString() const override; + + bool Equals(const PhysicalOperator &other) const override; + +public: + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + unique_ptr GetGlobalSourceState(ClientContext &context) const override; + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + idx_t GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, + LocalSourceState &lstate) const override; + + bool IsSource() const override { + return true; + } + bool ParallelSource() const override { + return true; + } + + bool SupportsBatchIndex() const override { + return function.get_batch_index != nullptr; + } + + double GetProgress(ClientContext &context, GlobalSourceState &gstate) const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_alter.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_alter.hpp new file mode 100644 index 00000000..dfa05fa7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_alter.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_alter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" + +namespace duckdb { + +//! PhysicalAlter represents an ALTER TABLE command +class PhysicalAlter : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::ALTER; + +public: + explicit PhysicalAlter(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::ALTER, {LogicalType::BOOLEAN}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_attach.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_attach.hpp new file mode 100644 index 00000000..5fab731d --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_attach.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_attach.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/attach_info.hpp" + +namespace duckdb { + +//! PhysicalLoad represents an extension LOAD operation +class PhysicalAttach : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::ATTACH; + +public: + explicit PhysicalAttach(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::ATTACH, {LogicalType::BOOLEAN}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_art_index.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_art_index.hpp new file mode 100644 index 00000000..fb66fe39 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_art_index.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_create_art_index.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" + +#include "duckdb/storage/data_table.hpp" + +#include + +namespace duckdb { +class DuckTableEntry; + +//! Physical CREATE (UNIQUE) INDEX statement +class PhysicalCreateARTIndex : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CREATE_INDEX; + +public: + PhysicalCreateARTIndex(LogicalOperator &op, TableCatalogEntry &table, const vector &column_ids, + unique_ptr info, vector> unbound_expressions, + idx_t estimated_cardinality, const bool sorted); + + //! The table to create the index for + DuckTableEntry &table; + //! The list of column IDs required for the index + vector storage_ids; + //! Info for index creation + unique_ptr info; + //! Unbound expressions to be used in the optimizer + vector> unbound_expressions; + //! Whether the pipeline sorts the data prior to index creation + const bool sorted; + +public: + //! Source interface, NOP for this operator + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + //! Sink interface, thread-local sink states + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + //! Sink interface, global sink state + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + //! Sink for unsorted data: insert iteratively + SinkResultType SinkUnsorted(Vector &row_identifiers, OperatorSinkInput &input) const; + //! Sink for sorted data: build + merge + SinkResultType SinkSorted(Vector &row_identifiers, OperatorSinkInput &input) const; + + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const override; + + bool IsSink() const override { + return true; + } + bool ParallelSink() const override { + return true; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_function.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_function.hpp new file mode 100644 index 00000000..e28f229a --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_function.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_create_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" + +namespace duckdb { + +//! PhysicalCreateFunction represents a CREATE FUNCTION command +class PhysicalCreateFunction : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CREATE_MACRO; + +public: + explicit PhysicalCreateFunction(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::CREATE_MACRO, {LogicalType::BIGINT}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_schema.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_schema.hpp new file mode 100644 index 00000000..c3b2e395 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_schema.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_create_schema.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" + +namespace duckdb { + +//! PhysicalCreateSchema represents a CREATE SCHEMA command +class PhysicalCreateSchema : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CREATE_SCHEMA; + +public: + explicit PhysicalCreateSchema(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::CREATE_SCHEMA, {LogicalType::BIGINT}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_sequence.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_sequence.hpp new file mode 100644 index 00000000..13a3e1da --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_sequence.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_create_sequence.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" + +namespace duckdb { + +//! PhysicalCreateSequence represents a CREATE SEQUENCE command +class PhysicalCreateSequence : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CREATE_SEQUENCE; + +public: + explicit PhysicalCreateSequence(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::CREATE_SEQUENCE, {LogicalType::BIGINT}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_table.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_table.hpp new file mode 100644 index 00000000..18b0339e --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_table.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_create_table.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" + +namespace duckdb { + +//! Physically CREATE TABLE statement +class PhysicalCreateTable : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CREATE_TABLE; + +public: + PhysicalCreateTable(LogicalOperator &op, SchemaCatalogEntry &schema, unique_ptr info, + idx_t estimated_cardinality); + + //! Schema to insert to + SchemaCatalogEntry &schema; + //! Table name to create + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_type.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_type.hpp new file mode 100644 index 00000000..1d5a5036 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_type.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_create_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" + +namespace duckdb { + +//! PhysicalCreateType represents a CREATE TYPE command +class PhysicalCreateType : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CREATE_TYPE; + +public: + explicit PhysicalCreateType(unique_ptr info, idx_t estimated_cardinality); + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + + bool IsSink() const override { + return !children.empty(); + } + + bool ParallelSink() const override { + return false; + } + + bool SinkOrderDependent() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_view.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_view.hpp new file mode 100644 index 00000000..9533ad76 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_view.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_create_view.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" + +namespace duckdb { + +//! PhysicalCreateView represents a CREATE VIEW command +class PhysicalCreateView : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CREATE_VIEW; + +public: + explicit PhysicalCreateView(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::CREATE_VIEW, {LogicalType::BIGINT}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_detach.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_detach.hpp new file mode 100644 index 00000000..61a05a0f --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_detach.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_detach.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/detach_info.hpp" + +namespace duckdb { + +class PhysicalDetach : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::DETACH; + +public: + explicit PhysicalDetach(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::DETACH, {LogicalType::BOOLEAN}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_drop.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_drop.hpp new file mode 100644 index 00000000..ac7dbb55 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_drop.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/schema/physical_drop.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" + +namespace duckdb { + +//! PhysicalDrop represents a DROP [...] command +class PhysicalDrop : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::DROP; + +public: + explicit PhysicalDrop(unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::DROP, {LogicalType::BOOLEAN}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/set/physical_cte.hpp b/src/duckdb/src/include/duckdb/execution/operator/set/physical_cte.hpp new file mode 100644 index 00000000..6babb8dd --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/set/physical_cte.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/set/physical_cte.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +class RecursiveCTEState; + +class PhysicalCTE : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CTE; + +public: + PhysicalCTE(string ctename, idx_t table_index, vector types, unique_ptr top, + unique_ptr bottom, idx_t estimated_cardinality); + ~PhysicalCTE() override; + + std::shared_ptr working_table; + shared_ptr recursive_meta_pipeline; + + idx_t table_index; + string ctename; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool IsSink() const override { + return true; + } + + string ParamsToString() const override; + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + + vector> GetSources() const override; + +private: + void ExecuteRecursivePipelines(ExecutionContext &context) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp b/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp new file mode 100644 index 00000000..b40e2019 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/set/physical_recursive_cte.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +class RecursiveCTEState; + +class PhysicalRecursiveCTE : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::RECURSIVE_CTE; + +public: + PhysicalRecursiveCTE(string ctename, idx_t table_index, vector types, bool union_all, + unique_ptr top, unique_ptr bottom, + idx_t estimated_cardinality); + ~PhysicalRecursiveCTE() override; + + string ctename; + idx_t table_index; + + bool union_all; + std::shared_ptr working_table; + shared_ptr recursive_meta_pipeline; + +public: + // Source interface + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } + +public: + // Sink interface + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + + unique_ptr GetGlobalSinkState(ClientContext &context) const override; + + bool IsSink() const override { + return true; + } + + string ParamsToString() const override; + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + + vector> GetSources() const override; + +private: + //! Probe Hash Table and eliminate duplicate rows + idx_t ProbeHT(DataChunk &chunk, RecursiveCTEState &state) const; + + void ExecuteRecursivePipelines(ExecutionContext &context) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/set/physical_union.hpp b/src/duckdb/src/include/duckdb/execution/operator/set/physical_union.hpp new file mode 100644 index 00000000..05e4647c --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/set/physical_union.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/set/physical_union.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +class PhysicalUnion : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::UNION; + +public: + PhysicalUnion(vector types, unique_ptr top, unique_ptr bottom, + idx_t estimated_cardinality); + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; + vector> GetSources() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/perfect_aggregate_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/perfect_aggregate_hashtable.hpp new file mode 100644 index 00000000..211c27b3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/perfect_aggregate_hashtable.hpp @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/perfect_aggregate_hashtable.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/base_aggregate_hashtable.hpp" +#include "duckdb/storage/arena_allocator.hpp" + +namespace duckdb { + +class PerfectAggregateHashTable : public BaseAggregateHashTable { +public: + PerfectAggregateHashTable(ClientContext &context, Allocator &allocator, const vector &group_types, + vector payload_types_p, vector aggregate_objects, + vector group_minima, vector required_bits); + ~PerfectAggregateHashTable() override; + +public: + //! Add the given data to the HT + void AddChunk(DataChunk &groups, DataChunk &payload); + + //! Combines the target perfect aggregate HT into this one + void Combine(PerfectAggregateHashTable &other); + + //! Scan the HT starting from the scan_position + void Scan(idx_t &scan_position, DataChunk &result); + +protected: + Vector addresses; + //! The required bits per group + vector required_bits; + //! The total required bits for the HT (this determines the max capacity) + idx_t total_required_bits; + //! The total amount of groups + idx_t total_groups; + //! The tuple size + idx_t tuple_size; + //! The number of grouping columns + idx_t grouping_columns; + + // The actual pointer to the data + data_ptr_t data; + //! The owned data of the HT + unsafe_unique_array owned_data; + //! Information on whether or not a specific group has any entries + unsafe_unique_array group_is_set; + + //! The minimum values for each of the group columns + vector group_minima; + + //! Reused selection vector + SelectionVector sel; + + //! The active arena allocator used by the aggregates for their internal state + unique_ptr aggregate_allocator; + //! Owning arena allocators that this HT has data from + vector> stored_allocators; + +private: + //! Destroy the perfect aggregate HT (called automatically by the destructor) + void Destroy(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/physical_operator.hpp b/src/duckdb/src/include/duckdb/execution/physical_operator.hpp new file mode 100644 index 00000000..e1304673 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/physical_operator.hpp @@ -0,0 +1,246 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/physical_operator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/operator_result_type.hpp" +#include "duckdb/common/enums/physical_operator_type.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/common/optional_idx.hpp" +#include "duckdb/execution/physical_operator_states.hpp" +#include "duckdb/common/enums/order_preservation_type.hpp" + +namespace duckdb { +class Event; +class Executor; +class PhysicalOperator; +class Pipeline; +class PipelineBuildState; +class MetaPipeline; + +//! PhysicalOperator is the base class of the physical operators present in the +//! execution plan +class PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::INVALID; + +public: + PhysicalOperator(PhysicalOperatorType type, vector types, idx_t estimated_cardinality) + : type(type), types(std::move(types)), estimated_cardinality(estimated_cardinality) { + } + + virtual ~PhysicalOperator() { + } + + //! The physical operator type + PhysicalOperatorType type; + //! The set of children of the operator + vector> children; + //! The types returned by this physical operator + vector types; + //! The estimated cardinality of this physical operator + idx_t estimated_cardinality; + + //! The global sink state of this operator + unique_ptr sink_state; + //! The global state of this operator + unique_ptr op_state; + //! Lock for (re)setting any of the operator states + mutex lock; + +public: + virtual string GetName() const; + virtual string ParamsToString() const { + return ""; + } + virtual string ToString() const; + void Print() const; + virtual vector> GetChildren() const; + + //! Return a vector of the types that will be returned by this operator + const vector &GetTypes() const { + return types; + } + + virtual bool Equals(const PhysicalOperator &other) const { + return false; + } + + virtual void Verify(); + +public: + // Operator interface + virtual unique_ptr GetOperatorState(ExecutionContext &context) const; + virtual unique_ptr GetGlobalOperatorState(ClientContext &context) const; + virtual OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const; + virtual OperatorFinalizeResultType FinalExecute(ExecutionContext &context, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const; + + virtual bool ParallelOperator() const { + return false; + } + + virtual bool RequiresFinalExecute() const { + return false; + } + + //! The influence the operator has on order (insertion order means no influence) + virtual OrderPreservationType OperatorOrder() const { + return OrderPreservationType::INSERTION_ORDER; + } + +public: + // Source interface + virtual unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const; + virtual unique_ptr GetGlobalSourceState(ClientContext &context) const; + virtual SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const; + + virtual idx_t GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, + LocalSourceState &lstate) const; + + virtual bool IsSource() const { + return false; + } + + virtual bool ParallelSource() const { + return false; + } + + virtual bool SupportsBatchIndex() const { + return false; + } + + //! The type of order emitted by the operator (as a source) + virtual OrderPreservationType SourceOrder() const { + return OrderPreservationType::INSERTION_ORDER; + } + + //! Returns the current progress percentage, or a negative value if progress bars are not supported + virtual double GetProgress(ClientContext &context, GlobalSourceState &gstate) const; + +public: + // Sink interface + + //! The sink method is called constantly with new input, as long as new input is available. Note that this method + //! CAN be called in parallel, proper locking is needed when accessing data inside the GlobalSinkState. + virtual SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const; + // The combine is called when a single thread has completed execution of its part of the pipeline, it is the final + // time that a specific LocalSinkState is accessible. This method can be called in parallel while other Sink() or + // Combine() calls are active on the same GlobalSinkState. + virtual SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const; + //! The finalize is called when ALL threads are finished execution. It is called only once per pipeline, and is + //! entirely single threaded. + //! If Finalize returns SinkResultType::FINISHED, the sink is marked as finished + virtual SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const; + //! For sinks with RequiresBatchIndex set to true, when a new batch starts being processed this method is called + //! This allows flushing of the current batch (e.g. to disk) TODO: should this be able to block too? + virtual void NextBatch(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate_p) const; + + virtual unique_ptr GetLocalSinkState(ExecutionContext &context) const; + virtual unique_ptr GetGlobalSinkState(ClientContext &context) const; + + //! The maximum amount of memory the operator should use per thread. + static idx_t GetMaxThreadMemory(ClientContext &context); + + //! Whether operator caching is allowed in the current execution context + static bool OperatorCachingAllowed(ExecutionContext &context); + + virtual bool IsSink() const { + return false; + } + + virtual bool ParallelSink() const { + return false; + } + + virtual bool RequiresBatchIndex() const { + return false; + } + + //! Whether or not the sink operator depends on the order of the input chunks + //! If this is set to true, we cannot do things like caching intermediate vectors + virtual bool SinkOrderDependent() const { + return false; + } + +public: + // Pipeline construction + virtual vector> GetSources() const; + bool AllSourcesSupportBatchIndex() const; + + virtual void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline); + +public: + template + TARGET &Cast() { + if (TARGET::TYPE != PhysicalOperatorType::INVALID && type != TARGET::TYPE) { + throw InternalException("Failed to cast physical operator to type - physical operator type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (TARGET::TYPE != PhysicalOperatorType::INVALID && type != TARGET::TYPE) { + throw InternalException("Failed to cast physical operator to type - physical operator type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +//! Contains state for the CachingPhysicalOperator +class CachingOperatorState : public OperatorState { +public: + ~CachingOperatorState() override { + } + + void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { + } + + unique_ptr cached_chunk; + bool initialized = false; + //! Whether or not the chunk can be cached + bool can_cache_chunk = false; +}; + +//! Base class that caches output from child Operator class. Note that Operators inheriting from this class should also +//! inherit their state class from the CachingOperatorState. +class CachingPhysicalOperator : public PhysicalOperator { +public: + static constexpr const idx_t CACHE_THRESHOLD = 64; + CachingPhysicalOperator(PhysicalOperatorType type, vector types, idx_t estimated_cardinality); + + bool caching_supported; + +public: + OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const final; + OperatorFinalizeResultType FinalExecute(ExecutionContext &context, DataChunk &chunk, GlobalOperatorState &gstate, + OperatorState &state) const final; + + bool RequiresFinalExecute() const final { + return caching_supported; + } + +protected: + //! Child classes need to implement the ExecuteInternal method instead of the Execute + virtual OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &state) const = 0; + +private: + bool CanCacheType(const LogicalType &type); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp b/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp new file mode 100644 index 00000000..e3963efb --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp @@ -0,0 +1,181 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/physical_operator_states.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/operator_result_type.hpp" +#include "duckdb/common/enums/physical_operator_type.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/optimizer/join_order/join_node.hpp" + +namespace duckdb { +class Event; +class Executor; +class PhysicalOperator; +class Pipeline; +class PipelineBuildState; +class MetaPipeline; +class InterruptState; + +struct SourcePartitionInfo { + //! The current batch index + //! This is only set in case RequiresBatchIndex() is true, and the source has support for it (SupportsBatchIndex()) + //! Otherwise this is left on INVALID_INDEX + //! The batch index is a globally unique, increasing index that should be used to maintain insertion order + //! //! in conjunction with parallelism + optional_idx batch_index; + //! The minimum batch index that any thread is currently actively reading + optional_idx min_batch_index; +}; + +// LCOV_EXCL_START +class OperatorState { +public: + virtual ~OperatorState() { + } + + virtual void Finalize(const PhysicalOperator &op, ExecutionContext &context) { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +class GlobalOperatorState { +public: + virtual ~GlobalOperatorState() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +class GlobalSinkState { +public: + GlobalSinkState() : state(SinkFinalizeType::READY) { + } + virtual ~GlobalSinkState() { + } + + SinkFinalizeType state; + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +class LocalSinkState { +public: + virtual ~LocalSinkState() { + } + + //! Source partition info + SourcePartitionInfo partition_info; + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +class GlobalSourceState { +public: + virtual ~GlobalSourceState() { + } + + virtual idx_t MaxThreads() { + return 1; + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +class LocalSourceState { +public: + virtual ~LocalSourceState() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct OperatorSinkInput { + GlobalSinkState &global_state; + LocalSinkState &local_state; + InterruptState &interrupt_state; +}; + +struct OperatorSourceInput { + GlobalSourceState &global_state; + LocalSourceState &local_state; + InterruptState &interrupt_state; +}; + +struct OperatorSinkCombineInput { + GlobalSinkState &global_state; + LocalSinkState &local_state; + InterruptState &interrupt_state; +}; + +struct OperatorSinkFinalizeInput { + GlobalSinkState &global_state; + InterruptState &interrupt_state; +}; + +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp b/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp new file mode 100644 index 00000000..f17dd1b8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp @@ -0,0 +1,108 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/physical_plan_generator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/logical_tokens.hpp" +#include "duckdb/planner/operator/logical_limit_percent.hpp" +#include "duckdb/catalog/dependency_list.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { +class ClientContext; +class ColumnDataCollection; + +//! The physical plan generator generates a physical execution plan from a +//! logical query plan +class PhysicalPlanGenerator { +public: + explicit PhysicalPlanGenerator(ClientContext &context); + ~PhysicalPlanGenerator(); + + DependencyList dependencies; + //! Recursive CTEs require at least one ChunkScan, referencing the working_table. + //! This data structure is used to establish it. + unordered_map> recursive_cte_tables; + //! Materialized CTE ids must be collected. + unordered_set materialized_ctes; + +public: + //! Creates a plan from the logical operator. This involves resolving column bindings and generating physical + //! operator nodes. + unique_ptr CreatePlan(unique_ptr logical); + + //! Whether or not we can (or should) use a batch-index based operator for executing the given sink + static bool UseBatchIndex(ClientContext &context, PhysicalOperator &plan); + //! Whether or not we should preserve insertion order for executing the given sink + static bool PreserveInsertionOrder(ClientContext &context, PhysicalOperator &plan); + +protected: + unique_ptr CreatePlan(LogicalOperator &op); + + unique_ptr CreatePlan(LogicalAggregate &op); + unique_ptr CreatePlan(LogicalAnyJoin &op); + unique_ptr CreatePlan(LogicalColumnDataGet &op); + unique_ptr CreatePlan(LogicalComparisonJoin &op); + unique_ptr CreatePlan(LogicalCreate &op); + unique_ptr CreatePlan(LogicalCreateTable &op); + unique_ptr CreatePlan(LogicalCreateIndex &op); + unique_ptr CreatePlan(LogicalCrossProduct &op); + unique_ptr CreatePlan(LogicalDelete &op); + unique_ptr CreatePlan(LogicalDelimGet &op); + unique_ptr CreatePlan(LogicalDistinct &op); + unique_ptr CreatePlan(LogicalDummyScan &expr); + unique_ptr CreatePlan(LogicalEmptyResult &op); + unique_ptr CreatePlan(LogicalExpressionGet &op); + unique_ptr CreatePlan(LogicalExport &op); + unique_ptr CreatePlan(LogicalFilter &op); + unique_ptr CreatePlan(LogicalGet &op); + unique_ptr CreatePlan(LogicalLimit &op); + unique_ptr CreatePlan(LogicalLimitPercent &op); + unique_ptr CreatePlan(LogicalOrder &op); + unique_ptr CreatePlan(LogicalTopN &op); + unique_ptr CreatePlan(LogicalPositionalJoin &op); + unique_ptr CreatePlan(LogicalProjection &op); + unique_ptr CreatePlan(LogicalInsert &op); + unique_ptr CreatePlan(LogicalCopyToFile &op); + unique_ptr CreatePlan(LogicalExplain &op); + unique_ptr CreatePlan(LogicalSetOperation &op); + unique_ptr CreatePlan(LogicalUpdate &op); + unique_ptr CreatePlan(LogicalPrepare &expr); + unique_ptr CreatePlan(LogicalWindow &expr); + unique_ptr CreatePlan(LogicalExecute &op); + unique_ptr CreatePlan(LogicalPragma &op); + unique_ptr CreatePlan(LogicalSample &op); + unique_ptr CreatePlan(LogicalSet &op); + unique_ptr CreatePlan(LogicalReset &op); + unique_ptr CreatePlan(LogicalShow &op); + unique_ptr CreatePlan(LogicalSimple &op); + unique_ptr CreatePlan(LogicalUnnest &op); + unique_ptr CreatePlan(LogicalRecursiveCTE &op); + unique_ptr CreatePlan(LogicalMaterializedCTE &op); + unique_ptr CreatePlan(LogicalCTERef &op); + unique_ptr CreatePlan(LogicalPivot &op); + + unique_ptr PlanAsOfJoin(LogicalComparisonJoin &op); + unique_ptr PlanComparisonJoin(LogicalComparisonJoin &op); + unique_ptr PlanDelimJoin(LogicalComparisonJoin &op); + unique_ptr ExtractAggregateExpressions(unique_ptr child, + vector> &expressions, + vector> &groups); + +private: + bool PreserveInsertionOrder(PhysicalOperator &plan); + bool UseBatchIndex(PhysicalOperator &plan); + +private: + ClientContext &context; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/radix_partitioned_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/radix_partitioned_hashtable.hpp new file mode 100644 index 00000000..c9827a35 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/radix_partitioned_hashtable.hpp @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/radix_partitioned_hashtable.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/row/tuple_data_layout.hpp" +#include "duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp" +#include "duckdb/parser/group_by_node.hpp" + +namespace duckdb { + +class GroupedAggregateHashTable; +struct AggregatePartition; + +class RadixPartitionedHashTable { +public: + RadixPartitionedHashTable(GroupingSet &grouping_set, const GroupedAggregateData &op); + unique_ptr CreateHT(ClientContext &context, const idx_t capacity, + const idx_t radix_bits) const; + +public: + GroupingSet &grouping_set; + //! The indices specified in the groups_count that do not appear in the grouping_set + unsafe_vector null_groups; + const GroupedAggregateData &op; + vector group_types; + //! The GROUPING values that belong to this hash table + vector grouping_values; + +public: + //! Sink Interface + unique_ptr GetGlobalSinkState(ClientContext &context) const; + unique_ptr GetLocalSinkState(ExecutionContext &context) const; + + void Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, DataChunk &aggregate_input_chunk, + const unsafe_vector &filter) const; + void Combine(ExecutionContext &context, GlobalSinkState &gstate, LocalSinkState &lstate) const; + void Finalize(ClientContext &context, GlobalSinkState &gstate) const; + +public: + //! Source interface + unique_ptr GetGlobalSourceState(ClientContext &context) const; + unique_ptr GetLocalSourceState(ExecutionContext &context) const; + + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, GlobalSinkState &sink, + OperatorSourceInput &input) const; + + const TupleDataLayout &GetLayout() const; + idx_t NumberOfPartitions(GlobalSinkState &sink) const; + static void SetMultiScan(GlobalSinkState &sink); + +private: + void SetGroupingValues(); + void PopulateGroupChunk(DataChunk &group_chunk, DataChunk &input_chunk) const; + + TupleDataLayout layout; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/reservoir_sample.hpp b/src/duckdb/src/include/duckdb/execution/reservoir_sample.hpp new file mode 100644 index 00000000..7af05d5c --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/reservoir_sample.hpp @@ -0,0 +1,122 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/reservoir_sample.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/random_engine.hpp" +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/common/queue.hpp" + +namespace duckdb { + +class BaseReservoirSampling { +public: + explicit BaseReservoirSampling(int64_t seed); + BaseReservoirSampling(); + + void InitializeReservoir(idx_t cur_size, idx_t sample_size); + + void SetNextEntry(); + + void ReplaceElement(); + + //! The random generator + RandomEngine random; + //! Priority queue of [random element, index] for each of the elements in the sample + std::priority_queue> reservoir_weights; + //! The next element to sample + idx_t next_index; + //! The reservoir threshold of the current min entry + double min_threshold; + //! The reservoir index of the current min entry + idx_t min_entry; + //! The current count towards next index (i.e. we will replace an entry in next_index - current_count tuples) + idx_t current_count; +}; + +class BlockingSample { +public: + explicit BlockingSample(int64_t seed) : base_reservoir_sample(seed), random(base_reservoir_sample.random) { + } + virtual ~BlockingSample() { + } + + //! Add a chunk of data to the sample + virtual void AddToReservoir(DataChunk &input) = 0; + + //! Fetches a chunk from the sample. Note that this method is destructive and should only be used after the + // sample is completely built. + virtual unique_ptr GetChunk() = 0; + +protected: + //! The reservoir sampling + BaseReservoirSampling base_reservoir_sample; + RandomEngine &random; +}; + +//! The reservoir sample class maintains a streaming sample of fixed size "sample_count" +class ReservoirSample : public BlockingSample { +public: + ReservoirSample(Allocator &allocator, idx_t sample_count, int64_t seed); + + //! Add a chunk of data to the sample + void AddToReservoir(DataChunk &input) override; + + //! Fetches a chunk from the sample. Note that this method is destructive and should only be used after the + //! sample is completely built. + unique_ptr GetChunk() override; + +private: + //! Replace a single element of the input + void ReplaceElement(DataChunk &input, idx_t index_in_chunk); + + //! Fills the reservoir up until sample_count entries, returns how many entries are still required + idx_t FillReservoir(DataChunk &input); + +private: + //! The size of the reservoir sample + idx_t sample_count; + //! The current reservoir + ChunkCollection reservoir; +}; + +//! The reservoir sample sample_size class maintains a streaming sample of variable size +class ReservoirSamplePercentage : public BlockingSample { + constexpr static idx_t RESERVOIR_THRESHOLD = 100000; + +public: + ReservoirSamplePercentage(Allocator &allocator, double percentage, int64_t seed); + + //! Add a chunk of data to the sample + void AddToReservoir(DataChunk &input) override; + + //! Fetches a chunk from the sample. Note that this method is destructive and should only be used after the + //! sample is completely built. + unique_ptr GetChunk() override; + +private: + void Finalize(); + +private: + Allocator &allocator; + //! The sample_size to sample + double sample_percentage; + //! The fixed sample size of the sub-reservoirs + idx_t reservoir_sample_size; + //! The current sample + unique_ptr current_sample; + //! The set of finished samples of the reservoir sample + vector> finished_samples; + //! The amount of tuples that have been processed so far + idx_t current_count = 0; + //! Whether or not the stream is finalized. The stream is automatically finalized on the first call to GetChunk(); + bool is_finalized; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/window_executor.hpp b/src/duckdb/src/include/duckdb/execution/window_executor.hpp new file mode 100644 index 00000000..beadd3d6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/window_executor.hpp @@ -0,0 +1,313 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/window_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/window_segment_tree.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +namespace duckdb { + +struct WindowInputExpression { + static void PrepareInputExpression(Expression &expr, ExpressionExecutor &executor, DataChunk &chunk) { + vector types; + types.push_back(expr.return_type); + executor.AddExpression(expr); + + auto &allocator = executor.GetAllocator(); + chunk.Initialize(allocator, types); + } + + WindowInputExpression(optional_ptr expr_p, ClientContext &context) + : expr(expr_p), ptype(PhysicalType::INVALID), scalar(true), executor(context) { + if (expr) { + PrepareInputExpression(*expr, executor, chunk); + ptype = expr->return_type.InternalType(); + scalar = expr->IsScalar(); + } + } + + void Execute(DataChunk &input_chunk) { + if (expr) { + chunk.Reset(); + executor.Execute(input_chunk, chunk); + chunk.Verify(); + } + } + + template + inline T GetCell(idx_t i) const { + D_ASSERT(!chunk.data.empty()); + const auto data = FlatVector::GetData(chunk.data[0]); + return data[scalar ? 0 : i]; + } + + inline bool CellIsNull(idx_t i) const { + D_ASSERT(!chunk.data.empty()); + if (chunk.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR) { + return ConstantVector::IsNull(chunk.data[0]); + } + return FlatVector::IsNull(chunk.data[0], i); + } + + inline void CopyCell(Vector &target, idx_t target_offset) const { + D_ASSERT(!chunk.data.empty()); + auto &source = chunk.data[0]; + auto source_offset = scalar ? 0 : target_offset; + VectorOperations::Copy(source, target, source_offset + 1, source_offset, target_offset); + } + + optional_ptr expr; + PhysicalType ptype; + bool scalar; + ExpressionExecutor executor; + DataChunk chunk; +}; + +struct WindowInputColumn { + WindowInputColumn(Expression *expr_p, ClientContext &context, idx_t capacity_p) + : input_expr(expr_p, context), count(0), capacity(capacity_p) { + if (input_expr.expr) { + target = make_uniq(input_expr.chunk.data[0].GetType(), capacity); + } + } + + void Append(DataChunk &input_chunk) { + if (input_expr.expr) { + const auto source_count = input_chunk.size(); + D_ASSERT(count + source_count <= capacity); + if (!input_expr.scalar || !count) { + input_expr.Execute(input_chunk); + auto &source = input_expr.chunk.data[0]; + VectorOperations::Copy(source, *target, source_count, 0, count); + } + count += source_count; + } + } + + inline bool CellIsNull(idx_t i) const { + D_ASSERT(target); + D_ASSERT(i < count); + return FlatVector::IsNull(*target, input_expr.scalar ? 0 : i); + } + + template + inline T GetCell(idx_t i) const { + D_ASSERT(target); + D_ASSERT(i < count); + const auto data = FlatVector::GetData(*target); + return data[input_expr.scalar ? 0 : i]; + } + + WindowInputExpression input_expr; + +private: + unique_ptr target; + idx_t count; + idx_t capacity; +}; + +// Column indexes of the bounds chunk +enum WindowBounds : uint8_t { PARTITION_BEGIN, PARTITION_END, PEER_BEGIN, PEER_END, WINDOW_BEGIN, WINDOW_END }; + +class WindowExecutorState { +public: + WindowExecutorState() {}; + virtual ~WindowExecutorState() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +class WindowExecutor { +public: + WindowExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + virtual ~WindowExecutor() { + } + + virtual void Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) { + range.Append(input_chunk); + } + + virtual void Finalize() { + } + + virtual unique_ptr GetExecutorState() const; + + void Evaluate(idx_t row_idx, DataChunk &input_chunk, Vector &result, WindowExecutorState &lstate) const; + +protected: + // The function + BoundWindowExpression &wexpr; + ClientContext &context; + const idx_t payload_count; + const ValidityMask &partition_mask; + const ValidityMask &order_mask; + + // Expression collections + DataChunk payload_collection; + ExpressionExecutor payload_executor; + DataChunk payload_chunk; + + // evaluate RANGE expressions, if needed + WindowInputColumn range; + + virtual void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const = 0; +}; + +class WindowAggregateExecutor : public WindowExecutor { +public: + bool IsConstantAggregate(); + bool IsCustomAggregate(); + + WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask, + WindowAggregationMode mode); + + void Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) override; + void Finalize() override; + + unique_ptr GetExecutorState() const override; + + const WindowAggregationMode mode; + +protected: + ExpressionExecutor filter_executor; + SelectionVector filter_sel; + + // aggregate computation algorithm + unique_ptr aggregator; + + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +class WindowRowNumberExecutor : public WindowExecutor { +public: + WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +// Base class for non-aggregate functions that use peer boundaries +class WindowRankExecutor : public WindowExecutor { +public: + WindowRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + + unique_ptr GetExecutorState() const override; + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +class WindowDenseRankExecutor : public WindowExecutor { +public: + WindowDenseRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + + unique_ptr GetExecutorState() const override; + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +class WindowPercentRankExecutor : public WindowExecutor { +public: + WindowPercentRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + + unique_ptr GetExecutorState() const override; + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +class WindowCumeDistExecutor : public WindowExecutor { +public: + WindowCumeDistExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +// Base class for non-aggregate functions that have a payload +class WindowValueExecutor : public WindowExecutor { +public: + WindowValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + + void Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) override; + +protected: + // IGNORE NULLS + ValidityMask ignore_nulls; +}; + +// +class WindowNtileExecutor : public WindowValueExecutor { +public: + WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; +class WindowLeadLagExecutor : public WindowValueExecutor { +public: + WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + + unique_ptr GetExecutorState() const override; + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +class WindowFirstValueExecutor : public WindowValueExecutor { +public: + WindowFirstValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +class WindowLastValueExecutor : public WindowValueExecutor { +public: + WindowLastValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +class WindowNthValueExecutor : public WindowValueExecutor { +public: + WindowNthValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + +protected: + void EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, idx_t row_idx) const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/window_segment_tree.hpp b/src/duckdb/src/include/duckdb/execution/window_segment_tree.hpp new file mode 100644 index 00000000..12786a19 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/window_segment_tree.hpp @@ -0,0 +1,152 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/window_segment_tree.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/common/enums/window_aggregation_mode.hpp" +#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" + +namespace duckdb { + +class WindowAggregatorState { +public: + WindowAggregatorState(); + virtual ~WindowAggregatorState() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + + //! Allocator for aggregates + ArenaAllocator allocator; +}; + +class WindowAggregator { +public: + WindowAggregator(AggregateObject aggr, const LogicalType &result_type_p, idx_t partition_count); + virtual ~WindowAggregator(); + + // Build + virtual void Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered); + virtual void Finalize(); + + // Probe + virtual unique_ptr GetLocalState() const = 0; + virtual void Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, Vector &result, + idx_t count) const = 0; + +protected: + AggregateObject aggr; + //! The result type of the window function + LogicalType result_type; + + //! The cardinality of the partition + const idx_t partition_count; + //! The size of a single aggregate state + const idx_t state_size; + //! Partition data chunk + DataChunk inputs; + + //! The filtered rows in inputs. + vector filter_bits; + ValidityMask filter_mask; + idx_t filter_pos; + //! The state used by the aggregator to build. + unique_ptr gstate; +}; + +class WindowConstantAggregator : public WindowAggregator { +public: + WindowConstantAggregator(AggregateObject aggr, const LogicalType &result_type_p, const ValidityMask &partition_mask, + const idx_t count); + ~WindowConstantAggregator() override { + } + + void Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) override; + void Finalize() override; + + unique_ptr GetLocalState() const override; + void Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, Vector &result, + idx_t count) const override; + +private: + void AggregateInit(); + void AggegateFinal(Vector &result, idx_t rid); + + //! Partition starts + vector partition_offsets; + //! Aggregate results + unique_ptr results; + //! The current result partition being built/read + idx_t partition; + //! The current input row being built/read + idx_t row; + //! Data pointer that contains a single state, used for intermediate window segment aggregation + vector state; + //! A vector of pointers to "state", used for intermediate window segment aggregation + Vector statep; + //! Reused result state container for the window functions + Vector statef; +}; + +class WindowCustomAggregator : public WindowAggregator { +public: + WindowCustomAggregator(AggregateObject aggr, const LogicalType &result_type_p, idx_t partition_count); + ~WindowCustomAggregator() override; + + unique_ptr GetLocalState() const override; + void Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, Vector &result, + idx_t count) const override; +}; + +class WindowSegmentTree : public WindowAggregator { +public: + WindowSegmentTree(AggregateObject aggr, const LogicalType &result_type, idx_t count, WindowAggregationMode mode_p); + ~WindowSegmentTree() override; + + void Finalize() override; + + unique_ptr GetLocalState() const override; + void Evaluate(WindowAggregatorState &lstate, const idx_t *begins, const idx_t *ends, Vector &result, + idx_t count) const override; + +public: + void ConstructTree(); + + //! Use the combine API, if available + inline bool UseCombineAPI() const { + return mode < WindowAggregationMode::SEPARATE; + } + + //! The actual window segment tree: an array of aggregate states that represent all the intermediate nodes + unsafe_unique_array levels_flat_native; + //! For each level, the starting location in the levels_flat_native array + vector levels_flat_start; + + //! The total number of internal nodes of the tree, stored in levels_flat_native + idx_t internal_nodes; + + //! Use the combine API, if available + WindowAggregationMode mode; + + // TREE_FANOUT needs to cleanly divide STANDARD_VECTOR_SIZE + static constexpr idx_t TREE_FANOUT = 16; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/aggregate/distributive_functions.hpp b/src/duckdb/src/include/duckdb/function/aggregate/distributive_functions.hpp new file mode 100644 index 00000000..cba1a7de --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/aggregate/distributive_functions.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/aggregate/distributive_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { + +struct CountStarFun { + static AggregateFunction GetFunction(); + + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct CountFun { + static AggregateFunction GetFunction(); + + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct FirstFun { + static AggregateFunction GetFunction(const LogicalType &type); + + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp new file mode 100644 index 00000000..56b712a3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp @@ -0,0 +1,267 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/aggregate_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/vector_operations/aggregate_executor.hpp" +#include "duckdb/function/aggregate_state.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! The type used for sizing hashed aggregate function states +typedef idx_t (*aggregate_size_t)(); +//! The type used for initializing hashed aggregate function states +typedef void (*aggregate_initialize_t)(data_ptr_t state); +//! The type used for updating hashed aggregate functions +typedef void (*aggregate_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + Vector &state, idx_t count); +//! The type used for combining hashed aggregate states +typedef void (*aggregate_combine_t)(Vector &state, Vector &combined, AggregateInputData &aggr_input_data, idx_t count); +//! The type used for finalizing hashed aggregate function payloads +typedef void (*aggregate_finalize_t)(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + idx_t offset); +//! The type used for propagating statistics in aggregate functions (optional) +typedef unique_ptr (*aggregate_statistics_t)(ClientContext &context, BoundAggregateExpression &expr, + AggregateStatisticsInput &input); +//! Binds the scalar function and creates the function data +typedef unique_ptr (*bind_aggregate_function_t)(ClientContext &context, AggregateFunction &function, + vector> &arguments); +//! The type used for the aggregate destructor method. NOTE: this method is used in destructors and MAY NOT throw. +typedef void (*aggregate_destructor_t)(Vector &state, AggregateInputData &aggr_input_data, idx_t count); + +//! The type used for updating simple (non-grouped) aggregate functions +typedef void (*aggregate_simple_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + data_ptr_t state, idx_t count); + +//! The type used for updating complex windowed aggregate functions (optional) +typedef void (*aggregate_window_t)(Vector inputs[], const ValidityMask &filter_mask, + AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, + const FrameBounds &frame, const FrameBounds &prev, Vector &result, idx_t rid, + idx_t bias); + +typedef void (*aggregate_serialize_t)(Serializer &serializer, const optional_ptr bind_data, + const AggregateFunction &function); +typedef unique_ptr (*aggregate_deserialize_t)(Deserializer &deserializer, AggregateFunction &function); + +class AggregateFunction : public BaseScalarFunction { +public: + AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, + aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, + aggregate_combine_t combine, aggregate_finalize_t finalize, + FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, + aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, + aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, + aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr, + aggregate_deserialize_t deserialize = nullptr) + : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS, + LogicalType(LogicalTypeId::INVALID), null_handling), + state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), + simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), + serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { + } + + AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, + aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, + aggregate_combine_t combine, aggregate_finalize_t finalize, + aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, + aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, + aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr, + aggregate_deserialize_t deserialize = nullptr) + : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS, + LogicalType(LogicalTypeId::INVALID)), + state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), + simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), + serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { + } + + AggregateFunction(const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, + aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, + aggregate_finalize_t finalize, + FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, + aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, + aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, + aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr, + aggregate_deserialize_t deserialize = nullptr) + : AggregateFunction(string(), arguments, return_type, state_size, initialize, update, combine, finalize, + null_handling, simple_update, bind, destructor, statistics, window, serialize, + deserialize) { + } + + AggregateFunction(const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, + aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, + aggregate_finalize_t finalize, aggregate_simple_update_t simple_update = nullptr, + bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, + aggregate_statistics_t statistics = nullptr, aggregate_window_t window = nullptr, + aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) + : AggregateFunction(string(), arguments, return_type, state_size, initialize, update, combine, finalize, + FunctionNullHandling::DEFAULT_NULL_HANDLING, simple_update, bind, destructor, statistics, + window, serialize, deserialize) { + } + //! The hashed aggregate state sizing function + aggregate_size_t state_size; + //! The hashed aggregate state initialization function + aggregate_initialize_t initialize; + //! The hashed aggregate update state function + aggregate_update_t update; + //! The hashed aggregate combine states function + aggregate_combine_t combine; + //! The hashed aggregate finalization function + aggregate_finalize_t finalize; + //! The simple aggregate update function (may be null) + aggregate_simple_update_t simple_update; + //! The windowed aggregate frame update function (may be null) + aggregate_window_t window; + + //! The bind function (may be null) + bind_aggregate_function_t bind; + //! The destructor method (may be null) + aggregate_destructor_t destructor; + + //! The statistics propagation function (may be null) + aggregate_statistics_t statistics; + + aggregate_serialize_t serialize; + aggregate_deserialize_t deserialize; + //! Whether or not the aggregate is order dependent + AggregateOrderDependent order_dependent; + + bool operator==(const AggregateFunction &rhs) const { + return state_size == rhs.state_size && initialize == rhs.initialize && update == rhs.update && + combine == rhs.combine && finalize == rhs.finalize && window == rhs.window; + } + bool operator!=(const AggregateFunction &rhs) const { + return !(*this == rhs); + } + +public: + template + static AggregateFunction NullaryAggregate(LogicalType return_type) { + return AggregateFunction( + {}, return_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateFunction::NullaryScatterUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, AggregateFunction::NullaryUpdate); + } + + template + static AggregateFunction + UnaryAggregate(const LogicalType &input_type, LogicalType return_type, + FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING) { + return AggregateFunction( + {input_type}, return_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, AggregateFunction::UnaryScatterUpdate, + AggregateFunction::StateCombine, AggregateFunction::StateFinalize, + null_handling, AggregateFunction::UnaryUpdate); + } + + template + static AggregateFunction UnaryAggregateDestructor(LogicalType input_type, LogicalType return_type) { + auto aggregate = UnaryAggregate(input_type, return_type); + aggregate.destructor = AggregateFunction::StateDestroy; + return aggregate; + } + + template + static AggregateFunction BinaryAggregate(const LogicalType &a_type, const LogicalType &b_type, + LogicalType return_type) { + return AggregateFunction({a_type, b_type}, return_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateFunction::BinaryScatterUpdate, + AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, + AggregateFunction::BinaryUpdate); + } + +public: + template + static idx_t StateSize() { + return sizeof(STATE); + } + + template + static void StateInitialize(data_ptr_t state) { + OP::Initialize(*reinterpret_cast(state)); + } + + template + static void NullaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + Vector &states, idx_t count) { + D_ASSERT(input_count == 0); + AggregateExecutor::NullaryScatter(states, aggr_input_data, count); + } + + template + static void NullaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, + idx_t count) { + D_ASSERT(input_count == 0); + AggregateExecutor::NullaryUpdate(state, aggr_input_data, count); + } + + template + static void UnaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + Vector &states, idx_t count) { + D_ASSERT(input_count == 1); + AggregateExecutor::UnaryScatter(inputs[0], states, aggr_input_data, count); + } + + template + static void UnaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, + idx_t count) { + D_ASSERT(input_count == 1); + AggregateExecutor::UnaryUpdate(inputs[0], aggr_input_data, state, count); + } + + template + static void UnaryWindow(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, + idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, + Vector &result, idx_t rid, idx_t bias) { + D_ASSERT(input_count == 1); + AggregateExecutor::UnaryWindow(inputs[0], filter_mask, aggr_input_data, + state, frame, prev, result, rid, bias); + } + + template + static void BinaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + Vector &states, idx_t count) { + D_ASSERT(input_count == 2); + AggregateExecutor::BinaryScatter(aggr_input_data, inputs[0], inputs[1], states, + count); + } + + template + static void BinaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, + idx_t count) { + D_ASSERT(input_count == 2); + AggregateExecutor::BinaryUpdate(aggr_input_data, inputs[0], inputs[1], state, count); + } + + template + static void StateCombine(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) { + AggregateExecutor::Combine(source, target, aggr_input_data, count); + } + + template + static void StateFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + idx_t offset) { + AggregateExecutor::Finalize(states, aggr_input_data, result, count, offset); + } + + template + static void StateVoidFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + idx_t offset) { + AggregateExecutor::VoidFinalize(states, aggr_input_data, result, count, offset); + } + + template + static void StateDestroy(Vector &states, AggregateInputData &aggr_input_data, idx_t count) { + AggregateExecutor::Destroy(states, aggr_input_data, count); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/aggregate_state.hpp b/src/duckdb/src/include/duckdb/function/aggregate_state.hpp new file mode 100644 index 00000000..66b7338c --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/aggregate_state.hpp @@ -0,0 +1,95 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/aggregate_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/node_statistics.hpp" + +namespace duckdb { + +enum class AggregateType : uint8_t { NON_DISTINCT = 1, DISTINCT = 2 }; +//! Whether or not the input order influences the result of the aggregate +enum class AggregateOrderDependent : uint8_t { ORDER_DEPENDENT = 1, NOT_ORDER_DEPENDENT = 2 }; + +class BoundAggregateExpression; + +struct AggregateInputData { + AggregateInputData(optional_ptr bind_data_p, ArenaAllocator &allocator_p) + : bind_data(bind_data_p), allocator(allocator_p) { + } + optional_ptr bind_data; + ArenaAllocator &allocator; +}; + +struct AggregateUnaryInput { + AggregateUnaryInput(AggregateInputData &input_p, ValidityMask &input_mask_p) + : input(input_p), input_mask(input_mask_p), input_idx(0) { + } + + AggregateInputData &input; + ValidityMask &input_mask; + idx_t input_idx; + + inline bool RowIsValid() { + return input_mask.RowIsValid(input_idx); + } +}; + +struct AggregateBinaryInput { + AggregateBinaryInput(AggregateInputData &input_p, ValidityMask &left_mask_p, ValidityMask &right_mask_p) + : input(input_p), left_mask(left_mask_p), right_mask(right_mask_p) { + } + + AggregateInputData &input; + ValidityMask &left_mask; + ValidityMask &right_mask; + idx_t lidx; + idx_t ridx; +}; + +struct AggregateFinalizeData { + AggregateFinalizeData(Vector &result_p, AggregateInputData &input_p) + : result(result_p), input(input_p), result_idx(0) { + } + + Vector &result; + AggregateInputData &input; + idx_t result_idx; + + inline void ReturnNull() { + switch (result.GetVectorType()) { + case VectorType::FLAT_VECTOR: + FlatVector::SetNull(result, result_idx, true); + break; + case VectorType::CONSTANT_VECTOR: + ConstantVector::SetNull(result, true); + break; + default: + throw InternalException("Invalid result vector type for aggregate"); + } + } + + inline string_t ReturnString(string_t value) { + return StringVector::AddStringOrBlob(result, value); + } +}; + +struct AggregateStatisticsInput { + AggregateStatisticsInput(optional_ptr bind_data_p, vector &child_stats_p, + optional_ptr node_stats_p) + : bind_data(bind_data_p), child_stats(child_stats_p), node_stats(node_stats_p) { + } + + optional_ptr bind_data; + vector &child_stats; + optional_ptr node_stats; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/built_in_functions.hpp b/src/duckdb/src/include/duckdb/function/built_in_functions.hpp new file mode 100644 index 00000000..2e57f317 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/built_in_functions.hpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/built_in_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function.hpp" +#include "duckdb/catalog/catalog_transaction.hpp" + +namespace duckdb { + +class BuiltinFunctions { +public: + BuiltinFunctions(CatalogTransaction transaction, Catalog &catalog); + ~BuiltinFunctions(); + + //! Initialize a catalog with all built-in functions + void Initialize(); + +public: + void AddFunction(AggregateFunctionSet set); + void AddFunction(AggregateFunction function); + void AddFunction(ScalarFunctionSet set); + void AddFunction(PragmaFunction function); + void AddFunction(const string &name, PragmaFunctionSet functions); + void AddFunction(ScalarFunction function); + void AddFunction(const vector &names, ScalarFunction function); + void AddFunction(TableFunctionSet set); + void AddFunction(TableFunction function); + void AddFunction(CopyFunction function); + + void AddCollation(string name, ScalarFunction function, bool combinable = false, + bool not_required_for_equality = false); + +private: + CatalogTransaction transaction; + Catalog &catalog; + +private: + template + void Register() { + T::RegisterFunction(*this); + } + + // table-producing functions + void RegisterTableScanFunctions(); + void RegisterSQLiteFunctions(); + void RegisterReadFunctions(); + void RegisterTableFunctions(); + void RegisterArrowFunctions(); + + // aggregates + void RegisterDistributiveAggregates(); + + // scalar functions + void RegisterCompressedMaterializationFunctions(); + void RegisterGenericFunctions(); + void RegisterOperators(); + void RegisterStringFunctions(); + void RegisterNestedFunctions(); + void RegisterSequenceFunctions(); + + // pragmas + void RegisterPragmaFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/cast/bound_cast_data.hpp b/src/duckdb/src/include/duckdb/function/cast/bound_cast_data.hpp new file mode 100644 index 00000000..9a81dfe6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/cast/bound_cast_data.hpp @@ -0,0 +1,116 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/cast/bound_cast_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/cast/default_casts.hpp" + +namespace duckdb { + +struct ListBoundCastData : public BoundCastData { + explicit ListBoundCastData(BoundCastInfo child_cast) : child_cast_info(std::move(child_cast)) { + } + + BoundCastInfo child_cast_info; + static unique_ptr BindListToListCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + static unique_ptr InitListLocalState(CastLocalStateParameters ¶meters); + +public: + unique_ptr Copy() const override { + return make_uniq(child_cast_info.Copy()); + } +}; + +struct ListCast { + static bool ListToListCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters); +}; + +struct StructBoundCastData : public BoundCastData { + StructBoundCastData(vector child_casts, LogicalType target_p) + : child_cast_info(std::move(child_casts)), target(std::move(target_p)) { + } + + vector child_cast_info; + LogicalType target; + + static unique_ptr BindStructToStructCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + static unique_ptr InitStructCastLocalState(CastLocalStateParameters ¶meters); + +public: + unique_ptr Copy() const override { + vector copy_info; + for (auto &info : child_cast_info) { + copy_info.push_back(info.Copy()); + } + return make_uniq(std::move(copy_info), target); + } +}; + +struct StructCastLocalState : public FunctionLocalState { +public: + vector> local_states; +}; + +struct MapBoundCastData : public BoundCastData { + MapBoundCastData(BoundCastInfo key_cast, BoundCastInfo value_cast) + : key_cast(std::move(key_cast)), value_cast(std::move(value_cast)) { + } + + BoundCastInfo key_cast; + BoundCastInfo value_cast; + + static unique_ptr BindMapToMapCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + +public: + unique_ptr Copy() const override { + return make_uniq(key_cast.Copy(), value_cast.Copy()); + } +}; + +struct MapCastLocalState : public FunctionLocalState { +public: + unique_ptr key_state; + unique_ptr value_state; +}; + +struct UnionBoundCastData : public BoundCastData { + UnionBoundCastData(union_tag_t member_idx, string name, LogicalType type, int64_t cost, + BoundCastInfo member_cast_info) + : tag(member_idx), name(std::move(name)), type(std::move(type)), cost(cost), + member_cast_info(std::move(member_cast_info)) { + } + + union_tag_t tag; + string name; + LogicalType type; + int64_t cost; + BoundCastInfo member_cast_info; + +public: + unique_ptr Copy() const override { + return make_uniq(tag, name, type, cost, member_cast_info.Copy()); + } + + static bool SortByCostAscending(const UnionBoundCastData &left, const UnionBoundCastData &right) { + return left.cost < right.cost; + } +}; + +struct StructToUnionCast { +public: + static bool AllowImplicitCastFromStruct(const LogicalType &source, const LogicalType &target); + static bool Cast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters); + static unique_ptr BindData(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + static BoundCastInfo Bind(BindCastInput &input, const LogicalType &source, const LogicalType &target); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/cast/cast_function_set.hpp b/src/duckdb/src/include/duckdb/function/cast/cast_function_set.hpp new file mode 100644 index 00000000..3395725c --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/cast/cast_function_set.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/cast/cast_function_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/cast/default_casts.hpp" + +namespace duckdb { +struct MapCastInfo; +struct MapCastNode; + +typedef BoundCastInfo (*bind_cast_function_t)(BindCastInput &input, const LogicalType &source, + const LogicalType &target); +typedef int64_t (*implicit_cast_cost_t)(const LogicalType &from, const LogicalType &to); + +struct GetCastFunctionInput { + GetCastFunctionInput(optional_ptr context = nullptr) : context(context) { + } + GetCastFunctionInput(ClientContext &context) : context(&context) { + } + + optional_ptr context; +}; + +struct BindCastFunction { + BindCastFunction(bind_cast_function_t function, + unique_ptr info = nullptr); // NOLINT: allow implicit cast + + bind_cast_function_t function; + unique_ptr info; +}; + +class CastFunctionSet { +public: + CastFunctionSet(); + +public: + DUCKDB_API static CastFunctionSet &Get(ClientContext &context); + DUCKDB_API static CastFunctionSet &Get(DatabaseInstance &db); + + //! Returns a cast function (from source -> target) + //! Note that this always returns a function - since a cast is ALWAYS possible if the value is NULL + DUCKDB_API BoundCastInfo GetCastFunction(const LogicalType &source, const LogicalType &target, + GetCastFunctionInput &input); + //! Returns the implicit cast cost of casting from source -> target + //! -1 means an implicit cast is not possible + DUCKDB_API int64_t ImplicitCastCost(const LogicalType &source, const LogicalType &target); + //! Register a new cast function from source to target + DUCKDB_API void RegisterCastFunction(const LogicalType &source, const LogicalType &target, BoundCastInfo function, + int64_t implicit_cast_cost = -1); + DUCKDB_API void RegisterCastFunction(const LogicalType &source, const LogicalType &target, + bind_cast_function_t bind, int64_t implicit_cast_cost = -1); + +private: + vector bind_functions; + //! If any custom cast functions have been defined using RegisterCastFunction, this holds the map + optional_ptr map_info; + +private: + void RegisterCastFunction(const LogicalType &source, const LogicalType &target, MapCastNode node); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp new file mode 100644 index 00000000..97e9ff90 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp @@ -0,0 +1,162 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/cast/default_casts.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/function/scalar_function.hpp" + +namespace duckdb { + +class CastFunctionSet; +struct FunctionLocalState; + +//! Extra data that can be attached to a bind function of a cast, and is available during binding +struct BindCastInfo { + DUCKDB_API virtual ~BindCastInfo(); + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +//! Extra data that can be returned by the bind of a cast, and is available during execution of a cast +struct BoundCastData { + DUCKDB_API virtual ~BoundCastData(); + + DUCKDB_API virtual unique_ptr Copy() const = 0; + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct CastParameters { + CastParameters() { + } + CastParameters(BoundCastData *cast_data, bool strict, string *error_message, + optional_ptr local_state) + : cast_data(cast_data), strict(strict), error_message(error_message), local_state(local_state) { + } + CastParameters(CastParameters &parent, optional_ptr cast_data, + optional_ptr local_state) + : cast_data(cast_data), strict(parent.strict), error_message(parent.error_message), local_state(local_state) { + } + + //! The bound cast data (if any) + optional_ptr cast_data; + //! whether or not to enable strict casting + bool strict = false; + // out: error message in case cast has failed + string *error_message = nullptr; + //! Local state + optional_ptr local_state; +}; + +struct CastLocalStateParameters { + CastLocalStateParameters(optional_ptr context_p, optional_ptr cast_data_p) + : context(context_p), cast_data(cast_data_p) { + } + CastLocalStateParameters(ClientContext &context_p, optional_ptr cast_data_p) + : context(&context_p), cast_data(cast_data_p) { + } + CastLocalStateParameters(CastLocalStateParameters &parent, optional_ptr cast_data_p) + : context(parent.context), cast_data(cast_data_p) { + } + + optional_ptr context; + //! The bound cast data (if any) + optional_ptr cast_data; +}; + +typedef bool (*cast_function_t)(Vector &source, Vector &result, idx_t count, CastParameters ¶meters); +typedef unique_ptr (*init_cast_local_state_t)(CastLocalStateParameters ¶meters); + +struct BoundCastInfo { + DUCKDB_API + BoundCastInfo( + cast_function_t function, unique_ptr cast_data = nullptr, + init_cast_local_state_t init_local_state = nullptr); // NOLINT: allow explicit cast from cast_function_t + cast_function_t function; + init_cast_local_state_t init_local_state; + unique_ptr cast_data; + +public: + BoundCastInfo Copy() const; +}; + +struct BindCastInput { + DUCKDB_API BindCastInput(CastFunctionSet &function_set, optional_ptr info, + optional_ptr context); + + CastFunctionSet &function_set; + optional_ptr info; + optional_ptr context; + +public: + DUCKDB_API BoundCastInfo GetCastFunction(const LogicalType &source, const LogicalType &target); +}; + +struct DefaultCasts { + DUCKDB_API static BoundCastInfo GetDefaultCastFunction(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + + DUCKDB_API static bool NopCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters); + DUCKDB_API static bool TryVectorNullCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters); + DUCKDB_API static bool ReinterpretCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters); + +private: + static BoundCastInfo BlobCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo BitCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo DateCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo DecimalCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo EnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo IntervalCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo ListCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo NumericCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo MapCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo PointerCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo StringCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo StructCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo TimeCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo TimeTzCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo TimestampCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + static BoundCastInfo TimestampTzCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + static BoundCastInfo TimestampNsCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + static BoundCastInfo TimestampMsCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + static BoundCastInfo TimestampSecCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target); + static BoundCastInfo UnionCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo UUIDCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + + static BoundCastInfo ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, + const LogicalType &target); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/cast/vector_cast_helpers.hpp b/src/duckdb/src/include/duckdb/function/cast/vector_cast_helpers.hpp new file mode 100644 index 00000000..c7138f97 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/cast/vector_cast_helpers.hpp @@ -0,0 +1,226 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/cast/vector_cast_helpers.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/general_cast.hpp" +#include "duckdb/common/operator/decimal_cast_operators.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/string_map_set.hpp" + +namespace duckdb { + +template +struct VectorStringCastOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto result = (Vector *)dataptr; + return OP::template Operation(input, *result); + } +}; + +struct VectorTryCastData { + VectorTryCastData(Vector &result_p, string *error_message_p, bool strict_p) + : result(result_p), error_message(error_message_p), strict(strict_p) { + } + + Vector &result; + string *error_message; + bool strict; + bool all_converted = true; +}; + +template +struct VectorTryCastOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + RESULT_TYPE output; + if (DUCKDB_LIKELY(OP::template Operation(input, output))) { + return output; + } + auto data = (VectorTryCastData *)dataptr; + return HandleVectorCastError::Operation(CastExceptionText(input), mask, + idx, data->error_message, data->all_converted); + } +}; + +template +struct VectorTryCastStrictOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = (VectorTryCastData *)dataptr; + RESULT_TYPE output; + if (DUCKDB_LIKELY(OP::template Operation(input, output, data->strict))) { + return output; + } + return HandleVectorCastError::Operation(CastExceptionText(input), mask, + idx, data->error_message, data->all_converted); + } +}; + +template +struct VectorTryCastErrorOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = (VectorTryCastData *)dataptr; + RESULT_TYPE output; + if (DUCKDB_LIKELY( + OP::template Operation(input, output, data->error_message, data->strict))) { + return output; + } + bool has_error = data->error_message && !data->error_message->empty(); + return HandleVectorCastError::Operation( + has_error ? *data->error_message : CastExceptionText(input), mask, idx, + data->error_message, data->all_converted); + } +}; + +template +struct VectorTryCastStringOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = (VectorTryCastData *)dataptr; + RESULT_TYPE output; + if (DUCKDB_LIKELY(OP::template Operation(input, output, data->result, + data->error_message, data->strict))) { + return output; + } + return HandleVectorCastError::Operation(CastExceptionText(input), mask, + idx, data->error_message, data->all_converted); + } +}; + +struct VectorDecimalCastData { + VectorDecimalCastData(string *error_message_p, uint8_t width_p, uint8_t scale_p) + : error_message(error_message_p), width(width_p), scale(scale_p) { + } + + string *error_message; + uint8_t width; + uint8_t scale; + bool all_converted = true; +}; + +template +struct VectorDecimalCastOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + auto data = (VectorDecimalCastData *)dataptr; + RESULT_TYPE result_value; + if (!OP::template Operation(input, result_value, data->error_message, data->width, + data->scale)) { + return HandleVectorCastError::Operation("Failed to cast decimal value", mask, idx, + data->error_message, data->all_converted); + } + return result_value; + } +}; + +struct VectorCastHelpers { + template + static bool TemplatedCastLoop(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + UnaryExecutor::Execute(source, result, count); + return true; + } + + template + static bool TemplatedTryCastLoop(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + VectorTryCastData input(result, parameters.error_message, parameters.strict); + UnaryExecutor::GenericExecute(source, result, count, &input, parameters.error_message); + return input.all_converted; + } + + template + static bool TryCastLoop(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + return TemplatedTryCastLoop>(source, result, count, parameters); + } + + template + static bool TryCastStrictLoop(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + return TemplatedTryCastLoop>(source, result, count, parameters); + } + + template + static bool TryCastErrorLoop(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + return TemplatedTryCastLoop>(source, result, count, parameters); + } + + template + static bool TryCastStringLoop(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + return TemplatedTryCastLoop>(source, result, count, parameters); + } + + template + static bool StringCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + D_ASSERT(result.GetType().InternalType() == PhysicalType::VARCHAR); + UnaryExecutor::GenericExecute>(source, result, count, + (void *)&result); + return true; + } + + template + static bool TemplatedDecimalCast(Vector &source, Vector &result, idx_t count, string *error_message, uint8_t width, + uint8_t scale) { + VectorDecimalCastData input(error_message, width, scale); + UnaryExecutor::GenericExecute>(source, result, count, (void *)&input, + error_message); + return input.all_converted; + } + + template + static bool ToDecimalCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &result_type = result.GetType(); + auto width = DecimalType::GetWidth(result_type); + auto scale = DecimalType::GetScale(result_type); + switch (result_type.InternalType()) { + case PhysicalType::INT16: + return TemplatedDecimalCast(source, result, count, parameters.error_message, + width, scale); + case PhysicalType::INT32: + return TemplatedDecimalCast(source, result, count, parameters.error_message, + width, scale); + case PhysicalType::INT64: + return TemplatedDecimalCast(source, result, count, parameters.error_message, + width, scale); + case PhysicalType::INT128: + return TemplatedDecimalCast(source, result, count, parameters.error_message, + width, scale); + default: + throw InternalException("Unimplemented internal type for decimal"); + } + } +}; + +struct VectorStringToList { + static idx_t CountPartsList(const string_t &input); + static bool SplitStringList(const string_t &input, string_t *child_data, idx_t &child_start, Vector &child); + static bool StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, Vector &result, + ValidityMask &result_mask, idx_t count, CastParameters ¶meters, + const SelectionVector *sel); +}; + +struct VectorStringToStruct { + static bool SplitStruct(const string_t &input, vector> &varchar_vectors, idx_t &row_idx, + string_map_t &child_names, vector &child_masks); + static bool StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, Vector &result, + ValidityMask &result_mask, idx_t count, CastParameters ¶meters, + const SelectionVector *sel); +}; + +struct VectorStringToMap { + static idx_t CountPartsMap(const string_t &input); + static bool SplitStringMap(const string_t &input, string_t *child_key_data, string_t *child_val_data, + idx_t &child_start, Vector &varchar_key, Vector &varchar_val); + static bool StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, Vector &result, + ValidityMask &result_mask, idx_t count, CastParameters ¶meters, + const SelectionVector *sel); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/cast_rules.hpp b/src/duckdb/src/include/duckdb/function/cast_rules.hpp new file mode 100644 index 00000000..ff61efed --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/cast_rules.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/cast_rules.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" + +namespace duckdb { +//! Contains a list of rules for casting +class CastRules { +public: + //! Returns the cost of performing an implicit cost from "from" to "to", or -1 if an implicit cast is not possible + static int64_t ImplicitCast(const LogicalType &from, const LogicalType &to); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/compression/compression.hpp b/src/duckdb/src/include/duckdb/function/compression/compression.hpp new file mode 100644 index 00000000..6339ed9f --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/compression/compression.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/compression/compression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/compression_function.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ConstantFun { + static CompressionFunction GetFunction(PhysicalType type); + static bool TypeIsSupported(PhysicalType type); +}; + +struct UncompressedFun { + static CompressionFunction GetFunction(PhysicalType type); + static bool TypeIsSupported(PhysicalType type); +}; + +struct RLEFun { + static CompressionFunction GetFunction(PhysicalType type); + static bool TypeIsSupported(PhysicalType type); +}; + +struct BitpackingFun { + static CompressionFunction GetFunction(PhysicalType type); + static bool TypeIsSupported(PhysicalType type); +}; + +struct DictionaryCompressionFun { + static CompressionFunction GetFunction(PhysicalType type); + static bool TypeIsSupported(PhysicalType type); +}; + +struct ChimpCompressionFun { + static CompressionFunction GetFunction(PhysicalType type); + static bool TypeIsSupported(PhysicalType type); +}; + +struct PatasCompressionFun { + static CompressionFunction GetFunction(PhysicalType type); + static bool TypeIsSupported(PhysicalType type); +}; + +struct FSSTFun { + static CompressionFunction GetFunction(PhysicalType type); + static bool TypeIsSupported(PhysicalType type); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/compression_function.hpp b/src/duckdb/src/include/duckdb/function/compression_function.hpp new file mode 100644 index 00000000..24bb45e0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/compression_function.hpp @@ -0,0 +1,261 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/compression_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/common/enums/compression_type.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/storage/data_pointer.hpp" + +namespace duckdb { +class DatabaseInstance; +class ColumnData; +class ColumnDataCheckpointer; +class ColumnSegment; +class SegmentStatistics; +struct ColumnSegmentState; + +struct ColumnFetchState; +struct ColumnScanState; +struct SegmentScanState; + +struct AnalyzeState { + virtual ~AnalyzeState() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct CompressionState { + virtual ~CompressionState() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct CompressedSegmentState { + virtual ~CompressedSegmentState() { + } + + //! Display info for PRAGMA storage_info + virtual string GetSegmentInfo() const { // LCOV_EXCL_START + return ""; + } // LCOV_EXCL_STOP + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct CompressionAppendState { + explicit CompressionAppendState(BufferHandle handle_p) : handle(std::move(handle_p)) { + } + virtual ~CompressionAppendState() { + } + + BufferHandle handle; + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +//===--------------------------------------------------------------------===// +// Analyze +//===--------------------------------------------------------------------===// +//! The analyze functions are used to determine whether or not to use this compression method +//! The system first determines the potential compression methods to use based on the physical type of the column +//! After that the following steps are taken: +//! 1. The init_analyze is called to initialize the analyze state of every candidate compression method +//! 2. The analyze method is called with all of the input data in the order in which it must be stored. +//! analyze can return "false". In that case, the compression method is taken out of consideration early. +//! 3. The final_analyze method is called, which should return a score for the compression method + +//! The system then decides which compression function to use based on the analyzed score (returned from final_analyze) +typedef unique_ptr (*compression_init_analyze_t)(ColumnData &col_data, PhysicalType type); +typedef bool (*compression_analyze_t)(AnalyzeState &state, Vector &input, idx_t count); +typedef idx_t (*compression_final_analyze_t)(AnalyzeState &state); + +//===--------------------------------------------------------------------===// +// Compress +//===--------------------------------------------------------------------===// +typedef unique_ptr (*compression_init_compression_t)(ColumnDataCheckpointer &checkpointer, + unique_ptr state); +typedef void (*compression_compress_data_t)(CompressionState &state, Vector &scan_vector, idx_t count); +typedef void (*compression_compress_finalize_t)(CompressionState &state); + +//===--------------------------------------------------------------------===// +// Uncompress / Scan +//===--------------------------------------------------------------------===// +typedef unique_ptr (*compression_init_segment_scan_t)(ColumnSegment &segment); + +//! Function prototype used for reading an entire vector (STANDARD_VECTOR_SIZE) +typedef void (*compression_scan_vector_t)(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, + Vector &result); +//! Function prototype used for reading an arbitrary ('scan_count') number of values +typedef void (*compression_scan_partial_t)(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, + Vector &result, idx_t result_offset); +//! Function prototype used for reading a single value +typedef void (*compression_fetch_row_t)(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx); +//! Function prototype used for skipping 'skip_count' values, non-trivial if random-access is not supported for the +//! compressed data. +typedef void (*compression_skip_t)(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count); + +//===--------------------------------------------------------------------===// +// Append (optional) +//===--------------------------------------------------------------------===// +typedef unique_ptr (*compression_init_segment_t)( + ColumnSegment &segment, block_id_t block_id, optional_ptr segment_state); +typedef unique_ptr (*compression_init_append_t)(ColumnSegment &segment); +typedef idx_t (*compression_append_t)(CompressionAppendState &append_state, ColumnSegment &segment, + SegmentStatistics &stats, UnifiedVectorFormat &data, idx_t offset, idx_t count); +typedef idx_t (*compression_finalize_append_t)(ColumnSegment &segment, SegmentStatistics &stats); +typedef void (*compression_revert_append_t)(ColumnSegment &segment, idx_t start_row); + +//===--------------------------------------------------------------------===// +// Serialization (optional) +//===--------------------------------------------------------------------===// +//! Function prototype for serializing the segment state +typedef unique_ptr (*compression_serialize_state_t)(ColumnSegment &segment); +//! Function prototype for deserializing the segment state +typedef unique_ptr (*compression_deserialize_state_t)(Deserializer &deserializer); +//! Function prototype for cleaning up the segment state when the column data is dropped +typedef void (*compression_cleanup_state_t)(ColumnSegment &segment); + +class CompressionFunction { +public: + CompressionFunction(CompressionType type, PhysicalType data_type, compression_init_analyze_t init_analyze, + compression_analyze_t analyze, compression_final_analyze_t final_analyze, + compression_init_compression_t init_compression, compression_compress_data_t compress, + compression_compress_finalize_t compress_finalize, compression_init_segment_scan_t init_scan, + compression_scan_vector_t scan_vector, compression_scan_partial_t scan_partial, + compression_fetch_row_t fetch_row, compression_skip_t skip, + compression_init_segment_t init_segment = nullptr, + compression_init_append_t init_append = nullptr, compression_append_t append = nullptr, + compression_finalize_append_t finalize_append = nullptr, + compression_revert_append_t revert_append = nullptr, + compression_serialize_state_t serialize_state = nullptr, + compression_deserialize_state_t deserialize_state = nullptr, + compression_cleanup_state_t cleanup_state = nullptr) + : type(type), data_type(data_type), init_analyze(init_analyze), analyze(analyze), final_analyze(final_analyze), + init_compression(init_compression), compress(compress), compress_finalize(compress_finalize), + init_scan(init_scan), scan_vector(scan_vector), scan_partial(scan_partial), fetch_row(fetch_row), skip(skip), + init_segment(init_segment), init_append(init_append), append(append), finalize_append(finalize_append), + revert_append(revert_append), serialize_state(serialize_state), deserialize_state(deserialize_state), + cleanup_state(cleanup_state) { + } + + //! Compression type + CompressionType type; + //! The data type this function can compress + PhysicalType data_type; + + //! Analyze step: determine which compression function is the most effective + //! init_analyze is called once to set up the analyze state + compression_init_analyze_t init_analyze; + //! analyze is called several times (once per vector in the row group) + //! analyze should return true, unless compression is no longer possible with this compression method + //! in that case false should be returned + compression_analyze_t analyze; + //! final_analyze should return the score of the compression function + //! ideally this is the exact number of bytes required to store the data + //! this is not required/enforced: it can be an estimate as well + //! also this function can return DConstants::INVALID_INDEX to skip this compression method + compression_final_analyze_t final_analyze; + + //! Compression step: actually compress the data + //! init_compression is called once to set up the comperssion state + compression_init_compression_t init_compression; + //! compress is called several times (once per vector in the row group) + compression_compress_data_t compress; + //! compress_finalize is called after + compression_compress_finalize_t compress_finalize; + + //! init_scan is called to set up the scan state + compression_init_segment_scan_t init_scan; + //! scan_vector scans an entire vector using the scan state + compression_scan_vector_t scan_vector; + //! scan_partial scans a subset of a vector + //! this can request > vector_size as well + //! this is used if a vector crosses segment boundaries, or for child columns of lists + compression_scan_partial_t scan_partial; + //! fetch an individual row from the compressed vector + //! used for index lookups + compression_fetch_row_t fetch_row; + //! Skip forward in the compressed segment + compression_skip_t skip; + + // Append functions + //! This only really needs to be defined for uncompressed segments + + //! Initialize a compressed segment (optional) + compression_init_segment_t init_segment; + //! Initialize the append state (optional) + compression_init_append_t init_append; + //! Append to the compressed segment (optional) + compression_append_t append; + //! Finalize an append to the segment + compression_finalize_append_t finalize_append; + //! Revert append (optional) + compression_revert_append_t revert_append; + + // State serialize functions + //! This is only necessary if the segment state has information that must be written to disk in the metadata + + //! Serialize the segment state to the metadata (optional) + compression_serialize_state_t serialize_state; + //! Deserialize the segment state to the metadata (optional) + compression_deserialize_state_t deserialize_state; + //! Cleanup the segment state (optional) + compression_cleanup_state_t cleanup_state; +}; + +//! The set of compression functions +struct CompressionFunctionSet { + mutex lock; + map> functions; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/copy_function.hpp b/src/duckdb/src/include/duckdb/function/copy_function.hpp new file mode 100644 index 00000000..1dd25eb0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/copy_function.hpp @@ -0,0 +1,142 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/copy_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" + +namespace duckdb { + +class Binder; +struct BoundStatement; +class ColumnDataCollection; +class ExecutionContext; + +struct LocalFunctionData { + virtual ~LocalFunctionData() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct GlobalFunctionData { + virtual ~GlobalFunctionData() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct PreparedBatchData { + virtual ~PreparedBatchData() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +enum class CopyFunctionExecutionMode { REGULAR_COPY_TO_FILE, PARALLEL_COPY_TO_FILE, BATCH_COPY_TO_FILE }; + +typedef BoundStatement (*copy_to_plan_t)(Binder &binder, CopyStatement &stmt); +typedef unique_ptr (*copy_to_bind_t)(ClientContext &context, CopyInfo &info, vector &names, + vector &sql_types); +typedef unique_ptr (*copy_to_initialize_local_t)(ExecutionContext &context, FunctionData &bind_data); +typedef unique_ptr (*copy_to_initialize_global_t)(ClientContext &context, FunctionData &bind_data, + const string &file_path); +typedef void (*copy_to_sink_t)(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + LocalFunctionData &lstate, DataChunk &input); +typedef void (*copy_to_combine_t)(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + LocalFunctionData &lstate); +typedef void (*copy_to_finalize_t)(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate); + +typedef void (*copy_to_serialize_t)(Serializer &serializer, const FunctionData &bind_data, + const CopyFunction &function); + +typedef unique_ptr (*copy_to_deserialize_t)(Deserializer &deserializer, CopyFunction &function); + +typedef unique_ptr (*copy_from_bind_t)(ClientContext &context, CopyInfo &info, + vector &expected_names, + vector &expected_types); +typedef CopyFunctionExecutionMode (*copy_to_execution_mode_t)(bool preserve_insertion_order, bool supports_batch_index); + +typedef unique_ptr (*copy_prepare_batch_t)(ClientContext &context, FunctionData &bind_data, + GlobalFunctionData &gstate, + unique_ptr collection); +typedef void (*copy_flush_batch_t)(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + PreparedBatchData &batch); +typedef idx_t (*copy_desired_batch_size_t)(ClientContext &context, FunctionData &bind_data); + +enum class CopyTypeSupport { SUPPORTED, LOSSY, UNSUPPORTED }; + +typedef CopyTypeSupport (*copy_supports_type_t)(const LogicalType &type); + +class CopyFunction : public Function { +public: + explicit CopyFunction(string name) + : Function(name), plan(nullptr), copy_to_bind(nullptr), copy_to_initialize_local(nullptr), + copy_to_initialize_global(nullptr), copy_to_sink(nullptr), copy_to_combine(nullptr), + copy_to_finalize(nullptr), execution_mode(nullptr), prepare_batch(nullptr), flush_batch(nullptr), + desired_batch_size(nullptr), serialize(nullptr), deserialize(nullptr), supports_type(nullptr), + copy_from_bind(nullptr) { + } + + //! Plan rewrite copy function + copy_to_plan_t plan; + + copy_to_bind_t copy_to_bind; + copy_to_initialize_local_t copy_to_initialize_local; + copy_to_initialize_global_t copy_to_initialize_global; + copy_to_sink_t copy_to_sink; + copy_to_combine_t copy_to_combine; + copy_to_finalize_t copy_to_finalize; + copy_to_execution_mode_t execution_mode; + + copy_prepare_batch_t prepare_batch; + copy_flush_batch_t flush_batch; + copy_desired_batch_size_t desired_batch_size; + + copy_to_serialize_t serialize; + copy_to_deserialize_t deserialize; + + copy_supports_type_t supports_type; + + copy_from_bind_t copy_from_bind; + TableFunction copy_from_function; + + string extension; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/function.hpp b/src/duckdb/src/include/duckdb/function/function.hpp new file mode 100644 index 00000000..9a9bb9e4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/function.hpp @@ -0,0 +1,169 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/named_parameter_map.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/main/external_dependencies.hpp" +#include "duckdb/parser/column_definition.hpp" + +namespace duckdb { +class CatalogEntry; +class Catalog; +class ClientContext; +class Expression; +class ExpressionExecutor; +class Transaction; + +class AggregateFunction; +class AggregateFunctionSet; +class CopyFunction; +class PragmaFunction; +class PragmaFunctionSet; +class ScalarFunctionSet; +class ScalarFunction; +class TableFunctionSet; +class TableFunction; +class SimpleFunction; + +struct PragmaInfo; + +//! The default null handling is NULL in, NULL out +enum class FunctionNullHandling : uint8_t { DEFAULT_NULL_HANDLING = 0, SPECIAL_HANDLING = 1 }; +enum class FunctionSideEffects : uint8_t { NO_SIDE_EFFECTS = 0, HAS_SIDE_EFFECTS = 1 }; + +struct FunctionData { + DUCKDB_API virtual ~FunctionData(); + + DUCKDB_API virtual unique_ptr Copy() const = 0; + DUCKDB_API virtual bool Equals(const FunctionData &other) const = 0; + DUCKDB_API static bool Equals(const FunctionData *left, const FunctionData *right); + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + // FIXME: this function should be removed in the future + template + TARGET &CastNoConst() const { + return const_cast(reinterpret_cast(*this)); + } +}; + +struct TableFunctionData : public FunctionData { + // used to pass on projections to table functions that support them. NB, can contain COLUMN_IDENTIFIER_ROW_ID + vector column_ids; + + DUCKDB_API virtual ~TableFunctionData(); + + DUCKDB_API unique_ptr Copy() const override; + DUCKDB_API bool Equals(const FunctionData &other) const override; +}; + +struct PyTableFunctionData : public TableFunctionData { + //! External dependencies of this table function + unique_ptr external_dependency; +}; + +struct FunctionParameters { + vector values; + named_parameter_map_t named_parameters; +}; + +//! Function is the base class used for any type of function (scalar, aggregate or simple function) +class Function { +public: + DUCKDB_API explicit Function(string name); + DUCKDB_API virtual ~Function(); + + //! The name of the function + string name; + //! Additional Information to specify function from it's name + string extra_info; + +public: + //! Returns the formatted string name(arg1, arg2, ...) + DUCKDB_API static string CallToString(const string &name, const vector &arguments); + //! Returns the formatted string name(arg1, arg2..) -> return_type + DUCKDB_API static string CallToString(const string &name, const vector &arguments, + const LogicalType &return_type); + //! Returns the formatted string name(arg1, arg2.., np1=a, np2=b, ...) + DUCKDB_API static string CallToString(const string &name, const vector &arguments, + const named_parameter_type_map_t &named_parameters); + + //! Used in the bind to erase an argument from a function + DUCKDB_API static void EraseArgument(SimpleFunction &bound_function, vector> &arguments, + idx_t argument_index); +}; + +class SimpleFunction : public Function { +public: + DUCKDB_API SimpleFunction(string name, vector arguments, + LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); + DUCKDB_API ~SimpleFunction() override; + + //! The set of arguments of the function + vector arguments; + //! The set of original arguments of the function - only set if Function::EraseArgument is called + //! Used for (de)serialization purposes + vector original_arguments; + //! The type of varargs to support, or LogicalTypeId::INVALID if the function does not accept variable length + //! arguments + LogicalType varargs; + +public: + DUCKDB_API virtual string ToString() const; + + DUCKDB_API bool HasVarArgs() const; +}; + +class SimpleNamedParameterFunction : public SimpleFunction { +public: + DUCKDB_API SimpleNamedParameterFunction(string name, vector arguments, + LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); + DUCKDB_API ~SimpleNamedParameterFunction() override; + + //! The named parameters of the function + named_parameter_type_map_t named_parameters; + +public: + DUCKDB_API string ToString() const override; + DUCKDB_API bool HasNamedParameters() const; +}; + +class BaseScalarFunction : public SimpleFunction { +public: + DUCKDB_API BaseScalarFunction(string name, vector arguments, LogicalType return_type, + FunctionSideEffects side_effects, + LogicalType varargs = LogicalType(LogicalTypeId::INVALID), + FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); + DUCKDB_API ~BaseScalarFunction() override; + + //! Return type of the function + LogicalType return_type; + //! Whether or not the function has side effects (e.g. sequence increments, random() functions, NOW()). Functions + //! with side-effects cannot be constant-folded. + FunctionSideEffects side_effects; + //! How this function handles NULL values + FunctionNullHandling null_handling; + +public: + DUCKDB_API hash_t Hash() const; + + DUCKDB_API string ToString() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/function_binder.hpp b/src/duckdb/src/include/duckdb/function/function_binder.hpp new file mode 100644 index 00000000..2d02b7ee --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/function_binder.hpp @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/function_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +//! The FunctionBinder class is responsible for binding functions +class FunctionBinder { +public: + DUCKDB_API explicit FunctionBinder(ClientContext &context); + + ClientContext &context; + +public: + //! Bind a scalar function from the set of functions and input arguments. Returns the index of the chosen function, + //! returns DConstants::INVALID_INDEX and sets error if none could be found + DUCKDB_API idx_t BindFunction(const string &name, ScalarFunctionSet &functions, + const vector &arguments, string &error); + DUCKDB_API idx_t BindFunction(const string &name, ScalarFunctionSet &functions, + vector> &arguments, string &error); + //! Bind an aggregate function from the set of functions and input arguments. Returns the index of the chosen + //! function, returns DConstants::INVALID_INDEX and sets error if none could be found + DUCKDB_API idx_t BindFunction(const string &name, AggregateFunctionSet &functions, + const vector &arguments, string &error); + DUCKDB_API idx_t BindFunction(const string &name, AggregateFunctionSet &functions, + vector> &arguments, string &error); + //! Bind a table function from the set of functions and input arguments. Returns the index of the chosen + //! function, returns DConstants::INVALID_INDEX and sets error if none could be found + DUCKDB_API idx_t BindFunction(const string &name, TableFunctionSet &functions, const vector &arguments, + string &error); + DUCKDB_API idx_t BindFunction(const string &name, TableFunctionSet &functions, + vector> &arguments, string &error); + //! Bind a pragma function from the set of functions and input arguments + DUCKDB_API idx_t BindFunction(const string &name, PragmaFunctionSet &functions, PragmaInfo &info, string &error); + + DUCKDB_API unique_ptr BindScalarFunction(const string &schema, const string &name, + vector> children, string &error, + bool is_operator = false, Binder *binder = nullptr); + DUCKDB_API unique_ptr BindScalarFunction(ScalarFunctionCatalogEntry &function, + vector> children, string &error, + bool is_operator = false, Binder *binder = nullptr); + + DUCKDB_API unique_ptr BindScalarFunction(ScalarFunction bound_function, + vector> children, + bool is_operator = false); + + DUCKDB_API unique_ptr + BindAggregateFunction(AggregateFunction bound_function, vector> children, + unique_ptr filter = nullptr, + AggregateType aggr_type = AggregateType::NON_DISTINCT); + + DUCKDB_API static void BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, + const vector> &groups); + +private: + //! Cast a set of expressions to the arguments of this function + void CastToFunctionArguments(SimpleFunction &function, vector> &children); + int64_t BindVarArgsFunctionCost(const SimpleFunction &func, const vector &arguments); + int64_t BindFunctionCost(const SimpleFunction &func, const vector &arguments); + + template + vector BindFunctionsFromArguments(const string &name, FunctionSet &functions, + const vector &arguments, string &error); + + template + idx_t MultipleCandidateException(const string &name, FunctionSet &functions, vector &candidate_functions, + const vector &arguments, string &error); + + template + idx_t BindFunctionFromArguments(const string &name, FunctionSet &functions, const vector &arguments, + string &error); + + vector GetLogicalTypesFromExpressions(vector> &arguments); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/function_serialization.hpp b/src/duckdb/src/include/duckdb/function/function_serialization.hpp new file mode 100644 index 00000000..ac244f38 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/function_serialization.hpp @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/function_serialization.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/client_context.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +class FunctionSerializer { +public: + template + static void Serialize(Serializer &serializer, const FUNC &function, optional_ptr bind_info) { + D_ASSERT(!function.name.empty()); + serializer.WriteProperty(500, "name", function.name); + serializer.WriteProperty(501, "arguments", function.arguments); + serializer.WriteProperty(502, "original_arguments", function.original_arguments); + bool has_serialize = function.serialize; + serializer.WriteProperty(503, "has_serialize", has_serialize); + if (has_serialize) { + serializer.WriteObject(504, "function_data", + [&](Serializer &obj) { function.serialize(obj, bind_info, function); }); + D_ASSERT(function.deserialize); + } + } + + template + static FUNC DeserializeFunction(ClientContext &context, CatalogType catalog_type, const string &name, + vector arguments, vector original_arguments) { + auto &func_catalog = Catalog::GetEntry(context, catalog_type, SYSTEM_CATALOG, DEFAULT_SCHEMA, name); + if (func_catalog.type != catalog_type) { + throw InternalException("DeserializeFunction - cant find catalog entry for function %s", name); + } + auto &functions = func_catalog.Cast(); + auto function = functions.functions.GetFunctionByArguments( + context, original_arguments.empty() ? arguments : original_arguments); + function.arguments = std::move(arguments); + function.original_arguments = std::move(original_arguments); + return function; + } + + template + static pair DeserializeBase(Deserializer &deserializer, CatalogType catalog_type) { + auto &context = deserializer.Get(); + auto name = deserializer.ReadProperty(500, "name"); + auto arguments = deserializer.ReadProperty>(501, "arguments"); + auto original_arguments = deserializer.ReadProperty>(502, "original_arguments"); + auto function = DeserializeFunction(context, catalog_type, name, std::move(arguments), + std::move(original_arguments)); + auto has_serialize = deserializer.ReadProperty(503, "has_serialize"); + return make_pair(std::move(function), has_serialize); + } + + template + static unique_ptr FunctionDeserialize(Deserializer &deserializer, FUNC &function) { + if (!function.deserialize) { + throw SerializationException("Function requires deserialization but no deserialization function for %s", + function.name); + } + unique_ptr result; + deserializer.ReadObject(504, "function_data", + [&](Deserializer &obj) { result = function.deserialize(obj, function); }); + return result; + } + + template + static pair> Deserialize(Deserializer &deserializer, CatalogType catalog_type, + vector> &children, + LogicalType return_type) { + auto &context = deserializer.Get(); + auto entry = DeserializeBase(deserializer, catalog_type); + auto &function = entry.first; + auto has_serialize = entry.second; + + unique_ptr bind_data; + if (has_serialize) { + bind_data = FunctionDeserialize(deserializer, function); + } else if (function.bind) { + try { + bind_data = function.bind(context, function, children); + } catch (Exception &ex) { + // FIXME + throw SerializationException("Error during bind of function in deserialization: %s", ex.what()); + } + } + function.return_type = std::move(return_type); + return make_pair(std::move(function), std::move(bind_data)); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/function_set.hpp b/src/duckdb/src/include/duckdb/function/function_set.hpp new file mode 100644 index 00000000..38989ad9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/function_set.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/function_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/function/pragma_function.hpp" + +namespace duckdb { + +template +class FunctionSet { +public: + explicit FunctionSet(string name) : name(name) { + } + + //! The name of the function set + string name; + //! The set of functions. + vector functions; + +public: + void AddFunction(T function) { + functions.push_back(std::move(function)); + } + idx_t Size() { + return functions.size(); + } + T GetFunctionByOffset(idx_t offset) { + D_ASSERT(offset < functions.size()); + return functions[offset]; + } + T &GetFunctionReferenceByOffset(idx_t offset) { + D_ASSERT(offset < functions.size()); + return functions[offset]; + } + bool MergeFunctionSet(FunctionSet new_functions) { + D_ASSERT(!new_functions.functions.empty()); + bool need_rewrite_entry = false; + for (auto &new_func : new_functions.functions) { + bool can_add = true; + for (auto &func : functions) { + if (new_func.Equal(func)) { + can_add = false; + break; + } + } + if (can_add) { + functions.push_back(new_func); + need_rewrite_entry = true; + } + } + return need_rewrite_entry; + } +}; + +class ScalarFunctionSet : public FunctionSet { +public: + DUCKDB_API explicit ScalarFunctionSet(); + DUCKDB_API explicit ScalarFunctionSet(string name); + DUCKDB_API explicit ScalarFunctionSet(ScalarFunction fun); + + DUCKDB_API ScalarFunction GetFunctionByArguments(ClientContext &context, const vector &arguments); +}; + +class AggregateFunctionSet : public FunctionSet { +public: + DUCKDB_API explicit AggregateFunctionSet(); + DUCKDB_API explicit AggregateFunctionSet(string name); + DUCKDB_API explicit AggregateFunctionSet(AggregateFunction fun); + + DUCKDB_API AggregateFunction GetFunctionByArguments(ClientContext &context, const vector &arguments); +}; + +class TableFunctionSet : public FunctionSet { +public: + DUCKDB_API explicit TableFunctionSet(string name); + DUCKDB_API explicit TableFunctionSet(TableFunction fun); + + TableFunction GetFunctionByArguments(ClientContext &context, const vector &arguments); +}; + +class PragmaFunctionSet : public FunctionSet { +public: + DUCKDB_API explicit PragmaFunctionSet(string name); + DUCKDB_API explicit PragmaFunctionSet(PragmaFunction fun); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/macro_function.hpp b/src/duckdb/src/include/duckdb/function/macro_function.hpp new file mode 100644 index 00000000..0a9bdb5a --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/macro_function.hpp @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/macro_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/query_node.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" + +namespace duckdb { + +enum class MacroType : uint8_t { VOID_MACRO = 0, TABLE_MACRO = 1, SCALAR_MACRO = 2 }; + +class MacroFunction { +public: + explicit MacroFunction(MacroType type); + + //! The type + MacroType type; + //! The positional parameters + vector> parameters; + //! The default parameters and their associated values + unordered_map> default_parameters; + +public: + virtual ~MacroFunction() { + } + + void CopyProperties(MacroFunction &other) const; + + virtual unique_ptr Copy() const = 0; + + static string ValidateArguments(MacroFunction ¯o_function, const string &name, + FunctionExpression &function_expr, + vector> &positionals, + unordered_map> &defaults); + + virtual string ToSQL(const string &schema, const string &name) const; + + virtual void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast macro to type - macro type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast macro to type - macro type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/pragma/pragma_functions.hpp b/src/duckdb/src/include/duckdb/function/pragma/pragma_functions.hpp new file mode 100644 index 00000000..5989d6de --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/pragma/pragma_functions.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/pragma/pragma_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/pragma_function.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { + +struct PragmaQueries { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaFunctions { + static void RegisterFunction(BuiltinFunctions &set); +}; + +string PragmaShow(ClientContext &context, const FunctionParameters ¶meters); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/pragma_function.hpp b/src/duckdb/src/include/duckdb/function/pragma_function.hpp new file mode 100644 index 00000000..85e86baa --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/pragma_function.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/pragma_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function.hpp" +#include "duckdb/parser/parsed_data/pragma_info.hpp" +#include "duckdb/common/unordered_map.hpp" + +namespace duckdb { +class ClientContext; + +//! Return a substitute query to execute instead of this pragma statement +typedef string (*pragma_query_t)(ClientContext &context, const FunctionParameters ¶meters); +//! Execute the main pragma function +typedef void (*pragma_function_t)(ClientContext &context, const FunctionParameters ¶meters); + +//! Pragma functions are invoked by calling PRAGMA x +//! Pragma functions come in three types: +//! * Call: function call, e.g. PRAGMA table_info('tbl') +//! -> call statements can take multiple parameters +//! * Statement: statement without parameters, e.g. PRAGMA show_tables +//! -> this is similar to a call pragma but without parameters +//! Pragma functions can either return a new query to execute (pragma_query_t) +//! or they can +class PragmaFunction : public SimpleNamedParameterFunction { +public: + // Call + DUCKDB_API static PragmaFunction PragmaCall(const string &name, pragma_query_t query, vector arguments, + LogicalType varargs = LogicalType::INVALID); + DUCKDB_API static PragmaFunction PragmaCall(const string &name, pragma_function_t function, + vector arguments, + LogicalType varargs = LogicalType::INVALID); + // Statement + DUCKDB_API static PragmaFunction PragmaStatement(const string &name, pragma_query_t query); + DUCKDB_API static PragmaFunction PragmaStatement(const string &name, pragma_function_t function); + + DUCKDB_API string ToString() const override; + +public: + PragmaType type; + + pragma_query_t query; + pragma_function_t function; + named_parameter_type_map_t named_parameters; + +private: + PragmaFunction(string name, PragmaType pragma_type, pragma_query_t query, pragma_function_t function, + vector arguments, LogicalType varargs); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/replacement_scan.hpp b/src/duckdb/src/include/duckdb/function/replacement_scan.hpp new file mode 100644 index 00000000..3447b96d --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/replacement_scan.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/replacement_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +class ClientContext; +class TableRef; + +struct ReplacementScanData { + virtual ~ReplacementScanData() { + } +}; + +typedef unique_ptr (*replacement_scan_t)(ClientContext &context, const string &table_name, + ReplacementScanData *data); + +//! Replacement table scans are automatically attempted when a table name cannot be found in the schema +//! This allows you to do e.g. SELECT * FROM 'filename.csv', and automatically convert this into a CSV scan +struct ReplacementScan { + explicit ReplacementScan(replacement_scan_t function, unique_ptr data_p = nullptr) + : function(function), data(std::move(data_p)) { + } + + replacement_scan_t function; + unique_ptr data; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp new file mode 100644 index 00000000..aab5dede --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar/compressed_materialization_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/built_in_functions.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct CompressedMaterializationFunctions { + //! The types we compress integral types to + static const vector IntegralTypes(); + //! The types we compress strings to + static const vector StringTypes(); + + static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments); +}; + +//! Needed for (de)serialization without binding +enum class CompressedMaterializationDirection : uint8_t { INVALID = 0, COMPRESS = 1, DECOMPRESS = 2 }; + +struct CMIntegralCompressFun { + static ScalarFunction GetFunction(const LogicalType &input_type, const LogicalType &result_type); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct CMIntegralDecompressFun { + static ScalarFunction GetFunction(const LogicalType &input_type, const LogicalType &result_type); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct CMStringCompressFun { + static ScalarFunction GetFunction(const LogicalType &result_type); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct CMStringDecompressFun { + static ScalarFunction GetFunction(const LogicalType &input_type); + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/generic_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/generic_functions.hpp new file mode 100644 index 00000000..cef16195 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/generic_functions.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar/generic_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { +class BoundFunctionExpression; + +struct ConstantOrNull { + static ScalarFunction GetFunction(const LogicalType &return_type); + static unique_ptr Bind(Value value); + static bool IsConstantOrNull(BoundFunctionExpression &expr, const Value &val); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ExportAggregateFunctionBindData : public FunctionData { + unique_ptr aggregate; + explicit ExportAggregateFunctionBindData(unique_ptr aggregate_p); + unique_ptr Copy() const override; + bool Equals(const FunctionData &other_p) const override; +}; + +struct ExportAggregateFunction { + static unique_ptr Bind(unique_ptr child_aggregate); + static ScalarFunction GetCombine(); + static ScalarFunction GetFinalize(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/list/contains_or_position.hpp b/src/duckdb/src/include/duckdb/function/scalar/list/contains_or_position.hpp new file mode 100644 index 00000000..0eceb782 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/list/contains_or_position.hpp @@ -0,0 +1,138 @@ +#pragma once + +namespace duckdb { + +template +static void TemplatedContainsOrPosition(DataChunk &args, Vector &result, bool is_nested = false) { + D_ASSERT(args.ColumnCount() == 2); + auto count = args.size(); + Vector &list = LIST_ACCESSOR::GetList(args.data[0]); + Vector &value_vector = args.data[1]; + + // Create a result vector of type RETURN_TYPE + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_entries = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + if (list.GetType().id() == LogicalTypeId::SQLNULL) { + result_validity.SetInvalid(0); + return; + } + + auto list_size = LIST_ACCESSOR::GetListSize(list); + auto &child_vector = LIST_ACCESSOR::GetEntry(list); + + UnifiedVectorFormat child_data; + child_vector.ToUnifiedFormat(list_size, child_data); + + UnifiedVectorFormat list_data; + list.ToUnifiedFormat(count, list_data); + auto list_entries = UnifiedVectorFormat::GetData(list_data); + + UnifiedVectorFormat value_data; + value_vector.ToUnifiedFormat(count, value_data); + + // not required for a comparison of nested types + auto child_value = UnifiedVectorFormat::GetData(child_data); + auto values = UnifiedVectorFormat::GetData(value_data); + + for (idx_t i = 0; i < count; i++) { + auto list_index = list_data.sel->get_index(i); + auto value_index = value_data.sel->get_index(i); + + if (!list_data.validity.RowIsValid(list_index) || !value_data.validity.RowIsValid(value_index)) { + result_validity.SetInvalid(i); + continue; + } + + const auto &list_entry = list_entries[list_index]; + + result_entries[i] = OP::Initialize(); + for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { + + auto child_value_idx = child_data.sel->get_index(list_entry.offset + child_idx); + if (!child_data.validity.RowIsValid(child_value_idx)) { + continue; + } + + if (!is_nested) { + if (Equals::Operation(child_value[child_value_idx], values[value_index])) { + result_entries[i] = OP::UpdateResultEntries(child_idx); + break; // Found value in list, no need to look further + } + } else { + // FIXME: using Value is less efficient than modifying the vector comparison code + // to more efficiently compare nested types + + // Note: When using GetValue we don't first apply the selection vector + // because it is already done inside GetValue + auto lvalue = child_vector.GetValue(list_entry.offset + child_idx); + auto rvalue = value_vector.GetValue(i); + if (Value::NotDistinctFrom(lvalue, rvalue)) { + result_entries[i] = OP::UpdateResultEntries(child_idx); + break; // Found value in list, no need to look further + } + } + } + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +template +void ListContainsOrPosition(DataChunk &args, Vector &result) { + const auto physical_type = args.data[1].GetType().InternalType(); + switch (physical_type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::INT16: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::INT32: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::INT64: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::INT128: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::UINT8: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::UINT16: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::UINT32: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::UINT64: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::FLOAT: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::DOUBLE: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::VARCHAR: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::INTERVAL: + TemplatedContainsOrPosition(args, result); + break; + case PhysicalType::STRUCT: + case PhysicalType::LIST: + TemplatedContainsOrPosition(args, result, true); + break; + default: + throw NotImplementedException("This function has not been implemented for logical type %s", + TypeIdToString(physical_type)); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/nested_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/nested_functions.hpp new file mode 100644 index 00000000..8fc34a3c --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/nested_functions.hpp @@ -0,0 +1,111 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar/nested_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/function/built_in_functions.hpp" +#include "duckdb/function/scalar/list/contains_or_position.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +struct ListArgFunctor { + static Vector &GetList(Vector &list) { + return list; + } + static idx_t GetListSize(Vector &list) { + return ListVector::GetListSize(list); + } + static Vector &GetEntry(Vector &list) { + return ListVector::GetEntry(list); + } +}; + +struct ContainsFunctor { + static inline bool Initialize() { + return false; + } + static inline bool UpdateResultEntries(idx_t child_idx) { + return true; + } +}; + +struct PositionFunctor { + static inline int32_t Initialize() { + return 0; + } + static inline int32_t UpdateResultEntries(idx_t child_idx) { + return child_idx + 1; + } +}; + +struct VariableReturnBindData : public FunctionData { + LogicalType stype; + + explicit VariableReturnBindData(LogicalType stype_p) : stype(std::move(stype_p)) { + } + + unique_ptr Copy() const override { + return make_uniq(stype); + } + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return stype == other.stype; + } + static void Serialize(Serializer &serializer, const optional_ptr bind_data, + const ScalarFunction &function) { + auto &info = bind_data->Cast(); + serializer.WriteProperty(100, "variable_return_type", info.stype); + } + + static unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &bound_function) { + auto stype = deserializer.ReadProperty(100, "variable_return_type"); + return make_uniq(std::move(stype)); + } +}; + +template > +struct HistogramAggState { + MAP_TYPE *hist; +}; + +struct ListExtractFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ListConcatFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ListContainsFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ListPositionFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ListResizeFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct StructExtractFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/operators.hpp b/src/duckdb/src/include/duckdb/function/scalar/operators.hpp new file mode 100644 index 00000000..90689a33 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/operators.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar/operators.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { + +struct AddFun { + static ScalarFunction GetFunction(const LogicalType &type); + static ScalarFunction GetFunction(const LogicalType &left_type, const LogicalType &right_type); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct SubtractFun { + static ScalarFunction GetFunction(const LogicalType &type); + static ScalarFunction GetFunction(const LogicalType &left_type, const LogicalType &right_type); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct MultiplyFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DivideFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ModFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp b/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp new file mode 100644 index 00000000..208e033e --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp @@ -0,0 +1,159 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar/regexp.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" +#include "re2/re2.h" +#include "duckdb/function/built_in_functions.hpp" +#include "re2/stringpiece.h" + +namespace duckdb { + +namespace regexp_util { + +bool TryParseConstantPattern(ClientContext &context, Expression &expr, string &constant_string); +void ParseRegexOptions(const string &options, duckdb_re2::RE2::Options &result, bool *global_replace = nullptr); +void ParseRegexOptions(ClientContext &context, Expression &expr, RE2::Options &target, bool *global_replace = nullptr); + +inline duckdb_re2::StringPiece CreateStringPiece(const string_t &input) { + return duckdb_re2::StringPiece(input.GetData(), input.GetSize()); +} + +inline string_t Extract(const string_t &input, Vector &result, const RE2 &re, const duckdb_re2::StringPiece &rewrite) { + string extracted; + RE2::Extract(input.GetString(), re, rewrite, &extracted); + return StringVector::AddString(result, extracted.c_str(), extracted.size()); +} + +} // namespace regexp_util + +struct RegexpExtractAll { + static void Execute(DataChunk &args, ExpressionState &state, Vector &result); + static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments); + static unique_ptr InitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data); +}; + +struct RegexpBaseBindData : public FunctionData { + RegexpBaseBindData(); + RegexpBaseBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern = true); + virtual ~RegexpBaseBindData(); + + duckdb_re2::RE2::Options options; + string constant_string; + bool constant_pattern; + + virtual bool Equals(const FunctionData &other_p) const override; +}; + +struct RegexpMatchesBindData : public RegexpBaseBindData { + RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern); + RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern, + string range_min, string range_max, bool range_success); + + string range_min; + string range_max; + bool range_success; + + unique_ptr Copy() const override; +}; + +struct RegexpReplaceBindData : public RegexpBaseBindData { + RegexpReplaceBindData(); + RegexpReplaceBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern, + bool global_replace); + + bool global_replace; + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other_p) const override; +}; + +struct RegexpExtractBindData : public RegexpBaseBindData { + RegexpExtractBindData(); + RegexpExtractBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern, + string group_string); + + string group_string; + duckdb_re2::StringPiece rewrite; + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other_p) const override; +}; + +struct RegexStringPieceArgs { + RegexStringPieceArgs() : size(0), capacity(0), group_buffer(nullptr) { + } + void Init(idx_t size) { + this->size = size; + // Allocate for one extra, for the all-encompassing match group + this->capacity = size + 1; + group_buffer = AllocateArray(capacity); + } + void SetSize(idx_t size) { + this->size = size; + if (size + 1 > capacity) { + Clear(); + Init(size); + } + } + + RegexStringPieceArgs &operator=(RegexStringPieceArgs &&other) { + std::swap(this->size, other.size); + std::swap(this->capacity, other.capacity); + std::swap(this->group_buffer, other.group_buffer); + return *this; + } + + ~RegexStringPieceArgs() { + Clear(); + } + +private: + void Clear() { + DeleteArray(group_buffer, capacity); + group_buffer = nullptr; + + size = 0; + capacity = 0; + } + +public: + idx_t size; + //! The currently allocated capacity for the groups + idx_t capacity; + //! Used by ExtractAll to pre-allocate the storage for the groups + duckdb_re2::StringPiece *group_buffer; +}; + +struct RegexLocalState : public FunctionLocalState { + explicit RegexLocalState(RegexpBaseBindData &info, bool extract_all = false) + : constant_pattern(duckdb_re2::StringPiece(info.constant_string.c_str(), info.constant_string.size()), + info.options) { + if (extract_all) { + auto group_count_p = constant_pattern.NumberOfCapturingGroups(); + if (group_count_p != -1) { + group_buffer.Init(group_count_p); + } + } + D_ASSERT(info.constant_pattern); + } + + RE2 constant_pattern; + //! Used by regexp_extract_all to pre-allocate the args + RegexStringPieceArgs group_buffer; +}; + +unique_ptr RegexInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data); +unique_ptr RegexpMatchesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/sequence_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/sequence_functions.hpp new file mode 100644 index 00000000..ee21a651 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/sequence_functions.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar/sequence_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { + +struct NextvalFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct CurrvalFun { + static void RegisterFunction(BuiltinFunctions &set); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/strftime_format.hpp b/src/duckdb/src/include/duckdb/function/scalar/strftime_format.hpp new file mode 100644 index 00000000..5f96d03e --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/strftime_format.hpp @@ -0,0 +1,168 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar/strftime.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/common/vector.hpp" + +#include + +namespace duckdb { + +enum class StrTimeSpecifier : uint8_t { + ABBREVIATED_WEEKDAY_NAME = 0, // %a - Abbreviated weekday name. (Sun, Mon, ...) + FULL_WEEKDAY_NAME = 1, // %A Full weekday name. (Sunday, Monday, ...) + WEEKDAY_DECIMAL = 2, // %w - Weekday as a decimal number. (0, 1, ..., 6) + DAY_OF_MONTH_PADDED = 3, // %d - Day of the month as a zero-padded decimal. (01, 02, ..., 31) + DAY_OF_MONTH = 4, // %-d - Day of the month as a decimal number. (1, 2, ..., 30) + ABBREVIATED_MONTH_NAME = 5, // %b - Abbreviated month name. (Jan, Feb, ..., Dec) + FULL_MONTH_NAME = 6, // %B - Full month name. (January, February, ...) + MONTH_DECIMAL_PADDED = 7, // %m - Month as a zero-padded decimal number. (01, 02, ..., 12) + MONTH_DECIMAL = 8, // %-m - Month as a decimal number. (1, 2, ..., 12) + YEAR_WITHOUT_CENTURY_PADDED = 9, // %y - Year without century as a zero-padded decimal number. (00, 01, ..., 99) + YEAR_WITHOUT_CENTURY = 10, // %-y - Year without century as a decimal number. (0, 1, ..., 99) + YEAR_DECIMAL = 11, // %Y - Year with century as a decimal number. (2013, 2019 etc.) + HOUR_24_PADDED = 12, // %H - Hour (24-hour clock) as a zero-padded decimal number. (00, 01, ..., 23) + HOUR_24_DECIMAL = 13, // %-H - Hour (24-hour clock) as a decimal number. (0, 1, ..., 23) + HOUR_12_PADDED = 14, // %I - Hour (12-hour clock) as a zero-padded decimal number. (01, 02, ..., 12) + HOUR_12_DECIMAL = 15, // %-I - Hour (12-hour clock) as a decimal number. (1, 2, ... 12) + AM_PM = 16, // %p - Locale’s AM or PM. (AM, PM) + MINUTE_PADDED = 17, // %M - Minute as a zero-padded decimal number. (00, 01, ..., 59) + MINUTE_DECIMAL = 18, // %-M - Minute as a decimal number. (0, 1, ..., 59) + SECOND_PADDED = 19, // %S - Second as a zero-padded decimal number. (00, 01, ..., 59) + SECOND_DECIMAL = 20, // %-S - Second as a decimal number. (0, 1, ..., 59) + MICROSECOND_PADDED = 21, // %f - Microsecond as a decimal number, zero-padded on the left. (000000 - 999999) + MILLISECOND_PADDED = 22, // %g - Millisecond as a decimal number, zero-padded on the left. (000 - 999) + UTC_OFFSET = 23, // %z - UTC offset in the form +HHMM or -HHMM. ( ) + TZ_NAME = 24, // %Z - Time zone name. ( ) + DAY_OF_YEAR_PADDED = 25, // %j - Day of the year as a zero-padded decimal number. (001, 002, ..., 366) + DAY_OF_YEAR_DECIMAL = 26, // %-j - Day of the year as a decimal number. (1, 2, ..., 366) + WEEK_NUMBER_PADDED_SUN_FIRST = + 27, // %U - Week number of the year (Sunday as the first day of the week). All days in a new year preceding the + // first Sunday are considered to be in week 0. (00, 01, ..., 53) + WEEK_NUMBER_PADDED_MON_FIRST = + 28, // %W - Week number of the year (Monday as the first day of the week). All days in a new year preceding the + // first Monday are considered to be in week 0. (00, 01, ..., 53) + LOCALE_APPROPRIATE_DATE_AND_TIME = + 29, // %c - Locale’s appropriate date and time representation. (Mon Sep 30 07:06:05 2013) + LOCALE_APPROPRIATE_DATE = 30, // %x - Locale’s appropriate date representation. (09/30/13) + LOCALE_APPROPRIATE_TIME = 31, // %X - Locale’s appropriate time representation. (07:06:05) + NANOSECOND_PADDED = 32 // %n - Nanosecond as a decimal number, zero-padded on the left. (000000000 - 999999999) +}; + +struct StrTimeFormat { +public: + virtual ~StrTimeFormat() { + } + + DUCKDB_API static string ParseFormatSpecifier(const string &format_string, StrTimeFormat &format); + + inline bool HasFormatSpecifier(StrTimeSpecifier s) const { + return std::find(specifiers.begin(), specifiers.end(), s) != specifiers.end(); + } + + //! The full format specifier, for error messages + string format_specifier; + +protected: + //! The format specifiers + vector specifiers; + //! The literals that appear in between the format specifiers + //! The following must hold: literals.size() = specifiers.size() + 1 + //! Format is literals[0], specifiers[0], literals[1], ..., specifiers[n - 1], literals[n] + vector literals; + //! The constant size that appears in the format string + idx_t constant_size = 0; + //! The max numeric width of the specifier (if it is parsed as a number), or -1 if it is not a number + vector numeric_width; + +protected: + void AddLiteral(string literal); + DUCKDB_API virtual void AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier); +}; + +struct StrfTimeFormat : public StrTimeFormat { + DUCKDB_API idx_t GetLength(date_t date, dtime_t time, int32_t utc_offset, const char *tz_name); + + DUCKDB_API void FormatString(date_t date, int32_t data[8], const char *tz_name, char *target); + void FormatString(date_t date, dtime_t time, char *target); + + DUCKDB_API static string Format(timestamp_t timestamp, const string &format); + + DUCKDB_API void ConvertDateVector(Vector &input, Vector &result, idx_t count); + DUCKDB_API void ConvertTimestampVector(Vector &input, Vector &result, idx_t count); + +protected: + //! The variable-length specifiers. To determine total string size, these need to be checked. + vector var_length_specifiers; + //! Whether or not the current specifier is a special "date" specifier (i.e. one that requires a date_t object to + //! generate) + vector is_date_specifier; + +protected: + DUCKDB_API void AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) override; + static idx_t GetSpecifierLength(StrTimeSpecifier specifier, date_t date, dtime_t time, int32_t utc_offset, + const char *tz_name); + char *WriteString(char *target, const string_t &str); + char *Write2(char *target, uint8_t value); + char *WritePadded2(char *target, uint32_t value); + char *WritePadded3(char *target, uint32_t value); + char *WritePadded(char *target, uint32_t value, size_t padding); + bool IsDateSpecifier(StrTimeSpecifier specifier); + char *WriteDateSpecifier(StrTimeSpecifier specifier, date_t date, char *target); + char *WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t data[], const char *tz_name, size_t tz_len, + char *target); +}; + +struct StrpTimeFormat : public StrTimeFormat { +public: + StrpTimeFormat(); + + //! Type-safe parsing argument + struct ParseResult { + int32_t data[8]; // year, month, day, hour, min, sec, µs, offset + string tz; + string error_message; + idx_t error_position = DConstants::INVALID_INDEX; + + date_t ToDate(); + timestamp_t ToTimestamp(); + + bool TryToDate(date_t &result); + bool TryToTimestamp(timestamp_t &result); + + DUCKDB_API string FormatError(string_t input, const string &format_specifier); + }; + +public: + DUCKDB_API static ParseResult Parse(const string &format, const string &text); + + DUCKDB_API bool Parse(string_t str, ParseResult &result) const; + + DUCKDB_API bool TryParseDate(string_t str, date_t &result, string &error_message) const; + DUCKDB_API bool TryParseTimestamp(string_t str, timestamp_t &result, string &error_message) const; + + date_t ParseDate(string_t str); + timestamp_t ParseTimestamp(string_t str); + + void Serialize(Serializer &serializer) const; + static StrpTimeFormat Deserialize(Deserializer &deserializer); + +protected: + static string FormatStrpTimeError(const string &input, idx_t position); + DUCKDB_API void AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) override; + int NumericSpecifierWidth(StrTimeSpecifier specifier); + int32_t TryParseCollection(const char *data, idx_t &pos, idx_t size, const string_t collection[], + idx_t collection_count) const; + +private: + explicit StrpTimeFormat(const string &format_string); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/string_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/string_functions.hpp new file mode 100644 index 00000000..293dcec9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/string_functions.hpp @@ -0,0 +1,130 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar/string_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" +#include "utf8proc.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace re2 { +class RE2; +} + +namespace duckdb { + +struct LowerFun { + static uint8_t ascii_to_lower_map[]; + + //! Returns the length of the result string obtained from lowercasing the given input (in bytes) + static idx_t LowerLength(const char *input_data, idx_t input_length); + //! Lowercases the string to the target output location, result_data must have space for at least LowerLength bytes + static void LowerCase(const char *input_data, idx_t input_length, char *result_data); + + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct UpperFun { + static uint8_t ascii_to_upper_map[]; + + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct StripAccentsFun { + static bool IsAscii(const char *input, idx_t n); + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ConcatFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct LengthFun { + static void RegisterFunction(BuiltinFunctions &set); + static inline bool IsCharacter(char c) { + return (c & 0xc0) != 0x80; + } + + template + static inline TR Length(TA input) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + TR length = 0; + for (idx_t i = 0; i < input_length; i++) { + length += IsCharacter(input_data[i]); + } + return length; + } + + template + static inline TR GraphemeCount(TA input) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + for (idx_t i = 0; i < input_length; i++) { + if (input_data[i] & 0x80) { + int64_t length = 0; + // non-ascii character: use grapheme iterator on remainder of string + utf8proc_grapheme_callback(input_data, input_length, [&](size_t start, size_t end) { + length++; + return true; + }); + return length; + } + } + return input_length; + } +}; + +struct LikeFun { + static ScalarFunction GetLikeFunction(); + static void RegisterFunction(BuiltinFunctions &set); + DUCKDB_API static bool Glob(const char *s, idx_t slen, const char *pattern, idx_t plen, + bool allow_question_mark = true); +}; + +struct LikeEscapeFun { + static ScalarFunction GetLikeEscapeFun(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct NFCNormalizeFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct SubstringFun { + static void RegisterFunction(BuiltinFunctions &set); + static string_t SubstringUnicode(Vector &result, string_t input, int64_t offset, int64_t length); + static string_t SubstringGrapheme(Vector &result, string_t input, int64_t offset, int64_t length); +}; + +struct PrefixFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct SuffixFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ContainsFun { + static ScalarFunction GetFunction(); + static void RegisterFunction(BuiltinFunctions &set); + static idx_t Find(const string_t &haystack, const string_t &needle); + static idx_t Find(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, + idx_t needle_size); +}; + +struct RegexpFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar_function.hpp b/src/duckdb/src/include/duckdb/function/scalar_function.hpp new file mode 100644 index 00000000..79b5a87d --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar_function.hpp @@ -0,0 +1,216 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { + +struct FunctionLocalState { + DUCKDB_API virtual ~FunctionLocalState(); + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +class Binder; +class BoundFunctionExpression; +class DependencyList; +class ScalarFunctionCatalogEntry; + +struct FunctionStatisticsInput { + FunctionStatisticsInput(BoundFunctionExpression &expr_p, optional_ptr bind_data_p, + vector &child_stats_p, unique_ptr *expr_ptr_p) + : expr(expr_p), bind_data(bind_data_p), child_stats(child_stats_p), expr_ptr(expr_ptr_p) { + } + + BoundFunctionExpression &expr; + optional_ptr bind_data; + vector &child_stats; + unique_ptr *expr_ptr; +}; + +//! The type used for scalar functions +typedef std::function scalar_function_t; +//! Binds the scalar function and creates the function data +typedef unique_ptr (*bind_scalar_function_t)(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments); +typedef unique_ptr (*init_local_state_t)(ExpressionState &state, + const BoundFunctionExpression &expr, + FunctionData *bind_data); +typedef unique_ptr (*function_statistics_t)(ClientContext &context, FunctionStatisticsInput &input); +//! Adds the dependencies of this BoundFunctionExpression to the set of dependencies +typedef void (*dependency_function_t)(BoundFunctionExpression &expr, DependencyList &dependencies); + +typedef void (*function_serialize_t)(Serializer &serializer, const optional_ptr bind_data, + const ScalarFunction &function); +typedef unique_ptr (*function_deserialize_t)(Deserializer &deserializer, ScalarFunction &function); + +class ScalarFunction : public BaseScalarFunction { +public: + DUCKDB_API ScalarFunction(string name, vector arguments, LogicalType return_type, + scalar_function_t function, bind_scalar_function_t bind = nullptr, + dependency_function_t dependency = nullptr, function_statistics_t statistics = nullptr, + init_local_state_t init_local_state = nullptr, + LogicalType varargs = LogicalType(LogicalTypeId::INVALID), + FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS, + FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); + + DUCKDB_API ScalarFunction(vector arguments, LogicalType return_type, scalar_function_t function, + bind_scalar_function_t bind = nullptr, dependency_function_t dependency = nullptr, + function_statistics_t statistics = nullptr, init_local_state_t init_local_state = nullptr, + LogicalType varargs = LogicalType(LogicalTypeId::INVALID), + FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS, + FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); + + //! The main scalar function to execute + scalar_function_t function; + //! The bind function (if any) + bind_scalar_function_t bind; + //! Init thread local state for the function (if any) + init_local_state_t init_local_state; + //! The dependency function (if any) + dependency_function_t dependency; + //! The statistics propagation function (if any) + function_statistics_t statistics; + + function_serialize_t serialize; + function_deserialize_t deserialize; + + DUCKDB_API bool operator==(const ScalarFunction &rhs) const; + DUCKDB_API bool operator!=(const ScalarFunction &rhs) const; + + DUCKDB_API bool Equal(const ScalarFunction &rhs) const; + +public: + DUCKDB_API static void NopFunction(DataChunk &input, ExpressionState &state, Vector &result); + + template + static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() >= 1); + UnaryExecutor::Execute(input.data[0], result, input.size()); + } + + template + static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 2); + BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result, input.size()); + } + + template + static void TernaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 3); + TernaryExecutor::ExecuteStandard(input.data[0], input.data[1], input.data[2], result, + input.size()); + } + +public: + template + static scalar_function_t GetScalarUnaryFunction(LogicalType type) { + scalar_function_t function; + switch (type.id()) { + case LogicalTypeId::TINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::SMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::INTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::BIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UTINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::USMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UINTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UBIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::HUGEINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::FLOAT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DOUBLE: + function = &ScalarFunction::UnaryFunction; + break; + default: + throw InternalException("Unimplemented type for GetScalarUnaryFunction"); + } + return function; + } + + template + static scalar_function_t GetScalarUnaryFunctionFixedReturn(LogicalType type) { + scalar_function_t function; + switch (type.id()) { + case LogicalTypeId::TINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::SMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::INTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::BIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UTINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::USMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UINTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UBIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::HUGEINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::FLOAT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DOUBLE: + function = &ScalarFunction::UnaryFunction; + break; + default: + throw InternalException("Unimplemented type for GetScalarUnaryFunctionFixedReturn"); + } + return function; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar_macro_function.hpp b/src/duckdb/src/include/duckdb/function/scalar_macro_function.hpp new file mode 100644 index 00000000..5e678162 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar_macro_function.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/scalar_macro_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once +//! The SelectStatement of the view +#include "duckdb/function/macro_function.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" + +namespace duckdb { + +class ScalarMacroFunction : public MacroFunction { +public: + static constexpr const MacroType TYPE = MacroType::SCALAR_MACRO; + +public: + explicit ScalarMacroFunction(unique_ptr expression); + ScalarMacroFunction(void); + + //! The macro expression + unique_ptr expression; + +public: + unique_ptr Copy() const override; + + string ToSQL(const string &schema, const string &name) const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/arrow.hpp b/src/duckdb/src/include/duckdb/function/table/arrow.hpp new file mode 100644 index 00000000..fae51810 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table/arrow.hpp @@ -0,0 +1,155 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table/arrow.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/table_function.hpp" +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/thread.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/function/built_in_functions.hpp" +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" + +namespace duckdb { + +struct ArrowInterval { + int32_t months; + int32_t days; + int64_t nanoseconds; + + inline bool operator==(const ArrowInterval &rhs) const { + return this->days == rhs.days && this->months == rhs.months && this->nanoseconds == rhs.nanoseconds; + } +}; + +struct ArrowProjectedColumns { + unordered_map projection_map; + vector columns; +}; + +struct ArrowStreamParameters { + ArrowProjectedColumns projected_columns; + TableFilterSet *filters; +}; + +typedef unique_ptr (*stream_factory_produce_t)(uintptr_t stream_factory_ptr, + ArrowStreamParameters ¶meters); +typedef void (*stream_factory_get_schema_t)(uintptr_t stream_factory_ptr, ArrowSchemaWrapper &schema); + +struct ArrowScanFunctionData : public PyTableFunctionData { +public: + ArrowScanFunctionData(stream_factory_produce_t scanner_producer_p, uintptr_t stream_factory_ptr_p) + : lines_read(0), stream_factory_ptr(stream_factory_ptr_p), scanner_producer(scanner_producer_p) { + } + vector all_types; + atomic lines_read; + ArrowSchemaWrapper schema_root; + idx_t rows_per_thread; + //! Pointer to the scanner factory + uintptr_t stream_factory_ptr; + //! Pointer to the scanner factory produce + stream_factory_produce_t scanner_producer; + //! Arrow table data + ArrowTableType arrow_table; +}; + +struct ArrowScanLocalState : public LocalTableFunctionState { + explicit ArrowScanLocalState(unique_ptr current_chunk) : chunk(current_chunk.release()) { + } + + unique_ptr stream; + shared_ptr chunk; + // This vector hold the Arrow Vectors owned by DuckDB to allow for zero-copy + // Note that only DuckDB can release these vectors + unordered_map> arrow_owned_data; + idx_t chunk_offset = 0; + idx_t batch_index = 0; + vector column_ids; + //! Store child vectors for Arrow Dictionary Vectors (col-idx,vector) + unordered_map> arrow_dictionary_vectors; + TableFilterSet *filters = nullptr; + //! The DataChunk containing all read columns (even filter columns that are immediately removed) + DataChunk all_columns; +}; + +struct ArrowScanGlobalState : public GlobalTableFunctionState { + unique_ptr stream; + mutex main_mutex; + idx_t max_threads = 1; + idx_t batch_index = 0; + bool done = false; + + vector projection_ids; + vector scanned_types; + + idx_t MaxThreads() const override { + return max_threads; + } + + bool CanRemoveFilterColumns() const { + return !projection_ids.empty(); + } +}; + +struct ArrowTableFunction { +public: + static void RegisterFunction(BuiltinFunctions &set); + +public: + //! Binds an arrow table + static unique_ptr ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names); + //! Actual conversion from Arrow to DuckDB + static void ArrowToDuckDB(ArrowScanLocalState &scan_state, const arrow_column_map_t &arrow_convert_data, + DataChunk &output, idx_t start, bool arrow_scan_is_projected = true); + + //! Get next scan state + static bool ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, + ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state); + + //! Initialize Global State + static unique_ptr ArrowScanInitGlobal(ClientContext &context, + TableFunctionInitInput &input); + + //! Initialize Local State + static unique_ptr ArrowScanInitLocalInternal(ClientContext &context, + TableFunctionInitInput &input, + GlobalTableFunctionState *global_state); + static unique_ptr ArrowScanInitLocal(ExecutionContext &context, + TableFunctionInitInput &input, + GlobalTableFunctionState *global_state); + + //! Scan Function + static void ArrowScanFunction(ClientContext &context, TableFunctionInput &data, DataChunk &output); + static void PopulateArrowTableType(ArrowTableType &arrow_table, ArrowSchemaWrapper &schema_p, vector &names, + vector &return_types); + +protected: + //! Defines Maximum Number of Threads + static idx_t ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data); + + //! Allows parallel Create Table / Insertion + static idx_t ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, + LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state); + + //! -----Utility Functions:----- + //! Gets Arrow Table's Cardinality + static unique_ptr ArrowScanCardinality(ClientContext &context, const FunctionData *bind_data); + //! Gets the progress on the table scan, used for Progress Bars + static double ArrowProgress(ClientContext &context, const FunctionData *bind_data, + const GlobalTableFunctionState *global_state); + //! Renames repeated columns and case sensitive columns + static void RenameArrowColumns(vector &names); + //! Helper function to get the DuckDB logical type + static unique_ptr GetArrowLogicalType(ArrowSchema &schema); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp b/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp new file mode 100644 index 00000000..bd15f89d --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table/arrow_duck_schema.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/unique_ptr.hpp" + +namespace duckdb { +//===--------------------------------------------------------------------===// +// Arrow Variable Size Types +//===--------------------------------------------------------------------===// +enum class ArrowVariableSizeType : uint8_t { FIXED_SIZE = 0, NORMAL = 1, SUPER_SIZE = 2 }; + +//===--------------------------------------------------------------------===// +// Arrow Time/Date Types +//===--------------------------------------------------------------------===// +enum class ArrowDateTimeType : uint8_t { + MILLISECONDS = 0, + MICROSECONDS = 1, + NANOSECONDS = 2, + SECONDS = 3, + DAYS = 4, + MONTHS = 5, + MONTH_DAY_NANO = 6 +}; + +class ArrowType { +public: + //! From a DuckDB type + ArrowType(LogicalType type_p) + : type(std::move(type_p)), size_type(ArrowVariableSizeType::NORMAL), + date_time_precision(ArrowDateTimeType::DAYS) {}; + + //! From a DuckDB type + fixed_size + ArrowType(LogicalType type_p, idx_t fixed_size_p) + : type(std::move(type_p)), size_type(ArrowVariableSizeType::FIXED_SIZE), + date_time_precision(ArrowDateTimeType::DAYS), fixed_size(fixed_size_p) {}; + + //! From a DuckDB type + variable size type + ArrowType(LogicalType type_p, ArrowVariableSizeType size_type_p) + : type(std::move(type_p)), size_type(size_type_p), date_time_precision(ArrowDateTimeType::DAYS) {}; + + //! From a DuckDB type + datetime type + ArrowType(LogicalType type_p, ArrowDateTimeType date_time_precision_p) + : type(std::move(type_p)), size_type(ArrowVariableSizeType::NORMAL), + date_time_precision(date_time_precision_p) {}; + + void AddChild(unique_ptr child); + + void AssignChildren(vector> children); + + const LogicalType &GetDuckType() const; + + ArrowVariableSizeType GetSizeType() const; + + idx_t FixedSize() const; + + void SetDictionary(unique_ptr dictionary); + + ArrowDateTimeType GetDateTimeType() const; + + const ArrowType &GetDictionary() const; + + const ArrowType &operator[](idx_t index) const; + +private: + LogicalType type; + //! If we have a nested type, their children's type. + vector> children; + //! If its a variable size type (e.g., strings, blobs, lists) holds which type it is + ArrowVariableSizeType size_type; + //! If this is a date/time holds its precision + ArrowDateTimeType date_time_precision; + //! Only for size types with fixed size + idx_t fixed_size = 0; + //! Hold the optional type if the array is a dictionary + unique_ptr dictionary_type; +}; + +using arrow_column_map_t = unordered_map>; + +struct ArrowTableType { +public: + void AddColumn(idx_t index, unique_ptr type); + const arrow_column_map_t &GetColumns() const; + +private: + arrow_column_map_t arrow_convert_data; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/list.hpp b/src/duckdb/src/include/duckdb/function/table/list.hpp new file mode 100644 index 00000000..efeade1d --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table/list.hpp @@ -0,0 +1,4 @@ +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/function/table/range.hpp" +#include "duckdb/function/table/summary.hpp" diff --git a/src/duckdb/src/include/duckdb/function/table/range.hpp b/src/duckdb/src/include/duckdb/function/table/range.hpp new file mode 100644 index 00000000..d1dfd04a --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table/range.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table/range.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/table_function.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { + +struct CheckpointFunction { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct GlobTableFunction { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct RangeTableFunction { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct RepeatTableFunction { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct RepeatRowTableFunction { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct UnnestTableFunction { + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/read_csv.hpp b/src/duckdb/src/include/duckdb/function/table/read_csv.hpp new file mode 100644 index 00000000..c3a22fe6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table/read_csv.hpp @@ -0,0 +1,120 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table/read_csv.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/scan/csv/buffered_csv_reader.hpp" +#include "duckdb/execution/operator/scan/csv/csv_buffer.hpp" +#include "duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp" +#include "duckdb/execution/operator/scan/csv/csv_file_handle.hpp" +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp" +#include "duckdb/function/built_in_functions.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp" + +namespace duckdb { + +class ReadCSV { +public: + static unique_ptr OpenCSV(const string &file_path, FileCompressionType compression, + ClientContext &context); +}; + +struct BaseCSVData : public TableFunctionData { + virtual ~BaseCSVData() { + } + //! The file path of the CSV file to read or write + vector files; + //! The CSV reader options + CSVReaderOptions options; + //! Offsets for generated columns + idx_t filename_col_idx; + idx_t hive_partition_col_idx; + + void Finalize(); +}; + +struct WriteCSVData : public BaseCSVData { + WriteCSVData(string file_path, vector sql_types, vector names) + : sql_types(std::move(sql_types)) { + files.push_back(std::move(file_path)); + options.name_list = std::move(names); + } + + //! The SQL types to write + vector sql_types; + //! The newline string to write + string newline = "\n"; + //! The size of the CSV file (in bytes) that we buffer before we flush it to disk + idx_t flush_size = 4096 * 8; + //! For each byte whether or not the CSV file requires quotes when containing the byte + unsafe_unique_array requires_quotes; +}; + +struct ColumnInfo { + ColumnInfo() { + } + ColumnInfo(vector names_p, vector types_p) { + names = std::move(names_p); + types = std::move(types_p); + } + void Serialize(Serializer &serializer) const; + static ColumnInfo Deserialize(Deserializer &deserializer); + + vector names; + vector types; +}; + +struct ReadCSVData : public BaseCSVData { + //! The expected SQL types to read from the file + vector csv_types; + //! The expected SQL names to be read from the file + vector csv_names; + //! The expected SQL types to be returned from the read - including added constants (e.g. filename, hive partitions) + vector return_types; + //! The expected SQL names to be returned from the read - including added constants (e.g. filename, hive partitions) + vector return_names; + //! The buffer manager (if any): this is used when automatic detection is used during binding. + //! In this case, some CSV buffers have already been read and can be reused. + shared_ptr buffer_manager; + unique_ptr initial_reader; + //! The union readers are created (when csv union_by_name option is on) during binding + //! Those readers can be re-used during ReadCSVFunction + vector> union_readers; + //! Whether or not the single-threaded reader should be used + bool single_threaded = false; + //! Reader bind data + MultiFileReaderBindData reader_bind; + vector column_info; + //! The CSVStateMachineCache caches state machines created for sniffing and parsing csv files + //! We cache them because when reading very small csv files, the cost of creating all the possible + //! State machines for sniffing becomes a major bottleneck. + CSVStateMachineCache state_machine_cache; + + void Initialize(unique_ptr &reader) { + this->initial_reader = std::move(reader); + } + void FinalizeRead(ClientContext &context); + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +struct CSVCopyFunction { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct ReadCSVTableFunction { + static TableFunction GetFunction(); + static TableFunction GetAutoFunction(); + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/summary.hpp b/src/duckdb/src/include/duckdb/function/table/summary.hpp new file mode 100644 index 00000000..8d4300e1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table/summary.hpp @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table/summary.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/table_function.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { + +struct SummaryTableFunction { + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp new file mode 100644 index 00000000..f6ff5308 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp @@ -0,0 +1,136 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table/system_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/table_function.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { + +struct PragmaCollations { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaTableInfo { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaStorageInfo { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaMetadataInfo { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaLastProfilingOutput { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaDetailedProfilingOutput { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaVersion { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaPlatform { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct PragmaDatabaseSize { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBSchemasFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBColumnsFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBConstraintsFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBDatabasesFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBDependenciesFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBExtensionsFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBFunctionsFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBKeywordsFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBIndexesFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBSequencesFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBSettingsFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBTablesFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBTemporaryFilesFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBTypesFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct DuckDBViewsFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +struct TestType { + TestType(LogicalType type_p, string name_p) + : type(std::move(type_p)), name(std::move(name_p)), min_value(Value::MinimumValue(type)), + max_value(Value::MaximumValue(type)) { + } + TestType(LogicalType type_p, string name_p, Value min, Value max) + : type(std::move(type_p)), name(std::move(name_p)), min_value(std::move(min)), max_value(std::move(max)) { + } + + LogicalType type; + string name; + Value min_value; + Value max_value; +}; + +struct TestAllTypesFun { + static void RegisterFunction(BuiltinFunctions &set); + static vector GetTestTypes(bool large_enum = false); +}; + +struct TestVectorTypesFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/table_scan.hpp b/src/duckdb/src/include/duckdb/function/table/table_scan.hpp new file mode 100644 index 00000000..0045c466 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table/table_scan.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table/table_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/table_function.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/function/built_in_functions.hpp" + +namespace duckdb { +class DuckTableEntry; +class TableCatalogEntry; + +struct TableScanBindData : public TableFunctionData { + explicit TableScanBindData(DuckTableEntry &table) : table(table), is_index_scan(false), is_create_index(false) { + } + + //! The table to scan + DuckTableEntry &table; + + //! Whether or not the table scan is an index scan + bool is_index_scan; + //! Whether or not the table scan is for index creation + bool is_create_index; + //! The row ids to fetch (in case of an index scan) + vector result_ids; + +public: + bool Equals(const FunctionData &other_p) const override { + auto &other = (const TableScanBindData &)other_p; + return &other.table == &table && result_ids == other.result_ids; + } +}; + +//! The table scan function represents a sequential scan over one of DuckDB's base tables. +struct TableScanFunction { + static void RegisterFunction(BuiltinFunctions &set); + static TableFunction GetFunction(); + static TableFunction GetIndexScanFunction(); + static optional_ptr GetTableEntry(const TableFunction &function, + const optional_ptr bind_data); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table_function.hpp b/src/duckdb/src/include/duckdb/function/table_function.hpp new file mode 100644 index 00000000..ef537304 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table_function.hpp @@ -0,0 +1,284 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/operator_result_type.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/planner/bind_context.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/storage/statistics/node_statistics.hpp" + +#include + +namespace duckdb { + +class BaseStatistics; +class DependencyList; +class LogicalGet; +class TableFilterSet; + +struct TableFunctionInfo { + DUCKDB_API virtual ~TableFunctionInfo(); + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct GlobalTableFunctionState { +public: + // value returned from MaxThreads when as many threads as possible should be used + constexpr static const int64_t MAX_THREADS = 999999999; + +public: + DUCKDB_API virtual ~GlobalTableFunctionState(); + + virtual idx_t MaxThreads() const { + return 1; + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct LocalTableFunctionState { + DUCKDB_API virtual ~LocalTableFunctionState(); + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct TableFunctionBindInput { + TableFunctionBindInput(vector &inputs, named_parameter_map_t &named_parameters, + vector &input_table_types, vector &input_table_names, + optional_ptr info) + : inputs(inputs), named_parameters(named_parameters), input_table_types(input_table_types), + input_table_names(input_table_names), info(info) { + } + + vector &inputs; + named_parameter_map_t &named_parameters; + vector &input_table_types; + vector &input_table_names; + optional_ptr info; +}; + +struct TableFunctionInitInput { + TableFunctionInitInput(optional_ptr bind_data_p, const vector &column_ids_p, + const vector &projection_ids_p, optional_ptr filters_p) + : bind_data(bind_data_p), column_ids(column_ids_p), projection_ids(projection_ids_p), filters(filters_p) { + } + + optional_ptr bind_data; + const vector &column_ids; + const vector projection_ids; + optional_ptr filters; + + bool CanRemoveFilterColumns() const { + if (projection_ids.empty()) { + // Not set, can't remove filter columns + return false; + } else if (projection_ids.size() == column_ids.size()) { + // Filter column is used in remainder of plan, can't remove + return false; + } else { + // Less columns need to be projected out than that we scan + return true; + } + } +}; + +struct TableFunctionInput { +public: + TableFunctionInput(optional_ptr bind_data_p, + optional_ptr local_state_p, + optional_ptr global_state_p) + : bind_data(bind_data_p), local_state(local_state_p), global_state(global_state_p) { + } + +public: + optional_ptr bind_data; + optional_ptr local_state; + optional_ptr global_state; +}; + +enum ScanType { TABLE, PARQUET }; + +struct BindInfo { +public: + explicit BindInfo(ScanType type_p) : type(type_p) {}; + unordered_map options; + ScanType type; + void InsertOption(const string &name, Value value) { + if (options.find(name) != options.end()) { + throw InternalException("This option already exists"); + } + options[name] = std::move(value); + } + template + T GetOption(const string &name) { + if (options.find(name) == options.end()) { + throw InternalException("This option does not exist"); + } + return options[name].GetValue(); + } + template + vector GetOptionList(const string &name) { + if (options.find(name) == options.end()) { + throw InternalException("This option does not exist"); + } + auto option = options[name]; + if (option.type().id() != LogicalTypeId::LIST) { + throw InternalException("This option is not a list"); + } + vector result; + auto list_children = ListValue::GetChildren(option); + for (auto &child : list_children) { + result.emplace_back(child.GetValue()); + } + return result; + } +}; + +typedef unique_ptr (*table_function_bind_t)(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names); +typedef unique_ptr (*table_function_bind_replace_t)(ClientContext &context, TableFunctionBindInput &input); +typedef unique_ptr (*table_function_init_global_t)(ClientContext &context, + TableFunctionInitInput &input); +typedef unique_ptr (*table_function_init_local_t)(ExecutionContext &context, + TableFunctionInitInput &input, + GlobalTableFunctionState *global_state); +typedef unique_ptr (*table_statistics_t)(ClientContext &context, const FunctionData *bind_data, + column_t column_index); +typedef void (*table_function_t)(ClientContext &context, TableFunctionInput &data, DataChunk &output); +typedef OperatorResultType (*table_in_out_function_t)(ExecutionContext &context, TableFunctionInput &data, + DataChunk &input, DataChunk &output); +typedef OperatorFinalizeResultType (*table_in_out_function_final_t)(ExecutionContext &context, TableFunctionInput &data, + DataChunk &output); +typedef idx_t (*table_function_get_batch_index_t)(ClientContext &context, const FunctionData *bind_data, + LocalTableFunctionState *local_state, + GlobalTableFunctionState *global_state); + +typedef BindInfo (*table_function_get_bind_info)(const FunctionData *bind_data); + +typedef double (*table_function_progress_t)(ClientContext &context, const FunctionData *bind_data, + const GlobalTableFunctionState *global_state); +typedef void (*table_function_dependency_t)(DependencyList &dependencies, const FunctionData *bind_data); +typedef unique_ptr (*table_function_cardinality_t)(ClientContext &context, + const FunctionData *bind_data); +typedef void (*table_function_pushdown_complex_filter_t)(ClientContext &context, LogicalGet &get, + FunctionData *bind_data, + vector> &filters); +typedef string (*table_function_to_string_t)(const FunctionData *bind_data); + +typedef void (*table_function_serialize_t)(Serializer &serializer, const optional_ptr bind_data, + const TableFunction &function); +typedef unique_ptr (*table_function_deserialize_t)(Deserializer &deserializer, TableFunction &function); + +class TableFunction : public SimpleNamedParameterFunction { +public: + DUCKDB_API + TableFunction(string name, vector arguments, table_function_t function, + table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, + table_function_init_local_t init_local = nullptr); + DUCKDB_API + TableFunction(const vector &arguments, table_function_t function, table_function_bind_t bind = nullptr, + table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); + DUCKDB_API TableFunction(); + + //! Bind function + //! This function is used for determining the return type of a table producing function and returning bind data + //! The returned FunctionData object should be constant and should not be changed during execution. + table_function_bind_t bind; + //! (Optional) Bind replace function + //! This function is called before the regular bind function. It allows returning a TableRef will be used to + //! to generate a logical plan that replaces the LogicalGet of a regularly bound TableFunction. The BindReplace can + //! also return a nullptr to indicate a regular bind needs to be performed instead. + table_function_bind_replace_t bind_replace; + //! (Optional) global init function + //! Initialize the global operator state of the function. + //! The global operator state is used to keep track of the progress in the table function and is shared between + //! all threads working on the table function. + table_function_init_global_t init_global; + //! (Optional) local init function + //! Initialize the local operator state of the function. + //! The local operator state is used to keep track of the progress in the table function and is thread-local. + table_function_init_local_t init_local; + //! The main function + table_function_t function; + //! The table in-out function (if this is an in-out function) + table_in_out_function_t in_out_function; + //! The table in-out final function (if this is an in-out function) + table_in_out_function_final_t in_out_function_final; + //! (Optional) statistics function + //! Returns the statistics of a specified column + table_statistics_t statistics; + //! (Optional) dependency function + //! Sets up which catalog entries this table function depend on + table_function_dependency_t dependency; + //! (Optional) cardinality function + //! Returns the expected cardinality of this scan + table_function_cardinality_t cardinality; + //! (Optional) pushdown a set of arbitrary filter expressions, rather than only simple comparisons with a constant + //! Any functions remaining in the expression list will be pushed as a regular filter after the scan + table_function_pushdown_complex_filter_t pushdown_complex_filter; + //! (Optional) function for rendering the operator to a string in profiling output + table_function_to_string_t to_string; + //! (Optional) return how much of the table we have scanned up to this point (% of the data) + table_function_progress_t table_scan_progress; + //! (Optional) returns the current batch index of the current scan operator + table_function_get_batch_index_t get_batch_index; + //! (Optional) returns the extra batch info, currently only used for the substrait extension + table_function_get_bind_info get_batch_info; + + table_function_serialize_t serialize; + table_function_deserialize_t deserialize; + bool verify_serialization = true; + + //! Whether or not the table function supports projection pushdown. If not supported a projection will be added + //! that filters out unused columns. + bool projection_pushdown; + //! Whether or not the table function supports filter pushdown. If not supported a filter will be added + //! that applies the table filter directly. + bool filter_pushdown; + //! Whether or not the table function can immediately prune out filter columns that are unused in the remainder of + //! the query plan, e.g., "SELECT i FROM tbl WHERE j = 42;" - j does not need to leave the table function at all + bool filter_prune; + //! Additional function info, passed to the bind + shared_ptr function_info; + + DUCKDB_API bool Equal(const TableFunction &rhs) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table_macro_function.hpp b/src/duckdb/src/include/duckdb/function/table_macro_function.hpp new file mode 100644 index 00000000..8ac33d1e --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/table_macro_function.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/table_macro_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/macro_function.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" + +namespace duckdb { + +class TableMacroFunction : public MacroFunction { +public: + static constexpr const MacroType TYPE = MacroType::TABLE_MACRO; + +public: + explicit TableMacroFunction(unique_ptr query_node); + TableMacroFunction(void); + + //! The main query node + unique_ptr query_node; + +public: + unique_ptr Copy() const override; + + string ToSQL(const string &schema, const string &name) const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/udf_function.hpp b/src/duckdb/src/include/duckdb/function/udf_function.hpp new file mode 100644 index 00000000..65569293 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/udf_function.hpp @@ -0,0 +1,379 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/function/udf_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/aggregate_function.hpp" + +namespace duckdb { + +struct UDFWrapper { +public: + template + inline static scalar_function_t CreateScalarFunction(const string &name, TR (*udf_func)(Args...)) { + const std::size_t num_template_argc = sizeof...(Args); + switch (num_template_argc) { + case 1: + return CreateUnaryFunction(name, udf_func); + case 2: + return CreateBinaryFunction(name, udf_func); + case 3: + return CreateTernaryFunction(name, udf_func); + default: // LCOV_EXCL_START + throw std::runtime_error("UDF function only supported until ternary!"); + } // LCOV_EXCL_STOP + } + + template + inline static scalar_function_t CreateScalarFunction(const string &name, vector args, + LogicalType ret_type, TR (*udf_func)(Args...)) { + if (!TypesMatch(ret_type)) { // LCOV_EXCL_START + throw std::runtime_error("Return type doesn't match with the first template type."); + } // LCOV_EXCL_STOP + + const std::size_t num_template_types = sizeof...(Args); + if (num_template_types != args.size()) { // LCOV_EXCL_START + throw std::runtime_error( + "The number of templated types should be the same quantity of the LogicalType arguments."); + } // LCOV_EXCL_STOP + + switch (num_template_types) { + case 1: + return CreateUnaryFunction(name, args, ret_type, udf_func); + case 2: + return CreateBinaryFunction(name, args, ret_type, udf_func); + case 3: + return CreateTernaryFunction(name, args, ret_type, udf_func); + default: // LCOV_EXCL_START + throw std::runtime_error("UDF function only supported until ternary!"); + } // LCOV_EXCL_STOP + } + + template + inline static void RegisterFunction(const string &name, scalar_function_t udf_function, ClientContext &context, + LogicalType varargs = LogicalType(LogicalTypeId::INVALID)) { + vector arguments; + GetArgumentTypesRecursive(arguments); + + LogicalType ret_type = GetArgumentType(); + + RegisterFunction(name, arguments, ret_type, udf_function, context, varargs); + } + + static void RegisterFunction(string name, vector args, LogicalType ret_type, + scalar_function_t udf_function, ClientContext &context, + LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); + + //--------------------------------- Aggregate UDFs ------------------------------------// + template + inline static AggregateFunction CreateAggregateFunction(const string &name) { + return CreateUnaryAggregateFunction(name); + } + + template + inline static AggregateFunction CreateAggregateFunction(const string &name) { + return CreateBinaryAggregateFunction(name); + } + + template + inline static AggregateFunction CreateAggregateFunction(const string &name, LogicalType ret_type, + LogicalType input_type) { + if (!TypesMatch(ret_type)) { // LCOV_EXCL_START + throw std::runtime_error("The return argument don't match!"); + } // LCOV_EXCL_STOP + + if (!TypesMatch(input_type)) { // LCOV_EXCL_START + throw std::runtime_error("The input argument don't match!"); + } // LCOV_EXCL_STOP + + return CreateUnaryAggregateFunction(name, ret_type, input_type); + } + + template + inline static AggregateFunction CreateAggregateFunction(const string &name, LogicalType ret_type, + LogicalType input_typeA, LogicalType input_typeB) { + if (!TypesMatch(ret_type)) { // LCOV_EXCL_START + throw std::runtime_error("The return argument don't match!"); + } + + if (!TypesMatch(input_typeA)) { + throw std::runtime_error("The first input argument don't match!"); + } + + if (!TypesMatch(input_typeB)) { + throw std::runtime_error("The second input argument don't match!"); + } // LCOV_EXCL_STOP + + return CreateBinaryAggregateFunction(name, ret_type, input_typeA, input_typeB); + } + + //! A generic CreateAggregateFunction ---------------------------------------------------------------------------// + inline static AggregateFunction + CreateAggregateFunction(string name, vector arguments, LogicalType return_type, + aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, + aggregate_combine_t combine, aggregate_finalize_t finalize, + aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, + aggregate_destructor_t destructor = nullptr) { + + AggregateFunction aggr_function(std::move(name), std::move(arguments), std::move(return_type), state_size, + initialize, update, combine, finalize, simple_update, bind, destructor); + aggr_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return aggr_function; + } + + static void RegisterAggrFunction(AggregateFunction aggr_function, ClientContext &context, + LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); + +private: + //-------------------------------- Templated functions --------------------------------// + struct UnaryUDFExecutor { + template + static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + typedef RESULT_TYPE (*unary_function_t)(INPUT_TYPE); + auto udf = (unary_function_t)dataptr; + return udf(input); + } + }; + + template + inline static scalar_function_t CreateUnaryFunction(const string &name, TR (*udf_func)(TA)) { + scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { + UnaryExecutor::GenericExecute(input.data[0], result, input.size(), + (void *)udf_func); + }; + return udf_function; + } + + template + inline static scalar_function_t CreateBinaryFunction(const string &name, TR (*udf_func)(TA, TB)) { + scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { + BinaryExecutor::Execute(input.data[0], input.data[1], result, input.size(), udf_func); + }; + return udf_function; + } + + template + inline static scalar_function_t CreateTernaryFunction(const string &name, TR (*udf_func)(TA, TB, TC)) { + scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { + TernaryExecutor::Execute(input.data[0], input.data[1], input.data[2], result, input.size(), + udf_func); + }; + return udf_function; + } + + template + inline static scalar_function_t CreateUnaryFunction(const string &name, + TR (*udf_func)(Args...)) { // LCOV_EXCL_START + throw std::runtime_error("Incorrect number of arguments for unary function"); + } // LCOV_EXCL_STOP + + template + inline static scalar_function_t CreateBinaryFunction(const string &name, + TR (*udf_func)(Args...)) { // LCOV_EXCL_START + throw std::runtime_error("Incorrect number of arguments for binary function"); + } // LCOV_EXCL_STOP + + template + inline static scalar_function_t CreateTernaryFunction(const string &name, + TR (*udf_func)(Args...)) { // LCOV_EXCL_START + throw std::runtime_error("Incorrect number of arguments for ternary function"); + } // LCOV_EXCL_STOP + + template + inline static LogicalType GetArgumentType() { + if (std::is_same()) { + return LogicalType(LogicalTypeId::BOOLEAN); + } else if (std::is_same()) { + return LogicalType(LogicalTypeId::TINYINT); + } else if (std::is_same()) { + return LogicalType(LogicalTypeId::SMALLINT); + } else if (std::is_same()) { + return LogicalType(LogicalTypeId::INTEGER); + } else if (std::is_same()) { + return LogicalType(LogicalTypeId::BIGINT); + } else if (std::is_same()) { + return LogicalType(LogicalTypeId::FLOAT); + } else if (std::is_same()) { + return LogicalType(LogicalTypeId::DOUBLE); + } else if (std::is_same()) { + return LogicalType(LogicalTypeId::VARCHAR); + } else { // LCOV_EXCL_START + throw std::runtime_error("Unrecognized type!"); + } // LCOV_EXCL_STOP + } + + template + inline static void GetArgumentTypesRecursive(vector &arguments) { + arguments.push_back(GetArgumentType()); + GetArgumentTypesRecursive(arguments); + } + + template + inline static void GetArgumentTypesRecursive(vector &arguments) { + arguments.push_back(GetArgumentType()); + } + +private: + //-------------------------------- Argumented functions --------------------------------// + + template + inline static scalar_function_t CreateUnaryFunction(const string &name, vector args, + LogicalType ret_type, + TR (*udf_func)(Args...)) { // LCOV_EXCL_START + throw std::runtime_error("Incorrect number of arguments for unary function"); + } // LCOV_EXCL_STOP + + template + inline static scalar_function_t CreateUnaryFunction(const string &name, vector args, + LogicalType ret_type, TR (*udf_func)(TA)) { + if (args.size() != 1) { // LCOV_EXCL_START + throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 1!"); + } + if (!TypesMatch(args[0])) { + throw std::runtime_error("The first arguments don't match!"); + } // LCOV_EXCL_STOP + + scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { + UnaryExecutor::GenericExecute(input.data[0], result, input.size(), + (void *)udf_func); + }; + return udf_function; + } + + template + inline static scalar_function_t CreateBinaryFunction(const string &name, vector args, + LogicalType ret_type, + TR (*udf_func)(Args...)) { // LCOV_EXCL_START + throw std::runtime_error("Incorrect number of arguments for binary function"); + } // LCOV_EXCL_STOP + + template + inline static scalar_function_t CreateBinaryFunction(const string &name, vector args, + LogicalType ret_type, TR (*udf_func)(TA, TB)) { + if (args.size() != 2) { // LCOV_EXCL_START + throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 2!"); + } + if (!TypesMatch(args[0])) { + throw std::runtime_error("The first arguments don't match!"); + } + if (!TypesMatch(args[1])) { + throw std::runtime_error("The second arguments don't match!"); + } // LCOV_EXCL_STOP + + scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute(input.data[0], input.data[1], result, input.size(), udf_func); + }; + return udf_function; + } + + template + inline static scalar_function_t CreateTernaryFunction(const string &name, vector args, + LogicalType ret_type, + TR (*udf_func)(Args...)) { // LCOV_EXCL_START + throw std::runtime_error("Incorrect number of arguments for ternary function"); + } // LCOV_EXCL_STOP + + template + inline static scalar_function_t CreateTernaryFunction(const string &name, vector args, + LogicalType ret_type, TR (*udf_func)(TA, TB, TC)) { + if (args.size() != 3) { // LCOV_EXCL_START + throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 3!"); + } + if (!TypesMatch(args[0])) { + throw std::runtime_error("The first arguments don't match!"); + } + if (!TypesMatch(args[1])) { + throw std::runtime_error("The second arguments don't match!"); + } + if (!TypesMatch(args[2])) { + throw std::runtime_error("The second arguments don't match!"); + } // LCOV_EXCL_STOP + + scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { + TernaryExecutor::Execute(input.data[0], input.data[1], input.data[2], result, input.size(), + udf_func); + }; + return udf_function; + } + + template + inline static bool TypesMatch(const LogicalType &sql_type) { + switch (sql_type.id()) { + case LogicalTypeId::BOOLEAN: + return std::is_same(); + case LogicalTypeId::TINYINT: + return std::is_same(); + case LogicalTypeId::SMALLINT: + return std::is_same(); + case LogicalTypeId::INTEGER: + return std::is_same(); + case LogicalTypeId::BIGINT: + return std::is_same(); + case LogicalTypeId::DATE: + return std::is_same(); + case LogicalTypeId::TIME: + return std::is_same(); + case LogicalTypeId::TIME_TZ: + return std::is_same(); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_TZ: + return std::is_same(); + case LogicalTypeId::FLOAT: + return std::is_same(); + case LogicalTypeId::DOUBLE: + return std::is_same(); + case LogicalTypeId::VARCHAR: + case LogicalTypeId::CHAR: + case LogicalTypeId::BLOB: + return std::is_same(); + default: // LCOV_EXCL_START + throw std::runtime_error("Type is not supported!"); + } // LCOV_EXCL_STOP + } + +private: + //-------------------------------- Aggregate functions --------------------------------// + template + inline static AggregateFunction CreateUnaryAggregateFunction(const string &name) { + LogicalType return_type = GetArgumentType(); + LogicalType input_type = GetArgumentType(); + return CreateUnaryAggregateFunction(name, return_type, input_type); + } + + template + inline static AggregateFunction CreateUnaryAggregateFunction(const string &name, LogicalType ret_type, + LogicalType input_type) { + AggregateFunction aggr_function = + AggregateFunction::UnaryAggregate(input_type, ret_type); + aggr_function.name = name; + return aggr_function; + } + + template + inline static AggregateFunction CreateBinaryAggregateFunction(const string &name) { + LogicalType return_type = GetArgumentType(); + LogicalType input_typeA = GetArgumentType(); + LogicalType input_typeB = GetArgumentType(); + return CreateBinaryAggregateFunction(name, return_type, input_typeA, input_typeB); + } + + template + inline static AggregateFunction CreateBinaryAggregateFunction(const string &name, LogicalType ret_type, + LogicalType input_typeA, LogicalType input_typeB) { + AggregateFunction aggr_function = + AggregateFunction::BinaryAggregate(input_typeA, input_typeB, ret_type); + aggr_function.name = name; + return aggr_function; + } +}; // end UDFWrapper + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/appender.hpp b/src/duckdb/src/include/duckdb/main/appender.hpp new file mode 100644 index 00000000..4fa513a7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/appender.hpp @@ -0,0 +1,184 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/appender.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/main/table_description.hpp" + +namespace duckdb { + +class ColumnDataCollection; +class ClientContext; +class DuckDB; +class TableCatalogEntry; +class Connection; + +enum class AppenderType : uint8_t { + LOGICAL, // Cast input -> LogicalType + PHYSICAL // Cast input -> PhysicalType +}; + +//! The Appender class can be used to append elements to a table. +class BaseAppender { +protected: + //! The amount of tuples that will be gathered in the column data collection before flushing + static constexpr const idx_t FLUSH_COUNT = STANDARD_VECTOR_SIZE * 100; + + Allocator &allocator; + //! The append types + vector types; + //! The buffered data for the append + unique_ptr collection; + //! Internal chunk used for appends + DataChunk chunk; + //! The current column to append to + idx_t column = 0; + //! The type of the appender + AppenderType appender_type; + +protected: + DUCKDB_API BaseAppender(Allocator &allocator, AppenderType type); + DUCKDB_API BaseAppender(Allocator &allocator, vector types, AppenderType type); + +public: + DUCKDB_API virtual ~BaseAppender(); + + //! Begins a new row append, after calling this the other AppendX() functions + //! should be called the correct amount of times. After that, + //! EndRow() should be called. + DUCKDB_API void BeginRow(); + //! Finishes appending the current row. + DUCKDB_API void EndRow(); + + // Append functions + template + void Append(T value) { + throw Exception("Undefined type for Appender::Append!"); + } + + DUCKDB_API void Append(const char *value, uint32_t length); + + // prepared statements + template + void AppendRow(Args... args) { + BeginRow(); + AppendRowRecursive(args...); + } + + //! Commit the changes made by the appender. + DUCKDB_API void Flush(); + //! Flush the changes made by the appender and close it. The appender cannot be used after this point + DUCKDB_API void Close(); + + vector &GetTypes() { + return types; + } + idx_t CurrentColumn() { + return column; + } + DUCKDB_API void AppendDataChunk(DataChunk &value); + +protected: + void Destructor(); + virtual void FlushInternal(ColumnDataCollection &collection) = 0; + void InitializeChunk(); + void FlushChunk(); + + template + void AppendValueInternal(T value); + template + void AppendValueInternal(Vector &vector, SRC input); + template + void AppendDecimalValueInternal(Vector &vector, SRC input); + + void AppendRowRecursive() { + EndRow(); + } + + template + void AppendRowRecursive(T value, Args... args) { + Append(value); + AppendRowRecursive(args...); + } + + void AppendValue(const Value &value); +}; + +class Appender : public BaseAppender { + //! A reference to a database connection that created this appender + shared_ptr context; + //! The table description (including column names) + unique_ptr description; + +public: + DUCKDB_API Appender(Connection &con, const string &schema_name, const string &table_name); + DUCKDB_API Appender(Connection &con, const string &table_name); + DUCKDB_API ~Appender() override; + +protected: + void FlushInternal(ColumnDataCollection &collection) override; +}; + +class InternalAppender : public BaseAppender { + //! The client context + ClientContext &context; + //! The internal table entry to append to + TableCatalogEntry &table; + +public: + DUCKDB_API InternalAppender(ClientContext &context, TableCatalogEntry &table); + DUCKDB_API ~InternalAppender() override; + +protected: + void FlushInternal(ColumnDataCollection &collection) override; +}; + +template <> +DUCKDB_API void BaseAppender::Append(bool value); +template <> +DUCKDB_API void BaseAppender::Append(int8_t value); +template <> +DUCKDB_API void BaseAppender::Append(int16_t value); +template <> +DUCKDB_API void BaseAppender::Append(int32_t value); +template <> +DUCKDB_API void BaseAppender::Append(int64_t value); +template <> +DUCKDB_API void BaseAppender::Append(hugeint_t value); +template <> +DUCKDB_API void BaseAppender::Append(uint8_t value); +template <> +DUCKDB_API void BaseAppender::Append(uint16_t value); +template <> +DUCKDB_API void BaseAppender::Append(uint32_t value); +template <> +DUCKDB_API void BaseAppender::Append(uint64_t value); +template <> +DUCKDB_API void BaseAppender::Append(float value); +template <> +DUCKDB_API void BaseAppender::Append(double value); +template <> +DUCKDB_API void BaseAppender::Append(date_t value); +template <> +DUCKDB_API void BaseAppender::Append(dtime_t value); +template <> +DUCKDB_API void BaseAppender::Append(timestamp_t value); +template <> +DUCKDB_API void BaseAppender::Append(interval_t value); +template <> +DUCKDB_API void BaseAppender::Append(const char *value); +template <> +DUCKDB_API void BaseAppender::Append(string_t value); +template <> +DUCKDB_API void BaseAppender::Append(Value value); +template <> +DUCKDB_API void BaseAppender::Append(std::nullptr_t value); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/attached_database.hpp b/src/duckdb/src/include/duckdb/main/attached_database.hpp new file mode 100644 index 00000000..14b6f6e0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/attached_database.hpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/attached_database.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/catalog/catalog_entry.hpp" + +namespace duckdb { +class Catalog; +class DatabaseInstance; +class StorageManager; +class TransactionManager; +class StorageExtension; + +struct AttachInfo; + +enum class AttachedDatabaseType { + READ_WRITE_DATABASE, + READ_ONLY_DATABASE, + SYSTEM_DATABASE, + TEMP_DATABASE, +}; + +//! The AttachedDatabase represents an attached database instance +class AttachedDatabase : public CatalogEntry { +public: + //! Create the built-in system attached database (without storage) + explicit AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType type = AttachedDatabaseType::SYSTEM_DATABASE); + //! Create an attached database instance with the specified name and storage + AttachedDatabase(DatabaseInstance &db, Catalog &catalog, string name, string file_path, AccessMode access_mode); + //! Create an attached database instance with the specified storage extension + AttachedDatabase(DatabaseInstance &db, Catalog &catalog, StorageExtension &ext, string name, AttachInfo &info, + AccessMode access_mode); + ~AttachedDatabase() override; + + void Initialize(); + + Catalog &ParentCatalog() override; + StorageManager &GetStorageManager(); + Catalog &GetCatalog(); + TransactionManager &GetTransactionManager(); + DatabaseInstance &GetDatabase() { + return db; + } + const string &GetName() const { + return name; + } + bool IsSystem() const; + bool IsTemporary() const; + bool IsReadOnly() const; + bool IsInitialDatabase() const; + void SetInitialDatabase(); + + static string ExtractDatabaseName(const string &dbpath, FileSystem &fs); + +private: + DatabaseInstance &db; + unique_ptr storage; + unique_ptr catalog; + unique_ptr transaction_manager; + AttachedDatabaseType type; + optional_ptr parent_catalog; + bool is_initial_database = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp new file mode 100644 index 00000000..cd4ae626 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp @@ -0,0 +1,81 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/capi/capi_internal.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/main/appender.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +#include +#include + +#ifdef _WIN32 +#ifndef strdup +#define strdup _strdup +#endif +#endif + +namespace duckdb { + +struct DatabaseData { + unique_ptr database; +}; + +struct PreparedStatementWrapper { + //! Map of name -> values + case_insensitive_map_t values; + unique_ptr statement; +}; + +struct ExtractStatementsWrapper { + vector> statements; + string error; +}; + +struct PendingStatementWrapper { + unique_ptr statement; + bool allow_streaming; +}; + +struct ArrowResultWrapper { + unique_ptr result; + unique_ptr current_chunk; + ClientProperties options; +}; + +struct AppenderWrapper { + unique_ptr appender; + string error; +}; + +enum class CAPIResultSetType : uint8_t { + CAPI_RESULT_TYPE_NONE = 0, + CAPI_RESULT_TYPE_MATERIALIZED, + CAPI_RESULT_TYPE_STREAMING, + CAPI_RESULT_TYPE_DEPRECATED +}; + +struct DuckDBResultData { + //! The underlying query result + unique_ptr result; + // Results can only use either the new API or the old API, not a mix of the two + // They start off as "none" and switch to one or the other when an API method is used + CAPIResultSetType result_set_type; +}; + +duckdb_type ConvertCPPTypeToC(const LogicalType &type); +LogicalTypeId ConvertCTypeToCPP(duckdb_type c_type); +idx_t GetCTypeSize(duckdb_type type); +duckdb_state duckdb_translate_result(unique_ptr result, duckdb_result *out); +bool deprecated_materialize_result(duckdb_result *result); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/capi/cast/from_decimal.hpp b/src/duckdb/src/include/duckdb/main/capi/cast/from_decimal.hpp new file mode 100644 index 00000000..de3f49ec --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/capi/cast/from_decimal.hpp @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/capi/capi_cast_from_decimal.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/capi/cast/utils.hpp" + +namespace duckdb { + +//! DECIMAL -> ? +template +bool CastDecimalCInternal(duckdb_result *source, RESULT_TYPE &result, idx_t col, idx_t row) { + auto result_data = (duckdb::DuckDBResultData *)source->internal_data; + auto &query_result = result_data->result; + auto &source_type = query_result->types[col]; + auto width = duckdb::DecimalType::GetWidth(source_type); + auto scale = duckdb::DecimalType::GetScale(source_type); + void *source_address = UnsafeFetchPtr(source, col, row); + switch (source_type.InternalType()) { + case duckdb::PhysicalType::INT16: + return duckdb::TryCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), + result, nullptr, width, scale); + case duckdb::PhysicalType::INT32: + return duckdb::TryCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), + result, nullptr, width, scale); + case duckdb::PhysicalType::INT64: + return duckdb::TryCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), + result, nullptr, width, scale); + case duckdb::PhysicalType::INT128: + return duckdb::TryCastFromDecimal::Operation( + UnsafeFetchFromPtr(source_address), result, nullptr, width, scale); + default: + throw duckdb::InternalException("Unimplemented internal type for decimal"); + } +} + +//! DECIMAL -> VARCHAR +template <> +bool CastDecimalCInternal(duckdb_result *source, duckdb_string &result, idx_t col, idx_t row); + +//! DECIMAL -> DECIMAL (internal fetch) +template <> +bool CastDecimalCInternal(duckdb_result *source, duckdb_decimal &result, idx_t col, idx_t row); + +//! DECIMAL -> ... +template +RESULT_TYPE TryCastDecimalCInternal(duckdb_result *source, idx_t col, idx_t row) { + RESULT_TYPE result_value; + try { + if (!CastDecimalCInternal(source, result_value, col, row)) { + return FetchDefaultValue::Operation(); + } + } catch (...) { + return FetchDefaultValue::Operation(); + } + return result_value; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/capi/cast/generic.hpp b/src/duckdb/src/include/duckdb/main/capi/cast/generic.hpp new file mode 100644 index 00000000..3aa708f4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/capi/cast/generic.hpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/capi/cast/generic_cast.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/date.hpp" + +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/capi/cast/utils.hpp" +#include "duckdb/main/capi/cast/from_decimal.hpp" + +namespace duckdb { + +template +RESULT_TYPE GetInternalCValue(duckdb_result *result, idx_t col, idx_t row) { + if (!CanFetchValue(result, col, row)) { + return FetchDefaultValue::Operation(); + } + switch (result->__deprecated_columns[col].__deprecated_type) { + case DUCKDB_TYPE_BOOLEAN: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_TINYINT: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_SMALLINT: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_INTEGER: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_BIGINT: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_UTINYINT: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_USMALLINT: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_UINTEGER: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_UBIGINT: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_FLOAT: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_DOUBLE: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_DATE: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_TIME: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_TIMESTAMP: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_HUGEINT: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_DECIMAL: + return TryCastDecimalCInternal(result, col, row); + case DUCKDB_TYPE_INTERVAL: + return TryCastCInternal(result, col, row); + case DUCKDB_TYPE_VARCHAR: + return TryCastCInternal>(result, col, row); + case DUCKDB_TYPE_BLOB: + return TryCastCInternal(result, col, row); + default: { // LCOV_EXCL_START + // invalid type for C to C++ conversion + D_ASSERT(0); + return FetchDefaultValue::Operation(); + } // LCOV_EXCL_STOP + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/capi/cast/to_decimal.hpp b/src/duckdb/src/include/duckdb/main/capi/cast/to_decimal.hpp new file mode 100644 index 00000000..049423d5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/capi/cast/to_decimal.hpp @@ -0,0 +1,134 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/capi/capi_cast_from_decimal.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/capi/cast/utils.hpp" + +namespace duckdb { + +template +struct ToCDecimalCastWrapper { + template + static bool Operation(SOURCE_TYPE input, duckdb_decimal &result, std::string *error, uint8_t width, uint8_t scale) { + throw NotImplementedException("Type not implemented for CDecimalCastWrapper"); + } +}; + +//! Hugeint +template <> +struct ToCDecimalCastWrapper { + template + static bool Operation(SOURCE_TYPE input, duckdb_decimal &result, std::string *error, uint8_t width, uint8_t scale) { + hugeint_t intermediate_result; + + if (!TryCastToDecimal::Operation(input, intermediate_result, error, width, scale)) { + result = FetchDefaultValue::Operation(); + return false; + } + result.scale = scale; + result.width = width; + + duckdb_hugeint hugeint_value; + hugeint_value.upper = intermediate_result.upper; + hugeint_value.lower = intermediate_result.lower; + result.value = hugeint_value; + return true; + } +}; + +//! FIXME: reduce duplication here by just matching on the signed-ness of the type +//! INTERNAL_TYPE = int16_t +template <> +struct ToCDecimalCastWrapper { + template + static bool Operation(SOURCE_TYPE input, duckdb_decimal &result, std::string *error, uint8_t width, uint8_t scale) { + int16_t intermediate_result; + + if (!TryCastToDecimal::Operation(input, intermediate_result, error, width, scale)) { + result = FetchDefaultValue::Operation(); + return false; + } + hugeint_t hugeint_result = Hugeint::Convert(intermediate_result); + + result.scale = scale; + result.width = width; + + duckdb_hugeint hugeint_value; + hugeint_value.upper = hugeint_result.upper; + hugeint_value.lower = hugeint_result.lower; + result.value = hugeint_value; + return true; + } +}; +//! INTERNAL_TYPE = int32_t +template <> +struct ToCDecimalCastWrapper { + template + static bool Operation(SOURCE_TYPE input, duckdb_decimal &result, std::string *error, uint8_t width, uint8_t scale) { + int32_t intermediate_result; + + if (!TryCastToDecimal::Operation(input, intermediate_result, error, width, scale)) { + result = FetchDefaultValue::Operation(); + return false; + } + hugeint_t hugeint_result = Hugeint::Convert(intermediate_result); + + result.scale = scale; + result.width = width; + + duckdb_hugeint hugeint_value; + hugeint_value.upper = hugeint_result.upper; + hugeint_value.lower = hugeint_result.lower; + result.value = hugeint_value; + return true; + } +}; +//! INTERNAL_TYPE = int64_t +template <> +struct ToCDecimalCastWrapper { + template + static bool Operation(SOURCE_TYPE input, duckdb_decimal &result, std::string *error, uint8_t width, uint8_t scale) { + int64_t intermediate_result; + + if (!TryCastToDecimal::Operation(input, intermediate_result, error, width, scale)) { + result = FetchDefaultValue::Operation(); + return false; + } + hugeint_t hugeint_result = Hugeint::Convert(intermediate_result); + + result.scale = scale; + result.width = width; + + duckdb_hugeint hugeint_value; + hugeint_value.upper = hugeint_result.upper; + hugeint_value.lower = hugeint_result.lower; + result.value = hugeint_value; + return true; + } +}; + +template +duckdb_decimal TryCastToDecimalCInternal(SOURCE_TYPE source, uint8_t width, uint8_t scale) { + duckdb_decimal result; + try { + if (!OP::template Operation(source, result, nullptr, width, scale)) { + return FetchDefaultValue::Operation(); + } + } catch (...) { + return FetchDefaultValue::Operation(); + } + return result; +} + +template +duckdb_decimal TryCastToDecimalCInternal(duckdb_result *result, idx_t col, idx_t row, uint8_t width, uint8_t scale) { + return TryCastToDecimalCInternal(UnsafeFetch(result, col, row), width, scale); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/capi/cast/utils.hpp b/src/duckdb/src/include/duckdb/main/capi/cast/utils.hpp new file mode 100644 index 00000000..fa9a73c6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/capi/cast/utils.hpp @@ -0,0 +1,124 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/capi/cast/utils.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/string_cast.hpp" +#include "duckdb/common/operator/decimal_cast_operators.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Unsafe Fetch (for internal use only) +//===--------------------------------------------------------------------===// +template +T UnsafeFetchFromPtr(void *pointer) { + return *((T *)pointer); +} + +template +void *UnsafeFetchPtr(duckdb_result *result, idx_t col, idx_t row) { + D_ASSERT(row < result->__deprecated_row_count); + return (void *)&(((T *)result->__deprecated_columns[col].__deprecated_data)[row]); +} + +template +T UnsafeFetch(duckdb_result *result, idx_t col, idx_t row) { + return UnsafeFetchFromPtr(UnsafeFetchPtr(result, col, row)); +} + +//===--------------------------------------------------------------------===// +// Fetch Default Value +//===--------------------------------------------------------------------===// +struct FetchDefaultValue { + template + static T Operation() { + return 0; + } +}; + +template <> +duckdb_decimal FetchDefaultValue::Operation(); +template <> +date_t FetchDefaultValue::Operation(); +template <> +dtime_t FetchDefaultValue::Operation(); +template <> +timestamp_t FetchDefaultValue::Operation(); +template <> +interval_t FetchDefaultValue::Operation(); +template <> +char *FetchDefaultValue::Operation(); +template <> +duckdb_string FetchDefaultValue::Operation(); +template <> +duckdb_blob FetchDefaultValue::Operation(); + +//===--------------------------------------------------------------------===// +// String Casts +//===--------------------------------------------------------------------===// +template +struct FromCStringCastWrapper { + template + static bool Operation(SOURCE_TYPE input_str, RESULT_TYPE &result) { + string_t input(input_str); + return OP::template Operation(input, result); + } +}; + +template +struct ToCStringCastWrapper { + template + static bool Operation(SOURCE_TYPE input, RESULT_TYPE &result) { + Vector result_vector(LogicalType::VARCHAR, nullptr); + auto result_string = OP::template Operation(input, result_vector); + auto result_size = result_string.GetSize(); + auto result_data = result_string.GetData(); + + char *allocated_data = char_ptr_cast(duckdb_malloc(result_size + 1)); + memcpy(allocated_data, result_data, result_size); + allocated_data[result_size] = '\0'; + result.data = allocated_data; + result.size = result_size; + return true; + } +}; + +//===--------------------------------------------------------------------===// +// Blob Casts +//===--------------------------------------------------------------------===// +struct FromCBlobCastWrapper { + template + static bool Operation(SOURCE_TYPE input_str, RESULT_TYPE &result) { + return false; + } +}; + +template <> +bool FromCBlobCastWrapper::Operation(duckdb_blob input, duckdb_string &result); + +template +RESULT_TYPE TryCastCInternal(duckdb_result *result, idx_t col, idx_t row) { + RESULT_TYPE result_value; + try { + if (!OP::template Operation(UnsafeFetch(result, col, row), + result_value)) { + return FetchDefaultValue::Operation(); + } + } catch (...) { + return FetchDefaultValue::Operation(); + } + return result_value; +} + +} // namespace duckdb + +bool CanFetchValue(duckdb_result *result, idx_t col, idx_t row); +bool CanUseDeprecatedFetch(duckdb_result *result, idx_t col, idx_t row); diff --git a/src/duckdb/src/include/duckdb/main/chunk_scan_state.hpp b/src/duckdb/src/include/duckdb/main/chunk_scan_state.hpp new file mode 100644 index 00000000..8849b09b --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/chunk_scan_state.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include "duckdb/common/vector.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/preserved_error.hpp" + +namespace duckdb { + +class DataChunk; + +//! Abstract chunk fetcher +class ChunkScanState { +public: + explicit ChunkScanState(); + virtual ~ChunkScanState(); + +public: + ChunkScanState(const ChunkScanState &other) = delete; + ChunkScanState(ChunkScanState &&other) = default; + ChunkScanState &operator=(const ChunkScanState &other) = delete; + ChunkScanState &operator=(ChunkScanState &&other) = default; + +public: + virtual bool LoadNextChunk(PreservedError &error) = 0; + virtual bool HasError() const = 0; + virtual PreservedError &GetError() = 0; + virtual const vector &Types() const = 0; + virtual const vector &Names() const = 0; + idx_t CurrentOffset() const; + idx_t RemainingInChunk() const; + DataChunk &CurrentChunk(); + bool ChunkIsEmpty() const; + bool Finished() const; + bool ScanStarted() const; + void IncreaseOffset(idx_t increment, bool unsafe = false); + +protected: + idx_t offset = 0; + bool finished = false; + unique_ptr current_chunk; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/chunk_scan_state/query_result.hpp b/src/duckdb/src/include/duckdb/main/chunk_scan_state/query_result.hpp new file mode 100644 index 00000000..d6f21a50 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/chunk_scan_state/query_result.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include "duckdb/main/chunk_scan_state.hpp" +#include "duckdb/common/preserved_error.hpp" + +namespace duckdb { + +class QueryResult; + +class QueryResultChunkScanState : public ChunkScanState { +public: + QueryResultChunkScanState(QueryResult &result); + ~QueryResultChunkScanState(); + +public: + bool LoadNextChunk(PreservedError &error) override; + bool HasError() const override; + PreservedError &GetError() override; + const vector &Types() const override; + const vector &Names() const override; + +private: + bool InternalLoad(PreservedError &error); + +private: + QueryResult &result; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/client_config.hpp b/src/duckdb/src/include/duckdb/main/client_config.hpp new file mode 100644 index 00000000..56c67f47 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/client_config.hpp @@ -0,0 +1,127 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/client_config.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/output_type.hpp" +#include "duckdb/common/enums/profiler_format.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/progress_bar/progress_bar.hpp" + +namespace duckdb { +class ClientContext; +class PhysicalResultCollector; +class PreparedStatementData; + +typedef std::function(ClientContext &context, PreparedStatementData &data)> + get_result_collector_t; + +struct ClientConfig { + //! The home directory used by the system (if any) + string home_directory; + //! If the query profiler is enabled or not. + bool enable_profiler = false; + //! If detailed query profiling is enabled + bool enable_detailed_profiling = false; + //! The format to print query profiling information in (default: query_tree), if enabled. + ProfilerPrintFormat profiler_print_format = ProfilerPrintFormat::QUERY_TREE; + //! The file to save query profiling information to, instead of printing it to the console + //! (empty = print to console) + string profiler_save_location; + + //! Allows suppressing profiler output, even if enabled. We turn on the profiler on all test runs but don't want + //! to output anything + bool emit_profiler_output = true; + + //! system-wide progress bar disable. + const char *system_progress_bar_disable_reason = nullptr; + //! If the progress bar is enabled or not. + bool enable_progress_bar = false; + //! If the print of the progress bar is enabled + bool print_progress_bar = true; + //! The wait time before showing the progress bar + int wait_time = 2000; + + //! Preserve identifier case while parsing. + //! If false, all unquoted identifiers are lower-cased (e.g. "MyTable" -> "mytable"). + bool preserve_identifier_case = true; + //! The maximum expression depth limit in the parser + idx_t max_expression_depth = 1000; + + //! Whether or not aggressive query verification is enabled + bool query_verification_enabled = false; + //! Whether or not verification of external operators is enabled, used for testing + bool verify_external = false; + //! Whether or not we should verify the serializer + bool verify_serializer = false; + //! Enable the running of optimizers + bool enable_optimizer = true; + //! Enable caching operators + bool enable_caching_operators = true; + //! Force parallelism of small tables, used for testing + bool verify_parallelism = false; + //! Enable the optimizer to consider index joins, which are disabled on default + bool enable_index_join = false; + //! Force index join independent of table cardinality, used for testing + bool force_index_join = false; + //! Force out-of-core computation for operators that support it, used for testing + bool force_external = false; + //! Force disable cross product generation when hyper graph isn't connected, used for testing + bool force_no_cross_product = false; + //! Force use of IEJoin to implement AsOfJoin, used for testing + bool force_asof_iejoin = false; + //! Use range joins for inequalities, even if there are equality predicates + bool prefer_range_joins = false; + //! If this context should also try to use the available replacement scans + //! True by default + bool use_replacement_scans = true; + //! Maximum bits allowed for using a perfect hash table (i.e. the perfect HT can hold up to 2^perfect_ht_threshold + //! elements) + idx_t perfect_ht_threshold = 12; + //! The maximum number of rows to accumulate before sorting ordered aggregates. + idx_t ordered_aggregate_threshold = (idx_t(1) << 18); + + //! Callback to create a progress bar display + progress_bar_display_create_func_t display_create_func = nullptr; + + //! Override for the default extension repository + string custom_extension_repo = ""; + //! Override for the default autoload extensoin repository + string autoinstall_extension_repo = ""; + + //! The explain output type used when none is specified (default: PHYSICAL_ONLY) + ExplainOutputType explain_output_type = ExplainOutputType::PHYSICAL_ONLY; + + //! The maximum amount of pivot columns + idx_t pivot_limit = 100000; + + //! The threshold at which we switch from using filtered aggregates to LIST with a dedicated pivot operator + idx_t pivot_filter_threshold = 10; + + //! Whether or not the "/" division operator defaults to integer division or floating point division + bool integer_division = false; + + //! Generic options + case_insensitive_map_t set_variables; + + //! Function that is used to create the result collector for a materialized result + //! Defaults to PhysicalMaterializedCollector + get_result_collector_t result_collector = nullptr; + +public: + static ClientConfig &GetConfig(ClientContext &context); + static const ClientConfig &GetConfig(const ClientContext &context); + + bool AnyVerification() { + return query_verification_enabled || verify_external || verify_serializer; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/client_context.hpp b/src/duckdb/src/include/duckdb/main/client_context.hpp new file mode 100644 index 00000000..16ead95d --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/client_context.hpp @@ -0,0 +1,297 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/client_context.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/common/enums/pending_execution_result.hpp" +#include "duckdb/common/deque.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/main/prepared_statement.hpp" +#include "duckdb/main/stream_query_result.hpp" +#include "duckdb/main/table_description.hpp" +#include "duckdb/transaction/transaction_context.hpp" +#include "duckdb/main/pending_query_result.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/external_dependencies.hpp" +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/main/client_properties.hpp" + +namespace duckdb { +class Appender; +class Catalog; +class CatalogSearchPath; +class ColumnDataCollection; +class DatabaseInstance; +class FileOpener; +class LogicalOperator; +class PreparedStatementData; +class Relation; +class BufferedFileWriter; +class QueryProfiler; +class ClientContextLock; +struct CreateScalarFunctionInfo; +class ScalarFunctionCatalogEntry; +struct ActiveQueryContext; +struct ParserOptions; +struct ClientData; + +struct PendingQueryParameters { + //! Prepared statement parameters (if any) + optional_ptr> parameters; + //! Whether or not a stream result should be allowed + bool allow_stream_result = false; +}; + +//! ClientContextState is virtual base class for ClientContext-local (or Query-Local, using QueryEnd callback) state +//! e.g. caches that need to live as long as a ClientContext or Query. +class ClientContextState { +public: + virtual ~ClientContextState() {}; + virtual void QueryEnd() = 0; +}; + +//! The ClientContext holds information relevant to the current client session +//! during execution +class ClientContext : public std::enable_shared_from_this { + friend class PendingQueryResult; + friend class StreamQueryResult; + friend class DuckTransactionManager; + +public: + DUCKDB_API explicit ClientContext(shared_ptr db); + DUCKDB_API ~ClientContext(); + + //! The database that this client is connected to + shared_ptr db; + //! Whether or not the query is interrupted + atomic interrupted; + //! External Objects (e.g., Python objects) that views depend of + unordered_map>> external_dependencies; + //! Set of optional states (e.g. Caches) that can be held by the ClientContext + unordered_map> registered_state; + //! The client configuration + ClientConfig config; + //! The set of client-specific data + unique_ptr client_data; + //! Data for the currently running transaction + TransactionContext transaction; + +public: + MetaTransaction &ActiveTransaction() { + return transaction.ActiveTransaction(); + } + + //! Interrupt execution of a query + DUCKDB_API void Interrupt(); + //! Enable query profiling + DUCKDB_API void EnableProfiling(); + //! Disable query profiling + DUCKDB_API void DisableProfiling(); + + //! Issue a query, returning a QueryResult. The QueryResult can be either a StreamQueryResult or a + //! MaterializedQueryResult. The StreamQueryResult will only be returned in the case of a successful SELECT + //! statement. + DUCKDB_API unique_ptr Query(const string &query, bool allow_stream_result); + DUCKDB_API unique_ptr Query(unique_ptr statement, bool allow_stream_result); + + //! Issues a query to the database and returns a Pending Query Result. Note that "query" may only contain + //! a single statement. + DUCKDB_API unique_ptr PendingQuery(const string &query, bool allow_stream_result); + //! Issues a query to the database and returns a Pending Query Result + DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, + bool allow_stream_result); + + //! Destroy the client context + DUCKDB_API void Destroy(); + + //! Get the table info of a specific table, or nullptr if it cannot be found + DUCKDB_API unique_ptr TableInfo(const string &schema_name, const string &table_name); + //! Appends a DataChunk to the specified table. Returns whether or not the append was successful. + DUCKDB_API void Append(TableDescription &description, ColumnDataCollection &collection); + //! Try to bind a relation in the current client context; either throws an exception or fills the result_columns + //! list with the set of returned columns + DUCKDB_API void TryBindRelation(Relation &relation, vector &result_columns); + + //! Execute a relation + DUCKDB_API unique_ptr PendingQuery(const shared_ptr &relation, + bool allow_stream_result); + DUCKDB_API unique_ptr Execute(const shared_ptr &relation); + + //! Prepare a query + DUCKDB_API unique_ptr Prepare(const string &query); + //! Directly prepare a SQL statement + DUCKDB_API unique_ptr Prepare(unique_ptr statement); + + //! Create a pending query result from a prepared statement with the given name and set of parameters + //! It is possible that the prepared statement will be re-bound. This will generally happen if the catalog is + //! modified in between the prepared statement being bound and the prepared statement being run. + DUCKDB_API unique_ptr PendingQuery(const string &query, + shared_ptr &prepared, + const PendingQueryParameters ¶meters); + + //! Execute a prepared statement with the given name and set of parameters + //! It is possible that the prepared statement will be re-bound. This will generally happen if the catalog is + //! modified in between the prepared statement being bound and the prepared statement being run. + DUCKDB_API unique_ptr Execute(const string &query, shared_ptr &prepared, + case_insensitive_map_t &values, bool allow_stream_result = true); + DUCKDB_API unique_ptr Execute(const string &query, shared_ptr &prepared, + const PendingQueryParameters ¶meters); + + //! Gets current percentage of the query's progress, returns 0 in case the progress bar is disabled. + DUCKDB_API double GetProgress(); + + //! Register function in the temporary schema + DUCKDB_API void RegisterFunction(CreateFunctionInfo &info); + + //! Parse statements from a query + DUCKDB_API vector> ParseStatements(const string &query); + + //! Extract the logical plan of a query + DUCKDB_API unique_ptr ExtractPlan(const string &query); + DUCKDB_API void HandlePragmaStatements(vector> &statements); + + //! Runs a function with a valid transaction context, potentially starting a transaction if the context is in auto + //! commit mode. + DUCKDB_API void RunFunctionInTransaction(const std::function &fun, + bool requires_valid_transaction = true); + //! Same as RunFunctionInTransaction, but does not obtain a lock on the client context or check for validation + DUCKDB_API void RunFunctionInTransactionInternal(ClientContextLock &lock, const std::function &fun, + bool requires_valid_transaction = true); + + //! Equivalent to CURRENT_SETTING(key) SQL function. + DUCKDB_API bool TryGetCurrentSetting(const std::string &key, Value &result); + + //! Returns the parser options for this client context + DUCKDB_API ParserOptions GetParserOptions() const; + + DUCKDB_API unique_ptr Fetch(ClientContextLock &lock, StreamQueryResult &result); + + //! Whether or not the given result object (streaming query result or pending query result) is active + DUCKDB_API bool IsActiveResult(ClientContextLock &lock, BaseQueryResult *result); + + //! Returns the current executor + Executor &GetExecutor(); + + //! Returns the current query string (if any) + const string &GetCurrentQuery(); + + //! Fetch a list of table names that are required for a given query + DUCKDB_API unordered_set GetTableNames(const string &query); + + DUCKDB_API ClientProperties GetClientProperties() const; + + //! Returns true if execution of the current query is finished + DUCKDB_API bool ExecutionIsFinished(); + +private: + //! Parse statements and resolve pragmas from a query + bool ParseStatements(ClientContextLock &lock, const string &query, vector> &result, + PreservedError &error); + //! Issues a query to the database and returns a Pending Query Result + unique_ptr PendingQueryInternal(ClientContextLock &lock, unique_ptr statement, + const PendingQueryParameters ¶meters, bool verify = true); + unique_ptr ExecutePendingQueryInternal(ClientContextLock &lock, PendingQueryResult &query); + + //! Parse statements from a query + vector> ParseStatementsInternal(ClientContextLock &lock, const string &query); + //! Perform aggressive query verification of a SELECT statement. Only called when query_verification_enabled is + //! true. + PreservedError VerifyQuery(ClientContextLock &lock, const string &query, unique_ptr statement); + + void InitialCleanup(ClientContextLock &lock); + //! Internal clean up, does not lock. Caller must hold the context_lock. + void CleanupInternal(ClientContextLock &lock, BaseQueryResult *result = nullptr, + bool invalidate_transaction = false); + unique_ptr PendingStatementOrPreparedStatement(ClientContextLock &lock, const string &query, + unique_ptr statement, + shared_ptr &prepared, + const PendingQueryParameters ¶meters); + unique_ptr PendingPreparedStatement(ClientContextLock &lock, + shared_ptr statement_p, + const PendingQueryParameters ¶meters); + + //! Internally prepare a SQL statement. Caller must hold the context_lock. + shared_ptr + CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr statement, + optional_ptr> values = nullptr); + unique_ptr PendingStatementInternal(ClientContextLock &lock, const string &query, + unique_ptr statement, + const PendingQueryParameters ¶meters); + unique_ptr RunStatementInternal(ClientContextLock &lock, const string &query, + unique_ptr statement, bool allow_stream_result, + bool verify = true); + unique_ptr PrepareInternal(ClientContextLock &lock, unique_ptr statement); + void LogQueryInternal(ClientContextLock &lock, const string &query); + + unique_ptr FetchResultInternal(ClientContextLock &lock, PendingQueryResult &pending); + unique_ptr FetchInternal(ClientContextLock &lock, Executor &executor, BaseQueryResult &result); + + unique_ptr LockContext(); + + void BeginTransactionInternal(ClientContextLock &lock, bool requires_valid_transaction); + void BeginQueryInternal(ClientContextLock &lock, const string &query); + PreservedError EndQueryInternal(ClientContextLock &lock, bool success, bool invalidate_transaction); + + PendingExecutionResult ExecuteTaskInternal(ClientContextLock &lock, PendingQueryResult &result); + + unique_ptr PendingStatementOrPreparedStatementInternal( + ClientContextLock &lock, const string &query, unique_ptr statement, + shared_ptr &prepared, const PendingQueryParameters ¶meters); + + unique_ptr PendingQueryPreparedInternal(ClientContextLock &lock, const string &query, + shared_ptr &prepared, + const PendingQueryParameters ¶meters); + + unique_ptr PendingQueryInternal(ClientContextLock &, const shared_ptr &relation, + bool allow_stream_result); + +private: + //! Lock on using the ClientContext in parallel + mutex context_lock; + //! The currently active query context + unique_ptr active_query; + //! The current query progress + atomic query_progress; +}; + +class ClientContextLock { +public: + explicit ClientContextLock(mutex &context_lock) : client_guard(context_lock) { + } + + ~ClientContextLock() { + } + +private: + lock_guard client_guard; +}; + +class ClientContextWrapper { +public: + explicit ClientContextWrapper(const shared_ptr &context) + : client_context(context) { + + }; + shared_ptr GetContext() { + auto actual_context = client_context.lock(); + if (!actual_context) { + throw ConnectionException("Connection has already been closed"); + } + return actual_context; + } + +private: + std::weak_ptr client_context; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/client_context_file_opener.hpp b/src/duckdb/src/include/duckdb/main/client_context_file_opener.hpp new file mode 100644 index 00000000..d956825d --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/client_context_file_opener.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/client_context_file_opener.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_opener.hpp" + +namespace duckdb { + +class ClientContext; + +//! ClientContext-specific FileOpener implementation. +//! This object is owned by ClientContext and never outlives it. +class ClientContextFileOpener : public FileOpener { +public: + explicit ClientContextFileOpener(ClientContext &context_p) : context(context_p) { + } + + bool TryGetCurrentSetting(const string &key, Value &result, FileOpenerInfo &info) override; + bool TryGetCurrentSetting(const string &key, Value &result) override; + + ClientContext *TryGetClientContext() override { + return &context; + }; + +private: + ClientContext &context; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/client_data.hpp b/src/duckdb/src/include/duckdb/main/client_data.hpp new file mode 100644 index 00000000..4eb38578 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/client_data.hpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/client_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/output_type.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/execution/operator/scan/csv/csv_state_machine_cache.hpp" + +namespace duckdb { +class AttachedDatabase; +class BufferedFileWriter; +class ClientContext; +class CatalogSearchPath; +class FileOpener; +class FileSystem; +class HTTPState; +class QueryProfiler; +class QueryProfilerHistory; +class PreparedStatementData; +class SchemaCatalogEntry; +struct RandomEngine; + +struct ClientData { + explicit ClientData(ClientContext &context); + ~ClientData(); + + //! Query profiler + shared_ptr profiler; + //! QueryProfiler History + unique_ptr query_profiler_history; + + //! The set of temporary objects that belong to this client + shared_ptr temporary_objects; + //! The set of bound prepared statements that belong to this client + case_insensitive_map_t> prepared_statements; + + //! The writer used to log queries (if logging is enabled) + unique_ptr log_query_writer; + //! The random generator used by random(). Its seed value can be set by setseed(). + unique_ptr random_engine; + + //! The catalog search path + unique_ptr catalog_search_path; + + //! The file opener of the client context + unique_ptr file_opener; + + //! HTTP State in this query + shared_ptr http_state; + + //! The clients' file system wrapper + unique_ptr client_file_system; + + //! The file search path + string file_search_path; + + //! The Max Line Length Size of Last Query Executed on a CSV File. (Only used for testing) + //! FIXME: this should not be done like this + bool debug_set_max_line_length = false; + idx_t debug_max_line_length = 0; + +public: + DUCKDB_API static ClientData &Get(ClientContext &context); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/client_properties.hpp b/src/duckdb/src/include/duckdb/main/client_properties.hpp new file mode 100644 index 00000000..238d96c2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/client_properties.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/client_properties.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { +enum ArrowOffsetSize { REGULAR, LARGE }; + +//! A set of properties from the client context that can be used to interpret the query result +struct ClientProperties { + ClientProperties(string time_zone_p, ArrowOffsetSize arrow_offset_size_p) + : time_zone(std::move(time_zone_p)), arrow_offset_size(arrow_offset_size_p) { + } + ClientProperties() {}; + string time_zone = "UTC"; + ArrowOffsetSize arrow_offset_size = ArrowOffsetSize::REGULAR; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/config.hpp b/src/duckdb/src/include/duckdb/main/config.hpp new file mode 100644 index 00000000..d9cd3a20 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/config.hpp @@ -0,0 +1,266 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/config.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/access_mode.hpp" +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/compression_type.hpp" +#include "duckdb/common/enums/optimizer_type.hpp" +#include "duckdb/common/enums/order_type.hpp" +#include "duckdb/common/enums/set_scope.hpp" +#include "duckdb/common/enums/window_aggregation_mode.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/storage/compression/bitpacking.hpp" +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/replacement_scan.hpp" +#include "duckdb/optimizer/optimizer_extension.hpp" +#include "duckdb/parser/parser_extension.hpp" +#include "duckdb/planner/operator_extension.hpp" +#include "duckdb/main/client_properties.hpp" + +namespace duckdb { +class BufferPool; +class CastFunctionSet; +class ClientContext; +class ErrorManager; +class CompressionFunction; +class TableFunctionRef; +class OperatorExtension; +class StorageExtension; +class ExtensionCallback; + +struct CompressionFunctionSet; +struct DBConfig; + +enum class CheckpointAbort : uint8_t { + NO_ABORT = 0, + DEBUG_ABORT_BEFORE_TRUNCATE = 1, + DEBUG_ABORT_BEFORE_HEADER = 2, + DEBUG_ABORT_AFTER_FREE_LIST_WRITE = 3 +}; + +typedef void (*set_global_function_t)(DatabaseInstance *db, DBConfig &config, const Value ¶meter); +typedef void (*set_local_function_t)(ClientContext &context, const Value ¶meter); +typedef void (*reset_global_function_t)(DatabaseInstance *db, DBConfig &config); +typedef void (*reset_local_function_t)(ClientContext &context); +typedef Value (*get_setting_function_t)(ClientContext &context); + +struct ConfigurationOption { + const char *name; + const char *description; + LogicalTypeId parameter_type; + set_global_function_t set_global; + set_local_function_t set_local; + reset_global_function_t reset_global; + reset_local_function_t reset_local; + get_setting_function_t get_setting; +}; + +typedef void (*set_option_callback_t)(ClientContext &context, SetScope scope, Value ¶meter); + +struct ExtensionOption { + ExtensionOption(string description_p, LogicalType type_p, set_option_callback_t set_function_p, + Value default_value_p) + : description(std::move(description_p)), type(std::move(type_p)), set_function(set_function_p), + default_value(std::move(default_value_p)) { + } + + string description; + LogicalType type; + set_option_callback_t set_function; + Value default_value; +}; + +struct DBConfigOptions { + //! Database file path. May be empty for in-memory mode + string database_path; + //! Database type. If empty, automatically extracted from `database_path`, where a `type:path` syntax is expected + string database_type; + //! Access mode of the database (AUTOMATIC, READ_ONLY or READ_WRITE) + AccessMode access_mode = AccessMode::AUTOMATIC; + //! Checkpoint when WAL reaches this size (default: 16MB) + idx_t checkpoint_wal_size = 1 << 24; + //! Whether or not to use Direct IO, bypassing operating system buffers + bool use_direct_io = false; + //! Whether extensions should be loaded on start-up + bool load_extensions = true; +#ifdef DUCKDB_EXTENSION_AUTOLOAD_DEFAULT + //! Whether known extensions are allowed to be automatically loaded when a query depends on them + bool autoload_known_extensions = DUCKDB_EXTENSION_AUTOLOAD_DEFAULT; +#else + bool autoload_known_extensions = false; +#endif +#ifdef DUCKDB_EXTENSION_AUTOINSTALL_DEFAULT + //! Whether known extensions are allowed to be automatically installed when a query depends on them + bool autoinstall_known_extensions = DUCKDB_EXTENSION_AUTOINSTALL_DEFAULT; +#else + bool autoinstall_known_extensions = false; +#endif + //! The maximum memory used by the database system (in bytes). Default: 80% of System available memory + idx_t maximum_memory = (idx_t)-1; + //! The maximum amount of CPU threads used by the database system. Default: all available. + idx_t maximum_threads = (idx_t)-1; + //! The number of external threads that work on DuckDB tasks. Default: none. + idx_t external_threads = 0; + //! Whether or not to create and use a temporary directory to store intermediates that do not fit in memory + bool use_temporary_directory = true; + //! Directory to store temporary structures that do not fit in memory + string temporary_directory; + //! The collation type of the database + string collation = string(); + //! The order type used when none is specified (default: ASC) + OrderType default_order_type = OrderType::ASCENDING; + //! Null ordering used when none is specified (default: NULLS LAST) + DefaultOrderByNullType default_null_order = DefaultOrderByNullType::NULLS_LAST; + //! enable COPY and related commands + bool enable_external_access = true; + //! Whether or not object cache is used + bool object_cache_enable = false; + //! Whether or not the global http metadata cache is used + bool http_metadata_cache_enable = false; + //! Force checkpoint when CHECKPOINT is called or on shutdown, even if no changes have been made + bool force_checkpoint = false; + //! Run a checkpoint on successful shutdown and delete the WAL, to leave only a single database file behind + bool checkpoint_on_shutdown = true; + //! Debug flag that decides when a checkpoing should be aborted. Only used for testing purposes. + CheckpointAbort checkpoint_abort = CheckpointAbort::NO_ABORT; + //! Initialize the database with the standard set of DuckDB functions + //! You should probably not touch this unless you know what you are doing + bool initialize_default_database = true; + //! The set of disabled optimizers (default empty) + set disabled_optimizers; + //! Force a specific compression method to be used when checkpointing (if available) + CompressionType force_compression = CompressionType::COMPRESSION_AUTO; + //! Force a specific bitpacking mode to be used when using the bitpacking compression method + BitpackingMode force_bitpacking_mode = BitpackingMode::AUTO; + //! Debug setting for window aggregation mode: (window, combine, separate) + WindowAggregationMode window_mode = WindowAggregationMode::WINDOW; + //! Whether or not preserving insertion order should be preserved + bool preserve_insertion_order = true; + //! Whether Arrow Arrays use Large or Regular buffers + ArrowOffsetSize arrow_offset_size = ArrowOffsetSize::REGULAR; + //! Database configuration variables as controlled by SET + case_insensitive_map_t set_variables; + //! Database configuration variable default values; + case_insensitive_map_t set_variable_defaults; + //! Directory to store extension binaries in + string extension_directory; + //! Whether unsigned extensions should be loaded + bool allow_unsigned_extensions = false; + //! Enable emitting FSST Vectors + bool enable_fsst_vectors = false; + //! Start transactions immediately in all attached databases - instead of lazily when a database is referenced + bool immediate_transaction_mode = false; + //! Debug setting - how to initialize blocks in the storage layer when allocating + DebugInitialize debug_initialize = DebugInitialize::NO_INITIALIZE; + //! The set of unrecognized (other) options + unordered_map unrecognized_options; + //! Whether or not the configuration settings can be altered + bool lock_configuration = false; + //! Whether to print bindings when printing the plan (debug mode only) + static bool debug_print_bindings; + //! The peak allocation threshold at which to flush the allocator after completing a task (1 << 27, ~128MB) + idx_t allocator_flush_threshold = 134217728; + + bool operator==(const DBConfigOptions &other) const; +}; + +struct DBConfig { + friend class DatabaseInstance; + friend class StorageManager; + +public: + DUCKDB_API DBConfig(); + DUCKDB_API DBConfig(std::unordered_map &config_dict, bool read_only); + DUCKDB_API ~DBConfig(); + + mutex config_lock; + //! Replacement table scans are automatically attempted when a table name cannot be found in the schema + vector replacement_scans; + + //! Extra parameters that can be SET for loaded extensions + case_insensitive_map_t extension_parameters; + //! The FileSystem to use, can be overwritten to allow for injecting custom file systems for testing purposes (e.g. + //! RamFS or something similar) + unique_ptr file_system; + //! The allocator used by the system + unique_ptr allocator; + //! Database configuration options + DBConfigOptions options; + //! Extensions made to the parser + vector parser_extensions; + //! Extensions made to the optimizer + vector optimizer_extensions; + //! Error manager + unique_ptr error_manager; + //! A reference to the (shared) default allocator (Allocator::DefaultAllocator) + shared_ptr default_allocator; + //! Extensions made to binder + vector> operator_extensions; + //! Extensions made to storage + case_insensitive_map_t> storage_extensions; + //! A buffer pool can be shared across multiple databases (if desired). + shared_ptr buffer_pool; + //! Set of callbacks that can be installed by extensions + vector> extension_callbacks; + +public: + DUCKDB_API static DBConfig &GetConfig(ClientContext &context); + DUCKDB_API static DBConfig &GetConfig(DatabaseInstance &db); + DUCKDB_API static DBConfig &Get(AttachedDatabase &db); + DUCKDB_API static const DBConfig &GetConfig(const ClientContext &context); + DUCKDB_API static const DBConfig &GetConfig(const DatabaseInstance &db); + DUCKDB_API static vector GetOptions(); + DUCKDB_API static idx_t GetOptionCount(); + DUCKDB_API static vector GetOptionNames(); + + DUCKDB_API void AddExtensionOption(const string &name, string description, LogicalType parameter, + const Value &default_value = Value(), set_option_callback_t function = nullptr); + //! Fetch an option by index. Returns a pointer to the option, or nullptr if out of range + DUCKDB_API static ConfigurationOption *GetOptionByIndex(idx_t index); + //! Fetch an option by name. Returns a pointer to the option, or nullptr if none exists. + DUCKDB_API static ConfigurationOption *GetOptionByName(const string &name); + + DUCKDB_API void SetOption(const ConfigurationOption &option, const Value &value); + DUCKDB_API void SetOption(DatabaseInstance *db, const ConfigurationOption &option, const Value &value); + DUCKDB_API void SetOptionByName(const string &name, const Value &value); + DUCKDB_API void ResetOption(DatabaseInstance *db, const ConfigurationOption &option); + DUCKDB_API void SetOption(const string &name, Value value); + DUCKDB_API void ResetOption(const string &name); + + DUCKDB_API static idx_t ParseMemoryLimit(const string &arg); + + //! Return the list of possible compression functions for the specific physical type + DUCKDB_API vector> GetCompressionFunctions(PhysicalType data_type); + //! Return the compression function for the specified compression type/physical type combo + DUCKDB_API optional_ptr GetCompressionFunction(CompressionType type, PhysicalType data_type); + + bool operator==(const DBConfig &other); + bool operator!=(const DBConfig &other); + + DUCKDB_API CastFunctionSet &GetCastFunctions(); + static idx_t GetSystemMaxThreads(FileSystem &fs); + void SetDefaultMaxThreads(); + void SetDefaultMaxMemory(); + + OrderType ResolveOrder(OrderType order_type) const; + OrderByNullType ResolveNullOrder(OrderType order_type, OrderByNullType null_type) const; + +private: + unique_ptr compression_functions; + unique_ptr cast_functions; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/connection.hpp b/src/duckdb/src/include/duckdb/main/connection.hpp new file mode 100644 index 00000000..29048a9e --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/connection.hpp @@ -0,0 +1,238 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/connection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/profiler_format.hpp" +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/function/udf_function.hpp" +#include "duckdb/main/materialized_query_result.hpp" +#include "duckdb/main/pending_query_result.hpp" +#include "duckdb/main/prepared_statement.hpp" +#include "duckdb/main/query_result.hpp" +#include "duckdb/main/relation.hpp" +#include "duckdb/main/stream_query_result.hpp" +#include "duckdb/main/table_description.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class ColumnDataCollection; +class ClientContext; + +class DatabaseInstance; +class DuckDB; +class LogicalOperator; +class SelectStatement; +struct CSVReaderOptions; + +typedef void (*warning_callback)(std::string); + +//! A connection to a database. This represents a (client) connection that can +//! be used to query the database. +class Connection { +public: + DUCKDB_API explicit Connection(DuckDB &database); + DUCKDB_API explicit Connection(DatabaseInstance &database); + DUCKDB_API ~Connection(); + + shared_ptr context; + warning_callback warning_cb; + +public: + //! Returns query profiling information for the current query + DUCKDB_API string GetProfilingInformation(ProfilerPrintFormat format = ProfilerPrintFormat::QUERY_TREE); + + //! Interrupt execution of the current query + DUCKDB_API void Interrupt(); + + //! Enable query profiling + DUCKDB_API void EnableProfiling(); + //! Disable query profiling + DUCKDB_API void DisableProfiling(); + + DUCKDB_API void SetWarningCallback(warning_callback); + + //! Enable aggressive verification/testing of queries, should only be used in testing + DUCKDB_API void EnableQueryVerification(); + DUCKDB_API void DisableQueryVerification(); + //! Force parallel execution, even for smaller tables. Should only be used in testing. + DUCKDB_API void ForceParallelism(); + + //! Issues a query to the database and returns a QueryResult. This result can be either a StreamQueryResult or a + //! MaterializedQueryResult. The result can be stepped through with calls to Fetch(). Note that there can only be + //! one active StreamQueryResult per Connection object. Calling SendQuery() will invalidate any previously existing + //! StreamQueryResult. + DUCKDB_API unique_ptr SendQuery(const string &query); + //! Issues a query to the database and materializes the result (if necessary). Always returns a + //! MaterializedQueryResult. + DUCKDB_API unique_ptr Query(const string &query); + //! Issues a query to the database and materializes the result (if necessary). Always returns a + //! MaterializedQueryResult. + DUCKDB_API unique_ptr Query(unique_ptr statement); + // prepared statements + template + unique_ptr Query(const string &query, Args... args) { + vector values; + return QueryParamsRecursive(query, values, args...); + } + + //! Issues a query to the database and returns a Pending Query Result. Note that "query" may only contain + //! a single statement. + DUCKDB_API unique_ptr PendingQuery(const string &query, bool allow_stream_result = false); + //! Issues a query to the database and returns a Pending Query Result + DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, + bool allow_stream_result = false); + + //! Prepare the specified query, returning a prepared statement object + DUCKDB_API unique_ptr Prepare(const string &query); + //! Prepare the specified statement, returning a prepared statement object + DUCKDB_API unique_ptr Prepare(unique_ptr statement); + + //! Get the table info of a specific table (in the default schema), or nullptr if it cannot be found + DUCKDB_API unique_ptr TableInfo(const string &table_name); + //! Get the table info of a specific table, or nullptr if it cannot be found + DUCKDB_API unique_ptr TableInfo(const string &schema_name, const string &table_name); + + //! Extract a set of SQL statements from a specific query + DUCKDB_API vector> ExtractStatements(const string &query); + //! Extract the logical plan that corresponds to a query + DUCKDB_API unique_ptr ExtractPlan(const string &query); + + //! Appends a DataChunk to the specified table + DUCKDB_API void Append(TableDescription &description, DataChunk &chunk); + //! Appends a ColumnDataCollection to the specified table + DUCKDB_API void Append(TableDescription &description, ColumnDataCollection &collection); + + //! Returns a relation that produces a table from this connection + DUCKDB_API shared_ptr Table(const string &tname); + DUCKDB_API shared_ptr Table(const string &schema_name, const string &table_name); + //! Returns a relation that produces a view from this connection + DUCKDB_API shared_ptr View(const string &tname); + DUCKDB_API shared_ptr View(const string &schema_name, const string &table_name); + //! Returns a relation that calls a specified table function + DUCKDB_API shared_ptr TableFunction(const string &tname); + DUCKDB_API shared_ptr TableFunction(const string &tname, const vector &values, + const named_parameter_map_t &named_parameters); + DUCKDB_API shared_ptr TableFunction(const string &tname, const vector &values); + //! Returns a relation that produces values + DUCKDB_API shared_ptr Values(const vector> &values); + DUCKDB_API shared_ptr Values(const vector> &values, const vector &column_names, + const string &alias = "values"); + DUCKDB_API shared_ptr Values(const string &values); + DUCKDB_API shared_ptr Values(const string &values, const vector &column_names, + const string &alias = "values"); + + //! Reads CSV file + DUCKDB_API shared_ptr ReadCSV(const string &csv_file); + DUCKDB_API shared_ptr ReadCSV(const string &csv_file, named_parameter_map_t &&options); + DUCKDB_API shared_ptr ReadCSV(const string &csv_file, const vector &columns); + + //! Reads Parquet file + DUCKDB_API shared_ptr ReadParquet(const string &parquet_file, bool binary_as_string); + //! Returns a relation from a query + DUCKDB_API shared_ptr RelationFromQuery(const string &query, const string &alias = "queryrelation", + const string &error = "Expected a single SELECT statement"); + DUCKDB_API shared_ptr RelationFromQuery(unique_ptr select_stmt, + const string &alias = "queryrelation"); + + //! Returns a substrait BLOB from a valid query + DUCKDB_API string GetSubstrait(const string &query); + //! Returns a Query Result from a substrait blob + DUCKDB_API unique_ptr FromSubstrait(const string &proto); + //! Returns a substrait BLOB from a valid query + DUCKDB_API string GetSubstraitJSON(const string &query); + //! Returns a Query Result from a substrait JSON + DUCKDB_API unique_ptr FromSubstraitJSON(const string &json); + DUCKDB_API void BeginTransaction(); + DUCKDB_API void Commit(); + DUCKDB_API void Rollback(); + DUCKDB_API void SetAutoCommit(bool auto_commit); + DUCKDB_API bool IsAutoCommit(); + DUCKDB_API bool HasActiveTransaction(); + + //! Fetch a list of table names that are required for a given query + DUCKDB_API unordered_set GetTableNames(const string &query); + + template + void CreateScalarFunction(const string &name, TR (*udf_func)(Args...)) { + scalar_function_t function = UDFWrapper::CreateScalarFunction(name, udf_func); + UDFWrapper::RegisterFunction(name, function, *context); + } + + template + void CreateScalarFunction(const string &name, vector args, LogicalType ret_type, + TR (*udf_func)(Args...)) { + scalar_function_t function = UDFWrapper::CreateScalarFunction(name, args, ret_type, udf_func); + UDFWrapper::RegisterFunction(name, args, ret_type, function, *context); + } + + template + void CreateVectorizedFunction(const string &name, scalar_function_t udf_func, + LogicalType varargs = LogicalType::INVALID) { + UDFWrapper::RegisterFunction(name, udf_func, *context, std::move(varargs)); + } + + void CreateVectorizedFunction(const string &name, vector args, LogicalType ret_type, + scalar_function_t udf_func, LogicalType varargs = LogicalType::INVALID) { + UDFWrapper::RegisterFunction(name, std::move(args), std::move(ret_type), udf_func, *context, + std::move(varargs)); + } + + //------------------------------------- Aggreate Functions ----------------------------------------// + template + void CreateAggregateFunction(const string &name) { + AggregateFunction function = UDFWrapper::CreateAggregateFunction(name); + UDFWrapper::RegisterAggrFunction(function, *context); + } + + template + void CreateAggregateFunction(const string &name) { + AggregateFunction function = UDFWrapper::CreateAggregateFunction(name); + UDFWrapper::RegisterAggrFunction(function, *context); + } + + template + void CreateAggregateFunction(const string &name, LogicalType ret_type, LogicalType input_typeA) { + AggregateFunction function = + UDFWrapper::CreateAggregateFunction(name, ret_type, input_typeA); + UDFWrapper::RegisterAggrFunction(function, *context); + } + + template + void CreateAggregateFunction(const string &name, LogicalType ret_type, LogicalType input_typeA, + LogicalType input_typeB) { + AggregateFunction function = + UDFWrapper::CreateAggregateFunction(name, ret_type, input_typeA, input_typeB); + UDFWrapper::RegisterAggrFunction(function, *context); + } + + void CreateAggregateFunction(const string &name, vector arguments, LogicalType return_type, + aggregate_size_t state_size, aggregate_initialize_t initialize, + aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, + aggregate_simple_update_t simple_update = nullptr, + bind_aggregate_function_t bind = nullptr, + aggregate_destructor_t destructor = nullptr) { + AggregateFunction function = + UDFWrapper::CreateAggregateFunction(name, arguments, return_type, state_size, initialize, update, combine, + finalize, simple_update, bind, destructor); + UDFWrapper::RegisterAggrFunction(function, *context); + } + +private: + unique_ptr QueryParamsRecursive(const string &query, vector &values); + + template + unique_ptr QueryParamsRecursive(const string &query, vector &values, T value, Args... args) { + values.push_back(Value::CreateValue(value)); + return QueryParamsRecursive(query, values, args...); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/connection_manager.hpp b/src/duckdb/src/include/duckdb/main/connection_manager.hpp new file mode 100644 index 00000000..11495742 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/connection_manager.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/connection_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +class ClientContext; +class DatabaseInstance; + +class ConnectionManager { +public: + ConnectionManager() { + } + + void AddConnection(ClientContext &context) { + lock_guard lock(connections_lock); + connections.insert(make_pair(&context, weak_ptr(context.shared_from_this()))); + } + + void RemoveConnection(ClientContext &context) { + lock_guard lock(connections_lock); + connections.erase(&context); + } + + vector> GetConnectionList() { + vector> result; + for (auto &it : connections) { + auto connection = it.second.lock(); + if (!connection) { + connections.erase(it.first); + continue; + } else { + result.push_back(std::move(connection)); + } + } + + return result; + } + + ClientContext *GetConnection(DatabaseInstance *db); + + static ConnectionManager &Get(DatabaseInstance &db); + static ConnectionManager &Get(ClientContext &context); + +public: + mutex connections_lock; + unordered_map> connections; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/database.hpp b/src/duckdb/src/include/duckdb/main/database.hpp new file mode 100644 index 00000000..8a89b730 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/database.hpp @@ -0,0 +1,109 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/database.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/config.hpp" +#include "duckdb/main/valid_checker.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/main/extension.hpp" + +namespace duckdb { +class BufferManager; +class DatabaseManager; +class StorageManager; +class Catalog; +class TransactionManager; +class ConnectionManager; +class FileSystem; +class TaskScheduler; +class ObjectCache; +struct AttachInfo; + +class DatabaseInstance : public std::enable_shared_from_this { + friend class DuckDB; + +public: + DUCKDB_API DatabaseInstance(); + DUCKDB_API ~DatabaseInstance(); + + DBConfig config; + +public: + BufferPool &GetBufferPool(); + DUCKDB_API BufferManager &GetBufferManager(); + DUCKDB_API DatabaseManager &GetDatabaseManager(); + DUCKDB_API FileSystem &GetFileSystem(); + DUCKDB_API TaskScheduler &GetScheduler(); + DUCKDB_API ObjectCache &GetObjectCache(); + DUCKDB_API ConnectionManager &GetConnectionManager(); + DUCKDB_API ValidChecker &GetValidChecker(); + DUCKDB_API void SetExtensionLoaded(const std::string &extension_name); + + idx_t NumberOfThreads(); + + DUCKDB_API static DatabaseInstance &GetDatabase(ClientContext &context); + + DUCKDB_API const unordered_set &LoadedExtensions(); + DUCKDB_API bool ExtensionIsLoaded(const std::string &name); + + DUCKDB_API bool TryGetCurrentSetting(const std::string &key, Value &result); + + unique_ptr CreateAttachedDatabase(AttachInfo &info, const string &type, AccessMode access_mode); + +private: + void Initialize(const char *path, DBConfig *config); + void CreateMainDatabase(); + + void Configure(DBConfig &config); + +private: + unique_ptr buffer_manager; + unique_ptr db_manager; + unique_ptr scheduler; + unique_ptr object_cache; + unique_ptr connection_manager; + unordered_set loaded_extensions; + ValidChecker db_validity; +}; + +//! The database object. This object holds the catalog and all the +//! database-specific meta information. +class DuckDB { +public: + DUCKDB_API explicit DuckDB(const char *path = nullptr, DBConfig *config = nullptr); + DUCKDB_API explicit DuckDB(const string &path, DBConfig *config = nullptr); + DUCKDB_API explicit DuckDB(DatabaseInstance &instance); + + DUCKDB_API ~DuckDB(); + + //! Reference to the actual database instance + shared_ptr instance; + +public: + template + void LoadExtension() { + T extension; + if (ExtensionIsLoaded(extension.Name())) { + return; + } + extension.Load(*this); + instance->SetExtensionLoaded(extension.Name()); + } + + DUCKDB_API FileSystem &GetFileSystem(); + + DUCKDB_API idx_t NumberOfThreads(); + DUCKDB_API static const char *SourceID(); + DUCKDB_API static const char *LibraryVersion(); + DUCKDB_API static idx_t StandardVectorSize(); + DUCKDB_API static string Platform(); + DUCKDB_API bool ExtensionIsLoaded(const std::string &name); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/database_manager.hpp b/src/duckdb/src/include/duckdb/main/database_manager.hpp new file mode 100644 index 00000000..f5360a37 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/database_manager.hpp @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/database_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" + +namespace duckdb { +class AttachedDatabase; +class Catalog; +class CatalogSet; +class ClientContext; +class DatabaseInstance; + +//! The DatabaseManager is a class that sits at the root of all attached databases +class DatabaseManager { + friend class Catalog; + +public: + explicit DatabaseManager(DatabaseInstance &db); + ~DatabaseManager(); + +public: + static DatabaseManager &Get(DatabaseInstance &db); + static DatabaseManager &Get(ClientContext &db); + static DatabaseManager &Get(AttachedDatabase &db); + + void InitializeSystemCatalog(); + //! Get an attached database with the given name + optional_ptr GetDatabase(ClientContext &context, const string &name); + //! Add a new attached database to the database manager + void AddDatabase(ClientContext &context, unique_ptr db); + void DetachDatabase(ClientContext &context, const string &name, OnEntryNotFound if_not_found); + //! Returns a reference to the system catalog + Catalog &GetSystemCatalog(); + static const string &GetDefaultDatabase(ClientContext &context); + void SetDefaultDatabase(ClientContext &context, const string &new_value); + + optional_ptr GetDatabaseFromPath(ClientContext &context, const string &path); + vector> GetDatabases(ClientContext &context); + + transaction_t GetNewQueryNumber() { + return current_query_number++; + } + transaction_t ActiveQueryNumber() const { + return current_query_number; + } + idx_t ModifyCatalog() { + return catalog_version++; + } + bool HasDefaultDatabase() { + return !default_database.empty(); + } + +private: + //! The system database is a special database that holds system entries (e.g. functions) + unique_ptr system; + //! The set of attached databases + unique_ptr databases; + //! The global catalog version, incremented whenever anything changes in the catalog + atomic catalog_version; + //! The current query number + atomic current_query_number; + //! The current default database + string default_database; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/database_path_and_type.hpp b/src/duckdb/src/include/duckdb/main/database_path_and_type.hpp new file mode 100644 index 00000000..e1dc5d46 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/database_path_and_type.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/database_path_and_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include "duckdb/main/config.hpp" + +namespace duckdb { + +struct DBPathAndType { + + //! Parse database extension type and rest of path from combined form (type:path) + static DBPathAndType Parse(const string &combined_path, const DBConfig &config); + + const string path; + const string type; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/db_instance_cache.hpp b/src/duckdb/src/include/duckdb/main/db_instance_cache.hpp new file mode 100644 index 00000000..46cc3d98 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/db_instance_cache.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/db_instance_cache.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/connection_manager.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/function/replacement_scan.hpp" + +namespace duckdb { +class DBInstanceCache { +public: + DBInstanceCache() {}; + //! Gets a DB Instance from the cache if already exists (Fails if the configurations do not match) + shared_ptr GetInstance(const string &database, const DBConfig &config_dict); + + //! Creates and caches a new DB Instance (Fails if a cached instance already exists) + shared_ptr CreateInstance(const string &database, DBConfig &config_dict, bool cache_instance = true); + + //! Creates and caches a new DB Instance (Fails if a cached instance already exists) + shared_ptr GetOrCreateInstance(const string &database, DBConfig &config_dict, bool cache_instance); + +private: + //! A map with the cached instances + unordered_map> db_instances; + + //! Lock to alter cache + mutex cache_lock; + +private: + shared_ptr GetInstanceInternal(const string &database, const DBConfig &config_dict); + shared_ptr CreateInstanceInternal(const string &database, DBConfig &config_dict, bool cache_instance); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/error_manager.hpp b/src/duckdb/src/include/duckdb/main/error_manager.hpp new file mode 100644 index 00000000..2fd63b36 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/error_manager.hpp @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/error_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/map.hpp" + +namespace duckdb { +class ClientContext; +class DatabaseInstance; + +enum class ErrorType : uint16_t { + // error message types + UNSIGNED_EXTENSION = 0, + INVALIDATED_TRANSACTION = 1, + INVALIDATED_DATABASE = 2, + + // this should always be the last value + ERROR_COUNT, + INVALID = 65535, +}; + +//! The error manager class is responsible for formatting error messages +//! It allows for error messages to be overridden by extensions and clients +class ErrorManager { +public: + template + string FormatException(ErrorType error_type, Args... params) { + vector values; + return FormatExceptionRecursive(error_type, values, params...); + } + + DUCKDB_API string FormatExceptionRecursive(ErrorType error_type, vector &values); + + template + string FormatExceptionRecursive(ErrorType error_type, vector &values, T param, + Args... params) { + values.push_back(ExceptionFormatValue::CreateFormatValue(param)); + return FormatExceptionRecursive(error_type, values, params...); + } + + template + static string FormatException(ClientContext &context, ErrorType error_type, Args... params) { + return Get(context).FormatException(error_type, params...); + } + + DUCKDB_API static string InvalidUnicodeError(const string &input, const string &context); + + //! Adds a custom error for a specific error type + void AddCustomError(ErrorType type, string new_error); + + DUCKDB_API static ErrorManager &Get(ClientContext &context); + DUCKDB_API static ErrorManager &Get(DatabaseInstance &context); + +private: + map custom_errors; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension.hpp b/src/duckdb/src/include/duckdb/main/extension.hpp new file mode 100644 index 00000000..5c57996d --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/extension.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/extension.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/winapi.hpp" + +namespace duckdb { +class DuckDB; + +//! The Extension class is the base class used to define extensions +class Extension { +public: + DUCKDB_API virtual ~Extension(); + + DUCKDB_API virtual void Load(DuckDB &db) = 0; + DUCKDB_API virtual std::string Name() = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension/generated_extension_loader.hpp b/src/duckdb/src/include/duckdb/main/extension/generated_extension_loader.hpp new file mode 100644 index 00000000..2b4a6662 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/extension/generated_extension_loader.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/extension/generated_extension_loader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/database.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/vector.hpp" + +#if defined(GENERATED_EXTENSION_HEADERS) and !defined(DUCKDB_AMALGAMATION) +#include "generated_extension_headers.hpp" +#include "duckdb/common/common.hpp" + +namespace duckdb { + +//! Looks through the CMake-generated list of extensions that are linked into DuckDB currently to try load +bool TryLoadLinkedExtension(DuckDB &db, const string &extension); +extern vector linked_extensions; +extern vector loaded_extension_test_paths; + +} // namespace duckdb +#endif diff --git a/src/duckdb/src/include/duckdb/main/extension_entries.hpp b/src/duckdb/src/include/duckdb/main/extension_entries.hpp new file mode 100644 index 00000000..97e8adc8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/extension_entries.hpp @@ -0,0 +1,299 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/extension_entries.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" + +// NOTE: this file is generated by scripts/generate_extensions_function.py. Check out the check-load-install-extensions +// job in .github/workflows/LinuxRelease.yml on how to use it + +namespace duckdb { + +struct ExtensionEntry { + char name[48]; + char extension[48]; +}; + +static constexpr ExtensionEntry EXTENSION_FUNCTIONS[] = { + {"->>", "json"}, + {"array_to_json", "json"}, + {"create_fts_index", "fts"}, + {"current_localtime", "icu"}, + {"current_localtimestamp", "icu"}, + {"dbgen", "tpch"}, + {"drop_fts_index", "fts"}, + {"dsdgen", "tpcds"}, + {"excel_text", "excel"}, + {"from_json", "json"}, + {"from_json_strict", "json"}, + {"from_substrait", "substrait"}, + {"from_substrait_json", "substrait"}, + {"fuzz_all_functions", "sqlsmith"}, + {"fuzzyduck", "sqlsmith"}, + {"get_substrait", "substrait"}, + {"get_substrait_json", "substrait"}, + {"host", "inet"}, + {"iceberg_metadata", "iceberg"}, + {"iceberg_scan", "iceberg"}, + {"iceberg_snapshots", "iceberg"}, + {"icu_calendar_names", "icu"}, + {"icu_sort_key", "icu"}, + {"json", "json"}, + {"json_array", "json"}, + {"json_array_length", "json"}, + {"json_contains", "json"}, + {"json_deserialize_sql", "json"}, + {"json_execute_serialized_sql", "json"}, + {"json_extract", "json"}, + {"json_extract_path", "json"}, + {"json_extract_path_text", "json"}, + {"json_extract_string", "json"}, + {"json_group_array", "json"}, + {"json_group_object", "json"}, + {"json_group_structure", "json"}, + {"json_keys", "json"}, + {"json_merge_patch", "json"}, + {"json_object", "json"}, + {"json_quote", "json"}, + {"json_serialize_sql", "json"}, + {"json_structure", "json"}, + {"json_transform", "json"}, + {"json_transform_strict", "json"}, + {"json_type", "json"}, + {"json_valid", "json"}, + {"load_aws_credentials", "aws"}, + {"make_timestamptz", "icu"}, + {"parquet_metadata", "parquet"}, + {"parquet_scan", "parquet"}, + {"parquet_schema", "parquet"}, + {"pg_timezone_names", "icu"}, + {"postgres_attach", "postgres_scanner"}, + {"postgres_scan", "postgres_scanner"}, + {"postgres_scan_pushdown", "postgres_scanner"}, + {"read_json", "json"}, + {"read_json_auto", "json"}, + {"read_json_objects", "json"}, + {"read_json_objects_auto", "json"}, + {"read_ndjson", "json"}, + {"read_ndjson_auto", "json"}, + {"read_ndjson_objects", "json"}, + {"read_parquet", "parquet"}, + {"reduce_sql_statement", "sqlsmith"}, + {"row_to_json", "json"}, + {"scan_arrow_ipc", "arrow"}, + {"sql_auto_complete", "autocomplete"}, + {"sqlite_attach", "sqlite_scanner"}, + {"sqlite_scan", "sqlite_scanner"}, + {"sqlsmith", "sqlsmith"}, + {"st_area", "spatial"}, + {"st_area_spheroid", "spatial"}, + {"st_asgeojson", "spatial"}, + {"st_ashexwkb", "spatial"}, + {"st_astext", "spatial"}, + {"st_aswkb", "spatial"}, + {"st_boundary", "spatial"}, + {"st_buffer", "spatial"}, + {"st_centroid", "spatial"}, + {"st_collect", "spatial"}, + {"st_collectionextract", "spatial"}, + {"st_contains", "spatial"}, + {"st_containsproperly", "spatial"}, + {"st_convexhull", "spatial"}, + {"st_coveredby", "spatial"}, + {"st_covers", "spatial"}, + {"st_crosses", "spatial"}, + {"st_difference", "spatial"}, + {"st_dimension", "spatial"}, + {"st_disjoint", "spatial"}, + {"st_distance", "spatial"}, + {"st_distance_spheroid", "spatial"}, + {"st_drivers", "spatial"}, + {"st_dwithin", "spatial"}, + {"st_dwithin_spheroid", "spatial"}, + {"st_endpoint", "spatial"}, + {"st_envelope", "spatial"}, + {"st_envelope_agg", "spatial"}, + {"st_equals", "spatial"}, + {"st_extent", "spatial"}, + {"st_exteriorring", "spatial"}, + {"st_flipcoordinates", "spatial"}, + {"st_geometrytype", "spatial"}, + {"st_geomfromgeojson", "spatial"}, + {"st_geomfromhexewkb", "spatial"}, + {"st_geomfromhexwkb", "spatial"}, + {"st_geomfromtext", "spatial"}, + {"st_geomfromwkb", "spatial"}, + {"st_intersection", "spatial"}, + {"st_intersection_agg", "spatial"}, + {"st_intersects", "spatial"}, + {"st_intersects_extent", "spatial"}, + {"st_isclosed", "spatial"}, + {"st_isempty", "spatial"}, + {"st_isring", "spatial"}, + {"st_issimple", "spatial"}, + {"st_isvalid", "spatial"}, + {"st_length", "spatial"}, + {"st_length_spheroid", "spatial"}, + {"st_linestring2dfromwkb", "spatial"}, + {"st_list_proj_crs", "spatial"}, + {"st_makeline", "spatial"}, + {"st_ngeometries", "spatial"}, + {"st_ninteriorrings", "spatial"}, + {"st_normalize", "spatial"}, + {"st_npoints", "spatial"}, + {"st_numgeometries", "spatial"}, + {"st_numinteriorrings", "spatial"}, + {"st_numpoints", "spatial"}, + {"st_overlaps", "spatial"}, + {"st_perimeter", "spatial"}, + {"st_perimeter_spheroid", "spatial"}, + {"st_point", "spatial"}, + {"st_point2d", "spatial"}, + {"st_point2dfromwkb", "spatial"}, + {"st_point3d", "spatial"}, + {"st_point4d", "spatial"}, + {"st_pointn", "spatial"}, + {"st_pointonsurface", "spatial"}, + {"st_polygon2dfromwkb", "spatial"}, + {"st_reverse", "spatial"}, + {"st_read", "spatial"}, + {"st_readosm", "spatial"}, + {"st_reduceprecision", "spatial"}, + {"st_removerepeatedpoints", "spatial"}, + {"st_simplify", "spatial"}, + {"st_simplifypreservetopology", "spatial"}, + {"st_startpoint", "spatial"}, + {"st_touches", "spatial"}, + {"st_transform", "spatial"}, + {"st_union", "spatial"}, + {"st_union_agg", "spatial"}, + {"st_within", "spatial"}, + {"st_x", "spatial"}, + {"st_xmax", "spatial"}, + {"st_xmin", "spatial"}, + {"st_y", "spatial"}, + {"st_ymax", "spatial"}, + {"st_ymin", "spatial"}, + {"stem", "fts"}, + {"text", "excel"}, + {"to_arrow_ipc", "arrow"}, + {"to_json", "json"}, + {"tpcds", "tpcds"}, + {"tpcds_answers", "tpcds"}, + {"tpcds_queries", "tpcds"}, + {"tpch", "tpch"}, + {"tpch_answers", "tpch"}, + {"tpch_queries", "tpch"}, + {"visualize_diff_profiling_output", "visualizer"}, + {"visualize_json_profiling_output", "visualizer"}, + {"visualize_last_profiling_output", "visualizer"}, +}; // END_OF_EXTENSION_FUNCTIONS + +static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { + {"azure_storage_connection_string", "azure"}, + {"binary_as_string", "parquet"}, + {"calendar", "icu"}, + {"force_download", "httpfs"}, + {"http_retries", "httpfs"}, + {"http_retry_backoff", "httpfs"}, + {"http_retry_wait_ms", "httpfs"}, + {"http_timeout", "httpfs"}, + {"s3_access_key_id", "httpfs"}, + {"s3_endpoint", "httpfs"}, + {"s3_region", "httpfs"}, + {"s3_secret_access_key", "httpfs"}, + {"s3_session_token", "httpfs"}, + {"s3_uploader_max_filesize", "httpfs"}, + {"s3_uploader_max_parts_per_file", "httpfs"}, + {"s3_uploader_thread_limit", "httpfs"}, + {"s3_url_compatibility_mode", "httpfs"}, + {"s3_url_style", "httpfs"}, + {"s3_use_ssl", "httpfs"}, + {"sqlite_all_varchar", "sqlite_scanner"}, + {"timezone", "icu"}, +}; // END_OF_EXTENSION_SETTINGS + +// Note: these are currently hardcoded in scripts/generate_extensions_function.py +// TODO: automate by passing though to script via duckdb +static constexpr ExtensionEntry EXTENSION_COPY_FUNCTIONS[] = {{"parquet", "parquet"}, + {"json", "json"}}; // END_OF_EXTENSION_COPY_FUNCTIONS + +// Note: these are currently hardcoded in scripts/generate_extensions_function.py +// TODO: automate by passing though to script via duckdb +static constexpr ExtensionEntry EXTENSION_TYPES[] = { + {"json", "json"}, {"inet", "inet"}, {"geometry", "spatial"}}; // END_OF_EXTENSION_TYPES + +// Note: these are currently hardcoded in scripts/generate_extensions_function.py +// TODO: automate by passing though to script via duckdb +static constexpr ExtensionEntry EXTENSION_COLLATIONS[] = { + {"af", "icu"}, {"am", "icu"}, {"ar", "icu"}, {"ar_sa", "icu"}, {"as", "icu"}, {"az", "icu"}, + {"be", "icu"}, {"bg", "icu"}, {"bn", "icu"}, {"bo", "icu"}, {"br", "icu"}, {"bs", "icu"}, + {"ca", "icu"}, {"ceb", "icu"}, {"chr", "icu"}, {"cs", "icu"}, {"cy", "icu"}, {"da", "icu"}, + {"de", "icu"}, {"de_at", "icu"}, {"dsb", "icu"}, {"dz", "icu"}, {"ee", "icu"}, {"el", "icu"}, + {"en", "icu"}, {"en_us", "icu"}, {"eo", "icu"}, {"es", "icu"}, {"et", "icu"}, {"fa", "icu"}, + {"fa_af", "icu"}, {"ff", "icu"}, {"fi", "icu"}, {"fil", "icu"}, {"fo", "icu"}, {"fr", "icu"}, + {"fr_ca", "icu"}, {"fy", "icu"}, {"ga", "icu"}, {"gl", "icu"}, {"gu", "icu"}, {"ha", "icu"}, + {"haw", "icu"}, {"he", "icu"}, {"he_il", "icu"}, {"hi", "icu"}, {"hr", "icu"}, {"hsb", "icu"}, + {"hu", "icu"}, {"hy", "icu"}, {"id", "icu"}, {"id_id", "icu"}, {"ig", "icu"}, {"is", "icu"}, + {"it", "icu"}, {"ja", "icu"}, {"ka", "icu"}, {"kk", "icu"}, {"kl", "icu"}, {"km", "icu"}, + {"kn", "icu"}, {"ko", "icu"}, {"kok", "icu"}, {"ku", "icu"}, {"ky", "icu"}, {"lb", "icu"}, + {"lkt", "icu"}, {"ln", "icu"}, {"lo", "icu"}, {"lt", "icu"}, {"lv", "icu"}, {"mk", "icu"}, + {"ml", "icu"}, {"mn", "icu"}, {"mr", "icu"}, {"ms", "icu"}, {"mt", "icu"}, {"my", "icu"}, + {"nb", "icu"}, {"nb_no", "icu"}, {"ne", "icu"}, {"nl", "icu"}, {"nn", "icu"}, {"om", "icu"}, + {"or", "icu"}, {"pa", "icu"}, {"pa_in", "icu"}, {"pl", "icu"}, {"ps", "icu"}, {"pt", "icu"}, + {"ro", "icu"}, {"ru", "icu"}, {"sa", "icu"}, {"se", "icu"}, {"si", "icu"}, {"sk", "icu"}, + {"sl", "icu"}, {"smn", "icu"}, {"sq", "icu"}, {"sr", "icu"}, {"sr_ba", "icu"}, {"sr_me", "icu"}, + {"sr_rs", "icu"}, {"sv", "icu"}, {"sw", "icu"}, {"ta", "icu"}, {"te", "icu"}, {"th", "icu"}, + {"tk", "icu"}, {"to", "icu"}, {"tr", "icu"}, {"ug", "icu"}, {"uk", "icu"}, {"ur", "icu"}, + {"uz", "icu"}, {"vi", "icu"}, {"wae", "icu"}, {"wo", "icu"}, {"xh", "icu"}, {"yi", "icu"}, + {"yo", "icu"}, {"yue", "icu"}, {"yue_cn", "icu"}, {"zh", "icu"}, {"zh_cn", "icu"}, {"zh_hk", "icu"}, + {"zh_mo", "icu"}, {"zh_sg", "icu"}, {"zh_tw", "icu"}, {"zu", "icu"}}; // END_OF_EXTENSION_COLLATIONS + +// Note: these are currently hardcoded in scripts/generate_extensions_function.py +// TODO: automate by passing though to script via duckdb +static constexpr ExtensionEntry EXTENSION_FILE_PREFIXES[] = { + {"http://", "httpfs"}, {"https://", "httpfs"}, {"s3://", "httpfs"}, + // {"azure://", "azure"} +}; // END_OF_EXTENSION_FILE_PREFIXES + +// Note: these are currently hardcoded in scripts/generate_extensions_function.py +// TODO: automate by passing though to script via duckdb +static constexpr ExtensionEntry EXTENSION_FILE_POSTFIXES[] = { + {".parquet", "parquet"}, {".json", "json"}, {".jsonl", "json"}, {".ndjson", "json"}, + {".shp", "spatial"}, {".gpkg", "spatial"}, {".fgb", "spatial"}}; // END_OF_EXTENSION_FILE_POSTFIXES + +// Note: these are currently hardcoded in scripts/generate_extensions_function.py +// TODO: automate by passing though to script via duckdb +static constexpr ExtensionEntry EXTENSION_FILE_CONTAINS[] = {{".parquet?", "parquet"}, + {".json?", "json"}, + {".ndjson?", ".jsonl?"}, + {".jsonl?", ".ndjson?"}}; // EXTENSION_FILE_CONTAINS + +static constexpr const char *AUTOLOADABLE_EXTENSIONS[] = { + // "azure", + "arrow", + "aws", + "autocomplete", + "excel", + "fts", + "httpfs", + // "inet", + // "icu", + "json", + "parquet", + "postgres_scanner", + // "spatial", TODO: table function isnt always autoloaded so test fails + "sqlsmith", + "sqlite_scanner", + "tpcds", + "tpch", + "visualizer", +}; // END_OF_AUTOLOADABLE_EXTENSIONS + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension_helper.hpp b/src/duckdb/src/include/duckdb/main/extension_helper.hpp new file mode 100644 index 00000000..1d2c5463 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/extension_helper.hpp @@ -0,0 +1,123 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/extension_helper.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include "duckdb.hpp" +#include "duckdb/main/extension_entries.hpp" + +namespace duckdb { +class DuckDB; + +enum class ExtensionLoadResult : uint8_t { LOADED_EXTENSION = 0, EXTENSION_UNKNOWN = 1, NOT_LOADED = 2 }; + +struct DefaultExtension { + const char *name; + const char *description; + bool statically_loaded; +}; + +struct ExtensionAlias { + const char *alias; + const char *extension; +}; + +struct ExtensionInitResult { + string filename; + string basename; + + void *lib_hdl; +}; + +class ExtensionHelper { +public: + static void LoadAllExtensions(DuckDB &db); + + static ExtensionLoadResult LoadExtension(DuckDB &db, const std::string &extension); + + static void InstallExtension(ClientContext &context, const string &extension, bool force_install, + const string &respository = ""); + static void InstallExtension(DBConfig &config, FileSystem &fs, const string &extension, bool force_install, + const string &respository = ""); + static void LoadExternalExtension(ClientContext &context, const string &extension); + static void LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension, + optional_ptr client_config); + + //! Autoload an extension by name. Depending on the current settings, this will either load or install+load + static void AutoLoadExtension(ClientContext &context, const string &extension_name); + DUCKDB_API static bool TryAutoLoadExtension(ClientContext &context, const string &extension_name) noexcept; + + static string ExtensionDirectory(ClientContext &context); + static string ExtensionDirectory(DBConfig &config, FileSystem &fs); + static string ExtensionUrlTemplate(optional_ptr config, const string &repository); + static string ExtensionFinalizeUrlTemplate(const string &url, const string &name); + + static idx_t DefaultExtensionCount(); + static DefaultExtension GetDefaultExtension(idx_t index); + + static idx_t ExtensionAliasCount(); + static ExtensionAlias GetExtensionAlias(idx_t index); + + static const vector GetPublicKeys(); + + // Returns extension name, or empty string if not a replacement open path + static string ExtractExtensionPrefixFromPath(const string &path); + + //! Apply any known extension aliases + static string ApplyExtensionAlias(string extension_name); + + static string GetExtensionName(const string &extension); + static bool IsFullPath(const string &extension); + + //! Lookup a name in an ExtensionEntry list + template + static string FindExtensionInEntries(const string &name, const ExtensionEntry (&entries)[N]) { + auto lcase = StringUtil::Lower(name); + + auto it = + std::find_if(entries, entries + N, [&](const ExtensionEntry &element) { return element.name == lcase; }); + + if (it != entries + N && it->name == lcase) { + return it->extension; + } + return ""; + } + + //! Whether an extension can be autoloaded (i.e. it's registered as an autoloadable extension in + //! extension_entries.hpp) + static bool CanAutoloadExtension(const string &ext_name); + + //! Utility functions for creating meaningful error messages regarding missing extensions + static string WrapAutoLoadExtensionErrorMsg(ClientContext &context, const string &base_error, + const string &extension_name); + static string AddExtensionInstallHintToErrorMsg(ClientContext &context, const string &base_error, + const string &extension_name); + +private: + static void InstallExtensionInternal(DBConfig &config, ClientConfig *client_config, FileSystem &fs, + const string &local_path, const string &extension, bool force_install, + const string &repository); + static const vector PathComponents(); + static bool AllowAutoInstall(const string &extension); + static ExtensionInitResult InitialLoad(DBConfig &config, FileSystem &fs, const string &extension, + optional_ptr client_config); + static bool TryInitialLoad(DBConfig &config, FileSystem &fs, const string &extension, ExtensionInitResult &result, + string &error, optional_ptr client_config); + //! For tagged releases we use the tag, else we use the git commit hash + static const string GetVersionDirectoryName(); + //! Version tags occur with and without 'v', tag in extension path is always with 'v' + static const string NormalizeVersionTag(const string &version_tag); + static bool IsRelease(const string &version_tag); + static bool CreateSuggestions(const string &extension_name, string &message); + +private: + static ExtensionLoadResult LoadExtensionInternal(DuckDB &db, const std::string &extension, bool initial_load); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension_util.hpp b/src/duckdb/src/include/duckdb/main/extension_util.hpp new file mode 100644 index 00000000..2ada8410 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/extension_util.hpp @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/extension_util.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { +struct CreateMacroInfo; +struct CreateCollationInfo; +class DatabaseInstance; + +//! The ExtensionUtil class contains methods that are useful for extensions +class ExtensionUtil { +public: + //! Register a new scalar function - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, ScalarFunction function); + //! Register a new scalar function set - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, ScalarFunctionSet function); + //! Register a new aggregate function - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, AggregateFunction function); + //! Register a new aggregate function set - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, AggregateFunctionSet function); + //! Register a new table function - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, TableFunction function); + //! Register a new table function set - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, TableFunctionSet function); + //! Register a new pragma function - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, PragmaFunction function); + //! Register a new pragma function set - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, PragmaFunctionSet function); + //! Register a new copy function - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, CopyFunction function); + //! Register a new macro function - throw an exception if the function already exists + DUCKDB_API static void RegisterFunction(DatabaseInstance &db, CreateMacroInfo &info); + + //! Register a new collation + DUCKDB_API static void RegisterCollation(DatabaseInstance &db, CreateCollationInfo &info); + + //! Returns a reference to the function in the catalog - throws an exception if it does not exist + DUCKDB_API static ScalarFunctionCatalogEntry &GetFunction(DatabaseInstance &db, const string &name); + DUCKDB_API static TableFunctionCatalogEntry &GetTableFunction(DatabaseInstance &db, const string &name); + + //! Add a function overload + DUCKDB_API static void AddFunctionOverload(DatabaseInstance &db, ScalarFunction function); + DUCKDB_API static void AddFunctionOverload(DatabaseInstance &db, ScalarFunctionSet function); + + DUCKDB_API static void AddFunctionOverload(DatabaseInstance &db, TableFunctionSet function); + + //! Registers a new type + DUCKDB_API static void RegisterType(DatabaseInstance &db, string type_name, LogicalType type); + + //! Registers a cast between two types + DUCKDB_API static void RegisterCastFunction(DatabaseInstance &db, const LogicalType &source, + const LogicalType &target, BoundCastInfo function, + int64_t implicit_cast_cost = -1); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/external_dependencies.hpp b/src/duckdb/src/include/duckdb/main/external_dependencies.hpp new file mode 100644 index 00000000..2632dd8a --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/external_dependencies.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/external_dependencies.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +enum ExternalDependenciesType { PYTHON_DEPENDENCY }; +class ExternalDependency { +public: + explicit ExternalDependency(ExternalDependenciesType type_p) : type(type_p) {}; + virtual ~ExternalDependency() {}; + ExternalDependenciesType type; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp b/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp new file mode 100644 index 00000000..334a7e99 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/materialized_query_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/main/query_result.hpp" + +namespace duckdb { + +class ClientContext; + +class MaterializedQueryResult : public QueryResult { +public: + static constexpr const QueryResultType TYPE = QueryResultType::MATERIALIZED_RESULT; + +public: + friend class ClientContext; + //! Creates a successful query result with the specified names and types + DUCKDB_API MaterializedQueryResult(StatementType statement_type, StatementProperties properties, + vector names, unique_ptr collection, + ClientProperties client_properties); + //! Creates an unsuccessful query result with error condition + DUCKDB_API explicit MaterializedQueryResult(PreservedError error); + +public: + //! Fetches a DataChunk from the query result. + //! This will consume the result (i.e. the result can only be scanned once with this function) + DUCKDB_API unique_ptr Fetch() override; + DUCKDB_API unique_ptr FetchRaw() override; + //! Converts the QueryResult to a string + DUCKDB_API string ToString() override; + DUCKDB_API string ToBox(ClientContext &context, const BoxRendererConfig &config) override; + + //! Gets the (index) value of the (column index) column. + //! Note: this is very slow. Scanning over the underlying collection is much faster. + DUCKDB_API Value GetValue(idx_t column, idx_t index); + + template + T GetValue(idx_t column, idx_t index) { + auto value = GetValue(column, index); + return (T)value.GetValue(); + } + + DUCKDB_API idx_t RowCount() const; + + //! Returns a reference to the underlying column data collection + ColumnDataCollection &Collection(); + +private: + unique_ptr collection; + //! Row collection, only created if GetValue is called + unique_ptr row_collection; + //! Scan state for Fetch calls + ColumnDataScanState scan_state; + bool scan_initialized; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/pending_query_result.hpp b/src/duckdb/src/include/duckdb/main/pending_query_result.hpp new file mode 100644 index 00000000..672388c4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/pending_query_result.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/pending_query_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/query_result.hpp" +#include "duckdb/common/enums/pending_execution_result.hpp" +#include "duckdb/execution/executor.hpp" + +namespace duckdb { +class ClientContext; +class ClientContextLock; +class PreparedStatementData; + +class PendingQueryResult : public BaseQueryResult { + friend class ClientContext; + +public: + static constexpr const QueryResultType TYPE = QueryResultType::PENDING_RESULT; + +public: + DUCKDB_API PendingQueryResult(shared_ptr context, PreparedStatementData &statement, + vector types, bool allow_stream_result); + DUCKDB_API explicit PendingQueryResult(PreservedError error_message); + DUCKDB_API ~PendingQueryResult(); + +public: + //! Executes a single task within the query, returning whether or not the query is ready. + //! If this returns RESULT_READY, the Execute function can be called to obtain a pointer to the result. + //! If this returns RESULT_NOT_READY, the ExecuteTask function should be called again. + //! If this returns EXECUTION_ERROR, an error occurred during execution. + //! If this returns NO_TASKS_AVAILABLE, this means currently no meaningful work can be done by the current executor, + //! but tasks may become available in the future. + //! The error message can be obtained by calling GetError() on the PendingQueryResult. + DUCKDB_API PendingExecutionResult ExecuteTask(); + + //! Returns the result of the query as an actual query result. + //! This returns (mostly) instantly if ExecuteTask has been called until RESULT_READY was returned. + DUCKDB_API unique_ptr Execute(); + + DUCKDB_API void Close(); + + //! Function to determine whether execution is considered finished + DUCKDB_API static bool IsFinished(PendingExecutionResult result); + +private: + shared_ptr context; + bool allow_stream_result; + +private: + void CheckExecutableInternal(ClientContextLock &lock); + + PendingExecutionResult ExecuteTaskInternal(ClientContextLock &lock); + unique_ptr ExecuteInternal(ClientContextLock &lock); + unique_ptr LockContext(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/prepared_statement.hpp b/src/duckdb/src/include/duckdb/main/prepared_statement.hpp new file mode 100644 index 00000000..8c5d29c7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/prepared_statement.hpp @@ -0,0 +1,178 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/prepared_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/winapi.hpp" +#include "duckdb/main/materialized_query_result.hpp" +#include "duckdb/main/pending_query_result.hpp" +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { +class ClientContext; +class PreparedStatementData; + +//! A prepared statement +class PreparedStatement { +public: + //! Create a successfully prepared prepared statement object with the given name + DUCKDB_API PreparedStatement(shared_ptr context, shared_ptr data, + string query, idx_t n_param, case_insensitive_map_t named_param_map); + //! Create a prepared statement that was not successfully prepared + DUCKDB_API explicit PreparedStatement(PreservedError error); + + DUCKDB_API ~PreparedStatement(); + +public: + //! The client context this prepared statement belongs to + shared_ptr context; + //! The prepared statement data + shared_ptr data; + //! The query that is being prepared + string query; + //! Whether or not the statement was successfully prepared + bool success; + //! The error message (if success = false) + PreservedError error; + //! The amount of bound parameters + idx_t n_param; + //! The (optional) named parameters + case_insensitive_map_t named_param_map; + +public: + //! Returns the stored error message + DUCKDB_API const string &GetError(); + //! Returns the stored error object + DUCKDB_API PreservedError &GetErrorObject(); + //! Returns whether or not an error occurred + DUCKDB_API bool HasError() const; + //! Returns the number of columns in the result + DUCKDB_API idx_t ColumnCount(); + //! Returns the statement type of the underlying prepared statement object + DUCKDB_API StatementType GetStatementType(); + //! Returns the underlying statement properties + DUCKDB_API StatementProperties GetStatementProperties(); + //! Returns the result SQL types of the prepared statement + DUCKDB_API const vector &GetTypes(); + //! Returns the result names of the prepared statement + DUCKDB_API const vector &GetNames(); + //! Returns the map of parameter index to the expected type of parameter + DUCKDB_API case_insensitive_map_t GetExpectedParameterTypes() const; + + //! Create a pending query result of the prepared statement with the given set of arguments + template + unique_ptr PendingQuery(Args... args) { + vector values; + return PendingQueryRecursive(values, args...); + } + + //! Create a pending query result of the prepared statement with the given set of arguments + DUCKDB_API unique_ptr PendingQuery(vector &values, bool allow_stream_result = true); + + //! Create a pending query result of the prepared statement with the given set named arguments + DUCKDB_API unique_ptr PendingQuery(case_insensitive_map_t &named_values, + bool allow_stream_result = true); + + //! Execute the prepared statement with the given set of values + DUCKDB_API unique_ptr Execute(vector &values, bool allow_stream_result = true); + + //! Execute the prepared statement with the given set of named+unnamed values + DUCKDB_API unique_ptr Execute(case_insensitive_map_t &named_values, + bool allow_stream_result = true); + + //! Execute the prepared statement with the given set of arguments + template + unique_ptr Execute(Args... args) { + vector values; + return ExecuteRecursive(values, args...); + } + + template + static string ExcessValuesException(const case_insensitive_map_t ¶meters, + case_insensitive_map_t &values) { + // Too many values + set excess_set; + for (auto &pair : values) { + auto &name = pair.first; + if (!parameters.count(name)) { + excess_set.insert(name); + } + } + vector excess_values; + for (auto &val : excess_set) { + excess_values.push_back(val); + } + return StringUtil::Format("Parameter argument/count mismatch, identifiers of the excess parameters: %s", + StringUtil::Join(excess_values, ", ")); + } + + template + static string MissingValuesException(const case_insensitive_map_t ¶meters, + case_insensitive_map_t &values) { + // Missing values + set missing_set; + for (auto &pair : parameters) { + auto &name = pair.first; + if (!values.count(name)) { + missing_set.insert(name); + } + } + vector missing_values; + for (auto &val : missing_set) { + missing_values.push_back(val); + } + return StringUtil::Format("Values were not provided for the following prepared statement parameters: %s", + StringUtil::Join(missing_values, ", ")); + } + + template + static void VerifyParameters(case_insensitive_map_t &provided, + const case_insensitive_map_t &expected) { + if (expected.size() == provided.size()) { + // Same amount of identifiers, if + for (auto &pair : expected) { + auto &identifier = pair.first; + if (!provided.count(identifier)) { + throw InvalidInputException(MissingValuesException(expected, provided)); + } + } + return; + } + // Mismatch in expected and provided parameters/values + if (expected.size() > provided.size()) { + throw InvalidInputException(MissingValuesException(expected, provided)); + } else { + D_ASSERT(provided.size() > expected.size()); + throw InvalidInputException(ExcessValuesException(expected, provided)); + } + } + +private: + unique_ptr PendingQueryRecursive(vector &values) { + return PendingQuery(values); + } + + template + unique_ptr PendingQueryRecursive(vector &values, T value, Args... args) { + values.push_back(Value::CreateValue(value)); + return PendingQueryRecursive(values, args...); + } + + unique_ptr ExecuteRecursive(vector &values) { + return Execute(values); + } + + template + unique_ptr ExecuteRecursive(vector &values, T value, Args... args) { + values.push_back(Value::CreateValue(value)); + return ExecuteRecursive(values, args...); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp new file mode 100644 index 00000000..ce0b7d1a --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/prepared_statement_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/planner/expression/bound_parameter_data.hpp" +#include "duckdb/planner/bound_parameter_map.hpp" + +namespace duckdb { +class CatalogEntry; +class ClientContext; +class PhysicalOperator; +class SQLStatement; + +class PreparedStatementData { +public: + DUCKDB_API explicit PreparedStatementData(StatementType type); + DUCKDB_API ~PreparedStatementData(); + + StatementType statement_type; + //! The unbound SQL statement that was prepared + unique_ptr unbound_statement; + //! The fully prepared physical plan of the prepared statement + unique_ptr plan; + + //! The result names of the transaction + vector names; + //! The result types of the transaction + vector types; + + //! The statement properties + StatementProperties properties; + + //! The catalog version of when the prepared statement was bound + //! If this version is lower than the current catalog version, we have to rebind the prepared statement + idx_t catalog_version; + //! The map of parameter index to the actual value entry + bound_parameter_map_t value_map; + +public: + void CheckParameterCount(idx_t parameter_count); + //! Whether or not the prepared statement data requires the query to rebound for the given parameters + bool RequireRebind(ClientContext &context, optional_ptr> values); + //! Bind a set of values to the prepared statement data + DUCKDB_API void Bind(case_insensitive_map_t values); + //! Get the expected SQL Type of the bound parameter + DUCKDB_API LogicalType GetType(const string &identifier); + //! Try to get the expected SQL Type of the bound parameter + DUCKDB_API bool TryGetType(const string &identifier, LogicalType &result); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/query_profiler.hpp b/src/duckdb/src/include/duckdb/main/query_profiler.hpp new file mode 100644 index 00000000..2ea69dc6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/query_profiler.hpp @@ -0,0 +1,265 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/query_profiler.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/profiler_format.hpp" +#include "duckdb/common/profiler.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/common/reference_map.hpp" +#include +#include "duckdb/common/pair.hpp" +#include "duckdb/common/deque.hpp" + +namespace duckdb { +class ClientContext; +class ExpressionExecutor; +class PhysicalOperator; +class SQLStatement; + +//! The ExpressionInfo keeps information related to an expression +struct ExpressionInfo { + explicit ExpressionInfo() : hasfunction(false) { + } + // A vector of children + vector> children; + // Extract ExpressionInformation from a given expression state + void ExtractExpressionsRecursive(unique_ptr &state); + + //! Whether or not expression has function + bool hasfunction; + //! The function Name + string function_name; + //! The function time + uint64_t function_time = 0; + //! Count the number of ALL tuples + uint64_t tuples_count = 0; + //! Count the number of tuples sampled + uint64_t sample_tuples_count = 0; +}; + +//! The ExpressionRootInfo keeps information related to the root of an expression tree +struct ExpressionRootInfo { + ExpressionRootInfo(ExpressionExecutorState &executor, string name); + + //! Count the number of time the executor called + uint64_t total_count = 0; + //! Count the number of time the executor called since last sampling + uint64_t current_count = 0; + //! Count the number of samples + uint64_t sample_count = 0; + //! Count the number of tuples in all samples + uint64_t sample_tuples_count = 0; + //! Count the number of tuples processed by this executor + uint64_t tuples_count = 0; + //! A vector which contain the pointer to root of each expression tree + unique_ptr root; + //! Name + string name; + //! Elapsed time + double time; + //! Extra Info + string extra_info; +}; + +struct ExpressionExecutorInfo { + explicit ExpressionExecutorInfo() {}; + explicit ExpressionExecutorInfo(ExpressionExecutor &executor, const string &name, int id); + + //! A vector which contain the pointer to all ExpressionRootInfo + vector> roots; + //! Id, it will be used as index for executors_info vector + int id; +}; + +struct OperatorInformation { + explicit OperatorInformation(double time_ = 0, idx_t elements_ = 0) : time(time_), elements(elements_) { + } + + double time = 0; + idx_t elements = 0; + string name; + //! A vector of Expression Executor Info + vector> executors_info; +}; + +//! The OperatorProfiler measures timings of individual operators +class OperatorProfiler { + friend class QueryProfiler; + +public: + DUCKDB_API explicit OperatorProfiler(bool enabled); + + DUCKDB_API void StartOperator(optional_ptr phys_op); + DUCKDB_API void EndOperator(optional_ptr chunk); + DUCKDB_API void Flush(const PhysicalOperator &phys_op, ExpressionExecutor &expression_executor, const string &name, + int id); + + ~OperatorProfiler() { + } + +private: + void AddTiming(const PhysicalOperator &op, double time, idx_t elements); + + //! Whether or not the profiler is enabled + bool enabled; + //! The timer used to time the execution time of the individual Physical Operators + Profiler op; + //! The stack of Physical Operators that are currently active + optional_ptr active_operator; + //! A mapping of physical operators to recorded timings + reference_map_t timings; +}; + +//! The QueryProfiler can be used to measure timings of queries +class QueryProfiler { +public: + DUCKDB_API QueryProfiler(ClientContext &context); + +public: + struct TreeNode { + PhysicalOperatorType type; + string name; + string extra_info; + OperatorInformation info; + vector> children; + idx_t depth = 0; + }; + + // Propagate save_location, enabled, detailed_enabled and automatic_print_format. + void Propagate(QueryProfiler &qp); + + using TreeMap = reference_map_t>; + +private: + unique_ptr CreateTree(const PhysicalOperator &root, idx_t depth = 0); + void Render(const TreeNode &node, std::ostream &str) const; + +public: + DUCKDB_API bool IsEnabled() const; + DUCKDB_API bool IsDetailedEnabled() const; + DUCKDB_API ProfilerPrintFormat GetPrintFormat() const; + DUCKDB_API bool PrintOptimizerOutput() const; + DUCKDB_API string GetSaveLocation() const; + + DUCKDB_API static QueryProfiler &Get(ClientContext &context); + + DUCKDB_API void StartQuery(string query, bool is_explain_analyze = false, bool start_at_optimizer = false); + DUCKDB_API void EndQuery(); + + DUCKDB_API void StartExplainAnalyze(); + + //! Adds the timings gathered by an OperatorProfiler to this query profiler + DUCKDB_API void Flush(OperatorProfiler &profiler); + + DUCKDB_API void StartPhase(string phase); + DUCKDB_API void EndPhase(); + + DUCKDB_API void Initialize(const PhysicalOperator &root); + + DUCKDB_API string QueryTreeToString() const; + DUCKDB_API void QueryTreeToStream(std::ostream &str) const; + DUCKDB_API void Print(); + + //! return the printed as a string. Unlike ToString, which is always formatted as a string, + //! the return value is formatted based on the current print format (see GetPrintFormat()). + DUCKDB_API string ToString() const; + + DUCKDB_API string ToJSON() const; + DUCKDB_API void WriteToFile(const char *path, string &info) const; + + idx_t OperatorSize() { + return tree_map.size(); + } + + void Finalize(TreeNode &node); + +private: + ClientContext &context; + + //! Whether or not the query profiler is running + bool running; + //! The lock used for flushing information from a thread into the global query profiler + mutex flush_lock; + + //! Whether or not the query requires profiling + bool query_requires_profiling; + + //! The root of the query tree + unique_ptr root; + //! The query string + string query; + //! The timer used to time the execution time of the entire query + Profiler main_query; + //! A map of a Physical Operator pointer to a tree node + TreeMap tree_map; + //! Whether or not we are running as part of a explain_analyze query + bool is_explain_analyze; + +public: + const TreeMap &GetTreeMap() const { + return tree_map; + } + +private: + //! The timer used to time the individual phases of the planning process + Profiler phase_profiler; + //! A mapping of the phase names to the timings + using PhaseTimingStorage = unordered_map; + PhaseTimingStorage phase_timings; + using PhaseTimingItem = PhaseTimingStorage::value_type; + //! The stack of currently active phases + vector phase_stack; + +private: + vector GetOrderedPhaseTimings() const; + + //! Check whether or not an operator type requires query profiling. If none of the ops in a query require profiling + //! no profiling information is output. + bool OperatorRequiresProfiling(PhysicalOperatorType op_type); +}; + +//! The QueryProfilerHistory can be used to access the profiler of previous queries +class QueryProfilerHistory { +private: + static constexpr uint64_t DEFAULT_SIZE = 20; + + //! Previous Query profilers + deque>> prev_profilers; + //! Previous Query profilers size + uint64_t prev_profilers_size = DEFAULT_SIZE; + +public: + deque>> &GetPrevProfilers() { + return prev_profilers; + } + QueryProfilerHistory() { + } + + void SetPrevProfilersSize(uint64_t prevProfilersSize) { + prev_profilers_size = prevProfilersSize; + } + uint64_t GetPrevProfilersSize() const { + return prev_profilers_size; + } + +public: + void SetProfilerHistorySize(uint64_t size) { + this->prev_profilers_size = size; + } + void ResetProfilerHistorySize() { + this->prev_profilers_size = DEFAULT_SIZE; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/query_result.hpp b/src/duckdb/src/include/duckdb/main/query_result.hpp new file mode 100644 index 00000000..3c5088b2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/query_result.hpp @@ -0,0 +1,206 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/query_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/main/client_properties.hpp" + +namespace duckdb { +struct BoxRendererConfig; + +enum class QueryResultType : uint8_t { MATERIALIZED_RESULT, STREAM_RESULT, PENDING_RESULT }; + +class BaseQueryResult { +public: + //! Creates a successful query result with the specified names and types + DUCKDB_API BaseQueryResult(QueryResultType type, StatementType statement_type, StatementProperties properties, + vector types, vector names); + //! Creates an unsuccessful query result with error condition + DUCKDB_API BaseQueryResult(QueryResultType type, PreservedError error); + DUCKDB_API virtual ~BaseQueryResult(); + + //! The type of the result (MATERIALIZED or STREAMING) + QueryResultType type; + //! The type of the statement that created this result + StatementType statement_type; + //! Properties of the statement + StatementProperties properties; + //! The SQL types of the result + vector types; + //! The names of the result + vector names; + +public: + [[noreturn]] DUCKDB_API void ThrowError(const string &prepended_message = "") const; + DUCKDB_API void SetError(PreservedError error); + DUCKDB_API bool HasError() const; + DUCKDB_API const ExceptionType &GetErrorType() const; + DUCKDB_API const std::string &GetError(); + DUCKDB_API PreservedError &GetErrorObject(); + DUCKDB_API idx_t ColumnCount(); + +protected: + //! Whether or not execution was successful + bool success; + //! The error (in case execution was not successful) + PreservedError error; +}; + +//! The QueryResult object holds the result of a query. It can either be a MaterializedQueryResult, in which case the +//! result contains the entire result set, or a StreamQueryResult in which case the Fetch method can be called to +//! incrementally fetch data from the database. +class QueryResult : public BaseQueryResult { +public: + //! Creates a successful query result with the specified names and types + DUCKDB_API QueryResult(QueryResultType type, StatementType statement_type, StatementProperties properties, + vector types, vector names, ClientProperties client_properties); + //! Creates an unsuccessful query result with error condition + DUCKDB_API QueryResult(QueryResultType type, PreservedError error); + DUCKDB_API virtual ~QueryResult() override; + + //! Properties from the client context + ClientProperties client_properties; + //! The next result (if any) + unique_ptr next; + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast query result to type - query result type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast query result to type - query result type mismatch"); + } + return reinterpret_cast(*this); + } + +public: + //! Returns the name of the column for the given index + DUCKDB_API const string &ColumnName(idx_t index) const; + //! Fetches a DataChunk of normalized (flat) vectors from the query result. + //! Returns nullptr if there are no more results to fetch. + DUCKDB_API virtual unique_ptr Fetch(); + //! Fetches a DataChunk from the query result. The vectors are not normalized and hence any vector types can be + //! returned. + DUCKDB_API virtual unique_ptr FetchRaw() = 0; + //! Converts the QueryResult to a string + DUCKDB_API virtual string ToString() = 0; + //! Converts the QueryResult to a box-rendered string + DUCKDB_API virtual string ToBox(ClientContext &context, const BoxRendererConfig &config); + //! Prints the QueryResult to the console + DUCKDB_API void Print(); + //! Returns true if the two results are identical; false otherwise. Note that this method is destructive; it calls + //! Fetch() until both results are exhausted. The data in the results will be lost. + DUCKDB_API bool Equals(QueryResult &other); + + bool TryFetch(unique_ptr &result, PreservedError &error) { + try { + result = Fetch(); + return success; + } catch (const Exception &ex) { + error = PreservedError(ex); + return false; + } catch (std::exception &ex) { + error = PreservedError(ex); + return false; + } catch (...) { + error = PreservedError("Unknown error in Fetch"); + return false; + } + } + +private: + class QueryResultIterator; + class QueryResultRow { + public: + explicit QueryResultRow(QueryResultIterator &iterator_p, idx_t row_idx) : iterator(iterator_p), row(0) { + } + + QueryResultIterator &iterator; + idx_t row; + + template + T GetValue(idx_t col_idx) const { + return iterator.chunk->GetValue(col_idx, row).GetValue(); + } + }; + //! The row-based query result iterator. Invoking the + class QueryResultIterator { + public: + explicit QueryResultIterator(optional_ptr result_p) + : current_row(*this, 0), result(result_p), base_row(0) { + if (result) { + chunk = shared_ptr(result->Fetch().release()); + if (!chunk) { + result = nullptr; + } + } + } + + QueryResultRow current_row; + shared_ptr chunk; + optional_ptr result; + idx_t base_row; + + public: + void Next() { + if (!chunk) { + return; + } + current_row.row++; + if (current_row.row >= chunk->size()) { + base_row += chunk->size(); + chunk = shared_ptr(result->Fetch().release()); + current_row.row = 0; + if (!chunk || chunk->size() == 0) { + // exhausted all rows + base_row = 0; + result = nullptr; + chunk.reset(); + } + } + } + + QueryResultIterator &operator++() { + Next(); + return *this; + } + bool operator!=(const QueryResultIterator &other) const { + return result != other.result || base_row != other.base_row || current_row.row != other.current_row.row; + } + const QueryResultRow &operator*() const { + return current_row; + } + }; + +public: + QueryResultIterator begin() { + return QueryResultIterator(this); + } + QueryResultIterator end() { + return QueryResultIterator(nullptr); + } + +protected: + DUCKDB_API string HeaderToString(); + +private: + QueryResult(const QueryResult &) = delete; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation.hpp b/src/duckdb/src/include/duckdb/main/relation.hpp new file mode 100644 index 00000000..89a66d02 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation.hpp @@ -0,0 +1,193 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/enums/relation_type.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/enums/joinref_type.hpp" +#include "duckdb/main/query_result.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/common/named_parameter_map.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/external_dependencies.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { +struct BoundStatement; + +class ClientContextWrapper; +class Binder; +class LogicalOperator; +class QueryNode; +class TableRef; + +class Relation : public std::enable_shared_from_this { +public: + Relation(const std::shared_ptr &context, RelationType type) : context(context), type(type) { + } + Relation(ClientContextWrapper &context, RelationType type) : context(context.GetContext()), type(type) { + } + virtual ~Relation() { + } + + ClientContextWrapper context; + + RelationType type; + + shared_ptr extra_dependencies; + +public: + DUCKDB_API virtual const vector &Columns() = 0; + DUCKDB_API virtual unique_ptr GetQueryNode(); + DUCKDB_API virtual BoundStatement Bind(Binder &binder); + DUCKDB_API virtual string GetAlias(); + + DUCKDB_API unique_ptr ExecuteOrThrow(); + DUCKDB_API unique_ptr Execute(); + DUCKDB_API string ToString(); + DUCKDB_API virtual string ToString(idx_t depth) = 0; + + DUCKDB_API void Print(); + DUCKDB_API void Head(idx_t limit = 10); + + DUCKDB_API shared_ptr CreateView(const string &name, bool replace = true, bool temporary = false); + DUCKDB_API shared_ptr CreateView(const string &schema_name, const string &name, bool replace = true, + bool temporary = false); + DUCKDB_API unique_ptr Query(const string &sql); + DUCKDB_API unique_ptr Query(const string &name, const string &sql); + + //! Explain the query plan of this relation + DUCKDB_API unique_ptr Explain(ExplainType type = ExplainType::EXPLAIN_STANDARD); + + DUCKDB_API virtual unique_ptr GetTableRef(); + virtual bool IsReadOnly() { + return true; + } + +public: + // PROJECT + DUCKDB_API shared_ptr Project(const string &select_list); + DUCKDB_API shared_ptr Project(const string &expression, const string &alias); + DUCKDB_API shared_ptr Project(const string &select_list, const vector &aliases); + DUCKDB_API shared_ptr Project(const vector &expressions); + DUCKDB_API shared_ptr Project(const vector &expressions, const vector &aliases); + DUCKDB_API shared_ptr Project(vector> expressions, + const vector &aliases); + + // FILTER + DUCKDB_API shared_ptr Filter(const string &expression); + DUCKDB_API shared_ptr Filter(unique_ptr expression); + DUCKDB_API shared_ptr Filter(const vector &expressions); + + // LIMIT + DUCKDB_API shared_ptr Limit(int64_t n, int64_t offset = 0); + + // ORDER + DUCKDB_API shared_ptr Order(const string &expression); + DUCKDB_API shared_ptr Order(const vector &expressions); + DUCKDB_API shared_ptr Order(vector expressions); + + // JOIN operation + DUCKDB_API shared_ptr Join(const shared_ptr &other, const string &condition, + JoinType type = JoinType::INNER, JoinRefType ref_type = JoinRefType::REGULAR); + shared_ptr Join(const shared_ptr &other, vector> condition, + JoinType type = JoinType::INNER, JoinRefType ref_type = JoinRefType::REGULAR); + + // CROSS PRODUCT operation + DUCKDB_API shared_ptr CrossProduct(const shared_ptr &other, + JoinRefType join_ref_type = JoinRefType::CROSS); + + // SET operations + DUCKDB_API shared_ptr Union(const shared_ptr &other); + DUCKDB_API shared_ptr Except(const shared_ptr &other); + DUCKDB_API shared_ptr Intersect(const shared_ptr &other); + + // DISTINCT operation + DUCKDB_API shared_ptr Distinct(); + + // AGGREGATES + DUCKDB_API shared_ptr Aggregate(const string &aggregate_list); + DUCKDB_API shared_ptr Aggregate(const vector &aggregates); + DUCKDB_API shared_ptr Aggregate(const string &aggregate_list, const string &group_list); + DUCKDB_API shared_ptr Aggregate(const vector &aggregates, const vector &groups); + DUCKDB_API shared_ptr Aggregate(vector> expressions, + const string &group_list); + + // ALIAS + DUCKDB_API shared_ptr Alias(const string &alias); + + //! Insert the data from this relation into a table + DUCKDB_API shared_ptr InsertRel(const string &schema_name, const string &table_name); + DUCKDB_API void Insert(const string &table_name); + DUCKDB_API void Insert(const string &schema_name, const string &table_name); + //! Insert a row (i.e.,list of values) into a table + DUCKDB_API void Insert(const vector> &values); + //! Create a table and insert the data from this relation into that table + DUCKDB_API shared_ptr CreateRel(const string &schema_name, const string &table_name); + DUCKDB_API void Create(const string &table_name); + DUCKDB_API void Create(const string &schema_name, const string &table_name); + + //! Write a relation to a CSV file + DUCKDB_API shared_ptr + WriteCSVRel(const string &csv_file, + case_insensitive_map_t> options = case_insensitive_map_t>()); + DUCKDB_API void WriteCSV(const string &csv_file, + case_insensitive_map_t> options = case_insensitive_map_t>()); + //! Write a relation to a Parquet file + DUCKDB_API shared_ptr + WriteParquetRel(const string &parquet_file, + case_insensitive_map_t> options = case_insensitive_map_t>()); + DUCKDB_API void + WriteParquet(const string &parquet_file, + case_insensitive_map_t> options = case_insensitive_map_t>()); + + //! Update a table, can only be used on a TableRelation + DUCKDB_API virtual void Update(const string &update, const string &condition = string()); + //! Delete from a table, can only be used on a TableRelation + DUCKDB_API virtual void Delete(const string &condition = string()); + //! Create a relation from calling a table in/out function on the input relation + //! Create a relation from calling a table in/out function on the input relation + DUCKDB_API shared_ptr TableFunction(const std::string &fname, const vector &values); + DUCKDB_API shared_ptr TableFunction(const std::string &fname, const vector &values, + const named_parameter_map_t &named_parameters); + +public: + //! Whether or not the relation inherits column bindings from its child or not, only relevant for binding + virtual bool InheritsColumnBindings() { + return false; + } + virtual Relation *ChildRelation() { + return nullptr; + } + DUCKDB_API vector> GetAllDependencies(); + +protected: + DUCKDB_API string RenderWhitespace(idx_t depth); + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/aggregate_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/aggregate_relation.hpp new file mode 100644 index 00000000..45daad4a --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/aggregate_relation.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/aggregate_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/group_by_node.hpp" + +namespace duckdb { + +class AggregateRelation : public Relation { +public: + DUCKDB_API AggregateRelation(shared_ptr child, vector> expressions); + DUCKDB_API AggregateRelation(shared_ptr child, vector> expressions, + GroupByNode groups); + DUCKDB_API AggregateRelation(shared_ptr child, vector> expressions, + vector> groups); + + vector> expressions; + GroupByNode groups; + vector columns; + shared_ptr child; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp new file mode 100644 index 00000000..44ca800a --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/create_table_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class CreateTableRelation : public Relation { +public: + CreateTableRelation(shared_ptr child, string schema_name, string table_name); + + shared_ptr child; + string schema_name; + string table_name; + vector columns; + +public: + BoundStatement Bind(Binder &binder) override; + const vector &Columns() override; + string ToString(idx_t depth) override; + bool IsReadOnly() override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp new file mode 100644 index 00000000..cb826a86 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/create_view_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class CreateViewRelation : public Relation { +public: + CreateViewRelation(shared_ptr child, string view_name, bool replace, bool temporary); + CreateViewRelation(shared_ptr child, string schema_name, string view_name, bool replace, bool temporary); + + shared_ptr child; + string schema_name; + string view_name; + bool replace; + bool temporary; + vector columns; + +public: + BoundStatement Bind(Binder &binder) override; + const vector &Columns() override; + string ToString(idx_t depth) override; + bool IsReadOnly() override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/cross_product_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/cross_product_relation.hpp new file mode 100644 index 00000000..82fc828b --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/cross_product_relation.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/cross_product_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/common/enums/joinref_type.hpp" + +namespace duckdb { + +class CrossProductRelation : public Relation { +public: + DUCKDB_API CrossProductRelation(shared_ptr left, shared_ptr right, + JoinRefType join_ref_type = JoinRefType::CROSS); + + shared_ptr left; + shared_ptr right; + JoinRefType ref_type; + vector columns; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + + unique_ptr GetTableRef() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp new file mode 100644 index 00000000..2e0c6564 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/delete_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +class DeleteRelation : public Relation { +public: + DeleteRelation(ClientContextWrapper &context, unique_ptr condition, string schema_name, + string table_name); + + vector columns; + unique_ptr condition; + string schema_name; + string table_name; + +public: + BoundStatement Bind(Binder &binder) override; + const vector &Columns() override; + string ToString(idx_t depth) override; + bool IsReadOnly() override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/distinct_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/distinct_relation.hpp new file mode 100644 index 00000000..18209e1a --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/distinct_relation.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/distinct_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class DistinctRelation : public Relation { +public: + explicit DistinctRelation(shared_ptr child); + + shared_ptr child; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + +public: + bool InheritsColumnBindings() override { + return true; + } + Relation *ChildRelation() override { + return child.get(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp new file mode 100644 index 00000000..bdc24097 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/explain_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class ExplainRelation : public Relation { +public: + explicit ExplainRelation(shared_ptr child, ExplainType type = ExplainType::EXPLAIN_STANDARD); + + shared_ptr child; + vector columns; + ExplainType type; + +public: + BoundStatement Bind(Binder &binder) override; + const vector &Columns() override; + string ToString(idx_t depth) override; + bool IsReadOnly() override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/filter_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/filter_relation.hpp new file mode 100644 index 00000000..2a87a02d --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/filter_relation.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/filter_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +class FilterRelation : public Relation { +public: + DUCKDB_API FilterRelation(shared_ptr child, unique_ptr condition); + + unique_ptr condition; + shared_ptr child; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + +public: + bool InheritsColumnBindings() override { + return true; + } + Relation *ChildRelation() override { + return child.get(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp new file mode 100644 index 00000000..3695cde7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/insert_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class InsertRelation : public Relation { +public: + InsertRelation(shared_ptr child, string schema_name, string table_name); + + shared_ptr child; + string schema_name; + string table_name; + vector columns; + +public: + BoundStatement Bind(Binder &binder) override; + const vector &Columns() override; + string ToString(idx_t depth) override; + bool IsReadOnly() override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/join_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/join_relation.hpp new file mode 100644 index 00000000..76436d5c --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/join_relation.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/join_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/common/enums/joinref_type.hpp" + +namespace duckdb { + +class JoinRelation : public Relation { +public: + DUCKDB_API JoinRelation(shared_ptr left, shared_ptr right, + unique_ptr condition, JoinType type, + JoinRefType join_ref_type = JoinRefType::REGULAR); + DUCKDB_API JoinRelation(shared_ptr left, shared_ptr right, vector using_columns, + JoinType type, JoinRefType join_ref_type = JoinRefType::REGULAR); + + shared_ptr left; + shared_ptr right; + unique_ptr condition; + vector using_columns; + JoinType join_type; + JoinRefType join_ref_type; + vector columns; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + + unique_ptr GetTableRef() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/limit_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/limit_relation.hpp new file mode 100644 index 00000000..4edc1ae4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/limit_relation.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/limit_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class LimitRelation : public Relation { +public: + DUCKDB_API LimitRelation(shared_ptr child, int64_t limit, int64_t offset); + + int64_t limit; + int64_t offset; + shared_ptr child; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + +public: + bool InheritsColumnBindings() override { + return true; + } + Relation *ChildRelation() override { + return child.get(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/order_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/order_relation.hpp new file mode 100644 index 00000000..604ea689 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/order_relation.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/order_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/result_modifier.hpp" + +namespace duckdb { + +class OrderRelation : public Relation { +public: + DUCKDB_API OrderRelation(shared_ptr child, vector orders); + + vector orders; + shared_ptr child; + vector columns; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + +public: + bool InheritsColumnBindings() override { + return true; + } + Relation *ChildRelation() override { + return child.get(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/projection_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/projection_relation.hpp new file mode 100644 index 00000000..11110a7c --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/projection_relation.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/projection_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +class ProjectionRelation : public Relation { +public: + DUCKDB_API ProjectionRelation(shared_ptr child, vector> expressions, + vector aliases); + + vector> expressions; + vector columns; + shared_ptr child; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp new file mode 100644 index 00000000..3a2728bf --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/query_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { +class SelectStatement; + +class QueryRelation : public Relation { +public: + QueryRelation(const std::shared_ptr &context, unique_ptr select_stmt, string alias); + ~QueryRelation(); + + unique_ptr select_stmt; + string alias; + vector columns; + +public: + static unique_ptr ParseStatement(ClientContext &context, const string &query, const string &error); + unique_ptr GetQueryNode() override; + unique_ptr GetTableRef() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + +private: + unique_ptr GetSelectStatement(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/read_csv_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/read_csv_relation.hpp new file mode 100644 index 00000000..fc2f9818 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/read_csv_relation.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/read_csv_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/main/relation/table_function_relation.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +class ReadCSVRelation : public TableFunctionRelation { +public: + ReadCSVRelation(const shared_ptr &context, const string &csv_file, vector columns, + string alias = string()); + ReadCSVRelation(const shared_ptr &context, const string &csv_file, named_parameter_map_t &&options, + string alias = string()); + + string alias; + bool auto_detect; + +public: + string GetAlias() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/read_json_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/read_json_relation.hpp new file mode 100644 index 00000000..b50a25fd --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/read_json_relation.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "duckdb/main/relation/table_function_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/named_parameter_map.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class ReadJSONRelation : public TableFunctionRelation { +public: + ReadJSONRelation(const shared_ptr &context, string json_file, named_parameter_map_t options, + bool auto_detect, string alias = ""); + string json_file; + string alias; + +public: + string GetAlias() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/setop_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/setop_relation.hpp new file mode 100644 index 00000000..b65eb802 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/setop_relation.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/setop_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/common/enums/set_operation_type.hpp" + +namespace duckdb { + +class SetOpRelation : public Relation { +public: + SetOpRelation(shared_ptr left, shared_ptr right, SetOperationType setop_type); + + shared_ptr left; + shared_ptr right; + SetOperationType setop_type; + vector columns; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/subquery_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/subquery_relation.hpp new file mode 100644 index 00000000..16e9b7a8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/subquery_relation.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/subquery_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class SubqueryRelation : public Relation { +public: + SubqueryRelation(shared_ptr child, string alias); + + shared_ptr child; + string alias; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + +public: + bool InheritsColumnBindings() override { + return child->InheritsColumnBindings(); + } + Relation *ChildRelation() override { + return child->ChildRelation(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/table_function_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/table_function_relation.hpp new file mode 100644 index 00000000..4dce6d95 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/table_function_relation.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/table_function_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class TableFunctionRelation : public Relation { +public: + TableFunctionRelation(const std::shared_ptr &context, string name, vector parameters, + named_parameter_map_t named_parameters, shared_ptr input_relation_p = nullptr, + bool auto_init = true); + + TableFunctionRelation(const std::shared_ptr &context, string name, vector parameters, + shared_ptr input_relation_p = nullptr, bool auto_init = true); + + string name; + vector parameters; + named_parameter_map_t named_parameters; + vector columns; + shared_ptr input_relation; + +public: + unique_ptr GetQueryNode() override; + unique_ptr GetTableRef() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + void AddNamedParameter(const string &name, Value argument); + void SetNamedParameters(named_parameter_map_t &&named_parameters); + +private: + void InitializeColumns(); + +private: + //! Whether or not to auto initialize the columns on construction + bool auto_initialize; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp new file mode 100644 index 00000000..77a950ce --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/table_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/main/table_description.hpp" + +namespace duckdb { + +class TableRelation : public Relation { +public: + TableRelation(const std::shared_ptr &context, unique_ptr description); + + unique_ptr description; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + + unique_ptr GetTableRef() override; + + void Update(const string &update, const string &condition = string()) override; + void Delete(const string &condition = string()) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp new file mode 100644 index 00000000..1cb14222 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/update_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +class UpdateRelation : public Relation { +public: + UpdateRelation(ClientContextWrapper &context, unique_ptr condition, string schema_name, + string table_name, vector update_columns, vector> expressions); + + vector columns; + unique_ptr condition; + string schema_name; + string table_name; + vector update_columns; + vector> expressions; + +public: + BoundStatement Bind(Binder &binder) override; + const vector &Columns() override; + string ToString(idx_t depth) override; + bool IsReadOnly() override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/value_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/value_relation.hpp new file mode 100644 index 00000000..b8aa47c0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/value_relation.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/value_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +class ValueRelation : public Relation { +public: + ValueRelation(const std::shared_ptr &context, const vector> &values, + vector names, string alias = "values"); + ValueRelation(const std::shared_ptr &context, const string &values, vector names, + string alias = "values"); + + vector>> expressions; + vector names; + vector columns; + string alias; + +public: + unique_ptr GetQueryNode() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; + + unique_ptr GetTableRef() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/view_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/view_relation.hpp new file mode 100644 index 00000000..8a8afa26 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/view_relation.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/view_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class ViewRelation : public Relation { +public: + ViewRelation(const std::shared_ptr &context, string schema_name, string view_name); + + string schema_name; + string view_name; + vector columns; + +public: + unique_ptr GetQueryNode() override; + unique_ptr GetTableRef() override; + + const vector &Columns() override; + string ToString(idx_t depth) override; + string GetAlias() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp new file mode 100644 index 00000000..99d2ebe8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/write_csv_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class WriteCSVRelation : public Relation { +public: + WriteCSVRelation(shared_ptr child, string csv_file, case_insensitive_map_t> options); + + shared_ptr child; + string csv_file; + vector columns; + case_insensitive_map_t> options; + +public: + BoundStatement Bind(Binder &binder) override; + const vector &Columns() override; + string ToString(idx_t depth) override; + bool IsReadOnly() override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp new file mode 100644 index 00000000..c67da3e2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/relation/write_csv_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class WriteParquetRelation : public Relation { +public: + WriteParquetRelation(shared_ptr child, string parquet_file, + case_insensitive_map_t> options); + + shared_ptr child; + string parquet_file; + vector columns; + case_insensitive_map_t> options; + +public: + BoundStatement Bind(Binder &binder) override; + const vector &Columns() override; + string ToString(idx_t depth) override; + bool IsReadOnly() override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp new file mode 100644 index 00000000..65d2b6c4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -0,0 +1,555 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/settings.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { +class ClientContext; +class DatabaseInstance; +struct DBConfig; + +struct AccessModeSetting { + static constexpr const char *Name = "access_mode"; + static constexpr const char *Description = "Access mode of the database (AUTOMATIC, READ_ONLY or READ_WRITE)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct CheckpointThresholdSetting { + static constexpr const char *Name = "checkpoint_threshold"; + static constexpr const char *Description = + "The WAL size threshold at which to automatically trigger a checkpoint (e.g. 1GB)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct DebugCheckpointAbort { + static constexpr const char *Name = "debug_checkpoint_abort"; + static constexpr const char *Description = + "DEBUG SETTING: trigger an abort while checkpointing for testing purposes"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct DebugForceExternal { + static constexpr const char *Name = "debug_force_external"; + static constexpr const char *Description = + "DEBUG SETTING: force out-of-core computation for operators that support it, used for testing"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct DebugForceNoCrossProduct { + static constexpr const char *Name = "debug_force_no_cross_product"; + static constexpr const char *Description = + "DEBUG SETTING: Force disable cross product generation when hyper graph isn't connected, used for testing"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct OrderedAggregateThreshold { + static constexpr const char *Name = "ordered_aggregate_threshold"; // NOLINT + static constexpr const char *Description = // NOLINT + "The number of rows to accumulate before sorting, used for tuning"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::UBIGINT; // NOLINT + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct DebugAsOfIEJoin { + static constexpr const char *Name = "debug_asof_iejoin"; // NOLINT + static constexpr const char *Description = "DEBUG SETTING: force use of IEJoin to implement AsOf joins"; // NOLINT + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; // NOLINT + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct PreferRangeJoins { + static constexpr const char *Name = "prefer_range_joins"; // NOLINT + static constexpr const char *Description = "Force use of range joins with mixed predicates"; // NOLINT + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; // NOLINT + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct DebugWindowMode { + static constexpr const char *Name = "debug_window_mode"; + static constexpr const char *Description = "DEBUG SETTING: switch window mode to use"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct DefaultCollationSetting { + static constexpr const char *Name = "default_collation"; + static constexpr const char *Description = "The collation setting used when none is specified"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct DefaultOrderSetting { + static constexpr const char *Name = "default_order"; + static constexpr const char *Description = "The order type used when none is specified (ASC or DESC)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct DefaultNullOrderSetting { + static constexpr const char *Name = "default_null_order"; + static constexpr const char *Description = "Null ordering used when none is specified (NULLS_FIRST or NULLS_LAST)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct DisabledFileSystemsSetting { + static constexpr const char *Name = "disabled_filesystems"; + static constexpr const char *Description = "Disable specific file systems preventing access (e.g. LocalFileSystem)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct DisabledOptimizersSetting { + static constexpr const char *Name = "disabled_optimizers"; + static constexpr const char *Description = "DEBUG SETTING: disable a specific set of optimizers (comma separated)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct EnableExternalAccessSetting { + static constexpr const char *Name = "enable_external_access"; + static constexpr const char *Description = + "Allow the database to access external state (through e.g. loading/installing modules, COPY TO/FROM, CSV " + "readers, pandas replacement scans, etc)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct EnableFSSTVectors { + static constexpr const char *Name = "enable_fsst_vectors"; + static constexpr const char *Description = + "Allow scans on FSST compressed segments to emit compressed vectors to utilize late decompression"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct AllowUnsignedExtensionsSetting { + static constexpr const char *Name = "allow_unsigned_extensions"; + static constexpr const char *Description = "Allow to load extensions with invalid or missing signatures"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct CustomExtensionRepository { + static constexpr const char *Name = "custom_extension_repository"; + static constexpr const char *Description = "Overrides the custom endpoint for remote extension installation"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct AutoloadExtensionRepository { + static constexpr const char *Name = "autoinstall_extension_repository"; + static constexpr const char *Description = + "Overrides the custom endpoint for extension installation on autoloading"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct AutoinstallKnownExtensions { + static constexpr const char *Name = "autoinstall_known_extensions"; + static constexpr const char *Description = + "Whether known extensions are allowed to be automatically installed when a query depends on them"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct AutoloadKnownExtensions { + static constexpr const char *Name = "autoload_known_extensions"; + static constexpr const char *Description = + "Whether known extensions are allowed to be automatically loaded when a query depends on them"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct EnableObjectCacheSetting { + static constexpr const char *Name = "enable_object_cache"; + static constexpr const char *Description = "Whether or not object cache is used to cache e.g. Parquet metadata"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct EnableHTTPMetadataCacheSetting { + static constexpr const char *Name = "enable_http_metadata_cache"; + static constexpr const char *Description = "Whether or not the global http metadata is used to cache HTTP metadata"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static Value GetSetting(ClientContext &context); +}; + +struct EnableProfilingSetting { + static constexpr const char *Name = "enable_profiling"; + static constexpr const char *Description = + "Enables profiling, and sets the output format (JSON, QUERY_TREE, QUERY_TREE_OPTIMIZER)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct EnableProgressBarSetting { + static constexpr const char *Name = "enable_progress_bar"; + static constexpr const char *Description = + "Enables the progress bar, printing progress to the terminal for long queries"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct EnableProgressBarPrintSetting { + static constexpr const char *Name = "enable_progress_bar_print"; + static constexpr const char *Description = + "Controls the printing of the progress bar, when 'enable_progress_bar' is true"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct ExplainOutputSetting { + static constexpr const char *Name = "explain_output"; + static constexpr const char *Description = "Output of EXPLAIN statements (ALL, OPTIMIZED_ONLY, PHYSICAL_ONLY)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct ExtensionDirectorySetting { + static constexpr const char *Name = "extension_directory"; + static constexpr const char *Description = "Set the directory to store extensions in"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct ExternalThreadsSetting { + static constexpr const char *Name = "external_threads"; + static constexpr const char *Description = "The number of external threads that work on DuckDB tasks."; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BIGINT; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct FileSearchPathSetting { + static constexpr const char *Name = "file_search_path"; + static constexpr const char *Description = "A comma separated list of directories to search for input files"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct ForceCompressionSetting { + static constexpr const char *Name = "force_compression"; + static constexpr const char *Description = "DEBUG SETTING: forces a specific compression method to be used"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct ForceBitpackingModeSetting { + static constexpr const char *Name = "force_bitpacking_mode"; + static constexpr const char *Description = "DEBUG SETTING: forces a specific bitpacking mode"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct HomeDirectorySetting { + static constexpr const char *Name = "home_directory"; + static constexpr const char *Description = "Sets the home directory used by the system"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct IntegerDivisionSetting { + static constexpr const char *Name = "integer_division"; + static constexpr const char *Description = + "Whether or not the / operator defaults to integer division, or to floating point division"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct LogQueryPathSetting { + static constexpr const char *Name = "log_query_path"; + static constexpr const char *Description = + "Specifies the path to which queries should be logged (default: empty string, queries are not logged)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct LockConfigurationSetting { + static constexpr const char *Name = "lock_configuration"; + static constexpr const char *Description = "Whether or not the configuration can be altered"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct ImmediateTransactionModeSetting { + static constexpr const char *Name = "immediate_transaction_mode"; + static constexpr const char *Description = + "Whether transactions should be started lazily when needed, or immediately when BEGIN TRANSACTION is called"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct MaximumExpressionDepthSetting { + static constexpr const char *Name = "max_expression_depth"; + static constexpr const char *Description = + "The maximum expression depth limit in the parser. WARNING: increasing this setting and using very deep " + "expressions might lead to stack overflow errors."; + static constexpr const LogicalTypeId InputType = LogicalTypeId::UBIGINT; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct MaximumMemorySetting { + static constexpr const char *Name = "max_memory"; + static constexpr const char *Description = "The maximum memory of the system (e.g. 1GB)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct PasswordSetting { + static constexpr const char *Name = "password"; + static constexpr const char *Description = "The password to use. Ignored for legacy compatibility."; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct PerfectHashThresholdSetting { + static constexpr const char *Name = "perfect_ht_threshold"; + static constexpr const char *Description = "Threshold in bytes for when to use a perfect hash table (default: 12)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BIGINT; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct PivotFilterThreshold { + static constexpr const char *Name = "pivot_filter_threshold"; + static constexpr const char *Description = + "The threshold to switch from using filtered aggregates to LIST with a dedicated pivot operator"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BIGINT; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct PivotLimitSetting { + static constexpr const char *Name = "pivot_limit"; + static constexpr const char *Description = + "The maximum number of pivot columns in a pivot statement (default: 100000)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BIGINT; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct PreserveIdentifierCase { + static constexpr const char *Name = "preserve_identifier_case"; + static constexpr const char *Description = + "Whether or not to preserve the identifier case, instead of always lowercasing all non-quoted identifiers"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct PreserveInsertionOrder { + static constexpr const char *Name = "preserve_insertion_order"; + static constexpr const char *Description = + "Whether or not to preserve insertion order. If set to false the system is allowed to re-order any results " + "that do not contain ORDER BY clauses."; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct ExportLargeBufferArrow { + static constexpr const char *Name = "arrow_large_buffer_size"; + static constexpr const char *Description = + "If arrow buffers for strings, blobs, uuids and bits should be exported using large buffers"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BOOLEAN; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct ProfilerHistorySize { + static constexpr const char *Name = "profiler_history_size"; + static constexpr const char *Description = "Sets the profiler history size"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BIGINT; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct ProfileOutputSetting { + static constexpr const char *Name = "profile_output"; + static constexpr const char *Description = + "The file to which profile output should be saved, or empty to print to the terminal"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct ProfilingModeSetting { + static constexpr const char *Name = "profiling_mode"; + static constexpr const char *Description = "The profiling mode (STANDARD or DETAILED)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct ProgressBarTimeSetting { + static constexpr const char *Name = "progress_bar_time"; + static constexpr const char *Description = + "Sets the time (in milliseconds) how long a query needs to take before we start printing a progress bar"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BIGINT; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct SchemaSetting { + static constexpr const char *Name = "schema"; + static constexpr const char *Description = + "Sets the default search schema. Equivalent to setting search_path to a single value."; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct SearchPathSetting { + static constexpr const char *Name = "search_path"; + static constexpr const char *Description = + "Sets the default catalog search path as a comma-separated list of values"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(ClientContext &context); +}; + +struct TempDirectorySetting { + static constexpr const char *Name = "temp_directory"; + static constexpr const char *Description = "Set the directory to which to write temp files"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct ThreadsSetting { + static constexpr const char *Name = "threads"; + static constexpr const char *Description = "The number of total threads used by the system."; + static constexpr const LogicalTypeId InputType = LogicalTypeId::BIGINT; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct UsernameSetting { + static constexpr const char *Name = "username"; + static constexpr const char *Description = "The username to use. Ignored for legacy compatibility."; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +struct FlushAllocatorSetting { + static constexpr const char *Name = "allocator_flush_threshold"; + static constexpr const char *Description = + "Peak allocation threshold at which to flush the allocator after completing a task."; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/stream_query_result.hpp b/src/duckdb/src/include/duckdb/main/stream_query_result.hpp new file mode 100644 index 00000000..0373f307 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/stream_query_result.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/stream_query_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/winapi.hpp" +#include "duckdb/main/query_result.hpp" + +namespace duckdb { + +class ClientContext; +class ClientContextLock; +class Executor; +class MaterializedQueryResult; +class PreparedStatementData; + +class StreamQueryResult : public QueryResult { + friend class ClientContext; + +public: + static constexpr const QueryResultType TYPE = QueryResultType::STREAM_RESULT; + +public: + //! Create a successful StreamQueryResult. StreamQueryResults should always be successful initially (it makes no + //! sense to stream an error). + DUCKDB_API StreamQueryResult(StatementType statement_type, StatementProperties properties, + shared_ptr context, vector types, vector names); + DUCKDB_API ~StreamQueryResult() override; + +public: + //! Fetches a DataChunk from the query result. + DUCKDB_API unique_ptr FetchRaw() override; + //! Converts the QueryResult to a string + DUCKDB_API string ToString() override; + //! Materializes the query result and turns it into a materialized query result + DUCKDB_API unique_ptr Materialize(); + + DUCKDB_API bool IsOpen(); + + //! Closes the StreamQueryResult + DUCKDB_API void Close(); + + //! The client context this StreamQueryResult belongs to + shared_ptr context; + +private: + unique_ptr LockContext(); + void CheckExecutableInternal(ClientContextLock &lock); + bool IsOpenInternal(ClientContextLock &lock); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/table_description.hpp b/src/duckdb/src/include/duckdb/main/table_description.hpp new file mode 100644 index 00000000..2a35b4f1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/table_description.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/table_description.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/column_definition.hpp" + +namespace duckdb { + +struct TableDescription { + //! The schema of the table + string schema; + //! The table name of the table + string table; + //! The columns of the table + vector columns; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/valid_checker.hpp b/src/duckdb/src/include/duckdb/main/valid_checker.hpp new file mode 100644 index 00000000..bf563509 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/valid_checker.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/valid_checker.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/mutex.hpp" + +namespace duckdb { +class DatabaseInstance; +class MetaTransaction; + +class ValidChecker { +public: + ValidChecker(); + + DUCKDB_API static ValidChecker &Get(DatabaseInstance &db); + DUCKDB_API static ValidChecker &Get(MetaTransaction &transaction); + + DUCKDB_API void Invalidate(string error); + DUCKDB_API bool IsInvalidated(); + DUCKDB_API string InvalidatedMessage(); + + template + static bool IsInvalidated(T &o) { + return Get(o).IsInvalidated(); + } + template + static void Invalidate(T &o, string error) { + Get(o).Invalidate(std::move(error)); + } + + template + static string InvalidatedMessage(T &o) { + return Get(o).InvalidatedMessage(); + } + +private: + //! Set to true if a fatal exception has occurred + mutex invalidate_lock; + atomic is_invalidated; + string invalidated_msg; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/column_binding_replacer.hpp b/src/duckdb/src/include/duckdb/optimizer/column_binding_replacer.hpp new file mode 100644 index 00000000..63fb16e8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/column_binding_replacer.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/column_binding_replacer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +struct ReplacementBinding { +public: + ReplacementBinding(ColumnBinding old_binding, ColumnBinding new_binding); + ReplacementBinding(ColumnBinding old_binding, ColumnBinding new_binding, LogicalType new_type); + +public: + ColumnBinding old_binding; + ColumnBinding new_binding; + + bool replace_type; + LogicalType new_type; +}; + +//! The ColumnBindingReplacer updates column bindings (e.g., after changing the operator plan), utility for optimizers +class ColumnBindingReplacer : LogicalOperatorVisitor { +public: + ColumnBindingReplacer(); + +public: + //! Update each operator of the plan + void VisitOperator(LogicalOperator &op) override; + //! Visit an expression and update its column bindings + void VisitExpression(unique_ptr *expression) override; + +public: + //! Contains all bindings that need to be updated + vector replacement_bindings; + + //! Do not recurse further than this operator (optional) + optional_ptr stop_operator; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/column_lifetime_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/column_lifetime_optimizer.hpp new file mode 100644 index 00000000..4b2f33d5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/column_lifetime_optimizer.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/column_lifetime_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +class BoundColumnRefExpression; + +//! The ColumnLifetimeAnalyzer optimizer traverses the logical operator tree and ensures that columns are removed from +//! the plan when no longer required +class ColumnLifetimeAnalyzer : public LogicalOperatorVisitor { +public: + explicit ColumnLifetimeAnalyzer(bool is_root = false) : everything_referenced(is_root) { + } + + void VisitOperator(LogicalOperator &op) override; + +protected: + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + unique_ptr VisitReplace(BoundReferenceExpression &expr, unique_ptr *expr_ptr) override; + +private: + //! Whether or not all the columns are referenced. This happens in the case of the root expression (because the + //! output implicitly refers all the columns below it) + bool everything_referenced; + //! The set of column references + column_binding_set_t column_references; + +private: + void StandardVisitOperator(LogicalOperator &op); + + void ExtractUnusedColumnBindings(vector bindings, column_binding_set_t &unused_bindings); + void GenerateProjectionMap(vector bindings, column_binding_set_t &unused_bindings, + vector &map); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/common_aggregate_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/common_aggregate_optimizer.hpp new file mode 100644 index 00000000..ff021592 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/common_aggregate_optimizer.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/common_aggregate_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/column_binding_map.hpp" + +namespace duckdb { +//! The CommonAggregateOptimizer optimizer eliminates duplicate aggregates from aggregate nodes +class CommonAggregateOptimizer : public LogicalOperatorVisitor { +public: + void VisitOperator(LogicalOperator &op) override; + +private: + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + + void ExtractCommonAggregates(LogicalAggregate &aggr); + +private: + column_binding_map_t aggregate_map; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp b/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp new file mode 100644 index 00000000..688d8a9d --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp @@ -0,0 +1,132 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/compressed_materialization.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/function/scalar/compressed_materialization_functions.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +class LogicalOperator; +struct JoinCondition; + +struct CMChildInfo { +public: + CMChildInfo(LogicalOperator &op, const column_binding_set_t &referenced_bindings); + +public: + //! Bindings and types before compressing + vector bindings_before; + vector &types; + //! Whether the input binding is eligible for compression + vector can_compress; + + //! Bindings after compressing (projection on top) + vector bindings_after; +}; + +struct CMBindingInfo { +public: + explicit CMBindingInfo(ColumnBinding binding, const LogicalType &type); + +public: + ColumnBinding binding; + + //! Type before compressing + LogicalType type; + bool needs_decompression; + unique_ptr stats; +}; + +struct CompressedMaterializationInfo { +public: + CompressedMaterializationInfo(LogicalOperator &op, vector &&child_idxs, + const column_binding_set_t &referenced_bindings); + +public: + //! Mapping from incoming bindings to outgoing bindings + column_binding_map_t binding_map; + + //! Operator child info + vector child_idxs; + vector child_info; +}; + +struct CompressExpression { +public: + CompressExpression(unique_ptr expression, unique_ptr stats); + +public: + unique_ptr expression; + unique_ptr stats; +}; + +typedef column_binding_map_t> statistics_map_t; + +//! The CompressedMaterialization optimizer compressed columns using projections, based on available statistics, +//! but only if the data enters a materializing operator +class CompressedMaterialization { +public: + explicit CompressedMaterialization(ClientContext &context, Binder &binder, statistics_map_t &&statistics_map); + + void Compress(unique_ptr &op); + +private: + //! Depth-first traversal of the plan + void CompressInternal(unique_ptr &op); + + //! Compress materializing operators + void CompressAggregate(unique_ptr &op); + void CompressDistinct(unique_ptr &op); + void CompressOrder(unique_ptr &op); + + //! Update statistics after compressing + void UpdateAggregateStats(unique_ptr &op); + void UpdateOrderStats(unique_ptr &op); + + //! Adds bindings referenced in expression to referenced_bindings + static void GetReferencedBindings(const Expression &expression, column_binding_set_t &referenced_bindings); + //! Updates CMBindingInfo in the binding_map in info + void UpdateBindingInfo(CompressedMaterializationInfo &info, const ColumnBinding &binding, bool needs_decompression); + + //! Create (de)compress projections around the operator + void CreateProjections(unique_ptr &op, CompressedMaterializationInfo &info); + bool TryCompressChild(CompressedMaterializationInfo &info, const CMChildInfo &child_info, + vector> &compress_expressions); + void CreateCompressProjection(unique_ptr &child_op, + vector> &&compress_exprs, + CompressedMaterializationInfo &info, CMChildInfo &child_info); + void CreateDecompressProjection(unique_ptr &op, CompressedMaterializationInfo &info); + + //! Create expressions that apply a scalar compression function + unique_ptr GetCompressExpression(const ColumnBinding &binding, const LogicalType &type, + const bool &can_compress); + unique_ptr GetCompressExpression(unique_ptr input, const BaseStatistics &stats); + unique_ptr GetIntegralCompress(unique_ptr input, const BaseStatistics &stats); + unique_ptr GetStringCompress(unique_ptr input, const BaseStatistics &stats); + + //! Create an expression that applies a scalar decompression function + unique_ptr GetDecompressExpression(unique_ptr input, const LogicalType &result_type, + const BaseStatistics &stats); + unique_ptr GetIntegralDecompress(unique_ptr input, const LogicalType &result_type, + const BaseStatistics &stats); + unique_ptr GetStringDecompress(unique_ptr input, const BaseStatistics &stats); + +private: + ClientContext &context; + Binder &binder; + statistics_map_t statistics_map; + unordered_set compression_table_indices; + unordered_set decompression_table_indices; + optional_ptr root; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/cse_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/cse_optimizer.hpp new file mode 100644 index 00000000..feb05db6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/cse_optimizer.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/cse_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" + +namespace duckdb { +class Binder; +struct CSEReplacementState; + +//! The CommonSubExpression optimizer traverses the expressions of a LogicalOperator to look for duplicate expressions +//! if there are any, it pushes a projection under the operator that resolves these expressions +class CommonSubExpressionOptimizer : public LogicalOperatorVisitor { +public: + explicit CommonSubExpressionOptimizer(Binder &binder) : binder(binder) { + } + +public: + void VisitOperator(LogicalOperator &op) override; + +private: + //! First iteration: count how many times each expression occurs + void CountExpressions(Expression &expr, CSEReplacementState &state); + //! Second iteration: perform the actual replacement of the duplicate expressions with common subexpressions nodes + void PerformCSEReplacement(unique_ptr &expr, CSEReplacementState &state); + + //! Main method to extract common subexpressions + void ExtractCommonSubExpresions(LogicalOperator &op); + +private: + Binder &binder; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/deliminator.hpp b/src/duckdb/src/include/duckdb/optimizer/deliminator.hpp new file mode 100644 index 00000000..482c4064 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/deliminator.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/deliminator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/column_binding_replacer.hpp" + +namespace duckdb { + +struct DelimCandidate; + +//! The Deliminator optimizer traverses the logical operator tree and removes any redundant DelimGets/DelimJoins +class Deliminator { +public: + Deliminator() { + } + //! Perform DelimJoin elimination + unique_ptr Optimize(unique_ptr op); + +private: + //! Finds DelimJoins and their corresponding DelimGets + void FindCandidates(unique_ptr &op, vector &candidates); + void FindJoinWithDelimGet(unique_ptr &op, DelimCandidate &candidate); + //! Remove joins with a DelimGet + bool RemoveJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, + unique_ptr &join, bool &all_equality_conditions); + bool RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, + unique_ptr &join, + const vector &replacement_bindings); + +private: + optional_ptr root; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/expression_heuristics.hpp b/src/duckdb/src/include/duckdb/optimizer/expression_heuristics.hpp new file mode 100644 index 00000000..94c88e85 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/expression_heuristics.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/expression_heuristics.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/common/unordered_map.hpp" + +namespace duckdb { + +class ExpressionHeuristics : public LogicalOperatorVisitor { +public: + explicit ExpressionHeuristics(Optimizer &optimizer) : optimizer(optimizer) { + } + + Optimizer &optimizer; + unique_ptr root; + +public: + //! Search for filters to be reordered + unique_ptr Rewrite(unique_ptr op); + //! Reorder the expressions of a filter + void ReorderExpressions(vector> &expressions); + //! Return the cost of an expression + idx_t Cost(Expression &expr); + + unique_ptr VisitReplace(BoundConjunctionExpression &expr, unique_ptr *expr_ptr) override; + //! Override this function to search for filter operators + void VisitOperator(LogicalOperator &op) override; + +private: + unordered_map function_costs = { + {"+", 5}, {"-", 5}, {"&", 5}, {"#", 5}, + {">>", 5}, {"<<", 5}, {"abs", 5}, {"*", 10}, + {"%", 10}, {"/", 15}, {"date_part", 20}, {"year", 20}, + {"round", 100}, {"~~", 200}, {"!~~", 200}, {"regexp_matches", 200}, + {"||", 200}}; + + idx_t ExpressionCost(BoundBetweenExpression &expr); + idx_t ExpressionCost(BoundCaseExpression &expr); + idx_t ExpressionCost(BoundCastExpression &expr); + idx_t ExpressionCost(BoundComparisonExpression &expr); + idx_t ExpressionCost(BoundConjunctionExpression &expr); + idx_t ExpressionCost(BoundFunctionExpression &expr); + idx_t ExpressionCost(BoundOperatorExpression &expr, ExpressionType &expr_type); + idx_t ExpressionCost(PhysicalType return_type, idx_t multiplier); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/expression_rewriter.hpp b/src/duckdb/src/include/duckdb/optimizer/expression_rewriter.hpp new file mode 100644 index 00000000..e0bf154d --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/expression_rewriter.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/expression_rewriter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { +class ClientContext; + +//! The ExpressionRewriter performs a set of fixed rewrite rules on the expressions that occur in a SQL statement +class ExpressionRewriter : public LogicalOperatorVisitor { +public: + explicit ExpressionRewriter(ClientContext &context) : context(context) { + } + +public: + //! The set of rules as known by the Expression Rewriter + vector> rules; + + ClientContext &context; + +public: + void VisitOperator(LogicalOperator &op) override; + void VisitExpression(unique_ptr *expression) override; + + // Generates either a constant_or_null(child) expression + static unique_ptr ConstantOrNull(unique_ptr child, Value value); + static unique_ptr ConstantOrNull(vector> children, Value value); + +private: + //! Apply a set of rules to a specific expression + static unique_ptr ApplyRules(LogicalOperator &op, const vector> &rules, + unique_ptr expr, bool &changes_made, bool is_root = false); + + optional_ptr op; + vector> to_apply_rules; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp new file mode 100644 index 00000000..3764915d --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp @@ -0,0 +1,126 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/filter_combiner.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" + +#include "duckdb/storage/data_table.hpp" +#include +#include + +namespace duckdb { +class Optimizer; + +enum class ValueComparisonResult { PRUNE_LEFT, PRUNE_RIGHT, UNSATISFIABLE_CONDITION, PRUNE_NOTHING }; +enum class FilterResult { UNSATISFIABLE, SUCCESS, UNSUPPORTED }; + +//! The FilterCombiner combines several filters and generates a logically equivalent set that is more efficient +//! Amongst others: +//! (1) it prunes obsolete filter conditions: i.e. [X > 5 and X > 7] => [X > 7] +//! (2) it generates new filters for expressions in the same equivalence set: i.e. [X = Y and X = 500] => [Y = 500] +//! (3) it prunes branches that have unsatisfiable filters: i.e. [X = 5 AND X > 6] => FALSE, prune branch +class FilterCombiner { +public: + explicit FilterCombiner(ClientContext &context); + explicit FilterCombiner(Optimizer &optimizer); + + ClientContext &context; + +public: + struct ExpressionValueInformation { + Value constant; + ExpressionType comparison_type; + }; + + FilterResult AddFilter(unique_ptr expr); + + void GenerateFilters(const std::function filter)> &callback); + bool HasFilters(); + TableFilterSet GenerateTableScanFilters(vector &column_ids); + // vector> GenerateZonemapChecks(vector &column_ids, vector> + // &pushed_filters); + +private: + FilterResult AddFilter(Expression &expr); + FilterResult AddBoundComparisonFilter(Expression &expr); + FilterResult AddTransitiveFilters(BoundComparisonExpression &comparison); + unique_ptr FindTransitiveFilter(Expression &expr); + // unordered_map> + // FindZonemapChecks(vector &column_ids, unordered_set ¬_constants, Expression *filter); + Expression &GetNode(Expression &expr); + idx_t GetEquivalenceSet(Expression &expr); + FilterResult AddConstantComparison(vector &info_list, ExpressionValueInformation info); + // + // //! Functions used to push and generate OR Filters + // void LookUpConjunctions(Expression *expr); + // bool BFSLookUpConjunctions(BoundConjunctionExpression *conjunction); + // void VerifyOrsToPush(Expression &expr); + // + // bool UpdateConjunctionFilter(BoundComparisonExpression *comparison_expr); + // bool UpdateFilterByColumn(BoundColumnRefExpression *column_ref, BoundComparisonExpression *comparison_expr); + // void GenerateORFilters(TableFilterSet &table_filter, vector &column_ids); + // + // template + // void GenerateConjunctionFilter(BoundConjunctionExpression *conjunction, ConjunctionFilter *last_conj_filter) { + // auto new_filter = NextConjunctionFilter(conjunction); + // auto conj_filter_ptr = (ConjunctionFilter *)new_filter.get(); + // last_conj_filter->child_filters.push_back(std::move(new_filter)); + // last_conj_filter = conj_filter_ptr; + // } + // + // template + // unique_ptr NextConjunctionFilter(BoundConjunctionExpression *conjunction) { + // unique_ptr conj_filter = make_uniq(); + // for (auto &expr : conjunction->children) { + // auto comp_expr = (BoundComparisonExpression *)expr.get(); + // auto &const_expr = + // (comp_expr->left->type == ExpressionType::VALUE_CONSTANT) ? *comp_expr->left : *comp_expr->right; + // auto const_value = ExpressionExecutor::EvaluateScalar(const_expr); + // auto const_filter = make_uniq(comp_expr->type, const_value); + // conj_filter->child_filters.push_back(std::move(const_filter)); + // } + // return std::move(conj_filter); + // } + +private: + vector> remaining_filters; + + expression_map_t> stored_expressions; + expression_map_t equivalence_set_map; + unordered_map> constant_values; + unordered_map>> equivalence_map; + idx_t set_index = 0; + // + // //! Structures used for OR Filters + // + // struct ConjunctionsToPush { + // BoundConjunctionExpression *root_or; + // + // // only preserve AND if there is a single column in the expression + // bool preserve_and = true; + // + // // conjunction chain for this column + // vector> conjunctions; + // }; + // + // expression_map_t>> map_col_conjunctions; + // vector vec_colref_insertion_order; + // + // BoundConjunctionExpression *cur_root_or; + // BoundConjunctionExpression *cur_conjunction; + // + // BoundColumnRefExpression *cur_colref_to_push; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp new file mode 100644 index 00000000..a35fbaab --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/filter_pullup.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class FilterPullup { +public: + explicit FilterPullup(bool pullup = false, bool add_column = false) + : can_pullup(pullup), can_add_column(add_column) { + } + + //! Perform filter pullup + unique_ptr Rewrite(unique_ptr op); + +private: + vector> filters_expr_pullup; + + // only pull up filters when there is a fork + bool can_pullup = false; + + // identifiy case the branch is a set operation (INTERSECT or EXCEPT) + bool can_add_column = false; + +private: + // Generate logical filters pulled up + unique_ptr GeneratePullupFilter(unique_ptr child, + vector> &expressions); + + //! Pull up a LogicalFilter op + unique_ptr PullupFilter(unique_ptr op); + + //! Pull up filter in a LogicalProjection op + unique_ptr PullupProjection(unique_ptr op); + + //! Pull up filter in a LogicalCrossProduct op + unique_ptr PullupCrossProduct(unique_ptr op); + + unique_ptr PullupJoin(unique_ptr op); + + // PPullup filter in a left join + unique_ptr PullupFromLeft(unique_ptr op); + + // Pullup filter in a inner join + unique_ptr PullupInnerJoin(unique_ptr op); + + // Pullup filter in LogicalIntersect or LogicalExcept op + unique_ptr PullupSetOperation(unique_ptr op); + + unique_ptr PullupBothSide(unique_ptr op); + + // Finish pull up at this operator + unique_ptr FinishPullup(unique_ptr op); + + // special treatment for SetOperations and projections + void ProjectSetOperation(LogicalProjection &proj); + +}; // end FilterPullup + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp new file mode 100644 index 00000000..1b9a4482 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/filter_pushdown.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/optimizer/filter_combiner.hpp" +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +class Optimizer; + +class FilterPushdown { +public: + explicit FilterPushdown(Optimizer &optimizer); + + //! Perform filter pushdown + unique_ptr Rewrite(unique_ptr op); + //! Return a reference to the client context (from the optimizer) + ClientContext &GetContext(); + + struct Filter { + unordered_set bindings; + unique_ptr filter; + + Filter() { + } + explicit Filter(unique_ptr filter) : filter(std::move(filter)) { + } + + void ExtractBindings(); + }; + +private: + vector> filters; + Optimizer &optimizer; + + //! Push down a LogicalAggregate op + unique_ptr PushdownAggregate(unique_ptr op); + //! Push down a LogicalFilter op + unique_ptr PushdownFilter(unique_ptr op); + //! Push down a LogicalCrossProduct op + unique_ptr PushdownCrossProduct(unique_ptr op); + //! Push down a join operator + unique_ptr PushdownJoin(unique_ptr op); + //! Push down a LogicalProjection op + unique_ptr PushdownProjection(unique_ptr op); + //! Push down a LogicalSetOperation op + unique_ptr PushdownSetOperation(unique_ptr op); + //! Push down a LogicalGet op + unique_ptr PushdownGet(unique_ptr op); + //! Push down a LogicalLimit op + unique_ptr PushdownLimit(unique_ptr op); + // Pushdown an inner join + unique_ptr PushdownInnerJoin(unique_ptr op, unordered_set &left_bindings, + unordered_set &right_bindings); + // Pushdown a left join + unique_ptr PushdownLeftJoin(unique_ptr op, unordered_set &left_bindings, + unordered_set &right_bindings); + // Pushdown a mark join + unique_ptr PushdownMarkJoin(unique_ptr op, unordered_set &left_bindings, + unordered_set &right_bindings); + // Pushdown a single join + unique_ptr PushdownSingleJoin(unique_ptr op, unordered_set &left_bindings, + unordered_set &right_bindings); + + // AddLogicalFilter used to add an extra LogicalFilter at this level, + // because in some cases, some expressions can not be pushed down. + unique_ptr AddLogicalFilter(unique_ptr op, + vector> expressions); + //! Push any remaining filters into a LogicalFilter at this level + unique_ptr PushFinalFilters(unique_ptr op); + // Finish pushing down at this operator, creating a LogicalFilter to store any of the stored filters and recursively + // pushing down into its children (if any) + unique_ptr FinishPushdown(unique_ptr op); + //! Adds a filter to the set of filters. Returns FilterResult::UNSATISFIABLE if the subtree should be stripped, or + //! FilterResult::SUCCESS otherwise + FilterResult AddFilter(unique_ptr expr); + //! Generate filters from the current set of filters stored in the FilterCombiner + void GenerateFilters(); + //! if there are filters in this FilterPushdown node, push them into the combiner + void PushFilters(); + + FilterCombiner combiner; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/in_clause_rewriter.hpp b/src/duckdb/src/include/duckdb/optimizer/in_clause_rewriter.hpp new file mode 100644 index 00000000..91e4e1f8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/in_clause_rewriter.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/in_clause_rewriter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator_visitor.hpp" + +namespace duckdb { +class ClientContext; +class Optimizer; + +class InClauseRewriter : public LogicalOperatorVisitor { +public: + explicit InClauseRewriter(ClientContext &context, Optimizer &optimizer) : context(context), optimizer(optimizer) { + } + + ClientContext &context; + Optimizer &optimizer; + unique_ptr root; + +public: + unique_ptr Rewrite(unique_ptr op); + + unique_ptr VisitReplace(BoundOperatorExpression &expr, unique_ptr *expr_ptr) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp new file mode 100644 index 00000000..dc8de7d8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp @@ -0,0 +1,95 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/cardinality_estimator.hpp +// +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/optimizer/join_order/query_graph.hpp" + +#include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" + +namespace duckdb { + +struct FilterInfo; + +struct RelationsToTDom { + //! column binding sets that are equivalent in a join plan. + //! if you have A.x = B.y and B.y = C.z, then one set is {A.x, B.y, C.z}. + column_binding_set_t equivalent_relations; + //! the estimated total domains of the equivalent relations determined using HLL + idx_t tdom_hll; + //! the estimated total domains of each relation without using HLL + idx_t tdom_no_hll; + bool has_tdom_hll; + vector filters; + vector column_names; + + RelationsToTDom(const column_binding_set_t &column_binding_set) + : equivalent_relations(column_binding_set), tdom_hll(0), tdom_no_hll(NumericLimits::Maximum()), + has_tdom_hll(false) {}; +}; + +struct Subgraph2Denominator { + unordered_set relations; + double denom; + + Subgraph2Denominator() : relations(), denom(1) {}; +}; + +class CardinalityHelper { +public: + CardinalityHelper() { + } + CardinalityHelper(double cardinality_before_filters, double filter_string) + : cardinality_before_filters(cardinality_before_filters), filter_strength(filter_string) {}; + +public: + double cardinality_before_filters; + double filter_strength; + + vector table_names_joined; + vector column_names; +}; + +class CardinalityEstimator { +public: + explicit CardinalityEstimator() {}; + +private: + vector relations_to_tdoms; + unordered_map relation_set_2_cardinality; + JoinRelationSetManager set_manager; + vector relation_stats; + +public: + void RemoveEmptyTotalDomains(); + void UpdateTotalDomains(optional_ptr set, RelationStats &stats); + void InitEquivalentRelations(const vector> &filter_infos); + + void InitCardinalityEstimatorProps(optional_ptr set, RelationStats &stats); + + //! cost model needs estimated cardinalities to the fraction since the formula captures + //! distinct count selectivities and multiplicities. Hence the template + template + T EstimateCardinalityWithSet(JoinRelationSet &new_set); + + //! used for debugging. + void AddRelationNamesToTdoms(vector &stats); + void PrintRelationToTdomInfo(); + +private: + bool SingleColumnFilter(FilterInfo &filter_info); + vector DetermineMatchingEquivalentSets(FilterInfo *filter_info); + //! Given a filter, add the column bindings to the matching equivalent set at the index + //! given in matching equivalent sets. + //! If there are multiple equivalence sets, they are merged. + void AddToEquivalenceSets(FilterInfo *filter_info, vector matching_equivalent_sets); + void AddRelationTdom(FilterInfo &filter_info); + bool EmptyFilter(FilterInfo &filter_info); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/cost_model.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/cost_model.hpp new file mode 100644 index 00000000..d2ff77d9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/cost_model.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/cost_model.hpp +// +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/optimizer/join_order/cardinality_estimator.hpp" + +namespace duckdb { + +class QueryGraphManager; + +class CostModel { +public: + CostModel(QueryGraphManager &query_graph_manager); + +private: + //! query graph storing relation manager information + QueryGraphManager &query_graph_manager; + +public: + void InitCostModel(); + + //! Compute cost of a join relation set + double ComputeCost(JoinNode &left, JoinNode &right); + + //! Cardinality Estimator used to calculate cost + CardinalityEstimator cardinality_estimator; + +private: +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/estimated_properties.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/estimated_properties.hpp new file mode 100644 index 00000000..7404d6ce --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/estimated_properties.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/estimated_properties.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/storage/statistics/distinct_statistics.hpp" +#include "duckdb/planner/table_filter.hpp" + +namespace duckdb { + +class EstimatedProperties { +public: + EstimatedProperties(double cardinality, double cost) : cardinality(cardinality), cost(cost) {}; + EstimatedProperties() : cardinality(0), cost(0) {}; + + template + T GetCardinality() const { + throw NotImplementedException("Unsupported type for GetCardinality"); + } + template + T GetCost() const { + throw NotImplementedException("Unsupported type for GetCost"); + } + void SetCost(double new_cost); + void SetCardinality(double cardinality); + +private: + double cardinality; + double cost; + +public: + unique_ptr Copy(); +}; + +template <> +double EstimatedProperties::GetCardinality() const; + +template <> +idx_t EstimatedProperties::GetCardinality() const; + +template <> +double EstimatedProperties::GetCost() const; + +template <> +idx_t EstimatedProperties::GetCost() const; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/join_node.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/join_node.hpp new file mode 100644 index 00000000..a1772102 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/join_node.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/join_node.hpp +// +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/query_graph.hpp" + +namespace duckdb { + +struct NeighborInfo; + +class JoinNode { +public: + //! Represents a node in the join plan + JoinRelationSet &set; + //! information on how left and right are connected + optional_ptr info; + //! left and right plans + optional_ptr left; + optional_ptr right; + + //! The cost of the join node. The cost is stored here so that the cost of + //! a join node stays in sync with how the join node is constructed. Storing the cost in an unordered_set + //! in the cost model is error prone. If the plan enumerator join node is updated and not the cost model + //! the whole Join Order Optimizer can start exhibiting undesired behavior. + double cost; + //! used only to populate logical operators with estimated caridnalities after the best join plan has been found. + idx_t cardinality; + + //! Create an intermediate node in the join tree. base_cardinality = estimated_props.cardinality + JoinNode(JoinRelationSet &set, optional_ptr info, JoinNode &left, JoinNode &right, double cost); + + //! Create a leaf node in the join tree + //! set cost to 0 for leaf nodes + //! cost will be the cost to *produce* an intermediate table + JoinNode(JoinRelationSet &set); + + bool operator==(const JoinNode &other) { + return other.set.ToString().compare(set.ToString()) == 0; + } + +private: +public: + void PrintJoinNode(); + string ToString(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/join_order_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/join_order_optimizer.hpp new file mode 100644 index 00000000..720ec4ab --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/join_order_optimizer.hpp @@ -0,0 +1,104 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/join_order_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/optimizer/join_order/query_graph_manager.hpp" +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/cardinality_estimator.hpp" +#include "duckdb/optimizer/join_order/query_graph.hpp" +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" + +#include + +namespace duckdb { + +class JoinOrderOptimizer { +public: + explicit JoinOrderOptimizer(ClientContext &context) : context(context), query_graph_manager(context) { + } + + //! Perform join reordering inside a plan + unique_ptr Optimize(unique_ptr plan, optional_ptr stats = nullptr); + + unique_ptr CreateJoinTree(JoinRelationSet &set, + const vector> &possible_connections, JoinNode &left, + JoinNode &right); + +private: + ClientContext &context; + + //! manages the query graph, relations, and edges between relations + QueryGraphManager query_graph_manager; + + //! The optimal join plan found for the specific JoinRelationSet* + unordered_map> plans; + + //! The set of filters extracted from the query graph + vector> filters; + //! The set of filter infos created from the extracted filters + vector> filter_infos; + //! A map of all expressions a given expression has to be equivalent to. This is used to add "implied join edges". + //! i.e. in the join A=B AND B=C, the equivalence set of {B} is {A, C}, thus we can add an implied join edge {A = C} + expression_map_t> equivalence_sets; + + CardinalityEstimator cardinality_estimator; + + bool full_plan_found; + bool must_update_full_plan; + unordered_set join_nodes_in_full_plan; + + //! Extract the bindings referred to by an Expression + bool ExtractBindings(Expression &expression, unordered_set &bindings); + + //! Get column bindings from a filter + void GetColumnBinding(Expression &expression, ColumnBinding &binding); + + //! Traverse the query tree to find (1) base relations, (2) existing join conditions and (3) filters that can be + //! rewritten into joins. Returns true if there are joins in the tree that can be reordered, false otherwise. + bool ExtractJoinRelations(LogicalOperator &input_op, vector> &filter_operators, + optional_ptr parent = nullptr); + + //! Emit a pair as a potential join candidate. Returns the best plan found for the (left, right) connection (either + //! the newly created plan, or an existing plan) + JoinNode &EmitPair(JoinRelationSet &left, JoinRelationSet &right, const vector> &info); + //! Tries to emit a potential join candidate pair. Returns false if too many pairs have already been emitted, + //! cancelling the dynamic programming step. + bool TryEmitPair(JoinRelationSet &left, JoinRelationSet &right, const vector> &info); + + bool EnumerateCmpRecursive(JoinRelationSet &left, JoinRelationSet &right, unordered_set &exclusion_set); + //! Emit a relation set node + bool EmitCSG(JoinRelationSet &node); + //! Enumerate the possible connected subgraphs that can be joined together in the join graph + bool EnumerateCSGRecursive(JoinRelationSet &node, unordered_set &exclusion_set); + //! Rewrite a logical query plan given the join plan + unique_ptr RewritePlan(unique_ptr plan, JoinNode &node); + //! Generate cross product edges inside the side + void GenerateCrossProducts(); + //! Perform the join order solving + void SolveJoinOrder(); + //! Solve the join order exactly using dynamic programming. Returns true if it was completed successfully (i.e. did + //! not time-out) + bool SolveJoinOrderExactly(); + //! Solve the join order approximately using a greedy algorithm + void SolveJoinOrderApproximately(); + + void UpdateDPTree(JoinNode &new_plan); + + void UpdateJoinNodesInFullPlan(JoinNode &node); + bool NodeInFullPlan(JoinNode &node); + + GenerateJoinRelation GenerateJoins(vector> &extracted_relations, JoinNode &node); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/join_relation.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/join_relation.hpp new file mode 100644 index 00000000..aee248e6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/join_relation.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/join_relation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +//! Set of relations, used in the join graph. +struct JoinRelationSet { + JoinRelationSet(unsafe_unique_array relations, idx_t count) : relations(std::move(relations)), count(count) { + } + + string ToString() const; + + unsafe_unique_array relations; + idx_t count; + + static bool IsSubset(JoinRelationSet &super, JoinRelationSet &sub); +}; + +//! The JoinRelationTree is a structure holding all the created JoinRelationSet objects and allowing fast lookup on to +//! them +class JoinRelationSetManager { +public: + //! Contains a node with a JoinRelationSet and child relations + // FIXME: this structure is inefficient, could use a bitmap for lookup instead (todo: profile) + struct JoinRelationTreeNode { + unique_ptr relation; + unordered_map> children; + }; + +public: + //! Create or get a JoinRelationSet from a single node with the given index + JoinRelationSet &GetJoinRelation(idx_t index); + //! Create or get a JoinRelationSet from a set of relation bindings + JoinRelationSet &GetJoinRelation(const unordered_set &bindings); + //! Create or get a JoinRelationSet from a (sorted, duplicate-free!) list of relations + JoinRelationSet &GetJoinRelation(unsafe_unique_array relations, idx_t count); + //! Union two sets of relations together and create a new relation set + JoinRelationSet &Union(JoinRelationSet &left, JoinRelationSet &right); + // //! Create the set difference of left \ right (i.e. all elements in left that are not in right) + // JoinRelationSet *Difference(JoinRelationSet *left, JoinRelationSet *right); + +private: + JoinRelationTreeNode root; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/plan_enumerator.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/plan_enumerator.hpp new file mode 100644 index 00000000..19d6a40f --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/plan_enumerator.hpp @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/plan_enumerator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/cardinality_estimator.hpp" +#include "duckdb/optimizer/join_order/query_graph.hpp" +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/optimizer/join_order/cost_model.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" + +#include + +namespace duckdb { + +class QueryGraphManager; + +class PlanEnumerator { +public: + explicit PlanEnumerator(QueryGraphManager &query_graph_manager, CostModel &cost_model, + const QueryGraphEdges &query_graph) + : query_graph(query_graph), query_graph_manager(query_graph_manager), cost_model(cost_model), + full_plan_found(false), must_update_full_plan(false) { + } + + //! Perform the join order solving + unique_ptr SolveJoinOrder(); + void InitLeafPlans(); + + static unique_ptr BuildSideProbeSideSwaps(unique_ptr plan); + +private: + QueryGraphEdges const &query_graph; + //! The total amount of join pairs that have been considered + idx_t pairs = 0; + //! The set of edges used in the join optimizer + QueryGraphManager &query_graph_manager; + //! Cost model to evaluate cost of joins + CostModel &cost_model; + //! A map to store the optimal join plan found for a specific JoinRelationSet* + reference_map_t> plans; + + bool full_plan_found; + bool must_update_full_plan; + unordered_set join_nodes_in_full_plan; + + unique_ptr CreateJoinTree(JoinRelationSet &set, + const vector> &possible_connections, JoinNode &left, + JoinNode &right); + + //! Emit a pair as a potential join candidate. Returns the best plan found for the (left, right) connection (either + //! the newly created plan, or an existing plan) + JoinNode &EmitPair(JoinRelationSet &left, JoinRelationSet &right, const vector> &info); + //! Tries to emit a potential join candidate pair. Returns false if too many pairs have already been emitted, + //! cancelling the dynamic programming step. + bool TryEmitPair(JoinRelationSet &left, JoinRelationSet &right, const vector> &info); + + bool EnumerateCmpRecursive(JoinRelationSet &left, JoinRelationSet &right, unordered_set &exclusion_set); + //! Emit a relation set node + bool EmitCSG(JoinRelationSet &node); + //! Enumerate the possible connected subgraphs that can be joined together in the join graph + bool EnumerateCSGRecursive(JoinRelationSet &node, unordered_set &exclusion_set); + //! Generate cross product edges inside the side + void GenerateCrossProducts(); + + //! Solve the join order exactly using dynamic programming. Returns true if it was completed successfully (i.e. did + //! not time-out) + bool SolveJoinOrderExactly(); + //! Solve the join order approximately using a greedy algorithm + void SolveJoinOrderApproximately(); + + void UpdateDPTree(JoinNode &new_plan); + + void UpdateJoinNodesInFullPlan(JoinNode &node); + bool NodeInFullPlan(JoinNode &node); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph.hpp new file mode 100644 index 00000000..92895e2c --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph.hpp @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/query_graph.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/optimizer/join_order/relation_manager.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/planner/column_binding.hpp" + +#include + +namespace duckdb { + +struct FilterInfo; + +struct NeighborInfo { + NeighborInfo(optional_ptr neighbor) : neighbor(neighbor) { + } + + optional_ptr neighbor; + vector> filters; +}; + +//! The QueryGraph contains edges between relations and allows edges to be created/queried +class QueryGraphEdges { +public: + //! Contains a node with info about neighboring relations and child edge infos + struct QueryEdge { + vector> neighbors; + unordered_map> children; + }; + +public: + string ToString() const; + void Print(); + + //! Returns a connection if there is an edge that connects these two sets, or nullptr otherwise + const vector> GetConnections(JoinRelationSet &node, JoinRelationSet &other) const; + //! Enumerate the neighbors of a specific node that do not belong to any of the exclusion_set. Note that if a + //! neighbor has multiple nodes, this function will return the lowest entry in that set. + const vector GetNeighbors(JoinRelationSet &node, unordered_set &exclusion_set) const; + + //! Enumerate all neighbors of a given JoinRelationSet node + void EnumerateNeighbors(JoinRelationSet &node, const std::function &callback) const; + //! Create an edge in the edge_set + void CreateEdge(JoinRelationSet &left, JoinRelationSet &right, optional_ptr info); + +private: + //! Get the QueryEdge of a specific node + optional_ptr GetQueryEdge(JoinRelationSet &left); + + void EnumerateNeighborsDFS(JoinRelationSet &node, reference info, idx_t index, + const std::function &callback) const; + + QueryEdge root; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp new file mode 100644 index 00000000..5a887f5d --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp @@ -0,0 +1,113 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/query_graph_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/query_graph.hpp" +#include "duckdb/optimizer/join_order/relation_manager.hpp" +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/planner/logical_operator.hpp" + +#include + +namespace duckdb { + +struct GenerateJoinRelation { + GenerateJoinRelation(optional_ptr set, unique_ptr op_p) + : set(set), op(std::move(op_p)) { + } + + optional_ptr set; + unique_ptr op; +}; + +//! Filter info struct that is used by the cardinality estimator to set the initial cardinality +//! but is also eventually transformed into a query edge. +struct FilterInfo { + FilterInfo(unique_ptr filter, JoinRelationSet &set, idx_t filter_index) + : filter(std::move(filter)), set(set), filter_index(filter_index) { + } + + unique_ptr filter; + JoinRelationSet &set; + idx_t filter_index; + optional_ptr left_set; + optional_ptr right_set; + ColumnBinding left_binding; + ColumnBinding right_binding; +}; + +//! The QueryGraphManager manages the process of extracting the reorderable and nonreorderable operations +//! from the logical plan and creating the intermediate structures needed by the plan enumerator. +//! When the plan enumerator finishes, the Query Graph Manger can then recreate the logical plan. +class QueryGraphManager { +public: + QueryGraphManager(ClientContext &context) : relation_manager(context), context(context) { + } + + //! manage relations and the logical operators they represent + RelationManager relation_manager; + + //! A structure holding all the created JoinRelationSet objects + JoinRelationSetManager set_manager; + + ClientContext &context; + + //! Extract the join relations, optimizing non-reoderable relations when encountered + bool Build(LogicalOperator &op); + + //! Reconstruct the logical plan using the plan found by the plan enumerator + unique_ptr Reconstruct(unique_ptr plan, JoinNode &node); + + //! Get a reference to the QueryGraphEdges structure that stores edges between + //! nodes and hypernodes. + const QueryGraphEdges &GetQueryGraphEdges() const; + + //! Get a list of the join filters in the join plan than eventually are + //! transformed into the query graph edges + const vector> &GetFilterBindings() const; + + //! Plan enumerator may not find a full plan and therefore will need to create cross + //! products to create edges. + void CreateQueryGraphCrossProduct(JoinRelationSet &left, JoinRelationSet &right); + + //! after join order optimization, we perform build side probe side optimizations. + //! (Basically we put lower expected cardinality columns on the build side, and larger + //! tables on the probe side) + unique_ptr LeftRightOptimizations(unique_ptr op); + +private: + vector> filter_operators; + + //! Filter information including the column_bindings that join filters + //! used by the cardinality estimator to estimate distinct counts + vector> filters_and_bindings; + + QueryGraphEdges query_graph; + + void GetColumnBinding(Expression &expression, ColumnBinding &binding); + + bool ExtractBindings(Expression &expression, unordered_set &bindings); + bool LeftCardLessThanRight(LogicalOperator &op); + + void CreateHyperGraphEdges(); + + GenerateJoinRelation GenerateJoins(vector> &extracted_relations, JoinNode &node); + + unique_ptr RewritePlan(unique_ptr plan, JoinNode &node); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp new file mode 100644 index 00000000..adaf487f --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/relation_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/cardinality_estimator.hpp" +#include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" + +namespace duckdb { + +struct FilterInfo; + +//! Represents a single relation and any metadata accompanying that relation +struct SingleJoinRelation { + LogicalOperator &op; + optional_ptr parent; + RelationStats stats; + + SingleJoinRelation(LogicalOperator &op, optional_ptr parent) : op(op), parent(parent) { + } + SingleJoinRelation(LogicalOperator &op, optional_ptr parent, RelationStats stats) + : op(op), parent(parent), stats(stats) { + } +}; + +class RelationManager { +public: + explicit RelationManager(ClientContext &context) : context(context) { + } + + idx_t NumRelations(); + + bool ExtractJoinRelations(LogicalOperator &input_op, vector> &filter_operators, + optional_ptr parent = nullptr); + + //! for each join filter in the logical plan op, extract the relations that are referred to on + //! both sides of the join filter, along with the tables & indexes. + vector> ExtractEdges(LogicalOperator &op, + vector> &filter_operators, + JoinRelationSetManager &set_manager); + + //! Extract the set of relations referred to inside an expression + bool ExtractBindings(Expression &expression, unordered_set &bindings); + void AddRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats); + + void AddAggregateRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats); + vector> GetRelations(); + + const vector GetRelationStats(); + //! A mapping of base table index -> index into relations array (relation number) + unordered_map relation_mapping; + + void PrintRelationStats(); + +private: + ClientContext &context; + //! Set of all relations considered in the join optimizer + vector> relations; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp new file mode 100644 index 00000000..86279cdd --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/statistics_extractor.hpp +// +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class CardinalityEstimator; + +struct DistinctCount { + idx_t distinct_count; + bool from_hll; +}; + +struct ExpressionBinding { + bool found_expression = false; + ColumnBinding child_binding; + bool expression_is_constant = false; +}; + +struct RelationStats { + // column_id -> estimated distinct count for column + vector column_distinct_count; + idx_t cardinality; + double filter_strength = 1; + bool stats_initialized = false; + + // for debug, column names and tables + vector column_names; + string table_name; + + RelationStats() : cardinality(1), filter_strength(1), stats_initialized(false) { + } +}; + +class RelationStatisticsHelper { +public: + static constexpr double DEFAULT_SELECTIVITY = 0.2; + +public: + static idx_t InspectConjunctionAND(idx_t cardinality, idx_t column_index, ConjunctionAndFilter &filter, + BaseStatistics &base_stats); + // static idx_t InspectConjunctionOR(idx_t cardinality, idx_t column_index, ConjunctionOrFilter &filter, + // BaseStatistics &base_stats); + //! Extract Statistics from a LogicalGet. + static RelationStats ExtractGetStats(LogicalGet &get, ClientContext &context); + static RelationStats ExtractDelimGetStats(LogicalDelimGet &delim_get, ClientContext &context); + //! Create the statistics for a projection using the statistics of the operator that sits underneath the + //! projection. Then also create statistics for any extra columns the projection creates. + static RelationStats ExtractDummyScanStats(LogicalDummyScan &dummy_scan, ClientContext &context); + static RelationStats ExtractExpressionGetStats(LogicalExpressionGet &expression_get, ClientContext &context); + //! All relation extractors for blocking relations + static RelationStats ExtractProjectionStats(LogicalProjection &proj, RelationStats &child_stats); + static RelationStats ExtractAggregationStats(LogicalAggregate &aggr, RelationStats &child_stats); + static RelationStats ExtractWindowStats(LogicalWindow &window, RelationStats &child_stats); + //! Called after reordering a query plan with potentially 2+ relations. + static RelationStats CombineStatsOfReorderableOperator(vector &bindings, + vector relation_stats); + //! Called after reordering a query plan with potentially 2+ relations. + static RelationStats CombineStatsOfNonReorderableOperator(LogicalOperator &op, vector child_stats); + static void CopyRelationStats(RelationStats &to, const RelationStats &from); + +private: +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/expression_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/expression_matcher.hpp new file mode 100644 index 00000000..88113cc8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/expression_matcher.hpp @@ -0,0 +1,139 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/matcher/expression_matcher.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/optimizer/matcher/expression_type_matcher.hpp" +#include "duckdb/optimizer/matcher/set_matcher.hpp" +#include "duckdb/optimizer/matcher/type_matcher.hpp" +#include "duckdb/optimizer/matcher/function_matcher.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! The ExpressionMatcher class contains a set of matchers that can be used to pattern match Expressions +class ExpressionMatcher { +public: + explicit ExpressionMatcher(ExpressionClass type = ExpressionClass::INVALID) : expr_class(type) { + } + virtual ~ExpressionMatcher() { + } + + //! Checks if the given expression matches this ExpressionMatcher. If it does, the expression is appended to the + //! bindings list and true is returned. Otherwise, false is returned. + virtual bool Match(Expression &expr, vector> &bindings); + + //! The ExpressionClass of the to-be-matched expression. ExpressionClass::INVALID for ANY. + ExpressionClass expr_class; + //! Matcher for the ExpressionType of the operator (nullptr for ANY) + unique_ptr expr_type; + //! Matcher for the return_type of the expression (nullptr for ANY) + unique_ptr type; +}; + +//! The ExpressionEqualityMatcher matches on equality with another (given) expression +class ExpressionEqualityMatcher : public ExpressionMatcher { +public: + explicit ExpressionEqualityMatcher(Expression &expr) + : ExpressionMatcher(ExpressionClass::INVALID), expression(expr) { + } + + bool Match(Expression &expr, vector> &bindings) override; + +private: + const Expression &expression; +}; + +class ConstantExpressionMatcher : public ExpressionMatcher { +public: + ConstantExpressionMatcher() : ExpressionMatcher(ExpressionClass::BOUND_CONSTANT) { + } +}; + +class CaseExpressionMatcher : public ExpressionMatcher { +public: + CaseExpressionMatcher() : ExpressionMatcher(ExpressionClass::BOUND_CASE) { + } + + bool Match(Expression &expr_, vector> &bindings) override; +}; + +class ComparisonExpressionMatcher : public ExpressionMatcher { +public: + ComparisonExpressionMatcher() + : ExpressionMatcher(ExpressionClass::BOUND_COMPARISON), policy(SetMatcher::Policy::INVALID) { + } + //! The matchers for the child expressions + vector> matchers; + //! The set matcher matching policy to use + SetMatcher::Policy policy; + + bool Match(Expression &expr_, vector> &bindings) override; +}; + +class CastExpressionMatcher : public ExpressionMatcher { +public: + CastExpressionMatcher() : ExpressionMatcher(ExpressionClass::BOUND_CAST) { + } + //! The matcher for the child expressions + unique_ptr matcher; + + bool Match(Expression &expr_, vector> &bindings) override; +}; + +class InClauseExpressionMatcher : public ExpressionMatcher { +public: + InClauseExpressionMatcher() : ExpressionMatcher(ExpressionClass::BOUND_OPERATOR) { + } + //! The matchers for the child expressions + vector> matchers; + //! The set matcher matching policy to use + SetMatcher::Policy policy; + + bool Match(Expression &expr_, vector> &bindings) override; +}; + +class ConjunctionExpressionMatcher : public ExpressionMatcher { +public: + ConjunctionExpressionMatcher() + : ExpressionMatcher(ExpressionClass::BOUND_CONJUNCTION), policy(SetMatcher::Policy::INVALID) { + } + //! The matchers for the child expressions + vector> matchers; + //! The set matcher matching policy to use + SetMatcher::Policy policy; + + bool Match(Expression &expr_, vector> &bindings) override; +}; + +class FunctionExpressionMatcher : public ExpressionMatcher { +public: + FunctionExpressionMatcher() : ExpressionMatcher(ExpressionClass::BOUND_FUNCTION) { + } + //! The matchers for the child expressions + vector> matchers; + //! The set matcher matching policy to use + SetMatcher::Policy policy; + //! The function name to match + unique_ptr function; + + bool Match(Expression &expr_, vector> &bindings) override; +}; + +//! The FoldableConstant matcher matches any expression that is foldable into a constant by the ExpressionExecutor (i.e. +//! scalar but not aggregate/window/parameter) +class FoldableConstantMatcher : public ExpressionMatcher { +public: + FoldableConstantMatcher() : ExpressionMatcher(ExpressionClass::INVALID) { + } + + bool Match(Expression &expr, vector> &bindings) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/expression_type_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/expression_type_matcher.hpp new file mode 100644 index 00000000..2f884577 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/expression_type_matcher.hpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/matcher/expression_type_matcher.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/vector.hpp" + +#include + +namespace duckdb { + +//! The ExpressionTypeMatcher class contains a set of matchers that can be used to pattern match ExpressionTypes +class ExpressionTypeMatcher { +public: + virtual ~ExpressionTypeMatcher() { + } + + virtual bool Match(ExpressionType type) = 0; +}; + +//! The SpecificExpressionTypeMatcher class matches a single specified Expression type +class SpecificExpressionTypeMatcher : public ExpressionTypeMatcher { +public: + explicit SpecificExpressionTypeMatcher(ExpressionType type) : type(type) { + } + + bool Match(ExpressionType type) override { + return type == this->type; + } + +private: + ExpressionType type; +}; + +//! The ManyExpressionTypeMatcher class matches a set of ExpressionTypes +class ManyExpressionTypeMatcher : public ExpressionTypeMatcher { +public: + explicit ManyExpressionTypeMatcher(vector types) : types(std::move(types)) { + } + + bool Match(ExpressionType type) override { + return std::find(types.begin(), types.end(), type) != types.end(); + } + +private: + vector types; +}; + +//! The ComparisonExpressionTypeMatcher class matches a comparison expression +class ComparisonExpressionTypeMatcher : public ExpressionTypeMatcher { +public: + bool Match(ExpressionType type) override { + return type == ExpressionType::COMPARE_EQUAL || type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || + type == ExpressionType::COMPARE_LESSTHANOREQUALTO || type == ExpressionType::COMPARE_LESSTHAN || + type == ExpressionType::COMPARE_GREATERTHAN; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/function_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/function_matcher.hpp new file mode 100644 index 00000000..6acf2bbf --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/function_matcher.hpp @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/matcher/function_matcher.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/unordered_set.hpp" +#include + +namespace duckdb { + +//! The FunctionMatcher class contains a set of matchers that can be used to pattern match specific functions +class FunctionMatcher { +public: + virtual ~FunctionMatcher() { + } + + virtual bool Match(string &name) = 0; + + static bool Match(unique_ptr &matcher, string &name) { + if (!matcher) { + return true; + } + return matcher->Match(name); + } +}; + +//! The SpecificFunctionMatcher class matches a single specified function name +class SpecificFunctionMatcher : public FunctionMatcher { +public: + explicit SpecificFunctionMatcher(string name) : name(std::move(name)) { + } + + bool Match(string &name) override { + return name == this->name; + } + +private: + string name; +}; + +//! The ManyFunctionMatcher class matches a set of functions +class ManyFunctionMatcher : public FunctionMatcher { +public: + explicit ManyFunctionMatcher(unordered_set names) : names(std::move(names)) { + } + + bool Match(string &name) override { + return names.find(name) != names.end(); + } + +private: + unordered_set names; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/logical_operator_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/logical_operator_matcher.hpp new file mode 100644 index 00000000..fe6b548f --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/logical_operator_matcher.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/matcher/logical_operator_matcher.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/logical_operator_type.hpp" + +namespace duckdb { + +//! The LogicalOperatorMatcher class contains a set of matchers that can be used to match LogicalOperators +class LogicalOperatorMatcher { +public: + virtual ~LogicalOperatorMatcher() { + } + + virtual bool Match(LogicalOperatorType type) = 0; +}; + +//! The SpecificLogicalTypeMatcher class matches only a single specified LogicalOperatorType +class SpecificLogicalTypeMatcher : public LogicalOperatorMatcher { +public: + explicit SpecificLogicalTypeMatcher(LogicalOperatorType type) : type(type) { + } + + bool Match(LogicalOperatorType type) override { + return type == this->type; + } + +private: + LogicalOperatorType type; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/set_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/set_matcher.hpp new file mode 100644 index 00000000..a709a1b5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/set_matcher.hpp @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/matcher/set_matcher.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +class SetMatcher { +public: + //! The policy used by the SetMatcher + enum class Policy { + //! All entries have to be matched, and the matches have to be ordered + ORDERED, + //! All entries have to be matched, but the order of the matches does not matter + UNORDERED, + //! Only some entries have to be matched, the order of the matches does not matter + SOME, + //! Only some entries have to be matched. The order of the matches does matter. + SOME_ORDERED, + //! Not initialized + INVALID + }; + + /* The double {{}} in the intializer for excluded_entries is intentional, workaround for bug in gcc-4.9 */ + template + static bool MatchRecursive(vector> &matchers, vector> &entries, + vector> &bindings, unordered_set excluded_entries, idx_t m_idx = 0) { + if (m_idx == matchers.size()) { + // matched all matchers! + return true; + } + // try to find a match for the current matcher (m_idx) + idx_t previous_binding_count = bindings.size(); + for (idx_t e_idx = 0; e_idx < entries.size(); e_idx++) { + // first check if this entry has already been matched + if (excluded_entries.find(e_idx) != excluded_entries.end()) { + // it has been matched: skip this entry + continue; + } + // otherwise check if the current matcher matches this entry + if (matchers[m_idx]->Match(entries[e_idx], bindings)) { + // m_idx matches e_idx! + // check if we can find a complete match for this path + // first add e_idx to the new set of excluded entries + unordered_set new_excluded_entries; + new_excluded_entries = excluded_entries; + new_excluded_entries.insert(e_idx); + // then match the next matcher in the set + if (MatchRecursive(matchers, entries, bindings, new_excluded_entries, m_idx + 1)) { + // we found a match for this path! success + return true; + } else { + // we did not find a match! remove any bindings we added in the call to Match() + bindings.erase(bindings.begin() + previous_binding_count, bindings.end()); + } + } + } + return false; + } + + template + static bool Match(vector> &matchers, vector> &entries, + vector> &bindings, Policy policy) { + if (policy == Policy::ORDERED) { + // ordered policy, count has to match + if (matchers.size() != entries.size()) { + return false; + } + // now entries have to match in order + for (idx_t i = 0; i < matchers.size(); i++) { + if (!matchers[i]->Match(entries[i], bindings)) { + return false; + } + } + return true; + } else if (policy == Policy::SOME_ORDERED) { + if (entries.size() < matchers.size()) { + return false; + } + // now provided entries have to match in order + for (idx_t i = 0; i < matchers.size(); i++) { + if (!matchers[i]->Match(entries[i], bindings)) { + return false; + } + } + return true; + } else { + if (policy == Policy::UNORDERED && matchers.size() != entries.size()) { + // unordered policy, count does not match: no match + return false; + } else if (policy == Policy::SOME && matchers.size() > entries.size()) { + // some policy, every matcher has to match a unique entry + // this is not possible if there are more matchers than entries + return false; + } + // now perform the actual matching + // every matcher has to match a UNIQUE entry + // we perform this matching in a recursive way + unordered_set excluded_entries; + if (!MatchRecursive(matchers, entries, bindings, excluded_entries)) { + return false; + } + return true; + } + } + + template + static bool Match(vector> &matchers, vector> &entries, + vector> &bindings, Policy policy) { + // convert vector of unique_ptr to vector of normal pointers + vector> ptr_entries; + for (auto &entry : entries) { + ptr_entries.push_back(*entry); + } + // then just call the normal match function + return Match(matchers, ptr_entries, bindings, policy); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher.hpp new file mode 100644 index 00000000..5cd9d73e --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/matcher/type_matcher.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" + +namespace duckdb { + +//! The TypeMatcher class contains a set of matchers that can be used to pattern match TypeIds for Rules +class TypeMatcher { +public: + virtual ~TypeMatcher() { + } + + virtual bool Match(const LogicalType &type) = 0; +}; + +//! The SpecificTypeMatcher class matches only a single specified type +class SpecificTypeMatcher : public TypeMatcher { +public: + explicit SpecificTypeMatcher(LogicalType type) : type(type) { + } + + bool Match(const LogicalType &type_p) override { + return type_p == this->type; + } + +private: + LogicalType type; +}; + +//! The NumericTypeMatcher class matches any numeric type (DECIMAL, INTEGER, etc...) +class NumericTypeMatcher : public TypeMatcher { +public: + bool Match(const LogicalType &type) override { + return type.IsNumeric(); + } +}; + +//! The IntegerTypeMatcher class matches only integer types (INTEGER, SMALLINT, TINYINT, BIGINT) +class IntegerTypeMatcher : public TypeMatcher { +public: + bool Match(const LogicalType &type) override { + return type.IsIntegral(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher_id.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher_id.hpp new file mode 100644 index 00000000..9bd83f13 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher_id.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/matcher/type_matcher_id.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/optimizer/matcher/type_matcher.hpp" +namespace duckdb { + +//! The TypeMatcherId class contains a set of matchers that can be used to pattern match TypeIds for Rules +class TypeMatcherId : public TypeMatcher { +public: + explicit TypeMatcherId(LogicalTypeId type_id_p) : type_id(type_id_p) { + } + + bool Match(const LogicalType &type) override { + return type.id() == this->type_id; + } + +private: + LogicalTypeId type_id; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/optimizer.hpp new file mode 100644 index 00000000..312fcc49 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/optimizer.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/common/enums/optimizer_type.hpp" + +#include + +namespace duckdb { +class Binder; + +class Optimizer { +public: + Optimizer(Binder &binder, ClientContext &context); + + //! Optimize a plan by running specialized optimizers + unique_ptr Optimize(unique_ptr plan); + //! Return a reference to the client context of this optimizer + ClientContext &GetContext(); + + ClientContext &context; + Binder &binder; + ExpressionRewriter rewriter; + +private: + void RunOptimizer(OptimizerType type, const std::function &callback); + void Verify(LogicalOperator &op); + +private: + unique_ptr plan; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/optimizer_extension.hpp b/src/duckdb/src/include/duckdb/optimizer/optimizer_extension.hpp new file mode 100644 index 00000000..bd2b0c8f --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/optimizer_extension.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/optimizer_extension.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! The OptimizerExtensionInfo holds static information relevant to the optimizer extension +struct OptimizerExtensionInfo { + virtual ~OptimizerExtensionInfo() { + } +}; + +typedef void (*optimize_function_t)(ClientContext &context, OptimizerExtensionInfo *info, + unique_ptr &plan); + +class OptimizerExtension { +public: + //! The parse function of the parser extension. + //! Takes a query string as input and returns ParserExtensionParseData (on success) or an error + optimize_function_t optimize_function; + + //! Additional parser info passed to the parse function + shared_ptr optimizer_info; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/regex_range_filter.hpp b/src/duckdb/src/include/duckdb/optimizer/regex_range_filter.hpp new file mode 100644 index 00000000..36714f26 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/regex_range_filter.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/regex_range_filter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +class Optimizer; + +class RegexRangeFilter { +public: + RegexRangeFilter() { + } + //! Perform filter pushdown + unique_ptr Rewrite(unique_ptr op); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/remove_duplicate_groups.hpp b/src/duckdb/src/include/duckdb/optimizer/remove_duplicate_groups.hpp new file mode 100644 index 00000000..2abf8bb3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/remove_duplicate_groups.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/remove_duplicate_groups.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" + +namespace duckdb { + +class BoundColumnRefExpression; + +//! The RemoveDuplicateGroups optimizer traverses the logical operator tree and removes any duplicate aggregate groups +//! Duplicate groups may be introduced when joins columns are removed, e.g., by Deliminator or RemoveUnusedColumns +class RemoveDuplicateGroups : public LogicalOperatorVisitor { +public: + RemoveDuplicateGroups() { + } + + void VisitOperator(LogicalOperator &op) override; + +private: + void VisitAggregate(LogicalAggregate &aggr); + +protected: + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + +private: + //! The map of column references + column_binding_map_t>> column_references; + //! Stored expressions (kept around so we don't have dangling pointers) + vector> stored_expressions; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp b/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp new file mode 100644 index 00000000..629efc5d --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/remove_unused_columns.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +class Binder; +class BoundColumnRefExpression; +class ClientContext; + +//! The RemoveUnusedColumns optimizer traverses the logical operator tree and removes any columns that are not required +class RemoveUnusedColumns : public LogicalOperatorVisitor { +public: + RemoveUnusedColumns(Binder &binder, ClientContext &context, bool is_root = false) + : binder(binder), context(context), everything_referenced(is_root) { + } + + void VisitOperator(LogicalOperator &op) override; + +protected: + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + unique_ptr VisitReplace(BoundReferenceExpression &expr, unique_ptr *expr_ptr) override; + +private: + Binder &binder; + ClientContext &context; + //! Whether or not all the columns are referenced. This happens in the case of the root expression (because the + //! output implicitly refers all the columns below it) + bool everything_referenced; + //! The map of column references + column_binding_map_t> column_references; + +private: + template + void ClearUnusedExpressions(vector &list, idx_t table_idx, bool replace = true); + + //! Perform a replacement of the ColumnBinding, iterating over all the currently found column references and + //! replacing the bindings + void ReplaceBinding(ColumnBinding current_binding, ColumnBinding new_binding); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule.hpp b/src/duckdb/src/include/duckdb/optimizer/rule.hpp new file mode 100644 index 00000000..a04549fa --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/optimizer/matcher/logical_operator_matcher.hpp" + +namespace duckdb { +class ExpressionRewriter; + +class Rule { +public: + explicit Rule(ExpressionRewriter &rewriter) : rewriter(rewriter) { + } + virtual ~Rule() { + } + + //! The expression rewriter this rule belongs to + ExpressionRewriter &rewriter; + //! The root + unique_ptr logical_root; + //! The expression matcher of the rule + unique_ptr root; + + ClientContext &GetContext() const; + virtual unique_ptr Apply(LogicalOperator &op, vector> &bindings, + bool &fixed_point, bool is_root) = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/arithmetic_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/arithmetic_simplification.hpp new file mode 100644 index 00000000..cea57a9f --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/arithmetic_simplification.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/arithmetic_simplification.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The Arithmetic Simplification rule applies arithmetic expressions to which the answer is known (e.g. X + 0 => X, X * +// 0 => 0) +class ArithmeticSimplificationRule : public Rule { +public: + explicit ArithmeticSimplificationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/case_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/case_simplification.hpp new file mode 100644 index 00000000..5c99d3e0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/case_simplification.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/case_simplification.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The Case Simplification rule rewrites cases with a constant check (i.e. [CASE WHEN 1=1 THEN x ELSE y END] => x) +class CaseSimplificationRule : public Rule { +public: + explicit CaseSimplificationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/comparison_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/comparison_simplification.hpp new file mode 100644 index 00000000..b1e477d8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/comparison_simplification.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/comparison_simplification.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The Comparison Simplification rule rewrites comparisons with a constant NULL (i.e. [x = NULL] => [NULL]) +class ComparisonSimplificationRule : public Rule { +public: + explicit ComparisonSimplificationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/conjunction_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/conjunction_simplification.hpp new file mode 100644 index 00000000..7f22821e --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/conjunction_simplification.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/conjunction_simplification.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The Conjunction Simplification rule rewrites conjunctions with a constant +class ConjunctionSimplificationRule : public Rule { +public: + explicit ConjunctionSimplificationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; + + unique_ptr RemoveExpression(BoundConjunctionExpression &conj, const Expression &expr); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/constant_folding.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/constant_folding.hpp new file mode 100644 index 00000000..aba6ce9c --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/constant_folding.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/constant_folding.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// Fold any constant scalar expressions into a single constant (i.e. [2 + 2] => [4], [2 = 2] => [True], etc...) +class ConstantFoldingRule : public Rule { +public: + explicit ConstantFoldingRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/date_part_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/date_part_simplification.hpp new file mode 100644 index 00000000..e328acd2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/date_part_simplification.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/date_part_simplification.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The DatePart Simplification rule rewrites date_part with a constant specifier into a specialized function (e.g. +// date_part('year', x) => year(x)) +class DatePartSimplificationRule : public Rule { +public: + explicit DatePartSimplificationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/distributivity.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/distributivity.hpp new file mode 100644 index 00000000..b0c29b7c --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/distributivity.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/distributivity.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" +#include "duckdb/parser/expression_map.hpp" + +namespace duckdb { + +// (X AND B) OR (X AND C) OR (X AND D) = X AND (B OR C OR D) +class DistributivityRule : public Rule { +public: + explicit DistributivityRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; + +private: + void AddExpressionSet(Expression &expr, expression_set_t &set); + unique_ptr ExtractExpression(BoundConjunctionExpression &conj, idx_t idx, Expression &expr); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/empty_needle_removal.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/empty_needle_removal.hpp new file mode 100644 index 00000000..3ce21e77 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/empty_needle_removal.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/empty_needle_removal.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The Empty_needle_removal Optimization rule folds some foldable ConstantExpression +//(e.g.: PREFIX('xyz', '') is TRUE, PREFIX(NULL, '') is NULL, so rewrite PREFIX(x, '') to TRUE_OR_NULL(x) +class EmptyNeedleRemovalRule : public Rule { +public: + explicit EmptyNeedleRemovalRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/enum_comparison.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/enum_comparison.hpp new file mode 100644 index 00000000..10d89330 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/enum_comparison.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/enum_comparison.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The Enum Comparison rule rewrites cases where two Enums are compared on an equality check +class EnumComparisonRule : public Rule { +public: + explicit EnumComparisonRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/equal_or_null_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/equal_or_null_simplification.hpp new file mode 100644 index 00000000..f80c1dfe --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/equal_or_null_simplification.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/equal_or_null_simplification.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// Rewrite +// a=b OR (a IS NULL AND b IS NULL) to a IS NOT DISTINCT FROM b +class EqualOrNullSimplification : public Rule { +public: + explicit EqualOrNullSimplification(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/in_clause_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/in_clause_simplification.hpp new file mode 100644 index 00000000..0afabeae --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/in_clause_simplification.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/in_clause_simplification.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The in clause simplification rule rewrites cases where left is a column ref with a cast and right are constant values +class InClauseSimplificationRule : public Rule { +public: + explicit InClauseSimplificationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/like_optimizations.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/like_optimizations.hpp new file mode 100644 index 00000000..c2bcf6a7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/like_optimizations.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/like_optimizations.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" +#include "duckdb/function/scalar/string_functions.hpp" + +namespace duckdb { + +// The Like Optimization rule rewrites LIKE to optimized scalar functions (e.g.: prefix, suffix, and contains) +class LikeOptimizationRule : public Rule { +public: + explicit LikeOptimizationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; + + unique_ptr ApplyRule(BoundFunctionExpression &expr, ScalarFunction function, string pattern, + bool is_not_like); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp new file mode 100644 index 00000000..e10f2727 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp @@ -0,0 +1,13 @@ +#include "duckdb/optimizer/rule/arithmetic_simplification.hpp" +#include "duckdb/optimizer/rule/case_simplification.hpp" +#include "duckdb/optimizer/rule/comparison_simplification.hpp" +#include "duckdb/optimizer/rule/conjunction_simplification.hpp" +#include "duckdb/optimizer/rule/constant_folding.hpp" +#include "duckdb/optimizer/rule/date_part_simplification.hpp" +#include "duckdb/optimizer/rule/distributivity.hpp" +#include "duckdb/optimizer/rule/empty_needle_removal.hpp" +#include "duckdb/optimizer/rule/like_optimizations.hpp" +#include "duckdb/optimizer/rule/move_constants.hpp" +#include "duckdb/optimizer/rule/enum_comparison.hpp" +#include "duckdb/optimizer/rule/regex_optimizations.hpp" +#include "duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp" diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/move_constants.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/move_constants.hpp new file mode 100644 index 00000000..343fe590 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/move_constants.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/move_constants.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// The MoveConstantsRule moves constants to the same side of an expression, e.g. if we have an expression x + 1 = 5000 +// then this will turn it into x = 4999. +class MoveConstantsRule : public Rule { +public: + explicit MoveConstantsRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp new file mode 100644 index 00000000..0de9a95d --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" +#include "duckdb/parser/expression_map.hpp" + +namespace duckdb { + +class OrderedAggregateOptimizer : public Rule { +public: + explicit OrderedAggregateOptimizer(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/regex_optimizations.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/regex_optimizations.hpp new file mode 100644 index 00000000..1f914ec1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/regex_optimizations.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/like_optimizations.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" +#include "duckdb/function/scalar/string_functions.hpp" + +namespace duckdb { + +class RegexOptimizationRule : public Rule { +public: + explicit RegexOptimizationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; + + unique_ptr ApplyRule(BoundFunctionExpression *expr, ScalarFunction function, string pattern, + bool is_not_like); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/statistics_propagator.hpp b/src/duckdb/src/include/duckdb/optimizer/statistics_propagator.hpp new file mode 100644 index 00000000..75699099 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/statistics_propagator.hpp @@ -0,0 +1,116 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/statistics_propagator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/planner/bound_tokens.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/logical_tokens.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/node_statistics.hpp" + +namespace duckdb { + +class Optimizer; +class ClientContext; +class LogicalOperator; +class TableFilter; +struct BoundOrderByNode; + +class StatisticsPropagator { +public: + explicit StatisticsPropagator(Optimizer &optimizer); + + unique_ptr PropagateStatistics(unique_ptr &node_ptr); + + column_binding_map_t> GetStatisticsMap() { + return std::move(statistics_map); + } + +private: + //! Propagate statistics through an operator + unique_ptr PropagateStatistics(LogicalOperator &node, unique_ptr *node_ptr); + + unique_ptr PropagateStatistics(LogicalFilter &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalGet &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalJoin &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalPositionalJoin &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalProjection &op, unique_ptr *node_ptr); + void PropagateStatistics(LogicalComparisonJoin &op, unique_ptr *node_ptr); + void PropagateStatistics(LogicalAnyJoin &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalSetOperation &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalAggregate &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalCrossProduct &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalLimit &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalOrder &op, unique_ptr *node_ptr); + unique_ptr PropagateStatistics(LogicalWindow &op, unique_ptr *node_ptr); + + unique_ptr PropagateChildren(LogicalOperator &node, unique_ptr *node_ptr); + + //! Return statistics from a constant value + unique_ptr StatisticsFromValue(const Value &input); + //! Run a comparison with two sets of statistics, returns if the comparison will always returns true/false or not + FilterPropagateResult PropagateComparison(BaseStatistics &left, BaseStatistics &right, ExpressionType comparison); + + //! Update filter statistics from a filter with a constant + void UpdateFilterStatistics(BaseStatistics &input, ExpressionType comparison_type, const Value &constant); + //! Update statistics from a filter between two stats + void UpdateFilterStatistics(BaseStatistics &lstats, BaseStatistics &rstats, ExpressionType comparison_type); + //! Update filter statistics from a generic comparison + void UpdateFilterStatistics(Expression &left, Expression &right, ExpressionType comparison_type); + //! Update filter statistics from an expression + void UpdateFilterStatistics(Expression &condition); + //! Set the statistics of a specific column binding to not contain null values + void SetStatisticsNotNull(ColumnBinding binding); + + //! Run a comparison between the statistics and the table filter; returns the prune result + FilterPropagateResult PropagateTableFilter(BaseStatistics &stats, TableFilter &filter); + //! Update filter statistics from a TableFilter + void UpdateFilterStatistics(BaseStatistics &input, TableFilter &filter); + + //! Add cardinalities together (i.e. new max is stats.max + new_stats.max): used for union + void AddCardinalities(unique_ptr &stats, NodeStatistics &new_stats); + //! Multiply the cardinalities together (i.e. new max cardinality is stats.max * new_stats.max): used for + //! joins/cross products + void MultiplyCardinalities(unique_ptr &stats, NodeStatistics &new_stats); + //! Creates and pushes down a filter based on join statistics + void CreateFilterFromJoinStats(unique_ptr &child, unique_ptr &expr, + const BaseStatistics &stats_before, const BaseStatistics &stats_after); + + unique_ptr PropagateExpression(unique_ptr &expr); + unique_ptr PropagateExpression(Expression &expr, unique_ptr *expr_ptr); + + unique_ptr PropagateExpression(BoundAggregateExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundBetweenExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundCaseExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundCastExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundConjunctionExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundFunctionExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundComparisonExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundConstantExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundColumnRefExpression &expr, unique_ptr *expr_ptr); + unique_ptr PropagateExpression(BoundOperatorExpression &expr, unique_ptr *expr_ptr); + + void ReplaceWithEmptyResult(unique_ptr &node); + + bool ExpressionIsConstant(Expression &expr, const Value &val); + bool ExpressionIsConstantOrNull(Expression &expr, const Value &val); + +private: + Optimizer &optimizer; + ClientContext &context; + //! The map of ColumnBinding -> statistics for the various nodes + column_binding_map_t> statistics_map; + //! Node stats for the current node + unique_ptr node_stats; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/topn_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/topn_optimizer.hpp new file mode 100644 index 00000000..94ebaed2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/topn_optimizer.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/topn_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { +class LogicalOperator; +class Optimizer; + +class TopN { +public: + //! Optimize ORDER BY + LIMIT to TopN + unique_ptr Optimize(unique_ptr op); + //! Whether we can perform the optimization on this operator + static bool CanOptimize(LogicalOperator &op); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp b/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp new file mode 100644 index 00000000..d51a0080 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/unnest_rewriter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/common/pair.hpp" + +namespace duckdb { + +class Optimizer; + +struct ReplaceBinding { + ReplaceBinding() {}; + ReplaceBinding(ColumnBinding old_binding, ColumnBinding new_binding) + : old_binding(old_binding), new_binding(new_binding) { + } + ColumnBinding old_binding; + ColumnBinding new_binding; +}; + +struct LHSBinding { + LHSBinding() {}; + LHSBinding(ColumnBinding binding, LogicalType type) : binding(binding), type(type) { + } + ColumnBinding binding; + LogicalType type; + string alias; +}; + +//! The UnnestRewriterPlanUpdater updates column bindings after changing the operator plan +class UnnestRewriterPlanUpdater : LogicalOperatorVisitor { +public: + UnnestRewriterPlanUpdater() { + } + //! Update each operator of the plan after moving an UNNEST into a projection + void VisitOperator(LogicalOperator &op) override; + //! Visit an expression and update its column bindings after moving and UNNEST into a projection + void VisitExpression(unique_ptr *expression) override; + + //! Contains all bindings that need to be updated + vector replace_bindings; + //! Stores the table index of the former child of the LOGICAL_UNNEST + idx_t overwritten_tbl_idx; +}; + +//! The UnnestRewriter optimizer traverses the logical operator tree and rewrites duplicate +//! eliminated joins that contain UNNESTs by moving the UNNESTs into the projection of +//! the SELECT +class UnnestRewriter { +public: + UnnestRewriter() { + } + //! Rewrite duplicate eliminated joins with UNNESTs + unique_ptr Optimize(unique_ptr op); + +private: + //! Find delim joins that contain an UNNEST + void FindCandidates(unique_ptr *op_ptr, vector *> &candidates); + //! Rewrite a delim join that contains an UNNEST + bool RewriteCandidate(unique_ptr *candidate); + //! Update the bindings of the RHS sequence of LOGICAL_PROJECTION(s) + void UpdateRHSBindings(unique_ptr *plan_ptr, unique_ptr *candidate, + UnnestRewriterPlanUpdater &updater); + //! Update the bindings of the BOUND_UNNEST expression of the LOGICAL_UNNEST + void UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater, unique_ptr *candidate); + + //! Store all delim columns of the delim join + void GetDelimColumns(LogicalOperator &op); + //! Store all LHS expressions of the LOGICAL_PROJECTION + void GetLHSExpressions(LogicalOperator &op); + + //! Keep track of the delim columns to find the correct UNNEST column + vector delim_columns; + //! Store the column bindings of the LHS child of the LOGICAL_DELIM_JOIN + vector lhs_bindings; + //! Stores the table index of the former child of the LOGICAL_UNNEST + idx_t overwritten_tbl_idx; + //! The number of distinct columns to unnest + idx_t distinct_unnest_count; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/base_pipeline_event.hpp b/src/duckdb/src/include/duckdb/parallel/base_pipeline_event.hpp new file mode 100644 index 00000000..a4d909eb --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/base_pipeline_event.hpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/base_pipeline_event.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parallel/event.hpp" +#include "duckdb/parallel/pipeline.hpp" + +namespace duckdb { + +//! A BasePipelineEvent is used as the basis of any event that belongs to a specific pipeline +class BasePipelineEvent : public Event { +public: + explicit BasePipelineEvent(shared_ptr pipeline); + explicit BasePipelineEvent(Pipeline &pipeline); + + void PrintPipeline() override { + pipeline->Print(); + } + + //! The pipeline that this event belongs to + shared_ptr pipeline; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/concurrentqueue.hpp b/src/duckdb/src/include/duckdb/parallel/concurrentqueue.hpp new file mode 100644 index 00000000..b83896b4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/concurrentqueue.hpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/concurrentqueue.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#ifndef DUCKDB_NO_THREADS +#include "concurrentqueue.h" +#else + +#include +#include +#include + +namespace duckdb_moodycamel { + +template +class ConcurrentQueue; +template +class BlockingConcurrentQueue; + +struct ProducerToken { + //! Constructor + template + explicit ProducerToken(ConcurrentQueue &); + //! Constructor + template + explicit ProducerToken(BlockingConcurrentQueue &); + //! Constructor + ProducerToken(ProducerToken &&) { + } + //! Is valid token? + inline bool valid() const { + return true; + } +}; + +template +class ConcurrentQueue { +private: + //! The queue + std::queue> q; + +public: + //! Constructor + ConcurrentQueue() = default; + //! Constructor + explicit ConcurrentQueue(size_t capacity) { + q.reserve(capacity); + } + + //! Enqueue item + template + bool enqueue(U &&item) { + q.push(std::forward(item)); + return true; + } + //! Try to dequeue an item + bool try_dequeue(T &item) { + if (q.empty()) { + return false; + } + item = std::move(q.front()); + q.pop(); + return true; + } +}; + +} // namespace duckdb_moodycamel + +#endif diff --git a/src/duckdb/src/include/duckdb/parallel/event.hpp b/src/duckdb/src/include/duckdb/parallel/event.hpp new file mode 100644 index 00000000..1cfee691 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/event.hpp @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/event.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +class Executor; +class Task; + +class Event : public std::enable_shared_from_this { +public: + explicit Event(Executor &executor); + virtual ~Event() = default; + +public: + virtual void Schedule() = 0; + //! Called right after the event is finished + virtual void FinishEvent() { + } + //! Called after the event is entirely finished + virtual void FinalizeFinish() { + } + + void FinishTask(); + void Finish(); + + void AddDependency(Event &event); + bool HasDependencies() const { + return total_dependencies != 0; + } + const vector &GetParentsVerification() const; + + void CompleteDependency(); + + void SetTasks(vector> tasks); + + void InsertEvent(shared_ptr replacement_event); + + bool IsFinished() const { + return finished; + } + + virtual void PrintPipeline() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + +protected: + Executor &executor; + //! The current threads working on the event + atomic finished_tasks; + //! The maximum amount of threads that can work on the event + atomic total_tasks; + + //! The amount of completed dependencies + //! The event can only be started after the dependencies have finished executing + atomic finished_dependencies; + //! The total amount of dependencies + idx_t total_dependencies; + + //! The events that depend on this event to run + vector> parents; + //! Raw pointers to the parents (used for verification only) + vector parents_raw; + + //! Whether or not the event is finished executing + atomic finished; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/interrupt.hpp b/src/duckdb/src/include/duckdb/parallel/interrupt.hpp new file mode 100644 index 00000000..5805e158 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/interrupt.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// src/include/duckdb/parallel/interrupt.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/parallel/task.hpp" +#include +#include + +namespace duckdb { + +//! InterruptMode specifies how operators should block/unblock, note that this will happen transparently to the +//! operator, as the operator only needs to return a BLOCKED result and call the callback using the InterruptState. +//! NO_INTERRUPTS: No blocking mode is specified, an error will be thrown when the operator blocks. Should only be used +//! when manually calling operators of which is known they will never block. +//! TASK: A weak pointer to a task is provided. On the callback, this task will be signalled. If the Task has +//! been deleted, this callback becomes a NOP. This is the preferred way to await blocked pipelines. +//! BLOCKING: The caller has blocked awaiting some synchronization primitive to wait for the callback. +enum class InterruptMode : uint8_t { NO_INTERRUPTS, TASK, BLOCKING }; + +//! Synchronization primitive used to await a callback in InterruptMode::BLOCKING. +struct InterruptDoneSignalState { + //! Called by the callback to signal the interrupt is over + void Signal(); + //! Await the callback signalling the interrupt is over + void Await(); + +protected: + mutex lock; + std::condition_variable cv; + bool done = false; +}; + +//! State required to make the callback after some asynchronous operation within an operator source / sink. +class InterruptState { +public: + //! Default interrupt state will be set to InterruptMode::NO_INTERRUPTS and throw an error on use of Callback() + InterruptState(); + //! Register the task to be interrupted and set mode to InterruptMode::TASK, the preferred way to handle interrupts + InterruptState(weak_ptr task); + //! Register signal state and set mode to InterruptMode::BLOCKING, used for code paths without Task. + InterruptState(weak_ptr done_signal); + + //! Perform the callback to indicate the Interrupt is over + DUCKDB_API void Callback() const; + +protected: + //! Current interrupt mode + InterruptMode mode; + //! Task ptr for InterruptMode::TASK + weak_ptr current_task; + //! Signal state for InterruptMode::BLOCKING + weak_ptr signal_state; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/meta_pipeline.hpp b/src/duckdb/src/include/duckdb/parallel/meta_pipeline.hpp new file mode 100644 index 00000000..82b92ba5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/meta_pipeline.hpp @@ -0,0 +1,110 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/meta_pipeline.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" + +namespace duckdb { + +class PhysicalRecursiveCTE; + +struct PipelineFinishGroup { + explicit PipelineFinishGroup(Pipeline *group_base_p) : group_base(group_base_p) { + } + Pipeline *group_base; + unordered_set group_members; +}; + +//! MetaPipeline represents a set of pipelines that all have the same sink +class MetaPipeline : public std::enable_shared_from_this { + //! We follow these rules when building: + //! 1. For joins, build out the blocking side before going down the probe side + //! - The current streaming pipeline will have a dependency on it (dependency across MetaPipelines) + //! - Unions of this streaming pipeline will automatically inherit this dependency + //! 2. Build child pipelines last (e.g., Hash Join becomes source after probe is done: scan HT for FULL OUTER JOIN) + //! - 'last' means after building out all other pipelines associated with this operator + //! - The child pipeline automatically has dependencies (within this MetaPipeline) on: + //! * The 'current' streaming pipeline + //! * And all pipelines that were added to the MetaPipeline after 'current' +public: + //! Create a MetaPipeline with the given sink + explicit MetaPipeline(Executor &executor, PipelineBuildState &state, PhysicalOperator *sink); + +public: + //! Get the Executor for this MetaPipeline + Executor &GetExecutor() const; + //! Get the PipelineBuildState for this MetaPipeline + PipelineBuildState &GetState() const; + //! Get the sink operator for this MetaPipeline + optional_ptr GetSink() const; + + //! Get the initial pipeline of this MetaPipeline + shared_ptr &GetBasePipeline(); + //! Get the pipelines of this MetaPipeline + void GetPipelines(vector> &result, bool recursive); + //! Get the MetaPipeline children of this MetaPipeline + void GetMetaPipelines(vector> &result, bool recursive, bool skip); + //! Get the dependencies (within this MetaPipeline) of the given Pipeline + const vector *GetDependencies(Pipeline *dependant) const; + //! Whether this MetaPipeline has a recursive CTE + bool HasRecursiveCTE() const; + //! Set the flag that this MetaPipeline is a recursive CTE pipeline + void SetRecursiveCTE(); + //! Assign a batch index to the given pipeline + void AssignNextBatchIndex(Pipeline *pipeline); + //! Let 'dependant' depend on all pipeline that were created since 'start', + //! where 'including' determines whether 'start' is added to the dependencies + void AddDependenciesFrom(Pipeline *dependant, Pipeline *start, bool including); + //! Make sure that the given pipeline has its own PipelineFinishEvent (e.g., for IEJoin - double Finalize) + void AddFinishEvent(Pipeline *pipeline); + //! Whether the pipeline needs its own PipelineFinishEvent + bool HasFinishEvent(Pipeline *pipeline) const; + //! Whether this pipeline is part of a PipelineFinishEvent + optional_ptr GetFinishGroup(Pipeline *pipeline) const; + +public: + //! Build the MetaPipeline with 'op' as the first operator (excl. the shared sink) + void Build(PhysicalOperator &op); + //! Ready all the pipelines (recursively) + void Ready(); + + //! Create an empty pipeline within this MetaPipeline + Pipeline *CreatePipeline(); + //! Create a union pipeline (clone of 'current') + Pipeline *CreateUnionPipeline(Pipeline ¤t, bool order_matters); + //! Create a child pipeline op 'current' starting at 'op', + //! where 'last_pipeline' is the last pipeline added before building out 'current' + void CreateChildPipeline(Pipeline ¤t, PhysicalOperator &op, Pipeline *last_pipeline); + //! Create a MetaPipeline child that 'current' depends on + MetaPipeline &CreateChildMetaPipeline(Pipeline ¤t, PhysicalOperator &op); + +private: + //! The executor for all MetaPipelines in the query plan + Executor &executor; + //! The PipelineBuildState for all MetaPipelines in the query plan + PipelineBuildState &state; + //! The sink of all pipelines within this MetaPipeline + optional_ptr sink; + //! Whether this MetaPipeline is a the recursive pipeline of a recursive CTE + bool recursive_cte; + //! All pipelines with a different source, but the same sink + vector> pipelines; + //! Dependencies within this MetaPipeline + unordered_map> dependencies; + //! Other MetaPipelines that this MetaPipeline depends on + vector> children; + //! Next batch index + idx_t next_batch_index; + //! Pipelines (other than the base pipeline) that need their own PipelineFinishEvent (e.g., for IEJoin) + unordered_set finish_pipelines; + //! Mapping from pipeline (e.g., child or union) to finish pipeline + unordered_map finish_map; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline.hpp new file mode 100644 index 00000000..27f9fa65 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/pipeline.hpp @@ -0,0 +1,139 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/pipeline.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/common/reference_map.hpp" + +namespace duckdb { + +class Executor; +class Event; +class MetaPipeline; + +class PipelineBuildState { +public: + //! How much to increment batch indexes when multiple pipelines share the same source + constexpr static idx_t BATCH_INCREMENT = 10000000000000; + +public: + //! Duplicate eliminated join scan dependencies + reference_map_t> delim_join_dependencies; + //! Materialized CTE scan dependencies + reference_map_t> cte_dependencies; + +public: + void SetPipelineSource(Pipeline &pipeline, PhysicalOperator &op); + void SetPipelineSink(Pipeline &pipeline, optional_ptr op, idx_t sink_pipeline_count); + void SetPipelineOperators(Pipeline &pipeline, vector> operators); + void AddPipelineOperator(Pipeline &pipeline, PhysicalOperator &op); + shared_ptr CreateChildPipeline(Executor &executor, Pipeline &pipeline, PhysicalOperator &op); + + optional_ptr GetPipelineSource(Pipeline &pipeline); + optional_ptr GetPipelineSink(Pipeline &pipeline); + vector> GetPipelineOperators(Pipeline &pipeline); +}; + +//! The Pipeline class represents an execution pipeline starting at a +class Pipeline : public std::enable_shared_from_this { + friend class Executor; + friend class PipelineExecutor; + friend class PipelineEvent; + friend class PipelineFinishEvent; + friend class PipelineBuildState; + friend class MetaPipeline; + +public: + explicit Pipeline(Executor &execution_context); + + Executor &executor; + +public: + ClientContext &GetClientContext(); + + void AddDependency(shared_ptr &pipeline); + + void Ready(); + void Reset(); + void ResetSink(); + void ResetSource(bool force); + void ClearSource(); + void Schedule(shared_ptr &event); + + string ToString() const; + void Print() const; + void PrintDependencies() const; + + //! Returns query progress + bool GetProgress(double ¤t_percentage, idx_t &estimated_cardinality); + + //! Returns a list of all operators (including source and sink) involved in this pipeline + vector> GetOperators(); + vector> GetOperators() const; + + optional_ptr GetSink() { + return sink; + } + + optional_ptr GetSource() { + return source; + } + + //! Returns whether any of the operators in the pipeline care about preserving order + bool IsOrderDependent() const; + + //! Registers a new batch index for a pipeline executor - returns the current minimum batch index + idx_t RegisterNewBatchIndex(); + + //! Updates the batch index of a pipeline (and returns the new minimum batch index) + idx_t UpdateBatchIndex(idx_t old_index, idx_t new_index); + +private: + //! Whether or not the pipeline has been readied + bool ready; + //! Whether or not the pipeline has been initialized + atomic initialized; + //! The source of this pipeline + optional_ptr source; + //! The chain of intermediate operators + vector> operators; + //! The sink (i.e. destination) for data; this is e.g. a hash table to-be-built + optional_ptr sink; + + //! The global source state + unique_ptr source_state; + + //! The parent pipelines (i.e. pipelines that are dependent on this pipeline to finish) + vector> parents; + //! The dependencies of this pipeline + vector> dependencies; + + //! The base batch index of this pipeline + idx_t base_batch_index = 0; + //! Lock for accessing the set of batch indexes + mutex batch_lock; + //! The set of batch indexes that are currently being processed + //! Despite batch indexes being unique - this is a multiset + //! The reason is that when we start a new pipeline we insert the current minimum batch index as a placeholder + //! Which leads to duplicate entries in the set of active batch indexes + multiset batch_indexes; + +private: + void ScheduleSequentialTask(shared_ptr &event); + bool LaunchScanTasks(shared_ptr &event, idx_t max_threads); + + bool ScheduleParallel(shared_ptr &event); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline_complete_event.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline_complete_event.hpp new file mode 100644 index 00000000..753aa851 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/pipeline_complete_event.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/pipeline_complete_event.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parallel/event.hpp" + +namespace duckdb { +class Executor; + +class PipelineCompleteEvent : public Event { +public: + PipelineCompleteEvent(Executor &executor, bool complete_pipeline_p); + + bool complete_pipeline; + +public: + void Schedule() override; + void FinalizeFinish() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline_event.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline_event.hpp new file mode 100644 index 00000000..7af51d48 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/pipeline_event.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/pipeline_event.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parallel/base_pipeline_event.hpp" + +namespace duckdb { + +//! A PipelineEvent is responsible for scheduling a pipeline +class PipelineEvent : public BasePipelineEvent { +public: + PipelineEvent(shared_ptr pipeline); + +public: + void Schedule() override; + void FinishEvent() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp new file mode 100644 index 00000000..6169425f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp @@ -0,0 +1,151 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/pipeline_executor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/parallel/interrupt.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/common/stack.hpp" + +#include + +namespace duckdb { +class Executor; + +//! The result of executing a PipelineExecutor +enum class PipelineExecuteResult { + //! PipelineExecutor is fully executed: the source is completely exhausted + FINISHED, + //! PipelineExecutor is not yet fully executed and can be called again immediately + NOT_FINISHED, + //! The PipelineExecutor was interrupted and should not be called again until the interrupt is handled as specified + //! in the InterruptMode + INTERRUPTED +}; + +//! The Pipeline class represents an execution pipeline +class PipelineExecutor { +public: + PipelineExecutor(ClientContext &context, Pipeline &pipeline); + + //! Fully execute a pipeline with a source and a sink until the source is completely exhausted + PipelineExecuteResult Execute(); + //! Execute a pipeline with a source and a sink until finished, or until max_chunks were processed from the source + //! Returns true if execution is finished, false if Execute should be called again + PipelineExecuteResult Execute(idx_t max_chunks); + + //! Push a single input DataChunk into the pipeline. + //! Returns either OperatorResultType::NEED_MORE_INPUT or OperatorResultType::FINISHED + //! If OperatorResultType::FINISHED is returned, more input will not change the result anymore + OperatorResultType ExecutePush(DataChunk &input); + //! Called after depleting the source: finalizes the execution of this pipeline executor + //! This should only be called once per PipelineExecutor. + PipelineExecuteResult PushFinalize(); + + //! Initializes a chunk with the types that will flow out of ExecutePull + void InitializeChunk(DataChunk &chunk); + //! Execute a pipeline without a sink, and retrieve a single DataChunk + //! Returns an empty chunk when finished. + void ExecutePull(DataChunk &result); + //! Called after depleting the source using ExecutePull + //! This flushes profiler states + void PullFinalize(); + + //! Registers the task in the interrupt_state to allow Source/Sink operators to block the task + void SetTaskForInterrupts(weak_ptr current_task); + +private: + //! The pipeline to process + Pipeline &pipeline; + //! The thread context of this executor + ThreadContext thread; + //! The total execution context of this executor + ExecutionContext context; + + //! Intermediate chunks for the operators + vector> intermediate_chunks; + //! Intermediate states for the operators + vector> intermediate_states; + + //! The local source state + unique_ptr local_source_state; + //! The local sink state (if any) + unique_ptr local_sink_state; + //! The interrupt state, holding required information for sink/source operators to block + InterruptState interrupt_state; + + //! The final chunk used for moving data into the sink + DataChunk final_chunk; + + //! The operators that are not yet finished executing and have data remaining + //! If the stack of in_process_operators is empty, we fetch from the source instead + stack in_process_operators; + //! Whether or not the pipeline has been finalized (used for verification only) + bool finalized = false; + //! Whether or not the pipeline has finished processing + int32_t finished_processing_idx = -1; + //! Whether or not this pipeline requires keeping track of the batch index of the source + bool requires_batch_index = false; + + //! Source has indicated it is exhausted + bool exhausted_source = false; + //! Flushing of intermediate operators has started + bool started_flushing = false; + //! Flushing of caching operators is done + bool done_flushing = false; + + //! This flag is set when the pipeline gets interrupted by the Sink -> the final_chunk should be re-sink-ed. + bool remaining_sink_chunk = false; + + //! Current operator being flushed + idx_t flushing_idx; + //! Whether the current flushing_idx should be flushed: this needs to be stored to make flushing code re-entrant + bool should_flush_current_idx = true; + +private: + void StartOperator(PhysicalOperator &op); + void EndOperator(PhysicalOperator &op, optional_ptr chunk); + + //! Reset the operator index to the first operator + void GoToSource(idx_t ¤t_idx, idx_t initial_idx); + SourceResultType FetchFromSource(DataChunk &result); + + void FinishProcessing(int32_t operator_idx = -1); + bool IsFinished(); + + //! Wrappers for sink/source calls to respective operators + SourceResultType GetData(DataChunk &chunk, OperatorSourceInput &input); + SinkResultType Sink(DataChunk &chunk, OperatorSinkInput &input); + + OperatorResultType ExecutePushInternal(DataChunk &input, idx_t initial_idx = 0); + //! Pushes a chunk through the pipeline and returns a single result chunk + //! Returns whether or not a new input chunk is needed, or whether or not we are finished + OperatorResultType Execute(DataChunk &input, DataChunk &result, idx_t initial_index = 0); + + //! Tries to flush all state from intermediate operators. Will return true if all state is flushed, false in the + //! case of a blocked sink. + bool TryFlushCachingOperators(); + + static bool CanCacheType(const LogicalType &type); + void CacheChunk(DataChunk &input, idx_t operator_idx); + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + //! Debugging state: number of times blocked + int debug_blocked_sink_count = 0; + int debug_blocked_source_count = 0; + int debug_blocked_combine_count = 0; + //! Number of times the Sink/Source will block before actually returning data + int debug_blocked_target_count = 1; +#endif +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline_finish_event.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline_finish_event.hpp new file mode 100644 index 00000000..1a486a61 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/pipeline_finish_event.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/pipeline_finish_event.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parallel/base_pipeline_event.hpp" + +namespace duckdb { +class Executor; + +class PipelineFinishEvent : public BasePipelineEvent { +public: + explicit PipelineFinishEvent(shared_ptr pipeline); + +public: + void Schedule() override; + void FinishEvent() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline_initialize_event.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline_initialize_event.hpp new file mode 100644 index 00000000..664717dc --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/pipeline_initialize_event.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/pipeline_finish_event.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parallel/base_pipeline_event.hpp" + +namespace duckdb { + +class Executor; + +class PipelineInitializeEvent : public BasePipelineEvent { +public: + explicit PipelineInitializeEvent(shared_ptr pipeline); + +public: + void Schedule() override; + void FinishEvent() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/task.hpp b/src/duckdb/src/include/duckdb/parallel/task.hpp new file mode 100644 index 00000000..2245c74a --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/task.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/task.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { +class ClientContext; +class Executor; +class Task; +class DatabaseInstance; +struct ProducerToken; + +enum class TaskExecutionMode : uint8_t { PROCESS_ALL, PROCESS_PARTIAL }; + +enum class TaskExecutionResult : uint8_t { TASK_FINISHED, TASK_NOT_FINISHED, TASK_ERROR, TASK_BLOCKED }; + +//! Generic parallel task +class Task : public std::enable_shared_from_this { +public: + virtual ~Task() { + } + + //! Execute the task in the specified execution mode + //! If mode is PROCESS_ALL, Execute should always finish processing and return TASK_FINISHED + //! If mode is PROCESS_PARTIAL, Execute can return TASK_NOT_FINISHED, in which case Execute will be called again + //! In case of an error, TASK_ERROR is returned + //! In case the task has interrupted, BLOCKED is returned. + virtual TaskExecutionResult Execute(TaskExecutionMode mode) = 0; + + //! Descheduling a task ensures the task is not executed, but remains available for rescheduling as long as + //! required, generally until some code in an operator calls the InterruptState::Callback() method of a state of the + //! InterruptMode::TASK mode. + virtual void Deschedule() { + throw InternalException("Cannot deschedule task of base Task class"); + }; + + //! Ensures a task is rescheduled to the correct queue + virtual void Reschedule() { + throw InternalException("Cannot reschedule task of base Task class"); + } +}; + +//! Execute a task within an executor, including exception handling +//! This should be used within queries +class ExecutorTask : public Task { +public: + ExecutorTask(Executor &executor); + ExecutorTask(ClientContext &context); + virtual ~ExecutorTask(); + + void Deschedule() override; + void Reschedule() override; + + Executor &executor; + +public: + virtual TaskExecutionResult ExecuteTask(TaskExecutionMode mode) = 0; + TaskExecutionResult Execute(TaskExecutionMode mode) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/task_counter.hpp b/src/duckdb/src/include/duckdb/parallel/task_counter.hpp new file mode 100644 index 00000000..53683b5d --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/task_counter.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/task_counter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parallel/task_scheduler.hpp" + +namespace duckdb { + +class TaskCounter { +public: + explicit TaskCounter(TaskScheduler &scheduler_p) + : scheduler(scheduler_p), token(scheduler_p.CreateProducer()), task_count(0), tasks_completed(0) { + } + + virtual void AddTask(shared_ptr task) { + ++task_count; + scheduler.ScheduleTask(*token, std::move(task)); + } + + virtual void FinishTask() const { + ++tasks_completed; + } + + virtual void Finish() { + while (tasks_completed < task_count) { + shared_ptr task; + if (scheduler.GetTaskFromProducer(*token, task)) { + task->Execute(); + task.reset(); + } + } + } + +private: + TaskScheduler &scheduler; + unique_ptr token; + size_t task_count; + mutable atomic tasks_completed; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp b/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp new file mode 100644 index 00000000..b54da51c --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/task_scheduler.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/parallel/task.hpp" +#include "duckdb/common/atomic.hpp" + +namespace duckdb { + +struct ConcurrentQueue; +struct QueueProducerToken; +class ClientContext; +class DatabaseInstance; +class TaskScheduler; + +struct SchedulerThread; + +struct ProducerToken { + ProducerToken(TaskScheduler &scheduler, unique_ptr token); + ~ProducerToken(); + + TaskScheduler &scheduler; + unique_ptr token; + mutex producer_lock; +}; + +//! The TaskScheduler is responsible for managing tasks and threads +class TaskScheduler { + // timeout for semaphore wait, default 5ms + constexpr static int64_t TASK_TIMEOUT_USECS = 5000; + +public: + explicit TaskScheduler(DatabaseInstance &db); + ~TaskScheduler(); + + DUCKDB_API static TaskScheduler &GetScheduler(ClientContext &context); + DUCKDB_API static TaskScheduler &GetScheduler(DatabaseInstance &db); + + unique_ptr CreateProducer(); + //! Schedule a task to be executed by the task scheduler + void ScheduleTask(ProducerToken &producer, shared_ptr task); + //! Fetches a task from a specific producer, returns true if successful or false if no tasks were available + bool GetTaskFromProducer(ProducerToken &token, shared_ptr &task); + //! Run tasks forever until "marker" is set to false, "marker" must remain valid until the thread is joined + void ExecuteForever(atomic *marker); + //! Run tasks until `marker` is set to false, `max_tasks` have been completed, or until there are no more tasks + //! available. Returns the number of tasks that were completed. + idx_t ExecuteTasks(atomic *marker, idx_t max_tasks); + //! Run tasks until `max_tasks` have been completed, or until there are no more tasks available + void ExecuteTasks(idx_t max_tasks); + + //! Sets the amount of active threads executing tasks for the system; n-1 background threads will be launched. + //! The main thread will also be used for execution + void SetThreads(int32_t n); + //! Returns the number of threads + DUCKDB_API int32_t NumberOfThreads(); + + //! Send signals to n threads, signalling for them to wake up and attempt to execute a task + void Signal(idx_t n); + + //! Yield to other threads + static void YieldThread(); + + //! Set the allocator flush threshold + void SetAllocatorFlushTreshold(idx_t threshold); + +private: + void SetThreadsInternal(int32_t n); + +private: + DatabaseInstance &db; + //! The task queue + unique_ptr queue; + //! Lock for modifying the thread count + mutex thread_lock; + //! The active background threads of the task scheduler + vector> threads; + //! Markers used by the various threads, if the markers are set to "false" the thread execution is stopped + vector>> markers; + //! The threshold after which to flush the allocator after completing a task + atomic allocator_flush_threshold; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/thread_context.hpp b/src/duckdb/src/include/duckdb/parallel/thread_context.hpp new file mode 100644 index 00000000..90097e7f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/thread_context.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/thread_context.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/query_profiler.hpp" + +namespace duckdb { +class ClientContext; + +//! The ThreadContext holds thread-local info for parallel usage +class ThreadContext { +public: + explicit ThreadContext(ClientContext &context); + + //! The operator profiler for the individual thread context + OperatorProfiler profiler; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/base_expression.hpp b/src/duckdb/src/include/duckdb/parser/base_expression.hpp new file mode 100644 index 00000000..08c3e1db --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/base_expression.hpp @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/base_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +//! The BaseExpression class is a base class that can represent any expression +//! part of a SQL statement. +class BaseExpression { +public: + //! Create an Expression + BaseExpression(ExpressionType type, ExpressionClass expression_class) + : type(type), expression_class(expression_class) { + } + virtual ~BaseExpression() { + } + + //! Returns the type of the expression + ExpressionType GetExpressionType() const { + return type; + } + //! Returns the class of the expression + ExpressionClass GetExpressionClass() const { + return expression_class; + } + + //! Type of the expression + ExpressionType type; + //! The expression class of the node + ExpressionClass expression_class; + //! The alias of the expression, + string alias; + +public: + //! Returns true if this expression is an aggregate or not. + /*! + Examples: + + (1) SUM(a) + 1 -- True + + (2) a + 1 -- False + */ + virtual bool IsAggregate() const = 0; + //! Returns true if the expression has a window function or not + virtual bool IsWindow() const = 0; + //! Returns true if the query contains a subquery + virtual bool HasSubquery() const = 0; + //! Returns true if expression does not contain a group ref or col ref or parameter + virtual bool IsScalar() const = 0; + //! Returns true if the expression has a parameter + virtual bool HasParameter() const = 0; + + //! Get the name of the expression + virtual string GetName() const; + //! Convert the Expression to a String + virtual string ToString() const = 0; + //! Print the expression to stdout + void Print() const; + + //! Creates a hash value of this expression. It is important that if two expressions are identical (i.e. + //! Expression::Equals() returns true), that their hash value is identical as well. + virtual hash_t Hash() const = 0; + //! Returns true if this expression is equal to another expression + virtual bool Equals(const BaseExpression &other) const; + + static bool Equals(const BaseExpression &left, const BaseExpression &right) { + return left.Equals(right); + } + bool operator==(const BaseExpression &rhs) { + return Equals(rhs); + } + + virtual void Verify() const; + +public: + template + TARGET &Cast() { + if (expression_class != TARGET::TYPE) { + throw InternalException("Failed to cast expression to type - expression type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (expression_class != TARGET::TYPE) { + throw InternalException("Failed to cast expression to type - expression type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/column_definition.hpp b/src/duckdb/src/include/duckdb/parser/column_definition.hpp new file mode 100644 index 00000000..1510186a --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/column_definition.hpp @@ -0,0 +1,102 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/column_definition.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/enums/compression_type.hpp" +#include "duckdb/catalog/catalog_entry/table_column_type.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +struct RenameColumnInfo; +struct RenameTableInfo; + +class ColumnDefinition; + +//! A column of a table. +class ColumnDefinition { +public: + DUCKDB_API ColumnDefinition(string name, LogicalType type); + DUCKDB_API ColumnDefinition(string name, LogicalType type, unique_ptr expression, + TableColumnType category); + +public: + //! default_value + const unique_ptr &DefaultValue() const; + void SetDefaultValue(unique_ptr default_value); + + //! type + DUCKDB_API const LogicalType &Type() const; + LogicalType &TypeMutable(); + void SetType(const LogicalType &type); + + //! name + DUCKDB_API const string &Name() const; + void SetName(const string &name); + + //! compression_type + const duckdb::CompressionType &CompressionType() const; + void SetCompressionType(duckdb::CompressionType compression_type); + + //! storage_oid + const storage_t &StorageOid() const; + void SetStorageOid(storage_t storage_oid); + + LogicalIndex Logical() const; + PhysicalIndex Physical() const; + + //! oid + const column_t &Oid() const; + void SetOid(column_t oid); + + //! category + const TableColumnType &Category() const; + //! Whether this column is a Generated Column + bool Generated() const; + DUCKDB_API ColumnDefinition Copy() const; + + DUCKDB_API void Serialize(Serializer &serializer) const; + DUCKDB_API static ColumnDefinition Deserialize(Deserializer &deserializer); + + //===--------------------------------------------------------------------===// + // Generated Columns (VIRTUAL) + //===--------------------------------------------------------------------===// + + ParsedExpression &GeneratedExpressionMutable(); + const ParsedExpression &GeneratedExpression() const; + void SetGeneratedExpression(unique_ptr expression); + void ChangeGeneratedExpressionType(const LogicalType &type); + void GetListOfDependencies(vector &dependencies) const; + + string GetName() const; + + LogicalType GetType() const; + +private: + //! The name of the entry + string name; + //! The type of the column + LogicalType type; + //! Compression Type used for this column + duckdb::CompressionType compression_type = duckdb::CompressionType::COMPRESSION_AUTO; + //! The index of the column in the storage of the table + storage_t storage_oid = DConstants::INVALID_INDEX; + //! The index of the column in the table + idx_t oid = DConstants::INVALID_INDEX; + //! The category of the column + TableColumnType category = TableColumnType::STANDARD; + //! The default value of the column (for non-generated columns) + //! The generated column expression (for generated columns) + unique_ptr expression; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/column_list.hpp b/src/duckdb/src/include/duckdb/parser/column_list.hpp new file mode 100644 index 00000000..d692986b --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/column_list.hpp @@ -0,0 +1,129 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/column_list.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/column_definition.hpp" + +namespace duckdb { + +//! A set of column definitions +class ColumnList { +public: + class ColumnListIterator; + +public: + DUCKDB_API ColumnList(bool allow_duplicate_names = false); + DUCKDB_API explicit ColumnList(vector columns, bool allow_duplicate_names = false); + + DUCKDB_API void AddColumn(ColumnDefinition column); + void Finalize(); + + DUCKDB_API const ColumnDefinition &GetColumn(LogicalIndex index) const; + DUCKDB_API const ColumnDefinition &GetColumn(PhysicalIndex index) const; + DUCKDB_API const ColumnDefinition &GetColumn(const string &name) const; + DUCKDB_API ColumnDefinition &GetColumnMutable(LogicalIndex index); + DUCKDB_API ColumnDefinition &GetColumnMutable(PhysicalIndex index); + DUCKDB_API ColumnDefinition &GetColumnMutable(const string &name); + DUCKDB_API vector GetColumnNames() const; + DUCKDB_API vector GetColumnTypes() const; + + DUCKDB_API bool ColumnExists(const string &name) const; + + DUCKDB_API LogicalIndex GetColumnIndex(string &column_name) const; + DUCKDB_API PhysicalIndex LogicalToPhysical(LogicalIndex index) const; + DUCKDB_API LogicalIndex PhysicalToLogical(PhysicalIndex index) const; + + idx_t LogicalColumnCount() const { + return columns.size(); + } + idx_t PhysicalColumnCount() const { + return physical_columns.size(); + } + bool empty() const { + return columns.empty(); + } + + ColumnList Copy() const; + void Serialize(Serializer &serializer) const; + static ColumnList Deserialize(Deserializer &deserializer); + + DUCKDB_API ColumnListIterator Logical() const; + DUCKDB_API ColumnListIterator Physical() const; + + void SetAllowDuplicates(bool allow_duplicates) { + allow_duplicate_names = allow_duplicates; + } + +private: + vector columns; + //! A map of column name to column index + case_insensitive_map_t name_map; + //! The set of physical columns + vector physical_columns; + //! Allow duplicate names or not + bool allow_duplicate_names; + +private: + void AddToNameMap(ColumnDefinition &column); + +public: + // logical iterator + class ColumnListIterator { + public: + ColumnListIterator(const ColumnList &list, bool physical) : list(list), physical(physical) { + } + + private: + const ColumnList &list; + bool physical; + + private: + class ColumnLogicalIteratorInternal { + public: + ColumnLogicalIteratorInternal(const ColumnList &list, bool physical, idx_t pos, idx_t end) + : list(list), physical(physical), pos(pos), end(end) { + } + + const ColumnList &list; + bool physical; + idx_t pos; + idx_t end; + + public: + ColumnLogicalIteratorInternal &operator++() { + pos++; + return *this; + } + bool operator!=(const ColumnLogicalIteratorInternal &other) const { + return pos != other.pos || end != other.end || &list != &other.list; + } + const ColumnDefinition &operator*() const { + if (physical) { + return list.GetColumn(PhysicalIndex(pos)); + } else { + return list.GetColumn(LogicalIndex(pos)); + } + } + }; + + public: + idx_t Size() { + return physical ? list.PhysicalColumnCount() : list.LogicalColumnCount(); + } + + ColumnLogicalIteratorInternal begin() { + return ColumnLogicalIteratorInternal(list, physical, 0, Size()); + } + ColumnLogicalIteratorInternal end() { + return ColumnLogicalIteratorInternal(list, physical, Size(), Size()); + } + }; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp new file mode 100644 index 00000000..80d7c897 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/common_table_expression_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/common/enums/cte_materialize.hpp" + +namespace duckdb { + +class SelectStatement; + +struct CommonTableExpressionInfo { + vector aliases; + unique_ptr query; + CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + unique_ptr Copy(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraint.hpp new file mode 100644 index 00000000..2d19d5ec --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/constraint.hpp @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +class Serializer; +class Deserializer; + +//===--------------------------------------------------------------------===// +// Constraint Types +//===--------------------------------------------------------------------===// +enum class ConstraintType : uint8_t { + INVALID = 0, // invalid constraint type + NOT_NULL = 1, // NOT NULL constraint + CHECK = 2, // CHECK constraint + UNIQUE = 3, // UNIQUE constraint + FOREIGN_KEY = 4, // FOREIGN KEY constraint +}; + +enum class ForeignKeyType : uint8_t { + FK_TYPE_PRIMARY_KEY_TABLE = 0, // main table + FK_TYPE_FOREIGN_KEY_TABLE = 1, // referencing table + FK_TYPE_SELF_REFERENCE_TABLE = 2 // self refrencing table +}; + +struct ForeignKeyInfo { + ForeignKeyType type; + string schema; + //! if type is FK_TYPE_FOREIGN_KEY_TABLE, means main key table, if type is FK_TYPE_PRIMARY_KEY_TABLE, means foreign + //! key table + string table; + //! The set of main key table's column's index + vector pk_keys; + //! The set of foreign key table's column's index + vector fk_keys; +}; + +//! Constraint is the base class of any type of table constraint. +class Constraint { +public: + DUCKDB_API explicit Constraint(ConstraintType type); + DUCKDB_API virtual ~Constraint(); + + ConstraintType type; + +public: + DUCKDB_API virtual string ToString() const = 0; + DUCKDB_API void Print() const; + + DUCKDB_API virtual unique_ptr Copy() const = 0; + + DUCKDB_API virtual void Serialize(Serializer &serializer) const; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast constraint to type - constraint type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast constraint to type - constraint type mismatch"); + } + return reinterpret_cast(*this); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/constraints/check_constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraints/check_constraint.hpp new file mode 100644 index 00000000..cba2a61f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/constraints/check_constraint.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/constraints/check_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +//! The CheckConstraint contains an expression that must evaluate to TRUE for +//! every row in a table +class CheckConstraint : public Constraint { +public: + static constexpr const ConstraintType TYPE = ConstraintType::CHECK; + +public: + DUCKDB_API explicit CheckConstraint(unique_ptr expression); + + unique_ptr expression; + +public: + DUCKDB_API string ToString() const override; + + DUCKDB_API unique_ptr Copy() const override; + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/constraints/foreign_key_constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraints/foreign_key_constraint.hpp new file mode 100644 index 00000000..7601a232 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/constraints/foreign_key_constraint.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/constraints/foreign_key_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/constraint.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class ForeignKeyConstraint : public Constraint { +public: + static constexpr const ConstraintType TYPE = ConstraintType::FOREIGN_KEY; + +public: + DUCKDB_API ForeignKeyConstraint(vector pk_columns, vector fk_columns, ForeignKeyInfo info); + + //! The set of main key table's columns + vector pk_columns; + //! The set of foreign key table's columns + vector fk_columns; + ForeignKeyInfo info; + +public: + DUCKDB_API string ToString() const override; + + DUCKDB_API unique_ptr Copy() const override; + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); + +private: + ForeignKeyConstraint(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/constraints/list.hpp b/src/duckdb/src/include/duckdb/parser/constraints/list.hpp new file mode 100644 index 00000000..bec392cf --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/constraints/list.hpp @@ -0,0 +1,4 @@ +#include "duckdb/parser/constraints/check_constraint.hpp" +#include "duckdb/parser/constraints/not_null_constraint.hpp" +#include "duckdb/parser/constraints/unique_constraint.hpp" +#include "duckdb/parser/constraints/foreign_key_constraint.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/constraints/not_null_constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraints/not_null_constraint.hpp new file mode 100644 index 00000000..d148065d --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/constraints/not_null_constraint.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/constraints/not_null_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/constraint.hpp" + +namespace duckdb { + +class NotNullConstraint : public Constraint { +public: + static constexpr const ConstraintType TYPE = ConstraintType::NOT_NULL; + +public: + DUCKDB_API explicit NotNullConstraint(LogicalIndex index); + DUCKDB_API ~NotNullConstraint() override; + + //! Column index this constraint pertains to + LogicalIndex index; + +public: + DUCKDB_API string ToString() const override; + + DUCKDB_API unique_ptr Copy() const override; + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp new file mode 100644 index 00000000..66073108 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/constraints/unique_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/constraint.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class UniqueConstraint : public Constraint { +public: + static constexpr const ConstraintType TYPE = ConstraintType::UNIQUE; + +public: + DUCKDB_API UniqueConstraint(LogicalIndex index, bool is_primary_key); + DUCKDB_API UniqueConstraint(vector columns, bool is_primary_key); + + //! The index of the column for which this constraint holds. Only used when the constraint relates to a single + //! column, equal to DConstants::INVALID_INDEX if not used + LogicalIndex index; + //! The set of columns for which this constraint holds by name. Only used when the index field is not used. + vector columns; + //! Whether or not this is a PRIMARY KEY constraint, or a UNIQUE constraint. + bool is_primary_key; + +public: + DUCKDB_API string ToString() const override; + + DUCKDB_API unique_ptr Copy() const override; + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); + +private: + UniqueConstraint(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/between_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/between_expression.hpp new file mode 100644 index 00000000..d8e47493 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/between_expression.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/between_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +class BetweenExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BETWEEN; + +public: + DUCKDB_API BetweenExpression(unique_ptr input, unique_ptr lower, + unique_ptr upper); + + unique_ptr input; + unique_ptr lower; + unique_ptr upper; + +public: + string ToString() const override; + + static bool Equal(const BetweenExpression &a, const BetweenExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + static string ToString(const T &entry) { + return "(" + entry.input->ToString() + " BETWEEN " + entry.lower->ToString() + " AND " + + entry.upper->ToString() + ")"; + } + +private: + BetweenExpression(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/bound_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/bound_expression.hpp new file mode 100644 index 00000000..abbf57ef --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/bound_expression.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/bound_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! BoundExpression is an intermediate dummy class used by the binder. It is a ParsedExpression but holds an Expression. +//! It represents a successfully bound expression. It is used in the Binder to prevent re-binding of already bound parts +//! when dealing with subqueries. +class BoundExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_EXPRESSION; + +public: + BoundExpression(unique_ptr expr); + + unique_ptr expr; + +public: + static unique_ptr &GetExpression(ParsedExpression &expr); + + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + hash_t Hash() const override; + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/case_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/case_expression.hpp new file mode 100644 index 00000000..8478f05b --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/case_expression.hpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/case_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +struct CaseCheck { + unique_ptr when_expr; + unique_ptr then_expr; + + void Serialize(Serializer &serializer) const; + static CaseCheck Deserialize(Deserializer &deserializer); +}; + +//! The CaseExpression represents a CASE expression in the query +class CaseExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::CASE; + +public: + DUCKDB_API CaseExpression(); + + vector case_checks; + unique_ptr else_expr; + +public: + string ToString() const override; + + static bool Equal(const CaseExpression &a, const CaseExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + static string ToString(const T &entry) { + string case_str = "CASE "; + for (auto &check : entry.case_checks) { + case_str += " WHEN (" + check.when_expr->ToString() + ")"; + case_str += " THEN (" + check.then_expr->ToString() + ")"; + } + case_str += " ELSE " + entry.else_expr->ToString(); + case_str += " END"; + return case_str; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/cast_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/cast_expression.hpp new file mode 100644 index 00000000..7c48e0dd --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/cast_expression.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/cast_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +//! CastExpression represents a type cast from one SQL type to another SQL type +class CastExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::CAST; + +public: + DUCKDB_API CastExpression(LogicalType target, unique_ptr child, bool try_cast = false); + + //! The child of the cast expression + unique_ptr child; + //! The type to cast to + LogicalType cast_type; + //! Whether or not this is a try_cast expression + bool try_cast; + +public: + string ToString() const override; + + static bool Equal(const CastExpression &a, const CastExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + static string ToString(const T &entry) { + return (entry.try_cast ? "TRY_CAST(" : "CAST(") + entry.child->ToString() + " AS " + + entry.cast_type.ToString() + ")"; + } + +private: + CastExpression(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/collate_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/collate_expression.hpp new file mode 100644 index 00000000..6f3bcac6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/collate_expression.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/collate_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +//! CollateExpression represents a COLLATE statement +class CollateExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::COLLATE; + +public: + CollateExpression(string collation, unique_ptr child); + + //! The child of the cast expression + unique_ptr child; + //! The collation clause + string collation; + +public: + string ToString() const override; + + static bool Equal(const CollateExpression &a, const CollateExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + CollateExpression(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/columnref_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/columnref_expression.hpp new file mode 100644 index 00000000..f50717ef --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/columnref_expression.hpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/columnref_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +//! Represents a reference to a column from either the FROM clause or from an +//! alias +class ColumnRefExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::COLUMN_REF; + +public: + //! Specify both the column and table name + ColumnRefExpression(string column_name, string table_name); + //! Only specify the column name, the table name will be derived later + explicit ColumnRefExpression(string column_name); + //! Specify a set of names + explicit ColumnRefExpression(vector column_names); + + //! The stack of names in order of which they appear (column_names[0].column_names[1].column_names[2]....) + vector column_names; + +public: + bool IsQualified() const; + const string &GetColumnName() const; + const string &GetTableName() const; + bool IsScalar() const override { + return false; + } + + string GetName() const override; + string ToString() const override; + + static bool Equal(const ColumnRefExpression &a, const ColumnRefExpression &b); + hash_t Hash() const override; + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + ColumnRefExpression(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/comparison_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/comparison_expression.hpp new file mode 100644 index 00000000..9c9a82e0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/comparison_expression.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/comparison_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { +//! ComparisonExpression represents a boolean comparison (e.g. =, >=, <>). Always returns a boolean +//! and has two children. +class ComparisonExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::COMPARISON; + +public: + DUCKDB_API ComparisonExpression(ExpressionType type, unique_ptr left, + unique_ptr right); + + unique_ptr left; + unique_ptr right; + +public: + string ToString() const override; + + static bool Equal(const ComparisonExpression &a, const ComparisonExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + static string ToString(const T &entry) { + return StringUtil::Format("(%s %s %s)", entry.left->ToString(), ExpressionTypeToOperator(entry.type), + entry.right->ToString()); + } + +private: + explicit ComparisonExpression(ExpressionType type); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/conjunction_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/conjunction_expression.hpp new file mode 100644 index 00000000..748bcb01 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/conjunction_expression.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/conjunction_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +//! Represents a conjunction (AND/OR) +class ConjunctionExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::CONJUNCTION; + +public: + DUCKDB_API explicit ConjunctionExpression(ExpressionType type); + DUCKDB_API ConjunctionExpression(ExpressionType type, vector> children); + DUCKDB_API ConjunctionExpression(ExpressionType type, unique_ptr left, + unique_ptr right); + + vector> children; + +public: + void AddExpression(unique_ptr expr); + + string ToString() const override; + + static bool Equal(const ConjunctionExpression &a, const ConjunctionExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + static string ToString(const T &entry) { + string result = "(" + entry.children[0]->ToString(); + for (idx_t i = 1; i < entry.children.size(); i++) { + result += " " + ExpressionTypeToOperator(entry.type) + " " + entry.children[i]->ToString(); + } + return result + ")"; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/constant_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/constant_expression.hpp new file mode 100644 index 00000000..9da6085f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/constant_expression.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/constant_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/value.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +//! ConstantExpression represents a constant value in the query +class ConstantExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::CONSTANT; + +public: + DUCKDB_API explicit ConstantExpression(Value val); + + //! The constant value referenced + Value value; + +public: + string ToString() const override; + + static bool Equal(const ConstantExpression &a, const ConstantExpression &b); + hash_t Hash() const override; + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + ConstantExpression(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/default_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/default_expression.hpp new file mode 100644 index 00000000..e7bc89b7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/default_expression.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/default_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { +//! Represents the default value of a column +class DefaultExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::DEFAULT; + +public: + DefaultExpression(); + +public: + bool IsScalar() const override { + return false; + } + + string ToString() const override; + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp new file mode 100644 index 00000000..d11eb8b6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp @@ -0,0 +1,125 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/function_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/vector.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/result_modifier.hpp" + +namespace duckdb { +//! Represents a function call +class FunctionExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::FUNCTION; + +public: + DUCKDB_API FunctionExpression(string catalog_name, string schema_name, const string &function_name, + vector> children, + unique_ptr filter = nullptr, + unique_ptr order_bys = nullptr, bool distinct = false, + bool is_operator = false, bool export_state = false); + DUCKDB_API FunctionExpression(const string &function_name, vector> children, + unique_ptr filter = nullptr, + unique_ptr order_bys = nullptr, bool distinct = false, + bool is_operator = false, bool export_state = false); + + //! Catalog of the function + string catalog; + //! Schema of the function + string schema; + //! Function name + string function_name; + //! Whether or not the function is an operator, only used for rendering + bool is_operator; + //! List of arguments to the function + vector> children; + //! Whether or not the aggregate function is distinct, only used for aggregates + bool distinct; + //! Expression representing a filter, only used for aggregates + unique_ptr filter; + //! Modifier representing an ORDER BY, only used for aggregates + unique_ptr order_bys; + //! whether this function should export its state or not + bool export_state; + +public: + string ToString() const override; + + unique_ptr Copy() const override; + + static bool Equal(const FunctionExpression &a, const FunctionExpression &b); + hash_t Hash() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + void Verify() const override; + +public: + template + static string ToString(const T &entry, const string &schema, const string &function_name, bool is_operator = false, + bool distinct = false, BASE *filter = nullptr, ORDER_MODIFIER *order_bys = nullptr, + bool export_state = false, bool add_alias = false) { + if (is_operator) { + // built-in operator + D_ASSERT(!distinct); + if (entry.children.size() == 1) { + if (StringUtil::Contains(function_name, "__postfix")) { + return "((" + entry.children[0]->ToString() + ")" + + StringUtil::Replace(function_name, "__postfix", "") + ")"; + } else { + return function_name + "(" + entry.children[0]->ToString() + ")"; + } + } else if (entry.children.size() == 2) { + return StringUtil::Format("(%s %s %s)", entry.children[0]->ToString(), function_name, + entry.children[1]->ToString()); + } + } + // standard function call + string result = schema.empty() ? function_name : schema + "." + function_name; + result += "("; + if (distinct) { + result += "DISTINCT "; + } + result += StringUtil::Join(entry.children, entry.children.size(), ", ", [&](const unique_ptr &child) { + return child->alias.empty() || !add_alias + ? child->ToString() + : StringUtil::Format("%s := %s", SQLIdentifier(child->alias), child->ToString()); + }); + // ordered aggregate + if (order_bys && !order_bys->orders.empty()) { + if (entry.children.empty()) { + result += ") WITHIN GROUP ("; + } + result += " ORDER BY "; + for (idx_t i = 0; i < order_bys->orders.size(); i++) { + if (i > 0) { + result += ", "; + } + result += order_bys->orders[i].ToString(); + } + } + result += ")"; + + // filtered aggregate + if (filter) { + result += " FILTER (WHERE " + filter->ToString() + ")"; + } + + if (export_state) { + result += " EXPORT_STATE"; + } + + return result; + } + +private: + FunctionExpression(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/lambda_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/lambda_expression.hpp new file mode 100644 index 00000000..5eae6e94 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/lambda_expression.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/lambda_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +//! LambdaExpression represents either: +//! 1. A lambda operator that can be used for e.g. mapping an expression to a list +//! 2. An OperatorExpression with the "->" operator +//! Lambda expressions are written in the form of "params -> expr", e.g. "x -> x + 1" +class LambdaExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::LAMBDA; + +public: + LambdaExpression(unique_ptr lhs, unique_ptr expr); + + // we need the context to determine if this is a list of column references or an expression (for JSON) + unique_ptr lhs; + + vector> params; + unique_ptr expr; + +public: + string ToString() const override; + + static bool Equal(const LambdaExpression &a, const LambdaExpression &b); + hash_t Hash() const override; + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + LambdaExpression(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/list.hpp b/src/duckdb/src/include/duckdb/parser/expression/list.hpp new file mode 100644 index 00000000..7eb8bc8f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/list.hpp @@ -0,0 +1,18 @@ +#include "duckdb/parser/expression/between_expression.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/collate_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/default_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/lambda_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/parser/expression/positional_reference_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/expression/window_expression.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp new file mode 100644 index 00000000..5867860e --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp @@ -0,0 +1,130 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/operator_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/qualified_name.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" + +namespace duckdb { +//! Represents a built-in operator expression +class OperatorExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::OPERATOR; + +public: + DUCKDB_API explicit OperatorExpression(ExpressionType type, unique_ptr left = nullptr, + unique_ptr right = nullptr); + DUCKDB_API OperatorExpression(ExpressionType type, vector> children); + + vector> children; + +public: + string ToString() const override; + + static bool Equal(const OperatorExpression &a, const OperatorExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + static string ToString(const T &entry) { + auto op = ExpressionTypeToOperator(entry.type); + if (!op.empty()) { + // use the operator string to represent the operator + D_ASSERT(entry.children.size() == 2); + return entry.children[0]->ToString() + " " + op + " " + entry.children[1]->ToString(); + } + switch (entry.type) { + case ExpressionType::COMPARE_IN: + case ExpressionType::COMPARE_NOT_IN: { + string op_type = entry.type == ExpressionType::COMPARE_IN ? " IN " : " NOT IN "; + string in_child = entry.children[0]->ToString(); + string child_list = "("; + for (idx_t i = 1; i < entry.children.size(); i++) { + if (i > 1) { + child_list += ", "; + } + child_list += entry.children[i]->ToString(); + } + child_list += ")"; + return "(" + in_child + op_type + child_list + ")"; + } + case ExpressionType::OPERATOR_NOT: { + string result = "("; + result += ExpressionTypeToString(entry.type); + result += " "; + result += StringUtil::Join(entry.children, entry.children.size(), ", ", + [](const unique_ptr &child) { return child->ToString(); }); + result += ")"; + return result; + } + case ExpressionType::GROUPING_FUNCTION: + case ExpressionType::OPERATOR_COALESCE: { + string result = ExpressionTypeToString(entry.type); + result += "("; + result += StringUtil::Join(entry.children, entry.children.size(), ", ", + [](const unique_ptr &child) { return child->ToString(); }); + result += ")"; + return result; + } + case ExpressionType::OPERATOR_IS_NULL: + return "(" + entry.children[0]->ToString() + " IS NULL)"; + case ExpressionType::OPERATOR_IS_NOT_NULL: + return "(" + entry.children[0]->ToString() + " IS NOT NULL)"; + case ExpressionType::ARRAY_EXTRACT: + return entry.children[0]->ToString() + "[" + entry.children[1]->ToString() + "]"; + case ExpressionType::ARRAY_SLICE: { + string begin = entry.children[1]->ToString(); + if (begin == "[]") { + begin = ""; + } + string end = entry.children[2]->ToString(); + if (end == "[]") { + if (entry.children.size() == 4) { + end = "-"; + } else { + end = ""; + } + } + if (entry.children.size() == 4) { + return entry.children[0]->ToString() + "[" + begin + ":" + end + ":" + entry.children[3]->ToString() + + "]"; + } + return entry.children[0]->ToString() + "[" + begin + ":" + end + "]"; + } + case ExpressionType::STRUCT_EXTRACT: { + if (entry.children[1]->type != ExpressionType::VALUE_CONSTANT) { + return string(); + } + auto child_string = entry.children[1]->ToString(); + D_ASSERT(child_string.size() >= 3); + D_ASSERT(child_string[0] == '\'' && child_string[child_string.size() - 1] == '\''); + return StringUtil::Format("(%s).%s", entry.children[0]->ToString(), + SQLIdentifier(child_string.substr(1, child_string.size() - 2))); + } + case ExpressionType::ARRAY_CONSTRUCTOR: { + string result = "(ARRAY["; + result += StringUtil::Join(entry.children, entry.children.size(), ", ", + [](const unique_ptr &child) { return child->ToString(); }); + result += "])"; + return result; + } + default: + throw InternalException("Unrecognized operator type"); + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/parameter_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/parameter_expression.hpp new file mode 100644 index 00000000..ac97ecbb --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/parameter_expression.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/parameter_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +// Parameters come in three different types: +// auto-increment: +// token: '?' +// name: - +// number: 0 +// positional: +// token: '$' +// name: - +// number: +// named: +// token: '$' +// name: +// number: 0 +enum class PreparedParamType : uint8_t { AUTO_INCREMENT, POSITIONAL, NAMED, INVALID }; + +class ParameterExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::PARAMETER; + +public: + ParameterExpression(); + + string identifier; + +public: + bool IsScalar() const override { + return true; + } + bool HasParameter() const override { + return true; + } + + string ToString() const override; + + static bool Equal(const ParameterExpression &a, const ParameterExpression &b); + + unique_ptr Copy() const override; + hash_t Hash() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/positional_reference_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/positional_reference_expression.hpp new file mode 100644 index 00000000..e2da1973 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/positional_reference_expression.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/positional_reference_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { +class PositionalReferenceExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::POSITIONAL_REFERENCE; + +public: + DUCKDB_API PositionalReferenceExpression(idx_t index); + + idx_t index; + +public: + bool IsScalar() const override { + return false; + } + + string ToString() const override; + + static bool Equal(const PositionalReferenceExpression &a, const PositionalReferenceExpression &b); + unique_ptr Copy() const override; + hash_t Hash() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + PositionalReferenceExpression(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/star_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/star_expression.hpp new file mode 100644 index 00000000..83f61585 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/star_expression.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/star_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +//! Represents a * expression in the SELECT clause +class StarExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::STAR; + +public: + StarExpression(string relation_name = string()); + + //! The relation name in case of tbl.*, or empty if this is a normal * + string relation_name; + //! List of columns to exclude from the STAR expression + case_insensitive_set_t exclude_list; + //! List of columns to replace with another expression + case_insensitive_map_t> replace_list; + //! The expression to select the columns (regular expression or list) + unique_ptr expr; + //! Whether or not this is a COLUMNS expression + bool columns = false; + +public: + string ToString() const override; + + static bool Equal(const StarExpression &a, const StarExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/subquery_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/subquery_expression.hpp new file mode 100644 index 00000000..b01bda76 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/subquery_expression.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/subquery_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/subquery_type.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/statement/select_statement.hpp" + +namespace duckdb { + +//! Represents a subquery +class SubqueryExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::SUBQUERY; + +public: + SubqueryExpression(); + + //! The actual subquery + unique_ptr subquery; + //! The subquery type + SubqueryType subquery_type; + //! the child expression to compare with (in case of IN, ANY, ALL operators, empty for EXISTS queries and scalar + //! subquery) + unique_ptr child; + //! The comparison type of the child expression with the subquery (in case of ANY, ALL operators), empty otherwise + ExpressionType comparison_type; + +public: + bool HasSubquery() const override { + return true; + } + bool IsScalar() const override { + return false; + } + + string ToString() const override; + + static bool Equal(const SubqueryExpression &a, const SubqueryExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp new file mode 100644 index 00000000..14a42304 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp @@ -0,0 +1,216 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression/window_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +enum class WindowBoundary : uint8_t { + INVALID = 0, + UNBOUNDED_PRECEDING = 1, + UNBOUNDED_FOLLOWING = 2, + CURRENT_ROW_RANGE = 3, + CURRENT_ROW_ROWS = 4, + EXPR_PRECEDING_ROWS = 5, + EXPR_FOLLOWING_ROWS = 6, + EXPR_PRECEDING_RANGE = 7, + EXPR_FOLLOWING_RANGE = 8 +}; + +const char *ToString(WindowBoundary value); + +//! The WindowExpression represents a window function in the query. They are a special case of aggregates which is why +//! they inherit from them. +class WindowExpression : public ParsedExpression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::WINDOW; + +public: + WindowExpression(ExpressionType type, string catalog_name, string schema_name, const string &function_name); + + //! Catalog of the aggregate function + string catalog; + //! Schema of the aggregate function + string schema; + //! Name of the aggregate function + string function_name; + //! The child expression of the main window function + vector> children; + //! The set of expressions to partition by + vector> partitions; + //! The set of ordering clauses + vector orders; + //! Expression representing a filter, only used for aggregates + unique_ptr filter_expr; + //! True to ignore NULL values + bool ignore_nulls; + //! The window boundaries + WindowBoundary start = WindowBoundary::INVALID; + WindowBoundary end = WindowBoundary::INVALID; + + unique_ptr start_expr; + unique_ptr end_expr; + //! Offset and default expressions for WINDOW_LEAD and WINDOW_LAG functions + unique_ptr offset_expr; + unique_ptr default_expr; + +public: + bool IsWindow() const override { + return true; + } + + //! Convert the Expression to a String + string ToString() const override; + + static bool Equal(const WindowExpression &a, const WindowExpression &b); + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + static ExpressionType WindowToExpressionType(string &fun_name); + +public: + template + static string ToString(const T &entry, const string &schema, const string &function_name) { + // Start with function call + string result = schema.empty() ? function_name : schema + "." + function_name; + result += "("; + if (entry.children.size()) { + result += StringUtil::Join(entry.children, entry.children.size(), ", ", + [](const unique_ptr &child) { return child->ToString(); }); + } + // Lead/Lag extra arguments + if (entry.offset_expr.get()) { + result += ", "; + result += entry.offset_expr->ToString(); + } + if (entry.default_expr.get()) { + result += ", "; + result += entry.default_expr->ToString(); + } + // IGNORE NULLS + if (entry.ignore_nulls) { + result += " IGNORE NULLS"; + } + // FILTER + if (entry.filter_expr) { + result += ") FILTER (WHERE " + entry.filter_expr->ToString(); + } + + // Over clause + result += ") OVER ("; + string sep; + + // Partitions + if (!entry.partitions.empty()) { + result += "PARTITION BY "; + result += StringUtil::Join(entry.partitions, entry.partitions.size(), ", ", + [](const unique_ptr &partition) { return partition->ToString(); }); + sep = " "; + } + + // Orders + if (!entry.orders.empty()) { + result += sep; + result += "ORDER BY "; + result += StringUtil::Join(entry.orders, entry.orders.size(), ", ", + [](const ORDER_NODE &order) { return order.ToString(); }); + sep = " "; + } + + // Rows/Range + string units = "ROWS"; + string from; + switch (entry.start) { + case WindowBoundary::CURRENT_ROW_RANGE: + case WindowBoundary::CURRENT_ROW_ROWS: + from = "CURRENT ROW"; + units = (entry.start == WindowBoundary::CURRENT_ROW_RANGE) ? "RANGE" : "ROWS"; + break; + case WindowBoundary::UNBOUNDED_PRECEDING: + if (entry.end != WindowBoundary::CURRENT_ROW_RANGE) { + from = "UNBOUNDED PRECEDING"; + } + break; + case WindowBoundary::EXPR_PRECEDING_ROWS: + case WindowBoundary::EXPR_PRECEDING_RANGE: + from = entry.start_expr->ToString() + " PRECEDING"; + units = (entry.start == WindowBoundary::EXPR_PRECEDING_RANGE) ? "RANGE" : "ROWS"; + break; + case WindowBoundary::EXPR_FOLLOWING_ROWS: + case WindowBoundary::EXPR_FOLLOWING_RANGE: + from = entry.start_expr->ToString() + " FOLLOWING"; + units = (entry.start == WindowBoundary::EXPR_FOLLOWING_RANGE) ? "RANGE" : "ROWS"; + break; + default: + throw InternalException("Unrecognized FROM in WindowExpression"); + } + + string to; + switch (entry.end) { + case WindowBoundary::CURRENT_ROW_RANGE: + if (entry.start != WindowBoundary::UNBOUNDED_PRECEDING) { + to = "CURRENT ROW"; + units = "RANGE"; + } + break; + case WindowBoundary::CURRENT_ROW_ROWS: + to = "CURRENT ROW"; + units = "ROWS"; + break; + case WindowBoundary::UNBOUNDED_PRECEDING: + to = "UNBOUNDED PRECEDING"; + break; + case WindowBoundary::UNBOUNDED_FOLLOWING: + to = "UNBOUNDED FOLLOWING"; + break; + case WindowBoundary::EXPR_PRECEDING_ROWS: + case WindowBoundary::EXPR_PRECEDING_RANGE: + to = entry.end_expr->ToString() + " PRECEDING"; + units = (entry.end == WindowBoundary::EXPR_PRECEDING_RANGE) ? "RANGE" : "ROWS"; + break; + case WindowBoundary::EXPR_FOLLOWING_ROWS: + case WindowBoundary::EXPR_FOLLOWING_RANGE: + to = entry.end_expr->ToString() + " FOLLOWING"; + units = (entry.end == WindowBoundary::EXPR_FOLLOWING_RANGE) ? "RANGE" : "ROWS"; + break; + default: + throw InternalException("Unrecognized TO in WindowExpression"); + } + + if (!from.empty() || !to.empty()) { + result += sep + units; + } + if (!from.empty() && !to.empty()) { + result += " BETWEEN "; + result += from; + result += " AND "; + result += to; + } else if (!from.empty()) { + result += " "; + result += from; + } else if (!to.empty()) { + result += " "; + result += to; + } + + result += ")"; + + return result; + } + +private: + explicit WindowExpression(ExpressionType type); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression_map.hpp b/src/duckdb/src/include/duckdb/parser/expression_map.hpp new file mode 100644 index 00000000..75ecd78e --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression_map.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/parser/base_expression.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { +class Expression; + +template +struct ExpressionHashFunction { + uint64_t operator()(const reference &expr) const { + return (uint64_t)expr.get().Hash(); + } +}; + +template +struct ExpressionEquality { + bool operator()(const reference &a, const reference &b) const { + return a.get().Equals(b.get()); + } +}; + +template +using expression_map_t = + unordered_map, T, ExpressionHashFunction, ExpressionEquality>; + +using expression_set_t = + unordered_set, ExpressionHashFunction, ExpressionEquality>; + +template +using parsed_expression_map_t = unordered_map, T, ExpressionHashFunction, + ExpressionEquality>; + +using parsed_expression_set_t = unordered_set, ExpressionHashFunction, + ExpressionEquality>; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression_util.hpp b/src/duckdb/src/include/duckdb/parser/expression_util.hpp new file mode 100644 index 00000000..24112d1f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/expression_util.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/expression_util.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/base_expression.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +class ParsedExpression; +class Expression; + +class ExpressionUtil { +public: + //! ListEquals: check if a list of two expressions is equal (order is important) + static bool ListEquals(const vector> &a, + const vector> &b); + static bool ListEquals(const vector> &a, const vector> &b); + //! SetEquals: check if two sets of expressions are equal (order is not important) + static bool SetEquals(const vector> &a, const vector> &b); + static bool SetEquals(const vector> &a, const vector> &b); + +private: + template + static bool ExpressionListEquals(const vector> &a, const vector> &b); + template + static bool ExpressionSetEquals(const vector> &a, const vector> &b); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/group_by_node.hpp b/src/duckdb/src/include/duckdb/parser/group_by_node.hpp new file mode 100644 index 00000000..02a8eecc --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/group_by_node.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/group_by_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +using GroupingSet = set; + +class GroupByNode { +public: + //! The total set of all group expressions + vector> group_expressions; + //! The different grouping sets as they map to the group expressions + vector grouping_sets; + +public: + GroupByNode Copy() { + GroupByNode node; + node.group_expressions.reserve(group_expressions.size()); + for (auto &expr : group_expressions) { + node.group_expressions.push_back(expr->Copy()); + } + node.grouping_sets = grouping_sets; + return node; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/keyword_helper.hpp b/src/duckdb/src/include/duckdb/parser/keyword_helper.hpp new file mode 100644 index 00000000..05c1664f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/keyword_helper.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/keyword_helper.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +class KeywordHelper { +public: + //! Returns true if the given text matches a keyword of the parser + static bool IsKeyword(const string &text); + + static string EscapeQuotes(const string &text, char quote = '"'); + + //! Returns true if the given string needs to be quoted when written as an identifier + static bool RequiresQuotes(const string &text, bool allow_caps = true); + + //! Writes a string that is quoted + static string WriteQuoted(const string &text, char quote = '\''); + + //! Writes a string that is optionally quoted + escaped so it can be used as an identifier + static string WriteOptionallyQuoted(const string &text, char quote = '"', bool allow_caps = true); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp new file mode 100644 index 00000000..2f7e5a00 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp @@ -0,0 +1,79 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/alter_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" + +namespace duckdb { + +enum class AlterType : uint8_t { + INVALID = 0, + ALTER_TABLE = 1, + ALTER_VIEW = 2, + ALTER_SEQUENCE = 3, + CHANGE_OWNERSHIP = 4, + ALTER_SCALAR_FUNCTION = 5, + ALTER_TABLE_FUNCTION = 6 +}; + +struct AlterEntryData { + AlterEntryData() { + } + AlterEntryData(string catalog_p, string schema_p, string name_p, OnEntryNotFound if_not_found) + : catalog(std::move(catalog_p)), schema(std::move(schema_p)), name(std::move(name_p)), + if_not_found(if_not_found) { + } + + string catalog; + string schema; + string name; + OnEntryNotFound if_not_found; +}; + +struct AlterInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::ALTER_INFO; + +public: + AlterInfo(AlterType type, string catalog, string schema, string name, OnEntryNotFound if_not_found); + ~AlterInfo() override; + + AlterType type; + //! if exists + OnEntryNotFound if_not_found; + //! Catalog name to alter + string catalog; + //! Schema name to alter + string schema; + //! Entry name to alter + string name; + //! Allow altering internal entries + bool allow_internal; + +public: + virtual CatalogType GetCatalogType() const = 0; + virtual unique_ptr Copy() const = 0; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + virtual string GetColumnName() const { + return ""; + }; + + AlterEntryData GetAlterEntryData() const; + +protected: + explicit AlterInfo(AlterType type); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_scalar_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_scalar_function_info.hpp new file mode 100644 index 00000000..d7e87afb --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_scalar_function_info.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/alter_scalar_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/parser/parsed_data/alter_info.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Alter Scalar Function +//===--------------------------------------------------------------------===// +enum class AlterScalarFunctionType : uint8_t { INVALID = 0, ADD_FUNCTION_OVERLOADS = 1 }; + +struct AlterScalarFunctionInfo : public AlterInfo { + AlterScalarFunctionInfo(AlterScalarFunctionType type, AlterEntryData data); + virtual ~AlterScalarFunctionInfo() override; + + AlterScalarFunctionType alter_scalar_function_type; + +public: + CatalogType GetCatalogType() const override; +}; + +//===--------------------------------------------------------------------===// +// AddScalarFunctionOverloadInfo +//===--------------------------------------------------------------------===// +struct AddScalarFunctionOverloadInfo : public AlterScalarFunctionInfo { + AddScalarFunctionOverloadInfo(AlterEntryData data, ScalarFunctionSet new_overloads); + ~AddScalarFunctionOverloadInfo() override; + + ScalarFunctionSet new_overloads; + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_function_info.hpp new file mode 100644 index 00000000..ef08f556 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_function_info.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/alter_scalar_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/parser/parsed_data/alter_info.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Alter Table Function +//===--------------------------------------------------------------------===// +enum class AlterTableFunctionType : uint8_t { INVALID = 0, ADD_FUNCTION_OVERLOADS = 1 }; + +struct AlterTableFunctionInfo : public AlterInfo { + AlterTableFunctionInfo(AlterTableFunctionType type, AlterEntryData data); + virtual ~AlterTableFunctionInfo() override; + + AlterTableFunctionType alter_table_function_type; + +public: + CatalogType GetCatalogType() const override; +}; + +//===--------------------------------------------------------------------===// +// AddTableFunctionOverloadInfo +//===--------------------------------------------------------------------===// +struct AddTableFunctionOverloadInfo : public AlterTableFunctionInfo { + AddTableFunctionOverloadInfo(AlterEntryData data, TableFunctionSet new_overloads); + ~AddTableFunctionOverloadInfo() override; + + TableFunctionSet new_overloads; + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp new file mode 100644 index 00000000..6001e9a3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp @@ -0,0 +1,311 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/alter_table_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/alter_info.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/parsed_data/parse_info.hpp" + +namespace duckdb { + +enum class AlterForeignKeyType : uint8_t { AFT_ADD = 0, AFT_DELETE = 1 }; + +//===--------------------------------------------------------------------===// +// Change Ownership +//===--------------------------------------------------------------------===// +struct ChangeOwnershipInfo : public AlterInfo { + ChangeOwnershipInfo(CatalogType entry_catalog_type, string entry_catalog, string entry_schema, string entry_name, + string owner_schema, string owner_name, OnEntryNotFound if_not_found); + + // Catalog type refers to the entry type, since this struct is usually built from an + // ALTER . OWNED BY . statement + // here it is only possible to know the type of who is to be owned + CatalogType entry_catalog_type; + + string owner_schema; + string owner_name; + +public: + CatalogType GetCatalogType() const override; + unique_ptr Copy() const override; +}; + +//===--------------------------------------------------------------------===// +// Alter Table +//===--------------------------------------------------------------------===// +enum class AlterTableType : uint8_t { + INVALID = 0, + RENAME_COLUMN = 1, + RENAME_TABLE = 2, + ADD_COLUMN = 3, + REMOVE_COLUMN = 4, + ALTER_COLUMN_TYPE = 5, + SET_DEFAULT = 6, + FOREIGN_KEY_CONSTRAINT = 7, + SET_NOT_NULL = 8, + DROP_NOT_NULL = 9 +}; + +struct AlterTableInfo : public AlterInfo { + AlterTableInfo(AlterTableType type, AlterEntryData data); + ~AlterTableInfo() override; + + AlterTableType alter_table_type; + +public: + CatalogType GetCatalogType() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + AlterTableInfo(AlterTableType type); +}; + +//===--------------------------------------------------------------------===// +// RenameColumnInfo +//===--------------------------------------------------------------------===// +struct RenameColumnInfo : public AlterTableInfo { + RenameColumnInfo(AlterEntryData data, string old_name_p, string new_name_p); + ~RenameColumnInfo() override; + + //! Column old name + string old_name; + //! Column new name + string new_name; + +public: + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + RenameColumnInfo(); +}; + +//===--------------------------------------------------------------------===// +// RenameTableInfo +//===--------------------------------------------------------------------===// +struct RenameTableInfo : public AlterTableInfo { + RenameTableInfo(AlterEntryData data, string new_name); + ~RenameTableInfo() override; + + //! Relation new name + string new_table_name; + +public: + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + RenameTableInfo(); +}; + +//===--------------------------------------------------------------------===// +// AddColumnInfo +//===--------------------------------------------------------------------===// +struct AddColumnInfo : public AlterTableInfo { + AddColumnInfo(AlterEntryData data, ColumnDefinition new_column, bool if_column_not_exists); + ~AddColumnInfo() override; + + //! New column + ColumnDefinition new_column; + //! Whether or not an error should be thrown if the column exist + bool if_column_not_exists; + +public: + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + explicit AddColumnInfo(ColumnDefinition new_column); +}; + +//===--------------------------------------------------------------------===// +// RemoveColumnInfo +//===--------------------------------------------------------------------===// +struct RemoveColumnInfo : public AlterTableInfo { + RemoveColumnInfo(AlterEntryData data, string removed_column, bool if_column_exists, bool cascade); + ~RemoveColumnInfo() override; + + //! The column to remove + string removed_column; + //! Whether or not an error should be thrown if the column does not exist + bool if_column_exists; + //! Whether or not the column should be removed if a dependency conflict arises (used by GENERATED columns) + bool cascade; + +public: + unique_ptr Copy() const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + string GetColumnName() const override { + return removed_column; + } + +private: + RemoveColumnInfo(); +}; + +//===--------------------------------------------------------------------===// +// ChangeColumnTypeInfo +//===--------------------------------------------------------------------===// +struct ChangeColumnTypeInfo : public AlterTableInfo { + ChangeColumnTypeInfo(AlterEntryData data, string column_name, LogicalType target_type, + unique_ptr expression); + ~ChangeColumnTypeInfo() override; + + //! The column name to alter + string column_name; + //! The target type of the column + LogicalType target_type; + //! The expression used for data conversion + unique_ptr expression; + +public: + unique_ptr Copy() const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + string GetColumnName() const override { + return column_name; + }; + +private: + ChangeColumnTypeInfo(); +}; + +//===--------------------------------------------------------------------===// +// SetDefaultInfo +//===--------------------------------------------------------------------===// +struct SetDefaultInfo : public AlterTableInfo { + SetDefaultInfo(AlterEntryData data, string column_name, unique_ptr new_default); + ~SetDefaultInfo() override; + + //! The column name to alter + string column_name; + //! The expression used for data conversion + unique_ptr expression; + +public: + unique_ptr Copy() const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + SetDefaultInfo(); +}; + +//===--------------------------------------------------------------------===// +// AlterForeignKeyInfo +//===--------------------------------------------------------------------===// +struct AlterForeignKeyInfo : public AlterTableInfo { + AlterForeignKeyInfo(AlterEntryData data, string fk_table, vector pk_columns, vector fk_columns, + vector pk_keys, vector fk_keys, AlterForeignKeyType type); + ~AlterForeignKeyInfo() override; + + string fk_table; + vector pk_columns; + vector fk_columns; + vector pk_keys; + vector fk_keys; + AlterForeignKeyType type; + +public: + unique_ptr Copy() const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + AlterForeignKeyInfo(); +}; + +//===--------------------------------------------------------------------===// +// SetNotNullInfo +//===--------------------------------------------------------------------===// +struct SetNotNullInfo : public AlterTableInfo { + SetNotNullInfo(AlterEntryData data, string column_name); + ~SetNotNullInfo() override; + + //! The column name to alter + string column_name; + +public: + unique_ptr Copy() const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + SetNotNullInfo(); +}; + +//===--------------------------------------------------------------------===// +// DropNotNullInfo +//===--------------------------------------------------------------------===// +struct DropNotNullInfo : public AlterTableInfo { + DropNotNullInfo(AlterEntryData data, string column_name); + ~DropNotNullInfo() override; + + //! The column name to alter + string column_name; + +public: + unique_ptr Copy() const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + DropNotNullInfo(); +}; + +//===--------------------------------------------------------------------===// +// Alter View +//===--------------------------------------------------------------------===// +enum class AlterViewType : uint8_t { INVALID = 0, RENAME_VIEW = 1 }; + +struct AlterViewInfo : public AlterInfo { + AlterViewInfo(AlterViewType type, AlterEntryData data); + ~AlterViewInfo() override; + + AlterViewType alter_view_type; + +public: + CatalogType GetCatalogType() const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + AlterViewInfo(AlterViewType type); +}; + +//===--------------------------------------------------------------------===// +// RenameViewInfo +//===--------------------------------------------------------------------===// +struct RenameViewInfo : public AlterViewInfo { + RenameViewInfo(AlterEntryData data, string new_name); + ~RenameViewInfo() override; + + //! Relation new name + string new_view_name; + +public: + unique_ptr Copy() const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + RenameViewInfo(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp new file mode 100644 index 00000000..9732f744 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/attach_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +struct AttachInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::ATTACH_INFO; + +public: + AttachInfo() : ParseInfo(TYPE) { + } + + //! The alias of the attached database + string name; + //! The path to the attached database + string path; + //! Set of (key, value) options + unordered_map options; + +public: + unique_ptr Copy() const; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp new file mode 100644 index 00000000..75e8322a --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/copy_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +struct CopyInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::COPY_INFO; + +public: + CopyInfo() : ParseInfo(TYPE), catalog(INVALID_CATALOG), schema(DEFAULT_SCHEMA) { + } + + //! The catalog name to copy to/from + string catalog; + //! The schema name to copy to/from + string schema; + //! The table name to copy to/from + string table; + //! List of columns to copy to/from + vector select_list; + //! Whether or not this is a copy to file (false) or copy from a file (true) + bool is_from; + //! The file format of the external file + string format; + //! The file path to copy to/from + string file_path; + //! Set of (key, value) options + case_insensitive_map_t> options; + +public: + unique_ptr Copy() const { + auto result = make_uniq(); + result->catalog = catalog; + result->schema = schema; + result->table = table; + result->select_list = select_list; + result->file_path = file_path; + result->is_from = is_from; + result->format = format; + result->options = options; + return result; + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_aggregate_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_aggregate_function_info.hpp new file mode 100644 index 00000000..bcec8d7c --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_aggregate_function_info.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_aggregate_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_function_info.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct CreateAggregateFunctionInfo : public CreateFunctionInfo { + explicit CreateAggregateFunctionInfo(AggregateFunction function); + explicit CreateAggregateFunctionInfo(AggregateFunctionSet set); + + AggregateFunctionSet functions; + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_collation_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_collation_info.hpp new file mode 100644 index 00000000..09dea057 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_collation_info.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_collation_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/function/scalar_function.hpp" + +namespace duckdb { + +struct CreateCollationInfo : public CreateInfo { + DUCKDB_API CreateCollationInfo(string name_p, ScalarFunction function_p, bool combinable_p, + bool not_required_for_equality_p); + + //! The name of the collation + string name; + //! The collation function to push in case collation is required + ScalarFunction function; + //! Whether or not the collation can be combined with other collations. + bool combinable; + //! Whether or not the collation is required for equality comparisons or not. For many collations a binary + //! comparison for equality comparisons is correct, allowing us to skip the collation in these cases which greatly + //! speeds up processing. + bool not_required_for_equality; + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_copy_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_copy_function_info.hpp new file mode 100644 index 00000000..40eafbc9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_copy_function_info.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_copy_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/function/copy_function.hpp" + +namespace duckdb { + +struct CreateCopyFunctionInfo : public CreateInfo { + DUCKDB_API explicit CreateCopyFunctionInfo(CopyFunction function); + + //! Function name + string name; + //! The table function + CopyFunction function; + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp new file mode 100644 index 00000000..1e50583c --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/function/function.hpp" + +namespace duckdb { + +struct CreateFunctionInfo : public CreateInfo { + explicit CreateFunctionInfo(CatalogType type, string schema = DEFAULT_SCHEMA) : CreateInfo(type, schema) { + D_ASSERT(type == CatalogType::SCALAR_FUNCTION_ENTRY || type == CatalogType::AGGREGATE_FUNCTION_ENTRY || + type == CatalogType::TABLE_FUNCTION_ENTRY || type == CatalogType::PRAGMA_FUNCTION_ENTRY || + type == CatalogType::MACRO_ENTRY || type == CatalogType::TABLE_MACRO_ENTRY); + } + + //! Function name + string name; + //! The description (if any) + string description; + //! Parameter names (if any) + vector parameter_names; + //! The example (if any) + string example; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_index_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_index_info.hpp new file mode 100644 index 00000000..a4d2181f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_index_info.hpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_index_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/common/enums/index_type.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/function/table_function.hpp" + +namespace duckdb { + +struct CreateIndexInfo : public CreateInfo { + CreateIndexInfo() : CreateInfo(CatalogType::INDEX_ENTRY) { + } + + //! Index Type (e.g., B+-tree, Skip-List, ...) + IndexType index_type; + //! Name of the Index + string index_name; + //! Name of the Index type + string index_type_name; + //! Index Constraint Type + IndexConstraintType constraint_type; + //! The table to create the index on + string table; + //! Set of expressions to index by + vector> expressions; + vector> parsed_expressions; + + //! Types used for the CREATE INDEX scan + vector scan_types; + //! The names of the columns, used for the CREATE INDEX scan + vector names; + //! Column IDs needed for index creation + vector column_ids; + + //! Options values (WITH ...) + case_insensitive_map_t options; + +public: + DUCKDB_API unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_info.hpp new file mode 100644 index 00000000..3fd94128 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_info.hpp @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/parser/parsed_data/parse_info.hpp" + +namespace duckdb { +struct AlterInfo; + +enum class OnCreateConflict : uint8_t { + // Standard: throw error + ERROR_ON_CONFLICT, + // CREATE IF NOT EXISTS, silently do nothing on conflict + IGNORE_ON_CONFLICT, + // CREATE OR REPLACE + REPLACE_ON_CONFLICT, + // Update on conflict - only support for functions. Add a function overload if the function already exists. + ALTER_ON_CONFLICT +}; + +struct CreateInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::CREATE_INFO; + +public: + explicit CreateInfo(CatalogType type, string schema = DEFAULT_SCHEMA, string catalog_p = INVALID_CATALOG) + : ParseInfo(TYPE), type(type), catalog(std::move(catalog_p)), schema(schema), + on_conflict(OnCreateConflict::ERROR_ON_CONFLICT), temporary(false), internal(false) { + } + ~CreateInfo() override { + } + + //! The to-be-created catalog type + CatalogType type; + //! The catalog name of the entry + string catalog; + //! The schema name of the entry + string schema; + //! What to do on create conflict + OnCreateConflict on_conflict; + //! Whether or not the entry is temporary + bool temporary; + //! Whether or not the entry is an internal entry + bool internal; + //! The SQL string of the CREATE statement + string sql; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + virtual unique_ptr Copy() const = 0; + + DUCKDB_API void CopyProperties(CreateInfo &other) const; + //! Generates an alter statement from the create statement - used for OnCreateConflict::ALTER_ON_CONFLICT + DUCKDB_API virtual unique_ptr GetAlterInfo() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_macro_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_macro_info.hpp new file mode 100644 index 00000000..d51ddfa7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_macro_info.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_macro_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_function_info.hpp" +#include "duckdb/function/macro_function.hpp" + +namespace duckdb { + +struct CreateMacroInfo : public CreateFunctionInfo { + CreateMacroInfo(CatalogType type); + + unique_ptr function; + +public: + unique_ptr Copy() const override; + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp new file mode 100644 index 00000000..eae55880 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_pragma_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_function_info.hpp" +#include "duckdb/function/pragma_function.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct CreatePragmaFunctionInfo : public CreateFunctionInfo { + DUCKDB_API explicit CreatePragmaFunctionInfo(PragmaFunction function); + DUCKDB_API CreatePragmaFunctionInfo(string name, PragmaFunctionSet functions_); + + PragmaFunctionSet functions; + +public: + DUCKDB_API unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_scalar_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_scalar_function_info.hpp new file mode 100644 index 00000000..12aee605 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_scalar_function_info.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_scalar_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_function_info.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct CreateScalarFunctionInfo : public CreateFunctionInfo { + DUCKDB_API explicit CreateScalarFunctionInfo(ScalarFunction function); + DUCKDB_API explicit CreateScalarFunctionInfo(ScalarFunctionSet set); + + ScalarFunctionSet functions; + +public: + DUCKDB_API unique_ptr Copy() const override; + DUCKDB_API unique_ptr GetAlterInfo() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_schema_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_schema_info.hpp new file mode 100644 index 00000000..90764587 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_schema_info.hpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_schema_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" + +namespace duckdb { + +struct CreateSchemaInfo : public CreateInfo { + CreateSchemaInfo() : CreateInfo(CatalogType::SCHEMA_ENTRY) { + } + +public: + unique_ptr Copy() const override { + auto result = make_uniq(); + CopyProperties(*result); + return std::move(result); + } + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp new file mode 100644 index 00000000..de8fcc57 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_sequence_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/common/limits.hpp" + +namespace duckdb { + +enum class SequenceInfo : uint8_t { + // Sequence start + SEQ_START, + // Sequence increment + SEQ_INC, + // Sequence minimum value + SEQ_MIN, + // Sequence maximum value + SEQ_MAX, + // Sequence cycle option + SEQ_CYCLE, + // Sequence owner table + SEQ_OWN +}; + +struct CreateSequenceInfo : public CreateInfo { + CreateSequenceInfo(); + + //! Sequence name to create + string name; + //! Usage count of the sequence + uint64_t usage_count; + //! The increment value + int64_t increment; + //! The minimum value of the sequence + int64_t min_value; + //! The maximum value of the sequence + int64_t max_value; + //! The start value of the sequence + int64_t start_value; + //! Whether or not the sequence cycles + bool cycle; + +public: + unique_ptr Copy() const override; + +public: + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_function_info.hpp new file mode 100644 index 00000000..1e64b6c6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_function_info.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_table_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_function_info.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct CreateTableFunctionInfo : public CreateFunctionInfo { + DUCKDB_API explicit CreateTableFunctionInfo(TableFunction function); + DUCKDB_API explicit CreateTableFunctionInfo(TableFunctionSet set); + + //! The table functions + TableFunctionSet functions; + +public: + DUCKDB_API unique_ptr Copy() const override; + DUCKDB_API unique_ptr GetAlterInfo() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp new file mode 100644 index 00000000..395f78ba --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_table_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/catalog/catalog_entry/column_dependency_manager.hpp" +#include "duckdb/parser/column_list.hpp" + +namespace duckdb { +class SchemaCatalogEntry; + +struct CreateTableInfo : public CreateInfo { + DUCKDB_API CreateTableInfo(); + DUCKDB_API CreateTableInfo(string catalog, string schema, string name); + DUCKDB_API CreateTableInfo(SchemaCatalogEntry &schema, string name); + + //! Table name to insert to + string table; + //! List of columns of the table + ColumnList columns; + //! List of constraints on the table + vector> constraints; + //! CREATE TABLE from QUERY + unique_ptr query; + +public: + DUCKDB_API unique_ptr Copy() const override; + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp new file mode 100644 index 00000000..c8ab662f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_type_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/statement/select_statement.hpp" + +namespace duckdb { + +struct CreateTypeInfo : public CreateInfo { + CreateTypeInfo(); + CreateTypeInfo(string name_p, LogicalType type_p); + + //! Name of the Type + string name; + //! Logical Type + LogicalType type; + //! Used by create enum from query + unique_ptr query; + +public: + unique_ptr Copy() const override; + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_view_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_view_info.hpp new file mode 100644 index 00000000..4f9ff34b --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_view_info.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/create_view_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/parser/statement/select_statement.hpp" + +namespace duckdb { +class SchemaCatalogEntry; + +struct CreateViewInfo : public CreateInfo { + CreateViewInfo(); + CreateViewInfo(SchemaCatalogEntry &schema, string view_name); + CreateViewInfo(string catalog_p, string schema_p, string view_name); + + //! Table name to insert to + string view_name; + //! Aliases of the view + vector aliases; + //! Return types + vector types; + //! The SelectStatement of the view + unique_ptr query; + +public: + unique_ptr Copy() const override; + + //! Gets a bound CreateViewInfo object from a SELECT statement and a view name, schema name, etc + DUCKDB_API static unique_ptr FromSelect(ClientContext &context, unique_ptr info); + //! Gets a bound CreateViewInfo object from a CREATE VIEW statement + DUCKDB_API static unique_ptr FromCreateView(ClientContext &context, const string &sql); + + DUCKDB_API void Serialize(Serializer &serializer) const override; + DUCKDB_API static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/detach_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/detach_info.hpp new file mode 100644 index 00000000..8c0fdedc --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/detach_info.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/detach_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" + +namespace duckdb { + +struct DetachInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::DETACH_INFO; + +public: + DetachInfo(); + + //! The alias of the attached database + string name; + //! Whether to throw an exception if alias is not found + OnEntryNotFound if_not_found; + +public: + unique_ptr Copy() const; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/drop_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/drop_info.hpp new file mode 100644 index 00000000..57322fa5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/drop_info.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/drop_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" + +namespace duckdb { + +struct DropInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::DROP_INFO; + +public: + DropInfo(); + + //! The catalog type to drop + CatalogType type; + //! Catalog name to drop from, if any + string catalog; + //! Schema name to drop from, if any + string schema; + //! Element name to drop + string name; + //! Ignore if the entry does not exist instead of failing + OnEntryNotFound if_not_found = OnEntryNotFound::THROW_EXCEPTION; + //! Cascade drop (drop all dependents instead of throwing an error if there + //! are any) + bool cascade = false; + //! Allow dropping of internal system entries + bool allow_drop_internal = false; + +public: + unique_ptr Copy() const; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/exported_table_data.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/exported_table_data.hpp new file mode 100644 index 00000000..a5246b4e --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/exported_table_data.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/export_table_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { +class TableCatalogEntry; + +struct ExportedTableData { + //! Name of the exported table + string table_name; + + //! Name of the schema + string schema_name; + + //! Name of the database + string database_name; + + //! Path to be exported + string file_path; +}; + +struct ExportedTableInfo { + ExportedTableInfo(TableCatalogEntry &entry, ExportedTableData table_data) + : entry(entry), table_data(std::move(table_data)) { + } + + TableCatalogEntry &entry; + ExportedTableData table_data; +}; + +struct BoundExportData : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::BOUND_EXPORT_DATA; + +public: + BoundExportData() : ParseInfo(TYPE) { + } + + vector data; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/load_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/load_info.hpp new file mode 100644 index 00000000..b595fa76 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/load_info.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/load_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" + +namespace duckdb { + +enum class LoadType : uint8_t { LOAD, INSTALL, FORCE_INSTALL }; + +struct LoadInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::LOAD_INFO; + +public: + LoadInfo() : ParseInfo(TYPE) { + } + + string filename; + string repository; + LoadType load_type; + +public: + unique_ptr Copy() const { + auto result = make_uniq(); + result->filename = filename; + result->repository = repository; + result->load_type = load_type; + return result; + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/parse_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/parse_info.hpp new file mode 100644 index 00000000..3cacbd7a --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/parse_info.hpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/parse_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +enum class ParseInfoType : uint8_t { + ALTER_INFO, + ATTACH_INFO, + COPY_INFO, + CREATE_INFO, + DETACH_INFO, + DROP_INFO, + BOUND_EXPORT_DATA, + LOAD_INFO, + PRAGMA_INFO, + SHOW_SELECT_INFO, + TRANSACTION_INFO, + VACUUM_INFO +}; + +struct ParseInfo { + explicit ParseInfo(ParseInfoType info_type) : info_type(info_type) { + } + virtual ~ParseInfo() { + } + + ParseInfoType info_type; + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + + virtual void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp new file mode 100644 index 00000000..56735880 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/pragma_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/named_parameter_map.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +enum class PragmaType : uint8_t { PRAGMA_STATEMENT, PRAGMA_CALL }; + +struct PragmaInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::PRAGMA_INFO; + +public: + PragmaInfo() : ParseInfo(TYPE) { + } + + //! Name of the PRAGMA statement + string name; + //! Parameter list (if any) + vector parameters; + //! Named parameter list (if any) + named_parameter_map_t named_parameters; + +public: + unique_ptr Copy() const { + auto result = make_uniq(); + result->name = name; + result->parameters = parameters; + result->named_parameters = named_parameters; + return result; + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp new file mode 100644 index 00000000..201469bc --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/sample_options.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +enum class SampleMethod : uint8_t { SYSTEM_SAMPLE = 0, BERNOULLI_SAMPLE = 1, RESERVOIR_SAMPLE = 2 }; + +// **DEPRECATED**: Use EnumUtil directly instead. +string SampleMethodToString(SampleMethod method); + +struct SampleOptions { + Value sample_size; + bool is_percentage; + SampleMethod method; + int64_t seed = -1; + + unique_ptr Copy(); + static bool Equals(SampleOptions *a, SampleOptions *b); + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/show_select_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/show_select_info.hpp new file mode 100644 index 00000000..117e8fd1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/show_select_info.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/show_select_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +struct ShowSelectInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::SHOW_SELECT_INFO; + +public: + ShowSelectInfo() : ParseInfo(TYPE) { + } + + //! Types of projected columns + vector types; + //! The QueryNode of select query + unique_ptr query; + //! Aliases of projected columns + vector aliases; + //! Whether or not we are requesting a summary or a describe + bool is_summary; + + unique_ptr Copy() { + auto result = make_uniq(); + result->types = types; + result->query = query->Copy(); + result->aliases = aliases; + result->is_summary = is_summary; + return result; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/transaction_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/transaction_info.hpp new file mode 100644 index 00000000..59c68914 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/transaction_info.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/transaction_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" + +namespace duckdb { + +enum class TransactionType : uint8_t { INVALID, BEGIN_TRANSACTION, COMMIT, ROLLBACK }; + +struct TransactionInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::TRANSACTION_INFO; + +public: + explicit TransactionInfo(TransactionType type); + + //! The type of transaction statement + TransactionType type; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + TransactionInfo(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp new file mode 100644 index 00000000..08fdda15 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/vacuum_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/planner/tableref/bound_basetableref.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { +class Serializer; +class Deserializer; + +struct VacuumOptions { + VacuumOptions() : vacuum(false), analyze(false) { + } + + bool vacuum; + bool analyze; + + void Serialize(Serializer &serializer) const; + static VacuumOptions Deserialize(Deserializer &deserializer); +}; + +struct VacuumInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::VACUUM_INFO; + +public: + explicit VacuumInfo(VacuumOptions options); + + const VacuumOptions options; + +public: + bool has_table; + unique_ptr ref; + optional_ptr table; + unordered_map column_id_map; + vector columns; + +public: + unique_ptr Copy(); + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp b/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp new file mode 100644 index 00000000..148b7ea9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/base_expression.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/qualified_name.hpp" +#include "duckdb/parser/expression_util.hpp" + +namespace duckdb { +class Deserializer; +class Serializer; + +//! The ParsedExpression class is a base class that can represent any expression +//! part of a SQL statement. +/*! + The ParsedExpression class is a base class that can represent any expression + part of a SQL statement. This is, for example, a column reference in a SELECT + clause, but also operators, aggregates or filters. The Expression is emitted by the parser and does not contain any + information about bindings to the catalog or to the types. ParsedExpressions are transformed into regular Expressions + in the Binder. + */ +class ParsedExpression : public BaseExpression { +public: + //! Create an Expression + ParsedExpression(ExpressionType type, ExpressionClass expression_class) : BaseExpression(type, expression_class) { + } + + //! The location in the query (if any) + idx_t query_location = DConstants::INVALID_INDEX; + +public: + bool IsAggregate() const override; + bool IsWindow() const override; + bool HasSubquery() const override; + bool IsScalar() const override; + bool HasParameter() const override; + + bool Equals(const BaseExpression &other) const override; + hash_t Hash() const override; + + //! Create a copy of this expression + virtual unique_ptr Copy() const = 0; + + virtual void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + + static bool Equals(const unique_ptr &left, const unique_ptr &right); + static bool ListEquals(const vector> &left, + const vector> &right); + +protected: + //! Copy base Expression properties from another expression to this one, + //! used in Copy method + void CopyProperties(const ParsedExpression &other) { + type = other.type; + expression_class = other.expression_class; + alias = other.alias; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_expression_iterator.hpp b/src/duckdb/src/include/duckdb/parser/parsed_expression_iterator.hpp new file mode 100644 index 00000000..b8953876 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_expression_iterator.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_expression_iterator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/tokens.hpp" + +#include + +namespace duckdb { + +class ParsedExpressionIterator { +public: + static void EnumerateChildren(const ParsedExpression &expression, + const std::function &callback); + static void EnumerateChildren(ParsedExpression &expr, const std::function &callback); + static void EnumerateChildren(ParsedExpression &expr, + const std::function &child)> &callback); + + static void EnumerateTableRefChildren(TableRef &ref, + const std::function &child)> &callback); + static void EnumerateQueryNodeChildren(QueryNode &node, + const std::function &child)> &callback); + + static void EnumerateQueryNodeModifiers(QueryNode &node, + const std::function &child)> &callback); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parser.hpp b/src/duckdb/src/include/duckdb/parser/parser.hpp new file mode 100644 index 00000000..4ba14181 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parser.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parser.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/column_list.hpp" +#include "duckdb/parser/simplified_token.hpp" +#include "duckdb/parser/parser_options.hpp" + +namespace duckdb_libpgquery { +struct PGNode; +struct PGList; +} // namespace duckdb_libpgquery + +namespace duckdb { + +class GroupByNode; + +//! The parser is responsible for parsing the query and converting it into a set +//! of parsed statements. The parsed statements can then be converted into a +//! plan and executed. +class Parser { +public: + Parser(ParserOptions options = ParserOptions()); + + //! The parsed SQL statements from an invocation to ParseQuery. + vector> statements; + +public: + //! Attempts to parse a query into a series of SQL statements. Returns + //! whether or not the parsing was successful. If the parsing was + //! successful, the parsed statements will be stored in the statements + //! variable. + void ParseQuery(const string &query); + + //! Tokenize a query, returning the raw tokens together with their locations + static vector Tokenize(const string &query); + + //! Returns true if the given text matches a keyword of the parser + static bool IsKeyword(const string &text); + //! Returns a list of all keywords in the parser + static vector KeywordList(); + + //! Parses a list of expressions (i.e. the list found in a SELECT clause) + DUCKDB_API static vector> ParseExpressionList(const string &select_list, + ParserOptions options = ParserOptions()); + //! Parses a list of GROUP BY expressions + static GroupByNode ParseGroupByList(const string &group_by, ParserOptions options = ParserOptions()); + //! Parses a list as found in an ORDER BY expression (i.e. including optional ASCENDING/DESCENDING modifiers) + static vector ParseOrderList(const string &select_list, ParserOptions options = ParserOptions()); + //! Parses an update list (i.e. the list found in the SET clause of an UPDATE statement) + static void ParseUpdateList(const string &update_list, vector &update_columns, + vector> &expressions, + ParserOptions options = ParserOptions()); + //! Parses a VALUES list (i.e. the list of expressions after a VALUES clause) + static vector>> ParseValuesList(const string &value_list, + ParserOptions options = ParserOptions()); + //! Parses a column list (i.e. as found in a CREATE TABLE statement) + static ColumnList ParseColumnList(const string &column_list, ParserOptions options = ParserOptions()); + + static bool StripUnicodeSpaces(const string &query_str, string &new_query); + +private: + ParserOptions options; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parser_extension.hpp b/src/duckdb/src/include/duckdb/parser/parser_extension.hpp new file mode 100644 index 00000000..50fec976 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parser_extension.hpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parser_extension.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/function/table_function.hpp" + +namespace duckdb { + +//! The ParserExtensionInfo holds static information relevant to the parser extension +//! It is made available in the parse_function, and will be kept alive as long as the database system is kept alive +struct ParserExtensionInfo { + virtual ~ParserExtensionInfo() { + } +}; + +//===--------------------------------------------------------------------===// +// Parse +//===--------------------------------------------------------------------===// +enum class ParserExtensionResultType : uint8_t { PARSE_SUCCESSFUL, DISPLAY_ORIGINAL_ERROR, DISPLAY_EXTENSION_ERROR }; + +//! The ParserExtensionParseData holds the result of a successful parse step +//! It will be passed along to the subsequent plan function +struct ParserExtensionParseData { + virtual ~ParserExtensionParseData() { + } + + virtual unique_ptr Copy() const = 0; +}; + +struct ParserExtensionParseResult { + ParserExtensionParseResult() : type(ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR) { + } + ParserExtensionParseResult(string error_p) + : type(ParserExtensionResultType::DISPLAY_EXTENSION_ERROR), error(std::move(error_p)) { + } + ParserExtensionParseResult(unique_ptr parse_data_p) + : type(ParserExtensionResultType::PARSE_SUCCESSFUL), parse_data(std::move(parse_data_p)) { + } + + //! Whether or not parsing was successful + ParserExtensionResultType type; + //! The parse data (if successful) + unique_ptr parse_data; + //! The error message (if unsuccessful) + string error; +}; + +typedef ParserExtensionParseResult (*parse_function_t)(ParserExtensionInfo *info, const string &query); +//===--------------------------------------------------------------------===// +// Plan +//===--------------------------------------------------------------------===// +struct ParserExtensionPlanResult { + //! The table function to execute + TableFunction function; + //! Parameters to the function + vector parameters; + //! The set of databases that will be modified by this statement (empty for a read-only statement) + unordered_set modified_databases; + //! Whether or not the statement requires a valid transaction to be executed + bool requires_valid_transaction = true; + //! What type of result set the statement returns + StatementReturnType return_type = StatementReturnType::NOTHING; +}; + +typedef ParserExtensionPlanResult (*plan_function_t)(ParserExtensionInfo *info, ClientContext &context, + unique_ptr parse_data); + +//===--------------------------------------------------------------------===// +// ParserExtension +//===--------------------------------------------------------------------===// +class ParserExtension { +public: + //! The parse function of the parser extension. + //! Takes a query string as input and returns ParserExtensionParseData (on success) or an error + parse_function_t parse_function; + + //! The plan function of the parser extension + //! Takes as input the result of the parse_function, and outputs various properties of the resulting plan + plan_function_t plan_function; + + //! Additional parser info passed to the parse function + shared_ptr parser_info; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parser_options.hpp b/src/duckdb/src/include/duckdb/parser/parser_options.hpp new file mode 100644 index 00000000..d388fb11 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parser_options.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parser_options.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { +class ParserExtension; + +struct ParserOptions { + bool preserve_identifier_case = true; + bool integer_division = false; + idx_t max_expression_depth = 1000; + const vector *extensions = nullptr; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/qualified_name.hpp b/src/duckdb/src/include/duckdb/parser/qualified_name.hpp new file mode 100644 index 00000000..d389cffb --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/qualified_name.hpp @@ -0,0 +1,91 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/qualified_name.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/string.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +struct QualifiedName { + string catalog; + string schema; + string name; + + //! Parse the (optional) schema and a name from a string in the format of e.g. "schema"."table"; if there is no dot + //! the schema will be set to INVALID_SCHEMA + static QualifiedName Parse(const string &input) { + string catalog; + string schema; + string name; + idx_t idx = 0; + vector entries; + string entry; + normal: + //! quote + for (; idx < input.size(); idx++) { + if (input[idx] == '"') { + idx++; + goto quoted; + } else if (input[idx] == '.') { + goto separator; + } + entry += input[idx]; + } + goto end; + separator: + entries.push_back(entry); + entry = ""; + idx++; + goto normal; + quoted: + //! look for another quote + for (; idx < input.size(); idx++) { + if (input[idx] == '"') { + //! unquote + idx++; + goto normal; + } + entry += input[idx]; + } + throw ParserException("Unterminated quote in qualified name!"); + end: + if (entries.empty()) { + catalog = INVALID_CATALOG; + schema = INVALID_SCHEMA; + name = entry; + } else if (entries.size() == 1) { + catalog = INVALID_CATALOG; + schema = entries[0]; + name = entry; + } else if (entries.size() == 2) { + catalog = entries[0]; + schema = entries[1]; + name = entry; + } else { + throw ParserException("Expected catalog.entry, schema.entry or entry: too many entries found"); + } + return QualifiedName {catalog, schema, name}; + } +}; + +struct QualifiedColumnName { + QualifiedColumnName() { + } + QualifiedColumnName(string table_p, string column_p) : table(std::move(table_p)), column(std::move(column_p)) { + } + + string schema; + string table; + string column; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp b/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp new file mode 100644 index 00000000..1a105209 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/qualified_name_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/qualified_name.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +struct QualifiedColumnHashFunction { + uint64_t operator()(const QualifiedColumnName &a) const { + std::hash str_hasher; + return str_hasher(a.schema) ^ str_hasher(a.table) ^ str_hasher(a.column); + } +}; + +struct QualifiedColumnEquality { + bool operator()(const QualifiedColumnName &a, const QualifiedColumnName &b) const { + return a.schema == b.schema && a.table == b.table && a.column == b.column; + } +}; + +using qualified_column_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_error_context.hpp b/src/duckdb/src/include/duckdb/parser/query_error_context.hpp new file mode 100644 index 00000000..79a5dc40 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_error_context.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_error_context.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/exception_format_value.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { +class SQLStatement; + +class QueryErrorContext { +public: + explicit QueryErrorContext(optional_ptr statement_ = nullptr, + idx_t query_location_ = DConstants::INVALID_INDEX) + : statement(statement_), query_location(query_location_) { + } + + //! The query statement + optional_ptr statement; + //! The location in which the error should be thrown + idx_t query_location; + +public: + DUCKDB_API static string Format(const string &query, const string &error_message, int error_location); + + DUCKDB_API string FormatErrorRecursive(const string &msg, vector &values); + template + string FormatErrorRecursive(const string &msg, vector &values, T param, Args... params) { + values.push_back(ExceptionFormatValue::CreateFormatValue(param)); + return FormatErrorRecursive(msg, values, params...); + } + + template + string FormatError(const string &msg, Args... params) { + vector values; + return FormatErrorRecursive(msg, values, params...); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node.hpp new file mode 100644 index 00000000..cdd2d728 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node.hpp @@ -0,0 +1,104 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/parser/common_table_expression_info.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +class Deserializer; +class Serializer; + +enum class QueryNodeType : uint8_t { + SELECT_NODE = 1, + SET_OPERATION_NODE = 2, + BOUND_SUBQUERY_NODE = 3, + RECURSIVE_CTE_NODE = 4, + CTE_NODE = 5 +}; + +struct CommonTableExpressionInfo; + +class CommonTableExpressionMap { +public: + CommonTableExpressionMap(); + + case_insensitive_map_t> map; + +public: + string ToString() const; + CommonTableExpressionMap Copy() const; + + void Serialize(Serializer &serializer) const; + // static void Deserialize(Deserializer &deserializer, CommonTableExpressionMap &ret); + static CommonTableExpressionMap Deserialize(Deserializer &deserializer); +}; + +class QueryNode { +public: + explicit QueryNode(QueryNodeType type) : type(type) { + } + virtual ~QueryNode() { + } + + //! The type of the query node, either SetOperation or Select + QueryNodeType type; + //! The set of result modifiers associated with this query node + vector> modifiers; + //! CTEs (used by SelectNode and SetOperationNode) + CommonTableExpressionMap cte_map; + + virtual const vector> &GetSelectList() const = 0; + +public: + //! Convert the query node to a string + virtual string ToString() const = 0; + + virtual bool Equals(const QueryNode *other) const; + + //! Create a copy of this QueryNode + virtual unique_ptr Copy() const = 0; + + string ResultModifiersToString() const; + + //! Adds a distinct modifier to the query node + void AddDistinct(); + + virtual void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + //! Copy base QueryNode properties from another expression to this one, + //! used in Copy method + void CopyProperties(QueryNode &other) const; + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast query node to type - query node type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast query node to type - query node type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp new file mode 100644 index 00000000..b41b194a --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/cte_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class CTENode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; + +public: + CTENode() : QueryNode(QueryNodeType::CTE_NODE) { + } + + string ctename; + //! The query of the CTE + unique_ptr query; + //! Child + unique_ptr child; + //! Aliases of the CTE node + vector aliases; + + const vector> &GetSelectList() const override { + return query->GetSelectList(); + } + +public: + //! Convert the query node to a string + string ToString() const override; + + bool Equals(const QueryNode *other) const override; + //! Create a copy of this SelectNode + unique_ptr Copy() const override; + + //! Serializes a QueryNode to a stand-alone binary blob + //! Deserializes a blob back into a QueryNode + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp new file mode 100644 index 00000000..94bfd343 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp @@ -0,0 +1,4 @@ +#include "duckdb/parser/query_node/recursive_cte_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp new file mode 100644 index 00000000..a69bf003 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/recursive_cte_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class RecursiveCTENode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::RECURSIVE_CTE_NODE; + +public: + RecursiveCTENode() : QueryNode(QueryNodeType::RECURSIVE_CTE_NODE) { + } + + string ctename; + bool union_all; + //! The left side of the set operation + unique_ptr left; + //! The right side of the set operation + unique_ptr right; + //! Aliases of the recursive CTE node + vector aliases; + + const vector> &GetSelectList() const override { + return left->GetSelectList(); + } + +public: + //! Convert the query node to a string + string ToString() const override; + + bool Equals(const QueryNode *other) const override; + //! Create a copy of this SelectNode + unique_ptr Copy() const override; + + //! Serializes a QueryNode to a stand-alone binary blob + //! Deserializes a blob back into a QueryNode + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp new file mode 100644 index 00000000..62aa9c0b --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/select_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" +#include "duckdb/parser/group_by_node.hpp" +#include "duckdb/common/enums/aggregate_handling.hpp" + +namespace duckdb { + +//! SelectNode represents a standard SELECT statement +class SelectNode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::SELECT_NODE; + +public: + DUCKDB_API SelectNode(); + + //! The projection list + vector> select_list; + //! The FROM clause + unique_ptr from_table; + //! The WHERE clause + unique_ptr where_clause; + //! list of groups + GroupByNode groups; + //! HAVING clause + unique_ptr having; + //! QUALIFY clause + unique_ptr qualify; + //! Aggregate handling during binding + AggregateHandling aggregate_handling; + //! The SAMPLE clause + unique_ptr sample; + + const vector> &GetSelectList() const override { + return select_list; + } + +public: + //! Convert the query node to a string + string ToString() const override; + + bool Equals(const QueryNode *other) const override; + + //! Create a copy of this SelectNode + unique_ptr Copy() const override; + + //! Serializes a QueryNode to a stand-alone binary blob + + //! Deserializes a blob back into a QueryNode + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp new file mode 100644 index 00000000..2582cdb8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/set_operation_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/set_operation_type.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class SetOperationNode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::SET_OPERATION_NODE; + +public: + SetOperationNode() : QueryNode(QueryNodeType::SET_OPERATION_NODE) { + } + + //! The type of set operation + SetOperationType setop_type = SetOperationType::NONE; + //! The left side of the set operation + unique_ptr left; + //! The right side of the set operation + unique_ptr right; + + const vector> &GetSelectList() const override { + return left->GetSelectList(); + } + +public: + //! Convert the query node to a string + string ToString() const override; + + bool Equals(const QueryNode *other) const override; + //! Create a copy of this SelectNode + unique_ptr Copy() const override; + + //! Serializes a QueryNode to a stand-alone binary blob + //! Deserializes a blob back into a QueryNode + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/result_modifier.hpp b/src/duckdb/src/include/duckdb/parser/result_modifier.hpp new file mode 100644 index 00000000..f1d79652 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/result_modifier.hpp @@ -0,0 +1,170 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/result_modifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/order_type.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { +class Deserializer; +class Serializer; + +enum class ResultModifierType : uint8_t { + LIMIT_MODIFIER = 1, + ORDER_MODIFIER = 2, + DISTINCT_MODIFIER = 3, + LIMIT_PERCENT_MODIFIER = 4 +}; + +const char *ToString(ResultModifierType value); +ResultModifierType ResultModifierFromString(const char *value); + +//! A ResultModifier +class ResultModifier { +public: + explicit ResultModifier(ResultModifierType type) : type(type) { + } + virtual ~ResultModifier() { + } + + ResultModifierType type; + +public: + //! Returns true if the two result modifiers are equivalent + virtual bool Equals(const ResultModifier &other) const; + + //! Create a copy of this ResultModifier + virtual unique_ptr Copy() const = 0; + + virtual void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast result modifier to type - result modifier type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast result modifier to type - result modifier type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +//! Single node in ORDER BY statement +struct OrderByNode { + OrderByNode(OrderType type, OrderByNullType null_order, unique_ptr expression) + : type(type), null_order(null_order), expression(std::move(expression)) { + } + + //! Sort order, ASC or DESC + OrderType type; + //! The NULL sort order, NULLS_FIRST or NULLS_LAST + OrderByNullType null_order; + //! Expression to order by + unique_ptr expression; + +public: + string ToString() const; + + void Serialize(Serializer &serializer) const; + static OrderByNode Deserialize(Deserializer &deserializer); +}; + +class LimitModifier : public ResultModifier { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::LIMIT_MODIFIER; + +public: + LimitModifier() : ResultModifier(ResultModifierType::LIMIT_MODIFIER) { + } + + //! LIMIT count + unique_ptr limit; + //! OFFSET + unique_ptr offset; + +public: + bool Equals(const ResultModifier &other) const override; + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +class OrderModifier : public ResultModifier { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::ORDER_MODIFIER; + +public: + OrderModifier() : ResultModifier(ResultModifierType::ORDER_MODIFIER) { + } + + //! List of order nodes + vector orders; + +public: + bool Equals(const ResultModifier &other) const override; + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + static bool Equals(const unique_ptr &left, const unique_ptr &right); +}; + +class DistinctModifier : public ResultModifier { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::DISTINCT_MODIFIER; + +public: + DistinctModifier() : ResultModifier(ResultModifierType::DISTINCT_MODIFIER) { + } + + //! list of distinct on targets (if any) + vector> distinct_on_targets; + +public: + bool Equals(const ResultModifier &other) const override; + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +class LimitPercentModifier : public ResultModifier { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::LIMIT_PERCENT_MODIFIER; + +public: + LimitPercentModifier() : ResultModifier(ResultModifierType::LIMIT_PERCENT_MODIFIER) { + } + + //! LIMIT % + unique_ptr limit; + //! OFFSET + unique_ptr offset; + +public: + bool Equals(const ResultModifier &other) const override; + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/simplified_token.hpp b/src/duckdb/src/include/duckdb/parser/simplified_token.hpp new file mode 100644 index 00000000..e7993871 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/simplified_token.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/simplified_token.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +//! Simplified tokens are a simplified (dense) representation of the lexer +//! Used for simple syntax highlighting in the tests +enum class SimplifiedTokenType : uint8_t { + SIMPLIFIED_TOKEN_IDENTIFIER, + SIMPLIFIED_TOKEN_NUMERIC_CONSTANT, + SIMPLIFIED_TOKEN_STRING_CONSTANT, + SIMPLIFIED_TOKEN_OPERATOR, + SIMPLIFIED_TOKEN_KEYWORD, + SIMPLIFIED_TOKEN_COMMENT +}; + +struct SimplifiedToken { + SimplifiedTokenType type; + idx_t start; +}; + +enum class KeywordCategory : uint8_t { KEYWORD_RESERVED, KEYWORD_UNRESERVED, KEYWORD_TYPE_FUNC, KEYWORD_COL_NAME }; + +struct ParserKeyword { + string name; + KeywordCategory category; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/sql_statement.hpp b/src/duckdb/src/include/duckdb/parser/sql_statement.hpp new file mode 100644 index 00000000..d8ebc1a2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/sql_statement.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/sql_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/named_parameter_map.hpp" + +namespace duckdb { + +//! SQLStatement is the base class of any type of SQL statement. +class SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::INVALID_STATEMENT; + +public: + explicit SQLStatement(StatementType type) : type(type) { + } + virtual ~SQLStatement() { + } + + //! The statement type + StatementType type; + //! The statement location within the query string + idx_t stmt_location = 0; + //! The statement length within the query string + idx_t stmt_length = 0; + //! The number of prepared statement parameters (if any) + idx_t n_param = 0; + //! The map of named parameter to param index (if n_param and any named) + case_insensitive_map_t named_param_map; + //! The query text that corresponds to this SQL statement + string query; + +protected: + SQLStatement(const SQLStatement &other) = default; + +public: + virtual string ToString() const { + throw InternalException("ToString not supported for this type of SQLStatement: '%s'", + StatementTypeToString(type)); + } + //! Create a copy of this SelectStatement + DUCKDB_API virtual unique_ptr Copy() const = 0; + +public: +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE && TARGET::TYPE != StatementType::INVALID_STATEMENT) { + throw InternalException("Failed to cast statement to type - statement type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE && TARGET::TYPE != StatementType::INVALID_STATEMENT) { + throw InternalException("Failed to cast statement to type - statement type mismatch"); + } + return reinterpret_cast(*this); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/alter_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/alter_statement.hpp new file mode 100644 index 00000000..67bd648e --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/alter_statement.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/alter_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class AlterStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::ALTER_STATEMENT; + +public: + AlterStatement(); + + unique_ptr info; + +protected: + AlterStatement(const AlterStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/attach_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/attach_statement.hpp new file mode 100644 index 00000000..3f1adcc1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/attach_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/attach_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/attach_info.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class AttachStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::ATTACH_STATEMENT; + +public: + AttachStatement(); + + unique_ptr info; + +protected: + AttachStatement(const AttachStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/call_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/call_statement.hpp new file mode 100644 index 00000000..7fd8fa51 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/call_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/call_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class CallStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::CALL_STATEMENT; + +public: + CallStatement(); + + unique_ptr function; + +protected: + CallStatement(const CallStatement &other); + +public: + unique_ptr Copy() const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/copy_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/copy_statement.hpp new file mode 100644 index 00000000..4ac59c5b --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/copy_statement.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/copy_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class CopyStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::COPY_STATEMENT; + +public: + CopyStatement(); + + unique_ptr info; + // The SQL statement used instead of a table when copying data out to a file + unique_ptr select_statement; + string ToString() const override; + string CopyOptionsToString(const string &format, const case_insensitive_map_t> &options) const; + +protected: + CopyStatement(const CopyStatement &other); + +public: + DUCKDB_API unique_ptr Copy() const override; + +private: +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/create_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/create_statement.hpp new file mode 100644 index 00000000..362af264 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/create_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/create_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class CreateStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::CREATE_STATEMENT; + +public: + CreateStatement(); + + unique_ptr info; + +protected: + CreateStatement(const CreateStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/delete_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/delete_statement.hpp new file mode 100644 index 00000000..b1dcd72f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/delete_statement.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/delete_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +class DeleteStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::DELETE_STATEMENT; + +public: + DeleteStatement(); + + unique_ptr condition; + unique_ptr table; + vector> using_clauses; + vector> returning_list; + //! CTEs + CommonTableExpressionMap cte_map; + +protected: + DeleteStatement(const DeleteStatement &other); + +public: + string ToString() const override; + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/detach_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/detach_statement.hpp new file mode 100644 index 00000000..3d1380cd --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/detach_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/detach_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/detach_info.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class DetachStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::DETACH_STATEMENT; + +public: + DetachStatement(); + + unique_ptr info; + +protected: + DetachStatement(const DetachStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/drop_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/drop_statement.hpp new file mode 100644 index 00000000..d8e71e65 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/drop_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/drop_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class DropStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::DROP_STATEMENT; + +public: + DropStatement(); + + unique_ptr info; + +protected: + DropStatement(const DropStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp new file mode 100644 index 00000000..81ca82f5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/execute_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class ExecuteStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::EXECUTE_STATEMENT; + +public: + ExecuteStatement(); + + string name; + case_insensitive_map_t> named_values; + +protected: + ExecuteStatement(const ExecuteStatement &other); + +public: + unique_ptr Copy() const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/explain_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/explain_statement.hpp new file mode 100644 index 00000000..f6f6086a --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/explain_statement.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/explain_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +enum class ExplainType : uint8_t { EXPLAIN_STANDARD, EXPLAIN_ANALYZE }; + +class ExplainStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::EXPLAIN_STATEMENT; + +public: + explicit ExplainStatement(unique_ptr stmt, ExplainType explain_type = ExplainType::EXPLAIN_STANDARD); + + unique_ptr stmt; + ExplainType explain_type; + +protected: + ExplainStatement(const ExplainStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/export_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/export_statement.hpp new file mode 100644 index 00000000..58215e39 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/export_statement.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/export_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" + +namespace duckdb { + +class ExportStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::EXPORT_STATEMENT; + +public: + explicit ExportStatement(unique_ptr info); + + unique_ptr info; + string database; + +protected: + ExportStatement(const ExportStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/extension_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/extension_statement.hpp new file mode 100644 index 00000000..c75f0668 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/extension_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/extension_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/parser_extension.hpp" + +namespace duckdb { + +class ExtensionStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::EXTENSION_STATEMENT; + +public: + ExtensionStatement(ParserExtension extension, unique_ptr parse_data); + + //! The ParserExtension this statement was generated from + ParserExtension extension; + //! The parse data for this specific statement + unique_ptr parse_data; + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/insert_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/insert_statement.hpp new file mode 100644 index 00000000..0f6a6624 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/insert_statement.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/insert_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/statement/update_statement.hpp" + +namespace duckdb { +class ExpressionListRef; +class UpdateSetInfo; + +enum class OnConflictAction : uint8_t { + THROW, + NOTHING, + UPDATE, + REPLACE // Only used in transform/bind step, changed to UPDATE later +}; + +enum class InsertColumnOrder : uint8_t { INSERT_BY_POSITION = 0, INSERT_BY_NAME = 1 }; + +class OnConflictInfo { +public: + OnConflictInfo(); + +public: + unique_ptr Copy() const; + +public: + OnConflictAction action_type; + + vector indexed_columns; + //! The SET information (if action_type == UPDATE) + unique_ptr set_info; + //! The condition determining whether we apply the DO .. for conflicts that arise + unique_ptr condition; + +protected: + OnConflictInfo(const OnConflictInfo &other); +}; + +class InsertStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::INSERT_STATEMENT; + +public: + InsertStatement(); + + //! The select statement to insert from + unique_ptr select_statement; + //! Column names to insert into + vector columns; + + //! Table name to insert to + string table; + //! Schema name to insert to + string schema; + //! The catalog name to insert to + string catalog; + + //! keep track of optional returningList if statement contains a RETURNING keyword + vector> returning_list; + + unique_ptr on_conflict_info; + unique_ptr table_ref; + + //! CTEs + CommonTableExpressionMap cte_map; + + //! Whether or not this a DEFAULT VALUES + bool default_values = false; + + //! INSERT BY POSITION or INSERT BY NAME + InsertColumnOrder column_order = InsertColumnOrder::INSERT_BY_POSITION; + +protected: + InsertStatement(const InsertStatement &other); + +public: + static string OnConflictActionToString(OnConflictAction action); + string ToString() const override; + unique_ptr Copy() const override; + + //! If the INSERT statement is inserted DIRECTLY from a values list (i.e. INSERT INTO tbl VALUES (...)) this returns + //! the expression list Otherwise, this returns NULL + optional_ptr GetValuesList() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/list.hpp b/src/duckdb/src/include/duckdb/parser/statement/list.hpp new file mode 100644 index 00000000..b85cc7c6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/list.hpp @@ -0,0 +1,24 @@ +#include "duckdb/parser/statement/alter_statement.hpp" +#include "duckdb/parser/statement/attach_statement.hpp" +#include "duckdb/parser/statement/call_statement.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/parser/statement/detach_statement.hpp" +#include "duckdb/parser/statement/drop_statement.hpp" +#include "duckdb/parser/statement/execute_statement.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/parser/statement/extension_statement.hpp" +#include "duckdb/parser/statement/export_statement.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/load_statement.hpp" +#include "duckdb/parser/statement/logical_plan_statement.hpp" +#include "duckdb/parser/statement/pragma_statement.hpp" +#include "duckdb/parser/statement/prepare_statement.hpp" +#include "duckdb/parser/statement/relation_statement.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/statement/set_statement.hpp" +#include "duckdb/parser/statement/show_statement.hpp" +#include "duckdb/parser/statement/transaction_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/statement/vacuum_statement.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/statement/load_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/load_statement.hpp new file mode 100644 index 00000000..2333e8f9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/load_statement.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/load_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/parsed_data/load_info.hpp" + +namespace duckdb { + +class LoadStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::LOAD_STATEMENT; + +public: + LoadStatement(); + +protected: + LoadStatement(const LoadStatement &other); + +public: + unique_ptr Copy() const override; + + unique_ptr info; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/logical_plan_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/logical_plan_statement.hpp new file mode 100644 index 00000000..811d1b86 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/logical_plan_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/logical_plan_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalPlanStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::LOGICAL_PLAN_STATEMENT; + +public: + explicit LogicalPlanStatement(unique_ptr plan_p) + : SQLStatement(StatementType::LOGICAL_PLAN_STATEMENT), plan(std::move(plan_p)) {}; + + unique_ptr plan; + +public: + unique_ptr Copy() const override { + throw NotImplementedException("PLAN_STATEMENT"); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/multi_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/multi_statement.hpp new file mode 100644 index 00000000..ef1dfc1c --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/multi_statement.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/multi_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class MultiStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::MULTI_STATEMENT; + +public: + MultiStatement(); + + vector> statements; + +protected: + MultiStatement(const MultiStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/pragma_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/pragma_statement.hpp new file mode 100644 index 00000000..a1b1e9a2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/pragma_statement.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/pragma_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/parsed_data/pragma_info.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +class PragmaStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::PRAGMA_STATEMENT; + +public: + PragmaStatement(); + + unique_ptr info; + +protected: + PragmaStatement(const PragmaStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp new file mode 100644 index 00000000..e3df292d --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/prepare_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class PrepareStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::PREPARE_STATEMENT; + +public: + PrepareStatement(); + + unique_ptr statement; + string name; + +protected: + PrepareStatement(const PrepareStatement &other); + +public: + unique_ptr Copy() const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/relation_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/relation_statement.hpp new file mode 100644 index 00000000..f23bd747 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/relation_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/relation_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/main/relation.hpp" + +namespace duckdb { + +class RelationStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::RELATION_STATEMENT; + +public: + explicit RelationStatement(shared_ptr relation); + + shared_ptr relation; + +protected: + RelationStatement(const RelationStatement &other) = default; + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp new file mode 100644 index 00000000..94581e2c --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/select_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +class QueryNode; +class Serializer; +class Deserializer; + +//! SelectStatement is a typical SELECT clause +class SelectStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::SELECT_STATEMENT; + +public: + SelectStatement() : SQLStatement(StatementType::SELECT_STATEMENT) { + } + + //! The main query node + unique_ptr node; + +protected: + SelectStatement(const SelectStatement &other); + +public: + //! Convert the SELECT statement to a string + DUCKDB_API string ToString() const override; + //! Create a copy of this SelectStatement + DUCKDB_API unique_ptr Copy() const override; + //! Whether or not the statements are equivalent + bool Equals(const SQLStatement &other) const; + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp new file mode 100644 index 00000000..0e1c9cef --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/set_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/set_scope.hpp" +#include "duckdb/common/enums/set_type.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +class SetStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::SET_STATEMENT; + +protected: + SetStatement(std::string name_p, SetScope scope_p, SetType type_p); + SetStatement(const SetStatement &other) = default; + +public: + unique_ptr Copy() const override; + +public: + std::string name; + SetScope scope; + SetType set_type; +}; + +class SetVariableStatement : public SetStatement { +public: + SetVariableStatement(std::string name_p, Value value_p, SetScope scope_p); + +protected: + SetVariableStatement(const SetVariableStatement &other) = default; + +public: + unique_ptr Copy() const override; + +public: + Value value; +}; + +class ResetVariableStatement : public SetStatement { +public: + ResetVariableStatement(std::string name_p, SetScope scope_p); + +protected: + ResetVariableStatement(const ResetVariableStatement &other) = default; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/show_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/show_statement.hpp new file mode 100644 index 00000000..ed5021de --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/show_statement.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/show_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/parsed_data/show_select_info.hpp" + +namespace duckdb { + +class ShowStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::SHOW_STATEMENT; + +public: + ShowStatement(); + + unique_ptr info; + +protected: + ShowStatement(const ShowStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/transaction_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/transaction_statement.hpp new file mode 100644 index 00000000..f34ac05f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/transaction_statement.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/transaction_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/parsed_data/transaction_info.hpp" + +namespace duckdb { + +class TransactionStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::TRANSACTION_STATEMENT; + +public: + explicit TransactionStatement(TransactionType type); + + unique_ptr info; + +protected: + TransactionStatement(const TransactionStatement &other); + +public: + unique_ptr Copy() const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/update_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/update_statement.hpp new file mode 100644 index 00000000..5b156c5a --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/update_statement.hpp @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/update_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +class UpdateSetInfo { +public: + UpdateSetInfo(); + +public: + unique_ptr Copy() const; + +public: + // The condition that needs to be met to perform the update + unique_ptr condition; + // The columns to update + vector columns; + // The set expressions to execute + vector> expressions; + +protected: + UpdateSetInfo(const UpdateSetInfo &other); +}; + +class UpdateStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::UPDATE_STATEMENT; + +public: + UpdateStatement(); + + unique_ptr table; + unique_ptr from_table; + //! keep track of optional returningList if statement contains a RETURNING keyword + vector> returning_list; + unique_ptr set_info; + //! CTEs + CommonTableExpressionMap cte_map; + +protected: + UpdateStatement(const UpdateStatement &other); + +public: + string ToString() const override; + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/vacuum_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/vacuum_statement.hpp new file mode 100644 index 00000000..f84fbff0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/vacuum_statement.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/vacuum_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/parsed_data/vacuum_info.hpp" + +namespace duckdb { + +class VacuumStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::VACUUM_STATEMENT; + +public: + explicit VacuumStatement(const VacuumOptions &options); + + unique_ptr info; + +protected: + VacuumStatement(const VacuumStatement &other); + +public: + unique_ptr Copy() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref.hpp b/src/duckdb/src/include/duckdb/parser/tableref.hpp new file mode 100644 index 00000000..ee4b8820 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref.hpp @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tableref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/tableref_type.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" + +namespace duckdb { + +//! Represents a generic expression that returns a table. +class TableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::INVALID; + +public: + explicit TableRef(TableReferenceType type) : type(type) { + } + virtual ~TableRef() { + } + + TableReferenceType type; + string alias; + //! Sample options (if any) + unique_ptr sample; + //! The location in the query (if any) + idx_t query_location = DConstants::INVALID_INDEX; + +public: + //! Convert the object to a string + virtual string ToString() const = 0; + string BaseToString(string result) const; + string BaseToString(string result, const vector &column_name_alias) const; + void Print(); + + virtual bool Equals(const TableRef &other) const; + static bool Equals(const unique_ptr &left, const unique_ptr &right); + + virtual unique_ptr Copy() = 0; + + //! Copy the properties of this table ref to the target + void CopyProperties(TableRef &target) const; + + virtual void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE && TARGET::TYPE != TableReferenceType::INVALID) { + throw InternalException("Failed to cast constraint to type - constraint type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE && TARGET::TYPE != TableReferenceType::INVALID) { + throw InternalException("Failed to cast constraint to type - constraint type mismatch"); + } + return reinterpret_cast(*this); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp new file mode 100644 index 00000000..4401fb28 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tableref/basetableref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/tableref.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +//! Represents a TableReference to a base table in the schema +class BaseTableRef : public TableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::BASE_TABLE; + +public: + BaseTableRef() + : TableRef(TableReferenceType::BASE_TABLE), catalog_name(INVALID_CATALOG), schema_name(INVALID_SCHEMA) { + } + + //! The catalog name + string catalog_name; + //! Schema name + string schema_name; + //! Table name + string table_name; + //! Aliases for the column names + vector column_name_alias; + +public: + string ToString() const override; + bool Equals(const TableRef &other_p) const override; + + unique_ptr Copy() override; + + //! Deserializes a blob back into a BaseTableRef + void Serialize(Serializer &serializer) const override; + + static unique_ptr Deserialize(Deserializer &source); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/emptytableref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/emptytableref.hpp new file mode 100644 index 00000000..1f4758e4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref/emptytableref.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tableref/emptytableref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/tableref.hpp" + +namespace duckdb { + +class EmptyTableRef : public TableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::EMPTY; + +public: + EmptyTableRef() : TableRef(TableReferenceType::EMPTY) { + } + +public: + string ToString() const override; + bool Equals(const TableRef &other_p) const override; + + unique_ptr Copy() override; + + //! Deserializes a blob back into a DummyTableRef + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/expressionlistref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/expressionlistref.hpp new file mode 100644 index 00000000..7d0bd86e --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref/expressionlistref.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tableref/expressionlistref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/tableref.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +//! Represents an expression list as generated by a VALUES statement +class ExpressionListRef : public TableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::EXPRESSION_LIST; + +public: + ExpressionListRef() : TableRef(TableReferenceType::EXPRESSION_LIST) { + } + + //! Value list, only used for VALUES statement + vector>> values; + //! Expected SQL types + vector expected_types; + //! The set of expected names + vector expected_names; + +public: + string ToString() const override; + bool Equals(const TableRef &other_p) const override; + + unique_ptr Copy() override; + + //! Deserializes a blob back into a ExpressionListRef + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp new file mode 100644 index 00000000..1f53f5c8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tableref/joinref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/enums/joinref_type.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +//! Represents a JOIN between two expressions +class JoinRef : public TableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::JOIN; + +public: + explicit JoinRef(JoinRefType ref_type = JoinRefType::REGULAR) + : TableRef(TableReferenceType::JOIN), type(JoinType::INNER), ref_type(ref_type) { + } + + //! The left hand side of the join + unique_ptr left; + //! The right hand side of the join + unique_ptr right; + //! The join condition + unique_ptr condition; + //! The join type + JoinType type; + //! Join condition type + JoinRefType ref_type; + //! The set of USING columns (if any) + vector using_columns; + +public: + string ToString() const override; + bool Equals(const TableRef &other_p) const override; + + unique_ptr Copy() override; + + //! Deserializes a blob back into a JoinRef + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/list.hpp b/src/duckdb/src/include/duckdb/parser/tableref/list.hpp new file mode 100644 index 00000000..5500ff22 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref/list.hpp @@ -0,0 +1,7 @@ +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/tableref/emptytableref.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/parser/tableref/joinref.hpp" +#include "duckdb/parser/tableref/pivotref.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/tableref/pivotref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/pivotref.hpp new file mode 100644 index 00000000..b4a5421f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref/pivotref.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tableref/pivotref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/tableref.hpp" +#include "duckdb/parser/query_node/select_node.hpp" + +namespace duckdb { + +struct PivotColumnEntry { + //! The set of values to match on + vector values; + //! The star expression (UNPIVOT only) + unique_ptr star_expr; + //! The alias of the pivot column entry + string alias; + + bool Equals(const PivotColumnEntry &other) const; + PivotColumnEntry Copy() const; + + void Serialize(Serializer &serializer) const; + static PivotColumnEntry Deserialize(Deserializer &source); +}; + +struct PivotValueElement { + vector values; + string name; +}; + +struct PivotColumn { + //! The set of expressions to pivot on + vector> pivot_expressions; + //! The set of unpivot names + vector unpivot_names; + //! The set of values to pivot on + vector entries; + //! The enum to read pivot values from (if any) + string pivot_enum; + //! Subquery (if any) - used during transform only + unique_ptr subquery; + + string ToString() const; + bool Equals(const PivotColumn &other) const; + PivotColumn Copy() const; + + void Serialize(Serializer &serializer) const; + static PivotColumn Deserialize(Deserializer &source); +}; + +//! Represents a PIVOT or UNPIVOT expression +class PivotRef : public TableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::PIVOT; + +public: + explicit PivotRef() : TableRef(TableReferenceType::PIVOT), include_nulls(false) { + } + + //! The source table of the pivot + unique_ptr source; + //! The aggregates to compute over the pivot (PIVOT only) + vector> aggregates; + //! The names of the unpivot expressions (UNPIVOT only) + vector unpivot_names; + //! The set of pivots + vector pivots; + //! The groups to pivot over. If none are specified all columns not included in the pivots/aggregate are chosen. + vector groups; + //! Aliases for the column names + vector column_name_alias; + //! Whether or not to include nulls in the result (UNPIVOT only) + bool include_nulls; + //! The set of values to pivot on (bound pivot only) + vector bound_pivot_values; + //! The set of bound group names (bound pivot only) + vector bound_group_names; + //! The set of bound aggregate names (bound pivot only) + vector bound_aggregate_names; + +public: + string ToString() const override; + bool Equals(const TableRef &other_p) const override; + + unique_ptr Copy() override; + + //! Deserializes a blob back into a PivotRef + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/subqueryref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/subqueryref.hpp new file mode 100644 index 00000000..5e562ced --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref/subqueryref.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tableref/subqueryref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/tableref.hpp" + +namespace duckdb { +//! Represents a subquery +class SubqueryRef : public TableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::SUBQUERY; + +private: + SubqueryRef(); + +public: + DUCKDB_API explicit SubqueryRef(unique_ptr subquery, string alias = string()); + + //! The subquery + unique_ptr subquery; + //! Aliases for the column names + vector column_name_alias; + +public: + string ToString() const override; + bool Equals(const TableRef &other_p) const override; + + unique_ptr Copy() override; + + //! Deserializes a blob back into a SubqueryRef + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp new file mode 100644 index 00000000..7d782ca0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tableref/table_function_ref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/main/external_dependencies.hpp" + +namespace duckdb { +//! Represents a Table producing function +class TableFunctionRef : public TableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::TABLE_FUNCTION; + +public: + DUCKDB_API TableFunctionRef(); + + unique_ptr function; + vector column_name_alias; + + // if the function takes a subquery as argument its in here + unique_ptr subquery; + + // External dependencies of this table function + unique_ptr external_dependency; + +public: + string ToString() const override; + + bool Equals(const TableRef &other_p) const override; + + unique_ptr Copy() override; + + //! Deserializes a blob back into a BaseTableRef + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tokens.hpp b/src/duckdb/src/include/duckdb/parser/tokens.hpp new file mode 100644 index 00000000..083630dd --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/tokens.hpp @@ -0,0 +1,106 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/tokens.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Statements +//===--------------------------------------------------------------------===// +class SQLStatement; + +class AlterStatement; +class AttachStatement; +class CallStatement; +class CopyStatement; +class CreateStatement; +class DetachStatement; +class DeleteStatement; +class DropStatement; +class ExtensionStatement; +class InsertStatement; +class SelectStatement; +class TransactionStatement; +class UpdateStatement; +class PrepareStatement; +class ExecuteStatement; +class PragmaStatement; +class ShowStatement; +class ExplainStatement; +class ExportStatement; +class VacuumStatement; +class RelationStatement; +class SetStatement; +class SetVariableStatement; +class ResetVariableStatement; +class LoadStatement; +class LogicalPlanStatement; +class MultiStatement; + +//===--------------------------------------------------------------------===// +// Query Node +//===--------------------------------------------------------------------===// +class QueryNode; +class SelectNode; +class SetOperationNode; +class RecursiveCTENode; +class CTENode; + +//===--------------------------------------------------------------------===// +// Expressions +//===--------------------------------------------------------------------===// +class ParsedExpression; + +class BetweenExpression; +class CaseExpression; +class CastExpression; +class CollateExpression; +class ColumnRefExpression; +class ComparisonExpression; +class ConjunctionExpression; +class ConstantExpression; +class DefaultExpression; +class FunctionExpression; +class LambdaExpression; +class OperatorExpression; +class ParameterExpression; +class PositionalReferenceExpression; +class StarExpression; +class SubqueryExpression; +class WindowExpression; + +//===--------------------------------------------------------------------===// +// Constraints +//===--------------------------------------------------------------------===// +class Constraint; + +class NotNullConstraint; +class CheckConstraint; +class UniqueConstraint; +class ForeignKeyConstraint; + +//===--------------------------------------------------------------------===// +// TableRefs +//===--------------------------------------------------------------------===// +class TableRef; + +class BaseTableRef; +class JoinRef; +class SubqueryRef; +class TableFunctionRef; +class EmptyTableRef; +class ExpressionListRef; +class PivotRef; + +//===--------------------------------------------------------------------===// +// Other +//===--------------------------------------------------------------------===// +struct SampleOptions; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/transformer.hpp b/src/duckdb/src/include/duckdb/parser/transformer.hpp new file mode 100644 index 00000000..d16dfb8c --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/transformer.hpp @@ -0,0 +1,364 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/transformer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/stack_checker.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/parser/group_by_node.hpp" +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/parser/qualified_name.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/parser/tokens.hpp" +#include "nodes/parsenodes.hpp" +#include "nodes/primnodes.hpp" +#include "pg_definitions.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" + +namespace duckdb { + +class ColumnDefinition; +struct OrderByNode; +struct CopyInfo; +struct CommonTableExpressionInfo; +struct GroupingExpressionMap; +class OnConflictInfo; +class UpdateSetInfo; +struct ParserOptions; +struct PivotColumn; + +//! The transformer class is responsible for transforming the internal Postgres +//! parser representation into the DuckDB representation +class Transformer { + friend class StackChecker; + + struct CreatePivotEntry { + string enum_name; + unique_ptr base; + unique_ptr column; + unique_ptr subquery; + }; + +public: + explicit Transformer(ParserOptions &options); + explicit Transformer(Transformer &parent); + ~Transformer(); + + //! Transforms a Postgres parse tree into a set of SQL Statements + bool TransformParseTree(duckdb_libpgquery::PGList *tree, vector> &statements); + string NodetypeToString(duckdb_libpgquery::PGNodeTag type); + + idx_t ParamCount() const; + +private: + optional_ptr parent; + //! Parser options + ParserOptions &options; + //! The current prepared statement parameter index + idx_t prepared_statement_parameter_index = 0; + //! Map from named parameter to parameter index; + case_insensitive_map_t named_param_map; + //! Last parameter type + PreparedParamType last_param_type = PreparedParamType::INVALID; + //! Holds window expressions defined by name. We need those when transforming the expressions referring to them. + unordered_map window_clauses; + //! The set of pivot entries to create + vector> pivot_entries; + //! Sets of stored CTEs, if any + vector stored_cte_map; + //! Whether or not we are currently binding a window definition + bool in_window_definition = false; + + void Clear(); + bool InWindowDefinition(); + + Transformer &RootTransformer(); + const Transformer &RootTransformer() const; + void SetParamCount(idx_t new_count); + void SetParam(const string &name, idx_t index, PreparedParamType type); + bool GetParam(const string &name, idx_t &index, PreparedParamType type); + + void AddPivotEntry(string enum_name, unique_ptr source, unique_ptr column, + unique_ptr subquery); + unique_ptr GenerateCreateEnumStmt(unique_ptr entry); + bool HasPivotEntries(); + idx_t PivotEntryCount(); + vector> &GetPivotEntries(); + void PivotEntryCheck(const string &type); + void ExtractCTEsRecursive(CommonTableExpressionMap &cte_map); + +private: + //! Transforms a Postgres statement into a single SQL statement + unique_ptr TransformStatement(duckdb_libpgquery::PGNode &stmt); + //! Transforms a Postgres statement into a single SQL statement + unique_ptr TransformStatementInternal(duckdb_libpgquery::PGNode &stmt); + //===--------------------------------------------------------------------===// + // Statement transformation + //===--------------------------------------------------------------------===// + //! Transform a Postgres duckdb_libpgquery::T_PGSelectStmt node into a SelectStatement + unique_ptr TransformSelect(optional_ptr node, bool is_select = true); + //! Transform a Postgres duckdb_libpgquery::T_PGSelectStmt node into a SelectStatement + unique_ptr TransformSelect(duckdb_libpgquery::PGSelectStmt &select, bool is_select = true); + //! Transform a Postgres T_AlterStmt node into a AlterStatement + unique_ptr TransformAlter(duckdb_libpgquery::PGAlterTableStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGRenameStmt node into a RenameStatement + unique_ptr TransformRename(duckdb_libpgquery::PGRenameStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGCreateStmt node into a CreateStatement + unique_ptr TransformCreateTable(duckdb_libpgquery::PGCreateStmt &node); + //! Transform a Postgres duckdb_libpgquery::T_PGCreateStmt node into a CreateStatement + unique_ptr TransformCreateTableAs(duckdb_libpgquery::PGCreateTableAsStmt &stmt); + //! Transform a Postgres node into a CreateStatement + unique_ptr TransformCreateSchema(duckdb_libpgquery::PGCreateSchemaStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGCreateSeqStmt node into a CreateStatement + unique_ptr TransformCreateSequence(duckdb_libpgquery::PGCreateSeqStmt &node); + //! Transform a Postgres duckdb_libpgquery::T_PGViewStmt node into a CreateStatement + unique_ptr TransformCreateView(duckdb_libpgquery::PGViewStmt &node); + //! Transform a Postgres duckdb_libpgquery::T_PGIndexStmt node into CreateStatement + unique_ptr TransformCreateIndex(duckdb_libpgquery::PGIndexStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGCreateFunctionStmt node into CreateStatement + unique_ptr TransformCreateFunction(duckdb_libpgquery::PGCreateFunctionStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGCreateTypeStmt node into CreateStatement + unique_ptr TransformCreateType(duckdb_libpgquery::PGCreateTypeStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGAlterSeqStmt node into CreateStatement + unique_ptr TransformAlterSequence(duckdb_libpgquery::PGAlterSeqStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGDropStmt node into a Drop[Table,Schema]Statement + unique_ptr TransformDrop(duckdb_libpgquery::PGDropStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGInsertStmt node into a InsertStatement + unique_ptr TransformInsert(duckdb_libpgquery::PGInsertStmt &stmt); + + //! Transform a Postgres duckdb_libpgquery::T_PGOnConflictClause node into a OnConflictInfo + unique_ptr TransformOnConflictClause(duckdb_libpgquery::PGOnConflictClause *node, + const string &relname); + //! Transform a ON CONFLICT shorthand into a OnConflictInfo + unique_ptr DummyOnConflictClause(duckdb_libpgquery::PGOnConflictActionAlias type, + const string &relname); + //! Transform a Postgres duckdb_libpgquery::T_PGCopyStmt node into a CopyStatement + unique_ptr TransformCopy(duckdb_libpgquery::PGCopyStmt &stmt); + void TransformCopyOptions(CopyInfo &info, optional_ptr options); + //! Transform a Postgres duckdb_libpgquery::T_PGTransactionStmt node into a TransactionStatement + unique_ptr TransformTransaction(duckdb_libpgquery::PGTransactionStmt &stmt); + //! Transform a Postgres T_DeleteStatement node into a DeleteStatement + unique_ptr TransformDelete(duckdb_libpgquery::PGDeleteStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGUpdateStmt node into a UpdateStatement + unique_ptr TransformUpdate(duckdb_libpgquery::PGUpdateStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGPragmaStmt node into a PragmaStatement + unique_ptr TransformPragma(duckdb_libpgquery::PGPragmaStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGExportStmt node into a ExportStatement + unique_ptr TransformExport(duckdb_libpgquery::PGExportStmt &stmt); + //! Transform a Postgres duckdb_libpgquery::T_PGImportStmt node into a PragmaStatement + unique_ptr TransformImport(duckdb_libpgquery::PGImportStmt &stmt); + unique_ptr TransformExplain(duckdb_libpgquery::PGExplainStmt &stmt); + unique_ptr TransformVacuum(duckdb_libpgquery::PGVacuumStmt &stmt); + unique_ptr TransformShow(duckdb_libpgquery::PGVariableShowStmt &stmt); + unique_ptr TransformShowSelect(duckdb_libpgquery::PGVariableShowSelectStmt &stmt); + unique_ptr TransformAttach(duckdb_libpgquery::PGAttachStmt &stmt); + unique_ptr TransformDetach(duckdb_libpgquery::PGDetachStmt &stmt); + unique_ptr TransformUse(duckdb_libpgquery::PGUseStmt &stmt); + + unique_ptr TransformPrepare(duckdb_libpgquery::PGPrepareStmt &stmt); + unique_ptr TransformExecute(duckdb_libpgquery::PGExecuteStmt &stmt); + unique_ptr TransformCall(duckdb_libpgquery::PGCallStmt &stmt); + unique_ptr TransformDeallocate(duckdb_libpgquery::PGDeallocateStmt &stmt); + unique_ptr TransformPivotStatement(duckdb_libpgquery::PGSelectStmt &select); + unique_ptr CreatePivotStatement(unique_ptr statement); + PivotColumn TransformPivotColumn(duckdb_libpgquery::PGPivot &pivot); + vector TransformPivotList(duckdb_libpgquery::PGList &list); + + //===--------------------------------------------------------------------===// + // SetStatement Transform + //===--------------------------------------------------------------------===// + unique_ptr TransformSet(duckdb_libpgquery::PGVariableSetStmt &set); + unique_ptr TransformSetVariable(duckdb_libpgquery::PGVariableSetStmt &stmt); + unique_ptr TransformResetVariable(duckdb_libpgquery::PGVariableSetStmt &stmt); + + unique_ptr TransformCheckpoint(duckdb_libpgquery::PGCheckPointStmt &stmt); + unique_ptr TransformLoad(duckdb_libpgquery::PGLoadStmt &stmt); + + //===--------------------------------------------------------------------===// + // Query Node Transform + //===--------------------------------------------------------------------===// + //! Transform a Postgres duckdb_libpgquery::T_PGSelectStmt node into a QueryNode + unique_ptr TransformSelectNode(duckdb_libpgquery::PGSelectStmt &select); + unique_ptr TransformSelectInternal(duckdb_libpgquery::PGSelectStmt &select); + void TransformModifiers(duckdb_libpgquery::PGSelectStmt &stmt, QueryNode &node); + + //===--------------------------------------------------------------------===// + // Expression Transform + //===--------------------------------------------------------------------===// + //! Transform a Postgres boolean expression into an Expression + unique_ptr TransformBoolExpr(duckdb_libpgquery::PGBoolExpr &root); + //! Transform a Postgres case expression into an Expression + unique_ptr TransformCase(duckdb_libpgquery::PGCaseExpr &root); + //! Transform a Postgres type cast into an Expression + unique_ptr TransformTypeCast(duckdb_libpgquery::PGTypeCast &root); + //! Transform a Postgres coalesce into an Expression + unique_ptr TransformCoalesce(duckdb_libpgquery::PGAExpr &root); + //! Transform a Postgres column reference into an Expression + unique_ptr TransformColumnRef(duckdb_libpgquery::PGColumnRef &root); + //! Transform a Postgres constant value into an Expression + unique_ptr TransformValue(duckdb_libpgquery::PGValue val); + //! Transform a Postgres operator into an Expression + unique_ptr TransformAExpr(duckdb_libpgquery::PGAExpr &root); + unique_ptr TransformAExprInternal(duckdb_libpgquery::PGAExpr &root); + //! Transform a Postgres abstract expression into an Expression + unique_ptr TransformExpression(optional_ptr node); + unique_ptr TransformExpression(duckdb_libpgquery::PGNode &node); + //! Transform a Postgres function call into an Expression + unique_ptr TransformFuncCall(duckdb_libpgquery::PGFuncCall &root); + //! Transform a Postgres boolean expression into an Expression + unique_ptr TransformInterval(duckdb_libpgquery::PGIntervalConstant &root); + //! Transform a Postgres lambda node [e.g. (x, y) -> x + y] into a lambda expression + unique_ptr TransformLambda(duckdb_libpgquery::PGLambdaFunction &node); + //! Transform a Postgres array access node (e.g. x[1] or x[1:3]) + unique_ptr TransformArrayAccess(duckdb_libpgquery::PGAIndirection &node); + //! Transform a positional reference (e.g. #1) + unique_ptr TransformPositionalReference(duckdb_libpgquery::PGPositionalReference &node); + unique_ptr TransformStarExpression(duckdb_libpgquery::PGAStar &node); + unique_ptr TransformBooleanTest(duckdb_libpgquery::PGBooleanTest &node); + + //! Transform a Postgres constant value into an Expression + unique_ptr TransformConstant(duckdb_libpgquery::PGAConst &c); + unique_ptr TransformGroupingFunction(duckdb_libpgquery::PGGroupingFunc &n); + unique_ptr TransformResTarget(duckdb_libpgquery::PGResTarget &root); + unique_ptr TransformNullTest(duckdb_libpgquery::PGNullTest &root); + unique_ptr TransformParamRef(duckdb_libpgquery::PGParamRef &node); + unique_ptr TransformNamedArg(duckdb_libpgquery::PGNamedArgExpr &root); + + //! Transform multi assignment reference into an Expression + unique_ptr TransformMultiAssignRef(duckdb_libpgquery::PGMultiAssignRef &root); + + unique_ptr TransformSQLValueFunction(duckdb_libpgquery::PGSQLValueFunction &node); + + unique_ptr TransformSubquery(duckdb_libpgquery::PGSubLink &root); + //===--------------------------------------------------------------------===// + // Constraints transform + //===--------------------------------------------------------------------===// + unique_ptr TransformConstraint(duckdb_libpgquery::PGListCell *cell); + + unique_ptr TransformConstraint(duckdb_libpgquery::PGListCell *cell, ColumnDefinition &column, + idx_t index); + + //===--------------------------------------------------------------------===// + // Update transform + //===--------------------------------------------------------------------===// + unique_ptr TransformUpdateSetInfo(duckdb_libpgquery::PGList *target_list, + duckdb_libpgquery::PGNode *where_clause); + + //===--------------------------------------------------------------------===// + // Index transform + //===--------------------------------------------------------------------===// + vector> TransformIndexParameters(duckdb_libpgquery::PGList &list, + const string &relation_name); + + //===--------------------------------------------------------------------===// + // Collation transform + //===--------------------------------------------------------------------===// + unique_ptr TransformCollateExpr(duckdb_libpgquery::PGCollateClause &collate); + + string TransformCollation(optional_ptr collate); + + ColumnDefinition TransformColumnDefinition(duckdb_libpgquery::PGColumnDef &cdef); + //===--------------------------------------------------------------------===// + // Helpers + //===--------------------------------------------------------------------===// + OnCreateConflict TransformOnConflict(duckdb_libpgquery::PGOnCreateConflict conflict); + string TransformAlias(duckdb_libpgquery::PGAlias *root, vector &column_name_alias); + vector TransformStringList(duckdb_libpgquery::PGList *list); + void TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map, + vector> &materialized_ctes); + static unique_ptr TransformMaterializedCTE(unique_ptr root, + vector> &materialized_ctes); + unique_ptr TransformRecursiveCTE(duckdb_libpgquery::PGCommonTableExpr &node, + CommonTableExpressionInfo &info); + + unique_ptr TransformUnaryOperator(const string &op, unique_ptr child); + unique_ptr TransformBinaryOperator(string op, unique_ptr left, + unique_ptr right); + static bool ConstructConstantFromExpression(const ParsedExpression &expr, Value &value); + //===--------------------------------------------------------------------===// + // TableRef transform + //===--------------------------------------------------------------------===// + //! Transform a Postgres node into a TableRef + unique_ptr TransformTableRefNode(duckdb_libpgquery::PGNode &n); + //! Transform a Postgres FROM clause into a TableRef + unique_ptr TransformFrom(optional_ptr root); + //! Transform a Postgres table reference into a TableRef + unique_ptr TransformRangeVar(duckdb_libpgquery::PGRangeVar &root); + //! Transform a Postgres table-producing function into a TableRef + unique_ptr TransformRangeFunction(duckdb_libpgquery::PGRangeFunction &root); + //! Transform a Postgres join node into a TableRef + unique_ptr TransformJoin(duckdb_libpgquery::PGJoinExpr &root); + //! Transform a Postgres pivot node into a TableRef + unique_ptr TransformPivot(duckdb_libpgquery::PGPivotExpr &root); + //! Transform a table producing subquery into a TableRef + unique_ptr TransformRangeSubselect(duckdb_libpgquery::PGRangeSubselect &root); + //! Transform a VALUES list into a set of expressions + unique_ptr TransformValuesList(duckdb_libpgquery::PGList *list); + + //! Transform a range var into a (schema) qualified name + QualifiedName TransformQualifiedName(duckdb_libpgquery::PGRangeVar &root); + + //! Transform a Postgres TypeName string into a LogicalType + LogicalType TransformTypeName(duckdb_libpgquery::PGTypeName &name); + + //! Transform a Postgres GROUP BY expression into a list of Expression + bool TransformGroupBy(optional_ptr group, SelectNode &result); + void TransformGroupByNode(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, SelectNode &result, + vector &result_sets); + void AddGroupByExpression(unique_ptr expression, GroupingExpressionMap &map, GroupByNode &result, + vector &result_set); + void TransformGroupByExpression(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, GroupByNode &result, + vector &result_set); + //! Transform a Postgres ORDER BY expression into an OrderByDescription + bool TransformOrderBy(duckdb_libpgquery::PGList *order, vector &result); + + //! Transform a Postgres SELECT clause into a list of Expressions + void TransformExpressionList(duckdb_libpgquery::PGList &list, vector> &result); + + //! Transform a Postgres PARTITION BY/ORDER BY specification into lists of expressions + void TransformWindowDef(duckdb_libpgquery::PGWindowDef &window_spec, WindowExpression &expr, + const char *window_name = nullptr); + //! Transform a Postgres window frame specification into frame expressions + void TransformWindowFrame(duckdb_libpgquery::PGWindowDef &window_spec, WindowExpression &expr); + + unique_ptr TransformSampleOptions(optional_ptr options); + //! Returns true if an expression is only a star (i.e. "*", without any other decorators) + bool ExpressionIsEmptyStar(ParsedExpression &expr); + + OnEntryNotFound TransformOnEntryNotFound(bool missing_ok); + + Vector PGListToVector(optional_ptr column_list, idx_t &size); + vector TransformConflictTarget(duckdb_libpgquery::PGList &list); + +private: + //! Current stack depth + idx_t stack_depth; + + void InitializeStackCheck(); + StackChecker StackCheck(idx_t extra_stack = 1); + +public: + template + static T &PGCast(duckdb_libpgquery::PGNode &node) { + return reinterpret_cast(node); + } + template + static optional_ptr PGPointerCast(void *ptr) { + return optional_ptr(reinterpret_cast(ptr)); + } +}; + +vector ReadPgListToString(duckdb_libpgquery::PGList *column_list); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bind_context.hpp b/src/duckdb/src/include/duckdb/planner/bind_context.hpp new file mode 100644 index 00000000..a0f5461c --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/bind_context.hpp @@ -0,0 +1,166 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/bind_context.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/qualified_name_set.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/table_binding.hpp" + +namespace duckdb { +class Binder; +class LogicalGet; +class BoundQueryNode; + +class StarExpression; + +class TableCatalogEntry; +class TableFunctionCatalogEntry; + +struct UsingColumnSet { + string primary_binding; + unordered_set bindings; +}; + +//! The BindContext object keeps track of all the tables and columns that are +//! encountered during the binding process. +class BindContext { +public: + //! Keep track of recursive CTE references + case_insensitive_map_t> cte_references; + +public: + //! Given a column name, find the matching table it belongs to. Throws an + //! exception if no table has a column of the given name. + string GetMatchingBinding(const string &column_name); + //! Like GetMatchingBinding, but instead of throwing an error if multiple tables have the same binding it will + //! return a list of all the matching ones + unordered_set GetMatchingBindings(const string &column_name); + //! Like GetMatchingBindings, but returns the top 3 most similar bindings (in levenshtein distance) instead of the + //! matching ones + vector GetSimilarBindings(const string &column_name); + + optional_ptr GetCTEBinding(const string &ctename); + //! Binds a column expression to the base table. Returns the bound expression + //! or throws an exception if the column could not be bound. + BindResult BindColumn(ColumnRefExpression &colref, idx_t depth); + string BindColumn(PositionalReferenceExpression &ref, string &table_name, string &column_name); + unique_ptr PositionToColumn(PositionalReferenceExpression &ref); + + unique_ptr ExpandGeneratedColumn(const string &table_name, const string &column_name); + + unique_ptr CreateColumnReference(const string &table_name, const string &column_name); + unique_ptr CreateColumnReference(const string &schema_name, const string &table_name, + const string &column_name); + unique_ptr CreateColumnReference(const string &catalog_name, const string &schema_name, + const string &table_name, const string &column_name); + + //! Generate column expressions for all columns that are present in the + //! referenced tables. This is used to resolve the * expression in a + //! selection list. + void GenerateAllColumnExpressions(StarExpression &expr, vector> &new_select_list); + //! Check if the given (binding, column_name) is in the exclusion/replacement lists. + //! Returns true if it is in one of these lists, and should therefore be skipped. + bool CheckExclusionList(StarExpression &expr, const string &column_name, + vector> &new_select_list, + case_insensitive_set_t &excluded_columns); + + const vector> &GetBindingsList() { + return bindings_list; + } + + void GetTypesAndNames(vector &result_names, vector &result_types); + + //! Adds a base table with the given alias to the BindContext. + void AddBaseTable(idx_t index, const string &alias, const vector &names, const vector &types, + vector &bound_column_ids, StandardEntry *entry, bool add_row_id = true); + //! Adds a call to a table function with the given alias to the BindContext. + void AddTableFunction(idx_t index, const string &alias, const vector &names, + const vector &types, vector &bound_column_ids, StandardEntry *entry); + //! Adds a table view with a given alias to the BindContext. + void AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, ViewCatalogEntry *view); + //! Adds a subquery with a given alias to the BindContext. + void AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery); + //! Adds a subquery with a given alias to the BindContext. + void AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery); + //! Adds a binding to a catalog entry with a given alias to the BindContext. + void AddEntryBinding(idx_t index, const string &alias, const vector &names, + const vector &types, StandardEntry &entry); + //! Adds a base table with the given alias to the BindContext. + void AddGenericBinding(idx_t index, const string &alias, const vector &names, + const vector &types); + + //! Adds a base table with the given alias to the CTE BindContext. + //! We need this to correctly bind recursive CTEs with multiple references. + void AddCTEBinding(idx_t index, const string &alias, const vector &names, const vector &types); + + //! Add an implicit join condition (e.g. USING (x)) + void AddUsingBinding(const string &column_name, UsingColumnSet &set); + + void AddUsingBindingSet(unique_ptr set); + + //! Returns any using column set for the given column name, or nullptr if there is none. On conflict (multiple using + //! column sets with the same name) throw an exception. + optional_ptr GetUsingBinding(const string &column_name); + //! Returns any using column set for the given column name, or nullptr if there is none + optional_ptr GetUsingBinding(const string &column_name, const string &binding_name); + //! Erase a using binding from the set of using bindings + void RemoveUsingBinding(const string &column_name, UsingColumnSet &set); + //! Transfer a using binding from one bind context to this bind context + void TransferUsingBinding(BindContext ¤t_context, optional_ptr current_set, + UsingColumnSet &new_set, const string &binding, const string &using_column); + + //! Fetch the actual column name from the given binding, or throws if none exists + //! This can be different from "column_name" because of case insensitivity + //! (e.g. "column_name" might return "COLUMN_NAME") + string GetActualColumnName(const string &binding, const string &column_name); + + case_insensitive_map_t> GetCTEBindings() { + return cte_bindings; + } + void SetCTEBindings(case_insensitive_map_t> bindings) { + cte_bindings = bindings; + } + + //! Alias a set of column names for the specified table, using the original names if there are not enough aliases + //! specified. + static vector AliasColumnNames(const string &table_name, const vector &names, + const vector &column_aliases); + + //! Add all the bindings from a BindContext to this BindContext. The other BindContext is destroyed in the process. + void AddContext(BindContext other); + //! For semi and anti joins we remove the binding context of the right table after binding the condition. + void RemoveContext(vector> &other_bindings_list); + + //! Gets a binding of the specified name. Returns a nullptr and sets the out_error if the binding could not be + //! found. + optional_ptr GetBinding(const string &name, string &out_error); + +private: + void AddBinding(const string &alias, unique_ptr binding); + +private: + //! The set of bindings + case_insensitive_map_t> bindings; + //! The list of bindings in insertion order + vector> bindings_list; + //! The set of columns used in USING join conditions + case_insensitive_map_t> using_columns; + //! Using column sets + vector> using_column_sets; + + //! The set of CTE bindings + case_insensitive_map_t> cte_bindings; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp new file mode 100644 index 00000000..90a7ae94 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -0,0 +1,375 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/parser/tokens.hpp" +#include "duckdb/planner/bind_context.hpp" +#include "duckdb/planner/bound_statement.hpp" +#include "duckdb/planner/bound_tokens.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/joinside.hpp" +#include "duckdb/common/reference_map.hpp" + +namespace duckdb { +class BoundResultModifier; +class BoundSelectNode; +class ClientContext; +class ExpressionBinder; +class LimitModifier; +class OrderBinder; +class TableCatalogEntry; +class ViewCatalogEntry; +class TableMacroCatalogEntry; +class UpdateSetInfo; +class LogicalProjection; + +class ColumnList; +class ExternalDependency; +class TableFunction; +class TableStorageInfo; + +struct CreateInfo; +struct BoundCreateTableInfo; +struct BoundCreateFunctionInfo; +struct CommonTableExpressionInfo; +struct BoundParameterMap; + +enum class BindingMode : uint8_t { STANDARD_BINDING, EXTRACT_NAMES }; + +struct CorrelatedColumnInfo { + ColumnBinding binding; + LogicalType type; + string name; + idx_t depth; + + CorrelatedColumnInfo(ColumnBinding binding, LogicalType type_p, string name_p, idx_t depth) + : binding(binding), type(std::move(type_p)), name(std::move(name_p)), depth(depth) { + } + explicit CorrelatedColumnInfo(BoundColumnRefExpression &expr) + : CorrelatedColumnInfo(expr.binding, expr.return_type, expr.GetName(), expr.depth) { + } + + bool operator==(const CorrelatedColumnInfo &rhs) const { + return binding == rhs.binding; + } +}; + +//! Bind the parsed query tree to the actual columns present in the catalog. +/*! + The binder is responsible for binding tables and columns to actual physical + tables and columns in the catalog. In the process, it also resolves types of + all expressions. +*/ +class Binder : public std::enable_shared_from_this { + friend class ExpressionBinder; + friend class RecursiveDependentJoinPlanner; + +public: + DUCKDB_API static shared_ptr CreateBinder(ClientContext &context, optional_ptr parent = nullptr, + bool inherit_ctes = true); + + //! The client context + ClientContext &context; + //! A mapping of names to common table expressions + case_insensitive_map_t> CTE_bindings; // NOLINT + //! The CTEs that have already been bound + reference_set_t bound_ctes; + //! The bind context + BindContext bind_context; + //! The set of correlated columns bound by this binder (FIXME: this should probably be an unordered_set and not a + //! vector) + vector correlated_columns; + //! The set of parameter expressions bound by this binder + optional_ptr parameters; + //! Statement properties + StatementProperties properties; + //! The alias for the currently processing subquery, if it exists + string alias; + //! Macro parameter bindings (if any) + optional_ptr macro_binding; + //! The intermediate lambda bindings to bind nested lambdas (if any) + optional_ptr> lambda_bindings; + +public: + DUCKDB_API BoundStatement Bind(SQLStatement &statement); + DUCKDB_API BoundStatement Bind(QueryNode &node); + + unique_ptr BindCreateTableInfo(unique_ptr info); + unique_ptr BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema); + + vector> BindCreateIndexExpressions(TableCatalogEntry &table, CreateIndexInfo &info); + + void BindCreateViewInfo(CreateViewInfo &base); + SchemaCatalogEntry &BindSchema(CreateInfo &info); + SchemaCatalogEntry &BindCreateFunctionInfo(CreateInfo &info); + + //! Check usage, and cast named parameters to their types + static void BindNamedParameters(named_parameter_type_map_t &types, named_parameter_map_t &values, + QueryErrorContext &error_context, string &func_name); + + unique_ptr Bind(TableRef &ref); + unique_ptr CreatePlan(BoundTableRef &ref); + + //! Generates an unused index for a table + idx_t GenerateTableIndex(); + + //! Add a common table expression to the binder + void AddCTE(const string &name, CommonTableExpressionInfo &cte); + //! Find a common table expression by name; returns nullptr if none exists + optional_ptr FindCTE(const string &name, bool skip = false); + + bool CTEIsAlreadyBound(CommonTableExpressionInfo &cte); + + //! Add the view to the set of currently bound views - used for detecting recursive view definitions + void AddBoundView(ViewCatalogEntry &view); + + void PushExpressionBinder(ExpressionBinder &binder); + void PopExpressionBinder(); + void SetActiveBinder(ExpressionBinder &binder); + ExpressionBinder &GetActiveBinder(); + bool HasActiveBinder(); + + vector> &GetActiveBinders(); + + void MergeCorrelatedColumns(vector &other); + //! Add a correlated column to this binder (if it does not exist) + void AddCorrelatedColumn(const CorrelatedColumnInfo &info); + + string FormatError(ParsedExpression &expr_context, const string &message); + string FormatError(TableRef &ref_context, const string &message); + + string FormatErrorRecursive(idx_t query_location, const string &message, vector &values); + template + string FormatErrorRecursive(idx_t query_location, const string &msg, vector &values, T param, + ARGS... params) { + values.push_back(ExceptionFormatValue::CreateFormatValue(param)); + return FormatErrorRecursive(query_location, msg, values, params...); + } + + template + string FormatError(idx_t query_location, const string &msg, ARGS... params) { + vector values; + return FormatErrorRecursive(query_location, msg, values, params...); + } + + unique_ptr BindUpdateSet(LogicalOperator &op, unique_ptr root, + UpdateSetInfo &set_info, TableCatalogEntry &table, + vector &columns); + void BindDoUpdateSetExpressions(const string &table_alias, LogicalInsert &insert, UpdateSetInfo &set_info, + TableCatalogEntry &table, TableStorageInfo &storage_info); + void BindOnConflictClause(LogicalInsert &insert, TableCatalogEntry &table, InsertStatement &stmt); + + static void BindSchemaOrCatalog(ClientContext &context, string &catalog, string &schema); + static void BindLogicalType(ClientContext &context, LogicalType &type, optional_ptr catalog = nullptr, + const string &schema = INVALID_SCHEMA); + + bool HasMatchingBinding(const string &table_name, const string &column_name, string &error_message); + bool HasMatchingBinding(const string &schema_name, const string &table_name, const string &column_name, + string &error_message); + bool HasMatchingBinding(const string &catalog_name, const string &schema_name, const string &table_name, + const string &column_name, string &error_message); + + void SetBindingMode(BindingMode mode); + BindingMode GetBindingMode(); + void AddTableName(string table_name); + const unordered_set &GetTableNames(); + optional_ptr GetRootStatement() { + return root_statement; + } + + void SetCanContainNulls(bool can_contain_nulls); + +private: + //! The parent binder (if any) + shared_ptr parent; + //! The vector of active binders + vector> active_binders; + //! The count of bound_tables + idx_t bound_tables; + //! Whether or not the binder has any unplanned dependent joins that still need to be planned/flattened + bool has_unplanned_dependent_joins = false; + //! Whether or not outside dependent joins have been planned and flattened + bool is_outside_flattened = true; + //! Whether CTEs should reference the parent binder (if it exists) + bool inherit_ctes = true; + //! Whether or not the binder can contain NULLs as the root of expressions + bool can_contain_nulls = false; + //! The root statement of the query that is currently being parsed + optional_ptr root_statement; + //! Binding mode + BindingMode mode = BindingMode::STANDARD_BINDING; + //! Table names extracted for BindingMode::EXTRACT_NAMES + unordered_set table_names; + //! The set of bound views + reference_set_t bound_views; + +private: + //! Get the root binder (binder with no parent) + Binder *GetRootBinder(); + //! Determine the depth of the binder + idx_t GetBinderDepth() const; + //! Bind the expressions of generated columns to check for errors + void BindGeneratedColumns(BoundCreateTableInfo &info); + //! Bind the default values of the columns of a table + void BindDefaultValues(const ColumnList &columns, vector> &bound_defaults); + //! Bind a limit value (LIMIT or OFFSET) + unique_ptr BindDelimiter(ClientContext &context, OrderBinder &order_binder, + unique_ptr delimiter, const LogicalType &type, + Value &delimiter_value); + + //! Move correlated expressions from the child binder to this binder + void MoveCorrelatedExpressions(Binder &other); + + //! Tries to bind the table name with replacement scans + unique_ptr BindWithReplacementScan(ClientContext &context, const string &table_name, + BaseTableRef &ref); + + BoundStatement Bind(SelectStatement &stmt); + BoundStatement Bind(InsertStatement &stmt); + BoundStatement Bind(CopyStatement &stmt); + BoundStatement Bind(DeleteStatement &stmt); + BoundStatement Bind(UpdateStatement &stmt); + BoundStatement Bind(CreateStatement &stmt); + BoundStatement Bind(DropStatement &stmt); + BoundStatement Bind(AlterStatement &stmt); + BoundStatement Bind(PrepareStatement &stmt); + BoundStatement Bind(ExecuteStatement &stmt); + BoundStatement Bind(TransactionStatement &stmt); + BoundStatement Bind(PragmaStatement &stmt); + BoundStatement Bind(ExplainStatement &stmt); + BoundStatement Bind(VacuumStatement &stmt); + BoundStatement Bind(RelationStatement &stmt); + BoundStatement Bind(ShowStatement &stmt); + BoundStatement Bind(CallStatement &stmt); + BoundStatement Bind(ExportStatement &stmt); + BoundStatement Bind(ExtensionStatement &stmt); + BoundStatement Bind(SetStatement &stmt); + BoundStatement Bind(SetVariableStatement &stmt); + BoundStatement Bind(ResetVariableStatement &stmt); + BoundStatement Bind(LoadStatement &stmt); + BoundStatement Bind(LogicalPlanStatement &stmt); + BoundStatement Bind(AttachStatement &stmt); + BoundStatement Bind(DetachStatement &stmt); + + BoundStatement BindReturning(vector> returning_list, TableCatalogEntry &table, + const string &alias, idx_t update_table_index, + unique_ptr child_operator, BoundStatement result); + + unique_ptr BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, idx_t depth); + + unique_ptr BindNode(SelectNode &node); + unique_ptr BindNode(SetOperationNode &node); + unique_ptr BindNode(RecursiveCTENode &node); + unique_ptr BindNode(CTENode &node); + unique_ptr BindNode(QueryNode &node); + + unique_ptr VisitQueryNode(BoundQueryNode &node, unique_ptr root); + unique_ptr CreatePlan(BoundRecursiveCTENode &node); + unique_ptr CreatePlan(BoundCTENode &node); + unique_ptr CreatePlan(BoundSelectNode &statement); + unique_ptr CreatePlan(BoundSetOperationNode &node); + unique_ptr CreatePlan(BoundQueryNode &node); + + unique_ptr Bind(BaseTableRef &ref); + unique_ptr Bind(JoinRef &ref); + unique_ptr Bind(SubqueryRef &ref, optional_ptr cte = nullptr); + unique_ptr Bind(TableFunctionRef &ref); + unique_ptr Bind(EmptyTableRef &ref); + unique_ptr Bind(ExpressionListRef &ref); + unique_ptr Bind(PivotRef &expr); + + unique_ptr BindPivot(PivotRef &expr, vector> all_columns); + unique_ptr BindUnpivot(Binder &child_binder, PivotRef &expr, + vector> all_columns, + unique_ptr &where_clause); + unique_ptr BindBoundPivot(PivotRef &expr); + + bool BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, + vector> &expressions, vector &arguments, + vector ¶meters, named_parameter_map_t &named_parameters, + unique_ptr &subquery, string &error); + bool BindTableInTableOutFunction(vector> &expressions, + unique_ptr &subquery, string &error); + unique_ptr BindTableFunction(TableFunction &function, vector parameters); + unique_ptr + BindTableFunctionInternal(TableFunction &table_function, const string &function_name, vector parameters, + named_parameter_map_t named_parameters, vector input_table_types, + vector input_table_names, const vector &column_name_alias, + unique_ptr external_dependency); + + unique_ptr CreatePlan(BoundBaseTableRef &ref); + unique_ptr CreatePlan(BoundJoinRef &ref); + unique_ptr CreatePlan(BoundSubqueryRef &ref); + unique_ptr CreatePlan(BoundTableFunction &ref); + unique_ptr CreatePlan(BoundEmptyTableRef &ref); + unique_ptr CreatePlan(BoundExpressionListRef &ref); + unique_ptr CreatePlan(BoundCTERef &ref); + unique_ptr CreatePlan(BoundPivotRef &ref); + + BoundStatement BindCopyTo(CopyStatement &stmt); + BoundStatement BindCopyFrom(CopyStatement &stmt); + + void BindModifiers(OrderBinder &order_binder, QueryNode &statement, BoundQueryNode &result); + void BindModifierTypes(BoundQueryNode &result, const vector &sql_types, idx_t projection_index); + + BoundStatement BindSummarize(ShowStatement &stmt); + unique_ptr BindLimit(OrderBinder &order_binder, LimitModifier &limit_mod); + unique_ptr BindLimitPercent(OrderBinder &order_binder, LimitPercentModifier &limit_mod); + unique_ptr BindOrderExpression(OrderBinder &order_binder, unique_ptr expr); + + unique_ptr PlanFilter(unique_ptr condition, unique_ptr root); + + void PlanSubqueries(unique_ptr &expr, unique_ptr &root); + unique_ptr PlanSubquery(BoundSubqueryExpression &expr, unique_ptr &root); + unique_ptr PlanLateralJoin(unique_ptr left, unique_ptr right, + vector &correlated_columns, + JoinType join_type = JoinType::INNER, + unique_ptr condition = nullptr); + + unique_ptr CastLogicalOperatorToTypes(vector &source_types, + vector &target_types, + unique_ptr op); + + string FindBinding(const string &using_column, const string &join_side); + bool TryFindBinding(const string &using_column, const string &join_side, string &result); + + void AddUsingBindingSet(unique_ptr set); + string RetrieveUsingBinding(Binder ¤t_binder, optional_ptr current_set, + const string &column_name, const string &join_side); + + void AddCTEMap(CommonTableExpressionMap &cte_map); + + void ExpandStarExpressions(vector> &select_list, + vector> &new_select_list); + void ExpandStarExpression(unique_ptr expr, vector> &new_select_list); + bool FindStarExpression(unique_ptr &expr, StarExpression **star, bool is_root, bool in_columns); + void ReplaceStarExpression(unique_ptr &expr, unique_ptr &replacement); + void BindWhereStarExpression(unique_ptr &expr); + + //! If only a schema name is provided (e.g. "a.b") then figure out if "a" is a schema or a catalog name + void BindSchemaOrCatalog(string &catalog_name, string &schema_name); + SchemaCatalogEntry &BindCreateSchema(CreateInfo &info); + + unique_ptr BindSelectNode(SelectNode &statement, unique_ptr from_table); + +public: + // This should really be a private constructor, but make_shared does not allow it... + // If you are thinking about calling this, you should probably call Binder::CreateBinder + Binder(bool i_know_what_i_am_doing, ClientContext &context, shared_ptr parent, bool inherit_ctes); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_constraint.hpp b/src/duckdb/src/include/duckdb/planner/bound_constraint.hpp new file mode 100644 index 00000000..c8205e48 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/bound_constraint.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/bound_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { +//! Bound equivalent of Constraint +class BoundConstraint { +public: + explicit BoundConstraint(ConstraintType type) : type(type) {}; + virtual ~BoundConstraint() { + } + + ConstraintType type; + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast constraint to type - bound constraint type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast constraint to type - bound constraint type mismatch"); + } + return reinterpret_cast(*this); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_parameter_map.hpp b/src/duckdb/src/include/duckdb/planner/bound_parameter_map.hpp new file mode 100644 index 00000000..ab5ef410 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/bound_parameter_map.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/bound_parameter_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/planner/expression/bound_parameter_data.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +class ParameterExpression; +class BoundParameterExpression; + +using bound_parameter_map_t = case_insensitive_map_t>; + +struct BoundParameterMap { +public: + explicit BoundParameterMap(case_insensitive_map_t ¶meter_data); + +public: + LogicalType GetReturnType(const string &identifier); + + bound_parameter_map_t *GetParametersPtr(); + + const bound_parameter_map_t &GetParameters(); + + const case_insensitive_map_t &GetParameterData(); + + unique_ptr BindParameterExpression(ParameterExpression &expr); + +private: + shared_ptr CreateOrGetData(const string &identifier); + void CreateNewParameter(const string &id, const shared_ptr ¶m_data); + +private: + bound_parameter_map_t parameters; + // Pre-provided parameter data if populated + case_insensitive_map_t ¶meter_data; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp new file mode 100644 index 00000000..cd5a78b6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/bound_query_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +//! Bound equivalent of QueryNode +class BoundQueryNode { +public: + explicit BoundQueryNode(QueryNodeType type) : type(type) { + } + virtual ~BoundQueryNode() { + } + + //! The type of the query node, either SetOperation or Select + QueryNodeType type; + //! The result modifiers that should be applied to this query node + vector> modifiers; + + //! The names returned by this QueryNode. + vector names; + //! The types returned by this QueryNode. + vector types; + +public: + virtual idx_t GetRootIndex() = 0; + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast bound query node to type - query node type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast bound query node to type - query node type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_result_modifier.hpp b/src/duckdb/src/include/duckdb/planner/bound_result_modifier.hpp new file mode 100644 index 00000000..a44922d7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/bound_result_modifier.hpp @@ -0,0 +1,135 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/bound_result_modifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/limits.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/planner/bound_statement.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +//! A ResultModifier +class BoundResultModifier { +public: + explicit BoundResultModifier(ResultModifierType type); + virtual ~BoundResultModifier(); + + ResultModifierType type; + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast result modifier to type - result modifier type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast result modifier to type - result modifier type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +struct BoundOrderByNode { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::ORDER_MODIFIER; + +public: + BoundOrderByNode(OrderType type, OrderByNullType null_order, unique_ptr expression); + BoundOrderByNode(OrderType type, OrderByNullType null_order, unique_ptr expression, + unique_ptr stats); + + OrderType type; + OrderByNullType null_order; + unique_ptr expression; + unique_ptr stats; + +public: + BoundOrderByNode Copy() const; + bool Equals(const BoundOrderByNode &other) const; + string ToString() const; + + void Serialize(Serializer &serializer) const; + static BoundOrderByNode Deserialize(Deserializer &deserializer); +}; + +class BoundLimitModifier : public BoundResultModifier { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::LIMIT_MODIFIER; + +public: + BoundLimitModifier(); + + //! LIMIT + int64_t limit_val = NumericLimits::Maximum(); + //! OFFSET + int64_t offset_val = 0; + //! Expression in case limit is not constant + unique_ptr limit; + //! Expression in case limit is not constant + unique_ptr offset; +}; + +class BoundOrderModifier : public BoundResultModifier { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::ORDER_MODIFIER; + +public: + BoundOrderModifier(); + + //! List of order nodes + vector orders; + + unique_ptr Copy() const; + static bool Equals(const BoundOrderModifier &left, const BoundOrderModifier &right); + static bool Equals(const unique_ptr &left, const unique_ptr &right); + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +enum class DistinctType : uint8_t { DISTINCT = 0, DISTINCT_ON = 1 }; + +class BoundDistinctModifier : public BoundResultModifier { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::DISTINCT_MODIFIER; + +public: + BoundDistinctModifier(); + + //! Whether or not this is a DISTINCT or DISTINCT ON + DistinctType distinct_type; + //! list of distinct on targets + vector> target_distincts; +}; + +class BoundLimitPercentModifier : public BoundResultModifier { +public: + static constexpr const ResultModifierType TYPE = ResultModifierType::LIMIT_PERCENT_MODIFIER; + +public: + BoundLimitPercentModifier(); + + //! LIMIT % + double limit_percent = 100.0; + //! OFFSET + int64_t offset_val = 0; + //! Expression in case limit is not constant + unique_ptr limit; + //! Expression in case limit is not constant + unique_ptr offset; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp new file mode 100644 index 00000000..bb1f7bfe --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/bound_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/string.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class LogicalOperator; +struct LogicalType; + +struct BoundStatement { + unique_ptr plan; + vector types; + vector names; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp b/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp new file mode 100644 index 00000000..0a831c54 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/bound_tableref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/tableref_type.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" + +namespace duckdb { + +class BoundTableRef { +public: + explicit BoundTableRef(TableReferenceType type) : type(type) { + } + virtual ~BoundTableRef() { + } + + //! The type of table reference + TableReferenceType type; + //! The sample options (if any) + unique_ptr sample; + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast bound table ref to type - table ref type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast bound table ref to type - table ref type mismatch"); + } + return reinterpret_cast(*this); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp new file mode 100644 index 00000000..b864b25a --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/bound_tokens.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Query Node +//===--------------------------------------------------------------------===// +class BoundQueryNode; +class BoundSelectNode; +class BoundSetOperationNode; +class BoundRecursiveCTENode; +class BoundCTENode; + +//===--------------------------------------------------------------------===// +// Expressions +//===--------------------------------------------------------------------===// +class Expression; + +class BoundAggregateExpression; +class BoundBetweenExpression; +class BoundCaseExpression; +class BoundCastExpression; +class BoundColumnRefExpression; +class BoundComparisonExpression; +class BoundConjunctionExpression; +class BoundConstantExpression; +class BoundDefaultExpression; +class BoundFunctionExpression; +class BoundOperatorExpression; +class BoundParameterExpression; +class BoundReferenceExpression; +class BoundSubqueryExpression; +class BoundUnnestExpression; +class BoundWindowExpression; + +//===--------------------------------------------------------------------===// +// TableRefs +//===--------------------------------------------------------------------===// +class BoundTableRef; + +class BoundBaseTableRef; +class BoundJoinRef; +class BoundSubqueryRef; +class BoundTableFunction; +class BoundEmptyTableRef; +class BoundExpressionListRef; +class BoundCTERef; +class BoundPivotRef; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/column_binding.hpp b/src/duckdb/src/include/duckdb/planner/column_binding.hpp new file mode 100644 index 00000000..0972e1b4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/column_binding.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/column_binding.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/to_string.hpp" + +#include + +namespace duckdb { +class Serializer; +class Deserializer; + +struct ColumnBinding { + idx_t table_index; + // This index is local to a Binding, and has no meaning outside of the context of the Binding that created it + idx_t column_index; + + ColumnBinding() : table_index(DConstants::INVALID_INDEX), column_index(DConstants::INVALID_INDEX) { + } + ColumnBinding(idx_t table, idx_t column) : table_index(table), column_index(column) { + } + + string ToString() const { + return "#[" + to_string(table_index) + "." + to_string(column_index) + "]"; + } + + bool operator==(const ColumnBinding &rhs) const { + return table_index == rhs.table_index && column_index == rhs.column_index; + } + + bool operator!=(const ColumnBinding &rhs) const { + return !(*this == rhs); + } + + void Serialize(Serializer &serializer) const; + static ColumnBinding Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/column_binding_map.hpp b/src/duckdb/src/include/duckdb/planner/column_binding_map.hpp new file mode 100644 index 00000000..8c1b84f0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/column_binding_map.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/column_binding_map.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/column_binding.hpp" + +namespace duckdb { + +struct ColumnBindingHashFunction { + uint64_t operator()(const ColumnBinding &a) const { + return CombineHash(Hash(a.table_index), Hash(a.column_index)); + } +}; + +struct ColumnBindingEquality { + bool operator()(const ColumnBinding &a, const ColumnBinding &b) const { + return a == b; + } +}; + +template +using column_binding_map_t = unordered_map; + +using column_binding_set_t = unordered_set; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/constraints/bound_check_constraint.hpp b/src/duckdb/src/include/duckdb/planner/constraints/bound_check_constraint.hpp new file mode 100644 index 00000000..8dd34328 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/constraints/bound_check_constraint.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/constraints/bound_check_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/bound_constraint.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/index_map.hpp" + +namespace duckdb { + +//! The CheckConstraint contains an expression that must evaluate to TRUE for +//! every row in a table +class BoundCheckConstraint : public BoundConstraint { +public: + static constexpr const ConstraintType TYPE = ConstraintType::CHECK; + +public: + BoundCheckConstraint() : BoundConstraint(ConstraintType::CHECK) { + } + + //! The expression + unique_ptr expression; + //! The columns used by the CHECK constraint + physical_index_set_t bound_columns; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/constraints/bound_foreign_key_constraint.hpp b/src/duckdb/src/include/duckdb/planner/constraints/bound_foreign_key_constraint.hpp new file mode 100644 index 00000000..6f90dee5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/constraints/bound_foreign_key_constraint.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/constraints/bound_foreign_key_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/bound_constraint.hpp" +#include "duckdb/common/index_map.hpp" + +namespace duckdb { + +class BoundForeignKeyConstraint : public BoundConstraint { +public: + static constexpr const ConstraintType TYPE = ConstraintType::FOREIGN_KEY; + +public: + BoundForeignKeyConstraint(ForeignKeyInfo info_p, physical_index_set_t pk_key_set_p, + physical_index_set_t fk_key_set_p) + : BoundConstraint(ConstraintType::FOREIGN_KEY), info(std::move(info_p)), pk_key_set(std::move(pk_key_set_p)), + fk_key_set(std::move(fk_key_set_p)) { +#ifdef DEBUG + D_ASSERT(info.pk_keys.size() == pk_key_set.size()); + for (auto &key : info.pk_keys) { + D_ASSERT(pk_key_set.find(key) != pk_key_set.end()); + } + D_ASSERT(info.fk_keys.size() == fk_key_set.size()); + for (auto &key : info.fk_keys) { + D_ASSERT(fk_key_set.find(key) != fk_key_set.end()); + } +#endif + } + + ForeignKeyInfo info; + //! The same keys but stored as an unordered set + physical_index_set_t pk_key_set; + //! The same keys but stored as an unordered set + physical_index_set_t fk_key_set; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/constraints/bound_not_null_constraint.hpp b/src/duckdb/src/include/duckdb/planner/constraints/bound_not_null_constraint.hpp new file mode 100644 index 00000000..915517cf --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/constraints/bound_not_null_constraint.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/constraints/bound_not_null_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_constraint.hpp" + +namespace duckdb { + +class BoundNotNullConstraint : public BoundConstraint { +public: + static constexpr const ConstraintType TYPE = ConstraintType::NOT_NULL; + +public: + explicit BoundNotNullConstraint(PhysicalIndex index) : BoundConstraint(ConstraintType::NOT_NULL), index(index) { + } + + //! Column index this constraint pertains to + PhysicalIndex index; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp b/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp new file mode 100644 index 00000000..4c7468d6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/constraints/bound_unique_constraint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/bound_constraint.hpp" +#include "duckdb/common/index_map.hpp" + +namespace duckdb { + +class BoundUniqueConstraint : public BoundConstraint { +public: + static constexpr const ConstraintType TYPE = ConstraintType::UNIQUE; + +public: + BoundUniqueConstraint(vector keys, logical_index_set_t key_set, bool is_primary_key) + : BoundConstraint(ConstraintType::UNIQUE), keys(std::move(keys)), key_set(std::move(key_set)), + is_primary_key(is_primary_key) { +#ifdef DEBUG + D_ASSERT(this->keys.size() == this->key_set.size()); + for (auto &key : this->keys) { + D_ASSERT(this->key_set.find(key) != this->key_set.end()); + } +#endif + } + + //! The keys that define the unique constraint + vector keys; + //! The same keys but stored as an unordered set + logical_index_set_t key_set; + //! Whether or not the unique constraint is a primary key + bool is_primary_key; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/constraints/list.hpp b/src/duckdb/src/include/duckdb/planner/constraints/list.hpp new file mode 100644 index 00000000..396d0177 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/constraints/list.hpp @@ -0,0 +1,4 @@ +#include "duckdb/planner/constraints/bound_check_constraint.hpp" +#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" +#include "duckdb/planner/constraints/bound_unique_constraint.hpp" +#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/expression.hpp b/src/duckdb/src/include/duckdb/planner/expression.hpp new file mode 100644 index 00000000..208e9b5c --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression.hpp @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/base_expression.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { +class BaseStatistics; +class ClientContext; + +//! The Expression class represents a bound Expression with a return type +class Expression : public BaseExpression { +public: + Expression(ExpressionType type, ExpressionClass expression_class, LogicalType return_type); + ~Expression() override; + + //! The return type of the expression + LogicalType return_type; + //! Expression statistics (if any) - ONLY USED FOR VERIFICATION + unique_ptr verification_stats; + +public: + bool IsAggregate() const override; + bool IsWindow() const override; + bool HasSubquery() const override; + bool IsScalar() const override; + bool HasParameter() const override; + virtual bool HasSideEffects() const; + virtual bool PropagatesNullValues() const; + virtual bool IsFoldable() const; + + hash_t Hash() const override; + + bool Equals(const BaseExpression &other) const override { + if (!BaseExpression::Equals(other)) { + return false; + } + return return_type == ((Expression &)other).return_type; + } + static bool Equals(const Expression &left, const Expression &right) { + return left.Equals(right); + } + static bool Equals(const unique_ptr &left, const unique_ptr &right); + static bool ListEquals(const vector> &left, const vector> &right); + //! Create a copy of this expression + virtual unique_ptr Copy() = 0; + + virtual void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + //! Copy base Expression properties from another expression to this one, + //! used in Copy method + void CopyProperties(Expression &other) { + type = other.type; + expression_class = other.expression_class; + alias = other.alias; + return_type = other.return_type; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp new file mode 100644 index 00000000..ba4293d0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_aggregate_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" +#include "duckdb/function/aggregate_function.hpp" +#include + +namespace duckdb { + +class BoundAggregateExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_AGGREGATE; + +public: + BoundAggregateExpression(AggregateFunction function, vector> children, + unique_ptr filter, unique_ptr bind_info, + AggregateType aggr_type); + + //! The bound function expression + AggregateFunction function; + //! List of arguments to the function + vector> children; + //! The bound function data (if any) + unique_ptr bind_info; + //! The aggregate type (distinct or non-distinct) + AggregateType aggr_type; + + //! Filter for this aggregate + unique_ptr filter; + //! The order by expression for this aggregate - if any + unique_ptr order_bys; + +public: + bool IsDistinct() const { + return aggr_type == AggregateType::DISTINCT; + } + + bool IsAggregate() const override { + return true; + } + bool IsFoldable() const override { + return false; + } + bool PropagatesNullValues() const override; + + string ToString() const override; + + hash_t Hash() const override; + bool Equals(const BaseExpression &other) const override; + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_between_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_between_expression.hpp new file mode 100644 index 00000000..21c301ef --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_between_expression.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_between_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class BoundBetweenExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_BETWEEN; + +public: + BoundBetweenExpression(unique_ptr input, unique_ptr lower, unique_ptr upper, + bool lower_inclusive, bool upper_inclusive); + + unique_ptr input; + unique_ptr lower; + unique_ptr upper; + bool lower_inclusive; + bool upper_inclusive; + +public: + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + ExpressionType LowerComparisonType() { + return lower_inclusive ? ExpressionType::COMPARE_GREATERTHANOREQUALTO : ExpressionType::COMPARE_GREATERTHAN; + } + ExpressionType UpperComparisonType() { + return upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO : ExpressionType::COMPARE_LESSTHAN; + } + +private: + BoundBetweenExpression(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_case_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_case_expression.hpp new file mode 100644 index 00000000..c3a5abe5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_case_expression.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_case_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +struct BoundCaseCheck { + unique_ptr when_expr; + unique_ptr then_expr; + + void Serialize(Serializer &serializer) const; + static BoundCaseCheck Deserialize(Deserializer &deserializer); +}; + +class BoundCaseExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_CASE; + +public: + BoundCaseExpression(LogicalType type); + BoundCaseExpression(unique_ptr when_expr, unique_ptr then_expr, + unique_ptr else_expr); + + vector case_checks; + unique_ptr else_expr; + +public: + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_cast_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_cast_expression.hpp new file mode 100644 index 00000000..595f6267 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_cast_expression.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_cast_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" +#include "duckdb/function/cast/default_casts.hpp" + +namespace duckdb { + +class BoundCastExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_CAST; + +public: + BoundCastExpression(unique_ptr child, LogicalType target_type, BoundCastInfo bound_cast, + bool try_cast = false); + + //! The child type + unique_ptr child; + //! Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of throwing an error. + bool try_cast; + //! The bound cast info + BoundCastInfo bound_cast; + +public: + LogicalType source_type() { + D_ASSERT(child->return_type.IsValid()); + return child->return_type; + } + + //! Cast an expression to the specified SQL type, using only the built-in SQL casts + static unique_ptr AddDefaultCastToType(unique_ptr expr, const LogicalType &target_type, + bool try_cast = false); + //! Cast an expression to the specified SQL type if required + DUCKDB_API static unique_ptr AddCastToType(ClientContext &context, unique_ptr expr, + const LogicalType &target_type, bool try_cast = false); + //! Returns true if a cast is invertible (i.e. CAST(s -> t -> s) = s for all values of s). This is not true for e.g. + //! boolean casts, because that can be e.g. -1 -> TRUE -> 1. This is necessary to prevent some optimizer bugs. + static bool CastIsInvertible(const LogicalType &source_type, const LogicalType &target_type); + + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + BoundCastExpression(ClientContext &context, unique_ptr child, LogicalType target_type); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_columnref_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_columnref_expression.hpp new file mode 100644 index 00000000..4d356c02 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_columnref_expression.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_columnref_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! A BoundColumnRef expression represents a ColumnRef expression that was bound to an actual table and column index. It +//! is not yet executable, however. The ColumnBindingResolver transforms the BoundColumnRefExpressions into +//! BoundExpressions, which refer to indexes into the physical chunks that pass through the executor. +class BoundColumnRefExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_COLUMN_REF; + +public: + BoundColumnRefExpression(LogicalType type, ColumnBinding binding, idx_t depth = 0); + BoundColumnRefExpression(string alias, LogicalType type, ColumnBinding binding, idx_t depth = 0); + + //! Column index set by the binder, used to generate the final BoundExpression + ColumnBinding binding; + //! The subquery depth (i.e. depth 0 = current query, depth 1 = parent query, depth 2 = parent of parent, etc...). + //! This is only non-zero for correlated expressions inside subqueries. + idx_t depth; + +public: + bool IsScalar() const override { + return false; + } + bool IsFoldable() const override { + return false; + } + + string ToString() const override; + string GetName() const override; + + bool Equals(const BaseExpression &other) const override; + hash_t Hash() const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_comparison_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_comparison_expression.hpp new file mode 100644 index 00000000..9e34e3cf --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_comparison_expression.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_comparison_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class BoundComparisonExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_COMPARISON; + +public: + BoundComparisonExpression(ExpressionType type, unique_ptr left, unique_ptr right); + + unique_ptr left; + unique_ptr right; + +public: + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + static LogicalType BindComparison(LogicalType left_type, LogicalType right_type); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_conjunction_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_conjunction_expression.hpp new file mode 100644 index 00000000..8e66963b --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_conjunction_expression.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_conjunction_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class BoundConjunctionExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_CONJUNCTION; + +public: + explicit BoundConjunctionExpression(ExpressionType type); + BoundConjunctionExpression(ExpressionType type, unique_ptr left, unique_ptr right); + + vector> children; + +public: + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + + bool PropagatesNullValues() const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_constant_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_constant_expression.hpp new file mode 100644 index 00000000..f11056aa --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_constant_expression.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_constant_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/value.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class BoundConstantExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_CONSTANT; + +public: + explicit BoundConstantExpression(Value value); + + Value value; + +public: + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + hash_t Hash() const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_default_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_default_expression.hpp new file mode 100644 index 00000000..bac032ec --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_default_expression.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_default_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class BoundDefaultExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_DEFAULT; + +public: + explicit BoundDefaultExpression(LogicalType type = LogicalType()) + : Expression(ExpressionType::VALUE_DEFAULT, ExpressionClass::BOUND_DEFAULT, type) { + } + +public: + bool IsScalar() const override { + return false; + } + bool IsFoldable() const override { + return false; + } + + string ToString() const override { + return "DEFAULT"; + } + + unique_ptr Copy() override { + return make_uniq(return_type); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_function_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_function_expression.hpp new file mode 100644 index 00000000..27201cb5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_function_expression.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_function_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { +class ScalarFunctionCatalogEntry; + +//! Represents a function call that has been bound to a base function +class BoundFunctionExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_FUNCTION; + +public: + BoundFunctionExpression(LogicalType return_type, ScalarFunction bound_function, + vector> arguments, unique_ptr bind_info, + bool is_operator = false); + + //! The bound function expression + ScalarFunction function; + //! List of child-expressions of the function + vector> children; + //! The bound function data (if any) + unique_ptr bind_info; + //! Whether or not the function is an operator, only used for rendering + bool is_operator; + +public: + bool HasSideEffects() const override; + bool IsFoldable() const override; + string ToString() const override; + bool PropagatesNullValues() const override; + hash_t Hash() const override; + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + void Verify() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_lambda_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_lambda_expression.hpp new file mode 100644 index 00000000..680650c2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_lambda_expression.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_lambda_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" +#include "duckdb/parser/expression/lambda_expression.hpp" + +namespace duckdb { + +class BoundLambdaExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_LAMBDA; + +public: + BoundLambdaExpression(ExpressionType type_p, LogicalType return_type_p, unique_ptr lambda_expr_p, + idx_t parameter_count_p); + + unique_ptr lambda_expr; + vector> captures; + idx_t parameter_count; + +public: + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_lambdaref_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_lambdaref_expression.hpp new file mode 100644 index 00000000..6ede2532 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_lambdaref_expression.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_lambdaref_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! A BoundLambdaRef expression represents a LambdaRef expression that was bound to an lambda parameter +//! in the lambda bindings vector. When capturing lambdas the BoundLambdaRef becomes a +//! BoundReferenceExpresssion, indexing the corresponding lambda parameter in the lambda bindings vector, +//! which refers to the physical chunk of the lambda parameter during execution. +class BoundLambdaRefExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_LAMBDA_REF; + +public: + BoundLambdaRefExpression(LogicalType type, ColumnBinding binding, idx_t lambda_index, idx_t depth = 0); + BoundLambdaRefExpression(string alias, LogicalType type, ColumnBinding binding, idx_t lambda_index, + idx_t depth = 0); + //! Column index set by the binder, used to generate the final BoundExpression + ColumnBinding binding; + //! The index of the lambda parameter in the lambda bindings vector + idx_t lambda_index; + //! The subquery depth (i.e. depth 0 = current query, depth 1 = parent query, depth 2 = parent of parent, etc...). + //! This is only non-zero for correlated expressions inside subqueries. + idx_t depth; + +public: + bool IsScalar() const override { + return false; + } + bool IsFoldable() const override { + return false; + } + + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + hash_t Hash() const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_operator_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_operator_expression.hpp new file mode 100644 index 00000000..1c929654 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_operator_expression.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_operator_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class BoundOperatorExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_OPERATOR; + +public: + BoundOperatorExpression(ExpressionType type, LogicalType return_type); + + vector> children; + +public: + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_data.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_data.hpp new file mode 100644 index 00000000..b1cac717 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_data.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_parameter_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +struct BoundParameterData { +public: + BoundParameterData() { + } + explicit BoundParameterData(Value val) : value(std::move(val)), return_type(value.type()) { + } + +private: + Value value; + +public: + LogicalType return_type; + +public: + void SetValue(Value val) { + value = std::move(val); + } + + const Value &GetValue() const { + return value; + } + + void Serialize(Serializer &serializer) const; + static shared_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_expression.hpp new file mode 100644 index 00000000..77bfa03d --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_expression.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_parameter_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/bound_parameter_map.hpp" + +namespace duckdb { + +class BoundParameterExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_PARAMETER; + +public: + explicit BoundParameterExpression(const string &identifier); + + string identifier; + shared_ptr parameter_data; + +public: + //! Invalidate a bound parameter expression - forcing a rebind on any subsequent filters + DUCKDB_API static void Invalidate(Expression &expr); + //! Invalidate all parameters within an expression + DUCKDB_API static void InvalidateRecursive(Expression &expr); + + bool IsScalar() const override; + bool HasParameter() const override; + bool IsFoldable() const override; + + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + hash_t Hash() const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + BoundParameterExpression(bound_parameter_map_t &global_parameter_set, string identifier, LogicalType return_type, + shared_ptr parameter_data); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_reference_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_reference_expression.hpp new file mode 100644 index 00000000..e8bc7b4a --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_reference_expression.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_reference_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! A BoundReferenceExpression represents a physical index into a DataChunk +class BoundReferenceExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_REF; + +public: + BoundReferenceExpression(string alias, LogicalType type, idx_t index); + BoundReferenceExpression(LogicalType type, storage_t index); + + //! Index used to access data in the chunks + storage_t index; + +public: + bool IsScalar() const override { + return false; + } + bool IsFoldable() const override { + return false; + } + + string ToString() const override; + + hash_t Hash() const override; + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp new file mode 100644 index 00000000..73cfd2fa --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_subquery_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/subquery_type.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class BoundSubqueryExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_SUBQUERY; + +public: + explicit BoundSubqueryExpression(LogicalType return_type); + + bool IsCorrelated() { + return binder->correlated_columns.size() > 0; + } + + //! The binder used to bind the subquery node + shared_ptr binder; + //! The bound subquery node + unique_ptr subquery; + //! The subquery type + SubqueryType subquery_type; + //! the child expression to compare with (in case of IN, ANY, ALL operators) + unique_ptr child; + //! The comparison type of the child expression with the subquery (in case of ANY, ALL operators) + ExpressionType comparison_type; + //! The LogicalType of the subquery result. Only used for ANY expressions. + LogicalType child_type; + //! The target LogicalType of the subquery result (i.e. to which type it should be casted, if child_type <> + //! child_target). Only used for ANY expressions. + LogicalType child_target; + +public: + bool HasSubquery() const override { + return true; + } + bool IsScalar() const override { + return false; + } + bool IsFoldable() const override { + return false; + } + + string ToString() const override; + + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + bool PropagatesNullValues() const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_unnest_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_unnest_expression.hpp new file mode 100644 index 00000000..6db10d92 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_unnest_expression.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_unnest_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! Represents a function call that has been bound to a base function +class BoundUnnestExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_UNNEST; + +public: + explicit BoundUnnestExpression(LogicalType return_type); + + unique_ptr child; + +public: + bool IsFoldable() const override; + string ToString() const override; + + hash_t Hash() const override; + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_window_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_window_expression.hpp new file mode 100644 index 00000000..a538344a --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_window_expression.hpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression/bound_window_expression.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { +class AggregateFunction; + +class BoundWindowExpression : public Expression { +public: + static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_WINDOW; + +public: + BoundWindowExpression(ExpressionType type, LogicalType return_type, unique_ptr aggregate, + unique_ptr bind_info); + + //! The bound aggregate function + unique_ptr aggregate; + //! The bound function info + unique_ptr bind_info; + //! The child expressions of the main window function + vector> children; + //! The set of expressions to partition by + vector> partitions; + //! Statistics belonging to the partitions expressions + vector> partitions_stats; + //! The set of ordering clauses + vector orders; + //! Expression representing a filter, only used for aggregates + unique_ptr filter_expr; + //! True to ignore NULL values + bool ignore_nulls; + //! The window boundaries + WindowBoundary start = WindowBoundary::INVALID; + WindowBoundary end = WindowBoundary::INVALID; + + unique_ptr start_expr; + unique_ptr end_expr; + //! Offset and default expressions for WINDOW_LEAD and WINDOW_LAG functions + unique_ptr offset_expr; + unique_ptr default_expr; + +public: + bool IsWindow() const override { + return true; + } + bool IsFoldable() const override { + return false; + } + + string ToString() const override; + + bool KeysAreCompatible(const BoundWindowExpression &other) const; + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/list.hpp b/src/duckdb/src/include/duckdb/planner/expression/list.hpp new file mode 100644 index 00000000..0a0e9efe --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression/list.hpp @@ -0,0 +1,18 @@ +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_default_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_lambda_expression.hpp" +#include "duckdb/planner/expression/bound_lambdaref_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_subquery_expression.hpp" +#include "duckdb/planner/expression/bound_unnest_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder.hpp new file mode 100644 index 00000000..11e5a882 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder.hpp @@ -0,0 +1,167 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/stack_checker.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/tokens.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +class Binder; +class ClientContext; +class QueryNode; + +class ScalarFunctionCatalogEntry; +class AggregateFunctionCatalogEntry; +class ScalarMacroCatalogEntry; +class CatalogEntry; +class SimpleFunction; + +struct DummyBinding; + +struct BoundColumnReferenceInfo { + string name; + idx_t query_location; +}; + +struct BindResult { + BindResult() { + } + explicit BindResult(string error) : error(error) { + } + explicit BindResult(unique_ptr expr) : expression(std::move(expr)) { + } + + bool HasError() { + return !error.empty(); + } + + unique_ptr expression; + string error; +}; + +class ExpressionBinder { + friend class StackChecker; + +public: + ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder = false); + virtual ~ExpressionBinder(); + + //! The target type that should result from the binder. If the result is not of this type, a cast to this type will + //! be added. Defaults to INVALID. + LogicalType target_type; + + optional_ptr macro_binding; + optional_ptr> lambda_bindings; + +public: + unique_ptr Bind(unique_ptr &expr, optional_ptr result_type = nullptr, + bool root_expression = true); + + //! Returns whether or not any columns have been bound by the expression binder + bool HasBoundColumns() { + return !bound_columns.empty(); + } + const vector &GetBoundColumns() { + return bound_columns; + } + + string Bind(unique_ptr &expr, idx_t depth, bool root_expression = false); + + unique_ptr CreateStructExtract(unique_ptr base, string field_name); + unique_ptr CreateStructPack(ColumnRefExpression &colref); + BindResult BindQualifiedColumnName(ColumnRefExpression &colref, const string &table_name); + + unique_ptr QualifyColumnName(const string &column_name, string &error_message); + unique_ptr QualifyColumnName(ColumnRefExpression &colref, string &error_message); + + // Bind table names to ColumnRefExpressions + void QualifyColumnNames(unique_ptr &expr); + static void QualifyColumnNames(Binder &binder, unique_ptr &expr); + + static unique_ptr PushCollation(ClientContext &context, unique_ptr source, + const string &collation, bool equality_only = false); + static void TestCollation(ClientContext &context, const string &collation); + + bool BindCorrelatedColumns(unique_ptr &expr); + + void BindChild(unique_ptr &expr, idx_t depth, string &error); + static void ExtractCorrelatedExpressions(Binder &binder, Expression &expr); + + static bool ContainsNullType(const LogicalType &type); + static LogicalType ExchangeNullType(const LogicalType &type); + static bool ContainsType(const LogicalType &type, LogicalTypeId target); + static LogicalType ExchangeType(const LogicalType &type, LogicalTypeId target, LogicalType new_type); + + virtual bool QualifyColumnAlias(const ColumnRefExpression &colref); + + //! Bind the given expresion. Unlike Bind(), this does *not* mute the given ParsedExpression. + //! Exposed to be used from sub-binders that aren't subclasses of ExpressionBinder. + virtual BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false); + + void ReplaceMacroParametersRecursive(unique_ptr &expr); + +private: + //! Maximum stack depth + static constexpr const idx_t MAXIMUM_STACK_DEPTH = 128; + //! Current stack depth + idx_t stack_depth = DConstants::INVALID_INDEX; + + void InitializeStackCheck(); + StackChecker StackCheck(const ParsedExpression &expr, idx_t extra_stack = 1); + +protected: + BindResult BindExpression(BetweenExpression &expr, idx_t depth); + BindResult BindExpression(CaseExpression &expr, idx_t depth); + BindResult BindExpression(CollateExpression &expr, idx_t depth); + BindResult BindExpression(CastExpression &expr, idx_t depth); + BindResult BindExpression(ColumnRefExpression &expr, idx_t depth); + BindResult BindExpression(ComparisonExpression &expr, idx_t depth); + BindResult BindExpression(ConjunctionExpression &expr, idx_t depth); + BindResult BindExpression(ConstantExpression &expr, idx_t depth); + BindResult BindExpression(FunctionExpression &expr, idx_t depth, unique_ptr &expr_ptr); + BindResult BindExpression(LambdaExpression &expr, idx_t depth, const bool is_lambda, + const LogicalType &list_child_type); + BindResult BindExpression(OperatorExpression &expr, idx_t depth); + BindResult BindExpression(ParameterExpression &expr, idx_t depth); + BindResult BindExpression(SubqueryExpression &expr, idx_t depth); + BindResult BindPositionalReference(unique_ptr &expr, idx_t depth, bool root_expression); + + void TransformCapturedLambdaColumn(unique_ptr &original, unique_ptr &replacement, + vector> &captures, LogicalType &list_child_type); + void CaptureLambdaColumns(vector> &captures, LogicalType &list_child_type, + unique_ptr &expr); + + static unique_ptr GetSQLValueFunction(const string &column_name); + +protected: + virtual BindResult BindGroupingFunction(OperatorExpression &op, idx_t depth); + virtual BindResult BindFunction(FunctionExpression &expr, ScalarFunctionCatalogEntry &function, idx_t depth); + virtual BindResult BindLambdaFunction(FunctionExpression &expr, ScalarFunctionCatalogEntry &function, idx_t depth); + virtual BindResult BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, idx_t depth); + virtual BindResult BindUnnest(FunctionExpression &expr, idx_t depth, bool root_expression); + virtual BindResult BindMacro(FunctionExpression &expr, ScalarMacroCatalogEntry ¯o, idx_t depth, + unique_ptr &expr_ptr); + + virtual string UnsupportedAggregateMessage(); + virtual string UnsupportedUnnestMessage(); + + Binder &binder; + ClientContext &context; + optional_ptr stored_binder; + vector bound_columns; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/aggregate_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/aggregate_binder.hpp new file mode 100644 index 00000000..84075d9f --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/aggregate_binder.hpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/aggregate_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +//! The AggregateBinder is responsible for binding aggregate statements extracted from a SELECT clause (by the +//! SelectBinder) +class AggregateBinder : public ExpressionBinder { + friend class SelectBinder; + +public: + AggregateBinder(Binder &binder, ClientContext &context); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + string UnsupportedAggregateMessage() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/alter_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/alter_binder.hpp new file mode 100644 index 00000000..2633a092 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/alter_binder.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/alter_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { +class TableCatalogEntry; + +//! The ALTER binder is responsible for binding an expression within alter statements +class AlterBinder : public ExpressionBinder { +public: + AlterBinder(Binder &binder, ClientContext &context, TableCatalogEntry &table, vector &bound_columns, + LogicalType target_type); + + TableCatalogEntry &table; + vector &bound_columns; + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + BindResult BindColumn(ColumnRefExpression &expr); + + string UnsupportedAggregateMessage() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/base_select_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/base_select_binder.hpp new file mode 100644 index 00000000..c05e1d8d --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/base_select_binder.hpp @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/base_select_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { +class BoundColumnRefExpression; +class WindowExpression; + +class BoundSelectNode; + +struct BoundGroupInformation { + parsed_expression_map_t map; + case_insensitive_map_t alias_map; + unordered_map collated_groups; +}; + +//! The BaseSelectBinder is the base binder of the SELECT, HAVING and QUALIFY binders. It can bind aggregates and window +//! functions. +class BaseSelectBinder : public ExpressionBinder { +public: + BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, + case_insensitive_map_t alias_map); + BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info); + + bool BoundAggregates() { + return bound_aggregate; + } + void ResetBindings() { + this->bound_aggregate = false; + this->bound_columns.clear(); + } + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + BindResult BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, idx_t depth) override; + + bool inside_window; + bool bound_aggregate = false; + + BoundSelectNode &node; + BoundGroupInformation &info; + case_insensitive_map_t alias_map; + +protected: + BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth); + BindResult BindGroupingFunction(OperatorExpression &op, idx_t depth) override; + BindResult BindWindow(WindowExpression &expr, idx_t depth); + + idx_t TryBindGroup(ParsedExpression &expr, idx_t depth); + BindResult BindGroup(ParsedExpression &expr, idx_t depth, idx_t group_index); + + bool QualifyColumnAlias(const ColumnRefExpression &colref) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/check_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/check_binder.hpp new file mode 100644 index 00000000..51bd9c29 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/check_binder.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/check_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/common/index_map.hpp" +#include "duckdb/parser/column_list.hpp" + +namespace duckdb { +//! The CHECK binder is responsible for binding an expression within a CHECK constraint +class CheckBinder : public ExpressionBinder { +public: + CheckBinder(Binder &binder, ClientContext &context, string table, const ColumnList &columns, + physical_index_set_t &bound_columns); + + string table; + const ColumnList &columns; + physical_index_set_t &bound_columns; + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + BindResult BindCheckColumn(ColumnRefExpression &expr); + + string UnsupportedAggregateMessage() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp new file mode 100644 index 00000000..ce95ac73 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/column_alias_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +class BoundSelectNode; +class ColumnRefExpression; + +//! A helper binder for WhereBinder and HavingBinder which support alias as a columnref. +class ColumnAliasBinder { +public: + ColumnAliasBinder(BoundSelectNode &node, const case_insensitive_map_t &alias_map); + + BindResult BindAlias(ExpressionBinder &enclosing_binder, ColumnRefExpression &expr, idx_t depth, + bool root_expression); + +private: + BoundSelectNode &node; + const case_insensitive_map_t &alias_map; + unordered_set visited_select_indexes; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/constant_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/constant_binder.hpp new file mode 100644 index 00000000..026b3b60 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/constant_binder.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/constant_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +//! The Constant binder can bind ONLY constant foldable expressions (i.e. no subqueries, column refs, etc) +class ConstantBinder : public ExpressionBinder { +public: + ConstantBinder(Binder &binder, ClientContext &context, string clause); + + //! The location where this binder is used, used for error messages + string clause; + +protected: + BindResult BindExpression(unique_ptr &expr, idx_t depth, bool root_expression = false) override; + + string UnsupportedAggregateMessage() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp new file mode 100644 index 00000000..10e32b6c --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/group_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { +class ConstantExpression; +class ColumnRefExpression; + +//! The GROUP binder is responsible for binding expressions in the GROUP BY clause +class GroupBinder : public ExpressionBinder { +public: + GroupBinder(Binder &binder, ClientContext &context, SelectNode &node, idx_t group_index, + case_insensitive_map_t &alias_map, case_insensitive_map_t &group_alias_map); + + //! The unbound root expression + unique_ptr unbound_expression; + //! The group index currently being bound + idx_t bind_index; + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) override; + + string UnsupportedAggregateMessage() override; + + BindResult BindSelectRef(idx_t entry); + BindResult BindColumnRef(ColumnRefExpression &expr); + BindResult BindConstant(ConstantExpression &expr); + + SelectNode &node; + case_insensitive_map_t &alias_map; + case_insensitive_map_t &group_alias_map; + unordered_set used_aliases; + + idx_t group_index; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/having_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/having_binder.hpp new file mode 100644 index 00000000..113bc68a --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/having_binder.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/having_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder/base_select_binder.hpp" +#include "duckdb/planner/expression_binder/column_alias_binder.hpp" +#include "duckdb/common/enums/aggregate_handling.hpp" + +namespace duckdb { + +//! The HAVING binder is responsible for binding an expression within the HAVING clause of a SQL statement +class HavingBinder : public BaseSelectBinder { +public: + HavingBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, + case_insensitive_map_t &alias_map, AggregateHandling aggregate_handling); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + +private: + BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression); + + ColumnAliasBinder column_alias_binder; + AggregateHandling aggregate_handling; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/index_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/index_binder.hpp new file mode 100644 index 00000000..fce742ff --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/index_binder.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/index_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +namespace duckdb { +class BoundColumnRefExpression; + +//! The IndexBinder is responsible for binding an expression within an index statement +class IndexBinder : public ExpressionBinder { +public: + IndexBinder(Binder &binder, ClientContext &context, optional_ptr table = nullptr, + optional_ptr info = nullptr); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + string UnsupportedAggregateMessage() override; + +private: + // only for WAL replay + optional_ptr table; + optional_ptr info; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/insert_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/insert_binder.hpp new file mode 100644 index 00000000..555445a2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/insert_binder.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/insert_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +//! The INSERT binder is responsible for binding expressions within the VALUES of an INSERT statement +class InsertBinder : public ExpressionBinder { +public: + InsertBinder(Binder &binder, ClientContext &context); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + string UnsupportedAggregateMessage() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp new file mode 100644 index 00000000..eb68a0cd --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/lateral_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +class ColumnAliasBinder; + +//! The LATERAL binder is responsible for binding an expression within a LATERAL join +class LateralBinder : public ExpressionBinder { +public: + LateralBinder(Binder &binder, ClientContext &context); + + bool HasCorrelatedColumns() const { + return !correlated_columns.empty(); + } + + static void ReduceExpressionDepth(LogicalOperator &op, const vector &info); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + string UnsupportedAggregateMessage() override; + +private: + BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression); + void ExtractCorrelatedColumns(Expression &expr); + +private: + vector correlated_columns; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/order_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/order_binder.hpp new file mode 100644 index 00000000..830952cc --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/order_binder.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/order_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" + +namespace duckdb { +class Binder; +class Expression; +class SelectNode; + +//! The ORDER binder is responsible for binding an expression within the ORDER BY clause of a SQL statement +class OrderBinder { +public: + OrderBinder(vector binders, idx_t projection_index, case_insensitive_map_t &alias_map, + parsed_expression_map_t &projection_map, idx_t max_count); + OrderBinder(vector binders, idx_t projection_index, SelectNode &node, + case_insensitive_map_t &alias_map, parsed_expression_map_t &projection_map); + +public: + unique_ptr Bind(unique_ptr expr); + + idx_t MaxCount() const { + return max_count; + } + bool HasExtraList() const { + return extra_list; + } + const vector &GetBinders() const { + return binders; + } + + unique_ptr CreateExtraReference(unique_ptr expr); + +private: + unique_ptr CreateProjectionReference(ParsedExpression &expr, idx_t index); + unique_ptr BindConstant(ParsedExpression &expr, const Value &val); + +private: + vector binders; + idx_t projection_index; + idx_t max_count; + vector> *extra_list; + case_insensitive_map_t &alias_map; + parsed_expression_map_t &projection_map; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/qualify_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/qualify_binder.hpp new file mode 100644 index 00000000..ef5d6e59 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/qualify_binder.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/qualify_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder/base_select_binder.hpp" +#include "duckdb/planner/expression_binder/column_alias_binder.hpp" + +namespace duckdb { + +//! The QUALIFY binder is responsible for binding an expression within the QUALIFY clause of a SQL statement +class QualifyBinder : public BaseSelectBinder { +public: + QualifyBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, + case_insensitive_map_t &alias_map); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + +private: + BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression); + + ColumnAliasBinder column_alias_binder; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/relation_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/relation_binder.hpp new file mode 100644 index 00000000..ebf44990 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/relation_binder.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/relation_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +//! The relation binder is a binder used to bind expressions in the relation API +class RelationBinder : public ExpressionBinder { +public: + RelationBinder(Binder &binder, ClientContext &context, string op); + + string op; + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + string UnsupportedAggregateMessage() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/returning_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/returning_binder.hpp new file mode 100644 index 00000000..bf8e838d --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/returning_binder.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/returning_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +//! The RETURNING binder is responsible for binding expressions within the RETURNING statement +class ReturningBinder : public ExpressionBinder { +public: + ReturningBinder(Binder &binder, ClientContext &context); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp new file mode 100644 index 00000000..9c2ca8a7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/select_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder/base_select_binder.hpp" + +namespace duckdb { + +//! The SELECT binder is responsible for binding an expression within the SELECT clause of a SQL statement +class SelectBinder : public BaseSelectBinder { +public: + SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, + case_insensitive_map_t alias_map); + SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info); + + bool HasExpandedExpressions() { + return !expanded_expressions.empty(); + } + vector> &ExpandedExpressions() { + return expanded_expressions; + } + +protected: + BindResult BindUnnest(FunctionExpression &function, idx_t depth, bool root_expression) override; + + idx_t unnest_level = 0; + vector> expanded_expressions; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/table_function_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/table_function_binder.hpp new file mode 100644 index 00000000..7f844858 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/table_function_binder.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/table_function_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +//! The Table function binder can bind standard table function parameters (i.e. non-table-in-out functions) +class TableFunctionBinder : public ExpressionBinder { +public: + TableFunctionBinder(Binder &binder, ClientContext &context); + +protected: + BindResult BindColumnReference(ColumnRefExpression &expr, idx_t depth, bool root_expression); + BindResult BindExpression(unique_ptr &expr, idx_t depth, bool root_expression = false) override; + + string UnsupportedAggregateMessage() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/update_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/update_binder.hpp new file mode 100644 index 00000000..383800bb --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/update_binder.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/update_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +//! The UPDATE binder is responsible for binding an expression within an UPDATE statement +class UpdateBinder : public ExpressionBinder { +public: + UpdateBinder(Binder &binder, ClientContext &context); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + string UnsupportedAggregateMessage() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp new file mode 100644 index 00000000..e04651e1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/where_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +class ColumnAliasBinder; + +//! The WHERE binder is responsible for binding an expression within the WHERE clause of a SQL statement +class WhereBinder : public ExpressionBinder { +public: + WhereBinder(Binder &binder, ClientContext &context, optional_ptr column_alias_binder = nullptr); + +protected: + BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression = false) override; + + string UnsupportedAggregateMessage() override; + +private: + BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression); + + optional_ptr column_alias_binder; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp new file mode 100644 index 00000000..bf3b204a --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_iterator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/planner/expression.hpp" + +#include + +namespace duckdb { +class BoundQueryNode; +class BoundTableRef; + +class ExpressionIterator { +public: + static void EnumerateChildren(const Expression &expression, + const std::function &callback); + static void EnumerateChildren(Expression &expression, const std::function &callback); + static void EnumerateChildren(Expression &expression, + const std::function &child)> &callback); + + static void EnumerateExpression(unique_ptr &expr, + const std::function &callback); + + static void EnumerateTableRefChildren(BoundTableRef &ref, const std::function &callback); + static void EnumerateQueryNodeChildren(BoundQueryNode &node, + const std::function &callback); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/extension_callback.hpp b/src/duckdb/src/include/duckdb/planner/extension_callback.hpp new file mode 100644 index 00000000..3665df80 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/extension_callback.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/extension_callback.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { +class DatabaseInstance; + +class ExtensionCallback { +public: + virtual ~ExtensionCallback() { + } + + //! Called after an extension is finished loading + virtual void OnExtensionLoaded(DatabaseInstance &db, const string &name) { + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/conjunction_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/conjunction_filter.hpp new file mode 100644 index 00000000..470093af --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/filter/conjunction_filter.hpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/conjunction_filter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { +class ConjunctionFilter : public TableFilter { +public: + ConjunctionFilter(TableFilterType filter_type_p) : TableFilter(filter_type_p) { + } + + virtual ~ConjunctionFilter() { + } + + //! The filters of this conjunction + vector> child_filters; + +public: + virtual FilterPropagateResult CheckStatistics(BaseStatistics &stats) = 0; + virtual string ToString(const string &column_name) = 0; + + virtual bool Equals(const TableFilter &other) const { + return TableFilter::Equals(other); + } +}; + +class ConjunctionOrFilter : public ConjunctionFilter { +public: + static constexpr const TableFilterType TYPE = TableFilterType::CONJUNCTION_OR; + +public: + ConjunctionOrFilter(); + +public: + FilterPropagateResult CheckStatistics(BaseStatistics &stats) override; + string ToString(const string &column_name) override; + bool Equals(const TableFilter &other) const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +class ConjunctionAndFilter : public ConjunctionFilter { +public: + static constexpr const TableFilterType TYPE = TableFilterType::CONJUNCTION_AND; + +public: + ConjunctionAndFilter(); + +public: + FilterPropagateResult CheckStatistics(BaseStatistics &stats) override; + string ToString(const string &column_name) override; + bool Equals(const TableFilter &other) const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/constant_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/constant_filter.hpp new file mode 100644 index 00000000..b4ceed04 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/filter/constant_filter.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/constant_filter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/enums/expression_type.hpp" + +namespace duckdb { + +class ConstantFilter : public TableFilter { +public: + static constexpr const TableFilterType TYPE = TableFilterType::CONSTANT_COMPARISON; + +public: + ConstantFilter(ExpressionType comparison_type, Value constant); + + //! The comparison type (e.g. COMPARE_EQUAL, COMPARE_GREATERTHAN, COMPARE_LESSTHAN, ...) + ExpressionType comparison_type; + //! The constant value to filter on + Value constant; + +public: + FilterPropagateResult CheckStatistics(BaseStatistics &stats) override; + string ToString(const string &column_name) override; + bool Equals(const TableFilter &other) const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/null_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/null_filter.hpp new file mode 100644 index 00000000..6553f77e --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/filter/null_filter.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/null_filter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/table_filter.hpp" + +namespace duckdb { + +class IsNullFilter : public TableFilter { +public: + static constexpr const TableFilterType TYPE = TableFilterType::IS_NULL; + +public: + IsNullFilter(); + +public: + FilterPropagateResult CheckStatistics(BaseStatistics &stats) override; + string ToString(const string &column_name) override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +class IsNotNullFilter : public TableFilter { +public: + static constexpr const TableFilterType TYPE = TableFilterType::IS_NOT_NULL; + +public: + IsNotNullFilter(); + +public: + FilterPropagateResult CheckStatistics(BaseStatistics &stats) override; + string ToString(const string &column_name) override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/joinside.hpp b/src/duckdb/src/include/duckdb/planner/joinside.hpp new file mode 100644 index 00000000..aa655e3e --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/joinside.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/joinside.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! JoinCondition represents a left-right comparison join condition +struct JoinCondition { +public: + JoinCondition() { + } + + //! Turns the JoinCondition into an expression; note that this destroys the JoinCondition as the expression inherits + //! the left/right expressions + static unique_ptr CreateExpression(JoinCondition cond); + static unique_ptr CreateExpression(vector conditions); + + void Serialize(Serializer &serializer) const; + static JoinCondition Deserialize(Deserializer &deserializer); + +public: + unique_ptr left; + unique_ptr right; + ExpressionType comparison; +}; + +class JoinSide { +public: + enum JoinValue : uint8_t { NONE, LEFT, RIGHT, BOTH }; + + JoinSide() = default; + constexpr JoinSide(JoinValue val) : value(val) { // NOLINT: Allow implicit conversion from `join_value` + } + + bool operator==(JoinSide a) const { + return value == a.value; + } + bool operator!=(JoinSide a) const { + return value != a.value; + } + + static JoinSide CombineJoinSide(JoinSide left, JoinSide right); + static JoinSide GetJoinSide(idx_t table_binding, const unordered_set &left_bindings, + const unordered_set &right_bindings); + static JoinSide GetJoinSide(Expression &expression, const unordered_set &left_bindings, + const unordered_set &right_bindings); + static JoinSide GetJoinSide(const unordered_set &bindings, const unordered_set &left_bindings, + const unordered_set &right_bindings); + +private: + JoinValue value; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/logical_operator.hpp b/src/duckdb/src/include/duckdb/planner/logical_operator.hpp new file mode 100644 index 00000000..ccb9cd12 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/logical_operator.hpp @@ -0,0 +1,101 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/logical_operator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/logical_operator_type.hpp" +#include "duckdb/optimizer/join_order/estimated_properties.hpp" +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" + +#include +#include + +namespace duckdb { + +//! LogicalOperator is the base class of the logical operators present in the +//! logical query tree +class LogicalOperator { +public: + explicit LogicalOperator(LogicalOperatorType type); + LogicalOperator(LogicalOperatorType type, vector> expressions); + virtual ~LogicalOperator(); + + //! The type of the logical operator + LogicalOperatorType type; + //! The set of children of the operator + vector> children; + //! The set of expressions contained within the operator, if any + vector> expressions; + //! The types returned by this logical operator. Set by calling LogicalOperator::ResolveTypes. + vector types; + //! Estimated Cardinality + idx_t estimated_cardinality; + bool has_estimated_cardinality; + +public: + virtual vector GetColumnBindings(); + static vector GenerateColumnBindings(idx_t table_idx, idx_t column_count); + static vector MapTypes(const vector &types, const vector &projection_map); + static vector MapBindings(const vector &types, const vector &projection_map); + + //! Resolve the types of the logical operator and its children + void ResolveOperatorTypes(); + + virtual string GetName() const; + virtual string ParamsToString() const; + virtual string ToString() const; + DUCKDB_API void Print(); + //! Debug method: verify that the integrity of expressions & child nodes are maintained + virtual void Verify(ClientContext &context); + + void AddChild(unique_ptr child); + virtual idx_t EstimateCardinality(ClientContext &context); + + virtual void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + + virtual unique_ptr Copy(ClientContext &context) const; + + virtual bool RequireOptimizer() const { + return true; + } + + //! Allows LogicalOperators to opt out of serialization + virtual bool SupportSerialization() const { + return true; + }; + + //! Returns the set of table indexes of this operator + virtual vector GetTableIndex() const; + +protected: + //! Resolve types for this specific operator + virtual void ResolveTypes() = 0; + +public: + template + TARGET &Cast() { + if (TARGET::TYPE != LogicalOperatorType::LOGICAL_INVALID && type != TARGET::TYPE) { + throw InternalException("Failed to cast logical operator to type - logical operator type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (TARGET::TYPE != LogicalOperatorType::LOGICAL_INVALID && type != TARGET::TYPE) { + throw InternalException("Failed to cast logical operator to type - logical operator type mismatch"); + } + return reinterpret_cast(*this); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/logical_operator_visitor.hpp b/src/duckdb/src/include/duckdb/planner/logical_operator_visitor.hpp new file mode 100644 index 00000000..feb8ca7e --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/logical_operator_visitor.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/logical_operator_visitor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/planner/bound_tokens.hpp" +#include "duckdb/planner/logical_tokens.hpp" + +#include + +namespace duckdb { +//! The LogicalOperatorVisitor is an abstract base class that implements the +//! Visitor pattern on LogicalOperator. +class LogicalOperatorVisitor { +public: + virtual ~LogicalOperatorVisitor() {}; + + virtual void VisitOperator(LogicalOperator &op); + virtual void VisitExpression(unique_ptr *expression); + + static void EnumerateExpressions(LogicalOperator &op, + const std::function *child)> &callback); + +protected: + //! Automatically calls the Visit method for LogicalOperator children of the current operator. Can be overloaded to + //! change this behavior. + void VisitOperatorChildren(LogicalOperator &op); + //! Automatically calls the Visit method for Expression children of the current operator. Can be overloaded to + //! change this behavior. + void VisitOperatorExpressions(LogicalOperator &op); + + // The VisitExpressionChildren method is called at the end of every call to VisitExpression to recursively visit all + // expressions in an expression tree. It can be overloaded to prevent automatically visiting the entire tree. + virtual void VisitExpressionChildren(Expression &expression); + + virtual unique_ptr VisitReplace(BoundAggregateExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundBetweenExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundCaseExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundCastExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundComparisonExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundConjunctionExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundConstantExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundDefaultExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundFunctionExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundOperatorExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundReferenceExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundParameterExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundWindowExpression &expr, unique_ptr *expr_ptr); + virtual unique_ptr VisitReplace(BoundUnnestExpression &expr, unique_ptr *expr_ptr); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/logical_tokens.hpp b/src/duckdb/src/include/duckdb/planner/logical_tokens.hpp new file mode 100644 index 00000000..0d033ef6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/logical_tokens.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/logical_tokens.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +class LogicalOperator; + +class LogicalAggregate; +class LogicalAnyJoin; +class LogicalColumnDataGet; +class LogicalComparisonJoin; +class LogicalCopyToFile; +class LogicalCreate; +class LogicalCreateTable; +class LogicalCreateIndex; +class LogicalCreateTable; +class LogicalCrossProduct; +class LogicalCTERef; +class LogicalDelete; +class LogicalDelimGet; +class LogicalDistinct; +class LogicalDummyScan; +class LogicalEmptyResult; +class LogicalExecute; +class LogicalExplain; +class LogicalExport; +class LogicalExpressionGet; +class LogicalFilter; +class LogicalGet; +class LogicalInsert; +class LogicalJoin; +class LogicalLimit; +class LogicalOrder; +class LogicalPivot; +class LogicalPositionalJoin; +class LogicalPragma; +class LogicalPrepare; +class LogicalProjection; +class LogicalRecursiveCTE; +class LogicalMaterializedCTE; +class LogicalSetOperation; +class LogicalSample; +class LogicalShow; +class LogicalSimple; +class LogicalSet; +class LogicalReset; +class LogicalTopN; +class LogicalUnnest; +class LogicalUpdate; +class LogicalWindow; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/list.hpp b/src/duckdb/src/include/duckdb/planner/operator/list.hpp new file mode 100644 index 00000000..e782f8ed --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/list.hpp @@ -0,0 +1,44 @@ +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_any_join.hpp" +#include "duckdb/planner/operator/logical_column_data_get.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_copy_to_file.hpp" +#include "duckdb/planner/operator/logical_create.hpp" +#include "duckdb/planner/operator/logical_create_index.hpp" +#include "duckdb/planner/operator/logical_create_table.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" +#include "duckdb/planner/operator/logical_delete.hpp" +#include "duckdb/planner/operator/logical_delim_get.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" +#include "duckdb/planner/operator/logical_execute.hpp" +#include "duckdb/planner/operator/logical_explain.hpp" +#include "duckdb/planner/operator/logical_export.hpp" +#include "duckdb/planner/operator/logical_expression_get.hpp" +#include "duckdb/planner/operator/logical_extension_operator.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_insert.hpp" +#include "duckdb/planner/operator/logical_join.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" +#include "duckdb/planner/operator/logical_limit_percent.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" +#include "duckdb/planner/operator/logical_order.hpp" +#include "duckdb/planner/operator/logical_pivot.hpp" +#include "duckdb/planner/operator/logical_positional_join.hpp" +#include "duckdb/planner/operator/logical_pragma.hpp" +#include "duckdb/planner/operator/logical_prepare.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_recursive_cte.hpp" +#include "duckdb/planner/operator/logical_reset.hpp" +#include "duckdb/planner/operator/logical_sample.hpp" +#include "duckdb/planner/operator/logical_set.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/operator/logical_show.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" +#include "duckdb/planner/operator/logical_top_n.hpp" +#include "duckdb/planner/operator/logical_unnest.hpp" +#include "duckdb/planner/operator/logical_update.hpp" +#include "duckdb/planner/operator/logical_window.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_aggregate.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_aggregate.hpp new file mode 100644 index 00000000..b126ba5f --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_aggregate.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_aggregate.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/parser/group_by_node.hpp" + +namespace duckdb { + +//! LogicalAggregate represents an aggregate operation with (optional) GROUP BY +//! operator. +class LogicalAggregate : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY; + +public: + LogicalAggregate(idx_t group_index, idx_t aggregate_index, vector> select_list); + + //! The table index for the groups of the LogicalAggregate + idx_t group_index; + //! The table index for the aggregates of the LogicalAggregate + idx_t aggregate_index; + //! The table index for the GROUPING function calls of the LogicalAggregate + idx_t groupings_index; + //! The set of groups (optional). + vector> groups; + //! The set of grouping sets (optional). + vector grouping_sets; + //! The list of grouping function calls (optional) + vector> grouping_functions; + //! Group statistics (optional) + vector> group_stats; + +public: + string ParamsToString() const override; + + vector GetColumnBindings() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + idx_t EstimateCardinality(ClientContext &context) override; + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_any_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_any_join.hpp new file mode 100644 index 00000000..d97015ab --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_any_join.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_any_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/planner/operator/logical_join.hpp" + +namespace duckdb { + +//! LogicalAnyJoin represents a join with an arbitrary expression as JoinCondition +class LogicalAnyJoin : public LogicalJoin { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_ANY_JOIN; + +public: + explicit LogicalAnyJoin(JoinType type); + + //! The JoinCondition on which this join is performed + unique_ptr condition; + +public: + string ParamsToString() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp new file mode 100644 index 00000000..55c09970 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_column_data_get.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalColumnDataGet represents a scan operation from a ColumnDataCollection +class LogicalColumnDataGet : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_CHUNK_GET; + +public: + LogicalColumnDataGet(idx_t table_index, vector types, unique_ptr collection); + + //! The table index in the current bind context + idx_t table_index; + //! The types of the chunk + vector chunk_types; + //! The chunk collection to scan + unique_ptr collection; + +public: + vector GetColumnBindings() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override { + // types are resolved in the constructor + this->types = chunk_types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_comparison_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_comparison_join.hpp new file mode 100644 index 00000000..abcfafb0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_comparison_join.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_comparison_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/enums/joinref_type.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/joinside.hpp" +#include "duckdb/planner/operator/logical_join.hpp" + +namespace duckdb { + +//! LogicalComparisonJoin represents a join that involves comparisons between the LHS and RHS +class LogicalComparisonJoin : public LogicalJoin { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INVALID; + +public: + explicit LogicalComparisonJoin(JoinType type, + LogicalOperatorType logical_type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN); + + //! The conditions of the join + vector conditions; + //! Used for duplicate-eliminated MARK joins + vector mark_types; + //! The set of columns that will be duplicate eliminated from the LHS and pushed into the RHS + vector> duplicate_eliminated_columns; + +public: + string ParamsToString() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +public: + static unique_ptr CreateJoin(ClientContext &context, JoinType type, JoinRefType ref_type, + unique_ptr left_child, + unique_ptr right_child, + unique_ptr condition); + static unique_ptr CreateJoin(ClientContext &context, JoinType type, JoinRefType ref_type, + unique_ptr left_child, + unique_ptr right_child, + vector conditions, + vector> arbitrary_expressions); + + static void ExtractJoinConditions(ClientContext &context, JoinType type, unique_ptr &left_child, + unique_ptr &right_child, unique_ptr condition, + vector &conditions, + vector> &arbitrary_expressions); + static void ExtractJoinConditions(ClientContext &context, JoinType type, unique_ptr &left_child, + unique_ptr &right_child, + vector> &expressions, vector &conditions, + vector> &arbitrary_expressions); + static void ExtractJoinConditions(ClientContext &context, JoinType type, unique_ptr &left_child, + unique_ptr &right_child, + const unordered_set &left_bindings, + const unordered_set &right_bindings, + vector> &expressions, vector &conditions, + vector> &arbitrary_expressions); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_copy_to_file.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_copy_to_file.hpp new file mode 100644 index 00000000..8bac2757 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_copy_to_file.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_copy_to_file.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/filename_pattern.hpp" +#include "duckdb/common/local_file_system.hpp" +#include "duckdb/function/copy_function.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalCopyToFile : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_COPY_TO_FILE; + +public: + LogicalCopyToFile(CopyFunction function, unique_ptr bind_data) + : LogicalOperator(LogicalOperatorType::LOGICAL_COPY_TO_FILE), function(function), + bind_data(std::move(bind_data)) { + } + CopyFunction function; + unique_ptr bind_data; + std::string file_path; + bool use_tmp_file; + FilenamePattern filename_pattern; + bool overwrite_or_ignore; + bool per_thread_output; + + bool partition_output; + vector partition_columns; + vector names; + vector expected_types; + +public: + idx_t EstimateCardinality(ClientContext &context) override; + //! Skips the serialization check in VerifyPlan + bool SupportSerialization() const override { + return false; + } + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + void ResolveTypes() override { + types.emplace_back(LogicalType::BIGINT); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_create.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_create.hpp new file mode 100644 index 00000000..02429563 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_create.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_create.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/parser/parsed_data/create_info.hpp" + +namespace duckdb { + +//! LogicalCreate represents a CREATE operator +class LogicalCreate : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INVALID; + +public: + LogicalCreate(LogicalOperatorType type, unique_ptr info, + optional_ptr schema = nullptr); + + optional_ptr schema; + unique_ptr info; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + idx_t EstimateCardinality(ClientContext &context) override; + +protected: + void ResolveTypes() override; + +private: + LogicalCreate(LogicalOperatorType type, ClientContext &context, unique_ptr info); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_create_index.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_create_index.hpp new file mode 100644 index 00000000..e9925cb4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_create_index.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_create_index.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/function/table_function.hpp" + +namespace duckdb { + +class LogicalCreateIndex : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_CREATE_INDEX; + +public: + LogicalCreateIndex(unique_ptr info_p, vector> expressions_p, + TableCatalogEntry &table_p); + + // Info for index creation + unique_ptr info; + + //! The table to create the index for + TableCatalogEntry &table; + + //! Unbound expressions to be used in the optimizer + vector> unbound_expressions; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + void ResolveTypes() override; + +private: + LogicalCreateIndex(ClientContext &context, unique_ptr info, vector> expressions); + + TableCatalogEntry &BindTable(ClientContext &context, CreateIndexInfo &info); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_create_table.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_create_table.hpp new file mode 100644 index 00000000..294d285b --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_create_table.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_create_table.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalCreateTable : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_CREATE_TABLE; + +public: + LogicalCreateTable(SchemaCatalogEntry &schema, unique_ptr info); + + //! Schema to insert to + SchemaCatalogEntry &schema; + //! Create Table information + unique_ptr info; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + idx_t EstimateCardinality(ClientContext &context) override; + +protected: + void ResolveTypes() override; + +private: + LogicalCreateTable(ClientContext &context, unique_ptr info); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_cross_product.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_cross_product.hpp new file mode 100644 index 00000000..47d49af6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_cross_product.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_cross_product.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/operator/logical_unconditional_join.hpp" + +namespace duckdb { + +//! LogicalCrossProduct represents a cross product between two relations +class LogicalCrossProduct : public LogicalUnconditionalJoin { + LogicalCrossProduct() : LogicalUnconditionalJoin(LogicalOperatorType::LOGICAL_CROSS_PRODUCT) {}; + +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_CROSS_PRODUCT; + +public: + LogicalCrossProduct(unique_ptr left, unique_ptr right); + +public: + static unique_ptr Create(unique_ptr left, unique_ptr right); + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_cteref.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_cteref.hpp new file mode 100644 index 00000000..da943b01 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_cteref.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_cteref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/chunk_collection.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/common/enums/cte_materialize.hpp" + +namespace duckdb { + +//! LogicalCTERef represents a reference to a recursive CTE +class LogicalCTERef : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_CTE_REF; + +public: + LogicalCTERef(idx_t table_index, idx_t cte_index, vector types, vector colnames, + CTEMaterialize materialized_cte) + : LogicalOperator(LogicalOperatorType::LOGICAL_CTE_REF), table_index(table_index), cte_index(cte_index), + materialized_cte(materialized_cte) { + D_ASSERT(types.size() > 0); + chunk_types = types; + bound_columns = colnames; + } + + vector bound_columns; + //! The table index in the current bind context + idx_t table_index; + //! CTE index + idx_t cte_index; + //! The types of the chunk + vector chunk_types; + //! Does this operator read a materialized CTE? + CTEMaterialize materialized_cte; + +public: + vector GetColumnBindings() override { + return GenerateColumnBindings(table_index, chunk_types.size()); + } + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override { + // types are resolved in the constructor + this->types = chunk_types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_delete.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_delete.hpp new file mode 100644 index 00000000..005d955b --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_delete.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_delete.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { +class TableCatalogEntry; + +class LogicalDelete : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_DELETE; + +public: + explicit LogicalDelete(TableCatalogEntry &table, idx_t table_index); + + TableCatalogEntry &table; + idx_t table_index; + bool return_chunk; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + idx_t EstimateCardinality(ClientContext &context) override; + vector GetTableIndex() const override; + string GetName() const override; + +protected: + vector GetColumnBindings() override; + void ResolveTypes() override; + +private: + LogicalDelete(ClientContext &context, const unique_ptr &table_info); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_delim_get.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_delim_get.hpp new file mode 100644 index 00000000..63421ee4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_delim_get.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_delim_get.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalDelimGet represents a duplicate eliminated scan belonging to a DelimJoin +class LogicalDelimGet : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_DELIM_GET; + +public: + LogicalDelimGet(idx_t table_index, vector types) + : LogicalOperator(LogicalOperatorType::LOGICAL_DELIM_GET), table_index(table_index) { + D_ASSERT(types.size() > 0); + chunk_types = types; + } + + //! The table index in the current bind context + idx_t table_index; + //! The types of the chunk + vector chunk_types; + +public: + vector GetColumnBindings() override { + return GenerateColumnBindings(table_index, chunk_types.size()); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override { + // types are resolved in the constructor + this->types = chunk_types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp new file mode 100644 index 00000000..769aa180 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_dependent_join.hpp +// +// logical_dependent_join represents a logical operator for lateral joins that +// is planned but not yet flattened +// +// This construct only exists during planning and should not exist in the plan +// once flattening is complete. Although the same information can be kept in the +// join itself, creating a new construct makes the code cleaner and easier to +// understand. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" + +namespace duckdb { + +class LogicalDependentJoin : public LogicalComparisonJoin { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_DEPENDENT_JOIN; + +public: + explicit LogicalDependentJoin(unique_ptr left, unique_ptr right, + vector correlated_columns, JoinType type, + unique_ptr condition); + + //! The conditions of the join + unique_ptr join_condition; + //! The list of columns that have correlations with the right + vector correlated_columns; + +public: + static unique_ptr Create(unique_ptr left, unique_ptr right, + vector correlated_columns, JoinType type, + unique_ptr condition); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_distinct.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_distinct.hpp new file mode 100644 index 00000000..9ced2e4e --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_distinct.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_distinct.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" + +namespace duckdb { + +//! LogicalDistinct filters duplicate entries from its child operator +class LogicalDistinct : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_DISTINCT; + +public: + explicit LogicalDistinct(DistinctType distinct_type); + explicit LogicalDistinct(vector> targets, DistinctType distinct_type); + + //! Whether or not this is a DISTINCT or DISTINCT ON + DistinctType distinct_type; + //! The set of distinct targets + vector> distinct_targets; + //! The order by modifier (optional, only for distinct on) + unique_ptr order_by; + +public: + string ParamsToString() const override; + + vector GetColumnBindings() override { + return children[0]->GetColumnBindings(); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + void ResolveTypes() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_dummy_scan.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_dummy_scan.hpp new file mode 100644 index 00000000..7adc63f8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_dummy_scan.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_dummy_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalDummyScan represents a dummy scan returning a single row +class LogicalDummyScan : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_DUMMY_SCAN; + +public: + explicit LogicalDummyScan(idx_t table_index) + : LogicalOperator(LogicalOperatorType::LOGICAL_DUMMY_SCAN), table_index(table_index) { + } + + idx_t table_index; + +public: + vector GetColumnBindings() override { + return {ColumnBinding(table_index, 0)}; + } + + idx_t EstimateCardinality(ClientContext &context) override { + return 1; + } + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override { + if (types.size() == 0) { + types.emplace_back(LogicalType::INTEGER); + } + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_empty_result.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_empty_result.hpp new file mode 100644 index 00000000..5185ddc6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_empty_result.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_empty_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalEmptyResult returns an empty result. This is created by the optimizer if it can reason that certain parts of +//! the tree will always return an empty result. +class LogicalEmptyResult : public LogicalOperator { + LogicalEmptyResult(); + +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_EMPTY_RESULT; + +public: + explicit LogicalEmptyResult(unique_ptr op); + + //! The set of return types of the empty result + vector return_types; + //! The columns that would be bound at this location (if the subtree was not optimized away) + vector bindings; + +public: + vector GetColumnBindings() override { + return bindings; + } + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + idx_t EstimateCardinality(ClientContext &context) override { + return 0; + } + +protected: + void ResolveTypes() override { + this->types = return_types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_execute.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_execute.hpp new file mode 100644 index 00000000..62eb28f2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_execute.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_execute.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalExecute : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_EXECUTE; + +public: + explicit LogicalExecute(shared_ptr prepared_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_EXECUTE), prepared(std::move(prepared_p)) { + D_ASSERT(prepared); + types = prepared->types; + } + + shared_ptr prepared; + +public: + //! Skips the serialization check in VerifyPlan + bool SupportSerialization() const override { + return false; + } + +protected: + void ResolveTypes() override { + // already resolved + } + vector GetColumnBindings() override { + return GenerateColumnBindings(0, types.size()); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_explain.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_explain.hpp new file mode 100644 index 00000000..dad4d898 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_explain.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_explain.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalExplain : public LogicalOperator { + LogicalExplain(ExplainType explain_type) + : LogicalOperator(LogicalOperatorType::LOGICAL_EXPLAIN), explain_type(explain_type) {}; + +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_EXPLAIN; + +public: + LogicalExplain(unique_ptr plan, ExplainType explain_type) + : LogicalOperator(LogicalOperatorType::LOGICAL_EXPLAIN), explain_type(explain_type) { + children.push_back(std::move(plan)); + } + + ExplainType explain_type; + string physical_plan; + string logical_plan_unopt; + string logical_plan_opt; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + idx_t EstimateCardinality(ClientContext &context) override { + return 3; + } + //! Skips the serialization check in VerifyPlan + bool SupportSerialization() const override { + return false; + } + +protected: + void ResolveTypes() override { + types = {LogicalType::VARCHAR, LogicalType::VARCHAR}; + } + vector GetColumnBindings() override { + return {ColumnBinding(0, 0), ColumnBinding(0, 1)}; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_export.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_export.hpp new file mode 100644 index 00000000..3de951e9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_export.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_export.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/parser/parsed_data/exported_table_data.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/function/copy_function.hpp" + +namespace duckdb { + +class LogicalExport : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_EXPORT; + +public: + LogicalExport(CopyFunction function, unique_ptr copy_info, BoundExportData exported_tables) + : LogicalOperator(LogicalOperatorType::LOGICAL_EXPORT), function(function), copy_info(std::move(copy_info)), + exported_tables(std::move(exported_tables)) { + } + CopyFunction function; + unique_ptr copy_info; + BoundExportData exported_tables; + +public: +protected: + void ResolveTypes() override { + types.emplace_back(LogicalType::BOOLEAN); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_expression_get.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_expression_get.hpp new file mode 100644 index 00000000..70120431 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_expression_get.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_expression_get.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalExpressionGet represents a scan operation over a set of to-be-executed expressions +class LogicalExpressionGet : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_EXPRESSION_GET; + +public: + LogicalExpressionGet(idx_t table_index, vector types, + vector>> expressions) + : LogicalOperator(LogicalOperatorType::LOGICAL_EXPRESSION_GET), table_index(table_index), expr_types(types), + expressions(std::move(expressions)) { + } + + //! The table index in the current bind context + idx_t table_index; + //! The types of the expressions + vector expr_types; + //! The set of expressions + vector>> expressions; + +public: + vector GetColumnBindings() override { + return GenerateColumnBindings(table_index, expr_types.size()); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + idx_t EstimateCardinality(ClientContext &context) override { + return expressions.size(); + } + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override { + // types are resolved in the constructor + this->types = expr_types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_extension_operator.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_extension_operator.hpp new file mode 100644 index 00000000..3fc6b6e7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_extension_operator.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_extension_operator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/operator_extension.hpp" + +namespace duckdb { + +class ColumnBindingResolver; + +struct LogicalExtensionOperator : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR; + +public: + LogicalExtensionOperator() : LogicalOperator(LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR) { + } + LogicalExtensionOperator(vector> expressions) + : LogicalOperator(LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR, std::move(expressions)) { + } + + virtual void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + virtual unique_ptr CreatePlan(ClientContext &context, PhysicalPlanGenerator &generator) = 0; + + virtual void ResolveColumnBindings(ColumnBindingResolver &res, vector &bindings); + virtual string GetExtensionName() const; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_filter.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_filter.hpp new file mode 100644 index 00000000..acd5771b --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_filter.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_filter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalFilter represents a filter operation (e.g. WHERE or HAVING clause) +class LogicalFilter : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_FILTER; + +public: + explicit LogicalFilter(unique_ptr expression); + LogicalFilter(); + + vector projection_map; + +public: + vector GetColumnBindings() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + bool SplitPredicates() { + return SplitPredicates(expressions); + } + //! Splits up the predicates of the LogicalFilter into a set of predicates + //! separated by AND Returns whether or not any splits were made + static bool SplitPredicates(vector> &expressions); + +protected: + void ResolveTypes() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp new file mode 100644 index 00000000..246a83cb --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp @@ -0,0 +1,81 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_get.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/table_function.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/extra_operator_info.hpp" + +namespace duckdb { + +//! LogicalGet represents a scan operation from a data source +class LogicalGet : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_GET; + +public: + LogicalGet(idx_t table_index, TableFunction function, unique_ptr bind_data, + vector returned_types, vector returned_names); + + //! The table index in the current bind context + idx_t table_index; + //! The function that is called + TableFunction function; + //! The bind data of the function + unique_ptr bind_data; + //! The types of ALL columns that can be returned by the table function + vector returned_types; + //! The names of ALL columns that can be returned by the table function + vector names; + //! Bound column IDs + vector column_ids; + //! Columns that are used outside of the scan + vector projection_ids; + //! Filters pushed down for table scan + TableFilterSet table_filters; + //! The set of input parameters for the table function + vector parameters; + //! The set of named input parameters for the table function + named_parameter_map_t named_parameters; + //! The set of named input table types for the table-in table-out function + vector input_table_types; + //! The set of named input table names for the table-in table-out function + vector input_table_names; + //! For a table-in-out function, the set of projected input columns + vector projected_input; + //! Currently stores File Filters (as strings) applied by hive partitioning/complex filter pushdown + //! Stored so the can be included in explain output + ExtraOperatorInfo extra_info; + + string GetName() const override; + string ParamsToString() const override; + //! Returns the underlying table that is being scanned, or nullptr if there is none + optional_ptr GetTable() const; + +public: + vector GetColumnBindings() override; + idx_t EstimateCardinality(ClientContext &context) override; + + vector GetTableIndex() const override; + //! Skips the serialization check in VerifyPlan + bool SupportSerialization() const override { + return function.verify_serialization; + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + void ResolveTypes() override; + +private: + LogicalGet(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp new file mode 100644 index 00000000..615d4aa1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp @@ -0,0 +1,77 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_insert.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/common/index_vector.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" + +namespace duckdb { +class TableCatalogEntry; + +class Index; + +//! LogicalInsert represents an insertion of data into a base table +class LogicalInsert : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INSERT; + +public: + LogicalInsert(TableCatalogEntry &table, idx_t table_index); + + vector>> insert_values; + //! The insertion map ([table_index -> index in result, or DConstants::INVALID_INDEX if not specified]) + physical_index_vector_t column_index_map; + //! The expected types for the INSERT statement (obtained from the column types) + vector expected_types; + //! The base table to insert into + TableCatalogEntry &table; + idx_t table_index; + //! if returning option is used, return actual chunk to projection + bool return_chunk; + //! The default statements used by the table + vector> bound_defaults; + + //! Which action to take on conflict + OnConflictAction action_type; + // The types that the DO UPDATE .. SET (expressions) are cast to + vector expected_set_types; + // The (distinct) column ids to apply the ON CONFLICT on + unordered_set on_conflict_filter; + // The WHERE clause of the conflict_target (ON CONFLICT .. WHERE ) + unique_ptr on_conflict_condition; + // The WHERE clause of the DO UPDATE clause + unique_ptr do_update_condition; + // The columns targeted by the DO UPDATE SET expressions + vector set_columns; + // The types of the columns targeted by the DO UPDATE SET expressions + vector set_types; + // The table_index referring to the column references qualified with 'excluded' + idx_t excluded_table_index; + // The columns to fetch from the 'destination' table + vector columns_to_fetch; + // The columns to fetch from the 'source' table + vector source_columns; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + vector GetColumnBindings() override; + void ResolveTypes() override; + + idx_t EstimateCardinality(ClientContext &context) override; + vector GetTableIndex() const override; + string GetName() const override; + +private: + LogicalInsert(ClientContext &context, const unique_ptr table_info); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_join.hpp new file mode 100644 index 00000000..a6a63de6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_join.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +//! LogicalJoin represents a join between two relations +class LogicalJoin : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INVALID; + +public: + explicit LogicalJoin(JoinType type, LogicalOperatorType logical_type = LogicalOperatorType::LOGICAL_JOIN); + + // Gets the set of table references that are reachable from this node + static void GetTableReferences(LogicalOperator &op, unordered_set &bindings); + static void GetExpressionBindings(Expression &expr, unordered_set &bindings); + + //! The type of the join (INNER, OUTER, etc...) + JoinType join_type; + //! Table index used to refer to the MARK column (in case of a MARK join) + idx_t mark_index {}; + //! The columns of the LHS that are output by the join + vector left_projection_map; + //! The columns of the RHS that are output by the join + vector right_projection_map; + //! Join Keys statistics (optional) + vector> join_stats; + +public: + vector GetColumnBindings() override; + +protected: + void ResolveTypes() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_limit.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_limit.hpp new file mode 100644 index 00000000..1bec099d --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_limit.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_limit.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalLimit represents a LIMIT clause +class LogicalLimit : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_LIMIT; + +public: + LogicalLimit(int64_t limit_val, int64_t offset_val, unique_ptr limit, unique_ptr offset); + + //! Limit and offset values in case they are constants, used in optimizations. + int64_t limit_val; + int64_t offset_val; + //! The maximum amount of elements to emit + unique_ptr limit; + //! The offset from the start to begin emitting elements + unique_ptr offset; + +public: + vector GetColumnBindings() override; + idx_t EstimateCardinality(ClientContext &context) override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + void ResolveTypes() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_limit_percent.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_limit_percent.hpp new file mode 100644 index 00000000..da2c6710 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_limit_percent.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_limit_percent.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalLimitPercent represents a LIMIT PERCENT clause +class LogicalLimitPercent : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_LIMIT_PERCENT; + +public: + LogicalLimitPercent(double limit_percent, int64_t offset_val, unique_ptr limit, + unique_ptr offset) + : LogicalOperator(LogicalOperatorType::LOGICAL_LIMIT_PERCENT), limit_percent(limit_percent), + offset_val(offset_val), limit(std::move(limit)), offset(std::move(offset)) { + } + + //! Limit percent and offset values in case they are constants, used in optimizations. + double limit_percent; + int64_t offset_val; + //! The maximum amount of elements to emit + unique_ptr limit; + //! The offset from the start to begin emitting elements + unique_ptr offset; + +public: + vector GetColumnBindings() override { + return children[0]->GetColumnBindings(); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + idx_t EstimateCardinality(ClientContext &context) override; + +protected: + void ResolveTypes() override { + types = children[0]->types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_materialized_cte.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_materialized_cte.hpp new file mode 100644 index 00000000..0f5018c3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_materialized_cte.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_materialized_cte.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalMaterializedCTE : public LogicalOperator { + explicit LogicalMaterializedCTE() : LogicalOperator(LogicalOperatorType::LOGICAL_MATERIALIZED_CTE) { + } + +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_MATERIALIZED_CTE; + +public: + LogicalMaterializedCTE(string ctename, idx_t table_index, idx_t column_count, unique_ptr cte, + unique_ptr child) + : LogicalOperator(LogicalOperatorType::LOGICAL_MATERIALIZED_CTE), table_index(table_index), + column_count(column_count), ctename(ctename) { + children.push_back(std::move(cte)); + children.push_back(std::move(child)); + } + + idx_t table_index; + idx_t column_count; + string ctename; + +public: + vector GetColumnBindings() override { + return children[1]->GetColumnBindings(); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + vector GetTableIndex() const override; + +protected: + void ResolveTypes() override { + types = children[1]->types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_order.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_order.hpp new file mode 100644 index 00000000..d0080b59 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_order.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_order.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +//! LogicalOrder represents an ORDER BY clause, sorting the data +class LogicalOrder : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_ORDER_BY; + +public: + explicit LogicalOrder(vector orders); + + vector orders; + vector projections; + +public: + vector GetColumnBindings() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + string ParamsToString() const override; + +protected: + void ResolveTypes() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_pivot.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_pivot.hpp new file mode 100644 index 00000000..26d569da --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_pivot.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_pivot.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/parser/tableref/pivotref.hpp" +#include "duckdb/planner/tableref/bound_pivotref.hpp" + +namespace duckdb { + +class LogicalPivot : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_PIVOT; + +public: + LogicalPivot(idx_t pivot_idx, unique_ptr plan, BoundPivotInfo info); + + idx_t pivot_index; + //! The bound pivot info + BoundPivotInfo bound_pivot; + +public: + vector GetColumnBindings() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override; + +private: + LogicalPivot(); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_positional_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_positional_join.hpp new file mode 100644 index 00000000..c3aea8e7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_positional_join.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_positional_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/operator/logical_unconditional_join.hpp" + +namespace duckdb { + +//! LogicalPositionalJoin represents a row-wise join between two relations +class LogicalPositionalJoin : public LogicalUnconditionalJoin { + LogicalPositionalJoin() : LogicalUnconditionalJoin(LogicalOperatorType::LOGICAL_POSITIONAL_JOIN) {}; + +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_POSITIONAL_JOIN; + +public: + LogicalPositionalJoin(unique_ptr left, unique_ptr right); + +public: + static unique_ptr Create(unique_ptr left, unique_ptr right); + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_pragma.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_pragma.hpp new file mode 100644 index 00000000..901ff7b8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_pragma.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_pragma.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/pragma_function.hpp" +#include "duckdb/parser/parsed_data/pragma_info.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalSimple represents a simple logical operator that only passes on the parse info +class LogicalPragma : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_PRAGMA; + +public: + LogicalPragma(PragmaFunction function_p, PragmaInfo info_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_PRAGMA), function(std::move(function_p)), + info(std::move(info_p)) { + } + + //! The pragma function to call + PragmaFunction function; + //! The context of the call + PragmaInfo info; + +public: + idx_t EstimateCardinality(ClientContext &context) override; + //! Skips the serialization check in VerifyPlan + bool SupportSerialization() const override { + return false; + } + +protected: + void ResolveTypes() override { + types.emplace_back(LogicalType::BOOLEAN); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_prepare.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_prepare.hpp new file mode 100644 index 00000000..7977027b --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_prepare.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_prepare.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class TableCatalogEntry; + +class LogicalPrepare : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_PREPARE; + +public: + LogicalPrepare(string name, shared_ptr prepared, unique_ptr logical_plan) + : LogicalOperator(LogicalOperatorType::LOGICAL_PREPARE), name(name), prepared(std::move(prepared)) { + if (logical_plan) { + children.push_back(std::move(logical_plan)); + } + } + + string name; + shared_ptr prepared; + +public: + idx_t EstimateCardinality(ClientContext &context) override; + //! Skips the serialization check in VerifyPlan + bool SupportSerialization() const override { + return false; + } + +protected: + void ResolveTypes() override { + types.emplace_back(LogicalType::BOOLEAN); + } + + bool RequireOptimizer() const override { + if (!prepared->properties.bound_all_parameters) { + return false; + } + return children[0]->RequireOptimizer(); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_projection.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_projection.hpp new file mode 100644 index 00000000..54a3dfef --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_projection.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_projection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalProjection represents the projection list in a SELECT clause +class LogicalProjection : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_PROJECTION; + +public: + LogicalProjection(idx_t table_index, vector> select_list); + + idx_t table_index; + +public: + vector GetColumnBindings() override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_recursive_cte.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_recursive_cte.hpp new file mode 100644 index 00000000..0fa74bd6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_recursive_cte.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_recursive_cte.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalRecursiveCTE : public LogicalOperator { + LogicalRecursiveCTE() : LogicalOperator(LogicalOperatorType::LOGICAL_RECURSIVE_CTE) { + } + +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_RECURSIVE_CTE; + +public: + LogicalRecursiveCTE(string ctename, idx_t table_index, idx_t column_count, bool union_all, + unique_ptr top, unique_ptr bottom) + : LogicalOperator(LogicalOperatorType::LOGICAL_RECURSIVE_CTE), union_all(union_all), ctename(ctename), + table_index(table_index), column_count(column_count) { + children.push_back(std::move(top)); + children.push_back(std::move(bottom)); + } + + bool union_all; + string ctename; + idx_t table_index; + idx_t column_count; + +public: + vector GetColumnBindings() override { + return GenerateColumnBindings(table_index, column_count); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override { + types = children[0]->types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_reset.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_reset.hpp new file mode 100644 index 00000000..693795e6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_reset.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_reset.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/set_scope.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/function/copy_function.hpp" + +namespace duckdb { + +class LogicalReset : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_RESET; + +public: + LogicalReset(std::string name_p, SetScope scope_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_RESET), name(name_p), scope(scope_p) { + } + + std::string name; + SetScope scope; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + idx_t EstimateCardinality(ClientContext &context) override; + +protected: + void ResolveTypes() override { + types.emplace_back(LogicalType::BOOLEAN); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_sample.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_sample.hpp new file mode 100644 index 00000000..7bce8140 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_sample.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_sample.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" + +namespace duckdb { + +//! LogicalSample represents a SAMPLE clause +class LogicalSample : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_SAMPLE; + +public: + LogicalSample(unique_ptr sample_options_p, unique_ptr child); + + //! The sample options + unique_ptr sample_options; + +public: + vector GetColumnBindings() override; + idx_t EstimateCardinality(ClientContext &context) override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + void ResolveTypes() override; + +private: + LogicalSample(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_set.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_set.hpp new file mode 100644 index 00000000..185f16a4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_set.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_set.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/set_scope.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/function/copy_function.hpp" + +namespace duckdb { + +class LogicalSet : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_SET; + +public: + LogicalSet(std::string name_p, Value value_p, SetScope scope_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_SET), name(name_p), value(value_p), scope(scope_p) { + } + + std::string name; + Value value; + SetScope scope; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + idx_t EstimateCardinality(ClientContext &context) override; + +protected: + void ResolveTypes() override { + types.emplace_back(LogicalType::BOOLEAN); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_set_operation.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_set_operation.hpp new file mode 100644 index 00000000..1d2840a3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_set_operation.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_set_operation.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalSetOperation : public LogicalOperator { + LogicalSetOperation(idx_t table_index, idx_t column_count, LogicalOperatorType type) + : LogicalOperator(type), table_index(table_index), column_count(column_count) { + } + +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INVALID; + +public: + LogicalSetOperation(idx_t table_index, idx_t column_count, unique_ptr top, + unique_ptr bottom, LogicalOperatorType type) + : LogicalOperator(type), table_index(table_index), column_count(column_count) { + D_ASSERT(type == LogicalOperatorType::LOGICAL_UNION || type == LogicalOperatorType::LOGICAL_EXCEPT || + type == LogicalOperatorType::LOGICAL_INTERSECT); + children.push_back(std::move(top)); + children.push_back(std::move(bottom)); + } + + idx_t table_index; + idx_t column_count; + +public: + vector GetColumnBindings() override { + return GenerateColumnBindings(table_index, column_count); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override { + types = children[0]->types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_show.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_show.hpp new file mode 100644 index 00000000..f3b5ee10 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_show.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_show.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class LogicalShow : public LogicalOperator { + LogicalShow() : LogicalOperator(LogicalOperatorType::LOGICAL_SHOW) {}; + +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_SHOW; + +public: + explicit LogicalShow(unique_ptr plan) : LogicalOperator(LogicalOperatorType::LOGICAL_SHOW) { + children.push_back(std::move(plan)); + } + + vector types_select; + vector aliases; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +protected: + void ResolveTypes() override { + types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, + LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; + } + vector GetColumnBindings() override { + return GenerateColumnBindings(0, types.size()); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_simple.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_simple.hpp new file mode 100644 index 00000000..f1b59597 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_simple.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_simple.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalSimple represents a simple logical operator that only passes on the parse info +class LogicalSimple : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INVALID; + +public: + LogicalSimple(LogicalOperatorType type, unique_ptr info) : LogicalOperator(type), info(std::move(info)) { + } + + unique_ptr info; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + idx_t EstimateCardinality(ClientContext &context) override; + +protected: + void ResolveTypes() override { + types.emplace_back(LogicalType::BOOLEAN); + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_top_n.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_top_n.hpp new file mode 100644 index 00000000..745ca629 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_top_n.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_top_n.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalTopN represents a comibination of ORDER BY and LIMIT clause, using Min/Max Heap +class LogicalTopN : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_TOP_N; + +public: + LogicalTopN(vector orders, int64_t limit, int64_t offset) + : LogicalOperator(LogicalOperatorType::LOGICAL_TOP_N), orders(std::move(orders)), limit(limit), offset(offset) { + } + + vector orders; + //! The maximum amount of elements to emit + int64_t limit; + //! The offset from the start to begin emitting elements + int64_t offset; + +public: + vector GetColumnBindings() override { + return children[0]->GetColumnBindings(); + } + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + idx_t EstimateCardinality(ClientContext &context) override; + +protected: + void ResolveTypes() override { + types = children[0]->types; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_unconditional_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_unconditional_join.hpp new file mode 100644 index 00000000..c25509bb --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_unconditional_join.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_unconditional_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalUnconditionalJoin represents a join between two relations +//! where the join condition is implicit (cross product, position, etc.) +class LogicalUnconditionalJoin : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INVALID; + +public: + explicit LogicalUnconditionalJoin(LogicalOperatorType logical_type) : LogicalOperator(logical_type) {}; + +public: + LogicalUnconditionalJoin(LogicalOperatorType logical_type, unique_ptr left, + unique_ptr right); + +public: + vector GetColumnBindings() override; + +protected: + void ResolveTypes() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_unnest.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_unnest.hpp new file mode 100644 index 00000000..7a5f3d46 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_unnest.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_unnest.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalUnnest represents the logical UNNEST operator. +class LogicalUnnest : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_UNNEST; + +public: + explicit LogicalUnnest(idx_t unnest_index) + : LogicalOperator(LogicalOperatorType::LOGICAL_UNNEST), unnest_index(unnest_index) { + } + + idx_t unnest_index; + +public: + vector GetColumnBindings() override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_update.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_update.hpp new file mode 100644 index 00000000..9215ff69 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_update.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_update.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { +class TableCatalogEntry; + +class LogicalUpdate : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_UPDATE; + +public: + explicit LogicalUpdate(TableCatalogEntry &table); + + //! The base table to update + TableCatalogEntry &table; + //! table catalog index + idx_t table_index; + //! if returning option is used, return the update chunk + bool return_chunk; + vector columns; + vector> bound_defaults; + bool update_is_del_and_insert; + +public: + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + idx_t EstimateCardinality(ClientContext &context) override; + string GetName() const override; + +protected: + vector GetColumnBindings() override; + void ResolveTypes() override; + +private: + LogicalUpdate(ClientContext &context, const unique_ptr &table_info); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_window.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_window.hpp new file mode 100644 index 00000000..45a0667d --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_window.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_window.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! LogicalAggregate represents an aggregate operation with (optional) GROUP BY +//! operator. +class LogicalWindow : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_WINDOW; + +public: + explicit LogicalWindow(idx_t window_index) + : LogicalOperator(LogicalOperatorType::LOGICAL_WINDOW), window_index(window_index) { + } + + idx_t window_index; + +public: + vector GetColumnBindings() override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + vector GetTableIndex() const override; + string GetName() const override; + +protected: + void ResolveTypes() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator_extension.hpp b/src/duckdb/src/include/duckdb/planner/operator_extension.hpp new file mode 100644 index 00000000..0dc22fae --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator_extension.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator_extension.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +//! The OperatorExtensionInfo holds static information relevant to the operator extension +struct OperatorExtensionInfo { + virtual ~OperatorExtensionInfo() { + } +}; + +typedef BoundStatement (*bind_function_t)(ClientContext &context, Binder &binder, OperatorExtensionInfo *info, + SQLStatement &statement); + +// forward declaration to avoid circular reference +struct LogicalExtensionOperator; + +class OperatorExtension { +public: + bind_function_t Bind; + + //! Additional info passed to the CreatePlan & Bind functions + shared_ptr operator_info; + + virtual std::string GetName() = 0; + virtual unique_ptr Deserialize(Deserializer &deserializer) = 0; + + virtual ~OperatorExtension() { + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/parsed_data/bound_create_function_info.hpp b/src/duckdb/src/include/duckdb/planner/parsed_data/bound_create_function_info.hpp new file mode 100644 index 00000000..7224037c --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/parsed_data/bound_create_function_info.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/parsed_data/bound_create_function_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_macro_info.hpp" + +namespace duckdb { +class CatalogEntry; + +struct BoundCreateFunctionInfo { + explicit BoundCreateFunctionInfo(SchemaCatalogEntry &schema, unique_ptr base) + : schema(schema), base(std::move(base)) { + } + + //! The schema to create the table in + SchemaCatalogEntry &schema; + //! The base CreateInfo object + unique_ptr base; + + CreateMacroInfo &Base() { + return (CreateMacroInfo &)*base; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp b/src/duckdb/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp new file mode 100644 index 00000000..c5c8778f --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/parsed_data/bound_create_table_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/planner/bound_constraint.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/storage/table/persistent_table_data.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/catalog/catalog_entry/table_column_type.hpp" +#include "duckdb/catalog/catalog_entry/column_dependency_manager.hpp" +#include "duckdb/storage/table/table_index_list.hpp" +#include "duckdb/catalog/dependency_list.hpp" + +namespace duckdb { +class CatalogEntry; + +struct BoundCreateTableInfo { + explicit BoundCreateTableInfo(SchemaCatalogEntry &schema, unique_ptr base_p) + : schema(schema), base(std::move(base_p)) { + D_ASSERT(base); + } + + //! The schema to create the table in + SchemaCatalogEntry &schema; + //! The base CreateInfo object + unique_ptr base; + //! Column dependency manager of the table + ColumnDependencyManager column_dependency_manager; + //! List of constraints on the table + vector> constraints; + //! List of bound constraints on the table + vector> bound_constraints; + //! Bound default values + vector> bound_defaults; + //! Dependents of the table (in e.g. default values) + DependencyList dependencies; + //! The existing table data on disk (if any) + unique_ptr data; + //! CREATE TABLE from QUERY + unique_ptr query; + //! Indexes created by this table + vector indexes; + + CreateTableInfo &Base() { + D_ASSERT(base); + return (CreateTableInfo &)*base; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/planner.hpp b/src/duckdb/src/include/duckdb/planner/planner.hpp new file mode 100644 index 00000000..98556783 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/planner.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/planner.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/bound_parameter_map.hpp" + +namespace duckdb { +class ClientContext; +class PreparedStatementData; + +//! The planner creates a logical query plan from the parsed SQL statements +//! using the Binder and LogicalPlanGenerator. +class Planner { + friend class Binder; + +public: + explicit Planner(ClientContext &context); + +public: + unique_ptr plan; + vector names; + vector types; + case_insensitive_map_t parameter_data; + + shared_ptr binder; + ClientContext &context; + + StatementProperties properties; + bound_parameter_map_t value_map; + +public: + void CreatePlan(unique_ptr statement); + static void VerifyPlan(ClientContext &context, unique_ptr &op, + optional_ptr map = nullptr); + +private: + void CreatePlan(SQLStatement &statement); + shared_ptr PrepareSQLStatement(unique_ptr statement); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/pragma_handler.hpp b/src/duckdb/src/include/duckdb/planner/pragma_handler.hpp new file mode 100644 index 00000000..81ba353a --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/pragma_handler.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/pragma_handler.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/parser/statement/pragma_statement.hpp" + +namespace duckdb { +class ClientContext; +class ClientContextLock; +class SQLStatement; +struct PragmaInfo; + +//! Pragma handler is responsible for converting certain pragma statements into new queries +class PragmaHandler { +public: + explicit PragmaHandler(ClientContext &context); + + void HandlePragmaStatements(ClientContextLock &lock, vector> &statements); + +private: + ClientContext &context; + +private: + //! Handles a pragma statement, returns whether the statement was expanded, if it was expanded the 'resulting_query' + //! contains the statement(s) to replace the current one + bool HandlePragma(SQLStatement *statement, string &resulting_query); + + void HandlePragmaStatementsInternal(vector> &statements); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp new file mode 100644 index 00000000..00b3f94a --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/query_node/bound_cte_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_query_node.hpp" + +namespace duckdb { + +class BoundCTENode : public BoundQueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; + +public: + BoundCTENode() : BoundQueryNode(QueryNodeType::CTE_NODE) { + } + + //! Keep track of the CTE name this node represents + string ctename; + + //! The cte node + unique_ptr query; + //! The child node + unique_ptr child; + //! Index used by the set operation + idx_t setop_index; + //! The binder used by the query side of the CTE + shared_ptr query_binder; + //! The binder used by the child side of the CTE + shared_ptr child_binder; + +public: + idx_t GetRootIndex() override { + return child->GetRootIndex(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp new file mode 100644 index 00000000..0e9b3ddd --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/query_node/bound_recursive_cte_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_query_node.hpp" + +namespace duckdb { + +//! Bound equivalent of SetOperationNode +class BoundRecursiveCTENode : public BoundQueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::RECURSIVE_CTE_NODE; + +public: + BoundRecursiveCTENode() : BoundQueryNode(QueryNodeType::RECURSIVE_CTE_NODE) { + } + + //! Keep track of the CTE name this node represents + string ctename; + + bool union_all; + //! The left side of the set operation + unique_ptr left; + //! The right side of the set operation + unique_ptr right; + + //! Index used by the set operation + idx_t setop_index; + //! The binder used by the left side of the set operation + shared_ptr left_binder; + //! The binder used by the right side of the set operation + shared_ptr right_binder; + +public: + idx_t GetRootIndex() override { + return setop_index; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp new file mode 100644 index 00000000..af16afc5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/query_node/bound_select_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" +#include "duckdb/parser/group_by_node.hpp" + +namespace duckdb { + +class BoundGroupByNode { +public: + //! The total set of all group expressions + vector> group_expressions; + //! The different grouping sets as they map to the group expressions + vector grouping_sets; +}; + +struct BoundUnnestNode { + //! The index of the UNNEST node + idx_t index; + //! The set of expressions + vector> expressions; +}; + +//! Bound equivalent of SelectNode +class BoundSelectNode : public BoundQueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::SELECT_NODE; + +public: + BoundSelectNode() : BoundQueryNode(QueryNodeType::SELECT_NODE) { + } + + //! The original unparsed expressions. This is exported after binding, because the binding might change the + //! expressions (e.g. when a * clause is present) + vector> original_expressions; + + //! The projection list + vector> select_list; + //! The FROM clause + unique_ptr from_table; + //! The WHERE clause + unique_ptr where_clause; + //! list of groups + BoundGroupByNode groups; + //! HAVING clause + unique_ptr having; + //! QUALIFY clause + unique_ptr qualify; + //! SAMPLE clause + unique_ptr sample_options; + + //! The amount of columns in the final result + idx_t column_count; + + //! Index used by the LogicalProjection + idx_t projection_index; + + //! Group index used by the LogicalAggregate (only used if HasAggregation is true) + idx_t group_index; + //! Table index for the projection child of the group op + idx_t group_projection_index; + //! Aggregate index used by the LogicalAggregate (only used if HasAggregation is true) + idx_t aggregate_index; + //! Index used for GROUPINGS column references + idx_t groupings_index; + //! Aggregate functions to compute (only used if HasAggregation is true) + vector> aggregates; + + //! GROUPING function calls + vector> grouping_functions; + + //! Map from aggregate function to aggregate index (used to eliminate duplicate aggregates) + expression_map_t aggregate_map; + + //! Window index used by the LogicalWindow (only used if HasWindow is true) + idx_t window_index; + //! Window functions to compute (only used if HasWindow is true) + vector> windows; + + //! Unnest expression + unordered_map unnests; + + //! Index of pruned node + idx_t prune_index; + bool need_prune = false; + +public: + idx_t GetRootIndex() override { + return need_prune ? prune_index : projection_index; + } +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp new file mode 100644 index 00000000..d6c8f95d --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/query_node/bound_set_operation_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/set_operation_type.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_query_node.hpp" + +namespace duckdb { + +//! Bound equivalent of SetOperationNode +class BoundSetOperationNode : public BoundQueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::SET_OPERATION_NODE; + +public: + BoundSetOperationNode() : BoundQueryNode(QueryNodeType::SET_OPERATION_NODE) { + } + + //! The type of set operation + SetOperationType setop_type = SetOperationType::NONE; + //! The left side of the set operation + unique_ptr left; + //! The right side of the set operation + unique_ptr right; + + //! Index used by the set operation + idx_t setop_index; + //! The binder used by the left side of the set operation + shared_ptr left_binder; + //! The binder used by the right side of the set operation + shared_ptr right_binder; + + //! Exprs used by the UNION BY NAME opeartons to add a new projection + vector> left_reorder_exprs; + vector> right_reorder_exprs; + + //! The exprs of the child node may be rearranged(UNION BY NAME), + //! this vector records the new index of the expression after rearrangement + //! used by GatherAlias(...) function to create new reorder index + vector left_reorder_idx; + vector right_reorder_idx; + +public: + idx_t GetRootIndex() override { + return setop_index; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/list.hpp b/src/duckdb/src/include/duckdb/planner/query_node/list.hpp new file mode 100644 index 00000000..5c7dbda9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/query_node/list.hpp @@ -0,0 +1,4 @@ +#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" +#include "duckdb/planner/query_node/bound_cte_node.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/query_node/bound_set_operation_node.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp new file mode 100644 index 00000000..ce1cdc37 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/subquery/flatten_dependent_join.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! The FlattenDependentJoins class is responsible for pushing the dependent join down into the plan to create a +//! flattened subquery +struct FlattenDependentJoins { + FlattenDependentJoins(Binder &binder, const vector &correlated, bool perform_delim = true, + bool any_join = false); + + //! Detects which Logical Operators have correlated expressions that they are dependent upon, filling the + //! has_correlated_expressions map. + bool DetectCorrelatedExpressions(LogicalOperator *op, bool lateral = false, idx_t lateral_depth = 0); + + //! Push the dependent join down a LogicalOperator + unique_ptr PushDownDependentJoin(unique_ptr plan); + + Binder &binder; + ColumnBinding base_binding; + idx_t delim_offset; + idx_t data_offset; + unordered_map has_correlated_expressions; + column_binding_map_t correlated_map; + column_binding_map_t replacement_map; + const vector &correlated_columns; + vector delim_types; + + bool perform_delim; + bool any_join; + +private: + unique_ptr PushDownDependentJoinInternal(unique_ptr plan, + bool &parent_propagate_null_values, idx_t lateral_depth); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp b/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp new file mode 100644 index 00000000..6b238ffc --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/subquery/has_correlated_expressions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! Helper class to recursively detect correlated expressions inside a single LogicalOperator +class HasCorrelatedExpressions : public LogicalOperatorVisitor { +public: + explicit HasCorrelatedExpressions(const vector &correlated, bool lateral = false, + idx_t lateral_depth = 0); + + void VisitOperator(LogicalOperator &op) override; + + bool has_correlated_expressions; + bool lateral; + +protected: + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr) override; + + const vector &correlated_columns; + // Tracks number of nested laterals + idx_t lateral_depth; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/subquery/recursive_dependent_join_planner.hpp b/src/duckdb/src/include/duckdb/planner/subquery/recursive_dependent_join_planner.hpp new file mode 100644 index 00000000..ea7b627b --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/subquery/recursive_dependent_join_planner.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/subquery/recursive_dependent_join_planner.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator_visitor.hpp" + +namespace duckdb { + +class Binder; + +/* + * Recursively plan subqueries and flatten dependent joins from outermost to innermost (like peeling an onion). + */ +class RecursiveDependentJoinPlanner : public LogicalOperatorVisitor { +public: + explicit RecursiveDependentJoinPlanner(Binder &binder) : binder(binder) { + } + void VisitOperator(LogicalOperator &op) override; + unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr) override; + +private: + unique_ptr root; + Binder &binder; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/subquery/rewrite_correlated_expressions.hpp b/src/duckdb/src/include/duckdb/planner/subquery/rewrite_correlated_expressions.hpp new file mode 100644 index 00000000..0de0597c --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/subquery/rewrite_correlated_expressions.hpp @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/subquery/rewrite_correlated_expressions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! Helper class to rewrite correlated expressions within a single LogicalOperator +class RewriteCorrelatedExpressions : public LogicalOperatorVisitor { +public: + RewriteCorrelatedExpressions(ColumnBinding base_binding, column_binding_map_t &correlated_map, + idx_t lateral_depth, bool recursive_rewrite = false); + + void VisitOperator(LogicalOperator &op) override; + +protected: + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr) override; + +private: + //! Helper class used to recursively rewrite correlated expressions within nested subqueries. + class RewriteCorrelatedRecursive { + public: + RewriteCorrelatedRecursive(BoundSubqueryExpression &parent, ColumnBinding base_binding, + column_binding_map_t &correlated_map); + void RewriteJoinRefRecursive(BoundTableRef &ref); + void RewriteCorrelatedSubquery(BoundSubqueryExpression &expr); + void RewriteCorrelatedExpressions(Expression &child); + + BoundSubqueryExpression &parent; + ColumnBinding base_binding; + column_binding_map_t &correlated_map; + }; + +private: + ColumnBinding base_binding; + column_binding_map_t &correlated_map; + // To keep track of the number of dependent joins encountered + idx_t lateral_depth; + // This flag is used to determine if the rewrite should recursively update the bindings for all + // bound columns ref in the plan, and update the depths to match the new source + bool recursive_rewrite; +}; + +//! Helper class that rewrites COUNT aggregates into a CASE expression turning NULL into 0 after a LEFT OUTER JOIN +class RewriteCountAggregates : public LogicalOperatorVisitor { +public: + explicit RewriteCountAggregates(column_binding_map_t &replacement_map); + + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + + column_binding_map_t &replacement_map; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/table_binding.hpp b/src/duckdb/src/include/duckdb/planner/table_binding.hpp new file mode 100644 index 00000000..4d6679c1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/table_binding.hpp @@ -0,0 +1,142 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/table_binding.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/catalog/catalog_entry/table_column_type.hpp" + +namespace duckdb { +class BindContext; +class BoundQueryNode; +class ColumnRefExpression; +class SubqueryRef; +class LogicalGet; +class TableCatalogEntry; +class TableFunctionCatalogEntry; +class BoundTableFunction; +class StandardEntry; +struct ColumnBinding; + +enum class BindingType { BASE, TABLE, DUMMY, CATALOG_ENTRY }; + +//! A Binding represents a binding to a table, table-producing function or subquery with a specified table index. +struct Binding { + Binding(BindingType binding_type, const string &alias, vector types, vector names, + idx_t index); + virtual ~Binding() = default; + + //! The type of Binding + BindingType binding_type; + //! The alias of the binding + string alias; + //! The table index of the binding + idx_t index; + //! The types of the bound columns + vector types; + //! Column names of the subquery + vector names; + //! Name -> index for the names + case_insensitive_map_t name_map; + +public: + bool TryGetBindingIndex(const string &column_name, column_t &column_index); + column_t GetBindingIndex(const string &column_name); + bool HasMatchingBinding(const string &column_name); + virtual string ColumnNotFoundError(const string &column_name) const; + virtual BindResult Bind(ColumnRefExpression &colref, idx_t depth); + virtual optional_ptr GetStandardEntry(); + +public: + template + TARGET &Cast() { + if (binding_type != TARGET::TYPE) { + throw InternalException("Failed to cast binding to type - binding type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (binding_type != TARGET::TYPE) { + throw InternalException("Failed to cast binding to type - binding type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +struct EntryBinding : public Binding { +public: + static constexpr const BindingType TYPE = BindingType::CATALOG_ENTRY; + +public: + EntryBinding(const string &alias, vector types, vector names, idx_t index, + StandardEntry &entry); + StandardEntry &entry; + +public: + optional_ptr GetStandardEntry() override; +}; + +//! TableBinding is exactly like the Binding, except it keeps track of which columns were bound in the linked LogicalGet +//! node for projection pushdown purposes. +struct TableBinding : public Binding { +public: + static constexpr const BindingType TYPE = BindingType::TABLE; + +public: + TableBinding(const string &alias, vector types, vector names, + vector &bound_column_ids, optional_ptr entry, idx_t index, + bool add_row_id = false); + + //! A reference to the set of bound column ids + vector &bound_column_ids; + //! The underlying catalog entry (if any) + optional_ptr entry; + +public: + unique_ptr ExpandGeneratedColumn(const string &column_name); + BindResult Bind(ColumnRefExpression &colref, idx_t depth) override; + optional_ptr GetStandardEntry() override; + string ColumnNotFoundError(const string &column_name) const override; + // These are columns that are present in the name_map, appearing in the order that they're bound + const vector &GetBoundColumnIds() const; + +protected: + ColumnBinding GetColumnBinding(column_t column_index); +}; + +//! DummyBinding is like the Binding, except the alias and index are set by default. Used for binding lambdas and macro +//! parameters. +struct DummyBinding : public Binding { +public: + static constexpr const BindingType TYPE = BindingType::DUMMY; + // NOTE: changing this string conflicts with the storage version + static constexpr const char *DUMMY_NAME = "0_macro_parameters"; + +public: + DummyBinding(vector types_p, vector names_p, string dummy_name_p); + + //! Arguments + vector> *arguments; + //! The name of the dummy binding + string dummy_name; + +public: + BindResult Bind(ColumnRefExpression &colref, idx_t depth) override; + BindResult Bind(ColumnRefExpression &colref, idx_t lambda_index, idx_t depth); + + //! Given the parameter colref, returns a copy of the argument that was supplied for this parameter + unique_ptr ParamToArg(ColumnRefExpression &colref); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/table_filter.hpp b/src/duckdb/src/include/duckdb/planner/table_filter.hpp new file mode 100644 index 00000000..368b13ed --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/table_filter.hpp @@ -0,0 +1,102 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/table_filter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" + +namespace duckdb { +class BaseStatistics; + +enum class TableFilterType : uint8_t { + CONSTANT_COMPARISON = 0, // constant comparison (e.g. =C, >C, >=C, Deserialize(Deserializer &deserializer); + +public: + template + TARGET &Cast() { + if (filter_type != TARGET::TYPE) { + throw InternalException("Failed to cast table to type - table filter type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (filter_type != TARGET::TYPE) { + throw InternalException("Failed to cast table to type - table filter type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +class TableFilterSet { +public: + unordered_map> filters; + +public: + void PushFilter(idx_t table_index, unique_ptr filter); + + bool Equals(TableFilterSet &other) { + if (filters.size() != other.filters.size()) { + return false; + } + for (auto &entry : filters) { + auto other_entry = other.filters.find(entry.first); + if (other_entry == other.filters.end()) { + return false; + } + if (!entry.second->Equals(*other_entry->second)) { + return false; + } + } + return true; + } + static bool Equals(TableFilterSet *left, TableFilterSet *right) { + if (left == right) { + return true; + } + if (!left || !right) { + return false; + } + return left->Equals(*right); + } + + void Serialize(Serializer &serializer) const; + static TableFilterSet Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp new file mode 100644 index 00000000..b1f7f6f4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_basetableref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { +class TableCatalogEntry; + +//! Represents a TableReference to a base table in the schema +class BoundBaseTableRef : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::BASE_TABLE; + +public: + BoundBaseTableRef(TableCatalogEntry &table, unique_ptr get) + : BoundTableRef(TableReferenceType::BASE_TABLE), table(table), get(std::move(get)) { + } + + TableCatalogEntry &table; + unique_ptr get; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp new file mode 100644 index 00000000..2e9db3ba --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_cteref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_tableref.hpp" + +namespace duckdb { + +class BoundCTERef : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::CTE; + +public: + BoundCTERef(idx_t bind_index, idx_t cte_index, CTEMaterialize materialized_cte) + : BoundTableRef(TableReferenceType::CTE), bind_index(bind_index), cte_index(cte_index), + materialized_cte(materialized_cte) { + } + + //! The set of columns bound to this base table reference + vector bound_columns; + //! The types of the values list + vector types; + //! The index in the bind context + idx_t bind_index; + //! The index of the cte + idx_t cte_index; + //! Is this a reference to a materialized CTE? + CTEMaterialize materialized_cte; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp new file mode 100644 index 00000000..debeb6e2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_dummytableref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_tableref.hpp" + +namespace duckdb { + +//! Represents a cross product +class BoundEmptyTableRef : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::EMPTY; + +public: + explicit BoundEmptyTableRef(idx_t bind_index) : BoundTableRef(TableReferenceType::EMPTY), bind_index(bind_index) { + } + idx_t bind_index; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp new file mode 100644 index 00000000..7fc563dd --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_expressionlistref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { +//! Represents a TableReference to a base table in the schema +class BoundExpressionListRef : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::EXPRESSION_LIST; + +public: + BoundExpressionListRef() : BoundTableRef(TableReferenceType::EXPRESSION_LIST) { + } + + //! The bound VALUES list + vector>> values; + //! The generated names of the values list + vector names; + //! The types of the values list + vector types; + //! The index in the bind context + idx_t bind_index; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp new file mode 100644 index 00000000..9500a348 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_joinref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/enums/joinref_type.hpp" +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +//! Represents a join +class BoundJoinRef : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::JOIN; + +public: + explicit BoundJoinRef(JoinRefType ref_type) + : BoundTableRef(TableReferenceType::JOIN), type(JoinType::INNER), ref_type(ref_type), lateral(false) { + } + + //! The binder used to bind the LHS of the join + shared_ptr left_binder; + //! The binder used to bind the RHS of the join + shared_ptr right_binder; + //! The left hand side of the join + unique_ptr left; + //! The right hand side of the join + unique_ptr right; + //! The join condition + unique_ptr condition; + //! The join type + JoinType type; + //! Join condition type + JoinRefType ref_type; + //! Whether or not this is a lateral join + bool lateral; + //! The correlated columns of the right-side with the left-side + vector correlated_columns; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp new file mode 100644 index 00000000..3219f630 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_pivotref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/parser/tableref/pivotref.hpp" +#include "duckdb/function/aggregate_function.hpp" + +namespace duckdb { + +struct BoundPivotInfo { + //! The number of group columns + idx_t group_count; + //! The set of types + vector types; + //! The set of values to pivot on + vector pivot_values; + //! The set of aggregate functions that is being executed + vector> aggregates; + + void Serialize(Serializer &serializer) const; + static BoundPivotInfo Deserialize(Deserializer &deserializer); +}; + +class BoundPivotRef : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::PIVOT; + +public: + explicit BoundPivotRef() : BoundTableRef(TableReferenceType::PIVOT) { + } + + idx_t bind_index; + //! The binder used to bind the child of the pivot + shared_ptr child_binder; + //! The child node of the pivot + unique_ptr child; + //! The bound pivot info + BoundPivotInfo bound_pivot; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp new file mode 100644 index 00000000..4cb057e4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_pos_join_ref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_tableref.hpp" + +namespace duckdb { + +//! Represents a positional join +class BoundPositionalJoinRef : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::POSITIONAL_JOIN; + +public: + BoundPositionalJoinRef() : BoundTableRef(TableReferenceType::POSITIONAL_JOIN), lateral(false) { + } + + //! The binder used to bind the LHS of the positional join + shared_ptr left_binder; + //! The binder used to bind the RHS of the positional join + shared_ptr right_binder; + //! The left hand side of the positional join + unique_ptr left; + //! The right hand side of the positional join + unique_ptr right; + //! Whether or not this is a lateral positional join + bool lateral; + //! The correlated columns of the right-side with the left-side + vector correlated_columns; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp new file mode 100644 index 00000000..fccd21b0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_subqueryref.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/bound_tableref.hpp" + +namespace duckdb { + +//! Represents a cross product +class BoundSubqueryRef : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::SUBQUERY; + +public: + BoundSubqueryRef(shared_ptr binder_p, unique_ptr subquery) + : BoundTableRef(TableReferenceType::SUBQUERY), binder(std::move(binder_p)), subquery(std::move(subquery)) { + } + + //! The binder used to bind the subquery + shared_ptr binder; + //! The bound subquery node + unique_ptr subquery; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp new file mode 100644 index 00000000..58d61b19 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/tableref/bound_table_function.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! Represents a reference to a table-producing function call +class BoundTableFunction : public BoundTableRef { +public: + static constexpr const TableReferenceType TYPE = TableReferenceType::TABLE_FUNCTION; + +public: + explicit BoundTableFunction(unique_ptr get) + : BoundTableRef(TableReferenceType::TABLE_FUNCTION), get(std::move(get)) { + } + + unique_ptr get; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/list.hpp b/src/duckdb/src/include/duckdb/planner/tableref/list.hpp new file mode 100644 index 00000000..1452daaf --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/tableref/list.hpp @@ -0,0 +1,8 @@ +#include "duckdb/planner/tableref/bound_basetableref.hpp" +#include "duckdb/planner/tableref/bound_cteref.hpp" +#include "duckdb/planner/tableref/bound_dummytableref.hpp" +#include "duckdb/planner/tableref/bound_expressionlistref.hpp" +#include "duckdb/planner/tableref/bound_joinref.hpp" +#include "duckdb/planner/tableref/bound_subqueryref.hpp" +#include "duckdb/planner/tableref/bound_table_function.hpp" +#include "duckdb/planner/tableref/bound_pivotref.hpp" diff --git a/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp b/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp new file mode 100644 index 00000000..ef66702d --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/arena_allocator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/common.hpp" + +namespace duckdb { + +struct ArenaChunk { + ArenaChunk(Allocator &allocator, idx_t size); + ~ArenaChunk(); + + AllocatedData data; + idx_t current_position; + idx_t maximum_size; + unsafe_unique_ptr next; + ArenaChunk *prev; +}; + +class ArenaAllocator { + static constexpr const idx_t ARENA_ALLOCATOR_INITIAL_CAPACITY = 2048; + +public: + DUCKDB_API ArenaAllocator(Allocator &allocator, idx_t initial_capacity = ARENA_ALLOCATOR_INITIAL_CAPACITY); + DUCKDB_API ~ArenaAllocator(); + + DUCKDB_API data_ptr_t Allocate(idx_t size); + DUCKDB_API data_ptr_t Reallocate(data_ptr_t pointer, idx_t old_size, idx_t size); + + DUCKDB_API data_ptr_t AllocateAligned(idx_t size); + DUCKDB_API data_ptr_t ReallocateAligned(data_ptr_t pointer, idx_t old_size, idx_t size); + + //! Resets the current head and destroys all previous arena chunks + DUCKDB_API void Reset(); + DUCKDB_API void Destroy(); + DUCKDB_API void Move(ArenaAllocator &allocator); + + DUCKDB_API ArenaChunk *GetHead(); + DUCKDB_API ArenaChunk *GetTail(); + + DUCKDB_API bool IsEmpty() const; + DUCKDB_API idx_t SizeInBytes() const; + + //! Returns an "Allocator" wrapper for this arena allocator + Allocator &GetAllocator() { + return arena_allocator; + } + +private: + //! Internal allocator that is used by the arena allocator + Allocator &allocator; + idx_t current_capacity; + unsafe_unique_ptr head; + ArenaChunk *tail; + //! An allocator wrapper using this arena allocator + Allocator arena_allocator; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/block.hpp b/src/duckdb/src/include/duckdb/storage/block.hpp new file mode 100644 index 00000000..b12bab29 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/block.hpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/block.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/file_buffer.hpp" + +namespace duckdb { + +class Serializer; +class Deserializer; + +class Block : public FileBuffer { +public: + Block(Allocator &allocator, block_id_t id); + Block(Allocator &allocator, block_id_t id, uint32_t internal_size); + Block(FileBuffer &source, block_id_t id); + + block_id_t id; +}; + +struct BlockPointer { + BlockPointer(block_id_t block_id_p, uint32_t offset_p) : block_id(block_id_p), offset(offset_p) { + } + BlockPointer() : block_id(INVALID_BLOCK), offset(0) { + } + + block_id_t block_id; + uint32_t offset; + + bool IsValid() const { + return block_id != INVALID_BLOCK; + } + + void Serialize(Serializer &serializer) const; + static BlockPointer Deserialize(Deserializer &source); +}; + +struct MetaBlockPointer { + MetaBlockPointer(idx_t block_pointer, uint32_t offset_p) : block_pointer(block_pointer), offset(offset_p) { + } + MetaBlockPointer() : block_pointer(DConstants::INVALID_INDEX), offset(0) { + } + + idx_t block_pointer; + uint32_t offset; + + bool IsValid() const { + return block_pointer != DConstants::INVALID_INDEX; + } + block_id_t GetBlockId() const; + uint32_t GetBlockIndex() const; + + void Serialize(Serializer &serializer) const; + static MetaBlockPointer Deserialize(Deserializer &source); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/block_manager.hpp b/src/duckdb/src/include/duckdb/storage/block_manager.hpp new file mode 100644 index 00000000..dca511b5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/block_manager.hpp @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/block_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/storage/block.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/unordered_map.hpp" + +namespace duckdb { +class BlockHandle; +class BufferManager; +class ClientContext; +class DatabaseInstance; +class MetadataManager; + +//! BlockManager is an abstract representation to manage blocks on DuckDB. When writing or reading blocks, the +//! BlockManager creates and accesses blocks. The concrete types implements how blocks are stored. +class BlockManager { +public: + explicit BlockManager(BufferManager &buffer_manager); + virtual ~BlockManager() = default; + + //! The buffer manager + BufferManager &buffer_manager; + +public: + //! Creates a new block inside the block manager + virtual unique_ptr ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) = 0; + virtual unique_ptr CreateBlock(block_id_t block_id, FileBuffer *source_buffer) = 0; + //! Return the next free block id + virtual block_id_t GetFreeBlockId() = 0; + //! Returns whether or not a specified block is the root block + virtual bool IsRootBlock(MetaBlockPointer root) = 0; + //! Mark a block as "free"; free blocks are immediately added to the free list and can be immediately overwritten + virtual void MarkBlockAsFree(block_id_t block_id) = 0; + //! Mark a block as "modified"; modified blocks are added to the free list after a checkpoint (i.e. their data is + //! assumed to be rewritten) + virtual void MarkBlockAsModified(block_id_t block_id) = 0; + //! Increase the reference count of a block. The block should hold at least one reference before this method is + //! called. + virtual void IncreaseBlockReferenceCount(block_id_t block_id) = 0; + //! Get the first meta block id + virtual idx_t GetMetaBlock() = 0; + //! Read the content of the block from disk + virtual void Read(Block &block) = 0; + //! Writes the block to disk + virtual void Write(FileBuffer &block, block_id_t block_id) = 0; + //! Writes the block to disk + void Write(Block &block) { + Write(block, block.id); + } + //! Write the header; should be the final step of a checkpoint + virtual void WriteHeader(DatabaseHeader header) = 0; + + //! Returns the number of total blocks + virtual idx_t TotalBlocks() = 0; + //! Returns the number of free blocks + virtual idx_t FreeBlocks() = 0; + + //! Truncate the underlying database file after a checkpoint + virtual void Truncate(); + + //! Register a block with the given block id in the base file + shared_ptr RegisterBlock(block_id_t block_id); + //! Convert an existing in-memory buffer into a persistent disk-backed block + shared_ptr ConvertToPersistent(block_id_t block_id, shared_ptr old_block); + + void UnregisterBlock(block_id_t block_id, bool can_destroy); + + MetadataManager &GetMetadataManager(); + +private: + //! The lock for the set of blocks + mutex blocks_lock; + //! A mapping of block id -> BlockHandle + unordered_map> blocks; + //! The metadata manager + unique_ptr metadata_manager; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/buffer/block_handle.hpp b/src/duckdb/src/include/duckdb/storage/buffer/block_handle.hpp new file mode 100644 index 00000000..e797003b --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/buffer/block_handle.hpp @@ -0,0 +1,133 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/buffer/block_handle.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/file_buffer.hpp" + +namespace duckdb { +class BlockManager; +class BufferHandle; +class BufferPool; +class DatabaseInstance; + +enum class BlockState : uint8_t { BLOCK_UNLOADED = 0, BLOCK_LOADED = 1 }; + +struct BufferPoolReservation { + idx_t size {0}; + BufferPool &pool; + + BufferPoolReservation(BufferPool &pool); + BufferPoolReservation(const BufferPoolReservation &) = delete; + BufferPoolReservation &operator=(const BufferPoolReservation &) = delete; + + BufferPoolReservation(BufferPoolReservation &&) noexcept; + BufferPoolReservation &operator=(BufferPoolReservation &&) noexcept; + + ~BufferPoolReservation(); + + void Resize(idx_t new_size); + void Merge(BufferPoolReservation &&src); +}; + +struct TempBufferPoolReservation : BufferPoolReservation { + TempBufferPoolReservation(BufferPool &pool, idx_t size) : BufferPoolReservation(pool) { + Resize(size); + } + TempBufferPoolReservation(TempBufferPoolReservation &&) = default; + ~TempBufferPoolReservation() { + Resize(0); + } +}; + +class BlockHandle { + friend class BlockManager; + friend struct BufferEvictionNode; + friend class BufferHandle; + friend class BufferManager; + friend class StandardBufferManager; + friend class BufferPool; + +public: + BlockHandle(BlockManager &block_manager, block_id_t block_id); + BlockHandle(BlockManager &block_manager, block_id_t block_id, unique_ptr buffer, bool can_destroy, + idx_t block_size, BufferPoolReservation &&reservation); + ~BlockHandle(); + + BlockManager &block_manager; + +public: + block_id_t BlockId() { + return block_id; + } + + void ResizeBuffer(idx_t block_size, int64_t memory_delta) { + D_ASSERT(buffer); + // resize and adjust current memory + buffer->Resize(block_size); + memory_usage += memory_delta; + D_ASSERT(memory_usage == buffer->AllocSize()); + } + + int32_t Readers() const { + return readers; + } + + inline bool IsSwizzled() const { + return !unswizzled; + } + + inline void SetSwizzling(const char *unswizzler) { + unswizzled = unswizzler; + } + + inline void SetCanDestroy(bool can_destroy_p) { + can_destroy = can_destroy_p; + } + + inline const idx_t &GetMemoryUsage() const { + return memory_usage; + } + bool IsUnloaded() { + return state == BlockState::BLOCK_UNLOADED; + } + +private: + static BufferHandle Load(shared_ptr &handle, unique_ptr buffer = nullptr); + unique_ptr UnloadAndTakeBlock(); + void Unload(); + bool CanUnload(); + + //! The block-level lock + mutex lock; + //! Whether or not the block is loaded/unloaded + atomic state; + //! Amount of concurrent readers + atomic readers; + //! The block id of the block + const block_id_t block_id; + //! Pointer to loaded data (if any) + unique_ptr buffer; + //! Internal eviction timestamp + atomic eviction_timestamp; + //! Whether or not the buffer can be destroyed (only used for temporary buffers) + bool can_destroy; + //! The memory usage of the block (when loaded). If we are pinning/loading + //! an unloaded block, this tells us how much memory to reserve. + idx_t memory_usage; + //! Current memory reservation / usage + BufferPoolReservation memory_charge; + //! Does the block contain any memory pointers? + const char *unswizzled; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/buffer/buffer_handle.hpp b/src/duckdb/src/include/duckdb/storage/buffer/buffer_handle.hpp new file mode 100644 index 00000000..ed15fce8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/buffer/buffer_handle.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/buffer/buffer_handle.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/file_buffer.hpp" + +namespace duckdb { +class BlockHandle; +class FileBuffer; + +class BufferHandle { +public: + DUCKDB_API BufferHandle(); + DUCKDB_API BufferHandle(shared_ptr handle, FileBuffer *node); + DUCKDB_API ~BufferHandle(); + // disable copy constructors + BufferHandle(const BufferHandle &other) = delete; + BufferHandle &operator=(const BufferHandle &) = delete; + //! enable move constructors + DUCKDB_API BufferHandle(BufferHandle &&other) noexcept; + DUCKDB_API BufferHandle &operator=(BufferHandle &&) noexcept; + +public: + //! Returns whether or not the BufferHandle is valid. + DUCKDB_API bool IsValid() const; + //! Returns a pointer to the buffer data. Handle must be valid. + inline data_ptr_t Ptr() const { + D_ASSERT(IsValid()); + return node->buffer; + } + //! Returns a pointer to the buffer data. Handle must be valid. + inline data_ptr_t Ptr() { + D_ASSERT(IsValid()); + return node->buffer; + } + //! Gets the underlying file buffer. Handle must be valid. + DUCKDB_API FileBuffer &GetFileBuffer(); + //! Destroys the buffer handle + DUCKDB_API void Destroy(); + + const shared_ptr &GetBlockHandle() const { + return handle; + } + +private: + //! The block handle + shared_ptr handle; + //! The managed buffer node + FileBuffer *node; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp b/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp new file mode 100644 index 00000000..28dd45c8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/file_buffer.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" + +namespace duckdb { + +struct EvictionQueue; + +struct BufferEvictionNode { + BufferEvictionNode() { + } + BufferEvictionNode(weak_ptr handle_p, idx_t timestamp_p) + : handle(std::move(handle_p)), timestamp(timestamp_p) { + D_ASSERT(!handle.expired()); + } + + weak_ptr handle; + idx_t timestamp; + + bool CanUnload(BlockHandle &handle_p); + + shared_ptr TryGetBlockHandle(); +}; + +//! The BufferPool is in charge of handling memory management for one or more databases. It defines memory limits +//! and implements priority eviction among all users of the pool. +class BufferPool { + friend class BlockHandle; + friend class BlockManager; + friend class BufferManager; + friend class StandardBufferManager; + +public: + explicit BufferPool(idx_t maximum_memory); + virtual ~BufferPool(); + + //! Set a new memory limit to the buffer pool, throws an exception if the new limit is too low and not enough + //! blocks can be evicted + void SetLimit(idx_t limit, const char *exception_postscript); + + void IncreaseUsedMemory(idx_t size); + + idx_t GetUsedMemory(); + + idx_t GetMaxMemory(); + +protected: + //! Evict blocks until the currently used memory + extra_memory fit, returns false if this was not possible + //! (i.e. not enough blocks could be evicted) + //! If the "buffer" argument is specified AND the system can find a buffer to re-use for the given allocation size + //! "buffer" will be made to point to the re-usable memory. Note that this is not guaranteed. + //! Returns a pair. result.first indicates if eviction was successful. result.second contains the + //! reservation handle, which can be moved to the BlockHandle that will own the reservation. + struct EvictionResult { + bool success; + TempBufferPoolReservation reservation; + }; + virtual EvictionResult EvictBlocks(idx_t extra_memory, idx_t memory_limit, + unique_ptr *buffer = nullptr); + + //! Garbage collect eviction queue + void PurgeQueue(); + void AddToEvictionQueue(shared_ptr &handle); + +private: + //! The lock for changing the memory limit + mutex limit_lock; + //! The current amount of memory that is occupied by the buffer manager (in bytes) + atomic current_memory; + //! The maximum amount of memory that the buffer manager can keep (in bytes) + atomic maximum_memory; + //! Eviction queue + unique_ptr queue; + //! Total number of insertions into the eviction queue. This guides the schedule for calling PurgeQueue. + atomic queue_insertions; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/buffer/temporary_file_information.hpp b/src/duckdb/src/include/duckdb/storage/buffer/temporary_file_information.hpp new file mode 100644 index 00000000..642078aa --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/buffer/temporary_file_information.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +struct TemporaryFileInformation { + string path; + idx_t size; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp new file mode 100644 index 00000000..614d4615 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp @@ -0,0 +1,81 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/buffer_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/storage/buffer/temporary_file_information.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +class Allocator; +class BufferPool; + +class BufferManager { + friend class BufferHandle; + friend class BlockHandle; + friend class BlockManager; + +public: + BufferManager() { + } + virtual ~BufferManager() { + } + +public: + static unique_ptr CreateStandardBufferManager(DatabaseInstance &db, DBConfig &config); + virtual BufferHandle Allocate(idx_t block_size, bool can_destroy = true, + shared_ptr *block = nullptr) = 0; + //! Reallocate an in-memory buffer that is pinned. + virtual void ReAllocate(shared_ptr &handle, idx_t block_size) = 0; + virtual BufferHandle Pin(shared_ptr &handle) = 0; + virtual void Unpin(shared_ptr &handle) = 0; + //! Returns the currently allocated memory + virtual idx_t GetUsedMemory() const = 0; + //! Returns the maximum available memory + virtual idx_t GetMaxMemory() const = 0; + virtual shared_ptr RegisterSmallMemory(idx_t block_size); + virtual DUCKDB_API Allocator &GetBufferAllocator(); + virtual DUCKDB_API void ReserveMemory(idx_t size); + virtual DUCKDB_API void FreeReservedMemory(idx_t size); + //! Set a new memory limit to the buffer manager, throws an exception if the new limit is too low and not enough + //! blocks can be evicted + virtual void SetLimit(idx_t limit = (idx_t)-1); + virtual vector GetTemporaryFiles(); + virtual const string &GetTemporaryDirectory(); + virtual void SetTemporaryDirectory(const string &new_dir); + virtual DatabaseInstance &GetDatabase(); + virtual bool HasTemporaryDirectory() const; + //! Construct a managed buffer. + virtual unique_ptr ConstructManagedBuffer(idx_t size, unique_ptr &&source, + FileBufferType type = FileBufferType::MANAGED_BUFFER); + //! Get the underlying buffer pool responsible for managing the buffers + virtual BufferPool &GetBufferPool(); + + // Static methods + DUCKDB_API static BufferManager &GetBufferManager(DatabaseInstance &db); + DUCKDB_API static BufferManager &GetBufferManager(ClientContext &context); + DUCKDB_API static BufferManager &GetBufferManager(AttachedDatabase &db); + + static idx_t GetAllocSize(idx_t block_size) { + return AlignValue(block_size + Storage::BLOCK_HEADER_SIZE); + } + +protected: + virtual void PurgeQueue() = 0; + virtual void AddToEvictionQueue(shared_ptr &handle); + virtual void WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer); + virtual unique_ptr ReadTemporaryBuffer(block_id_t id, unique_ptr buffer); + virtual void DeleteTemporaryFile(block_id_t id); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp new file mode 100644 index 00000000..85a62b33 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/checkpoint/row_group_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/checkpoint_manager.hpp" + +namespace duckdb { +struct ColumnCheckpointState; +class CheckpointWriter; +class ColumnData; +class ColumnSegment; +class RowGroup; +class BaseStatistics; +class SegmentStatistics; + +// Writes data for an entire row group. +class RowGroupWriter { +public: + RowGroupWriter(TableCatalogEntry &table, PartialBlockManager &partial_block_manager) + : table(table), partial_block_manager(partial_block_manager) { + } + virtual ~RowGroupWriter() { + } + + CompressionType GetColumnCompressionType(idx_t i); + + virtual void WriteColumnDataPointers(ColumnCheckpointState &column_checkpoint_state, Serializer &serializer) = 0; + + virtual MetadataWriter &GetPayloadWriter() = 0; + + void RegisterPartialBlock(PartialBlockAllocation &&allocation); + PartialBlockAllocation GetBlockAllocation(uint32_t segment_size); + + PartialBlockManager &GetPartialBlockManager() { + return partial_block_manager; + } + +protected: + TableCatalogEntry &table; + PartialBlockManager &partial_block_manager; +}; + +// Writes data for an entire row group. +class SingleFileRowGroupWriter : public RowGroupWriter { +public: + SingleFileRowGroupWriter(TableCatalogEntry &table, PartialBlockManager &partial_block_manager, + MetadataWriter &table_data_writer) + : RowGroupWriter(table, partial_block_manager), table_data_writer(table_data_writer) { + } + + //! MetadataWriter is a cursor on a given BlockManager. This returns the + //! cursor against which we should write payload data for the specified RowGroup. + MetadataWriter &table_data_writer; + +public: + virtual void WriteColumnDataPointers(ColumnCheckpointState &column_checkpoint_state, + Serializer &serializer) override; + + virtual MetadataWriter &GetPayloadWriter() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/string_checkpoint_state.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/string_checkpoint_state.hpp new file mode 100644 index 00000000..f9de6163 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/string_checkpoint_state.hpp @@ -0,0 +1,80 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/checkpoint/string_checkpoint_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/function/compression_function.hpp" + +namespace duckdb { +struct UncompressedStringSegmentState; + +class OverflowStringWriter { +public: + virtual ~OverflowStringWriter() { + } + + virtual void WriteString(UncompressedStringSegmentState &state, string_t string, block_id_t &result_block, + int32_t &result_offset) = 0; + virtual void Flush() = 0; +}; + +struct StringBlock { + shared_ptr block; + idx_t offset; + idx_t size; + unique_ptr next; +}; + +struct string_location_t { + string_location_t(block_id_t block_id, int32_t offset) : block_id(block_id), offset(offset) { + } + string_location_t() { + } + bool IsValid() { + return offset < Storage::BLOCK_SIZE && (block_id == INVALID_BLOCK || block_id >= MAXIMUM_BLOCK); + } + block_id_t block_id; + int32_t offset; +}; + +struct UncompressedStringSegmentState : public CompressedSegmentState { + ~UncompressedStringSegmentState() override; + + //! The string block holding strings that do not fit in the main block + //! FIXME: this should be replaced by a heap that also allows freeing of unused strings + unique_ptr head; + //! Map of block id to string block + unordered_map> overflow_blocks; + //! Overflow string writer (if any), if not set overflow strings will be written to memory blocks + unique_ptr overflow_writer; + //! The set of overflow blocks written to disk (if any) + vector on_disk_blocks; + +public: + shared_ptr GetHandle(BlockManager &manager, block_id_t block_id); + + void RegisterBlock(BlockManager &manager, block_id_t block_id); + + string GetSegmentInfo() const override { + if (on_disk_blocks.empty()) { + return ""; + } + string result = StringUtil::Join(on_disk_blocks, on_disk_blocks.size(), ", ", + [&](block_id_t block) { return to_string(block); }); + return "Overflow String Block Ids: " + result; + } + +private: + mutex block_lock; + unordered_map> handles; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_reader.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_reader.hpp new file mode 100644 index 00000000..20ca6d4d --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_reader.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/checkpoint/table_data_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/checkpoint_manager.hpp" + +namespace duckdb { +struct BoundCreateTableInfo; + +//! The table data reader is responsible for reading the data of a table from the block manager +class TableDataReader { +public: + TableDataReader(MetadataReader &reader, BoundCreateTableInfo &info); + + void ReadTableData(); + +private: + MetadataReader &reader; + BoundCreateTableInfo &info; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp new file mode 100644 index 00000000..ca049b74 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/checkpoint/table_data_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/checkpoint/row_group_writer.hpp" + +namespace duckdb { +class DuckTableEntry; +class TableStatistics; + +//! The table data writer is responsible for writing the data of a table to +//! storage. +// +//! This is meant to encapsulate and abstract: +//! - Storage/encoding of table metadata (block pointers) +//! - Mapping management of data block locations +//! Abstraction will support, for example: tiering, versioning, or splitting into multiple block managers. +class TableDataWriter { +public: + explicit TableDataWriter(TableCatalogEntry &table); + virtual ~TableDataWriter(); + +public: + void WriteTableData(Serializer &metadata_serializer); + + CompressionType GetColumnCompressionType(idx_t i); + + virtual void FinalizeTable(TableStatistics &&global_stats, DataTableInfo *info, + Serializer &metadata_serializer) = 0; + virtual unique_ptr GetRowGroupWriter(RowGroup &row_group) = 0; + + virtual void AddRowGroup(RowGroupPointer &&row_group_pointer, unique_ptr &&writer); + +protected: + DuckTableEntry &table; + // Pointers to the start of each row group. + vector row_group_pointers; +}; + +class SingleFileTableDataWriter : public TableDataWriter { +public: + SingleFileTableDataWriter(SingleFileCheckpointWriter &checkpoint_manager, TableCatalogEntry &table, + MetadataWriter &table_data_writer); + +public: + virtual void FinalizeTable(TableStatistics &&global_stats, DataTableInfo *info, + Serializer &metadata_serializer) override; + virtual unique_ptr GetRowGroupWriter(RowGroup &row_group) override; + +private: + SingleFileCheckpointWriter &checkpoint_manager; + // Writes the actual table data + MetadataWriter &table_data_writer; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp new file mode 100644 index 00000000..c1949a5a --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/checkpoint/string_checkpoint_state.hpp" + +namespace duckdb { + +class WriteOverflowStringsToDisk : public OverflowStringWriter { +public: + explicit WriteOverflowStringsToDisk(BlockManager &block_manager); + ~WriteOverflowStringsToDisk() override; + + //! The block manager + BlockManager &block_manager; + + //! Temporary buffer + BufferHandle handle; + //! The block on-disk to which we are writing + block_id_t block_id; + //! The offset within the current block + idx_t offset; + + static constexpr idx_t STRING_SPACE = Storage::BLOCK_SIZE - sizeof(block_id_t); + +public: + void WriteString(UncompressedStringSegmentState &state, string_t string, block_id_t &result_block, + int32_t &result_offset) override; + void Flush() override; + +private: + void AllocateNewBlock(UncompressedStringSegmentState &state, block_id_t new_block_id); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp new file mode 100644 index 00000000..e41a7792 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp @@ -0,0 +1,121 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/checkpoint_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/partial_block_manager.hpp" +#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { +class DatabaseInstance; +class ClientContext; +class ColumnSegment; +class MetadataReader; +class SchemaCatalogEntry; +class SequenceCatalogEntry; +class TableCatalogEntry; +class ViewCatalogEntry; +class TypeCatalogEntry; + +class CheckpointWriter { +public: + explicit CheckpointWriter(AttachedDatabase &db) : db(db) { + } + virtual ~CheckpointWriter() { + } + + //! The database + AttachedDatabase &db; + + virtual MetadataManager &GetMetadataManager() = 0; + virtual MetadataWriter &GetMetadataWriter() = 0; + virtual unique_ptr GetTableDataWriter(TableCatalogEntry &table) = 0; + +protected: + virtual void WriteEntry(CatalogEntry &entry, Serializer &serializer); + virtual void WriteSchema(SchemaCatalogEntry &schema, Serializer &serializer); + virtual void WriteTable(TableCatalogEntry &table, Serializer &serializer); + virtual void WriteView(ViewCatalogEntry &table, Serializer &serializer); + virtual void WriteSequence(SequenceCatalogEntry &table, Serializer &serializer); + virtual void WriteMacro(ScalarMacroCatalogEntry &table, Serializer &serializer); + virtual void WriteTableMacro(TableMacroCatalogEntry &table, Serializer &serializer); + virtual void WriteIndex(IndexCatalogEntry &index_catalog, Serializer &serializer); + virtual void WriteType(TypeCatalogEntry &type, Serializer &serializer); +}; + +class CheckpointReader { +public: + CheckpointReader(Catalog &catalog) : catalog(catalog) { + } + virtual ~CheckpointReader() { + } + +protected: + Catalog &catalog; + +protected: + virtual void LoadCheckpoint(ClientContext &context, MetadataReader &reader); + virtual void ReadEntry(ClientContext &context, Deserializer &deserializer); + virtual void ReadSchema(ClientContext &context, Deserializer &deserializer); + virtual void ReadTable(ClientContext &context, Deserializer &deserializer); + virtual void ReadView(ClientContext &context, Deserializer &deserializer); + virtual void ReadSequence(ClientContext &context, Deserializer &deserializer); + virtual void ReadMacro(ClientContext &context, Deserializer &deserializer); + virtual void ReadTableMacro(ClientContext &context, Deserializer &deserializer); + virtual void ReadIndex(ClientContext &context, Deserializer &deserializer); + virtual void ReadType(ClientContext &context, Deserializer &deserializer); + + virtual void ReadTableData(ClientContext &context, Deserializer &deserializer, BoundCreateTableInfo &bound_info); +}; + +class SingleFileCheckpointReader final : public CheckpointReader { +public: + explicit SingleFileCheckpointReader(SingleFileStorageManager &storage) + : CheckpointReader(Catalog::GetCatalog(storage.GetAttached())), storage(storage) { + } + + void LoadFromStorage(); + MetadataManager &GetMetadataManager(); + + //! The database + SingleFileStorageManager &storage; +}; + +//! CheckpointWriter is responsible for checkpointing the database +class SingleFileRowGroupWriter; +class SingleFileTableDataWriter; + +class SingleFileCheckpointWriter final : public CheckpointWriter { + friend class SingleFileRowGroupWriter; + friend class SingleFileTableDataWriter; + +public: + SingleFileCheckpointWriter(AttachedDatabase &db, BlockManager &block_manager); + + //! Checkpoint the current state of the WAL and flush it to the main storage. This should be called BEFORE any + //! connection is available because right now the checkpointing cannot be done online. (TODO) + void CreateCheckpoint(); + + virtual MetadataWriter &GetMetadataWriter() override; + virtual MetadataManager &GetMetadataManager() override; + virtual unique_ptr GetTableDataWriter(TableCatalogEntry &table) override; + + BlockManager &GetBlockManager(); + +private: + //! The metadata writer is responsible for writing schema information + unique_ptr metadata_writer; + //! The table data writer is responsible for writing the DataPointers used by the table chunks + unique_ptr table_metadata_writer; + //! Because this is single-file storage, we can share partial blocks across + //! an entire checkpoint. + PartialBlockManager partial_block_manager; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/bitpacking.hpp b/src/duckdb/src/include/duckdb/storage/compression/bitpacking.hpp new file mode 100644 index 00000000..6b87e5bb --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/bitpacking.hpp @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/bitpacking.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +enum class BitpackingMode : uint8_t { INVALID, AUTO, CONSTANT, CONSTANT_DELTA, DELTA_FOR, FOR }; + +BitpackingMode BitpackingModeFromString(const string &str); +string BitpackingModeToString(const BitpackingMode &mode); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/bit_reader.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/bit_reader.hpp new file mode 100644 index 00000000..65220f29 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/bit_reader.hpp @@ -0,0 +1,171 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/algorithm/chimp/bit_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" + +namespace duckdb { + +//! Every byte read touches at most 2 bytes (1 if it's perfectly aligned) +//! Within a byte we need to mask off the bits that we're interested in + +struct BitReader { +private: + //! Align the masks to the right + static constexpr uint8_t MASKS[] = { + 0, // 0b00000000, + 128, // 0b10000000, + 192, // 0b11000000, + 224, // 0b11100000, + 240, // 0b11110000, + 248, // 0b11111000, + 252, // 0b11111100, + 254, // 0b11111110, + 255, // 0b11111111, + // These later masks are for the cases where index + SIZE exceeds 8 + 254, // 0b11111110, + 252, // 0b11111100, + 248, // 0b11111000, + 240, // 0b11110000, + 224, // 0b11100000, + 192, // 0b11000000, + 128, // 0b10000000, + }; + + static constexpr uint8_t REMAINDER_MASKS[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 128, // 0b10000000, + 192, // 0b11000000, + 224, // 0b11100000, + 240, // 0b11110000, + 248, // 0b11111000, + 252, // 0b11111100, + 254, // 0b11111110, + 255, // 0b11111111, + }; + +public: +public: + BitReader() : input(nullptr), index(0) { + } + uint8_t *input; + uint32_t index; + +public: + void SetStream(uint8_t *input) { + this->input = input; + index = 0; + } + + inline uint8_t BitIndex() const { + return (index & 7); + } + inline uint64_t ByteIndex() const { + return (index >> 3); + } + + inline uint8_t InnerReadByte(const uint8_t &offset) { + uint8_t result = input[ByteIndex() + offset] << BitIndex() | + ((input[ByteIndex() + offset + 1] & REMAINDER_MASKS[8 + BitIndex()]) >> (8 - BitIndex())); + return result; + } + + //! index: 4 + //! size: 7 + //! input: [12345678][12345678] + //! result: [-AAAA BBB] + //! + //! Result contains 4 bits from the first byte (making up the most significant bits) + //! And 3 bits from the second byte (the least significant bits) + inline uint8_t InnerRead(const uint8_t &size, const uint8_t &offset) { + const uint8_t right_shift = 8 - size; + const uint8_t bit_remainder = (8 - ((size + BitIndex()) - 8)) & 7; + // The least significant bits are positioned at the far right of the byte + + // Create a mask given the size and index + // Take the first byte + // Left-shift it by index, to line up the bits we're interested in with the mask + // Get the mask for the given size + // Bit-wise AND the byte and the mask together + // Right-shift this result (the most significant bits) + + // Sometimes we will need to read from the second byte + // But to make this branchless, we will perform what is basically a no-op if this condition is not true + // SPILL = (index + size >= 8) + // + // If SPILL is true: + // The REMAINDER_MASKS gives us the mask for the bits we're interested in + // We bit-wise AND these together (no need to shift anything because the index is essentially zero for this new + // byte) And we then right-shift these bits in place (to the right of the previous bits) + const bool spill_to_next_byte = (size + BitIndex() >= 8); + uint8_t result = + ((input[ByteIndex() + offset] << BitIndex()) & MASKS[size]) >> right_shift | + ((input[ByteIndex() + offset + spill_to_next_byte] & REMAINDER_MASKS[size + BitIndex()]) >> bit_remainder); + return result; + } + + template + inline T ReadBytes(const uint8_t &remainder) { + T result = 0; + if (BYTES > 0) { + result = result << 8 | InnerReadByte(0); + } + if (BYTES > 1) { + result = result << 8 | InnerReadByte(1); + } + if (BYTES > 2) { + result = result << 8 | InnerReadByte(2); + } + if (BYTES > 3) { + result = result << 8 | InnerReadByte(3); + } + if (BYTES > 4) { + result = result << 8 | InnerReadByte(4); + } + if (BYTES > 5) { + result = result << 8 | InnerReadByte(5); + } + if (BYTES > 6) { + result = result << 8 | InnerReadByte(6); + } + if (BYTES > 7) { + result = result << 8 | InnerReadByte(7); + } + result = result << remainder | InnerRead(remainder, BYTES); + index += (BYTES << 3) + remainder; + return result; + } + + template + inline T ReadBytes(const uint8_t &bytes, const uint8_t &remainder) { + T result = 0; + for (uint8_t i = 0; i < bytes; i++) { + result = result << 8 | InnerReadByte(i); + } + result = result << remainder | InnerRead(remainder, bytes); + index += (bytes << 3) + remainder; + return result; + } + + template + inline T ReadValue() { + constexpr uint8_t BYTES = (SIZE >> 3); + constexpr uint8_t REMAINDER = (SIZE & 7); + return ReadBytes(REMAINDER); + } + + template + inline T ReadValue(const uint8_t &size) { + const uint8_t bytes = size >> 3; // divide by 8; + const uint8_t remainder = size & 7; + return ReadBytes(bytes, remainder); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/bit_utils.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/bit_utils.hpp new file mode 100644 index 00000000..3f65918b --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/bit_utils.hpp @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/algorithm/bit_utils.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +template +struct BitUtils { + static constexpr R Mask(unsigned int const bits) { + return (((uint64_t)(bits < (sizeof(R) * 8))) << (bits & ((sizeof(R) * 8) - 1))) - 1U; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/byte_reader.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/byte_reader.hpp new file mode 100644 index 00000000..fae2fb03 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/byte_reader.hpp @@ -0,0 +1,126 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/algorithm/byte_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +class ByteReader { +public: + ByteReader() : buffer(nullptr), index(0) { + } + +public: + void SetStream(const uint8_t *buffer) { + this->buffer = buffer; + index = 0; + } + + size_t Index() const { + return index; + } + + template + T ReadValue() { + auto result = Load(buffer + index); + index += sizeof(T); + return result; + } + + template + T ReadValue() { + return ReadValue(SIZE); + } + + template + inline T ReadValue(uint8_t bytes, uint8_t trailing_zero) { + T result = 0; + switch (bytes) { + // LCOV_EXCL_START + case 1: + result = Load(buffer + index); + index++; + return result; + case 2: + result = Load(buffer + index); + index += 2; + return result; + case 3: + memcpy(&result, (void *)(buffer + index), 3); + index += 3; + return result; + case 4: + result = Load(buffer + index); + index += 4; + return result; + case 5: + memcpy(&result, (void *)(buffer + index), 5); + index += 5; + return result; + case 6: + memcpy(&result, (void *)(buffer + index), 6); + index += 6; + return result; + case 7: + memcpy(&result, (void *)(buffer + index), 7); + index += 7; + return result; + // LCOV_EXCL_STOP + default: + if (trailing_zero < 8) { + result = Load(buffer + index); + index += sizeof(T); + return result; + } + return result; + } + } + +private: + const uint8_t *buffer; + uint32_t index; +}; + +template <> +inline uint32_t ByteReader::ReadValue(uint8_t bytes, uint8_t trailing_zero) { + uint32_t result = 0; + switch (bytes) { + case 0: + // LCOV_EXCL_START + if (trailing_zero < 8) { + result = Load(buffer + index); + index += sizeof(uint32_t); + return result; + } + return result; + case 1: + result = Load(buffer + index); + index++; + return result; + case 2: + result = Load(buffer + index); + index += 2; + return result; + case 3: + memcpy(&result, (void *)(buffer + index), 3); + index += 3; + return result; + case 4: + result = Load(buffer + index); + index += 4; + return result; + // LCOV_EXCL_STOP + default: + throw InternalException("Write of %llu bytes attempted into address pointing to 4 byte value", bytes); + } +} +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/byte_writer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/byte_writer.hpp new file mode 100644 index 00000000..76e1e315 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/byte_writer.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/algorithm/byte_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/common/helper.hpp" + +namespace duckdb { + +template +class ByteWriter { +public: + ByteWriter() : buffer(nullptr), index(0) { + } + +public: + idx_t BytesWritten() const { + return index; + } + + void Flush() { + } + + void ByteAlign() { + } + + void SetStream(uint8_t *buffer) { + this->buffer = buffer; + this->index = 0; + } + + template + void WriteValue(const T &value) { + const uint8_t bytes = (SIZE >> 3) + ((SIZE & 7) != 0); + if (!EMPTY) { + memcpy((void *)(buffer + index), &value, bytes); + } + index += bytes; + } + + template + void WriteValue(const T &value, const uint8_t &size) { + const uint8_t bytes = (size >> 3) + ((size & 7) != 0); + if (!EMPTY) { + memcpy((void *)(buffer + index), &value, bytes); + } + index += bytes; + } + +private: +private: + uint8_t *buffer; + idx_t index; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp new file mode 100644 index 00000000..50c3a123 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp @@ -0,0 +1,293 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/algorithm/chimp128.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" +#include "duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp" +#include "duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp" +#include "duckdb/storage/compression/chimp/algorithm/ring_buffer.hpp" +#include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/storage/compression/chimp/algorithm/packed_data.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/bit_utils.hpp" + +#include "duckdb/storage/compression/chimp/algorithm/bit_reader.hpp" +#include "duckdb/storage/compression/chimp/algorithm/output_bit_stream.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Compression +//===--------------------------------------------------------------------===// + +template +struct Chimp128CompressionState { + + Chimp128CompressionState() : ring_buffer(), previous_leading_zeros(NumericLimits::Maximum()) { + previous_value = 0; + } + + inline void SetLeadingZeros(int32_t value = NumericLimits::Maximum()) { + this->previous_leading_zeros = value; + } + + void Flush() { + leading_zero_buffer.Flush(); + } + + // Reset the state + void Reset() { + first = true; + ring_buffer.Reset(); + SetLeadingZeros(); + leading_zero_buffer.Reset(); + flag_buffer.Reset(); + packed_data_buffer.Reset(); + previous_value = 0; + } + + CHIMP_TYPE BitsWritten() const { + return output.BitsWritten() + leading_zero_buffer.BitsWritten() + flag_buffer.BitsWritten() + + (packed_data_buffer.index * 16); + } + + OutputBitStream output; // The stream to write to + LeadingZeroBuffer leading_zero_buffer; + FlagBuffer flag_buffer; + PackedDataBuffer packed_data_buffer; + RingBuffer ring_buffer; //! The ring buffer that holds the previous values + uint8_t previous_leading_zeros; //! The leading zeros of the reference value + CHIMP_TYPE previous_value = 0; + bool first = true; +}; + +template +class Chimp128Compression { +public: + using State = Chimp128CompressionState; + + //! The amount of bits needed to store an index between 0-127 + static constexpr uint8_t INDEX_BITS_SIZE = 7; + static constexpr uint8_t BIT_SIZE = sizeof(CHIMP_TYPE) * 8; + + static constexpr uint8_t TRAILING_ZERO_THRESHOLD = SignificantBits::size + INDEX_BITS_SIZE; + + static void Store(CHIMP_TYPE in, State &state) { + if (state.first) { + WriteFirst(in, state); + } else { + CompressValue(in, state); + } + } + + //! Write the content of the bit buffer to the stream + static void Flush(State &state) { + if (!EMPTY) { + state.output.Flush(); + } + } + + static void WriteFirst(CHIMP_TYPE in, State &state) { + state.ring_buffer.template Insert(in); + state.output.template WriteValue(in); + state.previous_value = in; + state.first = false; + } + + static void CompressValue(CHIMP_TYPE in, State &state) { + + auto key = state.ring_buffer.Key(in); + CHIMP_TYPE xor_result; + uint8_t previous_index; + uint32_t trailing_zeros = 0; + bool trailing_zeros_exceed_threshold = false; + const CHIMP_TYPE reference_index = state.ring_buffer.IndexOf(key); + + // Find the reference value to use when compressing the current value + if (((int64_t)state.ring_buffer.Size() - (int64_t)reference_index) < (int64_t)ChimpConstants::BUFFER_SIZE) { + // The reference index is within 128 values, we can use it + auto current_index = state.ring_buffer.IndexOf(key); + if (current_index > state.ring_buffer.Size()) { + current_index = 0; + } + auto reference_value = state.ring_buffer.Value(current_index % ChimpConstants::BUFFER_SIZE); + CHIMP_TYPE tempxor_result = (CHIMP_TYPE)in ^ reference_value; + trailing_zeros = CountZeros::Trailing(tempxor_result); + trailing_zeros_exceed_threshold = trailing_zeros > TRAILING_ZERO_THRESHOLD; + if (trailing_zeros_exceed_threshold) { + previous_index = current_index % ChimpConstants::BUFFER_SIZE; + xor_result = tempxor_result; + } else { + previous_index = state.ring_buffer.Size() % ChimpConstants::BUFFER_SIZE; + xor_result = (CHIMP_TYPE)in ^ state.ring_buffer.Value(previous_index); + } + } else { + // Reference index is not in range, use the directly previous value + previous_index = state.ring_buffer.Size() % ChimpConstants::BUFFER_SIZE; + xor_result = (CHIMP_TYPE)in ^ state.ring_buffer.Value(previous_index); + } + + // Compress the value + if (xor_result == 0) { + state.flag_buffer.Insert(ChimpConstants::Flags::VALUE_IDENTICAL); + state.output.template WriteValue(previous_index); + state.SetLeadingZeros(); + } else { + // Values are not identical + auto leading_zeros_raw = CountZeros::Leading(xor_result); + uint8_t leading_zeros = ChimpConstants::Compression::LEADING_ROUND[leading_zeros_raw]; + + if (trailing_zeros_exceed_threshold) { + state.flag_buffer.Insert(ChimpConstants::Flags::TRAILING_EXCEEDS_THRESHOLD); + uint32_t significant_bits = BIT_SIZE - leading_zeros - trailing_zeros; + auto result = PackedDataUtils::Pack( + reference_index, ChimpConstants::Compression::LEADING_REPRESENTATION[leading_zeros], + significant_bits); + state.packed_data_buffer.Insert(result & 0xFFFF); + state.output.template WriteValue(xor_result >> trailing_zeros, significant_bits); + state.SetLeadingZeros(); + } else if (leading_zeros == state.previous_leading_zeros) { + state.flag_buffer.Insert(ChimpConstants::Flags::LEADING_ZERO_EQUALITY); + int32_t significant_bits = BIT_SIZE - leading_zeros; + state.output.template WriteValue(xor_result, significant_bits); + } else { + state.flag_buffer.Insert(ChimpConstants::Flags::LEADING_ZERO_LOAD); + const int32_t significant_bits = BIT_SIZE - leading_zeros; + state.leading_zero_buffer.Insert(ChimpConstants::Compression::LEADING_REPRESENTATION[leading_zeros]); + state.output.template WriteValue(xor_result, significant_bits); + state.SetLeadingZeros(leading_zeros); + } + } + state.previous_value = in; + state.ring_buffer.Insert(in); + } +}; + +//===--------------------------------------------------------------------===// +// Decompression +//===--------------------------------------------------------------------===// + +template +struct Chimp128DecompressionState { +public: + Chimp128DecompressionState() : reference_value(0), first(true) { + ResetZeros(); + } + + void Reset() { + ResetZeros(); + reference_value = 0; + ring_buffer.Reset(); + first = true; + } + + inline void ResetZeros() { + leading_zeros = NumericLimits::Maximum(); + trailing_zeros = 0; + } + + inline void SetLeadingZeros(uint8_t value) { + leading_zeros = value; + } + + inline void SetTrailingZeros(uint8_t value) { + D_ASSERT(value <= sizeof(CHIMP_TYPE) * 8); + trailing_zeros = value; + } + + uint8_t LeadingZeros() const { + return leading_zeros; + } + uint8_t TrailingZeros() const { + return trailing_zeros; + } + + BitReader input; + uint8_t leading_zeros; + uint8_t trailing_zeros; + CHIMP_TYPE reference_value = 0; + RingBuffer ring_buffer; + + bool first; +}; + +template +struct Chimp128Decompression { +public: + using DecompressState = Chimp128DecompressionState; + + static constexpr uint8_t INDEX_BITS_SIZE = 7; + static constexpr uint8_t BIT_SIZE = sizeof(CHIMP_TYPE) * 8; + + static inline void UnpackPackedData(uint16_t packed_data, UnpackedData &dest) { + return PackedDataUtils::Unpack(packed_data, dest); + } + + static inline CHIMP_TYPE Load(ChimpConstants::Flags flag, uint8_t leading_zeros[], uint32_t &leading_zero_index, + UnpackedData unpacked_data[], uint32_t &unpacked_index, DecompressState &state) { + if (DUCKDB_UNLIKELY(state.first)) { + return LoadFirst(state); + } else { + return DecompressValue(flag, leading_zeros, leading_zero_index, unpacked_data, unpacked_index, state); + } + } + + static inline CHIMP_TYPE LoadFirst(DecompressState &state) { + CHIMP_TYPE result = state.input.template ReadValue(); + state.ring_buffer.template InsertScan(result); + state.first = false; + state.reference_value = result; + return result; + } + + static inline CHIMP_TYPE DecompressValue(ChimpConstants::Flags flag, uint8_t leading_zeros[], + uint32_t &leading_zero_index, UnpackedData unpacked_data[], + uint32_t &unpacked_index, DecompressState &state) { + CHIMP_TYPE result; + switch (flag) { + case ChimpConstants::Flags::VALUE_IDENTICAL: { + //! Value is identical to previous value + auto index = state.input.template ReadValue(); + result = state.ring_buffer.Value(index); + break; + } + case ChimpConstants::Flags::TRAILING_EXCEEDS_THRESHOLD: { + const UnpackedData &unpacked = unpacked_data[unpacked_index++]; + state.leading_zeros = unpacked.leading_zero; + state.trailing_zeros = BIT_SIZE - unpacked.significant_bits - state.leading_zeros; + result = state.input.template ReadValue(unpacked.significant_bits); + result <<= state.trailing_zeros; + result ^= state.ring_buffer.Value(unpacked.index); + break; + } + case ChimpConstants::Flags::LEADING_ZERO_EQUALITY: { + result = state.input.template ReadValue(BIT_SIZE - state.leading_zeros); + result ^= state.reference_value; + break; + } + case ChimpConstants::Flags::LEADING_ZERO_LOAD: { + state.leading_zeros = leading_zeros[leading_zero_index++]; + D_ASSERT(state.leading_zeros <= BIT_SIZE); + result = state.input.template ReadValue(BIT_SIZE - state.leading_zeros); + result ^= state.reference_value; + break; + } + default: + throw InternalException("Chimp compression flag with value %d not recognized", flag); + } + state.reference_value = result; + state.ring_buffer.InsertScan(result); + return result; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp new file mode 100644 index 00000000..9318c484 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" + +namespace duckdb { + +template +struct SignificantBits {}; + +template <> +struct SignificantBits { + static constexpr uint8_t size = 6; + static constexpr uint8_t mask = ((uint8_t)1 << size) - 1; +}; + +template <> +struct SignificantBits { + static constexpr uint8_t size = 5; + static constexpr uint8_t mask = ((uint8_t)1 << size) - 1; +}; + +struct ChimpConstants { + struct Compression { + static constexpr uint8_t LEADING_ROUND[] = {0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 12, 12, 12, 12, + 16, 16, 18, 18, 20, 20, 22, 22, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24}; + static constexpr uint8_t LEADING_REPRESENTATION[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7}; + }; + struct Decompression { + static constexpr uint8_t LEADING_REPRESENTATION[] = {0, 8, 12, 16, 18, 20, 22, 24}; + }; + static constexpr uint8_t BUFFER_SIZE = 128; + enum class Flags : uint8_t { + VALUE_IDENTICAL = 0, + TRAILING_EXCEEDS_THRESHOLD = 1, + LEADING_ZERO_EQUALITY = 2, + LEADING_ZERO_LOAD = 3 + }; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp new file mode 100644 index 00000000..5336b334 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp @@ -0,0 +1,108 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/flag_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" +#ifdef DEBUG +#include "duckdb/common/vector.hpp" +#include "duckdb/common/assert.hpp" +#endif + +namespace duckdb { + +struct FlagBufferConstants { + static constexpr uint8_t MASKS[4] = { + 192, // 0b1100 0000, + 48, // 0b0011 0000, + 12, // 0b0000 1100, + 3, // 0b0000 0011, + }; + + static constexpr uint8_t SHIFTS[4] = {6, 4, 2, 0}; +}; + +// This class is responsible for writing and reading the flag bits +// Only the last group is potentially not 1024 (GROUP_SIZE) values in size +// But we can determine from the count of the segment whether this is the case or not +// So we can just read/write from left to right +template +class FlagBuffer { + +public: + FlagBuffer() : counter(0), buffer(nullptr) { + } + +public: + void SetBuffer(uint8_t *buffer) { + this->buffer = buffer; + this->counter = 0; + } + void Reset() { + this->counter = 0; +#ifdef DEBUG + this->flags.clear(); +#endif + } + +#ifdef DEBUG + uint8_t ExtractValue(uint32_t value, uint8_t index) { + return (value & FlagBufferConstants::MASKS[index]) >> FlagBufferConstants::SHIFTS[index]; + } +#endif + + uint64_t BitsWritten() const { + return counter * 2; + } + + void Insert(ChimpConstants::Flags value) { + if (!EMPTY) { + if ((counter & 3) == 0) { + // Start the new byte fresh + buffer[counter >> 2] = 0; +#ifdef DEBUG + flags.clear(); +#endif + } +#ifdef DEBUG + flags.push_back((uint8_t)value); +#endif + buffer[counter >> 2] |= (((uint8_t)value & 3) << FlagBufferConstants::SHIFTS[counter & 3]); +#ifdef DEBUG + // Verify that the bits are serialized correctly + D_ASSERT(flags[counter & 3] == ExtractValue(buffer[counter >> 2], counter & 3)); +#endif + } + counter++; + } + inline uint8_t Extract() { + const uint8_t result = (buffer[counter >> 2] & FlagBufferConstants::MASKS[counter & 3]) >> + FlagBufferConstants::SHIFTS[counter & 3]; + counter++; + return result; + } + + uint32_t BytesUsed() const { + return (counter >> 2) + ((counter & 3) != 0); + } + + uint32_t FlagCount() const { + return counter; + } + +private: +private: + uint32_t counter = 0; + uint8_t *buffer; +#ifdef DEBUG + vector flags; +#endif +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp new file mode 100644 index 00000000..436de19a --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp @@ -0,0 +1,165 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/leading_zero_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/common/helper.hpp" +#ifdef DEBUG +#include "duckdb/common/vector.hpp" +#include "duckdb/common/assert.hpp" +#endif + +namespace duckdb { + +//! This class is in charge of storing the leading_zero_bits, which are of a fixed size +//! These are packed together so that the rest of the data can be byte-aligned +//! The leading zero bit data is read from left to right + +struct LeadingZeroBufferConstants { + static constexpr uint32_t MASKS[8] = { + 7, // 0b 00000000 00000000 00000000 00000111, + 56, // 0b 00000000 00000000 00000000 00111000, + 448, // 0b 00000000 00000000 00000001 11000000, + 3584, // 0b 00000000 00000000 00001110 00000000, + 28672, // 0b 00000000 00000000 01110000 00000000, + 229376, // 0b 00000000 00000011 10000000 00000000, + 1835008, // 0b 00000000 00011100 00000000 00000000, + 14680064, // 0b 00000000 11100000 00000000 00000000, + }; + + // We're not using the last byte (the most significant) of the 4 bytes we're accessing + static constexpr uint8_t SHIFTS[8] = {0, 3, 6, 9, 12, 15, 18, 21}; +}; + +template +class LeadingZeroBuffer { + +public: + static constexpr uint32_t CHIMP_GROUP_SIZE = 1024; + static constexpr uint32_t LEADING_ZERO_BITS_SIZE = 3; + static constexpr uint32_t LEADING_ZERO_BLOCK_SIZE = 8; + static constexpr uint32_t LEADING_ZERO_BLOCK_BIT_SIZE = LEADING_ZERO_BLOCK_SIZE * LEADING_ZERO_BITS_SIZE; + static constexpr uint32_t MAX_LEADING_ZERO_BLOCKS = CHIMP_GROUP_SIZE / LEADING_ZERO_BLOCK_SIZE; + static constexpr uint32_t MAX_BITS_USED_BY_ZERO_BLOCKS = MAX_LEADING_ZERO_BLOCKS * LEADING_ZERO_BLOCK_BIT_SIZE; + static constexpr uint32_t MAX_BYTES_USED_BY_ZERO_BLOCKS = MAX_BITS_USED_BY_ZERO_BLOCKS / 8; + + // Add an extra byte to prevent heap buffer overflow on the last group, because we'll be addressing 4 bytes each + static constexpr uint32_t BUFFER_SIZE = + MAX_BYTES_USED_BY_ZERO_BLOCKS + (sizeof(uint32_t) - (LEADING_ZERO_BLOCK_BIT_SIZE / 8)); + + template + const T Load(const uint8_t *ptr) { + T ret; + memcpy(&ret, ptr, sizeof(ret)); + return ret; + } + +public: + LeadingZeroBuffer() : current(0), counter(0), buffer(nullptr) { + } + void SetBuffer(uint8_t *buffer) { + // Set the internal buffer, when inserting this should be BUFFER_SIZE bytes in length + // This buffer does not need to be zero-initialized for inserting + this->buffer = buffer; + this->counter = 0; + } + void Flush() { + if ((counter & 7) != 0) { + FlushBuffer(); + } + } + + uint64_t BitsWritten() const { + return counter * 3; + } + + // Reset the counter, but don't replace the buffer + void Reset() { + this->counter = 0; + current = 0; +#ifdef DEBUG + flags.clear(); +#endif + } + +public: +#ifdef DEBUG + uint8_t ExtractValue(uint32_t value, uint8_t index) { + return (value & LeadingZeroBufferConstants::MASKS[index]) >> LeadingZeroBufferConstants::SHIFTS[index]; + } +#endif + + inline uint64_t BlockIndex() const { + return ((counter >> 3) * (LEADING_ZERO_BLOCK_BIT_SIZE / 8)); + } + + void FlushBuffer() { + if (EMPTY) { + return; + } + const auto buffer_idx = BlockIndex(); + memcpy((void *)(buffer + buffer_idx), (uint8_t *)¤t, 3); +#ifdef DEBUG + // Verify that the bits are copied correctly + + uint32_t temp_value = 0; + memcpy((uint8_t *)&temp_value, (void *)(buffer + buffer_idx), 3); + for (idx_t i = 0; i < flags.size(); i++) { + D_ASSERT(flags[i] == ExtractValue(temp_value, i)); + } + flags.clear(); +#endif + } + + void Insert(const uint8_t &value) { + if (!EMPTY) { +#ifdef DEBUG + flags.push_back(value); +#endif + current |= (value & 7) << LeadingZeroBufferConstants::SHIFTS[counter & 7]; +#ifdef DEBUG + // Verify that the bits are serialized correctly + D_ASSERT(flags[counter & 7] == ExtractValue(current, counter & 7)); +#endif + + if ((counter & (LEADING_ZERO_BLOCK_SIZE - 1)) == 7) { + FlushBuffer(); + current = 0; + } + } + counter++; + } + + inline uint8_t Extract() { + const auto buffer_idx = BlockIndex(); + auto const temp = Load(buffer + buffer_idx); + + const uint8_t result = + (temp & LeadingZeroBufferConstants::MASKS[counter & 7]) >> LeadingZeroBufferConstants::SHIFTS[counter & 7]; + counter++; + return result; + } + idx_t GetCount() const { + return counter; + } + idx_t BlockCount() const { + return (counter >> 3) + ((counter & 7) != 0); + } + +private: +private: + uint32_t current; + uint32_t counter = 0; // block_index * 8 + uint8_t *buffer; +#ifdef DEBUG + vector flags; +#endif +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/output_bit_stream.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/output_bit_stream.hpp new file mode 100644 index 00000000..3364e322 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/output_bit_stream.hpp @@ -0,0 +1,216 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/output_bit_stream.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/common/assert.hpp" + +#include "duckdb/storage/compression/chimp/algorithm/bit_utils.hpp" + +namespace duckdb { + +// This class writes arbitrary amounts of bits to a stream +// The way these bits are written is most-significant bit first +// For example if 6 bits are given as: 0b0011 1111 +// The bits are written to the stream as: 0b1111 1100 +template +class OutputBitStream { + using INTERNAL_TYPE = uint8_t; + +public: + friend class BitStreamWriter; + friend class EmptyWriter; + OutputBitStream() + : stream(nullptr), current(0), free_bits(INTERNAL_TYPE_BITSIZE), stream_index(0), bits_written(0) { + } + +public: + static constexpr uint8_t INTERNAL_TYPE_BITSIZE = sizeof(INTERNAL_TYPE) * 8; + + idx_t BytesWritten() const { + return (bits_written >> 3) + ((bits_written & 7) != 0); + } + + idx_t BitsWritten() const { + return bits_written; + } + + void Flush() { + if (free_bits == INTERNAL_TYPE_BITSIZE) { + // the bit buffer is empty, nothing to write + return; + } + WriteToStream(); + } + + void SetStream(uint8_t *output_stream) { + stream = output_stream; + stream_index = 0; + bits_written = 0; + free_bits = INTERNAL_TYPE_BITSIZE; + current = 0; + } + + uint64_t *Stream() { + return (uint64_t *)stream; + } + + idx_t BitSize() const { + return (stream_index * INTERNAL_TYPE_BITSIZE) + (INTERNAL_TYPE_BITSIZE - free_bits); + } + + template + void WriteRemainder(T value, uint8_t i) { + if (sizeof(T) * 8 > 32) { + if (i == 64) { + WriteToStream(((uint64_t)value >> 56) & 0xFF); + } + if (i > 55) { + WriteToStream(((uint64_t)value >> 48) & 0xFF); + } + if (i > 47) { + WriteToStream(((uint64_t)value >> 40) & 0xFF); + } + if (i > 39) { + WriteToStream(((uint64_t)value >> 32) & 0xFF); + } + } + if (i > 31) { + WriteToStream((value >> 24) & 0xFF); + } + if (i > 23) { + WriteToStream((value >> 16) & 0xFF); + } + if (i > 15) { + WriteToStream((value >> 8) & 0xFF); + } + if (i > 7) { + WriteToStream(value); + } + } + + template + void WriteValue(T value) { + bits_written += VALUE_SIZE; + if (EMPTY) { + return; + } + if (FitsInCurrent(VALUE_SIZE)) { + //! If we can write the entire value in one go + WriteInCurrent((INTERNAL_TYPE)value); + return; + } + auto i = VALUE_SIZE - free_bits; + const uint8_t queue = i & 7; + + if (free_bits != 0) { + // Reset the number of free bits + WriteInCurrent(value >> i, free_bits); + } + if (queue != 0) { + // We dont fill the entire 'current' buffer, + // so we can write these to 'current' first without flushing to the stream + // And then write the remaining bytes directly to the stream + i -= queue; + WriteInCurrent((INTERNAL_TYPE)value, queue); + value >>= queue; + } + WriteRemainder(value, i); + } + + template + void WriteValue(T value, const uint8_t &value_size) { + bits_written += value_size; + if (EMPTY) { + return; + } + if (FitsInCurrent(value_size)) { + //! If we can write the entire value in one go + WriteInCurrent((INTERNAL_TYPE)value, value_size); + return; + } + auto i = value_size - free_bits; + const uint8_t queue = i & 7; + + if (free_bits != 0) { + // Reset the number of free bits + WriteInCurrent(value >> i, free_bits); + } + if (queue != 0) { + // We dont fill the entire 'current' buffer, + // so we can write these to 'current' first without flushing to the stream + // And then write the remaining bytes directly to the stream + i -= queue; + WriteInCurrent((INTERNAL_TYPE)value, queue); + value >>= queue; + } + WriteRemainder(value, i); + } + +private: + void WriteBit(bool value) { + auto &byte = GetCurrentByte(); + if (value) { + byte = byte | GetMask(); + } + DecreaseFreeBits(); + } + + bool FitsInCurrent(uint8_t bits) { + return free_bits >= bits; + } + INTERNAL_TYPE GetMask() const { + return (INTERNAL_TYPE)1 << free_bits; + } + + INTERNAL_TYPE &GetCurrentByte() { + return current; + } + //! Write a value of type INTERNAL_TYPE directly to the stream + void WriteToStream(INTERNAL_TYPE value) { + stream[stream_index++] = value; + } + void WriteToStream() { + stream[stream_index++] = current; + current = 0; + free_bits = INTERNAL_TYPE_BITSIZE; + } + void DecreaseFreeBits(uint8_t value = 1) { + D_ASSERT(free_bits >= value); + free_bits -= value; + if (free_bits == 0) { + WriteToStream(); + } + } + void WriteInCurrent(INTERNAL_TYPE value, uint8_t value_size) { + D_ASSERT(INTERNAL_TYPE_BITSIZE >= value_size); + const auto shift_amount = free_bits - value_size; + current |= (value & BitUtils::Mask(value_size)) << shift_amount; + DecreaseFreeBits(value_size); + } + + template + void WriteInCurrent(INTERNAL_TYPE value) { + D_ASSERT(INTERNAL_TYPE_BITSIZE >= VALUE_SIZE); + const auto shift_amount = free_bits - VALUE_SIZE; + current |= (value & BitUtils::Mask(VALUE_SIZE)) << shift_amount; + DecreaseFreeBits(VALUE_SIZE); + } + +private: + uint8_t *stream; //! The stream we're writing our output to + + INTERNAL_TYPE current; //! The current value we're writing into (zero-initialized) + uint8_t free_bits; //! How many bits are still unwritten in 'current' + idx_t stream_index; //! Index used to keep track of which index we're at in the stream + + idx_t bits_written; //! The total amount of bits written to this stream +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/packed_data.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/packed_data.hpp new file mode 100644 index 00000000..932719de --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/packed_data.hpp @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/packed_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" +#include "duckdb.h" + +namespace duckdb { + +struct UnpackedData { + uint8_t leading_zero; + uint8_t significant_bits; + uint8_t index; +}; + +template +struct PackedDataUtils { +private: + static constexpr uint8_t INDEX_BITS_SIZE = 7; + static constexpr uint8_t LEADING_BITS_SIZE = 3; + + static constexpr uint8_t INDEX_MASK = ((uint8_t)1 << INDEX_BITS_SIZE) - 1; + static constexpr uint8_t LEADING_MASK = ((uint8_t)1 << LEADING_BITS_SIZE) - 1; + + static constexpr uint8_t INDEX_SHIFT_AMOUNT = (sizeof(uint16_t) * 8) - INDEX_BITS_SIZE; + static constexpr uint8_t LEADING_SHIFT_AMOUNT = INDEX_SHIFT_AMOUNT - LEADING_BITS_SIZE; + +public: + //|----------------| //! packed_data(16) bits + // IIIIIII //! Index (7 bits, shifted by 9) + // LLL //! LeadingZeros (3 bits, shifted by 6) + // SSSSSS //! SignificantBits (6 bits) + static inline void Unpack(uint16_t packed_data, UnpackedData &dest) { + dest.index = packed_data >> INDEX_SHIFT_AMOUNT & INDEX_MASK; + dest.leading_zero = packed_data >> LEADING_SHIFT_AMOUNT & LEADING_MASK; + dest.significant_bits = packed_data & SignificantBits::mask; + // Verify that combined, this is not bigger than the full size of the type + D_ASSERT(dest.significant_bits + dest.leading_zero <= (sizeof(CHIMP_TYPE) * 8)); + } + + static inline uint16_t Pack(uint8_t index, uint8_t leading_zero, uint8_t significant_bits) { + static constexpr uint8_t BIT_SIZE = (sizeof(CHIMP_TYPE) * 8); + + uint16_t result = 0; + result += ((uint32_t)BIT_SIZE << 3) * (ChimpConstants::BUFFER_SIZE + index); + result += BIT_SIZE * (leading_zero & 7); + if (BIT_SIZE == 32) { + // Shift the result by 1 to occupy the 16th bit + result <<= 1; + } + result += (significant_bits & 63); + + return result; + } +}; + +template +struct PackedDataBuffer { +public: + PackedDataBuffer() : index(0), buffer(nullptr) { + } + +public: + void SetBuffer(uint16_t *buffer) { + this->buffer = buffer; + this->index = 0; + } + + void Reset() { + this->index = 0; + } + + inline void Insert(uint16_t packed_data) { + if (!EMPTY) { + buffer[index] = packed_data; + } + index++; + } + + idx_t index; + uint16_t *buffer; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/ring_buffer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/ring_buffer.hpp new file mode 100644 index 00000000..04980e92 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/ring_buffer.hpp @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/ring_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" + +namespace duckdb { + +template +class RingBuffer { +public: + static constexpr uint8_t RING_SIZE = ChimpConstants::BUFFER_SIZE; + static constexpr uint64_t LEAST_SIGNIFICANT_BIT_COUNT = SignificantBits::size + 7 + 1; + static constexpr uint64_t LEAST_SIGNIFICANT_BIT_MASK = (1 << LEAST_SIGNIFICANT_BIT_COUNT) - 1; + static constexpr uint16_t INDICES_SIZE = 1 << LEAST_SIGNIFICANT_BIT_COUNT; // 16384 + +public: + void Reset() { + index = 0; + } + + RingBuffer() : index(0) { + } + template + void Insert(uint64_t value) { + if (!FIRST) { + index++; + } + buffer[index % RING_SIZE] = value; + indices[Key(value)] = index; + } + template + void InsertScan(uint64_t value) { + if (!FIRST) { + index++; + } + buffer[index % RING_SIZE] = value; + } + inline const uint64_t &Top() const { + return buffer[index % RING_SIZE]; + } + //! Get the index where values that produce this 'key' are stored + inline const uint64_t &IndexOf(const uint64_t &key) const { + return indices[key]; + } + //! Get the value at position 'index' of the buffer + inline const uint64_t &Value(const uint8_t &index_p) const { + return buffer[index_p]; + } + //! Get the amount of values that are inserted + inline const uint64_t &Size() const { + return index; + } + inline uint64_t Key(const uint64_t &value) const { + return value & LEAST_SIGNIFICANT_BIT_MASK; + } + +private: + uint64_t buffer[RING_SIZE] = {}; //! Stores the corresponding values + uint64_t index = 0; //! Keeps track of the index of the current value + uint64_t indices[INDICES_SIZE] = {}; //! Stores the corresponding indices +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp.hpp new file mode 100644 index 00000000..1d5e14bb --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp.hpp @@ -0,0 +1,77 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/chimp.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/algorithm/chimp128.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/function/compression_function.hpp" + +namespace duckdb { + +using byte_index_t = uint32_t; + +template +struct ChimpType {}; + +template <> +struct ChimpType { + typedef uint64_t type; +}; + +template <> +struct ChimpType { + typedef uint32_t type; +}; + +class ChimpPrimitives { +public: + static constexpr uint32_t CHIMP_SEQUENCE_SIZE = 1024; + static constexpr uint8_t MAX_BYTES_PER_VALUE = sizeof(double) + 1; // extra wiggle room + static constexpr uint8_t HEADER_SIZE = sizeof(uint32_t); + static constexpr uint8_t FLAG_BIT_SIZE = 2; + static constexpr uint32_t LEADING_ZERO_BLOCK_BUFFERSIZE = 1 + (CHIMP_SEQUENCE_SIZE / 8) * 3; +}; + +//! Where all the magic happens +template +struct ChimpState { +public: + using CHIMP_TYPE = typename ChimpType::type; + + ChimpState() : chimp() { + } + Chimp128CompressionState chimp; + +public: + void AssignDataBuffer(uint8_t *data_out) { + chimp.output.SetStream(data_out); + } + + void AssignFlagBuffer(uint8_t *flag_out) { + chimp.flag_buffer.SetBuffer(flag_out); + } + + void AssignPackedDataBuffer(uint16_t *packed_data_out) { + chimp.packed_data_buffer.SetBuffer(packed_data_out); + } + + void AssignLeadingZeroBuffer(uint8_t *leading_zero_out) { + chimp.leading_zero_buffer.SetBuffer(leading_zero_out); + } + + void Flush() { + chimp.output.Flush(); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_analyze.hpp new file mode 100644 index 00000000..1c2fc330 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_analyze.hpp @@ -0,0 +1,135 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/chimp_analyze.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/chimp.hpp" +#include "duckdb/function/compression_function.hpp" + +namespace duckdb { + +struct EmptyChimpWriter; + +template +struct ChimpAnalyzeState : public AnalyzeState { +public: + using CHIMP_TYPE = typename ChimpType::type; + + ChimpAnalyzeState() : state() { + state.AssignDataBuffer(nullptr); + } + ChimpState state; + idx_t group_idx = 0; + idx_t data_byte_size = 0; + idx_t metadata_byte_size = 0; + +public: + void WriteValue(CHIMP_TYPE value, bool is_valid) { + if (!is_valid) { + return; + } + //! Keep track of when a segment would end, to accurately simulate Reset()s in compress step + if (!HasEnoughSpace()) { + StartNewSegment(); + } + Chimp128Compression::Store(value, state.chimp); + group_idx++; + if (group_idx == ChimpPrimitives::CHIMP_SEQUENCE_SIZE) { + StartNewGroup(); + } + } + + void StartNewSegment() { + state.Flush(); + StartNewGroup(); + data_byte_size += UsedSpace(); + metadata_byte_size += ChimpPrimitives::HEADER_SIZE; + state.chimp.output.SetStream(nullptr); + } + + idx_t CurrentGroupMetadataSize() const { + idx_t metadata_size = 0; + + metadata_size += 3 * state.chimp.leading_zero_buffer.BlockCount(); + metadata_size += state.chimp.flag_buffer.BytesUsed(); + metadata_size += 2 * state.chimp.packed_data_buffer.index; + return metadata_size; + } + + idx_t RequiredSpace() const { + idx_t required_space = ChimpPrimitives::MAX_BYTES_PER_VALUE; + // Any value could be the last, + // so the cost of flushing metadata should be factored into the cost + // byte offset of data + required_space += sizeof(byte_index_t); + // amount of leading zero blocks + required_space += sizeof(uint8_t); + // first leading zero block + required_space += 3; + // amount of flag bytes + required_space += sizeof(uint8_t); + // first flag byte + required_space += 1; + return required_space; + } + + void StartNewGroup() { + metadata_byte_size += CurrentGroupMetadataSize(); + group_idx = 0; + state.chimp.Reset(); + } + + idx_t UsedSpace() const { + return state.chimp.output.BytesWritten(); + } + + bool HasEnoughSpace() { + idx_t total_bytes_used = 0; + total_bytes_used += AlignValue(ChimpPrimitives::HEADER_SIZE + UsedSpace() + RequiredSpace()); + total_bytes_used += CurrentGroupMetadataSize(); + total_bytes_used += metadata_byte_size; + return total_bytes_used <= Storage::BLOCK_SIZE; + } + + idx_t TotalUsedBytes() const { + return metadata_byte_size + AlignValue(data_byte_size + UsedSpace()); + } +}; + +template +unique_ptr ChimpInitAnalyze(ColumnData &col_data, PhysicalType type) { + return make_uniq>(); +} + +template +bool ChimpAnalyze(AnalyzeState &state, Vector &input, idx_t count) { + using CHIMP_TYPE = typename ChimpType::type; + auto &analyze_state = (ChimpAnalyzeState &)state; + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + analyze_state.WriteValue(data[idx], vdata.validity.RowIsValid(idx)); + } + return true; +} + +template +idx_t ChimpFinalAnalyze(AnalyzeState &state) { + auto &chimp = (ChimpAnalyzeState &)state; + // Finish the last "segment" + chimp.StartNewSegment(); + // Multiply the final size to factor in the extra cost of decompression time + const auto multiplier = 2.0; + const auto final_analyze_size = chimp.TotalUsedBytes(); + return final_analyze_size * multiplier; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_compress.hpp new file mode 100644 index 00000000..a514505b --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_compress.hpp @@ -0,0 +1,281 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/chimp_compress.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/chimp.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/compression/chimp/chimp_analyze.hpp" + +#include "duckdb/common/helper.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/operator/subtract.hpp" + +#include + +namespace duckdb { + +template +struct ChimpCompressionState : public CompressionState { +public: + using CHIMP_TYPE = typename ChimpType::type; + + explicit ChimpCompressionState(ColumnDataCheckpointer &checkpointer, ChimpAnalyzeState *analyze_state) + : checkpointer(checkpointer), + function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_CHIMP)) { + CreateEmptySegment(checkpointer.GetRowGroup().start); + + // These buffers are recycled for every group, so they only have to be set once + state.AssignLeadingZeroBuffer((uint8_t *)leading_zero_blocks); + state.AssignFlagBuffer((uint8_t *)flags); + state.AssignPackedDataBuffer((uint16_t *)packed_data_blocks); + } + + ColumnDataCheckpointer &checkpointer; + CompressionFunction &function; + unique_ptr current_segment; + BufferHandle handle; + idx_t group_idx = 0; + uint8_t flags[ChimpPrimitives::CHIMP_SEQUENCE_SIZE / 4]; + uint8_t leading_zero_blocks[ChimpPrimitives::LEADING_ZERO_BLOCK_BUFFERSIZE]; + uint16_t packed_data_blocks[ChimpPrimitives::CHIMP_SEQUENCE_SIZE]; + + // Ptr to next free spot in segment; + data_ptr_t segment_data; + data_ptr_t metadata_ptr; + uint32_t next_group_byte_index_start = ChimpPrimitives::HEADER_SIZE; + // The total size of metadata in the current segment + idx_t metadata_byte_size = 0; + + ChimpState state; + +public: + idx_t RequiredSpace() const { + idx_t required_space = ChimpPrimitives::MAX_BYTES_PER_VALUE; + // Any value could be the last, + // so the cost of flushing metadata should be factored into the cost + + // byte offset of data + required_space += sizeof(byte_index_t); + // amount of leading zero blocks + required_space += sizeof(uint8_t); + // first leading zero block + required_space += 3; + // amount of flag bytes + required_space += sizeof(uint8_t); + // first flag byte + required_space += 1; + return required_space; + } + + // How many bytes the data occupies for the current segment + idx_t UsedSpace() const { + return state.chimp.output.BytesWritten(); + } + + idx_t RemainingSpace() const { + return metadata_ptr - (handle.Ptr() + UsedSpace()); + } + + idx_t CurrentGroupMetadataSize() const { + idx_t metadata_size = 0; + + metadata_size += 3 * state.chimp.leading_zero_buffer.BlockCount(); + metadata_size += state.chimp.flag_buffer.BytesUsed(); + metadata_size += 2 * state.chimp.packed_data_buffer.index; + return metadata_size; + } + + // The current segment has enough space to fit this new value + bool HasEnoughSpace() { + if (handle.Ptr() + AlignValue(ChimpPrimitives::HEADER_SIZE + UsedSpace() + RequiredSpace()) >= + (metadata_ptr - CurrentGroupMetadataSize())) { + return false; + } + return true; + } + + void CreateEmptySegment(idx_t row_start) { + group_idx = 0; + metadata_byte_size = 0; + auto &db = checkpointer.GetDatabase(); + auto &type = checkpointer.GetType(); + auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); + compressed_segment->function = function; + current_segment = std::move(compressed_segment); + next_group_byte_index_start = ChimpPrimitives::HEADER_SIZE; + + auto &buffer_manager = BufferManager::GetBufferManager(db); + handle = buffer_manager.Pin(current_segment->block); + + segment_data = handle.Ptr() + current_segment->GetBlockOffset() + ChimpPrimitives::HEADER_SIZE; + metadata_ptr = handle.Ptr() + current_segment->GetBlockOffset() + Storage::BLOCK_SIZE; + state.AssignDataBuffer(segment_data); + state.chimp.Reset(); + } + + void Append(UnifiedVectorFormat &vdata, idx_t count) { + auto data = UnifiedVectorFormat::GetData(vdata); + + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + WriteValue(data[idx], vdata.validity.RowIsValid(idx)); + } + } + + void WriteValue(CHIMP_TYPE value, bool is_valid) { + if (!HasEnoughSpace()) { + // Segment is full + auto row_start = current_segment->start + current_segment->count; + FlushSegment(); + CreateEmptySegment(row_start); + } + current_segment->count++; + + if (is_valid) { + T floating_point_value = Load(const_data_ptr_cast(&value)); + NumericStats::Update(current_segment->stats.statistics, floating_point_value); + } else { + //! FIXME: find a cheaper alternative to storing a NULL + // store this as "value_identical", only using 9 bits for a NULL + value = state.chimp.previous_value; + } + + Chimp128Compression::Store(value, state.chimp); + group_idx++; + if (group_idx == ChimpPrimitives::CHIMP_SEQUENCE_SIZE) { + FlushGroup(); + } + } + + void FlushGroup() { + // Has to be called first to flush the last values in the LeadingZeroBuffer + state.chimp.Flush(); + + metadata_ptr -= sizeof(byte_index_t); + metadata_byte_size += sizeof(byte_index_t); + // Store where this groups data starts, relative to the start of the segment + Store(next_group_byte_index_start, metadata_ptr); + next_group_byte_index_start = UsedSpace(); + + const uint8_t leading_zero_block_count = state.chimp.leading_zero_buffer.BlockCount(); + // Every 8 values are packed in one block + D_ASSERT(leading_zero_block_count <= ChimpPrimitives::CHIMP_SEQUENCE_SIZE / 8); + metadata_ptr -= sizeof(uint8_t); + metadata_byte_size += sizeof(uint8_t); + // Store how many leading zero blocks there are + Store(leading_zero_block_count, metadata_ptr); + + const uint64_t bytes_used_by_leading_zero_blocks = 3 * leading_zero_block_count; + metadata_ptr -= bytes_used_by_leading_zero_blocks; + metadata_byte_size += bytes_used_by_leading_zero_blocks; + // Store the leading zeros (8 per 3 bytes) for this group + memcpy((void *)metadata_ptr, (void *)leading_zero_blocks, bytes_used_by_leading_zero_blocks); + + //! This is max 1024, because it's the amount of flags there are, not the amount of bytes that takes up + const uint16_t flag_bytes = state.chimp.flag_buffer.BytesUsed(); +#ifdef DEBUG + const idx_t padding = (current_segment->count % ChimpPrimitives::CHIMP_SEQUENCE_SIZE) == 0 + ? ChimpPrimitives::CHIMP_SEQUENCE_SIZE + : 0; + const idx_t size_of_group = padding + current_segment->count % ChimpPrimitives::CHIMP_SEQUENCE_SIZE; + D_ASSERT((AlignValue(size_of_group - 1) / 4) == flag_bytes); +#endif + + metadata_ptr -= flag_bytes; + metadata_byte_size += flag_bytes; + // Store the flags (4 per byte) for this group + memcpy((void *)metadata_ptr, (void *)flags, flag_bytes); + + // Store the packed data blocks (2 bytes each) + // We dont need to store an extra count for this, + // as the count can be derived from unpacking the flags and counting the '1' flags + + // FIXME: this does stop us from skipping groups with point queries, + // because the metadata has a variable size, and we have to extract all flags + iterate them to know this size + const uint16_t packed_data_blocks_count = state.chimp.packed_data_buffer.index; + metadata_ptr -= packed_data_blocks_count * 2; + metadata_byte_size += packed_data_blocks_count * 2; + if ((uint64_t)metadata_ptr & 1) { + // Align on a two-byte boundary + metadata_ptr--; + metadata_byte_size++; + } + memcpy((void *)metadata_ptr, (void *)packed_data_blocks, packed_data_blocks_count * sizeof(uint16_t)); + + state.chimp.Reset(); + group_idx = 0; + } + + // FIXME: only do this if the wasted space meets a certain threshold (>= 20%) + void FlushSegment() { + if (group_idx) { + // Only call this when the group actually has data that needs to be flushed + FlushGroup(); + } + state.chimp.output.Flush(); + auto &checkpoint_state = checkpointer.GetCheckpointState(); + auto dataptr = handle.Ptr(); + + // Compact the segment by moving the metadata next to the data. + idx_t bytes_used_by_data = ChimpPrimitives::HEADER_SIZE + UsedSpace(); + idx_t metadata_offset = AlignValue(bytes_used_by_data); + // Verify that the metadata_ptr does not cross this threshold + D_ASSERT(dataptr + metadata_offset <= metadata_ptr); + idx_t metadata_size = dataptr + Storage::BLOCK_SIZE - metadata_ptr; + idx_t total_segment_size = metadata_offset + metadata_size; +#ifdef DEBUG + uint32_t verify_bytes; + memcpy((void *)&verify_bytes, metadata_ptr, 4); +#endif + memmove(dataptr + metadata_offset, metadata_ptr, metadata_size); +#ifdef DEBUG + D_ASSERT(verify_bytes == *(uint32_t *)(dataptr + metadata_offset)); +#endif + // Store the offset of the metadata of the first group (which is at the highest address). + Store(metadata_offset + metadata_size, dataptr); + handle.Destroy(); + checkpoint_state.FlushSegment(std::move(current_segment), total_segment_size); + } + + void Finalize() { + FlushSegment(); + current_segment.reset(); + } +}; + +// Compression Functions + +template +unique_ptr ChimpInitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr state) { + return make_uniq>(checkpointer, (ChimpAnalyzeState *)state.get()); +} + +template +void ChimpCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { + auto &state = (ChimpCompressionState &)state_p; + UnifiedVectorFormat vdata; + scan_vector.ToUnifiedFormat(count, vdata); + state.Append(vdata, count); +} + +template +void ChimpFinalizeCompress(CompressionState &state_p) { + auto &state = (ChimpCompressionState &)state_p; + state.Finalize(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_fetch.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_fetch.hpp new file mode 100644 index 00000000..52b19c9e --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_fetch.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/chimp_fetch.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/chimp.hpp" +#include "duckdb/storage/compression/chimp/chimp_scan.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/operator/subtract.hpp" + +namespace duckdb { + +template +void ChimpFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { + using INTERNAL_TYPE = typename ChimpType::type; + + ChimpScanState scan_state(segment); + scan_state.Skip(segment, row_id); + auto result_data = FlatVector::GetData(result); + + if (scan_state.GroupFinished() && scan_state.total_value_count < scan_state.segment_count) { + scan_state.LoadGroup(scan_state.group_state.values); + } + scan_state.group_state.Scan(&result_data[result_idx], 1); + + scan_state.total_value_count++; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp new file mode 100644 index 00000000..559b5832 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp @@ -0,0 +1,292 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/chimp/chimp_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/chimp.hpp" +#include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/operator/subtract.hpp" + +#include "duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp" +#include "duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +template +struct ChimpGroupState { +public: + void Init(uint8_t *data) { + chimp_state.input.SetStream(data); + Reset(); + } + + void Reset() { + chimp_state.Reset(); + index = 0; + } + + bool Started() const { + return !!index; + } + + // Assuming the group is completely full + idx_t RemainingInGroup() const { + return ChimpPrimitives::CHIMP_SEQUENCE_SIZE - index; + } + + void Scan(CHIMP_TYPE *dest, idx_t count) { + memcpy(dest, (void *)(values + index), count * sizeof(CHIMP_TYPE)); + index += count; + } + + void LoadFlags(uint8_t *packed_data, idx_t group_size) { + FlagBuffer flag_buffer; + flag_buffer.SetBuffer(packed_data); + flags[0] = ChimpConstants::Flags::VALUE_IDENTICAL; // First value doesn't require a flag + for (idx_t i = 0; i < group_size; i++) { + flags[1 + i] = (ChimpConstants::Flags)flag_buffer.Extract(); + } + max_flags_to_read = group_size; + index = 0; + } + + void LoadLeadingZeros(uint8_t *packed_data, idx_t leading_zero_block_size) { +#ifdef DEBUG + idx_t flag_one_count = 0; + for (idx_t i = 0; i < max_flags_to_read; i++) { + flag_one_count += flags[1 + i] == ChimpConstants::Flags::LEADING_ZERO_LOAD; + } + // There are 8 leading zero values packed in one block, the block could be partially filled + flag_one_count = AlignValue(flag_one_count); + D_ASSERT(flag_one_count == leading_zero_block_size); +#endif + LeadingZeroBuffer leading_zero_buffer; + leading_zero_buffer.SetBuffer(packed_data); + for (idx_t i = 0; i < leading_zero_block_size; i++) { + leading_zeros[i] = ChimpConstants::Decompression::LEADING_REPRESENTATION[leading_zero_buffer.Extract()]; + } + max_leading_zeros_to_read = leading_zero_block_size; + leading_zero_index = 0; + } + + idx_t CalculatePackedDataCount() const { + idx_t count = 0; + for (idx_t i = 0; i < max_flags_to_read; i++) { + count += flags[1 + i] == ChimpConstants::Flags::TRAILING_EXCEEDS_THRESHOLD; + } + return count; + } + + void LoadPackedData(uint16_t *packed_data, idx_t packed_data_block_count) { + for (idx_t i = 0; i < packed_data_block_count; i++) { + PackedDataUtils::Unpack(packed_data[i], unpacked_data_blocks[i]); + if (unpacked_data_blocks[i].significant_bits == 0) { + unpacked_data_blocks[i].significant_bits = 64; + } + unpacked_data_blocks[i].leading_zero = + ChimpConstants::Decompression::LEADING_REPRESENTATION[unpacked_data_blocks[i].leading_zero]; + } + unpacked_index = 0; + max_packed_data_to_read = packed_data_block_count; + } + + void LoadValues(CHIMP_TYPE *result, idx_t count) { + for (idx_t i = 0; i < count; i++) { + result[i] = Chimp128Decompression::Load(flags[i], leading_zeros, leading_zero_index, + unpacked_data_blocks, unpacked_index, chimp_state); + } + } + +public: + uint32_t leading_zero_index; + uint32_t unpacked_index; + + ChimpConstants::Flags flags[ChimpPrimitives::CHIMP_SEQUENCE_SIZE + 1]; + uint8_t leading_zeros[ChimpPrimitives::CHIMP_SEQUENCE_SIZE + 1]; + UnpackedData unpacked_data_blocks[ChimpPrimitives::CHIMP_SEQUENCE_SIZE]; + + CHIMP_TYPE values[ChimpPrimitives::CHIMP_SEQUENCE_SIZE]; + +private: + idx_t index; + idx_t max_leading_zeros_to_read; + idx_t max_flags_to_read; + idx_t max_packed_data_to_read; + Chimp128DecompressionState chimp_state; +}; + +template +struct ChimpScanState : public SegmentScanState { +public: + using CHIMP_TYPE = typename ChimpType::type; + + explicit ChimpScanState(ColumnSegment &segment) : segment(segment), segment_count(segment.count) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + + handle = buffer_manager.Pin(segment.block); + auto dataptr = handle.Ptr(); + // ScanStates never exceed the boundaries of a Segment, + // but are not guaranteed to start at the beginning of the Block + auto start_of_data_segment = dataptr + segment.GetBlockOffset() + ChimpPrimitives::HEADER_SIZE; + group_state.Init(start_of_data_segment); + auto metadata_offset = Load(dataptr + segment.GetBlockOffset()); + metadata_ptr = dataptr + segment.GetBlockOffset() + metadata_offset; + } + + BufferHandle handle; + data_ptr_t metadata_ptr; + idx_t total_value_count = 0; + ChimpGroupState group_state; + + ColumnSegment &segment; + idx_t segment_count; + + idx_t LeftInGroup() const { + return ChimpPrimitives::CHIMP_SEQUENCE_SIZE - (total_value_count % ChimpPrimitives::CHIMP_SEQUENCE_SIZE); + } + + bool GroupFinished() const { + return (total_value_count % ChimpPrimitives::CHIMP_SEQUENCE_SIZE) == 0; + } + + template + void ScanGroup(CHIMP_TYPE *values, idx_t group_size) { + D_ASSERT(group_size <= ChimpPrimitives::CHIMP_SEQUENCE_SIZE); + D_ASSERT(group_size <= LeftInGroup()); + + if (GroupFinished() && total_value_count < segment_count) { + if (group_size == ChimpPrimitives::CHIMP_SEQUENCE_SIZE) { + LoadGroup(values); + total_value_count += group_size; + return; + } else { + LoadGroup(group_state.values); + } + } + group_state.Scan(values, group_size); + total_value_count += group_size; + } + + void LoadGroup(CHIMP_TYPE *value_buffer) { + + //! FIXME: If we change the order of this to flag -> leading_zero_blocks -> packed_data + //! We can leave out the leading zero block count as well, because it can be derived from + //! Extracting all the flags and counting the 3's + + // Load the offset indicating where a groups data starts + metadata_ptr -= sizeof(uint32_t); + auto data_byte_offset = Load(metadata_ptr); + D_ASSERT(data_byte_offset < Storage::BLOCK_SIZE); + // Only used for point queries + (void)data_byte_offset; + + // Load how many blocks of leading zero bits we have + metadata_ptr -= sizeof(uint8_t); + auto leading_zero_block_count = Load(metadata_ptr); + D_ASSERT(leading_zero_block_count <= ChimpPrimitives::CHIMP_SEQUENCE_SIZE / 8); + + // Load the leading zero block count + metadata_ptr -= 3 * leading_zero_block_count; + const auto leading_zero_block_ptr = metadata_ptr; + + // Figure out how many flags there are + D_ASSERT(segment_count >= total_value_count); + auto group_size = MinValue(segment_count - total_value_count, ChimpPrimitives::CHIMP_SEQUENCE_SIZE); + // Reduce by one, because the first value of a group does not have a flag + auto flag_count = group_size - 1; + uint16_t flag_byte_count = (AlignValue(flag_count) / 4); + + // Load the flags + metadata_ptr -= flag_byte_count; + auto flags = metadata_ptr; + group_state.LoadFlags(flags, flag_count); + + // Load the leading zero blocks + group_state.LoadLeadingZeros(leading_zero_block_ptr, (uint32_t)leading_zero_block_count * 8); + + // Load packed data blocks + auto packed_data_block_count = group_state.CalculatePackedDataCount(); + metadata_ptr -= packed_data_block_count * 2; + if ((uint64_t)metadata_ptr & 1) { + // Align on a two-byte boundary + metadata_ptr--; + } + group_state.LoadPackedData((uint16_t *)metadata_ptr, packed_data_block_count); + + group_state.Reset(); + + // Load all values for the group + group_state.LoadValues(value_buffer, group_size); + } + +public: + //! Skip the next 'skip_count' values, we don't store the values + // TODO: use the metadata to determine if we can skip a group + void Skip(ColumnSegment &segment, idx_t skip_count) { + using INTERNAL_TYPE = typename ChimpType::type; + INTERNAL_TYPE buffer[ChimpPrimitives::CHIMP_SEQUENCE_SIZE]; + + while (skip_count) { + auto skip_size = MinValue(skip_count, LeftInGroup()); + ScanGroup(buffer, skip_size); + skip_count -= skip_size; + } + } +}; + +template +unique_ptr ChimpInitScan(ColumnSegment &segment) { + auto result = make_uniq_base>(segment); + return result; +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +template +void ChimpScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + using INTERNAL_TYPE = typename ChimpType::type; + auto &scan_state = (ChimpScanState &)*state.scan_state; + + T *result_data = FlatVector::GetData(result); + result.SetVectorType(VectorType::FLAT_VECTOR); + + auto current_result_ptr = (INTERNAL_TYPE *)(result_data + result_offset); + + idx_t scanned = 0; + while (scanned < scan_count) { + idx_t to_scan = MinValue(scan_count - scanned, scan_state.LeftInGroup()); + scan_state.template ScanGroup(current_result_ptr + scanned, to_scan); + scanned += to_scan; + } +} + +template +void ChimpSkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { + auto &scan_state = (ChimpScanState &)*state.scan_state; + scan_state.Skip(segment, skip_count); +} + +template +void ChimpScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + ChimpScanPartial(segment, state, scan_count, result, 0); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/algorithm/patas.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/algorithm/patas.hpp new file mode 100644 index 00000000..6dcddcb2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/algorithm/patas.hpp @@ -0,0 +1,129 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/patas/algorithm/patas.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/algorithm/byte_writer.hpp" +#include "duckdb/storage/compression/chimp/algorithm/ring_buffer.hpp" +#include "duckdb/storage/compression/chimp/algorithm/byte_reader.hpp" +#include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" +#include "duckdb/storage/compression/chimp/algorithm/packed_data.hpp" +#include "duckdb/storage/compression/patas/shared.hpp" +#include "duckdb/common/bit_utils.hpp" + +namespace duckdb { + +namespace patas { + +template +class PatasCompressionState { +public: + PatasCompressionState() : index(0), first(true) { + } + +public: + void Reset() { + index = 0; + first = true; + ring_buffer.Reset(); + packed_data_buffer.Reset(); + } + void SetOutputBuffer(uint8_t *output) { + byte_writer.SetStream(output); + Reset(); + } + idx_t Index() const { + return index; + } + +public: + void UpdateMetadata(uint8_t trailing_zero, uint8_t byte_count, uint8_t index_diff) { + if (!EMPTY) { + packed_data_buffer.Insert(PackedDataUtils::Pack(index_diff, byte_count, trailing_zero)); + } + index++; + } + +public: + ByteWriter byte_writer; + PackedDataBuffer packed_data_buffer; + idx_t index; + RingBuffer ring_buffer; + bool first; +}; + +template +struct PatasCompression { + using State = PatasCompressionState; + static constexpr uint8_t EXACT_TYPE_BITSIZE = sizeof(EXACT_TYPE) * 8; + + static void Store(EXACT_TYPE value, State &state) { + if (state.first) { + StoreFirst(value, state); + } else { + StoreCompressed(value, state); + } + } + + static void StoreFirst(EXACT_TYPE value, State &state) { + // write first value, uncompressed + state.ring_buffer.template Insert(value); + state.byte_writer.template WriteValue(value); + state.first = false; + state.UpdateMetadata(0, sizeof(EXACT_TYPE), 0); + } + + static void StoreCompressed(EXACT_TYPE value, State &state) { + auto key = state.ring_buffer.Key(value); + uint64_t reference_index = state.ring_buffer.IndexOf(key); + + // Find the reference value to use when compressing the current value + const bool exceeds_highest_index = reference_index > state.ring_buffer.Size(); + const bool difference_too_big = + ((state.ring_buffer.Size() + 1) - reference_index) >= ChimpConstants::BUFFER_SIZE; + if (exceeds_highest_index || difference_too_big) { + // Reference index is not in range, use the directly previous value + reference_index = state.ring_buffer.Size(); + } + const auto reference_value = state.ring_buffer.Value(reference_index % ChimpConstants::BUFFER_SIZE); + + // XOR with previous value + EXACT_TYPE xor_result = value ^ reference_value; + + // Figure out the trailing zeros (max 6 bits) + const uint8_t trailing_zero = CountZeros::Trailing(xor_result); + const uint8_t leading_zero = CountZeros::Leading(xor_result); + + const bool is_equal = xor_result == 0; + + // Figure out the significant bytes (max 3 bits) + const uint8_t significant_bits = !is_equal * (EXACT_TYPE_BITSIZE - trailing_zero - leading_zero); + const uint8_t significant_bytes = (significant_bits >> 3) + ((significant_bits & 7) != 0); + + // Avoid an invalid shift error when xor_result is 0 + state.byte_writer.template WriteValue(xor_result >> (trailing_zero - is_equal), significant_bits); + + state.ring_buffer.Insert(value); + const uint8_t index_difference = state.ring_buffer.Size() - reference_index; + state.UpdateMetadata(trailing_zero - is_equal, significant_bytes, index_difference); + } +}; + +// Decompression + +template +struct PatasDecompression { + static inline EXACT_TYPE DecompressValue(ByteReader &byte_reader, uint8_t byte_count, uint8_t trailing_zero, + EXACT_TYPE previous) { + return (byte_reader.ReadValue(byte_count, trailing_zero) << trailing_zero) ^ previous; + } +}; + +} // namespace patas + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas.hpp new file mode 100644 index 00000000..a668b4e3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/patas/patas.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/patas/algorithm/patas.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/function/compression_function.hpp" + +namespace duckdb { + +using byte_index_t = uint32_t; + +//! FIXME: replace ChimpType with this +template +struct FloatingToExact {}; + +template <> +struct FloatingToExact { + typedef uint64_t type; +}; + +template <> +struct FloatingToExact { + typedef uint32_t type; +}; + +template +struct PatasState { +public: + using EXACT_TYPE = typename FloatingToExact::type; + + PatasState(void *state_p = nullptr) : data_ptr(state_p), patas_state() { + } + //! The Compress/Analyze State + void *data_ptr; + patas::PatasCompressionState patas_state; + +public: + void AssignDataBuffer(uint8_t *data_out) { + patas_state.SetOutputBuffer(data_out); + } + + template + bool Update(T uncompressed_value, bool is_valid) { + OP::template Operation(uncompressed_value, is_valid, data_ptr); + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_analyze.hpp new file mode 100644 index 00000000..6a6bca99 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_analyze.hpp @@ -0,0 +1,139 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/patas/patas_analyze.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/patas/patas.hpp" +#include "duckdb/function/compression_function.hpp" + +namespace duckdb { + +struct EmptyPatasWriter; + +template +struct PatasAnalyzeState : public AnalyzeState { +public: + using EXACT_TYPE = typename FloatingToExact::type; + + PatasAnalyzeState() : state((void *)this) { + state.AssignDataBuffer(nullptr); + } + PatasState state; + idx_t group_idx = 0; + idx_t data_byte_size = 0; + idx_t metadata_byte_size = 0; + //! To optimally store NULL, we keep track of the directly previous value + EXACT_TYPE previous_value; + +public: + void WriteValue(EXACT_TYPE value, bool is_valid) { + if (!is_valid) { + value = previous_value; + } + //! Keep track of when a segment would end, to accurately simulate Reset()s in compress step + if (!HasEnoughSpace()) { + StartNewSegment(); + } + patas::PatasCompression::Store(value, state.patas_state); + previous_value = value; + group_idx++; + if (group_idx == PatasPrimitives::PATAS_GROUP_SIZE) { + StartNewGroup(); + } + } + + idx_t CurrentGroupMetadataSize() const { + idx_t metadata_size = 0; + + // Offset to the data of the group + metadata_size += sizeof(uint32_t); + // Packed Trailing zeros + significant bytes + index_offsets for group + metadata_size += 2 * group_idx; + return metadata_size; + } + + void StartNewSegment() { + StartNewGroup(); + data_byte_size += UsedSpace(); + metadata_byte_size += PatasPrimitives::HEADER_SIZE; + state.patas_state.byte_writer.SetStream(nullptr); + } + + idx_t RequiredSpace() const { + idx_t required_space = 0; + required_space += sizeof(EXACT_TYPE); + required_space += sizeof(uint16_t); + return required_space; + } + + void StartNewGroup() { + previous_value = 0; + metadata_byte_size += CurrentGroupMetadataSize(); + group_idx = 0; + state.patas_state.Reset(); + } + + idx_t UsedSpace() const { + return state.patas_state.byte_writer.BytesWritten(); + } + + bool HasEnoughSpace() { + idx_t total_bytes_used = 0; + total_bytes_used += AlignValue(PatasPrimitives::HEADER_SIZE + UsedSpace() + RequiredSpace()); + total_bytes_used += CurrentGroupMetadataSize(); + total_bytes_used += metadata_byte_size; + return total_bytes_used <= Storage::BLOCK_SIZE; + } + + idx_t TotalUsedBytes() const { + return metadata_byte_size + AlignValue(data_byte_size + UsedSpace()); + } +}; + +struct EmptyPatasWriter { + + template + static void Operation(VALUE_TYPE uncompressed_value, bool is_valid, void *state_p) { + using EXACT_TYPE = typename FloatingToExact::type; + + auto state_wrapper = (PatasAnalyzeState *)state_p; + state_wrapper->WriteValue(Load(const_data_ptr_cast(&uncompressed_value)), is_valid); + } +}; + +template +unique_ptr PatasInitAnalyze(ColumnData &col_data, PhysicalType type) { + return make_uniq>(); +} + +template +bool PatasAnalyze(AnalyzeState &state, Vector &input, idx_t count) { + auto &analyze_state = (PatasAnalyzeState &)state; + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + analyze_state.state.template Update(data[idx], vdata.validity.RowIsValid(idx)); + } + return true; +} + +template +idx_t PatasFinalAnalyze(AnalyzeState &state) { + auto &patas_state = (PatasAnalyzeState &)state; + // Finish the last "segment" + patas_state.StartNewSegment(); + const auto final_analyze_size = patas_state.TotalUsedBytes(); + // Multiply the final size to factor in the extra cost of decompression time + const auto multiplier = 1.2; + return final_analyze_size * multiplier; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_compress.hpp new file mode 100644 index 00000000..28579dbb --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_compress.hpp @@ -0,0 +1,233 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/patas/patas_compress.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/bitpacking.hpp" +#include "duckdb/storage/compression/patas/patas.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/compression/patas/patas_analyze.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/operator/subtract.hpp" + +#include + +namespace duckdb { + +// State + +template +struct PatasCompressionState : public CompressionState { +public: + using EXACT_TYPE = typename FloatingToExact::type; + + struct PatasWriter { + + template + static void Operation(VALUE_TYPE value, bool is_valid, void *state_p) { + //! Need access to the CompressionState to be able to flush the segment + auto state_wrapper = (PatasCompressionState *)state_p; + + if (!state_wrapper->HasEnoughSpace()) { + // Segment is full + auto row_start = state_wrapper->current_segment->start + state_wrapper->current_segment->count; + state_wrapper->FlushSegment(); + state_wrapper->CreateEmptySegment(row_start); + } + + if (is_valid) { + NumericStats::Update(state_wrapper->current_segment->stats.statistics, value); + } + + state_wrapper->WriteValue(Load(const_data_ptr_cast(&value))); + } + }; + + explicit PatasCompressionState(ColumnDataCheckpointer &checkpointer, PatasAnalyzeState *analyze_state) + : checkpointer(checkpointer), + function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_PATAS)) { + CreateEmptySegment(checkpointer.GetRowGroup().start); + + state.data_ptr = (void *)this; + state.patas_state.packed_data_buffer.SetBuffer(packed_data); + state.patas_state.Reset(); + } + + ColumnDataCheckpointer &checkpointer; + CompressionFunction &function; + unique_ptr current_segment; + BufferHandle handle; + idx_t group_idx = 0; + uint16_t packed_data[PatasPrimitives::PATAS_GROUP_SIZE]; + + // Ptr to next free spot in segment; + data_ptr_t segment_data; + data_ptr_t metadata_ptr; + uint32_t next_group_byte_index_start = PatasPrimitives::HEADER_SIZE; + // The total size of metadata in the current segment + idx_t metadata_byte_size = 0; + + PatasState state; + +public: + idx_t RequiredSpace() const { + idx_t required_space = sizeof(EXACT_TYPE); + // byte offset of data + required_space += sizeof(byte_index_t); + // byte size of the packed_data_block + required_space += sizeof(uint16_t); + return required_space; + } + + // How many bytes the data occupies for the current segment + idx_t UsedSpace() const { + return state.patas_state.byte_writer.BytesWritten(); + } + + idx_t RemainingSpace() const { + return metadata_ptr - (handle.Ptr() + UsedSpace()); + } + + idx_t CurrentGroupMetadataSize() const { + idx_t metadata_size = 0; + + metadata_size += sizeof(byte_index_t); + metadata_size += sizeof(uint16_t) * group_idx; + return metadata_size; + } + + // The current segment has enough space to fit this new value + bool HasEnoughSpace() { + if (handle.Ptr() + AlignValue(PatasPrimitives::HEADER_SIZE + UsedSpace() + RequiredSpace()) >= + (metadata_ptr - CurrentGroupMetadataSize())) { + return false; + } + return true; + } + + void CreateEmptySegment(idx_t row_start) { + next_group_byte_index_start = PatasPrimitives::HEADER_SIZE; + group_idx = 0; + metadata_byte_size = 0; + auto &db = checkpointer.GetDatabase(); + auto &type = checkpointer.GetType(); + auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); + compressed_segment->function = function; + current_segment = std::move(compressed_segment); + + auto &buffer_manager = BufferManager::GetBufferManager(db); + handle = buffer_manager.Pin(current_segment->block); + + segment_data = handle.Ptr() + PatasPrimitives::HEADER_SIZE; + metadata_ptr = handle.Ptr() + Storage::BLOCK_SIZE; + state.AssignDataBuffer(segment_data); + state.patas_state.Reset(); + } + + void Append(UnifiedVectorFormat &vdata, idx_t count) { + auto data = UnifiedVectorFormat::GetData(vdata); + + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + state.template Update(data[idx], vdata.validity.RowIsValid(idx)); + } + } + + void WriteValue(EXACT_TYPE value) { + current_segment->count++; + patas::PatasCompression::Store(value, state.patas_state); + group_idx++; + if (group_idx == PatasPrimitives::PATAS_GROUP_SIZE) { + FlushGroup(); + } + } + + void FlushGroup() { + metadata_ptr -= sizeof(byte_index_t); + metadata_byte_size += sizeof(byte_index_t); + // Store where this groups data starts, relative to the start of the segment + Store(next_group_byte_index_start, metadata_ptr); + next_group_byte_index_start = PatasPrimitives::HEADER_SIZE + UsedSpace(); + + // Store the packed data blocks (7 + 6 + 3 bits) + metadata_ptr -= group_idx * sizeof(uint16_t); + metadata_byte_size += group_idx * sizeof(uint16_t); + memcpy(metadata_ptr, packed_data, sizeof(uint16_t) * group_idx); + + state.patas_state.Reset(); + group_idx = 0; + } + + //! FIXME: only compact if the unused space meets a certain threshold (20%) + void FlushSegment() { + if (group_idx != 0) { + FlushGroup(); + } + auto &checkpoint_state = checkpointer.GetCheckpointState(); + auto dataptr = handle.Ptr(); + + // Compact the segment by moving the metadata next to the data. + idx_t bytes_used_by_data = PatasPrimitives::HEADER_SIZE + UsedSpace(); + idx_t metadata_offset = AlignValue(bytes_used_by_data); + // Verify that the metadata_ptr does not cross this threshold + D_ASSERT(dataptr + metadata_offset <= metadata_ptr); + idx_t metadata_size = dataptr + Storage::BLOCK_SIZE - metadata_ptr; + idx_t total_segment_size = metadata_offset + metadata_size; +#ifdef DEBUG + //! Copy the first 4 bytes of the metadata + uint32_t verify_bytes; + std::memcpy((void *)&verify_bytes, metadata_ptr, 4); +#endif + memmove(dataptr + metadata_offset, metadata_ptr, metadata_size); +#ifdef DEBUG + //! Now assert that the memmove was correct + D_ASSERT(verify_bytes == *(uint32_t *)(dataptr + metadata_offset)); +#endif + // Store the offset to the metadata + Store(metadata_offset + metadata_size, dataptr); + handle.Destroy(); + checkpoint_state.FlushSegment(std::move(current_segment), total_segment_size); + } + + void Finalize() { + FlushSegment(); + current_segment.reset(); + } +}; + +// Compression Functions + +template +unique_ptr PatasInitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr state) { + return make_uniq>(checkpointer, (PatasAnalyzeState *)state.get()); +} + +template +void PatasCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { + auto &state = (PatasCompressionState &)state_p; + UnifiedVectorFormat vdata; + scan_vector.ToUnifiedFormat(count, vdata); + state.Append(vdata, count); +} + +template +void PatasFinalizeCompress(CompressionState &state_p) { + auto &state = (PatasCompressionState &)state_p; + state.Finalize(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_fetch.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_fetch.hpp new file mode 100644 index 00000000..9a6640c2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_fetch.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/patas/patas_fetch.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/patas/patas.hpp" +#include "duckdb/storage/compression/patas/patas_scan.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/operator/subtract.hpp" + +namespace duckdb { + +template +void PatasFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { + using EXACT_TYPE = typename FloatingToExact::type; + + PatasScanState scan_state(segment); + scan_state.Skip(segment, row_id); + auto result_data = FlatVector::GetData(result); + result_data[result_idx] = (EXACT_TYPE)0; + + if (scan_state.GroupFinished() && scan_state.total_value_count < scan_state.count) { + scan_state.LoadGroup(scan_state.group_state.values); + } + scan_state.group_state.Scan((uint8_t *)(result_data + result_idx), 1); + scan_state.total_value_count++; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp new file mode 100644 index 00000000..edce133a --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp @@ -0,0 +1,242 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/storage/compression/chimp/chimp_scan.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/compression/chimp/chimp.hpp" +#include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +//! Do not change order of these variables +struct PatasUnpackedValueStats { + uint8_t significant_bytes; + uint8_t trailing_zeros; + uint8_t index_diff; +}; + +template +struct PatasGroupState { +public: + void Init(uint8_t *data) { + byte_reader.SetStream(data); + } + + idx_t BytesRead() const { + return byte_reader.Index(); + } + + void Reset() { + index = 0; + } + + void LoadPackedData(uint16_t *packed_data, idx_t count) { + for (idx_t i = 0; i < count; i++) { + auto &unpacked = unpacked_data[i]; + PackedDataUtils::Unpack(packed_data[i], (UnpackedData &)unpacked); + } + } + + template + void Scan(uint8_t *dest, idx_t count) { + if (!SKIP) { + memcpy(dest, (void *)(values + index), sizeof(EXACT_TYPE) * count); + } + index += count; + } + + template + void LoadValues(EXACT_TYPE *value_buffer, idx_t count) { + if (SKIP) { + return; + } + value_buffer[0] = (EXACT_TYPE)0; + for (idx_t i = 0; i < count; i++) { + value_buffer[i] = patas::PatasDecompression::DecompressValue( + byte_reader, unpacked_data[i].significant_bytes, unpacked_data[i].trailing_zeros, + value_buffer[i - unpacked_data[i].index_diff]); + } + } + +public: + idx_t index; + PatasUnpackedValueStats unpacked_data[PatasPrimitives::PATAS_GROUP_SIZE]; + EXACT_TYPE values[PatasPrimitives::PATAS_GROUP_SIZE]; + +private: + ByteReader byte_reader; +}; + +template +struct PatasScanState : public SegmentScanState { +public: + using EXACT_TYPE = typename FloatingToExact::type; + + explicit PatasScanState(ColumnSegment &segment) : segment(segment), count(segment.count) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + + handle = buffer_manager.Pin(segment.block); + // ScanStates never exceed the boundaries of a Segment, + // but are not guaranteed to start at the beginning of the Block + segment_data = handle.Ptr() + segment.GetBlockOffset(); + auto metadata_offset = Load(segment_data); + metadata_ptr = segment_data + metadata_offset; + } + + BufferHandle handle; + data_ptr_t metadata_ptr; + data_ptr_t segment_data; + idx_t total_value_count = 0; + PatasGroupState group_state; + + ColumnSegment &segment; + idx_t count; + + idx_t LeftInGroup() const { + return PatasPrimitives::PATAS_GROUP_SIZE - (total_value_count % PatasPrimitives::PATAS_GROUP_SIZE); + } + + inline bool GroupFinished() const { + return (total_value_count % PatasPrimitives::PATAS_GROUP_SIZE) == 0; + } + + // Scan up to a group boundary + template + void ScanGroup(EXACT_TYPE *values, idx_t group_size) { + D_ASSERT(group_size <= PatasPrimitives::PATAS_GROUP_SIZE); + D_ASSERT(group_size <= LeftInGroup()); + + if (GroupFinished() && total_value_count < count) { + if (group_size == PatasPrimitives::PATAS_GROUP_SIZE) { + LoadGroup(values); + total_value_count += group_size; + return; + } else { + // Even if SKIP is given, group size is not big enough to be able to fully skip the entire group + LoadGroup(group_state.values); + } + } + group_state.template Scan((uint8_t *)values, group_size); + + total_value_count += group_size; + } + + // Using the metadata, we can avoid loading any of the data if we don't care about the group at all + void SkipGroup() { + // Skip the offset indicating where the data starts + metadata_ptr -= sizeof(uint32_t); + idx_t group_size = MinValue((idx_t)PatasPrimitives::PATAS_GROUP_SIZE, count - total_value_count); + // Skip the blocks of packed data + metadata_ptr -= sizeof(uint16_t) * group_size; + + total_value_count += group_size; + } + + template + void LoadGroup(EXACT_TYPE *value_buffer) { + group_state.Reset(); + + // Load the offset indicating where a groups data starts + metadata_ptr -= sizeof(uint32_t); + auto data_byte_offset = Load(metadata_ptr); + D_ASSERT(data_byte_offset < Storage::BLOCK_SIZE); + + // Initialize the byte_reader with the data values for the group + group_state.Init(segment_data + data_byte_offset); + + idx_t group_size = MinValue((idx_t)PatasPrimitives::PATAS_GROUP_SIZE, (count - total_value_count)); + + // Read the compacted blocks of (7 + 6 + 3 bits) value stats + metadata_ptr -= sizeof(uint16_t) * group_size; + group_state.LoadPackedData((uint16_t *)metadata_ptr, group_size); + + // Read all the values to the specified 'value_buffer' + group_state.template LoadValues(value_buffer, group_size); + } + +public: + //! Skip the next 'skip_count' values, we don't store the values + void Skip(ColumnSegment &segment, idx_t skip_count) { + using EXACT_TYPE = typename FloatingToExact::type; + + if (total_value_count != 0 && !GroupFinished()) { + // Finish skipping the current group + idx_t to_skip = LeftInGroup(); + skip_count -= to_skip; + ScanGroup(nullptr, to_skip); + } + // Figure out how many entire groups we can skip + // For these groups, we don't even need to process the metadata or values + idx_t groups_to_skip = skip_count / PatasPrimitives::PATAS_GROUP_SIZE; + for (idx_t i = 0; i < groups_to_skip; i++) { + SkipGroup(); + } + skip_count -= PatasPrimitives::PATAS_GROUP_SIZE * groups_to_skip; + if (skip_count == 0) { + return; + } + // For the last group that this skip (partially) touches, we do need to + // load the metadata and values into the group_state + ScanGroup(nullptr, skip_count); + } +}; + +template +unique_ptr PatasInitScan(ColumnSegment &segment) { + auto result = make_uniq_base>(segment); + return result; +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +template +void PatasScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + using EXACT_TYPE = typename FloatingToExact::type; + auto &scan_state = (PatasScanState &)*state.scan_state; + + // Get the pointer to the result values + auto current_result_ptr = FlatVector::GetData(result); + result.SetVectorType(VectorType::FLAT_VECTOR); + current_result_ptr += result_offset; + + idx_t scanned = 0; + while (scanned < scan_count) { + const auto remaining = scan_count - scanned; + const idx_t to_scan = MinValue(remaining, scan_state.LeftInGroup()); + + scan_state.template ScanGroup(current_result_ptr + scanned, to_scan); + scanned += to_scan; + } +} + +template +void PatasSkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { + auto &scan_state = (PatasScanState &)*state.scan_state; + scan_state.Skip(segment, skip_count); +} + +template +void PatasScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + PatasScanPartial(segment, state, scan_count, result, 0); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/shared.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/shared.hpp new file mode 100644 index 00000000..a26143ae --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/shared.hpp @@ -0,0 +1,13 @@ +#pragma once + +namespace duckdb { + +class PatasPrimitives { +public: + static constexpr uint32_t PATAS_GROUP_SIZE = 1024; + static constexpr uint8_t HEADER_SIZE = sizeof(uint32_t); + static constexpr uint8_t BYTECOUNT_BITSIZE = 3; + static constexpr uint8_t INDEX_BITSIZE = 7; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp new file mode 100644 index 00000000..c9f0bf38 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/data_pointer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/storage/block.hpp" +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/common/enums/compression_type.hpp" + +namespace duckdb { + +class Serializer; +class Deserializer; + +struct ColumnSegmentState { + virtual ~ColumnSegmentState() { + } + + virtual void Serialize(Serializer &serializer) const = 0; + static unique_ptr Deserialize(Deserializer &deserializer); + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct DataPointer { + explicit DataPointer(BaseStatistics stats) : statistics(std::move(stats)) { + } + + uint64_t row_start; + uint64_t tuple_count; + BlockPointer block_pointer; + CompressionType compression_type; + //! Type-specific statistics of the segment + BaseStatistics statistics; + //! Serialized segment state + unique_ptr segment_state; + + void Serialize(Serializer &serializer) const; + static DataPointer Deserialize(Deserializer &source); +}; + +struct RowGroupPointer { + uint64_t row_start; + uint64_t tuple_count; + //! The data pointers of the column segments stored in the row group + vector data_pointers; + //! Data pointers to the delete information of the row group (if any) + vector deletes_pointers; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/data_table.hpp b/src/duckdb/src/include/duckdb/storage/data_table.hpp new file mode 100644 index 00000000..b65ace9c --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/data_table.hpp @@ -0,0 +1,228 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/data_table.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/index_type.hpp" +#include "duckdb/common/enums/scan_options.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/storage/index.hpp" +#include "duckdb/storage/table/table_statistics.hpp" +#include "duckdb/storage/block.hpp" +#include "duckdb/storage/statistics/column_statistics.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/storage/table/persistent_table_data.hpp" +#include "duckdb/storage/table/row_group_collection.hpp" +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/transaction/local_storage.hpp" +#include "duckdb/storage/table/data_table_info.hpp" +#include "duckdb/common/unique_ptr.hpp" + +namespace duckdb { +class BoundForeignKeyConstraint; +class ClientContext; +class ColumnDataCollection; +class ColumnDefinition; +class DataTable; +class DuckTransaction; +class OptimisticDataWriter; +class RowGroup; +class StorageManager; +class TableCatalogEntry; +class TableIOManager; +class Transaction; +class WriteAheadLog; +class TableDataWriter; +class ConflictManager; +class TableScanState; +enum class VerifyExistenceType : uint8_t; + +//! DataTable represents a physical table on disk +class DataTable { +public: + //! Constructs a new data table from an (optional) set of persistent segments + DataTable(AttachedDatabase &db, shared_ptr table_io_manager, const string &schema, + const string &table, vector column_definitions_p, + unique_ptr data = nullptr); + //! Constructs a DataTable as a delta on an existing data table with a newly added column + DataTable(ClientContext &context, DataTable &parent, ColumnDefinition &new_column, Expression &default_value); + //! Constructs a DataTable as a delta on an existing data table but with one column removed + DataTable(ClientContext &context, DataTable &parent, idx_t removed_column); + //! Constructs a DataTable as a delta on an existing data table but with one column changed type + DataTable(ClientContext &context, DataTable &parent, idx_t changed_idx, const LogicalType &target_type, + const vector &bound_columns, Expression &cast_expr); + //! Constructs a DataTable as a delta on an existing data table but with one column added new constraint + explicit DataTable(ClientContext &context, DataTable &parent, unique_ptr constraint); + + //! The table info + shared_ptr info; + //! The set of physical columns stored by this DataTable + vector column_definitions; + //! A reference to the database instance + AttachedDatabase &db; + +public: + //! Returns a list of types of the table + vector GetTypes(); + + void InitializeScan(TableScanState &state, const vector &column_ids, + TableFilterSet *table_filter = nullptr); + void InitializeScan(DuckTransaction &transaction, TableScanState &state, const vector &column_ids, + TableFilterSet *table_filters = nullptr); + + //! Returns the maximum amount of threads that should be assigned to scan this data table + idx_t MaxThreads(ClientContext &context); + void InitializeParallelScan(ClientContext &context, ParallelTableScanState &state); + bool NextParallelScan(ClientContext &context, ParallelTableScanState &state, TableScanState &scan_state); + + //! Scans up to STANDARD_VECTOR_SIZE elements from the table starting + //! from offset and store them in result. Offset is incremented with how many + //! elements were returned. + //! Returns true if all pushed down filters were executed during data fetching + void Scan(DuckTransaction &transaction, DataChunk &result, TableScanState &state); + + //! Fetch data from the specific row identifiers from the base table + void Fetch(DuckTransaction &transaction, DataChunk &result, const vector &column_ids, + const Vector &row_ids, idx_t fetch_count, ColumnFetchState &state); + + //! Initializes an append to transaction-local storage + void InitializeLocalAppend(LocalAppendState &state, ClientContext &context); + //! Append a DataChunk to the transaction-local storage of the table. + void LocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, + bool unsafe = false); + //! Finalizes a transaction-local append + void FinalizeLocalAppend(LocalAppendState &state); + //! Append a chunk to the transaction-local storage of this table + void LocalAppend(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk); + //! Append a column data collection to the transaction-local storage of this table + void LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection); + //! Merge a row group collection into the transaction-local storage + void LocalMerge(ClientContext &context, RowGroupCollection &collection); + //! Creates an optimistic writer for this table - used for optimistically writing parallel appends + OptimisticDataWriter &CreateOptimisticWriter(ClientContext &context); + void FinalizeOptimisticWriter(ClientContext &context, OptimisticDataWriter &writer); + + //! Delete the entries with the specified row identifier from the table + idx_t Delete(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, idx_t count); + //! Update the entries with the specified row identifier from the table + void Update(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, + const vector &column_ids, DataChunk &data); + //! Update a single (sub-)column along a column path + //! The column_path vector is a *path* towards a column within the table + //! i.e. if we have a table with a single column S STRUCT(A INT, B INT) + //! and we update the validity mask of "S.B" + //! the column path is: + //! 0 (first column of table) + //! -> 1 (second subcolumn of struct) + //! -> 0 (first subcolumn of INT) + //! This method should only be used from the WAL replay. It does not verify update constraints. + void UpdateColumn(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, + const vector &column_path, DataChunk &updates); + + //! Add an index to the DataTable. NOTE: for CREATE (UNIQUE) INDEX statements, we use the PhysicalCreateARTIndex + //! operator. This function is only used during the WAL replay, and is a much less performant index creation + //! approach. + void WALAddIndex(ClientContext &context, unique_ptr index, + const vector> &expressions); + + //! Fetches an append lock + void AppendLock(TableAppendState &state); + //! Begin appending structs to this table, obtaining necessary locks, etc + void InitializeAppend(DuckTransaction &transaction, TableAppendState &state, idx_t append_count); + //! Append a chunk to the table using the AppendState obtained from InitializeAppend + void Append(DataChunk &chunk, TableAppendState &state); + //! Commit the append + void CommitAppend(transaction_t commit_id, idx_t row_start, idx_t count); + //! Write a segment of the table to the WAL + void WriteToLog(WriteAheadLog &log, idx_t row_start, idx_t count); + //! Revert a set of appends made by the given AppendState, used to revert appends in the event of an error during + //! commit (e.g. because of an I/O exception) + void RevertAppend(idx_t start_row, idx_t count); + void RevertAppendInternal(idx_t start_row); + + void ScanTableSegment(idx_t start_row, idx_t count, const std::function &function); + + //! Merge a row group collection directly into this table - appending it to the end of the table without copying + void MergeStorage(RowGroupCollection &data, TableIndexList &indexes); + + //! Append a chunk with the row ids [row_start, ..., row_start + chunk.size()] to all indexes of the table, returns + //! whether or not the append succeeded + PreservedError AppendToIndexes(DataChunk &chunk, row_t row_start); + static PreservedError AppendToIndexes(TableIndexList &indexes, DataChunk &chunk, row_t row_start); + //! Remove a chunk with the row ids [row_start, ..., row_start + chunk.size()] from all indexes of the table + void RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, row_t row_start); + //! Remove the chunk with the specified set of row identifiers from all indexes of the table + void RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers); + //! Remove the row identifiers from all the indexes of the table + void RemoveFromIndexes(Vector &row_identifiers, idx_t count); + + void SetAsRoot() { + this->is_root = true; + } + bool IsRoot() { + return this->is_root; + } + + //! Get statistics of a physical column within the table + unique_ptr GetStatistics(ClientContext &context, column_t column_id); + //! Sets statistics of a physical column within the table + void SetDistinct(column_t column_id, unique_ptr distinct_stats); + + //! Checkpoint the table to the specified table data writer + void Checkpoint(TableDataWriter &writer, Serializer &metadata_serializer); + void CommitDropTable(); + void CommitDropColumn(idx_t index); + + idx_t GetTotalRows(); + + vector GetColumnSegmentInfo(); + static bool IsForeignKeyIndex(const vector &fk_keys, Index &index, ForeignKeyType fk_type); + + //! Initializes a special scan that is used to create an index on the table, it keeps locks on the table + void InitializeWALCreateIndexScan(CreateIndexScanState &state, const vector &column_ids); + //! Scans the next chunk for the CREATE INDEX operator + bool CreateIndexScan(TableScanState &state, DataChunk &result, TableScanType type); + + //! Verify constraints with a chunk from the Append containing all columns of the table + void VerifyAppendConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, + ConflictManager *conflict_manager = nullptr); + +public: + static void VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &context, DataChunk &chunk, + ConflictManager *conflict_manager); + +private: + //! Verify the new added constraints against current persistent&local data + void VerifyNewConstraint(ClientContext &context, DataTable &parent, const BoundConstraint *constraint); + //! Verify constraints with a chunk from the Update containing only the specified column_ids + void VerifyUpdateConstraints(ClientContext &context, TableCatalogEntry &table, DataChunk &chunk, + const vector &column_ids); + //! Verify constraints with a chunk from the Delete containing all columns of the table + void VerifyDeleteConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk); + + void InitializeScanWithOffset(TableScanState &state, const vector &column_ids, idx_t start_row, + idx_t end_row); + + void VerifyForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, DataChunk &chunk, + VerifyExistenceType verify_type); + void VerifyAppendForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, + DataChunk &chunk); + void VerifyDeleteForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, + DataChunk &chunk); + +private: + //! Lock for appending entries to the table + mutex append_lock; + //! The row groups of the table + shared_ptr row_groups; + //! Whether or not the data table is the root DataTable for this table; the root DataTable is the newest version + //! that can be appended to + atomic is_root; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/database_size.hpp b/src/duckdb/src/include/duckdb/storage/database_size.hpp new file mode 100644 index 00000000..fa2353e5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/database_size.hpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/database_size.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +struct DatabaseSize { + idx_t total_blocks = 0; + idx_t block_size = 0; + idx_t free_blocks = 0; + idx_t used_blocks = 0; + idx_t bytes = 0; + idx_t wal_size = 0; +}; + +struct MetadataBlockInfo { + block_id_t block_id; + idx_t total_blocks; + vector free_list; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/in_memory_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/in_memory_block_manager.hpp new file mode 100644 index 00000000..0f3985cb --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/in_memory_block_manager.hpp @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/in_memory_block_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/storage/block_manager.hpp" + +namespace duckdb { + +//! InMemoryBlockManager is an implementation for a BlockManager +class InMemoryBlockManager : public BlockManager { +public: + using BlockManager::BlockManager; + + // LCOV_EXCL_START + unique_ptr ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) override { + throw InternalException("Cannot perform IO in in-memory database - ConvertBlock!"); + } + unique_ptr CreateBlock(block_id_t block_id, FileBuffer *source_buffer) override { + throw InternalException("Cannot perform IO in in-memory database - CreateBlock!"); + } + block_id_t GetFreeBlockId() override { + throw InternalException("Cannot perform IO in in-memory database - GetFreeBlockId!"); + } + bool IsRootBlock(MetaBlockPointer root) override { + throw InternalException("Cannot perform IO in in-memory database - IsRootBlock!"); + } + void MarkBlockAsFree(block_id_t block_id) override { + throw InternalException("Cannot perform IO in in-memory database - MarkBlockAsFree!"); + } + void MarkBlockAsModified(block_id_t block_id) override { + throw InternalException("Cannot perform IO in in-memory database - MarkBlockAsModified!"); + } + void IncreaseBlockReferenceCount(block_id_t block_id) override { + throw InternalException("Cannot perform IO in in-memory database - IncreaseBlockReferenceCount!"); + } + idx_t GetMetaBlock() override { + throw InternalException("Cannot perform IO in in-memory database - GetMetaBlock!"); + } + void Read(Block &block) override { + throw InternalException("Cannot perform IO in in-memory database - Read!"); + } + void Write(FileBuffer &block, block_id_t block_id) override { + throw InternalException("Cannot perform IO in in-memory database - Write!"); + } + void WriteHeader(DatabaseHeader header) override { + throw InternalException("Cannot perform IO in in-memory database - WriteHeader!"); + } + idx_t TotalBlocks() override { + throw InternalException("Cannot perform IO in in-memory database - TotalBlocks!"); + } + idx_t FreeBlocks() override { + throw InternalException("Cannot perform IO in in-memory database - FreeBlocks!"); + } + // LCOV_EXCL_STOP +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/index.hpp b/src/duckdb/src/include/duckdb/storage/index.hpp new file mode 100644 index 00000000..474421d4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/index.hpp @@ -0,0 +1,170 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/index.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/enums/index_type.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/storage/metadata/metadata_writer.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/types/constraint_conflict_info.hpp" + +namespace duckdb { + +class ClientContext; +class TableIOManager; +class Transaction; +class ConflictManager; + +struct IndexLock; +struct IndexScanState; + +//! The index is an abstract base class that serves as the basis for indexes +class Index { +public: + Index(AttachedDatabase &db, IndexType type, TableIOManager &table_io_manager, const vector &column_ids, + const vector> &unbound_expressions, IndexConstraintType constraint_type); + virtual ~Index() = default; + + //! The type of the index + IndexType type; + //! Associated table io manager + TableIOManager &table_io_manager; + //! Column identifiers to extract key columns from the base table + vector column_ids; + //! Unordered set of column_ids used by the index + unordered_set column_id_set; + //! Unbound expressions used by the index during optimizations + vector> unbound_expressions; + //! The physical types stored in the index + vector types; + //! The logical types of the expressions + vector logical_types; + //! Index constraint type (primary key, foreign key, ...) + IndexConstraintType constraint_type; + + //! Attached database instance + AttachedDatabase &db; + +public: + //! Initialize a single predicate scan on the index with the given expression and column IDs + virtual unique_ptr InitializeScanSinglePredicate(const Transaction &transaction, const Value &value, + const ExpressionType expression_type) = 0; + //! Initialize a two predicate scan on the index with the given expression and column IDs + virtual unique_ptr InitializeScanTwoPredicates(const Transaction &transaction, + const Value &low_value, + const ExpressionType low_expression_type, + const Value &high_value, + const ExpressionType high_expression_type) = 0; + //! Performs a lookup on the index, fetching up to max_count result IDs. Returns true if all row IDs were fetched, + //! and false otherwise + virtual bool Scan(const Transaction &transaction, const DataTable &table, IndexScanState &state, + const idx_t max_count, vector &result_ids) = 0; + + //! Obtain a lock on the index + virtual void InitializeLock(IndexLock &state); + //! Called when data is appended to the index. The lock obtained from InitializeLock must be held + virtual PreservedError Append(IndexLock &state, DataChunk &entries, Vector &row_identifiers) = 0; + //! Obtains a lock and calls Append while holding that lock + PreservedError Append(DataChunk &entries, Vector &row_identifiers); + //! Verify that data can be appended to the index without a constraint violation + virtual void VerifyAppend(DataChunk &chunk) = 0; + //! Verify that data can be appended to the index without a constraint violation using the conflict manager + virtual void VerifyAppend(DataChunk &chunk, ConflictManager &conflict_manager) = 0; + //! Performs constraint checking for a chunk of input data + virtual void CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_manager) = 0; + + //! Deletes all data from the index. The lock obtained from InitializeLock must be held + virtual void CommitDrop(IndexLock &index_lock) = 0; + //! Deletes all data from the index + void CommitDrop(); + //! Delete a chunk of entries from the index. The lock obtained from InitializeLock must be held + virtual void Delete(IndexLock &state, DataChunk &entries, Vector &row_identifiers) = 0; + //! Obtains a lock and calls Delete while holding that lock + void Delete(DataChunk &entries, Vector &row_identifiers); + + //! Insert a chunk of entries into the index + virtual PreservedError Insert(IndexLock &lock, DataChunk &input, Vector &row_identifiers) = 0; + + //! Merge another index into this index. The lock obtained from InitializeLock must be held, and the other + //! index must also be locked during the merge + virtual bool MergeIndexes(IndexLock &state, Index &other_index) = 0; + //! Obtains a lock and calls MergeIndexes while holding that lock + bool MergeIndexes(Index &other_index); + + //! Traverses an ART and vacuums the qualifying nodes. The lock obtained from InitializeLock must be held + virtual void Vacuum(IndexLock &state) = 0; + //! Obtains a lock and calls Vacuum while holding that lock + void Vacuum(); + + //! Returns the string representation of an index, or only traverses and verifies the index + virtual string VerifyAndToString(IndexLock &state, const bool only_verify) = 0; + //! Obtains a lock and calls VerifyAndToString while holding that lock + string VerifyAndToString(const bool only_verify); + + //! Returns true if the index is affected by updates on the specified column IDs, and false otherwise + bool IndexIsUpdated(const vector &column_ids) const; + + //! Returns unique flag + bool IsUnique() { + return (constraint_type == IndexConstraintType::UNIQUE || constraint_type == IndexConstraintType::PRIMARY); + } + //! Returns primary key flag + bool IsPrimary() { + return (constraint_type == IndexConstraintType::PRIMARY); + } + //! Returns foreign key flag + bool IsForeign() { + return (constraint_type == IndexConstraintType::FOREIGN); + } + + //! Serializes the index to disk + virtual BlockPointer Serialize(MetadataWriter &writer); + //! Returns the serialized root block pointer + BlockPointer GetRootBlockPointer() const { + return root_block_pointer; + } + + //! Execute the index expressions on an input chunk + void ExecuteExpressions(DataChunk &input, DataChunk &result); + static string AppendRowError(DataChunk &input, idx_t index); + +protected: + //! Lock used for any changes to the index + mutex lock; + //! Pointer to the index on disk + BlockPointer root_block_pointer; + +private: + //! Bound expressions used during expression execution + vector> bound_expressions; + //! Expression executor to execute the index expressions + ExpressionExecutor executor; + + //! Bind the unbound expressions of the index + unique_ptr BindExpression(unique_ptr expr); + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/magic_bytes.hpp b/src/duckdb/src/include/duckdb/storage/magic_bytes.hpp new file mode 100644 index 00000000..2fda8588 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/magic_bytes.hpp @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/magic_bytes.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { +class FileSystem; + +enum class DataFileType : uint8_t { + FILE_DOES_NOT_EXIST, // file does not exist + DUCKDB_FILE, // duckdb database file + SQLITE_FILE, // sqlite database file + PARQUET_FILE // parquet file +}; + +class MagicBytes { +public: + static DataFileType CheckMagicBytes(FileSystem *fs, const string &path); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp new file mode 100644 index 00000000..abe94735 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp @@ -0,0 +1,91 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/metadata/metadata_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/storage/block.hpp" +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" + +namespace duckdb { +class DatabaseInstance; +struct MetadataBlockInfo; + +struct MetadataBlock { + shared_ptr block; + block_id_t block_id; + vector free_blocks; + + void Write(WriteStream &sink); + static MetadataBlock Read(ReadStream &source); + + idx_t FreeBlocksToInteger(); + void FreeBlocksFromInteger(idx_t blocks); +}; + +struct MetadataPointer { + idx_t block_index : 56; + uint8_t index : 8; +}; + +struct MetadataHandle { + MetadataPointer pointer; + BufferHandle handle; +}; + +class MetadataManager { +public: + //! The size of metadata blocks + static constexpr const idx_t METADATA_BLOCK_SIZE = 4088; + //! The amount of metadata blocks per storage block + static constexpr const idx_t METADATA_BLOCK_COUNT = 64; + +public: + MetadataManager(BlockManager &block_manager, BufferManager &buffer_manager); + ~MetadataManager(); + + MetadataHandle AllocateHandle(); + MetadataHandle Pin(MetadataPointer pointer); + + MetaBlockPointer GetDiskPointer(MetadataPointer pointer, uint32_t offset = 0); + MetadataPointer FromDiskPointer(MetaBlockPointer pointer); + MetadataPointer RegisterDiskPointer(MetaBlockPointer pointer); + + static BlockPointer ToBlockPointer(MetaBlockPointer meta_pointer); + static MetaBlockPointer FromBlockPointer(BlockPointer block_pointer); + + //! Flush all blocks to disk + void Flush(); + + void MarkBlocksAsModified(); + void ClearModifiedBlocks(const vector &pointers); + + vector GetMetadataInfo() const; + idx_t BlockCount(); + + void Write(WriteStream &sink); + void Read(ReadStream &source); + +protected: + BlockManager &block_manager; + BufferManager &buffer_manager; + unordered_map blocks; + unordered_map modified_blocks; + +protected: + block_id_t AllocateNewBlock(); + block_id_t GetNextBlockId(); + + void AddBlock(MetadataBlock new_block, bool if_exists = false); + void AddAndRegisterBlock(MetadataBlock block); + void ConvertToTransient(MetadataBlock &block); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp new file mode 100644 index 00000000..1acb60fe --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/metadata/metadata_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/metadata/metadata_manager.hpp" +#include "duckdb/common/serializer/read_stream.hpp" + +namespace duckdb { + +enum class BlockReaderType { EXISTING_BLOCKS, REGISTER_BLOCKS }; + +class MetadataReader : public ReadStream { +public: + MetadataReader(MetadataManager &manager, MetaBlockPointer pointer, + optional_ptr> read_pointers = nullptr, + BlockReaderType type = BlockReaderType::EXISTING_BLOCKS); + MetadataReader(MetadataManager &manager, BlockPointer pointer); + ~MetadataReader() override; + +public: + //! Read content of size read_size into the buffer + void ReadData(data_ptr_t buffer, idx_t read_size) override; + + MetaBlockPointer GetMetaBlockPointer(); + + MetadataManager &GetMetadataManager() { + return manager; + } + +private: + data_ptr_t BasePtr(); + data_ptr_t Ptr(); + + void ReadNextBlock(); + + MetadataPointer FromDiskPointer(MetaBlockPointer pointer); + +private: + MetadataManager &manager; + BlockReaderType type; + MetadataHandle block; + MetadataPointer next_pointer; + bool has_next_block; + optional_ptr> read_pointers; + idx_t index; + idx_t offset; + idx_t next_offset; + idx_t capacity; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_writer.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_writer.hpp new file mode 100644 index 00000000..206451b7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_writer.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/metadata/metadata_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/metadata/metadata_manager.hpp" +#include "duckdb/common/serializer/write_stream.hpp" + +namespace duckdb { + +class MetadataWriter : public WriteStream { +public: + explicit MetadataWriter(MetadataManager &manager, + optional_ptr> written_pointers = nullptr); + MetadataWriter(const MetadataWriter &) = delete; + MetadataWriter &operator=(const MetadataWriter &) = delete; + ~MetadataWriter() override; + +public: + void WriteData(const_data_ptr_t buffer, idx_t write_size) override; + void Flush(); + + BlockPointer GetBlockPointer(); + MetaBlockPointer GetMetaBlockPointer(); + MetadataManager &GetManager() { + return manager; + } + +protected: + virtual MetadataHandle NextHandle(); + +private: + data_ptr_t BasePtr(); + data_ptr_t Ptr(); + + void NextBlock(); + +private: + MetadataManager &manager; + MetadataHandle block; + MetadataPointer current_pointer; + optional_ptr> written_pointers; + idx_t capacity; + idx_t offset; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/object_cache.hpp b/src/duckdb/src/include/duckdb/storage/object_cache.hpp new file mode 100644 index 00000000..170c0041 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/object_cache.hpp @@ -0,0 +1,86 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/object_cache.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" + +namespace duckdb { +class ClientContext; + +//! ObjectCache is the base class for objects caches in DuckDB +class ObjectCacheEntry { +public: + virtual ~ObjectCacheEntry() { + } + + virtual string GetObjectType() = 0; +}; + +class ObjectCache { +public: + shared_ptr GetObject(const string &key) { + lock_guard glock(lock); + auto entry = cache.find(key); + if (entry == cache.end()) { + return nullptr; + } + return entry->second; + } + + template + shared_ptr Get(const string &key) { + shared_ptr object = GetObject(key); + if (!object || object->GetObjectType() != T::ObjectType()) { + return nullptr; + } + return std::static_pointer_cast(object); + } + + template + shared_ptr GetOrCreate(const string &key, Args &&...args) { + lock_guard glock(lock); + + auto entry = cache.find(key); + if (entry == cache.end()) { + auto value = make_shared(args...); + cache[key] = value; + return value; + } + auto object = entry->second; + if (!object || object->GetObjectType() != T::ObjectType()) { + return nullptr; + } + return std::static_pointer_cast(object); + } + + void Put(string key, shared_ptr value) { + lock_guard glock(lock); + cache[key] = std::move(value); + } + + void Delete(const string &key) { + lock_guard glock(lock); + cache.erase(key); + } + + DUCKDB_API static ObjectCache &GetObjectCache(ClientContext &context); + DUCKDB_API static bool ObjectCacheEnabled(ClientContext &context); + +private: + //! Object Cache + unordered_map> cache; + mutex lock; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp new file mode 100644 index 00000000..59e09452 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/optimistic_data_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/row_group_collection.hpp" + +namespace duckdb { +class PartialBlockManager; + +class OptimisticDataWriter { +public: + OptimisticDataWriter(DataTable &table); + OptimisticDataWriter(DataTable &table, OptimisticDataWriter &parent); + ~OptimisticDataWriter(); + + //! Write a new row group to disk (if possible) + void WriteNewRowGroup(RowGroupCollection &row_groups); + //! Write the last row group of a collection to disk + void WriteLastRowGroup(RowGroupCollection &row_groups); + //! Final flush of the optimistic writer - fully flushes the partial block manager + void FinalFlush(); + //! Flushes a specific row group to disk + void FlushToDisk(RowGroup *row_group); + //! Merge the partially written blocks from one optimistic writer into another + void Merge(OptimisticDataWriter &other); + //! Rollback + void Rollback(); + +private: + //! Prepare a write to disk + bool PrepareWrite(); + +private: + //! The table + DataTable &table; + //! The partial block manager (if we created one yet) + unique_ptr partial_manager; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp new file mode 100644 index 00000000..c689bcae --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp @@ -0,0 +1,151 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/partial_block_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/storage/metadata/metadata_writer.hpp" +#include "duckdb/storage/data_pointer.hpp" + +namespace duckdb { +class DatabaseInstance; +class ClientContext; +class ColumnSegment; +class MetadataReader; +class SchemaCatalogEntry; +class SequenceCatalogEntry; +class TableCatalogEntry; +class ViewCatalogEntry; +class TypeCatalogEntry; + +//! Regions that require zero-initialization to avoid leaking memory +struct UninitializedRegion { + idx_t start; + idx_t end; +}; + +//! The current state of a partial block +struct PartialBlockState { + //! The block id of the partial block + block_id_t block_id; + //! The total bytes that we can assign to this block + uint32_t block_size; + //! Next allocation offset, and also the current allocation size + uint32_t offset; + //! The number of times that this block has been used for partial allocations + uint32_t block_use_count; +}; + +struct PartialBlock { + PartialBlock(PartialBlockState state, BlockManager &block_manager, const shared_ptr &block_handle); + virtual ~PartialBlock() { + } + + //! The current state of a partial block + PartialBlockState state; + //! All uninitialized regions on this block, we need to zero-initialize them when flushing + vector uninitialized_regions; + //! The block manager of the partial block manager + BlockManager &block_manager; + //! The block handle of the underlying block that this partial block writes to + shared_ptr block_handle; + +public: + //! Add regions that need zero-initialization to avoid leaking memory + void AddUninitializedRegion(const idx_t start, const idx_t end); + //! Flush the block to disk and zero-initialize any free space and uninitialized regions + virtual void Flush(const idx_t free_space_left) = 0; + void FlushInternal(const idx_t free_space_left); + virtual void Merge(PartialBlock &other, idx_t offset, idx_t other_size) = 0; + virtual void Clear() = 0; + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct PartialBlockAllocation { + //! The BlockManager owning the block_id + BlockManager *block_manager {nullptr}; + //! The number of assigned bytes to the caller + uint32_t allocation_size; + //! The current state of the partial block + PartialBlockState state; + //! Arbitrary state related to the partial block storage + unique_ptr partial_block; +}; + +enum class CheckpointType { FULL_CHECKPOINT, APPEND_TO_TABLE }; + +//! Enables sharing blocks across some scope. Scope is whatever we want to share +//! blocks across. It may be an entire checkpoint or just a single row group. +//! In any case, they must share a block manager. +class PartialBlockManager { +public: + //! 20% free / 80% utilization + static constexpr const idx_t DEFAULT_MAX_PARTIAL_BLOCK_SIZE = Storage::BLOCK_SIZE / 5 * 4; + //! Max number of shared references to a block. No effective limit by default. + static constexpr const idx_t DEFAULT_MAX_USE_COUNT = 1u << 20; + //! No point letting map size grow unbounded. We'll drop blocks with the + //! least free space first. + static constexpr const idx_t MAX_BLOCK_MAP_SIZE = 1u << 31; + +public: + PartialBlockManager(BlockManager &block_manager, CheckpointType checkpoint_type, + uint32_t max_partial_block_size = DEFAULT_MAX_PARTIAL_BLOCK_SIZE, + uint32_t max_use_count = DEFAULT_MAX_USE_COUNT); + virtual ~PartialBlockManager(); + +public: + //! Flush any remaining partial blocks to disk + void FlushPartialBlocks(); + + PartialBlockAllocation GetBlockAllocation(uint32_t segment_size); + + virtual void AllocateBlock(PartialBlockState &state, uint32_t segment_size); + + void Merge(PartialBlockManager &other); + //! Register a partially filled block that is filled with "segment_size" entries + void RegisterPartialBlock(PartialBlockAllocation &&allocation); + + //! Clear remaining blocks without writing them to disk + void ClearBlocks(); + + //! Rollback all data written by this partial block manager + void Rollback(); + +protected: + BlockManager &block_manager; + CheckpointType checkpoint_type; + //! A map of (available space -> PartialBlock) for partially filled blocks + //! This is a multimap because there might be outstanding partial blocks with + //! the same amount of left-over space + multimap> partially_filled_blocks; + //! The set of written blocks + unordered_set written_blocks; + + //! The maximum size (in bytes) at which a partial block will be considered a partial block + uint32_t max_partial_block_size; + uint32_t max_use_count; + +protected: + //! Try to obtain a partially filled block that can fit "segment_size" bytes + //! If successful, returns true and returns the block_id and offset_in_block to write to + //! Otherwise, returns false + bool GetPartialBlock(idx_t segment_size, unique_ptr &state); + + bool HasBlockAllocation(uint32_t segment_size); + void AddWrittenBlock(block_id_t block); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/segment/uncompressed.hpp b/src/duckdb/src/include/duckdb/storage/segment/uncompressed.hpp new file mode 100644 index 00000000..661d3f0f --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/segment/uncompressed.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/segment/uncompressed.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/column_segment.hpp" + +namespace duckdb { +class DatabaseInstance; + +struct UncompressedFunctions { + static unique_ptr InitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr state); + static void Compress(CompressionState &state_p, Vector &data, idx_t count); + static void FinalizeCompress(CompressionState &state_p); + static void EmptySkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { + } +}; + +struct FixedSizeUncompressed { + static CompressionFunction GetFunction(PhysicalType data_type); +}; + +struct ValidityUncompressed { +public: + static CompressionFunction GetFunction(PhysicalType data_type); + +public: + static const validity_t LOWER_MASKS[65]; + static const validity_t UPPER_MASKS[65]; +}; + +struct StringUncompressed { +public: + static CompressionFunction GetFunction(PhysicalType data_type); + +public: + //! The max string size that is allowed within a block. Strings bigger than this will be labeled as a BIG STRING and + //! offloaded to the overflow blocks. + static constexpr uint16_t STRING_BLOCK_LIMIT = 4096; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp new file mode 100644 index 00000000..809b5f7d --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp @@ -0,0 +1,115 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/single_file_block_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/block.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +class DatabaseInstance; +struct MetadataHandle; + +struct StorageManagerOptions { + bool read_only = false; + bool use_direct_io = false; + DebugInitialize debug_initialize = DebugInitialize::NO_INITIALIZE; +}; + +//! SingleFileBlockManager is an implementation for a BlockManager which manages blocks in a single file +class SingleFileBlockManager : public BlockManager { + //! The location in the file where the block writing starts + static constexpr uint64_t BLOCK_START = Storage::FILE_HEADER_SIZE * 3; + +public: + SingleFileBlockManager(AttachedDatabase &db, string path, StorageManagerOptions options); + + void GetFileFlags(uint8_t &flags, FileLockType &lock, bool create_new); + void CreateNewDatabase(); + void LoadExistingDatabase(); + + //! Creates a new Block using the specified block_id and returns a pointer + unique_ptr ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) override; + unique_ptr CreateBlock(block_id_t block_id, FileBuffer *source_buffer) override; + //! Return the next free block id + block_id_t GetFreeBlockId() override; + //! Returns whether or not a specified block is the root block + bool IsRootBlock(MetaBlockPointer root) override; + //! Mark a block as free (immediately re-writeable) + void MarkBlockAsFree(block_id_t block_id) override; + //! Mark a block as modified (re-writeable after a checkpoint) + void MarkBlockAsModified(block_id_t block_id) override; + //! Increase the reference count of a block. The block should hold at least one reference + void IncreaseBlockReferenceCount(block_id_t block_id) override; + //! Return the meta block id + idx_t GetMetaBlock() override; + //! Read the content of the block from disk + void Read(Block &block) override; + //! Write the given block to disk + void Write(FileBuffer &block, block_id_t block_id) override; + //! Write the header to disk, this is the final step of the checkpointing process + void WriteHeader(DatabaseHeader header) override; + //! Truncate the underlying database file after a checkpoint + void Truncate() override; + + //! Returns the number of total blocks + idx_t TotalBlocks() override; + //! Returns the number of free blocks + idx_t FreeBlocks() override; + +private: + //! Load the free list from the file + void LoadFreeList(); + + void Initialize(DatabaseHeader &header); + + void ReadAndChecksum(FileBuffer &handle, uint64_t location) const; + void ChecksumAndWrite(FileBuffer &handle, uint64_t location) const; + + //! Return the blocks to which we will write the free list and modified blocks + vector GetFreeListBlocks(); + +private: + AttachedDatabase &db; + //! The active DatabaseHeader, either 0 (h1) or 1 (h2) + uint8_t active_header; + //! The path where the file is stored + string path; + //! The file handle + unique_ptr handle; + //! The buffer used to read/write to the headers + FileBuffer header_buffer; + //! The list of free blocks that can be written to currently + set free_list; + //! The list of multi-use blocks (i.e. blocks that have >1 reference in the file) + //! When a multi-use block is marked as modified, the reference count is decreased by 1 instead of directly + //! Appending the block to the modified_blocks list + unordered_map multi_use_blocks; + //! The list of blocks that will be added to the free list + unordered_set modified_blocks; + //! The current meta block id + idx_t meta_block; + //! The current maximum block id, this id will be given away first after the free_list runs out + block_id_t max_block; + //! The block id where the free list can be found + idx_t free_list_id; + //! The current header iteration count + uint64_t iteration_count; + //! The storage manager options + StorageManagerOptions options; + //! Lock for performing various operations in the single file block manager + mutex block_lock; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp new file mode 100644 index 00000000..ff867ef5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp @@ -0,0 +1,157 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/standard_buffer_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/storage/block_manager.hpp" + +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/buffer/buffer_pool.hpp" + +namespace duckdb { +class BlockManager; +class DatabaseInstance; +class TemporaryDirectoryHandle; +struct EvictionQueue; + +//! The BufferManager is in charge of handling memory management for a single database. It cooperatively shares a +//! BufferPool with other BufferManagers, belonging to different databases. It hands out memory buffers that can +//! be used by the database internally, and offers configuration options specific to a database, which need not be +//! shared by the BufferPool, including whether to support swapping temp buffers to disk, and where to swap them to. +class StandardBufferManager : public BufferManager { + friend class BufferHandle; + friend class BlockHandle; + friend class BlockManager; + +public: + StandardBufferManager(DatabaseInstance &db, string temp_directory); + virtual ~StandardBufferManager(); + +public: + static unique_ptr CreateBufferManager(DatabaseInstance &db, string temp_directory); + //! Registers an in-memory buffer that cannot be unloaded until it is destroyed + //! This buffer can be small (smaller than BLOCK_SIZE) + //! Unpin and pin are nops on this block of memory + shared_ptr RegisterSmallMemory(idx_t block_size) final override; + + idx_t GetUsedMemory() const final override; + idx_t GetMaxMemory() const final override; + + //! Allocate an in-memory buffer with a single pin. + //! The allocated memory is released when the buffer handle is destroyed. + DUCKDB_API BufferHandle Allocate(idx_t block_size, bool can_destroy = true, + shared_ptr *block = nullptr) final override; + + //! Reallocate an in-memory buffer that is pinned. + void ReAllocate(shared_ptr &handle, idx_t block_size) final override; + + BufferHandle Pin(shared_ptr &handle) final override; + void Unpin(shared_ptr &handle) final override; + + //! Set a new memory limit to the buffer manager, throws an exception if the new limit is too low and not enough + //! blocks can be evicted + void SetLimit(idx_t limit = (idx_t)-1) final override; + + //! Returns a list of all temporary files + vector GetTemporaryFiles() final override; + + const string &GetTemporaryDirectory() final override { + return temp_directory; + } + + void SetTemporaryDirectory(const string &new_dir) final override; + + DUCKDB_API Allocator &GetBufferAllocator() final override; + + DatabaseInstance &GetDatabase() final override { + return db; + } + + //! Construct a managed buffer. + unique_ptr ConstructManagedBuffer(idx_t size, unique_ptr &&source, + FileBufferType type = FileBufferType::MANAGED_BUFFER) override; + + DUCKDB_API void ReserveMemory(idx_t size) final override; + DUCKDB_API void FreeReservedMemory(idx_t size) final override; + bool HasTemporaryDirectory() const final override; + +protected: + //! Helper + template + TempBufferPoolReservation EvictBlocksOrThrow(idx_t memory_delta, unique_ptr *buffer, ARGS...); + + //! Register an in-memory buffer of arbitrary size, as long as it is >= BLOCK_SIZE. can_destroy signifies whether or + //! not the buffer can be destroyed when unpinned, or whether or not it needs to be written to a temporary file so + //! it can be reloaded. The resulting buffer will already be allocated, but needs to be pinned in order to be used. + //! This needs to be private to prevent creating blocks without ever pinning them: + //! blocks that are never pinned are never added to the eviction queue + shared_ptr RegisterMemory(idx_t block_size, bool can_destroy); + + //! Evict blocks until the currently used memory + extra_memory fit, returns false if this was not possible + //! (i.e. not enough blocks could be evicted) + //! If the "buffer" argument is specified AND the system can find a buffer to re-use for the given allocation size + //! "buffer" will be made to point to the re-usable memory. Note that this is not guaranteed. + //! Returns a pair. result.first indicates if eviction was successful. result.second contains the + //! reservation handle, which can be moved to the BlockHandle that will own the reservation. + BufferPool::EvictionResult EvictBlocks(idx_t extra_memory, idx_t memory_limit, + unique_ptr *buffer = nullptr); + + //! Garbage collect eviction queue + void PurgeQueue() final override; + + BufferPool &GetBufferPool() final override; + + //! Write a temporary buffer to disk + void WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) final override; + //! Read a temporary buffer from disk + unique_ptr ReadTemporaryBuffer(block_id_t id, unique_ptr buffer = nullptr) final override; + //! Get the path of the temporary buffer + string GetTemporaryPath(block_id_t id); + + void DeleteTemporaryFile(block_id_t id) final override; + + void RequireTemporaryDirectory(); + + void AddToEvictionQueue(shared_ptr &handle) final override; + + const char *InMemoryWarning(); + + static data_ptr_t BufferAllocatorAllocate(PrivateAllocatorData *private_data, idx_t size); + static void BufferAllocatorFree(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size); + static data_ptr_t BufferAllocatorRealloc(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, + idx_t size); + + //! When the BlockHandle reaches 0 readers, this creates a new FileBuffer for this BlockHandle and + //! overwrites the data within with garbage. Any readers that do not hold the pin will notice + void VerifyZeroReaders(shared_ptr &handle); + +protected: + //! The database instance + DatabaseInstance &db; + //! The buffer pool + BufferPool &buffer_pool; + //! The directory name where temporary files are stored + string temp_directory; + //! Lock for creating the temp handle + mutex temp_handle_lock; + //! Handle for the temporary directory + unique_ptr temp_directory_handle; + //! The temporary id used for managed buffers + atomic temporary_id; + //! Allocator associated with the buffer manager, that passes all allocations through this buffer manager + Allocator buffer_allocator; + //! Block manager for temp data + unique_ptr temp_block_manager; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp new file mode 100644 index 00000000..790644cc --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp @@ -0,0 +1,139 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/base_statistics.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/storage/statistics/numeric_stats.hpp" +#include "duckdb/storage/statistics/string_stats.hpp" + +namespace duckdb { +struct SelectionVector; + +class Serializer; +class Deserializer; + +class Vector; +struct UnifiedVectorFormat; + +enum class StatsInfo : uint8_t { + CAN_HAVE_NULL_VALUES = 0, + CANNOT_HAVE_NULL_VALUES = 1, + CAN_HAVE_VALID_VALUES = 2, + CANNOT_HAVE_VALID_VALUES = 3, + CAN_HAVE_NULL_AND_VALID_VALUES = 4 +}; + +enum class StatisticsType : uint8_t { NUMERIC_STATS, STRING_STATS, LIST_STATS, STRUCT_STATS, BASE_STATS }; + +class BaseStatistics { + friend struct NumericStats; + friend struct StringStats; + friend struct StructStats; + friend struct ListStats; + +public: + DUCKDB_API ~BaseStatistics(); + // disable copy constructors + BaseStatistics(const BaseStatistics &other) = delete; + BaseStatistics &operator=(const BaseStatistics &) = delete; + //! enable move constructors + DUCKDB_API BaseStatistics(BaseStatistics &&other) noexcept; + DUCKDB_API BaseStatistics &operator=(BaseStatistics &&) noexcept; + +public: + //! Creates a set of statistics for data that is unknown, i.e. "has_null" is true, "has_no_null" is true, etc + //! This can be used in case nothing is known about the data - or can be used as a baseline when only a few things + //! are known + static BaseStatistics CreateUnknown(LogicalType type); + //! Creates statistics for an empty database, i.e. "has_null" is false, "has_no_null" is false, etc + //! This is used when incrementally constructing statistics by constantly adding new values + static BaseStatistics CreateEmpty(LogicalType type); + + DUCKDB_API StatisticsType GetStatsType() const; + DUCKDB_API static StatisticsType GetStatsType(const LogicalType &type); + + DUCKDB_API bool CanHaveNull() const; + DUCKDB_API bool CanHaveNoNull() const; + + void SetDistinctCount(idx_t distinct_count); + + bool IsConstant() const; + + const LogicalType &GetType() const { + return type; + } + + void Set(StatsInfo info); + void CombineValidity(BaseStatistics &left, BaseStatistics &right); + void CopyValidity(BaseStatistics &stats); + inline void SetHasNull() { + has_null = true; + } + inline void SetHasNoNull() { + has_no_null = true; + } + + void Merge(const BaseStatistics &other); + + void Copy(const BaseStatistics &other); + + BaseStatistics Copy() const; + unique_ptr ToUnique() const; + void CopyBase(const BaseStatistics &orig); + + void Serialize(Serializer &serializer) const; + static BaseStatistics Deserialize(Deserializer &deserializer); + + //! Verify that a vector does not violate the statistics + void Verify(Vector &vector, const SelectionVector &sel, idx_t count) const; + void Verify(Vector &vector, idx_t count) const; + + string ToString() const; + + idx_t GetDistinctCount(); + static BaseStatistics FromConstant(const Value &input); + +private: + BaseStatistics(); + explicit BaseStatistics(LogicalType type); + + static void Construct(BaseStatistics &stats, LogicalType type); + + void InitializeUnknown(); + void InitializeEmpty(); + + static BaseStatistics CreateUnknownType(LogicalType type); + static BaseStatistics CreateEmptyType(LogicalType type); + static BaseStatistics FromConstantType(const Value &input); + +private: + //! The type of the logical segment + LogicalType type; + //! Whether or not the segment can contain NULL values + bool has_null; + //! Whether or not the segment can contain values that are not null + bool has_no_null; + // estimate that one may have even if distinct_stats==nullptr + idx_t distinct_count; + //! Numeric and String stats + union { + //! Numeric stats data, for numeric stats + NumericStatsData numeric_data; + //! String stats data, for string stats + StringStatsData string_data; + } stats_union; + //! Child stats (for LIST and STRUCT) + unsafe_unique_array child_stats; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/column_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/column_statistics.hpp new file mode 100644 index 00000000..0cc4c69a --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/column_statistics.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/column_statistics.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/distinct_statistics.hpp" + +namespace duckdb { +class Serializer; + +class ColumnStatistics { +public: + explicit ColumnStatistics(BaseStatistics stats_p); + ColumnStatistics(BaseStatistics stats_p, unique_ptr distinct_stats_p); + +public: + static shared_ptr CreateEmptyStats(const LogicalType &type); + + void Merge(ColumnStatistics &other); + + void UpdateDistinctStatistics(Vector &v, idx_t count); + + BaseStatistics &Statistics(); + + bool HasDistinctStats(); + DistinctStatistics &DistinctStats(); + void SetDistinct(unique_ptr distinct_stats); + + shared_ptr Copy() const; + + void Serialize(Serializer &serializer) const; + static shared_ptr Deserialize(Deserializer &source); + +private: + BaseStatistics stats; + //! The approximate count distinct stats of the column + unique_ptr distinct_stats; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/distinct_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/distinct_statistics.hpp new file mode 100644 index 00000000..329ce879 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/distinct_statistics.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/distinct_statistics.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/types/hyperloglog.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { +class Vector; +class Serializer; +class Deserializer; + +class DistinctStatistics { +public: + DistinctStatistics(); + explicit DistinctStatistics(unique_ptr log, idx_t sample_count, idx_t total_count); + + //! The HLL of the table + unique_ptr log; + //! How many values have been sampled into the HLL + atomic sample_count; + //! How many values have been inserted (before sampling) + atomic total_count; + +public: + void Merge(const DistinctStatistics &other); + + unique_ptr Copy() const; + + void Update(Vector &update, idx_t count, bool sample = true); + void Update(UnifiedVectorFormat &update_data, const LogicalType &ptype, idx_t count, bool sample = true); + + string ToString() const; + idx_t GetCount() const; + + static bool TypeIsSupported(const LogicalType &type); + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! For distinct statistics we sample the input to speed up insertions + static constexpr const double SAMPLE_RATE = 0.1; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/list_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/list_stats.hpp new file mode 100644 index 00000000..8dfb2f43 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/list_stats.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/list_stats.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hugeint.hpp" + +namespace duckdb { +class BaseStatistics; +struct SelectionVector; +class Vector; + +struct ListStats { + DUCKDB_API static void Construct(BaseStatistics &stats); + DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); + DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); + + DUCKDB_API static const BaseStatistics &GetChildStats(const BaseStatistics &stats); + DUCKDB_API static BaseStatistics &GetChildStats(BaseStatistics &stats); + DUCKDB_API static void SetChildStats(BaseStatistics &stats, unique_ptr new_stats); + + DUCKDB_API static void Serialize(const BaseStatistics &stats, Serializer &serializer); + DUCKDB_API static void Deserialize(Deserializer &deserializer, BaseStatistics &base); + + DUCKDB_API static string ToString(const BaseStatistics &stats); + + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Copy(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/node_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/node_statistics.hpp new file mode 100644 index 00000000..31b5c3dc --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/node_statistics.hpp @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/node_statistics.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +class NodeStatistics { +public: + NodeStatistics() : has_estimated_cardinality(false), has_max_cardinality(false) { + } + explicit NodeStatistics(idx_t estimated_cardinality) + : has_estimated_cardinality(true), estimated_cardinality(estimated_cardinality), has_max_cardinality(false) { + } + NodeStatistics(idx_t estimated_cardinality, idx_t max_cardinality) + : has_estimated_cardinality(true), estimated_cardinality(estimated_cardinality), has_max_cardinality(true), + max_cardinality(max_cardinality) { + } + + //! Whether or not the node has an estimated cardinality specified + bool has_estimated_cardinality; + //! The estimated cardinality at the specified node + idx_t estimated_cardinality; + //! Whether or not the node has a maximum cardinality specified + bool has_max_cardinality; + //! The max possible cardinality at the specified node + idx_t max_cardinality; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats.hpp new file mode 100644 index 00000000..3d6f4490 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats.hpp @@ -0,0 +1,112 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/numeric_stats.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/statistics/numeric_stats_union.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { +class BaseStatistics; +struct SelectionVector; +class Vector; + +struct NumericStatsData { + //! Whether or not the value has a max value + bool has_min; + //! Whether or not the segment has a min value + bool has_max; + //! The minimum value of the segment + NumericValueUnion min; + //! The maximum value of the segment + NumericValueUnion max; +}; + +struct NumericStats { + //! Unknown statistics - i.e. "has_min" is false, "has_max" is false + DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); + //! Empty statistics - i.e. "min = MaxValue, max = MinValue" + DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); + + //! Returns true if the stats has a constant value + DUCKDB_API static bool IsConstant(const BaseStatistics &stats); + //! Returns true if the stats has both a min and max value defined + DUCKDB_API static bool HasMinMax(const BaseStatistics &stats); + //! Returns true if the stats has a min value defined + DUCKDB_API static bool HasMin(const BaseStatistics &stats); + //! Returns true if the stats has a max value defined + DUCKDB_API static bool HasMax(const BaseStatistics &stats); + //! Returns the min value - throws an exception if there is no min value + DUCKDB_API static Value Min(const BaseStatistics &stats); + //! Returns the max value - throws an exception if there is no max value + DUCKDB_API static Value Max(const BaseStatistics &stats); + //! Sets the min value of the statistics + DUCKDB_API static void SetMin(BaseStatistics &stats, const Value &val); + //! Sets the max value of the statistics + DUCKDB_API static void SetMax(BaseStatistics &stats, const Value &val); + + //! Check whether or not a given comparison with a constant could possibly be satisfied by rows given the statistics + DUCKDB_API static FilterPropagateResult CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, + const Value &constant); + + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other_p); + + DUCKDB_API static void Serialize(const BaseStatistics &stats, Serializer &serializer); + DUCKDB_API static void Deserialize(Deserializer &deserializer, BaseStatistics &stats); + + DUCKDB_API static string ToString(const BaseStatistics &stats); + + template + static inline void UpdateValue(T new_value, T &min, T &max) { + if (LessThan::Operation(new_value, min)) { + min = new_value; + } + if (GreaterThan::Operation(new_value, max)) { + max = new_value; + } + } + + template + static inline void Update(BaseStatistics &stats, T new_value) { + auto &nstats = NumericStats::GetDataUnsafe(stats); + UpdateValue(new_value, nstats.min.GetReferenceUnsafe(), nstats.max.GetReferenceUnsafe()); + } + + static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + + template + static T GetMin(const BaseStatistics &stats) { + return NumericStats::Min(stats).GetValueUnsafe(); + } + template + static T GetMax(const BaseStatistics &stats) { + return NumericStats::Max(stats).GetValueUnsafe(); + } + template + static T GetMinUnsafe(const BaseStatistics &stats); + template + static T GetMaxUnsafe(const BaseStatistics &stats); + +private: + static NumericStatsData &GetDataUnsafe(BaseStatistics &stats); + static const NumericStatsData &GetDataUnsafe(const BaseStatistics &stats); + static Value MinOrNull(const BaseStatistics &stats); + static Value MaxOrNull(const BaseStatistics &stats); + template + static void TemplatedVerify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); +}; + +template <> +void NumericStats::Update(BaseStatistics &stats, interval_t new_value); +template <> +void NumericStats::Update(BaseStatistics &stats, list_entry_t new_value); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats_union.hpp b/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats_union.hpp new file mode 100644 index 00000000..8006c206 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats_union.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/numeric_stats_union.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hugeint.hpp" + +namespace duckdb { + +struct NumericValueUnion { + union Val { + bool boolean; + int8_t tinyint; + int16_t smallint; + int32_t integer; + int64_t bigint; + uint8_t utinyint; + uint16_t usmallint; + uint32_t uinteger; + uint64_t ubigint; + hugeint_t hugeint; + float float_; + double double_; + } value_; + + template + T &GetReferenceUnsafe(); +}; + +template <> +DUCKDB_API bool &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API int8_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API int16_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API int32_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API int64_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API hugeint_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API uint8_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API uint16_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API uint32_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API uint64_t &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API float &NumericValueUnion::GetReferenceUnsafe(); +template <> +DUCKDB_API double &NumericValueUnion::GetReferenceUnsafe(); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/segment_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/segment_statistics.hpp new file mode 100644 index 00000000..4f1e0d63 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/segment_statistics.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/segment_statistics.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +class SegmentStatistics { +public: + SegmentStatistics(LogicalType type); + SegmentStatistics(BaseStatistics statistics); + + //! Type-specific statistics of the segment + BaseStatistics statistics; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp new file mode 100644 index 00000000..ddef43f5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/string_stats.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" + +namespace duckdb { +class BaseStatistics; +struct SelectionVector; +class Vector; + +struct StringStatsData { + constexpr static uint32_t MAX_STRING_MINMAX_SIZE = 8; + + //! The minimum value of the segment, potentially truncated + data_t min[MAX_STRING_MINMAX_SIZE]; + //! The maximum value of the segment, potentially truncated + data_t max[MAX_STRING_MINMAX_SIZE]; + //! Whether or not the column can contain unicode characters + bool has_unicode; + //! Whether or not the maximum string length is known + bool has_max_string_length; + //! The maximum string length in bytes + uint32_t max_string_length; +}; + +struct StringStats { + //! Unknown statistics - i.e. "has_unicode" is true, "max_string_length" is unknown, "min" is \0, max is \xFF + DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); + //! Empty statistics - i.e. "has_unicode" is false, "max_string_length" is 0, "min" is \xFF, max is \x00 + DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); + //! Whether or not the statistics have a maximum string length defined + DUCKDB_API static bool HasMaxStringLength(const BaseStatistics &stats); + //! Returns the maximum string length, or throws an exception if !HasMaxStringLength() + DUCKDB_API static uint32_t MaxStringLength(const BaseStatistics &stats); + //! Whether or not the strings can contain unicode + DUCKDB_API static bool CanContainUnicode(const BaseStatistics &stats); + //! Returns the min value (up to a length of StringStatsData::MAX_STRING_MINMAX_SIZE) + DUCKDB_API static string Min(const BaseStatistics &stats); + //! Returns the max value (up to a length of StringStatsData::MAX_STRING_MINMAX_SIZE) + DUCKDB_API static string Max(const BaseStatistics &stats); + + //! Resets the max string length so HasMaxStringLength() is false + DUCKDB_API static void ResetMaxStringLength(BaseStatistics &stats); + //! FIXME: make this part of Set on statistics + DUCKDB_API static void SetContainsUnicode(BaseStatistics &stats); + + DUCKDB_API static void Serialize(const BaseStatistics &stats, Serializer &serializer); + DUCKDB_API static void Deserialize(Deserializer &deserializer, BaseStatistics &base); + + DUCKDB_API static string ToString(const BaseStatistics &stats); + + DUCKDB_API static FilterPropagateResult CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, + const string &value); + + DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + +private: + static StringStatsData &GetDataUnsafe(BaseStatistics &stats); + static const StringStatsData &GetDataUnsafe(const BaseStatistics &stats); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/struct_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/struct_stats.hpp new file mode 100644 index 00000000..38b992f7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/struct_stats.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/struct_stats.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { +class BaseStatistics; +struct SelectionVector; +class Vector; + +struct StructStats { + DUCKDB_API static void Construct(BaseStatistics &stats); + DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); + DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); + + DUCKDB_API static const BaseStatistics *GetChildStats(const BaseStatistics &stats); + DUCKDB_API static const BaseStatistics &GetChildStats(const BaseStatistics &stats, idx_t i); + DUCKDB_API static BaseStatistics &GetChildStats(BaseStatistics &stats, idx_t i); + DUCKDB_API static void SetChildStats(BaseStatistics &stats, idx_t i, const BaseStatistics &new_stats); + DUCKDB_API static void SetChildStats(BaseStatistics &stats, idx_t i, unique_ptr new_stats); + + DUCKDB_API static void Serialize(const BaseStatistics &stats, Serializer &serializer); + DUCKDB_API static void Deserialize(Deserializer &deserializer, BaseStatistics &base); + + DUCKDB_API static string ToString(const BaseStatistics &stats); + + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Copy(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/storage_extension.hpp b/src/duckdb/src/include/duckdb/storage/storage_extension.hpp new file mode 100644 index 00000000..0eeeceec --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/storage_extension.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/storage_extension.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/access_mode.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" + +namespace duckdb { +class AttachedDatabase; +struct AttachInfo; +class Catalog; +class TransactionManager; + +//! The StorageExtensionInfo holds static information relevant to the storage extension +struct StorageExtensionInfo { + virtual ~StorageExtensionInfo() { + } +}; + +typedef unique_ptr (*attach_function_t)(StorageExtensionInfo *storage_info, AttachedDatabase &db, + const string &name, AttachInfo &info, AccessMode access_mode); +typedef unique_ptr (*create_transaction_manager_t)(StorageExtensionInfo *storage_info, + AttachedDatabase &db, Catalog &catalog); + +class StorageExtension { +public: + attach_function_t attach; + create_transaction_manager_t create_transaction_manager; + + //! Additional info passed to the various storage functions + shared_ptr storage_info; + + virtual ~StorageExtension() { + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/storage_info.hpp b/src/duckdb/src/include/duckdb/storage/storage_info.hpp new file mode 100644 index 00000000..77e02336 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/storage_info.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/storage_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/vector_size.hpp" + +namespace duckdb { +struct FileHandle; + +#define STANDARD_ROW_GROUPS_SIZE 122880 +#if STANDARD_ROW_GROUPS_SIZE < STANDARD_VECTOR_SIZE +#error Row groups should be able to hold at least one vector +#endif + +#if ((STANDARD_ROW_GROUPS_SIZE % STANDARD_VECTOR_SIZE) != 0) +#error Row group size should be cleanly divisible by vector size +#endif + +struct Storage { + //! The size of a hard disk sector, only really needed for Direct IO + constexpr static int SECTOR_SIZE = 4096; + //! Block header size for blocks written to the storage + constexpr static int BLOCK_HEADER_SIZE = sizeof(uint64_t); + // Size of a memory slot managed by the StorageManager. This is the quantum of allocation for Blocks on DuckDB. We + // default to 256KB. (1 << 18) + constexpr static int BLOCK_ALLOC_SIZE = 262144; + //! The actual memory space that is available within the blocks + constexpr static int BLOCK_SIZE = BLOCK_ALLOC_SIZE - BLOCK_HEADER_SIZE; + //! The size of the headers. This should be small and written more or less atomically by the hard disk. We default + //! to the page size, which is 4KB. (1 << 12) + constexpr static int FILE_HEADER_SIZE = 4096; + //! The number of rows per row group (must be a multiple of the vector size) + constexpr static const idx_t ROW_GROUP_SIZE = STANDARD_ROW_GROUPS_SIZE; + //! The number of vectors per row group + constexpr static const idx_t ROW_GROUP_VECTOR_COUNT = ROW_GROUP_SIZE / STANDARD_VECTOR_SIZE; +}; + +//! The version number of the database storage format +extern const uint64_t VERSION_NUMBER; + +const char *GetDuckDBVersion(idx_t version_number); + +using block_id_t = int64_t; + +#define INVALID_BLOCK (-1) + +// maximum block id, 2^62 +#define MAXIMUM_BLOCK 4611686018427388000LL + +//! The MainHeader is the first header in the storage file. The MainHeader is typically written only once for a database +//! file. +struct MainHeader { + static constexpr idx_t MAGIC_BYTE_SIZE = 4; + static constexpr idx_t MAGIC_BYTE_OFFSET = Storage::BLOCK_HEADER_SIZE; + static constexpr idx_t FLAG_COUNT = 4; + // the magic bytes in front of the file + // should be "DUCK" + static const char MAGIC_BYTES[]; + //! The version of the database + uint64_t version_number; + //! The set of flags used by the database + uint64_t flags[FLAG_COUNT]; + + static void CheckMagicBytes(FileHandle &handle); + + void Write(WriteStream &ser); + static MainHeader Read(ReadStream &source); +}; + +//! The DatabaseHeader contains information about the current state of the database. Every storage file has two +//! DatabaseHeaders. On startup, the DatabaseHeader with the highest iteration count is used as the active header. When +//! a checkpoint is performed, the active DatabaseHeader is switched by increasing the iteration count of the +//! DatabaseHeader. +struct DatabaseHeader { + //! The iteration count, increases by 1 every time the storage is checkpointed. + uint64_t iteration; + //! A pointer to the initial meta block + idx_t meta_block; + //! A pointer to the block containing the free list + idx_t free_list; + //! The number of blocks that is in the file as of this database header. If the file is larger than BLOCK_SIZE * + //! block_count any blocks appearing AFTER block_count are implicitly part of the free_list. + uint64_t block_count; + + void Write(WriteStream &ser); + static DatabaseHeader Read(ReadStream &source); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/storage_lock.hpp b/src/duckdb/src/include/duckdb/storage/storage_lock.hpp new file mode 100644 index 00000000..b8a6ceb8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/storage_lock.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/storage_lock.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/mutex.hpp" + +namespace duckdb { +class StorageLock; + +enum class StorageLockType { SHARED = 0, EXCLUSIVE = 1 }; + +class StorageLockKey { +public: + StorageLockKey(StorageLock &lock, StorageLockType type); + ~StorageLockKey(); + +private: + StorageLock &lock; + StorageLockType type; +}; + +class StorageLock { + friend class StorageLockKey; + +public: + StorageLock(); + + //! Get an exclusive lock + unique_ptr GetExclusiveLock(); + //! Get a shared lock + unique_ptr GetSharedLock(); + +private: + mutex exclusive_lock; + atomic read_count; + +private: + //! Release an exclusive lock + void ReleaseExclusiveLock(); + //! Release a shared lock + void ReleaseSharedLock(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp new file mode 100644 index 00000000..7bd92c29 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp @@ -0,0 +1,122 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/storage_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/helper.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table_io_manager.hpp" +#include "duckdb/storage/write_ahead_log.hpp" +#include "duckdb/storage/database_size.hpp" + +namespace duckdb { +class BlockManager; +class Catalog; +class CheckpointWriter; +class DatabaseInstance; +class TransactionManager; +class TableCatalogEntry; + +class StorageCommitState { +public: + // Destruction of this object, without prior call to FlushCommit, + // will roll back the committed changes. + virtual ~StorageCommitState() { + } + + // Make the commit persistent + virtual void FlushCommit() = 0; +}; + +//! StorageManager is responsible for managing the physical storage of the +//! database on disk +class StorageManager { +public: + StorageManager(AttachedDatabase &db, string path, bool read_only); + virtual ~StorageManager(); + +public: + static StorageManager &Get(AttachedDatabase &db); + static StorageManager &Get(Catalog &catalog); + + //! Initialize a database or load an existing database from the given path + void Initialize(); + + DatabaseInstance &GetDatabase(); + AttachedDatabase &GetAttached() { + return db; + } + + //! Get the WAL of the StorageManager, returns nullptr if in-memory + optional_ptr GetWriteAheadLog() { + return wal.get(); + } + + string GetDBPath() { + return path; + } + bool InMemory(); + + virtual bool AutomaticCheckpoint(idx_t estimated_wal_bytes) = 0; + virtual unique_ptr GenStorageCommitState(Transaction &transaction, bool checkpoint) = 0; + virtual bool IsCheckpointClean(MetaBlockPointer checkpoint_id) = 0; + virtual void CreateCheckpoint(bool delete_wal = false, bool force_checkpoint = false) = 0; + virtual DatabaseSize GetDatabaseSize() = 0; + virtual vector GetMetadataInfo() = 0; + virtual shared_ptr GetTableIOManager(BoundCreateTableInfo *info) = 0; + +protected: + virtual void LoadDatabase() = 0; + +protected: + //! The database this storagemanager belongs to + AttachedDatabase &db; + //! The path of the database + string path; + //! The WriteAheadLog of the storage manager + unique_ptr wal; + //! Whether or not the database is opened in read-only mode + bool read_only; + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +//! Stores database in a single file. +class SingleFileStorageManager : public StorageManager { +public: + SingleFileStorageManager(AttachedDatabase &db, string path, bool read_only); + + //! The BlockManager to read/store meta information and data in blocks + unique_ptr block_manager; + //! TableIoManager + unique_ptr table_io_manager; + +public: + bool AutomaticCheckpoint(idx_t estimated_wal_bytes) override; + unique_ptr GenStorageCommitState(Transaction &transaction, bool checkpoint) override; + bool IsCheckpointClean(MetaBlockPointer checkpoint_id) override; + void CreateCheckpoint(bool delete_wal, bool force_checkpoint) override; + DatabaseSize GetDatabaseSize() override; + vector GetMetadataInfo() override; + shared_ptr GetTableIOManager(BoundCreateTableInfo *info) override; + +protected: + void LoadDatabase() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp new file mode 100644 index 00000000..f5eab9b4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp @@ -0,0 +1,203 @@ +#pragma once + +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_size.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/checkpoint/string_checkpoint_state.hpp" +#include "duckdb/storage/segment/uncompressed.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/string_uncompressed.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/likely.hpp" + +namespace duckdb { +struct StringDictionaryContainer { + //! The size of the dictionary + uint32_t size; + //! The end of the dictionary (typically Storage::BLOCK_SIZE) + uint32_t end; + + void Verify() { + D_ASSERT(size <= Storage::BLOCK_SIZE); + D_ASSERT(end <= Storage::BLOCK_SIZE); + D_ASSERT(size <= end); + } +}; + +struct StringScanState : public SegmentScanState { + BufferHandle handle; +}; + +struct UncompressedStringStorage { +public: + //! Dictionary header size at the beginning of the string segment (offset + length) + static constexpr uint16_t DICTIONARY_HEADER_SIZE = sizeof(uint32_t) + sizeof(uint32_t); + //! Marker used in length field to indicate the presence of a big string + static constexpr uint16_t BIG_STRING_MARKER = (uint16_t)-1; + //! Base size of big string marker (block id + offset) + static constexpr idx_t BIG_STRING_MARKER_BASE_SIZE = sizeof(block_id_t) + sizeof(int32_t); + //! The marker size of the big string + static constexpr idx_t BIG_STRING_MARKER_SIZE = BIG_STRING_MARKER_BASE_SIZE; + //! The size below which the segment is compacted on flushing + static constexpr size_t COMPACTION_FLUSH_LIMIT = (size_t)Storage::BLOCK_SIZE / 5 * 4; + +public: + static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); + static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); + static idx_t StringFinalAnalyze(AnalyzeState &state_p); + static unique_ptr StringInitScan(ColumnSegment &segment); + static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset); + static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); + static void StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx); + static unique_ptr StringInitSegment(ColumnSegment &segment, block_id_t block_id, + optional_ptr segment_state); + + static unique_ptr StringInitAppend(ColumnSegment &segment) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + return make_uniq(std::move(handle)); + } + + static idx_t StringAppend(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, + UnifiedVectorFormat &data, idx_t offset, idx_t count) { + return StringAppendBase(append_state.handle, segment, stats, data, offset, count); + } + + static idx_t StringAppendBase(ColumnSegment &segment, SegmentStatistics &stats, UnifiedVectorFormat &data, + idx_t offset, idx_t count) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + return StringAppendBase(handle, segment, stats, data, offset, count); + } + + static idx_t StringAppendBase(BufferHandle &handle, ColumnSegment &segment, SegmentStatistics &stats, + UnifiedVectorFormat &data, idx_t offset, idx_t count) { + D_ASSERT(segment.GetBlockOffset() == 0); + auto handle_ptr = handle.Ptr(); + auto source_data = UnifiedVectorFormat::GetData(data); + auto result_data = (int32_t *)(handle_ptr + DICTIONARY_HEADER_SIZE); + uint32_t *dictionary_size = (uint32_t *)handle_ptr; + uint32_t *dictionary_end = (uint32_t *)(handle_ptr + sizeof(uint32_t)); + + idx_t remaining_space = RemainingSpace(segment, handle); + auto base_count = segment.count.load(); + for (idx_t i = 0; i < count; i++) { + auto source_idx = data.sel->get_index(offset + i); + auto target_idx = base_count + i; + if (remaining_space < sizeof(int32_t)) { + // string index does not fit in the block at all + segment.count += i; + return i; + } + remaining_space -= sizeof(int32_t); + if (!data.validity.RowIsValid(source_idx)) { + // null value is stored as a copy of the last value, this is done to be able to efficiently do the + // string_length calculation + if (target_idx > 0) { + result_data[target_idx] = result_data[target_idx - 1]; + } else { + result_data[target_idx] = 0; + } + continue; + } + auto end = handle.Ptr() + *dictionary_end; + +#ifdef DEBUG + GetDictionary(segment, handle).Verify(); +#endif + // Unknown string, continue + // non-null value, check if we can fit it within the block + idx_t string_length = source_data[source_idx].GetSize(); + + // determine whether or not we have space in the block for this string + bool use_overflow_block = false; + idx_t required_space = string_length; + if (DUCKDB_UNLIKELY(required_space >= StringUncompressed::STRING_BLOCK_LIMIT)) { + // string exceeds block limit, store in overflow block and only write a marker here + required_space = BIG_STRING_MARKER_SIZE; + use_overflow_block = true; + } + if (DUCKDB_UNLIKELY(required_space > remaining_space)) { + // no space remaining: return how many tuples we ended up writing + segment.count += i; + return i; + } + + // we have space: write the string + UpdateStringStats(stats, source_data[source_idx]); + + if (DUCKDB_UNLIKELY(use_overflow_block)) { + // write to overflow blocks + block_id_t block; + int32_t offset; + // write the string into the current string block + WriteString(segment, source_data[source_idx], block, offset); + *dictionary_size += BIG_STRING_MARKER_SIZE; + remaining_space -= BIG_STRING_MARKER_SIZE; + auto dict_pos = end - *dictionary_size; + + // write a big string marker into the dictionary + WriteStringMarker(dict_pos, block, offset); + + // place the dictionary offset into the set of vectors + // note: for overflow strings we write negative value + result_data[target_idx] = -(*dictionary_size); + } else { + // string fits in block, append to dictionary and increment dictionary position + D_ASSERT(string_length < NumericLimits::Maximum()); + *dictionary_size += required_space; + remaining_space -= required_space; + auto dict_pos = end - *dictionary_size; + // now write the actual string data into the dictionary + memcpy(dict_pos, source_data[source_idx].GetData(), string_length); + + // place the dictionary offset into the set of vectors + result_data[target_idx] = *dictionary_size; + } + D_ASSERT(RemainingSpace(segment, handle) <= Storage::BLOCK_SIZE); +#ifdef DEBUG + GetDictionary(segment, handle).Verify(); +#endif + } + segment.count += count; + return count; + } + + static idx_t FinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats); + +public: + static inline void UpdateStringStats(SegmentStatistics &stats, const string_t &new_value) { + StringStats::Update(stats.statistics, new_value); + } + + static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer dict); + static StringDictionaryContainer GetDictionary(ColumnSegment &segment, BufferHandle &handle); + static idx_t RemainingSpace(ColumnSegment &segment, BufferHandle &handle); + static void WriteString(ColumnSegment &segment, string_t string, block_id_t &result_block, int32_t &result_offset); + static void WriteStringMemory(ColumnSegment &segment, string_t string, block_id_t &result_block, + int32_t &result_offset); + static string_t ReadOverflowString(ColumnSegment &segment, Vector &result, block_id_t block, int32_t offset); + static string_t ReadString(data_ptr_t target, int32_t offset, uint32_t string_length); + static string_t ReadStringWithLength(data_ptr_t target, int32_t offset); + static void WriteStringMarker(data_ptr_t target, block_id_t block_id, int32_t offset); + static void ReadStringMarker(data_ptr_t target, block_id_t &block_id, int32_t &offset); + + static string_location_t FetchStringLocation(StringDictionaryContainer dict, data_ptr_t baseptr, + int32_t dict_offset); + static string_t FetchStringFromDict(ColumnSegment &segment, StringDictionaryContainer dict, Vector &result, + data_ptr_t baseptr, int32_t dict_offset, uint32_t string_length); + static string_t FetchString(ColumnSegment &segment, StringDictionaryContainer dict, Vector &result, + data_ptr_t baseptr, string_location_t location, uint32_t string_length); + + static unique_ptr SerializeState(ColumnSegment &segment); + static unique_ptr DeserializeState(Deserializer &deserializer); + static void CleanupState(ColumnSegment &segment); +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/append_state.hpp b/src/duckdb/src/include/duckdb/storage/table/append_state.hpp new file mode 100644 index 00000000..382f86da --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/append_state.hpp @@ -0,0 +1,79 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/append_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/storage/storage_lock.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/transaction/transaction_data.hpp" + +namespace duckdb { +class ColumnSegment; +class DataTable; +class LocalTableStorage; +class RowGroup; +class UpdateSegment; + +struct TableAppendState; + +struct ColumnAppendState { + //! The current segment of the append + ColumnSegment *current; + //! Child append states + vector child_appends; + //! The write lock that is held by the append + unique_ptr lock; + //! The compression append state + unique_ptr append_state; +}; + +struct RowGroupAppendState { + RowGroupAppendState(TableAppendState &parent_p) : parent(parent_p) { + } + + //! The parent append state + TableAppendState &parent; + //! The current row_group we are appending to + RowGroup *row_group; + //! The column append states + unsafe_unique_array states; + //! Offset within the row_group + idx_t offset_in_row_group; +}; + +struct IndexLock { + unique_lock index_lock; +}; + +struct TableAppendState { + TableAppendState(); + ~TableAppendState(); + + RowGroupAppendState row_group_append_state; + unique_lock append_lock; + row_t row_start; + row_t current_row; + //! The total number of rows appended by the append operation + idx_t total_append_count; + //! The first row-group that has been appended to + RowGroup *start_row_group; + //! The transaction data + TransactionData transaction; + //! The remaining append count, only if the append count is known beforehand + idx_t remaining; +}; + +struct LocalAppendState { + TableAppendState append_state; + LocalTableStorage *storage; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp b/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp new file mode 100644 index 00000000..14ed981e --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp @@ -0,0 +1,148 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/chunk_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/vector_size.hpp" +#include "duckdb/common/atomic.hpp" + +namespace duckdb { +class RowGroup; +struct SelectionVector; +class Transaction; +struct TransactionData; + +class Serializer; +class Deserializer; + +enum class ChunkInfoType : uint8_t { CONSTANT_INFO, VECTOR_INFO, EMPTY_INFO }; + +class ChunkInfo { +public: + ChunkInfo(idx_t start, ChunkInfoType type) : start(start), type(type) { + } + virtual ~ChunkInfo() { + } + + //! The row index of the first row + idx_t start; + //! The ChunkInfo type + ChunkInfoType type; + +public: + //! Gets up to max_count entries from the chunk info. If the ret is 0>ret>max_count, the selection vector is filled + //! with the tuples + virtual idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) = 0; + virtual idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, + SelectionVector &sel_vector, idx_t max_count) = 0; + //! Returns whether or not a single row in the ChunkInfo should be used or not for the given transaction + virtual bool Fetch(TransactionData transaction, row_t row) = 0; + virtual void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) = 0; + virtual idx_t GetCommittedDeletedCount(idx_t max_count) = 0; + + virtual bool HasDeletes() const = 0; + + virtual void Write(WriteStream &writer) const; + static unique_ptr Read(ReadStream &reader); + +public: + template + TARGET &Cast() { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast chunk info to type - query result type mismatch"); + } + return reinterpret_cast(*this); + } + + template + const TARGET &Cast() const { + if (type != TARGET::TYPE) { + throw InternalException("Failed to cast chunk info to type - query result type mismatch"); + } + return reinterpret_cast(*this); + } +}; + +class ChunkConstantInfo : public ChunkInfo { +public: + static constexpr const ChunkInfoType TYPE = ChunkInfoType::CONSTANT_INFO; + +public: + explicit ChunkConstantInfo(idx_t start); + + transaction_t insert_id; + transaction_t delete_id; + +public: + idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, + SelectionVector &sel_vector, idx_t max_count) override; + bool Fetch(TransactionData transaction, row_t row) override; + void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) override; + idx_t GetCommittedDeletedCount(idx_t max_count) override; + + bool HasDeletes() const override; + + void Write(WriteStream &writer) const override; + static unique_ptr Read(ReadStream &reader); + +private: + template + idx_t TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, + idx_t max_count) const; +}; + +class ChunkVectorInfo : public ChunkInfo { +public: + static constexpr const ChunkInfoType TYPE = ChunkInfoType::VECTOR_INFO; + +public: + explicit ChunkVectorInfo(idx_t start); + + //! The transaction ids of the transactions that inserted the tuples (if any) + transaction_t inserted[STANDARD_VECTOR_SIZE]; + transaction_t insert_id; + bool same_inserted_id; + + //! The transaction ids of the transactions that deleted the tuples (if any) + transaction_t deleted[STANDARD_VECTOR_SIZE]; + bool any_deleted; + +public: + idx_t GetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, + idx_t max_count) const; + idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, + SelectionVector &sel_vector, idx_t max_count) override; + bool Fetch(TransactionData transaction, row_t row) override; + void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) override; + idx_t GetCommittedDeletedCount(idx_t max_count) override; + + void Append(idx_t start, idx_t end, transaction_t commit_id); + + //! Performs a delete in the ChunkVectorInfo - returns how many tuples were actually deleted + //! The number of rows that were actually deleted might be lower than the input count + //! In case we delete rows that were already deleted + //! Note that "rows" is written to to reflect the row ids that were actually deleted + //! i.e. after calling this function, rows will hold [0..actual_delete_count] row ids of the actually deleted tuples + idx_t Delete(transaction_t transaction_id, row_t rows[], idx_t count); + void CommitDelete(transaction_t commit_id, row_t rows[], idx_t count); + + bool HasDeletes() const override; + + void Write(WriteStream &writer) const override; + static unique_ptr Read(ReadStream &reader); + +private: + template + idx_t TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, + idx_t max_count) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp b/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp new file mode 100644 index 00000000..463e2fd9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/column_checkpoint_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/data_pointer.hpp" +#include "duckdb/storage/statistics/segment_statistics.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/storage/partial_block_manager.hpp" + +namespace duckdb { +class ColumnData; +class DatabaseInstance; +class RowGroup; +class PartialBlockManager; +class TableDataWriter; + +struct ColumnCheckpointState { + ColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager); + virtual ~ColumnCheckpointState(); + + RowGroup &row_group; + ColumnData &column_data; + ColumnSegmentTree new_tree; + vector data_pointers; + unique_ptr global_stats; + +protected: + PartialBlockManager &partial_block_manager; + +public: + virtual unique_ptr GetStatistics(); + + virtual void FlushSegment(unique_ptr segment, idx_t segment_size); + virtual void WriteDataPointers(RowGroupWriter &writer, Serializer &serializer); + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct PartialBlockForCheckpoint : public PartialBlock { + struct PartialColumnSegment { + PartialColumnSegment(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block) + : data(data), segment(segment), offset_in_block(offset_in_block) { + } + + ColumnData &data; + ColumnSegment &segment; + uint32_t offset_in_block; + }; + +public: + PartialBlockForCheckpoint(ColumnData &data, ColumnSegment &segment, PartialBlockState state, + BlockManager &block_manager); + ~PartialBlockForCheckpoint() override; + + // We will copy all segment data into the memory of the shared block. + // Once the block is full (or checkpoint is complete) we'll invoke Flush(). + // This will cause the block to get written to storage (via BlockManger::ConvertToPersistent), + // and all segments to have their references updated (via ColumnSegment::ConvertToPersistent) + vector segments; + +public: + bool IsFlushed(); + void Flush(const idx_t free_space_left) override; + void Merge(PartialBlock &other, idx_t offset, idx_t other_size) override; + void AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block); + void Clear() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp new file mode 100644 index 00000000..c278b3a5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp @@ -0,0 +1,173 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/column_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/data_pointer.hpp" +#include "duckdb/storage/table/persistent_table_data.hpp" +#include "duckdb/storage/statistics/segment_statistics.hpp" +#include "duckdb/storage/table/segment_tree.hpp" +#include "duckdb/storage/table/column_segment_tree.hpp" +#include "duckdb/common/mutex.hpp" + +namespace duckdb { +class ColumnData; +class ColumnSegment; +class DatabaseInstance; +class RowGroup; +class RowGroupWriter; +class TableDataWriter; +class TableStorageInfo; +struct TransactionData; + +struct DataTableInfo; + +struct ColumnCheckpointInfo { + explicit ColumnCheckpointInfo(CompressionType compression_type_p) : compression_type(compression_type_p) {}; + CompressionType compression_type; +}; + +class ColumnData { + friend class ColumnDataCheckpointer; + +public: + ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, LogicalType type, + optional_ptr parent); + virtual ~ColumnData(); + + //! The start row + idx_t start; + //! The count of the column data + idx_t count; + //! The block manager + BlockManager &block_manager; + //! Table info for the column + DataTableInfo &info; + //! The column index of the column, either within the parent table or within the parent + idx_t column_index; + //! The type of the column + LogicalType type; + //! The parent column (if any) + optional_ptr parent; + +public: + virtual bool CheckZonemap(ColumnScanState &state, TableFilter &filter) = 0; + + BlockManager &GetBlockManager() { + return block_manager; + } + DatabaseInstance &GetDatabase() const; + DataTableInfo &GetTableInfo() const; + virtual idx_t GetMaxEntry(); + + void IncrementVersion(); + + virtual void SetStart(idx_t new_start); + //! The root type of the column + const LogicalType &RootType() const; + + //! Initialize a scan of the column + virtual void InitializeScan(ColumnScanState &state); + //! Initialize a scan starting at the specified offset + virtual void InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx); + //! Scan the next vector from the column + virtual idx_t Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result); + virtual idx_t ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates); + virtual void ScanCommittedRange(idx_t row_group_start, idx_t offset_in_row_group, idx_t count, Vector &result); + virtual idx_t ScanCount(ColumnScanState &state, Vector &result, idx_t count); + //! Select + virtual void Select(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, + SelectionVector &sel, idx_t &count, const TableFilter &filter); + virtual void FilterScan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, + SelectionVector &sel, idx_t count); + virtual void FilterScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, SelectionVector &sel, + idx_t count, bool allow_updates); + + //! Skip the scan forward by "count" rows + virtual void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE); + + //! Initialize an appending phase for this column + virtual void InitializeAppend(ColumnAppendState &state); + //! Append a vector of type [type] to the end of the column + virtual void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count); + //! Append a vector of type [type] to the end of the column + void Append(ColumnAppendState &state, Vector &vector, idx_t count); + virtual void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count); + //! Revert a set of appends to the ColumnData + virtual void RevertAppend(row_t start_row); + + //! Fetch the vector from the column data that belongs to this specific row + virtual idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result); + //! Fetch a specific row id and append it to the vector + virtual void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx); + + virtual void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, + idx_t update_count); + virtual void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t depth); + virtual unique_ptr GetUpdateStatistics(); + + virtual void CommitDropColumn(); + + virtual unique_ptr CreateCheckpointState(RowGroup &row_group, + PartialBlockManager &partial_block_manager); + virtual unique_ptr + Checkpoint(RowGroup &row_group, PartialBlockManager &partial_block_manager, ColumnCheckpointInfo &checkpoint_info); + + virtual void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, + Vector &scan_vector); + + virtual void DeserializeColumn(Deserializer &deserializer); + static shared_ptr Deserialize(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + idx_t start_row, ReadStream &source, const LogicalType &type, + optional_ptr parent); + + virtual void GetColumnSegmentInfo(idx_t row_group_index, vector col_path, vector &result); + virtual void Verify(RowGroup &parent); + + bool CheckZonemap(TableFilter &filter); + + static shared_ptr CreateColumn(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + idx_t start_row, const LogicalType &type, + optional_ptr parent = nullptr); + static unique_ptr CreateColumnUnique(BlockManager &block_manager, DataTableInfo &info, + idx_t column_index, idx_t start_row, const LogicalType &type, + optional_ptr parent = nullptr); + + void MergeStatistics(const BaseStatistics &other); + void MergeIntoStatistics(BaseStatistics &other); + unique_ptr GetStatistics(); + +protected: + //! Append a transient segment + void AppendTransientSegment(SegmentLock &l, idx_t start_row); + + //! Scans a base vector from the column + idx_t ScanVector(ColumnScanState &state, Vector &result, idx_t remaining, bool has_updates); + //! Scans a vector from the column merged with any potential updates + //! If ALLOW_UPDATES is set to false, the function will instead throw an exception if any updates are found + template + idx_t ScanVector(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result); + +protected: + //! The segments holding the data of this column segment + ColumnSegmentTree data; + //! The lock for the updates + mutex update_lock; + //! The updates for this column segment + unique_ptr updates; + //! The internal version of the column data + idx_t version; + //! The stats of the root segment + unique_ptr stats; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp new file mode 100644 index 00000000..aedb2ee8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/column_data_checkpointer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" + +namespace duckdb { + +class ColumnDataCheckpointer { +public: + ColumnDataCheckpointer(ColumnData &col_data_p, RowGroup &row_group_p, ColumnCheckpointState &state_p, + ColumnCheckpointInfo &checkpoint_info); + +public: + DatabaseInstance &GetDatabase(); + const LogicalType &GetType() const; + ColumnData &GetColumnData(); + RowGroup &GetRowGroup(); + ColumnCheckpointState &GetCheckpointState(); + + void Checkpoint(vector> nodes); + CompressionFunction &GetCompressionFunction(CompressionType type); + +private: + void ScanSegments(const std::function &callback); + unique_ptr DetectBestCompressionMethod(idx_t &compression_idx); + void WriteToDisk(); + bool HasChanges(); + void WritePersistentSegments(); + +private: + ColumnData &col_data; + RowGroup &row_group; + ColumnCheckpointState &state; + bool is_validity; + Vector intermediate; + vector> nodes; + vector> compression_functions; + ColumnCheckpointInfo &checkpoint_info; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp new file mode 100644 index 00000000..89a9e2b5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp @@ -0,0 +1,150 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/column_segment.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/block.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/statistics/segment_statistics.hpp" +#include "duckdb/storage/storage_lock.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/table/segment_base.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" + +namespace duckdb { +class ColumnSegment; +class BlockManager; +class ColumnSegment; +class ColumnData; +class DatabaseInstance; +class Transaction; +class BaseStatistics; +class UpdateSegment; +class TableFilter; +struct ColumnFetchState; +struct ColumnScanState; +struct ColumnAppendState; + +enum class ColumnSegmentType : uint8_t { TRANSIENT, PERSISTENT }; +//! TableFilter represents a filter pushed down into the table scan. + +class ColumnSegment : public SegmentBase { +public: + ~ColumnSegment(); + + //! The database instance + DatabaseInstance &db; + //! The type stored in the column + LogicalType type; + //! The size of the type + idx_t type_size; + //! The column segment type (transient or persistent) + ColumnSegmentType segment_type; + //! The compression function + reference function; + //! The statistics for the segment + SegmentStatistics stats; + //! The block that this segment relates to + shared_ptr block; + + static unique_ptr CreatePersistentSegment(DatabaseInstance &db, BlockManager &block_manager, + block_id_t id, idx_t offset, const LogicalType &type_p, + idx_t start, idx_t count, CompressionType compression_type, + BaseStatistics statistics, + unique_ptr segment_state); + static unique_ptr CreateTransientSegment(DatabaseInstance &db, const LogicalType &type, idx_t start, + idx_t segment_size = Storage::BLOCK_SIZE); + static unique_ptr CreateSegment(ColumnSegment &other, idx_t start); + +public: + void InitializeScan(ColumnScanState &state); + //! Scan one vector from this segment + void Scan(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset, bool entire_vector); + //! Fetch a value of the specific row id and append it to the result + void FetchRow(ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx); + + static idx_t FilterSelection(SelectionVector &sel, Vector &result, const TableFilter &filter, + idx_t &approved_tuple_count, ValidityMask &mask); + + //! Skip a scan forward to the row_index specified in the scan state + void Skip(ColumnScanState &state); + + // The maximum size of the buffer (in bytes) + idx_t SegmentSize() const; + //! Resize the block + void Resize(idx_t segment_size); + + //! Initialize an append of this segment. Appends are only supported on transient segments. + void InitializeAppend(ColumnAppendState &state); + //! Appends a (part of) vector to the segment, returns the amount of entries successfully appended + idx_t Append(ColumnAppendState &state, UnifiedVectorFormat &data, idx_t offset, idx_t count); + //! Finalize the segment for appending - no more appends can follow on this segment + //! The segment should be compacted as much as possible + //! Returns the number of bytes occupied within the segment + idx_t FinalizeAppend(ColumnAppendState &state); + //! Revert an append made to this segment + void RevertAppend(idx_t start_row); + + //! Convert a transient in-memory segment into a persistent segment blocked by an on-disk block. + //! Only used during checkpointing. + void ConvertToPersistent(optional_ptr block_manager, block_id_t block_id); + //! Updates pointers to refer to the given block and offset. This is only used + //! when sharing a block among segments. This is invoked only AFTER the block is written. + void MarkAsPersistent(shared_ptr block, uint32_t offset_in_block); + + block_id_t GetBlockId() { + D_ASSERT(segment_type == ColumnSegmentType::PERSISTENT); + return block_id; + } + + BlockManager &GetBlockManager() const { + return block->block_manager; + } + + idx_t GetBlockOffset() { + D_ASSERT(segment_type == ColumnSegmentType::PERSISTENT || offset == 0); + return offset; + } + + idx_t GetRelativeIndex(idx_t row_index) { + D_ASSERT(row_index >= this->start); + D_ASSERT(row_index <= this->start + this->count); + return row_index - this->start; + } + + optional_ptr GetSegmentState() { + return segment_state.get(); + } + + void CommitDropSegment(); + +public: + ColumnSegment(DatabaseInstance &db, shared_ptr block, LogicalType type, ColumnSegmentType segment_type, + idx_t start, idx_t count, CompressionFunction &function, BaseStatistics statistics, + block_id_t block_id, idx_t offset, idx_t segment_size, + unique_ptr segment_state = nullptr); + ColumnSegment(ColumnSegment &other, idx_t start); + +private: + void Scan(ColumnScanState &state, idx_t scan_count, Vector &result); + void ScanPartial(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); + +private: + //! The block id that this segment relates to (persistent segment only) + block_id_t block_id; + //! The offset into the block (persistent segment only) + idx_t offset; + //! The allocated segment size + idx_t segment_size; + //! Storage associated with the compressed segment + unique_ptr segment_state; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/column_segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/column_segment_tree.hpp new file mode 100644 index 00000000..d1fa796b --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/column_segment_tree.hpp @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/column_segment_tree.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/segment_tree.hpp" +#include "duckdb/storage/table/column_segment.hpp" + +namespace duckdb { + +class ColumnSegmentTree : public SegmentTree {}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp b/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp new file mode 100644 index 00000000..f215cfcb --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/data_table_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/storage/table/table_index_list.hpp" + +namespace duckdb { +class DatabaseInstance; +class TableIOManager; + +struct DataTableInfo { + DataTableInfo(AttachedDatabase &db, shared_ptr table_io_manager_p, string schema, string table); + + //! The database instance of the table + AttachedDatabase &db; + //! The table IO manager + shared_ptr table_io_manager; + //! The amount of elements in the table. Note that this number signifies the amount of COMMITTED entries in the + //! table. It can be inaccurate inside of transactions. More work is needed to properly support that. + atomic cardinality; + // schema of the table + string schema; + // name of the table + string table; + + TableIndexList indexes; + + bool IsTemporary() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp new file mode 100644 index 00000000..6001e703 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/list_column_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/validity_column_data.hpp" + +namespace duckdb { + +//! List column data represents a list +class ListColumnData : public ColumnData { +public: + ListColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, + LogicalType type, optional_ptr parent = nullptr); + + //! The child-column of the list + unique_ptr child_column; + //! The validity column data of the struct + ValidityColumnData validity; + +public: + void SetStart(idx_t new_start) override; + bool CheckZonemap(ColumnScanState &state, TableFilter &filter) override; + + void InitializeScan(ColumnScanState &state) override; + void InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) override; + + idx_t Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) override; + idx_t ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) override; + idx_t ScanCount(ColumnScanState &state, Vector &result, idx_t count) override; + + void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; + + void InitializeAppend(ColumnAppendState &state) override; + void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; + void RevertAppend(row_t start_row) override; + idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; + void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) override; + void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, + idx_t update_count) override; + void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t depth) override; + unique_ptr GetUpdateStatistics() override; + + void CommitDropColumn() override; + + unique_ptr CreateCheckpointState(RowGroup &row_group, + PartialBlockManager &partial_block_manager) override; + unique_ptr Checkpoint(RowGroup &row_group, PartialBlockManager &partial_block_manager, + ColumnCheckpointInfo &checkpoint_info) override; + + void DeserializeColumn(Deserializer &deserializer) override; + + void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, + vector &result) override; + +private: + uint64_t FetchListOffset(idx_t row_idx); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp b/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp new file mode 100644 index 00000000..f5eb6d9a --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/persistent_table_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/storage/data_pointer.hpp" +#include "duckdb/storage/table/table_statistics.hpp" +#include "duckdb/storage/metadata/metadata_manager.hpp" + +namespace duckdb { +class BaseStatistics; + +class PersistentTableData { +public: + explicit PersistentTableData(idx_t column_count); + ~PersistentTableData(); + + TableStatistics table_stats; + idx_t total_rows; + idx_t row_group_count; + MetaBlockPointer block_pointer; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp new file mode 100644 index 00000000..5b416ba1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp @@ -0,0 +1,175 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/row_group.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/vector_size.hpp" +#include "duckdb/storage/table/chunk_info.hpp" +#include "duckdb/storage/statistics/segment_statistics.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/enums/scan_options.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/parser/column_list.hpp" +#include "duckdb/storage/table/segment_base.hpp" +#include "duckdb/storage/block.hpp" + +namespace duckdb { +class AttachedDatabase; +class BlockManager; +class ColumnData; +class DatabaseInstance; +class DataTable; +class PartialBlockManager; +struct DataTableInfo; +class ExpressionExecutor; +class RowGroupCollection; +class RowGroupWriter; +class UpdateSegment; +class TableStatistics; +struct ColumnSegmentInfo; +class Vector; +struct ColumnCheckpointState; +struct RowGroupPointer; +struct TransactionData; +class CollectionScanState; +class TableFilterSet; +struct ColumnFetchState; +struct RowGroupAppendState; +class MetadataManager; +class RowVersionManager; + +struct RowGroupWriteData { + vector> states; + vector statistics; +}; + +class RowGroup : public SegmentBase { +public: + friend class ColumnData; + +public: + RowGroup(RowGroupCollection &collection, idx_t start, idx_t count); + RowGroup(RowGroupCollection &collection, RowGroupPointer &&pointer); + ~RowGroup(); + +private: + //! The RowGroupCollection this row-group is a part of + reference collection; + //! The version info of the row_group (inserted and deleted tuple info) + shared_ptr version_info; + //! The column data of the row_group + vector> columns; + +public: + void MoveToCollection(RowGroupCollection &collection, idx_t new_start); + RowGroupCollection &GetCollection() { + return collection.get(); + } + BlockManager &GetBlockManager(); + DataTableInfo &GetTableInfo(); + + unique_ptr AlterType(RowGroupCollection &collection, const LogicalType &target_type, idx_t changed_idx, + ExpressionExecutor &executor, CollectionScanState &scan_state, + DataChunk &scan_chunk); + unique_ptr AddColumn(RowGroupCollection &collection, ColumnDefinition &new_column, + ExpressionExecutor &executor, Expression &default_value, Vector &intermediate); + unique_ptr RemoveColumn(RowGroupCollection &collection, idx_t removed_column); + + void CommitDrop(); + void CommitDropColumn(idx_t index); + + void InitializeEmpty(const vector &types); + + //! Initialize a scan over this row_group + bool InitializeScan(CollectionScanState &state); + bool InitializeScanWithOffset(CollectionScanState &state, idx_t vector_offset); + //! Checks the given set of table filters against the row-group statistics. Returns false if the entire row group + //! can be skipped. + bool CheckZonemap(TableFilterSet &filters, const vector &column_ids); + //! Checks the given set of table filters against the per-segment statistics. Returns false if any segments were + //! skipped. + bool CheckZonemapSegments(CollectionScanState &state); + void Scan(TransactionData transaction, CollectionScanState &state, DataChunk &result); + void ScanCommitted(CollectionScanState &state, DataChunk &result, TableScanType type); + + idx_t GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); + idx_t GetCommittedSelVector(transaction_t start_time, transaction_t transaction_id, idx_t vector_idx, + SelectionVector &sel_vector, idx_t max_count); + + //! For a specific row, returns true if it should be used for the transaction and false otherwise. + bool Fetch(TransactionData transaction, idx_t row); + //! Fetch a specific row from the row_group and insert it into the result at the specified index + void FetchRow(TransactionData transaction, ColumnFetchState &state, const vector &column_ids, + row_t row_id, DataChunk &result, idx_t result_idx); + + //! Append count rows to the version info + void AppendVersionInfo(TransactionData transaction, idx_t count); + //! Commit a previous append made by RowGroup::AppendVersionInfo + void CommitAppend(transaction_t commit_id, idx_t start, idx_t count); + //! Revert a previous append made by RowGroup::AppendVersionInfo + void RevertAppend(idx_t start); + + //! Delete the given set of rows in the version manager + idx_t Delete(TransactionData transaction, DataTable &table, row_t *row_ids, idx_t count); + + RowGroupWriteData WriteToDisk(PartialBlockManager &manager, const vector &compression_types); + bool AllDeleted(); + RowGroupPointer Checkpoint(RowGroupWriter &writer, TableStatistics &global_stats); + + void InitializeAppend(RowGroupAppendState &append_state); + void Append(RowGroupAppendState &append_state, DataChunk &chunk, idx_t append_count); + + void Update(TransactionData transaction, DataChunk &updates, row_t *ids, idx_t offset, idx_t count, + const vector &column_ids); + //! Update a single column; corresponds to DataTable::UpdateColumn + //! This method should only be called from the WAL + void UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, + const vector &column_path); + + void MergeStatistics(idx_t column_idx, const BaseStatistics &other); + void MergeIntoStatistics(idx_t column_idx, BaseStatistics &other); + unique_ptr GetStatistics(idx_t column_idx); + + void GetColumnSegmentInfo(idx_t row_group_index, vector &result); + + void Verify(); + + void NextVector(CollectionScanState &state); + + idx_t DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t rows[], idx_t count); + RowVersionManager &GetOrCreateVersionInfo(); + + // Serialization + static void Serialize(RowGroupPointer &pointer, Serializer &serializer); + static RowGroupPointer Deserialize(Deserializer &deserializer); + +private: + shared_ptr &GetVersionInfo(); + shared_ptr &GetOrCreateVersionInfoPtr(); + + ColumnData &GetColumn(storage_t c); + idx_t GetColumnCount() const; + vector> &GetColumns(); + + template + void TemplatedScan(TransactionData transaction, CollectionScanState &state, DataChunk &result); + + vector CheckpointDeletes(MetadataManager &manager); + + bool HasUnloadedDeletes() const; + +private: + mutex row_group_lock; + mutex stats_lock; + vector column_pointers; + unique_ptr[]> is_loaded; + vector deletes_pointers; + atomic deletes_is_loaded; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp new file mode 100644 index 00000000..f5bdc86b --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp @@ -0,0 +1,134 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/row_group_collection.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/storage/table/segment_tree.hpp" +#include "duckdb/storage/statistics/column_statistics.hpp" +#include "duckdb/storage/table/table_statistics.hpp" + +namespace duckdb { +struct ParallelTableScanState; +struct ParallelCollectionScanState; +class CreateIndexScanState; +class CollectionScanState; +class PersistentTableData; +class TableDataWriter; +class TableIndexList; +class TableStatistics; +struct TableAppendState; +class DuckTransaction; +class BoundConstraint; +class RowGroupSegmentTree; +struct ColumnSegmentInfo; +class MetadataManager; + +class RowGroupCollection { +public: + RowGroupCollection(shared_ptr info, BlockManager &block_manager, vector types, + idx_t row_start, idx_t total_rows = 0); + +public: + idx_t GetTotalRows() const; + Allocator &GetAllocator() const; + + void Initialize(PersistentTableData &data); + void InitializeEmpty(); + + bool IsEmpty() const; + + void AppendRowGroup(SegmentLock &l, idx_t start_row); + //! Get the nth row-group, negative numbers start from the back (so -1 is the last row group, etc) + RowGroup *GetRowGroup(int64_t index); + void Verify(); + + void InitializeScan(CollectionScanState &state, const vector &column_ids, TableFilterSet *table_filters); + void InitializeCreateIndexScan(CreateIndexScanState &state); + void InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, idx_t start_row, + idx_t end_row); + static bool InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, + RowGroup &row_group, idx_t vector_index, idx_t max_row); + void InitializeParallelScan(ParallelCollectionScanState &state); + bool NextParallelScan(ClientContext &context, ParallelCollectionScanState &state, CollectionScanState &scan_state); + + bool Scan(DuckTransaction &transaction, const vector &column_ids, + const std::function &fun); + bool Scan(DuckTransaction &transaction, const std::function &fun); + + void Fetch(TransactionData transaction, DataChunk &result, const vector &column_ids, + const Vector &row_identifiers, idx_t fetch_count, ColumnFetchState &state); + + //! Initialize an append of a variable number of rows. FinalizeAppend must be called after appending is done. + void InitializeAppend(TableAppendState &state); + //! Initialize an append with a known number of rows. FinalizeAppend should not be called after appending is done. + void InitializeAppend(TransactionData transaction, TableAppendState &state, idx_t append_count); + //! Appends to the row group collection. Returns true if a new row group has been created to append to + bool Append(DataChunk &chunk, TableAppendState &state); + //! FinalizeAppend flushes an append with a variable number of rows. + void FinalizeAppend(TransactionData transaction, TableAppendState &state); + void CommitAppend(transaction_t commit_id, idx_t row_start, idx_t count); + void RevertAppendInternal(idx_t start_row); + + void MergeStorage(RowGroupCollection &data); + + void RemoveFromIndexes(TableIndexList &indexes, Vector &row_identifiers, idx_t count); + + idx_t Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count); + void Update(TransactionData transaction, row_t *ids, const vector &column_ids, DataChunk &updates); + void UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, + DataChunk &updates); + + void Checkpoint(TableDataWriter &writer, TableStatistics &global_stats); + + void CommitDropColumn(idx_t index); + void CommitDropTable(); + + vector GetColumnSegmentInfo(); + const vector &GetTypes() const; + + shared_ptr AddColumn(ClientContext &context, ColumnDefinition &new_column, + Expression &default_value); + shared_ptr RemoveColumn(idx_t col_idx); + shared_ptr AlterType(ClientContext &context, idx_t changed_idx, const LogicalType &target_type, + vector bound_columns, Expression &cast_expr); + void VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint); + + void CopyStats(TableStatistics &stats); + unique_ptr CopyStats(column_t column_id); + void SetDistinct(column_t column_id, unique_ptr distinct_stats); + + AttachedDatabase &GetAttached(); + BlockManager &GetBlockManager() { + return block_manager; + } + MetadataManager &GetMetadataManager(); + DataTableInfo &GetTableInfo() { + return *info; + } + +private: + bool IsEmpty(SegmentLock &) const; + +private: + //! BlockManager + BlockManager &block_manager; + //! The number of rows in the table + atomic total_rows; + //! The data table info + shared_ptr info; + //! The column types of the row group collection + vector types; + idx_t row_start; + //! The segment trees holding the various row_groups of the table + shared_ptr row_groups; + //! Table statistics + TableStatistics stats; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp new file mode 100644 index 00000000..e715c1a5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/row_group_segment_tree.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/segment_tree.hpp" +#include "duckdb/storage/table/row_group.hpp" + +namespace duckdb { +struct DataTableInfo; +class PersistentTableData; +class MetadataReader; + +class RowGroupSegmentTree : public SegmentTree { +public: + RowGroupSegmentTree(RowGroupCollection &collection); + ~RowGroupSegmentTree() override; + + void Initialize(PersistentTableData &data); + +protected: + unique_ptr LoadSegment() override; + + RowGroupCollection &collection; + idx_t current_row_group; + idx_t max_row_group; + unique_ptr reader; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp b/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp new file mode 100644 index 00000000..0763513b --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/row_version_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/vector_size.hpp" +#include "duckdb/storage/table/chunk_info.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/mutex.hpp" + +namespace duckdb { + +class MetadataManager; +struct MetaBlockPointer; + +class RowVersionManager { +public: + explicit RowVersionManager(idx_t start); + + idx_t GetStart() { + return start; + } + void SetStart(idx_t start); + idx_t GetCommittedDeletedCount(idx_t count); + + idx_t GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); + idx_t GetCommittedSelVector(transaction_t start_time, transaction_t transaction_id, idx_t vector_idx, + SelectionVector &sel_vector, idx_t max_count); + bool Fetch(TransactionData transaction, idx_t row); + + void AppendVersionInfo(TransactionData transaction, idx_t count, idx_t row_group_start, idx_t row_group_end); + void CommitAppend(transaction_t commit_id, idx_t row_group_start, idx_t count); + void RevertAppend(idx_t start_row); + + idx_t DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t rows[], idx_t count); + void CommitDelete(idx_t vector_idx, transaction_t commit_id, row_t rows[], idx_t count); + + vector Checkpoint(MetadataManager &manager); + static shared_ptr Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager, + idx_t start); + +private: + mutex version_lock; + idx_t start; + unique_ptr vector_info[Storage::ROW_GROUP_VECTOR_COUNT]; + bool has_changes; + vector storage_pointers; + +private: + optional_ptr GetChunkInfo(idx_t vector_idx); + ChunkVectorInfo &GetVectorInfo(idx_t vector_idx); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp new file mode 100644 index 00000000..cefa13a2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp @@ -0,0 +1,197 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/scan_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" +#include "duckdb/storage/storage_lock.hpp" +#include "duckdb/common/enums/scan_options.hpp" +#include "duckdb/execution/adaptive_filter.hpp" +#include "duckdb/storage/table/segment_lock.hpp" + +namespace duckdb { +class ColumnSegment; +class LocalTableStorage; +class CollectionScanState; +class Index; +class RowGroup; +class RowGroupCollection; +class UpdateSegment; +class TableScanState; +class ColumnSegment; +class ColumnSegmentTree; +class ValiditySegment; +class TableFilterSet; +class ColumnData; +class DuckTransaction; +class RowGroupSegmentTree; + +struct SegmentScanState { + virtual ~SegmentScanState() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +struct IndexScanState { + virtual ~IndexScanState() { + } + + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +typedef unordered_map buffer_handle_set_t; + +struct ColumnScanState { + //! The column segment that is currently being scanned + ColumnSegment *current = nullptr; + //! Column segment tree + ColumnSegmentTree *segment_tree = nullptr; + //! The current row index of the scan + idx_t row_index = 0; + //! The internal row index (i.e. the position of the SegmentScanState) + idx_t internal_index = 0; + //! Segment scan state + unique_ptr scan_state; + //! Child states of the vector + vector child_states; + //! Whether or not InitializeState has been called for this segment + bool initialized = false; + //! If this segment has already been checked for skipping purposes + bool segment_checked = false; + //! The version of the column data that we are scanning. + //! This is used to detect if the ColumnData has been changed out from under us during a scan + //! If this is the case, we re-initialize the scan + idx_t version = 0; + //! We initialize one SegmentScanState per segment, however, if scanning a DataChunk requires us to scan over more + //! than one Segment, we need to keep the scan states of the previous segments around + vector> previous_states; + //! The last read offset in the child state (used for LIST columns only) + idx_t last_offset = 0; + +public: + void Initialize(const LogicalType &type); + //! Move the scan state forward by "count" rows (including all child states) + void Next(idx_t count); + //! Move ONLY this state forward by "count" rows (i.e. not the child states) + void NextInternal(idx_t count); +}; + +struct ColumnFetchState { + //! The set of pinned block handles for this set of fetches + buffer_handle_set_t handles; + //! Any child states of the fetch + vector> child_states; + + BufferHandle &GetOrInsertHandle(ColumnSegment &segment); +}; + +class CollectionScanState { +public: + CollectionScanState(TableScanState &parent_p); + + //! The current row_group we are scanning + RowGroup *row_group; + //! The vector index within the row_group + idx_t vector_index; + //! The maximum row within the row group + idx_t max_row_group_row; + //! Child column scans + unsafe_unique_array column_scans; + //! Row group segment tree + RowGroupSegmentTree *row_groups; + //! The total maximum row index + idx_t max_row; + //! The current batch index + idx_t batch_index; + +public: + void Initialize(const vector &types); + const vector &GetColumnIds(); + TableFilterSet *GetFilters(); + AdaptiveFilter *GetAdaptiveFilter(); + bool Scan(DuckTransaction &transaction, DataChunk &result); + bool ScanCommitted(DataChunk &result, TableScanType type); + bool ScanCommitted(DataChunk &result, SegmentLock &l, TableScanType type); + +private: + TableScanState &parent; +}; + +class TableScanState { +public: + TableScanState() : table_state(*this), local_state(*this), table_filters(nullptr) {}; + + //! The underlying table scan state + CollectionScanState table_state; + //! Transaction-local scan state + CollectionScanState local_state; + +public: + void Initialize(vector column_ids, TableFilterSet *table_filters = nullptr); + + const vector &GetColumnIds(); + TableFilterSet *GetFilters(); + AdaptiveFilter *GetAdaptiveFilter(); + +private: + //! The column identifiers of the scan + vector column_ids; + //! The table filters (if any) + TableFilterSet *table_filters; + //! Adaptive filter info (if any) + unique_ptr adaptive_filter; +}; + +struct ParallelCollectionScanState { + ParallelCollectionScanState(); + + //! The row group collection we are scanning + RowGroupCollection *collection; + RowGroup *current_row_group; + idx_t vector_index; + idx_t max_row; + idx_t batch_index; + atomic processed_rows; + mutex lock; +}; + +struct ParallelTableScanState { + //! Parallel scan state for the table + ParallelCollectionScanState scan_state; + //! Parallel scan state for the transaction-local state + ParallelCollectionScanState local_state; +}; + +class CreateIndexScanState : public TableScanState { +public: + vector> locks; + unique_lock append_lock; + SegmentLock segment_lock; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp new file mode 100644 index 00000000..b71587bf --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/segment_base.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/atomic.hpp" + +namespace duckdb { + +template +class SegmentBase { +public: + SegmentBase(idx_t start, idx_t count) : start(start), count(count), next(nullptr) { + } + T *Next() { +#ifndef DUCKDB_R_BUILD + return next.load(); +#else + return next; +#endif + } + + //! The start row id of this chunk + idx_t start; + //! The amount of entries in this storage chunk + atomic count; + //! The next segment after this one +#ifndef DUCKDB_R_BUILD + atomic next; +#else + T *next; +#endif + //! The index within the segment tree + idx_t index; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_lock.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_lock.hpp new file mode 100644 index 00000000..88840437 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/segment_lock.hpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/segment_lock.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/mutex.hpp" + +namespace duckdb { + +struct SegmentLock { +public: + SegmentLock() { + } + SegmentLock(mutex &lock) : lock(lock) { + } + // disable copy constructors + SegmentLock(const SegmentLock &other) = delete; + SegmentLock &operator=(const SegmentLock &) = delete; + //! enable move constructors + SegmentLock(SegmentLock &&other) noexcept { + std::swap(lock, other.lock); + } + SegmentLock &operator=(SegmentLock &&other) noexcept { + std::swap(lock, other.lock); + return *this; + } + +private: + unique_lock lock; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp new file mode 100644 index 00000000..207c1481 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp @@ -0,0 +1,357 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/segment_tree.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/storage/storage_lock.hpp" +#include "duckdb/storage/table/segment_lock.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +template +struct SegmentNode { + idx_t row_start; + unique_ptr node; +}; + +//! The SegmentTree maintains a list of all segments of a specific column in a table, and allows searching for a segment +//! by row number +template +class SegmentTree { +private: + class SegmentIterationHelper; + +public: + explicit SegmentTree() : finished_loading(true) { + } + virtual ~SegmentTree() { + } + + //! Locks the segment tree. All methods to the segment tree either lock the segment tree, or take an already + //! obtained lock. + SegmentLock Lock() { + return SegmentLock(node_lock); + } + + bool IsEmpty(SegmentLock &l) { + return GetRootSegment(l) == nullptr; + } + + //! Gets a pointer to the first segment. Useful for scans. + T *GetRootSegment() { + auto l = Lock(); + return GetRootSegment(l); + } + + T *GetRootSegment(SegmentLock &l) { + if (nodes.empty()) { + LoadNextSegment(l); + } + return GetRootSegmentInternal(); + } + //! Obtains ownership of the data of the segment tree + vector> MoveSegments(SegmentLock &l) { + LoadAllSegments(l); + return std::move(nodes); + } + vector> MoveSegments() { + auto l = Lock(); + return MoveSegments(l); + } + idx_t GetSegmentCount() { + auto l = Lock(); + return nodes.size(); + } + //! Gets a pointer to the nth segment. Negative numbers start from the back. + T *GetSegmentByIndex(int64_t index) { + auto l = Lock(); + return GetSegmentByIndex(l, index); + } + T *GetSegmentByIndex(SegmentLock &l, int64_t index) { + if (index < 0) { + // load all segments + LoadAllSegments(l); + index = nodes.size() + index; + if (index < 0) { + return nullptr; + } + return nodes[index].node.get(); + } else { + // lazily load segments until we reach the specific segment + while (idx_t(index) >= nodes.size() && LoadNextSegment(l)) { + } + if (idx_t(index) >= nodes.size()) { + return nullptr; + } + return nodes[index].node.get(); + } + } + //! Gets the next segment + T *GetNextSegment(T *segment) { + if (!SUPPORTS_LAZY_LOADING) { + return segment->Next(); + } + if (finished_loading) { + return segment->Next(); + } + auto l = Lock(); + return GetNextSegment(l, segment); + } + T *GetNextSegment(SegmentLock &l, T *segment) { + if (!segment) { + return nullptr; + } +#ifdef DEBUG + D_ASSERT(nodes[segment->index].node.get() == segment); +#endif + return GetSegmentByIndex(l, segment->index + 1); + } + + //! Gets a pointer to the last segment. Useful for appends. + T *GetLastSegment(SegmentLock &l) { + LoadAllSegments(l); + if (nodes.empty()) { + return nullptr; + } + return nodes.back().node.get(); + } + //! Gets a pointer to a specific column segment for the given row + T *GetSegment(idx_t row_number) { + auto l = Lock(); + return GetSegment(l, row_number); + } + T *GetSegment(SegmentLock &l, idx_t row_number) { + return nodes[GetSegmentIndex(l, row_number)].node.get(); + } + + //! Append a column segment to the tree + void AppendSegmentInternal(SegmentLock &l, unique_ptr segment) { + D_ASSERT(segment); + // add the node to the list of nodes + if (!nodes.empty()) { + nodes.back().node->next = segment.get(); + } + SegmentNode node; + segment->index = nodes.size(); + node.row_start = segment->start; + node.node = std::move(segment); + nodes.push_back(std::move(node)); + } + void AppendSegment(unique_ptr segment) { + auto l = Lock(); + AppendSegment(l, std::move(segment)); + } + void AppendSegment(SegmentLock &l, unique_ptr segment) { + LoadAllSegments(l); + AppendSegmentInternal(l, std::move(segment)); + } + //! Debug method, check whether the segment is in the segment tree + bool HasSegment(T *segment) { + auto l = Lock(); + return HasSegment(l, segment); + } + bool HasSegment(SegmentLock &, T *segment) { + return segment->index < nodes.size() && nodes[segment->index].node.get() == segment; + } + + //! Replace this tree with another tree, taking over its nodes in-place + void Replace(SegmentTree &other) { + auto l = Lock(); + Replace(l, other); + } + void Replace(SegmentLock &l, SegmentTree &other) { + other.LoadAllSegments(l); + nodes = std::move(other.nodes); + } + + //! Erase all segments after a specific segment + void EraseSegments(SegmentLock &l, idx_t segment_start) { + LoadAllSegments(l); + if (segment_start >= nodes.size() - 1) { + return; + } + nodes.erase(nodes.begin() + segment_start + 1, nodes.end()); + } + + //! Get the segment index of the column segment for the given row + idx_t GetSegmentIndex(SegmentLock &l, idx_t row_number) { + idx_t segment_index; + if (TryGetSegmentIndex(l, row_number, segment_index)) { + return segment_index; + } + string error; + error = StringUtil::Format("Attempting to find row number \"%lld\" in %lld nodes\n", row_number, nodes.size()); + for (idx_t i = 0; i < nodes.size(); i++) { + error += StringUtil::Format("Node %lld: Start %lld, Count %lld", i, nodes[i].row_start, + nodes[i].node->count.load()); + } + throw InternalException("Could not find node in column segment tree!\n%s%s", error, Exception::GetStackTrace()); + } + + bool TryGetSegmentIndex(SegmentLock &l, idx_t row_number, idx_t &result) { + // load segments until the row number is within bounds + while (nodes.empty() || (row_number >= (nodes.back().row_start + nodes.back().node->count))) { + if (!LoadNextSegment(l)) { + break; + } + } + if (nodes.empty()) { + return false; + } + D_ASSERT(!nodes.empty()); + D_ASSERT(row_number >= nodes[0].row_start); + D_ASSERT(row_number < nodes.back().row_start + nodes.back().node->count); + idx_t lower = 0; + idx_t upper = nodes.size() - 1; + // binary search to find the node + while (lower <= upper) { + idx_t index = (lower + upper) / 2; + D_ASSERT(index < nodes.size()); + auto &entry = nodes[index]; + D_ASSERT(entry.row_start == entry.node->start); + if (row_number < entry.row_start) { + upper = index - 1; + } else if (row_number >= entry.row_start + entry.node->count) { + lower = index + 1; + } else { + result = index; + return true; + } + } + return false; + } + + void Verify(SegmentLock &) { +#ifdef DEBUG + idx_t base_start = nodes.empty() ? 0 : nodes[0].node->start; + for (idx_t i = 0; i < nodes.size(); i++) { + D_ASSERT(nodes[i].row_start == nodes[i].node->start); + D_ASSERT(nodes[i].node->start == base_start); + base_start += nodes[i].node->count; + } +#endif + } + void Verify() { +#ifdef DEBUG + auto l = Lock(); + Verify(l); +#endif + } + + SegmentIterationHelper Segments() { + return SegmentIterationHelper(*this); + } + + void Reinitialize() { + if (nodes.empty()) { + return; + } + idx_t offset = nodes[0].node->start; + for (auto &entry : nodes) { + if (entry.node->start != offset) { + throw InternalException("In SegmentTree::Reinitialize - gap found between nodes!"); + } + entry.row_start = offset; + offset += entry.node->count; + } + } + +protected: + atomic finished_loading; + + //! Load the next segment - only used when lazily loading + virtual unique_ptr LoadSegment() { + return nullptr; + } + +private: + //! The nodes in the tree, can be binary searched + vector> nodes; + //! Lock to access or modify the nodes + mutex node_lock; + +private: + T *GetRootSegmentInternal() { + return nodes.empty() ? nullptr : nodes[0].node.get(); + } + + class SegmentIterationHelper { + public: + explicit SegmentIterationHelper(SegmentTree &tree) : tree(tree) { + } + + private: + SegmentTree &tree; + + private: + class SegmentIterator { + public: + SegmentIterator(SegmentTree &tree_p, T *current_p) : tree(tree_p), current(current_p) { + } + + SegmentTree &tree; + T *current; + + public: + void Next() { + current = tree.GetNextSegment(current); + } + + SegmentIterator &operator++() { + Next(); + return *this; + } + bool operator!=(const SegmentIterator &other) const { + return current != other.current; + } + T &operator*() const { + D_ASSERT(current); + return *current; + } + }; + + public: + SegmentIterator begin() { + return SegmentIterator(tree, tree.GetRootSegment()); + } + SegmentIterator end() { + return SegmentIterator(tree, nullptr); + } + }; + + //! Load the next segment, if there are any left to load + bool LoadNextSegment(SegmentLock &l) { + if (!SUPPORTS_LAZY_LOADING) { + return false; + } + if (finished_loading) { + return false; + } + auto result = LoadSegment(); + if (result) { + AppendSegmentInternal(l, std::move(result)); + return true; + } + return false; + } + + //! Load all segments, if there are any left to load + void LoadAllSegments(SegmentLock &l) { + if (!SUPPORTS_LAZY_LOADING) { + return; + } + while (LoadNextSegment(l)) + ; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp new file mode 100644 index 00000000..17bd5eeb --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/standard_column_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/validity_column_data.hpp" + +namespace duckdb { + +//! Standard column data represents a regular flat column (e.g. a column of type INTEGER or STRING) +class StandardColumnData : public ColumnData { +public: + StandardColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, + LogicalType type, optional_ptr parent = nullptr); + + //! The validity column data + ValidityColumnData validity; + +public: + void SetStart(idx_t new_start) override; + bool CheckZonemap(ColumnScanState &state, TableFilter &filter) override; + + void InitializeScan(ColumnScanState &state) override; + void InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) override; + + idx_t Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) override; + idx_t ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) override; + idx_t ScanCount(ColumnScanState &state, Vector &result, idx_t count) override; + + void InitializeAppend(ColumnAppendState &state) override; + void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; + void RevertAppend(row_t start_row) override; + idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; + void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) override; + void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, + idx_t update_count) override; + void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t depth) override; + unique_ptr GetUpdateStatistics() override; + + void CommitDropColumn() override; + + unique_ptr CreateCheckpointState(RowGroup &row_group, + PartialBlockManager &partial_block_manager) override; + unique_ptr Checkpoint(RowGroup &row_group, PartialBlockManager &partial_block_manager, + ColumnCheckpointInfo &checkpoint_info) override; + void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, + Vector &scan_vector) override; + + void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, + vector &result) override; + + void DeserializeColumn(Deserializer &deserializer) override; + + void Verify(RowGroup &parent) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp new file mode 100644 index 00000000..61956b1f --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/struct_column_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/validity_column_data.hpp" + +namespace duckdb { + +//! Struct column data represents a struct +class StructColumnData : public ColumnData { +public: + StructColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, + LogicalType type, optional_ptr parent = nullptr); + + //! The sub-columns of the struct + vector> sub_columns; + //! The validity column data of the struct + ValidityColumnData validity; + +public: + void SetStart(idx_t new_start) override; + bool CheckZonemap(ColumnScanState &state, TableFilter &filter) override; + idx_t GetMaxEntry() override; + + void InitializeScan(ColumnScanState &state) override; + void InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) override; + + idx_t Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) override; + idx_t ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) override; + idx_t ScanCount(ColumnScanState &state, Vector &result, idx_t count) override; + + void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; + + void InitializeAppend(ColumnAppendState &state) override; + void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; + void RevertAppend(row_t start_row) override; + idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; + void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) override; + void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, + idx_t update_count) override; + void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t depth) override; + unique_ptr GetUpdateStatistics() override; + + void CommitDropColumn() override; + + unique_ptr CreateCheckpointState(RowGroup &row_group, + PartialBlockManager &partial_block_manager) override; + unique_ptr Checkpoint(RowGroup &row_group, PartialBlockManager &partial_block_manager, + ColumnCheckpointInfo &checkpoint_info) override; + + void DeserializeColumn(Deserializer &source) override; + + void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, + vector &result) override; + + void Verify(RowGroup &parent) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp b/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp new file mode 100644 index 00000000..ce94b2f3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table_index_list.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/mutex.hpp" +#include "duckdb/storage/index.hpp" + +namespace duckdb { + +class ConflictManager; + +class TableIndexList { +public: + //! Scan the catalog set, invoking the callback method for every entry + template + void Scan(T &&callback) { + // lock the catalog set + lock_guard lock(indexes_lock); + for (auto &index : indexes) { + if (callback(*index)) { + break; + } + } + } + + const vector> &Indexes() const { + return indexes; + } + + void AddIndex(unique_ptr index); + + void RemoveIndex(Index &index); + + bool Empty(); + + idx_t Count(); + + void Move(TableIndexList &other); + + Index *FindForeignKeyIndex(const vector &fk_keys, ForeignKeyType fk_type); + void VerifyForeignKey(const vector &fk_keys, DataChunk &chunk, ConflictManager &conflict_manager); + + //! Serialize all indexes owned by this table, returns a vector of block info of all indexes + vector SerializeIndexes(duckdb::MetadataWriter &writer); + + vector GetRequiredColumns(); + +private: + //! Indexes associated with the current table + mutex indexes_lock; + vector> indexes; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp b/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp new file mode 100644 index 00000000..52b0ec9a --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/table_statistics.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/storage/statistics/column_statistics.hpp" + +namespace duckdb { +class ColumnList; +class PersistentTableData; +class Serializer; +class Deserializer; + +class TableStatisticsLock { +public: + TableStatisticsLock(mutex &l) : guard(l) { + } + + lock_guard guard; +}; + +class TableStatistics { +public: + void Initialize(const vector &types, PersistentTableData &data); + void InitializeEmpty(const vector &types); + + void InitializeAddColumn(TableStatistics &parent, const LogicalType &new_column_type); + void InitializeRemoveColumn(TableStatistics &parent, idx_t removed_column); + void InitializeAlterType(TableStatistics &parent, idx_t changed_idx, const LogicalType &new_type); + void InitializeAddConstraint(TableStatistics &parent); + + void MergeStats(TableStatistics &other); + void MergeStats(idx_t i, BaseStatistics &stats); + void MergeStats(TableStatisticsLock &lock, idx_t i, BaseStatistics &stats); + + void CopyStats(TableStatistics &other); + unique_ptr CopyStats(idx_t i); + ColumnStatistics &GetStats(idx_t i); + + bool Empty(); + + unique_ptr GetLock(); + + void Serialize(Serializer &serializer) const; + void Deserialize(Deserializer &deserializer, ColumnList &columns); + +private: + //! The statistics lock + mutex stats_lock; + //! Column statistics + vector> column_stats; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp new file mode 100644 index 00000000..cefa5d09 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/update_segment.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/storage/storage_lock.hpp" +#include "duckdb/storage/statistics/segment_statistics.hpp" +#include "duckdb/common/types/string_heap.hpp" + +namespace duckdb { +class ColumnData; +class DataTable; +class Vector; +struct UpdateInfo; +struct UpdateNode; + +class UpdateSegment { +public: + UpdateSegment(ColumnData &column_data); + ~UpdateSegment(); + + ColumnData &column_data; + +public: + bool HasUpdates() const; + bool HasUncommittedUpdates(idx_t vector_index); + bool HasUpdates(idx_t vector_index) const; + bool HasUpdates(idx_t start_row_idx, idx_t end_row_idx); + + void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result); + void FetchCommitted(idx_t vector_index, Vector &result); + void FetchCommittedRange(idx_t start_row, idx_t count, Vector &result); + void Update(TransactionData transaction, idx_t column_index, Vector &update, row_t *ids, idx_t count, + Vector &base_data); + void FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx); + + void RollbackUpdate(UpdateInfo &info); + void CleanupUpdateInternal(const StorageLockKey &lock, UpdateInfo &info); + void CleanupUpdate(UpdateInfo &info); + + unique_ptr GetStatistics(); + StringHeap &GetStringHeap() { + return heap; + } + +private: + //! The lock for the update segment + StorageLock lock; + //! The root node (if any) + unique_ptr root; + //! Update statistics + SegmentStatistics stats; + //! Stats lock + mutex stats_lock; + //! Internal type size + idx_t type_size; + //! String heap, only used for strings + StringHeap heap; + +public: + typedef void (*initialize_update_function_t)(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, + Vector &update, const SelectionVector &sel); + typedef void (*merge_update_function_t)(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, + Vector &update, row_t *ids, idx_t count, const SelectionVector &sel); + typedef void (*fetch_update_function_t)(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, + Vector &result); + typedef void (*fetch_committed_function_t)(UpdateInfo *info, Vector &result); + typedef void (*fetch_committed_range_function_t)(UpdateInfo *info, idx_t start, idx_t end, idx_t result_offset, + Vector &result); + typedef void (*fetch_row_function_t)(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, + idx_t row_idx, Vector &result, idx_t result_idx); + typedef void (*rollback_update_function_t)(UpdateInfo &base_info, UpdateInfo &rollback_info); + typedef idx_t (*statistics_update_function_t)(UpdateSegment *segment, SegmentStatistics &stats, Vector &update, + idx_t count, SelectionVector &sel); + +private: + initialize_update_function_t initialize_update_function; + merge_update_function_t merge_update_function; + fetch_update_function_t fetch_update_function; + fetch_committed_function_t fetch_committed_function; + fetch_committed_range_function_t fetch_committed_range; + fetch_row_function_t fetch_row_function; + rollback_update_function_t rollback_update_function; + statistics_update_function_t statistics_update_function; + +private: + void InitializeUpdateInfo(UpdateInfo &info, row_t *ids, const SelectionVector &sel, idx_t count, idx_t vector_index, + idx_t vector_offset); +}; + +struct UpdateNodeData { + unique_ptr info; + unsafe_unique_array tuples; + unsafe_unique_array tuple_data; +}; + +struct UpdateNode { + unique_ptr info[Storage::ROW_GROUP_VECTOR_COUNT]; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp new file mode 100644 index 00000000..ba203d35 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/validity_column_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/column_data.hpp" + +namespace duckdb { + +//! Validity column data represents the validity data (i.e. which values are null) +class ValidityColumnData : public ColumnData { +public: + ValidityColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, + ColumnData &parent); + +public: + bool CheckZonemap(ColumnScanState &state, TableFilter &filter) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table_io_manager.hpp b/src/duckdb/src/include/duckdb/storage/table_io_manager.hpp new file mode 100644 index 00000000..eff63da0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table_io_manager.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table_io_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { +class BlockManager; +class DataTable; +class MetadataManager; + +class TableIOManager { +public: + virtual ~TableIOManager() { + } + + //! Obtains a reference to the TableIOManager of a specific table + static TableIOManager &Get(DataTable &table); + + //! The block manager used for managing index data + virtual BlockManager &GetIndexBlockManager() = 0; + + //! The block manager used for storing row group data + virtual BlockManager &GetBlockManagerForRowData() = 0; + + virtual MetadataManager &GetMetadataManager() = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table_storage_info.hpp b/src/duckdb/src/include/duckdb/storage/table_storage_info.hpp new file mode 100644 index 00000000..37282d36 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table_storage_info.hpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table_storage_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/unordered_set.hpp" + +namespace duckdb { + +struct ColumnSegmentInfo { + idx_t row_group_index; + idx_t column_id; + string column_path; + idx_t segment_idx; + string segment_type; + idx_t segment_start; + idx_t segment_count; + string compression_type; + string segment_stats; + bool has_updates; + bool persistent; + block_id_t block_id; + idx_t block_offset; + string segment_info; +}; + +struct IndexInfo { + bool is_unique; + bool is_primary; + bool is_foreign; + unordered_set column_set; +}; + +class TableStorageInfo { +public: + //! The (estimated) cardinality of the table + idx_t cardinality = DConstants::INVALID_INDEX; + //! Info of the indexes of a table + vector index_info; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp b/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp new file mode 100644 index 00000000..3844dff8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp @@ -0,0 +1,165 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/write_ahead_log.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/helper.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/enums/wal_type.hpp" +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/storage/storage_info.hpp" + +namespace duckdb { + +struct AlterInfo; + +class AttachedDatabase; +class Catalog; +class DatabaseInstance; +class SchemaCatalogEntry; +class SequenceCatalogEntry; +class ScalarMacroCatalogEntry; +class ViewCatalogEntry; +class TypeCatalogEntry; +class TableCatalogEntry; +class Transaction; +class TransactionManager; + +class ReplayState { +public: + ReplayState(AttachedDatabase &db, ClientContext &context) + : db(db), context(context), catalog(db.GetCatalog()), deserialize_only(false) { + } + + AttachedDatabase &db; + ClientContext &context; + Catalog &catalog; + optional_ptr current_table; + bool deserialize_only; + MetaBlockPointer checkpoint_id; + +public: + void ReplayEntry(WALType entry_type, BinaryDeserializer &deserializer); + +protected: + virtual void ReplayCreateTable(BinaryDeserializer &deserializer); + void ReplayDropTable(BinaryDeserializer &deserializer); + void ReplayAlter(BinaryDeserializer &deserializer); + + void ReplayCreateView(BinaryDeserializer &deserializer); + void ReplayDropView(BinaryDeserializer &deserializer); + + void ReplayCreateSchema(BinaryDeserializer &deserializer); + void ReplayDropSchema(BinaryDeserializer &deserializer); + + void ReplayCreateType(BinaryDeserializer &deserializer); + void ReplayDropType(BinaryDeserializer &deserializer); + + void ReplayCreateSequence(BinaryDeserializer &deserializer); + void ReplayDropSequence(BinaryDeserializer &deserializer); + void ReplaySequenceValue(BinaryDeserializer &deserializer); + + void ReplayCreateMacro(BinaryDeserializer &deserializer); + void ReplayDropMacro(BinaryDeserializer &deserializer); + + void ReplayCreateTableMacro(BinaryDeserializer &deserializer); + void ReplayDropTableMacro(BinaryDeserializer &deserializer); + + void ReplayCreateIndex(BinaryDeserializer &deserializer); + void ReplayDropIndex(BinaryDeserializer &deserializer); + + void ReplayUseTable(BinaryDeserializer &deserializer); + void ReplayInsert(BinaryDeserializer &deserializer); + void ReplayDelete(BinaryDeserializer &deserializer); + void ReplayUpdate(BinaryDeserializer &deserializer); + void ReplayCheckpoint(BinaryDeserializer &deserializer); +}; + +//! The WriteAheadLog (WAL) is a log that is used to provide durability. Prior +//! to committing a transaction it writes the changes the transaction made to +//! the database to the log, which can then be replayed upon startup in case the +//! server crashes or is shut down. +class WriteAheadLog { +public: + //! Initialize the WAL in the specified directory + explicit WriteAheadLog(AttachedDatabase &database, const string &path); + virtual ~WriteAheadLog(); + + //! Skip writing to the WAL + bool skip_writing; + +public: + //! Replay the WAL + static bool Replay(AttachedDatabase &database, string &path); + + //! Returns the current size of the WAL in bytes + int64_t GetWALSize(); + //! Gets the total bytes written to the WAL since startup + idx_t GetTotalWritten(); + + virtual void WriteCreateTable(const TableCatalogEntry &entry); + void WriteDropTable(const TableCatalogEntry &entry); + + void WriteCreateSchema(const SchemaCatalogEntry &entry); + void WriteDropSchema(const SchemaCatalogEntry &entry); + + void WriteCreateView(const ViewCatalogEntry &entry); + void WriteDropView(const ViewCatalogEntry &entry); + + void WriteCreateSequence(const SequenceCatalogEntry &entry); + void WriteDropSequence(const SequenceCatalogEntry &entry); + void WriteSequenceValue(const SequenceCatalogEntry &entry, SequenceValue val); + + void WriteCreateMacro(const ScalarMacroCatalogEntry &entry); + void WriteDropMacro(const ScalarMacroCatalogEntry &entry); + + void WriteCreateTableMacro(const TableMacroCatalogEntry &entry); + void WriteDropTableMacro(const TableMacroCatalogEntry &entry); + + void WriteCreateIndex(const IndexCatalogEntry &entry); + void WriteDropIndex(const IndexCatalogEntry &entry); + + void WriteCreateType(const TypeCatalogEntry &entry); + void WriteDropType(const TypeCatalogEntry &entry); + //! Sets the table used for subsequent insert/delete/update commands + void WriteSetTable(string &schema, string &table); + + void WriteAlter(const AlterInfo &info); + + void WriteInsert(DataChunk &chunk); + void WriteDelete(DataChunk &chunk); + //! Write a single (sub-) column update to the WAL. Chunk must be a pair of (COL, ROW_ID). + //! The column_path vector is a *path* towards a column within the table + //! i.e. if we have a table with a single column S STRUCT(A INT, B INT) + //! and we update the validity mask of "S.B" + //! the column path is: + //! 0 (first column of table) + //! -> 1 (second subcolumn of struct) + //! -> 0 (first subcolumn of INT) + void WriteUpdate(DataChunk &chunk, const vector &column_path); + + //! Truncate the WAL to a previous size, and clear anything currently set in the writer + void Truncate(int64_t size); + //! Delete the WAL file on disk. The WAL should not be used after this point. + void Delete(); + void Flush(); + + void WriteCheckpoint(MetaBlockPointer meta_block); + +protected: + AttachedDatabase &database; + unique_ptr writer; + string wal_path; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/append_info.hpp b/src/duckdb/src/include/duckdb/transaction/append_info.hpp new file mode 100644 index 00000000..8d06fe25 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/append_info.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/append_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { +class DataTable; + +struct AppendInfo { + DataTable *table; + idx_t start_row; + idx_t count; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp b/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp new file mode 100644 index 00000000..2956f1c7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/cleanup_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/transaction/undo_buffer.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/unordered_map.hpp" + +namespace duckdb { + +class DataTable; + +struct DeleteInfo; +struct UpdateInfo; + +class CleanupState { +public: + CleanupState(); + ~CleanupState(); + + // all tables with indexes that possibly need a vacuum (after e.g. a delete) + unordered_map> indexed_tables; + +public: + void CleanupEntry(UndoFlags type, data_ptr_t data); + +private: + // data for index cleanup + optional_ptr current_table; + DataChunk chunk; + row_t row_numbers[STANDARD_VECTOR_SIZE]; + idx_t count; + +private: + void CleanupDelete(DeleteInfo &info); + void CleanupUpdate(UpdateInfo &info); + + void Flush(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/commit_state.hpp b/src/duckdb/src/include/duckdb/transaction/commit_state.hpp new file mode 100644 index 00000000..3b005c4d --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/commit_state.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/commit_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/transaction/undo_buffer.hpp" +#include "duckdb/common/vector_size.hpp" + +namespace duckdb { +class CatalogEntry; +class DataChunk; +class WriteAheadLog; +class ClientContext; + +struct DataTableInfo; +struct DeleteInfo; +struct UpdateInfo; + +class CommitState { +public: + explicit CommitState(transaction_t commit_id, optional_ptr log = nullptr); + + optional_ptr log; + transaction_t commit_id; + UndoFlags current_op; + + optional_ptr current_table_info; + idx_t row_identifiers[STANDARD_VECTOR_SIZE]; + + unique_ptr delete_chunk; + unique_ptr update_chunk; + +public: + template + void CommitEntry(UndoFlags type, data_ptr_t data); + void RevertCommit(UndoFlags type, data_ptr_t data); + +private: + void SwitchTable(DataTableInfo *table, UndoFlags new_op); + + void WriteCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data); + void WriteDelete(DeleteInfo &info); + void WriteUpdate(UpdateInfo &info); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/delete_info.hpp b/src/duckdb/src/include/duckdb/transaction/delete_info.hpp new file mode 100644 index 00000000..569d12f1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/delete_info.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/delete_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { +class DataTable; +class RowVersionManager; + +struct DeleteInfo { + DataTable *table; + RowVersionManager *version_info; + idx_t vector_idx; + idx_t count; + idx_t base_row; + row_t rows[1]; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp b/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp new file mode 100644 index 00000000..8dc116a4 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/duck_transaction.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/transaction/transaction.hpp" + +namespace duckdb { +class RowVersionManager; + +class DuckTransaction : public Transaction { +public: + DuckTransaction(TransactionManager &manager, ClientContext &context, transaction_t start_time, + transaction_t transaction_id); + ~DuckTransaction() override; + + //! The start timestamp of this transaction + transaction_t start_time; + //! The transaction id of this transaction + transaction_t transaction_id; + //! The commit id of this transaction, if it has successfully been committed + transaction_t commit_id; + //! Map of all sequences that were used during the transaction and the value they had in this transaction + unordered_map sequence_usage; + //! Highest active query when the transaction finished, used for cleaning up + transaction_t highest_active_query; + +public: + static DuckTransaction &Get(ClientContext &context, AttachedDatabase &db); + static DuckTransaction &Get(ClientContext &context, Catalog &catalog); + LocalStorage &GetLocalStorage(); + + void PushCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data = nullptr, idx_t extra_data_size = 0); + + //! Commit the current transaction with the given commit identifier. Returns an error message if the transaction + //! commit failed, or an empty string if the commit was sucessful + string Commit(AttachedDatabase &db, transaction_t commit_id, bool checkpoint) noexcept; + //! Returns whether or not a commit of this transaction should trigger an automatic checkpoint + bool AutomaticCheckpoint(AttachedDatabase &db); + + //! Rollback + void Rollback() noexcept; + //! Cleanup the undo buffer + void Cleanup(); + + bool ChangesMade(); + + void PushDelete(DataTable &table, RowVersionManager &info, idx_t vector_idx, row_t rows[], idx_t count, + idx_t base_row); + void PushAppend(DataTable &table, idx_t row_start, idx_t row_count); + UpdateInfo *CreateUpdateInfo(idx_t type_size, idx_t entries); + + bool IsDuckTransaction() const override { + return true; + } + +private: + //! The undo buffer is used to store old versions of rows that are updated + //! or deleted + UndoBuffer undo_buffer; + //! The set of uncommitted appends for the transaction + unique_ptr storage; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp b/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp new file mode 100644 index 00000000..12ed08e0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/duck_transaction_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/transaction/transaction_manager.hpp" + +namespace duckdb { +class DuckTransaction; + +//! The Transaction Manager is responsible for creating and managing +//! transactions +class DuckTransactionManager : public TransactionManager { + friend struct CheckpointLock; + +public: + explicit DuckTransactionManager(AttachedDatabase &db); + ~DuckTransactionManager(); + +public: + static DuckTransactionManager &Get(AttachedDatabase &db); + + //! Start a new transaction + Transaction *StartTransaction(ClientContext &context) override; + //! Commit the given transaction + string CommitTransaction(ClientContext &context, Transaction *transaction) override; + //! Rollback the given transaction + void RollbackTransaction(Transaction *transaction) override; + + void Checkpoint(ClientContext &context, bool force = false) override; + + transaction_t LowestActiveId() { + return lowest_active_id; + } + transaction_t LowestActiveStart() { + return lowest_active_start; + } + + bool IsDuckTransactionManager() override { + return true; + } + +private: + bool CanCheckpoint(optional_ptr current = nullptr); + //! Remove the given transaction from the list of active transactions + void RemoveTransaction(DuckTransaction &transaction) noexcept; + void LockClients(vector &client_locks, ClientContext &context); + +private: + //! The current start timestamp used by transactions + transaction_t current_start_timestamp; + //! The current transaction ID used by transactions + transaction_t current_transaction_id; + //! The lowest active transaction id + atomic lowest_active_id; + //! The lowest active transaction timestamp + atomic lowest_active_start; + //! Set of currently running transactions + vector> active_transactions; + //! Set of recently committed transactions + vector> recently_committed_transactions; + //! Transactions awaiting GC + vector> old_transactions; + //! The lock used for transaction operations + mutex transaction_lock; + + bool thread_is_checkpointing; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp new file mode 100644 index 00000000..099507b5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp @@ -0,0 +1,166 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/local_storage.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/row_group_collection.hpp" +#include "duckdb/storage/table/table_index_list.hpp" +#include "duckdb/storage/table/table_statistics.hpp" +#include "duckdb/storage/optimistic_data_writer.hpp" + +namespace duckdb { +class AttachedDatabase; +class DataTable; +class Transaction; +class WriteAheadLog; +struct LocalAppendState; +struct TableAppendState; + +class LocalTableStorage : public std::enable_shared_from_this { +public: + // Create a new LocalTableStorage + explicit LocalTableStorage(DataTable &table); + // Create a LocalTableStorage from an ALTER TYPE + LocalTableStorage(ClientContext &context, DataTable &table, LocalTableStorage &parent, idx_t changed_idx, + const LogicalType &target_type, const vector &bound_columns, Expression &cast_expr); + // Create a LocalTableStorage from a DROP COLUMN + LocalTableStorage(DataTable &table, LocalTableStorage &parent, idx_t drop_idx); + // Create a LocalTableStorage from an ADD COLUMN + LocalTableStorage(ClientContext &context, DataTable &table, LocalTableStorage &parent, ColumnDefinition &new_column, + Expression &default_value); + ~LocalTableStorage(); + + reference table_ref; + + Allocator &allocator; + //! The main chunk collection holding the data + shared_ptr row_groups; + //! The set of unique indexes + TableIndexList indexes; + //! The number of deleted rows + idx_t deleted_rows; + //! The main optimistic data writer + OptimisticDataWriter optimistic_writer; + //! The set of all optimistic data writers associated with this table + vector> optimistic_writers; + //! Whether or not storage was merged + bool merged_storage = false; + +public: + void InitializeScan(CollectionScanState &state, optional_ptr table_filters = nullptr); + //! Write a new row group to disk (if possible) + void WriteNewRowGroup(); + void FlushBlocks(); + void Rollback(); + idx_t EstimatedSize(); + + void AppendToIndexes(DuckTransaction &transaction, TableAppendState &append_state, idx_t append_count, + bool append_to_table); + PreservedError AppendToIndexes(DuckTransaction &transaction, RowGroupCollection &source, TableIndexList &index_list, + const vector &table_types, row_t &start_row); + + //! Creates an optimistic writer for this table + OptimisticDataWriter &CreateOptimisticWriter(); + void FinalizeOptimisticWriter(OptimisticDataWriter &writer); +}; + +class LocalTableManager { +public: + shared_ptr MoveEntry(DataTable &table); + reference_map_t> MoveEntries(); + optional_ptr GetStorage(DataTable &table); + LocalTableStorage &GetOrCreateStorage(DataTable &table); + idx_t EstimatedSize(); + bool IsEmpty(); + void InsertEntry(DataTable &table, shared_ptr entry); + +private: + mutex table_storage_lock; + reference_map_t> table_storage; +}; + +//! The LocalStorage class holds appends that have not been committed yet +class LocalStorage { +public: + // Threshold to merge row groups instead of appending + static constexpr const idx_t MERGE_THRESHOLD = Storage::ROW_GROUP_SIZE; + +public: + struct CommitState { + CommitState(); + ~CommitState(); + + reference_map_t> append_states; + }; + +public: + explicit LocalStorage(ClientContext &context, DuckTransaction &transaction); + + static LocalStorage &Get(DuckTransaction &transaction); + static LocalStorage &Get(ClientContext &context, AttachedDatabase &db); + static LocalStorage &Get(ClientContext &context, Catalog &catalog); + + //! Initialize a scan of the local storage + void InitializeScan(DataTable &table, CollectionScanState &state, optional_ptr table_filters); + //! Scan + void Scan(CollectionScanState &state, const vector &column_ids, DataChunk &result); + + void InitializeParallelScan(DataTable &table, ParallelCollectionScanState &state); + bool NextParallelScan(ClientContext &context, DataTable &table, ParallelCollectionScanState &state, + CollectionScanState &scan_state); + + //! Begin appending to the local storage + void InitializeAppend(LocalAppendState &state, DataTable &table); + //! Append a chunk to the local storage + static void Append(LocalAppendState &state, DataChunk &chunk); + //! Finish appending to the local storage + static void FinalizeAppend(LocalAppendState &state); + //! Merge a row group collection into the transaction-local storage + void LocalMerge(DataTable &table, RowGroupCollection &collection); + //! Create an optimistic writer for the specified table + OptimisticDataWriter &CreateOptimisticWriter(DataTable &table); + void FinalizeOptimisticWriter(DataTable &table, OptimisticDataWriter &writer); + + //! Delete a set of rows from the local storage + idx_t Delete(DataTable &table, Vector &row_ids, idx_t count); + //! Update a set of rows in the local storage + void Update(DataTable &table, Vector &row_ids, const vector &column_ids, DataChunk &data); + + //! Commits the local storage, writing it to the WAL and completing the commit + void Commit(LocalStorage::CommitState &commit_state, DuckTransaction &transaction); + //! Rollback the local storage + void Rollback(); + + bool ChangesMade() noexcept; + idx_t EstimatedSize(); + + bool Find(DataTable &table); + + idx_t AddedRows(DataTable &table); + + void AddColumn(DataTable &old_dt, DataTable &new_dt, ColumnDefinition &new_column, Expression &default_value); + void DropColumn(DataTable &old_dt, DataTable &new_dt, idx_t removed_column); + void ChangeType(DataTable &old_dt, DataTable &new_dt, idx_t changed_idx, const LogicalType &target_type, + const vector &bound_columns, Expression &cast_expr); + + void MoveStorage(DataTable &old_dt, DataTable &new_dt); + void FetchChunk(DataTable &table, Vector &row_ids, idx_t count, const vector &col_ids, DataChunk &chunk, + ColumnFetchState &fetch_state); + TableIndexList &GetIndexes(DataTable &table); + + void VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint); + +private: + ClientContext &context; + DuckTransaction &transaction; + LocalTableManager table_manager; + + void Flush(DataTable &table, LocalTableStorage &storage); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp b/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp new file mode 100644 index 00000000..16d973d0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/meta_transaction.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/main/valid_checker.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { +class AttachedDatabase; +class ClientContext; +class Transaction; + +//! The MetaTransaction manages multiple transactions for different attached databases +class MetaTransaction { +public: + DUCKDB_API MetaTransaction(ClientContext &context, timestamp_t start_timestamp, idx_t catalog_version); + + ClientContext &context; + //! The timestamp when the transaction started + timestamp_t start_timestamp; + //! The catalog version when the transaction was started + idx_t catalog_version; + //! The validity checker of the transaction + ValidChecker transaction_validity; + //! Whether or not any transaction have made modifications + bool read_only; + //! The active query number + transaction_t active_query; + +public: + DUCKDB_API static MetaTransaction &Get(ClientContext &context); + timestamp_t GetCurrentTransactionStartTimestamp() { + return start_timestamp; + } + + Transaction &GetTransaction(AttachedDatabase &db); + + string Commit(); + void Rollback(); + + idx_t GetActiveQuery(); + void SetActiveQuery(transaction_t query_number); + + void ModifyDatabase(AttachedDatabase &db); + optional_ptr ModifiedDatabase() { + return modified_database; + } + +private: + //! The set of active transactions for each database + unordered_map transactions; + //! The set of transactions in order of when they were started + vector> all_transactions; + //! The database we are modifying - we can only modify one database per transaction + optional_ptr modified_database; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/rollback_state.hpp b/src/duckdb/src/include/duckdb/transaction/rollback_state.hpp new file mode 100644 index 00000000..19d6cdad --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/rollback_state.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/rollback_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/transaction/undo_buffer.hpp" + +namespace duckdb { +class DataChunk; +class DataTable; +class WriteAheadLog; + +class RollbackState { +public: + RollbackState() { + } + +public: + void RollbackEntry(UndoFlags type, data_ptr_t data); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/transaction.hpp b/src/duckdb/src/include/duckdb/transaction/transaction.hpp new file mode 100644 index 00000000..0caa342a --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/transaction.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/transaction.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/transaction/undo_buffer.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/transaction/transaction_data.hpp" + +namespace duckdb { +class SequenceCatalogEntry; +class SchemaCatalogEntry; + +class AttachedDatabase; +class ColumnData; +class ClientContext; +class CatalogEntry; +class DataTable; +class DatabaseInstance; +class LocalStorage; +class MetaTransaction; +class TransactionManager; +class WriteAheadLog; + +class ChunkVectorInfo; + +struct DeleteInfo; +struct UpdateInfo; + +//! The transaction object holds information about a currently running or past +//! transaction +class Transaction { +public: + DUCKDB_API Transaction(TransactionManager &manager, ClientContext &context); + DUCKDB_API virtual ~Transaction(); + + TransactionManager &manager; + weak_ptr context; + //! The current active query for the transaction. Set to MAXIMUM_QUERY_ID if + //! no query is active. + atomic active_query; + +public: + DUCKDB_API static Transaction &Get(ClientContext &context, AttachedDatabase &db); + DUCKDB_API static Transaction &Get(ClientContext &context, Catalog &catalog); + + //! Whether or not the transaction has made any modifications to the database so far + DUCKDB_API bool IsReadOnly(); + + virtual bool IsDuckTransaction() const { + return false; + } + +public: + template + TARGET &Cast() { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp b/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp new file mode 100644 index 00000000..1334ab09 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/transaction_context.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +class ClientContext; +class MetaTransaction; +class Transaction; +class TransactionManager; + +//! The transaction context keeps track of all the information relating to the +//! current transaction +class TransactionContext { +public: + TransactionContext(ClientContext &context); + ~TransactionContext(); + + MetaTransaction &ActiveTransaction() { + if (!current_transaction) { + throw InternalException("TransactionContext::ActiveTransaction called without active transaction"); + } + return *current_transaction; + } + + bool HasActiveTransaction() { + return !!current_transaction; + } + + void BeginTransaction(); + void Commit(); + void Rollback(); + void ClearTransaction(); + + void SetAutoCommit(bool value); + bool IsAutoCommit() { + return auto_commit; + } + + idx_t GetActiveQuery(); + void ResetActiveQuery(); + void SetActiveQuery(transaction_t query_number); + +private: + ClientContext &context; + bool auto_commit; + + unique_ptr current_transaction; + + TransactionContext(const TransactionContext &) = delete; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/transaction_data.hpp b/src/duckdb/src/include/duckdb/transaction/transaction_data.hpp new file mode 100644 index 00000000..b8e5d23e --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/transaction_data.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/transaction_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { +class DuckTransaction; +class Transaction; + +struct TransactionData { + TransactionData(DuckTransaction &transaction_p); + TransactionData(transaction_t transaction_id_p, transaction_t start_time_p); + + optional_ptr transaction; + transaction_t transaction_id; + transaction_t start_time; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/transaction_manager.hpp b/src/duckdb/src/include/duckdb/transaction/transaction_manager.hpp new file mode 100644 index 00000000..cd663833 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/transaction_manager.hpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/transaction_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/vector.hpp" + +#include "duckdb/common/atomic.hpp" + +namespace duckdb { + +class AttachedDatabase; +class ClientContext; +class Catalog; +struct ClientLockWrapper; +class DatabaseInstance; +class Transaction; + +//! The Transaction Manager is responsible for creating and managing +//! transactions +class TransactionManager { +public: + explicit TransactionManager(AttachedDatabase &db); + virtual ~TransactionManager(); + + //! Start a new transaction + virtual Transaction *StartTransaction(ClientContext &context) = 0; + //! Commit the given transaction. Returns a non-empty error message on failure. + virtual string CommitTransaction(ClientContext &context, Transaction *transaction) = 0; + //! Rollback the given transaction + virtual void RollbackTransaction(Transaction *transaction) = 0; + + virtual void Checkpoint(ClientContext &context, bool force = false) = 0; + + static TransactionManager &Get(AttachedDatabase &db); + + virtual bool IsDuckTransactionManager() { + return false; + } + + AttachedDatabase &GetDB() { + return db; + } + +protected: + //! The attached database + AttachedDatabase &db; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/undo_buffer.hpp b/src/duckdb/src/include/duckdb/transaction/undo_buffer.hpp new file mode 100644 index 00000000..d0ae7096 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/undo_buffer.hpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/undo_buffer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/undo_flags.hpp" +#include "duckdb/storage/arena_allocator.hpp" + +namespace duckdb { + +class WriteAheadLog; + +//! The undo buffer of a transaction is used to hold previous versions of tuples +//! that might be required in the future (because of rollbacks or previous +//! transactions accessing them) +class UndoBuffer { +public: + struct IteratorState { + ArenaChunk *current; + data_ptr_t start; + data_ptr_t end; + }; + +public: + UndoBuffer(ClientContext &context); + + //! Reserve space for an entry of the specified type and length in the undo + //! buffer + data_ptr_t CreateEntry(UndoFlags type, idx_t len); + + bool ChangesMade(); + idx_t EstimatedSize(); + + //! Cleanup the undo buffer + void Cleanup(); + //! Commit the changes made in the UndoBuffer: should be called on commit + void Commit(UndoBuffer::IteratorState &iterator_state, optional_ptr log, transaction_t commit_id); + //! Revert committed changes made in the UndoBuffer up until the currently committed state + void RevertCommit(UndoBuffer::IteratorState &iterator_state, transaction_t transaction_id); + //! Rollback the changes made in this UndoBuffer: should be called on + //! rollback + void Rollback() noexcept; + +private: + ArenaAllocator allocator; + +private: + template + void IterateEntries(UndoBuffer::IteratorState &state, T &&callback); + template + void IterateEntries(UndoBuffer::IteratorState &state, UndoBuffer::IteratorState &end_state, T &&callback); + template + void ReverseIterateEntries(T &&callback); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/update_info.hpp b/src/duckdb/src/include/duckdb/transaction/update_info.hpp new file mode 100644 index 00000000..8dc313a3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/transaction/update_info.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/transaction/update_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/common/atomic.hpp" + +namespace duckdb { +class UpdateSegment; +struct DataTableInfo; + +struct UpdateInfo { + //! The update segment that this update info affects + UpdateSegment *segment; + //! The column index of which column we are updating + idx_t column_index; + //! The version number + atomic version_number; + //! The vector index within the uncompressed segment + idx_t vector_index; + //! The amount of updated tuples + sel_t N; + //! The maximum amount of tuples that can fit into this UpdateInfo + sel_t max; + //! The row ids of the tuples that have been updated. This should always be kept sorted! + sel_t *tuples; + //! The data of the tuples + data_ptr_t tuple_data; + //! The previous update info (or nullptr if it is the base) + UpdateInfo *prev; + //! The next update info in the chain (or nullptr if it is the last) + UpdateInfo *next; + + //! Loop over the update chain and execute the specified callback on all UpdateInfo's that are relevant for that + //! transaction in-order of newest to oldest + template + static void UpdatesForTransaction(UpdateInfo *current, transaction_t start_time, transaction_t transaction_id, + T &&callback) { + while (current) { + if (current->version_number > start_time && current->version_number != transaction_id) { + // these tuples were either committed AFTER this transaction started or are not committed yet, use + // tuples stored in this version + callback(current); + } + current = current->next; + } + } + + Value GetValue(idx_t index); + string ToString(); + void Print(); + void Verify(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/copied_statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/copied_statement_verifier.hpp new file mode 100644 index 00000000..1df929c6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/copied_statement_verifier.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/copied_statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +class CopiedStatementVerifier : public StatementVerifier { +public: + explicit CopiedStatementVerifier(unique_ptr statement_p); + static unique_ptr Create(const SQLStatement &statement_p); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/deserialized_statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/deserialized_statement_verifier.hpp new file mode 100644 index 00000000..78b2ff1e --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/deserialized_statement_verifier.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/deserialized_statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +class DeserializedStatementVerifier : public StatementVerifier { +public: + explicit DeserializedStatementVerifier(unique_ptr statement_p); + static unique_ptr Create(const SQLStatement &statement); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/external_statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/external_statement_verifier.hpp new file mode 100644 index 00000000..91d551f0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/external_statement_verifier.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/external_statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +class ExternalStatementVerifier : public StatementVerifier { +public: + explicit ExternalStatementVerifier(unique_ptr statement_p); + static unique_ptr Create(const SQLStatement &statement); + + bool ForceExternal() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/no_operator_caching_verifier.hpp b/src/duckdb/src/include/duckdb/verification/no_operator_caching_verifier.hpp new file mode 100644 index 00000000..51a97d35 --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/no_operator_caching_verifier.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/unoptimized_statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +class NoOperatorCachingVerifier : public StatementVerifier { +public: + explicit NoOperatorCachingVerifier(unique_ptr statement_p); + static unique_ptr Create(const SQLStatement &statement_p); + + bool DisableOperatorCaching() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/parsed_statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/parsed_statement_verifier.hpp new file mode 100644 index 00000000..5448d5f8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/parsed_statement_verifier.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/parsed_statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +class ParsedStatementVerifier : public StatementVerifier { +public: + explicit ParsedStatementVerifier(unique_ptr statement_p); + static unique_ptr Create(const SQLStatement &statement); + + bool RequireEquality() const override { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/prepared_statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/prepared_statement_verifier.hpp new file mode 100644 index 00000000..23c7593a --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/prepared_statement_verifier.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/prepared_statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +class PreparedStatementVerifier : public StatementVerifier { +public: + explicit PreparedStatementVerifier(unique_ptr statement_p); + static unique_ptr Create(const SQLStatement &statement_p); + + bool Run(ClientContext &context, const string &query, + const std::function(const string &, unique_ptr)> &run) override; + +private: + case_insensitive_map_t> values; + unique_ptr prepare_statement; + unique_ptr execute_statement; + unique_ptr dealloc_statement; + +private: + void Extract(); + void ConvertConstants(unique_ptr &child); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp new file mode 100644 index 00000000..d14f1ffa --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/materialized_query_result.hpp" +#include "duckdb/parser/statement/select_statement.hpp" + +namespace duckdb { + +enum class VerificationType : uint8_t { + ORIGINAL, + COPIED, + DESERIALIZED, + PARSED, + UNOPTIMIZED, + NO_OPERATOR_CACHING, + PREPARED, + EXTERNAL, + + INVALID +}; + +class StatementVerifier { +public: + StatementVerifier(VerificationType type, string name, unique_ptr statement_p); + explicit StatementVerifier(unique_ptr statement_p); + static unique_ptr Create(VerificationType type, const SQLStatement &statement_p); + virtual ~StatementVerifier() noexcept; + + //! Check whether expressions in this verifier and the other verifier match + void CheckExpressions(const StatementVerifier &other) const; + //! Check whether expressions within this verifier match + void CheckExpressions() const; + + //! Run the select statement and store the result + virtual bool Run(ClientContext &context, const string &query, + const std::function(const string &, unique_ptr)> &run); + //! Compare this verifier's results with another verifier + string CompareResults(const StatementVerifier &other); + +public: + const VerificationType type; + const string name; + unique_ptr statement; + const vector> &select_list; + unique_ptr materialized_result; + + virtual bool RequireEquality() const { + return true; + } + + virtual bool DisableOptimizer() const { + return false; + } + + virtual bool DisableOperatorCaching() const { + return false; + } + + virtual bool ForceExternal() const { + return false; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/unoptimized_statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/unoptimized_statement_verifier.hpp new file mode 100644 index 00000000..4d71b2e7 --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/unoptimized_statement_verifier.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/unoptimized_statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +class UnoptimizedStatementVerifier : public StatementVerifier { +public: + explicit UnoptimizedStatementVerifier(unique_ptr statement_p); + static unique_ptr Create(const SQLStatement &statement_p); + + bool DisableOptimizer() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/main/appender.cpp b/src/duckdb/src/main/appender.cpp new file mode 100644 index 00000000..57822c38 --- /dev/null +++ b/src/duckdb/src/main/appender.cpp @@ -0,0 +1,372 @@ +#include "duckdb/main/appender.hpp" + +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/decimal_cast_operators.hpp" +#include "duckdb/common/operator/string_cast.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/connection.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/storage/data_table.hpp" + +namespace duckdb { + +BaseAppender::BaseAppender(Allocator &allocator, AppenderType type_p) + : allocator(allocator), column(0), appender_type(type_p) { +} + +BaseAppender::BaseAppender(Allocator &allocator_p, vector types_p, AppenderType type_p) + : allocator(allocator_p), types(std::move(types_p)), collection(make_uniq(allocator, types)), + column(0), appender_type(type_p) { + InitializeChunk(); +} + +BaseAppender::~BaseAppender() { +} + +void BaseAppender::Destructor() { + if (Exception::UncaughtException()) { + return; + } + // flush any remaining chunks, but only if we are not cleaning up the appender as part of an exception stack unwind + // wrapped in a try/catch because Close() can throw if the table was dropped in the meantime + try { + Close(); + } catch (...) { + } +} + +InternalAppender::InternalAppender(ClientContext &context_p, TableCatalogEntry &table_p) + : BaseAppender(Allocator::DefaultAllocator(), table_p.GetTypes(), AppenderType::PHYSICAL), context(context_p), + table(table_p) { +} + +InternalAppender::~InternalAppender() { + Destructor(); +} + +Appender::Appender(Connection &con, const string &schema_name, const string &table_name) + : BaseAppender(Allocator::DefaultAllocator(), AppenderType::LOGICAL), context(con.context) { + description = con.TableInfo(schema_name, table_name); + if (!description) { + // table could not be found + throw CatalogException(StringUtil::Format("Table \"%s.%s\" could not be found", schema_name, table_name)); + } + for (auto &column : description->columns) { + types.push_back(column.Type()); + } + InitializeChunk(); + collection = make_uniq(allocator, types); +} + +Appender::Appender(Connection &con, const string &table_name) : Appender(con, DEFAULT_SCHEMA, table_name) { +} + +Appender::~Appender() { + Destructor(); +} + +void BaseAppender::InitializeChunk() { + chunk.Initialize(allocator, types); +} + +void BaseAppender::BeginRow() { +} + +void BaseAppender::EndRow() { + // check that all rows have been appended to + if (column != chunk.ColumnCount()) { + throw InvalidInputException("Call to EndRow before all rows have been appended to!"); + } + column = 0; + chunk.SetCardinality(chunk.size() + 1); + if (chunk.size() >= STANDARD_VECTOR_SIZE) { + FlushChunk(); + } +} + +template +void BaseAppender::AppendValueInternal(Vector &col, SRC input) { + FlatVector::GetData(col)[chunk.size()] = Cast::Operation(input); +} + +template +void BaseAppender::AppendDecimalValueInternal(Vector &col, SRC input) { + switch (appender_type) { + case AppenderType::LOGICAL: { + auto &type = col.GetType(); + D_ASSERT(type.id() == LogicalTypeId::DECIMAL); + auto width = DecimalType::GetWidth(type); + auto scale = DecimalType::GetScale(type); + TryCastToDecimal::Operation(input, FlatVector::GetData(col)[chunk.size()], nullptr, width, + scale); + return; + } + case AppenderType::PHYSICAL: { + AppendValueInternal(col, input); + return; + } + default: + throw InternalException("Type not implemented for AppenderType"); + } +} + +template +void BaseAppender::AppendValueInternal(T input) { + if (column >= types.size()) { + throw InvalidInputException("Too many appends for chunk!"); + } + auto &col = chunk.data[column]; + switch (col.GetType().id()) { + case LogicalTypeId::BOOLEAN: + AppendValueInternal(col, input); + break; + case LogicalTypeId::UTINYINT: + AppendValueInternal(col, input); + break; + case LogicalTypeId::TINYINT: + AppendValueInternal(col, input); + break; + case LogicalTypeId::USMALLINT: + AppendValueInternal(col, input); + break; + case LogicalTypeId::SMALLINT: + AppendValueInternal(col, input); + break; + case LogicalTypeId::UINTEGER: + AppendValueInternal(col, input); + break; + case LogicalTypeId::INTEGER: + AppendValueInternal(col, input); + break; + case LogicalTypeId::UBIGINT: + AppendValueInternal(col, input); + break; + case LogicalTypeId::BIGINT: + AppendValueInternal(col, input); + break; + case LogicalTypeId::HUGEINT: + AppendValueInternal(col, input); + break; + case LogicalTypeId::FLOAT: + AppendValueInternal(col, input); + break; + case LogicalTypeId::DOUBLE: + AppendValueInternal(col, input); + break; + case LogicalTypeId::DECIMAL: + switch (col.GetType().InternalType()) { + case PhysicalType::INT16: + AppendDecimalValueInternal(col, input); + break; + case PhysicalType::INT32: + AppendDecimalValueInternal(col, input); + break; + case PhysicalType::INT64: + AppendDecimalValueInternal(col, input); + break; + case PhysicalType::INT128: + AppendDecimalValueInternal(col, input); + break; + default: + throw InternalException("Internal type not recognized for Decimal"); + } + break; + case LogicalTypeId::DATE: + AppendValueInternal(col, input); + break; + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + AppendValueInternal(col, input); + break; + case LogicalTypeId::TIME: + AppendValueInternal(col, input); + break; + case LogicalTypeId::TIME_TZ: + AppendValueInternal(col, input); + break; + case LogicalTypeId::INTERVAL: + AppendValueInternal(col, input); + break; + case LogicalTypeId::VARCHAR: + FlatVector::GetData(col)[chunk.size()] = StringCast::Operation(input, col); + break; + default: + AppendValue(Value::CreateValue(input)); + return; + } + column++; +} + +template <> +void BaseAppender::Append(bool value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(int8_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(int16_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(int32_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(int64_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(hugeint_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(uint8_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(uint16_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(uint32_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(uint64_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(const char *value) { + AppendValueInternal(string_t(value)); +} + +void BaseAppender::Append(const char *value, uint32_t length) { + AppendValueInternal(string_t(value, length)); +} + +template <> +void BaseAppender::Append(string_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(float value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(double value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(date_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(dtime_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(timestamp_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(interval_t value) { + AppendValueInternal(value); +} + +template <> +void BaseAppender::Append(Value value) { // NOLINT: template shtuff + if (column >= chunk.ColumnCount()) { + throw InvalidInputException("Too many appends for chunk!"); + } + AppendValue(value); +} + +template <> +void BaseAppender::Append(std::nullptr_t value) { + if (column >= chunk.ColumnCount()) { + throw InvalidInputException("Too many appends for chunk!"); + } + auto &col = chunk.data[column++]; + FlatVector::SetNull(col, chunk.size(), true); +} + +void BaseAppender::AppendValue(const Value &value) { + chunk.SetValue(column, chunk.size(), value); + column++; +} + +void BaseAppender::AppendDataChunk(DataChunk &chunk) { + if (chunk.GetTypes() != types) { + throw InvalidInputException("Type mismatch in Append DataChunk and the types required for appender"); + } + collection->Append(chunk); + if (collection->Count() >= FLUSH_COUNT) { + Flush(); + } +} + +void BaseAppender::FlushChunk() { + if (chunk.size() == 0) { + return; + } + collection->Append(chunk); + chunk.Reset(); + if (collection->Count() >= FLUSH_COUNT) { + Flush(); + } +} + +void BaseAppender::Flush() { + // check that all vectors have the same length before appending + if (column != 0) { + throw InvalidInputException("Failed to Flush appender: incomplete append to row!"); + } + + FlushChunk(); + if (collection->Count() == 0) { + return; + } + FlushInternal(*collection); + + collection->Reset(); + column = 0; +} + +void Appender::FlushInternal(ColumnDataCollection &collection) { + context->Append(*description, collection); +} + +void InternalAppender::FlushInternal(ColumnDataCollection &collection) { + table.GetStorage().LocalAppend(table, context, collection); +} + +void BaseAppender::Close() { + if (column == 0 || column == types.size()) { + Flush(); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/attached_database.cpp b/src/duckdb/src/main/attached_database.cpp new file mode 100644 index 00000000..24605378 --- /dev/null +++ b/src/duckdb/src/main/attached_database.cpp @@ -0,0 +1,135 @@ +#include "duckdb/main/attached_database.hpp" + +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/parser/parsed_data/attach_info.hpp" +#include "duckdb/storage/storage_extension.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/transaction/duck_transaction_manager.hpp" + +namespace duckdb { + +AttachedDatabase::AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType type) + : CatalogEntry(CatalogType::DATABASE_ENTRY, + type == AttachedDatabaseType::SYSTEM_DATABASE ? SYSTEM_CATALOG : TEMP_CATALOG, 0), + db(db), type(type) { + D_ASSERT(type == AttachedDatabaseType::TEMP_DATABASE || type == AttachedDatabaseType::SYSTEM_DATABASE); + if (type == AttachedDatabaseType::TEMP_DATABASE) { + storage = make_uniq(*this, ":memory:", false); + } + catalog = make_uniq(*this); + transaction_manager = make_uniq(*this); + internal = true; +} + +AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, string name_p, string file_path_p, + AccessMode access_mode) + : CatalogEntry(CatalogType::DATABASE_ENTRY, catalog_p, std::move(name_p)), db(db), + type(access_mode == AccessMode::READ_ONLY ? AttachedDatabaseType::READ_ONLY_DATABASE + : AttachedDatabaseType::READ_WRITE_DATABASE), + parent_catalog(&catalog_p) { + storage = make_uniq(*this, std::move(file_path_p), access_mode == AccessMode::READ_ONLY); + catalog = make_uniq(*this); + transaction_manager = make_uniq(*this); + internal = true; +} + +AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, StorageExtension &storage_extension, + string name_p, AttachInfo &info, AccessMode access_mode) + : CatalogEntry(CatalogType::DATABASE_ENTRY, catalog_p, std::move(name_p)), db(db), + type(access_mode == AccessMode::READ_ONLY ? AttachedDatabaseType::READ_ONLY_DATABASE + : AttachedDatabaseType::READ_WRITE_DATABASE), + parent_catalog(&catalog_p) { + catalog = storage_extension.attach(storage_extension.storage_info.get(), *this, name, info, access_mode); + if (!catalog) { + throw InternalException("AttachedDatabase - attach function did not return a catalog"); + } + transaction_manager = + storage_extension.create_transaction_manager(storage_extension.storage_info.get(), *this, *catalog); + if (!transaction_manager) { + throw InternalException( + "AttachedDatabase - create_transaction_manager function did not return a transaction manager"); + } + internal = true; +} + +AttachedDatabase::~AttachedDatabase() { + if (Exception::UncaughtException()) { + return; + } + if (!storage) { + return; + } + + // shutting down: attempt to checkpoint the database + // but only if we are not cleaning up as part of an exception unwind + try { + if (!storage->InMemory()) { + auto &config = DBConfig::GetConfig(db); + if (!config.options.checkpoint_on_shutdown) { + return; + } + storage->CreateCheckpoint(true); + } + } catch (...) { + } +} + +bool AttachedDatabase::IsSystem() const { + D_ASSERT(!storage || type != AttachedDatabaseType::SYSTEM_DATABASE); + return type == AttachedDatabaseType::SYSTEM_DATABASE; +} + +bool AttachedDatabase::IsTemporary() const { + return type == AttachedDatabaseType::TEMP_DATABASE; +} +bool AttachedDatabase::IsReadOnly() const { + return type == AttachedDatabaseType::READ_ONLY_DATABASE; +} + +string AttachedDatabase::ExtractDatabaseName(const string &dbpath, FileSystem &fs) { + if (dbpath.empty() || dbpath == ":memory:") { + return "memory"; + } + return fs.ExtractBaseName(dbpath); +} + +void AttachedDatabase::Initialize() { + if (IsSystem()) { + catalog->Initialize(true); + } else { + catalog->Initialize(false); + } + if (storage) { + storage->Initialize(); + } +} + +StorageManager &AttachedDatabase::GetStorageManager() { + if (!storage) { + throw InternalException("Internal system catalog does not have storage"); + } + return *storage; +} + +Catalog &AttachedDatabase::GetCatalog() { + return *catalog; +} + +TransactionManager &AttachedDatabase::GetTransactionManager() { + return *transaction_manager; +} + +Catalog &AttachedDatabase::ParentCatalog() { + return *parent_catalog; +} + +bool AttachedDatabase::IsInitialDatabase() const { + return is_initial_database; +} + +void AttachedDatabase::SetInitialDatabase() { + is_initial_database = true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/capi/appender-c.cpp b/src/duckdb/src/main/capi/appender-c.cpp new file mode 100644 index 00000000..8efa299b --- /dev/null +++ b/src/duckdb/src/main/capi/appender-c.cpp @@ -0,0 +1,208 @@ +#include "duckdb/main/capi/capi_internal.hpp" + +using duckdb::Appender; +using duckdb::AppenderWrapper; +using duckdb::Connection; +using duckdb::date_t; +using duckdb::dtime_t; +using duckdb::hugeint_t; +using duckdb::interval_t; +using duckdb::string_t; +using duckdb::timestamp_t; + +duckdb_state duckdb_appender_create(duckdb_connection connection, const char *schema, const char *table, + duckdb_appender *out_appender) { + Connection *conn = reinterpret_cast(connection); + + if (!connection || !table || !out_appender) { + return DuckDBError; + } + if (schema == nullptr) { + schema = DEFAULT_SCHEMA; + } + auto wrapper = new AppenderWrapper(); + *out_appender = (duckdb_appender)wrapper; + try { + wrapper->appender = duckdb::make_uniq(*conn, schema, table); + } catch (std::exception &ex) { + wrapper->error = ex.what(); + return DuckDBError; + } catch (...) { // LCOV_EXCL_START + wrapper->error = "Unknown create appender error"; + return DuckDBError; + } // LCOV_EXCL_STOP + return DuckDBSuccess; +} + +duckdb_state duckdb_appender_destroy(duckdb_appender *appender) { + if (!appender || !*appender) { + return DuckDBError; + } + duckdb_appender_close(*appender); + auto wrapper = reinterpret_cast(*appender); + if (wrapper) { + delete wrapper; + } + *appender = nullptr; + return DuckDBSuccess; +} + +template +duckdb_state duckdb_appender_run_function(duckdb_appender appender, FUN &&function) { + if (!appender) { + return DuckDBError; + } + auto wrapper = reinterpret_cast(appender); + if (!wrapper->appender) { + return DuckDBError; + } + try { + function(*wrapper->appender); + } catch (std::exception &ex) { + wrapper->error = ex.what(); + return DuckDBError; + } catch (...) { // LCOV_EXCL_START + wrapper->error = "Unknown error"; + return DuckDBError; + } // LCOV_EXCL_STOP + return DuckDBSuccess; +} + +const char *duckdb_appender_error(duckdb_appender appender) { + if (!appender) { + return nullptr; + } + auto wrapper = reinterpret_cast(appender); + if (wrapper->error.empty()) { + return nullptr; + } + return wrapper->error.c_str(); +} + +duckdb_state duckdb_appender_begin_row(duckdb_appender appender) { + return DuckDBSuccess; +} + +duckdb_state duckdb_appender_end_row(duckdb_appender appender) { + return duckdb_appender_run_function(appender, [&](Appender &appender) { appender.EndRow(); }); +} + +template +duckdb_state duckdb_append_internal(duckdb_appender appender, T value) { + if (!appender) { + return DuckDBError; + } + auto *appender_instance = reinterpret_cast(appender); + try { + appender_instance->appender->Append(value); + } catch (std::exception &ex) { + appender_instance->error = ex.what(); + return DuckDBError; + } catch (...) { + return DuckDBError; + } + return DuckDBSuccess; +} + +duckdb_state duckdb_append_bool(duckdb_appender appender, bool value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_int8(duckdb_appender appender, int8_t value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_int16(duckdb_appender appender, int16_t value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_int32(duckdb_appender appender, int32_t value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_int64(duckdb_appender appender, int64_t value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_hugeint(duckdb_appender appender, duckdb_hugeint value) { + hugeint_t internal; + internal.lower = value.lower; + internal.upper = value.upper; + return duckdb_append_internal(appender, internal); +} + +duckdb_state duckdb_append_uint8(duckdb_appender appender, uint8_t value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_uint16(duckdb_appender appender, uint16_t value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_uint32(duckdb_appender appender, uint32_t value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_uint64(duckdb_appender appender, uint64_t value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_float(duckdb_appender appender, float value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_double(duckdb_appender appender, double value) { + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_append_date(duckdb_appender appender, duckdb_date value) { + return duckdb_append_internal(appender, date_t(value.days)); +} + +duckdb_state duckdb_append_time(duckdb_appender appender, duckdb_time value) { + return duckdb_append_internal(appender, dtime_t(value.micros)); +} + +duckdb_state duckdb_append_timestamp(duckdb_appender appender, duckdb_timestamp value) { + return duckdb_append_internal(appender, timestamp_t(value.micros)); +} + +duckdb_state duckdb_append_interval(duckdb_appender appender, duckdb_interval value) { + interval_t interval; + interval.months = value.months; + interval.days = value.days; + interval.micros = value.micros; + return duckdb_append_internal(appender, interval); +} + +duckdb_state duckdb_append_null(duckdb_appender appender) { + return duckdb_append_internal(appender, nullptr); +} + +duckdb_state duckdb_append_varchar(duckdb_appender appender, const char *val) { + return duckdb_append_internal(appender, val); +} + +duckdb_state duckdb_append_varchar_length(duckdb_appender appender, const char *val, idx_t length) { + return duckdb_append_internal(appender, string_t(val, length)); +} +duckdb_state duckdb_append_blob(duckdb_appender appender, const void *data, idx_t length) { + auto value = duckdb::Value::BLOB((duckdb::const_data_ptr_t)data, length); + return duckdb_append_internal(appender, value); +} + +duckdb_state duckdb_appender_flush(duckdb_appender appender) { + return duckdb_appender_run_function(appender, [&](Appender &appender) { appender.Flush(); }); +} + +duckdb_state duckdb_appender_close(duckdb_appender appender) { + return duckdb_appender_run_function(appender, [&](Appender &appender) { appender.Close(); }); +} + +duckdb_state duckdb_append_data_chunk(duckdb_appender appender, duckdb_data_chunk chunk) { + if (!chunk) { + return DuckDBError; + } + auto data_chunk = (duckdb::DataChunk *)chunk; + return duckdb_appender_run_function(appender, [&](Appender &appender) { appender.AppendDataChunk(*data_chunk); }); +} diff --git a/src/duckdb/src/main/capi/arrow-c.cpp b/src/duckdb/src/main/capi/arrow-c.cpp new file mode 100644 index 00000000..9dff8ed1 --- /dev/null +++ b/src/duckdb/src/main/capi/arrow-c.cpp @@ -0,0 +1,298 @@ +#include "duckdb/common/arrow/arrow_converter.hpp" +#include "duckdb/function/table/arrow.hpp" +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/common/arrow/arrow.hpp" + +using duckdb::ArrowConverter; +using duckdb::ArrowResultWrapper; +using duckdb::Connection; +using duckdb::DataChunk; +using duckdb::LogicalType; +using duckdb::MaterializedQueryResult; +using duckdb::PreparedStatementWrapper; +using duckdb::QueryResult; +using duckdb::QueryResultType; + +duckdb_state duckdb_query_arrow(duckdb_connection connection, const char *query, duckdb_arrow *out_result) { + Connection *conn = (Connection *)connection; + auto wrapper = new ArrowResultWrapper(); + wrapper->result = conn->Query(query); + *out_result = (duckdb_arrow)wrapper; + return !wrapper->result->HasError() ? DuckDBSuccess : DuckDBError; +} + +duckdb_state duckdb_query_arrow_schema(duckdb_arrow result, duckdb_arrow_schema *out_schema) { + if (!out_schema) { + return DuckDBSuccess; + } + auto wrapper = reinterpret_cast(result); + ArrowConverter::ToArrowSchema((ArrowSchema *)*out_schema, wrapper->result->types, wrapper->result->names, + wrapper->options); + return DuckDBSuccess; +} + +duckdb_state duckdb_prepared_arrow_schema(duckdb_prepared_statement prepared, duckdb_arrow_schema *out_schema) { + if (!out_schema) { + return DuckDBSuccess; + } + auto wrapper = reinterpret_cast(prepared); + if (!wrapper || !wrapper->statement || !wrapper->statement->data) { + return DuckDBError; + } + auto properties = wrapper->statement->context->GetClientProperties(); + duckdb::vector prepared_types; + duckdb::vector prepared_names; + + auto count = wrapper->statement->data->properties.parameter_count; + for (idx_t i = 0; i < count; i++) { + // Every prepared parameter type is UNKNOWN, which we need to map to NULL according to the spec of + // 'AdbcStatementGetParameterSchema' + auto type = LogicalType::SQLNULL; + + // FIXME: we don't support named parameters yet, but when we do, this needs to be updated + auto name = std::to_string(i); + prepared_types.push_back(std::move(type)); + prepared_names.push_back(name); + } + + auto result_schema = (ArrowSchema *)*out_schema; + if (!result_schema) { + return DuckDBError; + } + + if (result_schema->release) { + // Need to release the existing schema before we overwrite it + result_schema->release(result_schema); + result_schema->release = nullptr; + } + + ArrowConverter::ToArrowSchema(result_schema, prepared_types, prepared_names, properties); + return DuckDBSuccess; +} + +duckdb_state duckdb_query_arrow_array(duckdb_arrow result, duckdb_arrow_array *out_array) { + if (!out_array) { + return DuckDBSuccess; + } + auto wrapper = reinterpret_cast(result); + auto success = wrapper->result->TryFetch(wrapper->current_chunk, wrapper->result->GetErrorObject()); + if (!success) { // LCOV_EXCL_START + return DuckDBError; + } // LCOV_EXCL_STOP + if (!wrapper->current_chunk || wrapper->current_chunk->size() == 0) { + return DuckDBSuccess; + } + ArrowConverter::ToArrowArray(*wrapper->current_chunk, reinterpret_cast(*out_array), wrapper->options); + return DuckDBSuccess; +} + +idx_t duckdb_arrow_row_count(duckdb_arrow result) { + auto wrapper = reinterpret_cast(result); + if (wrapper->result->HasError()) { + return 0; + } + return wrapper->result->RowCount(); +} + +idx_t duckdb_arrow_column_count(duckdb_arrow result) { + auto wrapper = reinterpret_cast(result); + return wrapper->result->ColumnCount(); +} + +idx_t duckdb_arrow_rows_changed(duckdb_arrow result) { + auto wrapper = reinterpret_cast(result); + if (wrapper->result->HasError()) { + return 0; + } + idx_t rows_changed = 0; + auto &collection = wrapper->result->Collection(); + idx_t row_count = collection.Count(); + if (row_count > 0 && wrapper->result->properties.return_type == duckdb::StatementReturnType::CHANGED_ROWS) { + auto rows = collection.GetRows(); + D_ASSERT(row_count == 1); + D_ASSERT(rows.size() == 1); + rows_changed = rows[0].GetValue(0).GetValue(); + } + return rows_changed; +} + +const char *duckdb_query_arrow_error(duckdb_arrow result) { + auto wrapper = reinterpret_cast(result); + return wrapper->result->GetError().c_str(); +} + +void duckdb_destroy_arrow(duckdb_arrow *result) { + if (*result) { + auto wrapper = reinterpret_cast(*result); + delete wrapper; + *result = nullptr; + } +} + +duckdb_state duckdb_execute_prepared_arrow(duckdb_prepared_statement prepared_statement, duckdb_arrow *out_result) { + auto wrapper = reinterpret_cast(prepared_statement); + if (!wrapper || !wrapper->statement || wrapper->statement->HasError() || !out_result) { + return DuckDBError; + } + auto arrow_wrapper = new ArrowResultWrapper(); + arrow_wrapper->options = wrapper->statement->context->GetClientProperties(); + + auto result = wrapper->statement->Execute(wrapper->values, false); + D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); + arrow_wrapper->result = duckdb::unique_ptr_cast(std::move(result)); + *out_result = reinterpret_cast(arrow_wrapper); + return !arrow_wrapper->result->HasError() ? DuckDBSuccess : DuckDBError; +} + +namespace arrow_array_stream_wrapper { +namespace { +struct PrivateData { + ArrowSchema *schema; + ArrowArray *array; + bool done = false; +}; + +// LCOV_EXCL_START +// This function is never called, but used to set ArrowSchema's release functions to a non-null NOOP. +void EmptySchemaRelease(ArrowSchema *) { +} +// LCOV_EXCL_STOP + +void EmptyArrayRelease(ArrowArray *) { +} + +void EmptyStreamRelease(ArrowArrayStream *) { +} + +void FactoryGetSchema(uintptr_t stream_factory_ptr, duckdb::ArrowSchemaWrapper &schema) { + auto stream = reinterpret_cast(stream_factory_ptr); + stream->get_schema(stream, &schema.arrow_schema); + + // Need to nullify the root schema's release function here, because streams don't allow us to set the release + // function. For the schema's children, we nullify the release functions in `duckdb_arrow_scan`, so we don't need to + // handle them again here. We set this to nullptr and not EmptySchemaRelease to prevent ArrowSchemaWrapper's + // destructor from destroying the schema (it's the caller's responsibility). + schema.arrow_schema.release = nullptr; +} + +int GetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { + auto private_data = static_cast((stream->private_data)); + if (private_data->schema == nullptr) { + return DuckDBError; + } + + *out = *private_data->schema; + out->release = EmptySchemaRelease; + return DuckDBSuccess; +} + +int GetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) { + auto private_data = static_cast((stream->private_data)); + *out = *private_data->array; + if (private_data->done) { + out->release = nullptr; + } else { + out->release = EmptyArrayRelease; + } + + private_data->done = true; + return DuckDBSuccess; +} + +duckdb::unique_ptr FactoryGetNext(uintptr_t stream_factory_ptr, + duckdb::ArrowStreamParameters ¶meters) { + auto stream = reinterpret_cast(stream_factory_ptr); + auto ret = duckdb::make_uniq(); + ret->arrow_array_stream = *stream; + ret->arrow_array_stream.release = EmptyStreamRelease; + return ret; +} + +// LCOV_EXCL_START +// This function is never be called, because it's used to construct a stream wrapping around a caller-supplied +// ArrowArray. Thus, the stream itself cannot produce an error. +const char *GetLastError(struct ArrowArrayStream *stream) { + return nullptr; +} +// LCOV_EXCL_STOP + +void Release(struct ArrowArrayStream *stream) { + if (stream->private_data != nullptr) { + delete reinterpret_cast(stream->private_data); + } + + stream->private_data = nullptr; + stream->release = nullptr; +} + +duckdb_state Ingest(duckdb_connection connection, const char *table_name, struct ArrowArrayStream *input) { + try { + auto cconn = reinterpret_cast(connection); + cconn + ->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), + duckdb::Value::POINTER((uintptr_t)FactoryGetNext), + duckdb::Value::POINTER((uintptr_t)FactoryGetSchema)}) + ->CreateView(table_name, true, false); + } catch (...) { // LCOV_EXCL_START + // Tried covering this in tests, but it proved harder than expected. At the time of writing: + // - Passing any name to `CreateView` worked without throwing an exception + // - Passing a null Arrow array worked without throwing an exception + // - Passing an invalid schema (without any columns) led to an InternalException with SIGABRT, which is meant to + // be un-catchable. This case likely needs to be handled gracefully within `arrow_scan`. + // Ref: https://discord.com/channels/909674491309850675/921100573732909107/1115230468699336785 + return DuckDBError; + } // LCOV_EXCL_STOP + + return DuckDBSuccess; +} +} // namespace +} // namespace arrow_array_stream_wrapper + +duckdb_state duckdb_arrow_scan(duckdb_connection connection, const char *table_name, duckdb_arrow_stream arrow) { + auto stream = reinterpret_cast(arrow); + + // Backup release functions - we nullify children schema release functions because we don't want to release on + // behalf of the caller, downstream in our code. Note that Arrow releases target immediate children, but aren't + // recursive. So we only back up immediate children here and restore their functions. + ArrowSchema schema; + if (stream->get_schema(stream, &schema) == DuckDBError) { + return DuckDBError; + } + + typedef void (*release_fn_t)(ArrowSchema *); + std::vector release_fns(schema.n_children); + for (int64_t i = 0; i < schema.n_children; i++) { + auto child = schema.children[i]; + release_fns[i] = child->release; + child->release = arrow_array_stream_wrapper::EmptySchemaRelease; + } + + auto ret = arrow_array_stream_wrapper::Ingest(connection, table_name, stream); + + // Restore release functions. + for (int64_t i = 0; i < schema.n_children; i++) { + schema.children[i]->release = release_fns[i]; + } + + return ret; +} + +duckdb_state duckdb_arrow_array_scan(duckdb_connection connection, const char *table_name, + duckdb_arrow_schema arrow_schema, duckdb_arrow_array arrow_array, + duckdb_arrow_stream *out_stream) { + auto private_data = new arrow_array_stream_wrapper::PrivateData; + private_data->schema = reinterpret_cast(arrow_schema); + private_data->array = reinterpret_cast(arrow_array); + private_data->done = false; + + ArrowArrayStream *stream = new ArrowArrayStream; + *out_stream = reinterpret_cast(stream); + stream->get_schema = arrow_array_stream_wrapper::GetSchema; + stream->get_next = arrow_array_stream_wrapper::GetNext; + stream->get_last_error = arrow_array_stream_wrapper::GetLastError; + stream->release = arrow_array_stream_wrapper::Release; + stream->private_data = private_data; + + return duckdb_arrow_scan(connection, table_name, reinterpret_cast(stream)); +} diff --git a/src/duckdb/src/main/capi/cast/from_decimal-c.cpp b/src/duckdb/src/main/capi/cast/from_decimal-c.cpp new file mode 100644 index 00000000..e6bc6f98 --- /dev/null +++ b/src/duckdb/src/main/capi/cast/from_decimal-c.cpp @@ -0,0 +1,120 @@ +#include "duckdb/main/capi/cast/from_decimal.hpp" +#include "duckdb/common/types/decimal.hpp" + +namespace duckdb { + +//! DECIMAL -> VARCHAR +template <> +bool CastDecimalCInternal(duckdb_result *source, duckdb_string &result, idx_t col, idx_t row) { + auto result_data = (duckdb::DuckDBResultData *)source->internal_data; + auto &query_result = result_data->result; + auto &source_type = query_result->types[col]; + auto width = duckdb::DecimalType::GetWidth(source_type); + auto scale = duckdb::DecimalType::GetScale(source_type); + duckdb::Vector result_vec(duckdb::LogicalType::VARCHAR, false, false); + duckdb::string_t result_string; + void *source_address = UnsafeFetchPtr(source, col, row); + switch (source_type.InternalType()) { + case duckdb::PhysicalType::INT16: + result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), + width, scale, result_vec); + break; + case duckdb::PhysicalType::INT32: + result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), + width, scale, result_vec); + break; + case duckdb::PhysicalType::INT64: + result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), + width, scale, result_vec); + break; + case duckdb::PhysicalType::INT128: + result_string = duckdb::StringCastFromDecimal::Operation( + UnsafeFetchFromPtr(source_address), width, scale, result_vec); + break; + default: + throw duckdb::InternalException("Unimplemented internal type for decimal"); + } + result.data = reinterpret_cast(duckdb_malloc(sizeof(char) * (result_string.GetSize() + 1))); + memcpy(result.data, result_string.GetData(), result_string.GetSize()); + result.data[result_string.GetSize()] = '\0'; + result.size = result_string.GetSize(); + return true; +} + +template +duckdb_hugeint FetchInternals(void *source_address) { + throw duckdb::NotImplementedException("FetchInternals not implemented for internal type"); +} + +template <> +duckdb_hugeint FetchInternals(void *source_address) { + duckdb_hugeint result; + int16_t intermediate_result; + + if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { + intermediate_result = FetchDefaultValue::Operation(); + } + hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); + result.lower = hugeint_result.lower; + result.upper = hugeint_result.upper; + return result; +} +template <> +duckdb_hugeint FetchInternals(void *source_address) { + duckdb_hugeint result; + int32_t intermediate_result; + + if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { + intermediate_result = FetchDefaultValue::Operation(); + } + hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); + result.lower = hugeint_result.lower; + result.upper = hugeint_result.upper; + return result; +} +template <> +duckdb_hugeint FetchInternals(void *source_address) { + duckdb_hugeint result; + int64_t intermediate_result; + + if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { + intermediate_result = FetchDefaultValue::Operation(); + } + hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); + result.lower = hugeint_result.lower; + result.upper = hugeint_result.upper; + return result; +} +template <> +duckdb_hugeint FetchInternals(void *source_address) { + duckdb_hugeint result; + hugeint_t intermediate_result; + + if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { + intermediate_result = FetchDefaultValue::Operation(); + } + result.lower = intermediate_result.lower; + result.upper = intermediate_result.upper; + return result; +} + +//! DECIMAL -> DECIMAL (internal fetch) +template <> +bool CastDecimalCInternal(duckdb_result *source, duckdb_decimal &result, idx_t col, idx_t row) { + auto result_data = (duckdb::DuckDBResultData *)source->internal_data; + result_data->result->types[col].GetDecimalProperties(result.width, result.scale); + auto source_address = UnsafeFetchPtr(source, col, row); + + if (result.width > duckdb::Decimal::MAX_WIDTH_INT64) { + result.value = FetchInternals(source_address); + } else if (result.width > duckdb::Decimal::MAX_WIDTH_INT32) { + result.value = FetchInternals(source_address); + } else if (result.width > duckdb::Decimal::MAX_WIDTH_INT16) { + result.value = FetchInternals(source_address); + } else { + result.value = FetchInternals(source_address); + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/capi/cast/utils-c.cpp b/src/duckdb/src/main/capi/cast/utils-c.cpp new file mode 100644 index 00000000..779d9600 --- /dev/null +++ b/src/duckdb/src/main/capi/cast/utils-c.cpp @@ -0,0 +1,98 @@ +#include "duckdb/main/capi/cast/utils.hpp" + +namespace duckdb { + +template <> +duckdb_decimal FetchDefaultValue::Operation() { + duckdb_decimal result; + result.scale = 0; + result.width = 0; + result.value = {0, 0}; + return result; +} + +template <> +date_t FetchDefaultValue::Operation() { + date_t result; + result.days = 0; + return result; +} + +template <> +dtime_t FetchDefaultValue::Operation() { + dtime_t result; + result.micros = 0; + return result; +} + +template <> +timestamp_t FetchDefaultValue::Operation() { + timestamp_t result; + result.value = 0; + return result; +} + +template <> +interval_t FetchDefaultValue::Operation() { + interval_t result; + result.months = 0; + result.days = 0; + result.micros = 0; + return result; +} + +template <> +char *FetchDefaultValue::Operation() { + return nullptr; +} + +template <> +duckdb_string FetchDefaultValue::Operation() { + duckdb_string result; + result.data = nullptr; + result.size = 0; + return result; +} + +template <> +duckdb_blob FetchDefaultValue::Operation() { + duckdb_blob result; + result.data = nullptr; + result.size = 0; + return result; +} + +//===--------------------------------------------------------------------===// +// Blob Casts +//===--------------------------------------------------------------------===// + +template <> +bool FromCBlobCastWrapper::Operation(duckdb_blob input, duckdb_string &result) { + string_t input_str(const_char_ptr_cast(input.data), input.size); + return ToCStringCastWrapper::template Operation(input_str, result); +} + +} // namespace duckdb + +bool CanUseDeprecatedFetch(duckdb_result *result, idx_t col, idx_t row) { + if (!result) { + return false; + } + if (!duckdb::deprecated_materialize_result(result)) { + return false; + } + if (col >= result->__deprecated_column_count || row >= result->__deprecated_row_count) { + return false; + } + return true; +} + +bool CanFetchValue(duckdb_result *result, idx_t col, idx_t row) { + if (!CanUseDeprecatedFetch(result, col, row)) { + return false; + } + if (result->__deprecated_columns[col].__deprecated_nullmask[row]) { + return false; + } + return true; +} diff --git a/src/duckdb/src/main/capi/config-c.cpp b/src/duckdb/src/main/capi/config-c.cpp new file mode 100644 index 00000000..67b4b46b --- /dev/null +++ b/src/duckdb/src/main/capi/config-c.cpp @@ -0,0 +1,64 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/common/types/value.hpp" + +using duckdb::DBConfig; +using duckdb::Value; + +// config +duckdb_state duckdb_create_config(duckdb_config *out_config) { + if (!out_config) { + return DuckDBError; + } + DBConfig *config; + try { + config = new DBConfig(); + } catch (...) { // LCOV_EXCL_START + return DuckDBError; + } // LCOV_EXCL_STOP + *out_config = reinterpret_cast(config); + return DuckDBSuccess; +} + +size_t duckdb_config_count() { + return DBConfig::GetOptionCount(); +} + +duckdb_state duckdb_get_config_flag(size_t index, const char **out_name, const char **out_description) { + auto option = DBConfig::GetOptionByIndex(index); + if (!option) { + return DuckDBError; + } + if (out_name) { + *out_name = option->name; + } + if (out_description) { + *out_description = option->description; + } + return DuckDBSuccess; +} + +duckdb_state duckdb_set_config(duckdb_config config, const char *name, const char *option) { + if (!config || !name || !option) { + return DuckDBError; + } + + try { + auto db_config = (DBConfig *)config; + db_config->SetOptionByName(name, Value(option)); + } catch (...) { + return DuckDBError; + } + return DuckDBSuccess; +} + +void duckdb_destroy_config(duckdb_config *config) { + if (!config) { + return; + } + if (*config) { + auto db_config = (DBConfig *)*config; + delete db_config; + *config = nullptr; + } +} diff --git a/src/duckdb/src/main/capi/data_chunk-c.cpp b/src/duckdb/src/main/capi/data_chunk-c.cpp new file mode 100644 index 00000000..b73a7a9d --- /dev/null +++ b/src/duckdb/src/main/capi/data_chunk-c.cpp @@ -0,0 +1,191 @@ +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/main/capi/capi_internal.hpp" + +#include + +duckdb_data_chunk duckdb_create_data_chunk(duckdb_logical_type *ctypes, idx_t column_count) { + if (!ctypes) { + return nullptr; + } + duckdb::vector types; + for (idx_t i = 0; i < column_count; i++) { + auto ltype = reinterpret_cast(ctypes[i]); + types.push_back(*ltype); + } + + auto result = new duckdb::DataChunk(); + result->Initialize(duckdb::Allocator::DefaultAllocator(), types); + return reinterpret_cast(result); +} + +void duckdb_destroy_data_chunk(duckdb_data_chunk *chunk) { + if (chunk && *chunk) { + auto dchunk = reinterpret_cast(*chunk); + delete dchunk; + *chunk = nullptr; + } +} + +void duckdb_data_chunk_reset(duckdb_data_chunk chunk) { + if (!chunk) { + return; + } + auto dchunk = reinterpret_cast(chunk); + dchunk->Reset(); +} + +idx_t duckdb_data_chunk_get_column_count(duckdb_data_chunk chunk) { + if (!chunk) { + return 0; + } + auto dchunk = reinterpret_cast(chunk); + return dchunk->ColumnCount(); +} + +duckdb_vector duckdb_data_chunk_get_vector(duckdb_data_chunk chunk, idx_t col_idx) { + if (!chunk || col_idx >= duckdb_data_chunk_get_column_count(chunk)) { + return nullptr; + } + auto dchunk = reinterpret_cast(chunk); + return reinterpret_cast(&dchunk->data[col_idx]); +} + +idx_t duckdb_data_chunk_get_size(duckdb_data_chunk chunk) { + if (!chunk) { + return 0; + } + auto dchunk = reinterpret_cast(chunk); + return dchunk->size(); +} + +void duckdb_data_chunk_set_size(duckdb_data_chunk chunk, idx_t size) { + if (!chunk) { + return; + } + auto dchunk = reinterpret_cast(chunk); + dchunk->SetCardinality(size); +} + +duckdb_logical_type duckdb_vector_get_column_type(duckdb_vector vector) { + if (!vector) { + return nullptr; + } + auto v = reinterpret_cast(vector); + return reinterpret_cast(new duckdb::LogicalType(v->GetType())); +} + +void *duckdb_vector_get_data(duckdb_vector vector) { + if (!vector) { + return nullptr; + } + auto v = reinterpret_cast(vector); + return duckdb::FlatVector::GetData(*v); +} + +uint64_t *duckdb_vector_get_validity(duckdb_vector vector) { + if (!vector) { + return nullptr; + } + auto v = reinterpret_cast(vector); + return duckdb::FlatVector::Validity(*v).GetData(); +} + +void duckdb_vector_ensure_validity_writable(duckdb_vector vector) { + if (!vector) { + return; + } + auto v = reinterpret_cast(vector); + auto &validity = duckdb::FlatVector::Validity(*v); + validity.EnsureWritable(); +} + +void duckdb_vector_assign_string_element(duckdb_vector vector, idx_t index, const char *str) { + duckdb_vector_assign_string_element_len(vector, index, str, strlen(str)); +} + +void duckdb_vector_assign_string_element_len(duckdb_vector vector, idx_t index, const char *str, idx_t str_len) { + if (!vector) { + return; + } + auto v = reinterpret_cast(vector); + auto data = duckdb::FlatVector::GetData(*v); + data[index] = duckdb::StringVector::AddString(*v, str, str_len); +} + +duckdb_vector duckdb_list_vector_get_child(duckdb_vector vector) { + if (!vector) { + return nullptr; + } + auto v = reinterpret_cast(vector); + return reinterpret_cast(&duckdb::ListVector::GetEntry(*v)); +} + +idx_t duckdb_list_vector_get_size(duckdb_vector vector) { + if (!vector) { + return 0; + } + auto v = reinterpret_cast(vector); + return duckdb::ListVector::GetListSize(*v); +} + +duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size) { + if (!vector) { + return duckdb_state::DuckDBError; + } + auto v = reinterpret_cast(vector); + duckdb::ListVector::SetListSize(*v, size); + return duckdb_state::DuckDBSuccess; +} + +duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity) { + if (!vector) { + return duckdb_state::DuckDBError; + } + auto v = reinterpret_cast(vector); + duckdb::ListVector::Reserve(*v, required_capacity); + return duckdb_state::DuckDBSuccess; +} + +duckdb_vector duckdb_struct_vector_get_child(duckdb_vector vector, idx_t index) { + if (!vector) { + return nullptr; + } + auto v = reinterpret_cast(vector); + return reinterpret_cast(duckdb::StructVector::GetEntries(*v)[index].get()); +} + +bool duckdb_validity_row_is_valid(uint64_t *validity, idx_t row) { + if (!validity) { + return true; + } + idx_t entry_idx = row / 64; + idx_t idx_in_entry = row % 64; + return validity[entry_idx] & ((idx_t)1 << idx_in_entry); +} + +void duckdb_validity_set_row_validity(uint64_t *validity, idx_t row, bool valid) { + if (valid) { + duckdb_validity_set_row_valid(validity, row); + } else { + duckdb_validity_set_row_invalid(validity, row); + } +} + +void duckdb_validity_set_row_invalid(uint64_t *validity, idx_t row) { + if (!validity) { + return; + } + idx_t entry_idx = row / 64; + idx_t idx_in_entry = row % 64; + validity[entry_idx] &= ~((uint64_t)1 << idx_in_entry); +} + +void duckdb_validity_set_row_valid(uint64_t *validity, idx_t row) { + if (!validity) { + return; + } + idx_t entry_idx = row / 64; + idx_t idx_in_entry = row % 64; + validity[entry_idx] |= (uint64_t)1 << idx_in_entry; +} diff --git a/src/duckdb/src/main/capi/datetime-c.cpp b/src/duckdb/src/main/capi/datetime-c.cpp new file mode 100644 index 00000000..1a3390eb --- /dev/null +++ b/src/duckdb/src/main/capi/datetime-c.cpp @@ -0,0 +1,73 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" + +using duckdb::Date; +using duckdb::Time; +using duckdb::Timestamp; + +using duckdb::date_t; +using duckdb::dtime_t; +using duckdb::timestamp_t; + +duckdb_date_struct duckdb_from_date(duckdb_date date) { + int32_t year, month, day; + Date::Convert(date_t(date.days), year, month, day); + + duckdb_date_struct result; + result.year = year; + result.month = month; + result.day = day; + return result; +} + +duckdb_date duckdb_to_date(duckdb_date_struct date) { + duckdb_date result; + result.days = Date::FromDate(date.year, date.month, date.day).days; + return result; +} + +duckdb_time_struct duckdb_from_time(duckdb_time time) { + int32_t hour, minute, second, micros; + Time::Convert(dtime_t(time.micros), hour, minute, second, micros); + + duckdb_time_struct result; + result.hour = hour; + result.min = minute; + result.sec = second; + result.micros = micros; + return result; +} + +duckdb_time duckdb_to_time(duckdb_time_struct time) { + duckdb_time result; + result.micros = Time::FromTime(time.hour, time.min, time.sec, time.micros).micros; + return result; +} + +duckdb_timestamp_struct duckdb_from_timestamp(duckdb_timestamp ts) { + date_t date; + dtime_t time; + Timestamp::Convert(timestamp_t(ts.micros), date, time); + + duckdb_date ddate; + ddate.days = date.days; + + duckdb_time dtime; + dtime.micros = time.micros; + + duckdb_timestamp_struct result; + result.date = duckdb_from_date(ddate); + result.time = duckdb_from_time(dtime); + return result; +} + +duckdb_timestamp duckdb_to_timestamp(duckdb_timestamp_struct ts) { + date_t date = date_t(duckdb_to_date(ts.date).days); + dtime_t time = dtime_t(duckdb_to_time(ts.time).micros); + + duckdb_timestamp result; + result.micros = Timestamp::FromDatetime(date, time).value; + return result; +} diff --git a/src/duckdb/src/main/capi/duckdb-c.cpp b/src/duckdb/src/main/capi/duckdb-c.cpp new file mode 100644 index 00000000..5377df46 --- /dev/null +++ b/src/duckdb/src/main/capi/duckdb-c.cpp @@ -0,0 +1,89 @@ +#include "duckdb/main/capi/capi_internal.hpp" + +using duckdb::Connection; +using duckdb::DatabaseData; +using duckdb::DBConfig; +using duckdb::DuckDB; + +duckdb_state duckdb_open_ext(const char *path, duckdb_database *out, duckdb_config config, char **error) { + auto wrapper = new DatabaseData(); + try { + auto db_config = (DBConfig *)config; + wrapper->database = duckdb::make_uniq(path, db_config); + } catch (std::exception &ex) { + if (error) { + *error = strdup(ex.what()); + } + delete wrapper; + return DuckDBError; + } catch (...) { // LCOV_EXCL_START + if (error) { + *error = strdup("Unknown error"); + } + delete wrapper; + return DuckDBError; + } // LCOV_EXCL_STOP + *out = (duckdb_database)wrapper; + return DuckDBSuccess; +} + +duckdb_state duckdb_open(const char *path, duckdb_database *out) { + return duckdb_open_ext(path, out, nullptr, nullptr); +} + +void duckdb_close(duckdb_database *database) { + if (database && *database) { + auto wrapper = reinterpret_cast(*database); + delete wrapper; + *database = nullptr; + } +} + +duckdb_state duckdb_connect(duckdb_database database, duckdb_connection *out) { + if (!database || !out) { + return DuckDBError; + } + auto wrapper = reinterpret_cast(database); + Connection *connection; + try { + connection = new Connection(*wrapper->database); + } catch (...) { // LCOV_EXCL_START + return DuckDBError; + } // LCOV_EXCL_STOP + *out = (duckdb_connection)connection; + return DuckDBSuccess; +} + +void duckdb_interrupt(duckdb_connection connection) { + if (!connection) { + return; + } + Connection *conn = reinterpret_cast(connection); + conn->Interrupt(); +} + +double duckdb_query_progress(duckdb_connection connection) { + if (!connection) { + return -1; + } + Connection *conn = reinterpret_cast(connection); + return conn->context->GetProgress(); +} + +void duckdb_disconnect(duckdb_connection *connection) { + if (connection && *connection) { + Connection *conn = reinterpret_cast(*connection); + delete conn; + *connection = nullptr; + } +} + +duckdb_state duckdb_query(duckdb_connection connection, const char *query, duckdb_result *out) { + Connection *conn = reinterpret_cast(connection); + auto result = conn->Query(query); + return duckdb_translate_result(std::move(result), out); +} + +const char *duckdb_library_version() { + return DuckDB::LibraryVersion(); +} diff --git a/src/duckdb/src/main/capi/duckdb_value-c.cpp b/src/duckdb/src/main/capi/duckdb_value-c.cpp new file mode 100644 index 00000000..8a63a75e --- /dev/null +++ b/src/duckdb/src/main/capi/duckdb_value-c.cpp @@ -0,0 +1,41 @@ +#include "duckdb/main/capi/capi_internal.hpp" + +void duckdb_destroy_value(duckdb_value *value) { + if (value && *value) { + auto val = reinterpret_cast(*value); + delete val; + *value = nullptr; + } +} + +duckdb_value duckdb_create_varchar_length(const char *text, idx_t length) { + return reinterpret_cast(new duckdb::Value(std::string(text, length))); +} + +duckdb_value duckdb_create_varchar(const char *text) { + return duckdb_create_varchar_length(text, strlen(text)); +} + +duckdb_value duckdb_create_int64(int64_t input) { + auto val = duckdb::Value::BIGINT(input); + return reinterpret_cast(new duckdb::Value(val)); +} + +char *duckdb_get_varchar(duckdb_value value) { + auto val = reinterpret_cast(value); + auto str_val = val->DefaultCastAs(duckdb::LogicalType::VARCHAR); + auto &str = duckdb::StringValue::Get(str_val); + + auto result = reinterpret_cast(malloc(sizeof(char) * (str.size() + 1))); + memcpy(result, str.c_str(), str.size()); + result[str.size()] = '\0'; + return result; +} + +int64_t duckdb_get_int64(duckdb_value value) { + auto val = reinterpret_cast(value); + if (!val->DefaultTryCastAs(duckdb::LogicalType::BIGINT)) { + return 0; + } + return duckdb::BigIntValue::Get(*val); +} diff --git a/src/duckdb/src/main/capi/helper-c.cpp b/src/duckdb/src/main/capi/helper-c.cpp new file mode 100644 index 00000000..cf8001e1 --- /dev/null +++ b/src/duckdb/src/main/capi/helper-c.cpp @@ -0,0 +1,195 @@ +#include "duckdb/main/capi/capi_internal.hpp" + +namespace duckdb { + +LogicalTypeId ConvertCTypeToCPP(duckdb_type c_type) { + switch (c_type) { + case DUCKDB_TYPE_BOOLEAN: + return LogicalTypeId::BOOLEAN; + case DUCKDB_TYPE_TINYINT: + return LogicalTypeId::TINYINT; + case DUCKDB_TYPE_SMALLINT: + return LogicalTypeId::SMALLINT; + case DUCKDB_TYPE_INTEGER: + return LogicalTypeId::INTEGER; + case DUCKDB_TYPE_BIGINT: + return LogicalTypeId::BIGINT; + case DUCKDB_TYPE_UTINYINT: + return LogicalTypeId::UTINYINT; + case DUCKDB_TYPE_USMALLINT: + return LogicalTypeId::USMALLINT; + case DUCKDB_TYPE_UINTEGER: + return LogicalTypeId::UINTEGER; + case DUCKDB_TYPE_UBIGINT: + return LogicalTypeId::UBIGINT; + case DUCKDB_TYPE_HUGEINT: + return LogicalTypeId::HUGEINT; + case DUCKDB_TYPE_FLOAT: + return LogicalTypeId::FLOAT; + case DUCKDB_TYPE_DOUBLE: + return LogicalTypeId::DOUBLE; + case DUCKDB_TYPE_TIMESTAMP: + return LogicalTypeId::TIMESTAMP; + case DUCKDB_TYPE_DATE: + return LogicalTypeId::DATE; + case DUCKDB_TYPE_TIME: + return LogicalTypeId::TIME; + case DUCKDB_TYPE_VARCHAR: + return LogicalTypeId::VARCHAR; + case DUCKDB_TYPE_BLOB: + return LogicalTypeId::BLOB; + case DUCKDB_TYPE_INTERVAL: + return LogicalTypeId::INTERVAL; + case DUCKDB_TYPE_TIMESTAMP_S: + return LogicalTypeId::TIMESTAMP_SEC; + case DUCKDB_TYPE_TIMESTAMP_MS: + return LogicalTypeId::TIMESTAMP_MS; + case DUCKDB_TYPE_TIMESTAMP_NS: + return LogicalTypeId::TIMESTAMP_NS; + case DUCKDB_TYPE_UUID: + return LogicalTypeId::UUID; + default: // LCOV_EXCL_START + D_ASSERT(0); + return LogicalTypeId::INVALID; + } // LCOV_EXCL_STOP +} + +duckdb_type ConvertCPPTypeToC(const LogicalType &sql_type) { + switch (sql_type.id()) { + case LogicalTypeId::BOOLEAN: + return DUCKDB_TYPE_BOOLEAN; + case LogicalTypeId::TINYINT: + return DUCKDB_TYPE_TINYINT; + case LogicalTypeId::SMALLINT: + return DUCKDB_TYPE_SMALLINT; + case LogicalTypeId::INTEGER: + return DUCKDB_TYPE_INTEGER; + case LogicalTypeId::BIGINT: + return DUCKDB_TYPE_BIGINT; + case LogicalTypeId::UTINYINT: + return DUCKDB_TYPE_UTINYINT; + case LogicalTypeId::USMALLINT: + return DUCKDB_TYPE_USMALLINT; + case LogicalTypeId::UINTEGER: + return DUCKDB_TYPE_UINTEGER; + case LogicalTypeId::UBIGINT: + return DUCKDB_TYPE_UBIGINT; + case LogicalTypeId::HUGEINT: + return DUCKDB_TYPE_HUGEINT; + case LogicalTypeId::FLOAT: + return DUCKDB_TYPE_FLOAT; + case LogicalTypeId::DOUBLE: + return DUCKDB_TYPE_DOUBLE; + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return DUCKDB_TYPE_TIMESTAMP; + case LogicalTypeId::TIMESTAMP_SEC: + return DUCKDB_TYPE_TIMESTAMP_S; + case LogicalTypeId::TIMESTAMP_MS: + return DUCKDB_TYPE_TIMESTAMP_MS; + case LogicalTypeId::TIMESTAMP_NS: + return DUCKDB_TYPE_TIMESTAMP_NS; + case LogicalTypeId::DATE: + return DUCKDB_TYPE_DATE; + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return DUCKDB_TYPE_TIME; + case LogicalTypeId::VARCHAR: + return DUCKDB_TYPE_VARCHAR; + case LogicalTypeId::BLOB: + return DUCKDB_TYPE_BLOB; + case LogicalTypeId::BIT: + return DUCKDB_TYPE_BIT; + case LogicalTypeId::INTERVAL: + return DUCKDB_TYPE_INTERVAL; + case LogicalTypeId::DECIMAL: + return DUCKDB_TYPE_DECIMAL; + case LogicalTypeId::ENUM: + return DUCKDB_TYPE_ENUM; + case LogicalTypeId::LIST: + return DUCKDB_TYPE_LIST; + case LogicalTypeId::STRUCT: + return DUCKDB_TYPE_STRUCT; + case LogicalTypeId::MAP: + return DUCKDB_TYPE_MAP; + case LogicalTypeId::UNION: + return DUCKDB_TYPE_UNION; + case LogicalTypeId::UUID: + return DUCKDB_TYPE_UUID; + default: // LCOV_EXCL_START + D_ASSERT(0); + return DUCKDB_TYPE_INVALID; + } // LCOV_EXCL_STOP +} + +idx_t GetCTypeSize(duckdb_type type) { + switch (type) { + case DUCKDB_TYPE_BOOLEAN: + return sizeof(bool); + case DUCKDB_TYPE_TINYINT: + return sizeof(int8_t); + case DUCKDB_TYPE_SMALLINT: + return sizeof(int16_t); + case DUCKDB_TYPE_INTEGER: + return sizeof(int32_t); + case DUCKDB_TYPE_BIGINT: + return sizeof(int64_t); + case DUCKDB_TYPE_UTINYINT: + return sizeof(uint8_t); + case DUCKDB_TYPE_USMALLINT: + return sizeof(uint16_t); + case DUCKDB_TYPE_UINTEGER: + return sizeof(uint32_t); + case DUCKDB_TYPE_UBIGINT: + return sizeof(uint64_t); + case DUCKDB_TYPE_HUGEINT: + case DUCKDB_TYPE_UUID: + return sizeof(duckdb_hugeint); + case DUCKDB_TYPE_FLOAT: + return sizeof(float); + case DUCKDB_TYPE_DOUBLE: + return sizeof(double); + case DUCKDB_TYPE_DATE: + return sizeof(duckdb_date); + case DUCKDB_TYPE_TIME: + return sizeof(duckdb_time); + case DUCKDB_TYPE_TIMESTAMP: + case DUCKDB_TYPE_TIMESTAMP_S: + case DUCKDB_TYPE_TIMESTAMP_MS: + case DUCKDB_TYPE_TIMESTAMP_NS: + return sizeof(duckdb_timestamp); + case DUCKDB_TYPE_VARCHAR: + return sizeof(const char *); + case DUCKDB_TYPE_BLOB: + return sizeof(duckdb_blob); + case DUCKDB_TYPE_INTERVAL: + return sizeof(duckdb_interval); + case DUCKDB_TYPE_DECIMAL: + return sizeof(duckdb_hugeint); + default: // LCOV_EXCL_START + // unsupported type + D_ASSERT(0); + return sizeof(const char *); + } // LCOV_EXCL_STOP +} + +} // namespace duckdb + +void *duckdb_malloc(size_t size) { + return malloc(size); +} + +void duckdb_free(void *ptr) { + free(ptr); +} + +idx_t duckdb_vector_size() { + return STANDARD_VECTOR_SIZE; +} + +bool duckdb_string_is_inlined(duckdb_string_t string_p) { + static_assert(sizeof(duckdb_string_t) == sizeof(duckdb::string_t), + "duckdb_string_t should have the same memory layout as duckdb::string_t"); + auto &string = *(duckdb::string_t *)(&string_p); + return string.IsInlined(); +} diff --git a/src/duckdb/src/main/capi/hugeint-c.cpp b/src/duckdb/src/main/capi/hugeint-c.cpp new file mode 100644 index 00000000..337ec498 --- /dev/null +++ b/src/duckdb/src/main/capi/hugeint-c.cpp @@ -0,0 +1,59 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/operator/decimal_cast_operators.hpp" +#include "duckdb/main/capi/cast/utils.hpp" +#include "duckdb/main/capi/cast/to_decimal.hpp" + +using duckdb::Hugeint; +using duckdb::hugeint_t; +using duckdb::Value; + +double duckdb_hugeint_to_double(duckdb_hugeint val) { + hugeint_t internal; + internal.lower = val.lower; + internal.upper = val.upper; + return Hugeint::Cast(internal); +} + +static duckdb_decimal to_decimal_cast(double val, uint8_t width, uint8_t scale) { + if (width > duckdb::Decimal::MAX_WIDTH_INT64) { + return duckdb::TryCastToDecimalCInternal>(val, width, scale); + } + if (width > duckdb::Decimal::MAX_WIDTH_INT32) { + return duckdb::TryCastToDecimalCInternal>(val, width, scale); + } + if (width > duckdb::Decimal::MAX_WIDTH_INT16) { + return duckdb::TryCastToDecimalCInternal>(val, width, scale); + } + return duckdb::TryCastToDecimalCInternal>(val, width, scale); +} + +duckdb_decimal duckdb_double_to_decimal(double val, uint8_t width, uint8_t scale) { + if (scale > width || width > duckdb::Decimal::MAX_WIDTH_INT128) { + return duckdb::FetchDefaultValue::Operation(); + } + return to_decimal_cast(val, width, scale); +} + +duckdb_hugeint duckdb_double_to_hugeint(double val) { + hugeint_t internal_result; + if (!Value::DoubleIsFinite(val) || !Hugeint::TryConvert(val, internal_result)) { + internal_result.lower = 0; + internal_result.upper = 0; + } + + duckdb_hugeint result; + result.lower = internal_result.lower; + result.upper = internal_result.upper; + return result; +} + +double duckdb_decimal_to_double(duckdb_decimal val) { + double result; + hugeint_t value; + value.lower = val.value.lower; + value.upper = val.value.upper; + duckdb::TryCastFromDecimal::Operation(value, result, nullptr, val.width, val.scale); + return result; +} diff --git a/src/duckdb/src/main/capi/logical_types-c.cpp b/src/duckdb/src/main/capi/logical_types-c.cpp new file mode 100644 index 00000000..83906f68 --- /dev/null +++ b/src/duckdb/src/main/capi/logical_types-c.cpp @@ -0,0 +1,270 @@ +#include "duckdb/main/capi/capi_internal.hpp" + +static bool AssertLogicalTypeId(duckdb_logical_type type, duckdb::LogicalTypeId type_id) { + if (!type) { + return false; + } + auto <ype = *(reinterpret_cast(type)); + if (ltype.id() != type_id) { + return false; + } + return true; +} + +static bool AssertInternalType(duckdb_logical_type type, duckdb::PhysicalType physical_type) { + if (!type) { + return false; + } + auto <ype = *(reinterpret_cast(type)); + if (ltype.InternalType() != physical_type) { + return false; + } + return true; +} + +duckdb_logical_type duckdb_create_logical_type(duckdb_type type) { + return reinterpret_cast(new duckdb::LogicalType(duckdb::ConvertCTypeToCPP(type))); +} + +duckdb_logical_type duckdb_create_list_type(duckdb_logical_type type) { + if (!type) { + return nullptr; + } + duckdb::LogicalType *ltype = new duckdb::LogicalType; + *ltype = duckdb::LogicalType::LIST(*reinterpret_cast(type)); + return reinterpret_cast(ltype); +} + +duckdb_logical_type duckdb_create_union_type(duckdb_logical_type member_types_p, const char **member_names, + idx_t member_count) { + if (!member_types_p || !member_names) { + return nullptr; + } + duckdb::LogicalType *member_types = reinterpret_cast(member_types_p); + duckdb::LogicalType *mtype = new duckdb::LogicalType; + duckdb::child_list_t members; + + for (idx_t i = 0; i < member_count; i++) { + members.push_back(make_pair(member_names[i], member_types[i])); + } + *mtype = duckdb::LogicalType::UNION(members); + return reinterpret_cast(mtype); +} + +duckdb_logical_type duckdb_create_struct_type(duckdb_logical_type *member_types_p, const char **member_names, + idx_t member_count) { + if (!member_types_p || !member_names) { + return nullptr; + } + duckdb::LogicalType **member_types = (duckdb::LogicalType **)member_types_p; + for (idx_t i = 0; i < member_count; i++) { + if (!member_names[i] || !member_types[i]) { + return nullptr; + } + } + + duckdb::LogicalType *mtype = new duckdb::LogicalType; + duckdb::child_list_t members; + + for (idx_t i = 0; i < member_count; i++) { + members.push_back(make_pair(member_names[i], *member_types[i])); + } + *mtype = duckdb::LogicalType::STRUCT(members); + return reinterpret_cast(mtype); +} + +duckdb_logical_type duckdb_create_map_type(duckdb_logical_type key_type, duckdb_logical_type value_type) { + if (!key_type || !value_type) { + return nullptr; + } + duckdb::LogicalType *mtype = new duckdb::LogicalType; + *mtype = duckdb::LogicalType::MAP(*reinterpret_cast(key_type), + *reinterpret_cast(value_type)); + return reinterpret_cast(mtype); +} + +duckdb_logical_type duckdb_create_decimal_type(uint8_t width, uint8_t scale) { + return reinterpret_cast(new duckdb::LogicalType(duckdb::LogicalType::DECIMAL(width, scale))); +} + +duckdb_type duckdb_get_type_id(duckdb_logical_type type) { + if (!type) { + return DUCKDB_TYPE_INVALID; + } + auto ltype = reinterpret_cast(type); + return duckdb::ConvertCPPTypeToC(*ltype); +} + +void duckdb_destroy_logical_type(duckdb_logical_type *type) { + if (type && *type) { + auto ltype = reinterpret_cast(*type); + delete ltype; + *type = nullptr; + } +} + +uint8_t duckdb_decimal_width(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::DECIMAL)) { + return 0; + } + auto <ype = *(reinterpret_cast(type)); + return duckdb::DecimalType::GetWidth(ltype); +} + +uint8_t duckdb_decimal_scale(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::DECIMAL)) { + return 0; + } + auto <ype = *(reinterpret_cast(type)); + return duckdb::DecimalType::GetScale(ltype); +} + +duckdb_type duckdb_decimal_internal_type(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::DECIMAL)) { + return DUCKDB_TYPE_INVALID; + } + auto <ype = *(reinterpret_cast(type)); + switch (ltype.InternalType()) { + case duckdb::PhysicalType::INT16: + return DUCKDB_TYPE_SMALLINT; + case duckdb::PhysicalType::INT32: + return DUCKDB_TYPE_INTEGER; + case duckdb::PhysicalType::INT64: + return DUCKDB_TYPE_BIGINT; + case duckdb::PhysicalType::INT128: + return DUCKDB_TYPE_HUGEINT; + default: + return DUCKDB_TYPE_INVALID; + } +} + +duckdb_type duckdb_enum_internal_type(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::ENUM)) { + return DUCKDB_TYPE_INVALID; + } + auto <ype = *(reinterpret_cast(type)); + switch (ltype.InternalType()) { + case duckdb::PhysicalType::UINT8: + return DUCKDB_TYPE_UTINYINT; + case duckdb::PhysicalType::UINT16: + return DUCKDB_TYPE_USMALLINT; + case duckdb::PhysicalType::UINT32: + return DUCKDB_TYPE_UINTEGER; + default: + return DUCKDB_TYPE_INVALID; + } +} + +uint32_t duckdb_enum_dictionary_size(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::ENUM)) { + return 0; + } + auto <ype = *(reinterpret_cast(type)); + return duckdb::EnumType::GetSize(ltype); +} + +char *duckdb_enum_dictionary_value(duckdb_logical_type type, idx_t index) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::ENUM)) { + return nullptr; + } + auto <ype = *(reinterpret_cast(type)); + auto &vector = duckdb::EnumType::GetValuesInsertOrder(ltype); + auto value = vector.GetValue(index); + return strdup(duckdb::StringValue::Get(value).c_str()); +} + +duckdb_logical_type duckdb_list_type_child_type(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::LIST) && + !AssertLogicalTypeId(type, duckdb::LogicalTypeId::MAP)) { + return nullptr; + } + auto <ype = *(reinterpret_cast(type)); + if (ltype.id() != duckdb::LogicalTypeId::LIST && ltype.id() != duckdb::LogicalTypeId::MAP) { + return nullptr; + } + return reinterpret_cast(new duckdb::LogicalType(duckdb::ListType::GetChildType(ltype))); +} + +duckdb_logical_type duckdb_map_type_key_type(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::MAP)) { + return nullptr; + } + auto &mtype = *(reinterpret_cast(type)); + if (mtype.id() != duckdb::LogicalTypeId::MAP) { + return nullptr; + } + return reinterpret_cast(new duckdb::LogicalType(duckdb::MapType::KeyType(mtype))); +} + +duckdb_logical_type duckdb_map_type_value_type(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::MAP)) { + return nullptr; + } + auto &mtype = *(reinterpret_cast(type)); + if (mtype.id() != duckdb::LogicalTypeId::MAP) { + return nullptr; + } + return reinterpret_cast(new duckdb::LogicalType(duckdb::MapType::ValueType(mtype))); +} + +idx_t duckdb_struct_type_child_count(duckdb_logical_type type) { + if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { + return 0; + } + auto <ype = *(reinterpret_cast(type)); + return duckdb::StructType::GetChildCount(ltype); +} + +idx_t duckdb_union_type_member_count(duckdb_logical_type type) { + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::UNION)) { + return 0; + } + idx_t member_count = duckdb_struct_type_child_count(type); + if (member_count != 0) { + member_count--; + } + return member_count; +} + +char *duckdb_union_type_member_name(duckdb_logical_type type, idx_t index) { + if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { + return nullptr; + } + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::UNION)) { + return nullptr; + } + auto <ype = *(reinterpret_cast(type)); + return strdup(duckdb::UnionType::GetMemberName(ltype, index).c_str()); +} + +duckdb_logical_type duckdb_union_type_member_type(duckdb_logical_type type, idx_t index) { + if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { + return nullptr; + } + if (!AssertLogicalTypeId(type, duckdb::LogicalTypeId::UNION)) { + return nullptr; + } + auto <ype = *(reinterpret_cast(type)); + return reinterpret_cast( + new duckdb::LogicalType(duckdb::UnionType::GetMemberType(ltype, index))); +} + +char *duckdb_struct_type_child_name(duckdb_logical_type type, idx_t index) { + if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { + return nullptr; + } + auto <ype = *(reinterpret_cast(type)); + return strdup(duckdb::StructType::GetChildName(ltype, index).c_str()); +} + +duckdb_logical_type duckdb_struct_type_child_type(duckdb_logical_type type, idx_t index) { + if (!AssertInternalType(type, duckdb::PhysicalType::STRUCT)) { + return nullptr; + } + auto <ype = *(reinterpret_cast(type)); + if (ltype.InternalType() != duckdb::PhysicalType::STRUCT) { + return nullptr; + } + return reinterpret_cast( + new duckdb::LogicalType(duckdb::StructType::GetChildType(ltype, index))); +} diff --git a/src/duckdb/src/main/capi/pending-c.cpp b/src/duckdb/src/main/capi/pending-c.cpp new file mode 100644 index 00000000..30051390 --- /dev/null +++ b/src/duckdb/src/main/capi/pending-c.cpp @@ -0,0 +1,132 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/query_result.hpp" +#include "duckdb/main/pending_query_result.hpp" +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/optional_ptr.hpp" + +using duckdb::case_insensitive_map_t; +using duckdb::make_uniq; +using duckdb::optional_ptr; +using duckdb::PendingExecutionResult; +using duckdb::PendingQueryResult; +using duckdb::PendingStatementWrapper; +using duckdb::PreparedStatementWrapper; +using duckdb::Value; + +duckdb_state duckdb_pending_prepared_internal(duckdb_prepared_statement prepared_statement, + duckdb_pending_result *out_result, bool allow_streaming) { + if (!prepared_statement || !out_result) { + return DuckDBError; + } + auto wrapper = reinterpret_cast(prepared_statement); + auto result = new PendingStatementWrapper(); + result->allow_streaming = allow_streaming; + + try { + result->statement = wrapper->statement->PendingQuery(wrapper->values, allow_streaming); + } catch (const duckdb::Exception &ex) { + result->statement = make_uniq(duckdb::PreservedError(ex)); + } catch (std::exception &ex) { + result->statement = make_uniq(duckdb::PreservedError(ex)); + } + duckdb_state return_value = !result->statement->HasError() ? DuckDBSuccess : DuckDBError; + *out_result = reinterpret_cast(result); + + return return_value; +} + +duckdb_state duckdb_pending_prepared(duckdb_prepared_statement prepared_statement, duckdb_pending_result *out_result) { + return duckdb_pending_prepared_internal(prepared_statement, out_result, false); +} + +duckdb_state duckdb_pending_prepared_streaming(duckdb_prepared_statement prepared_statement, + duckdb_pending_result *out_result) { + return duckdb_pending_prepared_internal(prepared_statement, out_result, true); +} + +void duckdb_destroy_pending(duckdb_pending_result *pending_result) { + if (!pending_result || !*pending_result) { + return; + } + auto wrapper = reinterpret_cast(*pending_result); + if (wrapper->statement) { + wrapper->statement->Close(); + } + delete wrapper; + *pending_result = nullptr; +} + +const char *duckdb_pending_error(duckdb_pending_result pending_result) { + if (!pending_result) { + return nullptr; + } + auto wrapper = reinterpret_cast(pending_result); + if (!wrapper->statement) { + return nullptr; + } + return wrapper->statement->GetError().c_str(); +} + +duckdb_pending_state duckdb_pending_execute_task(duckdb_pending_result pending_result) { + if (!pending_result) { + return DUCKDB_PENDING_ERROR; + } + auto wrapper = reinterpret_cast(pending_result); + if (!wrapper->statement) { + return DUCKDB_PENDING_ERROR; + } + if (wrapper->statement->HasError()) { + return DUCKDB_PENDING_ERROR; + } + PendingExecutionResult return_value; + try { + return_value = wrapper->statement->ExecuteTask(); + } catch (const duckdb::Exception &ex) { + wrapper->statement->SetError(duckdb::PreservedError(ex)); + return DUCKDB_PENDING_ERROR; + } catch (std::exception &ex) { + wrapper->statement->SetError(duckdb::PreservedError(ex)); + return DUCKDB_PENDING_ERROR; + } + switch (return_value) { + case PendingExecutionResult::RESULT_READY: + return DUCKDB_PENDING_RESULT_READY; + case PendingExecutionResult::NO_TASKS_AVAILABLE: + return DUCKDB_PENDING_NO_TASKS_AVAILABLE; + case PendingExecutionResult::RESULT_NOT_READY: + return DUCKDB_PENDING_RESULT_NOT_READY; + default: + return DUCKDB_PENDING_ERROR; + } +} + +bool duckdb_pending_execution_is_finished(duckdb_pending_state pending_state) { + switch (pending_state) { + case DUCKDB_PENDING_RESULT_READY: + return PendingQueryResult::IsFinished(PendingExecutionResult::RESULT_READY); + case DUCKDB_PENDING_NO_TASKS_AVAILABLE: + return PendingQueryResult::IsFinished(PendingExecutionResult::NO_TASKS_AVAILABLE); + case DUCKDB_PENDING_RESULT_NOT_READY: + return PendingQueryResult::IsFinished(PendingExecutionResult::RESULT_NOT_READY); + case DUCKDB_PENDING_ERROR: + return PendingQueryResult::IsFinished(PendingExecutionResult::EXECUTION_ERROR); + default: + return PendingQueryResult::IsFinished(PendingExecutionResult::EXECUTION_ERROR); + } +} + +duckdb_state duckdb_execute_pending(duckdb_pending_result pending_result, duckdb_result *out_result) { + if (!pending_result || !out_result) { + return DuckDBError; + } + auto wrapper = reinterpret_cast(pending_result); + if (!wrapper->statement) { + return DuckDBError; + } + + duckdb::unique_ptr result; + result = wrapper->statement->Execute(); + wrapper->statement.reset(); + return duckdb_translate_result(std::move(result), out_result); +} diff --git a/src/duckdb/src/main/capi/prepared-c.cpp b/src/duckdb/src/main/capi/prepared-c.cpp new file mode 100644 index 00000000..82b488b7 --- /dev/null +++ b/src/duckdb/src/main/capi/prepared-c.cpp @@ -0,0 +1,342 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/query_result.hpp" +#include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +using duckdb::case_insensitive_map_t; +using duckdb::Connection; +using duckdb::date_t; +using duckdb::dtime_t; +using duckdb::ExtractStatementsWrapper; +using duckdb::hugeint_t; +using duckdb::LogicalType; +using duckdb::MaterializedQueryResult; +using duckdb::optional_ptr; +using duckdb::PreparedStatementWrapper; +using duckdb::QueryResultType; +using duckdb::StringUtil; +using duckdb::timestamp_t; +using duckdb::Value; + +idx_t duckdb_extract_statements(duckdb_connection connection, const char *query, + duckdb_extracted_statements *out_extracted_statements) { + if (!connection || !query || !out_extracted_statements) { + return 0; + } + auto wrapper = new ExtractStatementsWrapper(); + Connection *conn = reinterpret_cast(connection); + try { + wrapper->statements = conn->ExtractStatements(query); + } catch (const duckdb::ParserException &e) { + wrapper->error = e.what(); + } + + *out_extracted_statements = (duckdb_extracted_statements)wrapper; + return wrapper->statements.size(); +} + +duckdb_state duckdb_prepare_extracted_statement(duckdb_connection connection, + duckdb_extracted_statements extracted_statements, idx_t index, + duckdb_prepared_statement *out_prepared_statement) { + Connection *conn = reinterpret_cast(connection); + auto source_wrapper = (ExtractStatementsWrapper *)extracted_statements; + + if (!connection || !out_prepared_statement || index >= source_wrapper->statements.size()) { + return DuckDBError; + } + auto wrapper = new PreparedStatementWrapper(); + wrapper->statement = conn->Prepare(std::move(source_wrapper->statements[index])); + + *out_prepared_statement = (duckdb_prepared_statement)wrapper; + return wrapper->statement->HasError() ? DuckDBError : DuckDBSuccess; +} + +const char *duckdb_extract_statements_error(duckdb_extracted_statements extracted_statements) { + auto wrapper = (ExtractStatementsWrapper *)extracted_statements; + if (!wrapper || wrapper->error.empty()) { + return nullptr; + } + return wrapper->error.c_str(); +} + +duckdb_state duckdb_prepare(duckdb_connection connection, const char *query, + duckdb_prepared_statement *out_prepared_statement) { + if (!connection || !query || !out_prepared_statement) { + return DuckDBError; + } + auto wrapper = new PreparedStatementWrapper(); + Connection *conn = reinterpret_cast(connection); + wrapper->statement = conn->Prepare(query); + *out_prepared_statement = (duckdb_prepared_statement)wrapper; + return !wrapper->statement->HasError() ? DuckDBSuccess : DuckDBError; +} + +const char *duckdb_prepare_error(duckdb_prepared_statement prepared_statement) { + auto wrapper = reinterpret_cast(prepared_statement); + if (!wrapper || !wrapper->statement || !wrapper->statement->HasError()) { + return nullptr; + } + return wrapper->statement->error.Message().c_str(); +} + +idx_t duckdb_nparams(duckdb_prepared_statement prepared_statement) { + auto wrapper = reinterpret_cast(prepared_statement); + if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { + return 0; + } + return wrapper->statement->n_param; +} + +static duckdb::string duckdb_parameter_name_internal(duckdb_prepared_statement prepared_statement, idx_t index) { + auto wrapper = (PreparedStatementWrapper *)prepared_statement; + if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { + return duckdb::string(); + } + if (index > wrapper->statement->n_param) { + return duckdb::string(); + } + for (auto &item : wrapper->statement->named_param_map) { + auto &identifier = item.first; + auto ¶m_idx = item.second; + if (param_idx == index) { + // Found the matching parameter + return identifier; + } + } + // No parameter was found with this index + return duckdb::string(); +} + +const char *duckdb_parameter_name(duckdb_prepared_statement prepared_statement, idx_t index) { + auto identifier = duckdb_parameter_name_internal(prepared_statement, index); + if (identifier == duckdb::string()) { + return NULL; + } + return strdup(identifier.c_str()); +} + +duckdb_type duckdb_param_type(duckdb_prepared_statement prepared_statement, idx_t param_idx) { + auto wrapper = reinterpret_cast(prepared_statement); + if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { + return DUCKDB_TYPE_INVALID; + } + LogicalType param_type; + auto identifier = std::to_string(param_idx); + if (wrapper->statement->data->TryGetType(identifier, param_type)) { + return ConvertCPPTypeToC(param_type); + } + // The value_map is gone after executing the prepared statement + // See if this is the case and we still have a value registered for it + auto it = wrapper->values.find(identifier); + if (it != wrapper->values.end()) { + return ConvertCPPTypeToC(it->second.type()); + } + return DUCKDB_TYPE_INVALID; +} + +duckdb_state duckdb_clear_bindings(duckdb_prepared_statement prepared_statement) { + auto wrapper = reinterpret_cast(prepared_statement); + if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { + return DuckDBError; + } + wrapper->values.clear(); + return DuckDBSuccess; +} + +duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_value val) { + auto value = reinterpret_cast(val); + auto wrapper = reinterpret_cast(prepared_statement); + if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { + return DuckDBError; + } + if (param_idx <= 0 || param_idx > wrapper->statement->n_param) { + wrapper->statement->error = + duckdb::InvalidInputException("Can not bind to parameter number %d, statement only has %d parameter(s)", + param_idx, wrapper->statement->n_param); + return DuckDBError; + } + auto identifier = duckdb_parameter_name_internal(prepared_statement, param_idx); + wrapper->values[identifier] = *value; + return DuckDBSuccess; +} + +duckdb_state duckdb_bind_parameter_index(duckdb_prepared_statement prepared_statement, idx_t *param_idx_out, + const char *name_p) { + auto wrapper = (PreparedStatementWrapper *)prepared_statement; + if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { + return DuckDBError; + } + if (!name_p || !param_idx_out) { + return DuckDBError; + } + auto name = std::string(name_p); + for (auto &pair : wrapper->statement->named_param_map) { + if (duckdb::StringUtil::CIEquals(pair.first, name)) { + *param_idx_out = pair.second; + return DuckDBSuccess; + } + } + return DuckDBError; +} + +duckdb_state duckdb_bind_boolean(duckdb_prepared_statement prepared_statement, idx_t param_idx, bool val) { + auto value = Value::BOOLEAN(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_int8(duckdb_prepared_statement prepared_statement, idx_t param_idx, int8_t val) { + auto value = Value::TINYINT(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_int16(duckdb_prepared_statement prepared_statement, idx_t param_idx, int16_t val) { + auto value = Value::SMALLINT(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_int32(duckdb_prepared_statement prepared_statement, idx_t param_idx, int32_t val) { + auto value = Value::INTEGER(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_int64(duckdb_prepared_statement prepared_statement, idx_t param_idx, int64_t val) { + auto value = Value::BIGINT(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +static hugeint_t duckdb_internal_hugeint(duckdb_hugeint val) { + hugeint_t internal; + internal.lower = val.lower; + internal.upper = val.upper; + return internal; +} + +duckdb_state duckdb_bind_hugeint(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_hugeint val) { + auto value = Value::HUGEINT(duckdb_internal_hugeint(val)); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_uint8(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint8_t val) { + auto value = Value::UTINYINT(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_uint16(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint16_t val) { + auto value = Value::USMALLINT(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_uint32(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint32_t val) { + auto value = Value::UINTEGER(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_uint64(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint64_t val) { + auto value = Value::UBIGINT(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_float(duckdb_prepared_statement prepared_statement, idx_t param_idx, float val) { + auto value = Value::FLOAT(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_double(duckdb_prepared_statement prepared_statement, idx_t param_idx, double val) { + auto value = Value::DOUBLE(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_date(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_date val) { + auto value = Value::DATE(date_t(val.days)); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_time(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_time val) { + auto value = Value::TIME(dtime_t(val.micros)); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_timestamp(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_timestamp val) { + auto value = Value::TIMESTAMP(timestamp_t(val.micros)); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_interval(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_interval val) { + auto value = Value::INTERVAL(val.months, val.days, val.micros); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_varchar(duckdb_prepared_statement prepared_statement, idx_t param_idx, const char *val) { + try { + auto value = Value(val); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); + } catch (...) { + return DuckDBError; + } +} + +duckdb_state duckdb_bind_varchar_length(duckdb_prepared_statement prepared_statement, idx_t param_idx, const char *val, + idx_t length) { + try { + auto value = Value(std::string(val, length)); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); + } catch (...) { + return DuckDBError; + } +} + +duckdb_state duckdb_bind_decimal(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_decimal val) { + auto hugeint_val = duckdb_internal_hugeint(val.value); + if (val.width > duckdb::Decimal::MAX_WIDTH_INT64) { + auto value = Value::DECIMAL(hugeint_val, val.width, val.scale); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); + } + auto value = hugeint_val.lower; + auto duck_val = Value::DECIMAL((int64_t)value, val.width, val.scale); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&duck_val); +} + +duckdb_state duckdb_bind_blob(duckdb_prepared_statement prepared_statement, idx_t param_idx, const void *data, + idx_t length) { + auto value = Value::BLOB(duckdb::const_data_ptr_cast(data), length); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_bind_null(duckdb_prepared_statement prepared_statement, idx_t param_idx) { + auto value = Value(); + return duckdb_bind_value(prepared_statement, param_idx, (duckdb_value)&value); +} + +duckdb_state duckdb_execute_prepared(duckdb_prepared_statement prepared_statement, duckdb_result *out_result) { + auto wrapper = reinterpret_cast(prepared_statement); + if (!wrapper || !wrapper->statement || wrapper->statement->HasError()) { + return DuckDBError; + } + + auto result = wrapper->statement->Execute(wrapper->values, false); + return duckdb_translate_result(std::move(result), out_result); +} + +template +void duckdb_destroy(void **wrapper) { + if (!wrapper) { + return; + } + + auto casted = (T *)*wrapper; + if (casted) { + delete casted; + } + *wrapper = nullptr; +} + +void duckdb_destroy_extracted(duckdb_extracted_statements *extracted_statements) { + duckdb_destroy(reinterpret_cast(extracted_statements)); +} + +void duckdb_destroy_prepare(duckdb_prepared_statement *prepared_statement) { + duckdb_destroy(reinterpret_cast(prepared_statement)); +} diff --git a/src/duckdb/src/main/capi/replacement_scan-c.cpp b/src/duckdb/src/main/capi/replacement_scan-c.cpp new file mode 100644 index 00000000..2941b25a --- /dev/null +++ b/src/duckdb/src/main/capi/replacement_scan-c.cpp @@ -0,0 +1,94 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" + +namespace duckdb { + +struct CAPIReplacementScanData : public ReplacementScanData { + ~CAPIReplacementScanData() { + if (delete_callback) { + delete_callback(extra_data); + } + } + + duckdb_replacement_callback_t callback; + void *extra_data; + duckdb_delete_callback_t delete_callback; +}; + +struct CAPIReplacementScanInfo { + CAPIReplacementScanInfo(CAPIReplacementScanData *data) : data(data) { + } + + CAPIReplacementScanData *data; + string function_name; + vector parameters; + string error; +}; + +unique_ptr duckdb_capi_replacement_callback(ClientContext &context, const string &table_name, + ReplacementScanData *data) { + auto &scan_data = reinterpret_cast(*data); + + CAPIReplacementScanInfo info(&scan_data); + scan_data.callback((duckdb_replacement_scan_info)&info, table_name.c_str(), scan_data.extra_data); + if (!info.error.empty()) { + throw BinderException("Error in replacement scan: %s\n", info.error); + } + if (info.function_name.empty()) { + // no function provided: bail-out + return nullptr; + } + auto table_function = make_uniq(); + vector> children; + for (auto ¶m : info.parameters) { + children.push_back(make_uniq(std::move(param))); + } + table_function->function = make_uniq(info.function_name, std::move(children)); + return std::move(table_function); +} + +} // namespace duckdb + +void duckdb_add_replacement_scan(duckdb_database db, duckdb_replacement_callback_t replacement, void *extra_data, + duckdb_delete_callback_t delete_callback) { + if (!db || !replacement) { + return; + } + auto wrapper = reinterpret_cast(db); + auto scan_info = duckdb::make_uniq(); + scan_info->callback = replacement; + scan_info->extra_data = extra_data; + scan_info->delete_callback = delete_callback; + + auto &config = duckdb::DBConfig::GetConfig(*wrapper->database->instance); + config.replacement_scans.push_back( + duckdb::ReplacementScan(duckdb::duckdb_capi_replacement_callback, std::move(scan_info))); +} + +void duckdb_replacement_scan_set_function_name(duckdb_replacement_scan_info info_p, const char *function_name) { + if (!info_p || !function_name) { + return; + } + auto info = reinterpret_cast(info_p); + info->function_name = function_name; +} + +void duckdb_replacement_scan_add_parameter(duckdb_replacement_scan_info info_p, duckdb_value parameter) { + if (!info_p || !parameter) { + return; + } + auto info = reinterpret_cast(info_p); + auto val = reinterpret_cast(parameter); + info->parameters.push_back(*val); +} + +void duckdb_replacement_scan_set_error(duckdb_replacement_scan_info info_p, const char *error) { + if (!info_p || !error) { + return; + } + auto info = reinterpret_cast(info_p); + info->error = error; +} diff --git a/src/duckdb/src/main/capi/result-c.cpp b/src/duckdb/src/main/capi/result-c.cpp new file mode 100644 index 00000000..bab21bea --- /dev/null +++ b/src/duckdb/src/main/capi/result-c.cpp @@ -0,0 +1,517 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/allocator.hpp" + +namespace duckdb { + +struct CBaseConverter { + template + static void NullConvert(DST &target) { + } +}; +struct CStandardConverter : public CBaseConverter { + template + static DST Convert(SRC input) { + return input; + } +}; + +struct CStringConverter { + template + static DST Convert(SRC input) { + auto result = char_ptr_cast(duckdb_malloc(input.GetSize() + 1)); + assert(result); + memcpy((void *)result, input.GetData(), input.GetSize()); + auto write_arr = char_ptr_cast(result); + write_arr[input.GetSize()] = '\0'; + return result; + } + + template + static void NullConvert(DST &target) { + target = nullptr; + } +}; + +struct CBlobConverter { + template + static DST Convert(SRC input) { + duckdb_blob result; + result.data = char_ptr_cast(duckdb_malloc(input.GetSize())); + result.size = input.GetSize(); + assert(result.data); + memcpy(result.data, input.GetData(), input.GetSize()); + return result; + } + + template + static void NullConvert(DST &target) { + target.data = nullptr; + target.size = 0; + } +}; + +struct CTimestampMsConverter : public CBaseConverter { + template + static DST Convert(SRC input) { + return Timestamp::FromEpochMs(input.value); + } +}; + +struct CTimestampNsConverter : public CBaseConverter { + template + static DST Convert(SRC input) { + return Timestamp::FromEpochNanoSeconds(input.value); + } +}; + +struct CTimestampSecConverter : public CBaseConverter { + template + static DST Convert(SRC input) { + return Timestamp::FromEpochSeconds(input.value); + } +}; + +struct CHugeintConverter : public CBaseConverter { + template + static DST Convert(SRC input) { + duckdb_hugeint result; + result.lower = input.lower; + result.upper = input.upper; + return result; + } +}; + +struct CIntervalConverter : public CBaseConverter { + template + static DST Convert(SRC input) { + duckdb_interval result; + result.days = input.days; + result.months = input.months; + result.micros = input.micros; + return result; + } +}; + +template +struct CDecimalConverter : public CBaseConverter { + template + static DST Convert(SRC input) { + duckdb_hugeint result; + result.lower = input; + result.upper = 0; + return result; + } +}; + +template +void WriteData(duckdb_column *column, ColumnDataCollection &source, const vector &column_ids) { + idx_t row = 0; + auto target = (DST *)column->__deprecated_data; + for (auto &input : source.Chunks(column_ids)) { + auto source = FlatVector::GetData(input.data[0]); + auto &mask = FlatVector::Validity(input.data[0]); + + for (idx_t k = 0; k < input.size(); k++, row++) { + if (!mask.RowIsValid(k)) { + OP::template NullConvert(target[row]); + } else { + target[row] = OP::template Convert(source[k]); + } + } + } +} + +duckdb_state deprecated_duckdb_translate_column(MaterializedQueryResult &result, duckdb_column *column, idx_t col) { + D_ASSERT(!result.HasError()); + auto &collection = result.Collection(); + idx_t row_count = collection.Count(); + column->__deprecated_nullmask = (bool *)duckdb_malloc(sizeof(bool) * collection.Count()); + column->__deprecated_data = duckdb_malloc(GetCTypeSize(column->__deprecated_type) * row_count); + if (!column->__deprecated_nullmask || !column->__deprecated_data) { // LCOV_EXCL_START + // malloc failure + return DuckDBError; + } // LCOV_EXCL_STOP + + vector column_ids {col}; + // first convert the nullmask + { + idx_t row = 0; + for (auto &input : collection.Chunks(column_ids)) { + for (idx_t k = 0; k < input.size(); k++) { + column->__deprecated_nullmask[row++] = FlatVector::IsNull(input.data[0], k); + } + } + } + // then write the data + switch (result.types[col].id()) { + case LogicalTypeId::BOOLEAN: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::TINYINT: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::SMALLINT: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::INTEGER: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::BIGINT: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::UTINYINT: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::USMALLINT: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::UINTEGER: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::UBIGINT: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::FLOAT: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::DOUBLE: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::DATE: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::TIME: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::TIME_TZ: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + WriteData(column, collection, column_ids); + break; + case LogicalTypeId::VARCHAR: { + WriteData(column, collection, column_ids); + break; + } + case LogicalTypeId::BLOB: { + WriteData(column, collection, column_ids); + break; + } + case LogicalTypeId::TIMESTAMP_NS: { + WriteData(column, collection, column_ids); + break; + } + case LogicalTypeId::TIMESTAMP_MS: { + WriteData(column, collection, column_ids); + break; + } + case LogicalTypeId::TIMESTAMP_SEC: { + WriteData(column, collection, column_ids); + break; + } + case LogicalTypeId::HUGEINT: { + WriteData(column, collection, column_ids); + break; + } + case LogicalTypeId::INTERVAL: { + WriteData(column, collection, column_ids); + break; + } + case LogicalTypeId::DECIMAL: { + // get data + switch (result.types[col].InternalType()) { + case PhysicalType::INT16: { + WriteData>(column, collection, column_ids); + break; + } + case PhysicalType::INT32: { + WriteData>(column, collection, column_ids); + break; + } + case PhysicalType::INT64: { + WriteData>(column, collection, column_ids); + break; + } + case PhysicalType::INT128: { + WriteData(column, collection, column_ids); + break; + } + default: + throw std::runtime_error("Unsupported physical type for Decimal" + + TypeIdToString(result.types[col].InternalType())); + } + break; + } + default: // LCOV_EXCL_START + return DuckDBError; + } // LCOV_EXCL_STOP + return DuckDBSuccess; +} + +duckdb_state duckdb_translate_result(unique_ptr result_p, duckdb_result *out) { + auto &result = *result_p; + D_ASSERT(result_p); + if (!out) { + // no result to write to, only return the status + return !result.HasError() ? DuckDBSuccess : DuckDBError; + } + + memset(out, 0, sizeof(duckdb_result)); + + // initialize the result_data object + auto result_data = new DuckDBResultData(); + result_data->result = std::move(result_p); + result_data->result_set_type = CAPIResultSetType::CAPI_RESULT_TYPE_NONE; + out->internal_data = result_data; + + if (result.HasError()) { + // write the error message + out->__deprecated_error_message = (char *)result.GetError().c_str(); // NOLINT + return DuckDBError; + } + // copy the data + // first write the meta data + out->__deprecated_column_count = result.ColumnCount(); + out->__deprecated_rows_changed = 0; + return DuckDBSuccess; +} + +bool deprecated_materialize_result(duckdb_result *result) { + if (!result) { + return false; + } + auto result_data = reinterpret_cast(result->internal_data); + if (result_data->result->HasError()) { + return false; + } + if (result_data->result_set_type == CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED) { + // already materialized into deprecated result format + return true; + } + if (result_data->result_set_type == CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED) { + // already used as a new result set + return false; + } + if (result_data->result_set_type == CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING) { + // already used as a streaming result + return false; + } + // materialize as deprecated result set + result_data->result_set_type = CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED; + auto column_count = result_data->result->ColumnCount(); + result->__deprecated_columns = (duckdb_column *)duckdb_malloc(sizeof(duckdb_column) * column_count); + if (!result->__deprecated_columns) { // LCOV_EXCL_START + // malloc failure + return DuckDBError; + } // LCOV_EXCL_STOP + if (result_data->result->type == QueryResultType::STREAM_RESULT) { + // if we are dealing with a stream result, convert it to a materialized result first + auto &stream_result = (StreamQueryResult &)*result_data->result; + result_data->result = stream_result.Materialize(); + } + D_ASSERT(result_data->result->type == QueryResultType::MATERIALIZED_RESULT); + auto &materialized = reinterpret_cast(*result_data->result); + + // convert the result to a materialized result + // zero initialize the columns (so we can cleanly delete it in case a malloc fails) + memset(result->__deprecated_columns, 0, sizeof(duckdb_column) * column_count); + for (idx_t i = 0; i < column_count; i++) { + result->__deprecated_columns[i].__deprecated_type = ConvertCPPTypeToC(result_data->result->types[i]); + result->__deprecated_columns[i].__deprecated_name = (char *)result_data->result->names[i].c_str(); // NOLINT + } + result->__deprecated_row_count = materialized.RowCount(); + if (result->__deprecated_row_count > 0 && + materialized.properties.return_type == StatementReturnType::CHANGED_ROWS) { + // update total changes + auto row_changes = materialized.GetValue(0, 0); + if (!row_changes.IsNull() && row_changes.DefaultTryCastAs(LogicalType::BIGINT)) { + result->__deprecated_rows_changed = row_changes.GetValue(); + } + } + // now write the data + for (idx_t col = 0; col < column_count; col++) { + auto state = deprecated_duckdb_translate_column(materialized, &result->__deprecated_columns[col], col); + if (state != DuckDBSuccess) { + return false; + } + } + return true; +} + +} // namespace duckdb + +static void DuckdbDestroyColumn(duckdb_column column, idx_t count) { + if (column.__deprecated_data) { + if (column.__deprecated_type == DUCKDB_TYPE_VARCHAR) { + // varchar, delete individual strings + auto data = reinterpret_cast(column.__deprecated_data); + for (idx_t i = 0; i < count; i++) { + if (data[i]) { + duckdb_free(data[i]); + } + } + } else if (column.__deprecated_type == DUCKDB_TYPE_BLOB) { + // blob, delete individual blobs + auto data = reinterpret_cast(column.__deprecated_data); + for (idx_t i = 0; i < count; i++) { + if (data[i].data) { + duckdb_free((void *)data[i].data); + } + } + } + duckdb_free(column.__deprecated_data); + } + if (column.__deprecated_nullmask) { + duckdb_free(column.__deprecated_nullmask); + } +} + +void duckdb_destroy_result(duckdb_result *result) { + if (result->__deprecated_columns) { + for (idx_t i = 0; i < result->__deprecated_column_count; i++) { + DuckdbDestroyColumn(result->__deprecated_columns[i], result->__deprecated_row_count); + } + duckdb_free(result->__deprecated_columns); + } + if (result->internal_data) { + auto result_data = reinterpret_cast(result->internal_data); + delete result_data; + } + memset(result, 0, sizeof(duckdb_result)); +} + +const char *duckdb_column_name(duckdb_result *result, idx_t col) { + if (!result || col >= duckdb_column_count(result)) { + return nullptr; + } + auto &result_data = *(reinterpret_cast(result->internal_data)); + return result_data.result->names[col].c_str(); +} + +duckdb_type duckdb_column_type(duckdb_result *result, idx_t col) { + if (!result || col >= duckdb_column_count(result)) { + return DUCKDB_TYPE_INVALID; + } + auto &result_data = *(reinterpret_cast(result->internal_data)); + return duckdb::ConvertCPPTypeToC(result_data.result->types[col]); +} + +duckdb_logical_type duckdb_column_logical_type(duckdb_result *result, idx_t col) { + if (!result || col >= duckdb_column_count(result)) { + return nullptr; + } + auto &result_data = *(reinterpret_cast(result->internal_data)); + return reinterpret_cast(new duckdb::LogicalType(result_data.result->types[col])); +} + +idx_t duckdb_column_count(duckdb_result *result) { + if (!result) { + return 0; + } + auto &result_data = *(reinterpret_cast(result->internal_data)); + return result_data.result->ColumnCount(); +} + +idx_t duckdb_row_count(duckdb_result *result) { + if (!result) { + return 0; + } + auto &result_data = *(reinterpret_cast(result->internal_data)); + if (result_data.result->type == duckdb::QueryResultType::STREAM_RESULT) { + // We can't know the row count beforehand + return 0; + } + auto &materialized = reinterpret_cast(*result_data.result); + return materialized.RowCount(); +} + +idx_t duckdb_rows_changed(duckdb_result *result) { + if (!result) { + return 0; + } + if (!duckdb::deprecated_materialize_result(result)) { + return 0; + } + return result->__deprecated_rows_changed; +} + +void *duckdb_column_data(duckdb_result *result, idx_t col) { + if (!result || col >= result->__deprecated_column_count) { + return nullptr; + } + if (!duckdb::deprecated_materialize_result(result)) { + return nullptr; + } + return result->__deprecated_columns[col].__deprecated_data; +} + +bool *duckdb_nullmask_data(duckdb_result *result, idx_t col) { + if (!result || col >= result->__deprecated_column_count) { + return nullptr; + } + if (!duckdb::deprecated_materialize_result(result)) { + return nullptr; + } + return result->__deprecated_columns[col].__deprecated_nullmask; +} + +const char *duckdb_result_error(duckdb_result *result) { + if (!result) { + return nullptr; + } + auto &result_data = *(reinterpret_cast(result->internal_data)); + return !result_data.result->HasError() ? nullptr : result_data.result->GetError().c_str(); +} + +idx_t duckdb_result_chunk_count(duckdb_result result) { + if (!result.internal_data) { + return 0; + } + auto &result_data = *(reinterpret_cast(result.internal_data)); + if (result_data.result_set_type == duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED) { + return 0; + } + if (result_data.result->type != duckdb::QueryResultType::MATERIALIZED_RESULT) { + // Can't know beforehand how many chunks are returned. + return 0; + } + auto &materialized = reinterpret_cast(*result_data.result); + return materialized.Collection().ChunkCount(); +} + +duckdb_data_chunk duckdb_result_get_chunk(duckdb_result result, idx_t chunk_idx) { + if (!result.internal_data) { + return nullptr; + } + auto &result_data = *(reinterpret_cast(result.internal_data)); + if (result_data.result_set_type == duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED) { + return nullptr; + } + if (result_data.result->type != duckdb::QueryResultType::MATERIALIZED_RESULT) { + // This API is only supported for materialized query results + return nullptr; + } + result_data.result_set_type = duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED; + auto &materialized = reinterpret_cast(*result_data.result); + auto &collection = materialized.Collection(); + if (chunk_idx >= collection.ChunkCount()) { + return nullptr; + } + auto chunk = duckdb::make_uniq(); + chunk->Initialize(duckdb::Allocator::DefaultAllocator(), collection.Types()); + collection.FetchChunk(chunk_idx, *chunk); + return reinterpret_cast(chunk.release()); +} + +bool duckdb_result_is_streaming(duckdb_result result) { + if (!result.internal_data) { + return false; + } + if (duckdb_result_error(&result) != nullptr) { + return false; + } + auto &result_data = *(reinterpret_cast(result.internal_data)); + return result_data.result->type == duckdb::QueryResultType::STREAM_RESULT; +} diff --git a/src/duckdb/src/main/capi/stream-c.cpp b/src/duckdb/src/main/capi/stream-c.cpp new file mode 100644 index 00000000..8fcc42fb --- /dev/null +++ b/src/duckdb/src/main/capi/stream-c.cpp @@ -0,0 +1,25 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/allocator.hpp" + +duckdb_data_chunk duckdb_stream_fetch_chunk(duckdb_result result) { + if (!result.internal_data) { + return nullptr; + } + auto &result_data = *((duckdb::DuckDBResultData *)result.internal_data); + if (result_data.result_set_type == duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED) { + return nullptr; + } + if (result_data.result->type != duckdb::QueryResultType::STREAM_RESULT) { + // We can only fetch from a StreamQueryResult + return nullptr; + } + result_data.result_set_type = duckdb::CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING; + auto &streaming = (duckdb::StreamQueryResult &)*result_data.result; + if (!streaming.IsOpen()) { + return nullptr; + } + // FetchRaw ? Do we care about flattening them? + auto chunk = streaming.Fetch(); + return reinterpret_cast(chunk.release()); +} diff --git a/src/duckdb/src/main/capi/table_function-c.cpp b/src/duckdb/src/main/capi/table_function-c.cpp new file mode 100644 index 00000000..fe1556bf --- /dev/null +++ b/src/duckdb/src/main/capi/table_function-c.cpp @@ -0,0 +1,482 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/parser/parsed_data/create_table_function_info.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/storage/statistics/node_statistics.hpp" + +namespace duckdb { + +struct CTableFunctionInfo : public TableFunctionInfo { + ~CTableFunctionInfo() { + if (extra_info && delete_callback) { + delete_callback(extra_info); + } + extra_info = nullptr; + delete_callback = nullptr; + } + + duckdb_table_function_bind_t bind = nullptr; + duckdb_table_function_init_t init = nullptr; + duckdb_table_function_init_t local_init = nullptr; + duckdb_table_function_t function = nullptr; + void *extra_info = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +struct CTableBindData : public TableFunctionData { + CTableBindData(CTableFunctionInfo &info) : info(info) { + } + ~CTableBindData() { + if (bind_data && delete_callback) { + delete_callback(bind_data); + } + bind_data = nullptr; + delete_callback = nullptr; + } + + CTableFunctionInfo &info; + void *bind_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; + unique_ptr stats; +}; + +struct CTableInternalBindInfo { + CTableInternalBindInfo(ClientContext &context, TableFunctionBindInput &input, vector &return_types, + vector &names, CTableBindData &bind_data, CTableFunctionInfo &function_info) + : context(context), input(input), return_types(return_types), names(names), bind_data(bind_data), + function_info(function_info), success(true) { + } + + ClientContext &context; + TableFunctionBindInput &input; + vector &return_types; + vector &names; + CTableBindData &bind_data; + CTableFunctionInfo &function_info; + bool success; + string error; +}; + +struct CTableInitData { + ~CTableInitData() { + if (init_data && delete_callback) { + delete_callback(init_data); + } + init_data = nullptr; + delete_callback = nullptr; + } + + void *init_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; + idx_t max_threads = 1; +}; + +struct CTableGlobalInitData : public GlobalTableFunctionState { + CTableInitData init_data; + + idx_t MaxThreads() const override { + return init_data.max_threads; + } +}; + +struct CTableLocalInitData : public LocalTableFunctionState { + CTableInitData init_data; +}; + +struct CTableInternalInitInfo { + CTableInternalInitInfo(const CTableBindData &bind_data, CTableInitData &init_data, + const vector &column_ids, optional_ptr filters) + : bind_data(bind_data), init_data(init_data), column_ids(column_ids), filters(filters), success(true) { + } + + const CTableBindData &bind_data; + CTableInitData &init_data; + const vector &column_ids; + optional_ptr filters; + bool success; + string error; +}; + +struct CTableInternalFunctionInfo { + CTableInternalFunctionInfo(const CTableBindData &bind_data, CTableInitData &init_data, CTableInitData &local_data) + : bind_data(bind_data), init_data(init_data), local_data(local_data), success(true) { + } + + const CTableBindData &bind_data; + CTableInitData &init_data; + CTableInitData &local_data; + bool success; + string error; +}; + +unique_ptr CTableFunctionBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto &info = input.info->Cast(); + D_ASSERT(info.bind && info.function && info.init); + auto result = make_uniq(info); + CTableInternalBindInfo bind_info(context, input, return_types, names, *result, info); + info.bind(&bind_info); + if (!bind_info.success) { + throw Exception(bind_info.error); + } + + return std::move(result); +} + +unique_ptr CTableFunctionInit(ClientContext &context, TableFunctionInitInput &data_p) { + auto &bind_data = data_p.bind_data->Cast(); + auto result = make_uniq(); + + CTableInternalInitInfo init_info(bind_data, result->init_data, data_p.column_ids, data_p.filters); + bind_data.info.init(&init_info); + if (!init_info.success) { + throw Exception(init_info.error); + } + return std::move(result); +} + +unique_ptr CTableFunctionLocalInit(ExecutionContext &context, TableFunctionInitInput &data_p, + GlobalTableFunctionState *gstate) { + auto &bind_data = data_p.bind_data->Cast(); + auto result = make_uniq(); + if (!bind_data.info.local_init) { + return std::move(result); + } + + CTableInternalInitInfo init_info(bind_data, result->init_data, data_p.column_ids, data_p.filters); + bind_data.info.local_init(&init_info); + if (!init_info.success) { + throw Exception(init_info.error); + } + return std::move(result); +} + +unique_ptr CTableFunctionCardinality(ClientContext &context, const FunctionData *bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + if (!bind_data.stats) { + return nullptr; + } + return make_uniq(*bind_data.stats); +} + +void CTableFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &global_data = (CTableGlobalInitData &)*data_p.global_state; + auto &local_data = (CTableLocalInitData &)*data_p.local_state; + CTableInternalFunctionInfo function_info(bind_data, global_data.init_data, local_data.init_data); + bind_data.info.function(&function_info, reinterpret_cast(&output)); + if (!function_info.success) { + throw Exception(function_info.error); + } +} + +} // namespace duckdb + +//===--------------------------------------------------------------------===// +// Table Function +//===--------------------------------------------------------------------===// +duckdb_table_function duckdb_create_table_function() { + auto function = new duckdb::TableFunction("", {}, duckdb::CTableFunction, duckdb::CTableFunctionBind, + duckdb::CTableFunctionInit, duckdb::CTableFunctionLocalInit); + function->function_info = duckdb::make_shared(); + function->cardinality = duckdb::CTableFunctionCardinality; + return function; +} + +void duckdb_destroy_table_function(duckdb_table_function *function) { + if (function && *function) { + auto tf = (duckdb::TableFunction *)*function; + delete tf; + *function = nullptr; + } +} + +void duckdb_table_function_set_name(duckdb_table_function function, const char *name) { + if (!function || !name) { + return; + } + auto tf = (duckdb::TableFunction *)function; + tf->name = name; +} + +void duckdb_table_function_add_parameter(duckdb_table_function function, duckdb_logical_type type) { + if (!function || !type) { + return; + } + auto tf = (duckdb::TableFunction *)function; + auto logical_type = (duckdb::LogicalType *)type; + tf->arguments.push_back(*logical_type); +} + +void duckdb_table_function_add_named_parameter(duckdb_table_function function, const char *name, + duckdb_logical_type type) { + if (!function || !type) { + return; + } + auto tf = (duckdb::TableFunction *)function; + auto logical_type = (duckdb::LogicalType *)type; + tf->named_parameters.insert({name, *logical_type}); +} + +void duckdb_table_function_set_extra_info(duckdb_table_function function, void *extra_info, + duckdb_delete_callback_t destroy) { + if (!function) { + return; + } + auto tf = (duckdb::TableFunction *)function; + auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); + info->extra_info = extra_info; + info->delete_callback = destroy; +} + +void duckdb_table_function_set_bind(duckdb_table_function function, duckdb_table_function_bind_t bind) { + if (!function || !bind) { + return; + } + auto tf = (duckdb::TableFunction *)function; + auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); + info->bind = bind; +} + +void duckdb_table_function_set_init(duckdb_table_function function, duckdb_table_function_init_t init) { + if (!function || !init) { + return; + } + auto tf = (duckdb::TableFunction *)function; + auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); + info->init = init; +} + +void duckdb_table_function_set_local_init(duckdb_table_function function, duckdb_table_function_init_t init) { + if (!function || !init) { + return; + } + auto tf = (duckdb::TableFunction *)function; + auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); + info->local_init = init; +} + +void duckdb_table_function_set_function(duckdb_table_function table_function, duckdb_table_function_t function) { + if (!table_function || !function) { + return; + } + auto tf = (duckdb::TableFunction *)table_function; + auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); + info->function = function; +} + +void duckdb_table_function_supports_projection_pushdown(duckdb_table_function table_function, bool pushdown) { + if (!table_function) { + return; + } + auto tf = (duckdb::TableFunction *)table_function; + tf->projection_pushdown = pushdown; +} + +duckdb_state duckdb_register_table_function(duckdb_connection connection, duckdb_table_function function) { + if (!connection || !function) { + return DuckDBError; + } + auto con = (duckdb::Connection *)connection; + auto tf = (duckdb::TableFunction *)function; + auto info = (duckdb::CTableFunctionInfo *)tf->function_info.get(); + if (tf->name.empty() || !info->bind || !info->init || !info->function) { + return DuckDBError; + } + con->context->RunFunctionInTransaction([&]() { + auto &catalog = duckdb::Catalog::GetSystemCatalog(*con->context); + duckdb::CreateTableFunctionInfo tf_info(*tf); + + // create the function in the catalog + catalog.CreateTableFunction(*con->context, tf_info); + }); + return DuckDBSuccess; +} + +//===--------------------------------------------------------------------===// +// Bind Interface +//===--------------------------------------------------------------------===// +void *duckdb_bind_get_extra_info(duckdb_bind_info info) { + if (!info) { + return nullptr; + } + auto bind_info = (duckdb::CTableInternalBindInfo *)info; + return bind_info->function_info.extra_info; +} + +void duckdb_bind_add_result_column(duckdb_bind_info info, const char *name, duckdb_logical_type type) { + if (!info || !name || !type) { + return; + } + auto bind_info = (duckdb::CTableInternalBindInfo *)info; + bind_info->names.push_back(name); + bind_info->return_types.push_back(*(reinterpret_cast(type))); +} + +idx_t duckdb_bind_get_parameter_count(duckdb_bind_info info) { + if (!info) { + return 0; + } + auto bind_info = (duckdb::CTableInternalBindInfo *)info; + return bind_info->input.inputs.size(); +} + +duckdb_value duckdb_bind_get_parameter(duckdb_bind_info info, idx_t index) { + if (!info || index >= duckdb_bind_get_parameter_count(info)) { + return nullptr; + } + auto bind_info = (duckdb::CTableInternalBindInfo *)info; + return reinterpret_cast(new duckdb::Value(bind_info->input.inputs[index])); +} + +duckdb_value duckdb_bind_get_named_parameter(duckdb_bind_info info, const char *name) { + if (!info || !name) { + return nullptr; + } + auto bind_info = (duckdb::CTableInternalBindInfo *)info; + auto t = bind_info->input.named_parameters.find(name); + if (t == bind_info->input.named_parameters.end()) { + return nullptr; + } else { + return reinterpret_cast(new duckdb::Value(t->second)); + } +} + +void duckdb_bind_set_bind_data(duckdb_bind_info info, void *bind_data, duckdb_delete_callback_t destroy) { + if (!info) { + return; + } + auto bind_info = (duckdb::CTableInternalBindInfo *)info; + bind_info->bind_data.bind_data = bind_data; + bind_info->bind_data.delete_callback = destroy; +} + +void duckdb_bind_set_cardinality(duckdb_bind_info info, idx_t cardinality, bool is_exact) { + if (!info) { + return; + } + auto bind_info = (duckdb::CTableInternalBindInfo *)info; + if (is_exact) { + bind_info->bind_data.stats = duckdb::make_uniq(cardinality); + } else { + bind_info->bind_data.stats = duckdb::make_uniq(cardinality, cardinality); + } +} + +void duckdb_bind_set_error(duckdb_bind_info info, const char *error) { + if (!info || !error) { + return; + } + auto function_info = (duckdb::CTableInternalBindInfo *)info; + function_info->error = error; + function_info->success = false; +} + +//===--------------------------------------------------------------------===// +// Init Interface +//===--------------------------------------------------------------------===// +void *duckdb_init_get_extra_info(duckdb_init_info info) { + if (!info) { + return nullptr; + } + auto init_info = (duckdb::CTableInternalInitInfo *)info; + return init_info->bind_data.info.extra_info; +} + +void *duckdb_init_get_bind_data(duckdb_init_info info) { + if (!info) { + return nullptr; + } + auto init_info = (duckdb::CTableInternalInitInfo *)info; + return init_info->bind_data.bind_data; +} + +void duckdb_init_set_init_data(duckdb_init_info info, void *init_data, duckdb_delete_callback_t destroy) { + if (!info) { + return; + } + auto init_info = (duckdb::CTableInternalInitInfo *)info; + init_info->init_data.init_data = init_data; + init_info->init_data.delete_callback = destroy; +} + +void duckdb_init_set_error(duckdb_init_info info, const char *error) { + if (!info || !error) { + return; + } + auto function_info = (duckdb::CTableInternalInitInfo *)info; + function_info->error = error; + function_info->success = false; +} + +idx_t duckdb_init_get_column_count(duckdb_init_info info) { + if (!info) { + return 0; + } + auto function_info = (duckdb::CTableInternalInitInfo *)info; + return function_info->column_ids.size(); +} + +idx_t duckdb_init_get_column_index(duckdb_init_info info, idx_t column_index) { + if (!info) { + return 0; + } + auto function_info = (duckdb::CTableInternalInitInfo *)info; + if (column_index >= function_info->column_ids.size()) { + return 0; + } + return function_info->column_ids[column_index]; +} + +void duckdb_init_set_max_threads(duckdb_init_info info, idx_t max_threads) { + if (!info) { + return; + } + auto function_info = (duckdb::CTableInternalInitInfo *)info; + function_info->init_data.max_threads = max_threads; +} + +//===--------------------------------------------------------------------===// +// Function Interface +//===--------------------------------------------------------------------===// +void *duckdb_function_get_extra_info(duckdb_function_info info) { + if (!info) { + return nullptr; + } + auto function_info = (duckdb::CTableInternalFunctionInfo *)info; + return function_info->bind_data.info.extra_info; +} + +void *duckdb_function_get_bind_data(duckdb_function_info info) { + if (!info) { + return nullptr; + } + auto function_info = (duckdb::CTableInternalFunctionInfo *)info; + return function_info->bind_data.bind_data; +} + +void *duckdb_function_get_init_data(duckdb_function_info info) { + if (!info) { + return nullptr; + } + auto function_info = (duckdb::CTableInternalFunctionInfo *)info; + return function_info->init_data.init_data; +} + +void *duckdb_function_get_local_init_data(duckdb_function_info info) { + if (!info) { + return nullptr; + } + auto function_info = (duckdb::CTableInternalFunctionInfo *)info; + return function_info->local_data.init_data; +} + +void duckdb_function_set_error(duckdb_function_info info, const char *error) { + if (!info || !error) { + return; + } + auto function_info = (duckdb::CTableInternalFunctionInfo *)info; + function_info->error = error; + function_info->success = false; +} diff --git a/src/duckdb/src/main/capi/threading-c.cpp b/src/duckdb/src/main/capi/threading-c.cpp new file mode 100644 index 00000000..b0a002f1 --- /dev/null +++ b/src/duckdb/src/main/capi/threading-c.cpp @@ -0,0 +1,88 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/parallel/task_scheduler.hpp" + +using duckdb::DatabaseData; + +struct CAPITaskState { + CAPITaskState(duckdb::DatabaseInstance &db) + : db(db), marker(duckdb::make_uniq>(true)), execute_count(0) { + } + + duckdb::DatabaseInstance &db; + duckdb::unique_ptr> marker; + duckdb::atomic execute_count; +}; + +void duckdb_execute_tasks(duckdb_database database, idx_t max_tasks) { + if (!database) { + return; + } + auto wrapper = (DatabaseData *)database; + auto &scheduler = duckdb::TaskScheduler::GetScheduler(*wrapper->database->instance); + scheduler.ExecuteTasks(max_tasks); +} + +duckdb_task_state duckdb_create_task_state(duckdb_database database) { + if (!database) { + return nullptr; + } + auto wrapper = (DatabaseData *)database; + auto state = new CAPITaskState(*wrapper->database->instance); + return state; +} + +void duckdb_execute_tasks_state(duckdb_task_state state_p) { + if (!state_p) { + return; + } + auto state = (CAPITaskState *)state_p; + auto &scheduler = duckdb::TaskScheduler::GetScheduler(state->db); + state->execute_count++; + scheduler.ExecuteForever(state->marker.get()); +} + +idx_t duckdb_execute_n_tasks_state(duckdb_task_state state_p, idx_t max_tasks) { + if (!state_p) { + return 0; + } + auto state = (CAPITaskState *)state_p; + auto &scheduler = duckdb::TaskScheduler::GetScheduler(state->db); + return scheduler.ExecuteTasks(state->marker.get(), max_tasks); +} + +void duckdb_finish_execution(duckdb_task_state state_p) { + if (!state_p) { + return; + } + auto state = (CAPITaskState *)state_p; + *state->marker = false; + if (state->execute_count > 0) { + // signal to the threads to wake up + auto &scheduler = duckdb::TaskScheduler::GetScheduler(state->db); + scheduler.Signal(state->execute_count); + } +} + +bool duckdb_task_state_is_finished(duckdb_task_state state_p) { + if (!state_p) { + return false; + } + auto state = (CAPITaskState *)state_p; + return !(*state->marker); +} + +void duckdb_destroy_task_state(duckdb_task_state state_p) { + if (!state_p) { + return; + } + auto state = (CAPITaskState *)state_p; + delete state; +} + +bool duckdb_execution_is_finished(duckdb_connection con) { + if (!con) { + return false; + } + duckdb::Connection *conn = (duckdb::Connection *)con; + return conn->context->ExecutionIsFinished(); +} diff --git a/src/duckdb/src/main/capi/value-c.cpp b/src/duckdb/src/main/capi/value-c.cpp new file mode 100644 index 00000000..a4a6040b --- /dev/null +++ b/src/duckdb/src/main/capi/value-c.cpp @@ -0,0 +1,168 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types.hpp" + +#include "duckdb/main/capi/cast/generic.hpp" + +#include + +using duckdb::date_t; +using duckdb::dtime_t; +using duckdb::FetchDefaultValue; +using duckdb::GetInternalCValue; +using duckdb::hugeint_t; +using duckdb::interval_t; +using duckdb::StringCast; +using duckdb::timestamp_t; +using duckdb::ToCStringCastWrapper; +using duckdb::UnsafeFetch; + +bool duckdb_value_boolean(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +int8_t duckdb_value_int8(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +int16_t duckdb_value_int16(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +int32_t duckdb_value_int32(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +int64_t duckdb_value_int64(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +static bool ResultIsDecimal(duckdb_result *result, idx_t col) { + if (!result) { + return false; + } + if (!result->internal_data) { + return false; + } + auto result_data = (duckdb::DuckDBResultData *)result->internal_data; + auto &query_result = result_data->result; + auto &source_type = query_result->types[col]; + return source_type.id() == duckdb::LogicalTypeId::DECIMAL; +} + +duckdb_decimal duckdb_value_decimal(duckdb_result *result, idx_t col, idx_t row) { + if (!CanFetchValue(result, col, row) || !ResultIsDecimal(result, col)) { + return FetchDefaultValue::Operation(); + } + + return GetInternalCValue(result, col, row); +} + +duckdb_hugeint duckdb_value_hugeint(duckdb_result *result, idx_t col, idx_t row) { + duckdb_hugeint result_value; + auto internal_value = GetInternalCValue(result, col, row); + result_value.lower = internal_value.lower; + result_value.upper = internal_value.upper; + return result_value; +} + +uint8_t duckdb_value_uint8(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +uint16_t duckdb_value_uint16(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +uint32_t duckdb_value_uint32(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +uint64_t duckdb_value_uint64(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +float duckdb_value_float(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +double duckdb_value_double(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue(result, col, row); +} + +duckdb_date duckdb_value_date(duckdb_result *result, idx_t col, idx_t row) { + duckdb_date result_value; + result_value.days = GetInternalCValue(result, col, row).days; + return result_value; +} + +duckdb_time duckdb_value_time(duckdb_result *result, idx_t col, idx_t row) { + duckdb_time result_value; + result_value.micros = GetInternalCValue(result, col, row).micros; + return result_value; +} + +duckdb_timestamp duckdb_value_timestamp(duckdb_result *result, idx_t col, idx_t row) { + duckdb_timestamp result_value; + result_value.micros = GetInternalCValue(result, col, row).value; + return result_value; +} + +duckdb_interval duckdb_value_interval(duckdb_result *result, idx_t col, idx_t row) { + duckdb_interval result_value; + auto ival = GetInternalCValue(result, col, row); + result_value.months = ival.months; + result_value.days = ival.days; + result_value.micros = ival.micros; + return result_value; +} + +char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t row) { + return duckdb_value_string(result, col, row).data; +} + +duckdb_string duckdb_value_string(duckdb_result *result, idx_t col, idx_t row) { + return GetInternalCValue>(result, col, row); +} + +char *duckdb_value_varchar_internal(duckdb_result *result, idx_t col, idx_t row) { + return duckdb_value_string_internal(result, col, row).data; +} + +duckdb_string duckdb_value_string_internal(duckdb_result *result, idx_t col, idx_t row) { + if (!CanFetchValue(result, col, row)) { + return FetchDefaultValue::Operation(); + } + if (duckdb_column_type(result, col) != DUCKDB_TYPE_VARCHAR) { + return FetchDefaultValue::Operation(); + } + // FIXME: this obviously does not work when there are null bytes in the string + // we need to remove the deprecated C result materialization to get that to work correctly + // since the deprecated C result materialization stores strings as null-terminated + duckdb_string res; + res.data = UnsafeFetch(result, col, row); + res.size = strlen(res.data); + return res; +} + +duckdb_blob duckdb_value_blob(duckdb_result *result, idx_t col, idx_t row) { + if (CanFetchValue(result, col, row) && result->__deprecated_columns[col].__deprecated_type == DUCKDB_TYPE_BLOB) { + auto internal_result = UnsafeFetch(result, col, row); + + duckdb_blob result_blob; + result_blob.data = malloc(internal_result.size); + result_blob.size = internal_result.size; + memcpy(result_blob.data, internal_result.data, internal_result.size); + return result_blob; + } + return FetchDefaultValue::Operation(); +} + +bool duckdb_value_is_null(duckdb_result *result, idx_t col, idx_t row) { + if (!CanUseDeprecatedFetch(result, col, row)) { + return false; + } + return result->__deprecated_columns[col].__deprecated_nullmask[row]; +} diff --git a/src/duckdb/src/main/chunk_scan_state.cpp b/src/duckdb/src/main/chunk_scan_state.cpp new file mode 100644 index 00000000..aab054c9 --- /dev/null +++ b/src/duckdb/src/main/chunk_scan_state.cpp @@ -0,0 +1,48 @@ +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/main/chunk_scan_state.hpp" + +namespace duckdb { + +ChunkScanState::ChunkScanState() { +} + +ChunkScanState::~ChunkScanState() { +} + +idx_t ChunkScanState::CurrentOffset() const { + return offset; +} + +void ChunkScanState::IncreaseOffset(idx_t increment, bool unsafe) { + D_ASSERT(unsafe || increment <= RemainingInChunk()); + offset += increment; +} + +bool ChunkScanState::ChunkIsEmpty() const { + return !current_chunk || current_chunk->size() == 0; +} + +bool ChunkScanState::Finished() const { + return finished; +} + +bool ChunkScanState::ScanStarted() const { + return !ChunkIsEmpty(); +} + +DataChunk &ChunkScanState::CurrentChunk() { + // Scan must already be started + D_ASSERT(current_chunk); + return *current_chunk; +} + +idx_t ChunkScanState::RemainingInChunk() const { + if (ChunkIsEmpty()) { + return 0; + } + D_ASSERT(current_chunk); + D_ASSERT(offset <= current_chunk->size()); + return current_chunk->size() - offset; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/chunk_scan_state/query_result.cpp b/src/duckdb/src/main/chunk_scan_state/query_result.cpp new file mode 100644 index 00000000..84e45e36 --- /dev/null +++ b/src/duckdb/src/main/chunk_scan_state/query_result.cpp @@ -0,0 +1,53 @@ +#include "duckdb/main/query_result.hpp" +#include "duckdb/main/stream_query_result.hpp" +#include "duckdb/main/chunk_scan_state/query_result.hpp" + +namespace duckdb { + +QueryResultChunkScanState::QueryResultChunkScanState(QueryResult &result) : ChunkScanState(), result(result) { +} + +QueryResultChunkScanState::~QueryResultChunkScanState() { +} + +bool QueryResultChunkScanState::InternalLoad(PreservedError &error) { + D_ASSERT(!finished); + if (result.type == QueryResultType::STREAM_RESULT) { + auto &stream_result = result.Cast(); + if (!stream_result.IsOpen()) { + return true; + } + } + return result.TryFetch(current_chunk, error); +} + +bool QueryResultChunkScanState::HasError() const { + return result.HasError(); +} + +PreservedError &QueryResultChunkScanState::GetError() { + D_ASSERT(result.HasError()); + return result.GetErrorObject(); +} + +const vector &QueryResultChunkScanState::Types() const { + return result.types; +} + +const vector &QueryResultChunkScanState::Names() const { + return result.names; +} + +bool QueryResultChunkScanState::LoadNextChunk(PreservedError &error) { + if (finished) { + return !finished; + } + auto load_result = InternalLoad(error); + if (!load_result) { + finished = true; + } + offset = 0; + return !finished; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/client_context.cpp b/src/duckdb/src/main/client_context.cpp new file mode 100644 index 00000000..157f0be7 --- /dev/null +++ b/src/duckdb/src/main/client_context.cpp @@ -0,0 +1,1186 @@ +#include "duckdb/main/client_context.hpp" + +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/common/progress_bar/progress_bar.hpp" +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/column_binding_resolver.hpp" +#include "duckdb/execution/operator/helper/physical_result_collector.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/main/appender.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/client_context_file_opener.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/error_manager.hpp" +#include "duckdb/main/materialized_query_result.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/main/query_result.hpp" +#include "duckdb/main/relation.hpp" +#include "duckdb/main/stream_query_result.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/parser/parsed_data/create_function_info.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/statement/drop_statement.hpp" +#include "duckdb/parser/statement/execute_statement.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/parser/statement/prepare_statement.hpp" +#include "duckdb/parser/statement/relation_statement.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/planner/operator/logical_execute.hpp" +#include "duckdb/planner/planner.hpp" +#include "duckdb/planner/pragma_handler.hpp" +#include "duckdb/transaction/meta_transaction.hpp" +#include "duckdb/transaction/transaction_manager.hpp" + +namespace duckdb { + +struct ActiveQueryContext { + //! The query that is currently being executed + string query; + //! The currently open result + BaseQueryResult *open_result = nullptr; + //! Prepared statement data + shared_ptr prepared; + //! The query executor + unique_ptr executor; + //! The progress bar + unique_ptr progress_bar; +}; + +ClientContext::ClientContext(shared_ptr database) + : db(std::move(database)), interrupted(false), client_data(make_uniq(*this)), transaction(*this) { +} + +ClientContext::~ClientContext() { + if (Exception::UncaughtException()) { + return; + } + // destroy the client context and rollback if there is an active transaction + // but only if we are not destroying this client context as part of an exception stack unwind + Destroy(); +} + +unique_ptr ClientContext::LockContext() { + return make_uniq(context_lock); +} + +void ClientContext::Destroy() { + auto lock = LockContext(); + if (transaction.HasActiveTransaction()) { + transaction.ResetActiveQuery(); + if (!transaction.IsAutoCommit()) { + transaction.Rollback(); + } + } + CleanupInternal(*lock); +} + +unique_ptr ClientContext::Fetch(ClientContextLock &lock, StreamQueryResult &result) { + D_ASSERT(IsActiveResult(lock, &result)); + D_ASSERT(active_query->executor); + return FetchInternal(lock, *active_query->executor, result); +} + +unique_ptr ClientContext::FetchInternal(ClientContextLock &lock, Executor &executor, + BaseQueryResult &result) { + bool invalidate_query = true; + try { + // fetch the chunk and return it + auto chunk = executor.FetchChunk(); + if (!chunk || chunk->size() == 0) { + CleanupInternal(lock, &result); + } + return chunk; + } catch (StandardException &ex) { + // standard exceptions do not invalidate the current transaction + result.SetError(PreservedError(ex)); + invalidate_query = false; + } catch (FatalException &ex) { + // fatal exceptions invalidate the entire database + result.SetError(PreservedError(ex)); + auto &db = DatabaseInstance::GetDatabase(*this); + ValidChecker::Invalidate(db, ex.what()); + } catch (const Exception &ex) { + result.SetError(PreservedError(ex)); + } catch (std::exception &ex) { + result.SetError(PreservedError(ex)); + } catch (...) { // LCOV_EXCL_START + result.SetError(PreservedError("Unhandled exception in FetchInternal")); + } // LCOV_EXCL_STOP + CleanupInternal(lock, &result, invalidate_query); + return nullptr; +} + +void ClientContext::BeginTransactionInternal(ClientContextLock &lock, bool requires_valid_transaction) { + // check if we are on AutoCommit. In this case we should start a transaction + D_ASSERT(!active_query); + auto &db = DatabaseInstance::GetDatabase(*this); + if (ValidChecker::IsInvalidated(db)) { + throw FatalException(ErrorManager::FormatException(*this, ErrorType::INVALIDATED_DATABASE, + ValidChecker::InvalidatedMessage(db))); + } + if (requires_valid_transaction && transaction.HasActiveTransaction() && + ValidChecker::IsInvalidated(transaction.ActiveTransaction())) { + throw Exception(ErrorManager::FormatException(*this, ErrorType::INVALIDATED_TRANSACTION)); + } + active_query = make_uniq(); + if (transaction.IsAutoCommit()) { + transaction.BeginTransaction(); + } +} + +void ClientContext::BeginQueryInternal(ClientContextLock &lock, const string &query) { + BeginTransactionInternal(lock, false); + LogQueryInternal(lock, query); + active_query->query = query; + query_progress = -1; + transaction.SetActiveQuery(db->GetDatabaseManager().GetNewQueryNumber()); +} + +PreservedError ClientContext::EndQueryInternal(ClientContextLock &lock, bool success, bool invalidate_transaction) { + client_data->profiler->EndQuery(); + + if (client_data->http_state) { + client_data->http_state->Reset(); + } + + // Notify any registered state of query end + for (auto const &s : registered_state) { + s.second->QueryEnd(); + } + + D_ASSERT(active_query.get()); + active_query.reset(); + query_progress = -1; + PreservedError error; + try { + if (transaction.HasActiveTransaction()) { + // Move the query profiler into the history + auto &prev_profilers = client_data->query_profiler_history->GetPrevProfilers(); + prev_profilers.emplace_back(transaction.GetActiveQuery(), std::move(client_data->profiler)); + // Reinitialize the query profiler + client_data->profiler = make_shared(*this); + // Propagate settings of the saved query into the new profiler. + client_data->profiler->Propagate(*prev_profilers.back().second); + if (prev_profilers.size() >= client_data->query_profiler_history->GetPrevProfilersSize()) { + prev_profilers.pop_front(); + } + + transaction.ResetActiveQuery(); + if (transaction.IsAutoCommit()) { + if (success) { + transaction.Commit(); + } else { + transaction.Rollback(); + } + } else if (invalidate_transaction) { + D_ASSERT(!success); + ValidChecker::Invalidate(ActiveTransaction(), "Failed to commit"); + } + } + } catch (FatalException &ex) { + auto &db = DatabaseInstance::GetDatabase(*this); + ValidChecker::Invalidate(db, ex.what()); + error = PreservedError(ex); + } catch (const Exception &ex) { + error = PreservedError(ex); + } catch (std::exception &ex) { + error = PreservedError(ex); + } catch (...) { // LCOV_EXCL_START + error = PreservedError("Unhandled exception!"); + } // LCOV_EXCL_STOP + return error; +} + +void ClientContext::CleanupInternal(ClientContextLock &lock, BaseQueryResult *result, bool invalidate_transaction) { + client_data->http_state = make_shared(); + if (!active_query) { + // no query currently active + return; + } + if (active_query->executor) { + active_query->executor->CancelTasks(); + } + active_query->progress_bar.reset(); + + auto error = EndQueryInternal(lock, result ? !result->HasError() : false, invalidate_transaction); + if (result && !result->HasError()) { + // if an error occurred while committing report it in the result + result->SetError(error); + } + D_ASSERT(!active_query); +} + +Executor &ClientContext::GetExecutor() { + D_ASSERT(active_query); + D_ASSERT(active_query->executor); + return *active_query->executor; +} + +const string &ClientContext::GetCurrentQuery() { + D_ASSERT(active_query); + return active_query->query; +} + +unique_ptr ClientContext::FetchResultInternal(ClientContextLock &lock, PendingQueryResult &pending) { + D_ASSERT(active_query); + D_ASSERT(active_query->open_result == &pending); + D_ASSERT(active_query->prepared); + auto &executor = GetExecutor(); + auto &prepared = *active_query->prepared; + bool create_stream_result = prepared.properties.allow_stream_result && pending.allow_stream_result; + if (create_stream_result) { + D_ASSERT(!executor.HasResultCollector()); + active_query->progress_bar.reset(); + query_progress = -1; + + // successfully compiled SELECT clause, and it is the last statement + // return a StreamQueryResult so the client can call Fetch() on it and stream the result + auto stream_result = make_uniq(pending.statement_type, pending.properties, + shared_from_this(), pending.types, pending.names); + active_query->open_result = stream_result.get(); + return std::move(stream_result); + } + unique_ptr result; + if (executor.HasResultCollector()) { + // we have a result collector - fetch the result directly from the result collector + result = executor.GetResult(); + CleanupInternal(lock, result.get(), false); + } else { + // no result collector - create a materialized result by continuously fetching + auto result_collection = make_uniq(Allocator::DefaultAllocator(), pending.types); + D_ASSERT(!result_collection->Types().empty()); + auto materialized_result = + make_uniq(pending.statement_type, pending.properties, pending.names, + std::move(result_collection), GetClientProperties()); + + auto &collection = materialized_result->Collection(); + D_ASSERT(!collection.Types().empty()); + ColumnDataAppendState append_state; + collection.InitializeAppend(append_state); + while (true) { + auto chunk = FetchInternal(lock, GetExecutor(), *materialized_result); + if (!chunk || chunk->size() == 0) { + break; + } +#ifdef DEBUG + for (idx_t i = 0; i < chunk->ColumnCount(); i++) { + if (pending.types[i].id() == LogicalTypeId::VARCHAR) { + chunk->data[i].UTFVerify(chunk->size()); + } + } +#endif + collection.Append(append_state, *chunk); + } + result = std::move(materialized_result); + } + return result; +} + +static bool IsExplainAnalyze(SQLStatement *statement) { + if (!statement) { + return false; + } + if (statement->type != StatementType::EXPLAIN_STATEMENT) { + return false; + } + auto &explain = statement->Cast(); + return explain.explain_type == ExplainType::EXPLAIN_ANALYZE; +} + +shared_ptr +ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr statement, + optional_ptr> values) { + StatementType statement_type = statement->type; + auto result = make_shared(statement_type); + + auto &profiler = QueryProfiler::Get(*this); + profiler.StartQuery(query, IsExplainAnalyze(statement.get()), true); + profiler.StartPhase("planner"); + Planner planner(*this); + if (values) { + auto ¶meter_values = *values; + for (auto &value : parameter_values) { + planner.parameter_data.emplace(value.first, BoundParameterData(value.second)); + } + } + + client_data->http_state = make_shared(); + planner.CreatePlan(std::move(statement)); + D_ASSERT(planner.plan || !planner.properties.bound_all_parameters); + profiler.EndPhase(); + + auto plan = std::move(planner.plan); + // extract the result column names from the plan + result->properties = planner.properties; + result->names = planner.names; + result->types = planner.types; + result->value_map = std::move(planner.value_map); + result->catalog_version = MetaTransaction::Get(*this).catalog_version; + + if (!planner.properties.bound_all_parameters) { + return result; + } +#ifdef DEBUG + plan->Verify(*this); +#endif + if (config.enable_optimizer && plan->RequireOptimizer()) { + profiler.StartPhase("optimizer"); + Optimizer optimizer(*planner.binder, *this); + plan = optimizer.Optimize(std::move(plan)); + D_ASSERT(plan); + profiler.EndPhase(); + +#ifdef DEBUG + plan->Verify(*this); +#endif + } + + profiler.StartPhase("physical_planner"); + // now convert logical query plan into a physical query plan + PhysicalPlanGenerator physical_planner(*this); + auto physical_plan = physical_planner.CreatePlan(std::move(plan)); + profiler.EndPhase(); + +#ifdef DEBUG + D_ASSERT(!physical_plan->ToString().empty()); +#endif + result->plan = std::move(physical_plan); + return result; +} + +double ClientContext::GetProgress() { + return query_progress.load(); +} + +unique_ptr ClientContext::PendingPreparedStatement(ClientContextLock &lock, + shared_ptr statement_p, + const PendingQueryParameters ¶meters) { + D_ASSERT(active_query); + auto &statement = *statement_p; + if (ValidChecker::IsInvalidated(ActiveTransaction()) && statement.properties.requires_valid_transaction) { + throw Exception(ErrorManager::FormatException(*this, ErrorType::INVALIDATED_TRANSACTION)); + } + auto &transaction = MetaTransaction::Get(*this); + auto &manager = DatabaseManager::Get(*this); + for (auto &modified_database : statement.properties.modified_databases) { + auto entry = manager.GetDatabase(*this, modified_database); + if (!entry) { + throw InternalException("Database \"%s\" not found", modified_database); + } + if (entry->IsReadOnly()) { + throw Exception(StringUtil::Format( + "Cannot execute statement of type \"%s\" on database \"%s\" which is attached in read-only mode!", + StatementTypeToString(statement.statement_type), modified_database)); + } + transaction.ModifyDatabase(*entry); + } + + // bind the bound values before execution + case_insensitive_map_t owned_values; + if (parameters.parameters) { + auto ¶ms = *parameters.parameters; + for (auto &val : params) { + owned_values.emplace(val); + } + } + statement.Bind(std::move(owned_values)); + + active_query->executor = make_uniq(*this); + auto &executor = *active_query->executor; + if (config.enable_progress_bar) { + progress_bar_display_create_func_t display_create_func = nullptr; + if (config.print_progress_bar) { + // If a custom display is set, use that, otherwise just use the default + display_create_func = + config.display_create_func ? config.display_create_func : ProgressBar::DefaultProgressBarDisplay; + } + active_query->progress_bar = make_uniq(executor, config.wait_time, display_create_func); + active_query->progress_bar->Start(); + query_progress = 0; + } + auto stream_result = parameters.allow_stream_result && statement.properties.allow_stream_result; + if (!stream_result && statement.properties.return_type == StatementReturnType::QUERY_RESULT) { + unique_ptr collector; + auto &config = ClientConfig::GetConfig(*this); + auto get_method = + config.result_collector ? config.result_collector : PhysicalResultCollector::GetResultCollector; + collector = get_method(*this, statement); + D_ASSERT(collector->type == PhysicalOperatorType::RESULT_COLLECTOR); + executor.Initialize(std::move(collector)); + } else { + executor.Initialize(*statement.plan); + } + auto types = executor.GetTypes(); + D_ASSERT(types == statement.types); + D_ASSERT(!active_query->open_result); + + auto pending_result = + make_uniq(shared_from_this(), *statement_p, std::move(types), stream_result); + active_query->prepared = std::move(statement_p); + active_query->open_result = pending_result.get(); + return pending_result; +} + +PendingExecutionResult ClientContext::ExecuteTaskInternal(ClientContextLock &lock, PendingQueryResult &result) { + D_ASSERT(active_query); + D_ASSERT(active_query->open_result == &result); + try { + auto result = active_query->executor->ExecuteTask(); + if (active_query->progress_bar) { + active_query->progress_bar->Update(result == PendingExecutionResult::RESULT_READY); + query_progress = active_query->progress_bar->GetCurrentPercentage(); + } + return result; + } catch (FatalException &ex) { + // fatal exceptions invalidate the entire database + result.SetError(PreservedError(ex)); + auto &db = DatabaseInstance::GetDatabase(*this); + ValidChecker::Invalidate(db, ex.what()); + } catch (const Exception &ex) { + result.SetError(PreservedError(ex)); + } catch (std::exception &ex) { + result.SetError(PreservedError(ex)); + } catch (...) { // LCOV_EXCL_START + result.SetError(PreservedError("Unhandled exception in ExecuteTaskInternal")); + } // LCOV_EXCL_STOP + EndQueryInternal(lock, false, true); + return PendingExecutionResult::EXECUTION_ERROR; +} + +void ClientContext::InitialCleanup(ClientContextLock &lock) { + //! Cleanup any open results and reset the interrupted flag + CleanupInternal(lock); + interrupted = false; +} + +vector> ClientContext::ParseStatements(const string &query) { + auto lock = LockContext(); + return ParseStatementsInternal(*lock, query); +} + +vector> ClientContext::ParseStatementsInternal(ClientContextLock &lock, const string &query) { + Parser parser(GetParserOptions()); + parser.ParseQuery(query); + + PragmaHandler handler(*this); + handler.HandlePragmaStatements(lock, parser.statements); + + return std::move(parser.statements); +} + +void ClientContext::HandlePragmaStatements(vector> &statements) { + auto lock = LockContext(); + + PragmaHandler handler(*this); + handler.HandlePragmaStatements(*lock, statements); +} + +unique_ptr ClientContext::ExtractPlan(const string &query) { + auto lock = LockContext(); + + auto statements = ParseStatementsInternal(*lock, query); + if (statements.size() != 1) { + throw Exception("ExtractPlan can only prepare a single statement"); + } + + unique_ptr plan; + client_data->http_state = make_shared(); + RunFunctionInTransactionInternal(*lock, [&]() { + Planner planner(*this); + planner.CreatePlan(std::move(statements[0])); + D_ASSERT(planner.plan); + + plan = std::move(planner.plan); + + if (config.enable_optimizer) { + Optimizer optimizer(*planner.binder, *this); + plan = optimizer.Optimize(std::move(plan)); + } + + ColumnBindingResolver resolver; + resolver.Verify(*plan); + resolver.VisitOperator(*plan); + + plan->ResolveOperatorTypes(); + }); + return plan; +} + +unique_ptr ClientContext::PrepareInternal(ClientContextLock &lock, + unique_ptr statement) { + auto n_param = statement->n_param; + auto named_param_map = std::move(statement->named_param_map); + auto statement_query = statement->query; + shared_ptr prepared_data; + auto unbound_statement = statement->Copy(); + RunFunctionInTransactionInternal( + lock, [&]() { prepared_data = CreatePreparedStatement(lock, statement_query, std::move(statement)); }, false); + prepared_data->unbound_statement = std::move(unbound_statement); + return make_uniq(shared_from_this(), std::move(prepared_data), std::move(statement_query), + n_param, std::move(named_param_map)); +} + +unique_ptr ClientContext::Prepare(unique_ptr statement) { + auto lock = LockContext(); + // prepare the query + try { + InitialCleanup(*lock); + return PrepareInternal(*lock, std::move(statement)); + } catch (const Exception &ex) { + return make_uniq(PreservedError(ex)); + } catch (std::exception &ex) { + return make_uniq(PreservedError(ex)); + } +} + +unique_ptr ClientContext::Prepare(const string &query) { + auto lock = LockContext(); + // prepare the query + try { + InitialCleanup(*lock); + + // first parse the query + auto statements = ParseStatementsInternal(*lock, query); + if (statements.empty()) { + throw Exception("No statement to prepare!"); + } + if (statements.size() > 1) { + throw Exception("Cannot prepare multiple statements at once!"); + } + return PrepareInternal(*lock, std::move(statements[0])); + } catch (const Exception &ex) { + return make_uniq(PreservedError(ex)); + } catch (std::exception &ex) { + return make_uniq(PreservedError(ex)); + } +} + +unique_ptr ClientContext::PendingQueryPreparedInternal(ClientContextLock &lock, const string &query, + shared_ptr &prepared, + const PendingQueryParameters ¶meters) { + try { + InitialCleanup(lock); + } catch (const Exception &ex) { + return make_uniq(PreservedError(ex)); + } catch (std::exception &ex) { + return make_uniq(PreservedError(ex)); + } + return PendingStatementOrPreparedStatementInternal(lock, query, nullptr, prepared, parameters); +} + +unique_ptr ClientContext::PendingQuery(const string &query, + shared_ptr &prepared, + const PendingQueryParameters ¶meters) { + auto lock = LockContext(); + return PendingQueryPreparedInternal(*lock, query, prepared, parameters); +} + +unique_ptr ClientContext::Execute(const string &query, shared_ptr &prepared, + const PendingQueryParameters ¶meters) { + auto lock = LockContext(); + auto pending = PendingQueryPreparedInternal(*lock, query, prepared, parameters); + if (pending->HasError()) { + return make_uniq(pending->GetErrorObject()); + } + return pending->ExecuteInternal(*lock); +} + +unique_ptr ClientContext::Execute(const string &query, shared_ptr &prepared, + case_insensitive_map_t &values, bool allow_stream_result) { + PendingQueryParameters parameters; + parameters.parameters = &values; + parameters.allow_stream_result = allow_stream_result; + return Execute(query, prepared, parameters); +} + +unique_ptr ClientContext::PendingStatementInternal(ClientContextLock &lock, const string &query, + unique_ptr statement, + const PendingQueryParameters ¶meters) { + // prepare the query for execution + auto prepared = CreatePreparedStatement(lock, query, std::move(statement), parameters.parameters); + idx_t parameter_count = !parameters.parameters ? 0 : parameters.parameters->size(); + if (prepared->properties.parameter_count > 0 && parameter_count == 0) { + string error_message = StringUtil::Format("Expected %lld parameters, but none were supplied", + prepared->properties.parameter_count); + return make_uniq(PreservedError(error_message)); + } + if (!prepared->properties.bound_all_parameters) { + return make_uniq(PreservedError("Not all parameters were bound")); + } + // execute the prepared statement + return PendingPreparedStatement(lock, std::move(prepared), parameters); +} + +unique_ptr ClientContext::RunStatementInternal(ClientContextLock &lock, const string &query, + unique_ptr statement, + bool allow_stream_result, bool verify) { + PendingQueryParameters parameters; + parameters.allow_stream_result = allow_stream_result; + auto pending = PendingQueryInternal(lock, std::move(statement), parameters, verify); + if (pending->HasError()) { + return make_uniq(pending->GetErrorObject()); + } + return ExecutePendingQueryInternal(lock, *pending); +} + +bool ClientContext::IsActiveResult(ClientContextLock &lock, BaseQueryResult *result) { + if (!active_query) { + return false; + } + return active_query->open_result == result; +} + +unique_ptr ClientContext::PendingStatementOrPreparedStatementInternal( + ClientContextLock &lock, const string &query, unique_ptr statement, + shared_ptr &prepared, const PendingQueryParameters ¶meters) { + // check if we are on AutoCommit. In this case we should start a transaction. + if (statement && config.AnyVerification()) { + // query verification is enabled + // create a copy of the statement, and use the copy + // this way we verify that the copy correctly copies all properties + auto copied_statement = statement->Copy(); + switch (statement->type) { + case StatementType::SELECT_STATEMENT: { + // in case this is a select query, we verify the original statement + PreservedError error; + try { + error = VerifyQuery(lock, query, std::move(statement)); + } catch (const Exception &ex) { + error = PreservedError(ex); + } catch (std::exception &ex) { + error = PreservedError(ex); + } + if (error) { + // error in verifying query + return make_uniq(error); + } + statement = std::move(copied_statement); + break; + } +#ifndef DUCKDB_ALTERNATIVE_VERIFY + case StatementType::COPY_STATEMENT: + case StatementType::INSERT_STATEMENT: + case StatementType::DELETE_STATEMENT: + case StatementType::UPDATE_STATEMENT: { + Parser parser; + PreservedError error; + try { + parser.ParseQuery(statement->ToString()); + } catch (const Exception &ex) { + error = PreservedError(ex); + } catch (std::exception &ex) { + error = PreservedError(ex); + } + if (error) { + // error in verifying query + return make_uniq(error); + } + statement = std::move(parser.statements[0]); + break; + } +#endif + default: + statement = std::move(copied_statement); + break; + } + } + return PendingStatementOrPreparedStatement(lock, query, std::move(statement), prepared, parameters); +} + +unique_ptr ClientContext::PendingStatementOrPreparedStatement( + ClientContextLock &lock, const string &query, unique_ptr statement, + shared_ptr &prepared, const PendingQueryParameters ¶meters) { + unique_ptr result; + + try { + BeginQueryInternal(lock, query); + } catch (FatalException &ex) { + // fatal exceptions invalidate the entire database + auto &db = DatabaseInstance::GetDatabase(*this); + ValidChecker::Invalidate(db, ex.what()); + result = make_uniq(PreservedError(ex)); + return result; + } catch (const Exception &ex) { + return make_uniq(PreservedError(ex)); + } catch (std::exception &ex) { + return make_uniq(PreservedError(ex)); + } + // start the profiler + auto &profiler = QueryProfiler::Get(*this); + profiler.StartQuery(query, IsExplainAnalyze(statement ? statement.get() : prepared->unbound_statement.get())); + + bool invalidate_query = true; + try { + if (statement) { + result = PendingStatementInternal(lock, query, std::move(statement), parameters); + } else { + if (prepared->RequireRebind(*this, parameters.parameters)) { + // catalog was modified: rebind the statement before execution + auto new_prepared = + CreatePreparedStatement(lock, query, prepared->unbound_statement->Copy(), parameters.parameters); + D_ASSERT(new_prepared->properties.bound_all_parameters); + new_prepared->unbound_statement = std::move(prepared->unbound_statement); + prepared = std::move(new_prepared); + prepared->properties.bound_all_parameters = false; + } + result = PendingPreparedStatement(lock, prepared, parameters); + } + } catch (StandardException &ex) { + // standard exceptions do not invalidate the current transaction + result = make_uniq(PreservedError(ex)); + invalidate_query = false; + } catch (FatalException &ex) { + // fatal exceptions invalidate the entire database + if (!config.query_verification_enabled) { + auto &db = DatabaseInstance::GetDatabase(*this); + ValidChecker::Invalidate(db, ex.what()); + } + result = make_uniq(PreservedError(ex)); + } catch (const Exception &ex) { + // other types of exceptions do invalidate the current transaction + result = make_uniq(PreservedError(ex)); + } catch (std::exception &ex) { + // other types of exceptions do invalidate the current transaction + result = make_uniq(PreservedError(ex)); + } + if (result->HasError()) { + // query failed: abort now + EndQueryInternal(lock, false, invalidate_query); + return result; + } + D_ASSERT(active_query->open_result == result.get()); + return result; +} + +void ClientContext::LogQueryInternal(ClientContextLock &, const string &query) { + if (!client_data->log_query_writer) { +#ifdef DUCKDB_FORCE_QUERY_LOG + try { + string log_path(DUCKDB_FORCE_QUERY_LOG); + client_data->log_query_writer = + make_uniq(FileSystem::GetFileSystem(*this), log_path, + BufferedFileWriter::DEFAULT_OPEN_FLAGS, client_data->file_opener.get()); + } catch (...) { + return; + } +#else + return; +#endif + } + // log query path is set: log the query + client_data->log_query_writer->WriteData(const_data_ptr_cast(query.c_str()), query.size()); + client_data->log_query_writer->WriteData(const_data_ptr_cast("\n"), 1); + client_data->log_query_writer->Flush(); + client_data->log_query_writer->Sync(); +} + +unique_ptr ClientContext::Query(unique_ptr statement, bool allow_stream_result) { + auto pending_query = PendingQuery(std::move(statement), allow_stream_result); + if (pending_query->HasError()) { + return make_uniq(pending_query->GetErrorObject()); + } + return pending_query->Execute(); +} + +unique_ptr ClientContext::Query(const string &query, bool allow_stream_result) { + auto lock = LockContext(); + + PreservedError error; + vector> statements; + if (!ParseStatements(*lock, query, statements, error)) { + return make_uniq(std::move(error)); + } + if (statements.empty()) { + // no statements, return empty successful result + StatementProperties properties; + vector names; + auto collection = make_uniq(Allocator::DefaultAllocator()); + return make_uniq(StatementType::INVALID_STATEMENT, properties, std::move(names), + std::move(collection), GetClientProperties()); + } + + unique_ptr result; + QueryResult *last_result = nullptr; + bool last_had_result = false; + for (idx_t i = 0; i < statements.size(); i++) { + auto &statement = statements[i]; + bool is_last_statement = i + 1 == statements.size(); + PendingQueryParameters parameters; + parameters.allow_stream_result = allow_stream_result && is_last_statement; + auto pending_query = PendingQueryInternal(*lock, std::move(statement), parameters); + auto has_result = pending_query->properties.return_type == StatementReturnType::QUERY_RESULT; + unique_ptr current_result; + if (pending_query->HasError()) { + current_result = make_uniq(pending_query->GetErrorObject()); + } else { + current_result = ExecutePendingQueryInternal(*lock, *pending_query); + } + // now append the result to the list of results + if (!last_result || !last_had_result) { + // first result of the query + result = std::move(current_result); + last_result = result.get(); + last_had_result = has_result; + } else { + // later results; attach to the result chain + // but only if there is a result + if (!has_result) { + continue; + } + last_result->next = std::move(current_result); + last_result = last_result->next.get(); + } + } + return result; +} + +bool ClientContext::ParseStatements(ClientContextLock &lock, const string &query, + vector> &result, PreservedError &error) { + try { + InitialCleanup(lock); + // parse the query and transform it into a set of statements + result = ParseStatementsInternal(lock, query); + return true; + } catch (const Exception &ex) { + error = PreservedError(ex); + return false; + } catch (std::exception &ex) { + error = PreservedError(ex); + return false; + } +} + +unique_ptr ClientContext::PendingQuery(const string &query, bool allow_stream_result) { + auto lock = LockContext(); + + PreservedError error; + vector> statements; + if (!ParseStatements(*lock, query, statements, error)) { + return make_uniq(std::move(error)); + } + if (statements.size() != 1) { + return make_uniq(PreservedError("PendingQuery can only take a single statement")); + } + PendingQueryParameters parameters; + parameters.allow_stream_result = allow_stream_result; + return PendingQueryInternal(*lock, std::move(statements[0]), parameters); +} + +unique_ptr ClientContext::PendingQuery(unique_ptr statement, + bool allow_stream_result) { + auto lock = LockContext(); + PendingQueryParameters parameters; + parameters.allow_stream_result = allow_stream_result; + return PendingQueryInternal(*lock, std::move(statement), parameters); +} + +unique_ptr ClientContext::PendingQueryInternal(ClientContextLock &lock, + unique_ptr statement, + const PendingQueryParameters ¶meters, + bool verify) { + auto query = statement->query; + shared_ptr prepared; + if (verify) { + return PendingStatementOrPreparedStatementInternal(lock, query, std::move(statement), prepared, parameters); + } else { + return PendingStatementOrPreparedStatement(lock, query, std::move(statement), prepared, parameters); + } +} + +unique_ptr ClientContext::ExecutePendingQueryInternal(ClientContextLock &lock, PendingQueryResult &query) { + return query.ExecuteInternal(lock); +} + +void ClientContext::Interrupt() { + interrupted = true; +} + +void ClientContext::EnableProfiling() { + auto lock = LockContext(); + auto &config = ClientConfig::GetConfig(*this); + config.enable_profiler = true; + config.emit_profiler_output = true; +} + +void ClientContext::DisableProfiling() { + auto lock = LockContext(); + auto &config = ClientConfig::GetConfig(*this); + config.enable_profiler = false; +} + +void ClientContext::RegisterFunction(CreateFunctionInfo &info) { + RunFunctionInTransaction([&]() { + auto existing_function = Catalog::GetEntry(*this, INVALID_CATALOG, info.schema, + info.name, OnEntryNotFound::RETURN_NULL); + if (existing_function) { + auto &new_info = info.Cast(); + if (new_info.functions.MergeFunctionSet(existing_function->functions)) { + // function info was updated from catalog entry, rewrite is needed + info.on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT; + } + } + // create function + auto &catalog = Catalog::GetSystemCatalog(*this); + catalog.CreateFunction(*this, info); + }); +} + +void ClientContext::RunFunctionInTransactionInternal(ClientContextLock &lock, const std::function &fun, + bool requires_valid_transaction) { + if (requires_valid_transaction && transaction.HasActiveTransaction() && + ValidChecker::IsInvalidated(ActiveTransaction())) { + throw TransactionException(ErrorManager::FormatException(*this, ErrorType::INVALIDATED_TRANSACTION)); + } + // check if we are on AutoCommit. In this case we should start a transaction + bool require_new_transaction = transaction.IsAutoCommit() && !transaction.HasActiveTransaction(); + if (require_new_transaction) { + D_ASSERT(!active_query); + transaction.BeginTransaction(); + } + try { + fun(); + } catch (StandardException &ex) { + if (require_new_transaction) { + transaction.Rollback(); + } + throw; + } catch (FatalException &ex) { + auto &db = DatabaseInstance::GetDatabase(*this); + ValidChecker::Invalidate(db, ex.what()); + throw; + } catch (std::exception &ex) { + if (require_new_transaction) { + transaction.Rollback(); + } else { + ValidChecker::Invalidate(ActiveTransaction(), ex.what()); + } + throw; + } + if (require_new_transaction) { + transaction.Commit(); + } +} + +void ClientContext::RunFunctionInTransaction(const std::function &fun, bool requires_valid_transaction) { + auto lock = LockContext(); + RunFunctionInTransactionInternal(*lock, fun, requires_valid_transaction); +} + +unique_ptr ClientContext::TableInfo(const string &schema_name, const string &table_name) { + unique_ptr result; + RunFunctionInTransaction([&]() { + // obtain the table info + auto table = Catalog::GetEntry(*this, INVALID_CATALOG, schema_name, table_name, + OnEntryNotFound::RETURN_NULL); + if (!table) { + return; + } + // write the table info to the result + result = make_uniq(); + result->schema = schema_name; + result->table = table_name; + for (auto &column : table->GetColumns().Logical()) { + result->columns.emplace_back(column.Name(), column.Type()); + } + }); + return result; +} + +void ClientContext::Append(TableDescription &description, ColumnDataCollection &collection) { + RunFunctionInTransaction([&]() { + auto &table_entry = + Catalog::GetEntry(*this, INVALID_CATALOG, description.schema, description.table); + // verify that the table columns and types match up + if (description.columns.size() != table_entry.GetColumns().PhysicalColumnCount()) { + throw Exception("Failed to append: table entry has different number of columns!"); + } + for (idx_t i = 0; i < description.columns.size(); i++) { + if (description.columns[i].Type() != table_entry.GetColumns().GetColumn(PhysicalIndex(i)).Type()) { + throw Exception("Failed to append: table entry has different number of columns!"); + } + } + table_entry.GetStorage().LocalAppend(table_entry, *this, collection); + }); +} + +void ClientContext::TryBindRelation(Relation &relation, vector &result_columns) { +#ifdef DEBUG + D_ASSERT(!relation.GetAlias().empty()); + D_ASSERT(!relation.ToString().empty()); +#endif + client_data->http_state = make_shared(); + RunFunctionInTransaction([&]() { + // bind the expressions + auto binder = Binder::CreateBinder(*this); + auto result = relation.Bind(*binder); + D_ASSERT(result.names.size() == result.types.size()); + + result_columns.reserve(result_columns.size() + result.names.size()); + for (idx_t i = 0; i < result.names.size(); i++) { + result_columns.emplace_back(result.names[i], result.types[i]); + } + }); +} + +unordered_set ClientContext::GetTableNames(const string &query) { + auto lock = LockContext(); + + auto statements = ParseStatementsInternal(*lock, query); + if (statements.size() != 1) { + throw InvalidInputException("Expected a single statement"); + } + + unordered_set result; + RunFunctionInTransactionInternal(*lock, [&]() { + // bind the expressions + auto binder = Binder::CreateBinder(*this); + binder->SetBindingMode(BindingMode::EXTRACT_NAMES); + binder->Bind(*statements[0]); + result = binder->GetTableNames(); + }); + return result; +} + +unique_ptr ClientContext::PendingQueryInternal(ClientContextLock &lock, + const shared_ptr &relation, + bool allow_stream_result) { + InitialCleanup(lock); + + string query; + if (config.query_verification_enabled) { + // run the ToString method of any relation we run, mostly to ensure it doesn't crash + relation->ToString(); + relation->GetAlias(); + if (relation->IsReadOnly()) { + // verify read only statements by running a select statement + auto select = make_uniq(); + select->node = relation->GetQueryNode(); + RunStatementInternal(lock, query, std::move(select), false); + } + } + + auto relation_stmt = make_uniq(relation); + PendingQueryParameters parameters; + parameters.allow_stream_result = allow_stream_result; + return PendingQueryInternal(lock, std::move(relation_stmt), parameters); +} + +unique_ptr ClientContext::PendingQuery(const shared_ptr &relation, + bool allow_stream_result) { + auto lock = LockContext(); + return PendingQueryInternal(*lock, relation, allow_stream_result); +} + +unique_ptr ClientContext::Execute(const shared_ptr &relation) { + auto lock = LockContext(); + auto &expected_columns = relation->Columns(); + auto pending = PendingQueryInternal(*lock, relation, false); + if (!pending->success) { + return make_uniq(pending->GetErrorObject()); + } + + unique_ptr result; + result = ExecutePendingQueryInternal(*lock, *pending); + if (result->HasError()) { + return result; + } + // verify that the result types and result names of the query match the expected result types/names + if (result->types.size() == expected_columns.size()) { + bool mismatch = false; + for (idx_t i = 0; i < result->types.size(); i++) { + if (result->types[i] != expected_columns[i].Type() || result->names[i] != expected_columns[i].Name()) { + mismatch = true; + break; + } + } + if (!mismatch) { + // all is as expected: return the result + return result; + } + } + // result mismatch + string err_str = "Result mismatch in query!\nExpected the following columns: ["; + for (idx_t i = 0; i < expected_columns.size(); i++) { + if (i > 0) { + err_str += ", "; + } + err_str += expected_columns[i].Name() + " " + expected_columns[i].Type().ToString(); + } + err_str += "]\nBut result contained the following: "; + for (idx_t i = 0; i < result->types.size(); i++) { + err_str += i == 0 ? "[" : ", "; + err_str += result->names[i] + " " + result->types[i].ToString(); + } + err_str += "]"; + return make_uniq(PreservedError(err_str)); +} + +bool ClientContext::TryGetCurrentSetting(const std::string &key, Value &result) { + // first check the built-in settings + auto &db_config = DBConfig::GetConfig(*this); + auto option = db_config.GetOptionByName(key); + if (option) { + result = option->get_setting(*this); + return true; + } + + // check the client session values + const auto &session_config_map = config.set_variables; + + auto session_value = session_config_map.find(key); + bool found_session_value = session_value != session_config_map.end(); + if (found_session_value) { + result = session_value->second; + return true; + } + // finally check the global session values + return db->TryGetCurrentSetting(key, result); +} + +ParserOptions ClientContext::GetParserOptions() const { + auto &client_config = ClientConfig::GetConfig(*this); + ParserOptions options; + options.preserve_identifier_case = client_config.preserve_identifier_case; + options.integer_division = client_config.integer_division; + options.max_expression_depth = client_config.max_expression_depth; + options.extensions = &DBConfig::GetConfig(*this).parser_extensions; + return options; +} + +ClientProperties ClientContext::GetClientProperties() const { + string timezone = "UTC"; + Value result; + // 1) Check Set Variable + auto &client_config = ClientConfig::GetConfig(*this); + auto tz_config = client_config.set_variables.find("timezone"); + if (tz_config == client_config.set_variables.end()) { + // 2) Check for Default Value + auto default_value = db->config.extension_parameters.find("timezone"); + if (default_value != db->config.extension_parameters.end()) { + timezone = default_value->second.default_value.GetValue(); + } + } else { + timezone = tz_config->second.GetValue(); + } + return {timezone, db->config.options.arrow_offset_size}; +} + +bool ClientContext::ExecutionIsFinished() { + if (!active_query || !active_query->executor) { + return false; + } + return active_query->executor->ExecutionIsFinished(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/client_context_file_opener.cpp b/src/duckdb/src/main/client_context_file_opener.cpp new file mode 100644 index 00000000..20de7959 --- /dev/null +++ b/src/duckdb/src/main/client_context_file_opener.cpp @@ -0,0 +1,42 @@ +#include "duckdb/main/client_context_file_opener.hpp" + +#include "duckdb/common/file_opener.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +bool ClientContextFileOpener::TryGetCurrentSetting(const string &key, Value &result) { + return context.TryGetCurrentSetting(key, result); +} + +// LCOV_EXCL_START +bool ClientContextFileOpener::TryGetCurrentSetting(const string &key, Value &result, FileOpenerInfo &) { + return context.TryGetCurrentSetting(key, result); +} + +ClientContext *FileOpener::TryGetClientContext(FileOpener *opener) { + if (!opener) { + return nullptr; + } + return opener->TryGetClientContext(); +} + +bool FileOpener::TryGetCurrentSetting(FileOpener *opener, const string &key, Value &result) { + if (!opener) { + return false; + } + return opener->TryGetCurrentSetting(key, result); +} + +bool FileOpener::TryGetCurrentSetting(FileOpener *opener, const string &key, Value &result, FileOpenerInfo &info) { + if (!opener) { + return false; + } + return opener->TryGetCurrentSetting(key, result, info); +} + +bool FileOpener::TryGetCurrentSetting(const string &key, Value &result, FileOpenerInfo &info) { + return this->TryGetCurrentSetting(key, result); +} +// LCOV_EXCL_STOP +} // namespace duckdb diff --git a/src/duckdb/src/main/client_data.cpp b/src/duckdb/src/main/client_data.cpp new file mode 100644 index 00000000..8553a65e --- /dev/null +++ b/src/duckdb/src/main/client_data.cpp @@ -0,0 +1,57 @@ +#include "duckdb/main/client_data.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/random_engine.hpp" +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_context_file_opener.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/common/opener_file_system.hpp" + +namespace duckdb { + +class ClientFileSystem : public OpenerFileSystem { +public: + explicit ClientFileSystem(ClientContext &context_p) : context(context_p) { + } + + FileSystem &GetFileSystem() const override { + auto &config = DBConfig::GetConfig(context); + return *config.file_system; + } + optional_ptr GetOpener() const override { + return ClientData::Get(context).file_opener.get(); + } + +private: + ClientContext &context; +}; + +ClientData::ClientData(ClientContext &context) : catalog_search_path(make_uniq(context)) { + auto &db = DatabaseInstance::GetDatabase(context); + profiler = make_shared(context); + query_profiler_history = make_uniq(); + temporary_objects = make_shared(db, AttachedDatabaseType::TEMP_DATABASE); + temporary_objects->oid = DatabaseManager::Get(db).ModifyCatalog(); + random_engine = make_uniq(); + file_opener = make_uniq(context); + client_file_system = make_uniq(context); + temporary_objects->Initialize(); +} +ClientData::~ClientData() { +} + +ClientData &ClientData::Get(ClientContext &context) { + return *context.client_data; +} + +RandomEngine &RandomEngine::Get(ClientContext &context) { + return *ClientData::Get(context).random_engine; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/client_verify.cpp b/src/duckdb/src/main/client_verify.cpp new file mode 100644 index 00000000..8244eb79 --- /dev/null +++ b/src/duckdb/src/main/client_verify.cpp @@ -0,0 +1,152 @@ +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/verification/statement_verifier.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/common/box_renderer.hpp" + +namespace duckdb { + +static void ThrowIfExceptionIsInternal(StatementVerifier &verifier) { + if (!verifier.materialized_result) { + return; + } + auto &result = *verifier.materialized_result; + if (!result.HasError()) { + return; + } + auto &error = result.GetErrorObject(); + if (error.Type() == ExceptionType::INTERNAL) { + error.Throw(); + } +} + +PreservedError ClientContext::VerifyQuery(ClientContextLock &lock, const string &query, + unique_ptr statement) { + D_ASSERT(statement->type == StatementType::SELECT_STATEMENT); + // Aggressive query verification + + // The purpose of this function is to test correctness of otherwise hard to test features: + // Copy() of statements and expressions + // Serialize()/Deserialize() of expressions + // Hash() of expressions + // Equality() of statements and expressions + // ToString() of statements and expressions + // Correctness of plans both with and without optimizers + + const auto &stmt = *statement; + vector> statement_verifiers; + unique_ptr prepared_statement_verifier; + if (config.query_verification_enabled) { + statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::COPIED, stmt)); + statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::DESERIALIZED, stmt)); + statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::UNOPTIMIZED, stmt)); + prepared_statement_verifier = StatementVerifier::Create(VerificationType::PREPARED, stmt); +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + // This verification is quite slow, so we only run it for the async sink/source debug mode + statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::NO_OPERATOR_CACHING, stmt)); +#endif + } + if (config.verify_external) { + statement_verifiers.emplace_back(StatementVerifier::Create(VerificationType::EXTERNAL, stmt)); + } + + auto original = make_uniq(std::move(statement)); + for (auto &verifier : statement_verifiers) { + original->CheckExpressions(*verifier); + } + original->CheckExpressions(); + + // See below + auto statement_copy_for_explain = stmt.Copy(); + + // Save settings + bool optimizer_enabled = config.enable_optimizer; + bool profiling_is_enabled = config.enable_profiler; + bool force_external = config.force_external; + + // Disable profiling if it is enabled + if (profiling_is_enabled) { + config.enable_profiler = false; + } + + // Execute the original statement + bool any_failed = original->Run(*this, query, [&](const string &q, unique_ptr s) { + return RunStatementInternal(lock, q, std::move(s), false, false); + }); + if (!any_failed) { + statement_verifiers.emplace_back( + StatementVerifier::Create(VerificationType::PARSED, *statement_copy_for_explain)); + } + // Execute the verifiers + for (auto &verifier : statement_verifiers) { + bool failed = verifier->Run(*this, query, [&](const string &q, unique_ptr s) { + return RunStatementInternal(lock, q, std::move(s), false, false); + }); + any_failed = any_failed || failed; + } + + if (!any_failed && prepared_statement_verifier) { + // If none failed, we execute the prepared statement verifier + bool failed = prepared_statement_verifier->Run(*this, query, [&](const string &q, unique_ptr s) { + return RunStatementInternal(lock, q, std::move(s), false, false); + }); + if (!failed) { + // PreparedStatementVerifier fails if it runs into a ParameterNotAllowedException, which is OK + statement_verifiers.push_back(std::move(prepared_statement_verifier)); + } else { + // If it does fail, let's make sure it's not an internal exception + ThrowIfExceptionIsInternal(*prepared_statement_verifier); + } + } else { + if (ValidChecker::IsInvalidated(*db)) { + return original->materialized_result->GetErrorObject(); + } + } + + // Restore config setting + config.enable_optimizer = optimizer_enabled; + config.force_external = force_external; + + // Check explain, only if q does not already contain EXPLAIN + if (original->materialized_result->success) { + auto explain_q = "EXPLAIN " + query; + auto explain_stmt = make_uniq(std::move(statement_copy_for_explain)); + try { + RunStatementInternal(lock, explain_q, std::move(explain_stmt), false, false); + } catch (std::exception &ex) { // LCOV_EXCL_START + interrupted = false; + return PreservedError("EXPLAIN failed but query did not (" + string(ex.what()) + ")"); + } // LCOV_EXCL_STOP + +#ifdef DUCKDB_VERIFY_BOX_RENDERER + // this is pretty slow, so disabled by default + // test the box renderer on the result + // we mostly care that this does not crash + RandomEngine random; + BoxRendererConfig config; + // test with a random width + config.max_width = random.NextRandomInteger() % 500; + BoxRenderer renderer(config); + renderer.ToString(*this, original->materialized_result->names, original->materialized_result->Collection()); +#endif + } + + // Restore profiler setting + if (profiling_is_enabled) { + config.enable_profiler = true; + } + + // Now compare the results + // The results of all runs should be identical + for (auto &verifier : statement_verifiers) { + auto result = original->CompareResults(*verifier); + if (!result.empty()) { + return PreservedError(result); + } + } + + return PreservedError(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp new file mode 100644 index 00000000..5e934993 --- /dev/null +++ b/src/duckdb/src/main/config.cpp @@ -0,0 +1,400 @@ +#include "duckdb/main/config.hpp" + +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/settings.hpp" +#include "duckdb/storage/storage_extension.hpp" + +#ifndef DUCKDB_NO_THREADS +#include "duckdb/common/thread.hpp" +#endif + +#include +#include + +namespace duckdb { + +#ifdef DEBUG +bool DBConfigOptions::debug_print_bindings = false; +#endif + +#define DUCKDB_GLOBAL(_PARAM) \ + { \ + _PARAM::Name, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, nullptr, _PARAM::ResetGlobal, \ + nullptr, _PARAM::GetSetting \ + } +#define DUCKDB_GLOBAL_ALIAS(_ALIAS, _PARAM) \ + { \ + _ALIAS, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, nullptr, _PARAM::ResetGlobal, nullptr, \ + _PARAM::GetSetting \ + } + +#define DUCKDB_LOCAL(_PARAM) \ + { \ + _PARAM::Name, _PARAM::Description, _PARAM::InputType, nullptr, _PARAM::SetLocal, nullptr, _PARAM::ResetLocal, \ + _PARAM::GetSetting \ + } +#define DUCKDB_LOCAL_ALIAS(_ALIAS, _PARAM) \ + { \ + _ALIAS, _PARAM::Description, _PARAM::InputType, nullptr, _PARAM::SetLocal, nullptr, _PARAM::ResetLocal, \ + _PARAM::GetSetting \ + } + +#define DUCKDB_GLOBAL_LOCAL(_PARAM) \ + { \ + _PARAM::Name, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, _PARAM::SetLocal, \ + _PARAM::ResetGlobal, _PARAM::ResetLocal, _PARAM::GetSetting \ + } +#define DUCKDB_GLOBAL_LOCAL_ALIAS(_ALIAS, _PARAM) \ + { \ + _ALIAS, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, _PARAM::SetLocal, _PARAM::ResetGlobal, \ + _PARAM::ResetLocal, _PARAM::GetSetting \ + } +#define FINAL_SETTING \ + { nullptr, nullptr, LogicalTypeId::INVALID, nullptr, nullptr, nullptr, nullptr, nullptr } + +static ConfigurationOption internal_options[] = {DUCKDB_GLOBAL(AccessModeSetting), + DUCKDB_GLOBAL(CheckpointThresholdSetting), + DUCKDB_GLOBAL(DebugCheckpointAbort), + DUCKDB_LOCAL(DebugForceExternal), + DUCKDB_LOCAL(DebugForceNoCrossProduct), + DUCKDB_LOCAL(DebugAsOfIEJoin), + DUCKDB_LOCAL(PreferRangeJoins), + DUCKDB_GLOBAL(DebugWindowMode), + DUCKDB_GLOBAL_LOCAL(DefaultCollationSetting), + DUCKDB_GLOBAL(DefaultOrderSetting), + DUCKDB_GLOBAL(DefaultNullOrderSetting), + DUCKDB_GLOBAL(DisabledFileSystemsSetting), + DUCKDB_GLOBAL(DisabledOptimizersSetting), + DUCKDB_GLOBAL(EnableExternalAccessSetting), + DUCKDB_GLOBAL(EnableFSSTVectors), + DUCKDB_GLOBAL(AllowUnsignedExtensionsSetting), + DUCKDB_LOCAL(CustomExtensionRepository), + DUCKDB_LOCAL(AutoloadExtensionRepository), + DUCKDB_GLOBAL(AutoinstallKnownExtensions), + DUCKDB_GLOBAL(AutoloadKnownExtensions), + DUCKDB_GLOBAL(EnableObjectCacheSetting), + DUCKDB_GLOBAL(EnableHTTPMetadataCacheSetting), + DUCKDB_LOCAL(EnableProfilingSetting), + DUCKDB_LOCAL(EnableProgressBarSetting), + DUCKDB_LOCAL(EnableProgressBarPrintSetting), + DUCKDB_LOCAL(ExplainOutputSetting), + DUCKDB_GLOBAL(ExtensionDirectorySetting), + DUCKDB_GLOBAL(ExternalThreadsSetting), + DUCKDB_LOCAL(FileSearchPathSetting), + DUCKDB_GLOBAL(ForceCompressionSetting), + DUCKDB_GLOBAL(ForceBitpackingModeSetting), + DUCKDB_LOCAL(HomeDirectorySetting), + DUCKDB_LOCAL(LogQueryPathSetting), + DUCKDB_GLOBAL(LockConfigurationSetting), + DUCKDB_GLOBAL(ImmediateTransactionModeSetting), + DUCKDB_LOCAL(IntegerDivisionSetting), + DUCKDB_LOCAL(MaximumExpressionDepthSetting), + DUCKDB_GLOBAL(MaximumMemorySetting), + DUCKDB_GLOBAL_ALIAS("memory_limit", MaximumMemorySetting), + DUCKDB_GLOBAL_ALIAS("null_order", DefaultNullOrderSetting), + DUCKDB_LOCAL(OrderedAggregateThreshold), + DUCKDB_GLOBAL(PasswordSetting), + DUCKDB_LOCAL(PerfectHashThresholdSetting), + DUCKDB_LOCAL(PivotFilterThreshold), + DUCKDB_LOCAL(PivotLimitSetting), + DUCKDB_LOCAL(PreserveIdentifierCase), + DUCKDB_GLOBAL(PreserveInsertionOrder), + DUCKDB_LOCAL(ProfilerHistorySize), + DUCKDB_LOCAL(ProfileOutputSetting), + DUCKDB_LOCAL(ProfilingModeSetting), + DUCKDB_LOCAL_ALIAS("profiling_output", ProfileOutputSetting), + DUCKDB_LOCAL(ProgressBarTimeSetting), + DUCKDB_LOCAL(SchemaSetting), + DUCKDB_LOCAL(SearchPathSetting), + DUCKDB_GLOBAL(TempDirectorySetting), + DUCKDB_GLOBAL(ThreadsSetting), + DUCKDB_GLOBAL(UsernameSetting), + DUCKDB_GLOBAL(ExportLargeBufferArrow), + DUCKDB_GLOBAL_ALIAS("user", UsernameSetting), + DUCKDB_GLOBAL_ALIAS("wal_autocheckpoint", CheckpointThresholdSetting), + DUCKDB_GLOBAL_ALIAS("worker_threads", ThreadsSetting), + DUCKDB_GLOBAL(FlushAllocatorSetting), + FINAL_SETTING}; + +vector DBConfig::GetOptions() { + vector options; + for (idx_t index = 0; internal_options[index].name; index++) { + options.push_back(internal_options[index]); + } + return options; +} + +idx_t DBConfig::GetOptionCount() { + idx_t count = 0; + for (idx_t index = 0; internal_options[index].name; index++) { + count++; + } + return count; +} + +vector DBConfig::GetOptionNames() { + vector names; + for (idx_t i = 0, option_count = DBConfig::GetOptionCount(); i < option_count; i++) { + names.emplace_back(DBConfig::GetOptionByIndex(i)->name); + } + return names; +} + +ConfigurationOption *DBConfig::GetOptionByIndex(idx_t target_index) { + for (idx_t index = 0; internal_options[index].name; index++) { + if (index == target_index) { + return internal_options + index; + } + } + return nullptr; +} + +ConfigurationOption *DBConfig::GetOptionByName(const string &name) { + auto lname = StringUtil::Lower(name); + for (idx_t index = 0; internal_options[index].name; index++) { + D_ASSERT(StringUtil::Lower(internal_options[index].name) == string(internal_options[index].name)); + if (internal_options[index].name == lname) { + return internal_options + index; + } + } + return nullptr; +} + +void DBConfig::SetOption(const ConfigurationOption &option, const Value &value) { + SetOption(nullptr, option, value); +} + +void DBConfig::SetOptionByName(const string &name, const Value &value) { + auto option = DBConfig::GetOptionByName(name); + if (option) { + SetOption(*option, value); + } else { + options.unrecognized_options[name] = value; + } +} + +void DBConfig::SetOption(DatabaseInstance *db, const ConfigurationOption &option, const Value &value) { + lock_guard l(config_lock); + if (!option.set_global) { + throw InvalidInputException("Could not set option \"%s\" as a global option", option.name); + } + D_ASSERT(option.reset_global); + Value input = value.DefaultCastAs(option.parameter_type); + option.set_global(db, *this, input); +} + +void DBConfig::ResetOption(DatabaseInstance *db, const ConfigurationOption &option) { + lock_guard l(config_lock); + if (!option.reset_global) { + throw InternalException("Could not reset option \"%s\" as a global option", option.name); + } + D_ASSERT(option.set_global); + option.reset_global(db, *this); +} + +void DBConfig::SetOption(const string &name, Value value) { + lock_guard l(config_lock); + options.set_variables[name] = std::move(value); +} + +void DBConfig::ResetOption(const string &name) { + lock_guard l(config_lock); + auto extension_option = extension_parameters.find(name); + D_ASSERT(extension_option != extension_parameters.end()); + auto &default_value = extension_option->second.default_value; + if (!default_value.IsNull()) { + // Default is not NULL, override the setting + options.set_variables[name] = default_value; + } else { + // Otherwise just remove it from the 'set_variables' map + options.set_variables.erase(name); + } +} + +void DBConfig::AddExtensionOption(const string &name, string description, LogicalType parameter, + const Value &default_value, set_option_callback_t function) { + extension_parameters.insert( + make_pair(name, ExtensionOption(std::move(description), std::move(parameter), function, default_value))); + if (!default_value.IsNull()) { + // Default value is set, insert it into the 'set_variables' list + options.set_variables[name] = default_value; + } +} + +CastFunctionSet &DBConfig::GetCastFunctions() { + return *cast_functions; +} + +void DBConfig::SetDefaultMaxMemory() { + auto memory = FileSystem::GetAvailableMemory(); + if (memory != DConstants::INVALID_INDEX) { + options.maximum_memory = memory * 8 / 10; + } +} + +idx_t CGroupBandwidthQuota(idx_t physical_cores, FileSystem &fs) { + static constexpr const char *CPU_MAX = "/sys/fs/cgroup/cpu.max"; + static constexpr const char *CFS_QUOTA = "/sys/fs/cgroup/cpu/cpu.cfs_quota_us"; + static constexpr const char *CFS_PERIOD = "/sys/fs/cgroup/cpu/cpu.cfs_period_us"; + + int64_t quota, period; + char byte_buffer[1000]; + unique_ptr handle; + int64_t read_bytes; + + if (fs.FileExists(CPU_MAX)) { + // cgroup v2 + // https://www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html + handle = + fs.OpenFile(CPU_MAX, FileFlags::FILE_FLAGS_READ, FileSystem::DEFAULT_LOCK, FileSystem::DEFAULT_COMPRESSION); + read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); + byte_buffer[read_bytes] = '\0'; + if (std::sscanf(byte_buffer, "%" SCNd64 " %" SCNd64 "", "a, &period) != 2) { + return physical_cores; + } + } else if (fs.FileExists(CFS_QUOTA) && fs.FileExists(CFS_PERIOD)) { + // cgroup v1 + // https://www.kernel.org/doc/html/latest/scheduler/sched-bwc.html#management + + // Read the quota, this indicates how many microseconds the CPU can be utilized by this cgroup per period + handle = fs.OpenFile(CFS_QUOTA, FileFlags::FILE_FLAGS_READ, FileSystem::DEFAULT_LOCK, + FileSystem::DEFAULT_COMPRESSION); + read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); + byte_buffer[read_bytes] = '\0'; + if (std::sscanf(byte_buffer, "%" SCNd64 "", "a) != 1) { + return physical_cores; + } + + // Read the time period, a cgroup can utilize the CPU up to quota microseconds every period + handle = fs.OpenFile(CFS_PERIOD, FileFlags::FILE_FLAGS_READ, FileSystem::DEFAULT_LOCK, + FileSystem::DEFAULT_COMPRESSION); + read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); + byte_buffer[read_bytes] = '\0'; + if (std::sscanf(byte_buffer, "%" SCNd64 "", &period) != 1) { + return physical_cores; + } + } else { + // No cgroup quota + return physical_cores; + } + if (quota > 0 && period > 0) { + return idx_t(std::ceil((double)quota / (double)period)); + } else { + return physical_cores; + } +} + +idx_t DBConfig::GetSystemMaxThreads(FileSystem &fs) { +#ifndef DUCKDB_NO_THREADS + idx_t physical_cores = std::thread::hardware_concurrency(); +#ifdef __linux__ + auto cores_available_per_period = CGroupBandwidthQuota(physical_cores, fs); + return MaxValue(cores_available_per_period, 1); +#else + return physical_cores; +#endif +#else + return 1; +#endif +} + +void DBConfig::SetDefaultMaxThreads() { +#ifndef DUCKDB_NO_THREADS + options.maximum_threads = GetSystemMaxThreads(*file_system); +#else + options.maximum_threads = 1; +#endif +} + +idx_t DBConfig::ParseMemoryLimit(const string &arg) { + if (arg[0] == '-' || arg == "null" || arg == "none") { + return DConstants::INVALID_INDEX; + } + // split based on the number/non-number + idx_t idx = 0; + while (StringUtil::CharacterIsSpace(arg[idx])) { + idx++; + } + idx_t num_start = idx; + while ((arg[idx] >= '0' && arg[idx] <= '9') || arg[idx] == '.' || arg[idx] == 'e' || arg[idx] == 'E' || + arg[idx] == '-') { + idx++; + } + if (idx == num_start) { + throw ParserException("Memory limit must have a number (e.g. SET memory_limit=1GB"); + } + string number = arg.substr(num_start, idx - num_start); + + // try to parse the number + double limit = Cast::Operation(string_t(number)); + + // now parse the memory limit unit (e.g. bytes, gb, etc) + while (StringUtil::CharacterIsSpace(arg[idx])) { + idx++; + } + idx_t start = idx; + while (idx < arg.size() && !StringUtil::CharacterIsSpace(arg[idx])) { + idx++; + } + if (limit < 0) { + // limit < 0, set limit to infinite + return (idx_t)-1; + } + string unit = StringUtil::Lower(arg.substr(start, idx - start)); + idx_t multiplier; + if (unit == "byte" || unit == "bytes" || unit == "b") { + multiplier = 1; + } else if (unit == "kilobyte" || unit == "kilobytes" || unit == "kb" || unit == "k") { + multiplier = 1000LL; + } else if (unit == "megabyte" || unit == "megabytes" || unit == "mb" || unit == "m") { + multiplier = 1000LL * 1000LL; + } else if (unit == "gigabyte" || unit == "gigabytes" || unit == "gb" || unit == "g") { + multiplier = 1000LL * 1000LL * 1000LL; + } else if (unit == "terabyte" || unit == "terabytes" || unit == "tb" || unit == "t") { + multiplier = 1000LL * 1000LL * 1000LL * 1000LL; + } else { + throw ParserException("Unknown unit for memory_limit: %s (expected: b, mb, gb or tb)", unit); + } + return (idx_t)multiplier * limit; +} + +// Right now we only really care about access mode when comparing DBConfigs +bool DBConfigOptions::operator==(const DBConfigOptions &other) const { + return other.access_mode == access_mode; +} + +bool DBConfig::operator==(const DBConfig &other) { + return other.options == options; +} + +bool DBConfig::operator!=(const DBConfig &other) { + return !(other.options == options); +} + +OrderType DBConfig::ResolveOrder(OrderType order_type) const { + if (order_type != OrderType::ORDER_DEFAULT) { + return order_type; + } + return options.default_order_type; +} + +OrderByNullType DBConfig::ResolveNullOrder(OrderType order_type, OrderByNullType null_type) const { + if (null_type != OrderByNullType::ORDER_DEFAULT) { + return null_type; + } + switch (options.default_null_order) { + case DefaultOrderByNullType::NULLS_FIRST: + return OrderByNullType::NULLS_FIRST; + case DefaultOrderByNullType::NULLS_LAST: + return OrderByNullType::NULLS_LAST; + case DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC: + return order_type == OrderType::ASCENDING ? OrderByNullType::NULLS_FIRST : OrderByNullType::NULLS_LAST; + case DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC: + return order_type == OrderType::ASCENDING ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; + default: + throw InternalException("Unknown null order setting"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/connection.cpp b/src/duckdb/src/main/connection.cpp new file mode 100644 index 00000000..e225c643 --- /dev/null +++ b/src/duckdb/src/main/connection.cpp @@ -0,0 +1,294 @@ +#include "duckdb/main/connection.hpp" + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp" +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/main/appender.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/connection_manager.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/main/relation/query_relation.hpp" +#include "duckdb/main/relation/read_csv_relation.hpp" +#include "duckdb/main/relation/table_function_relation.hpp" +#include "duckdb/main/relation/table_relation.hpp" +#include "duckdb/main/relation/value_relation.hpp" +#include "duckdb/main/relation/view_relation.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +Connection::Connection(DatabaseInstance &database) : context(make_shared(database.shared_from_this())) { + ConnectionManager::Get(database).AddConnection(*context); +#ifdef DEBUG + EnableProfiling(); + context->config.emit_profiler_output = false; +#endif +} + +Connection::Connection(DuckDB &database) : Connection(*database.instance) { +} + +Connection::~Connection() { + ConnectionManager::Get(*context->db).RemoveConnection(*context); +} + +string Connection::GetProfilingInformation(ProfilerPrintFormat format) { + auto &profiler = QueryProfiler::Get(*context); + if (format == ProfilerPrintFormat::JSON) { + return profiler.ToJSON(); + } else { + return profiler.QueryTreeToString(); + } +} + +void Connection::Interrupt() { + context->Interrupt(); +} + +void Connection::EnableProfiling() { + context->EnableProfiling(); +} + +void Connection::DisableProfiling() { + context->DisableProfiling(); +} + +void Connection::EnableQueryVerification() { + ClientConfig::GetConfig(*context).query_verification_enabled = true; +} + +void Connection::DisableQueryVerification() { + ClientConfig::GetConfig(*context).query_verification_enabled = false; +} + +void Connection::ForceParallelism() { + ClientConfig::GetConfig(*context).verify_parallelism = true; +} + +unique_ptr Connection::SendQuery(const string &query) { + return context->Query(query, true); +} + +unique_ptr Connection::Query(const string &query) { + auto result = context->Query(query, false); + D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); + return unique_ptr_cast(std::move(result)); +} + +DUCKDB_API string Connection::GetSubstrait(const string &query) { + vector params; + params.emplace_back(query); + auto result = TableFunction("get_substrait", params)->Execute(); + auto protobuf = result->FetchRaw()->GetValue(0, 0); + return protobuf.GetValueUnsafe().GetString(); +} + +DUCKDB_API unique_ptr Connection::FromSubstrait(const string &proto) { + vector params; + params.emplace_back(Value::BLOB_RAW(proto)); + return TableFunction("from_substrait", params)->Execute(); +} + +DUCKDB_API string Connection::GetSubstraitJSON(const string &query) { + vector params; + params.emplace_back(query); + auto result = TableFunction("get_substrait_json", params)->Execute(); + auto protobuf = result->FetchRaw()->GetValue(0, 0); + return protobuf.GetValueUnsafe().GetString(); +} + +DUCKDB_API unique_ptr Connection::FromSubstraitJSON(const string &json) { + vector params; + params.emplace_back(json); + return TableFunction("from_substrait_json", params)->Execute(); +} + +unique_ptr Connection::Query(unique_ptr statement) { + auto result = context->Query(std::move(statement), false); + D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); + return unique_ptr_cast(std::move(result)); +} + +unique_ptr Connection::PendingQuery(const string &query, bool allow_stream_result) { + return context->PendingQuery(query, allow_stream_result); +} + +unique_ptr Connection::PendingQuery(unique_ptr statement, bool allow_stream_result) { + return context->PendingQuery(std::move(statement), allow_stream_result); +} + +unique_ptr Connection::Prepare(const string &query) { + return context->Prepare(query); +} + +unique_ptr Connection::Prepare(unique_ptr statement) { + return context->Prepare(std::move(statement)); +} + +unique_ptr Connection::QueryParamsRecursive(const string &query, vector &values) { + auto statement = Prepare(query); + if (statement->HasError()) { + return make_uniq(statement->error); + } + return statement->Execute(values, false); +} + +unique_ptr Connection::TableInfo(const string &table_name) { + return TableInfo(INVALID_SCHEMA, table_name); +} + +unique_ptr Connection::TableInfo(const string &schema_name, const string &table_name) { + return context->TableInfo(schema_name, table_name); +} + +vector> Connection::ExtractStatements(const string &query) { + return context->ParseStatements(query); +} + +unique_ptr Connection::ExtractPlan(const string &query) { + return context->ExtractPlan(query); +} + +void Connection::Append(TableDescription &description, DataChunk &chunk) { + if (chunk.size() == 0) { + return; + } + ColumnDataCollection collection(Allocator::Get(*context), chunk.GetTypes()); + collection.Append(chunk); + Append(description, collection); +} + +void Connection::Append(TableDescription &description, ColumnDataCollection &collection) { + context->Append(description, collection); +} + +shared_ptr Connection::Table(const string &table_name) { + return Table(DEFAULT_SCHEMA, table_name); +} + +shared_ptr Connection::Table(const string &schema_name, const string &table_name) { + auto table_info = TableInfo(schema_name, table_name); + if (!table_info) { + throw CatalogException("Table '%s' does not exist!", table_name); + } + return make_shared(context, std::move(table_info)); +} + +shared_ptr Connection::View(const string &tname) { + return View(DEFAULT_SCHEMA, tname); +} + +shared_ptr Connection::View(const string &schema_name, const string &table_name) { + return make_shared(context, schema_name, table_name); +} + +shared_ptr Connection::TableFunction(const string &fname) { + vector values; + named_parameter_map_t named_parameters; + return TableFunction(fname, values, named_parameters); +} + +shared_ptr Connection::TableFunction(const string &fname, const vector &values, + const named_parameter_map_t &named_parameters) { + return make_shared(context, fname, values, named_parameters); +} + +shared_ptr Connection::TableFunction(const string &fname, const vector &values) { + return make_shared(context, fname, values); +} + +shared_ptr Connection::Values(const vector> &values) { + vector column_names; + return Values(values, column_names); +} + +shared_ptr Connection::Values(const vector> &values, const vector &column_names, + const string &alias) { + return make_shared(context, values, column_names, alias); +} + +shared_ptr Connection::Values(const string &values) { + vector column_names; + return Values(values, column_names); +} + +shared_ptr Connection::Values(const string &values, const vector &column_names, const string &alias) { + return make_shared(context, values, column_names, alias); +} + +shared_ptr Connection::ReadCSV(const string &csv_file) { + named_parameter_map_t options; + return ReadCSV(csv_file, std::move(options)); +} + +shared_ptr Connection::ReadCSV(const string &csv_file, named_parameter_map_t &&options) { + return make_shared(context, csv_file, std::move(options)); +} + +shared_ptr Connection::ReadCSV(const string &csv_file, const vector &columns) { + // parse columns + vector column_list; + for (auto &column : columns) { + auto col_list = Parser::ParseColumnList(column, context->GetParserOptions()); + if (col_list.LogicalColumnCount() != 1) { + throw ParserException("Expected a single column definition"); + } + column_list.push_back(std::move(col_list.GetColumnMutable(LogicalIndex(0)))); + } + return make_shared(context, csv_file, std::move(column_list)); +} + +shared_ptr Connection::ReadParquet(const string &parquet_file, bool binary_as_string) { + vector params; + params.emplace_back(parquet_file); + named_parameter_map_t named_parameters({{"binary_as_string", Value::BOOLEAN(binary_as_string)}}); + return TableFunction("parquet_scan", params, named_parameters)->Alias(parquet_file); +} + +unordered_set Connection::GetTableNames(const string &query) { + return context->GetTableNames(query); +} + +shared_ptr Connection::RelationFromQuery(const string &query, const string &alias, const string &error) { + return RelationFromQuery(QueryRelation::ParseStatement(*context, query, error), alias); +} + +shared_ptr Connection::RelationFromQuery(unique_ptr select_stmt, const string &alias) { + return make_shared(context, std::move(select_stmt), alias); +} + +void Connection::BeginTransaction() { + auto result = Query("BEGIN TRANSACTION"); + if (result->HasError()) { + result->ThrowError(); + } +} + +void Connection::Commit() { + auto result = Query("COMMIT"); + if (result->HasError()) { + result->ThrowError(); + } +} + +void Connection::Rollback() { + auto result = Query("ROLLBACK"); + if (result->HasError()) { + result->ThrowError(); + } +} + +void Connection::SetAutoCommit(bool auto_commit) { + context->transaction.SetAutoCommit(auto_commit); +} + +bool Connection::IsAutoCommit() { + return context->transaction.IsAutoCommit(); +} +bool Connection::HasActiveTransaction() { + return context->transaction.HasActiveTransaction(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/database.cpp b/src/duckdb/src/main/database.cpp new file mode 100644 index 00000000..31103869 --- /dev/null +++ b/src/duckdb/src/main/database.cpp @@ -0,0 +1,414 @@ +#include "duckdb/main/database.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/virtual_file_system.hpp" +#include "duckdb/execution/operator/helper/physical_set.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/connection_manager.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/error_manager.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/parser/parsed_data/attach_info.hpp" +#include "duckdb/storage/object_cache.hpp" +#include "duckdb/storage/standard_buffer_manager.hpp" +#include "duckdb/main/database_path_and_type.hpp" +#include "duckdb/storage/storage_extension.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/transaction/transaction_manager.hpp" +#include "duckdb/planner/extension_callback.hpp" + +#ifndef DUCKDB_NO_THREADS +#include "duckdb/common/thread.hpp" +#endif + +namespace duckdb { + +DBConfig::DBConfig() { + compression_functions = make_uniq(); + cast_functions = make_uniq(); + error_manager = make_uniq(); +} + +DBConfig::DBConfig(std::unordered_map &config_dict, bool read_only) : DBConfig::DBConfig() { + if (read_only) { + options.access_mode = AccessMode::READ_ONLY; + } + for (auto &kv : config_dict) { + string key = kv.first; + string val = kv.second; + auto opt_val = Value(val); + DBConfig::SetOptionByName(key, opt_val); + } +} + +DBConfig::~DBConfig() { +} + +DatabaseInstance::DatabaseInstance() { +} + +DatabaseInstance::~DatabaseInstance() { +} + +BufferManager &BufferManager::GetBufferManager(DatabaseInstance &db) { + return db.GetBufferManager(); +} + +BufferManager &BufferManager::GetBufferManager(AttachedDatabase &db) { + return BufferManager::GetBufferManager(db.GetDatabase()); +} + +DatabaseInstance &DatabaseInstance::GetDatabase(ClientContext &context) { + return *context.db; +} + +DatabaseManager &DatabaseInstance::GetDatabaseManager() { + if (!db_manager) { + throw InternalException("Missing DB manager"); + } + return *db_manager; +} + +Catalog &Catalog::GetSystemCatalog(DatabaseInstance &db) { + return db.GetDatabaseManager().GetSystemCatalog(); +} + +Catalog &Catalog::GetCatalog(AttachedDatabase &db) { + return db.GetCatalog(); +} + +FileSystem &FileSystem::GetFileSystem(DatabaseInstance &db) { + return db.GetFileSystem(); +} + +FileSystem &FileSystem::Get(AttachedDatabase &db) { + return FileSystem::GetFileSystem(db.GetDatabase()); +} + +DBConfig &DBConfig::GetConfig(DatabaseInstance &db) { + return db.config; +} + +ClientConfig &ClientConfig::GetConfig(ClientContext &context) { + return context.config; +} + +DBConfig &DBConfig::Get(AttachedDatabase &db) { + return DBConfig::GetConfig(db.GetDatabase()); +} + +const DBConfig &DBConfig::GetConfig(const DatabaseInstance &db) { + return db.config; +} + +const ClientConfig &ClientConfig::GetConfig(const ClientContext &context) { + return context.config; +} + +TransactionManager &TransactionManager::Get(AttachedDatabase &db) { + return db.GetTransactionManager(); +} + +ConnectionManager &ConnectionManager::Get(DatabaseInstance &db) { + return db.GetConnectionManager(); +} + +ClientContext *ConnectionManager::GetConnection(DatabaseInstance *db) { + for (auto &conn : connections) { + if (conn.first->db.get() == db) { + return conn.first; + } + } + return nullptr; +} + +ConnectionManager &ConnectionManager::Get(ClientContext &context) { + return ConnectionManager::Get(DatabaseInstance::GetDatabase(context)); +} + +duckdb::unique_ptr DatabaseInstance::CreateAttachedDatabase(AttachInfo &info, const string &type, + AccessMode access_mode) { + duckdb::unique_ptr attached_database; + if (!type.empty()) { + // find the storage extension + auto extension_name = ExtensionHelper::ApplyExtensionAlias(type); + auto entry = config.storage_extensions.find(extension_name); + if (entry == config.storage_extensions.end()) { + throw BinderException("Unrecognized storage type \"%s\"", type); + } + + if (entry->second->attach != nullptr && entry->second->create_transaction_manager != nullptr) { + // use storage extension to create the initial database + attached_database = make_uniq(*this, Catalog::GetSystemCatalog(*this), *entry->second, + info.name, info, access_mode); + } else { + attached_database = + make_uniq(*this, Catalog::GetSystemCatalog(*this), info.name, info.path, access_mode); + } + } else { + // check if this is an in-memory database or not + attached_database = + make_uniq(*this, Catalog::GetSystemCatalog(*this), info.name, info.path, access_mode); + } + return attached_database; +} + +void DatabaseInstance::CreateMainDatabase() { + AttachInfo info; + info.name = AttachedDatabase::ExtractDatabaseName(config.options.database_path, GetFileSystem()); + info.path = config.options.database_path; + + auto attached_database = CreateAttachedDatabase(info, config.options.database_type, config.options.access_mode); + auto initial_database = attached_database.get(); + { + Connection con(*this); + con.BeginTransaction(); + db_manager->AddDatabase(*con.context, std::move(attached_database)); + con.Commit(); + } + + // initialize the database + initial_database->SetInitialDatabase(); + initial_database->Initialize(); +} + +void ThrowExtensionSetUnrecognizedOptions(const unordered_map &unrecognized_options) { + auto unrecognized_options_iter = unrecognized_options.begin(); + string unrecognized_option_keys = unrecognized_options_iter->first; + while (++unrecognized_options_iter != unrecognized_options.end()) { + unrecognized_option_keys = "," + unrecognized_options_iter->first; + } + throw InvalidInputException("Unrecognized configuration property \"%s\"", unrecognized_option_keys); +} + +void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_config) { + DBConfig default_config; + DBConfig *config_ptr = &default_config; + if (user_config) { + config_ptr = user_config; + } + + if (config_ptr->options.temporary_directory.empty() && database_path) { + // no directory specified: use default temp path + config_ptr->options.temporary_directory = string(database_path) + ".tmp"; + + // special treatment for in-memory mode + if (strcmp(database_path, ":memory:") == 0) { + config_ptr->options.temporary_directory = ".tmp"; + } + } + + if (database_path) { + config_ptr->options.database_path = database_path; + } else { + config_ptr->options.database_path.clear(); + } + Configure(*config_ptr); + + if (user_config && !user_config->options.use_temporary_directory) { + // temporary directories explicitly disabled + config.options.temporary_directory = string(); + } + + db_manager = make_uniq(*this); + buffer_manager = make_uniq(*this, config.options.temporary_directory); + scheduler = make_uniq(*this); + object_cache = make_uniq(); + connection_manager = make_uniq(); + + // check if we are opening a standard DuckDB database or an extension database + if (config.options.database_type.empty()) { + auto path_and_type = DBPathAndType::Parse(config.options.database_path, config); + config.options.database_type = path_and_type.type; + config.options.database_path = path_and_type.path; + } + + // initialize the system catalog + db_manager->InitializeSystemCatalog(); + + if (!config.options.database_type.empty()) { + // if we are opening an extension database - load the extension + if (!config.file_system) { + throw InternalException("No file system!?"); + } + ExtensionHelper::LoadExternalExtension(*this, *config.file_system, config.options.database_type, nullptr); + } + + if (!config.options.unrecognized_options.empty()) { + ThrowExtensionSetUnrecognizedOptions(config.options.unrecognized_options); + } + + if (!db_manager->HasDefaultDatabase()) { + CreateMainDatabase(); + } + + // only increase thread count after storage init because we get races on catalog otherwise + scheduler->SetThreads(config.options.maximum_threads); +} + +DuckDB::DuckDB(const char *path, DBConfig *new_config) : instance(make_shared()) { + instance->Initialize(path, new_config); + if (instance->config.options.load_extensions) { + ExtensionHelper::LoadAllExtensions(*this); + } +} + +DuckDB::DuckDB(const string &path, DBConfig *config) : DuckDB(path.c_str(), config) { +} + +DuckDB::DuckDB(DatabaseInstance &instance_p) : instance(instance_p.shared_from_this()) { +} + +DuckDB::~DuckDB() { +} + +BufferManager &DatabaseInstance::GetBufferManager() { + return *buffer_manager; +} + +BufferPool &DatabaseInstance::GetBufferPool() { + return *config.buffer_pool; +} + +DatabaseManager &DatabaseManager::Get(DatabaseInstance &db) { + return db.GetDatabaseManager(); +} + +DatabaseManager &DatabaseManager::Get(ClientContext &db) { + return DatabaseManager::Get(*db.db); +} + +TaskScheduler &DatabaseInstance::GetScheduler() { + return *scheduler; +} + +ObjectCache &DatabaseInstance::GetObjectCache() { + return *object_cache; +} + +FileSystem &DatabaseInstance::GetFileSystem() { + return *config.file_system; +} + +ConnectionManager &DatabaseInstance::GetConnectionManager() { + return *connection_manager; +} + +FileSystem &DuckDB::GetFileSystem() { + return instance->GetFileSystem(); +} + +Allocator &Allocator::Get(ClientContext &context) { + return Allocator::Get(*context.db); +} + +Allocator &Allocator::Get(DatabaseInstance &db) { + return *db.config.allocator; +} + +Allocator &Allocator::Get(AttachedDatabase &db) { + return Allocator::Get(db.GetDatabase()); +} + +void DatabaseInstance::Configure(DBConfig &new_config) { + config.options = new_config.options; + if (config.options.access_mode == AccessMode::UNDEFINED) { + config.options.access_mode = AccessMode::READ_WRITE; + } + if (new_config.file_system) { + config.file_system = std::move(new_config.file_system); + } else { + config.file_system = make_uniq(); + } + if (config.options.maximum_memory == (idx_t)-1) { + config.SetDefaultMaxMemory(); + } + if (new_config.options.maximum_threads == (idx_t)-1) { + config.SetDefaultMaxThreads(); + } + config.allocator = std::move(new_config.allocator); + if (!config.allocator) { + config.allocator = make_uniq(); + } + config.replacement_scans = std::move(new_config.replacement_scans); + config.parser_extensions = std::move(new_config.parser_extensions); + config.error_manager = std::move(new_config.error_manager); + if (!config.error_manager) { + config.error_manager = make_uniq(); + } + if (!config.default_allocator) { + config.default_allocator = Allocator::DefaultAllocatorReference(); + } + if (new_config.buffer_pool) { + config.buffer_pool = std::move(new_config.buffer_pool); + } else { + config.buffer_pool = make_shared(config.options.maximum_memory); + } +} + +DBConfig &DBConfig::GetConfig(ClientContext &context) { + return context.db->config; +} + +const DBConfig &DBConfig::GetConfig(const ClientContext &context) { + return context.db->config; +} + +idx_t DatabaseInstance::NumberOfThreads() { + return scheduler->NumberOfThreads(); +} + +const unordered_set &DatabaseInstance::LoadedExtensions() { + return loaded_extensions; +} + +idx_t DuckDB::NumberOfThreads() { + return instance->NumberOfThreads(); +} + +bool DatabaseInstance::ExtensionIsLoaded(const std::string &name) { + auto extension_name = ExtensionHelper::GetExtensionName(name); + return loaded_extensions.find(extension_name) != loaded_extensions.end(); +} + +bool DuckDB::ExtensionIsLoaded(const std::string &name) { + return instance->ExtensionIsLoaded(name); +} + +void DatabaseInstance::SetExtensionLoaded(const std::string &name) { + auto extension_name = ExtensionHelper::GetExtensionName(name); + loaded_extensions.insert(extension_name); + + auto &callbacks = DBConfig::GetConfig(*this).extension_callbacks; + for (auto &callback : callbacks) { + callback->OnExtensionLoaded(*this, name); + } +} + +bool DatabaseInstance::TryGetCurrentSetting(const std::string &key, Value &result) { + // check the session values + auto &db_config = DBConfig::GetConfig(*this); + const auto &global_config_map = db_config.options.set_variables; + + auto global_value = global_config_map.find(key); + bool found_global_value = global_value != global_config_map.end(); + if (!found_global_value) { + return false; + } + result = global_value->second; + return true; +} + +ValidChecker &DatabaseInstance::GetValidChecker() { + return db_validity; +} + +ValidChecker &ValidChecker::Get(DatabaseInstance &db) { + return db.GetValidChecker(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/database_manager.cpp b/src/duckdb/src/main/database_manager.cpp new file mode 100644 index 00000000..486e6b11 --- /dev/null +++ b/src/duckdb/src/main/database_manager.cpp @@ -0,0 +1,119 @@ +#include "duckdb/main/database_manager.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" + +namespace duckdb { + +DatabaseManager::DatabaseManager(DatabaseInstance &db) : catalog_version(0), current_query_number(1) { + system = make_uniq(db); + databases = make_uniq(system->GetCatalog()); +} + +DatabaseManager::~DatabaseManager() { +} + +DatabaseManager &DatabaseManager::Get(AttachedDatabase &db) { + return DatabaseManager::Get(db.GetDatabase()); +} + +void DatabaseManager::InitializeSystemCatalog() { + system->Initialize(); +} + +optional_ptr DatabaseManager::GetDatabase(ClientContext &context, const string &name) { + if (StringUtil::Lower(name) == TEMP_CATALOG) { + return context.client_data->temporary_objects.get(); + } + return reinterpret_cast(databases->GetEntry(context, name).get()); +} + +void DatabaseManager::AddDatabase(ClientContext &context, unique_ptr db_instance) { + auto name = db_instance->GetName(); + db_instance->oid = ModifyCatalog(); + DependencyList dependencies; + if (default_database.empty()) { + default_database = name; + } + if (!databases->CreateEntry(context, name, std::move(db_instance), dependencies)) { + throw BinderException("Failed to attach database: database with name \"%s\" already exists", name); + } +} + +void DatabaseManager::DetachDatabase(ClientContext &context, const string &name, OnEntryNotFound if_not_found) { + if (GetDefaultDatabase(context) == name) { + throw BinderException("Cannot detach database \"%s\" because it is the default database. Select a different " + "database using `USE` to allow detaching this database", + name); + } + if (!databases->DropEntry(context, name, false, true)) { + if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { + throw BinderException("Failed to detach database with name \"%s\": database not found", name); + } + } +} + +optional_ptr DatabaseManager::GetDatabaseFromPath(ClientContext &context, const string &path) { + auto databases = GetDatabases(context); + for (auto &db_ref : databases) { + auto &db = db_ref.get(); + if (db.IsSystem()) { + continue; + } + auto &catalog = Catalog::GetCatalog(db); + if (catalog.InMemory()) { + continue; + } + auto db_path = catalog.GetDBPath(); + if (StringUtil::CIEquals(path, db_path)) { + return &db; + } + } + return nullptr; +} + +const string &DatabaseManager::GetDefaultDatabase(ClientContext &context) { + auto &config = ClientData::Get(context); + auto &default_entry = config.catalog_search_path->GetDefault(); + if (IsInvalidCatalog(default_entry.catalog)) { + auto &result = DatabaseManager::Get(context).default_database; + if (result.empty()) { + throw InternalException("Calling DatabaseManager::GetDefaultDatabase with no default database set"); + } + return result; + } + return default_entry.catalog; +} + +// LCOV_EXCL_START +void DatabaseManager::SetDefaultDatabase(ClientContext &context, const string &new_value) { + auto db_entry = GetDatabase(context, new_value); + + if (!db_entry) { + throw InternalException("Database \"%s\" not found", new_value); + } else if (db_entry->IsTemporary()) { + throw InternalException("Cannot set the default database to a temporary database"); + } else if (db_entry->IsSystem()) { + throw InternalException("Cannot set the default database to a system database"); + } + + default_database = new_value; +} +// LCOV_EXCL_STOP + +vector> DatabaseManager::GetDatabases(ClientContext &context) { + vector> result; + databases->Scan(context, [&](CatalogEntry &entry) { result.push_back(entry.Cast()); }); + result.push_back(*system); + result.push_back(*context.client_data->temporary_objects); + return result; +} + +Catalog &DatabaseManager::GetSystemCatalog() { + D_ASSERT(system); + return system->GetCatalog(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/database_path_and_type.cpp b/src/duckdb/src/main/database_path_and_type.cpp new file mode 100644 index 00000000..64f1f9ce --- /dev/null +++ b/src/duckdb/src/main/database_path_and_type.cpp @@ -0,0 +1,23 @@ +#include "duckdb/main/database_path_and_type.hpp" + +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/storage/magic_bytes.hpp" + +namespace duckdb { + +DBPathAndType DBPathAndType::Parse(const string &combined_path, const DBConfig &config) { + auto extension = ExtensionHelper::ExtractExtensionPrefixFromPath(combined_path); + if (!extension.empty()) { + // path is prefixed with an extension - remove it + auto path = StringUtil::Replace(combined_path, extension + ":", ""); + auto type = ExtensionHelper::ApplyExtensionAlias(extension); + return {path, type}; + } + // if there isn't - check the magic bytes of the file (if any) + auto file_type = MagicBytes::CheckMagicBytes(config.file_system.get(), combined_path); + if (file_type == DataFileType::SQLITE_FILE) { + return {combined_path, "sqlite"}; + } + return {combined_path, string()}; +} +} // namespace duckdb diff --git a/src/duckdb/src/main/db_instance_cache.cpp b/src/duckdb/src/main/db_instance_cache.cpp new file mode 100644 index 00000000..41012f55 --- /dev/null +++ b/src/duckdb/src/main/db_instance_cache.cpp @@ -0,0 +1,93 @@ +#include "duckdb/main/db_instance_cache.hpp" +#include "duckdb/main/extension_helper.hpp" + +namespace duckdb { + +string GetDBAbsolutePath(const string &database_p, FileSystem &fs) { + auto database = FileSystem::ExpandPath(database_p, nullptr); + if (database.empty()) { + return ":memory:"; + } + if (database.rfind(":memory:", 0) == 0) { + // this is a memory db, just return it. + return database; + } + if (!ExtensionHelper::ExtractExtensionPrefixFromPath(database).empty()) { + // this database path is handled by a replacement open and is not a file path + return database; + } + if (fs.IsPathAbsolute(database)) { + return fs.NormalizeAbsolutePath(database); + } + return fs.NormalizeAbsolutePath(fs.JoinPath(FileSystem::GetWorkingDirectory(), database)); +} + +shared_ptr DBInstanceCache::GetInstanceInternal(const string &database, const DBConfig &config) { + shared_ptr db_instance; + + auto local_fs = FileSystem::CreateLocal(); + auto abs_database_path = GetDBAbsolutePath(database, *local_fs); + if (db_instances.find(abs_database_path) != db_instances.end()) { + db_instance = db_instances[abs_database_path].lock(); + if (db_instance) { + if (db_instance->instance->config != config) { + throw duckdb::ConnectionException( + "Can't open a connection to same database file with a different configuration " + "than existing connections"); + } + } else { + // clean-up + db_instances.erase(abs_database_path); + } + } + return db_instance; +} + +shared_ptr DBInstanceCache::GetInstance(const string &database, const DBConfig &config) { + lock_guard l(cache_lock); + return GetInstanceInternal(database, config); +} + +shared_ptr DBInstanceCache::CreateInstanceInternal(const string &database, DBConfig &config, + bool cache_instance) { + string abs_database_path; + if (config.file_system) { + abs_database_path = GetDBAbsolutePath(database, *config.file_system); + } else { + auto tmp_fs = FileSystem::CreateLocal(); + abs_database_path = GetDBAbsolutePath(database, *tmp_fs); + } + if (db_instances.find(abs_database_path) != db_instances.end()) { + throw duckdb::Exception(ExceptionType::CONNECTION, + "Instance with path: " + abs_database_path + " already exists."); + } + // Creates new instance + string instance_path = abs_database_path; + if (abs_database_path.rfind(":memory:", 0) == 0) { + instance_path = ":memory:"; + } + auto db_instance = make_shared(instance_path, &config); + if (cache_instance) { + db_instances[abs_database_path] = db_instance; + } + return db_instance; +} + +shared_ptr DBInstanceCache::CreateInstance(const string &database, DBConfig &config, bool cache_instance) { + lock_guard l(cache_lock); + return CreateInstanceInternal(database, config, cache_instance); +} + +shared_ptr DBInstanceCache::GetOrCreateInstance(const string &database, DBConfig &config_dict, + bool cache_instance) { + lock_guard l(cache_lock); + if (cache_instance) { + auto instance = GetInstanceInternal(database, config_dict); + if (instance) { + return instance; + } + } + return CreateInstanceInternal(database, config_dict, cache_instance); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/error_manager.cpp b/src/duckdb/src/main/error_manager.cpp new file mode 100644 index 00000000..14038198 --- /dev/null +++ b/src/duckdb/src/main/error_manager.cpp @@ -0,0 +1,70 @@ +#include "duckdb/main/error_manager.hpp" +#include "duckdb/main/config.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +struct DefaultError { + ErrorType type; + const char *error; +}; + +static DefaultError internal_errors[] = { + {ErrorType::UNSIGNED_EXTENSION, + "Extension \"%s\" could not be loaded because its signature is either missing or invalid and unsigned extensions " + "are disabled by configuration (allow_unsigned_extensions)"}, + {ErrorType::INVALIDATED_TRANSACTION, "Current transaction is aborted (please ROLLBACK)"}, + {ErrorType::INVALIDATED_DATABASE, "Failed: database has been invalidated because of a previous fatal error. The " + "database must be restarted prior to being used again.\nOriginal error: \"%s\""}, + {ErrorType::INVALID, nullptr}}; + +string ErrorManager::FormatExceptionRecursive(ErrorType error_type, vector &values) { + if (error_type >= ErrorType::ERROR_COUNT) { + throw InternalException("Invalid error type passed to ErrorManager::FormatError"); + } + auto entry = custom_errors.find(error_type); + string error; + if (entry == custom_errors.end()) { + // error was not overwritten + error = internal_errors[int(error_type)].error; + } else { + // error was overwritten + error = entry->second; + } + return ExceptionFormatValue::Format(error, values); +} + +string ErrorManager::InvalidUnicodeError(const string &input, const string &context) { + UnicodeInvalidReason reason; + size_t pos; + auto unicode = Utf8Proc::Analyze(const_char_ptr_cast(input.c_str()), input.size(), &reason, &pos); + if (unicode != UnicodeType::INVALID) { + return "Invalid unicode error thrown but no invalid unicode detected in " + context; + } + string base_message; + switch (reason) { + case UnicodeInvalidReason::BYTE_MISMATCH: + base_message = "Invalid unicode (byte sequence mismatch)"; + break; + case UnicodeInvalidReason::INVALID_UNICODE: + base_message = "Invalid unicode"; + break; + default: + break; + } + return base_message + " detected in " + context; +} + +void ErrorManager::AddCustomError(ErrorType type, string new_error) { + custom_errors.insert(make_pair(type, std::move(new_error))); +} + +ErrorManager &ErrorManager::Get(ClientContext &context) { + return *DBConfig::GetConfig(context).error_manager; +} + +ErrorManager &ErrorManager::Get(DatabaseInstance &context) { + return *DBConfig::GetConfig(context).error_manager; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/extension.cpp b/src/duckdb/src/main/extension.cpp new file mode 100644 index 00000000..4bfc5c87 --- /dev/null +++ b/src/duckdb/src/main/extension.cpp @@ -0,0 +1,8 @@ +#include "duckdb/main/extension.hpp" + +namespace duckdb { + +Extension::~Extension() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/extension/extension_alias.cpp b/src/duckdb/src/main/extension/extension_alias.cpp new file mode 100644 index 00000000..cce16429 --- /dev/null +++ b/src/duckdb/src/main/extension/extension_alias.cpp @@ -0,0 +1,36 @@ +#include "duckdb/main/extension_helper.hpp" + +namespace duckdb { + +static ExtensionAlias internal_aliases[] = {{"http", "httpfs"}, // httpfs + {"https", "httpfs"}, + {"md", "motherduck"}, // motherduck + {"s3", "httpfs"}, + {"postgres", "postgres_scanner"}, // postgres + {"sqlite", "sqlite_scanner"}, // sqlite + {"sqlite3", "sqlite_scanner"}, + {nullptr, nullptr}}; + +idx_t ExtensionHelper::ExtensionAliasCount() { + idx_t index; + for (index = 0; internal_aliases[index].alias != nullptr; index++) { + } + return index; +} + +ExtensionAlias ExtensionHelper::GetExtensionAlias(idx_t index) { + D_ASSERT(index < ExtensionAliasCount()); + return internal_aliases[index]; +} + +string ExtensionHelper::ApplyExtensionAlias(string extension_name) { + auto lname = StringUtil::Lower(extension_name); + for (idx_t index = 0; internal_aliases[index].alias; index++) { + if (lname == internal_aliases[index].alias) { + return internal_aliases[index].extension; + } + } + return extension_name; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/extension/extension_helper.cpp b/src/duckdb/src/main/extension/extension_helper.cpp new file mode 100644 index 00000000..f47e853a --- /dev/null +++ b/src/duckdb/src/main/extension/extension_helper.cpp @@ -0,0 +1,617 @@ +#include "duckdb/main/extension_helper.hpp" + +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/windows.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" + +// Note that c++ preprocessor doesn't have a nice way to clean this up so we need to set the defines we use to false +// explicitly when they are undefined +#ifndef DUCKDB_EXTENSION_ICU_LINKED +#define DUCKDB_EXTENSION_ICU_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_EXCEL_LINKED +#define DUCKDB_EXTENSION_EXCEL_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_PARQUET_LINKED +#define DUCKDB_EXTENSION_PARQUET_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_TPCH_LINKED +#define DUCKDB_EXTENSION_TPCH_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_TPCDS_LINKED +#define DUCKDB_EXTENSION_TPCDS_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_FTS_LINKED +#define DUCKDB_EXTENSION_FTS_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_HTTPFS_LINKED +#define DUCKDB_EXTENSION_HTTPFS_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_JSON_LINKED +#define DUCKDB_EXTENSION_JSON_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_JEMALLOC_LINKED +#define DUCKDB_EXTENSION_JEMALLOC_LINKED false +#endif + +#ifndef DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED +#define DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED false +#endif + +// Load the generated header file containing our list of extension headers +#if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS && !defined(DUCKDB_AMALGAMATION) +#include "duckdb/main/extension/generated_extension_loader.hpp" +#else +// TODO: rewrite package_build.py to allow also loading out-of-tree extensions in non-cmake builds, after that +// these can be removed +#if DUCKDB_EXTENSION_ICU_LINKED +#include "icu_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_EXCEL_LINKED +#include "excel_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_PARQUET_LINKED +#include "parquet_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_TPCH_LINKED +#include "tpch_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_TPCDS_LINKED +#include "tpcds_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_FTS_LINKED +#include "fts_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_HTTPFS_LINKED +#include "httpfs_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_JSON_LINKED +#include "json_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_JEMALLOC_LINKED +#include "jemalloc_extension.hpp" +#endif + +#if DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED +#include "autocomplete_extension.hpp" +#endif +#endif + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Default Extensions +//===--------------------------------------------------------------------===// +static DefaultExtension internal_extensions[] = { + {"icu", "Adds support for time zones and collations using the ICU library", DUCKDB_EXTENSION_ICU_LINKED}, + {"excel", "Adds support for Excel-like format strings", DUCKDB_EXTENSION_EXCEL_LINKED}, + {"parquet", "Adds support for reading and writing parquet files", DUCKDB_EXTENSION_PARQUET_LINKED}, + {"tpch", "Adds TPC-H data generation and query support", DUCKDB_EXTENSION_TPCH_LINKED}, + {"tpcds", "Adds TPC-DS data generation and query support", DUCKDB_EXTENSION_TPCDS_LINKED}, + {"fts", "Adds support for Full-Text Search Indexes", DUCKDB_EXTENSION_FTS_LINKED}, + {"httpfs", "Adds support for reading and writing files over a HTTP(S) connection", DUCKDB_EXTENSION_HTTPFS_LINKED}, + {"json", "Adds support for JSON operations", DUCKDB_EXTENSION_JSON_LINKED}, + {"jemalloc", "Overwrites system allocator with JEMalloc", DUCKDB_EXTENSION_JEMALLOC_LINKED}, + {"autocomplete", "Adds support for autocomplete in the shell", DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED}, + {"motherduck", "Enables motherduck integration with the system", false}, + {"sqlite_scanner", "Adds support for reading SQLite database files", false}, + {"postgres_scanner", "Adds support for reading from a Postgres database", false}, + {"inet", "Adds support for IP-related data types and functions", false}, + {"spatial", "Geospatial extension that adds support for working with spatial data and functions", false}, + {"substrait", "Adds support for the Substrait integration", false}, + {"aws", "Provides features that depend on the AWS SDK", false}, + {"arrow", "A zero-copy data integration between Apache Arrow and DuckDB", false}, + {"azure", "Adds a filesystem abstraction for Azure blob storage to DuckDB", false}, + {"iceberg", "Adds support for Apache Iceberg", false}, + {"visualizer", "Creates an HTML-based visualization of the query plan", false}, + {nullptr, nullptr, false}}; + +idx_t ExtensionHelper::DefaultExtensionCount() { + idx_t index; + for (index = 0; internal_extensions[index].name != nullptr; index++) { + } + return index; +} + +DefaultExtension ExtensionHelper::GetDefaultExtension(idx_t index) { + D_ASSERT(index < DefaultExtensionCount()); + return internal_extensions[index]; +} + +//===--------------------------------------------------------------------===// +// Allow Auto-Install Extensions +//===--------------------------------------------------------------------===// +static const char *auto_install[] = {"motherduck", "postgres_scanner", "sqlite_scanner", nullptr}; + +// TODO: unify with new autoload mechanism +bool ExtensionHelper::AllowAutoInstall(const string &extension) { + auto lcase = StringUtil::Lower(extension); + for (idx_t i = 0; auto_install[i]; i++) { + if (lcase == auto_install[i]) { + return true; + } + } + return false; +} + +bool ExtensionHelper::CanAutoloadExtension(const string &ext_name) { +#ifdef DUCKDB_DISABLE_EXTENSION_LOAD + return false; +#endif + + if (ext_name.empty()) { + return false; + } + for (const auto &ext : AUTOLOADABLE_EXTENSIONS) { + if (ext_name == ext) { + return true; + } + } + return false; +} + +string ExtensionHelper::AddExtensionInstallHintToErrorMsg(ClientContext &context, const string &base_error, + const string &extension_name) { + auto &dbconfig = DBConfig::GetConfig(context); + string install_hint; + + if (!ExtensionHelper::CanAutoloadExtension(extension_name)) { + install_hint = "Please try installing and loading the " + extension_name + " extension:\nINSTALL " + + extension_name + ";\nLOAD " + extension_name + ";\n\n"; + } else if (!dbconfig.options.autoload_known_extensions) { + install_hint = + "Please try installing and loading the " + extension_name + " extension by running:\nINSTALL " + + extension_name + ";\nLOAD " + extension_name + + ";\n\nAlternatively, consider enabling auto-install " + "and auto-load by running:\nSET autoinstall_known_extensions=1;\nSET autoload_known_extensions=1;"; + } else if (!dbconfig.options.autoinstall_known_extensions) { + install_hint = + "Please try installing the " + extension_name + " extension by running:\nINSTALL " + extension_name + + ";\n\nAlternatively, consider enabling autoinstall by running:\nSET autoinstall_known_extensions=1;"; + } + + if (!install_hint.empty()) { + return base_error + "\n\n" + install_hint; + } + + return base_error; +} + +bool ExtensionHelper::TryAutoLoadExtension(ClientContext &context, const string &extension_name) noexcept { + auto &dbconfig = DBConfig::GetConfig(context); + try { + if (dbconfig.options.autoinstall_known_extensions) { + ExtensionHelper::InstallExtension(context, extension_name, false, + context.config.autoinstall_extension_repo); + } + ExtensionHelper::LoadExternalExtension(context, extension_name); + return true; + } catch (...) { + return false; + } + return false; +} + +void ExtensionHelper::AutoLoadExtension(ClientContext &context, const string &extension_name) { + auto &dbconfig = DBConfig::GetConfig(context); + try { +#ifndef DUCKDB_WASM + if (dbconfig.options.autoinstall_known_extensions) { + ExtensionHelper::InstallExtension(context, extension_name, false, + context.config.autoinstall_extension_repo); + } +#endif + ExtensionHelper::LoadExternalExtension(context, extension_name); + } catch (Exception &e) { + throw AutoloadException(extension_name, e); + } +} + +//===--------------------------------------------------------------------===// +// Load Statically Compiled Extension +//===--------------------------------------------------------------------===// +void ExtensionHelper::LoadAllExtensions(DuckDB &db) { + // The in-tree extensions that we check. Non-cmake builds are currently limited to these for static linking + // TODO: rewrite package_build.py to allow also loading out-of-tree extensions in non-cmake builds, after that + // these can be removed + unordered_set extensions {"parquet", "icu", "tpch", "tpcds", "fts", "httpfs", "visualizer", + "json", "excel", "sqlsmith", "inet", "jemalloc", "autocomplete"}; + for (auto &ext : extensions) { + LoadExtensionInternal(db, ext, true); + } + +#if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS + for (auto &ext : linked_extensions) { + LoadExtensionInternal(db, ext, true); + } +#endif +} + +ExtensionLoadResult ExtensionHelper::LoadExtension(DuckDB &db, const std::string &extension) { + return LoadExtensionInternal(db, extension, false); +} + +ExtensionLoadResult ExtensionHelper::LoadExtensionInternal(DuckDB &db, const std::string &extension, + bool initial_load) { +#ifdef DUCKDB_TEST_REMOTE_INSTALL + if (!initial_load && StringUtil::Contains(DUCKDB_TEST_REMOTE_INSTALL, extension)) { + Connection con(db); + auto result = con.Query("INSTALL " + extension); + if (result->HasError()) { + result->Print(); + return ExtensionLoadResult::EXTENSION_UNKNOWN; + } + result = con.Query("LOAD " + extension); + if (result->HasError()) { + result->Print(); + return ExtensionLoadResult::EXTENSION_UNKNOWN; + } + return ExtensionLoadResult::LOADED_EXTENSION; + } +#endif + +#ifdef DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE + // Note: weird comma's are on purpose to do easy string contains on a list of extension names + if (!initial_load && StringUtil::Contains(DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE, "," + extension + ",")) { + Connection con(db); + auto result = con.Query((string) "LOAD '" + DUCKDB_EXTENSIONS_BUILD_PATH + "/" + extension + "/" + extension + + ".duckdb_extension'"); + if (result->HasError()) { + result->Print(); + return ExtensionLoadResult::EXTENSION_UNKNOWN; + } + return ExtensionLoadResult::LOADED_EXTENSION; + } +#endif + + // This is the main extension loading mechanism that loads the extension that are statically linked. +#if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS + if (TryLoadLinkedExtension(db, extension)) { + return ExtensionLoadResult::LOADED_EXTENSION; + } else { + return ExtensionLoadResult::NOT_LOADED; + } +#endif + + // This is the fallback to the "old" extension loading mechanism for non-cmake builds + // TODO: rewrite package_build.py to allow also loading out-of-tree extensions in non-cmake builds + if (extension == "parquet") { +#if DUCKDB_EXTENSION_PARQUET_LINKED + db.LoadExtension(); +#else + // parquet extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "icu") { +#if DUCKDB_EXTENSION_ICU_LINKED + db.LoadExtension(); +#else + // icu extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "tpch") { +#if DUCKDB_EXTENSION_TPCH_LINKED + db.LoadExtension(); +#else + // icu extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "tpcds") { +#if DUCKDB_EXTENSION_TPCDS_LINKED + db.LoadExtension(); +#else + // icu extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "fts") { +#if DUCKDB_EXTENSION_FTS_LINKED +// db.LoadExtension(); +#else + // fts extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "httpfs") { +#if DUCKDB_EXTENSION_HTTPFS_LINKED + db.LoadExtension(); +#else + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "visualizer") { +#if DUCKDB_EXTENSION_VISUALIZER_LINKED + db.LoadExtension(); +#else + // visualizer extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "json") { +#if DUCKDB_EXTENSION_JSON_LINKED + db.LoadExtension(); +#else + // json extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "excel") { +#if DUCKDB_EXTENSION_EXCEL_LINKED + db.LoadExtension(); +#else + // excel extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "sqlsmith") { +#if DUCKDB_EXTENSION_SQLSMITH_LINKED + db.LoadExtension(); +#else + // excel extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "jemalloc") { +#if DUCKDB_EXTENSION_JEMALLOC_LINKED + db.LoadExtension(); +#else + // jemalloc extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "autocomplete") { +#if DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED + db.LoadExtension(); +#else + // autocomplete extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } else if (extension == "inet") { +#if DUCKDB_EXTENSION_INET_LINKED + db.LoadExtension(); +#else + // inet extension required but not build: skip this test + return ExtensionLoadResult::NOT_LOADED; +#endif + } + + return ExtensionLoadResult::LOADED_EXTENSION; +} + +static vector public_keys = { + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA6aZuHUa1cLR9YDDYaEfi +UDbWY8m2t7b71S+k1ZkXfHqu+5drAxm+dIDzdOHOKZSIdwnJbT3sSqwFoG6PlXF3 +g3dsJjax5qESIhbVvf98nyipwNINxoyHCkcCIPkX17QP2xpnT7V59+CqcfDJXLqB +ymjqoFSlaH8dUCHybM4OXlWnAtVHW/nmw0khF8CetcWn4LxaTUHptByaBz8CasSs +gWpXgSfaHc3R9eArsYhtsVFGyL/DEWgkEHWolxY3Llenhgm/zOf3s7PsAMe7EJX4 +qlSgiXE6OVBXnqd85z4k20lCw/LAOe5hoTMmRWXIj74MudWe2U91J6GrrGEZa7zT +7QIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAq8Gg1S/LI6ApMAYsFc9m +PrkFIY+nc0LXSpxm77twU8D5M0Xkz/Av4f88DQmj1OE3164bEtR7sl7xDPZojFHj +YYyucJxEI97l5OU1d3Pc1BdKXL4+mnW5FlUGj218u8qD+G1hrkySXQkrUzIjPPNw +o6knF3G/xqQF+KI+tc7ajnTni8CAlnUSxfnstycqbVS86m238PLASVPK9/SmIRgO +XCEV+ZNMlerq8EwsW4cJPHH0oNVMcaG+QT4z79roW1rbJghn9ubAVdQU6VLUAikI +b8keUyY+D0XdY9DpDBeiorb1qPYt8BPLOAQrIUAw1CgpMM9KFp9TNvW47KcG4bcB +dQIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyYATA9KOQ0Azf97QAPfY +Jc/WeZyE4E1qlRgKWKqNtYSXZqk5At0V7w2ntAWtYSpczFrVepCJ0oPMDpZTigEr +NgOgfo5LEhPx5XmtCf62xY/xL3kgtfz9Mm5TBkuQy4KwY4z1npGr4NYYDXtF7kkf +LQE+FnD8Yr4E0wHBib7ey7aeeKWmwqvUjzDqG+TzaqwzO/RCUsSctqSS0t1oo2hv +4q1ofanUXsV8MXk/ujtgxu7WkVvfiSpK1zRazgeZjcrQFO9qL/pla0vBUxa1U8He +GMLnL0oRfcMg7yKrbIMrvlEl2ZmiR9im44dXJWfY42quObwr1PuEkEoCMcMisSWl +jwIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4RvbWx3zLblDHH/lGUF5 +Q512MT+v3YPriuibROMllv8WiCLAMeJ0QXbVaIzBOeHDeLx8yvoZZN+TENKxtT6u +IfMMneUzxHBqy0AQNfIsSsOnG5nqoeE/AwbS6VqCdH1aLfoCoPffacHYa0XvTcsi +aVlZfr+UzJS+ty8pRmFVi1UKSOADDdK8XfIovJl/zMP2TxYX2Y3fnjeLtl8Sqs2e +P+eHDoy7Wi4EPTyY7tNTCfxwKNHn1HQ5yrv5dgvMxFWIWXGz24yikFvtwLGHe8uJ +Wi+fBX+0PF0diZ6pIthZ149VU8qCqYAXjgpxZ0EZdrsiF6Ewz0cfg20SYApFcmW4 +pwIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyhd5AfwrUohG3O4DE0K9 +O3FmgB7zE4aDnkL8UUfGCh5kdP8q7ewMjekY+c6LwWOmpdJpSwqhfV1q5ZU1l6rk +3hlt03LO3sgs28kcfOVH15hqfxts6Sg5KcRjxStE50ORmXGwXDcS9vqkJ60J1EHA +lcZqbCRSO73ZPLhdepfd0/C6tM0L7Ge6cAE62/MTmYNGv8fDzwQr/kYIJMdoS8Zp +thRpctFZJtPs3b0fffZA/TCLVKMvEVgTWs48751qKid7N/Lm/iEGx/tOf4o23Nec +Pz1IQaGLP+UOLVQbqQBHJWNOqigm7kWhDgs3N4YagWgxPEQ0WVLtFji/ZjlKZc7h +dwIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAnFDg3LhyV6BVE2Z3zQvN +6urrKvPhygTa5+wIPGwYTzJ8DfGALqlsX3VOXMvcJTca6SbuwwkoXHuSU5wQxfcs +bt4jTXD3NIoRwQPl+D9IbgIMuX0ACl27rJmr/f9zkY7qui4k1X82pQkxBe+/qJ4r +TBwVNONVx1fekTMnSCEhwg5yU3TNbkObu0qlQeJfuMWLDQbW/8v/qfr/Nz0JqHDN +yYKfKvFMlORxyJYiOyeOsbzNGEhkGQGOmKhRUhS35kD+oA0jqwPwMCM9O4kFg/L8 +iZbpBBX2By1K3msejWMRAewTOyPas6YMQOYq9BMmWQqzVtG5xcaSJwN/YnMpJyqb +sQIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1z0RU8vGrfEkrscEoZKA +GiOcGh2EMcKwjQpl4nKuR9H4o/dg+CZregVSHg7MP2f8mhLZZyoFev49oWOV4Rmi +qs99UNxm7DyKW1fF1ovowsUW5lsDoKYLvpuzHo0s4laiV4AnIYP7tHGLdzsnK2Os +Cp5dSuMwKHPZ9N25hXxFB/dRrAdIiXHvbSqr4N29XzfQloQpL3bGHLKY6guFHluH +X5dJ9eirVakWWou7BR2rnD0k9vER6oRdVnJ6YKb5uhWEOQ3NmV961oyr+uiDTcep +qqtGHWuFhENixtiWGjFJJcACwqxEAW3bz9lyrfnPDsHSW/rlQVDIAkik+fOp+R7L +kQIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAxwO27e1vnbNcpiDg7Wwx +K/w5aEGukXotu3529ieq+O39H0+Bak4vIbzGhDUh3/ElmxaFMAs4PYrWe/hc2WFD +H4JCOoFIn4y9gQeE855DGGFgeIVd1BnSs5S+5wUEMxLNyHdHSmINN6FsoZ535iUg +KdYjRh1iZevezg7ln8o/O36uthu925ehFBXSy6jLJgQlwmq0KxZJE0OAZhuDBM60 +MtIunNa/e5y+Gw3GknFwtRLmn/nEckZx1nEtepYvvUa7UGy+8KuGuhOerCZTutbG +k8liCVgGenRve8unA2LrBbpL+AUf3CrZU/uAxxTqWmw6Z/S6TeW5ozeeyOCh8ii6 +TwIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsGIFOfIQ4RI5qu4klOxf +ge6eXwBMAkuTXyhyIIJDtE8CurnwQvUXVlt+Kf0SfuIFW6MY5ErcWE/vMFbc81IR +9wByOAAV2CTyiLGZT63uE8pN6FSHd6yGYCLjXd3P3cnP3Qj5pBncpLuAUDfHG4wP +bs9jIADw3HysD+eCNja8p7ZC7CzWxTcO7HsEu9deAAU19YywdpagXvQ0pJ9zV5qU +jrHxBygl31t6TmmX+3d+azjGu9Hu36E+5wcSOOhuwAFXDejb40Ixv53ItJ3fZzzH +PF2nj9sQvQ8c5ptjyOvQCBRdqkEWXIVHClxqWb+o59pDIh1G0UGcmiDN7K9Gz5HA +ZQIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAt9uUnlW/CoUXT68yaZh9 +SeXHzGRCPNEI98Tara+dgYxDX1z7nfOh8o15liT0QsAzx34EewZOxcKCNiV/dZX5 +z4clCkD8uUbZut6IVx8Eu+7Qcd5jZthRc6hQrN9Ltv7ZQEh7KGXOHa53kT2K01ws +4jbVmd/7Nx7y0Yyqhja01pIu/CUaTkODfQxBXwriLdIzp7y/iJeF/TLqCwZWHKQx +QOZnsPEveB1F00Va9MeAtTlXFUJ/TQXquqTjeLj4HuIRtbyuNgWoc0JyF+mcafAl +bnrNEBIfxZhAT81aUCIAzRJp6AqfdeZxnZ/WwohtZQZLXAxFQPTWCcP+Z9M7OIQL +WwIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA56NhfACkeCyZM07l2wmd +iTp24E2tLLKU3iByKlIRWRAvXsOejRMJTHTNHWa3cQ7uLP++Tf2St7ksNsyPMNZy +9QRTLNCYr9rN9loLwdb2sMWxFBwwzCaAOTahGI7GJQy30UB7FEND0X/5U2rZvQij +Q6K+O4aa+K9M5qyOHNMmXywmTnAgWKNaNxQHPRtD2+dSj60T6zXdtIuCrPfcNGg5 +gj07qWGEXX83V/L7nSqCiIVYg/wqds1x52Yjk1nhXYNBTqlnhmOd8LynGxz/sXC7 +h2Q9XsHjXIChW4FHyLIOl6b4zPMBSxzCigYm3QZJWfAkZv5PBRtnq7vhYOLHzLQj +CwIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmfPLe0IWGYC0MZC6YiM3 +QGfhT6zSKB0I2DW44nlBlWUcF+32jW2bFJtgE76qGGKFeU4kJBWYr99ufHoAodNg +M1Ehl/JfQ5KmbC1WIqnFTrgbmqJde79jeCvCpbFLuqnzidwO1PbXDbfRFQcgWaXT +mDVLNNVmLxA0GkCv+kydE2gtcOD9BDceg7F/56TDvclyI5QqAnjE2XIRMPZlXQP4 +oF2kgz4Cn7LxLHYmkU2sS9NYLzHoyUqFplWlxkQjA4eQ0neutV1Ydmc1IX8W7R38 +A7nFtaT8iI8w6Vkv7ijYN6xf5cVBPKZ3Dv7AdwPet86JD5mf5v+r7iwg5xl3r77Z +iwIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoB1kWsX8YmCcFOD9ilBY +xK076HmUAN026uJ8JpmU9Hz+QT1FNXOsnj1h2G6U6btYVIdHUTHy/BvAumrDKqRz +qcEAzCuhxUjPjss54a/Zqu6nQcoIPHuG/Er39oZHIVkPR1WCvWj8wmyYv6T//dPH +unO6tW29sXXxS+J1Gah6vpbtJw1pI/liah1DZzb13KWPDI6ZzviTNnW4S05r6js/ +30He+Yud6aywrdaP/7G90qcrteEFcjFy4Xf+5vG960oKoGoDplwX5poay1oCP9tb +g8AC8VSRAGi3oviTeSWZcrLXS8AtJhGvF48cXQj2q+8YeVKVDpH6fPQxJ9Sh9aeU +awIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4NTMAIYIlCMID00ufy/I +AZXc8pocDx9N1Q5x5/cL3aIpLmx02AKo9BvTJaJuHiTjlwYhPtlhIrHV4HUVTkOX +sISp8B8v9i2I1RIvCTAcvy3gcH6rdRWZ0cdTUiMEqnnxBX9zdzl8oMzZcyauv19D +BeqJvzflIT96b8g8K3mvgJHs9a1j9f0gN8FuTA0c52DouKnrh8UwH7mlrumYerJw +6goJGQuK1HEOt6bcQuvogkbgJWOoEYwjNrPwQvIcP4wyrgSnOHg1yXOFE84oVynJ +czQEOz9ke42I3h8wrnQxilEYBVo2uX8MenqTyfGnE32lPRt3Wv1iEVQls8Cxiuy2 +CQIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3bUtfp66OtRyvIF/oucn +id8mo7gvbNEH04QMLO3Ok43dlWgWI3hekJAqOYc0mvoI5anqr98h8FI7aCYZm/bY +vpz0I1aXBaEPh3aWh8f/w9HME7ykBvmhMe3J+VFGWWL4eswfRl//GCtnSMBzDFhM +SaQOTvADWHkC0njeI5yXjf/lNm6fMACP1cnhuvCtnx7VP/DAtvUk9usDKG56MJnZ +UoVM3HHjbJeRwxCdlSWe12ilCdwMRKSDY92Hk38/zBLenH04C3HRQLjBGewACUmx +uvNInehZ4kSYFGa+7UxBxFtzJhlKzGR73qUjpWzZivCe1K0WfRVP5IWsKNCCESJ/ +nQIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyV2dE/CRUAUE8ybq/DoS +Lc7QlYXh04K+McbhN724TbHahLTuDk5mR5TAunA8Nea4euRzknKdMFAz1eh9gyy3 +5x4UfXQW1fIZqNo6WNrGxYJgWAXU+pov+OvxsMQWzqS4jrTHDHbblCCLKp1akwJk +aFNyqgjAL373PcqXC+XAn8vHx4xHFoFP5lq4lLcJCOW5ee9v9El3w0USLwS+t1cF +RY3kuV6Njlr4zsRH9iM6/zaSuCALYWJ/JrPEurSJXzFZnWsvn6aQdeNeAn08+z0F +k2NwaauEo0xmLqzqTRGzjHqKKmeefN3/+M/FN2FrApDlxWQfhD2Y3USdAiN547Nj +1wIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvm2+kTrEQWZXuxhWzBdl +PCbQGqbrukbeS6JKSlQLJDC8ayZIxFxatqg1Q8UPyv89MVRsHOGlG1OqFaOEtPjQ +Oo6j/moFwB4GPyJhJHOGpCKa4CLB5clhfDCLJw6ty7PcDU3T6yW4X4Qc5k4LRRWy +yzC8lVHfBdarN+1iEe0ALMOGoeiJjVn6i/AFxktRwgd8njqv/oWQyfjJZXkNMsb6 +7ZDxNVAUrp/WXpE4Kq694bB9xa/pWsqv7FjQJUgTnEzvbN+qXnVPtA7dHcOYYJ8Z +SbrJUfHrf8TS5B54AiopFpWG+hIbjqqdigqabBqFpmjiRDZgDy4zJJj52xJZMnrp +rwIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwEAcVmY3589O02pLA22f +MlarLyJUgy0BeJDG5AUsi17ct8sHZzRiv9zKQVCBk1CtZY//jyqnrM7iCBLWsyby +TiTOtGYHHApaLnNjjtaHdQ6zplhbc3g2XLy+4ab8GNKG3zc8iXpsQM6r+JO5n9pm +V9vollz9dkFxS9l+1P17lZdIgCh9O3EIFJv5QCd5c9l2ezHAan2OhkWhiDtldnH/ +MfRXbz7X5sqlwWLa/jhPtvY45x7dZaCHGqNzbupQZs0vHnAVdDu3vAWDmT/3sXHG +vmGxswKA9tPU0prSvQWLz4LUCnGi/cC5R+fiu+fovFM/BwvaGtqBFIF/1oWVq7bZ +4wIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA25qGwNO1+qHygC8mjm8L +3I66mV/IzslgBDHC91mE8YcI5Fq0sdrtsbUhK3z89wIN/zOhbHX0NEiXm2GxUnsI +vb5tDZXAh7AbTnXTMVbxO/e/8sPLUiObGjDvjVzyzrxOeG87yK/oIiilwk9wTsIb +wMn2Grj4ht9gVKx3oGHYV7STNdWBlzSaJj4Ou7+5M1InjPDRFZG1K31D2d3IHByX +lmcRPZtPFTa5C1uVJw00fI4F4uEFlPclZQlR5yA0G9v+0uDgLcjIUB4eqwMthUWc +dHhlmrPp04LI19eksWHCtG30RzmUaxDiIC7J2Ut0zHDqUe7aXn8tOVI7dE9tTKQD +KQIDAQAB +-----END PUBLIC KEY----- +)", + R"( +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA7EC2bx7aRnf3TcRg5gmw +QOKNCUheCelK8hoXLMsKSJqmufyJ+IHUejpXGOpvyYRbACiJ5GiNcww20MVpTBU7 +YESWB2QSU2eEJJXMq84qsZSO8WGmAuKpUckI+hNHKQYJBEDOougV6/vVVEm5c5bc +SLWQo0+/ciQ21Zwz5SwimX8ep1YpqYirO04gcyGZzAfGboXRvdUwA+1bZvuUXdKC +4zsCw2QALlcVpzPwjB5mqA/3a+SPgdLAiLOwWXFDRMnQw44UjsnPJFoXgEZiUpZm +EMS5gLv50CzQqJXK9mNzPuYXNUIc4Pw4ssVWe0OfN3Od90gl5uFUwk/G9lWSYnBN +3wIDAQAB +-----END PUBLIC KEY----- +)"}; + +const vector ExtensionHelper::GetPublicKeys() { + return public_keys; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/extension/extension_install.cpp b/src/duckdb/src/main/extension/extension_install.cpp new file mode 100644 index 00000000..231c38ee --- /dev/null +++ b/src/duckdb/src/main/extension/extension_install.cpp @@ -0,0 +1,306 @@ +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/common/gzip_file_system.hpp" +#include "duckdb/common/types/uuid.hpp" +#include "duckdb/common/string_util.hpp" + +#ifndef DISABLE_DUCKDB_REMOTE_INSTALL +#ifndef DUCKDB_DISABLE_EXTENSION_LOAD +#include "httplib.hpp" +#endif +#endif +#include "duckdb/common/windows_undefs.hpp" + +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Install Extension +//===--------------------------------------------------------------------===// +const string ExtensionHelper::NormalizeVersionTag(const string &version_tag) { + if (version_tag.length() > 0 && version_tag[0] != 'v') { + return "v" + version_tag; + } + return version_tag; +} + +bool ExtensionHelper::IsRelease(const string &version_tag) { + return !StringUtil::Contains(version_tag, "-dev"); +} + +const string ExtensionHelper::GetVersionDirectoryName() { +#ifdef DUCKDB_WASM_VERSION + return DUCKDB_QUOTE_DEFINE(DUCKDB_WASM_VERSION); +#endif + if (IsRelease(DuckDB::LibraryVersion())) { + return NormalizeVersionTag(DuckDB::LibraryVersion()); + } else { + return DuckDB::SourceID(); + } +} + +const vector ExtensionHelper::PathComponents() { + return vector {".duckdb", "extensions", GetVersionDirectoryName(), DuckDB::Platform()}; +} + +string ExtensionHelper::ExtensionDirectory(DBConfig &config, FileSystem &fs) { +#ifdef WASM_LOADABLE_EXTENSIONS + throw PermissionException("ExtensionDirectory functionality is not supported in duckdb-wasm"); +#endif + string extension_directory; + if (!config.options.extension_directory.empty()) { // create the extension directory if not present + extension_directory = config.options.extension_directory; + // TODO this should probably live in the FileSystem + // convert random separators to platform-canonic + extension_directory = fs.ConvertSeparators(extension_directory); + // expand ~ in extension directory + extension_directory = fs.ExpandPath(extension_directory); + if (!fs.DirectoryExists(extension_directory)) { + auto sep = fs.PathSeparator(extension_directory); + auto splits = StringUtil::Split(extension_directory, sep); + D_ASSERT(!splits.empty()); + string extension_directory_prefix; + if (StringUtil::StartsWith(extension_directory, sep)) { + extension_directory_prefix = sep; // this is swallowed by Split otherwise + } + for (auto &split : splits) { + extension_directory_prefix = extension_directory_prefix + split + sep; + if (!fs.DirectoryExists(extension_directory_prefix)) { + fs.CreateDirectory(extension_directory_prefix); + } + } + } + } else { // otherwise default to home + string home_directory = fs.GetHomeDirectory(); + // exception if the home directory does not exist, don't create whatever we think is home + if (!fs.DirectoryExists(home_directory)) { + throw IOException("Can't find the home directory at '%s'\nSpecify a home directory using the SET " + "home_directory='/path/to/dir' option.", + home_directory); + } + extension_directory = home_directory; + } + D_ASSERT(fs.DirectoryExists(extension_directory)); + + auto path_components = PathComponents(); + for (auto &path_ele : path_components) { + extension_directory = fs.JoinPath(extension_directory, path_ele); + if (!fs.DirectoryExists(extension_directory)) { + fs.CreateDirectory(extension_directory); + } + } + return extension_directory; +} + +string ExtensionHelper::ExtensionDirectory(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + auto &fs = FileSystem::GetFileSystem(context); + return ExtensionDirectory(config, fs); +} + +bool ExtensionHelper::CreateSuggestions(const string &extension_name, string &message) { + vector candidates; + for (idx_t ext_count = ExtensionHelper::DefaultExtensionCount(), i = 0; i < ext_count; i++) { + candidates.emplace_back(ExtensionHelper::GetDefaultExtension(i).name); + } + for (idx_t ext_count = ExtensionHelper::ExtensionAliasCount(), i = 0; i < ext_count; i++) { + candidates.emplace_back(ExtensionHelper::GetExtensionAlias(i).alias); + } + auto closest_extensions = StringUtil::TopNLevenshtein(candidates, extension_name); + message = StringUtil::CandidatesMessage(closest_extensions, "Candidate extensions"); + for (auto &closest : closest_extensions) { + if (closest == extension_name) { + message = "Extension \"" + extension_name + "\" is an existing extension.\n"; + return true; + } + } + return false; +} + +void ExtensionHelper::InstallExtension(DBConfig &config, FileSystem &fs, const string &extension, bool force_install, + const string &repository) { +#ifdef WASM_LOADABLE_EXTENSIONS + // Install is currently a no-op + return; +#endif + string local_path = ExtensionDirectory(config, fs); + InstallExtensionInternal(config, nullptr, fs, local_path, extension, force_install, repository); +} + +void ExtensionHelper::InstallExtension(ClientContext &context, const string &extension, bool force_install, + const string &repository) { +#ifdef WASM_LOADABLE_EXTENSIONS + // Install is currently a no-op + return; +#endif + auto &config = DBConfig::GetConfig(context); + auto &fs = FileSystem::GetFileSystem(context); + string local_path = ExtensionDirectory(context); + auto &client_config = ClientConfig::GetConfig(context); + InstallExtensionInternal(config, &client_config, fs, local_path, extension, force_install, repository); +} + +unsafe_unique_array ReadExtensionFileFromDisk(FileSystem &fs, const string &path, idx_t &file_size) { + auto source_file = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ); + file_size = source_file->GetFileSize(); + auto in_buffer = make_unsafe_uniq_array(file_size); + source_file->Read(in_buffer.get(), file_size); + source_file->Close(); + return in_buffer; +} + +void WriteExtensionFileToDisk(FileSystem &fs, const string &path, void *data, idx_t data_size) { + auto target_file = fs.OpenFile(path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_APPEND | + FileFlags::FILE_FLAGS_FILE_CREATE_NEW); + target_file->Write(data, data_size); + target_file->Close(); + target_file.reset(); +} + +string ExtensionHelper::ExtensionUrlTemplate(optional_ptr client_config, const string &repository) { + string versioned_path = "/${REVISION}/${PLATFORM}/${NAME}.duckdb_extension"; +#ifdef WASM_LOADABLE_EXTENSIONS + string default_endpoint = "https://extensions.duckdb.org"; + versioned_path = "/duckdb-wasm" + versioned_path + ".wasm"; +#else + string default_endpoint = "http://extensions.duckdb.org"; + versioned_path = versioned_path + ".gz"; +#endif + string custom_endpoint = client_config ? client_config->custom_extension_repo : string(); + string endpoint; + if (!repository.empty()) { + endpoint = repository; + } else if (!custom_endpoint.empty()) { + endpoint = custom_endpoint; + } else { + endpoint = default_endpoint; + } + string url_template = endpoint + versioned_path; + return url_template; +} + +string ExtensionHelper::ExtensionFinalizeUrlTemplate(const string &url_template, const string &extension_name) { + auto url = StringUtil::Replace(url_template, "${REVISION}", GetVersionDirectoryName()); + url = StringUtil::Replace(url, "${PLATFORM}", DuckDB::Platform()); + url = StringUtil::Replace(url, "${NAME}", extension_name); + return url; +} + +void ExtensionHelper::InstallExtensionInternal(DBConfig &config, ClientConfig *client_config, FileSystem &fs, + const string &local_path, const string &extension, bool force_install, + const string &repository) { +#ifdef DUCKDB_DISABLE_EXTENSION_LOAD + throw PermissionException("Installing external extensions is disabled through a compile time flag"); +#else + if (!config.options.enable_external_access) { + throw PermissionException("Installing extensions is disabled through configuration"); + } + auto extension_name = ApplyExtensionAlias(fs.ExtractBaseName(extension)); + + string local_extension_path = fs.JoinPath(local_path, extension_name + ".duckdb_extension"); + if (fs.FileExists(local_extension_path) && !force_install) { + return; + } + + auto uuid = UUID::ToString(UUID::GenerateRandomUUID()); + string temp_path = local_extension_path + ".tmp-" + uuid; + if (fs.FileExists(temp_path)) { + fs.RemoveFile(temp_path); + } + auto is_http_url = StringUtil::Contains(extension, "http://"); + if (fs.FileExists(extension)) { + idx_t file_size; + auto in_buffer = ReadExtensionFileFromDisk(fs, extension, file_size); + WriteExtensionFileToDisk(fs, temp_path, in_buffer.get(), file_size); + + if (fs.FileExists(local_extension_path) && force_install) { + fs.RemoveFile(local_extension_path); + } + fs.MoveFile(temp_path, local_extension_path); + return; + } else if (StringUtil::Contains(extension, "/") && !is_http_url) { + throw IOException("Failed to read extension from \"%s\": no such file", extension); + } + +#ifdef DISABLE_DUCKDB_REMOTE_INSTALL + throw BinderException("Remote extension installation is disabled through configuration"); +#else + + string url_template = ExtensionUrlTemplate(client_config, repository); + + if (is_http_url) { + url_template = extension; + extension_name = ""; + } + + string url = ExtensionFinalizeUrlTemplate(url_template, extension_name); + + string no_http = StringUtil::Replace(url, "http://", ""); + + idx_t next = no_http.find('/', 0); + if (next == string::npos) { + throw IOException("No slash in URL template"); + } + + // Special case to install extension from a local file, useful for testing + if (!StringUtil::Contains(url_template, "http://")) { + string file = fs.ConvertSeparators(url); + if (!fs.FileExists(file)) { + // check for non-gzipped variant + file = file.substr(0, file.size() - 3); + if (!fs.FileExists(file)) { + throw IOException("Failed to copy local extension \"%s\" at PATH \"%s\"\n", extension_name, file); + } + } + auto read_handle = fs.OpenFile(file, FileFlags::FILE_FLAGS_READ); + auto test_data = std::unique_ptr {new unsigned char[read_handle->GetFileSize()]}; + read_handle->Read(test_data.get(), read_handle->GetFileSize()); + WriteExtensionFileToDisk(fs, temp_path, (void *)test_data.get(), read_handle->GetFileSize()); + + if (fs.FileExists(local_extension_path) && force_install) { + fs.RemoveFile(local_extension_path); + } + fs.MoveFile(temp_path, local_extension_path); + return; + } + + // Push the substring [last, next) on to splits + auto hostname_without_http = no_http.substr(0, next); + auto url_local_part = no_http.substr(next); + + auto url_base = "http://" + hostname_without_http; + duckdb_httplib::Client cli(url_base.c_str()); + + duckdb_httplib::Headers headers = {{"User-Agent", StringUtil::Format("DuckDB %s %s %s", DuckDB::LibraryVersion(), + DuckDB::SourceID(), DuckDB::Platform())}}; + + auto res = cli.Get(url_local_part.c_str(), headers); + + if (!res || res->status != 200) { + // create suggestions + string message; + auto exact_match = ExtensionHelper::CreateSuggestions(extension_name, message); + if (exact_match) { + message += "\nAre you using a development build? In this case, extensions might not (yet) be uploaded."; + } + if (res.error() == duckdb_httplib::Error::Success) { + throw HTTPException(res.value(), "Failed to download extension \"%s\" at URL \"%s%s\"\n%s", extension_name, + url_base, url_local_part, message); + } else { + throw IOException("Failed to download extension \"%s\" at URL \"%s%s\"\n%s (ERROR %s)", extension_name, + url_base, url_local_part, message, to_string(res.error())); + } + } + auto decompressed_body = GZipFileSystem::UncompressGZIPString(res->body); + + WriteExtensionFileToDisk(fs, temp_path, (void *)decompressed_body.data(), decompressed_body.size()); + + if (fs.FileExists(local_extension_path) && force_install) { + fs.RemoveFile(local_extension_path); + } + fs.MoveFile(temp_path, local_extension_path); +#endif +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/extension/extension_load.cpp b/src/duckdb/src/main/extension/extension_load.cpp new file mode 100644 index 00000000..80d24c29 --- /dev/null +++ b/src/duckdb/src/main/extension/extension_load.cpp @@ -0,0 +1,336 @@ +#include "duckdb/common/dl.hpp" +#include "duckdb/common/virtual_file_system.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/main/error_manager.hpp" +#include "mbedtls_wrapper.hpp" + +#ifndef DUCKDB_NO_THREADS +#include +#endif // DUCKDB_NO_THREADS + +#ifdef WASM_LOADABLE_EXTENSIONS +#include +#endif + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Load External Extension +//===--------------------------------------------------------------------===// +#ifndef DUCKDB_DISABLE_EXTENSION_LOAD +typedef void (*ext_init_fun_t)(DatabaseInstance &); +typedef const char *(*ext_version_fun_t)(void); +typedef bool (*ext_is_storage_t)(void); + +template +static T LoadFunctionFromDLL(void *dll, const string &function_name, const string &filename) { + auto function = dlsym(dll, function_name.c_str()); + if (!function) { + throw IOException("File \"%s\" did not contain function \"%s\": %s", filename, function_name, GetDLError()); + } + return (T)function; +} + +static void ComputeSHA256String(const std::string &to_hash, std::string *res) { + // Invoke MbedTls function to actually compute sha256 + *res = duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(to_hash); +} + +static void ComputeSHA256FileSegment(FileHandle *handle, const idx_t start, const idx_t end, std::string *res) { + idx_t iter = start; + const idx_t segment_size = 1024 * 8; + + duckdb_mbedtls::MbedTlsWrapper::SHA256State state; + + std::string to_hash; + while (iter < end) { + idx_t len = std::min(end - iter, segment_size); + to_hash.resize(len); + handle->Read((void *)to_hash.data(), len, iter); + + state.AddString(to_hash); + + iter += segment_size; + } + + *res = state.Finalize(); +} +#endif + +bool ExtensionHelper::TryInitialLoad(DBConfig &config, FileSystem &fs, const string &extension, + ExtensionInitResult &result, string &error, + optional_ptr client_config) { +#ifdef DUCKDB_DISABLE_EXTENSION_LOAD + throw PermissionException("Loading external extensions is disabled through a compile time flag"); +#else + if (!config.options.enable_external_access) { + throw PermissionException("Loading external extensions is disabled through configuration"); + } + auto filename = fs.ConvertSeparators(extension); + + // shorthand case + if (!ExtensionHelper::IsFullPath(extension)) { + string extension_name = ApplyExtensionAlias(extension); +#ifdef WASM_LOADABLE_EXTENSIONS + string url_template = ExtensionUrlTemplate(client_config, ""); + string url = ExtensionFinalizeUrlTemplate(url_template, extension_name); + + char *str = (char *)EM_ASM_PTR( + { + var jsString = ((typeof runtime == 'object') && runtime && (typeof runtime.whereToLoad == 'function') && + runtime.whereToLoad) + ? runtime.whereToLoad(UTF8ToString($0)) + : (UTF8ToString($1)); + var lengthBytes = lengthBytesUTF8(jsString) + 1; + // 'jsString.length' would return the length of the string as UTF-16 + // units, but Emscripten C strings operate as UTF-8. + var stringOnWasmHeap = _malloc(lengthBytes); + stringToUTF8(jsString, stringOnWasmHeap, lengthBytes); + return stringOnWasmHeap; + }, + filename.c_str(), url.c_str()); + std::string address(str); + free(str); + + filename = address; +#else + + string local_path = + !config.options.extension_directory.empty() ? config.options.extension_directory : fs.GetHomeDirectory(); + + // convert random separators to platform-canonic + local_path = fs.ConvertSeparators(local_path); + // expand ~ in extension directory + local_path = fs.ExpandPath(local_path); + auto path_components = PathComponents(); + for (auto &path_ele : path_components) { + local_path = fs.JoinPath(local_path, path_ele); + } + filename = fs.JoinPath(local_path, extension_name + ".duckdb_extension"); +#endif + } + if (!fs.FileExists(filename)) { + string message; + bool exact_match = ExtensionHelper::CreateSuggestions(extension, message); + if (exact_match) { + message += "\nInstall it first using \"INSTALL " + extension + "\"."; + } + error = StringUtil::Format("Extension \"%s\" not found.\n%s", filename, message); + return false; + } + if (!config.options.allow_unsigned_extensions) { + auto handle = fs.OpenFile(filename, FileFlags::FILE_FLAGS_READ); + + // signature is the last 256 bytes of the file + + string signature; + signature.resize(256); + + auto signature_offset = handle->GetFileSize() - signature.size(); + + const idx_t maxLenChunks = 1024ULL * 1024ULL; + const idx_t numChunks = (signature_offset + maxLenChunks - 1) / maxLenChunks; + std::vector hash_chunks(numChunks); + std::vector splits(numChunks + 1); + + for (idx_t i = 0; i < numChunks; i++) { + splits[i] = maxLenChunks * i; + } + splits.back() = signature_offset; + +#ifndef DUCKDB_NO_THREADS + std::vector threads; + threads.reserve(numChunks); + for (idx_t i = 0; i < numChunks; i++) { + threads.emplace_back(ComputeSHA256FileSegment, handle.get(), splits[i], splits[i + 1], &hash_chunks[i]); + } + + for (auto &thread : threads) { + thread.join(); + } +#else + for (idx_t i = 0; i < numChunks; i++) { + ComputeSHA256FileSegment(handle.get(), splits[i], splits[i + 1], &hash_chunks[i]); + } +#endif // DUCKDB_NO_THREADS + + string hash_concatenation; + hash_concatenation.reserve(32 * numChunks); // 256 bits -> 32 bytes per chunk + + for (auto &hash_chunk : hash_chunks) { + hash_concatenation += hash_chunk; + } + + string two_level_hash; + ComputeSHA256String(hash_concatenation, &two_level_hash); + + // TODO maybe we should do a stream read / hash update here + handle->Read((void *)signature.data(), signature.size(), signature_offset); + + bool any_valid = false; + for (auto &key : ExtensionHelper::GetPublicKeys()) { + if (duckdb_mbedtls::MbedTlsWrapper::IsValidSha256Signature(key, signature, two_level_hash)) { + any_valid = true; + break; + } + } + if (!any_valid) { + throw IOException(config.error_manager->FormatException(ErrorType::UNSIGNED_EXTENSION, filename)); + } + } + auto basename = fs.ExtractBaseName(filename); + +#ifdef WASM_LOADABLE_EXTENSIONS + EM_ASM( + { + // Next few lines should argubly in separate JavaScript-land function call + // TODO: move them out / have them configurable + const xhr = new XMLHttpRequest(); + xhr.open("GET", UTF8ToString($0), false); + xhr.responseType = "arraybuffer"; + xhr.send(null); + var uInt8Array = xhr.response; + WebAssembly.validate(uInt8Array); + console.log('Loading extension ', UTF8ToString($1)); + + // Here we add the uInt8Array to Emscripten's filesystem, for it to be found by dlopen + FS.writeFile(UTF8ToString($1), new Uint8Array(uInt8Array)); + }, + filename.c_str(), basename.c_str()); + auto dopen_from = basename; +#else + auto dopen_from = filename; +#endif + + auto lib_hdl = dlopen(dopen_from.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!lib_hdl) { + throw IOException("Extension \"%s\" could not be loaded: %s", filename, GetDLError()); + } + + ext_version_fun_t version_fun; + auto version_fun_name = basename + "_version"; + + version_fun = LoadFunctionFromDLL(lib_hdl, version_fun_name, filename); + + std::string engine_version = std::string(DuckDB::LibraryVersion()); + + auto version_fun_result = (*version_fun)(); + if (version_fun_result == nullptr) { + throw InvalidInputException("Extension \"%s\" returned a nullptr", filename); + } + std::string extension_version = std::string(version_fun_result); + + // Trim v's if necessary + std::string extension_version_trimmed = extension_version; + std::string engine_version_trimmed = engine_version; + if (extension_version.length() > 0 && extension_version[0] == 'v') { + extension_version_trimmed = extension_version.substr(1); + } + if (engine_version.length() > 0 && engine_version[0] == 'v') { + engine_version_trimmed = engine_version.substr(1); + } + + if (extension_version_trimmed != engine_version_trimmed) { + throw InvalidInputException("Extension \"%s\" version (%s) does not match DuckDB version (%s)", filename, + extension_version, engine_version); + } + + result.basename = basename; + result.filename = filename; + result.lib_hdl = lib_hdl; + return true; +#endif +} + +ExtensionInitResult ExtensionHelper::InitialLoad(DBConfig &config, FileSystem &fs, const string &extension, + optional_ptr client_config) { + string error; + ExtensionInitResult result; + if (!TryInitialLoad(config, fs, extension, result, error, client_config)) { + if (!ExtensionHelper::AllowAutoInstall(extension)) { + throw IOException(error); + } + // the extension load failed - try installing the extension + ExtensionHelper::InstallExtension(config, fs, extension, false); + // try loading again + if (!TryInitialLoad(config, fs, extension, result, error, client_config)) { + throw IOException(error); + } + } + return result; +} + +bool ExtensionHelper::IsFullPath(const string &extension) { + return StringUtil::Contains(extension, ".") || StringUtil::Contains(extension, "/") || + StringUtil::Contains(extension, "\\"); +} + +string ExtensionHelper::GetExtensionName(const string &original_name) { + auto extension = StringUtil::Lower(original_name); + if (!IsFullPath(extension)) { + return ExtensionHelper::ApplyExtensionAlias(extension); + } + auto splits = StringUtil::Split(StringUtil::Replace(extension, "\\", "/"), '/'); + if (splits.empty()) { + return ExtensionHelper::ApplyExtensionAlias(extension); + } + splits = StringUtil::Split(splits.back(), '.'); + if (splits.empty()) { + return ExtensionHelper::ApplyExtensionAlias(extension); + } + return ExtensionHelper::ApplyExtensionAlias(splits.front()); +} + +void ExtensionHelper::LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension, + optional_ptr client_config) { + if (db.ExtensionIsLoaded(extension)) { + return; + } +#ifdef DUCKDB_DISABLE_EXTENSION_LOAD + throw PermissionException("Loading external extensions is disabled through a compile time flag"); +#else + auto res = InitialLoad(DBConfig::GetConfig(db), fs, extension, client_config); + auto init_fun_name = res.basename + "_init"; + + ext_init_fun_t init_fun; + init_fun = LoadFunctionFromDLL(res.lib_hdl, init_fun_name, res.filename); + + try { + (*init_fun)(db); + } catch (std::exception &e) { + throw InvalidInputException("Initialization function \"%s\" from file \"%s\" threw an exception: \"%s\"", + init_fun_name, res.filename, e.what()); + } + + db.SetExtensionLoaded(extension); +#endif +} + +void ExtensionHelper::LoadExternalExtension(ClientContext &context, const string &extension) { + LoadExternalExtension(DatabaseInstance::GetDatabase(context), FileSystem::GetFileSystem(context), extension, + &ClientConfig::GetConfig(context)); +} + +string ExtensionHelper::ExtractExtensionPrefixFromPath(const string &path) { + auto first_colon = path.find(':'); + if (first_colon == string::npos || first_colon < 2) { // needs to be at least two characters because windows c: ... + return ""; + } + auto extension = path.substr(0, first_colon); + + if (path.substr(first_colon, 3) == "://") { + // these are not extensions + return ""; + } + + D_ASSERT(extension.size() > 1); + // needs to be alphanumeric + for (auto &ch : extension) { + if (!isalnum(ch) && ch != '_') { + return ""; + } + } + return extension; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/extension/extension_util.cpp b/src/duckdb/src/main/extension/extension_util.cpp new file mode 100644 index 00000000..383e8bbe --- /dev/null +++ b/src/duckdb/src/main/extension/extension_util.cpp @@ -0,0 +1,162 @@ +#include "duckdb/main/extension_util.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" +#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/parser/parsed_data/create_table_function_info.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/parser/parsed_data/create_collation_info.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, ScalarFunctionSet set) { + D_ASSERT(!set.name.empty()); + CreateScalarFunctionInfo info(std::move(set)); + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + system_catalog.CreateFunction(data, info); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, ScalarFunction function) { + D_ASSERT(!function.name.empty()); + ScalarFunctionSet set(function.name); + set.AddFunction(std::move(function)); + RegisterFunction(db, std::move(set)); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, AggregateFunction function) { + D_ASSERT(!function.name.empty()); + AggregateFunctionSet set(function.name); + set.AddFunction(std::move(function)); + RegisterFunction(db, std::move(set)); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, AggregateFunctionSet set) { + D_ASSERT(!set.name.empty()); + CreateAggregateFunctionInfo info(std::move(set)); + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + system_catalog.CreateFunction(data, info); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, TableFunction function) { + D_ASSERT(!function.name.empty()); + TableFunctionSet set(function.name); + set.AddFunction(std::move(function)); + RegisterFunction(db, std::move(set)); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, TableFunctionSet function) { + D_ASSERT(!function.name.empty()); + CreateTableFunctionInfo info(std::move(function)); + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + system_catalog.CreateFunction(data, info); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, PragmaFunction function) { + D_ASSERT(!function.name.empty()); + PragmaFunctionSet set(function.name); + set.AddFunction(std::move(function)); + RegisterFunction(db, std::move(set)); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, PragmaFunctionSet function) { + D_ASSERT(!function.name.empty()); + auto function_name = function.name; + CreatePragmaFunctionInfo info(std::move(function_name), std::move(function)); + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + system_catalog.CreatePragmaFunction(data, info); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, CopyFunction function) { + CreateCopyFunctionInfo info(std::move(function)); + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + system_catalog.CreateCopyFunction(data, info); +} + +void ExtensionUtil::RegisterFunction(DatabaseInstance &db, CreateMacroInfo &info) { + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + system_catalog.CreateFunction(data, info); +} + +void ExtensionUtil::RegisterCollation(DatabaseInstance &db, CreateCollationInfo &info) { + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + info.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; + system_catalog.CreateCollation(data, info); +} + +void ExtensionUtil::AddFunctionOverload(DatabaseInstance &db, ScalarFunction function) { + auto &scalar_function = ExtensionUtil::GetFunction(db, function.name); + scalar_function.functions.AddFunction(std::move(function)); +} + +void ExtensionUtil::AddFunctionOverload(DatabaseInstance &db, ScalarFunctionSet functions) { // NOLINT + D_ASSERT(!functions.name.empty()); + auto &scalar_function = ExtensionUtil::GetFunction(db, functions.name); + for (auto &function : functions.functions) { + function.name = functions.name; + scalar_function.functions.AddFunction(std::move(function)); + } +} + +void ExtensionUtil::AddFunctionOverload(DatabaseInstance &db, TableFunctionSet functions) { // NOLINT + auto &table_function = ExtensionUtil::GetTableFunction(db, functions.name); + for (auto &function : functions.functions) { + function.name = functions.name; + table_function.functions.AddFunction(std::move(function)); + } +} + +ScalarFunctionCatalogEntry &ExtensionUtil::GetFunction(DatabaseInstance &db, const string &name) { + D_ASSERT(!name.empty()); + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + auto &schema = system_catalog.GetSchema(data, DEFAULT_SCHEMA); + auto catalog_entry = schema.GetEntry(data, CatalogType::SCALAR_FUNCTION_ENTRY, name); + if (!catalog_entry) { + throw InvalidInputException("Function with name \"%s\" not found in ExtensionUtil::GetFunction", name); + } + return catalog_entry->Cast(); +} + +TableFunctionCatalogEntry &ExtensionUtil::GetTableFunction(DatabaseInstance &db, const string &name) { + D_ASSERT(!name.empty()); + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + auto &schema = system_catalog.GetSchema(data, DEFAULT_SCHEMA); + auto catalog_entry = schema.GetEntry(data, CatalogType::TABLE_FUNCTION_ENTRY, name); + if (!catalog_entry) { + throw InvalidInputException("Function with name \"%s\" not found in ExtensionUtil::GetTableFunction", name); + } + return catalog_entry->Cast(); +} + +void ExtensionUtil::RegisterType(DatabaseInstance &db, string type_name, LogicalType type) { + D_ASSERT(!type_name.empty()); + CreateTypeInfo info(std::move(type_name), std::move(type)); + info.temporary = true; + info.internal = true; + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + system_catalog.CreateType(data, info); +} + +void ExtensionUtil::RegisterCastFunction(DatabaseInstance &db, const LogicalType &source, const LogicalType &target, + BoundCastInfo function, int64_t implicit_cast_cost) { + auto &config = DBConfig::GetConfig(db); + auto &casts = config.GetCastFunctions(); + casts.RegisterCastFunction(source, target, std::move(function), implicit_cast_cost); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/materialized_query_result.cpp b/src/duckdb/src/main/materialized_query_result.cpp new file mode 100644 index 00000000..a3a5d9b3 --- /dev/null +++ b/src/duckdb/src/main/materialized_query_result.cpp @@ -0,0 +1,98 @@ +#include "duckdb/main/materialized_query_result.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/box_renderer.hpp" + +namespace duckdb { + +MaterializedQueryResult::MaterializedQueryResult(StatementType statement_type, StatementProperties properties, + vector names_p, unique_ptr collection_p, + ClientProperties client_properties) + : QueryResult(QueryResultType::MATERIALIZED_RESULT, statement_type, std::move(properties), collection_p->Types(), + std::move(names_p), std::move(client_properties)), + collection(std::move(collection_p)), scan_initialized(false) { +} + +MaterializedQueryResult::MaterializedQueryResult(PreservedError error) + : QueryResult(QueryResultType::MATERIALIZED_RESULT, std::move(error)), scan_initialized(false) { +} + +string MaterializedQueryResult::ToString() { + string result; + if (success) { + result = HeaderToString(); + result += "[ Rows: " + to_string(collection->Count()) + "]\n"; + auto &coll = Collection(); + for (auto &row : coll.Rows()) { + for (idx_t col_idx = 0; col_idx < coll.ColumnCount(); col_idx++) { + if (col_idx > 0) { + result += "\t"; + } + auto val = row.GetValue(col_idx); + result += val.IsNull() ? "NULL" : StringUtil::Replace(val.ToString(), string("\0", 1), "\\0"); + } + result += "\n"; + } + result += "\n"; + } else { + result = GetError() + "\n"; + } + return result; +} + +string MaterializedQueryResult::ToBox(ClientContext &context, const BoxRendererConfig &config) { + if (!success) { + return GetError() + "\n"; + } + if (!collection) { + return "Internal error - result was successful but there was no collection"; + } + BoxRenderer renderer(config); + return renderer.ToString(context, names, Collection()); +} + +Value MaterializedQueryResult::GetValue(idx_t column, idx_t index) { + if (!row_collection) { + row_collection = make_uniq(collection->GetRows()); + } + return row_collection->GetValue(column, index); +} + +idx_t MaterializedQueryResult::RowCount() const { + return collection ? collection->Count() : 0; +} + +ColumnDataCollection &MaterializedQueryResult::Collection() { + if (HasError()) { + throw InvalidInputException("Attempting to get collection from an unsuccessful query result\n: Error %s", + GetError()); + } + if (!collection) { + throw InternalException("Missing collection from materialized query result"); + } + return *collection; +} + +unique_ptr MaterializedQueryResult::Fetch() { + return FetchRaw(); +} + +unique_ptr MaterializedQueryResult::FetchRaw() { + if (HasError()) { + throw InvalidInputException("Attempting to fetch from an unsuccessful query result\nError: %s", GetError()); + } + auto result = make_uniq(); + collection->InitializeScanChunk(*result); + if (!scan_initialized) { + // we disallow zero copy so the chunk is independently usable even after the result is destroyed + collection->InitializeScan(scan_state, ColumnDataScanProperties::DISALLOW_ZERO_COPY); + scan_initialized = true; + } + collection->Scan(scan_state, *result); + if (result->size() == 0) { + return nullptr; + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/pending_query_result.cpp b/src/duckdb/src/main/pending_query_result.cpp new file mode 100644 index 00000000..02501ca4 --- /dev/null +++ b/src/duckdb/src/main/pending_query_result.cpp @@ -0,0 +1,85 @@ +#include "duckdb/main/pending_query_result.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/prepared_statement_data.hpp" + +namespace duckdb { + +PendingQueryResult::PendingQueryResult(shared_ptr context_p, PreparedStatementData &statement, + vector types_p, bool allow_stream_result) + : BaseQueryResult(QueryResultType::PENDING_RESULT, statement.statement_type, statement.properties, + std::move(types_p), statement.names), + context(std::move(context_p)), allow_stream_result(allow_stream_result) { +} + +PendingQueryResult::PendingQueryResult(PreservedError error) + : BaseQueryResult(QueryResultType::PENDING_RESULT, std::move(error)) { +} + +PendingQueryResult::~PendingQueryResult() { +} + +unique_ptr PendingQueryResult::LockContext() { + if (!context) { + if (HasError()) { + throw InvalidInputException( + "Attempting to execute an unsuccessful or closed pending query result\nError: %s", GetError()); + } + throw InvalidInputException("Attempting to execute an unsuccessful or closed pending query result"); + } + return context->LockContext(); +} + +void PendingQueryResult::CheckExecutableInternal(ClientContextLock &lock) { + bool invalidated = HasError() || !context; + if (!invalidated) { + invalidated = !context->IsActiveResult(lock, this); + } + if (invalidated) { + if (HasError()) { + throw InvalidInputException( + "Attempting to execute an unsuccessful or closed pending query result\nError: %s", GetError()); + } + throw InvalidInputException("Attempting to execute an unsuccessful or closed pending query result"); + } +} + +PendingExecutionResult PendingQueryResult::ExecuteTask() { + auto lock = LockContext(); + return ExecuteTaskInternal(*lock); +} + +PendingExecutionResult PendingQueryResult::ExecuteTaskInternal(ClientContextLock &lock) { + CheckExecutableInternal(lock); + return context->ExecuteTaskInternal(lock, *this); +} + +unique_ptr PendingQueryResult::ExecuteInternal(ClientContextLock &lock) { + CheckExecutableInternal(lock); + // Busy wait while execution is not finished + while (!IsFinished(ExecuteTaskInternal(lock))) { + } + if (HasError()) { + return make_uniq(error); + } + auto result = context->FetchResultInternal(lock, *this); + Close(); + return result; +} + +unique_ptr PendingQueryResult::Execute() { + auto lock = LockContext(); + return ExecuteInternal(*lock); +} + +void PendingQueryResult::Close() { + context.reset(); +} + +bool PendingQueryResult::IsFinished(PendingExecutionResult result) { + if (result == PendingExecutionResult::RESULT_READY || result == PendingExecutionResult::EXECUTION_ERROR) { + return true; + } + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/prepared_statement.cpp b/src/duckdb/src/main/prepared_statement.cpp new file mode 100644 index 00000000..a50997c2 --- /dev/null +++ b/src/duckdb/src/main/prepared_statement.cpp @@ -0,0 +1,119 @@ +#include "duckdb/main/prepared_statement.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/prepared_statement_data.hpp" + +namespace duckdb { + +PreparedStatement::PreparedStatement(shared_ptr context, shared_ptr data_p, + string query, idx_t n_param, case_insensitive_map_t named_param_map_p) + : context(std::move(context)), data(std::move(data_p)), query(std::move(query)), success(true), n_param(n_param), + named_param_map(std::move(named_param_map_p)) { + D_ASSERT(data || !success); +} + +PreparedStatement::PreparedStatement(PreservedError error) : context(nullptr), success(false), error(std::move(error)) { +} + +PreparedStatement::~PreparedStatement() { +} + +const string &PreparedStatement::GetError() { + D_ASSERT(HasError()); + return error.Message(); +} + +PreservedError &PreparedStatement::GetErrorObject() { + return error; +} + +bool PreparedStatement::HasError() const { + return !success; +} + +idx_t PreparedStatement::ColumnCount() { + D_ASSERT(data); + return data->types.size(); +} + +StatementType PreparedStatement::GetStatementType() { + D_ASSERT(data); + return data->statement_type; +} + +StatementProperties PreparedStatement::GetStatementProperties() { + D_ASSERT(data); + return data->properties; +} + +const vector &PreparedStatement::GetTypes() { + D_ASSERT(data); + return data->types; +} + +const vector &PreparedStatement::GetNames() { + D_ASSERT(data); + return data->names; +} + +case_insensitive_map_t PreparedStatement::GetExpectedParameterTypes() const { + D_ASSERT(data); + case_insensitive_map_t expected_types(data->value_map.size()); + for (auto &it : data->value_map) { + auto &identifier = it.first; + D_ASSERT(data->value_map.count(identifier)); + D_ASSERT(it.second); + expected_types[identifier] = it.second->GetValue().type(); + } + return expected_types; +} + +unique_ptr PreparedStatement::Execute(case_insensitive_map_t &named_values, + bool allow_stream_result) { + auto pending = PendingQuery(named_values, allow_stream_result); + if (pending->HasError()) { + return make_uniq(pending->GetErrorObject()); + } + return pending->Execute(); +} + +unique_ptr PreparedStatement::Execute(vector &values, bool allow_stream_result) { + auto pending = PendingQuery(values, allow_stream_result); + if (pending->HasError()) { + return make_uniq(pending->GetErrorObject()); + } + return pending->Execute(); +} + +unique_ptr PreparedStatement::PendingQuery(vector &values, bool allow_stream_result) { + case_insensitive_map_t named_values; + for (idx_t i = 0; i < values.size(); i++) { + auto &val = values[i]; + named_values[std::to_string(i + 1)] = val; + } + return PendingQuery(named_values, allow_stream_result); +} + +unique_ptr PreparedStatement::PendingQuery(case_insensitive_map_t &named_values, + bool allow_stream_result) { + if (!success) { + auto exception = InvalidInputException("Attempting to execute an unsuccessfully prepared statement!"); + return make_uniq(PreservedError(exception)); + } + PendingQueryParameters parameters; + parameters.parameters = &named_values; + + try { + VerifyParameters(named_values, named_param_map); + } catch (const Exception &ex) { + return make_uniq(PreservedError(ex)); + } + + D_ASSERT(data); + parameters.allow_stream_result = allow_stream_result && data->properties.allow_stream_result; + auto result = context->PendingQuery(query, data, parameters); + // The result should not contain any reference to the 'vector parameters.parameters' + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/prepared_statement_data.cpp b/src/duckdb/src/main/prepared_statement_data.cpp new file mode 100644 index 00000000..4bc356ce --- /dev/null +++ b/src/duckdb/src/main/prepared_statement_data.cpp @@ -0,0 +1,91 @@ +#include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +PreparedStatementData::PreparedStatementData(StatementType type) : statement_type(type) { +} + +PreparedStatementData::~PreparedStatementData() { +} + +void PreparedStatementData::CheckParameterCount(idx_t parameter_count) { + const auto required = properties.parameter_count; + if (parameter_count != required) { + throw BinderException("Parameter/argument count mismatch for prepared statement. Expected %llu, got %llu", + required, parameter_count); + } +} + +bool PreparedStatementData::RequireRebind(ClientContext &context, optional_ptr> values) { + idx_t count = values ? values->size() : 0; + CheckParameterCount(count); + if (!unbound_statement) { + // no unbound statement!? cannot rebind? + return false; + } + if (!properties.bound_all_parameters) { + // parameters not yet bound: query always requires a rebind + return true; + } + if (Catalog::GetSystemCatalog(context).GetCatalogVersion() != catalog_version) { + //! context is out of bounds + return true; + } + for (auto &it : value_map) { + auto &identifier = it.first; + auto lookup = values->find(identifier); + D_ASSERT(lookup != values->end()); + if (lookup->second.type() != it.second->return_type) { + return true; + } + } + return false; +} + +void PreparedStatementData::Bind(case_insensitive_map_t values) { + // set parameters + D_ASSERT(!unbound_statement || unbound_statement->n_param == properties.parameter_count); + CheckParameterCount(values.size()); + + // bind the required values + for (auto &it : value_map) { + const string &identifier = it.first; + auto lookup = values.find(identifier); + if (lookup == values.end()) { + throw BinderException("Could not find parameter with identifier %s", identifier); + } + D_ASSERT(it.second); + auto &value = lookup->second; + if (!value.DefaultTryCastAs(it.second->return_type)) { + throw BinderException( + "Type mismatch for binding parameter with identifier %s, expected type %s but got type %s", identifier, + it.second->return_type.ToString().c_str(), value.type().ToString().c_str()); + } + it.second->SetValue(value); + } +} + +bool PreparedStatementData::TryGetType(const string &identifier, LogicalType &result) { + auto it = value_map.find(identifier); + if (it == value_map.end()) { + return false; + } + if (it->second->return_type.id() != LogicalTypeId::INVALID) { + result = it->second->return_type; + } else { + result = it->second->GetValue().type(); + } + return true; +} + +LogicalType PreparedStatementData::GetType(const string &identifier) { + LogicalType result; + if (!TryGetType(identifier, result)) { + throw BinderException("Could not find parameter identified with: %s", identifier); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/query_profiler.cpp b/src/duckdb/src/main/query_profiler.cpp new file mode 100644 index 00000000..26b17f04 --- /dev/null +++ b/src/duckdb/src/main/query_profiler.cpp @@ -0,0 +1,710 @@ +#include "duckdb/main/query_profiler.hpp" + +#include "duckdb/common/fstream.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/tree_renderer.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/helper/physical_execute.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +#include +#include + +namespace duckdb { + +QueryProfiler::QueryProfiler(ClientContext &context_p) + : context(context_p), running(false), query_requires_profiling(false), is_explain_analyze(false) { +} + +bool QueryProfiler::IsEnabled() const { + return is_explain_analyze ? true : ClientConfig::GetConfig(context).enable_profiler; +} + +bool QueryProfiler::IsDetailedEnabled() const { + return is_explain_analyze ? false : ClientConfig::GetConfig(context).enable_detailed_profiling; +} + +ProfilerPrintFormat QueryProfiler::GetPrintFormat() const { + return ClientConfig::GetConfig(context).profiler_print_format; +} + +bool QueryProfiler::PrintOptimizerOutput() const { + return GetPrintFormat() == ProfilerPrintFormat::QUERY_TREE_OPTIMIZER || IsDetailedEnabled(); +} + +string QueryProfiler::GetSaveLocation() const { + return is_explain_analyze ? string() : ClientConfig::GetConfig(context).profiler_save_location; +} + +QueryProfiler &QueryProfiler::Get(ClientContext &context) { + return *ClientData::Get(context).profiler; +} + +void QueryProfiler::StartQuery(string query, bool is_explain_analyze, bool start_at_optimizer) { + if (is_explain_analyze) { + StartExplainAnalyze(); + } + if (!IsEnabled()) { + return; + } + if (start_at_optimizer && !PrintOptimizerOutput()) { + // This is the StartQuery call before the optimizer, but we don't have to print optimizer output + return; + } + if (running) { + // Called while already running: this should only happen when we print optimizer output + D_ASSERT(PrintOptimizerOutput()); + return; + } + this->running = true; + this->query = std::move(query); + tree_map.clear(); + root = nullptr; + phase_timings.clear(); + phase_stack.clear(); + + main_query.Start(); +} + +bool QueryProfiler::OperatorRequiresProfiling(PhysicalOperatorType op_type) { + switch (op_type) { + case PhysicalOperatorType::ORDER_BY: + case PhysicalOperatorType::RESERVOIR_SAMPLE: + case PhysicalOperatorType::STREAMING_SAMPLE: + case PhysicalOperatorType::LIMIT: + case PhysicalOperatorType::LIMIT_PERCENT: + case PhysicalOperatorType::STREAMING_LIMIT: + case PhysicalOperatorType::TOP_N: + case PhysicalOperatorType::WINDOW: + case PhysicalOperatorType::UNNEST: + case PhysicalOperatorType::UNGROUPED_AGGREGATE: + case PhysicalOperatorType::HASH_GROUP_BY: + case PhysicalOperatorType::FILTER: + case PhysicalOperatorType::PROJECTION: + case PhysicalOperatorType::COPY_TO_FILE: + case PhysicalOperatorType::TABLE_SCAN: + case PhysicalOperatorType::CHUNK_SCAN: + case PhysicalOperatorType::DELIM_SCAN: + case PhysicalOperatorType::EXPRESSION_SCAN: + case PhysicalOperatorType::BLOCKWISE_NL_JOIN: + case PhysicalOperatorType::NESTED_LOOP_JOIN: + case PhysicalOperatorType::HASH_JOIN: + case PhysicalOperatorType::CROSS_PRODUCT: + case PhysicalOperatorType::PIECEWISE_MERGE_JOIN: + case PhysicalOperatorType::IE_JOIN: + case PhysicalOperatorType::DELIM_JOIN: + case PhysicalOperatorType::UNION: + case PhysicalOperatorType::RECURSIVE_CTE: + case PhysicalOperatorType::EMPTY_RESULT: + return true; + default: + return false; + } +} + +void QueryProfiler::Finalize(TreeNode &node) { + for (auto &child : node.children) { + Finalize(*child); + if (node.type == PhysicalOperatorType::UNION) { + node.info.elements += child->info.elements; + } + } +} + +void QueryProfiler::StartExplainAnalyze() { + this->is_explain_analyze = true; +} + +void QueryProfiler::EndQuery() { + lock_guard guard(flush_lock); + if (!IsEnabled() || !running) { + return; + } + + main_query.End(); + if (root) { + Finalize(*root); + } + this->running = false; + // print or output the query profiling after termination + // EXPLAIN ANALYSE should not be outputted by the profiler + if (IsEnabled() && !is_explain_analyze) { + string query_info = ToString(); + auto save_location = GetSaveLocation(); + if (!ClientConfig::GetConfig(context).emit_profiler_output) { + // disable output + } else if (save_location.empty()) { + Printer::Print(query_info); + Printer::Print("\n"); + } else { + WriteToFile(save_location.c_str(), query_info); + } + } + this->is_explain_analyze = false; +} +string QueryProfiler::ToString() const { + const auto format = GetPrintFormat(); + switch (format) { + case ProfilerPrintFormat::QUERY_TREE: + case ProfilerPrintFormat::QUERY_TREE_OPTIMIZER: + return QueryTreeToString(); + case ProfilerPrintFormat::JSON: + return ToJSON(); + default: + throw InternalException("Unknown ProfilerPrintFormat \"%s\"", format); + } +} + +void QueryProfiler::StartPhase(string new_phase) { + if (!IsEnabled() || !running) { + return; + } + + if (!phase_stack.empty()) { + // there are active phases + phase_profiler.End(); + // add the timing to all phases prior to this one + string prefix = ""; + for (auto &phase : phase_stack) { + phase_timings[phase] += phase_profiler.Elapsed(); + prefix += phase + " > "; + } + // when there are previous phases, we prefix the current phase with those phases + new_phase = prefix + new_phase; + } + + // start a new phase + phase_stack.push_back(new_phase); + // restart the timer + phase_profiler.Start(); +} + +void QueryProfiler::EndPhase() { + if (!IsEnabled() || !running) { + return; + } + D_ASSERT(phase_stack.size() > 0); + + // end the timer + phase_profiler.End(); + // add the timing to all currently active phases + for (auto &phase : phase_stack) { + phase_timings[phase] += phase_profiler.Elapsed(); + } + // now remove the last added phase + phase_stack.pop_back(); + + if (!phase_stack.empty()) { + phase_profiler.Start(); + } +} + +void QueryProfiler::Initialize(const PhysicalOperator &root_op) { + if (!IsEnabled() || !running) { + return; + } + this->query_requires_profiling = false; + this->root = CreateTree(root_op); + if (!query_requires_profiling) { + // query does not require profiling: disable profiling for this query + this->running = false; + tree_map.clear(); + root = nullptr; + phase_timings.clear(); + phase_stack.clear(); + } +} + +OperatorProfiler::OperatorProfiler(bool enabled_p) : enabled(enabled_p), active_operator(nullptr) { +} + +void OperatorProfiler::StartOperator(optional_ptr phys_op) { + if (!enabled) { + return; + } + + if (active_operator) { + throw InternalException("OperatorProfiler: Attempting to call StartOperator while another operator is active"); + } + + active_operator = phys_op; + + // start timing for current element + op.Start(); +} + +void OperatorProfiler::EndOperator(optional_ptr chunk) { + if (!enabled) { + return; + } + + if (!active_operator) { + throw InternalException("OperatorProfiler: Attempting to call EndOperator while another operator is active"); + } + + // finish timing for the current element + op.End(); + + AddTiming(*active_operator, op.Elapsed(), chunk ? chunk->size() : 0); + active_operator = nullptr; +} + +void OperatorProfiler::AddTiming(const PhysicalOperator &op, double time, idx_t elements) { + if (!enabled) { + return; + } + if (!Value::DoubleIsFinite(time)) { + return; + } + auto entry = timings.find(op); + if (entry == timings.end()) { + // add new entry + timings[op] = OperatorInformation(time, elements); + } else { + // add to existing entry + entry->second.time += time; + entry->second.elements += elements; + } +} +void OperatorProfiler::Flush(const PhysicalOperator &phys_op, ExpressionExecutor &expression_executor, + const string &name, int id) { + auto entry = timings.find(phys_op); + if (entry == timings.end()) { + return; + } + auto &operator_timing = timings.find(phys_op)->second; + if (int(operator_timing.executors_info.size()) <= id) { + operator_timing.executors_info.resize(id + 1); + } + operator_timing.executors_info[id] = make_uniq(expression_executor, name, id); + operator_timing.name = phys_op.GetName(); +} + +void QueryProfiler::Flush(OperatorProfiler &profiler) { + lock_guard guard(flush_lock); + if (!IsEnabled() || !running) { + return; + } + for (auto &node : profiler.timings) { + auto &op = node.first.get(); + auto entry = tree_map.find(op); + D_ASSERT(entry != tree_map.end()); + auto &tree_node = entry->second.get(); + + tree_node.info.time += node.second.time; + tree_node.info.elements += node.second.elements; + if (!IsDetailedEnabled()) { + continue; + } + for (auto &info : node.second.executors_info) { + if (!info) { + continue; + } + auto info_id = info->id; + if (int32_t(tree_node.info.executors_info.size()) <= info_id) { + tree_node.info.executors_info.resize(info_id + 1); + } + tree_node.info.executors_info[info_id] = std::move(info); + } + } + profiler.timings.clear(); +} + +static string DrawPadded(const string &str, idx_t width) { + if (str.size() > width) { + return str.substr(0, width); + } else { + width -= str.size(); + int half_spaces = width / 2; + int extra_left_space = width % 2 != 0 ? 1 : 0; + return string(half_spaces + extra_left_space, ' ') + str + string(half_spaces, ' '); + } +} + +static string RenderTitleCase(string str) { + str = StringUtil::Lower(str); + str[0] = toupper(str[0]); + for (idx_t i = 0; i < str.size(); i++) { + if (str[i] == '_') { + str[i] = ' '; + if (i + 1 < str.size()) { + str[i + 1] = toupper(str[i + 1]); + } + } + } + return str; +} + +static string RenderTiming(double timing) { + string timing_s; + if (timing >= 1) { + timing_s = StringUtil::Format("%.2f", timing); + } else if (timing >= 0.1) { + timing_s = StringUtil::Format("%.3f", timing); + } else { + timing_s = StringUtil::Format("%.4f", timing); + } + return timing_s + "s"; +} + +string QueryProfiler::QueryTreeToString() const { + std::stringstream str; + QueryTreeToStream(str); + return str.str(); +} + +void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { + if (!IsEnabled()) { + ss << "Query profiling is disabled. Call " + "Connection::EnableProfiling() to enable profiling!"; + return; + } + ss << "┌─────────────────────────────────────┐\n"; + ss << "│┌───────────────────────────────────┐│\n"; + ss << "││ Query Profiling Information ││\n"; + ss << "│└───────────────────────────────────┘│\n"; + ss << "└─────────────────────────────────────┘\n"; + ss << StringUtil::Replace(query, "\n", " ") + "\n"; + + // checking the tree to ensure the query is really empty + // the query string is empty when a logical plan is deserialized + if (query.empty() && !root) { + return; + } + + if (context.client_data->http_state && !context.client_data->http_state->IsEmpty()) { + string read = + "in: " + StringUtil::BytesToHumanReadableString(context.client_data->http_state->total_bytes_received); + string written = + "out: " + StringUtil::BytesToHumanReadableString(context.client_data->http_state->total_bytes_sent); + string head = "#HEAD: " + to_string(context.client_data->http_state->head_count); + string get = "#GET: " + to_string(context.client_data->http_state->get_count); + string put = "#PUT: " + to_string(context.client_data->http_state->put_count); + string post = "#POST: " + to_string(context.client_data->http_state->post_count); + + constexpr idx_t TOTAL_BOX_WIDTH = 39; + ss << "┌─────────────────────────────────────┐\n"; + ss << "│┌───────────────────────────────────┐│\n"; + ss << "││ HTTP Stats: ││\n"; + ss << "││ ││\n"; + ss << "││" + DrawPadded(read, TOTAL_BOX_WIDTH - 4) + "││\n"; + ss << "││" + DrawPadded(written, TOTAL_BOX_WIDTH - 4) + "││\n"; + ss << "││" + DrawPadded(head, TOTAL_BOX_WIDTH - 4) + "││\n"; + ss << "││" + DrawPadded(get, TOTAL_BOX_WIDTH - 4) + "││\n"; + ss << "││" + DrawPadded(put, TOTAL_BOX_WIDTH - 4) + "││\n"; + ss << "││" + DrawPadded(post, TOTAL_BOX_WIDTH - 4) + "││\n"; + ss << "│└───────────────────────────────────┘│\n"; + ss << "└─────────────────────────────────────┘\n"; + } + + constexpr idx_t TOTAL_BOX_WIDTH = 39; + ss << "┌─────────────────────────────────────┐\n"; + ss << "│┌───────────────────────────────────┐│\n"; + string total_time = "Total Time: " + RenderTiming(main_query.Elapsed()); + ss << "││" + DrawPadded(total_time, TOTAL_BOX_WIDTH - 4) + "││\n"; + ss << "│└───────────────────────────────────┘│\n"; + ss << "└─────────────────────────────────────┘\n"; + // print phase timings + if (PrintOptimizerOutput()) { + bool has_previous_phase = false; + for (const auto &entry : GetOrderedPhaseTimings()) { + if (!StringUtil::Contains(entry.first, " > ")) { + // primary phase! + if (has_previous_phase) { + ss << "│└───────────────────────────────────┘│\n"; + ss << "└─────────────────────────────────────┘\n"; + } + ss << "┌─────────────────────────────────────┐\n"; + ss << "│" + + DrawPadded(RenderTitleCase(entry.first) + ": " + RenderTiming(entry.second), + TOTAL_BOX_WIDTH - 2) + + "│\n"; + ss << "│┌───────────────────────────────────┐│\n"; + has_previous_phase = true; + } else { + string entry_name = StringUtil::Split(entry.first, " > ")[1]; + ss << "││" + + DrawPadded(RenderTitleCase(entry_name) + ": " + RenderTiming(entry.second), + TOTAL_BOX_WIDTH - 4) + + "││\n"; + } + } + if (has_previous_phase) { + ss << "│└───────────────────────────────────┘│\n"; + ss << "└─────────────────────────────────────┘\n"; + } + } + // render the main operator tree + if (root) { + Render(*root, ss); + } +} + +static string JSONSanitize(const string &text) { + string result; + result.reserve(text.size()); + for (idx_t i = 0; i < text.size(); i++) { + switch (text[i]) { + case '\b': + result += "\\b"; + break; + case '\f': + result += "\\f"; + break; + case '\n': + result += "\\n"; + break; + case '\r': + result += "\\r"; + break; + case '\t': + result += "\\t"; + break; + case '"': + result += "\\\""; + break; + case '\\': + result += "\\\\"; + break; + default: + result += text[i]; + break; + } + } + return result; +} + +// Print a row +static void PrintRow(std::ostream &ss, const string &annotation, int id, const string &name, double time, + int sample_counter, int tuple_counter, const string &extra_info, int depth) { + ss << string(depth * 3, ' ') << " {\n"; + ss << string(depth * 3, ' ') << " \"annotation\": \"" + JSONSanitize(annotation) + "\",\n"; + ss << string(depth * 3, ' ') << " \"id\": " + to_string(id) + ",\n"; + ss << string(depth * 3, ' ') << " \"name\": \"" + JSONSanitize(name) + "\",\n"; +#if defined(RDTSC) + ss << string(depth * 3, ' ') << " \"timing\": \"NULL\" ,\n"; + ss << string(depth * 3, ' ') << " \"cycles_per_tuple\": " + StringUtil::Format("%.4f", time) + ",\n"; +#else + ss << string(depth * 3, ' ') << " \"timing\":" + to_string(time) + ",\n"; + ss << string(depth * 3, ' ') << " \"cycles_per_tuple\": \"NULL\" ,\n"; +#endif + ss << string(depth * 3, ' ') << " \"sample_size\": " << to_string(sample_counter) + ",\n"; + ss << string(depth * 3, ' ') << " \"input_size\": " << to_string(tuple_counter) + ",\n"; + ss << string(depth * 3, ' ') << " \"extra_info\": \"" << JSONSanitize(extra_info) + "\"\n"; + ss << string(depth * 3, ' ') << " },\n"; +} + +static void ExtractFunctions(std::ostream &ss, ExpressionInfo &info, int &fun_id, int depth) { + if (info.hasfunction) { + double time = info.sample_tuples_count == 0 ? 0 : int(info.function_time) / double(info.sample_tuples_count); + PrintRow(ss, "Function", fun_id++, info.function_name, time, info.sample_tuples_count, info.tuples_count, "", + depth); + } + if (info.children.empty()) { + return; + } + // extract the children of this node + for (auto &child : info.children) { + ExtractFunctions(ss, *child, fun_id, depth); + } +} + +static void ToJSONRecursive(QueryProfiler::TreeNode &node, std::ostream &ss, int depth = 1) { + ss << string(depth * 3, ' ') << " {\n"; + ss << string(depth * 3, ' ') << " \"name\": \"" + JSONSanitize(node.name) + "\",\n"; + ss << string(depth * 3, ' ') << " \"timing\":" + to_string(node.info.time) + ",\n"; + ss << string(depth * 3, ' ') << " \"cardinality\":" + to_string(node.info.elements) + ",\n"; + ss << string(depth * 3, ' ') << " \"extra_info\": \"" + JSONSanitize(node.extra_info) + "\",\n"; + ss << string(depth * 3, ' ') << " \"timings\": ["; + int32_t function_counter = 1; + int32_t expression_counter = 1; + ss << "\n "; + for (auto &expr_executor : node.info.executors_info) { + // For each Expression tree + if (!expr_executor) { + continue; + } + for (auto &expr_timer : expr_executor->roots) { + double time = expr_timer->sample_tuples_count == 0 + ? 0 + : double(expr_timer->time) / double(expr_timer->sample_tuples_count); + PrintRow(ss, "ExpressionRoot", expression_counter++, expr_timer->name, time, + expr_timer->sample_tuples_count, expr_timer->tuples_count, expr_timer->extra_info, depth + 1); + // Extract all functions inside the tree + ExtractFunctions(ss, *expr_timer->root, function_counter, depth + 1); + } + } + ss.seekp(-2, ss.cur); + ss << "\n"; + ss << string(depth * 3, ' ') << " ],\n"; + ss << string(depth * 3, ' ') << " \"children\": [\n"; + if (node.children.empty()) { + ss << string(depth * 3, ' ') << " ]\n"; + } else { + for (idx_t i = 0; i < node.children.size(); i++) { + if (i > 0) { + ss << ",\n"; + } + ToJSONRecursive(*node.children[i], ss, depth + 1); + } + ss << string(depth * 3, ' ') << " ]\n"; + } + ss << string(depth * 3, ' ') << " }\n"; +} + +string QueryProfiler::ToJSON() const { + if (!IsEnabled()) { + return "{ \"result\": \"disabled\" }\n"; + } + if (query.empty() && !root) { + return "{ \"result\": \"empty\" }\n"; + } + if (!root) { + return "{ \"result\": \"error\" }\n"; + } + std::stringstream ss; + ss << "{\n"; + ss << " \"name\": \"Query\", \n"; + ss << " \"result\": " + to_string(main_query.Elapsed()) + ",\n"; + ss << " \"timing\": " + to_string(main_query.Elapsed()) + ",\n"; + ss << " \"cardinality\": " + to_string(root->info.elements) + ",\n"; + // JSON cannot have literal control characters in string literals + string extra_info = JSONSanitize(query); + ss << " \"extra-info\": \"" + extra_info + "\", \n"; + // print the phase timings + ss << " \"timings\": [\n"; + const auto &ordered_phase_timings = GetOrderedPhaseTimings(); + for (idx_t i = 0; i < ordered_phase_timings.size(); i++) { + if (i > 0) { + ss << ",\n"; + } + ss << " {\n"; + ss << " \"annotation\": \"" + ordered_phase_timings[i].first + "\", \n"; + ss << " \"timing\": " + to_string(ordered_phase_timings[i].second) + "\n"; + ss << " }"; + } + ss << "\n"; + ss << " ],\n"; + // recursively print the physical operator tree + ss << " \"children\": [\n"; + ToJSONRecursive(*root, ss); + ss << " ]\n"; + ss << "}"; + return ss.str(); +} + +void QueryProfiler::WriteToFile(const char *path, string &info) const { + ofstream out(path); + out << info; + out.close(); + // throw an IO exception if it fails to write the file + if (out.fail()) { + throw IOException(strerror(errno)); + } +} + +unique_ptr QueryProfiler::CreateTree(const PhysicalOperator &root, idx_t depth) { + if (OperatorRequiresProfiling(root.type)) { + this->query_requires_profiling = true; + } + auto node = make_uniq(); + node->type = root.type; + node->name = root.GetName(); + node->extra_info = root.ParamsToString(); + node->depth = depth; + tree_map.insert(make_pair(reference(root), reference(*node))); + auto children = root.GetChildren(); + for (auto &child : children) { + auto child_node = CreateTree(child.get(), depth + 1); + node->children.push_back(std::move(child_node)); + } + return node; +} + +void QueryProfiler::Render(const QueryProfiler::TreeNode &node, std::ostream &ss) const { + TreeRenderer renderer; + if (IsDetailedEnabled()) { + renderer.EnableDetailed(); + } else { + renderer.EnableStandard(); + } + renderer.Render(node, ss); +} + +void QueryProfiler::Print() { + Printer::Print(QueryTreeToString()); +} + +vector QueryProfiler::GetOrderedPhaseTimings() const { + vector result; + // first sort the phases alphabetically + vector phases; + for (auto &entry : phase_timings) { + phases.push_back(entry.first); + } + std::sort(phases.begin(), phases.end()); + for (const auto &phase : phases) { + auto entry = phase_timings.find(phase); + D_ASSERT(entry != phase_timings.end()); + result.emplace_back(entry->first, entry->second); + } + return result; +} +void QueryProfiler::Propagate(QueryProfiler &qp) { +} + +void ExpressionInfo::ExtractExpressionsRecursive(unique_ptr &state) { + if (state->child_states.empty()) { + return; + } + // extract the children of this node + for (auto &child : state->child_states) { + auto expr_info = make_uniq(); + if (child->expr.expression_class == ExpressionClass::BOUND_FUNCTION) { + expr_info->hasfunction = true; + expr_info->function_name = child->expr.Cast().function.ToString(); + expr_info->function_time = child->profiler.time; + expr_info->sample_tuples_count = child->profiler.sample_tuples_count; + expr_info->tuples_count = child->profiler.tuples_count; + } + expr_info->ExtractExpressionsRecursive(child); + children.push_back(std::move(expr_info)); + } + return; +} + +ExpressionExecutorInfo::ExpressionExecutorInfo(ExpressionExecutor &executor, const string &name, int id) : id(id) { + // Extract Expression Root Information from ExpressionExecutorStats + for (auto &state : executor.GetStates()) { + roots.push_back(make_uniq(*state, name)); + } +} + +ExpressionRootInfo::ExpressionRootInfo(ExpressionExecutorState &state, string name) + : current_count(state.profiler.current_count), sample_count(state.profiler.sample_count), + sample_tuples_count(state.profiler.sample_tuples_count), tuples_count(state.profiler.tuples_count), + name("expression"), time(state.profiler.time) { + // Use the name of expression-tree as extra-info + extra_info = std::move(name); + auto expression_info_p = make_uniq(); + // Maybe root has a function + if (state.root_state->expr.expression_class == ExpressionClass::BOUND_FUNCTION) { + expression_info_p->hasfunction = true; + expression_info_p->function_name = (state.root_state->expr.Cast()).function.name; + expression_info_p->function_time = state.root_state->profiler.time; + expression_info_p->sample_tuples_count = state.root_state->profiler.sample_tuples_count; + expression_info_p->tuples_count = state.root_state->profiler.tuples_count; + } + expression_info_p->ExtractExpressionsRecursive(state.root_state); + root = std::move(expression_info_p); +} +} // namespace duckdb diff --git a/src/duckdb/src/main/query_result.cpp b/src/duckdb/src/main/query_result.cpp new file mode 100644 index 00000000..d5edc4b9 --- /dev/null +++ b/src/duckdb/src/main/query_result.cpp @@ -0,0 +1,160 @@ +#include "duckdb/main/query_result.hpp" + +#include "duckdb/common/box_renderer.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/main/client_context.hpp" +namespace duckdb { + +BaseQueryResult::BaseQueryResult(QueryResultType type, StatementType statement_type, StatementProperties properties_p, + vector types_p, vector names_p) + : type(type), statement_type(statement_type), properties(std::move(properties_p)), types(std::move(types_p)), + names(std::move(names_p)), success(true) { + D_ASSERT(types.size() == names.size()); +} + +BaseQueryResult::BaseQueryResult(QueryResultType type, PreservedError error) + : type(type), success(false), error(std::move(error)) { +} + +BaseQueryResult::~BaseQueryResult() { +} + +void BaseQueryResult::ThrowError(const string &prepended_message) const { + D_ASSERT(HasError()); + error.Throw(prepended_message); +} + +void BaseQueryResult::SetError(PreservedError error) { + success = !error; + this->error = std::move(error); +} + +bool BaseQueryResult::HasError() const { + D_ASSERT((bool)error == !success); + return !success; +} + +const ExceptionType &BaseQueryResult::GetErrorType() const { + return error.Type(); +} + +const std::string &BaseQueryResult::GetError() { + D_ASSERT(HasError()); + return error.Message(); +} + +PreservedError &BaseQueryResult::GetErrorObject() { + return error; +} + +idx_t BaseQueryResult::ColumnCount() { + return types.size(); +} + +QueryResult::QueryResult(QueryResultType type, StatementType statement_type, StatementProperties properties, + vector types_p, vector names_p, ClientProperties client_properties_p) + : BaseQueryResult(type, statement_type, std::move(properties), std::move(types_p), std::move(names_p)), + client_properties(std::move(client_properties_p)) { +} + +QueryResult::QueryResult(QueryResultType type, PreservedError error) + : BaseQueryResult(type, std::move(error)), client_properties("UTC", ArrowOffsetSize::REGULAR) { +} + +QueryResult::~QueryResult() { +} + +const string &QueryResult::ColumnName(idx_t index) const { + D_ASSERT(index < names.size()); + return names[index]; +} + +string QueryResult::ToBox(ClientContext &context, const BoxRendererConfig &config) { + return ToString(); +} + +unique_ptr QueryResult::Fetch() { + auto chunk = FetchRaw(); + if (!chunk) { + return nullptr; + } + chunk->Flatten(); + return chunk; +} + +bool QueryResult::Equals(QueryResult &other) { // LCOV_EXCL_START + // first compare the success state of the results + if (success != other.success) { + return false; + } + if (!success) { + return error == other.error; + } + // compare names + if (names != other.names) { + return false; + } + // compare types + if (types != other.types) { + return false; + } + // now compare the actual values + // fetch chunks + unique_ptr lchunk, rchunk; + idx_t lindex = 0, rindex = 0; + while (true) { + if (!lchunk || lindex == lchunk->size()) { + lchunk = Fetch(); + lindex = 0; + } + if (!rchunk || rindex == rchunk->size()) { + rchunk = other.Fetch(); + rindex = 0; + } + if (!lchunk && !rchunk) { + return true; + } + if (!lchunk || !rchunk) { + return false; + } + if (lchunk->size() == 0 && rchunk->size() == 0) { + return true; + } + D_ASSERT(lchunk->ColumnCount() == rchunk->ColumnCount()); + for (; lindex < lchunk->size() && rindex < rchunk->size(); lindex++, rindex++) { + for (idx_t col = 0; col < rchunk->ColumnCount(); col++) { + auto lvalue = lchunk->GetValue(col, lindex); + auto rvalue = rchunk->GetValue(col, rindex); + if (lvalue.IsNull() && rvalue.IsNull()) { + continue; + } + if (lvalue.IsNull() != rvalue.IsNull()) { + return false; + } + if (lvalue != rvalue) { + return false; + } + } + } + } +} // LCOV_EXCL_STOP + +void QueryResult::Print() { + Printer::Print(ToString()); +} + +string QueryResult::HeaderToString() { + string result; + for (auto &name : names) { + result += name + "\t"; + } + result += "\n"; + for (auto &type : types) { + result += type.ToString() + "\t"; + } + result += "\n"; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation.cpp b/src/duckdb/src/main/relation.cpp new file mode 100644 index 00000000..b1d14566 --- /dev/null +++ b/src/duckdb/src/main/relation.cpp @@ -0,0 +1,402 @@ +#include "duckdb/main/relation.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/main/relation/aggregate_relation.hpp" +#include "duckdb/main/relation/cross_product_relation.hpp" +#include "duckdb/main/relation/distinct_relation.hpp" +#include "duckdb/main/relation/explain_relation.hpp" +#include "duckdb/main/relation/filter_relation.hpp" +#include "duckdb/main/relation/insert_relation.hpp" +#include "duckdb/main/relation/limit_relation.hpp" +#include "duckdb/main/relation/order_relation.hpp" +#include "duckdb/main/relation/projection_relation.hpp" +#include "duckdb/main/relation/setop_relation.hpp" +#include "duckdb/main/relation/subquery_relation.hpp" +#include "duckdb/main/relation/table_function_relation.hpp" +#include "duckdb/main/relation/create_table_relation.hpp" +#include "duckdb/main/relation/create_view_relation.hpp" +#include "duckdb/main/relation/write_csv_relation.hpp" +#include "duckdb/main/relation/write_parquet_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/main/relation/join_relation.hpp" +#include "duckdb/main/relation/value_relation.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" + +namespace duckdb { + +shared_ptr Relation::Project(const string &select_list) { + return Project(select_list, vector()); +} + +shared_ptr Relation::Project(const string &expression, const string &alias) { + return Project(expression, vector({alias})); +} + +shared_ptr Relation::Project(const string &select_list, const vector &aliases) { + auto expressions = Parser::ParseExpressionList(select_list, context.GetContext()->GetParserOptions()); + return make_shared(shared_from_this(), std::move(expressions), aliases); +} + +shared_ptr Relation::Project(const vector &expressions) { + vector aliases; + return Project(expressions, aliases); +} + +shared_ptr Relation::Project(vector> expressions, + const vector &aliases) { + return make_shared(shared_from_this(), std::move(expressions), aliases); +} + +static vector> StringListToExpressionList(ClientContext &context, + const vector &expressions) { + if (expressions.empty()) { + throw ParserException("Zero expressions provided"); + } + vector> result_list; + for (auto &expr : expressions) { + auto expression_list = Parser::ParseExpressionList(expr, context.GetParserOptions()); + if (expression_list.size() != 1) { + throw ParserException("Expected a single expression in the expression list"); + } + result_list.push_back(std::move(expression_list[0])); + } + return result_list; +} + +shared_ptr Relation::Project(const vector &expressions, const vector &aliases) { + auto result_list = StringListToExpressionList(*context.GetContext(), expressions); + return make_shared(shared_from_this(), std::move(result_list), aliases); +} + +shared_ptr Relation::Filter(const string &expression) { + auto expression_list = Parser::ParseExpressionList(expression, context.GetContext()->GetParserOptions()); + if (expression_list.size() != 1) { + throw ParserException("Expected a single expression as filter condition"); + } + return Filter(std::move(expression_list[0])); +} + +shared_ptr Relation::Filter(unique_ptr expression) { + return make_shared(shared_from_this(), std::move(expression)); +} + +shared_ptr Relation::Filter(const vector &expressions) { + // if there are multiple expressions, we AND them together + auto expression_list = StringListToExpressionList(*context.GetContext(), expressions); + D_ASSERT(!expression_list.empty()); + + auto expr = std::move(expression_list[0]); + for (idx_t i = 1; i < expression_list.size(); i++) { + expr = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(expr), + std::move(expression_list[i])); + } + return make_shared(shared_from_this(), std::move(expr)); +} + +shared_ptr Relation::Limit(int64_t limit, int64_t offset) { + return make_shared(shared_from_this(), limit, offset); +} + +shared_ptr Relation::Order(const string &expression) { + auto order_list = Parser::ParseOrderList(expression, context.GetContext()->GetParserOptions()); + return Order(std::move(order_list)); +} + +shared_ptr Relation::Order(vector order_list) { + return make_shared(shared_from_this(), std::move(order_list)); +} + +shared_ptr Relation::Order(const vector &expressions) { + if (expressions.empty()) { + throw ParserException("Zero ORDER BY expressions provided"); + } + vector order_list; + for (auto &expression : expressions) { + auto inner_list = Parser::ParseOrderList(expression, context.GetContext()->GetParserOptions()); + if (inner_list.size() != 1) { + throw ParserException("Expected a single ORDER BY expression in the expression list"); + } + order_list.push_back(std::move(inner_list[0])); + } + return Order(std::move(order_list)); +} + +shared_ptr Relation::Join(const shared_ptr &other, const string &condition, JoinType type, + JoinRefType ref_type) { + auto expression_list = Parser::ParseExpressionList(condition, context.GetContext()->GetParserOptions()); + D_ASSERT(!expression_list.empty()); + return Join(other, std::move(expression_list), type, ref_type); +} + +shared_ptr Relation::Join(const shared_ptr &other, + vector> expression_list, JoinType type, + JoinRefType ref_type) { + if (expression_list.size() > 1 || expression_list[0]->type == ExpressionType::COLUMN_REF) { + // multiple columns or single column ref: the condition is a USING list + vector using_columns; + for (auto &expr : expression_list) { + if (expr->type != ExpressionType::COLUMN_REF) { + throw ParserException("Expected a single expression as join condition"); + } + auto &colref = expr->Cast(); + if (colref.IsQualified()) { + throw ParserException("Expected unqualified column for column in USING clause"); + } + using_columns.push_back(colref.column_names[0]); + } + return make_shared(shared_from_this(), other, std::move(using_columns), type, ref_type); + } else { + // single expression that is not a column reference: use the expression as a join condition + return make_shared(shared_from_this(), other, std::move(expression_list[0]), type, ref_type); + } +} + +shared_ptr Relation::CrossProduct(const shared_ptr &other, JoinRefType join_ref_type) { + return make_shared(shared_from_this(), other, join_ref_type); +} + +shared_ptr Relation::Union(const shared_ptr &other) { + return make_shared(shared_from_this(), other, SetOperationType::UNION); +} + +shared_ptr Relation::Except(const shared_ptr &other) { + return make_shared(shared_from_this(), other, SetOperationType::EXCEPT); +} + +shared_ptr Relation::Intersect(const shared_ptr &other) { + return make_shared(shared_from_this(), other, SetOperationType::INTERSECT); +} + +shared_ptr Relation::Distinct() { + return make_shared(shared_from_this()); +} + +shared_ptr Relation::Alias(const string &alias) { + return make_shared(shared_from_this(), alias); +} + +shared_ptr Relation::Aggregate(const string &aggregate_list) { + auto expression_list = Parser::ParseExpressionList(aggregate_list, context.GetContext()->GetParserOptions()); + return make_shared(shared_from_this(), std::move(expression_list)); +} + +shared_ptr Relation::Aggregate(const string &aggregate_list, const string &group_list) { + auto expression_list = Parser::ParseExpressionList(aggregate_list, context.GetContext()->GetParserOptions()); + auto groups = Parser::ParseGroupByList(group_list, context.GetContext()->GetParserOptions()); + return make_shared(shared_from_this(), std::move(expression_list), std::move(groups)); +} + +shared_ptr Relation::Aggregate(const vector &aggregates) { + auto aggregate_list = StringListToExpressionList(*context.GetContext(), aggregates); + return make_shared(shared_from_this(), std::move(aggregate_list)); +} + +shared_ptr Relation::Aggregate(const vector &aggregates, const vector &groups) { + auto aggregate_list = StringUtil::Join(aggregates, ", "); + auto group_list = StringUtil::Join(groups, ", "); + return this->Aggregate(aggregate_list, group_list); +} + +shared_ptr Relation::Aggregate(vector> expressions, const string &group_list) { + auto groups = Parser::ParseGroupByList(group_list, context.GetContext()->GetParserOptions()); + return make_shared(shared_from_this(), std::move(expressions), std::move(groups)); +} + +string Relation::GetAlias() { + return "relation"; +} + +unique_ptr Relation::GetTableRef() { + auto select = make_uniq(); + select->node = GetQueryNode(); + return make_uniq(std::move(select), GetAlias()); +} + +unique_ptr Relation::Execute() { + return context.GetContext()->Execute(shared_from_this()); +} + +unique_ptr Relation::ExecuteOrThrow() { + auto res = Execute(); + D_ASSERT(res); + if (res->HasError()) { + res->ThrowError(); + } + return res; +} + +BoundStatement Relation::Bind(Binder &binder) { + SelectStatement stmt; + stmt.node = GetQueryNode(); + return binder.Bind(stmt.Cast()); +} + +shared_ptr Relation::InsertRel(const string &schema_name, const string &table_name) { + return make_shared(shared_from_this(), schema_name, table_name); +} + +void Relation::Insert(const string &table_name) { + Insert(INVALID_SCHEMA, table_name); +} + +void Relation::Insert(const string &schema_name, const string &table_name) { + auto insert = InsertRel(schema_name, table_name); + auto res = insert->Execute(); + if (res->HasError()) { + const string prepended_message = "Failed to insert into table '" + table_name + "': "; + res->ThrowError(prepended_message); + } +} + +void Relation::Insert(const vector> &values) { + vector column_names; + auto rel = make_shared(context.GetContext(), values, std::move(column_names), "values"); + rel->Insert(GetAlias()); +} + +shared_ptr Relation::CreateRel(const string &schema_name, const string &table_name) { + return make_shared(shared_from_this(), schema_name, table_name); +} + +void Relation::Create(const string &table_name) { + Create(INVALID_SCHEMA, table_name); +} + +void Relation::Create(const string &schema_name, const string &table_name) { + auto create = CreateRel(schema_name, table_name); + auto res = create->Execute(); + if (res->HasError()) { + const string prepended_message = "Failed to create table '" + table_name + "': "; + res->ThrowError(prepended_message); + } +} + +shared_ptr Relation::WriteCSVRel(const string &csv_file, case_insensitive_map_t> options) { + return std::make_shared(shared_from_this(), csv_file, std::move(options)); +} + +void Relation::WriteCSV(const string &csv_file, case_insensitive_map_t> options) { + auto write_csv = WriteCSVRel(csv_file, std::move(options)); + auto res = write_csv->Execute(); + if (res->HasError()) { + const string prepended_message = "Failed to write '" + csv_file + "': "; + res->ThrowError(prepended_message); + } +} + +shared_ptr Relation::WriteParquetRel(const string &parquet_file, + case_insensitive_map_t> options) { + auto write_parquet = + std::make_shared(shared_from_this(), parquet_file, std::move(options)); + return std::move(write_parquet); +} + +void Relation::WriteParquet(const string &parquet_file, case_insensitive_map_t> options) { + auto write_parquet = WriteParquetRel(parquet_file, std::move(options)); + auto res = write_parquet->Execute(); + if (res->HasError()) { + const string prepended_message = "Failed to write '" + parquet_file + "': "; + res->ThrowError(prepended_message); + } +} + +shared_ptr Relation::CreateView(const string &name, bool replace, bool temporary) { + return CreateView(INVALID_SCHEMA, name, replace, temporary); +} + +shared_ptr Relation::CreateView(const string &schema_name, const string &name, bool replace, bool temporary) { + auto view = make_shared(shared_from_this(), schema_name, name, replace, temporary); + auto res = view->Execute(); + if (res->HasError()) { + const string prepended_message = "Failed to create view '" + name + "': "; + res->ThrowError(prepended_message); + } + return shared_from_this(); +} + +unique_ptr Relation::Query(const string &sql) { + return context.GetContext()->Query(sql, false); +} + +unique_ptr Relation::Query(const string &name, const string &sql) { + CreateView(name); + return Query(sql); +} + +unique_ptr Relation::Explain(ExplainType type) { + auto explain = make_shared(shared_from_this(), type); + return explain->Execute(); +} + +void Relation::Update(const string &update, const string &condition) { + throw Exception("UPDATE can only be used on base tables!"); +} + +void Relation::Delete(const string &condition) { + throw Exception("DELETE can only be used on base tables!"); +} + +shared_ptr Relation::TableFunction(const std::string &fname, const vector &values, + const named_parameter_map_t &named_parameters) { + return make_shared(context.GetContext(), fname, values, named_parameters, + shared_from_this()); +} + +shared_ptr Relation::TableFunction(const std::string &fname, const vector &values) { + return make_shared(context.GetContext(), fname, values, shared_from_this()); +} + +string Relation::ToString() { + string str; + str += "---------------------\n"; + str += "--- Relation Tree ---\n"; + str += "---------------------\n"; + str += ToString(0); + str += "\n\n"; + str += "---------------------\n"; + str += "-- Result Columns --\n"; + str += "---------------------\n"; + auto &cols = Columns(); + for (idx_t i = 0; i < cols.size(); i++) { + str += "- " + cols[i].Name() + " (" + cols[i].Type().ToString() + ")\n"; + } + return str; +} + +// LCOV_EXCL_START +unique_ptr Relation::GetQueryNode() { + throw InternalException("Cannot create a query node from this node type"); +} + +void Relation::Head(idx_t limit) { + auto limit_node = Limit(limit); + limit_node->Execute()->Print(); +} +// LCOV_EXCL_STOP + +void Relation::Print() { + Printer::Print(ToString()); +} + +string Relation::RenderWhitespace(idx_t depth) { + return string(depth * 2, ' '); +} + +vector> Relation::GetAllDependencies() { + vector> all_dependencies; + Relation *cur = this; + while (cur) { + if (cur->extra_dependencies) { + all_dependencies.push_back(cur->extra_dependencies); + } + cur = cur->ChildRelation(); + } + return all_dependencies; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/aggregate_relation.cpp b/src/duckdb/src/main/relation/aggregate_relation.cpp new file mode 100644 index 00000000..c84036ad --- /dev/null +++ b/src/duckdb/src/main/relation/aggregate_relation.cpp @@ -0,0 +1,93 @@ +#include "duckdb/main/relation/aggregate_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" + +namespace duckdb { + +AggregateRelation::AggregateRelation(shared_ptr child_p, + vector> parsed_expressions) + : Relation(child_p->context, RelationType::AGGREGATE_RELATION), expressions(std::move(parsed_expressions)), + child(std::move(child_p)) { + // bind the expressions + context.GetContext()->TryBindRelation(*this, this->columns); +} + +AggregateRelation::AggregateRelation(shared_ptr child_p, + vector> parsed_expressions, GroupByNode groups_p) + : Relation(child_p->context, RelationType::AGGREGATE_RELATION), expressions(std::move(parsed_expressions)), + groups(std::move(groups_p)), child(std::move(child_p)) { + // bind the expressions + context.GetContext()->TryBindRelation(*this, this->columns); +} + +AggregateRelation::AggregateRelation(shared_ptr child_p, + vector> parsed_expressions, + vector> groups_p) + : Relation(child_p->context, RelationType::AGGREGATE_RELATION), expressions(std::move(parsed_expressions)), + child(std::move(child_p)) { + if (!groups_p.empty()) { + // explicit groups provided: use standard handling + GroupingSet grouping_set; + for (idx_t i = 0; i < groups_p.size(); i++) { + groups.group_expressions.push_back(std::move(groups_p[i])); + grouping_set.insert(i); + } + groups.grouping_sets.push_back(std::move(grouping_set)); + } + // bind the expressions + context.GetContext()->TryBindRelation(*this, this->columns); +} + +unique_ptr AggregateRelation::GetQueryNode() { + auto child_ptr = child.get(); + while (child_ptr->InheritsColumnBindings()) { + child_ptr = child_ptr->ChildRelation(); + } + unique_ptr result; + if (child_ptr->type == RelationType::JOIN_RELATION) { + // child node is a join: push projection into the child query node + result = child->GetQueryNode(); + } else { + // child node is not a join: create a new select node and push the child as a table reference + auto select = make_uniq(); + select->from_table = child->GetTableRef(); + result = std::move(select); + } + D_ASSERT(result->type == QueryNodeType::SELECT_NODE); + auto &select_node = result->Cast(); + if (!groups.group_expressions.empty()) { + select_node.aggregate_handling = AggregateHandling::STANDARD_HANDLING; + select_node.groups = groups.Copy(); + } else { + // no groups provided: automatically figure out groups (if any) + select_node.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; + } + select_node.select_list.clear(); + for (auto &expr : expressions) { + select_node.select_list.push_back(expr->Copy()); + } + return result; +} + +string AggregateRelation::GetAlias() { + return child->GetAlias(); +} + +const vector &AggregateRelation::Columns() { + return columns; +} + +string AggregateRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Aggregate ["; + for (idx_t i = 0; i < expressions.size(); i++) { + if (i != 0) { + str += ", "; + } + str += expressions[i]->ToString(); + } + str += "]\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/create_table_relation.cpp b/src/duckdb/src/main/relation/create_table_relation.cpp new file mode 100644 index 00000000..e8ac45ed --- /dev/null +++ b/src/duckdb/src/main/relation/create_table_relation.cpp @@ -0,0 +1,38 @@ +#include "duckdb/main/relation/create_table_relation.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +CreateTableRelation::CreateTableRelation(shared_ptr child_p, string schema_name, string table_name) + : Relation(child_p->context, RelationType::CREATE_TABLE_RELATION), child(std::move(child_p)), + schema_name(std::move(schema_name)), table_name(std::move(table_name)) { + context.GetContext()->TryBindRelation(*this, this->columns); +} + +BoundStatement CreateTableRelation::Bind(Binder &binder) { + auto select = make_uniq(); + select->node = child->GetQueryNode(); + + CreateStatement stmt; + auto info = make_uniq(); + info->schema = schema_name; + info->table = table_name; + info->query = std::move(select); + info->on_conflict = OnCreateConflict::ERROR_ON_CONFLICT; + stmt.info = std::move(info); + return binder.Bind(stmt.Cast()); +} + +const vector &CreateTableRelation::Columns() { + return columns; +} + +string CreateTableRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Create Table\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/create_view_relation.cpp b/src/duckdb/src/main/relation/create_view_relation.cpp new file mode 100644 index 00000000..b9c80c07 --- /dev/null +++ b/src/duckdb/src/main/relation/create_view_relation.cpp @@ -0,0 +1,47 @@ +#include "duckdb/main/relation/create_view_relation.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +CreateViewRelation::CreateViewRelation(shared_ptr child_p, string view_name_p, bool replace_p, + bool temporary_p) + : Relation(child_p->context, RelationType::CREATE_VIEW_RELATION), child(std::move(child_p)), + view_name(std::move(view_name_p)), replace(replace_p), temporary(temporary_p) { + context.GetContext()->TryBindRelation(*this, this->columns); +} + +CreateViewRelation::CreateViewRelation(shared_ptr child_p, string schema_name_p, string view_name_p, + bool replace_p, bool temporary_p) + : Relation(child_p->context, RelationType::CREATE_VIEW_RELATION), child(std::move(child_p)), + schema_name(std::move(schema_name_p)), view_name(std::move(view_name_p)), replace(replace_p), + temporary(temporary_p) { + context.GetContext()->TryBindRelation(*this, this->columns); +} + +BoundStatement CreateViewRelation::Bind(Binder &binder) { + auto select = make_uniq(); + select->node = child->GetQueryNode(); + + CreateStatement stmt; + auto info = make_uniq(); + info->query = std::move(select); + info->view_name = view_name; + info->temporary = temporary; + info->schema = schema_name; + info->on_conflict = replace ? OnCreateConflict::REPLACE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; + stmt.info = std::move(info); + return binder.Bind(stmt.Cast()); +} + +const vector &CreateViewRelation::Columns() { + return columns; +} + +string CreateViewRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Create View\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/cross_product_relation.cpp b/src/duckdb/src/main/relation/cross_product_relation.cpp new file mode 100644 index 00000000..f4bc12a3 --- /dev/null +++ b/src/duckdb/src/main/relation/cross_product_relation.cpp @@ -0,0 +1,43 @@ +#include "duckdb/main/relation/cross_product_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/tableref/joinref.hpp" + +namespace duckdb { + +CrossProductRelation::CrossProductRelation(shared_ptr left_p, shared_ptr right_p, + JoinRefType ref_type) + : Relation(left_p->context, RelationType::CROSS_PRODUCT_RELATION), left(std::move(left_p)), + right(std::move(right_p)), ref_type(ref_type) { + if (left->context.GetContext() != right->context.GetContext()) { + throw Exception("Cannot combine LEFT and RIGHT relations of different connections!"); + } + context.GetContext()->TryBindRelation(*this, this->columns); +} + +unique_ptr CrossProductRelation::GetQueryNode() { + auto result = make_uniq(); + result->select_list.push_back(make_uniq()); + result->from_table = GetTableRef(); + return std::move(result); +} + +unique_ptr CrossProductRelation::GetTableRef() { + auto cross_product_ref = make_uniq(ref_type); + cross_product_ref->left = left->GetTableRef(); + cross_product_ref->right = right->GetTableRef(); + return std::move(cross_product_ref); +} + +const vector &CrossProductRelation::Columns() { + return this->columns; +} + +string CrossProductRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth); + str = "Cross Product"; + return str + "\n" + left->ToString(depth + 1) + right->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/delete_relation.cpp b/src/duckdb/src/main/relation/delete_relation.cpp new file mode 100644 index 00000000..8afc8226 --- /dev/null +++ b/src/duckdb/src/main/relation/delete_relation.cpp @@ -0,0 +1,39 @@ +#include "duckdb/main/relation/delete_relation.hpp" +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" + +namespace duckdb { + +DeleteRelation::DeleteRelation(ClientContextWrapper &context, unique_ptr condition_p, + string schema_name_p, string table_name_p) + : Relation(context, RelationType::DELETE_RELATION), condition(std::move(condition_p)), + schema_name(std::move(schema_name_p)), table_name(std::move(table_name_p)) { + context.GetContext()->TryBindRelation(*this, this->columns); +} + +BoundStatement DeleteRelation::Bind(Binder &binder) { + auto basetable = make_uniq(); + basetable->schema_name = schema_name; + basetable->table_name = table_name; + + DeleteStatement stmt; + stmt.condition = condition ? condition->Copy() : nullptr; + stmt.table = std::move(basetable); + return binder.Bind(stmt.Cast()); +} + +const vector &DeleteRelation::Columns() { + return columns; +} + +string DeleteRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "DELETE FROM " + table_name; + if (condition) { + str += " WHERE " + condition->ToString(); + } + return str; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/distinct_relation.cpp b/src/duckdb/src/main/relation/distinct_relation.cpp new file mode 100644 index 00000000..0f96d458 --- /dev/null +++ b/src/duckdb/src/main/relation/distinct_relation.cpp @@ -0,0 +1,34 @@ +#include "duckdb/main/relation/distinct_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +DistinctRelation::DistinctRelation(shared_ptr child_p) + : Relation(child_p->context, RelationType::DISTINCT_RELATION), child(std::move(child_p)) { + D_ASSERT(child.get() != this); + vector dummy_columns; + context.GetContext()->TryBindRelation(*this, dummy_columns); +} + +unique_ptr DistinctRelation::GetQueryNode() { + auto child_node = child->GetQueryNode(); + child_node->AddDistinct(); + return child_node; +} + +string DistinctRelation::GetAlias() { + return child->GetAlias(); +} + +const vector &DistinctRelation::Columns() { + return child->Columns(); +} + +string DistinctRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Distinct\n"; + return str + child->ToString(depth + 1); + ; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/explain_relation.cpp b/src/duckdb/src/main/relation/explain_relation.cpp new file mode 100644 index 00000000..0a3f8125 --- /dev/null +++ b/src/duckdb/src/main/relation/explain_relation.cpp @@ -0,0 +1,31 @@ +#include "duckdb/main/relation/explain_relation.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +ExplainRelation::ExplainRelation(shared_ptr child_p, ExplainType type) + : Relation(child_p->context, RelationType::EXPLAIN_RELATION), child(std::move(child_p)), type(type) { + context.GetContext()->TryBindRelation(*this, this->columns); +} + +BoundStatement ExplainRelation::Bind(Binder &binder) { + auto select = make_uniq(); + select->node = child->GetQueryNode(); + ExplainStatement explain(std::move(select), type); + return binder.Bind(explain.Cast()); +} + +const vector &ExplainRelation::Columns() { + return columns; +} + +string ExplainRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Explain\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/filter_relation.cpp b/src/duckdb/src/main/relation/filter_relation.cpp new file mode 100644 index 00000000..2abaa41a --- /dev/null +++ b/src/duckdb/src/main/relation/filter_relation.cpp @@ -0,0 +1,57 @@ +#include "duckdb/main/relation/filter_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" + +namespace duckdb { + +FilterRelation::FilterRelation(shared_ptr child_p, unique_ptr condition_p) + : Relation(child_p->context, RelationType::FILTER_RELATION), condition(std::move(condition_p)), + child(std::move(child_p)) { + D_ASSERT(child.get() != this); + vector dummy_columns; + context.GetContext()->TryBindRelation(*this, dummy_columns); +} + +unique_ptr FilterRelation::GetQueryNode() { + auto child_ptr = child.get(); + while (child_ptr->InheritsColumnBindings()) { + child_ptr = child_ptr->ChildRelation(); + } + if (child_ptr->type == RelationType::JOIN_RELATION) { + // child node is a join: push filter into WHERE clause of select node + auto child_node = child->GetQueryNode(); + D_ASSERT(child_node->type == QueryNodeType::SELECT_NODE); + auto &select_node = child_node->Cast(); + if (!select_node.where_clause) { + select_node.where_clause = condition->Copy(); + } else { + select_node.where_clause = make_uniq( + ExpressionType::CONJUNCTION_AND, std::move(select_node.where_clause), condition->Copy()); + } + return child_node; + } else { + auto result = make_uniq(); + result->select_list.push_back(make_uniq()); + result->from_table = child->GetTableRef(); + result->where_clause = condition->Copy(); + return std::move(result); + } +} + +string FilterRelation::GetAlias() { + return child->GetAlias(); +} + +const vector &FilterRelation::Columns() { + return child->Columns(); +} + +string FilterRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Filter [" + condition->ToString() + "]\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/insert_relation.cpp b/src/duckdb/src/main/relation/insert_relation.cpp new file mode 100644 index 00000000..c6738d48 --- /dev/null +++ b/src/duckdb/src/main/relation/insert_relation.cpp @@ -0,0 +1,36 @@ +#include "duckdb/main/relation/insert_relation.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +InsertRelation::InsertRelation(shared_ptr child_p, string schema_name, string table_name) + : Relation(child_p->context, RelationType::INSERT_RELATION), child(std::move(child_p)), + schema_name(std::move(schema_name)), table_name(std::move(table_name)) { + context.GetContext()->TryBindRelation(*this, this->columns); +} + +BoundStatement InsertRelation::Bind(Binder &binder) { + InsertStatement stmt; + auto select = make_uniq(); + select->node = child->GetQueryNode(); + + stmt.schema = schema_name; + stmt.table = table_name; + stmt.select_statement = std::move(select); + return binder.Bind(stmt.Cast()); +} + +const vector &InsertRelation::Columns() { + return columns; +} + +string InsertRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Insert\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/join_relation.cpp b/src/duckdb/src/main/relation/join_relation.cpp new file mode 100644 index 00000000..22691cd8 --- /dev/null +++ b/src/duckdb/src/main/relation/join_relation.cpp @@ -0,0 +1,63 @@ +#include "duckdb/main/relation/join_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/tableref/joinref.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +JoinRelation::JoinRelation(shared_ptr left_p, shared_ptr right_p, + unique_ptr condition_p, JoinType type, JoinRefType join_ref_type) + : Relation(left_p->context, RelationType::JOIN_RELATION), left(std::move(left_p)), right(std::move(right_p)), + condition(std::move(condition_p)), join_type(type), join_ref_type(join_ref_type) { + if (left->context.GetContext() != right->context.GetContext()) { + throw Exception("Cannot combine LEFT and RIGHT relations of different connections!"); + } + context.GetContext()->TryBindRelation(*this, this->columns); +} + +JoinRelation::JoinRelation(shared_ptr left_p, shared_ptr right_p, vector using_columns_p, + JoinType type, JoinRefType join_ref_type) + : Relation(left_p->context, RelationType::JOIN_RELATION), left(std::move(left_p)), right(std::move(right_p)), + using_columns(std::move(using_columns_p)), join_type(type), join_ref_type(join_ref_type) { + if (left->context.GetContext() != right->context.GetContext()) { + throw Exception("Cannot combine LEFT and RIGHT relations of different connections!"); + } + context.GetContext()->TryBindRelation(*this, this->columns); +} + +unique_ptr JoinRelation::GetQueryNode() { + auto result = make_uniq(); + result->select_list.push_back(make_uniq()); + result->from_table = GetTableRef(); + return std::move(result); +} + +unique_ptr JoinRelation::GetTableRef() { + auto join_ref = make_uniq(join_ref_type); + join_ref->left = left->GetTableRef(); + join_ref->right = right->GetTableRef(); + if (condition) { + join_ref->condition = condition->Copy(); + } + join_ref->using_columns = using_columns; + join_ref->type = join_type; + return std::move(join_ref); +} + +const vector &JoinRelation::Columns() { + return this->columns; +} + +string JoinRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth); + str += "Join " + EnumUtil::ToString(join_ref_type) + " " + EnumUtil::ToString(join_type); + if (condition) { + str += " " + condition->GetName(); + } + + return str + "\n" + left->ToString(depth + 1) + "\n" + right->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/limit_relation.cpp b/src/duckdb/src/main/relation/limit_relation.cpp new file mode 100644 index 00000000..96491bce --- /dev/null +++ b/src/duckdb/src/main/relation/limit_relation.cpp @@ -0,0 +1,46 @@ +#include "duckdb/main/relation/limit_relation.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/common/to_string.hpp" + +namespace duckdb { + +LimitRelation::LimitRelation(shared_ptr child_p, int64_t limit, int64_t offset) + : Relation(child_p->context, RelationType::PROJECTION_RELATION), limit(limit), offset(offset), + child(std::move(child_p)) { + D_ASSERT(child.get() != this); +} + +unique_ptr LimitRelation::GetQueryNode() { + auto child_node = child->GetQueryNode(); + auto limit_node = make_uniq(); + if (limit >= 0) { + limit_node->limit = make_uniq(Value::BIGINT(limit)); + } + if (offset > 0) { + limit_node->offset = make_uniq(Value::BIGINT(offset)); + } + + child_node->modifiers.push_back(std::move(limit_node)); + return child_node; +} + +string LimitRelation::GetAlias() { + return child->GetAlias(); +} + +const vector &LimitRelation::Columns() { + return child->Columns(); +} + +string LimitRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Limit " + to_string(limit); + if (offset > 0) { + str += " Offset " + to_string(offset); + } + str += "\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/order_relation.cpp b/src/duckdb/src/main/relation/order_relation.cpp new file mode 100644 index 00000000..7a9c16dc --- /dev/null +++ b/src/duckdb/src/main/relation/order_relation.cpp @@ -0,0 +1,48 @@ +#include "duckdb/main/relation/order_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" + +namespace duckdb { + +OrderRelation::OrderRelation(shared_ptr child_p, vector orders) + : Relation(child_p->context, RelationType::ORDER_RELATION), orders(std::move(orders)), child(std::move(child_p)) { + D_ASSERT(child.get() != this); + // bind the expressions + context.GetContext()->TryBindRelation(*this, this->columns); +} + +unique_ptr OrderRelation::GetQueryNode() { + auto select = make_uniq(); + select->from_table = child->GetTableRef(); + select->select_list.push_back(make_uniq()); + auto order_node = make_uniq(); + for (idx_t i = 0; i < orders.size(); i++) { + order_node->orders.emplace_back(orders[i].type, orders[i].null_order, orders[i].expression->Copy()); + } + select->modifiers.push_back(std::move(order_node)); + return std::move(select); +} + +string OrderRelation::GetAlias() { + return child->GetAlias(); +} + +const vector &OrderRelation::Columns() { + return columns; +} + +string OrderRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Order ["; + for (idx_t i = 0; i < orders.size(); i++) { + if (i != 0) { + str += ", "; + } + str += orders[i].expression->ToString() + (orders[i].type == OrderType::ASCENDING ? " ASC" : " DESC"); + } + str += "]\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/projection_relation.cpp b/src/duckdb/src/main/relation/projection_relation.cpp new file mode 100644 index 00000000..eb1f57fd --- /dev/null +++ b/src/duckdb/src/main/relation/projection_relation.cpp @@ -0,0 +1,69 @@ +#include "duckdb/main/relation/projection_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" + +namespace duckdb { + +ProjectionRelation::ProjectionRelation(shared_ptr child_p, + vector> parsed_expressions, vector aliases) + : Relation(child_p->context, RelationType::PROJECTION_RELATION), expressions(std::move(parsed_expressions)), + child(std::move(child_p)) { + if (!aliases.empty()) { + if (aliases.size() != expressions.size()) { + throw ParserException("Aliases list length must match expression list length!"); + } + for (idx_t i = 0; i < aliases.size(); i++) { + expressions[i]->alias = aliases[i]; + } + } + // bind the expressions + context.GetContext()->TryBindRelation(*this, this->columns); +} + +unique_ptr ProjectionRelation::GetQueryNode() { + auto child_ptr = child.get(); + while (child_ptr->InheritsColumnBindings()) { + child_ptr = child_ptr->ChildRelation(); + } + unique_ptr result; + if (child_ptr->type == RelationType::JOIN_RELATION) { + // child node is a join: push projection into the child query node + result = child->GetQueryNode(); + } else { + // child node is not a join: create a new select node and push the child as a table reference + auto select = make_uniq(); + select->from_table = child->GetTableRef(); + result = std::move(select); + } + D_ASSERT(result->type == QueryNodeType::SELECT_NODE); + auto &select_node = result->Cast(); + select_node.aggregate_handling = AggregateHandling::NO_AGGREGATES_ALLOWED; + select_node.select_list.clear(); + for (auto &expr : expressions) { + select_node.select_list.push_back(expr->Copy()); + } + return result; +} + +string ProjectionRelation::GetAlias() { + return child->GetAlias(); +} + +const vector &ProjectionRelation::Columns() { + return columns; +} + +string ProjectionRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Projection ["; + for (idx_t i = 0; i < expressions.size(); i++) { + if (i != 0) { + str += ", "; + } + str += expressions[i]->ToString() + " as " + expressions[i]->alias; + } + str += "]\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/query_relation.cpp b/src/duckdb/src/main/relation/query_relation.cpp new file mode 100644 index 00000000..f6421601 --- /dev/null +++ b/src/duckdb/src/main/relation/query_relation.cpp @@ -0,0 +1,58 @@ +#include "duckdb/main/relation/query_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/parser/parser.hpp" + +namespace duckdb { + +QueryRelation::QueryRelation(const std::shared_ptr &context, unique_ptr select_stmt_p, + string alias_p) + : Relation(context, RelationType::QUERY_RELATION), select_stmt(std::move(select_stmt_p)), + alias(std::move(alias_p)) { + context->TryBindRelation(*this, this->columns); +} + +QueryRelation::~QueryRelation() { +} + +unique_ptr QueryRelation::ParseStatement(ClientContext &context, const string &query, + const string &error) { + Parser parser(context.GetParserOptions()); + parser.ParseQuery(query); + if (parser.statements.size() != 1) { + throw ParserException(error); + } + if (parser.statements[0]->type != StatementType::SELECT_STATEMENT) { + throw ParserException(error); + } + return unique_ptr_cast(std::move(parser.statements[0])); +} + +unique_ptr QueryRelation::GetSelectStatement() { + return unique_ptr_cast(select_stmt->Copy()); +} + +unique_ptr QueryRelation::GetQueryNode() { + auto select = GetSelectStatement(); + return std::move(select->node); +} + +unique_ptr QueryRelation::GetTableRef() { + auto subquery_ref = make_uniq(GetSelectStatement(), GetAlias()); + return std::move(subquery_ref); +} + +string QueryRelation::GetAlias() { + return alias; +} + +const vector &QueryRelation::Columns() { + return columns; +} + +string QueryRelation::ToString(idx_t depth) { + return RenderWhitespace(depth) + "Subquery"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/read_csv_relation.cpp b/src/duckdb/src/main/relation/read_csv_relation.cpp new file mode 100644 index 00000000..fd9ecf55 --- /dev/null +++ b/src/duckdb/src/main/relation/read_csv_relation.cpp @@ -0,0 +1,85 @@ +#include "duckdb/main/relation/read_csv_relation.hpp" + +#include "duckdb/execution/operator/scan/csv/buffered_csv_reader.hpp" +#include "duckdb/execution/operator/scan/csv/csv_buffer_manager.hpp" +#include "duckdb/execution/operator/scan/csv/csv_sniffer.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/common/multi_file_reader.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" + +namespace duckdb { + +ReadCSVRelation::ReadCSVRelation(const shared_ptr &context, const string &csv_file, + vector columns_p, string alias_p) + : TableFunctionRelation(context, "read_csv", {Value(csv_file)}, nullptr, false), alias(std::move(alias_p)), + auto_detect(false) { + + if (alias.empty()) { + alias = StringUtil::Split(csv_file, ".")[0]; + } + + columns = std::move(columns_p); + + child_list_t column_names; + for (idx_t i = 0; i < columns.size(); i++) { + column_names.push_back(make_pair(columns[i].Name(), Value(columns[i].Type().ToString()))); + } + + AddNamedParameter("columns", Value::STRUCT(std::move(column_names))); +} + +ReadCSVRelation::ReadCSVRelation(const std::shared_ptr &context, const string &csv_file, + named_parameter_map_t &&options, string alias_p) + : TableFunctionRelation(context, "read_csv_auto", {Value(csv_file)}, nullptr, false), alias(std::move(alias_p)), + auto_detect(true) { + + if (alias.empty()) { + alias = StringUtil::Split(csv_file, ".")[0]; + } + + auto files = MultiFileReader::GetFileList(*context, csv_file, "CSV"); + D_ASSERT(!files.empty()); + + auto &file_name = files[0]; + options["auto_detect"] = Value::BOOLEAN(true); + CSVReaderOptions csv_options; + csv_options.file_path = file_name; + vector empty; + + vector unused_types; + vector unused_names; + csv_options.FromNamedParameters(options, *context, unused_types, unused_names); + // Run the auto-detect, populating the options with the detected settings + + auto bm_file_handle = BaseCSVReader::OpenCSV(*context, csv_options); + auto buffer_manager = make_shared(*context, std::move(bm_file_handle), csv_options); + CSVStateMachineCache state_machine_cache; + CSVSniffer sniffer(csv_options, buffer_manager, state_machine_cache); + auto sniffer_result = sniffer.SniffCSV(); + auto &types = sniffer_result.return_types; + auto &names = sniffer_result.names; + for (idx_t i = 0; i < types.size(); i++) { + columns.emplace_back(names[i], types[i]); + } + + //! Capture the options potentially set/altered by the auto detection phase + csv_options.ToNamedParameters(options); + + // No need to auto-detect again + options["auto_detect"] = Value::BOOLEAN(false); + SetNamedParameters(std::move(options)); +} + +string ReadCSVRelation::GetAlias() { + return alias; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/read_json_relation.cpp b/src/duckdb/src/main/relation/read_json_relation.cpp new file mode 100644 index 00000000..c2cf3315 --- /dev/null +++ b/src/duckdb/src/main/relation/read_json_relation.cpp @@ -0,0 +1,20 @@ +#include "duckdb/main/relation/read_json_relation.hpp" +#include "duckdb/parser/column_definition.hpp" +namespace duckdb { + +ReadJSONRelation::ReadJSONRelation(const shared_ptr &context, string json_file_p, + named_parameter_map_t options, bool auto_detect, string alias_p) + : TableFunctionRelation(context, auto_detect ? "read_json_auto" : "read_json", {Value(json_file_p)}, + std::move(options)), + json_file(std::move(json_file_p)), alias(std::move(alias_p)) { + + if (alias.empty()) { + alias = StringUtil::Split(json_file, ".")[0]; + } +} + +string ReadJSONRelation::GetAlias() { + return alias; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/setop_relation.cpp b/src/duckdb/src/main/relation/setop_relation.cpp new file mode 100644 index 00000000..c59b6c78 --- /dev/null +++ b/src/duckdb/src/main/relation/setop_relation.cpp @@ -0,0 +1,54 @@ +#include "duckdb/main/relation/setop_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/result_modifier.hpp" + +namespace duckdb { + +SetOpRelation::SetOpRelation(shared_ptr left_p, shared_ptr right_p, SetOperationType setop_type_p) + : Relation(left_p->context, RelationType::SET_OPERATION_RELATION), left(std::move(left_p)), + right(std::move(right_p)), setop_type(setop_type_p) { + if (left->context.GetContext() != right->context.GetContext()) { + throw Exception("Cannot combine LEFT and RIGHT relations of different connections!"); + } + context.GetContext()->TryBindRelation(*this, this->columns); +} + +unique_ptr SetOpRelation::GetQueryNode() { + auto result = make_uniq(); + if (setop_type == SetOperationType::EXCEPT || setop_type == SetOperationType::INTERSECT) { + result->modifiers.push_back(make_uniq()); + } + result->left = left->GetQueryNode(); + result->right = right->GetQueryNode(); + result->setop_type = setop_type; + return std::move(result); +} + +string SetOpRelation::GetAlias() { + return left->GetAlias(); +} + +const vector &SetOpRelation::Columns() { + return this->columns; +} + +string SetOpRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth); + switch (setop_type) { + case SetOperationType::UNION: + str += "Union"; + break; + case SetOperationType::EXCEPT: + str += "Except"; + break; + case SetOperationType::INTERSECT: + str += "Intersect"; + break; + default: + throw InternalException("Unknown setop type"); + } + return str + "\n" + left->ToString(depth + 1) + right->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/subquery_relation.cpp b/src/duckdb/src/main/relation/subquery_relation.cpp new file mode 100644 index 00000000..5ee1e032 --- /dev/null +++ b/src/duckdb/src/main/relation/subquery_relation.cpp @@ -0,0 +1,31 @@ +#include "duckdb/main/relation/subquery_relation.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/query_node.hpp" + +namespace duckdb { + +SubqueryRelation::SubqueryRelation(shared_ptr child_p, string alias_p) + : Relation(child_p->context, RelationType::SUBQUERY_RELATION), child(std::move(child_p)), + alias(std::move(alias_p)) { + D_ASSERT(child.get() != this); + vector dummy_columns; + context.GetContext()->TryBindRelation(*this, dummy_columns); +} + +unique_ptr SubqueryRelation::GetQueryNode() { + return child->GetQueryNode(); +} + +string SubqueryRelation::GetAlias() { + return alias; +} + +const vector &SubqueryRelation::Columns() { + return child->Columns(); +} + +string SubqueryRelation::ToString(idx_t depth) { + return child->ToString(depth); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/table_function_relation.cpp b/src/duckdb/src/main/relation/table_function_relation.cpp new file mode 100644 index 00000000..6933a36a --- /dev/null +++ b/src/duckdb/src/main/relation/table_function_relation.cpp @@ -0,0 +1,106 @@ +#include "duckdb/main/relation/table_function_relation.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/common/shared_ptr.hpp" + +namespace duckdb { + +void TableFunctionRelation::AddNamedParameter(const string &name, Value argument) { + named_parameters[name] = std::move(argument); +} + +void TableFunctionRelation::SetNamedParameters(named_parameter_map_t &&options) { + D_ASSERT(named_parameters.empty()); + named_parameters = std::move(options); +} + +TableFunctionRelation::TableFunctionRelation(const shared_ptr &context, string name_p, + vector parameters_p, named_parameter_map_t named_parameters, + shared_ptr input_relation_p, bool auto_init) + : Relation(context, RelationType::TABLE_FUNCTION_RELATION), name(std::move(name_p)), + parameters(std::move(parameters_p)), named_parameters(std::move(named_parameters)), + input_relation(std::move(input_relation_p)), auto_initialize(auto_init) { + InitializeColumns(); +} + +TableFunctionRelation::TableFunctionRelation(const shared_ptr &context, string name_p, + vector parameters_p, shared_ptr input_relation_p, + bool auto_init) + : Relation(context, RelationType::TABLE_FUNCTION_RELATION), name(std::move(name_p)), + parameters(std::move(parameters_p)), input_relation(std::move(input_relation_p)), auto_initialize(auto_init) { + InitializeColumns(); +} + +void TableFunctionRelation::InitializeColumns() { + if (!auto_initialize) { + return; + } + context.GetContext()->TryBindRelation(*this, this->columns); +} + +unique_ptr TableFunctionRelation::GetQueryNode() { + auto result = make_uniq(); + result->select_list.push_back(make_uniq()); + result->from_table = GetTableRef(); + return std::move(result); +} + +unique_ptr TableFunctionRelation::GetTableRef() { + vector> children; + if (input_relation) { // input relation becomes first parameter if present, always + auto subquery = make_uniq(); + subquery->subquery = make_uniq(); + subquery->subquery->node = input_relation->GetQueryNode(); + subquery->subquery_type = SubqueryType::SCALAR; + children.push_back(std::move(subquery)); + } + for (auto ¶meter : parameters) { + children.push_back(make_uniq(parameter)); + } + + for (auto ¶meter : named_parameters) { + // Hackity-hack some comparisons with column refs + // This is all but pretty, basically the named parameter is the column, the table is empty because that's what + // the function binder likes + auto column_ref = make_uniq(parameter.first); + auto constant_value = make_uniq(parameter.second); + auto comparison = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(column_ref), + std::move(constant_value)); + children.push_back(std::move(comparison)); + } + + auto table_function = make_uniq(); + auto function = make_uniq(name, std::move(children)); + table_function->function = std::move(function); + return std::move(table_function); +} + +string TableFunctionRelation::GetAlias() { + return name; +} + +const vector &TableFunctionRelation::Columns() { + return columns; +} + +string TableFunctionRelation::ToString(idx_t depth) { + string function_call = name + "("; + for (idx_t i = 0; i < parameters.size(); i++) { + if (i > 0) { + function_call += ", "; + } + function_call += parameters[i].ToString(); + } + function_call += ")"; + return RenderWhitespace(depth) + function_call; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/table_relation.cpp b/src/duckdb/src/main/relation/table_relation.cpp new file mode 100644 index 00000000..b4954e3d --- /dev/null +++ b/src/duckdb/src/main/relation/table_relation.cpp @@ -0,0 +1,70 @@ +#include "duckdb/main/relation/table_relation.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/main/relation/delete_relation.hpp" +#include "duckdb/main/relation/update_relation.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +TableRelation::TableRelation(const std::shared_ptr &context, unique_ptr description) + : Relation(context, RelationType::TABLE_RELATION), description(std::move(description)) { +} + +unique_ptr TableRelation::GetQueryNode() { + auto result = make_uniq(); + result->select_list.push_back(make_uniq()); + result->from_table = GetTableRef(); + return std::move(result); +} + +unique_ptr TableRelation::GetTableRef() { + auto table_ref = make_uniq(); + table_ref->schema_name = description->schema; + table_ref->table_name = description->table; + return std::move(table_ref); +} + +string TableRelation::GetAlias() { + return description->table; +} + +const vector &TableRelation::Columns() { + return description->columns; +} + +string TableRelation::ToString(idx_t depth) { + return RenderWhitespace(depth) + "Scan Table [" + description->table + "]"; +} + +static unique_ptr ParseCondition(ClientContext &context, const string &condition) { + if (!condition.empty()) { + auto expression_list = Parser::ParseExpressionList(condition, context.GetParserOptions()); + if (expression_list.size() != 1) { + throw ParserException("Expected a single expression as filter condition"); + } + return std::move(expression_list[0]); + } else { + return nullptr; + } +} + +void TableRelation::Update(const string &update_list, const string &condition) { + vector update_columns; + vector> expressions; + auto cond = ParseCondition(*context.GetContext(), condition); + Parser::ParseUpdateList(update_list, update_columns, expressions, context.GetContext()->GetParserOptions()); + auto update = make_shared(context, std::move(cond), description->schema, description->table, + std::move(update_columns), std::move(expressions)); + update->Execute(); +} + +void TableRelation::Delete(const string &condition) { + auto cond = ParseCondition(*context.GetContext(), condition); + auto del = make_shared(context, std::move(cond), description->schema, description->table); + del->Execute(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/update_relation.cpp b/src/duckdb/src/main/relation/update_relation.cpp new file mode 100644 index 00000000..152d04af --- /dev/null +++ b/src/duckdb/src/main/relation/update_relation.cpp @@ -0,0 +1,51 @@ +#include "duckdb/main/relation/update_relation.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" + +namespace duckdb { + +UpdateRelation::UpdateRelation(ClientContextWrapper &context, unique_ptr condition_p, + string schema_name_p, string table_name_p, vector update_columns_p, + vector> expressions_p) + : Relation(context, RelationType::UPDATE_RELATION), condition(std::move(condition_p)), + schema_name(std::move(schema_name_p)), table_name(std::move(table_name_p)), + update_columns(std::move(update_columns_p)), expressions(std::move(expressions_p)) { + D_ASSERT(update_columns.size() == expressions.size()); + context.GetContext()->TryBindRelation(*this, this->columns); +} + +BoundStatement UpdateRelation::Bind(Binder &binder) { + auto basetable = make_uniq(); + basetable->schema_name = schema_name; + basetable->table_name = table_name; + + UpdateStatement stmt; + stmt.set_info = make_uniq(); + + stmt.set_info->condition = condition ? condition->Copy() : nullptr; + stmt.table = std::move(basetable); + stmt.set_info->columns = update_columns; + for (auto &expr : expressions) { + stmt.set_info->expressions.push_back(expr->Copy()); + } + return binder.Bind(stmt.Cast()); +} + +const vector &UpdateRelation::Columns() { + return columns; +} + +string UpdateRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "UPDATE " + table_name + " SET\n"; + for (idx_t i = 0; i < expressions.size(); i++) { + str += update_columns[i] + " = " + expressions[i]->ToString() + "\n"; + } + if (condition) { + str += "WHERE " + condition->ToString() + "\n"; + } + return str; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/value_relation.cpp b/src/duckdb/src/main/relation/value_relation.cpp new file mode 100644 index 00000000..67dd9432 --- /dev/null +++ b/src/duckdb/src/main/relation/value_relation.cpp @@ -0,0 +1,91 @@ +#include "duckdb/main/relation/value_relation.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/parser.hpp" + +namespace duckdb { + +ValueRelation::ValueRelation(const std::shared_ptr &context, const vector> &values, + vector names_p, string alias_p) + : Relation(context, RelationType::VALUE_LIST_RELATION), names(std::move(names_p)), alias(std::move(alias_p)) { + // create constant expressions for the values + for (idx_t row_idx = 0; row_idx < values.size(); row_idx++) { + auto &list = values[row_idx]; + vector> expressions; + for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { + expressions.push_back(make_uniq(list[col_idx])); + } + this->expressions.push_back(std::move(expressions)); + } + context->TryBindRelation(*this, this->columns); +} + +ValueRelation::ValueRelation(const std::shared_ptr &context, const string &values_list, + vector names_p, string alias_p) + : Relation(context, RelationType::VALUE_LIST_RELATION), names(std::move(names_p)), alias(std::move(alias_p)) { + this->expressions = Parser::ParseValuesList(values_list, context->GetParserOptions()); + context->TryBindRelation(*this, this->columns); +} + +unique_ptr ValueRelation::GetQueryNode() { + auto result = make_uniq(); + result->select_list.push_back(make_uniq()); + result->from_table = GetTableRef(); + return std::move(result); +} + +unique_ptr ValueRelation::GetTableRef() { + auto table_ref = make_uniq(); + // set the expected types/names + if (columns.empty()) { + // no columns yet: only set up names + for (idx_t i = 0; i < names.size(); i++) { + table_ref->expected_names.push_back(names[i]); + } + } else { + for (idx_t i = 0; i < columns.size(); i++) { + table_ref->expected_names.push_back(columns[i].Name()); + table_ref->expected_types.push_back(columns[i].Type()); + D_ASSERT(names.size() == 0 || columns[i].Name() == names[i]); + } + } + // copy the expressions + for (auto &expr_list : expressions) { + vector> copied_list; + copied_list.reserve(expr_list.size()); + for (auto &expr : expr_list) { + copied_list.push_back(expr->Copy()); + } + table_ref->values.push_back(std::move(copied_list)); + } + table_ref->alias = GetAlias(); + return std::move(table_ref); +} + +string ValueRelation::GetAlias() { + return alias; +} + +const vector &ValueRelation::Columns() { + return columns; +} + +string ValueRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Values "; + for (idx_t row_idx = 0; row_idx < expressions.size(); row_idx++) { + auto &list = expressions[row_idx]; + str += row_idx > 0 ? ", (" : "("; + for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { + str += col_idx > 0 ? ", " : ""; + str += list[col_idx]->ToString(); + } + str += ")"; + } + str += "\n"; + return str; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/view_relation.cpp b/src/duckdb/src/main/relation/view_relation.cpp new file mode 100644 index 00000000..b432f5d3 --- /dev/null +++ b/src/duckdb/src/main/relation/view_relation.cpp @@ -0,0 +1,42 @@ +#include "duckdb/main/relation/view_relation.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/parser.hpp" + +namespace duckdb { + +ViewRelation::ViewRelation(const std::shared_ptr &context, string schema_name_p, string view_name_p) + : Relation(context, RelationType::VIEW_RELATION), schema_name(std::move(schema_name_p)), + view_name(std::move(view_name_p)) { + context->TryBindRelation(*this, this->columns); +} + +unique_ptr ViewRelation::GetQueryNode() { + auto result = make_uniq(); + result->select_list.push_back(make_uniq()); + result->from_table = GetTableRef(); + return std::move(result); +} + +unique_ptr ViewRelation::GetTableRef() { + auto table_ref = make_uniq(); + table_ref->schema_name = schema_name; + table_ref->table_name = view_name; + return std::move(table_ref); +} + +string ViewRelation::GetAlias() { + return view_name; +} + +const vector &ViewRelation::Columns() { + return columns; +} + +string ViewRelation::ToString(idx_t depth) { + return RenderWhitespace(depth) + "View [" + view_name + "]"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/write_csv_relation.cpp b/src/duckdb/src/main/relation/write_csv_relation.cpp new file mode 100644 index 00000000..016ea018 --- /dev/null +++ b/src/duckdb/src/main/relation/write_csv_relation.cpp @@ -0,0 +1,37 @@ +#include "duckdb/main/relation/write_csv_relation.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +WriteCSVRelation::WriteCSVRelation(shared_ptr child_p, string csv_file_p, + case_insensitive_map_t> options_p) + : Relation(child_p->context, RelationType::WRITE_CSV_RELATION), child(std::move(child_p)), + csv_file(std::move(csv_file_p)), options(std::move(options_p)) { + context.GetContext()->TryBindRelation(*this, this->columns); +} + +BoundStatement WriteCSVRelation::Bind(Binder &binder) { + CopyStatement copy; + copy.select_statement = child->GetQueryNode(); + auto info = make_uniq(); + info->is_from = false; + info->file_path = csv_file; + info->format = "csv"; + info->options = options; + copy.info = std::move(info); + return binder.Bind(copy.Cast()); +} + +const vector &WriteCSVRelation::Columns() { + return columns; +} + +string WriteCSVRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Write To CSV [" + csv_file + "]\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/relation/write_parquet_relation.cpp b/src/duckdb/src/main/relation/write_parquet_relation.cpp new file mode 100644 index 00000000..6cedc29c --- /dev/null +++ b/src/duckdb/src/main/relation/write_parquet_relation.cpp @@ -0,0 +1,37 @@ +#include "duckdb/main/relation/write_parquet_relation.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +WriteParquetRelation::WriteParquetRelation(shared_ptr child_p, string parquet_file_p, + case_insensitive_map_t> options_p) + : Relation(child_p->context, RelationType::WRITE_PARQUET_RELATION), child(std::move(child_p)), + parquet_file(std::move(parquet_file_p)), options(std::move(options_p)) { + context.GetContext()->TryBindRelation(*this, this->columns); +} + +BoundStatement WriteParquetRelation::Bind(Binder &binder) { + CopyStatement copy; + copy.select_statement = child->GetQueryNode(); + auto info = make_uniq(); + info->is_from = false; + info->file_path = parquet_file; + info->format = "parquet"; + info->options = options; + copy.info = std::move(info); + return binder.Bind(copy.Cast()); +} + +const vector &WriteParquetRelation::Columns() { + return columns; +} + +string WriteParquetRelation::ToString(idx_t depth) { + string str = RenderWhitespace(depth) + "Write To Parquet [" + parquet_file + "]\n"; + return str + child->ToString(depth + 1); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/settings/settings.cpp b/src/duckdb/src/main/settings/settings.cpp new file mode 100644 index 00000000..1fe56fed --- /dev/null +++ b/src/duckdb/src/main/settings/settings.cpp @@ -0,0 +1,1206 @@ +#include "duckdb/main/settings.hpp" + +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/storage_manager.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Access Mode +//===--------------------------------------------------------------------===// +void AccessModeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + if (db) { + throw InvalidInputException("Cannot change access_mode setting while database is running - it must be set when " + "opening or attaching the database"); + } + auto parameter = StringUtil::Lower(input.ToString()); + if (parameter == "automatic") { + config.options.access_mode = AccessMode::AUTOMATIC; + } else if (parameter == "read_only") { + config.options.access_mode = AccessMode::READ_ONLY; + } else if (parameter == "read_write") { + config.options.access_mode = AccessMode::READ_WRITE; + } else { + throw InvalidInputException( + "Unrecognized parameter for option ACCESS_MODE \"%s\". Expected READ_ONLY or READ_WRITE.", parameter); + } +} + +void AccessModeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.access_mode = DBConfig().options.access_mode; +} + +Value AccessModeSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + switch (config.options.access_mode) { + case AccessMode::AUTOMATIC: + return "automatic"; + case AccessMode::READ_ONLY: + return "read_only"; + case AccessMode::READ_WRITE: + return "read_write"; + default: + throw InternalException("Unknown access mode setting"); + } +} + +//===--------------------------------------------------------------------===// +// Checkpoint Threshold +//===--------------------------------------------------------------------===// +void CheckpointThresholdSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + idx_t new_limit = DBConfig::ParseMemoryLimit(input.ToString()); + config.options.checkpoint_wal_size = new_limit; +} + +void CheckpointThresholdSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.checkpoint_wal_size = DBConfig().options.checkpoint_wal_size; +} + +Value CheckpointThresholdSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(StringUtil::BytesToHumanReadableString(config.options.checkpoint_wal_size)); +} + +//===--------------------------------------------------------------------===// +// Debug Checkpoint Abort +//===--------------------------------------------------------------------===// +void DebugCheckpointAbort::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto checkpoint_abort = StringUtil::Lower(input.ToString()); + if (checkpoint_abort == "none") { + config.options.checkpoint_abort = CheckpointAbort::NO_ABORT; + } else if (checkpoint_abort == "before_truncate") { + config.options.checkpoint_abort = CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE; + } else if (checkpoint_abort == "before_header") { + config.options.checkpoint_abort = CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER; + } else if (checkpoint_abort == "after_free_list_write") { + config.options.checkpoint_abort = CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE; + } else { + throw ParserException( + "Unrecognized option for PRAGMA debug_checkpoint_abort, expected none, before_truncate or before_header"); + } +} + +void DebugCheckpointAbort::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.checkpoint_abort = DBConfig().options.checkpoint_abort; +} + +Value DebugCheckpointAbort::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(*context.db); + auto setting = config.options.checkpoint_abort; + switch (setting) { + case CheckpointAbort::NO_ABORT: + return "none"; + case CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE: + return "before_truncate"; + case CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER: + return "before_header"; + case CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE: + return "after_free_list_write"; + default: + throw InternalException("Type not implemented for CheckpointAbort"); + } +} + +//===--------------------------------------------------------------------===// +// Debug Force External +//===--------------------------------------------------------------------===// +void DebugForceExternal::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).force_external = ClientConfig().force_external; +} + +void DebugForceExternal::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).force_external = input.GetValue(); +} + +Value DebugForceExternal::GetSetting(ClientContext &context) { + return Value::BOOLEAN(ClientConfig::GetConfig(context).force_external); +} + +//===--------------------------------------------------------------------===// +// Debug Force NoCrossProduct +//===--------------------------------------------------------------------===// +void DebugForceNoCrossProduct::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).force_no_cross_product = ClientConfig().force_no_cross_product; +} + +void DebugForceNoCrossProduct::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).force_no_cross_product = input.GetValue(); +} + +Value DebugForceNoCrossProduct::GetSetting(ClientContext &context) { + return Value::BOOLEAN(ClientConfig::GetConfig(context).force_no_cross_product); +} + +//===--------------------------------------------------------------------===// +// Ordered Aggregate Threshold +//===--------------------------------------------------------------------===// +void OrderedAggregateThreshold::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).ordered_aggregate_threshold = ClientConfig().ordered_aggregate_threshold; +} + +void OrderedAggregateThreshold::SetLocal(ClientContext &context, const Value &input) { + const auto param = input.GetValue(); + if (param <= 0) { + throw ParserException("Invalid option for PRAGMA ordered_aggregate_threshold, value must be positive"); + } + ClientConfig::GetConfig(context).ordered_aggregate_threshold = param; +} + +Value OrderedAggregateThreshold::GetSetting(ClientContext &context) { + return Value::UBIGINT(ClientConfig::GetConfig(context).ordered_aggregate_threshold); +} + +//===--------------------------------------------------------------------===// +// Debug Window Mode +//===--------------------------------------------------------------------===// +void DebugWindowMode::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto param = StringUtil::Lower(input.ToString()); + if (param == "window") { + config.options.window_mode = WindowAggregationMode::WINDOW; + } else if (param == "combine") { + config.options.window_mode = WindowAggregationMode::COMBINE; + } else if (param == "separate") { + config.options.window_mode = WindowAggregationMode::SEPARATE; + } else { + throw ParserException("Unrecognized option for PRAGMA debug_window_mode, expected window, combine or separate"); + } +} + +void DebugWindowMode::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.window_mode = DBConfig().options.window_mode; +} + +Value DebugWindowMode::GetSetting(ClientContext &context) { + return Value(); +} + +//===--------------------------------------------------------------------===// +// Debug AsOf Join +//===--------------------------------------------------------------------===// +void DebugAsOfIEJoin::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).force_asof_iejoin = ClientConfig().force_asof_iejoin; +} + +void DebugAsOfIEJoin::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).force_asof_iejoin = input.GetValue(); +} + +Value DebugAsOfIEJoin::GetSetting(ClientContext &context) { + return Value::BOOLEAN(ClientConfig::GetConfig(context).force_asof_iejoin); +} + +//===--------------------------------------------------------------------===// +// Prefer Range Joins +//===--------------------------------------------------------------------===// +void PreferRangeJoins::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).prefer_range_joins = ClientConfig().prefer_range_joins; +} + +void PreferRangeJoins::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).prefer_range_joins = input.GetValue(); +} + +Value PreferRangeJoins::GetSetting(ClientContext &context) { + return Value::BOOLEAN(ClientConfig::GetConfig(context).prefer_range_joins); +} + +//===--------------------------------------------------------------------===// +// Default Collation +//===--------------------------------------------------------------------===// +void DefaultCollationSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto parameter = StringUtil::Lower(input.ToString()); + config.options.collation = parameter; +} + +void DefaultCollationSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.collation = DBConfig().options.collation; +} + +void DefaultCollationSetting::ResetLocal(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + config.options.collation = DBConfig().options.collation; +} + +void DefaultCollationSetting::SetLocal(ClientContext &context, const Value &input) { + auto parameter = input.ToString(); + // bind the collation to verify that it exists + ExpressionBinder::TestCollation(context, parameter); + auto &config = DBConfig::GetConfig(context); + config.options.collation = parameter; +} + +Value DefaultCollationSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(config.options.collation); +} + +//===--------------------------------------------------------------------===// +// Default Order +//===--------------------------------------------------------------------===// +void DefaultOrderSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto parameter = StringUtil::Lower(input.ToString()); + if (parameter == "ascending" || parameter == "asc") { + config.options.default_order_type = OrderType::ASCENDING; + } else if (parameter == "descending" || parameter == "desc") { + config.options.default_order_type = OrderType::DESCENDING; + } else { + throw InvalidInputException("Unrecognized parameter for option DEFAULT_ORDER \"%s\". Expected ASC or DESC.", + parameter); + } +} + +void DefaultOrderSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.default_order_type = DBConfig().options.default_order_type; +} + +Value DefaultOrderSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + switch (config.options.default_order_type) { + case OrderType::ASCENDING: + return "asc"; + case OrderType::DESCENDING: + return "desc"; + default: + throw InternalException("Unknown order type setting"); + } +} + +//===--------------------------------------------------------------------===// +// Default Null Order +//===--------------------------------------------------------------------===// +void DefaultNullOrderSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto parameter = StringUtil::Lower(input.ToString()); + + if (parameter == "nulls_first" || parameter == "nulls first" || parameter == "null first" || parameter == "first") { + config.options.default_null_order = DefaultOrderByNullType::NULLS_FIRST; + } else if (parameter == "nulls_last" || parameter == "nulls last" || parameter == "null last" || + parameter == "last") { + config.options.default_null_order = DefaultOrderByNullType::NULLS_LAST; + } else if (parameter == "nulls_first_on_asc_last_on_desc" || parameter == "sqlite" || parameter == "mysql") { + config.options.default_null_order = DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC; + } else if (parameter == "nulls_last_on_asc_first_on_desc" || parameter == "postgres") { + config.options.default_null_order = DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC; + } else { + throw ParserException("Unrecognized parameter for option NULL_ORDER \"%s\", expected either NULLS FIRST, NULLS " + "LAST, SQLite, MySQL or Postgres", + parameter); + } +} + +void DefaultNullOrderSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.default_null_order = DBConfig().options.default_null_order; +} + +Value DefaultNullOrderSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + switch (config.options.default_null_order) { + case DefaultOrderByNullType::NULLS_FIRST: + return "nulls_first"; + case DefaultOrderByNullType::NULLS_LAST: + return "nulls_last"; + case DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC: + return "nulls_first_on_asc_last_on_desc"; + case DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC: + return "nulls_last_on_asc_first_on_desc"; + default: + throw InternalException("Unknown null order setting"); + } +} + +//===--------------------------------------------------------------------===// +// Disabled File Systems +//===--------------------------------------------------------------------===// +void DisabledFileSystemsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + if (!db) { + throw InternalException("disabled_filesystems can only be set in an active database"); + } + auto &fs = FileSystem::GetFileSystem(*db); + auto list = StringUtil::Split(input.ToString(), ","); + fs.SetDisabledFileSystems(list); +} + +void DisabledFileSystemsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + if (!db) { + throw InternalException("disabled_filesystems can only be set in an active database"); + } + auto &fs = FileSystem::GetFileSystem(*db); + fs.SetDisabledFileSystems(vector()); +} + +Value DisabledFileSystemsSetting::GetSetting(ClientContext &context) { + return Value(""); +} + +//===--------------------------------------------------------------------===// +// Disabled Optimizer +//===--------------------------------------------------------------------===// +void DisabledOptimizersSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto list = StringUtil::Split(input.ToString(), ","); + set disabled_optimizers; + for (auto &entry : list) { + auto param = StringUtil::Lower(entry); + StringUtil::Trim(param); + if (param.empty()) { + continue; + } + disabled_optimizers.insert(OptimizerTypeFromString(param)); + } + config.options.disabled_optimizers = std::move(disabled_optimizers); +} + +void DisabledOptimizersSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.disabled_optimizers = DBConfig().options.disabled_optimizers; +} + +Value DisabledOptimizersSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + string result; + for (auto &optimizer : config.options.disabled_optimizers) { + if (!result.empty()) { + result += ","; + } + result += OptimizerTypeToString(optimizer); + } + return Value(result); +} + +//===--------------------------------------------------------------------===// +// Enable External Access +//===--------------------------------------------------------------------===// +void EnableExternalAccessSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto new_value = input.GetValue(); + if (db && new_value) { + throw InvalidInputException("Cannot change enable_external_access setting while database is running"); + } + config.options.enable_external_access = new_value; +} + +void EnableExternalAccessSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + if (db) { + throw InvalidInputException("Cannot change enable_external_access setting while database is running"); + } + config.options.enable_external_access = DBConfig().options.enable_external_access; +} + +Value EnableExternalAccessSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.enable_external_access); +} + +//===--------------------------------------------------------------------===// +// Enable FSST Vectors +//===--------------------------------------------------------------------===// +void EnableFSSTVectors::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.enable_fsst_vectors = input.GetValue(); +} + +void EnableFSSTVectors::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.enable_fsst_vectors = DBConfig().options.enable_fsst_vectors; +} + +Value EnableFSSTVectors::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.enable_fsst_vectors); +} + +//===--------------------------------------------------------------------===// +// Allow Unsigned Extensions +//===--------------------------------------------------------------------===// +void AllowUnsignedExtensionsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto new_value = input.GetValue(); + if (db && new_value) { + throw InvalidInputException("Cannot change allow_unsigned_extensions setting while database is running"); + } + config.options.allow_unsigned_extensions = new_value; +} + +void AllowUnsignedExtensionsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + if (db) { + throw InvalidInputException("Cannot change allow_unsigned_extensions setting while database is running"); + } + config.options.allow_unsigned_extensions = DBConfig().options.allow_unsigned_extensions; +} + +Value AllowUnsignedExtensionsSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.allow_unsigned_extensions); +} + +//===--------------------------------------------------------------------===// +// Enable Object Cache +//===--------------------------------------------------------------------===// +void EnableObjectCacheSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.object_cache_enable = input.GetValue(); +} + +void EnableObjectCacheSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.object_cache_enable = DBConfig().options.object_cache_enable; +} + +Value EnableObjectCacheSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.object_cache_enable); +} + +//===--------------------------------------------------------------------===// +// Enable HTTP Metadata Cache +//===--------------------------------------------------------------------===// +void EnableHTTPMetadataCacheSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.http_metadata_cache_enable = input.GetValue(); +} + +void EnableHTTPMetadataCacheSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.http_metadata_cache_enable = DBConfig().options.http_metadata_cache_enable; +} + +Value EnableHTTPMetadataCacheSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.http_metadata_cache_enable); +} + +//===--------------------------------------------------------------------===// +// Enable Profiling +//===--------------------------------------------------------------------===// +void EnableProfilingSetting::ResetLocal(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + config.profiler_print_format = ClientConfig().profiler_print_format; + config.enable_profiler = ClientConfig().enable_profiler; + config.emit_profiler_output = ClientConfig().emit_profiler_output; +} + +void EnableProfilingSetting::SetLocal(ClientContext &context, const Value &input) { + auto parameter = StringUtil::Lower(input.ToString()); + + auto &config = ClientConfig::GetConfig(context); + if (parameter == "json") { + config.profiler_print_format = ProfilerPrintFormat::JSON; + } else if (parameter == "query_tree") { + config.profiler_print_format = ProfilerPrintFormat::QUERY_TREE; + } else if (parameter == "query_tree_optimizer") { + config.profiler_print_format = ProfilerPrintFormat::QUERY_TREE_OPTIMIZER; + } else { + throw ParserException( + "Unrecognized print format %s, supported formats: [json, query_tree, query_tree_optimizer]", parameter); + } + config.enable_profiler = true; + config.emit_profiler_output = true; +} + +Value EnableProfilingSetting::GetSetting(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + if (!config.enable_profiler) { + return Value(); + } + switch (config.profiler_print_format) { + case ProfilerPrintFormat::JSON: + return Value("json"); + case ProfilerPrintFormat::QUERY_TREE: + return Value("query_tree"); + case ProfilerPrintFormat::QUERY_TREE_OPTIMIZER: + return Value("query_tree_optimizer"); + default: + throw InternalException("Unsupported profiler print format"); + } +} + +//===--------------------------------------------------------------------===// +// Custom Extension Repository +//===--------------------------------------------------------------------===// +void CustomExtensionRepository::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).custom_extension_repo = ClientConfig().custom_extension_repo; +} + +void CustomExtensionRepository::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).custom_extension_repo = StringUtil::Lower(input.ToString()); +} + +Value CustomExtensionRepository::GetSetting(ClientContext &context) { + return Value(ClientConfig::GetConfig(context).custom_extension_repo); +} + +//===--------------------------------------------------------------------===// +// Autoload Extension Repository +//===--------------------------------------------------------------------===// +void AutoloadExtensionRepository::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).autoinstall_extension_repo = ClientConfig().autoinstall_extension_repo; +} + +void AutoloadExtensionRepository::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).autoinstall_extension_repo = StringUtil::Lower(input.ToString()); +} + +Value AutoloadExtensionRepository::GetSetting(ClientContext &context) { + return Value(ClientConfig::GetConfig(context).autoinstall_extension_repo); +} + +//===--------------------------------------------------------------------===// +// Autoinstall Known Extensions +//===--------------------------------------------------------------------===// +void AutoinstallKnownExtensions::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.autoinstall_known_extensions = input.GetValue(); +} + +void AutoinstallKnownExtensions::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.autoinstall_known_extensions = DBConfig().options.autoinstall_known_extensions; +} + +Value AutoinstallKnownExtensions::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.autoinstall_known_extensions); +} +//===--------------------------------------------------------------------===// +// Autoload Known Extensions +//===--------------------------------------------------------------------===// +void AutoloadKnownExtensions::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.autoload_known_extensions = input.GetValue(); +} + +void AutoloadKnownExtensions::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.autoload_known_extensions = DBConfig().options.autoload_known_extensions; +} + +Value AutoloadKnownExtensions::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.autoload_known_extensions); +} + +//===--------------------------------------------------------------------===// +// Enable Progress Bar +//===--------------------------------------------------------------------===// +void EnableProgressBarSetting::ResetLocal(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + ProgressBar::SystemOverrideCheck(config); + config.enable_progress_bar = ClientConfig().enable_progress_bar; +} + +void EnableProgressBarSetting::SetLocal(ClientContext &context, const Value &input) { + auto &config = ClientConfig::GetConfig(context); + ProgressBar::SystemOverrideCheck(config); + config.enable_progress_bar = input.GetValue(); +} + +Value EnableProgressBarSetting::GetSetting(ClientContext &context) { + return Value::BOOLEAN(ClientConfig::GetConfig(context).enable_progress_bar); +} + +//===--------------------------------------------------------------------===// +// Enable Progress Bar Print +//===--------------------------------------------------------------------===// +void EnableProgressBarPrintSetting::SetLocal(ClientContext &context, const Value &input) { + auto &config = ClientConfig::GetConfig(context); + ProgressBar::SystemOverrideCheck(config); + config.print_progress_bar = input.GetValue(); +} + +void EnableProgressBarPrintSetting::ResetLocal(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + ProgressBar::SystemOverrideCheck(config); + config.print_progress_bar = ClientConfig().print_progress_bar; +} + +Value EnableProgressBarPrintSetting::GetSetting(ClientContext &context) { + return Value::BOOLEAN(ClientConfig::GetConfig(context).print_progress_bar); +} + +//===--------------------------------------------------------------------===// +// Explain Output +//===--------------------------------------------------------------------===// +void ExplainOutputSetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).explain_output_type = ClientConfig().explain_output_type; +} + +void ExplainOutputSetting::SetLocal(ClientContext &context, const Value &input) { + auto parameter = StringUtil::Lower(input.ToString()); + if (parameter == "all") { + ClientConfig::GetConfig(context).explain_output_type = ExplainOutputType::ALL; + } else if (parameter == "optimized_only") { + ClientConfig::GetConfig(context).explain_output_type = ExplainOutputType::OPTIMIZED_ONLY; + } else if (parameter == "physical_only") { + ClientConfig::GetConfig(context).explain_output_type = ExplainOutputType::PHYSICAL_ONLY; + } else { + throw ParserException("Unrecognized output type \"%s\", expected either ALL, OPTIMIZED_ONLY or PHYSICAL_ONLY", + parameter); + } +} + +Value ExplainOutputSetting::GetSetting(ClientContext &context) { + switch (ClientConfig::GetConfig(context).explain_output_type) { + case ExplainOutputType::ALL: + return "all"; + case ExplainOutputType::OPTIMIZED_ONLY: + return "optimized_only"; + case ExplainOutputType::PHYSICAL_ONLY: + return "physical_only"; + default: + throw InternalException("Unrecognized explain output type"); + } +} + +//===--------------------------------------------------------------------===// +// Extension Directory Setting +//===--------------------------------------------------------------------===// +void ExtensionDirectorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto new_directory = input.ToString(); + config.options.extension_directory = input.ToString(); +} + +void ExtensionDirectorySetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.extension_directory = DBConfig().options.extension_directory; +} + +Value ExtensionDirectorySetting::GetSetting(ClientContext &context) { + return Value(DBConfig::GetConfig(context).options.extension_directory); +} + +//===--------------------------------------------------------------------===// +// External Threads Setting +//===--------------------------------------------------------------------===// +void ExternalThreadsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.external_threads = input.GetValue(); +} + +void ExternalThreadsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.external_threads = DBConfig().options.external_threads; +} + +Value ExternalThreadsSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BIGINT(config.options.external_threads); +} + +//===--------------------------------------------------------------------===// +// File Search Path +//===--------------------------------------------------------------------===// +void FileSearchPathSetting::ResetLocal(ClientContext &context) { + auto &client_data = ClientData::Get(context); + client_data.file_search_path.clear(); +} + +void FileSearchPathSetting::SetLocal(ClientContext &context, const Value &input) { + auto parameter = input.ToString(); + auto &client_data = ClientData::Get(context); + client_data.file_search_path = parameter; +} + +Value FileSearchPathSetting::GetSetting(ClientContext &context) { + auto &client_data = ClientData::Get(context); + return Value(client_data.file_search_path); +} + +//===--------------------------------------------------------------------===// +// Force Compression +//===--------------------------------------------------------------------===// +void ForceCompressionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto compression = StringUtil::Lower(input.ToString()); + if (compression == "none" || compression == "auto") { + config.options.force_compression = CompressionType::COMPRESSION_AUTO; + } else { + auto compression_type = CompressionTypeFromString(compression); + if (compression_type == CompressionType::COMPRESSION_AUTO) { + auto compression_types = StringUtil::Join(ListCompressionTypes(), ", "); + throw ParserException("Unrecognized option for PRAGMA force_compression, expected %s", compression_types); + } + config.options.force_compression = compression_type; + } +} + +void ForceCompressionSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.force_compression = DBConfig().options.force_compression; +} + +Value ForceCompressionSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(*context.db); + return CompressionTypeToString(config.options.force_compression); +} + +//===--------------------------------------------------------------------===// +// Force Bitpacking mode +//===--------------------------------------------------------------------===// +void ForceBitpackingModeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto mode_str = StringUtil::Lower(input.ToString()); + auto mode = BitpackingModeFromString(mode_str); + if (mode == BitpackingMode::INVALID) { + throw ParserException("Unrecognized option for force_bitpacking_mode, expected none, constant, constant_delta, " + "delta_for, or for"); + } + config.options.force_bitpacking_mode = mode; +} + +void ForceBitpackingModeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.force_bitpacking_mode = DBConfig().options.force_bitpacking_mode; +} + +Value ForceBitpackingModeSetting::GetSetting(ClientContext &context) { + return Value(BitpackingModeToString(context.db->config.options.force_bitpacking_mode)); +} + +//===--------------------------------------------------------------------===// +// Home Directory +//===--------------------------------------------------------------------===// +void HomeDirectorySetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).home_directory = ClientConfig().home_directory; +} + +void HomeDirectorySetting::SetLocal(ClientContext &context, const Value &input) { + auto &config = ClientConfig::GetConfig(context); + config.home_directory = input.IsNull() ? string() : input.ToString(); +} + +Value HomeDirectorySetting::GetSetting(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + return Value(config.home_directory); +} + +//===--------------------------------------------------------------------===// +// Integer Division +//===--------------------------------------------------------------------===// +void IntegerDivisionSetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).integer_division = ClientConfig().integer_division; +} + +void IntegerDivisionSetting::SetLocal(ClientContext &context, const Value &input) { + auto &config = ClientConfig::GetConfig(context); + config.integer_division = input.GetValue(); +} + +Value IntegerDivisionSetting::GetSetting(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + return Value(config.integer_division); +} + +//===--------------------------------------------------------------------===// +// Log Query Path +//===--------------------------------------------------------------------===// +void LogQueryPathSetting::ResetLocal(ClientContext &context) { + auto &client_data = ClientData::Get(context); + // TODO: verify that this does the right thing + client_data.log_query_writer = std::move(ClientData(context).log_query_writer); +} + +void LogQueryPathSetting::SetLocal(ClientContext &context, const Value &input) { + auto &client_data = ClientData::Get(context); + auto path = input.ToString(); + if (path.empty()) { + // empty path: clean up query writer + client_data.log_query_writer = nullptr; + } else { + client_data.log_query_writer = make_uniq(FileSystem::GetFileSystem(context), path, + BufferedFileWriter::DEFAULT_OPEN_FLAGS); + } +} + +Value LogQueryPathSetting::GetSetting(ClientContext &context) { + auto &client_data = ClientData::Get(context); + return client_data.log_query_writer ? Value(client_data.log_query_writer->path) : Value(); +} + +//===--------------------------------------------------------------------===// +// Lock Configuration +//===--------------------------------------------------------------------===// +void LockConfigurationSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto new_value = input.GetValue(); + config.options.lock_configuration = new_value; +} + +void LockConfigurationSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.lock_configuration = DBConfig().options.lock_configuration; +} + +Value LockConfigurationSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.lock_configuration); +} + +//===--------------------------------------------------------------------===// +// Immediate Transaction Mode +//===--------------------------------------------------------------------===// +void ImmediateTransactionModeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.immediate_transaction_mode = BooleanValue::Get(input); +} + +void ImmediateTransactionModeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.immediate_transaction_mode = DBConfig().options.immediate_transaction_mode; +} + +Value ImmediateTransactionModeSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.immediate_transaction_mode); +} + +//===--------------------------------------------------------------------===// +// Maximum Expression Depth +//===--------------------------------------------------------------------===// +void MaximumExpressionDepthSetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).max_expression_depth = ClientConfig().max_expression_depth; +} + +void MaximumExpressionDepthSetting::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).max_expression_depth = input.GetValue(); +} + +Value MaximumExpressionDepthSetting::GetSetting(ClientContext &context) { + return Value::UBIGINT(ClientConfig::GetConfig(context).max_expression_depth); +} + +//===--------------------------------------------------------------------===// +// Maximum Memory +//===--------------------------------------------------------------------===// +void MaximumMemorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.maximum_memory = DBConfig::ParseMemoryLimit(input.ToString()); + if (db) { + BufferManager::GetBufferManager(*db).SetLimit(config.options.maximum_memory); + } +} + +void MaximumMemorySetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.SetDefaultMaxMemory(); +} + +Value MaximumMemorySetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(StringUtil::BytesToHumanReadableString(config.options.maximum_memory)); +} + +//===--------------------------------------------------------------------===// +// Password Setting +//===--------------------------------------------------------------------===// +void PasswordSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + // nop +} + +void PasswordSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + // nop +} + +Value PasswordSetting::GetSetting(ClientContext &context) { + return Value(); +} + +//===--------------------------------------------------------------------===// +// Perfect Hash Threshold +//===--------------------------------------------------------------------===// +void PerfectHashThresholdSetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).perfect_ht_threshold = ClientConfig().perfect_ht_threshold; +} + +void PerfectHashThresholdSetting::SetLocal(ClientContext &context, const Value &input) { + auto bits = input.GetValue(); + if (bits < 0 || bits > 32) { + throw ParserException("Perfect HT threshold out of range: should be within range 0 - 32"); + } + ClientConfig::GetConfig(context).perfect_ht_threshold = bits; +} + +Value PerfectHashThresholdSetting::GetSetting(ClientContext &context) { + return Value::BIGINT(ClientConfig::GetConfig(context).perfect_ht_threshold); +} + +//===--------------------------------------------------------------------===// +// Pivot Filter Threshold +//===--------------------------------------------------------------------===// +void PivotFilterThreshold::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).pivot_filter_threshold = ClientConfig().pivot_filter_threshold; +} + +void PivotFilterThreshold::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).pivot_filter_threshold = input.GetValue(); +} + +Value PivotFilterThreshold::GetSetting(ClientContext &context) { + return Value::BIGINT(ClientConfig::GetConfig(context).pivot_filter_threshold); +} + +//===--------------------------------------------------------------------===// +// Pivot Limit +//===--------------------------------------------------------------------===// +void PivotLimitSetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).pivot_limit = ClientConfig().pivot_limit; +} + +void PivotLimitSetting::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).pivot_limit = input.GetValue(); +} + +Value PivotLimitSetting::GetSetting(ClientContext &context) { + return Value::BIGINT(ClientConfig::GetConfig(context).pivot_limit); +} + +//===--------------------------------------------------------------------===// +// PreserveIdentifierCase +//===--------------------------------------------------------------------===// +void PreserveIdentifierCase::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).preserve_identifier_case = ClientConfig().preserve_identifier_case; +} + +void PreserveIdentifierCase::SetLocal(ClientContext &context, const Value &input) { + ClientConfig::GetConfig(context).preserve_identifier_case = input.GetValue(); +} + +Value PreserveIdentifierCase::GetSetting(ClientContext &context) { + return Value::BOOLEAN(ClientConfig::GetConfig(context).preserve_identifier_case); +} + +//===--------------------------------------------------------------------===// +// PreserveInsertionOrder +//===--------------------------------------------------------------------===// +void PreserveInsertionOrder::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.preserve_insertion_order = input.GetValue(); +} + +void PreserveInsertionOrder::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.preserve_insertion_order = DBConfig().options.preserve_insertion_order; +} + +Value PreserveInsertionOrder::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BOOLEAN(config.options.preserve_insertion_order); +} + +//===--------------------------------------------------------------------===// +// ExportLargeBufferArrow +//===--------------------------------------------------------------------===// +void ExportLargeBufferArrow::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto export_large_buffers_arrow = input.GetValue(); + + config.options.arrow_offset_size = export_large_buffers_arrow ? ArrowOffsetSize::LARGE : ArrowOffsetSize::REGULAR; +} + +void ExportLargeBufferArrow::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.arrow_offset_size = DBConfig().options.arrow_offset_size; +} + +Value ExportLargeBufferArrow::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + bool export_large_buffers_arrow = config.options.arrow_offset_size == ArrowOffsetSize::LARGE; + return Value::BOOLEAN(export_large_buffers_arrow); +} + +//===--------------------------------------------------------------------===// +// Profiler History Size +//===--------------------------------------------------------------------===// +void ProfilerHistorySize::ResetLocal(ClientContext &context) { + auto &client_data = ClientData::Get(context); + client_data.query_profiler_history->ResetProfilerHistorySize(); +} + +void ProfilerHistorySize::SetLocal(ClientContext &context, const Value &input) { + auto size = input.GetValue(); + if (size <= 0) { + throw ParserException("Size should be >= 0"); + } + auto &client_data = ClientData::Get(context); + client_data.query_profiler_history->SetProfilerHistorySize(size); +} + +Value ProfilerHistorySize::GetSetting(ClientContext &context) { + return Value(); +} + +//===--------------------------------------------------------------------===// +// Profile Output +//===--------------------------------------------------------------------===// +void ProfileOutputSetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).profiler_save_location = ClientConfig().profiler_save_location; +} + +void ProfileOutputSetting::SetLocal(ClientContext &context, const Value &input) { + auto &config = ClientConfig::GetConfig(context); + auto parameter = input.ToString(); + config.profiler_save_location = parameter; +} + +Value ProfileOutputSetting::GetSetting(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + return Value(config.profiler_save_location); +} + +//===--------------------------------------------------------------------===// +// Profiling Mode +//===--------------------------------------------------------------------===// +void ProfilingModeSetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).enable_profiler = ClientConfig().enable_profiler; + ClientConfig::GetConfig(context).enable_detailed_profiling = ClientConfig().enable_detailed_profiling; + ClientConfig::GetConfig(context).emit_profiler_output = ClientConfig().emit_profiler_output; +} + +void ProfilingModeSetting::SetLocal(ClientContext &context, const Value &input) { + auto parameter = StringUtil::Lower(input.ToString()); + auto &config = ClientConfig::GetConfig(context); + if (parameter == "standard") { + config.enable_profiler = true; + config.enable_detailed_profiling = false; + config.emit_profiler_output = true; + } else if (parameter == "detailed") { + config.enable_profiler = true; + config.enable_detailed_profiling = true; + config.emit_profiler_output = true; + } else { + throw ParserException("Unrecognized profiling mode \"%s\", supported formats: [standard, detailed]", parameter); + } +} + +Value ProfilingModeSetting::GetSetting(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + if (!config.enable_profiler) { + return Value(); + } + return Value(config.enable_detailed_profiling ? "detailed" : "standard"); +} + +//===--------------------------------------------------------------------===// +// Progress Bar Time +//===--------------------------------------------------------------------===// +void ProgressBarTimeSetting::ResetLocal(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + ProgressBar::SystemOverrideCheck(config); + config.wait_time = ClientConfig().wait_time; + config.enable_progress_bar = ClientConfig().enable_progress_bar; +} + +void ProgressBarTimeSetting::SetLocal(ClientContext &context, const Value &input) { + auto &config = ClientConfig::GetConfig(context); + ProgressBar::SystemOverrideCheck(config); + config.wait_time = input.GetValue(); + config.enable_progress_bar = true; +} + +Value ProgressBarTimeSetting::GetSetting(ClientContext &context) { + return Value::BIGINT(ClientConfig::GetConfig(context).wait_time); +} + +//===--------------------------------------------------------------------===// +// Schema +//===--------------------------------------------------------------------===// +void SchemaSetting::ResetLocal(ClientContext &context) { + // FIXME: catalog_search_path is controlled by both SchemaSetting and SearchPathSetting + auto &client_data = ClientData::Get(context); + client_data.catalog_search_path->Reset(); +} + +void SchemaSetting::SetLocal(ClientContext &context, const Value &input) { + auto parameter = input.ToString(); + auto &client_data = ClientData::Get(context); + client_data.catalog_search_path->Set(CatalogSearchEntry::Parse(parameter), CatalogSetPathType::SET_SCHEMA); +} + +Value SchemaSetting::GetSetting(ClientContext &context) { + auto &client_data = ClientData::Get(context); + return client_data.catalog_search_path->GetDefault().schema; +} + +//===--------------------------------------------------------------------===// +// Search Path +//===--------------------------------------------------------------------===// +void SearchPathSetting::ResetLocal(ClientContext &context) { + // FIXME: catalog_search_path is controlled by both SchemaSetting and SearchPathSetting + auto &client_data = ClientData::Get(context); + client_data.catalog_search_path->Reset(); +} + +void SearchPathSetting::SetLocal(ClientContext &context, const Value &input) { + auto parameter = input.ToString(); + auto &client_data = ClientData::Get(context); + client_data.catalog_search_path->Set(CatalogSearchEntry::ParseList(parameter), CatalogSetPathType::SET_SCHEMAS); +} + +Value SearchPathSetting::GetSetting(ClientContext &context) { + auto &client_data = ClientData::Get(context); + auto &set_paths = client_data.catalog_search_path->GetSetPaths(); + return Value(CatalogSearchEntry::ListToString(set_paths)); +} + +//===--------------------------------------------------------------------===// +// Temp Directory +//===--------------------------------------------------------------------===// +void TempDirectorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.temporary_directory = input.ToString(); + config.options.use_temporary_directory = !config.options.temporary_directory.empty(); + if (db) { + auto &buffer_manager = BufferManager::GetBufferManager(*db); + buffer_manager.SetTemporaryDirectory(config.options.temporary_directory); + } +} + +void TempDirectorySetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.temporary_directory = DBConfig().options.temporary_directory; + config.options.use_temporary_directory = DBConfig().options.use_temporary_directory; + if (db) { + auto &buffer_manager = BufferManager::GetBufferManager(*db); + buffer_manager.SetTemporaryDirectory(config.options.temporary_directory); + } +} + +Value TempDirectorySetting::GetSetting(ClientContext &context) { + auto &buffer_manager = BufferManager::GetBufferManager(context); + return Value(buffer_manager.GetTemporaryDirectory()); +} + +//===--------------------------------------------------------------------===// +// Threads Setting +//===--------------------------------------------------------------------===// +void ThreadsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.maximum_threads = input.GetValue(); + if (db) { + TaskScheduler::GetScheduler(*db).SetThreads(config.options.maximum_threads); + } +} + +void ThreadsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.SetDefaultMaxThreads(); +} + +Value ThreadsSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BIGINT(config.options.maximum_threads); +} + +//===--------------------------------------------------------------------===// +// Username Setting +//===--------------------------------------------------------------------===// +void UsernameSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + // nop +} + +void UsernameSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + // nop +} + +Value UsernameSetting::GetSetting(ClientContext &context) { + return Value(); +} + +//===--------------------------------------------------------------------===// +// Allocator Flush Threshold +//===--------------------------------------------------------------------===// +void FlushAllocatorSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.allocator_flush_threshold = DBConfig::ParseMemoryLimit(input.ToString()); + if (db) { + TaskScheduler::GetScheduler(*db).SetAllocatorFlushTreshold(config.options.allocator_flush_threshold); + } +} + +void FlushAllocatorSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.allocator_flush_threshold = DBConfig().options.allocator_flush_threshold; + if (db) { + TaskScheduler::GetScheduler(*db).SetAllocatorFlushTreshold(config.options.allocator_flush_threshold); + } +} + +Value FlushAllocatorSetting::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(StringUtil::BytesToHumanReadableString(config.options.allocator_flush_threshold)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/stream_query_result.cpp b/src/duckdb/src/main/stream_query_result.cpp new file mode 100644 index 00000000..003a8786 --- /dev/null +++ b/src/duckdb/src/main/stream_query_result.cpp @@ -0,0 +1,110 @@ +#include "duckdb/main/stream_query_result.hpp" + +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/materialized_query_result.hpp" +#include "duckdb/common/box_renderer.hpp" + +namespace duckdb { + +StreamQueryResult::StreamQueryResult(StatementType statement_type, StatementProperties properties, + shared_ptr context_p, vector types, + vector names) + : QueryResult(QueryResultType::STREAM_RESULT, statement_type, std::move(properties), std::move(types), + std::move(names), context_p->GetClientProperties()), + context(std::move(context_p)) { + D_ASSERT(context); +} + +StreamQueryResult::~StreamQueryResult() { +} + +string StreamQueryResult::ToString() { + string result; + if (success) { + result = HeaderToString(); + result += "[[STREAM RESULT]]"; + } else { + result = GetError() + "\n"; + } + return result; +} + +unique_ptr StreamQueryResult::LockContext() { + if (!context) { + string error_str = "Attempting to execute an unsuccessful or closed pending query result"; + if (HasError()) { + error_str += StringUtil::Format("\nError: %s", GetError()); + } + throw InvalidInputException(error_str); + } + return context->LockContext(); +} + +void StreamQueryResult::CheckExecutableInternal(ClientContextLock &lock) { + if (!IsOpenInternal(lock)) { + string error_str = "Attempting to execute an unsuccessful or closed pending query result"; + if (HasError()) { + error_str += StringUtil::Format("\nError: %s", GetError()); + } + throw InvalidInputException(error_str); + } +} + +unique_ptr StreamQueryResult::FetchRaw() { + unique_ptr chunk; + { + auto lock = LockContext(); + CheckExecutableInternal(*lock); + chunk = context->Fetch(*lock, *this); + } + if (!chunk || chunk->ColumnCount() == 0 || chunk->size() == 0) { + Close(); + return nullptr; + } + return chunk; +} + +unique_ptr StreamQueryResult::Materialize() { + if (HasError() || !context) { + return make_uniq(GetErrorObject()); + } + auto collection = make_uniq(Allocator::DefaultAllocator(), types); + + ColumnDataAppendState append_state; + collection->InitializeAppend(append_state); + while (true) { + auto chunk = Fetch(); + if (!chunk || chunk->size() == 0) { + break; + } + collection->Append(append_state, *chunk); + } + auto result = + make_uniq(statement_type, properties, names, std::move(collection), client_properties); + if (HasError()) { + return make_uniq(GetErrorObject()); + } + return result; +} + +bool StreamQueryResult::IsOpenInternal(ClientContextLock &lock) { + bool invalidated = !success || !context; + if (!invalidated) { + invalidated = !context->IsActiveResult(lock, this); + } + return !invalidated; +} + +bool StreamQueryResult::IsOpen() { + if (!success || !context) { + return false; + } + auto lock = LockContext(); + return IsOpenInternal(*lock); +} + +void StreamQueryResult::Close() { + context.reset(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/valid_checker.cpp b/src/duckdb/src/main/valid_checker.cpp new file mode 100644 index 00000000..a4373ef0 --- /dev/null +++ b/src/duckdb/src/main/valid_checker.cpp @@ -0,0 +1,22 @@ +#include "duckdb/main/valid_checker.hpp" + +namespace duckdb { + +ValidChecker::ValidChecker() : is_invalidated(false) { +} + +void ValidChecker::Invalidate(string error) { + lock_guard l(invalidate_lock); + this->is_invalidated = true; + this->invalidated_msg = std::move(error); +} + +bool ValidChecker::IsInvalidated() { + return this->is_invalidated; +} + +string ValidChecker::InvalidatedMessage() { + lock_guard l(invalidate_lock); + return invalidated_msg; +} +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/column_binding_replacer.cpp b/src/duckdb/src/optimizer/column_binding_replacer.cpp new file mode 100644 index 00000000..2450b926 --- /dev/null +++ b/src/duckdb/src/optimizer/column_binding_replacer.cpp @@ -0,0 +1,43 @@ +#include "duckdb/optimizer/column_binding_replacer.hpp" + +#include "duckdb/planner/expression/bound_columnref_expression.hpp" + +namespace duckdb { + +ReplacementBinding::ReplacementBinding(ColumnBinding old_binding, ColumnBinding new_binding) + : old_binding(old_binding), new_binding(new_binding), replace_type(false) { +} + +ReplacementBinding::ReplacementBinding(ColumnBinding old_binding, ColumnBinding new_binding, LogicalType new_type) + : old_binding(old_binding), new_binding(new_binding), replace_type(true), new_type(std::move(new_type)) { +} + +ColumnBindingReplacer::ColumnBindingReplacer() { +} + +void ColumnBindingReplacer::VisitOperator(LogicalOperator &op) { + if (stop_operator && stop_operator.get() == &op) { + return; + } + VisitOperatorChildren(op); + VisitOperatorExpressions(op); +} + +void ColumnBindingReplacer::VisitExpression(unique_ptr *expression) { + auto &expr = *expression; + if (expr->expression_class == ExpressionClass::BOUND_COLUMN_REF) { + auto &bound_column_ref = expr->Cast(); + for (const auto &replace_binding : replacement_bindings) { + if (bound_column_ref.binding == replace_binding.old_binding) { + bound_column_ref.binding = replace_binding.new_binding; + if (replace_binding.replace_type) { + bound_column_ref.return_type = replace_binding.new_type; + } + } + } + } + + VisitExpressionChildren(**expression); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp b/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp new file mode 100644 index 00000000..23f99f75 --- /dev/null +++ b/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp @@ -0,0 +1,156 @@ +#include "duckdb/optimizer/column_lifetime_optimizer.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" + +namespace duckdb { + +void ColumnLifetimeAnalyzer::ExtractUnusedColumnBindings(vector bindings, + column_binding_set_t &unused_bindings) { + for (idx_t i = 0; i < bindings.size(); i++) { + if (column_references.find(bindings[i]) == column_references.end()) { + unused_bindings.insert(bindings[i]); + } + } +} + +void ColumnLifetimeAnalyzer::GenerateProjectionMap(vector bindings, + column_binding_set_t &unused_bindings, + vector &projection_map) { + projection_map.clear(); + if (unused_bindings.empty()) { + return; + } + // now iterate over the result bindings of the child + for (idx_t i = 0; i < bindings.size(); i++) { + // if this binding does not belong to the unused bindings, add it to the projection map + if (unused_bindings.find(bindings[i]) == unused_bindings.end()) { + projection_map.push_back(i); + } + } + if (projection_map.size() == bindings.size()) { + projection_map.clear(); + } +} + +void ColumnLifetimeAnalyzer::StandardVisitOperator(LogicalOperator &op) { + LogicalOperatorVisitor::VisitOperatorExpressions(op); + if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + // visit the duplicate eliminated columns on the LHS, if any + auto &delim_join = op.Cast(); + for (auto &expr : delim_join.duplicate_eliminated_columns) { + VisitExpression(&expr); + } + } + LogicalOperatorVisitor::VisitOperatorChildren(op); +} + +void ColumnLifetimeAnalyzer::VisitOperator(LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + // FIXME: groups that are not referenced can be removed from projection + // recurse into the children of the aggregate + ColumnLifetimeAnalyzer analyzer; + analyzer.VisitOperatorExpressions(op); + analyzer.VisitOperator(*op.children[0]); + return; + } + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + if (everything_referenced) { + break; + } + auto &comp_join = op.Cast(); + if (comp_join.join_type == JoinType::MARK || comp_join.join_type == JoinType::SEMI || + comp_join.join_type == JoinType::ANTI) { + break; + } + // FIXME for now, we only push into the projection map for equality (hash) joins + // FIXME: add projection to LHS as well + bool has_equality = false; + for (auto &cond : comp_join.conditions) { + if (cond.comparison == ExpressionType::COMPARE_EQUAL) { + has_equality = true; + break; + } + } + if (!has_equality) { + break; + } + // visit current operator expressions so they are added to the referenced_columns + LogicalOperatorVisitor::VisitOperatorExpressions(op); + + column_binding_set_t unused_bindings; + auto old_op_bindings = op.GetColumnBindings(); + ExtractUnusedColumnBindings(op.children[1]->GetColumnBindings(), unused_bindings); + + // now recurse into the filter and its children + LogicalOperatorVisitor::VisitOperatorChildren(op); + + // then generate the projection map + GenerateProjectionMap(op.children[1]->GetColumnBindings(), unused_bindings, comp_join.right_projection_map); + return; + } + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + // for set operations we don't remove anything, just recursively visit the children + // FIXME: for UNION we can remove unreferenced columns as long as everything_referenced is false (i.e. we + // encounter a UNION node that is not preceded by a DISTINCT) + for (auto &child : op.children) { + ColumnLifetimeAnalyzer analyzer(true); + analyzer.VisitOperator(*child); + } + return; + case LogicalOperatorType::LOGICAL_PROJECTION: { + // then recurse into the children of this projection + ColumnLifetimeAnalyzer analyzer; + analyzer.VisitOperatorExpressions(op); + analyzer.VisitOperator(*op.children[0]); + return; + } + case LogicalOperatorType::LOGICAL_DISTINCT: { + // distinct, all projected columns are used for the DISTINCT computation + // mark all columns as used and continue to the children + // FIXME: DISTINCT with expression list does not implicitly reference everything + everything_referenced = true; + break; + } + case LogicalOperatorType::LOGICAL_FILTER: { + auto &filter = op.Cast(); + if (everything_referenced) { + break; + } + // first visit operator expressions to populate referenced columns + LogicalOperatorVisitor::VisitOperatorExpressions(op); + // filter, figure out which columns are not needed after the filter + column_binding_set_t unused_bindings; + ExtractUnusedColumnBindings(op.children[0]->GetColumnBindings(), unused_bindings); + + // now recurse into the filter and its children + LogicalOperatorVisitor::VisitOperatorChildren(op); + + // then generate the projection map + GenerateProjectionMap(op.children[0]->GetColumnBindings(), unused_bindings, filter.projection_map); + return; + } + default: + break; + } + StandardVisitOperator(op); +} + +unique_ptr ColumnLifetimeAnalyzer::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + column_references.insert(expr.binding); + return nullptr; +} + +unique_ptr ColumnLifetimeAnalyzer::VisitReplace(BoundReferenceExpression &expr, + unique_ptr *expr_ptr) { + // BoundReferenceExpression should not be used here yet, they only belong in the physical plan + throw InternalException("BoundReferenceExpression should not be used here yet!"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp b/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp new file mode 100644 index 00000000..9ce0394d --- /dev/null +++ b/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp @@ -0,0 +1,60 @@ +#include "duckdb/optimizer/common_aggregate_optimizer.hpp" + +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/column_binding_map.hpp" + +namespace duckdb { + +void CommonAggregateOptimizer::VisitOperator(LogicalOperator &op) { + LogicalOperatorVisitor::VisitOperator(op); + switch (op.type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + ExtractCommonAggregates(op.Cast()); + break; + default: + break; + } +} + +unique_ptr CommonAggregateOptimizer::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + // check if this column ref points to an aggregate that was remapped; if it does we remap it + auto entry = aggregate_map.find(expr.binding); + if (entry != aggregate_map.end()) { + expr.binding = entry->second; + } + return nullptr; +} + +void CommonAggregateOptimizer::ExtractCommonAggregates(LogicalAggregate &aggr) { + expression_map_t aggregate_remap; + idx_t total_erased = 0; + for (idx_t i = 0; i < aggr.expressions.size(); i++) { + idx_t original_index = i + total_erased; + auto entry = aggregate_remap.find(*aggr.expressions[i]); + if (entry == aggregate_remap.end()) { + // aggregate does not exist yet: add it to the map + aggregate_remap[*aggr.expressions[i]] = i; + if (i != original_index) { + // this aggregate is not erased, however an agregate BEFORE it has been erased + // so we need to remap this aggregaet + ColumnBinding original_binding(aggr.aggregate_index, original_index); + ColumnBinding new_binding(aggr.aggregate_index, i); + aggregate_map[original_binding] = new_binding; + } + } else { + // aggregate already exists! we can remove this entry + total_erased++; + aggr.expressions.erase(aggr.expressions.begin() + i); + i--; + // we need to remap any references to this aggregate so they point to the other aggregate + ColumnBinding original_binding(aggr.aggregate_index, original_index); + ColumnBinding new_binding(aggr.aggregate_index, entry->second); + aggregate_map[original_binding] = new_binding; + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/compressed_materialization.cpp b/src/duckdb/src/optimizer/compressed_materialization.cpp new file mode 100644 index 00000000..a1fc9912 --- /dev/null +++ b/src/duckdb/src/optimizer/compressed_materialization.cpp @@ -0,0 +1,477 @@ +#include "duckdb/optimizer/compressed_materialization.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/operators.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/optimizer/topn_optimizer.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" + +namespace duckdb { + +CMChildInfo::CMChildInfo(LogicalOperator &op, const column_binding_set_t &referenced_bindings) + : bindings_before(op.GetColumnBindings()), types(op.types), can_compress(bindings_before.size(), true) { + for (const auto &binding : referenced_bindings) { + for (idx_t binding_idx = 0; binding_idx < bindings_before.size(); binding_idx++) { + if (binding == bindings_before[binding_idx]) { + can_compress[binding_idx] = false; + } + } + } +} + +CMBindingInfo::CMBindingInfo(ColumnBinding binding_p, const LogicalType &type_p) + : binding(binding_p), type(type_p), needs_decompression(false) { +} + +CompressedMaterializationInfo::CompressedMaterializationInfo(LogicalOperator &op, vector &&child_idxs_p, + const column_binding_set_t &referenced_bindings) + : child_idxs(child_idxs_p) { + child_info.reserve(child_idxs.size()); + for (const auto &child_idx : child_idxs) { + child_info.emplace_back(*op.children[child_idx], referenced_bindings); + } +} + +CompressExpression::CompressExpression(unique_ptr expression_p, unique_ptr stats_p) + : expression(std::move(expression_p)), stats(std::move(stats_p)) { +} + +CompressedMaterialization::CompressedMaterialization(ClientContext &context_p, Binder &binder_p, + statistics_map_t &&statistics_map_p) + : context(context_p), binder(binder_p), statistics_map(std::move(statistics_map_p)) { +} + +void CompressedMaterialization::GetReferencedBindings(const Expression &expression, + column_binding_set_t &referenced_bindings) { + if (expression.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + const auto &col_ref = expression.Cast(); + referenced_bindings.insert(col_ref.binding); + } else { + ExpressionIterator::EnumerateChildren( + expression, [&](const Expression &child) { GetReferencedBindings(child, referenced_bindings); }); + } +} + +void CompressedMaterialization::UpdateBindingInfo(CompressedMaterializationInfo &info, const ColumnBinding &binding, + bool needs_decompression) { + auto &binding_map = info.binding_map; + auto binding_it = binding_map.find(binding); + if (binding_it == binding_map.end()) { + return; + } + + auto &binding_info = binding_it->second; + binding_info.needs_decompression = needs_decompression; + auto stats_it = statistics_map.find(binding); + if (stats_it != statistics_map.end()) { + binding_info.stats = statistics_map[binding]->ToUnique(); + } +} + +void CompressedMaterialization::Compress(unique_ptr &op) { + root = op.get(); + root->ResolveOperatorTypes(); + + CompressInternal(op); +} + +void CompressedMaterialization::CompressInternal(unique_ptr &op) { + if (TopN::CanOptimize(*op)) { // Let's not mess with the TopN optimizer + CompressInternal(op->children[0]->children[0]); + return; + } + + for (auto &child : op->children) { + CompressInternal(child); + } + + switch (op->type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + CompressAggregate(op); + break; + case LogicalOperatorType::LOGICAL_DISTINCT: + CompressDistinct(op); + break; + case LogicalOperatorType::LOGICAL_ORDER_BY: + CompressOrder(op); + break; + default: + return; + } +} + +void CompressedMaterialization::CreateProjections(unique_ptr &op, + CompressedMaterializationInfo &info) { + auto &materializing_op = *op; + + bool compressed_anything = false; + for (idx_t i = 0; i < info.child_idxs.size(); i++) { + auto &child_info = info.child_info[i]; + vector> compress_exprs; + if (TryCompressChild(info, child_info, compress_exprs)) { + // We can compress: Create a projection on top of the child operator + const auto child_idx = info.child_idxs[i]; + CreateCompressProjection(materializing_op.children[child_idx], std::move(compress_exprs), info, child_info); + compressed_anything = true; + } + } + + if (compressed_anything) { + CreateDecompressProjection(op, info); + } +} + +bool CompressedMaterialization::TryCompressChild(CompressedMaterializationInfo &info, const CMChildInfo &child_info, + vector> &compress_exprs) { + // Try to compress each of the column bindings of the child + bool compressed_anything = false; + for (idx_t child_i = 0; child_i < child_info.bindings_before.size(); child_i++) { + const auto child_binding = child_info.bindings_before[child_i]; + const auto &child_type = child_info.types[child_i]; + const auto &can_compress = child_info.can_compress[child_i]; + auto compress_expr = GetCompressExpression(child_binding, child_type, can_compress); + bool compressed = false; + if (compress_expr) { // We compressed, mark the outgoing binding in need of decompression + compress_exprs.emplace_back(std::move(compress_expr)); + compressed = true; + } else { // We did not compress, just push a colref + auto colref_expr = make_uniq(child_type, child_binding); + auto it = statistics_map.find(colref_expr->binding); + unique_ptr colref_stats = it != statistics_map.end() ? it->second->ToUnique() : nullptr; + compress_exprs.emplace_back(make_uniq(std::move(colref_expr), std::move(colref_stats))); + } + UpdateBindingInfo(info, child_binding, compressed); + compressed_anything = compressed_anything || compressed; + } + if (!compressed_anything) { + // If we compressed anything non-generically, we still need to decompress + for (const auto &entry : info.binding_map) { + compressed_anything = compressed_anything || entry.second.needs_decompression; + } + } + return compressed_anything; +} + +void CompressedMaterialization::CreateCompressProjection(unique_ptr &child_op, + vector> &&compress_exprs, + CompressedMaterializationInfo &info, CMChildInfo &child_info) { + // Replace child op with a projection + vector> projections; + projections.reserve(compress_exprs.size()); + for (auto &compress_expr : compress_exprs) { + projections.emplace_back(std::move(compress_expr->expression)); + } + const auto table_index = binder.GenerateTableIndex(); + auto compress_projection = make_uniq(table_index, std::move(projections)); + compression_table_indices.insert(table_index); + compress_projection->ResolveOperatorTypes(); + + compress_projection->children.emplace_back(std::move(child_op)); + child_op = std::move(compress_projection); + + // Get the new bindings and types + child_info.bindings_after = child_op->GetColumnBindings(); + const auto &new_types = child_op->types; + + // Initialize a ColumnBindingReplacer with the new bindings and types + ColumnBindingReplacer replacer; + auto &replacement_bindings = replacer.replacement_bindings; + for (idx_t col_idx = 0; col_idx < child_info.bindings_before.size(); col_idx++) { + const auto &old_binding = child_info.bindings_before[col_idx]; + const auto &new_binding = child_info.bindings_after[col_idx]; + const auto &new_type = new_types[col_idx]; + replacement_bindings.emplace_back(old_binding, new_binding, new_type); + + // Remove the old binding from the statistics map + statistics_map.erase(old_binding); + } + + // Make sure we skip the compress operator when replacing bindings + replacer.stop_operator = child_op.get(); + + // Make the plan consistent again + replacer.VisitOperator(*root); + + // Replace in/out exprs in the binding map too + auto &binding_map = info.binding_map; + for (auto &replacement_binding : replacement_bindings) { + auto it = binding_map.find(replacement_binding.old_binding); + if (it == binding_map.end()) { + continue; + } + auto &binding_info = it->second; + if (binding_info.binding == replacement_binding.old_binding) { + binding_info.binding = replacement_binding.new_binding; + } + + if (it->first == replacement_binding.old_binding) { + auto binding_info_local = std::move(binding_info); + binding_map.erase(it); + binding_map.emplace(replacement_binding.new_binding, std::move(binding_info_local)); + } + } + + // Add projection stats to statistics map + for (idx_t col_idx = 0; col_idx < child_info.bindings_after.size(); col_idx++) { + const auto &binding = child_info.bindings_after[col_idx]; + auto &stats = compress_exprs[col_idx]->stats; + statistics_map.emplace(binding, std::move(stats)); + } +} + +void CompressedMaterialization::CreateDecompressProjection(unique_ptr &op, + CompressedMaterializationInfo &info) { + const auto bindings = op->GetColumnBindings(); + op->ResolveOperatorTypes(); + const auto &types = op->types; + + // Create decompress expressions for everything we compressed + auto &binding_map = info.binding_map; + vector> decompress_exprs; + vector> statistics; + for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { + const auto &binding = bindings[col_idx]; + auto decompress_expr = make_uniq_base(types[col_idx], binding); + optional_ptr stats; + for (auto &entry : binding_map) { + auto &binding_info = entry.second; + if (binding_info.binding != binding) { + continue; + } + stats = binding_info.stats.get(); + if (binding_info.needs_decompression) { + decompress_expr = GetDecompressExpression(std::move(decompress_expr), binding_info.type, *stats); + } + } + statistics.push_back(stats); + decompress_exprs.emplace_back(std::move(decompress_expr)); + } + + // Replace op with a projection + const auto table_index = binder.GenerateTableIndex(); + auto decompress_projection = make_uniq(table_index, std::move(decompress_exprs)); + decompression_table_indices.insert(table_index); + + decompress_projection->children.emplace_back(std::move(op)); + op = std::move(decompress_projection); + + // Check if we're placing a projection on top of the root + if (op->children[0].get() == root.get()) { + root = op.get(); + return; + } + + // Get the new bindings and types + auto new_bindings = op->GetColumnBindings(); + op->ResolveOperatorTypes(); + auto &new_types = op->types; + + // Initialize a ColumnBindingReplacer with the new bindings and types + ColumnBindingReplacer replacer; + auto &replacement_bindings = replacer.replacement_bindings; + for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { + const auto &old_binding = bindings[col_idx]; + const auto &new_binding = new_bindings[col_idx]; + const auto &new_type = new_types[col_idx]; + replacement_bindings.emplace_back(old_binding, new_binding, new_type); + + if (statistics[col_idx]) { + statistics_map[new_binding] = statistics[col_idx]->ToUnique(); + } + } + + // Make sure we skip the decompress operator when replacing bindings + replacer.stop_operator = op.get(); + + // Make the plan consistent again + replacer.VisitOperator(*root); +} + +unique_ptr CompressedMaterialization::GetCompressExpression(const ColumnBinding &binding, + const LogicalType &type, + const bool &can_compress) { + auto it = statistics_map.find(binding); + if (can_compress && it != statistics_map.end() && it->second) { + auto input = make_uniq(type, binding); + const auto &stats = *it->second; + return GetCompressExpression(std::move(input), stats); + } + return nullptr; +} + +unique_ptr CompressedMaterialization::GetCompressExpression(unique_ptr input, + const BaseStatistics &stats) { + const auto &type = input->return_type; + if (type != stats.GetType()) { // LCOV_EXCL_START + return nullptr; + } // LCOV_EXCL_STOP + if (type.IsIntegral()) { + return GetIntegralCompress(std::move(input), stats); + } else if (type.id() == LogicalTypeId::VARCHAR) { + return GetStringCompress(std::move(input), stats); + } + return nullptr; +} + +static Value GetIntegralRangeValue(ClientContext &context, const LogicalType &type, const BaseStatistics &stats) { + auto min = NumericStats::Min(stats); + auto max = NumericStats::Max(stats); + + vector> arguments; + arguments.emplace_back(make_uniq(max)); + arguments.emplace_back(make_uniq(min)); + BoundFunctionExpression sub(type, SubtractFun::GetFunction(type, type), std::move(arguments), nullptr); + + Value result; + if (ExpressionExecutor::TryEvaluateScalar(context, sub, result)) { + return result; + } else { + // Couldn't evaluate: Return max hugeint as range so GetIntegralCompress will return nullptr + return Value::HUGEINT(NumericLimits::Maximum()); + } +} + +unique_ptr CompressedMaterialization::GetIntegralCompress(unique_ptr input, + const BaseStatistics &stats) { + const auto &type = input->return_type; + if (GetTypeIdSize(type.InternalType()) == 1 || !NumericStats::HasMinMax(stats)) { + return nullptr; + } + + // Get range and cast to UBIGINT (might fail for HUGEINT, in which case we just return) + Value range_value = GetIntegralRangeValue(context, type, stats); + if (!range_value.DefaultTryCastAs(LogicalType::UBIGINT)) { + return nullptr; + } + + // Get the smallest type that the range can fit into + const auto range = UBigIntValue::Get(range_value); + LogicalType cast_type; + if (range <= NumericLimits().Maximum()) { + cast_type = LogicalType::UTINYINT; + } else if (range <= NumericLimits().Maximum()) { + cast_type = LogicalType::USMALLINT; + } else if (range <= NumericLimits().Maximum()) { + cast_type = LogicalType::UINTEGER; + } else { + D_ASSERT(range <= NumericLimits().Maximum()); + cast_type = LogicalType::UBIGINT; + } + + // Check if type that fits the range is smaller than the input type + if (GetTypeIdSize(cast_type.InternalType()) == GetTypeIdSize(type.InternalType())) { + return nullptr; + } + D_ASSERT(GetTypeIdSize(cast_type.InternalType()) < GetTypeIdSize(type.InternalType())); + + // Compressing will yield a benefit + auto compress_function = CMIntegralCompressFun::GetFunction(type, cast_type); + vector> arguments; + arguments.emplace_back(std::move(input)); + arguments.emplace_back(make_uniq(NumericStats::Min(stats))); + auto compress_expr = + make_uniq(cast_type, compress_function, std::move(arguments), nullptr); + + auto compress_stats = BaseStatistics::CreateEmpty(cast_type); + compress_stats.CopyBase(stats); + NumericStats::SetMin(compress_stats, Value(0).DefaultCastAs(cast_type)); + NumericStats::SetMax(compress_stats, range_value.DefaultCastAs(cast_type)); + + return make_uniq(std::move(compress_expr), compress_stats.ToUnique()); +} + +unique_ptr CompressedMaterialization::GetStringCompress(unique_ptr input, + const BaseStatistics &stats) { + if (!StringStats::HasMaxStringLength(stats)) { + return nullptr; + } + + const auto max_string_length = StringStats::MaxStringLength(stats); + LogicalType cast_type = LogicalType::INVALID; + for (const auto &compressed_type : CompressedMaterializationFunctions::StringTypes()) { + if (max_string_length < GetTypeIdSize(compressed_type.InternalType())) { + cast_type = compressed_type; + break; + } + } + if (cast_type == LogicalType::INVALID) { + return nullptr; + } + + auto compress_stats = BaseStatistics::CreateEmpty(cast_type); + compress_stats.CopyBase(stats); + if (cast_type.id() == LogicalTypeId::USMALLINT) { + auto min_string = StringStats::Min(stats); + auto max_string = StringStats::Max(stats); + + uint8_t min_numeric = 0; + if (max_string_length != 0 && min_string.length() != 0) { + min_numeric = *reinterpret_cast(min_string.c_str()); + } + uint8_t max_numeric = 0; + if (max_string_length != 0 && max_string.length() != 0) { + max_numeric = *reinterpret_cast(max_string.c_str()); + } + + Value min_val = Value::USMALLINT(min_numeric); + Value max_val = Value::USMALLINT(max_numeric + 1); + if (max_numeric < NumericLimits::Maximum()) { + cast_type = LogicalType::UTINYINT; + compress_stats = BaseStatistics::CreateEmpty(cast_type); + compress_stats.CopyBase(stats); + min_val = Value::UTINYINT(min_numeric); + max_val = Value::UTINYINT(max_numeric + 1); + } + + NumericStats::SetMin(compress_stats, min_val); + NumericStats::SetMax(compress_stats, max_val); + } + + auto compress_function = CMStringCompressFun::GetFunction(cast_type); + vector> arguments; + arguments.emplace_back(std::move(input)); + auto compress_expr = + make_uniq(cast_type, compress_function, std::move(arguments), nullptr); + return make_uniq(std::move(compress_expr), compress_stats.ToUnique()); +} + +unique_ptr CompressedMaterialization::GetDecompressExpression(unique_ptr input, + const LogicalType &result_type, + const BaseStatistics &stats) { + const auto &type = result_type; + if (TypeIsIntegral(type.InternalType())) { + return GetIntegralDecompress(std::move(input), result_type, stats); + } else if (type.id() == LogicalTypeId::VARCHAR) { + return GetStringDecompress(std::move(input), stats); + } else { + throw InternalException("Type other than integral/string marked for decompression!"); + } +} + +unique_ptr CompressedMaterialization::GetIntegralDecompress(unique_ptr input, + const LogicalType &result_type, + const BaseStatistics &stats) { + D_ASSERT(NumericStats::HasMinMax(stats)); + auto decompress_function = CMIntegralDecompressFun::GetFunction(input->return_type, result_type); + vector> arguments; + arguments.emplace_back(std::move(input)); + arguments.emplace_back(make_uniq(NumericStats::Min(stats))); + return make_uniq(result_type, decompress_function, std::move(arguments), nullptr); +} + +unique_ptr CompressedMaterialization::GetStringDecompress(unique_ptr input, + const BaseStatistics &stats) { + D_ASSERT(StringStats::HasMaxStringLength(stats)); + auto decompress_function = CMStringDecompressFun::GetFunction(input->return_type); + vector> arguments; + arguments.emplace_back(std::move(input)); + return make_uniq(decompress_function.return_type, decompress_function, + std::move(arguments), nullptr); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp b/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp new file mode 100644 index 00000000..a61787e7 --- /dev/null +++ b/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp @@ -0,0 +1,140 @@ +#include "duckdb/optimizer/compressed_materialization.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" + +namespace duckdb { + +void CompressedMaterialization::CompressAggregate(unique_ptr &op) { + auto &aggregate = op->Cast(); + auto &groups = aggregate.groups; + column_binding_set_t group_binding_set; + for (const auto &group : groups) { + if (group->type != ExpressionType::BOUND_COLUMN_REF) { + continue; + } + auto &colref = group->Cast(); + if (group_binding_set.find(colref.binding) != group_binding_set.end()) { + return; // Duplicate group - don't compress + } + group_binding_set.insert(colref.binding); + } + auto &group_stats = aggregate.group_stats; + + // No need to compress if there are no groups/stats + if (groups.empty() || group_stats.empty()) { + return; + } + D_ASSERT(groups.size() == group_stats.size()); + + // Find all bindings referenced by non-colref expressions in the groups + // These are excluded from compression by projection + // But we can try to compress the expression directly + column_binding_set_t referenced_bindings; + vector group_bindings(groups.size(), ColumnBinding()); + vector needs_decompression(groups.size(), false); + vector> stored_group_stats; + stored_group_stats.resize(groups.size()); + for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { + auto &group_expr = *groups[group_idx]; + if (group_expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = group_expr.Cast(); + group_bindings[group_idx] = colref.binding; + continue; // Will be compressed generically + } + + // Mark the bindings referenced by the non-colref expression so they won't be modified + GetReferencedBindings(group_expr, referenced_bindings); + + // The non-colref expression won't be compressed generically, so try to compress it here + if (!group_stats[group_idx]) { + continue; // Can't compress without stats + } + + // Try to compress, if successful, replace the expression + auto compress_expr = GetCompressExpression(group_expr.Copy(), *group_stats[group_idx]); + if (compress_expr) { + needs_decompression[group_idx] = true; + stored_group_stats[group_idx] = std::move(group_stats[group_idx]); + groups[group_idx] = std::move(compress_expr->expression); + group_stats[group_idx] = std::move(compress_expr->stats); + } + } + + // Anything referenced in the aggregate functions is also excluded + for (idx_t expr_idx = 0; expr_idx < aggregate.expressions.size(); expr_idx++) { + const auto &expr = *aggregate.expressions[expr_idx]; + D_ASSERT(expr.type == ExpressionType::BOUND_AGGREGATE); + const auto &aggr_expr = expr.Cast(); + for (const auto &child : aggr_expr.children) { + GetReferencedBindings(*child, referenced_bindings); + } + if (aggr_expr.filter) { + GetReferencedBindings(*aggr_expr.filter, referenced_bindings); + } + if (aggr_expr.order_bys) { + for (const auto &order : aggr_expr.order_bys->orders) { + const auto &order_expr = *order.expression; + if (order_expr.type != ExpressionType::BOUND_COLUMN_REF) { + GetReferencedBindings(order_expr, referenced_bindings); + } + } + } + } + + // Create info for compression + CompressedMaterializationInfo info(*op, {0}, referenced_bindings); + + // Create binding mapping + const auto bindings_out = aggregate.GetColumnBindings(); + const auto &types = aggregate.types; + for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { + // Aggregate changes bindings as it has a table idx + CMBindingInfo binding_info(bindings_out[group_idx], types[group_idx]); + binding_info.needs_decompression = needs_decompression[group_idx]; + if (needs_decompression[group_idx]) { + // Compressed non-generically + auto entry = info.binding_map.emplace(bindings_out[group_idx], std::move(binding_info)); + entry.first->second.stats = std::move(stored_group_stats[group_idx]); + } else if (group_bindings[group_idx] != ColumnBinding()) { + info.binding_map.emplace(group_bindings[group_idx], std::move(binding_info)); + } + } + + // Now try to compress + CreateProjections(op, info); + + // Update aggregate statistics + UpdateAggregateStats(op); +} + +void CompressedMaterialization::UpdateAggregateStats(unique_ptr &op) { + if (op->type != LogicalOperatorType::LOGICAL_PROJECTION) { + return; + } + + // Update aggregate group stats if compressed + auto &compressed_aggregate = op->children[0]->Cast(); + auto &groups = compressed_aggregate.groups; + auto &group_stats = compressed_aggregate.group_stats; + + for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { + auto &group_expr = *groups[group_idx]; + if (group_expr.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + continue; + } + auto &colref = group_expr.Cast(); + if (!group_stats[group_idx]) { + continue; + } + if (colref.return_type == group_stats[group_idx]->GetType()) { + continue; + } + auto it = statistics_map.find(colref.binding); + if (it != statistics_map.end() && it->second) { + group_stats[group_idx] = it->second->ToUnique(); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/compressed_materialization/compress_distinct.cpp b/src/duckdb/src/optimizer/compressed_materialization/compress_distinct.cpp new file mode 100644 index 00000000..3dd1f536 --- /dev/null +++ b/src/duckdb/src/optimizer/compressed_materialization/compress_distinct.cpp @@ -0,0 +1,42 @@ +#include "duckdb/optimizer/compressed_materialization.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" + +namespace duckdb { + +void CompressedMaterialization::CompressDistinct(unique_ptr &op) { + auto &distinct = op->Cast(); + auto &distinct_targets = distinct.distinct_targets; + + column_binding_set_t referenced_bindings; + for (auto &target : distinct_targets) { + if (target->type != ExpressionType::BOUND_COLUMN_REF) { // LCOV_EXCL_START + GetReferencedBindings(*target, referenced_bindings); + } // LCOV_EXCL_STOP + } + + if (distinct.order_by) { + for (auto &order : distinct.order_by->orders) { + if (order.expression->type != ExpressionType::BOUND_COLUMN_REF) { // LCOV_EXCL_START + GetReferencedBindings(*order.expression, referenced_bindings); + } // LCOV_EXCL_STOP + } + } + + // Create info for compression + CompressedMaterializationInfo info(*op, {0}, referenced_bindings); + + // Create binding mapping + const auto bindings = distinct.GetColumnBindings(); + const auto &types = distinct.types; + D_ASSERT(bindings.size() == types.size()); + for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { + // Distinct does not change bindings, input binding is output binding + info.binding_map.emplace(bindings[col_idx], CMBindingInfo(bindings[col_idx], types[col_idx])); + } + + // Now try to compress + CreateProjections(op, info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/compressed_materialization/compress_order.cpp b/src/duckdb/src/optimizer/compressed_materialization/compress_order.cpp new file mode 100644 index 00000000..2b098fda --- /dev/null +++ b/src/duckdb/src/optimizer/compressed_materialization/compress_order.cpp @@ -0,0 +1,65 @@ +#include "duckdb/optimizer/compressed_materialization.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_order.hpp" + +namespace duckdb { + +void CompressedMaterialization::CompressOrder(unique_ptr &op) { + auto &order = op->Cast(); + + // Find all bindings referenced by non-colref expressions in the order nodes + // These are excluded from compression by projection + // But we can try to compress the expression directly + column_binding_set_t referenced_bindings; + for (idx_t order_node_idx = 0; order_node_idx < order.orders.size(); order_node_idx++) { + auto &bound_order = order.orders[order_node_idx]; + auto &order_expression = *bound_order.expression; + if (order_expression.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + continue; // Will be compressed generically + } + + // Mark the bindings referenced by the non-colref expression so they won't be modified + GetReferencedBindings(order_expression, referenced_bindings); + } + + // Create info for compression + CompressedMaterializationInfo info(*op, {0}, referenced_bindings); + + // Create binding mapping + const auto bindings = order.GetColumnBindings(); + const auto &types = order.types; + D_ASSERT(bindings.size() == types.size()); + for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { + // Order does not change bindings, input binding is output binding + info.binding_map.emplace(bindings[col_idx], CMBindingInfo(bindings[col_idx], types[col_idx])); + } + + // Now try to compress + CreateProjections(op, info); + + // Update order statistics + UpdateOrderStats(op); +} + +void CompressedMaterialization::UpdateOrderStats(unique_ptr &op) { + if (op->type != LogicalOperatorType::LOGICAL_PROJECTION) { + return; + } + + // Update order stats if compressed + auto &compressed_order = op->children[0]->Cast(); + for (idx_t order_node_idx = 0; order_node_idx < compressed_order.orders.size(); order_node_idx++) { + auto &bound_order = compressed_order.orders[order_node_idx]; + auto &order_expression = *bound_order.expression; + if (order_expression.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + continue; + } + auto &colref = order_expression.Cast(); + auto it = statistics_map.find(colref.binding); + if (it != statistics_map.end() && it->second) { + bound_order.stats = it->second->ToUnique(); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/cse_optimizer.cpp b/src/duckdb/src/optimizer/cse_optimizer.cpp new file mode 100644 index 00000000..3a17c942 --- /dev/null +++ b/src/duckdb/src/optimizer/cse_optimizer.cpp @@ -0,0 +1,156 @@ +#include "duckdb/optimizer/cse_optimizer.hpp" + +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +//! The CSENode contains information about a common subexpression; how many times it occurs, and the column index in the +//! underlying projection +struct CSENode { + idx_t count; + idx_t column_index; + + CSENode() : count(1), column_index(DConstants::INVALID_INDEX) { + } +}; + +//! The CSEReplacementState +struct CSEReplacementState { + //! The projection index of the new projection + idx_t projection_index; + //! Map of expression -> CSENode + expression_map_t expression_count; + //! Map of column bindings to column indexes in the projection expression list + column_binding_map_t column_map; + //! The set of expressions of the resulting projection + vector> expressions; + //! Cached expressions that are kept around so the expression_map always contains valid expressions + vector> cached_expressions; +}; + +void CommonSubExpressionOptimizer::VisitOperator(LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_PROJECTION: + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + ExtractCommonSubExpresions(op); + break; + default: + break; + } + LogicalOperatorVisitor::VisitOperator(op); +} + +void CommonSubExpressionOptimizer::CountExpressions(Expression &expr, CSEReplacementState &state) { + // we only consider expressions with children for CSE elimination + switch (expr.expression_class) { + case ExpressionClass::BOUND_COLUMN_REF: + case ExpressionClass::BOUND_CONSTANT: + case ExpressionClass::BOUND_PARAMETER: + // skip conjunctions and case, since short-circuiting might be incorrectly disabled otherwise + case ExpressionClass::BOUND_CONJUNCTION: + case ExpressionClass::BOUND_CASE: + return; + default: + break; + } + if (expr.expression_class != ExpressionClass::BOUND_AGGREGATE && !expr.HasSideEffects()) { + // we can't move aggregates to a projection, so we only consider the children of the aggregate + auto node = state.expression_count.find(expr); + if (node == state.expression_count.end()) { + // first time we encounter this expression, insert this node with [count = 1] + state.expression_count[expr] = CSENode(); + } else { + // we encountered this expression before, increment the occurrence count + node->second.count++; + } + } + // recursively count the children + ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { CountExpressions(child, state); }); +} + +void CommonSubExpressionOptimizer::PerformCSEReplacement(unique_ptr &expr_ptr, CSEReplacementState &state) { + Expression &expr = *expr_ptr; + if (expr.expression_class == ExpressionClass::BOUND_COLUMN_REF) { + auto &bound_column_ref = expr.Cast(); + // bound column ref, check if this one has already been recorded in the expression list + auto column_entry = state.column_map.find(bound_column_ref.binding); + if (column_entry == state.column_map.end()) { + // not there yet: push the expression + idx_t new_column_index = state.expressions.size(); + state.column_map[bound_column_ref.binding] = new_column_index; + state.expressions.push_back(make_uniq( + bound_column_ref.alias, bound_column_ref.return_type, bound_column_ref.binding)); + bound_column_ref.binding = ColumnBinding(state.projection_index, new_column_index); + } else { + // else: just update the column binding! + bound_column_ref.binding = ColumnBinding(state.projection_index, column_entry->second); + } + return; + } + // check if this child is eligible for CSE elimination + bool can_cse = expr.expression_class != ExpressionClass::BOUND_CONJUNCTION && + expr.expression_class != ExpressionClass::BOUND_CASE; + if (can_cse && state.expression_count.find(expr) != state.expression_count.end()) { + auto &node = state.expression_count[expr]; + if (node.count > 1) { + // this expression occurs more than once! push it into the projection + // check if it has already been pushed into the projection + auto alias = expr.alias; + auto type = expr.return_type; + if (node.column_index == DConstants::INVALID_INDEX) { + // has not been pushed yet: push it + node.column_index = state.expressions.size(); + state.expressions.push_back(std::move(expr_ptr)); + } else { + state.cached_expressions.push_back(std::move(expr_ptr)); + } + // replace the original expression with a bound column ref + expr_ptr = make_uniq(alias, type, + ColumnBinding(state.projection_index, node.column_index)); + return; + } + } + // this expression only occurs once, we can't perform CSE elimination + // look into the children to see if we can replace them + ExpressionIterator::EnumerateChildren(expr, + [&](unique_ptr &child) { PerformCSEReplacement(child, state); }); +} + +void CommonSubExpressionOptimizer::ExtractCommonSubExpresions(LogicalOperator &op) { + D_ASSERT(op.children.size() == 1); + + // first we count for each expression with children how many types it occurs + CSEReplacementState state; + LogicalOperatorVisitor::EnumerateExpressions( + op, [&](unique_ptr *child) { CountExpressions(**child, state); }); + // check if there are any expressions to extract + bool perform_replacement = false; + for (auto &expr : state.expression_count) { + if (expr.second.count > 1) { + perform_replacement = true; + break; + } + } + if (!perform_replacement) { + // no CSEs to extract + return; + } + state.projection_index = binder.GenerateTableIndex(); + // we found common subexpressions to extract + // now we iterate over all the expressions and perform the actual CSE elimination + + LogicalOperatorVisitor::EnumerateExpressions( + op, [&](unique_ptr *child) { PerformCSEReplacement(*child, state); }); + D_ASSERT(state.expressions.size() > 0); + // create a projection node as the child of this node + auto projection = make_uniq(state.projection_index, std::move(state.expressions)); + projection->children.push_back(std::move(op.children[0])); + op.children[0] = std::move(projection); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/deliminator.cpp b/src/duckdb/src/optimizer/deliminator.cpp new file mode 100644 index 00000000..0e45c4d8 --- /dev/null +++ b/src/duckdb/src/optimizer/deliminator.cpp @@ -0,0 +1,297 @@ +#include "duckdb/optimizer/deliminator.hpp" + +#include "duckdb/optimizer/join_order/join_order_optimizer.hpp" +#include "duckdb/optimizer/remove_duplicate_groups.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_delim_get.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" + +namespace duckdb { + +struct DelimCandidate { +public: + explicit DelimCandidate(unique_ptr &op, LogicalComparisonJoin &delim_join) + : op(op), delim_join(delim_join), delim_get_count(0) { + } + +public: + unique_ptr &op; + LogicalComparisonJoin &delim_join; + vector>> joins; + idx_t delim_get_count; +}; + +static bool IsEqualityJoinCondition(const JoinCondition &cond) { + switch (cond.comparison) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return true; + default: + return false; + } +} + +unique_ptr Deliminator::Optimize(unique_ptr op) { + root = op; + + vector candidates; + FindCandidates(op, candidates); + + for (auto &candidate : candidates) { + auto &delim_join = candidate.delim_join; + + bool all_removed = true; + bool all_equality_conditions = true; + for (auto &join : candidate.joins) { + all_removed = + RemoveJoinWithDelimGet(delim_join, candidate.delim_get_count, join, all_equality_conditions) && + all_removed; + } + + // Change type if there are no more duplicate-eliminated columns + if (candidate.joins.size() == candidate.delim_get_count && all_removed) { + delim_join.type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; + delim_join.duplicate_eliminated_columns.clear(); + if (all_equality_conditions) { + for (auto &cond : delim_join.conditions) { + if (IsEqualityJoinCondition(cond)) { + cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + } + } + } + } + } + + return op; +} + +void Deliminator::FindCandidates(unique_ptr &op, vector &candidates) { + // Search children before adding, so the deepest candidates get added first + for (auto &child : op->children) { + FindCandidates(child, candidates); + } + + if (op->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { + return; + } + + candidates.emplace_back(op, op->Cast()); + auto &candidate = candidates.back(); + + // DelimGets are in the RHS + FindJoinWithDelimGet(op->children[1], candidate); +} + +static bool OperatorIsDelimGet(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_DELIM_GET) { + return true; + } + if (op.type == LogicalOperatorType::LOGICAL_FILTER && + op.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { + return true; + } + return false; +} + +void Deliminator::FindJoinWithDelimGet(unique_ptr &op, DelimCandidate &candidate) { + if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + FindJoinWithDelimGet(op->children[0], candidate); + } else if (op->type == LogicalOperatorType::LOGICAL_DELIM_GET) { + candidate.delim_get_count++; + } else { + for (auto &child : op->children) { + FindJoinWithDelimGet(child, candidate); + } + } + + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN && + (OperatorIsDelimGet(*op->children[0]) || OperatorIsDelimGet(*op->children[1]))) { + candidate.joins.emplace_back(op); + } +} + +static bool ChildJoinTypeCanBeDeliminated(JoinType &join_type) { + switch (join_type) { + case JoinType::INNER: + case JoinType::SEMI: + return true; + default: + return false; + } +} + +bool Deliminator::RemoveJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, + unique_ptr &join, bool &all_equality_conditions) { + auto &comparison_join = join->Cast(); + if (!ChildJoinTypeCanBeDeliminated(comparison_join.join_type)) { + return false; + } + + // Get the index (left or right) of the DelimGet side of the join + const idx_t delim_idx = OperatorIsDelimGet(*join->children[0]) ? 0 : 1; + + // Get the filter (if any) + optional_ptr filter; + vector> filter_expressions; + if (join->children[delim_idx]->type == LogicalOperatorType::LOGICAL_FILTER) { + filter = &join->children[delim_idx]->Cast(); + for (auto &expr : filter->expressions) { + filter_expressions.emplace_back(expr->Copy()); + } + } + + auto &delim_get = (filter ? filter->children[0] : join->children[delim_idx])->Cast(); + if (comparison_join.conditions.size() != delim_get.chunk_types.size()) { + return false; // Joining with DelimGet adds new information + } + + // Check if joining with the DelimGet is redundant, and collect relevant column information + ColumnBindingReplacer replacer; + auto &replacement_bindings = replacer.replacement_bindings; + for (auto &cond : comparison_join.conditions) { + all_equality_conditions = all_equality_conditions && IsEqualityJoinCondition(cond); + auto &delim_side = delim_idx == 0 ? *cond.left : *cond.right; + auto &other_side = delim_idx == 0 ? *cond.right : *cond.left; + if (delim_side.type != ExpressionType::BOUND_COLUMN_REF || + other_side.type != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + auto &delim_colref = delim_side.Cast(); + auto &other_colref = other_side.Cast(); + replacement_bindings.emplace_back(delim_colref.binding, other_colref.binding); + + if (cond.comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + auto is_not_null_expr = + make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); + is_not_null_expr->children.push_back(other_side.Copy()); + filter_expressions.push_back(std::move(is_not_null_expr)); + } + } + + if (!all_equality_conditions && + !RemoveInequalityJoinWithDelimGet(delim_join, delim_get_count, join, replacement_bindings)) { + return false; + } + + unique_ptr replacement_op = std::move(comparison_join.children[1 - delim_idx]); + if (!filter_expressions.empty()) { // Create filter if necessary + auto new_filter = make_uniq(); + new_filter->expressions = std::move(filter_expressions); + new_filter->children.emplace_back(std::move(replacement_op)); + replacement_op = std::move(new_filter); + } + + join = std::move(replacement_op); + + // TODO: Maybe go from delim join instead to save work + replacer.VisitOperator(*root); + return true; +} + +static bool InequalityDelimJoinCanBeEliminated(JoinType &join_type) { + return join_type == JoinType::ANTI || join_type == JoinType::MARK || join_type == JoinType::SEMI || + join_type == JoinType::SINGLE; +} + +bool FindAndReplaceBindings(vector &traced_bindings, const vector> &expressions, + const vector ¤t_bindings) { + for (auto &binding : traced_bindings) { + idx_t current_idx; + for (current_idx = 0; current_idx < expressions.size(); current_idx++) { + if (binding == current_bindings[current_idx]) { + break; + } + } + + if (current_idx == expressions.size() || expressions[current_idx]->type != ExpressionType::BOUND_COLUMN_REF) { + return false; // Didn't find / can't deal with non-colref + } + + auto &colref = expressions[current_idx]->Cast(); + binding = colref.binding; + } + return true; +} + +bool Deliminator::RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, + unique_ptr &join, + const vector &replacement_bindings) { + auto &comparison_join = join->Cast(); + auto &delim_conditions = delim_join.conditions; + const auto &join_conditions = comparison_join.conditions; + if (delim_get_count != 1 || !InequalityDelimJoinCanBeEliminated(delim_join.join_type) || + delim_conditions.size() != join_conditions.size()) { + return false; + } + + // TODO: we cannot perform the optimization here because our pure inequality joins don't implement + // JoinType::SINGLE yet + if (delim_join.join_type == JoinType::SINGLE) { + bool has_one_equality = false; + for (auto &cond : join_conditions) { + has_one_equality = has_one_equality || IsEqualityJoinCondition(cond); + } + if (!has_one_equality) { + return false; + } + } + + // We only support colref's + vector traced_bindings; + for (const auto &cond : delim_conditions) { + if (cond.right->type != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + auto &colref = cond.right->Cast(); + traced_bindings.emplace_back(colref.binding); + } + + // Now we trace down the bindings to the join (for now, we only trace it through a few operators) + reference current_op = *delim_join.children[1]; + while (¤t_op.get() != join.get()) { + if (current_op.get().children.size() != 1) { + return false; + } + + switch (current_op.get().type) { + case LogicalOperatorType::LOGICAL_PROJECTION: + FindAndReplaceBindings(traced_bindings, current_op.get().expressions, current_op.get().GetColumnBindings()); + break; + case LogicalOperatorType::LOGICAL_FILTER: + break; // Doesn't change bindings + default: + return false; + } + current_op = *current_op.get().children[0]; + } + + // Get the index (left or right) of the DelimGet side of the join + const idx_t delim_idx = OperatorIsDelimGet(*join->children[0]) ? 0 : 1; + + bool found_all = true; + for (idx_t cond_idx = 0; cond_idx < delim_conditions.size(); cond_idx++) { + auto &delim_condition = delim_conditions[cond_idx]; + const auto &traced_binding = traced_bindings[cond_idx]; + + bool found = false; + for (auto &join_condition : join_conditions) { + auto &delim_side = delim_idx == 0 ? *join_condition.left : *join_condition.right; + auto &colref = delim_side.Cast(); + if (colref.binding == traced_binding) { + delim_condition.comparison = FlipComparisonExpression(join_condition.comparison); + found = true; + break; + } + } + found_all = found_all && found; + } + + return found_all; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/expression_heuristics.cpp b/src/duckdb/src/optimizer/expression_heuristics.cpp new file mode 100644 index 00000000..e334db4b --- /dev/null +++ b/src/duckdb/src/optimizer/expression_heuristics.cpp @@ -0,0 +1,208 @@ +#include "duckdb/optimizer/expression_heuristics.hpp" +#include "duckdb/planner/expression/list.hpp" + +namespace duckdb { + +unique_ptr ExpressionHeuristics::Rewrite(unique_ptr op) { + VisitOperator(*op); + return op; +} + +void ExpressionHeuristics::VisitOperator(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_FILTER) { + // reorder all filter expressions + if (op.expressions.size() > 1) { + ReorderExpressions(op.expressions); + } + } + + // traverse recursively through the operator tree + VisitOperatorChildren(op); + VisitOperatorExpressions(op); +} + +unique_ptr ExpressionHeuristics::VisitReplace(BoundConjunctionExpression &expr, + unique_ptr *expr_ptr) { + ReorderExpressions(expr.children); + return nullptr; +} + +void ExpressionHeuristics::ReorderExpressions(vector> &expressions) { + + struct ExpressionCosts { + unique_ptr expr; + idx_t cost; + + bool operator==(const ExpressionCosts &p) const { + return cost == p.cost; + } + bool operator<(const ExpressionCosts &p) const { + return cost < p.cost; + } + }; + + vector expression_costs; + expression_costs.reserve(expressions.size()); + // iterate expressions, get cost for each one + for (idx_t i = 0; i < expressions.size(); i++) { + idx_t cost = Cost(*expressions[i]); + expression_costs.push_back({std::move(expressions[i]), cost}); + } + + // sort by cost and put back in place + sort(expression_costs.begin(), expression_costs.end()); + for (idx_t i = 0; i < expression_costs.size(); i++) { + expressions[i] = std::move(expression_costs[i].expr); + } +} + +idx_t ExpressionHeuristics::ExpressionCost(BoundBetweenExpression &expr) { + return Cost(*expr.input) + Cost(*expr.lower) + Cost(*expr.upper) + 10; +} + +idx_t ExpressionHeuristics::ExpressionCost(BoundCaseExpression &expr) { + // CASE WHEN check THEN result_if_true ELSE result_if_false END + idx_t case_cost = 0; + for (auto &case_check : expr.case_checks) { + case_cost += Cost(*case_check.then_expr); + case_cost += Cost(*case_check.when_expr); + } + case_cost += Cost(*expr.else_expr); + return case_cost; +} + +idx_t ExpressionHeuristics::ExpressionCost(BoundCastExpression &expr) { + // OPERATOR_CAST + // determine cast cost by comparing cast_expr.source_type and cast_expr_target_type + idx_t cast_cost = 0; + if (expr.return_type != expr.source_type()) { + // if cast from or to varchar + // TODO: we might want to add more cases + if (expr.return_type.id() == LogicalTypeId::VARCHAR || expr.source_type().id() == LogicalTypeId::VARCHAR || + expr.return_type.id() == LogicalTypeId::BLOB || expr.source_type().id() == LogicalTypeId::BLOB) { + cast_cost = 200; + } else { + cast_cost = 5; + } + } + return Cost(*expr.child) + cast_cost; +} + +idx_t ExpressionHeuristics::ExpressionCost(BoundComparisonExpression &expr) { + // COMPARE_EQUAL, COMPARE_NOTEQUAL, COMPARE_GREATERTHAN, COMPARE_GREATERTHANOREQUALTO, COMPARE_LESSTHAN, + // COMPARE_LESSTHANOREQUALTO + return Cost(*expr.left) + 5 + Cost(*expr.right); +} + +idx_t ExpressionHeuristics::ExpressionCost(BoundConjunctionExpression &expr) { + // CONJUNCTION_AND, CONJUNCTION_OR + idx_t cost = 5; + for (auto &child : expr.children) { + cost += Cost(*child); + } + return cost; +} + +idx_t ExpressionHeuristics::ExpressionCost(BoundFunctionExpression &expr) { + idx_t cost_children = 0; + for (auto &child : expr.children) { + cost_children += Cost(*child); + } + + auto cost_function = function_costs.find(expr.function.name); + if (cost_function != function_costs.end()) { + return cost_children + cost_function->second; + } else { + return cost_children + 1000; + } +} + +idx_t ExpressionHeuristics::ExpressionCost(BoundOperatorExpression &expr, ExpressionType &expr_type) { + idx_t sum = 0; + for (auto &child : expr.children) { + sum += Cost(*child); + } + + // OPERATOR_IS_NULL, OPERATOR_IS_NOT_NULL + if (expr_type == ExpressionType::OPERATOR_IS_NULL || expr_type == ExpressionType::OPERATOR_IS_NOT_NULL) { + return sum + 5; + } else if (expr_type == ExpressionType::COMPARE_IN || expr_type == ExpressionType::COMPARE_NOT_IN) { + // COMPARE_IN, COMPARE_NOT_IN + return sum + (expr.children.size() - 1) * 100; + } else if (expr_type == ExpressionType::OPERATOR_NOT) { + // OPERATOR_NOT + return sum + 10; // TODO: evaluate via measured runtimes + } else { + return sum + 1000; + } +} + +idx_t ExpressionHeuristics::ExpressionCost(PhysicalType return_type, idx_t multiplier) { + // TODO: ajust values according to benchmark results + switch (return_type) { + case PhysicalType::VARCHAR: + return 5 * multiplier; + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + return 2 * multiplier; + default: + return 1 * multiplier; + } +} + +idx_t ExpressionHeuristics::Cost(Expression &expr) { + switch (expr.expression_class) { + case ExpressionClass::BOUND_CASE: { + auto &case_expr = expr.Cast(); + return ExpressionCost(case_expr); + } + case ExpressionClass::BOUND_BETWEEN: { + auto &between_expr = expr.Cast(); + return ExpressionCost(between_expr); + } + case ExpressionClass::BOUND_CAST: { + auto &cast_expr = expr.Cast(); + return ExpressionCost(cast_expr); + } + case ExpressionClass::BOUND_COMPARISON: { + auto &comp_expr = expr.Cast(); + return ExpressionCost(comp_expr); + } + case ExpressionClass::BOUND_CONJUNCTION: { + auto &conj_expr = expr.Cast(); + return ExpressionCost(conj_expr); + } + case ExpressionClass::BOUND_FUNCTION: { + auto &func_expr = expr.Cast(); + return ExpressionCost(func_expr); + } + case ExpressionClass::BOUND_OPERATOR: { + auto &op_expr = expr.Cast(); + return ExpressionCost(op_expr, expr.type); + } + case ExpressionClass::BOUND_COLUMN_REF: { + auto &col_expr = expr.Cast(); + return ExpressionCost(col_expr.return_type.InternalType(), 8); + } + case ExpressionClass::BOUND_CONSTANT: { + auto &const_expr = expr.Cast(); + return ExpressionCost(const_expr.return_type.InternalType(), 1); + } + case ExpressionClass::BOUND_PARAMETER: { + auto &const_expr = expr.Cast(); + return ExpressionCost(const_expr.return_type.InternalType(), 1); + } + case ExpressionClass::BOUND_REF: { + auto &col_expr = expr.Cast(); + return ExpressionCost(col_expr.return_type.InternalType(), 8); + } + default: { + break; + } + } + + // return a very high value if nothing matches + return 1000; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/expression_rewriter.cpp b/src/duckdb/src/optimizer/expression_rewriter.cpp new file mode 100644 index 00000000..2e3a19f2 --- /dev/null +++ b/src/duckdb/src/optimizer/expression_rewriter.cpp @@ -0,0 +1,94 @@ +#include "duckdb/optimizer/expression_rewriter.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +unique_ptr ExpressionRewriter::ApplyRules(LogicalOperator &op, const vector> &rules, + unique_ptr expr, bool &changes_made, bool is_root) { + for (auto &rule : rules) { + vector> bindings; + if (rule.get().root->Match(*expr, bindings)) { + // the rule matches! try to apply it + bool rule_made_change = false; + auto result = rule.get().Apply(op, bindings, rule_made_change, is_root); + if (result) { + changes_made = true; + // the base node changed: the rule applied changes + // rerun on the new node + return ExpressionRewriter::ApplyRules(op, rules, std::move(result), changes_made); + } else if (rule_made_change) { + changes_made = true; + // the base node didn't change, but changes were made, rerun + return expr; + } + // else nothing changed, continue to the next rule + continue; + } + } + // no changes could be made to this node + // recursively run on the children of this node + ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { + child = ExpressionRewriter::ApplyRules(op, rules, std::move(child), changes_made); + }); + return expr; +} + +unique_ptr ExpressionRewriter::ConstantOrNull(unique_ptr child, Value value) { + vector> children; + children.push_back(make_uniq(value)); + children.push_back(std::move(child)); + return ConstantOrNull(std::move(children), std::move(value)); +} + +unique_ptr ExpressionRewriter::ConstantOrNull(vector> children, Value value) { + auto type = value.type(); + children.insert(children.begin(), make_uniq(value)); + return make_uniq(type, ConstantOrNull::GetFunction(type), std::move(children), + ConstantOrNull::Bind(std::move(value))); +} + +void ExpressionRewriter::VisitOperator(LogicalOperator &op) { + VisitOperatorChildren(op); + this->op = &op; + + to_apply_rules.clear(); + for (auto &rule : rules) { + if (rule->logical_root && !rule->logical_root->Match(op.type)) { + // this rule does not apply to this type of LogicalOperator + continue; + } + to_apply_rules.push_back(*rule); + } + if (to_apply_rules.empty()) { + // no rules to apply on this node + return; + } + + VisitOperatorExpressions(op); + + // if it is a LogicalFilter, we split up filter conjunctions again + if (op.type == LogicalOperatorType::LOGICAL_FILTER) { + auto &filter = op.Cast(); + filter.SplitPredicates(); + } +} + +void ExpressionRewriter::VisitExpression(unique_ptr *expression) { + bool changes_made; + do { + changes_made = false; + *expression = ExpressionRewriter::ApplyRules(*op, to_apply_rules, std::move(*expression), changes_made, true); + } while (changes_made); +} + +ClientContext &Rule::GetContext() const { + return rewriter.context; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/filter_combiner.cpp b/src/duckdb/src/optimizer/filter_combiner.cpp new file mode 100644 index 00000000..34162724 --- /dev/null +++ b/src/duckdb/src/optimizer/filter_combiner.cpp @@ -0,0 +1,1221 @@ +#include "duckdb/optimizer/filter_combiner.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/null_filter.hpp" +#include "duckdb/optimizer/optimizer.hpp" + +namespace duckdb { + +using ExpressionValueInformation = FilterCombiner::ExpressionValueInformation; + +ValueComparisonResult CompareValueInformation(ExpressionValueInformation &left, ExpressionValueInformation &right); + +FilterCombiner::FilterCombiner(ClientContext &context) : context(context) { +} + +FilterCombiner::FilterCombiner(Optimizer &optimizer) : FilterCombiner(optimizer.context) { +} + +Expression &FilterCombiner::GetNode(Expression &expr) { + auto entry = stored_expressions.find(expr); + if (entry != stored_expressions.end()) { + // expression already exists: return a reference to the stored expression + return *entry->second; + } + // expression does not exist yet: create a copy and store it + auto copy = expr.Copy(); + auto ©_ref = *copy; + D_ASSERT(stored_expressions.find(copy_ref) == stored_expressions.end()); + stored_expressions[copy_ref] = std::move(copy); + return copy_ref; +} + +idx_t FilterCombiner::GetEquivalenceSet(Expression &expr) { + D_ASSERT(stored_expressions.find(expr) != stored_expressions.end()); + D_ASSERT(stored_expressions.find(expr)->second.get() == &expr); + auto entry = equivalence_set_map.find(expr); + if (entry == equivalence_set_map.end()) { + idx_t index = set_index++; + equivalence_set_map[expr] = index; + equivalence_map[index].push_back(expr); + constant_values.insert(make_pair(index, vector())); + return index; + } else { + return entry->second; + } +} + +FilterResult FilterCombiner::AddConstantComparison(vector &info_list, + ExpressionValueInformation info) { + if (info.constant.IsNull()) { + return FilterResult::UNSATISFIABLE; + } + for (idx_t i = 0; i < info_list.size(); i++) { + auto comparison = CompareValueInformation(info_list[i], info); + switch (comparison) { + case ValueComparisonResult::PRUNE_LEFT: + // prune the entry from the info list + info_list.erase(info_list.begin() + i); + i--; + break; + case ValueComparisonResult::PRUNE_RIGHT: + // prune the current info + return FilterResult::SUCCESS; + case ValueComparisonResult::UNSATISFIABLE_CONDITION: + // combination of filters is unsatisfiable: prune the entire branch + return FilterResult::UNSATISFIABLE; + default: + // prune nothing, move to the next condition + break; + } + } + // finally add the entry to the list + info_list.push_back(info); + return FilterResult::SUCCESS; +} + +FilterResult FilterCombiner::AddFilter(unique_ptr expr) { + // LookUpConjunctions(expr.get()); + // try to push the filter into the combiner + auto result = AddFilter(*expr); + if (result == FilterResult::UNSUPPORTED) { + // unsupported filter, push into remaining filters + remaining_filters.push_back(std::move(expr)); + return FilterResult::SUCCESS; + } + return result; +} + +void FilterCombiner::GenerateFilters(const std::function filter)> &callback) { + // first loop over the remaining filters + for (auto &filter : remaining_filters) { + callback(std::move(filter)); + } + remaining_filters.clear(); + // now loop over the equivalence sets + for (auto &entry : equivalence_map) { + auto equivalence_set = entry.first; + auto &entries = entry.second; + auto &constant_list = constant_values.find(equivalence_set)->second; + // for each entry generate an equality expression comparing to each other + for (idx_t i = 0; i < entries.size(); i++) { + for (idx_t k = i + 1; k < entries.size(); k++) { + auto comparison = make_uniq( + ExpressionType::COMPARE_EQUAL, entries[i].get().Copy(), entries[k].get().Copy()); + callback(std::move(comparison)); + } + // for each entry also create a comparison with each constant + int lower_index = -1; + int upper_index = -1; + bool lower_inclusive = false; + bool upper_inclusive = false; + for (idx_t k = 0; k < constant_list.size(); k++) { + auto &info = constant_list[k]; + if (info.comparison_type == ExpressionType::COMPARE_GREATERTHAN || + info.comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO) { + lower_index = k; + lower_inclusive = info.comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO; + } else if (info.comparison_type == ExpressionType::COMPARE_LESSTHAN || + info.comparison_type == ExpressionType::COMPARE_LESSTHANOREQUALTO) { + upper_index = k; + upper_inclusive = info.comparison_type == ExpressionType::COMPARE_LESSTHANOREQUALTO; + } else { + auto constant = make_uniq(info.constant); + auto comparison = make_uniq( + info.comparison_type, entries[i].get().Copy(), std::move(constant)); + callback(std::move(comparison)); + } + } + if (lower_index >= 0 && upper_index >= 0) { + // found both lower and upper index, create a BETWEEN expression + auto lower_constant = make_uniq(constant_list[lower_index].constant); + auto upper_constant = make_uniq(constant_list[upper_index].constant); + auto between = + make_uniq(entries[i].get().Copy(), std::move(lower_constant), + std::move(upper_constant), lower_inclusive, upper_inclusive); + callback(std::move(between)); + } else if (lower_index >= 0) { + // only lower index found, create simple comparison expression + auto constant = make_uniq(constant_list[lower_index].constant); + auto comparison = make_uniq(constant_list[lower_index].comparison_type, + entries[i].get().Copy(), std::move(constant)); + callback(std::move(comparison)); + } else if (upper_index >= 0) { + // only upper index found, create simple comparison expression + auto constant = make_uniq(constant_list[upper_index].constant); + auto comparison = make_uniq(constant_list[upper_index].comparison_type, + entries[i].get().Copy(), std::move(constant)); + callback(std::move(comparison)); + } + } + } + stored_expressions.clear(); + equivalence_set_map.clear(); + constant_values.clear(); + equivalence_map.clear(); +} + +bool FilterCombiner::HasFilters() { + bool has_filters = false; + GenerateFilters([&](unique_ptr child) { has_filters = true; }); + return has_filters; +} + +// unordered_map> MergeAnd(unordered_map> &f_1, +// unordered_map> &f_2) { +// unordered_map> result; +// for (auto &f : f_1) { +// auto it = f_2.find(f.first); +// if (it == f_2.end()) { +// result[f.first] = f.second; +// } else { +// Value *min = nullptr, *max = nullptr; +// if (it->second.first && f.second.first) { +// if (*f.second.first > *it->second.first) { +// min = f.second.first; +// } else { +// min = it->second.first; +// } + +// } else if (it->second.first) { +// min = it->second.first; +// } else if (f.second.first) { +// min = f.second.first; +// } else { +// min = nullptr; +// } +// if (it->second.second && f.second.second) { +// if (*f.second.second < *it->second.second) { +// max = f.second.second; +// } else { +// max = it->second.second; +// } +// } else if (it->second.second) { +// max = it->second.second; +// } else if (f.second.second) { +// max = f.second.second; +// } else { +// max = nullptr; +// } +// result[f.first] = {min, max}; +// f_2.erase(f.first); +// } +// } +// for (auto &f : f_2) { +// result[f.first] = f.second; +// } +// return result; +// } + +// unordered_map> MergeOr(unordered_map> &f_1, +// unordered_map> &f_2) { +// unordered_map> result; +// for (auto &f : f_1) { +// auto it = f_2.find(f.first); +// if (it != f_2.end()) { +// Value *min = nullptr, *max = nullptr; +// if (it->second.first && f.second.first) { +// if (*f.second.first < *it->second.first) { +// min = f.second.first; +// } else { +// min = it->second.first; +// } +// } +// if (it->second.second && f.second.second) { +// if (*f.second.second > *it->second.second) { +// max = f.second.second; +// } else { +// max = it->second.second; +// } +// } +// result[f.first] = {min, max}; +// f_2.erase(f.first); +// } +// } +// return result; +// } + +// unordered_map> +// FilterCombiner::FindZonemapChecks(vector &column_ids, unordered_set ¬_constants, Expression *filter) +// { unordered_map> checks; switch (filter->type) { case +// ExpressionType::CONJUNCTION_OR: { +// //! For a filter to +// auto &or_exp = filter->Cast(); +// checks = FindZonemapChecks(column_ids, not_constants, or_exp.children[0].get()); +// for (size_t i = 1; i < or_exp.children.size(); ++i) { +// auto child_check = FindZonemapChecks(column_ids, not_constants, or_exp.children[i].get()); +// checks = MergeOr(checks, child_check); +// } +// return checks; +// } +// case ExpressionType::CONJUNCTION_AND: { +// auto &and_exp = filter->Cast(); +// checks = FindZonemapChecks(column_ids, not_constants, and_exp.children[0].get()); +// for (size_t i = 1; i < and_exp.children.size(); ++i) { +// auto child_check = FindZonemapChecks(column_ids, not_constants, and_exp.children[i].get()); +// checks = MergeAnd(checks, child_check); +// } +// return checks; +// } +// case ExpressionType::COMPARE_IN: { +// auto &comp_in_exp = filter->Cast(); +// if (comp_in_exp.children[0]->type == ExpressionType::BOUND_COLUMN_REF) { +// Value *min = nullptr, *max = nullptr; +// auto &column_ref = comp_in_exp.children[0]->Cast(); +// for (size_t i {1}; i < comp_in_exp.children.size(); i++) { +// if (comp_in_exp.children[i]->type != ExpressionType::VALUE_CONSTANT) { +// //! This indicates the column has a comparison that is not with a constant +// not_constants.insert(column_ids[column_ref.binding.column_index]); +// break; +// } else { +// auto &const_value_expr = comp_in_exp.children[i]->Cast(); +// if (const_value_expr.value.IsNull()) { +// return checks; +// } +// if (!min && !max) { +// min = &const_value_expr.value; +// max = min; +// } else { +// if (*min > const_value_expr.value) { +// min = &const_value_expr.value; +// } +// if (*max < const_value_expr.value) { +// max = &const_value_expr.value; +// } +// } +// } +// } +// checks[column_ids[column_ref.binding.column_index]] = {min, max}; +// } +// return checks; +// } +// case ExpressionType::COMPARE_EQUAL: { +// auto &comp_exp = filter->Cast(); +// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_COLUMN_REF && +// comp_exp.right->expression_class == ExpressionClass::BOUND_CONSTANT)) { +// auto &column_ref = comp_exp.left->Cast(); +// auto &constant_value_expr = comp_exp.right->Cast(); +// checks[column_ids[column_ref.binding.column_index]] = {&constant_value_expr.value, +// &constant_value_expr.value}; +// } +// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_CONSTANT && +// comp_exp.right->expression_class == ExpressionClass::BOUND_COLUMN_REF)) { +// auto &column_ref = comp_exp.right->Cast(); +// auto &constant_value_expr = comp_exp.left->Cast(); +// checks[column_ids[column_ref.binding.column_index]] = {&constant_value_expr.value, +// &constant_value_expr.value}; +// } +// return checks; +// } +// case ExpressionType::COMPARE_LESSTHAN: +// case ExpressionType::COMPARE_LESSTHANOREQUALTO: { +// auto &comp_exp = filter->Cast(); +// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_COLUMN_REF && +// comp_exp.right->expression_class == ExpressionClass::BOUND_CONSTANT)) { +// auto &column_ref = comp_exp.left->Cast(); +// auto &constant_value_expr = comp_exp.right->Cast(); +// checks[column_ids[column_ref.binding.column_index]] = {nullptr, &constant_value_expr.value}; +// } +// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_CONSTANT && +// comp_exp.right->expression_class == ExpressionClass::BOUND_COLUMN_REF)) { +// auto &column_ref = comp_exp.right->Cast(); +// auto &constant_value_expr = comp_exp.left->Cast(); +// checks[column_ids[column_ref.binding.column_index]] = {&constant_value_expr.value, nullptr}; +// } +// return checks; +// } +// case ExpressionType::COMPARE_GREATERTHANOREQUALTO: +// case ExpressionType::COMPARE_GREATERTHAN: { +// auto &comp_exp = filter->Cast(); +// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_COLUMN_REF && +// comp_exp.right->expression_class == ExpressionClass::BOUND_CONSTANT)) { +// auto &column_ref = comp_exp.left->Cast(); +// auto &constant_value_expr = comp_exp.right->Cast(); +// checks[column_ids[column_ref.binding.column_index]] = {&constant_value_expr.value, nullptr}; +// } +// if ((comp_exp.left->expression_class == ExpressionClass::BOUND_CONSTANT && +// comp_exp.right->expression_class == ExpressionClass::BOUND_COLUMN_REF)) { +// auto &column_ref = comp_exp.right->Cast(); +// auto &constant_value_expr = comp_exp.left->Cast(); +// checks[column_ids[column_ref.binding.column_index]] = {nullptr, &constant_value_expr.value}; +// } +// return checks; +// } +// default: +// return checks; +// } +// } + +// vector FilterCombiner::GenerateZonemapChecks(vector &column_ids, +// vector &pushed_filters) { +// vector zonemap_checks; +// unordered_set not_constants; +// //! We go through the remaining filters and capture their min max +// if (remaining_filters.empty()) { +// return zonemap_checks; +// } + +// auto checks = FindZonemapChecks(column_ids, not_constants, remaining_filters[0].get()); +// for (size_t i = 1; i < remaining_filters.size(); ++i) { +// auto child_check = FindZonemapChecks(column_ids, not_constants, remaining_filters[i].get()); +// checks = MergeAnd(checks, child_check); +// } +// //! We construct the equivalent filters +// for (auto not_constant : not_constants) { +// checks.erase(not_constant); +// } +// for (const auto &pushed_filter : pushed_filters) { +// checks.erase(column_ids[pushed_filter.column_index]); +// } +// for (const auto &check : checks) { +// if (check.second.first) { +// zonemap_checks.emplace_back(check.second.first->Copy(), ExpressionType::COMPARE_GREATERTHANOREQUALTO, +// check.first); +// } +// if (check.second.second) { +// zonemap_checks.emplace_back(check.second.second->Copy(), ExpressionType::COMPARE_LESSTHANOREQUALTO, +// check.first); +// } +// } +// return zonemap_checks; +// } + +TableFilterSet FilterCombiner::GenerateTableScanFilters(vector &column_ids) { + TableFilterSet table_filters; + //! First, we figure the filters that have constant expressions that we can push down to the table scan + for (auto &constant_value : constant_values) { + if (!constant_value.second.empty()) { + auto filter_exp = equivalence_map.end(); + if ((constant_value.second[0].comparison_type == ExpressionType::COMPARE_EQUAL || + constant_value.second[0].comparison_type == ExpressionType::COMPARE_GREATERTHAN || + constant_value.second[0].comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || + constant_value.second[0].comparison_type == ExpressionType::COMPARE_LESSTHAN || + constant_value.second[0].comparison_type == ExpressionType::COMPARE_LESSTHANOREQUALTO) && + (TypeIsNumeric(constant_value.second[0].constant.type().InternalType()) || + constant_value.second[0].constant.type().InternalType() == PhysicalType::VARCHAR || + constant_value.second[0].constant.type().InternalType() == PhysicalType::BOOL)) { + //! Here we check if these filters are column references + filter_exp = equivalence_map.find(constant_value.first); + if (filter_exp->second.size() == 1 && + filter_exp->second[0].get().type == ExpressionType::BOUND_COLUMN_REF) { + auto &filter_col_exp = filter_exp->second[0].get().Cast(); + auto column_index = column_ids[filter_col_exp.binding.column_index]; + if (column_index == COLUMN_IDENTIFIER_ROW_ID) { + break; + } + auto equivalence_set = filter_exp->first; + auto &entries = filter_exp->second; + auto &constant_list = constant_values.find(equivalence_set)->second; + // for each entry generate an equality expression comparing to each other + for (idx_t i = 0; i < entries.size(); i++) { + // for each entry also create a comparison with each constant + for (idx_t k = 0; k < constant_list.size(); k++) { + auto constant_filter = make_uniq(constant_value.second[k].comparison_type, + constant_value.second[k].constant); + table_filters.PushFilter(column_index, std::move(constant_filter)); + } + table_filters.PushFilter(column_index, make_uniq()); + } + equivalence_map.erase(filter_exp); + } + } + } + } + //! Here we look for LIKE or IN filters + for (idx_t rem_fil_idx = 0; rem_fil_idx < remaining_filters.size(); rem_fil_idx++) { + auto &remaining_filter = remaining_filters[rem_fil_idx]; + if (remaining_filter->expression_class == ExpressionClass::BOUND_FUNCTION) { + auto &func = remaining_filter->Cast(); + if (func.function.name == "prefix" && + func.children[0]->expression_class == ExpressionClass::BOUND_COLUMN_REF && + func.children[1]->type == ExpressionType::VALUE_CONSTANT) { + //! This is a like function. + auto &column_ref = func.children[0]->Cast(); + auto &constant_value_expr = func.children[1]->Cast(); + auto like_string = StringValue::Get(constant_value_expr.value); + if (like_string.empty()) { + continue; + } + auto column_index = column_ids[column_ref.binding.column_index]; + //! Here the like must be transformed to a BOUND COMPARISON geq le + auto lower_bound = + make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, Value(like_string)); + like_string[like_string.size() - 1]++; + auto upper_bound = make_uniq(ExpressionType::COMPARE_LESSTHAN, Value(like_string)); + table_filters.PushFilter(column_index, std::move(lower_bound)); + table_filters.PushFilter(column_index, std::move(upper_bound)); + table_filters.PushFilter(column_index, make_uniq()); + } + if (func.function.name == "~~" && func.children[0]->expression_class == ExpressionClass::BOUND_COLUMN_REF && + func.children[1]->type == ExpressionType::VALUE_CONSTANT) { + //! This is a like function. + auto &column_ref = func.children[0]->Cast(); + auto &constant_value_expr = func.children[1]->Cast(); + auto &like_string = StringValue::Get(constant_value_expr.value); + if (like_string[0] == '%' || like_string[0] == '_') { + //! We have no prefix so nothing to pushdown + break; + } + string prefix; + bool equality = true; + for (char const &c : like_string) { + if (c == '%' || c == '_') { + equality = false; + break; + } + prefix += c; + } + auto column_index = column_ids[column_ref.binding.column_index]; + if (equality) { + //! Here the like can be transformed to an equality query + auto equal_filter = make_uniq(ExpressionType::COMPARE_EQUAL, Value(prefix)); + table_filters.PushFilter(column_index, std::move(equal_filter)); + table_filters.PushFilter(column_index, make_uniq()); + } else { + //! Here the like must be transformed to a BOUND COMPARISON geq le + auto lower_bound = + make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, Value(prefix)); + prefix[prefix.size() - 1]++; + auto upper_bound = make_uniq(ExpressionType::COMPARE_LESSTHAN, Value(prefix)); + table_filters.PushFilter(column_index, std::move(lower_bound)); + table_filters.PushFilter(column_index, std::move(upper_bound)); + table_filters.PushFilter(column_index, make_uniq()); + } + } + } else if (remaining_filter->type == ExpressionType::COMPARE_IN) { + auto &func = remaining_filter->Cast(); + vector in_values; + D_ASSERT(func.children.size() > 1); + if (func.children[0]->expression_class != ExpressionClass::BOUND_COLUMN_REF) { + continue; + } + auto &column_ref = func.children[0]->Cast(); + auto column_index = column_ids[column_ref.binding.column_index]; + if (column_index == COLUMN_IDENTIFIER_ROW_ID) { + break; + } + //! check if all children are const expr + bool children_constant = true; + for (size_t i {1}; i < func.children.size(); i++) { + if (func.children[i]->type != ExpressionType::VALUE_CONSTANT) { + children_constant = false; + } + } + if (!children_constant) { + continue; + } + auto &fst_const_value_expr = func.children[1]->Cast(); + auto &type = fst_const_value_expr.value.type(); + + //! Check if values are consecutive, if yes transform them to >= <= (only for integers) + // e.g. if we have x IN (1, 2, 3, 4, 5) we transform this into x >= 1 AND x <= 5 + if (!type.IsIntegral()) { + continue; + } + + bool can_simplify_in_clause = true; + for (idx_t i = 1; i < func.children.size(); i++) { + auto &const_value_expr = func.children[i]->Cast(); + if (const_value_expr.value.IsNull()) { + can_simplify_in_clause = false; + break; + } + in_values.push_back(const_value_expr.value.GetValue()); + } + if (!can_simplify_in_clause || in_values.empty()) { + continue; + } + + sort(in_values.begin(), in_values.end()); + + for (idx_t in_val_idx = 1; in_val_idx < in_values.size(); in_val_idx++) { + if (in_values[in_val_idx] - in_values[in_val_idx - 1] > 1) { + can_simplify_in_clause = false; + break; + } + } + if (!can_simplify_in_clause) { + continue; + } + auto lower_bound = make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, + Value::Numeric(type, in_values.front())); + auto upper_bound = make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, + Value::Numeric(type, in_values.back())); + table_filters.PushFilter(column_index, std::move(lower_bound)); + table_filters.PushFilter(column_index, std::move(upper_bound)); + table_filters.PushFilter(column_index, make_uniq()); + + remaining_filters.erase(remaining_filters.begin() + rem_fil_idx); + } + } + + // GenerateORFilters(table_filters, column_ids); + + return table_filters; +} + +static bool IsGreaterThan(ExpressionType type) { + return type == ExpressionType::COMPARE_GREATERTHAN || type == ExpressionType::COMPARE_GREATERTHANOREQUALTO; +} + +static bool IsLessThan(ExpressionType type) { + return type == ExpressionType::COMPARE_LESSTHAN || type == ExpressionType::COMPARE_LESSTHANOREQUALTO; +} + +FilterResult FilterCombiner::AddBoundComparisonFilter(Expression &expr) { + auto &comparison = expr.Cast(); + if (comparison.type != ExpressionType::COMPARE_LESSTHAN && + comparison.type != ExpressionType::COMPARE_LESSTHANOREQUALTO && + comparison.type != ExpressionType::COMPARE_GREATERTHAN && + comparison.type != ExpressionType::COMPARE_GREATERTHANOREQUALTO && + comparison.type != ExpressionType::COMPARE_EQUAL && comparison.type != ExpressionType::COMPARE_NOTEQUAL) { + // only support [>, >=, <, <=, ==, !=] expressions + return FilterResult::UNSUPPORTED; + } + // check if one of the sides is a scalar value + bool left_is_scalar = comparison.left->IsFoldable(); + bool right_is_scalar = comparison.right->IsFoldable(); + if (left_is_scalar || right_is_scalar) { + // comparison with scalar + auto &node = GetNode(left_is_scalar ? *comparison.right : *comparison.left); + idx_t equivalence_set = GetEquivalenceSet(node); + auto &scalar = left_is_scalar ? comparison.left : comparison.right; + Value constant_value; + if (!ExpressionExecutor::TryEvaluateScalar(context, *scalar, constant_value)) { + return FilterResult::UNSATISFIABLE; + } + if (constant_value.IsNull()) { + // comparisons with null are always null (i.e. will never result in rows) + return FilterResult::UNSATISFIABLE; + } + + // create the ExpressionValueInformation + ExpressionValueInformation info; + info.comparison_type = left_is_scalar ? FlipComparisonExpression(comparison.type) : comparison.type; + info.constant = constant_value; + + // get the current bucket of constant values + D_ASSERT(constant_values.find(equivalence_set) != constant_values.end()); + auto &info_list = constant_values.find(equivalence_set)->second; + D_ASSERT(node.return_type == info.constant.type()); + // check the existing constant comparisons to see if we can do any pruning + auto ret = AddConstantComparison(info_list, info); + + auto &non_scalar = left_is_scalar ? *comparison.right : *comparison.left; + auto transitive_filter = FindTransitiveFilter(non_scalar); + if (transitive_filter != nullptr) { + // try to add transitive filters + if (AddTransitiveFilters(transitive_filter->Cast()) == + FilterResult::UNSUPPORTED) { + // in case of unsuccessful re-add filter into remaining ones + remaining_filters.push_back(std::move(transitive_filter)); + } + } + return ret; + } else { + // comparison between two non-scalars + // only handle comparisons for now + if (expr.type != ExpressionType::COMPARE_EQUAL) { + if (IsGreaterThan(expr.type) || IsLessThan(expr.type)) { + return AddTransitiveFilters(comparison); + } + return FilterResult::UNSUPPORTED; + } + // get the LHS and RHS nodes + auto &left_node = GetNode(*comparison.left); + auto &right_node = GetNode(*comparison.right); + if (left_node.Equals(right_node)) { + return FilterResult::UNSUPPORTED; + } + // get the equivalence sets of the LHS and RHS + auto left_equivalence_set = GetEquivalenceSet(left_node); + auto right_equivalence_set = GetEquivalenceSet(right_node); + if (left_equivalence_set == right_equivalence_set) { + // this equality filter already exists, prune it + return FilterResult::SUCCESS; + } + // add the right bucket into the left bucket + D_ASSERT(equivalence_map.find(left_equivalence_set) != equivalence_map.end()); + D_ASSERT(equivalence_map.find(right_equivalence_set) != equivalence_map.end()); + + auto &left_bucket = equivalence_map.find(left_equivalence_set)->second; + auto &right_bucket = equivalence_map.find(right_equivalence_set)->second; + for (auto &right_expr : right_bucket) { + // rewrite the equivalence set mapping for this node + equivalence_set_map[right_expr] = left_equivalence_set; + // add the node to the left bucket + left_bucket.push_back(right_expr); + } + // now add all constant values from the right bucket to the left bucket + D_ASSERT(constant_values.find(left_equivalence_set) != constant_values.end()); + D_ASSERT(constant_values.find(right_equivalence_set) != constant_values.end()); + auto &left_constant_bucket = constant_values.find(left_equivalence_set)->second; + auto &right_constant_bucket = constant_values.find(right_equivalence_set)->second; + for (auto &right_constant : right_constant_bucket) { + if (AddConstantComparison(left_constant_bucket, right_constant) == FilterResult::UNSATISFIABLE) { + return FilterResult::UNSATISFIABLE; + } + } + } + return FilterResult::SUCCESS; +} + +FilterResult FilterCombiner::AddFilter(Expression &expr) { + if (expr.HasParameter()) { + return FilterResult::UNSUPPORTED; + } + if (expr.IsFoldable()) { + // scalar condition, evaluate it + Value result; + if (!ExpressionExecutor::TryEvaluateScalar(context, expr, result)) { + return FilterResult::UNSUPPORTED; + } + result = result.DefaultCastAs(LogicalType::BOOLEAN); + // check if the filter passes + if (result.IsNull() || !BooleanValue::Get(result)) { + // the filter does not pass the scalar test, create an empty result + return FilterResult::UNSATISFIABLE; + } else { + // the filter passes the scalar test, just remove the condition + return FilterResult::SUCCESS; + } + } + D_ASSERT(!expr.IsFoldable()); + if (expr.GetExpressionClass() == ExpressionClass::BOUND_BETWEEN) { + auto &comparison = expr.Cast(); + //! check if one of the sides is a scalar value + bool lower_is_scalar = comparison.lower->IsFoldable(); + bool upper_is_scalar = comparison.upper->IsFoldable(); + if (lower_is_scalar || upper_is_scalar) { + //! comparison with scalar - break apart + auto &node = GetNode(*comparison.input); + idx_t equivalence_set = GetEquivalenceSet(node); + auto result = FilterResult::UNSATISFIABLE; + + if (lower_is_scalar) { + auto scalar = comparison.lower.get(); + Value constant_value; + if (!ExpressionExecutor::TryEvaluateScalar(context, *scalar, constant_value)) { + return FilterResult::UNSUPPORTED; + } + + // create the ExpressionValueInformation + ExpressionValueInformation info; + if (comparison.lower_inclusive) { + info.comparison_type = ExpressionType::COMPARE_GREATERTHANOREQUALTO; + } else { + info.comparison_type = ExpressionType::COMPARE_GREATERTHAN; + } + info.constant = constant_value; + + // get the current bucket of constant values + D_ASSERT(constant_values.find(equivalence_set) != constant_values.end()); + auto &info_list = constant_values.find(equivalence_set)->second; + // check the existing constant comparisons to see if we can do any pruning + result = AddConstantComparison(info_list, info); + } else { + D_ASSERT(upper_is_scalar); + const auto type = comparison.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO + : ExpressionType::COMPARE_LESSTHAN; + auto left = comparison.lower->Copy(); + auto right = comparison.input->Copy(); + auto lower_comp = make_uniq(type, std::move(left), std::move(right)); + result = AddBoundComparisonFilter(*lower_comp); + } + + // Stop if we failed + if (result != FilterResult::SUCCESS) { + return result; + } + + if (upper_is_scalar) { + auto scalar = comparison.upper.get(); + Value constant_value; + if (!ExpressionExecutor::TryEvaluateScalar(context, *scalar, constant_value)) { + return FilterResult::UNSUPPORTED; + } + + // create the ExpressionValueInformation + ExpressionValueInformation info; + if (comparison.upper_inclusive) { + info.comparison_type = ExpressionType::COMPARE_LESSTHANOREQUALTO; + } else { + info.comparison_type = ExpressionType::COMPARE_LESSTHAN; + } + info.constant = constant_value; + + // get the current bucket of constant values + D_ASSERT(constant_values.find(equivalence_set) != constant_values.end()); + // check the existing constant comparisons to see if we can do any pruning + result = AddConstantComparison(constant_values.find(equivalence_set)->second, info); + } else { + D_ASSERT(lower_is_scalar); + const auto type = comparison.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO + : ExpressionType::COMPARE_LESSTHAN; + auto left = comparison.input->Copy(); + auto right = comparison.upper->Copy(); + auto upper_comp = make_uniq(type, std::move(left), std::move(right)); + result = AddBoundComparisonFilter(*upper_comp); + } + + return result; + } + } else if (expr.GetExpressionClass() == ExpressionClass::BOUND_COMPARISON) { + return AddBoundComparisonFilter(expr); + } + // only comparisons supported for now + return FilterResult::UNSUPPORTED; +} + +/* + * Create and add new transitive filters from a two non-scalar filter such as j > i, j >= i, j < i, and j <= i + * It's missing to create another method to add transitive filters from scalar filters, e.g, i > 10 + */ +FilterResult FilterCombiner::AddTransitiveFilters(BoundComparisonExpression &comparison) { + D_ASSERT(IsGreaterThan(comparison.type) || IsLessThan(comparison.type)); + // get the LHS and RHS nodes + auto &left_node = GetNode(*comparison.left); + reference right_node = GetNode(*comparison.right); + // In case with filters like CAST(i) = j and i = 5 we replace the COLUMN_REF i with the constant 5 + if (right_node.get().type == ExpressionType::OPERATOR_CAST) { + auto &bound_cast_expr = right_node.get().Cast(); + if (bound_cast_expr.child->type == ExpressionType::BOUND_COLUMN_REF) { + auto &col_ref = bound_cast_expr.child->Cast(); + for (auto &stored_exp : stored_expressions) { + if (stored_exp.first.get().type == ExpressionType::BOUND_COLUMN_REF) { + auto &st_col_ref = stored_exp.second->Cast(); + if (st_col_ref.binding == col_ref.binding && + bound_cast_expr.return_type == stored_exp.second->return_type) { + bound_cast_expr.child = stored_exp.second->Copy(); + right_node = GetNode(*bound_cast_expr.child); + break; + } + } + } + } + } + + if (left_node.Equals(right_node)) { + return FilterResult::UNSUPPORTED; + } + // get the equivalence sets of the LHS and RHS + idx_t left_equivalence_set = GetEquivalenceSet(left_node); + idx_t right_equivalence_set = GetEquivalenceSet(right_node); + if (left_equivalence_set == right_equivalence_set) { + // this equality filter already exists, prune it + return FilterResult::SUCCESS; + } + + vector &left_constants = constant_values.find(left_equivalence_set)->second; + vector &right_constants = constant_values.find(right_equivalence_set)->second; + bool is_successful = false; + bool is_inserted = false; + // read every constant filters already inserted for the right scalar variable + // and see if we can create new transitive filters, e.g., there is already a filter i > 10, + // suppose that we have now the j >= i, then we can infer a new filter j > 10 + for (const auto &right_constant : right_constants) { + ExpressionValueInformation info; + info.constant = right_constant.constant; + // there is already an equality filter, e.g., i = 10 + if (right_constant.comparison_type == ExpressionType::COMPARE_EQUAL) { + // create filter j [>, >=, <, <=] 10 + // suppose the new comparison is j >= i and we have already a filter i = 10, + // then we create a new filter j >= 10 + // and the filter j >= i can be pruned by not adding it into the remaining filters + info.comparison_type = comparison.type; + } else if ((comparison.type == ExpressionType::COMPARE_GREATERTHANOREQUALTO && + IsGreaterThan(right_constant.comparison_type)) || + (comparison.type == ExpressionType::COMPARE_LESSTHANOREQUALTO && + IsLessThan(right_constant.comparison_type))) { + // filters (j >= i AND i [>, >=] 10) OR (j <= i AND i [<, <=] 10) + // create filter j [>, >=] 10 and add the filter j [>=, <=] i into the remaining filters + info.comparison_type = right_constant.comparison_type; // create filter j [>, >=, <, <=] 10 + if (!is_inserted) { + // Add the filter j >= i in the remaing filters + auto filter = make_uniq(comparison.type, comparison.left->Copy(), + comparison.right->Copy()); + remaining_filters.push_back(std::move(filter)); + is_inserted = true; + } + } else if ((comparison.type == ExpressionType::COMPARE_GREATERTHAN && + IsGreaterThan(right_constant.comparison_type)) || + (comparison.type == ExpressionType::COMPARE_LESSTHAN && + IsLessThan(right_constant.comparison_type))) { + // filters (j > i AND i [>, >=] 10) OR j < i AND i [<, <=] 10 + // create filter j [>, <] 10 and add the filter j [>, <] i into the remaining filters + // the comparisons j > i and j < i are more restrictive + info.comparison_type = comparison.type; + if (!is_inserted) { + // Add the filter j [>, <] i + auto filter = make_uniq(comparison.type, comparison.left->Copy(), + comparison.right->Copy()); + remaining_filters.push_back(std::move(filter)); + is_inserted = true; + } + } else { + // we cannot add a new filter + continue; + } + // Add the new filer into the left set + if (AddConstantComparison(left_constants, info) == FilterResult::UNSATISFIABLE) { + return FilterResult::UNSATISFIABLE; + } + is_successful = true; + } + if (is_successful) { + // now check for remaining trasitive filters from the left column + auto transitive_filter = FindTransitiveFilter(*comparison.left); + if (transitive_filter != nullptr) { + // try to add transitive filters + if (AddTransitiveFilters(transitive_filter->Cast()) == + FilterResult::UNSUPPORTED) { + // in case of unsuccessful re-add filter into remaining ones + remaining_filters.push_back(std::move(transitive_filter)); + } + } + return FilterResult::SUCCESS; + } + + return FilterResult::UNSUPPORTED; +} + +/* + * Find a transitive filter already inserted into the remaining filters + * Check for a match between the right column of bound comparisons and the expression, + * then removes the bound comparison from the remaining filters and returns it + */ +unique_ptr FilterCombiner::FindTransitiveFilter(Expression &expr) { + // We only check for bound column ref + if (expr.type != ExpressionType::BOUND_COLUMN_REF) { + return nullptr; + } + for (idx_t i = 0; i < remaining_filters.size(); i++) { + if (remaining_filters[i]->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON) { + auto &comparison = remaining_filters[i]->Cast(); + if (expr.Equals(*comparison.right) && comparison.type != ExpressionType::COMPARE_NOTEQUAL) { + auto filter = std::move(remaining_filters[i]); + remaining_filters.erase(remaining_filters.begin() + i); + return filter; + } + } + } + return nullptr; +} + +ValueComparisonResult InvertValueComparisonResult(ValueComparisonResult result) { + if (result == ValueComparisonResult::PRUNE_RIGHT) { + return ValueComparisonResult::PRUNE_LEFT; + } + if (result == ValueComparisonResult::PRUNE_LEFT) { + return ValueComparisonResult::PRUNE_RIGHT; + } + return result; +} + +ValueComparisonResult CompareValueInformation(ExpressionValueInformation &left, ExpressionValueInformation &right) { + if (left.comparison_type == ExpressionType::COMPARE_EQUAL) { + // left is COMPARE_EQUAL, we can either + // (1) prune the right side or + // (2) return UNSATISFIABLE + bool prune_right_side = false; + switch (right.comparison_type) { + case ExpressionType::COMPARE_LESSTHAN: + prune_right_side = left.constant < right.constant; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + prune_right_side = left.constant <= right.constant; + break; + case ExpressionType::COMPARE_GREATERTHAN: + prune_right_side = left.constant > right.constant; + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + prune_right_side = left.constant >= right.constant; + break; + case ExpressionType::COMPARE_NOTEQUAL: + prune_right_side = left.constant != right.constant; + break; + default: + D_ASSERT(right.comparison_type == ExpressionType::COMPARE_EQUAL); + prune_right_side = left.constant == right.constant; + break; + } + if (prune_right_side) { + return ValueComparisonResult::PRUNE_RIGHT; + } else { + return ValueComparisonResult::UNSATISFIABLE_CONDITION; + } + } else if (right.comparison_type == ExpressionType::COMPARE_EQUAL) { + // right is COMPARE_EQUAL + return InvertValueComparisonResult(CompareValueInformation(right, left)); + } else if (left.comparison_type == ExpressionType::COMPARE_NOTEQUAL) { + // left is COMPARE_NOTEQUAL, we can either + // (1) prune the left side or + // (2) not prune anything + bool prune_left_side = false; + switch (right.comparison_type) { + case ExpressionType::COMPARE_LESSTHAN: + prune_left_side = left.constant >= right.constant; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + prune_left_side = left.constant > right.constant; + break; + case ExpressionType::COMPARE_GREATERTHAN: + prune_left_side = left.constant <= right.constant; + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + prune_left_side = left.constant < right.constant; + break; + default: + D_ASSERT(right.comparison_type == ExpressionType::COMPARE_NOTEQUAL); + prune_left_side = left.constant == right.constant; + break; + } + if (prune_left_side) { + return ValueComparisonResult::PRUNE_LEFT; + } else { + return ValueComparisonResult::PRUNE_NOTHING; + } + } else if (right.comparison_type == ExpressionType::COMPARE_NOTEQUAL) { + return InvertValueComparisonResult(CompareValueInformation(right, left)); + } else if (IsGreaterThan(left.comparison_type) && IsGreaterThan(right.comparison_type)) { + // both comparisons are [>], we can either + // (1) prune the left side or + // (2) prune the right side + if (left.constant > right.constant) { + // left constant is more selective, prune right + return ValueComparisonResult::PRUNE_RIGHT; + } else if (left.constant < right.constant) { + // right constant is more selective, prune left + return ValueComparisonResult::PRUNE_LEFT; + } else { + // constants are equivalent + // however we can still have the scenario where one is [>=] and the other is [>] + // we want to prune the [>=] because [>] is more selective + // if left is [>=] we prune the left, else we prune the right + if (left.comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO) { + return ValueComparisonResult::PRUNE_LEFT; + } else { + return ValueComparisonResult::PRUNE_RIGHT; + } + } + } else if (IsLessThan(left.comparison_type) && IsLessThan(right.comparison_type)) { + // both comparisons are [<], we can either + // (1) prune the left side or + // (2) prune the right side + if (left.constant < right.constant) { + // left constant is more selective, prune right + return ValueComparisonResult::PRUNE_RIGHT; + } else if (left.constant > right.constant) { + // right constant is more selective, prune left + return ValueComparisonResult::PRUNE_LEFT; + } else { + // constants are equivalent + // however we can still have the scenario where one is [<=] and the other is [<] + // we want to prune the [<=] because [<] is more selective + // if left is [<=] we prune the left, else we prune the right + if (left.comparison_type == ExpressionType::COMPARE_LESSTHANOREQUALTO) { + return ValueComparisonResult::PRUNE_LEFT; + } else { + return ValueComparisonResult::PRUNE_RIGHT; + } + } + } else if (IsLessThan(left.comparison_type)) { + D_ASSERT(IsGreaterThan(right.comparison_type)); + // left is [<] and right is [>], in this case we can either + // (1) prune nothing or + // (2) return UNSATISFIABLE + // the SMALLER THAN constant has to be greater than the BIGGER THAN constant + if (left.constant >= right.constant) { + return ValueComparisonResult::PRUNE_NOTHING; + } else { + return ValueComparisonResult::UNSATISFIABLE_CONDITION; + } + } else { + // left is [>] and right is [<] or [!=] + D_ASSERT(IsLessThan(right.comparison_type) && IsGreaterThan(left.comparison_type)); + return InvertValueComparisonResult(CompareValueInformation(right, left)); + } +} +// +// void FilterCombiner::LookUpConjunctions(Expression *expr) { +// if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_OR) { +// auto root_or_expr = (BoundConjunctionExpression *)expr; +// for (const auto &entry : map_col_conjunctions) { +// for (const auto &conjs_to_push : entry.second) { +// if (conjs_to_push->root_or->Equals(root_or_expr)) { +// return; +// } +// } +// } +// +// cur_root_or = root_or_expr; +// cur_conjunction = root_or_expr; +// cur_colref_to_push = nullptr; +// if (!BFSLookUpConjunctions(cur_root_or)) { +// if (cur_colref_to_push) { +// auto entry = map_col_conjunctions.find(cur_colref_to_push); +// auto &vec_conjs_to_push = entry->second; +// if (vec_conjs_to_push.size() == 1) { +// map_col_conjunctions.erase(entry); +// return; +// } +// vec_conjs_to_push.pop_back(); +// } +// } +// return; +// } +// +// // Verify if the expression has a column already pushed down by other OR expression +// VerifyOrsToPush(*expr); +//} +// +// bool FilterCombiner::BFSLookUpConjunctions(BoundConjunctionExpression *conjunction) { +// vector conjunctions_to_visit; +// +// for (auto &child : conjunction->children) { +// switch (child->GetExpressionClass()) { +// case ExpressionClass::BOUND_CONJUNCTION: { +// auto child_conjunction = (BoundConjunctionExpression *)child.get(); +// conjunctions_to_visit.emplace_back(child_conjunction); +// break; +// } +// case ExpressionClass::BOUND_COMPARISON: { +// if (!UpdateConjunctionFilter((BoundComparisonExpression *)child.get())) { +// return false; +// } +// break; +// } +// default: { +// return false; +// } +// } +// } +// +// for (auto child_conjunction : conjunctions_to_visit) { +// cur_conjunction = child_conjunction; +// // traverse child conjuction +// if (!BFSLookUpConjunctions(child_conjunction)) { +// return false; +// } +// } +// return true; +//} +// +// void FilterCombiner::VerifyOrsToPush(Expression &expr) { +// if (expr.type == ExpressionType::BOUND_COLUMN_REF) { +// auto colref = (BoundColumnRefExpression *)&expr; +// auto entry = map_col_conjunctions.find(colref); +// if (entry == map_col_conjunctions.end()) { +// return; +// } +// map_col_conjunctions.erase(entry); +// } +// ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { VerifyOrsToPush(child); }); +//} +// +// bool FilterCombiner::UpdateConjunctionFilter(BoundComparisonExpression *comparison_expr) { +// bool left_is_scalar = comparison_expr->left->IsFoldable(); +// bool right_is_scalar = comparison_expr->right->IsFoldable(); +// +// Expression *non_scalar_expr; +// if (left_is_scalar || right_is_scalar) { +// // only support comparison with scalar +// non_scalar_expr = left_is_scalar ? comparison_expr->right.get() : comparison_expr->left.get(); +// +// if (non_scalar_expr->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { +// return UpdateFilterByColumn((BoundColumnRefExpression *)non_scalar_expr, comparison_expr); +// } +// } +// +// return false; +//} +// +// bool FilterCombiner::UpdateFilterByColumn(BoundColumnRefExpression *column_ref, +// BoundComparisonExpression *comparison_expr) { +// if (cur_colref_to_push == nullptr) { +// cur_colref_to_push = column_ref; +// +// auto or_conjunction = make_uniq(ExpressionType::CONJUNCTION_OR); +// or_conjunction->children.emplace_back(comparison_expr->Copy()); +// +// unique_ptr conjs_to_push = make_uniq(); +// conjs_to_push->conjunctions.emplace_back(std::move(or_conjunction)); +// conjs_to_push->root_or = cur_root_or; +// +// auto &&vec_col_conjs = map_col_conjunctions[column_ref]; +// vec_col_conjs.emplace_back(std::move(conjs_to_push)); +// vec_colref_insertion_order.emplace_back(column_ref); +// return true; +// } +// +// auto entry = map_col_conjunctions.find(cur_colref_to_push); +// D_ASSERT(entry != map_col_conjunctions.end()); +// auto &conjunctions_to_push = entry->second.back(); +// +// if (!cur_colref_to_push->Equals(column_ref)) { +// // check for multiple colunms in the same root OR node +// if (cur_root_or == cur_conjunction) { +// return false; +// } +// // found an AND using a different column, we should stop the look up +// if (cur_conjunction->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { +// return false; +// } +// +// // found a different column, AND conditions cannot be preserved anymore +// conjunctions_to_push->preserve_and = false; +// return true; +// } +// +// auto &last_conjunction = conjunctions_to_push->conjunctions.back(); +// if (cur_conjunction->GetExpressionType() == last_conjunction->GetExpressionType()) { +// last_conjunction->children.emplace_back(comparison_expr->Copy()); +// } else { +// auto new_conjunction = make_uniq(cur_conjunction->GetExpressionType()); +// new_conjunction->children.emplace_back(comparison_expr->Copy()); +// conjunctions_to_push->conjunctions.emplace_back(std::move(new_conjunction)); +// } +// return true; +//} +// +// void FilterCombiner::GenerateORFilters(TableFilterSet &table_filter, vector &column_ids) { +// for (const auto colref : vec_colref_insertion_order) { +// auto column_index = column_ids[colref->binding.column_index]; +// if (column_index == COLUMN_IDENTIFIER_ROW_ID) { +// break; +// } +// +// for (const auto &conjunctions_to_push : map_col_conjunctions[colref]) { +// // root OR filter to push into the TableFilter +// auto root_or_filter = make_uniq(); +// // variable to hold the last conjuntion filter pointer +// // the next filter will be added into it, i.e., we create a chain of conjunction filters +// ConjunctionFilter *last_conj_filter = root_or_filter.get(); +// +// for (auto &conjunction : conjunctions_to_push->conjunctions) { +// if (conjunction->GetExpressionType() == ExpressionType::CONJUNCTION_AND && +// conjunctions_to_push->preserve_and) { +// GenerateConjunctionFilter(conjunction.get(), last_conj_filter); +// } else { +// GenerateConjunctionFilter(conjunction.get(), last_conj_filter); +// } +// } +// table_filter.PushFilter(column_index, std::move(root_or_filter)); +// } +// } +// map_col_conjunctions.clear(); +// vec_colref_insertion_order.clear(); +//} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/filter_pullup.cpp b/src/duckdb/src/optimizer/filter_pullup.cpp new file mode 100644 index 00000000..04986f19 --- /dev/null +++ b/src/duckdb/src/optimizer/filter_pullup.cpp @@ -0,0 +1,91 @@ +#include "duckdb/optimizer/filter_pullup.hpp" +#include "duckdb/planner/operator/logical_join.hpp" + +namespace duckdb { + +unique_ptr FilterPullup::Rewrite(unique_ptr op) { + switch (op->type) { + case LogicalOperatorType::LOGICAL_FILTER: + return PullupFilter(std::move(op)); + case LogicalOperatorType::LOGICAL_PROJECTION: + return PullupProjection(std::move(op)); + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + return PullupCrossProduct(std::move(op)); + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + return PullupJoin(std::move(op)); + case LogicalOperatorType::LOGICAL_INTERSECT: + case LogicalOperatorType::LOGICAL_EXCEPT: + return PullupSetOperation(std::move(op)); + case LogicalOperatorType::LOGICAL_DISTINCT: + case LogicalOperatorType::LOGICAL_ORDER_BY: { + // we can just pull directly through these operations without any rewriting + op->children[0] = Rewrite(std::move(op->children[0])); + return op; + } + default: + return FinishPullup(std::move(op)); + } +} + +unique_ptr FilterPullup::PullupJoin(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || + op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN || op->type == LogicalOperatorType::LOGICAL_ANY_JOIN || + op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN); + auto &join = op->Cast(); + + switch (join.join_type) { + case JoinType::INNER: + return PullupInnerJoin(std::move(op)); + case JoinType::LEFT: + case JoinType::ANTI: + case JoinType::SEMI: { + return PullupFromLeft(std::move(op)); + } + default: + // unsupported join type: call children pull up + return FinishPullup(std::move(op)); + } +} + +unique_ptr FilterPullup::PullupInnerJoin(unique_ptr op) { + D_ASSERT(op->Cast().join_type == JoinType::INNER); + if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + return op; + } + return PullupBothSide(std::move(op)); +} + +unique_ptr FilterPullup::PullupCrossProduct(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_CROSS_PRODUCT); + return PullupBothSide(std::move(op)); +} + +unique_ptr FilterPullup::GeneratePullupFilter(unique_ptr child, + vector> &expressions) { + unique_ptr filter = make_uniq(); + for (idx_t i = 0; i < expressions.size(); ++i) { + filter->expressions.push_back(std::move(expressions[i])); + } + expressions.clear(); + filter->children.push_back(std::move(child)); + return std::move(filter); +} + +unique_ptr FilterPullup::FinishPullup(unique_ptr op) { + // unhandled type, first perform filter pushdown in its children + for (idx_t i = 0; i < op->children.size(); i++) { + FilterPullup pullup; + op->children[i] = pullup.Rewrite(std::move(op->children[i])); + } + // now pull up any existing filters + if (filters_expr_pullup.empty()) { + // no filters to pull up + return op; + } + return GeneratePullupFilter(std::move(op), filters_expr_pullup); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/filter_pushdown.cpp b/src/duckdb/src/optimizer/filter_pushdown.cpp new file mode 100644 index 00000000..a8ca617c --- /dev/null +++ b/src/duckdb/src/optimizer/filter_pushdown.cpp @@ -0,0 +1,155 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" + +#include "duckdb/optimizer/filter_combiner.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_join.hpp" +#include "duckdb/optimizer/optimizer.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +FilterPushdown::FilterPushdown(Optimizer &optimizer) : optimizer(optimizer), combiner(optimizer.context) { +} + +unique_ptr FilterPushdown::Rewrite(unique_ptr op) { + D_ASSERT(!combiner.HasFilters()); + switch (op->type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + return PushdownAggregate(std::move(op)); + case LogicalOperatorType::LOGICAL_FILTER: + return PushdownFilter(std::move(op)); + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + return PushdownCrossProduct(std::move(op)); + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + return PushdownJoin(std::move(op)); + case LogicalOperatorType::LOGICAL_PROJECTION: + return PushdownProjection(std::move(op)); + case LogicalOperatorType::LOGICAL_INTERSECT: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_UNION: + return PushdownSetOperation(std::move(op)); + case LogicalOperatorType::LOGICAL_DISTINCT: + case LogicalOperatorType::LOGICAL_ORDER_BY: { + // we can just push directly through these operations without any rewriting + op->children[0] = Rewrite(std::move(op->children[0])); + return op; + } + case LogicalOperatorType::LOGICAL_GET: + return PushdownGet(std::move(op)); + case LogicalOperatorType::LOGICAL_LIMIT: + return PushdownLimit(std::move(op)); + default: + return FinishPushdown(std::move(op)); + } +} + +ClientContext &FilterPushdown::GetContext() { + return optimizer.GetContext(); +} + +unique_ptr FilterPushdown::PushdownJoin(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || + op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN || op->type == LogicalOperatorType::LOGICAL_ANY_JOIN || + op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN); + auto &join = op->Cast(); + if (!join.left_projection_map.empty() || !join.right_projection_map.empty()) { + // cannot push down further otherwise the projection maps won't be preserved + return FinishPushdown(std::move(op)); + } + + unordered_set left_bindings, right_bindings; + LogicalJoin::GetTableReferences(*op->children[0], left_bindings); + LogicalJoin::GetTableReferences(*op->children[1], right_bindings); + + switch (join.join_type) { + case JoinType::INNER: + return PushdownInnerJoin(std::move(op), left_bindings, right_bindings); + case JoinType::LEFT: + return PushdownLeftJoin(std::move(op), left_bindings, right_bindings); + case JoinType::MARK: + return PushdownMarkJoin(std::move(op), left_bindings, right_bindings); + case JoinType::SINGLE: + return PushdownSingleJoin(std::move(op), left_bindings, right_bindings); + default: + // unsupported join type: stop pushing down + return FinishPushdown(std::move(op)); + } +} +void FilterPushdown::PushFilters() { + for (auto &f : filters) { + auto result = combiner.AddFilter(std::move(f->filter)); + D_ASSERT(result != FilterResult::UNSUPPORTED); + (void)result; + } + filters.clear(); +} + +FilterResult FilterPushdown::AddFilter(unique_ptr expr) { + PushFilters(); + // split up the filters by AND predicate + vector> expressions; + expressions.push_back(std::move(expr)); + LogicalFilter::SplitPredicates(expressions); + // push the filters into the combiner + for (auto &child_expr : expressions) { + if (combiner.AddFilter(std::move(child_expr)) == FilterResult::UNSATISFIABLE) { + return FilterResult::UNSATISFIABLE; + } + } + return FilterResult::SUCCESS; +} + +void FilterPushdown::GenerateFilters() { + if (!filters.empty()) { + D_ASSERT(!combiner.HasFilters()); + return; + } + combiner.GenerateFilters([&](unique_ptr filter) { + auto f = make_uniq(); + f->filter = std::move(filter); + f->ExtractBindings(); + filters.push_back(std::move(f)); + }); +} + +unique_ptr FilterPushdown::AddLogicalFilter(unique_ptr op, + vector> expressions) { + if (expressions.empty()) { + // No left expressions, so needn't to add an extra filter operator. + return op; + } + auto filter = make_uniq(); + filter->expressions = std::move(expressions); + filter->children.push_back(std::move(op)); + return std::move(filter); +} + +unique_ptr FilterPushdown::PushFinalFilters(unique_ptr op) { + vector> expressions; + for (auto &f : filters) { + expressions.push_back(std::move(f->filter)); + } + + return AddLogicalFilter(std::move(op), std::move(expressions)); +} + +unique_ptr FilterPushdown::FinishPushdown(unique_ptr op) { + // unhandled type, first perform filter pushdown in its children + for (auto &child : op->children) { + FilterPushdown pushdown(optimizer); + child = pushdown.Rewrite(std::move(child)); + } + // now push any existing filters + return PushFinalFilters(std::move(op)); +} + +void FilterPushdown::Filter::ExtractBindings() { + bindings.clear(); + LogicalJoin::GetExpressionBindings(*filter, bindings); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/in_clause_rewriter.cpp b/src/duckdb/src/optimizer/in_clause_rewriter.cpp new file mode 100644 index 00000000..3670a572 --- /dev/null +++ b/src/duckdb/src/optimizer/in_clause_rewriter.cpp @@ -0,0 +1,115 @@ +#include "duckdb/optimizer/in_clause_rewriter.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/operator/logical_column_data_get.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +unique_ptr InClauseRewriter::Rewrite(unique_ptr op) { + if (op->children.size() == 1) { + root = std::move(op->children[0]); + VisitOperatorExpressions(*op); + op->children[0] = std::move(root); + } + + for (auto &child : op->children) { + child = Rewrite(std::move(child)); + } + return op; +} + +unique_ptr InClauseRewriter::VisitReplace(BoundOperatorExpression &expr, unique_ptr *expr_ptr) { + if (expr.type != ExpressionType::COMPARE_IN && expr.type != ExpressionType::COMPARE_NOT_IN) { + return nullptr; + } + D_ASSERT(root); + auto in_type = expr.children[0]->return_type; + bool is_regular_in = expr.type == ExpressionType::COMPARE_IN; + bool all_scalar = true; + // IN clause with many children: try to generate a mark join that replaces this IN expression + // we can only do this if the expressions in the expression list are scalar + for (idx_t i = 1; i < expr.children.size(); i++) { + if (!expr.children[i]->IsFoldable()) { + // non-scalar expression + all_scalar = false; + } + } + if (expr.children.size() == 2) { + // only one child + // IN: turn into X = 1 + // NOT IN: turn into X <> 1 + return make_uniq(is_regular_in ? ExpressionType::COMPARE_EQUAL + : ExpressionType::COMPARE_NOTEQUAL, + std::move(expr.children[0]), std::move(expr.children[1])); + } + if (expr.children.size() < 6 || !all_scalar) { + // low amount of children or not all scalar + // IN: turn into (X = 1 OR X = 2 OR X = 3...) + // NOT IN: turn into (X <> 1 AND X <> 2 AND X <> 3 ...) + auto conjunction = make_uniq(is_regular_in ? ExpressionType::CONJUNCTION_OR + : ExpressionType::CONJUNCTION_AND); + for (idx_t i = 1; i < expr.children.size(); i++) { + conjunction->children.push_back(make_uniq( + is_regular_in ? ExpressionType::COMPARE_EQUAL : ExpressionType::COMPARE_NOTEQUAL, + expr.children[0]->Copy(), std::move(expr.children[i]))); + } + return std::move(conjunction); + } + // IN clause with many constant children + // generate a mark join that replaces this IN expression + // first generate a ColumnDataCollection from the set of expressions + vector types = {in_type}; + auto collection = make_uniq(context, types); + ColumnDataAppendState append_state; + collection->InitializeAppend(append_state); + + DataChunk chunk; + chunk.Initialize(context, types); + for (idx_t i = 1; i < expr.children.size(); i++) { + // resolve this expression to a constant + auto value = ExpressionExecutor::EvaluateScalar(context, *expr.children[i]); + idx_t index = chunk.size(); + chunk.SetCardinality(chunk.size() + 1); + chunk.SetValue(0, index, value); + if (chunk.size() == STANDARD_VECTOR_SIZE || i + 1 == expr.children.size()) { + // chunk full: append to chunk collection + collection->Append(append_state, chunk); + chunk.Reset(); + } + } + // now generate a ChunkGet that scans this collection + auto chunk_index = optimizer.binder.GenerateTableIndex(); + auto chunk_scan = make_uniq(chunk_index, types, std::move(collection)); + + // then we generate the MARK join with the chunk scan on the RHS + auto join = make_uniq(JoinType::MARK); + join->mark_index = chunk_index; + join->AddChild(std::move(root)); + join->AddChild(std::move(chunk_scan)); + // create the JOIN condition + JoinCondition cond; + cond.left = std::move(expr.children[0]); + + cond.right = make_uniq(in_type, ColumnBinding(chunk_index, 0)); + cond.comparison = ExpressionType::COMPARE_EQUAL; + join->conditions.push_back(std::move(cond)); + root = std::move(join); + + // we replace the original subquery with a BoundColumnRefExpression referring to the mark column + unique_ptr result = + make_uniq("IN (...)", LogicalType::BOOLEAN, ColumnBinding(chunk_index, 0)); + if (!is_regular_in) { + // NOT IN: invert + auto invert = make_uniq(ExpressionType::OPERATOR_NOT, LogicalType::BOOLEAN); + invert->children.push_back(std::move(result)); + result = std::move(invert); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp b/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp new file mode 100644 index 00000000..a929255c --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp @@ -0,0 +1,354 @@ +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/optimizer/join_order/join_order_optimizer.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/limits.hpp" + +namespace duckdb { + +// The filter was made on top of a logical sample or other projection, +// but no specific columns are referenced. See issue 4978 number 4. +bool CardinalityEstimator::EmptyFilter(FilterInfo &filter_info) { + if (!filter_info.left_set && !filter_info.right_set) { + return true; + } + return false; +} + +void CardinalityEstimator::AddRelationTdom(FilterInfo &filter_info) { + D_ASSERT(filter_info.set.count >= 1); + for (const RelationsToTDom &r2tdom : relations_to_tdoms) { + auto &i_set = r2tdom.equivalent_relations; + if (i_set.find(filter_info.left_binding) != i_set.end()) { + // found an equivalent filter + return; + } + } + + auto key = ColumnBinding(filter_info.left_binding.table_index, filter_info.left_binding.column_index); + RelationsToTDom new_r2tdom(column_binding_set_t({key})); + + relations_to_tdoms.emplace_back(new_r2tdom); +} + +bool CardinalityEstimator::SingleColumnFilter(FilterInfo &filter_info) { + if (filter_info.left_set && filter_info.right_set) { + // Both set + return false; + } + if (EmptyFilter(filter_info)) { + return false; + } + return true; +} + +vector CardinalityEstimator::DetermineMatchingEquivalentSets(FilterInfo *filter_info) { + vector matching_equivalent_sets; + auto equivalent_relation_index = 0; + + for (const RelationsToTDom &r2tdom : relations_to_tdoms) { + auto &i_set = r2tdom.equivalent_relations; + if (i_set.find(filter_info->left_binding) != i_set.end()) { + matching_equivalent_sets.push_back(equivalent_relation_index); + } else if (i_set.find(filter_info->right_binding) != i_set.end()) { + // don't add both left and right to the matching_equivalent_sets + // since both left and right get added to that index anyway. + matching_equivalent_sets.push_back(equivalent_relation_index); + } + equivalent_relation_index++; + } + return matching_equivalent_sets; +} + +void CardinalityEstimator::AddToEquivalenceSets(FilterInfo *filter_info, vector matching_equivalent_sets) { + D_ASSERT(matching_equivalent_sets.size() <= 2); + if (matching_equivalent_sets.size() > 1) { + // an equivalence relation is connecting two sets of equivalence relations + // so push all relations from the second set into the first. Later we will delete + // the second set. + for (ColumnBinding i : relations_to_tdoms.at(matching_equivalent_sets[1]).equivalent_relations) { + relations_to_tdoms.at(matching_equivalent_sets[0]).equivalent_relations.insert(i); + } + for (auto &column_name : relations_to_tdoms.at(matching_equivalent_sets[1]).column_names) { + relations_to_tdoms.at(matching_equivalent_sets[0]).column_names.push_back(column_name); + } + relations_to_tdoms.at(matching_equivalent_sets[1]).equivalent_relations.clear(); + relations_to_tdoms.at(matching_equivalent_sets[1]).column_names.clear(); + relations_to_tdoms.at(matching_equivalent_sets[0]).filters.push_back(filter_info); + // add all values of one set to the other, delete the empty one + } else if (matching_equivalent_sets.size() == 1) { + auto &tdom_i = relations_to_tdoms.at(matching_equivalent_sets.at(0)); + tdom_i.equivalent_relations.insert(filter_info->left_binding); + tdom_i.equivalent_relations.insert(filter_info->right_binding); + tdom_i.filters.push_back(filter_info); + } else if (matching_equivalent_sets.empty()) { + column_binding_set_t tmp; + tmp.insert(filter_info->left_binding); + tmp.insert(filter_info->right_binding); + relations_to_tdoms.emplace_back(tmp); + relations_to_tdoms.back().filters.push_back(filter_info); + } +} + +void CardinalityEstimator::InitEquivalentRelations(const vector> &filter_infos) { + // For each filter, we fill keep track of the index of the equivalent relation set + // the left and right relation needs to be added to. + for (auto &filter : filter_infos) { + if (SingleColumnFilter(*filter)) { + // Filter on one relation, (i.e string or range filter on a column). + // Grab the first relation and add it to the equivalence_relations + AddRelationTdom(*filter); + continue; + } else if (EmptyFilter(*filter)) { + continue; + } + D_ASSERT(filter->left_set->count >= 1); + D_ASSERT(filter->right_set->count >= 1); + + auto matching_equivalent_sets = DetermineMatchingEquivalentSets(filter.get()); + AddToEquivalenceSets(filter.get(), matching_equivalent_sets); + } + RemoveEmptyTotalDomains(); +} + +void CardinalityEstimator::RemoveEmptyTotalDomains() { + auto remove_start = std::remove_if(relations_to_tdoms.begin(), relations_to_tdoms.end(), + [](RelationsToTDom &r_2_tdom) { return r_2_tdom.equivalent_relations.empty(); }); + relations_to_tdoms.erase(remove_start, relations_to_tdoms.end()); +} + +void UpdateDenom(Subgraph2Denominator &relation_2_denom, RelationsToTDom &relation_to_tdom) { + relation_2_denom.denom *= relation_to_tdom.has_tdom_hll ? relation_to_tdom.tdom_hll : relation_to_tdom.tdom_no_hll; +} + +void FindSubgraphMatchAndMerge(Subgraph2Denominator &merge_to, idx_t find_me, + vector::iterator subgraph, + vector::iterator end) { + for (; subgraph != end; subgraph++) { + if (subgraph->relations.count(find_me) >= 1) { + for (auto &relation : subgraph->relations) { + merge_to.relations.insert(relation); + } + subgraph->relations.clear(); + merge_to.denom *= subgraph->denom; + return; + } + } +} + +template <> +double CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) { + + if (relation_set_2_cardinality.find(new_set.ToString()) != relation_set_2_cardinality.end()) { + return relation_set_2_cardinality[new_set.ToString()].cardinality_before_filters; + } + double numerator = 1; + unordered_set actual_set; + + for (idx_t i = 0; i < new_set.count; i++) { + auto &single_node_set = set_manager.GetJoinRelation(new_set.relations[i]); + auto card_helper = relation_set_2_cardinality[single_node_set.ToString()]; + numerator *= card_helper.cardinality_before_filters; + actual_set.insert(new_set.relations[i]); + } + + vector subgraphs; + bool done = false; + bool found_match = false; + + // Finding the denominator is tricky. You need to go through the tdoms in decreasing order + // Then loop through all filters in the equivalence set of the tdom to see if both the + // left and right relations are in the new set, if so you can use that filter. + // You must also make sure that the filters all relations in the given set, so we use subgraphs + // that should eventually merge into one connected graph that joins all the relations + // TODO: Implement a method to cache subgraphs so you don't have to build them up every + // time the cardinality of a new set is requested + + // relations_to_tdoms has already been sorted. + for (auto &relation_2_tdom : relations_to_tdoms) { + // loop through each filter in the tdom. + if (done) { + break; + } + for (auto &filter : relation_2_tdom.filters) { + if (actual_set.count(filter->left_binding.table_index) == 0 || + actual_set.count(filter->right_binding.table_index) == 0) { + continue; + } + // the join filter is on relations in the new set. + found_match = false; + vector::iterator it; + for (it = subgraphs.begin(); it != subgraphs.end(); it++) { + auto left_in = it->relations.count(filter->left_binding.table_index); + auto right_in = it->relations.count(filter->right_binding.table_index); + if (left_in && right_in) { + // if both left and right bindings are in the subgraph, continue. + // This means another filter is connecting relations already in the + // subgraph it, but it has a tdom that is less, and we don't care. + found_match = true; + continue; + } + if (!left_in && !right_in) { + // if both left and right bindings are *not* in the subgraph, continue + // without finding a match. This will trigger the process to add a new + // subgraph + continue; + } + idx_t find_table; + if (left_in) { + find_table = filter->right_binding.table_index; + } else { + D_ASSERT(right_in); + find_table = filter->left_binding.table_index; + } + auto next_subgraph = it + 1; + // iterate through other subgraphs and merge. + FindSubgraphMatchAndMerge(*it, find_table, next_subgraph, subgraphs.end()); + // Now insert the right binding and update denominator with the + // tdom of the filter + it->relations.insert(find_table); + UpdateDenom(*it, relation_2_tdom); + found_match = true; + break; + } + // means that the filter joins relations in the given set, but there is no + // connection to any subgraph in subgraphs. Add a new subgraph, and maybe later there will be + // a connection. + if (!found_match) { + subgraphs.emplace_back(); + auto &subgraph = subgraphs.back(); + subgraph.relations.insert(filter->left_binding.table_index); + subgraph.relations.insert(filter->right_binding.table_index); + UpdateDenom(subgraph, relation_2_tdom); + } + auto remove_start = std::remove_if(subgraphs.begin(), subgraphs.end(), + [](Subgraph2Denominator &s) { return s.relations.empty(); }); + subgraphs.erase(remove_start, subgraphs.end()); + + if (subgraphs.size() == 1 && subgraphs.at(0).relations.size() == new_set.count) { + // You have found enough filters to connect the relations. These are guaranteed + // to be the filters with the highest Tdoms. + done = true; + break; + } + } + } + double denom = 1; + // TODO: It's possible cross-products were added and are not present in the filters in the relation_2_tdom + // structures. When that's the case, multiply the denom structures that have no intersection + for (auto &match : subgraphs) { + denom *= match.denom; + } + // can happen if a table has cardinality 0, or a tdom is set to 0 + if (denom == 0) { + denom = 1; + } + auto result = numerator / denom; + auto new_entry = CardinalityHelper((double)result, 1); + relation_set_2_cardinality[new_set.ToString()] = new_entry; + return result; +} + +template <> +idx_t CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) { + auto cardinality_as_double = EstimateCardinalityWithSet(new_set); + auto max = NumericLimits::Maximum(); + if (cardinality_as_double > max) { + return max; + } + return (idx_t)cardinality_as_double; +} + +bool SortTdoms(const RelationsToTDom &a, const RelationsToTDom &b) { + if (a.has_tdom_hll && b.has_tdom_hll) { + return a.tdom_hll > b.tdom_hll; + } + if (a.has_tdom_hll) { + return a.tdom_hll > b.tdom_no_hll; + } + if (b.has_tdom_hll) { + return a.tdom_no_hll > b.tdom_hll; + } + return a.tdom_no_hll > b.tdom_no_hll; +} + +void CardinalityEstimator::InitCardinalityEstimatorProps(optional_ptr set, RelationStats &stats) { + // Get the join relation set + D_ASSERT(stats.stats_initialized); + auto relation_cardinality = stats.cardinality; + auto relation_filter = stats.filter_strength; + + auto card_helper = CardinalityHelper(relation_cardinality, relation_filter); + relation_set_2_cardinality[set->ToString()] = card_helper; + + UpdateTotalDomains(set, stats); + + // sort relations from greatest tdom to lowest tdom. + std::sort(relations_to_tdoms.begin(), relations_to_tdoms.end(), SortTdoms); +} + +void CardinalityEstimator::UpdateTotalDomains(optional_ptr set, RelationStats &stats) { + D_ASSERT(set->count == 1); + auto relation_id = set->relations[0]; + //! Initialize the distinct count for all columns used in joins with the current relation. + // D_ASSERT(stats.column_distinct_count.size() >= 1); + + for (idx_t i = 0; i < stats.column_distinct_count.size(); i++) { + //! for every column used in a filter in the relation, get the distinct count via HLL, or assume it to be + //! the cardinality + // Update the relation_to_tdom set with the estimated distinct count (or tdom) calculated above + auto key = ColumnBinding(relation_id, i); + for (auto &relation_to_tdom : relations_to_tdoms) { + column_binding_set_t i_set = relation_to_tdom.equivalent_relations; + if (i_set.find(key) == i_set.end()) { + continue; + } + auto distinct_count = stats.column_distinct_count.at(i); + if (distinct_count.from_hll && relation_to_tdom.has_tdom_hll) { + relation_to_tdom.tdom_hll = MaxValue(relation_to_tdom.tdom_hll, distinct_count.distinct_count); + } else if (distinct_count.from_hll && !relation_to_tdom.has_tdom_hll) { + relation_to_tdom.has_tdom_hll = true; + relation_to_tdom.tdom_hll = distinct_count.distinct_count; + } else { + relation_to_tdom.tdom_no_hll = MinValue(distinct_count.distinct_count, relation_to_tdom.tdom_no_hll); + } + break; + } + } +} + +// LCOV_EXCL_START + +void CardinalityEstimator::AddRelationNamesToTdoms(vector &stats) { +#ifdef DEBUG + for (auto &total_domain : relations_to_tdoms) { + for (auto &binding : total_domain.equivalent_relations) { + D_ASSERT(binding.table_index < stats.size()); + D_ASSERT(binding.column_index < stats.at(binding.table_index).column_names.size()); + string column_name = stats.at(binding.table_index).column_names.at(binding.column_index); + total_domain.column_names.push_back(column_name); + } + } +#endif +} + +void CardinalityEstimator::PrintRelationToTdomInfo() { + for (auto &total_domain : relations_to_tdoms) { + string domain = "Following columns have the same distinct count: "; + for (auto &column_name : total_domain.column_names) { + domain += column_name + ", "; + } + bool have_hll = total_domain.has_tdom_hll; + domain += "\n TOTAL DOMAIN = " + to_string(have_hll ? total_domain.tdom_hll : total_domain.tdom_no_hll); + Printer::Print(domain); + } +} + +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/cost_model.cpp b/src/duckdb/src/optimizer/join_order/cost_model.cpp new file mode 100644 index 00000000..bb2a0746 --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/cost_model.cpp @@ -0,0 +1,19 @@ +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/optimizer/join_order/join_order_optimizer.hpp" +#include "duckdb/optimizer/join_order/cost_model.hpp" +#include + +namespace duckdb { + +CostModel::CostModel(QueryGraphManager &query_graph_manager) + : query_graph_manager(query_graph_manager), cardinality_estimator() { +} + +double CostModel::ComputeCost(JoinNode &left, JoinNode &right) { + auto &combination = query_graph_manager.set_manager.Union(left.set, right.set); + auto join_card = cardinality_estimator.EstimateCardinalityWithSet(combination); + auto join_cost = join_card; + return join_cost + left.cost + right.cost; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/estimated_properties.cpp b/src/duckdb/src/optimizer/join_order/estimated_properties.cpp new file mode 100644 index 00000000..d3841a1b --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/estimated_properties.cpp @@ -0,0 +1,36 @@ + +#include "duckdb/optimizer/join_order/estimated_properties.hpp" + +namespace duckdb { + +template <> +double EstimatedProperties::GetCardinality() const { + return cardinality; +} + +template <> +idx_t EstimatedProperties::GetCardinality() const { + auto max_idx_t = NumericLimits::Maximum() - 10000; + return MinValue(cardinality, max_idx_t); +} + +template <> +double EstimatedProperties::GetCost() const { + return cost; +} + +template <> +idx_t EstimatedProperties::GetCost() const { + auto max_idx_t = NumericLimits::Maximum() - 10000; + return MinValue(cost, max_idx_t); +} + +void EstimatedProperties::SetCardinality(double new_card) { + cardinality = new_card; +} + +void EstimatedProperties::SetCost(double new_cost) { + cost = new_cost; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/join_node.cpp b/src/duckdb/src/optimizer/join_order/join_node.cpp new file mode 100644 index 00000000..1786bd49 --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/join_node.cpp @@ -0,0 +1,36 @@ +#include "duckdb/optimizer/join_order/join_node.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/operator/list.hpp" + +namespace duckdb { + +JoinNode::JoinNode(JoinRelationSet &set) : set(set) { +} + +JoinNode::JoinNode(JoinRelationSet &set, optional_ptr info, JoinNode &left, JoinNode &right, double cost) + : set(set), info(info), left(&left), right(&right), cost(cost) { +} + +unique_ptr EstimatedProperties::Copy() { + auto result = make_uniq(cardinality, cost); + return result; +} + +string JoinNode::ToString() { + string result = "-------------------------------\n"; + result += set.ToString() + "\n"; + result += "cost = " + to_string(cost) + "\n"; + result += "left = \n"; + if (left) { + result += left->ToString(); + } + result += "right = \n"; + if (right) { + result += right->ToString(); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp b/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp new file mode 100644 index 00000000..14c6ac1b --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp @@ -0,0 +1,86 @@ +#include "duckdb/optimizer/join_order/join_order_optimizer.hpp" +#include "duckdb/optimizer/join_order/cost_model.hpp" +#include "duckdb/optimizer/join_order/plan_enumerator.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/list.hpp" + +#include +#include + +namespace duckdb { + +static bool HasJoin(LogicalOperator *op) { + while (!op->children.empty()) { + if (op->children.size() == 1) { + op = op->children[0].get(); + } + if (op->children.size() == 2) { + return true; + } + } + return false; +} + +unique_ptr JoinOrderOptimizer::Optimize(unique_ptr plan, + optional_ptr stats) { + + // make sure query graph manager has not extracted a relation graph already + LogicalOperator *op = plan.get(); + + // extract the relations that go into the hyper graph. + // We optimize the children of any non-reorderable operations we come across. + bool reorderable = query_graph_manager.Build(*op); + + // get relation_stats here since the reconstruction process will move all of the relations. + auto relation_stats = query_graph_manager.relation_manager.GetRelationStats(); + unique_ptr new_logical_plan = nullptr; + + if (reorderable) { + // query graph now has filters and relations + auto cost_model = CostModel(query_graph_manager); + + // Initialize a plan enumerator. + auto plan_enumerator = + PlanEnumerator(query_graph_manager, cost_model, query_graph_manager.GetQueryGraphEdges()); + + // Initialize the leaf/single node plans + plan_enumerator.InitLeafPlans(); + + // Ask the plan enumerator to enumerate a number of join orders + auto final_plan = plan_enumerator.SolveJoinOrder(); + // TODO: add in the check that if no plan exists, you have to add a cross product. + + // now reconstruct a logical plan from the query graph plan + new_logical_plan = query_graph_manager.Reconstruct(std::move(plan), *final_plan); + } else { + new_logical_plan = std::move(plan); + if (relation_stats.size() == 1) { + new_logical_plan->estimated_cardinality = relation_stats.at(0).cardinality; + } + } + + // only perform left right optimizations when stats is null (means we have the top level optimize call) + // Don't check reorderability because non-reorderable joins will result in 1 relation, but we can + // still switch the children. + // TODO: put this in a different optimizer maybe? + if (stats == nullptr && HasJoin(new_logical_plan.get())) { + new_logical_plan = query_graph_manager.LeftRightOptimizations(std::move(new_logical_plan)); + } + + // Propagate up a stats object from the top of the new_logical_plan if stats exist. + if (stats) { + auto cardinality = new_logical_plan->EstimateCardinality(context); + auto bindings = new_logical_plan->GetColumnBindings(); + auto new_stats = RelationStatisticsHelper::CombineStatsOfReorderableOperator(bindings, relation_stats); + new_stats.cardinality = MaxValue(cardinality, new_stats.cardinality); + RelationStatisticsHelper::CopyRelationStats(*stats, new_stats); + } + + return new_logical_plan; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/join_relation_set.cpp b/src/duckdb/src/optimizer/join_order/join_relation_set.cpp new file mode 100644 index 00000000..576d5eab --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/join_relation_set.cpp @@ -0,0 +1,142 @@ +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" + +#include + +namespace duckdb { + +using JoinRelationTreeNode = JoinRelationSetManager::JoinRelationTreeNode; + +// LCOV_EXCL_START +string JoinRelationSet::ToString() const { + string result = "["; + result += StringUtil::Join(relations, count, ", ", [](const idx_t &relation) { return to_string(relation); }); + result += "]"; + return result; +} +// LCOV_EXCL_STOP + +//! Returns true if sub is a subset of super +bool JoinRelationSet::IsSubset(JoinRelationSet &super, JoinRelationSet &sub) { + D_ASSERT(sub.count > 0); + if (sub.count > super.count) { + return false; + } + idx_t j = 0; + for (idx_t i = 0; i < super.count; i++) { + if (sub.relations[j] == super.relations[i]) { + j++; + if (j == sub.count) { + return true; + } + } + } + return false; +} + +JoinRelationSet &JoinRelationSetManager::GetJoinRelation(unsafe_unique_array relations, idx_t count) { + // now look it up in the tree + reference info(root); + for (idx_t i = 0; i < count; i++) { + auto entry = info.get().children.find(relations[i]); + if (entry == info.get().children.end()) { + // node not found, create it + auto insert_it = info.get().children.insert(make_pair(relations[i], make_uniq())); + entry = insert_it.first; + } + // move to the next node + info = *entry->second; + } + // now check if the JoinRelationSet has already been created + if (!info.get().relation) { + // if it hasn't we need to create it + info.get().relation = make_uniq(std::move(relations), count); + } + return *info.get().relation; +} + +//! Create or get a JoinRelationSet from a single node with the given index +JoinRelationSet &JoinRelationSetManager::GetJoinRelation(idx_t index) { + // create a sorted vector of the relations + auto relations = make_unsafe_uniq_array(1); + relations[0] = index; + idx_t count = 1; + return GetJoinRelation(std::move(relations), count); +} + +JoinRelationSet &JoinRelationSetManager::GetJoinRelation(const unordered_set &bindings) { + // create a sorted vector of the relations + unsafe_unique_array relations = bindings.empty() ? nullptr : make_unsafe_uniq_array(bindings.size()); + idx_t count = 0; + for (auto &entry : bindings) { + relations[count++] = entry; + } + std::sort(relations.get(), relations.get() + count); + return GetJoinRelation(std::move(relations), count); +} + +JoinRelationSet &JoinRelationSetManager::Union(JoinRelationSet &left, JoinRelationSet &right) { + auto relations = make_unsafe_uniq_array(left.count + right.count); + idx_t count = 0; + // move through the left and right relations, eliminating duplicates + idx_t i = 0, j = 0; + while (true) { + if (i == left.count) { + // exhausted left relation, add remaining of right relation + for (; j < right.count; j++) { + relations[count++] = right.relations[j]; + } + break; + } else if (j == right.count) { + // exhausted right relation, add remaining of left + for (; i < left.count; i++) { + relations[count++] = left.relations[i]; + } + break; + } else if (left.relations[i] < right.relations[j]) { + // left is smaller, progress left and add it to the set + relations[count++] = left.relations[i]; + i++; + } else { + D_ASSERT(left.relations[i] > right.relations[j]); + // right is smaller, progress right and add it to the set + relations[count++] = right.relations[j]; + j++; + } + } + return GetJoinRelation(std::move(relations), count); +} + +// JoinRelationSet *JoinRelationSetManager::Difference(JoinRelationSet *left, JoinRelationSet *right) { +// auto relations = unsafe_unique_array(new idx_t[left->count]); +// idx_t count = 0; +// // move through the left and right relations +// idx_t i = 0, j = 0; +// while (true) { +// if (i == left->count) { +// // exhausted left relation, we are done +// break; +// } else if (j == right->count) { +// // exhausted right relation, add remaining of left +// for (; i < left->count; i++) { +// relations[count++] = left->relations[i]; +// } +// break; +// } else if (left->relations[i] == right->relations[j]) { +// // equivalent, add nothing +// i++; +// j++; +// } else if (left->relations[i] < right->relations[j]) { +// // left is smaller, progress left and add it to the set +// relations[count++] = left->relations[i]; +// i++; +// } else { +// // right is smaller, progress right +// j++; +// } +// } +// return GetJoinRelation(std::move(relations), count); +// } + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp new file mode 100644 index 00000000..37c5616c --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp @@ -0,0 +1,552 @@ +#include "duckdb/optimizer/join_order/join_node.hpp" +#include "duckdb/optimizer/join_order/plan_enumerator.hpp" +#include "duckdb/optimizer/join_order/query_graph_manager.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +bool PlanEnumerator::NodeInFullPlan(JoinNode &node) { + return join_nodes_in_full_plan.find(node.set.ToString()) != join_nodes_in_full_plan.end(); +} + +void PlanEnumerator::UpdateJoinNodesInFullPlan(JoinNode &node) { + if (node.set.count == query_graph_manager.relation_manager.NumRelations()) { + join_nodes_in_full_plan.clear(); + } + if (node.set.count < query_graph_manager.relation_manager.NumRelations()) { + join_nodes_in_full_plan.insert(node.set.ToString()); + } + if (node.left) { + UpdateJoinNodesInFullPlan(*node.left); + } + if (node.right) { + UpdateJoinNodesInFullPlan(*node.right); + } +} + +static vector> AddSuperSets(const vector> ¤t, + const vector &all_neighbors) { + vector> ret; + + for (const auto &neighbor_set : current) { + auto max_val = std::max_element(neighbor_set.begin(), neighbor_set.end()); + for (const auto &neighbor : all_neighbors) { + if (*max_val >= neighbor) { + continue; + } + if (neighbor_set.count(neighbor) == 0) { + unordered_set new_set; + for (auto &n : neighbor_set) { + new_set.insert(n); + } + new_set.insert(neighbor); + ret.push_back(new_set); + } + } + } + + return ret; +} + +//! Update the exclusion set with all entries in the subgraph +static void UpdateExclusionSet(optional_ptr node, unordered_set &exclusion_set) { + for (idx_t i = 0; i < node->count; i++) { + exclusion_set.insert(node->relations[i]); + } +} + +// works by first creating all sets with cardinality 1 +// then iterates over each previously created group of subsets and will only add a neighbor if the neighbor +// is greater than all relations in the set. +static vector> GetAllNeighborSets(vector neighbors) { + vector> ret; + sort(neighbors.begin(), neighbors.end()); + vector> added; + for (auto &neighbor : neighbors) { + added.push_back(unordered_set({neighbor})); + ret.push_back(unordered_set({neighbor})); + } + do { + added = AddSuperSets(added, neighbors); + for (auto &d : added) { + ret.push_back(d); + } + } while (!added.empty()); +#if DEBUG + // drive by test to make sure we have an accurate amount of + // subsets, and that each neighbor is in a correct amount + // of those subsets. + D_ASSERT(ret.size() == pow(2, neighbors.size()) - 1); + for (auto &n : neighbors) { + idx_t count = 0; + for (auto &set : ret) { + if (set.count(n) >= 1) { + count += 1; + } + } + D_ASSERT(count == pow(2, neighbors.size() - 1)); + } +#endif + return ret; +} + +void PlanEnumerator::GenerateCrossProducts() { + // generate a set of cross products to combine the currently available plans into a full join plan + // we create edges between every relation with a high cost + for (idx_t i = 0; i < query_graph_manager.relation_manager.NumRelations(); i++) { + auto &left = query_graph_manager.set_manager.GetJoinRelation(i); + for (idx_t j = 0; j < query_graph_manager.relation_manager.NumRelations(); j++) { + if (i != j) { + auto &right = query_graph_manager.set_manager.GetJoinRelation(j); + query_graph_manager.CreateQueryGraphCrossProduct(left, right); + } + } + } + // Now that the query graph has new edges, we need to re-initialize our query graph. + // TODO: do we need to initialize our qyery graph again? + // query_graph = query_graph_manager.GetQueryGraph(); +} + +//! Create a new JoinTree node by joining together two previous JoinTree nodes +unique_ptr PlanEnumerator::CreateJoinTree(JoinRelationSet &set, + const vector> &possible_connections, + JoinNode &left, JoinNode &right) { + // for the hash join we want the right side (build side) to have the smallest cardinality + // also just a heuristic but for now... + // FIXME: we should probably actually benchmark that as well + // FIXME: should consider different join algorithms, should we pick a join algorithm here as well? (probably) + optional_ptr best_connection = nullptr; + + // cross products are techincally still connections, but the filter expression is a null_ptr + if (!possible_connections.empty()) { + best_connection = &possible_connections.back().get(); + } + + auto cost = cost_model.ComputeCost(left, right); + auto result = make_uniq(set, best_connection, left, right, cost); + result->cardinality = cost_model.cardinality_estimator.EstimateCardinalityWithSet(set); + return result; +} + +JoinNode &PlanEnumerator::EmitPair(JoinRelationSet &left, JoinRelationSet &right, + const vector> &info) { + // get the left and right join plans + auto left_plan = plans.find(left); + auto right_plan = plans.find(right); + if (left_plan == plans.end() || right_plan == plans.end()) { + throw InternalException("No left or right plan: internal error in join order optimizer"); + } + auto &new_set = query_graph_manager.set_manager.Union(left, right); + // create the join tree based on combining the two plans + auto new_plan = CreateJoinTree(new_set, info, *left_plan->second, *right_plan->second); + // check if this plan is the optimal plan we found for this set of relations + auto entry = plans.find(new_set); + auto new_cost = new_plan->cost; + double old_cost = NumericLimits::Maximum(); + if (entry != plans.end()) { + old_cost = entry->second->cost; + } + if (entry == plans.end() || new_cost < old_cost) { + // the new plan costs less than the old plan. Update our DP tree and cost tree + auto &result = *new_plan; + + if (full_plan_found && + join_nodes_in_full_plan.find(new_plan->set.ToString()) != join_nodes_in_full_plan.end()) { + must_update_full_plan = true; + } + if (new_set.count == query_graph_manager.relation_manager.NumRelations()) { + full_plan_found = true; + // If we find a full plan, we need to keep track of which nodes are in the full plan. + // It's possible the DP algorithm updates a node in the current full plan, then moves on + // to the SolveApproximately. SolveApproximately may find a full plan with a higher cost than + // what SolveExactly found. In this case, we revert to the SolveExactly plan, but it is + // possible to get use-after-free errors if the SolveApproximately algorithm updated some (but not all) + // nodes in the SolveExactly plan + // If we know a node in the full plan is updated, we can prevent ourselves from exiting the + // DP algorithm until the last plan updated is a full plan + UpdateJoinNodesInFullPlan(result); + if (must_update_full_plan) { + must_update_full_plan = false; + } + } + + D_ASSERT(new_plan); + plans[new_set] = std::move(new_plan); + return result; + } + return *entry->second; +} + +bool PlanEnumerator::TryEmitPair(JoinRelationSet &left, JoinRelationSet &right, + const vector> &info) { + pairs++; + // If a full plan is created, it's possible a node in the plan gets updated. When this happens, make sure you keep + // emitting pairs until you emit another final plan. Another final plan is guaranteed to be produced because of + // our symmetry guarantees. + if (pairs >= 10000 && !must_update_full_plan) { + // when the amount of pairs gets too large we exit the dynamic programming and resort to a greedy algorithm + // FIXME: simple heuristic currently + // at 10K pairs stop searching exactly and switch to heuristic + return false; + } + EmitPair(left, right, info); + return true; +} + +bool PlanEnumerator::EmitCSG(JoinRelationSet &node) { + if (node.count == query_graph_manager.relation_manager.NumRelations()) { + return true; + } + // create the exclusion set as everything inside the subgraph AND anything with members BELOW it + unordered_set exclusion_set; + for (idx_t i = 0; i < node.relations[0]; i++) { + exclusion_set.insert(i); + } + UpdateExclusionSet(&node, exclusion_set); + // find the neighbors given this exclusion set + auto neighbors = query_graph.GetNeighbors(node, exclusion_set); + if (neighbors.empty()) { + return true; + } + + //! Neighbors should be reversed when iterating over them. + std::sort(neighbors.begin(), neighbors.end(), std::greater_equal()); + for (idx_t i = 0; i < neighbors.size() - 1; i++) { + D_ASSERT(neighbors[i] > neighbors[i + 1]); + } + + // Dphyp paper missiing this. + // Because we are traversing in reverse order, we need to add neighbors whose number is smaller than the current + // node to exclusion_set + // This avoids duplicated enumeration + unordered_set new_exclusion_set = exclusion_set; + for (idx_t i = 0; i < neighbors.size(); ++i) { + D_ASSERT(new_exclusion_set.find(neighbors[i]) == new_exclusion_set.end()); + new_exclusion_set.insert(neighbors[i]); + } + + for (auto neighbor : neighbors) { + // since the GetNeighbors only returns the smallest element in a list, the entry might not be connected to + // (only!) this neighbor, hence we have to do a connectedness check before we can emit it + auto &neighbor_relation = query_graph_manager.set_manager.GetJoinRelation(neighbor); + auto connections = query_graph.GetConnections(node, neighbor_relation); + if (!connections.empty()) { + if (!TryEmitPair(node, neighbor_relation, connections)) { + return false; + } + } + + if (!EnumerateCmpRecursive(node, neighbor_relation, new_exclusion_set)) { + return false; + } + + new_exclusion_set.erase(neighbor); + } + return true; +} + +bool PlanEnumerator::EnumerateCmpRecursive(JoinRelationSet &left, JoinRelationSet &right, + unordered_set &exclusion_set) { + // get the neighbors of the second relation under the exclusion set + auto neighbors = query_graph.GetNeighbors(right, exclusion_set); + if (neighbors.empty()) { + return true; + } + + auto all_subset = GetAllNeighborSets(neighbors); + vector> union_sets; + union_sets.reserve(all_subset.size()); + for (const auto &rel_set : all_subset) { + auto &neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); + // emit the combinations of this node and its neighbors + auto &combined_set = query_graph_manager.set_manager.Union(right, neighbor); + // If combined_set.count == right.count, This means we found a neighbor that has been present before + // This means we didn't set exclusion_set correctly. + D_ASSERT(combined_set.count > right.count); + if (plans.find(combined_set) != plans.end()) { + auto connections = query_graph.GetConnections(left, combined_set); + if (!connections.empty()) { + if (!TryEmitPair(left, combined_set, connections)) { + return false; + } + } + } + union_sets.push_back(combined_set); + } + + unordered_set new_exclusion_set = exclusion_set; + for (const auto &neighbor : neighbors) { + new_exclusion_set.insert(neighbor); + } + + // recursively enumerate the sets + for (idx_t i = 0; i < union_sets.size(); i++) { + // updated the set of excluded entries with this neighbor + if (!EnumerateCmpRecursive(left, union_sets[i], new_exclusion_set)) { + return false; + } + } + return true; +} + +bool PlanEnumerator::EnumerateCSGRecursive(JoinRelationSet &node, unordered_set &exclusion_set) { + // find neighbors of S under the exclusion set + auto neighbors = query_graph.GetNeighbors(node, exclusion_set); + if (neighbors.empty()) { + return true; + } + + auto all_subset = GetAllNeighborSets(neighbors); + vector> union_sets; + union_sets.reserve(all_subset.size()); + for (const auto &rel_set : all_subset) { + auto &neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); + // emit the combinations of this node and its neighbors + auto &new_set = query_graph_manager.set_manager.Union(node, neighbor); + D_ASSERT(new_set.count > node.count); + if (plans.find(new_set) != plans.end()) { + if (!EmitCSG(new_set)) { + return false; + } + } + union_sets.push_back(new_set); + } + + unordered_set new_exclusion_set = exclusion_set; + for (const auto &neighbor : neighbors) { + new_exclusion_set.insert(neighbor); + } + + // recursively enumerate the sets + for (idx_t i = 0; i < union_sets.size(); i++) { + // updated the set of excluded entries with this neighbor + if (!EnumerateCSGRecursive(union_sets[i], new_exclusion_set)) { + return false; + } + } + return true; +} + +bool PlanEnumerator::SolveJoinOrderExactly() { + // now we perform the actual dynamic programming to compute the final result + // we enumerate over all the possible pairs in the neighborhood + for (idx_t i = query_graph_manager.relation_manager.NumRelations(); i > 0; i--) { + // for every node in the set, we consider it as the start node once + auto &start_node = query_graph_manager.set_manager.GetJoinRelation(i - 1); + // emit the start node + if (!EmitCSG(start_node)) { + return false; + } + // initialize the set of exclusion_set as all the nodes with a number below this + unordered_set exclusion_set; + for (idx_t j = 0; j < i; j++) { + exclusion_set.insert(j); + } + // then we recursively search for neighbors that do not belong to the banned entries + if (!EnumerateCSGRecursive(start_node, exclusion_set)) { + return false; + } + } + return true; +} + +void PlanEnumerator::UpdateDPTree(JoinNode &new_plan) { + if (!NodeInFullPlan(new_plan)) { + // if the new node is not in the full plan, feel free to return + // because you won't be updating the full plan. + return; + } + auto &new_set = new_plan.set; + // now update every plan that uses this plan + unordered_set exclusion_set; + for (idx_t i = 0; i < new_set.count; i++) { + exclusion_set.insert(new_set.relations[i]); + } + auto neighbors = query_graph.GetNeighbors(new_set, exclusion_set); + auto all_neighbors = GetAllNeighborSets(neighbors); + for (const auto &neighbor : all_neighbors) { + auto &neighbor_relation = query_graph_manager.set_manager.GetJoinRelation(neighbor); + auto &combined_set = query_graph_manager.set_manager.Union(new_set, neighbor_relation); + + auto combined_set_plan = plans.find(combined_set); + if (combined_set_plan == plans.end()) { + continue; + } + + double combined_set_plan_cost = combined_set_plan->second->cost; // combined_set_plan->second->GetCost(); + auto connections = query_graph.GetConnections(new_set, neighbor_relation); + // recurse and update up the tree if the combined set produces a plan with a lower cost + // only recurse on neighbor relations that have plans. + auto right_plan = plans.find(neighbor_relation); + if (right_plan == plans.end()) { + continue; + } + auto &updated_plan = EmitPair(new_set, neighbor_relation, connections); + // <= because the child node has already been replaced. You need to + // replace the parent node as well in this case + if (updated_plan.cost < combined_set_plan_cost) { + UpdateDPTree(updated_plan); + } + } +} + +void PlanEnumerator::SolveJoinOrderApproximately() { + // at this point, we exited the dynamic programming but did not compute the final join order because it took too + // long instead, we use a greedy heuristic to obtain a join ordering now we use Greedy Operator Ordering to + // construct the result tree first we start out with all the base relations (the to-be-joined relations) + vector> join_relations; // T in the paper + for (idx_t i = 0; i < query_graph_manager.relation_manager.NumRelations(); i++) { + join_relations.push_back(query_graph_manager.set_manager.GetJoinRelation(i)); + } + while (join_relations.size() > 1) { + // now in every step of the algorithm, we greedily pick the join between the to-be-joined relations that has the + // smallest cost. This is O(r^2) per step, and every step will reduce the total amount of relations to-be-joined + // by 1, so the total cost is O(r^3) in the amount of relations + idx_t best_left = 0, best_right = 0; + optional_ptr best_connection; + for (idx_t i = 0; i < join_relations.size(); i++) { + auto left = join_relations[i]; + for (idx_t j = i + 1; j < join_relations.size(); j++) { + auto right = join_relations[j]; + // check if we can connect these two relations + auto connection = query_graph.GetConnections(left, right); + if (!connection.empty()) { + // we can check the cost of this connection + auto &node = EmitPair(left, right, connection); + + // update the DP tree in case a plan created by the DP algorithm uses the node + // that was potentially just updated by EmitPair. You will get a use-after-free + // error if future plans rely on the old node that was just replaced. + // if node in FullPath, then updateDP tree. + UpdateDPTree(node); + + if (!best_connection || node.cost < best_connection->cost) { + // best pair found so far + best_connection = &node; + best_left = i; + best_right = j; + } + } + } + } + if (!best_connection) { + // could not find a connection, but we were not done with finding a completed plan + // we have to add a cross product; we add it between the two smallest relations + optional_ptr smallest_plans[2]; + idx_t smallest_index[2]; + D_ASSERT(join_relations.size() >= 2); + + // first just add the first two join relations. It doesn't matter the cost as the JOO + // will swap them on estimated cardinality anyway. + for (idx_t i = 0; i < 2; i++) { + auto current_plan = plans[join_relations[i]].get(); + smallest_plans[i] = current_plan; + smallest_index[i] = i; + } + + // if there are any other join relations that don't have connections + // add them if they have lower estimated cardinality. + for (idx_t i = 2; i < join_relations.size(); i++) { + // get the plan for this relation + auto current_plan = plans[join_relations[i].get()].get(); + // check if the cardinality is smaller than the smallest two found so far + for (idx_t j = 0; j < 2; j++) { + if (!smallest_plans[j] || smallest_plans[j]->cost > current_plan->cost) { + smallest_plans[j] = current_plan; + smallest_index[j] = i; + break; + } + } + } + if (!smallest_plans[0] || !smallest_plans[1]) { + throw InternalException("Internal error in join order optimizer"); + } + D_ASSERT(smallest_plans[0] && smallest_plans[1]); + D_ASSERT(smallest_index[0] != smallest_index[1]); + auto &left = smallest_plans[0]->set; + auto &right = smallest_plans[1]->set; + // create a cross product edge (i.e. edge with empty filter) between these two sets in the query graph + query_graph_manager.CreateQueryGraphCrossProduct(left, right); + // now emit the pair and continue with the algorithm + auto connections = query_graph.GetConnections(left, right); + D_ASSERT(!connections.empty()); + + best_connection = &EmitPair(left, right, connections); + best_left = smallest_index[0]; + best_right = smallest_index[1]; + + UpdateDPTree(*best_connection); + // the code below assumes best_right > best_left + if (best_left > best_right) { + std::swap(best_left, best_right); + } + } + // now update the to-be-checked pairs + // remove left and right, and add the combination + + // important to erase the biggest element first + // if we erase the smallest element first the index of the biggest element changes + D_ASSERT(best_right > best_left); + join_relations.erase(join_relations.begin() + best_right); + join_relations.erase(join_relations.begin() + best_left); + join_relations.push_back(best_connection->set); + } +} + +void PlanEnumerator::InitLeafPlans() { + // First we initialize each of the single-node plans with themselves and with their cardinalities these are the leaf + // nodes of the join tree NOTE: we can just use pointers to JoinRelationSet* here because the GetJoinRelation + // function ensures that a unique combination of relations will have a unique JoinRelationSet object. + // first initialize equivalent relations based on the filters + auto relation_stats = query_graph_manager.relation_manager.GetRelationStats(); + + cost_model.cardinality_estimator.InitEquivalentRelations(query_graph_manager.GetFilterBindings()); + cost_model.cardinality_estimator.AddRelationNamesToTdoms(relation_stats); + + // then update the total domains based on the cardinalities of each relation. + for (idx_t i = 0; i < relation_stats.size(); i++) { + auto stats = relation_stats.at(i); + auto &relation_set = query_graph_manager.set_manager.GetJoinRelation(i); + auto join_node = make_uniq(relation_set); + join_node->cost = 0; + join_node->cardinality = stats.cardinality; + plans[relation_set] = std::move(join_node); + cost_model.cardinality_estimator.InitCardinalityEstimatorProps(&relation_set, stats); + } +} + +// the plan enumeration is a straight implementation of the paper "Dynamic Programming Strikes Back" by Guido +// Moerkotte and Thomas Neumannn, see that paper for additional info/documentation bonus slides: +// https://db.in.tum.de/teaching/ws1415/queryopt/chapter3.pdf?lang=de +unique_ptr PlanEnumerator::SolveJoinOrder() { + bool force_no_cross_product = query_graph_manager.context.config.force_no_cross_product; + // first try to solve the join order exactly + if (!SolveJoinOrderExactly()) { + // otherwise, if that times out we resort to a greedy algorithm + SolveJoinOrderApproximately(); + } + + // now the optimal join path should have been found + // get it from the node + unordered_set bindings; + for (idx_t i = 0; i < query_graph_manager.relation_manager.NumRelations(); i++) { + bindings.insert(i); + } + auto &total_relation = query_graph_manager.set_manager.GetJoinRelation(bindings); + auto final_plan = plans.find(total_relation); + if (final_plan == plans.end()) { + // could not find the final plan + // this should only happen in case the sets are actually disjunct + // in this case we need to generate cross product to connect the disjoint sets + if (force_no_cross_product) { + throw InvalidInputException( + "Query requires a cross-product, but 'force_no_cross_product' PRAGMA is enabled"); + } + GenerateCrossProducts(); + //! solve the join order again, returning the final plan + return SolveJoinOrder(); + } + return std::move(final_plan->second); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/query_graph.cpp b/src/duckdb/src/optimizer/join_order/query_graph.cpp new file mode 100644 index 00000000..beb9e152 --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/query_graph.cpp @@ -0,0 +1,140 @@ +#include "duckdb/optimizer/join_order/query_graph.hpp" + +#include "duckdb/common/printer.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/assert.hpp" + +namespace duckdb { + +using QueryEdge = QueryGraphEdges::QueryEdge; + +// LCOV_EXCL_START +static string QueryEdgeToString(const QueryEdge *info, vector prefix) { + string result = ""; + string source = "["; + for (idx_t i = 0; i < prefix.size(); i++) { + source += to_string(prefix[i]) + (i < prefix.size() - 1 ? ", " : ""); + } + source += "]"; + for (auto &entry : info->neighbors) { + result += StringUtil::Format("%s -> %s\n", source.c_str(), entry->neighbor->ToString().c_str()); + } + for (auto &entry : info->children) { + vector new_prefix = prefix; + new_prefix.push_back(entry.first); + result += QueryEdgeToString(entry.second.get(), new_prefix); + } + return result; +} + +string QueryGraphEdges::ToString() const { + return QueryEdgeToString(&root, {}); +} + +void QueryGraphEdges::Print() { + Printer::Print(ToString()); +} +// LCOV_EXCL_STOP + +optional_ptr QueryGraphEdges::GetQueryEdge(JoinRelationSet &left) { + D_ASSERT(left.count > 0); + // find the EdgeInfo corresponding to the left set + optional_ptr info(&root); + for (idx_t i = 0; i < left.count; i++) { + auto entry = info.get()->children.find(left.relations[i]); + if (entry == info.get()->children.end()) { + // node not found, create it + auto insert_it = info.get()->children.insert(make_pair(left.relations[i], make_uniq())); + entry = insert_it.first; + } + // move to the next node + info = entry->second; + } + return info; +} + +void QueryGraphEdges::CreateEdge(JoinRelationSet &left, JoinRelationSet &right, optional_ptr filter_info) { + D_ASSERT(left.count > 0 && right.count > 0); + // find the EdgeInfo corresponding to the left set + auto info = GetQueryEdge(left); + // now insert the edge to the right relation, if it does not exist + for (idx_t i = 0; i < info->neighbors.size(); i++) { + if (info->neighbors[i]->neighbor == &right) { + if (filter_info) { + // neighbor already exists just add the filter, if we have any + info->neighbors[i]->filters.push_back(filter_info); + } + return; + } + } + // neighbor does not exist, create it + auto n = make_uniq(&right); + // if the edge represents a cross product, filter_info is null. The easiest way then to determine + // if an edge is for a cross product is if the filters are empty + if (info && filter_info) { + n->filters.push_back(filter_info); + } + info->neighbors.push_back(std::move(n)); +} + +void QueryGraphEdges::EnumerateNeighborsDFS(JoinRelationSet &node, reference info, idx_t index, + const std::function &callback) const { + + for (auto &neighbor : info.get().neighbors) { + if (callback(*neighbor)) { + return; + } + } + + for (idx_t node_index = index; node_index < node.count; ++node_index) { + auto iter = info.get().children.find(node.relations[node_index]); + if (iter != info.get().children.end()) { + reference new_info = *iter->second; + EnumerateNeighborsDFS(node, new_info, node_index + 1, callback); + } + } +} + +void QueryGraphEdges::EnumerateNeighbors(JoinRelationSet &node, + const std::function &callback) const { + for (idx_t j = 0; j < node.count; j++) { + auto iter = root.children.find(node.relations[j]); + if (iter != root.children.end()) { + reference new_info = *iter->second; + EnumerateNeighborsDFS(node, new_info, j + 1, callback); + } + } +} + +//! Returns true if a JoinRelationSet is banned by the list of exclusion_set, false otherwise +static bool JoinRelationSetIsExcluded(optional_ptr node, unordered_set &exclusion_set) { + return exclusion_set.find(node->relations[0]) != exclusion_set.end(); +} + +const vector QueryGraphEdges::GetNeighbors(JoinRelationSet &node, unordered_set &exclusion_set) const { + unordered_set result; + EnumerateNeighbors(node, [&](NeighborInfo &info) -> bool { + if (!JoinRelationSetIsExcluded(info.neighbor, exclusion_set)) { + // add the smallest node of the neighbor to the set + result.insert(info.neighbor->relations[0]); + } + return false; + }); + vector neighbors; + neighbors.insert(neighbors.end(), result.begin(), result.end()); + return neighbors; +} + +const vector> QueryGraphEdges::GetConnections(JoinRelationSet &node, + JoinRelationSet &other) const { + vector> connections; + EnumerateNeighbors(node, [&](NeighborInfo &info) -> bool { + if (JoinRelationSet::IsSubset(other, *info.neighbor)) { + connections.push_back(info); + } + return false; + }); + return connections; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp b/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp new file mode 100644 index 00000000..1d5e815f --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp @@ -0,0 +1,409 @@ +#include "duckdb/optimizer/join_order/query_graph_manager.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/assert.hpp" + +namespace duckdb { + +//! Returns true if A and B are disjoint, false otherwise +template +static bool Disjoint(const unordered_set &a, const unordered_set &b) { + return std::all_of(a.begin(), a.end(), [&b](typename std::unordered_set::const_reference entry) { + return b.find(entry) == b.end(); + }); +} + +bool QueryGraphManager::Build(LogicalOperator &op) { + vector> filter_operators; + // have the relation manager extract the join relations and create a reference list of all the + // filter operators. + auto can_reorder = relation_manager.ExtractJoinRelations(op, filter_operators); + auto num_relations = relation_manager.NumRelations(); + if (num_relations <= 1 || !can_reorder) { + // nothing to optimize/reorder + return false; + } + // extract the edges of the hypergraph, creating a list of filters and their associated bindings. + filters_and_bindings = relation_manager.ExtractEdges(op, filter_operators, set_manager); + // Create the query_graph hyper edges + CreateHyperGraphEdges(); + return true; +} + +void QueryGraphManager::GetColumnBinding(Expression &expression, ColumnBinding &binding) { + if (expression.type == ExpressionType::BOUND_COLUMN_REF) { + // Here you have a filter on a single column in a table. Return a binding for the column + // being filtered on so the filter estimator knows what HLL count to pull + auto &colref = expression.Cast(); + D_ASSERT(colref.depth == 0); + D_ASSERT(colref.binding.table_index != DConstants::INVALID_INDEX); + // map the base table index to the relation index used by the JoinOrderOptimizer + D_ASSERT(relation_manager.relation_mapping.find(colref.binding.table_index) != + relation_manager.relation_mapping.end()); + binding = + ColumnBinding(relation_manager.relation_mapping[colref.binding.table_index], colref.binding.column_index); + } + // TODO: handle inequality filters with functions. + ExpressionIterator::EnumerateChildren(expression, [&](Expression &expr) { GetColumnBinding(expr, binding); }); +} + +const vector> &QueryGraphManager::GetFilterBindings() const { + return filters_and_bindings; +} + +static unique_ptr PushFilter(unique_ptr node, unique_ptr expr) { + // push an expression into a filter + // first check if we have any filter to push it into + if (node->type != LogicalOperatorType::LOGICAL_FILTER) { + // we don't, we need to create one + auto filter = make_uniq(); + filter->children.push_back(std::move(node)); + node = std::move(filter); + } + // push the filter into the LogicalFilter + D_ASSERT(node->type == LogicalOperatorType::LOGICAL_FILTER); + auto &filter = node->Cast(); + filter.expressions.push_back(std::move(expr)); + return node; +} + +void QueryGraphManager::CreateHyperGraphEdges() { + // create potential edges from the comparisons + for (auto &filter_info : filters_and_bindings) { + auto &filter = filter_info->filter; + // now check if it can be used as a join predicate + if (filter->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON) { + auto &comparison = filter->Cast(); + // extract the bindings that are required for the left and right side of the comparison + unordered_set left_bindings, right_bindings; + relation_manager.ExtractBindings(*comparison.left, left_bindings); + relation_manager.ExtractBindings(*comparison.right, right_bindings); + GetColumnBinding(*comparison.left, filter_info->left_binding); + GetColumnBinding(*comparison.right, filter_info->right_binding); + if (!left_bindings.empty() && !right_bindings.empty()) { + // both the left and the right side have bindings + // first create the relation sets, if they do not exist + filter_info->left_set = &set_manager.GetJoinRelation(left_bindings); + filter_info->right_set = &set_manager.GetJoinRelation(right_bindings); + // we can only create a meaningful edge if the sets are not exactly the same + if (filter_info->left_set != filter_info->right_set) { + // check if the sets are disjoint + if (Disjoint(left_bindings, right_bindings)) { + // they are disjoint, we only need to create one set of edges in the join graph + query_graph.CreateEdge(*filter_info->left_set, *filter_info->right_set, filter_info); + query_graph.CreateEdge(*filter_info->right_set, *filter_info->left_set, filter_info); + } else { + continue; + } + continue; + } + } + } + } +} + +static unique_ptr ExtractJoinRelation(unique_ptr &rel) { + auto &children = rel->parent->children; + for (idx_t i = 0; i < children.size(); i++) { + if (children[i].get() == &rel->op) { + // found it! take ownership o/**/f it from the parent + auto result = std::move(children[i]); + children.erase(children.begin() + i); + return result; + } + } + throw Exception("Could not find relation in parent node (?)"); +} + +unique_ptr QueryGraphManager::Reconstruct(unique_ptr plan, JoinNode &node) { + return RewritePlan(std::move(plan), node); +} + +GenerateJoinRelation QueryGraphManager::GenerateJoins(vector> &extracted_relations, + JoinNode &node) { + optional_ptr left_node; + optional_ptr right_node; + optional_ptr result_relation; + unique_ptr result_operator; + if (node.left && node.right && node.info) { + // generate the left and right children + auto left = GenerateJoins(extracted_relations, *node.left); + auto right = GenerateJoins(extracted_relations, *node.right); + + if (node.info->filters.empty()) { + // no filters, create a cross product + result_operator = LogicalCrossProduct::Create(std::move(left.op), std::move(right.op)); + } else { + // we have filters, create a join node + auto join = make_uniq(JoinType::INNER); + // Here we optimize build side probe side. Our build side is the right side + // So the right plans should have lower cardinalities. + join->children.push_back(std::move(left.op)); + join->children.push_back(std::move(right.op)); + + // set the join conditions from the join node + for (auto &filter_ref : node.info->filters) { + auto f = filter_ref.get(); + // extract the filter from the operator it originally belonged to + D_ASSERT(filters_and_bindings[f->filter_index]->filter); + auto &filter_and_binding = filters_and_bindings.at(f->filter_index); + auto condition = std::move(filter_and_binding->filter); + // now create the actual join condition + D_ASSERT((JoinRelationSet::IsSubset(*left.set, *f->left_set) && + JoinRelationSet::IsSubset(*right.set, *f->right_set)) || + (JoinRelationSet::IsSubset(*left.set, *f->right_set) && + JoinRelationSet::IsSubset(*right.set, *f->left_set))); + JoinCondition cond; + D_ASSERT(condition->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON); + auto &comparison = condition->Cast(); + + // we need to figure out which side is which by looking at the relations available to us + bool invert = !JoinRelationSet::IsSubset(*left.set, *f->left_set); + cond.left = !invert ? std::move(comparison.left) : std::move(comparison.right); + cond.right = !invert ? std::move(comparison.right) : std::move(comparison.left); + cond.comparison = condition->type; + + if (invert) { + // reverse comparison expression if we reverse the order of the children + cond.comparison = FlipComparisonExpression(cond.comparison); + } + join->conditions.push_back(std::move(cond)); + } + D_ASSERT(!join->conditions.empty()); + result_operator = std::move(join); + } + left_node = left.set; + right_node = right.set; + result_relation = &set_manager.Union(*left.set, *right.set); + } else { + // base node, get the entry from the list of extracted relations + D_ASSERT(node.set.count == 1); + D_ASSERT(extracted_relations[node.set.relations[0]]); + result_relation = &node.set; + result_operator = std::move(extracted_relations[node.set.relations[0]]); + } + // TODO: this is where estimated properties start coming into play. + // when creating the result operator, we should ask the cost model and cardinality estimator what + // the cost and cardinality are + // result_operator->estimated_props = node.estimated_props->Copy(); + result_operator->estimated_cardinality = node.cardinality; + result_operator->has_estimated_cardinality = true; + if (result_operator->type == LogicalOperatorType::LOGICAL_FILTER && + result_operator->children[0]->type == LogicalOperatorType::LOGICAL_GET) { + // FILTER on top of GET, add estimated properties to both + // auto &filter_props = *result_operator->estimated_props; + auto &child_operator = *result_operator->children[0]; + child_operator.estimated_cardinality = node.cardinality; + child_operator.has_estimated_cardinality = true; + } + // check if we should do a pushdown on this node + // basically, any remaining filter that is a subset of the current relation will no longer be used in joins + // hence we should push it here + for (auto &filter_info : filters_and_bindings) { + // check if the filter has already been extracted + auto &info = *filter_info; + if (filters_and_bindings[info.filter_index]->filter) { + // now check if the filter is a subset of the current relation + // note that infos with an empty relation set are a special case and we do not push them down + if (info.set.count > 0 && JoinRelationSet::IsSubset(*result_relation, info.set)) { + auto &filter_and_binding = filters_and_bindings[info.filter_index]; + auto filter = std::move(filter_and_binding->filter); + // if it is, we can push the filter + // we can push it either into a join or as a filter + // check if we are in a join or in a base table + if (!left_node || !info.left_set) { + // base table or non-comparison expression, push it as a filter + result_operator = PushFilter(std::move(result_operator), std::move(filter)); + continue; + } + // the node below us is a join or cross product and the expression is a comparison + // check if the nodes can be split up into left/right + bool found_subset = false; + bool invert = false; + if (JoinRelationSet::IsSubset(*left_node, *info.left_set) && + JoinRelationSet::IsSubset(*right_node, *info.right_set)) { + found_subset = true; + } else if (JoinRelationSet::IsSubset(*right_node, *info.left_set) && + JoinRelationSet::IsSubset(*left_node, *info.right_set)) { + invert = true; + found_subset = true; + } + if (!found_subset) { + // could not be split up into left/right + result_operator = PushFilter(std::move(result_operator), std::move(filter)); + continue; + } + // create the join condition + JoinCondition cond; + D_ASSERT(filter->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON); + auto &comparison = filter->Cast(); + // we need to figure out which side is which by looking at the relations available to us + cond.left = !invert ? std::move(comparison.left) : std::move(comparison.right); + cond.right = !invert ? std::move(comparison.right) : std::move(comparison.left); + cond.comparison = comparison.type; + if (invert) { + // reverse comparison expression if we reverse the order of the children + cond.comparison = FlipComparisonExpression(comparison.type); + } + // now find the join to push it into + auto node = result_operator.get(); + if (node->type == LogicalOperatorType::LOGICAL_FILTER) { + node = node->children[0].get(); + } + if (node->type == LogicalOperatorType::LOGICAL_CROSS_PRODUCT) { + // turn into comparison join + auto comp_join = make_uniq(JoinType::INNER); + comp_join->children.push_back(std::move(node->children[0])); + comp_join->children.push_back(std::move(node->children[1])); + comp_join->conditions.push_back(std::move(cond)); + if (node == result_operator.get()) { + result_operator = std::move(comp_join); + } else { + D_ASSERT(result_operator->type == LogicalOperatorType::LOGICAL_FILTER); + result_operator->children[0] = std::move(comp_join); + } + } else { + D_ASSERT(node->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || + node->type == LogicalOperatorType::LOGICAL_ASOF_JOIN); + auto &comp_join = node->Cast(); + comp_join.conditions.push_back(std::move(cond)); + } + } + } + } + auto result = GenerateJoinRelation(result_relation, std::move(result_operator)); + return result; +} + +const QueryGraphEdges &QueryGraphManager::GetQueryGraphEdges() const { + return query_graph; +} + +void QueryGraphManager::CreateQueryGraphCrossProduct(JoinRelationSet &left, JoinRelationSet &right) { + query_graph.CreateEdge(left, right, nullptr); + query_graph.CreateEdge(right, left, nullptr); +} + +unique_ptr QueryGraphManager::RewritePlan(unique_ptr plan, JoinNode &node) { + // now we have to rewrite the plan + bool root_is_join = plan->children.size() > 1; + + // first we will extract all relations from the main plan + vector> extracted_relations; + extracted_relations.reserve(relation_manager.NumRelations()); + for (auto &relation : relation_manager.GetRelations()) { + extracted_relations.push_back(ExtractJoinRelation(relation)); + } + + // now we generate the actual joins + auto join_tree = GenerateJoins(extracted_relations, node); + // perform the final pushdown of remaining filters + for (auto &filter : filters_and_bindings) { + // check if the filter has already been extracted + if (filter->filter) { + // if not we need to push it + join_tree.op = PushFilter(std::move(join_tree.op), std::move(filter->filter)); + } + } + + // find the first join in the relation to know where to place this node + if (root_is_join) { + // first node is the join, return it immediately + return std::move(join_tree.op); + } + D_ASSERT(plan->children.size() == 1); + // have to move up through the relations + auto op = plan.get(); + auto parent = plan.get(); + while (op->type != LogicalOperatorType::LOGICAL_CROSS_PRODUCT && + op->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN && + op->type != LogicalOperatorType::LOGICAL_ASOF_JOIN) { + D_ASSERT(op->children.size() == 1); + parent = op; + op = op->children[0].get(); + } + // have to replace at this node + parent->children[0] = std::move(join_tree.op); + return plan; +} + +bool QueryGraphManager::LeftCardLessThanRight(LogicalOperator &op) { + D_ASSERT(op.children.size() == 2); + if (op.children[0]->has_estimated_cardinality && op.children[1]->has_estimated_cardinality) { + return op.children[0]->estimated_cardinality < op.children[1]->estimated_cardinality; + } + return op.children[0]->EstimateCardinality(context) < op.children[1]->EstimateCardinality(context); +} + +unique_ptr QueryGraphManager::LeftRightOptimizations(unique_ptr input_op) { + auto op = input_op.get(); + // pass through single child operators + while (!op->children.empty()) { + if (op->children.size() == 2) { + switch (op->type) { + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + auto &join = op->Cast(); + if (join.join_type == JoinType::INNER) { + if (LeftCardLessThanRight(*op)) { + std::swap(op->children[0], op->children[1]); + for (auto &cond : join.conditions) { + std::swap(cond.left, cond.right); + cond.comparison = FlipComparisonExpression(cond.comparison); + } + } + } else if (join.join_type == JoinType::LEFT && join.right_projection_map.empty()) { + auto lhs_cardinality = join.children[0]->EstimateCardinality(context); + auto rhs_cardinality = join.children[1]->EstimateCardinality(context); + if (rhs_cardinality > lhs_cardinality * 2) { + join.join_type = JoinType::RIGHT; + std::swap(join.children[0], join.children[1]); + for (auto &cond : join.conditions) { + std::swap(cond.left, cond.right); + cond.comparison = FlipComparisonExpression(cond.comparison); + } + } + } + break; + } + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: { + if (LeftCardLessThanRight(*op)) { + std::swap(op->children[0], op->children[1]); + } + break; + } + case LogicalOperatorType::LOGICAL_ANY_JOIN: { + auto &join = op->Cast(); + if (join.join_type == JoinType::LEFT && join.right_projection_map.empty()) { + auto lhs_cardinality = join.children[0]->EstimateCardinality(context); + auto rhs_cardinality = join.children[1]->EstimateCardinality(context); + if (rhs_cardinality > lhs_cardinality * 2) { + join.join_type = JoinType::RIGHT; + std::swap(join.children[0], join.children[1]); + } + } else if (join.join_type == JoinType::INNER && LeftCardLessThanRight(*op)) { + std::swap(join.children[0], join.children[1]); + } + break; + } + default: + break; + } + op->children[0] = LeftRightOptimizations(std::move(op->children[0])); + op->children[1] = LeftRightOptimizations(std::move(op->children[1])); + // break from while loop + break; + } + if (op->children.size() == 1) { + op = op->children[0].get(); + } + } + return input_op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/relation_manager.cpp b/src/duckdb/src/optimizer/join_order/relation_manager.cpp new file mode 100644 index 00000000..9294cd07 --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/relation_manager.cpp @@ -0,0 +1,356 @@ +#include "duckdb/optimizer/join_order/relation_manager.hpp" +#include "duckdb/optimizer/join_order/join_order_optimizer.hpp" +#include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/operator/list.hpp" + +#include + +namespace duckdb { + +const vector RelationManager::GetRelationStats() { + vector ret; + for (idx_t i = 0; i < relations.size(); i++) { + ret.push_back(relations[i]->stats); + } + return ret; +} + +vector> RelationManager::GetRelations() { + return std::move(relations); +} + +idx_t RelationManager::NumRelations() { + return relations.size(); +} + +void RelationManager::AddAggregateRelation(LogicalOperator &op, optional_ptr parent, + const RelationStats &stats) { + auto relation = make_uniq(op, parent, stats); + auto relation_id = relations.size(); + + auto table_indexes = op.GetTableIndex(); + for (auto &index : table_indexes) { + D_ASSERT(relation_mapping.find(index) == relation_mapping.end()); + relation_mapping[index] = relation_id; + } + relations.push_back(std::move(relation)); +} + +void RelationManager::AddRelation(LogicalOperator &op, optional_ptr parent, + const RelationStats &stats) { + + // if parent is null, then this is a root relation + // if parent is not null, it should have multiple children + D_ASSERT(!parent || parent->children.size() >= 2); + auto relation = make_uniq(op, parent, stats); + auto relation_id = relations.size(); + + auto table_indexes = op.GetTableIndex(); + if (table_indexes.empty()) { + // relation represents a non-reorderable relation, most likely a join relation + // Get the tables referenced in the non-reorderable relation and add them to the relation mapping + // This should all table references, even if there are nested non-reorderable joins. + unordered_set table_references; + LogicalJoin::GetTableReferences(op, table_references); + D_ASSERT(table_references.size() > 0); + for (auto &reference : table_references) { + D_ASSERT(relation_mapping.find(reference) == relation_mapping.end()); + relation_mapping[reference] = relation_id; + } + } else { + // Relations should never return more than 1 table index + D_ASSERT(table_indexes.size() == 1); + idx_t table_index = table_indexes.at(0); + D_ASSERT(relation_mapping.find(table_index) == relation_mapping.end()); + relation_mapping[table_index] = relation_id; + } + relations.push_back(std::move(relation)); +} + +static bool OperatorNeedsRelation(LogicalOperatorType op_type) { + switch (op_type) { + case LogicalOperatorType::LOGICAL_PROJECTION: + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + case LogicalOperatorType::LOGICAL_GET: + case LogicalOperatorType::LOGICAL_DELIM_GET: + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + case LogicalOperatorType::LOGICAL_WINDOW: + return true; + default: + return false; + } +} + +static bool OperatorIsNonReorderable(LogicalOperatorType op_type) { + switch (op_type) { + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + return true; + default: + return false; + } +} + +static bool HasNonReorderableChild(LogicalOperator &op) { + LogicalOperator *tmp = &op; + while (tmp->children.size() == 1) { + if (OperatorNeedsRelation(tmp->type) || OperatorIsNonReorderable(tmp->type)) { + return true; + } + tmp = tmp->children[0].get(); + if (tmp->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + auto &join = tmp->Cast(); + if (join.join_type != JoinType::INNER) { + return true; + } + } + } + return tmp->children.empty(); +} + +bool RelationManager::ExtractJoinRelations(LogicalOperator &input_op, + vector> &filter_operators, + optional_ptr parent) { + LogicalOperator *op = &input_op; + vector> datasource_filters; + // pass through single child operators + while (op->children.size() == 1 && !OperatorNeedsRelation(op->type)) { + if (op->type == LogicalOperatorType::LOGICAL_FILTER) { + if (HasNonReorderableChild(*op)) { + datasource_filters.push_back(*op); + } + filter_operators.push_back(*op); + } + if (op->type == LogicalOperatorType::LOGICAL_SHOW) { + return false; + } + op = op->children[0].get(); + } + bool non_reorderable_operation = false; + if (OperatorIsNonReorderable(op->type)) { + // set operation, optimize separately in children + non_reorderable_operation = true; + } + + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + auto &join = op->Cast(); + if (join.join_type == JoinType::INNER) { + // extract join conditions from inner join + filter_operators.push_back(*op); + } else { + non_reorderable_operation = true; + } + } + if (non_reorderable_operation) { + // we encountered a non-reordable operation (setop or non-inner join) + // we do not reorder non-inner joins yet, however we do want to expand the potential join graph around them + // non-inner joins are also tricky because we can't freely make conditions through them + // e.g. suppose we have (left LEFT OUTER JOIN right WHERE right IS NOT NULL), the join can generate + // new NULL values in the right side, so pushing this condition through the join leads to incorrect results + // for this reason, we just start a new JoinOptimizer pass in each of the children of the join + // stats.cardinality will be initiated to highest cardinality of the children. + vector children_stats; + for (auto &child : op->children) { + auto stats = RelationStats(); + JoinOrderOptimizer optimizer(context); + child = optimizer.Optimize(std::move(child), &stats); + children_stats.push_back(stats); + } + + auto combined_stats = RelationStatisticsHelper::CombineStatsOfNonReorderableOperator(*op, children_stats); + if (!datasource_filters.empty()) { + combined_stats.cardinality = + (idx_t)MaxValue(combined_stats.cardinality * RelationStatisticsHelper::DEFAULT_SELECTIVITY, (double)1); + } + AddRelation(input_op, parent, combined_stats); + return true; + } + + switch (op->type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + // optimize children + RelationStats child_stats; + JoinOrderOptimizer optimizer(context); + op->children[0] = optimizer.Optimize(std::move(op->children[0]), &child_stats); + auto &aggr = op->Cast(); + auto operator_stats = RelationStatisticsHelper::ExtractAggregationStats(aggr, child_stats); + AddAggregateRelation(input_op, parent, operator_stats); + return true; + } + case LogicalOperatorType::LOGICAL_WINDOW: { + // optimize children + RelationStats child_stats; + JoinOrderOptimizer optimizer(context); + op->children[0] = optimizer.Optimize(std::move(op->children[0]), &child_stats); + auto &window = op->Cast(); + auto operator_stats = RelationStatisticsHelper::ExtractWindowStats(window, child_stats); + AddAggregateRelation(input_op, parent, operator_stats); + return true; + } + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: { + // Adding relations to the current join order optimizer + bool can_reorder_left = ExtractJoinRelations(*op->children[0], filter_operators, op); + bool can_reorder_right = ExtractJoinRelations(*op->children[1], filter_operators, op); + return can_reorder_left && can_reorder_right; + } + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: { + auto &dummy_scan = op->Cast(); + auto stats = RelationStatisticsHelper::ExtractDummyScanStats(dummy_scan, context); + AddRelation(input_op, parent, stats); + return true; + } + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { + // base table scan, add to set of relations. + // create empty stats for dummy scan or logical expression get + auto &expression_get = op->Cast(); + auto stats = RelationStatisticsHelper::ExtractExpressionGetStats(expression_get, context); + AddRelation(input_op, parent, stats); + return true; + } + case LogicalOperatorType::LOGICAL_GET: { + // TODO: Get stats from a logical GET + auto &get = op->Cast(); + auto stats = RelationStatisticsHelper::ExtractGetStats(get, context); + // if there is another logical filter that could not be pushed down into the + // table scan, apply another selectivity. + if (!datasource_filters.empty()) { + stats.cardinality = + (idx_t)MaxValue(stats.cardinality * RelationStatisticsHelper::DEFAULT_SELECTIVITY, (double)1); + } + AddRelation(input_op, parent, stats); + return true; + } + case LogicalOperatorType::LOGICAL_DELIM_GET: { + auto &delim_get = op->Cast(); + auto stats = RelationStatisticsHelper::ExtractDelimGetStats(delim_get, context); + AddRelation(input_op, parent, stats); + return true; + } + case LogicalOperatorType::LOGICAL_PROJECTION: { + auto child_stats = RelationStats(); + // optimize the child and copy the stats + JoinOrderOptimizer optimizer(context); + op->children[0] = optimizer.Optimize(std::move(op->children[0]), &child_stats); + auto &proj = op->Cast(); + // Projection can create columns so we need to add them here + auto proj_stats = RelationStatisticsHelper::ExtractProjectionStats(proj, child_stats); + AddRelation(input_op, parent, proj_stats); + return true; + } + default: + return false; + } +} + +bool RelationManager::ExtractBindings(Expression &expression, unordered_set &bindings) { + if (expression.type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expression.Cast(); + D_ASSERT(colref.depth == 0); + D_ASSERT(colref.binding.table_index != DConstants::INVALID_INDEX); + // map the base table index to the relation index used by the JoinOrderOptimizer + if (expression.alias == "SUBQUERY" && + relation_mapping.find(colref.binding.table_index) == relation_mapping.end()) { + // most likely a BoundSubqueryExpression that was created from an uncorrelated subquery + // Here we return true and don't fill the bindings, the expression can be reordered. + // A filter will be created using this expression, and pushed back on top of the parent + // operator during plan reconstruction + return true; + } + D_ASSERT(relation_mapping.find(colref.binding.table_index) != relation_mapping.end()); + bindings.insert(relation_mapping[colref.binding.table_index]); + } + if (expression.type == ExpressionType::BOUND_REF) { + // bound expression + bindings.clear(); + return false; + } + D_ASSERT(expression.type != ExpressionType::SUBQUERY); + bool can_reorder = true; + ExpressionIterator::EnumerateChildren(expression, [&](Expression &expr) { + if (!ExtractBindings(expr, bindings)) { + can_reorder = false; + return; + } + }); + return can_reorder; +} + +vector> RelationManager::ExtractEdges(LogicalOperator &op, + vector> &filter_operators, + JoinRelationSetManager &set_manager) { + // now that we know we are going to perform join ordering we actually extract the filters, eliminating duplicate + // filters in the process + vector> filters_and_bindings; + expression_set_t filter_set; + for (auto &filter_op : filter_operators) { + auto &f_op = filter_op.get(); + if (f_op.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || + f_op.type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { + auto &join = f_op.Cast(); + D_ASSERT(join.join_type == JoinType::INNER); + D_ASSERT(join.expressions.empty()); + for (auto &cond : join.conditions) { + auto comparison = + make_uniq(cond.comparison, std::move(cond.left), std::move(cond.right)); + if (filter_set.find(*comparison) == filter_set.end()) { + filter_set.insert(*comparison); + unordered_set bindings; + ExtractBindings(*comparison, bindings); + auto &set = set_manager.GetJoinRelation(bindings); + auto filter_info = make_uniq(std::move(comparison), set, filters_and_bindings.size()); + filters_and_bindings.push_back(std::move(filter_info)); + } + } + join.conditions.clear(); + } else { + for (auto &expression : f_op.expressions) { + if (filter_set.find(*expression) == filter_set.end()) { + filter_set.insert(*expression); + unordered_set bindings; + ExtractBindings(*expression, bindings); + auto &set = set_manager.GetJoinRelation(bindings); + auto filter_info = make_uniq(std::move(expression), set, filters_and_bindings.size()); + filters_and_bindings.push_back(std::move(filter_info)); + } + } + f_op.expressions.clear(); + } + } + + return filters_and_bindings; +} + +// LCOV_EXCL_START + +void RelationManager::PrintRelationStats() { +#ifdef DEBUG + string to_print; + for (idx_t i = 0; i < relations.size(); i++) { + auto &relation = relations.at(i); + auto &stats = relation->stats; + D_ASSERT(stats.column_names.size() == stats.column_distinct_count.size()); + for (idx_t i = 0; i < stats.column_names.size(); i++) { + to_print = stats.column_names.at(i) + " has estimated distinct count " + + to_string(stats.column_distinct_count.at(i).distinct_count); + Printer::Print(to_print); + } + to_print = stats.table_name + " has estimated cardinality " + to_string(stats.cardinality); + to_print += " and relation id " + to_string(i) + "\n"; + Printer::Print(to_print); + } +#endif +} + +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp b/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp new file mode 100644 index 00000000..0dce879b --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp @@ -0,0 +1,351 @@ +#include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" + +namespace duckdb { + +static ExpressionBinding GetChildColumnBinding(Expression &expr) { + auto ret = ExpressionBinding(); + switch (expr.expression_class) { + case ExpressionClass::BOUND_FUNCTION: { + // TODO: Other expression classes that can have 0 children? + auto &func = expr.Cast(); + // no children some sort of gen_random_uuid() or equivalent. + if (func.children.empty()) { + ret.found_expression = true; + ret.expression_is_constant = true; + return ret; + } + break; + } + case ExpressionClass::BOUND_COLUMN_REF: { + ret.found_expression = true; + auto &new_col_ref = expr.Cast(); + ret.child_binding = ColumnBinding(new_col_ref.binding.table_index, new_col_ref.binding.column_index); + return ret; + } + case ExpressionClass::BOUND_LAMBDA_REF: + case ExpressionClass::BOUND_CONSTANT: + case ExpressionClass::BOUND_DEFAULT: + case ExpressionClass::BOUND_PARAMETER: + case ExpressionClass::BOUND_REF: + ret.found_expression = true; + ret.expression_is_constant = true; + return ret; + default: + break; + } + ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &child) { + auto recursive_result = GetChildColumnBinding(*child); + if (recursive_result.found_expression) { + ret = recursive_result; + } + }); + // we didn't find a Bound Column Ref + return ret; +} + +RelationStats RelationStatisticsHelper::ExtractGetStats(LogicalGet &get, ClientContext &context) { + auto return_stats = RelationStats(); + + auto base_table_cardinality = get.EstimateCardinality(context); + auto cardinality_after_filters = base_table_cardinality; + unique_ptr column_statistics; + + auto table_thing = get.GetTable(); + auto name = string("some table"); + if (table_thing) { + name = table_thing->name; + return_stats.table_name = name; + } + + // if we can get the catalog table, then our column statistics will be accurate + // parquet readers etc. will still return statistics, but they initialize distinct column + // counts to 0. + // TODO: fix this, some file formats can encode distinct counts, we don't want to rely on + // getting a catalog table to know that we can use statistics. + bool have_catalog_table_statistics = false; + if (get.GetTable()) { + have_catalog_table_statistics = true; + } + + // first push back basic distinct counts for each column (if we have them). + for (idx_t i = 0; i < get.column_ids.size(); i++) { + bool have_distinct_count_stats = false; + if (get.function.statistics) { + column_statistics = get.function.statistics(context, get.bind_data.get(), get.column_ids[i]); + if (column_statistics && have_catalog_table_statistics) { + auto column_distinct_count = DistinctCount({column_statistics->GetDistinctCount(), true}); + return_stats.column_distinct_count.push_back(column_distinct_count); + return_stats.column_names.push_back(name + "." + get.names.at(get.column_ids.at(i))); + have_distinct_count_stats = true; + } + } + if (!have_distinct_count_stats) { + // currently treating the cardinality as the distinct count. + // the cardinality estimator will update these distinct counts based + // on the extra columns that are joined on. + auto column_distinct_count = DistinctCount({cardinality_after_filters, false}); + return_stats.column_distinct_count.push_back(column_distinct_count); + auto column_name = string("column"); + if (get.column_ids.at(i) < get.names.size()) { + column_name = get.names.at(get.column_ids.at(i)); + } + return_stats.column_names.push_back(get.GetName() + "." + column_name); + } + } + + if (!get.table_filters.filters.empty()) { + column_statistics = nullptr; + for (auto &it : get.table_filters.filters) { + if (get.bind_data && get.function.name.compare("seq_scan") == 0) { + auto &table_scan_bind_data = get.bind_data->Cast(); + column_statistics = get.function.statistics(context, &table_scan_bind_data, it.first); + } + + if (column_statistics && it.second->filter_type == TableFilterType::CONJUNCTION_AND) { + auto &filter = it.second->Cast(); + idx_t cardinality_with_and_filter = RelationStatisticsHelper::InspectConjunctionAND( + base_table_cardinality, it.first, filter, *column_statistics); + cardinality_after_filters = MinValue(cardinality_after_filters, cardinality_with_and_filter); + } + } + // if the above code didn't find an equality filter (i.e country_code = "[us]") + // and there are other table filters (i.e cost > 50), use default selectivity. + bool has_equality_filter = (cardinality_after_filters != base_table_cardinality); + if (!has_equality_filter && !get.table_filters.filters.empty()) { + cardinality_after_filters = + MaxValue(base_table_cardinality * RelationStatisticsHelper::DEFAULT_SELECTIVITY, 1); + } + if (base_table_cardinality == 0) { + cardinality_after_filters = 0; + } + } + return_stats.cardinality = cardinality_after_filters; + // update the estimated cardinality of the get as well. + // This is not updated during plan reconstruction. + get.estimated_cardinality = cardinality_after_filters; + get.has_estimated_cardinality = true; + D_ASSERT(base_table_cardinality >= cardinality_after_filters); + return_stats.stats_initialized = true; + return return_stats; +} + +RelationStats RelationStatisticsHelper::ExtractDelimGetStats(LogicalDelimGet &delim_get, ClientContext &context) { + RelationStats stats; + stats.table_name = delim_get.GetName(); + idx_t card = delim_get.EstimateCardinality(context); + stats.cardinality = card; + stats.stats_initialized = true; + for (auto &binding : delim_get.GetColumnBindings()) { + stats.column_distinct_count.push_back(DistinctCount({1, false})); + stats.column_names.push_back("column" + to_string(binding.column_index)); + } + return stats; +} + +RelationStats RelationStatisticsHelper::ExtractProjectionStats(LogicalProjection &proj, RelationStats &child_stats) { + auto proj_stats = RelationStats(); + proj_stats.cardinality = child_stats.cardinality; + proj_stats.table_name = proj.GetName(); + for (auto &expr : proj.expressions) { + proj_stats.column_names.push_back(expr->GetName()); + auto res = GetChildColumnBinding(*expr); + D_ASSERT(res.found_expression); + if (res.expression_is_constant) { + proj_stats.column_distinct_count.push_back(DistinctCount({1, true})); + } else { + auto column_index = res.child_binding.column_index; + if (column_index >= child_stats.column_distinct_count.size() && expr->ToString() == "count_star()") { + // only one value for a count star + proj_stats.column_distinct_count.push_back(DistinctCount({1, true})); + } else { + // TODO: add this back in + // D_ASSERT(column_index < stats.column_distinct_count.size()); + if (column_index < child_stats.column_distinct_count.size()) { + proj_stats.column_distinct_count.push_back(child_stats.column_distinct_count.at(column_index)); + } else { + proj_stats.column_distinct_count.push_back(DistinctCount({proj_stats.cardinality, false})); + } + } + } + } + proj_stats.stats_initialized = true; + return proj_stats; +} + +RelationStats RelationStatisticsHelper::ExtractDummyScanStats(LogicalDummyScan &dummy_scan, ClientContext &context) { + auto stats = RelationStats(); + idx_t card = dummy_scan.EstimateCardinality(context); + stats.cardinality = card; + for (idx_t i = 0; i < dummy_scan.GetColumnBindings().size(); i++) { + stats.column_distinct_count.push_back(DistinctCount({card, false})); + stats.column_names.push_back("dummy_scan_column"); + } + stats.stats_initialized = true; + stats.table_name = "dummy scan"; + return stats; +} + +void RelationStatisticsHelper::CopyRelationStats(RelationStats &to, const RelationStats &from) { + to.column_distinct_count = from.column_distinct_count; + to.column_names = from.column_names; + to.cardinality = from.cardinality; + to.table_name = from.table_name; + to.stats_initialized = from.stats_initialized; +} + +RelationStats RelationStatisticsHelper::CombineStatsOfReorderableOperator(vector &bindings, + vector relation_stats) { + RelationStats stats; + idx_t max_card = 0; + for (auto &child_stats : relation_stats) { + for (idx_t i = 0; i < child_stats.column_distinct_count.size(); i++) { + stats.column_distinct_count.push_back(child_stats.column_distinct_count.at(i)); + stats.column_names.push_back(child_stats.column_names.at(i)); + } + stats.table_name += "joined with " + child_stats.table_name; + max_card = MaxValue(max_card, child_stats.cardinality); + } + stats.stats_initialized = true; + stats.cardinality = max_card; + return stats; +} + +RelationStats RelationStatisticsHelper::CombineStatsOfNonReorderableOperator(LogicalOperator &op, + vector child_stats) { + D_ASSERT(child_stats.size() == 2); + RelationStats ret; + idx_t child_1_card = child_stats[0].stats_initialized ? child_stats[0].cardinality : 0; + idx_t child_2_card = child_stats[1].stats_initialized ? child_stats[1].cardinality : 0; + ret.cardinality = MaxValue(child_1_card, child_2_card); + ret.stats_initialized = true; + ret.filter_strength = 1; + ret.table_name = child_stats[0].table_name + " joined with " + child_stats[1].table_name; + for (auto &stats : child_stats) { + // MARK joins are nonreorderable. They won't return initialized stats + // continue in this case. + if (!stats.stats_initialized) { + continue; + } + for (auto &distinct_count : stats.column_distinct_count) { + ret.column_distinct_count.push_back(distinct_count); + } + for (auto &column_name : stats.column_names) { + ret.column_names.push_back(column_name); + } + } + return ret; +} + +RelationStats RelationStatisticsHelper::ExtractExpressionGetStats(LogicalExpressionGet &expression_get, + ClientContext &context) { + auto stats = RelationStats(); + idx_t card = expression_get.EstimateCardinality(context); + stats.cardinality = card; + for (idx_t i = 0; i < expression_get.GetColumnBindings().size(); i++) { + stats.column_distinct_count.push_back(DistinctCount({card, false})); + stats.column_names.push_back("expression_get_column"); + } + stats.stats_initialized = true; + stats.table_name = "expression_get"; + return stats; +} + +RelationStats RelationStatisticsHelper::ExtractWindowStats(LogicalWindow &window, RelationStats &child_stats) { + RelationStats stats; + stats.cardinality = child_stats.cardinality; + stats.column_distinct_count = child_stats.column_distinct_count; + stats.column_names = child_stats.column_names; + stats.stats_initialized = true; + auto num_child_columns = window.GetColumnBindings().size(); + + for (idx_t column_index = child_stats.column_distinct_count.size(); column_index < num_child_columns; + column_index++) { + stats.column_distinct_count.push_back(DistinctCount({child_stats.cardinality, false})); + stats.column_names.push_back("window"); + } + return stats; +} + +RelationStats RelationStatisticsHelper::ExtractAggregationStats(LogicalAggregate &aggr, RelationStats &child_stats) { + RelationStats stats; + // TODO: look at child distinct count to better estimate cardinality. + stats.cardinality = child_stats.cardinality; + stats.column_distinct_count = child_stats.column_distinct_count; + stats.column_names = child_stats.column_names; + stats.stats_initialized = true; + auto num_child_columns = aggr.GetColumnBindings().size(); + + for (idx_t column_index = child_stats.column_distinct_count.size(); column_index < num_child_columns; + column_index++) { + stats.column_distinct_count.push_back(DistinctCount({child_stats.cardinality, false})); + stats.column_names.push_back("aggregate"); + } + return stats; +} + +idx_t RelationStatisticsHelper::InspectConjunctionAND(idx_t cardinality, idx_t column_index, + ConjunctionAndFilter &filter, BaseStatistics &base_stats) { + auto cardinality_after_filters = cardinality; + for (auto &child_filter : filter.child_filters) { + if (child_filter->filter_type != TableFilterType::CONSTANT_COMPARISON) { + continue; + } + auto &comparison_filter = child_filter->Cast(); + if (comparison_filter.comparison_type != ExpressionType::COMPARE_EQUAL) { + continue; + } + auto column_count = base_stats.GetDistinctCount(); + auto filtered_card = cardinality; + // column_count = 0 when there is no column count (i.e parquet scans) + if (column_count > 0) { + // we want the ceil of cardinality/column_count. We also want to avoid compiler errors + filtered_card = (cardinality + column_count - 1) / column_count; + cardinality_after_filters = filtered_card; + } + } + return cardinality_after_filters; +} + +// TODO: Currently only simple AND filters are pushed into table scans. +// When OR filters are pushed this function can be added +// idx_t RelationStatisticsHelper::InspectConjunctionOR(idx_t cardinality, idx_t column_index, ConjunctionOrFilter +// &filter, +// BaseStatistics &base_stats) { +// auto has_equality_filter = false; +// auto cardinality_after_filters = cardinality; +// for (auto &child_filter : filter.child_filters) { +// if (child_filter->filter_type != TableFilterType::CONSTANT_COMPARISON) { +// continue; +// } +// auto &comparison_filter = child_filter->Cast(); +// if (comparison_filter.comparison_type == ExpressionType::COMPARE_EQUAL) { +// auto column_count = base_stats.GetDistinctCount(); +// auto increment = MaxValue(((cardinality + column_count - 1) / column_count), 1); +// if (has_equality_filter) { +// cardinality_after_filters += increment; +// } else { +// cardinality_after_filters = increment; +// } +// has_equality_filter = true; +// } +// if (child_filter->filter_type == TableFilterType::CONJUNCTION_AND) { +// auto &and_filter = child_filter->Cast(); +// cardinality_after_filters = RelationStatisticsHelper::InspectConjunctionAND( +// cardinality_after_filters, column_index, and_filter, base_stats); +// continue; +// } +// } +// D_ASSERT(cardinality_after_filters > 0); +// return cardinality_after_filters; +//} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/matcher/expression_matcher.cpp b/src/duckdb/src/optimizer/matcher/expression_matcher.cpp new file mode 100644 index 00000000..1d8ca93c --- /dev/null +++ b/src/duckdb/src/optimizer/matcher/expression_matcher.cpp @@ -0,0 +1,103 @@ +#include "duckdb/optimizer/matcher/expression_matcher.hpp" + +#include "duckdb/planner/expression/list.hpp" + +namespace duckdb { + +bool ExpressionMatcher::Match(Expression &expr, vector> &bindings) { + if (type && !type->Match(expr.return_type)) { + return false; + } + if (expr_type && !expr_type->Match(expr.type)) { + return false; + } + if (expr_class != ExpressionClass::INVALID && expr_class != expr.GetExpressionClass()) { + return false; + } + bindings.push_back(expr); + return true; +} + +bool ExpressionEqualityMatcher::Match(Expression &expr, vector> &bindings) { + if (!expr.Equals(expression)) { + return false; + } + bindings.push_back(expr); + return true; +} + +bool CaseExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { + if (!ExpressionMatcher::Match(expr_p, bindings)) { + return false; + } + return true; +} + +bool ComparisonExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { + if (!ExpressionMatcher::Match(expr_p, bindings)) { + return false; + } + auto &expr = expr_p.Cast(); + vector> expressions; + expressions.push_back(*expr.left); + expressions.push_back(*expr.right); + return SetMatcher::Match(matchers, expressions, bindings, policy); +} + +bool CastExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { + if (!ExpressionMatcher::Match(expr_p, bindings)) { + return false; + } + if (!matcher) { + return true; + } + auto &expr = expr_p.Cast(); + return matcher->Match(*expr.child, bindings); +} + +bool InClauseExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { + if (!ExpressionMatcher::Match(expr_p, bindings)) { + return false; + } + auto &expr = expr_p.Cast(); + if (expr.type != ExpressionType::COMPARE_IN || expr.type == ExpressionType::COMPARE_NOT_IN) { + return false; + } + return SetMatcher::Match(matchers, expr.children, bindings, policy); +} + +bool ConjunctionExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { + if (!ExpressionMatcher::Match(expr_p, bindings)) { + return false; + } + auto &expr = expr_p.Cast(); + if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { + return false; + } + return true; +} + +bool FunctionExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { + if (!ExpressionMatcher::Match(expr_p, bindings)) { + return false; + } + auto &expr = expr_p.Cast(); + if (!FunctionMatcher::Match(function, expr.function.name)) { + return false; + } + if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { + return false; + } + return true; +} + +bool FoldableConstantMatcher::Match(Expression &expr, vector> &bindings) { + // we match on ANY expression that is a scalar expression + if (!expr.IsFoldable()) { + return false; + } + bindings.push_back(expr); + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp new file mode 100644 index 00000000..0a66f000 --- /dev/null +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -0,0 +1,208 @@ +#include "duckdb/optimizer/optimizer.hpp" + +#include "duckdb/execution/column_binding_resolver.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/optimizer/column_lifetime_optimizer.hpp" +#include "duckdb/optimizer/common_aggregate_optimizer.hpp" +#include "duckdb/optimizer/compressed_materialization.hpp" +#include "duckdb/optimizer/cse_optimizer.hpp" +#include "duckdb/optimizer/deliminator.hpp" +#include "duckdb/optimizer/expression_heuristics.hpp" +#include "duckdb/optimizer/filter_pullup.hpp" +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/optimizer/in_clause_rewriter.hpp" +#include "duckdb/optimizer/join_order/join_order_optimizer.hpp" +#include "duckdb/optimizer/regex_range_filter.hpp" +#include "duckdb/optimizer/remove_duplicate_groups.hpp" +#include "duckdb/optimizer/remove_unused_columns.hpp" +#include "duckdb/optimizer/rule/equal_or_null_simplification.hpp" +#include "duckdb/optimizer/rule/in_clause_simplification.hpp" +#include "duckdb/optimizer/rule/list.hpp" +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/optimizer/topn_optimizer.hpp" +#include "duckdb/optimizer/unnest_rewriter.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/planner.hpp" + +namespace duckdb { + +Optimizer::Optimizer(Binder &binder, ClientContext &context) : context(context), binder(binder), rewriter(context) { + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); + +#ifdef DEBUG + for (auto &rule : rewriter.rules) { + // root not defined in rule + D_ASSERT(rule->root); + } +#endif +} + +ClientContext &Optimizer::GetContext() { + return context; +} + +void Optimizer::RunOptimizer(OptimizerType type, const std::function &callback) { + auto &config = DBConfig::GetConfig(context); + if (config.options.disabled_optimizers.find(type) != config.options.disabled_optimizers.end()) { + // optimizer is marked as disabled: skip + return; + } + auto &profiler = QueryProfiler::Get(context); + profiler.StartPhase(OptimizerTypeToString(type)); + callback(); + profiler.EndPhase(); + if (plan) { + Verify(*plan); + } +} + +void Optimizer::Verify(LogicalOperator &op) { + ColumnBindingResolver::Verify(op); +} + +unique_ptr Optimizer::Optimize(unique_ptr plan_p) { + Verify(*plan_p); + + switch (plan_p->type) { + case LogicalOperatorType::LOGICAL_TRANSACTION: + return plan_p; // skip optimizing simple & often-occurring plans unaffected by rewrites + default: + break; + } + + this->plan = std::move(plan_p); + // first we perform expression rewrites using the ExpressionRewriter + // this does not change the logical plan structure, but only simplifies the expression trees + RunOptimizer(OptimizerType::EXPRESSION_REWRITER, [&]() { rewriter.VisitOperator(*plan); }); + + // perform filter pullup + RunOptimizer(OptimizerType::FILTER_PULLUP, [&]() { + FilterPullup filter_pullup; + plan = filter_pullup.Rewrite(std::move(plan)); + }); + + // perform filter pushdown + RunOptimizer(OptimizerType::FILTER_PUSHDOWN, [&]() { + FilterPushdown filter_pushdown(*this); + plan = filter_pushdown.Rewrite(std::move(plan)); + }); + + RunOptimizer(OptimizerType::REGEX_RANGE, [&]() { + RegexRangeFilter regex_opt; + plan = regex_opt.Rewrite(std::move(plan)); + }); + + RunOptimizer(OptimizerType::IN_CLAUSE, [&]() { + InClauseRewriter ic_rewriter(context, *this); + plan = ic_rewriter.Rewrite(std::move(plan)); + }); + + // removes any redundant DelimGets/DelimJoins + RunOptimizer(OptimizerType::DELIMINATOR, [&]() { + Deliminator deliminator; + plan = deliminator.Optimize(std::move(plan)); + }); + + // then we perform the join ordering optimization + // this also rewrites cross products + filters into joins and performs filter pushdowns + RunOptimizer(OptimizerType::JOIN_ORDER, [&]() { + JoinOrderOptimizer optimizer(context); + plan = optimizer.Optimize(std::move(plan)); + }); + + // rewrites UNNESTs in DelimJoins by moving them to the projection + RunOptimizer(OptimizerType::UNNEST_REWRITER, [&]() { + UnnestRewriter unnest_rewriter; + plan = unnest_rewriter.Optimize(std::move(plan)); + }); + + // removes unused columns + RunOptimizer(OptimizerType::UNUSED_COLUMNS, [&]() { + RemoveUnusedColumns unused(binder, context, true); + unused.VisitOperator(*plan); + }); + + // Remove duplicate groups from aggregates + RunOptimizer(OptimizerType::DUPLICATE_GROUPS, [&]() { + RemoveDuplicateGroups remove; + remove.VisitOperator(*plan); + }); + + // then we extract common subexpressions inside the different operators + RunOptimizer(OptimizerType::COMMON_SUBEXPRESSIONS, [&]() { + CommonSubExpressionOptimizer cse_optimizer(binder); + cse_optimizer.VisitOperator(*plan); + }); + + // creates projection maps so unused columns are projected out early + RunOptimizer(OptimizerType::COLUMN_LIFETIME, [&]() { + ColumnLifetimeAnalyzer column_lifetime(true); + column_lifetime.VisitOperator(*plan); + }); + + // perform statistics propagation + column_binding_map_t> statistics_map; + RunOptimizer(OptimizerType::STATISTICS_PROPAGATION, [&]() { + StatisticsPropagator propagator(*this); + propagator.PropagateStatistics(plan); + statistics_map = propagator.GetStatisticsMap(); + }); + + // remove duplicate aggregates + RunOptimizer(OptimizerType::COMMON_AGGREGATE, [&]() { + CommonAggregateOptimizer common_aggregate; + common_aggregate.VisitOperator(*plan); + }); + + // creates projection maps so unused columns are projected out early + RunOptimizer(OptimizerType::COLUMN_LIFETIME, [&]() { + ColumnLifetimeAnalyzer column_lifetime(true); + column_lifetime.VisitOperator(*plan); + }); + + // compress data based on statistics for materializing operators + RunOptimizer(OptimizerType::COMPRESSED_MATERIALIZATION, [&]() { + CompressedMaterialization compressed_materialization(context, binder, std::move(statistics_map)); + compressed_materialization.Compress(plan); + }); + + // transform ORDER BY + LIMIT to TopN + RunOptimizer(OptimizerType::TOP_N, [&]() { + TopN topn; + plan = topn.Optimize(std::move(plan)); + }); + + // apply simple expression heuristics to get an initial reordering + RunOptimizer(OptimizerType::REORDER_FILTER, [&]() { + ExpressionHeuristics expression_heuristics(*this); + plan = expression_heuristics.Rewrite(std::move(plan)); + }); + + for (auto &optimizer_extension : DBConfig::GetConfig(context).optimizer_extensions) { + RunOptimizer(OptimizerType::EXTENSION, [&]() { + optimizer_extension.optimize_function(context, optimizer_extension.optimizer_info.get(), plan); + }); + } + + Planner::VerifyPlan(context, plan); + + return std::move(plan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pullup/pullup_both_side.cpp b/src/duckdb/src/optimizer/pullup/pullup_both_side.cpp new file mode 100644 index 00000000..04bbb3f1 --- /dev/null +++ b/src/duckdb/src/optimizer/pullup/pullup_both_side.cpp @@ -0,0 +1,24 @@ +#include "duckdb/optimizer/filter_pullup.hpp" + +namespace duckdb { + +unique_ptr FilterPullup::PullupBothSide(unique_ptr op) { + FilterPullup left_pullup(true, can_add_column); + FilterPullup right_pullup(true, can_add_column); + op->children[0] = left_pullup.Rewrite(std::move(op->children[0])); + op->children[1] = right_pullup.Rewrite(std::move(op->children[1])); + D_ASSERT(left_pullup.can_add_column == can_add_column); + D_ASSERT(right_pullup.can_add_column == can_add_column); + + // merging filter expressions + for (idx_t i = 0; i < right_pullup.filters_expr_pullup.size(); ++i) { + left_pullup.filters_expr_pullup.push_back(std::move(right_pullup.filters_expr_pullup[i])); + } + + if (!left_pullup.filters_expr_pullup.empty()) { + return GeneratePullupFilter(std::move(op), left_pullup.filters_expr_pullup); + } + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pullup/pullup_filter.cpp b/src/duckdb/src/optimizer/pullup/pullup_filter.cpp new file mode 100644 index 00000000..96395a97 --- /dev/null +++ b/src/duckdb/src/optimizer/pullup/pullup_filter.cpp @@ -0,0 +1,26 @@ +#include "duckdb/optimizer/filter_pullup.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" + +namespace duckdb { + +unique_ptr FilterPullup::PullupFilter(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_FILTER); + + auto &filter = op->Cast(); + if (can_pullup && filter.projection_map.empty()) { + unique_ptr child = std::move(op->children[0]); + child = Rewrite(std::move(child)); + // moving filter's expressions + for (idx_t i = 0; i < op->expressions.size(); ++i) { + filters_expr_pullup.push_back(std::move(op->expressions[i])); + } + return child; + } + op->children[0] = Rewrite(std::move(op->children[0])); + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pullup/pullup_from_left.cpp b/src/duckdb/src/optimizer/pullup/pullup_from_left.cpp new file mode 100644 index 00000000..c2014de7 --- /dev/null +++ b/src/duckdb/src/optimizer/pullup/pullup_from_left.cpp @@ -0,0 +1,25 @@ +#include "duckdb/optimizer/filter_pullup.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_join.hpp" + +namespace duckdb { + +unique_ptr FilterPullup::PullupFromLeft(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || + op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN || op->type == LogicalOperatorType::LOGICAL_ANY_JOIN || + op->type == LogicalOperatorType::LOGICAL_EXCEPT || op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN); + + FilterPullup left_pullup(true, can_add_column); + FilterPullup right_pullup(false, can_add_column); + + op->children[0] = left_pullup.Rewrite(std::move(op->children[0])); + op->children[1] = right_pullup.Rewrite(std::move(op->children[1])); + + // check only for filters from the LHS + if (!left_pullup.filters_expr_pullup.empty() && right_pullup.filters_expr_pullup.empty()) { + return GeneratePullupFilter(std::move(op), left_pullup.filters_expr_pullup); + } + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pullup/pullup_projection.cpp b/src/duckdb/src/optimizer/pullup/pullup_projection.cpp new file mode 100644 index 00000000..f2fd71b7 --- /dev/null +++ b/src/duckdb/src/optimizer/pullup/pullup_projection.cpp @@ -0,0 +1,98 @@ +#include "duckdb/optimizer/filter_pullup.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" + +namespace duckdb { + +static void RevertFilterPullup(LogicalProjection &proj, vector> &expressions) { + unique_ptr filter = make_uniq(); + for (idx_t i = 0; i < expressions.size(); ++i) { + filter->expressions.push_back(std::move(expressions[i])); + } + expressions.clear(); + filter->children.push_back(std::move(proj.children[0])); + proj.children[0] = std::move(filter); +} + +static void ReplaceExpressionBinding(vector> &proj_expressions, Expression &expr, + idx_t proj_table_idx) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + bool found_proj_col = false; + BoundColumnRefExpression &colref = expr.Cast(); + // find the corresponding column index in the projection expressions + for (idx_t proj_idx = 0; proj_idx < proj_expressions.size(); proj_idx++) { + auto &proj_expr = *proj_expressions[proj_idx]; + if (proj_expr.type == ExpressionType::BOUND_COLUMN_REF) { + if (colref.Equals(proj_expr)) { + colref.binding.table_index = proj_table_idx; + colref.binding.column_index = proj_idx; + found_proj_col = true; + break; + } + } + } + if (!found_proj_col) { + // Project a new column + auto new_colref = colref.Copy(); + colref.binding.table_index = proj_table_idx; + colref.binding.column_index = proj_expressions.size(); + proj_expressions.push_back(std::move(new_colref)); + } + } + ExpressionIterator::EnumerateChildren( + expr, [&](Expression &child) { return ReplaceExpressionBinding(proj_expressions, child, proj_table_idx); }); +} + +void FilterPullup::ProjectSetOperation(LogicalProjection &proj) { + vector> copy_proj_expressions; + // copying the project expressions, it's useful whether we should revert the filter pullup + for (idx_t i = 0; i < proj.expressions.size(); ++i) { + copy_proj_expressions.push_back(proj.expressions[i]->Copy()); + } + + // Replace filter expression bindings, when need we add new columns into the copied projection expression + vector> changed_filter_expressions; + for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) { + auto copy_filter_expr = filters_expr_pullup[i]->Copy(); + ReplaceExpressionBinding(copy_proj_expressions, (Expression &)*copy_filter_expr, proj.table_index); + changed_filter_expressions.push_back(std::move(copy_filter_expr)); + } + + /// Case new columns were added into the projection + // we must skip filter pullup because adding new columns to these operators will change the result + if (copy_proj_expressions.size() > proj.expressions.size()) { + RevertFilterPullup(proj, filters_expr_pullup); + return; + } + + // now we must replace the filter bindings + D_ASSERT(filters_expr_pullup.size() == changed_filter_expressions.size()); + for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) { + filters_expr_pullup[i] = std::move(changed_filter_expressions[i]); + } +} + +unique_ptr FilterPullup::PullupProjection(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_PROJECTION); + op->children[0] = Rewrite(std::move(op->children[0])); + if (!filters_expr_pullup.empty()) { + auto &proj = op->Cast(); + // INTERSECT, EXCEPT, and DISTINCT + if (!can_add_column) { + // special treatment for operators that cannot add columns, e.g., INTERSECT, EXCEPT, and DISTINCT + ProjectSetOperation(proj); + return op; + } + + for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) { + auto &expr = (Expression &)*filters_expr_pullup[i]; + ReplaceExpressionBinding(proj.expressions, expr, proj.table_index); + } + } + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pullup/pullup_set_operation.cpp b/src/duckdb/src/optimizer/pullup/pullup_set_operation.cpp new file mode 100644 index 00000000..8d7b3c1d --- /dev/null +++ b/src/duckdb/src/optimizer/pullup/pullup_set_operation.cpp @@ -0,0 +1,39 @@ +#include "duckdb/optimizer/filter_pullup.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" + +namespace duckdb { + +static void ReplaceFilterTableIndex(Expression &expr, LogicalSetOperation &setop) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expr.Cast(); + D_ASSERT(colref.depth == 0); + + colref.binding.table_index = setop.table_index; + return; + } + ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { ReplaceFilterTableIndex(child, setop); }); +} + +unique_ptr FilterPullup::PullupSetOperation(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_INTERSECT || op->type == LogicalOperatorType::LOGICAL_EXCEPT); + can_add_column = false; + can_pullup = true; + if (op->type == LogicalOperatorType::LOGICAL_INTERSECT) { + op = PullupBothSide(std::move(op)); + } else { + // EXCEPT only pull ups from LHS + op = PullupFromLeft(std::move(op)); + } + if (op->type == LogicalOperatorType::LOGICAL_FILTER) { + auto &filter = op->Cast(); + auto &setop = filter.children[0]->Cast(); + for (idx_t i = 0; i < filter.expressions.size(); ++i) { + ReplaceFilterTableIndex(*filter.expressions[i], setop); + } + } + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_aggregate.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_aggregate.cpp new file mode 100644 index 00000000..396980d5 --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_aggregate.cpp @@ -0,0 +1,99 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" +#include "duckdb/planner/operator/logical_join.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +static void ExtractFilterBindings(Expression &expr, vector &bindings) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expr.Cast(); + bindings.push_back(colref.binding); + } + ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { ExtractFilterBindings(child, bindings); }); +} + +static unique_ptr ReplaceGroupBindings(LogicalAggregate &proj, unique_ptr expr) { + if (expr->type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expr->Cast(); + D_ASSERT(colref.binding.table_index == proj.group_index); + D_ASSERT(colref.binding.column_index < proj.groups.size()); + D_ASSERT(colref.depth == 0); + // replace the binding with a copy to the expression at the referenced index + return proj.groups[colref.binding.column_index]->Copy(); + } + ExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child) { child = ReplaceGroupBindings(proj, std::move(child)); }); + return expr; +} + +unique_ptr FilterPushdown::PushdownAggregate(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY); + auto &aggr = op->Cast(); + + // pushdown into AGGREGATE and GROUP BY + // we cannot push expressions that refer to the aggregate + FilterPushdown child_pushdown(optimizer); + for (idx_t i = 0; i < filters.size(); i++) { + auto &f = *filters[i]; + if (f.bindings.find(aggr.aggregate_index) != f.bindings.end()) { + // filter on aggregate: cannot pushdown + continue; + } + if (f.bindings.find(aggr.groupings_index) != f.bindings.end()) { + // filter on GROUPINGS function: cannot pushdown + continue; + } + // no aggregate! we are filtering on a group + // we can only push this down if the filter is in all grouping sets + vector bindings; + ExtractFilterBindings(*f.filter, bindings); + + bool can_pushdown_filter = true; + if (aggr.grouping_sets.empty()) { + // empty grouping set - we cannot pushdown the filter + can_pushdown_filter = false; + } + for (auto &grp : aggr.grouping_sets) { + // check for each of the grouping sets if they contain all groups + if (bindings.empty()) { + // we can never push down empty grouping sets + can_pushdown_filter = false; + break; + } + for (auto &binding : bindings) { + if (grp.find(binding.column_index) == grp.end()) { + can_pushdown_filter = false; + break; + } + } + if (!can_pushdown_filter) { + break; + } + } + if (!can_pushdown_filter) { + continue; + } + // no aggregate! we can push this down + // rewrite any group bindings within the filter + f.filter = ReplaceGroupBindings(aggr, std::move(f.filter)); + // add the filter to the child node + if (child_pushdown.AddFilter(std::move(f.filter)) == FilterResult::UNSATISFIABLE) { + // filter statically evaluates to false, strip tree + return make_uniq(std::move(op)); + } + // erase the filter from here + filters.erase(filters.begin() + i); + i--; + } + child_pushdown.GenerateFilters(); + + op->children[0] = child_pushdown.Rewrite(std::move(op->children[0])); + return FinishPushdown(std::move(op)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_cross_product.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_cross_product.cpp new file mode 100644 index 00000000..efa3fa4e --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_cross_product.cpp @@ -0,0 +1,58 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +unique_ptr FilterPushdown::PushdownCrossProduct(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_CROSS_PRODUCT); + FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); + vector> join_expressions; + unordered_set left_bindings, right_bindings; + if (!filters.empty()) { + // check to see into which side we should push the filters + // first get the LHS and RHS bindings + LogicalJoin::GetTableReferences(*op->children[0], left_bindings); + LogicalJoin::GetTableReferences(*op->children[1], right_bindings); + // now check the set of filters + for (auto &f : filters) { + auto side = JoinSide::GetJoinSide(f->bindings, left_bindings, right_bindings); + if (side == JoinSide::LEFT) { + // bindings match left side: push into left + left_pushdown.filters.push_back(std::move(f)); + } else if (side == JoinSide::RIGHT) { + // bindings match right side: push into right + right_pushdown.filters.push_back(std::move(f)); + } else { + D_ASSERT(side == JoinSide::BOTH || side == JoinSide::NONE); + // bindings match both: turn into join condition + join_expressions.push_back(std::move(f->filter)); + } + } + } + + op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); + op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); + + if (!join_expressions.empty()) { + // join conditions found: turn into inner join + // extract join conditions + vector conditions; + vector> arbitrary_expressions; + auto join_type = JoinType::INNER; + LogicalComparisonJoin::ExtractJoinConditions(GetContext(), join_type, op->children[0], op->children[1], + left_bindings, right_bindings, join_expressions, conditions, + arbitrary_expressions); + // create the join from the join conditions + return LogicalComparisonJoin::CreateJoin(GetContext(), JoinType::INNER, JoinRefType::REGULAR, + std::move(op->children[0]), std::move(op->children[1]), + std::move(conditions), std::move(arbitrary_expressions)); + } else { + // no join conditions found: keep as cross product + return op; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_filter.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_filter.cpp new file mode 100644 index 00000000..9f3a6b5a --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_filter.cpp @@ -0,0 +1,26 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +unique_ptr FilterPushdown::PushdownFilter(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_FILTER); + auto &filter = op->Cast(); + if (!filter.projection_map.empty()) { + return FinishPushdown(std::move(op)); + } + // filter: gather the filters and remove the filter from the set of operations + for (auto &expression : filter.expressions) { + if (AddFilter(std::move(expression)) == FilterResult::UNSATISFIABLE) { + // filter statically evaluates to false, strip tree + return make_uniq(std::move(op)); + } + } + GenerateFilters(); + return Rewrite(std::move(filter.children[0])); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp new file mode 100644 index 00000000..865003fc --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp @@ -0,0 +1,80 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + +namespace duckdb { + +unique_ptr FilterPushdown::PushdownGet(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_GET); + auto &get = op->Cast(); + + if (get.function.pushdown_complex_filter || get.function.filter_pushdown) { + // this scan supports some form of filter push-down + // check if there are any parameters + // if there are, invalidate them to force a re-bind on execution + for (auto &filter : filters) { + if (filter->filter->HasParameter()) { + // there is a parameter in the filters! invalidate it + BoundParameterExpression::InvalidateRecursive(*filter->filter); + } + } + } + if (get.function.pushdown_complex_filter) { + // for the remaining filters, check if we can push any of them into the scan as well + vector> expressions; + expressions.reserve(filters.size()); + for (auto &filter : filters) { + expressions.push_back(std::move(filter->filter)); + } + filters.clear(); + + get.function.pushdown_complex_filter(optimizer.context, get, get.bind_data.get(), expressions); + + if (expressions.empty()) { + return op; + } + // re-generate the filters + for (auto &expr : expressions) { + auto f = make_uniq(); + f->filter = std::move(expr); + f->ExtractBindings(); + filters.push_back(std::move(f)); + } + } + + if (!get.table_filters.filters.empty() || !get.function.filter_pushdown) { + // the table function does not support filter pushdown: push a LogicalFilter on top + return FinishPushdown(std::move(op)); + } + PushFilters(); + + //! We generate the table filters that will be executed during the table scan + //! Right now this only executes simple AND filters + get.table_filters = combiner.GenerateTableScanFilters(get.column_ids); + + // //! For more complex filters if all filters to a column are constants we generate a min max boundary used to + // check + // //! the zonemaps. + // auto zonemap_checks = combiner.GenerateZonemapChecks(get.column_ids, get.table_filters); + + // for (auto &f : get.table_filters) { + // f.column_index = get.column_ids[f.column_index]; + // } + + // //! Use zonemap checks as table filters for pre-processing + // for (auto &zonemap_check : zonemap_checks) { + // if (zonemap_check.column_index != COLUMN_IDENTIFIER_ROW_ID) { + // get.table_filters.push_back(zonemap_check); + // } + // } + + GenerateFilters(); + + //! Now we try to pushdown the remaining filters to perform zonemap checking + return FinishPushdown(std::move(op)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp new file mode 100644 index 00000000..dd94d3f3 --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp @@ -0,0 +1,51 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/planner/operator/logical_any_join.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +unique_ptr FilterPushdown::PushdownInnerJoin(unique_ptr op, + unordered_set &left_bindings, + unordered_set &right_bindings) { + auto &join = op->Cast(); + D_ASSERT(join.join_type == JoinType::INNER); + if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + return FinishPushdown(std::move(op)); + } + // inner join: gather all the conditions of the inner join and add to the filter list + if (op->type == LogicalOperatorType::LOGICAL_ANY_JOIN) { + auto &any_join = join.Cast(); + // any join: only one filter to add + if (AddFilter(std::move(any_join.condition)) == FilterResult::UNSATISFIABLE) { + // filter statically evaluates to false, strip tree + return make_uniq(std::move(op)); + } + } else if (op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { + // Don't mess with non-standard condition interpretations + return FinishPushdown(std::move(op)); + } else { + // comparison join + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN); + auto &comp_join = join.Cast(); + // turn the conditions into filters + for (auto &i : comp_join.conditions) { + auto condition = JoinCondition::CreateExpression(std::move(i)); + if (AddFilter(std::move(condition)) == FilterResult::UNSATISFIABLE) { + // filter statically evaluates to false, strip tree + return make_uniq(std::move(op)); + } + } + } + GenerateFilters(); + + // turn the inner join into a cross product + auto cross_product = make_uniq(std::move(op->children[0]), std::move(op->children[1])); + // then push down cross product + return PushdownCrossProduct(std::move(cross_product)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp new file mode 100644 index 00000000..47cfdfd6 --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp @@ -0,0 +1,130 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +static unique_ptr ReplaceColRefWithNull(unique_ptr expr, unordered_set &right_bindings) { + if (expr->type == ExpressionType::BOUND_COLUMN_REF) { + auto &bound_colref = expr->Cast(); + if (right_bindings.find(bound_colref.binding.table_index) != right_bindings.end()) { + // bound colref belongs to RHS + // replace it with a constant NULL + return make_uniq(Value(expr->return_type)); + } + return expr; + } + ExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child) { child = ReplaceColRefWithNull(std::move(child), right_bindings); }); + return expr; +} + +static bool FilterRemovesNull(ClientContext &context, ExpressionRewriter &rewriter, Expression *expr, + unordered_set &right_bindings) { + // make a copy of the expression + auto copy = expr->Copy(); + // replace all BoundColumnRef expressions frmo the RHS with NULL constants in the copied expression + copy = ReplaceColRefWithNull(std::move(copy), right_bindings); + + // attempt to flatten the expression by running the expression rewriter on it + auto filter = make_uniq(); + filter->expressions.push_back(std::move(copy)); + rewriter.VisitOperator(*filter); + + // check if all expressions are foldable + for (idx_t i = 0; i < filter->expressions.size(); i++) { + if (!filter->expressions[i]->IsFoldable()) { + return false; + } + // we flattened the result into a scalar, check if it is FALSE or NULL + auto val = + ExpressionExecutor::EvaluateScalar(context, *filter->expressions[i]).DefaultCastAs(LogicalType::BOOLEAN); + // if the result of the expression with all expressions replaced with NULL is "NULL" or "false" + // then any extra entries generated by the LEFT OUTER JOIN will be filtered out! + // hence the LEFT OUTER JOIN is equivalent to an inner join + if (val.IsNull() || !BooleanValue::Get(val)) { + return true; + } + } + return false; +} + +unique_ptr FilterPushdown::PushdownLeftJoin(unique_ptr op, + unordered_set &left_bindings, + unordered_set &right_bindings) { + auto &join = op->Cast(); + if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + return FinishPushdown(std::move(op)); + } + FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); + // for a comparison join we create a FilterCombiner that checks if we can push conditions on LHS join conditions + // into the RHS of the join + FilterCombiner filter_combiner(optimizer); + const auto isComparison = (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || + op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN); + if (isComparison) { + // add all comparison conditions + auto &comparison_join = op->Cast(); + for (auto &cond : comparison_join.conditions) { + filter_combiner.AddFilter( + make_uniq(cond.comparison, cond.left->Copy(), cond.right->Copy())); + } + } + // now check the set of filters + for (idx_t i = 0; i < filters.size(); i++) { + auto side = JoinSide::GetJoinSide(filters[i]->bindings, left_bindings, right_bindings); + if (side == JoinSide::LEFT) { + // bindings match left side + // we can push the filter into the left side + if (isComparison) { + // we MIGHT be able to push it down the RHS as well, but only if it is a comparison that matches the + // join predicates we use the FilterCombiner to figure this out add the expression to the FilterCombiner + filter_combiner.AddFilter(filters[i]->filter->Copy()); + } + left_pushdown.filters.push_back(std::move(filters[i])); + // erase the filter from the list of filters + filters.erase(filters.begin() + i); + i--; + } else { + // bindings match right side or both sides: we cannot directly push it into the right + // however, if the filter removes rows with null values from the RHS we can turn the left outer join + // in an inner join, and then push down as we would push down an inner join + if (FilterRemovesNull(optimizer.context, optimizer.rewriter, filters[i]->filter.get(), right_bindings)) { + // the filter removes NULL values, turn it into an inner join + join.join_type = JoinType::INNER; + // now we can do more pushdown + // move all filters we added to the left_pushdown back into the filter list + for (auto &left_filter : left_pushdown.filters) { + filters.push_back(std::move(left_filter)); + } + // now push down the inner join + return PushdownInnerJoin(std::move(op), left_bindings, right_bindings); + } + } + } + // finally we check the FilterCombiner to see if there are any predicates we can push into the RHS + // we only added (1) predicates that have JoinSide::BOTH from the conditions, and + // (2) predicates that have JoinSide::LEFT from the filters + // we check now if this combination generated any new filters that are only on JoinSide::RIGHT + // this happens if, e.g. a join condition is (i=a) and there is a filter (i=500), we can then push the filter + // (a=500) into the RHS + filter_combiner.GenerateFilters([&](unique_ptr filter) { + if (JoinSide::GetJoinSide(*filter, left_bindings, right_bindings) == JoinSide::RIGHT) { + right_pushdown.AddFilter(std::move(filter)); + } + }); + right_pushdown.GenerateFilters(); + op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); + op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); + return PushFinalFilters(std::move(op)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_limit.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_limit.cpp new file mode 100644 index 00000000..c2c2d0b0 --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_limit.cpp @@ -0,0 +1,19 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" + +namespace duckdb { + +unique_ptr FilterPushdown::PushdownLimit(unique_ptr op) { + auto &limit = op->Cast(); + + if (!limit.limit && limit.limit_val == 0) { + return make_uniq(std::move(op)); + } + + return FinishPushdown(std::move(op)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp new file mode 100644 index 00000000..b336cd3e --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp @@ -0,0 +1,83 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +unique_ptr FilterPushdown::PushdownMarkJoin(unique_ptr op, + unordered_set &left_bindings, + unordered_set &right_bindings) { + auto &join = op->Cast(); + auto &comp_join = op->Cast(); + D_ASSERT(join.join_type == JoinType::MARK); + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || + op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN || op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN); + + right_bindings.insert(comp_join.mark_index); + FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); +#ifdef DEBUG + bool simplified_mark_join = false; +#endif + // now check the set of filters + for (idx_t i = 0; i < filters.size(); i++) { + auto side = JoinSide::GetJoinSide(filters[i]->bindings, left_bindings, right_bindings); + if (side == JoinSide::LEFT) { + // bindings match left side: push into left + left_pushdown.filters.push_back(std::move(filters[i])); + // erase the filter from the list of filters + filters.erase(filters.begin() + i); + i--; + } else if (side == JoinSide::RIGHT) { +#ifdef DEBUG + D_ASSERT(!simplified_mark_join); +#endif + // this filter references the marker + // we can turn this into a SEMI join if the filter is on only the marker + if (filters[i]->filter->type == ExpressionType::BOUND_COLUMN_REF) { + // filter just references the marker: turn into semi join +#ifdef DEBUG + simplified_mark_join = true; +#endif + join.join_type = JoinType::SEMI; + filters.erase(filters.begin() + i); + i--; + continue; + } + // if the filter is on NOT(marker) AND the join conditions are all set to "null_values_are_equal" we can + // turn this into an ANTI join if all join conditions have null_values_are_equal=true, then the result of + // the MARK join is always TRUE or FALSE, and never NULL this happens in the case of a correlated EXISTS + // clause + if (filters[i]->filter->type == ExpressionType::OPERATOR_NOT) { + auto &op_expr = filters[i]->filter->Cast(); + if (op_expr.children[0]->type == ExpressionType::BOUND_COLUMN_REF) { + // the filter is NOT(marker), check the join conditions + bool all_null_values_are_equal = true; + for (auto &cond : comp_join.conditions) { + if (cond.comparison != ExpressionType::COMPARE_DISTINCT_FROM && + cond.comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + all_null_values_are_equal = false; + break; + } + } + if (all_null_values_are_equal) { +#ifdef DEBUG + simplified_mark_join = true; +#endif + // all null values are equal, convert to ANTI join + join.join_type = JoinType::ANTI; + filters.erase(filters.begin() + i); + i--; + continue; + } + } + } + } + } + op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); + op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); + return PushFinalFilters(std::move(op)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_projection.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_projection.cpp new file mode 100644 index 00000000..a0d7d814 --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_projection.cpp @@ -0,0 +1,78 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" + +namespace duckdb { + +static bool HasSideEffects(LogicalProjection &proj, const unique_ptr &expr) { + if (expr->type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expr->Cast(); + D_ASSERT(colref.binding.table_index == proj.table_index); + D_ASSERT(colref.binding.column_index < proj.expressions.size()); + D_ASSERT(colref.depth == 0); + if (proj.expressions[colref.binding.column_index]->HasSideEffects()) { + return true; + } + return false; + } + bool has_side_effects = false; + ExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child) { has_side_effects |= HasSideEffects(proj, child); }); + return has_side_effects; +} + +static unique_ptr ReplaceProjectionBindings(LogicalProjection &proj, unique_ptr expr) { + if (expr->type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expr->Cast(); + D_ASSERT(colref.binding.table_index == proj.table_index); + D_ASSERT(colref.binding.column_index < proj.expressions.size()); + D_ASSERT(colref.depth == 0); + // replace the binding with a copy to the expression at the referenced index + return proj.expressions[colref.binding.column_index]->Copy(); + } + ExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child) { child = ReplaceProjectionBindings(proj, std::move(child)); }); + return expr; +} + +unique_ptr FilterPushdown::PushdownProjection(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_PROJECTION); + auto &proj = op->Cast(); + // push filter through logical projection + // all the BoundColumnRefExpressions in the filter should refer to the LogicalProjection + // we can rewrite them by replacing those references with the expression of the LogicalProjection node + FilterPushdown child_pushdown(optimizer); + // There are some expressions can not be pushed down. We should keep them + // and add an extra filter operator. + vector> remain_expressions; + for (auto &filter : filters) { + auto &f = *filter; + D_ASSERT(f.bindings.size() <= 1); + bool has_side_effects = HasSideEffects(proj, f.filter); + if (has_side_effects) { + // We can't push down related expressions if the column in the + // expression is generated by the functions which have side effects + remain_expressions.push_back(std::move(f.filter)); + } else { + // rewrite the bindings within this subquery + f.filter = ReplaceProjectionBindings(proj, std::move(f.filter)); + // add the filter to the child pushdown + if (child_pushdown.AddFilter(std::move(f.filter)) == FilterResult::UNSATISFIABLE) { + // filter statically evaluates to false, strip tree + return make_uniq(std::move(op)); + } + } + } + child_pushdown.GenerateFilters(); + // now push into children + op->children[0] = child_pushdown.Rewrite(std::move(op->children[0])); + if (op->children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { + // child returns an empty result: generate an empty result here too + return make_uniq(std::move(op)); + } + return AddLogicalFilter(std::move(op), std::move(remain_expressions)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_set_operation.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_set_operation.cpp new file mode 100644 index 00000000..e96eba2e --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_set_operation.cpp @@ -0,0 +1,113 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +static void ReplaceSetOpBindings(vector &bindings, Filter &filter, Expression &expr, + LogicalSetOperation &setop) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expr.Cast(); + D_ASSERT(colref.binding.table_index == setop.table_index); + D_ASSERT(colref.depth == 0); + + // rewrite the binding by looking into the bound_tables list of the subquery + colref.binding = bindings[colref.binding.column_index]; + filter.bindings.insert(colref.binding.table_index); + return; + } + ExpressionIterator::EnumerateChildren( + expr, [&](Expression &child) { ReplaceSetOpBindings(bindings, filter, child, setop); }); +} + +unique_ptr FilterPushdown::PushdownSetOperation(unique_ptr op) { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_UNION || op->type == LogicalOperatorType::LOGICAL_EXCEPT || + op->type == LogicalOperatorType::LOGICAL_INTERSECT); + auto &setop = op->Cast(); + + D_ASSERT(op->children.size() == 2); + auto left_bindings = op->children[0]->GetColumnBindings(); + auto right_bindings = op->children[1]->GetColumnBindings(); + if (left_bindings.size() != right_bindings.size()) { + throw InternalException("Filter pushdown - set operation LHS and RHS have incompatible counts"); + } + + // pushdown into set operation, we can duplicate the condition and pushdown the expressions into both sides + FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); + for (idx_t i = 0; i < filters.size(); i++) { + // first create a copy of the filter + auto right_filter = make_uniq(); + right_filter->filter = filters[i]->filter->Copy(); + + // in the original filter, rewrite references to the result of the union into references to the left_index + ReplaceSetOpBindings(left_bindings, *filters[i], *filters[i]->filter, setop); + // in the copied filter, rewrite references to the result of the union into references to the right_index + ReplaceSetOpBindings(right_bindings, *right_filter, *right_filter->filter, setop); + + // extract bindings again + filters[i]->ExtractBindings(); + right_filter->ExtractBindings(); + + // move the filters into the child pushdown nodes + left_pushdown.filters.push_back(std::move(filters[i])); + right_pushdown.filters.push_back(std::move(right_filter)); + } + + op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); + op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); + + bool left_empty = op->children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT; + bool right_empty = op->children[1]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT; + if (left_empty && right_empty) { + // both empty: return empty result + return make_uniq(std::move(op)); + } + if (left_empty) { + // left child is empty result + switch (op->type) { + case LogicalOperatorType::LOGICAL_UNION: + if (op->children[1]->type == LogicalOperatorType::LOGICAL_PROJECTION) { + // union with empty left side: return right child + auto &projection = op->children[1]->Cast(); + projection.table_index = setop.table_index; + return std::move(op->children[1]); + } + break; + case LogicalOperatorType::LOGICAL_EXCEPT: + // except: if left child is empty, return empty result + case LogicalOperatorType::LOGICAL_INTERSECT: + // intersect: if any child is empty, return empty result itself + return make_uniq(std::move(op)); + default: + throw InternalException("Unsupported set operation"); + } + } else if (right_empty) { + // right child is empty result + switch (op->type) { + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + if (op->children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) { + // union or except with empty right child: return left child + auto &projection = op->children[0]->Cast(); + projection.table_index = setop.table_index; + return std::move(op->children[0]); + } + break; + case LogicalOperatorType::LOGICAL_INTERSECT: + // intersect: if any child is empty, return empty result itself + return make_uniq(std::move(op)); + default: + throw InternalException("Unsupported set operation"); + } + } + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_single_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_single_join.cpp new file mode 100644 index 00000000..028dc092 --- /dev/null +++ b/src/duckdb/src/optimizer/pushdown/pushdown_single_join.cpp @@ -0,0 +1,29 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +unique_ptr FilterPushdown::PushdownSingleJoin(unique_ptr op, + unordered_set &left_bindings, + unordered_set &right_bindings) { + D_ASSERT(op->Cast().join_type == JoinType::SINGLE); + FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); + // now check the set of filters + for (idx_t i = 0; i < filters.size(); i++) { + auto side = JoinSide::GetJoinSide(filters[i]->bindings, left_bindings, right_bindings); + if (side == JoinSide::LEFT) { + // bindings match left side: push into left + left_pushdown.filters.push_back(std::move(filters[i])); + // erase the filter from the list of filters + filters.erase(filters.begin() + i); + i--; + } + } + op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); + op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); + return PushFinalFilters(std::move(op)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/regex_range_filter.cpp b/src/duckdb/src/optimizer/regex_range_filter.cpp new file mode 100644 index 00000000..19d0e059 --- /dev/null +++ b/src/duckdb/src/optimizer/regex_range_filter.cpp @@ -0,0 +1,62 @@ +#include "duckdb/optimizer/regex_range_filter.hpp" + +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" + +#include "duckdb/function/scalar/string_functions.hpp" + +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" + +#include "duckdb/function/scalar/regexp.hpp" + +namespace duckdb { + +unique_ptr RegexRangeFilter::Rewrite(unique_ptr op) { + + for (idx_t child_idx = 0; child_idx < op->children.size(); child_idx++) { + op->children[child_idx] = Rewrite(std::move(op->children[child_idx])); + } + + if (op->type != LogicalOperatorType::LOGICAL_FILTER) { + return op; + } + + auto new_filter = make_uniq(); + + for (auto &expr : op->expressions) { + if (expr->type == ExpressionType::BOUND_FUNCTION) { + auto &func = expr->Cast(); + if (func.function.name != "regexp_full_match" || func.children.size() != 2) { + continue; + } + auto &info = func.bind_info->Cast(); + if (!info.range_success) { + continue; + } + auto filter_left = make_uniq( + ExpressionType::COMPARE_GREATERTHANOREQUALTO, func.children[0]->Copy(), + make_uniq(Value::BLOB_RAW(info.range_min))); + auto filter_right = make_uniq( + ExpressionType::COMPARE_LESSTHANOREQUALTO, func.children[0]->Copy(), + make_uniq(Value::BLOB_RAW(info.range_max))); + auto filter_expr = make_uniq(ExpressionType::CONJUNCTION_AND, + std::move(filter_left), std::move(filter_right)); + + new_filter->expressions.push_back(std::move(filter_expr)); + } + } + + if (!new_filter->expressions.empty()) { + new_filter->children = std::move(op->children); + op->children.clear(); + op->children.push_back(std::move(new_filter)); + } + + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/remove_duplicate_groups.cpp b/src/duckdb/src/optimizer/remove_duplicate_groups.cpp new file mode 100644 index 00000000..5ca5d2ad --- /dev/null +++ b/src/duckdb/src/optimizer/remove_duplicate_groups.cpp @@ -0,0 +1,127 @@ +#include "duckdb/optimizer/remove_duplicate_groups.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" + +namespace duckdb { + +void RemoveDuplicateGroups::VisitOperator(LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + VisitAggregate(op.Cast()); + break; + default: + break; + } + LogicalOperatorVisitor::VisitOperatorExpressions(op); + LogicalOperatorVisitor::VisitOperatorChildren(op); +} + +void RemoveDuplicateGroups::VisitAggregate(LogicalAggregate &aggr) { + if (!aggr.grouping_functions.empty()) { + return; + } + + auto &groups = aggr.groups; + + column_binding_map_t duplicate_map; + vector> duplicates; + for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { + const auto &group = groups[group_idx]; + if (group->type != ExpressionType::BOUND_COLUMN_REF) { + continue; + } + const auto &colref = group->Cast(); + const auto &binding = colref.binding; + const auto it = duplicate_map.find(binding); + if (it == duplicate_map.end()) { + duplicate_map.emplace(binding, group_idx); + } else { + duplicates.emplace_back(it->second, group_idx); + } + } + + if (duplicates.empty()) { + return; + } + + // Sort duplicates by max duplicate group idx, because we want to remove groups from the back + sort(duplicates.begin(), duplicates.end(), + [](const pair &lhs, const pair &rhs) { return lhs.second > rhs.second; }); + + // Now we want to remove the duplicates, but this alters the column bindings coming out of the aggregate, + // so we keep track of how they shift and do another round of column binding replacements + column_binding_map_t group_binding_map; + for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { + group_binding_map.emplace(ColumnBinding(aggr.group_index, group_idx), + ColumnBinding(aggr.group_index, group_idx)); + } + + for (idx_t duplicate_idx = 0; duplicate_idx < duplicates.size(); duplicate_idx++) { + const auto &duplicate = duplicates[duplicate_idx]; + const auto &remaining_idx = duplicate.first; + const auto &removed_idx = duplicate.second; + + // Store expression and remove it from groups + stored_expressions.emplace_back(std::move(groups[removed_idx])); + groups.erase(groups.begin() + removed_idx); + + // This optimizer should run before statistics propagation, so this should be empty + // If it runs after, then group_stats should be updated too + D_ASSERT(aggr.group_stats.empty()); + + // Remove from grouping sets too + for (auto &grouping_set : aggr.grouping_sets) { + // Replace removed group with duplicate remaining group + if (grouping_set.erase(removed_idx) != 0) { + grouping_set.insert(remaining_idx); + } + + // Indices shifted: Reinsert groups in the set with group_idx - 1 + vector group_indices_to_reinsert; + for (auto &entry : grouping_set) { + if (entry > removed_idx) { + group_indices_to_reinsert.emplace_back(entry); + } + } + for (const auto group_idx : group_indices_to_reinsert) { + grouping_set.erase(group_idx); + } + for (const auto group_idx : group_indices_to_reinsert) { + grouping_set.insert(group_idx - 1); + } + } + + // Update mapping + auto it = group_binding_map.find(ColumnBinding(aggr.group_index, removed_idx)); + D_ASSERT(it != group_binding_map.end()); + it->second.column_index = remaining_idx; + + for (auto &map_entry : group_binding_map) { + auto &new_binding = map_entry.second; + if (new_binding.column_index > removed_idx) { + new_binding.column_index--; + } + } + } + + // Replace all references to the old group binding with the new group binding + for (const auto &map_entry : group_binding_map) { + auto it = column_references.find(map_entry.first); + if (it != column_references.end()) { + for (auto expr : it->second) { + expr.get().binding = map_entry.second; + } + } + } +} + +unique_ptr RemoveDuplicateGroups::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + // add a column reference + column_references[expr.binding].push_back(expr); + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/remove_unused_columns.cpp b/src/duckdb/src/optimizer/remove_unused_columns.cpp new file mode 100644 index 00000000..4291e174 --- /dev/null +++ b/src/duckdb/src/optimizer/remove_unused_columns.cpp @@ -0,0 +1,337 @@ +#include "duckdb/optimizer/remove_unused_columns.hpp" + +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/parser/parsed_data/vacuum_info.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_order.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" + +namespace duckdb { + +void RemoveUnusedColumns::ReplaceBinding(ColumnBinding current_binding, ColumnBinding new_binding) { + auto colrefs = column_references.find(current_binding); + if (colrefs != column_references.end()) { + for (auto &colref : colrefs->second) { + D_ASSERT(colref->binding == current_binding); + colref->binding = new_binding; + } + } +} + +template +void RemoveUnusedColumns::ClearUnusedExpressions(vector &list, idx_t table_idx, bool replace) { + idx_t offset = 0; + for (idx_t col_idx = 0; col_idx < list.size(); col_idx++) { + auto current_binding = ColumnBinding(table_idx, col_idx + offset); + auto entry = column_references.find(current_binding); + if (entry == column_references.end()) { + // this entry is not referred to, erase it from the set of expressions + list.erase(list.begin() + col_idx); + offset++; + col_idx--; + } else if (offset > 0 && replace) { + // column is used but the ColumnBinding has changed because of removed columns + ReplaceBinding(current_binding, ColumnBinding(table_idx, col_idx)); + } + } +} + +void RemoveUnusedColumns::VisitOperator(LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + // aggregate + if (!everything_referenced) { + // FIXME: groups that are not referenced need to stay -> but they don't need to be scanned and output! + auto &aggr = op.Cast(); + ClearUnusedExpressions(aggr.expressions, aggr.aggregate_index); + if (aggr.expressions.empty() && aggr.groups.empty()) { + // removed all expressions from the aggregate: push a COUNT(*) + auto count_star_fun = CountStarFun::GetFunction(); + FunctionBinder function_binder(context); + aggr.expressions.push_back( + function_binder.BindAggregateFunction(count_star_fun, {}, nullptr, AggregateType::NON_DISTINCT)); + } + } + + // then recurse into the children of the aggregate + RemoveUnusedColumns remove(binder, context); + remove.VisitOperatorExpressions(op); + remove.VisitOperator(*op.children[0]); + return; + } + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + if (!everything_referenced) { + auto &comp_join = op.Cast(); + + if (comp_join.join_type != JoinType::INNER) { + break; + } + // for inner joins with equality predicates in the form of (X=Y) + // we can replace any references to the RHS (Y) to references to the LHS (X) + // this reduces the amount of columns we need to extract from the join hash table + for (auto &cond : comp_join.conditions) { + if (cond.comparison == ExpressionType::COMPARE_EQUAL) { + if (cond.left->expression_class == ExpressionClass::BOUND_COLUMN_REF && + cond.right->expression_class == ExpressionClass::BOUND_COLUMN_REF) { + // comparison join between two bound column refs + // we can replace any reference to the RHS (build-side) with a reference to the LHS (probe-side) + auto &lhs_col = cond.left->Cast(); + auto &rhs_col = cond.right->Cast(); + // if there are any columns that refer to the RHS, + auto colrefs = column_references.find(rhs_col.binding); + if (colrefs != column_references.end()) { + for (auto &entry : colrefs->second) { + entry->binding = lhs_col.binding; + column_references[lhs_col.binding].push_back(entry); + } + column_references.erase(rhs_col.binding); + } + } + } + } + } + break; + } + case LogicalOperatorType::LOGICAL_ANY_JOIN: + break; + case LogicalOperatorType::LOGICAL_UNION: + if (!everything_referenced) { + // for UNION we can remove unreferenced columns as long as everything_referenced is false (i.e. we + // encounter a UNION node that is not preceded by a DISTINCT) + // this happens when UNION ALL is used + auto &setop = op.Cast(); + vector entries; + for (idx_t i = 0; i < setop.column_count; i++) { + entries.push_back(i); + } + ClearUnusedExpressions(entries, setop.table_index); + if (entries.size() < setop.column_count) { + if (entries.empty()) { + // no columns referenced: this happens in the case of a COUNT(*) + // extract the first column + entries.push_back(0); + } + // columns were cleared + setop.column_count = entries.size(); + + for (idx_t child_idx = 0; child_idx < op.children.size(); child_idx++) { + RemoveUnusedColumns remove(binder, context, true); + auto &child = op.children[child_idx]; + + // we push a projection under this child that references the required columns of the union + child->ResolveOperatorTypes(); + auto bindings = child->GetColumnBindings(); + vector> expressions; + expressions.reserve(entries.size()); + for (auto &column_idx : entries) { + expressions.push_back( + make_uniq(child->types[column_idx], bindings[column_idx])); + } + auto new_projection = + make_uniq(binder.GenerateTableIndex(), std::move(expressions)); + new_projection->children.push_back(std::move(child)); + op.children[child_idx] = std::move(new_projection); + + remove.VisitOperator(*op.children[child_idx]); + } + return; + } + } + for (auto &child : op.children) { + RemoveUnusedColumns remove(binder, context, true); + remove.VisitOperator(*child); + } + return; + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + // for INTERSECT/EXCEPT operations we can't remove anything, just recursively visit the children + for (auto &child : op.children) { + RemoveUnusedColumns remove(binder, context, true); + remove.VisitOperator(*child); + } + return; + case LogicalOperatorType::LOGICAL_ORDER_BY: + if (!everything_referenced) { + auto &order = op.Cast(); + D_ASSERT(order.projections.empty()); // should not yet be set + const auto all_bindings = order.GetColumnBindings(); + + for (idx_t col_idx = 0; col_idx < all_bindings.size(); col_idx++) { + if (column_references.find(all_bindings[col_idx]) != column_references.end()) { + order.projections.push_back(col_idx); + } + } + } + for (auto &child : op.children) { + RemoveUnusedColumns remove(binder, context, true); + remove.VisitOperator(*child); + } + return; + case LogicalOperatorType::LOGICAL_PROJECTION: { + if (!everything_referenced) { + auto &proj = op.Cast(); + ClearUnusedExpressions(proj.expressions, proj.table_index); + + if (proj.expressions.empty()) { + // nothing references the projected expressions + // this happens in the case of e.g. EXISTS(SELECT * FROM ...) + // in this case we only need to project a single constant + proj.expressions.push_back(make_uniq(Value::INTEGER(42))); + } + } + // then recurse into the children of this projection + RemoveUnusedColumns remove(binder, context); + remove.VisitOperatorExpressions(op); + remove.VisitOperator(*op.children[0]); + return; + } + case LogicalOperatorType::LOGICAL_INSERT: + case LogicalOperatorType::LOGICAL_UPDATE: + case LogicalOperatorType::LOGICAL_DELETE: { + //! When RETURNING is used, a PROJECTION is the top level operator for INSERTS, UPDATES, and DELETES + //! We still need to project all values from these operators so the projection + //! on top of them can select from only the table values being inserted. + //! TODO: Push down the projections from the returning statement + //! TODO: Be careful because you might be adding expressions when a user returns * + RemoveUnusedColumns remove(binder, context, true); + remove.VisitOperatorExpressions(op); + remove.VisitOperator(*op.children[0]); + return; + } + case LogicalOperatorType::LOGICAL_GET: + LogicalOperatorVisitor::VisitOperatorExpressions(op); + if (!everything_referenced) { + auto &get = op.Cast(); + if (!get.function.projection_pushdown) { + return; + } + + // Create "selection vector" of all column ids + vector proj_sel; + for (idx_t col_idx = 0; col_idx < get.column_ids.size(); col_idx++) { + proj_sel.push_back(col_idx); + } + // Create a copy that we can use to match ids later + auto col_sel = proj_sel; + // Clear unused ids, exclude filter columns that are projected out immediately + ClearUnusedExpressions(proj_sel, get.table_index, false); + + // for every table filter, push a column binding into the column references map to prevent the column from + // being projected out + for (auto &filter : get.table_filters.filters) { + idx_t index = DConstants::INVALID_INDEX; + for (idx_t i = 0; i < get.column_ids.size(); i++) { + if (get.column_ids[i] == filter.first) { + index = i; + break; + } + } + if (index == DConstants::INVALID_INDEX) { + throw InternalException("Could not find column index for table filter"); + } + ColumnBinding filter_binding(get.table_index, index); + if (column_references.find(filter_binding) == column_references.end()) { + column_references.insert(make_pair(filter_binding, vector())); + } + } + + // Clear unused ids, include filter columns that are projected out immediately + ClearUnusedExpressions(col_sel, get.table_index); + + // Now set the column ids in the LogicalGet using the "selection vector" + vector column_ids; + column_ids.reserve(col_sel.size()); + for (auto col_sel_idx : col_sel) { + column_ids.push_back(get.column_ids[col_sel_idx]); + } + get.column_ids = std::move(column_ids); + + if (get.function.filter_prune) { + // Now set the projection cols by matching the "selection vector" that excludes filter columns + // with the "selection vector" that includes filter columns + idx_t col_idx = 0; + for (auto proj_sel_idx : proj_sel) { + for (; col_idx < col_sel.size(); col_idx++) { + if (proj_sel_idx == col_sel[col_idx]) { + get.projection_ids.push_back(col_idx); + break; + } + } + } + } + + if (get.column_ids.empty()) { + // this generally means we are only interested in whether or not anything exists in the table (e.g. + // EXISTS(SELECT * FROM tbl)) in this case, we just scan the row identifier column as it means we do not + // need to read any of the columns + get.column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); + } + } + return; + case LogicalOperatorType::LOGICAL_FILTER: { + auto &filter = op.Cast(); + if (!filter.projection_map.empty()) { + // if we have any entries in the filter projection map don't prune any columns + // FIXME: we can do something more clever here + everything_referenced = true; + } + break; + } + case LogicalOperatorType::LOGICAL_DISTINCT: { + // distinct, all projected columns are used for the DISTINCT computation + // mark all columns as used and continue to the children + // FIXME: DISTINCT with expression list does not implicitly reference everything + everything_referenced = true; + break; + } + case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: { + everything_referenced = true; + break; + } + case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: { + everything_referenced = true; + break; + } + case LogicalOperatorType::LOGICAL_CTE_REF: { + everything_referenced = true; + break; + } + case LogicalOperatorType::LOGICAL_PIVOT: { + everything_referenced = true; + break; + } + default: + break; + } + LogicalOperatorVisitor::VisitOperatorExpressions(op); + LogicalOperatorVisitor::VisitOperatorChildren(op); +} + +unique_ptr RemoveUnusedColumns::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + // add a column reference + column_references[expr.binding].push_back(&expr); + return nullptr; +} + +unique_ptr RemoveUnusedColumns::VisitReplace(BoundReferenceExpression &expr, + unique_ptr *expr_ptr) { + // BoundReferenceExpression should not be used here yet, they only belong in the physical plan + throw InternalException("BoundReferenceExpression should not be used here yet!"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/arithmetic_simplification.cpp b/src/duckdb/src/optimizer/rule/arithmetic_simplification.cpp new file mode 100644 index 00000000..b319c60d --- /dev/null +++ b/src/duckdb/src/optimizer/rule/arithmetic_simplification.cpp @@ -0,0 +1,73 @@ +#include "duckdb/optimizer/rule/arithmetic_simplification.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" + +namespace duckdb { + +ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on an OperatorExpression that has a ConstantExpression as child + auto op = make_uniq(); + op->matchers.push_back(make_uniq()); + op->matchers.push_back(make_uniq()); + op->policy = SetMatcher::Policy::SOME; + // we only match on simple arithmetic expressions (+, -, *, /) + op->function = make_uniq(unordered_set {"+", "-", "*", "//"}); + // and only with numeric results + op->type = make_uniq(); + op->matchers[0]->type = make_uniq(); + op->matchers[1]->type = make_uniq(); + root = std::move(op); +} + +unique_ptr ArithmeticSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get().Cast(); + auto &constant = bindings[1].get().Cast(); + int constant_child = root.children[0].get() == &constant ? 0 : 1; + D_ASSERT(root.children.size() == 2); + (void)root; + // any arithmetic operator involving NULL is always NULL + if (constant.value.IsNull()) { + return make_uniq(Value(root.return_type)); + } + auto &func_name = root.function.name; + if (func_name == "+") { + if (constant.value == 0) { + // addition with 0 + // we can remove the entire operator and replace it with the non-constant child + return std::move(root.children[1 - constant_child]); + } + } else if (func_name == "-") { + if (constant_child == 1 && constant.value == 0) { + // subtraction by 0 + // we can remove the entire operator and replace it with the non-constant child + return std::move(root.children[1 - constant_child]); + } + } else if (func_name == "*") { + if (constant.value == 1) { + // multiply with 1, replace with non-constant child + return std::move(root.children[1 - constant_child]); + } else if (constant.value == 0) { + // multiply by zero: replace with constant or null + return ExpressionRewriter::ConstantOrNull(std::move(root.children[1 - constant_child]), + Value::Numeric(root.return_type, 0)); + } + } else if (func_name == "//") { + if (constant_child == 1) { + if (constant.value == 1) { + // divide by 1, replace with non-constant child + return std::move(root.children[1 - constant_child]); + } else if (constant.value == 0) { + // divide by 0, replace with NULL + return make_uniq(Value(root.return_type)); + } + } + } else { + throw InternalException("Unrecognized function name in ArithmeticSimplificationRule"); + } + return nullptr; +} +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/case_simplification.cpp b/src/duckdb/src/optimizer/rule/case_simplification.cpp new file mode 100644 index 00000000..61c6ed35 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/case_simplification.cpp @@ -0,0 +1,47 @@ +#include "duckdb/optimizer/rule/case_simplification.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" + +namespace duckdb { + +CaseSimplificationRule::CaseSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a CaseExpression that has a ConstantExpression as a check + auto op = make_uniq(); + root = std::move(op); +} + +unique_ptr CaseSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get().Cast(); + for (idx_t i = 0; i < root.case_checks.size(); i++) { + auto &case_check = root.case_checks[i]; + if (case_check.when_expr->IsFoldable()) { + // the WHEN check is a foldable expression + // use an ExpressionExecutor to execute the expression + auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), *case_check.when_expr); + + // fold based on the constant condition + auto condition = constant_value.DefaultCastAs(LogicalType::BOOLEAN); + if (condition.IsNull() || !BooleanValue::Get(condition)) { + // the condition is always false: remove this case check + root.case_checks.erase(root.case_checks.begin() + i); + i--; + } else { + // the condition is always true + // move the THEN clause to the ELSE of the case + root.else_expr = std::move(case_check.then_expr); + // remove this case check and any case checks after this one + root.case_checks.erase(root.case_checks.begin() + i, root.case_checks.end()); + break; + } + } + } + if (root.case_checks.empty()) { + // no case checks left: return the ELSE expression + return std::move(root.else_expr); + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/comparison_simplification.cpp b/src/duckdb/src/optimizer/rule/comparison_simplification.cpp new file mode 100644 index 00000000..c01aa982 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/comparison_simplification.cpp @@ -0,0 +1,78 @@ +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/optimizer/rule/comparison_simplification.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +ComparisonSimplificationRule::ComparisonSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a ComparisonExpression that has a ConstantExpression as a check + auto op = make_uniq(); + op->matchers.push_back(make_uniq()); + op->policy = SetMatcher::Policy::SOME; + root = std::move(op); +} + +unique_ptr ComparisonSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &expr = bindings[0].get().Cast(); + auto &constant_expr = bindings[1].get(); + bool column_ref_left = expr.left.get() != &constant_expr; + auto column_ref_expr = !column_ref_left ? expr.right.get() : expr.left.get(); + // the constant_expr is a scalar expression that we have to fold + // use an ExpressionExecutor to execute the expression + D_ASSERT(constant_expr.IsFoldable()); + Value constant_value; + if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), constant_expr, constant_value)) { + return nullptr; + } + if (constant_value.IsNull() && !(expr.type == ExpressionType::COMPARE_NOT_DISTINCT_FROM || + expr.type == ExpressionType::COMPARE_DISTINCT_FROM)) { + // comparison with constant NULL, return NULL + return make_uniq(Value(LogicalType::BOOLEAN)); + } + if (column_ref_expr->expression_class == ExpressionClass::BOUND_CAST) { + //! Here we check if we can apply the expression on the constant side + //! We can do this if the cast itself is invertible and casting the constant is + //! invertible in practice. + auto &cast_expression = column_ref_expr->Cast(); + auto target_type = cast_expression.source_type(); + if (!BoundCastExpression::CastIsInvertible(target_type, cast_expression.return_type)) { + return nullptr; + } + + // Can we cast the constant at all? + string error_message; + Value cast_constant; + auto new_constant = constant_value.DefaultTryCastAs(target_type, cast_constant, &error_message, true); + if (!new_constant) { + return nullptr; + } + + // Is the constant cast invertible? + if (!cast_constant.IsNull() && + !BoundCastExpression::CastIsInvertible(cast_expression.return_type, target_type)) { + // Is it actually invertible? + Value uncast_constant; + if (!cast_constant.DefaultTryCastAs(constant_value.type(), uncast_constant, &error_message, true) || + uncast_constant != constant_value) { + return nullptr; + } + } + + //! We can cast, now we change our column_ref_expression from an operator cast to a column reference + auto child_expression = std::move(cast_expression.child); + auto new_constant_expr = make_uniq(cast_constant); + if (column_ref_left) { + expr.left = std::move(child_expression); + expr.right = std::move(new_constant_expr); + } else { + expr.left = std::move(new_constant_expr); + expr.right = std::move(child_expression); + } + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/conjunction_simplification.cpp b/src/duckdb/src/optimizer/rule/conjunction_simplification.cpp new file mode 100644 index 00000000..070237cb --- /dev/null +++ b/src/duckdb/src/optimizer/rule/conjunction_simplification.cpp @@ -0,0 +1,70 @@ +#include "duckdb/optimizer/rule/conjunction_simplification.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +ConjunctionSimplificationRule::ConjunctionSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a ComparisonExpression that has a ConstantExpression as a check + auto op = make_uniq(); + op->matchers.push_back(make_uniq()); + op->policy = SetMatcher::Policy::SOME; + root = std::move(op); +} + +unique_ptr ConjunctionSimplificationRule::RemoveExpression(BoundConjunctionExpression &conj, + const Expression &expr) { + for (idx_t i = 0; i < conj.children.size(); i++) { + if (conj.children[i].get() == &expr) { + // erase the expression + conj.children.erase(conj.children.begin() + i); + break; + } + } + if (conj.children.size() == 1) { + // one expression remaining: simply return that expression and erase the conjunction + return std::move(conj.children[0]); + } + return nullptr; +} + +unique_ptr ConjunctionSimplificationRule::Apply(LogicalOperator &op, + vector> &bindings, bool &changes_made, + bool is_root) { + auto &conjunction = bindings[0].get().Cast(); + auto &constant_expr = bindings[1].get(); + // the constant_expr is a scalar expression that we have to fold + // use an ExpressionExecutor to execute the expression + D_ASSERT(constant_expr.IsFoldable()); + Value constant_value; + if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), constant_expr, constant_value)) { + return nullptr; + } + constant_value = constant_value.DefaultCastAs(LogicalType::BOOLEAN); + if (constant_value.IsNull()) { + // we can't simplify conjunctions with a constant NULL + return nullptr; + } + if (conjunction.type == ExpressionType::CONJUNCTION_AND) { + if (!BooleanValue::Get(constant_value)) { + // FALSE in AND, result of expression is false + return make_uniq(Value::BOOLEAN(false)); + } else { + // TRUE in AND, remove the expression from the set + return RemoveExpression(conjunction, constant_expr); + } + } else { + D_ASSERT(conjunction.type == ExpressionType::CONJUNCTION_OR); + if (!BooleanValue::Get(constant_value)) { + // FALSE in OR, remove the expression from the set + return RemoveExpression(conjunction, constant_expr); + } else { + // TRUE in OR, result of expression is true + return make_uniq(Value::BOOLEAN(true)); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/constant_folding.cpp b/src/duckdb/src/optimizer/rule/constant_folding.cpp new file mode 100644 index 00000000..6b7d20c4 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/constant_folding.cpp @@ -0,0 +1,43 @@ +#include "duckdb/optimizer/rule/constant_folding.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +//! The ConstantFoldingExpressionMatcher matches on any scalar expression (i.e. Expression::IsFoldable is true) +class ConstantFoldingExpressionMatcher : public FoldableConstantMatcher { +public: + bool Match(Expression &expr, vector> &bindings) override { + // we also do not match on ConstantExpressions, because we cannot fold those any further + if (expr.type == ExpressionType::VALUE_CONSTANT) { + return false; + } + return FoldableConstantMatcher::Match(expr, bindings); + } +}; + +ConstantFoldingRule::ConstantFoldingRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto op = make_uniq(); + root = std::move(op); +} + +unique_ptr ConstantFoldingRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get(); + // the root is a scalar expression that we have to fold + D_ASSERT(root.IsFoldable() && root.type != ExpressionType::VALUE_CONSTANT); + + // use an ExpressionExecutor to execute the expression + Value result_value; + if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), root, result_value)) { + return nullptr; + } + D_ASSERT(result_value.type().InternalType() == root.return_type.InternalType()); + // now get the value from the result vector and insert it back into the plan as a constant expression + return make_uniq(result_value); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/date_part_simplification.cpp b/src/duckdb/src/optimizer/rule/date_part_simplification.cpp new file mode 100644 index 00000000..037e7a63 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/date_part_simplification.cpp @@ -0,0 +1,104 @@ +#include "duckdb/optimizer/rule/date_part_simplification.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +DatePartSimplificationRule::DatePartSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto func = make_uniq(); + func->function = make_uniq("date_part"); + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + func->policy = SetMatcher::Policy::ORDERED; + root = std::move(func); +} + +unique_ptr DatePartSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &date_part = bindings[0].get().Cast(); + auto &constant_expr = bindings[1].get().Cast(); + auto &constant = constant_expr.value; + + if (constant.IsNull()) { + // NULL specifier: return constant NULL + return make_uniq(Value(date_part.return_type)); + } + // otherwise check the specifier + auto specifier = GetDatePartSpecifier(StringValue::Get(constant)); + string new_function_name; + switch (specifier) { + case DatePartSpecifier::YEAR: + new_function_name = "year"; + break; + case DatePartSpecifier::MONTH: + new_function_name = "month"; + break; + case DatePartSpecifier::DAY: + new_function_name = "day"; + break; + case DatePartSpecifier::DECADE: + new_function_name = "decade"; + break; + case DatePartSpecifier::CENTURY: + new_function_name = "century"; + break; + case DatePartSpecifier::MILLENNIUM: + new_function_name = "millennium"; + break; + case DatePartSpecifier::QUARTER: + new_function_name = "quarter"; + break; + case DatePartSpecifier::WEEK: + new_function_name = "week"; + break; + case DatePartSpecifier::YEARWEEK: + new_function_name = "yearweek"; + break; + case DatePartSpecifier::DOW: + new_function_name = "dayofweek"; + break; + case DatePartSpecifier::ISODOW: + new_function_name = "isodow"; + break; + case DatePartSpecifier::DOY: + new_function_name = "dayofyear"; + break; + case DatePartSpecifier::MICROSECONDS: + new_function_name = "microsecond"; + break; + case DatePartSpecifier::MILLISECONDS: + new_function_name = "millisecond"; + break; + case DatePartSpecifier::SECOND: + new_function_name = "second"; + break; + case DatePartSpecifier::MINUTE: + new_function_name = "minute"; + break; + case DatePartSpecifier::HOUR: + new_function_name = "hour"; + break; + default: + return nullptr; + } + // found a replacement function: bind it + vector> children; + children.push_back(std::move(date_part.children[1])); + + string error; + FunctionBinder binder(rewriter.context); + auto function = binder.BindScalarFunction(DEFAULT_SCHEMA, new_function_name, std::move(children), error, false); + if (!function) { + throw BinderException(error); + } + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/distributivity.cpp b/src/duckdb/src/optimizer/rule/distributivity.cpp new file mode 100644 index 00000000..509960c0 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/distributivity.cpp @@ -0,0 +1,135 @@ +#include "duckdb/optimizer/rule/distributivity.hpp" + +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" + +namespace duckdb { + +DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // we match on an OR expression within a LogicalFilter node + root = make_uniq(); + root->expr_type = make_uniq(ExpressionType::CONJUNCTION_OR); +} + +void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) { + if (expr.type == ExpressionType::CONJUNCTION_AND) { + auto &and_expr = expr.Cast(); + for (auto &child : and_expr.children) { + set.insert(*child); + } + } else { + set.insert(expr); + } +} + +unique_ptr DistributivityRule::ExtractExpression(BoundConjunctionExpression &conj, idx_t idx, + Expression &expr) { + auto &child = conj.children[idx]; + unique_ptr result; + if (child->type == ExpressionType::CONJUNCTION_AND) { + // AND, remove expression from the list + auto &and_expr = child->Cast(); + for (idx_t i = 0; i < and_expr.children.size(); i++) { + if (and_expr.children[i]->Equals(expr)) { + result = std::move(and_expr.children[i]); + and_expr.children.erase(and_expr.children.begin() + i); + break; + } + } + if (and_expr.children.size() == 1) { + conj.children[idx] = std::move(and_expr.children[0]); + } + } else { + // not an AND node! remove the entire expression + // this happens in the case of e.g. (X AND B) OR X + D_ASSERT(child->Equals(expr)); + result = std::move(child); + conj.children[idx] = nullptr; + } + D_ASSERT(result); + return result; +} + +unique_ptr DistributivityRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &initial_or = bindings[0].get().Cast(); + + // we want to find expressions that occur in each of the children of the OR + // i.e. (X AND A) OR (X AND B) => X occurs in all branches + // first, for the initial child, we create an expression set of which expressions occur + // this is our initial candidate set (in the example: [X, A]) + expression_set_t candidate_set; + AddExpressionSet(*initial_or.children[0], candidate_set); + // now for each of the remaining children, we create a set again and intersect them + // in our example: the second set would be [X, B] + // the intersection would leave [X] + for (idx_t i = 1; i < initial_or.children.size(); i++) { + expression_set_t next_set; + AddExpressionSet(*initial_or.children[i], next_set); + expression_set_t intersect_result; + for (auto &expr : candidate_set) { + if (next_set.find(expr) != next_set.end()) { + intersect_result.insert(expr); + } + } + candidate_set = intersect_result; + } + if (candidate_set.empty()) { + // nothing found: abort + return nullptr; + } + // now for each of the remaining expressions in the candidate set we know that it is contained in all branches of + // the OR + auto new_root = make_uniq(ExpressionType::CONJUNCTION_AND); + for (auto &expr : candidate_set) { + D_ASSERT(initial_or.children.size() > 0); + + // extract the expression from the first child of the OR + auto result = ExtractExpression(initial_or, 0, expr.get()); + // now for the subsequent expressions, simply remove the expression + for (idx_t i = 1; i < initial_or.children.size(); i++) { + ExtractExpression(initial_or, i, *result); + } + // now we add the expression to the new root + new_root->children.push_back(std::move(result)); + } + + // check if we completely erased one of the children of the OR + // this happens if we have an OR in the form of "X OR (X AND A)" + // the left child will be completely empty, as it only contains common expressions + // in this case, any other children are not useful: + // X OR (X AND A) is the same as "X" + // since (1) only tuples that do not qualify "X" will not pass this predicate + // and (2) all tuples that qualify "X" will pass this predicate + for (idx_t i = 0; i < initial_or.children.size(); i++) { + if (!initial_or.children[i]) { + if (new_root->children.size() <= 1) { + return std::move(new_root->children[0]); + } else { + return std::move(new_root); + } + } + } + // finally we need to add the remaining expressions in the OR to the new root + if (initial_or.children.size() == 1) { + // one child: skip the OR entirely and only add the single child + new_root->children.push_back(std::move(initial_or.children[0])); + } else if (initial_or.children.size() > 1) { + // multiple children still remain: push them into a new OR and add that to the new root + auto new_or = make_uniq(ExpressionType::CONJUNCTION_OR); + for (auto &child : initial_or.children) { + new_or->children.push_back(std::move(child)); + } + new_root->children.push_back(std::move(new_or)); + } + // finally return the new root + if (new_root->children.size() == 1) { + return std::move(new_root->children[0]); + } + return std::move(new_root); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/empty_needle_removal.cpp b/src/duckdb/src/optimizer/rule/empty_needle_removal.cpp new file mode 100644 index 00000000..500d639a --- /dev/null +++ b/src/duckdb/src/optimizer/rule/empty_needle_removal.cpp @@ -0,0 +1,54 @@ +#include "duckdb/optimizer/rule/empty_needle_removal.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" + +namespace duckdb { + +EmptyNeedleRemovalRule::EmptyNeedleRemovalRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a FunctionExpression that has a foldable ConstantExpression + auto func = make_uniq(); + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + func->policy = SetMatcher::Policy::SOME; + + unordered_set functions = {"prefix", "contains", "suffix"}; + func->function = make_uniq(functions); + root = std::move(func); +} + +unique_ptr EmptyNeedleRemovalRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get().Cast(); + D_ASSERT(root.children.size() == 2); + auto &prefix_expr = bindings[2].get(); + + // the constant_expr is a scalar expression that we have to fold + if (!prefix_expr.IsFoldable()) { + return nullptr; + } + D_ASSERT(root.return_type.id() == LogicalTypeId::BOOLEAN); + + auto prefix_value = ExpressionExecutor::EvaluateScalar(GetContext(), prefix_expr); + + if (prefix_value.IsNull()) { + return make_uniq(Value(LogicalType::BOOLEAN)); + } + + D_ASSERT(prefix_value.type() == prefix_expr.return_type); + auto &needle_string = StringValue::Get(prefix_value); + + // PREFIX('xyz', '') is TRUE + // PREFIX(NULL, '') is NULL + // so rewrite PREFIX(x, '') to TRUE_OR_NULL(x) + if (needle_string.empty()) { + return ExpressionRewriter::ConstantOrNull(std::move(root.children[0]), Value::BOOLEAN(true)); + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/enum_comparison.cpp b/src/duckdb/src/optimizer/rule/enum_comparison.cpp new file mode 100644 index 00000000..8b052578 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/enum_comparison.cpp @@ -0,0 +1,70 @@ +#include "duckdb/optimizer/rule/enum_comparison.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/optimizer/matcher/type_matcher_id.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +EnumComparisonRule::EnumComparisonRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a ComparisonExpression that is an Equality and has a VARCHAR and ENUM as its children + auto op = make_uniq(); + // Enum requires expression to be root + op->expr_type = make_uniq(ExpressionType::COMPARE_EQUAL); + for (idx_t i = 0; i < 2; i++) { + auto child = make_uniq(); + child->type = make_uniq(LogicalTypeId::VARCHAR); + child->matcher = make_uniq(); + child->matcher->type = make_uniq(LogicalTypeId::ENUM); + op->matchers.push_back(std::move(child)); + } + root = std::move(op); +} + +bool AreMatchesPossible(LogicalType &left, LogicalType &right) { + LogicalType *small_enum, *big_enum; + if (EnumType::GetSize(left) < EnumType::GetSize(right)) { + small_enum = &left; + big_enum = &right; + } else { + small_enum = &right; + big_enum = &left; + } + auto &string_vec = EnumType::GetValuesInsertOrder(*small_enum); + auto string_vec_ptr = FlatVector::GetData(string_vec); + auto size = EnumType::GetSize(*small_enum); + for (idx_t i = 0; i < size; i++) { + auto key = string_vec_ptr[i].GetString(); + if (EnumType::GetPos(*big_enum, key) != -1) { + return true; + } + } + return false; +} +unique_ptr EnumComparisonRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + + auto &root = bindings[0].get().Cast(); + auto &left_child = bindings[1].get().Cast(); + auto &right_child = bindings[3].get().Cast(); + + if (!AreMatchesPossible(left_child.child->return_type, right_child.child->return_type)) { + vector> children; + children.push_back(std::move(root.left)); + children.push_back(std::move(root.right)); + return ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(false)); + } + + if (!is_root || op.type != LogicalOperatorType::LOGICAL_FILTER) { + return nullptr; + } + + auto cast_left_to_right = + BoundCastExpression::AddDefaultCastToType(std::move(left_child.child), right_child.child->return_type, true); + return make_uniq(root.type, std::move(cast_left_to_right), std::move(right_child.child)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/equal_or_null_simplification.cpp b/src/duckdb/src/optimizer/rule/equal_or_null_simplification.cpp new file mode 100644 index 00000000..d53776b5 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/equal_or_null_simplification.cpp @@ -0,0 +1,109 @@ +#include "duckdb/optimizer/rule/equal_or_null_simplification.hpp" + +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" + +namespace duckdb { + +EqualOrNullSimplification::EqualOrNullSimplification(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on OR conjunction + auto op = make_uniq(); + op->expr_type = make_uniq(ExpressionType::CONJUNCTION_OR); + op->policy = SetMatcher::Policy::SOME; + + // equi comparison on one side + auto equal_child = make_uniq(); + equal_child->expr_type = make_uniq(ExpressionType::COMPARE_EQUAL); + equal_child->policy = SetMatcher::Policy::SOME; + op->matchers.push_back(std::move(equal_child)); + + // AND conjuction on the other + auto and_child = make_uniq(); + and_child->expr_type = make_uniq(ExpressionType::CONJUNCTION_AND); + and_child->policy = SetMatcher::Policy::SOME; + + // IS NULL tests inside AND + auto isnull_child = make_uniq(); + isnull_child->expr_type = make_uniq(ExpressionType::OPERATOR_IS_NULL); + // I could try to use std::make_uniq for a copy, but it's available from C++14 only + auto isnull_child2 = make_uniq(); + isnull_child2->expr_type = make_uniq(ExpressionType::OPERATOR_IS_NULL); + and_child->matchers.push_back(std::move(isnull_child)); + and_child->matchers.push_back(std::move(isnull_child2)); + + op->matchers.push_back(std::move(and_child)); + root = std::move(op); +} + +// a=b OR (a IS NULL AND b IS NULL) to a IS NOT DISTINCT FROM b +static unique_ptr TryRewriteEqualOrIsNull(Expression &equal_expr, Expression &and_expr) { + if (equal_expr.type != ExpressionType::COMPARE_EQUAL || and_expr.type != ExpressionType::CONJUNCTION_AND) { + return nullptr; + } + + auto &equal_cast = equal_expr.Cast(); + auto &and_cast = and_expr.Cast(); + + if (and_cast.children.size() != 2) { + return nullptr; + } + + // Make sure on the AND conjuction the relevant conditions appear + auto &a_exp = *equal_cast.left; + auto &b_exp = *equal_cast.right; + bool a_is_null_found = false; + bool b_is_null_found = false; + + for (const auto &item : and_cast.children) { + auto &next_exp = *item; + + if (next_exp.type == ExpressionType::OPERATOR_IS_NULL) { + auto &next_exp_cast = next_exp.Cast(); + auto &child = *next_exp_cast.children[0]; + + // Test for equality on both 'a' and 'b' expressions + if (Expression::Equals(child, a_exp)) { + a_is_null_found = true; + } else if (Expression::Equals(child, b_exp)) { + b_is_null_found = true; + } else { + return nullptr; + } + } else { + return nullptr; + } + } + if (a_is_null_found && b_is_null_found) { + return make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, + std::move(equal_cast.left), std::move(equal_cast.right)); + } + return nullptr; +} + +unique_ptr EqualOrNullSimplification::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + const Expression &or_exp = bindings[0].get(); + + if (or_exp.type != ExpressionType::CONJUNCTION_OR) { + return nullptr; + } + + const auto &or_exp_cast = or_exp.Cast(); + + if (or_exp_cast.children.size() != 2) { + return nullptr; + } + + auto &left_exp = *or_exp_cast.children[0]; + auto &right_exp = *or_exp_cast.children[1]; + // Test for: a=b OR (a IS NULL AND b IS NULL) + auto first_try = TryRewriteEqualOrIsNull(left_exp, right_exp); + if (first_try) { + return first_try; + } + // Test for: (a IS NULL AND b IS NULL) OR a=b + return TryRewriteEqualOrIsNull(right_exp, left_exp); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/in_clause_simplification_rule.cpp b/src/duckdb/src/optimizer/rule/in_clause_simplification_rule.cpp new file mode 100644 index 00000000..e1ad4fd9 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/in_clause_simplification_rule.cpp @@ -0,0 +1,57 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/rule/in_clause_simplification.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" + +namespace duckdb { + +InClauseSimplificationRule::InClauseSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on InClauseExpression that has a ConstantExpression as a check + auto op = make_uniq(); + op->policy = SetMatcher::Policy::SOME; + root = std::move(op); +} + +unique_ptr InClauseSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &expr = bindings[0].get().Cast(); + if (expr.children[0]->expression_class != ExpressionClass::BOUND_CAST) { + return nullptr; + } + auto &cast_expression = expr.children[0]->Cast(); + if (cast_expression.child->expression_class != ExpressionClass::BOUND_COLUMN_REF) { + return nullptr; + } + //! Here we check if we can apply the expression on the constant side + auto target_type = cast_expression.source_type(); + if (!BoundCastExpression::CastIsInvertible(cast_expression.return_type, target_type)) { + return nullptr; + } + vector> cast_list; + //! First check if we can cast all children + for (size_t i = 1; i < expr.children.size(); i++) { + if (expr.children[i]->expression_class != ExpressionClass::BOUND_CONSTANT) { + return nullptr; + } + D_ASSERT(expr.children[i]->IsFoldable()); + auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), *expr.children[i]); + auto new_constant = constant_value.DefaultTryCastAs(target_type); + if (!new_constant) { + return nullptr; + } else { + auto new_constant_expr = make_uniq(constant_value); + cast_list.push_back(std::move(new_constant_expr)); + } + } + //! We can cast, so we move the new constant + for (size_t i = 1; i < expr.children.size(); i++) { + expr.children[i] = std::move(cast_list[i - 1]); + + // expr->children[i] = std::move(new_constant_expr); + } + //! We can cast the full list, so we move the column + expr.children[0] = std::move(cast_expression.child); + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/like_optimizations.cpp b/src/duckdb/src/optimizer/rule/like_optimizations.cpp new file mode 100644 index 00000000..96f7b150 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/like_optimizations.cpp @@ -0,0 +1,162 @@ +#include "duckdb/optimizer/rule/like_optimizations.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" + +namespace duckdb { + +LikeOptimizationRule::LikeOptimizationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a FunctionExpression that has a foldable ConstantExpression + auto func = make_uniq(); + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + func->policy = SetMatcher::Policy::ORDERED; + // we match on LIKE ("~~") and NOT LIKE ("!~~") + func->function = make_uniq(unordered_set {"!~~", "~~"}); + root = std::move(func); +} + +static bool PatternIsConstant(const string &pattern) { + for (idx_t i = 0; i < pattern.size(); i++) { + if (pattern[i] == '%' || pattern[i] == '_') { + return false; + } + } + return true; +} + +static bool PatternIsPrefix(const string &pattern) { + idx_t i; + for (i = pattern.size(); i > 0; i--) { + if (pattern[i - 1] != '%') { + break; + } + } + if (i == pattern.size()) { + // no trailing % + // cannot be a prefix + return false; + } + // continue to look in the string + // if there is a % or _ in the string (besides at the very end) this is not a prefix match + for (; i > 0; i--) { + if (pattern[i - 1] == '%' || pattern[i - 1] == '_') { + return false; + } + } + return true; +} + +static bool PatternIsSuffix(const string &pattern) { + idx_t i; + for (i = 0; i < pattern.size(); i++) { + if (pattern[i] != '%') { + break; + } + } + if (i == 0) { + // no leading % + // cannot be a suffix + return false; + } + // continue to look in the string + // if there is a % or _ in the string (besides at the beginning) this is not a suffix match + for (; i < pattern.size(); i++) { + if (pattern[i] == '%' || pattern[i] == '_') { + return false; + } + } + return true; +} + +static bool PatternIsContains(const string &pattern) { + idx_t start; + idx_t end; + for (start = 0; start < pattern.size(); start++) { + if (pattern[start] != '%') { + break; + } + } + for (end = pattern.size(); end > 0; end--) { + if (pattern[end - 1] != '%') { + break; + } + } + if (start == 0 || end == pattern.size()) { + // contains requires both a leading AND a trailing % + return false; + } + // check if there are any other special characters in the string + // if there is a % or _ in the string (besides at the beginning/end) this is not a contains match + for (idx_t i = start; i < end; i++) { + if (pattern[i] == '%' || pattern[i] == '_') { + return false; + } + } + return true; +} + +unique_ptr LikeOptimizationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get().Cast(); + auto &constant_expr = bindings[2].get().Cast(); + D_ASSERT(root.children.size() == 2); + + if (constant_expr.value.IsNull()) { + return make_uniq(Value(root.return_type)); + } + + // the constant_expr is a scalar expression that we have to fold + if (!constant_expr.IsFoldable()) { + return nullptr; + } + + auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), constant_expr); + D_ASSERT(constant_value.type() == constant_expr.return_type); + auto &patt_str = StringValue::Get(constant_value); + + bool is_not_like = root.function.name == "!~~"; + if (PatternIsConstant(patt_str)) { + // Pattern is constant + return make_uniq(is_not_like ? ExpressionType::COMPARE_NOTEQUAL + : ExpressionType::COMPARE_EQUAL, + std::move(root.children[0]), std::move(root.children[1])); + } else if (PatternIsPrefix(patt_str)) { + // Prefix LIKE pattern : [^%_]*[%]+, ignoring underscore + return ApplyRule(root, PrefixFun::GetFunction(), patt_str, is_not_like); + } else if (PatternIsSuffix(patt_str)) { + // Suffix LIKE pattern: [%]+[^%_]*, ignoring underscore + return ApplyRule(root, SuffixFun::GetFunction(), patt_str, is_not_like); + } else if (PatternIsContains(patt_str)) { + // Contains LIKE pattern: [%]+[^%_]*[%]+, ignoring underscore + return ApplyRule(root, ContainsFun::GetFunction(), patt_str, is_not_like); + } + return nullptr; +} + +unique_ptr LikeOptimizationRule::ApplyRule(BoundFunctionExpression &expr, ScalarFunction function, + string pattern, bool is_not_like) { + // replace LIKE by an optimized function + unique_ptr result; + auto new_function = + make_uniq(expr.return_type, std::move(function), std::move(expr.children), nullptr); + + // removing "%" from the pattern + pattern.erase(std::remove(pattern.begin(), pattern.end(), '%'), pattern.end()); + + new_function->children[1] = make_uniq(Value(std::move(pattern))); + + result = std::move(new_function); + if (is_not_like) { + auto negation = make_uniq(ExpressionType::OPERATOR_NOT, LogicalType::BOOLEAN); + negation->children.push_back(std::move(result)); + result = std::move(negation); + } + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/move_constants.cpp b/src/duckdb/src/optimizer/rule/move_constants.cpp new file mode 100644 index 00000000..9704d665 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/move_constants.cpp @@ -0,0 +1,160 @@ +#include "duckdb/optimizer/rule/move_constants.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" + +namespace duckdb { + +MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto op = make_uniq(); + op->matchers.push_back(make_uniq()); + op->policy = SetMatcher::Policy::UNORDERED; + + auto arithmetic = make_uniq(); + // we handle multiplication, addition and subtraction because those are "easy" + // integer division makes the division case difficult + // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules + arithmetic->function = make_uniq(unordered_set {"+", "-", "*"}); + // we match only on integral numeric types + arithmetic->type = make_uniq(); + arithmetic->matchers.push_back(make_uniq()); + arithmetic->matchers.push_back(make_uniq()); + arithmetic->policy = SetMatcher::Policy::SOME; + op->matchers.push_back(std::move(arithmetic)); + root = std::move(op); +} + +unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &comparison = bindings[0].get().Cast(); + auto &outer_constant = bindings[1].get().Cast(); + auto &arithmetic = bindings[2].get().Cast(); + auto &inner_constant = bindings[3].get().Cast(); + if (!TypeIsIntegral(arithmetic.return_type.InternalType())) { + return nullptr; + } + if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { + return make_uniq(Value(comparison.return_type)); + } + auto &constant_type = outer_constant.return_type; + hugeint_t outer_value = IntegralValue::Get(outer_constant.value); + hugeint_t inner_value = IntegralValue::Get(inner_constant.value); + + idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; + auto &op_type = arithmetic.function.name; + if (op_type == "+") { + // [x + 1 COMP 10] OR [1 + x COMP 10] + // order does not matter in addition: + // simply change right side to 10-1 (outer_constant - inner_constant) + if (!Hugeint::SubtractInPlace(outer_value, inner_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(outer_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + // if the cast is not possible then the comparison is not possible + // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 + // since this is not possible we can remove the entire branch here + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } else if (op_type == "-") { + // [x - 1 COMP 10] O R [1 - x COMP 10] + // order matters in subtraction: + if (arithmetic_child_index == 0) { + // [x - 1 COMP 10] + // change right side to 10+1 (outer_constant + inner_constant) + if (!Hugeint::AddInPlace(outer_value, inner_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(outer_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + // if the cast is not possible then an equality comparison is not possible + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } else { + // [1 - x COMP 10] + // change right side to 1-10=-9 + if (!Hugeint::SubtractInPlace(inner_value, outer_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(inner_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + // if the cast is not possible then an equality comparison is not possible + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + // in this case, we should also flip the comparison + // e.g. if we have [4 - x < 2] then we should have [x > 2] + comparison.type = FlipComparisonExpression(comparison.type); + } + } else { + D_ASSERT(op_type == "*"); + // [x * 2 COMP 10] OR [2 * x COMP 10] + // order does not matter in multiplication: + // change right side to 10/2 (outer_constant / inner_constant) + // but ONLY if outer_constant is cleanly divisible by the inner_constant + if (inner_value == 0) { + // x * 0, the result is either 0 or NULL + // we let the arithmetic_simplification rule take care of simplifying this first + return nullptr; + } + if (outer_value % inner_value != 0) { + // not cleanly divisible + bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; + bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; + if (is_equality || is_inequality) { + // we know the values are not equal + // the result will be either FALSE or NULL (if COMPARE_EQUAL) + // or TRUE or NULL (if COMPARE_NOTEQUAL) + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(is_inequality)); + } else { + // not cleanly divisible and we are doing > >= < <=, skip the simplification for now + return nullptr; + } + } + if (inner_value < 0) { + // multiply by negative value, need to flip expression + comparison.type = FlipComparisonExpression(comparison.type); + } + // else divide the RHS by the LHS + // we need to do a range check on the cast even though we do a division + // because e.g. -128 / -1 = 128, which is out of range + auto result_value = Value::HUGEINT(outer_value / inner_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } + // replace left side with x + // first extract x from the arithmetic expression + auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); + // then place in the comparison + if (comparison.left.get() == &outer_constant) { + comparison.right = std::move(arithmetic_child); + } else { + comparison.left = std::move(arithmetic_child); + } + changes_made = true; + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/ordered_aggregate_optimizer.cpp b/src/duckdb/src/optimizer/rule/ordered_aggregate_optimizer.cpp new file mode 100644 index 00000000..ed4a5ae3 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/ordered_aggregate_optimizer.cpp @@ -0,0 +1,30 @@ +#include "duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp" + +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +OrderedAggregateOptimizer::OrderedAggregateOptimizer(ExpressionRewriter &rewriter) : Rule(rewriter) { + // we match on an OR expression within a LogicalFilter node + root = make_uniq(); + root->expr_class = ExpressionClass::BOUND_AGGREGATE; +} + +unique_ptr OrderedAggregateOptimizer::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &aggr = bindings[0].get().Cast(); + if (!aggr.order_bys) { + // no ORDER BYs defined + return nullptr; + } + if (aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT) { + // not an order dependent aggregate but we have an ORDER BY clause - remove it + aggr.order_bys.reset(); + changes_made = true; + return nullptr; + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp new file mode 100644 index 00000000..d1f49121 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp @@ -0,0 +1,203 @@ +#include "duckdb/optimizer/rule/regex_optimizations.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/function/scalar/regexp.hpp" + +#include "re2/re2.h" +#include "re2/regexp.h" + +namespace duckdb { + +RegexOptimizationRule::RegexOptimizationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto func = make_uniq(); + func->function = make_uniq("regexp_matches"); + func->policy = SetMatcher::Policy::SOME_ORDERED; + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + + root = std::move(func); +} + +struct LikeString { + bool exists = true; + bool escaped = false; + string like_string = ""; +}; + +static void AddCharacter(char chr, LikeString &ret, bool contains) { + // if we are not converting into a contains, and the string has LIKE special characters + // then don't return a possible LIKE match + // same if the character is a control character + if (iscntrl(chr) || (!contains && (chr == '%' || chr == '_'))) { + ret.exists = false; + return; + } + auto run_as_str {chr}; + ret.like_string += run_as_str; +} + +static LikeString GetLikeStringEscaped(duckdb_re2::Regexp *regexp, bool contains = false) { + D_ASSERT(regexp->op() == duckdb_re2::kRegexpLiteralString || regexp->op() == duckdb_re2::kRegexpLiteral); + LikeString ret; + + if (regexp->parse_flags() & duckdb_re2::Regexp::FoldCase || + !(regexp->parse_flags() & duckdb_re2::Regexp::OneLine)) { + // parse flags can turn on and off within a regex match, return no optimization + // For now, we just don't optimize if these every turn on. + // TODO: logic to attempt the optimization, then if the parse flags change, then abort + ret.exists = false; + return ret; + } + + // case insensitivity may be on now, but it can also turn off. + if (regexp->op() == duckdb_re2::kRegexpLiteralString) { + auto nrunes = (idx_t)regexp->nrunes(); + auto runes = regexp->runes(); + for (idx_t i = 0; i < nrunes; i++) { + char chr = toascii(runes[i]); + AddCharacter(chr, ret, contains); + if (!ret.exists) { + return ret; + } + } + } else { + auto rune = regexp->rune(); + char chr = toascii(rune); + AddCharacter(chr, ret, contains); + } + D_ASSERT(ret.like_string.size() >= 1 || !ret.exists); + return ret; +} + +static LikeString LikeMatchFromRegex(duckdb_re2::RE2 &pattern) { + LikeString ret = LikeString(); + auto num_subs = pattern.Regexp()->nsub(); + auto subs = pattern.Regexp()->sub(); + auto cur_sub_index = 0; + while (cur_sub_index < num_subs) { + switch (subs[cur_sub_index]->op()) { + case duckdb_re2::kRegexpAnyChar: + if (cur_sub_index == 0) { + ret.like_string += "%"; + } + ret.like_string += "_"; + if (cur_sub_index + 1 == num_subs) { + ret.like_string += "%"; + } + break; + case duckdb_re2::kRegexpStar: + // .* is a Star operator is a anyChar operator as a child. + // any other child operator would represent a pattern LIKE cannot match. + if (subs[cur_sub_index]->nsub() == 1 && subs[cur_sub_index]->sub()[0]->op() == duckdb_re2::kRegexpAnyChar) { + ret.like_string += "%"; + break; + } + ret.exists = false; + return ret; + case duckdb_re2::kRegexpLiteralString: + case duckdb_re2::kRegexpLiteral: { + // if this is the only matching op, we should have directly called + // GetEscapedLikeString + D_ASSERT(!(cur_sub_index == 0 && cur_sub_index + 1 == num_subs)); + if (cur_sub_index == 0) { + ret.like_string += "%"; + } + // if the kRegexpLiteral or kRegexpLiteralString is the only op to match + // the string can directly be converted into a contains + LikeString escaped_like_string = GetLikeStringEscaped(subs[cur_sub_index], false); + if (!escaped_like_string.exists) { + return escaped_like_string; + } + ret.like_string += escaped_like_string.like_string; + ret.escaped = escaped_like_string.escaped; + if (cur_sub_index + 1 == num_subs) { + ret.like_string += "%"; + } + break; + } + case duckdb_re2::kRegexpEndText: + case duckdb_re2::kRegexpEmptyMatch: + case duckdb_re2::kRegexpBeginText: { + break; + } + default: + // some other regexp op that doesn't have an equivalent to a like string + // return false; + ret.exists = false; + return ret; + } + cur_sub_index += 1; + } + return ret; +} + +unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get().Cast(); + auto &constant_expr = bindings[2].get().Cast(); + D_ASSERT(root.children.size() == 2 || root.children.size() == 3); + auto regexp_bind_data = root.bind_info.get()->Cast(); + + auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), constant_expr); + D_ASSERT(constant_value.type() == constant_expr.return_type); + auto patt_str = StringValue::Get(constant_value); + + duckdb_re2::RE2::Options parsed_options = regexp_bind_data.options; + + if (constant_expr.value.IsNull()) { + return make_uniq(Value(root.return_type)); + } + + // the constant_expr is a scalar expression that we have to fold + if (!constant_expr.IsFoldable()) { + return nullptr; + }; + + duckdb_re2::RE2 pattern(patt_str, parsed_options); + if (!pattern.ok()) { + return nullptr; // this should fail somewhere else + } + + LikeString like_string; + // check for a like string. If we can convert it to a like string, the like string + // optimizer will further optimize suffix and prefix things. + if (pattern.Regexp()->op() == duckdb_re2::kRegexpLiteralString || + pattern.Regexp()->op() == duckdb_re2::kRegexpLiteral) { + // convert to contains. + LikeString escaped_like_string = GetLikeStringEscaped(pattern.Regexp(), true); + if (!escaped_like_string.exists) { + return nullptr; + } + auto parameter = make_uniq(Value(std::move(escaped_like_string.like_string))); + auto contains = make_uniq(root.return_type, ContainsFun::GetFunction(), + std::move(root.children), nullptr); + contains->children[1] = std::move(parameter); + + return std::move(contains); + } else if (pattern.Regexp()->op() == duckdb_re2::kRegexpConcat) { + like_string = LikeMatchFromRegex(pattern); + } else { + like_string.exists = false; + } + + if (!like_string.exists) { + return nullptr; + } + + // if regexp had options, remove them so the new Like Expression can be matched for other optimizers. + if (root.children.size() == 3) { + root.children.pop_back(); + D_ASSERT(root.children.size() == 2); + } + + auto like_expression = make_uniq(root.return_type, LikeFun::GetLikeFunction(), + std::move(root.children), nullptr); + auto parameter = make_uniq(Value(std::move(like_string.like_string))); + like_expression->children[1] = std::move(parameter); + return std::move(like_expression); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp new file mode 100644 index 00000000..b3dc53b8 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp @@ -0,0 +1,25 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateExpression(BoundAggregateExpression &aggr, + unique_ptr *expr_ptr) { + vector stats; + stats.reserve(aggr.children.size()); + for (auto &child : aggr.children) { + auto stat = PropagateExpression(child); + if (!stat) { + stats.push_back(BaseStatistics::CreateUnknown(child->return_type)); + } else { + stats.push_back(stat->Copy()); + } + } + if (!aggr.function.statistics) { + return nullptr; + } + AggregateStatisticsInput input(aggr.bind_info.get(), stats, node_stats.get()); + return aggr.function.statistics(context, aggr, input); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_between.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_between.cpp new file mode 100644 index 00000000..fd4865e7 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_between.cpp @@ -0,0 +1,65 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateExpression(BoundBetweenExpression &between, + unique_ptr *expr_ptr) { + // propagate in all the children + auto input_stats = PropagateExpression(between.input); + auto lower_stats = PropagateExpression(between.lower); + auto upper_stats = PropagateExpression(between.upper); + if (!input_stats) { + return nullptr; + } + auto lower_comparison = between.LowerComparisonType(); + auto upper_comparison = between.UpperComparisonType(); + // propagate the comparisons + auto lower_prune = FilterPropagateResult::NO_PRUNING_POSSIBLE; + auto upper_prune = FilterPropagateResult::NO_PRUNING_POSSIBLE; + if (lower_stats) { + lower_prune = PropagateComparison(*input_stats, *lower_stats, lower_comparison); + } + if (upper_stats) { + upper_prune = PropagateComparison(*input_stats, *upper_stats, upper_comparison); + } + if (lower_prune == FilterPropagateResult::FILTER_ALWAYS_TRUE && + upper_prune == FilterPropagateResult::FILTER_ALWAYS_TRUE) { + // both filters are always true: replace the between expression with a constant true + *expr_ptr = make_uniq(Value::BOOLEAN(true)); + } else if (lower_prune == FilterPropagateResult::FILTER_ALWAYS_FALSE || + upper_prune == FilterPropagateResult::FILTER_ALWAYS_FALSE) { + // either one of the filters is always false: replace the between expression with a constant false + *expr_ptr = make_uniq(Value::BOOLEAN(false)); + } else if (lower_prune == FilterPropagateResult::FILTER_FALSE_OR_NULL || + upper_prune == FilterPropagateResult::FILTER_FALSE_OR_NULL) { + // either one of the filters is false or null: replace with a constant or null (false) + vector> children; + children.push_back(std::move(between.input)); + children.push_back(std::move(between.lower)); + children.push_back(std::move(between.upper)); + *expr_ptr = ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(false)); + } else if (lower_prune == FilterPropagateResult::FILTER_TRUE_OR_NULL && + upper_prune == FilterPropagateResult::FILTER_TRUE_OR_NULL) { + // both filters are true or null: replace with a true or null + vector> children; + children.push_back(std::move(between.input)); + children.push_back(std::move(between.lower)); + children.push_back(std::move(between.upper)); + *expr_ptr = ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(true)); + } else if (lower_prune == FilterPropagateResult::FILTER_ALWAYS_TRUE) { + // lower filter is always true: replace with upper comparison + *expr_ptr = + make_uniq(upper_comparison, std::move(between.input), std::move(between.upper)); + } else if (upper_prune == FilterPropagateResult::FILTER_ALWAYS_TRUE) { + // upper filter is always true: replace with lower comparison + *expr_ptr = + make_uniq(lower_comparison, std::move(between.input), std::move(between.lower)); + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_case.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_case.cpp new file mode 100644 index 00000000..7586b9a4 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_case.cpp @@ -0,0 +1,22 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateExpression(BoundCaseExpression &bound_case, + unique_ptr *expr_ptr) { + // propagate in all the children + auto result_stats = PropagateExpression(bound_case.else_expr); + for (auto &case_check : bound_case.case_checks) { + PropagateExpression(case_check.when_expr); + auto then_stats = PropagateExpression(case_check.then_expr); + if (!then_stats) { + result_stats.reset(); + } else if (result_stats) { + result_stats->Merge(*then_stats); + } + } + return result_stats; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp new file mode 100644 index 00000000..e9a2d811 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp @@ -0,0 +1,79 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +static unique_ptr StatisticsOperationsNumericNumericCast(const BaseStatistics &input, + const LogicalType &target) { + if (!NumericStats::HasMinMax(input)) { + return nullptr; + } + Value min = NumericStats::Min(input); + Value max = NumericStats::Max(input); + if (!min.DefaultTryCastAs(target) || !max.DefaultTryCastAs(target)) { + // overflow in cast: bailout + return nullptr; + } + auto result = NumericStats::CreateEmpty(target); + result.CopyBase(input); + NumericStats::SetMin(result, min); + NumericStats::SetMax(result, max); + return result.ToUnique(); +} + +static unique_ptr StatisticsNumericCastSwitch(const BaseStatistics &input, const LogicalType &target) { + // Downcasting timestamps to times is not a truncation operation + switch (target.id()) { + case LogicalTypeId::TIME: + switch (input.GetType().id()) { + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return nullptr; + default: + break; + } + default: + break; + } + + switch (target.InternalType()) { + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::INT128: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + return StatisticsOperationsNumericNumericCast(input, target); + default: + return nullptr; + } +} + +unique_ptr StatisticsPropagator::PropagateExpression(BoundCastExpression &cast, + unique_ptr *expr_ptr) { + auto child_stats = PropagateExpression(cast.child); + if (!child_stats) { + return nullptr; + } + unique_ptr result_stats; + switch (cast.child->return_type.InternalType()) { + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::INT128: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + result_stats = StatisticsNumericCastSwitch(*child_stats, cast.return_type); + break; + default: + return nullptr; + } + if (cast.try_cast && result_stats) { + result_stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + return result_stats; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_columnref.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_columnref.cpp new file mode 100644 index 00000000..8a8d7db1 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_columnref.cpp @@ -0,0 +1,15 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateExpression(BoundColumnRefExpression &colref, + unique_ptr *expr_ptr) { + auto stats = statistics_map.find(colref.binding); + if (stats == statistics_map.end()) { + return nullptr; + } + return stats->second->ToUnique(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_comparison.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_comparison.cpp new file mode 100644 index 00000000..0a073c54 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_comparison.cpp @@ -0,0 +1,129 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" + +namespace duckdb { + +FilterPropagateResult StatisticsPropagator::PropagateComparison(BaseStatistics &lstats, BaseStatistics &rstats, + ExpressionType comparison) { + // only handle numerics for now + switch (lstats.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::INT128: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + break; + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + if (!NumericStats::HasMinMax(lstats) || !NumericStats::HasMinMax(rstats)) { + // no stats available: nothing to prune + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + // the result of the propagation depend on whether or not either side has null values + // if there are null values present, we cannot say whether or not + bool has_null = lstats.CanHaveNull() || rstats.CanHaveNull(); + switch (comparison) { + case ExpressionType::COMPARE_EQUAL: + // l = r, if l.min > r.max or r.min > l.max equality is not possible + if (NumericStats::Min(lstats) > NumericStats::Max(rstats) || + NumericStats::Min(rstats) > NumericStats::Max(lstats)) { + return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; + } else { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + case ExpressionType::COMPARE_GREATERTHAN: + // l > r + if (NumericStats::Min(lstats) > NumericStats::Max(rstats)) { + // if l.min > r.max, it is always true ONLY if neither side contains nulls + return has_null ? FilterPropagateResult::FILTER_TRUE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + // if r.min is bigger or equal to l.max, the filter is always false + if (NumericStats::Min(rstats) >= NumericStats::Max(lstats)) { + return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + // l >= r + if (NumericStats::Min(lstats) >= NumericStats::Max(rstats)) { + // if l.min >= r.max, it is always true ONLY if neither side contains nulls + return has_null ? FilterPropagateResult::FILTER_TRUE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + // if r.min > l.max, the filter is always false + if (NumericStats::Min(rstats) > NumericStats::Max(lstats)) { + return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + case ExpressionType::COMPARE_LESSTHAN: + // l < r + if (NumericStats::Max(lstats) < NumericStats::Min(rstats)) { + // if l.max < r.min, it is always true ONLY if neither side contains nulls + return has_null ? FilterPropagateResult::FILTER_TRUE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + // if l.min >= rstats.max, the filter is always false + if (NumericStats::Min(lstats) >= NumericStats::Max(rstats)) { + return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // l <= r + if (NumericStats::Max(lstats) <= NumericStats::Min(rstats)) { + // if l.max <= r.min, it is always true ONLY if neither side contains nulls + return has_null ? FilterPropagateResult::FILTER_TRUE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + // if l.min > rstats.max, the filter is always false + if (NumericStats::Min(lstats) > NumericStats::Max(rstats)) { + return has_null ? FilterPropagateResult::FILTER_FALSE_OR_NULL : FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } +} + +unique_ptr StatisticsPropagator::PropagateExpression(BoundComparisonExpression &expr, + unique_ptr *expr_ptr) { + auto left_stats = PropagateExpression(expr.left); + auto right_stats = PropagateExpression(expr.right); + if (!left_stats || !right_stats) { + return nullptr; + } + // propagate the statistics of the comparison operator + auto propagate_result = PropagateComparison(*left_stats, *right_stats, expr.type); + switch (propagate_result) { + case FilterPropagateResult::FILTER_ALWAYS_TRUE: + *expr_ptr = make_uniq(Value::BOOLEAN(true)); + return PropagateExpression(*expr_ptr); + case FilterPropagateResult::FILTER_ALWAYS_FALSE: + *expr_ptr = make_uniq(Value::BOOLEAN(false)); + return PropagateExpression(*expr_ptr); + case FilterPropagateResult::FILTER_TRUE_OR_NULL: { + vector> children; + children.push_back(std::move(expr.left)); + children.push_back(std::move(expr.right)); + *expr_ptr = ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(true)); + return nullptr; + } + case FilterPropagateResult::FILTER_FALSE_OR_NULL: { + vector> children; + children.push_back(std::move(expr.left)); + children.push_back(std::move(expr.right)); + *expr_ptr = ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(false)); + return nullptr; + } + default: + // FIXME: we can propagate nulls here, i.e. this expression will have nulls only if left and right has nulls + return nullptr; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_conjunction.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_conjunction.cpp new file mode 100644 index 00000000..1fce16c8 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_conjunction.cpp @@ -0,0 +1,67 @@ + +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateExpression(BoundConjunctionExpression &expr, + unique_ptr *expr_ptr) { + auto is_and = expr.type == ExpressionType::CONJUNCTION_AND; + for (idx_t expr_idx = 0; expr_idx < expr.children.size(); expr_idx++) { + auto &child = expr.children[expr_idx]; + auto stats = PropagateExpression(child); + if (!child->IsFoldable()) { + continue; + } + // we have a constant in a conjunction + // we (1) either prune the child + // or (2) replace the entire conjunction with a constant + auto constant = ExpressionExecutor::EvaluateScalar(context, *child); + if (constant.IsNull()) { + continue; + } + auto b = BooleanValue::Get(constant); + bool prune_child = false; + bool constant_value = true; + if (b) { + // true + if (is_and) { + // true in and: prune child + prune_child = true; + } else { + // true in OR: replace with TRUE + constant_value = true; + } + } else { + // false + if (is_and) { + // false in AND: replace with FALSE + constant_value = false; + } else { + // false in OR: prune child + prune_child = true; + } + } + if (prune_child) { + expr.children.erase(expr.children.begin() + expr_idx); + expr_idx--; + continue; + } + *expr_ptr = make_uniq(Value::BOOLEAN(constant_value)); + return PropagateExpression(*expr_ptr); + } + if (expr.children.empty()) { + // if there are no children left, replace the conjunction with TRUE (for AND) or FALSE (for OR) + *expr_ptr = make_uniq(Value::BOOLEAN(is_and)); + return PropagateExpression(*expr_ptr); + } else if (expr.children.size() == 1) { + // if there is one child left, replace the conjunction with that one child + *expr_ptr = std::move(expr.children[0]); + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_constant.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_constant.cpp new file mode 100644 index 00000000..c531311d --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_constant.cpp @@ -0,0 +1,18 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/storage/statistics/distinct_statistics.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::StatisticsFromValue(const Value &input) { + return BaseStatistics::FromConstant(input).ToUnique(); +} + +unique_ptr StatisticsPropagator::PropagateExpression(BoundConstantExpression &constant, + unique_ptr *expr_ptr) { + return StatisticsFromValue(constant.value); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp new file mode 100644 index 00000000..3d7a2d7a --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp @@ -0,0 +1,25 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateExpression(BoundFunctionExpression &func, + unique_ptr *expr_ptr) { + vector stats; + stats.reserve(func.children.size()); + for (idx_t i = 0; i < func.children.size(); i++) { + auto stat = PropagateExpression(func.children[i]); + if (!stat) { + stats.push_back(BaseStatistics::CreateUnknown(func.children[i]->return_type)); + } else { + stats.push_back(stat->Copy()); + } + } + if (!func.function.statistics) { + return nullptr; + } + FunctionStatisticsInput input(func, func.bind_info.get(), stats, expr_ptr); + return func.function.statistics(context, input); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_operator.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_operator.cpp new file mode 100644 index 00000000..4bb1f8fd --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_operator.cpp @@ -0,0 +1,88 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateExpression(BoundOperatorExpression &expr, + unique_ptr *expr_ptr) { + bool all_have_stats = true; + vector> child_stats; + child_stats.reserve(expr.children.size()); + for (auto &child : expr.children) { + auto stats = PropagateExpression(child); + if (!stats) { + all_have_stats = false; + } + child_stats.push_back(std::move(stats)); + } + if (!all_have_stats) { + return nullptr; + } + switch (expr.type) { + case ExpressionType::OPERATOR_COALESCE: + // COALESCE, merge stats of all children + for (idx_t i = 0; i < expr.children.size(); i++) { + D_ASSERT(child_stats[i]); + if (!child_stats[i]->CanHaveNoNull()) { + // this child is always NULL, we can remove it from the coalesce + // UNLESS there is only one node remaining + if (expr.children.size() > 1) { + expr.children.erase(expr.children.begin() + i); + child_stats.erase(child_stats.begin() + i); + i--; + } + } else if (!child_stats[i]->CanHaveNull()) { + // coalesce child cannot have NULL entries + // this is the last coalesce node that influences the result + // we can erase any children after this node + if (i + 1 < expr.children.size()) { + expr.children.erase(expr.children.begin() + i + 1, expr.children.end()); + child_stats.erase(child_stats.begin() + i + 1, child_stats.end()); + } + break; + } + } + D_ASSERT(!expr.children.empty()); + D_ASSERT(expr.children.size() == child_stats.size()); + if (expr.children.size() == 1) { + // coalesce of one entry: simply return that entry + *expr_ptr = std::move(expr.children[0]); + } else { + // coalesce of multiple entries + // merge the stats + for (idx_t i = 1; i < expr.children.size(); i++) { + child_stats[0]->Merge(*child_stats[i]); + } + } + return std::move(child_stats[0]); + case ExpressionType::OPERATOR_IS_NULL: + if (!child_stats[0]->CanHaveNull()) { + // child has no null values: x IS NULL will always be false + *expr_ptr = make_uniq(Value::BOOLEAN(false)); + return PropagateExpression(*expr_ptr); + } + if (!child_stats[0]->CanHaveNoNull()) { + // child has no valid values: x IS NULL will always be true + *expr_ptr = make_uniq(Value::BOOLEAN(true)); + return PropagateExpression(*expr_ptr); + } + return nullptr; + case ExpressionType::OPERATOR_IS_NOT_NULL: + if (!child_stats[0]->CanHaveNull()) { + // child has no null values: x IS NOT NULL will always be true + *expr_ptr = make_uniq(Value::BOOLEAN(true)); + return PropagateExpression(*expr_ptr); + } + if (!child_stats[0]->CanHaveNoNull()) { + // child has no valid values: x IS NOT NULL will always be false + *expr_ptr = make_uniq(Value::BOOLEAN(false)); + return PropagateExpression(*expr_ptr); + } + return nullptr; + default: + return nullptr; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp new file mode 100644 index 00000000..1824a4f1 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp @@ -0,0 +1,41 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalAggregate &aggr, + unique_ptr *node_ptr) { + // first propagate statistics in the child node + node_stats = PropagateStatistics(aggr.children[0]); + + // handle the groups: simply propagate statistics and assign the stats to the group binding + aggr.group_stats.resize(aggr.groups.size()); + for (idx_t group_idx = 0; group_idx < aggr.groups.size(); group_idx++) { + auto stats = PropagateExpression(aggr.groups[group_idx]); + aggr.group_stats[group_idx] = stats ? stats->ToUnique() : nullptr; + if (!stats) { + continue; + } + if (aggr.grouping_sets.size() > 1) { + // aggregates with multiple grouping sets can introduce NULL values to certain groups + // FIXME: actually figure out WHICH groups can have null values introduced + stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + continue; + } + ColumnBinding group_binding(aggr.group_index, group_idx); + statistics_map[group_binding] = std::move(stats); + } + // propagate statistics in the aggregates + for (idx_t aggregate_idx = 0; aggregate_idx < aggr.expressions.size(); aggregate_idx++) { + auto stats = PropagateExpression(aggr.expressions[aggregate_idx]); + if (!stats) { + continue; + } + ColumnBinding aggregate_binding(aggr.aggregate_index, aggregate_idx); + statistics_map[aggregate_binding] = std::move(stats); + } + // the max cardinality of an aggregate is the max cardinality of the input (i.e. when every row is a unique group) + return std::move(node_stats); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_cross_product.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_cross_product.cpp new file mode 100644 index 00000000..c69c9b88 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_cross_product.cpp @@ -0,0 +1,18 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalCrossProduct &cp, + unique_ptr *node_ptr) { + // first propagate statistics in the child node + auto left_stats = PropagateStatistics(cp.children[0]); + auto right_stats = PropagateStatistics(cp.children[1]); + if (!left_stats || !right_stats) { + return nullptr; + } + MultiplyCardinalities(left_stats, *right_stats); + return left_stats; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_filter.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_filter.cpp new file mode 100644 index 00000000..a9bfe91c --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_filter.cpp @@ -0,0 +1,255 @@ +#include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +static bool IsCompareDistinct(ExpressionType type) { + return type == ExpressionType::COMPARE_DISTINCT_FROM || type == ExpressionType::COMPARE_NOT_DISTINCT_FROM; +} + +bool StatisticsPropagator::ExpressionIsConstant(Expression &expr, const Value &val) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { + return false; + } + auto &bound_constant = expr.Cast(); + D_ASSERT(bound_constant.value.type() == val.type()); + return Value::NotDistinctFrom(bound_constant.value, val); +} + +bool StatisticsPropagator::ExpressionIsConstantOrNull(Expression &expr, const Value &val) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + return false; + } + auto &bound_function = expr.Cast(); + return ConstantOrNull::IsConstantOrNull(bound_function, val); +} + +void StatisticsPropagator::SetStatisticsNotNull(ColumnBinding binding) { + auto entry = statistics_map.find(binding); + if (entry == statistics_map.end()) { + return; + } + entry->second->Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); +} + +void StatisticsPropagator::UpdateFilterStatistics(BaseStatistics &stats, ExpressionType comparison_type, + const Value &constant) { + // regular comparisons removes all null values + if (!IsCompareDistinct(comparison_type)) { + stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); + } + if (!stats.GetType().IsNumeric()) { + // don't handle non-numeric columns here (yet) + return; + } + if (!NumericStats::HasMinMax(stats)) { + // no stats available: skip this + return; + } + switch (comparison_type) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // X < constant OR X <= constant + // max becomes the constant + NumericStats::SetMax(stats, constant); + break; + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + // X > constant OR X >= constant + // min becomes the constant + NumericStats::SetMin(stats, constant); + break; + case ExpressionType::COMPARE_EQUAL: + // X = constant + // both min and max become the constant + NumericStats::SetMin(stats, constant); + NumericStats::SetMax(stats, constant); + break; + default: + break; + } +} + +void StatisticsPropagator::UpdateFilterStatistics(BaseStatistics &lstats, BaseStatistics &rstats, + ExpressionType comparison_type) { + // regular comparisons removes all null values + if (!IsCompareDistinct(comparison_type)) { + lstats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); + rstats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); + } + D_ASSERT(lstats.GetType() == rstats.GetType()); + if (!lstats.GetType().IsNumeric()) { + // don't handle non-numeric columns here (yet) + return; + } + if (!NumericStats::HasMinMax(lstats) || !NumericStats::HasMinMax(rstats)) { + // no stats available: skip this + return; + } + switch (comparison_type) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // LEFT < RIGHT OR LEFT <= RIGHT + // we know that every value of left is smaller (or equal to) every value in right + // i.e. if we have left = [-50, 250] and right = [-100, 100] + + // we know that left.max is AT MOST equal to right.max + // because any value in left that is BIGGER than right.max will not pass the filter + if (NumericStats::Max(lstats) > NumericStats::Max(rstats)) { + NumericStats::SetMax(lstats, NumericStats::Max(rstats)); + } + + // we also know that right.min is AT MOST equal to left.min + // because any value in right that is SMALLER than left.min will not pass the filter + if (NumericStats::Min(rstats) < NumericStats::Min(lstats)) { + NumericStats::SetMin(rstats, NumericStats::Min(lstats)); + } + // so in our example, the bounds get updated as follows: + // left: [-50, 100], right: [-50, 100] + break; + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + // LEFT > RIGHT OR LEFT >= RIGHT + // we know that every value of left is bigger (or equal to) every value in right + // this is essentially the inverse of the less than (or equal to) scenario + if (NumericStats::Max(rstats) > NumericStats::Max(lstats)) { + NumericStats::SetMax(rstats, NumericStats::Max(lstats)); + } + if (NumericStats::Min(lstats) < NumericStats::Min(rstats)) { + NumericStats::SetMin(lstats, NumericStats::Min(rstats)); + } + break; + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + // LEFT = RIGHT + // only the tightest bounds pass + // so if we have e.g. left = [-50, 250] and right = [-100, 100] + // the tighest bounds are [-50, 100] + // select the highest min + if (NumericStats::Min(lstats) > NumericStats::Min(rstats)) { + NumericStats::SetMin(rstats, NumericStats::Min(lstats)); + } else { + NumericStats::SetMin(lstats, NumericStats::Min(rstats)); + } + // select the lowest max + if (NumericStats::Max(lstats) < NumericStats::Max(rstats)) { + NumericStats::SetMax(rstats, NumericStats::Max(lstats)); + } else { + NumericStats::SetMax(lstats, NumericStats::Max(rstats)); + } + break; + default: + break; + } +} + +void StatisticsPropagator::UpdateFilterStatistics(Expression &left, Expression &right, ExpressionType comparison_type) { + // first check if either side is a bound column ref + // any column ref involved in a comparison will not be null after the comparison + bool compare_distinct = IsCompareDistinct(comparison_type); + if (!compare_distinct && left.type == ExpressionType::BOUND_COLUMN_REF) { + SetStatisticsNotNull((left.Cast()).binding); + } + if (!compare_distinct && right.type == ExpressionType::BOUND_COLUMN_REF) { + SetStatisticsNotNull((right.Cast()).binding); + } + // check if this is a comparison between a constant and a column ref + optional_ptr constant; + optional_ptr columnref; + if (left.type == ExpressionType::VALUE_CONSTANT && right.type == ExpressionType::BOUND_COLUMN_REF) { + constant = &left.Cast(); + columnref = &right.Cast(); + comparison_type = FlipComparisonExpression(comparison_type); + } else if (left.type == ExpressionType::BOUND_COLUMN_REF && right.type == ExpressionType::VALUE_CONSTANT) { + columnref = &left.Cast(); + constant = &right.Cast(); + } else if (left.type == ExpressionType::BOUND_COLUMN_REF && right.type == ExpressionType::BOUND_COLUMN_REF) { + // comparison between two column refs + auto &left_column_ref = left.Cast(); + auto &right_column_ref = right.Cast(); + auto lentry = statistics_map.find(left_column_ref.binding); + auto rentry = statistics_map.find(right_column_ref.binding); + if (lentry == statistics_map.end() || rentry == statistics_map.end()) { + return; + } + UpdateFilterStatistics(*lentry->second, *rentry->second, comparison_type); + } else { + // unsupported filter + return; + } + if (constant && columnref) { + // comparison between columnref + auto entry = statistics_map.find(columnref->binding); + if (entry == statistics_map.end()) { + return; + } + UpdateFilterStatistics(*entry->second, comparison_type, constant->value); + } +} + +void StatisticsPropagator::UpdateFilterStatistics(Expression &condition) { + // in filters, we check for constant comparisons with bound columns + // if we find a comparison in the form of e.g. "i=3", we can update our statistics for that column + switch (condition.GetExpressionClass()) { + case ExpressionClass::BOUND_BETWEEN: { + auto &between = condition.Cast(); + UpdateFilterStatistics(*between.input, *between.lower, between.LowerComparisonType()); + UpdateFilterStatistics(*between.input, *between.upper, between.UpperComparisonType()); + break; + } + case ExpressionClass::BOUND_COMPARISON: { + auto &comparison = condition.Cast(); + UpdateFilterStatistics(*comparison.left, *comparison.right, comparison.type); + break; + } + default: + break; + } +} + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalFilter &filter, + unique_ptr *node_ptr) { + // first propagate to the child + node_stats = PropagateStatistics(filter.children[0]); + if (filter.children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { + ReplaceWithEmptyResult(*node_ptr); + return make_uniq(0, 0); + } + + // then propagate to each of the expressions + for (idx_t i = 0; i < filter.expressions.size(); i++) { + auto &condition = filter.expressions[i]; + PropagateExpression(condition); + + if (ExpressionIsConstant(*condition, Value::BOOLEAN(true))) { + // filter is always true; it is useless to execute it + // erase this condition + filter.expressions.erase(filter.expressions.begin() + i); + i--; + if (filter.expressions.empty()) { + // all conditions have been erased: remove the entire filter + *node_ptr = std::move(filter.children[0]); + break; + } + } else if (ExpressionIsConstant(*condition, Value::BOOLEAN(false)) || + ExpressionIsConstantOrNull(*condition, Value::BOOLEAN(false))) { + // filter is always false or null; this entire filter should be replaced by an empty result block + ReplaceWithEmptyResult(*node_ptr); + return make_uniq(0, 0); + } else { + // cannot prune this filter: propagate statistics from the filter + UpdateFilterStatistics(*condition); + } + } + // the max cardinality of a filter is the cardinality of the input (i.e. no tuples get filtered) + return std::move(node_stats); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp new file mode 100644 index 00000000..22979f19 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp @@ -0,0 +1,99 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/table_filter.hpp" + +namespace duckdb { + +FilterPropagateResult StatisticsPropagator::PropagateTableFilter(BaseStatistics &stats, TableFilter &filter) { + return filter.CheckStatistics(stats); +} + +void StatisticsPropagator::UpdateFilterStatistics(BaseStatistics &input, TableFilter &filter) { + // FIXME: update stats... + switch (filter.filter_type) { + case TableFilterType::CONJUNCTION_AND: { + auto &conjunction_and = filter.Cast(); + for (auto &child_filter : conjunction_and.child_filters) { + UpdateFilterStatistics(input, *child_filter); + } + break; + } + case TableFilterType::CONSTANT_COMPARISON: { + auto &constant_filter = filter.Cast(); + UpdateFilterStatistics(input, constant_filter.comparison_type, constant_filter.constant); + break; + } + default: + break; + } +} + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalGet &get, + unique_ptr *node_ptr) { + if (get.function.cardinality) { + node_stats = get.function.cardinality(context, get.bind_data.get()); + } + if (!get.function.statistics) { + // no column statistics to get + return std::move(node_stats); + } + for (idx_t i = 0; i < get.column_ids.size(); i++) { + auto stats = get.function.statistics(context, get.bind_data.get(), get.column_ids[i]); + if (stats) { + ColumnBinding binding(get.table_index, i); + statistics_map.insert(make_pair(binding, std::move(stats))); + } + } + // push table filters into the statistics + vector column_indexes; + column_indexes.reserve(get.table_filters.filters.size()); + for (auto &kv : get.table_filters.filters) { + column_indexes.push_back(kv.first); + } + + for (auto &table_filter_column : column_indexes) { + idx_t column_index; + for (column_index = 0; column_index < get.column_ids.size(); column_index++) { + if (get.column_ids[column_index] == table_filter_column) { + break; + } + } + D_ASSERT(column_index < get.column_ids.size()); + D_ASSERT(get.column_ids[column_index] == table_filter_column); + + // find the stats + ColumnBinding stats_binding(get.table_index, column_index); + auto entry = statistics_map.find(stats_binding); + if (entry == statistics_map.end()) { + // no stats for this entry + continue; + } + auto &stats = *entry->second; + + // fetch the table filter + D_ASSERT(get.table_filters.filters.count(table_filter_column) > 0); + auto &filter = get.table_filters.filters[table_filter_column]; + auto propagate_result = PropagateTableFilter(stats, *filter); + switch (propagate_result) { + case FilterPropagateResult::FILTER_ALWAYS_TRUE: + // filter is always true; it is useless to execute it + // erase this condition + get.table_filters.filters.erase(table_filter_column); + break; + case FilterPropagateResult::FILTER_FALSE_OR_NULL: + case FilterPropagateResult::FILTER_ALWAYS_FALSE: + // filter is always false; this entire filter should be replaced by an empty result block + ReplaceWithEmptyResult(*node_ptr); + return make_uniq(0, 0); + default: + // general case: filter can be true or false, update this columns' statistics + UpdateFilterStatistics(stats, *filter); + break; + } + } + return std::move(node_stats); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_join.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_join.cpp new file mode 100644 index 00000000..38dcc278 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_join.cpp @@ -0,0 +1,341 @@ +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/operator/logical_any_join.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_join.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" +#include "duckdb/planner/operator/logical_positional_join.hpp" + +namespace duckdb { + +void StatisticsPropagator::PropagateStatistics(LogicalComparisonJoin &join, unique_ptr *node_ptr) { + for (idx_t i = 0; i < join.conditions.size(); i++) { + auto &condition = join.conditions[i]; + const auto stats_left = PropagateExpression(condition.left); + const auto stats_right = PropagateExpression(condition.right); + if (stats_left && stats_right) { + if ((condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || + condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) && + stats_left->CanHaveNull() && stats_right->CanHaveNull()) { + // null values are equal in this join, and both sides can have null values + // nothing to do here + continue; + } + auto prune_result = PropagateComparison(*stats_left, *stats_right, condition.comparison); + // Add stats to logical_join for perfect hash join + join.join_stats.push_back(stats_left->ToUnique()); + join.join_stats.push_back(stats_right->ToUnique()); + switch (prune_result) { + case FilterPropagateResult::FILTER_FALSE_OR_NULL: + case FilterPropagateResult::FILTER_ALWAYS_FALSE: + // filter is always false or null, none of the join conditions matter + switch (join.join_type) { + case JoinType::SEMI: + case JoinType::INNER: + // semi or inner join on false; entire node can be pruned + ReplaceWithEmptyResult(*node_ptr); + return; + case JoinType::ANTI: { + // when the right child has data, return the left child + // when the right child has no data, return an empty set + auto limit = make_uniq(1, 0, nullptr, nullptr); + limit->AddChild(std::move(join.children[1])); + auto cross_product = LogicalCrossProduct::Create(std::move(join.children[0]), std::move(limit)); + *node_ptr = std::move(cross_product); + return; + } + case JoinType::LEFT: + // anti/left outer join: replace right side with empty node + ReplaceWithEmptyResult(join.children[1]); + return; + case JoinType::RIGHT: + // right outer join: replace left side with empty node + ReplaceWithEmptyResult(join.children[0]); + return; + default: + // other join types: can't do much meaningful with this information + // full outer join requires both sides anyway; we can skip the execution of the actual join, but eh + // mark/single join requires knowing if the rhs has null values or not + break; + } + break; + case FilterPropagateResult::FILTER_ALWAYS_TRUE: + // filter is always true + if (join.conditions.size() > 1) { + // there are multiple conditions: erase this condition + join.conditions.erase(join.conditions.begin() + i); + // remove the corresponding statistics + join.join_stats.clear(); + i--; + continue; + } else { + // this is the only condition and it is always true: all conditions are true + switch (join.join_type) { + case JoinType::SEMI: { + // when the right child has data, return the left child + // when the right child has no data, return an empty set + auto limit = make_uniq(1, 0, nullptr, nullptr); + limit->AddChild(std::move(join.children[1])); + auto cross_product = LogicalCrossProduct::Create(std::move(join.children[0]), std::move(limit)); + *node_ptr = std::move(cross_product); + return; + } + case JoinType::INNER: { + // inner, replace with cross product + auto cross_product = + LogicalCrossProduct::Create(std::move(join.children[0]), std::move(join.children[1])); + *node_ptr = std::move(cross_product); + return; + } + case JoinType::ANTI: + // anti join on true: empty result + ReplaceWithEmptyResult(*node_ptr); + return; + default: + // we don't handle mark/single join here yet + break; + } + } + break; + default: + break; + } + } + // after we have propagated, we can update the statistics on both sides + // note that it is fine to do this now, even if the same column is used again later + // e.g. if we have i=j AND i=k, and the stats for j and k are disjoint, we know there are no results + // so if we have e.g. i: [0, 100], j: [0, 25], k: [75, 100] + // we can set i: [0, 25] after the first comparison, and statically determine that the second comparison is fals + + // note that we can't update statistics the same for all join types + // mark and single joins don't filter any tuples -> so there is no propagation possible + // anti joins have inverse statistics propagation + // (i.e. if we have an anti join on i: [0, 100] and j: [0, 25], the resulting stats are i:[25,100]) + // for now we don't handle anti joins + if (condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || + condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + // skip update when null values are equal (for now?) + continue; + } + switch (join.join_type) { + case JoinType::INNER: + case JoinType::SEMI: { + UpdateFilterStatistics(*condition.left, *condition.right, condition.comparison); + auto updated_stats_left = PropagateExpression(condition.left); + auto updated_stats_right = PropagateExpression(condition.right); + + // Try to push lhs stats down rhs and vice versa + if (!context.config.force_index_join && stats_left && stats_right && updated_stats_left && + updated_stats_right && condition.left->type == ExpressionType::BOUND_COLUMN_REF && + condition.right->type == ExpressionType::BOUND_COLUMN_REF) { + CreateFilterFromJoinStats(join.children[0], condition.left, *stats_left, *updated_stats_left); + CreateFilterFromJoinStats(join.children[1], condition.right, *stats_right, *updated_stats_right); + } + + // Update join_stats when is already part of the join + if (join.join_stats.size() == 2) { + join.join_stats[0] = std::move(updated_stats_left); + join.join_stats[1] = std::move(updated_stats_right); + } + break; + } + default: + break; + } + } +} + +void StatisticsPropagator::PropagateStatistics(LogicalAnyJoin &join, unique_ptr *node_ptr) { + // propagate the expression into the join condition + PropagateExpression(join.condition); +} + +void StatisticsPropagator::MultiplyCardinalities(unique_ptr &stats, NodeStatistics &new_stats) { + if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality || + !new_stats.has_max_cardinality) { + stats = nullptr; + return; + } + stats->estimated_cardinality = MaxValue(stats->estimated_cardinality, new_stats.estimated_cardinality); + auto new_max = Hugeint::Multiply(stats->max_cardinality, new_stats.max_cardinality); + if (new_max < NumericLimits::Maximum()) { + int64_t result; + if (!Hugeint::TryCast(new_max, result)) { + throw InternalException("Overflow in cast in statistics propagation"); + } + D_ASSERT(result >= 0); + stats->max_cardinality = idx_t(result); + } else { + stats = nullptr; + } +} + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalJoin &join, + unique_ptr *node_ptr) { + // first propagate through the children of the join + node_stats = PropagateStatistics(join.children[0]); + for (idx_t child_idx = 1; child_idx < join.children.size(); child_idx++) { + auto child_stats = PropagateStatistics(join.children[child_idx]); + if (!child_stats) { + node_stats = nullptr; + } else if (node_stats) { + MultiplyCardinalities(node_stats, *child_stats); + } + } + + auto join_type = join.join_type; + // depending on the join type, we might need to alter the statistics + // LEFT, FULL, RIGHT OUTER and SINGLE joins can introduce null values + // this requires us to alter the statistics after this point in the query plan + bool adds_null_on_left = IsRightOuterJoin(join_type); + bool adds_null_on_right = IsLeftOuterJoin(join_type) || join_type == JoinType::SINGLE; + + vector left_bindings, right_bindings; + if (adds_null_on_left) { + left_bindings = join.children[0]->GetColumnBindings(); + } + if (adds_null_on_right) { + right_bindings = join.children[1]->GetColumnBindings(); + } + + // then propagate into the join conditions + switch (join.type) { + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + PropagateStatistics(join.Cast(), node_ptr); + break; + case LogicalOperatorType::LOGICAL_ANY_JOIN: + PropagateStatistics(join.Cast(), node_ptr); + break; + default: + break; + } + + if (adds_null_on_right) { + // left or full outer join: set IsNull() to true for all rhs statistics + for (auto &binding : right_bindings) { + auto stats = statistics_map.find(binding); + if (stats != statistics_map.end()) { + stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + } + } + if (adds_null_on_left) { + // right or full outer join: set IsNull() to true for all lhs statistics + for (auto &binding : left_bindings) { + auto stats = statistics_map.find(binding); + if (stats != statistics_map.end()) { + stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + } + } + return std::move(node_stats); +} + +static void MaxCardinalities(unique_ptr &stats, NodeStatistics &new_stats) { + if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality || + !new_stats.has_max_cardinality) { + stats = nullptr; + return; + } + stats->estimated_cardinality = MaxValue(stats->estimated_cardinality, new_stats.estimated_cardinality); + stats->max_cardinality = MaxValue(stats->max_cardinality, new_stats.max_cardinality); +} + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalPositionalJoin &join, + unique_ptr *node_ptr) { + D_ASSERT(join.type == LogicalOperatorType::LOGICAL_POSITIONAL_JOIN); + + // first propagate through the children of the join + node_stats = PropagateStatistics(join.children[0]); + for (idx_t child_idx = 1; child_idx < join.children.size(); child_idx++) { + auto child_stats = PropagateStatistics(join.children[child_idx]); + if (!child_stats) { + node_stats = nullptr; + } else if (node_stats) { + if (!node_stats->has_estimated_cardinality || !child_stats->has_estimated_cardinality || + !node_stats->has_max_cardinality || !child_stats->has_max_cardinality) { + node_stats = nullptr; + } else { + MaxCardinalities(node_stats, *child_stats); + } + } + } + + // No conditions. + + // Positional Joins are always FULL OUTER + + // set IsNull() to true for all lhs statistics + auto left_bindings = join.children[0]->GetColumnBindings(); + for (auto &binding : left_bindings) { + auto stats = statistics_map.find(binding); + if (stats != statistics_map.end()) { + stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + } + + // set IsNull() to true for all rhs statistics + auto right_bindings = join.children[1]->GetColumnBindings(); + for (auto &binding : right_bindings) { + auto stats = statistics_map.find(binding); + if (stats != statistics_map.end()) { + stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + } + + return std::move(node_stats); +} + +void StatisticsPropagator::CreateFilterFromJoinStats(unique_ptr &child, unique_ptr &expr, + const BaseStatistics &stats_before, + const BaseStatistics &stats_after) { + // Only do this for integral colref's that have stats + if (expr->type != ExpressionType::BOUND_COLUMN_REF || !expr->return_type.IsIntegral() || + !NumericStats::HasMinMax(stats_before) || !NumericStats::HasMinMax(stats_after)) { + return; + } + + // Retrieve min/max + auto min_before = NumericStats::Min(stats_before); + auto max_before = NumericStats::Max(stats_before); + auto min_after = NumericStats::Min(stats_after); + auto max_after = NumericStats::Max(stats_after); + + vector> filter_exprs; + if (min_after > min_before) { + filter_exprs.emplace_back( + make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, expr->Copy(), + make_uniq(std::move(min_after)))); + } + if (max_after < max_before) { + filter_exprs.emplace_back( + make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, expr->Copy(), + make_uniq(std::move(max_after)))); + } + + if (filter_exprs.empty()) { + return; + } + + auto filter = make_uniq(); + filter->children.emplace_back(std::move(child)); + child = std::move(filter); + + for (auto &filter_expr : filter_exprs) { + child->expressions.emplace_back(std::move(filter_expr)); + } + + FilterPushdown filter_pushdown(optimizer); + child = filter_pushdown.Rewrite(std::move(child)); + PropagateExpression(expr); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_limit.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_limit.cpp new file mode 100644 index 00000000..17855bfc --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_limit.cpp @@ -0,0 +1,14 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalLimit &limit, + unique_ptr *node_ptr) { + // propagate statistics in the child node + PropagateStatistics(limit.children[0]); + // return the node stats, with as expected cardinality the amount specified in the limit + return make_uniq(limit.limit_val, limit.limit_val); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_order.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_order.cpp new file mode 100644 index 00000000..5770fcc8 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_order.cpp @@ -0,0 +1,19 @@ + +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/operator/logical_order.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalOrder &order, + unique_ptr *node_ptr) { + // first propagate to the child + node_stats = PropagateStatistics(order.children[0]); + + // then propagate to each of the order expressions + for (auto &bound_order : order.orders) { + bound_order.stats = PropagateExpression(bound_order.expression); + } + return std::move(node_stats); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_projection.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_projection.cpp new file mode 100644 index 00000000..83c83095 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_projection.cpp @@ -0,0 +1,25 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalProjection &proj, + unique_ptr *node_ptr) { + // first propagate to the child + node_stats = PropagateStatistics(proj.children[0]); + if (proj.children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { + ReplaceWithEmptyResult(*node_ptr); + return std::move(node_stats); + } + // then propagate to each of the expressions + for (idx_t i = 0; i < proj.expressions.size(); i++) { + auto stats = PropagateExpression(proj.expressions[i]); + if (stats) { + ColumnBinding binding(proj.table_index, i); + statistics_map.insert(make_pair(binding, std::move(stats))); + } + } + return std::move(node_stats); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_set_operation.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_set_operation.cpp new file mode 100644 index 00000000..8b90aeb6 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_set_operation.cpp @@ -0,0 +1,78 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" + +namespace duckdb { + +void StatisticsPropagator::AddCardinalities(unique_ptr &stats, NodeStatistics &new_stats) { + if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality || + !new_stats.has_max_cardinality) { + stats = nullptr; + return; + } + stats->estimated_cardinality += new_stats.estimated_cardinality; + auto new_max = Hugeint::Add(stats->max_cardinality, new_stats.max_cardinality); + if (new_max < NumericLimits::Maximum()) { + int64_t result; + if (!Hugeint::TryCast(new_max, result)) { + throw InternalException("Overflow in cast in statistics propagation"); + } + D_ASSERT(result >= 0); + stats->max_cardinality = idx_t(result); + } else { + stats = nullptr; + } +} + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalSetOperation &setop, + unique_ptr *node_ptr) { + // first propagate statistics in the child nodes + auto left_stats = PropagateStatistics(setop.children[0]); + auto right_stats = PropagateStatistics(setop.children[1]); + + // now fetch the column bindings on both sides + auto left_bindings = setop.children[0]->GetColumnBindings(); + auto right_bindings = setop.children[1]->GetColumnBindings(); + + D_ASSERT(left_bindings.size() == right_bindings.size()); + D_ASSERT(left_bindings.size() == setop.column_count); + for (idx_t i = 0; i < setop.column_count; i++) { + // for each column binding, we fetch the statistics from both the lhs and the rhs + auto left_entry = statistics_map.find(left_bindings[i]); + auto right_entry = statistics_map.find(right_bindings[i]); + if (left_entry == statistics_map.end() || right_entry == statistics_map.end()) { + // no statistics on one of the sides: can't propagate stats + continue; + } + unique_ptr new_stats; + switch (setop.type) { + case LogicalOperatorType::LOGICAL_UNION: + // union: merge the stats of the LHS and RHS together + new_stats = left_entry->second->ToUnique(); + new_stats->Merge(*right_entry->second); + break; + case LogicalOperatorType::LOGICAL_EXCEPT: + // except: use the stats of the LHS + new_stats = left_entry->second->ToUnique(); + break; + case LogicalOperatorType::LOGICAL_INTERSECT: + // intersect: intersect the two stats + // FIXME: for now we just use the stats of the LHS, as this is correct + // however, the stats can be further refined to the minimal subset of the LHS and RHS + new_stats = left_entry->second->ToUnique(); + break; + default: + throw InternalException("Unsupported setop type"); + } + ColumnBinding binding(setop.table_index, i); + statistics_map[binding] = std::move(new_stats); + } + if (!left_stats || !right_stats) { + return nullptr; + } + if (setop.type == LogicalOperatorType::LOGICAL_UNION) { + AddCardinalities(left_stats, *right_stats); + } + return left_stats; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_window.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_window.cpp new file mode 100644 index 00000000..1bfd3fc0 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_window.cpp @@ -0,0 +1,25 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/operator/logical_window.hpp" + +namespace duckdb { + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalWindow &window, + unique_ptr *node_ptr) { + // first propagate to the child + node_stats = PropagateStatistics(window.children[0]); + + // then propagate to each of the order expressions + for (auto &window_expr : window.expressions) { + auto over_expr = reinterpret_cast(window_expr.get()); + for (auto &expr : over_expr->partitions) { + over_expr->partitions_stats.push_back(PropagateExpression(expr)); + } + for (auto &bound_order : over_expr->orders) { + bound_order.stats = PropagateExpression(bound_order.expression); + } + } + return std::move(node_stats); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics_propagator.cpp b/src/duckdb/src/optimizer/statistics_propagator.cpp new file mode 100644 index 00000000..70169723 --- /dev/null +++ b/src/duckdb/src/optimizer/statistics_propagator.cpp @@ -0,0 +1,114 @@ +#include "duckdb/optimizer/statistics_propagator.hpp" + +#include "duckdb/main/client_context.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_join.hpp" +#include "duckdb/planner/operator/logical_order.hpp" +#include "duckdb/planner/operator/logical_positional_join.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/operator/logical_window.hpp" + +namespace duckdb { + +StatisticsPropagator::StatisticsPropagator(Optimizer &optimizer_p) + : optimizer(optimizer_p), context(optimizer.context) { +} + +void StatisticsPropagator::ReplaceWithEmptyResult(unique_ptr &node) { + node = make_uniq(std::move(node)); +} + +unique_ptr StatisticsPropagator::PropagateChildren(LogicalOperator &node, + unique_ptr *node_ptr) { + for (idx_t child_idx = 0; child_idx < node.children.size(); child_idx++) { + PropagateStatistics(node.children[child_idx]); + } + return nullptr; +} + +unique_ptr StatisticsPropagator::PropagateStatistics(LogicalOperator &node, + unique_ptr *node_ptr) { + switch (node.type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_FILTER: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_GET: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_PROJECTION: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_ORDER_BY: + return PropagateStatistics(node.Cast(), node_ptr); + case LogicalOperatorType::LOGICAL_WINDOW: + return PropagateStatistics(node.Cast(), node_ptr); + default: + return PropagateChildren(node, node_ptr); + } +} + +unique_ptr StatisticsPropagator::PropagateStatistics(unique_ptr &node_ptr) { + return PropagateStatistics(*node_ptr, &node_ptr); +} + +unique_ptr StatisticsPropagator::PropagateExpression(Expression &expr, + unique_ptr *expr_ptr) { + switch (expr.GetExpressionClass()) { + case ExpressionClass::BOUND_AGGREGATE: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_BETWEEN: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_CASE: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_CONJUNCTION: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_FUNCTION: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_CAST: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_COMPARISON: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_CONSTANT: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_COLUMN_REF: + return PropagateExpression(expr.Cast(), expr_ptr); + case ExpressionClass::BOUND_OPERATOR: + return PropagateExpression(expr.Cast(), expr_ptr); + default: + break; + } + ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &child) { PropagateExpression(child); }); + return nullptr; +} + +unique_ptr StatisticsPropagator::PropagateExpression(unique_ptr &expr) { + auto stats = PropagateExpression(*expr, &expr); + if (ClientConfig::GetConfig(context).query_verification_enabled && stats) { + expr->verification_stats = stats->ToUnique(); + } + return stats; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/topn_optimizer.cpp b/src/duckdb/src/optimizer/topn_optimizer.cpp new file mode 100644 index 00000000..4e7b59ad --- /dev/null +++ b/src/duckdb/src/optimizer/topn_optimizer.cpp @@ -0,0 +1,47 @@ +#include "duckdb/optimizer/topn_optimizer.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" +#include "duckdb/planner/operator/logical_order.hpp" +#include "duckdb/planner/operator/logical_top_n.hpp" + +namespace duckdb { + +bool TopN::CanOptimize(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_LIMIT && + op.children[0]->type == LogicalOperatorType::LOGICAL_ORDER_BY) { + auto &limit = op.Cast(); + + // When there are some expressions in the limit operator, + // we shouldn't use this optimizations. Because of the expressions + // will be lost when it convert to TopN operator. + if (limit.limit || limit.offset) { + return false; + } + + // This optimization doesn't apply when OFFSET is present without LIMIT + // Or if offset is not constant + if (limit.limit_val != NumericLimits::Maximum() || limit.offset) { + return true; + } + } + return false; +} + +unique_ptr TopN::Optimize(unique_ptr op) { + if (CanOptimize(*op)) { + auto &limit = op->Cast(); + auto &order_by = (op->children[0])->Cast(); + + auto topn = make_uniq(std::move(order_by.orders), limit.limit_val, limit.offset_val); + topn->AddChild(std::move(order_by.children[0])); + op = std::move(topn); + } else { + for (auto &child : op->children) { + child = Optimize(std::move(child)); + } + } + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/unnest_rewriter.cpp b/src/duckdb/src/optimizer/unnest_rewriter.cpp new file mode 100644 index 00000000..4c935058 --- /dev/null +++ b/src/duckdb/src/optimizer/unnest_rewriter.cpp @@ -0,0 +1,329 @@ +#include "duckdb/optimizer/unnest_rewriter.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/planner/operator/logical_delim_get.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_unnest.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_window.hpp" +#include "duckdb/planner/expression/bound_unnest_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" + +namespace duckdb { + +void UnnestRewriterPlanUpdater::VisitOperator(LogicalOperator &op) { + VisitOperatorChildren(op); + VisitOperatorExpressions(op); +} + +void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr *expression) { + auto &expr = *expression; + + if (expr->expression_class == ExpressionClass::BOUND_COLUMN_REF) { + auto &bound_column_ref = expr->Cast(); + for (idx_t i = 0; i < replace_bindings.size(); i++) { + if (bound_column_ref.binding == replace_bindings[i].old_binding) { + bound_column_ref.binding = replace_bindings[i].new_binding; + break; + } + } + } + + VisitExpressionChildren(**expression); +} + +unique_ptr UnnestRewriter::Optimize(unique_ptr op) { + + UnnestRewriterPlanUpdater updater; + vector *> candidates; + FindCandidates(&op, candidates); + + // rewrite the plan and update the bindings + for (auto &candidate : candidates) { + + // rearrange the logical operators + if (RewriteCandidate(candidate)) { + updater.overwritten_tbl_idx = overwritten_tbl_idx; + // update the bindings of the BOUND_UNNEST expression + UpdateBoundUnnestBindings(updater, candidate); + // update the sequence of LOGICAL_PROJECTION(s) + UpdateRHSBindings(&op, candidate, updater); + // reset + delim_columns.clear(); + lhs_bindings.clear(); + } + } + + return op; +} + +void UnnestRewriter::FindCandidates(unique_ptr *op_ptr, + vector *> &candidates) { + auto op = op_ptr->get(); + // search children before adding, so that we add candidates bottom-up + for (auto &child : op->children) { + FindCandidates(&child, candidates); + } + + // search for operator that has a LOGICAL_DELIM_JOIN as its child + if (op->children.size() != 1) { + return; + } + if (op->children[0]->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { + return; + } + + // found a delim join + auto &delim_join = op->children[0]->Cast(); + // only support INNER delim joins + if (delim_join.join_type != JoinType::INNER) { + return; + } + // INNER delim join must have exactly one condition + if (delim_join.conditions.size() != 1) { + return; + } + + // LHS child is a window + if (delim_join.children[0]->type != LogicalOperatorType::LOGICAL_WINDOW) { + return; + } + + // RHS child must be projection(s) followed by an UNNEST + auto curr_op = &delim_join.children[1]; + while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { + if (curr_op->get()->children.size() != 1) { + break; + } + curr_op = &curr_op->get()->children[0]; + } + + if (curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST) { + candidates.push_back(op_ptr); + } +} + +bool UnnestRewriter::RewriteCandidate(unique_ptr *candidate) { + + auto &topmost_op = (LogicalOperator &)**candidate; + if (topmost_op.type != LogicalOperatorType::LOGICAL_PROJECTION && + topmost_op.type != LogicalOperatorType::LOGICAL_WINDOW && + topmost_op.type != LogicalOperatorType::LOGICAL_FILTER && + topmost_op.type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY && + topmost_op.type != LogicalOperatorType::LOGICAL_UNNEST) { + return false; + } + + // get the LOGICAL_DELIM_JOIN, which is a child of the candidate + D_ASSERT(topmost_op.children.size() == 1); + auto &delim_join = *(topmost_op.children[0]); + D_ASSERT(delim_join.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); + GetDelimColumns(delim_join); + + // LHS of the LOGICAL_DELIM_JOIN is a LOGICAL_WINDOW that contains a LOGICAL_PROJECTION + // this lhs_proj later becomes the child of the UNNEST + auto &window = *delim_join.children[0]; + auto &lhs_op = window.children[0]; + GetLHSExpressions(*lhs_op); + + // find the LOGICAL_UNNEST + // and get the path down to the LOGICAL_UNNEST + vector *> path_to_unnest; + auto curr_op = &(delim_join.children[1]); + while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { + path_to_unnest.push_back(curr_op); + curr_op = &curr_op->get()->children[0]; + } + + // store the table index of the child of the LOGICAL_UNNEST + // then update the plan by making the lhs_proj the child of the LOGICAL_UNNEST + D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); + auto &unnest = curr_op->get()->Cast(); + D_ASSERT(unnest.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET); + overwritten_tbl_idx = unnest.children[0]->Cast().table_index; + + D_ASSERT(!unnest.children.empty()); + auto &delim_get = unnest.children[0]->Cast(); + D_ASSERT(delim_get.chunk_types.size() > 1); + distinct_unnest_count = delim_get.chunk_types.size(); + unnest.children[0] = std::move(lhs_op); + + // replace the LOGICAL_DELIM_JOIN with its RHS child operator + topmost_op.children[0] = std::move(*path_to_unnest.front()); + return true; +} + +void UnnestRewriter::UpdateRHSBindings(unique_ptr *plan_ptr, unique_ptr *candidate, + UnnestRewriterPlanUpdater &updater) { + + auto &topmost_op = (LogicalOperator &)**candidate; + idx_t shift = lhs_bindings.size(); + + vector *> path_to_unnest; + auto curr_op = &(topmost_op.children[0]); + while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { + + path_to_unnest.push_back(curr_op); + D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); + auto &proj = curr_op->get()->Cast(); + + // pop the unnest columns and the delim index + D_ASSERT(proj.expressions.size() > distinct_unnest_count); + for (idx_t i = 0; i < distinct_unnest_count; i++) { + proj.expressions.pop_back(); + } + + // store all shifted current bindings + idx_t tbl_idx = proj.table_index; + for (idx_t i = 0; i < proj.expressions.size(); i++) { + ReplaceBinding replace_binding(ColumnBinding(tbl_idx, i), ColumnBinding(tbl_idx, i + shift)); + updater.replace_bindings.push_back(replace_binding); + } + + curr_op = &curr_op->get()->children[0]; + } + + // update all bindings by shifting them + updater.VisitOperator(*plan_ptr->get()); + updater.replace_bindings.clear(); + + // update all bindings coming from the LHS to RHS bindings + D_ASSERT(topmost_op.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION); + auto &top_proj = topmost_op.children[0]->Cast(); + for (idx_t i = 0; i < lhs_bindings.size(); i++) { + ReplaceBinding replace_binding(lhs_bindings[i].binding, ColumnBinding(top_proj.table_index, i)); + updater.replace_bindings.push_back(replace_binding); + } + + // temporarily remove the BOUND_UNNESTs and the child of the LOGICAL_UNNEST from the plan + D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); + auto &unnest = curr_op->get()->Cast(); + vector> temp_bound_unnests; + for (auto &temp_bound_unnest : unnest.expressions) { + temp_bound_unnests.push_back(std::move(temp_bound_unnest)); + } + D_ASSERT(unnest.children.size() == 1); + auto temp_unnest_child = std::move(unnest.children[0]); + unnest.expressions.clear(); + unnest.children.clear(); + // update the bindings of the plan + updater.VisitOperator(*plan_ptr->get()); + updater.replace_bindings.clear(); + // add the children again + for (auto &temp_bound_unnest : temp_bound_unnests) { + unnest.expressions.push_back(std::move(temp_bound_unnest)); + } + unnest.children.push_back(std::move(temp_unnest_child)); + + // add the LHS expressions to each LOGICAL_PROJECTION + for (idx_t i = path_to_unnest.size(); i > 0; i--) { + + D_ASSERT(path_to_unnest[i - 1]->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); + auto &proj = path_to_unnest[i - 1]->get()->Cast(); + + // temporarily store the existing expressions + vector> existing_expressions; + for (idx_t expr_idx = 0; expr_idx < proj.expressions.size(); expr_idx++) { + existing_expressions.push_back(std::move(proj.expressions[expr_idx])); + } + + proj.expressions.clear(); + + // add the new expressions + for (idx_t expr_idx = 0; expr_idx < lhs_bindings.size(); expr_idx++) { + auto new_expr = make_uniq( + lhs_bindings[expr_idx].alias, lhs_bindings[expr_idx].type, lhs_bindings[expr_idx].binding); + proj.expressions.push_back(std::move(new_expr)); + + // update the table index + lhs_bindings[expr_idx].binding.table_index = proj.table_index; + lhs_bindings[expr_idx].binding.column_index = expr_idx; + } + + // add the existing expressions again + for (idx_t expr_idx = 0; expr_idx < existing_expressions.size(); expr_idx++) { + proj.expressions.push_back(std::move(existing_expressions[expr_idx])); + } + } +} + +void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater, + unique_ptr *candidate) { + + auto &topmost_op = (LogicalOperator &)**candidate; + + // traverse LOGICAL_PROJECTION(s) + auto curr_op = &(topmost_op.children[0]); + while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { + curr_op = &curr_op->get()->children[0]; + } + + // found the LOGICAL_UNNEST + D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); + auto &unnest = curr_op->get()->Cast(); + + D_ASSERT(unnest.children.size() == 1); + auto unnest_cols = unnest.children[0]->GetColumnBindings(); + + for (idx_t i = 0; i < delim_columns.size(); i++) { + auto delim_binding = delim_columns[i]; + + auto unnest_it = unnest_cols.begin(); + while (unnest_it != unnest_cols.end()) { + auto unnest_binding = *unnest_it; + + if (delim_binding.table_index == unnest_binding.table_index) { + unnest_binding.table_index = overwritten_tbl_idx; + unnest_binding.column_index++; + updater.replace_bindings.emplace_back(unnest_binding, delim_binding); + unnest_cols.erase(unnest_it); + break; + } + unnest_it++; + } + } + + // update bindings + for (auto &unnest_expr : unnest.expressions) { + updater.VisitExpression(&unnest_expr); + } + updater.replace_bindings.clear(); +} + +void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { + + D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); + auto &delim_join = op.Cast(); + for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) { + auto &expr = *delim_join.duplicate_eliminated_columns[i]; + D_ASSERT(expr.type == ExpressionType::BOUND_COLUMN_REF); + auto &bound_colref_expr = expr.Cast(); + delim_columns.push_back(bound_colref_expr.binding); + } +} + +void UnnestRewriter::GetLHSExpressions(LogicalOperator &op) { + + op.ResolveOperatorTypes(); + auto col_bindings = op.GetColumnBindings(); + D_ASSERT(op.types.size() == col_bindings.size()); + + bool set_alias = false; + // we can easily extract the alias for LOGICAL_PROJECTION(s) + if (op.type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &proj = op.Cast(); + if (proj.expressions.size() == op.types.size()) { + set_alias = true; + } + } + + for (idx_t i = 0; i < op.types.size(); i++) { + lhs_bindings.emplace_back(col_bindings[i], op.types[i]); + if (set_alias) { + auto &proj = op.Cast(); + lhs_bindings.back().alias = proj.expressions[i]->alias; + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/base_pipeline_event.cpp b/src/duckdb/src/parallel/base_pipeline_event.cpp new file mode 100644 index 00000000..347f114f --- /dev/null +++ b/src/duckdb/src/parallel/base_pipeline_event.cpp @@ -0,0 +1,13 @@ +#include "duckdb/parallel/base_pipeline_event.hpp" + +namespace duckdb { + +BasePipelineEvent::BasePipelineEvent(shared_ptr pipeline_p) + : Event(pipeline_p->executor), pipeline(std::move(pipeline_p)) { +} + +BasePipelineEvent::BasePipelineEvent(Pipeline &pipeline_p) + : Event(pipeline_p.executor), pipeline(pipeline_p.shared_from_this()) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/event.cpp b/src/duckdb/src/parallel/event.cpp new file mode 100644 index 00000000..0f51b41f --- /dev/null +++ b/src/duckdb/src/parallel/event.cpp @@ -0,0 +1,85 @@ +#include "duckdb/parallel/event.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/execution/executor.hpp" + +namespace duckdb { + +Event::Event(Executor &executor_p) + : executor(executor_p), finished_tasks(0), total_tasks(0), finished_dependencies(0), total_dependencies(0), + finished(false) { +} + +void Event::CompleteDependency() { + idx_t current_finished = ++finished_dependencies; + D_ASSERT(current_finished <= total_dependencies); + if (current_finished == total_dependencies) { + // all dependencies have been completed: schedule the event + D_ASSERT(total_tasks == 0); + Schedule(); + if (total_tasks == 0) { + Finish(); + } + } +} + +void Event::Finish() { + D_ASSERT(!finished); + FinishEvent(); + finished = true; + // finished processing the pipeline, now we can schedule pipelines that depend on this pipeline + for (auto &parent_entry : parents) { + auto parent = parent_entry.lock(); + if (!parent) { // LCOV_EXCL_START + continue; + } // LCOV_EXCL_STOP + // mark a dependency as completed for each of the parents + parent->CompleteDependency(); + } + FinalizeFinish(); +} + +void Event::AddDependency(Event &event) { + total_dependencies++; + event.parents.push_back(weak_ptr(shared_from_this())); +#ifdef DEBUG + event.parents_raw.push_back(this); +#endif +} + +const vector &Event::GetParentsVerification() const { + D_ASSERT(parents.size() == parents_raw.size()); + return parents_raw; +} + +void Event::FinishTask() { + D_ASSERT(finished_tasks.load() < total_tasks.load()); + idx_t current_tasks = total_tasks; + idx_t current_finished = ++finished_tasks; + D_ASSERT(current_finished <= current_tasks); + if (current_finished == current_tasks) { + Finish(); + } +} + +void Event::InsertEvent(shared_ptr replacement_event) { + replacement_event->parents = std::move(parents); +#ifdef DEBUG + replacement_event->parents_raw = std::move(parents_raw); +#endif + replacement_event->AddDependency(*this); + executor.AddEvent(std::move(replacement_event)); +} + +void Event::SetTasks(vector> tasks) { + auto &ts = TaskScheduler::GetScheduler(executor.context); + D_ASSERT(total_tasks == 0); + D_ASSERT(!tasks.empty()); + this->total_tasks = tasks.size(); + for (auto &task : tasks) { + ts.ScheduleTask(executor.GetToken(), std::move(task)); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/executor.cpp b/src/duckdb/src/parallel/executor.cpp new file mode 100644 index 00000000..4a55c87c --- /dev/null +++ b/src/duckdb/src/parallel/executor.cpp @@ -0,0 +1,638 @@ +#include "duckdb/execution/executor.hpp" + +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/execution/operator/helper/physical_result_collector.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" +#include "duckdb/execution/operator/set/physical_cte.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline_complete_event.hpp" +#include "duckdb/parallel/pipeline_event.hpp" +#include "duckdb/parallel/pipeline_executor.hpp" +#include "duckdb/parallel/pipeline_finish_event.hpp" +#include "duckdb/parallel/pipeline_initialize_event.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/parallel/thread_context.hpp" + +#include + +namespace duckdb { + +Executor::Executor(ClientContext &context) : context(context) { +} + +Executor::~Executor() { +} + +Executor &Executor::Get(ClientContext &context) { + return context.GetExecutor(); +} + +void Executor::AddEvent(shared_ptr event) { + lock_guard elock(executor_lock); + if (cancelled) { + return; + } + events.push_back(std::move(event)); +} + +struct PipelineEventStack { + PipelineEventStack(Event &pipeline_initialize_event, Event &pipeline_event, Event &pipeline_finish_event, + Event &pipeline_complete_event) + : pipeline_initialize_event(pipeline_initialize_event), pipeline_event(pipeline_event), + pipeline_finish_event(pipeline_finish_event), pipeline_complete_event(pipeline_complete_event) { + } + + Event &pipeline_initialize_event; + Event &pipeline_event; + Event &pipeline_finish_event; + Event &pipeline_complete_event; +}; + +using event_map_t = reference_map_t; + +struct ScheduleEventData { + ScheduleEventData(const vector> &meta_pipelines, vector> &events, + bool initial_schedule) + : meta_pipelines(meta_pipelines), events(events), initial_schedule(initial_schedule) { + } + + const vector> &meta_pipelines; + vector> &events; + bool initial_schedule; + event_map_t event_map; +}; + +void Executor::SchedulePipeline(const shared_ptr &meta_pipeline, ScheduleEventData &event_data) { + D_ASSERT(meta_pipeline); + auto &events = event_data.events; + auto &event_map = event_data.event_map; + + // create events/stack for the base pipeline + auto base_pipeline = meta_pipeline->GetBasePipeline(); + auto base_initialize_event = make_shared(base_pipeline); + auto base_event = make_shared(base_pipeline); + auto base_finish_event = make_shared(base_pipeline); + auto base_complete_event = make_shared(base_pipeline->executor, event_data.initial_schedule); + PipelineEventStack base_stack(*base_initialize_event, *base_event, *base_finish_event, *base_complete_event); + events.push_back(std::move(base_initialize_event)); + events.push_back(std::move(base_event)); + events.push_back(std::move(base_finish_event)); + events.push_back(std::move(base_complete_event)); + + // dependencies: initialize -> event -> finish -> complete + base_stack.pipeline_event.AddDependency(base_stack.pipeline_initialize_event); + base_stack.pipeline_finish_event.AddDependency(base_stack.pipeline_event); + base_stack.pipeline_complete_event.AddDependency(base_stack.pipeline_finish_event); + + // create an event and stack for all pipelines in the MetaPipeline + vector> pipelines; + meta_pipeline->GetPipelines(pipelines, false); + for (idx_t i = 1; i < pipelines.size(); i++) { // loop starts at 1 because 0 is the base pipeline + auto &pipeline = pipelines[i]; + D_ASSERT(pipeline); + + // create events/stack for this pipeline + auto pipeline_event = make_shared(pipeline); + + auto finish_group = meta_pipeline->GetFinishGroup(pipeline.get()); + if (finish_group) { + // this pipeline is part of a finish group + const auto group_entry = event_map.find(*finish_group.get()); + D_ASSERT(group_entry != event_map.end()); + auto &group_stack = group_entry->second; + PipelineEventStack pipeline_stack(base_stack.pipeline_initialize_event, *pipeline_event, + group_stack.pipeline_finish_event, base_stack.pipeline_complete_event); + + // dependencies: base_finish -> pipeline_event -> group_finish + pipeline_stack.pipeline_event.AddDependency(base_stack.pipeline_finish_event); + group_stack.pipeline_finish_event.AddDependency(pipeline_stack.pipeline_event); + + // add pipeline stack to event map + event_map.insert(make_pair(reference(*pipeline), pipeline_stack)); + } else if (meta_pipeline->HasFinishEvent(pipeline.get())) { + // this pipeline has its own finish event (despite going into the same sink - Finalize twice!) + auto pipeline_finish_event = make_shared(pipeline); + PipelineEventStack pipeline_stack(base_stack.pipeline_initialize_event, *pipeline_event, + *pipeline_finish_event, base_stack.pipeline_complete_event); + events.push_back(std::move(pipeline_finish_event)); + + // dependencies: base_finish -> pipeline_event -> pipeline_finish -> base_complete + pipeline_stack.pipeline_event.AddDependency(base_stack.pipeline_finish_event); + pipeline_stack.pipeline_finish_event.AddDependency(pipeline_stack.pipeline_event); + base_stack.pipeline_complete_event.AddDependency(pipeline_stack.pipeline_finish_event); + + // add pipeline stack to event map + event_map.insert(make_pair(reference(*pipeline), pipeline_stack)); + + } else { + // no additional finish event + PipelineEventStack pipeline_stack(base_stack.pipeline_initialize_event, *pipeline_event, + base_stack.pipeline_finish_event, base_stack.pipeline_complete_event); + + // dependencies: base_initialize -> pipeline_event -> base_finish + pipeline_stack.pipeline_event.AddDependency(base_stack.pipeline_initialize_event); + base_stack.pipeline_finish_event.AddDependency(pipeline_stack.pipeline_event); + + // add pipeline stack to event map + event_map.insert(make_pair(reference(*pipeline), pipeline_stack)); + } + events.push_back(std::move(pipeline_event)); + } + + // add base stack to the event data too + event_map.insert(make_pair(reference(*base_pipeline), base_stack)); + + // set up the dependencies within this MetaPipeline + for (auto &pipeline : pipelines) { + auto source = pipeline->GetSource(); + if (source->type == PhysicalOperatorType::TABLE_SCAN) { + // we have to reset the source here (in the main thread), because some of our clients (looking at you, R) + // do not like it when threads other than the main thread call into R, for e.g., arrow scans + pipeline->ResetSource(true); + } + + auto dependencies = meta_pipeline->GetDependencies(pipeline.get()); + if (!dependencies) { + continue; + } + auto root_entry = event_map.find(*pipeline); + D_ASSERT(root_entry != event_map.end()); + auto &pipeline_stack = root_entry->second; + for (auto &dependency : *dependencies) { + auto event_entry = event_map.find(*dependency); + D_ASSERT(event_entry != event_map.end()); + auto &dependency_stack = event_entry->second; + pipeline_stack.pipeline_event.AddDependency(dependency_stack.pipeline_event); + } + } +} + +void Executor::ScheduleEventsInternal(ScheduleEventData &event_data) { + auto &events = event_data.events; + D_ASSERT(events.empty()); + + // create all the required pipeline events + for (auto &pipeline : event_data.meta_pipelines) { + SchedulePipeline(pipeline, event_data); + } + + // set up the dependencies across MetaPipelines + auto &event_map = event_data.event_map; + for (auto &entry : event_map) { + auto &pipeline = entry.first.get(); + for (auto &dependency : pipeline.dependencies) { + auto dep = dependency.lock(); + D_ASSERT(dep); + auto event_map_entry = event_map.find(*dep); + D_ASSERT(event_map_entry != event_map.end()); + auto &dep_entry = event_map_entry->second; + entry.second.pipeline_event.AddDependency(dep_entry.pipeline_complete_event); + } + } + + // verify that we have no cyclic dependencies + VerifyScheduledEvents(event_data); + + // schedule the pipelines that do not have dependencies + for (auto &event : events) { + if (!event->HasDependencies()) { + event->Schedule(); + } + } +} + +void Executor::ScheduleEvents(const vector> &meta_pipelines) { + ScheduleEventData event_data(meta_pipelines, events, true); + ScheduleEventsInternal(event_data); +} + +void Executor::VerifyScheduledEvents(const ScheduleEventData &event_data) { +#ifdef DEBUG + const idx_t count = event_data.events.size(); + vector vertices; + vertices.reserve(count); + for (const auto &event : event_data.events) { + vertices.push_back(event.get()); + } + vector visited(count, false); + vector recursion_stack(count, false); + for (idx_t i = 0; i < count; i++) { + VerifyScheduledEventsInternal(i, vertices, visited, recursion_stack); + } +#endif +} + +void Executor::VerifyScheduledEventsInternal(const idx_t vertex, const vector &vertices, vector &visited, + vector &recursion_stack) { + D_ASSERT(!recursion_stack[vertex]); // this vertex is in the recursion stack: circular dependency! + if (visited[vertex]) { + return; // early out: we already visited this vertex + } + + auto &parents = vertices[vertex]->GetParentsVerification(); + if (parents.empty()) { + return; // early out: outgoing edges + } + + // create a vector the indices of the adjacent events + vector adjacent; + const idx_t count = vertices.size(); + for (auto parent : parents) { + idx_t i; + for (i = 0; i < count; i++) { + if (vertices[i] == parent) { + adjacent.push_back(i); + break; + } + } + D_ASSERT(i != count); // dependency must be in there somewhere + } + + // mark vertex as visited and add to recursion stack + visited[vertex] = true; + recursion_stack[vertex] = true; + + // recurse into adjacent vertices + for (const auto &i : adjacent) { + VerifyScheduledEventsInternal(i, vertices, visited, recursion_stack); + } + + // remove vertex from recursion stack + recursion_stack[vertex] = false; +} + +void Executor::AddRecursiveCTE(PhysicalOperator &rec_cte) { + recursive_ctes.push_back(rec_cte); +} + +void Executor::AddMaterializedCTE(PhysicalOperator &mat_cte) { + materialized_ctes.push_back(mat_cte); +} + +void Executor::ReschedulePipelines(const vector> &pipelines_p, + vector> &events_p) { + ScheduleEventData event_data(pipelines_p, events_p, false); + ScheduleEventsInternal(event_data); +} + +bool Executor::NextExecutor() { + if (root_pipeline_idx >= root_pipelines.size()) { + return false; + } + root_pipelines[root_pipeline_idx]->Reset(); + root_executor = make_uniq(context, *root_pipelines[root_pipeline_idx]); + root_pipeline_idx++; + return true; +} + +void Executor::VerifyPipeline(Pipeline &pipeline) { + D_ASSERT(!pipeline.ToString().empty()); + auto operators = pipeline.GetOperators(); + for (auto &other_pipeline : pipelines) { + auto other_operators = other_pipeline->GetOperators(); + for (idx_t op_idx = 0; op_idx < operators.size(); op_idx++) { + for (idx_t other_idx = 0; other_idx < other_operators.size(); other_idx++) { + auto &left = operators[op_idx].get(); + auto &right = other_operators[other_idx].get(); + if (left.Equals(right)) { + D_ASSERT(right.Equals(left)); + } else { + D_ASSERT(!right.Equals(left)); + } + } + } + } +} + +void Executor::VerifyPipelines() { +#ifdef DEBUG + for (auto &pipeline : pipelines) { + VerifyPipeline(*pipeline); + } +#endif +} + +void Executor::Initialize(unique_ptr physical_plan) { + Reset(); + owned_plan = std::move(physical_plan); + InitializeInternal(*owned_plan); +} + +void Executor::Initialize(PhysicalOperator &plan) { + Reset(); + InitializeInternal(plan); +} + +void Executor::InitializeInternal(PhysicalOperator &plan) { + + auto &scheduler = TaskScheduler::GetScheduler(context); + { + lock_guard elock(executor_lock); + physical_plan = &plan; + + this->profiler = ClientData::Get(context).profiler; + profiler->Initialize(plan); + this->producer = scheduler.CreateProducer(); + + // build and ready the pipelines + PipelineBuildState state; + auto root_pipeline = make_shared(*this, state, nullptr); + root_pipeline->Build(*physical_plan); + root_pipeline->Ready(); + + // ready recursive cte pipelines too + for (auto &rec_cte_ref : recursive_ctes) { + auto &rec_cte = rec_cte_ref.get().Cast(); + rec_cte.recursive_meta_pipeline->Ready(); + } + + // ready materialized cte pipelines too + for (auto &mat_cte_ref : materialized_ctes) { + auto &mat_cte = mat_cte_ref.get().Cast(); + mat_cte.recursive_meta_pipeline->Ready(); + } + + // set root pipelines, i.e., all pipelines that end in the final sink + root_pipeline->GetPipelines(root_pipelines, false); + root_pipeline_idx = 0; + + // collect all meta-pipelines from the root pipeline + vector> to_schedule; + root_pipeline->GetMetaPipelines(to_schedule, true, true); + + // number of 'PipelineCompleteEvent's is equal to the number of meta pipelines, so we have to set it here + total_pipelines = to_schedule.size(); + + // collect all pipelines from the root pipelines (recursively) for the progress bar and verify them + root_pipeline->GetPipelines(pipelines, true); + + // finally, verify and schedule + VerifyPipelines(); + ScheduleEvents(to_schedule); + } +} + +void Executor::CancelTasks() { + task.reset(); + // we do this by creating weak pointers to all pipelines + // then clearing our references to the pipelines + // and waiting until all pipelines have been destroyed + vector> weak_references; + { + lock_guard elock(executor_lock); + weak_references.reserve(pipelines.size()); + cancelled = true; + for (auto &pipeline : pipelines) { + weak_references.push_back(weak_ptr(pipeline)); + } + for (auto &rec_cte_ref : recursive_ctes) { + auto &rec_cte = rec_cte_ref.get().Cast(); + rec_cte.recursive_meta_pipeline.reset(); + } + for (auto &mat_cte_ref : materialized_ctes) { + auto &mat_cte = mat_cte_ref.get().Cast(); + mat_cte.recursive_meta_pipeline.reset(); + } + pipelines.clear(); + root_pipelines.clear(); + to_be_rescheduled_tasks.clear(); + events.clear(); + } + WorkOnTasks(); + for (auto &weak_ref : weak_references) { + while (true) { + auto weak = weak_ref.lock(); + if (!weak) { + break; + } + } + } +} + +void Executor::WorkOnTasks() { + auto &scheduler = TaskScheduler::GetScheduler(context); + + shared_ptr task; + while (scheduler.GetTaskFromProducer(*producer, task)) { + auto res = task->Execute(TaskExecutionMode::PROCESS_ALL); + if (res == TaskExecutionResult::TASK_BLOCKED) { + task->Deschedule(); + } + task.reset(); + } +} + +void Executor::RescheduleTask(shared_ptr &task) { + // This function will spin lock until the task provided is added to the to_be_rescheduled_tasks + while (true) { + lock_guard l(executor_lock); + if (cancelled) { + return; + } + auto entry = to_be_rescheduled_tasks.find(task.get()); + if (entry != to_be_rescheduled_tasks.end()) { + auto &scheduler = TaskScheduler::GetScheduler(context); + to_be_rescheduled_tasks.erase(task.get()); + scheduler.ScheduleTask(GetToken(), task); + break; + } + } +} + +void Executor::AddToBeRescheduled(shared_ptr &task) { + lock_guard l(executor_lock); + if (cancelled) { + return; + } + if (to_be_rescheduled_tasks.find(task.get()) != to_be_rescheduled_tasks.end()) { + return; + } + to_be_rescheduled_tasks[task.get()] = std::move(task); +} + +bool Executor::ExecutionIsFinished() { + return completed_pipelines >= total_pipelines || HasError(); +} + +PendingExecutionResult Executor::ExecuteTask() { + // Only executor should return NO_TASKS_AVAILABLE + D_ASSERT(execution_result != PendingExecutionResult::NO_TASKS_AVAILABLE); + if (execution_result != PendingExecutionResult::RESULT_NOT_READY) { + return execution_result; + } + // check if there are any incomplete pipelines + auto &scheduler = TaskScheduler::GetScheduler(context); + while (completed_pipelines < total_pipelines) { + // there are! if we don't already have a task, fetch one + if (!task) { + scheduler.GetTaskFromProducer(*producer, task); + } + if (!task && !HasError()) { + // there are no tasks to be scheduled and there are tasks blocked + return PendingExecutionResult::NO_TASKS_AVAILABLE; + } + if (task) { + // if we have a task, partially process it + auto result = task->Execute(TaskExecutionMode::PROCESS_PARTIAL); + if (result == TaskExecutionResult::TASK_BLOCKED) { + task->Deschedule(); + task.reset(); + } else if (result == TaskExecutionResult::TASK_FINISHED) { + // if the task is finished, clean it up + task.reset(); + } + } + if (!HasError()) { + // we (partially) processed a task and no exceptions were thrown + // give back control to the caller + return PendingExecutionResult::RESULT_NOT_READY; + } + execution_result = PendingExecutionResult::EXECUTION_ERROR; + + // an exception has occurred executing one of the pipelines + // we need to cancel all tasks associated with this executor + CancelTasks(); + ThrowException(); + } + D_ASSERT(!task); + + lock_guard elock(executor_lock); + pipelines.clear(); + NextExecutor(); + if (HasError()) { // LCOV_EXCL_START + // an exception has occurred executing one of the pipelines + execution_result = PendingExecutionResult::EXECUTION_ERROR; + ThrowException(); + } // LCOV_EXCL_STOP + execution_result = PendingExecutionResult::RESULT_READY; + return execution_result; +} + +void Executor::Reset() { + lock_guard elock(executor_lock); + physical_plan = nullptr; + cancelled = false; + owned_plan.reset(); + root_executor.reset(); + root_pipelines.clear(); + root_pipeline_idx = 0; + completed_pipelines = 0; + total_pipelines = 0; + exceptions.clear(); + pipelines.clear(); + events.clear(); + to_be_rescheduled_tasks.clear(); + execution_result = PendingExecutionResult::RESULT_NOT_READY; +} + +shared_ptr Executor::CreateChildPipeline(Pipeline ¤t, PhysicalOperator &op) { + D_ASSERT(!current.operators.empty()); + D_ASSERT(op.IsSource()); + // found another operator that is a source, schedule a child pipeline + // 'op' is the source, and the sink is the same + auto child_pipeline = make_shared(*this); + child_pipeline->sink = current.sink; + child_pipeline->source = &op; + + // the child pipeline has the same operators up until 'op' + for (auto current_op : current.operators) { + if (¤t_op.get() == &op) { + break; + } + child_pipeline->operators.push_back(current_op); + } + + return child_pipeline; +} + +vector Executor::GetTypes() { + D_ASSERT(physical_plan); + return physical_plan->GetTypes(); +} + +void Executor::PushError(PreservedError exception) { + lock_guard elock(error_lock); + // interrupt execution of any other pipelines that belong to this executor + context.interrupted = true; + // push the exception onto the stack + exceptions.push_back(std::move(exception)); +} + +bool Executor::HasError() { + lock_guard elock(error_lock); + return !exceptions.empty(); +} + +void Executor::ThrowException() { + lock_guard elock(error_lock); + D_ASSERT(!exceptions.empty()); + auto &entry = exceptions[0]; + entry.Throw(); +} + +void Executor::Flush(ThreadContext &tcontext) { + profiler->Flush(tcontext.profiler); +} + +bool Executor::GetPipelinesProgress(double ¤t_progress) { // LCOV_EXCL_START + lock_guard elock(executor_lock); + + vector progress; + vector cardinality; + idx_t total_cardinality = 0; + for (auto &pipeline : pipelines) { + double child_percentage; + idx_t child_cardinality; + + if (!pipeline->GetProgress(child_percentage, child_cardinality)) { + return false; + } + progress.push_back(child_percentage); + cardinality.push_back(child_cardinality); + total_cardinality += child_cardinality; + } + current_progress = 0; + if (total_cardinality == 0) { + return true; + } + for (size_t i = 0; i < progress.size(); i++) { + current_progress += progress[i] * double(cardinality[i]) / double(total_cardinality); + } + return true; +} // LCOV_EXCL_STOP + +bool Executor::HasResultCollector() { + return physical_plan->type == PhysicalOperatorType::RESULT_COLLECTOR; +} + +unique_ptr Executor::GetResult() { + D_ASSERT(HasResultCollector()); + auto &result_collector = physical_plan->Cast(); + D_ASSERT(result_collector.sink_state); + return result_collector.GetResult(*result_collector.sink_state); +} + +unique_ptr Executor::FetchChunk() { + D_ASSERT(physical_plan); + + auto chunk = make_uniq(); + root_executor->InitializeChunk(*chunk); + while (true) { + root_executor->ExecutePull(*chunk); + if (chunk->size() == 0) { + root_executor->PullFinalize(); + if (NextExecutor()) { + continue; + } + break; + } else { + break; + } + } + return chunk; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/executor_task.cpp b/src/duckdb/src/parallel/executor_task.cpp new file mode 100644 index 00000000..dc1f0448 --- /dev/null +++ b/src/duckdb/src/parallel/executor_task.cpp @@ -0,0 +1,39 @@ +#include "duckdb/parallel/task.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +ExecutorTask::ExecutorTask(Executor &executor_p) : executor(executor_p) { +} + +ExecutorTask::ExecutorTask(ClientContext &context) : ExecutorTask(Executor::Get(context)) { +} + +ExecutorTask::~ExecutorTask() { +} + +void ExecutorTask::Deschedule() { + auto this_ptr = shared_from_this(); + executor.AddToBeRescheduled(this_ptr); +} + +void ExecutorTask::Reschedule() { + auto this_ptr = shared_from_this(); + executor.RescheduleTask(this_ptr); +} + +TaskExecutionResult ExecutorTask::Execute(TaskExecutionMode mode) { + try { + return ExecuteTask(mode); + } catch (Exception &ex) { + executor.PushError(PreservedError(ex)); + } catch (std::exception &ex) { + executor.PushError(PreservedError(ex)); + } catch (...) { // LCOV_EXCL_START + executor.PushError(PreservedError("Unknown exception in Finalize!")); + } // LCOV_EXCL_STOP + return TaskExecutionResult::TASK_ERROR; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/interrupt.cpp b/src/duckdb/src/parallel/interrupt.cpp new file mode 100644 index 00000000..45948c41 --- /dev/null +++ b/src/duckdb/src/parallel/interrupt.cpp @@ -0,0 +1,57 @@ +#include "duckdb/parallel/interrupt.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/mutex.hpp" +#include + +namespace duckdb { + +InterruptState::InterruptState() : mode(InterruptMode::NO_INTERRUPTS) { +} +InterruptState::InterruptState(weak_ptr task) : mode(InterruptMode::TASK), current_task(std::move(task)) { +} +InterruptState::InterruptState(weak_ptr signal_state_p) + : mode(InterruptMode::BLOCKING), signal_state(std::move(signal_state_p)) { +} + +void InterruptState::Callback() const { + if (mode == InterruptMode::TASK) { + auto task = current_task.lock(); + + if (!task) { + return; + } + + task->Reschedule(); + } else if (mode == InterruptMode::BLOCKING) { + auto signal_state_l = signal_state.lock(); + + if (!signal_state_l) { + return; + } + + // Signal the caller, who is currently blocked + signal_state_l->Signal(); + } else { + throw InternalException("Callback made on InterruptState without valid interrupt mode specified"); + } +} + +void InterruptDoneSignalState::Signal() { + { + unique_lock lck {lock}; + done = true; + } + cv.notify_all(); +} + +void InterruptDoneSignalState::Await() { + std::unique_lock lck(lock); + cv.wait(lck, [&]() { return done; }); + + // Reset after signal received + done = false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/meta_pipeline.cpp b/src/duckdb/src/parallel/meta_pipeline.cpp new file mode 100644 index 00000000..191f36c7 --- /dev/null +++ b/src/duckdb/src/parallel/meta_pipeline.cpp @@ -0,0 +1,186 @@ +#include "duckdb/parallel/meta_pipeline.hpp" + +#include "duckdb/execution/executor.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" + +namespace duckdb { + +MetaPipeline::MetaPipeline(Executor &executor_p, PipelineBuildState &state_p, PhysicalOperator *sink_p) + : executor(executor_p), state(state_p), sink(sink_p), recursive_cte(false), next_batch_index(0) { + CreatePipeline(); +} + +Executor &MetaPipeline::GetExecutor() const { + return executor; +} + +PipelineBuildState &MetaPipeline::GetState() const { + return state; +} + +optional_ptr MetaPipeline::GetSink() const { + return sink; +} + +shared_ptr &MetaPipeline::GetBasePipeline() { + return pipelines[0]; +} + +void MetaPipeline::GetPipelines(vector> &result, bool recursive) { + result.insert(result.end(), pipelines.begin(), pipelines.end()); + if (recursive) { + for (auto &child : children) { + child->GetPipelines(result, true); + } + } +} + +void MetaPipeline::GetMetaPipelines(vector> &result, bool recursive, bool skip) { + if (!skip) { + result.push_back(shared_from_this()); + } + if (recursive) { + for (auto &child : children) { + child->GetMetaPipelines(result, true, false); + } + } +} + +const vector *MetaPipeline::GetDependencies(Pipeline *dependant) const { + auto it = dependencies.find(dependant); + if (it == dependencies.end()) { + return nullptr; + } else { + return &it->second; + } +} + +bool MetaPipeline::HasRecursiveCTE() const { + return recursive_cte; +} + +void MetaPipeline::SetRecursiveCTE() { + recursive_cte = true; +} + +void MetaPipeline::AssignNextBatchIndex(Pipeline *pipeline) { + pipeline->base_batch_index = next_batch_index++ * PipelineBuildState::BATCH_INCREMENT; +} + +void MetaPipeline::Build(PhysicalOperator &op) { + D_ASSERT(pipelines.size() == 1); + D_ASSERT(children.empty()); + op.BuildPipelines(*pipelines.back(), *this); +} + +void MetaPipeline::Ready() { + for (auto &pipeline : pipelines) { + pipeline->Ready(); + } + for (auto &child : children) { + child->Ready(); + } +} + +MetaPipeline &MetaPipeline::CreateChildMetaPipeline(Pipeline ¤t, PhysicalOperator &op) { + children.push_back(make_shared(executor, state, &op)); + auto child_meta_pipeline = children.back().get(); + // child MetaPipeline must finish completely before this MetaPipeline can start + current.AddDependency(child_meta_pipeline->GetBasePipeline()); + // child meta pipeline is part of the recursive CTE too + child_meta_pipeline->recursive_cte = recursive_cte; + return *child_meta_pipeline; +} + +Pipeline *MetaPipeline::CreatePipeline() { + pipelines.emplace_back(make_shared(executor)); + state.SetPipelineSink(*pipelines.back(), sink, next_batch_index++); + return pipelines.back().get(); +} + +void MetaPipeline::AddDependenciesFrom(Pipeline *dependant, Pipeline *start, bool including) { + // find 'start' + auto it = pipelines.begin(); + for (; it->get() != start; it++) { + } + + if (!including) { + it++; + } + + // collect pipelines that were created from then + vector created_pipelines; + for (; it != pipelines.end(); it++) { + if (it->get() == dependant) { + // cannot depend on itself + continue; + } + created_pipelines.push_back(it->get()); + } + + // add them to the dependencies + auto &deps = dependencies[dependant]; + deps.insert(deps.begin(), created_pipelines.begin(), created_pipelines.end()); +} + +void MetaPipeline::AddFinishEvent(Pipeline *pipeline) { + D_ASSERT(finish_pipelines.find(pipeline) == finish_pipelines.end()); + finish_pipelines.insert(pipeline); + + // add all pipelines that were added since 'pipeline' was added (including 'pipeline') to the finish group + auto it = pipelines.begin(); + for (; it->get() != pipeline; it++) { + } + it++; + for (; it != pipelines.end(); it++) { + finish_map.emplace(it->get(), pipeline); + } +} + +bool MetaPipeline::HasFinishEvent(Pipeline *pipeline) const { + return finish_pipelines.find(pipeline) != finish_pipelines.end(); +} + +optional_ptr MetaPipeline::GetFinishGroup(Pipeline *pipeline) const { + auto it = finish_map.find(pipeline); + return it == finish_map.end() ? nullptr : it->second; +} + +Pipeline *MetaPipeline::CreateUnionPipeline(Pipeline ¤t, bool order_matters) { + // create the union pipeline (batch index 0, should be set correctly afterwards) + auto union_pipeline = CreatePipeline(); + state.SetPipelineOperators(*union_pipeline, state.GetPipelineOperators(current)); + state.SetPipelineSink(*union_pipeline, sink, 0); + + // 'union_pipeline' inherits ALL dependencies of 'current' (within this MetaPipeline, and across MetaPipelines) + union_pipeline->dependencies = current.dependencies; + auto current_deps = GetDependencies(¤t); + if (current_deps) { + dependencies[union_pipeline] = *current_deps; + } + + if (order_matters) { + // if we need to preserve order, or if the sink is not parallel, we set a dependency + dependencies[union_pipeline].push_back(¤t); + } + + return union_pipeline; +} + +void MetaPipeline::CreateChildPipeline(Pipeline ¤t, PhysicalOperator &op, Pipeline *last_pipeline) { + // rule 2: 'current' must be fully built (down to the source) before creating the child pipeline + D_ASSERT(current.source); + + // create the child pipeline (same batch index) + pipelines.emplace_back(state.CreateChildPipeline(executor, current, op)); + auto child_pipeline = pipelines.back().get(); + child_pipeline->base_batch_index = current.base_batch_index; + + // child pipeline has a dependency (within this MetaPipeline on all pipelines that were scheduled + // between 'current' and now (including 'current') - set them up + dependencies[child_pipeline].push_back(¤t); + AddDependenciesFrom(child_pipeline, last_pipeline, false); + D_ASSERT(!GetDependencies(child_pipeline)->empty()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/pipeline.cpp b/src/duckdb/src/parallel/pipeline.cpp new file mode 100644 index 00000000..1448b100 --- /dev/null +++ b/src/duckdb/src/parallel/pipeline.cpp @@ -0,0 +1,333 @@ +#include "duckdb/parallel/pipeline.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/tree_renderer.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp" +#include "duckdb/execution/operator/scan/physical_table_scan.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parallel/pipeline_event.hpp" +#include "duckdb/parallel/pipeline_executor.hpp" +#include "duckdb/parallel/task_scheduler.hpp" + +namespace duckdb { + +class PipelineTask : public ExecutorTask { + static constexpr const idx_t PARTIAL_CHUNK_COUNT = 50; + +public: + explicit PipelineTask(Pipeline &pipeline_p, shared_ptr event_p) + : ExecutorTask(pipeline_p.executor), pipeline(pipeline_p), event(std::move(event_p)) { + } + + Pipeline &pipeline; + shared_ptr event; + unique_ptr pipeline_executor; + +public: + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + if (!pipeline_executor) { + pipeline_executor = make_uniq(pipeline.GetClientContext(), pipeline); + } + + pipeline_executor->SetTaskForInterrupts(shared_from_this()); + + if (mode == TaskExecutionMode::PROCESS_PARTIAL) { + auto res = pipeline_executor->Execute(PARTIAL_CHUNK_COUNT); + + switch (res) { + case PipelineExecuteResult::NOT_FINISHED: + return TaskExecutionResult::TASK_NOT_FINISHED; + case PipelineExecuteResult::INTERRUPTED: + return TaskExecutionResult::TASK_BLOCKED; + case PipelineExecuteResult::FINISHED: + break; + } + } else { + auto res = pipeline_executor->Execute(); + switch (res) { + case PipelineExecuteResult::NOT_FINISHED: + throw InternalException("Execute without limit should not return NOT_FINISHED"); + case PipelineExecuteResult::INTERRUPTED: + return TaskExecutionResult::TASK_BLOCKED; + case PipelineExecuteResult::FINISHED: + break; + } + } + + event->FinishTask(); + pipeline_executor.reset(); + return TaskExecutionResult::TASK_FINISHED; + } +}; + +Pipeline::Pipeline(Executor &executor_p) + : executor(executor_p), ready(false), initialized(false), source(nullptr), sink(nullptr) { +} + +ClientContext &Pipeline::GetClientContext() { + return executor.context; +} + +bool Pipeline::GetProgress(double ¤t_percentage, idx_t &source_cardinality) { + D_ASSERT(source); + source_cardinality = source->estimated_cardinality; + if (!initialized) { + current_percentage = 0; + return true; + } + auto &client = executor.context; + current_percentage = source->GetProgress(client, *source_state); + return current_percentage >= 0; +} + +void Pipeline::ScheduleSequentialTask(shared_ptr &event) { + vector> tasks; + tasks.push_back(make_uniq(*this, event)); + event->SetTasks(std::move(tasks)); +} + +bool Pipeline::ScheduleParallel(shared_ptr &event) { + // check if the sink, source and all intermediate operators support parallelism + if (!sink->ParallelSink()) { + return false; + } + if (!source->ParallelSource()) { + return false; + } + for (auto &op_ref : operators) { + auto &op = op_ref.get(); + if (!op.ParallelOperator()) { + return false; + } + } + if (sink->RequiresBatchIndex()) { + if (!source->SupportsBatchIndex()) { + throw InternalException( + "Attempting to schedule a pipeline where the sink requires batch index but source does not support it"); + } + } + idx_t max_threads = source_state->MaxThreads(); + return LaunchScanTasks(event, max_threads); +} + +bool Pipeline::IsOrderDependent() const { + auto &config = DBConfig::GetConfig(executor.context); + if (source) { + auto source_order = source->SourceOrder(); + if (source_order == OrderPreservationType::FIXED_ORDER) { + return true; + } + if (source_order == OrderPreservationType::NO_ORDER) { + return false; + } + } + for (auto &op_ref : operators) { + auto &op = op_ref.get(); + if (op.OperatorOrder() == OrderPreservationType::NO_ORDER) { + return false; + } + if (op.OperatorOrder() == OrderPreservationType::FIXED_ORDER) { + return true; + } + } + if (!config.options.preserve_insertion_order) { + return false; + } + if (sink && sink->SinkOrderDependent()) { + return true; + } + return false; +} + +void Pipeline::Schedule(shared_ptr &event) { + D_ASSERT(ready); + D_ASSERT(sink); + Reset(); + if (!ScheduleParallel(event)) { + // could not parallelize this pipeline: push a sequential task instead + ScheduleSequentialTask(event); + } +} + +bool Pipeline::LaunchScanTasks(shared_ptr &event, idx_t max_threads) { + // split the scan up into parts and schedule the parts + auto &scheduler = TaskScheduler::GetScheduler(executor.context); + idx_t active_threads = scheduler.NumberOfThreads(); + if (max_threads > active_threads) { + max_threads = active_threads; + } + if (max_threads <= 1) { + // too small to parallelize + return false; + } + + // launch a task for every thread + vector> tasks; + for (idx_t i = 0; i < max_threads; i++) { + tasks.push_back(make_uniq(*this, event)); + } + event->SetTasks(std::move(tasks)); + return true; +} + +void Pipeline::ResetSink() { + if (sink) { + if (!sink->IsSink()) { + throw InternalException("Sink of pipeline does not have IsSink set"); + } + lock_guard guard(sink->lock); + if (!sink->sink_state) { + sink->sink_state = sink->GetGlobalSinkState(GetClientContext()); + } + } +} + +void Pipeline::Reset() { + ResetSink(); + for (auto &op_ref : operators) { + auto &op = op_ref.get(); + lock_guard guard(op.lock); + if (!op.op_state) { + op.op_state = op.GetGlobalOperatorState(GetClientContext()); + } + } + ResetSource(false); + // we no longer reset source here because this function is no longer guaranteed to be called by the main thread + // source reset needs to be called by the main thread because resetting a source may call into clients like R + initialized = true; +} + +void Pipeline::ResetSource(bool force) { + if (source && !source->IsSource()) { + throw InternalException("Source of pipeline does not have IsSource set"); + } + if (force || !source_state) { + source_state = source->GetGlobalSourceState(GetClientContext()); + } +} + +void Pipeline::Ready() { + if (ready) { + return; + } + ready = true; + std::reverse(operators.begin(), operators.end()); +} + +void Pipeline::AddDependency(shared_ptr &pipeline) { + D_ASSERT(pipeline); + dependencies.push_back(weak_ptr(pipeline)); + pipeline->parents.push_back(weak_ptr(shared_from_this())); +} + +string Pipeline::ToString() const { + TreeRenderer renderer; + return renderer.ToString(*this); +} + +void Pipeline::Print() const { + Printer::Print(ToString()); +} + +void Pipeline::PrintDependencies() const { + for (auto &dep : dependencies) { + shared_ptr(dep)->Print(); + } +} + +vector> Pipeline::GetOperators() { + vector> result; + D_ASSERT(source); + result.push_back(*source); + for (auto &op : operators) { + result.push_back(op.get()); + } + if (sink) { + result.push_back(*sink); + } + return result; +} + +vector> Pipeline::GetOperators() const { + vector> result; + D_ASSERT(source); + result.push_back(*source); + for (auto &op : operators) { + result.push_back(op.get()); + } + if (sink) { + result.push_back(*sink); + } + return result; +} + +void Pipeline::ClearSource() { + source_state.reset(); + batch_indexes.clear(); +} + +idx_t Pipeline::RegisterNewBatchIndex() { + lock_guard l(batch_lock); + idx_t minimum = batch_indexes.empty() ? base_batch_index : *batch_indexes.begin(); + batch_indexes.insert(minimum); + return minimum; +} + +idx_t Pipeline::UpdateBatchIndex(idx_t old_index, idx_t new_index) { + lock_guard l(batch_lock); + if (new_index < *batch_indexes.begin()) { + throw InternalException("Processing batch index %llu, but previous min batch index was %llu", new_index, + *batch_indexes.begin()); + } + auto entry = batch_indexes.find(old_index); + if (entry == batch_indexes.end()) { + throw InternalException("Batch index %llu was not found in set of active batch indexes", old_index); + } + batch_indexes.erase(entry); + batch_indexes.insert(new_index); + return *batch_indexes.begin(); +} +//===--------------------------------------------------------------------===// +// Pipeline Build State +//===--------------------------------------------------------------------===// +void PipelineBuildState::SetPipelineSource(Pipeline &pipeline, PhysicalOperator &op) { + pipeline.source = &op; +} + +void PipelineBuildState::SetPipelineSink(Pipeline &pipeline, optional_ptr op, + idx_t sink_pipeline_count) { + pipeline.sink = op; + // set the base batch index of this pipeline based on how many other pipelines have this node as their sink + pipeline.base_batch_index = BATCH_INCREMENT * sink_pipeline_count; +} + +void PipelineBuildState::AddPipelineOperator(Pipeline &pipeline, PhysicalOperator &op) { + pipeline.operators.push_back(op); +} + +optional_ptr PipelineBuildState::GetPipelineSource(Pipeline &pipeline) { + return pipeline.source; +} + +optional_ptr PipelineBuildState::GetPipelineSink(Pipeline &pipeline) { + return pipeline.sink; +} + +void PipelineBuildState::SetPipelineOperators(Pipeline &pipeline, vector> operators) { + pipeline.operators = std::move(operators); +} + +shared_ptr PipelineBuildState::CreateChildPipeline(Executor &executor, Pipeline &pipeline, + PhysicalOperator &op) { + return executor.CreateChildPipeline(pipeline, op); +} + +vector> PipelineBuildState::GetPipelineOperators(Pipeline &pipeline) { + return pipeline.operators; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/pipeline_complete_event.cpp b/src/duckdb/src/parallel/pipeline_complete_event.cpp new file mode 100644 index 00000000..85e5ac63 --- /dev/null +++ b/src/duckdb/src/parallel/pipeline_complete_event.cpp @@ -0,0 +1,19 @@ +#include "duckdb/parallel/pipeline_complete_event.hpp" +#include "duckdb/execution/executor.hpp" + +namespace duckdb { + +PipelineCompleteEvent::PipelineCompleteEvent(Executor &executor, bool complete_pipeline_p) + : Event(executor), complete_pipeline(complete_pipeline_p) { +} + +void PipelineCompleteEvent::Schedule() { +} + +void PipelineCompleteEvent::FinalizeFinish() { + if (complete_pipeline) { + executor.CompletePipeline(); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/pipeline_event.cpp b/src/duckdb/src/parallel/pipeline_event.cpp new file mode 100644 index 00000000..735f9fad --- /dev/null +++ b/src/duckdb/src/parallel/pipeline_event.cpp @@ -0,0 +1,27 @@ +#include "duckdb/parallel/pipeline_event.hpp" +#include "duckdb/execution/executor.hpp" + +namespace duckdb { + +PipelineEvent::PipelineEvent(shared_ptr pipeline_p) : BasePipelineEvent(std::move(pipeline_p)) { +} + +void PipelineEvent::Schedule() { + auto event = shared_from_this(); + auto &executor = pipeline->executor; + try { + pipeline->Schedule(event); + D_ASSERT(total_tasks > 0); + } catch (Exception &ex) { + executor.PushError(PreservedError(ex)); + } catch (std::exception &ex) { + executor.PushError(PreservedError(ex)); + } catch (...) { // LCOV_EXCL_START + executor.PushError(PreservedError("Unknown exception in Finalize!")); + } // LCOV_EXCL_STOP +} + +void PipelineEvent::FinishEvent() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/pipeline_executor.cpp b/src/duckdb/src/parallel/pipeline_executor.cpp new file mode 100644 index 00000000..80fc57ca --- /dev/null +++ b/src/duckdb/src/parallel/pipeline_executor.cpp @@ -0,0 +1,543 @@ +#include "duckdb/parallel/pipeline_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/limits.hpp" + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE +#include +#include +#endif + +namespace duckdb { + +PipelineExecutor::PipelineExecutor(ClientContext &context_p, Pipeline &pipeline_p) + : pipeline(pipeline_p), thread(context_p), context(context_p, thread, &pipeline_p) { + D_ASSERT(pipeline.source_state); + if (pipeline.sink) { + local_sink_state = pipeline.sink->GetLocalSinkState(context); + requires_batch_index = pipeline.sink->RequiresBatchIndex() && pipeline.source->SupportsBatchIndex(); + if (requires_batch_index) { + auto &partition_info = local_sink_state->partition_info; + D_ASSERT(!partition_info.batch_index.IsValid()); + // batch index is not set yet - initialize before fetching anything + partition_info.batch_index = pipeline.RegisterNewBatchIndex(); + partition_info.min_batch_index = partition_info.batch_index; + } + } + local_source_state = pipeline.source->GetLocalSourceState(context, *pipeline.source_state); + + intermediate_chunks.reserve(pipeline.operators.size()); + intermediate_states.reserve(pipeline.operators.size()); + for (idx_t i = 0; i < pipeline.operators.size(); i++) { + auto &prev_operator = i == 0 ? *pipeline.source : pipeline.operators[i - 1].get(); + auto ¤t_operator = pipeline.operators[i].get(); + + auto chunk = make_uniq(); + chunk->Initialize(Allocator::Get(context.client), prev_operator.GetTypes()); + intermediate_chunks.push_back(std::move(chunk)); + + auto op_state = current_operator.GetOperatorState(context); + intermediate_states.push_back(std::move(op_state)); + + if (current_operator.IsSink() && current_operator.sink_state->state == SinkFinalizeType::NO_OUTPUT_POSSIBLE) { + // one of the operators has already figured out no output is possible + // we can skip executing the pipeline + FinishProcessing(); + } + } + InitializeChunk(final_chunk); +} + +bool PipelineExecutor::TryFlushCachingOperators() { + if (!started_flushing) { + // Remainder of this method assumes any in process operators are from flushing + D_ASSERT(in_process_operators.empty()); + started_flushing = true; + flushing_idx = IsFinished() ? idx_t(finished_processing_idx) : 0; + } + + // Go over each operator and keep flushing them using `FinalExecute` until empty + while (flushing_idx < pipeline.operators.size()) { + if (!pipeline.operators[flushing_idx].get().RequiresFinalExecute()) { + flushing_idx++; + continue; + } + + // This slightly awkward way of increasing the flushing idx is to make the code re-entrant: We need to call this + // method again in the case of a Sink returning BLOCKED. + if (!should_flush_current_idx && in_process_operators.empty()) { + should_flush_current_idx = true; + flushing_idx++; + continue; + } + + auto &curr_chunk = + flushing_idx + 1 >= intermediate_chunks.size() ? final_chunk : *intermediate_chunks[flushing_idx + 1]; + auto ¤t_operator = pipeline.operators[flushing_idx].get(); + + OperatorFinalizeResultType finalize_result; + OperatorResultType push_result; + + if (in_process_operators.empty()) { + curr_chunk.Reset(); + StartOperator(current_operator); + finalize_result = current_operator.FinalExecute(context, curr_chunk, *current_operator.op_state, + *intermediate_states[flushing_idx]); + EndOperator(current_operator, &curr_chunk); + } else { + // Reset flag and reflush the last chunk we were flushing. + finalize_result = OperatorFinalizeResultType::HAVE_MORE_OUTPUT; + } + + push_result = ExecutePushInternal(curr_chunk, flushing_idx + 1); + + if (finalize_result == OperatorFinalizeResultType::HAVE_MORE_OUTPUT) { + should_flush_current_idx = true; + } else { + should_flush_current_idx = false; + } + + if (push_result == OperatorResultType::BLOCKED) { + remaining_sink_chunk = true; + return false; + } else if (push_result == OperatorResultType::FINISHED) { + break; + } + } + return true; +} + +PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { + D_ASSERT(pipeline.sink); + auto &source_chunk = pipeline.operators.empty() ? final_chunk : *intermediate_chunks[0]; + for (idx_t i = 0; i < max_chunks; i++) { + if (context.client.interrupted) { + throw InterruptException(); + } + + OperatorResultType result; + if (exhausted_source && done_flushing && !remaining_sink_chunk && in_process_operators.empty()) { + break; + } else if (remaining_sink_chunk) { + // The pipeline was interrupted by the Sink. We should retry sinking the final chunk. + result = ExecutePushInternal(final_chunk); + remaining_sink_chunk = false; + } else if (!in_process_operators.empty() && !started_flushing) { + // The pipeline was interrupted by the Sink when pushing a source chunk through the pipeline. We need to + // re-push the same source chunk through the pipeline because there are in_process operators, meaning that + // the result for the pipeline + D_ASSERT(source_chunk.size() > 0); + result = ExecutePushInternal(source_chunk); + } else if (exhausted_source && !done_flushing) { + // The source was exhausted, try flushing all operators + auto flush_completed = TryFlushCachingOperators(); + if (flush_completed) { + done_flushing = true; + break; + } else { + return PipelineExecuteResult::INTERRUPTED; + } + } else if (!exhausted_source) { + // "Regular" path: fetch a chunk from the source and push it through the pipeline + source_chunk.Reset(); + SourceResultType source_result = FetchFromSource(source_chunk); + + if (source_result == SourceResultType::BLOCKED) { + return PipelineExecuteResult::INTERRUPTED; + } + + if (source_result == SourceResultType::FINISHED) { + exhausted_source = true; + if (source_chunk.size() == 0) { + continue; + } + } + result = ExecutePushInternal(source_chunk); + } else { + throw InternalException("Unexpected state reached in pipeline executor"); + } + + // SINK INTERRUPT + if (result == OperatorResultType::BLOCKED) { + remaining_sink_chunk = true; + return PipelineExecuteResult::INTERRUPTED; + } + + if (result == OperatorResultType::FINISHED) { + break; + } + } + + if ((!exhausted_source || !done_flushing) && !IsFinished()) { + return PipelineExecuteResult::NOT_FINISHED; + } + + return PushFinalize(); +} + +PipelineExecuteResult PipelineExecutor::Execute() { + return Execute(NumericLimits::Maximum()); +} + +OperatorResultType PipelineExecutor::ExecutePush(DataChunk &input) { // LCOV_EXCL_START + return ExecutePushInternal(input); +} // LCOV_EXCL_STOP + +void PipelineExecutor::FinishProcessing(int32_t operator_idx) { + finished_processing_idx = operator_idx < 0 ? NumericLimits::Maximum() : operator_idx; + in_process_operators = stack(); +} + +bool PipelineExecutor::IsFinished() { + return finished_processing_idx >= 0; +} + +OperatorResultType PipelineExecutor::ExecutePushInternal(DataChunk &input, idx_t initial_idx) { + D_ASSERT(pipeline.sink); + if (input.size() == 0) { // LCOV_EXCL_START + return OperatorResultType::NEED_MORE_INPUT; + } // LCOV_EXCL_STOP + + // this loop will continuously push the input chunk through the pipeline as long as: + // - the OperatorResultType for the Execute is HAVE_MORE_OUTPUT + // - the Sink doesn't block + while (true) { + OperatorResultType result; + // Note: if input is the final_chunk, we don't do any executing, the chunk just needs to be sinked + if (&input != &final_chunk) { + final_chunk.Reset(); + result = Execute(input, final_chunk, initial_idx); + if (result == OperatorResultType::FINISHED) { + return OperatorResultType::FINISHED; + } + } else { + result = OperatorResultType::NEED_MORE_INPUT; + } + auto &sink_chunk = final_chunk; + if (sink_chunk.size() > 0) { + StartOperator(*pipeline.sink); + D_ASSERT(pipeline.sink); + D_ASSERT(pipeline.sink->sink_state); + OperatorSinkInput sink_input {*pipeline.sink->sink_state, *local_sink_state, interrupt_state}; + + auto sink_result = Sink(sink_chunk, sink_input); + + EndOperator(*pipeline.sink, nullptr); + + if (sink_result == SinkResultType::BLOCKED) { + return OperatorResultType::BLOCKED; + } else if (sink_result == SinkResultType::FINISHED) { + FinishProcessing(); + return OperatorResultType::FINISHED; + } + } + if (result == OperatorResultType::NEED_MORE_INPUT) { + return OperatorResultType::NEED_MORE_INPUT; + } + } +} + +PipelineExecuteResult PipelineExecutor::PushFinalize() { + if (finalized) { + throw InternalException("Calling PushFinalize on a pipeline that has been finalized already"); + } + + D_ASSERT(local_sink_state); + + // Run the combine for the sink + OperatorSinkCombineInput combine_input {*pipeline.sink->sink_state, *local_sink_state, interrupt_state}; + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + if (debug_blocked_combine_count < debug_blocked_target_count) { + debug_blocked_combine_count++; + + auto &callback_state = combine_input.interrupt_state; + std::thread rewake_thread([callback_state] { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + callback_state.Callback(); + }); + rewake_thread.detach(); + + return PipelineExecuteResult::INTERRUPTED; + } +#endif + auto result = pipeline.sink->Combine(context, combine_input); + + if (result == SinkCombineResultType::BLOCKED) { + return PipelineExecuteResult::INTERRUPTED; + } + + finalized = true; + // flush all query profiler info + for (idx_t i = 0; i < intermediate_states.size(); i++) { + intermediate_states[i]->Finalize(pipeline.operators[i].get(), context); + } + pipeline.executor.Flush(thread); + local_sink_state.reset(); + + return PipelineExecuteResult::FINISHED; +} + +// TODO: Refactoring the StreamingQueryResult to use Push-based execution should eliminate the need for this code +void PipelineExecutor::ExecutePull(DataChunk &result) { + if (IsFinished()) { + return; + } + auto &executor = pipeline.executor; + try { + D_ASSERT(!pipeline.sink); + auto &source_chunk = pipeline.operators.empty() ? result : *intermediate_chunks[0]; + while (result.size() == 0 && !exhausted_source) { + if (in_process_operators.empty()) { + source_chunk.Reset(); + + auto done_signal = make_shared(); + interrupt_state = InterruptState(done_signal); + SourceResultType source_result; + + // Repeatedly try to fetch from the source until it doesn't block. Note that it may block multiple times + while (true) { + source_result = FetchFromSource(source_chunk); + + // No interrupt happened, all good. + if (source_result != SourceResultType::BLOCKED) { + break; + } + + // Busy wait for async callback from source operator + done_signal->Await(); + } + + if (source_result == SourceResultType::FINISHED) { + exhausted_source = true; + if (source_chunk.size() == 0) { + break; + } + } + } + if (!pipeline.operators.empty()) { + auto state = Execute(source_chunk, result); + if (state == OperatorResultType::FINISHED) { + break; + } + } + } + } catch (const Exception &ex) { // LCOV_EXCL_START + if (executor.HasError()) { + executor.ThrowException(); + } + throw; + } catch (std::exception &ex) { + if (executor.HasError()) { + executor.ThrowException(); + } + throw; + } catch (...) { + if (executor.HasError()) { + executor.ThrowException(); + } + throw; + } // LCOV_EXCL_STOP +} + +void PipelineExecutor::PullFinalize() { + if (finalized) { + throw InternalException("Calling PullFinalize on a pipeline that has been finalized already"); + } + finalized = true; + pipeline.executor.Flush(thread); +} + +void PipelineExecutor::GoToSource(idx_t ¤t_idx, idx_t initial_idx) { + // we go back to the first operator (the source) + current_idx = initial_idx; + if (!in_process_operators.empty()) { + // ... UNLESS there is an in process operator + // if there is an in-process operator, we start executing at the latest one + // for example, if we have a join operator that has tuples left, we first need to emit those tuples + current_idx = in_process_operators.top(); + in_process_operators.pop(); + } + D_ASSERT(current_idx >= initial_idx); +} + +OperatorResultType PipelineExecutor::Execute(DataChunk &input, DataChunk &result, idx_t initial_idx) { + if (input.size() == 0) { // LCOV_EXCL_START + return OperatorResultType::NEED_MORE_INPUT; + } // LCOV_EXCL_STOP + D_ASSERT(!pipeline.operators.empty()); + + idx_t current_idx; + GoToSource(current_idx, initial_idx); + if (current_idx == initial_idx) { + current_idx++; + } + if (current_idx > pipeline.operators.size()) { + result.Reference(input); + return OperatorResultType::NEED_MORE_INPUT; + } + while (true) { + if (context.client.interrupted) { + throw InterruptException(); + } + // now figure out where to put the chunk + // if current_idx is the last possible index (>= operators.size()) we write to the result + // otherwise we write to an intermediate chunk + auto current_intermediate = current_idx; + auto ¤t_chunk = + current_intermediate >= intermediate_chunks.size() ? result : *intermediate_chunks[current_intermediate]; + current_chunk.Reset(); + if (current_idx == initial_idx) { + // we went back to the source: we need more input + return OperatorResultType::NEED_MORE_INPUT; + } else { + auto &prev_chunk = + current_intermediate == initial_idx + 1 ? input : *intermediate_chunks[current_intermediate - 1]; + auto operator_idx = current_idx - 1; + auto ¤t_operator = pipeline.operators[operator_idx].get(); + + // if current_idx > source_idx, we pass the previous operators' output through the Execute of the current + // operator + StartOperator(current_operator); + auto result = current_operator.Execute(context, prev_chunk, current_chunk, *current_operator.op_state, + *intermediate_states[current_intermediate - 1]); + EndOperator(current_operator, ¤t_chunk); + if (result == OperatorResultType::HAVE_MORE_OUTPUT) { + // more data remains in this operator + // push in-process marker + in_process_operators.push(current_idx); + } else if (result == OperatorResultType::FINISHED) { + D_ASSERT(current_chunk.size() == 0); + FinishProcessing(current_idx); + return OperatorResultType::FINISHED; + } + current_chunk.Verify(); + } + + if (current_chunk.size() == 0) { + // no output from this operator! + if (current_idx == initial_idx) { + // if we got no output from the scan, we are done + break; + } else { + // if we got no output from an intermediate op + // we go back and try to pull data from the source again + GoToSource(current_idx, initial_idx); + continue; + } + } else { + // we got output! continue to the next operator + current_idx++; + if (current_idx > pipeline.operators.size()) { + // if we got output and are at the last operator, we are finished executing for this output chunk + // return the data and push it into the chunk + break; + } + } + } + return in_process_operators.empty() ? OperatorResultType::NEED_MORE_INPUT : OperatorResultType::HAVE_MORE_OUTPUT; +} + +void PipelineExecutor::SetTaskForInterrupts(weak_ptr current_task) { + interrupt_state = InterruptState(std::move(current_task)); +} + +SourceResultType PipelineExecutor::GetData(DataChunk &chunk, OperatorSourceInput &input) { + //! Testing feature to enable async source on every operator +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + if (debug_blocked_source_count < debug_blocked_target_count) { + debug_blocked_source_count++; + + auto &callback_state = input.interrupt_state; + std::thread rewake_thread([callback_state] { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + callback_state.Callback(); + }); + rewake_thread.detach(); + + return SourceResultType::BLOCKED; + } +#endif + + return pipeline.source->GetData(context, chunk, input); +} + +SinkResultType PipelineExecutor::Sink(DataChunk &chunk, OperatorSinkInput &input) { + //! Testing feature to enable async sink on every operator +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + if (debug_blocked_sink_count < debug_blocked_target_count) { + debug_blocked_sink_count++; + + auto &callback_state = input.interrupt_state; + std::thread rewake_thread([callback_state] { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + callback_state.Callback(); + }); + rewake_thread.detach(); + + return SinkResultType::BLOCKED; + } +#endif + return pipeline.sink->Sink(context, chunk, input); +} + +SourceResultType PipelineExecutor::FetchFromSource(DataChunk &result) { + StartOperator(*pipeline.source); + + OperatorSourceInput source_input = {*pipeline.source_state, *local_source_state, interrupt_state}; + auto res = GetData(result, source_input); + + // Ensures Sinks only return empty results when Blocking or Finished + D_ASSERT(res != SourceResultType::BLOCKED || result.size() == 0); + + if (requires_batch_index && res != SourceResultType::BLOCKED) { + idx_t next_batch_index; + if (result.size() == 0) { + next_batch_index = NumericLimits::Maximum(); + } else { + next_batch_index = + pipeline.source->GetBatchIndex(context, result, *pipeline.source_state, *local_source_state); + // we start with the base_batch_index as a valid starting value. Make sure that next batch is called below + next_batch_index += pipeline.base_batch_index + 1; + } + auto &partition_info = local_sink_state->partition_info; + if (next_batch_index != partition_info.batch_index.GetIndex()) { + // batch index has changed - update it + if (partition_info.batch_index.GetIndex() > next_batch_index) { + throw InternalException( + "Pipeline batch index - gotten lower batch index %llu (down from previous batch index of %llu)", + next_batch_index, partition_info.batch_index.GetIndex()); + } + auto current_batch = partition_info.batch_index.GetIndex(); + partition_info.batch_index = next_batch_index; + // call NextBatch before updating min_batch_index to provide the opportunity to flush the previous batch + pipeline.sink->NextBatch(context, *pipeline.sink->sink_state, *local_sink_state); + partition_info.min_batch_index = pipeline.UpdateBatchIndex(current_batch, next_batch_index); + } + } + + EndOperator(*pipeline.source, &result); + + return res; +} + +void PipelineExecutor::InitializeChunk(DataChunk &chunk) { + auto &last_op = pipeline.operators.empty() ? *pipeline.source : pipeline.operators.back().get(); + chunk.Initialize(Allocator::DefaultAllocator(), last_op.GetTypes()); +} + +void PipelineExecutor::StartOperator(PhysicalOperator &op) { + if (context.client.interrupted) { + throw InterruptException(); + } + context.thread.profiler.StartOperator(&op); +} + +void PipelineExecutor::EndOperator(PhysicalOperator &op, optional_ptr chunk) { + context.thread.profiler.EndOperator(chunk); + + if (chunk) { + chunk->Verify(); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/pipeline_finish_event.cpp b/src/duckdb/src/parallel/pipeline_finish_event.cpp new file mode 100644 index 00000000..a3858a2d --- /dev/null +++ b/src/duckdb/src/parallel/pipeline_finish_event.cpp @@ -0,0 +1,70 @@ +#include "duckdb/parallel/pipeline_finish_event.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/parallel/interrupt.hpp" + +namespace duckdb { + +//! The PipelineFinishTask calls Finalize on the sink. Note that this is a single-threaded operation, but is executed +//! in a task to allow the Finalize call to block (e.g. for async I/O) +class PipelineFinishTask : public ExecutorTask { +public: + explicit PipelineFinishTask(Pipeline &pipeline_p, shared_ptr event_p) + : ExecutorTask(pipeline_p.executor), pipeline(pipeline_p), event(std::move(event_p)) { + } + + Pipeline &pipeline; + shared_ptr event; + +public: + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + auto sink = pipeline.GetSink(); + InterruptState interrupt_state(shared_from_this()); + OperatorSinkFinalizeInput finalize_input {*sink->sink_state, interrupt_state}; + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + if (debug_blocked_count < debug_blocked_target_count) { + debug_blocked_count++; + + auto &callback_state = interrupt_state; + std::thread rewake_thread([callback_state] { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + callback_state.Callback(); + }); + rewake_thread.detach(); + + return TaskExecutionResult::TASK_BLOCKED; + } +#endif + auto sink_state = sink->Finalize(pipeline, *event, executor.context, finalize_input); + + if (sink_state == SinkFinalizeType::BLOCKED) { + return TaskExecutionResult::TASK_BLOCKED; + } + + sink->sink_state->state = sink_state; + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; + } + +private: +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + //! Debugging state: number of times blocked + int debug_blocked_count = 0; + //! Number of times the Finalize will block before actually returning data + int debug_blocked_target_count = 10; +#endif +}; + +PipelineFinishEvent::PipelineFinishEvent(shared_ptr pipeline_p) : BasePipelineEvent(std::move(pipeline_p)) { +} + +void PipelineFinishEvent::Schedule() { + vector> tasks; + tasks.push_back(make_uniq(*pipeline, shared_from_this())); + SetTasks(std::move(tasks)); +} + +void PipelineFinishEvent::FinishEvent() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/pipeline_initialize_event.cpp b/src/duckdb/src/parallel/pipeline_initialize_event.cpp new file mode 100644 index 00000000..f7cb8b40 --- /dev/null +++ b/src/duckdb/src/parallel/pipeline_initialize_event.cpp @@ -0,0 +1,38 @@ +#include "duckdb/parallel/pipeline_initialize_event.hpp" + +#include "duckdb/execution/executor.hpp" + +namespace duckdb { + +PipelineInitializeEvent::PipelineInitializeEvent(shared_ptr pipeline_p) + : BasePipelineEvent(std::move(pipeline_p)) { +} + +class PipelineInitializeTask : public ExecutorTask { +public: + explicit PipelineInitializeTask(Pipeline &pipeline_p, shared_ptr event_p) + : ExecutorTask(pipeline_p.executor), pipeline(pipeline_p), event(std::move(event_p)) { + } + + Pipeline &pipeline; + shared_ptr event; + +public: + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + pipeline.ResetSink(); + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; + } +}; + +void PipelineInitializeEvent::Schedule() { + // needs to spawn a task to get the chain of tasks for the query plan going + vector> tasks; + tasks.push_back(make_uniq(*pipeline, shared_from_this())); + SetTasks(std::move(tasks)); +} + +void PipelineInitializeEvent::FinishEvent() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/task_scheduler.cpp b/src/duckdb/src/parallel/task_scheduler.cpp new file mode 100644 index 00000000..d0d3bed5 --- /dev/null +++ b/src/duckdb/src/parallel/task_scheduler.cpp @@ -0,0 +1,302 @@ +#include "duckdb/parallel/task_scheduler.hpp" + +#include "duckdb/common/chrono.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" + +#ifndef DUCKDB_NO_THREADS +#include "concurrentqueue.h" +#include "duckdb/common/thread.hpp" +#include "lightweightsemaphore.h" +#include +#else +#include +#endif + +namespace duckdb { + +struct SchedulerThread { +#ifndef DUCKDB_NO_THREADS + explicit SchedulerThread(unique_ptr thread_p) : internal_thread(std::move(thread_p)) { + } + + unique_ptr internal_thread; +#endif +}; + +#ifndef DUCKDB_NO_THREADS +typedef duckdb_moodycamel::ConcurrentQueue> concurrent_queue_t; +typedef duckdb_moodycamel::LightweightSemaphore lightweight_semaphore_t; + +struct ConcurrentQueue { + concurrent_queue_t q; + lightweight_semaphore_t semaphore; + + void Enqueue(ProducerToken &token, shared_ptr task); + bool DequeueFromProducer(ProducerToken &token, shared_ptr &task); +}; + +struct QueueProducerToken { + explicit QueueProducerToken(ConcurrentQueue &queue) : queue_token(queue.q) { + } + + duckdb_moodycamel::ProducerToken queue_token; +}; + +void ConcurrentQueue::Enqueue(ProducerToken &token, shared_ptr task) { + lock_guard producer_lock(token.producer_lock); + if (q.enqueue(token.token->queue_token, std::move(task))) { + semaphore.signal(); + } else { + throw InternalException("Could not schedule task!"); + } +} + +bool ConcurrentQueue::DequeueFromProducer(ProducerToken &token, shared_ptr &task) { + lock_guard producer_lock(token.producer_lock); + return q.try_dequeue_from_producer(token.token->queue_token, task); +} + +#else +struct ConcurrentQueue { + std::queue> q; + mutex qlock; + + void Enqueue(ProducerToken &token, shared_ptr task); + bool DequeueFromProducer(ProducerToken &token, shared_ptr &task); +}; + +void ConcurrentQueue::Enqueue(ProducerToken &token, shared_ptr task) { + lock_guard lock(qlock); + q.push(std::move(task)); +} + +bool ConcurrentQueue::DequeueFromProducer(ProducerToken &token, shared_ptr &task) { + lock_guard lock(qlock); + if (q.empty()) { + return false; + } + task = std::move(q.front()); + q.pop(); + return true; +} + +struct QueueProducerToken { + QueueProducerToken(ConcurrentQueue &queue) { + } +}; +#endif + +ProducerToken::ProducerToken(TaskScheduler &scheduler, unique_ptr token) + : scheduler(scheduler), token(std::move(token)) { +} + +ProducerToken::~ProducerToken() { +} + +TaskScheduler::TaskScheduler(DatabaseInstance &db) + : db(db), queue(make_uniq()), + allocator_flush_threshold(db.config.options.allocator_flush_threshold) { +} + +TaskScheduler::~TaskScheduler() { +#ifndef DUCKDB_NO_THREADS + SetThreadsInternal(1); +#endif +} + +TaskScheduler &TaskScheduler::GetScheduler(ClientContext &context) { + return TaskScheduler::GetScheduler(DatabaseInstance::GetDatabase(context)); +} + +TaskScheduler &TaskScheduler::GetScheduler(DatabaseInstance &db) { + return db.GetScheduler(); +} + +unique_ptr TaskScheduler::CreateProducer() { + auto token = make_uniq(*queue); + return make_uniq(*this, std::move(token)); +} + +void TaskScheduler::ScheduleTask(ProducerToken &token, shared_ptr task) { + // Enqueue a task for the given producer token and signal any sleeping threads + queue->Enqueue(token, std::move(task)); +} + +bool TaskScheduler::GetTaskFromProducer(ProducerToken &token, shared_ptr &task) { + return queue->DequeueFromProducer(token, task); +} + +void TaskScheduler::ExecuteForever(atomic *marker) { +#ifndef DUCKDB_NO_THREADS + shared_ptr task; + // loop until the marker is set to false + while (*marker) { + // wait for a signal with a timeout + queue->semaphore.wait(); + if (queue->q.try_dequeue(task)) { + auto execute_result = task->Execute(TaskExecutionMode::PROCESS_ALL); + + switch (execute_result) { + case TaskExecutionResult::TASK_FINISHED: + case TaskExecutionResult::TASK_ERROR: + task.reset(); + break; + case TaskExecutionResult::TASK_NOT_FINISHED: + throw InternalException("Task should not return TASK_NOT_FINISHED in PROCESS_ALL mode"); + case TaskExecutionResult::TASK_BLOCKED: + task->Deschedule(); + task.reset(); + break; + } + + // Flushes the outstanding allocator's outstanding allocations + Allocator::ThreadFlush(allocator_flush_threshold); + } + } +#else + throw NotImplementedException("DuckDB was compiled without threads! Background thread loop is not allowed."); +#endif +} + +idx_t TaskScheduler::ExecuteTasks(atomic *marker, idx_t max_tasks) { +#ifndef DUCKDB_NO_THREADS + idx_t completed_tasks = 0; + // loop until the marker is set to false + while (*marker && completed_tasks < max_tasks) { + shared_ptr task; + if (!queue->q.try_dequeue(task)) { + return completed_tasks; + } + auto execute_result = task->Execute(TaskExecutionMode::PROCESS_ALL); + + switch (execute_result) { + case TaskExecutionResult::TASK_FINISHED: + case TaskExecutionResult::TASK_ERROR: + task.reset(); + completed_tasks++; + break; + case TaskExecutionResult::TASK_NOT_FINISHED: + throw InternalException("Task should not return TASK_NOT_FINISHED in PROCESS_ALL mode"); + case TaskExecutionResult::TASK_BLOCKED: + task->Deschedule(); + task.reset(); + break; + } + } + return completed_tasks; +#else + throw NotImplementedException("DuckDB was compiled without threads! Background thread loop is not allowed."); +#endif +} + +void TaskScheduler::ExecuteTasks(idx_t max_tasks) { +#ifndef DUCKDB_NO_THREADS + shared_ptr task; + for (idx_t i = 0; i < max_tasks; i++) { + queue->semaphore.wait(TASK_TIMEOUT_USECS); + if (!queue->q.try_dequeue(task)) { + return; + } + try { + auto execute_result = task->Execute(TaskExecutionMode::PROCESS_ALL); + switch (execute_result) { + case TaskExecutionResult::TASK_FINISHED: + case TaskExecutionResult::TASK_ERROR: + task.reset(); + break; + case TaskExecutionResult::TASK_NOT_FINISHED: + throw InternalException("Task should not return TASK_NOT_FINISHED in PROCESS_ALL mode"); + case TaskExecutionResult::TASK_BLOCKED: + task->Deschedule(); + task.reset(); + break; + } + } catch (...) { + return; + } + } +#else + throw NotImplementedException("DuckDB was compiled without threads! Background thread loop is not allowed."); +#endif +} + +#ifndef DUCKDB_NO_THREADS +static void ThreadExecuteTasks(TaskScheduler *scheduler, atomic *marker) { + scheduler->ExecuteForever(marker); +} +#endif + +int32_t TaskScheduler::NumberOfThreads() { + lock_guard t(thread_lock); + auto &config = DBConfig::GetConfig(db); + return threads.size() + config.options.external_threads + 1; +} + +void TaskScheduler::SetThreads(int32_t n) { +#ifndef DUCKDB_NO_THREADS + lock_guard t(thread_lock); + if (n < 1) { + throw SyntaxException("Must have at least 1 thread!"); + } + SetThreadsInternal(n); +#else + if (n != 1) { + throw NotImplementedException("DuckDB was compiled without threads! Setting threads > 1 is not allowed."); + } +#endif +} + +void TaskScheduler::SetAllocatorFlushTreshold(idx_t threshold) { +} + +void TaskScheduler::Signal(idx_t n) { +#ifndef DUCKDB_NO_THREADS + queue->semaphore.signal(n); +#endif +} + +void TaskScheduler::YieldThread() { +#ifndef DUCKDB_NO_THREADS + std::this_thread::yield(); +#endif +} + +void TaskScheduler::SetThreadsInternal(int32_t n) { +#ifndef DUCKDB_NO_THREADS + if (threads.size() == idx_t(n - 1)) { + return; + } + idx_t new_thread_count = n - 1; + if (threads.size() > new_thread_count) { + // we are reducing the number of threads: clear all threads first + for (idx_t i = 0; i < threads.size(); i++) { + *markers[i] = false; + } + Signal(threads.size()); + // now join the threads to ensure they are fully stopped before erasing them + for (idx_t i = 0; i < threads.size(); i++) { + threads[i]->internal_thread->join(); + } + // erase the threads/markers + threads.clear(); + markers.clear(); + } + if (threads.size() < new_thread_count) { + // we are increasing the number of threads: launch them and run tasks on them + idx_t create_new_threads = new_thread_count - threads.size(); + for (idx_t i = 0; i < create_new_threads; i++) { + // launch a thread and assign it a cancellation marker + auto marker = unique_ptr>(new atomic(true)); + auto worker_thread = make_uniq(ThreadExecuteTasks, this, marker.get()); + auto thread_wrapper = make_uniq(std::move(worker_thread)); + + threads.push_back(std::move(thread_wrapper)); + markers.push_back(std::move(marker)); + } + } +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/thread_context.cpp b/src/duckdb/src/parallel/thread_context.cpp new file mode 100644 index 00000000..049a8fcf --- /dev/null +++ b/src/duckdb/src/parallel/thread_context.cpp @@ -0,0 +1,10 @@ +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/execution/execution_context.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +ThreadContext::ThreadContext(ClientContext &context) : profiler(QueryProfiler::Get(context).IsEnabled()) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/base_expression.cpp b/src/duckdb/src/parser/base_expression.cpp new file mode 100644 index 00000000..92fa254a --- /dev/null +++ b/src/duckdb/src/parser/base_expression.cpp @@ -0,0 +1,31 @@ +#include "duckdb/parser/base_expression.hpp" + +#include "duckdb/main/config.hpp" +#include "duckdb/common/printer.hpp" + +namespace duckdb { + +void BaseExpression::Print() const { + Printer::Print(ToString()); +} + +string BaseExpression::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return ToString(); + } +#endif + return !alias.empty() ? alias : ToString(); +} + +bool BaseExpression::Equals(const BaseExpression &other) const { + if (expression_class != other.expression_class || type != other.type) { + return false; + } + return true; +} + +void BaseExpression::Verify() const { +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/column_definition.cpp b/src/duckdb/src/parser/column_definition.cpp new file mode 100644 index 00000000..93dfee0b --- /dev/null +++ b/src/duckdb/src/parser/column_definition.cpp @@ -0,0 +1,184 @@ +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" + +namespace duckdb { + +ColumnDefinition::ColumnDefinition(string name_p, LogicalType type_p) + : name(std::move(name_p)), type(std::move(type_p)) { +} + +ColumnDefinition::ColumnDefinition(string name_p, LogicalType type_p, unique_ptr expression, + TableColumnType category) + : name(std::move(name_p)), type(std::move(type_p)), category(category), expression(std::move(expression)) { +} + +ColumnDefinition ColumnDefinition::Copy() const { + ColumnDefinition copy(name, type); + copy.oid = oid; + copy.storage_oid = storage_oid; + copy.expression = expression ? expression->Copy() : nullptr; + copy.compression_type = compression_type; + copy.category = category; + return copy; +} + +const unique_ptr &ColumnDefinition::DefaultValue() const { + if (Generated()) { + throw InternalException("Calling DefaultValue() on a generated column"); + } + return expression; +} + +void ColumnDefinition::SetDefaultValue(unique_ptr default_value) { + if (Generated()) { + throw InternalException("Calling SetDefaultValue() on a generated column"); + } + this->expression = std::move(default_value); +} + +const LogicalType &ColumnDefinition::Type() const { + return type; +} + +LogicalType &ColumnDefinition::TypeMutable() { + return type; +} + +void ColumnDefinition::SetType(const LogicalType &type) { + this->type = type; +} + +const string &ColumnDefinition::Name() const { + return name; +} + +void ColumnDefinition::SetName(const string &name) { + this->name = name; +} + +const duckdb::CompressionType &ColumnDefinition::CompressionType() const { + return compression_type; +} + +void ColumnDefinition::SetCompressionType(duckdb::CompressionType compression_type) { + this->compression_type = compression_type; +} + +const storage_t &ColumnDefinition::StorageOid() const { + return storage_oid; +} + +LogicalIndex ColumnDefinition::Logical() const { + return LogicalIndex(oid); +} + +PhysicalIndex ColumnDefinition::Physical() const { + return PhysicalIndex(storage_oid); +} + +void ColumnDefinition::SetStorageOid(storage_t storage_oid) { + this->storage_oid = storage_oid; +} + +const column_t &ColumnDefinition::Oid() const { + return oid; +} + +void ColumnDefinition::SetOid(column_t oid) { + this->oid = oid; +} + +const TableColumnType &ColumnDefinition::Category() const { + return category; +} + +bool ColumnDefinition::Generated() const { + return category == TableColumnType::GENERATED; +} + +//===--------------------------------------------------------------------===// +// Generated Columns (VIRTUAL) +//===--------------------------------------------------------------------===// + +static void VerifyColumnRefs(ParsedExpression &expr) { + if (expr.type == ExpressionType::COLUMN_REF) { + auto &column_ref = expr.Cast(); + if (column_ref.IsQualified()) { + throw ParserException( + "Qualified (tbl.name) column references are not allowed inside of generated column expressions"); + } + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](const ParsedExpression &child) { VerifyColumnRefs((ParsedExpression &)child); }); +} + +static void InnerGetListOfDependencies(ParsedExpression &expr, vector &dependencies) { + if (expr.type == ExpressionType::COLUMN_REF) { + auto columnref = expr.Cast(); + auto &name = columnref.GetColumnName(); + dependencies.push_back(name); + } + ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) { + if (expr.type == ExpressionType::LAMBDA) { + throw NotImplementedException("Lambda functions are currently not supported in generated columns."); + } + InnerGetListOfDependencies((ParsedExpression &)child, dependencies); + }); +} + +void ColumnDefinition::GetListOfDependencies(vector &dependencies) const { + D_ASSERT(Generated()); + InnerGetListOfDependencies(*expression, dependencies); +} + +string ColumnDefinition::GetName() const { + return name; +} + +LogicalType ColumnDefinition::GetType() const { + return type; +} + +void ColumnDefinition::SetGeneratedExpression(unique_ptr new_expr) { + category = TableColumnType::GENERATED; + + if (new_expr->HasSubquery()) { + throw ParserException("Expression of generated column \"%s\" contains a subquery, which isn't allowed", name); + } + + VerifyColumnRefs(*new_expr); + if (type.id() == LogicalTypeId::ANY) { + expression = std::move(new_expr); + return; + } + // Always wrap the expression in a cast, that way we can always update the cast when we change the type + // Except if the type is LogicalType::ANY (no type specified) + expression = make_uniq_base(type, std::move(new_expr)); +} + +void ColumnDefinition::ChangeGeneratedExpressionType(const LogicalType &type) { + D_ASSERT(Generated()); + // First time the type is set, add a cast around the expression + D_ASSERT(this->type.id() == LogicalTypeId::ANY); + expression = make_uniq_base(type, std::move(expression)); + // Every generated expression should be wrapped in a cast on creation + // D_ASSERT(generated_expression->type == ExpressionType::OPERATOR_CAST); + // auto &cast_expr = generated_expression->Cast(); + // auto base_expr = std::move(cast_expr.child); + // generated_expression = make_uniq_base(type, std::move(base_expr)); +} + +const ParsedExpression &ColumnDefinition::GeneratedExpression() const { + D_ASSERT(Generated()); + return *expression; +} + +ParsedExpression &ColumnDefinition::GeneratedExpressionMutable() { + D_ASSERT(Generated()); + return *expression; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/column_list.cpp b/src/duckdb/src/parser/column_list.cpp new file mode 100644 index 00000000..3fe951fe --- /dev/null +++ b/src/duckdb/src/parser/column_list.cpp @@ -0,0 +1,169 @@ +#include "duckdb/parser/column_list.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/to_string.hpp" + +namespace duckdb { + +ColumnList::ColumnList(bool allow_duplicate_names) : allow_duplicate_names(allow_duplicate_names) { +} + +ColumnList::ColumnList(vector columns, bool allow_duplicate_names) + : allow_duplicate_names(allow_duplicate_names) { + for (auto &col : columns) { + AddColumn(std::move(col)); + } +} + +void ColumnList::AddColumn(ColumnDefinition column) { + auto oid = columns.size(); + if (!column.Generated()) { + column.SetStorageOid(physical_columns.size()); + physical_columns.push_back(oid); + } else { + column.SetStorageOid(DConstants::INVALID_INDEX); + } + column.SetOid(columns.size()); + AddToNameMap(column); + columns.push_back(std::move(column)); +} + +void ColumnList::Finalize() { + // add the "rowid" alias, if there is no rowid column specified in the table + if (name_map.find("rowid") == name_map.end()) { + name_map["rowid"] = COLUMN_IDENTIFIER_ROW_ID; + } +} + +void ColumnList::AddToNameMap(ColumnDefinition &col) { + if (allow_duplicate_names) { + idx_t index = 1; + string base_name = col.Name(); + while (name_map.find(col.Name()) != name_map.end()) { + col.SetName(base_name + ":" + to_string(index++)); + } + } else { + if (name_map.find(col.Name()) != name_map.end()) { + throw CatalogException("Column with name %s already exists!", col.Name()); + } + } + name_map[col.Name()] = col.Oid(); +} + +ColumnDefinition &ColumnList::GetColumnMutable(LogicalIndex logical) { + if (logical.index >= columns.size()) { + throw InternalException("Logical column index %lld out of range", logical.index); + } + return columns[logical.index]; +} + +ColumnDefinition &ColumnList::GetColumnMutable(PhysicalIndex physical) { + if (physical.index >= physical_columns.size()) { + throw InternalException("Physical column index %lld out of range", physical.index); + } + auto logical_index = physical_columns[physical.index]; + D_ASSERT(logical_index < columns.size()); + return columns[logical_index]; +} + +ColumnDefinition &ColumnList::GetColumnMutable(const string &name) { + auto entry = name_map.find(name); + if (entry == name_map.end()) { + throw InternalException("Column with name \"%s\" does not exist", name); + } + auto logical_index = entry->second; + D_ASSERT(logical_index < columns.size()); + return columns[logical_index]; +} + +const ColumnDefinition &ColumnList::GetColumn(LogicalIndex logical) const { + if (logical.index >= columns.size()) { + throw InternalException("Logical column index %lld out of range", logical.index); + } + return columns[logical.index]; +} + +const ColumnDefinition &ColumnList::GetColumn(PhysicalIndex physical) const { + if (physical.index >= physical_columns.size()) { + throw InternalException("Physical column index %lld out of range", physical.index); + } + auto logical_index = physical_columns[physical.index]; + D_ASSERT(logical_index < columns.size()); + return columns[logical_index]; +} + +const ColumnDefinition &ColumnList::GetColumn(const string &name) const { + auto entry = name_map.find(name); + if (entry == name_map.end()) { + throw InternalException("Column with name \"%s\" does not exist", name); + } + auto logical_index = entry->second; + D_ASSERT(logical_index < columns.size()); + return columns[logical_index]; +} + +vector ColumnList::GetColumnNames() const { + vector names; + names.reserve(columns.size()); + for (auto &column : columns) { + names.push_back(column.Name()); + } + return names; +} + +vector ColumnList::GetColumnTypes() const { + vector types; + types.reserve(columns.size()); + for (auto &column : columns) { + types.push_back(column.Type()); + } + return types; +} + +bool ColumnList::ColumnExists(const string &name) const { + auto entry = name_map.find(name); + return entry != name_map.end(); +} + +PhysicalIndex ColumnList::LogicalToPhysical(LogicalIndex logical) const { + auto &column = GetColumn(logical); + if (column.Generated()) { + throw InternalException("Column at position %d is not a physical column", logical.index); + } + return column.Physical(); +} + +LogicalIndex ColumnList::PhysicalToLogical(PhysicalIndex index) const { + auto &column = GetColumn(index); + return column.Logical(); +} + +LogicalIndex ColumnList::GetColumnIndex(string &column_name) const { + auto entry = name_map.find(column_name); + if (entry == name_map.end()) { + return LogicalIndex(DConstants::INVALID_INDEX); + } + if (entry->second == COLUMN_IDENTIFIER_ROW_ID) { + column_name = "rowid"; + return LogicalIndex(COLUMN_IDENTIFIER_ROW_ID); + } + column_name = columns[entry->second].Name(); + return LogicalIndex(entry->second); +} + +ColumnList ColumnList::Copy() const { + ColumnList result(allow_duplicate_names); + for (auto &col : columns) { + result.AddColumn(col.Copy()); + } + return result; +} + +ColumnList::ColumnListIterator ColumnList::Logical() const { + return ColumnListIterator(*this, false); +} + +ColumnList::ColumnListIterator ColumnList::Physical() const { + return ColumnListIterator(*this, true); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/constraint.cpp b/src/duckdb/src/parser/constraint.cpp new file mode 100644 index 00000000..c06a5c80 --- /dev/null +++ b/src/duckdb/src/parser/constraint.cpp @@ -0,0 +1,18 @@ +#include "duckdb/parser/constraint.hpp" + +#include "duckdb/common/printer.hpp" +#include "duckdb/parser/constraints/list.hpp" + +namespace duckdb { + +Constraint::Constraint(ConstraintType type) : type(type) { +} + +Constraint::~Constraint() { +} + +void Constraint::Print() const { + Printer::Print(ToString()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/constraints/check_constraint.cpp b/src/duckdb/src/parser/constraints/check_constraint.cpp new file mode 100644 index 00000000..9e6c27bd --- /dev/null +++ b/src/duckdb/src/parser/constraints/check_constraint.cpp @@ -0,0 +1,17 @@ +#include "duckdb/parser/constraints/check_constraint.hpp" + +namespace duckdb { + +CheckConstraint::CheckConstraint(unique_ptr expression) + : Constraint(ConstraintType::CHECK), expression(std::move(expression)) { +} + +string CheckConstraint::ToString() const { + return "CHECK(" + expression->ToString() + ")"; +} + +unique_ptr CheckConstraint::Copy() const { + return make_uniq(expression->Copy()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/constraints/foreign_key_constraint.cpp b/src/duckdb/src/parser/constraints/foreign_key_constraint.cpp new file mode 100644 index 00000000..66faf519 --- /dev/null +++ b/src/duckdb/src/parser/constraints/foreign_key_constraint.cpp @@ -0,0 +1,52 @@ +#include "duckdb/parser/constraints/foreign_key_constraint.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/parser/keyword_helper.hpp" + +namespace duckdb { + +ForeignKeyConstraint::ForeignKeyConstraint() : Constraint(ConstraintType::FOREIGN_KEY) { +} + +ForeignKeyConstraint::ForeignKeyConstraint(vector pk_columns, vector fk_columns, ForeignKeyInfo info) + : Constraint(ConstraintType::FOREIGN_KEY), pk_columns(std::move(pk_columns)), fk_columns(std::move(fk_columns)), + info(std::move(info)) { +} + +string ForeignKeyConstraint::ToString() const { + if (info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { + string base = "FOREIGN KEY ("; + + for (idx_t i = 0; i < fk_columns.size(); i++) { + if (i > 0) { + base += ", "; + } + base += KeywordHelper::WriteOptionallyQuoted(fk_columns[i]); + } + base += ") REFERENCES "; + if (!info.schema.empty()) { + base += info.schema; + base += "."; + } + base += info.table; + base += "("; + + for (idx_t i = 0; i < pk_columns.size(); i++) { + if (i > 0) { + base += ", "; + } + base += KeywordHelper::WriteOptionallyQuoted(pk_columns[i]); + } + base += ")"; + + return base; + } + + return ""; +} + +unique_ptr ForeignKeyConstraint::Copy() const { + return make_uniq(pk_columns, fk_columns, info); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/constraints/not_null_constraint.cpp b/src/duckdb/src/parser/constraints/not_null_constraint.cpp new file mode 100644 index 00000000..fb406e5a --- /dev/null +++ b/src/duckdb/src/parser/constraints/not_null_constraint.cpp @@ -0,0 +1,19 @@ +#include "duckdb/parser/constraints/not_null_constraint.hpp" + +namespace duckdb { + +NotNullConstraint::NotNullConstraint(LogicalIndex index) : Constraint(ConstraintType::NOT_NULL), index(index) { +} + +NotNullConstraint::~NotNullConstraint() { +} + +string NotNullConstraint::ToString() const { + return "NOT NULL"; +} + +unique_ptr NotNullConstraint::Copy() const { + return make_uniq(index); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/constraints/unique_constraint.cpp b/src/duckdb/src/parser/constraints/unique_constraint.cpp new file mode 100644 index 00000000..29fd536e --- /dev/null +++ b/src/duckdb/src/parser/constraints/unique_constraint.cpp @@ -0,0 +1,40 @@ +#include "duckdb/parser/constraints/unique_constraint.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/parser/keyword_helper.hpp" + +namespace duckdb { + +UniqueConstraint::UniqueConstraint() : Constraint(ConstraintType::UNIQUE), index(DConstants::INVALID_INDEX) { +} + +UniqueConstraint::UniqueConstraint(LogicalIndex index, bool is_primary_key) + : Constraint(ConstraintType::UNIQUE), index(index), is_primary_key(is_primary_key) { +} +UniqueConstraint::UniqueConstraint(vector columns, bool is_primary_key) + : Constraint(ConstraintType::UNIQUE), index(DConstants::INVALID_INDEX), columns(std::move(columns)), + is_primary_key(is_primary_key) { +} + +string UniqueConstraint::ToString() const { + string base = is_primary_key ? "PRIMARY KEY(" : "UNIQUE("; + for (idx_t i = 0; i < columns.size(); i++) { + if (i > 0) { + base += ", "; + } + base += KeywordHelper::WriteOptionallyQuoted(columns[i]); + } + return base + ")"; +} + +unique_ptr UniqueConstraint::Copy() const { + if (index.index == DConstants::INVALID_INDEX) { + return make_uniq(columns, is_primary_key); + } else { + auto result = make_uniq(index, is_primary_key); + result->columns = columns; + return std::move(result); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/between_expression.cpp b/src/duckdb/src/parser/expression/between_expression.cpp new file mode 100644 index 00000000..414e5048 --- /dev/null +++ b/src/duckdb/src/parser/expression/between_expression.cpp @@ -0,0 +1,39 @@ +#include "duckdb/parser/expression/between_expression.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +BetweenExpression::BetweenExpression(unique_ptr input_p, unique_ptr lower_p, + unique_ptr upper_p) + : ParsedExpression(ExpressionType::COMPARE_BETWEEN, ExpressionClass::BETWEEN), input(std::move(input_p)), + lower(std::move(lower_p)), upper(std::move(upper_p)) { +} + +BetweenExpression::BetweenExpression() : BetweenExpression(nullptr, nullptr, nullptr) { +} + +string BetweenExpression::ToString() const { + return ToString(*this); +} + +bool BetweenExpression::Equal(const BetweenExpression &a, const BetweenExpression &b) { + if (!a.input->Equals(*b.input)) { + return false; + } + if (!a.lower->Equals(*b.lower)) { + return false; + } + if (!a.upper->Equals(*b.upper)) { + return false; + } + return true; +} + +unique_ptr BetweenExpression::Copy() const { + auto copy = make_uniq(input->Copy(), lower->Copy(), upper->Copy()); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/case_expression.cpp b/src/duckdb/src/parser/expression/case_expression.cpp new file mode 100644 index 00000000..fd76081a --- /dev/null +++ b/src/duckdb/src/parser/expression/case_expression.cpp @@ -0,0 +1,48 @@ +#include "duckdb/parser/expression/case_expression.hpp" + +#include "duckdb/common/exception.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +CaseExpression::CaseExpression() : ParsedExpression(ExpressionType::CASE_EXPR, ExpressionClass::CASE) { +} + +string CaseExpression::ToString() const { + return ToString(*this); +} + +bool CaseExpression::Equal(const CaseExpression &a, const CaseExpression &b) { + if (a.case_checks.size() != b.case_checks.size()) { + return false; + } + for (idx_t i = 0; i < a.case_checks.size(); i++) { + if (!a.case_checks[i].when_expr->Equals(*b.case_checks[i].when_expr)) { + return false; + } + if (!a.case_checks[i].then_expr->Equals(*b.case_checks[i].then_expr)) { + return false; + } + } + if (!a.else_expr->Equals(*b.else_expr)) { + return false; + } + return true; +} + +unique_ptr CaseExpression::Copy() const { + auto copy = make_uniq(); + copy->CopyProperties(*this); + for (auto &check : case_checks) { + CaseCheck new_check; + new_check.when_expr = check.when_expr->Copy(); + new_check.then_expr = check.then_expr->Copy(); + copy->case_checks.push_back(std::move(new_check)); + } + copy->else_expr = else_expr->Copy(); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/cast_expression.cpp b/src/duckdb/src/parser/expression/cast_expression.cpp new file mode 100644 index 00000000..758cc29c --- /dev/null +++ b/src/duckdb/src/parser/expression/cast_expression.cpp @@ -0,0 +1,43 @@ +#include "duckdb/parser/expression/cast_expression.hpp" + +#include "duckdb/common/exception.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +CastExpression::CastExpression(LogicalType target, unique_ptr child, bool try_cast_p) + : ParsedExpression(ExpressionType::OPERATOR_CAST, ExpressionClass::CAST), cast_type(std::move(target)), + try_cast(try_cast_p) { + D_ASSERT(child); + this->child = std::move(child); +} + +CastExpression::CastExpression() : ParsedExpression(ExpressionType::OPERATOR_CAST, ExpressionClass::CAST) { +} + +string CastExpression::ToString() const { + return ToString(*this); +} + +bool CastExpression::Equal(const CastExpression &a, const CastExpression &b) { + if (!a.child->Equals(*b.child)) { + return false; + } + if (a.cast_type != b.cast_type) { + return false; + } + if (a.try_cast != b.try_cast) { + return false; + } + return true; +} + +unique_ptr CastExpression::Copy() const { + auto copy = make_uniq(cast_type, child->Copy(), try_cast); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/collate_expression.cpp b/src/duckdb/src/parser/expression/collate_expression.cpp new file mode 100644 index 00000000..70b754f9 --- /dev/null +++ b/src/duckdb/src/parser/expression/collate_expression.cpp @@ -0,0 +1,39 @@ +#include "duckdb/parser/expression/collate_expression.hpp" + +#include "duckdb/common/exception.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +CollateExpression::CollateExpression(string collation_p, unique_ptr child) + : ParsedExpression(ExpressionType::COLLATE, ExpressionClass::COLLATE), collation(std::move(collation_p)) { + D_ASSERT(child); + this->child = std::move(child); +} + +CollateExpression::CollateExpression() : ParsedExpression(ExpressionType::COLLATE, ExpressionClass::COLLATE) { +} + +string CollateExpression::ToString() const { + return StringUtil::Format("%s COLLATE %s", child->ToString(), SQLIdentifier(collation)); +} + +bool CollateExpression::Equal(const CollateExpression &a, const CollateExpression &b) { + if (!a.child->Equals(*b.child)) { + return false; + } + if (a.collation != b.collation) { + return false; + } + return true; +} + +unique_ptr CollateExpression::Copy() const { + auto copy = make_uniq(collation, child->Copy()); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/columnref_expression.cpp b/src/duckdb/src/parser/expression/columnref_expression.cpp new file mode 100644 index 00000000..70720f97 --- /dev/null +++ b/src/duckdb/src/parser/expression/columnref_expression.cpp @@ -0,0 +1,95 @@ +#include "duckdb/parser/expression/columnref_expression.hpp" + +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/qualified_name.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +ColumnRefExpression::ColumnRefExpression() : ParsedExpression(ExpressionType::COLUMN_REF, ExpressionClass::COLUMN_REF) { +} + +ColumnRefExpression::ColumnRefExpression(string column_name, string table_name) + : ColumnRefExpression(table_name.empty() ? vector {std::move(column_name)} + : vector {std::move(table_name), std::move(column_name)}) { +} + +ColumnRefExpression::ColumnRefExpression(string column_name) + : ColumnRefExpression(vector {std::move(column_name)}) { +} + +ColumnRefExpression::ColumnRefExpression(vector column_names_p) + : ParsedExpression(ExpressionType::COLUMN_REF, ExpressionClass::COLUMN_REF), + column_names(std::move(column_names_p)) { +#ifdef DEBUG + for (auto &col_name : column_names) { + D_ASSERT(!col_name.empty()); + } +#endif +} + +bool ColumnRefExpression::IsQualified() const { + return column_names.size() > 1; +} + +const string &ColumnRefExpression::GetColumnName() const { + D_ASSERT(column_names.size() <= 4); + return column_names.back(); +} + +const string &ColumnRefExpression::GetTableName() const { + D_ASSERT(column_names.size() >= 2 && column_names.size() <= 4); + if (column_names.size() == 4) { + return column_names[2]; + } + if (column_names.size() == 3) { + return column_names[1]; + } + return column_names[0]; +} + +string ColumnRefExpression::GetName() const { + return !alias.empty() ? alias : column_names.back(); +} + +string ColumnRefExpression::ToString() const { + string result; + for (idx_t i = 0; i < column_names.size(); i++) { + if (i > 0) { + result += "."; + } + result += KeywordHelper::WriteOptionallyQuoted(column_names[i]); + } + return result; +} + +bool ColumnRefExpression::Equal(const ColumnRefExpression &a, const ColumnRefExpression &b) { + if (a.column_names.size() != b.column_names.size()) { + return false; + } + for (idx_t i = 0; i < a.column_names.size(); i++) { + if (!StringUtil::CIEquals(a.column_names[i], b.column_names[i])) { + return false; + } + } + return true; +} + +hash_t ColumnRefExpression::Hash() const { + hash_t result = ParsedExpression::Hash(); + for (auto &column_name : column_names) { + result = CombineHash(result, StringUtil::CIHash(column_name)); + } + return result; +} + +unique_ptr ColumnRefExpression::Copy() const { + auto copy = make_uniq(column_names); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/comparison_expression.cpp b/src/duckdb/src/parser/expression/comparison_expression.cpp new file mode 100644 index 00000000..d0649b35 --- /dev/null +++ b/src/duckdb/src/parser/expression/comparison_expression.cpp @@ -0,0 +1,39 @@ +#include "duckdb/parser/expression/comparison_expression.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +ComparisonExpression::ComparisonExpression(ExpressionType type) : ParsedExpression(type, ExpressionClass::COMPARISON) { +} + +ComparisonExpression::ComparisonExpression(ExpressionType type, unique_ptr left, + unique_ptr right) + : ParsedExpression(type, ExpressionClass::COMPARISON), left(std::move(left)), right(std::move(right)) { +} + +string ComparisonExpression::ToString() const { + return ToString(*this); +} + +bool ComparisonExpression::Equal(const ComparisonExpression &a, const ComparisonExpression &b) { + if (!a.left->Equals(*b.left)) { + return false; + } + if (!a.right->Equals(*b.right)) { + return false; + } + return true; +} + +unique_ptr ComparisonExpression::Copy() const { + auto copy = make_uniq(type, left->Copy(), right->Copy()); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/conjunction_expression.cpp b/src/duckdb/src/parser/expression/conjunction_expression.cpp new file mode 100644 index 00000000..78a076bf --- /dev/null +++ b/src/duckdb/src/parser/expression/conjunction_expression.cpp @@ -0,0 +1,60 @@ +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/expression_util.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +ConjunctionExpression::ConjunctionExpression(ExpressionType type) + : ParsedExpression(type, ExpressionClass::CONJUNCTION) { +} + +ConjunctionExpression::ConjunctionExpression(ExpressionType type, vector> children) + : ParsedExpression(type, ExpressionClass::CONJUNCTION) { + for (auto &child : children) { + AddExpression(std::move(child)); + } +} + +ConjunctionExpression::ConjunctionExpression(ExpressionType type, unique_ptr left, + unique_ptr right) + : ParsedExpression(type, ExpressionClass::CONJUNCTION) { + AddExpression(std::move(left)); + AddExpression(std::move(right)); +} + +void ConjunctionExpression::AddExpression(unique_ptr expr) { + if (expr->type == type) { + // expr is a conjunction of the same type: merge the expression lists together + auto &other = expr->Cast(); + for (auto &child : other.children) { + children.push_back(std::move(child)); + } + } else { + children.push_back(std::move(expr)); + } +} + +string ConjunctionExpression::ToString() const { + return ToString(*this); +} + +bool ConjunctionExpression::Equal(const ConjunctionExpression &a, const ConjunctionExpression &b) { + return ExpressionUtil::SetEquals(a.children, b.children); +} + +unique_ptr ConjunctionExpression::Copy() const { + vector> copy_children; + copy_children.reserve(children.size()); + for (auto &expr : children) { + copy_children.push_back(expr->Copy()); + } + + auto copy = make_uniq(type, std::move(copy_children)); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/constant_expression.cpp b/src/duckdb/src/parser/expression/constant_expression.cpp new file mode 100644 index 00000000..437687b1 --- /dev/null +++ b/src/duckdb/src/parser/expression/constant_expression.cpp @@ -0,0 +1,37 @@ +#include "duckdb/parser/expression/constant_expression.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +ConstantExpression::ConstantExpression() : ParsedExpression(ExpressionType::VALUE_CONSTANT, ExpressionClass::CONSTANT) { +} + +ConstantExpression::ConstantExpression(Value val) + : ParsedExpression(ExpressionType::VALUE_CONSTANT, ExpressionClass::CONSTANT), value(std::move(val)) { +} + +string ConstantExpression::ToString() const { + return value.ToSQLString(); +} + +bool ConstantExpression::Equal(const ConstantExpression &a, const ConstantExpression &b) { + return a.value.type() == b.value.type() && !ValueOperations::DistinctFrom(a.value, b.value); +} + +hash_t ConstantExpression::Hash() const { + return value.Hash(); +} + +unique_ptr ConstantExpression::Copy() const { + auto copy = make_uniq(value); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/default_expression.cpp b/src/duckdb/src/parser/expression/default_expression.cpp new file mode 100644 index 00000000..7618fd21 --- /dev/null +++ b/src/duckdb/src/parser/expression/default_expression.cpp @@ -0,0 +1,23 @@ +#include "duckdb/parser/expression/default_expression.hpp" + +#include "duckdb/common/exception.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +DefaultExpression::DefaultExpression() : ParsedExpression(ExpressionType::VALUE_DEFAULT, ExpressionClass::DEFAULT) { +} + +string DefaultExpression::ToString() const { + return "DEFAULT"; +} + +unique_ptr DefaultExpression::Copy() const { + auto copy = make_uniq(); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/function_expression.cpp b/src/duckdb/src/parser/expression/function_expression.cpp new file mode 100644 index 00000000..bd1804b4 --- /dev/null +++ b/src/duckdb/src/parser/expression/function_expression.cpp @@ -0,0 +1,98 @@ +#include "duckdb/parser/expression/function_expression.hpp" + +#include +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hash.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +FunctionExpression::FunctionExpression() : ParsedExpression(ExpressionType::FUNCTION, ExpressionClass::FUNCTION) { +} + +FunctionExpression::FunctionExpression(string catalog, string schema, const string &function_name, + vector> children_p, + unique_ptr filter, unique_ptr order_bys_p, + bool distinct, bool is_operator, bool export_state_p) + : ParsedExpression(ExpressionType::FUNCTION, ExpressionClass::FUNCTION), catalog(std::move(catalog)), + schema(std::move(schema)), function_name(StringUtil::Lower(function_name)), is_operator(is_operator), + children(std::move(children_p)), distinct(distinct), filter(std::move(filter)), order_bys(std::move(order_bys_p)), + export_state(export_state_p) { + D_ASSERT(!function_name.empty()); + if (!order_bys) { + order_bys = make_uniq(); + } +} + +FunctionExpression::FunctionExpression(const string &function_name, vector> children_p, + unique_ptr filter, unique_ptr order_bys, + bool distinct, bool is_operator, bool export_state_p) + : FunctionExpression(INVALID_CATALOG, INVALID_SCHEMA, function_name, std::move(children_p), std::move(filter), + std::move(order_bys), distinct, is_operator, export_state_p) { +} + +string FunctionExpression::ToString() const { + return ToString(*this, schema, function_name, is_operator, distinct, + filter.get(), order_bys.get(), export_state, true); +} + +bool FunctionExpression::Equal(const FunctionExpression &a, const FunctionExpression &b) { + if (a.catalog != b.catalog || a.schema != b.schema || a.function_name != b.function_name || + b.distinct != a.distinct) { + return false; + } + if (b.children.size() != a.children.size()) { + return false; + } + for (idx_t i = 0; i < a.children.size(); i++) { + if (!a.children[i]->Equals(*b.children[i])) { + return false; + } + } + if (!ParsedExpression::Equals(a.filter, b.filter)) { + return false; + } + if (!OrderModifier::Equals(a.order_bys, b.order_bys)) { + return false; + } + if (a.export_state != b.export_state) { + return false; + } + return true; +} + +hash_t FunctionExpression::Hash() const { + hash_t result = ParsedExpression::Hash(); + result = CombineHash(result, duckdb::Hash(schema.c_str())); + result = CombineHash(result, duckdb::Hash(function_name.c_str())); + result = CombineHash(result, duckdb::Hash(distinct)); + result = CombineHash(result, duckdb::Hash(export_state)); + return result; +} + +unique_ptr FunctionExpression::Copy() const { + vector> copy_children; + unique_ptr filter_copy; + copy_children.reserve(children.size()); + for (auto &child : children) { + copy_children.push_back(child->Copy()); + } + if (filter) { + filter_copy = filter->Copy(); + } + auto order_copy = order_bys ? unique_ptr_cast(order_bys->Copy()) : nullptr; + auto copy = + make_uniq(catalog, schema, function_name, std::move(copy_children), std::move(filter_copy), + std::move(order_copy), distinct, is_operator, export_state); + copy->CopyProperties(*this); + return std::move(copy); +} + +void FunctionExpression::Verify() const { + D_ASSERT(!function_name.empty()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/lambda_expression.cpp b/src/duckdb/src/parser/expression/lambda_expression.cpp new file mode 100644 index 00000000..3b9b1398 --- /dev/null +++ b/src/duckdb/src/parser/expression/lambda_expression.cpp @@ -0,0 +1,38 @@ +#include "duckdb/parser/expression/lambda_expression.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/string_util.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +LambdaExpression::LambdaExpression() : ParsedExpression(ExpressionType::LAMBDA, ExpressionClass::LAMBDA) { +} + +LambdaExpression::LambdaExpression(unique_ptr lhs, unique_ptr expr) + : ParsedExpression(ExpressionType::LAMBDA, ExpressionClass::LAMBDA), lhs(std::move(lhs)), expr(std::move(expr)) { +} + +string LambdaExpression::ToString() const { + return "(" + lhs->ToString() + " -> " + expr->ToString() + ")"; +} + +bool LambdaExpression::Equal(const LambdaExpression &a, const LambdaExpression &b) { + return a.lhs->Equals(*b.lhs) && a.expr->Equals(*b.expr); +} + +hash_t LambdaExpression::Hash() const { + hash_t result = lhs->Hash(); + ParsedExpression::Hash(); + result = CombineHash(result, expr->Hash()); + return result; +} + +unique_ptr LambdaExpression::Copy() const { + auto copy = make_uniq(lhs->Copy(), expr->Copy()); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/operator_expression.cpp b/src/duckdb/src/parser/expression/operator_expression.cpp new file mode 100644 index 00000000..79d90f67 --- /dev/null +++ b/src/duckdb/src/parser/expression/operator_expression.cpp @@ -0,0 +1,50 @@ +#include "duckdb/parser/expression/operator_expression.hpp" + +#include "duckdb/common/exception.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +OperatorExpression::OperatorExpression(ExpressionType type, unique_ptr left, + unique_ptr right) + : ParsedExpression(type, ExpressionClass::OPERATOR) { + if (left) { + children.push_back(std::move(left)); + } + if (right) { + children.push_back(std::move(right)); + } +} + +OperatorExpression::OperatorExpression(ExpressionType type, vector> children) + : ParsedExpression(type, ExpressionClass::OPERATOR), children(std::move(children)) { +} + +string OperatorExpression::ToString() const { + return ToString(*this); +} + +bool OperatorExpression::Equal(const OperatorExpression &a, const OperatorExpression &b) { + if (a.children.size() != b.children.size()) { + return false; + } + for (idx_t i = 0; i < a.children.size(); i++) { + if (!a.children[i]->Equals(*b.children[i])) { + return false; + } + } + return true; +} + +unique_ptr OperatorExpression::Copy() const { + auto copy = make_uniq(type); + copy->CopyProperties(*this); + for (auto &it : children) { + copy->children.push_back(it->Copy()); + } + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/parameter_expression.cpp b/src/duckdb/src/parser/expression/parameter_expression.cpp new file mode 100644 index 00000000..034c4eae --- /dev/null +++ b/src/duckdb/src/parser/expression/parameter_expression.cpp @@ -0,0 +1,36 @@ +#include "duckdb/parser/expression/parameter_expression.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/to_string.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +ParameterExpression::ParameterExpression() + : ParsedExpression(ExpressionType::VALUE_PARAMETER, ExpressionClass::PARAMETER) { +} + +string ParameterExpression::ToString() const { + return "$" + identifier; +} + +unique_ptr ParameterExpression::Copy() const { + auto copy = make_uniq(); + copy->identifier = identifier; + copy->CopyProperties(*this); + return std::move(copy); +} + +bool ParameterExpression::Equal(const ParameterExpression &a, const ParameterExpression &b) { + return StringUtil::CIEquals(a.identifier, b.identifier); +} + +hash_t ParameterExpression::Hash() const { + hash_t result = ParsedExpression::Hash(); + return CombineHash(duckdb::Hash(identifier.c_str(), identifier.size()), result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/positional_reference_expression.cpp b/src/duckdb/src/parser/expression/positional_reference_expression.cpp new file mode 100644 index 00000000..e73bbedd --- /dev/null +++ b/src/duckdb/src/parser/expression/positional_reference_expression.cpp @@ -0,0 +1,40 @@ +#include "duckdb/parser/expression/positional_reference_expression.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/to_string.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +PositionalReferenceExpression::PositionalReferenceExpression() + : ParsedExpression(ExpressionType::POSITIONAL_REFERENCE, ExpressionClass::POSITIONAL_REFERENCE) { +} + +PositionalReferenceExpression::PositionalReferenceExpression(idx_t index) + : ParsedExpression(ExpressionType::POSITIONAL_REFERENCE, ExpressionClass::POSITIONAL_REFERENCE), index(index) { +} + +string PositionalReferenceExpression::ToString() const { + return "#" + to_string(index); +} + +bool PositionalReferenceExpression::Equal(const PositionalReferenceExpression &a, + const PositionalReferenceExpression &b) { + return a.index == b.index; +} + +unique_ptr PositionalReferenceExpression::Copy() const { + auto copy = make_uniq(index); + copy->CopyProperties(*this); + return std::move(copy); +} + +hash_t PositionalReferenceExpression::Hash() const { + hash_t result = ParsedExpression::Hash(); + return CombineHash(duckdb::Hash(index), result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/star_expression.cpp b/src/duckdb/src/parser/expression/star_expression.cpp new file mode 100644 index 00000000..c632603e --- /dev/null +++ b/src/duckdb/src/parser/expression/star_expression.cpp @@ -0,0 +1,93 @@ +#include "duckdb/parser/expression/star_expression.hpp" + +#include "duckdb/common/exception.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +StarExpression::StarExpression(string relation_name_p) + : ParsedExpression(ExpressionType::STAR, ExpressionClass::STAR), relation_name(std::move(relation_name_p)) { +} + +string StarExpression::ToString() const { + if (expr) { + D_ASSERT(columns); + return "COLUMNS(" + expr->ToString() + ")"; + } + string result; + if (columns) { + result += "COLUMNS("; + } + result += relation_name.empty() ? "*" : relation_name + ".*"; + if (!exclude_list.empty()) { + result += " EXCLUDE ("; + bool first_entry = true; + for (auto &entry : exclude_list) { + if (!first_entry) { + result += ", "; + } + result += entry; + first_entry = false; + } + result += ")"; + } + if (!replace_list.empty()) { + result += " REPLACE ("; + bool first_entry = true; + for (auto &entry : replace_list) { + if (!first_entry) { + result += ", "; + } + result += entry.second->ToString(); + result += " AS "; + result += entry.first; + first_entry = false; + } + result += ")"; + } + if (columns) { + result += ")"; + } + return result; +} + +bool StarExpression::Equal(const StarExpression &a, const StarExpression &b) { + if (a.relation_name != b.relation_name || a.exclude_list != b.exclude_list) { + return false; + } + if (a.columns != b.columns) { + return false; + } + if (a.replace_list.size() != b.replace_list.size()) { + return false; + } + for (auto &entry : a.replace_list) { + auto other_entry = b.replace_list.find(entry.first); + if (other_entry == b.replace_list.end()) { + return false; + } + if (!entry.second->Equals(*other_entry->second)) { + return false; + } + } + if (!ParsedExpression::Equals(a.expr, b.expr)) { + return false; + } + return true; +} + +unique_ptr StarExpression::Copy() const { + auto copy = make_uniq(relation_name); + copy->exclude_list = exclude_list; + for (auto &entry : replace_list) { + copy->replace_list[entry.first] = entry.second->Copy(); + } + copy->columns = columns; + copy->expr = expr ? expr->Copy() : nullptr; + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/subquery_expression.cpp b/src/duckdb/src/parser/expression/subquery_expression.cpp new file mode 100644 index 00000000..2cc1e454 --- /dev/null +++ b/src/duckdb/src/parser/expression/subquery_expression.cpp @@ -0,0 +1,51 @@ +#include "duckdb/parser/expression/subquery_expression.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" + +namespace duckdb { + +SubqueryExpression::SubqueryExpression() + : ParsedExpression(ExpressionType::SUBQUERY, ExpressionClass::SUBQUERY), subquery_type(SubqueryType::INVALID), + comparison_type(ExpressionType::INVALID) { +} + +string SubqueryExpression::ToString() const { + switch (subquery_type) { + case SubqueryType::ANY: + return "(" + child->ToString() + " " + ExpressionTypeToOperator(comparison_type) + " ANY(" + + subquery->ToString() + "))"; + case SubqueryType::EXISTS: + return "EXISTS(" + subquery->ToString() + ")"; + case SubqueryType::NOT_EXISTS: + return "NOT EXISTS(" + subquery->ToString() + ")"; + case SubqueryType::SCALAR: + return "(" + subquery->ToString() + ")"; + default: + throw InternalException("Unrecognized type for subquery"); + } +} + +bool SubqueryExpression::Equal(const SubqueryExpression &a, const SubqueryExpression &b) { + if (!a.subquery || !b.subquery) { + return false; + } + if (!ParsedExpression::Equals(a.child, b.child)) { + return false; + } + return a.comparison_type == b.comparison_type && a.subquery_type == b.subquery_type && + a.subquery->Equals(*b.subquery); +} + +unique_ptr SubqueryExpression::Copy() const { + auto copy = make_uniq(); + copy->CopyProperties(*this); + copy->subquery = unique_ptr_cast(subquery->Copy()); + copy->subquery_type = subquery_type; + copy->child = child ? child->Copy() : nullptr; + copy->comparison_type = comparison_type; + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression/window_expression.cpp b/src/duckdb/src/parser/expression/window_expression.cpp new file mode 100644 index 00000000..c36228e2 --- /dev/null +++ b/src/duckdb/src/parser/expression/window_expression.cpp @@ -0,0 +1,139 @@ +#include "duckdb/parser/expression/window_expression.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/string_util.hpp" + +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +WindowExpression::WindowExpression(ExpressionType type) : ParsedExpression(type, ExpressionClass::WINDOW) { +} + +WindowExpression::WindowExpression(ExpressionType type, string catalog_name, string schema, const string &function_name) + : ParsedExpression(type, ExpressionClass::WINDOW), catalog(std::move(catalog_name)), schema(std::move(schema)), + function_name(StringUtil::Lower(function_name)), ignore_nulls(false) { + switch (type) { + case ExpressionType::WINDOW_AGGREGATE: + case ExpressionType::WINDOW_ROW_NUMBER: + case ExpressionType::WINDOW_FIRST_VALUE: + case ExpressionType::WINDOW_LAST_VALUE: + case ExpressionType::WINDOW_NTH_VALUE: + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: + case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_CUME_DIST: + case ExpressionType::WINDOW_LEAD: + case ExpressionType::WINDOW_LAG: + case ExpressionType::WINDOW_NTILE: + break; + default: + throw NotImplementedException("Window aggregate type %s not supported", ExpressionTypeToString(type).c_str()); + } +} + +ExpressionType WindowExpression::WindowToExpressionType(string &fun_name) { + if (fun_name == "rank") { + return ExpressionType::WINDOW_RANK; + } else if (fun_name == "rank_dense" || fun_name == "dense_rank") { + return ExpressionType::WINDOW_RANK_DENSE; + } else if (fun_name == "percent_rank") { + return ExpressionType::WINDOW_PERCENT_RANK; + } else if (fun_name == "row_number") { + return ExpressionType::WINDOW_ROW_NUMBER; + } else if (fun_name == "first_value" || fun_name == "first") { + return ExpressionType::WINDOW_FIRST_VALUE; + } else if (fun_name == "last_value" || fun_name == "last") { + return ExpressionType::WINDOW_LAST_VALUE; + } else if (fun_name == "nth_value") { + return ExpressionType::WINDOW_NTH_VALUE; + } else if (fun_name == "cume_dist") { + return ExpressionType::WINDOW_CUME_DIST; + } else if (fun_name == "lead") { + return ExpressionType::WINDOW_LEAD; + } else if (fun_name == "lag") { + return ExpressionType::WINDOW_LAG; + } else if (fun_name == "ntile") { + return ExpressionType::WINDOW_NTILE; + } + return ExpressionType::WINDOW_AGGREGATE; +} + +string WindowExpression::ToString() const { + return ToString(*this, schema, function_name); +} + +bool WindowExpression::Equal(const WindowExpression &a, const WindowExpression &b) { + // check if the child expressions are equivalent + if (a.ignore_nulls != b.ignore_nulls) { + return false; + } + if (!ParsedExpression::ListEquals(a.children, b.children)) { + return false; + } + if (a.start != b.start || a.end != b.end) { + return false; + } + // check if the framing expressions are equivalentbind_ + if (!ParsedExpression::Equals(a.start_expr, b.start_expr) || !ParsedExpression::Equals(a.end_expr, b.end_expr) || + !ParsedExpression::Equals(a.offset_expr, b.offset_expr) || + !ParsedExpression::Equals(a.default_expr, b.default_expr)) { + return false; + } + + // check if the partitions are equivalent + if (!ParsedExpression::ListEquals(a.partitions, b.partitions)) { + return false; + } + // check if the orderings are equivalent + if (a.orders.size() != b.orders.size()) { + return false; + } + for (idx_t i = 0; i < a.orders.size(); i++) { + if (a.orders[i].type != b.orders[i].type) { + return false; + } + if (!a.orders[i].expression->Equals(*b.orders[i].expression)) { + return false; + } + } + // check if the filter clauses are equivalent + if (!ParsedExpression::Equals(a.filter_expr, b.filter_expr)) { + return false; + } + + return true; +} + +unique_ptr WindowExpression::Copy() const { + auto new_window = make_uniq(type, catalog, schema, function_name); + new_window->CopyProperties(*this); + + for (auto &child : children) { + new_window->children.push_back(child->Copy()); + } + + for (auto &e : partitions) { + new_window->partitions.push_back(e->Copy()); + } + + for (auto &o : orders) { + new_window->orders.emplace_back(o.type, o.null_order, o.expression->Copy()); + } + + new_window->filter_expr = filter_expr ? filter_expr->Copy() : nullptr; + + new_window->start = start; + new_window->end = end; + new_window->start_expr = start_expr ? start_expr->Copy() : nullptr; + new_window->end_expr = end_expr ? end_expr->Copy() : nullptr; + new_window->offset_expr = offset_expr ? offset_expr->Copy() : nullptr; + new_window->default_expr = default_expr ? default_expr->Copy() : nullptr; + new_window->ignore_nulls = ignore_nulls; + + return std::move(new_window); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/expression_util.cpp b/src/duckdb/src/parser/expression_util.cpp new file mode 100644 index 00000000..76ffddc5 --- /dev/null +++ b/src/duckdb/src/parser/expression_util.cpp @@ -0,0 +1,72 @@ +#include "duckdb/parser/expression_util.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/expression_map.hpp" + +namespace duckdb { + +template +bool ExpressionUtil::ExpressionListEquals(const vector> &a, const vector> &b) { + if (a.size() != b.size()) { + return false; + } + for (idx_t i = 0; i < a.size(); i++) { + if (!(*a[i] == *b[i])) { + return false; + } + } + return true; +} + +template +bool ExpressionUtil::ExpressionSetEquals(const vector> &a, const vector> &b) { + if (a.size() != b.size()) { + return false; + } + // we create a map of expression -> count for the left side + // we keep the count because the same expression can occur multiple times (e.g. "1 AND 1" is legal) + // in this case we track the following value: map["Constant(1)"] = 2 + EXPRESSION_MAP map; + for (idx_t i = 0; i < a.size(); i++) { + map[*a[i]]++; + } + // now on the right side we reduce the counts again + // if the conjunctions are identical, all the counts will be 0 after the + for (auto &expr : b) { + auto entry = map.find(*expr); + // first we check if we can find the expression in the map at all + if (entry == map.end()) { + return false; + } + // if we found it we check the count; if the count is already 0 we return false + // this happens if e.g. the left side contains "1 AND X", and the right side contains "1 AND 1" + // "1" is contained in the map, however, the right side contains the expression twice + // hence we know the children are not identical in this case because the LHS and RHS have a different count for + // the Constant(1) expression + if (entry->second == 0) { + return false; + } + entry->second--; + } + return true; +} + +bool ExpressionUtil::ListEquals(const vector> &a, + const vector> &b) { + return ExpressionListEquals(a, b); +} + +bool ExpressionUtil::ListEquals(const vector> &a, const vector> &b) { + return ExpressionListEquals(a, b); +} + +bool ExpressionUtil::SetEquals(const vector> &a, + const vector> &b) { + return ExpressionSetEquals>(a, b); +} + +bool ExpressionUtil::SetEquals(const vector> &a, const vector> &b) { + return ExpressionSetEquals>(a, b); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/keyword_helper.cpp b/src/duckdb/src/parser/keyword_helper.cpp new file mode 100644 index 00000000..ff6f573f --- /dev/null +++ b/src/duckdb/src/parser/keyword_helper.cpp @@ -0,0 +1,49 @@ +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +bool KeywordHelper::IsKeyword(const string &text) { + return Parser::IsKeyword(text); +} + +bool KeywordHelper::RequiresQuotes(const string &text, bool allow_caps) { + for (size_t i = 0; i < text.size(); i++) { + if (i > 0 && (text[i] >= '0' && text[i] <= '9')) { + continue; + } + if (text[i] >= 'a' && text[i] <= 'z') { + continue; + } + if (allow_caps) { + if (text[i] >= 'A' && text[i] <= 'Z') { + continue; + } + } + if (text[i] == '_') { + continue; + } + return true; + } + return IsKeyword(text); +} + +string KeywordHelper::EscapeQuotes(const string &text, char quote) { + return StringUtil::Replace(text, string(1, quote), string(2, quote)); +} + +string KeywordHelper::WriteQuoted(const string &text, char quote) { + // 1. Escapes all occurences of 'quote' by doubling them (escape in SQL) + // 2. Adds quotes around the string + return string(1, quote) + EscapeQuotes(text, quote) + string(1, quote); +} + +string KeywordHelper::WriteOptionallyQuoted(const string &text, char quote, bool allow_caps) { + if (!RequiresQuotes(text, allow_caps)) { + return text; + } + return WriteQuoted(text, quote); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/alter_info.cpp b/src/duckdb/src/parser/parsed_data/alter_info.cpp new file mode 100644 index 00000000..e4a78ba9 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/alter_info.cpp @@ -0,0 +1,28 @@ +#include "duckdb/parser/parsed_data/alter_info.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" +#include "duckdb/parser/parsed_data/alter_table_function_info.hpp" + +namespace duckdb { + +AlterInfo::AlterInfo(AlterType type, string catalog_p, string schema_p, string name_p, OnEntryNotFound if_not_found) + : ParseInfo(TYPE), type(type), if_not_found(if_not_found), catalog(std::move(catalog_p)), + schema(std::move(schema_p)), name(std::move(name_p)), allow_internal(false) { +} + +AlterInfo::AlterInfo(AlterType type) : ParseInfo(TYPE), type(type) { +} + +AlterInfo::~AlterInfo() { +} + +AlterEntryData AlterInfo::GetAlterEntryData() const { + AlterEntryData data; + data.catalog = catalog; + data.schema = schema; + data.name = name; + data.if_not_found = if_not_found; + return data; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/alter_scalar_function_info.cpp b/src/duckdb/src/parser/parsed_data/alter_scalar_function_info.cpp new file mode 100644 index 00000000..363019cf --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/alter_scalar_function_info.cpp @@ -0,0 +1,38 @@ +#include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" + +#include "duckdb/parser/constraint.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// AlterScalarFunctionInfo +//===--------------------------------------------------------------------===// +AlterScalarFunctionInfo::AlterScalarFunctionInfo(AlterScalarFunctionType type, AlterEntryData data) + : AlterInfo(AlterType::ALTER_SCALAR_FUNCTION, std::move(data.catalog), std::move(data.schema), std::move(data.name), + data.if_not_found), + alter_scalar_function_type(type) { +} +AlterScalarFunctionInfo::~AlterScalarFunctionInfo() { +} + +CatalogType AlterScalarFunctionInfo::GetCatalogType() const { + return CatalogType::SCALAR_FUNCTION_ENTRY; +} + +//===--------------------------------------------------------------------===// +// AddScalarFunctionOverloadInfo +//===--------------------------------------------------------------------===// +AddScalarFunctionOverloadInfo::AddScalarFunctionOverloadInfo(AlterEntryData data, ScalarFunctionSet new_overloads_p) + : AlterScalarFunctionInfo(AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS, std::move(data)), + new_overloads(std::move(new_overloads_p)) { + this->allow_internal = true; +} + +AddScalarFunctionOverloadInfo::~AddScalarFunctionOverloadInfo() { +} + +unique_ptr AddScalarFunctionOverloadInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), new_overloads); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/alter_table_function_info.cpp b/src/duckdb/src/parser/parsed_data/alter_table_function_info.cpp new file mode 100644 index 00000000..347eb3fd --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/alter_table_function_info.cpp @@ -0,0 +1,38 @@ +#include "duckdb/parser/parsed_data/alter_table_function_info.hpp" + +#include "duckdb/parser/constraint.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// AlterTableFunctionInfo +//===--------------------------------------------------------------------===// +AlterTableFunctionInfo::AlterTableFunctionInfo(AlterTableFunctionType type, AlterEntryData data) + : AlterInfo(AlterType::ALTER_TABLE_FUNCTION, std::move(data.catalog), std::move(data.schema), std::move(data.name), + data.if_not_found), + alter_table_function_type(type) { +} +AlterTableFunctionInfo::~AlterTableFunctionInfo() { +} + +CatalogType AlterTableFunctionInfo::GetCatalogType() const { + return CatalogType::TABLE_FUNCTION_ENTRY; +} + +//===--------------------------------------------------------------------===// +// AddTableFunctionOverloadInfo +//===--------------------------------------------------------------------===// +AddTableFunctionOverloadInfo::AddTableFunctionOverloadInfo(AlterEntryData data, TableFunctionSet new_overloads_p) + : AlterTableFunctionInfo(AlterTableFunctionType::ADD_FUNCTION_OVERLOADS, std::move(data)), + new_overloads(std::move(new_overloads_p)) { + this->allow_internal = true; +} + +AddTableFunctionOverloadInfo::~AddTableFunctionOverloadInfo() { +} + +unique_ptr AddTableFunctionOverloadInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), new_overloads); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/alter_table_info.cpp b/src/duckdb/src/parser/parsed_data/alter_table_info.cpp new file mode 100644 index 00000000..29ef6e90 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/alter_table_info.cpp @@ -0,0 +1,239 @@ +#include "duckdb/parser/parsed_data/alter_table_info.hpp" + +#include "duckdb/parser/constraint.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// ChangeOwnershipInfo +//===--------------------------------------------------------------------===// +ChangeOwnershipInfo::ChangeOwnershipInfo(CatalogType entry_catalog_type, string entry_catalog_p, string entry_schema_p, + string entry_name_p, string owner_schema_p, string owner_name_p, + OnEntryNotFound if_not_found) + : AlterInfo(AlterType::CHANGE_OWNERSHIP, std::move(entry_catalog_p), std::move(entry_schema_p), + std::move(entry_name_p), if_not_found), + entry_catalog_type(entry_catalog_type), owner_schema(std::move(owner_schema_p)), + owner_name(std::move(owner_name_p)) { +} + +CatalogType ChangeOwnershipInfo::GetCatalogType() const { + return entry_catalog_type; +} + +unique_ptr ChangeOwnershipInfo::Copy() const { + return make_uniq_base(entry_catalog_type, catalog, schema, name, owner_schema, + owner_name, if_not_found); +} + +//===--------------------------------------------------------------------===// +// AlterTableInfo +//===--------------------------------------------------------------------===// +AlterTableInfo::AlterTableInfo(AlterTableType type) : AlterInfo(AlterType::ALTER_TABLE), alter_table_type(type) { +} + +AlterTableInfo::AlterTableInfo(AlterTableType type, AlterEntryData data) + : AlterInfo(AlterType::ALTER_TABLE, std::move(data.catalog), std::move(data.schema), std::move(data.name), + data.if_not_found), + alter_table_type(type) { +} +AlterTableInfo::~AlterTableInfo() { +} + +CatalogType AlterTableInfo::GetCatalogType() const { + return CatalogType::TABLE_ENTRY; +} +//===--------------------------------------------------------------------===// +// RenameColumnInfo +//===--------------------------------------------------------------------===// +RenameColumnInfo::RenameColumnInfo(AlterEntryData data, string old_name_p, string new_name_p) + : AlterTableInfo(AlterTableType::RENAME_COLUMN, std::move(data)), old_name(std::move(old_name_p)), + new_name(std::move(new_name_p)) { +} + +RenameColumnInfo::RenameColumnInfo() : AlterTableInfo(AlterTableType::RENAME_COLUMN) { +} + +RenameColumnInfo::~RenameColumnInfo() { +} + +unique_ptr RenameColumnInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), old_name, new_name); +} + +//===--------------------------------------------------------------------===// +// RenameTableInfo +//===--------------------------------------------------------------------===// +RenameTableInfo::RenameTableInfo() : AlterTableInfo(AlterTableType::RENAME_TABLE) { +} + +RenameTableInfo::RenameTableInfo(AlterEntryData data, string new_name_p) + : AlterTableInfo(AlterTableType::RENAME_TABLE, std::move(data)), new_table_name(std::move(new_name_p)) { +} + +RenameTableInfo::~RenameTableInfo() { +} + +unique_ptr RenameTableInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), new_table_name); +} + +//===--------------------------------------------------------------------===// +// AddColumnInfo +//===--------------------------------------------------------------------===// +AddColumnInfo::AddColumnInfo(ColumnDefinition new_column_p) + : AlterTableInfo(AlterTableType::ADD_COLUMN), new_column(std::move(new_column_p)) { +} + +AddColumnInfo::AddColumnInfo(AlterEntryData data, ColumnDefinition new_column, bool if_column_not_exists) + : AlterTableInfo(AlterTableType::ADD_COLUMN, std::move(data)), new_column(std::move(new_column)), + if_column_not_exists(if_column_not_exists) { +} + +AddColumnInfo::~AddColumnInfo() { +} + +unique_ptr AddColumnInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), new_column.Copy(), if_column_not_exists); +} + +//===--------------------------------------------------------------------===// +// RemoveColumnInfo +//===--------------------------------------------------------------------===// +RemoveColumnInfo::RemoveColumnInfo() : AlterTableInfo(AlterTableType::REMOVE_COLUMN) { +} + +RemoveColumnInfo::RemoveColumnInfo(AlterEntryData data, string removed_column, bool if_column_exists, bool cascade) + : AlterTableInfo(AlterTableType::REMOVE_COLUMN, std::move(data)), removed_column(std::move(removed_column)), + if_column_exists(if_column_exists), cascade(cascade) { +} +RemoveColumnInfo::~RemoveColumnInfo() { +} + +unique_ptr RemoveColumnInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), removed_column, if_column_exists, cascade); +} + +//===--------------------------------------------------------------------===// +// ChangeColumnTypeInfo +//===--------------------------------------------------------------------===// +ChangeColumnTypeInfo::ChangeColumnTypeInfo() : AlterTableInfo(AlterTableType::ALTER_COLUMN_TYPE) { +} + +ChangeColumnTypeInfo::ChangeColumnTypeInfo(AlterEntryData data, string column_name, LogicalType target_type, + unique_ptr expression) + : AlterTableInfo(AlterTableType::ALTER_COLUMN_TYPE, std::move(data)), column_name(std::move(column_name)), + target_type(std::move(target_type)), expression(std::move(expression)) { +} +ChangeColumnTypeInfo::~ChangeColumnTypeInfo() { +} + +unique_ptr ChangeColumnTypeInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), column_name, target_type, + expression->Copy()); +} + +//===--------------------------------------------------------------------===// +// SetDefaultInfo +//===--------------------------------------------------------------------===// +SetDefaultInfo::SetDefaultInfo() : AlterTableInfo(AlterTableType::SET_DEFAULT) { +} + +SetDefaultInfo::SetDefaultInfo(AlterEntryData data, string column_name_p, unique_ptr new_default) + : AlterTableInfo(AlterTableType::SET_DEFAULT, std::move(data)), column_name(std::move(column_name_p)), + expression(std::move(new_default)) { +} +SetDefaultInfo::~SetDefaultInfo() { +} + +unique_ptr SetDefaultInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), column_name, + expression ? expression->Copy() : nullptr); +} + +//===--------------------------------------------------------------------===// +// SetNotNullInfo +//===--------------------------------------------------------------------===// +SetNotNullInfo::SetNotNullInfo() : AlterTableInfo(AlterTableType::SET_NOT_NULL) { +} + +SetNotNullInfo::SetNotNullInfo(AlterEntryData data, string column_name_p) + : AlterTableInfo(AlterTableType::SET_NOT_NULL, std::move(data)), column_name(std::move(column_name_p)) { +} +SetNotNullInfo::~SetNotNullInfo() { +} + +unique_ptr SetNotNullInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), column_name); +} + +//===--------------------------------------------------------------------===// +// DropNotNullInfo +//===--------------------------------------------------------------------===// +DropNotNullInfo::DropNotNullInfo() : AlterTableInfo(AlterTableType::DROP_NOT_NULL) { +} + +DropNotNullInfo::DropNotNullInfo(AlterEntryData data, string column_name_p) + : AlterTableInfo(AlterTableType::DROP_NOT_NULL, std::move(data)), column_name(std::move(column_name_p)) { +} +DropNotNullInfo::~DropNotNullInfo() { +} + +unique_ptr DropNotNullInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), column_name); +} + +//===--------------------------------------------------------------------===// +// AlterForeignKeyInfo +//===--------------------------------------------------------------------===// +AlterForeignKeyInfo::AlterForeignKeyInfo() : AlterTableInfo(AlterTableType::FOREIGN_KEY_CONSTRAINT) { +} + +AlterForeignKeyInfo::AlterForeignKeyInfo(AlterEntryData data, string fk_table, vector pk_columns, + vector fk_columns, vector pk_keys, + vector fk_keys, AlterForeignKeyType type_p) + : AlterTableInfo(AlterTableType::FOREIGN_KEY_CONSTRAINT, std::move(data)), fk_table(std::move(fk_table)), + pk_columns(std::move(pk_columns)), fk_columns(std::move(fk_columns)), pk_keys(std::move(pk_keys)), + fk_keys(std::move(fk_keys)), type(type_p) { +} +AlterForeignKeyInfo::~AlterForeignKeyInfo() { +} + +unique_ptr AlterForeignKeyInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), fk_table, pk_columns, fk_columns, + pk_keys, fk_keys, type); +} + +//===--------------------------------------------------------------------===// +// Alter View +//===--------------------------------------------------------------------===// +AlterViewInfo::AlterViewInfo(AlterViewType type) : AlterInfo(AlterType::ALTER_VIEW), alter_view_type(type) { +} + +AlterViewInfo::AlterViewInfo(AlterViewType type, AlterEntryData data) + : AlterInfo(AlterType::ALTER_VIEW, std::move(data.catalog), std::move(data.schema), std::move(data.name), + data.if_not_found), + alter_view_type(type) { +} +AlterViewInfo::~AlterViewInfo() { +} + +CatalogType AlterViewInfo::GetCatalogType() const { + return CatalogType::VIEW_ENTRY; +} + +//===--------------------------------------------------------------------===// +// RenameViewInfo +//===--------------------------------------------------------------------===// +RenameViewInfo::RenameViewInfo() : AlterViewInfo(AlterViewType::RENAME_VIEW) { +} +RenameViewInfo::RenameViewInfo(AlterEntryData data, string new_name_p) + : AlterViewInfo(AlterViewType::RENAME_VIEW, std::move(data)), new_view_name(std::move(new_name_p)) { +} +RenameViewInfo::~RenameViewInfo() { +} + +unique_ptr RenameViewInfo::Copy() const { + return make_uniq_base(GetAlterEntryData(), new_view_name); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/attach_info.cpp b/src/duckdb/src/parser/parsed_data/attach_info.cpp new file mode 100644 index 00000000..d326ef73 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/attach_info.cpp @@ -0,0 +1,13 @@ +#include "duckdb/parser/parsed_data/attach_info.hpp" + +namespace duckdb { + +unique_ptr AttachInfo::Copy() const { + auto result = make_uniq(); + result->name = name; + result->path = path; + result->options = options; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_aggregate_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_aggregate_function_info.cpp new file mode 100644 index 00000000..8e4f31e2 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_aggregate_function_info.cpp @@ -0,0 +1,27 @@ +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" + +namespace duckdb { + +CreateAggregateFunctionInfo::CreateAggregateFunctionInfo(AggregateFunction function) + : CreateFunctionInfo(CatalogType::AGGREGATE_FUNCTION_ENTRY), functions(function.name) { + name = function.name; + functions.AddFunction(std::move(function)); + internal = true; +} + +CreateAggregateFunctionInfo::CreateAggregateFunctionInfo(AggregateFunctionSet set) + : CreateFunctionInfo(CatalogType::AGGREGATE_FUNCTION_ENTRY), functions(std::move(set)) { + name = functions.name; + for (auto &func : functions.functions) { + func.name = functions.name; + } + internal = true; +} + +unique_ptr CreateAggregateFunctionInfo::Copy() const { + auto result = make_uniq(functions); + CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_collation_info.cpp b/src/duckdb/src/parser/parsed_data/create_collation_info.cpp new file mode 100644 index 00000000..9ae2e610 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_collation_info.cpp @@ -0,0 +1,19 @@ +#include "duckdb/parser/parsed_data/create_collation_info.hpp" + +namespace duckdb { + +CreateCollationInfo::CreateCollationInfo(string name_p, ScalarFunction function_p, bool combinable_p, + bool not_required_for_equality_p) + : CreateInfo(CatalogType::COLLATION_ENTRY), function(std::move(function_p)), combinable(combinable_p), + not_required_for_equality(not_required_for_equality_p) { + this->name = std::move(name_p); + internal = true; +} + +unique_ptr CreateCollationInfo::Copy() const { + auto result = make_uniq(name, function, combinable, not_required_for_equality); + CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_copy_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_copy_function_info.cpp new file mode 100644 index 00000000..333f2752 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_copy_function_info.cpp @@ -0,0 +1,17 @@ +#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" + +namespace duckdb { + +CreateCopyFunctionInfo::CreateCopyFunctionInfo(CopyFunction function_p) + : CreateInfo(CatalogType::COPY_FUNCTION_ENTRY), function(std::move(function_p)) { + this->name = function.name; + internal = true; +} + +unique_ptr CreateCopyFunctionInfo::Copy() const { + auto result = make_uniq(function); + CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_index_info.cpp b/src/duckdb/src/parser/parsed_data/create_index_info.cpp new file mode 100644 index 00000000..e5c9c2a5 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_index_info.cpp @@ -0,0 +1,27 @@ +#include "duckdb/parser/parsed_data/create_index_info.hpp" + +namespace duckdb { + +unique_ptr CreateIndexInfo::Copy() const { + auto result = make_uniq(); + CopyProperties(*result); + + result->index_type = index_type; + result->index_name = index_name; + result->constraint_type = constraint_type; + result->table = table; + for (auto &expr : expressions) { + result->expressions.push_back(expr->Copy()); + } + for (auto &expr : parsed_expressions) { + result->parsed_expressions.push_back(expr->Copy()); + } + + result->scan_types = scan_types; + result->names = names; + result->column_ids = column_ids; + result->options = options; + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_info.cpp b/src/duckdb/src/parser/parsed_data/create_info.cpp new file mode 100644 index 00000000..cd5ad2dc --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_info.cpp @@ -0,0 +1,28 @@ +#include "duckdb/parser/parsed_data/create_info.hpp" + +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/parser/parsed_data/alter_info.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" + +namespace duckdb { + +void CreateInfo::CopyProperties(CreateInfo &other) const { + other.type = type; + other.catalog = catalog; + other.schema = schema; + other.on_conflict = on_conflict; + other.temporary = temporary; + other.internal = internal; + other.sql = sql; +} + +unique_ptr CreateInfo::GetAlterInfo() const { + throw NotImplementedException("GetAlterInfo not implemented for this type"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_macro_info.cpp b/src/duckdb/src/parser/parsed_data/create_macro_info.cpp new file mode 100644 index 00000000..db616b8f --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_macro_info.cpp @@ -0,0 +1,19 @@ +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +CreateMacroInfo::CreateMacroInfo(CatalogType type) : CreateFunctionInfo(type, INVALID_SCHEMA) { +} + +unique_ptr CreateMacroInfo::Copy() const { + auto result = make_uniq(type); + result->function = function->Copy(); + result->name = name; + CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_pragma_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_pragma_function_info.cpp new file mode 100644 index 00000000..6d2c8a15 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_pragma_function_info.cpp @@ -0,0 +1,23 @@ +#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" + +namespace duckdb { + +CreatePragmaFunctionInfo::CreatePragmaFunctionInfo(PragmaFunction function) + : CreateFunctionInfo(CatalogType::PRAGMA_FUNCTION_ENTRY), functions(function.name) { + name = function.name; + functions.AddFunction(std::move(function)); + internal = true; +} +CreatePragmaFunctionInfo::CreatePragmaFunctionInfo(string name, PragmaFunctionSet functions_p) + : CreateFunctionInfo(CatalogType::PRAGMA_FUNCTION_ENTRY), functions(std::move(functions_p)) { + this->name = std::move(name); + internal = true; +} + +unique_ptr CreatePragmaFunctionInfo::Copy() const { + auto result = make_uniq(functions.name, functions); + CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_scalar_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_scalar_function_info.cpp new file mode 100644 index 00000000..6d01bcfb --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_scalar_function_info.cpp @@ -0,0 +1,34 @@ +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" + +namespace duckdb { + +CreateScalarFunctionInfo::CreateScalarFunctionInfo(ScalarFunction function) + : CreateFunctionInfo(CatalogType::SCALAR_FUNCTION_ENTRY), functions(function.name) { + name = function.name; + functions.AddFunction(std::move(function)); + internal = true; +} +CreateScalarFunctionInfo::CreateScalarFunctionInfo(ScalarFunctionSet set) + : CreateFunctionInfo(CatalogType::SCALAR_FUNCTION_ENTRY), functions(std::move(set)) { + name = functions.name; + for (auto &func : functions.functions) { + func.name = functions.name; + } + internal = true; +} + +unique_ptr CreateScalarFunctionInfo::Copy() const { + ScalarFunctionSet set(name); + set.functions = functions.functions; + auto result = make_uniq(std::move(set)); + CopyProperties(*result); + return std::move(result); +} + +unique_ptr CreateScalarFunctionInfo::GetAlterInfo() const { + return make_uniq_base( + AlterEntryData(catalog, schema, name, OnEntryNotFound::RETURN_NULL), functions); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp b/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp new file mode 100644 index 00000000..d7caac02 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp @@ -0,0 +1,27 @@ +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +CreateSequenceInfo::CreateSequenceInfo() + : CreateInfo(CatalogType::SEQUENCE_ENTRY, INVALID_SCHEMA), name(string()), usage_count(0), increment(1), + min_value(1), max_value(NumericLimits::Maximum()), start_value(1), cycle(false) { +} + +unique_ptr CreateSequenceInfo::Copy() const { + auto result = make_uniq(); + CopyProperties(*result); + result->name = name; + result->schema = schema; + result->usage_count = usage_count; + result->increment = increment; + result->min_value = min_value; + result->max_value = max_value; + result->start_value = start_value; + result->cycle = cycle; + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_table_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_table_function_info.cpp new file mode 100644 index 00000000..c2d297b5 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_table_function_info.cpp @@ -0,0 +1,34 @@ +#include "duckdb/parser/parsed_data/create_table_function_info.hpp" +#include "duckdb/parser/parsed_data/alter_table_function_info.hpp" + +namespace duckdb { + +CreateTableFunctionInfo::CreateTableFunctionInfo(TableFunction function) + : CreateFunctionInfo(CatalogType::TABLE_FUNCTION_ENTRY), functions(function.name) { + name = function.name; + functions.AddFunction(std::move(function)); + internal = true; +} +CreateTableFunctionInfo::CreateTableFunctionInfo(TableFunctionSet set) + : CreateFunctionInfo(CatalogType::TABLE_FUNCTION_ENTRY), functions(std::move(set)) { + name = functions.name; + for (auto &func : functions.functions) { + func.name = functions.name; + } + internal = true; +} + +unique_ptr CreateTableFunctionInfo::Copy() const { + TableFunctionSet set(name); + set.functions = functions.functions; + auto result = make_uniq(std::move(set)); + CopyProperties(*result); + return std::move(result); +} + +unique_ptr CreateTableFunctionInfo::GetAlterInfo() const { + return make_uniq_base( + AlterEntryData(catalog, schema, name, OnEntryNotFound::RETURN_NULL), functions); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_table_info.cpp b/src/duckdb/src/parser/parsed_data/create_table_info.cpp new file mode 100644 index 00000000..e97e277d --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_table_info.cpp @@ -0,0 +1,32 @@ +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +CreateTableInfo::CreateTableInfo() : CreateInfo(CatalogType::TABLE_ENTRY, INVALID_SCHEMA) { +} + +CreateTableInfo::CreateTableInfo(string catalog_p, string schema_p, string name_p) + : CreateInfo(CatalogType::TABLE_ENTRY, std::move(schema_p), std::move(catalog_p)), table(std::move(name_p)) { +} + +CreateTableInfo::CreateTableInfo(SchemaCatalogEntry &schema, string name_p) + : CreateTableInfo(schema.catalog.GetName(), schema.name, std::move(name_p)) { +} + +unique_ptr CreateTableInfo::Copy() const { + auto result = make_uniq(catalog, schema, table); + CopyProperties(*result); + result->columns = columns.Copy(); + for (auto &constraint : constraints) { + result->constraints.push_back(constraint->Copy()); + } + if (query) { + result->query = unique_ptr_cast(query->Copy()); + } + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_type_info.cpp b/src/duckdb/src/parser/parsed_data/create_type_info.cpp new file mode 100644 index 00000000..3d4bbb7a --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_type_info.cpp @@ -0,0 +1,25 @@ +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" + +namespace duckdb { + +CreateTypeInfo::CreateTypeInfo() : CreateInfo(CatalogType::TYPE_ENTRY) { +} +CreateTypeInfo::CreateTypeInfo(string name_p, LogicalType type_p) + : CreateInfo(CatalogType::TYPE_ENTRY), name(std::move(name_p)), type(std::move(type_p)) { +} + +unique_ptr CreateTypeInfo::Copy() const { + auto result = make_uniq(); + CopyProperties(*result); + result->name = name; + result->type = type; + if (query) { + result->query = query->Copy(); + } + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/create_view_info.cpp b/src/duckdb/src/parser/parsed_data/create_view_info.cpp new file mode 100644 index 00000000..fde4a2c8 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/create_view_info.cpp @@ -0,0 +1,79 @@ +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/statement/create_statement.hpp" + +namespace duckdb { + +CreateViewInfo::CreateViewInfo() : CreateInfo(CatalogType::VIEW_ENTRY, INVALID_SCHEMA) { +} +CreateViewInfo::CreateViewInfo(string catalog_p, string schema_p, string view_name_p) + : CreateInfo(CatalogType::VIEW_ENTRY, std::move(schema_p), std::move(catalog_p)), + view_name(std::move(view_name_p)) { +} + +CreateViewInfo::CreateViewInfo(SchemaCatalogEntry &schema, string view_name) + : CreateViewInfo(schema.catalog.GetName(), schema.name, std::move(view_name)) { +} + +unique_ptr CreateViewInfo::Copy() const { + auto result = make_uniq(catalog, schema, view_name); + CopyProperties(*result); + result->aliases = aliases; + result->types = types; + result->query = unique_ptr_cast(query->Copy()); + return std::move(result); +} + +unique_ptr CreateViewInfo::FromSelect(ClientContext &context, unique_ptr info) { + D_ASSERT(info); + D_ASSERT(!info->view_name.empty()); + D_ASSERT(!info->sql.empty()); + D_ASSERT(!info->query); + + Parser parser; + parser.ParseQuery(info->sql); + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { + throw BinderException( + "Failed to create view from SQL string - \"%s\" - statement did not contain a single SELECT statement", + info->sql); + } + D_ASSERT(parser.statements.size() == 1 && parser.statements[0]->type == StatementType::SELECT_STATEMENT); + info->query = unique_ptr_cast(std::move(parser.statements[0])); + + auto binder = Binder::CreateBinder(context); + binder->BindCreateViewInfo(*info); + + return info; +} + +unique_ptr CreateViewInfo::FromCreateView(ClientContext &context, const string &sql) { + D_ASSERT(!sql.empty()); + + // parse the SQL statement + Parser parser; + parser.ParseQuery(sql); + + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::CREATE_STATEMENT) { + throw BinderException( + "Failed to create view from SQL string - \"%s\" - statement did not contain a single CREATE VIEW statement", + sql); + } + auto &create_statement = parser.statements[0]->Cast(); + if (create_statement.info->type != CatalogType::VIEW_ENTRY) { + throw BinderException( + "Failed to create view from SQL string - \"%s\" - view did not contain a CREATE VIEW statement", sql); + } + + auto result = unique_ptr_cast(std::move(create_statement.info)); + + auto binder = Binder::CreateBinder(context); + binder->BindCreateViewInfo(*result); + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/detach_info.cpp b/src/duckdb/src/parser/parsed_data/detach_info.cpp new file mode 100644 index 00000000..04a8a71a --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/detach_info.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/parsed_data/detach_info.hpp" + +namespace duckdb { + +DetachInfo::DetachInfo() : ParseInfo(TYPE) { +} + +unique_ptr DetachInfo::Copy() const { + auto result = make_uniq(); + result->name = name; + result->if_not_found = if_not_found; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/drop_info.cpp b/src/duckdb/src/parser/parsed_data/drop_info.cpp new file mode 100644 index 00000000..9c0b57c4 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/drop_info.cpp @@ -0,0 +1,20 @@ +#include "duckdb/parser/parsed_data/drop_info.hpp" + +namespace duckdb { + +DropInfo::DropInfo() : ParseInfo(TYPE), catalog(INVALID_CATALOG), schema(INVALID_SCHEMA), cascade(false) { +} + +unique_ptr DropInfo::Copy() const { + auto result = make_uniq(); + result->type = type; + result->catalog = catalog; + result->schema = schema; + result->name = name; + result->if_not_found = if_not_found; + result->cascade = cascade; + result->allow_drop_internal = allow_drop_internal; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/sample_options.cpp b/src/duckdb/src/parser/parsed_data/sample_options.cpp new file mode 100644 index 00000000..03c8b322 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/sample_options.cpp @@ -0,0 +1,35 @@ +#include "duckdb/parser/parsed_data/sample_options.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +// **DEPRECATED**: Use EnumUtil directly instead. +string SampleMethodToString(SampleMethod method) { + return EnumUtil::ToString(method); +} + +unique_ptr SampleOptions::Copy() { + auto result = make_uniq(); + result->sample_size = sample_size; + result->is_percentage = is_percentage; + result->method = method; + result->seed = seed; + return result; +} + +bool SampleOptions::Equals(SampleOptions *a, SampleOptions *b) { + if (a == b) { + return true; + } + if (!a || !b) { + return false; + } + if (a->sample_size != b->sample_size || a->is_percentage != b->is_percentage || a->method != b->method || + a->seed != b->seed) { + return false; + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/transaction_info.cpp b/src/duckdb/src/parser/parsed_data/transaction_info.cpp new file mode 100644 index 00000000..d0cc8112 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/transaction_info.cpp @@ -0,0 +1,11 @@ +#include "duckdb/parser/parsed_data/transaction_info.hpp" + +namespace duckdb { + +TransactionInfo::TransactionInfo() : ParseInfo(TYPE) { +} + +TransactionInfo::TransactionInfo(TransactionType type) : ParseInfo(TYPE), type(type) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/vacuum_info.cpp b/src/duckdb/src/parser/parsed_data/vacuum_info.cpp new file mode 100644 index 00000000..bbbd9fb5 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/vacuum_info.cpp @@ -0,0 +1,17 @@ +#include "duckdb/parser/parsed_data/vacuum_info.hpp" + +namespace duckdb { + +VacuumInfo::VacuumInfo(VacuumOptions options) : ParseInfo(TYPE), options(options), has_table(false) { +} + +unique_ptr VacuumInfo::Copy() { + auto result = make_uniq(options); + result->has_table = has_table; + if (has_table) { + result->ref = ref->Copy(); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_expression.cpp b/src/duckdb/src/parser/parsed_expression.cpp new file mode 100644 index 00000000..617d60a3 --- /dev/null +++ b/src/duckdb/src/parser/parsed_expression.cpp @@ -0,0 +1,118 @@ +#include "duckdb/main/client_context.hpp" + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/parser/expression/list.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/expression_util.hpp" + +namespace duckdb { + +bool ParsedExpression::IsAggregate() const { + bool is_aggregate = false; + ParsedExpressionIterator::EnumerateChildren( + *this, [&](const ParsedExpression &child) { is_aggregate |= child.IsAggregate(); }); + return is_aggregate; +} + +bool ParsedExpression::IsWindow() const { + bool is_window = false; + ParsedExpressionIterator::EnumerateChildren(*this, + [&](const ParsedExpression &child) { is_window |= child.IsWindow(); }); + return is_window; +} + +bool ParsedExpression::IsScalar() const { + bool is_scalar = true; + ParsedExpressionIterator::EnumerateChildren(*this, [&](const ParsedExpression &child) { + if (!child.IsScalar()) { + is_scalar = false; + } + }); + return is_scalar; +} + +bool ParsedExpression::HasParameter() const { + bool has_parameter = false; + ParsedExpressionIterator::EnumerateChildren( + *this, [&](const ParsedExpression &child) { has_parameter |= child.HasParameter(); }); + return has_parameter; +} + +bool ParsedExpression::HasSubquery() const { + bool has_subquery = false; + ParsedExpressionIterator::EnumerateChildren( + *this, [&](const ParsedExpression &child) { has_subquery |= child.HasSubquery(); }); + return has_subquery; +} + +bool ParsedExpression::Equals(const BaseExpression &other) const { + if (!BaseExpression::Equals(other)) { + return false; + } + switch (expression_class) { + case ExpressionClass::BETWEEN: + return BetweenExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::CASE: + return CaseExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::CAST: + return CastExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::COLLATE: + return CollateExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::COLUMN_REF: + return ColumnRefExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::COMPARISON: + return ComparisonExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::CONJUNCTION: + return ConjunctionExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::CONSTANT: + return ConstantExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::DEFAULT: + return true; + case ExpressionClass::FUNCTION: + return FunctionExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::LAMBDA: + return LambdaExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::OPERATOR: + return OperatorExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::PARAMETER: + return ParameterExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::POSITIONAL_REFERENCE: + return PositionalReferenceExpression::Equal(Cast(), + other.Cast()); + case ExpressionClass::STAR: + return StarExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::SUBQUERY: + return SubqueryExpression::Equal(Cast(), other.Cast()); + case ExpressionClass::WINDOW: + return WindowExpression::Equal(Cast(), other.Cast()); + default: + throw SerializationException("Unsupported type for expression comparison!"); + } +} + +hash_t ParsedExpression::Hash() const { + hash_t hash = duckdb::Hash((uint32_t)type); + ParsedExpressionIterator::EnumerateChildren( + *this, [&](const ParsedExpression &child) { hash = CombineHash(child.Hash(), hash); }); + return hash; +} + +bool ParsedExpression::Equals(const unique_ptr &left, const unique_ptr &right) { + if (left.get() == right.get()) { + return true; + } + if (!left || !right) { + return false; + } + return left->Equals(*right); +} + +bool ParsedExpression::ListEquals(const vector> &left, + const vector> &right) { + return ExpressionUtil::ListEquals(left, right); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_expression_iterator.cpp b/src/duckdb/src/parser/parsed_expression_iterator.cpp new file mode 100644 index 00000000..be997f03 --- /dev/null +++ b/src/duckdb/src/parser/parsed_expression_iterator.cpp @@ -0,0 +1,306 @@ +#include "duckdb/parser/parsed_expression_iterator.hpp" + +#include "duckdb/parser/expression/list.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/query_node/recursive_cte_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/tableref/list.hpp" + +namespace duckdb { + +void ParsedExpressionIterator::EnumerateChildren(const ParsedExpression &expression, + const std::function &callback) { + EnumerateChildren((ParsedExpression &)expression, [&](unique_ptr &child) { + D_ASSERT(child); + callback(*child); + }); +} + +void ParsedExpressionIterator::EnumerateChildren(ParsedExpression &expr, + const std::function &callback) { + EnumerateChildren(expr, [&](unique_ptr &child) { + D_ASSERT(child); + callback(*child); + }); +} + +void ParsedExpressionIterator::EnumerateChildren( + ParsedExpression &expr, const std::function &child)> &callback) { + switch (expr.expression_class) { + case ExpressionClass::BETWEEN: { + auto &cast_expr = expr.Cast(); + callback(cast_expr.input); + callback(cast_expr.lower); + callback(cast_expr.upper); + break; + } + case ExpressionClass::CASE: { + auto &case_expr = expr.Cast(); + for (auto &check : case_expr.case_checks) { + callback(check.when_expr); + callback(check.then_expr); + } + callback(case_expr.else_expr); + break; + } + case ExpressionClass::CAST: { + auto &cast_expr = expr.Cast(); + callback(cast_expr.child); + break; + } + case ExpressionClass::COLLATE: { + auto &cast_expr = expr.Cast(); + callback(cast_expr.child); + break; + } + case ExpressionClass::COMPARISON: { + auto &comp_expr = expr.Cast(); + callback(comp_expr.left); + callback(comp_expr.right); + break; + } + case ExpressionClass::CONJUNCTION: { + auto &conj_expr = expr.Cast(); + for (auto &child : conj_expr.children) { + callback(child); + } + break; + } + + case ExpressionClass::FUNCTION: { + auto &func_expr = expr.Cast(); + for (auto &child : func_expr.children) { + callback(child); + } + if (func_expr.filter) { + callback(func_expr.filter); + } + if (func_expr.order_bys) { + for (auto &order : func_expr.order_bys->orders) { + callback(order.expression); + } + } + break; + } + case ExpressionClass::LAMBDA: { + auto &lambda_expr = expr.Cast(); + callback(lambda_expr.lhs); + callback(lambda_expr.expr); + break; + } + case ExpressionClass::OPERATOR: { + auto &op_expr = expr.Cast(); + for (auto &child : op_expr.children) { + callback(child); + } + break; + } + case ExpressionClass::STAR: { + auto &star_expr = expr.Cast(); + if (star_expr.expr) { + callback(star_expr.expr); + } + break; + } + case ExpressionClass::SUBQUERY: { + auto &subquery_expr = expr.Cast(); + if (subquery_expr.child) { + callback(subquery_expr.child); + } + break; + } + case ExpressionClass::WINDOW: { + auto &window_expr = expr.Cast(); + for (auto &partition : window_expr.partitions) { + callback(partition); + } + for (auto &order : window_expr.orders) { + callback(order.expression); + } + for (auto &child : window_expr.children) { + callback(child); + } + if (window_expr.filter_expr) { + callback(window_expr.filter_expr); + } + if (window_expr.start_expr) { + callback(window_expr.start_expr); + } + if (window_expr.end_expr) { + callback(window_expr.end_expr); + } + if (window_expr.offset_expr) { + callback(window_expr.offset_expr); + } + if (window_expr.default_expr) { + callback(window_expr.default_expr); + } + break; + } + case ExpressionClass::BOUND_EXPRESSION: + case ExpressionClass::COLUMN_REF: + case ExpressionClass::CONSTANT: + case ExpressionClass::DEFAULT: + case ExpressionClass::PARAMETER: + case ExpressionClass::POSITIONAL_REFERENCE: + // these node types have no children + break; + default: + // called on non ParsedExpression type! + throw NotImplementedException("Unimplemented expression class"); + } +} + +void ParsedExpressionIterator::EnumerateQueryNodeModifiers( + QueryNode &node, const std::function &child)> &callback) { + + for (auto &modifier : node.modifiers) { + switch (modifier->type) { + case ResultModifierType::LIMIT_MODIFIER: { + auto &limit_modifier = modifier->Cast(); + if (limit_modifier.limit) { + callback(limit_modifier.limit); + } + if (limit_modifier.offset) { + callback(limit_modifier.offset); + } + } break; + + case ResultModifierType::LIMIT_PERCENT_MODIFIER: { + auto &limit_modifier = modifier->Cast(); + if (limit_modifier.limit) { + callback(limit_modifier.limit); + } + if (limit_modifier.offset) { + callback(limit_modifier.offset); + } + } break; + + case ResultModifierType::ORDER_MODIFIER: { + auto &order_modifier = modifier->Cast(); + for (auto &order : order_modifier.orders) { + callback(order.expression); + } + } break; + + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct_modifier = modifier->Cast(); + for (auto &target : distinct_modifier.distinct_on_targets) { + callback(target); + } + } break; + + // do nothing + default: + break; + } + } +} + +void ParsedExpressionIterator::EnumerateTableRefChildren( + TableRef &ref, const std::function &child)> &callback) { + switch (ref.type) { + case TableReferenceType::EXPRESSION_LIST: { + auto &el_ref = ref.Cast(); + for (idx_t i = 0; i < el_ref.values.size(); i++) { + for (idx_t j = 0; j < el_ref.values[i].size(); j++) { + callback(el_ref.values[i][j]); + } + } + break; + } + case TableReferenceType::JOIN: { + auto &j_ref = ref.Cast(); + EnumerateTableRefChildren(*j_ref.left, callback); + EnumerateTableRefChildren(*j_ref.right, callback); + if (j_ref.condition) { + callback(j_ref.condition); + } + break; + } + case TableReferenceType::PIVOT: { + auto &p_ref = ref.Cast(); + EnumerateTableRefChildren(*p_ref.source, callback); + for (auto &aggr : p_ref.aggregates) { + callback(aggr); + } + break; + } + case TableReferenceType::SUBQUERY: { + auto &sq_ref = ref.Cast(); + EnumerateQueryNodeChildren(*sq_ref.subquery->node, callback); + break; + } + case TableReferenceType::TABLE_FUNCTION: { + auto &tf_ref = ref.Cast(); + callback(tf_ref.function); + break; + } + case TableReferenceType::BASE_TABLE: + case TableReferenceType::EMPTY: + // these TableRefs do not need to be unfolded + break; + case TableReferenceType::INVALID: + case TableReferenceType::CTE: + throw NotImplementedException("TableRef type not implemented for traversal"); + } +} + +void ParsedExpressionIterator::EnumerateQueryNodeChildren( + QueryNode &node, const std::function &child)> &callback) { + switch (node.type) { + case QueryNodeType::RECURSIVE_CTE_NODE: { + auto &rcte_node = node.Cast(); + EnumerateQueryNodeChildren(*rcte_node.left, callback); + EnumerateQueryNodeChildren(*rcte_node.right, callback); + break; + } + case QueryNodeType::CTE_NODE: { + auto &cte_node = node.Cast(); + EnumerateQueryNodeChildren(*cte_node.query, callback); + EnumerateQueryNodeChildren(*cte_node.child, callback); + break; + } + case QueryNodeType::SELECT_NODE: { + auto &sel_node = node.Cast(); + for (idx_t i = 0; i < sel_node.select_list.size(); i++) { + callback(sel_node.select_list[i]); + } + for (idx_t i = 0; i < sel_node.groups.group_expressions.size(); i++) { + callback(sel_node.groups.group_expressions[i]); + } + if (sel_node.where_clause) { + callback(sel_node.where_clause); + } + if (sel_node.having) { + callback(sel_node.having); + } + if (sel_node.qualify) { + callback(sel_node.qualify); + } + + EnumerateTableRefChildren(*sel_node.from_table.get(), callback); + break; + } + case QueryNodeType::SET_OPERATION_NODE: { + auto &setop_node = node.Cast(); + EnumerateQueryNodeChildren(*setop_node.left, callback); + EnumerateQueryNodeChildren(*setop_node.right, callback); + break; + } + default: + throw NotImplementedException("QueryNode type not implemented for traversal"); + } + + if (!node.modifiers.empty()) { + EnumerateQueryNodeModifiers(node, callback); + } + + for (auto &kv : node.cte_map.map) { + EnumerateQueryNodeChildren(*kv.second->query->node, callback); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parser.cpp b/src/duckdb/src/parser/parser.cpp new file mode 100644 index 00000000..8b5af5c7 --- /dev/null +++ b/src/duckdb/src/parser/parser.cpp @@ -0,0 +1,434 @@ +#include "duckdb/parser/parser.hpp" + +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/parser/parser_extension.hpp" +#include "duckdb/parser/query_error_context.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/statement/extension_statement.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/group_by_node.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/parser/transformer.hpp" +#include "parser/parser.hpp" +#include "postgres_parser.hpp" + +namespace duckdb { + +Parser::Parser(ParserOptions options_p) : options(options_p) { +} + +struct UnicodeSpace { + UnicodeSpace(idx_t pos, idx_t bytes) : pos(pos), bytes(bytes) { + } + + idx_t pos; + idx_t bytes; +}; + +static bool ReplaceUnicodeSpaces(const string &query, string &new_query, vector &unicode_spaces) { + if (unicode_spaces.empty()) { + // no unicode spaces found + return false; + } + idx_t prev = 0; + for (auto &usp : unicode_spaces) { + new_query += query.substr(prev, usp.pos - prev); + new_query += " "; + prev = usp.pos + usp.bytes; + } + new_query += query.substr(prev, query.size() - prev); + return true; +} + +// This function strips unicode space characters from the query and replaces them with regular spaces +// It returns true if any unicode space characters were found and stripped +// See here for a list of unicode space characters - https://jkorpela.fi/chars/spaces.html +bool Parser::StripUnicodeSpaces(const string &query_str, string &new_query) { + const idx_t NBSP_LEN = 2; + const idx_t USP_LEN = 3; + idx_t pos = 0; + unsigned char quote; + vector unicode_spaces; + auto query = const_uchar_ptr_cast(query_str.c_str()); + auto qsize = query_str.size(); + +regular: + for (; pos + 2 < qsize; pos++) { + if (query[pos] == 0xC2) { + if (query[pos + 1] == 0xA0) { + // U+00A0 - C2A0 + unicode_spaces.emplace_back(pos, NBSP_LEN); + } + } + if (query[pos] == 0xE2) { + if (query[pos + 1] == 0x80) { + if (query[pos + 2] >= 0x80 && query[pos + 2] <= 0x8B) { + // U+2000 to U+200B + // E28080 - E2808B + unicode_spaces.emplace_back(pos, USP_LEN); + } else if (query[pos + 2] == 0xAF) { + // U+202F - E280AF + unicode_spaces.emplace_back(pos, USP_LEN); + } + } else if (query[pos + 1] == 0x81) { + if (query[pos + 2] == 0x9F) { + // U+205F - E2819f + unicode_spaces.emplace_back(pos, USP_LEN); + } else if (query[pos + 2] == 0xA0) { + // U+2060 - E281A0 + unicode_spaces.emplace_back(pos, USP_LEN); + } + } + } else if (query[pos] == 0xE3) { + if (query[pos + 1] == 0x80 && query[pos + 2] == 0x80) { + // U+3000 - E38080 + unicode_spaces.emplace_back(pos, USP_LEN); + } + } else if (query[pos] == 0xEF) { + if (query[pos + 1] == 0xBB && query[pos + 2] == 0xBF) { + // U+FEFF - EFBBBF + unicode_spaces.emplace_back(pos, USP_LEN); + } + } else if (query[pos] == '"' || query[pos] == '\'') { + quote = query[pos]; + pos++; + goto in_quotes; + } else if (query[pos] == '-' && query[pos + 1] == '-') { + goto in_comment; + } + } + goto end; +in_quotes: + for (; pos + 1 < qsize; pos++) { + if (query[pos] == quote) { + if (query[pos + 1] == quote) { + // escaped quote + pos++; + continue; + } + pos++; + goto regular; + } + } + goto end; +in_comment: + for (; pos < qsize; pos++) { + if (query[pos] == '\n' || query[pos] == '\r') { + goto regular; + } + } + goto end; +end: + return ReplaceUnicodeSpaces(query_str, new_query, unicode_spaces); +} + +vector SplitQueryStringIntoStatements(const string &query) { + // Break sql string down into sql statements using the tokenizer + vector query_statements; + auto tokens = Parser::Tokenize(query); + auto next_statement_start = 0; + for (idx_t i = 1; i < tokens.size(); ++i) { + auto &t_prev = tokens[i - 1]; + auto &t = tokens[i]; + if (t_prev.type == SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR) { + // LCOV_EXCL_START + for (idx_t c = t_prev.start; c <= t.start; ++c) { + if (query.c_str()[c] == ';') { + query_statements.emplace_back(query.substr(next_statement_start, t.start - next_statement_start)); + next_statement_start = tokens[i].start; + } + } + // LCOV_EXCL_STOP + } + } + query_statements.emplace_back(query.substr(next_statement_start, query.size() - next_statement_start)); + return query_statements; +} + +void Parser::ParseQuery(const string &query) { + Transformer transformer(options); + string parser_error; + { + // check if there are any unicode spaces in the string + string new_query; + if (StripUnicodeSpaces(query, new_query)) { + // there are - strip the unicode spaces and re-run the query + ParseQuery(new_query); + return; + } + } + { + PostgresParser::SetPreserveIdentifierCase(options.preserve_identifier_case); + bool parsing_succeed = false; + // Creating a new scope to prevent multiple PostgresParser destructors being called + // which led to some memory issues + { + PostgresParser parser; + parser.Parse(query); + if (parser.success) { + if (!parser.parse_tree) { + // empty statement + return; + } + + // if it succeeded, we transform the Postgres parse tree into a list of + // SQLStatements + transformer.TransformParseTree(parser.parse_tree, statements); + parsing_succeed = true; + } else { + parser_error = QueryErrorContext::Format(query, parser.error_message, parser.error_location - 1); + } + } + // If DuckDB fails to parse the entire sql string, break the string down into individual statements + // using ';' as the delimiter so that parser extensions can parse the statement + if (parsing_succeed) { + // no-op + // return here would require refactoring into another function. o.w. will just no-op in order to run wrap up + // code at the end of this function + } else if (!options.extensions || options.extensions->empty()) { + throw ParserException(parser_error); + } else { + // split sql string into statements and re-parse using extension + auto query_statements = SplitQueryStringIntoStatements(query); + auto stmt_loc = 0; + for (auto const &query_statement : query_statements) { + string another_parser_error; + // Creating a new scope to allow extensions to use PostgresParser, which is not reentrant + { + PostgresParser another_parser; + another_parser.Parse(query_statement); + // LCOV_EXCL_START + // first see if DuckDB can parse this individual query statement + if (another_parser.success) { + if (!another_parser.parse_tree) { + // empty statement + continue; + } + transformer.TransformParseTree(another_parser.parse_tree, statements); + // important to set in the case of a mixture of DDB and parser ext statements + statements.back()->stmt_length = query_statement.size() - 1; + statements.back()->stmt_location = stmt_loc; + stmt_loc += query_statement.size(); + continue; + } else { + another_parser_error = QueryErrorContext::Format(query, another_parser.error_message, + another_parser.error_location - 1); + } + } // LCOV_EXCL_STOP + // LCOV_EXCL_START + // let extensions parse the statement which DuckDB failed to parse + bool parsed_single_statement = false; + for (auto &ext : *options.extensions) { + D_ASSERT(!parsed_single_statement); + D_ASSERT(ext.parse_function); + auto result = ext.parse_function(ext.parser_info.get(), query_statement); + if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { + auto statement = make_uniq(ext, std::move(result.parse_data)); + statement->stmt_length = query_statement.size() - 1; + statement->stmt_location = stmt_loc; + stmt_loc += query_statement.size(); + statements.push_back(std::move(statement)); + parsed_single_statement = true; + break; + } else if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { + throw ParserException(result.error); + } else { + // We move to the next one! + } + } + if (!parsed_single_statement) { + throw ParserException(parser_error); + } // LCOV_EXCL_STOP + } + } + } + if (!statements.empty()) { + auto &last_statement = statements.back(); + last_statement->stmt_length = query.size() - last_statement->stmt_location; + for (auto &statement : statements) { + statement->query = query; + if (statement->type == StatementType::CREATE_STATEMENT) { + auto &create = statement->Cast(); + create.info->sql = query.substr(statement->stmt_location, statement->stmt_length); + } + } + } +} + +vector Parser::Tokenize(const string &query) { + auto pg_tokens = PostgresParser::Tokenize(query); + vector result; + result.reserve(pg_tokens.size()); + for (auto &pg_token : pg_tokens) { + SimplifiedToken token; + switch (pg_token.type) { + case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_IDENTIFIER: + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER; + break; + case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_NUMERIC_CONSTANT: + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_NUMERIC_CONSTANT; + break; + case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_STRING_CONSTANT: + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_STRING_CONSTANT; + break; + case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_OPERATOR: + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR; + break; + case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_KEYWORD: + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD; + break; + // comments are not supported by our tokenizer right now + case duckdb_libpgquery::PGSimplifiedTokenType::PG_SIMPLIFIED_TOKEN_COMMENT: // LCOV_EXCL_START + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT; + break; + default: + throw InternalException("Unrecognized token category"); + } // LCOV_EXCL_STOP + token.start = pg_token.start; + result.push_back(token); + } + return result; +} + +bool Parser::IsKeyword(const string &text) { + return PostgresParser::IsKeyword(text); +} + +vector Parser::KeywordList() { + auto keywords = PostgresParser::KeywordList(); + vector result; + for (auto &kw : keywords) { + ParserKeyword res; + res.name = kw.text; + switch (kw.category) { + case duckdb_libpgquery::PGKeywordCategory::PG_KEYWORD_RESERVED: + res.category = KeywordCategory::KEYWORD_RESERVED; + break; + case duckdb_libpgquery::PGKeywordCategory::PG_KEYWORD_UNRESERVED: + res.category = KeywordCategory::KEYWORD_UNRESERVED; + break; + case duckdb_libpgquery::PGKeywordCategory::PG_KEYWORD_TYPE_FUNC: + res.category = KeywordCategory::KEYWORD_TYPE_FUNC; + break; + case duckdb_libpgquery::PGKeywordCategory::PG_KEYWORD_COL_NAME: + res.category = KeywordCategory::KEYWORD_COL_NAME; + break; + default: + throw InternalException("Unrecognized keyword category"); + } + result.push_back(res); + } + return result; +} + +vector> Parser::ParseExpressionList(const string &select_list, ParserOptions options) { + // construct a mock query prefixed with SELECT + string mock_query = "SELECT " + select_list; + // parse the query + Parser parser(options); + parser.ParseQuery(mock_query); + // check the statements + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { + throw ParserException("Expected a single SELECT statement"); + } + auto &select = parser.statements[0]->Cast(); + if (select.node->type != QueryNodeType::SELECT_NODE) { + throw ParserException("Expected a single SELECT node"); + } + auto &select_node = select.node->Cast(); + return std::move(select_node.select_list); +} + +GroupByNode Parser::ParseGroupByList(const string &group_by, ParserOptions options) { + // construct a mock SELECT query with our group_by expressions + string mock_query = StringUtil::Format("SELECT 42 GROUP BY %s", group_by); + // parse the query + Parser parser(options); + parser.ParseQuery(mock_query); + // check the result + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { + throw ParserException("Expected a single SELECT statement"); + } + auto &select = parser.statements[0]->Cast(); + D_ASSERT(select.node->type == QueryNodeType::SELECT_NODE); + auto &select_node = select.node->Cast(); + return std::move(select_node.groups); +} + +vector Parser::ParseOrderList(const string &select_list, ParserOptions options) { + // construct a mock query + string mock_query = "SELECT * FROM tbl ORDER BY " + select_list; + // parse the query + Parser parser(options); + parser.ParseQuery(mock_query); + // check the statements + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { + throw ParserException("Expected a single SELECT statement"); + } + auto &select = parser.statements[0]->Cast(); + D_ASSERT(select.node->type == QueryNodeType::SELECT_NODE); + auto &select_node = select.node->Cast(); + if (select_node.modifiers.empty() || select_node.modifiers[0]->type != ResultModifierType::ORDER_MODIFIER || + select_node.modifiers.size() != 1) { + throw ParserException("Expected a single ORDER clause"); + } + auto &order = select_node.modifiers[0]->Cast(); + return std::move(order.orders); +} + +void Parser::ParseUpdateList(const string &update_list, vector &update_columns, + vector> &expressions, ParserOptions options) { + // construct a mock query + string mock_query = "UPDATE tbl SET " + update_list; + // parse the query + Parser parser(options); + parser.ParseQuery(mock_query); + // check the statements + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::UPDATE_STATEMENT) { + throw ParserException("Expected a single UPDATE statement"); + } + auto &update = parser.statements[0]->Cast(); + update_columns = std::move(update.set_info->columns); + expressions = std::move(update.set_info->expressions); +} + +vector>> Parser::ParseValuesList(const string &value_list, ParserOptions options) { + // construct a mock query + string mock_query = "VALUES " + value_list; + // parse the query + Parser parser(options); + parser.ParseQuery(mock_query); + // check the statements + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { + throw ParserException("Expected a single SELECT statement"); + } + auto &select = parser.statements[0]->Cast(); + if (select.node->type != QueryNodeType::SELECT_NODE) { + throw ParserException("Expected a single SELECT node"); + } + auto &select_node = select.node->Cast(); + if (!select_node.from_table || select_node.from_table->type != TableReferenceType::EXPRESSION_LIST) { + throw ParserException("Expected a single VALUES statement"); + } + auto &values_list = select_node.from_table->Cast(); + return std::move(values_list.values); +} + +ColumnList Parser::ParseColumnList(const string &column_list, ParserOptions options) { + string mock_query = "CREATE TABLE blabla (" + column_list + ")"; + Parser parser(options); + parser.ParseQuery(mock_query); + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::CREATE_STATEMENT) { + throw ParserException("Expected a single CREATE statement"); + } + auto &create = parser.statements[0]->Cast(); + if (create.info->type != CatalogType::TABLE_ENTRY) { + throw InternalException("Expected a single CREATE TABLE statement"); + } + auto &info = create.info->Cast(); + return std::move(info.columns); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/query_error_context.cpp b/src/duckdb/src/parser/query_error_context.cpp new file mode 100644 index 00000000..58f3418b --- /dev/null +++ b/src/duckdb/src/parser/query_error_context.cpp @@ -0,0 +1,121 @@ +#include "duckdb/parser/query_error_context.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" + +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +string QueryErrorContext::Format(const string &query, const string &error_message, int error_loc) { + if (error_loc < 0 || size_t(error_loc) >= query.size()) { + // no location in query provided + return error_message; + } + idx_t error_location = idx_t(error_loc); + // count the line numbers until the error location + // and set the start position as the first character of that line + idx_t start_pos = 0; + idx_t line_number = 1; + for (idx_t i = 0; i < error_location; i++) { + if (StringUtil::CharacterIsNewline(query[i])) { + line_number++; + start_pos = i + 1; + } + } + // now find either the next newline token after the query, or find the end of string + // this is the initial end position + idx_t end_pos = query.size(); + for (idx_t i = error_location; i < query.size(); i++) { + if (StringUtil::CharacterIsNewline(query[i])) { + end_pos = i; + break; + } + } + // now start scanning from the start pos + // we want to figure out the start and end pos of what we are going to render + // we want to render at most 80 characters in total, with the error_location located in the middle + const char *buf = query.c_str() + start_pos; + idx_t len = end_pos - start_pos; + vector render_widths; + vector positions; + if (Utf8Proc::IsValid(buf, len)) { + // for unicode awareness, we traverse the graphemes of the current line and keep track of their render widths + // and of their position in the string + for (idx_t cpos = 0; cpos < len;) { + auto char_render_width = Utf8Proc::RenderWidth(buf, len, cpos); + positions.push_back(cpos); + render_widths.push_back(char_render_width); + cpos = Utf8Proc::NextGraphemeCluster(buf, len, cpos); + } + } else { // LCOV_EXCL_START + // invalid utf-8, we can't do much at this point + // we just assume every character is a character, and every character has a render width of 1 + for (idx_t cpos = 0; cpos < len; cpos++) { + positions.push_back(cpos); + render_widths.push_back(1); + } + } // LCOV_EXCL_STOP + // now we want to find the (unicode aware) start and end position + idx_t epos = 0; + // start by finding the error location inside the array + for (idx_t i = 0; i < positions.size(); i++) { + if (positions[i] >= (error_location - start_pos)) { + epos = i; + break; + } + } + bool truncate_beginning = false; + bool truncate_end = false; + idx_t spos = 0; + // now we iterate backwards from the error location + // we show max 40 render width before the error location + idx_t current_render_width = 0; + for (idx_t i = epos; i > 0; i--) { + current_render_width += render_widths[i]; + if (current_render_width >= 40) { + truncate_beginning = true; + start_pos = positions[i]; + spos = i; + break; + } + } + // now do the same, but going forward + current_render_width = 0; + for (idx_t i = epos; i < positions.size(); i++) { + current_render_width += render_widths[i]; + if (current_render_width >= 40) { + truncate_end = true; + end_pos = positions[i]; + break; + } + } + string line_indicator = "LINE " + to_string(line_number) + ": "; + string begin_trunc = truncate_beginning ? "..." : ""; + string end_trunc = truncate_end ? "..." : ""; + + // get the render width of the error indicator (i.e. how many spaces we need to insert before the ^) + idx_t error_render_width = 0; + for (idx_t i = spos; i < epos; i++) { + error_render_width += render_widths[i]; + } + error_render_width += line_indicator.size() + begin_trunc.size(); + + // now first print the error message plus the current line (or a subset of the line) + string result = error_message; + result += "\n" + line_indicator + begin_trunc + query.substr(start_pos, end_pos - start_pos) + end_trunc; + // print an arrow pointing at the error location + result += "\n" + string(error_render_width, ' ') + "^"; + return result; +} + +string QueryErrorContext::FormatErrorRecursive(const string &msg, vector &values) { + string error_message = values.empty() ? msg : ExceptionFormatValue::Format(msg, values); + if (!statement || query_location >= statement->query.size()) { + // no statement provided or query location out of range + return error_message; + } + return Format(statement->query, error_message, query_location); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/query_node.cpp b/src/duckdb/src/parser/query_node.cpp new file mode 100644 index 00000000..dab1375f --- /dev/null +++ b/src/duckdb/src/parser/query_node.cpp @@ -0,0 +1,184 @@ +#include "duckdb/parser/query_node.hpp" + +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/recursive_cte_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +CommonTableExpressionMap::CommonTableExpressionMap() { +} + +CommonTableExpressionMap CommonTableExpressionMap::Copy() const { + CommonTableExpressionMap res; + for (auto &kv : this->map) { + auto kv_info = make_uniq(); + for (auto &al : kv.second->aliases) { + kv_info->aliases.push_back(al); + } + kv_info->query = unique_ptr_cast(kv.second->query->Copy()); + kv_info->materialized = kv.second->materialized; + res.map[kv.first] = std::move(kv_info); + } + return res; +} + +string CommonTableExpressionMap::ToString() const { + if (map.empty()) { + return string(); + } + // check if there are any recursive CTEs + bool has_recursive = false; + for (auto &kv : map) { + if (kv.second->query->node->type == QueryNodeType::RECURSIVE_CTE_NODE) { + has_recursive = true; + break; + } + } + string result = "WITH "; + if (has_recursive) { + result += "RECURSIVE "; + } + bool first_cte = true; + for (auto &kv : map) { + if (!first_cte) { + result += ", "; + } + auto &cte = *kv.second; + result += KeywordHelper::WriteOptionallyQuoted(kv.first); + if (!cte.aliases.empty()) { + result += " ("; + for (idx_t k = 0; k < cte.aliases.size(); k++) { + if (k > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(cte.aliases[k]); + } + result += ")"; + } + if (kv.second->materialized == CTEMaterialize::CTE_MATERIALIZE_ALWAYS) { + result += " AS MATERIALIZED ("; + } else if (kv.second->materialized == CTEMaterialize::CTE_MATERIALIZE_NEVER) { + result += " AS NOT MATERIALIZED ("; + } else { + result += " AS ("; + } + result += cte.query->ToString(); + result += ")"; + first_cte = false; + } + return result; +} + +string QueryNode::ResultModifiersToString() const { + string result; + for (idx_t modifier_idx = 0; modifier_idx < modifiers.size(); modifier_idx++) { + auto &modifier = *modifiers[modifier_idx]; + if (modifier.type == ResultModifierType::ORDER_MODIFIER) { + auto &order_modifier = modifier.Cast(); + result += " ORDER BY "; + for (idx_t k = 0; k < order_modifier.orders.size(); k++) { + if (k > 0) { + result += ", "; + } + result += order_modifier.orders[k].ToString(); + } + } else if (modifier.type == ResultModifierType::LIMIT_MODIFIER) { + auto &limit_modifier = modifier.Cast(); + if (limit_modifier.limit) { + result += " LIMIT " + limit_modifier.limit->ToString(); + } + if (limit_modifier.offset) { + result += " OFFSET " + limit_modifier.offset->ToString(); + } + } else if (modifier.type == ResultModifierType::LIMIT_PERCENT_MODIFIER) { + auto &limit_p_modifier = modifier.Cast(); + if (limit_p_modifier.limit) { + result += " LIMIT (" + limit_p_modifier.limit->ToString() + ") %"; + } + if (limit_p_modifier.offset) { + result += " OFFSET " + limit_p_modifier.offset->ToString(); + } + } + } + return result; +} + +bool QueryNode::Equals(const QueryNode *other) const { + if (!other) { + return false; + } + if (this == other) { + return true; + } + if (other->type != this->type) { + return false; + } + + if (modifiers.size() != other->modifiers.size()) { + return false; + } + for (idx_t i = 0; i < modifiers.size(); i++) { + if (!modifiers[i]->Equals(*other->modifiers[i])) { + return false; + } + } + // WITH clauses (CTEs) + if (cte_map.map.size() != other->cte_map.map.size()) { + return false; + } + for (auto &entry : cte_map.map) { + auto other_entry = other->cte_map.map.find(entry.first); + if (other_entry == other->cte_map.map.end()) { + return false; + } + if (entry.second->aliases != other_entry->second->aliases) { + return false; + } + if (!entry.second->query->Equals(*other_entry->second->query)) { + return false; + } + } + return other->type == type; +} + +void QueryNode::CopyProperties(QueryNode &other) const { + for (auto &modifier : modifiers) { + other.modifiers.push_back(modifier->Copy()); + } + for (auto &kv : cte_map.map) { + auto kv_info = make_uniq(); + for (auto &al : kv.second->aliases) { + kv_info->aliases.push_back(al); + } + kv_info->query = unique_ptr_cast(kv.second->query->Copy()); + kv_info->materialized = kv.second->materialized; + other.cte_map.map[kv.first] = std::move(kv_info); + } +} + +void QueryNode::AddDistinct() { + // check if we already have a DISTINCT modifier + for (idx_t modifier_idx = modifiers.size(); modifier_idx > 0; modifier_idx--) { + auto &modifier = *modifiers[modifier_idx - 1]; + if (modifier.type == ResultModifierType::DISTINCT_MODIFIER) { + auto &distinct_modifier = modifier.Cast(); + if (distinct_modifier.distinct_on_targets.empty()) { + // we have a DISTINCT without an ON clause - this distinct does not need to be added + return; + } + } else if (modifier.type == ResultModifierType::LIMIT_MODIFIER || + modifier.type == ResultModifierType::LIMIT_PERCENT_MODIFIER) { + // we encountered a LIMIT or LIMIT PERCENT - these change the result of DISTINCT, so we do need to push a + // DISTINCT relation + break; + } + } + modifiers.push_back(make_uniq()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/query_node/cte_node.cpp b/src/duckdb/src/parser/query_node/cte_node.cpp new file mode 100644 index 00000000..4504c985 --- /dev/null +++ b/src/duckdb/src/parser/query_node/cte_node.cpp @@ -0,0 +1,41 @@ +#include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +string CTENode::ToString() const { + string result; + result += child->ToString(); + return result; +} + +bool CTENode::Equals(const QueryNode *other_p) const { + if (!QueryNode::Equals(other_p)) { + return false; + } + if (this == other_p) { + return true; + } + auto &other = other_p->Cast(); + + if (!query->Equals(other.query.get())) { + return false; + } + if (!child->Equals(other.child.get())) { + return false; + } + return true; +} + +unique_ptr CTENode::Copy() const { + auto result = make_uniq(); + result->ctename = ctename; + result->query = query->Copy(); + result->child = child->Copy(); + result->aliases = aliases; + this->CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/query_node/recursive_cte_node.cpp b/src/duckdb/src/parser/query_node/recursive_cte_node.cpp new file mode 100644 index 00000000..29cee0e9 --- /dev/null +++ b/src/duckdb/src/parser/query_node/recursive_cte_node.cpp @@ -0,0 +1,50 @@ +#include "duckdb/parser/query_node/recursive_cte_node.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +string RecursiveCTENode::ToString() const { + string result; + result += "(" + left->ToString() + ")"; + result += " UNION "; + if (union_all) { + result += " ALL "; + } + result += "(" + right->ToString() + ")"; + return result; +} + +bool RecursiveCTENode::Equals(const QueryNode *other_p) const { + if (!QueryNode::Equals(other_p)) { + return false; + } + if (this == other_p) { + return true; + } + auto &other = other_p->Cast(); + + if (other.union_all != union_all) { + return false; + } + if (!left->Equals(other.left.get())) { + return false; + } + if (!right->Equals(other.right.get())) { + return false; + } + return true; +} + +unique_ptr RecursiveCTENode::Copy() const { + auto result = make_uniq(); + result->ctename = ctename; + result->union_all = union_all; + result->left = left->Copy(); + result->right = right->Copy(); + result->aliases = aliases; + this->CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/query_node/select_node.cpp b/src/duckdb/src/parser/query_node/select_node.cpp new file mode 100644 index 00000000..e228bc39 --- /dev/null +++ b/src/duckdb/src/parser/query_node/select_node.cpp @@ -0,0 +1,170 @@ +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression_util.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +SelectNode::SelectNode() + : QueryNode(QueryNodeType::SELECT_NODE), aggregate_handling(AggregateHandling::STANDARD_HANDLING) { +} + +string SelectNode::ToString() const { + string result; + result = cte_map.ToString(); + result += "SELECT "; + + // search for a distinct modifier + for (idx_t modifier_idx = 0; modifier_idx < modifiers.size(); modifier_idx++) { + if (modifiers[modifier_idx]->type == ResultModifierType::DISTINCT_MODIFIER) { + auto &distinct_modifier = modifiers[modifier_idx]->Cast(); + result += "DISTINCT "; + if (!distinct_modifier.distinct_on_targets.empty()) { + result += "ON ("; + for (idx_t k = 0; k < distinct_modifier.distinct_on_targets.size(); k++) { + if (k > 0) { + result += ", "; + } + result += distinct_modifier.distinct_on_targets[k]->ToString(); + } + result += ") "; + } + } + } + for (idx_t i = 0; i < select_list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += select_list[i]->ToString(); + if (!select_list[i]->alias.empty()) { + result += StringUtil::Format(" AS %s", SQLIdentifier(select_list[i]->alias)); + } + } + if (from_table && from_table->type != TableReferenceType::EMPTY) { + result += " FROM " + from_table->ToString(); + } + if (where_clause) { + result += " WHERE " + where_clause->ToString(); + } + if (!groups.grouping_sets.empty()) { + result += " GROUP BY "; + // if we are dealing with multiple grouping sets, we have to add a few additional brackets + bool grouping_sets = groups.grouping_sets.size() > 1; + if (grouping_sets) { + result += "GROUPING SETS ("; + } + for (idx_t i = 0; i < groups.grouping_sets.size(); i++) { + auto &grouping_set = groups.grouping_sets[i]; + if (i > 0) { + result += ","; + } + if (grouping_set.empty()) { + result += "()"; + continue; + } + if (grouping_sets) { + result += "("; + } + bool first = true; + for (auto &grp : grouping_set) { + if (!first) { + result += ", "; + } + result += groups.group_expressions[grp]->ToString(); + first = false; + } + if (grouping_sets) { + result += ")"; + } + } + if (grouping_sets) { + result += ")"; + } + } else if (aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { + result += " GROUP BY ALL"; + } + if (having) { + result += " HAVING " + having->ToString(); + } + if (qualify) { + result += " QUALIFY " + qualify->ToString(); + } + if (sample) { + result += " USING SAMPLE "; + result += sample->sample_size.ToString(); + if (sample->is_percentage) { + result += "%"; + } + result += " (" + EnumUtil::ToString(sample->method); + if (sample->seed >= 0) { + result += ", " + std::to_string(sample->seed); + } + result += ")"; + } + return result + ResultModifiersToString(); +} + +bool SelectNode::Equals(const QueryNode *other_p) const { + if (!QueryNode::Equals(other_p)) { + return false; + } + if (this == other_p) { + return true; + } + auto &other = other_p->Cast(); + + // SELECT + if (!ExpressionUtil::ListEquals(select_list, other.select_list)) { + return false; + } + // FROM + if (!TableRef::Equals(from_table, other.from_table)) { + return false; + } + // WHERE + if (!ParsedExpression::Equals(where_clause, other.where_clause)) { + return false; + } + // GROUP BY + if (!ParsedExpression::ListEquals(groups.group_expressions, other.groups.group_expressions)) { + return false; + } + if (groups.grouping_sets != other.groups.grouping_sets) { + return false; + } + if (!SampleOptions::Equals(sample.get(), other.sample.get())) { + return false; + } + // HAVING + if (!ParsedExpression::Equals(having, other.having)) { + return false; + } + // QUALIFY + if (!ParsedExpression::Equals(qualify, other.qualify)) { + return false; + } + return true; +} + +unique_ptr SelectNode::Copy() const { + auto result = make_uniq(); + for (auto &child : select_list) { + result->select_list.push_back(child->Copy()); + } + result->from_table = from_table ? from_table->Copy() : nullptr; + result->where_clause = where_clause ? where_clause->Copy() : nullptr; + // groups + for (auto &group : groups.group_expressions) { + result->groups.group_expressions.push_back(group->Copy()); + } + result->groups.grouping_sets = groups.grouping_sets; + result->aggregate_handling = aggregate_handling; + result->having = having ? having->Copy() : nullptr; + result->qualify = qualify ? qualify->Copy() : nullptr; + result->sample = sample ? sample->Copy() : nullptr; + this->CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/query_node/set_operation_node.cpp b/src/duckdb/src/parser/query_node/set_operation_node.cpp new file mode 100644 index 00000000..1e82a6b0 --- /dev/null +++ b/src/duckdb/src/parser/query_node/set_operation_node.cpp @@ -0,0 +1,71 @@ +#include "duckdb/parser/query_node/set_operation_node.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +string SetOperationNode::ToString() const { + string result; + result = cte_map.ToString(); + result += "(" + left->ToString() + ") "; + bool is_distinct = false; + for (idx_t modifier_idx = 0; modifier_idx < modifiers.size(); modifier_idx++) { + if (modifiers[modifier_idx]->type == ResultModifierType::DISTINCT_MODIFIER) { + is_distinct = true; + break; + } + } + + switch (setop_type) { + case SetOperationType::UNION: + result += is_distinct ? "UNION" : "UNION ALL"; + break; + case SetOperationType::UNION_BY_NAME: + result += is_distinct ? "UNION BY NAME" : "UNION ALL BY NAME"; + break; + case SetOperationType::EXCEPT: + D_ASSERT(is_distinct); + result += "EXCEPT"; + break; + case SetOperationType::INTERSECT: + D_ASSERT(is_distinct); + result += "INTERSECT"; + break; + default: + throw InternalException("Unsupported set operation type"); + } + result += " (" + right->ToString() + ")"; + return result + ResultModifiersToString(); +} + +bool SetOperationNode::Equals(const QueryNode *other_p) const { + if (!QueryNode::Equals(other_p)) { + return false; + } + if (this == other_p) { + return true; + } + auto &other = other_p->Cast(); + if (setop_type != other.setop_type) { + return false; + } + if (!left->Equals(other.left.get())) { + return false; + } + if (!right->Equals(other.right.get())) { + return false; + } + return true; +} + +unique_ptr SetOperationNode::Copy() const { + auto result = make_uniq(); + result->setop_type = setop_type; + result->left = left->Copy(); + result->right = right->Copy(); + this->CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/result_modifier.cpp b/src/duckdb/src/parser/result_modifier.cpp new file mode 100644 index 00000000..eae317e4 --- /dev/null +++ b/src/duckdb/src/parser/result_modifier.cpp @@ -0,0 +1,144 @@ +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/parser/expression_util.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +bool ResultModifier::Equals(const ResultModifier &other) const { + return type == other.type; +} + +bool LimitModifier::Equals(const ResultModifier &other_p) const { + if (!ResultModifier::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!ParsedExpression::Equals(limit, other.limit)) { + return false; + } + if (!ParsedExpression::Equals(offset, other.offset)) { + return false; + } + return true; +} + +unique_ptr LimitModifier::Copy() const { + auto copy = make_uniq(); + if (limit) { + copy->limit = limit->Copy(); + } + if (offset) { + copy->offset = offset->Copy(); + } + return std::move(copy); +} + +bool DistinctModifier::Equals(const ResultModifier &other_p) const { + if (!ResultModifier::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!ExpressionUtil::ListEquals(distinct_on_targets, other.distinct_on_targets)) { + return false; + } + return true; +} + +unique_ptr DistinctModifier::Copy() const { + auto copy = make_uniq(); + for (auto &expr : distinct_on_targets) { + copy->distinct_on_targets.push_back(expr->Copy()); + } + return std::move(copy); +} + +bool OrderModifier::Equals(const ResultModifier &other_p) const { + if (!ResultModifier::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (orders.size() != other.orders.size()) { + return false; + } + for (idx_t i = 0; i < orders.size(); i++) { + if (orders[i].type != other.orders[i].type) { + return false; + } + if (!BaseExpression::Equals(*orders[i].expression, *other.orders[i].expression)) { + return false; + } + } + return true; +} + +bool OrderModifier::Equals(const unique_ptr &left, const unique_ptr &right) { + if (left.get() == right.get()) { + return true; + } + if (!left || !right) { + return false; + } + return left->Equals(*right); +} + +unique_ptr OrderModifier::Copy() const { + auto copy = make_uniq(); + for (auto &order : orders) { + copy->orders.emplace_back(order.type, order.null_order, order.expression->Copy()); + } + return std::move(copy); +} + +string OrderByNode::ToString() const { + auto str = expression->ToString(); + switch (type) { + case OrderType::ASCENDING: + str += " ASC"; + break; + case OrderType::DESCENDING: + str += " DESC"; + break; + default: + break; + } + + switch (null_order) { + case OrderByNullType::NULLS_FIRST: + str += " NULLS FIRST"; + break; + case OrderByNullType::NULLS_LAST: + str += " NULLS LAST"; + break; + default: + break; + } + return str; +} + +bool LimitPercentModifier::Equals(const ResultModifier &other_p) const { + if (!ResultModifier::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!ParsedExpression::Equals(limit, other.limit)) { + return false; + } + if (!ParsedExpression::Equals(offset, other.offset)) { + return false; + } + return true; +} + +unique_ptr LimitPercentModifier::Copy() const { + auto copy = make_uniq(); + if (limit) { + copy->limit = limit->Copy(); + } + if (offset) { + copy->offset = offset->Copy(); + } + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/alter_statement.cpp b/src/duckdb/src/parser/statement/alter_statement.cpp new file mode 100644 index 00000000..f9a05f08 --- /dev/null +++ b/src/duckdb/src/parser/statement/alter_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/alter_statement.hpp" + +namespace duckdb { + +AlterStatement::AlterStatement() : SQLStatement(StatementType::ALTER_STATEMENT) { +} + +AlterStatement::AlterStatement(const AlterStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr AlterStatement::Copy() const { + return unique_ptr(new AlterStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/attach_statement.cpp b/src/duckdb/src/parser/statement/attach_statement.cpp new file mode 100644 index 00000000..0bae08cd --- /dev/null +++ b/src/duckdb/src/parser/statement/attach_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/attach_statement.hpp" + +namespace duckdb { + +AttachStatement::AttachStatement() : SQLStatement(StatementType::ATTACH_STATEMENT) { +} + +AttachStatement::AttachStatement(const AttachStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr AttachStatement::Copy() const { + return unique_ptr(new AttachStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/call_statement.cpp b/src/duckdb/src/parser/statement/call_statement.cpp new file mode 100644 index 00000000..cd7a2218 --- /dev/null +++ b/src/duckdb/src/parser/statement/call_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/call_statement.hpp" + +namespace duckdb { + +CallStatement::CallStatement() : SQLStatement(StatementType::CALL_STATEMENT) { +} + +CallStatement::CallStatement(const CallStatement &other) : SQLStatement(other), function(other.function->Copy()) { +} + +unique_ptr CallStatement::Copy() const { + return unique_ptr(new CallStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/copy_statement.cpp b/src/duckdb/src/parser/statement/copy_statement.cpp new file mode 100644 index 00000000..9031de28 --- /dev/null +++ b/src/duckdb/src/parser/statement/copy_statement.cpp @@ -0,0 +1,109 @@ +#include "duckdb/parser/statement/copy_statement.hpp" + +namespace duckdb { + +CopyStatement::CopyStatement() : SQLStatement(StatementType::COPY_STATEMENT), info(make_uniq()) { +} + +CopyStatement::CopyStatement(const CopyStatement &other) : SQLStatement(other), info(other.info->Copy()) { + if (other.select_statement) { + select_statement = other.select_statement->Copy(); + } +} + +string CopyStatement::CopyOptionsToString(const string &format, + const case_insensitive_map_t> &options) const { + if (format.empty() && options.empty()) { + return string(); + } + string result; + + result += " ("; + if (!format.empty()) { + result += " FORMAT "; + result += format; + } + for (auto it = options.begin(); it != options.end(); it++) { + if (!format.empty() || it != options.begin()) { + result += ", "; + } + auto &name = it->first; + auto &values = it->second; + + result += name + " "; + if (values.empty()) { + // Options like HEADER don't need an explicit value + // just providing the name already sets it to true + } else if (values.size() == 1) { + result += values[0].ToSQLString(); + } else { + result += "( "; + for (idx_t i = 0; i < values.size(); i++) { + if (i) { + result += ", "; + } + result += values[i].ToSQLString(); + } + result += " )"; + } + } + result += " )"; + return result; +} + +// COPY table-name (c1, c2, ..) +string TablePart(const CopyInfo &info) { + string result; + + if (!info.catalog.empty()) { + result += KeywordHelper::WriteOptionallyQuoted(info.catalog) + "."; + } + if (!info.schema.empty()) { + result += KeywordHelper::WriteOptionallyQuoted(info.schema) + "."; + } + D_ASSERT(!info.table.empty()); + result += KeywordHelper::WriteOptionallyQuoted(info.table); + + // (c1, c2, ..) + if (!info.select_list.empty()) { + result += " ("; + for (idx_t i = 0; i < info.select_list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(info.select_list[i]); + } + result += " )"; + } + return result; +} + +string CopyStatement::ToString() const { + string result; + + result += "COPY "; + if (info->is_from) { + D_ASSERT(!select_statement); + result += TablePart(*info); + result += " FROM"; + result += StringUtil::Format(" %s", SQLString(info->file_path)); + result += CopyOptionsToString(info->format, info->options); + } else { + if (select_statement) { + // COPY (select-node) TO ... + result += "(" + select_statement->ToString() + ")"; + } else { + result += TablePart(*info); + } + result += " TO "; + result += StringUtil::Format("%s", SQLString(info->file_path)); + result += CopyOptionsToString(info->format, info->options); + } + return result; +} + +unique_ptr CopyStatement::Copy() const { + return unique_ptr(new CopyStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/create_statement.cpp b/src/duckdb/src/parser/statement/create_statement.cpp new file mode 100644 index 00000000..514807f0 --- /dev/null +++ b/src/duckdb/src/parser/statement/create_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/create_statement.hpp" + +namespace duckdb { + +CreateStatement::CreateStatement() : SQLStatement(StatementType::CREATE_STATEMENT) { +} + +CreateStatement::CreateStatement(const CreateStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr CreateStatement::Copy() const { + return unique_ptr(new CreateStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/delete_statement.cpp b/src/duckdb/src/parser/statement/delete_statement.cpp new file mode 100644 index 00000000..1de2767c --- /dev/null +++ b/src/duckdb/src/parser/statement/delete_statement.cpp @@ -0,0 +1,56 @@ +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" + +namespace duckdb { + +DeleteStatement::DeleteStatement() : SQLStatement(StatementType::DELETE_STATEMENT) { +} + +DeleteStatement::DeleteStatement(const DeleteStatement &other) : SQLStatement(other), table(other.table->Copy()) { + if (other.condition) { + condition = other.condition->Copy(); + } + for (const auto &using_clause : other.using_clauses) { + using_clauses.push_back(using_clause->Copy()); + } + for (auto &expr : other.returning_list) { + returning_list.emplace_back(expr->Copy()); + } + cte_map = other.cte_map.Copy(); +} + +string DeleteStatement::ToString() const { + string result; + result = cte_map.ToString(); + result += "DELETE FROM "; + result += table->ToString(); + if (!using_clauses.empty()) { + result += " USING "; + for (idx_t i = 0; i < using_clauses.size(); i++) { + if (i > 0) { + result += ", "; + } + result += using_clauses[i]->ToString(); + } + } + if (condition) { + result += " WHERE " + condition->ToString(); + } + + if (!returning_list.empty()) { + result += " RETURNING "; + for (idx_t i = 0; i < returning_list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += returning_list[i]->ToString(); + } + } + return result; +} + +unique_ptr DeleteStatement::Copy() const { + return unique_ptr(new DeleteStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/detach_statement.cpp b/src/duckdb/src/parser/statement/detach_statement.cpp new file mode 100644 index 00000000..1ca52f71 --- /dev/null +++ b/src/duckdb/src/parser/statement/detach_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/detach_statement.hpp" + +namespace duckdb { + +DetachStatement::DetachStatement() : SQLStatement(StatementType::DETACH_STATEMENT) { +} + +DetachStatement::DetachStatement(const DetachStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr DetachStatement::Copy() const { + return unique_ptr(new DetachStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/drop_statement.cpp b/src/duckdb/src/parser/statement/drop_statement.cpp new file mode 100644 index 00000000..7bf363f3 --- /dev/null +++ b/src/duckdb/src/parser/statement/drop_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/drop_statement.hpp" + +namespace duckdb { + +DropStatement::DropStatement() : SQLStatement(StatementType::DROP_STATEMENT), info(make_uniq()) { +} + +DropStatement::DropStatement(const DropStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr DropStatement::Copy() const { + return unique_ptr(new DropStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/execute_statement.cpp b/src/duckdb/src/parser/statement/execute_statement.cpp new file mode 100644 index 00000000..2a2fa131 --- /dev/null +++ b/src/duckdb/src/parser/statement/execute_statement.cpp @@ -0,0 +1,18 @@ +#include "duckdb/parser/statement/execute_statement.hpp" + +namespace duckdb { + +ExecuteStatement::ExecuteStatement() : SQLStatement(StatementType::EXECUTE_STATEMENT) { +} + +ExecuteStatement::ExecuteStatement(const ExecuteStatement &other) : SQLStatement(other), name(other.name) { + for (const auto &item : other.named_values) { + named_values.emplace(std::make_pair(item.first, item.second->Copy())); + } +} + +unique_ptr ExecuteStatement::Copy() const { + return unique_ptr(new ExecuteStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/explain_statement.cpp b/src/duckdb/src/parser/statement/explain_statement.cpp new file mode 100644 index 00000000..2ad08fec --- /dev/null +++ b/src/duckdb/src/parser/statement/explain_statement.cpp @@ -0,0 +1,17 @@ +#include "duckdb/parser/statement/explain_statement.hpp" + +namespace duckdb { + +ExplainStatement::ExplainStatement(unique_ptr stmt, ExplainType explain_type) + : SQLStatement(StatementType::EXPLAIN_STATEMENT), stmt(std::move(stmt)), explain_type(explain_type) { +} + +ExplainStatement::ExplainStatement(const ExplainStatement &other) + : SQLStatement(other), stmt(other.stmt->Copy()), explain_type(other.explain_type) { +} + +unique_ptr ExplainStatement::Copy() const { + return unique_ptr(new ExplainStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/export_statement.cpp b/src/duckdb/src/parser/statement/export_statement.cpp new file mode 100644 index 00000000..e9358e31 --- /dev/null +++ b/src/duckdb/src/parser/statement/export_statement.cpp @@ -0,0 +1,16 @@ +#include "duckdb/parser/statement/export_statement.hpp" + +namespace duckdb { + +ExportStatement::ExportStatement(unique_ptr info) + : SQLStatement(StatementType::EXPORT_STATEMENT), info(std::move(info)) { +} + +ExportStatement::ExportStatement(const ExportStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr ExportStatement::Copy() const { + return unique_ptr(new ExportStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/extension_statement.cpp b/src/duckdb/src/parser/statement/extension_statement.cpp new file mode 100644 index 00000000..22d6dec1 --- /dev/null +++ b/src/duckdb/src/parser/statement/extension_statement.cpp @@ -0,0 +1,14 @@ +#include "duckdb/parser/statement/extension_statement.hpp" + +namespace duckdb { + +ExtensionStatement::ExtensionStatement(ParserExtension extension_p, unique_ptr parse_data_p) + : SQLStatement(StatementType::EXTENSION_STATEMENT), extension(std::move(extension_p)), + parse_data(std::move(parse_data_p)) { +} + +unique_ptr ExtensionStatement::Copy() const { + return make_uniq(extension, parse_data->Copy()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/insert_statement.cpp b/src/duckdb/src/parser/statement/insert_statement.cpp new file mode 100644 index 00000000..4255ebb9 --- /dev/null +++ b/src/duckdb/src/parser/statement/insert_statement.cpp @@ -0,0 +1,196 @@ +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/parser/statement/update_statement.hpp" + +namespace duckdb { + +OnConflictInfo::OnConflictInfo() : action_type(OnConflictAction::THROW) { +} + +OnConflictInfo::OnConflictInfo(const OnConflictInfo &other) + : action_type(other.action_type), indexed_columns(other.indexed_columns) { + if (other.set_info) { + set_info = other.set_info->Copy(); + } + if (other.condition) { + condition = other.condition->Copy(); + } +} + +unique_ptr OnConflictInfo::Copy() const { + return unique_ptr(new OnConflictInfo(*this)); +} + +InsertStatement::InsertStatement() + : SQLStatement(StatementType::INSERT_STATEMENT), schema(DEFAULT_SCHEMA), catalog(INVALID_CATALOG) { +} + +InsertStatement::InsertStatement(const InsertStatement &other) + : SQLStatement(other), select_statement(unique_ptr_cast( + other.select_statement ? other.select_statement->Copy() : nullptr)), + columns(other.columns), table(other.table), schema(other.schema), catalog(other.catalog), + default_values(other.default_values), column_order(other.column_order) { + cte_map = other.cte_map.Copy(); + for (auto &expr : other.returning_list) { + returning_list.emplace_back(expr->Copy()); + } + if (other.table_ref) { + table_ref = other.table_ref->Copy(); + } + if (other.on_conflict_info) { + on_conflict_info = other.on_conflict_info->Copy(); + } +} + +string InsertStatement::OnConflictActionToString(OnConflictAction action) { + switch (action) { + case OnConflictAction::NOTHING: + return "DO NOTHING"; + case OnConflictAction::REPLACE: + case OnConflictAction::UPDATE: + return "DO UPDATE"; + case OnConflictAction::THROW: + // Explicitly left empty, for ToString purposes + return ""; + default: { + throw NotImplementedException("type not implemented for OnConflictActionType"); + } + } +} + +string InsertStatement::ToString() const { + bool or_replace_shorthand_set = false; + string result; + + result = cte_map.ToString(); + result += "INSERT"; + if (on_conflict_info && on_conflict_info->action_type == OnConflictAction::REPLACE) { + or_replace_shorthand_set = true; + result += " OR REPLACE"; + } + result += " INTO "; + if (!catalog.empty()) { + result += KeywordHelper::WriteOptionallyQuoted(catalog) + "."; + } + if (!schema.empty()) { + result += KeywordHelper::WriteOptionallyQuoted(schema) + "."; + } + result += KeywordHelper::WriteOptionallyQuoted(table); + // Write the (optional) alias of the insert target + if (table_ref && !table_ref->alias.empty()) { + result += StringUtil::Format(" AS %s", KeywordHelper::WriteOptionallyQuoted(table_ref->alias)); + } + if (column_order == InsertColumnOrder::INSERT_BY_NAME) { + result += " BY NAME"; + } + if (!columns.empty()) { + result += " ("; + for (idx_t i = 0; i < columns.size(); i++) { + if (i > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(columns[i]); + } + result += " )"; + } + result += " "; + auto values_list = GetValuesList(); + if (values_list) { + D_ASSERT(!default_values); + values_list->alias = string(); + result += values_list->ToString(); + } else if (select_statement) { + D_ASSERT(!default_values); + result += select_statement->ToString(); + } else { + D_ASSERT(default_values); + result += "DEFAULT VALUES"; + } + if (!or_replace_shorthand_set && on_conflict_info) { + auto &conflict_info = *on_conflict_info; + result += " ON CONFLICT "; + // (optional) conflict target + if (!conflict_info.indexed_columns.empty()) { + result += "("; + auto &columns = conflict_info.indexed_columns; + for (auto it = columns.begin(); it != columns.end();) { + result += StringUtil::Lower(*it); + if (++it != columns.end()) { + result += ", "; + } + } + result += " )"; + } + + // (optional) where clause + if (conflict_info.condition) { + result += " WHERE " + conflict_info.condition->ToString(); + } + result += " " + OnConflictActionToString(conflict_info.action_type); + if (conflict_info.set_info) { + D_ASSERT(conflict_info.action_type == OnConflictAction::UPDATE); + result += " SET "; + auto &set_info = *conflict_info.set_info; + D_ASSERT(set_info.columns.size() == set_info.expressions.size()); + // SET = + for (idx_t i = 0; i < set_info.columns.size(); i++) { + auto &column = set_info.columns[i]; + auto &expr = set_info.expressions[i]; + if (i) { + result += ", "; + } + result += StringUtil::Lower(column) + " = " + expr->ToString(); + } + // (optional) where clause + if (set_info.condition) { + result += " WHERE " + set_info.condition->ToString(); + } + } + } + if (!returning_list.empty()) { + result += " RETURNING "; + for (idx_t i = 0; i < returning_list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += returning_list[i]->ToString(); + } + } + return result; +} + +unique_ptr InsertStatement::Copy() const { + return unique_ptr(new InsertStatement(*this)); +} + +optional_ptr InsertStatement::GetValuesList() const { + if (!select_statement) { + return nullptr; + } + if (select_statement->node->type != QueryNodeType::SELECT_NODE) { + return nullptr; + } + auto &node = select_statement->node->Cast(); + if (node.where_clause || node.qualify || node.having) { + return nullptr; + } + if (!node.cte_map.map.empty()) { + return nullptr; + } + if (!node.groups.grouping_sets.empty()) { + return nullptr; + } + if (node.aggregate_handling != AggregateHandling::STANDARD_HANDLING) { + return nullptr; + } + if (node.select_list.size() != 1 || node.select_list[0]->type != ExpressionType::STAR) { + return nullptr; + } + if (!node.from_table || node.from_table->type != TableReferenceType::EXPRESSION_LIST) { + return nullptr; + } + return &node.from_table->Cast(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/load_statement.cpp b/src/duckdb/src/parser/statement/load_statement.cpp new file mode 100644 index 00000000..da59355c --- /dev/null +++ b/src/duckdb/src/parser/statement/load_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/load_statement.hpp" + +namespace duckdb { + +LoadStatement::LoadStatement() : SQLStatement(StatementType::LOAD_STATEMENT) { +} + +LoadStatement::LoadStatement(const LoadStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr LoadStatement::Copy() const { + return unique_ptr(new LoadStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/multi_statement.cpp b/src/duckdb/src/parser/statement/multi_statement.cpp new file mode 100644 index 00000000..3daea30f --- /dev/null +++ b/src/duckdb/src/parser/statement/multi_statement.cpp @@ -0,0 +1,18 @@ +#include "duckdb/parser/statement/multi_statement.hpp" + +namespace duckdb { + +MultiStatement::MultiStatement() : SQLStatement(StatementType::MULTI_STATEMENT) { +} + +MultiStatement::MultiStatement(const MultiStatement &other) : SQLStatement(other) { + for (auto &stmt : other.statements) { + statements.push_back(stmt->Copy()); + } +} + +unique_ptr MultiStatement::Copy() const { + return unique_ptr(new MultiStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/pragma_statement.cpp b/src/duckdb/src/parser/statement/pragma_statement.cpp new file mode 100644 index 00000000..7c083b29 --- /dev/null +++ b/src/duckdb/src/parser/statement/pragma_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/pragma_statement.hpp" + +namespace duckdb { + +PragmaStatement::PragmaStatement() : SQLStatement(StatementType::PRAGMA_STATEMENT), info(make_uniq()) { +} + +PragmaStatement::PragmaStatement(const PragmaStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr PragmaStatement::Copy() const { + return unique_ptr(new PragmaStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/prepare_statement.cpp b/src/duckdb/src/parser/statement/prepare_statement.cpp new file mode 100644 index 00000000..5e50dfc3 --- /dev/null +++ b/src/duckdb/src/parser/statement/prepare_statement.cpp @@ -0,0 +1,16 @@ +#include "duckdb/parser/statement/prepare_statement.hpp" + +namespace duckdb { + +PrepareStatement::PrepareStatement() : SQLStatement(StatementType::PREPARE_STATEMENT), statement(nullptr), name("") { +} + +PrepareStatement::PrepareStatement(const PrepareStatement &other) + : SQLStatement(other), statement(other.statement->Copy()), name(other.name) { +} + +unique_ptr PrepareStatement::Copy() const { + return unique_ptr(new PrepareStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/relation_statement.cpp b/src/duckdb/src/parser/statement/relation_statement.cpp new file mode 100644 index 00000000..f9d4d038 --- /dev/null +++ b/src/duckdb/src/parser/statement/relation_statement.cpp @@ -0,0 +1,13 @@ +#include "duckdb/parser/statement/relation_statement.hpp" + +namespace duckdb { + +RelationStatement::RelationStatement(shared_ptr relation) + : SQLStatement(StatementType::RELATION_STATEMENT), relation(std::move(relation)) { +} + +unique_ptr RelationStatement::Copy() const { + return unique_ptr(new RelationStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/select_statement.cpp b/src/duckdb/src/parser/statement/select_statement.cpp new file mode 100644 index 00000000..1c686c67 --- /dev/null +++ b/src/duckdb/src/parser/statement/select_statement.cpp @@ -0,0 +1,27 @@ +#include "duckdb/parser/statement/select_statement.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +SelectStatement::SelectStatement(const SelectStatement &other) : SQLStatement(other), node(other.node->Copy()) { +} + +unique_ptr SelectStatement::Copy() const { + return unique_ptr(new SelectStatement(*this)); +} + +bool SelectStatement::Equals(const SQLStatement &other_p) const { + if (type != other_p.type) { + return false; + } + auto &other = other_p.Cast(); + return node->Equals(other.node.get()); +} + +string SelectStatement::ToString() const { + return node->ToString(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/set_statement.cpp b/src/duckdb/src/parser/statement/set_statement.cpp new file mode 100644 index 00000000..e9308423 --- /dev/null +++ b/src/duckdb/src/parser/statement/set_statement.cpp @@ -0,0 +1,29 @@ +#include "duckdb/parser/statement/set_statement.hpp" + +namespace duckdb { + +SetStatement::SetStatement(std::string name_p, SetScope scope_p, SetType type_p) + : SQLStatement(StatementType::SET_STATEMENT), name(std::move(name_p)), scope(scope_p), set_type(type_p) { +} + +unique_ptr SetStatement::Copy() const { + return unique_ptr(new SetStatement(*this)); +} + +// Set Variable + +SetVariableStatement::SetVariableStatement(std::string name_p, Value value_p, SetScope scope_p) + : SetStatement(std::move(name_p), scope_p, SetType::SET), value(std::move(value_p)) { +} + +unique_ptr SetVariableStatement::Copy() const { + return unique_ptr(new SetVariableStatement(*this)); +} + +// Reset Variable + +ResetVariableStatement::ResetVariableStatement(std::string name_p, SetScope scope_p) + : SetStatement(std::move(name_p), scope_p, SetType::RESET) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/show_statement.cpp b/src/duckdb/src/parser/statement/show_statement.cpp new file mode 100644 index 00000000..e7abb4fd --- /dev/null +++ b/src/duckdb/src/parser/statement/show_statement.cpp @@ -0,0 +1,15 @@ +#include "duckdb/parser/statement/show_statement.hpp" + +namespace duckdb { + +ShowStatement::ShowStatement() : SQLStatement(StatementType::SHOW_STATEMENT), info(make_uniq()) { +} + +ShowStatement::ShowStatement(const ShowStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr ShowStatement::Copy() const { + return unique_ptr(new ShowStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/transaction_statement.cpp b/src/duckdb/src/parser/statement/transaction_statement.cpp new file mode 100644 index 00000000..6903ab84 --- /dev/null +++ b/src/duckdb/src/parser/statement/transaction_statement.cpp @@ -0,0 +1,17 @@ +#include "duckdb/parser/statement/transaction_statement.hpp" + +namespace duckdb { + +TransactionStatement::TransactionStatement(TransactionType type) + : SQLStatement(StatementType::TRANSACTION_STATEMENT), info(make_uniq(type)) { +} + +TransactionStatement::TransactionStatement(const TransactionStatement &other) + : SQLStatement(other), info(make_uniq(other.info->type)) { +} + +unique_ptr TransactionStatement::Copy() const { + return unique_ptr(new TransactionStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/update_statement.cpp b/src/duckdb/src/parser/statement/update_statement.cpp new file mode 100644 index 00000000..6c0c7991 --- /dev/null +++ b/src/duckdb/src/parser/statement/update_statement.cpp @@ -0,0 +1,78 @@ +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" + +namespace duckdb { + +UpdateSetInfo::UpdateSetInfo() { +} + +UpdateSetInfo::UpdateSetInfo(const UpdateSetInfo &other) : columns(other.columns) { + if (other.condition) { + condition = other.condition->Copy(); + } + for (auto &expr : other.expressions) { + expressions.emplace_back(expr->Copy()); + } +} + +unique_ptr UpdateSetInfo::Copy() const { + return unique_ptr(new UpdateSetInfo(*this)); +} + +UpdateStatement::UpdateStatement() : SQLStatement(StatementType::UPDATE_STATEMENT) { +} + +UpdateStatement::UpdateStatement(const UpdateStatement &other) + : SQLStatement(other), table(other.table->Copy()), set_info(other.set_info->Copy()) { + if (other.from_table) { + from_table = other.from_table->Copy(); + } + for (auto &expr : other.returning_list) { + returning_list.emplace_back(expr->Copy()); + } + cte_map = other.cte_map.Copy(); +} + +string UpdateStatement::ToString() const { + D_ASSERT(set_info); + auto &condition = set_info->condition; + auto &columns = set_info->columns; + auto &expressions = set_info->expressions; + + string result; + result = cte_map.ToString(); + result += "UPDATE "; + result += table->ToString(); + result += " SET "; + D_ASSERT(columns.size() == expressions.size()); + for (idx_t i = 0; i < columns.size(); i++) { + if (i > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(columns[i]); + result += " = "; + result += expressions[i]->ToString(); + } + if (from_table) { + result += " FROM " + from_table->ToString(); + } + if (condition) { + result += " WHERE " + condition->ToString(); + } + if (!returning_list.empty()) { + result += " RETURNING "; + for (idx_t i = 0; i < returning_list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += returning_list[i]->ToString(); + } + } + return result; +} + +unique_ptr UpdateStatement::Copy() const { + return unique_ptr(new UpdateStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/vacuum_statement.cpp b/src/duckdb/src/parser/statement/vacuum_statement.cpp new file mode 100644 index 00000000..5596acf8 --- /dev/null +++ b/src/duckdb/src/parser/statement/vacuum_statement.cpp @@ -0,0 +1,16 @@ +#include "duckdb/parser/statement/vacuum_statement.hpp" + +namespace duckdb { + +VacuumStatement::VacuumStatement(const VacuumOptions &options) + : SQLStatement(StatementType::VACUUM_STATEMENT), info(make_uniq(options)) { +} + +VacuumStatement::VacuumStatement(const VacuumStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr VacuumStatement::Copy() const { + return unique_ptr(new VacuumStatement(*this)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/tableref.cpp b/src/duckdb/src/parser/tableref.cpp new file mode 100644 index 00000000..3d72020f --- /dev/null +++ b/src/duckdb/src/parser/tableref.cpp @@ -0,0 +1,67 @@ +#include "duckdb/parser/tableref.hpp" + +#include "duckdb/common/printer.hpp" +#include "duckdb/parser/tableref/list.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/to_string.hpp" + +namespace duckdb { + +string TableRef::BaseToString(string result) const { + vector column_name_alias; + return BaseToString(std::move(result), column_name_alias); +} + +string TableRef::BaseToString(string result, const vector &column_name_alias) const { + if (!alias.empty()) { + result += StringUtil::Format(" AS %s", SQLIdentifier(alias)); + } + if (!column_name_alias.empty()) { + D_ASSERT(!alias.empty()); + result += "("; + for (idx_t i = 0; i < column_name_alias.size(); i++) { + if (i > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(column_name_alias[i]); + } + result += ")"; + } + if (sample) { + result += " TABLESAMPLE " + EnumUtil::ToString(sample->method); + result += "(" + sample->sample_size.ToString() + " " + string(sample->is_percentage ? "PERCENT" : "ROWS") + ")"; + if (sample->seed >= 0) { + result += "REPEATABLE (" + to_string(sample->seed) + ")"; + } + } + + return result; +} + +bool TableRef::Equals(const TableRef &other) const { + return type == other.type && alias == other.alias && SampleOptions::Equals(sample.get(), other.sample.get()); +} + +void TableRef::CopyProperties(TableRef &target) const { + D_ASSERT(type == target.type); + target.alias = alias; + target.query_location = query_location; + target.sample = sample ? sample->Copy() : nullptr; +} + +void TableRef::Print() { + Printer::Print(ToString()); +} + +bool TableRef::Equals(const unique_ptr &left, const unique_ptr &right) { + if (left.get() == right.get()) { + return true; + } + if (!left || !right) { + return false; + } + return left->Equals(*right); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/tableref/basetableref.cpp b/src/duckdb/src/parser/tableref/basetableref.cpp new file mode 100644 index 00000000..a90a6f38 --- /dev/null +++ b/src/duckdb/src/parser/tableref/basetableref.cpp @@ -0,0 +1,37 @@ +#include "duckdb/parser/tableref/basetableref.hpp" + +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +string BaseTableRef::ToString() const { + string result; + result += catalog_name.empty() ? "" : (KeywordHelper::WriteOptionallyQuoted(catalog_name) + "."); + result += schema_name.empty() ? "" : (KeywordHelper::WriteOptionallyQuoted(schema_name) + "."); + result += KeywordHelper::WriteOptionallyQuoted(table_name); + return BaseToString(result, column_name_alias); +} + +bool BaseTableRef::Equals(const TableRef &other_p) const { + if (!TableRef::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return other.catalog_name == catalog_name && other.schema_name == schema_name && other.table_name == table_name && + column_name_alias == other.column_name_alias; +} + +unique_ptr BaseTableRef::Copy() { + auto copy = make_uniq(); + + copy->catalog_name = catalog_name; + copy->schema_name = schema_name; + copy->table_name = table_name; + copy->column_name_alias = column_name_alias; + CopyProperties(*copy); + + return std::move(copy); +} +} // namespace duckdb diff --git a/src/duckdb/src/parser/tableref/emptytableref.cpp b/src/duckdb/src/parser/tableref/emptytableref.cpp new file mode 100644 index 00000000..0107aac9 --- /dev/null +++ b/src/duckdb/src/parser/tableref/emptytableref.cpp @@ -0,0 +1,17 @@ +#include "duckdb/parser/tableref/emptytableref.hpp" + +namespace duckdb { + +string EmptyTableRef::ToString() const { + return ""; +} + +bool EmptyTableRef::Equals(const TableRef &other) const { + return TableRef::Equals(other); +} + +unique_ptr EmptyTableRef::Copy() { + return make_uniq(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/tableref/expressionlistref.cpp b/src/duckdb/src/parser/tableref/expressionlistref.cpp new file mode 100644 index 00000000..c14483a1 --- /dev/null +++ b/src/duckdb/src/parser/tableref/expressionlistref.cpp @@ -0,0 +1,67 @@ +#include "duckdb/parser/tableref/expressionlistref.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +string ExpressionListRef::ToString() const { + D_ASSERT(!values.empty()); + string result = "(VALUES "; + for (idx_t row_idx = 0; row_idx < values.size(); row_idx++) { + if (row_idx > 0) { + result += ", "; + } + auto &row = values[row_idx]; + result += "("; + for (idx_t col_idx = 0; col_idx < row.size(); col_idx++) { + if (col_idx > 0) { + result += ", "; + } + result += row[col_idx]->ToString(); + } + result += ")"; + } + result += ")"; + return BaseToString(result, expected_names); +} + +bool ExpressionListRef::Equals(const TableRef &other_p) const { + if (!TableRef::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (values.size() != other.values.size()) { + return false; + } + for (idx_t i = 0; i < values.size(); i++) { + if (values[i].size() != other.values[i].size()) { + return false; + } + for (idx_t j = 0; j < values[i].size(); j++) { + if (!values[i][j]->Equals(*other.values[i][j])) { + return false; + } + } + } + return true; +} + +unique_ptr ExpressionListRef::Copy() { + // value list + auto result = make_uniq(); + for (auto &val_list : values) { + vector> new_val_list; + new_val_list.reserve(val_list.size()); + for (auto &val : val_list) { + new_val_list.push_back(val->Copy()); + } + result->values.push_back(std::move(new_val_list)); + } + result->expected_names = expected_names; + result->expected_types = expected_types; + CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/tableref/joinref.cpp b/src/duckdb/src/parser/tableref/joinref.cpp new file mode 100644 index 00000000..8ad69fa7 --- /dev/null +++ b/src/duckdb/src/parser/tableref/joinref.cpp @@ -0,0 +1,84 @@ +#include "duckdb/parser/tableref/joinref.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +string JoinRef::ToString() const { + string result; + result = left->ToString() + " "; + switch (ref_type) { + case JoinRefType::REGULAR: + result += EnumUtil::ToString(type) + " JOIN "; + break; + case JoinRefType::NATURAL: + result += "NATURAL "; + result += EnumUtil::ToString(type) + " JOIN "; + break; + case JoinRefType::ASOF: + result += "ASOF "; + result += EnumUtil::ToString(type) + " JOIN "; + break; + case JoinRefType::CROSS: + result += ", "; + break; + case JoinRefType::POSITIONAL: + result += "POSITIONAL JOIN "; + break; + case JoinRefType::DEPENDENT: + result += "DEPENDENT JOIN "; + break; + } + result += right->ToString(); + if (condition) { + D_ASSERT(using_columns.empty()); + result += " ON ("; + result += condition->ToString(); + result += ")"; + } else if (!using_columns.empty()) { + result += " USING ("; + for (idx_t i = 0; i < using_columns.size(); i++) { + if (i > 0) { + result += ", "; + } + result += using_columns[i]; + } + result += ")"; + } + return result; +} + +bool JoinRef::Equals(const TableRef &other_p) const { + if (!TableRef::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (using_columns.size() != other.using_columns.size()) { + return false; + } + for (idx_t i = 0; i < using_columns.size(); i++) { + if (using_columns[i] != other.using_columns[i]) { + return false; + } + } + return left->Equals(*other.left) && right->Equals(*other.right) && + ParsedExpression::Equals(condition, other.condition) && type == other.type; +} + +unique_ptr JoinRef::Copy() { + auto copy = make_uniq(ref_type); + copy->left = left->Copy(); + copy->right = right->Copy(); + if (condition) { + copy->condition = condition->Copy(); + } + copy->type = type; + copy->ref_type = ref_type; + copy->alias = alias; + copy->using_columns = using_columns; + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/tableref/pivotref.cpp b/src/duckdb/src/parser/tableref/pivotref.cpp new file mode 100644 index 00000000..4b182277 --- /dev/null +++ b/src/duckdb/src/parser/tableref/pivotref.cpp @@ -0,0 +1,252 @@ +#include "duckdb/parser/tableref/pivotref.hpp" + +#include "duckdb/common/limits.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// PivotColumn +//===--------------------------------------------------------------------===// +string PivotColumn::ToString() const { + string result; + if (!unpivot_names.empty()) { + D_ASSERT(pivot_expressions.empty()); + // unpivot + if (unpivot_names.size() == 1) { + result += KeywordHelper::WriteOptionallyQuoted(unpivot_names[0]); + } else { + result += "("; + for (idx_t n = 0; n < unpivot_names.size(); n++) { + if (n > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(unpivot_names[n]); + } + result += ")"; + } + } else if (!pivot_expressions.empty()) { + // pivot + result += "("; + for (idx_t n = 0; n < pivot_expressions.size(); n++) { + if (n > 0) { + result += ", "; + } + result += pivot_expressions[n]->ToString(); + } + result += ")"; + } + result += " IN "; + if (pivot_enum.empty()) { + result += "("; + for (idx_t e = 0; e < entries.size(); e++) { + auto &entry = entries[e]; + if (e > 0) { + result += ", "; + } + if (entry.star_expr) { + D_ASSERT(entry.values.empty()); + result += entry.star_expr->ToString(); + } else if (entry.values.size() == 1) { + result += entry.values[0].ToSQLString(); + } else { + result += "("; + for (idx_t v = 0; v < entry.values.size(); v++) { + if (v > 0) { + result += ", "; + } + result += entry.values[v].ToSQLString(); + } + result += ")"; + } + if (!entry.alias.empty()) { + result += " AS " + KeywordHelper::WriteOptionallyQuoted(entry.alias); + } + } + result += ")"; + } else { + result += KeywordHelper::WriteOptionallyQuoted(pivot_enum); + } + return result; +} + +bool PivotColumnEntry::Equals(const PivotColumnEntry &other) const { + if (alias != other.alias) { + return false; + } + if (values.size() != other.values.size()) { + return false; + } + for (idx_t i = 0; i < values.size(); i++) { + if (!Value::NotDistinctFrom(values[i], other.values[i])) { + return false; + } + } + return true; +} + +bool PivotColumn::Equals(const PivotColumn &other) const { + if (!ExpressionUtil::ListEquals(pivot_expressions, other.pivot_expressions)) { + return false; + } + if (other.unpivot_names != unpivot_names) { + return false; + } + if (other.pivot_enum != pivot_enum) { + return false; + } + if (other.entries.size() != entries.size()) { + return false; + } + for (idx_t i = 0; i < entries.size(); i++) { + if (!entries[i].Equals(other.entries[i])) { + return false; + } + } + return true; +} + +PivotColumn PivotColumn::Copy() const { + PivotColumn result; + for (auto &expr : pivot_expressions) { + result.pivot_expressions.push_back(expr->Copy()); + } + result.unpivot_names = unpivot_names; + for (auto &entry : entries) { + result.entries.push_back(entry.Copy()); + } + result.pivot_enum = pivot_enum; + return result; +} + +//===--------------------------------------------------------------------===// +// PivotColumnEntry +//===--------------------------------------------------------------------===// +PivotColumnEntry PivotColumnEntry::Copy() const { + PivotColumnEntry result; + result.values = values; + result.star_expr = star_expr ? star_expr->Copy() : nullptr; + result.alias = alias; + return result; +} + +//===--------------------------------------------------------------------===// +// PivotRef +//===--------------------------------------------------------------------===// +string PivotRef::ToString() const { + string result; + result = source->ToString(); + if (!aggregates.empty()) { + // pivot + result += " PIVOT ("; + for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { + if (aggr_idx > 0) { + result += ", "; + } + result += aggregates[aggr_idx]->ToString(); + if (!aggregates[aggr_idx]->alias.empty()) { + result += " AS " + KeywordHelper::WriteOptionallyQuoted(aggregates[aggr_idx]->alias); + } + } + } else { + // unpivot + result += " UNPIVOT "; + if (include_nulls) { + result += "INCLUDE NULLS "; + } + result += "("; + if (unpivot_names.size() == 1) { + result += KeywordHelper::WriteOptionallyQuoted(unpivot_names[0]); + } else { + result += "("; + for (idx_t n = 0; n < unpivot_names.size(); n++) { + if (n > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(unpivot_names[n]); + } + result += ")"; + } + } + result += " FOR"; + for (auto &pivot : pivots) { + result += " "; + result += pivot.ToString(); + } + if (!groups.empty()) { + result += " GROUP BY "; + for (idx_t i = 0; i < groups.size(); i++) { + if (i > 0) { + result += ", "; + } + result += groups[i]; + } + } + result += ")"; + if (!alias.empty()) { + result += " AS " + KeywordHelper::WriteOptionallyQuoted(alias); + if (!column_name_alias.empty()) { + result += "("; + for (idx_t i = 0; i < column_name_alias.size(); i++) { + if (i > 0) { + result += ", "; + } + result += KeywordHelper::WriteOptionallyQuoted(column_name_alias[i]); + } + result += ")"; + } + } + return result; +} + +bool PivotRef::Equals(const TableRef &other_p) const { + if (!TableRef::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!source->Equals(*other.source)) { + return false; + } + if (!ParsedExpression::ListEquals(aggregates, other.aggregates)) { + return false; + } + if (pivots.size() != other.pivots.size()) { + return false; + } + for (idx_t i = 0; i < pivots.size(); i++) { + if (!pivots[i].Equals(other.pivots[i])) { + return false; + } + } + if (unpivot_names != other.unpivot_names) { + return false; + } + if (alias != other.alias) { + return false; + } + if (groups != other.groups) { + return false; + } + if (include_nulls != other.include_nulls) { + return false; + } + return true; +} + +unique_ptr PivotRef::Copy() { + auto copy = make_uniq(); + copy->source = source->Copy(); + for (auto &aggr : aggregates) { + copy->aggregates.push_back(aggr->Copy()); + } + copy->unpivot_names = unpivot_names; + for (auto &entry : pivots) { + copy->pivots.push_back(entry.Copy()); + } + copy->groups = groups; + copy->column_name_alias = column_name_alias; + copy->include_nulls = include_nulls; + copy->alias = alias; + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/tableref/subqueryref.cpp b/src/duckdb/src/parser/tableref/subqueryref.cpp new file mode 100644 index 00000000..2d3214bd --- /dev/null +++ b/src/duckdb/src/parser/tableref/subqueryref.cpp @@ -0,0 +1,37 @@ +#include "duckdb/parser/tableref/subqueryref.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +string SubqueryRef::ToString() const { + string result = "(" + subquery->ToString() + ")"; + return BaseToString(result, column_name_alias); +} + +SubqueryRef::SubqueryRef() : TableRef(TableReferenceType::SUBQUERY) { +} + +SubqueryRef::SubqueryRef(unique_ptr subquery_p, string alias_p) + : TableRef(TableReferenceType::SUBQUERY), subquery(std::move(subquery_p)) { + this->alias = std::move(alias_p); +} + +bool SubqueryRef::Equals(const TableRef &other_p) const { + if (!TableRef::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return subquery->Equals(*other.subquery); +} + +unique_ptr SubqueryRef::Copy() { + auto copy = make_uniq(unique_ptr_cast(subquery->Copy()), alias); + copy->column_name_alias = column_name_alias; + CopyProperties(*copy); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/tableref/table_function.cpp b/src/duckdb/src/parser/tableref/table_function.cpp new file mode 100644 index 00000000..29a6da9b --- /dev/null +++ b/src/duckdb/src/parser/tableref/table_function.cpp @@ -0,0 +1,33 @@ +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +TableFunctionRef::TableFunctionRef() : TableRef(TableReferenceType::TABLE_FUNCTION) { +} + +string TableFunctionRef::ToString() const { + return BaseToString(function->ToString(), column_name_alias); +} + +bool TableFunctionRef::Equals(const TableRef &other_p) const { + if (!TableRef::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return function->Equals(*other.function); +} + +unique_ptr TableFunctionRef::Copy() { + auto copy = make_uniq(); + + copy->function = function->Copy(); + copy->column_name_alias = column_name_alias; + CopyProperties(*copy); + + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/constraint/transform_constraint.cpp b/src/duckdb/src/parser/transform/constraint/transform_constraint.cpp new file mode 100644 index 00000000..c77780e4 --- /dev/null +++ b/src/duckdb/src/parser/transform/constraint/transform_constraint.cpp @@ -0,0 +1,138 @@ +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/constraints/list.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +static void ParseSchemaTableNameFK(duckdb_libpgquery::PGRangeVar *input, ForeignKeyInfo &fk_info) { + if (input->catalogname) { + throw ParserException("FOREIGN KEY constraints cannot be defined cross-database"); + } + if (input->schemaname) { + fk_info.schema = input->schemaname; + } else { + fk_info.schema = ""; + }; + fk_info.table = input->relname; +} + +static bool ForeignKeyActionSupported(char action) { + switch (action) { + case PG_FKCONSTR_ACTION_NOACTION: + case PG_FKCONSTR_ACTION_RESTRICT: + return true; + case PG_FKCONSTR_ACTION_CASCADE: + case PG_FKCONSTR_ACTION_SETDEFAULT: + case PG_FKCONSTR_ACTION_SETNULL: + return false; + default: + D_ASSERT(false); + } + return false; +} + +static unique_ptr +TransformForeignKeyConstraint(duckdb_libpgquery::PGConstraint *constraint, + optional_ptr override_fk_column = nullptr) { + D_ASSERT(constraint); + if (!ForeignKeyActionSupported(constraint->fk_upd_action) || + !ForeignKeyActionSupported(constraint->fk_del_action)) { + throw ParserException("FOREIGN KEY constraints cannot use CASCADE, SET NULL or SET DEFAULT"); + } + ForeignKeyInfo fk_info; + fk_info.type = ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; + ParseSchemaTableNameFK(constraint->pktable, fk_info); + vector pk_columns, fk_columns; + if (override_fk_column) { + D_ASSERT(!constraint->fk_attrs); + fk_columns.emplace_back(*override_fk_column); + } else if (constraint->fk_attrs) { + for (auto kc = constraint->fk_attrs->head; kc; kc = kc->next) { + fk_columns.emplace_back(reinterpret_cast(kc->data.ptr_value)->val.str); + } + } + if (constraint->pk_attrs) { + for (auto kc = constraint->pk_attrs->head; kc; kc = kc->next) { + pk_columns.emplace_back(reinterpret_cast(kc->data.ptr_value)->val.str); + } + } + if (!pk_columns.empty() && pk_columns.size() != fk_columns.size()) { + throw ParserException("The number of referencing and referenced columns for foreign keys must be the same"); + } + if (fk_columns.empty()) { + throw ParserException("The set of referencing and referenced columns for foreign keys must be not empty"); + } + return make_uniq(pk_columns, fk_columns, std::move(fk_info)); +} + +unique_ptr Transformer::TransformConstraint(duckdb_libpgquery::PGListCell *cell) { + auto constraint = reinterpret_cast(cell->data.ptr_value); + D_ASSERT(constraint); + switch (constraint->contype) { + case duckdb_libpgquery::PG_CONSTR_UNIQUE: + case duckdb_libpgquery::PG_CONSTR_PRIMARY: { + bool is_primary_key = constraint->contype == duckdb_libpgquery::PG_CONSTR_PRIMARY; + vector columns; + for (auto kc = constraint->keys->head; kc; kc = kc->next) { + columns.emplace_back(reinterpret_cast(kc->data.ptr_value)->val.str); + } + return make_uniq(columns, is_primary_key); + } + case duckdb_libpgquery::PG_CONSTR_CHECK: { + auto expression = TransformExpression(constraint->raw_expr); + if (expression->HasSubquery()) { + throw ParserException("subqueries prohibited in CHECK constraints"); + } + return make_uniq(TransformExpression(constraint->raw_expr)); + } + case duckdb_libpgquery::PG_CONSTR_FOREIGN: + return TransformForeignKeyConstraint(constraint); + + default: + throw NotImplementedException("Constraint type not handled yet!"); + } +} + +unique_ptr Transformer::TransformConstraint(duckdb_libpgquery::PGListCell *cell, ColumnDefinition &column, + idx_t index) { + auto constraint = reinterpret_cast(cell->data.ptr_value); + D_ASSERT(constraint); + switch (constraint->contype) { + case duckdb_libpgquery::PG_CONSTR_NOTNULL: + return make_uniq(LogicalIndex(index)); + case duckdb_libpgquery::PG_CONSTR_CHECK: + return TransformConstraint(cell); + case duckdb_libpgquery::PG_CONSTR_PRIMARY: + return make_uniq(LogicalIndex(index), true); + case duckdb_libpgquery::PG_CONSTR_UNIQUE: + return make_uniq(LogicalIndex(index), false); + case duckdb_libpgquery::PG_CONSTR_NULL: + return nullptr; + case duckdb_libpgquery::PG_CONSTR_GENERATED_VIRTUAL: { + if (column.DefaultValue()) { + throw InvalidInputException("DEFAULT constraint on GENERATED column \"%s\" is not allowed", column.Name()); + } + column.SetGeneratedExpression(TransformExpression(constraint->raw_expr)); + return nullptr; + } + case duckdb_libpgquery::PG_CONSTR_GENERATED_STORED: + throw InvalidInputException("Can not create a STORED generated column!"); + case duckdb_libpgquery::PG_CONSTR_DEFAULT: + column.SetDefaultValue(TransformExpression(constraint->raw_expr)); + return nullptr; + case duckdb_libpgquery::PG_CONSTR_COMPRESSION: + column.SetCompressionType(CompressionTypeFromString(constraint->compression_name)); + if (column.CompressionType() == CompressionType::COMPRESSION_AUTO) { + throw ParserException("Unrecognized option for column compression, expected none, uncompressed, rle, " + "dictionary, pfor, bitpacking or fsst"); + } + return nullptr; + case duckdb_libpgquery::PG_CONSTR_FOREIGN: + return TransformForeignKeyConstraint(constraint, &column.Name()); + default: + throw NotImplementedException("Constraint not implemented!"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_array_access.cpp b/src/duckdb/src/parser/transform/expression/transform_array_access.cpp new file mode 100644 index 00000000..88571c04 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_array_access.cpp @@ -0,0 +1,81 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformArrayAccess(duckdb_libpgquery::PGAIndirection &indirection_node) { + // transform the source expression + unique_ptr result; + result = TransformExpression(indirection_node.arg); + + // now go over the indices + // note that a single indirection node can contain multiple indices + // this happens for e.g. more complex accesses (e.g. (foo).field1[42]) + idx_t list_size = 0; + for (auto node = indirection_node.indirection->head; node != nullptr; node = node->next) { + auto target = reinterpret_cast(node->data.ptr_value); + D_ASSERT(target); + + switch (target->type) { + case duckdb_libpgquery::T_PGAIndices: { + // index access (either slice or extract) + auto index = PGPointerCast(target); + vector> children; + children.push_back(std::move(result)); + if (index->is_slice) { + // slice + // if either the lower or upper bound is not specified, we use an empty const list so that we can + // handle it in the execution + unique_ptr lower = + index->lidx ? TransformExpression(index->lidx) + : make_uniq(Value::LIST(LogicalType::INTEGER, vector())); + children.push_back(std::move(lower)); + unique_ptr upper = + index->uidx ? TransformExpression(index->uidx) + : make_uniq(Value::LIST(LogicalType::INTEGER, vector())); + children.push_back(std::move(upper)); + if (index->step) { + children.push_back(TransformExpression(index->step)); + } + result = make_uniq(ExpressionType::ARRAY_SLICE, std::move(children)); + } else { + // array access + D_ASSERT(!index->lidx); + D_ASSERT(index->uidx); + children.push_back(TransformExpression(index->uidx)); + result = make_uniq(ExpressionType::ARRAY_EXTRACT, std::move(children)); + } + break; + } + case duckdb_libpgquery::T_PGString: { + auto val = PGPointerCast(target); + vector> children; + children.push_back(std::move(result)); + children.push_back(TransformValue(*val)); + result = make_uniq(ExpressionType::STRUCT_EXTRACT, std::move(children)); + break; + } + case duckdb_libpgquery::T_PGFuncCall: { + auto func = PGPointerCast(target); + auto function = TransformFuncCall(*func); + if (function->type != ExpressionType::FUNCTION) { + throw ParserException("%s.%s() call must be a function", result->ToString(), function->ToString()); + } + auto &f = function->Cast(); + f.children.insert(f.children.begin(), std::move(result)); + result = std::move(function); + break; + } + default: + throw NotImplementedException("Unimplemented subscript type"); + } + list_size++; + StackCheck(list_size); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_bool_expr.cpp b/src/duckdb/src/parser/transform/expression/transform_bool_expr.cpp new file mode 100644 index 00000000..b7de69fa --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_bool_expr.cpp @@ -0,0 +1,52 @@ +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformBoolExpr(duckdb_libpgquery::PGBoolExpr &root) { + unique_ptr result; + for (auto node = root.args->head; node != nullptr; node = node->next) { + auto next = TransformExpression(PGPointerCast(node->data.ptr_value)); + + switch (root.boolop) { + case duckdb_libpgquery::PG_AND_EXPR: { + if (!result) { + result = std::move(next); + } else { + result = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(result), + std::move(next)); + } + break; + } + case duckdb_libpgquery::PG_OR_EXPR: { + if (!result) { + result = std::move(next); + } else { + result = make_uniq(ExpressionType::CONJUNCTION_OR, std::move(result), + std::move(next)); + } + break; + } + case duckdb_libpgquery::PG_NOT_EXPR: { + if (next->type == ExpressionType::COMPARE_IN) { + // convert COMPARE_IN to COMPARE_NOT_IN + next->type = ExpressionType::COMPARE_NOT_IN; + result = std::move(next); + } else if (next->type >= ExpressionType::COMPARE_EQUAL && + next->type <= ExpressionType::COMPARE_GREATERTHANOREQUALTO) { + // NOT on a comparison: we can negate the comparison + // e.g. NOT(x > y) is equivalent to x <= y + next->type = NegateComparisonExpression(next->type); + result = std::move(next); + } else { + result = make_uniq(ExpressionType::OPERATOR_NOT, std::move(next)); + } + break; + } + } + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_boolean_test.cpp b/src/duckdb/src/parser/transform/expression/transform_boolean_test.cpp new file mode 100644 index 00000000..78ee75ab --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_boolean_test.cpp @@ -0,0 +1,39 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformBooleanTest(duckdb_libpgquery::PGBooleanTest &node) { + auto argument = TransformExpression(PGPointerCast(node.arg)); + + auto expr_true = make_uniq(Value::BOOLEAN(true)); + auto expr_false = make_uniq(Value::BOOLEAN(false)); + // we cast the argument to bool to remove ambiguity wrt function binding on the comparision + auto cast_argument = make_uniq(LogicalType::BOOLEAN, argument->Copy()); + + switch (node.booltesttype) { + case duckdb_libpgquery::PGBoolTestType::PG_IS_TRUE: + return make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, std::move(cast_argument), + std::move(expr_true)); + case duckdb_libpgquery::PGBoolTestType::IS_NOT_TRUE: + return make_uniq(ExpressionType::COMPARE_DISTINCT_FROM, std::move(cast_argument), + std::move(expr_true)); + case duckdb_libpgquery::PGBoolTestType::IS_FALSE: + return make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, std::move(cast_argument), + std::move(expr_false)); + case duckdb_libpgquery::PGBoolTestType::IS_NOT_FALSE: + return make_uniq(ExpressionType::COMPARE_DISTINCT_FROM, std::move(cast_argument), + std::move(expr_false)); + case duckdb_libpgquery::PGBoolTestType::IS_UNKNOWN: // IS NULL + return make_uniq(ExpressionType::OPERATOR_IS_NULL, std::move(argument)); + case duckdb_libpgquery::PGBoolTestType::IS_NOT_UNKNOWN: // IS NOT NULL + return make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, std::move(argument)); + default: + throw NotImplementedException("Unknown boolean test type %d", node.booltesttype); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_case.cpp b/src/duckdb/src/parser/transform/expression/transform_case.cpp new file mode 100644 index 00000000..2902fa43 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_case.cpp @@ -0,0 +1,35 @@ +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformCase(duckdb_libpgquery::PGCaseExpr &root) { + auto case_node = make_uniq(); + auto root_arg = TransformExpression(PGPointerCast(root.arg)); + for (auto cell = root.args->head; cell != nullptr; cell = cell->next) { + CaseCheck case_check; + + auto w = PGPointerCast(cell->data.ptr_value); + auto test_raw = TransformExpression(PGPointerCast(w->expr)); + unique_ptr test; + if (root_arg) { + case_check.when_expr = + make_uniq(ExpressionType::COMPARE_EQUAL, root_arg->Copy(), std::move(test_raw)); + } else { + case_check.when_expr = std::move(test_raw); + } + case_check.then_expr = TransformExpression(PGPointerCast(w->result)); + case_node->case_checks.push_back(std::move(case_check)); + } + + if (root.defresult) { + case_node->else_expr = TransformExpression(PGPointerCast(root.defresult)); + } else { + case_node->else_expr = make_uniq(Value(LogicalType::SQLNULL)); + } + return std::move(case_node); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_cast.cpp b/src/duckdb/src/parser/transform/expression/transform_cast.cpp new file mode 100644 index 00000000..d884290d --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_cast.cpp @@ -0,0 +1,29 @@ +#include "duckdb/common/limits.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/operator/cast_operators.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformTypeCast(duckdb_libpgquery::PGTypeCast &root) { + // get the type to cast to + auto type_name = root.typeName; + LogicalType target_type = TransformTypeName(*type_name); + + // check for a constant BLOB value, then return ConstantExpression with BLOB + if (!root.tryCast && target_type == LogicalType::BLOB && root.arg->type == duckdb_libpgquery::T_PGAConst) { + auto c = PGPointerCast(root.arg); + if (c->val.type == duckdb_libpgquery::T_PGString) { + return make_uniq(Value::BLOB(string(c->val.val.str))); + } + } + // transform the expression node + auto expression = TransformExpression(root.arg); + bool try_cast = root.tryCast; + + // now create a cast operation + return make_uniq(target_type, std::move(expression), try_cast); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_coalesce.cpp b/src/duckdb/src/parser/transform/expression/transform_coalesce.cpp new file mode 100644 index 00000000..bc42a2b2 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_coalesce.cpp @@ -0,0 +1,21 @@ +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +// COALESCE(a,b,c) returns the first argument that is NOT NULL, so +// rewrite into CASE(a IS NOT NULL, a, CASE(b IS NOT NULL, b, c)) +unique_ptr Transformer::TransformCoalesce(duckdb_libpgquery::PGAExpr &root) { + auto coalesce_args = PGPointerCast(root.lexpr); + D_ASSERT(coalesce_args->length > 0); // parser ensures this already + + auto coalesce_op = make_uniq(ExpressionType::OPERATOR_COALESCE); + for (auto cell = coalesce_args->head; cell; cell = cell->next) { + // get the value of the COALESCE + auto value_expr = TransformExpression(PGPointerCast(cell->data.ptr_value)); + coalesce_op->children.push_back(std::move(value_expr)); + } + return std::move(coalesce_op); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_columnref.cpp b/src/duckdb/src/parser/transform/expression/transform_columnref.cpp new file mode 100644 index 00000000..628c2130 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_columnref.cpp @@ -0,0 +1,88 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformStarExpression(duckdb_libpgquery::PGAStar &star) { + auto result = make_uniq(star.relation ? star.relation : string()); + if (star.except_list) { + for (auto head = star.except_list->head; head; head = head->next) { + auto value = PGPointerCast(head->data.ptr_value); + D_ASSERT(value->type == duckdb_libpgquery::T_PGString); + string exclude_entry = value->val.str; + if (result->exclude_list.find(exclude_entry) != result->exclude_list.end()) { + throw ParserException("Duplicate entry \"%s\" in EXCLUDE list", exclude_entry); + } + result->exclude_list.insert(std::move(exclude_entry)); + } + } + if (star.replace_list) { + for (auto head = star.replace_list->head; head; head = head->next) { + auto list = PGPointerCast(head->data.ptr_value); + D_ASSERT(list->length == 2); + auto replace_expression = + TransformExpression(PGPointerCast(list->head->data.ptr_value)); + auto value = PGPointerCast(list->tail->data.ptr_value); + D_ASSERT(value->type == duckdb_libpgquery::T_PGString); + string exclude_entry = value->val.str; + if (result->replace_list.find(exclude_entry) != result->replace_list.end()) { + throw ParserException("Duplicate entry \"%s\" in REPLACE list", exclude_entry); + } + if (result->exclude_list.find(exclude_entry) != result->exclude_list.end()) { + throw ParserException("Column \"%s\" cannot occur in both EXCEPT and REPLACE list", exclude_entry); + } + result->replace_list.insert(make_pair(std::move(exclude_entry), std::move(replace_expression))); + } + } + if (star.expr) { + D_ASSERT(star.columns); + D_ASSERT(result->relation_name.empty()); + D_ASSERT(result->exclude_list.empty()); + D_ASSERT(result->replace_list.empty()); + result->expr = TransformExpression(star.expr); + if (result->expr->type == ExpressionType::STAR) { + auto &child_star = result->expr->Cast(); + result->exclude_list = std::move(child_star.exclude_list); + result->replace_list = std::move(child_star.replace_list); + result->expr.reset(); + } else if (result->expr->type == ExpressionType::LAMBDA) { + vector> children; + children.push_back(make_uniq()); + children.push_back(std::move(result->expr)); + auto list_filter = make_uniq("list_filter", std::move(children)); + result->expr = std::move(list_filter); + } + } + result->columns = star.columns; + result->query_location = star.location; + return std::move(result); +} + +unique_ptr Transformer::TransformColumnRef(duckdb_libpgquery::PGColumnRef &root) { + auto fields = root.fields; + auto head_node = PGPointerCast(fields->head->data.ptr_value); + switch (head_node->type) { + case duckdb_libpgquery::T_PGString: { + if (fields->length < 1) { + throw InternalException("Unexpected field length"); + } + vector column_names; + for (auto node = fields->head; node; node = node->next) { + column_names.emplace_back(PGPointerCast(node->data.ptr_value)->val.str); + } + auto colref = make_uniq(std::move(column_names)); + colref->query_location = root.location; + return std::move(colref); + } + case duckdb_libpgquery::T_PGAStar: { + return TransformStarExpression(PGCast(*head_node)); + } + default: + throw NotImplementedException("ColumnRef not implemented!"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_constant.cpp b/src/duckdb/src/parser/transform/expression/transform_constant.cpp new file mode 100644 index 00000000..587a4a06 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_constant.cpp @@ -0,0 +1,131 @@ +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformValue(duckdb_libpgquery::PGValue val) { + switch (val.type) { + case duckdb_libpgquery::T_PGInteger: + D_ASSERT(val.val.ival <= NumericLimits::Maximum()); + return make_uniq(Value::INTEGER((int32_t)val.val.ival)); + case duckdb_libpgquery::T_PGBitString: // FIXME: this should actually convert to BLOB + case duckdb_libpgquery::T_PGString: + return make_uniq(Value(string(val.val.str))); + case duckdb_libpgquery::T_PGFloat: { + string_t str_val(val.val.str); + bool try_cast_as_integer = true; + bool try_cast_as_decimal = true; + int decimal_position = -1; + for (idx_t i = 0; i < str_val.GetSize(); i++) { + if (val.val.str[i] == '.') { + // decimal point: cast as either decimal or double + try_cast_as_integer = false; + decimal_position = i; + } + if (val.val.str[i] == 'e' || val.val.str[i] == 'E') { + // found exponent, cast as double + try_cast_as_integer = false; + try_cast_as_decimal = false; + } + } + if (try_cast_as_integer) { + int64_t bigint_value; + // try to cast as bigint first + if (TryCast::Operation(str_val, bigint_value)) { + // successfully cast to bigint: bigint value + return make_uniq(Value::BIGINT(bigint_value)); + } + hugeint_t hugeint_value; + // if that is not successful; try to cast as hugeint + if (TryCast::Operation(str_val, hugeint_value)) { + // successfully cast to bigint: bigint value + return make_uniq(Value::HUGEINT(hugeint_value)); + } + } + idx_t decimal_offset = val.val.str[0] == '-' ? 3 : 2; + if (try_cast_as_decimal && decimal_position >= 0 && + str_val.GetSize() < Decimal::MAX_WIDTH_DECIMAL + decimal_offset) { + // figure out the width/scale based on the decimal position + auto width = uint8_t(str_val.GetSize() - 1); + auto scale = uint8_t(width - decimal_position); + if (val.val.str[0] == '-') { + width--; + } + if (width <= Decimal::MAX_WIDTH_DECIMAL) { + // we can cast the value as a decimal + Value val = Value(str_val); + val = val.DefaultCastAs(LogicalType::DECIMAL(width, scale)); + return make_uniq(std::move(val)); + } + } + // if there is a decimal or the value is too big to cast as either hugeint or bigint + double dbl_value = Cast::Operation(str_val); + return make_uniq(Value::DOUBLE(dbl_value)); + } + case duckdb_libpgquery::T_PGNull: + return make_uniq(Value(LogicalType::SQLNULL)); + default: + throw NotImplementedException("Value not implemented!"); + } +} + +unique_ptr Transformer::TransformConstant(duckdb_libpgquery::PGAConst &c) { + return TransformValue(c.val); +} + +bool Transformer::ConstructConstantFromExpression(const ParsedExpression &expr, Value &value) { + // We have to construct it like this because we don't have the ClientContext for binding/executing the expr here + switch (expr.type) { + case ExpressionType::FUNCTION: { + auto &function = expr.Cast(); + if (function.function_name == "struct_pack") { + unordered_set unique_names; + child_list_t values; + values.reserve(function.children.size()); + for (const auto &child : function.children) { + if (!unique_names.insert(child->alias).second) { + throw BinderException("Duplicate struct entry name \"%s\"", child->alias); + } + Value child_value; + if (!ConstructConstantFromExpression(*child, child_value)) { + return false; + } + values.emplace_back(child->alias, std::move(child_value)); + } + value = Value::STRUCT(std::move(values)); + return true; + } else { + return false; + } + } + case ExpressionType::VALUE_CONSTANT: { + auto &constant = expr.Cast(); + value = constant.value; + return true; + } + case ExpressionType::OPERATOR_CAST: { + auto &cast = expr.Cast(); + Value dummy_value; + if (!ConstructConstantFromExpression(*cast.child, dummy_value)) { + return false; + } + + string error_message; + if (!dummy_value.DefaultTryCastAs(cast.cast_type, value, &error_message)) { + throw ConversionException("Unable to cast %s to %s", dummy_value.ToString(), + EnumUtil::ToString(cast.cast_type.id())); + } + return true; + } + default: + return false; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_expression.cpp b/src/duckdb/src/parser/transform/expression/transform_expression.cpp new file mode 100644 index 00000000..9f7e7a89 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_expression.cpp @@ -0,0 +1,103 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/expression/default_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformResTarget(duckdb_libpgquery::PGResTarget &root) { + auto expr = TransformExpression(root.val); + if (!expr) { + return nullptr; + } + if (root.name) { + expr->alias = string(root.name); + } + return expr; +} + +unique_ptr Transformer::TransformNamedArg(duckdb_libpgquery::PGNamedArgExpr &root) { + + auto expr = TransformExpression(PGPointerCast(root.arg)); + if (root.name) { + expr->alias = string(root.name); + } + return expr; +} + +unique_ptr Transformer::TransformExpression(duckdb_libpgquery::PGNode &node) { + + auto stack_checker = StackCheck(); + + switch (node.type) { + case duckdb_libpgquery::T_PGColumnRef: + return TransformColumnRef(PGCast(node)); + case duckdb_libpgquery::T_PGAConst: + return TransformConstant(PGCast(node)); + case duckdb_libpgquery::T_PGAExpr: + return TransformAExpr(PGCast(node)); + case duckdb_libpgquery::T_PGFuncCall: + return TransformFuncCall(PGCast(node)); + case duckdb_libpgquery::T_PGBoolExpr: + return TransformBoolExpr(PGCast(node)); + case duckdb_libpgquery::T_PGTypeCast: + return TransformTypeCast(PGCast(node)); + case duckdb_libpgquery::T_PGCaseExpr: + return TransformCase(PGCast(node)); + case duckdb_libpgquery::T_PGSubLink: + return TransformSubquery(PGCast(node)); + case duckdb_libpgquery::T_PGCoalesceExpr: + return TransformCoalesce(PGCast(node)); + case duckdb_libpgquery::T_PGNullTest: + return TransformNullTest(PGCast(node)); + case duckdb_libpgquery::T_PGResTarget: + return TransformResTarget(PGCast(node)); + case duckdb_libpgquery::T_PGParamRef: + return TransformParamRef(PGCast(node)); + case duckdb_libpgquery::T_PGNamedArgExpr: + return TransformNamedArg(PGCast(node)); + case duckdb_libpgquery::T_PGSQLValueFunction: + return TransformSQLValueFunction(PGCast(node)); + case duckdb_libpgquery::T_PGSetToDefault: + return make_uniq(); + case duckdb_libpgquery::T_PGCollateClause: + return TransformCollateExpr(PGCast(node)); + case duckdb_libpgquery::T_PGIntervalConstant: + return TransformInterval(PGCast(node)); + case duckdb_libpgquery::T_PGLambdaFunction: + return TransformLambda(PGCast(node)); + case duckdb_libpgquery::T_PGAIndirection: + return TransformArrayAccess(PGCast(node)); + case duckdb_libpgquery::T_PGPositionalReference: + return TransformPositionalReference(PGCast(node)); + case duckdb_libpgquery::T_PGGroupingFunc: + return TransformGroupingFunction(PGCast(node)); + case duckdb_libpgquery::T_PGAStar: + return TransformStarExpression(PGCast(node)); + case duckdb_libpgquery::T_PGBooleanTest: + return TransformBooleanTest(PGCast(node)); + case duckdb_libpgquery::T_PGMultiAssignRef: + return TransformMultiAssignRef(PGCast(node)); + + default: + throw NotImplementedException("Expression type %s (%d)", NodetypeToString(node.type), (int)node.type); + } +} + +unique_ptr Transformer::TransformExpression(optional_ptr node) { + if (!node) { + return nullptr; + } + return TransformExpression(*node); +} + +void Transformer::TransformExpressionList(duckdb_libpgquery::PGList &list, + vector> &result) { + for (auto node = list.head; node != nullptr; node = node->next) { + auto target = PGPointerCast(node->data.ptr_value); + + auto expr = TransformExpression(*target); + result.push_back(std::move(expr)); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_function.cpp b/src/duckdb/src/parser/transform/expression/transform_function.cpp new file mode 100644 index 00000000..bca232b4 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_function.cpp @@ -0,0 +1,334 @@ +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" + +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +void Transformer::TransformWindowDef(duckdb_libpgquery::PGWindowDef &window_spec, WindowExpression &expr, + const char *window_name) { + // next: partitioning/ordering expressions + if (window_spec.partitionClause) { + if (window_name && !expr.partitions.empty()) { + throw ParserException("Cannot override PARTITION BY clause of window \"%s\"", window_name); + } + TransformExpressionList(*window_spec.partitionClause, expr.partitions); + } + if (window_spec.orderClause) { + if (window_name && !expr.orders.empty()) { + throw ParserException("Cannot override ORDER BY clause of window \"%s\"", window_name); + } + TransformOrderBy(window_spec.orderClause, expr.orders); + } +} + +void Transformer::TransformWindowFrame(duckdb_libpgquery::PGWindowDef &window_spec, WindowExpression &expr) { + // finally: specifics of bounds + expr.start_expr = TransformExpression(window_spec.startOffset); + expr.end_expr = TransformExpression(window_spec.endOffset); + + if ((window_spec.frameOptions & FRAMEOPTION_END_UNBOUNDED_PRECEDING) || + (window_spec.frameOptions & FRAMEOPTION_START_UNBOUNDED_FOLLOWING)) { + throw InternalException( + "Window frames starting with unbounded following or ending in unbounded preceding make no sense"); + } + + const bool rangeMode = (window_spec.frameOptions & FRAMEOPTION_RANGE) != 0; + if (window_spec.frameOptions & FRAMEOPTION_START_UNBOUNDED_PRECEDING) { + expr.start = WindowBoundary::UNBOUNDED_PRECEDING; + } else if (window_spec.frameOptions & FRAMEOPTION_START_VALUE_PRECEDING) { + expr.start = rangeMode ? WindowBoundary::EXPR_PRECEDING_RANGE : WindowBoundary::EXPR_PRECEDING_ROWS; + } else if (window_spec.frameOptions & FRAMEOPTION_START_VALUE_FOLLOWING) { + expr.start = rangeMode ? WindowBoundary::EXPR_FOLLOWING_RANGE : WindowBoundary::EXPR_FOLLOWING_ROWS; + } else if (window_spec.frameOptions & FRAMEOPTION_START_CURRENT_ROW) { + expr.start = rangeMode ? WindowBoundary::CURRENT_ROW_RANGE : WindowBoundary::CURRENT_ROW_ROWS; + } + + if (window_spec.frameOptions & FRAMEOPTION_END_UNBOUNDED_FOLLOWING) { + expr.end = WindowBoundary::UNBOUNDED_FOLLOWING; + } else if (window_spec.frameOptions & FRAMEOPTION_END_VALUE_PRECEDING) { + expr.end = rangeMode ? WindowBoundary::EXPR_PRECEDING_RANGE : WindowBoundary::EXPR_PRECEDING_ROWS; + } else if (window_spec.frameOptions & FRAMEOPTION_END_VALUE_FOLLOWING) { + expr.end = rangeMode ? WindowBoundary::EXPR_FOLLOWING_RANGE : WindowBoundary::EXPR_FOLLOWING_ROWS; + } else if (window_spec.frameOptions & FRAMEOPTION_END_CURRENT_ROW) { + expr.end = rangeMode ? WindowBoundary::CURRENT_ROW_RANGE : WindowBoundary::CURRENT_ROW_ROWS; + } + + D_ASSERT(expr.start != WindowBoundary::INVALID && expr.end != WindowBoundary::INVALID); + if (((window_spec.frameOptions & (FRAMEOPTION_START_VALUE_PRECEDING | FRAMEOPTION_START_VALUE_FOLLOWING)) && + !expr.start_expr) || + ((window_spec.frameOptions & (FRAMEOPTION_END_VALUE_PRECEDING | FRAMEOPTION_END_VALUE_FOLLOWING)) && + !expr.end_expr)) { + throw InternalException("Failed to transform window boundary expression"); + } +} + +bool Transformer::ExpressionIsEmptyStar(ParsedExpression &expr) { + if (expr.expression_class != ExpressionClass::STAR) { + return false; + } + auto &star = expr.Cast(); + if (!star.columns && star.exclude_list.empty() && star.replace_list.empty()) { + return true; + } + return false; +} + +bool Transformer::InWindowDefinition() { + if (in_window_definition) { + return true; + } + if (parent) { + return parent->InWindowDefinition(); + } + return false; +} + +unique_ptr Transformer::TransformFuncCall(duckdb_libpgquery::PGFuncCall &root) { + auto name = root.funcname; + string catalog, schema, function_name; + if (name->length == 3) { + // catalog + schema + name + catalog = PGPointerCast(name->head->data.ptr_value)->val.str; + schema = PGPointerCast(name->head->next->data.ptr_value)->val.str; + function_name = PGPointerCast(name->head->next->next->data.ptr_value)->val.str; + } else if (name->length == 2) { + // schema + name + catalog = INVALID_CATALOG; + schema = PGPointerCast(name->head->data.ptr_value)->val.str; + function_name = PGPointerCast(name->head->next->data.ptr_value)->val.str; + } else if (name->length == 1) { + // unqualified name + catalog = INVALID_CATALOG; + schema = INVALID_SCHEMA; + function_name = PGPointerCast(name->head->data.ptr_value)->val.str; + } else { + throw ParserException("TransformFuncCall - Expected 1, 2 or 3 qualifications"); + } + + // transform children + vector> children; + if (root.args) { + TransformExpressionList(*root.args, children); + } + if (children.size() == 1 && ExpressionIsEmptyStar(*children[0]) && !root.agg_distinct && !root.agg_order) { + // COUNT(*) gets translated into COUNT() + children.clear(); + } + + auto lowercase_name = StringUtil::Lower(function_name); + if (root.over) { + if (InWindowDefinition()) { + throw ParserException("window functions are not allowed in window definitions"); + } + + const auto win_fun_type = WindowExpression::WindowToExpressionType(lowercase_name); + if (win_fun_type == ExpressionType::INVALID) { + throw InternalException("Unknown/unsupported window function"); + } + + if (root.agg_distinct) { + throw ParserException("DISTINCT is not implemented for window functions!"); + } + + if (root.agg_order) { + throw ParserException("ORDER BY is not implemented for window functions!"); + } + + if (win_fun_type != ExpressionType::WINDOW_AGGREGATE && root.agg_filter) { + throw ParserException("FILTER is not implemented for non-aggregate window functions!"); + } + if (root.export_state) { + throw ParserException("EXPORT_STATE is not supported for window functions!"); + } + + if (win_fun_type == ExpressionType::WINDOW_AGGREGATE && root.agg_ignore_nulls) { + throw ParserException("IGNORE NULLS is not supported for windowed aggregates"); + } + + auto expr = make_uniq(win_fun_type, std::move(catalog), std::move(schema), lowercase_name); + expr->ignore_nulls = root.agg_ignore_nulls; + + if (root.agg_filter) { + auto filter_expr = TransformExpression(root.agg_filter); + expr->filter_expr = std::move(filter_expr); + } + + if (win_fun_type == ExpressionType::WINDOW_AGGREGATE) { + expr->children = std::move(children); + } else { + if (!children.empty()) { + expr->children.push_back(std::move(children[0])); + } + if (win_fun_type == ExpressionType::WINDOW_LEAD || win_fun_type == ExpressionType::WINDOW_LAG) { + if (children.size() > 1) { + expr->offset_expr = std::move(children[1]); + } + if (children.size() > 2) { + expr->default_expr = std::move(children[2]); + } + if (children.size() > 3) { + throw ParserException("Incorrect number of parameters for function %s", lowercase_name); + } + } else if (win_fun_type == ExpressionType::WINDOW_NTH_VALUE) { + if (children.size() > 1) { + expr->children.push_back(std::move(children[1])); + } + if (children.size() > 2) { + throw ParserException("Incorrect number of parameters for function %s", lowercase_name); + } + } else { + if (children.size() > 1) { + throw ParserException("Incorrect number of parameters for function %s", lowercase_name); + } + } + } + auto window_spec = PGPointerCast(root.over); + if (window_spec->name) { + auto it = window_clauses.find(StringUtil::Lower(string(window_spec->name))); + if (it == window_clauses.end()) { + throw ParserException("window \"%s\" does not exist", window_spec->name); + } + window_spec = it->second; + D_ASSERT(window_spec); + } + auto window_ref = window_spec; + auto window_name = window_ref->refname; + if (window_ref->refname) { + auto it = window_clauses.find(StringUtil::Lower(string(window_spec->refname))); + if (it == window_clauses.end()) { + throw ParserException("window \"%s\" does not exist", window_spec->refname); + } + window_ref = it->second; + D_ASSERT(window_ref); + if (window_ref->startOffset || window_ref->endOffset || window_ref->frameOptions != FRAMEOPTION_DEFAULTS) { + throw ParserException("cannot copy window \"%s\" because it has a frame clause", window_spec->refname); + } + } + in_window_definition = true; + TransformWindowDef(*window_ref, *expr); + if (window_ref != window_spec) { + TransformWindowDef(*window_spec, *expr, window_name); + } + TransformWindowFrame(*window_spec, *expr); + in_window_definition = false; + expr->query_location = root.location; + return std::move(expr); + } + + if (root.agg_ignore_nulls) { + throw ParserException("IGNORE NULLS is not supported for non-window functions"); + } + + unique_ptr filter_expr; + if (root.agg_filter) { + filter_expr = TransformExpression(root.agg_filter); + } + + auto order_bys = make_uniq(); + TransformOrderBy(root.agg_order, order_bys->orders); + + // Ordered aggregates can be either WITHIN GROUP or after the function arguments + if (root.agg_within_group) { + // https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-ORDEREDSET-TABLE + // Since we implement "ordered aggregates" without sorting, + // we map all the ones we support to the corresponding aggregate function. + if (order_bys->orders.size() != 1) { + throw ParserException("Cannot use multiple ORDER BY clauses with WITHIN GROUP"); + } + if (lowercase_name == "percentile_cont") { + if (children.size() != 1) { + throw ParserException("Wrong number of arguments for PERCENTILE_CONT"); + } + lowercase_name = "quantile_cont"; + } else if (lowercase_name == "percentile_disc") { + if (children.size() != 1) { + throw ParserException("Wrong number of arguments for PERCENTILE_DISC"); + } + lowercase_name = "quantile_disc"; + } else if (lowercase_name == "mode") { + if (!children.empty()) { + throw ParserException("Wrong number of arguments for MODE"); + } + lowercase_name = "mode"; + } else { + throw ParserException("Unknown ordered aggregate \"%s\".", function_name); + } + } + + // star gets eaten in the parser + if (lowercase_name == "count" && children.empty()) { + lowercase_name = "count_star"; + } + + if (lowercase_name == "if") { + if (children.size() != 3) { + throw ParserException("Wrong number of arguments to IF."); + } + auto expr = make_uniq(); + CaseCheck check; + check.when_expr = std::move(children[0]); + check.then_expr = std::move(children[1]); + expr->case_checks.push_back(std::move(check)); + expr->else_expr = std::move(children[2]); + return std::move(expr); + } else if (lowercase_name == "construct_array") { + auto construct_array = make_uniq(ExpressionType::ARRAY_CONSTRUCTOR); + construct_array->children = std::move(children); + return std::move(construct_array); + } else if (lowercase_name == "ifnull") { + if (children.size() != 2) { + throw ParserException("Wrong number of arguments to IFNULL."); + } + + // Two-argument COALESCE + auto coalesce_op = make_uniq(ExpressionType::OPERATOR_COALESCE); + coalesce_op->children.push_back(std::move(children[0])); + coalesce_op->children.push_back(std::move(children[1])); + return std::move(coalesce_op); + } else if (lowercase_name == "list" && order_bys->orders.size() == 1) { + // list(expr ORDER BY expr ) => list_sort(list(expr), , ) + if (children.size() != 1) { + throw ParserException("Wrong number of arguments to LIST."); + } + auto arg_expr = children[0].get(); + auto &order_by = order_bys->orders[0]; + if (arg_expr->Equals(*order_by.expression)) { + auto sense = make_uniq(EnumUtil::ToChars(order_by.type)); + auto nulls = make_uniq(EnumUtil::ToChars(order_by.null_order)); + order_bys = nullptr; + auto unordered = make_uniq(catalog, schema, lowercase_name.c_str(), std::move(children), + std::move(filter_expr), std::move(order_bys), + root.agg_distinct, false, root.export_state); + lowercase_name = "list_sort"; + order_bys.reset(); // NOLINT + filter_expr.reset(); // NOLINT + children.clear(); // NOLINT + root.agg_distinct = false; + children.emplace_back(std::move(unordered)); + children.emplace_back(std::move(sense)); + children.emplace_back(std::move(nulls)); + } + } + + auto function = make_uniq(std::move(catalog), std::move(schema), lowercase_name.c_str(), + std::move(children), std::move(filter_expr), std::move(order_bys), + root.agg_distinct, false, root.export_state); + function->query_location = root.location; + + return std::move(function); +} + +unique_ptr Transformer::TransformSQLValueFunction(duckdb_libpgquery::PGSQLValueFunction &node) { + throw InternalException("SQL value functions should not be emitted by the parser"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_grouping_function.cpp b/src/duckdb/src/parser/transform/expression/transform_grouping_function.cpp new file mode 100644 index 00000000..36751a9e --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_grouping_function.cpp @@ -0,0 +1,16 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformGroupingFunction(duckdb_libpgquery::PGGroupingFunc &grouping) { + auto op = make_uniq(ExpressionType::GROUPING_FUNCTION); + for (auto node = grouping.args->head; node; node = node->next) { + auto n = PGPointerCast(node->data.ptr_value); + op->children.push_back(TransformExpression(n)); + } + op->query_location = grouping.location; + return std::move(op); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_interval.cpp b/src/duckdb/src/parser/transform/expression/transform_interval.cpp new file mode 100644 index 00000000..1a08c1f0 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_interval.cpp @@ -0,0 +1,119 @@ +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/operator/cast_operators.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformInterval(duckdb_libpgquery::PGIntervalConstant &node) { + // handle post-fix notation of INTERVAL + + // three scenarios + // interval (expr) year + // interval 'string' year + // interval int year + unique_ptr expr; + switch (node.val_type) { + case duckdb_libpgquery::T_PGAExpr: + expr = TransformExpression(node.eval); + break; + case duckdb_libpgquery::T_PGString: + expr = make_uniq(Value(node.sval)); + break; + case duckdb_libpgquery::T_PGInteger: + expr = make_uniq(Value(node.ival)); + break; + default: + throw InternalException("Unsupported interval transformation"); + } + + if (!node.typmods) { + return make_uniq(LogicalType::INTERVAL, std::move(expr)); + } + + int32_t mask = PGPointerCast(node.typmods->head->data.ptr_value)->val.val.ival; + // these seemingly random constants are from datetime.hpp + // they are copied here to avoid having to include this header + // the bitshift is from the function INTERVAL_MASK in the parser + constexpr int32_t MONTH_MASK = 1 << 1; + constexpr int32_t YEAR_MASK = 1 << 2; + constexpr int32_t DAY_MASK = 1 << 3; + constexpr int32_t HOUR_MASK = 1 << 10; + constexpr int32_t MINUTE_MASK = 1 << 11; + constexpr int32_t SECOND_MASK = 1 << 12; + constexpr int32_t MILLISECOND_MASK = 1 << 13; + constexpr int32_t MICROSECOND_MASK = 1 << 14; + + // we need to check certain combinations + // because certain interval masks (e.g. INTERVAL '10' HOURS TO DAYS) set multiple bits + // for now we don't support all of the combined ones + // (we might add support if someone complains about it) + + string fname; + LogicalType target_type; + if (mask & YEAR_MASK && mask & MONTH_MASK) { + // DAY TO HOUR + throw ParserException("YEAR TO MONTH is not supported"); + } else if (mask & DAY_MASK && mask & HOUR_MASK) { + // DAY TO HOUR + throw ParserException("DAY TO HOUR is not supported"); + } else if (mask & DAY_MASK && mask & MINUTE_MASK) { + // DAY TO MINUTE + throw ParserException("DAY TO MINUTE is not supported"); + } else if (mask & DAY_MASK && mask & SECOND_MASK) { + // DAY TO SECOND + throw ParserException("DAY TO SECOND is not supported"); + } else if (mask & HOUR_MASK && mask & MINUTE_MASK) { + // DAY TO SECOND + throw ParserException("HOUR TO MINUTE is not supported"); + } else if (mask & HOUR_MASK && mask & SECOND_MASK) { + // DAY TO SECOND + throw ParserException("HOUR TO SECOND is not supported"); + } else if (mask & MINUTE_MASK && mask & SECOND_MASK) { + // DAY TO SECOND + throw ParserException("MINUTE TO SECOND is not supported"); + } else if (mask & YEAR_MASK) { + // YEAR + fname = "to_years"; + target_type = LogicalType::INTEGER; + } else if (mask & MONTH_MASK) { + // MONTH + fname = "to_months"; + target_type = LogicalType::INTEGER; + } else if (mask & DAY_MASK) { + // DAY + fname = "to_days"; + target_type = LogicalType::INTEGER; + } else if (mask & HOUR_MASK) { + // HOUR + fname = "to_hours"; + target_type = LogicalType::BIGINT; + } else if (mask & MINUTE_MASK) { + // MINUTE + fname = "to_minutes"; + target_type = LogicalType::BIGINT; + } else if (mask & SECOND_MASK) { + // SECOND + fname = "to_seconds"; + target_type = LogicalType::BIGINT; + } else if (mask & MILLISECOND_MASK) { + // MILLISECOND + fname = "to_milliseconds"; + target_type = LogicalType::BIGINT; + } else if (mask & MICROSECOND_MASK) { + // SECOND + fname = "to_microseconds"; + target_type = LogicalType::BIGINT; + } else { + throw InternalException("Unsupported interval post-fix"); + } + // first push a cast to the target type + expr = make_uniq(target_type, std::move(expr)); + // now push the operation + vector> children; + children.push_back(std::move(expr)); + return make_uniq(fname, std::move(children)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_is_null.cpp b/src/duckdb/src/parser/transform/expression/transform_is_null.cpp new file mode 100644 index 00000000..a6bd014c --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_is_null.cpp @@ -0,0 +1,19 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformNullTest(duckdb_libpgquery::PGNullTest &root) { + auto arg = TransformExpression(PGPointerCast(root.arg)); + if (root.argisrow) { + throw NotImplementedException("IS NULL argisrow"); + } + ExpressionType expr_type = (root.nulltesttype == duckdb_libpgquery::PG_IS_NULL) + ? ExpressionType::OPERATOR_IS_NULL + : ExpressionType::OPERATOR_IS_NOT_NULL; + + return unique_ptr(new OperatorExpression(expr_type, std::move(arg))); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_lambda.cpp b/src/duckdb/src/parser/transform/expression/transform_lambda.cpp new file mode 100644 index 00000000..4f980a12 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_lambda.cpp @@ -0,0 +1,18 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/expression/lambda_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformLambda(duckdb_libpgquery::PGLambdaFunction &node) { + D_ASSERT(node.lhs); + D_ASSERT(node.rhs); + + auto lhs = TransformExpression(node.lhs); + auto rhs = TransformExpression(node.rhs); + D_ASSERT(lhs); + D_ASSERT(rhs); + return make_uniq(std::move(lhs), std::move(rhs)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp b/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp new file mode 100644 index 00000000..e2c4937f --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp @@ -0,0 +1,44 @@ +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformMultiAssignRef(duckdb_libpgquery::PGMultiAssignRef &root) { + // Multi assignment for the ROW function + if (root.source->type == duckdb_libpgquery::T_PGFuncCall) { + auto func = PGCast(*root.source); + + // Explicitly only allow ROW function + char const *function_name = + PGPointerCast(func.funcname->tail->data.ptr_value)->val.str; + if (function_name == nullptr || strlen(function_name) != 3 || strncmp(function_name, "row", 3) != 0) { + return TransformExpression(root.source); + } + + // Too many columns (ie. (x, y) = (1, 2, 3) ) + if (root.ncolumns < func.args->length) { + throw ParserException( + "Could not perform multiple assignment, target only expects %d values, %d were provided", root.ncolumns, + func.args->length); + } + + // Get the expression corresponding with the current column + idx_t idx = 1; + auto list = func.args->head; + while (list && idx < static_cast(root.colno)) { + list = list->next; + ++idx; + } + + // Not enough columns (ie. (x, y, z) = (1, 2) ) + if (!list) { + throw ParserException( + "Could not perform multiple assignment, target expects %d values, only %d were provided", root.ncolumns, + func.args->length); + } + return TransformExpression(reinterpret_cast(list->data.ptr_value)); + } + return TransformExpression(root.source); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_operator.cpp b/src/duckdb/src/parser/transform/expression/transform_operator.cpp new file mode 100644 index 00000000..826a3558 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_operator.cpp @@ -0,0 +1,218 @@ +#include "duckdb/parser/expression/between_expression.hpp" +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/emptytableref.hpp" +#include "duckdb/parser/parser_options.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformUnaryOperator(const string &op, unique_ptr child) { + vector> children; + children.push_back(std::move(child)); + + // built-in operator function + auto result = make_uniq(op, std::move(children)); + result->is_operator = true; + return std::move(result); +} + +unique_ptr Transformer::TransformBinaryOperator(string op, unique_ptr left, + unique_ptr right) { + vector> children; + children.push_back(std::move(left)); + children.push_back(std::move(right)); + + if (options.integer_division && op == "/") { + op = "//"; + } + if (op == "~" || op == "!~") { + // rewrite 'asdf' SIMILAR TO '.*sd.*' into regexp_full_match('asdf', '.*sd.*') + bool invert_similar = op == "!~"; + + auto result = make_uniq("regexp_full_match", std::move(children)); + if (invert_similar) { + return make_uniq(ExpressionType::OPERATOR_NOT, std::move(result)); + } else { + return std::move(result); + } + } else { + auto target_type = OperatorToExpressionType(op); + if (target_type != ExpressionType::INVALID) { + // built-in comparison operator + return make_uniq(target_type, std::move(children[0]), std::move(children[1])); + } + // not a special operator: convert to a function expression + auto result = make_uniq(std::move(op), std::move(children)); + result->is_operator = true; + return std::move(result); + } +} + +unique_ptr Transformer::TransformAExprInternal(duckdb_libpgquery::PGAExpr &root) { + auto name = string(PGPointerCast(root.name->head->data.ptr_value)->val.str); + + switch (root.kind) { + case duckdb_libpgquery::PG_AEXPR_OP_ALL: + case duckdb_libpgquery::PG_AEXPR_OP_ANY: { + // left=ANY(right) + // we turn this into left=ANY((SELECT UNNEST(right))) + auto left_expr = TransformExpression(root.lexpr); + auto right_expr = TransformExpression(root.rexpr); + + auto subquery_expr = make_uniq(); + auto select_statement = make_uniq(); + auto select_node = make_uniq(); + vector> children; + children.push_back(std::move(right_expr)); + + select_node->select_list.push_back(make_uniq("UNNEST", std::move(children))); + select_node->from_table = make_uniq(); + select_statement->node = std::move(select_node); + subquery_expr->subquery = std::move(select_statement); + subquery_expr->subquery_type = SubqueryType::ANY; + subquery_expr->child = std::move(left_expr); + subquery_expr->comparison_type = OperatorToExpressionType(name); + subquery_expr->query_location = root.location; + if (subquery_expr->comparison_type == ExpressionType::INVALID) { + throw ParserException("Unsupported comparison \"%s\" for ANY/ALL subquery", name); + } + + if (root.kind == duckdb_libpgquery::PG_AEXPR_OP_ALL) { + // ALL sublink is equivalent to NOT(ANY) with inverted comparison + // e.g. [= ALL()] is equivalent to [NOT(<> ANY())] + // first invert the comparison type + subquery_expr->comparison_type = NegateComparisonExpression(subquery_expr->comparison_type); + return make_uniq(ExpressionType::OPERATOR_NOT, std::move(subquery_expr)); + } + return std::move(subquery_expr); + } + case duckdb_libpgquery::PG_AEXPR_IN: { + auto left_expr = TransformExpression(root.lexpr); + ExpressionType operator_type; + // this looks very odd, but seems to be the way to find out its NOT IN + if (name == "<>") { + // NOT IN + operator_type = ExpressionType::COMPARE_NOT_IN; + } else { + // IN + operator_type = ExpressionType::COMPARE_IN; + } + auto result = make_uniq(operator_type, std::move(left_expr)); + result->query_location = root.location; + TransformExpressionList(*PGPointerCast(root.rexpr), result->children); + return std::move(result); + } + // rewrite NULLIF(a, b) into CASE WHEN a=b THEN NULL ELSE a END + case duckdb_libpgquery::PG_AEXPR_NULLIF: { + vector> children; + children.push_back(TransformExpression(root.lexpr)); + children.push_back(TransformExpression(root.rexpr)); + return make_uniq("nullif", std::move(children)); + } + // rewrite (NOT) X BETWEEN A AND B into (NOT) AND(GREATERTHANOREQUALTO(X, + // A), LESSTHANOREQUALTO(X, B)) + case duckdb_libpgquery::PG_AEXPR_BETWEEN: + case duckdb_libpgquery::PG_AEXPR_NOT_BETWEEN: { + auto between_args = PGPointerCast(root.rexpr); + if (between_args->length != 2 || !between_args->head->data.ptr_value || !between_args->tail->data.ptr_value) { + throw InternalException("(NOT) BETWEEN needs two args"); + } + + auto input = TransformExpression(root.lexpr); + auto between_left = + TransformExpression(PGPointerCast(between_args->head->data.ptr_value)); + auto between_right = + TransformExpression(PGPointerCast(between_args->tail->data.ptr_value)); + + auto compare_between = + make_uniq(std::move(input), std::move(between_left), std::move(between_right)); + if (root.kind == duckdb_libpgquery::PG_AEXPR_BETWEEN) { + return std::move(compare_between); + } else { + return make_uniq(ExpressionType::OPERATOR_NOT, std::move(compare_between)); + } + } + // rewrite SIMILAR TO into regexp_full_match('asdf', '.*sd.*') + case duckdb_libpgquery::PG_AEXPR_SIMILAR: { + auto left_expr = TransformExpression(root.lexpr); + auto right_expr = TransformExpression(root.rexpr); + + vector> children; + children.push_back(std::move(left_expr)); + + auto &similar_func = right_expr->Cast(); + D_ASSERT(similar_func.function_name == "similar_escape"); + D_ASSERT(similar_func.children.size() == 2); + if (similar_func.children[1]->type != ExpressionType::VALUE_CONSTANT) { + throw NotImplementedException("Custom escape in SIMILAR TO"); + } + auto &constant = similar_func.children[1]->Cast(); + if (!constant.value.IsNull()) { + throw NotImplementedException("Custom escape in SIMILAR TO"); + } + // take the child of the similar_func + children.push_back(std::move(similar_func.children[0])); + + // this looks very odd, but seems to be the way to find out its NOT IN + bool invert_similar = false; + if (name == "!~") { + // NOT SIMILAR TO + invert_similar = true; + } + const auto regex_function = "regexp_full_match"; + auto result = make_uniq(regex_function, std::move(children)); + + if (invert_similar) { + return make_uniq(ExpressionType::OPERATOR_NOT, std::move(result)); + } else { + return std::move(result); + } + } + case duckdb_libpgquery::PG_AEXPR_NOT_DISTINCT: { + auto left_expr = TransformExpression(root.lexpr); + auto right_expr = TransformExpression(root.rexpr); + return make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, std::move(left_expr), + std::move(right_expr)); + } + case duckdb_libpgquery::PG_AEXPR_DISTINCT: { + auto left_expr = TransformExpression(root.lexpr); + auto right_expr = TransformExpression(root.rexpr); + return make_uniq(ExpressionType::COMPARE_DISTINCT_FROM, std::move(left_expr), + std::move(right_expr)); + } + + default: + break; + } + auto left_expr = TransformExpression(root.lexpr); + auto right_expr = TransformExpression(root.rexpr); + + if (!left_expr) { + // prefix operator + return TransformUnaryOperator(name, std::move(right_expr)); + } else if (!right_expr) { + // postfix operator, only ! is currently supported + return TransformUnaryOperator(name + "__postfix", std::move(left_expr)); + } else { + return TransformBinaryOperator(std::move(name), std::move(left_expr), std::move(right_expr)); + } +} + +unique_ptr Transformer::TransformAExpr(duckdb_libpgquery::PGAExpr &root) { + auto result = TransformAExprInternal(root); + if (result) { + result->query_location = root.location; + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_param_ref.cpp b/src/duckdb/src/parser/transform/expression/transform_param_ref.cpp new file mode 100644 index 00000000..d5d7931f --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_param_ref.cpp @@ -0,0 +1,63 @@ +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +namespace { + +struct PreparedParam { + PreparedParamType type; + string identifier; +}; + +} // namespace + +static PreparedParam GetParameterIdentifier(duckdb_libpgquery::PGParamRef &node) { + PreparedParam param; + if (node.name) { + param.type = PreparedParamType::NAMED; + param.identifier = node.name; + return param; + } + if (node.number < 0) { + throw ParserException("Parameter numbers cannot be negative"); + } + param.identifier = StringUtil::Format("%d", node.number); + param.type = node.number == 0 ? PreparedParamType::AUTO_INCREMENT : PreparedParamType::POSITIONAL; + return param; +} + +unique_ptr Transformer::TransformParamRef(duckdb_libpgquery::PGParamRef &node) { + auto expr = make_uniq(); + + auto param = GetParameterIdentifier(node); + idx_t known_param_index = DConstants::INVALID_INDEX; + // This is a named parameter, try to find an entry for it + GetParam(param.identifier, known_param_index, param.type); + + if (known_param_index == DConstants::INVALID_INDEX) { + // We have not seen this parameter before + if (node.number != 0) { + // Preserve the parameter number + known_param_index = node.number; + } else { + known_param_index = ParamCount() + 1; + if (!node.name) { + param.identifier = StringUtil::Format("%d", known_param_index); + } + } + + if (!named_param_map.count(param.identifier)) { + // Add it to the named parameter map so we can find it next time it's referenced + SetParam(param.identifier, known_param_index, param.type); + } + } + + expr->identifier = param.identifier; + idx_t new_param_count = MaxValue(ParamCount(), known_param_index); + SetParamCount(new_param_count); + return std::move(expr); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_positional_reference.cpp b/src/duckdb/src/parser/transform/expression/transform_positional_reference.cpp new file mode 100644 index 00000000..55cb6472 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_positional_reference.cpp @@ -0,0 +1,16 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/expression/positional_reference_expression.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformPositionalReference(duckdb_libpgquery::PGPositionalReference &node) { + if (node.position <= 0) { + throw ParserException("Positional reference node needs to be >= 1"); + } + auto result = make_uniq(node.position); + result->query_location = node.location; + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/expression/transform_subquery.cpp b/src/duckdb/src/parser/transform/expression/transform_subquery.cpp new file mode 100644 index 00000000..c4ab20a8 --- /dev/null +++ b/src/duckdb/src/parser/transform/expression/transform_subquery.cpp @@ -0,0 +1,104 @@ +#include "duckdb/parser/expression/list.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformSubquery(duckdb_libpgquery::PGSubLink &root) { + auto subquery_expr = make_uniq(); + + subquery_expr->subquery = TransformSelect(root.subselect); + D_ASSERT(subquery_expr->subquery); + D_ASSERT(subquery_expr->subquery->node->GetSelectList().size() > 0); + + switch (root.subLinkType) { + case duckdb_libpgquery::PG_EXISTS_SUBLINK: { + subquery_expr->subquery_type = SubqueryType::EXISTS; + break; + } + case duckdb_libpgquery::PG_ANY_SUBLINK: + case duckdb_libpgquery::PG_ALL_SUBLINK: { + // comparison with ANY() or ALL() + subquery_expr->subquery_type = SubqueryType::ANY; + subquery_expr->child = TransformExpression(root.testexpr); + // get the operator name + if (!root.operName) { + // simple IN + subquery_expr->comparison_type = ExpressionType::COMPARE_EQUAL; + } else { + auto operator_name = + string((PGPointerCast(root.operName->head->data.ptr_value))->val.str); + subquery_expr->comparison_type = OperatorToExpressionType(operator_name); + } + if (subquery_expr->comparison_type != ExpressionType::COMPARE_EQUAL && + subquery_expr->comparison_type != ExpressionType::COMPARE_NOTEQUAL && + subquery_expr->comparison_type != ExpressionType::COMPARE_GREATERTHAN && + subquery_expr->comparison_type != ExpressionType::COMPARE_GREATERTHANOREQUALTO && + subquery_expr->comparison_type != ExpressionType::COMPARE_LESSTHAN && + subquery_expr->comparison_type != ExpressionType::COMPARE_LESSTHANOREQUALTO) { + throw ParserException("ANY and ALL operators require one of =,<>,>,<,>=,<= comparisons!"); + } + if (root.subLinkType == duckdb_libpgquery::PG_ALL_SUBLINK) { + // ALL sublink is equivalent to NOT(ANY) with inverted comparison + // e.g. [= ALL()] is equivalent to [NOT(<> ANY())] + // first invert the comparison type + subquery_expr->comparison_type = NegateComparisonExpression(subquery_expr->comparison_type); + return make_uniq(ExpressionType::OPERATOR_NOT, std::move(subquery_expr)); + } + break; + } + case duckdb_libpgquery::PG_EXPR_SUBLINK: { + // return a single scalar value from the subquery + // no child expression to compare to + subquery_expr->subquery_type = SubqueryType::SCALAR; + break; + } + case duckdb_libpgquery::PG_ARRAY_SUBLINK: { + auto subquery_table_alias = "__subquery"; + auto subquery_column_alias = "__arr_element"; + + // ARRAY expression + // wrap subquery into "SELECT CASE WHEN ARRAY_AGG(i) IS NULL THEN [] ELSE ARRAY_AGG(i) END FROM (...) tbl(i)" + auto select_node = make_uniq(); + + // ARRAY_AGG(i) + vector> children; + children.push_back( + make_uniq_base(subquery_column_alias, subquery_table_alias)); + auto aggr = make_uniq("array_agg", std::move(children)); + // ARRAY_AGG(i) IS NULL + auto agg_is_null = make_uniq(ExpressionType::OPERATOR_IS_NULL, aggr->Copy()); + // empty list + vector> list_children; + auto empty_list = make_uniq("list_value", std::move(list_children)); + // CASE + auto case_expr = make_uniq(); + CaseCheck check; + check.when_expr = std::move(agg_is_null); + check.then_expr = std::move(empty_list); + case_expr->case_checks.push_back(std::move(check)); + case_expr->else_expr = std::move(aggr); + + select_node->select_list.push_back(std::move(case_expr)); + + // FROM (...) tbl(i) + auto child_subquery = make_uniq(std::move(subquery_expr->subquery), subquery_table_alias); + child_subquery->column_name_alias.emplace_back(subquery_column_alias); + select_node->from_table = std::move(child_subquery); + + auto new_subquery = make_uniq(); + new_subquery->node = std::move(select_node); + subquery_expr->subquery = std::move(new_subquery); + + subquery_expr->subquery_type = SubqueryType::SCALAR; + break; + } + default: + throw NotImplementedException("Subquery of type %d not implemented\n", (int)root.subLinkType); + } + subquery_expr->query_location = root.location; + return std::move(subquery_expr); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/helpers/nodetype_to_string.cpp b/src/duckdb/src/parser/transform/helpers/nodetype_to_string.cpp new file mode 100644 index 00000000..99081e2e --- /dev/null +++ b/src/duckdb/src/parser/transform/helpers/nodetype_to_string.cpp @@ -0,0 +1,824 @@ +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +std::string Transformer::NodetypeToString(duckdb_libpgquery::PGNodeTag type) { // LCOV_EXCL_START + switch (type) { + case duckdb_libpgquery::T_PGInvalid: + return "T_Invalid"; + case duckdb_libpgquery::T_PGIndexInfo: + return "T_IndexInfo"; + case duckdb_libpgquery::T_PGExprContext: + return "T_ExprContext"; + case duckdb_libpgquery::T_PGProjectionInfo: + return "T_ProjectionInfo"; + case duckdb_libpgquery::T_PGJunkFilter: + return "T_JunkFilter"; + case duckdb_libpgquery::T_PGResultRelInfo: + return "T_ResultRelInfo"; + case duckdb_libpgquery::T_PGEState: + return "T_EState"; + case duckdb_libpgquery::T_PGTupleTableSlot: + return "T_TupleTableSlot"; + case duckdb_libpgquery::T_PGPlan: + return "T_Plan"; + case duckdb_libpgquery::T_PGResult: + return "T_Result"; + case duckdb_libpgquery::T_PGProjectSet: + return "T_ProjectSet"; + case duckdb_libpgquery::T_PGModifyTable: + return "T_ModifyTable"; + case duckdb_libpgquery::T_PGAppend: + return "T_Append"; + case duckdb_libpgquery::T_PGMergeAppend: + return "T_MergeAppend"; + case duckdb_libpgquery::T_PGRecursiveUnion: + return "T_RecursiveUnion"; + case duckdb_libpgquery::T_PGBitmapAnd: + return "T_BitmapAnd"; + case duckdb_libpgquery::T_PGBitmapOr: + return "T_BitmapOr"; + case duckdb_libpgquery::T_PGScan: + return "T_Scan"; + case duckdb_libpgquery::T_PGSeqScan: + return "T_SeqScan"; + case duckdb_libpgquery::T_PGSampleScan: + return "T_SampleScan"; + case duckdb_libpgquery::T_PGIndexScan: + return "T_IndexScan"; + case duckdb_libpgquery::T_PGIndexOnlyScan: + return "T_IndexOnlyScan"; + case duckdb_libpgquery::T_PGBitmapIndexScan: + return "T_BitmapIndexScan"; + case duckdb_libpgquery::T_PGBitmapHeapScan: + return "T_BitmapHeapScan"; + case duckdb_libpgquery::T_PGTidScan: + return "T_TidScan"; + case duckdb_libpgquery::T_PGSubqueryScan: + return "T_SubqueryScan"; + case duckdb_libpgquery::T_PGFunctionScan: + return "T_FunctionScan"; + case duckdb_libpgquery::T_PGValuesScan: + return "T_ValuesScan"; + case duckdb_libpgquery::T_PGTableFuncScan: + return "T_TableFuncScan"; + case duckdb_libpgquery::T_PGCteScan: + return "T_CteScan"; + case duckdb_libpgquery::T_PGNamedTuplestoreScan: + return "T_NamedTuplestoreScan"; + case duckdb_libpgquery::T_PGWorkTableScan: + return "T_WorkTableScan"; + case duckdb_libpgquery::T_PGForeignScan: + return "T_ForeignScan"; + case duckdb_libpgquery::T_PGCustomScan: + return "T_CustomScan"; + case duckdb_libpgquery::T_PGJoin: + return "T_Join"; + case duckdb_libpgquery::T_PGNestLoop: + return "T_NestLoop"; + case duckdb_libpgquery::T_PGMergeJoin: + return "T_MergeJoin"; + case duckdb_libpgquery::T_PGHashJoin: + return "T_HashJoin"; + case duckdb_libpgquery::T_PGMaterial: + return "T_Material"; + case duckdb_libpgquery::T_PGSort: + return "T_Sort"; + case duckdb_libpgquery::T_PGGroup: + return "T_Group"; + case duckdb_libpgquery::T_PGAgg: + return "T_Agg"; + case duckdb_libpgquery::T_PGWindowAgg: + return "T_WindowAgg"; + case duckdb_libpgquery::T_PGUnique: + return "T_Unique"; + case duckdb_libpgquery::T_PGGather: + return "T_Gather"; + case duckdb_libpgquery::T_PGGatherMerge: + return "T_GatherMerge"; + case duckdb_libpgquery::T_PGHash: + return "T_Hash"; + case duckdb_libpgquery::T_PGSetOp: + return "T_SetOp"; + case duckdb_libpgquery::T_PGLockRows: + return "T_LockRows"; + case duckdb_libpgquery::T_PGLimit: + return "T_Limit"; + case duckdb_libpgquery::T_PGNestLoopParam: + return "T_NestLoopParam"; + case duckdb_libpgquery::T_PGPlanRowMark: + return "T_PlanRowMark"; + case duckdb_libpgquery::T_PGPlanInvalItem: + return "T_PlanInvalItem"; + case duckdb_libpgquery::T_PGPlanState: + return "T_PlanState"; + case duckdb_libpgquery::T_PGResultState: + return "T_ResultState"; + case duckdb_libpgquery::T_PGProjectSetState: + return "T_ProjectSetState"; + case duckdb_libpgquery::T_PGModifyTableState: + return "T_ModifyTableState"; + case duckdb_libpgquery::T_PGAppendState: + return "T_AppendState"; + case duckdb_libpgquery::T_PGMergeAppendState: + return "T_MergeAppendState"; + case duckdb_libpgquery::T_PGRecursiveUnionState: + return "T_RecursiveUnionState"; + case duckdb_libpgquery::T_PGBitmapAndState: + return "T_BitmapAndState"; + case duckdb_libpgquery::T_PGBitmapOrState: + return "T_BitmapOrState"; + case duckdb_libpgquery::T_PGScanState: + return "T_ScanState"; + case duckdb_libpgquery::T_PGSeqScanState: + return "T_SeqScanState"; + case duckdb_libpgquery::T_PGSampleScanState: + return "T_SampleScanState"; + case duckdb_libpgquery::T_PGIndexScanState: + return "T_IndexScanState"; + case duckdb_libpgquery::T_PGIndexOnlyScanState: + return "T_IndexOnlyScanState"; + case duckdb_libpgquery::T_PGBitmapIndexScanState: + return "T_BitmapIndexScanState"; + case duckdb_libpgquery::T_PGBitmapHeapScanState: + return "T_BitmapHeapScanState"; + case duckdb_libpgquery::T_PGTidScanState: + return "T_TidScanState"; + case duckdb_libpgquery::T_PGSubqueryScanState: + return "T_SubqueryScanState"; + case duckdb_libpgquery::T_PGFunctionScanState: + return "T_FunctionScanState"; + case duckdb_libpgquery::T_PGTableFuncScanState: + return "T_TableFuncScanState"; + case duckdb_libpgquery::T_PGValuesScanState: + return "T_ValuesScanState"; + case duckdb_libpgquery::T_PGCteScanState: + return "T_CteScanState"; + case duckdb_libpgquery::T_PGNamedTuplestoreScanState: + return "T_NamedTuplestoreScanState"; + case duckdb_libpgquery::T_PGWorkTableScanState: + return "T_WorkTableScanState"; + case duckdb_libpgquery::T_PGForeignScanState: + return "T_ForeignScanState"; + case duckdb_libpgquery::T_PGCustomScanState: + return "T_CustomScanState"; + case duckdb_libpgquery::T_PGJoinState: + return "T_JoinState"; + case duckdb_libpgquery::T_PGNestLoopState: + return "T_NestLoopState"; + case duckdb_libpgquery::T_PGMergeJoinState: + return "T_MergeJoinState"; + case duckdb_libpgquery::T_PGHashJoinState: + return "T_HashJoinState"; + case duckdb_libpgquery::T_PGMaterialState: + return "T_MaterialState"; + case duckdb_libpgquery::T_PGSortState: + return "T_SortState"; + case duckdb_libpgquery::T_PGGroupState: + return "T_GroupState"; + case duckdb_libpgquery::T_PGAggState: + return "T_AggState"; + case duckdb_libpgquery::T_PGWindowAggState: + return "T_WindowAggState"; + case duckdb_libpgquery::T_PGUniqueState: + return "T_UniqueState"; + case duckdb_libpgquery::T_PGGatherState: + return "T_GatherState"; + case duckdb_libpgquery::T_PGGatherMergeState: + return "T_GatherMergeState"; + case duckdb_libpgquery::T_PGHashState: + return "T_HashState"; + case duckdb_libpgquery::T_PGSetOpState: + return "T_SetOpState"; + case duckdb_libpgquery::T_PGLockRowsState: + return "T_LockRowsState"; + case duckdb_libpgquery::T_PGLimitState: + return "T_LimitState"; + case duckdb_libpgquery::T_PGAlias: + return "T_Alias"; + case duckdb_libpgquery::T_PGRangeVar: + return "T_RangeVar"; + case duckdb_libpgquery::T_PGTableFunc: + return "T_TableFunc"; + case duckdb_libpgquery::T_PGExpr: + return "T_Expr"; + case duckdb_libpgquery::T_PGVar: + return "T_Var"; + case duckdb_libpgquery::T_PGConst: + return "T_Const"; + case duckdb_libpgquery::T_PGParam: + return "T_Param"; + case duckdb_libpgquery::T_PGAggref: + return "T_Aggref"; + case duckdb_libpgquery::T_PGGroupingFunc: + return "T_GroupingFunc"; + case duckdb_libpgquery::T_PGWindowFunc: + return "T_WindowFunc"; + case duckdb_libpgquery::T_PGArrayRef: + return "T_ArrayRef"; + case duckdb_libpgquery::T_PGFuncExpr: + return "T_FuncExpr"; + case duckdb_libpgquery::T_PGNamedArgExpr: + return "T_NamedArgExpr"; + case duckdb_libpgquery::T_PGOpExpr: + return "T_OpExpr"; + case duckdb_libpgquery::T_PGDistinctExpr: + return "T_DistinctExpr"; + case duckdb_libpgquery::T_PGNullIfExpr: + return "T_NullIfExpr"; + case duckdb_libpgquery::T_PGScalarArrayOpExpr: + return "T_ScalarArrayOpExpr"; + case duckdb_libpgquery::T_PGBoolExpr: + return "T_BoolExpr"; + case duckdb_libpgquery::T_PGSubLink: + return "T_SubLink"; + case duckdb_libpgquery::T_PGSubPlan: + return "T_SubPlan"; + case duckdb_libpgquery::T_PGAlternativeSubPlan: + return "T_AlternativeSubPlan"; + case duckdb_libpgquery::T_PGFieldSelect: + return "T_FieldSelect"; + case duckdb_libpgquery::T_PGFieldStore: + return "T_FieldStore"; + case duckdb_libpgquery::T_PGRelabelType: + return "T_RelabelType"; + case duckdb_libpgquery::T_PGCoerceViaIO: + return "T_CoerceViaIO"; + case duckdb_libpgquery::T_PGArrayCoerceExpr: + return "T_ArrayCoerceExpr"; + case duckdb_libpgquery::T_PGConvertRowtypeExpr: + return "T_ConvertRowtypeExpr"; + case duckdb_libpgquery::T_PGCollateExpr: + return "T_CollateExpr"; + case duckdb_libpgquery::T_PGCaseExpr: + return "T_CaseExpr"; + case duckdb_libpgquery::T_PGCaseWhen: + return "T_CaseWhen"; + case duckdb_libpgquery::T_PGCaseTestExpr: + return "T_CaseTestExpr"; + case duckdb_libpgquery::T_PGArrayExpr: + return "T_ArrayExpr"; + case duckdb_libpgquery::T_PGRowExpr: + return "T_RowExpr"; + case duckdb_libpgquery::T_PGRowCompareExpr: + return "T_RowCompareExpr"; + case duckdb_libpgquery::T_PGCoalesceExpr: + return "T_CoalesceExpr"; + case duckdb_libpgquery::T_PGMinMaxExpr: + return "T_MinMaxExpr"; + case duckdb_libpgquery::T_PGSQLValueFunction: + return "T_SQLValueFunction"; + case duckdb_libpgquery::T_PGXmlExpr: + return "T_XmlExpr"; + case duckdb_libpgquery::T_PGNullTest: + return "T_NullTest"; + case duckdb_libpgquery::T_PGBooleanTest: + return "T_BooleanTest"; + case duckdb_libpgquery::T_PGCoerceToDomain: + return "T_CoerceToDomain"; + case duckdb_libpgquery::T_PGCoerceToDomainValue: + return "T_CoerceToDomainValue"; + case duckdb_libpgquery::T_PGSetToDefault: + return "T_SetToDefault"; + case duckdb_libpgquery::T_PGCurrentOfExpr: + return "T_CurrentOfExpr"; + case duckdb_libpgquery::T_PGNextValueExpr: + return "T_NextValueExpr"; + case duckdb_libpgquery::T_PGInferenceElem: + return "T_InferenceElem"; + case duckdb_libpgquery::T_PGTargetEntry: + return "T_TargetEntry"; + case duckdb_libpgquery::T_PGRangeTblRef: + return "T_RangeTblRef"; + case duckdb_libpgquery::T_PGJoinExpr: + return "T_JoinExpr"; + case duckdb_libpgquery::T_PGFromExpr: + return "T_FromExpr"; + case duckdb_libpgquery::T_PGOnConflictExpr: + return "T_OnConflictExpr"; + case duckdb_libpgquery::T_PGIntoClause: + return "T_IntoClause"; + case duckdb_libpgquery::T_PGExprState: + return "T_ExprState"; + case duckdb_libpgquery::T_PGAggrefExprState: + return "T_AggrefExprState"; + case duckdb_libpgquery::T_PGWindowFuncExprState: + return "T_WindowFuncExprState"; + case duckdb_libpgquery::T_PGSetExprState: + return "T_SetExprState"; + case duckdb_libpgquery::T_PGSubPlanState: + return "T_SubPlanState"; + case duckdb_libpgquery::T_PGAlternativeSubPlanState: + return "T_AlternativeSubPlanState"; + case duckdb_libpgquery::T_PGDomainConstraintState: + return "T_DomainConstraintState"; + case duckdb_libpgquery::T_PGPlannerInfo: + return "T_PlannerInfo"; + case duckdb_libpgquery::T_PGPlannerGlobal: + return "T_PlannerGlobal"; + case duckdb_libpgquery::T_PGRelOptInfo: + return "T_RelOptInfo"; + case duckdb_libpgquery::T_PGIndexOptInfo: + return "T_IndexOptInfo"; + case duckdb_libpgquery::T_PGForeignKeyOptInfo: + return "T_ForeignKeyOptInfo"; + case duckdb_libpgquery::T_PGParamPathInfo: + return "T_ParamPathInfo"; + case duckdb_libpgquery::T_PGPath: + return "T_Path"; + case duckdb_libpgquery::T_PGIndexPath: + return "T_IndexPath"; + case duckdb_libpgquery::T_PGBitmapHeapPath: + return "T_BitmapHeapPath"; + case duckdb_libpgquery::T_PGBitmapAndPath: + return "T_BitmapAndPath"; + case duckdb_libpgquery::T_PGBitmapOrPath: + return "T_BitmapOrPath"; + case duckdb_libpgquery::T_PGTidPath: + return "T_TidPath"; + case duckdb_libpgquery::T_PGSubqueryScanPath: + return "T_SubqueryScanPath"; + case duckdb_libpgquery::T_PGForeignPath: + return "T_ForeignPath"; + case duckdb_libpgquery::T_PGCustomPath: + return "T_CustomPath"; + case duckdb_libpgquery::T_PGNestPath: + return "T_NestPath"; + case duckdb_libpgquery::T_PGMergePath: + return "T_MergePath"; + case duckdb_libpgquery::T_PGHashPath: + return "T_HashPath"; + case duckdb_libpgquery::T_PGAppendPath: + return "T_AppendPath"; + case duckdb_libpgquery::T_PGMergeAppendPath: + return "T_MergeAppendPath"; + case duckdb_libpgquery::T_PGResultPath: + return "T_ResultPath"; + case duckdb_libpgquery::T_PGMaterialPath: + return "T_MaterialPath"; + case duckdb_libpgquery::T_PGUniquePath: + return "T_UniquePath"; + case duckdb_libpgquery::T_PGGatherPath: + return "T_GatherPath"; + case duckdb_libpgquery::T_PGGatherMergePath: + return "T_GatherMergePath"; + case duckdb_libpgquery::T_PGProjectionPath: + return "T_ProjectionPath"; + case duckdb_libpgquery::T_PGProjectSetPath: + return "T_ProjectSetPath"; + case duckdb_libpgquery::T_PGSortPath: + return "T_SortPath"; + case duckdb_libpgquery::T_PGGroupPath: + return "T_GroupPath"; + case duckdb_libpgquery::T_PGUpperUniquePath: + return "T_UpperUniquePath"; + case duckdb_libpgquery::T_PGAggPath: + return "T_AggPath"; + case duckdb_libpgquery::T_PGGroupingSetsPath: + return "T_GroupingSetsPath"; + case duckdb_libpgquery::T_PGMinMaxAggPath: + return "T_MinMaxAggPath"; + case duckdb_libpgquery::T_PGWindowAggPath: + return "T_WindowAggPath"; + case duckdb_libpgquery::T_PGSetOpPath: + return "T_SetOpPath"; + case duckdb_libpgquery::T_PGRecursiveUnionPath: + return "T_RecursiveUnionPath"; + case duckdb_libpgquery::T_PGLockRowsPath: + return "T_LockRowsPath"; + case duckdb_libpgquery::T_PGModifyTablePath: + return "T_ModifyTablePath"; + case duckdb_libpgquery::T_PGLimitPath: + return "T_LimitPath"; + case duckdb_libpgquery::T_PGEquivalenceClass: + return "T_EquivalenceClass"; + case duckdb_libpgquery::T_PGEquivalenceMember: + return "T_EquivalenceMember"; + case duckdb_libpgquery::T_PGPathKey: + return "T_PathKey"; + case duckdb_libpgquery::T_PGPathTarget: + return "T_PathTarget"; + case duckdb_libpgquery::T_PGRestrictInfo: + return "T_RestrictInfo"; + case duckdb_libpgquery::T_PGPlaceHolderVar: + return "T_PlaceHolderVar"; + case duckdb_libpgquery::T_PGSpecialJoinInfo: + return "T_SpecialJoinInfo"; + case duckdb_libpgquery::T_PGAppendRelInfo: + return "T_AppendRelInfo"; + case duckdb_libpgquery::T_PGPartitionedChildRelInfo: + return "T_PartitionedChildRelInfo"; + case duckdb_libpgquery::T_PGPlaceHolderInfo: + return "T_PlaceHolderInfo"; + case duckdb_libpgquery::T_PGMinMaxAggInfo: + return "T_MinMaxAggInfo"; + case duckdb_libpgquery::T_PGPlannerParamItem: + return "T_PlannerParamItem"; + case duckdb_libpgquery::T_PGRollupData: + return "T_RollupData"; + case duckdb_libpgquery::T_PGGroupingSetData: + return "T_GroupingSetData"; + case duckdb_libpgquery::T_PGStatisticExtInfo: + return "T_StatisticExtInfo"; + case duckdb_libpgquery::T_PGMemoryContext: + return "T_MemoryContext"; + case duckdb_libpgquery::T_PGAllocSetContext: + return "T_AllocSetContext"; + case duckdb_libpgquery::T_PGSlabContext: + return "T_SlabContext"; + case duckdb_libpgquery::T_PGValue: + return "T_Value"; + case duckdb_libpgquery::T_PGInteger: + return "T_Integer"; + case duckdb_libpgquery::T_PGFloat: + return "T_Float"; + case duckdb_libpgquery::T_PGString: + return "T_String"; + case duckdb_libpgquery::T_PGBitString: + return "T_BitString"; + case duckdb_libpgquery::T_PGNull: + return "T_Null"; + case duckdb_libpgquery::T_PGList: + return "T_List"; + case duckdb_libpgquery::T_PGIntList: + return "T_IntList"; + case duckdb_libpgquery::T_PGOidList: + return "T_OidList"; + case duckdb_libpgquery::T_PGExtensibleNode: + return "T_ExtensibleNode"; + case duckdb_libpgquery::T_PGRawStmt: + return "T_RawStmt"; + case duckdb_libpgquery::T_PGQuery: + return "T_Query"; + case duckdb_libpgquery::T_PGPlannedStmt: + return "T_PlannedStmt"; + case duckdb_libpgquery::T_PGInsertStmt: + return "T_InsertStmt"; + case duckdb_libpgquery::T_PGDeleteStmt: + return "T_DeleteStmt"; + case duckdb_libpgquery::T_PGUpdateStmt: + return "T_UpdateStmt"; + case duckdb_libpgquery::T_PGSelectStmt: + return "T_SelectStmt"; + case duckdb_libpgquery::T_PGAlterTableStmt: + return "T_AlterTableStmt"; + case duckdb_libpgquery::T_PGAlterTableCmd: + return "T_AlterTableCmd"; + case duckdb_libpgquery::T_PGAlterDomainStmt: + return "T_AlterDomainStmt"; + case duckdb_libpgquery::T_PGSetOperationStmt: + return "T_SetOperationStmt"; + case duckdb_libpgquery::T_PGGrantStmt: + return "T_GrantStmt"; + case duckdb_libpgquery::T_PGGrantRoleStmt: + return "T_GrantRoleStmt"; + case duckdb_libpgquery::T_PGAlterDefaultPrivilegesStmt: + return "T_AlterDefaultPrivilegesStmt"; + case duckdb_libpgquery::T_PGClosePortalStmt: + return "T_ClosePortalStmt"; + case duckdb_libpgquery::T_PGClusterStmt: + return "T_ClusterStmt"; + case duckdb_libpgquery::T_PGCopyStmt: + return "T_CopyStmt"; + case duckdb_libpgquery::T_PGCreateStmt: + return "T_CreateStmt"; + case duckdb_libpgquery::T_PGDefineStmt: + return "T_DefineStmt"; + case duckdb_libpgquery::T_PGDropStmt: + return "T_DropStmt"; + case duckdb_libpgquery::T_PGTruncateStmt: + return "T_TruncateStmt"; + case duckdb_libpgquery::T_PGCommentStmt: + return "T_CommentStmt"; + case duckdb_libpgquery::T_PGFetchStmt: + return "T_FetchStmt"; + case duckdb_libpgquery::T_PGIndexStmt: + return "T_IndexStmt"; + case duckdb_libpgquery::T_PGCreateFunctionStmt: + return "T_CreateFunctionStmt"; + case duckdb_libpgquery::T_PGAlterFunctionStmt: + return "T_AlterFunctionStmt"; + case duckdb_libpgquery::T_PGDoStmt: + return "T_DoStmt"; + case duckdb_libpgquery::T_PGRenameStmt: + return "T_RenameStmt"; + case duckdb_libpgquery::T_PGRuleStmt: + return "T_RuleStmt"; + case duckdb_libpgquery::T_PGNotifyStmt: + return "T_NotifyStmt"; + case duckdb_libpgquery::T_PGListenStmt: + return "T_ListenStmt"; + case duckdb_libpgquery::T_PGUnlistenStmt: + return "T_UnlistenStmt"; + case duckdb_libpgquery::T_PGTransactionStmt: + return "T_TransactionStmt"; + case duckdb_libpgquery::T_PGViewStmt: + return "T_ViewStmt"; + case duckdb_libpgquery::T_PGLoadStmt: + return "T_LoadStmt"; + case duckdb_libpgquery::T_PGCreateDomainStmt: + return "T_CreateDomainStmt"; + case duckdb_libpgquery::T_PGCreatedbStmt: + return "T_CreatedbStmt"; + case duckdb_libpgquery::T_PGDropdbStmt: + return "T_DropdbStmt"; + case duckdb_libpgquery::T_PGVacuumStmt: + return "T_VacuumStmt"; + case duckdb_libpgquery::T_PGExplainStmt: + return "T_ExplainStmt"; + case duckdb_libpgquery::T_PGCreateTableAsStmt: + return "T_CreateTableAsStmt"; + case duckdb_libpgquery::T_PGCreateSeqStmt: + return "T_CreateSeqStmt"; + case duckdb_libpgquery::T_PGAlterSeqStmt: + return "T_AlterSeqStmt"; + case duckdb_libpgquery::T_PGVariableSetStmt: + return "T_VariableSetStmt"; + case duckdb_libpgquery::T_PGVariableShowStmt: + return "T_VariableShowStmt"; + case duckdb_libpgquery::T_PGVariableShowSelectStmt: + return "T_VariableShowSelectStmt"; + case duckdb_libpgquery::T_PGDiscardStmt: + return "T_DiscardStmt"; + case duckdb_libpgquery::T_PGCreateTrigStmt: + return "T_CreateTrigStmt"; + case duckdb_libpgquery::T_PGCreatePLangStmt: + return "T_CreatePLangStmt"; + case duckdb_libpgquery::T_PGCreateRoleStmt: + return "T_CreateRoleStmt"; + case duckdb_libpgquery::T_PGAlterRoleStmt: + return "T_AlterRoleStmt"; + case duckdb_libpgquery::T_PGDropRoleStmt: + return "T_DropRoleStmt"; + case duckdb_libpgquery::T_PGLockStmt: + return "T_LockStmt"; + case duckdb_libpgquery::T_PGConstraintsSetStmt: + return "T_ConstraintsSetStmt"; + case duckdb_libpgquery::T_PGReindexStmt: + return "T_ReindexStmt"; + case duckdb_libpgquery::T_PGCheckPointStmt: + return "T_CheckPointStmt"; + case duckdb_libpgquery::T_PGCreateSchemaStmt: + return "T_CreateSchemaStmt"; + case duckdb_libpgquery::T_PGAlterDatabaseStmt: + return "T_AlterDatabaseStmt"; + case duckdb_libpgquery::T_PGAlterDatabaseSetStmt: + return "T_AlterDatabaseSetStmt"; + case duckdb_libpgquery::T_PGAlterRoleSetStmt: + return "T_AlterRoleSetStmt"; + case duckdb_libpgquery::T_PGCreateConversionStmt: + return "T_CreateConversionStmt"; + case duckdb_libpgquery::T_PGCreateCastStmt: + return "T_CreateCastStmt"; + case duckdb_libpgquery::T_PGCreateOpClassStmt: + return "T_CreateOpClassStmt"; + case duckdb_libpgquery::T_PGCreateOpFamilyStmt: + return "T_CreateOpFamilyStmt"; + case duckdb_libpgquery::T_PGAlterOpFamilyStmt: + return "T_AlterOpFamilyStmt"; + case duckdb_libpgquery::T_PGPrepareStmt: + return "T_PrepareStmt"; + case duckdb_libpgquery::T_PGExecuteStmt: + return "T_ExecuteStmt"; + case duckdb_libpgquery::T_PGCallStmt: + return "T_CallStmt"; + case duckdb_libpgquery::T_PGDeallocateStmt: + return "T_DeallocateStmt"; + case duckdb_libpgquery::T_PGDeclareCursorStmt: + return "T_DeclareCursorStmt"; + case duckdb_libpgquery::T_PGCreateTableSpaceStmt: + return "T_CreateTableSpaceStmt"; + case duckdb_libpgquery::T_PGDropTableSpaceStmt: + return "T_DropTableSpaceStmt"; + case duckdb_libpgquery::T_PGAlterObjectDependsStmt: + return "T_AlterObjectDependsStmt"; + case duckdb_libpgquery::T_PGAlterObjectSchemaStmt: + return "T_AlterObjectSchemaStmt"; + case duckdb_libpgquery::T_PGAlterOwnerStmt: + return "T_AlterOwnerStmt"; + case duckdb_libpgquery::T_PGAlterOperatorStmt: + return "T_AlterOperatorStmt"; + case duckdb_libpgquery::T_PGDropOwnedStmt: + return "T_DropOwnedStmt"; + case duckdb_libpgquery::T_PGReassignOwnedStmt: + return "T_ReassignOwnedStmt"; + case duckdb_libpgquery::T_PGCompositeTypeStmt: + return "T_CompositeTypeStmt"; + case duckdb_libpgquery::T_PGCreateTypeStmt: + return "T_CreateTypeStmt"; + case duckdb_libpgquery::T_PGCreateRangeStmt: + return "T_CreateRangeStmt"; + case duckdb_libpgquery::T_PGAlterEnumStmt: + return "T_AlterEnumStmt"; + case duckdb_libpgquery::T_PGAlterTSDictionaryStmt: + return "T_AlterTSDictionaryStmt"; + case duckdb_libpgquery::T_PGAlterTSConfigurationStmt: + return "T_AlterTSConfigurationStmt"; + case duckdb_libpgquery::T_PGCreateFdwStmt: + return "T_CreateFdwStmt"; + case duckdb_libpgquery::T_PGAlterFdwStmt: + return "T_AlterFdwStmt"; + case duckdb_libpgquery::T_PGCreateForeignServerStmt: + return "T_CreateForeignServerStmt"; + case duckdb_libpgquery::T_PGAlterForeignServerStmt: + return "T_AlterForeignServerStmt"; + case duckdb_libpgquery::T_PGCreateUserMappingStmt: + return "T_CreateUserMappingStmt"; + case duckdb_libpgquery::T_PGAlterUserMappingStmt: + return "T_AlterUserMappingStmt"; + case duckdb_libpgquery::T_PGDropUserMappingStmt: + return "T_DropUserMappingStmt"; + case duckdb_libpgquery::T_PGAlterTableSpaceOptionsStmt: + return "T_AlterTableSpaceOptionsStmt"; + case duckdb_libpgquery::T_PGAlterTableMoveAllStmt: + return "T_AlterTableMoveAllStmt"; + case duckdb_libpgquery::T_PGSecLabelStmt: + return "T_SecLabelStmt"; + case duckdb_libpgquery::T_PGCreateForeignTableStmt: + return "T_CreateForeignTableStmt"; + case duckdb_libpgquery::T_PGImportForeignSchemaStmt: + return "T_ImportForeignSchemaStmt"; + case duckdb_libpgquery::T_PGCreateExtensionStmt: + return "T_CreateExtensionStmt"; + case duckdb_libpgquery::T_PGAlterExtensionStmt: + return "T_AlterExtensionStmt"; + case duckdb_libpgquery::T_PGAlterExtensionContentsStmt: + return "T_AlterExtensionContentsStmt"; + case duckdb_libpgquery::T_PGCreateEventTrigStmt: + return "T_CreateEventTrigStmt"; + case duckdb_libpgquery::T_PGAlterEventTrigStmt: + return "T_AlterEventTrigStmt"; + case duckdb_libpgquery::T_PGRefreshMatViewStmt: + return "T_RefreshMatViewStmt"; + case duckdb_libpgquery::T_PGReplicaIdentityStmt: + return "T_ReplicaIdentityStmt"; + case duckdb_libpgquery::T_PGAlterSystemStmt: + return "T_AlterSystemStmt"; + case duckdb_libpgquery::T_PGCreatePolicyStmt: + return "T_CreatePolicyStmt"; + case duckdb_libpgquery::T_PGAlterPolicyStmt: + return "T_AlterPolicyStmt"; + case duckdb_libpgquery::T_PGCreateTransformStmt: + return "T_CreateTransformStmt"; + case duckdb_libpgquery::T_PGCreateAmStmt: + return "T_CreateAmStmt"; + case duckdb_libpgquery::T_PGCreatePublicationStmt: + return "T_CreatePublicationStmt"; + case duckdb_libpgquery::T_PGAlterPublicationStmt: + return "T_AlterPublicationStmt"; + case duckdb_libpgquery::T_PGCreateSubscriptionStmt: + return "T_CreateSubscriptionStmt"; + case duckdb_libpgquery::T_PGAlterSubscriptionStmt: + return "T_AlterSubscriptionStmt"; + case duckdb_libpgquery::T_PGDropSubscriptionStmt: + return "T_DropSubscriptionStmt"; + case duckdb_libpgquery::T_PGCreateStatsStmt: + return "T_CreateStatsStmt"; + case duckdb_libpgquery::T_PGAlterCollationStmt: + return "T_AlterCollationStmt"; + case duckdb_libpgquery::T_PGAExpr: + return "TAExpr"; + case duckdb_libpgquery::T_PGColumnRef: + return "T_ColumnRef"; + case duckdb_libpgquery::T_PGParamRef: + return "T_ParamRef"; + case duckdb_libpgquery::T_PGAConst: + return "TAConst"; + case duckdb_libpgquery::T_PGFuncCall: + return "T_FuncCall"; + case duckdb_libpgquery::T_PGAStar: + return "TAStar"; + case duckdb_libpgquery::T_PGAIndices: + return "TAIndices"; + case duckdb_libpgquery::T_PGAIndirection: + return "TAIndirection"; + case duckdb_libpgquery::T_PGAArrayExpr: + return "TAArrayExpr"; + case duckdb_libpgquery::T_PGResTarget: + return "T_ResTarget"; + case duckdb_libpgquery::T_PGMultiAssignRef: + return "T_MultiAssignRef"; + case duckdb_libpgquery::T_PGTypeCast: + return "T_TypeCast"; + case duckdb_libpgquery::T_PGCollateClause: + return "T_CollateClause"; + case duckdb_libpgquery::T_PGSortBy: + return "T_SortBy"; + case duckdb_libpgquery::T_PGWindowDef: + return "T_WindowDef"; + case duckdb_libpgquery::T_PGRangeSubselect: + return "T_RangeSubselect"; + case duckdb_libpgquery::T_PGRangeFunction: + return "T_RangeFunction"; + case duckdb_libpgquery::T_PGRangeTableSample: + return "T_RangeTableSample"; + case duckdb_libpgquery::T_PGRangeTableFunc: + return "T_RangeTableFunc"; + case duckdb_libpgquery::T_PGRangeTableFuncCol: + return "T_RangeTableFuncCol"; + case duckdb_libpgquery::T_PGTypeName: + return "T_TypeName"; + case duckdb_libpgquery::T_PGColumnDef: + return "T_ColumnDef"; + case duckdb_libpgquery::T_PGIndexElem: + return "T_IndexElem"; + case duckdb_libpgquery::T_PGConstraint: + return "T_Constraint"; + case duckdb_libpgquery::T_PGDefElem: + return "T_DefElem"; + case duckdb_libpgquery::T_PGRangeTblEntry: + return "T_RangeTblEntry"; + case duckdb_libpgquery::T_PGRangeTblFunction: + return "T_RangeTblFunction"; + case duckdb_libpgquery::T_PGTableSampleClause: + return "T_TableSampleClause"; + case duckdb_libpgquery::T_PGWithCheckOption: + return "T_WithCheckOption"; + case duckdb_libpgquery::T_PGSortGroupClause: + return "T_SortGroupClause"; + case duckdb_libpgquery::T_PGGroupingSet: + return "T_GroupingSet"; + case duckdb_libpgquery::T_PGWindowClause: + return "T_WindowClause"; + case duckdb_libpgquery::T_PGObjectWithArgs: + return "T_ObjectWithArgs"; + case duckdb_libpgquery::T_PGAccessPriv: + return "T_AccessPriv"; + case duckdb_libpgquery::T_PGCreateOpClassItem: + return "T_CreateOpClassItem"; + case duckdb_libpgquery::T_PGTableLikeClause: + return "T_TableLikeClause"; + case duckdb_libpgquery::T_PGFunctionParameter: + return "T_FunctionParameter"; + case duckdb_libpgquery::T_PGLockingClause: + return "T_LockingClause"; + case duckdb_libpgquery::T_PGRowMarkClause: + return "T_RowMarkClause"; + case duckdb_libpgquery::T_PGXmlSerialize: + return "T_XmlSerialize"; + case duckdb_libpgquery::T_PGWithClause: + return "T_WithClause"; + case duckdb_libpgquery::T_PGInferClause: + return "T_InferClause"; + case duckdb_libpgquery::T_PGOnConflictClause: + return "T_OnConflictClause"; + case duckdb_libpgquery::T_PGCommonTableExpr: + return "T_CommonTableExpr"; + case duckdb_libpgquery::T_PGRoleSpec: + return "T_RoleSpec"; + case duckdb_libpgquery::T_PGTriggerTransition: + return "T_TriggerTransition"; + case duckdb_libpgquery::T_PGPartitionElem: + return "T_PartitionElem"; + case duckdb_libpgquery::T_PGPartitionSpec: + return "T_PartitionSpec"; + case duckdb_libpgquery::T_PGPartitionBoundSpec: + return "T_PartitionBoundSpec"; + case duckdb_libpgquery::T_PGPartitionRangeDatum: + return "T_PartitionRangeDatum"; + case duckdb_libpgquery::T_PGPartitionCmd: + return "T_PartitionCmd"; + case duckdb_libpgquery::T_PGIdentifySystemCmd: + return "T_IdentifySystemCmd"; + case duckdb_libpgquery::T_PGBaseBackupCmd: + return "T_BaseBackupCmd"; + case duckdb_libpgquery::T_PGCreateReplicationSlotCmd: + return "T_CreateReplicationSlotCmd"; + case duckdb_libpgquery::T_PGDropReplicationSlotCmd: + return "T_DropReplicationSlotCmd"; + case duckdb_libpgquery::T_PGStartReplicationCmd: + return "T_StartReplicationCmd"; + case duckdb_libpgquery::T_PGTimeLineHistoryCmd: + return "T_TimeLineHistoryCmd"; + case duckdb_libpgquery::T_PGSQLCmd: + return "T_SQLCmd"; + case duckdb_libpgquery::T_PGTriggerData: + return "T_TriggerData"; + case duckdb_libpgquery::T_PGEventTriggerData: + return "T_EventTriggerData"; + case duckdb_libpgquery::T_PGReturnSetInfo: + return "T_ReturnSetInfo"; + case duckdb_libpgquery::T_PGWindowObjectData: + return "T_WindowObjectData"; + case duckdb_libpgquery::T_PGTIDBitmap: + return "T_TIDBitmap"; + case duckdb_libpgquery::T_PGInlineCodeBlock: + return "T_InlineCodeBlock"; + case duckdb_libpgquery::T_PGFdwRoutine: + return "T_FdwRoutine"; + case duckdb_libpgquery::T_PGIndexAmRoutine: + return "T_IndexAmRoutine"; + case duckdb_libpgquery::T_PGTsmRoutine: + return "T_TsmRoutine"; + case duckdb_libpgquery::T_PGForeignKeyCacheInfo: + return "T_ForeignKeyCacheInfo"; + case duckdb_libpgquery::T_PGAttachStmt: + return "T_PGAttachStmt"; + case duckdb_libpgquery::T_PGUseStmt: + return "T_PGUseStmt"; + default: + return "(UNKNOWN)"; + } +} // LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/helpers/transform_alias.cpp b/src/duckdb/src/parser/transform/helpers/transform_alias.cpp new file mode 100644 index 00000000..637e94a5 --- /dev/null +++ b/src/duckdb/src/parser/transform/helpers/transform_alias.cpp @@ -0,0 +1,24 @@ +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +vector Transformer::TransformStringList(duckdb_libpgquery::PGList *list) { + vector result; + if (!list) { + return result; + } + for (auto node = list->head; node != nullptr; node = node->next) { + result.emplace_back(reinterpret_cast(node->data.ptr_value)->val.str); + } + return result; +} + +string Transformer::TransformAlias(duckdb_libpgquery::PGAlias *root, vector &column_name_alias) { + if (!root) { + return ""; + } + column_name_alias = TransformStringList(root->colnames); + return root->aliasname; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp new file mode 100644 index 00000000..dded0d7d --- /dev/null +++ b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp @@ -0,0 +1,140 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/enums/set_operation_type.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/query_node/recursive_cte_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" + +namespace duckdb { + +unique_ptr CommonTableExpressionInfo::Copy() { + auto result = make_uniq(); + result->aliases = aliases; + result->query = unique_ptr_cast(query->Copy()); + result->materialized = materialized; + return result; +} + +void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { + for (auto &cte_entry : stored_cte_map) { + for (auto &entry : cte_entry->map) { + auto found_entry = cte_map.map.find(entry.first); + if (found_entry != cte_map.map.end()) { + // entry already present - use top-most entry + continue; + } + cte_map.map[entry.first] = entry.second->Copy(); + } + } + if (parent) { + parent->ExtractCTEsRecursive(cte_map); + } +} + +void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map, + vector> &materialized_ctes) { + // TODO: might need to update in case of future lawsuit + stored_cte_map.push_back(&cte_map); + + D_ASSERT(de_with_clause.ctes); + for (auto cte_ele = de_with_clause.ctes->head; cte_ele != nullptr; cte_ele = cte_ele->next) { + auto info = make_uniq(); + + auto &cte = *PGPointerCast(cte_ele->data.ptr_value); + if (cte.aliascolnames) { + for (auto node = cte.aliascolnames->head; node != nullptr; node = node->next) { + info->aliases.emplace_back( + reinterpret_cast(node->data.ptr_value)->val.str); + } + } + // lets throw some errors on unsupported features early + if (cte.ctecolnames) { + throw NotImplementedException("Column name setting not supported in CTEs"); + } + if (cte.ctecoltypes) { + throw NotImplementedException("Column type setting not supported in CTEs"); + } + if (cte.ctecoltypmods) { + throw NotImplementedException("Column type modification not supported in CTEs"); + } + if (cte.ctecolcollations) { + throw NotImplementedException("CTE collations not supported"); + } + // we need a query + if (!cte.ctequery || cte.ctequery->type != duckdb_libpgquery::T_PGSelectStmt) { + throw NotImplementedException("A CTE needs a SELECT"); + } + + // CTE transformation can either result in inlining for non recursive CTEs, or in recursive CTE bindings + // otherwise. + if (cte.cterecursive || de_with_clause.recursive) { + info->query = TransformRecursiveCTE(cte, *info); + } else { + Transformer cte_transformer(*this); + info->query = + cte_transformer.TransformSelect(*PGPointerCast(cte.ctequery)); + } + D_ASSERT(info->query); + auto cte_name = string(cte.ctename); + + auto it = cte_map.map.find(cte_name); + if (it != cte_map.map.end()) { + // can't have two CTEs with same name + throw ParserException("Duplicate CTE name \"%s\"", cte_name); + } + +#ifdef DUCKDB_ALTERNATIVE_VERIFY + if (cte.ctematerialized == duckdb_libpgquery::PGCTEMaterializeDefault) { +#else + if (cte.ctematerialized == duckdb_libpgquery::PGCTEMaterializeAlways) { +#endif + auto materialize = make_uniq(); + materialize->query = info->query->node->Copy(); + materialize->ctename = cte_name; + materialize->aliases = info->aliases; + materialized_ctes.push_back(std::move(materialize)); + + info->materialized = CTEMaterialize::CTE_MATERIALIZE_ALWAYS; + } + + cte_map.map[cte_name] = std::move(info); + } +} + +unique_ptr Transformer::TransformRecursiveCTE(duckdb_libpgquery::PGCommonTableExpr &cte, + CommonTableExpressionInfo &info) { + auto &stmt = *PGPointerCast(cte.ctequery); + + unique_ptr select; + switch (stmt.op) { + case duckdb_libpgquery::PG_SETOP_UNION: + case duckdb_libpgquery::PG_SETOP_EXCEPT: + case duckdb_libpgquery::PG_SETOP_INTERSECT: { + select = make_uniq(); + select->node = make_uniq_base(); + auto &result = select->node->Cast(); + result.ctename = string(cte.ctename); + result.union_all = stmt.all; + result.left = TransformSelectNode(*PGPointerCast(stmt.larg)); + result.right = TransformSelectNode(*PGPointerCast(stmt.rarg)); + result.aliases = info.aliases; + if (stmt.op != duckdb_libpgquery::PG_SETOP_UNION) { + throw ParserException("Unsupported setop type for recursive CTE: only UNION or UNION ALL are supported"); + } + break; + } + default: + // This CTE is not recursive. Fallback to regular query transformation. + return TransformSelect(*PGPointerCast(cte.ctequery)); + } + + if (stmt.limitCount || stmt.limitOffset) { + throw ParserException("LIMIT or OFFSET in a recursive query is not allowed"); + } + if (stmt.sortClause) { + throw ParserException("ORDER BY in a recursive query is not allowed"); + } + return select; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/helpers/transform_groupby.cpp b/src/duckdb/src/parser/transform/helpers/transform_groupby.cpp new file mode 100644 index 00000000..c61071c4 --- /dev/null +++ b/src/duckdb/src/parser/transform/helpers/transform_groupby.cpp @@ -0,0 +1,189 @@ +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/parser/expression/function_expression.hpp" + +namespace duckdb { + +static void CheckGroupingSetMax(idx_t count) { + static constexpr const idx_t MAX_GROUPING_SETS = 65535; + if (count > MAX_GROUPING_SETS) { + throw ParserException("Maximum grouping set count of %d exceeded", MAX_GROUPING_SETS); + } +} + +static void CheckGroupingSetCubes(idx_t current_count, idx_t cube_count) { + idx_t combinations = 1; + for (idx_t i = 0; i < cube_count; i++) { + combinations *= 2; + CheckGroupingSetMax(current_count + combinations); + } +} + +struct GroupingExpressionMap { + parsed_expression_map_t map; +}; + +static GroupingSet VectorToGroupingSet(vector &indexes) { + GroupingSet result; + for (idx_t i = 0; i < indexes.size(); i++) { + result.insert(indexes[i]); + } + return result; +} + +static void MergeGroupingSet(GroupingSet &result, GroupingSet &other) { + CheckGroupingSetMax(result.size() + other.size()); + result.insert(other.begin(), other.end()); +} + +void Transformer::AddGroupByExpression(unique_ptr expression, GroupingExpressionMap &map, + GroupByNode &result, vector &result_set) { + if (expression->type == ExpressionType::FUNCTION) { + auto &func = expression->Cast(); + if (func.function_name == "row") { + for (auto &child : func.children) { + AddGroupByExpression(std::move(child), map, result, result_set); + } + return; + } + } + auto entry = map.map.find(*expression); + idx_t result_idx; + if (entry == map.map.end()) { + result_idx = result.group_expressions.size(); + map.map[*expression] = result_idx; + result.group_expressions.push_back(std::move(expression)); + } else { + result_idx = entry->second; + } + result_set.push_back(result_idx); +} + +static void AddCubeSets(const GroupingSet ¤t_set, vector &result_set, + vector &result_sets, idx_t start_idx = 0) { + CheckGroupingSetMax(result_sets.size()); + result_sets.push_back(current_set); + for (idx_t k = start_idx; k < result_set.size(); k++) { + auto child_set = current_set; + MergeGroupingSet(child_set, result_set[k]); + AddCubeSets(child_set, result_set, result_sets, k + 1); + } +} + +void Transformer::TransformGroupByExpression(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, + GroupByNode &result, vector &indexes) { + auto expression = TransformExpression(n); + AddGroupByExpression(std::move(expression), map, result, indexes); +} + +// If one GROUPING SETS clause is nested inside another, +// the effect is the same as if all the elements of the inner clause had been written directly in the outer clause. +void Transformer::TransformGroupByNode(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, SelectNode &result, + vector &result_sets) { + if (n.type == duckdb_libpgquery::T_PGGroupingSet) { + auto &grouping_set = PGCast(n); + switch (grouping_set.kind) { + case duckdb_libpgquery::GROUPING_SET_EMPTY: + result_sets.emplace_back(); + break; + case duckdb_libpgquery::GROUPING_SET_ALL: { + result.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; + break; + } + case duckdb_libpgquery::GROUPING_SET_SETS: { + for (auto node = grouping_set.content->head; node; node = node->next) { + auto pg_node = PGPointerCast(node->data.ptr_value); + TransformGroupByNode(*pg_node, map, result, result_sets); + } + break; + } + case duckdb_libpgquery::GROUPING_SET_ROLLUP: { + vector rollup_sets; + for (auto node = grouping_set.content->head; node; node = node->next) { + auto pg_node = PGPointerCast(node->data.ptr_value); + vector rollup_set; + TransformGroupByExpression(*pg_node, map, result.groups, rollup_set); + rollup_sets.push_back(VectorToGroupingSet(rollup_set)); + } + // generate the subsets of the rollup set and add them to the grouping sets + GroupingSet current_set; + result_sets.push_back(current_set); + for (idx_t i = 0; i < rollup_sets.size(); i++) { + MergeGroupingSet(current_set, rollup_sets[i]); + result_sets.push_back(current_set); + } + break; + } + case duckdb_libpgquery::GROUPING_SET_CUBE: { + vector cube_sets; + for (auto node = grouping_set.content->head; node; node = node->next) { + auto pg_node = PGPointerCast(node->data.ptr_value); + vector cube_set; + TransformGroupByExpression(*pg_node, map, result.groups, cube_set); + cube_sets.push_back(VectorToGroupingSet(cube_set)); + } + // generate the subsets of the rollup set and add them to the grouping sets + CheckGroupingSetCubes(result_sets.size(), cube_sets.size()); + + GroupingSet current_set; + AddCubeSets(current_set, cube_sets, result_sets, 0); + break; + } + default: + throw InternalException("Unsupported GROUPING SET type %d", grouping_set.kind); + } + } else { + vector indexes; + TransformGroupByExpression(n, map, result.groups, indexes); + result_sets.push_back(VectorToGroupingSet(indexes)); + } +} + +// If multiple grouping items are specified in a single GROUP BY clause, +// then the final list of grouping sets is the cross product of the individual items. +bool Transformer::TransformGroupBy(optional_ptr group, SelectNode &select_node) { + if (!group) { + return false; + } + auto &result = select_node.groups; + GroupingExpressionMap map; + for (auto node = group->head; node != nullptr; node = node->next) { + auto n = PGPointerCast(node->data.ptr_value); + vector result_sets; + TransformGroupByNode(*n, map, select_node, result_sets); + CheckGroupingSetMax(result_sets.size()); + if (result.grouping_sets.empty()) { + // no grouping sets yet: use the current set of grouping sets + result.grouping_sets = std::move(result_sets); + } else { + // compute the cross product + vector new_sets; + idx_t grouping_set_count = result.grouping_sets.size() * result_sets.size(); + CheckGroupingSetMax(grouping_set_count); + new_sets.reserve(grouping_set_count); + for (idx_t current_idx = 0; current_idx < result.grouping_sets.size(); current_idx++) { + auto ¤t_set = result.grouping_sets[current_idx]; + for (idx_t new_idx = 0; new_idx < result_sets.size(); new_idx++) { + auto &new_set = result_sets[new_idx]; + GroupingSet set; + set.insert(current_set.begin(), current_set.end()); + set.insert(new_set.begin(), new_set.end()); + new_sets.push_back(std::move(set)); + } + } + result.grouping_sets = std::move(new_sets); + } + } + if (result.group_expressions.size() == 1 && result.grouping_sets.size() == 1 && + ExpressionIsEmptyStar(*result.group_expressions[0])) { + // GROUP BY * + result.group_expressions.clear(); + result.grouping_sets.clear(); + select_node.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/helpers/transform_orderby.cpp b/src/duckdb/src/parser/transform/helpers/transform_orderby.cpp new file mode 100644 index 00000000..d448e114 --- /dev/null +++ b/src/duckdb/src/parser/transform/helpers/transform_orderby.cpp @@ -0,0 +1,47 @@ +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/expression/star_expression.hpp" + +namespace duckdb { + +bool Transformer::TransformOrderBy(duckdb_libpgquery::PGList *order, vector &result) { + if (!order) { + return false; + } + + for (auto node = order->head; node != nullptr; node = node->next) { + auto temp = reinterpret_cast(node->data.ptr_value); + if (temp->type == duckdb_libpgquery::T_PGSortBy) { + OrderType type; + OrderByNullType null_order; + auto sort = reinterpret_cast(temp); + auto target = sort->node; + if (sort->sortby_dir == duckdb_libpgquery::PG_SORTBY_DEFAULT) { + type = OrderType::ORDER_DEFAULT; + } else if (sort->sortby_dir == duckdb_libpgquery::PG_SORTBY_ASC) { + type = OrderType::ASCENDING; + } else if (sort->sortby_dir == duckdb_libpgquery::PG_SORTBY_DESC) { + type = OrderType::DESCENDING; + } else { + throw NotImplementedException("Unimplemented order by type"); + } + if (sort->sortby_nulls == duckdb_libpgquery::PG_SORTBY_NULLS_DEFAULT) { + null_order = OrderByNullType::ORDER_DEFAULT; + } else if (sort->sortby_nulls == duckdb_libpgquery::PG_SORTBY_NULLS_FIRST) { + null_order = OrderByNullType::NULLS_FIRST; + } else if (sort->sortby_nulls == duckdb_libpgquery::PG_SORTBY_NULLS_LAST) { + null_order = OrderByNullType::NULLS_LAST; + } else { + throw NotImplementedException("Unimplemented order by type"); + } + auto order_expression = TransformExpression(target); + result.emplace_back(type, null_order, std::move(order_expression)); + } else { + throw NotImplementedException("ORDER BY list member type %d\n", temp->type); + } + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/helpers/transform_sample.cpp b/src/duckdb/src/parser/transform/helpers/transform_sample.cpp new file mode 100644 index 00000000..0cffebfe --- /dev/null +++ b/src/duckdb/src/parser/transform/helpers/transform_sample.cpp @@ -0,0 +1,56 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +static SampleMethod GetSampleMethod(const string &method) { + auto lmethod = StringUtil::Lower(method); + if (lmethod == "system") { + return SampleMethod::SYSTEM_SAMPLE; + } else if (lmethod == "bernoulli") { + return SampleMethod::BERNOULLI_SAMPLE; + } else if (lmethod == "reservoir") { + return SampleMethod::RESERVOIR_SAMPLE; + } else { + throw ParserException("Unrecognized sampling method %s, expected system, bernoulli or reservoir", method); + } +} + +unique_ptr Transformer::TransformSampleOptions(optional_ptr options) { + if (!options) { + return nullptr; + } + auto result = make_uniq(); + auto &sample_options = PGCast(*options); + auto &sample_size = *PGPointerCast(sample_options.sample_size); + auto sample_value = TransformValue(sample_size.sample_size)->value; + result->is_percentage = sample_size.is_percentage; + if (sample_size.is_percentage) { + // sample size is given in sample_size: use system sampling + auto percentage = sample_value.GetValue(); + if (percentage < 0 || percentage > 100) { + throw ParserException("Sample sample_size %llf out of range, must be between 0 and 100", percentage); + } + result->sample_size = Value::DOUBLE(percentage); + result->method = SampleMethod::SYSTEM_SAMPLE; + } else { + // sample size is given in rows: use reservoir sampling + auto rows = sample_value.GetValue(); + if (rows < 0) { + throw ParserException("Sample rows %lld out of range, must be bigger than or equal to 0", rows); + } + result->sample_size = Value::BIGINT(rows); + result->method = SampleMethod::RESERVOIR_SAMPLE; + } + if (sample_options.method) { + result->method = GetSampleMethod(sample_options.method); + } + if (sample_options.has_seed) { + result->seed = sample_options.seed; + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/helpers/transform_typename.cpp b/src/duckdb/src/parser/transform/helpers/transform_typename.cpp new file mode 100644 index 00000000..2527b55f --- /dev/null +++ b/src/duckdb/src/parser/transform/helpers/transform_typename.cpp @@ -0,0 +1,234 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_name) { + if (type_name.type != duckdb_libpgquery::T_PGTypeName) { + throw ParserException("Expected a type"); + } + auto stack_checker = StackCheck(); + + auto name = PGPointerCast(type_name.names->tail->data.ptr_value)->val.str; + // transform it to the SQL type + LogicalTypeId base_type = TransformStringToLogicalTypeId(name); + + LogicalType result_type; + if (base_type == LogicalTypeId::LIST) { + throw ParserException("LIST is not valid as a stand-alone type"); + } else if (base_type == LogicalTypeId::ENUM) { + if (!type_name.typmods || type_name.typmods->length == 0) { + throw ParserException("Enum needs a set of entries"); + } + Vector enum_vector(LogicalType::VARCHAR, type_name.typmods->length); + auto string_data = FlatVector::GetData(enum_vector); + idx_t pos = 0; + for (auto node = type_name.typmods->head; node; node = node->next) { + auto constant_value = PGPointerCast(node->data.ptr_value); + if (constant_value->type != duckdb_libpgquery::T_PGAConst || + constant_value->val.type != duckdb_libpgquery::T_PGString) { + throw ParserException("Enum type requires a set of strings as type modifiers"); + } + string_data[pos++] = StringVector::AddString(enum_vector, constant_value->val.val.str); + } + return LogicalType::ENUM(enum_vector, type_name.typmods->length); + } else if (base_type == LogicalTypeId::STRUCT) { + if (!type_name.typmods || type_name.typmods->length == 0) { + throw ParserException("Struct needs a name and entries"); + } + child_list_t children; + case_insensitive_set_t name_collision_set; + + for (auto node = type_name.typmods->head; node; node = node->next) { + auto &type_val = *PGPointerCast(node->data.ptr_value); + if (type_val.length != 2) { + throw ParserException("Struct entry needs an entry name and a type name"); + } + + auto entry_name_node = PGPointerCast(type_val.head->data.ptr_value); + D_ASSERT(entry_name_node->type == duckdb_libpgquery::T_PGString); + auto entry_type_node = PGPointerCast(type_val.tail->data.ptr_value); + D_ASSERT(entry_type_node->type == duckdb_libpgquery::T_PGTypeName); + + auto entry_name = string(entry_name_node->val.str); + D_ASSERT(!entry_name.empty()); + + if (name_collision_set.find(entry_name) != name_collision_set.end()) { + throw ParserException("Duplicate struct entry name \"%s\"", entry_name); + } + name_collision_set.insert(entry_name); + auto entry_type = TransformTypeName(*entry_type_node); + + children.push_back(make_pair(entry_name, entry_type)); + } + D_ASSERT(!children.empty()); + result_type = LogicalType::STRUCT(children); + + } else if (base_type == LogicalTypeId::MAP) { + if (!type_name.typmods || type_name.typmods->length != 2) { + throw ParserException("Map type needs exactly two entries, key and value type"); + } + auto key_type = + TransformTypeName(*PGPointerCast(type_name.typmods->head->data.ptr_value)); + auto value_type = + TransformTypeName(*PGPointerCast(type_name.typmods->tail->data.ptr_value)); + + result_type = LogicalType::MAP(std::move(key_type), std::move(value_type)); + } else if (base_type == LogicalTypeId::UNION) { + if (!type_name.typmods || type_name.typmods->length == 0) { + throw ParserException("Union type needs at least one member"); + } + if (type_name.typmods->length > (int)UnionType::MAX_UNION_MEMBERS) { + throw ParserException("Union types can have at most %d members", UnionType::MAX_UNION_MEMBERS); + } + + child_list_t children; + case_insensitive_set_t name_collision_set; + + for (auto node = type_name.typmods->head; node; node = node->next) { + auto &type_val = *PGPointerCast(node->data.ptr_value); + if (type_val.length != 2) { + throw ParserException("Union type member needs a tag name and a type name"); + } + + auto entry_name_node = PGPointerCast(type_val.head->data.ptr_value); + D_ASSERT(entry_name_node->type == duckdb_libpgquery::T_PGString); + auto entry_type_node = PGPointerCast(type_val.tail->data.ptr_value); + D_ASSERT(entry_type_node->type == duckdb_libpgquery::T_PGTypeName); + + auto entry_name = string(entry_name_node->val.str); + D_ASSERT(!entry_name.empty()); + + if (name_collision_set.find(entry_name) != name_collision_set.end()) { + throw ParserException("Duplicate union type tag name \"%s\"", entry_name); + } + + name_collision_set.insert(entry_name); + + auto entry_type = TransformTypeName(*entry_type_node); + children.push_back(make_pair(entry_name, entry_type)); + } + D_ASSERT(!children.empty()); + result_type = LogicalType::UNION(std::move(children)); + } else { + int64_t width, scale; + if (base_type == LogicalTypeId::DECIMAL) { + // default decimal width/scale + width = 18; + scale = 3; + } else { + width = 0; + scale = 0; + } + // check any modifiers + int modifier_idx = 0; + if (type_name.typmods) { + for (auto node = type_name.typmods->head; node; node = node->next) { + auto &const_val = *PGPointerCast(node->data.ptr_value); + if (const_val.type != duckdb_libpgquery::T_PGAConst || + const_val.val.type != duckdb_libpgquery::T_PGInteger) { + throw ParserException("Expected an integer constant as type modifier"); + } + if (const_val.val.val.ival < 0) { + throw ParserException("Negative modifier not supported"); + } + if (modifier_idx == 0) { + width = const_val.val.val.ival; + if (base_type == LogicalTypeId::BIT && const_val.location != -1) { + width = 0; + } + } else if (modifier_idx == 1) { + scale = const_val.val.val.ival; + } else { + throw ParserException("A maximum of two modifiers is supported"); + } + modifier_idx++; + } + } + switch (base_type) { + case LogicalTypeId::VARCHAR: + if (modifier_idx > 1) { + throw ParserException("VARCHAR only supports a single modifier"); + } + // FIXME: create CHECK constraint based on varchar width + width = 0; + result_type = LogicalType::VARCHAR; + break; + case LogicalTypeId::DECIMAL: + if (modifier_idx == 1) { + // only width is provided: set scale to 0 + scale = 0; + } + if (width <= 0 || width > Decimal::MAX_WIDTH_DECIMAL) { + throw ParserException("Width must be between 1 and %d!", (int)Decimal::MAX_WIDTH_DECIMAL); + } + if (scale > width) { + throw ParserException("Scale cannot be bigger than width"); + } + result_type = LogicalType::DECIMAL(width, scale); + break; + case LogicalTypeId::INTERVAL: + if (modifier_idx > 1) { + throw ParserException("INTERVAL only supports a single modifier"); + } + width = 0; + result_type = LogicalType::INTERVAL; + break; + case LogicalTypeId::USER: { + string user_type_name {name}; + result_type = LogicalType::USER(user_type_name); + break; + } + case LogicalTypeId::BIT: { + if (!width && type_name.typmods) { + throw ParserException("Type %s does not support any modifiers!", LogicalType(base_type).ToString()); + } + result_type = LogicalType(base_type); + break; + } + case LogicalTypeId::TIMESTAMP: + if (modifier_idx == 0) { + result_type = LogicalType::TIMESTAMP; + } else { + if (modifier_idx > 1) { + throw ParserException("TIMESTAMP only supports a single modifier"); + } + if (width > 10) { + throw ParserException("TIMESTAMP only supports until nano-second precision (9)"); + } + if (width == 0) { + result_type = LogicalType::TIMESTAMP_S; + } else if (width <= 3) { + result_type = LogicalType::TIMESTAMP_MS; + } else if (width <= 6) { + result_type = LogicalType::TIMESTAMP; + } else { + result_type = LogicalType::TIMESTAMP_NS; + } + } + break; + default: + if (modifier_idx > 0) { + throw ParserException("Type %s does not support any modifiers!", LogicalType(base_type).ToString()); + } + result_type = LogicalType(base_type); + break; + } + } + if (type_name.arrayBounds) { + // array bounds: turn the type into a list + idx_t extra_stack = 0; + for (auto cell = type_name.arrayBounds->head; cell != nullptr; cell = cell->next) { + result_type = LogicalType::LIST(result_type); + StackCheck(extra_stack++); + } + } + return result_type; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_alter_sequence.cpp b/src/duckdb/src/parser/transform/statement/transform_alter_sequence.cpp new file mode 100644 index 00000000..b80c5a92 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_alter_sequence.cpp @@ -0,0 +1,71 @@ +#include "duckdb/common/enum_class_hash.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" +#include "duckdb/parser/statement/alter_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformAlterSequence(duckdb_libpgquery::PGAlterSeqStmt &stmt) { + auto result = make_uniq(); + + auto qname = TransformQualifiedName(*stmt.sequence); + auto sequence_catalog = qname.catalog; + auto sequence_schema = qname.schema; + auto sequence_name = qname.name; + + if (!stmt.options) { + throw InternalException("Expected an argument for ALTER SEQUENCE."); + } + + unordered_set used; + duckdb_libpgquery::PGListCell *cell; + for_each_cell(cell, stmt.options->head) { + auto def_elem = PGPointerCast(cell->data.ptr_value); + string opt_name = string(def_elem->defname); + + if (opt_name == "owned_by") { + if (used.find(SequenceInfo::SEQ_OWN) != used.end()) { + throw ParserException("Owned by value should be passed as most once"); + } + used.insert(SequenceInfo::SEQ_OWN); + + auto val = PGPointerCast(def_elem->arg); + if (!val) { + throw InternalException("Expected an argument for option %s", opt_name); + } + D_ASSERT(val); + if (val->type != duckdb_libpgquery::T_PGList) { + throw InternalException("Expected a string argument for option %s", opt_name); + } + auto opt_values = vector(); + + for (auto c = val->head; c != nullptr; c = lnext(c)) { + auto target = PGPointerCast(c->data.ptr_value); + opt_values.emplace_back(target->name); + } + D_ASSERT(!opt_values.empty()); + string owner_schema = INVALID_SCHEMA; + string owner_name; + if (opt_values.size() == 2) { + owner_schema = opt_values[0]; + owner_name = opt_values[1]; + } else if (opt_values.size() == 1) { + owner_schema = DEFAULT_SCHEMA; + owner_name = opt_values[0]; + } else { + throw InternalException("Wrong argument for %s. Expected either . or ", opt_name); + } + auto info = make_uniq(CatalogType::SEQUENCE_ENTRY, sequence_catalog, sequence_schema, + sequence_name, owner_schema, owner_name, + TransformOnEntryNotFound(stmt.missing_ok)); + result->info = std::move(info); + } else { + throw NotImplementedException("ALTER SEQUENCE option not supported yet!"); + } + } + result->info->if_not_found = TransformOnEntryNotFound(stmt.missing_ok); + return result; +} +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp b/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp new file mode 100644 index 00000000..ba899ba0 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp @@ -0,0 +1,105 @@ +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/statement/alter_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +OnEntryNotFound Transformer::TransformOnEntryNotFound(bool missing_ok) { + return missing_ok ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; +} + +unique_ptr Transformer::TransformAlter(duckdb_libpgquery::PGAlterTableStmt &stmt) { + D_ASSERT(stmt.relation); + + if (stmt.cmds->length != 1) { + throw ParserException("Only one ALTER command per statement is supported"); + } + + auto result = make_uniq(); + auto qname = TransformQualifiedName(*stmt.relation); + + // first we check the type of ALTER + for (auto c = stmt.cmds->head; c != nullptr; c = c->next) { + auto command = reinterpret_cast(lfirst(c)); + AlterEntryData data(qname.catalog, qname.schema, qname.name, TransformOnEntryNotFound(stmt.missing_ok)); + // TODO: Include more options for command->subtype + switch (command->subtype) { + case duckdb_libpgquery::PG_AT_AddColumn: { + auto cdef = PGPointerCast(command->def); + + if (stmt.relkind != duckdb_libpgquery::PG_OBJECT_TABLE) { + throw ParserException("Adding columns is only supported for tables"); + } + if (cdef->category == duckdb_libpgquery::COL_GENERATED) { + throw ParserException("Adding generated columns after table creation is not supported yet"); + } + auto centry = TransformColumnDefinition(*cdef); + + if (cdef->constraints) { + for (auto constr = cdef->constraints->head; constr != nullptr; constr = constr->next) { + auto constraint = TransformConstraint(constr, centry, 0); + if (!constraint) { + continue; + } + throw ParserException("Adding columns with constraints not yet supported"); + } + } + result->info = make_uniq(std::move(data), std::move(centry), command->missing_ok); + break; + } + case duckdb_libpgquery::PG_AT_DropColumn: { + bool cascade = command->behavior == duckdb_libpgquery::PG_DROP_CASCADE; + + if (stmt.relkind != duckdb_libpgquery::PG_OBJECT_TABLE) { + throw ParserException("Dropping columns is only supported for tables"); + } + result->info = make_uniq(std::move(data), command->name, command->missing_ok, cascade); + break; + } + case duckdb_libpgquery::PG_AT_ColumnDefault: { + auto expr = TransformExpression(command->def); + + if (stmt.relkind != duckdb_libpgquery::PG_OBJECT_TABLE) { + throw ParserException("Alter column's default is only supported for tables"); + } + result->info = make_uniq(std::move(data), command->name, std::move(expr)); + break; + } + case duckdb_libpgquery::PG_AT_AlterColumnType: { + auto cdef = PGPointerCast(command->def); + auto column_definition = TransformColumnDefinition(*cdef); + unique_ptr expr; + + if (stmt.relkind != duckdb_libpgquery::PG_OBJECT_TABLE) { + throw ParserException("Alter column's type is only supported for tables"); + } + if (cdef->raw_default) { + expr = TransformExpression(cdef->raw_default); + } else { + auto colref = make_uniq(command->name); + expr = make_uniq(column_definition.Type(), std::move(colref)); + } + result->info = make_uniq(std::move(data), command->name, column_definition.Type(), + std::move(expr)); + break; + } + case duckdb_libpgquery::PG_AT_SetNotNull: { + result->info = make_uniq(std::move(data), command->name); + break; + } + case duckdb_libpgquery::PG_AT_DropNotNull: { + result->info = make_uniq(std::move(data), command->name); + break; + } + case duckdb_libpgquery::PG_AT_DropConstraint: + default: + throw NotImplementedException("ALTER TABLE option not supported yet!"); + } + } + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_attach.cpp b/src/duckdb/src/parser/transform/statement/transform_attach.cpp new file mode 100644 index 00000000..b3ecde30 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_attach.cpp @@ -0,0 +1,31 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/statement/attach_statement.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformAttach(duckdb_libpgquery::PGAttachStmt &stmt) { + auto result = make_uniq(); + auto info = make_uniq(); + info->name = stmt.name ? stmt.name : string(); + info->path = stmt.path; + + if (stmt.options) { + duckdb_libpgquery::PGListCell *cell; + for_each_cell(cell, stmt.options->head) { + auto def_elem = PGPointerCast(cell->data.ptr_value); + Value val; + if (def_elem->arg) { + val = TransformValue(*PGPointerCast(def_elem->arg))->value; + } else { + val = Value::BOOLEAN(true); + } + info->options[StringUtil::Lower(def_elem->defname)] = std::move(val); + } + } + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_call.cpp b/src/duckdb/src/parser/transform/statement/transform_call.cpp new file mode 100644 index 00000000..54d1a4f3 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_call.cpp @@ -0,0 +1,12 @@ +#include "duckdb/parser/statement/call_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformCall(duckdb_libpgquery::PGCallStmt &stmt) { + auto result = make_uniq(); + result->function = TransformFuncCall(*PGPointerCast(stmt.func)); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_checkpoint.cpp b/src/duckdb/src/parser/transform/statement/transform_checkpoint.cpp new file mode 100644 index 00000000..e52d1df2 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_checkpoint.cpp @@ -0,0 +1,21 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/statement/call_statement.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformCheckpoint(duckdb_libpgquery::PGCheckPointStmt &stmt) { + vector> children; + // transform into "CALL checkpoint()" or "CALL force_checkpoint()" + auto checkpoint_name = stmt.force ? "force_checkpoint" : "checkpoint"; + auto result = make_uniq(); + auto function = make_uniq(checkpoint_name, std::move(children)); + if (stmt.name) { + function->children.push_back(make_uniq(Value(stmt.name))); + } + result->function = std::move(function); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_copy.cpp b/src/duckdb/src/parser/transform/statement/transform_copy.cpp new file mode 100644 index 00000000..41c8bea8 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_copy.cpp @@ -0,0 +1,119 @@ +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/core_functions/scalar/struct_functions.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/transformer.hpp" + +#include + +namespace duckdb { + +void Transformer::TransformCopyOptions(CopyInfo &info, optional_ptr options) { + if (!options) { + return; + } + + // iterate over each option + duckdb_libpgquery::PGListCell *cell; + for_each_cell(cell, options->head) { + auto def_elem = PGPointerCast(cell->data.ptr_value); + if (StringUtil::Lower(def_elem->defname) == "format") { + // format specifier: interpret this option + auto format_val = PGPointerCast(def_elem->arg); + if (!format_val || format_val->type != duckdb_libpgquery::T_PGString) { + throw ParserException("Unsupported parameter type for FORMAT: expected e.g. FORMAT 'csv', 'parquet'"); + } + info.format = StringUtil::Lower(format_val->val.str); + continue; + } + // otherwise + if (info.options.find(def_elem->defname) != info.options.end()) { + throw ParserException("Unexpected duplicate option \"%s\"", def_elem->defname); + } + if (!def_elem->arg) { + info.options[def_elem->defname] = vector(); + continue; + } + switch (def_elem->arg->type) { + case duckdb_libpgquery::T_PGList: { + auto column_list = PGPointerCast(def_elem->arg); + for (auto c = column_list->head; c != nullptr; c = lnext(c)) { + auto target = PGPointerCast(c->data.ptr_value); + info.options[def_elem->defname].push_back(Value(target->name)); + } + break; + } + case duckdb_libpgquery::T_PGAStar: + info.options[def_elem->defname].push_back(Value("*")); + break; + case duckdb_libpgquery::T_PGFuncCall: { + auto func_call = PGPointerCast(def_elem->arg); + auto func_expr = TransformFuncCall(*func_call); + + Value value; + if (!Transformer::ConstructConstantFromExpression(*func_expr, value)) { + throw ParserException("Unsupported expression in COPY options: %s", func_expr->ToString()); + } + info.options[def_elem->defname].push_back(std::move(value)); + break; + } + default: { + auto val = PGPointerCast(def_elem->arg); + info.options[def_elem->defname].push_back(TransformValue(*val)->value); + break; + } + } + } +} + +unique_ptr Transformer::TransformCopy(duckdb_libpgquery::PGCopyStmt &stmt) { + auto result = make_uniq(); + auto &info = *result->info; + + // get file_path and is_from + info.is_from = stmt.is_from; + if (!stmt.filename) { + // stdin/stdout + info.file_path = info.is_from ? "/dev/stdin" : "/dev/stdout"; + } else { + // copy to a file + info.file_path = stmt.filename; + } + if (StringUtil::EndsWith(info.file_path, ".parquet")) { + info.format = "parquet"; + } else if (StringUtil::EndsWith(info.file_path, ".json") || StringUtil::EndsWith(info.file_path, ".ndjson")) { + info.format = "json"; + } else { + info.format = "csv"; + } + + // get select_list + if (stmt.attlist) { + for (auto n = stmt.attlist->head; n != nullptr; n = n->next) { + auto target = PGPointerCast(n->data.ptr_value); + if (target->name) { + info.select_list.emplace_back(target->name); + } + } + } + + if (stmt.relation) { + auto ref = TransformRangeVar(*stmt.relation); + auto &table = ref->Cast(); + info.table = table.table_name; + info.schema = table.schema_name; + info.catalog = table.catalog_name; + } else { + result->select_statement = TransformSelectNode(*PGPointerCast(stmt.query)); + } + + // handle the different options of the COPY statement + TransformCopyOptions(info, stmt.options); + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_create_function.cpp b/src/duckdb/src/parser/transform/statement/transform_create_function.cpp new file mode 100644 index 00000000..4bdc21d2 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_create_function.cpp @@ -0,0 +1,83 @@ +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/transformer.hpp" + +#include "duckdb/function/scalar_macro_function.hpp" +#include "duckdb/function/table_macro_function.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformCreateFunction(duckdb_libpgquery::PGCreateFunctionStmt &stmt) { + D_ASSERT(stmt.type == duckdb_libpgquery::T_PGCreateFunctionStmt); + D_ASSERT(stmt.function || stmt.query); + + auto result = make_uniq(); + auto qname = TransformQualifiedName(*stmt.name); + + unique_ptr macro_func; + + // function can be null here + if (stmt.function) { + auto expression = TransformExpression(stmt.function); + macro_func = make_uniq(std::move(expression)); + } else if (stmt.query) { + auto query_node = + TransformSelect(*PGPointerCast(stmt.query), true)->node->Copy(); + macro_func = make_uniq(std::move(query_node)); + } + PivotEntryCheck("macro"); + + auto info = make_uniq(stmt.function ? CatalogType::MACRO_ENTRY : CatalogType::TABLE_MACRO_ENTRY); + info->catalog = qname.catalog; + info->schema = qname.schema; + info->name = qname.name; + + // temporary macro + switch (stmt.name->relpersistence) { + case duckdb_libpgquery::PG_RELPERSISTENCE_TEMP: + info->temporary = true; + break; + case duckdb_libpgquery::PG_RELPERSISTENCE_UNLOGGED: + throw ParserException("Unlogged flag not supported for macros: '%s'", qname.name); + break; + case duckdb_libpgquery::RELPERSISTENCE_PERMANENT: + info->temporary = false; + break; + } + + // what to do on conflict + info->on_conflict = TransformOnConflict(stmt.onconflict); + + if (stmt.params) { + vector> parameters; + TransformExpressionList(*stmt.params, parameters); + for (auto ¶m : parameters) { + if (param->type == ExpressionType::VALUE_CONSTANT) { + // parameters with default value (must have an alias) + if (param->alias.empty()) { + throw ParserException("Invalid parameter: '%s'", param->ToString()); + } + if (macro_func->default_parameters.find(param->alias) != macro_func->default_parameters.end()) { + throw ParserException("Duplicate default parameter: '%s'", param->alias); + } + macro_func->default_parameters[param->alias] = std::move(param); + } else if (param->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + // positional parameters + if (!macro_func->default_parameters.empty()) { + throw ParserException("Positional parameters cannot come after parameters with a default value!"); + } + macro_func->parameters.push_back(std::move(param)); + } else { + throw ParserException("Invalid parameter: '%s'", param->ToString()); + } + } + } + + info->function = std::move(macro_func); + result->info = std::move(info); + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_create_index.cpp b/src/duckdb/src/parser/transform/statement/transform_create_index.cpp new file mode 100644 index 00000000..b2267518 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_create_index.cpp @@ -0,0 +1,93 @@ +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +vector> Transformer::TransformIndexParameters(duckdb_libpgquery::PGList &list, + const string &relation_name) { + vector> expressions; + for (auto cell = list.head; cell != nullptr; cell = cell->next) { + auto index_element = PGPointerCast(cell->data.ptr_value); + if (index_element->collation) { + throw NotImplementedException("Index with collation not supported yet!"); + } + if (index_element->opclass) { + throw NotImplementedException("Index with opclass not supported yet!"); + } + + if (index_element->name) { + // create a column reference expression + expressions.push_back(make_uniq(index_element->name, relation_name)); + } else { + // parse the index expression + D_ASSERT(index_element->expr); + expressions.push_back(TransformExpression(index_element->expr)); + } + } + return expressions; +} + +unique_ptr Transformer::TransformCreateIndex(duckdb_libpgquery::PGIndexStmt &stmt) { + auto result = make_uniq(); + auto info = make_uniq(); + if (stmt.unique) { + info->constraint_type = IndexConstraintType::UNIQUE; + } else { + info->constraint_type = IndexConstraintType::NONE; + } + + info->on_conflict = TransformOnConflict(stmt.onconflict); + + info->expressions = TransformIndexParameters(*stmt.indexParams, stmt.relation->relname); + + auto index_type_name = StringUtil::Upper(string(stmt.accessMethod)); + + if (index_type_name == "ART") { + info->index_type = IndexType::ART; + } else { + info->index_type = IndexType::EXTENSION; + } + + info->index_type_name = index_type_name; + + if (stmt.relation->schemaname) { + info->schema = stmt.relation->schemaname; + } + if (stmt.relation->catalogname) { + info->catalog = stmt.relation->catalogname; + } + info->table = stmt.relation->relname; + if (stmt.idxname) { + info->index_name = stmt.idxname; + } else { + throw NotImplementedException("Index without a name not supported yet!"); + } + + // Parse the options list + if (stmt.options) { + duckdb_libpgquery::PGListCell *cell; + for_each_cell(cell, stmt.options->head) { + auto def_elem = PGPointerCast(cell->data.ptr_value); + Value val; + if (def_elem->arg) { + val = TransformValue(*PGPointerCast(def_elem->arg))->value; + } else { + val = Value::BOOLEAN(true); + } + info->options[StringUtil::Lower(def_elem->defname)] = std::move(val); + } + } + + for (auto &expr : info->expressions) { + info->parsed_expressions.emplace_back(expr->Copy()); + } + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_create_schema.cpp b/src/duckdb/src/parser/transform/statement/transform_create_schema.cpp new file mode 100644 index 00000000..0dc343ab --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_create_schema.cpp @@ -0,0 +1,32 @@ +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformCreateSchema(duckdb_libpgquery::PGCreateSchemaStmt &stmt) { + auto result = make_uniq(); + auto info = make_uniq(); + + D_ASSERT(stmt.schemaname); + info->catalog = stmt.catalogname ? stmt.catalogname : INVALID_CATALOG; + info->schema = stmt.schemaname; + info->on_conflict = TransformOnConflict(stmt.onconflict); + + if (stmt.schemaElts) { + // schema elements + for (auto cell = stmt.schemaElts->head; cell != nullptr; cell = cell->next) { + auto node = PGPointerCast(cell->data.ptr_value); + switch (node->type) { + case duckdb_libpgquery::T_PGCreateStmt: + case duckdb_libpgquery::T_PGViewStmt: + default: + throw NotImplementedException("Schema element not supported yet!"); + } + } + } + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_create_sequence.cpp b/src/duckdb/src/parser/transform/statement/transform_create_sequence.cpp new file mode 100644 index 00000000..cfc671f6 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_create_sequence.cpp @@ -0,0 +1,128 @@ +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/enum_class_hash.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/operator/cast_operators.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformCreateSequence(duckdb_libpgquery::PGCreateSeqStmt &stmt) { + auto result = make_uniq(); + auto info = make_uniq(); + + auto qname = TransformQualifiedName(*stmt.sequence); + info->catalog = qname.catalog; + info->schema = qname.schema; + info->name = qname.name; + + if (stmt.options) { + unordered_set used; + duckdb_libpgquery::PGListCell *cell = nullptr; + for_each_cell(cell, stmt.options->head) { + auto def_elem = PGPointerCast(cell->data.ptr_value); + string opt_name = string(def_elem->defname); + auto val = PGPointerCast(def_elem->arg); + bool nodef = def_elem->defaction == duckdb_libpgquery::PG_DEFELEM_UNSPEC && !val; // e.g. NO MINVALUE + int64_t opt_value = 0; + + if (val) { + if (val->type == duckdb_libpgquery::T_PGInteger) { + opt_value = val->val.ival; + } else if (val->type == duckdb_libpgquery::T_PGFloat) { + if (!TryCast::Operation(string_t(val->val.str), opt_value, true)) { + throw ParserException("Expected an integer argument for option %s", opt_name); + } + } else { + throw ParserException("Expected an integer argument for option %s", opt_name); + } + } + if (opt_name == "increment") { + if (used.find(SequenceInfo::SEQ_INC) != used.end()) { + throw ParserException("Increment value should be passed as most once"); + } + used.insert(SequenceInfo::SEQ_INC); + if (nodef) { + continue; + } + + info->increment = opt_value; + if (info->increment == 0) { + throw ParserException("Increment must not be zero"); + } + if (info->increment < 0) { + info->start_value = info->max_value = -1; + info->min_value = NumericLimits::Minimum(); + } else { + info->start_value = info->min_value = 1; + info->max_value = NumericLimits::Maximum(); + } + } else if (opt_name == "minvalue") { + if (used.find(SequenceInfo::SEQ_MIN) != used.end()) { + throw ParserException("Minvalue should be passed as most once"); + } + used.insert(SequenceInfo::SEQ_MIN); + if (nodef) { + continue; + } + + info->min_value = opt_value; + if (info->increment > 0) { + info->start_value = info->min_value; + } + } else if (opt_name == "maxvalue") { + if (used.find(SequenceInfo::SEQ_MAX) != used.end()) { + throw ParserException("Maxvalue should be passed as most once"); + } + used.insert(SequenceInfo::SEQ_MAX); + if (nodef) { + continue; + } + + info->max_value = opt_value; + if (info->increment < 0) { + info->start_value = info->max_value; + } + } else if (opt_name == "start") { + if (used.find(SequenceInfo::SEQ_START) != used.end()) { + throw ParserException("Start value should be passed as most once"); + } + used.insert(SequenceInfo::SEQ_START); + if (nodef) { + continue; + } + + info->start_value = opt_value; + } else if (opt_name == "cycle") { + if (used.find(SequenceInfo::SEQ_CYCLE) != used.end()) { + throw ParserException("Cycle value should be passed as most once"); + } + used.insert(SequenceInfo::SEQ_CYCLE); + if (nodef) { + continue; + } + + info->cycle = opt_value > 0; + } else { + throw ParserException("Unrecognized option \"%s\" for CREATE SEQUENCE", opt_name); + } + } + } + info->temporary = !stmt.sequence->relpersistence; + info->on_conflict = TransformOnConflict(stmt.onconflict); + if (info->max_value <= info->min_value) { + throw ParserException("MINVALUE (%lld) must be less than MAXVALUE (%lld)", info->min_value, info->max_value); + } + if (info->start_value < info->min_value) { + throw ParserException("START value (%lld) cannot be less than MINVALUE (%lld)", info->start_value, + info->min_value); + } + if (info->start_value > info->max_value) { + throw ParserException("START value (%lld) cannot be greater than MAXVALUE (%lld)", info->start_value, + info->max_value); + } + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_create_table.cpp b/src/duckdb/src/parser/transform/statement/transform_create_table.cpp new file mode 100644 index 00000000..8c707060 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_create_table.cpp @@ -0,0 +1,132 @@ +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/expression/collate_expression.hpp" +#include "duckdb/catalog/catalog_entry/table_column_type.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +string Transformer::TransformCollation(optional_ptr collate) { + if (!collate) { + return string(); + } + string collation; + for (auto c = collate->collname->head; c != nullptr; c = lnext(c)) { + auto pgvalue = PGPointerCast(c->data.ptr_value); + if (pgvalue->type != duckdb_libpgquery::T_PGString) { + throw ParserException("Expected a string as collation type!"); + } + auto collation_argument = string(pgvalue->val.str); + if (collation.empty()) { + collation = collation_argument; + } else { + collation += "." + collation_argument; + } + } + return collation; +} + +OnCreateConflict Transformer::TransformOnConflict(duckdb_libpgquery::PGOnCreateConflict conflict) { + switch (conflict) { + case duckdb_libpgquery::PG_ERROR_ON_CONFLICT: + return OnCreateConflict::ERROR_ON_CONFLICT; + case duckdb_libpgquery::PG_IGNORE_ON_CONFLICT: + return OnCreateConflict::IGNORE_ON_CONFLICT; + case duckdb_libpgquery::PG_REPLACE_ON_CONFLICT: + return OnCreateConflict::REPLACE_ON_CONFLICT; + default: + throw InternalException("Unrecognized OnConflict type"); + } +} + +unique_ptr Transformer::TransformCollateExpr(duckdb_libpgquery::PGCollateClause &collate) { + auto child = TransformExpression(collate.arg); + auto collation = TransformCollation(&collate); + return make_uniq(collation, std::move(child)); +} + +ColumnDefinition Transformer::TransformColumnDefinition(duckdb_libpgquery::PGColumnDef &cdef) { + string colname; + if (cdef.colname) { + colname = cdef.colname; + } + bool optional_type = cdef.category == duckdb_libpgquery::COL_GENERATED; + LogicalType target_type = (optional_type && !cdef.typeName) ? LogicalType::ANY : TransformTypeName(*cdef.typeName); + if (cdef.collClause) { + if (cdef.category == duckdb_libpgquery::COL_GENERATED) { + throw ParserException("Collations are not supported on generated columns"); + } + if (target_type.id() != LogicalTypeId::VARCHAR) { + throw ParserException("Only VARCHAR columns can have collations!"); + } + target_type = LogicalType::VARCHAR_COLLATION(TransformCollation(cdef.collClause)); + } + + return ColumnDefinition(colname, target_type); +} + +unique_ptr Transformer::TransformCreateTable(duckdb_libpgquery::PGCreateStmt &stmt) { + auto result = make_uniq(); + auto info = make_uniq(); + + if (stmt.inhRelations) { + throw NotImplementedException("inherited relations not implemented"); + } + D_ASSERT(stmt.relation); + + info->catalog = INVALID_CATALOG; + auto qname = TransformQualifiedName(*stmt.relation); + info->catalog = qname.catalog; + info->schema = qname.schema; + info->table = qname.name; + info->on_conflict = TransformOnConflict(stmt.onconflict); + info->temporary = + stmt.relation->relpersistence == duckdb_libpgquery::PGPostgresRelPersistence::PG_RELPERSISTENCE_TEMP; + + if (info->temporary && stmt.oncommit != duckdb_libpgquery::PGOnCommitAction::PG_ONCOMMIT_PRESERVE_ROWS && + stmt.oncommit != duckdb_libpgquery::PGOnCommitAction::PG_ONCOMMIT_NOOP) { + throw NotImplementedException("Only ON COMMIT PRESERVE ROWS is supported"); + } + if (!stmt.tableElts) { + throw ParserException("Table must have at least one column!"); + } + + idx_t column_count = 0; + for (auto c = stmt.tableElts->head; c != nullptr; c = lnext(c)) { + auto node = PGPointerCast(c->data.ptr_value); + switch (node->type) { + case duckdb_libpgquery::T_PGColumnDef: { + auto cdef = PGPointerCast(c->data.ptr_value); + auto centry = TransformColumnDefinition(*cdef); + if (cdef->constraints) { + for (auto constr = cdef->constraints->head; constr != nullptr; constr = constr->next) { + auto constraint = TransformConstraint(constr, centry, info->columns.LogicalColumnCount()); + if (constraint) { + info->constraints.push_back(std::move(constraint)); + } + } + } + info->columns.AddColumn(std::move(centry)); + column_count++; + break; + } + case duckdb_libpgquery::T_PGConstraint: { + info->constraints.push_back(TransformConstraint(c)); + break; + } + default: + throw NotImplementedException("ColumnDef type not handled yet"); + } + } + + if (!column_count) { + throw ParserException("Table must have at least one column!"); + } + + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_create_table_as.cpp b/src/duckdb/src/parser/transform/statement/transform_create_table_as.cpp new file mode 100644 index 00000000..0c3406c2 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_create_table_as.cpp @@ -0,0 +1,33 @@ +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformCreateTableAs(duckdb_libpgquery::PGCreateTableAsStmt &stmt) { + if (stmt.relkind == duckdb_libpgquery::PG_OBJECT_MATVIEW) { + throw NotImplementedException("Materialized view not implemented"); + } + if (stmt.is_select_into || stmt.into->colNames || stmt.into->options) { + throw NotImplementedException("Unimplemented features for CREATE TABLE as"); + } + auto qname = TransformQualifiedName(*stmt.into->rel); + if (stmt.query->type != duckdb_libpgquery::T_PGSelectStmt) { + throw ParserException("CREATE TABLE AS requires a SELECT clause"); + } + auto query = TransformSelect(stmt.query, false); + + auto result = make_uniq(); + auto info = make_uniq(); + info->catalog = qname.catalog; + info->schema = qname.schema; + info->table = qname.name; + info->on_conflict = TransformOnConflict(stmt.onconflict); + info->temporary = + stmt.into->rel->relpersistence == duckdb_libpgquery::PGPostgresRelPersistence::PG_RELPERSISTENCE_TEMP; + info->query = std::move(query); + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_create_type.cpp b/src/duckdb/src/parser/transform/statement/transform_create_type.cpp new file mode 100644 index 00000000..3235ed3b --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_create_type.cpp @@ -0,0 +1,74 @@ +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +Vector Transformer::PGListToVector(optional_ptr column_list, idx_t &size) { + if (!column_list) { + Vector result(LogicalType::VARCHAR); + return result; + } + // First we discover the size of this list + for (auto c = column_list->head; c != nullptr; c = lnext(c)) { + size++; + } + + Vector result(LogicalType::VARCHAR, size); + auto result_ptr = FlatVector::GetData(result); + + size = 0; + for (auto c = column_list->head; c != nullptr; c = lnext(c)) { + auto &type_val = *PGPointerCast(c->data.ptr_value); + auto &entry_value_node = type_val.val; + if (entry_value_node.type != duckdb_libpgquery::T_PGString) { + throw ParserException("Expected a string constant as value"); + } + + auto entry_value = string(entry_value_node.val.str); + D_ASSERT(!entry_value.empty()); + result_ptr[size++] = StringVector::AddStringOrBlob(result, entry_value); + } + return result; +} + +unique_ptr Transformer::TransformCreateType(duckdb_libpgquery::PGCreateTypeStmt &stmt) { + auto result = make_uniq(); + auto info = make_uniq(); + + auto qualified_name = TransformQualifiedName(*stmt.typeName); + info->catalog = qualified_name.catalog; + info->schema = qualified_name.schema; + info->name = qualified_name.name; + + switch (stmt.kind) { + case duckdb_libpgquery::PG_NEWTYPE_ENUM: { + info->internal = false; + if (stmt.query) { + // CREATE TYPE mood AS ENUM (SELECT ...) + D_ASSERT(stmt.vals == nullptr); + auto query = TransformSelect(stmt.query, false); + info->query = std::move(query); + info->type = LogicalType::INVALID; + } else { + D_ASSERT(stmt.query == nullptr); + idx_t size = 0; + auto ordered_array = PGListToVector(stmt.vals, size); + info->type = LogicalType::ENUM(ordered_array, size); + } + } break; + + case duckdb_libpgquery::PG_NEWTYPE_ALIAS: { + LogicalType target_type = TransformTypeName(*stmt.ofType); + info->type = target_type; + } break; + + default: + throw InternalException("Unknown kind of new type"); + } + result->info = std::move(info); + return result; +} +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_create_view.cpp b/src/duckdb/src/parser/transform/statement/transform_create_view.cpp new file mode 100644 index 00000000..0f3cac00 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_create_view.cpp @@ -0,0 +1,56 @@ +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformCreateView(duckdb_libpgquery::PGViewStmt &stmt) { + D_ASSERT(stmt.type == duckdb_libpgquery::T_PGViewStmt); + D_ASSERT(stmt.view); + + auto result = make_uniq(); + auto info = make_uniq(); + + auto qname = TransformQualifiedName(*stmt.view); + info->catalog = qname.catalog; + info->schema = qname.schema; + info->view_name = qname.name; + info->temporary = !stmt.view->relpersistence; + if (info->temporary && IsInvalidCatalog(info->catalog)) { + info->catalog = TEMP_CATALOG; + } + info->on_conflict = TransformOnConflict(stmt.onconflict); + + info->query = TransformSelect(*PGPointerCast(stmt.query), false); + + PivotEntryCheck("view"); + + if (stmt.aliases && stmt.aliases->length > 0) { + for (auto c = stmt.aliases->head; c != nullptr; c = lnext(c)) { + auto val = PGPointerCast(c->data.ptr_value); + switch (val->type) { + case duckdb_libpgquery::T_PGString: { + info->aliases.emplace_back(val->val.str); + break; + } + default: + throw NotImplementedException("View projection type"); + } + } + if (info->aliases.empty()) { + throw ParserException("Need at least one column name in CREATE VIEW projection list"); + } + } + + if (stmt.options && stmt.options->length > 0) { + throw NotImplementedException("VIEW options"); + } + + if (stmt.withCheckOption != duckdb_libpgquery::PGViewCheckOption::PG_NO_CHECK_OPTION) { + throw NotImplementedException("VIEW CHECK options"); + } + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_delete.cpp b/src/duckdb/src/parser/transform/statement/transform_delete.cpp new file mode 100644 index 00000000..48cde7b4 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_delete.cpp @@ -0,0 +1,36 @@ +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformDelete(duckdb_libpgquery::PGDeleteStmt &stmt) { + auto result = make_uniq(); + vector> materialized_ctes; + if (stmt.withClause) { + TransformCTE(*PGPointerCast(stmt.withClause), result->cte_map, + materialized_ctes); + if (!materialized_ctes.empty()) { + throw NotImplementedException("Materialized CTEs are not implemented for delete."); + } + } + + result->condition = TransformExpression(stmt.whereClause); + result->table = TransformRangeVar(*stmt.relation); + if (result->table->type != TableReferenceType::BASE_TABLE) { + throw Exception("Can only delete from base tables!"); + } + if (stmt.usingClause) { + for (auto n = stmt.usingClause->head; n != nullptr; n = n->next) { + auto target = PGPointerCast(n->data.ptr_value); + auto using_entry = TransformTableRefNode(*target); + result->using_clauses.push_back(std::move(using_entry)); + } + } + + if (stmt.returningList) { + TransformExpressionList(*stmt.returningList, result->returning_list); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_detach.cpp b/src/duckdb/src/parser/transform/statement/transform_detach.cpp new file mode 100644 index 00000000..8a3e00bd --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_detach.cpp @@ -0,0 +1,18 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/statement/detach_statement.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformDetach(duckdb_libpgquery::PGDetachStmt &stmt) { + auto result = make_uniq(); + auto info = make_uniq(); + info->name = stmt.db_name; + info->if_not_found = TransformOnEntryNotFound(stmt.missing_ok); + + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_drop.cpp b/src/duckdb/src/parser/transform/statement/transform_drop.cpp new file mode 100644 index 00000000..d93f822e --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_drop.cpp @@ -0,0 +1,82 @@ +#include "duckdb/parser/statement/drop_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformDrop(duckdb_libpgquery::PGDropStmt &stmt) { + auto result = make_uniq(); + auto &info = *result->info.get(); + if (stmt.objects->length != 1) { + throw NotImplementedException("Can only drop one object at a time"); + } + switch (stmt.removeType) { + case duckdb_libpgquery::PG_OBJECT_TABLE: + info.type = CatalogType::TABLE_ENTRY; + break; + case duckdb_libpgquery::PG_OBJECT_SCHEMA: + info.type = CatalogType::SCHEMA_ENTRY; + break; + case duckdb_libpgquery::PG_OBJECT_INDEX: + info.type = CatalogType::INDEX_ENTRY; + break; + case duckdb_libpgquery::PG_OBJECT_VIEW: + info.type = CatalogType::VIEW_ENTRY; + break; + case duckdb_libpgquery::PG_OBJECT_SEQUENCE: + info.type = CatalogType::SEQUENCE_ENTRY; + break; + case duckdb_libpgquery::PG_OBJECT_FUNCTION: + info.type = CatalogType::MACRO_ENTRY; + break; + case duckdb_libpgquery::PG_OBJECT_TABLE_MACRO: + info.type = CatalogType::TABLE_MACRO_ENTRY; + break; + case duckdb_libpgquery::PG_OBJECT_TYPE: + info.type = CatalogType::TYPE_ENTRY; + break; + default: + throw NotImplementedException("Cannot drop this type yet"); + } + + switch (stmt.removeType) { + case duckdb_libpgquery::PG_OBJECT_TYPE: { + auto view_list = PGPointerCast(stmt.objects); + auto target = PGPointerCast(view_list->head->data.ptr_value); + info.name = PGPointerCast(target->names->tail->data.ptr_value)->val.str; + break; + } + case duckdb_libpgquery::PG_OBJECT_SCHEMA: { + auto view_list = PGPointerCast(stmt.objects->head->data.ptr_value); + if (view_list->length == 2) { + info.catalog = PGPointerCast(view_list->head->data.ptr_value)->val.str; + info.name = PGPointerCast(view_list->head->next->data.ptr_value)->val.str; + } else if (view_list->length == 1) { + info.name = PGPointerCast(view_list->head->data.ptr_value)->val.str; + } else { + throw ParserException("Expected \"catalog.schema\" or \"schema\""); + } + break; + } + default: { + auto view_list = PGPointerCast(stmt.objects->head->data.ptr_value); + if (view_list->length == 3) { + info.catalog = PGPointerCast(view_list->head->data.ptr_value)->val.str; + info.schema = PGPointerCast(view_list->head->next->data.ptr_value)->val.str; + info.name = PGPointerCast(view_list->head->next->next->data.ptr_value)->val.str; + } else if (view_list->length == 2) { + info.schema = PGPointerCast(view_list->head->data.ptr_value)->val.str; + info.name = PGPointerCast(view_list->head->next->data.ptr_value)->val.str; + } else if (view_list->length == 1) { + info.name = PGPointerCast(view_list->head->data.ptr_value)->val.str; + } else { + throw ParserException("Expected \"catalog.schema.name\", \"schema.name\"or \"name\""); + } + break; + } + } + info.cascade = stmt.behavior == duckdb_libpgquery::PGDropBehavior::PG_DROP_CASCADE; + info.if_not_found = TransformOnEntryNotFound(stmt.missing_ok); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_explain.cpp b/src/duckdb/src/parser/transform/statement/transform_explain.cpp new file mode 100644 index 00000000..c68c69e7 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_explain.cpp @@ -0,0 +1,22 @@ +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformExplain(duckdb_libpgquery::PGExplainStmt &stmt) { + auto explain_type = ExplainType::EXPLAIN_STANDARD; + if (stmt.options) { + for (auto n = stmt.options->head; n; n = n->next) { + auto def_elem = PGPointerCast(n->data.ptr_value)->defname; + string elem(def_elem); + if (elem == "analyze") { + explain_type = ExplainType::EXPLAIN_ANALYZE; + } else { + throw NotImplementedException("Unimplemented explain type: %s", elem); + } + } + } + return make_uniq(TransformStatement(*stmt.query), explain_type); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_export.cpp b/src/duckdb/src/parser/transform/statement/transform_export.cpp new file mode 100644 index 00000000..24c819da --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_export.cpp @@ -0,0 +1,21 @@ +#include "duckdb/parser/statement/export_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformExport(duckdb_libpgquery::PGExportStmt &stmt) { + auto info = make_uniq(); + info->file_path = stmt.filename; + info->format = "csv"; + info->is_from = false; + // handle export options + TransformCopyOptions(*info, stmt.options); + + auto result = make_uniq(std::move(info)); + if (stmt.database) { + result->database = stmt.database; + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_import.cpp b/src/duckdb/src/parser/transform/statement/transform_import.cpp new file mode 100644 index 00000000..29092b2e --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_import.cpp @@ -0,0 +1,13 @@ +#include "duckdb/parser/statement/pragma_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformImport(duckdb_libpgquery::PGImportStmt &stmt) { + auto result = make_uniq(); + result->info->name = "import_database"; + result->info->parameters.emplace_back(stmt.filename); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_insert.cpp b/src/duckdb/src/parser/transform/statement/transform_insert.cpp new file mode 100644 index 00000000..30d7aeb6 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_insert.cpp @@ -0,0 +1,86 @@ +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformValuesList(duckdb_libpgquery::PGList *list) { + auto result = make_uniq(); + for (auto value_list = list->head; value_list != nullptr; value_list = value_list->next) { + auto target = PGPointerCast(value_list->data.ptr_value); + + vector> insert_values; + TransformExpressionList(*target, insert_values); + if (!result->values.empty()) { + if (result->values[0].size() != insert_values.size()) { + throw ParserException("VALUES lists must all be the same length"); + } + } + result->values.push_back(std::move(insert_values)); + } + result->alias = "valueslist"; + return std::move(result); +} + +unique_ptr Transformer::TransformInsert(duckdb_libpgquery::PGInsertStmt &stmt) { + auto result = make_uniq(); + vector> materialized_ctes; + if (stmt.withClause) { + TransformCTE(*PGPointerCast(stmt.withClause), result->cte_map, + materialized_ctes); + if (!materialized_ctes.empty()) { + throw NotImplementedException("Materialized CTEs are not implemented for insert."); + } + } + + // first check if there are any columns specified + if (stmt.cols) { + for (auto c = stmt.cols->head; c != nullptr; c = lnext(c)) { + auto target = PGPointerCast(c->data.ptr_value); + result->columns.emplace_back(target->name); + } + } + + // Grab and transform the returning columns from the parser. + if (stmt.returningList) { + TransformExpressionList(*stmt.returningList, result->returning_list); + } + if (stmt.selectStmt) { + result->select_statement = TransformSelect(stmt.selectStmt, false); + } else { + result->default_values = true; + } + + auto qname = TransformQualifiedName(*stmt.relation); + result->table = qname.name; + result->schema = qname.schema; + + if (stmt.onConflictClause) { + if (stmt.onConflictAlias != duckdb_libpgquery::PG_ONCONFLICT_ALIAS_NONE) { + // OR REPLACE | OR IGNORE are shorthands for the ON CONFLICT clause + throw ParserException("You can not provide both OR REPLACE|IGNORE and an ON CONFLICT clause, please remove " + "the first if you want to have more granual control"); + } + result->on_conflict_info = TransformOnConflictClause(stmt.onConflictClause, result->schema); + result->table_ref = TransformRangeVar(*stmt.relation); + } + if (stmt.onConflictAlias != duckdb_libpgquery::PG_ONCONFLICT_ALIAS_NONE) { + D_ASSERT(!stmt.onConflictClause); + result->on_conflict_info = DummyOnConflictClause(stmt.onConflictAlias, result->schema); + result->table_ref = TransformRangeVar(*stmt.relation); + } + switch (stmt.insert_column_order) { + case duckdb_libpgquery::PG_INSERT_BY_POSITION: + result->column_order = InsertColumnOrder::INSERT_BY_POSITION; + break; + case duckdb_libpgquery::PG_INSERT_BY_NAME: + result->column_order = InsertColumnOrder::INSERT_BY_NAME; + break; + default: + throw InternalException("Unrecognized insert column order in TransformInsert"); + } + result->catalog = qname.catalog; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_load.cpp b/src/duckdb/src/parser/transform/statement/transform_load.cpp new file mode 100644 index 00000000..4e8ec950 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_load.cpp @@ -0,0 +1,28 @@ +#include "duckdb/parser/statement/load_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformLoad(duckdb_libpgquery::PGLoadStmt &stmt) { + D_ASSERT(stmt.type == duckdb_libpgquery::T_PGLoadStmt); + + auto load_stmt = make_uniq(); + auto load_info = make_uniq(); + load_info->filename = std::string(stmt.filename); + load_info->repository = std::string(stmt.repository); + switch (stmt.load_type) { + case duckdb_libpgquery::PG_LOAD_TYPE_LOAD: + load_info->load_type = LoadType::LOAD; + break; + case duckdb_libpgquery::PG_LOAD_TYPE_INSTALL: + load_info->load_type = LoadType::INSTALL; + break; + case duckdb_libpgquery::PG_LOAD_TYPE_FORCE_INSTALL: + load_info->load_type = LoadType::FORCE_INSTALL; + break; + } + load_stmt->info = std::move(load_info); + return load_stmt; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp b/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp new file mode 100644 index 00000000..849af1a3 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp @@ -0,0 +1,208 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/tableref/pivotref.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/statement/multi_statement.hpp" +#include "duckdb/parser/statement/drop_statement.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" + +namespace duckdb { + +void Transformer::AddPivotEntry(string enum_name, unique_ptr base, unique_ptr column, + unique_ptr subquery) { + if (parent) { + parent->AddPivotEntry(std::move(enum_name), std::move(base), std::move(column), std::move(subquery)); + return; + } + auto result = make_uniq(); + result->enum_name = std::move(enum_name); + result->base = std::move(base); + result->column = std::move(column); + result->subquery = std::move(subquery); + + pivot_entries.push_back(std::move(result)); +} + +bool Transformer::HasPivotEntries() { + return !GetPivotEntries().empty(); +} + +idx_t Transformer::PivotEntryCount() { + return GetPivotEntries().size(); +} + +vector> &Transformer::GetPivotEntries() { + if (parent) { + return parent->GetPivotEntries(); + } + return pivot_entries; +} + +void Transformer::PivotEntryCheck(const string &type) { + auto &entries = GetPivotEntries(); + if (!entries.empty()) { + throw ParserException( + "PIVOT statements with pivot elements extracted from the data cannot be used in %ss.\nIn order to use " + "PIVOT in a %s the PIVOT values must be manually specified, e.g.:\nPIVOT ... ON %s IN (val1, val2, ...)", + type, type, entries[0]->column->ToString()); + } +} +unique_ptr Transformer::GenerateCreateEnumStmt(unique_ptr entry) { + auto result = make_uniq(); + auto info = make_uniq(); + + info->temporary = true; + info->internal = false; + info->catalog = INVALID_CATALOG; + info->schema = INVALID_SCHEMA; + info->name = std::move(entry->enum_name); + info->on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT; + + // generate the query that will result in the enum creation + unique_ptr subselect; + if (!entry->subquery) { + auto select_node = std::move(entry->base); + auto columnref = entry->column->Copy(); + auto cast = make_uniq(LogicalType::VARCHAR, std::move(columnref)); + select_node->select_list.push_back(std::move(cast)); + + auto is_not_null = + make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, std::move(entry->column)); + select_node->where_clause = std::move(is_not_null); + + // order by the column + select_node->modifiers.push_back(make_uniq()); + auto modifier = make_uniq(); + modifier->orders.emplace_back(OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, + make_uniq(Value::INTEGER(1))); + select_node->modifiers.push_back(std::move(modifier)); + subselect = std::move(select_node); + } else { + subselect = std::move(entry->subquery); + } + + auto select = make_uniq(); + select->node = std::move(subselect); + info->query = std::move(select); + info->type = LogicalType::INVALID; + + result->info = std::move(info); + return std::move(result); +} + +// unique_ptr GenerateDropEnumStmt(string enum_name) { +// auto result = make_uniq(); +// result->info->if_exists = true; +// result->info->schema = INVALID_SCHEMA; +// result->info->catalog = INVALID_CATALOG; +// result->info->name = std::move(enum_name); +// result->info->type = CatalogType::TYPE_ENTRY; +// return std::move(result); +//} + +unique_ptr Transformer::CreatePivotStatement(unique_ptr statement) { + auto result = make_uniq(); + for (auto &pivot : pivot_entries) { + result->statements.push_back(GenerateCreateEnumStmt(std::move(pivot))); + } + result->statements.push_back(std::move(statement)); + // FIXME: drop the types again!? + // for(auto &pivot : pivot_entries) { + // result->statements.push_back(GenerateDropEnumStmt(std::move(pivot->enum_name))); + // } + return std::move(result); +} + +unique_ptr Transformer::TransformPivotStatement(duckdb_libpgquery::PGSelectStmt &select) { + auto pivot = select.pivot; + auto source = TransformTableRefNode(*pivot->source); + + auto select_node = make_uniq(); + vector> materialized_ctes; + // handle the CTEs + if (select.withClause) { + TransformCTE(*PGPointerCast(select.withClause), select_node->cte_map, + materialized_ctes); + } + if (!pivot->columns) { + // no pivot columns - not actually a pivot + select_node->from_table = std::move(source); + if (pivot->groups) { + auto groups = TransformStringList(pivot->groups); + GroupingSet set; + for (idx_t gr = 0; gr < groups.size(); gr++) { + auto &group = groups[gr]; + auto colref = make_uniq(group); + select_node->select_list.push_back(colref->Copy()); + select_node->groups.group_expressions.push_back(std::move(colref)); + set.insert(gr); + } + select_node->groups.grouping_sets.push_back(std::move(set)); + } + if (pivot->aggrs) { + TransformExpressionList(*pivot->aggrs, select_node->select_list); + } + return std::move(select_node); + } + + // generate CREATE TYPE statements for each of the columns that do not have an IN list + auto columns = TransformPivotList(*pivot->columns); + auto pivot_idx = PivotEntryCount(); + for (idx_t c = 0; c < columns.size(); c++) { + auto &col = columns[c]; + if (!col.pivot_enum.empty() || !col.entries.empty()) { + continue; + } + if (col.pivot_expressions.size() != 1) { + throw InternalException("PIVOT statement with multiple names in pivot entry!?"); + } + auto enum_name = "__pivot_enum_" + std::to_string(pivot_idx) + "_" + std::to_string(c); + + auto new_select = make_uniq(); + ExtractCTEsRecursive(new_select->cte_map); + new_select->from_table = source->Copy(); + AddPivotEntry(enum_name, std::move(new_select), col.pivot_expressions[0]->Copy(), std::move(col.subquery)); + col.pivot_enum = enum_name; + } + + // generate the actual query, including the pivot + select_node->select_list.push_back(make_uniq()); + + auto pivot_ref = make_uniq(); + pivot_ref->source = std::move(source); + if (pivot->unpivots) { + pivot_ref->unpivot_names = TransformStringList(pivot->unpivots); + } else { + if (pivot->aggrs) { + TransformExpressionList(*pivot->aggrs, pivot_ref->aggregates); + } else { + // pivot but no aggregates specified - push a count star + vector> children; + auto function = make_uniq("count_star", std::move(children)); + pivot_ref->aggregates.push_back(std::move(function)); + } + } + if (pivot->groups) { + pivot_ref->groups = TransformStringList(pivot->groups); + } + pivot_ref->pivots = std::move(columns); + select_node->from_table = std::move(pivot_ref); + // transform order by/limit modifiers + TransformModifiers(select, *select_node); + + auto node = Transformer::TransformMaterializedCTE(std::move(select_node), materialized_ctes); + + return node; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_pragma.cpp b/src/duckdb/src/parser/transform/statement/transform_pragma.cpp new file mode 100644 index 00000000..6c9d25e2 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_pragma.cpp @@ -0,0 +1,86 @@ +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/statement/pragma_statement.hpp" +#include "duckdb/parser/statement/set_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformPragma(duckdb_libpgquery::PGPragmaStmt &stmt) { + auto result = make_uniq(); + auto &info = *result->info; + + info.name = stmt.name; + // parse the arguments, if any + if (stmt.args) { + for (auto cell = stmt.args->head; cell != nullptr; cell = cell->next) { + auto node = PGPointerCast(cell->data.ptr_value); + auto expr = TransformExpression(node); + + if (expr->type == ExpressionType::COMPARE_EQUAL) { + auto &comp = expr->Cast(); + if (comp.left->type != ExpressionType::COLUMN_REF) { + throw ParserException("Named parameter requires a column reference on the LHS"); + } + auto &columnref = comp.left->Cast(); + + Value rhs_value; + if (!Transformer::ConstructConstantFromExpression(*comp.right, rhs_value)) { + throw ParserException("Named parameter requires a constant on the RHS"); + } + + info.named_parameters[columnref.GetName()] = rhs_value; + } else if (node->type == duckdb_libpgquery::T_PGAConst) { + auto constant = TransformConstant(*PGPointerCast(node.get())); + info.parameters.push_back((constant->Cast()).value); + } else if (expr->type == ExpressionType::COLUMN_REF) { + auto &colref = expr->Cast(); + if (!colref.IsQualified()) { + info.parameters.emplace_back(colref.GetColumnName()); + } else { + info.parameters.emplace_back(expr->ToString()); + } + } else { + info.parameters.emplace_back(expr->ToString()); + } + } + } + // now parse the pragma type + switch (stmt.kind) { + case duckdb_libpgquery::PG_PRAGMA_TYPE_NOTHING: { + if (!info.parameters.empty() || !info.named_parameters.empty()) { + throw InternalException("PRAGMA statement that is not a call or assignment cannot contain parameters"); + } + break; + case duckdb_libpgquery::PG_PRAGMA_TYPE_ASSIGNMENT: + if (info.parameters.size() != 1) { + throw InternalException("PRAGMA statement with assignment should contain exactly one parameter"); + } + if (!info.named_parameters.empty()) { + throw InternalException("PRAGMA statement with assignment cannot have named parameters"); + } + // SQLite does not distinguish between: + // "PRAGMA table_info='integers'" + // "PRAGMA table_info('integers')" + // for compatibility, any pragmas that match the SQLite ones are parsed as calls + case_insensitive_set_t sqlite_compat_pragmas {"table_info"}; + if (sqlite_compat_pragmas.find(info.name) != sqlite_compat_pragmas.end()) { + break; + } + auto set_statement = make_uniq(info.name, info.parameters[0], SetScope::AUTOMATIC); + return std::move(set_statement); + } + case duckdb_libpgquery::PG_PRAGMA_TYPE_CALL: + break; + default: + throw InternalException("Unknown pragma type"); + } + + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_prepare.cpp b/src/duckdb/src/parser/transform/statement/transform_prepare.cpp new file mode 100644 index 00000000..19799876 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_prepare.cpp @@ -0,0 +1,72 @@ +#include "duckdb/parser/statement/drop_statement.hpp" +#include "duckdb/parser/statement/execute_statement.hpp" +#include "duckdb/parser/statement/prepare_statement.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformPrepare(duckdb_libpgquery::PGPrepareStmt &stmt) { + if (stmt.argtypes && stmt.argtypes->length > 0) { + throw NotImplementedException("Prepared statement argument types are not supported, use CAST"); + } + + auto result = make_uniq(); + result->name = string(stmt.name); + result->statement = TransformStatement(*stmt.query); + SetParamCount(0); + + return result; +} + +static string NotAcceptedExpressionException() { + return "Only scalar parameters, named parameters or NULL supported for EXECUTE"; +} + +unique_ptr Transformer::TransformExecute(duckdb_libpgquery::PGExecuteStmt &stmt) { + auto result = make_uniq(); + result->name = string(stmt.name); + + vector> intermediate_values; + if (stmt.params) { + TransformExpressionList(*stmt.params, intermediate_values); + } + + idx_t param_idx = 0; + for (idx_t i = 0; i < intermediate_values.size(); i++) { + auto &expr = intermediate_values[i]; + if (!expr->IsScalar()) { + throw InvalidInputException(NotAcceptedExpressionException()); + } + if (!expr->alias.empty() && param_idx != 0) { + // Found unnamed parameters mixed with named parameters + throw NotImplementedException("Mixing named parameters and positional parameters is not supported yet"); + } + auto param_name = expr->alias; + if (expr->alias.empty()) { + param_name = std::to_string(param_idx + 1); + if (param_idx != i) { + throw NotImplementedException("Mixing named parameters and positional parameters is not supported yet"); + } + param_idx++; + } + expr->alias.clear(); + result->named_values[param_name] = std::move(expr); + } + intermediate_values.clear(); + return result; +} + +unique_ptr Transformer::TransformDeallocate(duckdb_libpgquery::PGDeallocateStmt &stmt) { + if (!stmt.name) { + throw ParserException("DEALLOCATE requires a name"); + } + + auto result = make_uniq(); + result->info->type = CatalogType::PREPARED_STATEMENT; + result->info->name = string(stmt.name); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_rename.cpp b/src/duckdb/src/parser/transform/statement/transform_rename.cpp new file mode 100644 index 00000000..25b72d85 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_rename.cpp @@ -0,0 +1,54 @@ +#include "duckdb/parser/statement/alter_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformRename(duckdb_libpgquery::PGRenameStmt &stmt) { + if (!stmt.relation) { + throw NotImplementedException("Altering schemas is not yet supported"); + } + + unique_ptr info; + + AlterEntryData data; + data.if_not_found = TransformOnEntryNotFound(stmt.missing_ok); + data.catalog = stmt.relation->catalogname ? stmt.relation->catalogname : INVALID_CATALOG; + data.schema = stmt.relation->schemaname ? stmt.relation->schemaname : INVALID_SCHEMA; + if (stmt.relation->relname) { + data.name = stmt.relation->relname; + } + // first we check the type of ALTER + switch (stmt.renameType) { + case duckdb_libpgquery::PG_OBJECT_COLUMN: { + // change column name + + // get the old name and the new name + string old_name = stmt.subname; + string new_name = stmt.newname; + info = make_uniq(std::move(data), old_name, new_name); + break; + } + case duckdb_libpgquery::PG_OBJECT_TABLE: { + // change table name + string new_name = stmt.newname; + info = make_uniq(std::move(data), new_name); + break; + } + case duckdb_libpgquery::PG_OBJECT_VIEW: { + // change view name + string new_name = stmt.newname; + info = make_uniq(std::move(data), new_name); + break; + } + case duckdb_libpgquery::PG_OBJECT_DATABASE: + default: + throw NotImplementedException("Schema element not supported yet!"); + } + D_ASSERT(info); + + auto result = make_uniq(); + result->info = std::move(info); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_select.cpp b/src/duckdb/src/parser/transform/statement/transform_select.cpp new file mode 100644 index 00000000..5473c9c1 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_select.cpp @@ -0,0 +1,36 @@ +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformSelectNode(duckdb_libpgquery::PGSelectStmt &select) { + if (select.pivot) { + return TransformPivotStatement(select); + } else { + return TransformSelectInternal(select); + } +} + +unique_ptr Transformer::TransformSelect(duckdb_libpgquery::PGSelectStmt &select, bool is_select) { + auto result = make_uniq(); + + // Both Insert/Create Table As uses this. + if (is_select) { + if (select.intoClause) { + throw ParserException("SELECT INTO not supported!"); + } + if (select.lockingClause) { + throw ParserException("SELECT locking clause is not supported!"); + } + } + + result->node = TransformSelectNode(select); + return result; +} + +unique_ptr Transformer::TransformSelect(optional_ptr node, bool is_select) { + return TransformSelect(PGCast(*node), is_select); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_select_node.cpp b/src/duckdb/src/parser/transform/statement/transform_select_node.cpp new file mode 100644 index 00000000..0e4994fa --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_select_node.cpp @@ -0,0 +1,167 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" + +namespace duckdb { + +void Transformer::TransformModifiers(duckdb_libpgquery::PGSelectStmt &stmt, QueryNode &node) { + // transform the common properties + // both the set operations and the regular select can have an ORDER BY/LIMIT attached to them + vector orders; + TransformOrderBy(stmt.sortClause, orders); + if (!orders.empty()) { + auto order_modifier = make_uniq(); + order_modifier->orders = std::move(orders); + node.modifiers.push_back(std::move(order_modifier)); + } + if (stmt.limitCount || stmt.limitOffset) { + if (stmt.limitCount && stmt.limitCount->type == duckdb_libpgquery::T_PGLimitPercent) { + auto limit_percent_modifier = make_uniq(); + auto expr_node = PGPointerCast(stmt.limitCount)->limit_percent; + limit_percent_modifier->limit = TransformExpression(expr_node); + if (stmt.limitOffset) { + limit_percent_modifier->offset = TransformExpression(stmt.limitOffset); + } + node.modifiers.push_back(std::move(limit_percent_modifier)); + } else { + auto limit_modifier = make_uniq(); + if (stmt.limitCount) { + limit_modifier->limit = TransformExpression(stmt.limitCount); + } + if (stmt.limitOffset) { + limit_modifier->offset = TransformExpression(stmt.limitOffset); + } + node.modifiers.push_back(std::move(limit_modifier)); + } + } +} + +unique_ptr Transformer::TransformSelectInternal(duckdb_libpgquery::PGSelectStmt &stmt) { + D_ASSERT(stmt.type == duckdb_libpgquery::T_PGSelectStmt); + auto stack_checker = StackCheck(); + + unique_ptr node; + vector> materialized_ctes; + + switch (stmt.op) { + case duckdb_libpgquery::PG_SETOP_NONE: { + node = make_uniq(); + auto &result = node->Cast(); + if (stmt.withClause) { + TransformCTE(*PGPointerCast(stmt.withClause), node->cte_map, + materialized_ctes); + } + if (stmt.windowClause) { + for (auto window_ele = stmt.windowClause->head; window_ele != nullptr; window_ele = window_ele->next) { + auto window_def = PGPointerCast(window_ele->data.ptr_value); + D_ASSERT(window_def); + D_ASSERT(window_def->name); + string window_name(window_def->name); + auto it = window_clauses.find(window_name); + if (it != window_clauses.end()) { + throw ParserException("window \"%s\" is already defined", window_name); + } + window_clauses[window_name] = window_def.get(); + } + } + + // checks distinct clause + if (stmt.distinctClause != nullptr) { + auto modifier = make_uniq(); + // checks distinct on clause + auto target = PGPointerCast(stmt.distinctClause->head->data.ptr_value); + if (target) { + // add the columns defined in the ON clause to the select list + TransformExpressionList(*stmt.distinctClause, modifier->distinct_on_targets); + } + result.modifiers.push_back(std::move(modifier)); + } + + // do this early so the value lists also have a `FROM` + if (stmt.valuesLists) { + // VALUES list, create an ExpressionList + D_ASSERT(!stmt.fromClause); + result.from_table = TransformValuesList(stmt.valuesLists); + result.select_list.push_back(make_uniq()); + } else { + if (!stmt.targetList) { + throw ParserException("SELECT clause without selection list"); + } + // select list + TransformExpressionList(*stmt.targetList, result.select_list); + result.from_table = TransformFrom(stmt.fromClause); + } + + // where + result.where_clause = TransformExpression(stmt.whereClause); + // group by + TransformGroupBy(stmt.groupClause, result); + // having + result.having = TransformExpression(stmt.havingClause); + // qualify + result.qualify = TransformExpression(stmt.qualifyClause); + // sample + result.sample = TransformSampleOptions(stmt.sampleOptions); + break; + } + case duckdb_libpgquery::PG_SETOP_UNION: + case duckdb_libpgquery::PG_SETOP_EXCEPT: + case duckdb_libpgquery::PG_SETOP_INTERSECT: + case duckdb_libpgquery::PG_SETOP_UNION_BY_NAME: { + node = make_uniq(); + auto &result = node->Cast(); + if (stmt.withClause) { + TransformCTE(*PGPointerCast(stmt.withClause), node->cte_map, + materialized_ctes); + } + result.left = TransformSelectNode(*stmt.larg); + result.right = TransformSelectNode(*stmt.rarg); + if (!result.left || !result.right) { + throw Exception("Failed to transform setop children."); + } + + bool select_distinct = true; + switch (stmt.op) { + case duckdb_libpgquery::PG_SETOP_UNION: + select_distinct = !stmt.all; + result.setop_type = SetOperationType::UNION; + break; + case duckdb_libpgquery::PG_SETOP_EXCEPT: + result.setop_type = SetOperationType::EXCEPT; + break; + case duckdb_libpgquery::PG_SETOP_INTERSECT: + result.setop_type = SetOperationType::INTERSECT; + break; + case duckdb_libpgquery::PG_SETOP_UNION_BY_NAME: + select_distinct = !stmt.all; + result.setop_type = SetOperationType::UNION_BY_NAME; + break; + default: + throw Exception("Unexpected setop type"); + } + if (select_distinct) { + result.modifiers.push_back(make_uniq()); + } + if (stmt.sampleOptions) { + throw ParserException("SAMPLE clause is only allowed in regular SELECT statements"); + } + break; + } + default: + throw NotImplementedException("Statement type %d not implemented!", stmt.op); + } + + TransformModifiers(stmt, *node); + + // Handle materialized CTEs + node = Transformer::TransformMaterializedCTE(std::move(node), materialized_ctes); + + return node; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_set.cpp b/src/duckdb/src/parser/transform/statement/transform_set.cpp new file mode 100644 index 00000000..a1389bd6 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_set.cpp @@ -0,0 +1,86 @@ +#include "duckdb/parser/statement/set_statement.hpp" + +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" + +namespace duckdb { + +namespace { + +SetScope ToSetScope(duckdb_libpgquery::VariableSetScope pg_scope) { + switch (pg_scope) { + case duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_LOCAL: + return SetScope::LOCAL; + case duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_SESSION: + return SetScope::SESSION; + case duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_GLOBAL: + return SetScope::GLOBAL; + case duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_DEFAULT: + return SetScope::AUTOMATIC; + default: + throw InternalException("Unexpected pg_scope: %d", pg_scope); + } +} + +SetType ToSetType(duckdb_libpgquery::VariableSetKind pg_kind) { + switch (pg_kind) { + case duckdb_libpgquery::VariableSetKind::VAR_SET_VALUE: + return SetType::SET; + case duckdb_libpgquery::VariableSetKind::VAR_RESET: + return SetType::RESET; + default: + throw NotImplementedException("Can only SET or RESET a variable"); + } +} + +} // namespace + +unique_ptr Transformer::TransformSetVariable(duckdb_libpgquery::PGVariableSetStmt &stmt) { + D_ASSERT(stmt.kind == duckdb_libpgquery::VariableSetKind::VAR_SET_VALUE); + + if (stmt.scope == duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_LOCAL) { + throw NotImplementedException("SET LOCAL is not implemented."); + } + + auto name = std::string(stmt.name); + D_ASSERT(!name.empty()); // parser protect us! + if (stmt.args->length != 1) { + throw ParserException("SET needs a single scalar value parameter"); + } + D_ASSERT(stmt.args->head && stmt.args->head->data.ptr_value); + auto const_val = PGPointerCast(stmt.args->head->data.ptr_value); + D_ASSERT(const_val->type == duckdb_libpgquery::T_PGAConst); + + auto value = TransformValue(const_val->val)->value; + return make_uniq(name, value, ToSetScope(stmt.scope)); +} + +unique_ptr Transformer::TransformResetVariable(duckdb_libpgquery::PGVariableSetStmt &stmt) { + D_ASSERT(stmt.kind == duckdb_libpgquery::VariableSetKind::VAR_RESET); + + if (stmt.scope == duckdb_libpgquery::VariableSetScope::VAR_SET_SCOPE_LOCAL) { + throw NotImplementedException("RESET LOCAL is not implemented."); + } + + auto name = std::string(stmt.name); + D_ASSERT(!name.empty()); // parser protect us! + + return make_uniq(name, ToSetScope(stmt.scope)); +} + +unique_ptr Transformer::TransformSet(duckdb_libpgquery::PGVariableSetStmt &stmt) { + D_ASSERT(stmt.type == duckdb_libpgquery::T_PGVariableSetStmt); + + SetType set_type = ToSetType(stmt.kind); + + switch (set_type) { + case SetType::SET: + return TransformSetVariable(stmt); + case SetType::RESET: + return TransformResetVariable(stmt); + default: + throw NotImplementedException("Type not implemented for SetType"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_show.cpp b/src/duckdb/src/parser/transform/statement/transform_show.cpp new file mode 100644 index 00000000..59ede816 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_show.cpp @@ -0,0 +1,54 @@ +#include "duckdb/parser/statement/pragma_statement.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/statement/show_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" + +namespace duckdb { + +static void TransformShowName(unique_ptr &result, const string &name) { + auto &info = *result->info; + auto lname = StringUtil::Lower(name); + + if (lname == "\"databases\"") { + info.name = "show_databases"; + } else if (lname == "\"tables\"") { + // show all tables + info.name = "show_tables"; + } else if (lname == "__show_tables_expanded") { + info.name = "show_tables_expanded"; + } else { + // show one specific table + info.name = "show"; + info.parameters.emplace_back(name); + } +} + +unique_ptr Transformer::TransformShow(duckdb_libpgquery::PGVariableShowStmt &stmt) { + // we transform SHOW x into PRAGMA SHOW('x') + if (stmt.is_summary) { + auto result = make_uniq(); + auto &info = *result->info; + info.is_summary = stmt.is_summary; + + auto select = make_uniq(); + select->select_list.push_back(make_uniq()); + auto basetable = make_uniq(); + auto qualified_name = QualifiedName::Parse(stmt.name); + basetable->schema_name = qualified_name.schema; + basetable->table_name = qualified_name.name; + select->from_table = std::move(basetable); + + info.query = std::move(select); + return std::move(result); + } + + auto result = make_uniq(); + + auto show_name = stmt.name; + TransformShowName(result, show_name); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_show_select.cpp b/src/duckdb/src/parser/transform/statement/transform_show_select.cpp new file mode 100644 index 00000000..5fee8a7a --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_show_select.cpp @@ -0,0 +1,20 @@ +#include "duckdb/parser/statement/show_statement.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformShowSelect(duckdb_libpgquery::PGVariableShowSelectStmt &stmt) { + // we capture the select statement of SHOW + auto select_stmt = PGPointerCast(stmt.stmt); + + auto result = make_uniq(); + auto &info = *result->info; + info.is_summary = stmt.is_summary; + + info.query = TransformSelectNode(*select_stmt); + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_transaction.cpp b/src/duckdb/src/parser/transform/statement/transform_transaction.cpp new file mode 100644 index 00000000..e35bf9ac --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_transaction.cpp @@ -0,0 +1,20 @@ +#include "duckdb/parser/statement/transaction_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformTransaction(duckdb_libpgquery::PGTransactionStmt &stmt) { + switch (stmt.kind) { + case duckdb_libpgquery::PG_TRANS_STMT_BEGIN: + case duckdb_libpgquery::PG_TRANS_STMT_START: + return make_uniq(TransactionType::BEGIN_TRANSACTION); + case duckdb_libpgquery::PG_TRANS_STMT_COMMIT: + return make_uniq(TransactionType::COMMIT); + case duckdb_libpgquery::PG_TRANS_STMT_ROLLBACK: + return make_uniq(TransactionType::ROLLBACK); + default: + throw NotImplementedException("Transaction type %d not implemented yet", stmt.kind); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_update.cpp b/src/duckdb/src/parser/transform/statement/transform_update.cpp new file mode 100644 index 00000000..c283af83 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_update.cpp @@ -0,0 +1,46 @@ +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformUpdateSetInfo(duckdb_libpgquery::PGList *target_list, + duckdb_libpgquery::PGNode *where_clause) { + auto result = make_uniq(); + + auto root = target_list; + for (auto cell = root->head; cell != nullptr; cell = cell->next) { + auto target = PGPointerCast(cell->data.ptr_value); + result->columns.emplace_back(target->name); + result->expressions.push_back(TransformExpression(target->val)); + } + result->condition = TransformExpression(where_clause); + return result; +} + +unique_ptr Transformer::TransformUpdate(duckdb_libpgquery::PGUpdateStmt &stmt) { + auto result = make_uniq(); + vector> materialized_ctes; + if (stmt.withClause) { + TransformCTE(*PGPointerCast(stmt.withClause), result->cte_map, + materialized_ctes); + if (!materialized_ctes.empty()) { + throw NotImplementedException("Materialized CTEs are not implemented for update."); + } + } + + result->table = TransformRangeVar(*stmt.relation); + if (stmt.fromClause) { + result->from_table = TransformFrom(stmt.fromClause); + } + + result->set_info = TransformUpdateSetInfo(stmt.targetList, stmt.whereClause); + + // Grab and transform the returning columns from the parser. + if (stmt.returningList) { + TransformExpressionList(*stmt.returningList, result->returning_list); + } + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_upsert.cpp b/src/duckdb/src/parser/transform/statement/transform_upsert.cpp new file mode 100644 index 00000000..3c550cb0 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_upsert.cpp @@ -0,0 +1,95 @@ +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +OnConflictAction TransformOnConflictAction(duckdb_libpgquery::PGOnConflictClause *on_conflict) { + if (!on_conflict) { + return OnConflictAction::THROW; + } + switch (on_conflict->action) { + case duckdb_libpgquery::PG_ONCONFLICT_NONE: + return OnConflictAction::THROW; + case duckdb_libpgquery::PG_ONCONFLICT_NOTHING: + return OnConflictAction::NOTHING; + case duckdb_libpgquery::PG_ONCONFLICT_UPDATE: + return OnConflictAction::UPDATE; + default: + throw InternalException("Type not implemented for OnConflictAction"); + } +} + +vector Transformer::TransformConflictTarget(duckdb_libpgquery::PGList &list) { + vector columns; + for (auto cell = list.head; cell != nullptr; cell = cell->next) { + auto index_element = PGPointerCast(cell->data.ptr_value); + if (index_element->collation) { + throw NotImplementedException("Index with collation not supported yet!"); + } + if (index_element->opclass) { + throw NotImplementedException("Index with opclass not supported yet!"); + } + if (!index_element->name) { + throw NotImplementedException("Non-column index element not supported yet!"); + } + if (index_element->nulls_ordering) { + throw NotImplementedException("Index with null_ordering not supported yet!"); + } + if (index_element->ordering) { + throw NotImplementedException("Index with ordering not supported yet!"); + } + columns.emplace_back(index_element->name); + } + return columns; +} + +unique_ptr Transformer::DummyOnConflictClause(duckdb_libpgquery::PGOnConflictActionAlias type, + const string &relname) { + switch (type) { + case duckdb_libpgquery::PGOnConflictActionAlias::PG_ONCONFLICT_ALIAS_REPLACE: { + // This can not be fully resolved yet until the bind stage + auto result = make_uniq(); + result->action_type = OnConflictAction::REPLACE; + return result; + } + case duckdb_libpgquery::PGOnConflictActionAlias::PG_ONCONFLICT_ALIAS_IGNORE: { + // We can just fully replace this with DO NOTHING, and be done with it + auto result = make_uniq(); + result->action_type = OnConflictAction::NOTHING; + return result; + } + default: { + throw InternalException("Type not implemented for PGOnConflictActionAlias"); + } + } +} + +unique_ptr Transformer::TransformOnConflictClause(duckdb_libpgquery::PGOnConflictClause *node, + const string &relname) { + auto stmt = reinterpret_cast(node); + D_ASSERT(stmt); + + auto result = make_uniq(); + result->action_type = TransformOnConflictAction(stmt); + if (stmt->infer) { + // A filter for the ON CONFLICT ... is specified + if (stmt->infer->indexElems) { + // Columns are specified + result->indexed_columns = TransformConflictTarget(*stmt->infer->indexElems); + if (stmt->infer->whereClause) { + result->condition = TransformExpression(stmt->infer->whereClause); + } + } else { + throw NotImplementedException("ON CONSTRAINT conflict target is not supported yet"); + } + } + + if (result->action_type == OnConflictAction::UPDATE) { + result->set_info = TransformUpdateSetInfo(stmt->targetList, stmt->whereClause); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_use.cpp b/src/duckdb/src/parser/transform/statement/transform_use.cpp new file mode 100644 index 00000000..5efd5db0 --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_use.cpp @@ -0,0 +1,20 @@ +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/statement/set_statement.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformUse(duckdb_libpgquery::PGUseStmt &stmt) { + auto qualified_name = TransformQualifiedName(*stmt.name); + if (!IsInvalidCatalog(qualified_name.catalog)) { + throw ParserException("Expected \"USE database\" or \"USE database.schema\""); + } + string name; + if (IsInvalidSchema(qualified_name.schema)) { + name = qualified_name.name; + } else { + name = qualified_name.schema + "." + qualified_name.name; + } + return make_uniq("schema", std::move(name), SetScope::AUTOMATIC); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/statement/transform_vacuum.cpp b/src/duckdb/src/parser/transform/statement/transform_vacuum.cpp new file mode 100644 index 00000000..c61934cc --- /dev/null +++ b/src/duckdb/src/parser/transform/statement/transform_vacuum.cpp @@ -0,0 +1,53 @@ +#include "duckdb/parser/statement/vacuum_statement.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +VacuumOptions ParseOptions(int options) { + VacuumOptions result; + if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_VACUUM) { + result.vacuum = true; + } + if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_ANALYZE) { + result.analyze = true; + } + if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_VERBOSE) { + throw NotImplementedException("Verbose vacuum option"); + } + if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_FREEZE) { + throw NotImplementedException("Freeze vacuum option"); + } + if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_FULL) { + throw NotImplementedException("Full vacuum option"); + } + if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_NOWAIT) { + throw NotImplementedException("No Wait vacuum option"); + } + if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_SKIPTOAST) { + throw NotImplementedException("Skip Toast vacuum option"); + } + if (options & duckdb_libpgquery::PGVacuumOption::PG_VACOPT_DISABLE_PAGE_SKIPPING) { + throw NotImplementedException("Disable Page Skipping vacuum option"); + } + return result; +} + +unique_ptr Transformer::TransformVacuum(duckdb_libpgquery::PGVacuumStmt &stmt) { + auto result = make_uniq(ParseOptions(stmt.options)); + + if (stmt.relation) { + result->info->ref = TransformRangeVar(*stmt.relation); + result->info->has_table = true; + } + + if (stmt.va_cols) { + D_ASSERT(result->info->has_table); + for (auto col_node = stmt.va_cols->head; col_node != nullptr; col_node = col_node->next) { + result->info->columns.emplace_back( + reinterpret_cast(col_node->data.ptr_value)->val.str); + } + } + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/tableref/transform_base_tableref.cpp b/src/duckdb/src/parser/transform/tableref/transform_base_tableref.cpp new file mode 100644 index 00000000..6e6cbf1a --- /dev/null +++ b/src/duckdb/src/parser/transform/tableref/transform_base_tableref.cpp @@ -0,0 +1,46 @@ +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformRangeVar(duckdb_libpgquery::PGRangeVar &root) { + auto result = make_uniq(); + + result->alias = TransformAlias(root.alias, result->column_name_alias); + if (root.relname) { + result->table_name = root.relname; + } + if (root.catalogname) { + result->catalog_name = root.catalogname; + } + if (root.schemaname) { + result->schema_name = root.schemaname; + } + if (root.sample) { + result->sample = TransformSampleOptions(root.sample); + } + result->query_location = root.location; + return std::move(result); +} + +QualifiedName Transformer::TransformQualifiedName(duckdb_libpgquery::PGRangeVar &root) { + QualifiedName qname; + if (root.catalogname) { + qname.catalog = root.catalogname; + } else { + qname.catalog = INVALID_CATALOG; + } + if (root.schemaname) { + qname.schema = root.schemaname; + } else { + qname.schema = INVALID_SCHEMA; + } + if (root.relname) { + qname.name = root.relname; + } else { + qname.name = string(); + } + return qname; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/tableref/transform_from.cpp b/src/duckdb/src/parser/transform/tableref/transform_from.cpp new file mode 100644 index 00000000..84ec8ea5 --- /dev/null +++ b/src/duckdb/src/parser/transform/tableref/transform_from.cpp @@ -0,0 +1,41 @@ +#include "duckdb/parser/tableref/joinref.hpp" +#include "duckdb/parser/tableref/emptytableref.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformFrom(optional_ptr root) { + if (!root) { + return make_uniq(); + } + + if (root->length > 1) { + // Cross Product + auto result = make_uniq(JoinRefType::CROSS); + JoinRef *cur_root = result.get(); + idx_t list_size = 0; + for (auto node = root->head; node != nullptr; node = node->next) { + auto n = PGPointerCast(node->data.ptr_value); + unique_ptr next = TransformTableRefNode(*n); + if (!cur_root->left) { + cur_root->left = std::move(next); + } else if (!cur_root->right) { + cur_root->right = std::move(next); + } else { + auto old_res = std::move(result); + result = make_uniq(JoinRefType::CROSS); + result->left = std::move(old_res); + result->right = std::move(next); + cur_root = result.get(); + } + list_size++; + StackCheck(list_size); + } + return std::move(result); + } + + auto n = PGPointerCast(root->head->data.ptr_value); + return TransformTableRefNode(*n); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/tableref/transform_join.cpp b/src/duckdb/src/parser/transform/tableref/transform_join.cpp new file mode 100644 index 00000000..bd792410 --- /dev/null +++ b/src/duckdb/src/parser/transform/tableref/transform_join.cpp @@ -0,0 +1,77 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/tableref/joinref.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformJoin(duckdb_libpgquery::PGJoinExpr &root) { + auto result = make_uniq(JoinRefType::REGULAR); + switch (root.jointype) { + case duckdb_libpgquery::PG_JOIN_INNER: { + result->type = JoinType::INNER; + break; + } + case duckdb_libpgquery::PG_JOIN_LEFT: { + result->type = JoinType::LEFT; + break; + } + case duckdb_libpgquery::PG_JOIN_FULL: { + result->type = JoinType::OUTER; + break; + } + case duckdb_libpgquery::PG_JOIN_RIGHT: { + result->type = JoinType::RIGHT; + break; + } + case duckdb_libpgquery::PG_JOIN_SEMI: { + result->type = JoinType::SEMI; + break; + } + case duckdb_libpgquery::PG_JOIN_ANTI: { + result->type = JoinType::ANTI; + break; + } + case duckdb_libpgquery::PG_JOIN_POSITION: { + result->ref_type = JoinRefType::POSITIONAL; + break; + } + default: { + throw NotImplementedException("Join type %d not supported\n", root.jointype); + } + } + + // Check the type of left arg and right arg before transform + result->left = TransformTableRefNode(*root.larg); + result->right = TransformTableRefNode(*root.rarg); + switch (root.joinreftype) { + case duckdb_libpgquery::PG_JOIN_NATURAL: + result->ref_type = JoinRefType::NATURAL; + break; + case duckdb_libpgquery::PG_JOIN_ASOF: + result->ref_type = JoinRefType::ASOF; + break; + default: + break; + } + result->query_location = root.location; + + if (root.usingClause && root.usingClause->length > 0) { + // usingClause is a list of strings + for (auto node = root.usingClause->head; node != nullptr; node = node->next) { + auto target = reinterpret_cast(node->data.ptr_value); + D_ASSERT(target->type == duckdb_libpgquery::T_PGString); + auto column_name = string(reinterpret_cast(target)->val.str); + result->using_columns.push_back(column_name); + } + return std::move(result); + } + + if (!root.quals && result->using_columns.empty() && result->ref_type == JoinRefType::REGULAR) { // CROSS PRODUCT + result->ref_type = JoinRefType::CROSS; + } + result->condition = TransformExpression(root.quals); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/tableref/transform_pivot.cpp b/src/duckdb/src/parser/transform/tableref/transform_pivot.cpp new file mode 100644 index 00000000..a7e5afb2 --- /dev/null +++ b/src/duckdb/src/parser/transform/tableref/transform_pivot.cpp @@ -0,0 +1,127 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/tableref/pivotref.hpp" +#include "duckdb/parser/transformer.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" + +namespace duckdb { + +static void TransformPivotInList(unique_ptr &expr, PivotColumnEntry &entry, bool root_entry = true) { + if (expr->type == ExpressionType::COLUMN_REF) { + auto &colref = expr->Cast(); + if (colref.IsQualified()) { + throw ParserException("PIVOT IN list cannot contain qualified column references"); + } + entry.values.emplace_back(colref.GetColumnName()); + } else if (expr->type == ExpressionType::VALUE_CONSTANT) { + auto &constant_expr = expr->Cast(); + entry.values.push_back(std::move(constant_expr.value)); + } else if (root_entry && expr->type == ExpressionType::FUNCTION) { + auto &function = expr->Cast(); + if (function.function_name != "row") { + throw ParserException("PIVOT IN list must contain columns or lists of columns"); + } + for (auto &child : function.children) { + TransformPivotInList(child, entry, false); + } + } else if (root_entry && expr->type == ExpressionType::STAR) { + entry.star_expr = std::move(expr); + } else { + throw ParserException("PIVOT IN list must contain columns or lists of columns"); + } +} + +PivotColumn Transformer::TransformPivotColumn(duckdb_libpgquery::PGPivot &pivot) { + PivotColumn col; + if (pivot.pivot_columns) { + TransformExpressionList(*pivot.pivot_columns, col.pivot_expressions); + for (auto &expr : col.pivot_expressions) { + if (expr->IsScalar()) { + throw ParserException("Cannot pivot on constant value \"%s\"", expr->ToString()); + } + if (expr->HasSubquery()) { + throw ParserException("Cannot pivot on subquery \"%s\"", expr->ToString()); + } + } + } else if (pivot.unpivot_columns) { + col.unpivot_names = TransformStringList(pivot.unpivot_columns); + } else { + throw InternalException("Either pivot_columns or unpivot_columns must be defined"); + } + if (pivot.pivot_value) { + for (auto node = pivot.pivot_value->head; node != nullptr; node = node->next) { + auto n = PGPointerCast(node->data.ptr_value); + auto expr = TransformExpression(n); + PivotColumnEntry entry; + entry.alias = expr->alias; + TransformPivotInList(expr, entry); + col.entries.push_back(std::move(entry)); + } + } + if (pivot.subquery) { + col.subquery = TransformSelectNode(*PGPointerCast(pivot.subquery)); + } + if (pivot.pivot_enum) { + col.pivot_enum = pivot.pivot_enum; + } + return col; +} + +vector Transformer::TransformPivotList(duckdb_libpgquery::PGList &list) { + vector result; + for (auto node = list.head; node != nullptr; node = node->next) { + auto pivot = PGPointerCast(node->data.ptr_value); + result.push_back(TransformPivotColumn(*pivot)); + } + return result; +} + +unique_ptr Transformer::TransformPivot(duckdb_libpgquery::PGPivotExpr &root) { + auto result = make_uniq(); + result->source = TransformTableRefNode(*root.source); + if (root.aggrs) { + TransformExpressionList(*root.aggrs, result->aggregates); + } + if (root.unpivots) { + result->unpivot_names = TransformStringList(root.unpivots); + } + result->pivots = TransformPivotList(*root.pivots); + if (!result->unpivot_names.empty() && result->pivots.size() > 1) { + throw ParserException("UNPIVOT requires a single pivot element"); + } + if (root.groups) { + result->groups = TransformStringList(root.groups); + } + for (auto &pivot : result->pivots) { + idx_t expected_size; + bool is_pivot = result->unpivot_names.empty(); + if (!result->unpivot_names.empty()) { + // unpivot + if (pivot.unpivot_names.size() != 1) { + throw ParserException("UNPIVOT requires a single column name for the PIVOT IN clause"); + } + D_ASSERT(pivot.pivot_expressions.empty()); + expected_size = pivot.entries[0].values.size(); + } else { + // pivot + expected_size = pivot.pivot_expressions.size(); + D_ASSERT(pivot.unpivot_names.empty()); + } + for (auto &entry : pivot.entries) { + if (entry.star_expr && is_pivot) { + throw ParserException("PIVOT IN list must contain columns or lists of columns - star expressions are " + "only supported for UNPIVOT"); + } + if (entry.values.size() != expected_size) { + throw ParserException("PIVOT IN list - inconsistent amount of rows - expected %d but got %d", + expected_size, entry.values.size()); + } + } + } + result->include_nulls = root.include_nulls; + result->alias = TransformAlias(root.alias, result->column_name_alias); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/tableref/transform_subquery.cpp b/src/duckdb/src/parser/transform/tableref/transform_subquery.cpp new file mode 100644 index 00000000..1d5105ad --- /dev/null +++ b/src/duckdb/src/parser/transform/tableref/transform_subquery.cpp @@ -0,0 +1,20 @@ +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformRangeSubselect(duckdb_libpgquery::PGRangeSubselect &root) { + Transformer subquery_transformer(*this); + auto subquery = subquery_transformer.TransformSelect(root.subquery); + if (!subquery) { + return nullptr; + } + auto result = make_uniq(std::move(subquery)); + result->alias = TransformAlias(root.alias, result->column_name_alias); + if (root.sample) { + result->sample = TransformSampleOptions(root.sample); + } + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/tableref/transform_table_function.cpp b/src/duckdb/src/parser/transform/tableref/transform_table_function.cpp new file mode 100644 index 00000000..5d8b05c2 --- /dev/null +++ b/src/duckdb/src/parser/transform/tableref/transform_table_function.cpp @@ -0,0 +1,48 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformRangeFunction(duckdb_libpgquery::PGRangeFunction &root) { + if (root.ordinality) { + throw NotImplementedException("WITH ORDINALITY not implemented"); + } + if (root.is_rowsfrom) { + throw NotImplementedException("ROWS FROM() not implemented"); + } + if (root.functions->length != 1) { + throw NotImplementedException("Need exactly one function"); + } + auto function_sublist = PGPointerCast(root.functions->head->data.ptr_value); + D_ASSERT(function_sublist->length == 2); + auto call_tree = PGPointerCast(function_sublist->head->data.ptr_value); + auto coldef = function_sublist->head->next->data.ptr_value; + + if (coldef) { + throw NotImplementedException("Explicit column definition not supported yet"); + } + // transform the function call + auto result = make_uniq(); + switch (call_tree->type) { + case duckdb_libpgquery::T_PGFuncCall: { + auto func_call = PGPointerCast(call_tree.get()); + result->function = TransformFuncCall(*func_call); + result->query_location = func_call->location; + break; + } + case duckdb_libpgquery::T_PGSQLValueFunction: + result->function = + TransformSQLValueFunction(*PGPointerCast(call_tree.get())); + break; + default: + throw ParserException("Not a function call or value function"); + } + result->alias = TransformAlias(root.alias, result->column_name_alias); + if (root.sample) { + result->sample = TransformSampleOptions(root.sample); + } + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transform/tableref/transform_tableref.cpp b/src/duckdb/src/parser/transform/tableref/transform_tableref.cpp new file mode 100644 index 00000000..ee4de700 --- /dev/null +++ b/src/duckdb/src/parser/transform/tableref/transform_tableref.cpp @@ -0,0 +1,26 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/parser/transformer.hpp" + +namespace duckdb { + +unique_ptr Transformer::TransformTableRefNode(duckdb_libpgquery::PGNode &n) { + auto stack_checker = StackCheck(); + + switch (n.type) { + case duckdb_libpgquery::T_PGRangeVar: + return TransformRangeVar(PGCast(n)); + case duckdb_libpgquery::T_PGJoinExpr: + return TransformJoin(PGCast(n)); + case duckdb_libpgquery::T_PGRangeSubselect: + return TransformRangeSubselect(PGCast(n)); + case duckdb_libpgquery::T_PGRangeFunction: + return TransformRangeFunction(PGCast(n)); + case duckdb_libpgquery::T_PGPivotExpr: + return TransformPivot(PGCast(n)); + default: + throw NotImplementedException("From Type %d not supported", n.type); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/transformer.cpp b/src/duckdb/src/parser/transformer.cpp new file mode 100644 index 00000000..3cc1288b --- /dev/null +++ b/src/duckdb/src/parser/transformer.cpp @@ -0,0 +1,231 @@ +#include "duckdb/parser/transformer.hpp" + +#include "duckdb/parser/expression/list.hpp" +#include "duckdb/parser/statement/list.hpp" +#include "duckdb/parser/tableref/emptytableref.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/parser/parser_options.hpp" + +namespace duckdb { + +Transformer::Transformer(ParserOptions &options) + : parent(nullptr), options(options), stack_depth(DConstants::INVALID_INDEX) { +} + +Transformer::Transformer(Transformer &parent) + : parent(&parent), options(parent.options), stack_depth(DConstants::INVALID_INDEX) { +} + +Transformer::~Transformer() { +} + +void Transformer::Clear() { + SetParamCount(0); + pivot_entries.clear(); +} + +bool Transformer::TransformParseTree(duckdb_libpgquery::PGList *tree, vector> &statements) { + InitializeStackCheck(); + for (auto entry = tree->head; entry != nullptr; entry = entry->next) { + Clear(); + auto n = PGPointerCast(entry->data.ptr_value); + auto stmt = TransformStatement(*n); + D_ASSERT(stmt); + if (HasPivotEntries()) { + stmt = CreatePivotStatement(std::move(stmt)); + } + stmt->n_param = ParamCount(); + statements.push_back(std::move(stmt)); + } + return true; +} + +void Transformer::InitializeStackCheck() { + stack_depth = 0; +} + +StackChecker Transformer::StackCheck(idx_t extra_stack) { + auto &root = RootTransformer(); + D_ASSERT(root.stack_depth != DConstants::INVALID_INDEX); + if (root.stack_depth + extra_stack >= options.max_expression_depth) { + throw ParserException("Max expression depth limit of %lld exceeded. Use \"SET max_expression_depth TO x\" to " + "increase the maximum expression depth.", + options.max_expression_depth); + } + return StackChecker(root, extra_stack); +} + +unique_ptr Transformer::TransformStatement(duckdb_libpgquery::PGNode &stmt) { + auto result = TransformStatementInternal(stmt); + result->n_param = ParamCount(); + if (!named_param_map.empty()) { + // Avoid overriding a previous move with nothing + result->named_param_map = std::move(named_param_map); + } + return result; +} + +Transformer &Transformer::RootTransformer() { + reference node = *this; + while (node.get().parent) { + node = *node.get().parent; + } + return node.get(); +} + +const Transformer &Transformer::RootTransformer() const { + reference node = *this; + while (node.get().parent) { + node = *node.get().parent; + } + return node.get(); +} + +idx_t Transformer::ParamCount() const { + auto &root = RootTransformer(); + return root.prepared_statement_parameter_index; +} + +void Transformer::SetParamCount(idx_t new_count) { + auto &root = RootTransformer(); + root.prepared_statement_parameter_index = new_count; +} + +static void ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type) { + // Mixing positional/auto-increment and named parameters is not supported + if (last_type == PreparedParamType::INVALID) { + return; + } + if (last_type == PreparedParamType::NAMED) { + if (new_type != PreparedParamType::NAMED) { + throw NotImplementedException("Mixing named and positional parameters is not supported yet"); + } + } + if (last_type != PreparedParamType::NAMED) { + if (new_type == PreparedParamType::NAMED) { + throw NotImplementedException("Mixing named and positional parameters is not supported yet"); + } + } +} + +void Transformer::SetParam(const string &identifier, idx_t index, PreparedParamType type) { + auto &root = RootTransformer(); + ParamTypeCheck(root.last_param_type, type); + root.last_param_type = type; + D_ASSERT(!root.named_param_map.count(identifier)); + root.named_param_map[identifier] = index; +} + +bool Transformer::GetParam(const string &identifier, idx_t &index, PreparedParamType type) { + auto &root = RootTransformer(); + ParamTypeCheck(root.last_param_type, type); + auto entry = root.named_param_map.find(identifier); + if (entry == root.named_param_map.end()) { + return false; + } + index = entry->second; + return true; +} + +unique_ptr Transformer::TransformStatementInternal(duckdb_libpgquery::PGNode &stmt) { + switch (stmt.type) { + case duckdb_libpgquery::T_PGRawStmt: { + auto &raw_stmt = PGCast(stmt); + auto result = TransformStatement(*raw_stmt.stmt); + if (result) { + result->stmt_location = raw_stmt.stmt_location; + result->stmt_length = raw_stmt.stmt_len; + } + return result; + } + case duckdb_libpgquery::T_PGSelectStmt: + return TransformSelect(PGCast(stmt)); + case duckdb_libpgquery::T_PGCreateStmt: + return TransformCreateTable(PGCast(stmt)); + case duckdb_libpgquery::T_PGCreateSchemaStmt: + return TransformCreateSchema(PGCast(stmt)); + case duckdb_libpgquery::T_PGViewStmt: + return TransformCreateView(PGCast(stmt)); + case duckdb_libpgquery::T_PGCreateSeqStmt: + return TransformCreateSequence(PGCast(stmt)); + case duckdb_libpgquery::T_PGCreateFunctionStmt: + return TransformCreateFunction(PGCast(stmt)); + case duckdb_libpgquery::T_PGDropStmt: + return TransformDrop(PGCast(stmt)); + case duckdb_libpgquery::T_PGInsertStmt: + return TransformInsert(PGCast(stmt)); + case duckdb_libpgquery::T_PGCopyStmt: + return TransformCopy(PGCast(stmt)); + case duckdb_libpgquery::T_PGTransactionStmt: + return TransformTransaction(PGCast(stmt)); + case duckdb_libpgquery::T_PGDeleteStmt: + return TransformDelete(PGCast(stmt)); + case duckdb_libpgquery::T_PGUpdateStmt: + return TransformUpdate(PGCast(stmt)); + case duckdb_libpgquery::T_PGIndexStmt: + return TransformCreateIndex(PGCast(stmt)); + case duckdb_libpgquery::T_PGAlterTableStmt: + return TransformAlter(PGCast(stmt)); + case duckdb_libpgquery::T_PGRenameStmt: + return TransformRename(PGCast(stmt)); + case duckdb_libpgquery::T_PGPrepareStmt: + return TransformPrepare(PGCast(stmt)); + case duckdb_libpgquery::T_PGExecuteStmt: + return TransformExecute(PGCast(stmt)); + case duckdb_libpgquery::T_PGDeallocateStmt: + return TransformDeallocate(PGCast(stmt)); + case duckdb_libpgquery::T_PGCreateTableAsStmt: + return TransformCreateTableAs(PGCast(stmt)); + case duckdb_libpgquery::T_PGPragmaStmt: + return TransformPragma(PGCast(stmt)); + case duckdb_libpgquery::T_PGExportStmt: + return TransformExport(PGCast(stmt)); + case duckdb_libpgquery::T_PGImportStmt: + return TransformImport(PGCast(stmt)); + case duckdb_libpgquery::T_PGExplainStmt: + return TransformExplain(PGCast(stmt)); + case duckdb_libpgquery::T_PGVacuumStmt: + return TransformVacuum(PGCast(stmt)); + case duckdb_libpgquery::T_PGVariableShowStmt: + return TransformShow(PGCast(stmt)); + case duckdb_libpgquery::T_PGVariableShowSelectStmt: + return TransformShowSelect(PGCast(stmt)); + case duckdb_libpgquery::T_PGCallStmt: + return TransformCall(PGCast(stmt)); + case duckdb_libpgquery::T_PGVariableSetStmt: + return TransformSet(PGCast(stmt)); + case duckdb_libpgquery::T_PGCheckPointStmt: + return TransformCheckpoint(PGCast(stmt)); + case duckdb_libpgquery::T_PGLoadStmt: + return TransformLoad(PGCast(stmt)); + case duckdb_libpgquery::T_PGCreateTypeStmt: + return TransformCreateType(PGCast(stmt)); + case duckdb_libpgquery::T_PGAlterSeqStmt: + return TransformAlterSequence(PGCast(stmt)); + case duckdb_libpgquery::T_PGAttachStmt: + return TransformAttach(PGCast(stmt)); + case duckdb_libpgquery::T_PGDetachStmt: + return TransformDetach(PGCast(stmt)); + case duckdb_libpgquery::T_PGUseStmt: + return TransformUse(PGCast(stmt)); + default: + throw NotImplementedException(NodetypeToString(stmt.type)); + } +} + +unique_ptr Transformer::TransformMaterializedCTE(unique_ptr root, + vector> &materialized_ctes) { + while (!materialized_ctes.empty()) { + unique_ptr node_result; + node_result = std::move(materialized_ctes.back()); + node_result->cte_map = root->cte_map.Copy(); + node_result->child = std::move(root); + root = std::move(node_result); + materialized_ctes.pop_back(); + } + + return root; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/bind_context.cpp b/src/duckdb/src/planner/bind_context.cpp new file mode 100644 index 00000000..3b86a15b --- /dev/null +++ b/src/duckdb/src/planner/bind_context.cpp @@ -0,0 +1,557 @@ +#include "duckdb/planner/bind_context.hpp" + +#include "duckdb/catalog/catalog_entry/table_column_type.hpp" +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/positional_reference_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/planner/expression_binder/constant_binder.hpp" + +#include + +namespace duckdb { + +string BindContext::GetMatchingBinding(const string &column_name) { + string result; + for (auto &kv : bindings) { + auto binding = kv.second.get(); + auto is_using_binding = GetUsingBinding(column_name, kv.first); + if (is_using_binding) { + continue; + } + if (binding->HasMatchingBinding(column_name)) { + if (!result.empty() || is_using_binding) { + throw BinderException("Ambiguous reference to column name \"%s\" (use: \"%s.%s\" " + "or \"%s.%s\")", + column_name, result, column_name, kv.first, column_name); + } + result = kv.first; + } + } + return result; +} + +vector BindContext::GetSimilarBindings(const string &column_name) { + vector> scores; + for (auto &kv : bindings) { + auto binding = kv.second.get(); + for (auto &name : binding->names) { + idx_t distance = StringUtil::SimilarityScore(name, column_name); + scores.emplace_back(binding->alias + "." + name, distance); + } + } + return StringUtil::TopNStrings(scores); +} + +void BindContext::AddUsingBinding(const string &column_name, UsingColumnSet &set) { + using_columns[column_name].insert(set); +} + +void BindContext::AddUsingBindingSet(unique_ptr set) { + using_column_sets.push_back(std::move(set)); +} + +optional_ptr BindContext::GetUsingBinding(const string &column_name) { + auto entry = using_columns.find(column_name); + if (entry == using_columns.end()) { + return nullptr; + } + auto &using_bindings = entry->second; + if (using_bindings.size() > 1) { + string error = "Ambiguous column reference: column \"" + column_name + "\" can refer to either:\n"; + for (auto &using_set_ref : using_bindings) { + auto &using_set = using_set_ref.get(); + string result_bindings; + for (auto &binding : using_set.bindings) { + if (result_bindings.empty()) { + result_bindings = "["; + } else { + result_bindings += ", "; + } + result_bindings += binding; + result_bindings += "."; + result_bindings += GetActualColumnName(binding, column_name); + } + error += result_bindings + "]"; + } + throw BinderException(error); + } + for (auto &using_set : using_bindings) { + return &using_set.get(); + } + throw InternalException("Using binding found but no entries"); +} + +optional_ptr BindContext::GetUsingBinding(const string &column_name, const string &binding_name) { + if (binding_name.empty()) { + throw InternalException("GetUsingBinding: expected non-empty binding_name"); + } + auto entry = using_columns.find(column_name); + if (entry == using_columns.end()) { + return nullptr; + } + auto &using_bindings = entry->second; + for (auto &using_set_ref : using_bindings) { + auto &using_set = using_set_ref.get(); + auto &bindings = using_set.bindings; + if (bindings.find(binding_name) != bindings.end()) { + return &using_set; + } + } + return nullptr; +} + +void BindContext::RemoveUsingBinding(const string &column_name, UsingColumnSet &set) { + auto entry = using_columns.find(column_name); + if (entry == using_columns.end()) { + throw InternalException("Attempting to remove using binding that is not there"); + } + auto &bindings = entry->second; + if (bindings.find(set) != bindings.end()) { + bindings.erase(set); + } + if (bindings.empty()) { + using_columns.erase(column_name); + } +} + +void BindContext::TransferUsingBinding(BindContext ¤t_context, optional_ptr current_set, + UsingColumnSet &new_set, const string &binding, const string &using_column) { + AddUsingBinding(using_column, new_set); + if (current_set) { + current_context.RemoveUsingBinding(using_column, *current_set); + } +} + +string BindContext::GetActualColumnName(const string &binding_name, const string &column_name) { + string error; + auto binding = GetBinding(binding_name, error); + if (!binding) { + throw InternalException("No binding with name \"%s\"", binding_name); + } + column_t binding_index; + if (!binding->TryGetBindingIndex(column_name, binding_index)) { // LCOV_EXCL_START + throw InternalException("Binding with name \"%s\" does not have a column named \"%s\"", binding_name, + column_name); + } // LCOV_EXCL_STOP + return binding->names[binding_index]; +} + +unordered_set BindContext::GetMatchingBindings(const string &column_name) { + unordered_set result; + for (auto &kv : bindings) { + auto binding = kv.second.get(); + if (binding->HasMatchingBinding(column_name)) { + result.insert(kv.first); + } + } + return result; +} + +unique_ptr BindContext::ExpandGeneratedColumn(const string &table_name, const string &column_name) { + string error_message; + + auto binding = GetBinding(table_name, error_message); + D_ASSERT(binding); + auto &table_binding = binding->Cast(); + auto result = table_binding.ExpandGeneratedColumn(column_name); + result->alias = column_name; + return result; +} + +unique_ptr BindContext::CreateColumnReference(const string &table_name, const string &column_name) { + string schema_name; + return CreateColumnReference(schema_name, table_name, column_name); +} + +static bool ColumnIsGenerated(Binding &binding, column_t index) { + if (binding.binding_type != BindingType::TABLE) { + return false; + } + auto &table_binding = binding.Cast(); + auto catalog_entry = table_binding.GetStandardEntry(); + if (!catalog_entry) { + return false; + } + if (index == COLUMN_IDENTIFIER_ROW_ID) { + return false; + } + D_ASSERT(catalog_entry->type == CatalogType::TABLE_ENTRY); + auto &table_entry = catalog_entry->Cast(); + return table_entry.GetColumn(LogicalIndex(index)).Generated(); +} + +unique_ptr BindContext::CreateColumnReference(const string &catalog_name, const string &schema_name, + const string &table_name, const string &column_name) { + string error_message; + vector names; + if (!catalog_name.empty()) { + names.push_back(catalog_name); + } + if (!schema_name.empty()) { + names.push_back(schema_name); + } + names.push_back(table_name); + names.push_back(column_name); + + auto result = make_uniq(std::move(names)); + auto binding = GetBinding(table_name, error_message); + if (!binding) { + return std::move(result); + } + auto column_index = binding->GetBindingIndex(column_name); + if (ColumnIsGenerated(*binding, column_index)) { + return ExpandGeneratedColumn(table_name, column_name); + } else if (column_index < binding->names.size() && binding->names[column_index] != column_name) { + // because of case insensitivity in the binder we rename the column to the original name + // as it appears in the binding itself + result->alias = binding->names[column_index]; + } + return std::move(result); +} + +unique_ptr BindContext::CreateColumnReference(const string &schema_name, const string &table_name, + const string &column_name) { + string catalog_name; + return CreateColumnReference(catalog_name, schema_name, table_name, column_name); +} + +optional_ptr BindContext::GetCTEBinding(const string &ctename) { + auto match = cte_bindings.find(ctename); + if (match == cte_bindings.end()) { + return nullptr; + } + return match->second.get(); +} + +optional_ptr BindContext::GetBinding(const string &name, string &out_error) { + auto match = bindings.find(name); + if (match == bindings.end()) { + // alias not found in this BindContext + vector candidates; + for (auto &kv : bindings) { + candidates.push_back(kv.first); + } + string candidate_str = + StringUtil::CandidatesMessage(StringUtil::TopNLevenshtein(candidates, name), "Candidate tables"); + out_error = StringUtil::Format("Referenced table \"%s\" not found!%s", name, candidate_str); + return nullptr; + } + return match->second.get(); +} + +BindResult BindContext::BindColumn(ColumnRefExpression &colref, idx_t depth) { + if (!colref.IsQualified()) { + throw InternalException("Could not bind alias \"%s\"!", colref.GetColumnName()); + } + + string error; + auto binding = GetBinding(colref.GetTableName(), error); + if (!binding) { + return BindResult(error); + } + return binding->Bind(colref, depth); +} + +string BindContext::BindColumn(PositionalReferenceExpression &ref, string &table_name, string &column_name) { + idx_t total_columns = 0; + idx_t current_position = ref.index - 1; + for (auto &entry : bindings_list) { + auto &binding = entry.get(); + idx_t entry_column_count = binding.names.size(); + if (ref.index == 0) { + // this is a row id + table_name = binding.alias; + column_name = "rowid"; + return string(); + } + if (current_position < entry_column_count) { + table_name = binding.alias; + column_name = binding.names[current_position]; + return string(); + } else { + total_columns += entry_column_count; + current_position -= entry_column_count; + } + } + return StringUtil::Format("Positional reference %d out of range (total %d columns)", ref.index, total_columns); +} + +unique_ptr BindContext::PositionToColumn(PositionalReferenceExpression &ref) { + string table_name, column_name; + + string error = BindColumn(ref, table_name, column_name); + if (!error.empty()) { + throw BinderException(error); + } + return make_uniq(column_name, table_name); +} + +bool BindContext::CheckExclusionList(StarExpression &expr, const string &column_name, + vector> &new_select_list, + case_insensitive_set_t &excluded_columns) { + if (expr.exclude_list.find(column_name) != expr.exclude_list.end()) { + excluded_columns.insert(column_name); + return true; + } + auto entry = expr.replace_list.find(column_name); + if (entry != expr.replace_list.end()) { + auto new_entry = entry->second->Copy(); + new_entry->alias = entry->first; + excluded_columns.insert(entry->first); + new_select_list.push_back(std::move(new_entry)); + return true; + } + return false; +} + +void BindContext::GenerateAllColumnExpressions(StarExpression &expr, + vector> &new_select_list) { + if (bindings_list.empty()) { + throw BinderException("* expression without FROM clause!"); + } + case_insensitive_set_t excluded_columns; + if (expr.relation_name.empty()) { + // SELECT * case + // bind all expressions of each table in-order + reference_set_t handled_using_columns; + for (auto &entry : bindings_list) { + auto &binding = entry.get(); + for (auto &column_name : binding.names) { + if (CheckExclusionList(expr, column_name, new_select_list, excluded_columns)) { + continue; + } + // check if this column is a USING column + auto using_binding_ptr = GetUsingBinding(column_name, binding.alias); + if (using_binding_ptr) { + auto &using_binding = *using_binding_ptr; + // it is! + // check if we have already emitted the using column + if (handled_using_columns.find(using_binding) != handled_using_columns.end()) { + // we have! bail out + continue; + } + // we have not! output the using column + if (using_binding.primary_binding.empty()) { + // no primary binding: output a coalesce + auto coalesce = make_uniq(ExpressionType::OPERATOR_COALESCE); + for (auto &child_binding : using_binding.bindings) { + coalesce->children.push_back(make_uniq(column_name, child_binding)); + } + coalesce->alias = column_name; + new_select_list.push_back(std::move(coalesce)); + } else { + // primary binding: output the qualified column ref + new_select_list.push_back( + make_uniq(column_name, using_binding.primary_binding)); + } + handled_using_columns.insert(using_binding); + continue; + } + new_select_list.push_back(make_uniq(column_name, binding.alias)); + } + } + } else { + // SELECT tbl.* case + // SELECT struct.* case + string error; + auto binding = GetBinding(expr.relation_name, error); + bool is_struct_ref = false; + if (!binding) { + auto binding_name = GetMatchingBinding(expr.relation_name); + if (binding_name.empty()) { + throw BinderException(error); + } + binding = bindings[binding_name].get(); + is_struct_ref = true; + } + + if (is_struct_ref) { + auto col_idx = binding->GetBindingIndex(expr.relation_name); + auto col_type = binding->types[col_idx]; + if (col_type.id() != LogicalTypeId::STRUCT) { + throw BinderException(StringUtil::Format( + "Cannot extract field from expression \"%s\" because it is not a struct", expr.ToString())); + } + auto &struct_children = StructType::GetChildTypes(col_type); + vector column_names(3); + column_names[0] = binding->alias; + column_names[1] = expr.relation_name; + for (auto &child : struct_children) { + if (CheckExclusionList(expr, child.first, new_select_list, excluded_columns)) { + continue; + } + column_names[2] = child.first; + new_select_list.push_back(make_uniq(column_names)); + } + } else { + for (auto &column_name : binding->names) { + if (CheckExclusionList(expr, column_name, new_select_list, excluded_columns)) { + continue; + } + + new_select_list.push_back(make_uniq(column_name, binding->alias)); + } + } + } + for (auto &excluded : expr.exclude_list) { + if (excluded_columns.find(excluded) == excluded_columns.end()) { + throw BinderException("Column \"%s\" in EXCLUDE list not found in %s", excluded, + expr.relation_name.empty() ? "FROM clause" : expr.relation_name.c_str()); + } + } + for (auto &entry : expr.replace_list) { + if (excluded_columns.find(entry.first) == excluded_columns.end()) { + throw BinderException("Column \"%s\" in REPLACE list not found in %s", entry.first, + expr.relation_name.empty() ? "FROM clause" : expr.relation_name.c_str()); + } + } +} + +void BindContext::GetTypesAndNames(vector &result_names, vector &result_types) { + for (auto &binding_entry : bindings_list) { + auto &binding = binding_entry.get(); + D_ASSERT(binding.names.size() == binding.types.size()); + for (idx_t i = 0; i < binding.names.size(); i++) { + result_names.push_back(binding.names[i]); + result_types.push_back(binding.types[i]); + } + } +} + +void BindContext::AddBinding(const string &alias, unique_ptr binding) { + if (bindings.find(alias) != bindings.end()) { + throw BinderException("Duplicate alias \"%s\" in query!", alias); + } + bindings_list.push_back(*binding); + bindings[alias] = std::move(binding); +} + +void BindContext::AddBaseTable(idx_t index, const string &alias, const vector &names, + const vector &types, vector &bound_column_ids, + StandardEntry *entry, bool add_row_id) { + AddBinding(alias, make_uniq(alias, types, names, bound_column_ids, entry, index, add_row_id)); +} + +void BindContext::AddTableFunction(idx_t index, const string &alias, const vector &names, + const vector &types, vector &bound_column_ids, + StandardEntry *entry) { + AddBinding(alias, make_uniq(alias, types, names, bound_column_ids, entry, index)); +} + +static string AddColumnNameToBinding(const string &base_name, case_insensitive_set_t ¤t_names) { + idx_t index = 1; + string name = base_name; + while (current_names.find(name) != current_names.end()) { + name = base_name + ":" + std::to_string(index++); + } + current_names.insert(name); + return name; +} + +vector BindContext::AliasColumnNames(const string &table_name, const vector &names, + const vector &column_aliases) { + vector result; + if (column_aliases.size() > names.size()) { + throw BinderException("table \"%s\" has %lld columns available but %lld columns specified", table_name, + names.size(), column_aliases.size()); + } + case_insensitive_set_t current_names; + // use any provided column aliases first + for (idx_t i = 0; i < column_aliases.size(); i++) { + result.push_back(AddColumnNameToBinding(column_aliases[i], current_names)); + } + // if not enough aliases were provided, use the default names for remaining columns + for (idx_t i = column_aliases.size(); i < names.size(); i++) { + result.push_back(AddColumnNameToBinding(names[i], current_names)); + } + return result; +} + +void BindContext::AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery) { + auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); + AddGenericBinding(index, alias, names, subquery.types); +} + +void BindContext::AddEntryBinding(idx_t index, const string &alias, const vector &names, + const vector &types, StandardEntry &entry) { + AddBinding(alias, make_uniq(alias, types, names, index, entry)); +} + +void BindContext::AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, + ViewCatalogEntry *view) { + auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); + AddEntryBinding(index, alias, names, subquery.types, view->Cast()); +} + +void BindContext::AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery) { + auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); + AddGenericBinding(index, alias, names, subquery.types); +} + +void BindContext::AddGenericBinding(idx_t index, const string &alias, const vector &names, + const vector &types) { + AddBinding(alias, make_uniq(BindingType::BASE, alias, types, names, index)); +} + +void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector &names, + const vector &types) { + auto binding = make_shared(BindingType::BASE, alias, types, names, index); + + if (cte_bindings.find(alias) != cte_bindings.end()) { + throw BinderException("Duplicate alias \"%s\" in query!", alias); + } + cte_bindings[alias] = std::move(binding); + cte_references[alias] = std::make_shared(0); +} + +void BindContext::AddContext(BindContext other) { + for (auto &binding : other.bindings) { + if (bindings.find(binding.first) != bindings.end()) { + throw BinderException("Duplicate alias \"%s\" in query!", binding.first); + } + bindings[binding.first] = std::move(binding.second); + } + for (auto &binding : other.bindings_list) { + bindings_list.push_back(binding); + } + for (auto &entry : other.using_columns) { + for (auto &alias : entry.second) { +#ifdef DEBUG + for (auto &other_alias : using_columns[entry.first]) { + for (auto &col : alias.get().bindings) { + D_ASSERT(other_alias.get().bindings.find(col) == other_alias.get().bindings.end()); + } + } +#endif + using_columns[entry.first].insert(alias); + } + } +} + +void BindContext::RemoveContext(vector> &other_bindings_list) { + for (auto &other_binding : other_bindings_list) { + auto it = std::remove_if(bindings_list.begin(), bindings_list.end(), [other_binding](reference x) { + return x.get().alias == other_binding.get().alias; + }); + bindings_list.erase(it, bindings_list.end()); + } + + for (auto &other_binding : other_bindings_list) { + auto &alias = other_binding.get().alias; + if (bindings.find(alias) != bindings.end()) { + bindings.erase(alias); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp new file mode 100644 index 00000000..7cd20ed2 --- /dev/null +++ b/src/duckdb/src/planner/binder.cpp @@ -0,0 +1,503 @@ +#include "duckdb/planner/binder.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/statement/list.hpp" +#include "duckdb/parser/tableref/list.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/tableref/list.hpp" +#include "duckdb/planner/query_node/list.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/expression_binder/returning_binder.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_sample.hpp" +#include "duckdb/parser/query_node/list.hpp" + +#include + +namespace duckdb { + +Binder *Binder::GetRootBinder() { + Binder *root = this; + while (root->parent) { + root = root->parent.get(); + } + return root; +} + +idx_t Binder::GetBinderDepth() const { + const Binder *root = this; + idx_t depth = 1; + while (root->parent) { + depth++; + root = root->parent.get(); + } + return depth; +} + +shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr parent, bool inherit_ctes) { + auto depth = parent ? parent->GetBinderDepth() : 0; + if (depth > context.config.max_expression_depth) { + throw BinderException("Max expression depth limit of %lld exceeded. Use \"SET max_expression_depth TO x\" to " + "increase the maximum expression depth.", + context.config.max_expression_depth); + } + return make_shared(true, context, parent ? parent->shared_from_this() : nullptr, inherit_ctes); +} + +Binder::Binder(bool, ClientContext &context, shared_ptr parent_p, bool inherit_ctes_p) + : context(context), parent(std::move(parent_p)), bound_tables(0), inherit_ctes(inherit_ctes_p) { + if (parent) { + + // We have to inherit macro and lambda parameter bindings and from the parent binder, if there is a parent. + macro_binding = parent->macro_binding; + lambda_bindings = parent->lambda_bindings; + + if (inherit_ctes) { + // We have to inherit CTE bindings from the parent bind_context, if there is a parent. + bind_context.SetCTEBindings(parent->bind_context.GetCTEBindings()); + bind_context.cte_references = parent->bind_context.cte_references; + parameters = parent->parameters; + } + } +} + +BoundStatement Binder::Bind(SQLStatement &statement) { + root_statement = &statement; + switch (statement.type) { + case StatementType::SELECT_STATEMENT: + return Bind(statement.Cast()); + case StatementType::INSERT_STATEMENT: + return Bind(statement.Cast()); + case StatementType::COPY_STATEMENT: + return Bind(statement.Cast()); + case StatementType::DELETE_STATEMENT: + return Bind(statement.Cast()); + case StatementType::UPDATE_STATEMENT: + return Bind(statement.Cast()); + case StatementType::RELATION_STATEMENT: + return Bind(statement.Cast()); + case StatementType::CREATE_STATEMENT: + return Bind(statement.Cast()); + case StatementType::DROP_STATEMENT: + return Bind(statement.Cast()); + case StatementType::ALTER_STATEMENT: + return Bind(statement.Cast()); + case StatementType::TRANSACTION_STATEMENT: + return Bind(statement.Cast()); + case StatementType::PRAGMA_STATEMENT: + return Bind(statement.Cast()); + case StatementType::EXPLAIN_STATEMENT: + return Bind(statement.Cast()); + case StatementType::VACUUM_STATEMENT: + return Bind(statement.Cast()); + case StatementType::SHOW_STATEMENT: + return Bind(statement.Cast()); + case StatementType::CALL_STATEMENT: + return Bind(statement.Cast()); + case StatementType::EXPORT_STATEMENT: + return Bind(statement.Cast()); + case StatementType::SET_STATEMENT: + return Bind(statement.Cast()); + case StatementType::LOAD_STATEMENT: + return Bind(statement.Cast()); + case StatementType::EXTENSION_STATEMENT: + return Bind(statement.Cast()); + case StatementType::PREPARE_STATEMENT: + return Bind(statement.Cast()); + case StatementType::EXECUTE_STATEMENT: + return Bind(statement.Cast()); + case StatementType::LOGICAL_PLAN_STATEMENT: + return Bind(statement.Cast()); + case StatementType::ATTACH_STATEMENT: + return Bind(statement.Cast()); + case StatementType::DETACH_STATEMENT: + return Bind(statement.Cast()); + default: // LCOV_EXCL_START + throw NotImplementedException("Unimplemented statement type \"%s\" for Bind", + StatementTypeToString(statement.type)); + } // LCOV_EXCL_STOP +} + +void Binder::AddCTEMap(CommonTableExpressionMap &cte_map) { + for (auto &cte_it : cte_map.map) { + AddCTE(cte_it.first, *cte_it.second); + } +} + +unique_ptr Binder::BindNode(QueryNode &node) { + // first we visit the set of CTEs and add them to the bind context + AddCTEMap(node.cte_map); + // now we bind the node + unique_ptr result; + switch (node.type) { + case QueryNodeType::SELECT_NODE: + result = BindNode(node.Cast()); + break; + case QueryNodeType::RECURSIVE_CTE_NODE: + result = BindNode(node.Cast()); + break; + case QueryNodeType::CTE_NODE: + result = BindNode(node.Cast()); + break; + default: + D_ASSERT(node.type == QueryNodeType::SET_OPERATION_NODE); + result = BindNode(node.Cast()); + break; + } + return result; +} + +BoundStatement Binder::Bind(QueryNode &node) { + auto bound_node = BindNode(node); + + BoundStatement result; + result.names = bound_node->names; + result.types = bound_node->types; + + // and plan it + result.plan = CreatePlan(*bound_node); + return result; +} + +unique_ptr Binder::CreatePlan(BoundQueryNode &node) { + switch (node.type) { + case QueryNodeType::SELECT_NODE: + return CreatePlan(node.Cast()); + case QueryNodeType::SET_OPERATION_NODE: + return CreatePlan(node.Cast()); + case QueryNodeType::RECURSIVE_CTE_NODE: + return CreatePlan(node.Cast()); + case QueryNodeType::CTE_NODE: + return CreatePlan(node.Cast()); + default: + throw InternalException("Unsupported bound query node type"); + } +} + +unique_ptr Binder::Bind(TableRef &ref) { + unique_ptr result; + switch (ref.type) { + case TableReferenceType::BASE_TABLE: + result = Bind(ref.Cast()); + break; + case TableReferenceType::JOIN: + result = Bind(ref.Cast()); + break; + case TableReferenceType::SUBQUERY: + result = Bind(ref.Cast()); + break; + case TableReferenceType::EMPTY: + result = Bind(ref.Cast()); + break; + case TableReferenceType::TABLE_FUNCTION: + result = Bind(ref.Cast()); + break; + case TableReferenceType::EXPRESSION_LIST: + result = Bind(ref.Cast()); + break; + case TableReferenceType::PIVOT: + result = Bind(ref.Cast()); + break; + case TableReferenceType::CTE: + case TableReferenceType::INVALID: + default: + throw InternalException("Unknown table ref type"); + } + result->sample = std::move(ref.sample); + return result; +} + +unique_ptr Binder::CreatePlan(BoundTableRef &ref) { + unique_ptr root; + switch (ref.type) { + case TableReferenceType::BASE_TABLE: + root = CreatePlan(ref.Cast()); + break; + case TableReferenceType::SUBQUERY: + root = CreatePlan(ref.Cast()); + break; + case TableReferenceType::JOIN: + root = CreatePlan(ref.Cast()); + break; + case TableReferenceType::TABLE_FUNCTION: + root = CreatePlan(ref.Cast()); + break; + case TableReferenceType::EMPTY: + root = CreatePlan(ref.Cast()); + break; + case TableReferenceType::EXPRESSION_LIST: + root = CreatePlan(ref.Cast()); + break; + case TableReferenceType::CTE: + root = CreatePlan(ref.Cast()); + break; + case TableReferenceType::PIVOT: + root = CreatePlan(ref.Cast()); + break; + case TableReferenceType::INVALID: + default: + throw InternalException("Unsupported bound table ref type"); + } + // plan the sample clause + if (ref.sample) { + root = make_uniq(std::move(ref.sample), std::move(root)); + } + return root; +} + +void Binder::AddCTE(const string &name, CommonTableExpressionInfo &info) { + D_ASSERT(!name.empty()); + auto entry = CTE_bindings.find(name); + if (entry != CTE_bindings.end()) { + throw InternalException("Duplicate CTE \"%s\" in query!", name); + } + CTE_bindings.insert(make_pair(name, reference(info))); +} + +optional_ptr Binder::FindCTE(const string &name, bool skip) { + auto entry = CTE_bindings.find(name); + if (entry != CTE_bindings.end()) { + if (!skip || entry->second.get().query->node->type == QueryNodeType::RECURSIVE_CTE_NODE) { + return &entry->second.get(); + } + } + if (parent && inherit_ctes) { + return parent->FindCTE(name, name == alias); + } + return nullptr; +} + +bool Binder::CTEIsAlreadyBound(CommonTableExpressionInfo &cte) { + if (bound_ctes.find(cte) != bound_ctes.end()) { + return true; + } + if (parent && inherit_ctes) { + return parent->CTEIsAlreadyBound(cte); + } + return false; +} + +void Binder::AddBoundView(ViewCatalogEntry &view) { + // check if the view is already bound + auto current = this; + while (current) { + if (current->bound_views.find(view) != current->bound_views.end()) { + throw BinderException("infinite recursion detected: attempting to recursively bind view \"%s\"", view.name); + } + current = current->parent.get(); + } + bound_views.insert(view); +} + +idx_t Binder::GenerateTableIndex() { + auto root_binder = GetRootBinder(); + return root_binder->bound_tables++; +} + +void Binder::PushExpressionBinder(ExpressionBinder &binder) { + GetActiveBinders().push_back(binder); +} + +void Binder::PopExpressionBinder() { + D_ASSERT(HasActiveBinder()); + GetActiveBinders().pop_back(); +} + +void Binder::SetActiveBinder(ExpressionBinder &binder) { + D_ASSERT(HasActiveBinder()); + GetActiveBinders().back() = binder; +} + +ExpressionBinder &Binder::GetActiveBinder() { + return GetActiveBinders().back(); +} + +bool Binder::HasActiveBinder() { + return !GetActiveBinders().empty(); +} + +vector> &Binder::GetActiveBinders() { + auto root_binder = GetRootBinder(); + return root_binder->active_binders; +} + +void Binder::AddUsingBindingSet(unique_ptr set) { + auto root_binder = GetRootBinder(); + root_binder->bind_context.AddUsingBindingSet(std::move(set)); +} + +void Binder::MoveCorrelatedExpressions(Binder &other) { + MergeCorrelatedColumns(other.correlated_columns); + other.correlated_columns.clear(); +} + +void Binder::MergeCorrelatedColumns(vector &other) { + for (idx_t i = 0; i < other.size(); i++) { + AddCorrelatedColumn(other[i]); + } +} + +void Binder::AddCorrelatedColumn(const CorrelatedColumnInfo &info) { + // we only add correlated columns to the list if they are not already there + if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { + correlated_columns.push_back(info); + } +} + +bool Binder::HasMatchingBinding(const string &table_name, const string &column_name, string &error_message) { + string empty_schema; + return HasMatchingBinding(empty_schema, table_name, column_name, error_message); +} + +bool Binder::HasMatchingBinding(const string &schema_name, const string &table_name, const string &column_name, + string &error_message) { + string empty_catalog; + return HasMatchingBinding(empty_catalog, schema_name, table_name, column_name, error_message); +} + +bool Binder::HasMatchingBinding(const string &catalog_name, const string &schema_name, const string &table_name, + const string &column_name, string &error_message) { + optional_ptr binding; + D_ASSERT(!lambda_bindings); + if (macro_binding && table_name == macro_binding->alias) { + binding = optional_ptr(macro_binding.get()); + } else { + binding = bind_context.GetBinding(table_name, error_message); + } + + if (!binding) { + return false; + } + if (!catalog_name.empty() || !schema_name.empty()) { + auto catalog_entry = binding->GetStandardEntry(); + if (!catalog_entry) { + return false; + } + if (!catalog_name.empty() && catalog_entry->catalog.GetName() != catalog_name) { + return false; + } + if (!schema_name.empty() && catalog_entry->schema.name != schema_name) { + return false; + } + if (catalog_entry->name != table_name) { + return false; + } + } + bool binding_found; + binding_found = binding->HasMatchingBinding(column_name); + if (!binding_found) { + error_message = binding->ColumnNotFoundError(column_name); + } + return binding_found; +} + +void Binder::SetBindingMode(BindingMode mode) { + auto root_binder = GetRootBinder(); + // FIXME: this used to also set the 'mode' for the current binder, was that necessary? + root_binder->mode = mode; +} + +BindingMode Binder::GetBindingMode() { + auto root_binder = GetRootBinder(); + return root_binder->mode; +} + +void Binder::SetCanContainNulls(bool can_contain_nulls_p) { + can_contain_nulls = can_contain_nulls_p; +} + +void Binder::AddTableName(string table_name) { + auto root_binder = GetRootBinder(); + root_binder->table_names.insert(std::move(table_name)); +} + +const unordered_set &Binder::GetTableNames() { + auto root_binder = GetRootBinder(); + return root_binder->table_names; +} + +string Binder::FormatError(ParsedExpression &expr_context, const string &message) { + return FormatError(expr_context.query_location, message); +} + +string Binder::FormatError(TableRef &ref_context, const string &message) { + return FormatError(ref_context.query_location, message); +} + +string Binder::FormatErrorRecursive(idx_t query_location, const string &message, vector &values) { + QueryErrorContext context(root_statement, query_location); + return context.FormatErrorRecursive(message, values); +} + +// FIXME: this is extremely naive +void VerifyNotExcluded(ParsedExpression &expr) { + if (expr.type == ExpressionType::COLUMN_REF) { + auto &column_ref = expr.Cast(); + if (!column_ref.IsQualified()) { + return; + } + auto &table_name = column_ref.GetTableName(); + if (table_name == "excluded") { + throw NotImplementedException("'excluded' qualified columns are not supported in the RETURNING clause yet"); + } + return; + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](const ParsedExpression &child) { VerifyNotExcluded((ParsedExpression &)child); }); +} + +BoundStatement Binder::BindReturning(vector> returning_list, TableCatalogEntry &table, + const string &alias, idx_t update_table_index, + unique_ptr child_operator, BoundStatement result) { + + vector types; + vector names; + + auto binder = Binder::CreateBinder(context); + + vector bound_columns; + idx_t column_count = 0; + for (auto &col : table.GetColumns().Logical()) { + names.push_back(col.Name()); + types.push_back(col.Type()); + if (!col.Generated()) { + bound_columns.push_back(column_count); + } + column_count++; + } + + binder->bind_context.AddBaseTable(update_table_index, alias.empty() ? table.name : alias, names, types, + bound_columns, &table, false); + ReturningBinder returning_binder(*binder, context); + + vector> projection_expressions; + LogicalType result_type; + vector> new_returning_list; + binder->ExpandStarExpressions(returning_list, new_returning_list); + for (auto &returning_expr : new_returning_list) { + VerifyNotExcluded(*returning_expr); + auto expr = returning_binder.Bind(returning_expr, &result_type); + result.names.push_back(expr->GetName()); + result.types.push_back(result_type); + projection_expressions.push_back(std::move(expr)); + } + + auto projection = make_uniq(GenerateTableIndex(), std::move(projection_expressions)); + projection->AddChild(std::move(child_operator)); + D_ASSERT(result.types.size() == result.names.size()); + result.plan = std::move(projection); + // If an insert/delete/update statement returns data, there are sometimes issues with streaming results + // where the data modification doesn't take place until the streamed result is exhausted. Once a row is + // returned, it should be guaranteed that the row has been inserted. + // see https://github.com/duckdb/duckdb/issues/8310 + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::QUERY_RESULT; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp new file mode 100644 index 00000000..4bd47426 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp @@ -0,0 +1,256 @@ +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression_binder/aggregate_binder.hpp" +#include "duckdb/planner/expression_binder/base_select_binder.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +static Value NegatePercentileValue(const Value &v, const bool desc) { + if (v.IsNull()) { + return v; + } + + const auto frac = v.GetValue(); + if (frac < 0 || frac > 1) { + throw BinderException("PERCENTILEs can only take parameters in the range [0, 1]"); + } + + if (!desc) { + return v; + } + + const auto &type = v.type(); + switch (type.id()) { + case LogicalTypeId::DECIMAL: { + // Negate DECIMALs as DECIMAL. + const auto integral = IntegralValue::Get(v); + const auto width = DecimalType::GetWidth(type); + const auto scale = DecimalType::GetScale(type); + switch (type.InternalType()) { + case PhysicalType::INT16: + return Value::DECIMAL(Cast::Operation(-integral), width, scale); + case PhysicalType::INT32: + return Value::DECIMAL(Cast::Operation(-integral), width, scale); + case PhysicalType::INT64: + return Value::DECIMAL(Cast::Operation(-integral), width, scale); + case PhysicalType::INT128: + return Value::DECIMAL(-integral, width, scale); + default: + throw InternalException("Unknown DECIMAL type"); + } + } + default: + // Everything else can just be a DOUBLE + return Value::DOUBLE(-v.GetValue()); + } +} + +static void NegatePercentileFractions(ClientContext &context, unique_ptr &fractions, bool desc) { + D_ASSERT(fractions.get()); + D_ASSERT(fractions->expression_class == ExpressionClass::BOUND_EXPRESSION); + auto &bound = BoundExpression::GetExpression(*fractions); + + if (!bound->IsFoldable()) { + return; + } + + Value value = ExpressionExecutor::EvaluateScalar(context, *bound); + if (value.type().id() == LogicalTypeId::LIST) { + vector values; + for (const auto &element_val : ListValue::GetChildren(value)) { + values.push_back(NegatePercentileValue(element_val, desc)); + } + if (values.empty()) { + throw BinderException("Empty list in percentile not allowed"); + } + bound = make_uniq(Value::LIST(values)); + } else { + bound = make_uniq(NegatePercentileValue(value, desc)); + } +} + +BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFunctionCatalogEntry &func, idx_t depth) { + // first bind the child of the aggregate expression (if any) + this->bound_aggregate = true; + unique_ptr bound_filter; + AggregateBinder aggregate_binder(binder, context); + string error; + + // Now we bind the filter (if any) + if (aggr.filter) { + aggregate_binder.BindChild(aggr.filter, 0, error); + } + + // Handle ordered-set aggregates by moving the single ORDER BY expression to the front of the children. + // https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-ORDEREDSET-TABLE + bool ordered_set_agg = false; + bool negate_fractions = false; + if (aggr.order_bys && aggr.order_bys->orders.size() == 1) { + const auto &func_name = aggr.function_name; + ordered_set_agg = (func_name == "quantile_cont" || func_name == "quantile_disc" || + (func_name == "mode" && aggr.children.empty())); + + if (ordered_set_agg) { + auto &config = DBConfig::GetConfig(context); + const auto &order = aggr.order_bys->orders[0]; + const auto sense = + (order.type == OrderType::ORDER_DEFAULT) ? config.options.default_order_type : order.type; + negate_fractions = (sense == OrderType::DESCENDING); + } + } + + for (auto &child : aggr.children) { + aggregate_binder.BindChild(child, 0, error); + // We have to negate the fractions for PERCENTILE_XXXX DESC + if (error.empty() && ordered_set_agg) { + NegatePercentileFractions(context, child, negate_fractions); + } + } + + // Bind the ORDER BYs, if any + if (aggr.order_bys && !aggr.order_bys->orders.empty()) { + for (auto &order : aggr.order_bys->orders) { + aggregate_binder.BindChild(order.expression, 0, error); + } + } + + if (!error.empty()) { + // failed to bind child + if (aggregate_binder.HasBoundColumns()) { + for (idx_t i = 0; i < aggr.children.size(); i++) { + // however, we bound columns! + // that means this aggregation belongs to this node + // check if we have to resolve any errors by binding with parent binders + bool success = aggregate_binder.BindCorrelatedColumns(aggr.children[i]); + // if there is still an error after this, we could not successfully bind the aggregate + if (!success) { + throw BinderException(error); + } + auto &bound_expr = BoundExpression::GetExpression(*aggr.children[i]); + ExtractCorrelatedExpressions(binder, *bound_expr); + } + if (aggr.filter) { + bool success = aggregate_binder.BindCorrelatedColumns(aggr.filter); + // if there is still an error after this, we could not successfully bind the aggregate + if (!success) { + throw BinderException(error); + } + auto &bound_expr = BoundExpression::GetExpression(*aggr.filter); + ExtractCorrelatedExpressions(binder, *bound_expr); + } + if (aggr.order_bys && !aggr.order_bys->orders.empty()) { + for (auto &order : aggr.order_bys->orders) { + bool success = aggregate_binder.BindCorrelatedColumns(order.expression); + if (!success) { + throw BinderException(error); + } + auto &bound_expr = BoundExpression::GetExpression(*order.expression); + ExtractCorrelatedExpressions(binder, *bound_expr); + } + } + } else { + // we didn't bind columns, try again in children + return BindResult(error); + } + } else if (depth > 0 && !aggregate_binder.HasBoundColumns()) { + return BindResult("Aggregate with only constant parameters has to be bound in the root subquery"); + } + + if (aggr.filter) { + auto &child = BoundExpression::GetExpression(*aggr.filter); + bound_filter = BoundCastExpression::AddCastToType(context, std::move(child), LogicalType::BOOLEAN); + } + + // all children bound successfully + // extract the children and types + vector types; + vector arguments; + vector> children; + + if (ordered_set_agg) { + const bool order_sensitive = (aggr.function_name == "mode"); + for (auto &order : aggr.order_bys->orders) { + auto &child = BoundExpression::GetExpression(*order.expression); + types.push_back(child->return_type); + arguments.push_back(child->return_type); + if (order_sensitive) { + children.push_back(child->Copy()); + } else { + children.push_back(std::move(child)); + } + } + if (!order_sensitive) { + aggr.order_bys->orders.clear(); + } + } + + for (idx_t i = 0; i < aggr.children.size(); i++) { + auto &child = BoundExpression::GetExpression(*aggr.children[i]); + types.push_back(child->return_type); + arguments.push_back(child->return_type); + children.push_back(std::move(child)); + } + + // bind the aggregate + FunctionBinder function_binder(context); + idx_t best_function = function_binder.BindFunction(func.name, func.functions, types, error); + if (best_function == DConstants::INVALID_INDEX) { + throw BinderException(binder.FormatError(aggr, error)); + } + // found a matching function! + auto bound_function = func.functions.GetFunctionByOffset(best_function); + + // Bind any sort columns, unless the aggregate is order-insensitive + unique_ptr order_bys; + if (!aggr.order_bys->orders.empty()) { + order_bys = make_uniq(); + auto &config = DBConfig::GetConfig(context); + for (auto &order : aggr.order_bys->orders) { + auto &order_expr = BoundExpression::GetExpression(*order.expression); + const auto sense = config.ResolveOrder(order.type); + const auto null_order = config.ResolveNullOrder(sense, order.null_order); + order_bys->orders.emplace_back(sense, null_order, std::move(order_expr)); + } + } + + auto aggregate = + function_binder.BindAggregateFunction(bound_function, std::move(children), std::move(bound_filter), + aggr.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT); + if (aggr.export_state) { + aggregate = ExportAggregateFunction::Bind(std::move(aggregate)); + } + aggregate->order_bys = std::move(order_bys); + + // check for all the aggregates if this aggregate already exists + idx_t aggr_index; + auto entry = node.aggregate_map.find(*aggregate); + if (entry == node.aggregate_map.end()) { + // new aggregate: insert into aggregate list + aggr_index = node.aggregates.size(); + node.aggregate_map[*aggregate] = aggr_index; + node.aggregates.push_back(std::move(aggregate)); + } else { + // duplicate aggregate: simplify refer to this aggregate + aggr_index = entry->second; + } + + // now create a column reference referring to the aggregate + auto colref = make_uniq( + aggr.alias.empty() ? node.aggregates[aggr_index]->ToString() : aggr.alias, + node.aggregates[aggr_index]->return_type, ColumnBinding(node.aggregate_index, aggr_index), depth); + // move the aggregate expression into the set of bound aggregates + return BindResult(std::move(colref)); +} +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp new file mode 100644 index 00000000..1c82daef --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp @@ -0,0 +1,62 @@ +#include "duckdb/parser/expression/between_expression.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(BetweenExpression &expr, idx_t depth) { + // first try to bind the children of the case expression + string error; + BindChild(expr.input, depth, error); + BindChild(expr.lower, depth, error); + BindChild(expr.upper, depth, error); + if (!error.empty()) { + return BindResult(error); + } + // the children have been successfully resolved + auto &input = BoundExpression::GetExpression(*expr.input); + auto &lower = BoundExpression::GetExpression(*expr.lower); + auto &upper = BoundExpression::GetExpression(*expr.upper); + + auto input_sql_type = input->return_type; + auto lower_sql_type = lower->return_type; + auto upper_sql_type = upper->return_type; + + // cast the input types to the same type + // now obtain the result type of the input types + auto input_type = BoundComparisonExpression::BindComparison(input_sql_type, lower_sql_type); + input_type = BoundComparisonExpression::BindComparison(input_type, upper_sql_type); + // add casts (if necessary) + input = BoundCastExpression::AddCastToType(context, std::move(input), input_type); + lower = BoundCastExpression::AddCastToType(context, std::move(lower), input_type); + upper = BoundCastExpression::AddCastToType(context, std::move(upper), input_type); + if (input_type.id() == LogicalTypeId::VARCHAR) { + // handle collation + auto collation = StringType::GetCollation(input_type); + input = PushCollation(context, std::move(input), collation, false); + lower = PushCollation(context, std::move(lower), collation, false); + upper = PushCollation(context, std::move(upper), collation, false); + } + if (!input->HasSideEffects() && !input->HasParameter() && !input->HasSubquery()) { + // the expression does not have side effects and can be copied: create two comparisons + // the reason we do this is that individual comparisons are easier to handle in optimizers + // if both comparisons remain they will be folded together again into a single BETWEEN in the optimizer + auto left_compare = make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, + input->Copy(), std::move(lower)); + auto right_compare = make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, + std::move(input), std::move(upper)); + return BindResult(make_uniq(ExpressionType::CONJUNCTION_AND, + std::move(left_compare), std::move(right_compare))); + } else { + // expression has side effects: we cannot duplicate it + // create a bound_between directly + return BindResult( + make_uniq(std::move(input), std::move(lower), std::move(upper), true, true)); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_case_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_case_expression.cpp new file mode 100644 index 00000000..6ba9bd8f --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_case_expression.cpp @@ -0,0 +1,43 @@ +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(CaseExpression &expr, idx_t depth) { + // first try to bind the children of the case expression + string error; + for (auto &check : expr.case_checks) { + BindChild(check.when_expr, depth, error); + BindChild(check.then_expr, depth, error); + } + BindChild(expr.else_expr, depth, error); + if (!error.empty()) { + return BindResult(error); + } + // the children have been successfully resolved + // figure out the result type of the CASE expression + auto &else_expr = BoundExpression::GetExpression(*expr.else_expr); + auto return_type = else_expr->return_type; + for (auto &check : expr.case_checks) { + auto &then_expr = BoundExpression::GetExpression(*check.then_expr); + return_type = LogicalType::MaxLogicalType(return_type, then_expr->return_type); + } + + // bind all the individual components of the CASE statement + auto result = make_uniq(return_type); + for (idx_t i = 0; i < expr.case_checks.size(); i++) { + auto &check = expr.case_checks[i]; + auto &when_expr = BoundExpression::GetExpression(*check.when_expr); + auto &then_expr = BoundExpression::GetExpression(*check.then_expr); + BoundCaseCheck result_check; + result_check.when_expr = + BoundCastExpression::AddCastToType(context, std::move(when_expr), LogicalType::BOOLEAN); + result_check.then_expr = BoundCastExpression::AddCastToType(context, std::move(then_expr), return_type); + result->case_checks.push_back(std::move(result_check)); + } + result->else_expr = BoundCastExpression::AddCastToType(context, std::move(else_expr), return_type); + return BindResult(std::move(result)); +} +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_cast_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_cast_expression.cpp new file mode 100644 index 00000000..a62dd008 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_cast_expression.cpp @@ -0,0 +1,32 @@ +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(CastExpression &expr, idx_t depth) { + // first try to bind the child of the cast expression + string error = Bind(expr.child, depth); + if (!error.empty()) { + return BindResult(error); + } + // FIXME: We can also implement 'hello'::schema.custom_type; and pass by the schema down here. + // Right now just considering its DEFAULT_SCHEMA always + Binder::BindLogicalType(context, expr.cast_type); + // the children have been successfully resolved + auto &child = BoundExpression::GetExpression(*expr.child); + if (expr.try_cast) { + if (child->return_type == expr.cast_type) { + // no cast required: type matches + return BindResult(std::move(child)); + } + child = BoundCastExpression::AddCastToType(context, std::move(child), expr.cast_type, true); + } else { + // otherwise add a cast to the target type + child = BoundCastExpression::AddCastToType(context, std::move(child), expr.cast_type); + } + return BindResult(std::move(child)); +} +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_collate_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_collate_expression.cpp new file mode 100644 index 00000000..cbe51e7d --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_collate_expression.cpp @@ -0,0 +1,26 @@ +#include "duckdb/parser/expression/collate_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(CollateExpression &expr, idx_t depth) { + // first try to bind the child of the cast expression + string error = Bind(expr.child, depth); + if (!error.empty()) { + return BindResult(error); + } + auto &child = BoundExpression::GetExpression(*expr.child); + if (child->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (child->return_type.id() != LogicalTypeId::VARCHAR) { + throw BinderException("collations are only supported for type varchar"); + } + // Validate the collation, but don't use it + PushCollation(context, child->Copy(), expr.collation, false); + child->return_type = LogicalType::VARCHAR_COLLATION(expr.collation); + return BindResult(std::move(child)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp new file mode 100644 index 00000000..47f94962 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp @@ -0,0 +1,397 @@ +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/positional_reference_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_lambdaref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/expression_binder/where_binder.hpp" + +namespace duckdb { + +string GetSQLValueFunctionName(const string &column_name) { + auto lcase = StringUtil::Lower(column_name); + if (lcase == "current_catalog") { + return "current_catalog"; + } else if (lcase == "current_date") { + return "current_date"; + } else if (lcase == "current_schema") { + return "current_schema"; + } else if (lcase == "current_role") { + return "current_role"; + } else if (lcase == "current_time") { + return "get_current_time"; + } else if (lcase == "current_timestamp") { + return "get_current_timestamp"; + } else if (lcase == "current_user") { + return "current_user"; + } else if (lcase == "localtime") { + return "current_localtime"; + } else if (lcase == "localtimestamp") { + return "current_localtimestamp"; + } else if (lcase == "session_user") { + return "session_user"; + } else if (lcase == "user") { + return "user"; + } + return string(); +} + +unique_ptr ExpressionBinder::GetSQLValueFunction(const string &column_name) { + auto value_function = GetSQLValueFunctionName(column_name); + if (value_function.empty()) { + return nullptr; + } + + vector> children; + return make_uniq(value_function, std::move(children)); +} + +unique_ptr ExpressionBinder::QualifyColumnName(const string &column_name, string &error_message) { + auto using_binding = binder.bind_context.GetUsingBinding(column_name); + if (using_binding) { + // we are referencing a USING column + // check if we can refer to one of the base columns directly + unique_ptr expression; + if (!using_binding->primary_binding.empty()) { + // we can! just assign the table name and re-bind + return binder.bind_context.CreateColumnReference(using_binding->primary_binding, column_name); + } else { + // // we cannot! we need to bind this as a coalesce between all the relevant columns + auto coalesce = make_uniq(ExpressionType::OPERATOR_COALESCE); + coalesce->children.reserve(using_binding->bindings.size()); + for (auto &entry : using_binding->bindings) { + coalesce->children.push_back(make_uniq(column_name, entry)); + } + return std::move(coalesce); + } + } + + // find a binding that contains this + string table_name = binder.bind_context.GetMatchingBinding(column_name); + + // throw an error if a macro conflicts with a column name + auto is_macro_column = false; + if (binder.macro_binding != nullptr && binder.macro_binding->HasMatchingBinding(column_name)) { + is_macro_column = true; + if (!table_name.empty()) { + throw BinderException("Conflicting column names for column " + column_name + "!"); + } + } + + if (lambda_bindings) { + for (idx_t i = 0; i < lambda_bindings->size(); i++) { + if ((*lambda_bindings)[i].HasMatchingBinding(column_name)) { + + // throw an error if a lambda conflicts with a column name or a macro + if (!table_name.empty() || is_macro_column) { + throw BinderException("Conflicting column names for column " + column_name + "!"); + } + + D_ASSERT(!(*lambda_bindings)[i].alias.empty()); + return make_uniq(column_name, (*lambda_bindings)[i].alias); + } + } + } + + if (is_macro_column) { + D_ASSERT(!binder.macro_binding->alias.empty()); + return make_uniq(column_name, binder.macro_binding->alias); + } + // see if it's a column + if (table_name.empty()) { + // column was not found - check if it is a SQL value function + auto value_function = GetSQLValueFunction(column_name); + if (value_function) { + return value_function; + } + // it's not, find candidates and error + auto similar_bindings = binder.bind_context.GetSimilarBindings(column_name); + string candidate_str = StringUtil::CandidatesMessage(similar_bindings, "Candidate bindings"); + error_message = + StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", column_name, candidate_str); + return nullptr; + } + return binder.bind_context.CreateColumnReference(table_name, column_name); +} + +void ExpressionBinder::QualifyColumnNames(unique_ptr &expr) { + switch (expr->type) { + case ExpressionType::COLUMN_REF: { + auto &colref = expr->Cast(); + string error_message; + auto new_expr = QualifyColumnName(colref, error_message); + if (new_expr) { + if (!expr->alias.empty()) { + new_expr->alias = expr->alias; + } + new_expr->query_location = colref.query_location; + expr = std::move(new_expr); + } + break; + } + case ExpressionType::POSITIONAL_REFERENCE: { + auto &ref = expr->Cast(); + if (ref.alias.empty()) { + string table_name, column_name; + auto error = binder.bind_context.BindColumn(ref, table_name, column_name); + if (error.empty()) { + ref.alias = column_name; + } + } + break; + } + default: + break; + } + ParsedExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child) { QualifyColumnNames(child); }); +} + +void ExpressionBinder::QualifyColumnNames(Binder &binder, unique_ptr &expr) { + WhereBinder where_binder(binder, binder.context); + where_binder.QualifyColumnNames(expr); +} + +unique_ptr ExpressionBinder::CreateStructExtract(unique_ptr base, + string field_name) { + + // we need to transform the struct extract if it is inside a lambda expression + // because we cannot bind to an existing table, so we remove the dummy table also + if (lambda_bindings && base->type == ExpressionType::COLUMN_REF) { + auto &lambda_column_ref = base->Cast(); + D_ASSERT(!lambda_column_ref.column_names.empty()); + + if (lambda_column_ref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos) { + D_ASSERT(lambda_column_ref.column_names.size() == 2); + auto lambda_param_name = lambda_column_ref.column_names.back(); + lambda_column_ref.column_names.clear(); + lambda_column_ref.column_names.push_back(lambda_param_name); + } + } + + vector> children; + children.push_back(std::move(base)); + children.push_back(make_uniq_base(Value(std::move(field_name)))); + auto extract_fun = make_uniq(ExpressionType::STRUCT_EXTRACT, std::move(children)); + return std::move(extract_fun); +} + +unique_ptr ExpressionBinder::CreateStructPack(ColumnRefExpression &colref) { + D_ASSERT(colref.column_names.size() <= 3); + string error_message; + auto &table_name = colref.column_names.back(); + auto binding = binder.bind_context.GetBinding(table_name, error_message); + if (!binding) { + return nullptr; + } + if (colref.column_names.size() >= 2) { + // "schema_name.table_name" + auto catalog_entry = binding->GetStandardEntry(); + if (!catalog_entry) { + return nullptr; + } + if (catalog_entry->name != table_name) { + return nullptr; + } + if (colref.column_names.size() == 2) { + auto &qualifier = colref.column_names[0]; + if (catalog_entry->catalog.GetName() != qualifier && catalog_entry->schema.name != qualifier) { + return nullptr; + } + } else if (colref.column_names.size() == 3) { + auto &catalog_name = colref.column_names[0]; + auto &schema_name = colref.column_names[1]; + if (catalog_entry->catalog.GetName() != catalog_name || catalog_entry->schema.name != schema_name) { + return nullptr; + } + } else { + throw InternalException("Expected 2 or 3 column names for CreateStructPack"); + } + } + // We found the table, now create the struct_pack expression + vector> child_expressions; + child_expressions.reserve(binding->names.size()); + for (const auto &column_name : binding->names) { + child_expressions.push_back(make_uniq(column_name, table_name)); + } + return make_uniq("struct_pack", std::move(child_expressions)); +} + +unique_ptr ExpressionBinder::QualifyColumnName(ColumnRefExpression &colref, string &error_message) { + idx_t column_parts = colref.column_names.size(); + // column names can have an arbitrary amount of dots + // here is how the resolution works: + if (column_parts == 1) { + // no dots (i.e. "part1") + // -> part1 refers to a column + // check if we can qualify the column name with the table name + auto qualified_colref = QualifyColumnName(colref.GetColumnName(), error_message); + if (qualified_colref) { + // we could: return it + return qualified_colref; + } + // we could not! Try creating an implicit struct_pack + return CreateStructPack(colref); + } else if (column_parts == 2) { + // one dot (i.e. "part1.part2") + // EITHER: + // -> part1 is a table, part2 is a column + // -> part1 is a column, part2 is a property of that column (i.e. struct_extract) + + // first check if part1 is a table, and part2 is a standard column + if (binder.HasMatchingBinding(colref.column_names[0], colref.column_names[1], error_message)) { + // it is! return the colref directly + return binder.bind_context.CreateColumnReference(colref.column_names[0], colref.column_names[1]); + } else { + // otherwise check if we can turn this into a struct extract + auto new_colref = make_uniq(colref.column_names[0]); + string other_error; + auto qualified_colref = QualifyColumnName(colref.column_names[0], other_error); + if (qualified_colref) { + // we could: create a struct extract + return CreateStructExtract(std::move(qualified_colref), colref.column_names[1]); + } + // we could not! Try creating an implicit struct_pack + return CreateStructPack(colref); + } + } else { + // two or more dots (i.e. "part1.part2.part3.part4...") + // -> part1 is a catalog, part2 is a schema, part3 is a table, part4 is a column name, part 5 and beyond are + // struct fields + // -> part1 is a catalog, part2 is a table, part3 is a column name, part4 and beyond are struct fields + // -> part1 is a schema, part2 is a table, part3 is a column name, part4 and beyond are struct fields + // -> part1 is a table, part2 is a column name, part3 and beyond are struct fields + // -> part1 is a column, part2 and beyond are struct fields + + // we always prefer the most top-level view + // i.e. in case of multiple resolution options, we resolve in order: + // -> 1. resolve "part1" as a catalog + // -> 2. resolve "part1" as a schema + // -> 3. resolve "part1" as a table + // -> 4. resolve "part1" as a column + + unique_ptr result_expr; + idx_t struct_extract_start; + // first check if part1 is a catalog + if (colref.column_names.size() > 3 && + binder.HasMatchingBinding(colref.column_names[0], colref.column_names[1], colref.column_names[2], + colref.column_names[3], error_message)) { + // part1 is a catalog - the column reference is "catalog.schema.table.column" + result_expr = binder.bind_context.CreateColumnReference(colref.column_names[0], colref.column_names[1], + colref.column_names[2], colref.column_names[3]); + struct_extract_start = 4; + } else if (binder.HasMatchingBinding(colref.column_names[0], INVALID_SCHEMA, colref.column_names[1], + colref.column_names[2], error_message)) { + // part1 is a catalog - the column reference is "catalog.table.column" + result_expr = binder.bind_context.CreateColumnReference(colref.column_names[0], INVALID_SCHEMA, + colref.column_names[1], colref.column_names[2]); + struct_extract_start = 3; + } else if (binder.HasMatchingBinding(colref.column_names[0], colref.column_names[1], colref.column_names[2], + error_message)) { + // part1 is a schema - the column reference is "schema.table.column" + // any additional fields are turned into struct_extract calls + result_expr = binder.bind_context.CreateColumnReference(colref.column_names[0], colref.column_names[1], + colref.column_names[2]); + struct_extract_start = 3; + } else if (binder.HasMatchingBinding(colref.column_names[0], colref.column_names[1], error_message)) { + // part1 is a table + // the column reference is "table.column" + // any additional fields are turned into struct_extract calls + result_expr = binder.bind_context.CreateColumnReference(colref.column_names[0], colref.column_names[1]); + struct_extract_start = 2; + } else { + // part1 could be a column + string col_error; + result_expr = QualifyColumnName(colref.column_names[0], col_error); + if (!result_expr) { + // it is not! Try creating an implicit struct_pack + return CreateStructPack(colref); + } + // it is! add the struct extract calls + struct_extract_start = 1; + } + for (idx_t i = struct_extract_start; i < colref.column_names.size(); i++) { + result_expr = CreateStructExtract(std::move(result_expr), colref.column_names[i]); + } + return result_expr; + } +} + +BindResult ExpressionBinder::BindExpression(ColumnRefExpression &colref_p, idx_t depth) { + if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { + return BindResult(make_uniq(Value(LogicalType::SQLNULL))); + } + string error_message; + auto expr = QualifyColumnName(colref_p, error_message); + if (!expr) { + return BindResult(binder.FormatError(colref_p, error_message)); + } + expr->query_location = colref_p.query_location; + + // a generated column returns a generated expression, a struct on a column returns a struct extract + if (expr->type != ExpressionType::COLUMN_REF) { + auto alias = expr->alias; + auto result = BindExpression(expr, depth); + if (result.expression) { + result.expression->alias = std::move(alias); + } + return result; + } + + auto &colref = expr->Cast(); + D_ASSERT(colref.IsQualified()); + auto &table_name = colref.GetTableName(); + + // individual column reference + // resolve to either a base table or a subquery expression + // if it was a macro parameter, let macro_binding bind it to the argument + // if it was a lambda parameter, let lambda_bindings bind it to the argument + + BindResult result; + + auto found_lambda_binding = false; + if (lambda_bindings) { + for (idx_t i = 0; i < lambda_bindings->size(); i++) { + if (table_name == (*lambda_bindings)[i].alias) { + result = (*lambda_bindings)[i].Bind(colref, i, depth); + found_lambda_binding = true; + break; + } + } + } + + if (!found_lambda_binding) { + if (binder.macro_binding && table_name == binder.macro_binding->alias) { + result = binder.macro_binding->Bind(colref, depth); + } else { + result = binder.bind_context.BindColumn(colref, depth); + } + } + + if (!result.HasError()) { + BoundColumnReferenceInfo ref; + ref.name = colref.column_names.back(); + ref.query_location = colref.query_location; + bound_columns.push_back(std::move(ref)); + } else { + result.error = binder.FormatError(colref_p, result.error); + } + return result; +} + +bool ExpressionBinder::QualifyColumnAlias(const ColumnRefExpression &colref) { + // Only BaseSelectBinder will have a valid col alias map, + // otherwise just return false + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_comparison_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_comparison_expression.cpp new file mode 100644 index 00000000..0a452f71 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_comparison_expression.cpp @@ -0,0 +1,147 @@ +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" +#include "duckdb/common/string_util.hpp" + +#include "duckdb/function/scalar/string_functions.hpp" + +#include "duckdb/common/types/decimal.hpp" + +#include "duckdb/main/config.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +unique_ptr ExpressionBinder::PushCollation(ClientContext &context, unique_ptr source, + const string &collation_p, bool equality_only) { + // replace default collation with system collation + string collation; + if (collation_p.empty()) { + collation = DBConfig::GetConfig(context).options.collation; + } else { + collation = collation_p; + } + collation = StringUtil::Lower(collation); + // bind the collation + if (collation.empty() || collation == "binary" || collation == "c" || collation == "posix") { + // binary collation: just skip + return source; + } + auto &catalog = Catalog::GetSystemCatalog(context); + auto splits = StringUtil::Split(StringUtil::Lower(collation), "."); + vector> entries; + for (auto &collation_argument : splits) { + auto &collation_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, collation_argument); + if (collation_entry.combinable) { + entries.insert(entries.begin(), collation_entry); + } else { + if (!entries.empty() && !entries.back().get().combinable) { + throw BinderException("Cannot combine collation types \"%s\" and \"%s\"", entries.back().get().name, + collation_entry.name); + } + entries.push_back(collation_entry); + } + } + for (auto &entry : entries) { + auto &collation_entry = entry.get(); + if (equality_only && collation_entry.not_required_for_equality) { + continue; + } + vector> children; + children.push_back(std::move(source)); + + FunctionBinder function_binder(context); + auto function = function_binder.BindScalarFunction(collation_entry.function, std::move(children)); + source = std::move(function); + } + return source; +} + +void ExpressionBinder::TestCollation(ClientContext &context, const string &collation) { + PushCollation(context, make_uniq(Value("")), collation); +} + +LogicalType BoundComparisonExpression::BindComparison(LogicalType left_type, LogicalType right_type) { + auto result_type = LogicalType::MaxLogicalType(left_type, right_type); + switch (result_type.id()) { + case LogicalTypeId::DECIMAL: { + // result is a decimal: we need the maximum width and the maximum scale over width + vector argument_types = {left_type, right_type}; + uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0; + for (idx_t i = 0; i < argument_types.size(); i++) { + uint8_t width, scale; + auto can_convert = argument_types[i].GetDecimalProperties(width, scale); + if (!can_convert) { + return result_type; + } + max_width = MaxValue(width, max_width); + max_scale = MaxValue(scale, max_scale); + max_width_over_scale = MaxValue(width - scale, max_width_over_scale); + } + max_width = MaxValue(max_scale + max_width_over_scale, max_width); + if (max_width > Decimal::MAX_WIDTH_DECIMAL) { + // target width does not fit in decimal: truncate the scale (if possible) to try and make it fit + max_width = Decimal::MAX_WIDTH_DECIMAL; + } + return LogicalType::DECIMAL(max_width, max_scale); + } + case LogicalTypeId::VARCHAR: + // for comparison with strings, we prefer to bind to the numeric types + if (left_type.IsNumeric() || left_type.id() == LogicalTypeId::BOOLEAN) { + return left_type; + } else if (right_type.IsNumeric() || right_type.id() == LogicalTypeId::BOOLEAN) { + return right_type; + } else { + // else: check if collations are compatible + auto left_collation = StringType::GetCollation(left_type); + auto right_collation = StringType::GetCollation(right_type); + if (!left_collation.empty() && !right_collation.empty() && left_collation != right_collation) { + throw BinderException("Cannot combine types with different collation!"); + } + } + return result_type; + default: + return result_type; + } +} + +BindResult ExpressionBinder::BindExpression(ComparisonExpression &expr, idx_t depth) { + // first try to bind the children of the case expression + string error; + BindChild(expr.left, depth, error); + BindChild(expr.right, depth, error); + if (!error.empty()) { + return BindResult(error); + } + + // the children have been successfully resolved + auto &left = BoundExpression::GetExpression(*expr.left); + auto &right = BoundExpression::GetExpression(*expr.right); + auto left_sql_type = left->return_type; + auto right_sql_type = right->return_type; + // cast the input types to the same type + // now obtain the result type of the input types + auto input_type = BoundComparisonExpression::BindComparison(left_sql_type, right_sql_type); + // add casts (if necessary) + left = BoundCastExpression::AddCastToType(context, std::move(left), input_type, + input_type.id() == LogicalTypeId::ENUM); + right = BoundCastExpression::AddCastToType(context, std::move(right), input_type, + input_type.id() == LogicalTypeId::ENUM); + + if (input_type.id() == LogicalTypeId::VARCHAR) { + // handle collation + auto collation = StringType::GetCollation(input_type); + left = PushCollation(context, std::move(left), collation, expr.type == ExpressionType::COMPARE_EQUAL); + right = PushCollation(context, std::move(right), collation, expr.type == ExpressionType::COMPARE_EQUAL); + } + // now create the bound comparison expression + return BindResult(make_uniq(expr.type, std::move(left), std::move(right))); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_conjunction_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_conjunction_expression.cpp new file mode 100644 index 00000000..70b990a0 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_conjunction_expression.cpp @@ -0,0 +1,29 @@ +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(ConjunctionExpression &expr, idx_t depth) { + // first try to bind the children of the case expression + string error; + for (idx_t i = 0; i < expr.children.size(); i++) { + BindChild(expr.children[i], depth, error); + } + if (!error.empty()) { + return BindResult(error); + } + // the children have been successfully resolved + // cast the input types to boolean (if necessary) + // and construct the bound conjunction expression + auto result = make_uniq(expr.type); + for (auto &child_expr : expr.children) { + auto &child = BoundExpression::GetExpression(*child_expr); + result->children.push_back(BoundCastExpression::AddCastToType(context, std::move(child), LogicalType::BOOLEAN)); + } + // now create the bound conjunction expression + return BindResult(std::move(result)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_constant_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_constant_expression.cpp new file mode 100644 index 00000000..1222cf46 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_constant_expression.cpp @@ -0,0 +1,11 @@ +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(ConstantExpression &expr, idx_t depth) { + return BindResult(make_uniq(expr.value)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp new file mode 100644 index 00000000..c2a6179e --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp @@ -0,0 +1,261 @@ +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/lambda_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_lambda_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(FunctionExpression &function, idx_t depth, + unique_ptr &expr_ptr) { + // lookup the function in the catalog + QueryErrorContext error_context(binder.root_statement, function.query_location); + auto func = Catalog::GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, function.catalog, function.schema, + function.function_name, OnEntryNotFound::RETURN_NULL, error_context); + if (!func) { + // function was not found - check if we this is a table function + auto table_func = + Catalog::GetEntry(context, CatalogType::TABLE_FUNCTION_ENTRY, function.catalog, function.schema, + function.function_name, OnEntryNotFound::RETURN_NULL, error_context); + if (table_func) { + throw BinderException(binder.FormatError( + function, + StringUtil::Format("Function \"%s\" is a table function but it was used as a scalar function. This " + "function has to be called in a FROM clause (similar to a table).", + function.function_name))); + } + // not a table function - check if the schema is set + if (!function.schema.empty()) { + // the schema is set - check if we can turn this the schema into a column ref + string error; + unique_ptr colref; + if (function.catalog.empty()) { + colref = make_uniq(function.schema); + } else { + colref = make_uniq(function.schema, function.catalog); + } + auto new_colref = QualifyColumnName(*colref, error); + bool is_col = error.empty() ? true : false; + bool is_col_alias = QualifyColumnAlias(*colref); + + if (is_col || is_col_alias) { + // we can! transform this into a function call on the column + // i.e. "x.lower()" becomes "lower(x)" + function.children.insert(function.children.begin(), std::move(colref)); + function.catalog = INVALID_CATALOG; + function.schema = INVALID_SCHEMA; + } + } + // rebind the function + func = Catalog::GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, function.catalog, function.schema, + function.function_name, OnEntryNotFound::THROW_EXCEPTION, error_context); + } + + if (func->type != CatalogType::AGGREGATE_FUNCTION_ENTRY && + (function.distinct || function.filter || !function.order_bys->orders.empty())) { + throw InvalidInputException("Function \"%s\" is a %s. \"DISTINCT\", \"FILTER\", and \"ORDER BY\" are only " + "applicable to aggregate functions.", + function.function_name, CatalogTypeToString(func->type)); + } + + switch (func->type) { + case CatalogType::SCALAR_FUNCTION_ENTRY: { + // scalar function + + // check for lambda parameters, ignore ->> operator (JSON extension) + bool try_bind_lambda = false; + if (function.function_name != "->>") { + for (auto &child : function.children) { + if (child->expression_class == ExpressionClass::LAMBDA) { + try_bind_lambda = true; + } + } + } + + if (try_bind_lambda) { + auto result = BindLambdaFunction(function, func->Cast(), depth); + if (!result.HasError()) { + // Lambda bind successful + return result; + } + } + + // other scalar function + return BindFunction(function, func->Cast(), depth); + } + case CatalogType::MACRO_ENTRY: + // macro function + return BindMacro(function, func->Cast(), depth, expr_ptr); + default: + // aggregate function + return BindAggregate(function, func->Cast(), depth); + } +} + +BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFunctionCatalogEntry &func, idx_t depth) { + + // bind the children of the function expression + string error; + + // bind of each child + for (idx_t i = 0; i < function.children.size(); i++) { + BindChild(function.children[i], depth, error); + } + + if (!error.empty()) { + return BindResult(error); + } + if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { + return BindResult(make_uniq(Value(LogicalType::SQLNULL))); + } + + // all children bound successfully + // extract the children and types + vector> children; + for (idx_t i = 0; i < function.children.size(); i++) { + auto &child = BoundExpression::GetExpression(*function.children[i]); + children.push_back(std::move(child)); + } + + FunctionBinder function_binder(context); + unique_ptr result = + function_binder.BindScalarFunction(func, std::move(children), error, function.is_operator, &binder); + if (!result) { + throw BinderException(binder.FormatError(function, error)); + } + return BindResult(std::move(result)); +} + +BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, ScalarFunctionCatalogEntry &func, + idx_t depth) { + + // bind the children of the function expression + string error; + + if (function.children.size() != 2) { + return BindResult("Invalid function arguments!"); + } + D_ASSERT(function.children[1]->GetExpressionClass() == ExpressionClass::LAMBDA); + + // bind the list parameter + BindChild(function.children[0], depth, error); + if (!error.empty()) { + return BindResult(error); + } + + // get the logical type of the children of the list + auto &list_child = BoundExpression::GetExpression(*function.children[0]); + if (list_child->return_type.id() != LogicalTypeId::LIST && list_child->return_type.id() != LogicalTypeId::SQLNULL && + list_child->return_type.id() != LogicalTypeId::UNKNOWN) { + return BindResult(" Invalid LIST argument to " + function.function_name + "!"); + } + + LogicalType list_child_type = list_child->return_type.id(); + if (list_child->return_type.id() != LogicalTypeId::SQLNULL && + list_child->return_type.id() != LogicalTypeId::UNKNOWN) { + list_child_type = ListType::GetChildType(list_child->return_type); + } + + // bind the lambda parameter + auto &lambda_expr = function.children[1]->Cast(); + BindResult bind_lambda_result = BindExpression(lambda_expr, depth, true, list_child_type); + + if (bind_lambda_result.HasError()) { + error = bind_lambda_result.error; + } else { + // successfully bound: replace the node with a BoundExpression + auto alias = function.children[1]->alias; + bind_lambda_result.expression->alias = alias; + if (!alias.empty()) { + bind_lambda_result.expression->alias = alias; + } + function.children[1] = make_uniq(std::move(bind_lambda_result.expression)); + } + + if (!error.empty()) { + return BindResult(error); + } + if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { + return BindResult(make_uniq(Value(LogicalType::SQLNULL))); + } + + // all children bound successfully + // extract the children and types + vector> children; + for (idx_t i = 0; i < function.children.size(); i++) { + auto &child = BoundExpression::GetExpression(*function.children[i]); + children.push_back(std::move(child)); + } + + // capture the (lambda) columns + auto &bound_lambda_expr = children.back()->Cast(); + CaptureLambdaColumns(bound_lambda_expr.captures, list_child_type, bound_lambda_expr.lambda_expr); + + FunctionBinder function_binder(context); + unique_ptr result = + function_binder.BindScalarFunction(func, std::move(children), error, function.is_operator, &binder); + if (!result) { + throw BinderException(binder.FormatError(function, error)); + } + + auto &bound_function_expr = result->Cast(); + D_ASSERT(bound_function_expr.children.size() == 2); + + // remove the lambda expression from the children + auto lambda = std::move(bound_function_expr.children.back()); + bound_function_expr.children.pop_back(); + auto &bound_lambda = lambda->Cast(); + + // push back (in reverse order) any nested lambda parameters so that we can later use them in the lambda expression + // (rhs) + if (lambda_bindings) { + for (idx_t i = lambda_bindings->size(); i > 0; i--) { + + idx_t lambda_index = lambda_bindings->size() - i + 1; + auto &binding = (*lambda_bindings)[i - 1]; + + D_ASSERT(binding.names.size() == 1); + D_ASSERT(binding.types.size() == 1); + + auto bound_lambda_param = + make_uniq(binding.names[0], binding.types[0], lambda_index); + bound_function_expr.children.push_back(std::move(bound_lambda_param)); + } + } + + // push back the captures into the children vector and the correct return types into the bound_function arguments + for (auto &capture : bound_lambda.captures) { + bound_function_expr.children.push_back(std::move(capture)); + } + + return BindResult(std::move(result)); +} + +BindResult ExpressionBinder::BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, + idx_t depth) { + return BindResult(binder.FormatError(expr, UnsupportedAggregateMessage())); +} + +BindResult ExpressionBinder::BindUnnest(FunctionExpression &expr, idx_t depth, bool root_expression) { + return BindResult(binder.FormatError(expr, UnsupportedUnnestMessage())); +} + +string ExpressionBinder::UnsupportedAggregateMessage() { + return "Aggregate functions are not supported here"; +} + +string ExpressionBinder::UnsupportedUnnestMessage() { + return "UNNEST not supported here"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_lambda.cpp b/src/duckdb/src/planner/binder/expression/bind_lambda.cpp new file mode 100644 index 00000000..56bb87ed --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_lambda.cpp @@ -0,0 +1,178 @@ +#include "duckdb/parser/expression/lambda_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/bind_context.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_lambdaref_expression.hpp" +#include "duckdb/planner/expression/bound_lambda_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(LambdaExpression &expr, idx_t depth, const bool is_lambda, + const LogicalType &list_child_type) { + + if (!is_lambda) { + // this is for binding JSON + auto lhs_expr = expr.lhs->Copy(); + OperatorExpression arrow_expr(ExpressionType::ARROW, std::move(lhs_expr), expr.expr->Copy()); + return BindExpression(arrow_expr, depth); + } + + // binding the lambda expression + D_ASSERT(expr.lhs); + if (expr.lhs->expression_class != ExpressionClass::FUNCTION && + expr.lhs->expression_class != ExpressionClass::COLUMN_REF) { + throw BinderException( + "Invalid parameter list! Parameters must be comma-separated column names, e.g. x or (x, y)."); + } + + // move the lambda parameters to the params vector + if (expr.lhs->expression_class == ExpressionClass::COLUMN_REF) { + expr.params.push_back(std::move(expr.lhs)); + } else { + auto &func_expr = expr.lhs->Cast(); + for (idx_t i = 0; i < func_expr.children.size(); i++) { + expr.params.push_back(std::move(func_expr.children[i])); + } + } + D_ASSERT(!expr.params.empty()); + + // create dummy columns for the lambda parameters (lhs) + vector column_types; + vector column_names; + vector params_strings; + + // positional parameters as column references + for (idx_t i = 0; i < expr.params.size(); i++) { + if (expr.params[i]->GetExpressionClass() != ExpressionClass::COLUMN_REF) { + throw BinderException("Parameter must be a column name."); + } + + auto column_ref = expr.params[i]->Cast(); + if (column_ref.IsQualified()) { + throw BinderException("Invalid parameter name '%s': must be unqualified", column_ref.ToString()); + } + + column_types.emplace_back(list_child_type); + column_names.push_back(column_ref.GetColumnName()); + params_strings.push_back(expr.params[i]->ToString()); + } + + // base table alias + auto params_alias = StringUtil::Join(params_strings, ", "); + if (params_strings.size() > 1) { + params_alias = "(" + params_alias + ")"; + } + + // create a lambda binding and push it to the lambda bindings vector + vector local_bindings; + if (!lambda_bindings) { + lambda_bindings = &local_bindings; + } + DummyBinding new_lambda_binding(column_types, column_names, params_alias); + lambda_bindings->push_back(new_lambda_binding); + + // bind the parameter expressions + for (idx_t i = 0; i < expr.params.size(); i++) { + auto result = BindExpression(expr.params[i], depth, false); + if (result.HasError()) { + throw InternalException("Error during lambda binding: %s", result.error); + } + } + + auto result = BindExpression(expr.expr, depth, false); + lambda_bindings->pop_back(); + + // successfully bound a subtree of nested lambdas, set this to nullptr in case other parts of the + // query also contain lambdas + if (lambda_bindings->empty()) { + lambda_bindings = nullptr; + } + + if (result.HasError()) { + throw BinderException(result.error); + } + + return BindResult(make_uniq(ExpressionType::LAMBDA, LogicalType::LAMBDA, + std::move(result.expression), params_strings.size())); +} + +void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &original, + unique_ptr &replacement, + vector> &captures, + LogicalType &list_child_type) { + + // check if the original expression is a lambda parameter + if (original->expression_class == ExpressionClass::BOUND_LAMBDA_REF) { + + // determine if this is the lambda parameter + auto &bound_lambda_ref = original->Cast(); + auto alias = bound_lambda_ref.alias; + + if (lambda_bindings && bound_lambda_ref.lambda_index != lambda_bindings->size()) { + + D_ASSERT(bound_lambda_ref.lambda_index < lambda_bindings->size()); + auto &lambda_binding = (*lambda_bindings)[bound_lambda_ref.lambda_index]; + + D_ASSERT(lambda_binding.names.size() == 1); + D_ASSERT(lambda_binding.types.size() == 1); + // refers to a lambda parameter outside of the current lambda function + replacement = + make_uniq(lambda_binding.names[0], lambda_binding.types[0], + lambda_bindings->size() - bound_lambda_ref.lambda_index + 1); + + } else { + // refers to current lambda parameter + replacement = make_uniq(alias, list_child_type, 0); + } + + } else { + // always at least the current lambda parameter + idx_t index_offset = 1; + if (lambda_bindings) { + index_offset += lambda_bindings->size(); + } + + // this is not a lambda parameter, so we need to create a new argument for the arguments vector + replacement = make_uniq(original->alias, original->return_type, + captures.size() + index_offset + 1); + captures.push_back(std::move(original)); + } +} + +void ExpressionBinder::CaptureLambdaColumns(vector> &captures, LogicalType &list_child_type, + unique_ptr &expr) { + + if (expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { + throw InvalidInputException("Subqueries are not supported in lambda expressions!"); + } + + // these expression classes do not have children, transform them + if (expr->expression_class == ExpressionClass::BOUND_CONSTANT || + expr->expression_class == ExpressionClass::BOUND_COLUMN_REF || + expr->expression_class == ExpressionClass::BOUND_PARAMETER || + expr->expression_class == ExpressionClass::BOUND_LAMBDA_REF) { + + // move the expr because we are going to replace it + auto original = std::move(expr); + unique_ptr replacement; + + TransformCapturedLambdaColumn(original, replacement, captures, list_child_type); + + // replace the expression + expr = std::move(replacement); + + } else { + // recursively enumerate the children of the expression + ExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child) { CaptureLambdaColumns(captures, list_child_type, child); }); + } + + expr->Verify(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp new file mode 100644 index 00000000..cce36d49 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp @@ -0,0 +1,92 @@ +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/scalar_macro_function.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +void ExpressionBinder::ReplaceMacroParametersRecursive(unique_ptr &expr) { + switch (expr->GetExpressionClass()) { + case ExpressionClass::COLUMN_REF: { + // if expr is a parameter, replace it with its argument + auto &colref = expr->Cast(); + bool bind_macro_parameter = false; + if (colref.IsQualified()) { + bind_macro_parameter = false; + if (colref.GetTableName().find(DummyBinding::DUMMY_NAME) != string::npos) { + bind_macro_parameter = true; + } + } else { + bind_macro_parameter = macro_binding->HasMatchingBinding(colref.GetColumnName()); + } + if (bind_macro_parameter) { + D_ASSERT(macro_binding->HasMatchingBinding(colref.GetColumnName())); + expr = macro_binding->ParamToArg(colref); + } + return; + } + case ExpressionClass::SUBQUERY: { + // replacing parameters within a subquery is slightly different + auto &sq = (expr->Cast()).subquery; + ParsedExpressionIterator::EnumerateQueryNodeChildren( + *sq->node, [&](unique_ptr &child) { ReplaceMacroParametersRecursive(child); }); + break; + } + default: // fall through + break; + } + // unfold child expressions + ParsedExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child) { ReplaceMacroParametersRecursive(child); }); +} + +BindResult ExpressionBinder::BindMacro(FunctionExpression &function, ScalarMacroCatalogEntry ¯o_func, idx_t depth, + unique_ptr &expr) { + // recast function so we can access the scalar member function->expression + auto ¯o_def = macro_func.function->Cast(); + + // validate the arguments and separate positional and default arguments + vector> positionals; + unordered_map> defaults; + + string error = + MacroFunction::ValidateArguments(*macro_func.function, macro_func.name, function, positionals, defaults); + if (!error.empty()) { + throw BinderException(binder.FormatError(*expr, error)); + } + + // create a MacroBinding to bind this macro's parameters to its arguments + vector types; + vector names; + // positional parameters + for (idx_t i = 0; i < macro_def.parameters.size(); i++) { + types.emplace_back(LogicalType::SQLNULL); + auto ¶m = macro_def.parameters[i]->Cast(); + names.push_back(param.GetColumnName()); + } + // default parameters + for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { + types.emplace_back(LogicalType::SQLNULL); + names.push_back(it->first); + // now push the defaults into the positionals + positionals.push_back(std::move(defaults[it->first])); + } + auto new_macro_binding = make_uniq(types, names, macro_func.name); + new_macro_binding->arguments = &positionals; + macro_binding = new_macro_binding.get(); + + // replace current expression with stored macro expression + expr = macro_def.expression->Copy(); + + // now replace the parameters + ReplaceMacroParametersRecursive(expr); + + // bind the unfolded macro + return BindExpression(expr, depth); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp new file mode 100644 index 00000000..c19ee655 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp @@ -0,0 +1,157 @@ +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +static LogicalType ResolveNotType(OperatorExpression &op, vector> &children) { + // NOT expression, cast child to BOOLEAN + D_ASSERT(children.size() == 1); + children[0] = BoundCastExpression::AddDefaultCastToType(std::move(children[0]), LogicalType::BOOLEAN); + return LogicalType(LogicalTypeId::BOOLEAN); +} + +static LogicalType ResolveInType(OperatorExpression &op, vector> &children) { + if (children.empty()) { + throw InternalException("IN requires at least a single child node"); + } + // get the maximum type from the children + LogicalType max_type = children[0]->return_type; + bool any_varchar = children[0]->return_type == LogicalType::VARCHAR; + bool any_enum = children[0]->return_type.id() == LogicalTypeId::ENUM; + for (idx_t i = 1; i < children.size(); i++) { + max_type = LogicalType::MaxLogicalType(max_type, children[i]->return_type); + if (children[i]->return_type == LogicalType::VARCHAR) { + any_varchar = true; + } + if (children[i]->return_type.id() == LogicalTypeId::ENUM) { + any_enum = true; + } + } + if (any_varchar && any_enum) { + // For the coalesce function, we must be sure we always upcast the parameters to VARCHAR, if there are at least + // one enum and one varchar + max_type = LogicalType::VARCHAR; + } + + // cast all children to the same type + for (idx_t i = 0; i < children.size(); i++) { + children[i] = BoundCastExpression::AddDefaultCastToType(std::move(children[i]), max_type); + } + // (NOT) IN always returns a boolean + return LogicalType::BOOLEAN; +} + +static LogicalType ResolveOperatorType(OperatorExpression &op, vector> &children) { + switch (op.type) { + case ExpressionType::OPERATOR_IS_NULL: + case ExpressionType::OPERATOR_IS_NOT_NULL: + // IS (NOT) NULL always returns a boolean, and does not cast its children + if (!children[0]->return_type.IsValid()) { + throw ParameterNotResolvedException(); + } + return LogicalType::BOOLEAN; + case ExpressionType::COMPARE_IN: + case ExpressionType::COMPARE_NOT_IN: + return ResolveInType(op, children); + case ExpressionType::OPERATOR_COALESCE: { + ResolveInType(op, children); + return children[0]->return_type; + } + case ExpressionType::OPERATOR_NOT: + return ResolveNotType(op, children); + default: + throw InternalException("Unrecognized expression type for ResolveOperatorType"); + } +} + +BindResult ExpressionBinder::BindGroupingFunction(OperatorExpression &op, idx_t depth) { + return BindResult("GROUPING function is not supported here"); +} + +BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) { + if (op.type == ExpressionType::GROUPING_FUNCTION) { + return BindGroupingFunction(op, depth); + } + // bind the children of the operator expression + string error; + for (idx_t i = 0; i < op.children.size(); i++) { + BindChild(op.children[i], depth, error); + } + if (!error.empty()) { + return BindResult(error); + } + // all children bound successfully + string function_name; + switch (op.type) { + case ExpressionType::ARRAY_EXTRACT: { + D_ASSERT(op.children[0]->expression_class == ExpressionClass::BOUND_EXPRESSION); + auto &b_exp = BoundExpression::GetExpression(*op.children[0]); + if (b_exp->return_type.id() == LogicalTypeId::MAP) { + function_name = "map_extract"; + } else { + function_name = "array_extract"; + } + break; + } + case ExpressionType::ARRAY_SLICE: + function_name = "array_slice"; + break; + case ExpressionType::STRUCT_EXTRACT: { + D_ASSERT(op.children.size() == 2); + D_ASSERT(op.children[0]->expression_class == ExpressionClass::BOUND_EXPRESSION); + D_ASSERT(op.children[1]->expression_class == ExpressionClass::BOUND_EXPRESSION); + auto &extract_exp = BoundExpression::GetExpression(*op.children[0]); + auto &name_exp = BoundExpression::GetExpression(*op.children[1]); + auto extract_expr_type = extract_exp->return_type.id(); + if (extract_expr_type != LogicalTypeId::STRUCT && extract_expr_type != LogicalTypeId::UNION && + extract_expr_type != LogicalTypeId::SQLNULL) { + return BindResult(StringUtil::Format( + "Cannot extract field %s from expression \"%s\" because it is not a struct or a union", + name_exp->ToString(), extract_exp->ToString())); + } + function_name = extract_expr_type == LogicalTypeId::UNION ? "union_extract" : "struct_extract"; + break; + } + case ExpressionType::ARRAY_CONSTRUCTOR: + function_name = "list_value"; + break; + case ExpressionType::ARROW: + function_name = "json_extract"; + break; + default: + break; + } + if (!function_name.empty()) { + auto function = make_uniq_base(function_name, std::move(op.children)); + return BindExpression(function, depth, false); + } + + vector> children; + for (idx_t i = 0; i < op.children.size(); i++) { + D_ASSERT(op.children[i]->expression_class == ExpressionClass::BOUND_EXPRESSION); + children.push_back(std::move(BoundExpression::GetExpression(*op.children[i]))); + } + // now resolve the types + LogicalType result_type = ResolveOperatorType(op, children); + if (op.type == ExpressionType::OPERATOR_COALESCE) { + if (children.empty()) { + throw BinderException("COALESCE needs at least one child"); + } + if (children.size() == 1) { + return BindResult(std::move(children[0])); + } + } + + auto result = make_uniq(op.type, result_type); + for (auto &child : children) { + result->children.push_back(std::move(child)); + } + return BindResult(std::move(result)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp new file mode 100644 index 00000000..d46cf023 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp @@ -0,0 +1,32 @@ +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t depth) { + if (!binder.parameters) { + throw BinderException("Unexpected prepared parameter. This type of statement can't be prepared!"); + } + auto parameter_id = expr.identifier; + + D_ASSERT(binder.parameters); + // Check if a parameter value has already been supplied + auto ¶meter_data = binder.parameters->GetParameterData(); + auto param_data_it = parameter_data.find(parameter_id); + if (param_data_it != parameter_data.end()) { + // it has! emit a constant directly + auto &data = param_data_it->second; + auto constant = make_uniq(data.GetValue()); + constant->alias = expr.alias; + constant->return_type = binder.parameters->GetReturnType(parameter_id); + return BindResult(std::move(constant)); + } + + auto bound_parameter = binder.parameters->BindParameterExpression(expr); + return BindResult(std::move(bound_parameter)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_positional_reference_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_positional_reference_expression.cpp new file mode 100644 index 00000000..fbb27206 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_positional_reference_expression.cpp @@ -0,0 +1,19 @@ +#include "duckdb/parser/expression/positional_reference_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +BindResult ExpressionBinder::BindPositionalReference(unique_ptr &expr, idx_t depth, + bool root_expression) { + auto &ref = expr->Cast(); + if (depth != 0) { + throw InternalException("Positional reference expression could not be bound"); + } + // replace the positional reference with a column + auto column = binder.bind_context.PositionToColumn(ref); + expr = std::move(column); + return BindExpression(expr, depth, root_expression); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp new file mode 100644 index 00000000..c0a25d93 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp @@ -0,0 +1,194 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/planner/expression_binder/table_function_binder.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "re2/re2.h" + +namespace duckdb { + +string GetColumnsStringValue(ParsedExpression &expr) { + if (expr.type == ExpressionType::COLUMN_REF) { + auto &colref = expr.Cast(); + return colref.GetColumnName(); + } else { + return expr.ToString(); + } +} + +bool Binder::FindStarExpression(unique_ptr &expr, StarExpression **star, bool is_root, + bool in_columns) { + bool has_star = false; + if (expr->GetExpressionClass() == ExpressionClass::STAR) { + auto ¤t_star = expr->Cast(); + if (!current_star.columns) { + if (is_root) { + *star = ¤t_star; + return true; + } + if (!in_columns) { + throw BinderException( + "STAR expression is only allowed as the root element of an expression. Use COLUMNS(*) instead."); + } + // star expression inside a COLUMNS - convert to a constant list + if (!current_star.replace_list.empty()) { + throw BinderException( + "STAR expression with REPLACE list is only allowed as the root element of COLUMNS"); + } + vector> star_list; + bind_context.GenerateAllColumnExpressions(current_star, star_list); + + vector values; + values.reserve(star_list.size()); + for (auto &expr : star_list) { + values.emplace_back(GetColumnsStringValue(*expr)); + } + D_ASSERT(!values.empty()); + + expr = make_uniq(Value::LIST(LogicalType::VARCHAR, values)); + return true; + } + if (in_columns) { + throw BinderException("COLUMNS expression is not allowed inside another COLUMNS expression"); + } + in_columns = true; + if (*star) { + // we can have multiple + if (!(*star)->Equals(current_star)) { + throw BinderException( + FormatError(*expr, "Multiple different STAR/COLUMNS in the same expression are not supported")); + } + return true; + } + *star = ¤t_star; + has_star = true; + } + ParsedExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child_expr) { + if (FindStarExpression(child_expr, star, false, in_columns)) { + has_star = true; + } + }); + return has_star; +} + +void Binder::ReplaceStarExpression(unique_ptr &expr, unique_ptr &replacement) { + D_ASSERT(expr); + if (expr->GetExpressionClass() == ExpressionClass::STAR) { + D_ASSERT(replacement); + expr = replacement->Copy(); + return; + } + ParsedExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child_expr) { ReplaceStarExpression(child_expr, replacement); }); +} + +void Binder::ExpandStarExpression(unique_ptr expr, + vector> &new_select_list) { + StarExpression *star = nullptr; + if (!FindStarExpression(expr, &star, true, false)) { + // no star expression: add it as-is + D_ASSERT(!star); + new_select_list.push_back(std::move(expr)); + return; + } + D_ASSERT(star); + vector> star_list; + // we have star expressions! expand the list of star expressions + bind_context.GenerateAllColumnExpressions(*star, star_list); + + if (star->expr) { + // COLUMNS with an expression + // two options: + // VARCHAR parameter <- this is a regular expression + // LIST of VARCHAR parameters <- this is a set of columns + TableFunctionBinder binder(*this, context); + auto child = star->expr->Copy(); + auto result = binder.Bind(child); + if (!result->IsFoldable()) { + // cannot resolve parameters here + if (star->expr->HasParameter()) { + throw ParameterNotResolvedException(); + } else { + throw BinderException("Unsupported expression in COLUMNS"); + } + } + auto val = ExpressionExecutor::EvaluateScalar(context, *result); + if (val.type().id() == LogicalTypeId::VARCHAR) { + // regex + if (val.IsNull()) { + throw BinderException("COLUMNS does not support NULL as regex argument"); + } + auto ®ex_str = StringValue::Get(val); + duckdb_re2::RE2 regex(regex_str); + if (!regex.error().empty()) { + auto err = StringUtil::Format("Failed to compile regex \"%s\": %s", regex_str, regex.error()); + throw BinderException(FormatError(*star, err)); + } + vector> new_list; + for (idx_t i = 0; i < star_list.size(); i++) { + auto &colref = star_list[i]->Cast(); + if (!RE2::PartialMatch(colref.GetColumnName(), regex)) { + continue; + } + new_list.push_back(std::move(star_list[i])); + } + if (new_list.empty()) { + auto err = StringUtil::Format("No matching columns found that match regex \"%s\"", regex_str); + throw BinderException(FormatError(*star, err)); + } + star_list = std::move(new_list); + } else if (val.type().id() == LogicalTypeId::LIST && + ListType::GetChildType(val.type()).id() == LogicalTypeId::VARCHAR) { + // list of varchar columns + if (val.IsNull() || ListValue::GetChildren(val).empty()) { + auto err = + StringUtil::Format("Star expression \"%s\" resulted in an empty set of columns", star->ToString()); + throw BinderException(FormatError(*star, err)); + } + auto &children = ListValue::GetChildren(val); + vector> new_list; + // scan the list for all selected columns and construct a lookup table + case_insensitive_map_t selected_set; + for (auto &child : children) { + selected_set.insert(make_pair(StringValue::Get(child), false)); + } + // now check the list of all possible expressions and select which ones make it in + for (auto &expr : star_list) { + auto str = GetColumnsStringValue(*expr); + auto entry = selected_set.find(str); + if (entry != selected_set.end()) { + new_list.push_back(std::move(expr)); + entry->second = true; + } + } + // check if all expressions found a match + for (auto &entry : selected_set) { + if (!entry.second) { + throw BinderException("Column \"%s\" was selected but was not found in the FROM clause", + entry.first); + } + } + star_list = std::move(new_list); + } else { + throw BinderException(FormatError( + *star, "COLUMNS expects either a VARCHAR argument (regex) or a LIST of VARCHAR (list of columns)")); + } + } + + // now perform the replacement + for (idx_t i = 0; i < star_list.size(); i++) { + auto new_expr = expr->Copy(); + ReplaceStarExpression(new_expr, star_list[i]); + new_select_list.push_back(std::move(new_expr)); + } +} + +void Binder::ExpandStarExpressions(vector> &select_list, + vector> &new_select_list) { + for (auto &select_element : select_list) { + ExpandStarExpression(std::move(select_element), new_select_list); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp new file mode 100644 index 00000000..c8802e5e --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp @@ -0,0 +1,106 @@ +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_subquery_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +class BoundSubqueryNode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::BOUND_SUBQUERY_NODE; + +public: + BoundSubqueryNode(shared_ptr subquery_binder, unique_ptr bound_node, + unique_ptr subquery) + : QueryNode(QueryNodeType::BOUND_SUBQUERY_NODE), subquery_binder(std::move(subquery_binder)), + bound_node(std::move(bound_node)), subquery(std::move(subquery)) { + } + + shared_ptr subquery_binder; + unique_ptr bound_node; + unique_ptr subquery; + + const vector> &GetSelectList() const override { + throw InternalException("Cannot get select list of bound subquery node"); + } + + string ToString() const override { + throw InternalException("Cannot ToString bound subquery node"); + } + unique_ptr Copy() const override { + throw InternalException("Cannot copy bound subquery node"); + } + + void Serialize(Serializer &serializer) const override { + throw InternalException("Cannot serialize bound subquery node"); + } +}; + +BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t depth) { + if (expr.subquery->node->type != QueryNodeType::BOUND_SUBQUERY_NODE) { + D_ASSERT(depth == 0); + // first bind the actual subquery in a new binder + auto subquery_binder = Binder::CreateBinder(context, &binder); + subquery_binder->can_contain_nulls = true; + auto bound_node = subquery_binder->BindNode(*expr.subquery->node); + // check the correlated columns of the subquery for correlated columns with depth > 1 + for (idx_t i = 0; i < subquery_binder->correlated_columns.size(); i++) { + CorrelatedColumnInfo corr = subquery_binder->correlated_columns[i]; + if (corr.depth > 1) { + // depth > 1, the column references the query ABOVE the current one + // add to the set of correlated columns for THIS query + corr.depth -= 1; + binder.AddCorrelatedColumn(corr); + } + } + if (expr.subquery_type != SubqueryType::EXISTS && bound_node->types.size() > 1) { + throw BinderException(binder.FormatError( + expr, StringUtil::Format("Subquery returns %zu columns - expected 1", bound_node->types.size()))); + } + auto prior_subquery = std::move(expr.subquery); + expr.subquery = make_uniq(); + expr.subquery->node = + make_uniq(std::move(subquery_binder), std::move(bound_node), std::move(prior_subquery)); + } + // now bind the child node of the subquery + if (expr.child) { + // first bind the children of the subquery, if any + string error = Bind(expr.child, depth); + if (!error.empty()) { + return BindResult(error); + } + } + // both binding the child and binding the subquery was successful + D_ASSERT(expr.subquery->node->type == QueryNodeType::BOUND_SUBQUERY_NODE); + auto &bound_subquery = expr.subquery->node->Cast(); + auto subquery_binder = std::move(bound_subquery.subquery_binder); + auto bound_node = std::move(bound_subquery.bound_node); + LogicalType return_type = + expr.subquery_type == SubqueryType::SCALAR ? bound_node->types[0] : LogicalType(LogicalTypeId::BOOLEAN); + if (return_type.id() == LogicalTypeId::UNKNOWN) { + return_type = LogicalType::SQLNULL; + } + + auto result = make_uniq(return_type); + if (expr.subquery_type == SubqueryType::ANY) { + // ANY comparison + // cast child and subquery child to equivalent types + D_ASSERT(bound_node->types.size() == 1); + auto &child = BoundExpression::GetExpression(*expr.child); + auto compare_type = LogicalType::MaxLogicalType(child->return_type, bound_node->types[0]); + child = BoundCastExpression::AddCastToType(context, std::move(child), compare_type); + result->child_type = bound_node->types[0]; + result->child_target = compare_type; + result->child = std::move(child); + } + result->binder = std::move(subquery_binder); + result->subquery = std::move(bound_node); + result->subquery_type = expr.subquery_type; + result->comparison_type = expr.comparison_type; + + return BindResult(std::move(result)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp new file mode 100644 index 00000000..b2ded1c9 --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp @@ -0,0 +1,205 @@ +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression_binder/aggregate_binder.hpp" +#include "duckdb/planner/expression_binder/select_binder.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/expression/bound_unnest_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +unique_ptr CreateBoundStructExtract(ClientContext &context, unique_ptr expr, string key) { + vector> arguments; + arguments.push_back(std::move(expr)); + arguments.push_back(make_uniq(Value(key))); + auto extract_function = StructExtractFun::GetFunction(); + auto bind_info = extract_function.bind(context, extract_function, arguments); + auto return_type = extract_function.return_type; + auto result = make_uniq(return_type, std::move(extract_function), std::move(arguments), + std::move(bind_info)); + result->alias = std::move(key); + return std::move(result); +} + +BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, bool root_expression) { + // bind the children of the function expression + if (depth > 0) { + return BindResult(binder.FormatError(function, "UNNEST() for correlated expressions is not supported yet")); + } + string error; + if (function.children.empty()) { + return BindResult(binder.FormatError(function, "UNNEST() requires a single argument")); + } + idx_t max_depth = 1; + if (function.children.size() != 1) { + bool has_parameter = false; + bool supported_argument = false; + for (idx_t i = 1; i < function.children.size(); i++) { + if (has_parameter) { + return BindResult(binder.FormatError(function, "UNNEST() only supports a single additional argument")); + } + if (function.children[i]->HasParameter()) { + throw ParameterNotAllowedException("Parameter not allowed in unnest parameter"); + } + if (!function.children[i]->IsScalar()) { + break; + } + auto alias = function.children[i]->alias; + BindChild(function.children[i], depth, error); + if (!error.empty()) { + return BindResult(error); + } + auto &const_child = BoundExpression::GetExpression(*function.children[i]); + auto value = ExpressionExecutor::EvaluateScalar(context, *const_child, true); + if (alias == "recursive") { + auto recursive = value.GetValue(); + if (recursive) { + max_depth = NumericLimits::Maximum(); + } + } else if (alias == "max_depth") { + max_depth = value.GetValue(); + if (max_depth == 0) { + throw BinderException("UNNEST cannot have a max depth of 0"); + } + } else if (!alias.empty()) { + throw BinderException("Unsupported parameter \"%s\" for unnest", alias); + } else { + break; + } + has_parameter = true; + supported_argument = true; + } + if (!supported_argument) { + return BindResult(binder.FormatError(function, "UNNEST - unsupported extra argument, unnest only supports " + "recursive := [true/false] or max_depth := #")); + } + } + unnest_level++; + BindChild(function.children[0], depth, error); + if (!error.empty()) { + // failed to bind + // try to bind correlated columns manually + if (!BindCorrelatedColumns(function.children[0])) { + return BindResult(error); + } + auto &bound_expr = BoundExpression::GetExpression(*function.children[0]); + ExtractCorrelatedExpressions(binder, *bound_expr); + } + auto &child = BoundExpression::GetExpression(*function.children[0]); + auto &child_type = child->return_type; + unnest_level--; + + if (unnest_level > 0) { + throw BinderException( + "Nested UNNEST calls are not supported - use UNNEST(x, recursive := true) to unnest multiple levels"); + } + + switch (child_type.id()) { + case LogicalTypeId::UNKNOWN: + throw ParameterNotResolvedException(); + case LogicalTypeId::LIST: + case LogicalTypeId::STRUCT: + case LogicalTypeId::SQLNULL: + break; + default: + return BindResult(binder.FormatError(function, "UNNEST() can only be applied to lists, structs and NULL")); + } + + idx_t list_unnests; + idx_t struct_unnests = 0; + + auto unnest_expr = std::move(child); + if (child_type.id() == LogicalTypeId::SQLNULL) { + list_unnests = 1; + } else { + // first do all of the list unnests + auto type = child_type; + list_unnests = 0; + while (type.id() == LogicalTypeId::LIST) { + type = ListType::GetChildType(type); + list_unnests++; + if (list_unnests >= max_depth) { + break; + } + } + // unnest structs all the way afterwards, if there are any + if (type.id() == LogicalTypeId::STRUCT) { + struct_unnests = max_depth - list_unnests; + } + } + if (struct_unnests > 0 && !root_expression) { + return BindResult(binder.FormatError( + function, "UNNEST() on a struct column can only be applied as the root element of a SELECT expression")); + } + // perform all of the list unnests first + auto return_type = child_type; + for (idx_t current_depth = 0; current_depth < list_unnests; current_depth++) { + if (return_type.id() == LogicalTypeId::LIST) { + return_type = ListType::GetChildType(return_type); + } + auto result = make_uniq(return_type); + result->child = std::move(unnest_expr); + auto alias = function.alias.empty() ? result->ToString() : function.alias; + + auto current_level = unnest_level + list_unnests - current_depth - 1; + auto entry = node.unnests.find(current_level); + idx_t unnest_table_index; + idx_t unnest_column_index; + if (entry == node.unnests.end()) { + BoundUnnestNode unnest_node; + unnest_node.index = binder.GenerateTableIndex(); + unnest_node.expressions.push_back(std::move(result)); + unnest_table_index = unnest_node.index; + unnest_column_index = 0; + node.unnests.insert(make_pair(current_level, std::move(unnest_node))); + } else { + unnest_table_index = entry->second.index; + unnest_column_index = entry->second.expressions.size(); + entry->second.expressions.push_back(std::move(result)); + } + // now create a column reference referring to the unnest + unnest_expr = make_uniq( + std::move(alias), return_type, ColumnBinding(unnest_table_index, unnest_column_index), depth); + } + // now perform struct unnests, if any + if (struct_unnests > 0) { + vector> struct_expressions; + struct_expressions.push_back(std::move(unnest_expr)); + + for (idx_t i = 0; i < struct_unnests; i++) { + vector> new_expressions; + // check if there are any structs left + bool has_structs = false; + for (auto &expr : struct_expressions) { + if (expr->return_type.id() == LogicalTypeId::STRUCT) { + // struct! push a struct_extract + auto &child_types = StructType::GetChildTypes(expr->return_type); + for (auto &entry : child_types) { + new_expressions.push_back(CreateBoundStructExtract(context, expr->Copy(), entry.first)); + } + has_structs = true; + } else { + // not a struct - push as-is + new_expressions.push_back(std::move(expr)); + } + } + struct_expressions = std::move(new_expressions); + if (!has_structs) { + break; + } + } + expanded_expressions = std::move(struct_expressions); + unnest_expr = make_uniq(Value(42)); + } + return BindResult(std::move(unnest_expr)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp new file mode 100644 index 00000000..644c96ed --- /dev/null +++ b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp @@ -0,0 +1,294 @@ +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/expression_binder/select_binder.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/function_binder.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" + +namespace duckdb { + +static LogicalType ResolveWindowExpressionType(ExpressionType window_type, const vector &child_types) { + + idx_t param_count; + switch (window_type) { + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: + case ExpressionType::WINDOW_ROW_NUMBER: + case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_CUME_DIST: + param_count = 0; + break; + case ExpressionType::WINDOW_NTILE: + case ExpressionType::WINDOW_FIRST_VALUE: + case ExpressionType::WINDOW_LAST_VALUE: + case ExpressionType::WINDOW_LEAD: + case ExpressionType::WINDOW_LAG: + param_count = 1; + break; + case ExpressionType::WINDOW_NTH_VALUE: + param_count = 2; + break; + default: + throw InternalException("Unrecognized window expression type " + ExpressionTypeToString(window_type)); + } + if (child_types.size() != param_count) { + throw BinderException("%s needs %d parameter%s, got %d", ExpressionTypeToString(window_type), param_count, + param_count == 1 ? "" : "s", child_types.size()); + } + switch (window_type) { + case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_CUME_DIST: + return LogicalType(LogicalTypeId::DOUBLE); + case ExpressionType::WINDOW_ROW_NUMBER: + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: + case ExpressionType::WINDOW_NTILE: + return LogicalType::BIGINT; + case ExpressionType::WINDOW_NTH_VALUE: + case ExpressionType::WINDOW_FIRST_VALUE: + case ExpressionType::WINDOW_LAST_VALUE: + case ExpressionType::WINDOW_LEAD: + case ExpressionType::WINDOW_LAG: + return child_types[0]; + default: + throw InternalException("Unrecognized window expression type " + ExpressionTypeToString(window_type)); + } +} + +static unique_ptr GetExpression(unique_ptr &expr) { + if (!expr) { + return nullptr; + } + D_ASSERT(expr.get()); + D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); + return std::move(BoundExpression::GetExpression(*expr)); +} + +static unique_ptr CastWindowExpression(unique_ptr &expr, const LogicalType &type) { + if (!expr) { + return nullptr; + } + D_ASSERT(expr.get()); + D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); + + auto &bound = BoundExpression::GetExpression(*expr); + bound = BoundCastExpression::AddDefaultCastToType(std::move(bound), type); + + return std::move(bound); +} + +static LogicalType BindRangeExpression(ClientContext &context, const string &name, unique_ptr &expr, + unique_ptr &order_expr) { + + vector> children; + + D_ASSERT(order_expr.get()); + D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION); + auto &bound_order = BoundExpression::GetExpression(*order_expr); + children.emplace_back(bound_order->Copy()); + + D_ASSERT(expr.get()); + D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); + auto &bound = BoundExpression::GetExpression(*expr); + children.emplace_back(std::move(bound)); + + string error; + FunctionBinder function_binder(context); + auto function = function_binder.BindScalarFunction(DEFAULT_SCHEMA, name, std::move(children), error, true); + if (!function) { + throw BinderException(error); + } + bound = std::move(function); + return bound->return_type; +} + +BindResult BaseSelectBinder::BindWindow(WindowExpression &window, idx_t depth) { + auto name = window.GetName(); + + QueryErrorContext error_context(binder.GetRootStatement(), window.query_location); + if (inside_window) { + throw BinderException(error_context.FormatError("window function calls cannot be nested")); + } + if (depth > 0) { + throw BinderException(error_context.FormatError("correlated columns in window functions not supported")); + } + // If we have range expressions, then only one order by clause is allowed. + if ((window.start == WindowBoundary::EXPR_PRECEDING_RANGE || window.start == WindowBoundary::EXPR_FOLLOWING_RANGE || + window.end == WindowBoundary::EXPR_PRECEDING_RANGE || window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) && + window.orders.size() != 1) { + throw BinderException(error_context.FormatError("RANGE frames must have only one ORDER BY expression")); + } + // bind inside the children of the window function + // we set the inside_window flag to true to prevent binding nested window functions + this->inside_window = true; + string error; + for (auto &child : window.children) { + BindChild(child, depth, error); + } + for (auto &child : window.partitions) { + BindChild(child, depth, error); + } + for (auto &order : window.orders) { + BindChild(order.expression, depth, error); + } + BindChild(window.filter_expr, depth, error); + BindChild(window.start_expr, depth, error); + BindChild(window.end_expr, depth, error); + BindChild(window.offset_expr, depth, error); + BindChild(window.default_expr, depth, error); + + this->inside_window = false; + if (!error.empty()) { + // failed to bind children of window function + return BindResult(error); + } + // successfully bound all children: create bound window function + vector types; + vector> children; + for (auto &child : window.children) { + D_ASSERT(child.get()); + D_ASSERT(child->expression_class == ExpressionClass::BOUND_EXPRESSION); + auto &bound = BoundExpression::GetExpression(*child); + // Add casts for positional arguments + const auto argno = children.size(); + switch (window.type) { + case ExpressionType::WINDOW_NTILE: + // ntile(bigint) + if (argno == 0) { + bound = BoundCastExpression::AddCastToType(context, std::move(bound), LogicalType::BIGINT); + } + break; + case ExpressionType::WINDOW_NTH_VALUE: + // nth_value(, index) + if (argno == 1) { + bound = BoundCastExpression::AddCastToType(context, std::move(bound), LogicalType::BIGINT); + } + default: + break; + } + types.push_back(bound->return_type); + children.push_back(std::move(bound)); + } + // Determine the function type. + LogicalType sql_type; + unique_ptr aggregate; + unique_ptr bind_info; + if (window.type == ExpressionType::WINDOW_AGGREGATE) { + // Look up the aggregate function in the catalog + auto &func = Catalog::GetEntry(context, window.catalog, window.schema, + window.function_name, error_context); + D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); + + // bind the aggregate + string error; + FunctionBinder function_binder(context); + auto best_function = function_binder.BindFunction(func.name, func.functions, types, error); + if (best_function == DConstants::INVALID_INDEX) { + throw BinderException(binder.FormatError(window, error)); + } + // found a matching function! bind it as an aggregate + auto bound_function = func.functions.GetFunctionByOffset(best_function); + auto bound_aggregate = function_binder.BindAggregateFunction(bound_function, std::move(children)); + // create the aggregate + aggregate = make_uniq(bound_aggregate->function); + bind_info = std::move(bound_aggregate->bind_info); + children = std::move(bound_aggregate->children); + sql_type = bound_aggregate->return_type; + } else { + // fetch the child of the non-aggregate window function (if any) + sql_type = ResolveWindowExpressionType(window.type, types); + } + auto result = make_uniq(window.type, sql_type, std::move(aggregate), std::move(bind_info)); + result->children = std::move(children); + for (auto &child : window.partitions) { + result->partitions.push_back(GetExpression(child)); + } + result->ignore_nulls = window.ignore_nulls; + + // Convert RANGE boundary expressions to ORDER +/- expressions. + // Note that PRECEEDING and FOLLOWING refer to the sequential order in the frame, + // not the natural ordering of the type. This means that the offset arithmetic must be reversed + // for ORDER BY DESC. + auto &config = DBConfig::GetConfig(context); + auto range_sense = OrderType::INVALID; + LogicalType start_type = LogicalType::BIGINT; + if (window.start == WindowBoundary::EXPR_PRECEDING_RANGE) { + D_ASSERT(window.orders.size() == 1); + range_sense = config.ResolveOrder(window.orders[0].type); + const auto name = (range_sense == OrderType::ASCENDING) ? "-" : "+"; + start_type = BindRangeExpression(context, name, window.start_expr, window.orders[0].expression); + } else if (window.start == WindowBoundary::EXPR_FOLLOWING_RANGE) { + D_ASSERT(window.orders.size() == 1); + range_sense = config.ResolveOrder(window.orders[0].type); + const auto name = (range_sense == OrderType::ASCENDING) ? "+" : "-"; + start_type = BindRangeExpression(context, name, window.start_expr, window.orders[0].expression); + } + + LogicalType end_type = LogicalType::BIGINT; + if (window.end == WindowBoundary::EXPR_PRECEDING_RANGE) { + D_ASSERT(window.orders.size() == 1); + range_sense = config.ResolveOrder(window.orders[0].type); + const auto name = (range_sense == OrderType::ASCENDING) ? "-" : "+"; + end_type = BindRangeExpression(context, name, window.end_expr, window.orders[0].expression); + } else if (window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) { + D_ASSERT(window.orders.size() == 1); + range_sense = config.ResolveOrder(window.orders[0].type); + const auto name = (range_sense == OrderType::ASCENDING) ? "+" : "-"; + end_type = BindRangeExpression(context, name, window.end_expr, window.orders[0].expression); + } + + // Cast ORDER and boundary expressions to the same type + if (range_sense != OrderType::INVALID) { + D_ASSERT(window.orders.size() == 1); + + auto &order_expr = window.orders[0].expression; + D_ASSERT(order_expr.get()); + D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION); + auto &bound_order = BoundExpression::GetExpression(*order_expr); + auto order_type = bound_order->return_type; + if (window.start_expr) { + order_type = LogicalType::MaxLogicalType(order_type, start_type); + } + if (window.end_expr) { + order_type = LogicalType::MaxLogicalType(order_type, end_type); + } + + // Cast all three to match + bound_order = BoundCastExpression::AddCastToType(context, std::move(bound_order), order_type); + start_type = end_type = order_type; + } + + for (auto &order : window.orders) { + auto type = config.ResolveOrder(order.type); + auto null_order = config.ResolveNullOrder(type, order.null_order); + auto expression = GetExpression(order.expression); + result->orders.emplace_back(type, null_order, std::move(expression)); + } + + result->filter_expr = CastWindowExpression(window.filter_expr, LogicalType::BOOLEAN); + + result->start_expr = CastWindowExpression(window.start_expr, start_type); + result->end_expr = CastWindowExpression(window.end_expr, end_type); + result->offset_expr = CastWindowExpression(window.offset_expr, LogicalType::BIGINT); + result->default_expr = CastWindowExpression(window.default_expr, result->return_type); + result->start = window.start; + result->end = window.end; + + // create a BoundColumnRef that references this entry + auto colref = make_uniq(std::move(name), result->return_type, + ColumnBinding(node.window_index, node.windows.size()), depth); + // move the WINDOW expression into the set of bound windows + node.windows.push_back(std::move(result)); + return BindResult(std::move(colref)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp new file mode 100644 index 00000000..42f14b96 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp @@ -0,0 +1,64 @@ +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/query_node/bound_cte_node.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" + +namespace duckdb { + +unique_ptr Binder::BindNode(CTENode &statement) { + auto result = make_uniq(); + + // first recursively visit the materialized CTE operations + // the left side is visited first and is added to the BindContext of the right side + D_ASSERT(statement.query); + D_ASSERT(statement.child); + + result->ctename = statement.ctename; + result->setop_index = GenerateTableIndex(); + + result->query_binder = Binder::CreateBinder(context, this); + result->query = result->query_binder->BindNode(*statement.query); + + // the result types of the CTE are the types of the LHS + result->types = result->query->types; + // names are picked from the LHS, unless aliases are explicitly specified + result->names = result->query->names; + for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { + result->names[i] = statement.aliases[i]; + } + + // This allows the right side to reference the CTE + bind_context.AddGenericBinding(result->setop_index, statement.ctename, result->names, result->types); + + result->child_binder = Binder::CreateBinder(context, this); + + // Move all modifiers to the child node. + for (auto &modifier : statement.modifiers) { + statement.child->modifiers.push_back(std::move(modifier)); + } + + statement.modifiers.clear(); + + // Add bindings of left side to temporary CTE bindings context + result->child_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, result->names, + result->types); + result->child = result->child_binder->BindNode(*statement.child); + + // the result types of the CTE are the types of the LHS + result->types = result->child->types; + // names are picked from the LHS, unless aliases are explicitly specified + result->names = result->child->names; + for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { + result->names[i] = statement.aliases[i]; + } + + MoveCorrelatedExpressions(*result->query_binder); + MoveCorrelatedExpressions(*result->child_binder); + + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp new file mode 100644 index 00000000..cf9ebfac --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp @@ -0,0 +1,61 @@ +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/recursive_cte_node.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" + +namespace duckdb { + +unique_ptr Binder::BindNode(RecursiveCTENode &statement) { + auto result = make_uniq(); + + // first recursively visit the recursive CTE operations + // the left side is visited first and is added to the BindContext of the right side + D_ASSERT(statement.left); + D_ASSERT(statement.right); + + result->ctename = statement.ctename; + result->union_all = statement.union_all; + result->setop_index = GenerateTableIndex(); + + result->left_binder = Binder::CreateBinder(context, this); + result->left = result->left_binder->BindNode(*statement.left); + + // the result types of the CTE are the types of the LHS + result->types = result->left->types; + // names are picked from the LHS, unless aliases are explicitly specified + result->names = result->left->names; + for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { + result->names[i] = statement.aliases[i]; + } + + // This allows the right side to reference the CTE recursively + bind_context.AddGenericBinding(result->setop_index, statement.ctename, result->names, result->types); + + result->right_binder = Binder::CreateBinder(context, this); + + // Add bindings of left side to temporary CTE bindings context + result->right_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, result->names, + result->types); + result->right = result->right_binder->BindNode(*statement.right); + + // move the correlated expressions from the child binders to this binder + MoveCorrelatedExpressions(*result->left_binder); + MoveCorrelatedExpressions(*result->right_binder); + + // now both sides have been bound we can resolve types + if (result->left->types.size() != result->right->types.size()) { + throw BinderException("Set operations can only apply to expressions with the " + "same number of result columns"); + } + + if (!statement.modifiers.empty()) { + throw NotImplementedException("FIXME: bind modifiers in recursive CTE"); + } + + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp new file mode 100644 index 00000000..6a052a30 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -0,0 +1,555 @@ +#include "duckdb/common/limits.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/joinref.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression_binder/column_alias_binder.hpp" +#include "duckdb/planner/expression_binder/constant_binder.hpp" +#include "duckdb/planner/expression_binder/group_binder.hpp" +#include "duckdb/planner/expression_binder/having_binder.hpp" +#include "duckdb/planner/expression_binder/order_binder.hpp" +#include "duckdb/planner/expression_binder/qualify_binder.hpp" +#include "duckdb/planner/expression_binder/select_binder.hpp" +#include "duckdb/planner/expression_binder/where_binder.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" + +namespace duckdb { + +unique_ptr Binder::BindOrderExpression(OrderBinder &order_binder, unique_ptr expr) { + // we treat the Distinct list as a order by + auto bound_expr = order_binder.Bind(std::move(expr)); + if (!bound_expr) { + // DISTINCT ON non-integer constant + // remove the expression from the DISTINCT ON list + return nullptr; + } + D_ASSERT(bound_expr->type == ExpressionType::BOUND_COLUMN_REF); + return bound_expr; +} + +unique_ptr Binder::BindDelimiter(ClientContext &context, OrderBinder &order_binder, + unique_ptr delimiter, const LogicalType &type, + Value &delimiter_value) { + auto new_binder = Binder::CreateBinder(context, this, true); + if (delimiter->HasSubquery()) { + if (!order_binder.HasExtraList()) { + throw BinderException("Subquery in LIMIT/OFFSET not supported in set operation"); + } + return order_binder.CreateExtraReference(std::move(delimiter)); + } + ExpressionBinder expr_binder(*new_binder, context); + expr_binder.target_type = type; + auto expr = expr_binder.Bind(delimiter); + if (expr->IsFoldable()) { + //! this is a constant + delimiter_value = ExpressionExecutor::EvaluateScalar(context, *expr).CastAs(context, type); + return nullptr; + } + if (!new_binder->correlated_columns.empty()) { + throw BinderException("Correlated columns not supported in LIMIT/OFFSET"); + } + // move any correlated columns to this binder + MoveCorrelatedExpressions(*new_binder); + return expr; +} + +duckdb::unique_ptr Binder::BindLimit(OrderBinder &order_binder, LimitModifier &limit_mod) { + auto result = make_uniq(); + if (limit_mod.limit) { + Value val; + result->limit = BindDelimiter(context, order_binder, std::move(limit_mod.limit), LogicalType::BIGINT, val); + if (!result->limit) { + result->limit_val = val.IsNull() ? NumericLimits::Maximum() : val.GetValue(); + if (result->limit_val < 0) { + throw BinderException("LIMIT cannot be negative"); + } + } + } + if (limit_mod.offset) { + Value val; + result->offset = BindDelimiter(context, order_binder, std::move(limit_mod.offset), LogicalType::BIGINT, val); + if (!result->offset) { + result->offset_val = val.IsNull() ? 0 : val.GetValue(); + if (result->offset_val < 0) { + throw BinderException("OFFSET cannot be negative"); + } + } + } + return std::move(result); +} + +unique_ptr Binder::BindLimitPercent(OrderBinder &order_binder, LimitPercentModifier &limit_mod) { + auto result = make_uniq(); + if (limit_mod.limit) { + Value val; + result->limit = BindDelimiter(context, order_binder, std::move(limit_mod.limit), LogicalType::DOUBLE, val); + if (!result->limit) { + result->limit_percent = val.IsNull() ? 100 : val.GetValue(); + if (result->limit_percent < 0.0) { + throw Exception("Limit percentage can't be negative value"); + } + } + } + if (limit_mod.offset) { + Value val; + result->offset = BindDelimiter(context, order_binder, std::move(limit_mod.offset), LogicalType::BIGINT, val); + if (!result->offset) { + result->offset_val = val.IsNull() ? 0 : val.GetValue(); + } + } + return std::move(result); +} + +void Binder::BindModifiers(OrderBinder &order_binder, QueryNode &statement, BoundQueryNode &result) { + for (auto &mod : statement.modifiers) { + unique_ptr bound_modifier; + switch (mod->type) { + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct = mod->Cast(); + auto bound_distinct = make_uniq(); + bound_distinct->distinct_type = + distinct.distinct_on_targets.empty() ? DistinctType::DISTINCT : DistinctType::DISTINCT_ON; + if (distinct.distinct_on_targets.empty()) { + for (idx_t i = 0; i < result.names.size(); i++) { + distinct.distinct_on_targets.push_back(make_uniq(Value::INTEGER(1 + i))); + } + } + for (auto &distinct_on_target : distinct.distinct_on_targets) { + auto expr = BindOrderExpression(order_binder, std::move(distinct_on_target)); + if (!expr) { + continue; + } + bound_distinct->target_distincts.push_back(std::move(expr)); + } + bound_modifier = std::move(bound_distinct); + break; + } + case ResultModifierType::ORDER_MODIFIER: { + auto &order = mod->Cast(); + auto bound_order = make_uniq(); + auto &config = DBConfig::GetConfig(context); + D_ASSERT(!order.orders.empty()); + auto &order_binders = order_binder.GetBinders(); + if (order.orders.size() == 1 && order.orders[0].expression->type == ExpressionType::STAR) { + auto &star = order.orders[0].expression->Cast(); + if (star.exclude_list.empty() && star.replace_list.empty() && !star.expr) { + // ORDER BY ALL + // replace the order list with the all elements in the SELECT list + auto order_type = order.orders[0].type; + auto null_order = order.orders[0].null_order; + + vector new_orders; + for (idx_t i = 0; i < order_binder.MaxCount(); i++) { + new_orders.emplace_back(order_type, null_order, + make_uniq(Value::INTEGER(i + 1))); + } + order.orders = std::move(new_orders); + } + } + for (auto &order_node : order.orders) { + vector> order_list; + order_binders[0]->ExpandStarExpression(std::move(order_node.expression), order_list); + + auto type = config.ResolveOrder(order_node.type); + auto null_order = config.ResolveNullOrder(type, order_node.null_order); + for (auto &order_expr : order_list) { + auto bound_expr = BindOrderExpression(order_binder, std::move(order_expr)); + if (!bound_expr) { + continue; + } + bound_order->orders.emplace_back(type, null_order, std::move(bound_expr)); + } + } + if (!bound_order->orders.empty()) { + bound_modifier = std::move(bound_order); + } + break; + } + case ResultModifierType::LIMIT_MODIFIER: + bound_modifier = BindLimit(order_binder, mod->Cast()); + break; + case ResultModifierType::LIMIT_PERCENT_MODIFIER: + bound_modifier = BindLimitPercent(order_binder, mod->Cast()); + break; + default: + throw Exception("Unsupported result modifier"); + } + if (bound_modifier) { + result.modifiers.push_back(std::move(bound_modifier)); + } + } +} + +static void AssignReturnType(unique_ptr &expr, const vector &sql_types) { + if (!expr) { + return; + } + if (expr->type != ExpressionType::BOUND_COLUMN_REF) { + return; + } + auto &bound_colref = expr->Cast(); + bound_colref.return_type = sql_types[bound_colref.binding.column_index]; +} + +void Binder::BindModifierTypes(BoundQueryNode &result, const vector &sql_types, idx_t projection_index) { + for (auto &bound_mod : result.modifiers) { + switch (bound_mod->type) { + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct = bound_mod->Cast(); + D_ASSERT(!distinct.target_distincts.empty()); + // set types of distinct targets + for (auto &expr : distinct.target_distincts) { + D_ASSERT(expr->type == ExpressionType::BOUND_COLUMN_REF); + auto &bound_colref = expr->Cast(); + if (bound_colref.binding.column_index == DConstants::INVALID_INDEX) { + throw BinderException("Ambiguous name in DISTINCT ON!"); + } + D_ASSERT(bound_colref.binding.column_index < sql_types.size()); + bound_colref.return_type = sql_types[bound_colref.binding.column_index]; + } + for (auto &target_distinct : distinct.target_distincts) { + auto &bound_colref = target_distinct->Cast(); + const auto &sql_type = sql_types[bound_colref.binding.column_index]; + if (sql_type.id() == LogicalTypeId::VARCHAR) { + target_distinct = ExpressionBinder::PushCollation(context, std::move(target_distinct), + StringType::GetCollation(sql_type), true); + } + } + break; + } + case ResultModifierType::LIMIT_MODIFIER: { + auto &limit = bound_mod->Cast(); + AssignReturnType(limit.limit, sql_types); + AssignReturnType(limit.offset, sql_types); + break; + } + case ResultModifierType::LIMIT_PERCENT_MODIFIER: { + auto &limit = bound_mod->Cast(); + AssignReturnType(limit.limit, sql_types); + AssignReturnType(limit.offset, sql_types); + break; + } + case ResultModifierType::ORDER_MODIFIER: { + auto &order = bound_mod->Cast(); + for (auto &order_node : order.orders) { + auto &expr = order_node.expression; + D_ASSERT(expr->type == ExpressionType::BOUND_COLUMN_REF); + auto &bound_colref = expr->Cast(); + if (bound_colref.binding.column_index == DConstants::INVALID_INDEX) { + throw BinderException("Ambiguous name in ORDER BY!"); + } + D_ASSERT(bound_colref.binding.column_index < sql_types.size()); + const auto &sql_type = sql_types[bound_colref.binding.column_index]; + bound_colref.return_type = sql_types[bound_colref.binding.column_index]; + if (sql_type.id() == LogicalTypeId::VARCHAR) { + order_node.expression = ExpressionBinder::PushCollation(context, std::move(order_node.expression), + StringType::GetCollation(sql_type)); + } + } + break; + } + default: + break; + } + } +} + +unique_ptr Binder::BindNode(SelectNode &statement) { + D_ASSERT(statement.from_table); + // first bind the FROM table statement + auto from = std::move(statement.from_table); + auto from_table = Bind(*from); + return BindSelectNode(statement, std::move(from_table)); +} + +void Binder::BindWhereStarExpression(unique_ptr &expr) { + // expand any expressions in the upper AND recursively + if (expr->type == ExpressionType::CONJUNCTION_AND) { + auto &conj = expr->Cast(); + for (auto &child : conj.children) { + BindWhereStarExpression(child); + } + return; + } + if (expr->type == ExpressionType::STAR) { + auto &star = expr->Cast(); + if (!star.columns) { + throw ParserException("STAR expression is not allowed in the WHERE clause. Use COLUMNS(*) instead."); + } + } + // expand the stars for this expression + vector> new_conditions; + ExpandStarExpression(std::move(expr), new_conditions); + if (new_conditions.empty()) { + throw ParserException("COLUMNS expansion resulted in empty set of columns"); + } + + // set up an AND conjunction between the expanded conditions + expr = std::move(new_conditions[0]); + for (idx_t i = 1; i < new_conditions.size(); i++) { + auto and_conj = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(expr), + std::move(new_conditions[i])); + expr = std::move(and_conj); + } +} + +unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ptr from_table) { + D_ASSERT(from_table); + D_ASSERT(!statement.from_table); + auto result = make_uniq(); + result->projection_index = GenerateTableIndex(); + result->group_index = GenerateTableIndex(); + result->aggregate_index = GenerateTableIndex(); + result->groupings_index = GenerateTableIndex(); + result->window_index = GenerateTableIndex(); + result->prune_index = GenerateTableIndex(); + + result->from_table = std::move(from_table); + // bind the sample clause + if (statement.sample) { + result->sample_options = std::move(statement.sample); + } + + // visit the select list and expand any "*" statements + vector> new_select_list; + ExpandStarExpressions(statement.select_list, new_select_list); + + if (new_select_list.empty()) { + throw BinderException("SELECT list is empty after resolving * expressions!"); + } + statement.select_list = std::move(new_select_list); + + // create a mapping of (alias -> index) and a mapping of (Expression -> index) for the SELECT list + case_insensitive_map_t alias_map; + parsed_expression_map_t projection_map; + for (idx_t i = 0; i < statement.select_list.size(); i++) { + auto &expr = statement.select_list[i]; + result->names.push_back(expr->GetName()); + ExpressionBinder::QualifyColumnNames(*this, expr); + if (!expr->alias.empty()) { + alias_map[expr->alias] = i; + result->names[i] = expr->alias; + } + projection_map[*expr] = i; + result->original_expressions.push_back(expr->Copy()); + } + result->column_count = statement.select_list.size(); + + // first visit the WHERE clause + // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses + if (statement.where_clause) { + // bind any star expressions in the WHERE clause + BindWhereStarExpression(statement.where_clause); + + ColumnAliasBinder alias_binder(*result, alias_map); + WhereBinder where_binder(*this, context, &alias_binder); + unique_ptr condition = std::move(statement.where_clause); + result->where_clause = where_binder.Bind(condition); + } + + // now bind all the result modifiers; including DISTINCT and ORDER BY targets + OrderBinder order_binder({this}, result->projection_index, statement, alias_map, projection_map); + BindModifiers(order_binder, statement, *result); + + vector> unbound_groups; + BoundGroupInformation info; + auto &group_expressions = statement.groups.group_expressions; + if (!group_expressions.empty()) { + // the statement has a GROUP BY clause, bind it + unbound_groups.resize(group_expressions.size()); + GroupBinder group_binder(*this, context, statement, result->group_index, alias_map, info.alias_map); + for (idx_t i = 0; i < group_expressions.size(); i++) { + + // we keep a copy of the unbound expression; + // we keep the unbound copy around to check for group references in the SELECT and HAVING clause + // the reason we want the unbound copy is because we want to figure out whether an expression + // is a group reference BEFORE binding in the SELECT/HAVING binder + group_binder.unbound_expression = group_expressions[i]->Copy(); + group_binder.bind_index = i; + + // bind the groups + LogicalType group_type; + auto bound_expr = group_binder.Bind(group_expressions[i], &group_type); + D_ASSERT(bound_expr->return_type.id() != LogicalTypeId::INVALID); + + // find out whether the expression contains a subquery, it can't be copied if so + auto &bound_expr_ref = *bound_expr; + bool contains_subquery = bound_expr_ref.HasSubquery(); + + // push a potential collation, if necessary + auto collated_expr = ExpressionBinder::PushCollation(context, std::move(bound_expr), + StringType::GetCollation(group_type), true); + if (!contains_subquery && !collated_expr->Equals(bound_expr_ref)) { + // if there is a collation on a group x, we should group by the collated expr, + // but also push a first(x) aggregate in case x is selected (uncollated) + info.collated_groups[i] = result->aggregates.size(); + + auto first_fun = FirstFun::GetFunction(LogicalType::VARCHAR); + vector> first_children; + // FIXME: would be better to just refer to this expression, but for now we copy + first_children.push_back(bound_expr_ref.Copy()); + + FunctionBinder function_binder(context); + auto function = function_binder.BindAggregateFunction(first_fun, std::move(first_children)); + result->aggregates.push_back(std::move(function)); + } + result->groups.group_expressions.push_back(std::move(collated_expr)); + + // in the unbound expression we DO bind the table names of any ColumnRefs + // we do this to make sure that "table.a" and "a" are treated the same + // if we wouldn't do this then (SELECT test.a FROM test GROUP BY a) would not work because "test.a" <> "a" + // hence we convert "a" -> "test.a" in the unbound expression + unbound_groups[i] = std::move(group_binder.unbound_expression); + ExpressionBinder::QualifyColumnNames(*this, unbound_groups[i]); + info.map[*unbound_groups[i]] = i; + } + } + result->groups.grouping_sets = std::move(statement.groups.grouping_sets); + + // bind the HAVING clause, if any + if (statement.having) { + HavingBinder having_binder(*this, context, *result, info, alias_map, statement.aggregate_handling); + ExpressionBinder::QualifyColumnNames(*this, statement.having); + result->having = having_binder.Bind(statement.having); + } + + // bind the QUALIFY clause, if any + unique_ptr qualify_binder; + if (statement.qualify) { + if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { + throw BinderException("Combining QUALIFY with GROUP BY ALL is not supported yet"); + } + qualify_binder = make_uniq(*this, context, *result, info, alias_map); + ExpressionBinder::QualifyColumnNames(*this, statement.qualify); + result->qualify = qualify_binder->Bind(statement.qualify); + if (qualify_binder->HasBoundColumns() && qualify_binder->BoundAggregates()) { + throw BinderException("Cannot mix aggregates with non-aggregated columns!"); + } + } + + // after that, we bind to the SELECT list + SelectBinder select_binder(*this, context, *result, info, alias_map); + vector internal_sql_types; + vector group_by_all_indexes; + vector new_names; + for (idx_t i = 0; i < statement.select_list.size(); i++) { + bool is_window = statement.select_list[i]->IsWindow(); + idx_t unnest_count = result->unnests.size(); + LogicalType result_type; + auto expr = select_binder.Bind(statement.select_list[i], &result_type, true); + bool is_original_column = i < result->column_count; + bool can_group_by_all = + statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES && is_original_column; + if (select_binder.HasExpandedExpressions()) { + if (!is_original_column) { + throw InternalException("Only original columns can have expanded expressions"); + } + if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { + throw BinderException("UNNEST of struct cannot be combined with GROUP BY ALL"); + } + auto &struct_expressions = select_binder.ExpandedExpressions(); + D_ASSERT(!struct_expressions.empty()); + for (auto &struct_expr : struct_expressions) { + new_names.push_back(struct_expr->GetName()); + result->types.push_back(struct_expr->return_type); + result->select_list.push_back(std::move(struct_expr)); + } + struct_expressions.clear(); + continue; + } + if (can_group_by_all && select_binder.HasBoundColumns()) { + if (select_binder.BoundAggregates()) { + throw BinderException("Cannot mix aggregates with non-aggregated columns!"); + } + if (is_window) { + throw BinderException("Cannot group on a window clause"); + } + if (result->unnests.size() > unnest_count) { + throw BinderException("Cannot group on an UNNEST or UNLIST clause"); + } + // we are forcing aggregates, and the node has columns bound + // this entry becomes a group + group_by_all_indexes.push_back(i); + } + result->select_list.push_back(std::move(expr)); + if (is_original_column) { + new_names.push_back(std::move(result->names[i])); + result->types.push_back(result_type); + } + internal_sql_types.push_back(result_type); + if (can_group_by_all) { + select_binder.ResetBindings(); + } + } + // push the GROUP BY ALL expressions into the group set + for (auto &group_by_all_index : group_by_all_indexes) { + auto &expr = result->select_list[group_by_all_index]; + auto group_ref = make_uniq( + expr->return_type, ColumnBinding(result->group_index, result->groups.group_expressions.size())); + result->groups.group_expressions.push_back(std::move(expr)); + expr = std::move(group_ref); + } + result->column_count = new_names.size(); + result->names = std::move(new_names); + result->need_prune = result->select_list.size() > result->column_count; + + // in the normal select binder, we bind columns as if there is no aggregation + // i.e. in the query [SELECT i, SUM(i) FROM integers;] the "i" will be bound as a normal column + // since we have an aggregation, we need to either (1) throw an error, or (2) wrap the column in a FIRST() aggregate + // we choose the former one [CONTROVERSIAL: this is the PostgreSQL behavior] + if (!result->groups.group_expressions.empty() || !result->aggregates.empty() || statement.having || + !result->groups.grouping_sets.empty()) { + if (statement.aggregate_handling == AggregateHandling::NO_AGGREGATES_ALLOWED) { + throw BinderException("Aggregates cannot be present in a Project relation!"); + } else { + vector> to_check_binders; + to_check_binders.push_back(select_binder); + if (qualify_binder) { + to_check_binders.push_back(*qualify_binder); + } + for (auto &binder : to_check_binders) { + auto &sel_binder = binder.get(); + if (!sel_binder.HasBoundColumns()) { + continue; + } + auto &bound_columns = sel_binder.GetBoundColumns(); + string error; + error = "column \"%s\" must appear in the GROUP BY clause or must be part of an aggregate function."; + if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { + error += "\nGROUP BY ALL will only group entries in the SELECT list. Add it to the SELECT list or " + "GROUP BY this entry explicitly."; + } else { + error += + "\nEither add it to the GROUP BY list, or use \"ANY_VALUE(%s)\" if the exact value of \"%s\" " + "is not important."; + } + throw BinderException(FormatError(bound_columns[0].query_location, error, bound_columns[0].name, + bound_columns[0].name, bound_columns[0].name)); + } + } + } + + // QUALIFY clause requires at least one window function to be specified in at least one of the SELECT column list or + // the filter predicate of the QUALIFY clause + if (statement.qualify && result->windows.empty()) { + throw BinderException("at least one window function must appear in the SELECT column or QUALIFY clause"); + } + + // now that the SELECT list is bound, we set the types of DISTINCT/ORDER BY expressions + BindModifierTypes(*result, internal_sql_types, result->projection_index); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp new file mode 100644 index 00000000..7b2c3375 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp @@ -0,0 +1,266 @@ +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression_binder/order_binder.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/query_node/bound_set_operation_node.hpp" + +namespace duckdb { + +static void GatherAliases(BoundQueryNode &node, case_insensitive_map_t &aliases, + parsed_expression_map_t &expressions, const vector &reorder_idx) { + if (node.type == QueryNodeType::SET_OPERATION_NODE) { + // setop, recurse + auto &setop = node.Cast(); + + // create new reorder index + if (setop.setop_type == SetOperationType::UNION_BY_NAME) { + vector new_left_reorder_idx(setop.left_reorder_idx.size()); + vector new_right_reorder_idx(setop.right_reorder_idx.size()); + for (idx_t i = 0; i < setop.left_reorder_idx.size(); ++i) { + new_left_reorder_idx[i] = reorder_idx[setop.left_reorder_idx[i]]; + } + + for (idx_t i = 0; i < setop.right_reorder_idx.size(); ++i) { + new_right_reorder_idx[i] = reorder_idx[setop.right_reorder_idx[i]]; + } + + // use new reorder index + GatherAliases(*setop.left, aliases, expressions, new_left_reorder_idx); + GatherAliases(*setop.right, aliases, expressions, new_right_reorder_idx); + return; + } + + GatherAliases(*setop.left, aliases, expressions, reorder_idx); + GatherAliases(*setop.right, aliases, expressions, reorder_idx); + } else { + // query node + D_ASSERT(node.type == QueryNodeType::SELECT_NODE); + auto &select = node.Cast(); + // fill the alias lists + for (idx_t i = 0; i < select.names.size(); i++) { + auto &name = select.names[i]; + auto &expr = select.original_expressions[i]; + // first check if the alias is already in there + auto entry = aliases.find(name); + + idx_t index = reorder_idx[i]; + + if (entry != aliases.end()) { + // the alias already exists + // check if there is a conflict + + if (entry->second != index) { + // there is a conflict + // we place "-1" in the aliases map at this location + // "-1" signifies that there is an ambiguous reference + aliases[name] = DConstants::INVALID_INDEX; + } + } else { + // the alias is not in there yet, just assign it + aliases[name] = index; + } + // now check if the node is already in the set of expressions + auto expr_entry = expressions.find(*expr); + if (expr_entry != expressions.end()) { + // the node is in there + // repeat the same as with the alias: if there is an ambiguity we insert "-1" + if (expr_entry->second != index) { + expressions[*expr] = DConstants::INVALID_INDEX; + } + } else { + // not in there yet, just place it in there + expressions[*expr] = index; + } + } + } +} + +static void BuildUnionByNameInfo(BoundSetOperationNode &result, bool can_contain_nulls) { + D_ASSERT(result.setop_type == SetOperationType::UNION_BY_NAME); + case_insensitive_map_t left_names_map; + case_insensitive_map_t right_names_map; + + BoundQueryNode *left_node = result.left.get(); + BoundQueryNode *right_node = result.right.get(); + + // Build a name_map to use to check if a name exists + // We throw a binder exception if two same name in the SELECT list + for (idx_t i = 0; i < left_node->names.size(); ++i) { + if (left_names_map.find(left_node->names[i]) != left_names_map.end()) { + throw BinderException("UNION(ALL) BY NAME operation doesn't support same name in SELECT list"); + } + left_names_map[left_node->names[i]] = i; + } + + for (idx_t i = 0; i < right_node->names.size(); ++i) { + if (right_names_map.find(right_node->names[i]) != right_names_map.end()) { + throw BinderException("UNION(ALL) BY NAME operation doesn't support same name in SELECT list"); + } + if (left_names_map.find(right_node->names[i]) == left_names_map.end()) { + result.names.push_back(right_node->names[i]); + } + right_names_map[right_node->names[i]] = i; + } + + idx_t new_size = result.names.size(); + bool need_reorder = false; + vector left_reorder_idx(left_node->names.size()); + vector right_reorder_idx(right_node->names.size()); + + // Construct return type and reorder_idxs + // reorder_idxs is used to gather correct alias_map + // and expression_map in GatherAlias(...) + for (idx_t i = 0; i < new_size; ++i) { + auto left_index = left_names_map.find(result.names[i]); + auto right_index = right_names_map.find(result.names[i]); + bool left_exist = left_index != left_names_map.end(); + bool right_exist = right_index != right_names_map.end(); + LogicalType result_type; + if (left_exist && right_exist) { + result_type = LogicalType::MaxLogicalType(left_node->types[left_index->second], + right_node->types[right_index->second]); + if (left_index->second != i || right_index->second != i) { + need_reorder = true; + } + left_reorder_idx[left_index->second] = i; + right_reorder_idx[right_index->second] = i; + } else if (left_exist) { + result_type = left_node->types[left_index->second]; + need_reorder = true; + left_reorder_idx[left_index->second] = i; + } else { + D_ASSERT(right_exist); + result_type = right_node->types[right_index->second]; + need_reorder = true; + right_reorder_idx[right_index->second] = i; + } + + if (!can_contain_nulls) { + if (ExpressionBinder::ContainsNullType(result_type)) { + result_type = ExpressionBinder::ExchangeNullType(result_type); + } + } + + result.types.push_back(result_type); + } + + result.left_reorder_idx = std::move(left_reorder_idx); + result.right_reorder_idx = std::move(right_reorder_idx); + + // If reorder is required, collect reorder expressions for push projection + // into the two child nodes of union node + if (need_reorder) { + for (idx_t i = 0; i < new_size; ++i) { + auto left_index = left_names_map.find(result.names[i]); + auto right_index = right_names_map.find(result.names[i]); + bool left_exist = left_index != left_names_map.end(); + bool right_exist = right_index != right_names_map.end(); + unique_ptr left_reorder_expr; + unique_ptr right_reorder_expr; + if (left_exist && right_exist) { + left_reorder_expr = make_uniq( + left_node->types[left_index->second], ColumnBinding(left_node->GetRootIndex(), left_index->second)); + right_reorder_expr = + make_uniq(right_node->types[right_index->second], + ColumnBinding(right_node->GetRootIndex(), right_index->second)); + } else if (left_exist) { + left_reorder_expr = make_uniq( + left_node->types[left_index->second], ColumnBinding(left_node->GetRootIndex(), left_index->second)); + // create null value here + right_reorder_expr = make_uniq(Value(result.types[i])); + } else { + D_ASSERT(right_exist); + left_reorder_expr = make_uniq(Value(result.types[i])); + right_reorder_expr = + make_uniq(right_node->types[right_index->second], + ColumnBinding(right_node->GetRootIndex(), right_index->second)); + } + result.left_reorder_exprs.push_back(std::move(left_reorder_expr)); + result.right_reorder_exprs.push_back(std::move(right_reorder_expr)); + } + } +} + +unique_ptr Binder::BindNode(SetOperationNode &statement) { + auto result = make_uniq(); + result->setop_type = statement.setop_type; + + // first recursively visit the set operations + // both the left and right sides have an independent BindContext and Binder + D_ASSERT(statement.left); + D_ASSERT(statement.right); + + result->setop_index = GenerateTableIndex(); + + result->left_binder = Binder::CreateBinder(context, this); + result->left_binder->can_contain_nulls = true; + result->left = result->left_binder->BindNode(*statement.left); + result->right_binder = Binder::CreateBinder(context, this); + result->right_binder->can_contain_nulls = true; + result->right = result->right_binder->BindNode(*statement.right); + + result->names = result->left->names; + + // move the correlated expressions from the child binders to this binder + MoveCorrelatedExpressions(*result->left_binder); + MoveCorrelatedExpressions(*result->right_binder); + + // now both sides have been bound we can resolve types + if (result->setop_type != SetOperationType::UNION_BY_NAME && + result->left->types.size() != result->right->types.size()) { + throw BinderException("Set operations can only apply to expressions with the " + "same number of result columns"); + } + + if (result->setop_type == SetOperationType::UNION_BY_NAME) { + BuildUnionByNameInfo(*result, can_contain_nulls); + + } else { + // figure out the types of the setop result by picking the max of both + for (idx_t i = 0; i < result->left->types.size(); i++) { + auto result_type = LogicalType::MaxLogicalType(result->left->types[i], result->right->types[i]); + if (!can_contain_nulls) { + if (ExpressionBinder::ContainsNullType(result_type)) { + result_type = ExpressionBinder::ExchangeNullType(result_type); + } + } + result->types.push_back(result_type); + } + } + + if (!statement.modifiers.empty()) { + // handle the ORDER BY/DISTINCT clauses + + // we recursively visit the children of this node to extract aliases and expressions that can be referenced + // in the ORDER BY + case_insensitive_map_t alias_map; + parsed_expression_map_t expression_map; + + if (result->setop_type == SetOperationType::UNION_BY_NAME) { + GatherAliases(*result->left, alias_map, expression_map, result->left_reorder_idx); + GatherAliases(*result->right, alias_map, expression_map, result->right_reorder_idx); + } else { + vector reorder_idx; + for (idx_t i = 0; i < result->names.size(); i++) { + reorder_idx.push_back(i); + } + GatherAliases(*result, alias_map, expression_map, reorder_idx); + } + // now we perform the actual resolution of the ORDER BY/DISTINCT expressions + OrderBinder order_binder({result->left_binder.get(), result->right_binder.get()}, result->setop_index, + alias_map, expression_map, result->names.size()); + BindModifiers(order_binder, statement, *result); + } + + // finally bind the types of the ORDER/DISTINCT clause expressions + BindModifierTypes(*result, result->types, result->setop_index); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_table_macro_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_table_macro_node.cpp new file mode 100644 index 00000000..a3b7f58c --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/bind_table_macro_node.cpp @@ -0,0 +1,69 @@ +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/joinref.hpp" +#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/function/table_macro_function.hpp" + +namespace duckdb { + +unique_ptr Binder::BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, + idx_t depth) { + + auto ¯o_def = macro_func.function->Cast(); + auto node = macro_def.query_node->Copy(); + + // auto ¯o_def = *macro_func->function; + + // validate the arguments and separate positional and default arguments + vector> positionals; + unordered_map> defaults; + string error = + MacroFunction::ValidateArguments(*macro_func.function, macro_func.name, function, positionals, defaults); + if (!error.empty()) { + // cannot use error below as binder rnot in scope + // return BindResult(binder. FormatError(*expr->get(), error)); + throw BinderException(FormatError(function, error)); + } + + // create a MacroBinding to bind this macro's parameters to its arguments + vector types; + vector names; + // positional parameters + for (idx_t i = 0; i < macro_def.parameters.size(); i++) { + types.emplace_back(LogicalType::SQLNULL); + auto ¶m = macro_def.parameters[i]->Cast(); + names.push_back(param.GetColumnName()); + } + // default parameters + for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { + types.emplace_back(LogicalType::SQLNULL); + names.push_back(it->first); + // now push the defaults into the positionals + positionals.push_back(std::move(defaults[it->first])); + } + auto new_macro_binding = make_uniq(types, names, macro_func.name); + new_macro_binding->arguments = &positionals; + + // We need an ExpressionBinder so that we can call ExpressionBinder::ReplaceMacroParametersRecursive() + auto eb = ExpressionBinder(*this, this->context); + + eb.macro_binding = new_macro_binding.get(); + + /* Does it all goes throu every expression in a selectstmt */ + ParsedExpressionIterator::EnumerateQueryNodeChildren( + *node, [&](unique_ptr &child) { eb.ReplaceMacroParametersRecursive(child); }); + + return node; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp new file mode 100644 index 00000000..cfc07915 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp @@ -0,0 +1,26 @@ +#include "duckdb/common/string_util.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/query_node/bound_cte_node.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundCTENode &node) { + // Generate the logical plan for the cte_query and child. + auto cte_query = CreatePlan(*node.query); + auto cte_child = CreatePlan(*node.child); + + auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), + std::move(cte_query), std::move(cte_child)); + + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = + node.child_binder->has_unplanned_dependent_joins || node.query_binder->has_unplanned_dependent_joins; + + return VisitQueryNode(node, std::move(root)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_query_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_query_node.cpp new file mode 100644 index 00000000..ff29a548 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/plan_query_node.cpp @@ -0,0 +1,62 @@ +#include "duckdb/parser/query_node.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" +#include "duckdb/planner/operator/logical_limit_percent.hpp" +#include "duckdb/planner/operator/logical_order.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" + +namespace duckdb { + +unique_ptr Binder::VisitQueryNode(BoundQueryNode &node, unique_ptr root) { + D_ASSERT(root); + for (auto &mod : node.modifiers) { + switch (mod->type) { + case ResultModifierType::DISTINCT_MODIFIER: { + auto &bound = mod->Cast(); + auto distinct = make_uniq(std::move(bound.target_distincts), bound.distinct_type); + distinct->AddChild(std::move(root)); + root = std::move(distinct); + break; + } + case ResultModifierType::ORDER_MODIFIER: { + auto &bound = mod->Cast(); + if (root->type == LogicalOperatorType::LOGICAL_DISTINCT) { + auto &distinct = root->Cast(); + if (distinct.distinct_type == DistinctType::DISTINCT_ON) { + auto order_by = make_uniq(); + for (auto &order_node : bound.orders) { + order_by->orders.push_back(order_node.Copy()); + } + distinct.order_by = std::move(order_by); + } + } + auto order = make_uniq(std::move(bound.orders)); + order->AddChild(std::move(root)); + root = std::move(order); + break; + } + case ResultModifierType::LIMIT_MODIFIER: { + auto &bound = mod->Cast(); + auto limit = make_uniq(bound.limit_val, bound.offset_val, std::move(bound.limit), + std::move(bound.offset)); + limit->AddChild(std::move(root)); + root = std::move(limit); + break; + } + case ResultModifierType::LIMIT_PERCENT_MODIFIER: { + auto &bound = mod->Cast(); + auto limit = make_uniq(bound.limit_percent, bound.offset_val, std::move(bound.limit), + std::move(bound.offset)); + limit->AddChild(std::move(root)); + root = std::move(limit); + break; + } + default: + throw BinderException("Unimplemented modifier type!"); + } + } + return root; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp new file mode 100644 index 00000000..d41345b1 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp @@ -0,0 +1,39 @@ +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_recursive_cte.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundRecursiveCTENode &node) { + // Generate the logical plan for the left and right sides of the set operation + node.left_binder->is_outside_flattened = is_outside_flattened; + node.right_binder->is_outside_flattened = is_outside_flattened; + + auto left_node = node.left_binder->CreatePlan(*node.left); + auto right_node = node.right_binder->CreatePlan(*node.right); + + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = + node.left_binder->has_unplanned_dependent_joins || node.right_binder->has_unplanned_dependent_joins; + + // for both the left and right sides, cast them to the same types + left_node = CastLogicalOperatorToTypes(node.left->types, node.types, std::move(left_node)); + right_node = CastLogicalOperatorToTypes(node.right->types, node.types, std::move(right_node)); + + if (!node.right_binder->bind_context.cte_references[node.ctename] || + *node.right_binder->bind_context.cte_references[node.ctename] == 0) { + auto root = make_uniq(node.setop_index, node.types.size(), std::move(left_node), + std::move(right_node), LogicalOperatorType::LOGICAL_UNION); + return VisitQueryNode(node, std::move(root)); + } + auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), node.union_all, + std::move(left_node), std::move(right_node)); + + return VisitQueryNode(node, std::move(root)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp new file mode 100644 index 00000000..46e5d2e1 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp @@ -0,0 +1,134 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" + +namespace duckdb { + +unique_ptr Binder::PlanFilter(unique_ptr condition, unique_ptr root) { + PlanSubqueries(condition, root); + auto filter = make_uniq(std::move(condition)); + filter->AddChild(std::move(root)); + return std::move(filter); +} + +unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { + unique_ptr root; + D_ASSERT(statement.from_table); + root = CreatePlan(*statement.from_table); + D_ASSERT(root); + + // plan the sample clause + if (statement.sample_options) { + root = make_uniq(std::move(statement.sample_options), std::move(root)); + } + + if (statement.where_clause) { + root = PlanFilter(std::move(statement.where_clause), std::move(root)); + } + + if (!statement.aggregates.empty() || !statement.groups.group_expressions.empty()) { + if (!statement.groups.group_expressions.empty()) { + // visit the groups + for (auto &group : statement.groups.group_expressions) { + PlanSubqueries(group, root); + } + } + // now visit all aggregate expressions + for (auto &expr : statement.aggregates) { + PlanSubqueries(expr, root); + } + // finally create the aggregate node with the group_index and aggregate_index as obtained from the binder + auto aggregate = make_uniq(statement.group_index, statement.aggregate_index, + std::move(statement.aggregates)); + aggregate->groups = std::move(statement.groups.group_expressions); + aggregate->groupings_index = statement.groupings_index; + aggregate->grouping_sets = std::move(statement.groups.grouping_sets); + aggregate->grouping_functions = std::move(statement.grouping_functions); + + aggregate->AddChild(std::move(root)); + root = std::move(aggregate); + } else if (!statement.groups.grouping_sets.empty()) { + // edge case: we have grouping sets but no groups or aggregates + // this can only happen if we have e.g. select 1 from tbl group by (); + // just output a dummy scan + root = make_uniq_base(statement.group_index); + } + + if (statement.having) { + PlanSubqueries(statement.having, root); + auto having = make_uniq(std::move(statement.having)); + + having->AddChild(std::move(root)); + root = std::move(having); + } + + if (!statement.windows.empty()) { + auto win = make_uniq(statement.window_index); + win->expressions = std::move(statement.windows); + // visit the window expressions + for (auto &expr : win->expressions) { + PlanSubqueries(expr, root); + } + D_ASSERT(!win->expressions.empty()); + win->AddChild(std::move(root)); + root = std::move(win); + } + + if (statement.qualify) { + PlanSubqueries(statement.qualify, root); + auto qualify = make_uniq(std::move(statement.qualify)); + + qualify->AddChild(std::move(root)); + root = std::move(qualify); + } + + for (idx_t i = statement.unnests.size(); i > 0; i--) { + auto unnest_level = i - 1; + auto entry = statement.unnests.find(unnest_level); + if (entry == statement.unnests.end()) { + throw InternalException("unnests specified at level %d but none were found", unnest_level); + } + auto &unnest_node = entry->second; + auto unnest = make_uniq(unnest_node.index); + unnest->expressions = std::move(unnest_node.expressions); + // visit the unnest expressions + for (auto &expr : unnest->expressions) { + PlanSubqueries(expr, root); + } + D_ASSERT(!unnest->expressions.empty()); + unnest->AddChild(std::move(root)); + root = std::move(unnest); + } + + for (auto &expr : statement.select_list) { + PlanSubqueries(expr, root); + } + + auto proj = make_uniq(statement.projection_index, std::move(statement.select_list)); + auto &projection = *proj; + proj->AddChild(std::move(root)); + root = std::move(proj); + + // finish the plan by handling the elements of the QueryNode + root = VisitQueryNode(statement, std::move(root)); + + // add a prune node if necessary + if (statement.need_prune) { + D_ASSERT(root); + vector> prune_expressions; + for (idx_t i = 0; i < statement.column_count; i++) { + prune_expressions.push_back(make_uniq( + projection.expressions[i]->return_type, ColumnBinding(statement.projection_index, i))); + } + auto prune = make_uniq(statement.prune_index, std::move(prune_expressions)); + prune->AddChild(std::move(root)); + root = std::move(prune); + } + return root; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp new file mode 100644 index 00000000..3d9b7794 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp @@ -0,0 +1,123 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/query_node/bound_set_operation_node.hpp" + +namespace duckdb { + +// Optionally push a PROJECTION operator +unique_ptr Binder::CastLogicalOperatorToTypes(vector &source_types, + vector &target_types, + unique_ptr op) { + D_ASSERT(op); + // first check if we even need to cast + D_ASSERT(source_types.size() == target_types.size()); + if (source_types == target_types) { + // source and target types are equal: don't need to cast + return op; + } + // otherwise add casts + auto node = op.get(); + if (node->type == LogicalOperatorType::LOGICAL_PROJECTION) { + // "node" is a projection; we can just do the casts in there + D_ASSERT(node->expressions.size() == source_types.size()); + // add the casts to the selection list + for (idx_t i = 0; i < target_types.size(); i++) { + if (source_types[i] != target_types[i]) { + // differing types, have to add a cast + string alias = node->expressions[i]->alias; + node->expressions[i] = + BoundCastExpression::AddCastToType(context, std::move(node->expressions[i]), target_types[i]); + node->expressions[i]->alias = alias; + } + } + return op; + } else { + // found a non-projection operator + // push a new projection containing the casts + + // fetch the set of column bindings + auto setop_columns = op->GetColumnBindings(); + D_ASSERT(setop_columns.size() == source_types.size()); + + // now generate the expression list + vector> select_list; + for (idx_t i = 0; i < target_types.size(); i++) { + unique_ptr result = make_uniq(source_types[i], setop_columns[i]); + if (source_types[i] != target_types[i]) { + // add a cast only if the source and target types are not equivalent + result = BoundCastExpression::AddCastToType(context, std::move(result), target_types[i]); + } + select_list.push_back(std::move(result)); + } + auto projection = make_uniq(GenerateTableIndex(), std::move(select_list)); + projection->children.push_back(std::move(op)); + return std::move(projection); + } +} + +unique_ptr Binder::CreatePlan(BoundSetOperationNode &node) { + // Generate the logical plan for the left and right sides of the set operation + node.left_binder->is_outside_flattened = is_outside_flattened; + node.right_binder->is_outside_flattened = is_outside_flattened; + + auto left_node = node.left_binder->CreatePlan(*node.left); + auto right_node = node.right_binder->CreatePlan(*node.right); + + // Add a new projection to child node + D_ASSERT(node.left_reorder_exprs.size() == node.right_reorder_exprs.size()); + if (!node.left_reorder_exprs.empty()) { + D_ASSERT(node.setop_type == SetOperationType::UNION_BY_NAME); + vector left_types; + vector right_types; + // We are going to add a new projection operator, so collect the type + // of reorder exprs in order to call CastLogicalOperatorToTypes() + for (idx_t i = 0; i < node.left_reorder_exprs.size(); ++i) { + left_types.push_back(node.left_reorder_exprs[i]->return_type); + right_types.push_back(node.right_reorder_exprs[i]->return_type); + } + + auto left_projection = make_uniq(GenerateTableIndex(), std::move(node.left_reorder_exprs)); + left_projection->children.push_back(std::move(left_node)); + left_node = std::move(left_projection); + + auto right_projection = make_uniq(GenerateTableIndex(), std::move(node.right_reorder_exprs)); + right_projection->children.push_back(std::move(right_node)); + right_node = std::move(right_projection); + + left_node = CastLogicalOperatorToTypes(left_types, node.types, std::move(left_node)); + right_node = CastLogicalOperatorToTypes(right_types, node.types, std::move(right_node)); + } else { + left_node = CastLogicalOperatorToTypes(node.left->types, node.types, std::move(left_node)); + right_node = CastLogicalOperatorToTypes(node.right->types, node.types, std::move(right_node)); + } + + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = + node.left_binder->has_unplanned_dependent_joins || node.right_binder->has_unplanned_dependent_joins; + + // create actual logical ops for setops + LogicalOperatorType logical_type; + switch (node.setop_type) { + case SetOperationType::UNION: + case SetOperationType::UNION_BY_NAME: + logical_type = LogicalOperatorType::LOGICAL_UNION; + break; + case SetOperationType::EXCEPT: + logical_type = LogicalOperatorType::LOGICAL_EXCEPT; + break; + default: + D_ASSERT(node.setop_type == SetOperationType::INTERSECT); + logical_type = LogicalOperatorType::LOGICAL_INTERSECT; + break; + } + + auto root = make_uniq(node.setop_index, node.types.size(), std::move(left_node), + std::move(right_node), logical_type); + + return VisitQueryNode(node, std::move(root)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp new file mode 100644 index 00000000..8d019e14 --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp @@ -0,0 +1,459 @@ +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_subquery_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/planner/operator/logical_window.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/planner/subquery/flatten_dependent_join.hpp" +#include "duckdb/common/enums/logical_operator_type.hpp" +#include "duckdb/planner/operator/logical_dependent_join.hpp" +#include "duckdb/planner/expression_binder/lateral_binder.hpp" +#include "duckdb/planner/subquery/recursive_dependent_join_planner.hpp" + +namespace duckdb { + +static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubqueryExpression &expr, + unique_ptr &root, + unique_ptr plan) { + D_ASSERT(!expr.IsCorrelated()); + switch (expr.subquery_type) { + case SubqueryType::EXISTS: { + // uncorrelated EXISTS + // we only care about existence, hence we push a LIMIT 1 operator + auto limit = make_uniq(1, 0, nullptr, nullptr); + limit->AddChild(std::move(plan)); + plan = std::move(limit); + + // now we push a COUNT(*) aggregate onto the limit, this will be either 0 or 1 (EXISTS or NOT EXISTS) + auto count_star_fun = CountStarFun::GetFunction(); + + FunctionBinder function_binder(binder.context); + auto count_star = + function_binder.BindAggregateFunction(count_star_fun, {}, nullptr, AggregateType::NON_DISTINCT); + auto idx_type = count_star->return_type; + vector> aggregate_list; + aggregate_list.push_back(std::move(count_star)); + auto aggregate_index = binder.GenerateTableIndex(); + auto aggregate = + make_uniq(binder.GenerateTableIndex(), aggregate_index, std::move(aggregate_list)); + aggregate->AddChild(std::move(plan)); + plan = std::move(aggregate); + + // now we push a projection with a comparison to 1 + auto left_child = make_uniq(idx_type, ColumnBinding(aggregate_index, 0)); + auto right_child = make_uniq(Value::Numeric(idx_type, 1)); + auto comparison = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(left_child), + std::move(right_child)); + + vector> projection_list; + projection_list.push_back(std::move(comparison)); + auto projection_index = binder.GenerateTableIndex(); + auto projection = make_uniq(projection_index, std::move(projection_list)); + projection->AddChild(std::move(plan)); + plan = std::move(projection); + + // we add it to the main query by adding a cross product + // FIXME: should use something else besides cross product as we always add only one scalar constant + root = LogicalCrossProduct::Create(std::move(root), std::move(plan)); + + // we replace the original subquery with a ColumnRefExpression referring to the result of the projection (either + // TRUE or FALSE) + return make_uniq(expr.GetName(), LogicalType::BOOLEAN, + ColumnBinding(projection_index, 0)); + } + case SubqueryType::SCALAR: { + // uncorrelated scalar, we want to return the first entry + // figure out the table index of the bound table of the entry which we want to return + auto bindings = plan->GetColumnBindings(); + D_ASSERT(bindings.size() == 1); + idx_t table_idx = bindings[0].table_index; + + // in the uncorrelated case we are only interested in the first result of the query + // hence we simply push a LIMIT 1 to get the first row of the subquery + auto limit = make_uniq(1, 0, nullptr, nullptr); + limit->AddChild(std::move(plan)); + plan = std::move(limit); + + // we push an aggregate that returns the FIRST element + vector> expressions; + auto bound = make_uniq(expr.return_type, ColumnBinding(table_idx, 0)); + vector> first_children; + first_children.push_back(std::move(bound)); + + FunctionBinder function_binder(binder.context); + auto first_agg = function_binder.BindAggregateFunction( + FirstFun::GetFunction(expr.return_type), std::move(first_children), nullptr, AggregateType::NON_DISTINCT); + + expressions.push_back(std::move(first_agg)); + auto aggr_index = binder.GenerateTableIndex(); + auto aggr = make_uniq(binder.GenerateTableIndex(), aggr_index, std::move(expressions)); + aggr->AddChild(std::move(plan)); + plan = std::move(aggr); + + // in the uncorrelated case, we add the value to the main query through a cross product + // FIXME: should use something else besides cross product as we always add only one scalar constant and cross + // product is not optimized for this. + D_ASSERT(root); + root = LogicalCrossProduct::Create(std::move(root), std::move(plan)); + + // we replace the original subquery with a BoundColumnRefExpression referring to the first result of the + // aggregation + return make_uniq(expr.GetName(), expr.return_type, ColumnBinding(aggr_index, 0)); + } + default: { + D_ASSERT(expr.subquery_type == SubqueryType::ANY); + // we generate a MARK join that results in either (TRUE, FALSE or NULL) + // subquery has NULL values -> result is (TRUE or NULL) + // subquery has no NULL values -> result is (TRUE, FALSE or NULL [if input is NULL]) + // fetch the column bindings + auto plan_columns = plan->GetColumnBindings(); + + // then we generate the MARK join with the subquery + idx_t mark_index = binder.GenerateTableIndex(); + auto join = make_uniq(JoinType::MARK); + join->mark_index = mark_index; + join->AddChild(std::move(root)); + join->AddChild(std::move(plan)); + // create the JOIN condition + JoinCondition cond; + cond.left = std::move(expr.child); + cond.right = BoundCastExpression::AddDefaultCastToType( + make_uniq(expr.child_type, plan_columns[0]), expr.child_target); + cond.comparison = expr.comparison_type; + join->conditions.push_back(std::move(cond)); + root = std::move(join); + + // we replace the original subquery with a BoundColumnRefExpression referring to the mark column + return make_uniq(expr.GetName(), expr.return_type, ColumnBinding(mark_index, 0)); + } + } +} + +static unique_ptr +CreateDuplicateEliminatedJoin(const vector &correlated_columns, JoinType join_type, + unique_ptr original_plan, bool perform_delim) { + auto delim_join = make_uniq(join_type, LogicalOperatorType::LOGICAL_DELIM_JOIN); + if (!perform_delim) { + // if we are not performing a delim join, we push a row_number() OVER() window operator on the LHS + // and perform all duplicate elimination on that row number instead + D_ASSERT(correlated_columns[0].type.id() == LogicalTypeId::BIGINT); + auto window = make_uniq(correlated_columns[0].binding.table_index); + auto row_number = + make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); + row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; + row_number->end = WindowBoundary::CURRENT_ROW_ROWS; + row_number->alias = "delim_index"; + window->expressions.push_back(std::move(row_number)); + window->AddChild(std::move(original_plan)); + original_plan = std::move(window); + } + delim_join->AddChild(std::move(original_plan)); + for (idx_t i = 0; i < correlated_columns.size(); i++) { + auto &col = correlated_columns[i]; + delim_join->duplicate_eliminated_columns.push_back(make_uniq(col.type, col.binding)); + delim_join->mark_types.push_back(col.type); + } + return delim_join; +} + +static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, + const vector &correlated_columns, + vector bindings, idx_t base_offset, bool perform_delim) { + auto col_count = perform_delim ? correlated_columns.size() : 1; + for (idx_t i = 0; i < col_count; i++) { + auto &col = correlated_columns[i]; + auto binding_idx = base_offset + i; + if (binding_idx >= bindings.size()) { + throw InternalException("Delim join - binding index out of range"); + } + JoinCondition cond; + cond.left = make_uniq(col.name, col.type, col.binding); + cond.right = make_uniq(col.name, col.type, bindings[binding_idx]); + cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + delim_join.conditions.push_back(std::move(cond)); + } +} + +static bool PerformDelimOnType(const LogicalType &type) { + if (type.InternalType() == PhysicalType::LIST) { + return false; + } + if (type.InternalType() == PhysicalType::STRUCT) { + for (auto &entry : StructType::GetChildTypes(type)) { + if (!PerformDelimOnType(entry.second)) { + return false; + } + } + } + return true; +} + +static bool PerformDuplicateElimination(Binder &binder, vector &correlated_columns) { + if (!ClientConfig::GetConfig(binder.context).enable_optimizer) { + // if optimizations are disabled we always do a delim join + return true; + } + bool perform_delim = true; + for (auto &col : correlated_columns) { + if (!PerformDelimOnType(col.type)) { + perform_delim = false; + break; + } + } + if (perform_delim) { + return true; + } + auto binding = ColumnBinding(binder.GenerateTableIndex(), 0); + auto type = LogicalType::BIGINT; + auto name = "delim_index"; + CorrelatedColumnInfo info(binding, type, name, 0); + correlated_columns.insert(correlated_columns.begin(), std::move(info)); + return false; +} + +static unique_ptr PlanCorrelatedSubquery(Binder &binder, BoundSubqueryExpression &expr, + unique_ptr &root, + unique_ptr plan) { + auto &correlated_columns = expr.binder->correlated_columns; + // FIXME: there should be a way of disabling decorrelation for ANY queries as well, but not for now... + bool perform_delim = + expr.subquery_type == SubqueryType::ANY ? true : PerformDuplicateElimination(binder, correlated_columns); + D_ASSERT(expr.IsCorrelated()); + // correlated subquery + // for a more in-depth explanation of this code, read the paper "Unnesting Arbitrary Subqueries" + // we handle three types of correlated subqueries: Scalar, EXISTS and ANY + // all three cases are very similar with some minor changes (mainly the type of join performed at the end) + switch (expr.subquery_type) { + case SubqueryType::SCALAR: { + // correlated SCALAR query + // first push a DUPLICATE ELIMINATED join + // a duplicate eliminated join creates a duplicate eliminated copy of the LHS + // and pushes it into any DUPLICATE_ELIMINATED SCAN operators on the RHS + + // in the SCALAR case, we create a SINGLE join (because we are only interested in obtaining the value) + // NULL values are equal in this join because we join on the correlated columns ONLY + // and e.g. in the query: SELECT (SELECT 42 FROM integers WHERE i1.i IS NULL LIMIT 1) FROM integers i1; + // the input value NULL will generate the value 42, and we need to join NULL on the LHS with NULL on the RHS + // the left side is the original plan + // this is the side that will be duplicate eliminated and pushed into the RHS + auto delim_join = + CreateDuplicateEliminatedJoin(correlated_columns, JoinType::SINGLE, std::move(root), perform_delim); + + // the right side initially is a DEPENDENT join between the duplicate eliminated scan and the subquery + // HOWEVER: we do not explicitly create the dependent join + // instead, we eliminate the dependent join by pushing it down into the right side of the plan + FlattenDependentJoins flatten(binder, correlated_columns, perform_delim); + + // first we check which logical operators have correlated expressions in the first place + flatten.DetectCorrelatedExpressions(plan.get()); + // now we push the dependent join down + auto dependent_join = flatten.PushDownDependentJoin(std::move(plan)); + + // now the dependent join is fully eliminated + // we only need to create the join conditions between the LHS and the RHS + // fetch the set of columns + auto plan_columns = dependent_join->GetColumnBindings(); + + // now create the join conditions + CreateDelimJoinConditions(*delim_join, correlated_columns, plan_columns, flatten.delim_offset, perform_delim); + delim_join->AddChild(std::move(dependent_join)); + root = std::move(delim_join); + // finally push the BoundColumnRefExpression referring to the data element returned by the join + return make_uniq(expr.GetName(), expr.return_type, plan_columns[flatten.data_offset]); + } + case SubqueryType::EXISTS: { + // correlated EXISTS query + // this query is similar to the correlated SCALAR query, except we use a MARK join here + idx_t mark_index = binder.GenerateTableIndex(); + auto delim_join = + CreateDuplicateEliminatedJoin(correlated_columns, JoinType::MARK, std::move(root), perform_delim); + delim_join->mark_index = mark_index; + // RHS + FlattenDependentJoins flatten(binder, correlated_columns, perform_delim, true); + flatten.DetectCorrelatedExpressions(plan.get()); + auto dependent_join = flatten.PushDownDependentJoin(std::move(plan)); + + // fetch the set of columns + auto plan_columns = dependent_join->GetColumnBindings(); + + // now we create the join conditions between the dependent join and the original table + CreateDelimJoinConditions(*delim_join, correlated_columns, plan_columns, flatten.delim_offset, perform_delim); + delim_join->AddChild(std::move(dependent_join)); + root = std::move(delim_join); + // finally push the BoundColumnRefExpression referring to the marker + return make_uniq(expr.GetName(), expr.return_type, ColumnBinding(mark_index, 0)); + } + default: { + D_ASSERT(expr.subquery_type == SubqueryType::ANY); + // correlated ANY query + // this query is similar to the correlated SCALAR query + // however, in this case we push a correlated MARK join + // note that in this join null values are NOT equal for ALL columns, but ONLY for the correlated columns + // the correlated mark join handles this case by itself + // as the MARK join has one extra join condition (the original condition, of the ANY expression, e.g. + // [i=ANY(...)]) + idx_t mark_index = binder.GenerateTableIndex(); + auto delim_join = + CreateDuplicateEliminatedJoin(correlated_columns, JoinType::MARK, std::move(root), perform_delim); + delim_join->mark_index = mark_index; + // RHS + FlattenDependentJoins flatten(binder, correlated_columns, true, true); + flatten.DetectCorrelatedExpressions(plan.get()); + auto dependent_join = flatten.PushDownDependentJoin(std::move(plan)); + + // fetch the columns + auto plan_columns = dependent_join->GetColumnBindings(); + + // now we create the join conditions between the dependent join and the original table + CreateDelimJoinConditions(*delim_join, correlated_columns, plan_columns, flatten.delim_offset, perform_delim); + // add the actual condition based on the ANY/ALL predicate + JoinCondition compare_cond; + compare_cond.left = std::move(expr.child); + compare_cond.right = BoundCastExpression::AddDefaultCastToType( + make_uniq(expr.child_type, plan_columns[0]), expr.child_target); + compare_cond.comparison = expr.comparison_type; + delim_join->conditions.push_back(std::move(compare_cond)); + + delim_join->AddChild(std::move(dependent_join)); + root = std::move(delim_join); + // finally push the BoundColumnRefExpression referring to the marker + return make_uniq(expr.GetName(), expr.return_type, ColumnBinding(mark_index, 0)); + } + } +} + +void RecursiveDependentJoinPlanner::VisitOperator(LogicalOperator &op) { + if (!op.children.empty()) { + root = std::move(op.children[0]); + D_ASSERT(root); + if (root->type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { + // Found a dependent join, flatten it + auto &new_root = root->Cast(); + root = binder.PlanLateralJoin(std::move(new_root.children[0]), std::move(new_root.children[1]), + new_root.correlated_columns, new_root.join_type, + std::move(new_root.join_condition)); + } + VisitOperatorExpressions(op); + op.children[0] = std::move(root); + for (idx_t i = 0; i < op.children.size(); i++) { + D_ASSERT(op.children[i]); + VisitOperator(*op.children[i]); + } + } +} + +unique_ptr RecursiveDependentJoinPlanner::VisitReplace(BoundSubqueryExpression &expr, + unique_ptr *expr_ptr) { + return binder.PlanSubquery(expr, root); +} + +unique_ptr Binder::PlanSubquery(BoundSubqueryExpression &expr, unique_ptr &root) { + D_ASSERT(root); + // first we translate the QueryNode of the subquery into a logical plan + // note that we do not plan nested subqueries yet + auto sub_binder = Binder::CreateBinder(context, this); + sub_binder->is_outside_flattened = false; + auto subquery_root = sub_binder->CreatePlan(*expr.subquery); + D_ASSERT(subquery_root); + + // now we actually flatten the subquery + auto plan = std::move(subquery_root); + + unique_ptr result_expression; + if (!expr.IsCorrelated()) { + result_expression = PlanUncorrelatedSubquery(*this, expr, root, std::move(plan)); + } else { + result_expression = PlanCorrelatedSubquery(*this, expr, root, std::move(plan)); + } + // finally, we recursively plan the nested subqueries (if there are any) + if (sub_binder->has_unplanned_dependent_joins) { + RecursiveDependentJoinPlanner plan(*this); + plan.VisitOperator(*root); + } + return result_expression; +} + +void Binder::PlanSubqueries(unique_ptr &expr_ptr, unique_ptr &root) { + if (!expr_ptr) { + return; + } + auto &expr = *expr_ptr; + // first visit the children of the node, if any + ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &expr) { PlanSubqueries(expr, root); }); + + // check if this is a subquery node + if (expr.expression_class == ExpressionClass::BOUND_SUBQUERY) { + auto &subquery = expr.Cast(); + // subquery node! plan it + if (subquery.IsCorrelated() && !is_outside_flattened) { + // detected a nested correlated subquery + // we don't plan it yet here, we are currently planning a subquery + // nested subqueries will only be planned AFTER the current subquery has been flattened entirely + has_unplanned_dependent_joins = true; + return; + } + expr_ptr = PlanSubquery(subquery, root); + } +} + +unique_ptr Binder::PlanLateralJoin(unique_ptr left, unique_ptr right, + vector &correlated_columns, + JoinType join_type, unique_ptr condition) { + // scan the right operator for correlated columns + // correlated LATERAL JOIN + vector conditions; + vector> arbitrary_expressions; + if (condition) { + // extract join conditions, if there are any + LogicalComparisonJoin::ExtractJoinConditions(context, join_type, left, right, std::move(condition), conditions, + arbitrary_expressions); + } + + auto perform_delim = PerformDuplicateElimination(*this, correlated_columns); + auto delim_join = CreateDuplicateEliminatedJoin(correlated_columns, join_type, std::move(left), perform_delim); + + FlattenDependentJoins flatten(*this, correlated_columns, perform_delim); + + // first we check which logical operators have correlated expressions in the first place + flatten.DetectCorrelatedExpressions(right.get(), true); + // now we push the dependent join down + auto dependent_join = flatten.PushDownDependentJoin(std::move(right)); + + // now the dependent join is fully eliminated + // we only need to create the join conditions between the LHS and the RHS + // fetch the set of columns + auto plan_columns = dependent_join->GetColumnBindings(); + + // now create the join conditions + // start off with the conditions that were passed in (if any) + D_ASSERT(delim_join->conditions.empty()); + delim_join->conditions = std::move(conditions); + // then add the delim join conditions + CreateDelimJoinConditions(*delim_join, correlated_columns, plan_columns, flatten.delim_offset, perform_delim); + delim_join->AddChild(std::move(dependent_join)); + + // check if there are any arbitrary expressions left + if (!arbitrary_expressions.empty()) { + // we can only evaluate scalar arbitrary expressions for inner joins + if (join_type != JoinType::INNER) { + throw BinderException( + "Join condition for non-inner LATERAL JOIN must be a comparison between the left and right side"); + } + auto filter = make_uniq(); + filter->expressions = std::move(arbitrary_expressions); + filter->AddChild(std::move(delim_join)); + return std::move(filter); + } + return std::move(delim_join); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_attach.cpp b/src/duckdb/src/planner/binder/statement/bind_attach.cpp new file mode 100644 index 00000000..63921911 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_attach.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/attach_statement.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/planner/tableref/bound_table_function.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(AttachStatement &stmt) { + BoundStatement result; + result.types = {LogicalType::BOOLEAN}; + result.names = {"Success"}; + + result.plan = make_uniq(LogicalOperatorType::LOGICAL_ATTACH, std::move(stmt.info)); + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_call.cpp b/src/duckdb/src/planner/binder/statement/bind_call.cpp new file mode 100644 index 00000000..8f910bb9 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_call.cpp @@ -0,0 +1,31 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/call_statement.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/planner/tableref/bound_table_function.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(CallStatement &stmt) { + BoundStatement result; + + TableFunctionRef ref; + ref.function = std::move(stmt.function); + + auto bound_func = Bind(ref); + auto &bound_table_func = bound_func->Cast(); + ; + auto &get = bound_table_func.get->Cast(); + D_ASSERT(get.returned_types.size() > 0); + for (idx_t i = 0; i < get.returned_types.size(); i++) { + get.column_ids.push_back(i); + } + + result.types = get.returned_types; + result.names = get.names; + result.plan = CreatePlan(*bound_func); + properties.return_type = StatementReturnType::QUERY_RESULT; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_copy.cpp b/src/duckdb/src/planner/binder/statement/bind_copy.cpp new file mode 100644 index 00000000..35cc798e --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_copy.cpp @@ -0,0 +1,260 @@ +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/common/bind_helpers.hpp" +#include "duckdb/common/filename_pattern.hpp" +#include "duckdb/common/local_file_system.hpp" +#include "duckdb/execution/operator/scan/csv/parallel_csv_reader.hpp" +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_copy_to_file.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_insert.hpp" + +#include + +namespace duckdb { + +vector GetUniqueNames(const vector &original_names) { + unordered_set name_set; + vector unique_names; + unique_names.reserve(original_names.size()); + + for (auto &name : original_names) { + auto insert_result = name_set.insert(name); + if (insert_result.second == false) { + // Could not be inserted, name already exists + idx_t index = 1; + string postfixed_name; + while (true) { + postfixed_name = StringUtil::Format("%s:%d", name, index); + auto res = name_set.insert(postfixed_name); + if (!res.second) { + index++; + continue; + } + break; + } + unique_names.push_back(postfixed_name); + } else { + unique_names.push_back(name); + } + } + return unique_names; +} + +BoundStatement Binder::BindCopyTo(CopyStatement &stmt) { + // COPY TO a file + auto &config = DBConfig::GetConfig(context); + if (!config.options.enable_external_access) { + throw PermissionException("COPY TO is disabled by configuration"); + } + BoundStatement result; + result.types = {LogicalType::BIGINT}; + result.names = {"Count"}; + + // lookup the format in the catalog + auto ©_function = + Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, stmt.info->format); + if (copy_function.function.plan) { + // plan rewrite COPY TO + return copy_function.function.plan(*this, stmt); + } + + // bind the select statement + auto select_node = Bind(*stmt.select_statement); + + if (!copy_function.function.copy_to_bind) { + throw NotImplementedException("COPY TO is not supported for FORMAT \"%s\"", stmt.info->format); + } + bool use_tmp_file = true; + bool overwrite_or_ignore = false; + FilenamePattern filename_pattern; + bool user_set_use_tmp_file = false; + bool per_thread_output = false; + vector partition_cols; + + auto original_options = stmt.info->options; + stmt.info->options.clear(); + + for (auto &option : original_options) { + auto loption = StringUtil::Lower(option.first); + if (loption == "use_tmp_file") { + use_tmp_file = + option.second.empty() || option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue(); + user_set_use_tmp_file = true; + continue; + } + if (loption == "overwrite_or_ignore") { + overwrite_or_ignore = + option.second.empty() || option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue(); + continue; + } + if (loption == "filename_pattern") { + if (option.second.empty()) { + throw IOException("FILENAME_PATTERN cannot be empty"); + } + filename_pattern.SetFilenamePattern( + option.second[0].CastAs(context, LogicalType::VARCHAR).GetValue()); + continue; + } + + if (loption == "per_thread_output") { + per_thread_output = + option.second.empty() || option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue(); + continue; + } + if (loption == "partition_by") { + auto converted = ConvertVectorToValue(std::move(option.second)); + partition_cols = ParseColumnsOrdered(converted, select_node.names, loption); + continue; + } + stmt.info->options[option.first] = option.second; + } + if (user_set_use_tmp_file && per_thread_output) { + throw NotImplementedException("Can't combine USE_TMP_FILE and PER_THREAD_OUTPUT for COPY"); + } + if (user_set_use_tmp_file && !partition_cols.empty()) { + throw NotImplementedException("Can't combine USE_TMP_FILE and PARTITION_BY for COPY"); + } + if (per_thread_output && !partition_cols.empty()) { + throw NotImplementedException("Can't combine PER_THREAD_OUTPUT and PARTITION_BY for COPY"); + } + bool is_remote_file = config.file_system->IsRemoteFile(stmt.info->file_path); + if (is_remote_file) { + use_tmp_file = false; + } else { + bool is_file_and_exists = config.file_system->FileExists(stmt.info->file_path); + bool is_stdout = stmt.info->file_path == "/dev/stdout"; + if (!user_set_use_tmp_file) { + use_tmp_file = is_file_and_exists && !per_thread_output && partition_cols.empty() && !is_stdout; + } + } + + auto unique_column_names = GetUniqueNames(select_node.names); + + auto function_data = + copy_function.function.copy_to_bind(context, *stmt.info, unique_column_names, select_node.types); + // now create the copy information + auto copy = make_uniq(copy_function.function, std::move(function_data)); + copy->file_path = stmt.info->file_path; + copy->use_tmp_file = use_tmp_file; + copy->overwrite_or_ignore = overwrite_or_ignore; + copy->filename_pattern = filename_pattern; + copy->per_thread_output = per_thread_output; + copy->partition_output = !partition_cols.empty(); + copy->partition_columns = std::move(partition_cols); + + copy->names = unique_column_names; + copy->expected_types = select_node.types; + + copy->AddChild(std::move(select_node.plan)); + + result.plan = std::move(copy); + + return result; +} + +BoundStatement Binder::BindCopyFrom(CopyStatement &stmt) { + auto &config = DBConfig::GetConfig(context); + if (!config.options.enable_external_access) { + throw PermissionException("COPY FROM is disabled by configuration"); + } + BoundStatement result; + result.types = {LogicalType::BIGINT}; + result.names = {"Count"}; + + if (stmt.info->table.empty()) { + throw ParserException("COPY FROM requires a table name to be specified"); + } + // COPY FROM a file + // generate an insert statement for the the to-be-inserted table + InsertStatement insert; + insert.table = stmt.info->table; + insert.schema = stmt.info->schema; + insert.catalog = stmt.info->catalog; + insert.columns = stmt.info->select_list; + + // bind the insert statement to the base table + auto insert_statement = Bind(insert); + D_ASSERT(insert_statement.plan->type == LogicalOperatorType::LOGICAL_INSERT); + + auto &bound_insert = insert_statement.plan->Cast(); + + // lookup the format in the catalog + auto &catalog = Catalog::GetSystemCatalog(context); + auto ©_function = catalog.GetEntry(context, DEFAULT_SCHEMA, stmt.info->format); + if (!copy_function.function.copy_from_bind) { + throw NotImplementedException("COPY FROM is not supported for FORMAT \"%s\"", stmt.info->format); + } + // lookup the table to copy into + BindSchemaOrCatalog(stmt.info->catalog, stmt.info->schema); + auto &table = + Catalog::GetEntry(context, stmt.info->catalog, stmt.info->schema, stmt.info->table); + vector expected_names; + if (!bound_insert.column_index_map.empty()) { + expected_names.resize(bound_insert.expected_types.size()); + for (auto &col : table.GetColumns().Physical()) { + auto i = col.Physical(); + if (bound_insert.column_index_map[i] != DConstants::INVALID_INDEX) { + expected_names[bound_insert.column_index_map[i]] = col.Name(); + } + } + } else { + expected_names.reserve(bound_insert.expected_types.size()); + for (auto &col : table.GetColumns().Physical()) { + expected_names.push_back(col.Name()); + } + } + + auto function_data = + copy_function.function.copy_from_bind(context, *stmt.info, expected_names, bound_insert.expected_types); + auto get = make_uniq(GenerateTableIndex(), copy_function.function.copy_from_function, + std::move(function_data), bound_insert.expected_types, expected_names); + for (idx_t i = 0; i < bound_insert.expected_types.size(); i++) { + get->column_ids.push_back(i); + } + insert_statement.plan->children.push_back(std::move(get)); + result.plan = std::move(insert_statement.plan); + return result; +} + +BoundStatement Binder::Bind(CopyStatement &stmt) { + if (!stmt.info->is_from && !stmt.select_statement) { + // copy table into file without a query + // generate SELECT * FROM table; + auto ref = make_uniq(); + ref->catalog_name = stmt.info->catalog; + ref->schema_name = stmt.info->schema; + ref->table_name = stmt.info->table; + + auto statement = make_uniq(); + statement->from_table = std::move(ref); + if (!stmt.info->select_list.empty()) { + for (auto &name : stmt.info->select_list) { + statement->select_list.push_back(make_uniq(name)); + } + } else { + statement->select_list.push_back(make_uniq()); + } + stmt.select_statement = std::move(statement); + } + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::CHANGED_ROWS; + if (stmt.info->is_from) { + return BindCopyFrom(stmt); + } else { + return BindCopyTo(stmt); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_create.cpp b/src/duckdb/src/planner/binder/statement/bind_create.cpp new file mode 100644 index 00000000..0916da30 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_create.cpp @@ -0,0 +1,628 @@ +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/expression_binder/index_binder.hpp" +#include "duckdb/planner/expression_binder/select_binder.hpp" +#include "duckdb/planner/operator/logical_create.hpp" +#include "duckdb/planner/operator/logical_create_index.hpp" +#include "duckdb/planner/operator/logical_create_table.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/tableref/bound_basetableref.hpp" +#include "duckdb/parser/constraints/foreign_key_constraint.hpp" +#include "duckdb/function/scalar_macro_function.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/storage_extension.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/parser/constraints/unique_constraint.hpp" +#include "duckdb/parser/constraints/list.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/function/table/table_scan.hpp" + +namespace duckdb { + +void Binder::BindSchemaOrCatalog(ClientContext &context, string &catalog, string &schema) { + if (catalog.empty() && !schema.empty()) { + // schema is specified - but catalog is not + // try searching for the catalog instead + auto &db_manager = DatabaseManager::Get(context); + auto database = db_manager.GetDatabase(context, schema); + if (database) { + // we have a database with this name + // check if there is a schema + auto schema_obj = Catalog::GetSchema(context, INVALID_CATALOG, schema, OnEntryNotFound::RETURN_NULL); + if (schema_obj) { + auto &attached = schema_obj->catalog.GetAttached(); + throw BinderException( + "Ambiguous reference to catalog or schema \"%s\" - use a fully qualified path like \"%s.%s\"", + schema, attached.GetName(), schema); + } + catalog = schema; + schema = string(); + } + } +} + +void Binder::BindSchemaOrCatalog(string &catalog, string &schema) { + BindSchemaOrCatalog(context, catalog, schema); +} + +SchemaCatalogEntry &Binder::BindSchema(CreateInfo &info) { + BindSchemaOrCatalog(info.catalog, info.schema); + if (IsInvalidCatalog(info.catalog) && info.temporary) { + info.catalog = TEMP_CATALOG; + } + auto &search_path = ClientData::Get(context).catalog_search_path; + if (IsInvalidCatalog(info.catalog) && IsInvalidSchema(info.schema)) { + auto &default_entry = search_path->GetDefault(); + info.catalog = default_entry.catalog; + info.schema = default_entry.schema; + } else if (IsInvalidSchema(info.schema)) { + info.schema = search_path->GetDefaultSchema(info.catalog); + } else if (IsInvalidCatalog(info.catalog)) { + info.catalog = search_path->GetDefaultCatalog(info.schema); + } + if (IsInvalidCatalog(info.catalog)) { + info.catalog = DatabaseManager::GetDefaultDatabase(context); + } + if (!info.temporary) { + // non-temporary create: not read only + if (info.catalog == TEMP_CATALOG) { + throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", TEMP_CATALOG); + } + } else { + if (info.catalog != TEMP_CATALOG) { + throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", TEMP_CATALOG); + } + } + // fetch the schema in which we want to create the object + auto &schema_obj = Catalog::GetSchema(context, info.catalog, info.schema); + D_ASSERT(schema_obj.type == CatalogType::SCHEMA_ENTRY); + info.schema = schema_obj.name; + if (!info.temporary) { + properties.modified_databases.insert(schema_obj.catalog.GetName()); + } + return schema_obj; +} + +SchemaCatalogEntry &Binder::BindCreateSchema(CreateInfo &info) { + auto &schema = BindSchema(info); + if (schema.catalog.IsSystemCatalog()) { + throw BinderException("Cannot create entry in system catalog"); + } + return schema; +} + +void Binder::BindCreateViewInfo(CreateViewInfo &base) { + // bind the view as if it were a query so we can catch errors + // note that we bind the original, and replace the original with a copy + auto view_binder = Binder::CreateBinder(context); + view_binder->can_contain_nulls = true; + + auto copy = base.query->Copy(); + auto query_node = view_binder->Bind(*base.query); + base.query = unique_ptr_cast(std::move(copy)); + if (base.aliases.size() > query_node.names.size()) { + throw BinderException("More VIEW aliases than columns in query result"); + } + // fill up the aliases with the remaining names of the bound query + base.aliases.reserve(query_node.names.size()); + for (idx_t i = base.aliases.size(); i < query_node.names.size(); i++) { + base.aliases.push_back(query_node.names[i]); + } + base.types = query_node.types; +} + +SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { + auto &base = info.Cast(); + auto &scalar_function = base.function->Cast(); + + if (scalar_function.expression->HasParameter()) { + throw BinderException("Parameter expressions within macro's are not supported!"); + } + + // create macro binding in order to bind the function + vector dummy_types; + vector dummy_names; + // positional parameters + for (idx_t i = 0; i < base.function->parameters.size(); i++) { + auto param = base.function->parameters[i]->Cast(); + if (param.IsQualified()) { + throw BinderException("Invalid parameter name '%s': must be unqualified", param.ToString()); + } + dummy_types.emplace_back(LogicalType::SQLNULL); + dummy_names.push_back(param.GetColumnName()); + } + // default parameters + for (auto it = base.function->default_parameters.begin(); it != base.function->default_parameters.end(); it++) { + auto &val = it->second->Cast(); + dummy_types.push_back(val.value.type()); + dummy_names.push_back(it->first); + } + auto this_macro_binding = make_uniq(dummy_types, dummy_names, base.name); + macro_binding = this_macro_binding.get(); + ExpressionBinder::QualifyColumnNames(*this, scalar_function.expression); + + // create a copy of the expression because we do not want to alter the original + auto expression = scalar_function.expression->Copy(); + + // bind it to verify the function was defined correctly + string error; + auto sel_node = make_uniq(); + auto group_info = make_uniq(); + SelectBinder binder(*this, context, *sel_node, *group_info); + error = binder.Bind(expression, 0, false); + + if (!error.empty()) { + throw BinderException(error); + } + + return BindCreateSchema(info); +} + +void Binder::BindLogicalType(ClientContext &context, LogicalType &type, optional_ptr catalog, + const string &schema) { + if (type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::MAP) { + auto child_type = ListType::GetChildType(type); + BindLogicalType(context, child_type, catalog, schema); + auto alias = type.GetAlias(); + if (type.id() == LogicalTypeId::LIST) { + type = LogicalType::LIST(child_type); + } else { + D_ASSERT(child_type.id() == LogicalTypeId::STRUCT); // map must be list of structs + type = LogicalType::MAP(child_type); + } + + type.SetAlias(alias); + } else if (type.id() == LogicalTypeId::STRUCT) { + auto child_types = StructType::GetChildTypes(type); + for (auto &child_type : child_types) { + BindLogicalType(context, child_type.second, catalog, schema); + } + // Generate new Struct Type + auto alias = type.GetAlias(); + type = LogicalType::STRUCT(child_types); + type.SetAlias(alias); + } else if (type.id() == LogicalTypeId::UNION) { + auto member_types = UnionType::CopyMemberTypes(type); + for (auto &member_type : member_types) { + BindLogicalType(context, member_type.second, catalog, schema); + } + // Generate new Union Type + auto alias = type.GetAlias(); + type = LogicalType::UNION(member_types); + type.SetAlias(alias); + } else if (type.id() == LogicalTypeId::USER) { + auto user_type_name = UserType::GetTypeName(type); + if (catalog) { + // The search order is: + // 1) In the same schema as the table + // 2) In the same catalog + // 3) System catalog + type = catalog->GetType(context, schema, user_type_name, OnEntryNotFound::RETURN_NULL); + + if (type.id() == LogicalTypeId::INVALID) { + type = catalog->GetType(context, INVALID_SCHEMA, user_type_name, OnEntryNotFound::RETURN_NULL); + } + + if (type.id() == LogicalTypeId::INVALID) { + type = Catalog::GetType(context, INVALID_CATALOG, schema, user_type_name); + } + } else { + type = Catalog::GetType(context, INVALID_CATALOG, schema, user_type_name); + } + BindLogicalType(context, type, catalog, schema); + } +} + +static void FindMatchingPrimaryKeyColumns(const ColumnList &columns, const vector> &constraints, + ForeignKeyConstraint &fk) { + // find the matching primary key constraint + bool found_constraint = false; + // if no columns are defined, we will automatically try to bind to the primary key + bool find_primary_key = fk.pk_columns.empty(); + for (auto &constr : constraints) { + if (constr->type != ConstraintType::UNIQUE) { + continue; + } + auto &unique = constr->Cast(); + if (find_primary_key && !unique.is_primary_key) { + continue; + } + found_constraint = true; + + vector pk_names; + if (unique.index.index != DConstants::INVALID_INDEX) { + pk_names.push_back(columns.GetColumn(LogicalIndex(unique.index)).Name()); + } else { + pk_names = unique.columns; + } + if (find_primary_key) { + // found matching primary key + if (pk_names.size() != fk.fk_columns.size()) { + auto pk_name_str = StringUtil::Join(pk_names, ","); + auto fk_name_str = StringUtil::Join(fk.fk_columns, ","); + throw BinderException( + "Failed to create foreign key: number of referencing (%s) and referenced columns (%s) differ", + fk_name_str, pk_name_str); + } + fk.pk_columns = pk_names; + return; + } + if (pk_names.size() != fk.fk_columns.size()) { + // the number of referencing and referenced columns for foreign keys must be the same + continue; + } + bool equals = true; + for (idx_t i = 0; i < fk.pk_columns.size(); i++) { + if (!StringUtil::CIEquals(fk.pk_columns[i], pk_names[i])) { + equals = false; + break; + } + } + if (!equals) { + continue; + } + // found match + return; + } + // no match found! examine why + if (!found_constraint) { + // no unique constraint or primary key + string search_term = find_primary_key ? "primary key" : "primary key or unique constraint"; + throw BinderException("Failed to create foreign key: there is no %s for referenced table \"%s\"", search_term, + fk.info.table); + } + // check if all the columns exist + for (auto &name : fk.pk_columns) { + bool found = columns.ColumnExists(name); + if (!found) { + throw BinderException( + "Failed to create foreign key: referenced table \"%s\" does not have a column named \"%s\"", + fk.info.table, name); + } + } + auto fk_names = StringUtil::Join(fk.pk_columns, ","); + throw BinderException("Failed to create foreign key: referenced table \"%s\" does not have a primary key or unique " + "constraint on the columns %s", + fk.info.table, fk_names); +} + +static void FindForeignKeyIndexes(const ColumnList &columns, const vector &names, + vector &indexes) { + D_ASSERT(indexes.empty()); + D_ASSERT(!names.empty()); + for (auto &name : names) { + if (!columns.ColumnExists(name)) { + throw BinderException("column \"%s\" named in key does not exist", name); + } + auto &column = columns.GetColumn(name); + if (column.Generated()) { + throw BinderException("Failed to create foreign key: referenced column \"%s\" is a generated column", + column.Name()); + } + indexes.push_back(column.Physical()); + } +} + +static void CheckForeignKeyTypes(const ColumnList &pk_columns, const ColumnList &fk_columns, ForeignKeyConstraint &fk) { + D_ASSERT(fk.info.pk_keys.size() == fk.info.fk_keys.size()); + for (idx_t c_idx = 0; c_idx < fk.info.pk_keys.size(); c_idx++) { + auto &pk_col = pk_columns.GetColumn(fk.info.pk_keys[c_idx]); + auto &fk_col = fk_columns.GetColumn(fk.info.fk_keys[c_idx]); + if (pk_col.Type() != fk_col.Type()) { + throw BinderException("Failed to create foreign key: incompatible types between column \"%s\" (\"%s\") and " + "column \"%s\" (\"%s\")", + pk_col.Name(), pk_col.Type().ToString(), fk_col.Name(), fk_col.Type().ToString()); + } + } +} + +void ExpressionContainsGeneratedColumn(const ParsedExpression &expr, const unordered_set &gcols, + bool &contains_gcol) { + if (contains_gcol) { + return; + } + if (expr.type == ExpressionType::COLUMN_REF) { + auto &column_ref = expr.Cast(); + auto &name = column_ref.GetColumnName(); + if (gcols.count(name)) { + contains_gcol = true; + return; + } + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](const ParsedExpression &child) { ExpressionContainsGeneratedColumn(child, gcols, contains_gcol); }); +} + +static bool AnyConstraintReferencesGeneratedColumn(CreateTableInfo &table_info) { + unordered_set generated_columns; + for (auto &col : table_info.columns.Logical()) { + if (!col.Generated()) { + continue; + } + generated_columns.insert(col.Name()); + } + if (generated_columns.empty()) { + return false; + } + + for (auto &constr : table_info.constraints) { + switch (constr->type) { + case ConstraintType::CHECK: { + auto &constraint = constr->Cast(); + auto &expr = constraint.expression; + bool contains_generated_column = false; + ExpressionContainsGeneratedColumn(*expr, generated_columns, contains_generated_column); + if (contains_generated_column) { + return true; + } + break; + } + case ConstraintType::NOT_NULL: { + auto &constraint = constr->Cast(); + if (table_info.columns.GetColumn(constraint.index).Generated()) { + return true; + } + break; + } + case ConstraintType::UNIQUE: { + auto &constraint = constr->Cast(); + auto index = constraint.index; + if (index.index == DConstants::INVALID_INDEX) { + for (auto &col : constraint.columns) { + if (generated_columns.count(col)) { + return true; + } + } + } else { + if (table_info.columns.GetColumn(index).Generated()) { + return true; + } + } + break; + } + case ConstraintType::FOREIGN_KEY: { + // If it contained a generated column, an exception would have been thrown inside AddDataTableIndex earlier + break; + } + default: { + throw NotImplementedException("ConstraintType not implemented"); + } + } + } + return false; +} + +unique_ptr DuckCatalog::BindCreateIndex(Binder &binder, CreateStatement &stmt, + TableCatalogEntry &table, unique_ptr plan) { + D_ASSERT(plan->type == LogicalOperatorType::LOGICAL_GET); + auto &base = stmt.info->Cast(); + + auto &get = plan->Cast(); + // bind the index expressions + IndexBinder index_binder(binder, binder.context); + vector> expressions; + expressions.reserve(base.expressions.size()); + for (auto &expr : base.expressions) { + expressions.push_back(index_binder.Bind(expr)); + } + + auto create_index_info = unique_ptr_cast(std::move(stmt.info)); + for (auto &column_id : get.column_ids) { + if (column_id == COLUMN_IDENTIFIER_ROW_ID) { + throw BinderException("Cannot create an index on the rowid!"); + } + create_index_info->scan_types.push_back(get.returned_types[column_id]); + } + create_index_info->scan_types.emplace_back(LogicalType::ROW_TYPE); + create_index_info->names = get.names; + create_index_info->column_ids = get.column_ids; + auto &bind_data = get.bind_data->Cast(); + bind_data.is_create_index = true; + get.column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); + + // the logical CREATE INDEX also needs all fields to scan the referenced table + auto result = make_uniq(std::move(create_index_info), std::move(expressions), table); + result->children.push_back(std::move(plan)); + return std::move(result); +} + +BoundStatement Binder::Bind(CreateStatement &stmt) { + BoundStatement result; + result.names = {"Count"}; + result.types = {LogicalType::BIGINT}; + + auto catalog_type = stmt.info->type; + switch (catalog_type) { + case CatalogType::SCHEMA_ENTRY: + result.plan = make_uniq(LogicalOperatorType::LOGICAL_CREATE_SCHEMA, std::move(stmt.info)); + break; + case CatalogType::VIEW_ENTRY: { + auto &base = stmt.info->Cast(); + // bind the schema + auto &schema = BindCreateSchema(*stmt.info); + BindCreateViewInfo(base); + result.plan = make_uniq(LogicalOperatorType::LOGICAL_CREATE_VIEW, std::move(stmt.info), &schema); + break; + } + case CatalogType::SEQUENCE_ENTRY: { + auto &schema = BindCreateSchema(*stmt.info); + result.plan = + make_uniq(LogicalOperatorType::LOGICAL_CREATE_SEQUENCE, std::move(stmt.info), &schema); + break; + } + case CatalogType::TABLE_MACRO_ENTRY: { + auto &schema = BindCreateSchema(*stmt.info); + result.plan = + make_uniq(LogicalOperatorType::LOGICAL_CREATE_MACRO, std::move(stmt.info), &schema); + break; + } + case CatalogType::MACRO_ENTRY: { + auto &schema = BindCreateFunctionInfo(*stmt.info); + result.plan = + make_uniq(LogicalOperatorType::LOGICAL_CREATE_MACRO, std::move(stmt.info), &schema); + break; + } + case CatalogType::INDEX_ENTRY: { + auto &base = stmt.info->Cast(); + + // visit the table reference + auto table_ref = make_uniq(); + table_ref->catalog_name = base.catalog; + table_ref->schema_name = base.schema; + table_ref->table_name = base.table; + + auto bound_table = Bind(*table_ref); + if (bound_table->type != TableReferenceType::BASE_TABLE) { + throw BinderException("Can only create an index over a base table!"); + } + auto &table_binding = bound_table->Cast(); + auto &table = table_binding.table; + if (table.temporary) { + stmt.info->temporary = true; + } + // create a plan over the bound table + auto plan = CreatePlan(*bound_table); + if (plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Cannot create index on a view!"); + } + result.plan = table.catalog.BindCreateIndex(*this, stmt, table, std::move(plan)); + break; + } + case CatalogType::TABLE_ENTRY: { + auto &create_info = stmt.info->Cast(); + // If there is a foreign key constraint, resolve primary key column's index from primary key column's name + reference_set_t fk_schemas; + for (idx_t i = 0; i < create_info.constraints.size(); i++) { + auto &cond = create_info.constraints[i]; + if (cond->type != ConstraintType::FOREIGN_KEY) { + continue; + } + auto &fk = cond->Cast(); + if (fk.info.type != ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { + continue; + } + D_ASSERT(fk.info.pk_keys.empty()); + D_ASSERT(fk.info.fk_keys.empty()); + FindForeignKeyIndexes(create_info.columns, fk.fk_columns, fk.info.fk_keys); + if (StringUtil::CIEquals(create_info.table, fk.info.table)) { + // self-referential foreign key constraint + fk.info.type = ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE; + FindMatchingPrimaryKeyColumns(create_info.columns, create_info.constraints, fk); + FindForeignKeyIndexes(create_info.columns, fk.pk_columns, fk.info.pk_keys); + CheckForeignKeyTypes(create_info.columns, create_info.columns, fk); + } else { + // have to resolve referenced table + auto &pk_table_entry_ptr = + Catalog::GetEntry(context, INVALID_CATALOG, fk.info.schema, fk.info.table); + fk_schemas.insert(pk_table_entry_ptr.schema); + FindMatchingPrimaryKeyColumns(pk_table_entry_ptr.GetColumns(), pk_table_entry_ptr.GetConstraints(), fk); + FindForeignKeyIndexes(pk_table_entry_ptr.GetColumns(), fk.pk_columns, fk.info.pk_keys); + CheckForeignKeyTypes(pk_table_entry_ptr.GetColumns(), create_info.columns, fk); + auto &storage = pk_table_entry_ptr.GetStorage(); + auto index = storage.info->indexes.FindForeignKeyIndex(fk.info.pk_keys, + ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE); + if (!index) { + auto fk_column_names = StringUtil::Join(fk.pk_columns, ","); + throw BinderException("Failed to create foreign key on %s(%s): no UNIQUE or PRIMARY KEY constraint " + "present on these columns", + pk_table_entry_ptr.name, fk_column_names); + } + } + D_ASSERT(fk.info.pk_keys.size() == fk.info.fk_keys.size()); + D_ASSERT(fk.info.pk_keys.size() == fk.pk_columns.size()); + D_ASSERT(fk.info.fk_keys.size() == fk.fk_columns.size()); + } + if (AnyConstraintReferencesGeneratedColumn(create_info)) { + throw BinderException("Constraints on generated columns are not supported yet"); + } + auto bound_info = BindCreateTableInfo(std::move(stmt.info)); + auto root = std::move(bound_info->query); + for (auto &fk_schema : fk_schemas) { + if (&fk_schema.get() != &bound_info->schema) { + throw BinderException("Creating foreign keys across different schemas or catalogs is not supported"); + } + } + + // create the logical operator + auto &schema = bound_info->schema; + auto create_table = make_uniq(schema, std::move(bound_info)); + if (root) { + // CREATE TABLE AS + properties.return_type = StatementReturnType::CHANGED_ROWS; + create_table->children.push_back(std::move(root)); + } + result.plan = std::move(create_table); + break; + } + case CatalogType::TYPE_ENTRY: { + auto &schema = BindCreateSchema(*stmt.info); + auto &create_type_info = stmt.info->Cast(); + result.plan = make_uniq(LogicalOperatorType::LOGICAL_CREATE_TYPE, std::move(stmt.info), &schema); + if (create_type_info.query) { + // CREATE TYPE mood AS ENUM (SELECT 'happy') + auto query_obj = Bind(*create_type_info.query); + auto query = std::move(query_obj.plan); + create_type_info.query.reset(); + + auto &sql_types = query_obj.types; + if (sql_types.size() != 1) { + // add cast expression? + throw BinderException("The query must return a single column"); + } + if (sql_types[0].id() != LogicalType::VARCHAR) { + // push a projection casting to varchar + vector> select_list; + auto ref = make_uniq(sql_types[0], query->GetColumnBindings()[0]); + auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(ref), LogicalType::VARCHAR); + select_list.push_back(std::move(cast_expr)); + auto proj = make_uniq(GenerateTableIndex(), std::move(select_list)); + proj->AddChild(std::move(query)); + query = std::move(proj); + } + + result.plan->AddChild(std::move(query)); + } else if (create_type_info.type.id() == LogicalTypeId::USER) { + // two cases: + // 1: create a type with a non-existant type as source, catalog.GetType(...) will throw exception. + // 2: create a type alias with a custom type. + // eg. CREATE TYPE a AS INT; CREATE TYPE b AS a; + // We set b to be an alias for the underlying type of a + auto inner_type = Catalog::GetType(context, schema.catalog.GetName(), schema.name, + UserType::GetTypeName(create_type_info.type)); + inner_type.SetAlias(create_type_info.name); + create_type_info.type = inner_type; + } + break; + } + default: + throw Exception("Unrecognized type!"); + } + properties.return_type = StatementReturnType::NOTHING; + properties.allow_stream_result = false; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp new file mode 100644 index 00000000..eff11c73 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp @@ -0,0 +1,318 @@ +#include "duckdb/parser/constraints/list.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/constraints/list.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression_binder/check_binder.hpp" +#include "duckdb/planner/expression_binder/constant_binder.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/queue.hpp" +#include "duckdb/parser/expression/list.hpp" +#include "duckdb/common/index_map.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression_binder/index_binder.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" + +#include + +namespace duckdb { + +static void CreateColumnDependencyManager(BoundCreateTableInfo &info) { + auto &base = info.base->Cast(); + for (auto &col : base.columns.Logical()) { + if (!col.Generated()) { + continue; + } + info.column_dependency_manager.AddGeneratedColumn(col, base.columns); + } +} + +static void BindCheckConstraint(Binder &binder, BoundCreateTableInfo &info, const unique_ptr &cond) { + auto &base = info.base->Cast(); + + auto bound_constraint = make_uniq(); + // check constraint: bind the expression + CheckBinder check_binder(binder, binder.context, base.table, base.columns, bound_constraint->bound_columns); + auto &check = cond->Cast(); + // create a copy of the unbound expression because the binding destroys the constraint + auto unbound_expression = check.expression->Copy(); + // now bind the constraint and create a new BoundCheckConstraint + bound_constraint->expression = check_binder.Bind(check.expression); + info.bound_constraints.push_back(std::move(bound_constraint)); + // move the unbound constraint back into the original check expression + check.expression = std::move(unbound_expression); +} + +static void BindConstraints(Binder &binder, BoundCreateTableInfo &info) { + auto &base = info.base->Cast(); + + bool has_primary_key = false; + logical_index_set_t not_null_columns; + vector primary_keys; + for (idx_t i = 0; i < base.constraints.size(); i++) { + auto &cond = base.constraints[i]; + switch (cond->type) { + case ConstraintType::CHECK: { + BindCheckConstraint(binder, info, cond); + break; + } + case ConstraintType::NOT_NULL: { + auto ¬_null = cond->Cast(); + auto &col = base.columns.GetColumn(LogicalIndex(not_null.index)); + info.bound_constraints.push_back(make_uniq(PhysicalIndex(col.StorageOid()))); + not_null_columns.insert(not_null.index); + break; + } + case ConstraintType::UNIQUE: { + auto &unique = cond->Cast(); + // have to resolve columns of the unique constraint + vector keys; + logical_index_set_t key_set; + if (unique.index.index != DConstants::INVALID_INDEX) { + D_ASSERT(unique.index.index < base.columns.LogicalColumnCount()); + // unique constraint is given by single index + unique.columns.push_back(base.columns.GetColumn(unique.index).Name()); + keys.push_back(unique.index); + key_set.insert(unique.index); + } else { + // unique constraint is given by list of names + // have to resolve names + D_ASSERT(!unique.columns.empty()); + for (auto &keyname : unique.columns) { + if (!base.columns.ColumnExists(keyname)) { + throw ParserException("column \"%s\" named in key does not exist", keyname); + } + auto &column = base.columns.GetColumn(keyname); + auto column_index = column.Logical(); + if (key_set.find(column_index) != key_set.end()) { + throw ParserException("column \"%s\" appears twice in " + "primary key constraint", + keyname); + } + keys.push_back(column_index); + key_set.insert(column_index); + } + } + + if (unique.is_primary_key) { + // we can only have one primary key per table + if (has_primary_key) { + throw ParserException("table \"%s\" has more than one primary key", base.table); + } + has_primary_key = true; + primary_keys = keys; + } + info.bound_constraints.push_back( + make_uniq(std::move(keys), std::move(key_set), unique.is_primary_key)); + break; + } + case ConstraintType::FOREIGN_KEY: { + auto &fk = cond->Cast(); + D_ASSERT((fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE && !fk.info.pk_keys.empty()) || + (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && !fk.info.pk_keys.empty()) || + fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE); + physical_index_set_t fk_key_set, pk_key_set; + for (idx_t i = 0; i < fk.info.pk_keys.size(); i++) { + if (pk_key_set.find(fk.info.pk_keys[i]) != pk_key_set.end()) { + throw BinderException("Duplicate primary key referenced in FOREIGN KEY constraint"); + } + pk_key_set.insert(fk.info.pk_keys[i]); + } + for (idx_t i = 0; i < fk.info.fk_keys.size(); i++) { + if (fk_key_set.find(fk.info.fk_keys[i]) != fk_key_set.end()) { + throw BinderException("Duplicate key specified in FOREIGN KEY constraint"); + } + fk_key_set.insert(fk.info.fk_keys[i]); + } + info.bound_constraints.push_back( + make_uniq(fk.info, std::move(pk_key_set), std::move(fk_key_set))); + break; + } + default: + throw NotImplementedException("unrecognized constraint type in bind"); + } + } + if (has_primary_key) { + // if there is a primary key index, also create a NOT NULL constraint for each of the columns + for (auto &column_index : primary_keys) { + if (not_null_columns.count(column_index)) { + //! No need to create a NotNullConstraint, it's already present + continue; + } + auto physical_index = base.columns.LogicalToPhysical(column_index); + base.constraints.push_back(make_uniq(column_index)); + info.bound_constraints.push_back(make_uniq(physical_index)); + } + } +} + +void Binder::BindGeneratedColumns(BoundCreateTableInfo &info) { + auto &base = info.base->Cast(); + + vector names; + vector types; + + D_ASSERT(base.type == CatalogType::TABLE_ENTRY); + for (auto &col : base.columns.Logical()) { + names.push_back(col.Name()); + types.push_back(col.Type()); + } + auto table_index = GenerateTableIndex(); + + // Create a new binder because we dont need (or want) these bindings in this scope + auto binder = Binder::CreateBinder(context); + binder->bind_context.AddGenericBinding(table_index, base.table, names, types); + auto expr_binder = ExpressionBinder(*binder, context); + string ignore; + auto table_binding = binder->bind_context.GetBinding(base.table, ignore); + D_ASSERT(table_binding && ignore.empty()); + + auto bind_order = info.column_dependency_manager.GetBindOrder(base.columns); + logical_index_set_t bound_indices; + + while (!bind_order.empty()) { + auto i = bind_order.top(); + bind_order.pop(); + auto &col = base.columns.GetColumnMutable(i); + + //! Already bound this previously + //! This can not be optimized out of the GetBindOrder function + //! These occurrences happen because we need to make sure that ALL dependencies of a column are resolved before + //! it gets resolved + if (bound_indices.count(i)) { + continue; + } + D_ASSERT(col.Generated()); + auto expression = col.GeneratedExpression().Copy(); + + auto bound_expression = expr_binder.Bind(expression); + D_ASSERT(bound_expression); + D_ASSERT(!bound_expression->HasSubquery()); + if (col.Type().id() == LogicalTypeId::ANY) { + // Do this before changing the type, so we know it's the first time the type is set + col.ChangeGeneratedExpressionType(bound_expression->return_type); + col.SetType(bound_expression->return_type); + + // Update the type in the binding, for future expansions + string ignore; + table_binding->types[i.index] = col.Type(); + } + bound_indices.insert(i); + } +} + +void Binder::BindDefaultValues(const ColumnList &columns, vector> &bound_defaults) { + for (auto &column : columns.Physical()) { + unique_ptr bound_default; + if (column.DefaultValue()) { + // we bind a copy of the DEFAULT value because binding is destructive + // and we want to keep the original expression around for serialization + auto default_copy = column.DefaultValue()->Copy(); + ConstantBinder default_binder(*this, context, "DEFAULT value"); + default_binder.target_type = column.Type(); + bound_default = default_binder.Bind(default_copy); + } else { + // no default value specified: push a default value of constant null + bound_default = make_uniq(Value(column.Type())); + } + bound_defaults.push_back(std::move(bound_default)); + } +} + +static void ExtractExpressionDependencies(Expression &expr, DependencyList &dependencies) { + if (expr.type == ExpressionType::BOUND_FUNCTION) { + auto &function = expr.Cast(); + if (function.function.dependency) { + function.function.dependency(function, dependencies); + } + } + ExpressionIterator::EnumerateChildren( + expr, [&](Expression &child) { ExtractExpressionDependencies(child, dependencies); }); +} + +static void ExtractDependencies(BoundCreateTableInfo &info) { + for (auto &default_value : info.bound_defaults) { + if (default_value) { + ExtractExpressionDependencies(*default_value, info.dependencies); + } + } + for (auto &constraint : info.bound_constraints) { + if (constraint->type == ConstraintType::CHECK) { + auto &bound_check = constraint->Cast(); + ExtractExpressionDependencies(*bound_check.expression, info.dependencies); + } + } +} +unique_ptr Binder::BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema) { + auto &base = info->Cast(); + auto result = make_uniq(schema, std::move(info)); + if (base.query) { + // construct the result object + auto query_obj = Bind(*base.query); + base.query.reset(); + result->query = std::move(query_obj.plan); + + // construct the set of columns based on the names and types of the query + auto &names = query_obj.names; + auto &sql_types = query_obj.types; + D_ASSERT(names.size() == sql_types.size()); + base.columns.SetAllowDuplicates(true); + for (idx_t i = 0; i < names.size(); i++) { + base.columns.AddColumn(ColumnDefinition(names[i], sql_types[i])); + } + CreateColumnDependencyManager(*result); + // bind the generated column expressions + BindGeneratedColumns(*result); + } else { + CreateColumnDependencyManager(*result); + // bind the generated column expressions + BindGeneratedColumns(*result); + // bind any constraints + BindConstraints(*this, *result); + // bind the default values + BindDefaultValues(base.columns, result->bound_defaults); + } + // extract dependencies from any default values or CHECK constraints + ExtractDependencies(*result); + + if (base.columns.PhysicalColumnCount() == 0) { + throw BinderException("Creating a table without physical (non-generated) columns is not supported"); + } + // bind collations to detect any unsupported collation errors + for (idx_t i = 0; i < base.columns.PhysicalColumnCount(); i++) { + auto &column = base.columns.GetColumnMutable(PhysicalIndex(i)); + if (column.Type().id() == LogicalTypeId::VARCHAR) { + ExpressionBinder::TestCollation(context, StringType::GetCollation(column.Type())); + } + BindLogicalType(context, column.TypeMutable(), &result->schema.catalog); + } + result->dependencies.VerifyDependencies(schema.catalog, result->Base().table); + properties.allow_stream_result = false; + return result; +} + +unique_ptr Binder::BindCreateTableInfo(unique_ptr info) { + auto &base = info->Cast(); + auto &schema = BindCreateSchema(base); + return BindCreateTableInfo(std::move(info), schema); +} + +vector> Binder::BindCreateIndexExpressions(TableCatalogEntry &table, CreateIndexInfo &info) { + auto index_binder = IndexBinder(*this, this->context, &table, &info); + vector> expressions; + expressions.reserve(info.expressions.size()); + for (auto &expr : info.expressions) { + expressions.push_back(index_binder.Bind(expr)); + } + + return expressions; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_delete.cpp b/src/duckdb/src/planner/binder/statement/bind_delete.cpp new file mode 100644 index 00000000..7ae86228 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_delete.cpp @@ -0,0 +1,98 @@ +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder/where_binder.hpp" +#include "duckdb/planner/expression_binder/returning_binder.hpp" +#include "duckdb/planner/operator/logical_delete.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/tableref/bound_basetableref.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(DeleteStatement &stmt) { + BoundStatement result; + + // visit the table reference + auto bound_table = Bind(*stmt.table); + if (bound_table->type != TableReferenceType::BASE_TABLE) { + throw BinderException("Can only delete from base table!"); + } + auto &table_binding = bound_table->Cast(); + auto &table = table_binding.table; + + auto root = CreatePlan(*bound_table); + auto &get = root->Cast(); + D_ASSERT(root->type == LogicalOperatorType::LOGICAL_GET); + + if (!table.temporary) { + // delete from persistent table: not read only! + properties.modified_databases.insert(table.catalog.GetName()); + } + + // Add CTEs as bindable + AddCTEMap(stmt.cte_map); + + // plan any tables from the various using clauses + if (!stmt.using_clauses.empty()) { + unique_ptr child_operator; + for (auto &using_clause : stmt.using_clauses) { + // bind the using clause + auto using_binder = Binder::CreateBinder(context, this); + auto bound_node = using_binder->Bind(*using_clause); + auto op = CreatePlan(*bound_node); + if (child_operator) { + // already bound a child: create a cross product to unify the two + child_operator = LogicalCrossProduct::Create(std::move(child_operator), std::move(op)); + } else { + child_operator = std::move(op); + } + bind_context.AddContext(std::move(using_binder->bind_context)); + } + if (child_operator) { + root = LogicalCrossProduct::Create(std::move(root), std::move(child_operator)); + } + } + + // project any additional columns required for the condition + unique_ptr condition; + if (stmt.condition) { + WhereBinder binder(*this, context); + condition = binder.Bind(stmt.condition); + + PlanSubqueries(condition, root); + auto filter = make_uniq(std::move(condition)); + filter->AddChild(std::move(root)); + root = std::move(filter); + } + // create the delete node + auto del = make_uniq(table, GenerateTableIndex()); + del->AddChild(std::move(root)); + + // set up the delete expression + del->expressions.push_back(make_uniq( + LogicalType::ROW_TYPE, ColumnBinding(get.table_index, get.column_ids.size()))); + get.column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); + + if (!stmt.returning_list.empty()) { + del->return_chunk = true; + + auto update_table_index = GenerateTableIndex(); + del->table_index = update_table_index; + + unique_ptr del_as_logicaloperator = std::move(del); + return BindReturning(std::move(stmt.returning_list), table, stmt.table->alias, update_table_index, + std::move(del_as_logicaloperator), std::move(result)); + } + result.plan = std::move(del); + result.names = {"Count"}; + result.types = {LogicalType::BIGINT}; + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::CHANGED_ROWS; + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_detach.cpp b/src/duckdb/src/planner/binder/statement/bind_detach.cpp new file mode 100644 index 00000000..14c99e9e --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_detach.cpp @@ -0,0 +1,19 @@ +#include "duckdb/parser/statement/detach_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(DetachStatement &stmt) { + BoundStatement result; + + result.plan = make_uniq(LogicalOperatorType::LOGICAL_DETACH, std::move(stmt.info)); + result.names = {"Success"}; + result.types = {LogicalType::BOOLEAN}; + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_drop.cpp b/src/duckdb/src/planner/binder/statement/bind_drop.cpp new file mode 100644 index 00000000..700aa167 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_drop.cpp @@ -0,0 +1,64 @@ +#include "duckdb/parser/statement/drop_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/storage_extension.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(DropStatement &stmt) { + BoundStatement result; + + switch (stmt.info->type) { + case CatalogType::PREPARED_STATEMENT: + // dropping prepared statements is always possible + // it also does not require a valid transaction + properties.requires_valid_transaction = false; + break; + case CatalogType::SCHEMA_ENTRY: { + // dropping a schema is never read-only because there are no temporary schemas + auto &catalog = Catalog::GetCatalog(context, stmt.info->catalog); + properties.modified_databases.insert(catalog.GetName()); + break; + } + case CatalogType::VIEW_ENTRY: + case CatalogType::SEQUENCE_ENTRY: + case CatalogType::MACRO_ENTRY: + case CatalogType::TABLE_MACRO_ENTRY: + case CatalogType::INDEX_ENTRY: + case CatalogType::TABLE_ENTRY: + case CatalogType::TYPE_ENTRY: { + BindSchemaOrCatalog(stmt.info->catalog, stmt.info->schema); + auto entry = Catalog::GetEntry(context, stmt.info->type, stmt.info->catalog, stmt.info->schema, stmt.info->name, + OnEntryNotFound::RETURN_NULL); + if (!entry) { + break; + } + if (entry->internal) { + throw CatalogException("Cannot drop internal catalog entry \"%s\"!", entry->name); + } + stmt.info->catalog = entry->ParentCatalog().GetName(); + if (!entry->temporary) { + // we can only drop temporary tables in read-only mode + properties.modified_databases.insert(stmt.info->catalog); + } + stmt.info->schema = entry->ParentSchema().name; + break; + } + default: + throw BinderException("Unknown catalog type for drop statement!"); + } + result.plan = make_uniq(LogicalOperatorType::LOGICAL_DROP, std::move(stmt.info)); + result.names = {"Success"}; + result.types = {LogicalType::BOOLEAN}; + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_execute.cpp b/src/duckdb/src/planner/binder/statement/bind_execute.cpp new file mode 100644 index 00000000..f235c144 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_execute.cpp @@ -0,0 +1,74 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/execute_statement.hpp" +#include "duckdb/planner/planner.hpp" +#include "duckdb/planner/operator/logical_execute.hpp" +#include "duckdb/planner/expression_binder/constant_binder.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(ExecuteStatement &stmt) { + auto parameter_count = stmt.n_param; + + // bind the prepared statement + auto &client_data = ClientData::Get(context); + + auto entry = client_data.prepared_statements.find(stmt.name); + if (entry == client_data.prepared_statements.end()) { + throw BinderException("Prepared statement \"%s\" does not exist", stmt.name); + } + + // check if we need to rebind the prepared statement + // this happens if the catalog changes, since in this case e.g. tables we relied on may have been deleted + auto prepared = entry->second; + auto &named_param_map = prepared->unbound_statement->named_param_map; + + PreparedStatement::VerifyParameters(stmt.named_values, named_param_map); + + auto &mapped_named_values = stmt.named_values; + // bind any supplied parameters + case_insensitive_map_t bind_values; + auto constant_binder = Binder::CreateBinder(context); + constant_binder->SetCanContainNulls(true); + for (auto &pair : mapped_named_values) { + ConstantBinder cbinder(*constant_binder, context, "EXECUTE statement"); + auto bound_expr = cbinder.Bind(pair.second); + + Value value = ExpressionExecutor::EvaluateScalar(context, *bound_expr, true); + bind_values[pair.first] = std::move(value); + } + unique_ptr rebound_plan; + + if (prepared->RequireRebind(context, &bind_values)) { + // catalog was modified or statement does not have clear types: rebind the statement before running the execute + Planner prepared_planner(context); + for (auto &pair : bind_values) { + prepared_planner.parameter_data.emplace(std::make_pair(pair.first, BoundParameterData(pair.second))); + } + prepared = prepared_planner.PrepareSQLStatement(entry->second->unbound_statement->Copy()); + rebound_plan = std::move(prepared_planner.plan); + D_ASSERT(prepared->properties.bound_all_parameters); + this->bound_tables = prepared_planner.binder->bound_tables; + } + // copy the properties of the prepared statement into the planner + this->properties = prepared->properties; + this->properties.parameter_count = parameter_count; + BoundStatement result; + result.names = prepared->names; + result.types = prepared->types; + + prepared->Bind(std::move(bind_values)); + if (rebound_plan) { + auto execute_plan = make_uniq(std::move(prepared)); + execute_plan->children.push_back(std::move(rebound_plan)); + result.plan = std::move(execute_plan); + } else { + result.plan = make_uniq(std::move(prepared)); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_explain.cpp b/src/duckdb/src/planner/binder/statement/bind_explain.cpp new file mode 100644 index 00000000..2e1fd3ac --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_explain.cpp @@ -0,0 +1,24 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/planner/operator/logical_explain.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(ExplainStatement &stmt) { + BoundStatement result; + + // bind the underlying statement + auto plan = Bind(*stmt.stmt); + // get the unoptimized logical plan, and create the explain statement + auto logical_plan_unopt = plan.plan->ToString(); + auto explain = make_uniq(std::move(plan.plan), stmt.explain_type); + explain->logical_plan_unopt = logical_plan_unopt; + + result.plan = std::move(explain); + result.names = {"explain_key", "explain_value"}; + result.types = {LogicalType::VARCHAR, LogicalType::VARCHAR}; + properties.return_type = StatementReturnType::QUERY_RESULT; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_export.cpp b/src/duckdb/src/planner/binder/statement/bind_export.cpp new file mode 100644 index 00000000..be0367bd --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_export.cpp @@ -0,0 +1,355 @@ +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/parser/statement/export_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_export.hpp" +#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/parser/parsed_data/exported_table_data.hpp" +#include "duckdb/parser/constraints/foreign_key_constraint.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/query_node/select_node.hpp" + +#include "duckdb/common/string_util.hpp" +#include + +namespace duckdb { + +//! Sanitizes a string to have only low case chars and underscores +string SanitizeExportIdentifier(const string &str) { + // Copy the original string to result + string result(str); + + for (idx_t i = 0; i < str.length(); ++i) { + auto c = str[i]; + if (c >= 'a' && c <= 'z') { + // If it is lower case just continue + continue; + } + + if (c >= 'A' && c <= 'Z') { + // To lowercase + result[i] = tolower(c); + } else { + // Substitute to underscore + result[i] = '_'; + } + } + + return result; +} + +bool ReferencedTableIsOrdered(string &referenced_table, catalog_entry_vector_t &ordered) { + for (auto &entry : ordered) { + auto &table_entry = entry.get().Cast(); + if (StringUtil::CIEquals(table_entry.name, referenced_table)) { + // The referenced table is already ordered + return true; + } + } + return false; +} + +void ScanForeignKeyTable(catalog_entry_vector_t &ordered, catalog_entry_vector_t &unordered, + bool move_primary_keys_only) { + catalog_entry_vector_t remaining; + + for (auto &entry : unordered) { + auto &table_entry = entry.get().Cast(); + bool move_to_ordered = true; + auto &constraints = table_entry.GetConstraints(); + + for (auto &cond : constraints) { + if (cond->type != ConstraintType::FOREIGN_KEY) { + continue; + } + auto &fk = cond->Cast(); + if (fk.info.type != ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { + continue; + } + + if (move_primary_keys_only) { + // This table references a table, don't move it yet + move_to_ordered = false; + break; + } else if (!ReferencedTableIsOrdered(fk.info.table, ordered)) { + // The table that it references isn't ordered yet + move_to_ordered = false; + break; + } + } + + if (move_to_ordered) { + ordered.push_back(table_entry); + } else { + remaining.push_back(table_entry); + } + } + unordered = remaining; +} + +void ReorderTableEntries(catalog_entry_vector_t &tables) { + catalog_entry_vector_t ordered; + catalog_entry_vector_t unordered = tables; + // First only move the tables that don't have any dependencies + ScanForeignKeyTable(ordered, unordered, true); + while (!unordered.empty()) { + // Now we will start moving tables that have foreign key constraints + // if the tables they reference are already moved + ScanForeignKeyTable(ordered, unordered, false); + } + tables = ordered; +} + +string CreateFileName(const string &id_suffix, TableCatalogEntry &table, const string &extension) { + auto name = SanitizeExportIdentifier(table.name); + if (table.schema.name == DEFAULT_SCHEMA) { + return StringUtil::Format("%s%s.%s", name, id_suffix, extension); + } + auto schema = SanitizeExportIdentifier(table.schema.name); + return StringUtil::Format("%s_%s%s.%s", schema, name, id_suffix, extension); +} + +static bool IsSupported(CopyTypeSupport support_level) { + // For export purposes we don't want to lose information, so we only accept fully supported types + return support_level == CopyTypeSupport::SUPPORTED; +} + +static LogicalType AlterLogicalType(const LogicalType &original, copy_supports_type_t type_check) { + D_ASSERT(type_check); + auto id = original.id(); + switch (id) { + case LogicalTypeId::LIST: { + auto child = AlterLogicalType(ListType::GetChildType(original), type_check); + return LogicalType::LIST(child); + } + case LogicalTypeId::STRUCT: { + auto &original_children = StructType::GetChildTypes(original); + child_list_t new_children; + for (auto &child : original_children) { + auto &child_name = child.first; + auto &child_type = child.second; + + LogicalType new_type; + if (!IsSupported(type_check(child_type))) { + new_type = AlterLogicalType(child_type, type_check); + } else { + new_type = child_type; + } + new_children.push_back(std::make_pair(child_name, new_type)); + } + return LogicalType::STRUCT(std::move(new_children)); + } + case LogicalTypeId::UNION: { + auto member_count = UnionType::GetMemberCount(original); + child_list_t new_children; + for (idx_t i = 0; i < member_count; i++) { + auto &child_name = UnionType::GetMemberName(original, i); + auto &child_type = UnionType::GetMemberType(original, i); + + LogicalType new_type; + if (!IsSupported(type_check(child_type))) { + new_type = AlterLogicalType(child_type, type_check); + } else { + new_type = child_type; + } + + new_children.push_back(std::make_pair(child_name, new_type)); + } + return LogicalType::UNION(std::move(new_children)); + } + case LogicalTypeId::MAP: { + auto &key_type = MapType::KeyType(original); + auto &value_type = MapType::ValueType(original); + + LogicalType new_key_type; + LogicalType new_value_type; + if (!IsSupported(type_check(key_type))) { + new_key_type = AlterLogicalType(key_type, type_check); + } else { + new_key_type = key_type; + } + + if (!IsSupported(type_check(value_type))) { + new_value_type = AlterLogicalType(value_type, type_check); + } else { + new_value_type = value_type; + } + return LogicalType::MAP(new_key_type, new_value_type); + } + default: { + D_ASSERT(!IsSupported(type_check(original))); + return LogicalType::VARCHAR; + } + } +} + +static bool NeedsCast(LogicalType &type, copy_supports_type_t type_check) { + if (!type_check) { + return false; + } + if (IsSupported(type_check(type))) { + // The type is supported in it's entirety, no cast is required + return false; + } + // Change the type to something that is supported + type = AlterLogicalType(type, type_check); + return true; +} + +static unique_ptr CreateSelectStatement(CopyStatement &stmt, child_list_t &select_list, + copy_supports_type_t type_check) { + auto ref = make_uniq(); + ref->catalog_name = stmt.info->catalog; + ref->schema_name = stmt.info->schema; + ref->table_name = stmt.info->table; + + auto statement = make_uniq(); + statement->from_table = std::move(ref); + + vector> expressions; + for (auto &col : select_list) { + auto &name = col.first; + auto &type = col.second; + + auto expression = make_uniq_base(name); + if (NeedsCast(type, type_check)) { + // Add a cast to a type supported by the copy function + expression = make_uniq_base(type, std::move(expression)); + } + expressions.push_back(std::move(expression)); + } + + statement->select_list = std::move(expressions); + return std::move(statement); +} + +BoundStatement Binder::Bind(ExportStatement &stmt) { + // COPY TO a file + auto &config = DBConfig::GetConfig(context); + if (!config.options.enable_external_access) { + throw PermissionException("COPY TO is disabled through configuration"); + } + BoundStatement result; + result.types = {LogicalType::BOOLEAN}; + result.names = {"Success"}; + + // lookup the format in the catalog + auto ©_function = + Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, stmt.info->format); + if (!copy_function.function.copy_to_bind && !copy_function.function.plan) { + throw NotImplementedException("COPY TO is not supported for FORMAT \"%s\"", stmt.info->format); + } + + // gather a list of all the tables + string catalog = stmt.database.empty() ? INVALID_CATALOG : stmt.database; + catalog_entry_vector_t tables; + auto schemas = Catalog::GetSchemas(context, catalog); + for (auto &schema : schemas) { + schema.get().Scan(context, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { + if (entry.type == CatalogType::TABLE_ENTRY) { + tables.push_back(entry.Cast()); + } + }); + } + + // reorder tables because of foreign key constraint + ReorderTableEntries(tables); + + // now generate the COPY statements for each of the tables + auto &fs = FileSystem::GetFileSystem(context); + unique_ptr child_operator; + + BoundExportData exported_tables; + + unordered_set table_name_index; + for (auto &t : tables) { + auto &table = t.get().Cast(); + auto info = make_uniq(); + // we copy the options supplied to the EXPORT + info->format = stmt.info->format; + info->options = stmt.info->options; + // set up the file name for the COPY TO + + idx_t id = 0; + while (true) { + string id_suffix = id == 0 ? string() : "_" + to_string(id); + auto name = CreateFileName(id_suffix, table, copy_function.function.extension); + auto directory = stmt.info->file_path; + auto full_path = fs.JoinPath(directory, name); + info->file_path = full_path; + auto insert_result = table_name_index.insert(info->file_path); + if (insert_result.second == true) { + // this name was not yet taken: take it + break; + } + id++; + } + info->is_from = false; + info->catalog = catalog; + info->schema = table.schema.name; + info->table = table.name; + + // We can not export generated columns + child_list_t select_list; + + for (auto &col : table.GetColumns().Physical()) { + select_list.push_back(std::make_pair(col.Name(), col.Type())); + } + + ExportedTableData exported_data; + exported_data.database_name = catalog; + exported_data.table_name = info->table; + exported_data.schema_name = info->schema; + + exported_data.file_path = info->file_path; + + ExportedTableInfo table_info(table, std::move(exported_data)); + exported_tables.data.push_back(table_info); + id++; + + // generate the copy statement and bind it + CopyStatement copy_stmt; + copy_stmt.info = std::move(info); + copy_stmt.select_statement = + CreateSelectStatement(copy_stmt, select_list, copy_function.function.supports_type); + + auto copy_binder = Binder::CreateBinder(context, this); + auto bound_statement = copy_binder->Bind(copy_stmt); + auto plan = std::move(bound_statement.plan); + + if (child_operator) { + // use UNION ALL to combine the individual copy statements into a single node + auto copy_union = make_uniq(GenerateTableIndex(), 1, std::move(child_operator), + std::move(plan), LogicalOperatorType::LOGICAL_UNION); + child_operator = std::move(copy_union); + } else { + child_operator = std::move(plan); + } + } + + // try to create the directory, if it doesn't exist yet + // a bit hacky to do it here, but we need to create the directory BEFORE the copy statements run + if (!fs.DirectoryExists(stmt.info->file_path)) { + fs.CreateDirectory(stmt.info->file_path); + } + + // create the export node + auto export_node = make_uniq(copy_function.function, std::move(stmt.info), exported_tables); + + if (child_operator) { + export_node->children.push_back(std::move(child_operator)); + } + + result.plan = std::move(export_node); + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_extension.cpp b/src/duckdb/src/planner/binder/statement/bind_extension.cpp new file mode 100644 index 00000000..f0884cdc --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_extension.cpp @@ -0,0 +1,32 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/extension_statement.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(ExtensionStatement &stmt) { + BoundStatement result; + + // perform the planning of the function + D_ASSERT(stmt.extension.plan_function); + auto parse_result = + stmt.extension.plan_function(stmt.extension.parser_info.get(), context, std::move(stmt.parse_data)); + + properties.modified_databases = parse_result.modified_databases; + properties.requires_valid_transaction = parse_result.requires_valid_transaction; + properties.return_type = parse_result.return_type; + + // create the plan as a scan of the given table function + result.plan = BindTableFunction(parse_result.function, std::move(parse_result.parameters)); + D_ASSERT(result.plan->type == LogicalOperatorType::LOGICAL_GET); + auto &get = result.plan->Cast(); + result.names = get.names; + result.types = get.returned_types; + get.column_ids.clear(); + for (idx_t i = 0; i < get.returned_types.size(); i++) { + get.column_ids.push_back(i); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_insert.cpp b/src/duckdb/src/planner/binder/statement/bind_insert.cpp new file mode 100644 index 00000000..599b23aa --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_insert.cpp @@ -0,0 +1,549 @@ +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder/insert_binder.hpp" +#include "duckdb/planner/operator/logical_insert.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression_binder/returning_binder.hpp" +#include "duckdb/planner/expression_binder/where_binder.hpp" +#include "duckdb/planner/expression_binder/update_binder.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/planner/expression/bound_default_expression.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/tableref/bound_basetableref.hpp" +#include "duckdb/planner/tableref/bound_dummytableref.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/storage/table_storage_info.hpp" + +namespace duckdb { + +static void CheckInsertColumnCountMismatch(int64_t expected_columns, int64_t result_columns, bool columns_provided, + const char *tname) { + if (result_columns != expected_columns) { + string msg = StringUtil::Format(!columns_provided ? "table %s has %lld columns but %lld values were supplied" + : "Column name/value mismatch for insert on %s: " + "expected %lld columns but %lld values were supplied", + tname, expected_columns, result_columns); + throw BinderException(msg); + } +} + +unique_ptr ExpandDefaultExpression(const ColumnDefinition &column) { + if (column.DefaultValue()) { + return column.DefaultValue()->Copy(); + } else { + return make_uniq(Value(column.Type())); + } +} + +void ReplaceDefaultExpression(unique_ptr &expr, const ColumnDefinition &column) { + D_ASSERT(expr->type == ExpressionType::VALUE_DEFAULT); + expr = ExpandDefaultExpression(column); +} + +void QualifyColumnReferences(unique_ptr &expr, const string &table_name) { + // To avoid ambiguity with 'excluded', we explicitly qualify all column references + if (expr->type == ExpressionType::COLUMN_REF) { + auto &column_ref = expr->Cast(); + if (column_ref.IsQualified()) { + return; + } + auto column_name = column_ref.GetColumnName(); + expr = make_uniq(column_name, table_name); + } + ParsedExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child) { QualifyColumnReferences(child, table_name); }); +} + +// Replace binding.table_index with 'dest' if it's 'source' +void ReplaceColumnBindings(Expression &expr, idx_t source, idx_t dest) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + auto &bound_columnref = expr.Cast(); + if (bound_columnref.binding.table_index == source) { + bound_columnref.binding.table_index = dest; + } + } + ExpressionIterator::EnumerateChildren( + expr, [&](unique_ptr &child) { ReplaceColumnBindings(*child, source, dest); }); +} + +void Binder::BindDoUpdateSetExpressions(const string &table_alias, LogicalInsert &insert, UpdateSetInfo &set_info, + TableCatalogEntry &table, TableStorageInfo &storage_info) { + D_ASSERT(insert.children.size() == 1); + D_ASSERT(insert.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION); + + vector logical_column_ids; + vector column_names; + D_ASSERT(set_info.columns.size() == set_info.expressions.size()); + + for (idx_t i = 0; i < set_info.columns.size(); i++) { + auto &colname = set_info.columns[i]; + auto &expr = set_info.expressions[i]; + if (!table.ColumnExists(colname)) { + throw BinderException("Referenced update column %s not found in table!", colname); + } + auto &column = table.GetColumn(colname); + if (column.Generated()) { + throw BinderException("Cant update column \"%s\" because it is a generated column!", column.Name()); + } + if (std::find(insert.set_columns.begin(), insert.set_columns.end(), column.Physical()) != + insert.set_columns.end()) { + throw BinderException("Multiple assignments to same column \"%s\"", colname); + } + insert.set_columns.push_back(column.Physical()); + logical_column_ids.push_back(column.Oid()); + insert.set_types.push_back(column.Type()); + column_names.push_back(colname); + if (expr->type == ExpressionType::VALUE_DEFAULT) { + expr = ExpandDefaultExpression(column); + } + UpdateBinder binder(*this, context); + binder.target_type = column.Type(); + + // Avoid ambiguity issues + QualifyColumnReferences(expr, table_alias); + + auto bound_expr = binder.Bind(expr); + D_ASSERT(bound_expr); + if (bound_expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { + throw BinderException("Expression in the DO UPDATE SET clause can not be a subquery"); + } + + insert.expressions.push_back(std::move(bound_expr)); + } + + // Figure out which columns are indexed on + unordered_set indexed_columns; + for (auto &index : storage_info.index_info) { + for (auto &column_id : index.column_set) { + indexed_columns.insert(column_id); + } + } + + // Verify that none of the columns that are targeted with a SET expression are indexed on + for (idx_t i = 0; i < logical_column_ids.size(); i++) { + auto &column = logical_column_ids[i]; + if (indexed_columns.count(column)) { + throw BinderException("Can not assign to column '%s' because it has a UNIQUE/PRIMARY KEY constraint", + column_names[i]); + } + } +} + +unique_ptr CreateSetInfoForReplace(TableCatalogEntry &table, InsertStatement &insert, + TableStorageInfo &storage_info) { + auto set_info = make_uniq(); + + auto &columns = set_info->columns; + // Figure out which columns are indexed on + + unordered_set indexed_columns; + for (auto &index : storage_info.index_info) { + for (auto &column_id : index.column_set) { + indexed_columns.insert(column_id); + } + } + + auto &column_list = table.GetColumns(); + if (insert.columns.empty()) { + for (auto &column : column_list.Physical()) { + auto &name = column.Name(); + // FIXME: can these column names be aliased somehow? + if (indexed_columns.count(column.Oid())) { + continue; + } + columns.push_back(name); + } + } else { + // a list of columns was explicitly supplied, only update those + for (auto &name : insert.columns) { + auto &column = column_list.GetColumn(name); + if (indexed_columns.count(column.Oid())) { + continue; + } + columns.push_back(name); + } + } + + // Create 'excluded' qualified column references of these columns + for (auto &column : columns) { + set_info->expressions.push_back(make_uniq(column, "excluded")); + } + + return set_info; +} + +void Binder::BindOnConflictClause(LogicalInsert &insert, TableCatalogEntry &table, InsertStatement &stmt) { + if (!stmt.on_conflict_info) { + insert.action_type = OnConflictAction::THROW; + return; + } + D_ASSERT(stmt.table_ref->type == TableReferenceType::BASE_TABLE); + + // visit the table reference + auto bound_table = Bind(*stmt.table_ref); + if (bound_table->type != TableReferenceType::BASE_TABLE) { + throw BinderException("Can only update base table!"); + } + + auto &table_ref = stmt.table_ref->Cast(); + const string &table_alias = !table_ref.alias.empty() ? table_ref.alias : table_ref.table_name; + + auto &on_conflict = *stmt.on_conflict_info; + D_ASSERT(on_conflict.action_type != OnConflictAction::THROW); + insert.action_type = on_conflict.action_type; + + // obtain the table storage info + auto storage_info = table.GetStorageInfo(context); + + auto &columns = table.GetColumns(); + if (!on_conflict.indexed_columns.empty()) { + // Bind the ON CONFLICT () + + // create a mapping of (list index) -> (column index) + case_insensitive_map_t specified_columns; + for (idx_t i = 0; i < on_conflict.indexed_columns.size(); i++) { + specified_columns[on_conflict.indexed_columns[i]] = i; + auto column_index = table.GetColumnIndex(on_conflict.indexed_columns[i]); + if (column_index.index == COLUMN_IDENTIFIER_ROW_ID) { + throw BinderException("Cannot specify ROWID as ON CONFLICT target"); + } + auto &col = columns.GetColumn(column_index); + if (col.Generated()) { + throw BinderException("Cannot specify a generated column as ON CONFLICT target"); + } + } + for (auto &col : columns.Physical()) { + auto entry = specified_columns.find(col.Name()); + if (entry != specified_columns.end()) { + // column was specified, set to the index + insert.on_conflict_filter.insert(col.Oid()); + } + } + bool index_references_columns = false; + for (auto &index : storage_info.index_info) { + if (!index.is_unique) { + continue; + } + bool index_matches = insert.on_conflict_filter == index.column_set; + if (index_matches) { + index_references_columns = true; + break; + } + } + if (!index_references_columns) { + // Same as before, this is essentially a no-op, turning this into a DO THROW instead + // But since this makes no logical sense, it's probably better to throw an error + throw BinderException( + "The specified columns as conflict target are not referenced by a UNIQUE/PRIMARY KEY CONSTRAINT"); + } + } else { + // When omitting the conflict target, the ON CONFLICT applies to every UNIQUE/PRIMARY KEY on the table + + // We check if there are any constraints on the table, if there aren't we throw an error. + idx_t found_matching_indexes = 0; + for (auto &index : storage_info.index_info) { + if (!index.is_unique) { + continue; + } + // does this work with multi-column indexes? + auto &indexed_columns = index.column_set; + for (auto &column : table.GetColumns().Physical()) { + if (indexed_columns.count(column.Physical().index)) { + found_matching_indexes++; + } + } + } + if (!found_matching_indexes) { + throw BinderException( + "There are no UNIQUE/PRIMARY KEY Indexes that refer to this table, ON CONFLICT is a no-op"); + } + if (insert.action_type != OnConflictAction::NOTHING && found_matching_indexes != 1) { + // When no conflict target is provided, and the action type is UPDATE, + // we only allow the operation when only a single Index exists + throw BinderException("Conflict target has to be provided for a DO UPDATE operation when the table has " + "multiple UNIQUE/PRIMARY KEY constraints"); + } + } + + // add the 'excluded' dummy table binding + AddTableName("excluded"); + // add a bind context entry for it + auto excluded_index = GenerateTableIndex(); + insert.excluded_table_index = excluded_index; + auto table_column_names = columns.GetColumnNames(); + auto table_column_types = columns.GetColumnTypes(); + bind_context.AddGenericBinding(excluded_index, "excluded", table_column_names, table_column_types); + + if (on_conflict.condition) { + // Avoid ambiguity between binding and 'excluded' + QualifyColumnReferences(on_conflict.condition, table_alias); + // Bind the ON CONFLICT ... WHERE clause + WhereBinder where_binder(*this, context); + auto condition = where_binder.Bind(on_conflict.condition); + if (condition && condition->expression_class == ExpressionClass::BOUND_SUBQUERY) { + throw BinderException("conflict_target WHERE clause can not be a subquery"); + } + insert.on_conflict_condition = std::move(condition); + } + + auto bindings = insert.children[0]->GetColumnBindings(); + idx_t projection_index = DConstants::INVALID_INDEX; + vector> *insert_child_operators; + insert_child_operators = &insert.children; + while (projection_index == DConstants::INVALID_INDEX) { + if (insert_child_operators->empty()) { + // No further children to visit + break; + } + D_ASSERT(insert_child_operators->size() >= 1); + auto ¤t_child = (*insert_child_operators)[0]; + auto table_indices = current_child->GetTableIndex(); + if (table_indices.empty()) { + // This operator does not have a table index to refer to, we have to visit its children + insert_child_operators = ¤t_child->children; + continue; + } + projection_index = table_indices[0]; + } + if (projection_index == DConstants::INVALID_INDEX) { + throw InternalException("Could not locate a table_index from the children of the insert"); + } + + string unused; + auto original_binding = bind_context.GetBinding(table_alias, unused); + D_ASSERT(original_binding); + + auto table_index = original_binding->index; + + // Replace any column bindings to refer to the projection table_index, rather than the source table + if (insert.on_conflict_condition) { + ReplaceColumnBindings(*insert.on_conflict_condition, table_index, projection_index); + } + + if (insert.action_type == OnConflictAction::REPLACE) { + D_ASSERT(on_conflict.set_info == nullptr); + on_conflict.set_info = CreateSetInfoForReplace(table, stmt, storage_info); + insert.action_type = OnConflictAction::UPDATE; + } + if (on_conflict.set_info && on_conflict.set_info->columns.empty()) { + // if we are doing INSERT OR REPLACE on a table with no columns outside of the primary key column + // convert to INSERT OR IGNORE + insert.action_type = OnConflictAction::NOTHING; + } + if (insert.action_type == OnConflictAction::NOTHING) { + if (!insert.on_conflict_condition) { + return; + } + // Get the column_ids we need to fetch later on from the conflicting tuples + // of the original table, to execute the expressions + D_ASSERT(original_binding->binding_type == BindingType::TABLE); + auto &table_binding = original_binding->Cast(); + insert.columns_to_fetch = table_binding.GetBoundColumnIds(); + return; + } + + D_ASSERT(on_conflict.set_info); + auto &set_info = *on_conflict.set_info; + D_ASSERT(set_info.columns.size() == set_info.expressions.size()); + + if (set_info.condition) { + // Avoid ambiguity between binding and 'excluded' + QualifyColumnReferences(set_info.condition, table_alias); + // Bind the SET ... WHERE clause + WhereBinder where_binder(*this, context); + auto condition = where_binder.Bind(set_info.condition); + if (condition && condition->expression_class == ExpressionClass::BOUND_SUBQUERY) { + throw BinderException("conflict_target WHERE clause can not be a subquery"); + } + insert.do_update_condition = std::move(condition); + } + + BindDoUpdateSetExpressions(table_alias, insert, set_info, table, storage_info); + + // Get the column_ids we need to fetch later on from the conflicting tuples + // of the original table, to execute the expressions + D_ASSERT(original_binding->binding_type == BindingType::TABLE); + auto &table_binding = original_binding->Cast(); + insert.columns_to_fetch = table_binding.GetBoundColumnIds(); + + // Replace the column bindings to refer to the child operator + for (auto &expr : insert.expressions) { + // Change the non-excluded column references to refer to the projection index + ReplaceColumnBindings(*expr, table_index, projection_index); + } + // Do the same for the (optional) DO UPDATE condition + if (insert.do_update_condition) { + ReplaceColumnBindings(*insert.do_update_condition, table_index, projection_index); + } +} + +BoundStatement Binder::Bind(InsertStatement &stmt) { + BoundStatement result; + result.names = {"Count"}; + result.types = {LogicalType::BIGINT}; + + BindSchemaOrCatalog(stmt.catalog, stmt.schema); + auto &table = Catalog::GetEntry(context, stmt.catalog, stmt.schema, stmt.table); + if (!table.temporary) { + // inserting into a non-temporary table: alters underlying database + properties.modified_databases.insert(table.catalog.GetName()); + } + + auto insert = make_uniq(table, GenerateTableIndex()); + // Add CTEs as bindable + AddCTEMap(stmt.cte_map); + + auto values_list = stmt.GetValuesList(); + + // bind the root select node (if any) + BoundStatement root_select; + if (stmt.column_order == InsertColumnOrder::INSERT_BY_NAME) { + if (values_list) { + throw BinderException("INSERT BY NAME can only be used when inserting from a SELECT statement"); + } + if (!stmt.columns.empty()) { + throw BinderException("INSERT BY NAME cannot be combined with an explicit column list"); + } + D_ASSERT(stmt.select_statement); + // INSERT BY NAME - generate the columns from the names of the SELECT statement + auto select_binder = Binder::CreateBinder(context, this); + root_select = select_binder->Bind(*stmt.select_statement); + MoveCorrelatedExpressions(*select_binder); + + stmt.columns = root_select.names; + } + + vector named_column_map; + if (!stmt.columns.empty() || stmt.default_values) { + // insertion statement specifies column list + + // create a mapping of (list index) -> (column index) + case_insensitive_map_t column_name_map; + for (idx_t i = 0; i < stmt.columns.size(); i++) { + auto entry = column_name_map.insert(make_pair(stmt.columns[i], i)); + if (!entry.second) { + throw BinderException("Duplicate column name \"%s\" in INSERT", stmt.columns[i]); + } + column_name_map[stmt.columns[i]] = i; + auto column_index = table.GetColumnIndex(stmt.columns[i]); + if (column_index.index == COLUMN_IDENTIFIER_ROW_ID) { + throw BinderException("Cannot explicitly insert values into rowid column"); + } + auto &col = table.GetColumn(column_index); + if (col.Generated()) { + throw BinderException("Cannot insert into a generated column"); + } + insert->expected_types.push_back(col.Type()); + named_column_map.push_back(column_index); + } + for (auto &col : table.GetColumns().Physical()) { + auto entry = column_name_map.find(col.Name()); + if (entry == column_name_map.end()) { + // column not specified, set index to DConstants::INVALID_INDEX + insert->column_index_map.push_back(DConstants::INVALID_INDEX); + } else { + // column was specified, set to the index + insert->column_index_map.push_back(entry->second); + } + } + } else { + // insert by position and no columns specified - insertion into all columns of the table + // intentionally don't populate 'column_index_map' as an indication of this + for (auto &col : table.GetColumns().Physical()) { + named_column_map.push_back(col.Logical()); + insert->expected_types.push_back(col.Type()); + } + } + + // bind the default values + BindDefaultValues(table.GetColumns(), insert->bound_defaults); + if (!stmt.select_statement && !stmt.default_values) { + result.plan = std::move(insert); + return result; + } + // Exclude the generated columns from this amount + idx_t expected_columns = stmt.columns.empty() ? table.GetColumns().PhysicalColumnCount() : stmt.columns.size(); + + // special case: check if we are inserting from a VALUES statement + if (values_list) { + auto &expr_list = values_list->Cast(); + expr_list.expected_types.resize(expected_columns); + expr_list.expected_names.resize(expected_columns); + + D_ASSERT(expr_list.values.size() > 0); + CheckInsertColumnCountMismatch(expected_columns, expr_list.values[0].size(), !stmt.columns.empty(), + table.name.c_str()); + + // VALUES list! + for (idx_t col_idx = 0; col_idx < expected_columns; col_idx++) { + D_ASSERT(named_column_map.size() >= col_idx); + auto &table_col_idx = named_column_map[col_idx]; + + // set the expected types as the types for the INSERT statement + auto &column = table.GetColumn(table_col_idx); + expr_list.expected_types[col_idx] = column.Type(); + expr_list.expected_names[col_idx] = column.Name(); + + // now replace any DEFAULT values with the corresponding default expression + for (idx_t list_idx = 0; list_idx < expr_list.values.size(); list_idx++) { + if (expr_list.values[list_idx][col_idx]->type == ExpressionType::VALUE_DEFAULT) { + // DEFAULT value! replace the entry + ReplaceDefaultExpression(expr_list.values[list_idx][col_idx], column); + } + } + } + } + + // parse select statement and add to logical plan + unique_ptr root; + if (stmt.select_statement) { + if (stmt.column_order == InsertColumnOrder::INSERT_BY_POSITION) { + auto select_binder = Binder::CreateBinder(context, this); + root_select = select_binder->Bind(*stmt.select_statement); + MoveCorrelatedExpressions(*select_binder); + } + // inserting from a select - check if the column count matches + CheckInsertColumnCountMismatch(expected_columns, root_select.types.size(), !stmt.columns.empty(), + table.name.c_str()); + + root = CastLogicalOperatorToTypes(root_select.types, insert->expected_types, std::move(root_select.plan)); + } else { + root = make_uniq(GenerateTableIndex()); + } + insert->AddChild(std::move(root)); + + BindOnConflictClause(*insert, table, stmt); + + if (!stmt.returning_list.empty()) { + insert->return_chunk = true; + result.types.clear(); + result.names.clear(); + auto insert_table_index = GenerateTableIndex(); + insert->table_index = insert_table_index; + unique_ptr index_as_logicaloperator = std::move(insert); + + return BindReturning(std::move(stmt.returning_list), table, stmt.table_ref ? stmt.table_ref->alias : string(), + insert_table_index, std::move(index_as_logicaloperator), std::move(result)); + } + + D_ASSERT(result.types.size() == result.names.size()); + result.plan = std::move(insert); + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::CHANGED_ROWS; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_load.cpp b/src/duckdb/src/planner/binder/statement/bind_load.cpp new file mode 100644 index 00000000..a179ba2b --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_load.cpp @@ -0,0 +1,19 @@ +#include "duckdb/parser/statement/load_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" +#include + +namespace duckdb { + +BoundStatement Binder::Bind(LoadStatement &stmt) { + BoundStatement result; + result.types = {LogicalType::BOOLEAN}; + result.names = {"Success"}; + + result.plan = make_uniq(LogicalOperatorType::LOGICAL_LOAD, std::move(stmt.info)); + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp new file mode 100644 index 00000000..9a7ae93a --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp @@ -0,0 +1,37 @@ +#include "duckdb/parser/statement/logical_plan_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include + +namespace duckdb { + +idx_t GetMaxTableIndex(LogicalOperator &op) { + idx_t result = 0; + for (auto &child : op.children) { + auto max_child_index = GetMaxTableIndex(*child); + result = MaxValue(result, max_child_index); + } + auto indexes = op.GetTableIndex(); + for (auto &index : indexes) { + result = MaxValue(result, index); + } + return result; +} + +BoundStatement Binder::Bind(LogicalPlanStatement &stmt) { + BoundStatement result; + result.types = stmt.plan->types; + for (idx_t i = 0; i < result.types.size(); i++) { + result.names.push_back(StringUtil::Format("col%d", i)); + } + result.plan = std::move(stmt.plan); + properties.allow_stream_result = true; + properties.return_type = StatementReturnType::QUERY_RESULT; // TODO could also be something else + + if (parent) { + throw InternalException("LogicalPlanStatement should be bound in root binder"); + } + bound_tables = GetMaxTableIndex(*result.plan) + 1; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp new file mode 100644 index 00000000..5e7bb420 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp @@ -0,0 +1,38 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/pragma_statement.hpp" +#include "duckdb/planner/operator/logical_pragma.hpp" +#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(PragmaStatement &stmt) { + // bind the pragma function + auto &entry = + Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, stmt.info->name); + string error; + FunctionBinder function_binder(context); + idx_t bound_idx = function_binder.BindFunction(entry.name, entry.functions, *stmt.info, error); + if (bound_idx == DConstants::INVALID_INDEX) { + throw BinderException(FormatError(stmt.stmt_location, error)); + } + auto bound_function = entry.functions.GetFunctionByOffset(bound_idx); + if (!bound_function.function) { + throw BinderException("PRAGMA function does not have a function specified"); + } + + // bind and check named params + QueryErrorContext error_context(root_statement, stmt.stmt_location); + BindNamedParameters(bound_function.named_parameters, stmt.info->named_parameters, error_context, + bound_function.name); + + BoundStatement result; + result.names = {"Success"}; + result.types = {LogicalType::BOOLEAN}; + result.plan = make_uniq(bound_function, *stmt.info); + properties.return_type = StatementReturnType::QUERY_RESULT; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_prepare.cpp b/src/duckdb/src/planner/binder/statement/bind_prepare.cpp new file mode 100644 index 00000000..700e9175 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_prepare.cpp @@ -0,0 +1,29 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/prepare_statement.hpp" +#include "duckdb/planner/planner.hpp" +#include "duckdb/planner/operator/logical_prepare.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(PrepareStatement &stmt) { + Planner prepared_planner(context); + auto prepared_data = prepared_planner.PrepareSQLStatement(std::move(stmt.statement)); + this->bound_tables = prepared_planner.binder->bound_tables; + + auto prepare = make_uniq(stmt.name, std::move(prepared_data), std::move(prepared_planner.plan)); + // we can always prepare, even if the transaction has been invalidated + // this is required because most clients ALWAYS invoke prepared statements + properties.requires_valid_transaction = false; + properties.allow_stream_result = false; + properties.bound_all_parameters = true; + properties.parameter_count = 0; + properties.return_type = StatementReturnType::NOTHING; + + BoundStatement result; + result.names = {"Success"}; + result.types = {LogicalType::BOOLEAN}; + result.plan = std::move(prepare); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_relation.cpp b/src/duckdb/src/planner/binder/statement/bind_relation.cpp new file mode 100644 index 00000000..1e0565a1 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_relation.cpp @@ -0,0 +1,14 @@ +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/relation_statement.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(RelationStatement &stmt) { + return stmt.relation->Bind(*this); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_select.cpp b/src/duckdb/src/planner/binder/statement/bind_select.cpp new file mode 100644 index 00000000..711a07c5 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_select.cpp @@ -0,0 +1,13 @@ +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_query_node.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(SelectStatement &stmt) { + properties.allow_stream_result = true; + properties.return_type = StatementReturnType::QUERY_RESULT; + return Bind(*stmt.node); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_set.cpp b/src/duckdb/src/planner/binder/statement/bind_set.cpp new file mode 100644 index 00000000..c9984530 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_set.cpp @@ -0,0 +1,44 @@ +#include "duckdb/parser/statement/set_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_set.hpp" +#include "duckdb/planner/operator/logical_reset.hpp" +#include + +namespace duckdb { + +BoundStatement Binder::Bind(SetVariableStatement &stmt) { + BoundStatement result; + result.types = {LogicalType::BOOLEAN}; + result.names = {"Success"}; + + result.plan = make_uniq(stmt.name, stmt.value, stmt.scope); + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +BoundStatement Binder::Bind(ResetVariableStatement &stmt) { + BoundStatement result; + result.types = {LogicalType::BOOLEAN}; + result.names = {"Success"}; + + result.plan = make_uniq(stmt.name, stmt.scope); + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +BoundStatement Binder::Bind(SetStatement &stmt) { + switch (stmt.set_type) { + case SetType::SET: { + auto &set_stmt = stmt.Cast(); + return Bind(set_stmt); + } + case SetType::RESET: { + auto &set_stmt = stmt.Cast(); + return Bind(set_stmt); + } + default: + throw NotImplementedException("Type not implemented for SetType"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_show.cpp b/src/duckdb/src/planner/binder/statement/bind_show.cpp new file mode 100644 index 00000000..a6a50e2b --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_show.cpp @@ -0,0 +1,30 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/show_statement.hpp" +#include "duckdb/planner/operator/logical_show.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(ShowStatement &stmt) { + BoundStatement result; + + if (stmt.info->is_summary) { + return BindSummarize(stmt); + } + auto plan = Bind(*stmt.info->query); + stmt.info->types = plan.types; + stmt.info->aliases = plan.names; + + auto show = make_uniq(std::move(plan.plan)); + show->types_select = plan.types; + show->aliases = plan.names; + + result.plan = std::move(show); + + result.names = {"column_name", "column_type", "null", "key", "default", "extra"}; + result.types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, + LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; + properties.return_type = StatementReturnType::QUERY_RESULT; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_simple.cpp b/src/duckdb/src/planner/binder/statement/bind_simple.cpp new file mode 100644 index 00000000..272abd10 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_simple.cpp @@ -0,0 +1,48 @@ +#include "duckdb/parser/statement/alter_statement.hpp" +#include "duckdb/parser/statement/transaction_statement.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/planner/binder.hpp" + +//! This file contains the binder definitions for statements that do not need to be bound at all and only require a +//! straightforward conversion + +namespace duckdb { + +BoundStatement Binder::Bind(AlterStatement &stmt) { + BoundStatement result; + result.names = {"Success"}; + result.types = {LogicalType::BOOLEAN}; + BindSchemaOrCatalog(stmt.info->catalog, stmt.info->schema); + auto entry = Catalog::GetEntry(context, stmt.info->GetCatalogType(), stmt.info->catalog, stmt.info->schema, + stmt.info->name, stmt.info->if_not_found); + if (entry) { + auto &catalog = entry->ParentCatalog(); + if (!entry->temporary) { + // we can only alter temporary tables/views in read-only mode + properties.modified_databases.insert(catalog.GetName()); + } + stmt.info->catalog = catalog.GetName(); + stmt.info->schema = entry->ParentSchema().name; + } + result.plan = make_uniq(LogicalOperatorType::LOGICAL_ALTER, std::move(stmt.info)); + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +BoundStatement Binder::Bind(TransactionStatement &stmt) { + // transaction statements do not require a valid transaction + properties.requires_valid_transaction = stmt.info->type == TransactionType::BEGIN_TRANSACTION; + + BoundStatement result; + result.names = {"Success"}; + result.types = {LogicalType::BOOLEAN}; + result.plan = make_uniq(LogicalOperatorType::LOGICAL_TRANSACTION, std::move(stmt.info)); + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp new file mode 100644 index 00000000..c46cd9f7 --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp @@ -0,0 +1,137 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/statement/show_statement.hpp" +#include "duckdb/planner/operator/logical_show.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" + +namespace duckdb { + +static unique_ptr SummarizeWrapUnnest(vector> &children, + const string &alias) { + auto list_function = make_uniq("list_value", std::move(children)); + vector> unnest_children; + unnest_children.push_back(std::move(list_function)); + auto unnest_function = make_uniq("unnest", std::move(unnest_children)); + unnest_function->alias = alias; + return std::move(unnest_function); +} + +static unique_ptr SummarizeCreateAggregate(const string &aggregate, string column_name) { + vector> children; + children.push_back(make_uniq(std::move(column_name))); + auto aggregate_function = make_uniq(aggregate, std::move(children)); + auto cast_function = make_uniq(LogicalType::VARCHAR, std::move(aggregate_function)); + return std::move(cast_function); +} + +static unique_ptr SummarizeCreateAggregate(const string &aggregate, string column_name, + const Value &modifier) { + vector> children; + children.push_back(make_uniq(std::move(column_name))); + children.push_back(make_uniq(modifier)); + auto aggregate_function = make_uniq(aggregate, std::move(children)); + auto cast_function = make_uniq(LogicalType::VARCHAR, std::move(aggregate_function)); + return std::move(cast_function); +} + +static unique_ptr SummarizeCreateCountStar() { + vector> children; + auto aggregate_function = make_uniq("count_star", std::move(children)); + return std::move(aggregate_function); +} + +static unique_ptr SummarizeCreateBinaryFunction(const string &op, unique_ptr left, + unique_ptr right) { + vector> children; + children.push_back(std::move(left)); + children.push_back(std::move(right)); + auto binary_function = make_uniq(op, std::move(children)); + return std::move(binary_function); +} + +static unique_ptr SummarizeCreateNullPercentage(string column_name) { + auto count_star = make_uniq(LogicalType::DOUBLE, SummarizeCreateCountStar()); + auto count = + make_uniq(LogicalType::DOUBLE, SummarizeCreateAggregate("count", std::move(column_name))); + auto null_percentage = SummarizeCreateBinaryFunction("/", std::move(count), std::move(count_star)); + auto negate_x = + SummarizeCreateBinaryFunction("-", make_uniq(Value::DOUBLE(1)), std::move(null_percentage)); + auto percentage_x = + SummarizeCreateBinaryFunction("*", std::move(negate_x), make_uniq(Value::DOUBLE(100))); + auto round_x = SummarizeCreateBinaryFunction("round", std::move(percentage_x), + make_uniq(Value::INTEGER(2))); + auto concat_x = + SummarizeCreateBinaryFunction("concat", std::move(round_x), make_uniq(Value("%"))); + + return concat_x; +} + +BoundStatement Binder::BindSummarize(ShowStatement &stmt) { + auto query_copy = stmt.info->query->Copy(); + + // we bind the plan once in a child-node to figure out the column names and column types + auto child_binder = Binder::CreateBinder(context); + auto plan = child_binder->Bind(*stmt.info->query); + D_ASSERT(plan.types.size() == plan.names.size()); + vector> name_children; + vector> type_children; + vector> min_children; + vector> max_children; + vector> unique_children; + vector> avg_children; + vector> std_children; + vector> q25_children; + vector> q50_children; + vector> q75_children; + vector> count_children; + vector> null_percentage_children; + auto select = make_uniq(); + select->node = std::move(query_copy); + for (idx_t i = 0; i < plan.names.size(); i++) { + name_children.push_back(make_uniq(Value(plan.names[i]))); + type_children.push_back(make_uniq(Value(plan.types[i].ToString()))); + min_children.push_back(SummarizeCreateAggregate("min", plan.names[i])); + max_children.push_back(SummarizeCreateAggregate("max", plan.names[i])); + unique_children.push_back(SummarizeCreateAggregate("approx_count_distinct", plan.names[i])); + if (plan.types[i].IsNumeric()) { + avg_children.push_back(SummarizeCreateAggregate("avg", plan.names[i])); + std_children.push_back(SummarizeCreateAggregate("stddev", plan.names[i])); + q25_children.push_back(SummarizeCreateAggregate("approx_quantile", plan.names[i], Value::FLOAT(0.25))); + q50_children.push_back(SummarizeCreateAggregate("approx_quantile", plan.names[i], Value::FLOAT(0.50))); + q75_children.push_back(SummarizeCreateAggregate("approx_quantile", plan.names[i], Value::FLOAT(0.75))); + } else { + avg_children.push_back(make_uniq(Value())); + std_children.push_back(make_uniq(Value())); + q25_children.push_back(make_uniq(Value())); + q50_children.push_back(make_uniq(Value())); + q75_children.push_back(make_uniq(Value())); + } + count_children.push_back(SummarizeCreateCountStar()); + null_percentage_children.push_back(SummarizeCreateNullPercentage(plan.names[i])); + } + auto subquery_ref = make_uniq(std::move(select), "summarize_tbl"); + subquery_ref->column_name_alias = plan.names; + + auto select_node = make_uniq(); + select_node->select_list.push_back(SummarizeWrapUnnest(name_children, "column_name")); + select_node->select_list.push_back(SummarizeWrapUnnest(type_children, "column_type")); + select_node->select_list.push_back(SummarizeWrapUnnest(min_children, "min")); + select_node->select_list.push_back(SummarizeWrapUnnest(max_children, "max")); + select_node->select_list.push_back(SummarizeWrapUnnest(unique_children, "approx_unique")); + select_node->select_list.push_back(SummarizeWrapUnnest(avg_children, "avg")); + select_node->select_list.push_back(SummarizeWrapUnnest(std_children, "std")); + select_node->select_list.push_back(SummarizeWrapUnnest(q25_children, "q25")); + select_node->select_list.push_back(SummarizeWrapUnnest(q50_children, "q50")); + select_node->select_list.push_back(SummarizeWrapUnnest(q75_children, "q75")); + select_node->select_list.push_back(SummarizeWrapUnnest(count_children, "count")); + select_node->select_list.push_back(SummarizeWrapUnnest(null_percentage_children, "null_percentage")); + select_node->from_table = std::move(subquery_ref); + + properties.return_type = StatementReturnType::QUERY_RESULT; + return Bind(*select_node); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_update.cpp b/src/duckdb/src/planner/binder/statement/bind_update.cpp new file mode 100644 index 00000000..ce7c6c5f --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_update.cpp @@ -0,0 +1,156 @@ +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/tableref/bound_joinref.hpp" +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/constraints/bound_check_constraint.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_default_expression.hpp" +#include "duckdb/planner/expression_binder/update_binder.hpp" +#include "duckdb/planner/expression_binder/where_binder.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_update.hpp" +#include "duckdb/planner/tableref/bound_basetableref.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/storage/data_table.hpp" + +#include + +namespace duckdb { + +// This creates a LogicalProjection and moves 'root' into it as a child +// unless there are no expressions to project, in which case it just returns 'root' +unique_ptr Binder::BindUpdateSet(LogicalOperator &op, unique_ptr root, + UpdateSetInfo &set_info, TableCatalogEntry &table, + vector &columns) { + auto proj_index = GenerateTableIndex(); + + vector> projection_expressions; + D_ASSERT(set_info.columns.size() == set_info.expressions.size()); + for (idx_t i = 0; i < set_info.columns.size(); i++) { + auto &colname = set_info.columns[i]; + auto &expr = set_info.expressions[i]; + if (!table.ColumnExists(colname)) { + throw BinderException("Referenced update column %s not found in table!", colname); + } + auto &column = table.GetColumn(colname); + if (column.Generated()) { + throw BinderException("Cant update column \"%s\" because it is a generated column!", column.Name()); + } + if (std::find(columns.begin(), columns.end(), column.Physical()) != columns.end()) { + throw BinderException("Multiple assignments to same column \"%s\"", colname); + } + columns.push_back(column.Physical()); + if (expr->type == ExpressionType::VALUE_DEFAULT) { + op.expressions.push_back(make_uniq(column.Type())); + } else { + UpdateBinder binder(*this, context); + binder.target_type = column.Type(); + auto bound_expr = binder.Bind(expr); + PlanSubqueries(bound_expr, root); + + op.expressions.push_back(make_uniq( + bound_expr->return_type, ColumnBinding(proj_index, projection_expressions.size()))); + projection_expressions.push_back(std::move(bound_expr)); + } + } + if (op.type != LogicalOperatorType::LOGICAL_UPDATE && projection_expressions.empty()) { + return root; + } + // now create the projection + auto proj = make_uniq(proj_index, std::move(projection_expressions)); + proj->AddChild(std::move(root)); + return unique_ptr_cast(std::move(proj)); +} + +BoundStatement Binder::Bind(UpdateStatement &stmt) { + BoundStatement result; + unique_ptr root; + + // visit the table reference + auto bound_table = Bind(*stmt.table); + if (bound_table->type != TableReferenceType::BASE_TABLE) { + throw BinderException("Can only update base table!"); + } + auto &table_binding = bound_table->Cast(); + auto &table = table_binding.table; + + // Add CTEs as bindable + AddCTEMap(stmt.cte_map); + + optional_ptr get; + if (stmt.from_table) { + auto from_binder = Binder::CreateBinder(context, this); + BoundJoinRef bound_crossproduct(JoinRefType::CROSS); + bound_crossproduct.left = std::move(bound_table); + bound_crossproduct.right = from_binder->Bind(*stmt.from_table); + root = CreatePlan(bound_crossproduct); + get = &root->children[0]->Cast(); + bind_context.AddContext(std::move(from_binder->bind_context)); + } else { + root = CreatePlan(*bound_table); + get = &root->Cast(); + } + + if (!table.temporary) { + // update of persistent table: not read only! + properties.modified_databases.insert(table.catalog.GetName()); + } + auto update = make_uniq(table); + + // set return_chunk boolean early because it needs uses update_is_del_and_insert logic + if (!stmt.returning_list.empty()) { + update->return_chunk = true; + } + // bind the default values + BindDefaultValues(table.GetColumns(), update->bound_defaults); + + // project any additional columns required for the condition/expressions + if (stmt.set_info->condition) { + WhereBinder binder(*this, context); + auto condition = binder.Bind(stmt.set_info->condition); + + PlanSubqueries(condition, root); + auto filter = make_uniq(std::move(condition)); + filter->AddChild(std::move(root)); + root = std::move(filter); + } + + D_ASSERT(stmt.set_info); + D_ASSERT(stmt.set_info->columns.size() == stmt.set_info->expressions.size()); + + auto proj_tmp = BindUpdateSet(*update, std::move(root), *stmt.set_info, table, update->columns); + D_ASSERT(proj_tmp->type == LogicalOperatorType::LOGICAL_PROJECTION); + auto proj = unique_ptr_cast(std::move(proj_tmp)); + + // bind any extra columns necessary for CHECK constraints or indexes + table.BindUpdateConstraints(*get, *proj, *update, context); + + // finally add the row id column to the projection list + proj->expressions.push_back(make_uniq( + LogicalType::ROW_TYPE, ColumnBinding(get->table_index, get->column_ids.size()))); + get->column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); + + // set the projection as child of the update node and finalize the result + update->AddChild(std::move(proj)); + + auto update_table_index = GenerateTableIndex(); + update->table_index = update_table_index; + if (!stmt.returning_list.empty()) { + unique_ptr update_as_logicaloperator = std::move(update); + + return BindReturning(std::move(stmt.returning_list), table, stmt.table->alias, update_table_index, + std::move(update_as_logicaloperator), std::move(result)); + } + + result.names = {"Count"}; + result.types = {LogicalType::BIGINT}; + result.plan = std::move(update); + properties.allow_stream_result = false; + properties.return_type = StatementReturnType::CHANGED_ROWS; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp new file mode 100644 index 00000000..fc36d55a --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp @@ -0,0 +1,95 @@ +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/statement/vacuum_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(VacuumStatement &stmt) { + BoundStatement result; + + unique_ptr root; + + if (stmt.info->has_table) { + D_ASSERT(!stmt.info->table); + D_ASSERT(stmt.info->column_id_map.empty()); + auto bound_table = Bind(*stmt.info->ref); + if (bound_table->type != TableReferenceType::BASE_TABLE) { + throw InvalidInputException("Can only vacuum/analyze base tables!"); + } + auto ref = unique_ptr_cast(std::move(bound_table)); + auto &table = ref->table; + stmt.info->table = &table; + + auto &columns = stmt.info->columns; + vector> select_list; + if (columns.empty()) { + // Empty means ALL columns should be vacuumed/analyzed + auto &get = ref->get->Cast(); + columns.insert(columns.end(), get.names.begin(), get.names.end()); + } + + case_insensitive_set_t column_name_set; + vector non_generated_column_names; + for (auto &col_name : columns) { + if (column_name_set.count(col_name) > 0) { + throw BinderException("Vacuum the same column twice(same name in column name list)"); + } + column_name_set.insert(col_name); + if (!table.ColumnExists(col_name)) { + throw BinderException("Column with name \"%s\" does not exist", col_name); + } + auto &col = table.GetColumn(col_name); + // ignore generated column + if (col.Generated()) { + continue; + } + non_generated_column_names.push_back(col_name); + ColumnRefExpression colref(col_name, table.name); + auto result = bind_context.BindColumn(colref, 0); + if (result.HasError()) { + throw BinderException(result.error); + } + select_list.push_back(std::move(result.expression)); + } + stmt.info->columns = std::move(non_generated_column_names); + if (!select_list.empty()) { + auto table_scan = CreatePlan(*ref); + D_ASSERT(table_scan->type == LogicalOperatorType::LOGICAL_GET); + + auto &get = table_scan->Cast(); + + D_ASSERT(select_list.size() == get.column_ids.size()); + D_ASSERT(stmt.info->columns.size() == get.column_ids.size()); + for (idx_t i = 0; i < get.column_ids.size(); i++) { + stmt.info->column_id_map[i] = + table.GetColumns().LogicalToPhysical(LogicalIndex(get.column_ids[i])).index; + } + + auto projection = make_uniq(GenerateTableIndex(), std::move(select_list)); + projection->children.push_back(std::move(table_scan)); + + root = std::move(projection); + } else { + // eg. CREATE TABLE test (x AS (1)); + // ANALYZE test; + // Make it not a SINK so it doesn't have to do anything + stmt.info->has_table = false; + } + } + auto vacuum = make_uniq(LogicalOperatorType::LOGICAL_VACUUM, std::move(stmt.info)); + if (root) { + vacuum->children.push_back(std::move(root)); + } + + result.names = {"Success"}; + result.types = {LogicalType::BOOLEAN}; + result.plan = std::move(vacuum); + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp new file mode 100644 index 00000000..3a4a0d6d --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp @@ -0,0 +1,252 @@ +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/tableref/bound_basetableref.hpp" +#include "duckdb/planner/tableref/bound_subqueryref.hpp" +#include "duckdb/planner/tableref/bound_cteref.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/planner/tableref/bound_dummytableref.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +static bool TryLoadExtensionForReplacementScan(ClientContext &context, const string &table_name) { + auto lower_name = StringUtil::Lower(table_name); + auto &dbconfig = DBConfig::GetConfig(context); + + if (!dbconfig.options.autoload_known_extensions) { + return false; + } + + for (const auto &entry : EXTENSION_FILE_POSTFIXES) { + if (StringUtil::EndsWith(lower_name, entry.name)) { + ExtensionHelper::AutoLoadExtension(context, entry.extension); + return true; + } + } + + for (const auto &entry : EXTENSION_FILE_CONTAINS) { + if (StringUtil::Contains(lower_name, entry.name)) { + ExtensionHelper::AutoLoadExtension(context, entry.extension); + return true; + } + } + + return false; +} + +unique_ptr Binder::BindWithReplacementScan(ClientContext &context, const string &table_name, + BaseTableRef &ref) { + auto &config = DBConfig::GetConfig(context); + if (context.config.use_replacement_scans) { + for (auto &scan : config.replacement_scans) { + auto replacement_function = scan.function(context, table_name, scan.data.get()); + if (replacement_function) { + if (!ref.alias.empty()) { + // user-provided alias overrides the default alias + replacement_function->alias = ref.alias; + } else if (replacement_function->alias.empty()) { + // if the replacement scan itself did not provide an alias we use the table name + replacement_function->alias = ref.table_name; + } + if (replacement_function->type == TableReferenceType::TABLE_FUNCTION) { + auto &table_function = replacement_function->Cast(); + table_function.column_name_alias = ref.column_name_alias; + } else if (replacement_function->type == TableReferenceType::SUBQUERY) { + auto &subquery = replacement_function->Cast(); + subquery.column_name_alias = ref.column_name_alias; + } else { + throw InternalException("Replacement scan should return either a table function or a subquery"); + } + return Bind(*replacement_function); + } + } + } + + return nullptr; +} + +unique_ptr Binder::Bind(BaseTableRef &ref) { + QueryErrorContext error_context(root_statement, ref.query_location); + // CTEs and views are also referred to using BaseTableRefs, hence need to distinguish here + // check if the table name refers to a CTE + + // CTE name should never be qualified (i.e. schema_name should be empty) + optional_ptr found_cte = nullptr; + if (ref.schema_name.empty()) { + found_cte = FindCTE(ref.table_name, ref.table_name == alias); + } + + if (found_cte) { + // Check if there is a CTE binding in the BindContext + auto &cte = *found_cte; + auto ctebinding = bind_context.GetCTEBinding(ref.table_name); + if (!ctebinding) { + if (CTEIsAlreadyBound(cte)) { + throw BinderException( + "Circular reference to CTE \"%s\", There are two possible solutions. \n1. use WITH RECURSIVE to " + "use recursive CTEs. \n2. If " + "you want to use the TABLE name \"%s\" the same as the CTE name, please explicitly add " + "\"SCHEMA\" before table name. You can try \"main.%s\" (main is the duckdb default schema)", + ref.table_name, ref.table_name, ref.table_name); + } + // Move CTE to subquery and bind recursively + SubqueryRef subquery(unique_ptr_cast(cte.query->Copy())); + subquery.alias = ref.alias.empty() ? ref.table_name : ref.alias; + subquery.column_name_alias = cte.aliases; + for (idx_t i = 0; i < ref.column_name_alias.size(); i++) { + if (i < subquery.column_name_alias.size()) { + subquery.column_name_alias[i] = ref.column_name_alias[i]; + } else { + subquery.column_name_alias.push_back(ref.column_name_alias[i]); + } + } + return Bind(subquery, found_cte); + } else { + // There is a CTE binding in the BindContext. + // This can only be the case if there is a recursive CTE, + // or a materialized CTE present. + auto index = GenerateTableIndex(); + auto materialized = cte.materialized; + if (materialized == CTEMaterialize::CTE_MATERIALIZE_DEFAULT) { +#ifdef DUCKDB_ALTERNATIVE_VERIFY + materialized = CTEMaterialize::CTE_MATERIALIZE_ALWAYS; +#else + materialized = CTEMaterialize::CTE_MATERIALIZE_NEVER; +#endif + } + auto result = make_uniq(index, ctebinding->index, materialized); + auto b = ctebinding; + auto alias = ref.alias.empty() ? ref.table_name : ref.alias; + auto names = BindContext::AliasColumnNames(alias, b->names, ref.column_name_alias); + + bind_context.AddGenericBinding(index, alias, names, b->types); + // Update references to CTE + auto cteref = bind_context.cte_references[ref.table_name]; + (*cteref)++; + + result->types = b->types; + result->bound_columns = std::move(names); + return std::move(result); + } + } + // not a CTE + // extract a table or view from the catalog + BindSchemaOrCatalog(ref.catalog_name, ref.schema_name); + auto table_or_view = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, ref.catalog_name, ref.schema_name, + ref.table_name, OnEntryNotFound::RETURN_NULL, error_context); + // we still didn't find the table + if (GetBindingMode() == BindingMode::EXTRACT_NAMES) { + if (!table_or_view || table_or_view->type == CatalogType::TABLE_ENTRY) { + // if we are in EXTRACT_NAMES, we create a dummy table ref + AddTableName(ref.table_name); + + // add a bind context entry + auto table_index = GenerateTableIndex(); + auto alias = ref.alias.empty() ? ref.table_name : ref.alias; + vector types {LogicalType::INTEGER}; + vector names {"__dummy_col" + to_string(table_index)}; + bind_context.AddGenericBinding(table_index, alias, names, types); + return make_uniq_base(table_index); + } + } + if (!table_or_view) { + string table_name = ref.catalog_name; + if (!ref.schema_name.empty()) { + table_name += (!table_name.empty() ? "." : "") + ref.schema_name; + } + table_name += (!table_name.empty() ? "." : "") + ref.table_name; + // table could not be found: try to bind a replacement scan + // Try replacement scan bind + auto replacement_scan_bind_result = BindWithReplacementScan(context, table_name, ref); + if (replacement_scan_bind_result) { + return replacement_scan_bind_result; + } + + // Try autoloading an extension, then retry the replacement scan bind + auto extension_loaded = TryLoadExtensionForReplacementScan(context, table_name); + if (extension_loaded) { + replacement_scan_bind_result = BindWithReplacementScan(context, table_name, ref); + if (replacement_scan_bind_result) { + return replacement_scan_bind_result; + } + } + + // could not find an alternative: bind again to get the error + table_or_view = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, ref.catalog_name, ref.schema_name, + ref.table_name, OnEntryNotFound::THROW_EXCEPTION, error_context); + } + switch (table_or_view->type) { + case CatalogType::TABLE_ENTRY: { + // base table: create the BoundBaseTableRef node + auto table_index = GenerateTableIndex(); + auto &table = table_or_view->Cast(); + + unique_ptr bind_data; + auto scan_function = table.GetScanFunction(context, bind_data); + auto alias = ref.alias.empty() ? ref.table_name : ref.alias; + // TODO: bundle the type and name vector in a struct (e.g PackedColumnMetadata) + vector table_types; + vector table_names; + vector table_categories; + + vector return_types; + vector return_names; + for (auto &col : table.GetColumns().Logical()) { + table_types.push_back(col.Type()); + table_names.push_back(col.Name()); + return_types.push_back(col.Type()); + return_names.push_back(col.Name()); + } + table_names = BindContext::AliasColumnNames(alias, table_names, ref.column_name_alias); + + auto logical_get = make_uniq(table_index, scan_function, std::move(bind_data), + std::move(return_types), std::move(return_names)); + bind_context.AddBaseTable(table_index, alias, table_names, table_types, logical_get->column_ids, + logical_get->GetTable().get()); + return make_uniq_base(table, std::move(logical_get)); + } + case CatalogType::VIEW_ENTRY: { + // the node is a view: get the query that the view represents + auto &view_catalog_entry = table_or_view->Cast(); + // We need to use a new binder for the view that doesn't reference any CTEs + // defined for this binder so there are no collisions between the CTEs defined + // for the view and for the current query + bool inherit_ctes = false; + auto view_binder = Binder::CreateBinder(context, this, inherit_ctes); + view_binder->can_contain_nulls = true; + SubqueryRef subquery(unique_ptr_cast(view_catalog_entry.query->Copy())); + subquery.alias = ref.alias.empty() ? ref.table_name : ref.alias; + subquery.column_name_alias = + BindContext::AliasColumnNames(subquery.alias, view_catalog_entry.aliases, ref.column_name_alias); + // bind the child subquery + view_binder->AddBoundView(view_catalog_entry); + auto bound_child = view_binder->Bind(subquery); + if (!view_binder->correlated_columns.empty()) { + throw BinderException("Contents of view were altered - view bound correlated columns"); + } + + D_ASSERT(bound_child->type == TableReferenceType::SUBQUERY); + // verify that the types and names match up with the expected types and names + auto &bound_subquery = bound_child->Cast(); + if (GetBindingMode() != BindingMode::EXTRACT_NAMES && + bound_subquery.subquery->types != view_catalog_entry.types) { + throw BinderException("Contents of view were altered: types don't match!"); + } + bind_context.AddView(bound_subquery.subquery->GetRootIndex(), subquery.alias, subquery, + *bound_subquery.subquery, &view_catalog_entry); + return bound_child; + } + default: + throw InternalException("Catalog entry type"); + } +} +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp new file mode 100644 index 00000000..fe0e96f3 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp @@ -0,0 +1,11 @@ +#include "duckdb/parser/tableref/emptytableref.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/tableref/bound_dummytableref.hpp" + +namespace duckdb { + +unique_ptr Binder::Bind(EmptyTableRef &ref) { + return make_uniq(GenerateTableIndex()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp new file mode 100644 index 00000000..a506764b --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp @@ -0,0 +1,65 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/tableref/bound_expressionlistref.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/planner/expression_binder/insert_binder.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +unique_ptr Binder::Bind(ExpressionListRef &expr) { + auto result = make_uniq(); + result->types = expr.expected_types; + result->names = expr.expected_names; + // bind value list + InsertBinder binder(*this, context); + binder.target_type = LogicalType(LogicalTypeId::INVALID); + for (idx_t list_idx = 0; list_idx < expr.values.size(); list_idx++) { + auto &expression_list = expr.values[list_idx]; + if (result->names.empty()) { + // no names provided, generate them + for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { + result->names.push_back("col" + to_string(val_idx)); + } + } + + vector> list; + for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { + if (!result->types.empty()) { + D_ASSERT(result->types.size() == expression_list.size()); + binder.target_type = result->types[val_idx]; + } + auto expr = binder.Bind(expression_list[val_idx]); + list.push_back(std::move(expr)); + } + result->values.push_back(std::move(list)); + } + if (result->types.empty() && !expr.values.empty()) { + // there are no types specified + // we have to figure out the result types + // for each column, we iterate over all of the expressions and select the max logical type + // we initialize all types to SQLNULL + result->types.resize(expr.values[0].size(), LogicalType::SQLNULL); + // now loop over the lists and select the max logical type + for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { + auto &list = result->values[list_idx]; + for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { + result->types[val_idx] = + LogicalType::MaxLogicalType(result->types[val_idx], list[val_idx]->return_type); + } + } + // finally do another loop over the expressions and add casts where required + for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { + auto &list = result->values[list_idx]; + for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { + list[val_idx] = + BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result->types[val_idx]); + } + } + } + result->bind_index = GenerateTableIndex(); + bind_context.AddGenericBinding(result->bind_index, expr.alias, result->names, result->types); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp new file mode 100644 index 00000000..a8a0c9ed --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp @@ -0,0 +1,320 @@ +#include "duckdb/parser/tableref/joinref.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder/where_binder.hpp" +#include "duckdb/planner/tableref/bound_joinref.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/planner/expression_binder/lateral_binder.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" + +namespace duckdb { + +static unique_ptr BindColumn(Binder &binder, ClientContext &context, const string &alias, + const string &column_name) { + auto expr = make_uniq_base(column_name, alias); + ExpressionBinder expr_binder(binder, context); + auto result = expr_binder.Bind(expr); + return make_uniq(std::move(result)); +} + +static unique_ptr AddCondition(ClientContext &context, Binder &left_binder, Binder &right_binder, + const string &left_alias, const string &right_alias, + const string &column_name, ExpressionType type) { + ExpressionBinder expr_binder(left_binder, context); + auto left = BindColumn(left_binder, context, left_alias, column_name); + auto right = BindColumn(right_binder, context, right_alias, column_name); + return make_uniq(type, std::move(left), std::move(right)); +} + +bool Binder::TryFindBinding(const string &using_column, const string &join_side, string &result) { + // for each using column, get the matching binding + auto bindings = bind_context.GetMatchingBindings(using_column); + if (bindings.empty()) { + return false; + } + // find the join binding + for (auto &binding : bindings) { + if (!result.empty()) { + string error = "Column name \""; + error += using_column; + error += "\" is ambiguous: it exists more than once on "; + error += join_side; + error += " side of join.\nCandidates:"; + for (auto &binding : bindings) { + error += "\n\t"; + error += binding; + error += "."; + error += bind_context.GetActualColumnName(binding, using_column); + } + throw BinderException(error); + } else { + result = binding; + } + } + return true; +} + +string Binder::FindBinding(const string &using_column, const string &join_side) { + string result; + if (!TryFindBinding(using_column, join_side, result)) { + throw BinderException("Column \"%s\" does not exist on %s side of join!", using_column, join_side); + } + return result; +} + +static void AddUsingBindings(UsingColumnSet &set, optional_ptr input_set, const string &input_binding) { + if (input_set) { + for (auto &entry : input_set->bindings) { + set.bindings.insert(entry); + } + } else { + set.bindings.insert(input_binding); + } +} + +static void SetPrimaryBinding(UsingColumnSet &set, JoinType join_type, const string &left_binding, + const string &right_binding) { + switch (join_type) { + case JoinType::LEFT: + case JoinType::INNER: + case JoinType::SEMI: + case JoinType::ANTI: + set.primary_binding = left_binding; + break; + case JoinType::RIGHT: + set.primary_binding = right_binding; + break; + default: + break; + } +} + +string Binder::RetrieveUsingBinding(Binder ¤t_binder, optional_ptr current_set, + const string &using_column, const string &join_side) { + string binding; + if (!current_set) { + binding = current_binder.FindBinding(using_column, join_side); + } else { + binding = current_set->primary_binding; + } + return binding; +} + +static vector RemoveDuplicateUsingColumns(const vector &using_columns) { + vector result; + case_insensitive_set_t handled_columns; + for (auto &using_column : using_columns) { + if (handled_columns.find(using_column) == handled_columns.end()) { + handled_columns.insert(using_column); + result.push_back(using_column); + } + } + return result; +} + +unique_ptr Binder::Bind(JoinRef &ref) { + auto result = make_uniq(ref.ref_type); + result->left_binder = Binder::CreateBinder(context, this); + result->right_binder = Binder::CreateBinder(context, this); + auto &left_binder = *result->left_binder; + auto &right_binder = *result->right_binder; + + result->type = ref.type; + result->left = left_binder.Bind(*ref.left); + { + LateralBinder binder(left_binder, context); + result->right = right_binder.Bind(*ref.right); + bool is_lateral = false; + // Store the correlated columns in the right binder in bound ref for planning of LATERALs + // Ignore the correlated columns in the left binder, flattening handles those correlations + result->correlated_columns = right_binder.correlated_columns; + // Find correlations for the current join + for (auto &cor_col : result->correlated_columns) { + if (cor_col.depth == 1) { + // Depth 1 indicates columns binding from the left indicating a lateral join + is_lateral = true; + break; + } + } + result->lateral = is_lateral; + if (result->lateral) { + // lateral join: can only be an INNER or LEFT join + if (ref.type != JoinType::INNER && ref.type != JoinType::LEFT) { + throw BinderException("The combining JOIN type must be INNER or LEFT for a LATERAL reference"); + } + } + } + + vector> extra_conditions; + vector extra_using_columns; + switch (ref.ref_type) { + case JoinRefType::NATURAL: { + // natural join, figure out which column names are present in both sides of the join + // first bind the left hand side and get a list of all the tables and column names + case_insensitive_set_t lhs_columns; + auto &lhs_binding_list = left_binder.bind_context.GetBindingsList(); + for (auto &binding : lhs_binding_list) { + for (auto &column_name : binding.get().names) { + lhs_columns.insert(column_name); + } + } + // now bind the rhs + for (auto &column_name : lhs_columns) { + auto right_using_binding = right_binder.bind_context.GetUsingBinding(column_name); + + string right_binding; + // loop over the set of lhs columns, and figure out if there is a table in the rhs with the same name + if (!right_using_binding) { + if (!right_binder.TryFindBinding(column_name, "right", right_binding)) { + // no match found for this column on the rhs: skip + continue; + } + } + extra_using_columns.push_back(column_name); + } + if (extra_using_columns.empty()) { + // no matching bindings found in natural join: throw an exception + string error_msg = "No columns found to join on in NATURAL JOIN.\n"; + error_msg += "Use CROSS JOIN if you intended for this to be a cross-product."; + // gather all left/right candidates + string left_candidates, right_candidates; + auto &rhs_binding_list = right_binder.bind_context.GetBindingsList(); + for (auto &binding_ref : lhs_binding_list) { + auto &binding = binding_ref.get(); + for (auto &column_name : binding.names) { + if (!left_candidates.empty()) { + left_candidates += ", "; + } + left_candidates += binding.alias + "." + column_name; + } + } + for (auto &binding_ref : rhs_binding_list) { + auto &binding = binding_ref.get(); + for (auto &column_name : binding.names) { + if (!right_candidates.empty()) { + right_candidates += ", "; + } + right_candidates += binding.alias + "." + column_name; + } + } + error_msg += "\n Left candidates: " + left_candidates; + error_msg += "\n Right candidates: " + right_candidates; + throw BinderException(FormatError(ref, error_msg)); + } + break; + } + case JoinRefType::REGULAR: + case JoinRefType::ASOF: + if (!ref.using_columns.empty()) { + // USING columns + D_ASSERT(!result->condition); + extra_using_columns = ref.using_columns; + } + break; + + case JoinRefType::CROSS: + case JoinRefType::POSITIONAL: + case JoinRefType::DEPENDENT: + break; + } + extra_using_columns = RemoveDuplicateUsingColumns(extra_using_columns); + + if (!extra_using_columns.empty()) { + vector> left_using_bindings; + vector> right_using_bindings; + for (idx_t i = 0; i < extra_using_columns.size(); i++) { + auto &using_column = extra_using_columns[i]; + // we check if there is ALREADY a using column of the same name in the left and right set + // this can happen if we chain USING clauses + // e.g. x JOIN y USING (c) JOIN z USING (c) + auto left_using_binding = left_binder.bind_context.GetUsingBinding(using_column); + auto right_using_binding = right_binder.bind_context.GetUsingBinding(using_column); + if (!left_using_binding) { + left_binder.bind_context.GetMatchingBinding(using_column); + } + if (!right_using_binding) { + right_binder.bind_context.GetMatchingBinding(using_column); + } + left_using_bindings.push_back(left_using_binding); + right_using_bindings.push_back(right_using_binding); + } + + for (idx_t i = 0; i < extra_using_columns.size(); i++) { + auto &using_column = extra_using_columns[i]; + string left_binding; + string right_binding; + + auto set = make_uniq(); + auto &left_using_binding = left_using_bindings[i]; + auto &right_using_binding = right_using_bindings[i]; + left_binding = RetrieveUsingBinding(left_binder, left_using_binding, using_column, "left"); + right_binding = RetrieveUsingBinding(right_binder, right_using_binding, using_column, "right"); + + // Last column of ASOF JOIN ... USING is >= + const auto type = (ref.ref_type == JoinRefType::ASOF && i == extra_using_columns.size() - 1) + ? ExpressionType::COMPARE_GREATERTHANOREQUALTO + : ExpressionType::COMPARE_EQUAL; + + extra_conditions.push_back( + AddCondition(context, left_binder, right_binder, left_binding, right_binding, using_column, type)); + + AddUsingBindings(*set, left_using_binding, left_binding); + AddUsingBindings(*set, right_using_binding, right_binding); + SetPrimaryBinding(*set, ref.type, left_binding, right_binding); + bind_context.TransferUsingBinding(left_binder.bind_context, left_using_binding, *set, left_binding, + using_column); + bind_context.TransferUsingBinding(right_binder.bind_context, right_using_binding, *set, right_binding, + using_column); + AddUsingBindingSet(std::move(set)); + } + } + + auto right_bindings_list_copy = right_binder.bind_context.GetBindingsList(); + + bind_context.AddContext(std::move(left_binder.bind_context)); + bind_context.AddContext(std::move(right_binder.bind_context)); + + // Update the correlated columns for the parent binder + // For the left binder, depth >= 1 indicates correlations from the parent binder + for (const auto &col : left_binder.correlated_columns) { + if (col.depth >= 1) { + AddCorrelatedColumn(col); + } + } + // For the right binder, depth > 1 indicates correlations from the parent binder + // (depth = 1 indicates correlations from the left side of the join) + for (auto col : right_binder.correlated_columns) { + if (col.depth > 1) { + // Decrement the depth to account for the effect of the lateral binder + col.depth--; + AddCorrelatedColumn(col); + } + } + + for (auto &condition : extra_conditions) { + if (ref.condition) { + ref.condition = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(ref.condition), + std::move(condition)); + } else { + ref.condition = std::move(condition); + } + } + if (ref.condition) { + WhereBinder binder(*this, context); + result->condition = binder.Bind(ref.condition); + } + + if (result->type == JoinType::SEMI || result->type == JoinType::ANTI) { + bind_context.RemoveContext(right_bindings_list_copy); + } + + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_named_parameters.cpp b/src/duckdb/src/planner/binder/tableref/bind_named_parameters.cpp new file mode 100644 index 00000000..e7d9fa1e --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/bind_named_parameters.cpp @@ -0,0 +1,34 @@ +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +void Binder::BindNamedParameters(named_parameter_type_map_t &types, named_parameter_map_t &values, + QueryErrorContext &error_context, string &func_name) { + for (auto &kv : values) { + auto entry = types.find(kv.first); + if (entry == types.end()) { + // create a list of named parameters for the error + string named_params; + for (auto &kv : types) { + named_params += " "; + named_params += kv.first; + named_params += " "; + named_params += kv.second.ToString(); + named_params += "\n"; + } + string error_msg; + if (named_params.empty()) { + error_msg = "Function does not accept any named parameters."; + } else { + error_msg = "Candidates:\n" + named_params; + } + throw BinderException(error_context.FormatError("Invalid named parameter \"%s\" for function %s\n%s", + kv.first, func_name, error_msg)); + } + if (entry->second.id() != LogicalTypeId::ANY) { + kv.second = kv.second.DefaultCastAs(entry->second); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp new file mode 100644 index 00000000..ed474f24 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -0,0 +1,662 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/tableref/pivotref.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/common/types/value_map.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/planner/tableref/bound_subqueryref.hpp" +#include "duckdb/planner/tableref/bound_pivotref.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" + +namespace duckdb { + +static void ConstructPivots(PivotRef &ref, vector &pivot_values, idx_t pivot_idx = 0, + const PivotValueElement ¤t_value = PivotValueElement()) { + auto &pivot = ref.pivots[pivot_idx]; + bool last_pivot = pivot_idx + 1 == ref.pivots.size(); + for (auto &entry : pivot.entries) { + PivotValueElement new_value = current_value; + string name = entry.alias; + D_ASSERT(entry.values.size() == pivot.pivot_expressions.size()); + for (idx_t v = 0; v < entry.values.size(); v++) { + auto &value = entry.values[v]; + new_value.values.push_back(value); + if (entry.alias.empty()) { + if (name.empty()) { + name = value.ToString(); + } else { + name += "_" + value.ToString(); + } + } + } + if (!current_value.name.empty()) { + new_value.name = current_value.name + "_" + name; + } else { + new_value.name = std::move(name); + } + if (last_pivot) { + pivot_values.push_back(std::move(new_value)); + } else { + // need to recurse + ConstructPivots(ref, pivot_values, pivot_idx + 1, new_value); + } + } +} + +static void ExtractPivotExpressions(ParsedExpression &expr, case_insensitive_set_t &handled_columns) { + if (expr.type == ExpressionType::COLUMN_REF) { + auto &child_colref = expr.Cast(); + if (child_colref.IsQualified()) { + throw BinderException("PIVOT expression cannot contain qualified columns"); + } + handled_columns.insert(child_colref.GetColumnName()); + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](ParsedExpression &child) { ExtractPivotExpressions(child, handled_columns); }); +} + +static unique_ptr ConstructInitialGrouping(PivotRef &ref, vector> all_columns, + const case_insensitive_set_t &handled_columns) { + auto subquery = make_uniq(); + subquery->from_table = std::move(ref.source); + if (ref.groups.empty()) { + // if rows are not specified any columns that are not pivoted/aggregated on are added to the GROUP BY clause + for (auto &entry : all_columns) { + if (entry->type != ExpressionType::COLUMN_REF) { + throw InternalException("Unexpected child of pivot source - not a ColumnRef"); + } + auto &columnref = entry->Cast(); + if (handled_columns.find(columnref.GetColumnName()) == handled_columns.end()) { + // not handled - add to grouping set + subquery->groups.group_expressions.push_back( + make_uniq(Value::INTEGER(subquery->select_list.size() + 1))); + subquery->select_list.push_back(make_uniq(columnref.GetColumnName())); + } + } + } else { + // if rows are specified only the columns mentioned in rows are added as groups + for (auto &row : ref.groups) { + subquery->groups.group_expressions.push_back( + make_uniq(Value::INTEGER(subquery->select_list.size() + 1))); + subquery->select_list.push_back(make_uniq(row)); + } + } + return subquery; +} + +static unique_ptr PivotFilteredAggregate(PivotRef &ref, vector> all_columns, + const case_insensitive_set_t &handled_columns, + vector pivot_values) { + auto subquery = ConstructInitialGrouping(ref, std::move(all_columns), handled_columns); + + // push the filtered aggregates + for (auto &pivot_value : pivot_values) { + unique_ptr filter; + idx_t pivot_value_idx = 0; + for (auto &pivot_column : ref.pivots) { + for (auto &pivot_expr : pivot_column.pivot_expressions) { + auto column_ref = make_uniq(LogicalType::VARCHAR, pivot_expr->Copy()); + auto constant_value = make_uniq(pivot_value.values[pivot_value_idx++]); + auto comp_expr = make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, + std::move(column_ref), std::move(constant_value)); + if (filter) { + filter = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(filter), + std::move(comp_expr)); + } else { + filter = std::move(comp_expr); + } + } + } + for (auto &aggregate : ref.aggregates) { + auto copied_aggr = aggregate->Copy(); + auto &aggr = copied_aggr->Cast(); + aggr.filter = filter->Copy(); + auto &aggr_name = aggregate->alias; + auto name = pivot_value.name; + if (ref.aggregates.size() > 1 || !aggr_name.empty()) { + // if there are multiple aggregates specified we add the name of the aggregate as well + name += "_" + (aggr_name.empty() ? aggregate->GetName() : aggr_name); + } + aggr.alias = name; + subquery->select_list.push_back(std::move(copied_aggr)); + } + } + return subquery; +} + +struct PivotBindState { + vector internal_group_names; + vector group_names; + vector aggregate_names; + vector internal_aggregate_names; +}; + +static unique_ptr PivotInitialAggregate(PivotBindState &bind_state, PivotRef &ref, + vector> all_columns, + const case_insensitive_set_t &handled_columns) { + auto subquery_stage1 = ConstructInitialGrouping(ref, std::move(all_columns), handled_columns); + + idx_t group_count = 0; + for (auto &expr : subquery_stage1->select_list) { + bind_state.group_names.push_back(expr->GetName()); + if (expr->alias.empty()) { + expr->alias = "__internal_pivot_group" + std::to_string(++group_count); + } + bind_state.internal_group_names.push_back(expr->alias); + } + // group by all of the pivot values + idx_t pivot_count = 0; + for (auto &pivot_column : ref.pivots) { + for (auto &pivot_expr : pivot_column.pivot_expressions) { + if (pivot_expr->alias.empty()) { + pivot_expr->alias = "__internal_pivot_ref" + std::to_string(++pivot_count); + } + auto pivot_alias = pivot_expr->alias; + subquery_stage1->groups.group_expressions.push_back( + make_uniq(Value::INTEGER(subquery_stage1->select_list.size() + 1))); + subquery_stage1->select_list.push_back(std::move(pivot_expr)); + pivot_expr = make_uniq(std::move(pivot_alias)); + } + } + idx_t aggregate_count = 0; + // finally add the aggregates + for (auto &aggregate : ref.aggregates) { + auto aggregate_alias = "__internal_pivot_aggregate" + std::to_string(++aggregate_count); + bind_state.aggregate_names.push_back(aggregate->alias); + bind_state.internal_aggregate_names.push_back(aggregate_alias); + aggregate->alias = std::move(aggregate_alias); + subquery_stage1->select_list.push_back(std::move(aggregate)); + } + return subquery_stage1; +} + +unique_ptr ConstructPivotExpression(unique_ptr pivot_expr) { + auto cast = make_uniq(LogicalType::VARCHAR, std::move(pivot_expr)); + vector> coalesce_children; + coalesce_children.push_back(std::move(cast)); + coalesce_children.push_back(make_uniq(Value("NULL"))); + auto coalesce = make_uniq(ExpressionType::OPERATOR_COALESCE, std::move(coalesce_children)); + return std::move(coalesce); +} + +static unique_ptr PivotListAggregate(PivotBindState &bind_state, PivotRef &ref, + unique_ptr subquery_stage1) { + auto subquery_stage2 = make_uniq(); + // wrap the subquery of stage 1 + auto subquery_select = make_uniq(); + subquery_select->node = std::move(subquery_stage1); + auto subquery_ref = make_uniq(std::move(subquery_select)); + + // add all of the groups + for (idx_t gr = 0; gr < bind_state.internal_group_names.size(); gr++) { + subquery_stage2->groups.group_expressions.push_back( + make_uniq(Value::INTEGER(subquery_stage2->select_list.size() + 1))); + auto group_reference = make_uniq(bind_state.internal_group_names[gr]); + group_reference->alias = bind_state.internal_group_names[gr]; + subquery_stage2->select_list.push_back(std::move(group_reference)); + } + + // construct the list aggregates + for (idx_t aggr = 0; aggr < bind_state.internal_aggregate_names.size(); aggr++) { + auto colref = make_uniq(bind_state.internal_aggregate_names[aggr]); + vector> list_children; + list_children.push_back(std::move(colref)); + auto aggregate = make_uniq("list", std::move(list_children)); + aggregate->alias = bind_state.internal_aggregate_names[aggr]; + subquery_stage2->select_list.push_back(std::move(aggregate)); + } + // construct the pivot list + auto pivot_name = "__internal_pivot_name"; + unique_ptr expr; + for (auto &pivot : ref.pivots) { + for (auto &pivot_expr : pivot.pivot_expressions) { + // coalesce(pivot::VARCHAR, 'NULL') + auto coalesce = ConstructPivotExpression(std::move(pivot_expr)); + if (!expr) { + expr = std::move(coalesce); + } else { + // string concat + vector> concat_children; + concat_children.push_back(std::move(expr)); + concat_children.push_back(make_uniq(Value("_"))); + concat_children.push_back(std::move(coalesce)); + auto concat = make_uniq("concat", std::move(concat_children)); + expr = std::move(concat); + } + } + } + // list(coalesce) + vector> list_children; + list_children.push_back(std::move(expr)); + auto aggregate = make_uniq("list", std::move(list_children)); + + aggregate->alias = pivot_name; + subquery_stage2->select_list.push_back(std::move(aggregate)); + + subquery_stage2->from_table = std::move(subquery_ref); + return subquery_stage2; +} + +static unique_ptr PivotFinalOperator(PivotBindState &bind_state, PivotRef &ref, + unique_ptr subquery, + vector pivot_values) { + auto final_pivot_operator = make_uniq(); + // wrap the subquery of stage 1 + auto subquery_select = make_uniq(); + subquery_select->node = std::move(subquery); + auto subquery_ref = make_uniq(std::move(subquery_select)); + + auto bound_pivot = make_uniq(); + bound_pivot->bound_pivot_values = std::move(pivot_values); + bound_pivot->bound_group_names = std::move(bind_state.group_names); + bound_pivot->bound_aggregate_names = std::move(bind_state.aggregate_names); + bound_pivot->source = std::move(subquery_ref); + + final_pivot_operator->select_list.push_back(make_uniq()); + final_pivot_operator->from_table = std::move(bound_pivot); + return final_pivot_operator; +} + +void ExtractPivotAggregates(BoundTableRef &node, vector> &aggregates) { + if (node.type != TableReferenceType::SUBQUERY) { + throw InternalException("Pivot - Expected a subquery"); + } + auto &subq = node.Cast(); + if (subq.subquery->type != QueryNodeType::SELECT_NODE) { + throw InternalException("Pivot - Expected a select node"); + } + auto &select = subq.subquery->Cast(); + if (select.from_table->type != TableReferenceType::SUBQUERY) { + throw InternalException("Pivot - Expected another subquery"); + } + auto &subq2 = select.from_table->Cast(); + if (subq2.subquery->type != QueryNodeType::SELECT_NODE) { + throw InternalException("Pivot - Expected another select node"); + } + auto &select2 = subq2.subquery->Cast(); + for (auto &aggr : select2.aggregates) { + aggregates.push_back(aggr->Copy()); + } +} + +unique_ptr Binder::BindBoundPivot(PivotRef &ref) { + // bind the child table in a child binder + auto result = make_uniq(); + result->bind_index = GenerateTableIndex(); + result->child_binder = Binder::CreateBinder(context, this); + result->child = result->child_binder->Bind(*ref.source); + + auto &aggregates = result->bound_pivot.aggregates; + ExtractPivotAggregates(*result->child, aggregates); + if (aggregates.size() != ref.bound_aggregate_names.size()) { + throw InternalException("Pivot aggregate count mismatch (expected %llu, found %llu)", + ref.bound_aggregate_names.size(), aggregates.size()); + } + + vector child_names; + vector child_types; + result->child_binder->bind_context.GetTypesAndNames(child_names, child_types); + + vector names; + vector types; + // emit the groups + for (idx_t i = 0; i < ref.bound_group_names.size(); i++) { + names.push_back(ref.bound_group_names[i]); + types.push_back(child_types[i]); + } + // emit the pivot columns + for (auto &pivot_value : ref.bound_pivot_values) { + for (idx_t aggr_idx = 0; aggr_idx < ref.bound_aggregate_names.size(); aggr_idx++) { + auto &aggr = aggregates[aggr_idx]; + auto &aggr_name = ref.bound_aggregate_names[aggr_idx]; + auto name = pivot_value.name; + if (aggregates.size() > 1 || !aggr_name.empty()) { + // if there are multiple aggregates specified we add the name of the aggregate as well + name += "_" + (aggr_name.empty() ? aggr->GetName() : aggr_name); + } + string pivot_str; + for (auto &value : pivot_value.values) { + auto str = value.ToString(); + if (pivot_str.empty()) { + pivot_str = std::move(str); + } else { + pivot_str += "_" + str; + } + } + result->bound_pivot.pivot_values.push_back(std::move(pivot_str)); + names.push_back(std::move(name)); + types.push_back(aggr->return_type); + } + } + result->bound_pivot.group_count = ref.bound_group_names.size(); + result->bound_pivot.types = types; + auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; + bind_context.AddGenericBinding(result->bind_index, subquery_alias, names, types); + MoveCorrelatedExpressions(*result->child_binder); + return std::move(result); +} + +unique_ptr Binder::BindPivot(PivotRef &ref, vector> all_columns) { + // keep track of the columns by which we pivot/aggregate + // any columns which are not pivoted/aggregated on are added to the GROUP BY clause + case_insensitive_set_t handled_columns; + // parse the aggregate, and extract the referenced columns from the aggregate + for (auto &aggr : ref.aggregates) { + if (aggr->type != ExpressionType::FUNCTION) { + throw BinderException(FormatError(*aggr, "Pivot expression must be an aggregate")); + } + if (aggr->HasSubquery()) { + throw BinderException(FormatError(*aggr, "Pivot expression cannot contain subqueries")); + } + if (aggr->IsWindow()) { + throw BinderException(FormatError(*aggr, "Pivot expression cannot contain window functions")); + } + // bind the function as an aggregate to ensure it is an aggregate and not a scalar function + auto &aggr_function = aggr->Cast(); + (void)Catalog::GetEntry(context, aggr_function.catalog, aggr_function.schema, + aggr_function.function_name); + ExtractPivotExpressions(*aggr, handled_columns); + } + + // first add all pivots to the set of handled columns, and check for duplicates + idx_t total_pivots = 1; + for (auto &pivot : ref.pivots) { + if (!pivot.pivot_enum.empty()) { + auto type = Catalog::GetType(context, INVALID_CATALOG, INVALID_SCHEMA, pivot.pivot_enum); + if (type.id() != LogicalTypeId::ENUM) { + throw BinderException( + FormatError(ref, StringUtil::Format("Pivot must reference an ENUM type: \"%s\" is of type \"%s\"", + pivot.pivot_enum, type.ToString()))); + } + auto enum_size = EnumType::GetSize(type); + for (idx_t i = 0; i < enum_size; i++) { + auto enum_value = EnumType::GetValue(Value::ENUM(i, type)); + PivotColumnEntry entry; + entry.values.emplace_back(enum_value); + entry.alias = std::move(enum_value); + pivot.entries.push_back(std::move(entry)); + } + } + total_pivots *= pivot.entries.size(); + // add the pivoted column to the columns that have been handled + for (auto &pivot_name : pivot.pivot_expressions) { + ExtractPivotExpressions(*pivot_name, handled_columns); + } + value_set_t pivots; + for (auto &entry : pivot.entries) { + D_ASSERT(!entry.star_expr); + Value val; + if (entry.values.size() == 1) { + val = entry.values[0]; + } else { + val = Value::LIST(LogicalType::VARCHAR, entry.values); + } + if (pivots.find(val) != pivots.end()) { + throw BinderException(FormatError( + ref, StringUtil::Format("The value \"%s\" was specified multiple times in the IN clause", + val.ToString()))); + } + if (entry.values.size() != pivot.pivot_expressions.size()) { + throw ParserException("PIVOT IN list - inconsistent amount of rows - expected %d but got %d", + pivot.pivot_expressions.size(), entry.values.size()); + } + pivots.insert(val); + } + } + auto &client_config = ClientConfig::GetConfig(context); + auto pivot_limit = client_config.pivot_limit; + if (total_pivots >= pivot_limit) { + throw BinderException("Pivot column limit of %llu exceeded. Use SET pivot_limit=X to increase the limit.", + client_config.pivot_limit); + } + + // construct the required pivot values recursively + vector pivot_values; + ConstructPivots(ref, pivot_values); + + unique_ptr pivot_node; + // pivots have three components + // - the pivots (i.e. future column names) + // - the groups (i.e. the future row names + // - the aggregates (i.e. the values of the pivot columns) + + // we have two ways of executing a pivot statement + // (1) the straightforward manner of filtered aggregates SUM(..) FILTER (pivot_value=X) + // (2) computing the aggregates once, then using LIST to group the aggregates together with the PIVOT operator + // -> filtered aggregates are faster when there are FEW pivot values + // -> LIST is faster when there are MANY pivot values + // we switch dynamically based on the number of pivots to compute + if (pivot_values.size() <= client_config.pivot_filter_threshold) { + // use a set of filtered aggregates + pivot_node = PivotFilteredAggregate(ref, std::move(all_columns), handled_columns, std::move(pivot_values)); + } else { + // executing a pivot statement happens in three stages + // 1) execute the query "SELECT {groups}, {pivots}, {aggregates} FROM {from_clause} GROUP BY {groups}, {pivots} + // this computes all values that are required in the final result, but not yet in the correct orientation + // 2) execute the query "SELECT {groups}, LIST({pivots}), LIST({aggregates}) FROM [Q1] GROUP BY {groups} + // this pushes all pivots and aggregates that belong to a specific group together in an aligned manner + // 3) push a PIVOT operator, that performs the actual pivoting of the values into the different columns + + PivotBindState bind_state; + // Pivot Stage 1 + // SELECT {groups}, {pivots}, {aggregates} FROM {from_clause} GROUP BY {groups}, {pivots} + auto subquery_stage1 = PivotInitialAggregate(bind_state, ref, std::move(all_columns), handled_columns); + + // Pivot stage 2 + // SELECT {groups}, LIST({pivots}), LIST({aggregates}) FROM [Q1] GROUP BY {groups} + auto subquery_stage2 = PivotListAggregate(bind_state, ref, std::move(subquery_stage1)); + + // Pivot stage 3 + // construct the final pivot operator + pivot_node = PivotFinalOperator(bind_state, ref, std::move(subquery_stage2), std::move(pivot_values)); + } + return pivot_node; +} + +unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, + vector> all_columns, + unique_ptr &where_clause) { + D_ASSERT(ref.groups.empty()); + D_ASSERT(ref.pivots.size() == 1); + + unique_ptr expr; + auto select_node = make_uniq(); + select_node->from_table = std::move(ref.source); + + // handle the pivot + auto &unpivot = ref.pivots[0]; + + // handle star expressions in any entries + vector new_entries; + for (auto &entry : unpivot.entries) { + if (entry.star_expr) { + D_ASSERT(entry.values.empty()); + vector> star_columns; + child_binder.ExpandStarExpression(std::move(entry.star_expr), star_columns); + + for (auto &col : star_columns) { + if (col->type != ExpressionType::COLUMN_REF) { + throw InternalException("Unexpected child of unpivot star - not a ColumnRef"); + } + auto &columnref = col->Cast(); + PivotColumnEntry new_entry; + new_entry.values.emplace_back(columnref.GetColumnName()); + new_entry.alias = columnref.GetColumnName(); + new_entries.push_back(std::move(new_entry)); + } + } else { + new_entries.push_back(std::move(entry)); + } + } + unpivot.entries = std::move(new_entries); + + case_insensitive_set_t handled_columns; + case_insensitive_map_t name_map; + for (auto &entry : unpivot.entries) { + for (auto &value : entry.values) { + handled_columns.insert(value.ToString()); + } + } + + for (auto &col_expr : all_columns) { + if (col_expr->type != ExpressionType::COLUMN_REF) { + throw InternalException("Unexpected child of pivot source - not a ColumnRef"); + } + auto &columnref = col_expr->Cast(); + auto &column_name = columnref.GetColumnName(); + auto entry = handled_columns.find(column_name); + if (entry == handled_columns.end()) { + // not handled - add to the set of regularly selected columns + select_node->select_list.push_back(std::move(col_expr)); + } else { + name_map[column_name] = column_name; + handled_columns.erase(entry); + } + } + if (!handled_columns.empty()) { + for (auto &entry : handled_columns) { + throw BinderException("Column \"%s\" referenced in UNPIVOT but no matching entry was found in the table", + entry); + } + } + vector unpivot_names; + for (auto &entry : unpivot.entries) { + string generated_name; + for (auto &val : entry.values) { + auto name_entry = name_map.find(val.ToString()); + if (name_entry == name_map.end()) { + throw InternalException("Unpivot - could not find column name in name map"); + } + if (!generated_name.empty()) { + generated_name += "_"; + } + generated_name += name_entry->second; + } + unpivot_names.emplace_back(!entry.alias.empty() ? entry.alias : generated_name); + } + vector>> unpivot_expressions; + for (idx_t v_idx = 1; v_idx < unpivot.entries.size(); v_idx++) { + if (unpivot.entries[v_idx].values.size() != unpivot.entries[0].values.size()) { + throw BinderException( + "UNPIVOT value count mismatch - entry has %llu values, but expected all entries to have %llu values", + unpivot.entries[v_idx].values.size(), unpivot.entries[0].values.size()); + } + } + + for (idx_t v_idx = 0; v_idx < unpivot.entries[0].values.size(); v_idx++) { + vector> expressions; + expressions.reserve(unpivot.entries.size()); + for (auto &entry : unpivot.entries) { + expressions.push_back(make_uniq(entry.values[v_idx].ToString())); + } + unpivot_expressions.push_back(std::move(expressions)); + } + + // construct the UNNEST expression for the set of names (constant) + auto unpivot_list = Value::LIST(LogicalType::VARCHAR, std::move(unpivot_names)); + auto unpivot_name_expr = make_uniq(std::move(unpivot_list)); + vector> unnest_name_children; + unnest_name_children.push_back(std::move(unpivot_name_expr)); + auto unnest_name_expr = make_uniq("unnest", std::move(unnest_name_children)); + unnest_name_expr->alias = unpivot.unpivot_names[0]; + select_node->select_list.push_back(std::move(unnest_name_expr)); + + // construct the UNNEST expression for the set of unpivoted columns + if (ref.unpivot_names.size() != unpivot_expressions.size()) { + throw BinderException("UNPIVOT name count mismatch - got %d names but %d expressions", ref.unpivot_names.size(), + unpivot_expressions.size()); + } + for (idx_t i = 0; i < unpivot_expressions.size(); i++) { + auto list_expr = make_uniq("list_value", std::move(unpivot_expressions[i])); + vector> unnest_val_children; + unnest_val_children.push_back(std::move(list_expr)); + auto unnest_val_expr = make_uniq("unnest", std::move(unnest_val_children)); + auto unnest_name = i < ref.column_name_alias.size() ? ref.column_name_alias[i] : ref.unpivot_names[i]; + unnest_val_expr->alias = unnest_name; + select_node->select_list.push_back(std::move(unnest_val_expr)); + if (!ref.include_nulls) { + // if we are running with EXCLUDE NULLS we need to add an IS NOT NULL filter + auto colref = make_uniq(unnest_name); + auto filter = make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, std::move(colref)); + if (where_clause) { + where_clause = make_uniq(ExpressionType::CONJUNCTION_AND, + std::move(where_clause), std::move(filter)); + } else { + where_clause = std::move(filter); + } + } + } + return select_node; +} + +unique_ptr Binder::Bind(PivotRef &ref) { + if (!ref.source) { + throw InternalException("Pivot without a source!?"); + } + if (!ref.bound_pivot_values.empty() || !ref.bound_group_names.empty() || !ref.bound_aggregate_names.empty()) { + // bound pivot + return BindBoundPivot(ref); + } + + // bind the source of the pivot + // we need to do this to be able to expand star expressions + if (ref.source->type == TableReferenceType::SUBQUERY && ref.source->alias.empty()) { + ref.source->alias = "__internal_pivot_alias_" + to_string(GenerateTableIndex()); + } + auto copied_source = ref.source->Copy(); + auto star_binder = Binder::CreateBinder(context, this); + star_binder->Bind(*copied_source); + + // figure out the set of column names that are in the source of the pivot + vector> all_columns; + star_binder->ExpandStarExpression(make_uniq(), all_columns); + + unique_ptr select_node; + unique_ptr where_clause; + if (!ref.aggregates.empty()) { + select_node = BindPivot(ref, std::move(all_columns)); + } else { + select_node = BindUnpivot(*star_binder, ref, std::move(all_columns), where_clause); + } + // bind the generated select node + auto child_binder = Binder::CreateBinder(context, this); + auto bound_select_node = child_binder->BindNode(*select_node); + auto root_index = bound_select_node->GetRootIndex(); + BoundQueryNode *bound_select_ptr = bound_select_node.get(); + + unique_ptr result; + MoveCorrelatedExpressions(*child_binder); + result = make_uniq(std::move(child_binder), std::move(bound_select_node)); + auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; + SubqueryRef subquery_ref(nullptr, subquery_alias); + subquery_ref.column_name_alias = std::move(ref.column_name_alias); + if (where_clause) { + // if a WHERE clause was provided - bind a subquery holding the WHERE clause + // we need to bind a new subquery here because the WHERE clause has to be applied AFTER the unnest + child_binder = Binder::CreateBinder(context, this); + child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); + auto where_query = make_uniq(); + where_query->select_list.push_back(make_uniq()); + where_query->where_clause = std::move(where_clause); + bound_select_node = child_binder->BindSelectNode(*where_query, std::move(result)); + bound_select_ptr = bound_select_node.get(); + root_index = bound_select_node->GetRootIndex(); + result = make_uniq(std::move(child_binder), std::move(bound_select_node)); + } + bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp new file mode 100644 index 00000000..e7755bc4 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp @@ -0,0 +1,28 @@ +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/tableref/bound_subqueryref.hpp" + +namespace duckdb { + +unique_ptr Binder::Bind(SubqueryRef &ref, optional_ptr cte) { + auto binder = Binder::CreateBinder(context, this); + binder->can_contain_nulls = true; + if (cte) { + binder->bound_ctes.insert(*cte); + } + binder->alias = ref.alias.empty() ? "unnamed_subquery" : ref.alias; + auto subquery = binder->BindNode(*ref.subquery->node); + idx_t bind_index = subquery->GetRootIndex(); + string subquery_alias; + if (ref.alias.empty()) { + subquery_alias = "unnamed_subquery" + to_string(bind_index); + } else { + subquery_alias = ref.alias; + } + auto result = make_uniq(std::move(binder), std::move(subquery)); + bind_context.AddSubquery(bind_index, subquery_alias, ref, *result->subquery); + MoveCorrelatedExpressions(*result->binder); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp new file mode 100644 index 00000000..a3eddeac --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp @@ -0,0 +1,298 @@ +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/emptytableref.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder/table_function_binder.hpp" +#include "duckdb/planner/expression_binder/select_binder.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/tableref/bound_subqueryref.hpp" +#include "duckdb/planner/tableref/bound_table_function.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/function/table/read_csv.hpp" + +namespace duckdb { + +static bool IsTableInTableOutFunction(TableFunctionCatalogEntry &table_function) { + auto fun = table_function.functions.GetFunctionByOffset(0); + return table_function.functions.Size() == 1 && fun.arguments.size() == 1 && + fun.arguments[0].id() == LogicalTypeId::TABLE; +} + +bool Binder::BindTableInTableOutFunction(vector> &expressions, + unique_ptr &subquery, string &error) { + auto binder = Binder::CreateBinder(this->context, this, true); + unique_ptr subquery_node; + if (expressions.size() == 1 && expressions[0]->type == ExpressionType::SUBQUERY) { + // general case: argument is a subquery, bind it as part of the node + auto &se = expressions[0]->Cast(); + subquery_node = std::move(se.subquery->node); + } else { + // special case: non-subquery parameter to table-in table-out function + // generate a subquery and bind that (i.e. UNNEST([1,2,3]) becomes UNNEST((SELECT [1,2,3])) + auto select_node = make_uniq(); + select_node->select_list = std::move(expressions); + select_node->from_table = make_uniq(); + subquery_node = std::move(select_node); + } + auto node = binder->BindNode(*subquery_node); + subquery = make_uniq(std::move(binder), std::move(node)); + MoveCorrelatedExpressions(*subquery->binder); + return true; +} + +bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, + vector> &expressions, + vector &arguments, vector ¶meters, + named_parameter_map_t &named_parameters, + unique_ptr &subquery, string &error) { + if (IsTableInTableOutFunction(table_function)) { + // special case binding for table-in table-out function + arguments.emplace_back(LogicalTypeId::TABLE); + return BindTableInTableOutFunction(expressions, subquery, error); + } + bool seen_subquery = false; + for (auto &child : expressions) { + string parameter_name; + + // hack to make named parameters work + if (child->type == ExpressionType::COMPARE_EQUAL) { + // comparison, check if the LHS is a columnref + auto &comp = child->Cast(); + if (comp.left->type == ExpressionType::COLUMN_REF) { + auto &colref = comp.left->Cast(); + if (!colref.IsQualified()) { + parameter_name = colref.GetColumnName(); + child = std::move(comp.right); + } + } + } + if (child->type == ExpressionType::SUBQUERY) { + auto fun = table_function.functions.GetFunctionByOffset(0); + if (table_function.functions.Size() != 1 || fun.arguments.empty() || + fun.arguments[0].id() != LogicalTypeId::TABLE) { + throw BinderException( + "Only table-in-out functions can have subquery parameters - %s only accepts constant parameters", + fun.name); + } + // this separate subquery binding path is only used by python_map + // FIXME: this should be unified with `BindTableInTableOutFunction` above + if (seen_subquery) { + error = "Table function can have at most one subquery parameter "; + return false; + } + auto binder = Binder::CreateBinder(this->context, this, true); + auto &se = child->Cast(); + auto node = binder->BindNode(*se.subquery->node); + subquery = make_uniq(std::move(binder), std::move(node)); + seen_subquery = true; + arguments.emplace_back(LogicalTypeId::TABLE); + parameters.emplace_back( + Value(LogicalType::INVALID)); // this is a dummy value so the lengths of arguments and parameter match + continue; + } + + TableFunctionBinder binder(*this, context); + LogicalType sql_type; + auto expr = binder.Bind(child, &sql_type); + if (expr->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!expr->IsScalar()) { + // should have been eliminated before + throw InternalException("Table function requires a constant parameter"); + } + auto constant = ExpressionExecutor::EvaluateScalar(context, *expr, true); + if (parameter_name.empty()) { + // unnamed parameter + if (!named_parameters.empty()) { + error = "Unnamed parameters cannot come after named parameters"; + return false; + } + arguments.emplace_back(sql_type); + parameters.emplace_back(std::move(constant)); + } else { + named_parameters[parameter_name] = std::move(constant); + } + } + return true; +} + +unique_ptr +Binder::BindTableFunctionInternal(TableFunction &table_function, const string &function_name, vector parameters, + named_parameter_map_t named_parameters, vector input_table_types, + vector input_table_names, const vector &column_name_alias, + unique_ptr external_dependency) { + auto bind_index = GenerateTableIndex(); + // perform the binding + unique_ptr bind_data; + vector return_types; + vector return_names; + if (table_function.bind || table_function.bind_replace) { + TableFunctionBindInput bind_input(parameters, named_parameters, input_table_types, input_table_names, + table_function.function_info.get()); + if (table_function.bind_replace) { + auto new_plan = table_function.bind_replace(context, bind_input); + if (new_plan != nullptr) { + return CreatePlan(*Bind(*new_plan)); + } else if (!table_function.bind) { + throw BinderException("Failed to bind \"%s\": nullptr returned from bind_replace without bind function", + table_function.name); + } + } + bind_data = table_function.bind(context, bind_input, return_types, return_names); + if (table_function.name == "pandas_scan" || table_function.name == "arrow_scan") { + auto &arrow_bind = bind_data->Cast(); + arrow_bind.external_dependency = std::move(external_dependency); + } + if (table_function.name == "read_csv" || table_function.name == "read_csv_auto") { + auto &csv_bind = bind_data->Cast(); + if (csv_bind.single_threaded) { + table_function.extra_info = "(Single-Threaded)"; + } else { + table_function.extra_info = "(Multi-Threaded)"; + } + } + } else { + throw InvalidInputException("Cannot call function \"%s\" directly - it has no bind function", + table_function.name); + } + if (return_types.size() != return_names.size()) { + throw InternalException("Failed to bind \"%s\": return_types/names must have same size", table_function.name); + } + if (return_types.empty()) { + throw InternalException("Failed to bind \"%s\": Table function must return at least one column", + table_function.name); + } + // overwrite the names with any supplied aliases + for (idx_t i = 0; i < column_name_alias.size() && i < return_names.size(); i++) { + return_names[i] = column_name_alias[i]; + } + for (idx_t i = 0; i < return_names.size(); i++) { + if (return_names[i].empty()) { + return_names[i] = "C" + to_string(i); + } + } + + auto get = make_uniq(bind_index, table_function, std::move(bind_data), return_types, return_names); + get->parameters = parameters; + get->named_parameters = named_parameters; + get->input_table_types = input_table_types; + get->input_table_names = input_table_names; + if (table_function.in_out_function && !table_function.projection_pushdown) { + get->column_ids.reserve(return_types.size()); + for (idx_t i = 0; i < return_types.size(); i++) { + get->column_ids.push_back(i); + } + } + // now add the table function to the bind context so its columns can be bound + bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->column_ids, + get->GetTable().get()); + return std::move(get); +} + +unique_ptr Binder::BindTableFunction(TableFunction &function, vector parameters) { + named_parameter_map_t named_parameters; + vector input_table_types; + vector input_table_names; + vector column_name_aliases; + return BindTableFunctionInternal(function, function.name, std::move(parameters), std::move(named_parameters), + std::move(input_table_types), std::move(input_table_names), column_name_aliases, + nullptr); +} + +unique_ptr Binder::Bind(TableFunctionRef &ref) { + QueryErrorContext error_context(root_statement, ref.query_location); + + D_ASSERT(ref.function->type == ExpressionType::FUNCTION); + auto &fexpr = ref.function->Cast(); + + // fetch the function from the catalog + auto &func_catalog = Catalog::GetEntry(context, CatalogType::TABLE_FUNCTION_ENTRY, fexpr.catalog, fexpr.schema, + fexpr.function_name, error_context); + + if (func_catalog.type == CatalogType::TABLE_MACRO_ENTRY) { + auto ¯o_func = func_catalog.Cast(); + auto query_node = BindTableMacro(fexpr, macro_func, 0); + D_ASSERT(query_node); + + auto binder = Binder::CreateBinder(context, this); + binder->can_contain_nulls = true; + + binder->alias = ref.alias.empty() ? "unnamed_query" : ref.alias; + auto query = binder->BindNode(*query_node); + + idx_t bind_index = query->GetRootIndex(); + // string alias; + string alias = (ref.alias.empty() ? "unnamed_query" + to_string(bind_index) : ref.alias); + + auto result = make_uniq(std::move(binder), std::move(query)); + // remember ref here is TableFunctionRef and NOT base class + bind_context.AddSubquery(bind_index, alias, ref, *result->subquery); + MoveCorrelatedExpressions(*result->binder); + return std::move(result); + } + D_ASSERT(func_catalog.type == CatalogType::TABLE_FUNCTION_ENTRY); + auto &function = func_catalog.Cast(); + + // evaluate the input parameters to the function + vector arguments; + vector parameters; + named_parameter_map_t named_parameters; + unique_ptr subquery; + string error; + if (!BindTableFunctionParameters(function, fexpr.children, arguments, parameters, named_parameters, subquery, + error)) { + throw BinderException(FormatError(ref, error)); + } + + // select the function based on the input parameters + FunctionBinder function_binder(context); + idx_t best_function_idx = function_binder.BindFunction(function.name, function.functions, arguments, error); + if (best_function_idx == DConstants::INVALID_INDEX) { + throw BinderException(FormatError(ref, error)); + } + auto table_function = function.functions.GetFunctionByOffset(best_function_idx); + + // now check the named parameters + BindNamedParameters(table_function.named_parameters, named_parameters, error_context, table_function.name); + + // cast the parameters to the type of the function + for (idx_t i = 0; i < arguments.size(); i++) { + auto target_type = i < table_function.arguments.size() ? table_function.arguments[i] : table_function.varargs; + + if (target_type != LogicalType::ANY && target_type != LogicalType::TABLE && + target_type != LogicalType::POINTER && target_type.id() != LogicalTypeId::LIST) { + parameters[i] = parameters[i].CastAs(context, target_type); + } + } + + vector input_table_types; + vector input_table_names; + + if (subquery) { + input_table_types = subquery->subquery->types; + input_table_names = subquery->subquery->names; + } + auto get = BindTableFunctionInternal(table_function, ref.alias.empty() ? fexpr.function_name : ref.alias, + std::move(parameters), std::move(named_parameters), + std::move(input_table_types), std::move(input_table_names), + ref.column_name_alias, std::move(ref.external_dependency)); + if (subquery) { + get->children.push_back(Binder::CreatePlan(*subquery)); + } + + return make_uniq_base(std::move(get)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp new file mode 100644 index 00000000..085498fb --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp @@ -0,0 +1,11 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/tableref/bound_basetableref.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundBaseTableRef &ref) { + return std::move(ref.get); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp b/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp new file mode 100644 index 00000000..6f5ba901 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp @@ -0,0 +1,19 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" +#include "duckdb/planner/tableref/bound_cteref.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundCTERef &ref) { + auto index = ref.bind_index; + + vector types; + types.reserve(ref.types.size()); + for (auto &type : ref.types) { + types.push_back(type); + } + + return make_uniq(index, ref.cte_index, types, ref.bound_columns, ref.materialized_cte); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp b/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp new file mode 100644 index 00000000..f31fc929 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp @@ -0,0 +1,11 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" +#include "duckdb/planner/tableref/bound_dummytableref.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundEmptyTableRef &ref) { + return make_uniq(ref.bind_index); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp new file mode 100644 index 00000000..ba6253bc --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp @@ -0,0 +1,27 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/tableref/bound_expressionlistref.hpp" +#include "duckdb/planner/operator/logical_expression_get.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundExpressionListRef &ref) { + auto root = make_uniq_base(GenerateTableIndex()); + // values list, first plan any subqueries in the list + for (auto &expr_list : ref.values) { + for (auto &expr : expr_list) { + PlanSubqueries(expr, root); + } + } + // now create a LogicalExpressionGet from the set of expressions + // fetch the types + vector types; + for (auto &expr : ref.values[0]) { + types.push_back(expr->return_type); + } + auto expr_get = make_uniq(ref.bind_index, types, std::move(ref.values)); + expr_get->AddChild(std::move(root)); + return std::move(expr_get); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp new file mode 100644 index 00000000..9bc0b130 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp @@ -0,0 +1,368 @@ +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_subquery_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_any_join.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_cross_product.hpp" +#include "duckdb/planner/operator/logical_dependent_join.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_positional_join.hpp" +#include "duckdb/planner/tableref/bound_joinref.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression_binder/lateral_binder.hpp" +#include "duckdb/planner/subquery/recursive_dependent_join_planner.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +//! Create a JoinCondition from a comparison +static bool CreateJoinCondition(Expression &expr, const unordered_set &left_bindings, + const unordered_set &right_bindings, vector &conditions) { + // comparison + auto &comparison = expr.Cast(); + auto left_side = JoinSide::GetJoinSide(*comparison.left, left_bindings, right_bindings); + auto right_side = JoinSide::GetJoinSide(*comparison.right, left_bindings, right_bindings); + if (left_side != JoinSide::BOTH && right_side != JoinSide::BOTH) { + // join condition can be divided in a left/right side + JoinCondition condition; + condition.comparison = expr.type; + auto left = std::move(comparison.left); + auto right = std::move(comparison.right); + if (left_side == JoinSide::RIGHT) { + // left = right, right = left, flip the comparison symbol and reverse sides + swap(left, right); + condition.comparison = FlipComparisonExpression(expr.type); + } + condition.left = std::move(left); + condition.right = std::move(right); + conditions.push_back(std::move(condition)); + return true; + } + return false; +} + +void LogicalComparisonJoin::ExtractJoinConditions( + ClientContext &context, JoinType type, unique_ptr &left_child, + unique_ptr &right_child, const unordered_set &left_bindings, + const unordered_set &right_bindings, vector> &expressions, + vector &conditions, vector> &arbitrary_expressions) { + + for (auto &expr : expressions) { + auto total_side = JoinSide::GetJoinSide(*expr, left_bindings, right_bindings); + if (total_side != JoinSide::BOTH) { + // join condition does not reference both sides, add it as filter under the join + if (type == JoinType::LEFT && total_side == JoinSide::RIGHT) { + // filter is on RHS and the join is a LEFT OUTER join, we can push it in the right child + if (right_child->type != LogicalOperatorType::LOGICAL_FILTER) { + // not a filter yet, push a new empty filter + auto filter = make_uniq(); + filter->AddChild(std::move(right_child)); + right_child = std::move(filter); + } + // push the expression into the filter + auto &filter = right_child->Cast(); + filter.expressions.push_back(std::move(expr)); + continue; + } + // if the join is a LEFT JOIN and the join expression constantly evaluates to TRUE, + // then we do not add it to the arbitrary expressions + if (type == JoinType::LEFT && expr->IsFoldable()) { + Value result; + ExpressionExecutor::TryEvaluateScalar(context, *expr, result); + if (!result.IsNull() && result == Value(true)) { + continue; + } + } + } else if (expr->type == ExpressionType::COMPARE_EQUAL || expr->type == ExpressionType::COMPARE_NOTEQUAL || + expr->type == ExpressionType::COMPARE_BOUNDARY_START || + expr->type == ExpressionType::COMPARE_LESSTHAN || + expr->type == ExpressionType::COMPARE_GREATERTHAN || + expr->type == ExpressionType::COMPARE_LESSTHANOREQUALTO || + expr->type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || + expr->type == ExpressionType::COMPARE_BOUNDARY_START || + expr->type == ExpressionType::COMPARE_NOT_DISTINCT_FROM || + expr->type == ExpressionType::COMPARE_DISTINCT_FROM) + + { + // comparison, check if we can create a comparison JoinCondition + if (CreateJoinCondition(*expr, left_bindings, right_bindings, conditions)) { + // successfully created the join condition + continue; + } + } + arbitrary_expressions.push_back(std::move(expr)); + } +} + +void LogicalComparisonJoin::ExtractJoinConditions(ClientContext &context, JoinType type, + unique_ptr &left_child, + unique_ptr &right_child, + vector> &expressions, + vector &conditions, + vector> &arbitrary_expressions) { + unordered_set left_bindings, right_bindings; + LogicalJoin::GetTableReferences(*left_child, left_bindings); + LogicalJoin::GetTableReferences(*right_child, right_bindings); + return ExtractJoinConditions(context, type, left_child, right_child, left_bindings, right_bindings, expressions, + conditions, arbitrary_expressions); +} + +void LogicalComparisonJoin::ExtractJoinConditions(ClientContext &context, JoinType type, + unique_ptr &left_child, + unique_ptr &right_child, + unique_ptr condition, vector &conditions, + vector> &arbitrary_expressions) { + // split the expressions by the AND clause + vector> expressions; + expressions.push_back(std::move(condition)); + LogicalFilter::SplitPredicates(expressions); + return ExtractJoinConditions(context, type, left_child, right_child, expressions, conditions, + arbitrary_expressions); +} + +unique_ptr LogicalComparisonJoin::CreateJoin(ClientContext &context, JoinType type, + JoinRefType reftype, + unique_ptr left_child, + unique_ptr right_child, + vector conditions, + vector> arbitrary_expressions) { + // Validate the conditions + bool need_to_consider_arbitrary_expressions = true; + switch (reftype) { + case JoinRefType::ASOF: { + need_to_consider_arbitrary_expressions = false; + auto asof_idx = conditions.size(); + for (size_t c = 0; c < conditions.size(); ++c) { + auto &cond = conditions[c]; + switch (cond.comparison) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_LESSTHAN: + if (asof_idx < conditions.size()) { + throw BinderException("Multiple ASOF JOIN inequalities"); + } + asof_idx = c; + break; + default: + throw BinderException("Invalid ASOF JOIN comparison"); + } + } + if (asof_idx == conditions.size()) { + throw BinderException("Missing ASOF JOIN inequality"); + } + break; + } + default: + break; + } + + if (type == JoinType::INNER && reftype == JoinRefType::REGULAR) { + // for inner joins we can push arbitrary expressions as a filter + // here we prefer to create a comparison join if possible + // that way we can use the much faster hash join to process the main join + // rather than doing a nested loop join to handle arbitrary expressions + + // for left and full outer joins we HAVE to process all join conditions + // because pushing a filter will lead to an incorrect result, as non-matching tuples cannot be filtered out + need_to_consider_arbitrary_expressions = false; + } + if ((need_to_consider_arbitrary_expressions && !arbitrary_expressions.empty()) || conditions.empty()) { + if (arbitrary_expressions.empty()) { + // all conditions were pushed down, add TRUE predicate + arbitrary_expressions.push_back(make_uniq(Value::BOOLEAN(true))); + } + for (auto &condition : conditions) { + arbitrary_expressions.push_back(JoinCondition::CreateExpression(std::move(condition))); + } + // if we get here we could not create any JoinConditions + // turn this into an arbitrary expression join + auto any_join = make_uniq(type); + // create the condition + any_join->children.push_back(std::move(left_child)); + any_join->children.push_back(std::move(right_child)); + // AND all the arbitrary expressions together + // do the same with any remaining conditions + any_join->condition = std::move(arbitrary_expressions[0]); + for (idx_t i = 1; i < arbitrary_expressions.size(); i++) { + any_join->condition = make_uniq( + ExpressionType::CONJUNCTION_AND, std::move(any_join->condition), std::move(arbitrary_expressions[i])); + } + return std::move(any_join); + } else { + // we successfully converted expressions into JoinConditions + // create a LogicalComparisonJoin + auto logical_type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; + if (reftype == JoinRefType::ASOF) { + logical_type = LogicalOperatorType::LOGICAL_ASOF_JOIN; + } + auto comp_join = make_uniq(type, logical_type); + comp_join->conditions = std::move(conditions); + comp_join->children.push_back(std::move(left_child)); + comp_join->children.push_back(std::move(right_child)); + if (!arbitrary_expressions.empty()) { + // we have some arbitrary expressions as well + // add them to a filter + auto filter = make_uniq(); + for (auto &expr : arbitrary_expressions) { + filter->expressions.push_back(std::move(expr)); + } + LogicalFilter::SplitPredicates(filter->expressions); + filter->children.push_back(std::move(comp_join)); + return std::move(filter); + } + return std::move(comp_join); + } +} + +static bool HasCorrelatedColumns(Expression &expression) { + if (expression.type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expression.Cast(); + if (colref.depth > 0) { + return true; + } + } + bool has_correlated_columns = false; + ExpressionIterator::EnumerateChildren(expression, [&](Expression &child) { + if (HasCorrelatedColumns(child)) { + has_correlated_columns = true; + } + }); + return has_correlated_columns; +} + +unique_ptr LogicalComparisonJoin::CreateJoin(ClientContext &context, JoinType type, + JoinRefType reftype, + unique_ptr left_child, + unique_ptr right_child, + unique_ptr condition) { + vector conditions; + vector> arbitrary_expressions; + LogicalComparisonJoin::ExtractJoinConditions(context, type, left_child, right_child, std::move(condition), + conditions, arbitrary_expressions); + return LogicalComparisonJoin::CreateJoin(context, type, reftype, std::move(left_child), std::move(right_child), + std::move(conditions), std::move(arbitrary_expressions)); +} + +unique_ptr Binder::CreatePlan(BoundJoinRef &ref) { + auto old_is_outside_flattened = is_outside_flattened; + // Plan laterals from outermost to innermost + if (ref.lateral) { + // Set the flag to ensure that children do not flatten before the root + is_outside_flattened = false; + } + auto left = CreatePlan(*ref.left); + auto right = CreatePlan(*ref.right); + is_outside_flattened = old_is_outside_flattened; + + // For joins, depth of the bindings will be one higher on the right because of the lateral binder + // If the current join does not have correlations between left and right, then the right bindings + // have depth 1 too high and can be reduced by 1 throughout + if (!ref.lateral && !ref.correlated_columns.empty()) { + LateralBinder::ReduceExpressionDepth(*right, ref.correlated_columns); + } + + if (ref.type == JoinType::RIGHT && ref.ref_type != JoinRefType::ASOF && + ClientConfig::GetConfig(context).enable_optimizer) { + // we turn any right outer joins into left outer joins for optimization purposes + // they are the same but with sides flipped, so treating them the same simplifies life + ref.type = JoinType::LEFT; + std::swap(left, right); + } + if (ref.lateral) { + if (!is_outside_flattened) { + // If outer dependent joins is yet to be flattened, only plan the lateral + has_unplanned_dependent_joins = true; + return LogicalDependentJoin::Create(std::move(left), std::move(right), ref.correlated_columns, ref.type, + std::move(ref.condition)); + } else { + // All outer dependent joins have been planned and flattened, so plan and flatten lateral and recursively + // plan the children + auto new_plan = PlanLateralJoin(std::move(left), std::move(right), ref.correlated_columns, ref.type, + std::move(ref.condition)); + if (has_unplanned_dependent_joins) { + RecursiveDependentJoinPlanner plan(*this); + plan.VisitOperator(*new_plan); + } + return new_plan; + } + } + switch (ref.ref_type) { + case JoinRefType::CROSS: + return LogicalCrossProduct::Create(std::move(left), std::move(right)); + case JoinRefType::POSITIONAL: + return LogicalPositionalJoin::Create(std::move(left), std::move(right)); + default: + break; + } + if (ref.type == JoinType::INNER && (ref.condition->HasSubquery() || HasCorrelatedColumns(*ref.condition)) && + ref.ref_type == JoinRefType::REGULAR) { + // inner join, generate a cross product + filter + // this will be later turned into a proper join by the join order optimizer + auto root = LogicalCrossProduct::Create(std::move(left), std::move(right)); + + auto filter = make_uniq(std::move(ref.condition)); + // visit the expressions in the filter + for (auto &expression : filter->expressions) { + PlanSubqueries(expression, root); + } + filter->AddChild(std::move(root)); + return std::move(filter); + } + + // now create the join operator from the join condition + auto result = LogicalComparisonJoin::CreateJoin(context, ref.type, ref.ref_type, std::move(left), std::move(right), + std::move(ref.condition)); + + optional_ptr join; + if (result->type == LogicalOperatorType::LOGICAL_FILTER) { + join = result->children[0].get(); + } else { + join = result.get(); + } + for (auto &child : join->children) { + if (child->type == LogicalOperatorType::LOGICAL_FILTER) { + auto &filter = child->Cast(); + for (auto &expr : filter.expressions) { + PlanSubqueries(expr, filter.children[0]); + } + } + } + + // we visit the expressions depending on the type of join + switch (join->type) { + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + // comparison join + // in this join we visit the expressions on the LHS with the LHS as root node + // and the expressions on the RHS with the RHS as root node + auto &comp_join = join->Cast(); + for (idx_t i = 0; i < comp_join.conditions.size(); i++) { + PlanSubqueries(comp_join.conditions[i].left, comp_join.children[0]); + PlanSubqueries(comp_join.conditions[i].right, comp_join.children[1]); + } + break; + } + case LogicalOperatorType::LOGICAL_ANY_JOIN: { + auto &any_join = join->Cast(); + // for the any join we just visit the condition + if (any_join.condition->HasSubquery()) { + throw NotImplementedException("Cannot perform non-inner join on subquery!"); + } + break; + } + default: + break; + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp b/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp new file mode 100644 index 00000000..4d9482e5 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp @@ -0,0 +1,13 @@ +#include "duckdb/planner/tableref/bound_pivotref.hpp" +#include "duckdb/planner/operator/logical_pivot.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundPivotRef &ref) { + auto subquery = ref.child_binder->CreatePlan(*ref.child); + + auto result = make_uniq(ref.bind_index, std::move(subquery), std::move(ref.bound_pivot)); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp new file mode 100644 index 00000000..82165446 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp @@ -0,0 +1,17 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/tableref/bound_subqueryref.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundSubqueryRef &ref) { + // generate the logical plan for the subquery + // this happens separately from the current LogicalPlan generation + ref.binder->is_outside_flattened = is_outside_flattened; + auto subquery = ref.binder->CreatePlan(*ref.subquery); + if (ref.binder->has_unplanned_dependent_joins) { + has_unplanned_dependent_joins = true; + } + return subquery; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp b/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp new file mode 100644 index 00000000..dffbd320 --- /dev/null +++ b/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp @@ -0,0 +1,10 @@ +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/tableref/bound_table_function.hpp" + +namespace duckdb { + +unique_ptr Binder::CreatePlan(BoundTableFunction &ref) { + return std::move(ref.get); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/bound_parameter_map.cpp b/src/duckdb/src/planner/bound_parameter_map.cpp new file mode 100644 index 00000000..cd55830a --- /dev/null +++ b/src/duckdb/src/planner/bound_parameter_map.cpp @@ -0,0 +1,67 @@ +#include "duckdb/planner/bound_parameter_map.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" + +namespace duckdb { + +BoundParameterMap::BoundParameterMap(case_insensitive_map_t ¶meter_data) + : parameter_data(parameter_data) { +} + +LogicalType BoundParameterMap::GetReturnType(const string &identifier) { + D_ASSERT(!identifier.empty()); + auto it = parameter_data.find(identifier); + if (it == parameter_data.end()) { + return LogicalTypeId::UNKNOWN; + } + return it->second.return_type; +} + +bound_parameter_map_t *BoundParameterMap::GetParametersPtr() { + return ¶meters; +} + +const bound_parameter_map_t &BoundParameterMap::GetParameters() { + return parameters; +} + +const case_insensitive_map_t &BoundParameterMap::GetParameterData() { + return parameter_data; +} + +shared_ptr BoundParameterMap::CreateOrGetData(const string &identifier) { + auto entry = parameters.find(identifier); + if (entry == parameters.end()) { + // no entry yet: create a new one + auto data = make_shared(); + data->return_type = GetReturnType(identifier); + + CreateNewParameter(identifier, data); + return data; + } + return entry->second; +} + +unique_ptr BoundParameterMap::BindParameterExpression(ParameterExpression &expr) { + auto &identifier = expr.identifier; + auto return_type = GetReturnType(identifier); + + D_ASSERT(!parameter_data.count(identifier)); + + // No value has been supplied yet, + // We return a shared pointer to an object that will get populated wtih a Value later + // When the BoundParameterExpression get executed, this will be used to get the corresponding value + auto param_data = CreateOrGetData(identifier); + auto bound_expr = make_uniq(identifier); + bound_expr->parameter_data = param_data; + bound_expr->return_type = return_type; + bound_expr->alias = expr.alias; + return bound_expr; +} + +void BoundParameterMap::CreateNewParameter(const string &id, const shared_ptr ¶m_data) { + D_ASSERT(!parameters.count(id)); + parameters.emplace(std::make_pair(id, param_data)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/bound_result_modifier.cpp b/src/duckdb/src/planner/bound_result_modifier.cpp new file mode 100644 index 00000000..0518eb7b --- /dev/null +++ b/src/duckdb/src/planner/bound_result_modifier.cpp @@ -0,0 +1,108 @@ +#include "duckdb/planner/bound_result_modifier.hpp" + +namespace duckdb { + +BoundResultModifier::BoundResultModifier(ResultModifierType type) : type(type) { +} + +BoundResultModifier::~BoundResultModifier() { +} + +BoundOrderByNode::BoundOrderByNode(OrderType type, OrderByNullType null_order, unique_ptr expression) + : type(type), null_order(null_order), expression(std::move(expression)) { +} +BoundOrderByNode::BoundOrderByNode(OrderType type, OrderByNullType null_order, unique_ptr expression, + unique_ptr stats) + : type(type), null_order(null_order), expression(std::move(expression)), stats(std::move(stats)) { +} + +BoundOrderByNode BoundOrderByNode::Copy() const { + if (stats) { + return BoundOrderByNode(type, null_order, expression->Copy(), stats->ToUnique()); + } else { + return BoundOrderByNode(type, null_order, expression->Copy()); + } +} + +bool BoundOrderByNode::Equals(const BoundOrderByNode &other) const { + if (type != other.type || null_order != other.null_order) { + return false; + } + if (!expression->Equals(*other.expression)) { + return false; + } + + return true; +} + +string BoundOrderByNode::ToString() const { + auto str = expression->ToString(); + switch (type) { + case OrderType::ASCENDING: + str += " ASC"; + break; + case OrderType::DESCENDING: + str += " DESC"; + break; + default: + break; + } + + switch (null_order) { + case OrderByNullType::NULLS_FIRST: + str += " NULLS FIRST"; + break; + case OrderByNullType::NULLS_LAST: + str += " NULLS LAST"; + break; + default: + break; + } + return str; +} + +unique_ptr BoundOrderModifier::Copy() const { + auto result = make_uniq(); + for (auto &order : orders) { + result->orders.push_back(order.Copy()); + } + return result; +} + +bool BoundOrderModifier::Equals(const BoundOrderModifier &left, const BoundOrderModifier &right) { + if (left.orders.size() != right.orders.size()) { + return false; + } + for (idx_t i = 0; i < left.orders.size(); i++) { + if (!left.orders[i].Equals(right.orders[i])) { + return false; + } + } + return true; +} + +bool BoundOrderModifier::Equals(const unique_ptr &left, + const unique_ptr &right) { + if (left.get() == right.get()) { + return true; + } + if (!left || !right) { + return false; + } + return BoundOrderModifier::Equals(*left, *right); +} + +BoundLimitModifier::BoundLimitModifier() : BoundResultModifier(ResultModifierType::LIMIT_MODIFIER) { +} + +BoundOrderModifier::BoundOrderModifier() : BoundResultModifier(ResultModifierType::ORDER_MODIFIER) { +} + +BoundDistinctModifier::BoundDistinctModifier() : BoundResultModifier(ResultModifierType::DISTINCT_MODIFIER) { +} + +BoundLimitPercentModifier::BoundLimitPercentModifier() + : BoundResultModifier(ResultModifierType::LIMIT_PERCENT_MODIFIER) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression.cpp b/src/duckdb/src/planner/expression.cpp new file mode 100644 index 00000000..2d648771 --- /dev/null +++ b/src/duckdb/src/planner/expression.cpp @@ -0,0 +1,112 @@ +#include "duckdb/planner/expression.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/parser/expression_util.hpp" + +namespace duckdb { + +Expression::Expression(ExpressionType type, ExpressionClass expression_class, LogicalType return_type) + : BaseExpression(type, expression_class), return_type(std::move(return_type)) { +} + +Expression::~Expression() { +} + +bool Expression::IsAggregate() const { + bool is_aggregate = false; + ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { is_aggregate |= child.IsAggregate(); }); + return is_aggregate; +} + +bool Expression::IsWindow() const { + bool is_window = false; + ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { is_window |= child.IsWindow(); }); + return is_window; +} + +bool Expression::IsScalar() const { + bool is_scalar = true; + ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { + if (!child.IsScalar()) { + is_scalar = false; + } + }); + return is_scalar; +} + +bool Expression::HasSideEffects() const { + bool has_side_effects = false; + ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { + if (child.HasSideEffects()) { + has_side_effects = true; + } + }); + return has_side_effects; +} + +bool Expression::PropagatesNullValues() const { + if (type == ExpressionType::OPERATOR_IS_NULL || type == ExpressionType::OPERATOR_IS_NOT_NULL || + type == ExpressionType::COMPARE_NOT_DISTINCT_FROM || type == ExpressionType::COMPARE_DISTINCT_FROM || + type == ExpressionType::CONJUNCTION_OR || type == ExpressionType::CONJUNCTION_AND || + type == ExpressionType::OPERATOR_COALESCE) { + return false; + } + bool propagate_null_values = true; + ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { + if (!child.PropagatesNullValues()) { + propagate_null_values = false; + } + }); + return propagate_null_values; +} + +bool Expression::IsFoldable() const { + bool is_foldable = true; + ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { + if (!child.IsFoldable()) { + is_foldable = false; + } + }); + return is_foldable; +} + +bool Expression::HasParameter() const { + bool has_parameter = false; + ExpressionIterator::EnumerateChildren(*this, + [&](const Expression &child) { has_parameter |= child.HasParameter(); }); + return has_parameter; +} + +bool Expression::HasSubquery() const { + bool has_subquery = false; + ExpressionIterator::EnumerateChildren(*this, [&](const Expression &child) { has_subquery |= child.HasSubquery(); }); + return has_subquery; +} + +hash_t Expression::Hash() const { + hash_t hash = duckdb::Hash((uint32_t)type); + hash = CombineHash(hash, return_type.Hash()); + ExpressionIterator::EnumerateChildren(*this, + [&](const Expression &child) { hash = CombineHash(child.Hash(), hash); }); + return hash; +} + +bool Expression::Equals(const unique_ptr &left, const unique_ptr &right) { + if (left.get() == right.get()) { + return true; + } + if (!left || !right) { + return false; + } + return left->Equals(*right); +} + +bool Expression::ListEquals(const vector> &left, const vector> &right) { + return ExpressionUtil::ListEquals(left, right); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp b/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp new file mode 100644 index 00000000..a0e56c4f --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp @@ -0,0 +1,106 @@ +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/function/function_serialization.hpp" + +namespace duckdb { + +BoundAggregateExpression::BoundAggregateExpression(AggregateFunction function, vector> children, + unique_ptr filter, unique_ptr bind_info, + AggregateType aggr_type) + : Expression(ExpressionType::BOUND_AGGREGATE, ExpressionClass::BOUND_AGGREGATE, function.return_type), + function(std::move(function)), children(std::move(children)), bind_info(std::move(bind_info)), + aggr_type(aggr_type), filter(std::move(filter)) { + D_ASSERT(!this->function.name.empty()); +} + +string BoundAggregateExpression::ToString() const { + return FunctionExpression::ToString( + *this, string(), function.name, false, IsDistinct(), filter.get(), order_bys.get()); +} + +hash_t BoundAggregateExpression::Hash() const { + hash_t result = Expression::Hash(); + result = CombineHash(result, function.Hash()); + result = CombineHash(result, duckdb::Hash(IsDistinct())); + return result; +} + +bool BoundAggregateExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (other.aggr_type != aggr_type) { + return false; + } + if (other.function != function) { + return false; + } + if (children.size() != other.children.size()) { + return false; + } + if (!Expression::Equals(other.filter, filter)) { + return false; + } + for (idx_t i = 0; i < children.size(); i++) { + if (!Expression::Equals(*children[i], *other.children[i])) { + return false; + } + } + if (!FunctionData::Equals(bind_info.get(), other.bind_info.get())) { + return false; + } + if (!BoundOrderModifier::Equals(order_bys, other.order_bys)) { + return false; + } + return true; +} + +bool BoundAggregateExpression::PropagatesNullValues() const { + return function.null_handling == FunctionNullHandling::SPECIAL_HANDLING ? false + : Expression::PropagatesNullValues(); +} + +unique_ptr BoundAggregateExpression::Copy() { + vector> new_children; + new_children.reserve(children.size()); + for (auto &child : children) { + new_children.push_back(child->Copy()); + } + auto new_bind_info = bind_info ? bind_info->Copy() : nullptr; + auto new_filter = filter ? filter->Copy() : nullptr; + auto copy = make_uniq(function, std::move(new_children), std::move(new_filter), + std::move(new_bind_info), aggr_type); + copy->CopyProperties(*this); + copy->order_bys = order_bys ? order_bys->Copy() : nullptr; + return std::move(copy); +} + +void BoundAggregateExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WriteProperty(201, "children", children); + FunctionSerializer::Serialize(serializer, function, bind_info.get()); + serializer.WriteProperty(203, "aggregate_type", aggr_type); + serializer.WritePropertyWithDefault(204, "filter", filter, unique_ptr()); + serializer.WritePropertyWithDefault(205, "order_bys", order_bys, unique_ptr()); +} + +unique_ptr BoundAggregateExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto children = deserializer.ReadProperty>>(201, "children"); + auto entry = FunctionSerializer::Deserialize( + deserializer, CatalogType::AGGREGATE_FUNCTION_ENTRY, children, std::move(return_type)); + auto aggregate_type = deserializer.ReadProperty(203, "aggregate_type"); + auto filter = deserializer.ReadPropertyWithDefault>(204, "filter", unique_ptr()); + auto result = make_uniq(std::move(entry.first), std::move(children), std::move(filter), + std::move(entry.second), aggregate_type); + deserializer.ReadPropertyWithDefault(205, "order_bys", result->order_bys, unique_ptr()); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_between_expression.cpp b/src/duckdb/src/planner/expression/bound_between_expression.cpp new file mode 100644 index 00000000..7b308279 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_between_expression.cpp @@ -0,0 +1,45 @@ +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/parser/expression/between_expression.hpp" + +namespace duckdb { + +BoundBetweenExpression::BoundBetweenExpression() + : Expression(ExpressionType::COMPARE_BETWEEN, ExpressionClass::BOUND_BETWEEN, LogicalType::BOOLEAN) { +} + +BoundBetweenExpression::BoundBetweenExpression(unique_ptr input, unique_ptr lower, + unique_ptr upper, bool lower_inclusive, bool upper_inclusive) + : Expression(ExpressionType::COMPARE_BETWEEN, ExpressionClass::BOUND_BETWEEN, LogicalType::BOOLEAN), + input(std::move(input)), lower(std::move(lower)), upper(std::move(upper)), lower_inclusive(lower_inclusive), + upper_inclusive(upper_inclusive) { +} + +string BoundBetweenExpression::ToString() const { + return BetweenExpression::ToString(*this); +} + +bool BoundBetweenExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!Expression::Equals(*input, *other.input)) { + return false; + } + if (!Expression::Equals(*lower, *other.lower)) { + return false; + } + if (!Expression::Equals(*upper, *other.upper)) { + return false; + } + return lower_inclusive == other.lower_inclusive && upper_inclusive == other.upper_inclusive; +} + +unique_ptr BoundBetweenExpression::Copy() { + auto copy = make_uniq(input->Copy(), lower->Copy(), upper->Copy(), lower_inclusive, + upper_inclusive); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_case_expression.cpp b/src/duckdb/src/planner/expression/bound_case_expression.cpp new file mode 100644 index 00000000..badfee06 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_case_expression.cpp @@ -0,0 +1,60 @@ +#include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/parser/expression/case_expression.hpp" + +namespace duckdb { + +BoundCaseExpression::BoundCaseExpression(LogicalType type) + : Expression(ExpressionType::CASE_EXPR, ExpressionClass::BOUND_CASE, std::move(type)) { +} + +BoundCaseExpression::BoundCaseExpression(unique_ptr when_expr, unique_ptr then_expr, + unique_ptr else_expr_p) + : Expression(ExpressionType::CASE_EXPR, ExpressionClass::BOUND_CASE, then_expr->return_type), + else_expr(std::move(else_expr_p)) { + BoundCaseCheck check; + check.when_expr = std::move(when_expr); + check.then_expr = std::move(then_expr); + case_checks.push_back(std::move(check)); +} + +string BoundCaseExpression::ToString() const { + return CaseExpression::ToString(*this); +} + +bool BoundCaseExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (case_checks.size() != other.case_checks.size()) { + return false; + } + for (idx_t i = 0; i < case_checks.size(); i++) { + if (!Expression::Equals(*case_checks[i].when_expr, *other.case_checks[i].when_expr)) { + return false; + } + if (!Expression::Equals(*case_checks[i].then_expr, *other.case_checks[i].then_expr)) { + return false; + } + } + if (!Expression::Equals(*else_expr, *other.else_expr)) { + return false; + } + return true; +} + +unique_ptr BoundCaseExpression::Copy() { + auto new_case = make_uniq(return_type); + for (auto &check : case_checks) { + BoundCaseCheck new_check; + new_check.when_expr = check.when_expr->Copy(); + new_check.then_expr = check.then_expr->Copy(); + new_case->case_checks.push_back(std::move(new_check)); + } + new_case->else_expr = else_expr->Copy(); + + new_case->CopyProperties(*this); + return std::move(new_case); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_cast_expression.cpp b/src/duckdb/src/planner/expression/bound_cast_expression.cpp new file mode 100644 index 00000000..55057787 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_cast_expression.cpp @@ -0,0 +1,196 @@ +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_default_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/function/cast_rules.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +static BoundCastInfo BindCastFunction(ClientContext &context, const LogicalType &source, const LogicalType &target) { + auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions(); + GetCastFunctionInput input(context); + return cast_functions.GetCastFunction(source, target, input); +} + +BoundCastExpression::BoundCastExpression(unique_ptr child_p, LogicalType target_type_p, + BoundCastInfo bound_cast_p, bool try_cast_p) + : Expression(ExpressionType::OPERATOR_CAST, ExpressionClass::BOUND_CAST, std::move(target_type_p)), + child(std::move(child_p)), try_cast(try_cast_p), bound_cast(std::move(bound_cast_p)) { +} + +BoundCastExpression::BoundCastExpression(ClientContext &context, unique_ptr child_p, + LogicalType target_type_p) + : Expression(ExpressionType::OPERATOR_CAST, ExpressionClass::BOUND_CAST, std::move(target_type_p)), + child(std::move(child_p)), try_cast(false), + bound_cast(BindCastFunction(context, child->return_type, return_type)) { +} + +unique_ptr AddCastExpressionInternal(unique_ptr expr, const LogicalType &target_type, + BoundCastInfo bound_cast, bool try_cast) { + if (expr->return_type == target_type) { + return expr; + } + auto &expr_type = expr->return_type; + if (target_type.id() == LogicalTypeId::LIST && expr_type.id() == LogicalTypeId::LIST) { + auto &target_list = ListType::GetChildType(target_type); + auto &expr_list = ListType::GetChildType(expr_type); + if (target_list.id() == LogicalTypeId::ANY || expr_list == target_list) { + return expr; + } + } + return make_uniq(std::move(expr), target_type, std::move(bound_cast), try_cast); +} + +unique_ptr AddCastToTypeInternal(unique_ptr expr, const LogicalType &target_type, + CastFunctionSet &cast_functions, GetCastFunctionInput &get_input, + bool try_cast) { + D_ASSERT(expr); + if (expr->expression_class == ExpressionClass::BOUND_PARAMETER) { + auto ¶meter = expr->Cast(); + if (!target_type.IsValid()) { + // invalidate the parameter + parameter.parameter_data->return_type = LogicalType::INVALID; + parameter.return_type = target_type; + return expr; + } + if (parameter.parameter_data->return_type.id() == LogicalTypeId::INVALID) { + // we don't know the type of this parameter + parameter.return_type = target_type; + return expr; + } + if (parameter.parameter_data->return_type.id() == LogicalTypeId::UNKNOWN) { + // prepared statement parameter cast - but there is no type, convert the type + parameter.parameter_data->return_type = target_type; + parameter.return_type = target_type; + return expr; + } + // prepared statement parameter already has a type + if (parameter.parameter_data->return_type == target_type) { + // this type! we are done + parameter.return_type = parameter.parameter_data->return_type; + return expr; + } + // invalidate the type + parameter.parameter_data->return_type = LogicalType::INVALID; + parameter.return_type = target_type; + return expr; + } else if (expr->expression_class == ExpressionClass::BOUND_DEFAULT) { + D_ASSERT(target_type.IsValid()); + auto &def = expr->Cast(); + def.return_type = target_type; + } + if (!target_type.IsValid()) { + return expr; + } + + auto cast_function = cast_functions.GetCastFunction(expr->return_type, target_type, get_input); + return AddCastExpressionInternal(std::move(expr), target_type, std::move(cast_function), try_cast); +} + +unique_ptr BoundCastExpression::AddDefaultCastToType(unique_ptr expr, + const LogicalType &target_type, bool try_cast) { + CastFunctionSet default_set; + GetCastFunctionInput get_input; + return AddCastToTypeInternal(std::move(expr), target_type, default_set, get_input, try_cast); +} + +unique_ptr BoundCastExpression::AddCastToType(ClientContext &context, unique_ptr expr, + const LogicalType &target_type, bool try_cast) { + auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions(); + GetCastFunctionInput get_input(context); + return AddCastToTypeInternal(std::move(expr), target_type, cast_functions, get_input, try_cast); +} + +bool BoundCastExpression::CastIsInvertible(const LogicalType &source_type, const LogicalType &target_type) { + D_ASSERT(source_type.IsValid() && target_type.IsValid()); + if (source_type.id() == LogicalTypeId::BOOLEAN || target_type.id() == LogicalTypeId::BOOLEAN) { + return false; + } + if (source_type.id() == LogicalTypeId::FLOAT || target_type.id() == LogicalTypeId::FLOAT) { + return false; + } + if (source_type.id() == LogicalTypeId::DOUBLE || target_type.id() == LogicalTypeId::DOUBLE) { + return false; + } + if (source_type.id() == LogicalTypeId::DECIMAL || target_type.id() == LogicalTypeId::DECIMAL) { + uint8_t source_width, target_width; + uint8_t source_scale, target_scale; + // cast to or from decimal + // cast is only invertible if the cast is strictly widening + if (!source_type.GetDecimalProperties(source_width, source_scale)) { + return false; + } + if (!target_type.GetDecimalProperties(target_width, target_scale)) { + return false; + } + if (target_scale < source_scale) { + return false; + } + return true; + } + if (source_type.id() == LogicalTypeId::TIMESTAMP || source_type.id() == LogicalTypeId::TIMESTAMP_TZ) { + switch (target_type.id()) { + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return false; + default: + break; + } + } + if (source_type.id() == LogicalTypeId::VARCHAR) { + switch (target_type.id()) { + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_TZ: + return true; + default: + return false; + } + } + if (target_type.id() == LogicalTypeId::VARCHAR) { + switch (source_type.id()) { + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIMESTAMP_TZ: + return true; + default: + return false; + } + } + return true; +} + +string BoundCastExpression::ToString() const { + return (try_cast ? "TRY_CAST(" : "CAST(") + child->GetName() + " AS " + return_type.ToString() + ")"; +} + +bool BoundCastExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!Expression::Equals(*child, *other.child)) { + return false; + } + if (try_cast != other.try_cast) { + return false; + } + return true; +} + +unique_ptr BoundCastExpression::Copy() { + auto copy = make_uniq(child->Copy(), return_type, bound_cast.Copy(), try_cast); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_columnref_expression.cpp b/src/duckdb/src/planner/expression/bound_columnref_expression.cpp new file mode 100644 index 00000000..a3cfd158 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_columnref_expression.cpp @@ -0,0 +1,58 @@ +#include "duckdb/planner/expression/bound_columnref_expression.hpp" + +#include "duckdb/common/types/hash.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +BoundColumnRefExpression::BoundColumnRefExpression(string alias_p, LogicalType type, ColumnBinding binding, idx_t depth) + : Expression(ExpressionType::BOUND_COLUMN_REF, ExpressionClass::BOUND_COLUMN_REF, std::move(type)), + binding(binding), depth(depth) { + this->alias = std::move(alias_p); +} + +BoundColumnRefExpression::BoundColumnRefExpression(LogicalType type, ColumnBinding binding, idx_t depth) + : BoundColumnRefExpression(string(), std::move(type), binding, depth) { +} + +unique_ptr BoundColumnRefExpression::Copy() { + return make_uniq(alias, return_type, binding, depth); +} + +hash_t BoundColumnRefExpression::Hash() const { + auto result = Expression::Hash(); + result = CombineHash(result, duckdb::Hash(binding.column_index)); + result = CombineHash(result, duckdb::Hash(binding.table_index)); + return CombineHash(result, duckdb::Hash(depth)); +} + +bool BoundColumnRefExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return other.binding == binding && other.depth == depth; +} + +string BoundColumnRefExpression::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return binding.ToString(); + } +#endif + return Expression::GetName(); +} + +string BoundColumnRefExpression::ToString() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return binding.ToString(); + } +#endif + if (!alias.empty()) { + return alias; + } + return binding.ToString(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_comparison_expression.cpp b/src/duckdb/src/planner/expression/bound_comparison_expression.cpp new file mode 100644 index 00000000..b4b6f7e0 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_comparison_expression.cpp @@ -0,0 +1,36 @@ +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" + +namespace duckdb { + +BoundComparisonExpression::BoundComparisonExpression(ExpressionType type, unique_ptr left, + unique_ptr right) + : Expression(type, ExpressionClass::BOUND_COMPARISON, LogicalType::BOOLEAN), left(std::move(left)), + right(std::move(right)) { +} + +string BoundComparisonExpression::ToString() const { + return ComparisonExpression::ToString(*this); +} + +bool BoundComparisonExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!Expression::Equals(*left, *other.left)) { + return false; + } + if (!Expression::Equals(*right, *other.right)) { + return false; + } + return true; +} + +unique_ptr BoundComparisonExpression::Copy() { + auto copy = make_uniq(type, left->Copy(), right->Copy()); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_conjunction_expression.cpp b/src/duckdb/src/planner/expression/bound_conjunction_expression.cpp new file mode 100644 index 00000000..6063d780 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_conjunction_expression.cpp @@ -0,0 +1,43 @@ +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression_util.hpp" + +namespace duckdb { + +BoundConjunctionExpression::BoundConjunctionExpression(ExpressionType type) + : Expression(type, ExpressionClass::BOUND_CONJUNCTION, LogicalType::BOOLEAN) { +} + +BoundConjunctionExpression::BoundConjunctionExpression(ExpressionType type, unique_ptr left, + unique_ptr right) + : BoundConjunctionExpression(type) { + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +string BoundConjunctionExpression::ToString() const { + return ConjunctionExpression::ToString(*this); +} + +bool BoundConjunctionExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return ExpressionUtil::SetEquals(children, other.children); +} + +bool BoundConjunctionExpression::PropagatesNullValues() const { + return false; +} + +unique_ptr BoundConjunctionExpression::Copy() { + auto copy = make_uniq(type); + for (auto &expr : children) { + copy->children.push_back(expr->Copy()); + } + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_constant_expression.cpp b/src/duckdb/src/planner/expression/bound_constant_expression.cpp new file mode 100644 index 00000000..3fe2d75b --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_constant_expression.cpp @@ -0,0 +1,35 @@ +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" + +namespace duckdb { + +BoundConstantExpression::BoundConstantExpression(Value value_p) + : Expression(ExpressionType::VALUE_CONSTANT, ExpressionClass::BOUND_CONSTANT, value_p.type()), + value(std::move(value_p)) { +} + +string BoundConstantExpression::ToString() const { + return value.ToSQLString(); +} + +bool BoundConstantExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return value.type() == other.value.type() && !ValueOperations::DistinctFrom(value, other.value); +} + +hash_t BoundConstantExpression::Hash() const { + hash_t result = Expression::Hash(); + return CombineHash(value.Hash(), result); +} + +unique_ptr BoundConstantExpression::Copy() { + auto copy = make_uniq(value); + copy->CopyProperties(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_expression.cpp b/src/duckdb/src/planner/expression/bound_expression.cpp new file mode 100644 index 00000000..e5cdfb1c --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_expression.cpp @@ -0,0 +1,40 @@ +#include "duckdb/parser/expression/bound_expression.hpp" + +namespace duckdb { + +BoundExpression::BoundExpression(unique_ptr expr_p) + : ParsedExpression(ExpressionType::INVALID, ExpressionClass::BOUND_EXPRESSION), expr(std::move(expr_p)) { + this->alias = expr->alias; +} + +unique_ptr &BoundExpression::GetExpression(ParsedExpression &expr) { + auto &bound_expr = expr.Cast(); + if (!bound_expr.expr) { + throw InternalException("BoundExpression::GetExpression called on empty bound expression"); + } + return bound_expr.expr; +} + +string BoundExpression::ToString() const { + if (!expr) { + throw InternalException("ToString(): BoundExpression does not have a child"); + } + return expr->ToString(); +} + +bool BoundExpression::Equals(const BaseExpression &other) const { + return false; +} +hash_t BoundExpression::Hash() const { + return 0; +} + +unique_ptr BoundExpression::Copy() const { + throw SerializationException("Cannot copy or serialize bound expression"); +} + +void BoundExpression::Serialize(Serializer &serializer) const { + throw SerializationException("Cannot copy or serialize bound expression"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_function_expression.cpp b/src/duckdb/src/planner/expression/bound_function_expression.cpp new file mode 100644 index 00000000..40ba170c --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_function_expression.cpp @@ -0,0 +1,97 @@ +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/function/function_serialization.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +BoundFunctionExpression::BoundFunctionExpression(LogicalType return_type, ScalarFunction bound_function, + vector> arguments, + unique_ptr bind_info, bool is_operator) + : Expression(ExpressionType::BOUND_FUNCTION, ExpressionClass::BOUND_FUNCTION, std::move(return_type)), + function(std::move(bound_function)), children(std::move(arguments)), bind_info(std::move(bind_info)), + is_operator(is_operator) { + D_ASSERT(!function.name.empty()); +} + +bool BoundFunctionExpression::HasSideEffects() const { + return function.side_effects == FunctionSideEffects::HAS_SIDE_EFFECTS ? true : Expression::HasSideEffects(); +} + +bool BoundFunctionExpression::IsFoldable() const { + // functions with side effects cannot be folded: they have to be executed once for every row + return function.side_effects == FunctionSideEffects::HAS_SIDE_EFFECTS ? false : Expression::IsFoldable(); +} + +string BoundFunctionExpression::ToString() const { + return FunctionExpression::ToString(*this, string(), function.name, + is_operator); +} +bool BoundFunctionExpression::PropagatesNullValues() const { + return function.null_handling == FunctionNullHandling::SPECIAL_HANDLING ? false + : Expression::PropagatesNullValues(); +} + +hash_t BoundFunctionExpression::Hash() const { + hash_t result = Expression::Hash(); + return CombineHash(result, function.Hash()); +} + +bool BoundFunctionExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (other.function != function) { + return false; + } + if (!Expression::ListEquals(children, other.children)) { + return false; + } + if (!FunctionData::Equals(bind_info.get(), other.bind_info.get())) { + return false; + } + return true; +} + +unique_ptr BoundFunctionExpression::Copy() { + vector> new_children; + new_children.reserve(children.size()); + for (auto &child : children) { + new_children.push_back(child->Copy()); + } + unique_ptr new_bind_info = bind_info ? bind_info->Copy() : nullptr; + + auto copy = make_uniq(return_type, function, std::move(new_children), + std::move(new_bind_info), is_operator); + copy->CopyProperties(*this); + return std::move(copy); +} + +void BoundFunctionExpression::Verify() const { + D_ASSERT(!function.name.empty()); +} + +void BoundFunctionExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WriteProperty(201, "children", children); + FunctionSerializer::Serialize(serializer, function, bind_info.get()); + serializer.WriteProperty(202, "is_operator", is_operator); +} + +unique_ptr BoundFunctionExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto children = deserializer.ReadProperty>>(201, "children"); + auto entry = FunctionSerializer::Deserialize( + deserializer, CatalogType::SCALAR_FUNCTION_ENTRY, children, return_type); + auto result = make_uniq(std::move(return_type), std::move(entry.first), + std::move(children), std::move(entry.second)); + deserializer.ReadProperty(202, "is_operator", result->is_operator); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_lambda_expression.cpp b/src/duckdb/src/planner/expression/bound_lambda_expression.cpp new file mode 100644 index 00000000..51778cea --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_lambda_expression.cpp @@ -0,0 +1,41 @@ +#include "duckdb/planner/expression/bound_lambda_expression.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +BoundLambdaExpression::BoundLambdaExpression(ExpressionType type_p, LogicalType return_type_p, + unique_ptr lambda_expr_p, idx_t parameter_count_p) + : Expression(type_p, ExpressionClass::BOUND_LAMBDA, std::move(return_type_p)), + lambda_expr(std::move(lambda_expr_p)), parameter_count(parameter_count_p) { +} + +string BoundLambdaExpression::ToString() const { + return lambda_expr->ToString(); +} + +bool BoundLambdaExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!Expression::Equals(*lambda_expr, *other.lambda_expr)) { + return false; + } + if (!Expression::ListEquals(captures, other.captures)) { + return false; + } + if (parameter_count != other.parameter_count) { + return false; + } + return true; +} + +unique_ptr BoundLambdaExpression::Copy() { + auto copy = make_uniq(type, return_type, lambda_expr->Copy(), parameter_count); + for (auto &capture : captures) { + copy->captures.push_back(capture->Copy()); + } + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_lambdaref_expression.cpp b/src/duckdb/src/planner/expression/bound_lambdaref_expression.cpp new file mode 100644 index 00000000..388404cf --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_lambdaref_expression.cpp @@ -0,0 +1,48 @@ +#include "duckdb/planner/expression/bound_lambdaref_expression.hpp" + +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/to_string.hpp" + +namespace duckdb { + +BoundLambdaRefExpression::BoundLambdaRefExpression(string alias_p, LogicalType type, ColumnBinding binding, + idx_t lambda_index, idx_t depth) + : Expression(ExpressionType::BOUND_LAMBDA_REF, ExpressionClass::BOUND_LAMBDA_REF, std::move(type)), + binding(binding), lambda_index(lambda_index), depth(depth) { + this->alias = std::move(alias_p); +} + +BoundLambdaRefExpression::BoundLambdaRefExpression(LogicalType type, ColumnBinding binding, idx_t lambda_index, + idx_t depth) + : BoundLambdaRefExpression(string(), std::move(type), binding, lambda_index, depth) { +} + +unique_ptr BoundLambdaRefExpression::Copy() { + return make_uniq(alias, return_type, binding, lambda_index, depth); +} + +hash_t BoundLambdaRefExpression::Hash() const { + auto result = Expression::Hash(); + result = CombineHash(result, duckdb::Hash(lambda_index)); + result = CombineHash(result, duckdb::Hash(binding.column_index)); + result = CombineHash(result, duckdb::Hash(binding.table_index)); + return CombineHash(result, duckdb::Hash(depth)); +} + +bool BoundLambdaRefExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return other.binding == binding && other.lambda_index == lambda_index && other.depth == depth; +} + +string BoundLambdaRefExpression::ToString() const { + if (!alias.empty()) { + return alias; + } + return "#[" + to_string(binding.table_index) + "." + to_string(binding.column_index) + "." + + to_string(lambda_index) + "]"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_operator_expression.cpp b/src/duckdb/src/planner/expression/bound_operator_expression.cpp new file mode 100644 index 00000000..dce8708a --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_operator_expression.cpp @@ -0,0 +1,35 @@ +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" + +namespace duckdb { + +BoundOperatorExpression::BoundOperatorExpression(ExpressionType type, LogicalType return_type) + : Expression(type, ExpressionClass::BOUND_OPERATOR, std::move(return_type)) { +} + +string BoundOperatorExpression::ToString() const { + return OperatorExpression::ToString(*this); +} + +bool BoundOperatorExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!Expression::ListEquals(children, other.children)) { + return false; + } + return true; +} + +unique_ptr BoundOperatorExpression::Copy() { + auto copy = make_uniq(type, return_type); + copy->CopyProperties(*this); + for (auto &child : children) { + copy->children.push_back(child->Copy()); + } + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_parameter_expression.cpp b/src/duckdb/src/planner/expression/bound_parameter_expression.cpp new file mode 100644 index 00000000..329b7191 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_parameter_expression.cpp @@ -0,0 +1,84 @@ +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/planner/expression_iterator.hpp" + +namespace duckdb { + +BoundParameterExpression::BoundParameterExpression(const string &identifier) + : Expression(ExpressionType::VALUE_PARAMETER, ExpressionClass::BOUND_PARAMETER, + LogicalType(LogicalTypeId::UNKNOWN)), + identifier(identifier) { +} + +BoundParameterExpression::BoundParameterExpression(bound_parameter_map_t &global_parameter_set, string identifier, + LogicalType return_type, + shared_ptr parameter_data) + : Expression(ExpressionType::VALUE_PARAMETER, ExpressionClass::BOUND_PARAMETER, std::move(return_type)), + identifier(std::move(identifier)) { + // check if we have already deserialized a parameter with this number + auto entry = global_parameter_set.find(this->identifier); + if (entry == global_parameter_set.end()) { + // we have not - store the entry we deserialized from this parameter expression + global_parameter_set[this->identifier] = parameter_data; + } else { + // we have! use the previously deserialized entry + parameter_data = entry->second; + } + this->parameter_data = std::move(parameter_data); +} + +void BoundParameterExpression::Invalidate(Expression &expr) { + if (expr.type != ExpressionType::VALUE_PARAMETER) { + throw InternalException("BoundParameterExpression::Invalidate requires a parameter as input"); + } + auto &bound_parameter = expr.Cast(); + bound_parameter.return_type = LogicalTypeId::SQLNULL; + bound_parameter.parameter_data->return_type = LogicalTypeId::INVALID; +} + +void BoundParameterExpression::InvalidateRecursive(Expression &expr) { + if (expr.type == ExpressionType::VALUE_PARAMETER) { + Invalidate(expr); + return; + } + ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { InvalidateRecursive(child); }); +} + +bool BoundParameterExpression::IsScalar() const { + return true; +} +bool BoundParameterExpression::HasParameter() const { + return true; +} +bool BoundParameterExpression::IsFoldable() const { + return false; +} + +string BoundParameterExpression::ToString() const { + return "$" + identifier; +} + +bool BoundParameterExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return StringUtil::CIEquals(identifier, other.identifier); +} + +hash_t BoundParameterExpression::Hash() const { + hash_t result = Expression::Hash(); + result = CombineHash(duckdb::Hash(identifier.c_str(), identifier.size()), result); + return result; +} + +unique_ptr BoundParameterExpression::Copy() { + auto result = make_uniq(identifier); + result->parameter_data = parameter_data; + result->return_type = return_type; + result->CopyProperties(*this); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_reference_expression.cpp b/src/duckdb/src/planner/expression/bound_reference_expression.cpp new file mode 100644 index 00000000..b309b498 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_reference_expression.cpp @@ -0,0 +1,45 @@ +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +#include "duckdb/common/to_string.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +BoundReferenceExpression::BoundReferenceExpression(string alias, LogicalType type, idx_t index) + : Expression(ExpressionType::BOUND_REF, ExpressionClass::BOUND_REF, std::move(type)), index(index) { + this->alias = std::move(alias); +} +BoundReferenceExpression::BoundReferenceExpression(LogicalType type, idx_t index) + : BoundReferenceExpression(string(), std::move(type), index) { +} + +string BoundReferenceExpression::ToString() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return "#" + to_string(index); + } +#endif + if (!alias.empty()) { + return alias; + } + return "#" + to_string(index); +} + +bool BoundReferenceExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return other.index == index; +} + +hash_t BoundReferenceExpression::Hash() const { + return CombineHash(Expression::Hash(), duckdb::Hash(index)); +} + +unique_ptr BoundReferenceExpression::Copy() { + return make_uniq(alias, return_type, index); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_subquery_expression.cpp b/src/duckdb/src/planner/expression/bound_subquery_expression.cpp new file mode 100644 index 00000000..02a7770d --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_subquery_expression.cpp @@ -0,0 +1,29 @@ +#include "duckdb/planner/expression/bound_subquery_expression.hpp" + +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +BoundSubqueryExpression::BoundSubqueryExpression(LogicalType return_type) + : Expression(ExpressionType::SUBQUERY, ExpressionClass::BOUND_SUBQUERY, std::move(return_type)) { +} + +string BoundSubqueryExpression::ToString() const { + return "SUBQUERY"; +} + +bool BoundSubqueryExpression::Equals(const BaseExpression &other_p) const { + // equality between bound subqueries not implemented currently + return false; +} + +unique_ptr BoundSubqueryExpression::Copy() { + throw SerializationException("Cannot copy BoundSubqueryExpression"); +} + +bool BoundSubqueryExpression::PropagatesNullValues() const { + // TODO this can be optimized further by checking the actual subquery node + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_unnest_expression.cpp b/src/duckdb/src/planner/expression/bound_unnest_expression.cpp new file mode 100644 index 00000000..0b86c783 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_unnest_expression.cpp @@ -0,0 +1,42 @@ +#include "duckdb/planner/expression/bound_unnest_expression.hpp" + +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +BoundUnnestExpression::BoundUnnestExpression(LogicalType return_type) + : Expression(ExpressionType::BOUND_UNNEST, ExpressionClass::BOUND_UNNEST, std::move(return_type)) { +} + +bool BoundUnnestExpression::IsFoldable() const { + return false; +} + +string BoundUnnestExpression::ToString() const { + return "UNNEST(" + child->ToString() + ")"; +} + +hash_t BoundUnnestExpression::Hash() const { + hash_t result = Expression::Hash(); + return CombineHash(result, duckdb::Hash("unnest")); +} + +bool BoundUnnestExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (!Expression::Equals(*child, *other.child)) { + return false; + } + return true; +} + +unique_ptr BoundUnnestExpression::Copy() { + auto copy = make_uniq(return_type); + copy->child = child->Copy(); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_window_expression.cpp b/src/duckdb/src/planner/expression/bound_window_expression.cpp new file mode 100644 index 00000000..684c7ef7 --- /dev/null +++ b/src/duckdb/src/planner/expression/bound_window_expression.cpp @@ -0,0 +1,159 @@ +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/parser/expression/window_expression.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/aggregate_function.hpp" +#include "duckdb/function/function_serialization.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" + +namespace duckdb { + +BoundWindowExpression::BoundWindowExpression(ExpressionType type, LogicalType return_type, + unique_ptr aggregate, + unique_ptr bind_info) + : Expression(type, ExpressionClass::BOUND_WINDOW, std::move(return_type)), aggregate(std::move(aggregate)), + bind_info(std::move(bind_info)), ignore_nulls(false) { +} + +string BoundWindowExpression::ToString() const { + string function_name = aggregate.get() ? aggregate->name : ExpressionTypeToString(type); + return WindowExpression::ToString(*this, string(), + function_name); +} + +bool BoundWindowExpression::Equals(const BaseExpression &other_p) const { + if (!Expression::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + + if (ignore_nulls != other.ignore_nulls) { + return false; + } + if (start != other.start || end != other.end) { + return false; + } + // check if the child expressions are equivalent + if (!Expression::ListEquals(children, other.children)) { + return false; + } + // check if the filter expressions are equivalent + if (!Expression::Equals(filter_expr, other.filter_expr)) { + return false; + } + + // check if the framing expressions are equivalent + if (!Expression::Equals(start_expr, other.start_expr) || !Expression::Equals(end_expr, other.end_expr) || + !Expression::Equals(offset_expr, other.offset_expr) || !Expression::Equals(default_expr, other.default_expr)) { + return false; + } + + return KeysAreCompatible(other); +} + +bool BoundWindowExpression::KeysAreCompatible(const BoundWindowExpression &other) const { + // check if the partitions are equivalent + if (!Expression::ListEquals(partitions, other.partitions)) { + return false; + } + // check if the orderings are equivalent + if (orders.size() != other.orders.size()) { + return false; + } + for (idx_t i = 0; i < orders.size(); i++) { + if (!orders[i].Equals(other.orders[i])) { + return false; + } + } + return true; +} + +unique_ptr BoundWindowExpression::Copy() { + auto new_window = make_uniq(type, return_type, nullptr, nullptr); + new_window->CopyProperties(*this); + + if (aggregate) { + new_window->aggregate = make_uniq(*aggregate); + } + if (bind_info) { + new_window->bind_info = bind_info->Copy(); + } + for (auto &child : children) { + new_window->children.push_back(child->Copy()); + } + for (auto &e : partitions) { + new_window->partitions.push_back(e->Copy()); + } + for (auto &ps : partitions_stats) { + if (ps) { + new_window->partitions_stats.push_back(ps->ToUnique()); + } else { + new_window->partitions_stats.push_back(nullptr); + } + } + for (auto &o : orders) { + new_window->orders.emplace_back(o.type, o.null_order, o.expression->Copy()); + } + + new_window->filter_expr = filter_expr ? filter_expr->Copy() : nullptr; + + new_window->start = start; + new_window->end = end; + new_window->start_expr = start_expr ? start_expr->Copy() : nullptr; + new_window->end_expr = end_expr ? end_expr->Copy() : nullptr; + new_window->offset_expr = offset_expr ? offset_expr->Copy() : nullptr; + new_window->default_expr = default_expr ? default_expr->Copy() : nullptr; + new_window->ignore_nulls = ignore_nulls; + + return std::move(new_window); +} + +void BoundWindowExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WriteProperty(201, "children", children); + if (type == ExpressionType::WINDOW_AGGREGATE) { + D_ASSERT(aggregate); + FunctionSerializer::Serialize(serializer, *aggregate, bind_info.get()); + } + serializer.WriteProperty(202, "partitions", partitions); + serializer.WriteProperty(203, "orders", orders); + serializer.WritePropertyWithDefault(204, "filters", filter_expr, unique_ptr()); + serializer.WriteProperty(205, "ignore_nulls", ignore_nulls); + serializer.WriteProperty(206, "start", start); + serializer.WriteProperty(207, "end", end); + serializer.WritePropertyWithDefault(208, "start_expr", start_expr, unique_ptr()); + serializer.WritePropertyWithDefault(209, "end_expr", end_expr, unique_ptr()); + serializer.WritePropertyWithDefault(210, "offset_expr", offset_expr, unique_ptr()); + serializer.WritePropertyWithDefault(211, "default_expr", default_expr, unique_ptr()); +} + +unique_ptr BoundWindowExpression::Deserialize(Deserializer &deserializer) { + auto expression_type = deserializer.Get(); + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto children = deserializer.ReadProperty>>(201, "children"); + unique_ptr aggregate; + unique_ptr bind_info; + if (expression_type == ExpressionType::WINDOW_AGGREGATE) { + auto entry = FunctionSerializer::Deserialize( + deserializer, CatalogType::AGGREGATE_FUNCTION_ENTRY, children, return_type); + aggregate = make_uniq(std::move(entry.first)); + bind_info = std::move(entry.second); + } + auto result = + make_uniq(expression_type, return_type, std::move(aggregate), std::move(bind_info)); + result->children = std::move(children); + deserializer.ReadProperty(202, "partitions", result->partitions); + deserializer.ReadProperty(203, "orders", result->orders); + deserializer.ReadPropertyWithDefault(204, "filters", result->filter_expr, unique_ptr()); + deserializer.ReadProperty(205, "ignore_nulls", result->ignore_nulls); + deserializer.ReadProperty(206, "start", result->start); + deserializer.ReadProperty(207, "end", result->end); + deserializer.ReadPropertyWithDefault(208, "start_expr", result->start_expr, unique_ptr()); + deserializer.ReadPropertyWithDefault(209, "end_expr", result->end_expr, unique_ptr()); + deserializer.ReadPropertyWithDefault(210, "offset_expr", result->offset_expr, unique_ptr()); + deserializer.ReadPropertyWithDefault(211, "default_expr", result->default_expr, unique_ptr()); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder.cpp b/src/duckdb/src/planner/expression_binder.cpp new file mode 100644 index 00000000..266b4102 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder.cpp @@ -0,0 +1,271 @@ +#include "duckdb/planner/expression_binder.hpp" + +#include "duckdb/parser/expression/list.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/expression_iterator.hpp" + +namespace duckdb { + +ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder) + : binder(binder), context(context) { + InitializeStackCheck(); + if (replace_binder) { + stored_binder = &binder.GetActiveBinder(); + binder.SetActiveBinder(*this); + } else { + binder.PushExpressionBinder(*this); + } +} + +ExpressionBinder::~ExpressionBinder() { + if (binder.HasActiveBinder()) { + if (stored_binder) { + binder.SetActiveBinder(*stored_binder); + } else { + binder.PopExpressionBinder(); + } + } +} + +void ExpressionBinder::InitializeStackCheck() { + if (binder.HasActiveBinder()) { + stack_depth = binder.GetActiveBinder().stack_depth; + } else { + stack_depth = 0; + } +} + +StackChecker ExpressionBinder::StackCheck(const ParsedExpression &expr, idx_t extra_stack) { + D_ASSERT(stack_depth != DConstants::INVALID_INDEX); + if (stack_depth + extra_stack >= MAXIMUM_STACK_DEPTH) { + throw BinderException("Maximum recursion depth exceeded (Maximum: %llu) while binding \"%s\"", + MAXIMUM_STACK_DEPTH, expr.ToString()); + } + return StackChecker(*this, extra_stack); +} + +BindResult ExpressionBinder::BindExpression(unique_ptr &expr, idx_t depth, bool root_expression) { + auto stack_checker = StackCheck(*expr); + + auto &expr_ref = *expr; + switch (expr_ref.expression_class) { + case ExpressionClass::BETWEEN: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::CASE: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::CAST: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::COLLATE: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::COLUMN_REF: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::COMPARISON: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::CONJUNCTION: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::CONSTANT: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::FUNCTION: { + auto &function = expr_ref.Cast(); + if (function.function_name == "unnest" || function.function_name == "unlist") { + // special case, not in catalog + return BindUnnest(function, depth, root_expression); + } + // binding function expression has extra parameter needed for macro's + return BindExpression(function, depth, expr); + } + case ExpressionClass::LAMBDA: + return BindExpression(expr_ref.Cast(), depth, false, LogicalTypeId::INVALID); + case ExpressionClass::OPERATOR: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::SUBQUERY: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::PARAMETER: + return BindExpression(expr_ref.Cast(), depth); + case ExpressionClass::POSITIONAL_REFERENCE: { + return BindPositionalReference(expr, depth, root_expression); + } + case ExpressionClass::STAR: + return BindResult(binder.FormatError(expr_ref, "STAR expression is not supported here")); + default: + throw NotImplementedException("Unimplemented expression class"); + } +} + +bool ExpressionBinder::BindCorrelatedColumns(unique_ptr &expr) { + // try to bind in one of the outer queries, if the binding error occurred in a subquery + auto &active_binders = binder.GetActiveBinders(); + // make a copy of the set of binders, so we can restore it later + auto binders = active_binders; + + // we already failed with the current binder + active_binders.pop_back(); + idx_t depth = 1; + bool success = false; + + while (!active_binders.empty()) { + auto &next_binder = active_binders.back().get(); + ExpressionBinder::QualifyColumnNames(next_binder.binder, expr); + auto bind_result = next_binder.Bind(expr, depth); + if (bind_result.empty()) { + success = true; + break; + } + depth++; + active_binders.pop_back(); + } + active_binders = binders; + return success; +} + +void ExpressionBinder::BindChild(unique_ptr &expr, idx_t depth, string &error) { + if (expr) { + string bind_error = Bind(expr, depth); + if (error.empty()) { + error = bind_error; + } + } +} + +void ExpressionBinder::ExtractCorrelatedExpressions(Binder &binder, Expression &expr) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + auto &bound_colref = expr.Cast(); + if (bound_colref.depth > 0) { + binder.AddCorrelatedColumn(CorrelatedColumnInfo(bound_colref)); + } + } + ExpressionIterator::EnumerateChildren(expr, + [&](Expression &child) { ExtractCorrelatedExpressions(binder, child); }); +} + +bool ExpressionBinder::ContainsType(const LogicalType &type, LogicalTypeId target) { + if (type.id() == target) { + return true; + } + switch (type.id()) { + case LogicalTypeId::STRUCT: { + auto child_count = StructType::GetChildCount(type); + for (idx_t i = 0; i < child_count; i++) { + if (ContainsType(StructType::GetChildType(type, i), target)) { + return true; + } + } + return false; + } + case LogicalTypeId::UNION: { + auto member_count = UnionType::GetMemberCount(type); + for (idx_t i = 0; i < member_count; i++) { + if (ContainsType(UnionType::GetMemberType(type, i), target)) { + return true; + } + } + return false; + } + case LogicalTypeId::LIST: + case LogicalTypeId::MAP: + return ContainsType(ListType::GetChildType(type), target); + default: + return false; + } +} + +LogicalType ExpressionBinder::ExchangeType(const LogicalType &type, LogicalTypeId target, LogicalType new_type) { + if (type.id() == target) { + return new_type; + } + switch (type.id()) { + case LogicalTypeId::STRUCT: { + // we make a copy of the child types of the struct here + auto child_types = StructType::GetChildTypes(type); + for (auto &child_type : child_types) { + child_type.second = ExchangeType(child_type.second, target, new_type); + } + return LogicalType::STRUCT(child_types); + } + case LogicalTypeId::UNION: { + auto member_types = UnionType::CopyMemberTypes(type); + for (auto &member_type : member_types) { + member_type.second = ExchangeType(member_type.second, target, new_type); + } + return LogicalType::UNION(std::move(member_types)); + } + case LogicalTypeId::LIST: + return LogicalType::LIST(ExchangeType(ListType::GetChildType(type), target, new_type)); + case LogicalTypeId::MAP: + return LogicalType::MAP(ExchangeType(ListType::GetChildType(type), target, new_type)); + default: + return type; + } +} + +bool ExpressionBinder::ContainsNullType(const LogicalType &type) { + return ContainsType(type, LogicalTypeId::SQLNULL); +} + +LogicalType ExpressionBinder::ExchangeNullType(const LogicalType &type) { + return ExchangeType(type, LogicalTypeId::SQLNULL, LogicalType::INTEGER); +} + +unique_ptr ExpressionBinder::Bind(unique_ptr &expr, optional_ptr result_type, + bool root_expression) { + // bind the main expression + auto error_msg = Bind(expr, 0, root_expression); + if (!error_msg.empty()) { + // failed to bind: try to bind correlated columns in the expression (if any) + bool success = BindCorrelatedColumns(expr); + if (!success) { + throw BinderException(error_msg); + } + auto &bound_expr = expr->Cast(); + ExtractCorrelatedExpressions(binder, *bound_expr.expr); + } + auto &bound_expr = expr->Cast(); + unique_ptr result = std::move(bound_expr.expr); + if (target_type.id() != LogicalTypeId::INVALID) { + // the binder has a specific target type: add a cast to that type + result = BoundCastExpression::AddCastToType(context, std::move(result), target_type); + } else { + if (!binder.can_contain_nulls) { + // SQL NULL type is only used internally in the binder + // cast to INTEGER if we encounter it outside of the binder + if (ContainsNullType(result->return_type)) { + auto exchanged_type = ExchangeNullType(result->return_type); + result = BoundCastExpression::AddCastToType(context, std::move(result), exchanged_type); + } + } + if (result->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } + if (result_type) { + *result_type = result->return_type; + } + return result; +} + +string ExpressionBinder::Bind(unique_ptr &expr, idx_t depth, bool root_expression) { + // bind the node, but only if it has not been bound yet + auto &expression = *expr; + auto alias = expression.alias; + if (expression.GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION) { + // already bound, don't bind it again + return string(); + } + // bind the expression + BindResult result = BindExpression(expr, depth, root_expression); + if (result.HasError()) { + return result.error; + } + // successfully bound: replace the node with a BoundExpression + expr = make_uniq(std::move(result.expression)); + auto &be = expr->Cast(); + be.alias = alias; + if (!alias.empty()) { + be.expr->alias = alias; + } + return string(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/aggregate_binder.cpp b/src/duckdb/src/planner/expression_binder/aggregate_binder.cpp new file mode 100644 index 00000000..cf172934 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/aggregate_binder.cpp @@ -0,0 +1,23 @@ +#include "duckdb/planner/expression_binder/aggregate_binder.hpp" + +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +AggregateBinder::AggregateBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context, true) { +} + +BindResult AggregateBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.expression_class) { + case ExpressionClass::WINDOW: + throw ParserException("aggregate function calls cannot contain window function calls"); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string AggregateBinder::UnsupportedAggregateMessage() { + return "aggregate function calls cannot be nested"; +} +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/alter_binder.cpp b/src/duckdb/src/planner/expression_binder/alter_binder.cpp new file mode 100644 index 00000000..597ef0ca --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/alter_binder.cpp @@ -0,0 +1,49 @@ +#include "duckdb/planner/expression_binder/alter_binder.hpp" + +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +namespace duckdb { + +AlterBinder::AlterBinder(Binder &binder, ClientContext &context, TableCatalogEntry &table, + vector &bound_columns, LogicalType target_type) + : ExpressionBinder(binder, context), table(table), bound_columns(bound_columns) { + this->target_type = std::move(target_type); +} + +BindResult AlterBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.GetExpressionClass()) { + case ExpressionClass::WINDOW: + return BindResult("window functions are not allowed in alter statement"); + case ExpressionClass::SUBQUERY: + return BindResult("cannot use subquery in alter statement"); + case ExpressionClass::COLUMN_REF: + return BindColumn(expr.Cast()); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string AlterBinder::UnsupportedAggregateMessage() { + return "aggregate functions are not allowed in alter statement"; +} + +BindResult AlterBinder::BindColumn(ColumnRefExpression &colref) { + if (colref.column_names.size() > 1) { + return BindQualifiedColumnName(colref, table.name); + } + auto idx = table.GetColumnIndex(colref.column_names[0], true); + if (!idx.IsValid()) { + throw BinderException("Table does not contain column %s referenced in alter statement!", + colref.column_names[0]); + } + if (table.GetColumn(idx).Generated()) { + throw BinderException("Using generated columns in alter statement not supported"); + } + bound_columns.push_back(idx); + return BindResult(make_uniq(table.GetColumn(idx).Type(), bound_columns.size() - 1)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/base_select_binder.cpp b/src/duckdb/src/planner/expression_binder/base_select_binder.cpp new file mode 100644 index 00000000..542bfa52 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/base_select_binder.cpp @@ -0,0 +1,161 @@ +#include "duckdb/planner/expression_binder/base_select_binder.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/expression_binder/aggregate_binder.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" + +namespace duckdb { + +BaseSelectBinder::BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, + BoundGroupInformation &info, case_insensitive_map_t alias_map) + : ExpressionBinder(binder, context), inside_window(false), node(node), info(info), alias_map(std::move(alias_map)) { +} + +BaseSelectBinder::BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, + BoundGroupInformation &info) + : BaseSelectBinder(binder, context, node, info, case_insensitive_map_t()) { +} + +BindResult BaseSelectBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + // check if the expression binds to one of the groups + auto group_index = TryBindGroup(expr, depth); + if (group_index != DConstants::INVALID_INDEX) { + return BindGroup(expr, depth, group_index); + } + switch (expr.expression_class) { + case ExpressionClass::COLUMN_REF: + return BindColumnRef(expr_ptr, depth); + case ExpressionClass::DEFAULT: + return BindResult("SELECT clause cannot contain DEFAULT clause"); + case ExpressionClass::WINDOW: + return BindWindow(expr.Cast(), depth); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth, root_expression); + } +} + +idx_t BaseSelectBinder::TryBindGroup(ParsedExpression &expr, idx_t depth) { + // first check the group alias map, if expr is a ColumnRefExpression + if (expr.type == ExpressionType::COLUMN_REF) { + auto &colref = expr.Cast(); + if (!colref.IsQualified()) { + auto alias_entry = info.alias_map.find(colref.column_names[0]); + if (alias_entry != info.alias_map.end()) { + // found entry! + return alias_entry->second; + } + } + } + // no alias reference found + // check the list of group columns for a match + auto entry = info.map.find(expr); + if (entry != info.map.end()) { + return entry->second; + } +#ifdef DEBUG + for (auto entry : info.map) { + D_ASSERT(!entry.first.get().Equals(expr)); + D_ASSERT(!expr.Equals(entry.first.get())); + } +#endif + return DConstants::INVALID_INDEX; +} + +BindResult BaseSelectBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth) { + // first try to bind the column reference regularly + auto result = ExpressionBinder::BindExpression(expr_ptr, depth); + if (!result.HasError()) { + return result; + } + // binding failed + // check in the alias map + auto &colref = (expr_ptr.get())->Cast(); + if (!colref.IsQualified()) { + auto alias_entry = alias_map.find(colref.column_names[0]); + if (alias_entry != alias_map.end()) { + // found entry! + auto index = alias_entry->second; + if (index >= node.select_list.size()) { + throw BinderException("Column \"%s\" referenced that exists in the SELECT clause - but this column " + "cannot be referenced before it is defined", + colref.column_names[0]); + } + if (node.select_list[index]->HasSideEffects()) { + throw BinderException("Alias \"%s\" referenced in a SELECT clause - but the expression has side " + "effects. This is not yet supported.", + colref.column_names[0]); + } + if (node.select_list[index]->HasSubquery()) { + throw BinderException("Alias \"%s\" referenced in a SELECT clause - but the expression has a subquery." + " This is not yet supported.", + colref.column_names[0]); + } + auto result = BindResult(node.select_list[index]->Copy()); + if (result.expression->type == ExpressionType::BOUND_COLUMN_REF) { + auto &result_expr = result.expression->Cast(); + result_expr.depth = depth; + } + return result; + } + } + // entry was not found in the alias map: return the original error + return result; +} + +BindResult BaseSelectBinder::BindGroupingFunction(OperatorExpression &op, idx_t depth) { + if (op.children.empty()) { + throw InternalException("GROUPING requires at least one child"); + } + if (node.groups.group_expressions.empty()) { + return BindResult(binder.FormatError(op, "GROUPING statement cannot be used without groups")); + } + if (op.children.size() >= 64) { + return BindResult(binder.FormatError(op, "GROUPING statement cannot have more than 64 groups")); + } + vector group_indexes; + group_indexes.reserve(op.children.size()); + for (auto &child : op.children) { + ExpressionBinder::QualifyColumnNames(binder, child); + auto idx = TryBindGroup(*child, depth); + if (idx == DConstants::INVALID_INDEX) { + return BindResult(binder.FormatError( + op, StringUtil::Format("GROUPING child \"%s\" must be a grouping column", child->GetName()))); + } + group_indexes.push_back(idx); + } + auto col_idx = node.grouping_functions.size(); + node.grouping_functions.push_back(std::move(group_indexes)); + return BindResult(make_uniq(op.GetName(), LogicalType::BIGINT, + ColumnBinding(node.groupings_index, col_idx), depth)); +} + +BindResult BaseSelectBinder::BindGroup(ParsedExpression &expr, idx_t depth, idx_t group_index) { + auto it = info.collated_groups.find(group_index); + if (it != info.collated_groups.end()) { + // This is an implicitly collated group, so we need to refer to the first() aggregate + const auto &aggr_index = it->second; + return BindResult(make_uniq(expr.GetName(), node.aggregates[aggr_index]->return_type, + ColumnBinding(node.aggregate_index, aggr_index), depth)); + } else { + auto &group = node.groups.group_expressions[group_index]; + return BindResult(make_uniq(expr.GetName(), group->return_type, + ColumnBinding(node.group_index, group_index), depth)); + } +} + +bool BaseSelectBinder::QualifyColumnAlias(const ColumnRefExpression &colref) { + if (!colref.IsQualified()) { + return alias_map.find(colref.column_names[0]) != alias_map.end() ? true : false; + } + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/check_binder.cpp b/src/duckdb/src/planner/expression_binder/check_binder.cpp new file mode 100644 index 00000000..698d9e06 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/check_binder.cpp @@ -0,0 +1,76 @@ +#include "duckdb/planner/expression_binder/check_binder.hpp" + +#include "duckdb/planner/table_binding.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +CheckBinder::CheckBinder(Binder &binder, ClientContext &context, string table_p, const ColumnList &columns, + physical_index_set_t &bound_columns) + : ExpressionBinder(binder, context), table(std::move(table_p)), columns(columns), bound_columns(bound_columns) { + target_type = LogicalType::INTEGER; +} + +BindResult CheckBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.GetExpressionClass()) { + case ExpressionClass::WINDOW: + return BindResult("window functions are not allowed in check constraints"); + case ExpressionClass::SUBQUERY: + return BindResult("cannot use subquery in check constraint"); + case ExpressionClass::COLUMN_REF: + return BindCheckColumn(expr.Cast()); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string CheckBinder::UnsupportedAggregateMessage() { + return "aggregate functions are not allowed in check constraints"; +} + +BindResult ExpressionBinder::BindQualifiedColumnName(ColumnRefExpression &colref, const string &table_name) { + idx_t struct_start = 0; + if (colref.column_names[0] == table_name) { + struct_start++; + } + auto result = make_uniq_base(colref.column_names.back()); + for (idx_t i = struct_start; i + 1 < colref.column_names.size(); i++) { + result = CreateStructExtract(std::move(result), colref.column_names[i]); + } + return BindExpression(result, 0); +} + +BindResult CheckBinder::BindCheckColumn(ColumnRefExpression &colref) { + + // if this is a lambda parameters, then we temporarily add a BoundLambdaRef, + // which we capture and remove later + if (lambda_bindings) { + for (idx_t i = 0; i < lambda_bindings->size(); i++) { + if (colref.GetColumnName() == (*lambda_bindings)[i].dummy_name) { + // FIXME: support lambdas in CHECK constraints + // FIXME: like so: return (*lambda_bindings)[i].Bind(colref, i, depth); + throw NotImplementedException("Lambda functions are currently not supported in CHECK constraints."); + } + } + } + + if (colref.column_names.size() > 1) { + return BindQualifiedColumnName(colref, table); + } + if (!columns.ColumnExists(colref.column_names[0])) { + throw BinderException("Table does not contain column %s referenced in check constraint!", + colref.column_names[0]); + } + auto &col = columns.GetColumn(colref.column_names[0]); + if (col.Generated()) { + auto bound_expression = col.GeneratedExpression().Copy(); + return BindExpression(bound_expression, 0, false); + } + bound_columns.insert(col.Physical()); + D_ASSERT(col.StorageOid() != DConstants::INVALID_INDEX); + return BindResult(make_uniq(col.Type(), col.StorageOid())); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp b/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp new file mode 100644 index 00000000..1e0a0247 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp @@ -0,0 +1,41 @@ +#include "duckdb/planner/expression_binder/column_alias_binder.hpp" + +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +ColumnAliasBinder::ColumnAliasBinder(BoundSelectNode &node, const case_insensitive_map_t &alias_map) + : node(node), alias_map(alias_map), visited_select_indexes() { +} + +BindResult ColumnAliasBinder::BindAlias(ExpressionBinder &enclosing_binder, ColumnRefExpression &expr, idx_t depth, + bool root_expression) { + if (expr.IsQualified()) { + return BindResult(StringUtil::Format("Alias %s cannot be qualified.", expr.ToString())); + } + + auto alias_entry = alias_map.find(expr.column_names[0]); + if (alias_entry == alias_map.end()) { + return BindResult(StringUtil::Format("Alias %s is not found.", expr.ToString())); + } + + if (visited_select_indexes.find(alias_entry->second) != visited_select_indexes.end()) { + return BindResult("Cannot resolve self-referential alias"); + } + + // found an alias: bind the alias expression + auto expression = node.original_expressions[alias_entry->second]->Copy(); + visited_select_indexes.insert(alias_entry->second); + + // since the alias has been found, pass a depth of 0. See Issue 4978 (#16) + // ColumnAliasBinders are only in Having, Qualify and Where Binders + auto result = enclosing_binder.BindExpression(expression, 0, root_expression); + visited_select_indexes.erase(alias_entry->second); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/constant_binder.cpp b/src/duckdb/src/planner/expression_binder/constant_binder.cpp new file mode 100644 index 00000000..fb93f363 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/constant_binder.cpp @@ -0,0 +1,39 @@ +#include "duckdb/planner/expression_binder/constant_binder.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" + +namespace duckdb { + +ConstantBinder::ConstantBinder(Binder &binder, ClientContext &context, string clause) + : ExpressionBinder(binder, context), clause(std::move(clause)) { +} + +BindResult ConstantBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.GetExpressionClass()) { + case ExpressionClass::COLUMN_REF: { + auto &colref = expr.Cast(); + if (!colref.IsQualified()) { + auto value_function = GetSQLValueFunction(colref.GetColumnName()); + if (value_function) { + expr_ptr = std::move(value_function); + return BindExpression(expr_ptr, depth, root_expression); + } + } + return BindResult(clause + " cannot contain column names"); + } + case ExpressionClass::SUBQUERY: + throw BinderException(clause + " cannot contain subqueries"); + case ExpressionClass::DEFAULT: + return BindResult(clause + " cannot contain DEFAULT clause"); + case ExpressionClass::WINDOW: + return BindResult(clause + " cannot contain window functions!"); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string ConstantBinder::UnsupportedAggregateMessage() { + return clause + " cannot contain aggregates!"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/group_binder.cpp b/src/duckdb/src/planner/expression_binder/group_binder.cpp new file mode 100644 index 00000000..f2fb36a5 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/group_binder.cpp @@ -0,0 +1,110 @@ +#include "duckdb/planner/expression_binder/group_binder.hpp" + +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/common/to_string.hpp" + +namespace duckdb { + +GroupBinder::GroupBinder(Binder &binder, ClientContext &context, SelectNode &node, idx_t group_index, + case_insensitive_map_t &alias_map, case_insensitive_map_t &group_alias_map) + : ExpressionBinder(binder, context), node(node), alias_map(alias_map), group_alias_map(group_alias_map), + group_index(group_index) { +} + +BindResult GroupBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + if (root_expression && depth == 0) { + switch (expr.expression_class) { + case ExpressionClass::COLUMN_REF: + return BindColumnRef(expr.Cast()); + case ExpressionClass::CONSTANT: + return BindConstant(expr.Cast()); + case ExpressionClass::PARAMETER: + throw ParameterNotAllowedException("Parameter not supported in GROUP BY clause"); + default: + break; + } + } + switch (expr.expression_class) { + case ExpressionClass::DEFAULT: + return BindResult("GROUP BY clause cannot contain DEFAULT clause"); + case ExpressionClass::WINDOW: + return BindResult("GROUP BY clause cannot contain window functions!"); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string GroupBinder::UnsupportedAggregateMessage() { + return "GROUP BY clause cannot contain aggregates!"; +} + +BindResult GroupBinder::BindSelectRef(idx_t entry) { + if (used_aliases.find(entry) != used_aliases.end()) { + // the alias has already been bound to before! + // this happens if we group on the same alias twice + // e.g. GROUP BY k, k or GROUP BY 1, 1 + // in this case, we can just replace the grouping with a constant since the second grouping has no effect + // (the constant grouping will be optimized out later) + return BindResult(make_uniq(Value::INTEGER(42))); + } + if (entry >= node.select_list.size()) { + throw BinderException("GROUP BY term out of range - should be between 1 and %d", (int)node.select_list.size()); + } + // we replace the root expression, also replace the unbound expression + unbound_expression = node.select_list[entry]->Copy(); + // move the expression that this refers to here and bind it + auto select_entry = std::move(node.select_list[entry]); + auto binding = Bind(select_entry, nullptr, false); + // now replace the original expression in the select list with a reference to this group + group_alias_map[to_string(entry)] = bind_index; + node.select_list[entry] = make_uniq(to_string(entry)); + // insert into the set of used aliases + used_aliases.insert(entry); + return BindResult(std::move(binding)); +} + +BindResult GroupBinder::BindConstant(ConstantExpression &constant) { + // constant as root expression + if (!constant.value.type().IsIntegral()) { + // non-integral expression, we just leave the constant here. + return ExpressionBinder::BindExpression(constant, 0); + } + // INTEGER constant: we use the integer as an index into the select list (e.g. GROUP BY 1) + auto index = (idx_t)constant.value.GetValue(); + return BindSelectRef(index - 1); +} + +BindResult GroupBinder::BindColumnRef(ColumnRefExpression &colref) { + // columns in GROUP BY clauses: + // FIRST refer to the original tables, and + // THEN if no match is found refer to aliases in the SELECT list + // THEN if no match is found, refer to outer queries + + // first try to bind to the base columns (original tables) + auto result = ExpressionBinder::BindExpression(colref, 0); + if (result.HasError()) { + if (colref.IsQualified()) { + // explicit table name: not an alias reference + return result; + } + // failed to bind the column and the node is the root expression with depth = 0 + // check if refers to an alias in the select clause + auto alias_name = colref.column_names[0]; + auto entry = alias_map.find(alias_name); + if (entry == alias_map.end()) { + // no matching alias found + return result; + } + result = BindResult(BindSelectRef(entry->second)); + if (!result.HasError()) { + group_alias_map[alias_name] = bind_index; + } + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/having_binder.cpp b/src/duckdb/src/planner/expression_binder/having_binder.cpp new file mode 100644 index 00000000..75e92476 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/having_binder.cpp @@ -0,0 +1,61 @@ +#include "duckdb/planner/expression_binder/having_binder.hpp" + +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder/aggregate_binder.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" + +namespace duckdb { + +HavingBinder::HavingBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, + case_insensitive_map_t &alias_map, AggregateHandling aggregate_handling) + : BaseSelectBinder(binder, context, node, info), column_alias_binder(node, alias_map), + aggregate_handling(aggregate_handling) { + target_type = LogicalType(LogicalTypeId::BOOLEAN); +} + +BindResult HavingBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = expr_ptr->Cast(); + auto alias_result = column_alias_binder.BindAlias(*this, expr, depth, root_expression); + if (!alias_result.HasError()) { + if (depth > 0) { + throw BinderException("Having clause cannot reference alias in correlated subquery"); + } + return alias_result; + } + if (aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { + if (depth > 0) { + throw BinderException("Having clause cannot reference column in correlated subquery and group by all"); + } + auto expr = duckdb::BaseSelectBinder::BindExpression(expr_ptr, depth); + if (expr.HasError()) { + return expr; + } + auto group_ref = make_uniq( + expr.expression->return_type, ColumnBinding(node.group_index, node.groups.group_expressions.size())); + node.groups.group_expressions.push_back(std::move(expr.expression)); + return BindResult(std::move(group_ref)); + } + return BindResult(StringUtil::Format( + "column %s must appear in the GROUP BY clause or be used in an aggregate function", expr.ToString())); +} + +BindResult HavingBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + // check if the expression binds to one of the groups + auto group_index = TryBindGroup(expr, depth); + if (group_index != DConstants::INVALID_INDEX) { + return BindGroup(expr, depth, group_index); + } + switch (expr.expression_class) { + case ExpressionClass::WINDOW: + return BindResult("HAVING clause cannot contain window functions!"); + case ExpressionClass::COLUMN_REF: + return BindColumnRef(expr_ptr, depth, root_expression); + default: + return duckdb::BaseSelectBinder::BindExpression(expr_ptr, depth); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/index_binder.cpp b/src/duckdb/src/planner/expression_binder/index_binder.cpp new file mode 100644 index 00000000..e7af2b42 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/index_binder.cpp @@ -0,0 +1,56 @@ +#include "duckdb/planner/expression_binder/index_binder.hpp" + +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/column_binding.hpp" + +namespace duckdb { + +IndexBinder::IndexBinder(Binder &binder, ClientContext &context, optional_ptr table, + optional_ptr info) + : ExpressionBinder(binder, context), table(table), info(info) { +} + +BindResult IndexBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.expression_class) { + case ExpressionClass::WINDOW: + return BindResult("window functions are not allowed in index expressions"); + case ExpressionClass::SUBQUERY: + return BindResult("cannot use subquery in index expressions"); + case ExpressionClass::COLUMN_REF: { + if (table) { + // WAL replay + // we assume that the parsed expressions have qualified column names + // and that the columns exist in the table + auto &col_ref = expr.Cast(); + auto col_idx = table->GetColumnIndex(col_ref.column_names.back()); + auto col_type = table->GetColumn(col_idx).GetType(); + + // find the col_idx in the index.column_ids + auto col_id_idx = DConstants::INVALID_INDEX; + for (idx_t i = 0; i < info->column_ids.size(); i++) { + if (col_idx.index == info->column_ids[i]) { + col_id_idx = i; + } + } + + if (col_id_idx == DConstants::INVALID_INDEX) { + throw InternalException("failed to replay CREATE INDEX statement - column id not found"); + } + return BindResult( + make_uniq(col_ref.GetColumnName(), col_type, ColumnBinding(0, col_id_idx))); + } + return ExpressionBinder::BindExpression(expr_ptr, depth); + } + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string IndexBinder::UnsupportedAggregateMessage() { + return "aggregate functions are not allowed in index expressions"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/insert_binder.cpp b/src/duckdb/src/planner/expression_binder/insert_binder.cpp new file mode 100644 index 00000000..84cd2a4f --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/insert_binder.cpp @@ -0,0 +1,26 @@ +#include "duckdb/planner/expression_binder/insert_binder.hpp" + +#include "duckdb/planner/expression/bound_default_expression.hpp" + +namespace duckdb { + +InsertBinder::InsertBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { +} + +BindResult InsertBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.GetExpressionClass()) { + case ExpressionClass::DEFAULT: + return BindResult("DEFAULT is not allowed here!"); + case ExpressionClass::WINDOW: + return BindResult("INSERT statement cannot contain window functions!"); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string InsertBinder::UnsupportedAggregateMessage() { + return "INSERT statement cannot contain aggregates!"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp new file mode 100644 index 00000000..18a68ff0 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp @@ -0,0 +1,120 @@ +#include "duckdb/planner/expression_binder/lateral_binder.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_subquery_expression.hpp" + +namespace duckdb { + +LateralBinder::LateralBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { +} + +void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + auto &bound_colref = expr.Cast(); + if (bound_colref.depth > 0) { + // add the correlated column info + CorrelatedColumnInfo info(bound_colref); + if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { + correlated_columns.push_back(std::move(info)); + } + } + } + ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { ExtractCorrelatedColumns(child); }); +} + +BindResult LateralBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + if (depth == 0) { + throw InternalException("Lateral binder can only bind correlated columns"); + } + auto result = ExpressionBinder::BindExpression(expr_ptr, depth); + if (result.HasError()) { + return result; + } + ExtractCorrelatedColumns(*result.expression); + return result; +} + +BindResult LateralBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.GetExpressionClass()) { + case ExpressionClass::DEFAULT: + return BindResult("LATERAL join cannot contain DEFAULT clause"); + case ExpressionClass::WINDOW: + return BindResult("LATERAL join cannot contain window functions!"); + case ExpressionClass::COLUMN_REF: + return BindColumnRef(expr_ptr, depth, root_expression); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string LateralBinder::UnsupportedAggregateMessage() { + return "LATERAL join cannot contain aggregates!"; +} + +class ExpressionDepthReducer : public LogicalOperatorVisitor { +public: + explicit ExpressionDepthReducer(const vector &correlated) : correlated_columns(correlated) { + } + +protected: + void ReduceColumnRefDepth(BoundColumnRefExpression &expr) { + // don't need to reduce this + if (expr.depth == 0) { + return; + } + for (auto &correlated : correlated_columns) { + if (correlated.binding == expr.binding) { + D_ASSERT(expr.depth > 1); + expr.depth--; + break; + } + } + } + + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override { + ReduceColumnRefDepth(expr); + return nullptr; + } + + void ReduceExpressionSubquery(BoundSubqueryExpression &expr) { + for (auto &s_correlated : expr.binder->correlated_columns) { + for (auto &correlated : correlated_columns) { + if (correlated == s_correlated) { + s_correlated.depth--; + break; + } + } + } + } + + void ReduceExpressionDepth(Expression &expr) { + if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + ReduceColumnRefDepth(expr.Cast()); + } + if (expr.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY) { + auto &subquery_ref = expr.Cast(); + ReduceExpressionSubquery(expr.Cast()); + // Recursively update the depth in the bindings of the children nodes + ExpressionIterator::EnumerateQueryNodeChildren( + *subquery_ref.subquery, [&](Expression &child_expr) { ReduceExpressionDepth(child_expr); }); + } + } + + unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr) override { + ReduceExpressionSubquery(expr); + ExpressionIterator::EnumerateQueryNodeChildren( + *expr.subquery, [&](Expression &child_expr) { ReduceExpressionDepth(child_expr); }); + return nullptr; + } + + const vector &correlated_columns; +}; + +void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector &correlated) { + ExpressionDepthReducer depth_reducer(correlated); + depth_reducer.VisitOperator(op); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/order_binder.cpp b/src/duckdb/src/planner/expression_binder/order_binder.cpp new file mode 100644 index 00000000..ee59c684 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/order_binder.cpp @@ -0,0 +1,134 @@ +#include "duckdb/planner/expression_binder/order_binder.hpp" + +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/parser/expression/positional_reference_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +OrderBinder::OrderBinder(vector binders, idx_t projection_index, case_insensitive_map_t &alias_map, + parsed_expression_map_t &projection_map, idx_t max_count) + : binders(std::move(binders)), projection_index(projection_index), max_count(max_count), extra_list(nullptr), + alias_map(alias_map), projection_map(projection_map) { +} +OrderBinder::OrderBinder(vector binders, idx_t projection_index, SelectNode &node, + case_insensitive_map_t &alias_map, parsed_expression_map_t &projection_map) + : binders(std::move(binders)), projection_index(projection_index), alias_map(alias_map), + projection_map(projection_map) { + this->max_count = node.select_list.size(); + this->extra_list = &node.select_list; +} + +unique_ptr OrderBinder::CreateProjectionReference(ParsedExpression &expr, idx_t index) { + string alias; + if (extra_list && index < extra_list->size()) { + alias = extra_list->at(index)->ToString(); + } else { + if (!expr.alias.empty()) { + alias = expr.alias; + } + } + return make_uniq(std::move(alias), LogicalType::INVALID, + ColumnBinding(projection_index, index)); +} + +unique_ptr OrderBinder::CreateExtraReference(unique_ptr expr) { + if (!extra_list) { + throw InternalException("CreateExtraReference called without extra_list"); + } + projection_map[*expr] = extra_list->size(); + auto result = CreateProjectionReference(*expr, extra_list->size()); + extra_list->push_back(std::move(expr)); + return result; +} + +unique_ptr OrderBinder::BindConstant(ParsedExpression &expr, const Value &val) { + // ORDER BY a constant + if (!val.type().IsIntegral()) { + // non-integral expression, we just leave the constant here. + // ORDER BY has no effect + // CONTROVERSIAL: maybe we should throw an error + return nullptr; + } + // INTEGER constant: we use the integer as an index into the select list (e.g. ORDER BY 1) + auto index = (idx_t)val.GetValue(); + if (index < 1 || index > max_count) { + throw BinderException("ORDER term out of range - should be between 1 and %lld", (idx_t)max_count); + } + return CreateProjectionReference(expr, index - 1); +} + +unique_ptr OrderBinder::Bind(unique_ptr expr) { + // in the ORDER BY clause we do not bind children + // we bind ONLY to the select list + // if there is no matching entry in the SELECT list already, we add the expression to the SELECT list and refer the + // new expression the new entry will then be bound later during the binding of the SELECT list we also don't do type + // resolution here: this only happens after the SELECT list has been bound + switch (expr->expression_class) { + case ExpressionClass::CONSTANT: { + // ORDER BY constant + // is the ORDER BY expression a constant integer? (e.g. ORDER BY 1) + auto &constant = expr->Cast(); + return BindConstant(*expr, constant.value); + } + case ExpressionClass::COLUMN_REF: { + // COLUMN REF expression + // check if we can bind it to an alias in the select list + auto &colref = expr->Cast(); + // if there is an explicit table name we can't bind to an alias + if (colref.IsQualified()) { + break; + } + // check the alias list + auto entry = alias_map.find(colref.column_names[0]); + if (entry != alias_map.end()) { + // it does! point it to that entry + return CreateProjectionReference(*expr, entry->second); + } + break; + } + case ExpressionClass::POSITIONAL_REFERENCE: { + auto &posref = expr->Cast(); + if (posref.index < 1 || posref.index > max_count) { + throw BinderException("ORDER term out of range - should be between 1 and %lld", (idx_t)max_count); + } + return CreateProjectionReference(*expr, posref.index - 1); + } + case ExpressionClass::PARAMETER: { + throw ParameterNotAllowedException("Parameter not supported in ORDER BY clause"); + } + default: + break; + } + // general case + // first bind the table names of this entry + for (auto &binder : binders) { + ExpressionBinder::QualifyColumnNames(*binder, expr); + } + // first check if the ORDER BY clause already points to an entry in the projection list + auto entry = projection_map.find(*expr); + if (entry != projection_map.end()) { + if (entry->second == DConstants::INVALID_INDEX) { + throw BinderException("Ambiguous reference to column"); + } + // there is a matching entry in the projection list + // just point to that entry + return CreateProjectionReference(*expr, entry->second); + } + if (!extra_list) { + // no extra list specified: we cannot push an extra ORDER BY clause + throw BinderException("Could not ORDER BY column \"%s\": add the expression/function to every SELECT, or move " + "the UNION into a FROM clause.", + expr->ToString()); + } + // otherwise we need to push the ORDER BY entry into the select list + return CreateExtraReference(std::move(expr)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/qualify_binder.cpp b/src/duckdb/src/planner/expression_binder/qualify_binder.cpp new file mode 100644 index 00000000..59a97e61 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/qualify_binder.cpp @@ -0,0 +1,51 @@ +#include "duckdb/planner/expression_binder/qualify_binder.hpp" + +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression_binder/aggregate_binder.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/parser/expression/window_expression.hpp" + +namespace duckdb { + +QualifyBinder::QualifyBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, + case_insensitive_map_t &alias_map) + : BaseSelectBinder(binder, context, node, info), column_alias_binder(node, alias_map) { + target_type = LogicalType(LogicalTypeId::BOOLEAN); +} + +BindResult QualifyBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = expr_ptr->Cast(); + auto result = duckdb::BaseSelectBinder::BindExpression(expr_ptr, depth); + if (!result.HasError()) { + return result; + } + + auto alias_result = column_alias_binder.BindAlias(*this, expr, depth, root_expression); + if (!alias_result.HasError()) { + return alias_result; + } + + return BindResult(StringUtil::Format("Referenced column %s not found in FROM clause and can't find in alias map.", + expr.ToString())); +} + +BindResult QualifyBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + // check if the expression binds to one of the groups + auto group_index = TryBindGroup(expr, depth); + if (group_index != DConstants::INVALID_INDEX) { + return BindGroup(expr, depth, group_index); + } + switch (expr.expression_class) { + case ExpressionClass::WINDOW: + return BindWindow(expr.Cast(), depth); + case ExpressionClass::COLUMN_REF: + return BindColumnRef(expr_ptr, depth, root_expression); + default: + return duckdb::BaseSelectBinder::BindExpression(expr_ptr, depth); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/relation_binder.cpp b/src/duckdb/src/planner/expression_binder/relation_binder.cpp new file mode 100644 index 00000000..4e041574 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/relation_binder.cpp @@ -0,0 +1,29 @@ +#include "duckdb/planner/expression_binder/relation_binder.hpp" + +namespace duckdb { + +RelationBinder::RelationBinder(Binder &binder, ClientContext &context, string op) + : ExpressionBinder(binder, context), op(std::move(op)) { +} + +BindResult RelationBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.expression_class) { + case ExpressionClass::AGGREGATE: + return BindResult("aggregate functions are not allowed in " + op); + case ExpressionClass::DEFAULT: + return BindResult(op + " cannot contain DEFAULT clause"); + case ExpressionClass::SUBQUERY: + return BindResult("subqueries are not allowed in " + op); + case ExpressionClass::WINDOW: + return BindResult("window functions are not allowed in " + op); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string RelationBinder::UnsupportedAggregateMessage() { + return "aggregate functions are not allowed in " + op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/returning_binder.cpp b/src/duckdb/src/planner/expression_binder/returning_binder.cpp new file mode 100644 index 00000000..8d731b21 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/returning_binder.cpp @@ -0,0 +1,24 @@ +#include "duckdb/planner/expression_binder/returning_binder.hpp" + +#include "duckdb/planner/expression/bound_default_expression.hpp" + +namespace duckdb { + +ReturningBinder::ReturningBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { +} + +BindResult ReturningBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.GetExpressionClass()) { + case ExpressionClass::SUBQUERY: + return BindResult("SUBQUERY is not supported in returning statements"); + case ExpressionClass::BOUND_SUBQUERY: + return BindResult("BOUND SUBQUERY is not supported in returning statements"); + case ExpressionClass::COLUMN_REF: + return ExpressionBinder::BindExpression(expr_ptr, depth); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/select_binder.cpp b/src/duckdb/src/planner/expression_binder/select_binder.cpp new file mode 100644 index 00000000..d83b2e6c --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/select_binder.cpp @@ -0,0 +1,14 @@ +#include "duckdb/planner/expression_binder/select_binder.hpp" + +namespace duckdb { + +SelectBinder::SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, + case_insensitive_map_t alias_map) + : BaseSelectBinder(binder, context, node, info, std::move(alias_map)) { +} + +SelectBinder::SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info) + : SelectBinder(binder, context, node, info, case_insensitive_map_t()) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp new file mode 100644 index 00000000..30010b1b --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp @@ -0,0 +1,53 @@ +#include "duckdb/planner/expression_binder/table_function_binder.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/table_binding.hpp" + +namespace duckdb { + +TableFunctionBinder::TableFunctionBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { +} + +BindResult TableFunctionBinder::BindColumnReference(ColumnRefExpression &expr, idx_t depth, bool root_expression) { + + // if this is a lambda parameters, then we temporarily add a BoundLambdaRef, + // which we capture and remove later + if (lambda_bindings) { + auto &colref = expr.Cast(); + for (idx_t i = 0; i < lambda_bindings->size(); i++) { + if (colref.GetColumnName() == (*lambda_bindings)[i].dummy_name) { + return (*lambda_bindings)[i].Bind(colref, i, depth); + } + } + } + auto value_function = ExpressionBinder::GetSQLValueFunction(expr.GetColumnName()); + if (value_function) { + return BindExpression(value_function, depth, root_expression); + } + + auto result_name = StringUtil::Join(expr.column_names, "."); + return BindResult(make_uniq(Value(result_name))); +} + +BindResult TableFunctionBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, + bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.GetExpressionClass()) { + case ExpressionClass::COLUMN_REF: + return BindColumnReference(expr.Cast(), depth, root_expression); + case ExpressionClass::SUBQUERY: + throw BinderException("Table function cannot contain subqueries"); + case ExpressionClass::DEFAULT: + return BindResult("Table function cannot contain DEFAULT clause"); + case ExpressionClass::WINDOW: + return BindResult("Table function cannot contain window functions!"); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string TableFunctionBinder::UnsupportedAggregateMessage() { + return "Table function cannot contain aggregates!"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/update_binder.cpp b/src/duckdb/src/planner/expression_binder/update_binder.cpp new file mode 100644 index 00000000..4537957e --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/update_binder.cpp @@ -0,0 +1,22 @@ +#include "duckdb/planner/expression_binder/update_binder.hpp" + +namespace duckdb { + +UpdateBinder::UpdateBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { +} + +BindResult UpdateBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.expression_class) { + case ExpressionClass::WINDOW: + return BindResult("window functions are not allowed in UPDATE"); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string UpdateBinder::UnsupportedAggregateMessage() { + return "aggregate functions are not allowed in UPDATE"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/where_binder.cpp b/src/duckdb/src/planner/expression_binder/where_binder.cpp new file mode 100644 index 00000000..182687f8 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/where_binder.cpp @@ -0,0 +1,46 @@ +#include "duckdb/planner/expression_binder/where_binder.hpp" +#include "duckdb/planner/expression_binder/column_alias_binder.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" + +namespace duckdb { + +WhereBinder::WhereBinder(Binder &binder, ClientContext &context, optional_ptr column_alias_binder) + : ExpressionBinder(binder, context), column_alias_binder(column_alias_binder) { + target_type = LogicalType(LogicalTypeId::BOOLEAN); +} + +BindResult WhereBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = expr_ptr->Cast(); + auto result = ExpressionBinder::BindExpression(expr_ptr, depth); + if (!result.HasError() || !column_alias_binder) { + return result; + } + + BindResult alias_result = column_alias_binder->BindAlias(*this, expr, depth, root_expression); + // This code path cannot be exercised at thispoint. #1547 might change that. + if (!alias_result.HasError()) { + return alias_result; + } + + return result; +} + +BindResult WhereBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + auto &expr = *expr_ptr; + switch (expr.GetExpressionClass()) { + case ExpressionClass::DEFAULT: + return BindResult("WHERE clause cannot contain DEFAULT clause"); + case ExpressionClass::WINDOW: + return BindResult("WHERE clause cannot contain window functions!"); + case ExpressionClass::COLUMN_REF: + return BindColumnRef(expr_ptr, depth, root_expression); + default: + return ExpressionBinder::BindExpression(expr_ptr, depth); + } +} + +string WhereBinder::UnsupportedAggregateMessage() { + return "WHERE clause cannot contain aggregates!"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_iterator.cpp b/src/duckdb/src/planner/expression_iterator.cpp new file mode 100644 index 00000000..34e6bc10 --- /dev/null +++ b/src/duckdb/src/planner/expression_iterator.cpp @@ -0,0 +1,254 @@ +#include "duckdb/planner/expression_iterator.hpp" + +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/query_node/bound_set_operation_node.hpp" +#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" +#include "duckdb/planner/query_node/bound_cte_node.hpp" +#include "duckdb/planner/tableref/list.hpp" + +namespace duckdb { + +void ExpressionIterator::EnumerateChildren(const Expression &expr, + const std::function &callback) { + EnumerateChildren((Expression &)expr, [&](unique_ptr &child) { callback(*child); }); +} + +void ExpressionIterator::EnumerateChildren(Expression &expr, const std::function &callback) { + EnumerateChildren(expr, [&](unique_ptr &child) { callback(*child); }); +} + +void ExpressionIterator::EnumerateChildren(Expression &expr, + const std::function &child)> &callback) { + switch (expr.expression_class) { + case ExpressionClass::BOUND_AGGREGATE: { + auto &aggr_expr = expr.Cast(); + for (auto &child : aggr_expr.children) { + callback(child); + } + if (aggr_expr.filter) { + callback(aggr_expr.filter); + } + if (aggr_expr.order_bys) { + for (auto &order : aggr_expr.order_bys->orders) { + callback(order.expression); + } + } + break; + } + case ExpressionClass::BOUND_BETWEEN: { + auto &between_expr = expr.Cast(); + callback(between_expr.input); + callback(between_expr.lower); + callback(between_expr.upper); + break; + } + case ExpressionClass::BOUND_CASE: { + auto &case_expr = expr.Cast(); + for (auto &case_check : case_expr.case_checks) { + callback(case_check.when_expr); + callback(case_check.then_expr); + } + callback(case_expr.else_expr); + break; + } + case ExpressionClass::BOUND_CAST: { + auto &cast_expr = expr.Cast(); + callback(cast_expr.child); + break; + } + case ExpressionClass::BOUND_COMPARISON: { + auto &comp_expr = expr.Cast(); + callback(comp_expr.left); + callback(comp_expr.right); + break; + } + case ExpressionClass::BOUND_CONJUNCTION: { + auto &conj_expr = expr.Cast(); + for (auto &child : conj_expr.children) { + callback(child); + } + break; + } + case ExpressionClass::BOUND_FUNCTION: { + auto &func_expr = expr.Cast(); + for (auto &child : func_expr.children) { + callback(child); + } + break; + } + case ExpressionClass::BOUND_OPERATOR: { + auto &op_expr = expr.Cast(); + for (auto &child : op_expr.children) { + callback(child); + } + break; + } + case ExpressionClass::BOUND_SUBQUERY: { + auto &subquery_expr = expr.Cast(); + if (subquery_expr.child) { + callback(subquery_expr.child); + } + break; + } + case ExpressionClass::BOUND_WINDOW: { + auto &window_expr = expr.Cast(); + for (auto &partition : window_expr.partitions) { + callback(partition); + } + for (auto &order : window_expr.orders) { + callback(order.expression); + } + for (auto &child : window_expr.children) { + callback(child); + } + if (window_expr.filter_expr) { + callback(window_expr.filter_expr); + } + if (window_expr.start_expr) { + callback(window_expr.start_expr); + } + if (window_expr.end_expr) { + callback(window_expr.end_expr); + } + if (window_expr.offset_expr) { + callback(window_expr.offset_expr); + } + if (window_expr.default_expr) { + callback(window_expr.default_expr); + } + break; + } + case ExpressionClass::BOUND_UNNEST: { + auto &unnest_expr = expr.Cast(); + callback(unnest_expr.child); + break; + } + case ExpressionClass::BOUND_COLUMN_REF: + case ExpressionClass::BOUND_LAMBDA_REF: + case ExpressionClass::BOUND_CONSTANT: + case ExpressionClass::BOUND_DEFAULT: + case ExpressionClass::BOUND_PARAMETER: + case ExpressionClass::BOUND_REF: + // these node types have no children + break; + default: + throw InternalException("ExpressionIterator used on unbound expression"); + } +} + +void ExpressionIterator::EnumerateExpression(unique_ptr &expr, + const std::function &callback) { + if (!expr) { + return; + } + callback(*expr); + ExpressionIterator::EnumerateChildren(*expr, + [&](unique_ptr &child) { EnumerateExpression(child, callback); }); +} + +void ExpressionIterator::EnumerateTableRefChildren(BoundTableRef &ref, + const std::function &callback) { + switch (ref.type) { + case TableReferenceType::EXPRESSION_LIST: { + auto &bound_expr_list = ref.Cast(); + for (auto &expr_list : bound_expr_list.values) { + for (auto &expr : expr_list) { + EnumerateExpression(expr, callback); + } + } + break; + } + case TableReferenceType::JOIN: { + auto &bound_join = ref.Cast(); + if (bound_join.condition) { + EnumerateExpression(bound_join.condition, callback); + } + EnumerateTableRefChildren(*bound_join.left, callback); + EnumerateTableRefChildren(*bound_join.right, callback); + break; + } + case TableReferenceType::SUBQUERY: { + auto &bound_subquery = ref.Cast(); + EnumerateQueryNodeChildren(*bound_subquery.subquery, callback); + break; + } + case TableReferenceType::TABLE_FUNCTION: + case TableReferenceType::EMPTY: + case TableReferenceType::BASE_TABLE: + case TableReferenceType::CTE: + break; + default: + throw NotImplementedException("Unimplemented table reference type in ExpressionIterator"); + } +} + +void ExpressionIterator::EnumerateQueryNodeChildren(BoundQueryNode &node, + const std::function &callback) { + switch (node.type) { + case QueryNodeType::SET_OPERATION_NODE: { + auto &bound_setop = node.Cast(); + EnumerateQueryNodeChildren(*bound_setop.left, callback); + EnumerateQueryNodeChildren(*bound_setop.right, callback); + break; + } + case QueryNodeType::RECURSIVE_CTE_NODE: { + auto &cte_node = node.Cast(); + EnumerateQueryNodeChildren(*cte_node.left, callback); + EnumerateQueryNodeChildren(*cte_node.right, callback); + break; + } + case QueryNodeType::CTE_NODE: { + auto &cte_node = node.Cast(); + EnumerateQueryNodeChildren(*cte_node.child, callback); + break; + } + case QueryNodeType::SELECT_NODE: { + auto &bound_select = node.Cast(); + for (auto &expr : bound_select.select_list) { + EnumerateExpression(expr, callback); + } + EnumerateExpression(bound_select.where_clause, callback); + for (auto &expr : bound_select.groups.group_expressions) { + EnumerateExpression(expr, callback); + } + EnumerateExpression(bound_select.having, callback); + for (auto &expr : bound_select.aggregates) { + EnumerateExpression(expr, callback); + } + for (auto &entry : bound_select.unnests) { + for (auto &expr : entry.second.expressions) { + EnumerateExpression(expr, callback); + } + } + for (auto &expr : bound_select.windows) { + EnumerateExpression(expr, callback); + } + if (bound_select.from_table) { + EnumerateTableRefChildren(*bound_select.from_table, callback); + } + break; + } + default: + throw NotImplementedException("Unimplemented query node in ExpressionIterator"); + } + for (idx_t i = 0; i < node.modifiers.size(); i++) { + switch (node.modifiers[i]->type) { + case ResultModifierType::DISTINCT_MODIFIER: + for (auto &expr : node.modifiers[i]->Cast().target_distincts) { + EnumerateExpression(expr, callback); + } + break; + case ResultModifierType::ORDER_MODIFIER: + for (auto &order : node.modifiers[i]->Cast().orders) { + EnumerateExpression(order.expression, callback); + } + break; + default: + break; + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/conjunction_filter.cpp b/src/duckdb/src/planner/filter/conjunction_filter.cpp new file mode 100644 index 00000000..b1b43d6d --- /dev/null +++ b/src/duckdb/src/planner/filter/conjunction_filter.cpp @@ -0,0 +1,94 @@ +#include "duckdb/planner/filter/conjunction_filter.hpp" + +namespace duckdb { + +ConjunctionOrFilter::ConjunctionOrFilter() : ConjunctionFilter(TableFilterType::CONJUNCTION_OR) { +} + +FilterPropagateResult ConjunctionOrFilter::CheckStatistics(BaseStatistics &stats) { + // the OR filter is true if ANY of the children is true + D_ASSERT(!child_filters.empty()); + for (auto &filter : child_filters) { + auto prune_result = filter->CheckStatistics(stats); + if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } else if (prune_result == FilterPropagateResult::FILTER_ALWAYS_TRUE) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + } + return FilterPropagateResult::FILTER_ALWAYS_FALSE; +} + +string ConjunctionOrFilter::ToString(const string &column_name) { + string result; + for (idx_t i = 0; i < child_filters.size(); i++) { + if (i > 0) { + result += " OR "; + } + result += child_filters[i]->ToString(column_name); + } + return result; +} + +bool ConjunctionOrFilter::Equals(const TableFilter &other_p) const { + if (!ConjunctionFilter::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (other.child_filters.size() != child_filters.size()) { + return false; + } + for (idx_t i = 0; i < other.child_filters.size(); i++) { + if (!child_filters[i]->Equals(*other.child_filters[i])) { + return false; + } + } + return true; +} + +ConjunctionAndFilter::ConjunctionAndFilter() : ConjunctionFilter(TableFilterType::CONJUNCTION_AND) { +} + +FilterPropagateResult ConjunctionAndFilter::CheckStatistics(BaseStatistics &stats) { + // the AND filter is true if ALL of the children is true + D_ASSERT(!child_filters.empty()); + auto result = FilterPropagateResult::FILTER_ALWAYS_TRUE; + for (auto &filter : child_filters) { + auto prune_result = filter->CheckStatistics(stats); + if (prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } else if (prune_result != result) { + result = FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + } + return result; +} + +string ConjunctionAndFilter::ToString(const string &column_name) { + string result; + for (idx_t i = 0; i < child_filters.size(); i++) { + if (i > 0) { + result += " AND "; + } + result += child_filters[i]->ToString(column_name); + } + return result; +} + +bool ConjunctionAndFilter::Equals(const TableFilter &other_p) const { + if (!ConjunctionFilter::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + if (other.child_filters.size() != child_filters.size()) { + return false; + } + for (idx_t i = 0; i < other.child_filters.size(); i++) { + if (!child_filters[i]->Equals(*other.child_filters[i])) { + return false; + } + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/constant_filter.cpp b/src/duckdb/src/planner/filter/constant_filter.cpp new file mode 100644 index 00000000..3f5084d5 --- /dev/null +++ b/src/duckdb/src/planner/filter/constant_filter.cpp @@ -0,0 +1,45 @@ +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +ConstantFilter::ConstantFilter(ExpressionType comparison_type_p, Value constant_p) + : TableFilter(TableFilterType::CONSTANT_COMPARISON), comparison_type(comparison_type_p), + constant(std::move(constant_p)) { +} + +FilterPropagateResult ConstantFilter::CheckStatistics(BaseStatistics &stats) { + D_ASSERT(constant.type().id() == stats.GetType().id()); + switch (constant.type().InternalType()) { + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::INT128: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + return NumericStats::CheckZonemap(stats, comparison_type, constant); + case PhysicalType::VARCHAR: + return StringStats::CheckZonemap(stats, comparison_type, StringValue::Get(constant)); + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } +} + +string ConstantFilter::ToString(const string &column_name) { + return column_name + ExpressionTypeToOperator(comparison_type) + constant.ToString(); +} + +bool ConstantFilter::Equals(const TableFilter &other_p) const { + if (!TableFilter::Equals(other_p)) { + return false; + } + auto &other = other_p.Cast(); + return other.comparison_type == comparison_type && other.constant == constant; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/null_filter.cpp b/src/duckdb/src/planner/filter/null_filter.cpp new file mode 100644 index 00000000..341198ac --- /dev/null +++ b/src/duckdb/src/planner/filter/null_filter.cpp @@ -0,0 +1,44 @@ +#include "duckdb/planner/filter/null_filter.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +IsNullFilter::IsNullFilter() : TableFilter(TableFilterType::IS_NULL) { +} + +FilterPropagateResult IsNullFilter::CheckStatistics(BaseStatistics &stats) { + if (!stats.CanHaveNull()) { + // no null values are possible: always false + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + if (!stats.CanHaveNoNull()) { + // no non-null values are possible: always true + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +string IsNullFilter::ToString(const string &column_name) { + return column_name + "IS NULL"; +} + +IsNotNullFilter::IsNotNullFilter() : TableFilter(TableFilterType::IS_NOT_NULL) { +} + +FilterPropagateResult IsNotNullFilter::CheckStatistics(BaseStatistics &stats) { + if (!stats.CanHaveNoNull()) { + // no non-null values are possible: always false + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + if (!stats.CanHaveNull()) { + // no null values are possible: always true + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +string IsNotNullFilter::ToString(const string &column_name) { + return column_name + " IS NOT NULL"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/joinside.cpp b/src/duckdb/src/planner/joinside.cpp new file mode 100644 index 00000000..3b24f337 --- /dev/null +++ b/src/duckdb/src/planner/joinside.cpp @@ -0,0 +1,104 @@ +#include "duckdb/planner/joinside.hpp" + +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_subquery_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" + +namespace duckdb { + +unique_ptr JoinCondition::CreateExpression(JoinCondition cond) { + auto bound_comparison = + make_uniq(cond.comparison, std::move(cond.left), std::move(cond.right)); + return std::move(bound_comparison); +} + +unique_ptr JoinCondition::CreateExpression(vector conditions) { + unique_ptr result; + for (auto &cond : conditions) { + auto expr = CreateExpression(std::move(cond)); + if (!result) { + result = std::move(expr); + } else { + auto conj = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(expr), + std::move(result)); + result = std::move(conj); + } + } + return result; +} + +JoinSide JoinSide::CombineJoinSide(JoinSide left, JoinSide right) { + if (left == JoinSide::NONE) { + return right; + } + if (right == JoinSide::NONE) { + return left; + } + if (left != right) { + return JoinSide::BOTH; + } + return left; +} + +JoinSide JoinSide::GetJoinSide(idx_t table_binding, const unordered_set &left_bindings, + const unordered_set &right_bindings) { + if (left_bindings.find(table_binding) != left_bindings.end()) { + // column references table on left side + D_ASSERT(right_bindings.find(table_binding) == right_bindings.end()); + return JoinSide::LEFT; + } else { + // column references table on right side + D_ASSERT(right_bindings.find(table_binding) != right_bindings.end()); + return JoinSide::RIGHT; + } +} + +JoinSide JoinSide::GetJoinSide(Expression &expression, const unordered_set &left_bindings, + const unordered_set &right_bindings) { + if (expression.type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expression.Cast(); + if (colref.depth > 0) { + throw Exception("Non-inner join on correlated columns not supported"); + } + return GetJoinSide(colref.binding.table_index, left_bindings, right_bindings); + } + D_ASSERT(expression.type != ExpressionType::BOUND_REF); + if (expression.type == ExpressionType::SUBQUERY) { + D_ASSERT(expression.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); + auto &subquery = expression.Cast(); + JoinSide side = JoinSide::NONE; + if (subquery.child) { + side = GetJoinSide(*subquery.child, left_bindings, right_bindings); + } + // correlated subquery, check the side of each of correlated columns in the subquery + for (auto &corr : subquery.binder->correlated_columns) { + if (corr.depth > 1) { + // correlated column has depth > 1 + // it does not refer to any table in the current set of bindings + return JoinSide::BOTH; + } + auto correlated_side = GetJoinSide(corr.binding.table_index, left_bindings, right_bindings); + side = CombineJoinSide(side, correlated_side); + } + return side; + } + JoinSide join_side = JoinSide::NONE; + ExpressionIterator::EnumerateChildren(expression, [&](Expression &child) { + auto child_side = GetJoinSide(child, left_bindings, right_bindings); + join_side = CombineJoinSide(child_side, join_side); + }); + return join_side; +} + +JoinSide JoinSide::GetJoinSide(const unordered_set &bindings, const unordered_set &left_bindings, + const unordered_set &right_bindings) { + JoinSide side = JoinSide::NONE; + for (auto binding : bindings) { + side = CombineJoinSide(side, GetJoinSide(binding, left_bindings, right_bindings)); + } + return side; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/logical_operator.cpp b/src/duckdb/src/planner/logical_operator.cpp new file mode 100644 index 00000000..8c3c4c93 --- /dev/null +++ b/src/duckdb/src/planner/logical_operator.cpp @@ -0,0 +1,198 @@ +#include "duckdb/planner/logical_operator.hpp" + +#include "duckdb/common/printer.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/tree_renderer.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" + +namespace duckdb { + +LogicalOperator::LogicalOperator(LogicalOperatorType type) + : type(type), estimated_cardinality(0), has_estimated_cardinality(false) { +} + +LogicalOperator::LogicalOperator(LogicalOperatorType type, vector> expressions) + : type(type), expressions(std::move(expressions)), estimated_cardinality(0), has_estimated_cardinality(false) { +} + +LogicalOperator::~LogicalOperator() { +} + +vector LogicalOperator::GetColumnBindings() { + return {ColumnBinding(0, 0)}; +} + +string LogicalOperator::GetName() const { + return LogicalOperatorToString(type); +} + +string LogicalOperator::ParamsToString() const { + string result; + for (idx_t i = 0; i < expressions.size(); i++) { + if (i > 0) { + result += "\n"; + } + result += expressions[i]->GetName(); + } + return result; +} + +void LogicalOperator::ResolveOperatorTypes() { + + types.clear(); + // first resolve child types + for (auto &child : children) { + child->ResolveOperatorTypes(); + } + // now resolve the types for this operator + ResolveTypes(); + D_ASSERT(types.size() == GetColumnBindings().size()); +} + +vector LogicalOperator::GenerateColumnBindings(idx_t table_idx, idx_t column_count) { + vector result; + result.reserve(column_count); + for (idx_t i = 0; i < column_count; i++) { + result.emplace_back(table_idx, i); + } + return result; +} + +vector LogicalOperator::MapTypes(const vector &types, const vector &projection_map) { + if (projection_map.empty()) { + return types; + } else { + vector result_types; + result_types.reserve(projection_map.size()); + for (auto index : projection_map) { + result_types.push_back(types[index]); + } + return result_types; + } +} + +vector LogicalOperator::MapBindings(const vector &bindings, + const vector &projection_map) { + if (projection_map.empty()) { + return bindings; + } else { + vector result_bindings; + result_bindings.reserve(projection_map.size()); + for (auto index : projection_map) { + D_ASSERT(index < bindings.size()); + result_bindings.push_back(bindings[index]); + } + return result_bindings; + } +} + +string LogicalOperator::ToString() const { + TreeRenderer renderer; + return renderer.ToString(*this); +} + +void LogicalOperator::Verify(ClientContext &context) { +#ifdef DEBUG + // verify expressions + for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) { + auto str = expressions[expr_idx]->ToString(); + // verify that we can (correctly) copy this expression + auto copy = expressions[expr_idx]->Copy(); + auto original_hash = expressions[expr_idx]->Hash(); + auto copy_hash = copy->Hash(); + // copy should be identical to original + D_ASSERT(expressions[expr_idx]->ToString() == copy->ToString()); + D_ASSERT(original_hash == copy_hash); + D_ASSERT(Expression::Equals(expressions[expr_idx], copy)); + + for (idx_t other_idx = 0; other_idx < expr_idx; other_idx++) { + // comparison with other expressions + auto other_hash = expressions[other_idx]->Hash(); + bool expr_equal = Expression::Equals(expressions[expr_idx], expressions[other_idx]); + if (original_hash != other_hash) { + // if the hashes are not equal the expressions should not be equal either + D_ASSERT(!expr_equal); + } + } + D_ASSERT(!str.empty()); + + // verify that serialization + deserialization round-trips correctly + if (expressions[expr_idx]->HasParameter()) { + continue; + } + MemoryStream stream; + // We are serializing a query plan + try { + BinarySerializer::Serialize(*expressions[expr_idx], stream); + } catch (NotImplementedException &ex) { + // ignore for now (FIXME) + continue; + } + // Rewind the stream + stream.Rewind(); + + bound_parameter_map_t parameters; + auto deserialized_expression = BinaryDeserializer::Deserialize(stream, context, parameters); + + // FIXME: expressions might not be equal yet because of statistics propagation + continue; + D_ASSERT(Expression::Equals(expressions[expr_idx], deserialized_expression)); + D_ASSERT(expressions[expr_idx]->Hash() == deserialized_expression->Hash()); + } + D_ASSERT(!ToString().empty()); + for (auto &child : children) { + child->Verify(context); + } +#endif +} + +void LogicalOperator::AddChild(unique_ptr child) { + D_ASSERT(child); + children.push_back(std::move(child)); +} + +idx_t LogicalOperator::EstimateCardinality(ClientContext &context) { + // simple estimator, just take the max of the children + if (has_estimated_cardinality) { + return estimated_cardinality; + } + idx_t max_cardinality = 0; + for (auto &child : children) { + max_cardinality = MaxValue(child->EstimateCardinality(context), max_cardinality); + } + has_estimated_cardinality = true; + estimated_cardinality = max_cardinality; + return estimated_cardinality; +} + +void LogicalOperator::Print() { + Printer::Print(ToString()); +} + +vector LogicalOperator::GetTableIndex() const { + return vector {}; +} + +unique_ptr LogicalOperator::Copy(ClientContext &context) const { + MemoryStream stream; + BinarySerializer serializer(stream); + try { + serializer.Begin(); + this->Serialize(serializer); + serializer.End(); + } catch (NotImplementedException &ex) { + throw NotImplementedException("Logical Operator Copy requires the logical operator and all of its children to " + "be serializable: " + + std::string(ex.what())); + } + stream.Rewind(); + bound_parameter_map_t parameters; + auto op_copy = BinaryDeserializer::Deserialize(stream, context, parameters); + return op_copy; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/logical_operator_visitor.cpp b/src/duckdb/src/planner/logical_operator_visitor.cpp new file mode 100644 index 00000000..e692473f --- /dev/null +++ b/src/duckdb/src/planner/logical_operator_visitor.cpp @@ -0,0 +1,279 @@ +#include "duckdb/planner/logical_operator_visitor.hpp" + +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/list.hpp" + +namespace duckdb { + +void LogicalOperatorVisitor::VisitOperator(LogicalOperator &op) { + VisitOperatorChildren(op); + VisitOperatorExpressions(op); +} + +void LogicalOperatorVisitor::VisitOperatorChildren(LogicalOperator &op) { + for (auto &child : op.children) { + VisitOperator(*child); + } +} + +void LogicalOperatorVisitor::EnumerateExpressions(LogicalOperator &op, + const std::function *child)> &callback) { + + switch (op.type) { + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { + auto &get = op.Cast(); + for (auto &expr_list : get.expressions) { + for (auto &expr : expr_list) { + callback(&expr); + } + } + break; + } + case LogicalOperatorType::LOGICAL_ORDER_BY: { + auto &order = op.Cast(); + for (auto &node : order.orders) { + callback(&node.expression); + } + break; + } + case LogicalOperatorType::LOGICAL_TOP_N: { + auto &order = op.Cast(); + for (auto &node : order.orders) { + callback(&node.expression); + } + break; + } + case LogicalOperatorType::LOGICAL_DISTINCT: { + auto &distinct = op.Cast(); + for (auto &target : distinct.distinct_targets) { + callback(&target); + } + if (distinct.order_by) { + for (auto &order : distinct.order_by->orders) { + callback(&order.expression); + } + } + break; + } + case LogicalOperatorType::LOGICAL_INSERT: { + auto &insert = op.Cast(); + if (insert.on_conflict_condition) { + callback(&insert.on_conflict_condition); + } + if (insert.do_update_condition) { + callback(&insert.do_update_condition); + } + break; + } + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + auto &join = op.Cast(); + for (auto &expr : join.duplicate_eliminated_columns) { + callback(&expr); + } + for (auto &cond : join.conditions) { + callback(&cond.left); + callback(&cond.right); + } + break; + } + case LogicalOperatorType::LOGICAL_ANY_JOIN: { + auto &join = op.Cast(); + callback(&join.condition); + break; + } + case LogicalOperatorType::LOGICAL_LIMIT: { + auto &limit = op.Cast(); + if (limit.limit) { + callback(&limit.limit); + } + if (limit.offset) { + callback(&limit.offset); + } + break; + } + case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: { + auto &limit = op.Cast(); + if (limit.limit) { + callback(&limit.limit); + } + if (limit.offset) { + callback(&limit.offset); + } + break; + } + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + auto &aggr = op.Cast(); + for (auto &group : aggr.groups) { + callback(&group); + } + break; + } + default: + break; + } + for (auto &expression : op.expressions) { + callback(&expression); + } +} + +void LogicalOperatorVisitor::VisitOperatorExpressions(LogicalOperator &op) { + LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr *child) { VisitExpression(child); }); +} + +void LogicalOperatorVisitor::VisitExpression(unique_ptr *expression) { + auto &expr = **expression; + unique_ptr result; + switch (expr.GetExpressionClass()) { + case ExpressionClass::BOUND_AGGREGATE: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_BETWEEN: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_CASE: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_CAST: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_COLUMN_REF: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_COMPARISON: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_CONJUNCTION: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_CONSTANT: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_FUNCTION: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_SUBQUERY: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_OPERATOR: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_PARAMETER: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_REF: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_DEFAULT: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_WINDOW: + result = VisitReplace(expr.Cast(), expression); + break; + case ExpressionClass::BOUND_UNNEST: + result = VisitReplace(expr.Cast(), expression); + break; + default: + throw InternalException("Unrecognized expression type in logical operator visitor"); + } + if (result) { + *expression = std::move(result); + } else { + // visit the children of this node + VisitExpressionChildren(expr); + } +} + +void LogicalOperatorVisitor::VisitExpressionChildren(Expression &expr) { + ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &expr) { VisitExpression(&expr); }); +} + +// these are all default methods that can be overriden +// we don't care about coverage here +// LCOV_EXCL_START +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundAggregateExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundBetweenExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundCaseExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundCastExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundComparisonExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundConjunctionExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundConstantExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundDefaultExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundFunctionExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundOperatorExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundParameterExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundReferenceExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundSubqueryExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundWindowExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +unique_ptr LogicalOperatorVisitor::VisitReplace(BoundUnnestExpression &expr, + unique_ptr *expr_ptr) { + return nullptr; +} + +// LCOV_EXCL_STOP + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_aggregate.cpp b/src/duckdb/src/planner/operator/logical_aggregate.cpp new file mode 100644 index 00000000..c0ae7d5b --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_aggregate.cpp @@ -0,0 +1,86 @@ +#include "duckdb/planner/operator/logical_aggregate.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +LogicalAggregate::LogicalAggregate(idx_t group_index, idx_t aggregate_index, vector> select_list) + : LogicalOperator(LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY, std::move(select_list)), + group_index(group_index), aggregate_index(aggregate_index), groupings_index(DConstants::INVALID_INDEX) { +} + +void LogicalAggregate::ResolveTypes() { + D_ASSERT(groupings_index != DConstants::INVALID_INDEX || grouping_functions.empty()); + for (auto &expr : groups) { + types.push_back(expr->return_type); + } + // get the chunk types from the projection list + for (auto &expr : expressions) { + types.push_back(expr->return_type); + } + for (idx_t i = 0; i < grouping_functions.size(); i++) { + types.emplace_back(LogicalType::BIGINT); + } +} + +vector LogicalAggregate::GetColumnBindings() { + D_ASSERT(groupings_index != DConstants::INVALID_INDEX || grouping_functions.empty()); + vector result; + result.reserve(groups.size() + expressions.size() + grouping_functions.size()); + for (idx_t i = 0; i < groups.size(); i++) { + result.emplace_back(group_index, i); + } + for (idx_t i = 0; i < expressions.size(); i++) { + result.emplace_back(aggregate_index, i); + } + for (idx_t i = 0; i < grouping_functions.size(); i++) { + result.emplace_back(groupings_index, i); + } + return result; +} + +string LogicalAggregate::ParamsToString() const { + string result; + for (idx_t i = 0; i < groups.size(); i++) { + if (i > 0) { + result += "\n"; + } + result += groups[i]->GetName(); + } + for (idx_t i = 0; i < expressions.size(); i++) { + if (i > 0 || !groups.empty()) { + result += "\n"; + } + result += expressions[i]->GetName(); + } + return result; +} + +idx_t LogicalAggregate::EstimateCardinality(ClientContext &context) { + if (groups.empty()) { + // ungrouped aggregate + return 1; + } + return LogicalOperator::EstimateCardinality(context); +} + +vector LogicalAggregate::GetTableIndex() const { + vector result {group_index, aggregate_index}; + if (groupings_index != DConstants::INVALID_INDEX) { + result.push_back(groupings_index); + } + return result; +} + +string LogicalAggregate::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + + StringUtil::Format(" #%llu, #%llu, #%llu", group_index, aggregate_index, groupings_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_any_join.cpp b/src/duckdb/src/planner/operator/logical_any_join.cpp new file mode 100644 index 00000000..13681f69 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_any_join.cpp @@ -0,0 +1,12 @@ +#include "duckdb/planner/operator/logical_any_join.hpp" + +namespace duckdb { + +LogicalAnyJoin::LogicalAnyJoin(JoinType type) : LogicalJoin(type, LogicalOperatorType::LOGICAL_ANY_JOIN) { +} + +string LogicalAnyJoin::ParamsToString() const { + return condition->ToString(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_column_data_get.cpp b/src/duckdb/src/planner/operator/logical_column_data_get.cpp new file mode 100644 index 00000000..b99c9121 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_column_data_get.cpp @@ -0,0 +1,33 @@ +#include "duckdb/planner/operator/logical_column_data_get.hpp" + +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +LogicalColumnDataGet::LogicalColumnDataGet(idx_t table_index, vector types, + unique_ptr collection) + : LogicalOperator(LogicalOperatorType::LOGICAL_CHUNK_GET), table_index(table_index), + collection(std::move(collection)) { + D_ASSERT(types.size() > 0); + chunk_types = std::move(types); +} + +vector LogicalColumnDataGet::GetColumnBindings() { + return GenerateColumnBindings(table_index, chunk_types.size()); +} + +vector LogicalColumnDataGet::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalColumnDataGet::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_comparison_join.cpp b/src/duckdb/src/planner/operator/logical_comparison_join.cpp new file mode 100644 index 00000000..8fd3e38c --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_comparison_join.cpp @@ -0,0 +1,23 @@ +#include "duckdb/common/string_util.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/common/enum_util.hpp" +namespace duckdb { + +LogicalComparisonJoin::LogicalComparisonJoin(JoinType join_type, LogicalOperatorType logical_type) + : LogicalJoin(join_type, logical_type) { +} + +string LogicalComparisonJoin::ParamsToString() const { + string result = EnumUtil::ToChars(join_type); + for (auto &condition : conditions) { + result += "\n"; + auto expr = + make_uniq(condition.comparison, condition.left->Copy(), condition.right->Copy()); + result += expr->ToString(); + } + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_copy_to_file.cpp b/src/duckdb/src/planner/operator/logical_copy_to_file.cpp new file mode 100644 index 00000000..c3654b86 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_copy_to_file.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_copy_to_file.hpp" + +#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" +#include "duckdb/function/copy_function.hpp" + +namespace duckdb { + +void LogicalCopyToFile::Serialize(Serializer &serializer) const { + throw SerializationException("LogicalCopyToFile not implemented yet"); +} + +unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deserializer) { + throw SerializationException("LogicalCopyToFile not implemented yet"); +} + +idx_t LogicalCopyToFile::EstimateCardinality(ClientContext &context) { + return 1; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_create.cpp b/src/duckdb/src/planner/operator/logical_create.cpp new file mode 100644 index 00000000..ff4aeb98 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_create.cpp @@ -0,0 +1,23 @@ +#include "duckdb/planner/operator/logical_create.hpp" + +namespace duckdb { + +LogicalCreate::LogicalCreate(LogicalOperatorType type, unique_ptr info, + optional_ptr schema) + : LogicalOperator(type), schema(schema), info(std::move(info)) { +} + +LogicalCreate::LogicalCreate(LogicalOperatorType type, ClientContext &context, unique_ptr info_p) + : LogicalOperator(type), info(std::move(info_p)) { + this->schema = Catalog::GetSchema(context, info->catalog, info->schema, OnEntryNotFound::RETURN_NULL); +} + +idx_t LogicalCreate::EstimateCardinality(ClientContext &context) { + return 1; +} + +void LogicalCreate::ResolveTypes() { + types.emplace_back(LogicalType::BIGINT); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_create_index.cpp b/src/duckdb/src/planner/operator/logical_create_index.cpp new file mode 100644 index 00000000..65e36069 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_create_index.cpp @@ -0,0 +1,44 @@ +#include "duckdb/planner/operator/logical_create_index.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/function/function_serialization.hpp" + +namespace duckdb { + +LogicalCreateIndex::LogicalCreateIndex(unique_ptr info_p, vector> expressions_p, + TableCatalogEntry &table_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), info(std::move(info_p)), table(table_p) { + + for (auto &expr : expressions_p) { + this->unbound_expressions.push_back(expr->Copy()); + } + this->expressions = std::move(expressions_p); + + if (info->column_ids.empty()) { + throw BinderException("CREATE INDEX does not refer to any columns in the base table!"); + } +} + +LogicalCreateIndex::LogicalCreateIndex(ClientContext &context, unique_ptr info_p, + vector> expressions_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), + info(unique_ptr_cast(std::move(info_p))), table(BindTable(context, *info)) { + for (auto &expr : expressions_p) { + this->unbound_expressions.push_back(expr->Copy()); + } + this->expressions = std::move(expressions_p); +} + +void LogicalCreateIndex::ResolveTypes() { + types.emplace_back(LogicalType::BIGINT); +} + +TableCatalogEntry &LogicalCreateIndex::BindTable(ClientContext &context, CreateIndexInfo &info) { + auto &catalog = info.catalog; + auto &schema = info.schema; + auto &table_name = info.table; + return Catalog::GetEntry(context, catalog, schema, table_name); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_create_table.cpp b/src/duckdb/src/planner/operator/logical_create_table.cpp new file mode 100644 index 00000000..906fe683 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_create_table.cpp @@ -0,0 +1,25 @@ +#include "duckdb/planner/operator/logical_create_table.hpp" + +namespace duckdb { + +LogicalCreateTable::LogicalCreateTable(SchemaCatalogEntry &schema, unique_ptr info) + : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_TABLE), schema(schema), info(std::move(info)) { +} + +LogicalCreateTable::LogicalCreateTable(ClientContext &context, unique_ptr unbound_info) + : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_TABLE), + schema(Catalog::GetSchema(context, unbound_info->catalog, unbound_info->schema)) { + D_ASSERT(unbound_info->type == CatalogType::TABLE_ENTRY); + auto binder = Binder::CreateBinder(context); + info = binder->BindCreateTableInfo(unique_ptr_cast(std::move(unbound_info))); +} + +idx_t LogicalCreateTable::EstimateCardinality(ClientContext &context) { + return 1; +} + +void LogicalCreateTable::ResolveTypes() { + types.emplace_back(LogicalType::BIGINT); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_cross_product.cpp b/src/duckdb/src/planner/operator/logical_cross_product.cpp new file mode 100644 index 00000000..263dcf49 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_cross_product.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_cross_product.hpp" + +namespace duckdb { + +LogicalCrossProduct::LogicalCrossProduct(unique_ptr left, unique_ptr right) + : LogicalUnconditionalJoin(LogicalOperatorType::LOGICAL_CROSS_PRODUCT, std::move(left), std::move(right)) { +} + +unique_ptr LogicalCrossProduct::Create(unique_ptr left, + unique_ptr right) { + if (left->type == LogicalOperatorType::LOGICAL_DUMMY_SCAN) { + return right; + } + if (right->type == LogicalOperatorType::LOGICAL_DUMMY_SCAN) { + return left; + } + return make_uniq(std::move(left), std::move(right)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_cteref.cpp b/src/duckdb/src/planner/operator/logical_cteref.cpp new file mode 100644 index 00000000..08b6aea4 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_cteref.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_cteref.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +vector LogicalCTERef::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalCTERef::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_delete.cpp b/src/duckdb/src/planner/operator/logical_delete.cpp new file mode 100644 index 00000000..a028a1ea --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_delete.cpp @@ -0,0 +1,52 @@ +#include "duckdb/planner/operator/logical_delete.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" + +namespace duckdb { + +LogicalDelete::LogicalDelete(TableCatalogEntry &table, idx_t table_index) + : LogicalOperator(LogicalOperatorType::LOGICAL_DELETE), table(table), table_index(table_index), + return_chunk(false) { +} + +LogicalDelete::LogicalDelete(ClientContext &context, const unique_ptr &table_info) + : LogicalOperator(LogicalOperatorType::LOGICAL_DELETE), + table(Catalog::GetEntry(context, table_info->catalog, table_info->schema, + dynamic_cast(*table_info).table)) { +} + +idx_t LogicalDelete::EstimateCardinality(ClientContext &context) { + return return_chunk ? LogicalOperator::EstimateCardinality(context) : 1; +} + +vector LogicalDelete::GetTableIndex() const { + return vector {table_index}; +} + +vector LogicalDelete::GetColumnBindings() { + if (return_chunk) { + return GenerateColumnBindings(table_index, table.GetTypes().size()); + } + return {ColumnBinding(0, 0)}; +} + +void LogicalDelete::ResolveTypes() { + if (return_chunk) { + types = table.GetTypes(); + } else { + types.emplace_back(LogicalType::BIGINT); + } +} + +string LogicalDelete::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_delim_get.cpp b/src/duckdb/src/planner/operator/logical_delim_get.cpp new file mode 100644 index 00000000..3cd63f3a --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_delim_get.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_delim_get.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +vector LogicalDelimGet::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalDelimGet::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_dependent_join.cpp b/src/duckdb/src/planner/operator/logical_dependent_join.cpp new file mode 100644 index 00000000..308aa15d --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_dependent_join.cpp @@ -0,0 +1,22 @@ +#include "duckdb/planner/operator/logical_dependent_join.hpp" + +namespace duckdb { + +LogicalDependentJoin::LogicalDependentJoin(unique_ptr left, unique_ptr right, + vector correlated_columns, JoinType type, + unique_ptr condition) + : LogicalComparisonJoin(type, LogicalOperatorType::LOGICAL_DEPENDENT_JOIN), join_condition(std::move(condition)), + correlated_columns(std::move(correlated_columns)) { + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +unique_ptr LogicalDependentJoin::Create(unique_ptr left, + unique_ptr right, + vector correlated_columns, JoinType type, + unique_ptr condition) { + return make_uniq(std::move(left), std::move(right), std::move(correlated_columns), type, + std::move(condition)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_distinct.cpp b/src/duckdb/src/planner/operator/logical_distinct.cpp new file mode 100644 index 00000000..3a35ba87 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_distinct.cpp @@ -0,0 +1,28 @@ +#include "duckdb/common/string_util.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" + +namespace duckdb { + +LogicalDistinct::LogicalDistinct(DistinctType distinct_type) + : LogicalOperator(LogicalOperatorType::LOGICAL_DISTINCT), distinct_type(distinct_type) { +} +LogicalDistinct::LogicalDistinct(vector> targets, DistinctType distinct_type) + : LogicalOperator(LogicalOperatorType::LOGICAL_DISTINCT), distinct_type(distinct_type), + distinct_targets(std::move(targets)) { +} + +string LogicalDistinct::ParamsToString() const { + string result = LogicalOperator::ParamsToString(); + if (!distinct_targets.empty()) { + result += StringUtil::Join(distinct_targets, distinct_targets.size(), "\n", + [](const unique_ptr &child) { return child->GetName(); }); + } + + return result; +} + +void LogicalDistinct::ResolveTypes() { + types = children[0]->types; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_dummy_scan.cpp b/src/duckdb/src/planner/operator/logical_dummy_scan.cpp new file mode 100644 index 00000000..13219fba --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_dummy_scan.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_dummy_scan.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +vector LogicalDummyScan::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalDummyScan::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_empty_result.cpp b/src/duckdb/src/planner/operator/logical_empty_result.cpp new file mode 100644 index 00000000..28b014ff --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_empty_result.cpp @@ -0,0 +1,17 @@ +#include "duckdb/planner/operator/logical_empty_result.hpp" + +namespace duckdb { + +LogicalEmptyResult::LogicalEmptyResult(unique_ptr op) + : LogicalOperator(LogicalOperatorType::LOGICAL_EMPTY_RESULT) { + + this->bindings = op->GetColumnBindings(); + + op->ResolveOperatorTypes(); + this->return_types = op->types; +} + +LogicalEmptyResult::LogicalEmptyResult() : LogicalOperator(LogicalOperatorType::LOGICAL_EMPTY_RESULT) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_expression_get.cpp b/src/duckdb/src/planner/operator/logical_expression_get.cpp new file mode 100644 index 00000000..6f9fe045 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_expression_get.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_expression_get.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +vector LogicalExpressionGet::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalExpressionGet::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_extension_operator.cpp b/src/duckdb/src/planner/operator/logical_extension_operator.cpp new file mode 100644 index 00000000..2e0cb6d1 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_extension_operator.cpp @@ -0,0 +1,44 @@ +#include "duckdb/planner/operator/logical_extension_operator.hpp" +#include "duckdb/execution/column_binding_resolver.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +void LogicalExtensionOperator::ResolveColumnBindings(ColumnBindingResolver &res, vector &bindings) { + // general case + // first visit the children of this operator + for (auto &child : children) { + res.VisitOperator(*child); + } + // now visit the expressions of this operator to resolve any bound column references + for (auto &expression : expressions) { + res.VisitExpression(&expression); + } + // finally update the current set of bindings to the current set of column bindings + bindings = GetColumnBindings(); +} + +void LogicalExtensionOperator::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WriteProperty(200, "extension_name", GetExtensionName()); +} + +unique_ptr LogicalExtensionOperator::Deserialize(Deserializer &deserializer) { + auto &config = DBConfig::GetConfig(deserializer.Get()); + auto extension_name = deserializer.ReadProperty(200, "extension_name"); + for (auto &extension : config.operator_extensions) { + if (extension->GetName() == extension_name) { + return extension->Deserialize(deserializer); + } + } + throw SerializationException("No deserialization method exists for extension: " + extension_name); +} + +string LogicalExtensionOperator::GetExtensionName() const { + throw SerializationException("LogicalExtensionOperator::GetExtensionName not implemented which is required for " + "serializing extension operators"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_filter.cpp b/src/duckdb/src/planner/operator/logical_filter.cpp new file mode 100644 index 00000000..4fc18c08 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_filter.cpp @@ -0,0 +1,45 @@ +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" + +namespace duckdb { + +LogicalFilter::LogicalFilter(unique_ptr expression) : LogicalOperator(LogicalOperatorType::LOGICAL_FILTER) { + expressions.push_back(std::move(expression)); + SplitPredicates(expressions); +} + +LogicalFilter::LogicalFilter() : LogicalOperator(LogicalOperatorType::LOGICAL_FILTER) { +} + +void LogicalFilter::ResolveTypes() { + types = MapTypes(children[0]->types, projection_map); +} + +vector LogicalFilter::GetColumnBindings() { + return MapBindings(children[0]->GetColumnBindings(), projection_map); +} + +// Split the predicates separated by AND statements +// These are the predicates that are safe to push down because all of them MUST +// be true +bool LogicalFilter::SplitPredicates(vector> &expressions) { + bool found_conjunction = false; + for (idx_t i = 0; i < expressions.size(); i++) { + if (expressions[i]->type == ExpressionType::CONJUNCTION_AND) { + auto &conjunction = expressions[i]->Cast(); + found_conjunction = true; + // AND expression, append the other children + for (idx_t k = 1; k < conjunction.children.size(); k++) { + expressions.push_back(std::move(conjunction.children[k])); + } + // replace this expression with the first child of the conjunction + expressions[i] = std::move(conjunction.children[0]); + // we move back by one so the right child is checked again + // in case it is an AND expression as well + i--; + } + } + return found_conjunction; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_get.cpp b/src/duckdb/src/planner/operator/logical_get.cpp new file mode 100644 index 00000000..25b4a6bc --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_get.cpp @@ -0,0 +1,203 @@ +#include "duckdb/planner/operator/logical_get.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/function_serialization.hpp" +#include "duckdb/function/table/table_scan.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +LogicalGet::LogicalGet() : LogicalOperator(LogicalOperatorType::LOGICAL_GET) { +} + +LogicalGet::LogicalGet(idx_t table_index, TableFunction function, unique_ptr bind_data, + vector returned_types, vector returned_names) + : LogicalOperator(LogicalOperatorType::LOGICAL_GET), table_index(table_index), function(std::move(function)), + bind_data(std::move(bind_data)), returned_types(std::move(returned_types)), names(std::move(returned_names)), + extra_info() { +} + +optional_ptr LogicalGet::GetTable() const { + return TableScanFunction::GetTableEntry(function, bind_data.get()); +} + +string LogicalGet::ParamsToString() const { + string result = ""; + for (auto &kv : table_filters.filters) { + auto &column_index = kv.first; + auto &filter = kv.second; + if (column_index < names.size()) { + result += filter->ToString(names[column_index]); + } + result += "\n"; + } + if (!extra_info.file_filters.empty()) { + result += "\n[INFOSEPARATOR]\n"; + result += "File Filters: " + extra_info.file_filters; + } + if (!function.to_string) { + return result; + } + return result + "\n" + function.to_string(bind_data.get()); +} + +vector LogicalGet::GetColumnBindings() { + if (column_ids.empty()) { + return {ColumnBinding(table_index, 0)}; + } + vector result; + if (projection_ids.empty()) { + for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { + result.emplace_back(table_index, col_idx); + } + } else { + for (auto proj_id : projection_ids) { + result.emplace_back(table_index, proj_id); + } + } + if (!projected_input.empty()) { + if (children.size() != 1) { + throw InternalException("LogicalGet::project_input can only be set for table-in-out functions"); + } + auto child_bindings = children[0]->GetColumnBindings(); + for (auto entry : projected_input) { + D_ASSERT(entry < child_bindings.size()); + result.emplace_back(child_bindings[entry]); + } + } + return result; +} + +void LogicalGet::ResolveTypes() { + if (column_ids.empty()) { + column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); + } + + if (projection_ids.empty()) { + for (auto &index : column_ids) { + if (index == COLUMN_IDENTIFIER_ROW_ID) { + types.emplace_back(LogicalType::ROW_TYPE); + } else { + types.push_back(returned_types[index]); + } + } + } else { + for (auto &proj_index : projection_ids) { + auto &index = column_ids[proj_index]; + if (index == COLUMN_IDENTIFIER_ROW_ID) { + types.emplace_back(LogicalType::ROW_TYPE); + } else { + types.push_back(returned_types[index]); + } + } + } + if (!projected_input.empty()) { + if (children.size() != 1) { + throw InternalException("LogicalGet::project_input can only be set for table-in-out functions"); + } + for (auto entry : projected_input) { + D_ASSERT(entry < children[0]->types.size()); + types.push_back(children[0]->types[entry]); + } + } +} + +idx_t LogicalGet::EstimateCardinality(ClientContext &context) { + // join order optimizer does better cardinality estimation. + if (has_estimated_cardinality) { + return estimated_cardinality; + } + if (function.cardinality) { + auto node_stats = function.cardinality(context, bind_data.get()); + if (node_stats && node_stats->has_estimated_cardinality) { + return node_stats->estimated_cardinality; + } + } + return 1; +} + +void LogicalGet::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WriteProperty(200, "table_index", table_index); + serializer.WriteProperty(201, "returned_types", returned_types); + serializer.WriteProperty(202, "names", names); + serializer.WriteProperty(203, "column_ids", column_ids); + serializer.WriteProperty(204, "projection_ids", projection_ids); + serializer.WriteProperty(205, "table_filters", table_filters); + FunctionSerializer::Serialize(serializer, function, bind_data.get()); + if (!function.serialize) { + D_ASSERT(!function.serialize); + // no serialize method: serialize input values and named_parameters for rebinding purposes + serializer.WriteProperty(206, "parameters", parameters); + serializer.WriteProperty(207, "named_parameters", named_parameters); + serializer.WriteProperty(208, "input_table_types", input_table_types); + serializer.WriteProperty(209, "input_table_names", input_table_names); + } + serializer.WriteProperty(210, "projected_input", projected_input); +} + +unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) { + auto result = unique_ptr(new LogicalGet()); + deserializer.ReadProperty(200, "table_index", result->table_index); + deserializer.ReadProperty(201, "returned_types", result->returned_types); + deserializer.ReadProperty(202, "names", result->names); + deserializer.ReadProperty(203, "column_ids", result->column_ids); + deserializer.ReadProperty(204, "projection_ids", result->projection_ids); + deserializer.ReadProperty(205, "table_filters", result->table_filters); + auto entry = FunctionSerializer::DeserializeBase( + deserializer, CatalogType::TABLE_FUNCTION_ENTRY); + result->function = entry.first; + auto &function = result->function; + auto has_serialize = entry.second; + + unique_ptr bind_data; + if (!has_serialize) { + deserializer.ReadProperty(206, "parameters", result->parameters); + deserializer.ReadProperty(207, "named_parameters", result->named_parameters); + deserializer.ReadProperty(208, "input_table_types", result->input_table_types); + deserializer.ReadProperty(209, "input_table_names", result->input_table_names); + TableFunctionBindInput input(result->parameters, result->named_parameters, result->input_table_types, + result->input_table_names, function.function_info.get()); + + vector bind_return_types; + vector bind_names; + if (!function.bind) { + throw InternalException("Table function \"%s\" has neither bind nor (de)serialize", function.name); + } + bind_data = function.bind(deserializer.Get(), input, bind_return_types, bind_names); + if (result->returned_types != bind_return_types) { + throw SerializationException( + "Table function deserialization failure - bind returned different return types than were serialized"); + } + // names can actually be different because of aliases - only the sizes cannot be different + if (result->names.size() != bind_names.size()) { + throw SerializationException( + "Table function deserialization failure - bind returned different returned names than were serialized"); + } + } else { + bind_data = FunctionSerializer::FunctionDeserialize(deserializer, function); + } + result->bind_data = std::move(bind_data); + deserializer.ReadProperty(210, "projected_input", result->projected_input); + return std::move(result); +} + +vector LogicalGet::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalGet::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return StringUtil::Upper(function.name) + StringUtil::Format(" #%llu", table_index); + } +#endif + return StringUtil::Upper(function.name); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_insert.cpp b/src/duckdb/src/planner/operator/logical_insert.cpp new file mode 100644 index 00000000..3846ed00 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_insert.cpp @@ -0,0 +1,52 @@ +#include "duckdb/planner/operator/logical_insert.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" + +namespace duckdb { + +LogicalInsert::LogicalInsert(TableCatalogEntry &table, idx_t table_index) + : LogicalOperator(LogicalOperatorType::LOGICAL_INSERT), table(table), table_index(table_index), return_chunk(false), + action_type(OnConflictAction::THROW) { +} + +LogicalInsert::LogicalInsert(ClientContext &context, const unique_ptr table_info) + : LogicalOperator(LogicalOperatorType::LOGICAL_INSERT), + table(Catalog::GetEntry(context, table_info->catalog, table_info->schema, + dynamic_cast(*table_info).table)) { +} + +idx_t LogicalInsert::EstimateCardinality(ClientContext &context) { + return return_chunk ? LogicalOperator::EstimateCardinality(context) : 1; +} + +vector LogicalInsert::GetTableIndex() const { + return vector {table_index}; +} + +vector LogicalInsert::GetColumnBindings() { + if (return_chunk) { + return GenerateColumnBindings(table_index, table.GetTypes().size()); + } + return {ColumnBinding(0, 0)}; +} + +void LogicalInsert::ResolveTypes() { + if (return_chunk) { + types = table.GetTypes(); + } else { + types.emplace_back(LogicalType::BIGINT); + } +} + +string LogicalInsert::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_join.cpp b/src/duckdb/src/planner/operator/logical_join.cpp new file mode 100644 index 00000000..3c78d81b --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_join.cpp @@ -0,0 +1,61 @@ +#include "duckdb/planner/operator/logical_join.hpp" + +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" + +namespace duckdb { + +LogicalJoin::LogicalJoin(JoinType join_type, LogicalOperatorType logical_type) + : LogicalOperator(logical_type), join_type(join_type) { +} + +vector LogicalJoin::GetColumnBindings() { + auto left_bindings = MapBindings(children[0]->GetColumnBindings(), left_projection_map); + if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { + // for SEMI and ANTI join we only project the left hand side + return left_bindings; + } + if (join_type == JoinType::MARK) { + // for MARK join we project the left hand side plus the MARK column + left_bindings.emplace_back(mark_index, 0); + return left_bindings; + } + // for other join types we project both the LHS and the RHS + auto right_bindings = MapBindings(children[1]->GetColumnBindings(), right_projection_map); + left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end()); + return left_bindings; +} + +void LogicalJoin::ResolveTypes() { + types = MapTypes(children[0]->types, left_projection_map); + if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { + // for SEMI and ANTI join we only project the left hand side + return; + } + if (join_type == JoinType::MARK) { + // for MARK join we project the left hand side, plus a BOOLEAN column indicating the MARK + types.emplace_back(LogicalType::BOOLEAN); + return; + } + // for any other join we project both sides + auto right_types = MapTypes(children[1]->types, right_projection_map); + types.insert(types.end(), right_types.begin(), right_types.end()); +} + +void LogicalJoin::GetTableReferences(LogicalOperator &op, unordered_set &bindings) { + auto column_bindings = op.GetColumnBindings(); + for (auto binding : column_bindings) { + bindings.insert(binding.table_index); + } +} + +void LogicalJoin::GetExpressionBindings(Expression &expr, unordered_set &bindings) { + if (expr.type == ExpressionType::BOUND_COLUMN_REF) { + auto &colref = expr.Cast(); + D_ASSERT(colref.depth == 0); + bindings.insert(colref.binding.table_index); + } + ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { GetExpressionBindings(child, bindings); }); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_limit.cpp b/src/duckdb/src/planner/operator/logical_limit.cpp new file mode 100644 index 00000000..e413078a --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_limit.cpp @@ -0,0 +1,27 @@ +#include "duckdb/planner/operator/logical_limit.hpp" + +namespace duckdb { + +LogicalLimit::LogicalLimit(int64_t limit_val, int64_t offset_val, unique_ptr limit, + unique_ptr offset) + : LogicalOperator(LogicalOperatorType::LOGICAL_LIMIT), limit_val(limit_val), offset_val(offset_val), + limit(std::move(limit)), offset(std::move(offset)) { +} + +vector LogicalLimit::GetColumnBindings() { + return children[0]->GetColumnBindings(); +} + +idx_t LogicalLimit::EstimateCardinality(ClientContext &context) { + auto child_cardinality = children[0]->EstimateCardinality(context); + if (limit_val >= 0 && idx_t(limit_val) < child_cardinality) { + child_cardinality = limit_val; + } + return child_cardinality; +} + +void LogicalLimit::ResolveTypes() { + types = children[0]->types; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_limit_percent.cpp b/src/duckdb/src/planner/operator/logical_limit_percent.cpp new file mode 100644 index 00000000..c020ac4b --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_limit_percent.cpp @@ -0,0 +1,14 @@ +#include "duckdb/planner/operator/logical_limit_percent.hpp" +#include + +namespace duckdb { + +idx_t LogicalLimitPercent::EstimateCardinality(ClientContext &context) { + auto child_cardinality = LogicalOperator::EstimateCardinality(context); + if ((limit_percent < 0 || limit_percent > 100) || std::isnan(limit_percent)) { + return child_cardinality; + } + return idx_t(child_cardinality * (limit_percent / 100.0)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_materialized_cte.cpp b/src/duckdb/src/planner/operator/logical_materialized_cte.cpp new file mode 100644 index 00000000..5a178611 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_materialized_cte.cpp @@ -0,0 +1,9 @@ +#include "duckdb/planner/operator/logical_materialized_cte.hpp" + +namespace duckdb { + +vector LogicalMaterializedCTE::GetTableIndex() const { + return vector {table_index}; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_order.cpp b/src/duckdb/src/planner/operator/logical_order.cpp new file mode 100644 index 00000000..5def8714 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_order.cpp @@ -0,0 +1,44 @@ +#include "duckdb/planner/operator/logical_order.hpp" + +namespace duckdb { + +LogicalOrder::LogicalOrder(vector orders) + : LogicalOperator(LogicalOperatorType::LOGICAL_ORDER_BY), orders(std::move(orders)) { +} + +vector LogicalOrder::GetColumnBindings() { + auto child_bindings = children[0]->GetColumnBindings(); + if (projections.empty()) { + return child_bindings; + } + + vector result; + for (auto &col_idx : projections) { + result.push_back(child_bindings[col_idx]); + } + return result; +} + +string LogicalOrder::ParamsToString() const { + string result = "ORDERS:\n"; + for (idx_t i = 0; i < orders.size(); i++) { + if (i > 0) { + result += "\n"; + } + result += orders[i].expression->GetName(); + } + return result; +} + +void LogicalOrder::ResolveTypes() { + const auto child_types = children[0]->types; + if (projections.empty()) { + types = child_types; + } else { + for (auto &col_idx : projections) { + types.push_back(child_types[col_idx]); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_pivot.cpp b/src/duckdb/src/planner/operator/logical_pivot.cpp new file mode 100644 index 00000000..b5fbbd2e --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_pivot.cpp @@ -0,0 +1,41 @@ +#include "duckdb/planner/operator/logical_pivot.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +LogicalPivot::LogicalPivot() : LogicalOperator(LogicalOperatorType::LOGICAL_PIVOT) { +} + +LogicalPivot::LogicalPivot(idx_t pivot_idx, unique_ptr plan, BoundPivotInfo info_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_PIVOT), pivot_index(pivot_idx), bound_pivot(std::move(info_p)) { + D_ASSERT(plan); + children.push_back(std::move(plan)); +} + +vector LogicalPivot::GetColumnBindings() { + vector result; + for (idx_t i = 0; i < bound_pivot.types.size(); i++) { + result.emplace_back(pivot_index, i); + } + return result; +} + +vector LogicalPivot::GetTableIndex() const { + return vector {pivot_index}; +} + +void LogicalPivot::ResolveTypes() { + this->types = bound_pivot.types; +} + +string LogicalPivot::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", pivot_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_positional_join.cpp b/src/duckdb/src/planner/operator/logical_positional_join.cpp new file mode 100644 index 00000000..e863eaf5 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_positional_join.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_positional_join.hpp" + +namespace duckdb { + +LogicalPositionalJoin::LogicalPositionalJoin(unique_ptr left, unique_ptr right) + : LogicalUnconditionalJoin(LogicalOperatorType::LOGICAL_POSITIONAL_JOIN, std::move(left), std::move(right)) { +} + +unique_ptr LogicalPositionalJoin::Create(unique_ptr left, + unique_ptr right) { + if (left->type == LogicalOperatorType::LOGICAL_DUMMY_SCAN) { + return right; + } + if (right->type == LogicalOperatorType::LOGICAL_DUMMY_SCAN) { + return left; + } + return make_uniq(std::move(left), std::move(right)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_pragma.cpp b/src/duckdb/src/planner/operator/logical_pragma.cpp new file mode 100644 index 00000000..e694c366 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_pragma.cpp @@ -0,0 +1,9 @@ +#include "duckdb/planner/operator/logical_pragma.hpp" + +namespace duckdb { + +idx_t LogicalPragma::EstimateCardinality(ClientContext &context) { + return 1; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_prepare.cpp b/src/duckdb/src/planner/operator/logical_prepare.cpp new file mode 100644 index 00000000..d1f067a5 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_prepare.cpp @@ -0,0 +1,9 @@ +#include "duckdb/planner/operator/logical_prepare.hpp" + +namespace duckdb { + +idx_t LogicalPrepare::EstimateCardinality(ClientContext &context) { + return 1; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_projection.cpp b/src/duckdb/src/planner/operator/logical_projection.cpp new file mode 100644 index 00000000..141773a6 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_projection.cpp @@ -0,0 +1,34 @@ +#include "duckdb/planner/operator/logical_projection.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +LogicalProjection::LogicalProjection(idx_t table_index, vector> select_list) + : LogicalOperator(LogicalOperatorType::LOGICAL_PROJECTION, std::move(select_list)), table_index(table_index) { +} + +vector LogicalProjection::GetColumnBindings() { + return GenerateColumnBindings(table_index, expressions.size()); +} + +void LogicalProjection::ResolveTypes() { + for (auto &expr : expressions) { + types.push_back(expr->return_type); + } +} + +vector LogicalProjection::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalProjection::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_recursive_cte.cpp b/src/duckdb/src/planner/operator/logical_recursive_cte.cpp new file mode 100644 index 00000000..b6867c0c --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_recursive_cte.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_recursive_cte.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +vector LogicalRecursiveCTE::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalRecursiveCTE::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_reset.cpp b/src/duckdb/src/planner/operator/logical_reset.cpp new file mode 100644 index 00000000..12d663f5 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_reset.cpp @@ -0,0 +1,9 @@ +#include "duckdb/planner/operator/logical_reset.hpp" + +namespace duckdb { + +idx_t LogicalReset::EstimateCardinality(ClientContext &context) { + return 1; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_sample.cpp b/src/duckdb/src/planner/operator/logical_sample.cpp new file mode 100644 index 00000000..59e14e6e --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_sample.cpp @@ -0,0 +1,39 @@ +#include "duckdb/planner/operator/logical_sample.hpp" + +namespace duckdb { + +LogicalSample::LogicalSample() : LogicalOperator(LogicalOperatorType::LOGICAL_SAMPLE) { +} + +LogicalSample::LogicalSample(unique_ptr sample_options_p, unique_ptr child) + : LogicalOperator(LogicalOperatorType::LOGICAL_SAMPLE), sample_options(std::move(sample_options_p)) { + children.push_back(std::move(child)); +} + +vector LogicalSample::GetColumnBindings() { + return children[0]->GetColumnBindings(); +} + +idx_t LogicalSample::EstimateCardinality(ClientContext &context) { + auto child_cardinality = children[0]->EstimateCardinality(context); + if (sample_options->is_percentage) { + double sample_cardinality = + double(child_cardinality) * (sample_options->sample_size.GetValue() / 100.0); + if (sample_cardinality > double(child_cardinality)) { + return child_cardinality; + } + return idx_t(sample_cardinality); + } else { + auto sample_size = sample_options->sample_size.GetValue(); + if (sample_size < child_cardinality) { + return sample_size; + } + } + return child_cardinality; +} + +void LogicalSample::ResolveTypes() { + types = children[0]->types; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_set.cpp b/src/duckdb/src/planner/operator/logical_set.cpp new file mode 100644 index 00000000..0f17cac5 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_set.cpp @@ -0,0 +1,9 @@ +#include "duckdb/planner/operator/logical_set.hpp" + +namespace duckdb { + +idx_t LogicalSet::EstimateCardinality(ClientContext &context) { + return 1; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_set_operation.cpp b/src/duckdb/src/planner/operator/logical_set_operation.cpp new file mode 100644 index 00000000..72da2e05 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_set_operation.cpp @@ -0,0 +1,20 @@ +#include "duckdb/planner/operator/logical_set_operation.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +vector LogicalSetOperation::GetTableIndex() const { + return vector {table_index}; +} + +string LogicalSetOperation::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_simple.cpp b/src/duckdb/src/planner/operator/logical_simple.cpp new file mode 100644 index 00000000..45982c44 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_simple.cpp @@ -0,0 +1,16 @@ +#include "duckdb/planner/operator/logical_simple.hpp" +#include "duckdb/parser/parsed_data/alter_info.hpp" +#include "duckdb/parser/parsed_data/attach_info.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/parser/parsed_data/load_info.hpp" +#include "duckdb/parser/parsed_data/transaction_info.hpp" +#include "duckdb/parser/parsed_data/vacuum_info.hpp" +#include "duckdb/parser/parsed_data/detach_info.hpp" + +namespace duckdb { + +idx_t LogicalSimple::EstimateCardinality(ClientContext &context) { + return 1; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_top_n.cpp b/src/duckdb/src/planner/operator/logical_top_n.cpp new file mode 100644 index 00000000..da1fa493 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_top_n.cpp @@ -0,0 +1,13 @@ +#include "duckdb/planner/operator/logical_top_n.hpp" + +namespace duckdb { + +idx_t LogicalTopN::EstimateCardinality(ClientContext &context) { + auto child_cardinality = LogicalOperator::EstimateCardinality(context); + if (limit >= 0 && child_cardinality < idx_t(limit)) { + return limit; + } + return child_cardinality; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_unconditional_join.cpp b/src/duckdb/src/planner/operator/logical_unconditional_join.cpp new file mode 100644 index 00000000..3a14f5d8 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_unconditional_join.cpp @@ -0,0 +1,26 @@ +#include "duckdb/planner/operator/logical_unconditional_join.hpp" + +namespace duckdb { + +LogicalUnconditionalJoin::LogicalUnconditionalJoin(LogicalOperatorType logical_type, unique_ptr left, + unique_ptr right) + : LogicalOperator(logical_type) { + D_ASSERT(left); + D_ASSERT(right); + children.push_back(std::move(left)); + children.push_back(std::move(right)); +} + +vector LogicalUnconditionalJoin::GetColumnBindings() { + auto left_bindings = children[0]->GetColumnBindings(); + auto right_bindings = children[1]->GetColumnBindings(); + left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end()); + return left_bindings; +} + +void LogicalUnconditionalJoin::ResolveTypes() { + types.insert(types.end(), children[0]->types.begin(), children[0]->types.end()); + types.insert(types.end(), children[1]->types.begin(), children[1]->types.end()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_unnest.cpp b/src/duckdb/src/planner/operator/logical_unnest.cpp new file mode 100644 index 00000000..7d0932cf --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_unnest.cpp @@ -0,0 +1,35 @@ +#include "duckdb/planner/operator/logical_unnest.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +vector LogicalUnnest::GetColumnBindings() { + auto child_bindings = children[0]->GetColumnBindings(); + for (idx_t i = 0; i < expressions.size(); i++) { + child_bindings.emplace_back(unnest_index, i); + } + return child_bindings; +} + +void LogicalUnnest::ResolveTypes() { + types.insert(types.end(), children[0]->types.begin(), children[0]->types.end()); + for (auto &expr : expressions) { + types.push_back(expr->return_type); + } +} + +vector LogicalUnnest::GetTableIndex() const { + return vector {unnest_index}; +} + +string LogicalUnnest::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", unnest_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_update.cpp b/src/duckdb/src/planner/operator/logical_update.cpp new file mode 100644 index 00000000..e66dd36d --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_update.cpp @@ -0,0 +1,46 @@ +#include "duckdb/planner/operator/logical_update.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/main/config.hpp" + +namespace duckdb { + +LogicalUpdate::LogicalUpdate(TableCatalogEntry &table) + : LogicalOperator(LogicalOperatorType::LOGICAL_UPDATE), table(table), table_index(0), return_chunk(false) { +} + +LogicalUpdate::LogicalUpdate(ClientContext &context, const unique_ptr &table_info) + : LogicalOperator(LogicalOperatorType::LOGICAL_UPDATE), + table(Catalog::GetEntry(context, table_info->catalog, table_info->schema, + dynamic_cast(*table_info).table)) { +} + +idx_t LogicalUpdate::EstimateCardinality(ClientContext &context) { + return return_chunk ? LogicalOperator::EstimateCardinality(context) : 1; +} + +vector LogicalUpdate::GetColumnBindings() { + if (return_chunk) { + return GenerateColumnBindings(table_index, table.GetTypes().size()); + } + return {ColumnBinding(0, 0)}; +} + +void LogicalUpdate::ResolveTypes() { + if (return_chunk) { + types = table.GetTypes(); + } else { + types.emplace_back(LogicalType::BIGINT); + } +} + +string LogicalUpdate::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", table_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_window.cpp b/src/duckdb/src/planner/operator/logical_window.cpp new file mode 100644 index 00000000..4d066f41 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_window.cpp @@ -0,0 +1,35 @@ +#include "duckdb/planner/operator/logical_window.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +vector LogicalWindow::GetColumnBindings() { + auto child_bindings = children[0]->GetColumnBindings(); + for (idx_t i = 0; i < expressions.size(); i++) { + child_bindings.emplace_back(window_index, i); + } + return child_bindings; +} + +void LogicalWindow::ResolveTypes() { + types.insert(types.end(), children[0]->types.begin(), children[0]->types.end()); + for (auto &expr : expressions) { + types.push_back(expr->return_type); + } +} + +vector LogicalWindow::GetTableIndex() const { + return vector {window_index}; +} + +string LogicalWindow::GetName() const { +#ifdef DEBUG + if (DBConfigOptions::debug_print_bindings) { + return LogicalOperator::GetName() + StringUtil::Format(" #%llu", window_index); + } +#endif + return LogicalOperator::GetName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/planner.cpp b/src/duckdb/src/planner/planner.cpp new file mode 100644 index 00000000..369e400f --- /dev/null +++ b/src/duckdb/src/planner/planner.cpp @@ -0,0 +1,188 @@ +#include "duckdb/planner/planner.hpp" + +#include "duckdb/common/serializer/binary_deserializer.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" +#include "duckdb/transaction/meta_transaction.hpp" + +namespace duckdb { + +Planner::Planner(ClientContext &context) : binder(Binder::CreateBinder(context)), context(context) { +} + +static void CheckTreeDepth(const LogicalOperator &op, idx_t max_depth, idx_t depth = 0) { + if (depth >= max_depth) { + throw ParserException("Maximum tree depth of %lld exceeded in logical planner", max_depth); + } + for (auto &child : op.children) { + CheckTreeDepth(*child, max_depth, depth + 1); + } +} + +void Planner::CreatePlan(SQLStatement &statement) { + auto &profiler = QueryProfiler::Get(context); + auto parameter_count = statement.n_param; + + BoundParameterMap bound_parameters(parameter_data); + + // first bind the tables and columns to the catalog + bool parameters_resolved = true; + try { + profiler.StartPhase("binder"); + binder->parameters = &bound_parameters; + auto bound_statement = binder->Bind(statement); + profiler.EndPhase(); + + this->names = bound_statement.names; + this->types = bound_statement.types; + this->plan = std::move(bound_statement.plan); + + auto max_tree_depth = ClientConfig::GetConfig(context).max_expression_depth; + CheckTreeDepth(*plan, max_tree_depth); + } catch (const ParameterNotResolvedException &ex) { + // parameter types could not be resolved + this->names = {"unknown"}; + this->types = {LogicalTypeId::UNKNOWN}; + this->plan = nullptr; + parameters_resolved = false; + } catch (const Exception &ex) { + auto &config = DBConfig::GetConfig(context); + + this->plan = nullptr; + for (auto &extension_op : config.operator_extensions) { + auto bound_statement = + extension_op->Bind(context, *this->binder, extension_op->operator_info.get(), statement); + if (bound_statement.plan != nullptr) { + this->names = bound_statement.names; + this->types = bound_statement.types; + this->plan = std::move(bound_statement.plan); + break; + } + } + + if (!this->plan) { + throw; + } + } catch (std::exception &ex) { + throw; + } + this->properties = binder->properties; + this->properties.parameter_count = parameter_count; + properties.bound_all_parameters = parameters_resolved; + + Planner::VerifyPlan(context, plan, bound_parameters.GetParametersPtr()); + + // set up a map of parameter number -> value entries + for (auto &kv : bound_parameters.GetParameters()) { + auto &identifier = kv.first; + auto ¶m = kv.second; + // check if the type of the parameter could be resolved + if (!param->return_type.IsValid()) { + properties.bound_all_parameters = false; + continue; + } + param->SetValue(Value(param->return_type)); + value_map[identifier] = param; + } +} + +shared_ptr Planner::PrepareSQLStatement(unique_ptr statement) { + auto copied_statement = statement->Copy(); + // create a plan of the underlying statement + CreatePlan(std::move(statement)); + // now create the logical prepare + auto prepared_data = make_shared(copied_statement->type); + prepared_data->unbound_statement = std::move(copied_statement); + prepared_data->names = names; + prepared_data->types = types; + prepared_data->value_map = std::move(value_map); + prepared_data->properties = properties; + prepared_data->catalog_version = MetaTransaction::Get(context).catalog_version; + return prepared_data; +} + +void Planner::CreatePlan(unique_ptr statement) { + D_ASSERT(statement); + switch (statement->type) { + case StatementType::SELECT_STATEMENT: + case StatementType::INSERT_STATEMENT: + case StatementType::COPY_STATEMENT: + case StatementType::DELETE_STATEMENT: + case StatementType::UPDATE_STATEMENT: + case StatementType::CREATE_STATEMENT: + case StatementType::DROP_STATEMENT: + case StatementType::ALTER_STATEMENT: + case StatementType::TRANSACTION_STATEMENT: + case StatementType::EXPLAIN_STATEMENT: + case StatementType::VACUUM_STATEMENT: + case StatementType::RELATION_STATEMENT: + case StatementType::CALL_STATEMENT: + case StatementType::EXPORT_STATEMENT: + case StatementType::PRAGMA_STATEMENT: + case StatementType::SHOW_STATEMENT: + case StatementType::SET_STATEMENT: + case StatementType::LOAD_STATEMENT: + case StatementType::EXTENSION_STATEMENT: + case StatementType::PREPARE_STATEMENT: + case StatementType::EXECUTE_STATEMENT: + case StatementType::LOGICAL_PLAN_STATEMENT: + case StatementType::ATTACH_STATEMENT: + case StatementType::DETACH_STATEMENT: + CreatePlan(*statement); + break; + default: + throw NotImplementedException("Cannot plan statement of type %s!", StatementTypeToString(statement->type)); + } +} + +static bool OperatorSupportsSerialization(LogicalOperator &op) { + for (auto &child : op.children) { + if (!OperatorSupportsSerialization(*child)) { + return false; + } + } + return op.SupportSerialization(); +} + +void Planner::VerifyPlan(ClientContext &context, unique_ptr &op, + optional_ptr map) { +#ifdef DUCKDB_ALTERNATIVE_VERIFY + // if alternate verification is enabled we run the original operator + return; +#endif + if (!op || !ClientConfig::GetConfig(context).verify_serializer) { + return; + } + //! SELECT only for now + if (!OperatorSupportsSerialization(*op)) { + return; + } + + // format (de)serialization of this operator + try { + MemoryStream stream; + BinarySerializer::Serialize(*op, stream, true); + stream.Rewind(); + bound_parameter_map_t parameters; + auto new_plan = BinaryDeserializer::Deserialize(stream, context, parameters); + + if (map) { + *map = std::move(parameters); + } + op = std::move(new_plan); + } catch (SerializationException &ex) { + // pass + } catch (NotImplementedException &ex) { + // pass + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/pragma_handler.cpp b/src/duckdb/src/planner/pragma_handler.cpp new file mode 100644 index 00000000..34ed4240 --- /dev/null +++ b/src/duckdb/src/planner/pragma_handler.cpp @@ -0,0 +1,93 @@ +#include "duckdb/planner/pragma_handler.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/parser.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/parser/statement/multi_statement.hpp" +#include "duckdb/parser/parsed_data/pragma_info.hpp" +#include "duckdb/function/function.hpp" + +#include "duckdb/main/client_context.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +PragmaHandler::PragmaHandler(ClientContext &context) : context(context) { +} + +void PragmaHandler::HandlePragmaStatementsInternal(vector> &statements) { + vector> new_statements; + for (idx_t i = 0; i < statements.size(); i++) { + if (statements[i]->type == StatementType::MULTI_STATEMENT) { + auto &multi_statement = statements[i]->Cast(); + for (auto &stmt : multi_statement.statements) { + statements.push_back(std::move(stmt)); + } + continue; + } + if (statements[i]->type == StatementType::PRAGMA_STATEMENT) { + // PRAGMA statement: check if we need to replace it by a new set of statements + PragmaHandler handler(context); + string new_query; + bool expanded = handler.HandlePragma(statements[i].get(), new_query); + if (expanded) { + // this PRAGMA statement gets replaced by a new query string + // push the new query string through the parser again and add it to the transformer + Parser parser(context.GetParserOptions()); + parser.ParseQuery(new_query); + // insert the new statements and remove the old statement + for (idx_t j = 0; j < parser.statements.size(); j++) { + new_statements.push_back(std::move(parser.statements[j])); + } + continue; + } + } + new_statements.push_back(std::move(statements[i])); + } + statements = std::move(new_statements); +} + +void PragmaHandler::HandlePragmaStatements(ClientContextLock &lock, vector> &statements) { + // first check if there are any pragma statements + bool found_pragma = false; + for (idx_t i = 0; i < statements.size(); i++) { + if (statements[i]->type == StatementType::PRAGMA_STATEMENT || + statements[i]->type == StatementType::MULTI_STATEMENT) { + found_pragma = true; + break; + } + } + if (!found_pragma) { + // no pragmas: skip this step + return; + } + context.RunFunctionInTransactionInternal(lock, [&]() { HandlePragmaStatementsInternal(statements); }); +} + +bool PragmaHandler::HandlePragma(SQLStatement *statement, string &resulting_query) { // PragmaInfo &info + auto info = *(statement->Cast()).info; + auto &entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name); + string error; + + FunctionBinder function_binder(context); + idx_t bound_idx = function_binder.BindFunction(entry.name, entry.functions, info, error); + if (bound_idx == DConstants::INVALID_INDEX) { + throw BinderException(error); + } + auto bound_function = entry.functions.GetFunctionByOffset(bound_idx); + if (bound_function.query) { + QueryErrorContext error_context(statement, statement->stmt_location); + Binder::BindNamedParameters(bound_function.named_parameters, info.named_parameters, error_context, + bound_function.name); + FunctionParameters parameters {info.parameters, info.named_parameters}; + resulting_query = bound_function.query(context, parameters); + return true; + } + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp new file mode 100644 index 00000000..9c3632fe --- /dev/null +++ b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp @@ -0,0 +1,594 @@ +#include "duckdb/planner/subquery/flatten_dependent_join.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/planner/subquery/has_correlated_expressions.hpp" +#include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" +#include "duckdb/planner/operator/logical_dependent_join.hpp" + +namespace duckdb { + +FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector &correlated, + bool perform_delim, bool any_join) + : binder(binder), delim_offset(DConstants::INVALID_INDEX), correlated_columns(correlated), + perform_delim(perform_delim), any_join(any_join) { + for (idx_t i = 0; i < correlated_columns.size(); i++) { + auto &col = correlated_columns[i]; + correlated_map[col.binding] = i; + delim_types.push_back(col.type); + } +} + +bool FlattenDependentJoins::DetectCorrelatedExpressions(LogicalOperator *op, bool lateral, idx_t lateral_depth) { + + bool is_lateral_join = false; + + D_ASSERT(op); + // check if this entry has correlated expressions + if (op->type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { + is_lateral_join = true; + } + HasCorrelatedExpressions visitor(correlated_columns, lateral, lateral_depth); + visitor.VisitOperator(*op); + bool has_correlation = visitor.has_correlated_expressions; + int child_idx = 0; + // now visit the children of this entry and check if they have correlated expressions + for (auto &child : op->children) { + auto new_lateral_depth = lateral_depth; + if (is_lateral_join && child_idx == 1) { + new_lateral_depth = lateral_depth + 1; + } + // we OR the property with its children such that has_correlation is true if either + // (1) this node has a correlated expression or + // (2) one of its children has a correlated expression + if (DetectCorrelatedExpressions(child.get(), lateral, new_lateral_depth)) { + has_correlation = true; + } + child_idx++; + } + // set the entry in the map + has_correlated_expressions[op] = has_correlation; + return has_correlation; +} + +unique_ptr FlattenDependentJoins::PushDownDependentJoin(unique_ptr plan) { + bool propagate_null_values = true; + auto result = PushDownDependentJoinInternal(std::move(plan), propagate_null_values, 0); + if (!replacement_map.empty()) { + // check if we have to replace any COUNT aggregates into "CASE WHEN X IS NULL THEN 0 ELSE COUNT END" + RewriteCountAggregates aggr(replacement_map); + aggr.VisitOperator(*result); + } + return result; +} + +bool SubqueryDependentFilter(Expression *expr) { + if (expr->expression_class == ExpressionClass::BOUND_CONJUNCTION && + expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { + auto &bound_conjuction = expr->Cast(); + for (auto &child : bound_conjuction.children) { + if (SubqueryDependentFilter(child.get())) { + return true; + } + } + } + if (expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { + return true; + } + return false; +} + +unique_ptr FlattenDependentJoins::PushDownDependentJoinInternal(unique_ptr plan, + bool &parent_propagate_null_values, + idx_t lateral_depth) { + // first check if the logical operator has correlated expressions + auto entry = has_correlated_expressions.find(plan.get()); + D_ASSERT(entry != has_correlated_expressions.end()); + if (!entry->second) { + // we reached a node without correlated expressions + // we can eliminate the dependent join now and create a simple cross product + // now create the duplicate eliminated scan for this node + auto left_columns = plan->GetColumnBindings().size(); + auto delim_index = binder.GenerateTableIndex(); + this->base_binding = ColumnBinding(delim_index, 0); + this->delim_offset = left_columns; + this->data_offset = 0; + auto delim_scan = make_uniq(delim_index, delim_types); + return LogicalCrossProduct::Create(std::move(plan), std::move(delim_scan)); + } + switch (plan->type) { + case LogicalOperatorType::LOGICAL_UNNEST: + case LogicalOperatorType::LOGICAL_FILTER: { + // filter + // first we flatten the dependent join in the child of the filter + for (auto &expr : plan->expressions) { + any_join |= SubqueryDependentFilter(expr.get()); + } + plan->children[0] = + PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); + + // then we replace any correlated expressions with the corresponding entry in the correlated_map + RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); + rewriter.VisitOperator(*plan); + return plan; + } + case LogicalOperatorType::LOGICAL_PROJECTION: { + // projection + // first we flatten the dependent join in the child of the projection + for (auto &expr : plan->expressions) { + parent_propagate_null_values &= expr->PropagatesNullValues(); + } + plan->children[0] = + PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); + + // then we replace any correlated expressions with the corresponding entry in the correlated_map + RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); + rewriter.VisitOperator(*plan); + // now we add all the columns of the delim_scan to the projection list + auto &proj = plan->Cast(); + for (idx_t i = 0; i < correlated_columns.size(); i++) { + auto &col = correlated_columns[i]; + auto colref = make_uniq( + col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); + plan->expressions.push_back(std::move(colref)); + } + + base_binding.table_index = proj.table_index; + this->delim_offset = base_binding.column_index = plan->expressions.size() - correlated_columns.size(); + this->data_offset = 0; + return plan; + } + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + auto &aggr = plan->Cast(); + // aggregate and group by + // first we flatten the dependent join in the child of the projection + for (auto &expr : plan->expressions) { + parent_propagate_null_values &= expr->PropagatesNullValues(); + } + plan->children[0] = + PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); + // then we replace any correlated expressions with the corresponding entry in the correlated_map + RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); + rewriter.VisitOperator(*plan); + // now we add all the columns of the delim_scan to the grouping operators AND the projection list + idx_t delim_table_index; + idx_t delim_column_offset; + idx_t delim_data_offset; + auto new_group_count = perform_delim ? correlated_columns.size() : 1; + for (idx_t i = 0; i < new_group_count; i++) { + auto &col = correlated_columns[i]; + auto colref = make_uniq( + col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); + for (auto &set : aggr.grouping_sets) { + set.insert(aggr.groups.size()); + } + aggr.groups.push_back(std::move(colref)); + } + if (!perform_delim) { + // if we are not performing the duplicate elimination, we have only added the row_id column to the grouping + // operators in this case, we push a FIRST aggregate for each of the remaining expressions + delim_table_index = aggr.aggregate_index; + delim_column_offset = aggr.expressions.size(); + delim_data_offset = aggr.groups.size(); + for (idx_t i = 0; i < correlated_columns.size(); i++) { + auto &col = correlated_columns[i]; + auto first_aggregate = FirstFun::GetFunction(col.type); + auto colref = make_uniq( + col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); + vector> aggr_children; + aggr_children.push_back(std::move(colref)); + auto first_fun = + make_uniq(std::move(first_aggregate), std::move(aggr_children), nullptr, + nullptr, AggregateType::NON_DISTINCT); + aggr.expressions.push_back(std::move(first_fun)); + } + } else { + delim_table_index = aggr.group_index; + delim_column_offset = aggr.groups.size() - correlated_columns.size(); + delim_data_offset = aggr.groups.size(); + } + if (aggr.groups.size() == new_group_count) { + // we have to perform a LEFT OUTER JOIN between the result of this aggregate and the delim scan + // FIXME: this does not always have to be a LEFT OUTER JOIN, depending on whether aggr.expressions return + // NULL or a value + unique_ptr join = make_uniq(JoinType::INNER); + for (auto &aggr_exp : aggr.expressions) { + auto &b_aggr_exp = aggr_exp->Cast(); + if (!b_aggr_exp.PropagatesNullValues() || any_join || !parent_propagate_null_values) { + join = make_uniq(JoinType::LEFT); + break; + } + } + auto left_index = binder.GenerateTableIndex(); + auto delim_scan = make_uniq(left_index, delim_types); + join->children.push_back(std::move(delim_scan)); + join->children.push_back(std::move(plan)); + for (idx_t i = 0; i < new_group_count; i++) { + auto &col = correlated_columns[i]; + JoinCondition cond; + cond.left = make_uniq(col.name, col.type, ColumnBinding(left_index, i)); + cond.right = make_uniq( + correlated_columns[i].type, ColumnBinding(delim_table_index, delim_column_offset + i)); + cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + join->conditions.push_back(std::move(cond)); + } + // for any COUNT aggregate we replace references to the column with: CASE WHEN COUNT(*) IS NULL THEN 0 + // ELSE COUNT(*) END + for (idx_t i = 0; i < aggr.expressions.size(); i++) { + D_ASSERT(aggr.expressions[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); + auto &bound = aggr.expressions[i]->Cast(); + vector arguments; + if (bound.function == CountFun::GetFunction() || bound.function == CountStarFun::GetFunction()) { + // have to replace this ColumnBinding with the CASE expression + replacement_map[ColumnBinding(aggr.aggregate_index, i)] = i; + } + } + // now we update the delim_index + base_binding.table_index = left_index; + this->delim_offset = base_binding.column_index = 0; + this->data_offset = 0; + return std::move(join); + } else { + // update the delim_index + base_binding.table_index = delim_table_index; + this->delim_offset = base_binding.column_index = delim_column_offset; + this->data_offset = delim_data_offset; + return plan; + } + } + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: { + // cross product + // push into both sides of the plan + bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second; + bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second; + if (!right_has_correlation) { + // only left has correlation: push into left + plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), + parent_propagate_null_values, lateral_depth); + return plan; + } + if (!left_has_correlation) { + // only right has correlation: push into right + plan->children[1] = PushDownDependentJoinInternal(std::move(plan->children[1]), + parent_propagate_null_values, lateral_depth); + return plan; + } + // both sides have correlation + // turn into an inner join + auto join = make_uniq(JoinType::INNER); + plan->children[0] = + PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); + auto left_binding = this->base_binding; + plan->children[1] = + PushDownDependentJoinInternal(std::move(plan->children[1]), parent_propagate_null_values, lateral_depth); + // add the correlated columns to the join conditions + for (idx_t i = 0; i < correlated_columns.size(); i++) { + JoinCondition cond; + cond.left = make_uniq( + correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i)); + cond.right = make_uniq( + correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); + cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + join->conditions.push_back(std::move(cond)); + } + join->children.push_back(std::move(plan->children[0])); + join->children.push_back(std::move(plan->children[1])); + return std::move(join); + } + case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: { + auto &dependent_join = plan->Cast(); + if (!((dependent_join.join_type == JoinType::INNER) || (dependent_join.join_type == JoinType::LEFT))) { + throw Exception("Dependent join can only be INNER or LEFT type"); + } + D_ASSERT(plan->children.size() == 2); + // Push all the bindings down to the left side so the right side knows where to refer DELIM_GET from + plan->children[0] = + PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); + + // Normal rewriter like in other joins + RewriteCorrelatedExpressions rewriter(this->base_binding, correlated_map, lateral_depth); + rewriter.VisitOperator(*plan); + + // Recursive rewriter to visit right side of lateral join and update bindings from left + RewriteCorrelatedExpressions recursive_rewriter(this->base_binding, correlated_map, lateral_depth + 1, true); + recursive_rewriter.VisitOperator(*plan->children[1]); + + return plan; + } + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + auto &join = plan->Cast(); + D_ASSERT(plan->children.size() == 2); + // check the correlated expressions in the children of the join + bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second; + bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second; + + if (join.join_type == JoinType::INNER) { + // inner join + if (!right_has_correlation) { + // only left has correlation: push into left + plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), + parent_propagate_null_values, lateral_depth); + // Remove the correlated columns coming from outside for current join node + return plan; + } + if (!left_has_correlation) { + // only right has correlation: push into right + plan->children[1] = PushDownDependentJoinInternal(std::move(plan->children[1]), + parent_propagate_null_values, lateral_depth); + // Remove the correlated columns coming from outside for current join node + return plan; + } + } else if (join.join_type == JoinType::LEFT) { + // left outer join + if (!right_has_correlation) { + // only left has correlation: push into left + plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), + parent_propagate_null_values, lateral_depth); + // Remove the correlated columns coming from outside for current join node + return plan; + } + } else if (join.join_type == JoinType::RIGHT) { + // left outer join + if (!left_has_correlation) { + // only right has correlation: push into right + plan->children[1] = PushDownDependentJoinInternal(std::move(plan->children[1]), + parent_propagate_null_values, lateral_depth); + return plan; + } + } else if (join.join_type == JoinType::MARK) { + if (right_has_correlation) { + throw Exception("MARK join with correlation in RHS not supported"); + } + // push the child into the LHS + plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), + parent_propagate_null_values, lateral_depth); + // rewrite expressions in the join conditions + RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); + rewriter.VisitOperator(*plan); + return plan; + } else { + throw Exception("Unsupported join type for flattening correlated subquery"); + } + // both sides have correlation + // push into both sides + plan->children[0] = + PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); + auto left_binding = this->base_binding; + plan->children[1] = + PushDownDependentJoinInternal(std::move(plan->children[1]), parent_propagate_null_values, lateral_depth); + auto right_binding = this->base_binding; + // NOTE: for OUTER JOINS it matters what the BASE BINDING is after the join + // for the LEFT OUTER JOIN, we want the LEFT side to be the base binding after we push + // because the RIGHT binding might contain NULL values + if (join.join_type == JoinType::LEFT) { + this->base_binding = left_binding; + } else if (join.join_type == JoinType::RIGHT) { + this->base_binding = right_binding; + } + // add the correlated columns to the join conditions + for (idx_t i = 0; i < correlated_columns.size(); i++) { + auto left = make_uniq( + correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i)); + auto right = make_uniq( + correlated_columns[i].type, ColumnBinding(right_binding.table_index, right_binding.column_index + i)); + + if (join.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || + join.type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { + JoinCondition cond; + cond.left = std::move(left); + cond.right = std::move(right); + cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + + auto &comparison_join = join.Cast(); + comparison_join.conditions.push_back(std::move(cond)); + } else { + auto &any_join = join.Cast(); + auto comparison = make_uniq(ExpressionType::COMPARE_NOT_DISTINCT_FROM, + std::move(left), std::move(right)); + auto conjunction = make_uniq( + ExpressionType::CONJUNCTION_AND, std::move(comparison), std::move(any_join.condition)); + any_join.condition = std::move(conjunction); + } + } + // then we replace any correlated expressions with the corresponding entry in the correlated_map + RewriteCorrelatedExpressions rewriter(right_binding, correlated_map, lateral_depth); + rewriter.VisitOperator(*plan); + return plan; + } + case LogicalOperatorType::LOGICAL_LIMIT: { + auto &limit = plan->Cast(); + if (limit.limit || limit.offset) { + throw ParserException("Non-constant limit or offset not supported in correlated subquery"); + } + auto rownum_alias = "limit_rownum"; + unique_ptr child; + unique_ptr order_by; + + // check if the direct child of this LIMIT node is an ORDER BY node, if so, keep it separate + // this is done for an optimization to avoid having to compute the total order + if (plan->children[0]->type == LogicalOperatorType::LOGICAL_ORDER_BY) { + order_by = unique_ptr_cast(std::move(plan->children[0])); + child = PushDownDependentJoinInternal(std::move(order_by->children[0]), parent_propagate_null_values, + lateral_depth); + } else { + child = PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, + lateral_depth); + } + auto child_column_count = child->GetColumnBindings().size(); + // we push a row_number() OVER (PARTITION BY [correlated columns]) + auto window_index = binder.GenerateTableIndex(); + auto window = make_uniq(window_index); + auto row_number = + make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); + auto partition_count = perform_delim ? correlated_columns.size() : 1; + for (idx_t i = 0; i < partition_count; i++) { + auto &col = correlated_columns[i]; + auto colref = make_uniq( + col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); + row_number->partitions.push_back(std::move(colref)); + } + if (order_by) { + // optimization: if there is an ORDER BY node followed by a LIMIT + // rather than computing the entire order, we push the ORDER BY expressions into the row_num computation + // this way, the order only needs to be computed per partition + row_number->orders = std::move(order_by->orders); + } + row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; + row_number->end = WindowBoundary::CURRENT_ROW_ROWS; + window->expressions.push_back(std::move(row_number)); + window->children.push_back(std::move(child)); + + // add a filter based on the row_number + // the filter we add is "row_number > offset AND row_number <= offset + limit" + auto filter = make_uniq(); + unique_ptr condition; + auto row_num_ref = + make_uniq(rownum_alias, LogicalType::BIGINT, ColumnBinding(window_index, 0)); + + int64_t upper_bound_limit = NumericLimits::Maximum(); + TryAddOperator::Operation(limit.offset_val, limit.limit_val, upper_bound_limit); + auto upper_bound = make_uniq(Value::BIGINT(upper_bound_limit)); + condition = make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, row_num_ref->Copy(), + std::move(upper_bound)); + // we only need to add "row_number >= offset + 1" if offset is bigger than 0 + if (limit.offset_val > 0) { + auto lower_bound = make_uniq(Value::BIGINT(limit.offset_val)); + auto lower_comp = make_uniq(ExpressionType::COMPARE_GREATERTHAN, + row_num_ref->Copy(), std::move(lower_bound)); + auto conj = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(lower_comp), + std::move(condition)); + condition = std::move(conj); + } + filter->expressions.push_back(std::move(condition)); + filter->children.push_back(std::move(window)); + // we prune away the row_number after the filter clause using the projection map + for (idx_t i = 0; i < child_column_count; i++) { + filter->projection_map.push_back(i); + } + return std::move(filter); + } + case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: { + // NOTE: limit percent could be supported in a manner similar to the LIMIT above + // but instead of filtering by an exact number of rows, the limit should be expressed as + // COUNT computed over the partition multiplied by the percentage + throw ParserException("Limit percent operator not supported in correlated subquery"); + } + case LogicalOperatorType::LOGICAL_WINDOW: { + auto &window = plan->Cast(); + // push into children + plan->children[0] = + PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); + // add the correlated columns to the PARTITION BY clauses in the Window + for (auto &expr : window.expressions) { + D_ASSERT(expr->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); + auto &w = expr->Cast(); + for (idx_t i = 0; i < correlated_columns.size(); i++) { + w.partitions.push_back(make_uniq( + correlated_columns[i].type, + ColumnBinding(base_binding.table_index, base_binding.column_index + i))); + } + } + return plan; + } + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + case LogicalOperatorType::LOGICAL_UNION: { + auto &setop = plan->Cast(); + // set operator, push into both children +#ifdef DEBUG + plan->children[0]->ResolveOperatorTypes(); + plan->children[1]->ResolveOperatorTypes(); + D_ASSERT(plan->children[0]->types == plan->children[1]->types); +#endif + plan->children[0] = PushDownDependentJoin(std::move(plan->children[0])); + plan->children[1] = PushDownDependentJoin(std::move(plan->children[1])); +#ifdef DEBUG + D_ASSERT(plan->children[0]->GetColumnBindings().size() == plan->children[1]->GetColumnBindings().size()); + plan->children[0]->ResolveOperatorTypes(); + plan->children[1]->ResolveOperatorTypes(); + D_ASSERT(plan->children[0]->types == plan->children[1]->types); +#endif + // we have to refer to the setop index now + base_binding.table_index = setop.table_index; + base_binding.column_index = setop.column_count; + setop.column_count += correlated_columns.size(); + return plan; + } + case LogicalOperatorType::LOGICAL_DISTINCT: { + auto &distinct = plan->Cast(); + // push down into child + distinct.children[0] = PushDownDependentJoin(std::move(distinct.children[0])); + // add all correlated columns to the distinct targets + for (idx_t i = 0; i < correlated_columns.size(); i++) { + distinct.distinct_targets.push_back(make_uniq( + correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i))); + } + return plan; + } + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { + // expression get + // first we flatten the dependent join in the child + plan->children[0] = + PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); + // then we replace any correlated expressions with the corresponding entry in the correlated_map + RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); + rewriter.VisitOperator(*plan); + // now we add all the correlated columns to each of the expressions of the expression scan + auto &expr_get = plan->Cast(); + for (idx_t i = 0; i < correlated_columns.size(); i++) { + for (auto &expr_list : expr_get.expressions) { + auto colref = make_uniq( + correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); + expr_list.push_back(std::move(colref)); + } + expr_get.expr_types.push_back(correlated_columns[i].type); + } + + base_binding.table_index = expr_get.table_index; + this->delim_offset = base_binding.column_index = expr_get.expr_types.size() - correlated_columns.size(); + this->data_offset = 0; + return plan; + } + case LogicalOperatorType::LOGICAL_PIVOT: + throw BinderException("PIVOT is not supported in correlated subqueries yet"); + case LogicalOperatorType::LOGICAL_ORDER_BY: + plan->children[0] = PushDownDependentJoin(std::move(plan->children[0])); + return plan; + case LogicalOperatorType::LOGICAL_GET: { + auto &get = plan->Cast(); + if (get.children.size() != 1) { + throw InternalException("Flatten dependent joins - logical get encountered without children"); + } + plan->children[0] = PushDownDependentJoin(std::move(plan->children[0])); + for (idx_t i = 0; i < correlated_columns.size(); i++) { + get.projected_input.push_back(this->delim_offset + i); + } + this->delim_offset = get.returned_types.size(); + this->data_offset = 0; + return plan; + } + case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: { + throw BinderException("Recursive CTEs not (yet) supported in correlated subquery"); + } + case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: { + throw BinderException("Materialized CTEs not (yet) supported in correlated subquery"); + } + case LogicalOperatorType::LOGICAL_DELIM_JOIN: { + throw BinderException("Nested lateral joins or lateral joins in correlated subqueries are not (yet) supported"); + } + case LogicalOperatorType::LOGICAL_SAMPLE: + throw BinderException("Sampling in correlated subqueries is not (yet) supported"); + default: + throw InternalException("Logical operator type \"%s\" for dependent join", LogicalOperatorToString(plan->type)); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp new file mode 100644 index 00000000..d53cb984 --- /dev/null +++ b/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp @@ -0,0 +1,66 @@ +#include "duckdb/planner/subquery/has_correlated_expressions.hpp" + +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_subquery_expression.hpp" + +#include + +namespace duckdb { + +HasCorrelatedExpressions::HasCorrelatedExpressions(const vector &correlated, bool lateral, + idx_t lateral_depth) + : has_correlated_expressions(false), lateral(lateral), correlated_columns(correlated), + lateral_depth(lateral_depth) { +} + +void HasCorrelatedExpressions::VisitOperator(LogicalOperator &op) { + VisitOperatorExpressions(op); +} + +unique_ptr HasCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + // Indicates local correlations (all correlations within a child) for the root + if (expr.depth <= lateral_depth) { + return nullptr; + } + + // Should never happen + if (expr.depth > 1 + lateral_depth) { + if (lateral) { + throw BinderException("Invalid lateral depth encountered for an expression"); + } + throw InternalException("Expression with depth > 1 detected in non-lateral join"); + } + // Note: This is added, since we only want to set has_correlated_expressions to true when the + // BoundSubqueryExpression has the same bindings as one of the correlated_columns from the left hand side + // (correlated_columns is the correlated_columns from left hand side) + bool found_match = false; + for (idx_t i = 0; i < correlated_columns.size(); i++) { + if (correlated_columns[i].binding == expr.binding) { + found_match = true; + break; + } + } + // correlated column reference + D_ASSERT(expr.depth == lateral_depth + 1); + has_correlated_expressions = found_match; + return nullptr; +} + +unique_ptr HasCorrelatedExpressions::VisitReplace(BoundSubqueryExpression &expr, + unique_ptr *expr_ptr) { + if (!expr.IsCorrelated()) { + return nullptr; + } + // check if the subquery contains any of the correlated expressions that we are concerned about in this node + for (idx_t i = 0; i < correlated_columns.size(); i++) { + if (std::find(expr.binder->correlated_columns.begin(), expr.binder->correlated_columns.end(), + correlated_columns[i]) != expr.binder->correlated_columns.end()) { + has_correlated_expressions = true; + break; + } + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp new file mode 100644 index 00000000..fd177ccf --- /dev/null +++ b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp @@ -0,0 +1,174 @@ +#include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" + +#include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_subquery_expression.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/tableref/bound_joinref.hpp" +#include "duckdb/planner/operator/logical_dependent_join.hpp" + +namespace duckdb { + +RewriteCorrelatedExpressions::RewriteCorrelatedExpressions(ColumnBinding base_binding, + column_binding_map_t &correlated_map, + idx_t lateral_depth, bool recursive_rewrite) + : base_binding(base_binding), correlated_map(correlated_map), lateral_depth(lateral_depth), + recursive_rewrite(recursive_rewrite) { +} + +void RewriteCorrelatedExpressions::VisitOperator(LogicalOperator &op) { + if (recursive_rewrite) { + // Update column bindings from left child of lateral to right child + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { + D_ASSERT(op.children.size() == 2); + VisitOperator(*op.children[0]); + lateral_depth++; + VisitOperator(*op.children[1]); + lateral_depth--; + } else { + VisitOperatorChildren(op); + } + } + // update the bindings in the correlated columns of the dependendent join + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { + auto &plan = op.Cast(); + for (auto &corr : plan.correlated_columns) { + auto entry = correlated_map.find(corr.binding); + if (entry != correlated_map.end()) { + corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); + } + } + } + VisitOperatorExpressions(op); +} + +unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + if (expr.depth <= lateral_depth) { + // Indicates local correlations not relevant for the current the rewrite + return nullptr; + } + // correlated column reference + // replace with the entry referring to the duplicate eliminated scan + // if this assertion occurs it generally means the bindings are inappropriate set in the binder or + // we either missed to account for lateral binder or over-counted for the lateral binder + D_ASSERT(expr.depth == 1 + lateral_depth); + auto entry = correlated_map.find(expr.binding); + D_ASSERT(entry != correlated_map.end()); + + expr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); + if (recursive_rewrite) { + D_ASSERT(expr.depth > 1); + expr.depth--; + } else { + expr.depth = 0; + } + return nullptr; +} + +unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryExpression &expr, + unique_ptr *expr_ptr) { + if (!expr.IsCorrelated()) { + return nullptr; + } + // subquery detected within this subquery + // recursively rewrite it using the RewriteCorrelatedRecursive class + RewriteCorrelatedRecursive rewrite(expr, base_binding, correlated_map); + rewrite.RewriteCorrelatedSubquery(expr); + return nullptr; +} + +RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedRecursive( + BoundSubqueryExpression &parent, ColumnBinding base_binding, column_binding_map_t &correlated_map) + : parent(parent), base_binding(base_binding), correlated_map(correlated_map) { +} + +void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteJoinRefRecursive(BoundTableRef &ref) { + // recursively rewrite bindings in the correlated columns for the table ref and all the children + if (ref.type == TableReferenceType::JOIN) { + auto &bound_join = ref.Cast(); + for (auto &corr : bound_join.correlated_columns) { + auto entry = correlated_map.find(corr.binding); + if (entry != correlated_map.end()) { + corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); + } + } + RewriteJoinRefRecursive(*bound_join.left); + RewriteJoinRefRecursive(*bound_join.right); + } +} + +void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedSubquery( + BoundSubqueryExpression &expr) { + // rewrite the binding in the correlated list of the subquery) + for (auto &corr : expr.binder->correlated_columns) { + auto entry = correlated_map.find(corr.binding); + if (entry != correlated_map.end()) { + corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); + } + } + // TODO: Cleanup and find a better way to do this + auto &node = *expr.subquery; + if (node.type == QueryNodeType::SELECT_NODE) { + // Found an unplanned select node, need to update column bindings correlated columns in the from tables + auto &bound_select = node.Cast(); + if (bound_select.from_table) { + BoundTableRef &table_ref = *bound_select.from_table; + RewriteJoinRefRecursive(table_ref); + } + } + // now rewrite any correlated BoundColumnRef expressions inside the subquery + ExpressionIterator::EnumerateQueryNodeChildren(*expr.subquery, + [&](Expression &child) { RewriteCorrelatedExpressions(child); }); +} + +void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedExpressions(Expression &child) { + if (child.type == ExpressionType::BOUND_COLUMN_REF) { + // bound column reference + auto &bound_colref = child.Cast(); + if (bound_colref.depth == 0) { + // not a correlated column, ignore + return; + } + // correlated column + // check the correlated map + auto entry = correlated_map.find(bound_colref.binding); + if (entry != correlated_map.end()) { + // we found the column in the correlated map! + // update the binding and reduce the depth by 1 + bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); + bound_colref.depth--; + } + } else if (child.type == ExpressionType::SUBQUERY) { + // we encountered another subquery: rewrite recursively + D_ASSERT(child.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); + auto &bound_subquery = child.Cast(); + RewriteCorrelatedRecursive rewrite(bound_subquery, base_binding, correlated_map); + rewrite.RewriteCorrelatedSubquery(bound_subquery); + } +} + +RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t &replacement_map) + : replacement_map(replacement_map) { +} + +unique_ptr RewriteCountAggregates::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + auto entry = replacement_map.find(expr.binding); + if (entry != replacement_map.end()) { + // reference to a COUNT(*) aggregate + // replace this with CASE WHEN COUNT(*) IS NULL THEN 0 ELSE COUNT(*) END + auto is_null = make_uniq(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN); + is_null->children.push_back(expr.Copy()); + auto check = std::move(is_null); + auto result_if_true = make_uniq(Value::Numeric(expr.return_type, 0)); + auto result_if_false = std::move(*expr_ptr); + return make_uniq(std::move(check), std::move(result_if_true), std::move(result_if_false)); + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/table_binding.cpp b/src/duckdb/src/planner/table_binding.cpp new file mode 100644 index 00000000..aec069f3 --- /dev/null +++ b/src/duckdb/src/planner/table_binding.cpp @@ -0,0 +1,262 @@ +#include "duckdb/planner/table_binding.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/planner/bind_context.hpp" +#include "duckdb/planner/bound_query_node.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_lambdaref_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" + +#include + +namespace duckdb { + +Binding::Binding(BindingType binding_type, const string &alias, vector coltypes, vector colnames, + idx_t index) + : binding_type(binding_type), alias(alias), index(index), types(std::move(coltypes)), names(std::move(colnames)) { + D_ASSERT(types.size() == names.size()); + for (idx_t i = 0; i < names.size(); i++) { + auto &name = names[i]; + D_ASSERT(!name.empty()); + if (name_map.find(name) != name_map.end()) { + throw BinderException("table \"%s\" has duplicate column name \"%s\"", alias, name); + } + name_map[name] = i; + } +} + +bool Binding::TryGetBindingIndex(const string &column_name, column_t &result) { + auto entry = name_map.find(column_name); + if (entry == name_map.end()) { + return false; + } + auto column_info = entry->second; + result = column_info; + return true; +} + +column_t Binding::GetBindingIndex(const string &column_name) { + column_t result; + if (!TryGetBindingIndex(column_name, result)) { + throw InternalException("Binding index for column \"%s\" not found", column_name); + } + return result; +} + +bool Binding::HasMatchingBinding(const string &column_name) { + column_t result; + return TryGetBindingIndex(column_name, result); +} + +string Binding::ColumnNotFoundError(const string &column_name) const { + return StringUtil::Format("Values list \"%s\" does not have a column named \"%s\"", alias, column_name); +} + +BindResult Binding::Bind(ColumnRefExpression &colref, idx_t depth) { + column_t column_index; + bool success = false; + success = TryGetBindingIndex(colref.GetColumnName(), column_index); + if (!success) { + return BindResult(ColumnNotFoundError(colref.GetColumnName())); + } + ColumnBinding binding; + binding.table_index = index; + binding.column_index = column_index; + LogicalType sql_type = types[column_index]; + if (colref.alias.empty()) { + colref.alias = names[column_index]; + } + return BindResult(make_uniq(colref.GetName(), sql_type, binding, depth)); +} + +optional_ptr Binding::GetStandardEntry() { + return nullptr; +} + +EntryBinding::EntryBinding(const string &alias, vector types_p, vector names_p, idx_t index, + StandardEntry &entry) + : Binding(BindingType::CATALOG_ENTRY, alias, std::move(types_p), std::move(names_p), index), entry(entry) { +} + +optional_ptr EntryBinding::GetStandardEntry() { + return &entry; +} + +TableBinding::TableBinding(const string &alias, vector types_p, vector names_p, + vector &bound_column_ids, optional_ptr entry, idx_t index, + bool add_row_id) + : Binding(BindingType::TABLE, alias, std::move(types_p), std::move(names_p), index), + bound_column_ids(bound_column_ids), entry(entry) { + if (add_row_id) { + if (name_map.find("rowid") == name_map.end()) { + name_map["rowid"] = COLUMN_IDENTIFIER_ROW_ID; + } + } +} + +static void ReplaceAliases(ParsedExpression &expr, const ColumnList &list, + const unordered_map &alias_map) { + if (expr.type == ExpressionType::COLUMN_REF) { + auto &colref = expr.Cast(); + D_ASSERT(!colref.IsQualified()); + auto &col_names = colref.column_names; + D_ASSERT(col_names.size() == 1); + auto idx_entry = list.GetColumnIndex(col_names[0]); + auto &alias = alias_map.at(idx_entry.index); + col_names = {alias}; + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](const ParsedExpression &child) { ReplaceAliases((ParsedExpression &)child, list, alias_map); }); +} + +static void BakeTableName(ParsedExpression &expr, const string &table_name) { + if (expr.type == ExpressionType::COLUMN_REF) { + auto &colref = expr.Cast(); + D_ASSERT(!colref.IsQualified()); + auto &col_names = colref.column_names; + col_names.insert(col_names.begin(), table_name); + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](const ParsedExpression &child) { BakeTableName((ParsedExpression &)child, table_name); }); +} + +unique_ptr TableBinding::ExpandGeneratedColumn(const string &column_name) { + auto catalog_entry = GetStandardEntry(); + D_ASSERT(catalog_entry); // Should only be called on a TableBinding + + D_ASSERT(catalog_entry->type == CatalogType::TABLE_ENTRY); + auto &table_entry = catalog_entry->Cast(); + + // Get the index of the generated column + auto column_index = GetBindingIndex(column_name); + D_ASSERT(table_entry.GetColumn(LogicalIndex(column_index)).Generated()); + // Get a copy of the generated column + auto expression = table_entry.GetColumn(LogicalIndex(column_index)).GeneratedExpression().Copy(); + unordered_map alias_map; + for (auto &entry : name_map) { + alias_map[entry.second] = entry.first; + } + ReplaceAliases(*expression, table_entry.GetColumns(), alias_map); + BakeTableName(*expression, alias); + return (expression); +} + +const vector &TableBinding::GetBoundColumnIds() const { +#ifdef DEBUG + unordered_set column_ids; + for (auto &id : bound_column_ids) { + auto result = column_ids.insert(id); + // assert that all entries in the bound_column_ids are unique + D_ASSERT(result.second); + auto it = std::find_if(name_map.begin(), name_map.end(), + [&](const std::pair &it) { return it.second == id; }); + // assert that every id appears in the name_map + D_ASSERT(it != name_map.end()); + // the order that they appear in is not guaranteed to be sequential + } +#endif + return bound_column_ids; +} + +ColumnBinding TableBinding::GetColumnBinding(column_t column_index) { + auto &column_ids = bound_column_ids; + ColumnBinding binding; + + // Locate the column_id that matches the 'column_index' + auto it = std::find_if(column_ids.begin(), column_ids.end(), + [&](const column_t &id) -> bool { return id == column_index; }); + // Get the index of it + binding.column_index = std::distance(column_ids.begin(), it); + // If it wasn't found, add it + if (it == column_ids.end()) { + column_ids.push_back(column_index); + } + + binding.table_index = index; + return binding; +} + +BindResult TableBinding::Bind(ColumnRefExpression &colref, idx_t depth) { + auto &column_name = colref.GetColumnName(); + column_t column_index; + bool success = false; + success = TryGetBindingIndex(column_name, column_index); + if (!success) { + return BindResult(ColumnNotFoundError(column_name)); + } + auto entry = GetStandardEntry(); + if (entry && column_index != COLUMN_IDENTIFIER_ROW_ID) { + D_ASSERT(entry->type == CatalogType::TABLE_ENTRY); + // Either there is no table, or the columns category has to be standard + auto &table_entry = entry->Cast(); + auto &column_entry = table_entry.GetColumn(LogicalIndex(column_index)); + (void)table_entry; + (void)column_entry; + D_ASSERT(column_entry.Category() == TableColumnType::STANDARD); + } + // fetch the type of the column + LogicalType col_type; + if (column_index == COLUMN_IDENTIFIER_ROW_ID) { + // row id: BIGINT type + col_type = LogicalType::BIGINT; + } else { + // normal column: fetch type from base column + col_type = types[column_index]; + if (colref.alias.empty()) { + colref.alias = names[column_index]; + } + } + ColumnBinding binding = GetColumnBinding(column_index); + return BindResult(make_uniq(colref.GetName(), col_type, binding, depth)); +} + +optional_ptr TableBinding::GetStandardEntry() { + return entry; +} + +string TableBinding::ColumnNotFoundError(const string &column_name) const { + return StringUtil::Format("Table \"%s\" does not have a column named \"%s\"", alias, column_name); +} + +DummyBinding::DummyBinding(vector types_p, vector names_p, string dummy_name_p) + : Binding(BindingType::DUMMY, DummyBinding::DUMMY_NAME + dummy_name_p, std::move(types_p), std::move(names_p), + DConstants::INVALID_INDEX), + dummy_name(std::move(dummy_name_p)) { +} + +BindResult DummyBinding::Bind(ColumnRefExpression &colref, idx_t depth) { + column_t column_index; + if (!TryGetBindingIndex(colref.GetColumnName(), column_index)) { + throw InternalException("Column %s not found in bindings", colref.GetColumnName()); + } + ColumnBinding binding(index, column_index); + + // we are binding a parameter to create the dummy binding, no arguments are supplied + return BindResult(make_uniq(colref.GetName(), types[column_index], binding, depth)); +} + +BindResult DummyBinding::Bind(ColumnRefExpression &colref, idx_t lambda_index, idx_t depth) { + column_t column_index; + if (!TryGetBindingIndex(colref.GetColumnName(), column_index)) { + throw InternalException("Column %s not found in bindings", colref.GetColumnName()); + } + ColumnBinding binding(index, column_index); + return BindResult( + make_uniq(colref.GetName(), types[column_index], binding, lambda_index, depth)); +} + +unique_ptr DummyBinding::ParamToArg(ColumnRefExpression &colref) { + column_t column_index; + if (!TryGetBindingIndex(colref.GetColumnName(), column_index)) { + throw InternalException("Column %s not found in macro", colref.GetColumnName()); + } + auto arg = (*arguments)[column_index]->Copy(); + arg->alias = colref.alias; + return arg; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/table_filter.cpp b/src/duckdb/src/planner/table_filter.cpp new file mode 100644 index 00000000..0f6f2b0b --- /dev/null +++ b/src/duckdb/src/planner/table_filter.cpp @@ -0,0 +1,27 @@ +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/null_filter.hpp" + +namespace duckdb { + +void TableFilterSet::PushFilter(idx_t column_index, unique_ptr filter) { + auto entry = filters.find(column_index); + if (entry == filters.end()) { + // no filter yet: push the filter directly + filters[column_index] = std::move(filter); + } else { + // there is already a filter: AND it together + if (entry->second->filter_type == TableFilterType::CONJUNCTION_AND) { + auto &and_filter = entry->second->Cast(); + and_filter.child_filters.push_back(std::move(filter)); + } else { + auto and_filter = make_uniq(); + and_filter->child_filters.push_back(std::move(entry->second)); + and_filter->child_filters.push_back(std::move(filter)); + filters[column_index] = std::move(and_filter); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/arena_allocator.cpp b/src/duckdb/src/storage/arena_allocator.cpp new file mode 100644 index 00000000..1382ba2e --- /dev/null +++ b/src/duckdb/src/storage/arena_allocator.cpp @@ -0,0 +1,166 @@ +#include "duckdb/storage/arena_allocator.hpp" + +#include "duckdb/common/assert.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Arena Chunk +//===--------------------------------------------------------------------===// +ArenaChunk::ArenaChunk(Allocator &allocator, idx_t size) : current_position(0), maximum_size(size), prev(nullptr) { + D_ASSERT(size > 0); + data = allocator.Allocate(size); +} +ArenaChunk::~ArenaChunk() { + if (next) { + auto current_next = std::move(next); + while (current_next) { + current_next = std::move(current_next->next); + } + } +} + +//===--------------------------------------------------------------------===// +// Allocator Wrapper +//===--------------------------------------------------------------------===// +struct ArenaAllocatorData : public PrivateAllocatorData { + explicit ArenaAllocatorData(ArenaAllocator &allocator) : allocator(allocator) { + } + + ArenaAllocator &allocator; +}; + +static data_ptr_t ArenaAllocatorAllocate(PrivateAllocatorData *private_data, idx_t size) { + auto &allocator_data = private_data->Cast(); + return allocator_data.allocator.Allocate(size); +} + +static void ArenaAllocatorFree(PrivateAllocatorData *, data_ptr_t, idx_t) { + // nop +} + +static data_ptr_t ArenaAllocateReallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, + idx_t size) { + auto &allocator_data = private_data->Cast(); + return allocator_data.allocator.Reallocate(pointer, old_size, size); +} +//===--------------------------------------------------------------------===// +// Arena Allocator +//===--------------------------------------------------------------------===// +ArenaAllocator::ArenaAllocator(Allocator &allocator, idx_t initial_capacity) + : allocator(allocator), arena_allocator(ArenaAllocatorAllocate, ArenaAllocatorFree, ArenaAllocateReallocate, + make_uniq(*this)) { + head = nullptr; + tail = nullptr; + current_capacity = initial_capacity; +} + +ArenaAllocator::~ArenaAllocator() { +} + +data_ptr_t ArenaAllocator::Allocate(idx_t len) { + D_ASSERT(!head || head->current_position <= head->maximum_size); + if (!head || head->current_position + len > head->maximum_size) { + do { + current_capacity *= 2; + } while (current_capacity < len); + auto new_chunk = make_unsafe_uniq(allocator, current_capacity); + if (head) { + head->prev = new_chunk.get(); + new_chunk->next = std::move(head); + } else { + tail = new_chunk.get(); + } + head = std::move(new_chunk); + } + D_ASSERT(head->current_position + len <= head->maximum_size); + auto result = head->data.get() + head->current_position; + head->current_position += len; + return result; +} + +data_ptr_t ArenaAllocator::Reallocate(data_ptr_t pointer, idx_t old_size, idx_t size) { + D_ASSERT(head); + if (old_size == size) { + // nothing to do + return pointer; + } + + auto head_ptr = head->data.get() + head->current_position; + int64_t diff = size - old_size; + if (pointer == head_ptr && (size < old_size || head->current_position + diff <= head->maximum_size)) { + // passed pointer is the head pointer, and the diff fits on the current chunk + head->current_position += diff; + return pointer; + } else { + // allocate new memory + auto result = Allocate(size); + memcpy(result, pointer, old_size); + return result; + } +} + +data_ptr_t ArenaAllocator::AllocateAligned(idx_t size) { + return Allocate(AlignValue(size)); +} + +data_ptr_t ArenaAllocator::ReallocateAligned(data_ptr_t pointer, idx_t old_size, idx_t size) { + return Reallocate(pointer, old_size, AlignValue(size)); +} + +void ArenaAllocator::Reset() { + if (head) { + // destroy all chunks except the current one + if (head->next) { + auto current_next = std::move(head->next); + while (current_next) { + current_next = std::move(current_next->next); + } + } + tail = head.get(); + + // reset the head + head->current_position = 0; + head->prev = nullptr; + } +} + +void ArenaAllocator::Destroy() { + head = nullptr; + tail = nullptr; + current_capacity = ARENA_ALLOCATOR_INITIAL_CAPACITY; +} + +void ArenaAllocator::Move(ArenaAllocator &other) { + D_ASSERT(!other.head); + other.tail = tail; + other.head = std::move(head); + other.current_capacity = current_capacity; + Destroy(); +} + +ArenaChunk *ArenaAllocator::GetHead() { + return head.get(); +} + +ArenaChunk *ArenaAllocator::GetTail() { + return tail; +} + +bool ArenaAllocator::IsEmpty() const { + return head == nullptr; +} + +idx_t ArenaAllocator::SizeInBytes() const { + idx_t total_size = 0; + if (!IsEmpty()) { + auto current = head.get(); + while (current != nullptr) { + total_size += current->current_position; + current = current->next.get(); + } + } + return total_size; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/block.cpp b/src/duckdb/src/storage/block.cpp new file mode 100644 index 00000000..e90b621c --- /dev/null +++ b/src/duckdb/src/storage/block.cpp @@ -0,0 +1,19 @@ +#include "duckdb/storage/block.hpp" +#include "duckdb/common/assert.hpp" + +namespace duckdb { + +Block::Block(Allocator &allocator, block_id_t id) + : FileBuffer(allocator, FileBufferType::BLOCK, Storage::BLOCK_SIZE), id(id) { +} + +Block::Block(Allocator &allocator, block_id_t id, uint32_t internal_size) + : FileBuffer(allocator, FileBufferType::BLOCK, internal_size), id(id) { + D_ASSERT((AllocSize() & (Storage::SECTOR_SIZE - 1)) == 0); +} + +Block::Block(FileBuffer &source, block_id_t id) : FileBuffer(source, FileBufferType::BLOCK), id(id) { + D_ASSERT((AllocSize() & (Storage::SECTOR_SIZE - 1)) == 0); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/buffer/block_handle.cpp b/src/duckdb/src/storage/buffer/block_handle.cpp new file mode 100644 index 00000000..a9c9f967 --- /dev/null +++ b/src/duckdb/src/storage/buffer/block_handle.cpp @@ -0,0 +1,129 @@ +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/block.hpp" +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/buffer/buffer_pool.hpp" +#include "duckdb/common/file_buffer.hpp" + +namespace duckdb { + +BlockHandle::BlockHandle(BlockManager &block_manager, block_id_t block_id_p) + : block_manager(block_manager), readers(0), block_id(block_id_p), buffer(nullptr), eviction_timestamp(0), + can_destroy(false), memory_charge(block_manager.buffer_manager.GetBufferPool()), unswizzled(nullptr) { + eviction_timestamp = 0; + state = BlockState::BLOCK_UNLOADED; + memory_usage = Storage::BLOCK_ALLOC_SIZE; +} + +BlockHandle::BlockHandle(BlockManager &block_manager, block_id_t block_id_p, unique_ptr buffer_p, + bool can_destroy_p, idx_t block_size, BufferPoolReservation &&reservation) + : block_manager(block_manager), readers(0), block_id(block_id_p), eviction_timestamp(0), can_destroy(can_destroy_p), + memory_charge(block_manager.buffer_manager.GetBufferPool()), unswizzled(nullptr) { + buffer = std::move(buffer_p); + state = BlockState::BLOCK_LOADED; + memory_usage = block_size; + memory_charge = std::move(reservation); +} + +BlockHandle::~BlockHandle() { // NOLINT: allow internal exceptions + // being destroyed, so any unswizzled pointers are just binary junk now. + unswizzled = nullptr; + auto &buffer_manager = block_manager.buffer_manager; + // no references remain to this block: erase + if (buffer && state == BlockState::BLOCK_LOADED) { + D_ASSERT(memory_charge.size > 0); + // the block is still loaded in memory: erase it + buffer.reset(); + memory_charge.Resize(0); + } else { + D_ASSERT(memory_charge.size == 0); + } + buffer_manager.GetBufferPool().PurgeQueue(); + block_manager.UnregisterBlock(block_id, can_destroy); +} + +unique_ptr AllocateBlock(BlockManager &block_manager, unique_ptr reusable_buffer, + block_id_t block_id) { + if (reusable_buffer) { + // re-usable buffer: re-use it + if (reusable_buffer->type == FileBufferType::BLOCK) { + // we can reuse the buffer entirely + auto &block = reinterpret_cast(*reusable_buffer); + block.id = block_id; + return unique_ptr_cast(std::move(reusable_buffer)); + } + auto block = block_manager.CreateBlock(block_id, reusable_buffer.get()); + reusable_buffer.reset(); + return block; + } else { + // no re-usable buffer: allocate a new block + return block_manager.CreateBlock(block_id, nullptr); + } +} + +BufferHandle BlockHandle::Load(shared_ptr &handle, unique_ptr reusable_buffer) { + if (handle->state == BlockState::BLOCK_LOADED) { + // already loaded + D_ASSERT(handle->buffer); + return BufferHandle(handle, handle->buffer.get()); + } + + auto &block_manager = handle->block_manager; + if (handle->block_id < MAXIMUM_BLOCK) { + auto block = AllocateBlock(block_manager, std::move(reusable_buffer), handle->block_id); + block_manager.Read(*block); + handle->buffer = std::move(block); + } else { + if (handle->can_destroy) { + return BufferHandle(); + } else { + handle->buffer = + block_manager.buffer_manager.ReadTemporaryBuffer(handle->block_id, std::move(reusable_buffer)); + } + } + handle->state = BlockState::BLOCK_LOADED; + return BufferHandle(handle, handle->buffer.get()); +} + +unique_ptr BlockHandle::UnloadAndTakeBlock() { + if (state == BlockState::BLOCK_UNLOADED) { + // already unloaded: nothing to do + return nullptr; + } + D_ASSERT(!unswizzled); + D_ASSERT(CanUnload()); + + if (block_id >= MAXIMUM_BLOCK && !can_destroy) { + // temporary block that cannot be destroyed: write to temporary file + block_manager.buffer_manager.WriteTemporaryBuffer(block_id, *buffer); + } + memory_charge.Resize(0); + state = BlockState::BLOCK_UNLOADED; + return std::move(buffer); +} + +void BlockHandle::Unload() { + auto block = UnloadAndTakeBlock(); + block.reset(); +} + +bool BlockHandle::CanUnload() { + if (state == BlockState::BLOCK_UNLOADED) { + // already unloaded + return false; + } + if (readers > 0) { + // there are active readers + return false; + } + if (block_id >= MAXIMUM_BLOCK && !can_destroy && !block_manager.buffer_manager.HasTemporaryDirectory()) { + // in order to unload this block we need to write it to a temporary buffer + // however, no temporary directory is specified! + // hence we cannot unload the block + return false; + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/buffer/block_manager.cpp b/src/duckdb/src/storage/buffer/block_manager.cpp new file mode 100644 index 00000000..20871d01 --- /dev/null +++ b/src/duckdb/src/storage/buffer/block_manager.cpp @@ -0,0 +1,86 @@ +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/buffer/buffer_pool.hpp" +#include "duckdb/storage/metadata/metadata_manager.hpp" + +namespace duckdb { + +BlockManager::BlockManager(BufferManager &buffer_manager) + : buffer_manager(buffer_manager), metadata_manager(make_uniq(*this, buffer_manager)) { +} + +shared_ptr BlockManager::RegisterBlock(block_id_t block_id) { + lock_guard lock(blocks_lock); + // check if the block already exists + auto entry = blocks.find(block_id); + if (entry != blocks.end()) { + // already exists: check if it hasn't expired yet + auto existing_ptr = entry->second.lock(); + if (existing_ptr) { + //! it hasn't! return it + return existing_ptr; + } + } + // create a new block pointer for this block + auto result = make_shared(*this, block_id); + // register the block pointer in the set of blocks as a weak pointer + blocks[block_id] = weak_ptr(result); + return result; +} + +shared_ptr BlockManager::ConvertToPersistent(block_id_t block_id, shared_ptr old_block) { + // pin the old block to ensure we have it loaded in memory + auto old_handle = buffer_manager.Pin(old_block); + D_ASSERT(old_block->state == BlockState::BLOCK_LOADED); + D_ASSERT(old_block->buffer); + + // Temp buffers can be larger than the storage block size. But persistent buffers + // cannot. + D_ASSERT(old_block->buffer->AllocSize() <= Storage::BLOCK_ALLOC_SIZE); + + // register a block with the new block id + auto new_block = RegisterBlock(block_id); + D_ASSERT(new_block->state == BlockState::BLOCK_UNLOADED); + D_ASSERT(new_block->readers == 0); + + // move the data from the old block into data for the new block + new_block->state = BlockState::BLOCK_LOADED; + new_block->buffer = ConvertBlock(block_id, *old_block->buffer); + new_block->memory_usage = old_block->memory_usage; + new_block->memory_charge = std::move(old_block->memory_charge); + + // clear the old buffer and unload it + old_block->buffer.reset(); + old_block->state = BlockState::BLOCK_UNLOADED; + old_block->memory_usage = 0; + old_handle.Destroy(); + old_block.reset(); + + // persist the new block to disk + Write(*new_block->buffer, block_id); + + buffer_manager.GetBufferPool().AddToEvictionQueue(new_block); + + return new_block; +} + +void BlockManager::UnregisterBlock(block_id_t block_id, bool can_destroy) { + if (block_id >= MAXIMUM_BLOCK) { + // in-memory buffer: buffer could have been offloaded to disk: remove the file + buffer_manager.DeleteTemporaryFile(block_id); + } else { + lock_guard lock(blocks_lock); + // on-disk block: erase from list of blocks in manager + blocks.erase(block_id); + } +} + +MetadataManager &BlockManager::GetMetadataManager() { + return *metadata_manager; +} + +void BlockManager::Truncate() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/buffer/buffer_handle.cpp b/src/duckdb/src/storage/buffer/buffer_handle.cpp new file mode 100644 index 00000000..dc3be3f2 --- /dev/null +++ b/src/duckdb/src/storage/buffer/buffer_handle.cpp @@ -0,0 +1,47 @@ +#include "duckdb/storage/buffer/buffer_handle.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" + +namespace duckdb { + +BufferHandle::BufferHandle() : handle(nullptr), node(nullptr) { +} + +BufferHandle::BufferHandle(shared_ptr handle_p, FileBuffer *node_p) + : handle(std::move(handle_p)), node(node_p) { +} + +BufferHandle::BufferHandle(BufferHandle &&other) noexcept { + std::swap(node, other.node); + std::swap(handle, other.handle); +} + +BufferHandle &BufferHandle::operator=(BufferHandle &&other) noexcept { + std::swap(node, other.node); + std::swap(handle, other.handle); + return *this; +} + +BufferHandle::~BufferHandle() { + Destroy(); +} + +bool BufferHandle::IsValid() const { + return node != nullptr; +} + +void BufferHandle::Destroy() { + if (!handle || !IsValid()) { + return; + } + handle->block_manager.buffer_manager.Unpin(handle); + handle.reset(); + node = nullptr; +} + +FileBuffer &BufferHandle::GetFileBuffer() { + D_ASSERT(node); + return *node; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/buffer/buffer_pool.cpp b/src/duckdb/src/storage/buffer/buffer_pool.cpp new file mode 100644 index 00000000..b1701db0 --- /dev/null +++ b/src/duckdb/src/storage/buffer/buffer_pool.cpp @@ -0,0 +1,136 @@ +#include "duckdb/storage/buffer/buffer_pool.hpp" +#include "duckdb/parallel/concurrentqueue.hpp" +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +typedef duckdb_moodycamel::ConcurrentQueue eviction_queue_t; + +struct EvictionQueue { + eviction_queue_t q; +}; + +bool BufferEvictionNode::CanUnload(BlockHandle &handle_p) { + if (timestamp != handle_p.eviction_timestamp) { + // handle was used in between + return false; + } + return handle_p.CanUnload(); +} + +shared_ptr BufferEvictionNode::TryGetBlockHandle() { + auto handle_p = handle.lock(); + if (!handle_p) { + // BlockHandle has been destroyed + return nullptr; + } + if (!CanUnload(*handle_p)) { + // handle was used in between + return nullptr; + } + // this is the latest node in the queue with this handle + return handle_p; +} + +BufferPool::BufferPool(idx_t maximum_memory) + : current_memory(0), maximum_memory(maximum_memory), queue(make_uniq()), queue_insertions(0) { +} +BufferPool::~BufferPool() { +} + +void BufferPool::AddToEvictionQueue(shared_ptr &handle) { + constexpr int INSERT_INTERVAL = 1024; + + D_ASSERT(handle->readers == 0); + handle->eviction_timestamp++; + // After each 1024 insertions, run through the queue and purge. + if ((++queue_insertions % INSERT_INTERVAL) == 0) { + PurgeQueue(); + } + queue->q.enqueue(BufferEvictionNode(weak_ptr(handle), handle->eviction_timestamp)); +} + +void BufferPool::IncreaseUsedMemory(idx_t size) { + current_memory += size; +} + +idx_t BufferPool::GetUsedMemory() { + return current_memory; +} +idx_t BufferPool::GetMaxMemory() { + return maximum_memory; +} + +BufferPool::EvictionResult BufferPool::EvictBlocks(idx_t extra_memory, idx_t memory_limit, + unique_ptr *buffer) { + BufferEvictionNode node; + TempBufferPoolReservation r(*this, extra_memory); + while (current_memory > memory_limit) { + // get a block to unpin from the queue + if (!queue->q.try_dequeue(node)) { + // Failed to reserve. Adjust size of temp reservation to 0. + r.Resize(0); + return {false, std::move(r)}; + } + // get a reference to the underlying block pointer + auto handle = node.TryGetBlockHandle(); + if (!handle) { + continue; + } + // we might be able to free this block: grab the mutex and check if we can free it + lock_guard lock(handle->lock); + if (!node.CanUnload(*handle)) { + // something changed in the mean-time, bail out + continue; + } + // hooray, we can unload the block + if (buffer && handle->buffer->AllocSize() == extra_memory) { + // we can actually re-use the memory directly! + *buffer = handle->UnloadAndTakeBlock(); + return {true, std::move(r)}; + } else { + // release the memory and mark the block as unloaded + handle->Unload(); + } + } + return {true, std::move(r)}; +} + +void BufferPool::PurgeQueue() { + BufferEvictionNode node; + while (true) { + if (!queue->q.try_dequeue(node)) { + break; + } + auto handle = node.TryGetBlockHandle(); + if (!handle) { + continue; + } else { + queue->q.enqueue(std::move(node)); + break; + } + } +} + +void BufferPool::SetLimit(idx_t limit, const char *exception_postscript) { + lock_guard l_lock(limit_lock); + // try to evict until the limit is reached + if (!EvictBlocks(0, limit).success) { + throw OutOfMemoryException( + "Failed to change memory limit to %lld: could not free up enough memory for the new limit%s", limit, + exception_postscript); + } + idx_t old_limit = maximum_memory; + // set the global maximum memory to the new limit if successful + maximum_memory = limit; + // evict again + if (!EvictBlocks(0, limit).success) { + // failed: go back to old limit + maximum_memory = old_limit; + throw OutOfMemoryException( + "Failed to change memory limit to %lld: could not free up enough memory for the new limit%s", limit, + exception_postscript); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/buffer/buffer_pool_reservation.cpp b/src/duckdb/src/storage/buffer/buffer_pool_reservation.cpp new file mode 100644 index 00000000..312d3bf1 --- /dev/null +++ b/src/duckdb/src/storage/buffer/buffer_pool_reservation.cpp @@ -0,0 +1,35 @@ +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/buffer/buffer_pool.hpp" + +namespace duckdb { + +BufferPoolReservation::BufferPoolReservation(BufferPool &pool) : pool(pool) { +} + +BufferPoolReservation::BufferPoolReservation(BufferPoolReservation &&src) noexcept : pool(src.pool) { + size = src.size; + src.size = 0; +} + +BufferPoolReservation &BufferPoolReservation::operator=(BufferPoolReservation &&src) noexcept { + size = src.size; + src.size = 0; + return *this; +} + +BufferPoolReservation::~BufferPoolReservation() { + D_ASSERT(size == 0); +} + +void BufferPoolReservation::Resize(idx_t new_size) { + int64_t delta = (int64_t)new_size - size; + pool.IncreaseUsedMemory(delta); + size = new_size; +} + +void BufferPoolReservation::Merge(BufferPoolReservation &&src) { + size += src.size; + src.size = 0; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/buffer_manager.cpp b/src/duckdb/src/storage/buffer_manager.cpp new file mode 100644 index 00000000..9cd61aec --- /dev/null +++ b/src/duckdb/src/storage/buffer_manager.cpp @@ -0,0 +1,79 @@ +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_buffer.hpp" +#include "duckdb/storage/standard_buffer_manager.hpp" + +namespace duckdb { + +unique_ptr BufferManager::CreateStandardBufferManager(DatabaseInstance &db, DBConfig &config) { + return make_uniq(db, config.options.temporary_directory); +} + +shared_ptr BufferManager::RegisterSmallMemory(idx_t block_size) { + throw NotImplementedException("This type of BufferManager can not create 'small-memory' blocks"); +} + +Allocator &BufferManager::GetBufferAllocator() { + throw NotImplementedException("This type of BufferManager does not have an Allocator"); +} + +void BufferManager::ReserveMemory(idx_t size) { + throw NotImplementedException("This type of BufferManager can not reserve memory"); +} +void BufferManager::FreeReservedMemory(idx_t size) { + throw NotImplementedException("This type of BufferManager can not free reserved memory"); +} + +void BufferManager::SetLimit(idx_t limit) { + throw NotImplementedException("This type of BufferManager can not set a limit"); +} + +vector BufferManager::GetTemporaryFiles() { + throw InternalException("This type of BufferManager does not allow temporary files"); +} + +const string &BufferManager::GetTemporaryDirectory() { + throw InternalException("This type of BufferManager does not allow a temporary directory"); +} + +BufferPool &BufferManager::GetBufferPool() { + throw InternalException("This type of BufferManager does not have a buffer pool"); +} + +void BufferManager::SetTemporaryDirectory(const string &new_dir) { + throw NotImplementedException("This type of BufferManager can not set a temporary directory"); +} + +DatabaseInstance &BufferManager::GetDatabase() { + throw NotImplementedException("This type of BufferManager is not linked to a DatabaseInstance"); +} + +bool BufferManager::HasTemporaryDirectory() const { + return false; +} + +unique_ptr BufferManager::ConstructManagedBuffer(idx_t size, unique_ptr &&source, + FileBufferType type) { + throw NotImplementedException("This type of BufferManager can not construct managed buffers"); +} + +// Protected methods + +void BufferManager::AddToEvictionQueue(shared_ptr &handle) { + throw NotImplementedException("This type of BufferManager does not support 'AddToEvictionQueue"); +} + +void BufferManager::WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) { + throw NotImplementedException("This type of BufferManager does not support 'WriteTemporaryBuffer"); +} + +unique_ptr BufferManager::ReadTemporaryBuffer(block_id_t id, unique_ptr buffer) { + throw NotImplementedException("This type of BufferManager does not support 'ReadTemporaryBuffer"); +} + +void BufferManager::DeleteTemporaryFile(block_id_t id) { + throw NotImplementedException("This type of BufferManager does not support 'DeleteTemporaryFile"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint/row_group_writer.cpp b/src/duckdb/src/storage/checkpoint/row_group_writer.cpp new file mode 100644 index 00000000..18e4d1dd --- /dev/null +++ b/src/duckdb/src/storage/checkpoint/row_group_writer.cpp @@ -0,0 +1,30 @@ +#include "duckdb/storage/checkpoint/table_data_writer.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" + +namespace duckdb { + +CompressionType RowGroupWriter::GetColumnCompressionType(idx_t i) { + return table.GetColumn(LogicalIndex(i)).CompressionType(); +} + +void RowGroupWriter::RegisterPartialBlock(PartialBlockAllocation &&allocation) { + partial_block_manager.RegisterPartialBlock(std::move(allocation)); +} + +PartialBlockAllocation RowGroupWriter::GetBlockAllocation(uint32_t segment_size) { + return partial_block_manager.GetBlockAllocation(segment_size); +} + +void SingleFileRowGroupWriter::WriteColumnDataPointers(ColumnCheckpointState &column_checkpoint_state, + Serializer &serializer) { + const auto &data_pointers = column_checkpoint_state.data_pointers; + serializer.WriteProperty(100, "data_pointers", data_pointers); +} + +MetadataWriter &SingleFileRowGroupWriter::GetPayloadWriter() { + return table_data_writer; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint/table_data_reader.cpp b/src/duckdb/src/storage/checkpoint/table_data_reader.cpp new file mode 100644 index 00000000..5f76227f --- /dev/null +++ b/src/duckdb/src/storage/checkpoint/table_data_reader.cpp @@ -0,0 +1,32 @@ +#include "duckdb/storage/checkpoint/table_data_reader.hpp" +#include "duckdb/storage/metadata/metadata_reader.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" + +#include "duckdb/main/database.hpp" + +namespace duckdb { + +TableDataReader::TableDataReader(MetadataReader &reader, BoundCreateTableInfo &info) : reader(reader), info(info) { + info.data = make_uniq(info.Base().columns.LogicalColumnCount()); +} + +void TableDataReader::ReadTableData() { + auto &columns = info.Base().columns; + D_ASSERT(!columns.empty()); + + // We stored the table statistics as a unit in FinalizeTable. + BinaryDeserializer stats_deserializer(reader); + stats_deserializer.Begin(); + info.data->table_stats.Deserialize(stats_deserializer, columns); + stats_deserializer.End(); + + // Deserialize the row group pointers (lazily, just set the count and the pointer to them for now) + info.data->row_group_count = reader.Read(); + info.data->block_pointer = reader.GetMetaBlockPointer(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp new file mode 100644 index 00000000..4c7168be --- /dev/null +++ b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp @@ -0,0 +1,78 @@ +#include "duckdb/storage/checkpoint/table_data_writer.hpp" + +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/storage/table/table_statistics.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" + +namespace duckdb { + +TableDataWriter::TableDataWriter(TableCatalogEntry &table_p) : table(table_p.Cast()) { + D_ASSERT(table_p.IsDuckTable()); +} + +TableDataWriter::~TableDataWriter() { +} + +void TableDataWriter::WriteTableData(Serializer &metadata_serializer) { + // start scanning the table and append the data to the uncompressed segments + table.GetStorage().Checkpoint(*this, metadata_serializer); +} + +CompressionType TableDataWriter::GetColumnCompressionType(idx_t i) { + return table.GetColumn(LogicalIndex(i)).CompressionType(); +} + +void TableDataWriter::AddRowGroup(RowGroupPointer &&row_group_pointer, unique_ptr &&writer) { + row_group_pointers.push_back(std::move(row_group_pointer)); + writer.reset(); +} + +SingleFileTableDataWriter::SingleFileTableDataWriter(SingleFileCheckpointWriter &checkpoint_manager, + TableCatalogEntry &table, MetadataWriter &table_data_writer) + : TableDataWriter(table), checkpoint_manager(checkpoint_manager), table_data_writer(table_data_writer) { +} + +unique_ptr SingleFileTableDataWriter::GetRowGroupWriter(RowGroup &row_group) { + return make_uniq(table, checkpoint_manager.partial_block_manager, table_data_writer); +} + +void SingleFileTableDataWriter::FinalizeTable(TableStatistics &&global_stats, DataTableInfo *info, + Serializer &metadata_serializer) { + // store the current position in the metadata writer + // this is where the row groups for this table start + auto pointer = table_data_writer.GetMetaBlockPointer(); + + // Serialize statistics as a single unit + BinarySerializer stats_serializer(table_data_writer); + stats_serializer.Begin(); + global_stats.Serialize(stats_serializer); + stats_serializer.End(); + + // now start writing the row group pointers to disk + table_data_writer.Write(row_group_pointers.size()); + idx_t total_rows = 0; + for (auto &row_group_pointer : row_group_pointers) { + auto row_group_count = row_group_pointer.row_start + row_group_pointer.tuple_count; + if (row_group_count > total_rows) { + total_rows = row_group_count; + } + + // Each RowGroup is its own unit + BinarySerializer row_group_serializer(table_data_writer); + row_group_serializer.Begin(); + RowGroup::Serialize(row_group_pointer, row_group_serializer); + row_group_serializer.End(); + } + + auto index_pointers = info->indexes.SerializeIndexes(table_data_writer); + + // Now begin the metadata as a unit + // Pointer to the table itself goes to the metadata stream. + metadata_serializer.WriteProperty(101, "table_pointer", pointer); + metadata_serializer.WriteProperty(102, "total_rows", total_rows); + metadata_serializer.WriteProperty(103, "index_pointers", index_pointers); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint/write_overflow_strings_to_disk.cpp b/src/duckdb/src/storage/checkpoint/write_overflow_strings_to_disk.cpp new file mode 100644 index 00000000..3c0f08df --- /dev/null +++ b/src/duckdb/src/storage/checkpoint/write_overflow_strings_to_disk.cpp @@ -0,0 +1,104 @@ +#include "duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp" +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +WriteOverflowStringsToDisk::WriteOverflowStringsToDisk(BlockManager &block_manager) + : block_manager(block_manager), block_id(INVALID_BLOCK), offset(0) { +} + +WriteOverflowStringsToDisk::~WriteOverflowStringsToDisk() { + // verify that the overflow writer has been flushed + D_ASSERT(Exception::UncaughtException() || offset == 0); +} + +shared_ptr UncompressedStringSegmentState::GetHandle(BlockManager &manager, block_id_t block_id) { + lock_guard lock(block_lock); + auto entry = handles.find(block_id); + if (entry != handles.end()) { + return entry->second; + } + auto result = manager.RegisterBlock(block_id); + handles.insert(make_pair(block_id, result)); + return result; +} + +void UncompressedStringSegmentState::RegisterBlock(BlockManager &manager, block_id_t block_id) { + lock_guard lock(block_lock); + auto entry = handles.find(block_id); + if (entry != handles.end()) { + throw InternalException("UncompressedStringSegmentState::RegisterBlock - block id %llu already exists", + block_id); + } + auto result = manager.RegisterBlock(block_id); + handles.insert(make_pair(block_id, std::move(result))); + on_disk_blocks.push_back(block_id); +} + +void WriteOverflowStringsToDisk::WriteString(UncompressedStringSegmentState &state, string_t string, + block_id_t &result_block, int32_t &result_offset) { + auto &buffer_manager = block_manager.buffer_manager; + if (!handle.IsValid()) { + handle = buffer_manager.Allocate(Storage::BLOCK_SIZE); + } + // first write the length of the string + if (block_id == INVALID_BLOCK || offset + 2 * sizeof(uint32_t) >= STRING_SPACE) { + AllocateNewBlock(state, block_manager.GetFreeBlockId()); + } + result_block = block_id; + result_offset = offset; + + // write the length field + auto data_ptr = handle.Ptr(); + auto string_length = string.GetSize(); + Store(string_length, data_ptr + offset); + offset += sizeof(uint32_t); + + // now write the remainder of the string + auto strptr = string.GetData(); + uint32_t remaining = string_length; + while (remaining > 0) { + uint32_t to_write = MinValue(remaining, STRING_SPACE - offset); + if (to_write > 0) { + memcpy(data_ptr + offset, strptr, to_write); + + remaining -= to_write; + offset += to_write; + strptr += to_write; + } + if (remaining > 0) { + D_ASSERT(offset == WriteOverflowStringsToDisk::STRING_SPACE); + // there is still remaining stuff to write + // now write the current block to disk and allocate a new block + AllocateNewBlock(state, block_manager.GetFreeBlockId()); + } + } +} + +void WriteOverflowStringsToDisk::Flush() { + if (block_id != INVALID_BLOCK && offset > 0) { + // zero-initialize the empty part of the overflow string buffer (if any) + if (offset < STRING_SPACE) { + memset(handle.Ptr() + offset, 0, STRING_SPACE - offset); + } + // write to disk + block_manager.Write(handle.GetFileBuffer(), block_id); + } + block_id = INVALID_BLOCK; + offset = 0; +} + +void WriteOverflowStringsToDisk::AllocateNewBlock(UncompressedStringSegmentState &state, block_id_t new_block_id) { + if (block_id != INVALID_BLOCK) { + // there is an old block, write it first + // write the new block id at the end of the previous block + Store(new_block_id, handle.Ptr() + WriteOverflowStringsToDisk::STRING_SPACE); + Flush(); + } + offset = 0; + block_id = new_block_id; + state.RegisterBlock(block_manager, new_block_id); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint_manager.cpp b/src/duckdb/src/storage/checkpoint_manager.cpp new file mode 100644 index 00000000..ffe7515c --- /dev/null +++ b/src/duckdb/src/storage/checkpoint_manager.cpp @@ -0,0 +1,536 @@ +#include "duckdb/storage/checkpoint_manager.hpp" + +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/connection.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_tableref.hpp" +#include "duckdb/planner/expression_binder/index_binder.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/checkpoint/table_data_reader.hpp" +#include "duckdb/storage/checkpoint/table_data_writer.hpp" +#include "duckdb/storage/metadata/metadata_reader.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/transaction/transaction_manager.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" + +namespace duckdb { + +void ReorderTableEntries(catalog_entry_vector_t &tables); + +SingleFileCheckpointWriter::SingleFileCheckpointWriter(AttachedDatabase &db, BlockManager &block_manager) + : CheckpointWriter(db), partial_block_manager(block_manager, CheckpointType::FULL_CHECKPOINT) { +} + +BlockManager &SingleFileCheckpointWriter::GetBlockManager() { + auto &storage_manager = db.GetStorageManager().Cast(); + return *storage_manager.block_manager; +} + +MetadataWriter &SingleFileCheckpointWriter::GetMetadataWriter() { + return *metadata_writer; +} + +MetadataManager &SingleFileCheckpointWriter::GetMetadataManager() { + return GetBlockManager().GetMetadataManager(); +} + +unique_ptr SingleFileCheckpointWriter::GetTableDataWriter(TableCatalogEntry &table) { + return make_uniq(*this, table, *table_metadata_writer); +} + +static catalog_entry_vector_t GetCatalogEntries(vector> &schemas) { + catalog_entry_vector_t entries; + for (auto &schema_p : schemas) { + auto &schema = schema_p.get(); + entries.push_back(schema); + schema.Scan(CatalogType::TYPE_ENTRY, [&](CatalogEntry &entry) { + if (entry.internal) { + return; + } + entries.push_back(entry); + }); + + schema.Scan(CatalogType::SEQUENCE_ENTRY, [&](CatalogEntry &entry) { + if (entry.internal) { + return; + } + entries.push_back(entry); + }); + + catalog_entry_vector_t tables; + vector> views; + schema.Scan(CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { + if (entry.internal) { + return; + } + if (entry.type == CatalogType::TABLE_ENTRY) { + tables.push_back(entry.Cast()); + } else if (entry.type == CatalogType::VIEW_ENTRY) { + views.push_back(entry.Cast()); + } else { + throw NotImplementedException("Catalog type for entries"); + } + }); + // Reorder tables because of foreign key constraint + ReorderTableEntries(tables); + for (auto &table : tables) { + entries.push_back(table.get()); + } + for (auto &view : views) { + entries.push_back(view.get()); + } + + schema.Scan(CatalogType::SCALAR_FUNCTION_ENTRY, [&](CatalogEntry &entry) { + if (entry.internal) { + return; + } + if (entry.type == CatalogType::MACRO_ENTRY) { + entries.push_back(entry); + } + }); + + schema.Scan(CatalogType::TABLE_FUNCTION_ENTRY, [&](CatalogEntry &entry) { + if (entry.internal) { + return; + } + if (entry.type == CatalogType::TABLE_MACRO_ENTRY) { + entries.push_back(entry); + } + }); + + schema.Scan(CatalogType::INDEX_ENTRY, [&](CatalogEntry &entry) { + D_ASSERT(!entry.internal); + entries.push_back(entry); + }); + } + return entries; +} + +void SingleFileCheckpointWriter::CreateCheckpoint() { + auto &config = DBConfig::Get(db); + auto &storage_manager = db.GetStorageManager().Cast(); + if (storage_manager.InMemory()) { + return; + } + // assert that the checkpoint manager hasn't been used before + D_ASSERT(!metadata_writer); + + auto &block_manager = GetBlockManager(); + auto &metadata_manager = GetMetadataManager(); + + //! Set up the writers for the checkpoints + metadata_writer = make_uniq(metadata_manager); + table_metadata_writer = make_uniq(metadata_manager); + + // get the id of the first meta block + auto meta_block = metadata_writer->GetMetaBlockPointer(); + + vector> schemas; + // we scan the set of committed schemas + auto &catalog = Catalog::GetCatalog(db).Cast(); + catalog.ScanSchemas([&](SchemaCatalogEntry &entry) { schemas.push_back(entry); }); + // write the actual data into the database + + // Create a serializer to write the checkpoint data + // The serialized format is roughly: + /* + { + schemas: [ + { + schema: , + custom_types: [ { type: }, ... ], + sequences: [ { sequence: }, ... ], + tables: [ { table: }, ... ], + views: [ { view: }, ... ], + macros: [ { macro: }, ... ], + table_macros: [ { table_macro: }, ... ], + indexes: [ { index: , root_offset }, ... ] + } + ] + } + */ + auto catalog_entries = GetCatalogEntries(schemas); + BinarySerializer serializer(*metadata_writer); + serializer.Begin(); + serializer.WriteList(100, "catalog_entries", catalog_entries.size(), [&](Serializer::List &list, idx_t i) { + auto &entry = catalog_entries[i]; + list.WriteObject([&](Serializer &obj) { WriteEntry(entry.get(), obj); }); + }); + serializer.End(); + + partial_block_manager.FlushPartialBlocks(); + metadata_writer->Flush(); + table_metadata_writer->Flush(); + + // write a checkpoint flag to the WAL + // this protects against the rare event that the database crashes AFTER writing the file, but BEFORE truncating the + // WAL we write an entry CHECKPOINT "meta_block_id" into the WAL upon loading, if we see there is an entry + // CHECKPOINT "meta_block_id", and the id MATCHES the head idin the file we know that the database was successfully + // checkpointed, so we know that we should avoid replaying the WAL to avoid duplicating data + auto wal = storage_manager.GetWriteAheadLog(); + wal->WriteCheckpoint(meta_block); + wal->Flush(); + + if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER) { + throw FatalException("Checkpoint aborted before header write because of PRAGMA checkpoint_abort flag"); + } + + // finally write the updated header + DatabaseHeader header; + header.meta_block = meta_block.block_pointer; + block_manager.WriteHeader(header); + + if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE) { + throw FatalException("Checkpoint aborted before truncate because of PRAGMA checkpoint_abort flag"); + } + + // truncate the file + block_manager.Truncate(); + + // truncate the WAL + wal->Truncate(0); +} + +void CheckpointReader::LoadCheckpoint(ClientContext &context, MetadataReader &reader) { + BinaryDeserializer deserializer(reader); + deserializer.Begin(); + deserializer.ReadList(100, "catalog_entries", [&](Deserializer::List &list, idx_t i) { + return list.ReadObject([&](Deserializer &obj) { ReadEntry(context, obj); }); + }); + deserializer.End(); +} + +MetadataManager &SingleFileCheckpointReader::GetMetadataManager() { + return storage.block_manager->GetMetadataManager(); +} + +void SingleFileCheckpointReader::LoadFromStorage() { + auto &block_manager = *storage.block_manager; + auto &metadata_manager = GetMetadataManager(); + MetaBlockPointer meta_block(block_manager.GetMetaBlock(), 0); + if (!meta_block.IsValid()) { + // storage is empty + return; + } + + Connection con(storage.GetDatabase()); + con.BeginTransaction(); + // create the MetadataReader to read from the storage + MetadataReader reader(metadata_manager, meta_block); + // reader.SetContext(*con.context); + LoadCheckpoint(*con.context, reader); + con.Commit(); +} + +void CheckpointWriter::WriteEntry(CatalogEntry &entry, Serializer &serializer) { + serializer.WriteProperty(99, "catalog_type", entry.type); + + switch (entry.type) { + case CatalogType::SCHEMA_ENTRY: { + auto &schema = entry.Cast(); + WriteSchema(schema, serializer); + break; + } + case CatalogType::TYPE_ENTRY: { + auto &custom_type = entry.Cast(); + WriteType(custom_type, serializer); + break; + } + case CatalogType::SEQUENCE_ENTRY: { + auto &seq = entry.Cast(); + WriteSequence(seq, serializer); + break; + } + case CatalogType::TABLE_ENTRY: { + auto &table = entry.Cast(); + WriteTable(table, serializer); + break; + } + case CatalogType::VIEW_ENTRY: { + auto &view = entry.Cast(); + WriteView(view, serializer); + break; + } + case CatalogType::MACRO_ENTRY: { + auto ¯o = entry.Cast(); + WriteMacro(macro, serializer); + break; + } + case CatalogType::TABLE_MACRO_ENTRY: { + auto ¯o = entry.Cast(); + WriteTableMacro(macro, serializer); + break; + } + case CatalogType::INDEX_ENTRY: { + auto &index = entry.Cast(); + WriteIndex(index, serializer); + break; + } + default: + throw InternalException("Unrecognized catalog type in CheckpointWriter::WriteEntry"); + } +} + +//===--------------------------------------------------------------------===// +// Schema +//===--------------------------------------------------------------------===// +void CheckpointWriter::WriteSchema(SchemaCatalogEntry &schema, Serializer &serializer) { + // write the schema data + serializer.WriteProperty(100, "schema", &schema); +} + +void CheckpointReader::ReadEntry(ClientContext &context, Deserializer &deserializer) { + auto type = deserializer.ReadProperty(99, "type"); + + switch (type) { + case CatalogType::SCHEMA_ENTRY: { + ReadSchema(context, deserializer); + break; + } + case CatalogType::TYPE_ENTRY: { + ReadType(context, deserializer); + break; + } + case CatalogType::SEQUENCE_ENTRY: { + ReadSequence(context, deserializer); + break; + } + case CatalogType::TABLE_ENTRY: { + ReadTable(context, deserializer); + break; + } + case CatalogType::VIEW_ENTRY: { + ReadView(context, deserializer); + break; + } + case CatalogType::MACRO_ENTRY: { + ReadMacro(context, deserializer); + break; + } + case CatalogType::TABLE_MACRO_ENTRY: { + ReadTableMacro(context, deserializer); + break; + } + case CatalogType::INDEX_ENTRY: { + ReadIndex(context, deserializer); + break; + } + default: + throw InternalException("Unrecognized catalog type in CheckpointWriter::WriteEntry"); + } +} + +void CheckpointReader::ReadSchema(ClientContext &context, Deserializer &deserializer) { + // Read the schema and create it in the catalog + auto info = deserializer.ReadProperty>(100, "schema"); + auto &schema_info = info->Cast(); + + // we set create conflict to IGNORE_ON_CONFLICT, so that we can ignore a failure when recreating the main schema + schema_info.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; + catalog.CreateSchema(context, schema_info); +} + +//===--------------------------------------------------------------------===// +// Views +//===--------------------------------------------------------------------===// +void CheckpointWriter::WriteView(ViewCatalogEntry &view, Serializer &serializer) { + serializer.WriteProperty(100, "view", &view); +} + +void CheckpointReader::ReadView(ClientContext &context, Deserializer &deserializer) { + auto info = deserializer.ReadProperty>(100, "view"); + auto &view_info = info->Cast(); + catalog.CreateView(context, view_info); +} + +//===--------------------------------------------------------------------===// +// Sequences +//===--------------------------------------------------------------------===// +void CheckpointWriter::WriteSequence(SequenceCatalogEntry &seq, Serializer &serializer) { + serializer.WriteProperty(100, "sequence", &seq); +} + +void CheckpointReader::ReadSequence(ClientContext &context, Deserializer &deserializer) { + auto info = deserializer.ReadProperty>(100, "sequence"); + auto &sequence_info = info->Cast(); + catalog.CreateSequence(context, sequence_info); +} + +//===--------------------------------------------------------------------===// +// Indexes +//===--------------------------------------------------------------------===// +void CheckpointWriter::WriteIndex(IndexCatalogEntry &index_catalog, Serializer &serializer) { + // The index data is written as part of WriteTableData. + // Here, we need only serialize the pointer to that data. + auto root_block_pointer = index_catalog.index->GetRootBlockPointer(); + serializer.WriteProperty(100, "index", &index_catalog); + serializer.WriteProperty(101, "root_block_pointer", root_block_pointer); +} + +void CheckpointReader::ReadIndex(ClientContext &context, Deserializer &deserializer) { + + // deserialize the index create info + auto create_info = deserializer.ReadProperty>(100, "index"); + auto &info = create_info->Cast(); + + // create the index in the catalog + auto &schema = catalog.GetSchema(context, create_info->schema); + auto &table = + catalog.GetEntry(context, CatalogType::TABLE_ENTRY, create_info->schema, info.table).Cast(); + + auto &index = schema.CreateIndex(context, info, table)->Cast(); + + index.info = table.GetStorage().info; + // insert the parsed expressions into the stored index so that we correctly (de)serialize it during consecutive + // checkpoints + for (auto &parsed_expr : info.parsed_expressions) { + index.parsed_expressions.push_back(parsed_expr->Copy()); + } + + // we deserialize the index lazily, i.e., we do not need to load any node information + // except the root block pointer + auto root_block_pointer = deserializer.ReadProperty(101, "root_block_pointer"); + + // obtain the parsed expressions of the ART from the index metadata + vector> parsed_expressions; + for (auto &parsed_expr : info.parsed_expressions) { + parsed_expressions.push_back(parsed_expr->Copy()); + } + D_ASSERT(!parsed_expressions.empty()); + + // add the table to the bind context to bind the parsed expressions + auto binder = Binder::CreateBinder(context); + vector column_types; + vector column_names; + for (auto &col : table.GetColumns().Logical()) { + column_types.push_back(col.Type()); + column_names.push_back(col.Name()); + } + + // create a binder to bind the parsed expressions + vector column_ids; + binder->bind_context.AddBaseTable(0, info.table, column_names, column_types, column_ids, &table); + IndexBinder idx_binder(*binder, context); + + // bind the parsed expressions to create unbound expressions + vector> unbound_expressions; + unbound_expressions.reserve(parsed_expressions.size()); + for (auto &expr : parsed_expressions) { + unbound_expressions.push_back(idx_binder.Bind(expr)); + } + + // create the index and add it to the storage + switch (info.index_type) { + case IndexType::ART: { + auto &storage = table.GetStorage(); + auto art = make_uniq(info.column_ids, TableIOManager::Get(storage), std::move(unbound_expressions), + info.constraint_type, storage.db, nullptr, root_block_pointer); + + index.index = art.get(); + storage.info->indexes.AddIndex(std::move(art)); + } break; + default: + throw InternalException("Unknown index type for ReadIndex"); + } +} + +//===--------------------------------------------------------------------===// +// Custom Types +//===--------------------------------------------------------------------===// +void CheckpointWriter::WriteType(TypeCatalogEntry &type, Serializer &serializer) { + serializer.WriteProperty(100, "type", &type); +} + +void CheckpointReader::ReadType(ClientContext &context, Deserializer &deserializer) { + auto info = deserializer.ReadProperty>(100, "type"); + auto &type_info = info->Cast(); + catalog.CreateType(context, type_info); +} + +//===--------------------------------------------------------------------===// +// Macro's +//===--------------------------------------------------------------------===// +void CheckpointWriter::WriteMacro(ScalarMacroCatalogEntry ¯o, Serializer &serializer) { + serializer.WriteProperty(100, "macro", ¯o); +} + +void CheckpointReader::ReadMacro(ClientContext &context, Deserializer &deserializer) { + auto info = deserializer.ReadProperty>(100, "macro"); + auto ¯o_info = info->Cast(); + catalog.CreateFunction(context, macro_info); +} + +void CheckpointWriter::WriteTableMacro(TableMacroCatalogEntry ¯o, Serializer &serializer) { + serializer.WriteProperty(100, "table_macro", ¯o); +} + +void CheckpointReader::ReadTableMacro(ClientContext &context, Deserializer &deserializer) { + auto info = deserializer.ReadProperty>(100, "table_macro"); + auto ¯o_info = info->Cast(); + catalog.CreateFunction(context, macro_info); +} + +//===--------------------------------------------------------------------===// +// Table Metadata +//===--------------------------------------------------------------------===// +void CheckpointWriter::WriteTable(TableCatalogEntry &table, Serializer &serializer) { + // Write the table meta data + serializer.WriteProperty(100, "table", &table); + + // Write the table data + if (auto writer = GetTableDataWriter(table)) { + writer->WriteTableData(serializer); + } +} + +void CheckpointReader::ReadTable(ClientContext &context, Deserializer &deserializer) { + // deserialize the table meta data + auto info = deserializer.ReadProperty>(100, "table"); + auto binder = Binder::CreateBinder(context); + auto &schema = catalog.GetSchema(context, info->schema); + auto bound_info = binder->BindCreateTableInfo(std::move(info), schema); + + // now read the actual table data and place it into the create table info + ReadTableData(context, deserializer, *bound_info); + + // finally create the table in the catalog + catalog.CreateTable(context, *bound_info); +} + +void CheckpointReader::ReadTableData(ClientContext &context, Deserializer &deserializer, + BoundCreateTableInfo &bound_info) { + + // This is written in "SingleFileTableDataWriter::FinalizeTable" + auto table_pointer = deserializer.ReadProperty(101, "table_pointer"); + auto total_rows = deserializer.ReadProperty(102, "total_rows"); + auto index_pointers = deserializer.ReadProperty>(103, "index_pointers"); + + // FIXME: icky downcast to get the underlying MetadataReader + auto &binary_deserializer = dynamic_cast(deserializer); + auto &reader = dynamic_cast(binary_deserializer.GetStream()); + + MetadataReader table_data_reader(reader.GetMetadataManager(), table_pointer); + TableDataReader data_reader(table_data_reader, bound_info); + data_reader.ReadTableData(); + + bound_info.data->total_rows = total_rows; + bound_info.indexes = index_pointers; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/bitpacking.cpp b/src/duckdb/src/storage/compression/bitpacking.cpp new file mode 100644 index 00000000..ffc79432 --- /dev/null +++ b/src/duckdb/src/storage/compression/bitpacking.cpp @@ -0,0 +1,966 @@ +#include "duckdb/common/bitpacking.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/storage/compression/bitpacking.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/common/numeric_utils.hpp" + +#include + +namespace duckdb { + +static constexpr const idx_t BITPACKING_METADATA_GROUP_SIZE = STANDARD_VECTOR_SIZE > 512 ? STANDARD_VECTOR_SIZE : 2048; + +BitpackingMode BitpackingModeFromString(const string &str) { + auto mode = StringUtil::Lower(str); + if (mode == "auto" || mode == "none") { + return BitpackingMode::AUTO; + } else if (mode == "constant") { + return BitpackingMode::CONSTANT; + } else if (mode == "constant_delta") { + return BitpackingMode::CONSTANT_DELTA; + } else if (mode == "delta_for") { + return BitpackingMode::DELTA_FOR; + } else if (mode == "for") { + return BitpackingMode::FOR; + } else { + return BitpackingMode::INVALID; + } +} + +string BitpackingModeToString(const BitpackingMode &mode) { + switch (mode) { + case BitpackingMode::AUTO: + return "auto"; + case BitpackingMode::CONSTANT: + return "constant"; + case BitpackingMode::CONSTANT_DELTA: + return "constant_delta"; + case BitpackingMode::DELTA_FOR: + return "delta_for"; + case BitpackingMode::FOR: + return "for"; + default: + throw NotImplementedException("Unknown bitpacking mode: " + to_string((uint8_t)mode) + "\n"); + } +} + +typedef struct { + BitpackingMode mode; + uint32_t offset; +} bitpacking_metadata_t; + +typedef uint32_t bitpacking_metadata_encoded_t; + +static bitpacking_metadata_encoded_t EncodeMeta(bitpacking_metadata_t metadata) { + D_ASSERT(metadata.offset <= 16777215); // max uint24_t + bitpacking_metadata_encoded_t encoded_value = metadata.offset; + encoded_value |= (uint8_t)metadata.mode << 24; + return encoded_value; +} +static bitpacking_metadata_t DecodeMeta(bitpacking_metadata_encoded_t *metadata_encoded) { + bitpacking_metadata_t metadata; + metadata.mode = Load(data_ptr_cast(metadata_encoded) + 3); + metadata.offset = *metadata_encoded & 0x00FFFFFF; + return metadata; +} + +struct EmptyBitpackingWriter { + template + static void WriteConstant(T constant, idx_t count, void *data_ptr, bool all_invalid) { + } + template ::type> + static void WriteConstantDelta(T_S constant, T frame_of_reference, idx_t count, T *values, bool *validity, + void *data_ptr) { + } + template ::type> + static void WriteDeltaFor(T *values, bool *validity, bitpacking_width_t width, T frame_of_reference, + T_S delta_offset, T *original_values, idx_t count, void *data_ptr) { + } + template + static void WriteFor(T *values, bool *validity, bitpacking_width_t width, T frame_of_reference, idx_t count, + void *data_ptr) { + } +}; + +template ::type> +struct BitpackingState { +public: + BitpackingState() : compression_buffer_idx(0), total_size(0), data_ptr(nullptr) { + compression_buffer_internal[0] = T(0); + compression_buffer = &compression_buffer_internal[1]; + Reset(); + } + + // Extra val for delta encoding + T compression_buffer_internal[BITPACKING_METADATA_GROUP_SIZE + 1]; + T *compression_buffer; + T_S delta_buffer[BITPACKING_METADATA_GROUP_SIZE]; + bool compression_buffer_validity[BITPACKING_METADATA_GROUP_SIZE]; + idx_t compression_buffer_idx; + idx_t total_size; + + // Used to pass CompressionState ptr through the Bitpacking writer + void *data_ptr; + + // Stats on current compression buffer + T minimum; + T maximum; + T min_max_diff; + T_S minimum_delta; + T_S maximum_delta; + T_S min_max_delta_diff; + T_S delta_offset; + bool all_valid; + bool all_invalid; + + bool can_do_delta; + bool can_do_for; + + // Used to force a specific mode, useful in testing + BitpackingMode mode = BitpackingMode::AUTO; + +public: + void Reset() { + minimum = NumericLimits::Maximum(); + minimum_delta = NumericLimits::Maximum(); + maximum = NumericLimits::Minimum(); + maximum_delta = NumericLimits::Minimum(); + delta_offset = 0; + all_valid = true; + all_invalid = true; + can_do_delta = false; + can_do_for = false; + compression_buffer_idx = 0; + min_max_diff = 0; + min_max_delta_diff = 0; + } + + void CalculateFORStats() { + can_do_for = TrySubtractOperator::Operation(maximum, minimum, min_max_diff); + } + + void CalculateDeltaStats() { + // TODO: currently we dont support delta compression of values above NumericLimits::Maximum(), + // we could support this with some clever substract trickery? + if (maximum > static_cast(NumericLimits::Maximum())) { + return; + } + + // Don't delta encoding 1 value makes no sense + if (compression_buffer_idx < 2) { + return; + } + + // TODO: handle NULLS here? + // Currently we cannot handle nulls because we would need an additional step of patching for this. + // we could for example copy the last value on a null insert. This would help a bit, but not be optimal for + // large deltas since theres suddenly a zero then. Ideally we would insert a value that leads to a delta within + // the current domain of deltas however we dont know that domain here yet + if (!all_valid) { + return; + } + + // Note: since we dont allow any values over NumericLimits::Maximum(), all subtractions for unsigned types + // are guaranteed not to overflow + bool can_do_all = true; + if (NumericLimits::IsSigned()) { + T_S bogus; + can_do_all = TrySubtractOperator::Operation(static_cast(minimum), static_cast(maximum), bogus) && + TrySubtractOperator::Operation(static_cast(maximum), static_cast(minimum), bogus); + } + + // Calculate delta's + // compression_buffer pointer points one element ahead of the internal buffer making the use of signed index + // integer (-1) possible + D_ASSERT(compression_buffer_idx <= NumericLimits::Maximum()); + if (can_do_all) { + for (int64_t i = 0; i < static_cast(compression_buffer_idx); i++) { + delta_buffer[i] = static_cast(compression_buffer[i]) - static_cast(compression_buffer[i - 1]); + } + } else { + for (int64_t i = 0; i < static_cast(compression_buffer_idx); i++) { + auto success = + TrySubtractOperator::Operation(static_cast(compression_buffer[i]), + static_cast(compression_buffer[i - 1]), delta_buffer[i]); + if (!success) { + return; + } + } + } + + can_do_delta = true; + + for (idx_t i = 1; i < compression_buffer_idx; i++) { + maximum_delta = MaxValue(maximum_delta, delta_buffer[i]); + minimum_delta = MinValue(minimum_delta, delta_buffer[i]); + } + + // Since we can set the first value arbitrarily, we want to pick one from the current domain, note that + // we will store the original first value - this offset as the delta_offset to be able to decode this again. + delta_buffer[0] = minimum_delta; + + can_do_delta = can_do_delta && TrySubtractOperator::Operation(maximum_delta, minimum_delta, min_max_delta_diff); + can_do_delta = can_do_delta && TrySubtractOperator::Operation(static_cast(compression_buffer[0]), + minimum_delta, delta_offset); + } + + template + void SubtractFrameOfReference(T_INNER *buffer, T_INNER frame_of_reference) { + static_assert(IsIntegral::value, "Integral type required."); + for (idx_t i = 0; i < compression_buffer_idx; i++) { + buffer[i] -= static_cast::type>(frame_of_reference); + } + } + + template + bool Flush() { + if (compression_buffer_idx == 0) { + return true; + } + + if ((all_invalid || maximum == minimum) && (mode == BitpackingMode::AUTO || mode == BitpackingMode::CONSTANT)) { + OP::WriteConstant(maximum, compression_buffer_idx, data_ptr, all_invalid); + total_size += sizeof(T) + sizeof(bitpacking_metadata_encoded_t); + return true; + } + + CalculateFORStats(); + CalculateDeltaStats(); + + if (can_do_delta) { + if (maximum_delta == minimum_delta && mode != BitpackingMode::FOR && mode != BitpackingMode::DELTA_FOR) { + // FOR needs to be T (considering hugeint is bigger than idx_t) + T frame_of_reference = compression_buffer[0]; + + OP::WriteConstantDelta(maximum_delta, static_cast(frame_of_reference), compression_buffer_idx, + compression_buffer, compression_buffer_validity, data_ptr); + total_size += sizeof(T) + sizeof(T) + sizeof(bitpacking_metadata_encoded_t); + return true; + } + + // Check if delta has benefit + // bitwidth is calculated differently between signed and unsigned values, but considering we do not have + // an unsigned version of hugeint, we need to explicitly specify (through boolean) that we wish to calculate + // the unsigned minimum bit-width instead of relying on MakeUnsigned and IsSigned + auto delta_required_bitwidth = BitpackingPrimitives::MinimumBitWidth(min_max_delta_diff); + auto regular_required_bitwidth = BitpackingPrimitives::MinimumBitWidth(min_max_diff); + + if (delta_required_bitwidth < regular_required_bitwidth && mode != BitpackingMode::FOR) { + SubtractFrameOfReference(delta_buffer, minimum_delta); + + OP::WriteDeltaFor(reinterpret_cast(delta_buffer), compression_buffer_validity, + delta_required_bitwidth, static_cast(minimum_delta), delta_offset, + compression_buffer, compression_buffer_idx, data_ptr); + + total_size += BitpackingPrimitives::GetRequiredSize(compression_buffer_idx, delta_required_bitwidth); + total_size += sizeof(T); // FOR value + total_size += sizeof(T); // Delta offset value + total_size += AlignValue(sizeof(bitpacking_width_t)); // FOR value + + return true; + } + } + + if (can_do_for) { + auto width = BitpackingPrimitives::MinimumBitWidth(min_max_diff); + SubtractFrameOfReference(compression_buffer, minimum); + OP::WriteFor(compression_buffer, compression_buffer_validity, width, minimum, compression_buffer_idx, + data_ptr); + + total_size += BitpackingPrimitives::GetRequiredSize(compression_buffer_idx, width); + total_size += sizeof(T); // FOR value + total_size += AlignValue(sizeof(bitpacking_width_t)); + + return true; + } + + return false; + } + + template + bool Update(T value, bool is_valid) { + compression_buffer_validity[compression_buffer_idx] = is_valid; + all_valid = all_valid && is_valid; + all_invalid = all_invalid && !is_valid; + + if (is_valid) { + compression_buffer[compression_buffer_idx] = value; + minimum = MinValue(minimum, value); + maximum = MaxValue(maximum, value); + } + + compression_buffer_idx++; + + if (compression_buffer_idx == BITPACKING_METADATA_GROUP_SIZE) { + bool success = Flush(); + Reset(); + return success; + } + return true; + } +}; + +//===--------------------------------------------------------------------===// +// Analyze +//===--------------------------------------------------------------------===// +template +struct BitpackingAnalyzeState : public AnalyzeState { + BitpackingState state; +}; + +template +unique_ptr BitpackingInitAnalyze(ColumnData &col_data, PhysicalType type) { + auto &config = DBConfig::GetConfig(col_data.GetDatabase()); + + auto state = make_uniq>(); + state->state.mode = config.options.force_bitpacking_mode; + + return std::move(state); +} + +template +bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { + auto &analyze_state = static_cast &>(state); + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + if (!analyze_state.state.template Update(data[idx], vdata.validity.RowIsValid(idx))) { + return false; + } + } + return true; +} + +template +idx_t BitpackingFinalAnalyze(AnalyzeState &state) { + auto &bitpacking_state = static_cast &>(state); + auto flush_result = bitpacking_state.state.template Flush(); + if (!flush_result) { + return DConstants::INVALID_INDEX; + } + return bitpacking_state.state.total_size; +} + +//===--------------------------------------------------------------------===// +// Compress +//===--------------------------------------------------------------------===// +template ::type> +struct BitpackingCompressState : public CompressionState { +public: + explicit BitpackingCompressState(ColumnDataCheckpointer &checkpointer) + : checkpointer(checkpointer), + function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_BITPACKING)) { + CreateEmptySegment(checkpointer.GetRowGroup().start); + + state.data_ptr = reinterpret_cast(this); + + auto &config = DBConfig::GetConfig(checkpointer.GetDatabase()); + state.mode = config.options.force_bitpacking_mode; + } + + ColumnDataCheckpointer &checkpointer; + CompressionFunction &function; + unique_ptr current_segment; + BufferHandle handle; + + // Ptr to next free spot in segment; + data_ptr_t data_ptr; + // Ptr to next free spot for storing bitwidths and frame-of-references (growing downwards). + data_ptr_t metadata_ptr; + + BitpackingState state; + +public: + struct BitpackingWriter { + static void WriteConstant(T constant, idx_t count, void *data_ptr, bool all_invalid) { + auto state = reinterpret_cast *>(data_ptr); + + ReserveSpace(state, sizeof(T)); + WriteMetaData(state, BitpackingMode::CONSTANT); + WriteData(state->data_ptr, constant); + + UpdateStats(state, count); + } + + static void WriteConstantDelta(T_S constant, T frame_of_reference, idx_t count, T *values, bool *validity, + void *data_ptr) { + auto state = reinterpret_cast *>(data_ptr); + + ReserveSpace(state, 2 * sizeof(T)); + WriteMetaData(state, BitpackingMode::CONSTANT_DELTA); + WriteData(state->data_ptr, frame_of_reference); + WriteData(state->data_ptr, constant); + + UpdateStats(state, count); + } + static void WriteDeltaFor(T *values, bool *validity, bitpacking_width_t width, T frame_of_reference, + T_S delta_offset, T *original_values, idx_t count, void *data_ptr) { + auto state = reinterpret_cast *>(data_ptr); + + auto bp_size = BitpackingPrimitives::GetRequiredSize(count, width); + ReserveSpace(state, bp_size + 3 * sizeof(T)); + + WriteMetaData(state, BitpackingMode::DELTA_FOR); + WriteData(state->data_ptr, frame_of_reference); + WriteData(state->data_ptr, static_cast(width)); + WriteData(state->data_ptr, delta_offset); + + BitpackingPrimitives::PackBuffer(state->data_ptr, values, count, width); + state->data_ptr += bp_size; + + UpdateStats(state, count); + } + + static void WriteFor(T *values, bool *validity, bitpacking_width_t width, T frame_of_reference, idx_t count, + void *data_ptr) { + auto state = reinterpret_cast *>(data_ptr); + + auto bp_size = BitpackingPrimitives::GetRequiredSize(count, width); + ReserveSpace(state, bp_size + 2 * sizeof(T)); + + WriteMetaData(state, BitpackingMode::FOR); + WriteData(state->data_ptr, frame_of_reference); + WriteData(state->data_ptr, (T)width); + + BitpackingPrimitives::PackBuffer(state->data_ptr, values, count, width); + state->data_ptr += bp_size; + + UpdateStats(state, count); + } + + template + static void WriteData(data_ptr_t &ptr, T_OUT val) { + *reinterpret_cast(ptr) = val; + ptr += sizeof(T_OUT); + } + + static void WriteMetaData(BitpackingCompressState *state, BitpackingMode mode) { + bitpacking_metadata_t metadata {mode, (uint32_t)(state->data_ptr - state->handle.Ptr())}; + state->metadata_ptr -= sizeof(bitpacking_metadata_encoded_t); + Store(EncodeMeta(metadata), state->metadata_ptr); + } + + static void ReserveSpace(BitpackingCompressState *state, idx_t data_bytes) { + idx_t meta_bytes = sizeof(bitpacking_metadata_encoded_t); + state->FlushAndCreateSegmentIfFull(data_bytes, meta_bytes); + D_ASSERT(state->CanStore(data_bytes, meta_bytes)); + } + + static void UpdateStats(BitpackingCompressState *state, idx_t count) { + state->current_segment->count += count; + + if (WRITE_STATISTICS && !state->state.all_invalid) { + NumericStats::Update(state->current_segment->stats.statistics, state->state.minimum); + NumericStats::Update(state->current_segment->stats.statistics, state->state.maximum); + } + } + }; + + bool CanStore(idx_t data_bytes, idx_t meta_bytes) { + auto required_data_bytes = AlignValue((data_ptr + data_bytes) - data_ptr); + auto required_meta_bytes = Storage::BLOCK_SIZE - (metadata_ptr - data_ptr) + meta_bytes; + + return required_data_bytes + required_meta_bytes <= + Storage::BLOCK_SIZE - BitpackingPrimitives::BITPACKING_HEADER_SIZE; + } + + void CreateEmptySegment(idx_t row_start) { + auto &db = checkpointer.GetDatabase(); + auto &type = checkpointer.GetType(); + auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); + compressed_segment->function = function; + current_segment = std::move(compressed_segment); + auto &buffer_manager = BufferManager::GetBufferManager(db); + handle = buffer_manager.Pin(current_segment->block); + + data_ptr = handle.Ptr() + BitpackingPrimitives::BITPACKING_HEADER_SIZE; + metadata_ptr = handle.Ptr() + Storage::BLOCK_SIZE; + } + + void Append(UnifiedVectorFormat &vdata, idx_t count) { + auto data = UnifiedVectorFormat::GetData(vdata); + + for (idx_t i = 0; i < count; i++) { + idx_t idx = vdata.sel->get_index(i); + state.template Update::BitpackingWriter>( + data[idx], vdata.validity.RowIsValid(idx)); + } + } + + void FlushAndCreateSegmentIfFull(idx_t required_data_bytes, idx_t required_meta_bytes) { + if (!CanStore(required_data_bytes, required_meta_bytes)) { + idx_t row_start = current_segment->start + current_segment->count; + FlushSegment(); + CreateEmptySegment(row_start); + } + } + + void FlushSegment() { + auto &state = checkpointer.GetCheckpointState(); + auto base_ptr = handle.Ptr(); + + // Compact the segment by moving the metadata next to the data. + idx_t metadata_offset = AlignValue(data_ptr - base_ptr); + idx_t metadata_size = base_ptr + Storage::BLOCK_SIZE - metadata_ptr; + idx_t total_segment_size = metadata_offset + metadata_size; + + // Asserting things are still sane here + if (!CanStore(0, 0)) { + throw InternalException("Error in bitpacking size calculation"); + } + + memmove(base_ptr + metadata_offset, metadata_ptr, metadata_size); + + // Store the offset of the metadata of the first group (which is at the highest address). + Store(metadata_offset + metadata_size, base_ptr); + handle.Destroy(); + + state.FlushSegment(std::move(current_segment), total_segment_size); + } + + void Finalize() { + state.template Flush::BitpackingWriter>(); + FlushSegment(); + current_segment.reset(); + } +}; + +template +unique_ptr BitpackingInitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr state) { + return make_uniq>(checkpointer); +} + +template +void BitpackingCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { + auto &state = static_cast &>(state_p); + UnifiedVectorFormat vdata; + scan_vector.ToUnifiedFormat(count, vdata); + state.Append(vdata, count); +} + +template +void BitpackingFinalizeCompress(CompressionState &state_p) { + auto &state = static_cast &>(state_p); + state.Finalize(); +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +template +static void ApplyFrameOfReference(T *dst, T frame_of_reference, idx_t size) { + if (!frame_of_reference) { + return; + } + for (idx_t i = 0; i < size; i++) { + dst[i] += frame_of_reference; + } +} + +// Based on https://github.com/lemire/FastPFor (Apache License 2.0) +template +static T DeltaDecode(T *data, T previous_value, const size_t size) { + D_ASSERT(size >= 1); + + data[0] += previous_value; + + const size_t UnrollQty = 4; + const size_t sz0 = (size / UnrollQty) * UnrollQty; // equal to 0, if size < UnrollQty + size_t i = 1; + if (sz0 >= UnrollQty) { + T a = data[0]; + for (; i < sz0 - UnrollQty; i += UnrollQty) { + a = data[i] += a; + a = data[i + 1] += a; + a = data[i + 2] += a; + a = data[i + 3] += a; + } + } + for (; i != size; ++i) { + data[i] += data[i - 1]; + } + + return data[size - 1]; +} + +template ::type> +struct BitpackingScanState : public SegmentScanState { +public: + explicit BitpackingScanState(ColumnSegment &segment) : current_segment(segment) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + handle = buffer_manager.Pin(segment.block); + auto dataptr = handle.Ptr(); + + // load offset to bitpacking widths pointer + auto bitpacking_metadata_offset = Load(dataptr + segment.GetBlockOffset()); + bitpacking_metadata_ptr = + dataptr + segment.GetBlockOffset() + bitpacking_metadata_offset - sizeof(bitpacking_metadata_encoded_t); + + // load the first group + LoadNextGroup(); + } + + BufferHandle handle; + ColumnSegment ¤t_segment; + + T decompression_buffer[BITPACKING_METADATA_GROUP_SIZE]; + + bitpacking_metadata_t current_group; + + bitpacking_width_t current_width; + T current_frame_of_reference; + T current_constant; + T current_delta_offset; + + idx_t current_group_offset = 0; + data_ptr_t current_group_ptr; + data_ptr_t bitpacking_metadata_ptr; + +public: + //! Loads the metadata for the current metadata group. This will set bitpacking_metadata_ptr to the next group. + //! this will also load any metadata that is at the start of a compressed buffer (e.g. the width, for, or constant + //! value) depending on the bitpacking mode for that group + void LoadNextGroup() { + D_ASSERT(bitpacking_metadata_ptr > handle.Ptr() && + bitpacking_metadata_ptr < handle.Ptr() + Storage::BLOCK_SIZE); + current_group_offset = 0; + current_group = DecodeMeta(reinterpret_cast(bitpacking_metadata_ptr)); + + bitpacking_metadata_ptr -= sizeof(bitpacking_metadata_encoded_t); + current_group_ptr = GetPtr(current_group); + + // Read first value + switch (current_group.mode) { + case BitpackingMode::CONSTANT: + current_constant = *reinterpret_cast(current_group_ptr); + current_group_ptr += sizeof(T); + break; + case BitpackingMode::FOR: + case BitpackingMode::CONSTANT_DELTA: + case BitpackingMode::DELTA_FOR: + current_frame_of_reference = *reinterpret_cast(current_group_ptr); + current_group_ptr += sizeof(T); + break; + default: + throw InternalException("Invalid bitpacking mode"); + } + + // Read second value + switch (current_group.mode) { + case BitpackingMode::CONSTANT_DELTA: + current_constant = *reinterpret_cast(current_group_ptr); + current_group_ptr += sizeof(T); + break; + case BitpackingMode::FOR: + case BitpackingMode::DELTA_FOR: + current_width = (bitpacking_width_t)(*reinterpret_cast(current_group_ptr)); + current_group_ptr += MaxValue(sizeof(T), sizeof(bitpacking_width_t)); + break; + case BitpackingMode::CONSTANT: + break; + default: + throw InternalException("Invalid bitpacking mode"); + } + + // Read third value + if (current_group.mode == BitpackingMode::DELTA_FOR) { + current_delta_offset = *reinterpret_cast(current_group_ptr); + current_group_ptr += sizeof(T); + } + } + + void Skip(ColumnSegment &segment, idx_t skip_count) { + bool skip_sign_extend = true; + + idx_t skipped = 0; + while (skipped < skip_count) { + // Exhausted this metadata group, move pointers to next group and load metadata for next group. + if (current_group_offset >= BITPACKING_METADATA_GROUP_SIZE) { + LoadNextGroup(); + } + + idx_t offset_in_compression_group = + current_group_offset % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; + + if (current_group.mode == BitpackingMode::CONSTANT) { + idx_t remaining = skip_count - skipped; + idx_t to_skip = MinValue(remaining, BITPACKING_METADATA_GROUP_SIZE - current_group_offset); + skipped += to_skip; + current_group_offset += to_skip; + continue; + } + if (current_group.mode == BitpackingMode::CONSTANT_DELTA) { + idx_t remaining = skip_count - skipped; + idx_t to_skip = MinValue(remaining, BITPACKING_METADATA_GROUP_SIZE - current_group_offset); + skipped += to_skip; + current_group_offset += to_skip; + continue; + } + D_ASSERT(current_group.mode == BitpackingMode::FOR || current_group.mode == BitpackingMode::DELTA_FOR); + + idx_t to_skip = + MinValue(skip_count - skipped, + BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - offset_in_compression_group); + // Calculate start of compression algorithm group + if (current_group.mode == BitpackingMode::DELTA_FOR) { + data_ptr_t current_position_ptr = current_group_ptr + current_group_offset * current_width / 8; + data_ptr_t decompression_group_start_pointer = + current_position_ptr - offset_in_compression_group * current_width / 8; + + BitpackingPrimitives::UnPackBlock(data_ptr_cast(decompression_buffer), + decompression_group_start_pointer, current_width, + skip_sign_extend); + + T *decompression_ptr = decompression_buffer + offset_in_compression_group; + ApplyFrameOfReference(reinterpret_cast(decompression_ptr), + static_cast(current_frame_of_reference), to_skip); + DeltaDecode(reinterpret_cast(decompression_ptr), static_cast(current_delta_offset), + to_skip); + current_delta_offset = decompression_ptr[to_skip - 1]; + } + + skipped += to_skip; + current_group_offset += to_skip; + } + } + + data_ptr_t GetPtr(bitpacking_metadata_t group) { + return handle.Ptr() + current_segment.GetBlockOffset() + group.offset; + } +}; + +template +unique_ptr BitpackingInitScan(ColumnSegment &segment) { + auto result = make_uniq>(segment); + return std::move(result); +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +template ::type> +void BitpackingScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + auto &scan_state = static_cast &>(*state.scan_state); + + T *result_data = FlatVector::GetData(result); + result.SetVectorType(VectorType::FLAT_VECTOR); + + //! Because FOR offsets all our values to be 0 or above, we can always skip sign extension here + bool skip_sign_extend = true; + + idx_t scanned = 0; + while (scanned < scan_count) { + // Exhausted this metadata group, move pointers to next group and load metadata for next group. + if (scan_state.current_group_offset >= BITPACKING_METADATA_GROUP_SIZE) { + scan_state.LoadNextGroup(); + } + + idx_t offset_in_compression_group = + scan_state.current_group_offset % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; + + if (scan_state.current_group.mode == BitpackingMode::CONSTANT) { + idx_t remaining = scan_count - scanned; + idx_t to_scan = MinValue(remaining, BITPACKING_METADATA_GROUP_SIZE - scan_state.current_group_offset); + T *begin = result_data + result_offset + scanned; + T *end = begin + remaining; + std::fill(begin, end, scan_state.current_constant); + scanned += to_scan; + scan_state.current_group_offset += to_scan; + continue; + } + if (scan_state.current_group.mode == BitpackingMode::CONSTANT_DELTA) { + idx_t remaining = scan_count - scanned; + idx_t to_scan = MinValue(remaining, BITPACKING_METADATA_GROUP_SIZE - scan_state.current_group_offset); + T *target_ptr = result_data + result_offset + scanned; + + for (idx_t i = 0; i < to_scan; i++) { + target_ptr[i] = (static_cast(scan_state.current_group_offset + i) * scan_state.current_constant) + + scan_state.current_frame_of_reference; + } + + scanned += to_scan; + scan_state.current_group_offset += to_scan; + continue; + } + D_ASSERT(scan_state.current_group.mode == BitpackingMode::FOR || + scan_state.current_group.mode == BitpackingMode::DELTA_FOR); + + idx_t to_scan = MinValue(scan_count - scanned, BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - + offset_in_compression_group); + // Calculate start of compression algorithm group + data_ptr_t current_position_ptr = + scan_state.current_group_ptr + scan_state.current_group_offset * scan_state.current_width / 8; + data_ptr_t decompression_group_start_pointer = + current_position_ptr - offset_in_compression_group * scan_state.current_width / 8; + + T *current_result_ptr = result_data + result_offset + scanned; + + if (to_scan == BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE && offset_in_compression_group == 0) { + // Decompress directly into result vector + BitpackingPrimitives::UnPackBlock(data_ptr_cast(current_result_ptr), decompression_group_start_pointer, + scan_state.current_width, skip_sign_extend); + } else { + // Decompress compression algorithm to buffer + BitpackingPrimitives::UnPackBlock(data_ptr_cast(scan_state.decompression_buffer), + decompression_group_start_pointer, scan_state.current_width, + skip_sign_extend); + + memcpy(current_result_ptr, scan_state.decompression_buffer + offset_in_compression_group, + to_scan * sizeof(T)); + } + + if (scan_state.current_group.mode == BitpackingMode::DELTA_FOR) { + ApplyFrameOfReference(reinterpret_cast(current_result_ptr), + static_cast(scan_state.current_frame_of_reference), to_scan); + DeltaDecode(reinterpret_cast(current_result_ptr), + static_cast(scan_state.current_delta_offset), to_scan); + scan_state.current_delta_offset = current_result_ptr[to_scan - 1]; + } else { + ApplyFrameOfReference(current_result_ptr, scan_state.current_frame_of_reference, to_scan); + } + + scanned += to_scan; + scan_state.current_group_offset += to_scan; + } +} + +template +void BitpackingScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + BitpackingScanPartial(segment, state, scan_count, result, 0); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +template +void BitpackingFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + BitpackingScanState scan_state(segment); + scan_state.Skip(segment, row_id); + T *result_data = FlatVector::GetData(result); + T *current_result_ptr = result_data + result_idx; + + idx_t offset_in_compression_group = + scan_state.current_group_offset % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; + + data_ptr_t decompression_group_start_pointer = + scan_state.current_group_ptr + + (scan_state.current_group_offset - offset_in_compression_group) * scan_state.current_width / 8; + + //! Because FOR offsets all our values to be 0 or above, we can always skip sign extension here + bool skip_sign_extend = true; + + if (scan_state.current_group.mode == BitpackingMode::CONSTANT) { + *current_result_ptr = scan_state.current_constant; + return; + } + + if (scan_state.current_group.mode == BitpackingMode::CONSTANT_DELTA) { +#ifdef DEBUG + // overflow check + T result; + bool multiply = TryMultiplyOperator::Operation(static_cast(scan_state.current_group_offset), + scan_state.current_constant, result); + bool add = TryAddOperator::Operation(result, scan_state.current_frame_of_reference, result); + D_ASSERT(multiply && add); +#endif + *current_result_ptr = (static_cast(scan_state.current_group_offset) * scan_state.current_constant) + + scan_state.current_frame_of_reference; + return; + } + + D_ASSERT(scan_state.current_group.mode == BitpackingMode::FOR || + scan_state.current_group.mode == BitpackingMode::DELTA_FOR); + + BitpackingPrimitives::UnPackBlock(data_ptr_cast(scan_state.decompression_buffer), + decompression_group_start_pointer, scan_state.current_width, skip_sign_extend); + + *current_result_ptr = scan_state.decompression_buffer[offset_in_compression_group]; + *current_result_ptr += scan_state.current_frame_of_reference; + + if (scan_state.current_group.mode == BitpackingMode::DELTA_FOR) { + *current_result_ptr += scan_state.current_delta_offset; + } +} +template +void BitpackingSkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { + auto &scan_state = static_cast &>(*state.scan_state); + scan_state.Skip(segment, skip_count); +} + +//===--------------------------------------------------------------------===// +// Get Function +//===--------------------------------------------------------------------===// +template +CompressionFunction GetBitpackingFunction(PhysicalType data_type) { + return CompressionFunction(CompressionType::COMPRESSION_BITPACKING, data_type, BitpackingInitAnalyze, + BitpackingAnalyze, BitpackingFinalAnalyze, + BitpackingInitCompression, BitpackingCompress, + BitpackingFinalizeCompress, BitpackingInitScan, + BitpackingScan, BitpackingScanPartial, BitpackingFetchRow, BitpackingSkip); +} + +CompressionFunction BitpackingFun::GetFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return GetBitpackingFunction(type); + case PhysicalType::INT16: + return GetBitpackingFunction(type); + case PhysicalType::INT32: + return GetBitpackingFunction(type); + case PhysicalType::INT64: + return GetBitpackingFunction(type); + case PhysicalType::UINT8: + return GetBitpackingFunction(type); + case PhysicalType::UINT16: + return GetBitpackingFunction(type); + case PhysicalType::UINT32: + return GetBitpackingFunction(type); + case PhysicalType::UINT64: + return GetBitpackingFunction(type); + case PhysicalType::INT128: + return GetBitpackingFunction(type); + case PhysicalType::LIST: + return GetBitpackingFunction(type); + default: + throw InternalException("Unsupported type for Bitpacking"); + } +} + +bool BitpackingFun::TypeIsSupported(PhysicalType type) { + switch (type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::LIST: + case PhysicalType::INT128: + return true; + default: + return false; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp b/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp new file mode 100644 index 00000000..6743ecd8 --- /dev/null +++ b/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp @@ -0,0 +1,295 @@ +#include "duckdb/common/bitpacking.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Unpacking +//===--------------------------------------------------------------------===// + +static void UnpackSingle(const uint32_t *__restrict &in, hugeint_t *__restrict out, uint16_t delta, uint16_t shr) { + if (delta + shr < 32) { + *out = ((static_cast(in[0])) >> shr) % (hugeint_t(1) << delta); + } + + else if (delta + shr >= 32 && delta + shr < 64) { + *out = static_cast(in[0]) >> shr; + ++in; + + if (delta + shr > 32) { + const uint16_t NEXT_SHR = shr + delta - 32; + *out |= static_cast((*in) % (1U << NEXT_SHR)) << (32 - shr); + } + } + + else if (delta + shr >= 64 && delta + shr < 96) { + *out = static_cast(in[0]) >> shr; + *out |= static_cast(in[1]) << (32 - shr); + in += 2; + + if (delta + shr > 64) { + const uint16_t NEXT_SHR = delta + shr - 64; + *out |= static_cast((*in) % (1U << NEXT_SHR)) << (64 - shr); + } + } + + else if (delta + shr >= 96 && delta + shr < 128) { + *out = static_cast(in[0]) >> shr; + *out |= static_cast(in[1]) << (32 - shr); + *out |= static_cast(in[2]) << (64 - shr); + in += 3; + + if (delta + shr > 96) { + const uint16_t NEXT_SHR = delta + shr - 96; + *out |= static_cast((*in) % (1U << NEXT_SHR)) << (96 - shr); + } + } + + else if (delta + shr >= 128) { + *out = static_cast(in[0]) >> shr; + *out |= static_cast(in[1]) << (32 - shr); + *out |= static_cast(in[2]) << (64 - shr); + *out |= static_cast(in[3]) << (96 - shr); + in += 4; + + if (delta + shr > 128) { + const uint16_t NEXT_SHR = delta + shr - 128; + *out |= static_cast((*in) % (1U << NEXT_SHR)) << (128 - shr); + } + } +} + +static void UnpackLast(const uint32_t *__restrict &in, hugeint_t *__restrict out, uint16_t delta) { + const uint8_t LAST_IDX = 31; + const uint16_t SHIFT = (delta * 31) % 32; + out[LAST_IDX] = in[0] >> SHIFT; + if (delta > 32) { + out[LAST_IDX] |= static_cast(in[1]) << (32 - SHIFT); + } + if (delta > 64) { + out[LAST_IDX] |= static_cast(in[2]) << (64 - SHIFT); + } + if (delta > 96) { + out[LAST_IDX] |= static_cast(in[3]) << (96 - SHIFT); + } +} + +// Unpacks for specific deltas +static void UnpackDelta0(const uint32_t *__restrict in, hugeint_t *__restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + out[i] = 0; + } +} + +static void UnpackDelta32(const uint32_t *__restrict in, hugeint_t *__restrict out) { + for (uint8_t k = 0; k < 32; ++k) { + out[k] = static_cast(in[k]); + } +} + +static void UnpackDelta64(const uint32_t *__restrict in, hugeint_t *__restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = i * 2; + out[i] = in[OFFSET]; + out[i] |= static_cast(in[OFFSET + 1]) << 32; + } +} + +static void UnpackDelta96(const uint32_t *__restrict in, hugeint_t *__restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = i * 3; + out[i] = in[OFFSET]; + out[i] |= static_cast(in[OFFSET + 1]) << 32; + out[i] |= static_cast(in[OFFSET + 2]) << 64; + } +} + +static void UnpackDelta128(const uint32_t *__restrict in, hugeint_t *__restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = i * 4; + out[i] = in[OFFSET]; + out[i] |= static_cast(in[OFFSET + 1]) << 32; + out[i] |= static_cast(in[OFFSET + 2]) << 64; + out[i] |= static_cast(in[OFFSET + 3]) << 96; + } +} + +//===--------------------------------------------------------------------===// +// Packing +//===--------------------------------------------------------------------===// + +static void PackSingle(const hugeint_t in, uint32_t *__restrict &out, uint16_t delta, uint16_t shl, hugeint_t mask) { + if (delta + shl < 32) { + + if (shl == 0) { + out[0] = static_cast(in & mask); + } else { + out[0] |= static_cast((in & mask) << shl); + } + + } else if (delta + shl >= 32 && delta + shl < 64) { + + if (shl == 0) { + out[0] = static_cast(in & mask); + } else { + out[0] |= static_cast((in & mask) << shl); + } + ++out; + + if (delta + shl > 32) { + *out = static_cast((in & mask) >> (32 - shl)); + } + } + + else if (delta + shl >= 64 && delta + shl < 96) { + + if (shl == 0) { + out[0] = static_cast(in & mask); + } else { + out[0] |= static_cast(in << shl); + } + + out[1] = static_cast((in & mask) >> (32 - shl)); + out += 2; + + if (delta + shl > 64) { + *out = static_cast((in & mask) >> (64 - shl)); + } + } + + else if (delta + shl >= 96 && delta + shl < 128) { + if (shl == 0) { + out[0] = static_cast(in & mask); + } else { + out[0] |= static_cast(in << shl); + } + + out[1] = static_cast((in & mask) >> (32 - shl)); + out[2] = static_cast((in & mask) >> (64 - shl)); + out += 3; + + if (delta + shl > 96) { + *out = static_cast((in & mask) >> (96 - shl)); + } + } + + else if (delta + shl >= 128) { + // shl == 0 won't ever happen here considering a delta of 128 calls PackDelta128 + out[0] |= static_cast(in << shl); + out[1] = static_cast((in & mask) >> (32 - shl)); + out[2] = static_cast((in & mask) >> (64 - shl)); + out[3] = static_cast((in & mask) >> (96 - shl)); + out += 4; + + if (delta + shl > 128) { + *out = static_cast((in & mask) >> (128 - shl)); + } + } +} + +static void PackLast(const hugeint_t *__restrict in, uint32_t *__restrict out, uint16_t delta) { + const uint8_t LAST_IDX = 31; + const uint16_t SHIFT = (delta * 31) % 32; + out[0] |= static_cast(in[LAST_IDX] << SHIFT); + if (delta > 32) { + out[1] = static_cast(in[LAST_IDX] >> (32 - SHIFT)); + } + if (delta > 64) { + out[2] = static_cast(in[LAST_IDX] >> (64 - SHIFT)); + } + if (delta > 96) { + out[3] = static_cast(in[LAST_IDX] >> (96 - SHIFT)); + } +} + +// Packs for specific deltas +static void PackDelta32(const hugeint_t *__restrict in, uint32_t *__restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + out[i] = static_cast(in[i]); + } +} + +static void PackDelta64(const hugeint_t *__restrict in, uint32_t *__restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = 2 * i; + out[OFFSET] = static_cast(in[i]); + out[OFFSET + 1] = static_cast(in[i] >> 32); + } +} + +static void PackDelta96(const hugeint_t *__restrict in, uint32_t *__restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = 3 * i; + out[OFFSET] = static_cast(in[i]); + out[OFFSET + 1] = static_cast(in[i] >> 32); + out[OFFSET + 2] = static_cast(in[i] >> 64); + } +} + +static void PackDelta128(const hugeint_t *__restrict in, uint32_t *__restrict out) { + for (uint8_t i = 0; i < 32; ++i) { + const uint8_t OFFSET = 4 * i; + out[OFFSET] = static_cast(in[i]); + out[OFFSET + 1] = static_cast(in[i] >> 32); + out[OFFSET + 2] = static_cast(in[i] >> 64); + out[OFFSET + 3] = static_cast(in[i] >> 96); + } +} + +//===--------------------------------------------------------------------===// +// HugeIntPacker +//===--------------------------------------------------------------------===// + +void HugeIntPacker::Pack(const hugeint_t *__restrict in, uint32_t *__restrict out, bitpacking_width_t width) { + D_ASSERT(width <= 128); + switch (width) { + case 0: + break; + case 32: + PackDelta32(in, out); + break; + case 64: + PackDelta64(in, out); + break; + case 96: + PackDelta96(in, out); + break; + case 128: + PackDelta128(in, out); + break; + default: + for (idx_t oindex = 0; oindex < BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - 1; ++oindex) { + PackSingle(in[oindex], out, width, (width * oindex) % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE, + (hugeint_t(1) << width) - 1); + } + PackLast(in, out, width); + } +} + +void HugeIntPacker::Unpack(const uint32_t *__restrict in, hugeint_t *__restrict out, bitpacking_width_t width) { + D_ASSERT(width <= 128); + switch (width) { + case 0: + UnpackDelta0(in, out); + break; + case 32: + UnpackDelta32(in, out); + break; + case 64: + UnpackDelta64(in, out); + break; + case 96: + UnpackDelta96(in, out); + break; + case 128: + UnpackDelta128(in, out); + break; + default: + for (idx_t oindex = 0; oindex < BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - 1; ++oindex) { + UnpackSingle(in, out + oindex, width, + (width * oindex) % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE); + } + UnpackLast(in, out, width); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/chimp/bit_reader.cpp b/src/duckdb/src/storage/compression/chimp/bit_reader.cpp new file mode 100644 index 00000000..ee24cb58 --- /dev/null +++ b/src/duckdb/src/storage/compression/chimp/bit_reader.cpp @@ -0,0 +1,8 @@ +#include "duckdb/storage/compression/chimp/algorithm/bit_reader.hpp" + +namespace duckdb { + +constexpr uint8_t BitReader::REMAINDER_MASKS[]; +constexpr uint8_t BitReader::MASKS[]; + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/chimp/chimp.cpp b/src/duckdb/src/storage/compression/chimp/chimp.cpp new file mode 100644 index 00000000..f697c9e7 --- /dev/null +++ b/src/duckdb/src/storage/compression/chimp/chimp.cpp @@ -0,0 +1,41 @@ +#include "duckdb/storage/compression/chimp/chimp.hpp" +#include "duckdb/storage/compression/chimp/chimp_compress.hpp" +#include "duckdb/storage/compression/chimp/chimp_scan.hpp" +#include "duckdb/storage/compression/chimp/chimp_fetch.hpp" +#include "duckdb/storage/compression/chimp/chimp_analyze.hpp" + +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" + +namespace duckdb { + +template +CompressionFunction GetChimpFunction(PhysicalType data_type) { + return CompressionFunction(CompressionType::COMPRESSION_CHIMP, data_type, ChimpInitAnalyze, ChimpAnalyze, + ChimpFinalAnalyze, ChimpInitCompression, ChimpCompress, + ChimpFinalizeCompress, ChimpInitScan, ChimpScan, ChimpScanPartial, + ChimpFetchRow, ChimpSkip); +} + +CompressionFunction ChimpCompressionFun::GetFunction(PhysicalType type) { + switch (type) { + case PhysicalType::FLOAT: + return GetChimpFunction(type); + case PhysicalType::DOUBLE: + return GetChimpFunction(type); + default: + throw InternalException("Unsupported type for Chimp"); + } +} + +bool ChimpCompressionFun::TypeIsSupported(PhysicalType type) { + switch (type) { + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + return true; + default: + return false; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/chimp/chimp_constants.cpp b/src/duckdb/src/storage/compression/chimp/chimp_constants.cpp new file mode 100644 index 00000000..4577ba5f --- /dev/null +++ b/src/duckdb/src/storage/compression/chimp/chimp_constants.cpp @@ -0,0 +1,10 @@ +#include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" + +namespace duckdb { + +constexpr uint8_t ChimpConstants::Compression::LEADING_ROUND[]; +constexpr uint8_t ChimpConstants::Compression::LEADING_REPRESENTATION[]; + +constexpr uint8_t ChimpConstants::Decompression::LEADING_REPRESENTATION[]; + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/chimp/flag_buffer.cpp b/src/duckdb/src/storage/compression/chimp/flag_buffer.cpp new file mode 100644 index 00000000..06c84d83 --- /dev/null +++ b/src/duckdb/src/storage/compression/chimp/flag_buffer.cpp @@ -0,0 +1,8 @@ +#include "duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp" + +namespace duckdb { + +constexpr uint8_t FlagBufferConstants::MASKS[]; +constexpr uint8_t FlagBufferConstants::SHIFTS[]; + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/chimp/leading_zero_buffer.cpp b/src/duckdb/src/storage/compression/chimp/leading_zero_buffer.cpp new file mode 100644 index 00000000..ac4d9e43 --- /dev/null +++ b/src/duckdb/src/storage/compression/chimp/leading_zero_buffer.cpp @@ -0,0 +1,8 @@ +#include "duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp" + +namespace duckdb { + +constexpr uint32_t LeadingZeroBufferConstants::MASKS[]; +constexpr uint8_t LeadingZeroBufferConstants::SHIFTS[]; + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/dictionary_compression.cpp b/src/duckdb/src/storage/compression/dictionary_compression.cpp new file mode 100644 index 00000000..1c74c102 --- /dev/null +++ b/src/duckdb/src/storage/compression/dictionary_compression.cpp @@ -0,0 +1,652 @@ +#include "duckdb/common/bitpacking.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/common/types/vector_buffer.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/segment/uncompressed.hpp" +#include "duckdb/storage/string_uncompressed.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/column_data_checkpointer.hpp" + +namespace duckdb { + +// Abstract class for keeping compression state either for compression or size analysis +class DictionaryCompressionState : public CompressionState { +public: + bool UpdateState(Vector &scan_vector, idx_t count) { + UnifiedVectorFormat vdata; + scan_vector.ToUnifiedFormat(count, vdata); + auto data = UnifiedVectorFormat::GetData(vdata); + Verify(); + + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + size_t string_size = 0; + bool new_string = false; + auto row_is_valid = vdata.validity.RowIsValid(idx); + + if (row_is_valid) { + string_size = data[idx].GetSize(); + if (string_size >= StringUncompressed::STRING_BLOCK_LIMIT) { + // Big strings not implemented for dictionary compression + return false; + } + new_string = !LookupString(data[idx]); + } + + bool fits = CalculateSpaceRequirements(new_string, string_size); + if (!fits) { + Flush(); + new_string = true; + + fits = CalculateSpaceRequirements(new_string, string_size); + if (!fits) { + throw InternalException("Dictionary compression could not write to new segment"); + } + } + + if (!row_is_valid) { + AddNull(); + } else if (new_string) { + AddNewString(data[idx]); + } else { + AddLastLookup(); + } + + Verify(); + } + + return true; + } + +protected: + // Should verify the State + virtual void Verify() = 0; + // Performs a lookup of str, storing the result internally + virtual bool LookupString(string_t str) = 0; + // Add the most recently looked up str to compression state + virtual void AddLastLookup() = 0; + // Add string to the state that is known to not be seen yet + virtual void AddNewString(string_t str) = 0; + // Add a null value to the compression state + virtual void AddNull() = 0; + // Needs to be called before adding a value. Will return false if a flush is required first. + virtual bool CalculateSpaceRequirements(bool new_string, size_t string_size) = 0; + // Flush the segment to disk if compressing or reset the counters if analyzing + virtual void Flush(bool final = false) = 0; +}; + +typedef struct { + uint32_t dict_size; + uint32_t dict_end; + uint32_t index_buffer_offset; + uint32_t index_buffer_count; + uint32_t bitpacking_width; +} dictionary_compression_header_t; + +struct DictionaryCompressionStorage { + static constexpr float MINIMUM_COMPRESSION_RATIO = 1.2; + static constexpr uint16_t DICTIONARY_HEADER_SIZE = sizeof(dictionary_compression_header_t); + static constexpr size_t COMPACTION_FLUSH_LIMIT = (size_t)Storage::BLOCK_SIZE / 5 * 4; + + static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); + static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); + static idx_t StringFinalAnalyze(AnalyzeState &state_p); + + static unique_ptr InitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr state); + static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); + static void FinalizeCompress(CompressionState &state_p); + + static unique_ptr StringInitScan(ColumnSegment &segment); + template + static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset); + static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); + static void StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx); + + static bool HasEnoughSpace(idx_t current_count, idx_t index_count, idx_t dict_size, + bitpacking_width_t packing_width); + static idx_t RequiredSpace(idx_t current_count, idx_t index_count, idx_t dict_size, + bitpacking_width_t packing_width); + + static StringDictionaryContainer GetDictionary(ColumnSegment &segment, BufferHandle &handle); + static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer container); + static string_t FetchStringFromDict(ColumnSegment &segment, StringDictionaryContainer dict, data_ptr_t baseptr, + int32_t dict_offset, uint16_t string_len); + static uint16_t GetStringLength(uint32_t *index_buffer_ptr, sel_t index); +}; + +// Dictionary compression uses a combination of bitpacking and a dictionary to compress string segments. The data is +// stored across three buffers: the index buffer, the selection buffer and the dictionary. Firstly the Index buffer +// contains the offsets into the dictionary which are also used to determine the string lengths. Each value in the +// dictionary gets a single unique index in the index buffer. Secondly, the selection buffer maps the tuples to an index +// in the index buffer. The selection buffer is compressed with bitpacking. Finally, the dictionary contains simply all +// the unique strings without lenghts or null termination as we can deduce the lengths from the index buffer. The +// addition of the selection buffer is done for two reasons: firstly, to allow the scan to emit dictionary vectors by +// scanning the whole dictionary at once and then scanning the selection buffer for each emitted vector. Secondly, it +// allows for efficient bitpacking compression as the selection values should remain relatively small. +struct DictionaryCompressionCompressState : public DictionaryCompressionState { + explicit DictionaryCompressionCompressState(ColumnDataCheckpointer &checkpointer_p) + : checkpointer(checkpointer_p), + function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_DICTIONARY)), + heap(BufferAllocator::Get(checkpointer.GetDatabase())) { + CreateEmptySegment(checkpointer.GetRowGroup().start); + } + + ColumnDataCheckpointer &checkpointer; + CompressionFunction &function; + + // State regarding current segment + unique_ptr current_segment; + BufferHandle current_handle; + StringDictionaryContainer current_dictionary; + data_ptr_t current_end_ptr; + + // Buffers and map for current segment + StringHeap heap; + string_map_t current_string_map; + vector index_buffer; + vector selection_buffer; + + bitpacking_width_t current_width = 0; + bitpacking_width_t next_width = 0; + + // Result of latest LookupString call + uint32_t latest_lookup_result; + +public: + void CreateEmptySegment(idx_t row_start) { + auto &db = checkpointer.GetDatabase(); + auto &type = checkpointer.GetType(); + auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); + current_segment = std::move(compressed_segment); + + current_segment->function = function; + + // Reset the buffers and string map + current_string_map.clear(); + index_buffer.clear(); + index_buffer.push_back(0); // Reserve index 0 for null strings + selection_buffer.clear(); + + current_width = 0; + next_width = 0; + + // Reset the pointers into the current segment + auto &buffer_manager = BufferManager::GetBufferManager(checkpointer.GetDatabase()); + current_handle = buffer_manager.Pin(current_segment->block); + current_dictionary = DictionaryCompressionStorage::GetDictionary(*current_segment, current_handle); + current_end_ptr = current_handle.Ptr() + current_dictionary.end; + } + + void Verify() override { + current_dictionary.Verify(); + D_ASSERT(current_segment->count == selection_buffer.size()); + D_ASSERT(DictionaryCompressionStorage::HasEnoughSpace(current_segment->count.load(), index_buffer.size(), + current_dictionary.size, current_width)); + D_ASSERT(current_dictionary.end == Storage::BLOCK_SIZE); + D_ASSERT(index_buffer.size() == current_string_map.size() + 1); // +1 is for null value + } + + bool LookupString(string_t str) override { + auto search = current_string_map.find(str); + auto has_result = search != current_string_map.end(); + + if (has_result) { + latest_lookup_result = search->second; + } + return has_result; + } + + void AddNewString(string_t str) override { + UncompressedStringStorage::UpdateStringStats(current_segment->stats, str); + + // Copy string to dict + current_dictionary.size += str.GetSize(); + auto dict_pos = current_end_ptr - current_dictionary.size; + memcpy(dict_pos, str.GetData(), str.GetSize()); + current_dictionary.Verify(); + D_ASSERT(current_dictionary.end == Storage::BLOCK_SIZE); + + // Update buffers and map + index_buffer.push_back(current_dictionary.size); + selection_buffer.push_back(index_buffer.size() - 1); + if (str.IsInlined()) { + current_string_map.insert({str, index_buffer.size() - 1}); + } else { + current_string_map.insert({heap.AddBlob(str), index_buffer.size() - 1}); + } + DictionaryCompressionStorage::SetDictionary(*current_segment, current_handle, current_dictionary); + + current_width = next_width; + current_segment->count++; + } + + void AddNull() override { + selection_buffer.push_back(0); + current_segment->count++; + } + + void AddLastLookup() override { + selection_buffer.push_back(latest_lookup_result); + current_segment->count++; + } + + bool CalculateSpaceRequirements(bool new_string, size_t string_size) override { + if (new_string) { + next_width = BitpackingPrimitives::MinimumBitWidth(index_buffer.size() - 1 + new_string); + return DictionaryCompressionStorage::HasEnoughSpace(current_segment->count.load() + 1, + index_buffer.size() + 1, + current_dictionary.size + string_size, next_width); + } else { + return DictionaryCompressionStorage::HasEnoughSpace(current_segment->count.load() + 1, index_buffer.size(), + current_dictionary.size, current_width); + } + } + + void Flush(bool final = false) override { + auto next_start = current_segment->start + current_segment->count; + + auto segment_size = Finalize(); + auto &state = checkpointer.GetCheckpointState(); + state.FlushSegment(std::move(current_segment), segment_size); + + if (!final) { + CreateEmptySegment(next_start); + } + } + + idx_t Finalize() { + auto &buffer_manager = BufferManager::GetBufferManager(checkpointer.GetDatabase()); + auto handle = buffer_manager.Pin(current_segment->block); + D_ASSERT(current_dictionary.end == Storage::BLOCK_SIZE); + + // calculate sizes + auto compressed_selection_buffer_size = + BitpackingPrimitives::GetRequiredSize(current_segment->count, current_width); + auto index_buffer_size = index_buffer.size() * sizeof(uint32_t); + auto total_size = DictionaryCompressionStorage::DICTIONARY_HEADER_SIZE + compressed_selection_buffer_size + + index_buffer_size + current_dictionary.size; + + // calculate ptr and offsets + auto base_ptr = handle.Ptr(); + auto header_ptr = reinterpret_cast(base_ptr); + auto compressed_selection_buffer_offset = DictionaryCompressionStorage::DICTIONARY_HEADER_SIZE; + auto index_buffer_offset = compressed_selection_buffer_offset + compressed_selection_buffer_size; + + // Write compressed selection buffer + BitpackingPrimitives::PackBuffer(base_ptr + compressed_selection_buffer_offset, + (sel_t *)(selection_buffer.data()), current_segment->count, + current_width); + + // Write the index buffer + memcpy(base_ptr + index_buffer_offset, index_buffer.data(), index_buffer_size); + + // Store sizes and offsets in segment header + Store(index_buffer_offset, data_ptr_cast(&header_ptr->index_buffer_offset)); + Store(index_buffer.size(), data_ptr_cast(&header_ptr->index_buffer_count)); + Store((uint32_t)current_width, data_ptr_cast(&header_ptr->bitpacking_width)); + + D_ASSERT(current_width == BitpackingPrimitives::MinimumBitWidth(index_buffer.size() - 1)); + D_ASSERT(DictionaryCompressionStorage::HasEnoughSpace(current_segment->count, index_buffer.size(), + current_dictionary.size, current_width)); + D_ASSERT((uint64_t)*max_element(std::begin(selection_buffer), std::end(selection_buffer)) == + index_buffer.size() - 1); + + if (total_size >= DictionaryCompressionStorage::COMPACTION_FLUSH_LIMIT) { + // the block is full enough, don't bother moving around the dictionary + return Storage::BLOCK_SIZE; + } + // the block has space left: figure out how much space we can save + auto move_amount = Storage::BLOCK_SIZE - total_size; + // move the dictionary so it lines up exactly with the offsets + auto new_dictionary_offset = index_buffer_offset + index_buffer_size; + memmove(base_ptr + new_dictionary_offset, base_ptr + current_dictionary.end - current_dictionary.size, + current_dictionary.size); + current_dictionary.end -= move_amount; + D_ASSERT(current_dictionary.end == total_size); + // write the new dictionary (with the updated "end") + DictionaryCompressionStorage::SetDictionary(*current_segment, handle, current_dictionary); + return total_size; + } +}; + +//===--------------------------------------------------------------------===// +// Analyze +//===--------------------------------------------------------------------===// +struct DictionaryAnalyzeState : public DictionaryCompressionState { + DictionaryAnalyzeState() + : segment_count(0), current_tuple_count(0), current_unique_count(0), current_dict_size(0), current_width(0), + next_width(0) { + } + + size_t segment_count; + idx_t current_tuple_count; + idx_t current_unique_count; + size_t current_dict_size; + StringHeap heap; + string_set_t current_set; + bitpacking_width_t current_width; + bitpacking_width_t next_width; + + bool LookupString(string_t str) override { + return current_set.count(str); + } + + void AddNewString(string_t str) override { + current_tuple_count++; + current_unique_count++; + current_dict_size += str.GetSize(); + if (str.IsInlined()) { + current_set.insert(str); + } else { + current_set.insert(heap.AddBlob(str)); + } + current_width = next_width; + } + + void AddLastLookup() override { + current_tuple_count++; + } + + void AddNull() override { + current_tuple_count++; + } + + bool CalculateSpaceRequirements(bool new_string, size_t string_size) override { + if (new_string) { + next_width = + BitpackingPrimitives::MinimumBitWidth(current_unique_count + 2); // 1 for null, one for new string + return DictionaryCompressionStorage::HasEnoughSpace(current_tuple_count + 1, current_unique_count + 1, + current_dict_size + string_size, next_width); + } else { + return DictionaryCompressionStorage::HasEnoughSpace(current_tuple_count + 1, current_unique_count, + current_dict_size, current_width); + } + } + + void Flush(bool final = false) override { + segment_count++; + current_tuple_count = 0; + current_unique_count = 0; + current_dict_size = 0; + current_set.clear(); + } + void Verify() override {}; +}; + +struct DictionaryCompressionAnalyzeState : public AnalyzeState { + DictionaryCompressionAnalyzeState() : analyze_state(make_uniq()) { + } + + unique_ptr analyze_state; +}; + +unique_ptr DictionaryCompressionStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { + return make_uniq(); +} + +bool DictionaryCompressionStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { + auto &state = state_p.Cast(); + return state.analyze_state->UpdateState(input, count); +} + +idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { + auto &analyze_state = state_p.Cast(); + auto &state = *analyze_state.analyze_state; + + auto width = BitpackingPrimitives::MinimumBitWidth(state.current_unique_count + 1); + auto req_space = + RequiredSpace(state.current_tuple_count, state.current_unique_count, state.current_dict_size, width); + + return MINIMUM_COMPRESSION_RATIO * (state.segment_count * Storage::BLOCK_SIZE + req_space); +} + +//===--------------------------------------------------------------------===// +// Compress +//===--------------------------------------------------------------------===// +unique_ptr DictionaryCompressionStorage::InitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr state) { + return make_uniq(checkpointer); +} + +void DictionaryCompressionStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { + auto &state = state_p.Cast(); + state.UpdateState(scan_vector, count); +} + +void DictionaryCompressionStorage::FinalizeCompress(CompressionState &state_p) { + auto &state = state_p.Cast(); + state.Flush(true); +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +struct CompressedStringScanState : public StringScanState { + BufferHandle handle; + buffer_ptr dictionary; + bitpacking_width_t current_width; + buffer_ptr sel_vec; + idx_t sel_vec_size = 0; +}; + +unique_ptr DictionaryCompressionStorage::StringInitScan(ColumnSegment &segment) { + auto state = make_uniq(); + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + state->handle = buffer_manager.Pin(segment.block); + + auto baseptr = state->handle.Ptr() + segment.GetBlockOffset(); + + // Load header values + auto dict = DictionaryCompressionStorage::GetDictionary(segment, state->handle); + auto header_ptr = reinterpret_cast(baseptr); + auto index_buffer_offset = Load(data_ptr_cast(&header_ptr->index_buffer_offset)); + auto index_buffer_count = Load(data_ptr_cast(&header_ptr->index_buffer_count)); + state->current_width = (bitpacking_width_t)(Load(data_ptr_cast(&header_ptr->bitpacking_width))); + + auto index_buffer_ptr = reinterpret_cast(baseptr + index_buffer_offset); + + state->dictionary = make_buffer(segment.type, index_buffer_count); + auto dict_child_data = FlatVector::GetData(*(state->dictionary)); + + for (uint32_t i = 0; i < index_buffer_count; i++) { + // NOTE: the passing of dict_child_vector, will not be used, its for big strings + uint16_t str_len = GetStringLength(index_buffer_ptr, i); + dict_child_data[i] = FetchStringFromDict(segment, dict, baseptr, index_buffer_ptr[i], str_len); + } + + return std::move(state); +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +template +void DictionaryCompressionStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, + Vector &result, idx_t result_offset) { + // clear any previously locked buffers and get the primary buffer handle + auto &scan_state = state.scan_state->Cast(); + auto start = segment.GetRelativeIndex(state.row_index); + + auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); + auto dict = DictionaryCompressionStorage::GetDictionary(segment, scan_state.handle); + + auto header_ptr = reinterpret_cast(baseptr); + auto index_buffer_offset = Load(data_ptr_cast(&header_ptr->index_buffer_offset)); + auto index_buffer_ptr = reinterpret_cast(baseptr + index_buffer_offset); + + auto base_data = data_ptr_cast(baseptr + DICTIONARY_HEADER_SIZE); + auto result_data = FlatVector::GetData(result); + + if (!ALLOW_DICT_VECTORS || scan_count != STANDARD_VECTOR_SIZE || + start % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE != 0) { + // Emit regular vector + + // Handling non-bitpacking-group-aligned start values; + idx_t start_offset = start % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; + + // We will scan in blocks of BITPACKING_ALGORITHM_GROUP_SIZE, so we may scan some extra values. + idx_t decompress_count = BitpackingPrimitives::RoundUpToAlgorithmGroupSize(scan_count + start_offset); + + // Create a decompression buffer of sufficient size if we don't already have one. + if (!scan_state.sel_vec || scan_state.sel_vec_size < decompress_count) { + scan_state.sel_vec_size = decompress_count; + scan_state.sel_vec = make_buffer(decompress_count); + } + + data_ptr_t src = &base_data[((start - start_offset) * scan_state.current_width) / 8]; + sel_t *sel_vec_ptr = scan_state.sel_vec->data(); + + BitpackingPrimitives::UnPackBuffer(data_ptr_cast(sel_vec_ptr), src, decompress_count, + scan_state.current_width); + + for (idx_t i = 0; i < scan_count; i++) { + // Lookup dict offset in index buffer + auto string_number = scan_state.sel_vec->get_index(i + start_offset); + auto dict_offset = index_buffer_ptr[string_number]; + uint16_t str_len = GetStringLength(index_buffer_ptr, string_number); + result_data[result_offset + i] = FetchStringFromDict(segment, dict, baseptr, dict_offset, str_len); + } + + } else { + D_ASSERT(start % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0); + D_ASSERT(scan_count == STANDARD_VECTOR_SIZE); + D_ASSERT(result_offset == 0); + + idx_t decompress_count = BitpackingPrimitives::RoundUpToAlgorithmGroupSize(scan_count); + + // Create a selection vector of sufficient size if we don't already have one. + if (!scan_state.sel_vec || scan_state.sel_vec_size < decompress_count) { + scan_state.sel_vec_size = decompress_count; + scan_state.sel_vec = make_buffer(decompress_count); + } + + // Scanning 1024 values, emitting a dict vector + data_ptr_t dst = data_ptr_cast(scan_state.sel_vec->data()); + data_ptr_t src = data_ptr_cast(&base_data[(start * scan_state.current_width) / 8]); + + BitpackingPrimitives::UnPackBuffer(dst, src, scan_count, scan_state.current_width); + + result.Slice(*(scan_state.dictionary), *scan_state.sel_vec, scan_count); + } +} + +void DictionaryCompressionStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, + Vector &result) { + StringScanPartial(segment, state, scan_count, result, 0); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +void DictionaryCompressionStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, + Vector &result, idx_t result_idx) { + // fetch a single row from the string segment + // first pin the main buffer if it is not already pinned + auto &handle = state.GetOrInsertHandle(segment); + + auto baseptr = handle.Ptr() + segment.GetBlockOffset(); + auto header_ptr = reinterpret_cast(baseptr); + auto dict = DictionaryCompressionStorage::GetDictionary(segment, handle); + auto index_buffer_offset = Load(data_ptr_cast(&header_ptr->index_buffer_offset)); + auto width = (bitpacking_width_t)Load(data_ptr_cast(&header_ptr->bitpacking_width)); + auto index_buffer_ptr = reinterpret_cast(baseptr + index_buffer_offset); + auto base_data = data_ptr_cast(baseptr + DICTIONARY_HEADER_SIZE); + auto result_data = FlatVector::GetData(result); + + // Handling non-bitpacking-group-aligned start values; + idx_t start_offset = row_id % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; + + // Decompress part of selection buffer we need for this value. + sel_t decompression_buffer[BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE]; + data_ptr_t src = data_ptr_cast(&base_data[((row_id - start_offset) * width) / 8]); + BitpackingPrimitives::UnPackBuffer(data_ptr_cast(decompression_buffer), src, + BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE, width); + + auto selection_value = decompression_buffer[start_offset]; + auto dict_offset = index_buffer_ptr[selection_value]; + uint16_t str_len = GetStringLength(index_buffer_ptr, selection_value); + + result_data[result_idx] = FetchStringFromDict(segment, dict, baseptr, dict_offset, str_len); +} + +//===--------------------------------------------------------------------===// +// Helper Functions +//===--------------------------------------------------------------------===// +bool DictionaryCompressionStorage::HasEnoughSpace(idx_t current_count, idx_t index_count, idx_t dict_size, + bitpacking_width_t packing_width) { + return RequiredSpace(current_count, index_count, dict_size, packing_width) <= Storage::BLOCK_SIZE; +} + +idx_t DictionaryCompressionStorage::RequiredSpace(idx_t current_count, idx_t index_count, idx_t dict_size, + bitpacking_width_t packing_width) { + idx_t base_space = DICTIONARY_HEADER_SIZE + dict_size; + idx_t string_number_space = BitpackingPrimitives::GetRequiredSize(current_count, packing_width); + idx_t index_space = index_count * sizeof(uint32_t); + + idx_t used_space = base_space + index_space + string_number_space; + + return used_space; +} + +StringDictionaryContainer DictionaryCompressionStorage::GetDictionary(ColumnSegment &segment, BufferHandle &handle) { + auto header_ptr = reinterpret_cast(handle.Ptr() + segment.GetBlockOffset()); + StringDictionaryContainer container; + container.size = Load(data_ptr_cast(&header_ptr->dict_size)); + container.end = Load(data_ptr_cast(&header_ptr->dict_end)); + return container; +} + +void DictionaryCompressionStorage::SetDictionary(ColumnSegment &segment, BufferHandle &handle, + StringDictionaryContainer container) { + auto header_ptr = reinterpret_cast(handle.Ptr() + segment.GetBlockOffset()); + Store(container.size, data_ptr_cast(&header_ptr->dict_size)); + Store(container.end, data_ptr_cast(&header_ptr->dict_end)); +} + +string_t DictionaryCompressionStorage::FetchStringFromDict(ColumnSegment &segment, StringDictionaryContainer dict, + data_ptr_t baseptr, int32_t dict_offset, + uint16_t string_len) { + D_ASSERT(dict_offset >= 0 && dict_offset <= Storage::BLOCK_SIZE); + + if (dict_offset == 0) { + return string_t(nullptr, 0); + } + // normal string: read string from this block + auto dict_end = baseptr + dict.end; + auto dict_pos = dict_end - dict_offset; + + auto str_ptr = char_ptr_cast(dict_pos); + return string_t(str_ptr, string_len); +} + +uint16_t DictionaryCompressionStorage::GetStringLength(uint32_t *index_buffer_ptr, sel_t index) { + if (index == 0) { + return 0; + } else { + return index_buffer_ptr[index] - index_buffer_ptr[index - 1]; + } +} + +//===--------------------------------------------------------------------===// +// Get Function +//===--------------------------------------------------------------------===// +CompressionFunction DictionaryCompressionFun::GetFunction(PhysicalType data_type) { + return CompressionFunction( + CompressionType::COMPRESSION_DICTIONARY, data_type, DictionaryCompressionStorage ::StringInitAnalyze, + DictionaryCompressionStorage::StringAnalyze, DictionaryCompressionStorage::StringFinalAnalyze, + DictionaryCompressionStorage::InitCompression, DictionaryCompressionStorage::Compress, + DictionaryCompressionStorage::FinalizeCompress, DictionaryCompressionStorage::StringInitScan, + DictionaryCompressionStorage::StringScan, DictionaryCompressionStorage::StringScanPartial, + DictionaryCompressionStorage::StringFetchRow, UncompressedFunctions::EmptySkip); +} + +bool DictionaryCompressionFun::TypeIsSupported(PhysicalType type) { + return type == PhysicalType::VARCHAR; +} +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp new file mode 100644 index 00000000..ea77f926 --- /dev/null +++ b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp @@ -0,0 +1,300 @@ +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp" +#include "duckdb/storage/segment/uncompressed.hpp" + +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Analyze +//===--------------------------------------------------------------------===// +struct FixedSizeAnalyzeState : public AnalyzeState { + FixedSizeAnalyzeState() : count(0) { + } + + idx_t count; +}; + +unique_ptr FixedSizeInitAnalyze(ColumnData &col_data, PhysicalType type) { + return make_uniq(); +} + +bool FixedSizeAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { + auto &state = state_p.Cast(); + state.count += count; + return true; +} + +template +idx_t FixedSizeFinalAnalyze(AnalyzeState &state_p) { + auto &state = state_p.template Cast(); + return sizeof(T) * state.count; +} + +//===--------------------------------------------------------------------===// +// Compress +//===--------------------------------------------------------------------===// +struct UncompressedCompressState : public CompressionState { + explicit UncompressedCompressState(ColumnDataCheckpointer &checkpointer); + + ColumnDataCheckpointer &checkpointer; + unique_ptr current_segment; + ColumnAppendState append_state; + + virtual void CreateEmptySegment(idx_t row_start); + void FlushSegment(idx_t segment_size); + void Finalize(idx_t segment_size); +}; + +UncompressedCompressState::UncompressedCompressState(ColumnDataCheckpointer &checkpointer) + : checkpointer(checkpointer) { + UncompressedCompressState::CreateEmptySegment(checkpointer.GetRowGroup().start); +} + +void UncompressedCompressState::CreateEmptySegment(idx_t row_start) { + auto &db = checkpointer.GetDatabase(); + auto &type = checkpointer.GetType(); + auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); + if (type.InternalType() == PhysicalType::VARCHAR) { + auto &state = compressed_segment->GetSegmentState()->Cast(); + state.overflow_writer = make_uniq(checkpointer.GetRowGroup().GetBlockManager()); + } + current_segment = std::move(compressed_segment); + current_segment->InitializeAppend(append_state); +} + +void UncompressedCompressState::FlushSegment(idx_t segment_size) { + auto &state = checkpointer.GetCheckpointState(); + if (current_segment->type.InternalType() == PhysicalType::VARCHAR) { + auto &segment_state = current_segment->GetSegmentState()->Cast(); + segment_state.overflow_writer->Flush(); + segment_state.overflow_writer.reset(); + } + state.FlushSegment(std::move(current_segment), segment_size); +} + +void UncompressedCompressState::Finalize(idx_t segment_size) { + FlushSegment(segment_size); + current_segment.reset(); +} + +unique_ptr UncompressedFunctions::InitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr state) { + return make_uniq(checkpointer); +} + +void UncompressedFunctions::Compress(CompressionState &state_p, Vector &data, idx_t count) { + auto &state = state_p.Cast(); + UnifiedVectorFormat vdata; + data.ToUnifiedFormat(count, vdata); + + idx_t offset = 0; + while (count > 0) { + idx_t appended = state.current_segment->Append(state.append_state, vdata, offset, count); + if (appended == count) { + // appended everything: finished + return; + } + auto next_start = state.current_segment->start + state.current_segment->count; + // the segment is full: flush it to disk + state.FlushSegment(state.current_segment->FinalizeAppend(state.append_state)); + + // now create a new segment and continue appending + state.CreateEmptySegment(next_start); + offset += appended; + count -= appended; + } +} + +void UncompressedFunctions::FinalizeCompress(CompressionState &state_p) { + auto &state = state_p.Cast(); + state.Finalize(state.current_segment->FinalizeAppend(state.append_state)); +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +struct FixedSizeScanState : public SegmentScanState { + BufferHandle handle; +}; + +unique_ptr FixedSizeInitScan(ColumnSegment &segment) { + auto result = make_uniq(); + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + result->handle = buffer_manager.Pin(segment.block); + return std::move(result); +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +template +void FixedSizeScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + auto &scan_state = state.scan_state->Cast(); + auto start = segment.GetRelativeIndex(state.row_index); + + auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); + auto source_data = data + start * sizeof(T); + + // copy the data from the base table + result.SetVectorType(VectorType::FLAT_VECTOR); + memcpy(FlatVector::GetData(result) + result_offset * sizeof(T), source_data, scan_count * sizeof(T)); +} + +template +void FixedSizeScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + auto &scan_state = state.scan_state->template Cast(); + auto start = segment.GetRelativeIndex(state.row_index); + + auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); + auto source_data = data + start * sizeof(T); + + result.SetVectorType(VectorType::FLAT_VECTOR); + FlatVector::SetData(result, source_data); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +template +void FixedSizeFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + + // first fetch the data from the base table + auto data_ptr = handle.Ptr() + segment.GetBlockOffset() + row_id * sizeof(T); + + memcpy(FlatVector::GetData(result) + result_idx * sizeof(T), data_ptr, sizeof(T)); +} + +//===--------------------------------------------------------------------===// +// Append +//===--------------------------------------------------------------------===// +static unique_ptr FixedSizeInitAppend(ColumnSegment &segment) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + return make_uniq(std::move(handle)); +} + +struct StandardFixedSizeAppend { + template + static void Append(SegmentStatistics &stats, data_ptr_t target, idx_t target_offset, UnifiedVectorFormat &adata, + idx_t offset, idx_t count) { + auto sdata = UnifiedVectorFormat::GetData(adata); + auto tdata = reinterpret_cast(target); + if (!adata.validity.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto source_idx = adata.sel->get_index(offset + i); + auto target_idx = target_offset + i; + bool is_null = !adata.validity.RowIsValid(source_idx); + if (!is_null) { + NumericStats::Update(stats.statistics, sdata[source_idx]); + tdata[target_idx] = sdata[source_idx]; + } else { + // we insert a NullValue in the null gap for debuggability + // this value should never be used or read anywhere + tdata[target_idx] = NullValue(); + } + } + } else { + for (idx_t i = 0; i < count; i++) { + auto source_idx = adata.sel->get_index(offset + i); + auto target_idx = target_offset + i; + NumericStats::Update(stats.statistics, sdata[source_idx]); + tdata[target_idx] = sdata[source_idx]; + } + } + } +}; + +struct ListFixedSizeAppend { + template + static void Append(SegmentStatistics &stats, data_ptr_t target, idx_t target_offset, UnifiedVectorFormat &adata, + idx_t offset, idx_t count) { + auto sdata = UnifiedVectorFormat::GetData(adata); + auto tdata = reinterpret_cast(target); + for (idx_t i = 0; i < count; i++) { + auto source_idx = adata.sel->get_index(offset + i); + auto target_idx = target_offset + i; + tdata[target_idx] = sdata[source_idx]; + } + } +}; + +template +idx_t FixedSizeAppend(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, + UnifiedVectorFormat &data, idx_t offset, idx_t count) { + D_ASSERT(segment.GetBlockOffset() == 0); + + auto target_ptr = append_state.handle.Ptr(); + idx_t max_tuple_count = segment.SegmentSize() / sizeof(T); + idx_t copy_count = MinValue(count, max_tuple_count - segment.count); + + OP::template Append(stats, target_ptr, segment.count, data, offset, copy_count); + segment.count += copy_count; + return copy_count; +} + +template +idx_t FixedSizeFinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { + return segment.count * sizeof(T); +} + +//===--------------------------------------------------------------------===// +// Get Function +//===--------------------------------------------------------------------===// +template +CompressionFunction FixedSizeGetFunction(PhysicalType data_type) { + return CompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, data_type, FixedSizeInitAnalyze, + FixedSizeAnalyze, FixedSizeFinalAnalyze, UncompressedFunctions::InitCompression, + UncompressedFunctions::Compress, UncompressedFunctions::FinalizeCompress, + FixedSizeInitScan, FixedSizeScan, FixedSizeScanPartial, FixedSizeFetchRow, + UncompressedFunctions::EmptySkip, nullptr, FixedSizeInitAppend, + FixedSizeAppend, FixedSizeFinalizeAppend, nullptr); +} + +CompressionFunction FixedSizeUncompressed::GetFunction(PhysicalType data_type) { + switch (data_type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return FixedSizeGetFunction(data_type); + case PhysicalType::INT16: + return FixedSizeGetFunction(data_type); + case PhysicalType::INT32: + return FixedSizeGetFunction(data_type); + case PhysicalType::INT64: + return FixedSizeGetFunction(data_type); + case PhysicalType::UINT8: + return FixedSizeGetFunction(data_type); + case PhysicalType::UINT16: + return FixedSizeGetFunction(data_type); + case PhysicalType::UINT32: + return FixedSizeGetFunction(data_type); + case PhysicalType::UINT64: + return FixedSizeGetFunction(data_type); + case PhysicalType::INT128: + return FixedSizeGetFunction(data_type); + case PhysicalType::FLOAT: + return FixedSizeGetFunction(data_type); + case PhysicalType::DOUBLE: + return FixedSizeGetFunction(data_type); + case PhysicalType::INTERVAL: + return FixedSizeGetFunction(data_type); + case PhysicalType::LIST: + return FixedSizeGetFunction(data_type); + default: + throw InternalException("Unsupported type for FixedSizeUncompressed::GetFunction"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/fsst.cpp b/src/duckdb/src/storage/compression/fsst.cpp new file mode 100644 index 00000000..9729af27 --- /dev/null +++ b/src/duckdb/src/storage/compression/fsst.cpp @@ -0,0 +1,752 @@ +#include "duckdb/common/bitpacking.hpp" +#include "duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp" +#include "duckdb/storage/string_uncompressed.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/random_engine.hpp" +#include "duckdb/common/fsst.hpp" +#include "miniz_wrapper.hpp" +#include "fsst.h" + +namespace duckdb { + +typedef struct { + uint32_t dict_size; + uint32_t dict_end; + uint32_t bitpacking_width; + uint32_t fsst_symbol_table_offset; +} fsst_compression_header_t; + +// Counts and offsets used during scanning/fetching +// | ColumnSegment to be scanned / fetched from | +// | untouched | bp align | unused d-values | to scan | bp align | untouched | +typedef struct BPDeltaDecodeOffsets { + idx_t delta_decode_start_row; // X + idx_t bitunpack_alignment_offset; // <---------> + idx_t bitunpack_start_row; // X + idx_t unused_delta_decoded_values; // <-----------------> + idx_t scan_offset; // <----------------------------> + idx_t total_delta_decode_count; // <--------------------------> + idx_t total_bitunpack_count; // <------------------------------------------------> +} bp_delta_offsets_t; + +struct FSSTStorage { + static constexpr size_t COMPACTION_FLUSH_LIMIT = (size_t)Storage::BLOCK_SIZE / 5 * 4; + static constexpr double MINIMUM_COMPRESSION_RATIO = 1.2; + static constexpr double ANALYSIS_SAMPLE_SIZE = 0.25; + + static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); + static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); + static idx_t StringFinalAnalyze(AnalyzeState &state_p); + + static unique_ptr InitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr analyze_state_p); + static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); + static void FinalizeCompress(CompressionState &state_p); + + static unique_ptr StringInitScan(ColumnSegment &segment); + template + static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset); + static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); + static void StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx); + + static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer container); + static StringDictionaryContainer GetDictionary(ColumnSegment &segment, BufferHandle &handle); + + static char *FetchStringPointer(StringDictionaryContainer dict, data_ptr_t baseptr, int32_t dict_offset); + static bp_delta_offsets_t CalculateBpDeltaOffsets(int64_t last_known_row, idx_t start, idx_t scan_count); + static bool ParseFSSTSegmentHeader(data_ptr_t base_ptr, duckdb_fsst_decoder_t *decoder_out, + bitpacking_width_t *width_out); +}; + +//===--------------------------------------------------------------------===// +// Analyze +//===--------------------------------------------------------------------===// +struct FSSTAnalyzeState : public AnalyzeState { + FSSTAnalyzeState() : count(0), fsst_string_total_size(0), empty_strings(0) { + } + + ~FSSTAnalyzeState() override { + if (fsst_encoder) { + duckdb_fsst_destroy(fsst_encoder); + } + } + + duckdb_fsst_encoder_t *fsst_encoder = nullptr; + idx_t count; + + StringHeap fsst_string_heap; + vector fsst_strings; + size_t fsst_string_total_size; + + RandomEngine random_engine; + bool have_valid_row = false; + + idx_t empty_strings; +}; + +unique_ptr FSSTStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { + return make_uniq(); +} + +bool FSSTStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { + auto &state = state_p.Cast(); + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + + state.count += count; + auto data = UnifiedVectorFormat::GetData(vdata); + + // Note that we ignore the sampling in case we have not found any valid strings yet, this solves the issue of + // not having seen any valid strings here leading to an empty fsst symbol table. + bool sample_selected = !state.have_valid_row || state.random_engine.NextRandom() < ANALYSIS_SAMPLE_SIZE; + + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + + if (!vdata.validity.RowIsValid(idx)) { + continue; + } + + // We need to check all strings for this, otherwise we run in to trouble during compression if we miss ones + auto string_size = data[idx].GetSize(); + if (string_size >= StringUncompressed::STRING_BLOCK_LIMIT) { + return false; + } + + if (!sample_selected) { + continue; + } + + if (string_size > 0) { + state.have_valid_row = true; + if (data[idx].IsInlined()) { + state.fsst_strings.push_back(data[idx]); + } else { + state.fsst_strings.emplace_back(state.fsst_string_heap.AddBlob(data[idx])); + } + state.fsst_string_total_size += string_size; + } else { + state.empty_strings++; + } + } + return true; +} + +idx_t FSSTStorage::StringFinalAnalyze(AnalyzeState &state_p) { + auto &state = state_p.Cast(); + + size_t compressed_dict_size = 0; + size_t max_compressed_string_length = 0; + + auto string_count = state.fsst_strings.size(); + + if (!string_count) { + return DConstants::INVALID_INDEX; + } + + size_t output_buffer_size = 7 + 2 * state.fsst_string_total_size; // size as specified in fsst.h + + vector fsst_string_sizes; + vector fsst_string_ptrs; + for (auto &str : state.fsst_strings) { + fsst_string_sizes.push_back(str.GetSize()); + fsst_string_ptrs.push_back((unsigned char *)str.GetData()); // NOLINT + } + + state.fsst_encoder = duckdb_fsst_create(string_count, &fsst_string_sizes[0], &fsst_string_ptrs[0], 0); + + // TODO: do we really need to encode to get a size estimate? + auto compressed_ptrs = vector(string_count, nullptr); + auto compressed_sizes = vector(string_count, 0); + unique_ptr compressed_buffer(new unsigned char[output_buffer_size]); + + auto res = + duckdb_fsst_compress(state.fsst_encoder, string_count, &fsst_string_sizes[0], &fsst_string_ptrs[0], + output_buffer_size, compressed_buffer.get(), &compressed_sizes[0], &compressed_ptrs[0]); + + if (string_count != res) { + throw std::runtime_error("FSST output buffer is too small unexpectedly"); + } + + // Sum and and Max compressed lengths + for (auto &size : compressed_sizes) { + compressed_dict_size += size; + max_compressed_string_length = MaxValue(max_compressed_string_length, size); + } + D_ASSERT(compressed_dict_size == (compressed_ptrs[res - 1] - compressed_ptrs[0]) + compressed_sizes[res - 1]); + + auto minimum_width = BitpackingPrimitives::MinimumBitWidth(max_compressed_string_length); + auto bitpacked_offsets_size = + BitpackingPrimitives::GetRequiredSize(string_count + state.empty_strings, minimum_width); + + auto estimated_base_size = (bitpacked_offsets_size + compressed_dict_size) * (1 / ANALYSIS_SAMPLE_SIZE); + auto num_blocks = estimated_base_size / (Storage::BLOCK_SIZE - sizeof(duckdb_fsst_decoder_t)); + auto symtable_size = num_blocks * sizeof(duckdb_fsst_decoder_t); + + auto estimated_size = estimated_base_size + symtable_size; + + return estimated_size * MINIMUM_COMPRESSION_RATIO; +} + +//===--------------------------------------------------------------------===// +// Compress +//===--------------------------------------------------------------------===// + +class FSSTCompressionState : public CompressionState { +public: + explicit FSSTCompressionState(ColumnDataCheckpointer &checkpointer) + : checkpointer(checkpointer), function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_FSST)) { + CreateEmptySegment(checkpointer.GetRowGroup().start); + } + + ~FSSTCompressionState() override { + if (fsst_encoder) { + duckdb_fsst_destroy(fsst_encoder); + } + } + + void Reset() { + index_buffer.clear(); + current_width = 0; + max_compressed_string_length = 0; + last_fitting_size = 0; + + // Reset the pointers into the current segment + auto &buffer_manager = BufferManager::GetBufferManager(current_segment->db); + current_handle = buffer_manager.Pin(current_segment->block); + current_dictionary = FSSTStorage::GetDictionary(*current_segment, current_handle); + current_end_ptr = current_handle.Ptr() + current_dictionary.end; + } + + void CreateEmptySegment(idx_t row_start) { + auto &db = checkpointer.GetDatabase(); + auto &type = checkpointer.GetType(); + auto compressed_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); + current_segment = std::move(compressed_segment); + current_segment->function = function; + Reset(); + } + + void UpdateState(string_t uncompressed_string, unsigned char *compressed_string, size_t compressed_string_len) { + if (!HasEnoughSpace(compressed_string_len)) { + Flush(); + if (!HasEnoughSpace(compressed_string_len)) { + throw InternalException("FSST string compression failed due to insufficient space in empty block"); + }; + } + + UncompressedStringStorage::UpdateStringStats(current_segment->stats, uncompressed_string); + + // Write string into dictionary + current_dictionary.size += compressed_string_len; + auto dict_pos = current_end_ptr - current_dictionary.size; + memcpy(dict_pos, compressed_string, compressed_string_len); + current_dictionary.Verify(); + + // We just push the string length to effectively delta encode the strings + index_buffer.push_back(compressed_string_len); + + max_compressed_string_length = MaxValue(max_compressed_string_length, compressed_string_len); + + current_width = BitpackingPrimitives::MinimumBitWidth(max_compressed_string_length); + current_segment->count++; + } + + void AddNull() { + if (!HasEnoughSpace(0)) { + Flush(); + if (!HasEnoughSpace(0)) { + throw InternalException("FSST string compression failed due to insufficient space in empty block"); + }; + } + index_buffer.push_back(0); + current_segment->count++; + } + + void AddEmptyString() { + AddNull(); + UncompressedStringStorage::UpdateStringStats(current_segment->stats, ""); + } + + size_t GetRequiredSize(size_t string_len) { + bitpacking_width_t required_minimum_width; + if (string_len > max_compressed_string_length) { + required_minimum_width = BitpackingPrimitives::MinimumBitWidth(string_len); + } else { + required_minimum_width = current_width; + } + + size_t current_dict_size = current_dictionary.size; + idx_t current_string_count = index_buffer.size(); + + size_t dict_offsets_size = + BitpackingPrimitives::GetRequiredSize(current_string_count + 1, required_minimum_width); + + // TODO switch to a symbol table per RowGroup, saves a bit of space + return sizeof(fsst_compression_header_t) + current_dict_size + dict_offsets_size + string_len + + fsst_serialized_symbol_table_size; + } + + // Checks if there is enough space, if there is, sets last_fitting_size + bool HasEnoughSpace(size_t string_len) { + auto required_size = GetRequiredSize(string_len); + + if (required_size <= Storage::BLOCK_SIZE) { + last_fitting_size = required_size; + return true; + } + return false; + } + + void Flush(bool final = false) { + auto next_start = current_segment->start + current_segment->count; + + auto segment_size = Finalize(); + auto &state = checkpointer.GetCheckpointState(); + state.FlushSegment(std::move(current_segment), segment_size); + + if (!final) { + CreateEmptySegment(next_start); + } + } + + idx_t Finalize() { + auto &buffer_manager = BufferManager::GetBufferManager(current_segment->db); + auto handle = buffer_manager.Pin(current_segment->block); + D_ASSERT(current_dictionary.end == Storage::BLOCK_SIZE); + + // calculate sizes + auto compressed_index_buffer_size = + BitpackingPrimitives::GetRequiredSize(current_segment->count, current_width); + auto total_size = sizeof(fsst_compression_header_t) + compressed_index_buffer_size + current_dictionary.size + + fsst_serialized_symbol_table_size; + + if (total_size != last_fitting_size) { + throw InternalException("FSST string compression failed due to incorrect size calculation"); + } + + // calculate ptr and offsets + auto base_ptr = handle.Ptr(); + auto header_ptr = reinterpret_cast(base_ptr); + auto compressed_index_buffer_offset = sizeof(fsst_compression_header_t); + auto symbol_table_offset = compressed_index_buffer_offset + compressed_index_buffer_size; + + D_ASSERT(current_segment->count == index_buffer.size()); + BitpackingPrimitives::PackBuffer(base_ptr + compressed_index_buffer_offset, + reinterpret_cast(index_buffer.data()), + current_segment->count, current_width); + + // Write the fsst symbol table or nothing + if (fsst_encoder != nullptr) { + memcpy(base_ptr + symbol_table_offset, &fsst_serialized_symbol_table[0], fsst_serialized_symbol_table_size); + } else { + memset(base_ptr + symbol_table_offset, 0, fsst_serialized_symbol_table_size); + } + + Store(symbol_table_offset, data_ptr_cast(&header_ptr->fsst_symbol_table_offset)); + Store((uint32_t)current_width, data_ptr_cast(&header_ptr->bitpacking_width)); + + if (total_size >= FSSTStorage::COMPACTION_FLUSH_LIMIT) { + // the block is full enough, don't bother moving around the dictionary + return Storage::BLOCK_SIZE; + } + // the block has space left: figure out how much space we can save + auto move_amount = Storage::BLOCK_SIZE - total_size; + // move the dictionary so it lines up exactly with the offsets + auto new_dictionary_offset = symbol_table_offset + fsst_serialized_symbol_table_size; + memmove(base_ptr + new_dictionary_offset, base_ptr + current_dictionary.end - current_dictionary.size, + current_dictionary.size); + current_dictionary.end -= move_amount; + D_ASSERT(current_dictionary.end == total_size); + // write the new dictionary (with the updated "end") + FSSTStorage::SetDictionary(*current_segment, handle, current_dictionary); + + return total_size; + } + + ColumnDataCheckpointer &checkpointer; + CompressionFunction &function; + + // State regarding current segment + unique_ptr current_segment; + BufferHandle current_handle; + StringDictionaryContainer current_dictionary; + data_ptr_t current_end_ptr; + + // Buffers and map for current segment + vector index_buffer; + + size_t max_compressed_string_length; + bitpacking_width_t current_width; + idx_t last_fitting_size; + + duckdb_fsst_encoder_t *fsst_encoder = nullptr; + unsigned char fsst_serialized_symbol_table[sizeof(duckdb_fsst_decoder_t)]; + size_t fsst_serialized_symbol_table_size = sizeof(duckdb_fsst_decoder_t); +}; + +unique_ptr FSSTStorage::InitCompression(ColumnDataCheckpointer &checkpointer, + unique_ptr analyze_state_p) { + auto analyze_state = static_cast(analyze_state_p.get()); + auto compression_state = make_uniq(checkpointer); + + if (analyze_state->fsst_encoder == nullptr) { + throw InternalException("No encoder found during FSST compression"); + } + + compression_state->fsst_encoder = analyze_state->fsst_encoder; + compression_state->fsst_serialized_symbol_table_size = + duckdb_fsst_export(compression_state->fsst_encoder, &compression_state->fsst_serialized_symbol_table[0]); + analyze_state->fsst_encoder = nullptr; + + return std::move(compression_state); +} + +void FSSTStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { + auto &state = state_p.Cast(); + + // Get vector data + UnifiedVectorFormat vdata; + scan_vector.ToUnifiedFormat(count, vdata); + auto data = UnifiedVectorFormat::GetData(vdata); + + // Collect pointers to strings to compress + vector sizes_in; + vector strings_in; + size_t total_size = 0; + idx_t total_count = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + + // Note: we treat nulls and empty strings the same + if (!vdata.validity.RowIsValid(idx) || data[idx].GetSize() == 0) { + continue; + } + + total_count++; + total_size += data[idx].GetSize(); + sizes_in.push_back(data[idx].GetSize()); + strings_in.push_back((unsigned char *)data[idx].GetData()); // NOLINT + } + + // Only Nulls or empty strings in this vector, nothing to compress + if (total_count == 0) { + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + if (!vdata.validity.RowIsValid(idx)) { + state.AddNull(); + } else if (data[idx].GetSize() == 0) { + state.AddEmptyString(); + } else { + throw FatalException("FSST: no encoder found even though there are values to encode"); + } + } + return; + } + + // Compress buffers + size_t compress_buffer_size = MaxValue(total_size * 2 + 7, 1); + vector strings_out(total_count, nullptr); + vector sizes_out(total_count, 0); + vector compress_buffer(compress_buffer_size, 0); + + auto res = duckdb_fsst_compress( + state.fsst_encoder, /* IN: encoder obtained from duckdb_fsst_create(). */ + total_count, /* IN: number of strings in batch to compress. */ + &sizes_in[0], /* IN: byte-lengths of the inputs */ + &strings_in[0], /* IN: input string start pointers. */ + compress_buffer_size, /* IN: byte-length of output buffer. */ + &compress_buffer[0], /* OUT: memory buffer to put the compressed strings in (one after the other). */ + &sizes_out[0], /* OUT: byte-lengths of the compressed strings. */ + &strings_out[0] /* OUT: output string start pointers. Will all point into [output,output+size). */ + ); + + if (res != total_count) { + throw FatalException("FSST compression failed to compress all strings"); + } + + // Push the compressed strings to the compression state one by one + idx_t compressed_idx = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + if (!vdata.validity.RowIsValid(idx)) { + state.AddNull(); + } else if (data[idx].GetSize() == 0) { + state.AddEmptyString(); + } else { + state.UpdateState(data[idx], strings_out[compressed_idx], sizes_out[compressed_idx]); + compressed_idx++; + } + } +} + +void FSSTStorage::FinalizeCompress(CompressionState &state_p) { + auto &state = state_p.Cast(); + state.Flush(true); +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +struct FSSTScanState : public StringScanState { + FSSTScanState() { + ResetStoredDelta(); + } + + buffer_ptr duckdb_fsst_decoder; + bitpacking_width_t current_width; + + // To speed up delta decoding we store the last index + uint32_t last_known_index; + int64_t last_known_row; + + void StoreLastDelta(uint32_t value, int64_t row) { + last_known_index = value; + last_known_row = row; + } + void ResetStoredDelta() { + last_known_index = 0; + last_known_row = -1; + } +}; + +unique_ptr FSSTStorage::StringInitScan(ColumnSegment &segment) { + auto state = make_uniq(); + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + state->handle = buffer_manager.Pin(segment.block); + auto base_ptr = state->handle.Ptr() + segment.GetBlockOffset(); + + state->duckdb_fsst_decoder = make_buffer(); + auto retval = ParseFSSTSegmentHeader( + base_ptr, reinterpret_cast(state->duckdb_fsst_decoder.get()), &state->current_width); + if (!retval) { + state->duckdb_fsst_decoder = nullptr; + } + + return std::move(state); +} + +void DeltaDecodeIndices(uint32_t *buffer_in, uint32_t *buffer_out, idx_t decode_count, uint32_t last_known_value) { + buffer_out[0] = buffer_in[0]; + buffer_out[0] += last_known_value; + for (idx_t i = 1; i < decode_count; i++) { + buffer_out[i] = buffer_in[i] + buffer_out[i - 1]; + } +} + +void BitUnpackRange(data_ptr_t src_ptr, data_ptr_t dst_ptr, idx_t count, idx_t row, bitpacking_width_t width) { + auto bitunpack_src_ptr = &src_ptr[(row * width) / 8]; + BitpackingPrimitives::UnPackBuffer(dst_ptr, bitunpack_src_ptr, count, width); +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +template +void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + + auto &scan_state = state.scan_state->Cast(); + auto start = segment.GetRelativeIndex(state.row_index); + + bool enable_fsst_vectors; + if (ALLOW_FSST_VECTORS) { + auto &config = DBConfig::GetConfig(segment.db); + enable_fsst_vectors = config.options.enable_fsst_vectors; + } else { + enable_fsst_vectors = false; + } + + auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); + auto dict = GetDictionary(segment, scan_state.handle); + auto base_data = data_ptr_cast(baseptr + sizeof(fsst_compression_header_t)); + string_t *result_data; + + if (scan_count == 0) { + return; + } + + if (enable_fsst_vectors) { + D_ASSERT(result_offset == 0); + if (scan_state.duckdb_fsst_decoder) { + D_ASSERT(result_offset == 0 || result.GetVectorType() == VectorType::FSST_VECTOR); + result.SetVectorType(VectorType::FSST_VECTOR); + FSSTVector::RegisterDecoder(result, scan_state.duckdb_fsst_decoder); + result_data = FSSTVector::GetCompressedData(result); + } else { + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + result_data = FlatVector::GetData(result); + } + } else { + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + result_data = FlatVector::GetData(result); + } + + if (start == 0 || scan_state.last_known_row >= (int64_t)start) { + scan_state.ResetStoredDelta(); + } + + auto offsets = CalculateBpDeltaOffsets(scan_state.last_known_row, start, scan_count); + + auto bitunpack_buffer = unique_ptr(new uint32_t[offsets.total_bitunpack_count]); + BitUnpackRange(base_data, data_ptr_cast(bitunpack_buffer.get()), offsets.total_bitunpack_count, + offsets.bitunpack_start_row, scan_state.current_width); + auto delta_decode_buffer = unique_ptr(new uint32_t[offsets.total_delta_decode_count]); + DeltaDecodeIndices(bitunpack_buffer.get() + offsets.bitunpack_alignment_offset, delta_decode_buffer.get(), + offsets.total_delta_decode_count, scan_state.last_known_index); + + if (enable_fsst_vectors) { + // Lookup decompressed offsets in dict + for (idx_t i = 0; i < scan_count; i++) { + uint32_t string_length = bitunpack_buffer[i + offsets.scan_offset]; + result_data[i] = UncompressedStringStorage::FetchStringFromDict( + segment, dict, result, baseptr, delta_decode_buffer[i + offsets.unused_delta_decoded_values], + string_length); + FSSTVector::SetCount(result, scan_count); + } + } else { + // Just decompress + for (idx_t i = 0; i < scan_count; i++) { + uint32_t str_len = bitunpack_buffer[i + offsets.scan_offset]; + auto str_ptr = FSSTStorage::FetchStringPointer( + dict, baseptr, delta_decode_buffer[i + offsets.unused_delta_decoded_values]); + + if (str_len > 0) { + result_data[i + result_offset] = + FSSTPrimitives::DecompressValue(scan_state.duckdb_fsst_decoder.get(), result, str_ptr, str_len); + } else { + result_data[i + result_offset] = string_t(nullptr, 0); + } + } + } + + scan_state.StoreLastDelta(delta_decode_buffer[scan_count + offsets.unused_delta_decoded_values - 1], + start + scan_count - 1); +} + +void FSSTStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + StringScanPartial(segment, state, scan_count, result, 0); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +void FSSTStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + auto base_ptr = handle.Ptr() + segment.GetBlockOffset(); + auto base_data = data_ptr_cast(base_ptr + sizeof(fsst_compression_header_t)); + auto dict = GetDictionary(segment, handle); + + duckdb_fsst_decoder_t decoder; + bitpacking_width_t width; + auto have_symbol_table = ParseFSSTSegmentHeader(base_ptr, &decoder, &width); + + auto result_data = FlatVector::GetData(result); + + if (have_symbol_table) { + // We basically just do a scan of 1 which is kinda expensive as we need to repeatedly delta decode until we + // reach the row we want, we could consider a more clever caching trick if this is slow + auto offsets = CalculateBpDeltaOffsets(-1, row_id, 1); + + auto bitunpack_buffer = unique_ptr(new uint32_t[offsets.total_bitunpack_count]); + BitUnpackRange(base_data, data_ptr_cast(bitunpack_buffer.get()), offsets.total_bitunpack_count, + offsets.bitunpack_start_row, width); + auto delta_decode_buffer = unique_ptr(new uint32_t[offsets.total_delta_decode_count]); + DeltaDecodeIndices(bitunpack_buffer.get() + offsets.bitunpack_alignment_offset, delta_decode_buffer.get(), + offsets.total_delta_decode_count, 0); + + uint32_t string_length = bitunpack_buffer[offsets.scan_offset]; + + string_t compressed_string = UncompressedStringStorage::FetchStringFromDict( + segment, dict, result, base_ptr, delta_decode_buffer[offsets.unused_delta_decoded_values], string_length); + + result_data[result_idx] = FSSTPrimitives::DecompressValue((void *)&decoder, result, compressed_string.GetData(), + compressed_string.GetSize()); + } else { + // There's no fsst symtable, this only happens for empty strings or nulls, we can just emit an empty string + result_data[result_idx] = string_t(nullptr, 0); + } +} + +//===--------------------------------------------------------------------===// +// Get Function +//===--------------------------------------------------------------------===// +CompressionFunction FSSTFun::GetFunction(PhysicalType data_type) { + D_ASSERT(data_type == PhysicalType::VARCHAR); + return CompressionFunction( + CompressionType::COMPRESSION_FSST, data_type, FSSTStorage::StringInitAnalyze, FSSTStorage::StringAnalyze, + FSSTStorage::StringFinalAnalyze, FSSTStorage::InitCompression, FSSTStorage::Compress, + FSSTStorage::FinalizeCompress, FSSTStorage::StringInitScan, FSSTStorage::StringScan, + FSSTStorage::StringScanPartial, FSSTStorage::StringFetchRow, UncompressedFunctions::EmptySkip); +} + +bool FSSTFun::TypeIsSupported(PhysicalType type) { + return type == PhysicalType::VARCHAR; +} + +//===--------------------------------------------------------------------===// +// Helper Functions +//===--------------------------------------------------------------------===// +void FSSTStorage::SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer container) { + auto header_ptr = reinterpret_cast(handle.Ptr() + segment.GetBlockOffset()); + Store(container.size, data_ptr_cast(&header_ptr->dict_size)); + Store(container.end, data_ptr_cast(&header_ptr->dict_end)); +} + +StringDictionaryContainer FSSTStorage::GetDictionary(ColumnSegment &segment, BufferHandle &handle) { + auto header_ptr = reinterpret_cast(handle.Ptr() + segment.GetBlockOffset()); + StringDictionaryContainer container; + container.size = Load(data_ptr_cast(&header_ptr->dict_size)); + container.end = Load(data_ptr_cast(&header_ptr->dict_end)); + return container; +} + +char *FSSTStorage::FetchStringPointer(StringDictionaryContainer dict, data_ptr_t baseptr, int32_t dict_offset) { + if (dict_offset == 0) { + return nullptr; + } + + auto dict_end = baseptr + dict.end; + auto dict_pos = dict_end - dict_offset; + return char_ptr_cast(dict_pos); +} + +// Returns false if no symbol table was found. This means all strings are either empty or null +bool FSSTStorage::ParseFSSTSegmentHeader(data_ptr_t base_ptr, duckdb_fsst_decoder_t *decoder_out, + bitpacking_width_t *width_out) { + auto header_ptr = reinterpret_cast(base_ptr); + auto fsst_symbol_table_offset = Load(data_ptr_cast(&header_ptr->fsst_symbol_table_offset)); + *width_out = (bitpacking_width_t)(Load(data_ptr_cast(&header_ptr->bitpacking_width))); + return duckdb_fsst_import(decoder_out, base_ptr + fsst_symbol_table_offset); +} + +// The calculation of offsets and counts while scanning or fetching is a bit tricky, for two reasons: +// - bitunpacking needs to be aligned to BITPACKING_ALGORITHM_GROUP_SIZE +// - delta decoding needs to decode from the last known value. +bp_delta_offsets_t FSSTStorage::CalculateBpDeltaOffsets(int64_t last_known_row, idx_t start, idx_t scan_count) { + D_ASSERT((idx_t)(last_known_row + 1) <= start); + bp_delta_offsets_t result; + + result.delta_decode_start_row = (idx_t)(last_known_row + 1); + result.bitunpack_alignment_offset = + result.delta_decode_start_row % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; + result.bitunpack_start_row = result.delta_decode_start_row - result.bitunpack_alignment_offset; + result.unused_delta_decoded_values = start - result.delta_decode_start_row; + result.scan_offset = result.bitunpack_alignment_offset + result.unused_delta_decoded_values; + result.total_delta_decode_count = scan_count + result.unused_delta_decoded_values; + result.total_bitunpack_count = + BitpackingPrimitives::RoundUpToAlgorithmGroupSize(scan_count + result.scan_offset); + + D_ASSERT(result.total_delta_decode_count + result.bitunpack_alignment_offset <= result.total_bitunpack_count); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/numeric_constant.cpp b/src/duckdb/src/storage/compression/numeric_constant.cpp new file mode 100644 index 00000000..bb18e1b8 --- /dev/null +++ b/src/duckdb/src/storage/compression/numeric_constant.cpp @@ -0,0 +1,161 @@ +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/common/types/vector.hpp" + +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/segment/uncompressed.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +unique_ptr ConstantInitScan(ColumnSegment &segment) { + return nullptr; +} + +//===--------------------------------------------------------------------===// +// Scan Partial +//===--------------------------------------------------------------------===// +void ConstantFillFunctionValidity(ColumnSegment &segment, Vector &result, idx_t start_idx, idx_t count) { + auto &stats = segment.stats.statistics; + if (stats.CanHaveNull()) { + auto &mask = FlatVector::Validity(result); + for (idx_t i = 0; i < count; i++) { + mask.SetInvalid(start_idx + i); + } + } +} + +template +void ConstantFillFunction(ColumnSegment &segment, Vector &result, idx_t start_idx, idx_t count) { + auto &nstats = segment.stats.statistics; + + auto data = FlatVector::GetData(result); + auto constant_value = NumericStats::GetMin(nstats); + for (idx_t i = 0; i < count; i++) { + data[start_idx + i] = constant_value; + } +} + +void ConstantScanPartialValidity(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + ConstantFillFunctionValidity(segment, result, result_offset, scan_count); +} + +template +void ConstantScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + ConstantFillFunction(segment, result, result_offset, scan_count); +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +void ConstantScanFunctionValidity(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + auto &stats = segment.stats.statistics; + if (stats.CanHaveNull()) { + if (result.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + result.Flatten(scan_count); + ConstantFillFunctionValidity(segment, result, 0, scan_count); + } + } +} + +template +void ConstantScanFunction(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + auto &nstats = segment.stats.statistics; + + auto data = FlatVector::GetData(result); + data[0] = NumericStats::GetMin(nstats); + result.SetVectorType(VectorType::CONSTANT_VECTOR); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +void ConstantFetchRowValidity(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + ConstantFillFunctionValidity(segment, result, result_idx, 1); +} + +template +void ConstantFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { + ConstantFillFunction(segment, result, result_idx, 1); +} + +//===--------------------------------------------------------------------===// +// Get Function +//===--------------------------------------------------------------------===// +CompressionFunction ConstantGetFunctionValidity(PhysicalType data_type) { + D_ASSERT(data_type == PhysicalType::BIT); + return CompressionFunction(CompressionType::COMPRESSION_CONSTANT, data_type, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, ConstantInitScan, ConstantScanFunctionValidity, + ConstantScanPartialValidity, ConstantFetchRowValidity, UncompressedFunctions::EmptySkip); +} + +template +CompressionFunction ConstantGetFunction(PhysicalType data_type) { + return CompressionFunction(CompressionType::COMPRESSION_CONSTANT, data_type, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, ConstantInitScan, ConstantScanFunction, ConstantScanPartial, + ConstantFetchRow, UncompressedFunctions::EmptySkip); +} + +CompressionFunction ConstantFun::GetFunction(PhysicalType data_type) { + switch (data_type) { + case PhysicalType::BIT: + return ConstantGetFunctionValidity(data_type); + case PhysicalType::BOOL: + case PhysicalType::INT8: + return ConstantGetFunction(data_type); + case PhysicalType::INT16: + return ConstantGetFunction(data_type); + case PhysicalType::INT32: + return ConstantGetFunction(data_type); + case PhysicalType::INT64: + return ConstantGetFunction(data_type); + case PhysicalType::UINT8: + return ConstantGetFunction(data_type); + case PhysicalType::UINT16: + return ConstantGetFunction(data_type); + case PhysicalType::UINT32: + return ConstantGetFunction(data_type); + case PhysicalType::UINT64: + return ConstantGetFunction(data_type); + case PhysicalType::INT128: + return ConstantGetFunction(data_type); + case PhysicalType::FLOAT: + return ConstantGetFunction(data_type); + case PhysicalType::DOUBLE: + return ConstantGetFunction(data_type); + default: + throw InternalException("Unsupported type for ConstantUncompressed::GetFunction"); + } +} + +bool ConstantFun::TypeIsSupported(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + case PhysicalType::BOOL: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::INT128: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + return true; + default: + throw InternalException("Unsupported type for constant function"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/patas.cpp b/src/duckdb/src/storage/compression/patas.cpp new file mode 100644 index 00000000..8df7f5c5 --- /dev/null +++ b/src/duckdb/src/storage/compression/patas.cpp @@ -0,0 +1,64 @@ +#include "duckdb/storage/compression/patas/patas.hpp" +#include "duckdb/storage/compression/patas/patas_compress.hpp" +#include "duckdb/storage/compression/patas/patas_scan.hpp" +#include "duckdb/storage/compression/patas/patas_fetch.hpp" +#include "duckdb/storage/compression/patas/patas_analyze.hpp" + +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/operator/subtract.hpp" + +#include + +namespace duckdb { + +template +CompressionFunction GetPatasFunction(PhysicalType data_type) { + throw NotImplementedException("GetPatasFunction not implemented for the given datatype"); +} + +template <> +CompressionFunction GetPatasFunction(PhysicalType data_type) { + return CompressionFunction(CompressionType::COMPRESSION_PATAS, data_type, PatasInitAnalyze, + PatasAnalyze, PatasFinalAnalyze, PatasInitCompression, + PatasCompress, PatasFinalizeCompress, PatasInitScan, + PatasScan, PatasScanPartial, PatasFetchRow, PatasSkip); +} + +template <> +CompressionFunction GetPatasFunction(PhysicalType data_type) { + return CompressionFunction(CompressionType::COMPRESSION_PATAS, data_type, PatasInitAnalyze, + PatasAnalyze, PatasFinalAnalyze, PatasInitCompression, + PatasCompress, PatasFinalizeCompress, PatasInitScan, + PatasScan, PatasScanPartial, PatasFetchRow, PatasSkip); +} + +CompressionFunction PatasCompressionFun::GetFunction(PhysicalType type) { + switch (type) { + case PhysicalType::FLOAT: + return GetPatasFunction(type); + case PhysicalType::DOUBLE: + return GetPatasFunction(type); + default: + throw InternalException("Unsupported type for Patas"); + } +} + +bool PatasCompressionFun::TypeIsSupported(PhysicalType type) { + switch (type) { + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + return true; + default: + return false; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/rle.cpp b/src/duckdb/src/storage/compression/rle.cpp new file mode 100644 index 00000000..a7944091 --- /dev/null +++ b/src/duckdb/src/storage/compression/rle.cpp @@ -0,0 +1,454 @@ +#include "duckdb/function/compression/compression.hpp" + +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include + +namespace duckdb { + +using rle_count_t = uint16_t; + +//===--------------------------------------------------------------------===// +// Analyze +//===--------------------------------------------------------------------===// +struct EmptyRLEWriter { + template + static void Operation(VALUE_TYPE value, rle_count_t count, void *dataptr, bool is_null) { + } +}; + +template +struct RLEState { + RLEState() : seen_count(0), last_value(NullValue()), last_seen_count(0), dataptr(nullptr) { + } + + idx_t seen_count; + T last_value; + rle_count_t last_seen_count; + void *dataptr; + bool all_null = true; + +public: + template + void Flush() { + OP::template Operation(last_value, last_seen_count, dataptr, all_null); + } + + template + void Update(const T *data, ValidityMask &validity, idx_t idx) { + if (validity.RowIsValid(idx)) { + if (all_null) { + // no value seen yet + // assign the current value, and increment the seen_count + // note that we increment last_seen_count rather than setting it to 1 + // this is intentional: this is the first VALID value we see + // but it might not be the first value in case of nulls! + last_value = data[idx]; + seen_count++; + last_seen_count++; + all_null = false; + } else if (last_value == data[idx]) { + // the last value is identical to this value: increment the last_seen_count + last_seen_count++; + } else { + // the values are different + // issue the callback on the last value + Flush(); + + // increment the seen_count and put the new value into the RLE slot + last_value = data[idx]; + seen_count++; + last_seen_count = 1; + } + } else { + // NULL value: we merely increment the last_seen_count + last_seen_count++; + } + if (last_seen_count == NumericLimits::Maximum()) { + // we have seen the same value so many times in a row we are at the limit of what fits in our count + // write away the value and move to the next value + Flush(); + last_seen_count = 0; + seen_count++; + } + } +}; + +template +struct RLEAnalyzeState : public AnalyzeState { + RLEAnalyzeState() { + } + + RLEState state; +}; + +template +unique_ptr RLEInitAnalyze(ColumnData &col_data, PhysicalType type) { + return make_uniq>(); +} + +template +bool RLEAnalyze(AnalyzeState &state, Vector &input, idx_t count) { + auto &rle_state = state.template Cast>(); + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + rle_state.state.Update(data, vdata.validity, idx); + } + return true; +} + +template +idx_t RLEFinalAnalyze(AnalyzeState &state) { + auto &rle_state = state.template Cast>(); + return (sizeof(rle_count_t) + sizeof(T)) * rle_state.state.seen_count; +} + +//===--------------------------------------------------------------------===// +// Compress +//===--------------------------------------------------------------------===// +struct RLEConstants { + static constexpr const idx_t RLE_HEADER_SIZE = sizeof(uint64_t); +}; + +template +struct RLECompressState : public CompressionState { + struct RLEWriter { + template + static void Operation(VALUE_TYPE value, rle_count_t count, void *dataptr, bool is_null) { + auto state = reinterpret_cast *>(dataptr); + state->WriteValue(value, count, is_null); + } + }; + + static idx_t MaxRLECount() { + auto entry_size = sizeof(T) + sizeof(rle_count_t); + auto entry_count = (Storage::BLOCK_SIZE - RLEConstants::RLE_HEADER_SIZE) / entry_size; + auto max_vector_count = entry_count / STANDARD_VECTOR_SIZE; + return max_vector_count * STANDARD_VECTOR_SIZE; + } + + explicit RLECompressState(ColumnDataCheckpointer &checkpointer_p) + : checkpointer(checkpointer_p), + function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_RLE)) { + CreateEmptySegment(checkpointer.GetRowGroup().start); + + state.dataptr = (void *)this; + max_rle_count = MaxRLECount(); + } + + void CreateEmptySegment(idx_t row_start) { + auto &db = checkpointer.GetDatabase(); + auto &type = checkpointer.GetType(); + auto column_segment = ColumnSegment::CreateTransientSegment(db, type, row_start); + column_segment->function = function; + current_segment = std::move(column_segment); + auto &buffer_manager = BufferManager::GetBufferManager(db); + handle = buffer_manager.Pin(current_segment->block); + } + + void Append(UnifiedVectorFormat &vdata, idx_t count) { + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + state.template Update::RLEWriter>(data, vdata.validity, idx); + } + } + + void WriteValue(T value, rle_count_t count, bool is_null) { + // write the RLE entry + auto handle_ptr = handle.Ptr() + RLEConstants::RLE_HEADER_SIZE; + auto data_pointer = reinterpret_cast(handle_ptr); + auto index_pointer = reinterpret_cast(handle_ptr + max_rle_count * sizeof(T)); + data_pointer[entry_count] = value; + index_pointer[entry_count] = count; + entry_count++; + + // update meta data + if (WRITE_STATISTICS && !is_null) { + NumericStats::Update(current_segment->stats.statistics, value); + } + current_segment->count += count; + + if (entry_count == max_rle_count) { + // we have finished writing this segment: flush it and create a new segment + auto row_start = current_segment->start + current_segment->count; + FlushSegment(); + CreateEmptySegment(row_start); + entry_count = 0; + } + } + + void FlushSegment() { + // flush the segment + // we compact the segment by moving the counts so they are directly next to the values + idx_t counts_size = sizeof(rle_count_t) * entry_count; + idx_t original_rle_offset = RLEConstants::RLE_HEADER_SIZE + max_rle_count * sizeof(T); + idx_t minimal_rle_offset = AlignValue(RLEConstants::RLE_HEADER_SIZE + sizeof(T) * entry_count); + idx_t total_segment_size = minimal_rle_offset + counts_size; + auto data_ptr = handle.Ptr(); + memmove(data_ptr + minimal_rle_offset, data_ptr + original_rle_offset, counts_size); + // store the final RLE offset within the segment + Store(minimal_rle_offset, data_ptr); + handle.Destroy(); + + auto &state = checkpointer.GetCheckpointState(); + state.FlushSegment(std::move(current_segment), total_segment_size); + } + + void Finalize() { + state.template Flush::RLEWriter>(); + + FlushSegment(); + current_segment.reset(); + } + + ColumnDataCheckpointer &checkpointer; + CompressionFunction &function; + unique_ptr current_segment; + BufferHandle handle; + + RLEState state; + idx_t entry_count = 0; + idx_t max_rle_count; +}; + +template +unique_ptr RLEInitCompression(ColumnDataCheckpointer &checkpointer, unique_ptr state) { + return make_uniq>(checkpointer); +} + +template +void RLECompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { + auto &state = (RLECompressState &)state_p; + UnifiedVectorFormat vdata; + scan_vector.ToUnifiedFormat(count, vdata); + + state.Append(vdata, count); +} + +template +void RLEFinalizeCompress(CompressionState &state_p) { + auto &state = (RLECompressState &)state_p; + state.Finalize(); +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +template +struct RLEScanState : public SegmentScanState { + explicit RLEScanState(ColumnSegment &segment) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + handle = buffer_manager.Pin(segment.block); + entry_pos = 0; + position_in_entry = 0; + rle_count_offset = Load(handle.Ptr() + segment.GetBlockOffset()); + D_ASSERT(rle_count_offset <= Storage::BLOCK_SIZE); + } + + void Skip(ColumnSegment &segment, idx_t skip_count) { + auto data = handle.Ptr() + segment.GetBlockOffset(); + auto index_pointer = reinterpret_cast(data + rle_count_offset); + + for (idx_t i = 0; i < skip_count; i++) { + // assign the current value + position_in_entry++; + if (position_in_entry >= index_pointer[entry_pos]) { + // handled all entries in this RLE value + // move to the next entry + entry_pos++; + position_in_entry = 0; + } + } + } + + BufferHandle handle; + idx_t entry_pos; + idx_t position_in_entry; + uint32_t rle_count_offset; +}; + +template +unique_ptr RLEInitScan(ColumnSegment &segment) { + auto result = make_uniq>(segment); + return std::move(result); +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +template +void RLESkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { + auto &scan_state = state.scan_state->Cast>(); + scan_state.Skip(segment, skip_count); +} + +template +static bool CanEmitConstantVector(idx_t position, idx_t run_length, idx_t scan_count) { + if (!ENTIRE_VECTOR) { + return false; + } + if (scan_count != STANDARD_VECTOR_SIZE) { + // Only when we can fill an entire Vector can we emit a ConstantVector, because subsequent scans require the + // input Vector to be flat + return false; + } + D_ASSERT(position < run_length); + auto remaining_in_run = run_length - position; + // The amount of values left in this run are equal or greater than the amount of values we need to scan + return remaining_in_run >= scan_count; +} + +template +inline static void ForwardToNextRun(RLEScanState &scan_state) { + // handled all entries in this RLE value + // move to the next entry + scan_state.entry_pos++; + scan_state.position_in_entry = 0; +} + +template +inline static bool ExhaustedRun(RLEScanState &scan_state, rle_count_t *index_pointer) { + return scan_state.position_in_entry >= index_pointer[scan_state.entry_pos]; +} + +template +static void RLEScanConstant(RLEScanState &scan_state, rle_count_t *index_pointer, T *data_pointer, idx_t scan_count, + Vector &result) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto result_data = ConstantVector::GetData(result); + result_data[0] = data_pointer[scan_state.entry_pos]; + scan_state.position_in_entry += scan_count; + if (ExhaustedRun(scan_state, index_pointer)) { + ForwardToNextRun(scan_state); + } + return; +} + +template +void RLEScanPartialInternal(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + auto &scan_state = state.scan_state->Cast>(); + + auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); + auto data_pointer = reinterpret_cast(data + RLEConstants::RLE_HEADER_SIZE); + auto index_pointer = reinterpret_cast(data + scan_state.rle_count_offset); + + // If we are scanning an entire Vector and it contains only a single run + if (CanEmitConstantVector(scan_state.position_in_entry, index_pointer[scan_state.entry_pos], + scan_count)) { + RLEScanConstant(scan_state, index_pointer, data_pointer, scan_count, result); + return; + } + + auto result_data = FlatVector::GetData(result); + result.SetVectorType(VectorType::FLAT_VECTOR); + for (idx_t i = 0; i < scan_count; i++) { + // assign the current value + result_data[result_offset + i] = data_pointer[scan_state.entry_pos]; + scan_state.position_in_entry++; + if (ExhaustedRun(scan_state, index_pointer)) { + ForwardToNextRun(scan_state); + } + } +} + +template +void RLEScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + return RLEScanPartialInternal(segment, state, scan_count, result, result_offset); +} + +template +void RLEScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + RLEScanPartialInternal(segment, state, scan_count, result, 0); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +template +void RLEFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { + RLEScanState scan_state(segment); + scan_state.Skip(segment, row_id); + + auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); + auto data_pointer = reinterpret_cast(data + RLEConstants::RLE_HEADER_SIZE); + auto result_data = FlatVector::GetData(result); + result_data[result_idx] = data_pointer[scan_state.entry_pos]; +} + +//===--------------------------------------------------------------------===// +// Get Function +//===--------------------------------------------------------------------===// +template +CompressionFunction GetRLEFunction(PhysicalType data_type) { + return CompressionFunction(CompressionType::COMPRESSION_RLE, data_type, RLEInitAnalyze, RLEAnalyze, + RLEFinalAnalyze, RLEInitCompression, + RLECompress, RLEFinalizeCompress, + RLEInitScan, RLEScan, RLEScanPartial, RLEFetchRow, RLESkip); +} + +CompressionFunction RLEFun::GetFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + return GetRLEFunction(type); + case PhysicalType::INT16: + return GetRLEFunction(type); + case PhysicalType::INT32: + return GetRLEFunction(type); + case PhysicalType::INT64: + return GetRLEFunction(type); + case PhysicalType::INT128: + return GetRLEFunction(type); + case PhysicalType::UINT8: + return GetRLEFunction(type); + case PhysicalType::UINT16: + return GetRLEFunction(type); + case PhysicalType::UINT32: + return GetRLEFunction(type); + case PhysicalType::UINT64: + return GetRLEFunction(type); + case PhysicalType::FLOAT: + return GetRLEFunction(type); + case PhysicalType::DOUBLE: + return GetRLEFunction(type); + case PhysicalType::LIST: + return GetRLEFunction(type); + default: + throw InternalException("Unsupported type for RLE"); + } +} + +bool RLEFun::TypeIsSupported(PhysicalType type) { + switch (type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::INT128: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + case PhysicalType::LIST: + return true; + default: + return false; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/string_uncompressed.cpp b/src/duckdb/src/storage/compression/string_uncompressed.cpp new file mode 100644 index 00000000..4f7c6da2 --- /dev/null +++ b/src/duckdb/src/storage/compression/string_uncompressed.cpp @@ -0,0 +1,444 @@ +#include "duckdb/storage/string_uncompressed.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/storage/table/column_data.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Storage Class +//===--------------------------------------------------------------------===// +UncompressedStringSegmentState::~UncompressedStringSegmentState() { + while (head) { + // prevent deep recursion here + head = std::move(head->next); + } +} + +//===--------------------------------------------------------------------===// +// Analyze +//===--------------------------------------------------------------------===// +struct StringAnalyzeState : public AnalyzeState { + StringAnalyzeState() : count(0), total_string_size(0), overflow_strings(0) { + } + + idx_t count; + idx_t total_string_size; + idx_t overflow_strings; +}; + +unique_ptr UncompressedStringStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { + return make_uniq(); +} + +bool UncompressedStringStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { + auto &state = state_p.Cast(); + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + + state.count += count; + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = vdata.sel->get_index(i); + if (vdata.validity.RowIsValid(idx)) { + auto string_size = data[idx].GetSize(); + state.total_string_size += string_size; + if (string_size >= StringUncompressed::STRING_BLOCK_LIMIT) { + state.overflow_strings++; + } + } + } + return true; +} + +idx_t UncompressedStringStorage::StringFinalAnalyze(AnalyzeState &state_p) { + auto &state = state_p.Cast(); + return state.count * sizeof(int32_t) + state.total_string_size + state.overflow_strings * BIG_STRING_MARKER_SIZE; +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +unique_ptr UncompressedStringStorage::StringInitScan(ColumnSegment &segment) { + auto result = make_uniq(); + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + result->handle = buffer_manager.Pin(segment.block); + return std::move(result); +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +void UncompressedStringStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, + Vector &result, idx_t result_offset) { + // clear any previously locked buffers and get the primary buffer handle + auto &scan_state = state.scan_state->Cast(); + auto start = segment.GetRelativeIndex(state.row_index); + + auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); + auto dict = GetDictionary(segment, scan_state.handle); + auto base_data = reinterpret_cast(baseptr + DICTIONARY_HEADER_SIZE); + auto result_data = FlatVector::GetData(result); + + int32_t previous_offset = start > 0 ? base_data[start - 1] : 0; + + for (idx_t i = 0; i < scan_count; i++) { + // std::abs used since offsets can be negative to indicate big strings + uint32_t string_length = std::abs(base_data[start + i]) - std::abs(previous_offset); + result_data[result_offset + i] = + FetchStringFromDict(segment, dict, result, baseptr, base_data[start + i], string_length); + previous_offset = base_data[start + i]; + } +} + +void UncompressedStringStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, + Vector &result) { + StringScanPartial(segment, state, scan_count, result, 0); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +BufferHandle &ColumnFetchState::GetOrInsertHandle(ColumnSegment &segment) { + auto primary_id = segment.block->BlockId(); + + auto entry = handles.find(primary_id); + if (entry == handles.end()) { + // not pinned yet: pin it + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + auto entry = handles.insert(make_pair(primary_id, std::move(handle))); + return entry.first->second; + } else { + // already pinned: use the pinned handle + return entry->second; + } +} + +void UncompressedStringStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, + Vector &result, idx_t result_idx) { + // fetch a single row from the string segment + // first pin the main buffer if it is not already pinned + auto &handle = state.GetOrInsertHandle(segment); + + auto baseptr = handle.Ptr() + segment.GetBlockOffset(); + auto dict = GetDictionary(segment, handle); + auto base_data = reinterpret_cast(baseptr + DICTIONARY_HEADER_SIZE); + auto result_data = FlatVector::GetData(result); + + auto dict_offset = base_data[row_id]; + uint32_t string_length; + if ((idx_t)row_id == 0) { + // edge case where this is the first string in the dict + string_length = std::abs(dict_offset); + } else { + string_length = std::abs(dict_offset) - std::abs(base_data[row_id - 1]); + } + result_data[result_idx] = FetchStringFromDict(segment, dict, result, baseptr, dict_offset, string_length); +} + +//===--------------------------------------------------------------------===// +// Append +//===--------------------------------------------------------------------===// +struct SerializedStringSegmentState : public ColumnSegmentState { + SerializedStringSegmentState() { + } + explicit SerializedStringSegmentState(vector blocks_p) : blocks(std::move(blocks_p)) { + } + + vector blocks; + + void Serialize(Serializer &serializer) const override { + serializer.WriteProperty(1, "overflow_blocks", blocks); + } +}; + +unique_ptr +UncompressedStringStorage::StringInitSegment(ColumnSegment &segment, block_id_t block_id, + optional_ptr segment_state) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + if (block_id == INVALID_BLOCK) { + auto handle = buffer_manager.Pin(segment.block); + StringDictionaryContainer dictionary; + dictionary.size = 0; + dictionary.end = segment.SegmentSize(); + SetDictionary(segment, handle, dictionary); + } + auto result = make_uniq(); + if (segment_state) { + auto &serialized_state = segment_state->Cast(); + result->on_disk_blocks = std::move(serialized_state.blocks); + } + return std::move(result); +} + +idx_t UncompressedStringStorage::FinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + auto dict = GetDictionary(segment, handle); + D_ASSERT(dict.end == segment.SegmentSize()); + // compute the total size required to store this segment + auto offset_size = DICTIONARY_HEADER_SIZE + segment.count * sizeof(int32_t); + auto total_size = offset_size + dict.size; + if (total_size >= COMPACTION_FLUSH_LIMIT) { + // the block is full enough, don't bother moving around the dictionary + return segment.SegmentSize(); + } + // the block has space left: figure out how much space we can save + auto move_amount = segment.SegmentSize() - total_size; + // move the dictionary so it lines up exactly with the offsets + auto dataptr = handle.Ptr(); + memmove(dataptr + offset_size, dataptr + dict.end - dict.size, dict.size); + dict.end -= move_amount; + D_ASSERT(dict.end == total_size); + // write the new dictionary (with the updated "end") + SetDictionary(segment, handle, dict); + return total_size; +} + +//===--------------------------------------------------------------------===// +// Serialization & Cleanup +//===--------------------------------------------------------------------===// +unique_ptr UncompressedStringStorage::SerializeState(ColumnSegment &segment) { + auto &state = segment.GetSegmentState()->Cast(); + if (state.on_disk_blocks.empty()) { + // no on-disk blocks - nothing to write + return nullptr; + } + return make_uniq(state.on_disk_blocks); +} + +unique_ptr UncompressedStringStorage::DeserializeState(Deserializer &deserializer) { + auto result = make_uniq(); + deserializer.ReadProperty(1, "overflow_blocks", result->blocks); + return std::move(result); +} + +void UncompressedStringStorage::CleanupState(ColumnSegment &segment) { + auto &state = segment.GetSegmentState()->Cast(); + auto &block_manager = segment.GetBlockManager(); + for (auto &block_id : state.on_disk_blocks) { + block_manager.MarkBlockAsModified(block_id); + } +} + +//===--------------------------------------------------------------------===// +// Get Function +//===--------------------------------------------------------------------===// +CompressionFunction StringUncompressed::GetFunction(PhysicalType data_type) { + D_ASSERT(data_type == PhysicalType::VARCHAR); + return CompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, data_type, + UncompressedStringStorage::StringInitAnalyze, UncompressedStringStorage::StringAnalyze, + UncompressedStringStorage::StringFinalAnalyze, UncompressedFunctions::InitCompression, + UncompressedFunctions::Compress, UncompressedFunctions::FinalizeCompress, + UncompressedStringStorage::StringInitScan, UncompressedStringStorage::StringScan, + UncompressedStringStorage::StringScanPartial, UncompressedStringStorage::StringFetchRow, + UncompressedFunctions::EmptySkip, UncompressedStringStorage::StringInitSegment, + UncompressedStringStorage::StringInitAppend, UncompressedStringStorage::StringAppend, + UncompressedStringStorage::FinalizeAppend, nullptr, + UncompressedStringStorage::SerializeState, UncompressedStringStorage::DeserializeState, + UncompressedStringStorage::CleanupState); +} + +//===--------------------------------------------------------------------===// +// Helper Functions +//===--------------------------------------------------------------------===// +void UncompressedStringStorage::SetDictionary(ColumnSegment &segment, BufferHandle &handle, + StringDictionaryContainer container) { + auto startptr = handle.Ptr() + segment.GetBlockOffset(); + Store(container.size, startptr); + Store(container.end, startptr + sizeof(uint32_t)); +} + +StringDictionaryContainer UncompressedStringStorage::GetDictionary(ColumnSegment &segment, BufferHandle &handle) { + auto startptr = handle.Ptr() + segment.GetBlockOffset(); + StringDictionaryContainer container; + container.size = Load(startptr); + container.end = Load(startptr + sizeof(uint32_t)); + return container; +} + +idx_t UncompressedStringStorage::RemainingSpace(ColumnSegment &segment, BufferHandle &handle) { + auto dictionary = GetDictionary(segment, handle); + D_ASSERT(dictionary.end == segment.SegmentSize()); + idx_t used_space = dictionary.size + segment.count * sizeof(int32_t) + DICTIONARY_HEADER_SIZE; + D_ASSERT(segment.SegmentSize() >= used_space); + return segment.SegmentSize() - used_space; +} + +void UncompressedStringStorage::WriteString(ColumnSegment &segment, string_t string, block_id_t &result_block, + int32_t &result_offset) { + auto &state = segment.GetSegmentState()->Cast(); + if (state.overflow_writer) { + // overflow writer is set: write string there + state.overflow_writer->WriteString(state, string, result_block, result_offset); + } else { + // default overflow behavior: use in-memory buffer to store the overflow string + WriteStringMemory(segment, string, result_block, result_offset); + } +} + +void UncompressedStringStorage::WriteStringMemory(ColumnSegment &segment, string_t string, block_id_t &result_block, + int32_t &result_offset) { + uint32_t total_length = string.GetSize() + sizeof(uint32_t); + shared_ptr block; + BufferHandle handle; + + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto &state = segment.GetSegmentState()->Cast(); + // check if the string fits in the current block + if (!state.head || state.head->offset + total_length >= state.head->size) { + // string does not fit, allocate space for it + // create a new string block + idx_t alloc_size = MaxValue(total_length, Storage::BLOCK_SIZE); + auto new_block = make_uniq(); + new_block->offset = 0; + new_block->size = alloc_size; + // allocate an in-memory buffer for it + handle = buffer_manager.Allocate(alloc_size, false, &block); + state.overflow_blocks.insert(make_pair(block->BlockId(), reference(*new_block))); + new_block->block = std::move(block); + new_block->next = std::move(state.head); + state.head = std::move(new_block); + } else { + // string fits, copy it into the current block + handle = buffer_manager.Pin(state.head->block); + } + + result_block = state.head->block->BlockId(); + result_offset = state.head->offset; + + // copy the string and the length there + auto ptr = handle.Ptr() + state.head->offset; + Store(string.GetSize(), ptr); + ptr += sizeof(uint32_t); + memcpy(ptr, string.GetData(), string.GetSize()); + state.head->offset += total_length; +} + +string_t UncompressedStringStorage::ReadOverflowString(ColumnSegment &segment, Vector &result, block_id_t block, + int32_t offset) { + D_ASSERT(block != INVALID_BLOCK); + D_ASSERT(offset < Storage::BLOCK_SIZE); + + auto &block_manager = segment.GetBlockManager(); + auto &buffer_manager = block_manager.buffer_manager; + auto &state = segment.GetSegmentState()->Cast(); + if (block < MAXIMUM_BLOCK) { + // read the overflow string from disk + // pin the initial handle and read the length + auto block_handle = state.GetHandle(block_manager, block); + auto handle = buffer_manager.Pin(block_handle); + + // read header + uint32_t length = Load(handle.Ptr() + offset); + uint32_t remaining = length; + offset += sizeof(uint32_t); + + // allocate a buffer to store the string + auto alloc_size = MaxValue(Storage::BLOCK_SIZE, length); + // allocate a buffer to store the compressed string + // TODO: profile this to check if we need to reuse buffer + auto target_handle = buffer_manager.Allocate(alloc_size); + auto target_ptr = target_handle.Ptr(); + + // now append the string to the single buffer + while (remaining > 0) { + idx_t to_write = MinValue(remaining, Storage::BLOCK_SIZE - sizeof(block_id_t) - offset); + memcpy(target_ptr, handle.Ptr() + offset, to_write); + remaining -= to_write; + offset += to_write; + target_ptr += to_write; + if (remaining > 0) { + // read the next block + block_id_t next_block = Load(handle.Ptr() + offset); + block_handle = state.GetHandle(block_manager, next_block); + handle = buffer_manager.Pin(block_handle); + offset = 0; + } + } + + auto final_buffer = target_handle.Ptr(); + StringVector::AddHandle(result, std::move(target_handle)); + return ReadString(final_buffer, 0, length); + } else { + // read the overflow string from memory + // first pin the handle, if it is not pinned yet + auto entry = state.overflow_blocks.find(block); + D_ASSERT(entry != state.overflow_blocks.end()); + auto handle = buffer_manager.Pin(entry->second.get().block); + auto final_buffer = handle.Ptr(); + StringVector::AddHandle(result, std::move(handle)); + return ReadStringWithLength(final_buffer, offset); + } +} + +string_t UncompressedStringStorage::ReadString(data_ptr_t target, int32_t offset, uint32_t string_length) { + auto ptr = target + offset; + auto str_ptr = char_ptr_cast(ptr); + return string_t(str_ptr, string_length); +} + +string_t UncompressedStringStorage::ReadStringWithLength(data_ptr_t target, int32_t offset) { + auto ptr = target + offset; + auto str_length = Load(ptr); + auto str_ptr = char_ptr_cast(ptr + sizeof(uint32_t)); + return string_t(str_ptr, str_length); +} + +void UncompressedStringStorage::WriteStringMarker(data_ptr_t target, block_id_t block_id, int32_t offset) { + memcpy(target, &block_id, sizeof(block_id_t)); + target += sizeof(block_id_t); + memcpy(target, &offset, sizeof(int32_t)); +} + +void UncompressedStringStorage::ReadStringMarker(data_ptr_t target, block_id_t &block_id, int32_t &offset) { + memcpy(&block_id, target, sizeof(block_id_t)); + target += sizeof(block_id_t); + memcpy(&offset, target, sizeof(int32_t)); +} + +string_location_t UncompressedStringStorage::FetchStringLocation(StringDictionaryContainer dict, data_ptr_t baseptr, + int32_t dict_offset) { + D_ASSERT(dict_offset >= -1 * Storage::BLOCK_SIZE && dict_offset <= Storage::BLOCK_SIZE); + if (dict_offset < 0) { + string_location_t result; + ReadStringMarker(baseptr + dict.end - (-1 * dict_offset), result.block_id, result.offset); + return result; + } else { + return string_location_t(INVALID_BLOCK, dict_offset); + } +} + +string_t UncompressedStringStorage::FetchStringFromDict(ColumnSegment &segment, StringDictionaryContainer dict, + Vector &result, data_ptr_t baseptr, int32_t dict_offset, + uint32_t string_length) { + // fetch base data + D_ASSERT(dict_offset <= Storage::BLOCK_SIZE); + string_location_t location = FetchStringLocation(dict, baseptr, dict_offset); + return FetchString(segment, dict, result, baseptr, location, string_length); +} + +string_t UncompressedStringStorage::FetchString(ColumnSegment &segment, StringDictionaryContainer dict, Vector &result, + data_ptr_t baseptr, string_location_t location, + uint32_t string_length) { + if (location.block_id != INVALID_BLOCK) { + // big string marker: read from separate block + return ReadOverflowString(segment, result, location.block_id, location.offset); + } else { + if (location.offset == 0) { + return string_t(nullptr, 0); + } + // normal string: read string from this block + auto dict_end = baseptr + dict.end; + auto dict_pos = dict_end - location.offset; + + auto str_ptr = char_ptr_cast(dict_pos); + return string_t(str_ptr, string_length); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/uncompressed.cpp b/src/duckdb/src/storage/compression/uncompressed.cpp new file mode 100644 index 00000000..9955bf8a --- /dev/null +++ b/src/duckdb/src/storage/compression/uncompressed.cpp @@ -0,0 +1,36 @@ +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/storage/segment/uncompressed.hpp" + +namespace duckdb { + +CompressionFunction UncompressedFun::GetFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::INT128: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + case PhysicalType::LIST: + case PhysicalType::INTERVAL: + return FixedSizeUncompressed::GetFunction(type); + case PhysicalType::BIT: + return ValidityUncompressed::GetFunction(type); + case PhysicalType::VARCHAR: + return StringUncompressed::GetFunction(type); + default: + throw InternalException("Unsupported type for Uncompressed"); + } +} + +bool UncompressedFun::TypeIsSupported(PhysicalType type) { + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/validity_uncompressed.cpp b/src/duckdb/src/storage/compression/validity_uncompressed.cpp new file mode 100644 index 00000000..66b8f4f8 --- /dev/null +++ b/src/duckdb/src/storage/compression/validity_uncompressed.cpp @@ -0,0 +1,478 @@ +#include "duckdb/storage/segment/uncompressed.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/storage/table/append_state.hpp" + +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Mask constants +//===--------------------------------------------------------------------===// +// LOWER_MASKS contains masks with all the lower bits set until a specific value +// LOWER_MASKS[0] has the 0 lowest bits set, i.e.: +// 0b0000000000000000000000000000000000000000000000000000000000000000, +// LOWER_MASKS[10] has the 10 lowest bits set, i.e.: +// 0b0000000000000000000000000000000000000000000000000000000111111111, +// etc... +// 0b0000000000000000000000000000000000000001111111111111111111111111, +// ... +// 0b0000000000000000000001111111111111111111111111111111111111111111, +// until LOWER_MASKS[64], which has all bits set: +// 0b1111111111111111111111111111111111111111111111111111111111111111 +// generated with this python snippet: +// for i in range(65): +// print(hex(int((64 - i) * '0' + i * '1', 2)) + ",") +const validity_t ValidityUncompressed::LOWER_MASKS[] = {0x0, + 0x1, + 0x3, + 0x7, + 0xf, + 0x1f, + 0x3f, + 0x7f, + 0xff, + 0x1ff, + 0x3ff, + 0x7ff, + 0xfff, + 0x1fff, + 0x3fff, + 0x7fff, + 0xffff, + 0x1ffff, + 0x3ffff, + 0x7ffff, + 0xfffff, + 0x1fffff, + 0x3fffff, + 0x7fffff, + 0xffffff, + 0x1ffffff, + 0x3ffffff, + 0x7ffffff, + 0xfffffff, + 0x1fffffff, + 0x3fffffff, + 0x7fffffff, + 0xffffffff, + 0x1ffffffff, + 0x3ffffffff, + 0x7ffffffff, + 0xfffffffff, + 0x1fffffffff, + 0x3fffffffff, + 0x7fffffffff, + 0xffffffffff, + 0x1ffffffffff, + 0x3ffffffffff, + 0x7ffffffffff, + 0xfffffffffff, + 0x1fffffffffff, + 0x3fffffffffff, + 0x7fffffffffff, + 0xffffffffffff, + 0x1ffffffffffff, + 0x3ffffffffffff, + 0x7ffffffffffff, + 0xfffffffffffff, + 0x1fffffffffffff, + 0x3fffffffffffff, + 0x7fffffffffffff, + 0xffffffffffffff, + 0x1ffffffffffffff, + 0x3ffffffffffffff, + 0x7ffffffffffffff, + 0xfffffffffffffff, + 0x1fffffffffffffff, + 0x3fffffffffffffff, + 0x7fffffffffffffff, + 0xffffffffffffffff}; + +// UPPER_MASKS contains masks with all the highest bits set until a specific value +// UPPER_MASKS[0] has the 0 highest bits set, i.e.: +// 0b0000000000000000000000000000000000000000000000000000000000000000, +// UPPER_MASKS[10] has the 10 highest bits set, i.e.: +// 0b1111111111110000000000000000000000000000000000000000000000000000, +// etc... +// 0b1111111111111111111111110000000000000000000000000000000000000000, +// ... +// 0b1111111111111111111111111111111111111110000000000000000000000000, +// until UPPER_MASKS[64], which has all bits set: +// 0b1111111111111111111111111111111111111111111111111111111111111111 +// generated with this python snippet: +// for i in range(65): +// print(hex(int(i * '1' + (64 - i) * '0', 2)) + ",") +const validity_t ValidityUncompressed::UPPER_MASKS[] = {0x0, + 0x8000000000000000, + 0xc000000000000000, + 0xe000000000000000, + 0xf000000000000000, + 0xf800000000000000, + 0xfc00000000000000, + 0xfe00000000000000, + 0xff00000000000000, + 0xff80000000000000, + 0xffc0000000000000, + 0xffe0000000000000, + 0xfff0000000000000, + 0xfff8000000000000, + 0xfffc000000000000, + 0xfffe000000000000, + 0xffff000000000000, + 0xffff800000000000, + 0xffffc00000000000, + 0xffffe00000000000, + 0xfffff00000000000, + 0xfffff80000000000, + 0xfffffc0000000000, + 0xfffffe0000000000, + 0xffffff0000000000, + 0xffffff8000000000, + 0xffffffc000000000, + 0xffffffe000000000, + 0xfffffff000000000, + 0xfffffff800000000, + 0xfffffffc00000000, + 0xfffffffe00000000, + 0xffffffff00000000, + 0xffffffff80000000, + 0xffffffffc0000000, + 0xffffffffe0000000, + 0xfffffffff0000000, + 0xfffffffff8000000, + 0xfffffffffc000000, + 0xfffffffffe000000, + 0xffffffffff000000, + 0xffffffffff800000, + 0xffffffffffc00000, + 0xffffffffffe00000, + 0xfffffffffff00000, + 0xfffffffffff80000, + 0xfffffffffffc0000, + 0xfffffffffffe0000, + 0xffffffffffff0000, + 0xffffffffffff8000, + 0xffffffffffffc000, + 0xffffffffffffe000, + 0xfffffffffffff000, + 0xfffffffffffff800, + 0xfffffffffffffc00, + 0xfffffffffffffe00, + 0xffffffffffffff00, + 0xffffffffffffff80, + 0xffffffffffffffc0, + 0xffffffffffffffe0, + 0xfffffffffffffff0, + 0xfffffffffffffff8, + 0xfffffffffffffffc, + 0xfffffffffffffffe, + 0xffffffffffffffff}; + +//===--------------------------------------------------------------------===// +// Analyze +//===--------------------------------------------------------------------===// +struct ValidityAnalyzeState : public AnalyzeState { + ValidityAnalyzeState() : count(0) { + } + + idx_t count; +}; + +unique_ptr ValidityInitAnalyze(ColumnData &col_data, PhysicalType type) { + return make_uniq(); +} + +bool ValidityAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { + auto &state = state_p.Cast(); + state.count += count; + return true; +} + +idx_t ValidityFinalAnalyze(AnalyzeState &state_p) { + auto &state = state_p.Cast(); + return (state.count + 7) / 8; +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +struct ValidityScanState : public SegmentScanState { + BufferHandle handle; + block_id_t block_id; +}; + +unique_ptr ValidityInitScan(ColumnSegment &segment) { + auto result = make_uniq(); + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + result->handle = buffer_manager.Pin(segment.block); + result->block_id = segment.block->BlockId(); + return std::move(result); +} + +//===--------------------------------------------------------------------===// +// Scan base data +//===--------------------------------------------------------------------===// +void ValidityScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + auto start = segment.GetRelativeIndex(state.row_index); + + static_assert(sizeof(validity_t) == sizeof(uint64_t), "validity_t should be 64-bit"); + auto &scan_state = state.scan_state->Cast(); + + auto &result_mask = FlatVector::Validity(result); + auto buffer_ptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); + D_ASSERT(scan_state.block_id == segment.block->BlockId()); + auto input_data = reinterpret_cast(buffer_ptr); + +#ifdef DEBUG + // this method relies on all the bits we are going to write to being set to valid + for (idx_t i = 0; i < scan_count; i++) { + D_ASSERT(result_mask.RowIsValid(result_offset + i)); + } +#endif +#if STANDARD_VECTOR_SIZE < 128 + // fallback for tiny vector sizes + // the bitwise ops we use below don't work if the vector size is too small + ValidityMask source_mask(input_data); + for (idx_t i = 0; i < scan_count; i++) { + if (!source_mask.RowIsValid(start + i)) { + if (result_mask.AllValid()) { + result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, result_offset + scan_count)); + } + result_mask.SetInvalid(result_offset + i); + } + } +#else + // the code below does what the fallback code above states, but using bitwise ops: + auto result_data = (validity_t *)result_mask.GetData(); + + // set up the initial positions + // we need to find the validity_entry to modify, together with the bit-index WITHIN the validity entry + idx_t result_entry = result_offset / ValidityMask::BITS_PER_VALUE; + idx_t result_idx = result_offset - result_entry * ValidityMask::BITS_PER_VALUE; + + // same for the input: find the validity_entry we are pulling from, together with the bit-index WITHIN that entry + idx_t input_entry = start / ValidityMask::BITS_PER_VALUE; + idx_t input_idx = start - input_entry * ValidityMask::BITS_PER_VALUE; + + // now start the bit games + idx_t pos = 0; + while (pos < scan_count) { + // these are the current validity entries we are dealing with + idx_t current_result_idx = result_entry; + idx_t offset; + validity_t input_mask = input_data[input_entry]; + + // construct the mask to AND together with the result + if (result_idx < input_idx) { + // we have to shift the input RIGHT if the result_idx is smaller than the input_idx + auto shift_amount = input_idx - result_idx; + D_ASSERT(shift_amount > 0 && shift_amount <= ValidityMask::BITS_PER_VALUE); + + input_mask = input_mask >> shift_amount; + + // now the upper "shift_amount" bits are set to 0 + // we need them to be set to 1 + // otherwise the subsequent bitwise & will modify values outside of the range of values we want to alter + input_mask |= ValidityUncompressed::UPPER_MASKS[shift_amount]; + + // after this, we move to the next input_entry + offset = ValidityMask::BITS_PER_VALUE - input_idx; + input_entry++; + input_idx = 0; + result_idx += offset; + } else if (result_idx > input_idx) { + // we have to shift the input LEFT if the result_idx is bigger than the input_idx + auto shift_amount = result_idx - input_idx; + D_ASSERT(shift_amount > 0 && shift_amount <= ValidityMask::BITS_PER_VALUE); + + // to avoid overflows, we set the upper "shift_amount" values to 0 first + input_mask = (input_mask & ~ValidityUncompressed::UPPER_MASKS[shift_amount]) << shift_amount; + + // now the lower "shift_amount" bits are set to 0 + // we need them to be set to 1 + // otherwise the subsequent bitwise & will modify values outside of the range of values we want to alter + input_mask |= ValidityUncompressed::LOWER_MASKS[shift_amount]; + + // after this, we move to the next result_entry + offset = ValidityMask::BITS_PER_VALUE - result_idx; + result_entry++; + result_idx = 0; + input_idx += offset; + } else { + // if the input_idx is equal to result_idx they are already aligned + // we just move to the next entry for both after this + offset = ValidityMask::BITS_PER_VALUE - result_idx; + input_entry++; + result_entry++; + result_idx = input_idx = 0; + } + // now we need to check if we should include the ENTIRE mask + // OR if we need to mask from the right side + pos += offset; + if (pos > scan_count) { + // we need to set any bits that are past the scan_count on the right-side to 1 + // this is required so we don't influence any bits that are not part of the scan + input_mask |= ValidityUncompressed::UPPER_MASKS[pos - scan_count]; + } + // now finally we can merge the input mask with the result mask + if (input_mask != ValidityMask::ValidityBuffer::MAX_ENTRY) { + if (!result_data) { + result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, result_offset + scan_count)); + result_data = (validity_t *)result_mask.GetData(); + } + result_data[current_result_idx] &= input_mask; + } + } +#endif + +#ifdef DEBUG + // verify that we actually accomplished the bitwise ops equivalent that we wanted to do + ValidityMask input_mask(input_data); + for (idx_t i = 0; i < scan_count; i++) { + D_ASSERT(result_mask.RowIsValid(result_offset + i) == input_mask.RowIsValid(start + i)); + } +#endif +} + +void ValidityScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + result.Flatten(scan_count); + + auto start = segment.GetRelativeIndex(state.row_index); + if (start % ValidityMask::BITS_PER_VALUE == 0) { + auto &scan_state = state.scan_state->Cast(); + + // aligned scan: no need to do anything fancy + // note: this is only an optimization which avoids having to do messy bitshifting in the common case + // it is not required for correctness + auto &result_mask = FlatVector::Validity(result); + auto buffer_ptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); + D_ASSERT(scan_state.block_id == segment.block->BlockId()); + auto input_data = reinterpret_cast(buffer_ptr); + auto result_data = result_mask.GetData(); + idx_t start_offset = start / ValidityMask::BITS_PER_VALUE; + idx_t entry_scan_count = (scan_count + ValidityMask::BITS_PER_VALUE - 1) / ValidityMask::BITS_PER_VALUE; + for (idx_t i = 0; i < entry_scan_count; i++) { + auto input_entry = input_data[start_offset + i]; + if (!result_data && input_entry == ValidityMask::ValidityBuffer::MAX_ENTRY) { + continue; + } + if (!result_data) { + result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, scan_count)); + result_data = result_mask.GetData(); + } + result_data[i] = input_entry; + } + } else { + // unaligned scan: fall back to scan_partial which does bitshift tricks + ValidityScanPartial(segment, state, scan_count, result, 0); + } +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +void ValidityFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { + D_ASSERT(row_id >= 0 && row_id < row_t(segment.count)); + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + auto dataptr = handle.Ptr() + segment.GetBlockOffset(); + ValidityMask mask(reinterpret_cast(dataptr)); + auto &result_mask = FlatVector::Validity(result); + if (!mask.RowIsValidUnsafe(row_id)) { + result_mask.SetInvalid(result_idx); + } +} + +//===--------------------------------------------------------------------===// +// Append +//===--------------------------------------------------------------------===// +static unique_ptr ValidityInitAppend(ColumnSegment &segment) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + return make_uniq(std::move(handle)); +} + +unique_ptr ValidityInitSegment(ColumnSegment &segment, block_id_t block_id, + optional_ptr segment_state) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + if (block_id == INVALID_BLOCK) { + auto handle = buffer_manager.Pin(segment.block); + memset(handle.Ptr(), 0xFF, segment.SegmentSize()); + } + return nullptr; +} + +idx_t ValidityAppend(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, + UnifiedVectorFormat &data, idx_t offset, idx_t vcount) { + D_ASSERT(segment.GetBlockOffset() == 0); + auto &validity_stats = stats.statistics; + + auto max_tuples = segment.SegmentSize() / ValidityMask::STANDARD_MASK_SIZE * STANDARD_VECTOR_SIZE; + idx_t append_count = MinValue(vcount, max_tuples - segment.count); + if (data.validity.AllValid()) { + // no null values: skip append + segment.count += append_count; + validity_stats.SetHasNoNull(); + return append_count; + } + + ValidityMask mask(reinterpret_cast(append_state.handle.Ptr())); + for (idx_t i = 0; i < append_count; i++) { + auto idx = data.sel->get_index(offset + i); + if (!data.validity.RowIsValidUnsafe(idx)) { + mask.SetInvalidUnsafe(segment.count + i); + validity_stats.SetHasNull(); + } else { + validity_stats.SetHasNoNull(); + } + } + segment.count += append_count; + return append_count; +} + +idx_t ValidityFinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { + return ((segment.count + STANDARD_VECTOR_SIZE - 1) / STANDARD_VECTOR_SIZE) * ValidityMask::STANDARD_MASK_SIZE; +} + +void ValidityRevertAppend(ColumnSegment &segment, idx_t start_row) { + idx_t start_bit = start_row - segment.start; + + auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto handle = buffer_manager.Pin(segment.block); + idx_t revert_start; + if (start_bit % 8 != 0) { + // handle sub-bit stuff (yay) + idx_t byte_pos = start_bit / 8; + idx_t bit_end = (byte_pos + 1) * 8; + ValidityMask mask(reinterpret_cast(handle.Ptr())); + for (idx_t i = start_bit; i < bit_end; i++) { + mask.SetValid(i); + } + revert_start = bit_end / 8; + } else { + revert_start = start_bit / 8; + } + // for the rest, we just memset + memset(handle.Ptr() + revert_start, 0xFF, segment.SegmentSize() - revert_start); +} + +//===--------------------------------------------------------------------===// +// Get Function +//===--------------------------------------------------------------------===// +CompressionFunction ValidityUncompressed::GetFunction(PhysicalType data_type) { + D_ASSERT(data_type == PhysicalType::BIT); + return CompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, data_type, ValidityInitAnalyze, + ValidityAnalyze, ValidityFinalAnalyze, UncompressedFunctions::InitCompression, + UncompressedFunctions::Compress, UncompressedFunctions::FinalizeCompress, + ValidityInitScan, ValidityScan, ValidityScanPartial, ValidityFetchRow, + UncompressedFunctions::EmptySkip, ValidityInitSegment, ValidityInitAppend, + ValidityAppend, ValidityFinalizeAppend, ValidityRevertAppend); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/data_pointer.cpp b/src/duckdb/src/storage/data_pointer.cpp new file mode 100644 index 00000000..29718e72 --- /dev/null +++ b/src/duckdb/src/storage/data_pointer.cpp @@ -0,0 +1,20 @@ +#include "duckdb/storage/data_pointer.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/function/compression_function.hpp" + +namespace duckdb { + +unique_ptr ColumnSegmentState::Deserialize(Deserializer &deserializer) { + auto compression_type = deserializer.Get(); + auto &db = deserializer.Get(); + auto &type = deserializer.Get(); + auto compression_function = DBConfig::GetConfig(db).GetCompressionFunction(compression_type, type.InternalType()); + if (!compression_function || !compression_function->deserialize_state) { + throw SerializationException("Deserializing a ColumnSegmentState but could not find deserialize method"); + } + return compression_function->deserialize_state(deserializer); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/data_table.cpp b/src/duckdb/src/storage/data_table.cpp new file mode 100644 index 00000000..45a0a379 --- /dev/null +++ b/src/duckdb/src/storage/data_table.cpp @@ -0,0 +1,1331 @@ +#include "duckdb/storage/data_table.hpp" + +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/chrono.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parser/constraints/list.hpp" +#include "duckdb/planner/constraints/list.hpp" +#include "duckdb/planner/expression_binder/check_binder.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/checkpoint/table_data_writer.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/storage/table/persistent_table_data.hpp" +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/storage/table/standard_column_data.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/transaction/transaction_manager.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/common/types/conflict_manager.hpp" +#include "duckdb/common/types/constraint_conflict_info.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +DataTableInfo::DataTableInfo(AttachedDatabase &db, shared_ptr table_io_manager_p, string schema, + string table) + : db(db), table_io_manager(std::move(table_io_manager_p)), cardinality(0), schema(std::move(schema)), + table(std::move(table)) { +} + +bool DataTableInfo::IsTemporary() const { + return db.IsTemporary(); +} + +DataTable::DataTable(AttachedDatabase &db, shared_ptr table_io_manager_p, const string &schema, + const string &table, vector column_definitions_p, + unique_ptr data) + : info(make_shared(db, std::move(table_io_manager_p), schema, table)), + column_definitions(std::move(column_definitions_p)), db(db), is_root(true) { + // initialize the table with the existing data from disk, if any + auto types = GetTypes(); + this->row_groups = + make_shared(info, TableIOManager::Get(*this).GetBlockManagerForRowData(), types, 0); + if (data && data->row_group_count > 0) { + this->row_groups->Initialize(*data); + } else { + this->row_groups->InitializeEmpty(); + D_ASSERT(row_groups->GetTotalRows() == 0); + } + row_groups->Verify(); +} + +DataTable::DataTable(ClientContext &context, DataTable &parent, ColumnDefinition &new_column, Expression &default_value) + : info(parent.info), db(parent.db), is_root(true) { + // add the column definitions from this DataTable + for (auto &column_def : parent.column_definitions) { + column_definitions.emplace_back(column_def.Copy()); + } + column_definitions.emplace_back(new_column.Copy()); + // prevent any new tuples from being added to the parent + lock_guard parent_lock(parent.append_lock); + + this->row_groups = parent.row_groups->AddColumn(context, new_column, default_value); + + // also add this column to client local storage + auto &local_storage = LocalStorage::Get(context, db); + local_storage.AddColumn(parent, *this, new_column, default_value); + + // this table replaces the previous table, hence the parent is no longer the root DataTable + parent.is_root = false; +} + +DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t removed_column) + : info(parent.info), db(parent.db), is_root(true) { + // prevent any new tuples from being added to the parent + lock_guard parent_lock(parent.append_lock); + + for (auto &column_def : parent.column_definitions) { + column_definitions.emplace_back(column_def.Copy()); + } + // first check if there are any indexes that exist that point to the removed column + info->indexes.Scan([&](Index &index) { + for (auto &column_id : index.column_ids) { + if (column_id == removed_column) { + throw CatalogException("Cannot drop this column: an index depends on it!"); + } else if (column_id > removed_column) { + throw CatalogException("Cannot drop this column: an index depends on a column after it!"); + } + } + return false; + }); + + // erase the column definitions from this DataTable + D_ASSERT(removed_column < column_definitions.size()); + column_definitions.erase(column_definitions.begin() + removed_column); + + storage_t storage_idx = 0; + for (idx_t i = 0; i < column_definitions.size(); i++) { + auto &col = column_definitions[i]; + col.SetOid(i); + if (col.Generated()) { + continue; + } + col.SetStorageOid(storage_idx++); + } + + // alter the row_groups and remove the column from each of them + this->row_groups = parent.row_groups->RemoveColumn(removed_column); + + // scan the original table, and fill the new column with the transformed value + auto &local_storage = LocalStorage::Get(context, db); + local_storage.DropColumn(parent, *this, removed_column); + + // this table replaces the previous table, hence the parent is no longer the root DataTable + parent.is_root = false; +} + +// Alter column to add new constraint +DataTable::DataTable(ClientContext &context, DataTable &parent, unique_ptr constraint) + : info(parent.info), db(parent.db), row_groups(parent.row_groups), is_root(true) { + + lock_guard parent_lock(parent.append_lock); + for (auto &column_def : parent.column_definitions) { + column_definitions.emplace_back(column_def.Copy()); + } + + // Verify the new constraint against current persistent/local data + VerifyNewConstraint(context, parent, constraint.get()); + + // Get the local data ownership from old dt + auto &local_storage = LocalStorage::Get(context, db); + local_storage.MoveStorage(parent, *this); + // this table replaces the previous table, hence the parent is no longer the root DataTable + parent.is_root = false; +} + +DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t changed_idx, const LogicalType &target_type, + const vector &bound_columns, Expression &cast_expr) + : info(parent.info), db(parent.db), is_root(true) { + // prevent any tuples from being added to the parent + lock_guard lock(append_lock); + for (auto &column_def : parent.column_definitions) { + column_definitions.emplace_back(column_def.Copy()); + } + // first check if there are any indexes that exist that point to the changed column + info->indexes.Scan([&](Index &index) { + for (auto &column_id : index.column_ids) { + if (column_id == changed_idx) { + throw CatalogException("Cannot change the type of this column: an index depends on it!"); + } + } + return false; + }); + + // change the type in this DataTable + column_definitions[changed_idx].SetType(target_type); + + // set up the statistics for the table + // the column that had its type changed will have the new statistics computed during conversion + this->row_groups = parent.row_groups->AlterType(context, changed_idx, target_type, bound_columns, cast_expr); + + // scan the original table, and fill the new column with the transformed value + auto &local_storage = LocalStorage::Get(context, db); + local_storage.ChangeType(parent, *this, changed_idx, target_type, bound_columns, cast_expr); + + // this table replaces the previous table, hence the parent is no longer the root DataTable + parent.is_root = false; +} + +vector DataTable::GetTypes() { + vector types; + for (auto &it : column_definitions) { + types.push_back(it.Type()); + } + return types; +} + +TableIOManager &TableIOManager::Get(DataTable &table) { + return *table.info->table_io_manager; +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +void DataTable::InitializeScan(TableScanState &state, const vector &column_ids, + TableFilterSet *table_filters) { + state.Initialize(column_ids, table_filters); + row_groups->InitializeScan(state.table_state, column_ids, table_filters); +} + +void DataTable::InitializeScan(DuckTransaction &transaction, TableScanState &state, const vector &column_ids, + TableFilterSet *table_filters) { + InitializeScan(state, column_ids, table_filters); + auto &local_storage = LocalStorage::Get(transaction); + local_storage.InitializeScan(*this, state.local_state, table_filters); +} + +void DataTable::InitializeScanWithOffset(TableScanState &state, const vector &column_ids, idx_t start_row, + idx_t end_row) { + state.Initialize(column_ids); + row_groups->InitializeScanWithOffset(state.table_state, column_ids, start_row, end_row); +} + +idx_t DataTable::MaxThreads(ClientContext &context) { + idx_t parallel_scan_vector_count = Storage::ROW_GROUP_VECTOR_COUNT; + if (ClientConfig::GetConfig(context).verify_parallelism) { + parallel_scan_vector_count = 1; + } + idx_t parallel_scan_tuple_count = STANDARD_VECTOR_SIZE * parallel_scan_vector_count; + return GetTotalRows() / parallel_scan_tuple_count + 1; +} + +void DataTable::InitializeParallelScan(ClientContext &context, ParallelTableScanState &state) { + row_groups->InitializeParallelScan(state.scan_state); + + auto &local_storage = LocalStorage::Get(context, db); + local_storage.InitializeParallelScan(*this, state.local_state); +} + +bool DataTable::NextParallelScan(ClientContext &context, ParallelTableScanState &state, TableScanState &scan_state) { + if (row_groups->NextParallelScan(context, state.scan_state, scan_state.table_state)) { + return true; + } + scan_state.table_state.batch_index = state.scan_state.batch_index; + auto &local_storage = LocalStorage::Get(context, db); + if (local_storage.NextParallelScan(context, *this, state.local_state, scan_state.local_state)) { + return true; + } else { + // finished all scans: no more scans remaining + return false; + } +} + +void DataTable::Scan(DuckTransaction &transaction, DataChunk &result, TableScanState &state) { + // scan the persistent segments + if (state.table_state.Scan(transaction, result)) { + D_ASSERT(result.size() > 0); + return; + } + + // scan the transaction-local segments + auto &local_storage = LocalStorage::Get(transaction); + local_storage.Scan(state.local_state, state.GetColumnIds(), result); +} + +bool DataTable::CreateIndexScan(TableScanState &state, DataChunk &result, TableScanType type) { + return state.table_state.ScanCommitted(result, type); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +void DataTable::Fetch(DuckTransaction &transaction, DataChunk &result, const vector &column_ids, + const Vector &row_identifiers, idx_t fetch_count, ColumnFetchState &state) { + row_groups->Fetch(transaction, result, column_ids, row_identifiers, fetch_count, state); +} + +//===--------------------------------------------------------------------===// +// Append +//===--------------------------------------------------------------------===// +static void VerifyNotNullConstraint(TableCatalogEntry &table, Vector &vector, idx_t count, const string &col_name) { + if (!VectorOperations::HasNull(vector, count)) { + return; + } + + throw ConstraintException("NOT NULL constraint failed: %s.%s", table.name, col_name); +} + +// To avoid throwing an error at SELECT, instead this moves the error detection to INSERT +static void VerifyGeneratedExpressionSuccess(ClientContext &context, TableCatalogEntry &table, DataChunk &chunk, + Expression &expr, column_t index) { + auto &col = table.GetColumn(LogicalIndex(index)); + D_ASSERT(col.Generated()); + ExpressionExecutor executor(context, expr); + Vector result(col.Type()); + try { + executor.ExecuteExpression(chunk, result); + } catch (InternalException &ex) { + throw; + } catch (std::exception &ex) { + throw ConstraintException("Incorrect value for generated column \"%s %s AS (%s)\" : %s", col.Name(), + col.Type().ToString(), col.GeneratedExpression().ToString(), ex.what()); + } +} + +static void VerifyCheckConstraint(ClientContext &context, TableCatalogEntry &table, Expression &expr, + DataChunk &chunk) { + ExpressionExecutor executor(context, expr); + Vector result(LogicalType::INTEGER); + try { + executor.ExecuteExpression(chunk, result); + } catch (std::exception &ex) { + throw ConstraintException("CHECK constraint failed: %s (Error: %s)", table.name, ex.what()); + } catch (...) { // LCOV_EXCL_START + throw ConstraintException("CHECK constraint failed: %s (Unknown Error)", table.name); + } // LCOV_EXCL_STOP + UnifiedVectorFormat vdata; + result.ToUnifiedFormat(chunk.size(), vdata); + + auto dataptr = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < chunk.size(); i++) { + auto idx = vdata.sel->get_index(i); + if (vdata.validity.RowIsValid(idx) && dataptr[idx] == 0) { + throw ConstraintException("CHECK constraint failed: %s", table.name); + } + } +} + +bool DataTable::IsForeignKeyIndex(const vector &fk_keys, Index &index, ForeignKeyType fk_type) { + if (fk_type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE ? !index.IsUnique() : !index.IsForeign()) { + return false; + } + if (fk_keys.size() != index.column_ids.size()) { + return false; + } + for (auto &fk_key : fk_keys) { + bool is_found = false; + for (auto &index_key : index.column_ids) { + if (fk_key.index == index_key) { + is_found = true; + break; + } + } + if (!is_found) { + return false; + } + } + return true; +} + +// Find the first index that is not null, and did not find a match +static idx_t FirstMissingMatch(const ManagedSelection &matches) { + idx_t match_idx = 0; + + for (idx_t i = 0; i < matches.Size(); i++) { + auto match = matches.IndexMapsToLocation(match_idx, i); + match_idx += match; + if (!match) { + // This index is missing in the matches vector + return i; + } + } + return DConstants::INVALID_INDEX; +} + +idx_t LocateErrorIndex(bool is_append, const ManagedSelection &matches) { + idx_t failed_index = DConstants::INVALID_INDEX; + if (!is_append) { + // We expected to find nothing, so the first error is the first match + failed_index = matches[0]; + } else { + // We expected to find matches for all of them, so the first missing match is the first error + return FirstMissingMatch(matches); + } + return failed_index; +} + +[[noreturn]] static void ThrowForeignKeyConstraintError(idx_t failed_index, bool is_append, Index &index, + DataChunk &input) { + auto verify_type = is_append ? VerifyExistenceType::APPEND_FK : VerifyExistenceType::DELETE_FK; + + D_ASSERT(failed_index != DConstants::INVALID_INDEX); + D_ASSERT(index.type == IndexType::ART); + auto &art_index = index.Cast(); + auto key_name = art_index.GenerateErrorKeyName(input, failed_index); + auto exception_msg = art_index.GenerateConstraintErrorMessage(verify_type, key_name); + throw ConstraintException(exception_msg); +} + +bool IsForeignKeyConstraintError(bool is_append, idx_t input_count, const ManagedSelection &matches) { + if (is_append) { + // We need to find a match for all of the values + return matches.Count() != input_count; + } else { + // We should not find any matches + return matches.Count() != 0; + } +} + +static bool IsAppend(VerifyExistenceType verify_type) { + return verify_type == VerifyExistenceType::APPEND_FK; +} + +void DataTable::VerifyForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, + DataChunk &chunk, VerifyExistenceType verify_type) { + const vector *src_keys_ptr = &bfk.info.fk_keys; + const vector *dst_keys_ptr = &bfk.info.pk_keys; + + bool is_append = IsAppend(verify_type); + if (!is_append) { + src_keys_ptr = &bfk.info.pk_keys; + dst_keys_ptr = &bfk.info.fk_keys; + } + + auto &table_entry_ptr = + Catalog::GetEntry(context, INVALID_CATALOG, bfk.info.schema, bfk.info.table); + // make the data chunk to check + vector types; + for (auto &col : table_entry_ptr.GetColumns().Physical()) { + types.emplace_back(col.Type()); + } + DataChunk dst_chunk; + dst_chunk.InitializeEmpty(types); + for (idx_t i = 0; i < src_keys_ptr->size(); i++) { + dst_chunk.data[(*dst_keys_ptr)[i].index].Reference(chunk.data[(*src_keys_ptr)[i].index]); + } + dst_chunk.SetCardinality(chunk.size()); + auto &data_table = table_entry_ptr.GetStorage(); + + idx_t count = dst_chunk.size(); + if (count <= 0) { + return; + } + + // Set up a way to record conflicts, rather than directly throw on them + unordered_set empty_column_list; + ConflictInfo empty_conflict_info(empty_column_list, false); + ConflictManager regular_conflicts(verify_type, count, &empty_conflict_info); + ConflictManager transaction_conflicts(verify_type, count, &empty_conflict_info); + regular_conflicts.SetMode(ConflictManagerMode::SCAN); + transaction_conflicts.SetMode(ConflictManagerMode::SCAN); + + data_table.info->indexes.VerifyForeignKey(*dst_keys_ptr, dst_chunk, regular_conflicts); + regular_conflicts.Finalize(); + auto ®ular_matches = regular_conflicts.Conflicts(); + + // check if we can insert the chunk into the reference table's local storage + auto &local_storage = LocalStorage::Get(context, db); + bool error = IsForeignKeyConstraintError(is_append, count, regular_matches); + bool transaction_error = false; + bool transaction_check = local_storage.Find(data_table); + + if (transaction_check) { + auto &transact_index = local_storage.GetIndexes(data_table); + transact_index.VerifyForeignKey(*dst_keys_ptr, dst_chunk, transaction_conflicts); + transaction_conflicts.Finalize(); + auto &transaction_matches = transaction_conflicts.Conflicts(); + transaction_error = IsForeignKeyConstraintError(is_append, count, transaction_matches); + } + + if (!transaction_error && !error) { + // No error occurred; + return; + } + + // Some error occurred, and we likely want to throw + optional_ptr index; + optional_ptr transaction_index; + + auto fk_type = is_append ? ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE : ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; + // check whether or not the chunk can be inserted or deleted into the referenced table' storage + index = data_table.info->indexes.FindForeignKeyIndex(*dst_keys_ptr, fk_type); + if (transaction_check) { + auto &transact_index = local_storage.GetIndexes(data_table); + // check whether or not the chunk can be inserted or deleted into the referenced table' storage + transaction_index = transact_index.FindForeignKeyIndex(*dst_keys_ptr, fk_type); + } + + if (!transaction_check) { + // Only local state is checked, throw the error + D_ASSERT(error); + auto failed_index = LocateErrorIndex(is_append, regular_matches); + D_ASSERT(failed_index != DConstants::INVALID_INDEX); + ThrowForeignKeyConstraintError(failed_index, is_append, *index, dst_chunk); + } + if (transaction_error && error && is_append) { + // When we want to do an append, we only throw if the foreign key does not exist in both transaction and local + // storage + auto &transaction_matches = transaction_conflicts.Conflicts(); + idx_t failed_index = DConstants::INVALID_INDEX; + idx_t regular_idx = 0; + idx_t transaction_idx = 0; + for (idx_t i = 0; i < count; i++) { + bool in_regular = regular_matches.IndexMapsToLocation(regular_idx, i); + regular_idx += in_regular; + bool in_transaction = transaction_matches.IndexMapsToLocation(transaction_idx, i); + transaction_idx += in_transaction; + + if (!in_regular && !in_transaction) { + // We need to find a match for all of the input values + // The failed index is i, it does not show up in either regular or transaction storage + failed_index = i; + break; + } + } + if (failed_index == DConstants::INVALID_INDEX) { + // We don't throw, every value was present in either regular or transaction storage + return; + } + ThrowForeignKeyConstraintError(failed_index, true, *index, dst_chunk); + } + if (!is_append && transaction_check) { + auto &transaction_matches = transaction_conflicts.Conflicts(); + if (error) { + auto failed_index = LocateErrorIndex(false, regular_matches); + D_ASSERT(failed_index != DConstants::INVALID_INDEX); + ThrowForeignKeyConstraintError(failed_index, false, *index, dst_chunk); + } else { + D_ASSERT(transaction_error); + D_ASSERT(transaction_matches.Count() != DConstants::INVALID_INDEX); + auto failed_index = LocateErrorIndex(false, transaction_matches); + D_ASSERT(failed_index != DConstants::INVALID_INDEX); + ThrowForeignKeyConstraintError(failed_index, false, *transaction_index, dst_chunk); + } + } +} + +void DataTable::VerifyAppendForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, + DataChunk &chunk) { + VerifyForeignKeyConstraint(bfk, context, chunk, VerifyExistenceType::APPEND_FK); +} + +void DataTable::VerifyDeleteForeignKeyConstraint(const BoundForeignKeyConstraint &bfk, ClientContext &context, + DataChunk &chunk) { + VerifyForeignKeyConstraint(bfk, context, chunk, VerifyExistenceType::DELETE_FK); +} + +void DataTable::VerifyNewConstraint(ClientContext &context, DataTable &parent, const BoundConstraint *constraint) { + if (constraint->type != ConstraintType::NOT_NULL) { + throw NotImplementedException("FIXME: ALTER COLUMN with such constraint is not supported yet"); + } + + parent.row_groups->VerifyNewConstraint(parent, *constraint); + auto &local_storage = LocalStorage::Get(context, db); + local_storage.VerifyNewConstraint(parent, *constraint); +} + +bool HasUniqueIndexes(TableIndexList &list) { + bool has_unique_index = false; + list.Scan([&](Index &index) { + if (index.IsUnique()) { + return has_unique_index = true; + return true; + } + return false; + }); + return has_unique_index; +} + +void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &context, DataChunk &chunk, + ConflictManager *conflict_manager) { + //! check whether or not the chunk can be inserted into the indexes + if (!conflict_manager) { + // Only need to verify that no unique constraints are violated + indexes.Scan([&](Index &index) { + if (!index.IsUnique()) { + return false; + } + index.VerifyAppend(chunk); + return false; + }); + return; + } + + D_ASSERT(conflict_manager); + // The conflict manager is only provided when a ON CONFLICT clause was provided to the INSERT statement + + idx_t matching_indexes = 0; + auto &conflict_info = conflict_manager->GetConflictInfo(); + // First we figure out how many indexes match our conflict target + // So we can optimize accordingly + indexes.Scan([&](Index &index) { + matching_indexes += conflict_info.ConflictTargetMatches(index); + return false; + }); + conflict_manager->SetMode(ConflictManagerMode::SCAN); + conflict_manager->SetIndexCount(matching_indexes); + // First we verify only the indexes that match our conflict target + unordered_set checked_indexes; + indexes.Scan([&](Index &index) { + if (!index.IsUnique()) { + return false; + } + if (conflict_info.ConflictTargetMatches(index)) { + index.VerifyAppend(chunk, *conflict_manager); + checked_indexes.insert(&index); + } + return false; + }); + + conflict_manager->SetMode(ConflictManagerMode::THROW); + // Then we scan the other indexes, throwing if they cause conflicts on tuples that were not found during + // the scan + indexes.Scan([&](Index &index) { + if (!index.IsUnique()) { + return false; + } + if (checked_indexes.count(&index)) { + // Already checked this constraint + return false; + } + index.VerifyAppend(chunk, *conflict_manager); + return false; + }); +} + +void DataTable::VerifyAppendConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, + ConflictManager *conflict_manager) { + if (table.HasGeneratedColumns()) { + // Verify that the generated columns expression work with the inserted values + auto binder = Binder::CreateBinder(context); + physical_index_set_t bound_columns; + CheckBinder generated_check_binder(*binder, context, table.name, table.GetColumns(), bound_columns); + for (auto &col : table.GetColumns().Logical()) { + if (!col.Generated()) { + continue; + } + D_ASSERT(col.Type().id() != LogicalTypeId::ANY); + generated_check_binder.target_type = col.Type(); + auto to_be_bound_expression = col.GeneratedExpression().Copy(); + auto bound_expression = generated_check_binder.Bind(to_be_bound_expression); + VerifyGeneratedExpressionSuccess(context, table, chunk, *bound_expression, col.Oid()); + } + } + + if (HasUniqueIndexes(info->indexes)) { + VerifyUniqueIndexes(info->indexes, context, chunk, conflict_manager); + } + + auto &constraints = table.GetConstraints(); + auto &bound_constraints = table.GetBoundConstraints(); + for (idx_t i = 0; i < bound_constraints.size(); i++) { + auto &base_constraint = constraints[i]; + auto &constraint = bound_constraints[i]; + switch (base_constraint->type) { + case ConstraintType::NOT_NULL: { + auto &bound_not_null = *reinterpret_cast(constraint.get()); + auto ¬_null = *reinterpret_cast(base_constraint.get()); + auto &col = table.GetColumns().GetColumn(LogicalIndex(not_null.index)); + VerifyNotNullConstraint(table, chunk.data[bound_not_null.index.index], chunk.size(), col.Name()); + break; + } + case ConstraintType::CHECK: { + auto &check = *reinterpret_cast(constraint.get()); + VerifyCheckConstraint(context, table, *check.expression, chunk); + break; + } + case ConstraintType::UNIQUE: { + // These were handled earlier on + break; + } + case ConstraintType::FOREIGN_KEY: { + auto &bfk = *reinterpret_cast(constraint.get()); + if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || + bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + VerifyAppendForeignKeyConstraint(bfk, context, chunk); + } + break; + } + default: + throw NotImplementedException("Constraint type not implemented!"); + } + } +} + +void DataTable::InitializeLocalAppend(LocalAppendState &state, ClientContext &context) { + if (!is_root) { + throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); + } + auto &local_storage = LocalStorage::Get(context, db); + local_storage.InitializeAppend(state, *this); +} + +void DataTable::LocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, + bool unsafe) { + if (chunk.size() == 0) { + return; + } + D_ASSERT(chunk.ColumnCount() == table.GetColumns().PhysicalColumnCount()); + if (!is_root) { + throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); + } + + chunk.Verify(); + + // verify any constraints on the new chunk + if (!unsafe) { + VerifyAppendConstraints(table, context, chunk); + } + + // append to the transaction local data + LocalStorage::Append(state, chunk); +} + +void DataTable::FinalizeLocalAppend(LocalAppendState &state) { + LocalStorage::FinalizeAppend(state); +} + +OptimisticDataWriter &DataTable::CreateOptimisticWriter(ClientContext &context) { + auto &local_storage = LocalStorage::Get(context, db); + return local_storage.CreateOptimisticWriter(*this); +} + +void DataTable::FinalizeOptimisticWriter(ClientContext &context, OptimisticDataWriter &writer) { + auto &local_storage = LocalStorage::Get(context, db); + local_storage.FinalizeOptimisticWriter(*this, writer); +} + +void DataTable::LocalMerge(ClientContext &context, RowGroupCollection &collection) { + auto &local_storage = LocalStorage::Get(context, db); + local_storage.LocalMerge(*this, collection); +} + +void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk) { + LocalAppendState append_state; + auto &storage = table.GetStorage(); + storage.InitializeLocalAppend(append_state, context); + storage.LocalAppend(append_state, table, context, chunk); + storage.FinalizeLocalAppend(append_state); +} + +void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection) { + LocalAppendState append_state; + auto &storage = table.GetStorage(); + storage.InitializeLocalAppend(append_state, context); + for (auto &chunk : collection.Chunks()) { + storage.LocalAppend(append_state, table, context, chunk); + } + storage.FinalizeLocalAppend(append_state); +} + +void DataTable::AppendLock(TableAppendState &state) { + state.append_lock = unique_lock(append_lock); + if (!is_root) { + throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); + } + state.row_start = row_groups->GetTotalRows(); + state.current_row = state.row_start; +} + +void DataTable::InitializeAppend(DuckTransaction &transaction, TableAppendState &state, idx_t append_count) { + // obtain the append lock for this table + if (!state.append_lock) { + throw InternalException("DataTable::AppendLock should be called before DataTable::InitializeAppend"); + } + row_groups->InitializeAppend(transaction, state, append_count); +} + +void DataTable::Append(DataChunk &chunk, TableAppendState &state) { + D_ASSERT(is_root); + row_groups->Append(chunk, state); +} + +void DataTable::ScanTableSegment(idx_t row_start, idx_t count, const std::function &function) { + if (count == 0) { + return; + } + idx_t end = row_start + count; + + vector column_ids; + vector types; + for (idx_t i = 0; i < this->column_definitions.size(); i++) { + auto &col = this->column_definitions[i]; + column_ids.push_back(i); + types.push_back(col.Type()); + } + DataChunk chunk; + chunk.Initialize(Allocator::Get(db), types); + + CreateIndexScanState state; + + InitializeScanWithOffset(state, column_ids, row_start, row_start + count); + auto row_start_aligned = state.table_state.row_group->start + state.table_state.vector_index * STANDARD_VECTOR_SIZE; + + idx_t current_row = row_start_aligned; + while (current_row < end) { + state.table_state.ScanCommitted(chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); + if (chunk.size() == 0) { + break; + } + idx_t end_row = current_row + chunk.size(); + // start of chunk is current_row + // end of chunk is end_row + // figure out if we need to write the entire chunk or just part of it + idx_t chunk_start = MaxValue(current_row, row_start); + idx_t chunk_end = MinValue(end_row, end); + D_ASSERT(chunk_start < chunk_end); + idx_t chunk_count = chunk_end - chunk_start; + if (chunk_count != chunk.size()) { + D_ASSERT(chunk_count <= chunk.size()); + // need to slice the chunk before insert + idx_t start_in_chunk; + if (current_row >= row_start) { + start_in_chunk = 0; + } else { + start_in_chunk = row_start - current_row; + } + SelectionVector sel(start_in_chunk, chunk_count); + chunk.Slice(sel, chunk_count); + chunk.Verify(); + } + function(chunk); + chunk.Reset(); + current_row = end_row; + } +} + +void DataTable::MergeStorage(RowGroupCollection &data, TableIndexList &indexes) { + row_groups->MergeStorage(data); + row_groups->Verify(); +} + +void DataTable::WriteToLog(WriteAheadLog &log, idx_t row_start, idx_t count) { + if (log.skip_writing) { + return; + } + log.WriteSetTable(info->schema, info->table); + ScanTableSegment(row_start, count, [&](DataChunk &chunk) { log.WriteInsert(chunk); }); +} + +void DataTable::CommitAppend(transaction_t commit_id, idx_t row_start, idx_t count) { + lock_guard lock(append_lock); + row_groups->CommitAppend(commit_id, row_start, count); + info->cardinality += count; +} + +void DataTable::RevertAppendInternal(idx_t start_row) { + // adjust the cardinality + info->cardinality = start_row; + D_ASSERT(is_root); + // revert appends made to row_groups + row_groups->RevertAppendInternal(start_row); +} + +void DataTable::RevertAppend(idx_t start_row, idx_t count) { + lock_guard lock(append_lock); + + // revert any appends to indexes + if (!info->indexes.Empty()) { + idx_t current_row_base = start_row; + row_t row_data[STANDARD_VECTOR_SIZE]; + Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_data)); + idx_t scan_count = MinValue(count, row_groups->GetTotalRows() - start_row); + ScanTableSegment(start_row, scan_count, [&](DataChunk &chunk) { + for (idx_t i = 0; i < chunk.size(); i++) { + row_data[i] = current_row_base + i; + } + info->indexes.Scan([&](Index &index) { + index.Delete(chunk, row_identifiers); + return false; + }); + current_row_base += chunk.size(); + }); + } + + // we need to vacuum the indexes to remove any buffers that are now empty + // due to reverting the appends + info->indexes.Scan([&](Index &index) { + index.Vacuum(); + return false; + }); + + // revert the data table append + RevertAppendInternal(start_row); +} + +//===--------------------------------------------------------------------===// +// Indexes +//===--------------------------------------------------------------------===// +PreservedError DataTable::AppendToIndexes(TableIndexList &indexes, DataChunk &chunk, row_t row_start) { + PreservedError error; + if (indexes.Empty()) { + return error; + } + // first generate the vector of row identifiers + Vector row_identifiers(LogicalType::ROW_TYPE); + VectorOperations::GenerateSequence(row_identifiers, chunk.size(), row_start, 1); + + vector already_appended; + bool append_failed = false; + // now append the entries to the indices + indexes.Scan([&](Index &index) { + try { + error = index.Append(chunk, row_identifiers); + } catch (Exception &ex) { + error = PreservedError(ex); + } catch (std::exception &ex) { + error = PreservedError(ex); + } + if (error) { + append_failed = true; + return true; + } + already_appended.push_back(&index); + return false; + }); + + if (append_failed) { + // constraint violation! + // remove any appended entries from previous indexes (if any) + for (auto *index : already_appended) { + index->Delete(chunk, row_identifiers); + } + } + return error; +} + +PreservedError DataTable::AppendToIndexes(DataChunk &chunk, row_t row_start) { + D_ASSERT(is_root); + return AppendToIndexes(info->indexes, chunk, row_start); +} + +void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, row_t row_start) { + D_ASSERT(is_root); + if (info->indexes.Empty()) { + return; + } + // first generate the vector of row identifiers + Vector row_identifiers(LogicalType::ROW_TYPE); + VectorOperations::GenerateSequence(row_identifiers, chunk.size(), row_start, 1); + + // now remove the entries from the indices + RemoveFromIndexes(state, chunk, row_identifiers); +} + +void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers) { + D_ASSERT(is_root); + info->indexes.Scan([&](Index &index) { + index.Delete(chunk, row_identifiers); + return false; + }); +} + +void DataTable::RemoveFromIndexes(Vector &row_identifiers, idx_t count) { + D_ASSERT(is_root); + row_groups->RemoveFromIndexes(info->indexes, row_identifiers, count); +} + +//===--------------------------------------------------------------------===// +// Delete +//===--------------------------------------------------------------------===// +static bool TableHasDeleteConstraints(TableCatalogEntry &table) { + auto &bound_constraints = table.GetBoundConstraints(); + for (auto &constraint : bound_constraints) { + switch (constraint->type) { + case ConstraintType::NOT_NULL: + case ConstraintType::CHECK: + case ConstraintType::UNIQUE: + break; + case ConstraintType::FOREIGN_KEY: { + auto &bfk = *reinterpret_cast(constraint.get()); + if (bfk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE || + bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + return true; + } + break; + } + default: + throw NotImplementedException("Constraint type not implemented!"); + } + } + return false; +} + +void DataTable::VerifyDeleteConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk) { + auto &bound_constraints = table.GetBoundConstraints(); + for (auto &constraint : bound_constraints) { + switch (constraint->type) { + case ConstraintType::NOT_NULL: + case ConstraintType::CHECK: + case ConstraintType::UNIQUE: + break; + case ConstraintType::FOREIGN_KEY: { + auto &bfk = *reinterpret_cast(constraint.get()); + if (bfk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE || + bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + VerifyDeleteForeignKeyConstraint(bfk, context, chunk); + } + break; + } + default: + throw NotImplementedException("Constraint type not implemented!"); + } + } +} + +idx_t DataTable::Delete(TableCatalogEntry &table, ClientContext &context, Vector &row_identifiers, idx_t count) { + D_ASSERT(row_identifiers.GetType().InternalType() == ROW_TYPE); + if (count == 0) { + return 0; + } + + auto &transaction = DuckTransaction::Get(context, db); + auto &local_storage = LocalStorage::Get(transaction); + bool has_delete_constraints = TableHasDeleteConstraints(table); + + row_identifiers.Flatten(count); + auto ids = FlatVector::GetData(row_identifiers); + + DataChunk verify_chunk; + vector col_ids; + vector types; + ColumnFetchState fetch_state; + if (has_delete_constraints) { + // initialize the chunk if there are any constraints to verify + for (idx_t i = 0; i < column_definitions.size(); i++) { + col_ids.push_back(column_definitions[i].StorageOid()); + types.emplace_back(column_definitions[i].Type()); + } + verify_chunk.Initialize(Allocator::Get(context), types); + } + idx_t pos = 0; + idx_t delete_count = 0; + while (pos < count) { + idx_t start = pos; + bool is_transaction_delete = ids[pos] >= MAX_ROW_ID; + // figure out which batch of rows to delete now + for (pos++; pos < count; pos++) { + bool row_is_transaction_delete = ids[pos] >= MAX_ROW_ID; + if (row_is_transaction_delete != is_transaction_delete) { + break; + } + } + idx_t current_offset = start; + idx_t current_count = pos - start; + + Vector offset_ids(row_identifiers, current_offset, pos); + if (is_transaction_delete) { + // transaction-local delete + if (has_delete_constraints) { + // perform the constraint verification + local_storage.FetchChunk(*this, offset_ids, current_count, col_ids, verify_chunk, fetch_state); + VerifyDeleteConstraints(table, context, verify_chunk); + } + delete_count += local_storage.Delete(*this, offset_ids, current_count); + } else { + // regular table delete + if (has_delete_constraints) { + // perform the constraint verification + Fetch(transaction, verify_chunk, col_ids, offset_ids, current_count, fetch_state); + VerifyDeleteConstraints(table, context, verify_chunk); + } + delete_count += row_groups->Delete(transaction, *this, ids + current_offset, current_count); + } + } + return delete_count; +} + +//===--------------------------------------------------------------------===// +// Update +//===--------------------------------------------------------------------===// +static void CreateMockChunk(vector &types, const vector &column_ids, DataChunk &chunk, + DataChunk &mock_chunk) { + // construct a mock DataChunk + mock_chunk.InitializeEmpty(types); + for (column_t i = 0; i < column_ids.size(); i++) { + mock_chunk.data[column_ids[i].index].Reference(chunk.data[i]); + } + mock_chunk.SetCardinality(chunk.size()); +} + +static bool CreateMockChunk(TableCatalogEntry &table, const vector &column_ids, + physical_index_set_t &desired_column_ids, DataChunk &chunk, DataChunk &mock_chunk) { + idx_t found_columns = 0; + // check whether the desired columns are present in the UPDATE clause + for (column_t i = 0; i < column_ids.size(); i++) { + if (desired_column_ids.find(column_ids[i]) != desired_column_ids.end()) { + found_columns++; + } + } + if (found_columns == 0) { + // no columns were found: no need to check the constraint again + return false; + } + if (found_columns != desired_column_ids.size()) { + // not all columns in UPDATE clause are present! + // this should not be triggered at all as the binder should add these columns + throw InternalException("Not all columns required for the CHECK constraint are present in the UPDATED chunk!"); + } + // construct a mock DataChunk + auto types = table.GetTypes(); + CreateMockChunk(types, column_ids, chunk, mock_chunk); + return true; +} + +void DataTable::VerifyUpdateConstraints(ClientContext &context, TableCatalogEntry &table, DataChunk &chunk, + const vector &column_ids) { + auto &constraints = table.GetConstraints(); + auto &bound_constraints = table.GetBoundConstraints(); + for (idx_t constr_idx = 0; constr_idx < bound_constraints.size(); constr_idx++) { + auto &base_constraint = constraints[constr_idx]; + auto &constraint = bound_constraints[constr_idx]; + switch (constraint->type) { + case ConstraintType::NOT_NULL: { + auto &bound_not_null = *reinterpret_cast(constraint.get()); + auto ¬_null = *reinterpret_cast(base_constraint.get()); + // check if the constraint is in the list of column_ids + for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { + if (column_ids[col_idx] == bound_not_null.index) { + // found the column id: check the data in + auto &col = table.GetColumn(LogicalIndex(not_null.index)); + VerifyNotNullConstraint(table, chunk.data[col_idx], chunk.size(), col.Name()); + break; + } + } + break; + } + case ConstraintType::CHECK: { + auto &check = *reinterpret_cast(constraint.get()); + + DataChunk mock_chunk; + if (CreateMockChunk(table, column_ids, check.bound_columns, chunk, mock_chunk)) { + VerifyCheckConstraint(context, table, *check.expression, mock_chunk); + } + break; + } + case ConstraintType::UNIQUE: + case ConstraintType::FOREIGN_KEY: + break; + default: + throw NotImplementedException("Constraint type not implemented!"); + } + } + // update should not be called for indexed columns! + // instead update should have been rewritten to delete + update on higher layer +#ifdef DEBUG + info->indexes.Scan([&](Index &index) { + D_ASSERT(!index.IndexIsUpdated(column_ids)); + return false; + }); + +#endif +} + +void DataTable::Update(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, + const vector &column_ids, DataChunk &updates) { + D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); + D_ASSERT(column_ids.size() == updates.ColumnCount()); + updates.Verify(); + + auto count = updates.size(); + if (count == 0) { + return; + } + + if (!is_root) { + throw TransactionException("Transaction conflict: cannot update a table that has been altered!"); + } + + // first verify that no constraints are violated + VerifyUpdateConstraints(context, table, updates, column_ids); + + // now perform the actual update + Vector max_row_id_vec(Value::BIGINT(MAX_ROW_ID)); + Vector row_ids_slice(LogicalType::BIGINT); + DataChunk updates_slice; + updates_slice.InitializeEmpty(updates.GetTypes()); + + SelectionVector sel_local_update(count), sel_global_update(count); + auto n_local_update = VectorOperations::GreaterThanEquals(row_ids, max_row_id_vec, nullptr, count, + &sel_local_update, &sel_global_update); + auto n_global_update = count - n_local_update; + + // row id > MAX_ROW_ID? transaction-local storage + if (n_local_update > 0) { + updates_slice.Slice(updates, sel_local_update, n_local_update); + updates_slice.Flatten(); + row_ids_slice.Slice(row_ids, sel_local_update, n_local_update); + row_ids_slice.Flatten(n_local_update); + + LocalStorage::Get(context, db).Update(*this, row_ids_slice, column_ids, updates_slice); + } + + // otherwise global storage + if (n_global_update > 0) { + updates_slice.Slice(updates, sel_global_update, n_global_update); + updates_slice.Flatten(); + row_ids_slice.Slice(row_ids, sel_global_update, n_global_update); + row_ids_slice.Flatten(n_global_update); + + row_groups->Update(DuckTransaction::Get(context, db), FlatVector::GetData(row_ids_slice), column_ids, + updates_slice); + } +} + +void DataTable::UpdateColumn(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, + const vector &column_path, DataChunk &updates) { + D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); + D_ASSERT(updates.ColumnCount() == 1); + updates.Verify(); + if (updates.size() == 0) { + return; + } + + if (!is_root) { + throw TransactionException("Transaction conflict: cannot update a table that has been altered!"); + } + + // now perform the actual update + auto &transaction = DuckTransaction::Get(context, db); + + updates.Flatten(); + row_ids.Flatten(updates.size()); + row_groups->UpdateColumn(transaction, row_ids, column_path, updates); +} + +//===--------------------------------------------------------------------===// +// Index Scan +//===--------------------------------------------------------------------===// +void DataTable::InitializeWALCreateIndexScan(CreateIndexScanState &state, const vector &column_ids) { + // we grab the append lock to make sure nothing is appended until AFTER we finish the index scan + state.append_lock = std::unique_lock(append_lock); + InitializeScan(state, column_ids); +} + +void DataTable::WALAddIndex(ClientContext &context, unique_ptr index, + const vector> &expressions) { + + // if the data table is empty + if (row_groups->IsEmpty()) { + info->indexes.AddIndex(std::move(index)); + return; + } + + auto &allocator = Allocator::Get(db); + + // intermediate holds scanned chunks of the underlying data to create the index + DataChunk intermediate; + vector intermediate_types; + vector column_ids; + for (auto &it : column_definitions) { + intermediate_types.push_back(it.Type()); + column_ids.push_back(it.Oid()); + } + column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); + intermediate_types.emplace_back(LogicalType::ROW_TYPE); + + intermediate.Initialize(allocator, intermediate_types); + + // holds the result of executing the index expression on the intermediate chunks + DataChunk result; + result.Initialize(allocator, index->logical_types); + + // initialize an index scan + CreateIndexScanState state; + InitializeWALCreateIndexScan(state, column_ids); + + if (!is_root) { + throw InternalException("Error during WAL replay. Cannot add an index to a table that has been altered."); + } + + // now start incrementally building the index + { + IndexLock lock; + index->InitializeLock(lock); + + while (true) { + intermediate.Reset(); + result.Reset(); + // scan a new chunk from the table to index + CreateIndexScan(state, intermediate, TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED); + if (intermediate.size() == 0) { + // finished scanning for index creation + // release all locks + break; + } + // resolve the expressions for this chunk + index->ExecuteExpressions(intermediate, result); + + // insert into the index + auto error = index->Insert(lock, result, intermediate.data[intermediate.ColumnCount() - 1]); + if (error) { + throw InternalException("Error during WAL replay: %s", error.Message()); + } + } + } + + info->indexes.AddIndex(std::move(index)); +} + +//===--------------------------------------------------------------------===// +// Statistics +//===--------------------------------------------------------------------===// +unique_ptr DataTable::GetStatistics(ClientContext &context, column_t column_id) { + if (column_id == COLUMN_IDENTIFIER_ROW_ID) { + return nullptr; + } + return row_groups->CopyStats(column_id); +} + +void DataTable::SetDistinct(column_t column_id, unique_ptr distinct_stats) { + D_ASSERT(column_id != COLUMN_IDENTIFIER_ROW_ID); + row_groups->SetDistinct(column_id, std::move(distinct_stats)); +} + +//===--------------------------------------------------------------------===// +// Checkpoint +//===--------------------------------------------------------------------===// +void DataTable::Checkpoint(TableDataWriter &writer, Serializer &metadata_serializer) { + // checkpoint each individual row group + // FIXME: we might want to combine adjacent row groups in case they have had deletions... + TableStatistics global_stats; + row_groups->CopyStats(global_stats); + + row_groups->Checkpoint(writer, global_stats); + + // The rowgroup payload data has been written. Now write: + // column stats + // row-group pointers + // table pointer + // index data + writer.FinalizeTable(std::move(global_stats), info.get(), metadata_serializer); +} + +void DataTable::CommitDropColumn(idx_t index) { + row_groups->CommitDropColumn(index); +} + +idx_t DataTable::GetTotalRows() { + return row_groups->GetTotalRows(); +} + +void DataTable::CommitDropTable() { + // commit a drop of this table: mark all blocks as modified so they can be reclaimed later on + row_groups->CommitDropTable(); +} + +//===--------------------------------------------------------------------===// +// GetColumnSegmentInfo +//===--------------------------------------------------------------------===// +vector DataTable::GetColumnSegmentInfo() { + return row_groups->GetColumnSegmentInfo(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/index.cpp b/src/duckdb/src/storage/index.cpp new file mode 100644 index 00000000..74b19c94 --- /dev/null +++ b/src/duckdb/src/storage/index.cpp @@ -0,0 +1,110 @@ +#include "duckdb/storage/index.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/storage/table/append_state.hpp" + +namespace duckdb { + +Index::Index(AttachedDatabase &db, IndexType type, TableIOManager &table_io_manager, + const vector &column_ids_p, const vector> &unbound_expressions, + IndexConstraintType constraint_type_p) + + : type(type), table_io_manager(table_io_manager), column_ids(column_ids_p), constraint_type(constraint_type_p), + db(db) { + + for (auto &expr : unbound_expressions) { + types.push_back(expr->return_type.InternalType()); + logical_types.push_back(expr->return_type); + auto unbound_expression = expr->Copy(); + bound_expressions.push_back(BindExpression(unbound_expression->Copy())); + this->unbound_expressions.emplace_back(std::move(unbound_expression)); + } + for (auto &bound_expr : bound_expressions) { + executor.AddExpression(*bound_expr); + } + + // create the column id set + column_id_set.insert(column_ids.begin(), column_ids.end()); +} + +void Index::InitializeLock(IndexLock &state) { + state.index_lock = unique_lock(lock); +} + +PreservedError Index::Append(DataChunk &entries, Vector &row_identifiers) { + IndexLock state; + InitializeLock(state); + return Append(state, entries, row_identifiers); +} + +void Index::CommitDrop() { + IndexLock index_lock; + InitializeLock(index_lock); + CommitDrop(index_lock); +} + +void Index::Delete(DataChunk &entries, Vector &row_identifiers) { + IndexLock state; + InitializeLock(state); + Delete(state, entries, row_identifiers); +} + +bool Index::MergeIndexes(Index &other_index) { + IndexLock state; + InitializeLock(state); + return MergeIndexes(state, other_index); +} + +string Index::VerifyAndToString(const bool only_verify) { + IndexLock state; + InitializeLock(state); + return VerifyAndToString(state, only_verify); +} + +void Index::Vacuum() { + IndexLock state; + InitializeLock(state); + Vacuum(state); +} + +void Index::ExecuteExpressions(DataChunk &input, DataChunk &result) { + executor.Execute(input, result); +} + +unique_ptr Index::BindExpression(unique_ptr expr) { + if (expr->type == ExpressionType::BOUND_COLUMN_REF) { + auto &bound_colref = expr->Cast(); + return make_uniq(expr->return_type, column_ids[bound_colref.binding.column_index]); + } + ExpressionIterator::EnumerateChildren( + *expr, [this](unique_ptr &expr) { expr = BindExpression(std::move(expr)); }); + return expr; +} + +bool Index::IndexIsUpdated(const vector &column_ids) const { + for (auto &column : column_ids) { + if (column_id_set.find(column.index) != column_id_set.end()) { + return true; + } + } + return false; +} + +BlockPointer Index::Serialize(MetadataWriter &writer) { + throw NotImplementedException("The implementation of this index serialization does not exist."); +} + +string Index::AppendRowError(DataChunk &input, idx_t index) { + string error; + for (idx_t c = 0; c < input.ColumnCount(); c++) { + if (c > 0) { + error += ", "; + } + error += input.GetValue(c, index).ToString(); + } + return error; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/local_storage.cpp b/src/duckdb/src/storage/local_storage.cpp new file mode 100644 index 00000000..dcad346c --- /dev/null +++ b/src/duckdb/src/storage/local_storage.cpp @@ -0,0 +1,570 @@ +#include "duckdb/transaction/local_storage.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/write_ahead_log.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/partial_block_manager.hpp" + +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/storage/table_io_manager.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +LocalTableStorage::LocalTableStorage(DataTable &table) + : table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), optimistic_writer(table), + merged_storage(false) { + auto types = table.GetTypes(); + row_groups = make_shared(table.info, TableIOManager::Get(table).GetBlockManagerForRowData(), + types, MAX_ROW_ID, 0); + row_groups->InitializeEmpty(); + + table.info->indexes.Scan([&](Index &index) { + D_ASSERT(index.type == IndexType::ART); + auto &art = index.Cast(); + if (art.constraint_type != IndexConstraintType::NONE) { + // unique index: create a local ART index that maintains the same unique constraint + vector> unbound_expressions; + unbound_expressions.reserve(art.unbound_expressions.size()); + for (auto &expr : art.unbound_expressions) { + unbound_expressions.push_back(expr->Copy()); + } + indexes.AddIndex(make_uniq(art.column_ids, art.table_io_manager, std::move(unbound_expressions), + art.constraint_type, art.db)); + } + return false; + }); +} + +LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_dt, LocalTableStorage &parent, + idx_t changed_idx, const LogicalType &target_type, + const vector &bound_columns, Expression &cast_expr) + : table_ref(new_dt), allocator(Allocator::Get(new_dt.db)), deleted_rows(parent.deleted_rows), + optimistic_writer(new_dt, parent.optimistic_writer), optimistic_writers(std::move(parent.optimistic_writers)), + merged_storage(parent.merged_storage) { + row_groups = parent.row_groups->AlterType(context, changed_idx, target_type, bound_columns, cast_expr); + parent.row_groups.reset(); + indexes.Move(parent.indexes); +} + +LocalTableStorage::LocalTableStorage(DataTable &new_dt, LocalTableStorage &parent, idx_t drop_idx) + : table_ref(new_dt), allocator(Allocator::Get(new_dt.db)), deleted_rows(parent.deleted_rows), + optimistic_writer(new_dt, parent.optimistic_writer), optimistic_writers(std::move(parent.optimistic_writers)), + merged_storage(parent.merged_storage) { + row_groups = parent.row_groups->RemoveColumn(drop_idx); + parent.row_groups.reset(); + indexes.Move(parent.indexes); +} + +LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_dt, LocalTableStorage &parent, + ColumnDefinition &new_column, Expression &default_value) + : table_ref(new_dt), allocator(Allocator::Get(new_dt.db)), deleted_rows(parent.deleted_rows), + optimistic_writer(new_dt, parent.optimistic_writer), optimistic_writers(std::move(parent.optimistic_writers)), + merged_storage(parent.merged_storage) { + row_groups = parent.row_groups->AddColumn(context, new_column, default_value); + parent.row_groups.reset(); + indexes.Move(parent.indexes); +} + +LocalTableStorage::~LocalTableStorage() { +} + +void LocalTableStorage::InitializeScan(CollectionScanState &state, optional_ptr table_filters) { + if (row_groups->GetTotalRows() == 0) { + throw InternalException("No rows in LocalTableStorage row group for scan"); + } + row_groups->InitializeScan(state, state.GetColumnIds(), table_filters.get()); +} + +idx_t LocalTableStorage::EstimatedSize() { + idx_t appended_rows = row_groups->GetTotalRows() - deleted_rows; + idx_t row_size = 0; + auto &types = row_groups->GetTypes(); + for (auto &type : types) { + row_size += GetTypeIdSize(type.InternalType()); + } + return appended_rows * row_size; +} + +void LocalTableStorage::WriteNewRowGroup() { + if (deleted_rows != 0) { + // we have deletes - we cannot merge row groups + return; + } + optimistic_writer.WriteNewRowGroup(*row_groups); +} + +void LocalTableStorage::FlushBlocks() { + if (!merged_storage && row_groups->GetTotalRows() > Storage::ROW_GROUP_SIZE) { + optimistic_writer.WriteLastRowGroup(*row_groups); + } + optimistic_writer.FinalFlush(); +} + +PreservedError LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, RowGroupCollection &source, + TableIndexList &index_list, const vector &table_types, + row_t &start_row) { + // only need to scan for index append + // figure out which columns we need to scan for the set of indexes + auto columns = index_list.GetRequiredColumns(); + // create an empty mock chunk that contains all the correct types for the table + DataChunk mock_chunk; + mock_chunk.InitializeEmpty(table_types); + PreservedError error; + source.Scan(transaction, columns, [&](DataChunk &chunk) -> bool { + // construct the mock chunk by referencing the required columns + for (idx_t i = 0; i < columns.size(); i++) { + mock_chunk.data[columns[i]].Reference(chunk.data[i]); + } + mock_chunk.SetCardinality(chunk); + // append this chunk to the indexes of the table + error = DataTable::AppendToIndexes(index_list, mock_chunk, start_row); + if (error) { + return false; + } + start_row += chunk.size(); + return true; + }); + return error; +} + +void LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, TableAppendState &append_state, + idx_t append_count, bool append_to_table) { + auto &table = table_ref.get(); + if (append_to_table) { + table.InitializeAppend(transaction, append_state, append_count); + } + PreservedError error; + if (append_to_table) { + // appending: need to scan entire + row_groups->Scan(transaction, [&](DataChunk &chunk) -> bool { + // append this chunk to the indexes of the table + error = table.AppendToIndexes(chunk, append_state.current_row); + if (error) { + return false; + } + // append to base table + table.Append(chunk, append_state); + return true; + }); + } else { + error = + AppendToIndexes(transaction, *row_groups, table.info->indexes, table.GetTypes(), append_state.current_row); + } + if (error) { + // need to revert all appended row ids + row_t current_row = append_state.row_start; + // remove the data from the indexes, if there are any indexes + row_groups->Scan(transaction, [&](DataChunk &chunk) -> bool { + // append this chunk to the indexes of the table + try { + table.RemoveFromIndexes(append_state, chunk, current_row); + } catch (Exception &ex) { + error = PreservedError(ex); + return false; + } catch (std::exception &ex) { // LCOV_EXCL_START + error = PreservedError(ex); + return false; + } // LCOV_EXCL_STOP + + current_row += chunk.size(); + if (current_row >= append_state.current_row) { + // finished deleting all rows from the index: abort now + return false; + } + return true; + }); + if (append_to_table) { + table.RevertAppendInternal(append_state.row_start); + } + + // we need to vacuum the indexes to remove any buffers that are now empty + // due to reverting the appends + table.info->indexes.Scan([&](Index &index) { + index.Vacuum(); + return false; + }); + error.Throw(); + } +} + +OptimisticDataWriter &LocalTableStorage::CreateOptimisticWriter() { + auto writer = make_uniq(table_ref.get()); + optimistic_writers.push_back(std::move(writer)); + return *optimistic_writers.back(); +} + +void LocalTableStorage::FinalizeOptimisticWriter(OptimisticDataWriter &writer) { + // remove the writer from the set of optimistic writers + unique_ptr owned_writer; + for (idx_t i = 0; i < optimistic_writers.size(); i++) { + if (optimistic_writers[i].get() == &writer) { + owned_writer = std::move(optimistic_writers[i]); + optimistic_writers.erase(optimistic_writers.begin() + i); + break; + } + } + if (!owned_writer) { + throw InternalException("Error in FinalizeOptimisticWriter - could not find writer"); + } + optimistic_writer.Merge(*owned_writer); +} + +void LocalTableStorage::Rollback() { + for (auto &writer : optimistic_writers) { + writer->Rollback(); + } + optimistic_writers.clear(); + optimistic_writer.Rollback(); +} + +//===--------------------------------------------------------------------===// +// LocalTableManager +//===--------------------------------------------------------------------===// +optional_ptr LocalTableManager::GetStorage(DataTable &table) { + lock_guard l(table_storage_lock); + auto entry = table_storage.find(table); + return entry == table_storage.end() ? nullptr : entry->second.get(); +} + +LocalTableStorage &LocalTableManager::GetOrCreateStorage(DataTable &table) { + lock_guard l(table_storage_lock); + auto entry = table_storage.find(table); + if (entry == table_storage.end()) { + auto new_storage = make_shared(table); + auto storage = new_storage.get(); + table_storage.insert(make_pair(reference(table), std::move(new_storage))); + return *storage; + } else { + return *entry->second.get(); + } +} + +bool LocalTableManager::IsEmpty() { + lock_guard l(table_storage_lock); + return table_storage.empty(); +} + +shared_ptr LocalTableManager::MoveEntry(DataTable &table) { + lock_guard l(table_storage_lock); + auto entry = table_storage.find(table); + if (entry == table_storage.end()) { + return nullptr; + } + auto storage_entry = std::move(entry->second); + table_storage.erase(entry); + return storage_entry; +} + +reference_map_t> LocalTableManager::MoveEntries() { + lock_guard l(table_storage_lock); + return std::move(table_storage); +} + +idx_t LocalTableManager::EstimatedSize() { + lock_guard l(table_storage_lock); + idx_t estimated_size = 0; + for (auto &storage : table_storage) { + estimated_size += storage.second->EstimatedSize(); + } + return estimated_size; +} + +void LocalTableManager::InsertEntry(DataTable &table, shared_ptr entry) { + lock_guard l(table_storage_lock); + D_ASSERT(table_storage.find(table) == table_storage.end()); + table_storage[table] = std::move(entry); +} + +//===--------------------------------------------------------------------===// +// LocalStorage +//===--------------------------------------------------------------------===// +LocalStorage::LocalStorage(ClientContext &context, DuckTransaction &transaction) + : context(context), transaction(transaction) { +} + +LocalStorage::CommitState::CommitState() { +} + +LocalStorage::CommitState::~CommitState() { +} + +LocalStorage &LocalStorage::Get(DuckTransaction &transaction) { + return transaction.GetLocalStorage(); +} + +LocalStorage &LocalStorage::Get(ClientContext &context, AttachedDatabase &db) { + return DuckTransaction::Get(context, db).GetLocalStorage(); +} + +LocalStorage &LocalStorage::Get(ClientContext &context, Catalog &catalog) { + return LocalStorage::Get(context, catalog.GetAttached()); +} + +void LocalStorage::InitializeScan(DataTable &table, CollectionScanState &state, + optional_ptr table_filters) { + auto storage = table_manager.GetStorage(table); + if (storage == nullptr) { + return; + } + storage->InitializeScan(state, table_filters); +} + +void LocalStorage::Scan(CollectionScanState &state, const vector &column_ids, DataChunk &result) { + state.Scan(transaction, result); +} + +void LocalStorage::InitializeParallelScan(DataTable &table, ParallelCollectionScanState &state) { + auto storage = table_manager.GetStorage(table); + if (!storage) { + state.max_row = 0; + state.vector_index = 0; + state.current_row_group = nullptr; + } else { + storage->row_groups->InitializeParallelScan(state); + } +} + +bool LocalStorage::NextParallelScan(ClientContext &context, DataTable &table, ParallelCollectionScanState &state, + CollectionScanState &scan_state) { + auto storage = table_manager.GetStorage(table); + if (!storage) { + return false; + } + return storage->row_groups->NextParallelScan(context, state, scan_state); +} + +void LocalStorage::InitializeAppend(LocalAppendState &state, DataTable &table) { + state.storage = &table_manager.GetOrCreateStorage(table); + state.storage->row_groups->InitializeAppend(TransactionData(transaction), state.append_state, 0); +} + +void LocalStorage::Append(LocalAppendState &state, DataChunk &chunk) { + // append to unique indices (if any) + auto storage = state.storage; + idx_t base_id = MAX_ROW_ID + storage->row_groups->GetTotalRows() + state.append_state.total_append_count; + auto error = DataTable::AppendToIndexes(storage->indexes, chunk, base_id); + if (error) { + error.Throw(); + } + + //! Append the chunk to the local storage + auto new_row_group = storage->row_groups->Append(chunk, state.append_state); + //! Check if we should pre-emptively flush blocks to disk + if (new_row_group) { + storage->WriteNewRowGroup(); + } +} + +void LocalStorage::FinalizeAppend(LocalAppendState &state) { + state.storage->row_groups->FinalizeAppend(state.append_state.transaction, state.append_state); +} + +void LocalStorage::LocalMerge(DataTable &table, RowGroupCollection &collection) { + auto &storage = table_manager.GetOrCreateStorage(table); + if (!storage.indexes.Empty()) { + // append data to indexes if required + row_t base_id = MAX_ROW_ID + storage.row_groups->GetTotalRows(); + auto error = storage.AppendToIndexes(transaction, collection, storage.indexes, table.GetTypes(), base_id); + if (error) { + error.Throw(); + } + } + storage.row_groups->MergeStorage(collection); + storage.merged_storage = true; +} + +OptimisticDataWriter &LocalStorage::CreateOptimisticWriter(DataTable &table) { + auto &storage = table_manager.GetOrCreateStorage(table); + return storage.CreateOptimisticWriter(); +} + +void LocalStorage::FinalizeOptimisticWriter(DataTable &table, OptimisticDataWriter &writer) { + auto &storage = table_manager.GetOrCreateStorage(table); + storage.FinalizeOptimisticWriter(writer); +} + +bool LocalStorage::ChangesMade() noexcept { + return !table_manager.IsEmpty(); +} + +bool LocalStorage::Find(DataTable &table) { + return table_manager.GetStorage(table) != nullptr; +} + +idx_t LocalStorage::EstimatedSize() { + return table_manager.EstimatedSize(); +} + +idx_t LocalStorage::Delete(DataTable &table, Vector &row_ids, idx_t count) { + auto storage = table_manager.GetStorage(table); + D_ASSERT(storage); + + // delete from unique indices (if any) + if (!storage->indexes.Empty()) { + storage->row_groups->RemoveFromIndexes(storage->indexes, row_ids, count); + } + + auto ids = FlatVector::GetData(row_ids); + idx_t delete_count = storage->row_groups->Delete(TransactionData(0, 0), table, ids, count); + storage->deleted_rows += delete_count; + return delete_count; +} + +void LocalStorage::Update(DataTable &table, Vector &row_ids, const vector &column_ids, + DataChunk &updates) { + auto storage = table_manager.GetStorage(table); + D_ASSERT(storage); + + auto ids = FlatVector::GetData(row_ids); + storage->row_groups->Update(TransactionData(0, 0), ids, column_ids, updates); +} + +void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage) { + if (storage.row_groups->GetTotalRows() <= storage.deleted_rows) { + return; + } + idx_t append_count = storage.row_groups->GetTotalRows() - storage.deleted_rows; + + TableAppendState append_state; + table.AppendLock(append_state); + transaction.PushAppend(table, append_state.row_start, append_count); + if ((append_state.row_start == 0 || storage.row_groups->GetTotalRows() >= MERGE_THRESHOLD) && + storage.deleted_rows == 0) { + // table is currently empty OR we are bulk appending: move over the storage directly + // first flush any outstanding blocks + storage.FlushBlocks(); + // now append to the indexes (if there are any) + // FIXME: we should be able to merge the transaction-local index directly into the main table index + // as long we just rewrite some row-ids + if (!table.info->indexes.Empty()) { + storage.AppendToIndexes(transaction, append_state, append_count, false); + } + // finally move over the row groups + table.MergeStorage(*storage.row_groups, storage.indexes); + } else { + // check if we have written data + // if we have, we cannot merge to disk after all + // so we need to revert the data we have already written + storage.Rollback(); + // append to the indexes and append to the base table + storage.AppendToIndexes(transaction, append_state, append_count, true); + } + + // possibly vacuum any excess index data + table.info->indexes.Scan([&](Index &index) { + index.Vacuum(); + return false; + }); +} + +void LocalStorage::Commit(LocalStorage::CommitState &commit_state, DuckTransaction &transaction) { + // commit local storage + // iterate over all entries in the table storage map and commit them + // after this, the local storage is no longer required and can be cleared + auto table_storage = table_manager.MoveEntries(); + for (auto &entry : table_storage) { + auto table = entry.first; + auto storage = entry.second.get(); + Flush(table, *storage); + entry.second.reset(); + } +} + +void LocalStorage::Rollback() { + // rollback local storage + // after this, the local storage is no longer required and can be cleared + auto table_storage = table_manager.MoveEntries(); + for (auto &entry : table_storage) { + auto storage = entry.second.get(); + if (!storage) { + continue; + } + storage->Rollback(); + + entry.second.reset(); + } +} + +idx_t LocalStorage::AddedRows(DataTable &table) { + auto storage = table_manager.GetStorage(table); + if (!storage) { + return 0; + } + return storage->row_groups->GetTotalRows() - storage->deleted_rows; +} + +void LocalStorage::MoveStorage(DataTable &old_dt, DataTable &new_dt) { + // check if there are any pending appends for the old version of the table + auto new_storage = table_manager.MoveEntry(old_dt); + if (!new_storage) { + return; + } + // take over the storage from the old entry + new_storage->table_ref = new_dt; + table_manager.InsertEntry(new_dt, std::move(new_storage)); +} + +void LocalStorage::AddColumn(DataTable &old_dt, DataTable &new_dt, ColumnDefinition &new_column, + Expression &default_value) { + // check if there are any pending appends for the old version of the table + auto storage = table_manager.MoveEntry(old_dt); + if (!storage) { + return; + } + auto new_storage = make_shared(context, new_dt, *storage, new_column, default_value); + table_manager.InsertEntry(new_dt, std::move(new_storage)); +} + +void LocalStorage::DropColumn(DataTable &old_dt, DataTable &new_dt, idx_t removed_column) { + // check if there are any pending appends for the old version of the table + auto storage = table_manager.MoveEntry(old_dt); + if (!storage) { + return; + } + auto new_storage = make_shared(new_dt, *storage, removed_column); + table_manager.InsertEntry(new_dt, std::move(new_storage)); +} + +void LocalStorage::ChangeType(DataTable &old_dt, DataTable &new_dt, idx_t changed_idx, const LogicalType &target_type, + const vector &bound_columns, Expression &cast_expr) { + // check if there are any pending appends for the old version of the table + auto storage = table_manager.MoveEntry(old_dt); + if (!storage) { + return; + } + auto new_storage = + make_shared(context, new_dt, *storage, changed_idx, target_type, bound_columns, cast_expr); + table_manager.InsertEntry(new_dt, std::move(new_storage)); +} + +void LocalStorage::FetchChunk(DataTable &table, Vector &row_ids, idx_t count, const vector &col_ids, + DataChunk &chunk, ColumnFetchState &fetch_state) { + auto storage = table_manager.GetStorage(table); + if (!storage) { + throw InternalException("LocalStorage::FetchChunk - local storage not found"); + } + + storage->row_groups->Fetch(transaction, chunk, col_ids, row_ids, count, fetch_state); +} + +TableIndexList &LocalStorage::GetIndexes(DataTable &table) { + auto storage = table_manager.GetStorage(table); + if (!storage) { + throw InternalException("LocalStorage::GetIndexes - local storage not found"); + } + return storage->indexes; +} + +void LocalStorage::VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint) { + auto storage = table_manager.GetStorage(parent); + if (!storage) { + return; + } + storage->row_groups->VerifyNewConstraint(parent, constraint); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/magic_bytes.cpp b/src/duckdb/src/storage/magic_bytes.cpp new file mode 100644 index 00000000..f004d754 --- /dev/null +++ b/src/duckdb/src/storage/magic_bytes.cpp @@ -0,0 +1,31 @@ +#include "duckdb/storage/magic_bytes.hpp" +#include "duckdb/common/local_file_system.hpp" +#include "duckdb/storage/storage_info.hpp" + +namespace duckdb { + +DataFileType MagicBytes::CheckMagicBytes(FileSystem *fs_p, const string &path) { + LocalFileSystem lfs; + FileSystem &fs = fs_p ? *fs_p : lfs; + if (!fs.FileExists(path)) { + return DataFileType::FILE_DOES_NOT_EXIST; + } + auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ); + + constexpr const idx_t MAGIC_BYTES_READ_SIZE = 16; + char buffer[MAGIC_BYTES_READ_SIZE]; + + handle->Read(buffer, MAGIC_BYTES_READ_SIZE); + if (memcmp(buffer, "SQLite format 3\0", 16) == 0) { + return DataFileType::SQLITE_FILE; + } + if (memcmp(buffer, "PAR1", 4) == 0) { + return DataFileType::PARQUET_FILE; + } + if (memcmp(buffer + MainHeader::MAGIC_BYTE_OFFSET, MainHeader::MAGIC_BYTES, MainHeader::MAGIC_BYTE_SIZE) == 0) { + return DataFileType::DUCKDB_FILE; + } + return DataFileType::FILE_DOES_NOT_EXIST; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/metadata/metadata_manager.cpp b/src/duckdb/src/storage/metadata/metadata_manager.cpp new file mode 100644 index 00000000..0bd58d68 --- /dev/null +++ b/src/duckdb/src/storage/metadata/metadata_manager.cpp @@ -0,0 +1,322 @@ +#include "duckdb/storage/metadata/metadata_manager.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/common/serializer/write_stream.hpp" +#include "duckdb/common/serializer/read_stream.hpp" +#include "duckdb/storage/database_size.hpp" + +namespace duckdb { + +MetadataManager::MetadataManager(BlockManager &block_manager, BufferManager &buffer_manager) + : block_manager(block_manager), buffer_manager(buffer_manager) { +} + +MetadataManager::~MetadataManager() { +} + +MetadataHandle MetadataManager::AllocateHandle() { + // check if there is any free space left in an existing block + // if not allocate a new block + block_id_t free_block = INVALID_BLOCK; + for (auto &kv : blocks) { + auto &block = kv.second; + D_ASSERT(kv.first == block.block_id); + if (!block.free_blocks.empty()) { + free_block = kv.first; + break; + } + } + if (free_block == INVALID_BLOCK) { + free_block = AllocateNewBlock(); + } + D_ASSERT(free_block != INVALID_BLOCK); + + // select the first free metadata block we can find + MetadataPointer pointer; + pointer.block_index = free_block; + auto &block = blocks[free_block]; + if (block.block->BlockId() < MAXIMUM_BLOCK) { + // this block is a disk-backed block, yet we are planning to write to it + // we need to convert it into a transient block before we can write to it + ConvertToTransient(block); + D_ASSERT(block.block->BlockId() >= MAXIMUM_BLOCK); + } + D_ASSERT(!block.free_blocks.empty()); + pointer.index = block.free_blocks.back(); + // mark the block as used + block.free_blocks.pop_back(); + D_ASSERT(pointer.index < METADATA_BLOCK_COUNT); + // pin the block + return Pin(pointer); +} + +MetadataHandle MetadataManager::Pin(MetadataPointer pointer) { + D_ASSERT(pointer.index < METADATA_BLOCK_COUNT); + auto &block = blocks[pointer.block_index]; + + MetadataHandle handle; + handle.pointer.block_index = pointer.block_index; + handle.pointer.index = pointer.index; + handle.handle = buffer_manager.Pin(block.block); + return handle; +} + +void MetadataManager::ConvertToTransient(MetadataBlock &block) { + // pin the old block + auto old_buffer = buffer_manager.Pin(block.block); + + // allocate a new transient block to replace it + shared_ptr new_block; + auto new_buffer = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &new_block); + + // copy the data to the transient block + memcpy(new_buffer.Ptr(), old_buffer.Ptr(), Storage::BLOCK_SIZE); + + block.block = std::move(new_block); + + // unregister the old block + block_manager.UnregisterBlock(block.block_id, false); +} + +block_id_t MetadataManager::AllocateNewBlock() { + auto new_block_id = GetNextBlockId(); + + MetadataBlock new_block; + auto handle = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &new_block.block); + new_block.block_id = new_block_id; + for (idx_t i = 0; i < METADATA_BLOCK_COUNT; i++) { + new_block.free_blocks.push_back(METADATA_BLOCK_COUNT - i - 1); + } + // zero-initialize the handle + memset(handle.Ptr(), 0, Storage::BLOCK_SIZE); + AddBlock(std::move(new_block)); + return new_block_id; +} + +void MetadataManager::AddBlock(MetadataBlock new_block, bool if_exists) { + if (blocks.find(new_block.block_id) != blocks.end()) { + if (if_exists) { + return; + } + throw InternalException("Block id with id %llu already exists", new_block.block_id); + } + blocks[new_block.block_id] = std::move(new_block); +} + +void MetadataManager::AddAndRegisterBlock(MetadataBlock block) { + if (block.block) { + throw InternalException("Calling AddAndRegisterBlock on block that already exists"); + } + block.block = block_manager.RegisterBlock(block.block_id); + AddBlock(std::move(block), true); +} + +MetaBlockPointer MetadataManager::GetDiskPointer(MetadataPointer pointer, uint32_t offset) { + idx_t block_pointer = idx_t(pointer.block_index); + block_pointer |= idx_t(pointer.index) << 56ULL; + return MetaBlockPointer(block_pointer, offset); +} + +block_id_t MetaBlockPointer::GetBlockId() const { + return block_id_t(block_pointer & ~(idx_t(0xFF) << 56ULL)); +} + +uint32_t MetaBlockPointer::GetBlockIndex() const { + return block_pointer >> 56ULL; +} + +MetadataPointer MetadataManager::FromDiskPointer(MetaBlockPointer pointer) { + auto block_id = pointer.GetBlockId(); + auto index = pointer.GetBlockIndex(); + auto entry = blocks.find(block_id); + if (entry == blocks.end()) { // LCOV_EXCL_START + throw InternalException("Failed to load metadata pointer (id %llu, idx %llu, ptr %llu)\n", block_id, index, + pointer.block_pointer); + } // LCOV_EXCL_STOP + MetadataPointer result; + result.block_index = block_id; + result.index = index; + return result; +} + +MetadataPointer MetadataManager::RegisterDiskPointer(MetaBlockPointer pointer) { + auto block_id = pointer.GetBlockId(); + MetadataBlock block; + block.block_id = block_id; + AddAndRegisterBlock(block); + return FromDiskPointer(pointer); +} + +BlockPointer MetadataManager::ToBlockPointer(MetaBlockPointer meta_pointer) { + BlockPointer result; + result.block_id = meta_pointer.GetBlockId(); + result.offset = meta_pointer.GetBlockIndex() * MetadataManager::METADATA_BLOCK_SIZE + meta_pointer.offset; + D_ASSERT(result.offset < MetadataManager::METADATA_BLOCK_SIZE * MetadataManager::METADATA_BLOCK_COUNT); + return result; +} + +MetaBlockPointer MetadataManager::FromBlockPointer(BlockPointer block_pointer) { + if (!block_pointer.IsValid()) { + return MetaBlockPointer(); + } + idx_t index = block_pointer.offset / MetadataManager::METADATA_BLOCK_SIZE; + auto offset = block_pointer.offset % MetadataManager::METADATA_BLOCK_SIZE; + D_ASSERT(index < MetadataManager::METADATA_BLOCK_COUNT); + D_ASSERT(offset < MetadataManager::METADATA_BLOCK_SIZE); + MetaBlockPointer result; + result.block_pointer = idx_t(block_pointer.block_id) | index << 56ULL; + result.offset = offset; + return result; +} + +idx_t MetadataManager::BlockCount() { + return blocks.size(); +} + +void MetadataManager::Flush() { + const idx_t total_metadata_size = MetadataManager::METADATA_BLOCK_SIZE * MetadataManager::METADATA_BLOCK_COUNT; + // write the blocks of the metadata manager to disk + for (auto &kv : blocks) { + auto &block = kv.second; + auto handle = buffer_manager.Pin(block.block); + // there are a few bytes left-over at the end of the block, zero-initialize them + memset(handle.Ptr() + total_metadata_size, 0, Storage::BLOCK_SIZE - total_metadata_size); + D_ASSERT(kv.first == block.block_id); + if (block.block->BlockId() >= MAXIMUM_BLOCK) { + // temporary block - convert to persistent + block.block = block_manager.ConvertToPersistent(kv.first, std::move(block.block)); + } else { + // already a persistent block - only need to write it + D_ASSERT(block.block->BlockId() == block.block_id); + block_manager.Write(handle.GetFileBuffer(), block.block_id); + } + } +} + +void MetadataManager::Write(WriteStream &sink) { + sink.Write(blocks.size()); + for (auto &kv : blocks) { + kv.second.Write(sink); + } +} + +void MetadataManager::Read(ReadStream &source) { + auto block_count = source.Read(); + for (idx_t i = 0; i < block_count; i++) { + auto block = MetadataBlock::Read(source); + auto entry = blocks.find(block.block_id); + if (entry == blocks.end()) { + // block does not exist yet + AddAndRegisterBlock(std::move(block)); + } else { + // block was already created - only copy over the free list + entry->second.free_blocks = std::move(block.free_blocks); + } + } +} + +void MetadataBlock::Write(WriteStream &sink) { + sink.Write(block_id); + sink.Write(FreeBlocksToInteger()); +} + +MetadataBlock MetadataBlock::Read(ReadStream &source) { + MetadataBlock result; + result.block_id = source.Read(); + auto free_list = source.Read(); + result.FreeBlocksFromInteger(free_list); + return result; +} + +idx_t MetadataBlock::FreeBlocksToInteger() { + idx_t result = 0; + for (idx_t i = 0; i < free_blocks.size(); i++) { + D_ASSERT(free_blocks[i] < idx_t(64)); + idx_t mask = idx_t(1) << idx_t(free_blocks[i]); + result |= mask; + } + return result; +} + +void MetadataBlock::FreeBlocksFromInteger(idx_t free_list) { + free_blocks.clear(); + if (free_list == 0) { + return; + } + for (idx_t i = 64; i > 0; i--) { + auto index = i - 1; + idx_t mask = idx_t(1) << index; + if (free_list & mask) { + free_blocks.push_back(index); + } + } +} + +void MetadataManager::MarkBlocksAsModified() { + // for any blocks that were modified in the last checkpoint - set them to free blocks currently + for (auto &kv : modified_blocks) { + auto block_id = kv.first; + idx_t modified_list = kv.second; + auto entry = blocks.find(block_id); + D_ASSERT(entry != blocks.end()); + auto &block = entry->second; + idx_t current_free_blocks = block.FreeBlocksToInteger(); + // merge the current set of free blocks with the modified blocks + idx_t new_free_blocks = current_free_blocks | modified_list; + if (new_free_blocks == NumericLimits::Maximum()) { + // if new free_blocks is all blocks - mark entire block as modified + blocks.erase(entry); + block_manager.MarkBlockAsModified(block_id); + } else { + // set the new set of free blocks + block.FreeBlocksFromInteger(new_free_blocks); + } + } + + modified_blocks.clear(); + for (auto &kv : blocks) { + auto &block = kv.second; + idx_t free_list = block.FreeBlocksToInteger(); + idx_t occupied_list = ~free_list; + modified_blocks[block.block_id] = occupied_list; + } +} + +void MetadataManager::ClearModifiedBlocks(const vector &pointers) { + for (auto &pointer : pointers) { + auto block_id = pointer.GetBlockId(); + auto block_index = pointer.GetBlockIndex(); + auto entry = modified_blocks.find(block_id); + if (entry == modified_blocks.end()) { + throw InternalException("ClearModifiedBlocks - Block id %llu not found in modified_blocks", block_id); + } + auto &modified_list = entry->second; + // verify the block has been modified + D_ASSERT(modified_list && (1ULL << block_index)); + // unset the bit + modified_list &= ~(1ULL << block_index); + } +} + +vector MetadataManager::GetMetadataInfo() const { + vector result; + for (auto &block : blocks) { + MetadataBlockInfo block_info; + block_info.block_id = block.second.block_id; + block_info.total_blocks = MetadataManager::METADATA_BLOCK_COUNT; + for (auto free_block : block.second.free_blocks) { + block_info.free_list.push_back(free_block); + } + std::sort(block_info.free_list.begin(), block_info.free_list.end()); + result.push_back(std::move(block_info)); + } + std::sort(result.begin(), result.end(), + [](const MetadataBlockInfo &a, const MetadataBlockInfo &b) { return a.block_id < b.block_id; }); + return result; +} + +block_id_t MetadataManager::GetNextBlockId() { + return block_manager.GetFreeBlockId(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/metadata/metadata_reader.cpp b/src/duckdb/src/storage/metadata/metadata_reader.cpp new file mode 100644 index 00000000..fbc4dfb4 --- /dev/null +++ b/src/duckdb/src/storage/metadata/metadata_reader.cpp @@ -0,0 +1,89 @@ +#include "duckdb/storage/metadata/metadata_reader.hpp" + +namespace duckdb { + +MetadataReader::MetadataReader(MetadataManager &manager, MetaBlockPointer pointer, + optional_ptr> read_pointers_p, BlockReaderType type) + : manager(manager), type(type), next_pointer(FromDiskPointer(pointer)), has_next_block(true), + read_pointers(read_pointers_p), index(0), offset(0), next_offset(pointer.offset), capacity(0) { + if (read_pointers) { + D_ASSERT(read_pointers->empty()); + read_pointers->push_back(pointer); + } +} + +MetadataReader::MetadataReader(MetadataManager &manager, BlockPointer pointer) + : MetadataReader(manager, MetadataManager::FromBlockPointer(pointer)) { +} + +MetadataPointer MetadataReader::FromDiskPointer(MetaBlockPointer pointer) { + if (type == BlockReaderType::EXISTING_BLOCKS) { + return manager.FromDiskPointer(pointer); + } else { + return manager.RegisterDiskPointer(pointer); + } +} + +MetadataReader::~MetadataReader() { +} + +void MetadataReader::ReadData(data_ptr_t buffer, idx_t read_size) { + while (offset + read_size > capacity) { + // cannot read entire entry from block + // first read what we can from this block + idx_t to_read = capacity - offset; + if (to_read > 0) { + memcpy(buffer, Ptr(), to_read); + read_size -= to_read; + buffer += to_read; + offset += read_size; + } + // then move to the next block + ReadNextBlock(); + } + // we have enough left in this block to read from the buffer + memcpy(buffer, Ptr(), read_size); + offset += read_size; +} + +MetaBlockPointer MetadataReader::GetMetaBlockPointer() { + return manager.GetDiskPointer(block.pointer, offset); +} + +void MetadataReader::ReadNextBlock() { + if (!has_next_block) { + throw IOException("No more data remaining in MetadataReader"); + } + block = manager.Pin(next_pointer); + index = next_pointer.index; + + idx_t next_block = Load(BasePtr()); + if (next_block == idx_t(-1)) { + has_next_block = false; + } else { + next_pointer = FromDiskPointer(MetaBlockPointer(next_block, 0)); + MetaBlockPointer next_block_pointer(next_block, 0); + if (read_pointers) { + read_pointers->push_back(next_block_pointer); + } + } + if (next_offset < sizeof(block_id_t)) { + next_offset = sizeof(block_id_t); + } + if (next_offset > MetadataManager::METADATA_BLOCK_SIZE) { + throw InternalException("next_offset cannot be bigger than block size"); + } + offset = next_offset; + next_offset = sizeof(block_id_t); + capacity = MetadataManager::METADATA_BLOCK_SIZE; +} + +data_ptr_t MetadataReader::BasePtr() { + return block.handle.Ptr() + index * MetadataManager::METADATA_BLOCK_SIZE; +} + +data_ptr_t MetadataReader::Ptr() { + return BasePtr() + offset; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/metadata/metadata_writer.cpp b/src/duckdb/src/storage/metadata/metadata_writer.cpp new file mode 100644 index 00000000..47bb8d1f --- /dev/null +++ b/src/duckdb/src/storage/metadata/metadata_writer.cpp @@ -0,0 +1,92 @@ +#include "duckdb/storage/metadata/metadata_writer.hpp" +#include "duckdb/storage/block_manager.hpp" + +namespace duckdb { + +MetadataWriter::MetadataWriter(MetadataManager &manager, optional_ptr> written_pointers_p) + : manager(manager), written_pointers(written_pointers_p), capacity(0), offset(0) { + D_ASSERT(!written_pointers || written_pointers->empty()); +} + +MetadataWriter::~MetadataWriter() { + // If there's an exception during checkpoint, this can get destroyed without + // flushing the data...which is fine, because none of the unwritten data + // will be referenced. + // + // Otherwise, we should have explicitly flushed (and thereby nulled the block). + D_ASSERT(!block.handle.IsValid() || Exception::UncaughtException()); +} + +BlockPointer MetadataWriter::GetBlockPointer() { + return MetadataManager::ToBlockPointer(GetMetaBlockPointer()); +} + +MetaBlockPointer MetadataWriter::GetMetaBlockPointer() { + if (offset >= capacity) { + // at the end of the block - fetch the next block + NextBlock(); + D_ASSERT(capacity > 0); + } + return manager.GetDiskPointer(block.pointer, offset); +} + +MetadataHandle MetadataWriter::NextHandle() { + return manager.AllocateHandle(); +} + +void MetadataWriter::NextBlock() { + // now we need to get a new block id + auto new_handle = NextHandle(); + + // write the block id of the new block to the start of the current block + if (capacity > 0) { + auto disk_block = manager.GetDiskPointer(new_handle.pointer); + Store(disk_block.block_pointer, BasePtr()); + } + // now update the block id of the block + block = std::move(new_handle); + current_pointer = block.pointer; + offset = sizeof(idx_t); + capacity = MetadataManager::METADATA_BLOCK_SIZE; + Store(-1, BasePtr()); + if (written_pointers) { + written_pointers->push_back(manager.GetDiskPointer(current_pointer)); + } +} + +void MetadataWriter::WriteData(const_data_ptr_t buffer, idx_t write_size) { + while (offset + write_size > capacity) { + // we need to make a new block + // first copy what we can + D_ASSERT(offset <= capacity); + idx_t copy_amount = capacity - offset; + if (copy_amount > 0) { + memcpy(Ptr(), buffer, copy_amount); + buffer += copy_amount; + offset += copy_amount; + write_size -= copy_amount; + } + // move forward to the next block + NextBlock(); + } + memcpy(Ptr(), buffer, write_size); + offset += write_size; +} + +void MetadataWriter::Flush() { + if (offset < capacity) { + // clear remaining bytes of block (if any) + memset(Ptr(), 0, capacity - offset); + } + block.handle.Destroy(); +} + +data_ptr_t MetadataWriter::BasePtr() { + return block.handle.Ptr() + current_pointer.index * MetadataManager::METADATA_BLOCK_SIZE; +} + +data_ptr_t MetadataWriter::Ptr() { + return BasePtr() + offset; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/optimistic_data_writer.cpp b/src/duckdb/src/storage/optimistic_data_writer.cpp new file mode 100644 index 00000000..df352698 --- /dev/null +++ b/src/duckdb/src/storage/optimistic_data_writer.cpp @@ -0,0 +1,96 @@ +#include "duckdb/storage/optimistic_data_writer.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/storage/partial_block_manager.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" + +namespace duckdb { + +OptimisticDataWriter::OptimisticDataWriter(DataTable &table) : table(table) { +} + +OptimisticDataWriter::OptimisticDataWriter(DataTable &table, OptimisticDataWriter &parent) : table(table) { + if (parent.partial_manager) { + parent.partial_manager->ClearBlocks(); + } +} + +OptimisticDataWriter::~OptimisticDataWriter() { +} + +bool OptimisticDataWriter::PrepareWrite() { + // check if we should pre-emptively write the table to disk + if (table.info->IsTemporary() || StorageManager::Get(table.info->db).InMemory()) { + return false; + } + // we should! write the second-to-last row group to disk + // allocate the partial block-manager if none is allocated yet + if (!partial_manager) { + auto &block_manager = table.info->table_io_manager->GetBlockManagerForRowData(); + partial_manager = make_uniq(block_manager, CheckpointType::APPEND_TO_TABLE); + } + return true; +} + +void OptimisticDataWriter::WriteNewRowGroup(RowGroupCollection &row_groups) { + // we finished writing a complete row group + if (!PrepareWrite()) { + return; + } + // flush second-to-last row group + auto row_group = row_groups.GetRowGroup(-2); + FlushToDisk(row_group); +} + +void OptimisticDataWriter::WriteLastRowGroup(RowGroupCollection &row_groups) { + // we finished writing a complete row group + if (!PrepareWrite()) { + return; + } + // flush second-to-last row group + auto row_group = row_groups.GetRowGroup(-1); + if (!row_group) { + return; + } + FlushToDisk(row_group); +} + +void OptimisticDataWriter::FlushToDisk(RowGroup *row_group) { + if (!row_group) { + throw InternalException("FlushToDisk called without a RowGroup"); + } + //! The set of column compression types (if any) + vector compression_types; + D_ASSERT(compression_types.empty()); + for (auto &column : table.column_definitions) { + compression_types.push_back(column.CompressionType()); + } + row_group->WriteToDisk(*partial_manager, compression_types); +} + +void OptimisticDataWriter::Merge(OptimisticDataWriter &other) { + if (!other.partial_manager) { + return; + } + if (!partial_manager) { + partial_manager = std::move(other.partial_manager); + return; + } + partial_manager->Merge(*other.partial_manager); + other.partial_manager.reset(); +} + +void OptimisticDataWriter::FinalFlush() { + if (partial_manager) { + partial_manager->FlushPartialBlocks(); + partial_manager.reset(); + } +} + +void OptimisticDataWriter::Rollback() { + if (partial_manager) { + partial_manager->Rollback(); + partial_manager.reset(); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/partial_block_manager.cpp b/src/duckdb/src/storage/partial_block_manager.cpp new file mode 100644 index 00000000..8128caf0 --- /dev/null +++ b/src/duckdb/src/storage/partial_block_manager.cpp @@ -0,0 +1,193 @@ +#include "duckdb/storage/partial_block_manager.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// PartialBlock +//===--------------------------------------------------------------------===// + +PartialBlock::PartialBlock(PartialBlockState state, BlockManager &block_manager, + const shared_ptr &block_handle) + : state(state), block_manager(block_manager), block_handle(block_handle) { +} + +void PartialBlock::AddUninitializedRegion(idx_t start, idx_t end) { + uninitialized_regions.push_back({start, end}); +} + +void PartialBlock::FlushInternal(const idx_t free_space_left) { + + // ensure that we do not leak any data + if (free_space_left > 0 || !uninitialized_regions.empty()) { + auto buffer_handle = block_manager.buffer_manager.Pin(block_handle); + + // memset any uninitialized regions + for (auto &uninitialized : uninitialized_regions) { + memset(buffer_handle.Ptr() + uninitialized.start, 0, uninitialized.end - uninitialized.start); + } + // memset any free space at the end of the block to 0 prior to writing to disk + memset(buffer_handle.Ptr() + Storage::BLOCK_SIZE - free_space_left, 0, free_space_left); + } +} + +//===--------------------------------------------------------------------===// +// PartialBlockManager +//===--------------------------------------------------------------------===// + +PartialBlockManager::PartialBlockManager(BlockManager &block_manager, CheckpointType checkpoint_type, + uint32_t max_partial_block_size, uint32_t max_use_count) + : block_manager(block_manager), checkpoint_type(checkpoint_type), max_partial_block_size(max_partial_block_size), + max_use_count(max_use_count) { +} +PartialBlockManager::~PartialBlockManager() { +} + +PartialBlockAllocation PartialBlockManager::GetBlockAllocation(uint32_t segment_size) { + PartialBlockAllocation allocation; + allocation.block_manager = &block_manager; + allocation.allocation_size = segment_size; + + // if the block is less than 80% full, we consider it a "partial block" + // which means we will try to fit it with other blocks + // check if there is a partial block available we can write to + if (segment_size <= max_partial_block_size && GetPartialBlock(segment_size, allocation.partial_block)) { + //! there is! increase the reference count of this block + allocation.partial_block->state.block_use_count += 1; + allocation.state = allocation.partial_block->state; + if (checkpoint_type == CheckpointType::FULL_CHECKPOINT) { + block_manager.IncreaseBlockReferenceCount(allocation.state.block_id); + } + } else { + // full block: get a free block to write to + AllocateBlock(allocation.state, segment_size); + } + return allocation; +} + +bool PartialBlockManager::HasBlockAllocation(uint32_t segment_size) { + return segment_size <= max_partial_block_size && + partially_filled_blocks.lower_bound(segment_size) != partially_filled_blocks.end(); +} + +void PartialBlockManager::AllocateBlock(PartialBlockState &state, uint32_t segment_size) { + D_ASSERT(segment_size <= Storage::BLOCK_SIZE); + if (checkpoint_type == CheckpointType::FULL_CHECKPOINT) { + state.block_id = block_manager.GetFreeBlockId(); + } else { + state.block_id = INVALID_BLOCK; + } + state.block_size = Storage::BLOCK_SIZE; + state.offset = 0; + state.block_use_count = 1; +} + +bool PartialBlockManager::GetPartialBlock(idx_t segment_size, unique_ptr &partial_block) { + auto entry = partially_filled_blocks.lower_bound(segment_size); + if (entry == partially_filled_blocks.end()) { + return false; + } + // found a partially filled block! fill in the info + partial_block = std::move(entry->second); + partially_filled_blocks.erase(entry); + + D_ASSERT(partial_block->state.offset > 0); + D_ASSERT(ValueIsAligned(partial_block->state.offset)); + return true; +} + +void PartialBlockManager::RegisterPartialBlock(PartialBlockAllocation &&allocation) { + auto &state = allocation.partial_block->state; + D_ASSERT(checkpoint_type != CheckpointType::FULL_CHECKPOINT || state.block_id >= 0); + if (state.block_use_count < max_use_count) { + auto unaligned_size = allocation.allocation_size + state.offset; + auto new_size = AlignValue(unaligned_size); + if (new_size != unaligned_size) { + // register the uninitialized region so we can correctly initialize it before writing to disk + allocation.partial_block->AddUninitializedRegion(unaligned_size, new_size); + } + state.offset = new_size; + auto new_space_left = state.block_size - new_size; + // check if the block is STILL partially filled after adding the segment_size + if (new_space_left >= Storage::BLOCK_SIZE - max_partial_block_size) { + // the block is still partially filled: add it to the partially_filled_blocks list + partially_filled_blocks.insert(make_pair(new_space_left, std::move(allocation.partial_block))); + } + } + idx_t free_space = state.block_size - state.offset; + auto block_to_free = std::move(allocation.partial_block); + if (!block_to_free && partially_filled_blocks.size() > MAX_BLOCK_MAP_SIZE) { + // Free the page with the least space free. + auto itr = partially_filled_blocks.begin(); + block_to_free = std::move(itr->second); + free_space = state.block_size - itr->first; + partially_filled_blocks.erase(itr); + } + // Flush any block that we're not going to reuse. + if (block_to_free) { + block_to_free->Flush(free_space); + AddWrittenBlock(block_to_free->state.block_id); + } +} + +void PartialBlockManager::Merge(PartialBlockManager &other) { + if (&other == this) { + throw InternalException("Cannot merge into itself"); + } + // for each partially filled block in the other manager, check if we can merge it into an existing block in this + // manager + for (auto &e : other.partially_filled_blocks) { + if (!e.second) { + throw InternalException("Empty partially filled block found"); + } + auto used_space = Storage::BLOCK_SIZE - e.first; + if (HasBlockAllocation(used_space)) { + // we can merge this block into an existing block - merge them + // merge blocks + auto allocation = GetBlockAllocation(used_space); + allocation.partial_block->Merge(*e.second, allocation.state.offset, used_space); + + // re-register the partial block + allocation.state.offset += used_space; + RegisterPartialBlock(std::move(allocation)); + } else { + // we cannot merge this block - append it directly to the current block manager + partially_filled_blocks.insert(make_pair(e.first, std::move(e.second))); + } + } + // copy over the written blocks + for (auto &block_id : other.written_blocks) { + AddWrittenBlock(block_id); + } + other.written_blocks.clear(); + other.partially_filled_blocks.clear(); +} + +void PartialBlockManager::AddWrittenBlock(block_id_t block) { + auto entry = written_blocks.insert(block); + if (!entry.second) { + throw InternalException("Written block already exists"); + } +} + +void PartialBlockManager::ClearBlocks() { + for (auto &e : partially_filled_blocks) { + e.second->Clear(); + } + partially_filled_blocks.clear(); +} + +void PartialBlockManager::FlushPartialBlocks() { + for (auto &e : partially_filled_blocks) { + e.second->Flush(e.first); + } + partially_filled_blocks.clear(); +} + +void PartialBlockManager::Rollback() { + ClearBlocks(); + for (auto &block_id : written_blocks) { + block_manager.MarkBlockAsFree(block_id); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_constraint.cpp b/src/duckdb/src/storage/serialization/serialize_constraint.cpp new file mode 100644 index 00000000..9993a59e --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_constraint.cpp @@ -0,0 +1,98 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/constraints/list.hpp" + +namespace duckdb { + +void Constraint::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); +} + +unique_ptr Constraint::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + unique_ptr result; + switch (type) { + case ConstraintType::CHECK: + result = CheckConstraint::Deserialize(deserializer); + break; + case ConstraintType::FOREIGN_KEY: + result = ForeignKeyConstraint::Deserialize(deserializer); + break; + case ConstraintType::NOT_NULL: + result = NotNullConstraint::Deserialize(deserializer); + break; + case ConstraintType::UNIQUE: + result = UniqueConstraint::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of Constraint!"); + } + return result; +} + +void CheckConstraint::Serialize(Serializer &serializer) const { + Constraint::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "expression", expression); +} + +unique_ptr CheckConstraint::Deserialize(Deserializer &deserializer) { + auto expression = deserializer.ReadPropertyWithDefault>(200, "expression"); + auto result = duckdb::unique_ptr(new CheckConstraint(std::move(expression))); + return std::move(result); +} + +void ForeignKeyConstraint::Serialize(Serializer &serializer) const { + Constraint::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "pk_columns", pk_columns); + serializer.WritePropertyWithDefault>(201, "fk_columns", fk_columns); + serializer.WriteProperty(202, "fk_type", info.type); + serializer.WritePropertyWithDefault(203, "schema", info.schema); + serializer.WritePropertyWithDefault(204, "table", info.table); + serializer.WritePropertyWithDefault>(205, "pk_keys", info.pk_keys); + serializer.WritePropertyWithDefault>(206, "fk_keys", info.fk_keys); +} + +unique_ptr ForeignKeyConstraint::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ForeignKeyConstraint()); + deserializer.ReadPropertyWithDefault>(200, "pk_columns", result->pk_columns); + deserializer.ReadPropertyWithDefault>(201, "fk_columns", result->fk_columns); + deserializer.ReadProperty(202, "fk_type", result->info.type); + deserializer.ReadPropertyWithDefault(203, "schema", result->info.schema); + deserializer.ReadPropertyWithDefault(204, "table", result->info.table); + deserializer.ReadPropertyWithDefault>(205, "pk_keys", result->info.pk_keys); + deserializer.ReadPropertyWithDefault>(206, "fk_keys", result->info.fk_keys); + return std::move(result); +} + +void NotNullConstraint::Serialize(Serializer &serializer) const { + Constraint::Serialize(serializer); + serializer.WriteProperty(200, "index", index); +} + +unique_ptr NotNullConstraint::Deserialize(Deserializer &deserializer) { + auto index = deserializer.ReadProperty(200, "index"); + auto result = duckdb::unique_ptr(new NotNullConstraint(index)); + return std::move(result); +} + +void UniqueConstraint::Serialize(Serializer &serializer) const { + Constraint::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "is_primary_key", is_primary_key); + serializer.WriteProperty(201, "index", index); + serializer.WritePropertyWithDefault>(202, "columns", columns); +} + +unique_ptr UniqueConstraint::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new UniqueConstraint()); + deserializer.ReadPropertyWithDefault(200, "is_primary_key", result->is_primary_key); + deserializer.ReadProperty(201, "index", result->index); + deserializer.ReadPropertyWithDefault>(202, "columns", result->columns); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_create_info.cpp b/src/duckdb/src/storage/serialization/serialize_create_info.cpp new file mode 100644 index 00000000..821c6cec --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_create_info.cpp @@ -0,0 +1,198 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/parser/parsed_data/create_sequence_info.hpp" + +namespace duckdb { + +void CreateInfo::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); + serializer.WritePropertyWithDefault(101, "catalog", catalog); + serializer.WritePropertyWithDefault(102, "schema", schema); + serializer.WritePropertyWithDefault(103, "temporary", temporary); + serializer.WritePropertyWithDefault(104, "internal", internal); + serializer.WriteProperty(105, "on_conflict", on_conflict); + serializer.WritePropertyWithDefault(106, "sql", sql); +} + +unique_ptr CreateInfo::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto catalog = deserializer.ReadPropertyWithDefault(101, "catalog"); + auto schema = deserializer.ReadPropertyWithDefault(102, "schema"); + auto temporary = deserializer.ReadPropertyWithDefault(103, "temporary"); + auto internal = deserializer.ReadPropertyWithDefault(104, "internal"); + auto on_conflict = deserializer.ReadProperty(105, "on_conflict"); + auto sql = deserializer.ReadPropertyWithDefault(106, "sql"); + deserializer.Set(type); + unique_ptr result; + switch (type) { + case CatalogType::INDEX_ENTRY: + result = CreateIndexInfo::Deserialize(deserializer); + break; + case CatalogType::MACRO_ENTRY: + result = CreateMacroInfo::Deserialize(deserializer); + break; + case CatalogType::SCHEMA_ENTRY: + result = CreateSchemaInfo::Deserialize(deserializer); + break; + case CatalogType::SEQUENCE_ENTRY: + result = CreateSequenceInfo::Deserialize(deserializer); + break; + case CatalogType::TABLE_ENTRY: + result = CreateTableInfo::Deserialize(deserializer); + break; + case CatalogType::TABLE_MACRO_ENTRY: + result = CreateMacroInfo::Deserialize(deserializer); + break; + case CatalogType::TYPE_ENTRY: + result = CreateTypeInfo::Deserialize(deserializer); + break; + case CatalogType::VIEW_ENTRY: + result = CreateViewInfo::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of CreateInfo!"); + } + deserializer.Unset(); + result->catalog = std::move(catalog); + result->schema = std::move(schema); + result->temporary = temporary; + result->internal = internal; + result->on_conflict = on_conflict; + result->sql = std::move(sql); + return result; +} + +void CreateIndexInfo::Serialize(Serializer &serializer) const { + CreateInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", index_name); + serializer.WritePropertyWithDefault(201, "table", table); + serializer.WriteProperty(202, "index_type", index_type); + serializer.WriteProperty(203, "constraint_type", constraint_type); + serializer.WritePropertyWithDefault>>(204, "parsed_expressions", parsed_expressions); + serializer.WritePropertyWithDefault>(205, "scan_types", scan_types); + serializer.WritePropertyWithDefault>(206, "names", names); + serializer.WritePropertyWithDefault>(207, "column_ids", column_ids); + serializer.WritePropertyWithDefault>(208, "options", options); + serializer.WritePropertyWithDefault(209, "index_type_name", index_type_name); +} + +unique_ptr CreateIndexInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CreateIndexInfo()); + deserializer.ReadPropertyWithDefault(200, "name", result->index_name); + deserializer.ReadPropertyWithDefault(201, "table", result->table); + deserializer.ReadProperty(202, "index_type", result->index_type); + deserializer.ReadProperty(203, "constraint_type", result->constraint_type); + deserializer.ReadPropertyWithDefault>>(204, "parsed_expressions", result->parsed_expressions); + deserializer.ReadPropertyWithDefault>(205, "scan_types", result->scan_types); + deserializer.ReadPropertyWithDefault>(206, "names", result->names); + deserializer.ReadPropertyWithDefault>(207, "column_ids", result->column_ids); + deserializer.ReadPropertyWithDefault>(208, "options", result->options); + deserializer.ReadPropertyWithDefault(209, "index_type_name", result->index_type_name); + return std::move(result); +} + +void CreateMacroInfo::Serialize(Serializer &serializer) const { + CreateInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault>(201, "function", function); +} + +unique_ptr CreateMacroInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CreateMacroInfo(deserializer.Get())); + deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault>(201, "function", result->function); + return std::move(result); +} + +void CreateSchemaInfo::Serialize(Serializer &serializer) const { + CreateInfo::Serialize(serializer); +} + +unique_ptr CreateSchemaInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CreateSchemaInfo()); + return std::move(result); +} + +void CreateSequenceInfo::Serialize(Serializer &serializer) const { + CreateInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(201, "usage_count", usage_count); + serializer.WritePropertyWithDefault(202, "increment", increment); + serializer.WritePropertyWithDefault(203, "min_value", min_value); + serializer.WritePropertyWithDefault(204, "max_value", max_value); + serializer.WritePropertyWithDefault(205, "start_value", start_value); + serializer.WritePropertyWithDefault(206, "cycle", cycle); +} + +unique_ptr CreateSequenceInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CreateSequenceInfo()); + deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault(201, "usage_count", result->usage_count); + deserializer.ReadPropertyWithDefault(202, "increment", result->increment); + deserializer.ReadPropertyWithDefault(203, "min_value", result->min_value); + deserializer.ReadPropertyWithDefault(204, "max_value", result->max_value); + deserializer.ReadPropertyWithDefault(205, "start_value", result->start_value); + deserializer.ReadPropertyWithDefault(206, "cycle", result->cycle); + return std::move(result); +} + +void CreateTableInfo::Serialize(Serializer &serializer) const { + CreateInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table", table); + serializer.WriteProperty(201, "columns", columns); + serializer.WritePropertyWithDefault>>(202, "constraints", constraints); + serializer.WritePropertyWithDefault>(203, "query", query); +} + +unique_ptr CreateTableInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CreateTableInfo()); + deserializer.ReadPropertyWithDefault(200, "table", result->table); + deserializer.ReadProperty(201, "columns", result->columns); + deserializer.ReadPropertyWithDefault>>(202, "constraints", result->constraints); + deserializer.ReadPropertyWithDefault>(203, "query", result->query); + return std::move(result); +} + +void CreateTypeInfo::Serialize(Serializer &serializer) const { + CreateInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WriteProperty(201, "logical_type", type); +} + +unique_ptr CreateTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CreateTypeInfo()); + deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadProperty(201, "logical_type", result->type); + return std::move(result); +} + +void CreateViewInfo::Serialize(Serializer &serializer) const { + CreateInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "view_name", view_name); + serializer.WritePropertyWithDefault>(201, "aliases", aliases); + serializer.WritePropertyWithDefault>(202, "types", types); + serializer.WritePropertyWithDefault>(203, "query", query); +} + +unique_ptr CreateViewInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CreateViewInfo()); + deserializer.ReadPropertyWithDefault(200, "view_name", result->view_name); + deserializer.ReadPropertyWithDefault>(201, "aliases", result->aliases); + deserializer.ReadPropertyWithDefault>(202, "types", result->types); + deserializer.ReadPropertyWithDefault>(203, "query", result->query); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_expression.cpp b/src/duckdb/src/storage/serialization/serialize_expression.cpp new file mode 100644 index 00000000..bc0b6e19 --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_expression.cpp @@ -0,0 +1,283 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/planner/expression/list.hpp" + +namespace duckdb { + +void Expression::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "expression_class", expression_class); + serializer.WriteProperty(101, "type", type); + serializer.WritePropertyWithDefault(102, "alias", alias); +} + +unique_ptr Expression::Deserialize(Deserializer &deserializer) { + auto expression_class = deserializer.ReadProperty(100, "expression_class"); + auto type = deserializer.ReadProperty(101, "type"); + auto alias = deserializer.ReadPropertyWithDefault(102, "alias"); + deserializer.Set(type); + unique_ptr result; + switch (expression_class) { + case ExpressionClass::BOUND_AGGREGATE: + result = BoundAggregateExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_BETWEEN: + result = BoundBetweenExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_CASE: + result = BoundCaseExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_CAST: + result = BoundCastExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_COLUMN_REF: + result = BoundColumnRefExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_COMPARISON: + result = BoundComparisonExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_CONJUNCTION: + result = BoundConjunctionExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_CONSTANT: + result = BoundConstantExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_DEFAULT: + result = BoundDefaultExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_FUNCTION: + result = BoundFunctionExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_LAMBDA: + result = BoundLambdaExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_LAMBDA_REF: + result = BoundLambdaRefExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_OPERATOR: + result = BoundOperatorExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_PARAMETER: + result = BoundParameterExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_REF: + result = BoundReferenceExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_UNNEST: + result = BoundUnnestExpression::Deserialize(deserializer); + break; + case ExpressionClass::BOUND_WINDOW: + result = BoundWindowExpression::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of Expression!"); + } + deserializer.Unset(); + result->alias = std::move(alias); + return result; +} + +void BoundBetweenExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "input", input); + serializer.WritePropertyWithDefault>(201, "lower", lower); + serializer.WritePropertyWithDefault>(202, "upper", upper); + serializer.WritePropertyWithDefault(203, "lower_inclusive", lower_inclusive); + serializer.WritePropertyWithDefault(204, "upper_inclusive", upper_inclusive); +} + +unique_ptr BoundBetweenExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new BoundBetweenExpression()); + deserializer.ReadPropertyWithDefault>(200, "input", result->input); + deserializer.ReadPropertyWithDefault>(201, "lower", result->lower); + deserializer.ReadPropertyWithDefault>(202, "upper", result->upper); + deserializer.ReadPropertyWithDefault(203, "lower_inclusive", result->lower_inclusive); + deserializer.ReadPropertyWithDefault(204, "upper_inclusive", result->upper_inclusive); + return std::move(result); +} + +void BoundCaseExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WritePropertyWithDefault>(201, "case_checks", case_checks); + serializer.WritePropertyWithDefault>(202, "else_expr", else_expr); +} + +unique_ptr BoundCaseExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto result = duckdb::unique_ptr(new BoundCaseExpression(std::move(return_type))); + deserializer.ReadPropertyWithDefault>(201, "case_checks", result->case_checks); + deserializer.ReadPropertyWithDefault>(202, "else_expr", result->else_expr); + return std::move(result); +} + +void BoundCastExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "child", child); + serializer.WriteProperty(201, "return_type", return_type); + serializer.WritePropertyWithDefault(202, "try_cast", try_cast); +} + +unique_ptr BoundCastExpression::Deserialize(Deserializer &deserializer) { + auto child = deserializer.ReadPropertyWithDefault>(200, "child"); + auto return_type = deserializer.ReadProperty(201, "return_type"); + auto result = duckdb::unique_ptr(new BoundCastExpression(deserializer.Get(), std::move(child), std::move(return_type))); + deserializer.ReadPropertyWithDefault(202, "try_cast", result->try_cast); + return std::move(result); +} + +void BoundColumnRefExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WriteProperty(201, "binding", binding); + serializer.WritePropertyWithDefault(202, "depth", depth); +} + +unique_ptr BoundColumnRefExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto binding = deserializer.ReadProperty(201, "binding"); + auto depth = deserializer.ReadPropertyWithDefault(202, "depth"); + auto result = duckdb::unique_ptr(new BoundColumnRefExpression(std::move(return_type), binding, depth)); + return std::move(result); +} + +void BoundComparisonExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "left", left); + serializer.WritePropertyWithDefault>(201, "right", right); +} + +unique_ptr BoundComparisonExpression::Deserialize(Deserializer &deserializer) { + auto left = deserializer.ReadPropertyWithDefault>(200, "left"); + auto right = deserializer.ReadPropertyWithDefault>(201, "right"); + auto result = duckdb::unique_ptr(new BoundComparisonExpression(deserializer.Get(), std::move(left), std::move(right))); + return std::move(result); +} + +void BoundConjunctionExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "children", children); +} + +unique_ptr BoundConjunctionExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new BoundConjunctionExpression(deserializer.Get())); + deserializer.ReadPropertyWithDefault>>(200, "children", result->children); + return std::move(result); +} + +void BoundConstantExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "value", value); +} + +unique_ptr BoundConstantExpression::Deserialize(Deserializer &deserializer) { + auto value = deserializer.ReadProperty(200, "value"); + auto result = duckdb::unique_ptr(new BoundConstantExpression(value)); + return std::move(result); +} + +void BoundDefaultExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); +} + +unique_ptr BoundDefaultExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto result = duckdb::unique_ptr(new BoundDefaultExpression(std::move(return_type))); + return std::move(result); +} + +void BoundLambdaExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WritePropertyWithDefault>(201, "lambda_expr", lambda_expr); + serializer.WritePropertyWithDefault>>(202, "captures", captures); + serializer.WritePropertyWithDefault(203, "parameter_count", parameter_count); +} + +unique_ptr BoundLambdaExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto lambda_expr = deserializer.ReadPropertyWithDefault>(201, "lambda_expr"); + auto captures = deserializer.ReadPropertyWithDefault>>(202, "captures"); + auto parameter_count = deserializer.ReadPropertyWithDefault(203, "parameter_count"); + auto result = duckdb::unique_ptr(new BoundLambdaExpression(deserializer.Get(), std::move(return_type), std::move(lambda_expr), parameter_count)); + result->captures = std::move(captures); + return std::move(result); +} + +void BoundLambdaRefExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WriteProperty(201, "binding", binding); + serializer.WritePropertyWithDefault(202, "lambda_index", lambda_index); + serializer.WritePropertyWithDefault(203, "depth", depth); +} + +unique_ptr BoundLambdaRefExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto binding = deserializer.ReadProperty(201, "binding"); + auto lambda_index = deserializer.ReadPropertyWithDefault(202, "lambda_index"); + auto depth = deserializer.ReadPropertyWithDefault(203, "depth"); + auto result = duckdb::unique_ptr(new BoundLambdaRefExpression(std::move(return_type), binding, lambda_index, depth)); + return std::move(result); +} + +void BoundOperatorExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WritePropertyWithDefault>>(201, "children", children); +} + +unique_ptr BoundOperatorExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto result = duckdb::unique_ptr(new BoundOperatorExpression(deserializer.Get(), std::move(return_type))); + deserializer.ReadPropertyWithDefault>>(201, "children", result->children); + return std::move(result); +} + +void BoundParameterExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "identifier", identifier); + serializer.WriteProperty(201, "return_type", return_type); + serializer.WritePropertyWithDefault>(202, "parameter_data", parameter_data); +} + +unique_ptr BoundParameterExpression::Deserialize(Deserializer &deserializer) { + auto identifier = deserializer.ReadPropertyWithDefault(200, "identifier"); + auto return_type = deserializer.ReadProperty(201, "return_type"); + auto parameter_data = deserializer.ReadPropertyWithDefault>(202, "parameter_data"); + auto result = duckdb::unique_ptr(new BoundParameterExpression(deserializer.Get(), std::move(identifier), std::move(return_type), std::move(parameter_data))); + return std::move(result); +} + +void BoundReferenceExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WritePropertyWithDefault(201, "index", index); +} + +unique_ptr BoundReferenceExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto index = deserializer.ReadPropertyWithDefault(201, "index"); + auto result = duckdb::unique_ptr(new BoundReferenceExpression(std::move(return_type), index)); + return std::move(result); +} + +void BoundUnnestExpression::Serialize(Serializer &serializer) const { + Expression::Serialize(serializer); + serializer.WriteProperty(200, "return_type", return_type); + serializer.WritePropertyWithDefault>(201, "child", child); +} + +unique_ptr BoundUnnestExpression::Deserialize(Deserializer &deserializer) { + auto return_type = deserializer.ReadProperty(200, "return_type"); + auto result = duckdb::unique_ptr(new BoundUnnestExpression(std::move(return_type))); + deserializer.ReadPropertyWithDefault>(201, "child", result->child); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp b/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp new file mode 100644 index 00000000..f64dcab5 --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp @@ -0,0 +1,748 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" + +namespace duckdb { + +void LogicalOperator::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); + serializer.WritePropertyWithDefault>>(101, "children", children); +} + +unique_ptr LogicalOperator::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto children = deserializer.ReadPropertyWithDefault>>(101, "children"); + deserializer.Set(type); + unique_ptr result; + switch (type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + result = LogicalAggregate::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_ALTER: + result = LogicalSimple::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_ANY_JOIN: + result = LogicalAnyJoin::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + result = LogicalComparisonJoin::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_ATTACH: + result = LogicalSimple::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CHUNK_GET: + result = LogicalColumnDataGet::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + result = LogicalComparisonJoin::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_COPY_TO_FILE: + result = LogicalCopyToFile::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CREATE_INDEX: + result = LogicalCreateIndex::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CREATE_MACRO: + result = LogicalCreate::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: + result = LogicalCreate::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: + result = LogicalCreate::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CREATE_TABLE: + result = LogicalCreateTable::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CREATE_TYPE: + result = LogicalCreate::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CREATE_VIEW: + result = LogicalCreate::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + result = LogicalCrossProduct::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CTE_REF: + result = LogicalCTERef::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_DELETE: + result = LogicalDelete::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_DELIM_GET: + result = LogicalDelimGet::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + result = LogicalComparisonJoin::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_DETACH: + result = LogicalSimple::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_DISTINCT: + result = LogicalDistinct::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_DROP: + result = LogicalSimple::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + result = LogicalDummyScan::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_EMPTY_RESULT: + result = LogicalEmptyResult::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_EXCEPT: + result = LogicalSetOperation::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_EXPLAIN: + result = LogicalExplain::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + result = LogicalExpressionGet::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: + result = LogicalExtensionOperator::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_FILTER: + result = LogicalFilter::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_GET: + result = LogicalGet::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_INSERT: + result = LogicalInsert::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_INTERSECT: + result = LogicalSetOperation::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_LIMIT: + result = LogicalLimit::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: + result = LogicalLimitPercent::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_LOAD: + result = LogicalSimple::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: + result = LogicalMaterializedCTE::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_ORDER_BY: + result = LogicalOrder::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_PIVOT: + result = LogicalPivot::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: + result = LogicalPositionalJoin::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_PROJECTION: + result = LogicalProjection::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: + result = LogicalRecursiveCTE::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_RESET: + result = LogicalReset::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_SAMPLE: + result = LogicalSample::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_SET: + result = LogicalSet::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_SHOW: + result = LogicalShow::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_TOP_N: + result = LogicalTopN::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_TRANSACTION: + result = LogicalSimple::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_UNION: + result = LogicalSetOperation::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_UNNEST: + result = LogicalUnnest::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_UPDATE: + result = LogicalUpdate::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_VACUUM: + result = LogicalSimple::Deserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_WINDOW: + result = LogicalWindow::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of LogicalOperator!"); + } + deserializer.Unset(); + result->children = std::move(children); + return result; +} + +void LogicalAggregate::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "expressions", expressions); + serializer.WritePropertyWithDefault(201, "group_index", group_index); + serializer.WritePropertyWithDefault(202, "aggregate_index", aggregate_index); + serializer.WritePropertyWithDefault(203, "groupings_index", groupings_index); + serializer.WritePropertyWithDefault>>(204, "groups", groups); + serializer.WritePropertyWithDefault>(205, "grouping_sets", grouping_sets); + serializer.WritePropertyWithDefault>>(206, "grouping_functions", grouping_functions); +} + +unique_ptr LogicalAggregate::Deserialize(Deserializer &deserializer) { + auto expressions = deserializer.ReadPropertyWithDefault>>(200, "expressions"); + auto group_index = deserializer.ReadPropertyWithDefault(201, "group_index"); + auto aggregate_index = deserializer.ReadPropertyWithDefault(202, "aggregate_index"); + auto result = duckdb::unique_ptr(new LogicalAggregate(group_index, aggregate_index, std::move(expressions))); + deserializer.ReadPropertyWithDefault(203, "groupings_index", result->groupings_index); + deserializer.ReadPropertyWithDefault>>(204, "groups", result->groups); + deserializer.ReadPropertyWithDefault>(205, "grouping_sets", result->grouping_sets); + deserializer.ReadPropertyWithDefault>>(206, "grouping_functions", result->grouping_functions); + return std::move(result); +} + +void LogicalAnyJoin::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WriteProperty(200, "join_type", join_type); + serializer.WritePropertyWithDefault(201, "mark_index", mark_index); + serializer.WritePropertyWithDefault>(202, "left_projection_map", left_projection_map); + serializer.WritePropertyWithDefault>(203, "right_projection_map", right_projection_map); + serializer.WritePropertyWithDefault>(204, "condition", condition); +} + +unique_ptr LogicalAnyJoin::Deserialize(Deserializer &deserializer) { + auto join_type = deserializer.ReadProperty(200, "join_type"); + auto result = duckdb::unique_ptr(new LogicalAnyJoin(join_type)); + deserializer.ReadPropertyWithDefault(201, "mark_index", result->mark_index); + deserializer.ReadPropertyWithDefault>(202, "left_projection_map", result->left_projection_map); + deserializer.ReadPropertyWithDefault>(203, "right_projection_map", result->right_projection_map); + deserializer.ReadPropertyWithDefault>(204, "condition", result->condition); + return std::move(result); +} + +void LogicalCTERef::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table_index", table_index); + serializer.WritePropertyWithDefault(201, "cte_index", cte_index); + serializer.WritePropertyWithDefault>(202, "chunk_types", chunk_types); + serializer.WritePropertyWithDefault>(203, "bound_columns", bound_columns); + serializer.WriteProperty(204, "materialized_cte", materialized_cte); +} + +unique_ptr LogicalCTERef::Deserialize(Deserializer &deserializer) { + auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); + auto cte_index = deserializer.ReadPropertyWithDefault(201, "cte_index"); + auto chunk_types = deserializer.ReadPropertyWithDefault>(202, "chunk_types"); + auto bound_columns = deserializer.ReadPropertyWithDefault>(203, "bound_columns"); + auto materialized_cte = deserializer.ReadProperty(204, "materialized_cte"); + auto result = duckdb::unique_ptr(new LogicalCTERef(table_index, cte_index, std::move(chunk_types), std::move(bound_columns), materialized_cte)); + return std::move(result); +} + +void LogicalColumnDataGet::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table_index", table_index); + serializer.WritePropertyWithDefault>(201, "chunk_types", chunk_types); + serializer.WritePropertyWithDefault>(202, "collection", collection); +} + +unique_ptr LogicalColumnDataGet::Deserialize(Deserializer &deserializer) { + auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); + auto chunk_types = deserializer.ReadPropertyWithDefault>(201, "chunk_types"); + auto collection = deserializer.ReadPropertyWithDefault>(202, "collection"); + auto result = duckdb::unique_ptr(new LogicalColumnDataGet(table_index, std::move(chunk_types), std::move(collection))); + return std::move(result); +} + +void LogicalComparisonJoin::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WriteProperty(200, "join_type", join_type); + serializer.WritePropertyWithDefault(201, "mark_index", mark_index); + serializer.WritePropertyWithDefault>(202, "left_projection_map", left_projection_map); + serializer.WritePropertyWithDefault>(203, "right_projection_map", right_projection_map); + serializer.WritePropertyWithDefault>(204, "conditions", conditions); + serializer.WritePropertyWithDefault>(205, "mark_types", mark_types); + serializer.WritePropertyWithDefault>>(206, "duplicate_eliminated_columns", duplicate_eliminated_columns); +} + +unique_ptr LogicalComparisonJoin::Deserialize(Deserializer &deserializer) { + auto join_type = deserializer.ReadProperty(200, "join_type"); + auto result = duckdb::unique_ptr(new LogicalComparisonJoin(join_type, deserializer.Get())); + deserializer.ReadPropertyWithDefault(201, "mark_index", result->mark_index); + deserializer.ReadPropertyWithDefault>(202, "left_projection_map", result->left_projection_map); + deserializer.ReadPropertyWithDefault>(203, "right_projection_map", result->right_projection_map); + deserializer.ReadPropertyWithDefault>(204, "conditions", result->conditions); + deserializer.ReadPropertyWithDefault>(205, "mark_types", result->mark_types); + deserializer.ReadPropertyWithDefault>>(206, "duplicate_eliminated_columns", result->duplicate_eliminated_columns); + return std::move(result); +} + +void LogicalCreate::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "info", info); +} + +unique_ptr LogicalCreate::Deserialize(Deserializer &deserializer) { + auto info = deserializer.ReadPropertyWithDefault>(200, "info"); + auto result = duckdb::unique_ptr(new LogicalCreate(deserializer.Get(), deserializer.Get(), std::move(info))); + return std::move(result); +} + +void LogicalCreateIndex::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "info", info); + serializer.WritePropertyWithDefault>>(201, "unbound_expressions", unbound_expressions); +} + +unique_ptr LogicalCreateIndex::Deserialize(Deserializer &deserializer) { + auto info = deserializer.ReadPropertyWithDefault>(200, "info"); + auto unbound_expressions = deserializer.ReadPropertyWithDefault>>(201, "unbound_expressions"); + auto result = duckdb::unique_ptr(new LogicalCreateIndex(deserializer.Get(), std::move(info), std::move(unbound_expressions))); + return std::move(result); +} + +void LogicalCreateTable::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "info", info->base); +} + +unique_ptr LogicalCreateTable::Deserialize(Deserializer &deserializer) { + auto info = deserializer.ReadPropertyWithDefault>(200, "info"); + auto result = duckdb::unique_ptr(new LogicalCreateTable(deserializer.Get(), std::move(info))); + return std::move(result); +} + +void LogicalCrossProduct::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); +} + +unique_ptr LogicalCrossProduct::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalCrossProduct()); + return std::move(result); +} + +void LogicalDelete::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "table_info", table.GetInfo()); + serializer.WritePropertyWithDefault(201, "table_index", table_index); + serializer.WritePropertyWithDefault(202, "return_chunk", return_chunk); + serializer.WritePropertyWithDefault>>(203, "expressions", expressions); +} + +unique_ptr LogicalDelete::Deserialize(Deserializer &deserializer) { + auto table_info = deserializer.ReadPropertyWithDefault>(200, "table_info"); + auto result = duckdb::unique_ptr(new LogicalDelete(deserializer.Get(), table_info)); + deserializer.ReadPropertyWithDefault(201, "table_index", result->table_index); + deserializer.ReadPropertyWithDefault(202, "return_chunk", result->return_chunk); + deserializer.ReadPropertyWithDefault>>(203, "expressions", result->expressions); + return std::move(result); +} + +void LogicalDelimGet::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table_index", table_index); + serializer.WritePropertyWithDefault>(201, "chunk_types", chunk_types); +} + +unique_ptr LogicalDelimGet::Deserialize(Deserializer &deserializer) { + auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); + auto chunk_types = deserializer.ReadPropertyWithDefault>(201, "chunk_types"); + auto result = duckdb::unique_ptr(new LogicalDelimGet(table_index, std::move(chunk_types))); + return std::move(result); +} + +void LogicalDistinct::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WriteProperty(200, "distinct_type", distinct_type); + serializer.WritePropertyWithDefault>>(201, "distinct_targets", distinct_targets); + serializer.WritePropertyWithDefault>(202, "order_by", order_by); +} + +unique_ptr LogicalDistinct::Deserialize(Deserializer &deserializer) { + auto distinct_type = deserializer.ReadProperty(200, "distinct_type"); + auto distinct_targets = deserializer.ReadPropertyWithDefault>>(201, "distinct_targets"); + auto result = duckdb::unique_ptr(new LogicalDistinct(std::move(distinct_targets), distinct_type)); + deserializer.ReadPropertyWithDefault>(202, "order_by", result->order_by); + return std::move(result); +} + +void LogicalDummyScan::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table_index", table_index); +} + +unique_ptr LogicalDummyScan::Deserialize(Deserializer &deserializer) { + auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); + auto result = duckdb::unique_ptr(new LogicalDummyScan(table_index)); + return std::move(result); +} + +void LogicalEmptyResult::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "return_types", return_types); + serializer.WritePropertyWithDefault>(201, "bindings", bindings); +} + +unique_ptr LogicalEmptyResult::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalEmptyResult()); + deserializer.ReadPropertyWithDefault>(200, "return_types", result->return_types); + deserializer.ReadPropertyWithDefault>(201, "bindings", result->bindings); + return std::move(result); +} + +void LogicalExplain::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WriteProperty(200, "explain_type", explain_type); + serializer.WritePropertyWithDefault(201, "physical_plan", physical_plan); + serializer.WritePropertyWithDefault(202, "logical_plan_unopt", logical_plan_unopt); + serializer.WritePropertyWithDefault(203, "logical_plan_opt", logical_plan_opt); +} + +unique_ptr LogicalExplain::Deserialize(Deserializer &deserializer) { + auto explain_type = deserializer.ReadProperty(200, "explain_type"); + auto result = duckdb::unique_ptr(new LogicalExplain(explain_type)); + deserializer.ReadPropertyWithDefault(201, "physical_plan", result->physical_plan); + deserializer.ReadPropertyWithDefault(202, "logical_plan_unopt", result->logical_plan_unopt); + deserializer.ReadPropertyWithDefault(203, "logical_plan_opt", result->logical_plan_opt); + return std::move(result); +} + +void LogicalExpressionGet::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table_index", table_index); + serializer.WritePropertyWithDefault>(201, "expr_types", expr_types); + serializer.WritePropertyWithDefault>>>(202, "expressions", expressions); +} + +unique_ptr LogicalExpressionGet::Deserialize(Deserializer &deserializer) { + auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); + auto expr_types = deserializer.ReadPropertyWithDefault>(201, "expr_types"); + auto expressions = deserializer.ReadPropertyWithDefault>>>(202, "expressions"); + auto result = duckdb::unique_ptr(new LogicalExpressionGet(table_index, std::move(expr_types), std::move(expressions))); + return std::move(result); +} + +void LogicalFilter::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "expressions", expressions); + serializer.WritePropertyWithDefault>(201, "projection_map", projection_map); +} + +unique_ptr LogicalFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalFilter()); + deserializer.ReadPropertyWithDefault>>(200, "expressions", result->expressions); + deserializer.ReadPropertyWithDefault>(201, "projection_map", result->projection_map); + return std::move(result); +} + +void LogicalInsert::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "table_info", table.GetInfo()); + serializer.WritePropertyWithDefault>>>(201, "insert_values", insert_values); + serializer.WriteProperty>(202, "column_index_map", column_index_map); + serializer.WritePropertyWithDefault>(203, "expected_types", expected_types); + serializer.WritePropertyWithDefault(204, "table_index", table_index); + serializer.WritePropertyWithDefault(205, "return_chunk", return_chunk); + serializer.WritePropertyWithDefault>>(206, "bound_defaults", bound_defaults); + serializer.WriteProperty(207, "action_type", action_type); + serializer.WritePropertyWithDefault>(208, "expected_set_types", expected_set_types); + serializer.WritePropertyWithDefault>(209, "on_conflict_filter", on_conflict_filter); + serializer.WritePropertyWithDefault>(210, "on_conflict_condition", on_conflict_condition); + serializer.WritePropertyWithDefault>(211, "do_update_condition", do_update_condition); + serializer.WritePropertyWithDefault>(212, "set_columns", set_columns); + serializer.WritePropertyWithDefault>(213, "set_types", set_types); + serializer.WritePropertyWithDefault(214, "excluded_table_index", excluded_table_index); + serializer.WritePropertyWithDefault>(215, "columns_to_fetch", columns_to_fetch); + serializer.WritePropertyWithDefault>(216, "source_columns", source_columns); + serializer.WritePropertyWithDefault>>(217, "expressions", expressions); +} + +unique_ptr LogicalInsert::Deserialize(Deserializer &deserializer) { + auto table_info = deserializer.ReadPropertyWithDefault>(200, "table_info"); + auto result = duckdb::unique_ptr(new LogicalInsert(deserializer.Get(), std::move(table_info))); + deserializer.ReadPropertyWithDefault>>>(201, "insert_values", result->insert_values); + deserializer.ReadProperty>(202, "column_index_map", result->column_index_map); + deserializer.ReadPropertyWithDefault>(203, "expected_types", result->expected_types); + deserializer.ReadPropertyWithDefault(204, "table_index", result->table_index); + deserializer.ReadPropertyWithDefault(205, "return_chunk", result->return_chunk); + deserializer.ReadPropertyWithDefault>>(206, "bound_defaults", result->bound_defaults); + deserializer.ReadProperty(207, "action_type", result->action_type); + deserializer.ReadPropertyWithDefault>(208, "expected_set_types", result->expected_set_types); + deserializer.ReadPropertyWithDefault>(209, "on_conflict_filter", result->on_conflict_filter); + deserializer.ReadPropertyWithDefault>(210, "on_conflict_condition", result->on_conflict_condition); + deserializer.ReadPropertyWithDefault>(211, "do_update_condition", result->do_update_condition); + deserializer.ReadPropertyWithDefault>(212, "set_columns", result->set_columns); + deserializer.ReadPropertyWithDefault>(213, "set_types", result->set_types); + deserializer.ReadPropertyWithDefault(214, "excluded_table_index", result->excluded_table_index); + deserializer.ReadPropertyWithDefault>(215, "columns_to_fetch", result->columns_to_fetch); + deserializer.ReadPropertyWithDefault>(216, "source_columns", result->source_columns); + deserializer.ReadPropertyWithDefault>>(217, "expressions", result->expressions); + return std::move(result); +} + +void LogicalLimit::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "limit_val", limit_val); + serializer.WritePropertyWithDefault(201, "offset_val", offset_val); + serializer.WritePropertyWithDefault>(202, "limit", limit); + serializer.WritePropertyWithDefault>(203, "offset", offset); +} + +unique_ptr LogicalLimit::Deserialize(Deserializer &deserializer) { + auto limit_val = deserializer.ReadPropertyWithDefault(200, "limit_val"); + auto offset_val = deserializer.ReadPropertyWithDefault(201, "offset_val"); + auto limit = deserializer.ReadPropertyWithDefault>(202, "limit"); + auto offset = deserializer.ReadPropertyWithDefault>(203, "offset"); + auto result = duckdb::unique_ptr(new LogicalLimit(limit_val, offset_val, std::move(limit), std::move(offset))); + return std::move(result); +} + +void LogicalLimitPercent::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WriteProperty(200, "limit_percent", limit_percent); + serializer.WritePropertyWithDefault(201, "offset_val", offset_val); + serializer.WritePropertyWithDefault>(202, "limit", limit); + serializer.WritePropertyWithDefault>(203, "offset", offset); +} + +unique_ptr LogicalLimitPercent::Deserialize(Deserializer &deserializer) { + auto limit_percent = deserializer.ReadProperty(200, "limit_percent"); + auto offset_val = deserializer.ReadPropertyWithDefault(201, "offset_val"); + auto limit = deserializer.ReadPropertyWithDefault>(202, "limit"); + auto offset = deserializer.ReadPropertyWithDefault>(203, "offset"); + auto result = duckdb::unique_ptr(new LogicalLimitPercent(limit_percent, offset_val, std::move(limit), std::move(offset))); + return std::move(result); +} + +void LogicalMaterializedCTE::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table_index", table_index); + serializer.WritePropertyWithDefault(201, "column_count", column_count); + serializer.WritePropertyWithDefault(202, "ctename", ctename); +} + +unique_ptr LogicalMaterializedCTE::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalMaterializedCTE()); + deserializer.ReadPropertyWithDefault(200, "table_index", result->table_index); + deserializer.ReadPropertyWithDefault(201, "column_count", result->column_count); + deserializer.ReadPropertyWithDefault(202, "ctename", result->ctename); + return std::move(result); +} + +void LogicalOrder::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "orders", orders); + serializer.WritePropertyWithDefault>(201, "projections", projections); +} + +unique_ptr LogicalOrder::Deserialize(Deserializer &deserializer) { + auto orders = deserializer.ReadPropertyWithDefault>(200, "orders"); + auto result = duckdb::unique_ptr(new LogicalOrder(std::move(orders))); + deserializer.ReadPropertyWithDefault>(201, "projections", result->projections); + return std::move(result); +} + +void LogicalPivot::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "pivot_index", pivot_index); + serializer.WriteProperty(201, "bound_pivot", bound_pivot); +} + +unique_ptr LogicalPivot::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalPivot()); + deserializer.ReadPropertyWithDefault(200, "pivot_index", result->pivot_index); + deserializer.ReadProperty(201, "bound_pivot", result->bound_pivot); + return std::move(result); +} + +void LogicalPositionalJoin::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); +} + +unique_ptr LogicalPositionalJoin::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalPositionalJoin()); + return std::move(result); +} + +void LogicalProjection::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table_index", table_index); + serializer.WritePropertyWithDefault>>(201, "expressions", expressions); +} + +unique_ptr LogicalProjection::Deserialize(Deserializer &deserializer) { + auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); + auto expressions = deserializer.ReadPropertyWithDefault>>(201, "expressions"); + auto result = duckdb::unique_ptr(new LogicalProjection(table_index, std::move(expressions))); + return std::move(result); +} + +void LogicalRecursiveCTE::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "union_all", union_all); + serializer.WritePropertyWithDefault(201, "ctename", ctename); + serializer.WritePropertyWithDefault(202, "table_index", table_index); + serializer.WritePropertyWithDefault(203, "column_count", column_count); +} + +unique_ptr LogicalRecursiveCTE::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalRecursiveCTE()); + deserializer.ReadPropertyWithDefault(200, "union_all", result->union_all); + deserializer.ReadPropertyWithDefault(201, "ctename", result->ctename); + deserializer.ReadPropertyWithDefault(202, "table_index", result->table_index); + deserializer.ReadPropertyWithDefault(203, "column_count", result->column_count); + return std::move(result); +} + +void LogicalReset::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WriteProperty(201, "scope", scope); +} + +unique_ptr LogicalReset::Deserialize(Deserializer &deserializer) { + auto name = deserializer.ReadPropertyWithDefault(200, "name"); + auto scope = deserializer.ReadProperty(201, "scope"); + auto result = duckdb::unique_ptr(new LogicalReset(std::move(name), scope)); + return std::move(result); +} + +void LogicalSample::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "sample_options", sample_options); +} + +unique_ptr LogicalSample::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalSample()); + deserializer.ReadPropertyWithDefault>(200, "sample_options", result->sample_options); + return std::move(result); +} + +void LogicalSet::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WriteProperty(201, "value", value); + serializer.WriteProperty(202, "scope", scope); +} + +unique_ptr LogicalSet::Deserialize(Deserializer &deserializer) { + auto name = deserializer.ReadPropertyWithDefault(200, "name"); + auto value = deserializer.ReadProperty(201, "value"); + auto scope = deserializer.ReadProperty(202, "scope"); + auto result = duckdb::unique_ptr(new LogicalSet(std::move(name), value, scope)); + return std::move(result); +} + +void LogicalSetOperation::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "table_index", table_index); + serializer.WritePropertyWithDefault(201, "column_count", column_count); +} + +unique_ptr LogicalSetOperation::Deserialize(Deserializer &deserializer) { + auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); + auto column_count = deserializer.ReadPropertyWithDefault(201, "column_count"); + auto result = duckdb::unique_ptr(new LogicalSetOperation(table_index, column_count, deserializer.Get())); + return std::move(result); +} + +void LogicalShow::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "types_select", types_select); + serializer.WritePropertyWithDefault>(201, "aliases", aliases); +} + +unique_ptr LogicalShow::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LogicalShow()); + deserializer.ReadPropertyWithDefault>(200, "types_select", result->types_select); + deserializer.ReadPropertyWithDefault>(201, "aliases", result->aliases); + return std::move(result); +} + +void LogicalSimple::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "info", info); +} + +unique_ptr LogicalSimple::Deserialize(Deserializer &deserializer) { + auto info = deserializer.ReadPropertyWithDefault>(200, "info"); + auto result = duckdb::unique_ptr(new LogicalSimple(deserializer.Get(), std::move(info))); + return std::move(result); +} + +void LogicalTopN::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "orders", orders); + serializer.WritePropertyWithDefault(201, "limit", limit); + serializer.WritePropertyWithDefault(202, "offset", offset); +} + +unique_ptr LogicalTopN::Deserialize(Deserializer &deserializer) { + auto orders = deserializer.ReadPropertyWithDefault>(200, "orders"); + auto limit = deserializer.ReadPropertyWithDefault(201, "limit"); + auto offset = deserializer.ReadPropertyWithDefault(202, "offset"); + auto result = duckdb::unique_ptr(new LogicalTopN(std::move(orders), limit, offset)); + return std::move(result); +} + +void LogicalUnnest::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "unnest_index", unnest_index); + serializer.WritePropertyWithDefault>>(201, "expressions", expressions); +} + +unique_ptr LogicalUnnest::Deserialize(Deserializer &deserializer) { + auto unnest_index = deserializer.ReadPropertyWithDefault(200, "unnest_index"); + auto result = duckdb::unique_ptr(new LogicalUnnest(unnest_index)); + deserializer.ReadPropertyWithDefault>>(201, "expressions", result->expressions); + return std::move(result); +} + +void LogicalUpdate::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "table_info", table.GetInfo()); + serializer.WritePropertyWithDefault(201, "table_index", table_index); + serializer.WritePropertyWithDefault(202, "return_chunk", return_chunk); + serializer.WritePropertyWithDefault>>(203, "expressions", expressions); + serializer.WritePropertyWithDefault>(204, "columns", columns); + serializer.WritePropertyWithDefault>>(205, "bound_defaults", bound_defaults); + serializer.WritePropertyWithDefault(206, "update_is_del_and_insert", update_is_del_and_insert); +} + +unique_ptr LogicalUpdate::Deserialize(Deserializer &deserializer) { + auto table_info = deserializer.ReadPropertyWithDefault>(200, "table_info"); + auto result = duckdb::unique_ptr(new LogicalUpdate(deserializer.Get(), table_info)); + deserializer.ReadPropertyWithDefault(201, "table_index", result->table_index); + deserializer.ReadPropertyWithDefault(202, "return_chunk", result->return_chunk); + deserializer.ReadPropertyWithDefault>>(203, "expressions", result->expressions); + deserializer.ReadPropertyWithDefault>(204, "columns", result->columns); + deserializer.ReadPropertyWithDefault>>(205, "bound_defaults", result->bound_defaults); + deserializer.ReadPropertyWithDefault(206, "update_is_del_and_insert", result->update_is_del_and_insert); + return std::move(result); +} + +void LogicalWindow::Serialize(Serializer &serializer) const { + LogicalOperator::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "window_index", window_index); + serializer.WritePropertyWithDefault>>(201, "expressions", expressions); +} + +unique_ptr LogicalWindow::Deserialize(Deserializer &deserializer) { + auto window_index = deserializer.ReadPropertyWithDefault(200, "window_index"); + auto result = duckdb::unique_ptr(new LogicalWindow(window_index)); + deserializer.ReadPropertyWithDefault>>(201, "expressions", result->expressions); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_macro_function.cpp b/src/duckdb/src/storage/serialization/serialize_macro_function.cpp new file mode 100644 index 00000000..79484d80 --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_macro_function.cpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/function/macro_function.hpp" +#include "duckdb/function/scalar_macro_function.hpp" +#include "duckdb/function/table_macro_function.hpp" + +namespace duckdb { + +void MacroFunction::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); + serializer.WritePropertyWithDefault>>(101, "parameters", parameters); + serializer.WritePropertyWithDefault>>(102, "default_parameters", default_parameters); +} + +unique_ptr MacroFunction::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto parameters = deserializer.ReadPropertyWithDefault>>(101, "parameters"); + auto default_parameters = deserializer.ReadPropertyWithDefault>>(102, "default_parameters"); + unique_ptr result; + switch (type) { + case MacroType::SCALAR_MACRO: + result = ScalarMacroFunction::Deserialize(deserializer); + break; + case MacroType::TABLE_MACRO: + result = TableMacroFunction::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of MacroFunction!"); + } + result->parameters = std::move(parameters); + result->default_parameters = std::move(default_parameters); + return result; +} + +void ScalarMacroFunction::Serialize(Serializer &serializer) const { + MacroFunction::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "expression", expression); +} + +unique_ptr ScalarMacroFunction::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ScalarMacroFunction()); + deserializer.ReadPropertyWithDefault>(200, "expression", result->expression); + return std::move(result); +} + +void TableMacroFunction::Serialize(Serializer &serializer) const { + MacroFunction::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "query_node", query_node); +} + +unique_ptr TableMacroFunction::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new TableMacroFunction()); + deserializer.ReadPropertyWithDefault>(200, "query_node", result->query_node); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_nodes.cpp b/src/duckdb/src/storage/serialization/serialize_nodes.cpp new file mode 100644 index 00000000..1a8b50fc --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_nodes.cpp @@ -0,0 +1,459 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/extra_type_info.hpp" +#include "duckdb/parser/common_table_expression_info.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/parser/parsed_data/sample_options.hpp" +#include "duckdb/parser/tableref/pivotref.hpp" +#include "duckdb/planner/tableref/bound_pivotref.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/parser/column_list.hpp" +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/planner/expression/bound_parameter_data.hpp" +#include "duckdb/planner/joinside.hpp" +#include "duckdb/parser/parsed_data/vacuum_info.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/multi_file_reader_options.hpp" +#include "duckdb/common/multi_file_reader.hpp" +#include "duckdb/execution/operator/scan/csv/csv_reader_options.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/common/types/interval.hpp" + +namespace duckdb { + +void BoundCaseCheck::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "when_expr", when_expr); + serializer.WritePropertyWithDefault>(101, "then_expr", then_expr); +} + +BoundCaseCheck BoundCaseCheck::Deserialize(Deserializer &deserializer) { + BoundCaseCheck result; + deserializer.ReadPropertyWithDefault>(100, "when_expr", result.when_expr); + deserializer.ReadPropertyWithDefault>(101, "then_expr", result.then_expr); + return result; +} + +void BoundOrderByNode::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); + serializer.WriteProperty(101, "null_order", null_order); + serializer.WritePropertyWithDefault>(102, "expression", expression); +} + +BoundOrderByNode BoundOrderByNode::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto null_order = deserializer.ReadProperty(101, "null_order"); + auto expression = deserializer.ReadPropertyWithDefault>(102, "expression"); + BoundOrderByNode result(type, null_order, std::move(expression)); + return result; +} + +void BoundParameterData::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "value", value); + serializer.WriteProperty(101, "return_type", return_type); +} + +shared_ptr BoundParameterData::Deserialize(Deserializer &deserializer) { + auto value = deserializer.ReadProperty(100, "value"); + auto result = duckdb::shared_ptr(new BoundParameterData(value)); + deserializer.ReadProperty(101, "return_type", result->return_type); + return result; +} + +void BoundPivotInfo::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "group_count", group_count); + serializer.WritePropertyWithDefault>(101, "types", types); + serializer.WritePropertyWithDefault>(102, "pivot_values", pivot_values); + serializer.WritePropertyWithDefault>>(103, "aggregates", aggregates); +} + +BoundPivotInfo BoundPivotInfo::Deserialize(Deserializer &deserializer) { + BoundPivotInfo result; + deserializer.ReadPropertyWithDefault(100, "group_count", result.group_count); + deserializer.ReadPropertyWithDefault>(101, "types", result.types); + deserializer.ReadPropertyWithDefault>(102, "pivot_values", result.pivot_values); + deserializer.ReadPropertyWithDefault>>(103, "aggregates", result.aggregates); + return result; +} + +void CSVReaderOptions::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "has_delimiter", has_delimiter); + serializer.WritePropertyWithDefault(101, "has_quote", has_quote); + serializer.WritePropertyWithDefault(102, "has_escape", has_escape); + serializer.WritePropertyWithDefault(103, "has_header", has_header); + serializer.WritePropertyWithDefault(104, "ignore_errors", ignore_errors); + serializer.WritePropertyWithDefault(105, "buffer_sample_size", buffer_sample_size); + serializer.WritePropertyWithDefault(106, "null_str", null_str); + serializer.WriteProperty(107, "compression", compression); + serializer.WritePropertyWithDefault(108, "allow_quoted_nulls", allow_quoted_nulls); + serializer.WritePropertyWithDefault(109, "skip_rows_set", skip_rows_set); + serializer.WritePropertyWithDefault(110, "maximum_line_size", maximum_line_size); + serializer.WritePropertyWithDefault(111, "normalize_names", normalize_names); + serializer.WritePropertyWithDefault>(112, "force_not_null", force_not_null); + serializer.WritePropertyWithDefault(113, "all_varchar", all_varchar); + serializer.WritePropertyWithDefault(114, "sample_size_chunks", sample_size_chunks); + serializer.WritePropertyWithDefault(115, "auto_detect", auto_detect); + serializer.WritePropertyWithDefault(116, "file_path", file_path); + serializer.WritePropertyWithDefault(117, "decimal_separator", decimal_separator); + serializer.WritePropertyWithDefault(118, "null_padding", null_padding); + serializer.WritePropertyWithDefault(119, "buffer_size", buffer_size); + serializer.WriteProperty(120, "file_options", file_options); + serializer.WritePropertyWithDefault>(121, "force_quote", force_quote); + serializer.WritePropertyWithDefault(122, "rejects_table_name", rejects_table_name); + serializer.WritePropertyWithDefault(123, "rejects_limit", rejects_limit); + serializer.WritePropertyWithDefault>(124, "rejects_recovery_columns", rejects_recovery_columns); + serializer.WritePropertyWithDefault>(125, "rejects_recovery_column_ids", rejects_recovery_column_ids); + serializer.WriteProperty(126, "dialect_options.state_machine_options.delimiter", dialect_options.state_machine_options.delimiter); + serializer.WriteProperty(127, "dialect_options.state_machine_options.quote", dialect_options.state_machine_options.quote); + serializer.WriteProperty(128, "dialect_options.state_machine_options.escape", dialect_options.state_machine_options.escape); + serializer.WritePropertyWithDefault(129, "dialect_options.header", dialect_options.header); + serializer.WritePropertyWithDefault(130, "dialect_options.num_cols", dialect_options.num_cols); + serializer.WriteProperty(131, "dialect_options.new_line", dialect_options.new_line); + serializer.WritePropertyWithDefault(132, "dialect_options.skip_rows", dialect_options.skip_rows); + serializer.WritePropertyWithDefault>(133, "dialect_options.date_format", dialect_options.date_format); + serializer.WritePropertyWithDefault>(134, "dialect_options.has_format", dialect_options.has_format); +} + +CSVReaderOptions CSVReaderOptions::Deserialize(Deserializer &deserializer) { + CSVReaderOptions result; + deserializer.ReadPropertyWithDefault(100, "has_delimiter", result.has_delimiter); + deserializer.ReadPropertyWithDefault(101, "has_quote", result.has_quote); + deserializer.ReadPropertyWithDefault(102, "has_escape", result.has_escape); + deserializer.ReadPropertyWithDefault(103, "has_header", result.has_header); + deserializer.ReadPropertyWithDefault(104, "ignore_errors", result.ignore_errors); + deserializer.ReadPropertyWithDefault(105, "buffer_sample_size", result.buffer_sample_size); + deserializer.ReadPropertyWithDefault(106, "null_str", result.null_str); + deserializer.ReadProperty(107, "compression", result.compression); + deserializer.ReadPropertyWithDefault(108, "allow_quoted_nulls", result.allow_quoted_nulls); + deserializer.ReadPropertyWithDefault(109, "skip_rows_set", result.skip_rows_set); + deserializer.ReadPropertyWithDefault(110, "maximum_line_size", result.maximum_line_size); + deserializer.ReadPropertyWithDefault(111, "normalize_names", result.normalize_names); + deserializer.ReadPropertyWithDefault>(112, "force_not_null", result.force_not_null); + deserializer.ReadPropertyWithDefault(113, "all_varchar", result.all_varchar); + deserializer.ReadPropertyWithDefault(114, "sample_size_chunks", result.sample_size_chunks); + deserializer.ReadPropertyWithDefault(115, "auto_detect", result.auto_detect); + deserializer.ReadPropertyWithDefault(116, "file_path", result.file_path); + deserializer.ReadPropertyWithDefault(117, "decimal_separator", result.decimal_separator); + deserializer.ReadPropertyWithDefault(118, "null_padding", result.null_padding); + deserializer.ReadPropertyWithDefault(119, "buffer_size", result.buffer_size); + deserializer.ReadProperty(120, "file_options", result.file_options); + deserializer.ReadPropertyWithDefault>(121, "force_quote", result.force_quote); + deserializer.ReadPropertyWithDefault(122, "rejects_table_name", result.rejects_table_name); + deserializer.ReadPropertyWithDefault(123, "rejects_limit", result.rejects_limit); + deserializer.ReadPropertyWithDefault>(124, "rejects_recovery_columns", result.rejects_recovery_columns); + deserializer.ReadPropertyWithDefault>(125, "rejects_recovery_column_ids", result.rejects_recovery_column_ids); + deserializer.ReadProperty(126, "dialect_options.state_machine_options.delimiter", result.dialect_options.state_machine_options.delimiter); + deserializer.ReadProperty(127, "dialect_options.state_machine_options.quote", result.dialect_options.state_machine_options.quote); + deserializer.ReadProperty(128, "dialect_options.state_machine_options.escape", result.dialect_options.state_machine_options.escape); + deserializer.ReadPropertyWithDefault(129, "dialect_options.header", result.dialect_options.header); + deserializer.ReadPropertyWithDefault(130, "dialect_options.num_cols", result.dialect_options.num_cols); + deserializer.ReadProperty(131, "dialect_options.new_line", result.dialect_options.new_line); + deserializer.ReadPropertyWithDefault(132, "dialect_options.skip_rows", result.dialect_options.skip_rows); + deserializer.ReadPropertyWithDefault>(133, "dialect_options.date_format", result.dialect_options.date_format); + deserializer.ReadPropertyWithDefault>(134, "dialect_options.has_format", result.dialect_options.has_format); + return result; +} + +void CaseCheck::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "when_expr", when_expr); + serializer.WritePropertyWithDefault>(101, "then_expr", then_expr); +} + +CaseCheck CaseCheck::Deserialize(Deserializer &deserializer) { + CaseCheck result; + deserializer.ReadPropertyWithDefault>(100, "when_expr", result.when_expr); + deserializer.ReadPropertyWithDefault>(101, "then_expr", result.then_expr); + return result; +} + +void ColumnBinding::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "table_index", table_index); + serializer.WritePropertyWithDefault(101, "column_index", column_index); +} + +ColumnBinding ColumnBinding::Deserialize(Deserializer &deserializer) { + ColumnBinding result; + deserializer.ReadPropertyWithDefault(100, "table_index", result.table_index); + deserializer.ReadPropertyWithDefault(101, "column_index", result.column_index); + return result; +} + +void ColumnDefinition::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "name", name); + serializer.WriteProperty(101, "type", type); + serializer.WritePropertyWithDefault>(102, "expression", expression); + serializer.WriteProperty(103, "category", category); + serializer.WriteProperty(104, "compression_type", compression_type); +} + +ColumnDefinition ColumnDefinition::Deserialize(Deserializer &deserializer) { + auto name = deserializer.ReadPropertyWithDefault(100, "name"); + auto type = deserializer.ReadProperty(101, "type"); + auto expression = deserializer.ReadPropertyWithDefault>(102, "expression"); + auto category = deserializer.ReadProperty(103, "category"); + ColumnDefinition result(std::move(name), std::move(type), std::move(expression), category); + deserializer.ReadProperty(104, "compression_type", result.compression_type); + return result; +} + +void ColumnInfo::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "names", names); + serializer.WritePropertyWithDefault>(101, "types", types); +} + +ColumnInfo ColumnInfo::Deserialize(Deserializer &deserializer) { + ColumnInfo result; + deserializer.ReadPropertyWithDefault>(100, "names", result.names); + deserializer.ReadPropertyWithDefault>(101, "types", result.types); + return result; +} + +void ColumnList::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "columns", columns); +} + +ColumnList ColumnList::Deserialize(Deserializer &deserializer) { + auto columns = deserializer.ReadPropertyWithDefault>(100, "columns"); + ColumnList result(std::move(columns)); + return result; +} + +void CommonTableExpressionInfo::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "aliases", aliases); + serializer.WritePropertyWithDefault>(101, "query", query); + serializer.WriteProperty(102, "materialized", materialized); +} + +unique_ptr CommonTableExpressionInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CommonTableExpressionInfo()); + deserializer.ReadPropertyWithDefault>(100, "aliases", result->aliases); + deserializer.ReadPropertyWithDefault>(101, "query", result->query); + deserializer.ReadProperty(102, "materialized", result->materialized); + return result; +} + +void CommonTableExpressionMap::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>>(100, "map", map); +} + +CommonTableExpressionMap CommonTableExpressionMap::Deserialize(Deserializer &deserializer) { + CommonTableExpressionMap result; + deserializer.ReadPropertyWithDefault>>(100, "map", result.map); + return result; +} + +void HivePartitioningIndex::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "value", value); + serializer.WritePropertyWithDefault(101, "index", index); +} + +HivePartitioningIndex HivePartitioningIndex::Deserialize(Deserializer &deserializer) { + auto value = deserializer.ReadPropertyWithDefault(100, "value"); + auto index = deserializer.ReadPropertyWithDefault(101, "index"); + HivePartitioningIndex result(std::move(value), index); + return result; +} + +void JoinCondition::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "left", left); + serializer.WritePropertyWithDefault>(101, "right", right); + serializer.WriteProperty(102, "comparison", comparison); +} + +JoinCondition JoinCondition::Deserialize(Deserializer &deserializer) { + JoinCondition result; + deserializer.ReadPropertyWithDefault>(100, "left", result.left); + deserializer.ReadPropertyWithDefault>(101, "right", result.right); + deserializer.ReadProperty(102, "comparison", result.comparison); + return result; +} + +void LogicalType::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "id", id_); + serializer.WritePropertyWithDefault>(101, "type_info", type_info_); +} + +LogicalType LogicalType::Deserialize(Deserializer &deserializer) { + auto id = deserializer.ReadProperty(100, "id"); + auto type_info = deserializer.ReadPropertyWithDefault>(101, "type_info"); + LogicalType result(id, std::move(type_info)); + return result; +} + +void MultiFileReaderBindData::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "filename_idx", filename_idx); + serializer.WritePropertyWithDefault>(101, "hive_partitioning_indexes", hive_partitioning_indexes); +} + +MultiFileReaderBindData MultiFileReaderBindData::Deserialize(Deserializer &deserializer) { + MultiFileReaderBindData result; + deserializer.ReadPropertyWithDefault(100, "filename_idx", result.filename_idx); + deserializer.ReadPropertyWithDefault>(101, "hive_partitioning_indexes", result.hive_partitioning_indexes); + return result; +} + +void MultiFileReaderOptions::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "filename", filename); + serializer.WritePropertyWithDefault(101, "hive_partitioning", hive_partitioning); + serializer.WritePropertyWithDefault(102, "auto_detect_hive_partitioning", auto_detect_hive_partitioning); + serializer.WritePropertyWithDefault(103, "union_by_name", union_by_name); + serializer.WritePropertyWithDefault(104, "hive_types_autocast", hive_types_autocast); + serializer.WritePropertyWithDefault>(105, "hive_types_schema", hive_types_schema); +} + +MultiFileReaderOptions MultiFileReaderOptions::Deserialize(Deserializer &deserializer) { + MultiFileReaderOptions result; + deserializer.ReadPropertyWithDefault(100, "filename", result.filename); + deserializer.ReadPropertyWithDefault(101, "hive_partitioning", result.hive_partitioning); + deserializer.ReadPropertyWithDefault(102, "auto_detect_hive_partitioning", result.auto_detect_hive_partitioning); + deserializer.ReadPropertyWithDefault(103, "union_by_name", result.union_by_name); + deserializer.ReadPropertyWithDefault(104, "hive_types_autocast", result.hive_types_autocast); + deserializer.ReadPropertyWithDefault>(105, "hive_types_schema", result.hive_types_schema); + return result; +} + +void OrderByNode::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); + serializer.WriteProperty(101, "null_order", null_order); + serializer.WritePropertyWithDefault>(102, "expression", expression); +} + +OrderByNode OrderByNode::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto null_order = deserializer.ReadProperty(101, "null_order"); + auto expression = deserializer.ReadPropertyWithDefault>(102, "expression"); + OrderByNode result(type, null_order, std::move(expression)); + return result; +} + +void PivotColumn::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>>(100, "pivot_expressions", pivot_expressions); + serializer.WritePropertyWithDefault>(101, "unpivot_names", unpivot_names); + serializer.WritePropertyWithDefault>(102, "entries", entries); + serializer.WritePropertyWithDefault(103, "pivot_enum", pivot_enum); +} + +PivotColumn PivotColumn::Deserialize(Deserializer &deserializer) { + PivotColumn result; + deserializer.ReadPropertyWithDefault>>(100, "pivot_expressions", result.pivot_expressions); + deserializer.ReadPropertyWithDefault>(101, "unpivot_names", result.unpivot_names); + deserializer.ReadPropertyWithDefault>(102, "entries", result.entries); + deserializer.ReadPropertyWithDefault(103, "pivot_enum", result.pivot_enum); + return result; +} + +void PivotColumnEntry::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "values", values); + serializer.WritePropertyWithDefault>(101, "star_expr", star_expr); + serializer.WritePropertyWithDefault(102, "alias", alias); +} + +PivotColumnEntry PivotColumnEntry::Deserialize(Deserializer &deserializer) { + PivotColumnEntry result; + deserializer.ReadPropertyWithDefault>(100, "values", result.values); + deserializer.ReadPropertyWithDefault>(101, "star_expr", result.star_expr); + deserializer.ReadPropertyWithDefault(102, "alias", result.alias); + return result; +} + +void ReadCSVData::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "files", files); + serializer.WritePropertyWithDefault>(101, "csv_types", csv_types); + serializer.WritePropertyWithDefault>(102, "csv_names", csv_names); + serializer.WritePropertyWithDefault>(103, "return_types", return_types); + serializer.WritePropertyWithDefault>(104, "return_names", return_names); + serializer.WritePropertyWithDefault(105, "filename_col_idx", filename_col_idx); + serializer.WriteProperty(106, "options", options); + serializer.WritePropertyWithDefault(107, "single_threaded", single_threaded); + serializer.WriteProperty(108, "reader_bind", reader_bind); + serializer.WritePropertyWithDefault>(109, "column_info", column_info); +} + +unique_ptr ReadCSVData::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ReadCSVData()); + deserializer.ReadPropertyWithDefault>(100, "files", result->files); + deserializer.ReadPropertyWithDefault>(101, "csv_types", result->csv_types); + deserializer.ReadPropertyWithDefault>(102, "csv_names", result->csv_names); + deserializer.ReadPropertyWithDefault>(103, "return_types", result->return_types); + deserializer.ReadPropertyWithDefault>(104, "return_names", result->return_names); + deserializer.ReadPropertyWithDefault(105, "filename_col_idx", result->filename_col_idx); + deserializer.ReadProperty(106, "options", result->options); + deserializer.ReadPropertyWithDefault(107, "single_threaded", result->single_threaded); + deserializer.ReadProperty(108, "reader_bind", result->reader_bind); + deserializer.ReadPropertyWithDefault>(109, "column_info", result->column_info); + return result; +} + +void SampleOptions::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "sample_size", sample_size); + serializer.WritePropertyWithDefault(101, "is_percentage", is_percentage); + serializer.WriteProperty(102, "method", method); + serializer.WritePropertyWithDefault(103, "seed", seed); +} + +unique_ptr SampleOptions::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SampleOptions()); + deserializer.ReadProperty(100, "sample_size", result->sample_size); + deserializer.ReadPropertyWithDefault(101, "is_percentage", result->is_percentage); + deserializer.ReadProperty(102, "method", result->method); + deserializer.ReadPropertyWithDefault(103, "seed", result->seed); + return result; +} + +void StrpTimeFormat::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "format_specifier", format_specifier); +} + +StrpTimeFormat StrpTimeFormat::Deserialize(Deserializer &deserializer) { + auto format_specifier = deserializer.ReadPropertyWithDefault(100, "format_specifier"); + StrpTimeFormat result(format_specifier); + return result; +} + +void TableFilterSet::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>>(100, "filters", filters); +} + +TableFilterSet TableFilterSet::Deserialize(Deserializer &deserializer) { + TableFilterSet result; + deserializer.ReadPropertyWithDefault>>(100, "filters", result.filters); + return result; +} + +void VacuumOptions::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "vacuum", vacuum); + serializer.WritePropertyWithDefault(101, "analyze", analyze); +} + +VacuumOptions VacuumOptions::Deserialize(Deserializer &deserializer) { + VacuumOptions result; + deserializer.ReadPropertyWithDefault(100, "vacuum", result.vacuum); + deserializer.ReadPropertyWithDefault(101, "analyze", result.analyze); + return result; +} + +void interval_t::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(1, "months", months); + serializer.WritePropertyWithDefault(2, "days", days); + serializer.WritePropertyWithDefault(3, "micros", micros); +} + +interval_t interval_t::Deserialize(Deserializer &deserializer) { + interval_t result; + deserializer.ReadPropertyWithDefault(1, "months", result.months); + deserializer.ReadPropertyWithDefault(2, "days", result.days); + deserializer.ReadPropertyWithDefault(3, "micros", result.micros); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_parse_info.cpp b/src/duckdb/src/storage/serialization/serialize_parse_info.cpp new file mode 100644 index 00000000..ec86b5fe --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_parse_info.cpp @@ -0,0 +1,421 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/parser/parsed_data/alter_info.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/parsed_data/attach_info.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/parser/parsed_data/detach_info.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/parser/parsed_data/load_info.hpp" +#include "duckdb/parser/parsed_data/pragma_info.hpp" +#include "duckdb/parser/parsed_data/transaction_info.hpp" +#include "duckdb/parser/parsed_data/vacuum_info.hpp" + +namespace duckdb { + +void ParseInfo::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "info_type", info_type); +} + +unique_ptr ParseInfo::Deserialize(Deserializer &deserializer) { + auto info_type = deserializer.ReadProperty(100, "info_type"); + unique_ptr result; + switch (info_type) { + case ParseInfoType::ALTER_INFO: + result = AlterInfo::Deserialize(deserializer); + break; + case ParseInfoType::ATTACH_INFO: + result = AttachInfo::Deserialize(deserializer); + break; + case ParseInfoType::COPY_INFO: + result = CopyInfo::Deserialize(deserializer); + break; + case ParseInfoType::DETACH_INFO: + result = DetachInfo::Deserialize(deserializer); + break; + case ParseInfoType::DROP_INFO: + result = DropInfo::Deserialize(deserializer); + break; + case ParseInfoType::LOAD_INFO: + result = LoadInfo::Deserialize(deserializer); + break; + case ParseInfoType::PRAGMA_INFO: + result = PragmaInfo::Deserialize(deserializer); + break; + case ParseInfoType::TRANSACTION_INFO: + result = TransactionInfo::Deserialize(deserializer); + break; + case ParseInfoType::VACUUM_INFO: + result = VacuumInfo::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of ParseInfo!"); + } + return result; +} + +void AlterInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WriteProperty(200, "type", type); + serializer.WritePropertyWithDefault(201, "catalog", catalog); + serializer.WritePropertyWithDefault(202, "schema", schema); + serializer.WritePropertyWithDefault(203, "name", name); + serializer.WriteProperty(204, "if_not_found", if_not_found); + serializer.WritePropertyWithDefault(205, "allow_internal", allow_internal); +} + +unique_ptr AlterInfo::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(200, "type"); + auto catalog = deserializer.ReadPropertyWithDefault(201, "catalog"); + auto schema = deserializer.ReadPropertyWithDefault(202, "schema"); + auto name = deserializer.ReadPropertyWithDefault(203, "name"); + auto if_not_found = deserializer.ReadProperty(204, "if_not_found"); + auto allow_internal = deserializer.ReadPropertyWithDefault(205, "allow_internal"); + unique_ptr result; + switch (type) { + case AlterType::ALTER_TABLE: + result = AlterTableInfo::Deserialize(deserializer); + break; + case AlterType::ALTER_VIEW: + result = AlterViewInfo::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of AlterInfo!"); + } + result->catalog = std::move(catalog); + result->schema = std::move(schema); + result->name = std::move(name); + result->if_not_found = if_not_found; + result->allow_internal = allow_internal; + return std::move(result); +} + +void AlterTableInfo::Serialize(Serializer &serializer) const { + AlterInfo::Serialize(serializer); + serializer.WriteProperty(300, "alter_table_type", alter_table_type); +} + +unique_ptr AlterTableInfo::Deserialize(Deserializer &deserializer) { + auto alter_table_type = deserializer.ReadProperty(300, "alter_table_type"); + unique_ptr result; + switch (alter_table_type) { + case AlterTableType::ADD_COLUMN: + result = AddColumnInfo::Deserialize(deserializer); + break; + case AlterTableType::ALTER_COLUMN_TYPE: + result = ChangeColumnTypeInfo::Deserialize(deserializer); + break; + case AlterTableType::DROP_NOT_NULL: + result = DropNotNullInfo::Deserialize(deserializer); + break; + case AlterTableType::FOREIGN_KEY_CONSTRAINT: + result = AlterForeignKeyInfo::Deserialize(deserializer); + break; + case AlterTableType::REMOVE_COLUMN: + result = RemoveColumnInfo::Deserialize(deserializer); + break; + case AlterTableType::RENAME_COLUMN: + result = RenameColumnInfo::Deserialize(deserializer); + break; + case AlterTableType::RENAME_TABLE: + result = RenameTableInfo::Deserialize(deserializer); + break; + case AlterTableType::SET_DEFAULT: + result = SetDefaultInfo::Deserialize(deserializer); + break; + case AlterTableType::SET_NOT_NULL: + result = SetNotNullInfo::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of AlterTableInfo!"); + } + return std::move(result); +} + +void AlterViewInfo::Serialize(Serializer &serializer) const { + AlterInfo::Serialize(serializer); + serializer.WriteProperty(300, "alter_view_type", alter_view_type); +} + +unique_ptr AlterViewInfo::Deserialize(Deserializer &deserializer) { + auto alter_view_type = deserializer.ReadProperty(300, "alter_view_type"); + unique_ptr result; + switch (alter_view_type) { + case AlterViewType::RENAME_VIEW: + result = RenameViewInfo::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of AlterViewInfo!"); + } + return std::move(result); +} + +void AddColumnInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WriteProperty(400, "new_column", new_column); + serializer.WritePropertyWithDefault(401, "if_column_not_exists", if_column_not_exists); +} + +unique_ptr AddColumnInfo::Deserialize(Deserializer &deserializer) { + auto new_column = deserializer.ReadProperty(400, "new_column"); + auto result = duckdb::unique_ptr(new AddColumnInfo(std::move(new_column))); + deserializer.ReadPropertyWithDefault(401, "if_column_not_exists", result->if_column_not_exists); + return std::move(result); +} + +void AlterForeignKeyInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "fk_table", fk_table); + serializer.WritePropertyWithDefault>(401, "pk_columns", pk_columns); + serializer.WritePropertyWithDefault>(402, "fk_columns", fk_columns); + serializer.WritePropertyWithDefault>(403, "pk_keys", pk_keys); + serializer.WritePropertyWithDefault>(404, "fk_keys", fk_keys); + serializer.WriteProperty(405, "alter_fk_type", type); +} + +unique_ptr AlterForeignKeyInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new AlterForeignKeyInfo()); + deserializer.ReadPropertyWithDefault(400, "fk_table", result->fk_table); + deserializer.ReadPropertyWithDefault>(401, "pk_columns", result->pk_columns); + deserializer.ReadPropertyWithDefault>(402, "fk_columns", result->fk_columns); + deserializer.ReadPropertyWithDefault>(403, "pk_keys", result->pk_keys); + deserializer.ReadPropertyWithDefault>(404, "fk_keys", result->fk_keys); + deserializer.ReadProperty(405, "alter_fk_type", result->type); + return std::move(result); +} + +void AttachInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(201, "path", path); + serializer.WritePropertyWithDefault>(202, "options", options); +} + +unique_ptr AttachInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new AttachInfo()); + deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault(201, "path", result->path); + deserializer.ReadPropertyWithDefault>(202, "options", result->options); + return std::move(result); +} + +void ChangeColumnTypeInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "column_name", column_name); + serializer.WriteProperty(401, "target_type", target_type); + serializer.WritePropertyWithDefault>(402, "expression", expression); +} + +unique_ptr ChangeColumnTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ChangeColumnTypeInfo()); + deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); + deserializer.ReadProperty(401, "target_type", result->target_type); + deserializer.ReadPropertyWithDefault>(402, "expression", result->expression); + return std::move(result); +} + +void CopyInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "catalog", catalog); + serializer.WritePropertyWithDefault(201, "schema", schema); + serializer.WritePropertyWithDefault(202, "table", table); + serializer.WritePropertyWithDefault>(203, "select_list", select_list); + serializer.WritePropertyWithDefault(204, "is_from", is_from); + serializer.WritePropertyWithDefault(205, "format", format); + serializer.WritePropertyWithDefault(206, "file_path", file_path); + serializer.WritePropertyWithDefault>>(207, "options", options); +} + +unique_ptr CopyInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CopyInfo()); + deserializer.ReadPropertyWithDefault(200, "catalog", result->catalog); + deserializer.ReadPropertyWithDefault(201, "schema", result->schema); + deserializer.ReadPropertyWithDefault(202, "table", result->table); + deserializer.ReadPropertyWithDefault>(203, "select_list", result->select_list); + deserializer.ReadPropertyWithDefault(204, "is_from", result->is_from); + deserializer.ReadPropertyWithDefault(205, "format", result->format); + deserializer.ReadPropertyWithDefault(206, "file_path", result->file_path); + deserializer.ReadPropertyWithDefault>>(207, "options", result->options); + return std::move(result); +} + +void DetachInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WriteProperty(201, "if_not_found", if_not_found); +} + +unique_ptr DetachInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new DetachInfo()); + deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadProperty(201, "if_not_found", result->if_not_found); + return std::move(result); +} + +void DropInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WriteProperty(200, "type", type); + serializer.WritePropertyWithDefault(201, "catalog", catalog); + serializer.WritePropertyWithDefault(202, "schema", schema); + serializer.WritePropertyWithDefault(203, "name", name); + serializer.WriteProperty(204, "if_not_found", if_not_found); + serializer.WritePropertyWithDefault(205, "cascade", cascade); + serializer.WritePropertyWithDefault(206, "allow_drop_internal", allow_drop_internal); +} + +unique_ptr DropInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new DropInfo()); + deserializer.ReadProperty(200, "type", result->type); + deserializer.ReadPropertyWithDefault(201, "catalog", result->catalog); + deserializer.ReadPropertyWithDefault(202, "schema", result->schema); + deserializer.ReadPropertyWithDefault(203, "name", result->name); + deserializer.ReadProperty(204, "if_not_found", result->if_not_found); + deserializer.ReadPropertyWithDefault(205, "cascade", result->cascade); + deserializer.ReadPropertyWithDefault(206, "allow_drop_internal", result->allow_drop_internal); + return std::move(result); +} + +void DropNotNullInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "column_name", column_name); +} + +unique_ptr DropNotNullInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new DropNotNullInfo()); + deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); + return std::move(result); +} + +void LoadInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "filename", filename); + serializer.WriteProperty(201, "load_type", load_type); + serializer.WritePropertyWithDefault(202, "repository", repository); +} + +unique_ptr LoadInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LoadInfo()); + deserializer.ReadPropertyWithDefault(200, "filename", result->filename); + deserializer.ReadProperty(201, "load_type", result->load_type); + deserializer.ReadPropertyWithDefault(202, "repository", result->repository); + return std::move(result); +} + +void PragmaInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault>(201, "parameters", parameters); + serializer.WriteProperty(202, "named_parameters", named_parameters); +} + +unique_ptr PragmaInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new PragmaInfo()); + deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault>(201, "parameters", result->parameters); + deserializer.ReadProperty(202, "named_parameters", result->named_parameters); + return std::move(result); +} + +void RemoveColumnInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "removed_column", removed_column); + serializer.WritePropertyWithDefault(401, "if_column_exists", if_column_exists); + serializer.WritePropertyWithDefault(402, "cascade", cascade); +} + +unique_ptr RemoveColumnInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new RemoveColumnInfo()); + deserializer.ReadPropertyWithDefault(400, "removed_column", result->removed_column); + deserializer.ReadPropertyWithDefault(401, "if_column_exists", result->if_column_exists); + deserializer.ReadPropertyWithDefault(402, "cascade", result->cascade); + return std::move(result); +} + +void RenameColumnInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "old_name", old_name); + serializer.WritePropertyWithDefault(401, "new_name", new_name); +} + +unique_ptr RenameColumnInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new RenameColumnInfo()); + deserializer.ReadPropertyWithDefault(400, "old_name", result->old_name); + deserializer.ReadPropertyWithDefault(401, "new_name", result->new_name); + return std::move(result); +} + +void RenameTableInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "new_table_name", new_table_name); +} + +unique_ptr RenameTableInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new RenameTableInfo()); + deserializer.ReadPropertyWithDefault(400, "new_table_name", result->new_table_name); + return std::move(result); +} + +void RenameViewInfo::Serialize(Serializer &serializer) const { + AlterViewInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "new_view_name", new_view_name); +} + +unique_ptr RenameViewInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new RenameViewInfo()); + deserializer.ReadPropertyWithDefault(400, "new_view_name", result->new_view_name); + return std::move(result); +} + +void SetDefaultInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "column_name", column_name); + serializer.WritePropertyWithDefault>(401, "expression", expression); +} + +unique_ptr SetDefaultInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SetDefaultInfo()); + deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); + deserializer.ReadPropertyWithDefault>(401, "expression", result->expression); + return std::move(result); +} + +void SetNotNullInfo::Serialize(Serializer &serializer) const { + AlterTableInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(400, "column_name", column_name); +} + +unique_ptr SetNotNullInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SetNotNullInfo()); + deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); + return std::move(result); +} + +void TransactionInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WriteProperty(200, "type", type); +} + +unique_ptr TransactionInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new TransactionInfo()); + deserializer.ReadProperty(200, "type", result->type); + return std::move(result); +} + +void VacuumInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WriteProperty(200, "options", options); +} + +unique_ptr VacuumInfo::Deserialize(Deserializer &deserializer) { + auto options = deserializer.ReadProperty(200, "options"); + auto result = duckdb::unique_ptr(new VacuumInfo(options)); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_parsed_expression.cpp b/src/duckdb/src/storage/serialization/serialize_parsed_expression.cpp new file mode 100644 index 00000000..39297311 --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_parsed_expression.cpp @@ -0,0 +1,342 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/expression/list.hpp" + +namespace duckdb { + +void ParsedExpression::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "class", expression_class); + serializer.WriteProperty(101, "type", type); + serializer.WritePropertyWithDefault(102, "alias", alias); +} + +unique_ptr ParsedExpression::Deserialize(Deserializer &deserializer) { + auto expression_class = deserializer.ReadProperty(100, "class"); + auto type = deserializer.ReadProperty(101, "type"); + auto alias = deserializer.ReadPropertyWithDefault(102, "alias"); + deserializer.Set(type); + unique_ptr result; + switch (expression_class) { + case ExpressionClass::BETWEEN: + result = BetweenExpression::Deserialize(deserializer); + break; + case ExpressionClass::CASE: + result = CaseExpression::Deserialize(deserializer); + break; + case ExpressionClass::CAST: + result = CastExpression::Deserialize(deserializer); + break; + case ExpressionClass::COLLATE: + result = CollateExpression::Deserialize(deserializer); + break; + case ExpressionClass::COLUMN_REF: + result = ColumnRefExpression::Deserialize(deserializer); + break; + case ExpressionClass::COMPARISON: + result = ComparisonExpression::Deserialize(deserializer); + break; + case ExpressionClass::CONJUNCTION: + result = ConjunctionExpression::Deserialize(deserializer); + break; + case ExpressionClass::CONSTANT: + result = ConstantExpression::Deserialize(deserializer); + break; + case ExpressionClass::DEFAULT: + result = DefaultExpression::Deserialize(deserializer); + break; + case ExpressionClass::FUNCTION: + result = FunctionExpression::Deserialize(deserializer); + break; + case ExpressionClass::LAMBDA: + result = LambdaExpression::Deserialize(deserializer); + break; + case ExpressionClass::OPERATOR: + result = OperatorExpression::Deserialize(deserializer); + break; + case ExpressionClass::PARAMETER: + result = ParameterExpression::Deserialize(deserializer); + break; + case ExpressionClass::POSITIONAL_REFERENCE: + result = PositionalReferenceExpression::Deserialize(deserializer); + break; + case ExpressionClass::STAR: + result = StarExpression::Deserialize(deserializer); + break; + case ExpressionClass::SUBQUERY: + result = SubqueryExpression::Deserialize(deserializer); + break; + case ExpressionClass::WINDOW: + result = WindowExpression::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of ParsedExpression!"); + } + deserializer.Unset(); + result->alias = std::move(alias); + return result; +} + +void BetweenExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "input", input); + serializer.WritePropertyWithDefault>(201, "lower", lower); + serializer.WritePropertyWithDefault>(202, "upper", upper); +} + +unique_ptr BetweenExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new BetweenExpression()); + deserializer.ReadPropertyWithDefault>(200, "input", result->input); + deserializer.ReadPropertyWithDefault>(201, "lower", result->lower); + deserializer.ReadPropertyWithDefault>(202, "upper", result->upper); + return std::move(result); +} + +void CaseExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "case_checks", case_checks); + serializer.WritePropertyWithDefault>(201, "else_expr", else_expr); +} + +unique_ptr CaseExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CaseExpression()); + deserializer.ReadPropertyWithDefault>(200, "case_checks", result->case_checks); + deserializer.ReadPropertyWithDefault>(201, "else_expr", result->else_expr); + return std::move(result); +} + +void CastExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "child", child); + serializer.WriteProperty(201, "cast_type", cast_type); + serializer.WritePropertyWithDefault(202, "try_cast", try_cast); +} + +unique_ptr CastExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CastExpression()); + deserializer.ReadPropertyWithDefault>(200, "child", result->child); + deserializer.ReadProperty(201, "cast_type", result->cast_type); + deserializer.ReadPropertyWithDefault(202, "try_cast", result->try_cast); + return std::move(result); +} + +void CollateExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "child", child); + serializer.WritePropertyWithDefault(201, "collation", collation); +} + +unique_ptr CollateExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CollateExpression()); + deserializer.ReadPropertyWithDefault>(200, "child", result->child); + deserializer.ReadPropertyWithDefault(201, "collation", result->collation); + return std::move(result); +} + +void ColumnRefExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "column_names", column_names); +} + +unique_ptr ColumnRefExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ColumnRefExpression()); + deserializer.ReadPropertyWithDefault>(200, "column_names", result->column_names); + return std::move(result); +} + +void ComparisonExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "left", left); + serializer.WritePropertyWithDefault>(201, "right", right); +} + +unique_ptr ComparisonExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ComparisonExpression(deserializer.Get())); + deserializer.ReadPropertyWithDefault>(200, "left", result->left); + deserializer.ReadPropertyWithDefault>(201, "right", result->right); + return std::move(result); +} + +void ConjunctionExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "children", children); +} + +unique_ptr ConjunctionExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ConjunctionExpression(deserializer.Get())); + deserializer.ReadPropertyWithDefault>>(200, "children", result->children); + return std::move(result); +} + +void ConstantExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WriteProperty(200, "value", value); +} + +unique_ptr ConstantExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ConstantExpression()); + deserializer.ReadProperty(200, "value", result->value); + return std::move(result); +} + +void DefaultExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); +} + +unique_ptr DefaultExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new DefaultExpression()); + return std::move(result); +} + +void FunctionExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "function_name", function_name); + serializer.WritePropertyWithDefault(201, "schema", schema); + serializer.WritePropertyWithDefault>>(202, "children", children); + serializer.WritePropertyWithDefault>(203, "filter", filter); + serializer.WritePropertyWithDefault>(204, "order_bys", order_bys); + serializer.WritePropertyWithDefault(205, "distinct", distinct); + serializer.WritePropertyWithDefault(206, "is_operator", is_operator); + serializer.WritePropertyWithDefault(207, "export_state", export_state); + serializer.WritePropertyWithDefault(208, "catalog", catalog); +} + +unique_ptr FunctionExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new FunctionExpression()); + deserializer.ReadPropertyWithDefault(200, "function_name", result->function_name); + deserializer.ReadPropertyWithDefault(201, "schema", result->schema); + deserializer.ReadPropertyWithDefault>>(202, "children", result->children); + deserializer.ReadPropertyWithDefault>(203, "filter", result->filter); + auto order_bys = deserializer.ReadPropertyWithDefault>(204, "order_bys"); + result->order_bys = unique_ptr_cast(std::move(order_bys)); + deserializer.ReadPropertyWithDefault(205, "distinct", result->distinct); + deserializer.ReadPropertyWithDefault(206, "is_operator", result->is_operator); + deserializer.ReadPropertyWithDefault(207, "export_state", result->export_state); + deserializer.ReadPropertyWithDefault(208, "catalog", result->catalog); + return std::move(result); +} + +void LambdaExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "lhs", lhs); + serializer.WritePropertyWithDefault>(201, "expr", expr); +} + +unique_ptr LambdaExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LambdaExpression()); + deserializer.ReadPropertyWithDefault>(200, "lhs", result->lhs); + deserializer.ReadPropertyWithDefault>(201, "expr", result->expr); + return std::move(result); +} + +void OperatorExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "children", children); +} + +unique_ptr OperatorExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new OperatorExpression(deserializer.Get())); + deserializer.ReadPropertyWithDefault>>(200, "children", result->children); + return std::move(result); +} + +void ParameterExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "identifier", identifier); +} + +unique_ptr ParameterExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ParameterExpression()); + deserializer.ReadPropertyWithDefault(200, "identifier", result->identifier); + return std::move(result); +} + +void PositionalReferenceExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "index", index); +} + +unique_ptr PositionalReferenceExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new PositionalReferenceExpression()); + deserializer.ReadPropertyWithDefault(200, "index", result->index); + return std::move(result); +} + +void StarExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "relation_name", relation_name); + serializer.WriteProperty(201, "exclude_list", exclude_list); + serializer.WritePropertyWithDefault>>(202, "replace_list", replace_list); + serializer.WritePropertyWithDefault(203, "columns", columns); + serializer.WritePropertyWithDefault>(204, "expr", expr); +} + +unique_ptr StarExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new StarExpression()); + deserializer.ReadPropertyWithDefault(200, "relation_name", result->relation_name); + deserializer.ReadProperty(201, "exclude_list", result->exclude_list); + deserializer.ReadPropertyWithDefault>>(202, "replace_list", result->replace_list); + deserializer.ReadPropertyWithDefault(203, "columns", result->columns); + deserializer.ReadPropertyWithDefault>(204, "expr", result->expr); + return std::move(result); +} + +void SubqueryExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WriteProperty(200, "subquery_type", subquery_type); + serializer.WritePropertyWithDefault>(201, "subquery", subquery); + serializer.WritePropertyWithDefault>(202, "child", child); + serializer.WriteProperty(203, "comparison_type", comparison_type); +} + +unique_ptr SubqueryExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SubqueryExpression()); + deserializer.ReadProperty(200, "subquery_type", result->subquery_type); + deserializer.ReadPropertyWithDefault>(201, "subquery", result->subquery); + deserializer.ReadPropertyWithDefault>(202, "child", result->child); + deserializer.ReadProperty(203, "comparison_type", result->comparison_type); + return std::move(result); +} + +void WindowExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "function_name", function_name); + serializer.WritePropertyWithDefault(201, "schema", schema); + serializer.WritePropertyWithDefault(202, "catalog", catalog); + serializer.WritePropertyWithDefault>>(203, "children", children); + serializer.WritePropertyWithDefault>>(204, "partitions", partitions); + serializer.WritePropertyWithDefault>(205, "orders", orders); + serializer.WriteProperty(206, "start", start); + serializer.WriteProperty(207, "end", end); + serializer.WritePropertyWithDefault>(208, "start_expr", start_expr); + serializer.WritePropertyWithDefault>(209, "end_expr", end_expr); + serializer.WritePropertyWithDefault>(210, "offset_expr", offset_expr); + serializer.WritePropertyWithDefault>(211, "default_expr", default_expr); + serializer.WritePropertyWithDefault(212, "ignore_nulls", ignore_nulls); + serializer.WritePropertyWithDefault>(213, "filter_expr", filter_expr); +} + +unique_ptr WindowExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new WindowExpression(deserializer.Get())); + deserializer.ReadPropertyWithDefault(200, "function_name", result->function_name); + deserializer.ReadPropertyWithDefault(201, "schema", result->schema); + deserializer.ReadPropertyWithDefault(202, "catalog", result->catalog); + deserializer.ReadPropertyWithDefault>>(203, "children", result->children); + deserializer.ReadPropertyWithDefault>>(204, "partitions", result->partitions); + deserializer.ReadPropertyWithDefault>(205, "orders", result->orders); + deserializer.ReadProperty(206, "start", result->start); + deserializer.ReadProperty(207, "end", result->end); + deserializer.ReadPropertyWithDefault>(208, "start_expr", result->start_expr); + deserializer.ReadPropertyWithDefault>(209, "end_expr", result->end_expr); + deserializer.ReadPropertyWithDefault>(210, "offset_expr", result->offset_expr); + deserializer.ReadPropertyWithDefault>(211, "default_expr", result->default_expr); + deserializer.ReadPropertyWithDefault(212, "ignore_nulls", result->ignore_nulls); + deserializer.ReadPropertyWithDefault>(213, "filter_expr", result->filter_expr); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_query_node.cpp b/src/duckdb/src/storage/serialization/serialize_query_node.cpp new file mode 100644 index 00000000..f271ebc4 --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_query_node.cpp @@ -0,0 +1,122 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/query_node/list.hpp" + +namespace duckdb { + +void QueryNode::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); + serializer.WritePropertyWithDefault>>(101, "modifiers", modifiers); + serializer.WriteProperty(102, "cte_map", cte_map); +} + +unique_ptr QueryNode::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto modifiers = deserializer.ReadPropertyWithDefault>>(101, "modifiers"); + auto cte_map = deserializer.ReadProperty(102, "cte_map"); + unique_ptr result; + switch (type) { + case QueryNodeType::CTE_NODE: + result = CTENode::Deserialize(deserializer); + break; + case QueryNodeType::RECURSIVE_CTE_NODE: + result = RecursiveCTENode::Deserialize(deserializer); + break; + case QueryNodeType::SELECT_NODE: + result = SelectNode::Deserialize(deserializer); + break; + case QueryNodeType::SET_OPERATION_NODE: + result = SetOperationNode::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of QueryNode!"); + } + result->modifiers = std::move(modifiers); + result->cte_map = std::move(cte_map); + return result; +} + +void CTENode::Serialize(Serializer &serializer) const { + QueryNode::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "cte_name", ctename); + serializer.WritePropertyWithDefault>(201, "query", query); + serializer.WritePropertyWithDefault>(202, "child", child); + serializer.WritePropertyWithDefault>(203, "aliases", aliases); +} + +unique_ptr CTENode::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new CTENode()); + deserializer.ReadPropertyWithDefault(200, "cte_name", result->ctename); + deserializer.ReadPropertyWithDefault>(201, "query", result->query); + deserializer.ReadPropertyWithDefault>(202, "child", result->child); + deserializer.ReadPropertyWithDefault>(203, "aliases", result->aliases); + return std::move(result); +} + +void RecursiveCTENode::Serialize(Serializer &serializer) const { + QueryNode::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "cte_name", ctename); + serializer.WritePropertyWithDefault(201, "union_all", union_all, false); + serializer.WritePropertyWithDefault>(202, "left", left); + serializer.WritePropertyWithDefault>(203, "right", right); + serializer.WritePropertyWithDefault>(204, "aliases", aliases); +} + +unique_ptr RecursiveCTENode::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new RecursiveCTENode()); + deserializer.ReadPropertyWithDefault(200, "cte_name", result->ctename); + deserializer.ReadPropertyWithDefault(201, "union_all", result->union_all, false); + deserializer.ReadPropertyWithDefault>(202, "left", result->left); + deserializer.ReadPropertyWithDefault>(203, "right", result->right); + deserializer.ReadPropertyWithDefault>(204, "aliases", result->aliases); + return std::move(result); +} + +void SelectNode::Serialize(Serializer &serializer) const { + QueryNode::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "select_list", select_list); + serializer.WritePropertyWithDefault>(201, "from_table", from_table); + serializer.WritePropertyWithDefault>(202, "where_clause", where_clause); + serializer.WritePropertyWithDefault>>(203, "group_expressions", groups.group_expressions); + serializer.WritePropertyWithDefault>(204, "group_sets", groups.grouping_sets); + serializer.WriteProperty(205, "aggregate_handling", aggregate_handling); + serializer.WritePropertyWithDefault>(206, "having", having); + serializer.WritePropertyWithDefault>(207, "sample", sample); + serializer.WritePropertyWithDefault>(208, "qualify", qualify); +} + +unique_ptr SelectNode::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SelectNode()); + deserializer.ReadPropertyWithDefault>>(200, "select_list", result->select_list); + deserializer.ReadPropertyWithDefault>(201, "from_table", result->from_table); + deserializer.ReadPropertyWithDefault>(202, "where_clause", result->where_clause); + deserializer.ReadPropertyWithDefault>>(203, "group_expressions", result->groups.group_expressions); + deserializer.ReadPropertyWithDefault>(204, "group_sets", result->groups.grouping_sets); + deserializer.ReadProperty(205, "aggregate_handling", result->aggregate_handling); + deserializer.ReadPropertyWithDefault>(206, "having", result->having); + deserializer.ReadPropertyWithDefault>(207, "sample", result->sample); + deserializer.ReadPropertyWithDefault>(208, "qualify", result->qualify); + return std::move(result); +} + +void SetOperationNode::Serialize(Serializer &serializer) const { + QueryNode::Serialize(serializer); + serializer.WriteProperty(200, "setop_type", setop_type); + serializer.WritePropertyWithDefault>(201, "left", left); + serializer.WritePropertyWithDefault>(202, "right", right); +} + +unique_ptr SetOperationNode::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SetOperationNode()); + deserializer.ReadProperty(200, "setop_type", result->setop_type); + deserializer.ReadPropertyWithDefault>(201, "left", result->left); + deserializer.ReadPropertyWithDefault>(202, "right", result->right); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_result_modifier.cpp b/src/duckdb/src/storage/serialization/serialize_result_modifier.cpp new file mode 100644 index 00000000..261f106c --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_result_modifier.cpp @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" + +namespace duckdb { + +void ResultModifier::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); +} + +unique_ptr ResultModifier::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + unique_ptr result; + switch (type) { + case ResultModifierType::DISTINCT_MODIFIER: + result = DistinctModifier::Deserialize(deserializer); + break; + case ResultModifierType::LIMIT_MODIFIER: + result = LimitModifier::Deserialize(deserializer); + break; + case ResultModifierType::LIMIT_PERCENT_MODIFIER: + result = LimitPercentModifier::Deserialize(deserializer); + break; + case ResultModifierType::ORDER_MODIFIER: + result = OrderModifier::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of ResultModifier!"); + } + return result; +} + +void BoundOrderModifier::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "orders", orders); +} + +unique_ptr BoundOrderModifier::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new BoundOrderModifier()); + deserializer.ReadPropertyWithDefault>(100, "orders", result->orders); + return result; +} + +void DistinctModifier::Serialize(Serializer &serializer) const { + ResultModifier::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "distinct_on_targets", distinct_on_targets); +} + +unique_ptr DistinctModifier::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new DistinctModifier()); + deserializer.ReadPropertyWithDefault>>(200, "distinct_on_targets", result->distinct_on_targets); + return std::move(result); +} + +void LimitModifier::Serialize(Serializer &serializer) const { + ResultModifier::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "limit", limit); + serializer.WritePropertyWithDefault>(201, "offset", offset); +} + +unique_ptr LimitModifier::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LimitModifier()); + deserializer.ReadPropertyWithDefault>(200, "limit", result->limit); + deserializer.ReadPropertyWithDefault>(201, "offset", result->offset); + return std::move(result); +} + +void LimitPercentModifier::Serialize(Serializer &serializer) const { + ResultModifier::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "limit", limit); + serializer.WritePropertyWithDefault>(201, "offset", offset); +} + +unique_ptr LimitPercentModifier::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LimitPercentModifier()); + deserializer.ReadPropertyWithDefault>(200, "limit", result->limit); + deserializer.ReadPropertyWithDefault>(201, "offset", result->offset); + return std::move(result); +} + +void OrderModifier::Serialize(Serializer &serializer) const { + ResultModifier::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "orders", orders); +} + +unique_ptr OrderModifier::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new OrderModifier()); + deserializer.ReadPropertyWithDefault>(200, "orders", result->orders); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_statement.cpp b/src/duckdb/src/storage/serialization/serialize_statement.cpp new file mode 100644 index 00000000..235ffd1e --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_statement.cpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/statement/select_statement.hpp" + +namespace duckdb { + +void SelectStatement::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "node", node); +} + +unique_ptr SelectStatement::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SelectStatement()); + deserializer.ReadPropertyWithDefault>(100, "node", result->node); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_storage.cpp b/src/duckdb/src/storage/serialization/serialize_storage.cpp new file mode 100644 index 00000000..605d5a41 --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_storage.cpp @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/storage/block.hpp" +#include "duckdb/storage/data_pointer.hpp" +#include "duckdb/storage/statistics/distinct_statistics.hpp" + +namespace duckdb { + +void BlockPointer::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "block_id", block_id); + serializer.WritePropertyWithDefault(101, "offset", offset); +} + +BlockPointer BlockPointer::Deserialize(Deserializer &deserializer) { + auto block_id = deserializer.ReadProperty(100, "block_id"); + auto offset = deserializer.ReadPropertyWithDefault(101, "offset"); + BlockPointer result(block_id, offset); + return result; +} + +void DataPointer::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "row_start", row_start); + serializer.WritePropertyWithDefault(101, "tuple_count", tuple_count); + serializer.WriteProperty(102, "block_pointer", block_pointer); + serializer.WriteProperty(103, "compression_type", compression_type); + serializer.WriteProperty(104, "statistics", statistics); + serializer.WritePropertyWithDefault>(105, "segment_state", segment_state); +} + +DataPointer DataPointer::Deserialize(Deserializer &deserializer) { + auto row_start = deserializer.ReadPropertyWithDefault(100, "row_start"); + auto tuple_count = deserializer.ReadPropertyWithDefault(101, "tuple_count"); + auto block_pointer = deserializer.ReadProperty(102, "block_pointer"); + auto compression_type = deserializer.ReadProperty(103, "compression_type"); + auto statistics = deserializer.ReadProperty(104, "statistics"); + DataPointer result(std::move(statistics)); + result.row_start = row_start; + result.tuple_count = tuple_count; + result.block_pointer = block_pointer; + result.compression_type = compression_type; + deserializer.Set(compression_type); + deserializer.ReadPropertyWithDefault>(105, "segment_state", result.segment_state); + deserializer.Unset(); + return result; +} + +void DistinctStatistics::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "sample_count", sample_count); + serializer.WritePropertyWithDefault(101, "total_count", total_count); + serializer.WritePropertyWithDefault>(102, "log", log); +} + +unique_ptr DistinctStatistics::Deserialize(Deserializer &deserializer) { + auto sample_count = deserializer.ReadPropertyWithDefault(100, "sample_count"); + auto total_count = deserializer.ReadPropertyWithDefault(101, "total_count"); + auto log = deserializer.ReadPropertyWithDefault>(102, "log"); + auto result = duckdb::unique_ptr(new DistinctStatistics(std::move(log), sample_count, total_count)); + return result; +} + +void MetaBlockPointer::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "block_pointer", block_pointer); + serializer.WritePropertyWithDefault(101, "offset", offset); +} + +MetaBlockPointer MetaBlockPointer::Deserialize(Deserializer &deserializer) { + auto block_pointer = deserializer.ReadPropertyWithDefault(100, "block_pointer"); + auto offset = deserializer.ReadPropertyWithDefault(101, "offset"); + MetaBlockPointer result(block_pointer, offset); + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_table_filter.cpp b/src/duckdb/src/storage/serialization/serialize_table_filter.cpp new file mode 100644 index 00000000..ae065cdd --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_table_filter.cpp @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/planner/filter/null_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" + +namespace duckdb { + +void TableFilter::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "filter_type", filter_type); +} + +unique_ptr TableFilter::Deserialize(Deserializer &deserializer) { + auto filter_type = deserializer.ReadProperty(100, "filter_type"); + unique_ptr result; + switch (filter_type) { + case TableFilterType::CONJUNCTION_AND: + result = ConjunctionAndFilter::Deserialize(deserializer); + break; + case TableFilterType::CONJUNCTION_OR: + result = ConjunctionOrFilter::Deserialize(deserializer); + break; + case TableFilterType::CONSTANT_COMPARISON: + result = ConstantFilter::Deserialize(deserializer); + break; + case TableFilterType::IS_NOT_NULL: + result = IsNotNullFilter::Deserialize(deserializer); + break; + case TableFilterType::IS_NULL: + result = IsNullFilter::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of TableFilter!"); + } + return result; +} + +void ConjunctionAndFilter::Serialize(Serializer &serializer) const { + TableFilter::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "child_filters", child_filters); +} + +unique_ptr ConjunctionAndFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ConjunctionAndFilter()); + deserializer.ReadPropertyWithDefault>>(200, "child_filters", result->child_filters); + return std::move(result); +} + +void ConjunctionOrFilter::Serialize(Serializer &serializer) const { + TableFilter::Serialize(serializer); + serializer.WritePropertyWithDefault>>(200, "child_filters", child_filters); +} + +unique_ptr ConjunctionOrFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ConjunctionOrFilter()); + deserializer.ReadPropertyWithDefault>>(200, "child_filters", result->child_filters); + return std::move(result); +} + +void ConstantFilter::Serialize(Serializer &serializer) const { + TableFilter::Serialize(serializer); + serializer.WriteProperty(200, "comparison_type", comparison_type); + serializer.WriteProperty(201, "constant", constant); +} + +unique_ptr ConstantFilter::Deserialize(Deserializer &deserializer) { + auto comparison_type = deserializer.ReadProperty(200, "comparison_type"); + auto constant = deserializer.ReadProperty(201, "constant"); + auto result = duckdb::unique_ptr(new ConstantFilter(comparison_type, constant)); + return std::move(result); +} + +void IsNotNullFilter::Serialize(Serializer &serializer) const { + TableFilter::Serialize(serializer); +} + +unique_ptr IsNotNullFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new IsNotNullFilter()); + return std::move(result); +} + +void IsNullFilter::Serialize(Serializer &serializer) const { + TableFilter::Serialize(serializer); +} + +unique_ptr IsNullFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new IsNullFilter()); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_tableref.cpp b/src/duckdb/src/storage/serialization/serialize_tableref.cpp new file mode 100644 index 00000000..754457e5 --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_tableref.cpp @@ -0,0 +1,164 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/tableref/list.hpp" + +namespace duckdb { + +void TableRef::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); + serializer.WritePropertyWithDefault(101, "alias", alias); + serializer.WritePropertyWithDefault>(102, "sample", sample); +} + +unique_ptr TableRef::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto alias = deserializer.ReadPropertyWithDefault(101, "alias"); + auto sample = deserializer.ReadPropertyWithDefault>(102, "sample"); + unique_ptr result; + switch (type) { + case TableReferenceType::BASE_TABLE: + result = BaseTableRef::Deserialize(deserializer); + break; + case TableReferenceType::EMPTY: + result = EmptyTableRef::Deserialize(deserializer); + break; + case TableReferenceType::EXPRESSION_LIST: + result = ExpressionListRef::Deserialize(deserializer); + break; + case TableReferenceType::JOIN: + result = JoinRef::Deserialize(deserializer); + break; + case TableReferenceType::PIVOT: + result = PivotRef::Deserialize(deserializer); + break; + case TableReferenceType::SUBQUERY: + result = SubqueryRef::Deserialize(deserializer); + break; + case TableReferenceType::TABLE_FUNCTION: + result = TableFunctionRef::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of TableRef!"); + } + result->alias = std::move(alias); + result->sample = std::move(sample); + return result; +} + +void BaseTableRef::Serialize(Serializer &serializer) const { + TableRef::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "schema_name", schema_name); + serializer.WritePropertyWithDefault(201, "table_name", table_name); + serializer.WritePropertyWithDefault>(202, "column_name_alias", column_name_alias); + serializer.WritePropertyWithDefault(203, "catalog_name", catalog_name); +} + +unique_ptr BaseTableRef::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new BaseTableRef()); + deserializer.ReadPropertyWithDefault(200, "schema_name", result->schema_name); + deserializer.ReadPropertyWithDefault(201, "table_name", result->table_name); + deserializer.ReadPropertyWithDefault>(202, "column_name_alias", result->column_name_alias); + deserializer.ReadPropertyWithDefault(203, "catalog_name", result->catalog_name); + return std::move(result); +} + +void EmptyTableRef::Serialize(Serializer &serializer) const { + TableRef::Serialize(serializer); +} + +unique_ptr EmptyTableRef::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new EmptyTableRef()); + return std::move(result); +} + +void ExpressionListRef::Serialize(Serializer &serializer) const { + TableRef::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "expected_names", expected_names); + serializer.WritePropertyWithDefault>(201, "expected_types", expected_types); + serializer.WritePropertyWithDefault>>>(202, "values", values); +} + +unique_ptr ExpressionListRef::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ExpressionListRef()); + deserializer.ReadPropertyWithDefault>(200, "expected_names", result->expected_names); + deserializer.ReadPropertyWithDefault>(201, "expected_types", result->expected_types); + deserializer.ReadPropertyWithDefault>>>(202, "values", result->values); + return std::move(result); +} + +void JoinRef::Serialize(Serializer &serializer) const { + TableRef::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "left", left); + serializer.WritePropertyWithDefault>(201, "right", right); + serializer.WritePropertyWithDefault>(202, "condition", condition); + serializer.WriteProperty(203, "join_type", type); + serializer.WriteProperty(204, "ref_type", ref_type); + serializer.WritePropertyWithDefault>(205, "using_columns", using_columns); +} + +unique_ptr JoinRef::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new JoinRef()); + deserializer.ReadPropertyWithDefault>(200, "left", result->left); + deserializer.ReadPropertyWithDefault>(201, "right", result->right); + deserializer.ReadPropertyWithDefault>(202, "condition", result->condition); + deserializer.ReadProperty(203, "join_type", result->type); + deserializer.ReadProperty(204, "ref_type", result->ref_type); + deserializer.ReadPropertyWithDefault>(205, "using_columns", result->using_columns); + return std::move(result); +} + +void PivotRef::Serialize(Serializer &serializer) const { + TableRef::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "source", source); + serializer.WritePropertyWithDefault>>(201, "aggregates", aggregates); + serializer.WritePropertyWithDefault>(202, "unpivot_names", unpivot_names); + serializer.WritePropertyWithDefault>(203, "pivots", pivots); + serializer.WritePropertyWithDefault>(204, "groups", groups); + serializer.WritePropertyWithDefault>(205, "column_name_alias", column_name_alias); + serializer.WritePropertyWithDefault(206, "include_nulls", include_nulls); +} + +unique_ptr PivotRef::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new PivotRef()); + deserializer.ReadPropertyWithDefault>(200, "source", result->source); + deserializer.ReadPropertyWithDefault>>(201, "aggregates", result->aggregates); + deserializer.ReadPropertyWithDefault>(202, "unpivot_names", result->unpivot_names); + deserializer.ReadPropertyWithDefault>(203, "pivots", result->pivots); + deserializer.ReadPropertyWithDefault>(204, "groups", result->groups); + deserializer.ReadPropertyWithDefault>(205, "column_name_alias", result->column_name_alias); + deserializer.ReadPropertyWithDefault(206, "include_nulls", result->include_nulls); + return std::move(result); +} + +void SubqueryRef::Serialize(Serializer &serializer) const { + TableRef::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "subquery", subquery); + serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); +} + +unique_ptr SubqueryRef::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SubqueryRef()); + deserializer.ReadPropertyWithDefault>(200, "subquery", result->subquery); + deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); + return std::move(result); +} + +void TableFunctionRef::Serialize(Serializer &serializer) const { + TableRef::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "function", function); + serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); +} + +unique_ptr TableFunctionRef::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new TableFunctionRef()); + deserializer.ReadPropertyWithDefault>(200, "function", result->function); + deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_types.cpp b/src/duckdb/src/storage/serialization/serialize_types.cpp new file mode 100644 index 00000000..13b5b078 --- /dev/null +++ b/src/duckdb/src/storage/serialization/serialize_types.cpp @@ -0,0 +1,127 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/extra_type_info.hpp" + +namespace duckdb { + +void ExtraTypeInfo::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "type", type); + serializer.WritePropertyWithDefault(101, "alias", alias); +} + +shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) { + auto type = deserializer.ReadProperty(100, "type"); + auto alias = deserializer.ReadPropertyWithDefault(101, "alias"); + shared_ptr result; + switch (type) { + case ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO: + result = AggregateStateTypeInfo::Deserialize(deserializer); + break; + case ExtraTypeInfoType::DECIMAL_TYPE_INFO: + result = DecimalTypeInfo::Deserialize(deserializer); + break; + case ExtraTypeInfoType::ENUM_TYPE_INFO: + result = EnumTypeInfo::Deserialize(deserializer); + break; + case ExtraTypeInfoType::GENERIC_TYPE_INFO: + result = make_shared(type); + break; + case ExtraTypeInfoType::INVALID_TYPE_INFO: + return nullptr; + case ExtraTypeInfoType::LIST_TYPE_INFO: + result = ListTypeInfo::Deserialize(deserializer); + break; + case ExtraTypeInfoType::STRING_TYPE_INFO: + result = StringTypeInfo::Deserialize(deserializer); + break; + case ExtraTypeInfoType::STRUCT_TYPE_INFO: + result = StructTypeInfo::Deserialize(deserializer); + break; + case ExtraTypeInfoType::USER_TYPE_INFO: + result = UserTypeInfo::Deserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of ExtraTypeInfo!"); + } + result->alias = std::move(alias); + return result; +} + +void AggregateStateTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "function_name", state_type.function_name); + serializer.WriteProperty(201, "return_type", state_type.return_type); + serializer.WritePropertyWithDefault>(202, "bound_argument_types", state_type.bound_argument_types); +} + +shared_ptr AggregateStateTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new AggregateStateTypeInfo()); + deserializer.ReadPropertyWithDefault(200, "function_name", result->state_type.function_name); + deserializer.ReadProperty(201, "return_type", result->state_type.return_type); + deserializer.ReadPropertyWithDefault>(202, "bound_argument_types", result->state_type.bound_argument_types); + return std::move(result); +} + +void DecimalTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "width", width); + serializer.WritePropertyWithDefault(201, "scale", scale); +} + +shared_ptr DecimalTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new DecimalTypeInfo()); + deserializer.ReadPropertyWithDefault(200, "width", result->width); + deserializer.ReadPropertyWithDefault(201, "scale", result->scale); + return std::move(result); +} + +void ListTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); + serializer.WriteProperty(200, "child_type", child_type); +} + +shared_ptr ListTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new ListTypeInfo()); + deserializer.ReadProperty(200, "child_type", result->child_type); + return std::move(result); +} + +void StringTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "collation", collation); +} + +shared_ptr StringTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new StringTypeInfo()); + deserializer.ReadPropertyWithDefault(200, "collation", result->collation); + return std::move(result); +} + +void StructTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "child_types", child_types); +} + +shared_ptr StructTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new StructTypeInfo()); + deserializer.ReadPropertyWithDefault>(200, "child_types", result->child_types); + return std::move(result); +} + +void UserTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "user_type_name", user_type_name); +} + +shared_ptr UserTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new UserTypeInfo()); + deserializer.ReadPropertyWithDefault(200, "user_type_name", result->user_type_name); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/single_file_block_manager.cpp b/src/duckdb/src/storage/single_file_block_manager.cpp new file mode 100644 index 00000000..2a0d72d7 --- /dev/null +++ b/src/duckdb/src/storage/single_file_block_manager.cpp @@ -0,0 +1,513 @@ +#include "duckdb/storage/single_file_block_manager.hpp" + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/checksum.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" +#include "duckdb/storage/metadata/metadata_reader.hpp" +#include "duckdb/storage/metadata/metadata_writer.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/database.hpp" + +#include +#include + +namespace duckdb { + +const char MainHeader::MAGIC_BYTES[] = "DUCK"; + +void SerializeVersionNumber(WriteStream &ser, const string &version_str) { + constexpr const idx_t MAX_VERSION_SIZE = 32; + data_t version[MAX_VERSION_SIZE]; + memset(version, 0, MAX_VERSION_SIZE); + memcpy(version, version_str.c_str(), MinValue(version_str.size(), MAX_VERSION_SIZE)); + ser.WriteData(version, MAX_VERSION_SIZE); +} + +void MainHeader::Write(WriteStream &ser) { + ser.WriteData(const_data_ptr_cast(MAGIC_BYTES), MAGIC_BYTE_SIZE); + ser.Write(version_number); + for (idx_t i = 0; i < FLAG_COUNT; i++) { + ser.Write(flags[i]); + } + SerializeVersionNumber(ser, DuckDB::LibraryVersion()); + SerializeVersionNumber(ser, DuckDB::SourceID()); +} + +void MainHeader::CheckMagicBytes(FileHandle &handle) { + data_t magic_bytes[MAGIC_BYTE_SIZE]; + if (handle.GetFileSize() < MainHeader::MAGIC_BYTE_SIZE + MainHeader::MAGIC_BYTE_OFFSET) { + throw IOException("The file \"%s\" exists, but it is not a valid DuckDB database file!", handle.path); + } + handle.Read(magic_bytes, MainHeader::MAGIC_BYTE_SIZE, MainHeader::MAGIC_BYTE_OFFSET); + if (memcmp(magic_bytes, MainHeader::MAGIC_BYTES, MainHeader::MAGIC_BYTE_SIZE) != 0) { + throw IOException("The file \"%s\" exists, but it is not a valid DuckDB database file!", handle.path); + } +} + +MainHeader MainHeader::Read(ReadStream &source) { + data_t magic_bytes[MAGIC_BYTE_SIZE]; + MainHeader header; + source.ReadData(magic_bytes, MainHeader::MAGIC_BYTE_SIZE); + if (memcmp(magic_bytes, MainHeader::MAGIC_BYTES, MainHeader::MAGIC_BYTE_SIZE) != 0) { + throw IOException("The file is not a valid DuckDB database file!"); + } + header.version_number = source.Read(); + // check the version number + if (header.version_number != VERSION_NUMBER) { + auto version = GetDuckDBVersion(header.version_number); + string version_text; + if (version) { + // known version + version_text = "DuckDB version " + string(version); + } else { + version_text = string("an ") + (VERSION_NUMBER > header.version_number ? "older development" : "newer") + + string(" version of DuckDB"); + } + throw IOException( + "Trying to read a database file with version number %lld, but we can only read version %lld.\n" + "The database file was created with %s.\n\n" + "The storage of DuckDB is not yet stable; newer versions of DuckDB cannot read old database files and " + "vice versa.\n" + "The storage will be stabilized when version 1.0 releases.\n\n" + "For now, we recommend that you load the database file in a supported version of DuckDB, and use the " + "EXPORT DATABASE command " + "followed by IMPORT DATABASE on the current version of DuckDB.\n\n" + "See the storage page for more information: https://duckdb.org/internals/storage", + header.version_number, VERSION_NUMBER, version_text); + } + // read the flags + for (idx_t i = 0; i < FLAG_COUNT; i++) { + header.flags[i] = source.Read(); + } + return header; +} + +void DatabaseHeader::Write(WriteStream &ser) { + ser.Write(iteration); + ser.Write(meta_block); + ser.Write(free_list); + ser.Write(block_count); +} + +DatabaseHeader DatabaseHeader::Read(ReadStream &source) { + DatabaseHeader header; + header.iteration = source.Read(); + header.meta_block = source.Read(); + header.free_list = source.Read(); + header.block_count = source.Read(); + return header; +} + +template +void SerializeHeaderStructure(T header, data_ptr_t ptr) { + MemoryStream ser(ptr, Storage::FILE_HEADER_SIZE); + header.Write(ser); +} + +template +T DeserializeHeaderStructure(data_ptr_t ptr) { + MemoryStream source(ptr, Storage::FILE_HEADER_SIZE); + return T::Read(source); +} + +SingleFileBlockManager::SingleFileBlockManager(AttachedDatabase &db, string path_p, StorageManagerOptions options) + : BlockManager(BufferManager::GetBufferManager(db)), db(db), path(std::move(path_p)), + header_buffer(Allocator::Get(db), FileBufferType::MANAGED_BUFFER, + Storage::FILE_HEADER_SIZE - Storage::BLOCK_HEADER_SIZE), + iteration_count(0), options(options) { +} + +void SingleFileBlockManager::GetFileFlags(uint8_t &flags, FileLockType &lock, bool create_new) { + if (options.read_only) { + D_ASSERT(!create_new); + flags = FileFlags::FILE_FLAGS_READ; + lock = FileLockType::READ_LOCK; + } else { + flags = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_READ; + lock = FileLockType::WRITE_LOCK; + if (create_new) { + flags |= FileFlags::FILE_FLAGS_FILE_CREATE; + } + } + if (options.use_direct_io) { + flags |= FileFlags::FILE_FLAGS_DIRECT_IO; + } +} + +void SingleFileBlockManager::CreateNewDatabase() { + uint8_t flags; + FileLockType lock; + GetFileFlags(flags, lock, true); + + // open the RDBMS handle + auto &fs = FileSystem::Get(db); + handle = fs.OpenFile(path, flags, lock); + + // if we create a new file, we fill the metadata of the file + // first fill in the new header + header_buffer.Clear(); + + MainHeader main_header; + main_header.version_number = VERSION_NUMBER; + memset(main_header.flags, 0, sizeof(uint64_t) * 4); + + SerializeHeaderStructure(main_header, header_buffer.buffer); + // now write the header to the file + ChecksumAndWrite(header_buffer, 0); + header_buffer.Clear(); + + // write the database headers + // initialize meta_block and free_list to INVALID_BLOCK because the database file does not contain any actual + // content yet + DatabaseHeader h1, h2; + // header 1 + h1.iteration = 0; + h1.meta_block = INVALID_BLOCK; + h1.free_list = INVALID_BLOCK; + h1.block_count = 0; + SerializeHeaderStructure(h1, header_buffer.buffer); + ChecksumAndWrite(header_buffer, Storage::FILE_HEADER_SIZE); + // header 2 + h2.iteration = 0; + h2.meta_block = INVALID_BLOCK; + h2.free_list = INVALID_BLOCK; + h2.block_count = 0; + SerializeHeaderStructure(h2, header_buffer.buffer); + ChecksumAndWrite(header_buffer, Storage::FILE_HEADER_SIZE * 2ULL); + // ensure that writing to disk is completed before returning + handle->Sync(); + // we start with h2 as active_header, this way our initial write will be in h1 + iteration_count = 0; + active_header = 1; + max_block = 0; +} + +void SingleFileBlockManager::LoadExistingDatabase() { + uint8_t flags; + FileLockType lock; + GetFileFlags(flags, lock, false); + + // open the RDBMS handle + auto &fs = FileSystem::Get(db); + handle = fs.OpenFile(path, flags, lock); + + MainHeader::CheckMagicBytes(*handle); + // otherwise, we check the metadata of the file + ReadAndChecksum(header_buffer, 0); + DeserializeHeaderStructure(header_buffer.buffer); + + // read the database headers from disk + DatabaseHeader h1, h2; + ReadAndChecksum(header_buffer, Storage::FILE_HEADER_SIZE); + h1 = DeserializeHeaderStructure(header_buffer.buffer); + ReadAndChecksum(header_buffer, Storage::FILE_HEADER_SIZE * 2ULL); + h2 = DeserializeHeaderStructure(header_buffer.buffer); + // check the header with the highest iteration count + if (h1.iteration > h2.iteration) { + // h1 is active header + active_header = 0; + Initialize(h1); + } else { + // h2 is active header + active_header = 1; + Initialize(h2); + } + LoadFreeList(); +} + +void SingleFileBlockManager::ReadAndChecksum(FileBuffer &block, uint64_t location) const { + // read the buffer from disk + block.Read(*handle, location); + // compute the checksum + auto stored_checksum = Load(block.InternalBuffer()); + uint64_t computed_checksum = Checksum(block.buffer, block.size); + // verify the checksum + if (stored_checksum != computed_checksum) { + throw IOException("Corrupt database file: computed checksum %llu does not match stored checksum %llu in block", + computed_checksum, stored_checksum); + } +} + +void SingleFileBlockManager::ChecksumAndWrite(FileBuffer &block, uint64_t location) const { + // compute the checksum and write it to the start of the buffer (if not temp buffer) + uint64_t checksum = Checksum(block.buffer, block.size); + Store(checksum, block.InternalBuffer()); + // now write the buffer + block.Write(*handle, location); +} + +void SingleFileBlockManager::Initialize(DatabaseHeader &header) { + free_list_id = header.free_list; + meta_block = header.meta_block; + iteration_count = header.iteration; + max_block = header.block_count; +} + +void SingleFileBlockManager::LoadFreeList() { + MetaBlockPointer free_pointer(free_list_id, 0); + if (!free_pointer.IsValid()) { + // no free list + return; + } + MetadataReader reader(GetMetadataManager(), free_pointer, nullptr, BlockReaderType::REGISTER_BLOCKS); + auto free_list_count = reader.Read(); + free_list.clear(); + for (idx_t i = 0; i < free_list_count; i++) { + free_list.insert(reader.Read()); + } + auto multi_use_blocks_count = reader.Read(); + multi_use_blocks.clear(); + for (idx_t i = 0; i < multi_use_blocks_count; i++) { + auto block_id = reader.Read(); + auto usage_count = reader.Read(); + multi_use_blocks[block_id] = usage_count; + } + GetMetadataManager().Read(reader); + GetMetadataManager().MarkBlocksAsModified(); +} + +bool SingleFileBlockManager::IsRootBlock(MetaBlockPointer root) { + return root.block_pointer == meta_block; +} + +block_id_t SingleFileBlockManager::GetFreeBlockId() { + lock_guard lock(block_lock); + block_id_t block; + if (!free_list.empty()) { + // free list is non empty + // take an entry from the free list + block = *free_list.begin(); + // erase the entry from the free list again + free_list.erase(free_list.begin()); + } else { + block = max_block++; + } + return block; +} + +void SingleFileBlockManager::MarkBlockAsFree(block_id_t block_id) { + lock_guard lock(block_lock); + D_ASSERT(block_id >= 0); + D_ASSERT(block_id < max_block); + if (free_list.find(block_id) != free_list.end()) { + throw InternalException("MarkBlockAsFree called but block %llu was already freed!", block_id); + } + multi_use_blocks.erase(block_id); + free_list.insert(block_id); +} + +void SingleFileBlockManager::MarkBlockAsModified(block_id_t block_id) { + lock_guard lock(block_lock); + D_ASSERT(block_id >= 0); + D_ASSERT(block_id < max_block); + + // check if the block is a multi-use block + auto entry = multi_use_blocks.find(block_id); + if (entry != multi_use_blocks.end()) { + // it is! reduce the reference count of the block + entry->second--; + // check the reference count: is the block still a multi-use block? + if (entry->second <= 1) { + // no longer a multi-use block! + multi_use_blocks.erase(entry); + } + return; + } + // Check for multi-free + // TODO: Fix the bug that causes this assert to fire, then uncomment it. + // D_ASSERT(modified_blocks.find(block_id) == modified_blocks.end()); + D_ASSERT(free_list.find(block_id) == free_list.end()); + modified_blocks.insert(block_id); +} + +void SingleFileBlockManager::IncreaseBlockReferenceCount(block_id_t block_id) { + lock_guard lock(block_lock); + D_ASSERT(block_id >= 0); + D_ASSERT(block_id < max_block); + D_ASSERT(free_list.find(block_id) == free_list.end()); + auto entry = multi_use_blocks.find(block_id); + if (entry != multi_use_blocks.end()) { + entry->second++; + } else { + multi_use_blocks[block_id] = 2; + } +} + +idx_t SingleFileBlockManager::GetMetaBlock() { + return meta_block; +} + +idx_t SingleFileBlockManager::TotalBlocks() { + lock_guard lock(block_lock); + return max_block; +} + +idx_t SingleFileBlockManager::FreeBlocks() { + lock_guard lock(block_lock); + return free_list.size(); +} + +unique_ptr SingleFileBlockManager::ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) { + D_ASSERT(source_buffer.AllocSize() == Storage::BLOCK_ALLOC_SIZE); + return make_uniq(source_buffer, block_id); +} + +unique_ptr SingleFileBlockManager::CreateBlock(block_id_t block_id, FileBuffer *source_buffer) { + unique_ptr result; + if (source_buffer) { + result = ConvertBlock(block_id, *source_buffer); + } else { + result = make_uniq(Allocator::Get(db), block_id); + } + result->Initialize(options.debug_initialize); + return result; +} + +void SingleFileBlockManager::Read(Block &block) { + D_ASSERT(block.id >= 0); + D_ASSERT(std::find(free_list.begin(), free_list.end(), block.id) == free_list.end()); + ReadAndChecksum(block, BLOCK_START + block.id * Storage::BLOCK_ALLOC_SIZE); +} + +void SingleFileBlockManager::Write(FileBuffer &buffer, block_id_t block_id) { + D_ASSERT(block_id >= 0); + ChecksumAndWrite(buffer, BLOCK_START + block_id * Storage::BLOCK_ALLOC_SIZE); +} + +void SingleFileBlockManager::Truncate() { + BlockManager::Truncate(); + idx_t blocks_to_truncate = 0; + // reverse iterate over the free-list + for (auto entry = free_list.rbegin(); entry != free_list.rend(); entry++) { + auto block_id = *entry; + if (block_id + 1 != max_block) { + break; + } + blocks_to_truncate++; + max_block--; + } + if (blocks_to_truncate == 0) { + // nothing to truncate + return; + } + // truncate the file + for (idx_t i = 0; i < blocks_to_truncate; i++) { + free_list.erase(max_block + i); + } + handle->Truncate(BLOCK_START + max_block * Storage::BLOCK_ALLOC_SIZE); +} + +vector SingleFileBlockManager::GetFreeListBlocks() { + vector free_list_blocks; + + // reserve all blocks that we are going to write the free list to + // since these blocks are no longer free we cannot just include them in the free list! + auto block_size = MetadataManager::METADATA_BLOCK_SIZE - sizeof(idx_t); + idx_t allocated_size = 0; + while (true) { + auto free_list_size = sizeof(uint64_t) + sizeof(block_id_t) * (free_list.size() + modified_blocks.size()); + auto multi_use_blocks_size = + sizeof(uint64_t) + (sizeof(block_id_t) + sizeof(uint32_t)) * multi_use_blocks.size(); + auto metadata_blocks = + sizeof(uint64_t) + (sizeof(block_id_t) + sizeof(idx_t)) * GetMetadataManager().BlockCount(); + auto total_size = free_list_size + multi_use_blocks_size + metadata_blocks; + if (total_size < allocated_size) { + break; + } + auto free_list_handle = GetMetadataManager().AllocateHandle(); + free_list_blocks.push_back(std::move(free_list_handle)); + allocated_size += block_size; + } + + return free_list_blocks; +} + +class FreeListBlockWriter : public MetadataWriter { +public: + FreeListBlockWriter(MetadataManager &manager, vector free_list_blocks_p) + : MetadataWriter(manager), free_list_blocks(std::move(free_list_blocks_p)), index(0) { + } + + vector free_list_blocks; + idx_t index; + +protected: + MetadataHandle NextHandle() override { + if (index >= free_list_blocks.size()) { + throw InternalException( + "Free List Block Writer ran out of blocks, this means not enough blocks were allocated up front"); + } + return std::move(free_list_blocks[index++]); + } +}; + +void SingleFileBlockManager::WriteHeader(DatabaseHeader header) { + // set the iteration count + header.iteration = ++iteration_count; + + auto free_list_blocks = GetFreeListBlocks(); + + // now handle the free list + auto &metadata_manager = GetMetadataManager(); + // add all modified blocks to the free list: they can now be written to again + metadata_manager.MarkBlocksAsModified(); + for (auto &block : modified_blocks) { + free_list.insert(block); + } + modified_blocks.clear(); + + if (!free_list_blocks.empty()) { + // there are blocks to write, either in the free_list or in the modified_blocks + // we write these blocks specifically to the free_list_blocks + // a normal MetadataWriter will fetch blocks to use from the free_list + // but since we are WRITING the free_list, this behavior is sub-optimal + FreeListBlockWriter writer(metadata_manager, std::move(free_list_blocks)); + + auto ptr = writer.GetMetaBlockPointer(); + header.free_list = ptr.block_pointer; + + writer.Write(free_list.size()); + for (auto &block_id : free_list) { + writer.Write(block_id); + } + writer.Write(multi_use_blocks.size()); + for (auto &entry : multi_use_blocks) { + writer.Write(entry.first); + writer.Write(entry.second); + } + GetMetadataManager().Write(writer); + writer.Flush(); + } else { + // no blocks in the free list + header.free_list = DConstants::INVALID_INDEX; + } + metadata_manager.Flush(); + header.block_count = max_block; + + auto &config = DBConfig::Get(db); + if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE) { + throw FatalException("Checkpoint aborted after free list write because of PRAGMA checkpoint_abort flag"); + } + + if (!options.use_direct_io) { + // if we are not using Direct IO we need to fsync BEFORE we write the header to ensure that all the previous + // blocks are written as well + handle->Sync(); + } + // set the header inside the buffer + header_buffer.Clear(); + MemoryStream serializer; + header.Write(serializer); + memcpy(header_buffer.buffer, serializer.GetData(), serializer.GetPosition()); + // now write the header to the file, active_header determines whether we write to h1 or h2 + // note that if active_header is h1 we write to h2, and vice versa + ChecksumAndWrite(header_buffer, active_header == 1 ? Storage::FILE_HEADER_SIZE : Storage::FILE_HEADER_SIZE * 2); + // switch active header to the other header + active_header = 1 - active_header; + //! Ensure the header write ends up on disk + handle->Sync(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/standard_buffer_manager.cpp b/src/duckdb/src/storage/standard_buffer_manager.cpp new file mode 100644 index 00000000..87505a80 --- /dev/null +++ b/src/duckdb/src/storage/standard_buffer_manager.cpp @@ -0,0 +1,794 @@ +#include "duckdb/storage/standard_buffer_manager.hpp" + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/storage/buffer/buffer_pool.hpp" +#include "duckdb/storage/in_memory_block_manager.hpp" +#include "duckdb/storage/storage_manager.hpp" + +namespace duckdb { + +struct BufferAllocatorData : PrivateAllocatorData { + explicit BufferAllocatorData(StandardBufferManager &manager) : manager(manager) { + } + + StandardBufferManager &manager; +}; + +unique_ptr StandardBufferManager::ConstructManagedBuffer(idx_t size, unique_ptr &&source, + FileBufferType type) { + unique_ptr result; + if (source) { + auto tmp = std::move(source); + D_ASSERT(tmp->AllocSize() == BufferManager::GetAllocSize(size)); + result = make_uniq(*tmp, type); + } else { + // no re-usable buffer: allocate a new buffer + result = make_uniq(Allocator::Get(db), type, size); + } + result->Initialize(DBConfig::GetConfig(db).options.debug_initialize); + return result; +} + +class TemporaryFileManager; + +class TemporaryDirectoryHandle { +public: + TemporaryDirectoryHandle(DatabaseInstance &db, string path_p); + ~TemporaryDirectoryHandle(); + + TemporaryFileManager &GetTempFile(); + +private: + DatabaseInstance &db; + string temp_directory; + bool created_directory = false; + unique_ptr temp_file; +}; + +void StandardBufferManager::SetTemporaryDirectory(const string &new_dir) { + if (temp_directory_handle) { + throw NotImplementedException("Cannot switch temporary directory after the current one has been used"); + } + this->temp_directory = new_dir; +} + +StandardBufferManager::StandardBufferManager(DatabaseInstance &db, string tmp) + : BufferManager(), db(db), buffer_pool(db.GetBufferPool()), temp_directory(std::move(tmp)), + temporary_id(MAXIMUM_BLOCK), buffer_allocator(BufferAllocatorAllocate, BufferAllocatorFree, + BufferAllocatorRealloc, make_uniq(*this)) { + temp_block_manager = make_uniq(*this); +} + +StandardBufferManager::~StandardBufferManager() { +} + +BufferPool &StandardBufferManager::GetBufferPool() { + return buffer_pool; +} + +idx_t StandardBufferManager::GetUsedMemory() const { + return buffer_pool.GetUsedMemory(); +} +idx_t StandardBufferManager::GetMaxMemory() const { + return buffer_pool.GetMaxMemory(); +} + +template +TempBufferPoolReservation StandardBufferManager::EvictBlocksOrThrow(idx_t memory_delta, unique_ptr *buffer, + ARGS... args) { + auto r = buffer_pool.EvictBlocks(memory_delta, buffer_pool.maximum_memory, buffer); + if (!r.success) { + string extra_text = StringUtil::Format(" (%s/%s used)", StringUtil::BytesToHumanReadableString(GetUsedMemory()), + StringUtil::BytesToHumanReadableString(GetMaxMemory())); + extra_text += InMemoryWarning(); + throw OutOfMemoryException(args..., extra_text); + } + return std::move(r.reservation); +} + +shared_ptr StandardBufferManager::RegisterSmallMemory(idx_t block_size) { + D_ASSERT(block_size < Storage::BLOCK_SIZE); + auto res = EvictBlocksOrThrow(block_size, nullptr, "could not allocate block of size %s%s", + StringUtil::BytesToHumanReadableString(block_size)); + + auto buffer = ConstructManagedBuffer(block_size, nullptr, FileBufferType::TINY_BUFFER); + + // create a new block pointer for this block + return make_shared(*temp_block_manager, ++temporary_id, std::move(buffer), false, block_size, + std::move(res)); +} + +shared_ptr StandardBufferManager::RegisterMemory(idx_t block_size, bool can_destroy) { + D_ASSERT(block_size >= Storage::BLOCK_SIZE); + auto alloc_size = GetAllocSize(block_size); + // first evict blocks until we have enough memory to store this buffer + unique_ptr reusable_buffer; + auto res = EvictBlocksOrThrow(alloc_size, &reusable_buffer, "could not allocate block of size %s%s", + StringUtil::BytesToHumanReadableString(alloc_size)); + + auto buffer = ConstructManagedBuffer(block_size, std::move(reusable_buffer)); + + // create a new block pointer for this block + return make_shared(*temp_block_manager, ++temporary_id, std::move(buffer), can_destroy, alloc_size, + std::move(res)); +} + +BufferHandle StandardBufferManager::Allocate(idx_t block_size, bool can_destroy, shared_ptr *block) { + shared_ptr local_block; + auto block_ptr = block ? block : &local_block; + *block_ptr = RegisterMemory(block_size, can_destroy); + return Pin(*block_ptr); +} + +void StandardBufferManager::ReAllocate(shared_ptr &handle, idx_t block_size) { + D_ASSERT(block_size >= Storage::BLOCK_SIZE); + lock_guard lock(handle->lock); + D_ASSERT(handle->state == BlockState::BLOCK_LOADED); + D_ASSERT(handle->memory_usage == handle->buffer->AllocSize()); + D_ASSERT(handle->memory_usage == handle->memory_charge.size); + + auto req = handle->buffer->CalculateMemory(block_size); + int64_t memory_delta = (int64_t)req.alloc_size - handle->memory_usage; + + if (memory_delta == 0) { + return; + } else if (memory_delta > 0) { + // evict blocks until we have space to resize this block + auto reservation = EvictBlocksOrThrow(memory_delta, nullptr, "failed to resize block from %s to %s%s", + StringUtil::BytesToHumanReadableString(handle->memory_usage), + StringUtil::BytesToHumanReadableString(req.alloc_size)); + // EvictBlocks decrements 'current_memory' for us. + handle->memory_charge.Merge(std::move(reservation)); + } else { + // no need to evict blocks, but we do need to decrement 'current_memory'. + handle->memory_charge.Resize(req.alloc_size); + } + + handle->ResizeBuffer(block_size, memory_delta); +} + +BufferHandle StandardBufferManager::Pin(shared_ptr &handle) { + idx_t required_memory; + { + // lock the block + lock_guard lock(handle->lock); + // check if the block is already loaded + if (handle->state == BlockState::BLOCK_LOADED) { + // the block is loaded, increment the reader count and return a pointer to the handle + handle->readers++; + return handle->Load(handle); + } + required_memory = handle->memory_usage; + } + // evict blocks until we have space for the current block + unique_ptr reusable_buffer; + auto reservation = EvictBlocksOrThrow(required_memory, &reusable_buffer, "failed to pin block of size %s%s", + StringUtil::BytesToHumanReadableString(required_memory)); + // lock the handle again and repeat the check (in case anybody loaded in the mean time) + lock_guard lock(handle->lock); + // check if the block is already loaded + if (handle->state == BlockState::BLOCK_LOADED) { + // the block is loaded, increment the reader count and return a pointer to the handle + handle->readers++; + reservation.Resize(0); + return handle->Load(handle); + } + // now we can actually load the current block + D_ASSERT(handle->readers == 0); + handle->readers = 1; + auto buf = handle->Load(handle, std::move(reusable_buffer)); + handle->memory_charge = std::move(reservation); + // In the case of a variable sized block, the buffer may be smaller than a full block. + int64_t delta = handle->buffer->AllocSize() - handle->memory_usage; + if (delta) { + D_ASSERT(delta < 0); + handle->memory_usage += delta; + handle->memory_charge.Resize(handle->memory_usage); + } + D_ASSERT(handle->memory_usage == handle->buffer->AllocSize()); + return buf; +} + +void StandardBufferManager::PurgeQueue() { + buffer_pool.PurgeQueue(); +} + +void StandardBufferManager::AddToEvictionQueue(shared_ptr &handle) { + buffer_pool.AddToEvictionQueue(handle); +} + +void StandardBufferManager::VerifyZeroReaders(shared_ptr &handle) { +#ifdef DUCKDB_DEBUG_DESTROY_BLOCKS + auto replacement_buffer = make_uniq(Allocator::Get(db), handle->buffer->type, + handle->memory_usage - Storage::BLOCK_HEADER_SIZE); + memcpy(replacement_buffer->buffer, handle->buffer->buffer, handle->buffer->size); + memset(handle->buffer->buffer, 0xa5, handle->buffer->size); // 0xa5 is default memory in debug mode + handle->buffer = std::move(replacement_buffer); +#endif +} + +void StandardBufferManager::Unpin(shared_ptr &handle) { + lock_guard lock(handle->lock); + if (!handle->buffer || handle->buffer->type == FileBufferType::TINY_BUFFER) { + return; + } + D_ASSERT(handle->readers > 0); + handle->readers--; + if (handle->readers == 0) { + VerifyZeroReaders(handle); + buffer_pool.AddToEvictionQueue(handle); + } +} + +void StandardBufferManager::SetLimit(idx_t limit) { + buffer_pool.SetLimit(limit, InMemoryWarning()); +} + +//===--------------------------------------------------------------------===// +// Temporary File Management +//===--------------------------------------------------------------------===// +unique_ptr ReadTemporaryBufferInternal(BufferManager &buffer_manager, FileHandle &handle, idx_t position, + idx_t size, block_id_t id, unique_ptr reusable_buffer) { + auto buffer = buffer_manager.ConstructManagedBuffer(size, std::move(reusable_buffer)); + buffer->Read(handle, position); + return buffer; +} + +struct TemporaryFileIndex { + explicit TemporaryFileIndex(idx_t file_index = DConstants::INVALID_INDEX, + idx_t block_index = DConstants::INVALID_INDEX) + : file_index(file_index), block_index(block_index) { + } + + idx_t file_index; + idx_t block_index; + +public: + bool IsValid() { + return block_index != DConstants::INVALID_INDEX; + } +}; + +struct BlockIndexManager { + BlockIndexManager() : max_index(0) { + } + +public: + //! Obtains a new block index from the index manager + idx_t GetNewBlockIndex() { + auto index = GetNewBlockIndexInternal(); + indexes_in_use.insert(index); + return index; + } + + //! Removes an index from the block manager + //! Returns true if the max_index has been altered + bool RemoveIndex(idx_t index) { + // remove this block from the set of blocks + auto entry = indexes_in_use.find(index); + if (entry == indexes_in_use.end()) { + throw InternalException("RemoveIndex - index %llu not found in indexes_in_use", index); + } + indexes_in_use.erase(entry); + free_indexes.insert(index); + // check if we can truncate the file + + // get the max_index in use right now + auto max_index_in_use = indexes_in_use.empty() ? 0 : *indexes_in_use.rbegin(); + if (max_index_in_use < max_index) { + // max index in use is lower than the max_index + // reduce the max_index + max_index = indexes_in_use.empty() ? 0 : max_index_in_use + 1; + // we can remove any free_indexes that are larger than the current max_index + while (!free_indexes.empty()) { + auto max_entry = *free_indexes.rbegin(); + if (max_entry < max_index) { + break; + } + free_indexes.erase(max_entry); + } + return true; + } + return false; + } + + idx_t GetMaxIndex() { + return max_index; + } + + bool HasFreeBlocks() { + return !free_indexes.empty(); + } + +private: + idx_t GetNewBlockIndexInternal() { + if (free_indexes.empty()) { + return max_index++; + } + auto entry = free_indexes.begin(); + auto index = *entry; + free_indexes.erase(entry); + return index; + } + + idx_t max_index; + set free_indexes; + set indexes_in_use; +}; + +class TemporaryFileHandle { + constexpr static idx_t MAX_ALLOWED_INDEX_BASE = 4000; + +public: + TemporaryFileHandle(idx_t temp_file_count, DatabaseInstance &db, const string &temp_directory, idx_t index) + : max_allowed_index((1 << temp_file_count) * MAX_ALLOWED_INDEX_BASE), db(db), file_index(index), + path(FileSystem::GetFileSystem(db).JoinPath(temp_directory, + "duckdb_temp_storage-" + to_string(index) + ".tmp")) { + } + +public: + struct TemporaryFileLock { + explicit TemporaryFileLock(mutex &mutex) : lock(mutex) { + } + + lock_guard lock; + }; + +public: + TemporaryFileIndex TryGetBlockIndex() { + TemporaryFileLock lock(file_lock); + if (index_manager.GetMaxIndex() >= max_allowed_index && index_manager.HasFreeBlocks()) { + // file is at capacity + return TemporaryFileIndex(); + } + // open the file handle if it does not yet exist + CreateFileIfNotExists(lock); + // fetch a new block index to write to + auto block_index = index_manager.GetNewBlockIndex(); + return TemporaryFileIndex(file_index, block_index); + } + + void WriteTemporaryFile(FileBuffer &buffer, TemporaryFileIndex index) { + D_ASSERT(buffer.size == Storage::BLOCK_SIZE); + buffer.Write(*handle, GetPositionInFile(index.block_index)); + } + + unique_ptr ReadTemporaryBuffer(block_id_t id, idx_t block_index, + unique_ptr reusable_buffer) { + return ReadTemporaryBufferInternal(BufferManager::GetBufferManager(db), *handle, GetPositionInFile(block_index), + Storage::BLOCK_SIZE, id, std::move(reusable_buffer)); + } + + void EraseBlockIndex(block_id_t block_index) { + // remove the block (and potentially truncate the temp file) + TemporaryFileLock lock(file_lock); + D_ASSERT(handle); + RemoveTempBlockIndex(lock, block_index); + } + + bool DeleteIfEmpty() { + TemporaryFileLock lock(file_lock); + if (index_manager.GetMaxIndex() > 0) { + // there are still blocks in this file + return false; + } + // the file is empty: delete it + handle.reset(); + auto &fs = FileSystem::GetFileSystem(db); + fs.RemoveFile(path); + return true; + } + + TemporaryFileInformation GetTemporaryFile() { + TemporaryFileLock lock(file_lock); + TemporaryFileInformation info; + info.path = path; + info.size = GetPositionInFile(index_manager.GetMaxIndex()); + return info; + } + +private: + void CreateFileIfNotExists(TemporaryFileLock &) { + if (handle) { + return; + } + auto &fs = FileSystem::GetFileSystem(db); + handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_WRITE | + FileFlags::FILE_FLAGS_FILE_CREATE); + } + + void RemoveTempBlockIndex(TemporaryFileLock &, idx_t index) { + // remove the block index from the index manager + if (index_manager.RemoveIndex(index)) { + // the max_index that is currently in use has decreased + // as a result we can truncate the file +#ifndef WIN32 // this ended up causing issues when sorting + auto max_index = index_manager.GetMaxIndex(); + auto &fs = FileSystem::GetFileSystem(db); + fs.Truncate(*handle, GetPositionInFile(max_index + 1)); +#endif + } + } + + idx_t GetPositionInFile(idx_t index) { + return index * Storage::BLOCK_ALLOC_SIZE; + } + +private: + const idx_t max_allowed_index; + DatabaseInstance &db; + unique_ptr handle; + idx_t file_index; + string path; + mutex file_lock; + BlockIndexManager index_manager; +}; + +class TemporaryFileManager { +public: + TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p) + : db(db), temp_directory(temp_directory_p) { + } + +public: + struct TemporaryManagerLock { + explicit TemporaryManagerLock(mutex &mutex) : lock(mutex) { + } + + lock_guard lock; + }; + + void WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) { + D_ASSERT(buffer.size == Storage::BLOCK_SIZE); + TemporaryFileIndex index; + TemporaryFileHandle *handle = nullptr; + + { + TemporaryManagerLock lock(manager_lock); + // first check if we can write to an open existing file + for (auto &entry : files) { + auto &temp_file = entry.second; + index = temp_file->TryGetBlockIndex(); + if (index.IsValid()) { + handle = entry.second.get(); + break; + } + } + if (!handle) { + // no existing handle to write to; we need to create & open a new file + auto new_file_index = index_manager.GetNewBlockIndex(); + auto new_file = make_uniq(files.size(), db, temp_directory, new_file_index); + handle = new_file.get(); + files[new_file_index] = std::move(new_file); + + index = handle->TryGetBlockIndex(); + } + D_ASSERT(used_blocks.find(block_id) == used_blocks.end()); + used_blocks[block_id] = index; + } + D_ASSERT(handle); + D_ASSERT(index.IsValid()); + handle->WriteTemporaryFile(buffer, index); + } + + bool HasTemporaryBuffer(block_id_t block_id) { + lock_guard lock(manager_lock); + return used_blocks.find(block_id) != used_blocks.end(); + } + + unique_ptr ReadTemporaryBuffer(block_id_t id, unique_ptr reusable_buffer) { + TemporaryFileIndex index; + TemporaryFileHandle *handle; + { + TemporaryManagerLock lock(manager_lock); + index = GetTempBlockIndex(lock, id); + handle = GetFileHandle(lock, index.file_index); + } + auto buffer = handle->ReadTemporaryBuffer(id, index.block_index, std::move(reusable_buffer)); + { + // remove the block (and potentially erase the temp file) + TemporaryManagerLock lock(manager_lock); + EraseUsedBlock(lock, id, handle, index); + } + return buffer; + } + + void DeleteTemporaryBuffer(block_id_t id) { + TemporaryManagerLock lock(manager_lock); + auto index = GetTempBlockIndex(lock, id); + auto handle = GetFileHandle(lock, index.file_index); + EraseUsedBlock(lock, id, handle, index); + } + + vector GetTemporaryFiles() { + lock_guard lock(manager_lock); + vector result; + for (auto &file : files) { + result.push_back(file.second->GetTemporaryFile()); + } + return result; + } + +private: + void EraseUsedBlock(TemporaryManagerLock &lock, block_id_t id, TemporaryFileHandle *handle, + TemporaryFileIndex index) { + auto entry = used_blocks.find(id); + if (entry == used_blocks.end()) { + throw InternalException("EraseUsedBlock - Block %llu not found in used blocks", id); + } + used_blocks.erase(entry); + handle->EraseBlockIndex(index.block_index); + if (handle->DeleteIfEmpty()) { + EraseFileHandle(lock, index.file_index); + } + } + + TemporaryFileHandle *GetFileHandle(TemporaryManagerLock &, idx_t index) { + return files[index].get(); + } + + TemporaryFileIndex GetTempBlockIndex(TemporaryManagerLock &, block_id_t id) { + D_ASSERT(used_blocks.find(id) != used_blocks.end()); + return used_blocks[id]; + } + + void EraseFileHandle(TemporaryManagerLock &, idx_t file_index) { + files.erase(file_index); + index_manager.RemoveIndex(file_index); + } + +private: + DatabaseInstance &db; + mutex manager_lock; + //! The temporary directory + string temp_directory; + //! The set of active temporary file handles + unordered_map> files; + //! map of block_id -> temporary file position + unordered_map used_blocks; + //! Manager of in-use temporary file indexes + BlockIndexManager index_manager; +}; + +TemporaryDirectoryHandle::TemporaryDirectoryHandle(DatabaseInstance &db, string path_p) + : db(db), temp_directory(std::move(path_p)), temp_file(make_uniq(db, temp_directory)) { + auto &fs = FileSystem::GetFileSystem(db); + if (!temp_directory.empty()) { + if (!fs.DirectoryExists(temp_directory)) { + fs.CreateDirectory(temp_directory); + created_directory = true; + } + } +} +TemporaryDirectoryHandle::~TemporaryDirectoryHandle() { + // first release any temporary files + temp_file.reset(); + // then delete the temporary file directory + auto &fs = FileSystem::GetFileSystem(db); + if (!temp_directory.empty()) { + bool delete_directory = created_directory; + vector files_to_delete; + if (!created_directory) { + bool deleted_everything = true; + fs.ListFiles(temp_directory, [&](const string &path, bool isdir) { + if (isdir) { + deleted_everything = false; + return; + } + if (!StringUtil::StartsWith(path, "duckdb_temp_")) { + deleted_everything = false; + return; + } + files_to_delete.push_back(path); + }); + } + if (delete_directory) { + // we want to remove all files in the directory + fs.RemoveDirectory(temp_directory); + } else { + for (auto &file : files_to_delete) { + fs.RemoveFile(fs.JoinPath(temp_directory, file)); + } + } + } +} + +TemporaryFileManager &TemporaryDirectoryHandle::GetTempFile() { + return *temp_file; +} + +string StandardBufferManager::GetTemporaryPath(block_id_t id) { + auto &fs = FileSystem::GetFileSystem(db); + return fs.JoinPath(temp_directory, "duckdb_temp_block-" + to_string(id) + ".block"); +} + +void StandardBufferManager::RequireTemporaryDirectory() { + if (temp_directory.empty()) { + throw Exception( + "Out-of-memory: cannot write buffer because no temporary directory is specified!\nTo enable " + "temporary buffer eviction set a temporary directory using PRAGMA temp_directory='/path/to/tmp.tmp'"); + } + lock_guard temp_handle_guard(temp_handle_lock); + if (!temp_directory_handle) { + // temp directory has not been created yet: initialize it + temp_directory_handle = make_uniq(db, temp_directory); + } +} + +void StandardBufferManager::WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) { + RequireTemporaryDirectory(); + if (buffer.size == Storage::BLOCK_SIZE) { + temp_directory_handle->GetTempFile().WriteTemporaryBuffer(block_id, buffer); + return; + } + // get the path to write to + auto path = GetTemporaryPath(block_id); + D_ASSERT(buffer.size > Storage::BLOCK_SIZE); + // create the file and write the size followed by the buffer contents + auto &fs = FileSystem::GetFileSystem(db); + auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE); + handle->Write(&buffer.size, sizeof(idx_t), 0); + buffer.Write(*handle, sizeof(idx_t)); +} + +unique_ptr StandardBufferManager::ReadTemporaryBuffer(block_id_t id, + unique_ptr reusable_buffer) { + D_ASSERT(!temp_directory.empty()); + D_ASSERT(temp_directory_handle.get()); + if (temp_directory_handle->GetTempFile().HasTemporaryBuffer(id)) { + return temp_directory_handle->GetTempFile().ReadTemporaryBuffer(id, std::move(reusable_buffer)); + } + idx_t block_size; + // open the temporary file and read the size + auto path = GetTemporaryPath(id); + auto &fs = FileSystem::GetFileSystem(db); + auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ); + handle->Read(&block_size, sizeof(idx_t), 0); + + // now allocate a buffer of this size and read the data into that buffer + auto buffer = + ReadTemporaryBufferInternal(*this, *handle, sizeof(idx_t), block_size, id, std::move(reusable_buffer)); + + handle.reset(); + DeleteTemporaryFile(id); + return buffer; +} + +void StandardBufferManager::DeleteTemporaryFile(block_id_t id) { + if (temp_directory.empty()) { + // no temporary directory specified: nothing to delete + return; + } + { + lock_guard temp_handle_guard(temp_handle_lock); + if (!temp_directory_handle) { + // temporary directory was not initialized yet: nothing to delete + return; + } + } + // check if we should delete the file from the shared pool of files, or from the general file system + if (temp_directory_handle->GetTempFile().HasTemporaryBuffer(id)) { + temp_directory_handle->GetTempFile().DeleteTemporaryBuffer(id); + return; + } + auto &fs = FileSystem::GetFileSystem(db); + auto path = GetTemporaryPath(id); + if (fs.FileExists(path)) { + fs.RemoveFile(path); + } +} + +bool StandardBufferManager::HasTemporaryDirectory() const { + return !temp_directory.empty(); +} + +vector StandardBufferManager::GetTemporaryFiles() { + vector result; + if (temp_directory.empty()) { + return result; + } + { + lock_guard temp_handle_guard(temp_handle_lock); + if (temp_directory_handle) { + result = temp_directory_handle->GetTempFile().GetTemporaryFiles(); + } + } + auto &fs = FileSystem::GetFileSystem(db); + fs.ListFiles(temp_directory, [&](const string &name, bool is_dir) { + if (is_dir) { + return; + } + if (!StringUtil::EndsWith(name, ".block")) { + return; + } + TemporaryFileInformation info; + info.path = name; + auto handle = fs.OpenFile(name, FileFlags::FILE_FLAGS_READ); + info.size = fs.GetFileSize(*handle); + handle.reset(); + result.push_back(info); + }); + return result; +} + +const char *StandardBufferManager::InMemoryWarning() { + if (!temp_directory.empty()) { + return ""; + } + return "\nDatabase is launched in in-memory mode and no temporary directory is specified." + "\nUnused blocks cannot be offloaded to disk." + "\n\nLaunch the database with a persistent storage back-end" + "\nOr set PRAGMA temp_directory='/path/to/tmp.tmp'"; +} + +void StandardBufferManager::ReserveMemory(idx_t size) { + if (size == 0) { + return; + } + auto reservation = EvictBlocksOrThrow(size, nullptr, "failed to reserve memory data of size %s%s", + StringUtil::BytesToHumanReadableString(size)); + reservation.size = 0; +} + +void StandardBufferManager::FreeReservedMemory(idx_t size) { + if (size == 0) { + return; + } + buffer_pool.current_memory -= size; +} + +//===--------------------------------------------------------------------===// +// Buffer Allocator +//===--------------------------------------------------------------------===// +data_ptr_t StandardBufferManager::BufferAllocatorAllocate(PrivateAllocatorData *private_data, idx_t size) { + auto &data = private_data->Cast(); + auto reservation = data.manager.EvictBlocksOrThrow(size, nullptr, "failed to allocate data of size %s%s", + StringUtil::BytesToHumanReadableString(size)); + // We rely on manual tracking of this one. :( + reservation.size = 0; + return Allocator::Get(data.manager.db).AllocateData(size); +} + +void StandardBufferManager::BufferAllocatorFree(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size) { + auto &data = private_data->Cast(); + BufferPoolReservation r(data.manager.GetBufferPool()); + r.size = size; + r.Resize(0); + return Allocator::Get(data.manager.db).FreeData(pointer, size); +} + +data_ptr_t StandardBufferManager::BufferAllocatorRealloc(PrivateAllocatorData *private_data, data_ptr_t pointer, + idx_t old_size, idx_t size) { + if (old_size == size) { + return pointer; + } + auto &data = private_data->Cast(); + BufferPoolReservation r(data.manager.GetBufferPool()); + r.size = old_size; + r.Resize(size); + r.size = 0; + return Allocator::Get(data.manager.db).ReallocateData(pointer, old_size, size); +} + +Allocator &BufferAllocator::Get(ClientContext &context) { + auto &manager = StandardBufferManager::GetBufferManager(context); + return manager.GetBufferAllocator(); +} + +Allocator &BufferAllocator::Get(DatabaseInstance &db) { + return StandardBufferManager::GetBufferManager(db).GetBufferAllocator(); +} + +Allocator &BufferAllocator::Get(AttachedDatabase &db) { + return BufferAllocator::Get(db.GetDatabase()); +} + +Allocator &StandardBufferManager::GetBufferAllocator() { + return buffer_allocator; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/base_statistics.cpp b/src/duckdb/src/storage/statistics/base_statistics.cpp new file mode 100644 index 00000000..fda77cf9 --- /dev/null +++ b/src/duckdb/src/storage/statistics/base_statistics.cpp @@ -0,0 +1,469 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +BaseStatistics::BaseStatistics() : type(LogicalType::INVALID) { +} + +BaseStatistics::BaseStatistics(LogicalType type) { + Construct(*this, std::move(type)); +} + +void BaseStatistics::Construct(BaseStatistics &stats, LogicalType type) { + stats.distinct_count = 0; + stats.type = std::move(type); + switch (GetStatsType(stats.type)) { + case StatisticsType::LIST_STATS: + ListStats::Construct(stats); + break; + case StatisticsType::STRUCT_STATS: + StructStats::Construct(stats); + break; + default: + break; + } +} + +BaseStatistics::~BaseStatistics() { +} + +BaseStatistics::BaseStatistics(BaseStatistics &&other) noexcept { + std::swap(type, other.type); + has_null = other.has_null; + has_no_null = other.has_no_null; + distinct_count = other.distinct_count; + stats_union = other.stats_union; + std::swap(child_stats, other.child_stats); +} + +BaseStatistics &BaseStatistics::operator=(BaseStatistics &&other) noexcept { + std::swap(type, other.type); + has_null = other.has_null; + has_no_null = other.has_no_null; + distinct_count = other.distinct_count; + stats_union = other.stats_union; + std::swap(child_stats, other.child_stats); + return *this; +} + +StatisticsType BaseStatistics::GetStatsType(const LogicalType &type) { + if (type.id() == LogicalTypeId::SQLNULL) { + return StatisticsType::BASE_STATS; + } + switch (type.InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::INT128: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + return StatisticsType::NUMERIC_STATS; + case PhysicalType::VARCHAR: + return StatisticsType::STRING_STATS; + case PhysicalType::STRUCT: + return StatisticsType::STRUCT_STATS; + case PhysicalType::LIST: + return StatisticsType::LIST_STATS; + case PhysicalType::BIT: + case PhysicalType::INTERVAL: + default: + return StatisticsType::BASE_STATS; + } +} + +StatisticsType BaseStatistics::GetStatsType() const { + return GetStatsType(GetType()); +} + +void BaseStatistics::InitializeUnknown() { + has_null = true; + has_no_null = true; +} + +void BaseStatistics::InitializeEmpty() { + has_null = false; + has_no_null = true; +} + +bool BaseStatistics::CanHaveNull() const { + return has_null; +} + +bool BaseStatistics::CanHaveNoNull() const { + return has_no_null; +} + +bool BaseStatistics::IsConstant() const { + if (type.id() == LogicalTypeId::VALIDITY) { + // validity mask + if (CanHaveNull() && !CanHaveNoNull()) { + return true; + } + if (!CanHaveNull() && CanHaveNoNull()) { + return true; + } + return false; + } + switch (GetStatsType()) { + case StatisticsType::NUMERIC_STATS: + return NumericStats::IsConstant(*this); + default: + break; + } + return false; +} + +void BaseStatistics::Merge(const BaseStatistics &other) { + has_null = has_null || other.has_null; + has_no_null = has_no_null || other.has_no_null; + switch (GetStatsType()) { + case StatisticsType::NUMERIC_STATS: + NumericStats::Merge(*this, other); + break; + case StatisticsType::STRING_STATS: + StringStats::Merge(*this, other); + break; + case StatisticsType::LIST_STATS: + ListStats::Merge(*this, other); + break; + case StatisticsType::STRUCT_STATS: + StructStats::Merge(*this, other); + break; + default: + break; + } +} + +idx_t BaseStatistics::GetDistinctCount() { + return distinct_count; +} + +BaseStatistics BaseStatistics::CreateUnknownType(LogicalType type) { + switch (GetStatsType(type)) { + case StatisticsType::NUMERIC_STATS: + return NumericStats::CreateUnknown(std::move(type)); + case StatisticsType::STRING_STATS: + return StringStats::CreateUnknown(std::move(type)); + case StatisticsType::LIST_STATS: + return ListStats::CreateUnknown(std::move(type)); + case StatisticsType::STRUCT_STATS: + return StructStats::CreateUnknown(std::move(type)); + default: + return BaseStatistics(std::move(type)); + } +} + +BaseStatistics BaseStatistics::CreateEmptyType(LogicalType type) { + switch (GetStatsType(type)) { + case StatisticsType::NUMERIC_STATS: + return NumericStats::CreateEmpty(std::move(type)); + case StatisticsType::STRING_STATS: + return StringStats::CreateEmpty(std::move(type)); + case StatisticsType::LIST_STATS: + return ListStats::CreateEmpty(std::move(type)); + case StatisticsType::STRUCT_STATS: + return StructStats::CreateEmpty(std::move(type)); + default: + return BaseStatistics(std::move(type)); + } +} + +BaseStatistics BaseStatistics::CreateUnknown(LogicalType type) { + auto result = CreateUnknownType(std::move(type)); + result.InitializeUnknown(); + return result; +} + +BaseStatistics BaseStatistics::CreateEmpty(LogicalType type) { + if (type.InternalType() == PhysicalType::BIT) { + // FIXME: this special case should not be necessary + // but currently InitializeEmpty sets StatsInfo::CAN_HAVE_VALID_VALUES + BaseStatistics result(std::move(type)); + result.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); + result.Set(StatsInfo::CANNOT_HAVE_VALID_VALUES); + return result; + } + auto result = CreateEmptyType(std::move(type)); + result.InitializeEmpty(); + return result; +} + +void BaseStatistics::Copy(const BaseStatistics &other) { + D_ASSERT(GetType() == other.GetType()); + CopyBase(other); + stats_union = other.stats_union; + switch (GetStatsType()) { + case StatisticsType::LIST_STATS: + ListStats::Copy(*this, other); + break; + case StatisticsType::STRUCT_STATS: + StructStats::Copy(*this, other); + break; + default: + break; + } +} + +BaseStatistics BaseStatistics::Copy() const { + BaseStatistics result(type); + result.Copy(*this); + return result; +} + +unique_ptr BaseStatistics::ToUnique() const { + auto result = unique_ptr(new BaseStatistics(type)); + result->Copy(*this); + return result; +} + +void BaseStatistics::CopyBase(const BaseStatistics &other) { + has_null = other.has_null; + has_no_null = other.has_no_null; + distinct_count = other.distinct_count; +} + +void BaseStatistics::Set(StatsInfo info) { + switch (info) { + case StatsInfo::CAN_HAVE_NULL_VALUES: + has_null = true; + break; + case StatsInfo::CANNOT_HAVE_NULL_VALUES: + has_null = false; + break; + case StatsInfo::CAN_HAVE_VALID_VALUES: + has_no_null = true; + break; + case StatsInfo::CANNOT_HAVE_VALID_VALUES: + has_no_null = false; + break; + case StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES: + has_null = true; + has_no_null = true; + break; + default: + throw InternalException("Unrecognized StatsInfo for BaseStatistics::Set"); + } +} + +void BaseStatistics::CombineValidity(BaseStatistics &left, BaseStatistics &right) { + has_null = left.has_null || right.has_null; + has_no_null = left.has_no_null || right.has_no_null; +} + +void BaseStatistics::CopyValidity(BaseStatistics &stats) { + has_null = stats.has_null; + has_no_null = stats.has_no_null; +} + +void BaseStatistics::SetDistinctCount(idx_t count) { + this->distinct_count = count; +} + +void BaseStatistics::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "has_null", has_null); + serializer.WriteProperty(101, "has_no_null", has_no_null); + serializer.WriteProperty(102, "distinct_count", distinct_count); + serializer.WriteObject(103, "type_stats", [&](Serializer &serializer) { + switch (GetStatsType()) { + case StatisticsType::NUMERIC_STATS: + NumericStats::Serialize(*this, serializer); + break; + case StatisticsType::STRING_STATS: + StringStats::Serialize(*this, serializer); + break; + case StatisticsType::LIST_STATS: + ListStats::Serialize(*this, serializer); + break; + case StatisticsType::STRUCT_STATS: + StructStats::Serialize(*this, serializer); + break; + default: + break; + } + }); +} + +BaseStatistics BaseStatistics::Deserialize(Deserializer &deserializer) { + auto has_null = deserializer.ReadProperty(100, "has_null"); + auto has_no_null = deserializer.ReadProperty(101, "has_no_null"); + auto distinct_count = deserializer.ReadProperty(102, "distinct_count"); + + // Get the logical type from the deserializer context. + auto type = deserializer.Get(); + + auto stats_type = GetStatsType(type); + + BaseStatistics stats(std::move(type)); + + stats.has_null = has_null; + stats.has_no_null = has_no_null; + stats.distinct_count = distinct_count; + + deserializer.ReadObject(103, "type_stats", [&](Deserializer &obj) { + switch (stats_type) { + case StatisticsType::NUMERIC_STATS: + NumericStats::Deserialize(obj, stats); + break; + case StatisticsType::STRING_STATS: + StringStats::Deserialize(obj, stats); + break; + case StatisticsType::LIST_STATS: + ListStats::Deserialize(obj, stats); + break; + case StatisticsType::STRUCT_STATS: + StructStats::Deserialize(obj, stats); + break; + default: + break; + } + }); + + return stats; +} + +string BaseStatistics::ToString() const { + auto has_n = has_null ? "true" : "false"; + auto has_n_n = has_no_null ? "true" : "false"; + string result = + StringUtil::Format("%s%s", StringUtil::Format("[Has Null: %s, Has No Null: %s]", has_n, has_n_n), + distinct_count > 0 ? StringUtil::Format("[Approx Unique: %lld]", distinct_count) : ""); + switch (GetStatsType()) { + case StatisticsType::NUMERIC_STATS: + result = NumericStats::ToString(*this) + result; + break; + case StatisticsType::STRING_STATS: + result = StringStats::ToString(*this) + result; + break; + case StatisticsType::LIST_STATS: + result = ListStats::ToString(*this) + result; + break; + case StatisticsType::STRUCT_STATS: + result = StructStats::ToString(*this) + result; + break; + default: + break; + } + return result; +} + +void BaseStatistics::Verify(Vector &vector, const SelectionVector &sel, idx_t count) const { + D_ASSERT(vector.GetType() == this->type); + switch (GetStatsType()) { + case StatisticsType::NUMERIC_STATS: + NumericStats::Verify(*this, vector, sel, count); + break; + case StatisticsType::STRING_STATS: + StringStats::Verify(*this, vector, sel, count); + break; + case StatisticsType::LIST_STATS: + ListStats::Verify(*this, vector, sel, count); + break; + case StatisticsType::STRUCT_STATS: + StructStats::Verify(*this, vector, sel, count); + break; + default: + break; + } + if (has_null && has_no_null) { + // nothing to verify + return; + } + UnifiedVectorFormat vdata; + vector.ToUnifiedFormat(count, vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto index = vdata.sel->get_index(idx); + bool row_is_valid = vdata.validity.RowIsValid(index); + if (row_is_valid && !has_no_null) { + throw InternalException( + "Statistics mismatch: vector labeled as having only NULL values, but vector contains valid values: %s", + vector.ToString(count)); + } + if (!row_is_valid && !has_null) { + throw InternalException( + "Statistics mismatch: vector labeled as not having NULL values, but vector contains null values: %s", + vector.ToString(count)); + } + } +} + +void BaseStatistics::Verify(Vector &vector, idx_t count) const { + auto sel = FlatVector::IncrementalSelectionVector(); + Verify(vector, *sel, count); +} + +BaseStatistics BaseStatistics::FromConstantType(const Value &input) { + switch (GetStatsType(input.type())) { + case StatisticsType::NUMERIC_STATS: { + auto result = NumericStats::CreateEmpty(input.type()); + NumericStats::SetMin(result, input); + NumericStats::SetMax(result, input); + return result; + } + case StatisticsType::STRING_STATS: { + auto result = StringStats::CreateEmpty(input.type()); + if (!input.IsNull()) { + auto &string_value = StringValue::Get(input); + StringStats::Update(result, string_t(string_value)); + } + return result; + } + case StatisticsType::LIST_STATS: { + auto result = ListStats::CreateEmpty(input.type()); + auto &child_stats = ListStats::GetChildStats(result); + if (!input.IsNull()) { + auto &list_children = ListValue::GetChildren(input); + for (auto &child_element : list_children) { + child_stats.Merge(FromConstant(child_element)); + } + } + return result; + } + case StatisticsType::STRUCT_STATS: { + auto result = StructStats::CreateEmpty(input.type()); + auto &child_types = StructType::GetChildTypes(input.type()); + if (input.IsNull()) { + for (idx_t i = 0; i < child_types.size(); i++) { + StructStats::SetChildStats(result, i, FromConstant(Value(child_types[i].second))); + } + } else { + auto &struct_children = StructValue::GetChildren(input); + for (idx_t i = 0; i < child_types.size(); i++) { + StructStats::SetChildStats(result, i, FromConstant(struct_children[i])); + } + } + return result; + } + default: + return BaseStatistics(input.type()); + } +} + +BaseStatistics BaseStatistics::FromConstant(const Value &input) { + auto result = FromConstantType(input); + result.SetDistinctCount(1); + if (input.IsNull()) { + result.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + result.Set(StatsInfo::CANNOT_HAVE_VALID_VALUES); + } else { + result.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); + result.Set(StatsInfo::CAN_HAVE_VALID_VALUES); + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/column_statistics.cpp b/src/duckdb/src/storage/statistics/column_statistics.cpp new file mode 100644 index 00000000..adf2c3b4 --- /dev/null +++ b/src/duckdb/src/storage/statistics/column_statistics.cpp @@ -0,0 +1,70 @@ +#include "duckdb/storage/statistics/column_statistics.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" + +namespace duckdb { + +ColumnStatistics::ColumnStatistics(BaseStatistics stats_p) : stats(std::move(stats_p)) { + if (DistinctStatistics::TypeIsSupported(stats.GetType())) { + distinct_stats = make_uniq(); + } +} +ColumnStatistics::ColumnStatistics(BaseStatistics stats_p, unique_ptr distinct_stats_p) + : stats(std::move(stats_p)), distinct_stats(std::move(distinct_stats_p)) { +} + +shared_ptr ColumnStatistics::CreateEmptyStats(const LogicalType &type) { + return make_shared(BaseStatistics::CreateEmpty(type)); +} + +void ColumnStatistics::Merge(ColumnStatistics &other) { + stats.Merge(other.stats); + if (distinct_stats) { + distinct_stats->Merge(*other.distinct_stats); + } +} + +BaseStatistics &ColumnStatistics::Statistics() { + return stats; +} + +bool ColumnStatistics::HasDistinctStats() { + return distinct_stats.get(); +} + +DistinctStatistics &ColumnStatistics::DistinctStats() { + if (!distinct_stats) { + throw InternalException("DistinctStats called without distinct_stats"); + } + return *distinct_stats; +} + +void ColumnStatistics::SetDistinct(unique_ptr distinct) { + this->distinct_stats = std::move(distinct); +} + +void ColumnStatistics::UpdateDistinctStatistics(Vector &v, idx_t count) { + if (!distinct_stats) { + return; + } + auto &d_stats = (DistinctStatistics &)*distinct_stats; + d_stats.Update(v, count); +} + +shared_ptr ColumnStatistics::Copy() const { + return make_shared(stats.Copy(), distinct_stats ? distinct_stats->Copy() : nullptr); +} + +void ColumnStatistics::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "statistics", stats); + serializer.WritePropertyWithDefault(101, "distinct", distinct_stats, unique_ptr()); +} + +shared_ptr ColumnStatistics::Deserialize(Deserializer &deserializer) { + auto stats = deserializer.ReadProperty(100, "statistics"); + auto distinct_stats = deserializer.ReadPropertyWithDefault>( + 101, "distinct", unique_ptr()); + return make_shared(std::move(stats), std::move(distinct_stats)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/distinct_statistics.cpp b/src/duckdb/src/storage/statistics/distinct_statistics.cpp new file mode 100644 index 00000000..e80cc140 --- /dev/null +++ b/src/duckdb/src/storage/statistics/distinct_statistics.cpp @@ -0,0 +1,75 @@ +#include "duckdb/storage/statistics/distinct_statistics.hpp" + +#include "duckdb/common/string_util.hpp" + +#include + +namespace duckdb { + +DistinctStatistics::DistinctStatistics() : log(make_uniq()), sample_count(0), total_count(0) { +} + +DistinctStatistics::DistinctStatistics(unique_ptr log, idx_t sample_count, idx_t total_count) + : log(std::move(log)), sample_count(sample_count), total_count(total_count) { +} + +unique_ptr DistinctStatistics::Copy() const { + return make_uniq(log->Copy(), sample_count, total_count); +} + +void DistinctStatistics::Merge(const DistinctStatistics &other) { + log = log->Merge(*other.log); + sample_count += other.sample_count; + total_count += other.total_count; +} + +void DistinctStatistics::Update(Vector &v, idx_t count, bool sample) { + UnifiedVectorFormat vdata; + v.ToUnifiedFormat(count, vdata); + Update(vdata, v.GetType(), count, sample); +} + +void DistinctStatistics::Update(UnifiedVectorFormat &vdata, const LogicalType &type, idx_t count, bool sample) { + if (count == 0) { + return; + } + + total_count += count; + if (sample) { + count = MinValue(idx_t(SAMPLE_RATE * MaxValue(STANDARD_VECTOR_SIZE, count)), count); + } + sample_count += count; + + uint64_t indices[STANDARD_VECTOR_SIZE]; + uint8_t counts[STANDARD_VECTOR_SIZE]; + + HyperLogLog::ProcessEntries(vdata, type, indices, counts, count); + log->AddToLog(vdata, count, indices, counts); +} + +string DistinctStatistics::ToString() const { + return StringUtil::Format("[Approx Unique: %s]", to_string(GetCount())); +} + +idx_t DistinctStatistics::GetCount() const { + if (sample_count == 0 || total_count == 0) { + return 0; + } + + double u = MinValue(log->Count(), sample_count); + double s = sample_count; + double n = total_count; + + // Assume this proportion of the the sampled values occurred only once + double u1 = pow(u / s, 2) * u; + + // Estimate total uniques using Good Turing Estimation + idx_t estimate = u + u1 / s * (n - s); + return MinValue(estimate, total_count); +} + +bool DistinctStatistics::TypeIsSupported(const LogicalType &type) { + return type.InternalType() != PhysicalType::LIST && type.InternalType() != PhysicalType::STRUCT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/list_stats.cpp b/src/duckdb/src/storage/statistics/list_stats.cpp new file mode 100644 index 00000000..0d4a0011 --- /dev/null +++ b/src/duckdb/src/storage/statistics/list_stats.cpp @@ -0,0 +1,126 @@ +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/vector.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +void ListStats::Construct(BaseStatistics &stats) { + stats.child_stats = unsafe_unique_array(new BaseStatistics[1]); + BaseStatistics::Construct(stats.child_stats[0], ListType::GetChildType(stats.GetType())); +} + +BaseStatistics ListStats::CreateUnknown(LogicalType type) { + auto &child_type = ListType::GetChildType(type); + BaseStatistics result(std::move(type)); + result.InitializeUnknown(); + result.child_stats[0].Copy(BaseStatistics::CreateUnknown(child_type)); + return result; +} + +BaseStatistics ListStats::CreateEmpty(LogicalType type) { + auto &child_type = ListType::GetChildType(type); + BaseStatistics result(std::move(type)); + result.InitializeEmpty(); + result.child_stats[0].Copy(BaseStatistics::CreateEmpty(child_type)); + return result; +} + +void ListStats::Copy(BaseStatistics &stats, const BaseStatistics &other) { + D_ASSERT(stats.child_stats); + D_ASSERT(other.child_stats); + stats.child_stats[0].Copy(other.child_stats[0]); +} + +const BaseStatistics &ListStats::GetChildStats(const BaseStatistics &stats) { + if (stats.GetStatsType() != StatisticsType::LIST_STATS) { + throw InternalException("ListStats::GetChildStats called on stats that is not a list"); + } + D_ASSERT(stats.child_stats); + return stats.child_stats[0]; +} +BaseStatistics &ListStats::GetChildStats(BaseStatistics &stats) { + if (stats.GetStatsType() != StatisticsType::LIST_STATS) { + throw InternalException("ListStats::GetChildStats called on stats that is not a list"); + } + D_ASSERT(stats.child_stats); + return stats.child_stats[0]; +} + +void ListStats::SetChildStats(BaseStatistics &stats, unique_ptr new_stats) { + if (!new_stats) { + stats.child_stats[0].Copy(BaseStatistics::CreateUnknown(ListType::GetChildType(stats.GetType()))); + } else { + stats.child_stats[0].Copy(*new_stats); + } +} + +void ListStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + + auto &child_stats = ListStats::GetChildStats(stats); + auto &other_child_stats = ListStats::GetChildStats(other); + child_stats.Merge(other_child_stats); +} + +void ListStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { + auto &child_stats = ListStats::GetChildStats(stats); + serializer.WriteProperty(200, "child_stats", child_stats); +} + +void ListStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { + auto &type = base.GetType(); + D_ASSERT(type.InternalType() == PhysicalType::LIST); + auto &child_type = ListType::GetChildType(type); + + // Push the logical type of the child type to the deserialization context + deserializer.Set(const_cast(child_type)); + base.child_stats[0].Copy(deserializer.ReadProperty(200, "child_stats")); + deserializer.Unset(); +} + +string ListStats::ToString(const BaseStatistics &stats) { + auto &child_stats = ListStats::GetChildStats(stats); + return StringUtil::Format("[%s]", child_stats.ToString()); +} + +void ListStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { + auto &child_stats = ListStats::GetChildStats(stats); + auto &child_entry = ListVector::GetEntry(vector); + UnifiedVectorFormat vdata; + vector.ToUnifiedFormat(count, vdata); + + auto list_data = UnifiedVectorFormat::GetData(vdata); + idx_t total_list_count = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto index = vdata.sel->get_index(idx); + auto list = list_data[index]; + if (vdata.validity.RowIsValid(index)) { + for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { + total_list_count++; + } + } + } + SelectionVector list_sel(total_list_count); + idx_t list_count = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto index = vdata.sel->get_index(idx); + auto list = list_data[index]; + if (vdata.validity.RowIsValid(index)) { + for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { + list_sel.set_index(list_count++, list.offset + list_idx); + } + } + } + + child_stats.Verify(child_entry, list_sel, list_count); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/numeric_stats.cpp b/src/duckdb/src/storage/statistics/numeric_stats.cpp new file mode 100644 index 00000000..79e3376b --- /dev/null +++ b/src/duckdb/src/storage/statistics/numeric_stats.cpp @@ -0,0 +1,601 @@ +#include "duckdb/storage/statistics/numeric_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +template <> +void NumericStats::Update(BaseStatistics &stats, interval_t new_value) { +} + +template <> +void NumericStats::Update(BaseStatistics &stats, list_entry_t new_value) { +} + +//===--------------------------------------------------------------------===// +// NumericStats +//===--------------------------------------------------------------------===// +BaseStatistics NumericStats::CreateUnknown(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeUnknown(); + SetMin(result, Value(result.GetType())); + SetMax(result, Value(result.GetType())); + return result; +} + +BaseStatistics NumericStats::CreateEmpty(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeEmpty(); + SetMin(result, Value::MaximumValue(result.GetType())); + SetMax(result, Value::MinimumValue(result.GetType())); + return result; +} + +NumericStatsData &NumericStats::GetDataUnsafe(BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::NUMERIC_STATS); + return stats.stats_union.numeric_data; +} + +const NumericStatsData &NumericStats::GetDataUnsafe(const BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::NUMERIC_STATS); + return stats.stats_union.numeric_data; +} + +void NumericStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + D_ASSERT(stats.GetType() == other.GetType()); + if (NumericStats::HasMin(other) && NumericStats::HasMin(stats)) { + auto other_min = NumericStats::Min(other); + if (other_min < NumericStats::Min(stats)) { + NumericStats::SetMin(stats, other_min); + } + } else { + NumericStats::SetMin(stats, Value()); + } + if (NumericStats::HasMax(other) && NumericStats::HasMax(stats)) { + auto other_max = NumericStats::Max(other); + if (other_max > NumericStats::Max(stats)) { + NumericStats::SetMax(stats, other_max); + } + } else { + NumericStats::SetMax(stats, Value()); + } +} + +struct GetNumericValueUnion { + template + static T Operation(const NumericValueUnion &v); +}; + +template <> +int8_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.tinyint; +} + +template <> +int16_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.smallint; +} + +template <> +int32_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.integer; +} + +template <> +int64_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.bigint; +} + +template <> +hugeint_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.hugeint; +} + +template <> +uint8_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.utinyint; +} + +template <> +uint16_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.usmallint; +} + +template <> +uint32_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.uinteger; +} + +template <> +uint64_t GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.ubigint; +} + +template <> +float GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.float_; +} + +template <> +double GetNumericValueUnion::Operation(const NumericValueUnion &v) { + return v.value_.double_; +} + +template +T NumericStats::GetMinUnsafe(const BaseStatistics &stats) { + return GetNumericValueUnion::Operation(NumericStats::GetDataUnsafe(stats).min); +} + +template +T NumericStats::GetMaxUnsafe(const BaseStatistics &stats) { + return GetNumericValueUnion::Operation(NumericStats::GetDataUnsafe(stats).max); +} + +template +bool ConstantExactRange(T min, T max, T constant) { + return Equals::Operation(constant, min) && Equals::Operation(constant, max); +} + +template +bool ConstantValueInRange(T min, T max, T constant) { + return !(LessThan::Operation(constant, min) || GreaterThan::Operation(constant, max)); +} + +template +FilterPropagateResult CheckZonemapTemplated(const BaseStatistics &stats, ExpressionType comparison_type, + const Value &constant_value) { + T min_value = NumericStats::GetMinUnsafe(stats); + T max_value = NumericStats::GetMaxUnsafe(stats); + T constant = constant_value.GetValueUnsafe(); + switch (comparison_type) { + case ExpressionType::COMPARE_EQUAL: + if (ConstantExactRange(min_value, max_value, constant)) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + if (ConstantValueInRange(min_value, max_value, constant)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + case ExpressionType::COMPARE_NOTEQUAL: + if (!ConstantValueInRange(min_value, max_value, constant)) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } else if (ConstantExactRange(min_value, max_value, constant)) { + // corner case of a cluster with one numeric equal to the target constant + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + // GreaterThanEquals::Operation(X, C) + // this can be true only if max(X) >= C + // if min(X) >= C, then this is always true + if (GreaterThanEquals::Operation(min_value, constant)) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } else if (GreaterThanEquals::Operation(max_value, constant)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } else { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + case ExpressionType::COMPARE_GREATERTHAN: + // GreaterThan::Operation(X, C) + // this can be true only if max(X) > C + // if min(X) > C, then this is always true + if (GreaterThan::Operation(min_value, constant)) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } else if (GreaterThan::Operation(max_value, constant)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } else { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // LessThanEquals::Operation(X, C) + // this can be true only if min(X) <= C + // if max(X) <= C, then this is always true + if (LessThanEquals::Operation(max_value, constant)) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } else if (LessThanEquals::Operation(min_value, constant)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } else { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + case ExpressionType::COMPARE_LESSTHAN: + // LessThan::Operation(X, C) + // this can be true only if min(X) < C + // if max(X) < C, then this is always true + if (LessThan::Operation(max_value, constant)) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } else if (LessThan::Operation(min_value, constant)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } else { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + default: + throw InternalException("Expression type in zonemap check not implemented"); + } +} + +FilterPropagateResult NumericStats::CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, + const Value &constant) { + D_ASSERT(constant.type() == stats.GetType()); + if (constant.IsNull()) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + if (!NumericStats::HasMinMax(stats)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + switch (stats.GetType().InternalType()) { + case PhysicalType::INT8: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::INT16: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::INT32: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::INT64: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::UINT8: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::UINT16: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::UINT32: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::UINT64: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::INT128: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::FLOAT: + return CheckZonemapTemplated(stats, comparison_type, constant); + case PhysicalType::DOUBLE: + return CheckZonemapTemplated(stats, comparison_type, constant); + default: + throw InternalException("Unsupported type for NumericStats::CheckZonemap"); + } +} + +bool NumericStats::IsConstant(const BaseStatistics &stats) { + return NumericStats::Max(stats) <= NumericStats::Min(stats); +} + +void SetNumericValueInternal(const Value &input, const LogicalType &type, NumericValueUnion &val, bool &has_val) { + if (input.IsNull()) { + has_val = false; + return; + } + if (input.type().InternalType() != type.InternalType()) { + throw InternalException("SetMin or SetMax called with Value that does not match statistics' column value"); + } + has_val = true; + switch (type.InternalType()) { + case PhysicalType::BOOL: + val.value_.boolean = BooleanValue::Get(input); + break; + case PhysicalType::INT8: + val.value_.tinyint = TinyIntValue::Get(input); + break; + case PhysicalType::INT16: + val.value_.smallint = SmallIntValue::Get(input); + break; + case PhysicalType::INT32: + val.value_.integer = IntegerValue::Get(input); + break; + case PhysicalType::INT64: + val.value_.bigint = BigIntValue::Get(input); + break; + case PhysicalType::UINT8: + val.value_.utinyint = UTinyIntValue::Get(input); + break; + case PhysicalType::UINT16: + val.value_.usmallint = USmallIntValue::Get(input); + break; + case PhysicalType::UINT32: + val.value_.uinteger = UIntegerValue::Get(input); + break; + case PhysicalType::UINT64: + val.value_.ubigint = UBigIntValue::Get(input); + break; + case PhysicalType::INT128: + val.value_.hugeint = HugeIntValue::Get(input); + break; + case PhysicalType::FLOAT: + val.value_.float_ = FloatValue::Get(input); + break; + case PhysicalType::DOUBLE: + val.value_.double_ = DoubleValue::Get(input); + break; + default: + throw InternalException("Unsupported type for NumericStatistics::SetValueInternal"); + } +} + +void NumericStats::SetMin(BaseStatistics &stats, const Value &new_min) { + auto &data = NumericStats::GetDataUnsafe(stats); + SetNumericValueInternal(new_min, stats.GetType(), data.min, data.has_min); +} + +void NumericStats::SetMax(BaseStatistics &stats, const Value &new_max) { + auto &data = NumericStats::GetDataUnsafe(stats); + SetNumericValueInternal(new_max, stats.GetType(), data.max, data.has_max); +} + +Value NumericValueUnionToValueInternal(const LogicalType &type, const NumericValueUnion &val) { + switch (type.InternalType()) { + case PhysicalType::BOOL: + return Value::BOOLEAN(val.value_.boolean); + case PhysicalType::INT8: + return Value::TINYINT(val.value_.tinyint); + case PhysicalType::INT16: + return Value::SMALLINT(val.value_.smallint); + case PhysicalType::INT32: + return Value::INTEGER(val.value_.integer); + case PhysicalType::INT64: + return Value::BIGINT(val.value_.bigint); + case PhysicalType::UINT8: + return Value::UTINYINT(val.value_.utinyint); + case PhysicalType::UINT16: + return Value::USMALLINT(val.value_.usmallint); + case PhysicalType::UINT32: + return Value::UINTEGER(val.value_.uinteger); + case PhysicalType::UINT64: + return Value::UBIGINT(val.value_.ubigint); + case PhysicalType::INT128: + return Value::HUGEINT(val.value_.hugeint); + case PhysicalType::FLOAT: + return Value::FLOAT(val.value_.float_); + case PhysicalType::DOUBLE: + return Value::DOUBLE(val.value_.double_); + default: + throw InternalException("Unsupported type for NumericValueUnionToValue"); + } +} + +Value NumericValueUnionToValue(const LogicalType &type, const NumericValueUnion &val) { + Value result = NumericValueUnionToValueInternal(type, val); + result.GetTypeMutable() = type; + return result; +} + +bool NumericStats::HasMinMax(const BaseStatistics &stats) { + return NumericStats::HasMin(stats) && NumericStats::HasMax(stats); +} + +bool NumericStats::HasMin(const BaseStatistics &stats) { + if (stats.GetType().id() == LogicalTypeId::SQLNULL) { + return false; + } + return NumericStats::GetDataUnsafe(stats).has_min; +} + +bool NumericStats::HasMax(const BaseStatistics &stats) { + if (stats.GetType().id() == LogicalTypeId::SQLNULL) { + return false; + } + return NumericStats::GetDataUnsafe(stats).has_max; +} + +Value NumericStats::Min(const BaseStatistics &stats) { + if (!NumericStats::HasMin(stats)) { + throw InternalException("Min() called on statistics that does not have min"); + } + return NumericValueUnionToValue(stats.GetType(), NumericStats::GetDataUnsafe(stats).min); +} + +Value NumericStats::Max(const BaseStatistics &stats) { + if (!NumericStats::HasMax(stats)) { + throw InternalException("Max() called on statistics that does not have max"); + } + return NumericValueUnionToValue(stats.GetType(), NumericStats::GetDataUnsafe(stats).max); +} + +Value NumericStats::MinOrNull(const BaseStatistics &stats) { + if (!NumericStats::HasMin(stats)) { + return Value(stats.GetType()); + } + return NumericStats::Min(stats); +} + +Value NumericStats::MaxOrNull(const BaseStatistics &stats) { + if (!NumericStats::HasMax(stats)) { + return Value(stats.GetType()); + } + return NumericStats::Max(stats); +} + +static void SerializeNumericStatsValue(const LogicalType &type, NumericValueUnion val, bool has_value, + Serializer &serializer) { + serializer.WriteProperty(100, "has_value", has_value); + if (!has_value) { + return; + } + switch (type.InternalType()) { + case PhysicalType::BOOL: + serializer.WriteProperty(101, "value", val.value_.boolean); + break; + case PhysicalType::INT8: + serializer.WriteProperty(101, "value", val.value_.tinyint); + break; + case PhysicalType::INT16: + serializer.WriteProperty(101, "value", val.value_.smallint); + break; + case PhysicalType::INT32: + serializer.WriteProperty(101, "value", val.value_.integer); + break; + case PhysicalType::INT64: + serializer.WriteProperty(101, "value", val.value_.bigint); + break; + case PhysicalType::UINT8: + serializer.WriteProperty(101, "value", val.value_.utinyint); + break; + case PhysicalType::UINT16: + serializer.WriteProperty(101, "value", val.value_.usmallint); + break; + case PhysicalType::UINT32: + serializer.WriteProperty(101, "value", val.value_.uinteger); + break; + case PhysicalType::UINT64: + serializer.WriteProperty(101, "value", val.value_.ubigint); + break; + case PhysicalType::INT128: + serializer.WriteProperty(101, "value", val.value_.hugeint); + break; + case PhysicalType::FLOAT: + serializer.WriteProperty(101, "value", val.value_.float_); + break; + case PhysicalType::DOUBLE: + serializer.WriteProperty(101, "value", val.value_.double_); + break; + default: + throw InternalException("Unsupported type for serializing numeric statistics"); + } +} + +static void DeserializeNumericStatsValue(const LogicalType &type, NumericValueUnion &result, bool &has_stats, + Deserializer &deserializer) { + auto has_value = deserializer.ReadProperty(100, "has_value"); + if (!has_value) { + has_stats = false; + return; + } + has_stats = true; + switch (type.InternalType()) { + case PhysicalType::BOOL: + result.value_.boolean = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::INT8: + result.value_.tinyint = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::INT16: + result.value_.smallint = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::INT32: + result.value_.integer = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::INT64: + result.value_.bigint = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::UINT8: + result.value_.utinyint = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::UINT16: + result.value_.usmallint = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::UINT32: + result.value_.uinteger = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::UINT64: + result.value_.ubigint = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::INT128: + result.value_.hugeint = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::FLOAT: + result.value_.float_ = deserializer.ReadProperty(101, "value"); + break; + case PhysicalType::DOUBLE: + result.value_.double_ = deserializer.ReadProperty(101, "value"); + break; + default: + throw InternalException("Unsupported type for serializing numeric statistics"); + } +} + +void NumericStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { + auto &numeric_stats = NumericStats::GetDataUnsafe(stats); + serializer.WriteObject(200, "max", [&](Serializer &object) { + SerializeNumericStatsValue(stats.GetType(), numeric_stats.min, numeric_stats.has_min, object); + }); + serializer.WriteObject(201, "min", [&](Serializer &object) { + SerializeNumericStatsValue(stats.GetType(), numeric_stats.max, numeric_stats.has_max, object); + }); +} + +void NumericStats::Deserialize(Deserializer &deserializer, BaseStatistics &result) { + auto &numeric_stats = NumericStats::GetDataUnsafe(result); + + deserializer.ReadObject(200, "max", [&](Deserializer &object) { + DeserializeNumericStatsValue(result.GetType(), numeric_stats.min, numeric_stats.has_min, object); + }); + deserializer.ReadObject(201, "min", [&](Deserializer &object) { + DeserializeNumericStatsValue(result.GetType(), numeric_stats.max, numeric_stats.has_max, object); + }); +} + +string NumericStats::ToString(const BaseStatistics &stats) { + return StringUtil::Format("[Min: %s, Max: %s]", NumericStats::MinOrNull(stats).ToString(), + NumericStats::MaxOrNull(stats).ToString()); +} + +template +void NumericStats::TemplatedVerify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, + idx_t count) { + UnifiedVectorFormat vdata; + vector.ToUnifiedFormat(count, vdata); + + auto data = UnifiedVectorFormat::GetData(vdata); + auto min_value = NumericStats::MinOrNull(stats); + auto max_value = NumericStats::MaxOrNull(stats); + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto index = vdata.sel->get_index(idx); + if (!vdata.validity.RowIsValid(index)) { + continue; + } + if (!min_value.IsNull() && LessThan::Operation(data[index], min_value.GetValueUnsafe())) { // LCOV_EXCL_START + throw InternalException("Statistics mismatch: value is smaller than min.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString(count)); + } // LCOV_EXCL_STOP + if (!max_value.IsNull() && GreaterThan::Operation(data[index], max_value.GetValueUnsafe())) { + throw InternalException("Statistics mismatch: value is bigger than max.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString(count)); + } + } +} + +void NumericStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { + auto &type = stats.GetType(); + switch (type.InternalType()) { + case PhysicalType::BOOL: + break; + case PhysicalType::INT8: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::INT16: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::INT32: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::INT64: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::UINT8: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::UINT16: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::UINT32: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::UINT64: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::INT128: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::FLOAT: + TemplatedVerify(stats, vector, sel, count); + break; + case PhysicalType::DOUBLE: + TemplatedVerify(stats, vector, sel, count); + break; + default: + throw InternalException("Unsupported type %s for numeric statistics verify", type.ToString()); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/numeric_stats_union.cpp b/src/duckdb/src/storage/statistics/numeric_stats_union.cpp new file mode 100644 index 00000000..ab6bc673 --- /dev/null +++ b/src/duckdb/src/storage/statistics/numeric_stats_union.cpp @@ -0,0 +1,65 @@ +#include "duckdb/storage/statistics/numeric_stats_union.hpp" + +namespace duckdb { + +template <> +bool &NumericValueUnion::GetReferenceUnsafe() { + return value_.boolean; +} + +template <> +int8_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.tinyint; +} + +template <> +int16_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.smallint; +} + +template <> +int32_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.integer; +} + +template <> +int64_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.bigint; +} + +template <> +hugeint_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.hugeint; +} + +template <> +uint8_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.utinyint; +} + +template <> +uint16_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.usmallint; +} + +template <> +uint32_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.uinteger; +} + +template <> +uint64_t &NumericValueUnion::GetReferenceUnsafe() { + return value_.ubigint; +} + +template <> +float &NumericValueUnion::GetReferenceUnsafe() { + return value_.float_; +} + +template <> +double &NumericValueUnion::GetReferenceUnsafe() { + return value_.double_; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/segment_statistics.cpp b/src/duckdb/src/storage/statistics/segment_statistics.cpp new file mode 100644 index 00000000..69717544 --- /dev/null +++ b/src/duckdb/src/storage/statistics/segment_statistics.cpp @@ -0,0 +1,13 @@ +#include "duckdb/storage/statistics/segment_statistics.hpp" + +#include "duckdb/common/exception.hpp" + +namespace duckdb { + +SegmentStatistics::SegmentStatistics(LogicalType type) : statistics(BaseStatistics::CreateEmpty(std::move(type))) { +} + +SegmentStatistics::SegmentStatistics(BaseStatistics stats) : statistics(std::move(stats)) { +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/string_stats.cpp b/src/duckdb/src/storage/statistics/string_stats.cpp new file mode 100644 index 00000000..c9be1c38 --- /dev/null +++ b/src/duckdb/src/storage/statistics/string_stats.cpp @@ -0,0 +1,293 @@ +#include "duckdb/storage/statistics/string_stats.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/main/error_manager.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "utf8proc_wrapper.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +BaseStatistics StringStats::CreateUnknown(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeUnknown(); + auto &string_data = StringStats::GetDataUnsafe(result); + for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { + string_data.min[i] = 0; + string_data.max[i] = 0xFF; + } + string_data.max_string_length = 0; + string_data.has_max_string_length = false; + string_data.has_unicode = true; + return result; +} + +BaseStatistics StringStats::CreateEmpty(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeEmpty(); + auto &string_data = StringStats::GetDataUnsafe(result); + for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { + string_data.min[i] = 0xFF; + string_data.max[i] = 0; + } + string_data.max_string_length = 0; + string_data.has_max_string_length = true; + string_data.has_unicode = false; + return result; +} + +StringStatsData &StringStats::GetDataUnsafe(BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::STRING_STATS); + return stats.stats_union.string_data; +} + +const StringStatsData &StringStats::GetDataUnsafe(const BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::STRING_STATS); + return stats.stats_union.string_data; +} + +bool StringStats::HasMaxStringLength(const BaseStatistics &stats) { + if (stats.GetType().id() == LogicalTypeId::SQLNULL) { + return false; + } + return StringStats::GetDataUnsafe(stats).has_max_string_length; +} + +uint32_t StringStats::MaxStringLength(const BaseStatistics &stats) { + if (!HasMaxStringLength(stats)) { + throw InternalException("MaxStringLength called on statistics that does not have a max string length"); + } + return StringStats::GetDataUnsafe(stats).max_string_length; +} + +bool StringStats::CanContainUnicode(const BaseStatistics &stats) { + if (stats.GetType().id() == LogicalTypeId::SQLNULL) { + return true; + } + return StringStats::GetDataUnsafe(stats).has_unicode; +} + +string GetStringMinMaxValue(const data_t data[]) { + idx_t len; + for (len = 0; len < StringStatsData::MAX_STRING_MINMAX_SIZE; len++) { + if (!data[len]) { + break; + } + } + return string(const_char_ptr_cast(data), len); +} + +string StringStats::Min(const BaseStatistics &stats) { + return GetStringMinMaxValue(StringStats::GetDataUnsafe(stats).min); +} + +string StringStats::Max(const BaseStatistics &stats) { + return GetStringMinMaxValue(StringStats::GetDataUnsafe(stats).max); +} + +void StringStats::ResetMaxStringLength(BaseStatistics &stats) { + StringStats::GetDataUnsafe(stats).has_max_string_length = false; +} + +void StringStats::SetContainsUnicode(BaseStatistics &stats) { + StringStats::GetDataUnsafe(stats).has_unicode = true; +} + +void StringStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { + auto &string_data = StringStats::GetDataUnsafe(stats); + serializer.WriteProperty(200, "min", string_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); + serializer.WriteProperty(201, "max", string_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); + serializer.WriteProperty(202, "has_unicode", string_data.has_unicode); + serializer.WriteProperty(203, "has_max_string_length", string_data.has_max_string_length); + serializer.WriteProperty(204, "max_string_length", string_data.max_string_length); +} + +void StringStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { + auto &string_data = StringStats::GetDataUnsafe(base); + deserializer.ReadProperty(200, "min", string_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); + deserializer.ReadProperty(201, "max", string_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); + deserializer.ReadProperty(202, "has_unicode", string_data.has_unicode); + deserializer.ReadProperty(203, "has_max_string_length", string_data.has_max_string_length); + deserializer.ReadProperty(204, "max_string_length", string_data.max_string_length); +} + +static int StringValueComparison(const_data_ptr_t data, idx_t len, const_data_ptr_t comparison) { + D_ASSERT(len <= StringStatsData::MAX_STRING_MINMAX_SIZE); + for (idx_t i = 0; i < len; i++) { + if (data[i] < comparison[i]) { + return -1; + } else if (data[i] > comparison[i]) { + return 1; + } + } + return 0; +} + +static void ConstructValue(const_data_ptr_t data, idx_t size, data_t target[]) { + idx_t value_size = size > StringStatsData::MAX_STRING_MINMAX_SIZE ? StringStatsData::MAX_STRING_MINMAX_SIZE : size; + memcpy(target, data, value_size); + for (idx_t i = value_size; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { + target[i] = '\0'; + } +} + +void StringStats::Update(BaseStatistics &stats, const string_t &value) { + auto data = const_data_ptr_cast(value.GetData()); + auto size = value.GetSize(); + + //! we can only fit 8 bytes, so we might need to trim our string + // construct the value + data_t target[StringStatsData::MAX_STRING_MINMAX_SIZE]; + ConstructValue(data, size, target); + + // update the min and max + auto &string_data = StringStats::GetDataUnsafe(stats); + if (StringValueComparison(target, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.min) < 0) { + memcpy(string_data.min, target, StringStatsData::MAX_STRING_MINMAX_SIZE); + } + if (StringValueComparison(target, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.max) > 0) { + memcpy(string_data.max, target, StringStatsData::MAX_STRING_MINMAX_SIZE); + } + if (size > string_data.max_string_length) { + string_data.max_string_length = size; + } + if (stats.GetType().id() == LogicalTypeId::VARCHAR && !string_data.has_unicode) { + auto unicode = Utf8Proc::Analyze(const_char_ptr_cast(data), size); + if (unicode == UnicodeType::UNICODE) { + string_data.has_unicode = true; + } else if (unicode == UnicodeType::INVALID) { + throw InvalidInputException(ErrorManager::InvalidUnicodeError(string(const_char_ptr_cast(data), size), + "segment statistics update")); + } + } +} + +void StringStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + auto &string_data = StringStats::GetDataUnsafe(stats); + auto &other_data = StringStats::GetDataUnsafe(other); + if (StringValueComparison(other_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.min) < 0) { + memcpy(string_data.min, other_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); + } + if (StringValueComparison(other_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.max) > 0) { + memcpy(string_data.max, other_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); + } + string_data.has_unicode = string_data.has_unicode || other_data.has_unicode; + string_data.has_max_string_length = string_data.has_max_string_length && other_data.has_max_string_length; + string_data.max_string_length = MaxValue(string_data.max_string_length, other_data.max_string_length); +} + +FilterPropagateResult StringStats::CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, + const string &constant) { + auto &string_data = StringStats::GetDataUnsafe(stats); + auto data = const_data_ptr_cast(constant.c_str()); + auto size = constant.size(); + + idx_t value_size = size > StringStatsData::MAX_STRING_MINMAX_SIZE ? StringStatsData::MAX_STRING_MINMAX_SIZE : size; + int min_comp = StringValueComparison(data, value_size, string_data.min); + int max_comp = StringValueComparison(data, value_size, string_data.max); + switch (comparison_type) { + case ExpressionType::COMPARE_EQUAL: + if (min_comp >= 0 && max_comp <= 0) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } else { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + case ExpressionType::COMPARE_NOTEQUAL: + if (min_comp < 0 || max_comp > 0) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHAN: + if (max_comp <= 0) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } else { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + if (min_comp >= 0) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } else { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + default: + throw InternalException("Expression type not implemented for string statistics zone map"); + } +} + +static idx_t GetValidMinMaxSubstring(const_data_ptr_t data) { + for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { + if (data[i] == '\0') { + return i; + } + if ((data[i] & 0x80) != 0) { + return i; + } + } + return StringStatsData::MAX_STRING_MINMAX_SIZE; +} + +string StringStats::ToString(const BaseStatistics &stats) { + auto &string_data = StringStats::GetDataUnsafe(stats); + idx_t min_len = GetValidMinMaxSubstring(string_data.min); + idx_t max_len = GetValidMinMaxSubstring(string_data.max); + return StringUtil::Format("[Min: %s, Max: %s, Has Unicode: %s, Max String Length: %s]", + string(const_char_ptr_cast(string_data.min), min_len), + string(const_char_ptr_cast(string_data.max), max_len), + string_data.has_unicode ? "true" : "false", + string_data.has_max_string_length ? to_string(string_data.max_string_length) : "?"); +} + +void StringStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { + auto &string_data = StringStats::GetDataUnsafe(stats); + + UnifiedVectorFormat vdata; + vector.ToUnifiedFormat(count, vdata); + auto data = UnifiedVectorFormat::GetData(vdata); + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto index = vdata.sel->get_index(idx); + if (!vdata.validity.RowIsValid(index)) { + continue; + } + auto value = data[index]; + auto data = value.GetData(); + auto len = value.GetSize(); + // LCOV_EXCL_START + if (string_data.has_max_string_length && len > string_data.max_string_length) { + throw InternalException( + "Statistics mismatch: string value exceeds maximum string length.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString(count)); + } + if (stats.GetType().id() == LogicalTypeId::VARCHAR && !string_data.has_unicode) { + auto unicode = Utf8Proc::Analyze(data, len); + if (unicode == UnicodeType::UNICODE) { + throw InternalException("Statistics mismatch: string value contains unicode, but statistics says it " + "shouldn't.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString(count)); + } else if (unicode == UnicodeType::INVALID) { + throw InternalException("Invalid unicode detected in vector: %s", vector.ToString(count)); + } + } + if (StringValueComparison(const_data_ptr_cast(data), + MinValue(len, StringStatsData::MAX_STRING_MINMAX_SIZE), string_data.min) < 0) { + throw InternalException("Statistics mismatch: value is smaller than min.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString(count)); + } + if (StringValueComparison(const_data_ptr_cast(data), + MinValue(len, StringStatsData::MAX_STRING_MINMAX_SIZE), string_data.max) > 0) { + throw InternalException("Statistics mismatch: value is bigger than max.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString(count)); + } + // LCOV_EXCL_STOP + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/struct_stats.cpp b/src/duckdb/src/storage/statistics/struct_stats.cpp new file mode 100644 index 00000000..38096384 --- /dev/null +++ b/src/duckdb/src/storage/statistics/struct_stats.cpp @@ -0,0 +1,138 @@ +#include "duckdb/storage/statistics/struct_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/types/vector.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +void StructStats::Construct(BaseStatistics &stats) { + auto &child_types = StructType::GetChildTypes(stats.GetType()); + stats.child_stats = unsafe_unique_array(new BaseStatistics[child_types.size()]); + for (idx_t i = 0; i < child_types.size(); i++) { + BaseStatistics::Construct(stats.child_stats[i], child_types[i].second); + } +} + +BaseStatistics StructStats::CreateUnknown(LogicalType type) { + auto &child_types = StructType::GetChildTypes(type); + BaseStatistics result(std::move(type)); + result.InitializeUnknown(); + for (idx_t i = 0; i < child_types.size(); i++) { + result.child_stats[i].Copy(BaseStatistics::CreateUnknown(child_types[i].second)); + } + return result; +} + +BaseStatistics StructStats::CreateEmpty(LogicalType type) { + auto &child_types = StructType::GetChildTypes(type); + BaseStatistics result(std::move(type)); + result.InitializeEmpty(); + for (idx_t i = 0; i < child_types.size(); i++) { + result.child_stats[i].Copy(BaseStatistics::CreateEmpty(child_types[i].second)); + } + return result; +} + +const BaseStatistics *StructStats::GetChildStats(const BaseStatistics &stats) { + if (stats.GetStatsType() != StatisticsType::STRUCT_STATS) { + throw InternalException("Calling StructStats::GetChildStats on stats that is not a struct"); + } + return stats.child_stats.get(); +} + +const BaseStatistics &StructStats::GetChildStats(const BaseStatistics &stats, idx_t i) { + D_ASSERT(stats.GetStatsType() == StatisticsType::STRUCT_STATS); + if (i >= StructType::GetChildCount(stats.GetType())) { + throw InternalException("Calling StructStats::GetChildStats but there are no stats for this index"); + } + return stats.child_stats[i]; +} + +BaseStatistics &StructStats::GetChildStats(BaseStatistics &stats, idx_t i) { + D_ASSERT(stats.GetStatsType() == StatisticsType::STRUCT_STATS); + if (i >= StructType::GetChildCount(stats.GetType())) { + throw InternalException("Calling StructStats::GetChildStats but there are no stats for this index"); + } + return stats.child_stats[i]; +} + +void StructStats::SetChildStats(BaseStatistics &stats, idx_t i, const BaseStatistics &new_stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::STRUCT_STATS); + D_ASSERT(i < StructType::GetChildCount(stats.GetType())); + stats.child_stats[i].Copy(new_stats); +} + +void StructStats::SetChildStats(BaseStatistics &stats, idx_t i, unique_ptr new_stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::STRUCT_STATS); + if (!new_stats) { + StructStats::SetChildStats(stats, i, + BaseStatistics::CreateUnknown(StructType::GetChildType(stats.GetType(), i))); + } else { + StructStats::SetChildStats(stats, i, *new_stats); + } +} + +void StructStats::Copy(BaseStatistics &stats, const BaseStatistics &other) { + auto count = StructType::GetChildCount(stats.GetType()); + for (idx_t i = 0; i < count; i++) { + stats.child_stats[i].Copy(other.child_stats[i]); + } +} + +void StructStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + D_ASSERT(stats.GetType() == other.GetType()); + auto child_count = StructType::GetChildCount(stats.GetType()); + for (idx_t i = 0; i < child_count; i++) { + stats.child_stats[i].Merge(other.child_stats[i]); + } +} + +void StructStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { + auto child_stats = StructStats::GetChildStats(stats); + auto child_count = StructType::GetChildCount(stats.GetType()); + + serializer.WriteList(200, "child_stats", child_count, + [&](Serializer::List &list, idx_t i) { list.WriteElement(child_stats[i]); }); +} + +void StructStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { + auto &type = base.GetType(); + D_ASSERT(type.InternalType() == PhysicalType::STRUCT); + + auto &child_types = StructType::GetChildTypes(type); + + deserializer.ReadList(200, "child_stats", [&](Deserializer::List &list, idx_t i) { + deserializer.Set(const_cast(child_types[i].second)); + auto stat = list.ReadElement(); + base.child_stats[i].Copy(stat); + deserializer.Unset(); + }); +} + +string StructStats::ToString(const BaseStatistics &stats) { + string result; + result += " {"; + auto &child_types = StructType::GetChildTypes(stats.GetType()); + for (idx_t i = 0; i < child_types.size(); i++) { + if (i > 0) { + result += ", "; + } + result += child_types[i].first + ": " + stats.child_stats[i].ToString(); + } + result += "}"; + return result; +} + +void StructStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { + auto &child_entries = StructVector::GetEntries(vector); + for (idx_t i = 0; i < child_entries.size(); i++) { + stats.child_stats[i].Verify(*child_entries[i], sel, count); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/storage_info.cpp b/src/duckdb/src/storage/storage_info.cpp new file mode 100644 index 00000000..e3f0fffc --- /dev/null +++ b/src/duckdb/src/storage/storage_info.cpp @@ -0,0 +1,40 @@ +#include "duckdb/storage/storage_info.hpp" + +namespace duckdb { + +const uint64_t VERSION_NUMBER = 64; + +struct StorageVersionInfo { + const char *version_name; + idx_t storage_version; +}; + +static StorageVersionInfo storage_version_info[] = {{"v0.8.0 or v0.8.1", 51}, + {"v0.7.0 or v0.7.1", 43}, + {"v0.6.0 or v0.6.1", 39}, + {"v0.5.0 or v0.5.1", 38}, + {"v0.3.3, v0.3.4 or v0.4.0", 33}, + {"v0.3.2", 31}, + {"v0.3.1", 27}, + {"v0.3.0", 25}, + {"v0.2.9", 21}, + {"v0.2.8", 18}, + {"v0.2.7", 17}, + {"v0.2.6", 15}, + {"v0.2.5", 13}, + {"v0.2.4", 11}, + {"v0.2.3", 6}, + {"v0.2.2", 4}, + {"v0.2.1 and prior", 1}, + {nullptr, 0}}; + +const char *GetDuckDBVersion(idx_t version_number) { + for (idx_t i = 0; storage_version_info[i].version_name; i++) { + if (version_number == storage_version_info[i].storage_version) { + return storage_version_info[i].version_name; + } + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/storage_lock.cpp b/src/duckdb/src/storage/storage_lock.cpp new file mode 100644 index 00000000..1b7b1577 --- /dev/null +++ b/src/duckdb/src/storage/storage_lock.cpp @@ -0,0 +1,44 @@ +#include "duckdb/storage/storage_lock.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/assert.hpp" + +namespace duckdb { + +StorageLockKey::StorageLockKey(StorageLock &lock, StorageLockType type) : lock(lock), type(type) { +} + +StorageLockKey::~StorageLockKey() { + if (type == StorageLockType::EXCLUSIVE) { + lock.ReleaseExclusiveLock(); + } else { + D_ASSERT(type == StorageLockType::SHARED); + lock.ReleaseSharedLock(); + } +} + +StorageLock::StorageLock() : read_count(0) { +} + +unique_ptr StorageLock::GetExclusiveLock() { + exclusive_lock.lock(); + while (read_count != 0) { + } + return make_uniq(*this, StorageLockType::EXCLUSIVE); +} + +unique_ptr StorageLock::GetSharedLock() { + exclusive_lock.lock(); + read_count++; + exclusive_lock.unlock(); + return make_uniq(*this, StorageLockType::SHARED); +} + +void StorageLock::ReleaseExclusiveLock() { + exclusive_lock.unlock(); +} + +void StorageLock::ReleaseSharedLock() { + read_count--; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/storage_manager.cpp b/src/duckdb/src/storage/storage_manager.cpp new file mode 100644 index 00000000..088f018c --- /dev/null +++ b/src/duckdb/src/storage/storage_manager.cpp @@ -0,0 +1,290 @@ +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/storage/checkpoint_manager.hpp" +#include "duckdb/storage/in_memory_block_manager.hpp" +#include "duckdb/storage/single_file_block_manager.hpp" +#include "duckdb/storage/object_cache.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/transaction/transaction_manager.hpp" +#include "duckdb/common/serializer/buffered_file_reader.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/database_manager.hpp" + +namespace duckdb { + +StorageManager::StorageManager(AttachedDatabase &db, string path_p, bool read_only) + : db(db), path(std::move(path_p)), read_only(read_only) { + if (path.empty()) { + path = ":memory:"; + } else { + auto &fs = FileSystem::Get(db); + this->path = fs.ExpandPath(path); + } +} + +StorageManager::~StorageManager() { +} + +StorageManager &StorageManager::Get(AttachedDatabase &db) { + return db.GetStorageManager(); +} +StorageManager &StorageManager::Get(Catalog &catalog) { + return StorageManager::Get(catalog.GetAttached()); +} + +DatabaseInstance &StorageManager::GetDatabase() { + return db.GetDatabase(); +} + +BufferManager &BufferManager::GetBufferManager(ClientContext &context) { + return BufferManager::GetBufferManager(*context.db); +} + +ObjectCache &ObjectCache::GetObjectCache(ClientContext &context) { + return context.db->GetObjectCache(); +} + +bool ObjectCache::ObjectCacheEnabled(ClientContext &context) { + return context.db->config.options.object_cache_enable; +} + +bool StorageManager::InMemory() { + D_ASSERT(!path.empty()); + return path == ":memory:"; +} + +void StorageManager::Initialize() { + bool in_memory = InMemory(); + if (in_memory && read_only) { + throw CatalogException("Cannot launch in-memory database in read-only mode!"); + } + + // create or load the database from disk, if not in-memory mode + LoadDatabase(); +} + +/////////////////////////////////////////////////////////////////////////// +class SingleFileTableIOManager : public TableIOManager { +public: + explicit SingleFileTableIOManager(BlockManager &block_manager) : block_manager(block_manager) { + } + + BlockManager &block_manager; + +public: + BlockManager &GetIndexBlockManager() override { + return block_manager; + } + BlockManager &GetBlockManagerForRowData() override { + return block_manager; + } + MetadataManager &GetMetadataManager() override { + return block_manager.GetMetadataManager(); + } +}; + +SingleFileStorageManager::SingleFileStorageManager(AttachedDatabase &db, string path, bool read_only) + : StorageManager(db, std::move(path), read_only) { +} + +void SingleFileStorageManager::LoadDatabase() { + if (InMemory()) { + block_manager = make_uniq(BufferManager::GetBufferManager(db)); + table_io_manager = make_uniq(*block_manager); + return; + } + std::size_t question_mark_pos = path.find('?'); + auto wal_path = path; + if (question_mark_pos != std::string::npos) { + wal_path.insert(question_mark_pos, ".wal"); + } else { + wal_path += ".wal"; + } + auto &fs = FileSystem::Get(db); + auto &config = DBConfig::Get(db); + bool truncate_wal = false; + if (!config.options.enable_external_access) { + if (!db.IsInitialDatabase()) { + throw PermissionException("Attaching on-disk databases is disabled through configuration"); + } + } + + StorageManagerOptions options; + options.read_only = read_only; + options.use_direct_io = config.options.use_direct_io; + options.debug_initialize = config.options.debug_initialize; + // first check if the database exists + if (!fs.FileExists(path)) { + if (read_only) { + throw CatalogException("Cannot open database \"%s\" in read-only mode: database does not exist", path); + } + // check if the WAL exists + if (fs.FileExists(wal_path)) { + // WAL file exists but database file does not + // remove the WAL + fs.RemoveFile(wal_path); + } + // initialize the block manager while creating a new db file + auto sf_block_manager = make_uniq(db, path, options); + sf_block_manager->CreateNewDatabase(); + block_manager = std::move(sf_block_manager); + table_io_manager = make_uniq(*block_manager); + } else { + // initialize the block manager while loading the current db file + auto sf_block_manager = make_uniq(db, path, options); + sf_block_manager->LoadExistingDatabase(); + block_manager = std::move(sf_block_manager); + table_io_manager = make_uniq(*block_manager); + + //! Load from storage + auto checkpointer = SingleFileCheckpointReader(*this); + checkpointer.LoadFromStorage(); + // check if the WAL file exists + if (fs.FileExists(wal_path)) { + // replay the WAL + truncate_wal = WriteAheadLog::Replay(db, wal_path); + } + } + // initialize the WAL file + if (!read_only) { + wal = make_uniq(db, wal_path); + if (truncate_wal) { + wal->Truncate(0); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// + +class SingleFileStorageCommitState : public StorageCommitState { + idx_t initial_wal_size = 0; + idx_t initial_written = 0; + optional_ptr log; + bool checkpoint; + +public: + SingleFileStorageCommitState(StorageManager &storage_manager, bool checkpoint); + ~SingleFileStorageCommitState() override { + // If log is non-null, then commit threw an exception before flushing. + if (log) { + auto &wal = *log.get(); + wal.skip_writing = false; + if (wal.GetTotalWritten() > initial_written) { + // remove any entries written into the WAL by truncating it + wal.Truncate(initial_wal_size); + } + } + } + + // Make the commit persistent + void FlushCommit() override; +}; + +SingleFileStorageCommitState::SingleFileStorageCommitState(StorageManager &storage_manager, bool checkpoint) + : checkpoint(checkpoint) { + log = storage_manager.GetWriteAheadLog(); + if (log) { + auto initial_size = log->GetWALSize(); + initial_written = log->GetTotalWritten(); + initial_wal_size = initial_size < 0 ? 0 : idx_t(initial_size); + + if (checkpoint) { + // check if we are checkpointing after this commit + // if we are checkpointing, we don't need to write anything to the WAL + // this saves us a lot of unnecessary writes to disk in the case of large commits + log->skip_writing = true; + } + } else { + D_ASSERT(!checkpoint); + } +} + +// Make the commit persistent +void SingleFileStorageCommitState::FlushCommit() { + if (log) { + // flush the WAL if any changes were made + if (log->GetTotalWritten() > initial_written) { + (void)checkpoint; + D_ASSERT(!checkpoint); + D_ASSERT(!log->skip_writing); + log->Flush(); + } + log->skip_writing = false; + } + // Null so that the destructor will not truncate the log. + log = nullptr; +} + +unique_ptr SingleFileStorageManager::GenStorageCommitState(Transaction &transaction, + bool checkpoint) { + return make_uniq(*this, checkpoint); +} + +bool SingleFileStorageManager::IsCheckpointClean(MetaBlockPointer checkpoint_id) { + return block_manager->IsRootBlock(checkpoint_id); +} + +void SingleFileStorageManager::CreateCheckpoint(bool delete_wal, bool force_checkpoint) { + if (InMemory() || read_only || !wal) { + return; + } + auto &config = DBConfig::Get(db); + if (wal->GetWALSize() > 0 || config.options.force_checkpoint || force_checkpoint) { + // we only need to checkpoint if there is anything in the WAL + try { + SingleFileCheckpointWriter checkpointer(db, *block_manager); + checkpointer.CreateCheckpoint(); + } catch (std::exception &ex) { + throw FatalException("Failed to create checkpoint because of error: %s", ex.what()); + } + } + if (delete_wal) { + wal->Delete(); + wal.reset(); + } +} + +DatabaseSize SingleFileStorageManager::GetDatabaseSize() { + // All members default to zero + DatabaseSize ds; + if (!InMemory()) { + ds.total_blocks = block_manager->TotalBlocks(); + ds.block_size = Storage::BLOCK_ALLOC_SIZE; + ds.free_blocks = block_manager->FreeBlocks(); + ds.used_blocks = ds.total_blocks - ds.free_blocks; + ds.bytes = (ds.total_blocks * ds.block_size); + if (auto wal = GetWriteAheadLog()) { + ds.wal_size = wal->GetWALSize(); + } + } + return ds; +} + +vector SingleFileStorageManager::GetMetadataInfo() { + auto &metadata_manager = block_manager->GetMetadataManager(); + return metadata_manager.GetMetadataInfo(); +} + +bool SingleFileStorageManager::AutomaticCheckpoint(idx_t estimated_wal_bytes) { + auto log = GetWriteAheadLog(); + if (!log) { + return false; + } + + auto &config = DBConfig::Get(db); + auto initial_size = log->GetWALSize(); + idx_t expected_wal_size = initial_size + estimated_wal_bytes; + return expected_wal_size > config.options.checkpoint_wal_size; +} + +shared_ptr SingleFileStorageManager::GetTableIOManager(BoundCreateTableInfo *info /*info*/) { + // This is an unmanaged reference. No ref/deref overhead. Lifetime of the + // TableIoManager follows lifetime of the StorageManager (this). + return shared_ptr(shared_ptr(nullptr), table_io_manager.get()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/chunk_info.cpp b/src/duckdb/src/storage/table/chunk_info.cpp new file mode 100644 index 00000000..fa19b502 --- /dev/null +++ b/src/duckdb/src/storage/table/chunk_info.cpp @@ -0,0 +1,285 @@ +#include "duckdb/storage/table/chunk_info.hpp" +#include "duckdb/transaction/transaction.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" + +namespace duckdb { + +struct TransactionVersionOperator { + static bool UseInsertedVersion(transaction_t start_time, transaction_t transaction_id, transaction_t id) { + return id < start_time || id == transaction_id; + } + + static bool UseDeletedVersion(transaction_t start_time, transaction_t transaction_id, transaction_t id) { + return !UseInsertedVersion(start_time, transaction_id, id); + } +}; + +struct CommittedVersionOperator { + static bool UseInsertedVersion(transaction_t start_time, transaction_t transaction_id, transaction_t id) { + return true; + } + + static bool UseDeletedVersion(transaction_t min_start_time, transaction_t min_transaction_id, transaction_t id) { + return (id >= min_start_time && id < TRANSACTION_ID_START) || (id >= min_transaction_id); + } +}; + +static bool UseVersion(TransactionData transaction, transaction_t id) { + return TransactionVersionOperator::UseInsertedVersion(transaction.start_time, transaction.transaction_id, id); +} + +void ChunkInfo::Write(WriteStream &writer) const { + writer.Write(type); +} + +unique_ptr ChunkInfo::Read(ReadStream &reader) { + auto type = reader.Read(); + switch (type) { + case ChunkInfoType::EMPTY_INFO: + return nullptr; + case ChunkInfoType::CONSTANT_INFO: + return ChunkConstantInfo::Read(reader); + case ChunkInfoType::VECTOR_INFO: + return ChunkVectorInfo::Read(reader); + default: + throw SerializationException("Could not deserialize Chunk Info Type: unrecognized type"); + } +} + +//===--------------------------------------------------------------------===// +// Constant info +//===--------------------------------------------------------------------===// +ChunkConstantInfo::ChunkConstantInfo(idx_t start) + : ChunkInfo(start, ChunkInfoType::CONSTANT_INFO), insert_id(0), delete_id(NOT_DELETED_ID) { +} + +template +idx_t ChunkConstantInfo::TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, + SelectionVector &sel_vector, idx_t max_count) const { + if (OP::UseInsertedVersion(start_time, transaction_id, insert_id) && + OP::UseDeletedVersion(start_time, transaction_id, delete_id)) { + return max_count; + } + return 0; +} + +idx_t ChunkConstantInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { + return TemplatedGetSelVector(transaction.start_time, transaction.transaction_id, + sel_vector, max_count); +} + +idx_t ChunkConstantInfo::GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, + SelectionVector &sel_vector, idx_t max_count) { + return TemplatedGetSelVector(min_start_id, min_transaction_id, sel_vector, max_count); +} + +bool ChunkConstantInfo::Fetch(TransactionData transaction, row_t row) { + return UseVersion(transaction, insert_id) && !UseVersion(transaction, delete_id); +} + +void ChunkConstantInfo::CommitAppend(transaction_t commit_id, idx_t start, idx_t end) { + D_ASSERT(start == 0 && end == STANDARD_VECTOR_SIZE); + insert_id = commit_id; +} + +bool ChunkConstantInfo::HasDeletes() const { + bool is_deleted = insert_id >= TRANSACTION_ID_START || delete_id < TRANSACTION_ID_START; + return is_deleted; +} + +idx_t ChunkConstantInfo::GetCommittedDeletedCount(idx_t max_count) { + return delete_id < TRANSACTION_ID_START ? max_count : 0; +} + +void ChunkConstantInfo::Write(WriteStream &writer) const { + D_ASSERT(HasDeletes()); + ChunkInfo::Write(writer); + writer.Write(start); +} + +unique_ptr ChunkConstantInfo::Read(ReadStream &reader) { + auto start = reader.Read(); + auto info = make_uniq(start); + info->insert_id = 0; + info->delete_id = 0; + return std::move(info); +} + +//===--------------------------------------------------------------------===// +// Vector info +//===--------------------------------------------------------------------===// +ChunkVectorInfo::ChunkVectorInfo(idx_t start) + : ChunkInfo(start, ChunkInfoType::VECTOR_INFO), insert_id(0), same_inserted_id(true), any_deleted(false) { + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + inserted[i] = 0; + deleted[i] = NOT_DELETED_ID; + } +} + +template +idx_t ChunkVectorInfo::TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, + SelectionVector &sel_vector, idx_t max_count) const { + idx_t count = 0; + if (same_inserted_id && !any_deleted) { + // all tuples have the same inserted id: and no tuples were deleted + if (OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { + return max_count; + } else { + return 0; + } + } else if (same_inserted_id) { + if (!OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { + return 0; + } + // have to check deleted flag + for (idx_t i = 0; i < max_count; i++) { + if (OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { + sel_vector.set_index(count++, i); + } + } + } else if (!any_deleted) { + // have to check inserted flag + for (idx_t i = 0; i < max_count; i++) { + if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i])) { + sel_vector.set_index(count++, i); + } + } + } else { + // have to check both flags + for (idx_t i = 0; i < max_count; i++) { + if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i]) && + OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { + sel_vector.set_index(count++, i); + } + } + } + return count; +} + +idx_t ChunkVectorInfo::GetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, + idx_t max_count) const { + return TemplatedGetSelVector(start_time, transaction_id, sel_vector, max_count); +} + +idx_t ChunkVectorInfo::GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, + SelectionVector &sel_vector, idx_t max_count) { + return TemplatedGetSelVector(min_start_id, min_transaction_id, sel_vector, max_count); +} + +idx_t ChunkVectorInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { + return GetSelVector(transaction.start_time, transaction.transaction_id, sel_vector, max_count); +} + +bool ChunkVectorInfo::Fetch(TransactionData transaction, row_t row) { + return UseVersion(transaction, inserted[row]) && !UseVersion(transaction, deleted[row]); +} + +idx_t ChunkVectorInfo::Delete(transaction_t transaction_id, row_t rows[], idx_t count) { + any_deleted = true; + + idx_t deleted_tuples = 0; + for (idx_t i = 0; i < count; i++) { + if (deleted[rows[i]] == transaction_id) { + continue; + } + // first check the chunk for conflicts + if (deleted[rows[i]] != NOT_DELETED_ID) { + // tuple was already deleted by another transaction + throw TransactionException("Conflict on tuple deletion!"); + } + // after verifying that there are no conflicts we mark the tuple as deleted + deleted[rows[i]] = transaction_id; + rows[deleted_tuples] = rows[i]; + deleted_tuples++; + } + return deleted_tuples; +} + +void ChunkVectorInfo::CommitDelete(transaction_t commit_id, row_t rows[], idx_t count) { + for (idx_t i = 0; i < count; i++) { + deleted[rows[i]] = commit_id; + } +} + +void ChunkVectorInfo::Append(idx_t start, idx_t end, transaction_t commit_id) { + if (start == 0) { + insert_id = commit_id; + } else if (insert_id != commit_id) { + same_inserted_id = false; + insert_id = NOT_DELETED_ID; + } + for (idx_t i = start; i < end; i++) { + inserted[i] = commit_id; + } +} + +void ChunkVectorInfo::CommitAppend(transaction_t commit_id, idx_t start, idx_t end) { + if (same_inserted_id) { + insert_id = commit_id; + } + for (idx_t i = start; i < end; i++) { + inserted[i] = commit_id; + } +} + +bool ChunkVectorInfo::HasDeletes() const { + return any_deleted; +} + +idx_t ChunkVectorInfo::GetCommittedDeletedCount(idx_t max_count) { + if (!any_deleted) { + return 0; + } + idx_t delete_count = 0; + for (idx_t i = 0; i < max_count; i++) { + if (deleted[i] < TRANSACTION_ID_START) { + delete_count++; + } + } + return delete_count; +} + +void ChunkVectorInfo::Write(WriteStream &writer) const { + SelectionVector sel(STANDARD_VECTOR_SIZE); + transaction_t start_time = TRANSACTION_ID_START - 1; + transaction_t transaction_id = DConstants::INVALID_INDEX; + idx_t count = GetSelVector(start_time, transaction_id, sel, STANDARD_VECTOR_SIZE); + if (count == STANDARD_VECTOR_SIZE) { + // nothing is deleted: skip writing anything + writer.Write(ChunkInfoType::EMPTY_INFO); + return; + } + if (count == 0) { + // everything is deleted: write a constant vector + writer.Write(ChunkInfoType::CONSTANT_INFO); + writer.Write(start); + return; + } + // write a boolean vector + ChunkInfo::Write(writer); + writer.Write(start); + ValidityMask mask(STANDARD_VECTOR_SIZE); + mask.Initialize(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < count; i++) { + mask.SetInvalid(sel.get_index(i)); + } + mask.Write(writer, STANDARD_VECTOR_SIZE); +} + +unique_ptr ChunkVectorInfo::Read(ReadStream &reader) { + auto start = reader.Read(); + auto result = make_uniq(start); + result->any_deleted = true; + ValidityMask mask; + mask.Read(reader, STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + if (mask.RowIsValid(i)) { + result->deleted[i] = 0; + } + } + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/column_checkpoint_state.cpp b/src/duckdb/src/storage/table/column_checkpoint_state.cpp new file mode 100644 index 00000000..88d803b7 --- /dev/null +++ b/src/duckdb/src/storage/table/column_checkpoint_state.cpp @@ -0,0 +1,197 @@ +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp" +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/storage/checkpoint/table_data_writer.hpp" + +#include "duckdb/main/config.hpp" + +namespace duckdb { + +ColumnCheckpointState::ColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, + PartialBlockManager &partial_block_manager) + : row_group(row_group), column_data(column_data), partial_block_manager(partial_block_manager) { +} + +ColumnCheckpointState::~ColumnCheckpointState() { +} + +unique_ptr ColumnCheckpointState::GetStatistics() { + D_ASSERT(global_stats); + return std::move(global_stats); +} + +PartialBlockForCheckpoint::PartialBlockForCheckpoint(ColumnData &data, ColumnSegment &segment, PartialBlockState state, + BlockManager &block_manager) + : PartialBlock(state, block_manager, segment.block) { + AddSegmentToTail(data, segment, 0); +} + +PartialBlockForCheckpoint::~PartialBlockForCheckpoint() { + D_ASSERT(IsFlushed() || Exception::UncaughtException()); +} + +bool PartialBlockForCheckpoint::IsFlushed() { + // segments are cleared on Flush + return segments.empty(); +} + +void PartialBlockForCheckpoint::Flush(const idx_t free_space_left) { + + if (IsFlushed()) { + throw InternalException("Flush called on partial block that was already flushed"); + } + + // zero-initialize unused memory + FlushInternal(free_space_left); + + // At this point, we've already copied all data from tail_segments + // into the page owned by first_segment. We flush all segment data to + // disk with the following call. + // persist the first segment to disk and point the remaining segments to the same block + bool fetch_new_block = state.block_id == INVALID_BLOCK; + if (fetch_new_block) { + state.block_id = block_manager.GetFreeBlockId(); + } + + for (idx_t i = 0; i < segments.size(); i++) { + auto &segment = segments[i]; + segment.data.IncrementVersion(); + if (i == 0) { + // the first segment is converted to persistent - this writes the data for ALL segments to disk + D_ASSERT(segment.offset_in_block == 0); + segment.segment.ConvertToPersistent(&block_manager, state.block_id); + // update the block after it has been converted to a persistent segment + block_handle = segment.segment.block; + } else { + // subsequent segments are MARKED as persistent - they don't need to be rewritten + segment.segment.MarkAsPersistent(block_handle, segment.offset_in_block); + if (fetch_new_block) { + // if we fetched a new block we need to increase the reference count to the block + block_manager.IncreaseBlockReferenceCount(state.block_id); + } + } + } + + Clear(); +} + +void PartialBlockForCheckpoint::Merge(PartialBlock &other_p, idx_t offset, idx_t other_size) { + auto &other = other_p.Cast(); + + auto &buffer_manager = block_manager.buffer_manager; + // pin the source block + auto old_handle = buffer_manager.Pin(other.block_handle); + // pin the target block + auto new_handle = buffer_manager.Pin(block_handle); + // memcpy the contents of the old block to the new block + memcpy(new_handle.Ptr() + offset, old_handle.Ptr(), other_size); + + // now copy over all segments to the new block + // move over the uninitialized regions + for (auto ®ion : other.uninitialized_regions) { + region.start += offset; + region.end += offset; + uninitialized_regions.push_back(region); + } + + // move over the segments + for (auto &segment : other.segments) { + AddSegmentToTail(segment.data, segment.segment, segment.offset_in_block + offset); + } + + other.Clear(); +} + +void PartialBlockForCheckpoint::AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block) { + segments.emplace_back(data, segment, offset_in_block); +} + +void PartialBlockForCheckpoint::Clear() { + uninitialized_regions.clear(); + block_handle.reset(); + segments.clear(); +} + +void ColumnCheckpointState::FlushSegment(unique_ptr segment, idx_t segment_size) { + D_ASSERT(segment_size <= Storage::BLOCK_SIZE); + auto tuple_count = segment->count.load(); + if (tuple_count == 0) { // LCOV_EXCL_START + return; + } // LCOV_EXCL_STOP + + // merge the segment stats into the global stats + global_stats->Merge(segment->stats.statistics); + + // get the buffer of the segment and pin it + auto &db = column_data.GetDatabase(); + auto &buffer_manager = BufferManager::GetBufferManager(db); + block_id_t block_id = INVALID_BLOCK; + uint32_t offset_in_block = 0; + + if (!segment->stats.statistics.IsConstant()) { + // non-constant block + PartialBlockAllocation allocation = partial_block_manager.GetBlockAllocation(segment_size); + block_id = allocation.state.block_id; + offset_in_block = allocation.state.offset; + + if (allocation.partial_block) { + // Use an existing block. + D_ASSERT(offset_in_block > 0); + auto &pstate = allocation.partial_block->Cast(); + // pin the source block + auto old_handle = buffer_manager.Pin(segment->block); + // pin the target block + auto new_handle = buffer_manager.Pin(pstate.block_handle); + // memcpy the contents of the old block to the new block + memcpy(new_handle.Ptr() + offset_in_block, old_handle.Ptr(), segment_size); + pstate.AddSegmentToTail(column_data, *segment, offset_in_block); + } else { + // Create a new block for future reuse. + if (segment->SegmentSize() != Storage::BLOCK_SIZE) { + // the segment is smaller than the block size + // allocate a new block and copy the data over + D_ASSERT(segment->SegmentSize() < Storage::BLOCK_SIZE); + segment->Resize(Storage::BLOCK_SIZE); + } + D_ASSERT(offset_in_block == 0); + allocation.partial_block = make_uniq(column_data, *segment, allocation.state, + *allocation.block_manager); + } + // Writer will decide whether to reuse this block. + partial_block_manager.RegisterPartialBlock(std::move(allocation)); + } else { + // constant block: no need to write anything to disk besides the stats + // set up the compression function to constant + auto &config = DBConfig::GetConfig(db); + segment->function = + *config.GetCompressionFunction(CompressionType::COMPRESSION_CONSTANT, segment->type.InternalType()); + segment->ConvertToPersistent(nullptr, INVALID_BLOCK); + } + + // construct the data pointer + DataPointer data_pointer(segment->stats.statistics.Copy()); + data_pointer.block_pointer.block_id = block_id; + data_pointer.block_pointer.offset = offset_in_block; + data_pointer.row_start = row_group.start; + if (!data_pointers.empty()) { + auto &last_pointer = data_pointers.back(); + data_pointer.row_start = last_pointer.row_start + last_pointer.tuple_count; + } + data_pointer.tuple_count = tuple_count; + data_pointer.compression_type = segment->function.get().type; + if (segment->function.get().serialize_state) { + data_pointer.segment_state = segment->function.get().serialize_state(*segment); + } + + // append the segment to the new segment tree + new_tree.AppendSegment(std::move(segment)); + data_pointers.push_back(std::move(data_pointer)); +} + +void ColumnCheckpointState::WriteDataPointers(RowGroupWriter &writer, Serializer &serializer) { + writer.WriteColumnDataPointers(*this, serializer); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/column_data.cpp b/src/duckdb/src/storage/table/column_data.cpp new file mode 100644 index 00000000..d6a1bd93 --- /dev/null +++ b/src/duckdb/src/storage/table/column_data.cpp @@ -0,0 +1,593 @@ +#include "duckdb/storage/table/column_data.hpp" + +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/data_pointer.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/statistics/distinct_statistics.hpp" +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/table/list_column_data.hpp" +#include "duckdb/storage/table/standard_column_data.hpp" + +#include "duckdb/storage/table/struct_column_data.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/common/serializer/read_stream.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" + +namespace duckdb { + +ColumnData::ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, + LogicalType type_p, optional_ptr parent) + : start(start_row), count(0), block_manager(block_manager), info(info), column_index(column_index), + type(std::move(type_p)), parent(parent), version(0) { + if (!parent) { + stats = make_uniq(type); + } +} + +ColumnData::~ColumnData() { +} + +void ColumnData::SetStart(idx_t new_start) { + this->start = new_start; + idx_t offset = 0; + for (auto &segment : data.Segments()) { + segment.start = start + offset; + offset += segment.count; + } + data.Reinitialize(); +} + +DatabaseInstance &ColumnData::GetDatabase() const { + return info.db.GetDatabase(); +} + +DataTableInfo &ColumnData::GetTableInfo() const { + return info; +} + +const LogicalType &ColumnData::RootType() const { + if (parent) { + return parent->RootType(); + } + return type; +} + +void ColumnData::IncrementVersion() { + version++; +} + +idx_t ColumnData::GetMaxEntry() { + return count; +} + +void ColumnData::InitializeScan(ColumnScanState &state) { + state.current = data.GetRootSegment(); + state.segment_tree = &data; + state.row_index = state.current ? state.current->start : 0; + state.internal_index = state.row_index; + state.initialized = false; + state.version = version; + state.scan_state.reset(); + state.last_offset = 0; +} + +void ColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { + state.current = data.GetSegment(row_idx); + state.segment_tree = &data; + state.row_index = row_idx; + state.internal_index = state.current->start; + state.initialized = false; + state.version = version; + state.scan_state.reset(); + state.last_offset = 0; +} + +idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remaining, bool has_updates) { + state.previous_states.clear(); + if (state.version != version) { + InitializeScanWithOffset(state, state.row_index); + state.current->InitializeScan(state); + state.initialized = true; + } else if (!state.initialized) { + D_ASSERT(state.current); + state.current->InitializeScan(state); + state.internal_index = state.current->start; + state.initialized = true; + } + D_ASSERT(data.HasSegment(state.current)); + D_ASSERT(state.version == version); + D_ASSERT(state.internal_index <= state.row_index); + if (state.internal_index < state.row_index) { + state.current->Skip(state); + } + D_ASSERT(state.current->type == type); + idx_t initial_remaining = remaining; + while (remaining > 0) { + D_ASSERT(state.row_index >= state.current->start && + state.row_index <= state.current->start + state.current->count); + idx_t scan_count = MinValue(remaining, state.current->start + state.current->count - state.row_index); + idx_t result_offset = initial_remaining - remaining; + if (scan_count > 0) { + state.current->Scan(state, scan_count, result, result_offset, + !has_updates && scan_count == initial_remaining); + + state.row_index += scan_count; + remaining -= scan_count; + } + + if (remaining > 0) { + auto next = data.GetNextSegment(state.current); + if (!next) { + break; + } + state.previous_states.emplace_back(std::move(state.scan_state)); + state.current = next; + state.current->InitializeScan(state); + state.segment_checked = false; + D_ASSERT(state.row_index >= state.current->start && + state.row_index <= state.current->start + state.current->count); + } + } + state.internal_index = state.row_index; + return initial_remaining - remaining; +} + +template +idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) { + bool has_updates; + { + lock_guard update_guard(update_lock); + has_updates = updates ? true : false; + } + auto scan_count = ScanVector(state, result, STANDARD_VECTOR_SIZE, has_updates); + if (has_updates) { + lock_guard update_guard(update_lock); + if (!ALLOW_UPDATES && updates->HasUncommittedUpdates(vector_index)) { + throw TransactionException("Cannot create index with outstanding updates"); + } + result.Flatten(scan_count); + if (SCAN_COMMITTED) { + updates->FetchCommitted(vector_index, result); + } else { + updates->FetchUpdates(transaction, vector_index, result); + } + } + return scan_count; +} + +template idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, + ColumnScanState &state, Vector &result); +template idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, + ColumnScanState &state, Vector &result); +template idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, + ColumnScanState &state, Vector &result); +template idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, + ColumnScanState &state, Vector &result); + +idx_t ColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) { + return ScanVector(transaction, vector_index, state, result); +} + +idx_t ColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) { + if (allow_updates) { + return ScanVector(TransactionData(0, 0), vector_index, state, result); + } else { + return ScanVector(TransactionData(0, 0), vector_index, state, result); + } +} + +void ColumnData::ScanCommittedRange(idx_t row_group_start, idx_t offset_in_row_group, idx_t count, Vector &result) { + ColumnScanState child_state; + InitializeScanWithOffset(child_state, row_group_start + offset_in_row_group); + auto scan_count = ScanVector(child_state, result, count, updates ? true : false); + if (updates) { + result.Flatten(scan_count); + updates->FetchCommittedRange(offset_in_row_group, count, result); + } +} + +idx_t ColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count) { + if (count == 0) { + return 0; + } + // ScanCount can only be used if there are no updates + D_ASSERT(!updates); + return ScanVector(state, result, count, false); +} + +void ColumnData::Select(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, + SelectionVector &sel, idx_t &count, const TableFilter &filter) { + idx_t scan_count = Scan(transaction, vector_index, state, result); + result.Flatten(scan_count); + ColumnSegment::FilterSelection(sel, result, filter, count, FlatVector::Validity(result)); +} + +void ColumnData::FilterScan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, + SelectionVector &sel, idx_t count) { + Scan(transaction, vector_index, state, result); + result.Slice(sel, count); +} + +void ColumnData::FilterScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, SelectionVector &sel, + idx_t count, bool allow_updates) { + ScanCommitted(vector_index, state, result, allow_updates); + result.Slice(sel, count); +} + +void ColumnData::Skip(ColumnScanState &state, idx_t count) { + state.Next(count); +} + +void ColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { + UnifiedVectorFormat vdata; + vector.ToUnifiedFormat(count, vdata); + AppendData(stats, state, vdata, count); +} + +void ColumnData::Append(ColumnAppendState &state, Vector &vector, idx_t count) { + if (parent || !stats) { + throw InternalException("ColumnData::Append called on a column with a parent or without stats"); + } + Append(stats->statistics, state, vector, count); +} + +bool ColumnData::CheckZonemap(TableFilter &filter) { + if (!stats) { + throw InternalException("ColumnData::CheckZonemap called on a column without stats"); + } + auto propagate_result = filter.CheckStatistics(stats->statistics); + if (propagate_result == FilterPropagateResult::FILTER_ALWAYS_FALSE || + propagate_result == FilterPropagateResult::FILTER_FALSE_OR_NULL) { + return false; + } + return true; +} + +unique_ptr ColumnData::GetStatistics() { + if (!stats) { + throw InternalException("ColumnData::GetStatistics called on a column without stats"); + } + return stats->statistics.ToUnique(); +} + +void ColumnData::MergeStatistics(const BaseStatistics &other) { + if (!stats) { + throw InternalException("ColumnData::MergeStatistics called on a column without stats"); + } + return stats->statistics.Merge(other); +} + +void ColumnData::MergeIntoStatistics(BaseStatistics &other) { + if (!stats) { + throw InternalException("ColumnData::MergeIntoStatistics called on a column without stats"); + } + return other.Merge(stats->statistics); +} + +void ColumnData::InitializeAppend(ColumnAppendState &state) { + auto l = data.Lock(); + if (data.IsEmpty(l)) { + // no segments yet, append an empty segment + AppendTransientSegment(l, start); + } + auto segment = data.GetLastSegment(l); + if (segment->segment_type == ColumnSegmentType::PERSISTENT || !segment->function.get().init_append) { + // we cannot append to this segment - append a new segment + auto total_rows = segment->start + segment->count; + AppendTransientSegment(l, total_rows); + state.current = data.GetLastSegment(l); + } else { + state.current = segment; + } + + D_ASSERT(state.current->segment_type == ColumnSegmentType::TRANSIENT); + state.current->InitializeAppend(state); + D_ASSERT(state.current->function.get().append); +} + +void ColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) { + idx_t offset = 0; + this->count += count; + while (true) { + // append the data from the vector + idx_t copied_elements = state.current->Append(state, vdata, offset, count); + stats.Merge(state.current->stats.statistics); + if (copied_elements == count) { + // finished copying everything + break; + } + + // we couldn't fit everything we wanted in the current column segment, create a new one + { + auto l = data.Lock(); + AppendTransientSegment(l, state.current->start + state.current->count); + state.current = data.GetLastSegment(l); + state.current->InitializeAppend(state); + } + offset += copied_elements; + count -= copied_elements; + } +} + +void ColumnData::RevertAppend(row_t start_row) { + auto l = data.Lock(); + // check if this row is in the segment tree at all + auto last_segment = data.GetLastSegment(l); + if (idx_t(start_row) >= last_segment->start + last_segment->count) { + // the start row is equal to the final portion of the column data: nothing was ever appended here + D_ASSERT(idx_t(start_row) == last_segment->start + last_segment->count); + return; + } + // find the segment index that the current row belongs to + idx_t segment_index = data.GetSegmentIndex(l, start_row); + auto segment = data.GetSegmentByIndex(l, segment_index); + auto &transient = *segment; + D_ASSERT(transient.segment_type == ColumnSegmentType::TRANSIENT); + + // remove any segments AFTER this segment: they should be deleted entirely + data.EraseSegments(l, segment_index); + + this->count = start_row - this->start; + segment->next = nullptr; + transient.RevertAppend(start_row); +} + +idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { + D_ASSERT(row_id >= 0); + D_ASSERT(idx_t(row_id) >= start); + // perform the fetch within the segment + state.row_index = start + ((row_id - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); + state.current = data.GetSegment(state.row_index); + state.internal_index = state.current->start; + return ScanVector(state, result, STANDARD_VECTOR_SIZE, false); +} + +void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + auto segment = data.GetSegment(row_id); + + // now perform the fetch within the segment + segment->FetchRow(state, row_id, result, result_idx); + // merge any updates made to this row + lock_guard update_guard(update_lock); + if (updates) { + updates->FetchRow(transaction, row_id, result, result_idx); + } +} + +void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, + idx_t update_count) { + lock_guard update_guard(update_lock); + if (!updates) { + updates = make_uniq(*this); + } + Vector base_vector(type); + ColumnScanState state; + auto fetch_count = Fetch(state, row_ids[0], base_vector); + + base_vector.Flatten(fetch_count); + updates->Update(transaction, column_index, update_vector, row_ids, update_count, base_vector); +} + +void ColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t depth) { + // this method should only be called at the end of the path in the base column case + D_ASSERT(depth >= column_path.size()); + ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); +} + +unique_ptr ColumnData::GetUpdateStatistics() { + lock_guard update_guard(update_lock); + return updates ? updates->GetStatistics() : nullptr; +} + +void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row) { + idx_t segment_size = Storage::BLOCK_SIZE; + if (start_row == idx_t(MAX_ROW_ID)) { +#if STANDARD_VECTOR_SIZE < 1024 + segment_size = 1024 * GetTypeIdSize(type.InternalType()); +#else + segment_size = STANDARD_VECTOR_SIZE * GetTypeIdSize(type.InternalType()); +#endif + } + auto new_segment = ColumnSegment::CreateTransientSegment(GetDatabase(), type, start_row, segment_size); + data.AppendSegment(l, std::move(new_segment)); +} + +void ColumnData::CommitDropColumn() { + for (auto &segment_p : data.Segments()) { + auto &segment = segment_p; + segment.CommitDropSegment(); + } +} + +unique_ptr ColumnData::CreateCheckpointState(RowGroup &row_group, + PartialBlockManager &partial_block_manager) { + return make_uniq(row_group, *this, partial_block_manager); +} + +void ColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, + Vector &scan_vector) { + segment.Scan(state, count, scan_vector, 0, true); + if (updates) { + scan_vector.Flatten(count); + updates->FetchCommittedRange(state.row_index - row_group_start, count, scan_vector); + } +} + +unique_ptr ColumnData::Checkpoint(RowGroup &row_group, + PartialBlockManager &partial_block_manager, + ColumnCheckpointInfo &checkpoint_info) { + // scan the segments of the column data + // set up the checkpoint state + auto checkpoint_state = CreateCheckpointState(row_group, partial_block_manager); + checkpoint_state->global_stats = BaseStatistics::CreateEmpty(type).ToUnique(); + + auto l = data.Lock(); + auto nodes = data.MoveSegments(l); + if (nodes.empty()) { + // empty table: flush the empty list + return checkpoint_state; + } + lock_guard update_guard(update_lock); + + ColumnDataCheckpointer checkpointer(*this, row_group, *checkpoint_state, checkpoint_info); + checkpointer.Checkpoint(std::move(nodes)); + + // replace the old tree with the new one + data.Replace(l, checkpoint_state->new_tree); + version++; + + return checkpoint_state; +} + +void ColumnData::DeserializeColumn(Deserializer &deserializer) { + // load the data pointers for the column + deserializer.Set(info.db.GetDatabase()); + deserializer.Set(type); + + vector data_pointers; + deserializer.ReadProperty(100, "data_pointers", data_pointers); + + deserializer.Unset(); + deserializer.Unset(); + + // construct the segments based on the data pointers + this->count = 0; + for (auto &data_pointer : data_pointers) { + // Update the count and statistics + this->count += data_pointer.tuple_count; + if (stats) { + stats->statistics.Merge(data_pointer.statistics); + } + + // create a persistent segment + auto segment = ColumnSegment::CreatePersistentSegment( + GetDatabase(), block_manager, data_pointer.block_pointer.block_id, data_pointer.block_pointer.offset, type, + data_pointer.row_start, data_pointer.tuple_count, data_pointer.compression_type, + std::move(data_pointer.statistics), std::move(data_pointer.segment_state)); + + data.AppendSegment(std::move(segment)); + } +} + +shared_ptr ColumnData::Deserialize(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + idx_t start_row, ReadStream &source, const LogicalType &type, + optional_ptr parent) { + auto entry = ColumnData::CreateColumn(block_manager, info, column_index, start_row, type, parent); + BinaryDeserializer deserializer(source); + deserializer.Begin(); + entry->DeserializeColumn(deserializer); + deserializer.End(); + return entry; +} + +void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_path, + vector &result) { + D_ASSERT(!col_path.empty()); + + // convert the column path to a string + string col_path_str = "["; + for (idx_t i = 0; i < col_path.size(); i++) { + if (i > 0) { + col_path_str += ", "; + } + col_path_str += to_string(col_path[i]); + } + col_path_str += "]"; + + // iterate over the segments + idx_t segment_idx = 0; + auto segment = (ColumnSegment *)data.GetRootSegment(); + while (segment) { + ColumnSegmentInfo column_info; + column_info.row_group_index = row_group_index; + column_info.column_id = col_path[0]; + column_info.column_path = col_path_str; + column_info.segment_idx = segment_idx; + column_info.segment_type = type.ToString(); + column_info.segment_start = segment->start; + column_info.segment_count = segment->count; + column_info.compression_type = CompressionTypeToString(segment->function.get().type); + column_info.segment_stats = segment->stats.statistics.ToString(); + { + lock_guard ulock(update_lock); + column_info.has_updates = updates ? true : false; + } + // persistent + // block_id + // block_offset + if (segment->segment_type == ColumnSegmentType::PERSISTENT) { + column_info.persistent = true; + column_info.block_id = segment->GetBlockId(); + column_info.block_offset = segment->GetBlockOffset(); + } else { + column_info.persistent = false; + } + auto segment_state = segment->GetSegmentState(); + if (segment_state) { + column_info.segment_info = segment_state->GetSegmentInfo(); + } + result.emplace_back(column_info); + + segment_idx++; + segment = data.GetNextSegment(segment); + } +} + +void ColumnData::Verify(RowGroup &parent) { +#ifdef DEBUG + D_ASSERT(this->start == parent.start); + data.Verify(); + if (type.InternalType() == PhysicalType::STRUCT) { + // structs don't have segments + D_ASSERT(!data.GetRootSegment()); + return; + } + idx_t current_index = 0; + idx_t current_start = this->start; + idx_t total_count = 0; + for (auto &segment : data.Segments()) { + D_ASSERT(segment.index == current_index); + D_ASSERT(segment.start == current_start); + current_start += segment.count; + total_count += segment.count; + current_index++; + } + D_ASSERT(this->count == total_count); +#endif +} + +template +static RET CreateColumnInternal(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, + const LogicalType &type, optional_ptr parent) { + if (type.InternalType() == PhysicalType::STRUCT) { + return OP::template Create(block_manager, info, column_index, start_row, type, parent); + } else if (type.InternalType() == PhysicalType::LIST) { + return OP::template Create(block_manager, info, column_index, start_row, type, parent); + } else if (type.id() == LogicalTypeId::VALIDITY) { + return OP::template Create(block_manager, info, column_index, start_row, *parent); + } + return OP::template Create(block_manager, info, column_index, start_row, type, parent); +} + +shared_ptr ColumnData::CreateColumn(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + idx_t start_row, const LogicalType &type, + optional_ptr parent) { + return CreateColumnInternal, SharedConstructor>(block_manager, info, column_index, start_row, + type, parent); +} + +unique_ptr ColumnData::CreateColumnUnique(BlockManager &block_manager, DataTableInfo &info, + idx_t column_index, idx_t start_row, const LogicalType &type, + optional_ptr parent) { + return CreateColumnInternal, UniqueConstructor>(block_manager, info, column_index, start_row, + type, parent); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/column_data_checkpointer.cpp b/src/duckdb/src/storage/table/column_data_checkpointer.cpp new file mode 100644 index 00000000..271c67eb --- /dev/null +++ b/src/duckdb/src/storage/table/column_data_checkpointer.cpp @@ -0,0 +1,261 @@ +#include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +ColumnDataCheckpointer::ColumnDataCheckpointer(ColumnData &col_data_p, RowGroup &row_group_p, + ColumnCheckpointState &state_p, ColumnCheckpointInfo &checkpoint_info_p) + : col_data(col_data_p), row_group(row_group_p), state(state_p), + is_validity(GetType().id() == LogicalTypeId::VALIDITY), + intermediate(is_validity ? LogicalType::BOOLEAN : GetType(), true, is_validity), + checkpoint_info(checkpoint_info_p) { + auto &config = DBConfig::GetConfig(GetDatabase()); + auto functions = config.GetCompressionFunctions(GetType().InternalType()); + for (auto &func : functions) { + compression_functions.push_back(&func.get()); + } +} + +DatabaseInstance &ColumnDataCheckpointer::GetDatabase() { + return col_data.GetDatabase(); +} + +const LogicalType &ColumnDataCheckpointer::GetType() const { + return col_data.type; +} + +ColumnData &ColumnDataCheckpointer::GetColumnData() { + return col_data; +} + +RowGroup &ColumnDataCheckpointer::GetRowGroup() { + return row_group; +} + +ColumnCheckpointState &ColumnDataCheckpointer::GetCheckpointState() { + return state; +} + +void ColumnDataCheckpointer::ScanSegments(const std::function &callback) { + Vector scan_vector(intermediate.GetType(), nullptr); + for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { + auto &segment = *nodes[segment_idx].node; + ColumnScanState scan_state; + scan_state.current = &segment; + segment.InitializeScan(scan_state); + + for (idx_t base_row_index = 0; base_row_index < segment.count; base_row_index += STANDARD_VECTOR_SIZE) { + scan_vector.Reference(intermediate); + + idx_t count = MinValue(segment.count - base_row_index, STANDARD_VECTOR_SIZE); + scan_state.row_index = segment.start + base_row_index; + + col_data.CheckpointScan(segment, scan_state, row_group.start, count, scan_vector); + + callback(scan_vector, count); + } + } +} + +CompressionType ForceCompression(vector> &compression_functions, + CompressionType compression_type) { + // On of the force_compression flags has been set + // check if this compression method is available + bool found = false; + for (idx_t i = 0; i < compression_functions.size(); i++) { + auto &compression_function = *compression_functions[i]; + if (compression_function.type == compression_type) { + found = true; + break; + } + } + if (found) { + // the force_compression method is available + // clear all other compression methods + // except the uncompressed method, so we can fall back on that + for (idx_t i = 0; i < compression_functions.size(); i++) { + auto &compression_function = *compression_functions[i]; + if (compression_function.type == CompressionType::COMPRESSION_UNCOMPRESSED) { + continue; + } + if (compression_function.type != compression_type) { + compression_functions[i] = nullptr; + } + } + } + return found ? compression_type : CompressionType::COMPRESSION_AUTO; +} + +unique_ptr ColumnDataCheckpointer::DetectBestCompressionMethod(idx_t &compression_idx) { + D_ASSERT(!compression_functions.empty()); + auto &config = DBConfig::GetConfig(GetDatabase()); + CompressionType forced_method = CompressionType::COMPRESSION_AUTO; + + auto compression_type = checkpoint_info.compression_type; + if (compression_type != CompressionType::COMPRESSION_AUTO) { + forced_method = ForceCompression(compression_functions, compression_type); + } + if (compression_type == CompressionType::COMPRESSION_AUTO && + config.options.force_compression != CompressionType::COMPRESSION_AUTO) { + forced_method = ForceCompression(compression_functions, config.options.force_compression); + } + // set up the analyze states for each compression method + vector> analyze_states; + analyze_states.reserve(compression_functions.size()); + for (idx_t i = 0; i < compression_functions.size(); i++) { + if (!compression_functions[i]) { + analyze_states.push_back(nullptr); + continue; + } + analyze_states.push_back(compression_functions[i]->init_analyze(col_data, col_data.type.InternalType())); + } + + // scan over all the segments and run the analyze step + ScanSegments([&](Vector &scan_vector, idx_t count) { + for (idx_t i = 0; i < compression_functions.size(); i++) { + if (!compression_functions[i]) { + continue; + } + auto success = compression_functions[i]->analyze(*analyze_states[i], scan_vector, count); + if (!success) { + // could not use this compression function on this data set + // erase it + compression_functions[i] = nullptr; + analyze_states[i].reset(); + } + } + }); + + // now that we have passed over all the data, we need to figure out the best method + // we do this using the final_analyze method + unique_ptr state; + compression_idx = DConstants::INVALID_INDEX; + idx_t best_score = NumericLimits::Maximum(); + for (idx_t i = 0; i < compression_functions.size(); i++) { + if (!compression_functions[i]) { + continue; + } + //! Check if the method type is the forced method (if forced is used) + bool forced_method_found = compression_functions[i]->type == forced_method; + auto score = compression_functions[i]->final_analyze(*analyze_states[i]); + + //! The finalize method can return this value from final_analyze to indicate it should not be used. + if (score == DConstants::INVALID_INDEX) { + continue; + } + + if (score < best_score || forced_method_found) { + compression_idx = i; + best_score = score; + state = std::move(analyze_states[i]); + } + //! If we have found the forced method, we're done + if (forced_method_found) { + break; + } + } + return state; +} + +void ColumnDataCheckpointer::WriteToDisk() { + // there were changes or transient segments + // we need to rewrite the column segments to disk + + // first we check the current segments + // if there are any persistent segments, we will mark their old block ids as modified + // since the segments will be rewritten their old on disk data is no longer required + for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { + auto segment = nodes[segment_idx].node.get(); + segment->CommitDropSegment(); + } + + // now we need to write our segment + // we will first run an analyze step that determines which compression function to use + idx_t compression_idx; + auto analyze_state = DetectBestCompressionMethod(compression_idx); + + if (!analyze_state) { + throw FatalException("No suitable compression/storage method found to store column"); + } + + // now that we have analyzed the compression functions we can start writing to disk + auto best_function = compression_functions[compression_idx]; + auto compress_state = best_function->init_compression(*this, std::move(analyze_state)); + ScanSegments( + [&](Vector &scan_vector, idx_t count) { best_function->compress(*compress_state, scan_vector, count); }); + best_function->compress_finalize(*compress_state); + + nodes.clear(); +} + +bool ColumnDataCheckpointer::HasChanges() { + for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { + auto segment = nodes[segment_idx].node.get(); + if (segment->segment_type == ColumnSegmentType::TRANSIENT) { + // transient segment: always need to write to disk + return true; + } else { + // persistent segment; check if there were any updates or deletions in this segment + idx_t start_row_idx = segment->start - row_group.start; + idx_t end_row_idx = start_row_idx + segment->count; + if (col_data.updates && col_data.updates->HasUpdates(start_row_idx, end_row_idx)) { + return true; + } + } + } + return false; +} + +void ColumnDataCheckpointer::WritePersistentSegments() { + // all segments are persistent and there are no updates + // we only need to write the metadata + for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { + auto segment = nodes[segment_idx].node.get(); + D_ASSERT(segment->segment_type == ColumnSegmentType::PERSISTENT); + + // set up the data pointer directly using the data from the persistent segment + DataPointer pointer(segment->stats.statistics.Copy()); + pointer.block_pointer.block_id = segment->GetBlockId(); + pointer.block_pointer.offset = segment->GetBlockOffset(); + pointer.row_start = segment->start; + pointer.tuple_count = segment->count; + pointer.compression_type = segment->function.get().type; + if (segment->function.get().serialize_state) { + pointer.segment_state = segment->function.get().serialize_state(*segment); + } + + // merge the persistent stats into the global column stats + state.global_stats->Merge(segment->stats.statistics); + + // directly append the current segment to the new tree + state.new_tree.AppendSegment(std::move(nodes[segment_idx].node)); + + state.data_pointers.push_back(std::move(pointer)); + } +} + +void ColumnDataCheckpointer::Checkpoint(vector> nodes_p) { + D_ASSERT(!nodes_p.empty()); + this->nodes = std::move(nodes_p); + // first check if any of the segments have changes + if (!HasChanges()) { + // no changes: only need to write the metadata for this column + WritePersistentSegments(); + } else { + // there are changes: rewrite the set of columns); + WriteToDisk(); + } +} + +CompressionFunction &ColumnDataCheckpointer::GetCompressionFunction(CompressionType compression_type) { + auto &db = GetDatabase(); + auto &column_type = GetType(); + auto &config = DBConfig::GetConfig(db); + return *config.GetCompressionFunction(compression_type, column_type.InternalType()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/column_segment.cpp b/src/duckdb/src/storage/table/column_segment.cpp new file mode 100644 index 00000000..b97c1229 --- /dev/null +++ b/src/duckdb/src/storage/table/column_segment.cpp @@ -0,0 +1,484 @@ +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/data_pointer.hpp" + +#include + +namespace duckdb { + +unique_ptr ColumnSegment::CreatePersistentSegment(DatabaseInstance &db, BlockManager &block_manager, + block_id_t block_id, idx_t offset, + const LogicalType &type, idx_t start, idx_t count, + CompressionType compression_type, + BaseStatistics statistics, + unique_ptr segment_state) { + auto &config = DBConfig::GetConfig(db); + optional_ptr function; + shared_ptr block; + if (block_id == INVALID_BLOCK) { + // constant segment, no need to allocate an actual block + function = config.GetCompressionFunction(CompressionType::COMPRESSION_CONSTANT, type.InternalType()); + } else { + function = config.GetCompressionFunction(compression_type, type.InternalType()); + block = block_manager.RegisterBlock(block_id); + } + auto segment_size = Storage::BLOCK_SIZE; + return make_uniq(db, std::move(block), type, ColumnSegmentType::PERSISTENT, start, count, *function, + std::move(statistics), block_id, offset, segment_size, std::move(segment_state)); +} + +unique_ptr ColumnSegment::CreateTransientSegment(DatabaseInstance &db, const LogicalType &type, + idx_t start, idx_t segment_size) { + auto &config = DBConfig::GetConfig(db); + auto function = config.GetCompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, type.InternalType()); + auto &buffer_manager = BufferManager::GetBufferManager(db); + shared_ptr block; + // transient: allocate a buffer for the uncompressed segment + if (segment_size < Storage::BLOCK_SIZE) { + block = buffer_manager.RegisterSmallMemory(segment_size); + } else { + buffer_manager.Allocate(segment_size, false, &block); + } + return make_uniq(db, std::move(block), type, ColumnSegmentType::TRANSIENT, start, 0, *function, + BaseStatistics::CreateEmpty(type), INVALID_BLOCK, 0, segment_size); +} + +unique_ptr ColumnSegment::CreateSegment(ColumnSegment &other, idx_t start) { + return make_uniq(other, start); +} + +ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block, LogicalType type_p, + ColumnSegmentType segment_type, idx_t start, idx_t count, CompressionFunction &function_p, + BaseStatistics statistics, block_id_t block_id_p, idx_t offset_p, idx_t segment_size_p, + unique_ptr segment_state) + : SegmentBase(start, count), db(db), type(std::move(type_p)), + type_size(GetTypeIdSize(type.InternalType())), segment_type(segment_type), function(function_p), + stats(std::move(statistics)), block(std::move(block)), block_id(block_id_p), offset(offset_p), + segment_size(segment_size_p) { + if (function.get().init_segment) { + this->segment_state = function.get().init_segment(*this, block_id, segment_state.get()); + } +} + +ColumnSegment::ColumnSegment(ColumnSegment &other, idx_t start) + : SegmentBase(start, other.count.load()), db(other.db), type(std::move(other.type)), + type_size(other.type_size), segment_type(other.segment_type), function(other.function), + stats(std::move(other.stats)), block(std::move(other.block)), block_id(other.block_id), offset(other.offset), + segment_size(other.segment_size), segment_state(std::move(other.segment_state)) { +} + +ColumnSegment::~ColumnSegment() { +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +void ColumnSegment::InitializeScan(ColumnScanState &state) { + state.scan_state = function.get().init_scan(*this); +} + +void ColumnSegment::Scan(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset, + bool entire_vector) { + if (entire_vector) { + D_ASSERT(result_offset == 0); + Scan(state, scan_count, result); + } else { + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + ScanPartial(state, scan_count, result, result_offset); + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + } +} + +void ColumnSegment::Skip(ColumnScanState &state) { + function.get().skip(*this, state, state.row_index - state.internal_index); + state.internal_index = state.row_index; +} + +void ColumnSegment::Scan(ColumnScanState &state, idx_t scan_count, Vector &result) { + function.get().scan_vector(*this, state, scan_count, result); +} + +void ColumnSegment::ScanPartial(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset) { + function.get().scan_partial(*this, state, scan_count, result, result_offset); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +void ColumnSegment::FetchRow(ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { + function.get().fetch_row(*this, state, row_id - this->start, result, result_idx); +} + +//===--------------------------------------------------------------------===// +// Append +//===--------------------------------------------------------------------===// +idx_t ColumnSegment::SegmentSize() const { + return segment_size; +} + +void ColumnSegment::Resize(idx_t new_size) { + D_ASSERT(new_size > this->segment_size); + D_ASSERT(offset == 0); + auto &buffer_manager = BufferManager::GetBufferManager(db); + auto old_handle = buffer_manager.Pin(block); + shared_ptr new_block; + auto new_handle = buffer_manager.Allocate(Storage::BLOCK_SIZE, false, &new_block); + memcpy(new_handle.Ptr(), old_handle.Ptr(), segment_size); + this->block_id = new_block->BlockId(); + this->block = std::move(new_block); + this->segment_size = new_size; +} + +void ColumnSegment::InitializeAppend(ColumnAppendState &state) { + D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); + if (!function.get().init_append) { + throw InternalException("Attempting to init append to a segment without init_append method"); + } + state.append_state = function.get().init_append(*this); +} + +idx_t ColumnSegment::Append(ColumnAppendState &state, UnifiedVectorFormat &append_data, idx_t offset, idx_t count) { + D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); + if (!function.get().append) { + throw InternalException("Attempting to append to a segment without append method"); + } + return function.get().append(*state.append_state, *this, stats, append_data, offset, count); +} + +idx_t ColumnSegment::FinalizeAppend(ColumnAppendState &state) { + D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); + if (!function.get().finalize_append) { + throw InternalException("Attempting to call FinalizeAppend on a segment without a finalize_append method"); + } + auto result_count = function.get().finalize_append(*this, stats); + state.append_state.reset(); + return result_count; +} + +void ColumnSegment::RevertAppend(idx_t start_row) { + D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); + if (function.get().revert_append) { + function.get().revert_append(*this, start_row); + } + this->count = start_row - this->start; +} + +//===--------------------------------------------------------------------===// +// Convert To Persistent +//===--------------------------------------------------------------------===// +void ColumnSegment::ConvertToPersistent(optional_ptr block_manager, block_id_t block_id_p) { + D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); + segment_type = ColumnSegmentType::PERSISTENT; + + block_id = block_id_p; + offset = 0; + + if (block_id == INVALID_BLOCK) { + // constant block: reset the block buffer + D_ASSERT(stats.statistics.IsConstant()); + block.reset(); + } else { + D_ASSERT(!stats.statistics.IsConstant()); + // non-constant block: write the block to disk + // the data for the block already exists in-memory of our block + // instead of copying the data we alter some metadata so the buffer points to an on-disk block + block = block_manager->ConvertToPersistent(block_id, std::move(block)); + } +} + +void ColumnSegment::MarkAsPersistent(shared_ptr block_p, uint32_t offset_p) { + D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); + segment_type = ColumnSegmentType::PERSISTENT; + + block_id = block_p->BlockId(); + offset = offset_p; + block = std::move(block_p); +} + +//===--------------------------------------------------------------------===// +// Drop Segment +//===--------------------------------------------------------------------===// +void ColumnSegment::CommitDropSegment() { + if (segment_type != ColumnSegmentType::PERSISTENT) { + // not persistent + return; + } + if (block_id != INVALID_BLOCK) { + GetBlockManager().MarkBlockAsModified(block_id); + } + if (function.get().cleanup_state) { + function.get().cleanup_state(*this); + } +} + +//===--------------------------------------------------------------------===// +// Filter Selection +//===--------------------------------------------------------------------===// +template +static idx_t TemplatedFilterSelection(T *vec, T predicate, SelectionVector &sel, idx_t approved_tuple_count, + ValidityMask &mask, SelectionVector &result_sel) { + idx_t result_count = 0; + for (idx_t i = 0; i < approved_tuple_count; i++) { + auto idx = sel.get_index(i); + if ((!HAS_NULL || mask.RowIsValid(idx)) && OP::Operation(vec[idx], predicate)) { + result_sel.set_index(result_count++, idx); + } + } + return result_count; +} + +template +static void FilterSelectionSwitch(T *vec, T predicate, SelectionVector &sel, idx_t &approved_tuple_count, + ExpressionType comparison_type, ValidityMask &mask) { + SelectionVector new_sel(approved_tuple_count); + // the inplace loops take the result as the last parameter + switch (comparison_type) { + case ExpressionType::COMPARE_EQUAL: { + if (mask.AllValid()) { + approved_tuple_count = + TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); + } else { + approved_tuple_count = + TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); + } + break; + } + case ExpressionType::COMPARE_NOTEQUAL: { + if (mask.AllValid()) { + approved_tuple_count = + TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); + } else { + approved_tuple_count = + TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); + } + break; + } + case ExpressionType::COMPARE_LESSTHAN: { + if (mask.AllValid()) { + approved_tuple_count = + TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); + } else { + approved_tuple_count = + TemplatedFilterSelection(vec, predicate, sel, approved_tuple_count, mask, new_sel); + } + break; + } + case ExpressionType::COMPARE_GREATERTHAN: { + if (mask.AllValid()) { + approved_tuple_count = TemplatedFilterSelection(vec, predicate, sel, + approved_tuple_count, mask, new_sel); + } else { + approved_tuple_count = TemplatedFilterSelection(vec, predicate, sel, + approved_tuple_count, mask, new_sel); + } + break; + } + case ExpressionType::COMPARE_LESSTHANOREQUALTO: { + if (mask.AllValid()) { + approved_tuple_count = TemplatedFilterSelection( + vec, predicate, sel, approved_tuple_count, mask, new_sel); + } else { + approved_tuple_count = TemplatedFilterSelection( + vec, predicate, sel, approved_tuple_count, mask, new_sel); + } + break; + } + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: { + if (mask.AllValid()) { + approved_tuple_count = TemplatedFilterSelection( + vec, predicate, sel, approved_tuple_count, mask, new_sel); + } else { + approved_tuple_count = TemplatedFilterSelection( + vec, predicate, sel, approved_tuple_count, mask, new_sel); + } + break; + } + default: + throw NotImplementedException("Unknown comparison type for filter pushed down to table!"); + } + sel.Initialize(new_sel); +} + +template +static idx_t TemplatedNullSelection(SelectionVector &sel, idx_t &approved_tuple_count, ValidityMask &mask) { + if (mask.AllValid()) { + // no NULL values + if (IS_NULL) { + approved_tuple_count = 0; + return 0; + } else { + return approved_tuple_count; + } + } else { + SelectionVector result_sel(approved_tuple_count); + idx_t result_count = 0; + for (idx_t i = 0; i < approved_tuple_count; i++) { + auto idx = sel.get_index(i); + if (mask.RowIsValid(idx) != IS_NULL) { + result_sel.set_index(result_count++, idx); + } + } + sel.Initialize(result_sel); + approved_tuple_count = result_count; + return result_count; + } +} + +idx_t ColumnSegment::FilterSelection(SelectionVector &sel, Vector &result, const TableFilter &filter, + idx_t &approved_tuple_count, ValidityMask &mask) { + switch (filter.filter_type) { + case TableFilterType::CONJUNCTION_OR: { + // similar to the CONJUNCTION_AND, but we need to take care of the SelectionVectors (OR all of them) + idx_t count_total = 0; + SelectionVector result_sel(approved_tuple_count); + auto &conjunction_or = filter.Cast(); + for (auto &child_filter : conjunction_or.child_filters) { + SelectionVector temp_sel; + temp_sel.Initialize(sel); + idx_t temp_tuple_count = approved_tuple_count; + idx_t temp_count = FilterSelection(temp_sel, result, *child_filter, temp_tuple_count, mask); + // tuples passed, move them into the actual result vector + for (idx_t i = 0; i < temp_count; i++) { + auto new_idx = temp_sel.get_index(i); + bool is_new_idx = true; + for (idx_t res_idx = 0; res_idx < count_total; res_idx++) { + if (result_sel.get_index(res_idx) == new_idx) { + is_new_idx = false; + break; + } + } + if (is_new_idx) { + result_sel.set_index(count_total++, new_idx); + } + } + } + sel.Initialize(result_sel); + approved_tuple_count = count_total; + return approved_tuple_count; + } + case TableFilterType::CONJUNCTION_AND: { + auto &conjunction_and = filter.Cast(); + for (auto &child_filter : conjunction_and.child_filters) { + FilterSelection(sel, result, *child_filter, approved_tuple_count, mask); + } + return approved_tuple_count; + } + case TableFilterType::CONSTANT_COMPARISON: { + auto &constant_filter = filter.Cast(); + // the inplace loops take the result as the last parameter + switch (result.GetType().InternalType()) { + case PhysicalType::UINT8: { + auto result_flat = FlatVector::GetData(result); + auto predicate = UTinyIntValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::UINT16: { + auto result_flat = FlatVector::GetData(result); + auto predicate = USmallIntValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::UINT32: { + auto result_flat = FlatVector::GetData(result); + auto predicate = UIntegerValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::UINT64: { + auto result_flat = FlatVector::GetData(result); + auto predicate = UBigIntValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::INT8: { + auto result_flat = FlatVector::GetData(result); + auto predicate = TinyIntValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::INT16: { + auto result_flat = FlatVector::GetData(result); + auto predicate = SmallIntValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::INT32: { + auto result_flat = FlatVector::GetData(result); + auto predicate = IntegerValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::INT64: { + auto result_flat = FlatVector::GetData(result); + auto predicate = BigIntValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::INT128: { + auto result_flat = FlatVector::GetData(result); + auto predicate = HugeIntValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::FLOAT: { + auto result_flat = FlatVector::GetData(result); + auto predicate = FloatValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::DOUBLE: { + auto result_flat = FlatVector::GetData(result); + auto predicate = DoubleValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::VARCHAR: { + auto result_flat = FlatVector::GetData(result); + auto predicate = string_t(StringValue::Get(constant_filter.constant)); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + case PhysicalType::BOOL: { + auto result_flat = FlatVector::GetData(result); + auto predicate = BooleanValue::Get(constant_filter.constant); + FilterSelectionSwitch(result_flat, predicate, sel, approved_tuple_count, + constant_filter.comparison_type, mask); + break; + } + default: + throw InvalidTypeException(result.GetType(), "Invalid type for filter pushed down to table comparison"); + } + return approved_tuple_count; + } + case TableFilterType::IS_NULL: + return TemplatedNullSelection(sel, approved_tuple_count, mask); + case TableFilterType::IS_NOT_NULL: + return TemplatedNullSelection(sel, approved_tuple_count, mask); + default: + throw InternalException("FIXME: unsupported type for filter selection"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/list_column_data.cpp b/src/duckdb/src/storage/table/list_column_data.cpp new file mode 100644 index 00000000..36f6da70 --- /dev/null +++ b/src/duckdb/src/storage/table/list_column_data.cpp @@ -0,0 +1,371 @@ +#include "duckdb/storage/table/list_column_data.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +ListColumnData::ListColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, + LogicalType type_p, optional_ptr parent) + : ColumnData(block_manager, info, column_index, start_row, std::move(type_p), parent), + validity(block_manager, info, 0, start_row, *this) { + D_ASSERT(type.InternalType() == PhysicalType::LIST); + auto &child_type = ListType::GetChildType(type); + // the child column, with column index 1 (0 is the validity mask) + child_column = ColumnData::CreateColumnUnique(block_manager, info, 1, start_row, child_type, this); +} + +void ListColumnData::SetStart(idx_t new_start) { + ColumnData::SetStart(new_start); + child_column->SetStart(new_start); + validity.SetStart(new_start); +} + +bool ListColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { + // table filters are not supported yet for list columns + return false; +} + +void ListColumnData::InitializeScan(ColumnScanState &state) { + ColumnData::InitializeScan(state); + + // initialize the validity segment + D_ASSERT(state.child_states.size() == 2); + validity.InitializeScan(state.child_states[0]); + + // initialize the child scan + child_column->InitializeScan(state.child_states[1]); +} + +uint64_t ListColumnData::FetchListOffset(idx_t row_idx) { + auto segment = data.GetSegment(row_idx); + ColumnFetchState fetch_state; + Vector result(type, 1); + segment->FetchRow(fetch_state, row_idx, result, 0); + + // initialize the child scan with the required offset + return FlatVector::GetData(result)[0]; +} + +void ListColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { + if (row_idx == 0) { + InitializeScan(state); + return; + } + ColumnData::InitializeScanWithOffset(state, row_idx); + + // initialize the validity segment + D_ASSERT(state.child_states.size() == 2); + validity.InitializeScanWithOffset(state.child_states[0], row_idx); + + // we need to read the list at position row_idx to get the correct row offset of the child + auto child_offset = row_idx == start ? 0 : FetchListOffset(row_idx - 1); + D_ASSERT(child_offset <= child_column->GetMaxEntry()); + if (child_offset < child_column->GetMaxEntry()) { + child_column->InitializeScanWithOffset(state.child_states[1], start + child_offset); + } + state.last_offset = child_offset; +} + +idx_t ListColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) { + return ScanCount(state, result, STANDARD_VECTOR_SIZE); +} + +idx_t ListColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) { + return ScanCount(state, result, STANDARD_VECTOR_SIZE); +} + +idx_t ListColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count) { + if (count == 0) { + return 0; + } + // updates not supported for lists + D_ASSERT(!updates); + + Vector offset_vector(LogicalType::UBIGINT, count); + idx_t scan_count = ScanVector(state, offset_vector, count, false); + D_ASSERT(scan_count > 0); + validity.ScanCount(state.child_states[0], result, count); + + UnifiedVectorFormat offsets; + offset_vector.ToUnifiedFormat(scan_count, offsets); + auto data = UnifiedVectorFormat::GetData(offsets); + auto last_entry = data[offsets.sel->get_index(scan_count - 1)]; + + // shift all offsets so they are 0 at the first entry + auto result_data = FlatVector::GetData(result); + auto base_offset = state.last_offset; + idx_t current_offset = 0; + for (idx_t i = 0; i < scan_count; i++) { + auto offset_index = offsets.sel->get_index(i); + result_data[i].offset = current_offset; + result_data[i].length = data[offset_index] - current_offset - base_offset; + current_offset += result_data[i].length; + } + + D_ASSERT(last_entry >= base_offset); + idx_t child_scan_count = last_entry - base_offset; + ListVector::Reserve(result, child_scan_count); + + if (child_scan_count > 0) { + auto &child_entry = ListVector::GetEntry(result); + if (child_entry.GetType().InternalType() != PhysicalType::STRUCT && + state.child_states[1].row_index + child_scan_count > child_column->start + child_column->GetMaxEntry()) { + throw InternalException("ListColumnData::ScanCount - internal list scan offset is out of range"); + } + child_column->ScanCount(state.child_states[1], child_entry, child_scan_count); + } + state.last_offset = last_entry; + + ListVector::SetListSize(result, child_scan_count); + return scan_count; +} + +void ListColumnData::Skip(ColumnScanState &state, idx_t count) { + // skip inside the validity segment + validity.Skip(state.child_states[0], count); + + // we need to read the list entries/offsets to figure out how much to skip + // note that we only need to read the first and last entry + // however, let's just read all "count" entries for now + Vector result(LogicalType::UBIGINT, count); + idx_t scan_count = ScanVector(state, result, count, false); + if (scan_count == 0) { + return; + } + + auto data = FlatVector::GetData(result); + auto last_entry = data[scan_count - 1]; + idx_t child_scan_count = last_entry - state.last_offset; + if (child_scan_count == 0) { + return; + } + state.last_offset = last_entry; + + // skip the child state forward by the child_scan_count + child_column->Skip(state.child_states[1], child_scan_count); +} + +void ListColumnData::InitializeAppend(ColumnAppendState &state) { + // initialize the list offset append + ColumnData::InitializeAppend(state); + + // initialize the validity append + ColumnAppendState validity_append_state; + validity.InitializeAppend(validity_append_state); + state.child_appends.push_back(std::move(validity_append_state)); + + // initialize the child column append + ColumnAppendState child_append_state; + child_column->InitializeAppend(child_append_state); + state.child_appends.push_back(std::move(child_append_state)); +} + +void ListColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { + D_ASSERT(count > 0); + UnifiedVectorFormat list_data; + vector.ToUnifiedFormat(count, list_data); + auto &list_validity = list_data.validity; + + // construct the list_entry_t entries to append to the column data + auto input_offsets = UnifiedVectorFormat::GetData(list_data); + auto start_offset = child_column->GetMaxEntry(); + idx_t child_count = 0; + + ValidityMask append_mask(count); + auto append_offsets = unique_ptr(new uint64_t[count]); + bool child_contiguous = true; + for (idx_t i = 0; i < count; i++) { + auto input_idx = list_data.sel->get_index(i); + if (list_validity.RowIsValid(input_idx)) { + auto &input_list = input_offsets[input_idx]; + if (input_list.offset != child_count) { + child_contiguous = false; + } + append_offsets[i] = start_offset + child_count + input_list.length; + child_count += input_list.length; + } else { + append_mask.SetInvalid(i); + append_offsets[i] = start_offset + child_count; + } + } + auto &list_child = ListVector::GetEntry(vector); + Vector child_vector(list_child); + if (!child_contiguous) { + // if the child of the list vector is a non-contiguous vector (i.e. list elements are repeating or have gaps) + // we first push a selection vector and flatten the child vector to turn it into a contiguous vector + SelectionVector child_sel(child_count); + idx_t current_count = 0; + for (idx_t i = 0; i < count; i++) { + auto input_idx = list_data.sel->get_index(i); + if (list_validity.RowIsValid(input_idx)) { + auto &input_list = input_offsets[input_idx]; + for (idx_t list_idx = 0; list_idx < input_list.length; list_idx++) { + child_sel.set_index(current_count++, input_list.offset + list_idx); + } + } + } + D_ASSERT(current_count == child_count); + child_vector.Slice(list_child, child_sel, child_count); + } + + UnifiedVectorFormat vdata; + vdata.sel = FlatVector::IncrementalSelectionVector(); + vdata.data = data_ptr_cast(append_offsets.get()); + + // append the list offsets + ColumnData::AppendData(stats, state, vdata, count); + // append the validity data + vdata.validity = append_mask; + validity.AppendData(stats, state.child_appends[0], vdata, count); + // append the child vector + if (child_count > 0) { + child_column->Append(ListStats::GetChildStats(stats), state.child_appends[1], child_vector, child_count); + } +} + +void ListColumnData::RevertAppend(row_t start_row) { + ColumnData::RevertAppend(start_row); + validity.RevertAppend(start_row); + auto column_count = GetMaxEntry(); + if (column_count > start) { + // revert append in the child column + auto list_offset = FetchListOffset(column_count - 1); + child_column->RevertAppend(list_offset); + } +} + +idx_t ListColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { + throw NotImplementedException("List Fetch"); +} + +void ListColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, + idx_t update_count) { + throw NotImplementedException("List Update is not supported."); +} + +void ListColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { + throw NotImplementedException("List Update Column is not supported"); +} + +unique_ptr ListColumnData::GetUpdateStatistics() { + return nullptr; +} + +void ListColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + // insert any child states that are required + // we need two (validity & list child) + // note that we need a scan state for the child vector + // this is because we will (potentially) fetch more than one tuple from the list child + if (state.child_states.empty()) { + auto child_state = make_uniq(); + state.child_states.push_back(std::move(child_state)); + } + + // now perform the fetch within the segment + auto start_offset = idx_t(row_id) == this->start ? 0 : FetchListOffset(row_id - 1); + auto end_offset = FetchListOffset(row_id); + validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); + + auto &validity = FlatVector::Validity(result); + auto list_data = FlatVector::GetData(result); + auto &list_entry = list_data[result_idx]; + // set the list entry offset to the size of the current list + list_entry.offset = ListVector::GetListSize(result); + list_entry.length = end_offset - start_offset; + if (!validity.RowIsValid(result_idx)) { + // the list is NULL! no need to fetch the child + D_ASSERT(list_entry.length == 0); + return; + } + + // now we need to read from the child all the elements between [offset...length] + auto child_scan_count = list_entry.length; + if (child_scan_count > 0) { + auto child_state = make_uniq(); + auto &child_type = ListType::GetChildType(result.GetType()); + Vector child_scan(child_type, child_scan_count); + // seek the scan towards the specified position and read [length] entries + child_state->Initialize(child_type); + child_column->InitializeScanWithOffset(*child_state, start + start_offset); + D_ASSERT(child_type.InternalType() == PhysicalType::STRUCT || + child_state->row_index + child_scan_count - this->start <= child_column->GetMaxEntry()); + child_column->ScanCount(*child_state, child_scan, child_scan_count); + + ListVector::Append(result, child_scan, child_scan_count); + } +} + +void ListColumnData::CommitDropColumn() { + validity.CommitDropColumn(); + child_column->CommitDropColumn(); +} + +struct ListColumnCheckpointState : public ColumnCheckpointState { + ListColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager) + : ColumnCheckpointState(row_group, column_data, partial_block_manager) { + global_stats = ListStats::CreateEmpty(column_data.type).ToUnique(); + } + + unique_ptr validity_state; + unique_ptr child_state; + +public: + unique_ptr GetStatistics() override { + auto stats = global_stats->Copy(); + ListStats::SetChildStats(stats, child_state->GetStatistics()); + return stats.ToUnique(); + } + + void WriteDataPointers(RowGroupWriter &writer, Serializer &serializer) override { + ColumnCheckpointState::WriteDataPointers(writer, serializer); + serializer.WriteObject(101, "validity", + [&](Serializer &serializer) { validity_state->WriteDataPointers(writer, serializer); }); + serializer.WriteObject(102, "child_column", + [&](Serializer &serializer) { child_state->WriteDataPointers(writer, serializer); }); + } +}; + +unique_ptr ListColumnData::CreateCheckpointState(RowGroup &row_group, + PartialBlockManager &partial_block_manager) { + return make_uniq(row_group, *this, partial_block_manager); +} + +unique_ptr ListColumnData::Checkpoint(RowGroup &row_group, + PartialBlockManager &partial_block_manager, + ColumnCheckpointInfo &checkpoint_info) { + auto validity_state = validity.Checkpoint(row_group, partial_block_manager, checkpoint_info); + auto base_state = ColumnData::Checkpoint(row_group, partial_block_manager, checkpoint_info); + auto child_state = child_column->Checkpoint(row_group, partial_block_manager, checkpoint_info); + + auto &checkpoint_state = base_state->Cast(); + checkpoint_state.validity_state = std::move(validity_state); + checkpoint_state.child_state = std::move(child_state); + return base_state; +} + +void ListColumnData::DeserializeColumn(Deserializer &deserializer) { + ColumnData::DeserializeColumn(deserializer); + + deserializer.ReadObject(101, "validity", + [&](Deserializer &deserializer) { validity.DeserializeColumn(deserializer); }); + + deserializer.ReadObject(102, "child_column", + [&](Deserializer &deserializer) { child_column->DeserializeColumn(deserializer); }); +} + +void ListColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, + vector &result) { + ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); + col_path.push_back(0); + validity.GetColumnSegmentInfo(row_group_index, col_path, result); + col_path.back() = 1; + child_column->GetColumnSegmentInfo(row_group_index, col_path, result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/persistent_table_data.cpp b/src/duckdb/src/storage/table/persistent_table_data.cpp new file mode 100644 index 00000000..e08b8a52 --- /dev/null +++ b/src/duckdb/src/storage/table/persistent_table_data.cpp @@ -0,0 +1,12 @@ +#include "duckdb/storage/table/persistent_table_data.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" + +namespace duckdb { + +PersistentTableData::PersistentTableData(idx_t column_count) : total_rows(0), row_group_count(0) { +} + +PersistentTableData::~PersistentTableData() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp new file mode 100644 index 00000000..eafe9804 --- /dev/null +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -0,0 +1,944 @@ +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/common/chrono.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/checkpoint/table_data_writer.hpp" +#include "duckdb/storage/metadata/metadata_reader.hpp" +#include "duckdb/transaction/duck_transaction_manager.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table/row_version_manager.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" + +namespace duckdb { + +RowGroup::RowGroup(RowGroupCollection &collection, idx_t start, idx_t count) + : SegmentBase(start, count), collection(collection) { + Verify(); +} + +RowGroup::RowGroup(RowGroupCollection &collection, RowGroupPointer &&pointer) + : SegmentBase(pointer.row_start, pointer.tuple_count), collection(collection) { + // deserialize the columns + if (pointer.data_pointers.size() != collection.GetTypes().size()) { + throw IOException("Row group column count is unaligned with table column count. Corrupt file?"); + } + this->column_pointers = std::move(pointer.data_pointers); + this->columns.resize(column_pointers.size()); + this->is_loaded = unique_ptr[]>(new atomic[columns.size()]); + for (idx_t c = 0; c < columns.size(); c++) { + this->is_loaded[c] = false; + } + this->deletes_pointers = std::move(pointer.deletes_pointers); + this->deletes_is_loaded = false; + + Verify(); +} + +void RowGroup::MoveToCollection(RowGroupCollection &collection, idx_t new_start) { + this->collection = collection; + this->start = new_start; + for (auto &column : GetColumns()) { + column->SetStart(new_start); + } + if (!HasUnloadedDeletes()) { + auto &vinfo = GetVersionInfo(); + if (vinfo) { + vinfo->SetStart(new_start); + } + } +} + +RowGroup::~RowGroup() { +} + +vector> &RowGroup::GetColumns() { + // ensure all columns are loaded + for (idx_t c = 0; c < GetColumnCount(); c++) { + GetColumn(c); + } + return columns; +} + +idx_t RowGroup::GetColumnCount() const { + return columns.size(); +} + +ColumnData &RowGroup::GetColumn(storage_t c) { + D_ASSERT(c < columns.size()); + if (!is_loaded) { + // not being lazy loaded + D_ASSERT(columns[c]); + return *columns[c]; + } + if (is_loaded[c]) { + D_ASSERT(columns[c]); + return *columns[c]; + } + lock_guard l(row_group_lock); + if (columns[c]) { + D_ASSERT(is_loaded[c]); + return *columns[c]; + } + if (column_pointers.size() != columns.size()) { + throw InternalException("Lazy loading a column but the pointer was not set"); + } + auto &metadata_manager = GetCollection().GetMetadataManager(); + auto &types = GetCollection().GetTypes(); + auto &block_pointer = column_pointers[c]; + MetadataReader column_data_reader(metadata_manager, block_pointer); + this->columns[c] = + ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), c, start, column_data_reader, types[c], nullptr); + is_loaded[c] = true; + if (this->columns[c]->count != this->count) { + throw InternalException("Corrupted database - loaded column with index %llu at row start %llu, count %llu did " + "not match count of row group %llu", + c, start, this->columns[c]->count, this->count.load()); + } + return *columns[c]; +} + +BlockManager &RowGroup::GetBlockManager() { + return GetCollection().GetBlockManager(); +} +DataTableInfo &RowGroup::GetTableInfo() { + return GetCollection().GetTableInfo(); +} + +void RowGroup::InitializeEmpty(const vector &types) { + // set up the segment trees for the column segments + D_ASSERT(columns.empty()); + for (idx_t i = 0; i < types.size(); i++) { + auto column_data = ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), i, start, types[i]); + columns.push_back(std::move(column_data)); + } +} + +void ColumnScanState::Initialize(const LogicalType &type) { + if (type.id() == LogicalTypeId::VALIDITY) { + // validity - nothing to initialize + return; + } + if (type.InternalType() == PhysicalType::STRUCT) { + // validity + struct children + auto &struct_children = StructType::GetChildTypes(type); + child_states.resize(struct_children.size() + 1); + for (idx_t i = 0; i < struct_children.size(); i++) { + child_states[i + 1].Initialize(struct_children[i].second); + } + } else if (type.InternalType() == PhysicalType::LIST) { + // validity + list child + child_states.resize(2); + child_states[1].Initialize(ListType::GetChildType(type)); + } else { + // validity + child_states.resize(1); + } +} + +void CollectionScanState::Initialize(const vector &types) { + auto &column_ids = GetColumnIds(); + column_scans = make_unsafe_uniq_array(column_ids.size()); + for (idx_t i = 0; i < column_ids.size(); i++) { + if (column_ids[i] == COLUMN_IDENTIFIER_ROW_ID) { + continue; + } + column_scans[i].Initialize(types[column_ids[i]]); + } +} + +bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, idx_t vector_offset) { + auto &column_ids = state.GetColumnIds(); + auto filters = state.GetFilters(); + if (filters) { + if (!CheckZonemap(*filters, column_ids)) { + return false; + } + } + + state.row_group = this; + state.vector_index = vector_offset; + state.max_row_group_row = + this->start > state.max_row ? 0 : MinValue(this->count, state.max_row - this->start); + D_ASSERT(state.column_scans); + for (idx_t i = 0; i < column_ids.size(); i++) { + const auto &column = column_ids[i]; + if (column != COLUMN_IDENTIFIER_ROW_ID) { + auto &column_data = GetColumn(column); + column_data.InitializeScanWithOffset(state.column_scans[i], start + vector_offset * STANDARD_VECTOR_SIZE); + } else { + state.column_scans[i].current = nullptr; + } + } + return true; +} + +bool RowGroup::InitializeScan(CollectionScanState &state) { + auto &column_ids = state.GetColumnIds(); + auto filters = state.GetFilters(); + if (filters) { + if (!CheckZonemap(*filters, column_ids)) { + return false; + } + } + state.row_group = this; + state.vector_index = 0; + state.max_row_group_row = + this->start > state.max_row ? 0 : MinValue(this->count, state.max_row - this->start); + if (state.max_row_group_row == 0) { + return false; + } + D_ASSERT(state.column_scans); + for (idx_t i = 0; i < column_ids.size(); i++) { + auto column = column_ids[i]; + if (column != COLUMN_IDENTIFIER_ROW_ID) { + auto &column_data = GetColumn(column); + column_data.InitializeScan(state.column_scans[i]); + } else { + state.column_scans[i].current = nullptr; + } + } + return true; +} + +unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, const LogicalType &target_type, + idx_t changed_idx, ExpressionExecutor &executor, + CollectionScanState &scan_state, DataChunk &scan_chunk) { + Verify(); + + // construct a new column data for this type + auto column_data = ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), changed_idx, start, target_type); + + ColumnAppendState append_state; + column_data->InitializeAppend(append_state); + + // scan the original table, and fill the new column with the transformed value + scan_state.Initialize(GetCollection().GetTypes()); + InitializeScan(scan_state); + + DataChunk append_chunk; + vector append_types; + append_types.push_back(target_type); + append_chunk.Initialize(Allocator::DefaultAllocator(), append_types); + auto &append_vector = append_chunk.data[0]; + while (true) { + // scan the table + scan_chunk.Reset(); + ScanCommitted(scan_state, scan_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); + if (scan_chunk.size() == 0) { + break; + } + // execute the expression + append_chunk.Reset(); + executor.ExecuteExpression(scan_chunk, append_vector); + column_data->Append(append_state, append_vector, scan_chunk.size()); + } + + // set up the row_group based on this row_group + auto row_group = make_uniq(new_collection, this->start, this->count); + row_group->version_info = GetOrCreateVersionInfoPtr(); + auto &cols = GetColumns(); + for (idx_t i = 0; i < cols.size(); i++) { + if (i == changed_idx) { + // this is the altered column: use the new column + row_group->columns.push_back(std::move(column_data)); + } else { + // this column was not altered: use the data directly + row_group->columns.push_back(cols[i]); + } + } + row_group->Verify(); + return row_group; +} + +unique_ptr RowGroup::AddColumn(RowGroupCollection &new_collection, ColumnDefinition &new_column, + ExpressionExecutor &executor, Expression &default_value, Vector &result) { + Verify(); + + // construct a new column data for the new column + auto added_column = + ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), GetColumnCount(), start, new_column.Type()); + + idx_t rows_to_write = this->count; + if (rows_to_write > 0) { + DataChunk dummy_chunk; + + ColumnAppendState state; + added_column->InitializeAppend(state); + for (idx_t i = 0; i < rows_to_write; i += STANDARD_VECTOR_SIZE) { + idx_t rows_in_this_vector = MinValue(rows_to_write - i, STANDARD_VECTOR_SIZE); + dummy_chunk.SetCardinality(rows_in_this_vector); + executor.ExecuteExpression(dummy_chunk, result); + added_column->Append(state, result, rows_in_this_vector); + } + } + + // set up the row_group based on this row_group + auto row_group = make_uniq(new_collection, this->start, this->count); + row_group->version_info = GetOrCreateVersionInfoPtr(); + row_group->columns = GetColumns(); + // now add the new column + row_group->columns.push_back(std::move(added_column)); + + row_group->Verify(); + return row_group; +} + +unique_ptr RowGroup::RemoveColumn(RowGroupCollection &new_collection, idx_t removed_column) { + Verify(); + + D_ASSERT(removed_column < columns.size()); + + auto row_group = make_uniq(new_collection, this->start, this->count); + row_group->version_info = GetOrCreateVersionInfoPtr(); + // copy over all columns except for the removed one + auto &cols = GetColumns(); + for (idx_t i = 0; i < cols.size(); i++) { + if (i != removed_column) { + row_group->columns.push_back(cols[i]); + } + } + + row_group->Verify(); + return row_group; +} + +void RowGroup::CommitDrop() { + for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { + CommitDropColumn(column_idx); + } +} + +void RowGroup::CommitDropColumn(idx_t column_idx) { + GetColumn(column_idx).CommitDropColumn(); +} + +void RowGroup::NextVector(CollectionScanState &state) { + state.vector_index++; + const auto &column_ids = state.GetColumnIds(); + for (idx_t i = 0; i < column_ids.size(); i++) { + const auto &column = column_ids[i]; + if (column == COLUMN_IDENTIFIER_ROW_ID) { + continue; + } + D_ASSERT(column < columns.size()); + GetColumn(column).Skip(state.column_scans[i]); + } +} + +bool RowGroup::CheckZonemap(TableFilterSet &filters, const vector &column_ids) { + for (auto &entry : filters.filters) { + auto column_index = entry.first; + auto &filter = entry.second; + const auto &base_column_index = column_ids[column_index]; + if (!GetColumn(base_column_index).CheckZonemap(*filter)) { + return false; + } + } + return true; +} + +bool RowGroup::CheckZonemapSegments(CollectionScanState &state) { + auto &column_ids = state.GetColumnIds(); + auto filters = state.GetFilters(); + if (!filters) { + return true; + } + for (auto &entry : filters->filters) { + D_ASSERT(entry.first < column_ids.size()); + auto column_idx = entry.first; + const auto &base_column_idx = column_ids[column_idx]; + bool read_segment = GetColumn(base_column_idx).CheckZonemap(state.column_scans[column_idx], *entry.second); + if (!read_segment) { + idx_t target_row = + state.column_scans[column_idx].current->start + state.column_scans[column_idx].current->count; + D_ASSERT(target_row >= this->start); + D_ASSERT(target_row <= this->start + this->count); + idx_t target_vector_index = (target_row - this->start) / STANDARD_VECTOR_SIZE; + if (state.vector_index == target_vector_index) { + // we can't skip any full vectors because this segment contains less than a full vector + // for now we just bail-out + // FIXME: we could check if we can ALSO skip the next segments, in which case skipping a full vector + // might be possible + // we don't care that much though, since a single segment that fits less than a full vector is + // exceedingly rare + return true; + } + while (state.vector_index < target_vector_index) { + NextVector(state); + } + return false; + } + } + + return true; +} + +template +void RowGroup::TemplatedScan(TransactionData transaction, CollectionScanState &state, DataChunk &result) { + const bool ALLOW_UPDATES = TYPE != TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES && + TYPE != TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED; + auto table_filters = state.GetFilters(); + const auto &column_ids = state.GetColumnIds(); + auto adaptive_filter = state.GetAdaptiveFilter(); + while (true) { + if (state.vector_index * STANDARD_VECTOR_SIZE >= state.max_row_group_row) { + // exceeded the amount of rows to scan + return; + } + idx_t current_row = state.vector_index * STANDARD_VECTOR_SIZE; + auto max_count = MinValue(STANDARD_VECTOR_SIZE, state.max_row_group_row - current_row); + + //! first check the zonemap if we have to scan this partition + if (!CheckZonemapSegments(state)) { + continue; + } + // second, scan the version chunk manager to figure out which tuples to load for this transaction + idx_t count; + SelectionVector valid_sel(STANDARD_VECTOR_SIZE); + if (TYPE == TableScanType::TABLE_SCAN_REGULAR) { + count = state.row_group->GetSelVector(transaction, state.vector_index, valid_sel, max_count); + if (count == 0) { + // nothing to scan for this vector, skip the entire vector + NextVector(state); + continue; + } + } else if (TYPE == TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED) { + count = state.row_group->GetCommittedSelVector(transaction.start_time, transaction.transaction_id, + state.vector_index, valid_sel, max_count); + if (count == 0) { + // nothing to scan for this vector, skip the entire vector + NextVector(state); + continue; + } + } else { + count = max_count; + } + if (count == max_count && !table_filters) { + // scan all vectors completely: full scan without deletions or table filters + for (idx_t i = 0; i < column_ids.size(); i++) { + const auto &column = column_ids[i]; + if (column == COLUMN_IDENTIFIER_ROW_ID) { + // scan row id + D_ASSERT(result.data[i].GetType().InternalType() == ROW_TYPE); + result.data[i].Sequence(this->start + current_row, 1, count); + } else { + auto &col_data = GetColumn(column); + if (TYPE != TableScanType::TABLE_SCAN_REGULAR) { + col_data.ScanCommitted(state.vector_index, state.column_scans[i], result.data[i], + ALLOW_UPDATES); + } else { + col_data.Scan(transaction, state.vector_index, state.column_scans[i], result.data[i]); + } + } + } + } else { + // partial scan: we have deletions or table filters + idx_t approved_tuple_count = count; + SelectionVector sel; + if (count != max_count) { + sel.Initialize(valid_sel); + } else { + sel.Initialize(nullptr); + } + //! first, we scan the columns with filters, fetch their data and generate a selection vector. + //! get runtime statistics + auto start_time = high_resolution_clock::now(); + if (table_filters) { + D_ASSERT(adaptive_filter); + D_ASSERT(ALLOW_UPDATES); + for (idx_t i = 0; i < table_filters->filters.size(); i++) { + auto tf_idx = adaptive_filter->permutation[i]; + auto col_idx = column_ids[tf_idx]; + auto &col_data = GetColumn(col_idx); + col_data.Select(transaction, state.vector_index, state.column_scans[tf_idx], result.data[tf_idx], + sel, approved_tuple_count, *table_filters->filters[tf_idx]); + } + for (auto &table_filter : table_filters->filters) { + result.data[table_filter.first].Slice(sel, approved_tuple_count); + } + } + if (approved_tuple_count == 0) { + // all rows were filtered out by the table filters + // skip this vector in all the scans that were not scanned yet + D_ASSERT(table_filters); + result.Reset(); + for (idx_t i = 0; i < column_ids.size(); i++) { + auto col_idx = column_ids[i]; + if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { + continue; + } + if (table_filters->filters.find(i) == table_filters->filters.end()) { + auto &col_data = GetColumn(col_idx); + col_data.Skip(state.column_scans[i]); + } + } + state.vector_index++; + continue; + } + //! Now we use the selection vector to fetch data for the other columns. + for (idx_t i = 0; i < column_ids.size(); i++) { + if (!table_filters || table_filters->filters.find(i) == table_filters->filters.end()) { + auto column = column_ids[i]; + if (column == COLUMN_IDENTIFIER_ROW_ID) { + D_ASSERT(result.data[i].GetType().InternalType() == PhysicalType::INT64); + result.data[i].SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result.data[i]); + for (size_t sel_idx = 0; sel_idx < approved_tuple_count; sel_idx++) { + result_data[sel_idx] = this->start + current_row + sel.get_index(sel_idx); + } + } else { + auto &col_data = GetColumn(column); + if (TYPE == TableScanType::TABLE_SCAN_REGULAR) { + col_data.FilterScan(transaction, state.vector_index, state.column_scans[i], result.data[i], + sel, approved_tuple_count); + } else { + col_data.FilterScanCommitted(state.vector_index, state.column_scans[i], result.data[i], sel, + approved_tuple_count, ALLOW_UPDATES); + } + } + } + } + auto end_time = high_resolution_clock::now(); + if (adaptive_filter && table_filters->filters.size() > 1) { + adaptive_filter->AdaptRuntimeStatistics(duration_cast>(end_time - start_time).count()); + } + D_ASSERT(approved_tuple_count > 0); + count = approved_tuple_count; + } + result.SetCardinality(count); + state.vector_index++; + break; + } +} + +void RowGroup::Scan(TransactionData transaction, CollectionScanState &state, DataChunk &result) { + TemplatedScan(transaction, state, result); +} + +void RowGroup::ScanCommitted(CollectionScanState &state, DataChunk &result, TableScanType type) { + auto &transaction_manager = DuckTransactionManager::Get(GetCollection().GetAttached()); + + auto lowest_active_start = transaction_manager.LowestActiveStart(); + auto lowest_active_id = transaction_manager.LowestActiveId(); + TransactionData data(lowest_active_id, lowest_active_start); + switch (type) { + case TableScanType::TABLE_SCAN_COMMITTED_ROWS: + TemplatedScan(data, state, result); + break; + case TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES: + TemplatedScan(data, state, result); + break; + case TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED: + TemplatedScan(data, state, result); + break; + default: + throw InternalException("Unrecognized table scan type"); + } +} + +shared_ptr &RowGroup::GetVersionInfo() { + if (!HasUnloadedDeletes()) { + // deletes are loaded - return the version info + return version_info; + } + lock_guard lock(row_group_lock); + // double-check after obtaining the lock whether or not deletes are still not loaded to avoid double load + if (HasUnloadedDeletes()) { + // deletes are not loaded - reload + auto root_delete = deletes_pointers[0]; + version_info = RowVersionManager::Deserialize(root_delete, GetBlockManager().GetMetadataManager(), start); + deletes_is_loaded = true; + } + return version_info; +} + +shared_ptr &RowGroup::GetOrCreateVersionInfoPtr() { + auto vinfo = GetVersionInfo(); + if (!vinfo) { + lock_guard lock(row_group_lock); + if (!version_info) { + version_info = make_shared(start); + } + } + return version_info; +} + +RowVersionManager &RowGroup::GetOrCreateVersionInfo() { + return *GetOrCreateVersionInfoPtr(); +} + +idx_t RowGroup::GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, + idx_t max_count) { + auto &vinfo = GetVersionInfo(); + if (!vinfo) { + return max_count; + } + return vinfo->GetSelVector(transaction, vector_idx, sel_vector, max_count); +} + +idx_t RowGroup::GetCommittedSelVector(transaction_t start_time, transaction_t transaction_id, idx_t vector_idx, + SelectionVector &sel_vector, idx_t max_count) { + auto &vinfo = GetVersionInfo(); + if (!vinfo) { + return max_count; + } + return vinfo->GetCommittedSelVector(start_time, transaction_id, vector_idx, sel_vector, max_count); +} + +bool RowGroup::Fetch(TransactionData transaction, idx_t row) { + D_ASSERT(row < this->count); + auto &vinfo = GetVersionInfo(); + if (!vinfo) { + return true; + } + return vinfo->Fetch(transaction, row); +} + +void RowGroup::FetchRow(TransactionData transaction, ColumnFetchState &state, const vector &column_ids, + row_t row_id, DataChunk &result, idx_t result_idx) { + for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { + auto column = column_ids[col_idx]; + if (column == COLUMN_IDENTIFIER_ROW_ID) { + // row id column: fill in the row ids + D_ASSERT(result.data[col_idx].GetType().InternalType() == PhysicalType::INT64); + result.data[col_idx].SetVectorType(VectorType::FLAT_VECTOR); + auto data = FlatVector::GetData(result.data[col_idx]); + data[result_idx] = row_id; + } else { + // regular column: fetch data from the base column + auto &col_data = GetColumn(column); + col_data.FetchRow(transaction, state, row_id, result.data[col_idx], result_idx); + } + } +} + +void RowGroup::AppendVersionInfo(TransactionData transaction, idx_t count) { + idx_t row_group_start = this->count.load(); + idx_t row_group_end = row_group_start + count; + if (row_group_end > Storage::ROW_GROUP_SIZE) { + row_group_end = Storage::ROW_GROUP_SIZE; + } + // create the version_info if it doesn't exist yet + auto &vinfo = GetOrCreateVersionInfo(); + vinfo.AppendVersionInfo(transaction, count, row_group_start, row_group_end); + this->count = row_group_end; +} + +void RowGroup::CommitAppend(transaction_t commit_id, idx_t row_group_start, idx_t count) { + auto &vinfo = GetOrCreateVersionInfo(); + vinfo.CommitAppend(commit_id, row_group_start, count); +} + +void RowGroup::RevertAppend(idx_t row_group_start) { + auto &vinfo = GetOrCreateVersionInfo(); + vinfo.RevertAppend(row_group_start - this->start); + for (auto &column : columns) { + column->RevertAppend(row_group_start); + } + this->count = MinValue(row_group_start - this->start, this->count); + Verify(); +} + +void RowGroup::InitializeAppend(RowGroupAppendState &append_state) { + append_state.row_group = this; + append_state.offset_in_row_group = this->count; + // for each column, initialize the append state + append_state.states = make_unsafe_uniq_array(GetColumnCount()); + for (idx_t i = 0; i < GetColumnCount(); i++) { + auto &col_data = GetColumn(i); + col_data.InitializeAppend(append_state.states[i]); + } +} + +void RowGroup::Append(RowGroupAppendState &state, DataChunk &chunk, idx_t append_count) { + // append to the current row_group + for (idx_t i = 0; i < GetColumnCount(); i++) { + auto &col_data = GetColumn(i); + col_data.Append(state.states[i], chunk.data[i], append_count); + } + state.offset_in_row_group += append_count; +} + +void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_t *ids, idx_t offset, idx_t count, + const vector &column_ids) { +#ifdef DEBUG + for (size_t i = offset; i < offset + count; i++) { + D_ASSERT(ids[i] >= row_t(this->start) && ids[i] < row_t(this->start + this->count)); + } +#endif + for (idx_t i = 0; i < column_ids.size(); i++) { + auto column = column_ids[i]; + D_ASSERT(column.index != COLUMN_IDENTIFIER_ROW_ID); + auto &col_data = GetColumn(column.index); + D_ASSERT(col_data.type.id() == update_chunk.data[i].GetType().id()); + if (offset > 0) { + Vector sliced_vector(update_chunk.data[i], offset, offset + count); + sliced_vector.Flatten(count); + col_data.Update(transaction, column.index, sliced_vector, ids + offset, count); + } else { + col_data.Update(transaction, column.index, update_chunk.data[i], ids, count); + } + MergeStatistics(column.index, *col_data.GetUpdateStatistics()); + } +} + +void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, + const vector &column_path) { + D_ASSERT(updates.ColumnCount() == 1); + auto ids = FlatVector::GetData(row_ids); + + auto primary_column_idx = column_path[0]; + D_ASSERT(primary_column_idx != COLUMN_IDENTIFIER_ROW_ID); + D_ASSERT(primary_column_idx < columns.size()); + auto &col_data = GetColumn(primary_column_idx); + col_data.UpdateColumn(transaction, column_path, updates.data[0], ids, updates.size(), 1); + MergeStatistics(primary_column_idx, *col_data.GetUpdateStatistics()); +} + +unique_ptr RowGroup::GetStatistics(idx_t column_idx) { + auto &col_data = GetColumn(column_idx); + lock_guard slock(stats_lock); + return col_data.GetStatistics(); +} + +void RowGroup::MergeStatistics(idx_t column_idx, const BaseStatistics &other) { + auto &col_data = GetColumn(column_idx); + lock_guard slock(stats_lock); + col_data.MergeStatistics(other); +} + +void RowGroup::MergeIntoStatistics(idx_t column_idx, BaseStatistics &other) { + auto &col_data = GetColumn(column_idx); + lock_guard slock(stats_lock); + col_data.MergeIntoStatistics(other); +} + +RowGroupWriteData RowGroup::WriteToDisk(PartialBlockManager &manager, + const vector &compression_types) { + RowGroupWriteData result; + result.states.reserve(columns.size()); + result.statistics.reserve(columns.size()); + + // Checkpoint the individual columns of the row group + // Here we're iterating over columns. Each column can have multiple segments. + // (Some columns will be wider than others, and require different numbers + // of blocks to encode.) Segments cannot span blocks. + // + // Some of these columns are composite (list, struct). The data is written + // first sequentially, and the pointers are written later, so that the + // pointers all end up densely packed, and thus more cache-friendly. + for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { + auto &column = GetColumn(column_idx); + ColumnCheckpointInfo checkpoint_info {compression_types[column_idx]}; + auto checkpoint_state = column.Checkpoint(*this, manager, checkpoint_info); + D_ASSERT(checkpoint_state); + + auto stats = checkpoint_state->GetStatistics(); + D_ASSERT(stats); + + result.statistics.push_back(stats->Copy()); + result.states.push_back(std::move(checkpoint_state)); + } + D_ASSERT(result.states.size() == result.statistics.size()); + return result; +} + +bool RowGroup::AllDeleted() { + if (HasUnloadedDeletes()) { + // deletes aren't loaded yet - we know not everything is deleted + return false; + } + auto &vinfo = GetVersionInfo(); + if (!vinfo) { + return false; + } + return vinfo->GetCommittedDeletedCount(count) == count; +} + +bool RowGroup::HasUnloadedDeletes() const { + if (deletes_pointers.empty()) { + // no stored deletes at all + return false; + } + // return whether or not the deletes have been loaded + return !deletes_is_loaded; +} + +RowGroupPointer RowGroup::Checkpoint(RowGroupWriter &writer, TableStatistics &global_stats) { + RowGroupPointer row_group_pointer; + + vector compression_types; + compression_types.reserve(columns.size()); + for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { + auto &column = GetColumn(column_idx); + if (column.count != this->count) { + throw InternalException("Corrupted in-memory column - column with index %llu has misaligned count (row " + "group has %llu rows, column has %llu)", + column_idx, this->count.load(), column.count); + } + compression_types.push_back(writer.GetColumnCompressionType(column_idx)); + } + auto result = WriteToDisk(writer.GetPartialBlockManager(), compression_types); + for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { + global_stats.GetStats(column_idx).Statistics().Merge(result.statistics[column_idx]); + } + + // construct the row group pointer and write the column meta data to disk + D_ASSERT(result.states.size() == columns.size()); + row_group_pointer.row_start = start; + row_group_pointer.tuple_count = count; + for (auto &state : result.states) { + // get the current position of the table data writer + auto &data_writer = writer.GetPayloadWriter(); + auto pointer = data_writer.GetMetaBlockPointer(); + + // store the stats and the data pointers in the row group pointers + row_group_pointer.data_pointers.push_back(pointer); + + // Write pointers to the column segments. + // + // Just as above, the state can refer to many other states, so this + // can cascade recursively into more pointer writes. + BinarySerializer serializer(data_writer); + serializer.Begin(); + state->WriteDataPointers(writer, serializer); + serializer.End(); + } + row_group_pointer.deletes_pointers = CheckpointDeletes(writer.GetPayloadWriter().GetManager()); + Verify(); + return row_group_pointer; +} + +vector RowGroup::CheckpointDeletes(MetadataManager &manager) { + if (HasUnloadedDeletes()) { + // deletes were not loaded so they cannot be changed + // re-use them as-is + manager.ClearModifiedBlocks(deletes_pointers); + return deletes_pointers; + } + if (!version_info) { + // no version information: write nothing + return vector(); + } + return version_info->Checkpoint(manager); +} + +void RowGroup::Serialize(RowGroupPointer &pointer, Serializer &serializer) { + serializer.WriteProperty(100, "row_start", pointer.row_start); + serializer.WriteProperty(101, "tuple_count", pointer.tuple_count); + serializer.WriteProperty(102, "data_pointers", pointer.data_pointers); + serializer.WriteProperty(103, "delete_pointers", pointer.deletes_pointers); +} + +RowGroupPointer RowGroup::Deserialize(Deserializer &deserializer) { + RowGroupPointer result; + result.row_start = deserializer.ReadProperty(100, "row_start"); + result.tuple_count = deserializer.ReadProperty(101, "tuple_count"); + result.data_pointers = deserializer.ReadProperty>(102, "data_pointers"); + result.deletes_pointers = deserializer.ReadProperty>(103, "delete_pointers"); + return result; +} + +//===--------------------------------------------------------------------===// +// GetColumnSegmentInfo +//===--------------------------------------------------------------------===// +void RowGroup::GetColumnSegmentInfo(idx_t row_group_index, vector &result) { + for (idx_t col_idx = 0; col_idx < GetColumnCount(); col_idx++) { + auto &col_data = GetColumn(col_idx); + col_data.GetColumnSegmentInfo(row_group_index, {col_idx}, result); + } +} + +//===--------------------------------------------------------------------===// +// Version Delete Information +//===--------------------------------------------------------------------===// +class VersionDeleteState { +public: + VersionDeleteState(RowGroup &info, TransactionData transaction, DataTable &table, idx_t base_row) + : info(info), transaction(transaction), table(table), current_chunk(DConstants::INVALID_INDEX), count(0), + base_row(base_row), delete_count(0) { + } + + RowGroup &info; + TransactionData transaction; + DataTable &table; + idx_t current_chunk; + row_t rows[STANDARD_VECTOR_SIZE]; + idx_t count; + idx_t base_row; + idx_t chunk_row; + idx_t delete_count; + +public: + void Delete(row_t row_id); + void Flush(); +}; + +idx_t RowGroup::Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count) { + VersionDeleteState del_state(*this, transaction, table, this->start); + + // obtain a write lock + for (idx_t i = 0; i < count; i++) { + D_ASSERT(ids[i] >= 0); + D_ASSERT(idx_t(ids[i]) >= this->start && idx_t(ids[i]) < this->start + this->count); + del_state.Delete(ids[i] - this->start); + } + del_state.Flush(); + return del_state.delete_count; +} + +void RowGroup::Verify() { +#ifdef DEBUG + for (auto &column : GetColumns()) { + column->Verify(*this); + } +#endif +} + +idx_t RowGroup::DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t rows[], idx_t count) { + return GetOrCreateVersionInfo().DeleteRows(vector_idx, transaction_id, rows, count); +} + +void VersionDeleteState::Delete(row_t row_id) { + D_ASSERT(row_id >= 0); + idx_t vector_idx = row_id / STANDARD_VECTOR_SIZE; + idx_t idx_in_vector = row_id - vector_idx * STANDARD_VECTOR_SIZE; + if (current_chunk != vector_idx) { + Flush(); + + current_chunk = vector_idx; + chunk_row = vector_idx * STANDARD_VECTOR_SIZE; + } + rows[count++] = idx_in_vector; +} + +void VersionDeleteState::Flush() { + if (count == 0) { + return; + } + // it is possible for delete statements to delete the same tuple multiple times when combined with a USING clause + // in the current_info->Delete, we check which tuples are actually deleted (excluding duplicate deletions) + // this is returned in the actual_delete_count + auto actual_delete_count = info.DeleteRows(current_chunk, transaction.transaction_id, rows, count); + delete_count += actual_delete_count; + if (transaction.transaction && actual_delete_count > 0) { + // now push the delete into the undo buffer, but only if any deletes were actually performed + transaction.transaction->PushDelete(table, info.GetOrCreateVersionInfo(), current_chunk, rows, + actual_delete_count, base_row + chunk_row); + } + count = 0; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/row_group_collection.cpp b/src/duckdb/src/storage/table/row_group_collection.cpp new file mode 100644 index 00000000..c8342dfd --- /dev/null +++ b/src/duckdb/src/storage/table/row_group_collection.cpp @@ -0,0 +1,789 @@ +#include "duckdb/storage/table/row_group_collection.hpp" +#include "duckdb/storage/table/persistent_table_data.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" +#include "duckdb/storage/checkpoint/table_data_writer.hpp" +#include "duckdb/storage/table/row_group_segment_tree.hpp" +#include "duckdb/storage/metadata/metadata_reader.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Row Group Segment Tree +//===--------------------------------------------------------------------===// +RowGroupSegmentTree::RowGroupSegmentTree(RowGroupCollection &collection) + : SegmentTree(), collection(collection), current_row_group(0), max_row_group(0) { +} +RowGroupSegmentTree::~RowGroupSegmentTree() { +} + +void RowGroupSegmentTree::Initialize(PersistentTableData &data) { + D_ASSERT(data.row_group_count > 0); + current_row_group = 0; + max_row_group = data.row_group_count; + finished_loading = false; + reader = make_uniq(collection.GetMetadataManager(), data.block_pointer); +} + +unique_ptr RowGroupSegmentTree::LoadSegment() { + if (current_row_group >= max_row_group) { + reader.reset(); + finished_loading = true; + return nullptr; + } + BinaryDeserializer deserializer(*reader); + deserializer.Begin(); + auto row_group_pointer = RowGroup::Deserialize(deserializer); + deserializer.End(); + current_row_group++; + return make_uniq(collection, std::move(row_group_pointer)); +} + +//===--------------------------------------------------------------------===// +// Row Group Collection +//===--------------------------------------------------------------------===// +RowGroupCollection::RowGroupCollection(shared_ptr info_p, BlockManager &block_manager, + vector types_p, idx_t row_start_p, idx_t total_rows_p) + : block_manager(block_manager), total_rows(total_rows_p), info(std::move(info_p)), types(std::move(types_p)), + row_start(row_start_p) { + row_groups = make_shared(*this); +} + +idx_t RowGroupCollection::GetTotalRows() const { + return total_rows.load(); +} + +const vector &RowGroupCollection::GetTypes() const { + return types; +} + +Allocator &RowGroupCollection::GetAllocator() const { + return Allocator::Get(info->db); +} + +AttachedDatabase &RowGroupCollection::GetAttached() { + return GetTableInfo().db; +} + +MetadataManager &RowGroupCollection::GetMetadataManager() { + return GetBlockManager().GetMetadataManager(); +} + +//===--------------------------------------------------------------------===// +// Initialize +//===--------------------------------------------------------------------===// +void RowGroupCollection::Initialize(PersistentTableData &data) { + D_ASSERT(this->row_start == 0); + auto l = row_groups->Lock(); + this->total_rows = data.total_rows; + row_groups->Initialize(data); + stats.Initialize(types, data); +} + +void RowGroupCollection::InitializeEmpty() { + stats.InitializeEmpty(types); +} + +void RowGroupCollection::AppendRowGroup(SegmentLock &l, idx_t start_row) { + D_ASSERT(start_row >= row_start); + auto new_row_group = make_uniq(*this, start_row, 0); + new_row_group->InitializeEmpty(types); + row_groups->AppendSegment(l, std::move(new_row_group)); +} + +RowGroup *RowGroupCollection::GetRowGroup(int64_t index) { + return (RowGroup *)row_groups->GetSegmentByIndex(index); +} + +void RowGroupCollection::Verify() { +#ifdef DEBUG + idx_t current_total_rows = 0; + row_groups->Verify(); + for (auto &row_group : row_groups->Segments()) { + row_group.Verify(); + D_ASSERT(&row_group.GetCollection() == this); + D_ASSERT(row_group.start == this->row_start + current_total_rows); + current_total_rows += row_group.count; + } + D_ASSERT(current_total_rows == total_rows.load()); +#endif +} + +//===--------------------------------------------------------------------===// +// Scan +//===--------------------------------------------------------------------===// +void RowGroupCollection::InitializeScan(CollectionScanState &state, const vector &column_ids, + TableFilterSet *table_filters) { + auto row_group = row_groups->GetRootSegment(); + D_ASSERT(row_group); + state.row_groups = row_groups.get(); + state.max_row = row_start + total_rows; + state.Initialize(GetTypes()); + while (row_group && !row_group->InitializeScan(state)) { + row_group = row_groups->GetNextSegment(row_group); + } +} + +void RowGroupCollection::InitializeCreateIndexScan(CreateIndexScanState &state) { + state.segment_lock = row_groups->Lock(); +} + +void RowGroupCollection::InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, + idx_t start_row, idx_t end_row) { + auto row_group = row_groups->GetSegment(start_row); + D_ASSERT(row_group); + state.row_groups = row_groups.get(); + state.max_row = end_row; + state.Initialize(GetTypes()); + idx_t start_vector = (start_row - row_group->start) / STANDARD_VECTOR_SIZE; + if (!row_group->InitializeScanWithOffset(state, start_vector)) { + throw InternalException("Failed to initialize row group scan with offset"); + } +} + +bool RowGroupCollection::InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, + RowGroup &row_group, idx_t vector_index, idx_t max_row) { + state.max_row = max_row; + state.row_groups = collection.row_groups.get(); + if (!state.column_scans) { + // initialize the scan state + state.Initialize(collection.GetTypes()); + } + return row_group.InitializeScanWithOffset(state, vector_index); +} + +void RowGroupCollection::InitializeParallelScan(ParallelCollectionScanState &state) { + state.collection = this; + state.current_row_group = row_groups->GetRootSegment(); + state.vector_index = 0; + state.max_row = row_start + total_rows; + state.batch_index = 0; + state.processed_rows = 0; +} + +bool RowGroupCollection::NextParallelScan(ClientContext &context, ParallelCollectionScanState &state, + CollectionScanState &scan_state) { + while (true) { + idx_t vector_index; + idx_t max_row; + RowGroupCollection *collection; + RowGroup *row_group; + { + // select the next row group to scan from the parallel state + lock_guard l(state.lock); + if (!state.current_row_group || state.current_row_group->count == 0) { + // no more data left to scan + break; + } + collection = state.collection; + row_group = state.current_row_group; + if (ClientConfig::GetConfig(context).verify_parallelism) { + vector_index = state.vector_index; + max_row = state.current_row_group->start + + MinValue(state.current_row_group->count, + STANDARD_VECTOR_SIZE * state.vector_index + STANDARD_VECTOR_SIZE); + D_ASSERT(vector_index * STANDARD_VECTOR_SIZE < state.current_row_group->count); + state.vector_index++; + if (state.vector_index * STANDARD_VECTOR_SIZE >= state.current_row_group->count) { + state.current_row_group = row_groups->GetNextSegment(state.current_row_group); + state.vector_index = 0; + } + } else { + state.processed_rows += state.current_row_group->count; + vector_index = 0; + max_row = state.current_row_group->start + state.current_row_group->count; + state.current_row_group = row_groups->GetNextSegment(state.current_row_group); + } + max_row = MinValue(max_row, state.max_row); + scan_state.batch_index = ++state.batch_index; + } + D_ASSERT(collection); + D_ASSERT(row_group); + + // initialize the scan for this row group + bool need_to_scan = InitializeScanInRowGroup(scan_state, *collection, *row_group, vector_index, max_row); + if (!need_to_scan) { + // skip this row group + continue; + } + return true; + } + return false; +} + +bool RowGroupCollection::Scan(DuckTransaction &transaction, const vector &column_ids, + const std::function &fun) { + vector scan_types; + for (idx_t i = 0; i < column_ids.size(); i++) { + scan_types.push_back(types[column_ids[i]]); + } + DataChunk chunk; + chunk.Initialize(GetAllocator(), scan_types); + + // initialize the scan + TableScanState state; + state.Initialize(column_ids, nullptr); + InitializeScan(state.local_state, column_ids, nullptr); + + while (true) { + chunk.Reset(); + state.local_state.Scan(transaction, chunk); + if (chunk.size() == 0) { + return true; + } + if (!fun(chunk)) { + return false; + } + } +} + +bool RowGroupCollection::Scan(DuckTransaction &transaction, const std::function &fun) { + vector column_ids; + column_ids.reserve(types.size()); + for (idx_t i = 0; i < types.size(); i++) { + column_ids.push_back(i); + } + return Scan(transaction, column_ids, fun); +} + +//===--------------------------------------------------------------------===// +// Fetch +//===--------------------------------------------------------------------===// +void RowGroupCollection::Fetch(TransactionData transaction, DataChunk &result, const vector &column_ids, + const Vector &row_identifiers, idx_t fetch_count, ColumnFetchState &state) { + // figure out which row_group to fetch from + auto row_ids = FlatVector::GetData(row_identifiers); + idx_t count = 0; + for (idx_t i = 0; i < fetch_count; i++) { + auto row_id = row_ids[i]; + RowGroup *row_group; + { + idx_t segment_index; + auto l = row_groups->Lock(); + if (!row_groups->TryGetSegmentIndex(l, row_id, segment_index)) { + // in parallel append scenarios it is possible for the row_id + continue; + } + row_group = row_groups->GetSegmentByIndex(l, segment_index); + } + if (!row_group->Fetch(transaction, row_id - row_group->start)) { + continue; + } + row_group->FetchRow(transaction, state, column_ids, row_id, result, count); + count++; + } + result.SetCardinality(count); +} + +//===--------------------------------------------------------------------===// +// Append +//===--------------------------------------------------------------------===// +TableAppendState::TableAppendState() + : row_group_append_state(*this), total_append_count(0), start_row_group(nullptr), transaction(0, 0), remaining(0) { +} + +TableAppendState::~TableAppendState() { + D_ASSERT(Exception::UncaughtException() || remaining == 0); +} + +bool RowGroupCollection::IsEmpty() const { + auto l = row_groups->Lock(); + return IsEmpty(l); +} + +bool RowGroupCollection::IsEmpty(SegmentLock &l) const { + return row_groups->IsEmpty(l); +} + +void RowGroupCollection::InitializeAppend(TransactionData transaction, TableAppendState &state, idx_t append_count) { + state.row_start = total_rows; + state.current_row = state.row_start; + state.total_append_count = 0; + + // start writing to the row_groups + auto l = row_groups->Lock(); + if (IsEmpty(l)) { + // empty row group collection: empty first row group + AppendRowGroup(l, row_start); + } + state.start_row_group = row_groups->GetLastSegment(l); + D_ASSERT(this->row_start + total_rows == state.start_row_group->start + state.start_row_group->count); + state.start_row_group->InitializeAppend(state.row_group_append_state); + state.remaining = append_count; + state.transaction = transaction; + if (state.remaining > 0) { + state.start_row_group->AppendVersionInfo(transaction, state.remaining); + total_rows += state.remaining; + } +} + +void RowGroupCollection::InitializeAppend(TableAppendState &state) { + TransactionData tdata(0, 0); + InitializeAppend(tdata, state, 0); +} + +bool RowGroupCollection::Append(DataChunk &chunk, TableAppendState &state) { + D_ASSERT(chunk.ColumnCount() == types.size()); + chunk.Verify(); + + bool new_row_group = false; + idx_t append_count = chunk.size(); + idx_t remaining = chunk.size(); + state.total_append_count += append_count; + while (true) { + auto current_row_group = state.row_group_append_state.row_group; + // check how much we can fit into the current row_group + idx_t append_count = + MinValue(remaining, Storage::ROW_GROUP_SIZE - state.row_group_append_state.offset_in_row_group); + if (append_count > 0) { + current_row_group->Append(state.row_group_append_state, chunk, append_count); + // merge the stats + auto stats_lock = stats.GetLock(); + for (idx_t i = 0; i < types.size(); i++) { + current_row_group->MergeIntoStatistics(i, stats.GetStats(i).Statistics()); + } + } + remaining -= append_count; + if (state.remaining > 0) { + state.remaining -= append_count; + } + if (remaining > 0) { + // we expect max 1 iteration of this loop (i.e. a single chunk should never overflow more than one + // row_group) + D_ASSERT(chunk.size() == remaining + append_count); + // slice the input chunk + if (remaining < chunk.size()) { + SelectionVector sel(remaining); + for (idx_t i = 0; i < remaining; i++) { + sel.set_index(i, append_count + i); + } + chunk.Slice(sel, remaining); + } + // append a new row_group + new_row_group = true; + auto next_start = current_row_group->start + state.row_group_append_state.offset_in_row_group; + + auto l = row_groups->Lock(); + AppendRowGroup(l, next_start); + // set up the append state for this row_group + auto last_row_group = row_groups->GetLastSegment(l); + last_row_group->InitializeAppend(state.row_group_append_state); + if (state.remaining > 0) { + last_row_group->AppendVersionInfo(state.transaction, state.remaining); + } + continue; + } else { + break; + } + } + state.current_row += append_count; + auto stats_lock = stats.GetLock(); + for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { + stats.GetStats(col_idx).UpdateDistinctStatistics(chunk.data[col_idx], chunk.size()); + } + return new_row_group; +} + +void RowGroupCollection::FinalizeAppend(TransactionData transaction, TableAppendState &state) { + auto remaining = state.total_append_count; + auto row_group = state.start_row_group; + while (remaining > 0) { + auto append_count = MinValue(remaining, Storage::ROW_GROUP_SIZE - row_group->count); + row_group->AppendVersionInfo(transaction, append_count); + remaining -= append_count; + row_group = row_groups->GetNextSegment(row_group); + } + total_rows += state.total_append_count; + + state.total_append_count = 0; + state.start_row_group = nullptr; + + Verify(); +} + +void RowGroupCollection::CommitAppend(transaction_t commit_id, idx_t row_start, idx_t count) { + auto row_group = row_groups->GetSegment(row_start); + D_ASSERT(row_group); + idx_t current_row = row_start; + idx_t remaining = count; + while (true) { + idx_t start_in_row_group = current_row - row_group->start; + idx_t append_count = MinValue(row_group->count - start_in_row_group, remaining); + + row_group->CommitAppend(commit_id, start_in_row_group, append_count); + + current_row += append_count; + remaining -= append_count; + if (remaining == 0) { + break; + } + row_group = row_groups->GetNextSegment(row_group); + } +} + +void RowGroupCollection::RevertAppendInternal(idx_t start_row) { + if (total_rows <= start_row) { + return; + } + total_rows = start_row; + + auto l = row_groups->Lock(); + // find the segment index that the current row belongs to + idx_t segment_index = row_groups->GetSegmentIndex(l, start_row); + auto segment = row_groups->GetSegmentByIndex(l, segment_index); + auto &info = *segment; + + // remove any segments AFTER this segment: they should be deleted entirely + row_groups->EraseSegments(l, segment_index); + + info.next = nullptr; + info.RevertAppend(start_row); +} + +void RowGroupCollection::MergeStorage(RowGroupCollection &data) { + D_ASSERT(data.types == types); + auto index = row_start + total_rows.load(); + auto segments = data.row_groups->MoveSegments(); + for (auto &entry : segments) { + auto &row_group = entry.node; + row_group->MoveToCollection(*this, index); + index += row_group->count; + row_groups->AppendSegment(std::move(row_group)); + } + stats.MergeStats(data.stats); + total_rows += data.total_rows.load(); +} + +//===--------------------------------------------------------------------===// +// Delete +//===--------------------------------------------------------------------===// +idx_t RowGroupCollection::Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count) { + idx_t delete_count = 0; + // delete is in the row groups + // we need to figure out for each id to which row group it belongs + // usually all (or many) ids belong to the same row group + // we iterate over the ids and check for every id if it belongs to the same row group as their predecessor + idx_t pos = 0; + do { + idx_t start = pos; + auto row_group = row_groups->GetSegment(ids[start]); + for (pos++; pos < count; pos++) { + D_ASSERT(ids[pos] >= 0); + // check if this id still belongs to this row group + if (idx_t(ids[pos]) < row_group->start) { + // id is before row_group start -> it does not + break; + } + if (idx_t(ids[pos]) >= row_group->start + row_group->count) { + // id is after row group end -> it does not + break; + } + } + delete_count += row_group->Delete(transaction, table, ids + start, pos - start); + } while (pos < count); + return delete_count; +} + +//===--------------------------------------------------------------------===// +// Update +//===--------------------------------------------------------------------===// +void RowGroupCollection::Update(TransactionData transaction, row_t *ids, const vector &column_ids, + DataChunk &updates) { + idx_t pos = 0; + do { + idx_t start = pos; + auto row_group = row_groups->GetSegment(ids[pos]); + row_t base_id = + row_group->start + ((ids[pos] - row_group->start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); + row_t max_id = MinValue(base_id + STANDARD_VECTOR_SIZE, row_group->start + row_group->count); + for (pos++; pos < updates.size(); pos++) { + D_ASSERT(ids[pos] >= 0); + // check if this id still belongs to this vector in this row group + if (ids[pos] < base_id) { + // id is before vector start -> it does not + break; + } + if (ids[pos] >= max_id) { + // id is after the maximum id in this vector -> it does not + break; + } + } + row_group->Update(transaction, updates, ids, start, pos - start, column_ids); + + auto l = stats.GetLock(); + for (idx_t i = 0; i < column_ids.size(); i++) { + auto column_id = column_ids[i]; + stats.MergeStats(*l, column_id.index, *row_group->GetStatistics(column_id.index)); + } + } while (pos < updates.size()); +} + +void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_identifiers, idx_t count) { + auto row_ids = FlatVector::GetData(row_identifiers); + + // initialize the fetch state + // FIXME: we do not need to fetch all columns, only the columns required by the indices! + TableScanState state; + vector column_ids; + column_ids.reserve(types.size()); + for (idx_t i = 0; i < types.size(); i++) { + column_ids.push_back(i); + } + state.Initialize(std::move(column_ids)); + state.table_state.max_row = row_start + total_rows; + + // initialize the fetch chunk + DataChunk result; + result.Initialize(GetAllocator(), types); + + SelectionVector sel(STANDARD_VECTOR_SIZE); + // now iterate over the row ids + for (idx_t r = 0; r < count;) { + result.Reset(); + // figure out which row_group to fetch from + auto row_id = row_ids[r]; + auto row_group = row_groups->GetSegment(row_id); + auto row_group_vector_idx = (row_id - row_group->start) / STANDARD_VECTOR_SIZE; + auto base_row_id = row_group_vector_idx * STANDARD_VECTOR_SIZE + row_group->start; + + // fetch the current vector + state.table_state.Initialize(GetTypes()); + row_group->InitializeScanWithOffset(state.table_state, row_group_vector_idx); + row_group->ScanCommitted(state.table_state, result, TableScanType::TABLE_SCAN_COMMITTED_ROWS); + result.Verify(); + + // check for any remaining row ids if they also fall into this vector + // we try to fetch handle as many rows as possible at the same time + idx_t sel_count = 0; + for (; r < count; r++) { + idx_t current_row = idx_t(row_ids[r]); + if (current_row < base_row_id || current_row >= base_row_id + result.size()) { + // this row-id does not fall into the current chunk - break + break; + } + auto row_in_vector = current_row - base_row_id; + D_ASSERT(row_in_vector < result.size()); + sel.set_index(sel_count++, row_in_vector); + } + D_ASSERT(sel_count > 0); + // slice the vector with all rows that are present in this vector and erase from the index + result.Slice(sel, sel_count); + + indexes.Scan([&](Index &index) { + index.Delete(result, row_identifiers); + return false; + }); + } +} + +void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, + DataChunk &updates) { + auto first_id = FlatVector::GetValue(row_ids, 0); + if (first_id >= MAX_ROW_ID) { + throw NotImplementedException("Cannot update a column-path on transaction local data"); + } + // find the row_group this id belongs to + auto primary_column_idx = column_path[0]; + auto row_group = row_groups->GetSegment(first_id); + row_group->UpdateColumn(transaction, updates, row_ids, column_path); + + row_group->MergeIntoStatistics(primary_column_idx, stats.GetStats(primary_column_idx).Statistics()); +} + +//===--------------------------------------------------------------------===// +// Checkpoint +//===--------------------------------------------------------------------===// +void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &global_stats) { + bool can_vacuum_deletes = info->indexes.Empty(); + idx_t start = this->row_start; + auto segments = row_groups->MoveSegments(); + auto l = row_groups->Lock(); + for (auto &entry : segments) { + auto &row_group = *entry.node; + if (can_vacuum_deletes && row_group.AllDeleted()) { + row_group.CommitDrop(); + continue; + } + row_group.MoveToCollection(*this, start); + auto row_group_writer = writer.GetRowGroupWriter(row_group); + auto pointer = row_group.Checkpoint(*row_group_writer, global_stats); + writer.AddRowGroup(std::move(pointer), std::move(row_group_writer)); + row_groups->AppendSegment(l, std::move(entry.node)); + start += row_group.count; + } + total_rows = start; +} + +//===--------------------------------------------------------------------===// +// CommitDrop +//===--------------------------------------------------------------------===// +void RowGroupCollection::CommitDropColumn(idx_t index) { + for (auto &row_group : row_groups->Segments()) { + row_group.CommitDropColumn(index); + } +} + +void RowGroupCollection::CommitDropTable() { + for (auto &row_group : row_groups->Segments()) { + row_group.CommitDrop(); + } +} + +//===--------------------------------------------------------------------===// +// GetColumnSegmentInfo +//===--------------------------------------------------------------------===// +vector RowGroupCollection::GetColumnSegmentInfo() { + vector result; + for (auto &row_group : row_groups->Segments()) { + row_group.GetColumnSegmentInfo(row_group.index, result); + } + return result; +} + +//===--------------------------------------------------------------------===// +// Alter +//===--------------------------------------------------------------------===// +shared_ptr RowGroupCollection::AddColumn(ClientContext &context, ColumnDefinition &new_column, + Expression &default_value) { + idx_t new_column_idx = types.size(); + auto new_types = types; + new_types.push_back(new_column.GetType()); + auto result = + make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); + + ExpressionExecutor executor(context); + DataChunk dummy_chunk; + Vector default_vector(new_column.GetType()); + executor.AddExpression(default_value); + + result->stats.InitializeAddColumn(stats, new_column.GetType()); + auto &new_column_stats = result->stats.GetStats(new_column_idx); + + // fill the column with its DEFAULT value, or NULL if none is specified + auto new_stats = make_uniq(new_column.GetType()); + for (auto ¤t_row_group : row_groups->Segments()) { + auto new_row_group = current_row_group.AddColumn(*result, new_column, executor, default_value, default_vector); + // merge in the statistics + new_row_group->MergeIntoStatistics(new_column_idx, new_column_stats.Statistics()); + + result->row_groups->AppendSegment(std::move(new_row_group)); + } + return result; +} + +shared_ptr RowGroupCollection::RemoveColumn(idx_t col_idx) { + D_ASSERT(col_idx < types.size()); + auto new_types = types; + new_types.erase(new_types.begin() + col_idx); + + auto result = + make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); + result->stats.InitializeRemoveColumn(stats, col_idx); + + for (auto ¤t_row_group : row_groups->Segments()) { + auto new_row_group = current_row_group.RemoveColumn(*result, col_idx); + result->row_groups->AppendSegment(std::move(new_row_group)); + } + return result; +} + +shared_ptr RowGroupCollection::AlterType(ClientContext &context, idx_t changed_idx, + const LogicalType &target_type, + vector bound_columns, Expression &cast_expr) { + D_ASSERT(changed_idx < types.size()); + auto new_types = types; + new_types[changed_idx] = target_type; + + auto result = + make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); + result->stats.InitializeAlterType(stats, changed_idx, target_type); + + vector scan_types; + for (idx_t i = 0; i < bound_columns.size(); i++) { + if (bound_columns[i] == COLUMN_IDENTIFIER_ROW_ID) { + scan_types.emplace_back(LogicalType::ROW_TYPE); + } else { + scan_types.push_back(types[bound_columns[i]]); + } + } + DataChunk scan_chunk; + scan_chunk.Initialize(GetAllocator(), scan_types); + + ExpressionExecutor executor(context); + executor.AddExpression(cast_expr); + + TableScanState scan_state; + scan_state.Initialize(bound_columns); + scan_state.table_state.max_row = row_start + total_rows; + + // now alter the type of the column within all of the row_groups individually + auto &changed_stats = result->stats.GetStats(changed_idx); + for (auto ¤t_row_group : row_groups->Segments()) { + auto new_row_group = current_row_group.AlterType(*result, target_type, changed_idx, executor, + scan_state.table_state, scan_chunk); + new_row_group->MergeIntoStatistics(changed_idx, changed_stats.Statistics()); + result->row_groups->AppendSegment(std::move(new_row_group)); + } + + return result; +} + +void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint) { + if (total_rows == 0) { + return; + } + // scan the original table, check if there's any null value + auto ¬_null_constraint = constraint.Cast(); + vector scan_types; + auto physical_index = not_null_constraint.index.index; + D_ASSERT(physical_index < types.size()); + scan_types.push_back(types[physical_index]); + DataChunk scan_chunk; + scan_chunk.Initialize(GetAllocator(), scan_types); + + CreateIndexScanState state; + vector cids; + cids.push_back(physical_index); + // Use ScanCommitted to scan the latest committed data + state.Initialize(cids, nullptr); + InitializeScan(state.table_state, cids, nullptr); + InitializeCreateIndexScan(state); + while (true) { + scan_chunk.Reset(); + state.table_state.ScanCommitted(scan_chunk, state.segment_lock, + TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED); + if (scan_chunk.size() == 0) { + break; + } + // Check constraint + if (VectorOperations::HasNull(scan_chunk.data[0], scan_chunk.size())) { + throw ConstraintException("NOT NULL constraint failed: %s.%s", info->table, + parent.column_definitions[physical_index].GetName()); + } + } +} + +//===--------------------------------------------------------------------===// +// Statistics +//===--------------------------------------------------------------------===// +void RowGroupCollection::CopyStats(TableStatistics &other_stats) { + stats.CopyStats(other_stats); +} + +unique_ptr RowGroupCollection::CopyStats(column_t column_id) { + return stats.CopyStats(column_id); +} + +void RowGroupCollection::SetDistinct(column_t column_id, unique_ptr distinct_stats) { + D_ASSERT(column_id != COLUMN_IDENTIFIER_ROW_ID); + auto stats_guard = stats.GetLock(); + stats.GetStats(column_id).SetDistinct(std::move(distinct_stats)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/row_version_manager.cpp b/src/duckdb/src/storage/table/row_version_manager.cpp new file mode 100644 index 00000000..9fa7c46c --- /dev/null +++ b/src/duckdb/src/storage/table/row_version_manager.cpp @@ -0,0 +1,228 @@ +#include "duckdb/storage/table/row_version_manager.hpp" +#include "duckdb/transaction/transaction_data.hpp" +#include "duckdb/storage/metadata/metadata_manager.hpp" +#include "duckdb/storage/metadata/metadata_reader.hpp" +#include "duckdb/storage/metadata/metadata_writer.hpp" +#include "duckdb/common/pair.hpp" + +namespace duckdb { + +RowVersionManager::RowVersionManager(idx_t start) : start(start), has_changes(false) { +} + +void RowVersionManager::SetStart(idx_t new_start) { + lock_guard l(version_lock); + this->start = new_start; + idx_t current_start = start; + for (idx_t i = 0; i < Storage::ROW_GROUP_VECTOR_COUNT; i++) { + if (vector_info[i]) { + vector_info[i]->start = current_start; + } + current_start += STANDARD_VECTOR_SIZE; + } +} + +idx_t RowVersionManager::GetCommittedDeletedCount(idx_t count) { + lock_guard l(version_lock); + idx_t deleted_count = 0; + for (idx_t r = 0, i = 0; r < count; r += STANDARD_VECTOR_SIZE, i++) { + if (!vector_info[i]) { + continue; + } + idx_t max_count = MinValue(STANDARD_VECTOR_SIZE, count - r); + if (max_count == 0) { + break; + } + deleted_count += vector_info[i]->GetCommittedDeletedCount(max_count); + } + return deleted_count; +} + +optional_ptr RowVersionManager::GetChunkInfo(idx_t vector_idx) { + return vector_info[vector_idx].get(); +} + +idx_t RowVersionManager::GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, + idx_t max_count) { + lock_guard l(version_lock); + auto chunk_info = GetChunkInfo(vector_idx); + if (!chunk_info) { + return max_count; + } + return chunk_info->GetSelVector(transaction, sel_vector, max_count); +} + +idx_t RowVersionManager::GetCommittedSelVector(transaction_t start_time, transaction_t transaction_id, idx_t vector_idx, + SelectionVector &sel_vector, idx_t max_count) { + lock_guard l(version_lock); + auto info = GetChunkInfo(vector_idx); + if (!info) { + return max_count; + } + return info->GetCommittedSelVector(start_time, transaction_id, sel_vector, max_count); +} + +bool RowVersionManager::Fetch(TransactionData transaction, idx_t row) { + lock_guard lock(version_lock); + idx_t vector_index = row / STANDARD_VECTOR_SIZE; + auto info = GetChunkInfo(vector_index); + if (!info) { + return true; + } + return info->Fetch(transaction, row - vector_index * STANDARD_VECTOR_SIZE); +} + +void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t count, idx_t row_group_start, + idx_t row_group_end) { + lock_guard lock(version_lock); + has_changes = true; + idx_t start_vector_idx = row_group_start / STANDARD_VECTOR_SIZE; + idx_t end_vector_idx = (row_group_end - 1) / STANDARD_VECTOR_SIZE; + for (idx_t vector_idx = start_vector_idx; vector_idx <= end_vector_idx; vector_idx++) { + idx_t vector_start = + vector_idx == start_vector_idx ? row_group_start - start_vector_idx * STANDARD_VECTOR_SIZE : 0; + idx_t vector_end = + vector_idx == end_vector_idx ? row_group_end - end_vector_idx * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; + if (vector_start == 0 && vector_end == STANDARD_VECTOR_SIZE) { + // entire vector is encapsulated by append: append a single constant + auto constant_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + constant_info->insert_id = transaction.transaction_id; + constant_info->delete_id = NOT_DELETED_ID; + vector_info[vector_idx] = std::move(constant_info); + } else { + // part of a vector is encapsulated: append to that part + optional_ptr new_info; + if (!vector_info[vector_idx]) { + // first time appending to this vector: create new info + auto insert_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + new_info = insert_info.get(); + vector_info[vector_idx] = std::move(insert_info); + } else if (vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO) { + // use existing vector + new_info = &vector_info[vector_idx]->Cast(); + } else { + throw InternalException("Error in RowVersionManager::AppendVersionInfo - expected either a " + "ChunkVectorInfo or no version info"); + } + new_info->Append(vector_start, vector_end, transaction.transaction_id); + } + } +} + +void RowVersionManager::CommitAppend(transaction_t commit_id, idx_t row_group_start, idx_t count) { + idx_t row_group_end = row_group_start + count; + + lock_guard lock(version_lock); + idx_t start_vector_idx = row_group_start / STANDARD_VECTOR_SIZE; + idx_t end_vector_idx = (row_group_end - 1) / STANDARD_VECTOR_SIZE; + for (idx_t vector_idx = start_vector_idx; vector_idx <= end_vector_idx; vector_idx++) { + idx_t vstart = vector_idx == start_vector_idx ? row_group_start - start_vector_idx * STANDARD_VECTOR_SIZE : 0; + idx_t vend = + vector_idx == end_vector_idx ? row_group_end - end_vector_idx * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; + + auto info = vector_info[vector_idx].get(); + info->CommitAppend(commit_id, vstart, vend); + } +} + +void RowVersionManager::RevertAppend(idx_t start_row) { + lock_guard lock(version_lock); + idx_t start_vector_idx = (start_row + (STANDARD_VECTOR_SIZE - 1)) / STANDARD_VECTOR_SIZE; + for (idx_t vector_idx = start_vector_idx; vector_idx < Storage::ROW_GROUP_VECTOR_COUNT; vector_idx++) { + vector_info[vector_idx].reset(); + } +} + +ChunkVectorInfo &RowVersionManager::GetVectorInfo(idx_t vector_idx) { + if (!vector_info[vector_idx]) { + // no info yet: create it + vector_info[vector_idx] = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + } else if (vector_info[vector_idx]->type == ChunkInfoType::CONSTANT_INFO) { + auto &constant = vector_info[vector_idx]->Cast(); + // info exists but it's a constant info: convert to a vector info + auto new_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + new_info->insert_id = constant.insert_id; + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + new_info->inserted[i] = constant.insert_id; + } + vector_info[vector_idx] = std::move(new_info); + } + D_ASSERT(vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO); + return vector_info[vector_idx]->Cast(); +} + +idx_t RowVersionManager::DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t rows[], idx_t count) { + lock_guard lock(version_lock); + has_changes = true; + return GetVectorInfo(vector_idx).Delete(transaction_id, rows, count); +} + +void RowVersionManager::CommitDelete(idx_t vector_idx, transaction_t commit_id, row_t rows[], idx_t count) { + lock_guard lock(version_lock); + has_changes = true; + GetVectorInfo(vector_idx).CommitDelete(commit_id, rows, count); +} + +vector RowVersionManager::Checkpoint(MetadataManager &manager) { + if (!has_changes && !storage_pointers.empty()) { + // the row version manager already exists on disk and no changes were made + // we can write the current pointer as-is + // ensure the blocks we are pointing to are not marked as free + manager.ClearModifiedBlocks(storage_pointers); + // return the root pointer + return storage_pointers; + } + // first count how many ChunkInfo's we need to deserialize + vector>> to_serialize; + for (idx_t vector_idx = 0; vector_idx < Storage::ROW_GROUP_VECTOR_COUNT; vector_idx++) { + auto chunk_info = vector_info[vector_idx].get(); + if (!chunk_info) { + continue; + } + if (!chunk_info->HasDeletes()) { + continue; + } + to_serialize.emplace_back(vector_idx, *chunk_info); + } + if (to_serialize.empty()) { + return vector(); + } + + storage_pointers.clear(); + + MetadataWriter writer(manager, &storage_pointers); + // now serialize the actual version information + writer.Write(to_serialize.size()); + for (auto &entry : to_serialize) { + auto &vector_idx = entry.first; + auto &chunk_info = entry.second.get(); + writer.Write(vector_idx); + chunk_info.Write(writer); + } + writer.Flush(); + + has_changes = false; + return storage_pointers; +} + +shared_ptr RowVersionManager::Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager, + idx_t start) { + if (!delete_pointer.IsValid()) { + return nullptr; + } + auto version_info = make_shared(start); + MetadataReader source(manager, delete_pointer, &version_info->storage_pointers); + auto chunk_count = source.Read(); + D_ASSERT(chunk_count > 0); + for (idx_t i = 0; i < chunk_count; i++) { + idx_t vector_index = source.Read(); + if (vector_index >= Storage::ROW_GROUP_VECTOR_COUNT) { + throw Exception("In DeserializeDeletes, vector_index is out of range for the row group. Corrupted file?"); + } + version_info->vector_info[vector_index] = ChunkInfo::Read(source); + } + version_info->has_changes = false; + return version_info; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/scan_state.cpp b/src/duckdb/src/storage/table/scan_state.cpp new file mode 100644 index 00000000..a40a9b1f --- /dev/null +++ b/src/duckdb/src/storage/table/scan_state.cpp @@ -0,0 +1,137 @@ +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/row_group_collection.hpp" +#include "duckdb/storage/table/row_group_segment_tree.hpp" + +namespace duckdb { + +void TableScanState::Initialize(vector column_ids, TableFilterSet *table_filters) { + this->column_ids = std::move(column_ids); + this->table_filters = table_filters; + if (table_filters) { + D_ASSERT(table_filters->filters.size() > 0); + this->adaptive_filter = make_uniq(table_filters); + } +} + +const vector &TableScanState::GetColumnIds() { + D_ASSERT(!column_ids.empty()); + return column_ids; +} + +TableFilterSet *TableScanState::GetFilters() { + D_ASSERT(!table_filters || adaptive_filter.get()); + return table_filters; +} + +AdaptiveFilter *TableScanState::GetAdaptiveFilter() { + return adaptive_filter.get(); +} + +void ColumnScanState::NextInternal(idx_t count) { + if (!current) { + //! There is no column segment + return; + } + row_index += count; + while (row_index >= current->start + current->count) { + current = segment_tree->GetNextSegment(current); + initialized = false; + segment_checked = false; + if (!current) { + break; + } + } + D_ASSERT(!current || (row_index >= current->start && row_index < current->start + current->count)); +} + +void ColumnScanState::Next(idx_t count) { + NextInternal(count); + for (auto &child_state : child_states) { + child_state.Next(count); + } +} + +const vector &CollectionScanState::GetColumnIds() { + return parent.GetColumnIds(); +} + +TableFilterSet *CollectionScanState::GetFilters() { + return parent.GetFilters(); +} + +AdaptiveFilter *CollectionScanState::GetAdaptiveFilter() { + return parent.GetAdaptiveFilter(); +} + +ParallelCollectionScanState::ParallelCollectionScanState() + : collection(nullptr), current_row_group(nullptr), processed_rows(0) { +} + +CollectionScanState::CollectionScanState(TableScanState &parent_p) + : row_group(nullptr), vector_index(0), max_row_group_row(0), row_groups(nullptr), max_row(0), batch_index(0), + parent(parent_p) { +} + +bool CollectionScanState::Scan(DuckTransaction &transaction, DataChunk &result) { + while (row_group) { + row_group->Scan(transaction, *this, result); + if (result.size() > 0) { + return true; + } else if (max_row <= row_group->start + row_group->count) { + row_group = nullptr; + return false; + } else { + do { + row_group = row_groups->GetNextSegment(row_group); + if (row_group) { + if (row_group->start >= max_row) { + row_group = nullptr; + break; + } + bool scan_row_group = row_group->InitializeScan(*this); + if (scan_row_group) { + // scan this row group + break; + } + } + } while (row_group); + } + } + return false; +} + +bool CollectionScanState::ScanCommitted(DataChunk &result, SegmentLock &l, TableScanType type) { + while (row_group) { + row_group->ScanCommitted(*this, result, type); + if (result.size() > 0) { + return true; + } else { + row_group = row_groups->GetNextSegment(l, row_group); + if (row_group) { + row_group->InitializeScan(*this); + } + } + } + return false; +} + +bool CollectionScanState::ScanCommitted(DataChunk &result, TableScanType type) { + while (row_group) { + row_group->ScanCommitted(*this, result, type); + if (result.size() > 0) { + return true; + } else { + row_group = row_groups->GetNextSegment(row_group); + if (row_group) { + row_group->InitializeScan(*this); + } + } + } + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/standard_column_data.cpp b/src/duckdb/src/storage/table/standard_column_data.cpp new file mode 100644 index 00000000..e98ffb4e --- /dev/null +++ b/src/duckdb/src/storage/table/standard_column_data.cpp @@ -0,0 +1,227 @@ +#include "duckdb/storage/table/standard_column_data.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +StandardColumnData::StandardColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + idx_t start_row, LogicalType type, optional_ptr parent) + : ColumnData(block_manager, info, column_index, start_row, std::move(type), parent), + validity(block_manager, info, 0, start_row, *this) { +} + +void StandardColumnData::SetStart(idx_t new_start) { + ColumnData::SetStart(new_start); + validity.SetStart(new_start); +} + +bool StandardColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { + if (!state.segment_checked) { + if (!state.current) { + return true; + } + state.segment_checked = true; + auto prune_result = filter.CheckStatistics(state.current->stats.statistics); + if (prune_result != FilterPropagateResult::FILTER_ALWAYS_FALSE) { + return true; + } + if (updates) { + auto update_stats = updates->GetStatistics(); + prune_result = filter.CheckStatistics(*update_stats); + return prune_result != FilterPropagateResult::FILTER_ALWAYS_FALSE; + } else { + return false; + } + } else { + return true; + } +} + +void StandardColumnData::InitializeScan(ColumnScanState &state) { + ColumnData::InitializeScan(state); + + // initialize the validity segment + D_ASSERT(state.child_states.size() == 1); + validity.InitializeScan(state.child_states[0]); +} + +void StandardColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { + ColumnData::InitializeScanWithOffset(state, row_idx); + + // initialize the validity segment + D_ASSERT(state.child_states.size() == 1); + validity.InitializeScanWithOffset(state.child_states[0], row_idx); +} + +idx_t StandardColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, + Vector &result) { + D_ASSERT(state.row_index == state.child_states[0].row_index); + auto scan_count = ColumnData::Scan(transaction, vector_index, state, result); + validity.Scan(transaction, vector_index, state.child_states[0], result); + return scan_count; +} + +idx_t StandardColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, + bool allow_updates) { + D_ASSERT(state.row_index == state.child_states[0].row_index); + auto scan_count = ColumnData::ScanCommitted(vector_index, state, result, allow_updates); + validity.ScanCommitted(vector_index, state.child_states[0], result, allow_updates); + return scan_count; +} + +idx_t StandardColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count) { + auto scan_count = ColumnData::ScanCount(state, result, count); + validity.ScanCount(state.child_states[0], result, count); + return scan_count; +} + +void StandardColumnData::InitializeAppend(ColumnAppendState &state) { + ColumnData::InitializeAppend(state); + + ColumnAppendState child_append; + validity.InitializeAppend(child_append); + state.child_appends.push_back(std::move(child_append)); +} + +void StandardColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, + idx_t count) { + ColumnData::AppendData(stats, state, vdata, count); + validity.AppendData(stats, state.child_appends[0], vdata, count); +} + +void StandardColumnData::RevertAppend(row_t start_row) { + ColumnData::RevertAppend(start_row); + + validity.RevertAppend(start_row); +} + +idx_t StandardColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { + // fetch validity mask + if (state.child_states.empty()) { + ColumnScanState child_state; + state.child_states.push_back(std::move(child_state)); + } + auto scan_count = ColumnData::Fetch(state, row_id, result); + validity.Fetch(state.child_states[0], row_id, result); + return scan_count; +} + +void StandardColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, + idx_t update_count) { + ColumnData::Update(transaction, column_index, update_vector, row_ids, update_count); + validity.Update(transaction, column_index, update_vector, row_ids, update_count); +} + +void StandardColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { + if (depth >= column_path.size()) { + // update this column + ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + } else { + // update the child column (i.e. the validity column) + validity.UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, depth + 1); + } +} + +unique_ptr StandardColumnData::GetUpdateStatistics() { + auto stats = updates ? updates->GetStatistics() : nullptr; + auto validity_stats = validity.GetUpdateStatistics(); + if (!stats && !validity_stats) { + return nullptr; + } + if (!stats) { + stats = BaseStatistics::CreateEmpty(type).ToUnique(); + } + if (validity_stats) { + stats->Merge(*validity_stats); + } + return stats; +} + +void StandardColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + // find the segment the row belongs to + if (state.child_states.empty()) { + auto child_state = make_uniq(); + state.child_states.push_back(std::move(child_state)); + } + validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); + ColumnData::FetchRow(transaction, state, row_id, result, result_idx); +} + +void StandardColumnData::CommitDropColumn() { + ColumnData::CommitDropColumn(); + validity.CommitDropColumn(); +} + +struct StandardColumnCheckpointState : public ColumnCheckpointState { + StandardColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, + PartialBlockManager &partial_block_manager) + : ColumnCheckpointState(row_group, column_data, partial_block_manager) { + } + + unique_ptr validity_state; + +public: + unique_ptr GetStatistics() override { + D_ASSERT(global_stats); + return std::move(global_stats); + } + + void WriteDataPointers(RowGroupWriter &writer, Serializer &serializer) override { + ColumnCheckpointState::WriteDataPointers(writer, serializer); + serializer.WriteObject(101, "validity", + [&](Serializer &serializer) { validity_state->WriteDataPointers(writer, serializer); }); + } +}; + +unique_ptr +StandardColumnData::CreateCheckpointState(RowGroup &row_group, PartialBlockManager &partial_block_manager) { + return make_uniq(row_group, *this, partial_block_manager); +} + +unique_ptr StandardColumnData::Checkpoint(RowGroup &row_group, + PartialBlockManager &partial_block_manager, + ColumnCheckpointInfo &checkpoint_info) { + auto validity_state = validity.Checkpoint(row_group, partial_block_manager, checkpoint_info); + auto base_state = ColumnData::Checkpoint(row_group, partial_block_manager, checkpoint_info); + auto &checkpoint_state = base_state->Cast(); + checkpoint_state.validity_state = std::move(validity_state); + return base_state; +} + +void StandardColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, + idx_t count, Vector &scan_vector) { + ColumnData::CheckpointScan(segment, state, row_group_start, count, scan_vector); + + idx_t offset_in_row_group = state.row_index - row_group_start; + validity.ScanCommittedRange(row_group_start, offset_in_row_group, count, scan_vector); +} + +void StandardColumnData::DeserializeColumn(Deserializer &deserializer) { + ColumnData::DeserializeColumn(deserializer); + deserializer.ReadObject(101, "validity", + [&](Deserializer &deserializer) { validity.DeserializeColumn(deserializer); }); +} + +void StandardColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, + vector &result) { + ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); + col_path.push_back(0); + validity.GetColumnSegmentInfo(row_group_index, std::move(col_path), result); +} + +void StandardColumnData::Verify(RowGroup &parent) { +#ifdef DEBUG + ColumnData::Verify(parent); + validity.Verify(parent); +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/struct_column_data.cpp b/src/duckdb/src/storage/table/struct_column_data.cpp new file mode 100644 index 00000000..cc9b6432 --- /dev/null +++ b/src/duckdb/src/storage/table/struct_column_data.cpp @@ -0,0 +1,307 @@ +#include "duckdb/storage/table/struct_column_data.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { + +StructColumnData::StructColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + idx_t start_row, LogicalType type_p, optional_ptr parent) + : ColumnData(block_manager, info, column_index, start_row, std::move(type_p), parent), + validity(block_manager, info, 0, start_row, *this) { + D_ASSERT(type.InternalType() == PhysicalType::STRUCT); + auto &child_types = StructType::GetChildTypes(type); + D_ASSERT(child_types.size() > 0); + if (type.id() != LogicalTypeId::UNION && StructType::IsUnnamed(type)) { + throw InvalidInputException("A table cannot be created from an unnamed struct"); + } + // the sub column index, starting at 1 (0 is the validity mask) + idx_t sub_column_index = 1; + for (auto &child_type : child_types) { + sub_columns.push_back( + ColumnData::CreateColumnUnique(block_manager, info, sub_column_index, start_row, child_type.second, this)); + sub_column_index++; + } +} + +void StructColumnData::SetStart(idx_t new_start) { + this->start = new_start; + for (auto &sub_column : sub_columns) { + sub_column->SetStart(new_start); + } + validity.SetStart(new_start); +} + +bool StructColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { + // table filters are not supported yet for struct columns + return false; +} + +idx_t StructColumnData::GetMaxEntry() { + return sub_columns[0]->GetMaxEntry(); +} + +void StructColumnData::InitializeScan(ColumnScanState &state) { + D_ASSERT(state.child_states.size() == sub_columns.size() + 1); + state.row_index = 0; + state.current = nullptr; + + // initialize the validity segment + validity.InitializeScan(state.child_states[0]); + + // initialize the sub-columns + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->InitializeScan(state.child_states[i + 1]); + } +} + +void StructColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { + D_ASSERT(state.child_states.size() == sub_columns.size() + 1); + state.row_index = row_idx; + state.current = nullptr; + + // initialize the validity segment + validity.InitializeScanWithOffset(state.child_states[0], row_idx); + + // initialize the sub-columns + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->InitializeScanWithOffset(state.child_states[i + 1], row_idx); + } +} + +idx_t StructColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result) { + auto scan_count = validity.Scan(transaction, vector_index, state.child_states[0], result); + auto &child_entries = StructVector::GetEntries(result); + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->Scan(transaction, vector_index, state.child_states[i + 1], *child_entries[i]); + } + return scan_count; +} + +idx_t StructColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates) { + auto scan_count = validity.ScanCommitted(vector_index, state.child_states[0], result, allow_updates); + auto &child_entries = StructVector::GetEntries(result); + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->ScanCommitted(vector_index, state.child_states[i + 1], *child_entries[i], allow_updates); + } + return scan_count; +} + +idx_t StructColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count) { + auto scan_count = validity.ScanCount(state.child_states[0], result, count); + auto &child_entries = StructVector::GetEntries(result); + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->ScanCount(state.child_states[i + 1], *child_entries[i], count); + } + return scan_count; +} + +void StructColumnData::Skip(ColumnScanState &state, idx_t count) { + validity.Skip(state.child_states[0], count); + + // skip inside the sub-columns + for (idx_t child_idx = 0; child_idx < sub_columns.size(); child_idx++) { + sub_columns[child_idx]->Skip(state.child_states[child_idx + 1], count); + } +} + +void StructColumnData::InitializeAppend(ColumnAppendState &state) { + ColumnAppendState validity_append; + validity.InitializeAppend(validity_append); + state.child_appends.push_back(std::move(validity_append)); + + for (auto &sub_column : sub_columns) { + ColumnAppendState child_append; + sub_column->InitializeAppend(child_append); + state.child_appends.push_back(std::move(child_append)); + } +} + +void StructColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { + vector.Flatten(count); + + // append the null values + validity.Append(stats, state.child_appends[0], vector, count); + + auto &child_entries = StructVector::GetEntries(vector); + for (idx_t i = 0; i < child_entries.size(); i++) { + sub_columns[i]->Append(StructStats::GetChildStats(stats, i), state.child_appends[i + 1], *child_entries[i], + count); + } + this->count += count; +} + +void StructColumnData::RevertAppend(row_t start_row) { + validity.RevertAppend(start_row); + for (auto &sub_column : sub_columns) { + sub_column->RevertAppend(start_row); + } + this->count = start_row - this->start; +} + +idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { + // fetch validity mask + auto &child_entries = StructVector::GetEntries(result); + // insert any child states that are required + for (idx_t i = state.child_states.size(); i < child_entries.size() + 1; i++) { + ColumnScanState child_state; + state.child_states.push_back(std::move(child_state)); + } + // fetch the validity state + idx_t scan_count = validity.Fetch(state.child_states[0], row_id, result); + // fetch the sub-column states + for (idx_t i = 0; i < child_entries.size(); i++) { + sub_columns[i]->Fetch(state.child_states[i + 1], row_id, *child_entries[i]); + } + return scan_count; +} + +void StructColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, + idx_t update_count) { + validity.Update(transaction, column_index, update_vector, row_ids, update_count); + auto &child_entries = StructVector::GetEntries(update_vector); + for (idx_t i = 0; i < child_entries.size(); i++) { + sub_columns[i]->Update(transaction, column_index, *child_entries[i], row_ids, update_count); + } +} + +void StructColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { + // we can never DIRECTLY update a struct column + if (depth >= column_path.size()) { + throw InternalException("Attempting to directly update a struct column - this should not be possible"); + } + auto update_column = column_path[depth]; + if (update_column == 0) { + // update the validity column + validity.UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, depth + 1); + } else { + if (update_column > sub_columns.size()) { + throw InternalException("Update column_path out of range"); + } + sub_columns[update_column - 1]->UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, + depth + 1); + } +} + +unique_ptr StructColumnData::GetUpdateStatistics() { + // check if any child column has updates + auto stats = BaseStatistics::CreateEmpty(type); + auto validity_stats = validity.GetUpdateStatistics(); + if (validity_stats) { + stats.Merge(*validity_stats); + } + for (idx_t i = 0; i < sub_columns.size(); i++) { + auto child_stats = sub_columns[i]->GetUpdateStatistics(); + if (child_stats) { + StructStats::SetChildStats(stats, i, std::move(child_stats)); + } + } + return stats.ToUnique(); +} + +void StructColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + // fetch validity mask + auto &child_entries = StructVector::GetEntries(result); + // insert any child states that are required + for (idx_t i = state.child_states.size(); i < child_entries.size() + 1; i++) { + auto child_state = make_uniq(); + state.child_states.push_back(std::move(child_state)); + } + // fetch the validity state + validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); + // fetch the sub-column states + for (idx_t i = 0; i < child_entries.size(); i++) { + sub_columns[i]->FetchRow(transaction, *state.child_states[i + 1], row_id, *child_entries[i], result_idx); + } +} + +void StructColumnData::CommitDropColumn() { + validity.CommitDropColumn(); + for (auto &sub_column : sub_columns) { + sub_column->CommitDropColumn(); + } +} + +struct StructColumnCheckpointState : public ColumnCheckpointState { + StructColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, + PartialBlockManager &partial_block_manager) + : ColumnCheckpointState(row_group, column_data, partial_block_manager) { + global_stats = StructStats::CreateEmpty(column_data.type).ToUnique(); + } + + unique_ptr validity_state; + vector> child_states; + +public: + unique_ptr GetStatistics() override { + auto stats = StructStats::CreateEmpty(column_data.type); + for (idx_t i = 0; i < child_states.size(); i++) { + StructStats::SetChildStats(stats, i, child_states[i]->GetStatistics()); + } + return stats.ToUnique(); + } + + void WriteDataPointers(RowGroupWriter &writer, Serializer &serializer) override { + serializer.WriteObject(101, "validity", + [&](Serializer &serializer) { validity_state->WriteDataPointers(writer, serializer); }); + serializer.WriteList(102, "sub_columns", child_states.size(), [&](Serializer::List &list, idx_t i) { + auto &state = child_states[i]; + list.WriteObject([&](Serializer &serializer) { state->WriteDataPointers(writer, serializer); }); + }); + } +}; + +unique_ptr StructColumnData::CreateCheckpointState(RowGroup &row_group, + PartialBlockManager &partial_block_manager) { + return make_uniq(row_group, *this, partial_block_manager); +} + +unique_ptr StructColumnData::Checkpoint(RowGroup &row_group, + PartialBlockManager &partial_block_manager, + ColumnCheckpointInfo &checkpoint_info) { + auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); + checkpoint_state->validity_state = validity.Checkpoint(row_group, partial_block_manager, checkpoint_info); + for (auto &sub_column : sub_columns) { + checkpoint_state->child_states.push_back( + sub_column->Checkpoint(row_group, partial_block_manager, checkpoint_info)); + } + return std::move(checkpoint_state); +} + +void StructColumnData::DeserializeColumn(Deserializer &deserializer) { + deserializer.ReadObject(101, "validity", + [&](Deserializer &deserializer) { validity.DeserializeColumn(deserializer); }); + + deserializer.ReadList(102, "sub_columns", [&](Deserializer::List &list, idx_t i) { + list.ReadObject([&](Deserializer &item) { sub_columns[i]->DeserializeColumn(item); }); + }); + + this->count = validity.count; +} + +void StructColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, + vector &result) { + col_path.push_back(0); + validity.GetColumnSegmentInfo(row_group_index, col_path, result); + for (idx_t i = 0; i < sub_columns.size(); i++) { + col_path.back() = i + 1; + sub_columns[i]->GetColumnSegmentInfo(row_group_index, col_path, result); + } +} + +void StructColumnData::Verify(RowGroup &parent) { +#ifdef DEBUG + ColumnData::Verify(parent); + validity.Verify(parent); + for (auto &sub_column : sub_columns) { + sub_column->Verify(parent); + } +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/table_statistics.cpp b/src/duckdb/src/storage/table/table_statistics.cpp new file mode 100644 index 00000000..4df780f5 --- /dev/null +++ b/src/duckdb/src/storage/table/table_statistics.cpp @@ -0,0 +1,133 @@ +#include "duckdb/storage/table/table_statistics.hpp" +#include "duckdb/storage/table/persistent_table_data.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +void TableStatistics::Initialize(const vector &types, PersistentTableData &data) { + D_ASSERT(Empty()); + + column_stats = std::move(data.table_stats.column_stats); + if (column_stats.size() != types.size()) { // LCOV_EXCL_START + throw IOException("Table statistics column count is not aligned with table column count. Corrupt file?"); + } // LCOV_EXCL_STOP +} + +void TableStatistics::InitializeEmpty(const vector &types) { + D_ASSERT(Empty()); + + for (auto &type : types) { + column_stats.push_back(ColumnStatistics::CreateEmptyStats(type)); + } +} + +void TableStatistics::InitializeAddColumn(TableStatistics &parent, const LogicalType &new_column_type) { + D_ASSERT(Empty()); + + lock_guard stats_lock(parent.stats_lock); + for (idx_t i = 0; i < parent.column_stats.size(); i++) { + column_stats.push_back(parent.column_stats[i]); + } + column_stats.push_back(ColumnStatistics::CreateEmptyStats(new_column_type)); +} + +void TableStatistics::InitializeRemoveColumn(TableStatistics &parent, idx_t removed_column) { + D_ASSERT(Empty()); + + lock_guard stats_lock(parent.stats_lock); + for (idx_t i = 0; i < parent.column_stats.size(); i++) { + if (i != removed_column) { + column_stats.push_back(parent.column_stats[i]); + } + } +} + +void TableStatistics::InitializeAlterType(TableStatistics &parent, idx_t changed_idx, const LogicalType &new_type) { + D_ASSERT(Empty()); + + lock_guard stats_lock(parent.stats_lock); + for (idx_t i = 0; i < parent.column_stats.size(); i++) { + if (i == changed_idx) { + column_stats.push_back(ColumnStatistics::CreateEmptyStats(new_type)); + } else { + column_stats.push_back(parent.column_stats[i]); + } + } +} + +void TableStatistics::InitializeAddConstraint(TableStatistics &parent) { + D_ASSERT(Empty()); + + lock_guard stats_lock(parent.stats_lock); + for (idx_t i = 0; i < parent.column_stats.size(); i++) { + column_stats.push_back(parent.column_stats[i]); + } +} + +void TableStatistics::MergeStats(TableStatistics &other) { + auto l = GetLock(); + D_ASSERT(column_stats.size() == other.column_stats.size()); + for (idx_t i = 0; i < column_stats.size(); i++) { + column_stats[i]->Merge(*other.column_stats[i]); + } +} + +void TableStatistics::MergeStats(idx_t i, BaseStatistics &stats) { + auto l = GetLock(); + MergeStats(*l, i, stats); +} + +void TableStatistics::MergeStats(TableStatisticsLock &lock, idx_t i, BaseStatistics &stats) { + column_stats[i]->Statistics().Merge(stats); +} + +ColumnStatistics &TableStatistics::GetStats(idx_t i) { + return *column_stats[i]; +} + +unique_ptr TableStatistics::CopyStats(idx_t i) { + lock_guard l(stats_lock); + auto result = column_stats[i]->Statistics().Copy(); + if (column_stats[i]->HasDistinctStats()) { + result.SetDistinctCount(column_stats[i]->DistinctStats().GetCount()); + } + return result.ToUnique(); +} + +void TableStatistics::CopyStats(TableStatistics &other) { + for (auto &stats : column_stats) { + other.column_stats.push_back(stats->Copy()); + } +} + +void TableStatistics::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "column_stats", column_stats); +} + +void TableStatistics::Deserialize(Deserializer &deserializer, ColumnList &columns) { + auto physical_columns = columns.Physical(); + + auto iter = physical_columns.begin(); + deserializer.ReadList(100, "column_stats", [&](Deserializer::List &list, idx_t i) { + auto &col = *iter; + iter.operator++(); + + auto type = col.GetType(); + deserializer.Set(type); + + column_stats.push_back(list.ReadElement>()); + + deserializer.Unset(); + }); +} + +unique_ptr TableStatistics::GetLock() { + return make_uniq(stats_lock); +} + +bool TableStatistics::Empty() { + return column_stats.empty(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/update_segment.cpp b/src/duckdb/src/storage/table/update_segment.cpp new file mode 100644 index 00000000..3e1cab7d --- /dev/null +++ b/src/duckdb/src/storage/table/update_segment.cpp @@ -0,0 +1,1221 @@ +#include "duckdb/storage/table/update_segment.hpp" + +#include "duckdb/storage/statistics/distinct_statistics.hpp" + +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/transaction/update_info.hpp" +#include "duckdb/common/printer.hpp" + +#include + +namespace duckdb { + +static UpdateSegment::initialize_update_function_t GetInitializeUpdateFunction(PhysicalType type); +static UpdateSegment::fetch_update_function_t GetFetchUpdateFunction(PhysicalType type); +static UpdateSegment::fetch_committed_function_t GetFetchCommittedFunction(PhysicalType type); +static UpdateSegment::fetch_committed_range_function_t GetFetchCommittedRangeFunction(PhysicalType type); + +static UpdateSegment::merge_update_function_t GetMergeUpdateFunction(PhysicalType type); +static UpdateSegment::rollback_update_function_t GetRollbackUpdateFunction(PhysicalType type); +static UpdateSegment::statistics_update_function_t GetStatisticsUpdateFunction(PhysicalType type); +static UpdateSegment::fetch_row_function_t GetFetchRowFunction(PhysicalType type); + +UpdateSegment::UpdateSegment(ColumnData &column_data) + : column_data(column_data), stats(column_data.type), heap(BufferAllocator::Get(column_data.GetDatabase())) { + auto physical_type = column_data.type.InternalType(); + + this->type_size = GetTypeIdSize(physical_type); + + this->initialize_update_function = GetInitializeUpdateFunction(physical_type); + this->fetch_update_function = GetFetchUpdateFunction(physical_type); + this->fetch_committed_function = GetFetchCommittedFunction(physical_type); + this->fetch_committed_range = GetFetchCommittedRangeFunction(physical_type); + this->fetch_row_function = GetFetchRowFunction(physical_type); + this->merge_update_function = GetMergeUpdateFunction(physical_type); + this->rollback_update_function = GetRollbackUpdateFunction(physical_type); + this->statistics_update_function = GetStatisticsUpdateFunction(physical_type); +} + +UpdateSegment::~UpdateSegment() { +} + +//===--------------------------------------------------------------------===// +// Update Info Helpers +//===--------------------------------------------------------------------===// +Value UpdateInfo::GetValue(idx_t index) { + auto &type = segment->column_data.type; + + switch (type.id()) { + case LogicalTypeId::VALIDITY: + return Value::BOOLEAN(reinterpret_cast(tuple_data)[index]); + case LogicalTypeId::INTEGER: + return Value::INTEGER(reinterpret_cast(tuple_data)[index]); + default: + throw NotImplementedException("Unimplemented type for UpdateInfo::GetValue"); + } +} + +void UpdateInfo::Print() { + Printer::Print(ToString()); +} + +string UpdateInfo::ToString() { + auto &type = segment->column_data.type; + string result = "Update Info [" + type.ToString() + ", Count: " + to_string(N) + + ", Transaction Id: " + to_string(version_number) + "]\n"; + for (idx_t i = 0; i < N; i++) { + result += to_string(tuples[i]) + ": " + GetValue(i).ToString() + "\n"; + } + if (next) { + result += "\nChild Segment: " + next->ToString(); + } + return result; +} + +void UpdateInfo::Verify() { +#ifdef DEBUG + for (idx_t i = 1; i < N; i++) { + D_ASSERT(tuples[i] > tuples[i - 1] && tuples[i] < STANDARD_VECTOR_SIZE); + } +#endif +} + +//===--------------------------------------------------------------------===// +// Update Fetch +//===--------------------------------------------------------------------===// +static void MergeValidityInfo(UpdateInfo *current, ValidityMask &result_mask) { + auto info_data = reinterpret_cast(current->tuple_data); + for (idx_t i = 0; i < current->N; i++) { + result_mask.Set(current->tuples[i], info_data[i]); + } +} + +static void UpdateMergeValidity(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, + Vector &result) { + auto &result_mask = FlatVector::Validity(result); + UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, + [&](UpdateInfo *current) { MergeValidityInfo(current, result_mask); }); +} + +template +static void MergeUpdateInfo(UpdateInfo *current, T *result_data) { + auto info_data = reinterpret_cast(current->tuple_data); + if (current->N == STANDARD_VECTOR_SIZE) { + // special case: update touches ALL tuples of this vector + // in this case we can just memcpy the data + // since the layout of the update info is guaranteed to be [0, 1, 2, 3, ...] + memcpy(result_data, info_data, sizeof(T) * current->N); + } else { + for (idx_t i = 0; i < current->N; i++) { + result_data[current->tuples[i]] = info_data[i]; + } + } +} + +template +static void UpdateMergeFetch(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, Vector &result) { + auto result_data = FlatVector::GetData(result); + UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, + [&](UpdateInfo *current) { MergeUpdateInfo(current, result_data); }); +} + +static UpdateSegment::fetch_update_function_t GetFetchUpdateFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + return UpdateMergeValidity; + case PhysicalType::BOOL: + case PhysicalType::INT8: + return UpdateMergeFetch; + case PhysicalType::INT16: + return UpdateMergeFetch; + case PhysicalType::INT32: + return UpdateMergeFetch; + case PhysicalType::INT64: + return UpdateMergeFetch; + case PhysicalType::UINT8: + return UpdateMergeFetch; + case PhysicalType::UINT16: + return UpdateMergeFetch; + case PhysicalType::UINT32: + return UpdateMergeFetch; + case PhysicalType::UINT64: + return UpdateMergeFetch; + case PhysicalType::INT128: + return UpdateMergeFetch; + case PhysicalType::FLOAT: + return UpdateMergeFetch; + case PhysicalType::DOUBLE: + return UpdateMergeFetch; + case PhysicalType::INTERVAL: + return UpdateMergeFetch; + case PhysicalType::VARCHAR: + return UpdateMergeFetch; + default: + throw NotImplementedException("Unimplemented type for update segment"); + } +} + +void UpdateSegment::FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result) { + auto lock_handle = lock.GetSharedLock(); + if (!root) { + return; + } + if (!root->info[vector_index]) { + return; + } + // FIXME: normalify if this is not the case... need to pass in count? + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + + fetch_update_function(transaction.start_time, transaction.transaction_id, root->info[vector_index]->info.get(), + result); +} + +//===--------------------------------------------------------------------===// +// Fetch Committed +//===--------------------------------------------------------------------===// +static void FetchCommittedValidity(UpdateInfo *info, Vector &result) { + auto &result_mask = FlatVector::Validity(result); + MergeValidityInfo(info, result_mask); +} + +template +static void TemplatedFetchCommitted(UpdateInfo *info, Vector &result) { + auto result_data = FlatVector::GetData(result); + MergeUpdateInfo(info, result_data); +} + +static UpdateSegment::fetch_committed_function_t GetFetchCommittedFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + return FetchCommittedValidity; + case PhysicalType::BOOL: + case PhysicalType::INT8: + return TemplatedFetchCommitted; + case PhysicalType::INT16: + return TemplatedFetchCommitted; + case PhysicalType::INT32: + return TemplatedFetchCommitted; + case PhysicalType::INT64: + return TemplatedFetchCommitted; + case PhysicalType::UINT8: + return TemplatedFetchCommitted; + case PhysicalType::UINT16: + return TemplatedFetchCommitted; + case PhysicalType::UINT32: + return TemplatedFetchCommitted; + case PhysicalType::UINT64: + return TemplatedFetchCommitted; + case PhysicalType::INT128: + return TemplatedFetchCommitted; + case PhysicalType::FLOAT: + return TemplatedFetchCommitted; + case PhysicalType::DOUBLE: + return TemplatedFetchCommitted; + case PhysicalType::INTERVAL: + return TemplatedFetchCommitted; + case PhysicalType::VARCHAR: + return TemplatedFetchCommitted; + default: + throw NotImplementedException("Unimplemented type for update segment"); + } +} + +void UpdateSegment::FetchCommitted(idx_t vector_index, Vector &result) { + auto lock_handle = lock.GetSharedLock(); + + if (!root) { + return; + } + if (!root->info[vector_index]) { + return; + } + // FIXME: normalify if this is not the case... need to pass in count? + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + + fetch_committed_function(root->info[vector_index]->info.get(), result); +} + +//===--------------------------------------------------------------------===// +// Fetch Range +//===--------------------------------------------------------------------===// +static void MergeUpdateInfoRangeValidity(UpdateInfo *current, idx_t start, idx_t end, idx_t result_offset, + ValidityMask &result_mask) { + auto info_data = reinterpret_cast(current->tuple_data); + for (idx_t i = 0; i < current->N; i++) { + auto tuple_idx = current->tuples[i]; + if (tuple_idx < start) { + continue; + } else if (tuple_idx >= end) { + break; + } + auto result_idx = result_offset + tuple_idx - start; + result_mask.Set(result_idx, info_data[i]); + } +} + +static void FetchCommittedRangeValidity(UpdateInfo *info, idx_t start, idx_t end, idx_t result_offset, Vector &result) { + auto &result_mask = FlatVector::Validity(result); + MergeUpdateInfoRangeValidity(info, start, end, result_offset, result_mask); +} + +template +static void MergeUpdateInfoRange(UpdateInfo *current, idx_t start, idx_t end, idx_t result_offset, T *result_data) { + auto info_data = reinterpret_cast(current->tuple_data); + for (idx_t i = 0; i < current->N; i++) { + auto tuple_idx = current->tuples[i]; + if (tuple_idx < start) { + continue; + } else if (tuple_idx >= end) { + break; + } + auto result_idx = result_offset + tuple_idx - start; + result_data[result_idx] = info_data[i]; + } +} + +template +static void TemplatedFetchCommittedRange(UpdateInfo *info, idx_t start, idx_t end, idx_t result_offset, + Vector &result) { + auto result_data = FlatVector::GetData(result); + MergeUpdateInfoRange(info, start, end, result_offset, result_data); +} + +static UpdateSegment::fetch_committed_range_function_t GetFetchCommittedRangeFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + return FetchCommittedRangeValidity; + case PhysicalType::BOOL: + case PhysicalType::INT8: + return TemplatedFetchCommittedRange; + case PhysicalType::INT16: + return TemplatedFetchCommittedRange; + case PhysicalType::INT32: + return TemplatedFetchCommittedRange; + case PhysicalType::INT64: + return TemplatedFetchCommittedRange; + case PhysicalType::UINT8: + return TemplatedFetchCommittedRange; + case PhysicalType::UINT16: + return TemplatedFetchCommittedRange; + case PhysicalType::UINT32: + return TemplatedFetchCommittedRange; + case PhysicalType::UINT64: + return TemplatedFetchCommittedRange; + case PhysicalType::INT128: + return TemplatedFetchCommittedRange; + case PhysicalType::FLOAT: + return TemplatedFetchCommittedRange; + case PhysicalType::DOUBLE: + return TemplatedFetchCommittedRange; + case PhysicalType::INTERVAL: + return TemplatedFetchCommittedRange; + case PhysicalType::VARCHAR: + return TemplatedFetchCommittedRange; + default: + throw NotImplementedException("Unimplemented type for update segment"); + } +} + +void UpdateSegment::FetchCommittedRange(idx_t start_row, idx_t count, Vector &result) { + D_ASSERT(count > 0); + if (!root) { + return; + } + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + + idx_t end_row = start_row + count; + idx_t start_vector = start_row / STANDARD_VECTOR_SIZE; + idx_t end_vector = (end_row - 1) / STANDARD_VECTOR_SIZE; + D_ASSERT(start_vector <= end_vector); + D_ASSERT(end_vector < Storage::ROW_GROUP_VECTOR_COUNT); + + for (idx_t vector_idx = start_vector; vector_idx <= end_vector; vector_idx++) { + if (!root->info[vector_idx]) { + continue; + } + idx_t start_in_vector = vector_idx == start_vector ? start_row - start_vector * STANDARD_VECTOR_SIZE : 0; + idx_t end_in_vector = + vector_idx == end_vector ? end_row - end_vector * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; + D_ASSERT(start_in_vector < end_in_vector); + D_ASSERT(end_in_vector > 0 && end_in_vector <= STANDARD_VECTOR_SIZE); + idx_t result_offset = ((vector_idx * STANDARD_VECTOR_SIZE) + start_in_vector) - start_row; + fetch_committed_range(root->info[vector_idx]->info.get(), start_in_vector, end_in_vector, result_offset, + result); + } +} + +//===--------------------------------------------------------------------===// +// Fetch Row +//===--------------------------------------------------------------------===// +static void FetchRowValidity(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, idx_t row_idx, + Vector &result, idx_t result_idx) { + auto &result_mask = FlatVector::Validity(result); + UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, [&](UpdateInfo *current) { + auto info_data = reinterpret_cast(current->tuple_data); + // FIXME: we could do a binary search in here + for (idx_t i = 0; i < current->N; i++) { + if (current->tuples[i] == row_idx) { + result_mask.Set(result_idx, info_data[i]); + break; + } else if (current->tuples[i] > row_idx) { + break; + } + } + }); +} + +template +static void TemplatedFetchRow(transaction_t start_time, transaction_t transaction_id, UpdateInfo *info, idx_t row_idx, + Vector &result, idx_t result_idx) { + auto result_data = FlatVector::GetData(result); + UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, [&](UpdateInfo *current) { + auto info_data = (T *)current->tuple_data; + // FIXME: we could do a binary search in here + for (idx_t i = 0; i < current->N; i++) { + if (current->tuples[i] == row_idx) { + result_data[result_idx] = info_data[i]; + break; + } else if (current->tuples[i] > row_idx) { + break; + } + } + }); +} + +static UpdateSegment::fetch_row_function_t GetFetchRowFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + return FetchRowValidity; + case PhysicalType::BOOL: + case PhysicalType::INT8: + return TemplatedFetchRow; + case PhysicalType::INT16: + return TemplatedFetchRow; + case PhysicalType::INT32: + return TemplatedFetchRow; + case PhysicalType::INT64: + return TemplatedFetchRow; + case PhysicalType::UINT8: + return TemplatedFetchRow; + case PhysicalType::UINT16: + return TemplatedFetchRow; + case PhysicalType::UINT32: + return TemplatedFetchRow; + case PhysicalType::UINT64: + return TemplatedFetchRow; + case PhysicalType::INT128: + return TemplatedFetchRow; + case PhysicalType::FLOAT: + return TemplatedFetchRow; + case PhysicalType::DOUBLE: + return TemplatedFetchRow; + case PhysicalType::INTERVAL: + return TemplatedFetchRow; + case PhysicalType::VARCHAR: + return TemplatedFetchRow; + default: + throw NotImplementedException("Unimplemented type for update segment fetch row"); + } +} + +void UpdateSegment::FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx) { + if (!root) { + return; + } + idx_t vector_index = (row_id - column_data.start) / STANDARD_VECTOR_SIZE; + if (!root->info[vector_index]) { + return; + } + idx_t row_in_vector = (row_id - column_data.start) - vector_index * STANDARD_VECTOR_SIZE; + fetch_row_function(transaction.start_time, transaction.transaction_id, root->info[vector_index]->info.get(), + row_in_vector, result, result_idx); +} + +//===--------------------------------------------------------------------===// +// Rollback update +//===--------------------------------------------------------------------===// +template +static void RollbackUpdate(UpdateInfo &base_info, UpdateInfo &rollback_info) { + auto base_data = (T *)base_info.tuple_data; + auto rollback_data = (T *)rollback_info.tuple_data; + idx_t base_offset = 0; + for (idx_t i = 0; i < rollback_info.N; i++) { + auto id = rollback_info.tuples[i]; + while (base_info.tuples[base_offset] < id) { + base_offset++; + D_ASSERT(base_offset < base_info.N); + } + base_data[base_offset] = rollback_data[i]; + } +} + +static UpdateSegment::rollback_update_function_t GetRollbackUpdateFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + return RollbackUpdate; + case PhysicalType::BOOL: + case PhysicalType::INT8: + return RollbackUpdate; + case PhysicalType::INT16: + return RollbackUpdate; + case PhysicalType::INT32: + return RollbackUpdate; + case PhysicalType::INT64: + return RollbackUpdate; + case PhysicalType::UINT8: + return RollbackUpdate; + case PhysicalType::UINT16: + return RollbackUpdate; + case PhysicalType::UINT32: + return RollbackUpdate; + case PhysicalType::UINT64: + return RollbackUpdate; + case PhysicalType::INT128: + return RollbackUpdate; + case PhysicalType::FLOAT: + return RollbackUpdate; + case PhysicalType::DOUBLE: + return RollbackUpdate; + case PhysicalType::INTERVAL: + return RollbackUpdate; + case PhysicalType::VARCHAR: + return RollbackUpdate; + default: + throw NotImplementedException("Unimplemented type for uncompressed segment"); + } +} + +void UpdateSegment::RollbackUpdate(UpdateInfo &info) { + // obtain an exclusive lock + auto lock_handle = lock.GetExclusiveLock(); + + // move the data from the UpdateInfo back into the base info + D_ASSERT(root->info[info.vector_index]); + rollback_update_function(*root->info[info.vector_index]->info, info); + + // clean up the update chain + CleanupUpdateInternal(*lock_handle, info); +} + +//===--------------------------------------------------------------------===// +// Cleanup Update +//===--------------------------------------------------------------------===// +void UpdateSegment::CleanupUpdateInternal(const StorageLockKey &lock, UpdateInfo &info) { + D_ASSERT(info.prev); + auto prev = info.prev; + prev->next = info.next; + if (prev->next) { + prev->next->prev = prev; + } +} + +void UpdateSegment::CleanupUpdate(UpdateInfo &info) { + // obtain an exclusive lock + auto lock_handle = lock.GetExclusiveLock(); + CleanupUpdateInternal(*lock_handle, info); +} + +//===--------------------------------------------------------------------===// +// Check for conflicts in update +//===--------------------------------------------------------------------===// +static void CheckForConflicts(UpdateInfo *info, TransactionData transaction, row_t *ids, const SelectionVector &sel, + idx_t count, row_t offset, UpdateInfo *&node) { + if (!info) { + return; + } + if (info->version_number == transaction.transaction_id) { + // this UpdateInfo belongs to the current transaction, set it in the node + node = info; + } else if (info->version_number > transaction.start_time) { + // potential conflict, check that tuple ids do not conflict + // as both ids and info->tuples are sorted, this is similar to a merge join + idx_t i = 0, j = 0; + while (true) { + auto id = ids[sel.get_index(i)] - offset; + if (id == info->tuples[j]) { + throw TransactionException("Conflict on update!"); + } else if (id < info->tuples[j]) { + // id < the current tuple in info, move to next id + i++; + if (i == count) { + break; + } + } else { + // id > the current tuple, move to next tuple in info + j++; + if (j == info->N) { + break; + } + } + } + } + CheckForConflicts(info->next, transaction, ids, sel, count, offset, node); +} + +//===--------------------------------------------------------------------===// +// Initialize update info +//===--------------------------------------------------------------------===// +void UpdateSegment::InitializeUpdateInfo(UpdateInfo &info, row_t *ids, const SelectionVector &sel, idx_t count, + idx_t vector_index, idx_t vector_offset) { + info.segment = this; + info.vector_index = vector_index; + info.prev = nullptr; + info.next = nullptr; + + // set up the tuple ids + info.N = count; + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + auto id = ids[idx]; + D_ASSERT(idx_t(id) >= vector_offset && idx_t(id) < vector_offset + STANDARD_VECTOR_SIZE); + info.tuples[i] = id - vector_offset; + }; +} + +static void InitializeUpdateValidity(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, Vector &update, + const SelectionVector &sel) { + auto &update_mask = FlatVector::Validity(update); + auto tuple_data = reinterpret_cast(update_info->tuple_data); + + if (!update_mask.AllValid()) { + for (idx_t i = 0; i < update_info->N; i++) { + auto idx = sel.get_index(i); + tuple_data[i] = update_mask.RowIsValidUnsafe(idx); + } + } else { + for (idx_t i = 0; i < update_info->N; i++) { + tuple_data[i] = true; + } + } + + auto &base_mask = FlatVector::Validity(base_data); + auto base_tuple_data = reinterpret_cast(base_info->tuple_data); + if (!base_mask.AllValid()) { + for (idx_t i = 0; i < base_info->N; i++) { + base_tuple_data[i] = base_mask.RowIsValidUnsafe(base_info->tuples[i]); + } + } else { + for (idx_t i = 0; i < base_info->N; i++) { + base_tuple_data[i] = true; + } + } +} + +struct UpdateSelectElement { + template + static T Operation(UpdateSegment *segment, T element) { + return element; + } +}; + +template <> +string_t UpdateSelectElement::Operation(UpdateSegment *segment, string_t element) { + return element.IsInlined() ? element : segment->GetStringHeap().AddBlob(element); +} + +template +static void InitializeUpdateData(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, Vector &update, + const SelectionVector &sel) { + auto update_data = FlatVector::GetData(update); + auto tuple_data = (T *)update_info->tuple_data; + + for (idx_t i = 0; i < update_info->N; i++) { + auto idx = sel.get_index(i); + tuple_data[i] = update_data[idx]; + } + + auto base_array_data = FlatVector::GetData(base_data); + auto &base_validity = FlatVector::Validity(base_data); + auto base_tuple_data = (T *)base_info->tuple_data; + for (idx_t i = 0; i < base_info->N; i++) { + auto base_idx = base_info->tuples[i]; + if (!base_validity.RowIsValid(base_idx)) { + continue; + } + base_tuple_data[i] = UpdateSelectElement::Operation(base_info->segment, base_array_data[base_idx]); + } +} + +static UpdateSegment::initialize_update_function_t GetInitializeUpdateFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + return InitializeUpdateValidity; + case PhysicalType::BOOL: + case PhysicalType::INT8: + return InitializeUpdateData; + case PhysicalType::INT16: + return InitializeUpdateData; + case PhysicalType::INT32: + return InitializeUpdateData; + case PhysicalType::INT64: + return InitializeUpdateData; + case PhysicalType::UINT8: + return InitializeUpdateData; + case PhysicalType::UINT16: + return InitializeUpdateData; + case PhysicalType::UINT32: + return InitializeUpdateData; + case PhysicalType::UINT64: + return InitializeUpdateData; + case PhysicalType::INT128: + return InitializeUpdateData; + case PhysicalType::FLOAT: + return InitializeUpdateData; + case PhysicalType::DOUBLE: + return InitializeUpdateData; + case PhysicalType::INTERVAL: + return InitializeUpdateData; + case PhysicalType::VARCHAR: + return InitializeUpdateData; + default: + throw NotImplementedException("Unimplemented type for update segment"); + } +} + +//===--------------------------------------------------------------------===// +// Merge update info +//===--------------------------------------------------------------------===// +template +static idx_t MergeLoop(row_t a[], sel_t b[], idx_t acount, idx_t bcount, idx_t aoffset, F1 merge, F2 pick_a, F3 pick_b, + const SelectionVector &asel) { + idx_t aidx = 0, bidx = 0; + idx_t count = 0; + while (aidx < acount && bidx < bcount) { + auto a_index = asel.get_index(aidx); + auto a_id = a[a_index] - aoffset; + auto b_id = b[bidx]; + if (a_id == b_id) { + merge(a_id, a_index, bidx, count); + aidx++; + bidx++; + count++; + } else if (a_id < b_id) { + pick_a(a_id, a_index, count); + aidx++; + count++; + } else { + pick_b(b_id, bidx, count); + bidx++; + count++; + } + } + for (; aidx < acount; aidx++) { + auto a_index = asel.get_index(aidx); + pick_a(a[a_index] - aoffset, a_index, count); + count++; + } + for (; bidx < bcount; bidx++) { + pick_b(b[bidx], bidx, count); + count++; + } + return count; +} + +struct ExtractStandardEntry { + template + static T Extract(V *data, idx_t entry) { + return data[entry]; + } +}; + +struct ExtractValidityEntry { + template + static T Extract(V *data, idx_t entry) { + return data->RowIsValid(entry); + } +}; + +template +static void MergeUpdateLoopInternal(UpdateInfo *base_info, V *base_table_data, UpdateInfo *update_info, + V *update_vector_data, row_t *ids, idx_t count, const SelectionVector &sel) { + auto base_id = base_info->segment->column_data.start + base_info->vector_index * STANDARD_VECTOR_SIZE; +#ifdef DEBUG + // all of these should be sorted, otherwise the below algorithm does not work + for (idx_t i = 1; i < count; i++) { + auto prev_idx = sel.get_index(i - 1); + auto idx = sel.get_index(i); + D_ASSERT(ids[idx] > ids[prev_idx] && ids[idx] >= row_t(base_id) && + ids[idx] < row_t(base_id + STANDARD_VECTOR_SIZE)); + } +#endif + + // we have a new batch of updates (update, ids, count) + // we already have existing updates (base_info) + // and potentially, this transaction already has updates present (update_info) + // we need to merge these all together so that the latest updates get merged into base_info + // and the "old" values (fetched from EITHER base_info OR from base_data) get placed into update_info + auto base_info_data = (T *)base_info->tuple_data; + auto update_info_data = (T *)update_info->tuple_data; + + // we first do the merging of the old values + // what we are trying to do here is update the "update_info" of this transaction with all the old data we require + // this means we need to merge (1) any previously updated values (stored in update_info->tuples) + // together with (2) + // to simplify this, we create new arrays here + // we memcpy these over afterwards + T result_values[STANDARD_VECTOR_SIZE]; + sel_t result_ids[STANDARD_VECTOR_SIZE]; + + idx_t base_info_offset = 0; + idx_t update_info_offset = 0; + idx_t result_offset = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = sel.get_index(i); + // we have to merge the info for "ids[i]" + auto update_id = ids[idx] - base_id; + + while (update_info_offset < update_info->N && update_info->tuples[update_info_offset] < update_id) { + // old id comes before the current id: write it + result_values[result_offset] = update_info_data[update_info_offset]; + result_ids[result_offset++] = update_info->tuples[update_info_offset]; + update_info_offset++; + } + // write the new id + if (update_info_offset < update_info->N && update_info->tuples[update_info_offset] == update_id) { + // we have an id that is equivalent in the current update info: write the update info + result_values[result_offset] = update_info_data[update_info_offset]; + result_ids[result_offset++] = update_info->tuples[update_info_offset]; + update_info_offset++; + continue; + } + + /// now check if we have the current update_id in the base_info, or if we should fetch it from the base data + while (base_info_offset < base_info->N && base_info->tuples[base_info_offset] < update_id) { + base_info_offset++; + } + if (base_info_offset < base_info->N && base_info->tuples[base_info_offset] == update_id) { + // it is! we have to move the tuple from base_info->ids[base_info_offset] to update_info + result_values[result_offset] = base_info_data[base_info_offset]; + } else { + // it is not! we have to move base_table_data[update_id] to update_info + result_values[result_offset] = UpdateSelectElement::Operation( + base_info->segment, OP::template Extract(base_table_data, update_id)); + } + result_ids[result_offset++] = update_id; + } + // write any remaining entries from the old updates + while (update_info_offset < update_info->N) { + result_values[result_offset] = update_info_data[update_info_offset]; + result_ids[result_offset++] = update_info->tuples[update_info_offset]; + update_info_offset++; + } + // now copy them back + update_info->N = result_offset; + memcpy(update_info_data, result_values, result_offset * sizeof(T)); + memcpy(update_info->tuples, result_ids, result_offset * sizeof(sel_t)); + + // now we merge the new values into the base_info + result_offset = 0; + auto pick_new = [&](idx_t id, idx_t aidx, idx_t count) { + result_values[result_offset] = OP::template Extract(update_vector_data, aidx); + result_ids[result_offset] = id; + result_offset++; + }; + auto pick_old = [&](idx_t id, idx_t bidx, idx_t count) { + result_values[result_offset] = base_info_data[bidx]; + result_ids[result_offset] = id; + result_offset++; + }; + // now we perform a merge of the new ids with the old ids + auto merge = [&](idx_t id, idx_t aidx, idx_t bidx, idx_t count) { + pick_new(id, aidx, count); + }; + MergeLoop(ids, base_info->tuples, count, base_info->N, base_id, merge, pick_new, pick_old, sel); + + base_info->N = result_offset; + memcpy(base_info_data, result_values, result_offset * sizeof(T)); + memcpy(base_info->tuples, result_ids, result_offset * sizeof(sel_t)); +} + +static void MergeValidityLoop(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, Vector &update, + row_t *ids, idx_t count, const SelectionVector &sel) { + auto &base_validity = FlatVector::Validity(base_data); + auto &update_validity = FlatVector::Validity(update); + MergeUpdateLoopInternal(base_info, &base_validity, update_info, + &update_validity, ids, count, sel); +} + +template +static void MergeUpdateLoop(UpdateInfo *base_info, Vector &base_data, UpdateInfo *update_info, Vector &update, + row_t *ids, idx_t count, const SelectionVector &sel) { + auto base_table_data = FlatVector::GetData(base_data); + auto update_vector_data = FlatVector::GetData(update); + MergeUpdateLoopInternal(base_info, base_table_data, update_info, update_vector_data, ids, count, sel); +} + +static UpdateSegment::merge_update_function_t GetMergeUpdateFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + return MergeValidityLoop; + case PhysicalType::BOOL: + case PhysicalType::INT8: + return MergeUpdateLoop; + case PhysicalType::INT16: + return MergeUpdateLoop; + case PhysicalType::INT32: + return MergeUpdateLoop; + case PhysicalType::INT64: + return MergeUpdateLoop; + case PhysicalType::UINT8: + return MergeUpdateLoop; + case PhysicalType::UINT16: + return MergeUpdateLoop; + case PhysicalType::UINT32: + return MergeUpdateLoop; + case PhysicalType::UINT64: + return MergeUpdateLoop; + case PhysicalType::INT128: + return MergeUpdateLoop; + case PhysicalType::FLOAT: + return MergeUpdateLoop; + case PhysicalType::DOUBLE: + return MergeUpdateLoop; + case PhysicalType::INTERVAL: + return MergeUpdateLoop; + case PhysicalType::VARCHAR: + return MergeUpdateLoop; + default: + throw NotImplementedException("Unimplemented type for uncompressed segment"); + } +} + +//===--------------------------------------------------------------------===// +// Update statistics +//===--------------------------------------------------------------------===// +unique_ptr UpdateSegment::GetStatistics() { + lock_guard stats_guard(stats_lock); + return stats.statistics.ToUnique(); +} + +idx_t UpdateValidityStatistics(UpdateSegment *segment, SegmentStatistics &stats, Vector &update, idx_t count, + SelectionVector &sel) { + auto &mask = FlatVector::Validity(update); + auto &validity = stats.statistics; + if (!mask.AllValid() && !validity.CanHaveNull()) { + for (idx_t i = 0; i < count; i++) { + if (!mask.RowIsValid(i)) { + validity.SetHasNull(); + break; + } + } + } + sel.Initialize(nullptr); + return count; +} + +template +idx_t TemplatedUpdateNumericStatistics(UpdateSegment *segment, SegmentStatistics &stats, Vector &update, idx_t count, + SelectionVector &sel) { + auto update_data = FlatVector::GetData(update); + auto &mask = FlatVector::Validity(update); + + if (mask.AllValid()) { + for (idx_t i = 0; i < count; i++) { + NumericStats::Update(stats.statistics, update_data[i]); + } + sel.Initialize(nullptr); + return count; + } else { + idx_t not_null_count = 0; + sel.Initialize(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < count; i++) { + if (mask.RowIsValid(i)) { + sel.set_index(not_null_count++, i); + NumericStats::Update(stats.statistics, update_data[i]); + } + } + return not_null_count; + } +} + +idx_t UpdateStringStatistics(UpdateSegment *segment, SegmentStatistics &stats, Vector &update, idx_t count, + SelectionVector &sel) { + auto update_data = FlatVector::GetData(update); + auto &mask = FlatVector::Validity(update); + if (mask.AllValid()) { + for (idx_t i = 0; i < count; i++) { + StringStats::Update(stats.statistics, update_data[i]); + if (!update_data[i].IsInlined()) { + update_data[i] = segment->GetStringHeap().AddBlob(update_data[i]); + } + } + sel.Initialize(nullptr); + return count; + } else { + idx_t not_null_count = 0; + sel.Initialize(STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < count; i++) { + if (mask.RowIsValid(i)) { + sel.set_index(not_null_count++, i); + StringStats::Update(stats.statistics, update_data[i]); + if (!update_data[i].IsInlined()) { + update_data[i] = segment->GetStringHeap().AddBlob(update_data[i]); + } + } + } + return not_null_count; + } +} + +UpdateSegment::statistics_update_function_t GetStatisticsUpdateFunction(PhysicalType type) { + switch (type) { + case PhysicalType::BIT: + return UpdateValidityStatistics; + case PhysicalType::BOOL: + case PhysicalType::INT8: + return TemplatedUpdateNumericStatistics; + case PhysicalType::INT16: + return TemplatedUpdateNumericStatistics; + case PhysicalType::INT32: + return TemplatedUpdateNumericStatistics; + case PhysicalType::INT64: + return TemplatedUpdateNumericStatistics; + case PhysicalType::UINT8: + return TemplatedUpdateNumericStatistics; + case PhysicalType::UINT16: + return TemplatedUpdateNumericStatistics; + case PhysicalType::UINT32: + return TemplatedUpdateNumericStatistics; + case PhysicalType::UINT64: + return TemplatedUpdateNumericStatistics; + case PhysicalType::INT128: + return TemplatedUpdateNumericStatistics; + case PhysicalType::FLOAT: + return TemplatedUpdateNumericStatistics; + case PhysicalType::DOUBLE: + return TemplatedUpdateNumericStatistics; + case PhysicalType::INTERVAL: + return TemplatedUpdateNumericStatistics; + case PhysicalType::VARCHAR: + return UpdateStringStatistics; + default: + throw NotImplementedException("Unimplemented type for uncompressed segment"); + } +} + +//===--------------------------------------------------------------------===// +// Update +//===--------------------------------------------------------------------===// +static idx_t SortSelectionVector(SelectionVector &sel, idx_t count, row_t *ids) { + D_ASSERT(count > 0); + + bool is_sorted = true; + for (idx_t i = 1; i < count; i++) { + auto prev_idx = sel.get_index(i - 1); + auto idx = sel.get_index(i); + if (ids[idx] <= ids[prev_idx]) { + is_sorted = false; + break; + } + } + if (is_sorted) { + // already sorted: bailout + return count; + } + // not sorted: need to sort the selection vector + SelectionVector sorted_sel(count); + for (idx_t i = 0; i < count; i++) { + sorted_sel.set_index(i, sel.get_index(i)); + } + std::sort(sorted_sel.data(), sorted_sel.data() + count, [&](sel_t l, sel_t r) { return ids[l] < ids[r]; }); + // eliminate any duplicates + idx_t pos = 1; + for (idx_t i = 1; i < count; i++) { + auto prev_idx = sorted_sel.get_index(i - 1); + auto idx = sorted_sel.get_index(i); + D_ASSERT(ids[idx] >= ids[prev_idx]); + if (ids[prev_idx] != ids[idx]) { + sorted_sel.set_index(pos++, idx); + } + } +#ifdef DEBUG + for (idx_t i = 1; i < pos; i++) { + auto prev_idx = sorted_sel.get_index(i - 1); + auto idx = sorted_sel.get_index(i); + D_ASSERT(ids[idx] > ids[prev_idx]); + } +#endif + + sel.Initialize(sorted_sel); + D_ASSERT(pos > 0); + return pos; +} + +UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, idx_t type_size, idx_t count, + unsafe_unique_array &data) { + data = make_unsafe_uniq_array(sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); + auto update_info = reinterpret_cast(data.get()); + update_info->max = STANDARD_VECTOR_SIZE; + update_info->tuples = reinterpret_cast((data_ptr_cast(update_info)) + sizeof(UpdateInfo)); + update_info->tuple_data = (data_ptr_cast(update_info)) + sizeof(UpdateInfo) + sizeof(sel_t) * update_info->max; + update_info->version_number = transaction.transaction_id; + return update_info; +} + +void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vector &update, row_t *ids, idx_t count, + Vector &base_data) { + // obtain an exclusive lock + auto write_lock = lock.GetExclusiveLock(); + + update.Flatten(count); + + // update statistics + SelectionVector sel; + { + lock_guard stats_guard(stats_lock); + count = statistics_update_function(this, stats, update, count, sel); + } + if (count == 0) { + return; + } + + // subsequent algorithms used by the update require row ids to be (1) sorted, and (2) unique + // this is usually the case for "standard" queries (e.g. UPDATE tbl SET x=bla WHERE cond) + // however, for more exotic queries involving e.g. cross products/joins this might not be the case + // hence we explicitly check here if the ids are sorted and, if not, sort + duplicate eliminate them + count = SortSelectionVector(sel, count, ids); + D_ASSERT(count > 0); + + // create the versions for this segment, if there are none yet + if (!root) { + root = make_uniq(); + } + + // get the vector index based on the first id + // we assert that all updates must be part of the same vector + auto first_id = ids[sel.get_index(0)]; + idx_t vector_index = (first_id - column_data.start) / STANDARD_VECTOR_SIZE; + idx_t vector_offset = column_data.start + vector_index * STANDARD_VECTOR_SIZE; + + D_ASSERT(idx_t(first_id) >= column_data.start); + D_ASSERT(vector_index < Storage::ROW_GROUP_VECTOR_COUNT); + + // first check the version chain + UpdateInfo *node = nullptr; + + if (root->info[vector_index]) { + // there is already a version here, check if there are any conflicts and search for the node that belongs to + // this transaction in the version chain + auto base_info = root->info[vector_index]->info.get(); + CheckForConflicts(base_info->next, transaction, ids, sel, count, vector_offset, node); + + // there are no conflicts + // first, check if this thread has already done any updates + auto node = base_info->next; + while (node) { + if (node->version_number == transaction.transaction_id) { + // it has! use this node + break; + } + node = node->next; + } + unsafe_unique_array update_info_data; + if (!node) { + // no updates made yet by this transaction: initially the update info to empty + if (transaction.transaction) { + auto &dtransaction = transaction.transaction->Cast(); + node = dtransaction.CreateUpdateInfo(type_size, count); + } else { + node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + } + node->segment = this; + node->vector_index = vector_index; + node->N = 0; + node->column_index = column_index; + + // insert the new node into the chain + node->next = base_info->next; + if (node->next) { + node->next->prev = node; + } + node->prev = base_info; + base_info->next = transaction.transaction ? node : nullptr; + } + base_info->Verify(); + node->Verify(); + + // now we are going to perform the merge + merge_update_function(base_info, base_data, node, update, ids, count, sel); + + base_info->Verify(); + node->Verify(); + } else { + // there is no version info yet: create the top level update info and fill it with the updates + auto result = make_uniq(); + + result->info = make_uniq(); + result->tuples = make_unsafe_uniq_array(STANDARD_VECTOR_SIZE); + result->tuple_data = make_unsafe_uniq_array(STANDARD_VECTOR_SIZE * type_size); + result->info->tuples = result->tuples.get(); + result->info->tuple_data = result->tuple_data.get(); + result->info->version_number = TRANSACTION_ID_START - 1; + result->info->column_index = column_index; + InitializeUpdateInfo(*result->info, ids, sel, count, vector_index, vector_offset); + + // now create the transaction level update info in the undo log + unsafe_unique_array update_info_data; + UpdateInfo *transaction_node; + if (transaction.transaction) { + transaction_node = transaction.transaction->CreateUpdateInfo(type_size, count); + } else { + transaction_node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + } + + InitializeUpdateInfo(*transaction_node, ids, sel, count, vector_index, vector_offset); + + // we write the updates in the update node data, and write the updates in the info + initialize_update_function(transaction_node, base_data, result->info.get(), update, sel); + + result->info->next = transaction.transaction ? transaction_node : nullptr; + result->info->prev = nullptr; + transaction_node->next = nullptr; + transaction_node->prev = result->info.get(); + transaction_node->column_index = column_index; + + transaction_node->Verify(); + result->info->Verify(); + + root->info[vector_index] = std::move(result); + } +} + +bool UpdateSegment::HasUpdates() const { + return root.get() != nullptr; +} + +bool UpdateSegment::HasUpdates(idx_t vector_index) const { + if (!HasUpdates()) { + return false; + } + return root->info[vector_index].get(); +} + +bool UpdateSegment::HasUncommittedUpdates(idx_t vector_index) { + if (!HasUpdates(vector_index)) { + return false; + } + auto read_lock = lock.GetSharedLock(); + auto entry = root->info[vector_index].get(); + if (entry->info->next) { + return true; + } + return false; +} + +bool UpdateSegment::HasUpdates(idx_t start_row_index, idx_t end_row_index) { + if (!HasUpdates()) { + return false; + } + auto read_lock = lock.GetSharedLock(); + idx_t base_vector_index = start_row_index / STANDARD_VECTOR_SIZE; + idx_t end_vector_index = end_row_index / STANDARD_VECTOR_SIZE; + for (idx_t i = base_vector_index; i <= end_vector_index; i++) { + if (root->info[i]) { + return true; + } + } + return false; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/validity_column_data.cpp b/src/duckdb/src/storage/table/validity_column_data.cpp new file mode 100644 index 00000000..bd594fd1 --- /dev/null +++ b/src/duckdb/src/storage/table/validity_column_data.cpp @@ -0,0 +1,16 @@ +#include "duckdb/storage/table/validity_column_data.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table/update_segment.hpp" + +namespace duckdb { + +ValidityColumnData::ValidityColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + idx_t start_row, ColumnData &parent) + : ColumnData(block_manager, info, column_index, start_row, LogicalType(LogicalTypeId::VALIDITY), &parent) { +} + +bool ValidityColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table_index_list.cpp b/src/duckdb/src/storage/table_index_list.cpp new file mode 100644 index 00000000..3223f5e3 --- /dev/null +++ b/src/duckdb/src/storage/table_index_list.cpp @@ -0,0 +1,90 @@ +#include "duckdb/storage/table/table_index_list.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/common/types/conflict_manager.hpp" +#include "duckdb/execution/index/art/art.hpp" + +namespace duckdb { +void TableIndexList::AddIndex(unique_ptr index) { + D_ASSERT(index); + lock_guard lock(indexes_lock); + indexes.push_back(std::move(index)); +} + +void TableIndexList::RemoveIndex(Index &index) { + lock_guard lock(indexes_lock); + + for (idx_t index_idx = 0; index_idx < indexes.size(); index_idx++) { + auto &index_entry = indexes[index_idx]; + if (index_entry.get() == &index) { + indexes.erase(indexes.begin() + index_idx); + break; + } + } +} + +bool TableIndexList::Empty() { + lock_guard lock(indexes_lock); + return indexes.empty(); +} + +idx_t TableIndexList::Count() { + lock_guard lock(indexes_lock); + return indexes.size(); +} + +void TableIndexList::Move(TableIndexList &other) { + D_ASSERT(indexes.empty()); + indexes = std::move(other.indexes); +} + +Index *TableIndexList::FindForeignKeyIndex(const vector &fk_keys, ForeignKeyType fk_type) { + Index *result = nullptr; + Scan([&](Index &index) { + if (DataTable::IsForeignKeyIndex(fk_keys, index, fk_type)) { + result = &index; + } + return false; + }); + return result; +} + +void TableIndexList::VerifyForeignKey(const vector &fk_keys, DataChunk &chunk, + ConflictManager &conflict_manager) { + auto fk_type = conflict_manager.LookupType() == VerifyExistenceType::APPEND_FK + ? ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE + : ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; + + // check whether the chunk can be inserted or deleted into the referenced table storage + auto index = FindForeignKeyIndex(fk_keys, fk_type); + if (!index) { + throw InternalException("Internal Foreign Key error: could not find index to verify..."); + } + conflict_manager.SetIndexCount(1); + index->CheckConstraintsForChunk(chunk, conflict_manager); +} + +vector TableIndexList::GetRequiredColumns() { + lock_guard lock(indexes_lock); + set unique_indexes; + for (auto &index : indexes) { + for (auto col_index : index->column_ids) { + unique_indexes.insert(col_index); + } + } + vector result; + result.reserve(unique_indexes.size()); + for (auto column_index : unique_indexes) { + result.emplace_back(column_index); + } + return result; +} + +vector TableIndexList::SerializeIndexes(duckdb::MetadataWriter &writer) { + vector blocks_info; + for (auto &index : indexes) { + blocks_info.emplace_back(index->Serialize(writer)); + } + return blocks_info; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/wal_replay.cpp b/src/duckdb/src/storage/wal_replay.cpp new file mode 100644 index 00000000..7e700e39 --- /dev/null +++ b/src/duckdb/src/storage/wal_replay.cpp @@ -0,0 +1,538 @@ +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/common/printer.hpp" +#include "duckdb/common/serializer/buffered_file_reader.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/connection.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" +#include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/write_ahead_log.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" + +namespace duckdb { + +bool WriteAheadLog::Replay(AttachedDatabase &database, string &path) { + Connection con(database.GetDatabase()); + auto initial_source = make_uniq(FileSystem::Get(database), path.c_str()); + if (initial_source->Finished()) { + // WAL is empty + return false; + } + + con.BeginTransaction(); + + // first deserialize the WAL to look for a checkpoint flag + // if there is a checkpoint flag, we might have already flushed the contents of the WAL to disk + ReplayState checkpoint_state(database, *con.context); + checkpoint_state.deserialize_only = true; + try { + while (true) { + // read the current entry + BinaryDeserializer deserializer(*initial_source); + deserializer.Begin(); + auto entry_type = deserializer.ReadProperty(100, "wal_type"); + if (entry_type == WALType::WAL_FLUSH) { + deserializer.End(); + // check if the file is exhausted + if (initial_source->Finished()) { + // we finished reading the file: break + break; + } + } else { + // replay the entry + checkpoint_state.ReplayEntry(entry_type, deserializer); + deserializer.End(); + } + } + } catch (SerializationException &ex) { // LCOV_EXCL_START + // serialization exception - torn WAL + // continue reading + } catch (std::exception &ex) { + Printer::PrintF("Exception in WAL playback during initial read: %s\n", ex.what()); + return false; + } catch (...) { + Printer::Print("Unknown Exception in WAL playback during initial read"); + return false; + } // LCOV_EXCL_STOP + initial_source.reset(); + if (checkpoint_state.checkpoint_id.IsValid()) { + // there is a checkpoint flag: check if we need to deserialize the WAL + auto &manager = database.GetStorageManager(); + if (manager.IsCheckpointClean(checkpoint_state.checkpoint_id)) { + // the contents of the WAL have already been checkpointed + // we can safely truncate the WAL and ignore its contents + return true; + } + } + + // we need to recover from the WAL: actually set up the replay state + BufferedFileReader reader(FileSystem::Get(database), path.c_str()); + ReplayState state(database, *con.context); + + // replay the WAL + // note that everything is wrapped inside a try/catch block here + // there can be errors in WAL replay because of a corrupt WAL file + // in this case we should throw a warning but startup anyway + try { + while (true) { + // read the current entry + BinaryDeserializer deserializer(reader); + deserializer.Begin(); + auto entry_type = deserializer.ReadProperty(100, "wal_type"); + if (entry_type == WALType::WAL_FLUSH) { + deserializer.End(); + con.Commit(); + // check if the file is exhausted + if (reader.Finished()) { + // we finished reading the file: break + break; + } + con.BeginTransaction(); + } else { + // replay the entry + state.ReplayEntry(entry_type, deserializer); + deserializer.End(); + } + } + } catch (SerializationException &ex) { // LCOV_EXCL_START + // serialization error during WAL replay: rollback + con.Rollback(); + } catch (std::exception &ex) { + // FIXME: this should report a proper warning in the connection + Printer::PrintF("Exception in WAL playback: %s\n", ex.what()); + // exception thrown in WAL replay: rollback + con.Rollback(); + } catch (...) { + Printer::Print("Unknown Exception in WAL playback: %s\n"); + // exception thrown in WAL replay: rollback + con.Rollback(); + } // LCOV_EXCL_STOP + return false; +} + +//===--------------------------------------------------------------------===// +// Replay Entries +//===--------------------------------------------------------------------===// +void ReplayState::ReplayEntry(WALType entry_type, BinaryDeserializer &deserializer) { + switch (entry_type) { + case WALType::CREATE_TABLE: + ReplayCreateTable(deserializer); + break; + case WALType::DROP_TABLE: + ReplayDropTable(deserializer); + break; + case WALType::ALTER_INFO: + ReplayAlter(deserializer); + break; + case WALType::CREATE_VIEW: + ReplayCreateView(deserializer); + break; + case WALType::DROP_VIEW: + ReplayDropView(deserializer); + break; + case WALType::CREATE_SCHEMA: + ReplayCreateSchema(deserializer); + break; + case WALType::DROP_SCHEMA: + ReplayDropSchema(deserializer); + break; + case WALType::CREATE_SEQUENCE: + ReplayCreateSequence(deserializer); + break; + case WALType::DROP_SEQUENCE: + ReplayDropSequence(deserializer); + break; + case WALType::SEQUENCE_VALUE: + ReplaySequenceValue(deserializer); + break; + case WALType::CREATE_MACRO: + ReplayCreateMacro(deserializer); + break; + case WALType::DROP_MACRO: + ReplayDropMacro(deserializer); + break; + case WALType::CREATE_TABLE_MACRO: + ReplayCreateTableMacro(deserializer); + break; + case WALType::DROP_TABLE_MACRO: + ReplayDropTableMacro(deserializer); + break; + case WALType::CREATE_INDEX: + ReplayCreateIndex(deserializer); + break; + case WALType::DROP_INDEX: + ReplayDropIndex(deserializer); + break; + case WALType::USE_TABLE: + ReplayUseTable(deserializer); + break; + case WALType::INSERT_TUPLE: + ReplayInsert(deserializer); + break; + case WALType::DELETE_TUPLE: + ReplayDelete(deserializer); + break; + case WALType::UPDATE_TUPLE: + ReplayUpdate(deserializer); + break; + case WALType::CHECKPOINT: + ReplayCheckpoint(deserializer); + break; + case WALType::CREATE_TYPE: + ReplayCreateType(deserializer); + break; + case WALType::DROP_TYPE: + ReplayDropType(deserializer); + break; + default: + throw InternalException("Invalid WAL entry type!"); + } +} + +//===--------------------------------------------------------------------===// +// Replay Table +//===--------------------------------------------------------------------===// +void ReplayState::ReplayCreateTable(BinaryDeserializer &deserializer) { + auto info = deserializer.ReadProperty>(101, "table"); + if (deserialize_only) { + return; + } + // bind the constraints to the table again + auto binder = Binder::CreateBinder(context); + auto &schema = catalog.GetSchema(context, info->schema); + auto bound_info = binder->BindCreateTableInfo(std::move(info), schema); + + catalog.CreateTable(context, *bound_info); +} + +void ReplayState::ReplayDropTable(BinaryDeserializer &deserializer) { + + DropInfo info; + + info.type = CatalogType::TABLE_ENTRY; + info.schema = deserializer.ReadProperty(101, "schema"); + info.name = deserializer.ReadProperty(102, "name"); + if (deserialize_only) { + return; + } + + catalog.DropEntry(context, info); +} + +void ReplayState::ReplayAlter(BinaryDeserializer &deserializer) { + + auto info = deserializer.ReadProperty>(101, "info"); + auto &alter_info = info->Cast(); + if (deserialize_only) { + return; + } + catalog.Alter(context, alter_info); +} + +//===--------------------------------------------------------------------===// +// Replay View +//===--------------------------------------------------------------------===// +void ReplayState::ReplayCreateView(BinaryDeserializer &deserializer) { + auto entry = deserializer.ReadProperty>(101, "view"); + if (deserialize_only) { + return; + } + catalog.CreateView(context, entry->Cast()); +} + +void ReplayState::ReplayDropView(BinaryDeserializer &deserializer) { + DropInfo info; + info.type = CatalogType::VIEW_ENTRY; + info.schema = deserializer.ReadProperty(101, "schema"); + info.name = deserializer.ReadProperty(102, "name"); + if (deserialize_only) { + return; + } + catalog.DropEntry(context, info); +} + +//===--------------------------------------------------------------------===// +// Replay Schema +//===--------------------------------------------------------------------===// +void ReplayState::ReplayCreateSchema(BinaryDeserializer &deserializer) { + CreateSchemaInfo info; + info.schema = deserializer.ReadProperty(101, "schema"); + if (deserialize_only) { + return; + } + + catalog.CreateSchema(context, info); +} + +void ReplayState::ReplayDropSchema(BinaryDeserializer &deserializer) { + DropInfo info; + + info.type = CatalogType::SCHEMA_ENTRY; + info.name = deserializer.ReadProperty(101, "schema"); + if (deserialize_only) { + return; + } + + catalog.DropEntry(context, info); +} + +//===--------------------------------------------------------------------===// +// Replay Custom Type +//===--------------------------------------------------------------------===// +void ReplayState::ReplayCreateType(BinaryDeserializer &deserializer) { + auto info = deserializer.ReadProperty>(101, "type"); + info->on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; + catalog.CreateType(context, info->Cast()); +} + +void ReplayState::ReplayDropType(BinaryDeserializer &deserializer) { + DropInfo info; + + info.type = CatalogType::TYPE_ENTRY; + info.schema = deserializer.ReadProperty(101, "schema"); + info.name = deserializer.ReadProperty(102, "name"); + if (deserialize_only) { + return; + } + + catalog.DropEntry(context, info); +} + +//===--------------------------------------------------------------------===// +// Replay Sequence +//===--------------------------------------------------------------------===// +void ReplayState::ReplayCreateSequence(BinaryDeserializer &deserializer) { + auto entry = deserializer.ReadProperty>(101, "sequence"); + if (deserialize_only) { + return; + } + + catalog.CreateSequence(context, entry->Cast()); +} + +void ReplayState::ReplayDropSequence(BinaryDeserializer &deserializer) { + DropInfo info; + info.type = CatalogType::SEQUENCE_ENTRY; + info.schema = deserializer.ReadProperty(101, "schema"); + info.name = deserializer.ReadProperty(102, "name"); + if (deserialize_only) { + return; + } + + catalog.DropEntry(context, info); +} + +void ReplayState::ReplaySequenceValue(BinaryDeserializer &deserializer) { + auto schema = deserializer.ReadProperty(101, "schema"); + auto name = deserializer.ReadProperty(102, "name"); + auto usage_count = deserializer.ReadProperty(103, "usage_count"); + auto counter = deserializer.ReadProperty(104, "counter"); + if (deserialize_only) { + return; + } + + // fetch the sequence from the catalog + auto &seq = catalog.GetEntry(context, schema, name); + if (usage_count > seq.usage_count) { + seq.usage_count = usage_count; + seq.counter = counter; + } +} + +//===--------------------------------------------------------------------===// +// Replay Macro +//===--------------------------------------------------------------------===// +void ReplayState::ReplayCreateMacro(BinaryDeserializer &deserializer) { + auto entry = deserializer.ReadProperty>(101, "macro"); + if (deserialize_only) { + return; + } + + catalog.CreateFunction(context, entry->Cast()); +} + +void ReplayState::ReplayDropMacro(BinaryDeserializer &deserializer) { + DropInfo info; + info.type = CatalogType::MACRO_ENTRY; + info.schema = deserializer.ReadProperty(101, "schema"); + info.name = deserializer.ReadProperty(102, "name"); + if (deserialize_only) { + return; + } + + catalog.DropEntry(context, info); +} + +//===--------------------------------------------------------------------===// +// Replay Table Macro +//===--------------------------------------------------------------------===// +void ReplayState::ReplayCreateTableMacro(BinaryDeserializer &deserializer) { + auto entry = deserializer.ReadProperty>(101, "table_macro"); + if (deserialize_only) { + return; + } + catalog.CreateFunction(context, entry->Cast()); +} + +void ReplayState::ReplayDropTableMacro(BinaryDeserializer &deserializer) { + DropInfo info; + info.type = CatalogType::TABLE_MACRO_ENTRY; + info.schema = deserializer.ReadProperty(101, "schema"); + info.name = deserializer.ReadProperty(102, "name"); + if (deserialize_only) { + return; + } + + catalog.DropEntry(context, info); +} + +//===--------------------------------------------------------------------===// +// Replay Index +//===--------------------------------------------------------------------===// +void ReplayState::ReplayCreateIndex(BinaryDeserializer &deserializer) { + auto info = deserializer.ReadProperty>(101, "index"); + if (deserialize_only) { + return; + } + auto &index_info = info->Cast(); + + // get the physical table to which we'll add the index + auto &table = catalog.GetEntry(context, info->schema, index_info.table); + auto &data_table = table.GetStorage(); + + // bind the parsed expressions + if (index_info.expressions.empty()) { + for (auto &parsed_expr : index_info.parsed_expressions) { + index_info.expressions.push_back(parsed_expr->Copy()); + } + } + auto binder = Binder::CreateBinder(context); + auto expressions = binder->BindCreateIndexExpressions(table, index_info); + + // create the empty index + unique_ptr index; + switch (index_info.index_type) { + case IndexType::ART: { + index = make_uniq(index_info.column_ids, TableIOManager::Get(data_table), expressions, + index_info.constraint_type, data_table.db); + break; + } + default: + throw InternalException("Unimplemented index type"); + } + + // add the index to the catalog + auto &index_entry = catalog.CreateIndex(context, index_info)->Cast(); + index_entry.index = index.get(); + index_entry.info = data_table.info; + for (auto &parsed_expr : index_info.parsed_expressions) { + index_entry.parsed_expressions.push_back(parsed_expr->Copy()); + } + + // physically add the index to the data table storage + data_table.WALAddIndex(context, std::move(index), expressions); +} + +void ReplayState::ReplayDropIndex(BinaryDeserializer &deserializer) { + DropInfo info; + info.type = CatalogType::INDEX_ENTRY; + info.schema = deserializer.ReadProperty(101, "schema"); + info.name = deserializer.ReadProperty(102, "name"); + if (deserialize_only) { + return; + } + + catalog.DropEntry(context, info); +} + +//===--------------------------------------------------------------------===// +// Replay Data +//===--------------------------------------------------------------------===// +void ReplayState::ReplayUseTable(BinaryDeserializer &deserializer) { + auto schema_name = deserializer.ReadProperty(101, "schema"); + auto table_name = deserializer.ReadProperty(102, "table"); + if (deserialize_only) { + return; + } + current_table = &catalog.GetEntry(context, schema_name, table_name); +} + +void ReplayState::ReplayInsert(BinaryDeserializer &deserializer) { + DataChunk chunk; + deserializer.ReadObject(101, "chunk", [&](Deserializer &object) { chunk.Deserialize(object); }); + if (deserialize_only) { + return; + } + if (!current_table) { + throw Exception("Corrupt WAL: insert without table"); + } + + // append to the current table + current_table->GetStorage().LocalAppend(*current_table, context, chunk); +} + +void ReplayState::ReplayDelete(BinaryDeserializer &deserializer) { + DataChunk chunk; + deserializer.ReadObject(101, "chunk", [&](Deserializer &object) { chunk.Deserialize(object); }); + if (deserialize_only) { + return; + } + if (!current_table) { + throw InternalException("Corrupt WAL: delete without table"); + } + + D_ASSERT(chunk.ColumnCount() == 1 && chunk.data[0].GetType() == LogicalType::ROW_TYPE); + row_t row_ids[1]; + Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_ids)); + + auto source_ids = FlatVector::GetData(chunk.data[0]); + // delete the tuples from the current table + for (idx_t i = 0; i < chunk.size(); i++) { + row_ids[0] = source_ids[i]; + current_table->GetStorage().Delete(*current_table, context, row_identifiers, 1); + } +} + +void ReplayState::ReplayUpdate(BinaryDeserializer &deserializer) { + auto column_path = deserializer.ReadProperty>(101, "column_indexes"); + + DataChunk chunk; + deserializer.ReadObject(102, "chunk", [&](Deserializer &object) { chunk.Deserialize(object); }); + + if (deserialize_only) { + return; + } + if (!current_table) { + throw InternalException("Corrupt WAL: update without table"); + } + + if (column_path[0] >= current_table->GetColumns().PhysicalColumnCount()) { + throw InternalException("Corrupt WAL: column index for update out of bounds"); + } + + // remove the row id vector from the chunk + auto row_ids = std::move(chunk.data.back()); + chunk.data.pop_back(); + + // now perform the update + current_table->GetStorage().UpdateColumn(*current_table, context, row_ids, column_path, chunk); +} + +void ReplayState::ReplayCheckpoint(BinaryDeserializer &deserializer) { + checkpoint_id = deserializer.ReadProperty(101, "meta_block"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/write_ahead_log.cpp b/src/duckdb/src/storage/write_ahead_log.cpp new file mode 100644 index 00000000..5a5657b7 --- /dev/null +++ b/src/duckdb/src/storage/write_ahead_log.cpp @@ -0,0 +1,377 @@ +#include "duckdb/storage/write_ahead_log.hpp" + +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/parsed_data/alter_table_info.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include + +namespace duckdb { + +WriteAheadLog::WriteAheadLog(AttachedDatabase &database, const string &path) : skip_writing(false), database(database) { + wal_path = path; + writer = make_uniq(FileSystem::Get(database), path.c_str(), + FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE | + FileFlags::FILE_FLAGS_APPEND); +} + +WriteAheadLog::~WriteAheadLog() { +} + +int64_t WriteAheadLog::GetWALSize() { + D_ASSERT(writer); + return writer->GetFileSize(); +} + +idx_t WriteAheadLog::GetTotalWritten() { + D_ASSERT(writer); + return writer->GetTotalWritten(); +} + +void WriteAheadLog::Truncate(int64_t size) { + writer->Truncate(size); +} + +void WriteAheadLog::Delete() { + if (!writer) { + return; + } + writer.reset(); + + auto &fs = FileSystem::Get(database); + fs.RemoveFile(wal_path); +} + +//===--------------------------------------------------------------------===// +// Write Entries +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteCheckpoint(MetaBlockPointer meta_block) { + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CHECKPOINT); + serializer.WriteProperty(101, "meta_block", meta_block); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// CREATE TABLE +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteCreateTable(const TableCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CREATE_TABLE); + serializer.WriteProperty(101, "table", &entry); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// DROP TABLE +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteDropTable(const TableCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DROP_TABLE); + serializer.WriteProperty(101, "schema", entry.schema.name); + serializer.WriteProperty(102, "name", entry.name); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// CREATE SCHEMA +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteCreateSchema(const SchemaCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CREATE_SCHEMA); + serializer.WriteProperty(101, "schema", entry.name); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// SEQUENCES +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteCreateSequence(const SequenceCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CREATE_SEQUENCE); + serializer.WriteProperty(101, "sequence", &entry); + serializer.End(); +} + +void WriteAheadLog::WriteDropSequence(const SequenceCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DROP_SEQUENCE); + serializer.WriteProperty(101, "schema", entry.schema.name); + serializer.WriteProperty(102, "name", entry.name); + serializer.End(); +} + +void WriteAheadLog::WriteSequenceValue(const SequenceCatalogEntry &entry, SequenceValue val) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::SEQUENCE_VALUE); + serializer.WriteProperty(101, "schema", entry.schema.name); + serializer.WriteProperty(102, "name", entry.name); + serializer.WriteProperty(103, "usage_count", val.usage_count); + serializer.WriteProperty(104, "counter", val.counter); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// MACROS +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteCreateMacro(const ScalarMacroCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CREATE_MACRO); + serializer.WriteProperty(101, "macro", &entry); + serializer.End(); +} + +void WriteAheadLog::WriteDropMacro(const ScalarMacroCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DROP_MACRO); + serializer.WriteProperty(101, "schema", entry.schema.name); + serializer.WriteProperty(102, "name", entry.name); + serializer.End(); +} + +void WriteAheadLog::WriteCreateTableMacro(const TableMacroCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CREATE_TABLE_MACRO); + serializer.WriteProperty(101, "table", &entry); + serializer.End(); +} + +void WriteAheadLog::WriteDropTableMacro(const TableMacroCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DROP_TABLE_MACRO); + serializer.WriteProperty(101, "schema", entry.schema.name); + serializer.WriteProperty(102, "name", entry.name); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// Indexes +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteCreateIndex(const IndexCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CREATE_INDEX); + serializer.WriteProperty(101, "index", &entry); + serializer.End(); +} + +void WriteAheadLog::WriteDropIndex(const IndexCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DROP_INDEX); + serializer.WriteProperty(101, "schema", entry.schema.name); + serializer.WriteProperty(102, "name", entry.name); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// Custom Types +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteCreateType(const TypeCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CREATE_TYPE); + serializer.WriteProperty(101, "type", &entry); + serializer.End(); +} + +void WriteAheadLog::WriteDropType(const TypeCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DROP_TYPE); + serializer.WriteProperty(101, "schema", entry.schema.name); + serializer.WriteProperty(102, "name", entry.name); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// VIEWS +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteCreateView(const ViewCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::CREATE_VIEW); + serializer.WriteProperty(101, "view", &entry); + serializer.End(); +} + +void WriteAheadLog::WriteDropView(const ViewCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DROP_VIEW); + serializer.WriteProperty(101, "schema", entry.schema.name); + serializer.WriteProperty(102, "name", entry.name); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// DROP SCHEMA +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteDropSchema(const SchemaCatalogEntry &entry) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DROP_SCHEMA); + serializer.WriteProperty(101, "schema", entry.name); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// DATA +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteSetTable(string &schema, string &table) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::USE_TABLE); + serializer.WriteProperty(101, "schema", schema); + serializer.WriteProperty(102, "table", table); + serializer.End(); +} + +void WriteAheadLog::WriteInsert(DataChunk &chunk) { + if (skip_writing) { + return; + } + D_ASSERT(chunk.size() > 0); + chunk.Verify(); + + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::INSERT_TUPLE); + serializer.WriteProperty(101, "chunk", chunk); + serializer.End(); +} + +void WriteAheadLog::WriteDelete(DataChunk &chunk) { + if (skip_writing) { + return; + } + D_ASSERT(chunk.size() > 0); + D_ASSERT(chunk.ColumnCount() == 1 && chunk.data[0].GetType() == LogicalType::ROW_TYPE); + chunk.Verify(); + + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::DELETE_TUPLE); + serializer.WriteProperty(101, "chunk", chunk); + serializer.End(); +} + +void WriteAheadLog::WriteUpdate(DataChunk &chunk, const vector &column_indexes) { + if (skip_writing) { + return; + } + D_ASSERT(chunk.size() > 0); + D_ASSERT(chunk.ColumnCount() == 2); + D_ASSERT(chunk.data[1].GetType().id() == LogicalType::ROW_TYPE); + chunk.Verify(); + + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::UPDATE_TUPLE); + serializer.WriteProperty(101, "column_indexes", column_indexes); + serializer.WriteProperty(102, "chunk", chunk); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// Write ALTER Statement +//===--------------------------------------------------------------------===// +void WriteAheadLog::WriteAlter(const AlterInfo &info) { + if (skip_writing) { + return; + } + BinarySerializer serializer(*writer); + serializer.Begin(); + serializer.WriteProperty(100, "wal_type", WALType::ALTER_INFO); + serializer.WriteProperty(101, "info", &info); + serializer.End(); +} + +//===--------------------------------------------------------------------===// +// FLUSH +//===--------------------------------------------------------------------===// +void WriteAheadLog::Flush() { + if (skip_writing) { + return; + } + + BinarySerializer serializer(*writer); + serializer.Begin(); + // write an empty entry + serializer.WriteProperty(100, "wal_type", WALType::WAL_FLUSH); + serializer.End(); + + // flushes all changes made to the WAL to disk + writer->Sync(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/cleanup_state.cpp b/src/duckdb/src/transaction/cleanup_state.cpp new file mode 100644 index 00000000..cbdaa80a --- /dev/null +++ b/src/duckdb/src/transaction/cleanup_state.cpp @@ -0,0 +1,95 @@ +#include "duckdb/transaction/cleanup_state.hpp" +#include "duckdb/transaction/delete_info.hpp" +#include "duckdb/transaction/update_info.hpp" + +#include "duckdb/storage/data_table.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/storage/table/chunk_info.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/table/row_version_manager.hpp" + +namespace duckdb { + +CleanupState::CleanupState() : current_table(nullptr), count(0) { +} + +CleanupState::~CleanupState() { + Flush(); +} + +void CleanupState::CleanupEntry(UndoFlags type, data_ptr_t data) { + switch (type) { + case UndoFlags::CATALOG_ENTRY: { + auto catalog_entry = Load(data); + D_ASSERT(catalog_entry); + D_ASSERT(catalog_entry->set); + catalog_entry->set->CleanupEntry(*catalog_entry); + break; + } + case UndoFlags::DELETE_TUPLE: { + auto info = reinterpret_cast(data); + CleanupDelete(*info); + break; + } + case UndoFlags::UPDATE_TUPLE: { + auto info = reinterpret_cast(data); + CleanupUpdate(*info); + break; + } + default: + break; + } +} + +void CleanupState::CleanupUpdate(UpdateInfo &info) { + // remove the update info from the update chain + // first obtain an exclusive lock on the segment + info.segment->CleanupUpdate(info); +} + +void CleanupState::CleanupDelete(DeleteInfo &info) { + auto version_table = info.table; + D_ASSERT(version_table->info->cardinality >= info.count); + version_table->info->cardinality -= info.count; + + if (version_table->info->indexes.Empty()) { + // this table has no indexes: no cleanup to be done + return; + } + + if (current_table != version_table) { + // table for this entry differs from previous table: flush and switch to the new table + Flush(); + current_table = version_table; + } + + // possibly vacuum any indexes in this table later + indexed_tables[current_table->info->table] = current_table; + + count = 0; + for (idx_t i = 0; i < info.count; i++) { + row_numbers[count++] = info.base_row + info.rows[i]; + } + Flush(); +} + +void CleanupState::Flush() { + if (count == 0) { + return; + } + + // set up the row identifiers vector + Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_numbers)); + + // delete the tuples from all the indexes + try { + current_table->RemoveFromIndexes(row_identifiers, count); + } catch (...) { + } + + count = 0; +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/commit_state.cpp b/src/duckdb/src/transaction/commit_state.cpp new file mode 100644 index 00000000..6f844317 --- /dev/null +++ b/src/duckdb/src/transaction/commit_state.cpp @@ -0,0 +1,342 @@ +#include "duckdb/transaction/commit_state.hpp" + +#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table/chunk_info.hpp" +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/write_ahead_log.hpp" +#include "duckdb/transaction/append_info.hpp" +#include "duckdb/transaction/delete_info.hpp" +#include "duckdb/transaction/update_info.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" +#include "duckdb/storage/table/row_version_manager.hpp" +#include "duckdb/common/serializer/binary_deserializer.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" + +namespace duckdb { + +CommitState::CommitState(transaction_t commit_id, optional_ptr log) + : log(log), commit_id(commit_id), current_table_info(nullptr) { +} + +void CommitState::SwitchTable(DataTableInfo *table_info, UndoFlags new_op) { + if (current_table_info != table_info) { + // write the current table to the log + log->WriteSetTable(table_info->schema, table_info->table); + current_table_info = table_info; + } +} + +void CommitState::WriteCatalogEntry(CatalogEntry &entry, data_ptr_t dataptr) { + if (entry.temporary || entry.parent->temporary) { + return; + } + D_ASSERT(log); + // look at the type of the parent entry + auto parent = entry.parent; + switch (parent->type) { + case CatalogType::TABLE_ENTRY: + if (entry.type == CatalogType::TABLE_ENTRY) { + auto &table_entry = entry.Cast(); + D_ASSERT(table_entry.IsDuckTable()); + // ALTER TABLE statement, read the extra data after the entry + + auto extra_data_size = Load(dataptr); + auto extra_data = data_ptr_cast(dataptr + sizeof(idx_t)); + + MemoryStream source(extra_data, extra_data_size); + BinaryDeserializer deserializer(source); + deserializer.Begin(); + auto column_name = deserializer.ReadProperty(100, "column_name"); + auto parse_info = deserializer.ReadProperty>(101, "alter_info"); + deserializer.End(); + + if (!column_name.empty()) { + // write the alter table in the log + table_entry.CommitAlter(column_name); + } + auto &alter_info = parse_info->Cast(); + log->WriteAlter(alter_info); + } else { + // CREATE TABLE statement + log->WriteCreateTable(parent->Cast()); + } + break; + case CatalogType::SCHEMA_ENTRY: + if (entry.type == CatalogType::SCHEMA_ENTRY) { + // ALTER TABLE statement, skip it + return; + } + log->WriteCreateSchema(parent->Cast()); + break; + case CatalogType::VIEW_ENTRY: + if (entry.type == CatalogType::VIEW_ENTRY) { + // ALTER TABLE statement, read the extra data after the entry + auto extra_data_size = Load(dataptr); + auto extra_data = data_ptr_cast(dataptr + sizeof(idx_t)); + // deserialize it + MemoryStream source(extra_data, extra_data_size); + BinaryDeserializer deserializer(source); + deserializer.Begin(); + auto column_name = deserializer.ReadProperty(100, "column_name"); + auto parse_info = deserializer.ReadProperty>(101, "alter_info"); + deserializer.End(); + + (void)column_name; + + // write the alter table in the log + auto &alter_info = parse_info->Cast(); + log->WriteAlter(alter_info); + } else { + log->WriteCreateView(parent->Cast()); + } + break; + case CatalogType::SEQUENCE_ENTRY: + log->WriteCreateSequence(parent->Cast()); + break; + case CatalogType::MACRO_ENTRY: + log->WriteCreateMacro(parent->Cast()); + break; + case CatalogType::TABLE_MACRO_ENTRY: + log->WriteCreateTableMacro(parent->Cast()); + break; + case CatalogType::INDEX_ENTRY: + log->WriteCreateIndex(parent->Cast()); + break; + case CatalogType::TYPE_ENTRY: + log->WriteCreateType(parent->Cast()); + break; + case CatalogType::DELETED_ENTRY: + switch (entry.type) { + case CatalogType::TABLE_ENTRY: { + auto &table_entry = entry.Cast(); + D_ASSERT(table_entry.IsDuckTable()); + table_entry.CommitDrop(); + log->WriteDropTable(table_entry); + break; + } + case CatalogType::SCHEMA_ENTRY: + log->WriteDropSchema(entry.Cast()); + break; + case CatalogType::VIEW_ENTRY: + log->WriteDropView(entry.Cast()); + break; + case CatalogType::SEQUENCE_ENTRY: + log->WriteDropSequence(entry.Cast()); + break; + case CatalogType::MACRO_ENTRY: + log->WriteDropMacro(entry.Cast()); + break; + case CatalogType::TABLE_MACRO_ENTRY: + log->WriteDropTableMacro(entry.Cast()); + break; + case CatalogType::TYPE_ENTRY: + log->WriteDropType(entry.Cast()); + break; + case CatalogType::INDEX_ENTRY: { + auto &index_entry = entry.Cast(); + index_entry.CommitDrop(); + log->WriteDropIndex(entry.Cast()); + break; + } + case CatalogType::PREPARED_STATEMENT: + case CatalogType::SCALAR_FUNCTION_ENTRY: + // do nothing, indexes/prepared statements/functions aren't persisted to disk + break; + default: + throw InternalException("Don't know how to drop this type!"); + } + break; + case CatalogType::PREPARED_STATEMENT: + case CatalogType::AGGREGATE_FUNCTION_ENTRY: + case CatalogType::SCALAR_FUNCTION_ENTRY: + case CatalogType::TABLE_FUNCTION_ENTRY: + case CatalogType::COPY_FUNCTION_ENTRY: + case CatalogType::PRAGMA_FUNCTION_ENTRY: + case CatalogType::COLLATION_ENTRY: + // do nothing, these entries are not persisted to disk + break; + default: + throw InternalException("UndoBuffer - don't know how to write this entry to the WAL"); + } +} + +void CommitState::WriteDelete(DeleteInfo &info) { + D_ASSERT(log); + // switch to the current table, if necessary + SwitchTable(info.table->info.get(), UndoFlags::DELETE_TUPLE); + + if (!delete_chunk) { + delete_chunk = make_uniq(); + vector delete_types = {LogicalType::ROW_TYPE}; + delete_chunk->Initialize(Allocator::DefaultAllocator(), delete_types); + } + auto rows = FlatVector::GetData(delete_chunk->data[0]); + for (idx_t i = 0; i < info.count; i++) { + rows[i] = info.base_row + info.rows[i]; + } + delete_chunk->SetCardinality(info.count); + log->WriteDelete(*delete_chunk); +} + +void CommitState::WriteUpdate(UpdateInfo &info) { + D_ASSERT(log); + // switch to the current table, if necessary + auto &column_data = info.segment->column_data; + auto &table_info = column_data.GetTableInfo(); + + SwitchTable(&table_info, UndoFlags::UPDATE_TUPLE); + + // initialize the update chunk + vector update_types; + if (column_data.type.id() == LogicalTypeId::VALIDITY) { + update_types.emplace_back(LogicalType::BOOLEAN); + } else { + update_types.push_back(column_data.type); + } + update_types.emplace_back(LogicalType::ROW_TYPE); + + update_chunk = make_uniq(); + update_chunk->Initialize(Allocator::DefaultAllocator(), update_types); + + // fetch the updated values from the base segment + info.segment->FetchCommitted(info.vector_index, update_chunk->data[0]); + + // write the row ids into the chunk + auto row_ids = FlatVector::GetData(update_chunk->data[1]); + idx_t start = column_data.start + info.vector_index * STANDARD_VECTOR_SIZE; + for (idx_t i = 0; i < info.N; i++) { + row_ids[info.tuples[i]] = start + info.tuples[i]; + } + if (column_data.type.id() == LogicalTypeId::VALIDITY) { + // zero-initialize the booleans + // FIXME: this is only required because of NullValue in Vector::Serialize... + auto booleans = FlatVector::GetData(update_chunk->data[0]); + for (idx_t i = 0; i < info.N; i++) { + auto idx = info.tuples[i]; + booleans[idx] = false; + } + } + SelectionVector sel(info.tuples); + update_chunk->Slice(sel, info.N); + + // construct the column index path + vector column_indexes; + reference current_column_data = column_data; + while (current_column_data.get().parent) { + column_indexes.push_back(current_column_data.get().column_index); + current_column_data = *current_column_data.get().parent; + } + column_indexes.push_back(info.column_index); + std::reverse(column_indexes.begin(), column_indexes.end()); + + log->WriteUpdate(*update_chunk, column_indexes); +} + +template +void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { + switch (type) { + case UndoFlags::CATALOG_ENTRY: { + // set the commit timestamp of the catalog entry to the given id + auto catalog_entry = Load(data); + D_ASSERT(catalog_entry->parent); + + auto &catalog = catalog_entry->ParentCatalog(); + D_ASSERT(catalog.IsDuckCatalog()); + + // Grab a write lock on the catalog + auto &duck_catalog = catalog.Cast(); + lock_guard write_lock(duck_catalog.GetWriteLock()); + catalog_entry->set->UpdateTimestamp(*catalog_entry->parent, commit_id); + if (catalog_entry->name != catalog_entry->parent->name) { + catalog_entry->set->UpdateTimestamp(*catalog_entry, commit_id); + } + if (HAS_LOG) { + // push the catalog update to the WAL + WriteCatalogEntry(*catalog_entry, data + sizeof(CatalogEntry *)); + } + break; + } + case UndoFlags::INSERT_TUPLE: { + // append: + auto info = reinterpret_cast(data); + if (HAS_LOG && !info->table->info->IsTemporary()) { + info->table->WriteToLog(*log, info->start_row, info->count); + } + // mark the tuples as committed + info->table->CommitAppend(commit_id, info->start_row, info->count); + break; + } + case UndoFlags::DELETE_TUPLE: { + // deletion: + auto info = reinterpret_cast(data); + if (HAS_LOG && !info->table->info->IsTemporary()) { + WriteDelete(*info); + } + // mark the tuples as committed + info->version_info->CommitDelete(info->vector_idx, commit_id, info->rows, info->count); + break; + } + case UndoFlags::UPDATE_TUPLE: { + // update: + auto info = reinterpret_cast(data); + if (HAS_LOG && !info->segment->column_data.GetTableInfo().IsTemporary()) { + WriteUpdate(*info); + } + info->version_number = commit_id; + break; + } + default: + throw InternalException("UndoBuffer - don't know how to commit this type!"); + } +} + +void CommitState::RevertCommit(UndoFlags type, data_ptr_t data) { + transaction_t transaction_id = commit_id; + switch (type) { + case UndoFlags::CATALOG_ENTRY: { + // set the commit timestamp of the catalog entry to the given id + auto catalog_entry = Load(data); + D_ASSERT(catalog_entry->parent); + catalog_entry->set->UpdateTimestamp(*catalog_entry->parent, transaction_id); + if (catalog_entry->name != catalog_entry->parent->name) { + catalog_entry->set->UpdateTimestamp(*catalog_entry, transaction_id); + } + break; + } + case UndoFlags::INSERT_TUPLE: { + auto info = reinterpret_cast(data); + // revert this append + info->table->RevertAppend(info->start_row, info->count); + break; + } + case UndoFlags::DELETE_TUPLE: { + // deletion: + auto info = reinterpret_cast(data); + info->table->info->cardinality += info->count; + // revert the commit by writing the (uncommitted) transaction_id back into the version info + info->version_info->CommitDelete(info->vector_idx, transaction_id, info->rows, info->count); + break; + } + case UndoFlags::UPDATE_TUPLE: { + // update: + auto info = reinterpret_cast(data); + info->version_number = transaction_id; + break; + } + default: + throw InternalException("UndoBuffer - don't know how to revert commit of this type!"); + } +} + +template void CommitState::CommitEntry(UndoFlags type, data_ptr_t data); +template void CommitState::CommitEntry(UndoFlags type, data_ptr_t data); + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/duck_transaction.cpp b/src/duckdb/src/transaction/duck_transaction.cpp new file mode 100644 index 00000000..3448a0a8 --- /dev/null +++ b/src/duckdb/src/transaction/duck_transaction.cpp @@ -0,0 +1,159 @@ +#include "duckdb/transaction/duck_transaction.hpp" + +#include "duckdb/main/client_context.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/parser/column_definition.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/write_ahead_log.hpp" +#include "duckdb/storage/storage_manager.hpp" + +#include "duckdb/transaction/append_info.hpp" +#include "duckdb/transaction/delete_info.hpp" +#include "duckdb/transaction/update_info.hpp" +#include "duckdb/transaction/local_storage.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/attached_database.hpp" + +namespace duckdb { + +TransactionData::TransactionData(DuckTransaction &transaction_p) // NOLINT + : transaction(&transaction_p), transaction_id(transaction_p.transaction_id), start_time(transaction_p.start_time) { +} +TransactionData::TransactionData(transaction_t transaction_id_p, transaction_t start_time_p) + : transaction(nullptr), transaction_id(transaction_id_p), start_time(start_time_p) { +} + +DuckTransaction::DuckTransaction(TransactionManager &manager, ClientContext &context_p, transaction_t start_time, + transaction_t transaction_id) + : Transaction(manager, context_p), start_time(start_time), transaction_id(transaction_id), commit_id(0), + highest_active_query(0), undo_buffer(context_p), storage(make_uniq(context_p, *this)) { +} + +DuckTransaction::~DuckTransaction() { +} + +DuckTransaction &DuckTransaction::Get(ClientContext &context, AttachedDatabase &db) { + return DuckTransaction::Get(context, db.GetCatalog()); +} + +DuckTransaction &DuckTransaction::Get(ClientContext &context, Catalog &catalog) { + auto &transaction = Transaction::Get(context, catalog); + if (!transaction.IsDuckTransaction()) { + throw InternalException("DuckTransaction::Get called on non-DuckDB transaction"); + } + return transaction.Cast(); +} + +LocalStorage &DuckTransaction::GetLocalStorage() { + return *storage; +} + +void DuckTransaction::PushCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data, idx_t extra_data_size) { + idx_t alloc_size = sizeof(CatalogEntry *); + if (extra_data_size > 0) { + alloc_size += extra_data_size + sizeof(idx_t); + } + auto baseptr = undo_buffer.CreateEntry(UndoFlags::CATALOG_ENTRY, alloc_size); + // store the pointer to the catalog entry + Store(&entry, baseptr); + if (extra_data_size > 0) { + // copy the extra data behind the catalog entry pointer (if any) + baseptr += sizeof(CatalogEntry *); + // first store the extra data size + Store(extra_data_size, baseptr); + baseptr += sizeof(idx_t); + // then copy over the actual data + memcpy(baseptr, extra_data, extra_data_size); + } +} + +void DuckTransaction::PushDelete(DataTable &table, RowVersionManager &info, idx_t vector_idx, row_t rows[], idx_t count, + idx_t base_row) { + auto delete_info = reinterpret_cast( + undo_buffer.CreateEntry(UndoFlags::DELETE_TUPLE, sizeof(DeleteInfo) + sizeof(row_t) * count)); + delete_info->version_info = &info; + delete_info->vector_idx = vector_idx; + delete_info->table = &table; + delete_info->count = count; + delete_info->base_row = base_row; + memcpy(delete_info->rows, rows, sizeof(row_t) * count); +} + +void DuckTransaction::PushAppend(DataTable &table, idx_t start_row, idx_t row_count) { + auto append_info = + reinterpret_cast(undo_buffer.CreateEntry(UndoFlags::INSERT_TUPLE, sizeof(AppendInfo))); + append_info->table = &table; + append_info->start_row = start_row; + append_info->count = row_count; +} + +UpdateInfo *DuckTransaction::CreateUpdateInfo(idx_t type_size, idx_t entries) { + data_ptr_t base_info = undo_buffer.CreateEntry( + UndoFlags::UPDATE_TUPLE, sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); + auto update_info = reinterpret_cast(base_info); + update_info->max = STANDARD_VECTOR_SIZE; + update_info->tuples = reinterpret_cast(base_info + sizeof(UpdateInfo)); + update_info->tuple_data = base_info + sizeof(UpdateInfo) + sizeof(sel_t) * update_info->max; + update_info->version_number = transaction_id; + return update_info; +} + +bool DuckTransaction::ChangesMade() { + return undo_buffer.ChangesMade() || storage->ChangesMade(); +} + +bool DuckTransaction::AutomaticCheckpoint(AttachedDatabase &db) { + auto &storage_manager = db.GetStorageManager(); + return storage_manager.AutomaticCheckpoint(storage->EstimatedSize() + undo_buffer.EstimatedSize()); +} + +string DuckTransaction::Commit(AttachedDatabase &db, transaction_t commit_id, bool checkpoint) noexcept { + // "checkpoint" parameter indicates if the caller will checkpoint. If checkpoint == + // true: Then this function will NOT write to the WAL or flush/persist. + // This method only makes commit in memory, expecting caller to checkpoint/flush. + // false: Then this function WILL write to the WAL and Flush/Persist it. + this->commit_id = commit_id; + + UndoBuffer::IteratorState iterator_state; + LocalStorage::CommitState commit_state; + unique_ptr storage_commit_state; + optional_ptr log; + if (!db.IsSystem()) { + auto &storage_manager = db.GetStorageManager(); + log = storage_manager.GetWriteAheadLog(); + storage_commit_state = storage_manager.GenStorageCommitState(*this, checkpoint); + } else { + log = nullptr; + } + try { + storage->Commit(commit_state, *this); + undo_buffer.Commit(iterator_state, log, commit_id); + if (log) { + // commit any sequences that were used to the WAL + for (auto &entry : sequence_usage) { + log->WriteSequenceValue(*entry.first, entry.second); + } + } + if (storage_commit_state) { + storage_commit_state->FlushCommit(); + } + return string(); + } catch (std::exception &ex) { + undo_buffer.RevertCommit(iterator_state, this->transaction_id); + return ex.what(); + } +} + +void DuckTransaction::Rollback() noexcept { + storage->Rollback(); + undo_buffer.Rollback(); +} + +void DuckTransaction::Cleanup() { + undo_buffer.Cleanup(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/duck_transaction_manager.cpp b/src/duckdb/src/transaction/duck_transaction_manager.cpp new file mode 100644 index 00000000..b5dab24d --- /dev/null +++ b/src/duckdb/src/transaction/duck_transaction_manager.cpp @@ -0,0 +1,342 @@ +#include "duckdb/transaction/duck_transaction_manager.hpp" + +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/connection_manager.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/database_manager.hpp" + +namespace duckdb { + +struct CheckpointLock { + explicit CheckpointLock(DuckTransactionManager &manager) : manager(manager), is_locked(false) { + } + ~CheckpointLock() { + Unlock(); + } + + DuckTransactionManager &manager; + bool is_locked; + + void Lock() { + D_ASSERT(!manager.thread_is_checkpointing); + manager.thread_is_checkpointing = true; + is_locked = true; + } + void Unlock() { + if (!is_locked) { + return; + } + D_ASSERT(manager.thread_is_checkpointing); + manager.thread_is_checkpointing = false; + is_locked = false; + } +}; + +DuckTransactionManager::DuckTransactionManager(AttachedDatabase &db) + : TransactionManager(db), thread_is_checkpointing(false) { + // start timestamp starts at two + current_start_timestamp = 2; + // transaction ID starts very high: + // it should be much higher than the current start timestamp + // if transaction_id < start_timestamp for any set of active transactions + // uncommited data could be read by + current_transaction_id = TRANSACTION_ID_START; + lowest_active_id = TRANSACTION_ID_START; + lowest_active_start = MAX_TRANSACTION_ID; +} + +DuckTransactionManager::~DuckTransactionManager() { +} + +DuckTransactionManager &DuckTransactionManager::Get(AttachedDatabase &db) { + auto &transaction_manager = TransactionManager::Get(db); + if (!transaction_manager.IsDuckTransactionManager()) { + throw InternalException("Calling DuckTransactionManager::Get on non-DuckDB transaction manager"); + } + return reinterpret_cast(transaction_manager); +} + +Transaction *DuckTransactionManager::StartTransaction(ClientContext &context) { + // obtain the transaction lock during this function + lock_guard lock(transaction_lock); + if (current_start_timestamp >= TRANSACTION_ID_START) { // LCOV_EXCL_START + throw InternalException("Cannot start more transactions, ran out of " + "transaction identifiers!"); + } // LCOV_EXCL_STOP + + // obtain the start time and transaction ID of this transaction + transaction_t start_time = current_start_timestamp++; + transaction_t transaction_id = current_transaction_id++; + if (active_transactions.empty()) { + lowest_active_start = start_time; + lowest_active_id = transaction_id; + } + + // create the actual transaction + auto transaction = make_uniq(*this, context, start_time, transaction_id); + auto transaction_ptr = transaction.get(); + + // store it in the set of active transactions + active_transactions.push_back(std::move(transaction)); + return transaction_ptr; +} + +struct ClientLockWrapper { + ClientLockWrapper(mutex &client_lock, shared_ptr connection) + : connection(std::move(connection)), connection_lock(make_uniq>(client_lock)) { + } + + shared_ptr connection; + unique_ptr> connection_lock; +}; + +void DuckTransactionManager::LockClients(vector &client_locks, ClientContext &context) { + auto &connection_manager = ConnectionManager::Get(context); + client_locks.emplace_back(connection_manager.connections_lock, nullptr); + auto connection_list = connection_manager.GetConnectionList(); + for (auto &con : connection_list) { + if (con.get() == &context) { + continue; + } + auto &context_lock = con->context_lock; + client_locks.emplace_back(context_lock, std::move(con)); + } +} + +void DuckTransactionManager::Checkpoint(ClientContext &context, bool force) { + auto &storage_manager = db.GetStorageManager(); + if (storage_manager.InMemory()) { + return; + } + + // first check if no other thread is checkpointing right now + auto lock = unique_lock(transaction_lock); + if (thread_is_checkpointing) { + throw TransactionException("Cannot CHECKPOINT: another thread is checkpointing right now"); + } + CheckpointLock checkpoint_lock(*this); + checkpoint_lock.Lock(); + lock.unlock(); + + // lock all the clients AND the connection manager now + // this ensures no new queries can be started, and no new connections to the database can be made + // to avoid deadlock we release the transaction lock while locking the clients + vector client_locks; + LockClients(client_locks, context); + + auto current = &DuckTransaction::Get(context, db); + lock.lock(); + if (current->ChangesMade()) { + throw TransactionException("Cannot CHECKPOINT: the current transaction has transaction local changes"); + } + if (!force) { + if (!CanCheckpoint(current)) { + throw TransactionException("Cannot CHECKPOINT: there are other transactions. Use FORCE CHECKPOINT to abort " + "the other transactions and force a checkpoint"); + } + } else { + if (!CanCheckpoint(current)) { + for (size_t i = 0; i < active_transactions.size(); i++) { + auto &transaction = active_transactions[i]; + // rollback the transaction + transaction->Rollback(); + auto transaction_context = transaction->context.lock(); + + // remove the transaction id from the list of active transactions + // potentially resulting in garbage collection + RemoveTransaction(*transaction); + if (transaction_context) { + transaction_context->transaction.ClearTransaction(); + } + i--; + } + D_ASSERT(CanCheckpoint(nullptr)); + } + } + storage_manager.CreateCheckpoint(); +} + +bool DuckTransactionManager::CanCheckpoint(optional_ptr current) { + if (db.IsSystem()) { + return false; + } + auto &storage_manager = db.GetStorageManager(); + if (storage_manager.InMemory()) { + return false; + } + if (!recently_committed_transactions.empty() || !old_transactions.empty()) { + return false; + } + for (auto &transaction : active_transactions) { + if (transaction.get() != current.get()) { + return false; + } + } + return true; +} + +string DuckTransactionManager::CommitTransaction(ClientContext &context, Transaction *transaction_p) { + auto &transaction = transaction_p->Cast(); + vector client_locks; + auto lock = make_uniq>(transaction_lock); + CheckpointLock checkpoint_lock(*this); + // check if we can checkpoint + bool checkpoint = thread_is_checkpointing ? false : CanCheckpoint(&transaction); + if (checkpoint) { + if (transaction.AutomaticCheckpoint(db)) { + checkpoint_lock.Lock(); + // we might be able to checkpoint: lock all clients + // to avoid deadlock we release the transaction lock while locking the clients + lock.reset(); + + LockClients(client_locks, context); + + lock = make_uniq>(transaction_lock); + checkpoint = CanCheckpoint(&transaction); + if (!checkpoint) { + checkpoint_lock.Unlock(); + client_locks.clear(); + } + } else { + checkpoint = false; + } + } + // obtain a commit id for the transaction + transaction_t commit_id = current_start_timestamp++; + // commit the UndoBuffer of the transaction + string error = transaction.Commit(db, commit_id, checkpoint); + if (!error.empty()) { + // commit unsuccessful: rollback the transaction instead + checkpoint = false; + transaction.commit_id = 0; + transaction.Rollback(); + } + if (!checkpoint) { + // we won't checkpoint after all: unlock the clients again + checkpoint_lock.Unlock(); + client_locks.clear(); + } + + // commit successful: remove the transaction id from the list of active transactions + // potentially resulting in garbage collection + RemoveTransaction(transaction); + // now perform a checkpoint if (1) we are able to checkpoint, and (2) the WAL has reached sufficient size to + // checkpoint + if (checkpoint) { + // checkpoint the database to disk + auto &storage_manager = db.GetStorageManager(); + storage_manager.CreateCheckpoint(false, true); + } + return error; +} + +void DuckTransactionManager::RollbackTransaction(Transaction *transaction_p) { + auto &transaction = transaction_p->Cast(); + // obtain the transaction lock during this function + lock_guard lock(transaction_lock); + + // rollback the transaction + transaction.Rollback(); + + // remove the transaction id from the list of active transactions + // potentially resulting in garbage collection + RemoveTransaction(transaction); +} + +void DuckTransactionManager::RemoveTransaction(DuckTransaction &transaction) noexcept { + // remove the transaction from the list of active transactions + idx_t t_index = active_transactions.size(); + // check for the lowest and highest start time in the list of transactions + transaction_t lowest_start_time = TRANSACTION_ID_START; + transaction_t lowest_transaction_id = MAX_TRANSACTION_ID; + transaction_t lowest_active_query = MAXIMUM_QUERY_ID; + for (idx_t i = 0; i < active_transactions.size(); i++) { + if (active_transactions[i].get() == &transaction) { + t_index = i; + } else { + transaction_t active_query = active_transactions[i]->active_query; + lowest_start_time = MinValue(lowest_start_time, active_transactions[i]->start_time); + lowest_active_query = MinValue(lowest_active_query, active_query); + lowest_transaction_id = MinValue(lowest_transaction_id, active_transactions[i]->transaction_id); + } + } + lowest_active_start = lowest_start_time; + lowest_active_id = lowest_transaction_id; + + transaction_t lowest_stored_query = lowest_start_time; + D_ASSERT(t_index != active_transactions.size()); + auto current_transaction = std::move(active_transactions[t_index]); + auto current_query = DatabaseManager::Get(db).ActiveQueryNumber(); + if (transaction.commit_id != 0) { + // the transaction was committed, add it to the list of recently + // committed transactions + recently_committed_transactions.push_back(std::move(current_transaction)); + } else { + // the transaction was aborted, but we might still need its information + // add it to the set of transactions awaiting GC + current_transaction->highest_active_query = current_query; + old_transactions.push_back(std::move(current_transaction)); + } + // remove the transaction from the set of currently active transactions + active_transactions.erase(active_transactions.begin() + t_index); + // traverse the recently_committed transactions to see if we can remove any + idx_t i = 0; + for (; i < recently_committed_transactions.size(); i++) { + D_ASSERT(recently_committed_transactions[i]); + lowest_stored_query = MinValue(recently_committed_transactions[i]->start_time, lowest_stored_query); + if (recently_committed_transactions[i]->commit_id < lowest_start_time) { + // changes made BEFORE this transaction are no longer relevant + // we can cleanup the undo buffer + + // HOWEVER: any currently running QUERY can still be using + // the version information after the cleanup! + + // if we remove the UndoBuffer immediately, we have a race + // condition + + // we can only safely do the actual memory cleanup when all the + // currently active queries have finished running! (actually, + // when all the currently active scans have finished running...) + recently_committed_transactions[i]->Cleanup(); + // store the current highest active query + recently_committed_transactions[i]->highest_active_query = current_query; + // move it to the list of transactions awaiting GC + old_transactions.push_back(std::move(recently_committed_transactions[i])); + } else { + // recently_committed_transactions is ordered on commit_id + // implicitly thus if the current one is bigger than + // lowest_start_time any subsequent ones are also bigger + break; + } + } + if (i > 0) { + // we garbage collected transactions: remove them from the list + recently_committed_transactions.erase(recently_committed_transactions.begin(), + recently_committed_transactions.begin() + i); + } + // check if we can free the memory of any old transactions + i = active_transactions.empty() ? old_transactions.size() : 0; + for (; i < old_transactions.size(); i++) { + D_ASSERT(old_transactions[i]); + D_ASSERT(old_transactions[i]->highest_active_query > 0); + if (old_transactions[i]->highest_active_query >= lowest_active_query) { + // there is still a query running that could be using + // this transactions' data + break; + } + } + if (i > 0) { + // we garbage collected transactions: remove them from the list + old_transactions.erase(old_transactions.begin(), old_transactions.begin() + i); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/meta_transaction.cpp b/src/duckdb/src/transaction/meta_transaction.cpp new file mode 100644 index 00000000..d6dd6ddf --- /dev/null +++ b/src/duckdb/src/transaction/meta_transaction.cpp @@ -0,0 +1,109 @@ +#include "duckdb/transaction/meta_transaction.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/transaction/transaction_manager.hpp" + +namespace duckdb { + +MetaTransaction::MetaTransaction(ClientContext &context_p, timestamp_t start_timestamp_p, idx_t catalog_version_p) + : context(context_p), start_timestamp(start_timestamp_p), catalog_version(catalog_version_p), read_only(true), + active_query(MAXIMUM_QUERY_ID), modified_database(nullptr) { +} + +MetaTransaction &MetaTransaction::Get(ClientContext &context) { + return context.transaction.ActiveTransaction(); +} + +ValidChecker &ValidChecker::Get(MetaTransaction &transaction) { + return transaction.transaction_validity; +} + +Transaction &Transaction::Get(ClientContext &context, AttachedDatabase &db) { + auto &meta_transaction = MetaTransaction::Get(context); + return meta_transaction.GetTransaction(db); +} + +Transaction &MetaTransaction::GetTransaction(AttachedDatabase &db) { + auto entry = transactions.find(&db); + if (entry == transactions.end()) { + auto new_transaction = db.GetTransactionManager().StartTransaction(context); + if (!new_transaction) { + throw InternalException("StartTransaction did not return a valid transaction"); + } + new_transaction->active_query = active_query; + all_transactions.push_back(&db); + transactions[&db] = new_transaction; + return *new_transaction; + } else { + D_ASSERT(entry->second->active_query == active_query); + return *entry->second; + } +} + +Transaction &Transaction::Get(ClientContext &context, Catalog &catalog) { + return Transaction::Get(context, catalog.GetAttached()); +} + +string MetaTransaction::Commit() { + string error; + // commit transactions in reverse order + for (idx_t i = all_transactions.size(); i > 0; i--) { + auto db = all_transactions[i - 1]; + auto entry = transactions.find(db.get()); + if (entry == transactions.end()) { + throw InternalException("Could not find transaction corresponding to database in MetaTransaction"); + } + auto &transaction_manager = db->GetTransactionManager(); + auto transaction = entry->second; + if (error.empty()) { + // commit + error = transaction_manager.CommitTransaction(context, transaction); + } else { + // we have encountered an error previously - roll back subsequent entries + transaction_manager.RollbackTransaction(transaction); + } + } + return error; +} + +void MetaTransaction::Rollback() { + // rollback transactions in reverse order + for (idx_t i = all_transactions.size(); i > 0; i--) { + auto db = all_transactions[i - 1]; + auto &transaction_manager = db->GetTransactionManager(); + auto entry = transactions.find(db.get()); + D_ASSERT(entry != transactions.end()); + auto transaction = entry->second; + transaction_manager.RollbackTransaction(transaction); + } +} + +idx_t MetaTransaction::GetActiveQuery() { + return active_query; +} + +void MetaTransaction::SetActiveQuery(transaction_t query_number) { + active_query = query_number; + for (auto &entry : transactions) { + entry.second->active_query = query_number; + } +} + +void MetaTransaction::ModifyDatabase(AttachedDatabase &db) { + if (db.IsSystem() || db.IsTemporary()) { + // we can always modify the system and temp databases + return; + } + if (!modified_database) { + modified_database = &db; + return; + } + if (&db != modified_database.get()) { + throw TransactionException( + "Attempting to write to database \"%s\" in a transaction that has already modified database \"%s\" - a " + "single transaction can only write to a single attached database.", + db.GetName(), modified_database->GetName()); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/rollback_state.cpp b/src/duckdb/src/transaction/rollback_state.cpp new file mode 100644 index 00000000..b30124c1 --- /dev/null +++ b/src/duckdb/src/transaction/rollback_state.cpp @@ -0,0 +1,48 @@ +#include "duckdb/transaction/rollback_state.hpp" +#include "duckdb/transaction/append_info.hpp" +#include "duckdb/transaction/delete_info.hpp" +#include "duckdb/transaction/update_info.hpp" + +#include "duckdb/storage/table/chunk_info.hpp" + +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/table/row_version_manager.hpp" + +namespace duckdb { + +void RollbackState::RollbackEntry(UndoFlags type, data_ptr_t data) { + switch (type) { + case UndoFlags::CATALOG_ENTRY: { + // undo this catalog entry + auto catalog_entry = Load(data); + D_ASSERT(catalog_entry->set); + catalog_entry->set->Undo(*catalog_entry); + break; + } + case UndoFlags::INSERT_TUPLE: { + auto info = reinterpret_cast(data); + // revert the append in the base table + info->table->RevertAppend(info->start_row, info->count); + break; + } + case UndoFlags::DELETE_TUPLE: { + auto info = reinterpret_cast(data); + // reset the deleted flag on rollback + info->version_info->CommitDelete(info->vector_idx, NOT_DELETED_ID, info->rows, info->count); + break; + } + case UndoFlags::UPDATE_TUPLE: { + auto info = reinterpret_cast(data); + info->segment->RollbackUpdate(*info); + break; + } + default: // LCOV_EXCL_START + D_ASSERT(type == UndoFlags::EMPTY_ENTRY); + break; + } // LCOV_EXCL_STOP +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/transaction.cpp b/src/duckdb/src/transaction/transaction.cpp new file mode 100644 index 00000000..5c37879d --- /dev/null +++ b/src/duckdb/src/transaction/transaction.cpp @@ -0,0 +1,24 @@ +#include "duckdb/transaction/transaction.hpp" +#include "duckdb/transaction/meta_transaction.hpp" +#include "duckdb/transaction/transaction_manager.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +Transaction::Transaction(TransactionManager &manager_p, ClientContext &context_p) + : manager(manager_p), context(context_p.shared_from_this()), active_query(MAXIMUM_QUERY_ID) { +} + +Transaction::~Transaction() { +} + +bool Transaction::IsReadOnly() { + auto ctxt = context.lock(); + if (!ctxt) { + throw InternalException("Transaction::IsReadOnly() called after client context has been destroyed"); + } + auto &db = manager.GetDB(); + return MetaTransaction::Get(*ctxt).ModifiedDatabase().get() != &db; +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/transaction_context.cpp b/src/duckdb/src/transaction/transaction_context.cpp new file mode 100644 index 00000000..4b65a721 --- /dev/null +++ b/src/duckdb/src/transaction/transaction_context.cpp @@ -0,0 +1,95 @@ +#include "duckdb/transaction/transaction_context.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/transaction/meta_transaction.hpp" +#include "duckdb/transaction/transaction_manager.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/database_manager.hpp" + +namespace duckdb { + +TransactionContext::TransactionContext(ClientContext &context) + : context(context), auto_commit(true), current_transaction(nullptr) { +} + +TransactionContext::~TransactionContext() { + if (current_transaction) { + try { + Rollback(); + } catch (...) { + } + } +} + +void TransactionContext::BeginTransaction() { + if (current_transaction) { + throw TransactionException("cannot start a transaction within a transaction"); + } + auto start_timestamp = Timestamp::GetCurrentTimestamp(); + auto catalog_version = Catalog::GetSystemCatalog(context).GetCatalogVersion(); + current_transaction = make_uniq(context, start_timestamp, catalog_version); + + auto &config = DBConfig::GetConfig(context); + if (config.options.immediate_transaction_mode) { + // if immediate transaction mode is enabled then start all transactions immediately + auto databases = DatabaseManager::Get(context).GetDatabases(context); + for (auto db : databases) { + current_transaction->GetTransaction(db.get()); + } + } +} + +void TransactionContext::Commit() { + if (!current_transaction) { + throw TransactionException("failed to commit: no transaction active"); + } + auto transaction = std::move(current_transaction); + ClearTransaction(); + string error = transaction->Commit(); + if (!error.empty()) { + throw TransactionException("Failed to commit: %s", error); + } +} + +void TransactionContext::SetAutoCommit(bool value) { + auto_commit = value; + if (!auto_commit && !current_transaction) { + BeginTransaction(); + } +} + +void TransactionContext::Rollback() { + if (!current_transaction) { + throw TransactionException("failed to rollback: no transaction active"); + } + auto transaction = std::move(current_transaction); + ClearTransaction(); + transaction->Rollback(); +} + +void TransactionContext::ClearTransaction() { + SetAutoCommit(true); + current_transaction = nullptr; +} + +idx_t TransactionContext::GetActiveQuery() { + if (!current_transaction) { + throw InternalException("GetActiveQuery called without active transaction"); + } + return current_transaction->GetActiveQuery(); +} + +void TransactionContext::ResetActiveQuery() { + if (current_transaction) { + SetActiveQuery(MAXIMUM_QUERY_ID); + } +} + +void TransactionContext::SetActiveQuery(transaction_t query_number) { + if (!current_transaction) { + throw InternalException("SetActiveQuery called without active transaction"); + } + current_transaction->SetActiveQuery(query_number); +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/transaction_manager.cpp b/src/duckdb/src/transaction/transaction_manager.cpp new file mode 100644 index 00000000..90ec3a42 --- /dev/null +++ b/src/duckdb/src/transaction/transaction_manager.cpp @@ -0,0 +1,11 @@ +#include "duckdb/transaction/transaction_manager.hpp" + +namespace duckdb { + +TransactionManager::TransactionManager(AttachedDatabase &db) : db(db) { +} + +TransactionManager::~TransactionManager() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/transaction/undo_buffer.cpp b/src/duckdb/src/transaction/undo_buffer.cpp new file mode 100644 index 00000000..bf1717b7 --- /dev/null +++ b/src/duckdb/src/transaction/undo_buffer.cpp @@ -0,0 +1,160 @@ +#include "duckdb/transaction/undo_buffer.hpp" + +#include "duckdb/catalog/catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/list.hpp" +#include "duckdb/catalog/catalog_set.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/write_ahead_log.hpp" +#include "duckdb/transaction/cleanup_state.hpp" +#include "duckdb/transaction/commit_state.hpp" +#include "duckdb/transaction/rollback_state.hpp" +#include "duckdb/common/pair.hpp" + +namespace duckdb { +constexpr uint32_t UNDO_ENTRY_HEADER_SIZE = sizeof(UndoFlags) + sizeof(uint32_t); + +UndoBuffer::UndoBuffer(ClientContext &context_p) : allocator(BufferAllocator::Get(context_p)) { +} + +data_ptr_t UndoBuffer::CreateEntry(UndoFlags type, idx_t len) { + D_ASSERT(len <= NumericLimits::Maximum()); + len = AlignValue(len); + idx_t needed_space = len + UNDO_ENTRY_HEADER_SIZE; + auto data = allocator.Allocate(needed_space); + Store(type, data); + data += sizeof(UndoFlags); + Store(len, data); + data += sizeof(uint32_t); + return data; +} + +template +void UndoBuffer::IterateEntries(UndoBuffer::IteratorState &state, T &&callback) { + // iterate in insertion order: start with the tail + state.current = allocator.GetTail(); + while (state.current) { + state.start = state.current->data.get(); + state.end = state.start + state.current->current_position; + while (state.start < state.end) { + UndoFlags type = Load(state.start); + state.start += sizeof(UndoFlags); + + uint32_t len = Load(state.start); + state.start += sizeof(uint32_t); + callback(type, state.start); + state.start += len; + } + state.current = state.current->prev; + } +} + +template +void UndoBuffer::IterateEntries(UndoBuffer::IteratorState &state, UndoBuffer::IteratorState &end_state, T &&callback) { + // iterate in insertion order: start with the tail + state.current = allocator.GetTail(); + while (state.current) { + state.start = state.current->data.get(); + state.end = + state.current == end_state.current ? end_state.start : state.start + state.current->current_position; + while (state.start < state.end) { + auto type = Load(state.start); + state.start += sizeof(UndoFlags); + auto len = Load(state.start); + state.start += sizeof(uint32_t); + callback(type, state.start); + state.start += len; + } + if (state.current == end_state.current) { + // finished executing until the current end state + return; + } + state.current = state.current->prev; + } +} + +template +void UndoBuffer::ReverseIterateEntries(T &&callback) { + // iterate in reverse insertion order: start with the head + auto current = allocator.GetHead(); + while (current) { + data_ptr_t start = current->data.get(); + data_ptr_t end = start + current->current_position; + // create a vector with all nodes in this chunk + vector> nodes; + while (start < end) { + auto type = Load(start); + start += sizeof(UndoFlags); + auto len = Load(start); + start += sizeof(uint32_t); + nodes.emplace_back(type, start); + start += len; + } + // iterate over it in reverse order + for (idx_t i = nodes.size(); i > 0; i--) { + callback(nodes[i - 1].first, nodes[i - 1].second); + } + current = current->next.get(); + } +} + +bool UndoBuffer::ChangesMade() { + return !allocator.IsEmpty(); +} + +idx_t UndoBuffer::EstimatedSize() { + idx_t estimated_size = 0; + auto node = allocator.GetHead(); + while (node) { + estimated_size += node->current_position; + node = node->next.get(); + } + return estimated_size; +} + +void UndoBuffer::Cleanup() { + // garbage collect everything in the Undo Chunk + // this should only happen if + // (1) the transaction this UndoBuffer belongs to has successfully + // committed + // (on Rollback the Rollback() function should be called, that clears + // the chunks) + // (2) there is no active transaction with start_id < commit_id of this + // transaction + CleanupState state; + UndoBuffer::IteratorState iterator_state; + IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CleanupEntry(type, data); }); + + // possibly vacuum indexes + for (const auto &table : state.indexed_tables) { + table.second->info->indexes.Scan([&](Index &index) { + index.Vacuum(); + return false; + }); + } +} + +void UndoBuffer::Commit(UndoBuffer::IteratorState &iterator_state, optional_ptr log, + transaction_t commit_id) { + CommitState state(commit_id, log); + if (log) { + // commit WITH write ahead log + IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CommitEntry(type, data); }); + } else { + // commit WITHOUT write ahead log + IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CommitEntry(type, data); }); + } +} + +void UndoBuffer::RevertCommit(UndoBuffer::IteratorState &end_state, transaction_t transaction_id) { + CommitState state(transaction_id, nullptr); + UndoBuffer::IteratorState start_state; + IterateEntries(start_state, end_state, [&](UndoFlags type, data_ptr_t data) { state.RevertCommit(type, data); }); +} + +void UndoBuffer::Rollback() noexcept { + // rollback needs to be performed in reverse + RollbackState state; + ReverseIterateEntries([&](UndoFlags type, data_ptr_t data) { state.RollbackEntry(type, data); }); +} +} // namespace duckdb diff --git a/src/duckdb/src/verification/copied_statement_verifier.cpp b/src/duckdb/src/verification/copied_statement_verifier.cpp new file mode 100644 index 00000000..6b603f3b --- /dev/null +++ b/src/duckdb/src/verification/copied_statement_verifier.cpp @@ -0,0 +1,13 @@ +#include "duckdb/verification/copied_statement_verifier.hpp" + +namespace duckdb { + +CopiedStatementVerifier::CopiedStatementVerifier(unique_ptr statement_p) + : StatementVerifier(VerificationType::COPIED, "Copied", std::move(statement_p)) { +} + +unique_ptr CopiedStatementVerifier::Create(const SQLStatement &statement) { + return make_uniq(statement.Copy()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/verification/deserialized_statement_verifier.cpp b/src/duckdb/src/verification/deserialized_statement_verifier.cpp new file mode 100644 index 00000000..dcbc1780 --- /dev/null +++ b/src/duckdb/src/verification/deserialized_statement_verifier.cpp @@ -0,0 +1,24 @@ +#include "duckdb/verification/deserialized_statement_verifier.hpp" + +#include "duckdb/common/serializer/binary_deserializer.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" +namespace duckdb { + +DeserializedStatementVerifier::DeserializedStatementVerifier(unique_ptr statement_p) + : StatementVerifier(VerificationType::DESERIALIZED, "Deserialized", std::move(statement_p)) { +} + +unique_ptr DeserializedStatementVerifier::Create(const SQLStatement &statement) { + + auto &select_stmt = statement.Cast(); + + MemoryStream stream; + BinarySerializer::Serialize(select_stmt, stream); + stream.Rewind(); + auto result = BinaryDeserializer::Deserialize(stream); + + return make_uniq(std::move(result)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/verification/external_statement_verifier.cpp b/src/duckdb/src/verification/external_statement_verifier.cpp new file mode 100644 index 00000000..5e9655d5 --- /dev/null +++ b/src/duckdb/src/verification/external_statement_verifier.cpp @@ -0,0 +1,13 @@ +#include "duckdb/verification/external_statement_verifier.hpp" + +namespace duckdb { + +ExternalStatementVerifier::ExternalStatementVerifier(unique_ptr statement_p) + : StatementVerifier(VerificationType::EXTERNAL, "External", std::move(statement_p)) { +} + +unique_ptr ExternalStatementVerifier::Create(const SQLStatement &statement) { + return make_uniq(statement.Copy()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/verification/no_operator_caching_verifier.cpp b/src/duckdb/src/verification/no_operator_caching_verifier.cpp new file mode 100644 index 00000000..10540931 --- /dev/null +++ b/src/duckdb/src/verification/no_operator_caching_verifier.cpp @@ -0,0 +1,13 @@ +#include "duckdb/verification/no_operator_caching_verifier.hpp" + +namespace duckdb { + +NoOperatorCachingVerifier::NoOperatorCachingVerifier(unique_ptr statement_p) + : StatementVerifier(VerificationType::NO_OPERATOR_CACHING, "No operator caching", std::move(statement_p)) { +} + +unique_ptr NoOperatorCachingVerifier::Create(const SQLStatement &statement_p) { + return make_uniq(statement_p.Copy()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/verification/parsed_statement_verifier.cpp b/src/duckdb/src/verification/parsed_statement_verifier.cpp new file mode 100644 index 00000000..a47141f9 --- /dev/null +++ b/src/duckdb/src/verification/parsed_statement_verifier.cpp @@ -0,0 +1,24 @@ +#include "duckdb/verification/parsed_statement_verifier.hpp" + +#include "duckdb/parser/parser.hpp" + +namespace duckdb { + +ParsedStatementVerifier::ParsedStatementVerifier(unique_ptr statement_p) + : StatementVerifier(VerificationType::PARSED, "Parsed", std::move(statement_p)) { +} + +unique_ptr ParsedStatementVerifier::Create(const SQLStatement &statement) { + auto query_str = statement.ToString(); + Parser parser; + try { + parser.ParseQuery(query_str); + } catch (std::exception &ex) { + throw InternalException("Parsed statement verification failed. Query:\n%s\n\nError: %s", query_str, ex.what()); + } + D_ASSERT(parser.statements.size() == 1); + D_ASSERT(parser.statements[0]->type == StatementType::SELECT_STATEMENT); + return make_uniq(std::move(parser.statements[0])); +} + +} // namespace duckdb diff --git a/src/duckdb/src/verification/prepared_statement_verifier.cpp b/src/duckdb/src/verification/prepared_statement_verifier.cpp new file mode 100644 index 00000000..31a1d9ff --- /dev/null +++ b/src/duckdb/src/verification/prepared_statement_verifier.cpp @@ -0,0 +1,111 @@ +#include "duckdb/verification/prepared_statement_verifier.hpp" + +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/statement/drop_statement.hpp" +#include "duckdb/parser/statement/execute_statement.hpp" +#include "duckdb/parser/statement/prepare_statement.hpp" + +namespace duckdb { + +PreparedStatementVerifier::PreparedStatementVerifier(unique_ptr statement_p) + : StatementVerifier(VerificationType::PREPARED, "Prepared", std::move(statement_p)) { +} + +unique_ptr PreparedStatementVerifier::Create(const SQLStatement &statement) { + return make_uniq(statement.Copy()); +} + +void PreparedStatementVerifier::Extract() { + auto &select = *statement; + // replace all the constants from the select statement and replace them with parameter expressions + ParsedExpressionIterator::EnumerateQueryNodeChildren( + *select.node, [&](unique_ptr &child) { ConvertConstants(child); }); + statement->n_param = values.size(); + for (auto &kv : values) { + statement->named_param_map[kv.first] = 0; + } + // create the PREPARE and EXECUTE statements + string name = "__duckdb_verification_prepared_statement"; + auto prepare = make_uniq(); + prepare->name = name; + prepare->statement = std::move(statement); + + auto execute = make_uniq(); + execute->name = name; + execute->named_values = std::move(values); + + auto dealloc = make_uniq(); + dealloc->info->type = CatalogType::PREPARED_STATEMENT; + dealloc->info->name = string(name); + + prepare_statement = std::move(prepare); + execute_statement = std::move(execute); + dealloc_statement = std::move(dealloc); +} + +void PreparedStatementVerifier::ConvertConstants(unique_ptr &child) { + if (child->type == ExpressionType::VALUE_CONSTANT) { + // constant: extract the constant value + auto alias = child->alias; + child->alias = string(); + // check if the value already exists + idx_t index = values.size(); + auto identifier = std::to_string(index + 1); + const auto predicate = [&](const std::pair> &pair) { + return pair.second->Equals(*child.get()); + }; + auto result = std::find_if(values.begin(), values.end(), predicate); + if (result == values.end()) { + // If it doesn't exist yet, add it + values[identifier] = std::move(child); + } else { + identifier = result->first; + } + + // replace it with an expression + auto parameter = make_uniq(); + parameter->identifier = identifier; + parameter->alias = alias; + child = std::move(parameter); + return; + } + ParsedExpressionIterator::EnumerateChildren(*child, + [&](unique_ptr &child) { ConvertConstants(child); }); +} + +bool PreparedStatementVerifier::Run( + ClientContext &context, const string &query, + const std::function(const string &, unique_ptr)> &run) { + bool failed = false; + // verify that we can extract all constants from the query and run the query as a prepared statement + // create the PREPARE and EXECUTE statements + Extract(); + // execute the prepared statements + try { + auto prepare_result = run(string(), std::move(prepare_statement)); + if (prepare_result->HasError()) { + prepare_result->ThrowError("Failed prepare during verify: "); + } + auto execute_result = run(string(), std::move(execute_statement)); + if (execute_result->HasError()) { + execute_result->ThrowError("Failed execute during verify: "); + } + materialized_result = unique_ptr_cast(std::move(execute_result)); + } catch (const Exception &ex) { + if (ex.type != ExceptionType::PARAMETER_NOT_ALLOWED) { + materialized_result = make_uniq(PreservedError(ex)); + } + failed = true; + } catch (std::exception &ex) { + materialized_result = make_uniq(PreservedError(ex)); + failed = true; + } + run(string(), std::move(dealloc_statement)); + context.interrupted = false; + + return failed; +} + +} // namespace duckdb diff --git a/src/duckdb/src/verification/statement_verifier.cpp b/src/duckdb/src/verification/statement_verifier.cpp new file mode 100644 index 00000000..3d473302 --- /dev/null +++ b/src/duckdb/src/verification/statement_verifier.cpp @@ -0,0 +1,153 @@ +#include "duckdb/verification/statement_verifier.hpp" + +#include "duckdb/common/preserved_error.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/verification/copied_statement_verifier.hpp" +#include "duckdb/verification/deserialized_statement_verifier.hpp" +#include "duckdb/verification/external_statement_verifier.hpp" +#include "duckdb/verification/parsed_statement_verifier.hpp" +#include "duckdb/verification/prepared_statement_verifier.hpp" +#include "duckdb/verification/unoptimized_statement_verifier.hpp" +#include "duckdb/verification/no_operator_caching_verifier.hpp" + +namespace duckdb { + +StatementVerifier::StatementVerifier(VerificationType type, string name, unique_ptr statement_p) + : type(type), name(std::move(name)), + statement(unique_ptr_cast(std::move(statement_p))), + select_list(statement->node->GetSelectList()) { +} + +StatementVerifier::StatementVerifier(unique_ptr statement_p) + : StatementVerifier(VerificationType::ORIGINAL, "Original", std::move(statement_p)) { +} + +StatementVerifier::~StatementVerifier() noexcept { +} + +unique_ptr StatementVerifier::Create(VerificationType type, const SQLStatement &statement_p) { + switch (type) { + case VerificationType::COPIED: + return CopiedStatementVerifier::Create(statement_p); + case VerificationType::DESERIALIZED: + return DeserializedStatementVerifier::Create(statement_p); + case VerificationType::PARSED: + return ParsedStatementVerifier::Create(statement_p); + case VerificationType::UNOPTIMIZED: + return UnoptimizedStatementVerifier::Create(statement_p); + case VerificationType::NO_OPERATOR_CACHING: + return NoOperatorCachingVerifier::Create(statement_p); + case VerificationType::PREPARED: + return PreparedStatementVerifier::Create(statement_p); + case VerificationType::EXTERNAL: + return ExternalStatementVerifier::Create(statement_p); + case VerificationType::INVALID: + default: + throw InternalException("Invalid statement verification type!"); + } +} + +void StatementVerifier::CheckExpressions(const StatementVerifier &other) const { + // Only the original statement should check other statements + D_ASSERT(type == VerificationType::ORIGINAL); + + // Check equality + if (other.RequireEquality()) { + D_ASSERT(statement->Equals(*other.statement)); + } + +#ifdef DEBUG + // Now perform checking on the expressions + D_ASSERT(select_list.size() == other.select_list.size()); + const auto expr_count = select_list.size(); + if (other.RequireEquality()) { + for (idx_t i = 0; i < expr_count; i++) { + // Run the ToString, to verify that it doesn't crash + select_list[i]->ToString(); + + if (select_list[i]->HasSubquery()) { + continue; + } + + // Check that the expressions are equivalent + D_ASSERT(select_list[i]->Equals(*other.select_list[i])); + // Check that the hashes are equivalent too + D_ASSERT(select_list[i]->Hash() == other.select_list[i]->Hash()); + + other.select_list[i]->Verify(); + } + } +#endif +} + +void StatementVerifier::CheckExpressions() const { +#ifdef DEBUG + D_ASSERT(type == VerificationType::ORIGINAL); + // Perform additional checking within the expressions + const auto expr_count = select_list.size(); + for (idx_t outer_idx = 0; outer_idx < expr_count; outer_idx++) { + auto hash = select_list[outer_idx]->Hash(); + for (idx_t inner_idx = 0; inner_idx < expr_count; inner_idx++) { + auto hash2 = select_list[inner_idx]->Hash(); + if (hash != hash2) { + // if the hashes are not equivalent, the expressions should not be equivalent + D_ASSERT(!select_list[outer_idx]->Equals(*select_list[inner_idx])); + } + } + } +#endif +} + +bool StatementVerifier::Run( + ClientContext &context, const string &query, + const std::function(const string &, unique_ptr)> &run) { + bool failed = false; + + context.interrupted = false; + context.config.enable_optimizer = !DisableOptimizer(); + context.config.enable_caching_operators = !DisableOperatorCaching(); + context.config.force_external = ForceExternal(); + try { + auto result = run(query, std::move(statement)); + if (result->HasError()) { + failed = true; + } + materialized_result = unique_ptr_cast(std::move(result)); + } catch (const Exception &ex) { + failed = true; + materialized_result = make_uniq(PreservedError(ex)); + } catch (std::exception &ex) { + failed = true; + materialized_result = make_uniq(PreservedError(ex)); + } + context.interrupted = false; + + return failed; +} + +string StatementVerifier::CompareResults(const StatementVerifier &other) { + D_ASSERT(type == VerificationType::ORIGINAL); + string error; + if (materialized_result->HasError() != other.materialized_result->HasError()) { // LCOV_EXCL_START + string result = other.name + " statement differs from original result!\n"; + result += "Original Result:\n" + materialized_result->ToString(); + result += other.name + ":\n" + other.materialized_result->ToString(); + return result; + } // LCOV_EXCL_STOP + if (materialized_result->HasError()) { + return ""; + } + if (!ColumnDataCollection::ResultEquals(materialized_result->Collection(), other.materialized_result->Collection(), + error)) { // LCOV_EXCL_START + string result = other.name + " statement differs from original result!\n"; + result += "Original Result:\n" + materialized_result->ToString(); + result += other.name + ":\n" + other.materialized_result->ToString(); + result += "\n\n---------------------------------\n" + error; + return result; + } // LCOV_EXCL_STOP + + return ""; +} + +} // namespace duckdb diff --git a/src/duckdb/src/verification/unoptimized_statement_verifier.cpp b/src/duckdb/src/verification/unoptimized_statement_verifier.cpp new file mode 100644 index 00000000..b0f27402 --- /dev/null +++ b/src/duckdb/src/verification/unoptimized_statement_verifier.cpp @@ -0,0 +1,13 @@ +#include "duckdb/verification/unoptimized_statement_verifier.hpp" + +namespace duckdb { + +UnoptimizedStatementVerifier::UnoptimizedStatementVerifier(unique_ptr statement_p) + : StatementVerifier(VerificationType::UNOPTIMIZED, "Unoptimized", std::move(statement_p)) { +} + +unique_ptr UnoptimizedStatementVerifier::Create(const SQLStatement &statement_p) { + return make_uniq(statement_p.Copy()); +} + +} // namespace duckdb diff --git a/src/duckdb/third_party/concurrentqueue/blockingconcurrentqueue.h b/src/duckdb/third_party/concurrentqueue/blockingconcurrentqueue.h new file mode 100644 index 00000000..916a5a27 --- /dev/null +++ b/src/duckdb/third_party/concurrentqueue/blockingconcurrentqueue.h @@ -0,0 +1,588 @@ +// Provides an efficient blocking version of moodycamel::ConcurrentQueue. +// ©2015-2016 Cameron Desrochers. Distributed under the terms of the simplified +// BSD license, available at the top of concurrentqueue.h. +// Uses Jeff Preshing's semaphore implementation (under the terms of its +// separate zlib license, embedded below). + +#pragma once + +#include "concurrentqueue.h" +#include "lightweightsemaphore.h" + +#include +#include +#include +#include +#include + +namespace duckdb_moodycamel +{ +// This is a blocking version of the queue. It has an almost identical interface to +// the normal non-blocking version, with the addition of various wait_dequeue() methods +// and the removal of producer-specific dequeue methods. +template +class BlockingConcurrentQueue +{ +private: + typedef ::duckdb_moodycamel::ConcurrentQueue ConcurrentQueue; + typedef ::duckdb_moodycamel::LightweightSemaphore LightweightSemaphore; + +public: + typedef typename ConcurrentQueue::producer_token_t producer_token_t; + typedef typename ConcurrentQueue::consumer_token_t consumer_token_t; + + typedef typename ConcurrentQueue::index_t index_t; + typedef typename ConcurrentQueue::size_t size_t; + typedef typename std::make_signed::type ssize_t; + + static const size_t BLOCK_SIZE = ConcurrentQueue::BLOCK_SIZE; + static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = ConcurrentQueue::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD; + static const size_t EXPLICIT_INITIAL_INDEX_SIZE = ConcurrentQueue::EXPLICIT_INITIAL_INDEX_SIZE; + static const size_t IMPLICIT_INITIAL_INDEX_SIZE = ConcurrentQueue::IMPLICIT_INITIAL_INDEX_SIZE; + static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = ConcurrentQueue::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; + static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = ConcurrentQueue::EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE; + static const size_t MAX_SUBQUEUE_SIZE = ConcurrentQueue::MAX_SUBQUEUE_SIZE; + +public: + // Creates a queue with at least `capacity` element slots; note that the + // actual number of elements that can be inserted without additional memory + // allocation depends on the number of producers and the block size (e.g. if + // the block size is equal to `capacity`, only a single block will be allocated + // up-front, which means only a single producer will be able to enqueue elements + // without an extra allocation -- blocks aren't shared between producers). + // This method is not thread safe -- it is up to the user to ensure that the + // queue is fully constructed before it starts being used by other threads (this + // includes making the memory effects of construction visible, possibly with a + // memory barrier). + explicit BlockingConcurrentQueue(size_t capacity = 6 * BLOCK_SIZE) + : inner(capacity), sema(create(), &BlockingConcurrentQueue::template destroy) + { + assert(reinterpret_cast((BlockingConcurrentQueue*)1) == &((BlockingConcurrentQueue*)1)->inner && "BlockingConcurrentQueue must have ConcurrentQueue as its first member"); + if (!sema) { + MOODYCAMEL_THROW(std::bad_alloc()); + } + } + + BlockingConcurrentQueue(size_t minCapacity, size_t maxExplicitProducers, size_t maxImplicitProducers) + : inner(minCapacity, maxExplicitProducers, maxImplicitProducers), sema(create(), &BlockingConcurrentQueue::template destroy) + { + assert(reinterpret_cast((BlockingConcurrentQueue*)1) == &((BlockingConcurrentQueue*)1)->inner && "BlockingConcurrentQueue must have ConcurrentQueue as its first member"); + if (!sema) { + MOODYCAMEL_THROW(std::bad_alloc()); + } + } + + // Disable copying and copy assignment + BlockingConcurrentQueue(BlockingConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; + BlockingConcurrentQueue& operator=(BlockingConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; + + // Moving is supported, but note that it is *not* a thread-safe operation. + // Nobody can use the queue while it's being moved, and the memory effects + // of that move must be propagated to other threads before they can use it. + // Note: When a queue is moved, its tokens are still valid but can only be + // used with the destination queue (i.e. semantically they are moved along + // with the queue itself). + BlockingConcurrentQueue(BlockingConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT + : inner(std::move(other.inner)), sema(std::move(other.sema)) + { } + + inline BlockingConcurrentQueue& operator=(BlockingConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT + { + return swap_internal(other); + } + + // Swaps this queue's state with the other's. Not thread-safe. + // Swapping two queues does not invalidate their tokens, however + // the tokens that were created for one queue must be used with + // only the swapped queue (i.e. the tokens are tied to the + // queue's movable state, not the object itself). + inline void swap(BlockingConcurrentQueue& other) MOODYCAMEL_NOEXCEPT + { + swap_internal(other); + } + +private: + BlockingConcurrentQueue& swap_internal(BlockingConcurrentQueue& other) + { + if (this == &other) { + return *this; + } + + inner.swap(other.inner); + sema.swap(other.sema); + return *this; + } + +public: + // Enqueues a single item (by copying it). + // Allocates memory if required. Only fails if memory allocation fails (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, + // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T const& item) + { + if ((details::likely)(inner.enqueue(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible). + // Allocates memory if required. Only fails if memory allocation fails (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, + // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T&& item) + { + if ((details::likely)(inner.enqueue(std::move(item)))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const& token, T const& item) + { + if ((details::likely)(inner.enqueue(token, item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const& token, T&& item) + { + if ((details::likely)(inner.enqueue(token, std::move(item)))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues several items. + // Allocates memory if required. Only fails if memory allocation fails (or + // implicit production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0, or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved instead of copied. + // Thread-safe. + template + inline bool enqueue_bulk(It itemFirst, size_t count) + { + if ((details::likely)(inner.enqueue_bulk(std::forward(itemFirst), count))) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues several items using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails + // (or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) + { + if ((details::likely)(inner.enqueue_bulk(token, std::forward(itemFirst), count))) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues a single item (by copying it). + // Does not allocate memory. Fails if not enough room to enqueue (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0). + // Thread-safe. + inline bool try_enqueue(T const& item) + { + if (inner.try_enqueue(item)) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible). + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Thread-safe. + inline bool try_enqueue(T&& item) + { + if (inner.try_enqueue(std::move(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const& token, T const& item) + { + if (inner.try_enqueue(token, item)) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const& token, T&& item) + { + if (inner.try_enqueue(token, std::move(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues several items. + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool try_enqueue_bulk(It itemFirst, size_t count) + { + if (inner.try_enqueue_bulk(std::forward(itemFirst), count)) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues several items using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool try_enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) + { + if (inner.try_enqueue_bulk(token, std::forward(itemFirst), count)) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + + // Attempts to dequeue from the queue. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline bool try_dequeue(U& item) + { + if (sema->tryWait()) { + while (!inner.try_dequeue(item)) { + continue; + } + return true; + } + return false; + } + + // Attempts to dequeue from the queue using an explicit consumer token. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline bool try_dequeue(consumer_token_t& token, U& item) + { + if (sema->tryWait()) { + while (!inner.try_dequeue(token, item)) { + continue; + } + return true; + } + return false; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline size_t try_dequeue_bulk(It itemFirst, size_t max) + { + size_t count = 0; + max = (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline size_t try_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) + { + size_t count = 0; + max = (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + + + // Blocks the current thread until there's something to dequeue, then + // dequeues it. + // Never allocates. Thread-safe. + template + inline void wait_dequeue(U& item) + { + while (!sema->wait()) { + continue; + } + while (!inner.try_dequeue(item)) { + continue; + } + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout (specified in microseconds) expires. Returns false + // without setting `item` if the timeout expires, otherwise assigns + // to `item` and returns true. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(U& item, std::int64_t timeout_usecs) + { + if (!sema->wait(timeout_usecs)) { + return false; + } + while (!inner.try_dequeue(item)) { + continue; + } + return true; + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout expires. Returns false without setting `item` if the + // timeout expires, otherwise assigns to `item` and returns true. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(U& item, std::chrono::duration const& timeout) + { + return wait_dequeue_timed(item, std::chrono::duration_cast(timeout).count()); + } + + // Blocks the current thread until there's something to dequeue, then + // dequeues it using an explicit consumer token. + // Never allocates. Thread-safe. + template + inline void wait_dequeue(consumer_token_t& token, U& item) + { + while (!sema->wait()) { + continue; + } + while (!inner.try_dequeue(token, item)) { + continue; + } + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout (specified in microseconds) expires. Returns false + // without setting `item` if the timeout expires, otherwise assigns + // to `item` and returns true. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(consumer_token_t& token, U& item, std::int64_t timeout_usecs) + { + if (!sema->wait(timeout_usecs)) { + return false; + } + while (!inner.try_dequeue(token, item)) { + continue; + } + return true; + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout expires. Returns false without setting `item` if the + // timeout expires, otherwise assigns to `item` and returns true. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(consumer_token_t& token, U& item, std::chrono::duration const& timeout) + { + return wait_dequeue_timed(token, item, std::chrono::duration_cast(timeout).count()); + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which will + // always be at least one (this method blocks until the queue + // is non-empty) and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk(It itemFirst, size_t max) + { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue_bulk. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, std::int64_t timeout_usecs) + { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, timeout_usecs); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, std::chrono::duration const& timeout) + { + return wait_dequeue_bulk_timed(itemFirst, max, std::chrono::duration_cast(timeout).count()); + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued, which will + // always be at least one (this method blocks until the queue + // is non-empty) and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) + { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue_bulk. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(consumer_token_t& token, It itemFirst, size_t max, std::int64_t timeout_usecs) + { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, timeout_usecs); + while (count != max) { + count += inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(consumer_token_t& token, It itemFirst, size_t max, std::chrono::duration const& timeout) + { + return wait_dequeue_bulk_timed(token, itemFirst, max, std::chrono::duration_cast(timeout).count()); + } + + + // Returns an estimate of the total number of elements currently in the queue. This + // estimate is only accurate if the queue has completely stabilized before it is called + // (i.e. all enqueue and dequeue operations have completed and their memory effects are + // visible on the calling thread, and no further operations start while this method is + // being called). + // Thread-safe. + inline size_t size_approx() const + { + return (size_t)sema->availableApprox(); + } + + + // Returns true if the underlying atomic variables used by + // the queue are lock-free (they should be on most platforms). + // Thread-safe. + static bool is_lock_free() + { + return ConcurrentQueue::is_lock_free(); + } + + +private: + template + static inline U* create() + { + auto p = (Traits::malloc)(sizeof(U)); + return p != nullptr ? new (p) U : nullptr; + } + + template + static inline U* create(A1&& a1) + { + auto p = (Traits::malloc)(sizeof(U)); + return p != nullptr ? new (p) U(std::forward(a1)) : nullptr; + } + + template + static inline void destroy(U* p) + { + if (p != nullptr) { + p->~U(); + } + (Traits::free)(p); + } + +private: + ConcurrentQueue inner; + std::unique_ptr sema; +}; + + +template +inline void swap(BlockingConcurrentQueue& a, BlockingConcurrentQueue& b) MOODYCAMEL_NOEXCEPT +{ + a.swap(b); +} + +} // end namespace moodycamel diff --git a/src/duckdb/third_party/concurrentqueue/concurrentqueue.h b/src/duckdb/third_party/concurrentqueue/concurrentqueue.h new file mode 100644 index 00000000..f3e2b100 --- /dev/null +++ b/src/duckdb/third_party/concurrentqueue/concurrentqueue.h @@ -0,0 +1,3667 @@ +// Provides a C++11 implementation of a multi-producer, multi-consumer lock-free queue. +// An overview, including benchmark results, is provided here: +// http://moodycamel.com/blog/2014/a-fast-general-purpose-lock-free-queue-for-c++ +// The full design is also described in excruciating detail at: +// http://moodycamel.com/blog/2014/detailed-design-of-a-lock-free-queue + +// Simplified BSD license: +// Copyright (c) 2013-2016, Cameron Desrochers. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// - Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, this list of +// conditions and the following disclaimer in the documentation and/or other materials +// provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +// OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +// TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#pragma once + +#if defined(__GNUC__) +// Disable -Wconversion warnings (spuriously triggered when Traits::size_t and +// Traits::index_t are set to < 32 bits, causing integer promotion, causing warnings +// upon assigning any computed values) + +#endif + +#if defined(__APPLE__) +#include +#endif + +#include // Requires C++11. Sorry VS2010. +#include +#include // for max_align_t +#include +#include +#include +#include +#include +#include +#include // for CHAR_BIT +#include +#include // partly for __WINPTHREADS_VERSION if on MinGW-w64 w/ POSIX threading + +// Platform-specific definitions of a numeric thread ID type and an invalid value +namespace duckdb_moodycamel { namespace details { + template struct thread_id_converter { + typedef thread_id_t thread_id_numeric_size_t; + typedef thread_id_t thread_id_hash_t; + static thread_id_hash_t prehash(thread_id_t const& x) { return x; } + }; +} } +#if defined(MCDBGQ_USE_RELACY) +namespace duckdb_moodycamel { namespace details { + typedef std::uint32_t thread_id_t; + static const thread_id_t invalid_thread_id = 0xFFFFFFFFU; + static const thread_id_t invalid_thread_id2 = 0xFFFFFFFEU; + static inline thread_id_t thread_id() { return rl::thread_index(); } +} } +#elif defined(_WIN32) || defined(__WINDOWS__) || defined(__WIN32__) +// No sense pulling in windows.h in a header, we'll manually declare the function +// we use and rely on backwards-compatibility for this not to break +extern "C" __declspec(dllimport) unsigned long __stdcall GetCurrentThreadId(void); +namespace duckdb_moodycamel { namespace details { + static_assert(sizeof(unsigned long) == sizeof(std::uint32_t), "Expected size of unsigned long to be 32 bits on Windows"); + typedef std::uint32_t thread_id_t; + static const thread_id_t invalid_thread_id = 0; // See http://blogs.msdn.com/b/oldnewthing/archive/2004/02/23/78395.aspx + static const thread_id_t invalid_thread_id2 = 0xFFFFFFFFU; // Not technically guaranteed to be invalid, but is never used in practice. Note that all Win32 thread IDs are presently multiples of 4. + static inline thread_id_t thread_id() { return static_cast(::GetCurrentThreadId()); } +} } +#elif defined(__arm__) || defined(_M_ARM) || defined(__aarch64__) || (defined(__APPLE__) && TARGET_OS_IPHONE) || defined(__MVS__) +namespace duckdb_moodycamel { namespace details { + static_assert(sizeof(std::thread::id) == 4 || sizeof(std::thread::id) == 8, "std::thread::id is expected to be either 4 or 8 bytes"); + + typedef std::thread::id thread_id_t; + static const thread_id_t invalid_thread_id; // Default ctor creates invalid ID + + // Note we don't define a invalid_thread_id2 since std::thread::id doesn't have one; it's + // only used if MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED is defined anyway, which it won't + // be. + static inline thread_id_t thread_id() { return std::this_thread::get_id(); } + + template struct thread_id_size { }; + template<> struct thread_id_size<4> { typedef std::uint32_t numeric_t; }; + template<> struct thread_id_size<8> { typedef std::uint64_t numeric_t; }; + + template<> struct thread_id_converter { + typedef thread_id_size::numeric_t thread_id_numeric_size_t; +#ifndef __APPLE__ + typedef std::size_t thread_id_hash_t; +#else + typedef thread_id_numeric_size_t thread_id_hash_t; +#endif + + static thread_id_hash_t prehash(thread_id_t const& x) + { +#ifndef __APPLE__ + return std::hash()(x); +#else + return *reinterpret_cast(&x); +#endif + } + }; +} } +#else +// Use a nice trick from this answer: http://stackoverflow.com/a/8438730/21475 +// In order to get a numeric thread ID in a platform-independent way, we use a thread-local +// static variable's address as a thread identifier :-) +#if defined(__GNUC__) || defined(__INTEL_COMPILER) +#define MOODYCAMEL_THREADLOCAL __thread +#elif defined(_MSC_VER) +#define MOODYCAMEL_THREADLOCAL __declspec(thread) +#else +// Assume C++11 compliant compiler +#define MOODYCAMEL_THREADLOCAL thread_local +#endif +namespace duckdb_moodycamel { namespace details { + typedef std::uintptr_t thread_id_t; + static const thread_id_t invalid_thread_id = 0; // Address can't be nullptr +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + static const thread_id_t invalid_thread_id2 = 1; // Member accesses off a null pointer are also generally invalid. Plus it's not aligned. +#endif + inline thread_id_t thread_id() { static MOODYCAMEL_THREADLOCAL int x; return reinterpret_cast(&x); } +} } +#endif + +// Constexpr if +#ifndef MOODYCAMEL_CONSTEXPR_IF +#if (defined(_MSC_VER) && defined(_HAS_CXX17) && _HAS_CXX17) || __cplusplus > 201402L +#define MOODYCAMEL_CONSTEXPR_IF if constexpr +#define MOODYCAMEL_MAYBE_UNUSED [[maybe_unused]] +#else +#define MOODYCAMEL_CONSTEXPR_IF if +#define MOODYCAMEL_MAYBE_UNUSED +#endif +#endif + +// Exceptions +#ifndef MOODYCAMEL_EXCEPTIONS_ENABLED +#if (defined(_MSC_VER) && defined(_CPPUNWIND)) || (defined(__GNUC__) && defined(__EXCEPTIONS)) || (!defined(_MSC_VER) && !defined(__GNUC__)) +#define MOODYCAMEL_EXCEPTIONS_ENABLED +#endif +#endif +#ifdef MOODYCAMEL_EXCEPTIONS_ENABLED +#define MOODYCAMEL_TRY try +#define MOODYCAMEL_CATCH(...) catch(__VA_ARGS__) +#define MOODYCAMEL_RETHROW throw +#define MOODYCAMEL_THROW(expr) throw (expr) +#else +#define MOODYCAMEL_TRY MOODYCAMEL_CONSTEXPR_IF (true) +#define MOODYCAMEL_CATCH(...) else MOODYCAMEL_CONSTEXPR_IF (false) +#define MOODYCAMEL_RETHROW +#define MOODYCAMEL_THROW(expr) +#endif + +#ifndef MOODYCAMEL_NOEXCEPT +#if !defined(MOODYCAMEL_EXCEPTIONS_ENABLED) +#define MOODYCAMEL_NOEXCEPT +#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) true +#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) true +#elif defined(_MSC_VER) && defined(_NOEXCEPT) && _MSC_VER < 1800 +// VS2012's std::is_nothrow_[move_]constructible is broken and returns true when it shouldn't :-( +// We have to assume *all* non-trivial constructors may throw on VS2012! +#define MOODYCAMEL_NOEXCEPT _NOEXCEPT +#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) (std::is_rvalue_reference::value && std::is_move_constructible::value ? std::is_trivially_move_constructible::value : std::is_trivially_copy_constructible::value) +#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) ((std::is_rvalue_reference::value && std::is_move_assignable::value ? std::is_trivially_move_assignable::value || std::is_nothrow_move_assignable::value : std::is_trivially_copy_assignable::value || std::is_nothrow_copy_assignable::value) && MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr)) +#elif defined(_MSC_VER) && defined(_NOEXCEPT) && _MSC_VER < 1900 +#define MOODYCAMEL_NOEXCEPT _NOEXCEPT +#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) (std::is_rvalue_reference::value && std::is_move_constructible::value ? std::is_trivially_move_constructible::value || std::is_nothrow_move_constructible::value : std::is_trivially_copy_constructible::value || std::is_nothrow_copy_constructible::value) +#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) ((std::is_rvalue_reference::value && std::is_move_assignable::value ? std::is_trivially_move_assignable::value || std::is_nothrow_move_assignable::value : std::is_trivially_copy_assignable::value || std::is_nothrow_copy_assignable::value) && MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr)) +#else +#define MOODYCAMEL_NOEXCEPT noexcept +#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) noexcept(expr) +#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) noexcept(expr) +#endif +#endif + +#ifndef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED +#ifdef MCDBGQ_USE_RELACY +#define MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED +#else +// VS2013 doesn't support `thread_local`, and MinGW-w64 w/ POSIX threading has a crippling bug: http://sourceforge.net/p/mingw-w64/bugs/445 +// g++ <=4.7 doesn't support thread_local either. +// Finally, iOS/ARM doesn't have support for it either, and g++/ARM allows it to compile but it's unconfirmed to actually work +#if (!defined(_MSC_VER) || _MSC_VER >= 1900) && (!defined(__MINGW32__) && !defined(__MINGW64__) || !defined(__WINPTHREADS_VERSION)) && (!defined(__GNUC__) || __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)) && (!defined(__APPLE__) || !TARGET_OS_IPHONE) && !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(__MVS__) +// Assume `thread_local` is fully supported in all other C++11 compilers/platforms +//#define MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED // always disabled for now since several users report having problems with it on +#endif +#endif +#endif + +// VS2012 doesn't support deleted functions. +// In this case, we declare the function normally but don't define it. A link error will be generated if the function is called. +#ifndef MOODYCAMEL_DELETE_FUNCTION +#if defined(_MSC_VER) && _MSC_VER < 1800 +#define MOODYCAMEL_DELETE_FUNCTION +#else +#define MOODYCAMEL_DELETE_FUNCTION = delete +#endif +#endif + +#ifndef MOODYCAMEL_ALIGNAS +// VS2013 doesn't support alignas or alignof +#if defined(_MSC_VER) && _MSC_VER <= 1800 +#define MOODYCAMEL_ALIGNAS(alignment) __declspec(align(alignment)) +#define MOODYCAMEL_ALIGNOF(obj) __alignof(obj) +#else +#define MOODYCAMEL_ALIGNAS(alignment) alignas(alignment) +#define MOODYCAMEL_ALIGNOF(obj) alignof(obj) +#endif +#endif + + + +// Compiler-specific likely/unlikely hints +namespace duckdb_moodycamel { namespace details { + +#if defined(__GNUC__) + static inline bool (likely)(bool x) { return __builtin_expect((x), true); } +// static inline bool (unlikely)(bool x) { return __builtin_expect((x), false); } +#else + static inline bool (likely)(bool x) { return x; } +// static inline bool (unlikely)(bool x) { return x; } +#endif +} } + +namespace duckdb_moodycamel { +namespace details { + template + struct const_numeric_max { + static_assert(std::is_integral::value, "const_numeric_max can only be used with integers"); + static const T value = std::numeric_limits::is_signed + ? (static_cast(1) << (sizeof(T) * CHAR_BIT - 1)) - static_cast(1) + : static_cast(-1); + }; + +#if defined(__GLIBCXX__) + typedef ::max_align_t std_max_align_t; // libstdc++ forgot to add it to std:: for a while +#else + typedef std::max_align_t std_max_align_t; // Others (e.g. MSVC) insist it can *only* be accessed via std:: +#endif + + // Some platforms have incorrectly set max_align_t to a type with <8 bytes alignment even while supporting + // 8-byte aligned scalar values (*cough* 32-bit iOS). Work around this with our own union. See issue #64. + typedef union { + std_max_align_t x; + long long y; + void* z; + } max_align_t; +} + +// Default traits for the ConcurrentQueue. To change some of the +// traits without re-implementing all of them, inherit from this +// struct and shadow the declarations you wish to be different; +// since the traits are used as a template type parameter, the +// shadowed declarations will be used where defined, and the defaults +// otherwise. +struct ConcurrentQueueDefaultTraits +{ + // General-purpose size type. std::size_t is strongly recommended. + typedef std::size_t size_t; + + // The type used for the enqueue and dequeue indices. Must be at least as + // large as size_t. Should be significantly larger than the number of elements + // you expect to hold at once, especially if you have a high turnover rate; + // for example, on 32-bit x86, if you expect to have over a hundred million + // elements or pump several million elements through your queue in a very + // short space of time, using a 32-bit type *may* trigger a race condition. + // A 64-bit int type is recommended in that case, and in practice will + // prevent a race condition no matter the usage of the queue. Note that + // whether the queue is lock-free with a 64-int type depends on the whether + // std::atomic is lock-free, which is platform-specific. + typedef std::size_t index_t; + + // Internally, all elements are enqueued and dequeued from multi-element + // blocks; this is the smallest controllable unit. If you expect few elements + // but many producers, a smaller block size should be favoured. For few producers + // and/or many elements, a larger block size is preferred. A sane default + // is provided. Must be a power of 2. + static const size_t BLOCK_SIZE = 32; + + // For explicit producers (i.e. when using a producer token), the block is + // checked for being empty by iterating through a list of flags, one per element. + // For large block sizes, this is too inefficient, and switching to an atomic + // counter-based approach is faster. The switch is made for block sizes strictly + // larger than this threshold. + static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = 32; + + // How many full blocks can be expected for a single explicit producer? This should + // reflect that number's maximum for optimal performance. Must be a power of 2. + static const size_t EXPLICIT_INITIAL_INDEX_SIZE = 32; + + // How many full blocks can be expected for a single implicit producer? This should + // reflect that number's maximum for optimal performance. Must be a power of 2. + static const size_t IMPLICIT_INITIAL_INDEX_SIZE = 32; + + // The initial size of the hash table mapping thread IDs to implicit producers. + // Note that the hash is resized every time it becomes half full. + // Must be a power of two, and either 0 or at least 1. If 0, implicit production + // (using the enqueue methods without an explicit producer token) is disabled. + static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = 32; + + // Controls the number of items that an explicit consumer (i.e. one with a token) + // must consume before it causes all consumers to rotate and move on to the next + // internal queue. + static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = 256; + + // The maximum number of elements (inclusive) that can be enqueued to a sub-queue. + // Enqueue operations that would cause this limit to be surpassed will fail. Note + // that this limit is enforced at the block level (for performance reasons), i.e. + // it's rounded up to the nearest block size. + static const size_t MAX_SUBQUEUE_SIZE = details::const_numeric_max::value; + + +#ifndef MCDBGQ_USE_RELACY + // Memory allocation can be customized if needed. + // malloc should return nullptr on failure, and handle alignment like std::malloc. +#if defined(malloc) || defined(free) + // Gah, this is 2015, stop defining macros that break standard code already! + // Work around malloc/free being special macros: + static inline void* WORKAROUND_malloc(size_t size) { return malloc(size); } + static inline void WORKAROUND_free(void* ptr) { return free(ptr); } + static inline void* (malloc)(size_t size) { return WORKAROUND_malloc(size); } + static inline void (free)(void* ptr) { return WORKAROUND_free(ptr); } +#else + static inline void* malloc(size_t size) { return std::malloc(size); } + static inline void free(void* ptr) { return std::free(ptr); } +#endif +#else + // Debug versions when running under the Relacy race detector (ignore + // these in user code) + static inline void* malloc(size_t size) { return rl::rl_malloc(size, $); } + static inline void free(void* ptr) { return rl::rl_free(ptr, $); } +#endif +}; + + +// When producing or consuming many elements, the most efficient way is to: +// 1) Use one of the bulk-operation methods of the queue with a token +// 2) Failing that, use the bulk-operation methods without a token +// 3) Failing that, create a token and use that with the single-item methods +// 4) Failing that, use the single-parameter methods of the queue +// Having said that, don't create tokens willy-nilly -- ideally there should be +// a maximum of one token per thread (of each kind). +struct ProducerToken; +struct ConsumerToken; + +template class ConcurrentQueue; +template class BlockingConcurrentQueue; +class ConcurrentQueueTests; + + +namespace details +{ + struct ConcurrentQueueProducerTypelessBase + { + ConcurrentQueueProducerTypelessBase* next; + std::atomic inactive; + ProducerToken* token; + + ConcurrentQueueProducerTypelessBase() + : next(nullptr), inactive(false), token(nullptr) + { + } + }; + + template struct _hash_32_or_64 { + static inline std::uint32_t hash(std::uint32_t h) + { + // MurmurHash3 finalizer -- see https://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp + // Since the thread ID is already unique, all we really want to do is propagate that + // uniqueness evenly across all the bits, so that we can use a subset of the bits while + // reducing collisions significantly + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + return h ^ (h >> 16); + } + }; + template<> struct _hash_32_or_64<1> { + static inline std::uint64_t hash(std::uint64_t h) + { + h ^= h >> 33; + h *= 0xff51afd7ed558ccd; + h ^= h >> 33; + h *= 0xc4ceb9fe1a85ec53; + return h ^ (h >> 33); + } + }; + template struct hash_32_or_64 : public _hash_32_or_64<(size > 4)> { }; + + static inline size_t hash_thread_id(thread_id_t id) + { + static_assert(sizeof(thread_id_t) <= 8, "Expected a platform where thread IDs are at most 64-bit values"); + return static_cast(hash_32_or_64::thread_id_hash_t)>::hash( + thread_id_converter::prehash(id))); + } + + template + static inline bool circular_less_than(T a, T b) + { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable: 4554) +#endif + static_assert(std::is_integral::value && !std::numeric_limits::is_signed, "circular_less_than is intended to be used only with unsigned integer types"); + return static_cast(a - b) > static_cast(static_cast(1) << static_cast(sizeof(T) * CHAR_BIT - 1)); +#ifdef _MSC_VER +#pragma warning(pop) +#endif + } + + template + static inline char* align_for(char* ptr) + { + const std::size_t alignment = std::alignment_of::value; + return ptr + (alignment - (reinterpret_cast(ptr) % alignment)) % alignment; + } + + template + static inline T ceil_to_pow_2(T x) + { + static_assert(std::is_integral::value && !std::numeric_limits::is_signed, "ceil_to_pow_2 is intended to be used only with unsigned integer types"); + + // Adapted from http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 + --x; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + for (std::size_t i = 1; i < sizeof(T); i <<= 1) { + x |= x >> (i << 3); + } + ++x; + return x; + } + + template + static inline void swap_relaxed(std::atomic& left, std::atomic& right) + { + T temp = std::move(left.load(std::memory_order_relaxed)); + left.store(std::move(right.load(std::memory_order_relaxed)), std::memory_order_relaxed); + right.store(std::move(temp), std::memory_order_relaxed); + } + + template + static inline T const& nomove(T const& x) + { + return x; + } + + template + struct nomove_if + { + template + static inline T const& eval(T const& x) + { + return x; + } + }; + + template<> + struct nomove_if + { + template + static inline auto eval(U&& x) + -> decltype(std::forward(x)) + { + return std::forward(x); + } + }; + + template + static inline auto deref_noexcept(It& it) MOODYCAMEL_NOEXCEPT -> decltype(*it) + { + return *it; + } + +#if defined(__clang__) || !defined(__GNUC__) || __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8) + template struct is_trivially_destructible : std::is_trivially_destructible { }; +#else + template struct is_trivially_destructible : std::has_trivial_destructor { }; +#endif + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED +#ifdef MCDBGQ_USE_RELACY + typedef RelacyThreadExitListener ThreadExitListener; + typedef RelacyThreadExitNotifier ThreadExitNotifier; +#else + struct ThreadExitListener + { + typedef void (*callback_t)(void*); + callback_t callback; + void* userData; + + ThreadExitListener* next; // reserved for use by the ThreadExitNotifier + }; + + + class ThreadExitNotifier + { + public: + static void subscribe(ThreadExitListener* listener) + { + auto& tlsInst = instance(); + listener->next = tlsInst.tail; + tlsInst.tail = listener; + } + + static void unsubscribe(ThreadExitListener* listener) + { + auto& tlsInst = instance(); + ThreadExitListener** prev = &tlsInst.tail; + for (auto ptr = tlsInst.tail; ptr != nullptr; ptr = ptr->next) { + if (ptr == listener) { + *prev = ptr->next; + break; + } + prev = &ptr->next; + } + } + + private: + ThreadExitNotifier() : tail(nullptr) { } + ThreadExitNotifier(ThreadExitNotifier const&) MOODYCAMEL_DELETE_FUNCTION; + ThreadExitNotifier& operator=(ThreadExitNotifier const&) MOODYCAMEL_DELETE_FUNCTION; + + ~ThreadExitNotifier() + { + // This thread is about to exit, let everyone know! + assert(this == &instance() && "If this assert fails, you likely have a buggy compiler! Change the preprocessor conditions such that MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED is no longer defined."); + for (auto ptr = tail; ptr != nullptr; ptr = ptr->next) { + ptr->callback(ptr->userData); + } + } + + // Thread-local + static inline ThreadExitNotifier& instance() + { + static thread_local ThreadExitNotifier notifier; + return notifier; + } + + private: + ThreadExitListener* tail; + }; +#endif +#endif + + template struct static_is_lock_free_num { enum { value = 0 }; }; + template<> struct static_is_lock_free_num { enum { value = ATOMIC_CHAR_LOCK_FREE }; }; + template<> struct static_is_lock_free_num { enum { value = ATOMIC_SHORT_LOCK_FREE }; }; + template<> struct static_is_lock_free_num { enum { value = ATOMIC_INT_LOCK_FREE }; }; + template<> struct static_is_lock_free_num { enum { value = ATOMIC_LONG_LOCK_FREE }; }; + template<> struct static_is_lock_free_num { enum { value = ATOMIC_LLONG_LOCK_FREE }; }; + template struct static_is_lock_free : static_is_lock_free_num::type> { }; + template<> struct static_is_lock_free { enum { value = ATOMIC_BOOL_LOCK_FREE }; }; + template struct static_is_lock_free { enum { value = ATOMIC_POINTER_LOCK_FREE }; }; +} + + +struct ProducerToken +{ + template + explicit ProducerToken(ConcurrentQueue& queue); + + template + explicit ProducerToken(BlockingConcurrentQueue& queue); + + ProducerToken(ProducerToken&& other) MOODYCAMEL_NOEXCEPT + : producer(other.producer) + { + other.producer = nullptr; + if (producer != nullptr) { + producer->token = this; + } + } + + inline ProducerToken& operator=(ProducerToken&& other) MOODYCAMEL_NOEXCEPT + { + swap(other); + return *this; + } + + void swap(ProducerToken& other) MOODYCAMEL_NOEXCEPT + { + std::swap(producer, other.producer); + if (producer != nullptr) { + producer->token = this; + } + if (other.producer != nullptr) { + other.producer->token = &other; + } + } + + // A token is always valid unless: + // 1) Memory allocation failed during construction + // 2) It was moved via the move constructor + // (Note: assignment does a swap, leaving both potentially valid) + // 3) The associated queue was destroyed + // Note that if valid() returns true, that only indicates + // that the token is valid for use with a specific queue, + // but not which one; that's up to the user to track. + inline bool valid() const { return producer != nullptr; } + + ~ProducerToken() + { + if (producer != nullptr) { + producer->token = nullptr; + producer->inactive.store(true, std::memory_order_release); + } + } + + // Disable copying and assignment + ProducerToken(ProducerToken const&) MOODYCAMEL_DELETE_FUNCTION; + ProducerToken& operator=(ProducerToken const&) MOODYCAMEL_DELETE_FUNCTION; + +private: + template friend class ConcurrentQueue; + friend class ConcurrentQueueTests; + +protected: + details::ConcurrentQueueProducerTypelessBase* producer; +}; + + +struct ConsumerToken +{ + template + explicit ConsumerToken(ConcurrentQueue& q); + + template + explicit ConsumerToken(BlockingConcurrentQueue& q); + + ConsumerToken(ConsumerToken&& other) MOODYCAMEL_NOEXCEPT + : initialOffset(other.initialOffset), lastKnownGlobalOffset(other.lastKnownGlobalOffset), itemsConsumedFromCurrent(other.itemsConsumedFromCurrent), currentProducer(other.currentProducer), desiredProducer(other.desiredProducer) + { + } + + inline ConsumerToken& operator=(ConsumerToken&& other) MOODYCAMEL_NOEXCEPT + { + swap(other); + return *this; + } + + void swap(ConsumerToken& other) MOODYCAMEL_NOEXCEPT + { + std::swap(initialOffset, other.initialOffset); + std::swap(lastKnownGlobalOffset, other.lastKnownGlobalOffset); + std::swap(itemsConsumedFromCurrent, other.itemsConsumedFromCurrent); + std::swap(currentProducer, other.currentProducer); + std::swap(desiredProducer, other.desiredProducer); + } + + // Disable copying and assignment + ConsumerToken(ConsumerToken const&) MOODYCAMEL_DELETE_FUNCTION; + ConsumerToken& operator=(ConsumerToken const&) MOODYCAMEL_DELETE_FUNCTION; + +private: + template friend class ConcurrentQueue; + friend class ConcurrentQueueTests; + +private: // but shared with ConcurrentQueue + std::uint32_t initialOffset; + std::uint32_t lastKnownGlobalOffset; + std::uint32_t itemsConsumedFromCurrent; + details::ConcurrentQueueProducerTypelessBase* currentProducer; + details::ConcurrentQueueProducerTypelessBase* desiredProducer; +}; + +// Need to forward-declare this swap because it's in a namespace. +// See http://stackoverflow.com/questions/4492062/why-does-a-c-friend-class-need-a-forward-declaration-only-in-other-namespaces +template +inline void swap(typename ConcurrentQueue::ImplicitProducerKVP& a, typename ConcurrentQueue::ImplicitProducerKVP& b) MOODYCAMEL_NOEXCEPT; + + +template +class ConcurrentQueue +{ +public: + typedef ::duckdb_moodycamel::ProducerToken producer_token_t; + typedef ::duckdb_moodycamel::ConsumerToken consumer_token_t; + + typedef typename Traits::index_t index_t; + typedef typename Traits::size_t size_t; + + static const size_t BLOCK_SIZE = static_cast(Traits::BLOCK_SIZE); + static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = static_cast(Traits::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD); + static const size_t EXPLICIT_INITIAL_INDEX_SIZE = static_cast(Traits::EXPLICIT_INITIAL_INDEX_SIZE); + static const size_t IMPLICIT_INITIAL_INDEX_SIZE = static_cast(Traits::IMPLICIT_INITIAL_INDEX_SIZE); + static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = static_cast(Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE); + static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = static_cast(Traits::EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE); +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable: 4307) // + integral constant overflow (that's what the ternary expression is for!) +#pragma warning(disable: 4309) // static_cast: Truncation of constant value +#endif + static const size_t MAX_SUBQUEUE_SIZE = (details::const_numeric_max::value - static_cast(Traits::MAX_SUBQUEUE_SIZE) < BLOCK_SIZE) ? details::const_numeric_max::value : ((static_cast(Traits::MAX_SUBQUEUE_SIZE) + (BLOCK_SIZE - 1)) / BLOCK_SIZE * BLOCK_SIZE); +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + static_assert(!std::numeric_limits::is_signed && std::is_integral::value, "Traits::size_t must be an unsigned integral type"); + static_assert(!std::numeric_limits::is_signed && std::is_integral::value, "Traits::index_t must be an unsigned integral type"); + static_assert(sizeof(index_t) >= sizeof(size_t), "Traits::index_t must be at least as wide as Traits::size_t"); + static_assert((BLOCK_SIZE > 1) && !(BLOCK_SIZE & (BLOCK_SIZE - 1)), "Traits::BLOCK_SIZE must be a power of 2 (and at least 2)"); + static_assert((EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD > 1) && !(EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD & (EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD - 1)), "Traits::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD must be a power of 2 (and greater than 1)"); + static_assert((EXPLICIT_INITIAL_INDEX_SIZE > 1) && !(EXPLICIT_INITIAL_INDEX_SIZE & (EXPLICIT_INITIAL_INDEX_SIZE - 1)), "Traits::EXPLICIT_INITIAL_INDEX_SIZE must be a power of 2 (and greater than 1)"); + static_assert((IMPLICIT_INITIAL_INDEX_SIZE > 1) && !(IMPLICIT_INITIAL_INDEX_SIZE & (IMPLICIT_INITIAL_INDEX_SIZE - 1)), "Traits::IMPLICIT_INITIAL_INDEX_SIZE must be a power of 2 (and greater than 1)"); + static_assert((INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) || !(INITIAL_IMPLICIT_PRODUCER_HASH_SIZE & (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - 1)), "Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE must be a power of 2"); + static_assert(INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0 || INITIAL_IMPLICIT_PRODUCER_HASH_SIZE >= 1, "Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE must be at least 1 (or 0 to disable implicit enqueueing)"); + +public: + // Creates a queue with at least `capacity` element slots; note that the + // actual number of elements that can be inserted without additional memory + // allocation depends on the number of producers and the block size (e.g. if + // the block size is equal to `capacity`, only a single block will be allocated + // up-front, which means only a single producer will be able to enqueue elements + // without an extra allocation -- blocks aren't shared between producers). + // This method is not thread safe -- it is up to the user to ensure that the + // queue is fully constructed before it starts being used by other threads (this + // includes making the memory effects of construction visible, possibly with a + // memory barrier). + explicit ConcurrentQueue(size_t capacity = 6 * BLOCK_SIZE) + : producerListTail(nullptr), + producerCount(0), + initialBlockPoolIndex(0), + nextExplicitConsumerId(0), + globalExplicitConsumerOffset(0) + { + implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); + populate_initial_implicit_producer_hash(); + populate_initial_block_list(capacity / BLOCK_SIZE + ((capacity & (BLOCK_SIZE - 1)) == 0 ? 0 : 1)); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + // Track all the producers using a fully-resolved typed list for + // each kind; this makes it possible to debug them starting from + // the root queue object (otherwise wacky casts are needed that + // don't compile in the debugger's expression evaluator). + explicitProducers.store(nullptr, std::memory_order_relaxed); + implicitProducers.store(nullptr, std::memory_order_relaxed); +#endif + } + + // Computes the correct amount of pre-allocated blocks for you based + // on the minimum number of elements you want available at any given + // time, and the maximum concurrent number of each type of producer. + ConcurrentQueue(size_t minCapacity, size_t maxExplicitProducers, size_t maxImplicitProducers) + : producerListTail(nullptr), + producerCount(0), + initialBlockPoolIndex(0), + nextExplicitConsumerId(0), + globalExplicitConsumerOffset(0) + { + implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); + populate_initial_implicit_producer_hash(); + size_t blocks = (((minCapacity + BLOCK_SIZE - 1) / BLOCK_SIZE) - 1) * (maxExplicitProducers + 1) + 2 * (maxExplicitProducers + maxImplicitProducers); + populate_initial_block_list(blocks); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + explicitProducers.store(nullptr, std::memory_order_relaxed); + implicitProducers.store(nullptr, std::memory_order_relaxed); +#endif + } + + // Note: The queue should not be accessed concurrently while it's + // being deleted. It's up to the user to synchronize this. + // This method is not thread safe. + ~ConcurrentQueue() + { + // Destroy producers + auto ptr = producerListTail.load(std::memory_order_relaxed); + while (ptr != nullptr) { + auto next = ptr->next_prod(); + if (ptr->token != nullptr) { + ptr->token->producer = nullptr; + } + destroy(ptr); + ptr = next; + } + + // Destroy implicit producer hash tables + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE != 0) { + auto hash = implicitProducerHash.load(std::memory_order_relaxed); + while (hash != nullptr) { + auto prev = hash->prev; + if (prev != nullptr) { // The last hash is part of this object and was not allocated dynamically + for (size_t i = 0; i != hash->capacity; ++i) { + hash->entries[i].~ImplicitProducerKVP(); + } + hash->~ImplicitProducerHash(); + (Traits::free)(hash); + } + hash = prev; + } + } + + // Destroy global free list + auto block = freeList.head_unsafe(); + while (block != nullptr) { + auto next = block->freeListNext.load(std::memory_order_relaxed); + if (block->dynamicallyAllocated) { + destroy(block); + } + block = next; + } + + // Destroy initial free list + destroy_array(initialBlockPool, initialBlockPoolSize); + } + + // Disable copying and copy assignment + ConcurrentQueue(ConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; + ConcurrentQueue& operator=(ConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; + + // Moving is supported, but note that it is *not* a thread-safe operation. + // Nobody can use the queue while it's being moved, and the memory effects + // of that move must be propagated to other threads before they can use it. + // Note: When a queue is moved, its tokens are still valid but can only be + // used with the destination queue (i.e. semantically they are moved along + // with the queue itself). + ConcurrentQueue(ConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT + : producerListTail(other.producerListTail.load(std::memory_order_relaxed)), + producerCount(other.producerCount.load(std::memory_order_relaxed)), + initialBlockPoolIndex(other.initialBlockPoolIndex.load(std::memory_order_relaxed)), + initialBlockPool(other.initialBlockPool), + initialBlockPoolSize(other.initialBlockPoolSize), + freeList(std::move(other.freeList)), + nextExplicitConsumerId(other.nextExplicitConsumerId.load(std::memory_order_relaxed)), + globalExplicitConsumerOffset(other.globalExplicitConsumerOffset.load(std::memory_order_relaxed)) + { + // Move the other one into this, and leave the other one as an empty queue + implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); + populate_initial_implicit_producer_hash(); + swap_implicit_producer_hashes(other); + + other.producerListTail.store(nullptr, std::memory_order_relaxed); + other.producerCount.store(0, std::memory_order_relaxed); + other.nextExplicitConsumerId.store(0, std::memory_order_relaxed); + other.globalExplicitConsumerOffset.store(0, std::memory_order_relaxed); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + explicitProducers.store(other.explicitProducers.load(std::memory_order_relaxed), std::memory_order_relaxed); + other.explicitProducers.store(nullptr, std::memory_order_relaxed); + implicitProducers.store(other.implicitProducers.load(std::memory_order_relaxed), std::memory_order_relaxed); + other.implicitProducers.store(nullptr, std::memory_order_relaxed); +#endif + + other.initialBlockPoolIndex.store(0, std::memory_order_relaxed); + other.initialBlockPoolSize = 0; + other.initialBlockPool = nullptr; + + reown_producers(); + } + + inline ConcurrentQueue& operator=(ConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT + { + return swap_internal(other); + } + + // Swaps this queue's state with the other's. Not thread-safe. + // Swapping two queues does not invalidate their tokens, however + // the tokens that were created for one queue must be used with + // only the swapped queue (i.e. the tokens are tied to the + // queue's movable state, not the object itself). + inline void swap(ConcurrentQueue& other) MOODYCAMEL_NOEXCEPT + { + swap_internal(other); + } + +private: + ConcurrentQueue& swap_internal(ConcurrentQueue& other) + { + if (this == &other) { + return *this; + } + + details::swap_relaxed(producerListTail, other.producerListTail); + details::swap_relaxed(producerCount, other.producerCount); + details::swap_relaxed(initialBlockPoolIndex, other.initialBlockPoolIndex); + std::swap(initialBlockPool, other.initialBlockPool); + std::swap(initialBlockPoolSize, other.initialBlockPoolSize); + freeList.swap(other.freeList); + details::swap_relaxed(nextExplicitConsumerId, other.nextExplicitConsumerId); + details::swap_relaxed(globalExplicitConsumerOffset, other.globalExplicitConsumerOffset); + + swap_implicit_producer_hashes(other); + + reown_producers(); + other.reown_producers(); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + details::swap_relaxed(explicitProducers, other.explicitProducers); + details::swap_relaxed(implicitProducers, other.implicitProducers); +#endif + + return *this; + } + +public: + // Enqueues a single item (by copying it). + // Allocates memory if required. Only fails if memory allocation fails (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, + // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T const& item) + { + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + else return inner_enqueue(item); + } + + // Enqueues a single item (by moving it, if possible). + // Allocates memory if required. Only fails if memory allocation fails (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, + // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T&& item) + { + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + else return inner_enqueue(std::move(item)); + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const& token, T const& item) + { + return inner_enqueue(token, item); + } + + // Enqueues a single item (by moving it, if possible) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const& token, T&& item) + { + return inner_enqueue(token, std::move(item)); + } + + // Enqueues several items. + // Allocates memory if required. Only fails if memory allocation fails (or + // implicit production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0, or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved instead of copied. + // Thread-safe. + template + bool enqueue_bulk(It itemFirst, size_t count) + { + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + else return inner_enqueue_bulk(itemFirst, count); + } + + // Enqueues several items using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails + // (or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + bool enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) + { + return inner_enqueue_bulk(token, itemFirst, count); + } + + // Enqueues a single item (by copying it). + // Does not allocate memory. Fails if not enough room to enqueue (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0). + // Thread-safe. + inline bool try_enqueue(T const& item) + { + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + else return inner_enqueue(item); + } + + // Enqueues a single item (by moving it, if possible). + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Thread-safe. + inline bool try_enqueue(T&& item) + { + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + else return inner_enqueue(std::move(item)); + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const& token, T const& item) + { + return inner_enqueue(token, item); + } + + // Enqueues a single item (by moving it, if possible) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const& token, T&& item) + { + return inner_enqueue(token, std::move(item)); + } + + // Enqueues several items. + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + bool try_enqueue_bulk(It itemFirst, size_t count) + { + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + else return inner_enqueue_bulk(itemFirst, count); + } + + // Enqueues several items using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + bool try_enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) + { + return inner_enqueue_bulk(token, itemFirst, count); + } + + + + // Attempts to dequeue from the queue. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + bool try_dequeue(U& item) + { + // Instead of simply trying each producer in turn (which could cause needless contention on the first + // producer), we score them heuristically. + size_t nonEmptyCount = 0; + ProducerBase* best = nullptr; + size_t bestSize = 0; + for (auto ptr = producerListTail.load(std::memory_order_acquire); nonEmptyCount < 3 && ptr != nullptr; ptr = ptr->next_prod()) { + auto size = ptr->size_approx(); + if (size > 0) { + if (size > bestSize) { + bestSize = size; + best = ptr; + } + ++nonEmptyCount; + } + } + + // If there was at least one non-empty queue but it appears empty at the time + // we try to dequeue from it, we need to make sure every queue's been tried + if (nonEmptyCount > 0) { + if ((details::likely)(best->dequeue(item))) { + return true; + } + for (auto ptr = producerListTail.load(std::memory_order_acquire); ptr != nullptr; ptr = ptr->next_prod()) { + if (ptr != best && ptr->dequeue(item)) { + return true; + } + } + } + return false; + } + + // Attempts to dequeue from the queue. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // This differs from the try_dequeue(item) method in that this one does + // not attempt to reduce contention by interleaving the order that producer + // streams are dequeued from. So, using this method can reduce overall throughput + // under contention, but will give more predictable results in single-threaded + // consumer scenarios. This is mostly only useful for internal unit tests. + // Never allocates. Thread-safe. + template + bool try_dequeue_non_interleaved(U& item) + { + for (auto ptr = producerListTail.load(std::memory_order_acquire); ptr != nullptr; ptr = ptr->next_prod()) { + if (ptr->dequeue(item)) { + return true; + } + } + return false; + } + + // Attempts to dequeue from the queue using an explicit consumer token. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + bool try_dequeue(consumer_token_t& token, U& item) + { + // The idea is roughly as follows: + // Every 256 items from one producer, make everyone rotate (increase the global offset) -> this means the highest efficiency consumer dictates the rotation speed of everyone else, more or less + // If you see that the global offset has changed, you must reset your consumption counter and move to your designated place + // If there's no items where you're supposed to be, keep moving until you find a producer with some items + // If the global offset has not changed but you've run out of items to consume, move over from your current position until you find an producer with something in it + + if (token.desiredProducer == nullptr || token.lastKnownGlobalOffset != globalExplicitConsumerOffset.load(std::memory_order_relaxed)) { + if (!update_current_producer_after_rotation(token)) { + return false; + } + } + + // If there was at least one non-empty queue but it appears empty at the time + // we try to dequeue from it, we need to make sure every queue's been tried + if (static_cast(token.currentProducer)->dequeue(item)) { + if (++token.itemsConsumedFromCurrent == EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE) { + globalExplicitConsumerOffset.fetch_add(1, std::memory_order_relaxed); + } + return true; + } + + auto tail = producerListTail.load(std::memory_order_acquire); + auto ptr = static_cast(token.currentProducer)->next_prod(); + if (ptr == nullptr) { + ptr = tail; + } + while (ptr != static_cast(token.currentProducer)) { + if (ptr->dequeue(item)) { + token.currentProducer = ptr; + token.itemsConsumedFromCurrent = 1; + return true; + } + ptr = ptr->next_prod(); + if (ptr == nullptr) { + ptr = tail; + } + } + return false; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + size_t try_dequeue_bulk(It itemFirst, size_t max) + { + size_t count = 0; + for (auto ptr = producerListTail.load(std::memory_order_acquire); ptr != nullptr; ptr = ptr->next_prod()) { + count += ptr->dequeue_bulk(itemFirst, max - count); + if (count == max) { + break; + } + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + size_t try_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) + { + if (token.desiredProducer == nullptr || token.lastKnownGlobalOffset != globalExplicitConsumerOffset.load(std::memory_order_relaxed)) { + if (!update_current_producer_after_rotation(token)) { + return 0; + } + } + + size_t count = static_cast(token.currentProducer)->dequeue_bulk(itemFirst, max); + if (count == max) { + if ((token.itemsConsumedFromCurrent += static_cast(max)) >= EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE) { + globalExplicitConsumerOffset.fetch_add(1, std::memory_order_relaxed); + } + return max; + } + token.itemsConsumedFromCurrent += static_cast(count); + max -= count; + + auto tail = producerListTail.load(std::memory_order_acquire); + auto ptr = static_cast(token.currentProducer)->next_prod(); + if (ptr == nullptr) { + ptr = tail; + } + while (ptr != static_cast(token.currentProducer)) { + auto dequeued = ptr->dequeue_bulk(itemFirst, max); + count += dequeued; + if (dequeued != 0) { + token.currentProducer = ptr; + token.itemsConsumedFromCurrent = static_cast(dequeued); + } + if (dequeued == max) { + break; + } + max -= dequeued; + ptr = ptr->next_prod(); + if (ptr == nullptr) { + ptr = tail; + } + } + return count; + } + + + + // Attempts to dequeue from a specific producer's inner queue. + // If you happen to know which producer you want to dequeue from, this + // is significantly faster than using the general-case try_dequeue methods. + // Returns false if the producer's queue appeared empty at the time it + // was checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline bool try_dequeue_from_producer(producer_token_t const& producer, U& item) + { + return static_cast(producer.producer)->dequeue(item); + } + + // Attempts to dequeue several elements from a specific producer's inner queue. + // Returns the number of items actually dequeued. + // If you happen to know which producer you want to dequeue from, this + // is significantly faster than using the general-case try_dequeue methods. + // Returns 0 if the producer's queue appeared empty at the time it + // was checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline size_t try_dequeue_bulk_from_producer(producer_token_t const& producer, It itemFirst, size_t max) + { + return static_cast(producer.producer)->dequeue_bulk(itemFirst, max); + } + + + // Returns an estimate of the total number of elements currently in the queue. This + // estimate is only accurate if the queue has completely stabilized before it is called + // (i.e. all enqueue and dequeue operations have completed and their memory effects are + // visible on the calling thread, and no further operations start while this method is + // being called). + // Thread-safe. + size_t size_approx() const + { + size_t size = 0; + for (auto ptr = producerListTail.load(std::memory_order_acquire); ptr != nullptr; ptr = ptr->next_prod()) { + size += ptr->size_approx(); + } + return size; + } + + + // Returns true if the underlying atomic variables used by + // the queue are lock-free (they should be on most platforms). + // Thread-safe. + static bool is_lock_free() + { + return + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::thread_id_numeric_size_t>::value == 2; + } + + +private: + friend struct ProducerToken; + friend struct ConsumerToken; + struct ExplicitProducer; + friend struct ExplicitProducer; + struct ImplicitProducer; + friend struct ImplicitProducer; + friend class ConcurrentQueueTests; + + enum AllocationMode { CanAlloc, CannotAlloc }; + + + /////////////////////////////// + // Queue methods + /////////////////////////////// + + template + inline bool inner_enqueue(producer_token_t const& token, U&& element) + { + return static_cast(token.producer)->ConcurrentQueue::ExplicitProducer::template enqueue(std::forward(element)); + } + + template + inline bool inner_enqueue(U&& element) + { + auto producer = get_or_add_implicit_producer(); + return producer == nullptr ? false : producer->ConcurrentQueue::ImplicitProducer::template enqueue(std::forward(element)); + } + + template + inline bool inner_enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) + { + return static_cast(token.producer)->ConcurrentQueue::ExplicitProducer::template enqueue_bulk(itemFirst, count); + } + + template + inline bool inner_enqueue_bulk(It itemFirst, size_t count) + { + auto producer = get_or_add_implicit_producer(); + return producer == nullptr ? false : producer->ConcurrentQueue::ImplicitProducer::template enqueue_bulk(itemFirst, count); + } + + inline bool update_current_producer_after_rotation(consumer_token_t& token) + { + // Ah, there's been a rotation, figure out where we should be! + auto tail = producerListTail.load(std::memory_order_acquire); + if (token.desiredProducer == nullptr && tail == nullptr) { + return false; + } + auto prodCount = producerCount.load(std::memory_order_relaxed); + auto globalOffset = globalExplicitConsumerOffset.load(std::memory_order_relaxed); + if (token.desiredProducer == nullptr) { + // Aha, first time we're dequeueing anything. + // Figure out our local position + // Note: offset is from start, not end, but we're traversing from end -- subtract from count first + std::uint32_t offset = prodCount - 1 - (token.initialOffset % prodCount); + token.desiredProducer = tail; + for (std::uint32_t i = 0; i != offset; ++i) { + token.desiredProducer = static_cast(token.desiredProducer)->next_prod(); + if (token.desiredProducer == nullptr) { + token.desiredProducer = tail; + } + } + } + + std::uint32_t delta = globalOffset - token.lastKnownGlobalOffset; + if (delta >= prodCount) { + delta = delta % prodCount; + } + for (std::uint32_t i = 0; i != delta; ++i) { + token.desiredProducer = static_cast(token.desiredProducer)->next_prod(); + if (token.desiredProducer == nullptr) { + token.desiredProducer = tail; + } + } + + token.lastKnownGlobalOffset = globalOffset; + token.currentProducer = token.desiredProducer; + token.itemsConsumedFromCurrent = 0; + return true; + } + + + /////////////////////////// + // Free list + /////////////////////////// + + template + struct FreeListNode + { + FreeListNode() : freeListRefs(0), freeListNext(nullptr) { } + + std::atomic freeListRefs; + std::atomic freeListNext; + }; + + // A simple CAS-based lock-free free list. Not the fastest thing in the world under heavy contention, but + // simple and correct (assuming nodes are never freed until after the free list is destroyed), and fairly + // speedy under low contention. + template // N must inherit FreeListNode or have the same fields (and initialization of them) + struct FreeList + { + FreeList() : freeListHead(nullptr) { } + FreeList(FreeList&& other) : freeListHead(other.freeListHead.load(std::memory_order_relaxed)) { other.freeListHead.store(nullptr, std::memory_order_relaxed); } + void swap(FreeList& other) { details::swap_relaxed(freeListHead, other.freeListHead); } + + FreeList(FreeList const&) MOODYCAMEL_DELETE_FUNCTION; + FreeList& operator=(FreeList const&) MOODYCAMEL_DELETE_FUNCTION; + + inline void add(N* node) + { +#ifdef MCDBGQ_NOLOCKFREE_FREELIST + debug::DebugLock lock(mutex); +#endif + // We know that the should-be-on-freelist bit is 0 at this point, so it's safe to + // set it using a fetch_add + if (node->freeListRefs.fetch_add(SHOULD_BE_ON_FREELIST, std::memory_order_acq_rel) == 0) { + // Oh look! We were the last ones referencing this node, and we know + // we want to add it to the free list, so let's do it! + add_knowing_refcount_is_zero(node); + } + } + + inline N* try_get() + { +#ifdef MCDBGQ_NOLOCKFREE_FREELIST + debug::DebugLock lock(mutex); +#endif + auto head = freeListHead.load(std::memory_order_acquire); + while (head != nullptr) { + auto prevHead = head; + auto refs = head->freeListRefs.load(std::memory_order_relaxed); + if ((refs & REFS_MASK) == 0 || !head->freeListRefs.compare_exchange_strong(refs, refs + 1, std::memory_order_acquire, std::memory_order_relaxed)) { + head = freeListHead.load(std::memory_order_acquire); + continue; + } + + // Good, reference count has been incremented (it wasn't at zero), which means we can read the + // next and not worry about it changing between now and the time we do the CAS + auto next = head->freeListNext.load(std::memory_order_relaxed); + if (freeListHead.compare_exchange_strong(head, next, std::memory_order_acquire, std::memory_order_relaxed)) { + // Yay, got the node. This means it was on the list, which means shouldBeOnFreeList must be false no + // matter the refcount (because nobody else knows it's been taken off yet, it can't have been put back on). + assert((head->freeListRefs.load(std::memory_order_relaxed) & SHOULD_BE_ON_FREELIST) == 0); + + // Decrease refcount twice, once for our ref, and once for the list's ref + head->freeListRefs.fetch_sub(2, std::memory_order_release); + return head; + } + + // OK, the head must have changed on us, but we still need to decrease the refcount we increased. + // Note that we don't need to release any memory effects, but we do need to ensure that the reference + // count decrement happens-after the CAS on the head. + refs = prevHead->freeListRefs.fetch_sub(1, std::memory_order_acq_rel); + if (refs == SHOULD_BE_ON_FREELIST + 1) { + add_knowing_refcount_is_zero(prevHead); + } + } + + return nullptr; + } + + // Useful for traversing the list when there's no contention (e.g. to destroy remaining nodes) + N* head_unsafe() const { return freeListHead.load(std::memory_order_relaxed); } + + private: + inline void add_knowing_refcount_is_zero(N* node) + { + // Since the refcount is zero, and nobody can increase it once it's zero (except us, and we run + // only one copy of this method per node at a time, i.e. the single thread case), then we know + // we can safely change the next pointer of the node; however, once the refcount is back above + // zero, then other threads could increase it (happens under heavy contention, when the refcount + // goes to zero in between a load and a refcount increment of a node in try_get, then back up to + // something non-zero, then the refcount increment is done by the other thread) -- so, if the CAS + // to add the node to the actual list fails, decrease the refcount and leave the add operation to + // the next thread who puts the refcount back at zero (which could be us, hence the loop). + auto head = freeListHead.load(std::memory_order_relaxed); + while (true) { + node->freeListNext.store(head, std::memory_order_relaxed); + node->freeListRefs.store(1, std::memory_order_release); + if (!freeListHead.compare_exchange_strong(head, node, std::memory_order_release, std::memory_order_relaxed)) { + // Hmm, the add failed, but we can only try again when the refcount goes back to zero + if (node->freeListRefs.fetch_add(SHOULD_BE_ON_FREELIST - 1, std::memory_order_release) == 1) { + continue; + } + } + return; + } + } + + private: + // Implemented like a stack, but where node order doesn't matter (nodes are inserted out of order under contention) + std::atomic freeListHead; + + static const std::uint32_t REFS_MASK = 0x7FFFFFFF; + static const std::uint32_t SHOULD_BE_ON_FREELIST = 0x80000000; + +#ifdef MCDBGQ_NOLOCKFREE_FREELIST + debug::DebugMutex mutex; +#endif + }; + + + /////////////////////////// + // Block + /////////////////////////// + + enum InnerQueueContext { implicit_context = 0, explicit_context = 1 }; + + struct Block + { + Block() + : next(nullptr), elementsCompletelyDequeued(0), freeListRefs(0), freeListNext(nullptr), shouldBeOnFreeList(false), dynamicallyAllocated(true) + { +#ifdef MCDBGQ_TRACKMEM + owner = nullptr; +#endif + } + + template + inline bool is_empty() const + { + MOODYCAMEL_CONSTEXPR_IF (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Check flags + for (size_t i = 0; i < BLOCK_SIZE; ++i) { + if (!emptyFlags[i].load(std::memory_order_relaxed)) { + return false; + } + } + + // Aha, empty; make sure we have all other memory effects that happened before the empty flags were set + std::atomic_thread_fence(std::memory_order_acquire); + return true; + } + else { + // Check counter + if (elementsCompletelyDequeued.load(std::memory_order_relaxed) == BLOCK_SIZE) { + std::atomic_thread_fence(std::memory_order_acquire); + return true; + } + assert(elementsCompletelyDequeued.load(std::memory_order_relaxed) <= BLOCK_SIZE); + return false; + } + } + + // Returns true if the block is now empty (does not apply in explicit context) + template + inline bool set_empty(MOODYCAMEL_MAYBE_UNUSED index_t i) + { + MOODYCAMEL_CONSTEXPR_IF (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Set flag + assert(!emptyFlags[BLOCK_SIZE - 1 - static_cast(i & static_cast(BLOCK_SIZE - 1))].load(std::memory_order_relaxed)); + emptyFlags[BLOCK_SIZE - 1 - static_cast(i & static_cast(BLOCK_SIZE - 1))].store(true, std::memory_order_release); + return false; + } + else { + // Increment counter + auto prevVal = elementsCompletelyDequeued.fetch_add(1, std::memory_order_release); + assert(prevVal < BLOCK_SIZE); + return prevVal == BLOCK_SIZE - 1; + } + } + + // Sets multiple contiguous item statuses to 'empty' (assumes no wrapping and count > 0). + // Returns true if the block is now empty (does not apply in explicit context). + template + inline bool set_many_empty(MOODYCAMEL_MAYBE_UNUSED index_t i, size_t count) + { + MOODYCAMEL_CONSTEXPR_IF (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Set flags + std::atomic_thread_fence(std::memory_order_release); + i = BLOCK_SIZE - 1 - static_cast(i & static_cast(BLOCK_SIZE - 1)) - count + 1; + for (size_t j = 0; j != count; ++j) { + assert(!emptyFlags[i + j].load(std::memory_order_relaxed)); + emptyFlags[i + j].store(true, std::memory_order_relaxed); + } + return false; + } + else { + // Increment counter + auto prevVal = elementsCompletelyDequeued.fetch_add(count, std::memory_order_release); + assert(prevVal + count <= BLOCK_SIZE); + return prevVal + count == BLOCK_SIZE; + } + } + + template + inline void set_all_empty() + { + MOODYCAMEL_CONSTEXPR_IF (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Set all flags + for (size_t i = 0; i != BLOCK_SIZE; ++i) { + emptyFlags[i].store(true, std::memory_order_relaxed); + } + } + else { + // Reset counter + elementsCompletelyDequeued.store(BLOCK_SIZE, std::memory_order_relaxed); + } + } + + template + inline void reset_empty() + { + MOODYCAMEL_CONSTEXPR_IF (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Reset flags + for (size_t i = 0; i != BLOCK_SIZE; ++i) { + emptyFlags[i].store(false, std::memory_order_relaxed); + } + } + else { + // Reset counter + elementsCompletelyDequeued.store(0, std::memory_order_relaxed); + } + } + + inline T* operator[](index_t idx) MOODYCAMEL_NOEXCEPT { return static_cast(static_cast(elements)) + static_cast(idx & static_cast(BLOCK_SIZE - 1)); } + inline T const* operator[](index_t idx) const MOODYCAMEL_NOEXCEPT { return static_cast(static_cast(elements)) + static_cast(idx & static_cast(BLOCK_SIZE - 1)); } + + private: + static_assert(std::alignment_of::value <= sizeof(T), "The queue does not support types with an alignment greater than their size at this time"); + MOODYCAMEL_ALIGNAS(MOODYCAMEL_ALIGNOF(T)) char elements[sizeof(T) * BLOCK_SIZE]; + public: + Block* next; + std::atomic elementsCompletelyDequeued; + std::atomic emptyFlags[BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD ? BLOCK_SIZE : 1]; + public: + std::atomic freeListRefs; + std::atomic freeListNext; + std::atomic shouldBeOnFreeList; + bool dynamicallyAllocated; // Perhaps a better name for this would be 'isNotPartOfInitialBlockPool' + +#ifdef MCDBGQ_TRACKMEM + void* owner; +#endif + }; + static_assert(std::alignment_of::value >= std::alignment_of::value, "Internal error: Blocks must be at least as aligned as the type they are wrapping"); + + +#ifdef MCDBGQ_TRACKMEM +public: + struct MemStats; +private: +#endif + + /////////////////////////// + // Producer base + /////////////////////////// + + struct ProducerBase : public details::ConcurrentQueueProducerTypelessBase + { + ProducerBase(ConcurrentQueue* parent_, bool isExplicit_) : + tailIndex(0), + headIndex(0), + dequeueOptimisticCount(0), + dequeueOvercommit(0), + tailBlock(nullptr), + isExplicit(isExplicit_), + parent(parent_) + { + } + + virtual ~ProducerBase() { }; + + template + inline bool dequeue(U& element) + { + if (isExplicit) { + return static_cast(this)->dequeue(element); + } + else { + return static_cast(this)->dequeue(element); + } + } + + template + inline size_t dequeue_bulk(It& itemFirst, size_t max) + { + if (isExplicit) { + return static_cast(this)->dequeue_bulk(itemFirst, max); + } + else { + return static_cast(this)->dequeue_bulk(itemFirst, max); + } + } + + inline ProducerBase* next_prod() const { return static_cast(next); } + + inline size_t size_approx() const + { + auto tail = tailIndex.load(std::memory_order_relaxed); + auto head = headIndex.load(std::memory_order_relaxed); + return details::circular_less_than(head, tail) ? static_cast(tail - head) : 0; + } + + inline index_t getTail() const { return tailIndex.load(std::memory_order_relaxed); } + protected: + std::atomic tailIndex; // Where to enqueue to next + std::atomic headIndex; // Where to dequeue from next + + std::atomic dequeueOptimisticCount; + std::atomic dequeueOvercommit; + + Block* tailBlock; + + public: + bool isExplicit; + ConcurrentQueue* parent; + + protected: +#ifdef MCDBGQ_TRACKMEM + friend struct MemStats; +#endif + }; + + + /////////////////////////// + // Explicit queue + /////////////////////////// + + struct ExplicitProducer : public ProducerBase + { + explicit ExplicitProducer(ConcurrentQueue* parent_) : + ProducerBase(parent_, true), + blockIndex(nullptr), + pr_blockIndexSlotsUsed(0), + pr_blockIndexSize(EXPLICIT_INITIAL_INDEX_SIZE >> 1), + pr_blockIndexFront(0), + pr_blockIndexEntries(nullptr), + pr_blockIndexRaw(nullptr) + { + size_t poolBasedIndexSize = details::ceil_to_pow_2(parent_->initialBlockPoolSize) >> 1; + if (poolBasedIndexSize > pr_blockIndexSize) { + pr_blockIndexSize = poolBasedIndexSize; + } + + new_block_index(0); // This creates an index with double the number of current entries, i.e. EXPLICIT_INITIAL_INDEX_SIZE + } + + ~ExplicitProducer() + { + // Destruct any elements not yet dequeued. + // Since we're in the destructor, we can assume all elements + // are either completely dequeued or completely not (no halfways). + if (this->tailBlock != nullptr) { // Note this means there must be a block index too + // First find the block that's partially dequeued, if any + Block* halfDequeuedBlock = nullptr; + if ((this->headIndex.load(std::memory_order_relaxed) & static_cast(BLOCK_SIZE - 1)) != 0) { + // The head's not on a block boundary, meaning a block somewhere is partially dequeued + // (or the head block is the tail block and was fully dequeued, but the head/tail are still not on a boundary) + size_t i = (pr_blockIndexFront - pr_blockIndexSlotsUsed) & (pr_blockIndexSize - 1); + while (details::circular_less_than(pr_blockIndexEntries[i].base + BLOCK_SIZE, this->headIndex.load(std::memory_order_relaxed))) { + i = (i + 1) & (pr_blockIndexSize - 1); + } + assert(details::circular_less_than(pr_blockIndexEntries[i].base, this->headIndex.load(std::memory_order_relaxed))); + halfDequeuedBlock = pr_blockIndexEntries[i].block; + } + + // Start at the head block (note the first line in the loop gives us the head from the tail on the first iteration) + auto block = this->tailBlock; + do { + block = block->next; + if (block->ConcurrentQueue::Block::template is_empty()) { + continue; + } + + size_t i = 0; // Offset into block + if (block == halfDequeuedBlock) { + i = static_cast(this->headIndex.load(std::memory_order_relaxed) & static_cast(BLOCK_SIZE - 1)); + } + + // Walk through all the items in the block; if this is the tail block, we need to stop when we reach the tail index + auto lastValidIndex = (this->tailIndex.load(std::memory_order_relaxed) & static_cast(BLOCK_SIZE - 1)) == 0 ? BLOCK_SIZE : static_cast(this->tailIndex.load(std::memory_order_relaxed) & static_cast(BLOCK_SIZE - 1)); + while (i != BLOCK_SIZE && (block != this->tailBlock || i != lastValidIndex)) { + (*block)[i++]->~T(); + } + } while (block != this->tailBlock); + } + + // Destroy all blocks that we own + if (this->tailBlock != nullptr) { + auto block = this->tailBlock; + do { + auto nextBlock = block->next; + if (block->dynamicallyAllocated) { + destroy(block); + } + else { + this->parent->add_block_to_free_list(block); + } + block = nextBlock; + } while (block != this->tailBlock); + } + + // Destroy the block indices + auto header = static_cast(pr_blockIndexRaw); + while (header != nullptr) { + auto prev = static_cast(header->prev); + header->~BlockIndexHeader(); + (Traits::free)(header); + header = prev; + } + } + + template + inline bool enqueue(U&& element) + { + index_t currentTailIndex = this->tailIndex.load(std::memory_order_relaxed); + index_t newTailIndex = 1 + currentTailIndex; + if ((currentTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { + // We reached the end of a block, start a new one + auto startBlock = this->tailBlock; + auto originalBlockIndexSlotsUsed = pr_blockIndexSlotsUsed; + if (this->tailBlock != nullptr && this->tailBlock->next->ConcurrentQueue::Block::template is_empty()) { + // We can re-use the block ahead of us, it's empty! + this->tailBlock = this->tailBlock->next; + this->tailBlock->ConcurrentQueue::Block::template reset_empty(); + + // We'll put the block on the block index (guaranteed to be room since we're conceptually removing the + // last block from it first -- except instead of removing then adding, we can just overwrite). + // Note that there must be a valid block index here, since even if allocation failed in the ctor, + // it would have been re-attempted when adding the first block to the queue; since there is such + // a block, a block index must have been successfully allocated. + } + else { + // Whatever head value we see here is >= the last value we saw here (relatively), + // and <= its current value. Since we have the most recent tail, the head must be + // <= to it. + auto head = this->headIndex.load(std::memory_order_relaxed); + assert(!details::circular_less_than(currentTailIndex, head)); + if (!details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) + || (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && (MAX_SUBQUEUE_SIZE == 0 || MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head))) { + // We can't enqueue in another block because there's not enough leeway -- the + // tail could surpass the head by the time the block fills up! (Or we'll exceed + // the size limit, if the second part of the condition was true.) + return false; + } + // We're going to need a new block; check that the block index has room + if (pr_blockIndexRaw == nullptr || pr_blockIndexSlotsUsed == pr_blockIndexSize) { + // Hmm, the circular block index is already full -- we'll need + // to allocate a new index. Note pr_blockIndexRaw can only be nullptr if + // the initial allocation failed in the constructor. + + MOODYCAMEL_CONSTEXPR_IF (allocMode == CannotAlloc) { + return false; + } + else if (!new_block_index(pr_blockIndexSlotsUsed)) { + return false; + } + } + + // Insert a new block in the circular linked list + auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); + if (newBlock == nullptr) { + return false; + } +#ifdef MCDBGQ_TRACKMEM + newBlock->owner = this; +#endif + newBlock->ConcurrentQueue::Block::template reset_empty(); + if (this->tailBlock == nullptr) { + newBlock->next = newBlock; + } + else { + newBlock->next = this->tailBlock->next; + this->tailBlock->next = newBlock; + } + this->tailBlock = newBlock; + ++pr_blockIndexSlotsUsed; + } + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new ((T*)nullptr) T(std::forward(element)))) { + // The constructor may throw. We want the element not to appear in the queue in + // that case (without corrupting the queue): + MOODYCAMEL_TRY { + new ((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); + } + MOODYCAMEL_CATCH (...) { + // Revert change to the current block, but leave the new block available + // for next time + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? this->tailBlock : startBlock; + MOODYCAMEL_RETHROW; + } + } + else { + (void)startBlock; + (void)originalBlockIndexSlotsUsed; + } + + // Add block to block index + auto& entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; + entry.base = currentTailIndex; + entry.block = this->tailBlock; + blockIndex.load(std::memory_order_relaxed)->front.store(pr_blockIndexFront, std::memory_order_release); + pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new ((T*)nullptr) T(std::forward(element)))) { + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + } + + // Enqueue + new ((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); + + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + + template + bool dequeue(U& element) + { + auto tail = this->tailIndex.load(std::memory_order_relaxed); + auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); + if (details::circular_less_than(this->dequeueOptimisticCount.load(std::memory_order_relaxed) - overcommit, tail)) { + // Might be something to dequeue, let's give it a try + + // Note that this if is purely for performance purposes in the common case when the queue is + // empty and the values are eventually consistent -- we may enter here spuriously. + + // Note that whatever the values of overcommit and tail are, they are not going to change (unless we + // change them) and must be the same value at this point (inside the if) as when the if condition was + // evaluated. + + // We insert an acquire fence here to synchronize-with the release upon incrementing dequeueOvercommit below. + // This ensures that whatever the value we got loaded into overcommit, the load of dequeueOptisticCount in + // the fetch_add below will result in a value at least as recent as that (and therefore at least as large). + // Note that I believe a compiler (signal) fence here would be sufficient due to the nature of fetch_add (all + // read-modify-write operations are guaranteed to work on the latest value in the modification order), but + // unfortunately that can't be shown to be correct using only the C++11 standard. + // See http://stackoverflow.com/questions/18223161/what-are-the-c11-memory-ordering-guarantees-in-this-corner-case + std::atomic_thread_fence(std::memory_order_acquire); + + // Increment optimistic counter, then check if it went over the boundary + auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(1, std::memory_order_relaxed); + + // Note that since dequeueOvercommit must be <= dequeueOptimisticCount (because dequeueOvercommit is only ever + // incremented after dequeueOptimisticCount -- this is enforced in the `else` block below), and since we now + // have a version of dequeueOptimisticCount that is at least as recent as overcommit (due to the release upon + // incrementing dequeueOvercommit and the acquire above that synchronizes with it), overcommit <= myDequeueCount. + // However, we can't assert this since both dequeueOptimisticCount and dequeueOvercommit may (independently) + // overflow; in such a case, though, the logic still holds since the difference between the two is maintained. + + // Note that we reload tail here in case it changed; it will be the same value as before or greater, since + // this load is sequenced after (happens after) the earlier load above. This is supported by read-read + // coherency (as defined in the standard), explained here: http://en.cppreference.com/w/cpp/atomic/memory_order + tail = this->tailIndex.load(std::memory_order_acquire); + if ((details::likely)(details::circular_less_than(myDequeueCount - overcommit, tail))) { + // Guaranteed to be at least one element to dequeue! + + // Get the index. Note that since there's guaranteed to be at least one element, this + // will never exceed tail. We need to do an acquire-release fence here since it's possible + // that whatever condition got us to this point was for an earlier enqueued element (that + // we already see the memory effects for), but that by the time we increment somebody else + // has incremented it, and we need to see the memory effects for *that* element, which is + // in such a case is necessarily visible on the thread that incremented it in the first + // place with the more current condition (they must have acquired a tail that is at least + // as recent). + auto index = this->headIndex.fetch_add(1, std::memory_order_acq_rel); + + + // Determine which block the element is in + + auto localBlockIndex = blockIndex.load(std::memory_order_acquire); + auto localBlockIndexHead = localBlockIndex->front.load(std::memory_order_acquire); + + // We need to be careful here about subtracting and dividing because of index wrap-around. + // When an index wraps, we need to preserve the sign of the offset when dividing it by the + // block size (in order to get a correct signed block count offset in all cases): + auto headBase = localBlockIndex->entries[localBlockIndexHead].base; + auto blockBaseIndex = index & ~static_cast(BLOCK_SIZE - 1); + auto offset = static_cast(static_cast::type>(blockBaseIndex - headBase) / BLOCK_SIZE); + auto block = localBlockIndex->entries[(localBlockIndexHead + offset) & (localBlockIndex->size - 1)].block; + + // Dequeue + auto& el = *((*block)[index]); + if (!MOODYCAMEL_NOEXCEPT_ASSIGN(T, T&&, element = std::move(el))) { + // Make sure the element is still fully dequeued and destroyed even if the assignment + // throws + struct Guard { + Block* block; + index_t index; + + ~Guard() + { + (*block)[index]->~T(); + block->ConcurrentQueue::Block::template set_empty(index); + } + } guard = { block, index }; + + element = std::move(el); // NOLINT + } + else { + element = std::move(el); // NOLINT + el.~T(); // NOLINT + block->ConcurrentQueue::Block::template set_empty(index); + } + + return true; + } + else { + // Wasn't anything to dequeue after all; make the effective dequeue count eventually consistent + this->dequeueOvercommit.fetch_add(1, std::memory_order_release); // Release so that the fetch_add on dequeueOptimisticCount is guaranteed to happen before this write + } + } + + return false; + } + + template + bool enqueue_bulk(It itemFirst, size_t count) + { + // First, we need to make sure we have enough room to enqueue all of the elements; + // this means pre-allocating blocks and putting them in the block index (but only if + // all the allocations succeeded). + index_t startTailIndex = this->tailIndex.load(std::memory_order_relaxed); + auto startBlock = this->tailBlock; + auto originalBlockIndexFront = pr_blockIndexFront; + auto originalBlockIndexSlotsUsed = pr_blockIndexSlotsUsed; + + Block* firstAllocatedBlock = nullptr; + + // Figure out how many blocks we'll need to allocate, and do so + size_t blockBaseDiff = ((startTailIndex + count - 1) & ~static_cast(BLOCK_SIZE - 1)) - ((startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1)); + index_t currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); + if (blockBaseDiff > 0) { + // Allocate as many blocks as possible from ahead + while (blockBaseDiff > 0 && this->tailBlock != nullptr && this->tailBlock->next != firstAllocatedBlock && this->tailBlock->next->ConcurrentQueue::Block::template is_empty()) { + blockBaseDiff -= static_cast(BLOCK_SIZE); + currentTailIndex += static_cast(BLOCK_SIZE); + + this->tailBlock = this->tailBlock->next; + firstAllocatedBlock = firstAllocatedBlock == nullptr ? this->tailBlock : firstAllocatedBlock; + + auto& entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; + entry.base = currentTailIndex; + entry.block = this->tailBlock; + pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); + } + + // Now allocate as many blocks as necessary from the block pool + while (blockBaseDiff > 0) { + blockBaseDiff -= static_cast(BLOCK_SIZE); + currentTailIndex += static_cast(BLOCK_SIZE); + + auto head = this->headIndex.load(std::memory_order_relaxed); + assert(!details::circular_less_than(currentTailIndex, head)); + bool full = !details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && (MAX_SUBQUEUE_SIZE == 0 || MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head)); + if (pr_blockIndexRaw == nullptr || pr_blockIndexSlotsUsed == pr_blockIndexSize || full) { + MOODYCAMEL_CONSTEXPR_IF (allocMode == CannotAlloc) { + // Failed to allocate, undo changes (but keep injected blocks) + pr_blockIndexFront = originalBlockIndexFront; + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; + return false; + } + else if (full || !new_block_index(originalBlockIndexSlotsUsed)) { + // Failed to allocate, undo changes (but keep injected blocks) + pr_blockIndexFront = originalBlockIndexFront; + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; + return false; + } + + // pr_blockIndexFront is updated inside new_block_index, so we need to + // update our fallback value too (since we keep the new index even if we + // later fail) + originalBlockIndexFront = originalBlockIndexSlotsUsed; + } + + // Insert a new block in the circular linked list + auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); + if (newBlock == nullptr) { + pr_blockIndexFront = originalBlockIndexFront; + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; + return false; + } + +#ifdef MCDBGQ_TRACKMEM + newBlock->owner = this; +#endif + newBlock->ConcurrentQueue::Block::template set_all_empty(); + if (this->tailBlock == nullptr) { + newBlock->next = newBlock; + } + else { + newBlock->next = this->tailBlock->next; + this->tailBlock->next = newBlock; + } + this->tailBlock = newBlock; + firstAllocatedBlock = firstAllocatedBlock == nullptr ? this->tailBlock : firstAllocatedBlock; + + ++pr_blockIndexSlotsUsed; + + auto& entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; + entry.base = currentTailIndex; + entry.block = this->tailBlock; + pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); + } + + // Excellent, all allocations succeeded. Reset each block's emptiness before we fill them up, and + // publish the new block index front + auto block = firstAllocatedBlock; + while (true) { + block->ConcurrentQueue::Block::template reset_empty(); + if (block == this->tailBlock) { + break; + } + block = block->next; + } + + if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), new ((T*)nullptr) T(details::deref_noexcept(itemFirst)))) { + blockIndex.load(std::memory_order_relaxed)->front.store((pr_blockIndexFront - 1) & (pr_blockIndexSize - 1), std::memory_order_release); + } + } + + // Enqueue, one block at a time + index_t newTailIndex = startTailIndex + static_cast(count); + currentTailIndex = startTailIndex; + auto endBlock = this->tailBlock; + this->tailBlock = startBlock; + assert((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || firstAllocatedBlock != nullptr || count == 0); + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0 && firstAllocatedBlock != nullptr) { + this->tailBlock = firstAllocatedBlock; + } + while (true) { + auto stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + if (details::circular_less_than(newTailIndex, stopIndex)) { + stopIndex = newTailIndex; + } + if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), new ((T*)nullptr) T(details::deref_noexcept(itemFirst)))) { + while (currentTailIndex != stopIndex) { + new ((*this->tailBlock)[currentTailIndex++]) T(*itemFirst++); + } + } + else { + MOODYCAMEL_TRY { + while (currentTailIndex != stopIndex) { + // Must use copy constructor even if move constructor is available + // because we may have to revert if there's an exception. + // Sorry about the horrible templated next line, but it was the only way + // to disable moving *at compile time*, which is important because a type + // may only define a (noexcept) move constructor, and so calls to the + // cctor will not compile, even if they are in an if branch that will never + // be executed + new ((*this->tailBlock)[currentTailIndex]) T(details::nomove_if<(bool)!MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), new ((T*)nullptr) T(details::deref_noexcept(itemFirst)))>::eval(*itemFirst)); + ++currentTailIndex; + ++itemFirst; + } + } + MOODYCAMEL_CATCH (...) { + // Oh dear, an exception's been thrown -- destroy the elements that + // were enqueued so far and revert the entire bulk operation (we'll keep + // any allocated blocks in our linked list for later, though). + auto constructedStopIndex = currentTailIndex; + auto lastBlockEnqueued = this->tailBlock; + + pr_blockIndexFront = originalBlockIndexFront; + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; + + if (!details::is_trivially_destructible::value) { + auto block = startBlock; + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { + block = firstAllocatedBlock; + } + currentTailIndex = startTailIndex; + while (true) { + stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + if (details::circular_less_than(constructedStopIndex, stopIndex)) { + stopIndex = constructedStopIndex; + } + while (currentTailIndex != stopIndex) { + (*block)[currentTailIndex++]->~T(); + } + if (block == lastBlockEnqueued) { + break; + } + block = block->next; + } + } + MOODYCAMEL_RETHROW; + } + } + + if (this->tailBlock == endBlock) { + assert(currentTailIndex == newTailIndex); + break; + } + this->tailBlock = this->tailBlock->next; + } + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), new ((T*)nullptr) T(details::deref_noexcept(itemFirst))) && firstAllocatedBlock != nullptr) { + blockIndex.load(std::memory_order_relaxed)->front.store((pr_blockIndexFront - 1) & (pr_blockIndexSize - 1), std::memory_order_release); + } + + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + + template + size_t dequeue_bulk(It& itemFirst, size_t max) + { + auto tail = this->tailIndex.load(std::memory_order_relaxed); + auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); + auto desiredCount = static_cast(tail - (this->dequeueOptimisticCount.load(std::memory_order_relaxed) - overcommit)); + if (details::circular_less_than(0, desiredCount)) { + desiredCount = desiredCount < max ? desiredCount : max; + std::atomic_thread_fence(std::memory_order_acquire); + + auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(desiredCount, std::memory_order_relaxed);; + + tail = this->tailIndex.load(std::memory_order_acquire); + auto actualCount = static_cast(tail - (myDequeueCount - overcommit)); + if (details::circular_less_than(0, actualCount)) { + actualCount = desiredCount < actualCount ? desiredCount : actualCount; + if (actualCount < desiredCount) { + this->dequeueOvercommit.fetch_add(desiredCount - actualCount, std::memory_order_release); + } + + // Get the first index. Note that since there's guaranteed to be at least actualCount elements, this + // will never exceed tail. + auto firstIndex = this->headIndex.fetch_add(actualCount, std::memory_order_acq_rel); + + // Determine which block the first element is in + auto localBlockIndex = blockIndex.load(std::memory_order_acquire); + auto localBlockIndexHead = localBlockIndex->front.load(std::memory_order_acquire); + + auto headBase = localBlockIndex->entries[localBlockIndexHead].base; + auto firstBlockBaseIndex = firstIndex & ~static_cast(BLOCK_SIZE - 1); + auto offset = static_cast(static_cast::type>(firstBlockBaseIndex - headBase) / BLOCK_SIZE); + auto indexIndex = (localBlockIndexHead + offset) & (localBlockIndex->size - 1); + + // Iterate the blocks and dequeue + auto index = firstIndex; + do { + auto firstIndexInBlock = index; + auto endIndex = (index & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + endIndex = details::circular_less_than(firstIndex + static_cast(actualCount), endIndex) ? firstIndex + static_cast(actualCount) : endIndex; + auto block = localBlockIndex->entries[indexIndex].block; + if (MOODYCAMEL_NOEXCEPT_ASSIGN(T, T&&, details::deref_noexcept(itemFirst) = std::move((*(*block)[index])))) { + while (index != endIndex) { + auto& el = *((*block)[index]); + *itemFirst++ = std::move(el); + el.~T(); + ++index; + } + } + else { + MOODYCAMEL_TRY { + while (index != endIndex) { + auto& el = *((*block)[index]); + *itemFirst = std::move(el); + ++itemFirst; + el.~T(); + ++index; + } + } + MOODYCAMEL_CATCH (...) { + // It's too late to revert the dequeue, but we can make sure that all + // the dequeued objects are properly destroyed and the block index + // (and empty count) are properly updated before we propagate the exception + do { + block = localBlockIndex->entries[indexIndex].block; + while (index != endIndex) { + (*block)[index++]->~T(); + } + block->ConcurrentQueue::Block::template set_many_empty(firstIndexInBlock, static_cast(endIndex - firstIndexInBlock)); + indexIndex = (indexIndex + 1) & (localBlockIndex->size - 1); + + firstIndexInBlock = index; + endIndex = (index & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + endIndex = details::circular_less_than(firstIndex + static_cast(actualCount), endIndex) ? firstIndex + static_cast(actualCount) : endIndex; + } while (index != firstIndex + actualCount); + + MOODYCAMEL_RETHROW; + } + } + block->ConcurrentQueue::Block::template set_many_empty(firstIndexInBlock, static_cast(endIndex - firstIndexInBlock)); + indexIndex = (indexIndex + 1) & (localBlockIndex->size - 1); + } while (index != firstIndex + actualCount); + + return actualCount; + } + else { + // Wasn't anything to dequeue after all; make the effective dequeue count eventually consistent + this->dequeueOvercommit.fetch_add(desiredCount, std::memory_order_release); + } + } + + return 0; + } + + private: + struct BlockIndexEntry + { + index_t base; + Block* block; + }; + + struct BlockIndexHeader + { + size_t size; + std::atomic front; // Current slot (not next, like pr_blockIndexFront) + BlockIndexEntry* entries; + void* prev; + }; + + + bool new_block_index(size_t numberOfFilledSlotsToExpose) + { + auto prevBlockSizeMask = pr_blockIndexSize - 1; + + // Create the new block + pr_blockIndexSize <<= 1; + auto newRawPtr = static_cast((Traits::malloc)(sizeof(BlockIndexHeader) + std::alignment_of::value - 1 + sizeof(BlockIndexEntry) * pr_blockIndexSize)); + if (newRawPtr == nullptr) { + pr_blockIndexSize >>= 1; // Reset to allow graceful retry + return false; + } + + auto newBlockIndexEntries = reinterpret_cast(details::align_for(newRawPtr + sizeof(BlockIndexHeader))); + + // Copy in all the old indices, if any + size_t j = 0; + if (pr_blockIndexSlotsUsed != 0) { + auto i = (pr_blockIndexFront - pr_blockIndexSlotsUsed) & prevBlockSizeMask; + do { + newBlockIndexEntries[j++] = pr_blockIndexEntries[i]; + i = (i + 1) & prevBlockSizeMask; + } while (i != pr_blockIndexFront); + } + + // Update everything + auto header = new (newRawPtr) BlockIndexHeader; + header->size = pr_blockIndexSize; + header->front.store(numberOfFilledSlotsToExpose - 1, std::memory_order_relaxed); + header->entries = newBlockIndexEntries; + header->prev = pr_blockIndexRaw; // we link the new block to the old one so we can free it later + + pr_blockIndexFront = j; + pr_blockIndexEntries = newBlockIndexEntries; + pr_blockIndexRaw = newRawPtr; + blockIndex.store(header, std::memory_order_release); + + return true; + } + + private: + std::atomic blockIndex; + + // To be used by producer only -- consumer must use the ones in referenced by blockIndex + size_t pr_blockIndexSlotsUsed; + size_t pr_blockIndexSize; + size_t pr_blockIndexFront; // Next slot (not current) + BlockIndexEntry* pr_blockIndexEntries; + void* pr_blockIndexRaw; + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + public: + ExplicitProducer* nextExplicitProducer; + private: +#endif + +#ifdef MCDBGQ_TRACKMEM + friend struct MemStats; +#endif + }; + + + ////////////////////////////////// + // Implicit queue + ////////////////////////////////// + + struct ImplicitProducer : public ProducerBase + { + ImplicitProducer(ConcurrentQueue* parent_) : + ProducerBase(parent_, false), + nextBlockIndexCapacity(IMPLICIT_INITIAL_INDEX_SIZE), + blockIndex(nullptr) + { + new_block_index(); + } + + ~ImplicitProducer() + { + // Note that since we're in the destructor we can assume that all enqueue/dequeue operations + // completed already; this means that all undequeued elements are placed contiguously across + // contiguous blocks, and that only the first and last remaining blocks can be only partially + // empty (all other remaining blocks must be completely full). + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + // Unregister ourselves for thread termination notification + if (!this->inactive.load(std::memory_order_relaxed)) { + details::ThreadExitNotifier::unsubscribe(&threadExitListener); + } +#endif + + // Destroy all remaining elements! + auto tail = this->tailIndex.load(std::memory_order_relaxed); + auto index = this->headIndex.load(std::memory_order_relaxed); + Block* block = nullptr; + assert(index == tail || details::circular_less_than(index, tail)); + bool forceFreeLastBlock = index != tail; // If we enter the loop, then the last (tail) block will not be freed + while (index != tail) { + if ((index & static_cast(BLOCK_SIZE - 1)) == 0 || block == nullptr) { + if (block != nullptr) { + // Free the old block + this->parent->add_block_to_free_list(block); + } + + block = get_block_index_entry_for_index(index)->value.load(std::memory_order_relaxed); + } + + ((*block)[index])->~T(); + ++index; + } + // Even if the queue is empty, there's still one block that's not on the free list + // (unless the head index reached the end of it, in which case the tail will be poised + // to create a new block). + if (this->tailBlock != nullptr && (forceFreeLastBlock || (tail & static_cast(BLOCK_SIZE - 1)) != 0)) { + this->parent->add_block_to_free_list(this->tailBlock); + } + + // Destroy block index + auto localBlockIndex = blockIndex.load(std::memory_order_relaxed); + if (localBlockIndex != nullptr) { + for (size_t i = 0; i != localBlockIndex->capacity; ++i) { + localBlockIndex->index[i]->~BlockIndexEntry(); + } + do { + auto prev = localBlockIndex->prev; + localBlockIndex->~BlockIndexHeader(); + (Traits::free)(localBlockIndex); + localBlockIndex = prev; + } while (localBlockIndex != nullptr); + } + } + + template + inline bool enqueue(U&& element) + { + index_t currentTailIndex = this->tailIndex.load(std::memory_order_relaxed); + index_t newTailIndex = 1 + currentTailIndex; + if ((currentTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { + // We reached the end of a block, start a new one + auto head = this->headIndex.load(std::memory_order_relaxed); + assert(!details::circular_less_than(currentTailIndex, head)); + if (!details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && (MAX_SUBQUEUE_SIZE == 0 || MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head))) { + return false; + } +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + // Find out where we'll be inserting this block in the block index + BlockIndexEntry* idxEntry; + if (!insert_block_index_entry(idxEntry, currentTailIndex)) { + return false; + } + + // Get ahold of a new block + auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); + if (newBlock == nullptr) { + rewind_block_index_tail(); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + return false; + } +#ifdef MCDBGQ_TRACKMEM + newBlock->owner = this; +#endif + newBlock->ConcurrentQueue::Block::template reset_empty(); + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new ((T*)nullptr) T(std::forward(element)))) { + // May throw, try to insert now before we publish the fact that we have this new block + MOODYCAMEL_TRY { + new ((*newBlock)[currentTailIndex]) T(std::forward(element)); + } + MOODYCAMEL_CATCH (...) { + rewind_block_index_tail(); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + this->parent->add_block_to_free_list(newBlock); + MOODYCAMEL_RETHROW; + } + } + + // Insert the new block into the index + idxEntry->value.store(newBlock, std::memory_order_relaxed); + + this->tailBlock = newBlock; + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new ((T*)nullptr) T(std::forward(element)))) { + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + } + + // Enqueue + new ((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); + + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + + template + bool dequeue(U& element) + { + // See ExplicitProducer::dequeue for rationale and explanation + index_t tail = this->tailIndex.load(std::memory_order_relaxed); + index_t overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); + if (details::circular_less_than(this->dequeueOptimisticCount.load(std::memory_order_relaxed) - overcommit, tail)) { + std::atomic_thread_fence(std::memory_order_acquire); + + index_t myDequeueCount = this->dequeueOptimisticCount.fetch_add(1, std::memory_order_relaxed); + tail = this->tailIndex.load(std::memory_order_acquire); + if ((details::likely)(details::circular_less_than(myDequeueCount - overcommit, tail))) { + index_t index = this->headIndex.fetch_add(1, std::memory_order_acq_rel); + + // Determine which block the element is in + auto entry = get_block_index_entry_for_index(index); + + // Dequeue + auto block = entry->value.load(std::memory_order_relaxed); + auto& el = *((*block)[index]); + + if (!MOODYCAMEL_NOEXCEPT_ASSIGN(T, T&&, element = std::move(el))) { +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + // Note: Acquiring the mutex with every dequeue instead of only when a block + // is released is very sub-optimal, but it is, after all, purely debug code. + debug::DebugLock lock(producer->mutex); +#endif + struct Guard { + Block* block; + index_t index; + BlockIndexEntry* entry; + ConcurrentQueue* parent; + + ~Guard() + { + (*block)[index]->~T(); + if (block->ConcurrentQueue::Block::template set_empty(index)) { + entry->value.store(nullptr, std::memory_order_relaxed); + parent->add_block_to_free_list(block); + } + } + } guard = { block, index, entry, this->parent }; + + element = std::move(el); // NOLINT + } + else { + element = std::move(el); // NOLINT + el.~T(); // NOLINT + + if (block->ConcurrentQueue::Block::template set_empty(index)) { + { +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + // Add the block back into the global free pool (and remove from block index) + entry->value.store(nullptr, std::memory_order_relaxed); + } + this->parent->add_block_to_free_list(block); // releases the above store + } + } + + return true; + } + else { + this->dequeueOvercommit.fetch_add(1, std::memory_order_release); + } + } + + return false; + } + + template + bool enqueue_bulk(It itemFirst, size_t count) + { + // First, we need to make sure we have enough room to enqueue all of the elements; + // this means pre-allocating blocks and putting them in the block index (but only if + // all the allocations succeeded). + + // Note that the tailBlock we start off with may not be owned by us any more; + // this happens if it was filled up exactly to the top (setting tailIndex to + // the first index of the next block which is not yet allocated), then dequeued + // completely (putting it on the free list) before we enqueue again. + + index_t startTailIndex = this->tailIndex.load(std::memory_order_relaxed); + auto startBlock = this->tailBlock; + Block* firstAllocatedBlock = nullptr; + auto endBlock = this->tailBlock; + + // Figure out how many blocks we'll need to allocate, and do so + size_t blockBaseDiff = ((startTailIndex + count - 1) & ~static_cast(BLOCK_SIZE - 1)) - ((startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1)); + index_t currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); + if (blockBaseDiff > 0) { +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + do { + blockBaseDiff -= static_cast(BLOCK_SIZE); + currentTailIndex += static_cast(BLOCK_SIZE); + + // Find out where we'll be inserting this block in the block index + BlockIndexEntry* idxEntry = nullptr; // initialization here unnecessary but compiler can't always tell + Block* newBlock; + bool indexInserted = false; + auto head = this->headIndex.load(std::memory_order_relaxed); + assert(!details::circular_less_than(currentTailIndex, head)); + bool full = !details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && (MAX_SUBQUEUE_SIZE == 0 || MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head)); + if (full || !(indexInserted = insert_block_index_entry(idxEntry, currentTailIndex)) || (newBlock = this->parent->ConcurrentQueue::template requisition_block()) == nullptr) { + // Index allocation or block allocation failed; revert any other allocations + // and index insertions done so far for this operation + if (indexInserted) { + rewind_block_index_tail(); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + } + currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); + for (auto block = firstAllocatedBlock; block != nullptr; block = block->next) { + currentTailIndex += static_cast(BLOCK_SIZE); + idxEntry = get_block_index_entry_for_index(currentTailIndex); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + rewind_block_index_tail(); + } + this->parent->add_blocks_to_free_list(firstAllocatedBlock); + this->tailBlock = startBlock; + + return false; + } + +#ifdef MCDBGQ_TRACKMEM + newBlock->owner = this; +#endif + newBlock->ConcurrentQueue::Block::template reset_empty(); + newBlock->next = nullptr; + + // Insert the new block into the index + idxEntry->value.store(newBlock, std::memory_order_relaxed); + + // Store the chain of blocks so that we can undo if later allocations fail, + // and so that we can find the blocks when we do the actual enqueueing + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || firstAllocatedBlock != nullptr) { + assert(this->tailBlock != nullptr); + this->tailBlock->next = newBlock; + } + this->tailBlock = newBlock; + endBlock = newBlock; + firstAllocatedBlock = firstAllocatedBlock == nullptr ? newBlock : firstAllocatedBlock; + } while (blockBaseDiff > 0); + } + + // Enqueue, one block at a time + index_t newTailIndex = startTailIndex + static_cast(count); + currentTailIndex = startTailIndex; + this->tailBlock = startBlock; + assert((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || firstAllocatedBlock != nullptr || count == 0); + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0 && firstAllocatedBlock != nullptr) { + this->tailBlock = firstAllocatedBlock; + } + while (true) { + auto stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + if (details::circular_less_than(newTailIndex, stopIndex)) { + stopIndex = newTailIndex; + } + if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), new ((T*)nullptr) T(details::deref_noexcept(itemFirst)))) { + while (currentTailIndex != stopIndex) { + new ((*this->tailBlock)[currentTailIndex++]) T(*itemFirst++); + } + } + else { + MOODYCAMEL_TRY { + while (currentTailIndex != stopIndex) { + new ((*this->tailBlock)[currentTailIndex]) T(details::nomove_if<(bool)!MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), new ((T*)nullptr) T(details::deref_noexcept(itemFirst)))>::eval(*itemFirst)); + ++currentTailIndex; + ++itemFirst; + } + } + MOODYCAMEL_CATCH (...) { + auto constructedStopIndex = currentTailIndex; + auto lastBlockEnqueued = this->tailBlock; + + if (!details::is_trivially_destructible::value) { + auto block = startBlock; + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { + block = firstAllocatedBlock; + } + currentTailIndex = startTailIndex; + while (true) { + stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + if (details::circular_less_than(constructedStopIndex, stopIndex)) { + stopIndex = constructedStopIndex; + } + while (currentTailIndex != stopIndex) { + (*block)[currentTailIndex++]->~T(); + } + if (block == lastBlockEnqueued) { + break; + } + block = block->next; + } + } + + currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); + for (auto block = firstAllocatedBlock; block != nullptr; block = block->next) { + currentTailIndex += static_cast(BLOCK_SIZE); + auto idxEntry = get_block_index_entry_for_index(currentTailIndex); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + rewind_block_index_tail(); + } + this->parent->add_blocks_to_free_list(firstAllocatedBlock); + this->tailBlock = startBlock; + MOODYCAMEL_RETHROW; + } + } + + if (this->tailBlock == endBlock) { + assert(currentTailIndex == newTailIndex); + break; + } + this->tailBlock = this->tailBlock->next; + } + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + + template + size_t dequeue_bulk(It& itemFirst, size_t max) + { + auto tail = this->tailIndex.load(std::memory_order_relaxed); + auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); + auto desiredCount = static_cast(tail - (this->dequeueOptimisticCount.load(std::memory_order_relaxed) - overcommit)); + if (details::circular_less_than(0, desiredCount)) { + desiredCount = desiredCount < max ? desiredCount : max; + std::atomic_thread_fence(std::memory_order_acquire); + + auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(desiredCount, std::memory_order_relaxed); + + tail = this->tailIndex.load(std::memory_order_acquire); + auto actualCount = static_cast(tail - (myDequeueCount - overcommit)); + if (details::circular_less_than(0, actualCount)) { + actualCount = desiredCount < actualCount ? desiredCount : actualCount; + if (actualCount < desiredCount) { + this->dequeueOvercommit.fetch_add(desiredCount - actualCount, std::memory_order_release); + } + + // Get the first index. Note that since there's guaranteed to be at least actualCount elements, this + // will never exceed tail. + auto firstIndex = this->headIndex.fetch_add(actualCount, std::memory_order_acq_rel); + + // Iterate the blocks and dequeue + auto index = firstIndex; + BlockIndexHeader* localBlockIndex; + auto indexIndex = get_block_index_index_for_index(index, localBlockIndex); + do { + auto blockStartIndex = index; + auto endIndex = (index & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + endIndex = details::circular_less_than(firstIndex + static_cast(actualCount), endIndex) ? firstIndex + static_cast(actualCount) : endIndex; + + auto entry = localBlockIndex->index[indexIndex]; + auto block = entry->value.load(std::memory_order_relaxed); + if (MOODYCAMEL_NOEXCEPT_ASSIGN(T, T&&, details::deref_noexcept(itemFirst) = std::move((*(*block)[index])))) { + while (index != endIndex) { + auto& el = *((*block)[index]); + *itemFirst++ = std::move(el); + el.~T(); + ++index; + } + } + else { + MOODYCAMEL_TRY { + while (index != endIndex) { + auto& el = *((*block)[index]); + *itemFirst = std::move(el); + ++itemFirst; + el.~T(); + ++index; + } + } + MOODYCAMEL_CATCH (...) { + do { + entry = localBlockIndex->index[indexIndex]; + block = entry->value.load(std::memory_order_relaxed); + while (index != endIndex) { + (*block)[index++]->~T(); + } + + if (block->ConcurrentQueue::Block::template set_many_empty(blockStartIndex, static_cast(endIndex - blockStartIndex))) { +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + entry->value.store(nullptr, std::memory_order_relaxed); + this->parent->add_block_to_free_list(block); + } + indexIndex = (indexIndex + 1) & (localBlockIndex->capacity - 1); + + blockStartIndex = index; + endIndex = (index & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + endIndex = details::circular_less_than(firstIndex + static_cast(actualCount), endIndex) ? firstIndex + static_cast(actualCount) : endIndex; + } while (index != firstIndex + actualCount); + + MOODYCAMEL_RETHROW; + } + } + if (block->ConcurrentQueue::Block::template set_many_empty(blockStartIndex, static_cast(endIndex - blockStartIndex))) { + { +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + // Note that the set_many_empty above did a release, meaning that anybody who acquires the block + // we're about to free can use it safely since our writes (and reads!) will have happened-before then. + entry->value.store(nullptr, std::memory_order_relaxed); + } + this->parent->add_block_to_free_list(block); // releases the above store + } + indexIndex = (indexIndex + 1) & (localBlockIndex->capacity - 1); + } while (index != firstIndex + actualCount); + + return actualCount; + } + else { + this->dequeueOvercommit.fetch_add(desiredCount, std::memory_order_release); + } + } + + return 0; + } + + private: + // The block size must be > 1, so any number with the low bit set is an invalid block base index + static const index_t INVALID_BLOCK_BASE = 1; + + struct BlockIndexEntry + { + std::atomic key; + std::atomic value; + }; + + struct BlockIndexHeader + { + size_t capacity; + std::atomic tail; + BlockIndexEntry* entries; + BlockIndexEntry** index; + BlockIndexHeader* prev; + }; + + template + inline bool insert_block_index_entry(BlockIndexEntry*& idxEntry, index_t blockStartIndex) + { + auto localBlockIndex = blockIndex.load(std::memory_order_relaxed); // We're the only writer thread, relaxed is OK + if (localBlockIndex == nullptr) { + return false; // this can happen if new_block_index failed in the constructor + } + auto newTail = (localBlockIndex->tail.load(std::memory_order_relaxed) + 1) & (localBlockIndex->capacity - 1); + idxEntry = localBlockIndex->index[newTail]; + if (idxEntry->key.load(std::memory_order_relaxed) == INVALID_BLOCK_BASE || + idxEntry->value.load(std::memory_order_relaxed) == nullptr) { + + idxEntry->key.store(blockStartIndex, std::memory_order_relaxed); + localBlockIndex->tail.store(newTail, std::memory_order_release); + return true; + } + + // No room in the old block index, try to allocate another one! + MOODYCAMEL_CONSTEXPR_IF (allocMode == CannotAlloc) { + return false; + } + else if (!new_block_index()) { + return false; + } + localBlockIndex = blockIndex.load(std::memory_order_relaxed); + newTail = (localBlockIndex->tail.load(std::memory_order_relaxed) + 1) & (localBlockIndex->capacity - 1); + idxEntry = localBlockIndex->index[newTail]; + assert(idxEntry->key.load(std::memory_order_relaxed) == INVALID_BLOCK_BASE); + idxEntry->key.store(blockStartIndex, std::memory_order_relaxed); + localBlockIndex->tail.store(newTail, std::memory_order_release); + return true; + } + + inline void rewind_block_index_tail() + { + auto localBlockIndex = blockIndex.load(std::memory_order_relaxed); + localBlockIndex->tail.store((localBlockIndex->tail.load(std::memory_order_relaxed) - 1) & (localBlockIndex->capacity - 1), std::memory_order_relaxed); + } + + inline BlockIndexEntry* get_block_index_entry_for_index(index_t index) const + { + BlockIndexHeader* localBlockIndex; + auto idx = get_block_index_index_for_index(index, localBlockIndex); + return localBlockIndex->index[idx]; + } + + inline size_t get_block_index_index_for_index(index_t index, BlockIndexHeader*& localBlockIndex) const + { +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + index &= ~static_cast(BLOCK_SIZE - 1); + localBlockIndex = blockIndex.load(std::memory_order_acquire); + auto tail = localBlockIndex->tail.load(std::memory_order_acquire); + auto tailBase = localBlockIndex->index[tail]->key.load(std::memory_order_relaxed); + assert(tailBase != INVALID_BLOCK_BASE); + // Note: Must use division instead of shift because the index may wrap around, causing a negative + // offset, whose negativity we want to preserve + auto offset = static_cast(static_cast::type>(index - tailBase) / BLOCK_SIZE); + size_t idx = (tail + offset) & (localBlockIndex->capacity - 1); + assert(localBlockIndex->index[idx]->key.load(std::memory_order_relaxed) == index && localBlockIndex->index[idx]->value.load(std::memory_order_relaxed) != nullptr); + return idx; + } + + bool new_block_index() + { + auto prev = blockIndex.load(std::memory_order_relaxed); + size_t prevCapacity = prev == nullptr ? 0 : prev->capacity; + auto entryCount = prev == nullptr ? nextBlockIndexCapacity : prevCapacity; + auto raw = static_cast((Traits::malloc)( + sizeof(BlockIndexHeader) + + std::alignment_of::value - 1 + sizeof(BlockIndexEntry) * entryCount + + std::alignment_of::value - 1 + sizeof(BlockIndexEntry*) * nextBlockIndexCapacity)); + if (raw == nullptr) { + return false; + } + + auto header = new (raw) BlockIndexHeader; + auto entries = reinterpret_cast(details::align_for(raw + sizeof(BlockIndexHeader))); + auto index = reinterpret_cast(details::align_for(reinterpret_cast(entries) + sizeof(BlockIndexEntry) * entryCount)); + if (prev != nullptr) { + auto prevTail = prev->tail.load(std::memory_order_relaxed); + auto prevPos = prevTail; + size_t i = 0; + do { + prevPos = (prevPos + 1) & (prev->capacity - 1); + index[i++] = prev->index[prevPos]; + } while (prevPos != prevTail); + assert(i == prevCapacity); + } + for (size_t i = 0; i != entryCount; ++i) { + new (entries + i) BlockIndexEntry; + entries[i].key.store(INVALID_BLOCK_BASE, std::memory_order_relaxed); + index[prevCapacity + i] = entries + i; + } + header->prev = prev; + header->entries = entries; + header->index = index; + header->capacity = nextBlockIndexCapacity; + header->tail.store((prevCapacity - 1) & (nextBlockIndexCapacity - 1), std::memory_order_relaxed); + + blockIndex.store(header, std::memory_order_release); + + nextBlockIndexCapacity <<= 1; + + return true; + } + + private: + size_t nextBlockIndexCapacity; + std::atomic blockIndex; + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + public: + details::ThreadExitListener threadExitListener; + private: +#endif + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + public: + ImplicitProducer* nextImplicitProducer; + private: +#endif + +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + mutable debug::DebugMutex mutex; +#endif +#ifdef MCDBGQ_TRACKMEM + friend struct MemStats; +#endif + }; + + + ////////////////////////////////// + // Block pool manipulation + ////////////////////////////////// + + void populate_initial_block_list(size_t blockCount) + { + initialBlockPoolSize = blockCount; + if (initialBlockPoolSize == 0) { + initialBlockPool = nullptr; + return; + } + + initialBlockPool = create_array(blockCount); + if (initialBlockPool == nullptr) { + initialBlockPoolSize = 0; + } + for (size_t i = 0; i < initialBlockPoolSize; ++i) { + initialBlockPool[i].dynamicallyAllocated = false; + } + } + + inline Block* try_get_block_from_initial_pool() + { + if (initialBlockPoolIndex.load(std::memory_order_relaxed) >= initialBlockPoolSize) { + return nullptr; + } + + auto index = initialBlockPoolIndex.fetch_add(1, std::memory_order_relaxed); + + return index < initialBlockPoolSize ? (initialBlockPool + index) : nullptr; + } + + inline void add_block_to_free_list(Block* block) + { +#ifdef MCDBGQ_TRACKMEM + block->owner = nullptr; +#endif + freeList.add(block); + } + + inline void add_blocks_to_free_list(Block* block) + { + while (block != nullptr) { + auto next = block->next; + add_block_to_free_list(block); + block = next; + } + } + + inline Block* try_get_block_from_free_list() + { + return freeList.try_get(); + } + + // Gets a free block from one of the memory pools, or allocates a new one (if applicable) + template + Block* requisition_block() + { + auto block = try_get_block_from_initial_pool(); + if (block != nullptr) { + return block; + } + + block = try_get_block_from_free_list(); + if (block != nullptr) { + return block; + } + + MOODYCAMEL_CONSTEXPR_IF (canAlloc == CanAlloc) { + return create(); + } + else { + return nullptr; + } + } + + +#ifdef MCDBGQ_TRACKMEM + public: + struct MemStats { + size_t allocatedBlocks; + size_t usedBlocks; + size_t freeBlocks; + size_t ownedBlocksExplicit; + size_t ownedBlocksImplicit; + size_t implicitProducers; + size_t explicitProducers; + size_t elementsEnqueued; + size_t blockClassBytes; + size_t queueClassBytes; + size_t implicitBlockIndexBytes; + size_t explicitBlockIndexBytes; + + friend class ConcurrentQueue; + + private: + static MemStats getFor(ConcurrentQueue* q) + { + MemStats stats = { 0 }; + + stats.elementsEnqueued = q->size_approx(); + + auto block = q->freeList.head_unsafe(); + while (block != nullptr) { + ++stats.allocatedBlocks; + ++stats.freeBlocks; + block = block->freeListNext.load(std::memory_order_relaxed); + } + + for (auto ptr = q->producerListTail.load(std::memory_order_acquire); ptr != nullptr; ptr = ptr->next_prod()) { + bool implicit = dynamic_cast(ptr) != nullptr; + stats.implicitProducers += implicit ? 1 : 0; + stats.explicitProducers += implicit ? 0 : 1; + + if (implicit) { + auto prod = static_cast(ptr); + stats.queueClassBytes += sizeof(ImplicitProducer); + auto head = prod->headIndex.load(std::memory_order_relaxed); + auto tail = prod->tailIndex.load(std::memory_order_relaxed); + auto hash = prod->blockIndex.load(std::memory_order_relaxed); + if (hash != nullptr) { + for (size_t i = 0; i != hash->capacity; ++i) { + if (hash->index[i]->key.load(std::memory_order_relaxed) != ImplicitProducer::INVALID_BLOCK_BASE && hash->index[i]->value.load(std::memory_order_relaxed) != nullptr) { + ++stats.allocatedBlocks; + ++stats.ownedBlocksImplicit; + } + } + stats.implicitBlockIndexBytes += hash->capacity * sizeof(typename ImplicitProducer::BlockIndexEntry); + for (; hash != nullptr; hash = hash->prev) { + stats.implicitBlockIndexBytes += sizeof(typename ImplicitProducer::BlockIndexHeader) + hash->capacity * sizeof(typename ImplicitProducer::BlockIndexEntry*); + } + } + for (; details::circular_less_than(head, tail); head += BLOCK_SIZE) { + //auto block = prod->get_block_index_entry_for_index(head); + ++stats.usedBlocks; + } + } + else { + auto prod = static_cast(ptr); + stats.queueClassBytes += sizeof(ExplicitProducer); + auto tailBlock = prod->tailBlock; + bool wasNonEmpty = false; + if (tailBlock != nullptr) { + auto block = tailBlock; + do { + ++stats.allocatedBlocks; + if (!block->ConcurrentQueue::Block::template is_empty() || wasNonEmpty) { + ++stats.usedBlocks; + wasNonEmpty = wasNonEmpty || block != tailBlock; + } + ++stats.ownedBlocksExplicit; + block = block->next; + } while (block != tailBlock); + } + auto index = prod->blockIndex.load(std::memory_order_relaxed); + while (index != nullptr) { + stats.explicitBlockIndexBytes += sizeof(typename ExplicitProducer::BlockIndexHeader) + index->size * sizeof(typename ExplicitProducer::BlockIndexEntry); + index = static_cast(index->prev); + } + } + } + + auto freeOnInitialPool = q->initialBlockPoolIndex.load(std::memory_order_relaxed) >= q->initialBlockPoolSize ? 0 : q->initialBlockPoolSize - q->initialBlockPoolIndex.load(std::memory_order_relaxed); + stats.allocatedBlocks += freeOnInitialPool; + stats.freeBlocks += freeOnInitialPool; + + stats.blockClassBytes = sizeof(Block) * stats.allocatedBlocks; + stats.queueClassBytes += sizeof(ConcurrentQueue); + + return stats; + } + }; + + // For debugging only. Not thread-safe. + MemStats getMemStats() + { + return MemStats::getFor(this); + } + private: + friend struct MemStats; +#endif + + + ////////////////////////////////// + // Producer list manipulation + ////////////////////////////////// + + ProducerBase* recycle_or_create_producer(bool isExplicit) + { + bool recycled; + return recycle_or_create_producer(isExplicit, recycled); + } + + ProducerBase* recycle_or_create_producer(bool isExplicit, bool& recycled) + { +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH + debug::DebugLock lock(implicitProdMutex); +#endif + // Try to re-use one first + for (auto ptr = producerListTail.load(std::memory_order_acquire); ptr != nullptr; ptr = ptr->next_prod()) { + if (ptr->inactive.load(std::memory_order_relaxed) && ptr->isExplicit == isExplicit) { + bool expected = true; + if (ptr->inactive.compare_exchange_strong(expected, /* desired */ false, std::memory_order_acquire, std::memory_order_relaxed)) { + // We caught one! It's been marked as activated, the caller can have it + recycled = true; + return ptr; + } + } + } + + recycled = false; + return add_producer(isExplicit ? static_cast(create(this)) : create(this)); + } + + ProducerBase* add_producer(ProducerBase* producer) + { + // Handle failed memory allocation + if (producer == nullptr) { + return nullptr; + } + + producerCount.fetch_add(1, std::memory_order_relaxed); + + // Add it to the lock-free list + auto prevTail = producerListTail.load(std::memory_order_relaxed); + do { + producer->next = prevTail; + } while (!producerListTail.compare_exchange_weak(prevTail, producer, std::memory_order_release, std::memory_order_relaxed)); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + if (producer->isExplicit) { + auto prevTailExplicit = explicitProducers.load(std::memory_order_relaxed); + do { + static_cast(producer)->nextExplicitProducer = prevTailExplicit; + } while (!explicitProducers.compare_exchange_weak(prevTailExplicit, static_cast(producer), std::memory_order_release, std::memory_order_relaxed)); + } + else { + auto prevTailImplicit = implicitProducers.load(std::memory_order_relaxed); + do { + static_cast(producer)->nextImplicitProducer = prevTailImplicit; + } while (!implicitProducers.compare_exchange_weak(prevTailImplicit, static_cast(producer), std::memory_order_release, std::memory_order_relaxed)); + } +#endif + + return producer; + } + + void reown_producers() + { + // After another instance is moved-into/swapped-with this one, all the + // producers we stole still think their parents are the other queue. + // So fix them up! + for (auto ptr = producerListTail.load(std::memory_order_relaxed); ptr != nullptr; ptr = ptr->next_prod()) { + ptr->parent = this; + } + } + + + ////////////////////////////////// + // Implicit producer hash + ////////////////////////////////// + + struct ImplicitProducerKVP + { + std::atomic key; + ImplicitProducer* value; // No need for atomicity since it's only read by the thread that sets it in the first place + + ImplicitProducerKVP() : value(nullptr) { } + + ImplicitProducerKVP(ImplicitProducerKVP&& other) MOODYCAMEL_NOEXCEPT + { + key.store(other.key.load(std::memory_order_relaxed), std::memory_order_relaxed); + value = other.value; + } + + inline ImplicitProducerKVP& operator=(ImplicitProducerKVP&& other) MOODYCAMEL_NOEXCEPT + { + swap(other); + return *this; + } + + inline void swap(ImplicitProducerKVP& other) MOODYCAMEL_NOEXCEPT + { + if (this != &other) { + details::swap_relaxed(key, other.key); + std::swap(value, other.value); + } + } + }; + + template + friend void duckdb_moodycamel::swap(typename ConcurrentQueue::ImplicitProducerKVP&, typename ConcurrentQueue::ImplicitProducerKVP&) MOODYCAMEL_NOEXCEPT; + + struct ImplicitProducerHash + { + size_t capacity; + ImplicitProducerKVP* entries; + ImplicitProducerHash* prev; + }; + + inline void populate_initial_implicit_producer_hash() + { + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) { + return; + } + else { + implicitProducerHashCount.store(0, std::memory_order_relaxed); + auto hash = &initialImplicitProducerHash; + hash->capacity = INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; + hash->entries = &initialImplicitProducerHashEntries[0]; + for (size_t i = 0; i != INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; ++i) { + initialImplicitProducerHashEntries[i].key.store(details::invalid_thread_id, std::memory_order_relaxed); + } + hash->prev = nullptr; + implicitProducerHash.store(hash, std::memory_order_relaxed); + } + } + + void swap_implicit_producer_hashes(ConcurrentQueue& other) + { + MOODYCAMEL_CONSTEXPR_IF (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) { + return; + } + else { + // Swap (assumes our implicit producer hash is initialized) + initialImplicitProducerHashEntries.swap(other.initialImplicitProducerHashEntries); + initialImplicitProducerHash.entries = &initialImplicitProducerHashEntries[0]; + other.initialImplicitProducerHash.entries = &other.initialImplicitProducerHashEntries[0]; + + details::swap_relaxed(implicitProducerHashCount, other.implicitProducerHashCount); + + details::swap_relaxed(implicitProducerHash, other.implicitProducerHash); + if (implicitProducerHash.load(std::memory_order_relaxed) == &other.initialImplicitProducerHash) { + implicitProducerHash.store(&initialImplicitProducerHash, std::memory_order_relaxed); + } + else { + ImplicitProducerHash* hash; + for (hash = implicitProducerHash.load(std::memory_order_relaxed); hash->prev != &other.initialImplicitProducerHash; hash = hash->prev) { + continue; + } + hash->prev = &initialImplicitProducerHash; + } + if (other.implicitProducerHash.load(std::memory_order_relaxed) == &initialImplicitProducerHash) { + other.implicitProducerHash.store(&other.initialImplicitProducerHash, std::memory_order_relaxed); + } + else { + ImplicitProducerHash* hash; + for (hash = other.implicitProducerHash.load(std::memory_order_relaxed); hash->prev != &initialImplicitProducerHash; hash = hash->prev) { + continue; + } + hash->prev = &other.initialImplicitProducerHash; + } + } + } + + // Only fails (returns nullptr) if memory allocation fails + ImplicitProducer* get_or_add_implicit_producer() + { + // Note that since the data is essentially thread-local (key is thread ID), + // there's a reduced need for fences (memory ordering is already consistent + // for any individual thread), except for the current table itself. + + // Start by looking for the thread ID in the current and all previous hash tables. + // If it's not found, it must not be in there yet, since this same thread would + // have added it previously to one of the tables that we traversed. + + // Code and algorithm adapted from http://preshing.com/20130605/the-worlds-simplest-lock-free-hash-table + +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH + debug::DebugLock lock(implicitProdMutex); +#endif + + auto id = details::thread_id(); + auto hashedId = details::hash_thread_id(id); + + auto mainHash = implicitProducerHash.load(std::memory_order_acquire); + assert(mainHash != nullptr); // silence clang-tidy and MSVC warnings (hash cannot be null) + for (auto hash = mainHash; hash != nullptr; hash = hash->prev) { + // Look for the id in this hash + auto index = hashedId; + while (true) { // Not an infinite loop because at least one slot is free in the hash table + index &= hash->capacity - 1; + + auto probedKey = hash->entries[index].key.load(std::memory_order_relaxed); + if (probedKey == id) { + // Found it! If we had to search several hashes deep, though, we should lazily add it + // to the current main hash table to avoid the extended search next time. + // Note there's guaranteed to be room in the current hash table since every subsequent + // table implicitly reserves space for all previous tables (there's only one + // implicitProducerHashCount). + auto value = hash->entries[index].value; + if (hash != mainHash) { + index = hashedId; + while (true) { + index &= mainHash->capacity - 1; + probedKey = mainHash->entries[index].key.load(std::memory_order_relaxed); + auto empty = details::invalid_thread_id; +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + auto reusable = details::invalid_thread_id2; + if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, std::memory_order_relaxed, std::memory_order_relaxed)) || + (probedKey == reusable && mainHash->entries[index].key.compare_exchange_strong(reusable, id, std::memory_order_acquire, std::memory_order_acquire))) { +#else + if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, std::memory_order_relaxed, std::memory_order_relaxed))) { +#endif + mainHash->entries[index].value = value; + break; + } + ++index; + } + } + + return value; + } + if (probedKey == details::invalid_thread_id) { + break; // Not in this hash table + } + ++index; + } + } + + // Insert! + auto newCount = 1 + implicitProducerHashCount.fetch_add(1, std::memory_order_relaxed); + while (true) { + // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) + if (newCount >= (mainHash->capacity >> 1) && !implicitProducerHashResizeInProgress.test_and_set(std::memory_order_acquire)) { + // We've acquired the resize lock, try to allocate a bigger hash table. + // Note the acquire fence synchronizes with the release fence at the end of this block, and hence when + // we reload implicitProducerHash it must be the most recent version (it only gets changed within this + // locked block). + mainHash = implicitProducerHash.load(std::memory_order_acquire); + if (newCount >= (mainHash->capacity >> 1)) { + auto newCapacity = mainHash->capacity << 1; + while (newCount >= (newCapacity >> 1)) { + newCapacity <<= 1; + } + auto raw = static_cast((Traits::malloc)(sizeof(ImplicitProducerHash) + std::alignment_of::value - 1 + sizeof(ImplicitProducerKVP) * newCapacity)); + if (raw == nullptr) { + // Allocation failed + implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); + implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); + return nullptr; + } + + auto newHash = new (raw) ImplicitProducerHash; + newHash->capacity = newCapacity; + newHash->entries = reinterpret_cast(details::align_for(raw + sizeof(ImplicitProducerHash))); + for (size_t i = 0; i != newCapacity; ++i) { + new (newHash->entries + i) ImplicitProducerKVP; + newHash->entries[i].key.store(details::invalid_thread_id, std::memory_order_relaxed); + } + newHash->prev = mainHash; + implicitProducerHash.store(newHash, std::memory_order_release); + implicitProducerHashResizeInProgress.clear(std::memory_order_release); + mainHash = newHash; + } + else { + implicitProducerHashResizeInProgress.clear(std::memory_order_release); + } + } + + // If it's < three-quarters full, add to the old one anyway so that we don't have to wait for the next table + // to finish being allocated by another thread (and if we just finished allocating above, the condition will + // always be true) + if (newCount < (mainHash->capacity >> 1) + (mainHash->capacity >> 2)) { + bool recycled; + auto producer = static_cast(recycle_or_create_producer(false, recycled)); + if (producer == nullptr) { + implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); + return nullptr; + } + if (recycled) { + implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); + } + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + producer->threadExitListener.callback = &ConcurrentQueue::implicit_producer_thread_exited_callback; + producer->threadExitListener.userData = producer; + details::ThreadExitNotifier::subscribe(&producer->threadExitListener); +#endif + + auto index = hashedId; + while (true) { + index &= mainHash->capacity - 1; + auto probedKey = mainHash->entries[index].key.load(std::memory_order_relaxed); + + auto empty = details::invalid_thread_id; +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + auto reusable = details::invalid_thread_id2; + if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, std::memory_order_relaxed, std::memory_order_relaxed)) || + (probedKey == reusable && mainHash->entries[index].key.compare_exchange_strong(reusable, id, std::memory_order_acquire, std::memory_order_acquire))) { +#else + if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, std::memory_order_relaxed, std::memory_order_relaxed))) { +#endif + mainHash->entries[index].value = producer; + break; + } + ++index; + } + return producer; + } + + // Hmm, the old hash is quite full and somebody else is busy allocating a new one. + // We need to wait for the allocating thread to finish (if it succeeds, we add, if not, + // we try to allocate ourselves). + mainHash = implicitProducerHash.load(std::memory_order_acquire); + } + } + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + void implicit_producer_thread_exited(ImplicitProducer* producer) + { + // Remove from thread exit listeners + details::ThreadExitNotifier::unsubscribe(&producer->threadExitListener); + + // Remove from hash +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH + debug::DebugLock lock(implicitProdMutex); +#endif + auto hash = implicitProducerHash.load(std::memory_order_acquire); + assert(hash != nullptr); // The thread exit listener is only registered if we were added to a hash in the first place + auto id = details::thread_id(); + auto hashedId = details::hash_thread_id(id); + details::thread_id_t probedKey; + + // We need to traverse all the hashes just in case other threads aren't on the current one yet and are + // trying to add an entry thinking there's a free slot (because they reused a producer) + for (; hash != nullptr; hash = hash->prev) { + auto index = hashedId; + do { + index &= hash->capacity - 1; + probedKey = hash->entries[index].key.load(std::memory_order_relaxed); + if (probedKey == id) { + hash->entries[index].key.store(details::invalid_thread_id2, std::memory_order_release); + break; + } + ++index; + } while (probedKey != details::invalid_thread_id); // Can happen if the hash has changed but we weren't put back in it yet, or if we weren't added to this hash in the first place + } + + // Mark the queue as being recyclable + producer->inactive.store(true, std::memory_order_release); + } + + static void implicit_producer_thread_exited_callback(void* userData) + { + auto producer = static_cast(userData); + auto queue = producer->parent; + queue->implicit_producer_thread_exited(producer); + } +#endif + + ////////////////////////////////// + // Utility functions + ////////////////////////////////// + + template + static inline void* aligned_malloc(size_t size) + { + if (std::alignment_of::value <= std::alignment_of::value) + return (Traits::malloc)(size); + size_t alignment = std::alignment_of::value; + void* raw = (Traits::malloc)(size + alignment - 1 + sizeof(void*)); + if (!raw) + return nullptr; + char* ptr = details::align_for(reinterpret_cast(raw) + sizeof(void*)); + *(reinterpret_cast(ptr) - 1) = raw; + return ptr; + } + + template + static inline void aligned_free(void* ptr) + { + if (std::alignment_of::value <= std::alignment_of::value) + return (Traits::free)(ptr); + (Traits::free)(ptr ? *(reinterpret_cast(ptr) - 1) : nullptr); + } + + template + static inline U* create_array(size_t count) + { + assert(count > 0); + U* p = static_cast(aligned_malloc(sizeof(U) * count)); + if (p == nullptr) + return nullptr; + + for (size_t i = 0; i != count; ++i) + new (p + i) U(); + return p; + } + + template + static inline void destroy_array(U* p, size_t count) + { + if (p != nullptr) { + assert(count > 0); + for (size_t i = count; i != 0; ) + (p + --i)->~U(); + } + aligned_free(p); + } + + template + static inline U* create() + { + void* p = aligned_malloc(sizeof(U)); + return p != nullptr ? new (p) U : nullptr; + } + + template + static inline U* create(A1&& a1) + { + void* p = aligned_malloc(sizeof(U)); + return p != nullptr ? new (p) U(std::forward(a1)) : nullptr; + } + + template + static inline void destroy(U* p) + { + if (p != nullptr) + p->~U(); + aligned_free(p); + } + +private: + std::atomic producerListTail; + std::atomic producerCount; + + std::atomic initialBlockPoolIndex; + Block* initialBlockPool; + size_t initialBlockPoolSize; + +#ifndef MCDBGQ_USEDEBUGFREELIST + FreeList freeList; +#else + debug::DebugFreeList freeList; +#endif + + std::atomic implicitProducerHash; + std::atomic implicitProducerHashCount; // Number of slots logically used + ImplicitProducerHash initialImplicitProducerHash; + std::array initialImplicitProducerHashEntries; + std::atomic_flag implicitProducerHashResizeInProgress; + + std::atomic nextExplicitConsumerId; + std::atomic globalExplicitConsumerOffset; + +#ifdef MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH + debug::DebugMutex implicitProdMutex; +#endif + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + std::atomic explicitProducers; + std::atomic implicitProducers; +#endif +}; + + +template +ProducerToken::ProducerToken(ConcurrentQueue& queue) + : producer(queue.recycle_or_create_producer(true)) +{ + if (producer != nullptr) { + producer->token = this; + } +} + +template +ProducerToken::ProducerToken(BlockingConcurrentQueue& queue) + : producer(reinterpret_cast*>(&queue)->recycle_or_create_producer(true)) +{ + if (producer != nullptr) { + producer->token = this; + } +} + +template +ConsumerToken::ConsumerToken(ConcurrentQueue& queue) + : itemsConsumedFromCurrent(0), currentProducer(nullptr), desiredProducer(nullptr) +{ + initialOffset = queue.nextExplicitConsumerId.fetch_add(1, std::memory_order_release); + lastKnownGlobalOffset = -1; +} + +template +ConsumerToken::ConsumerToken(BlockingConcurrentQueue& queue) + : itemsConsumedFromCurrent(0), currentProducer(nullptr), desiredProducer(nullptr) +{ + initialOffset = reinterpret_cast*>(&queue)->nextExplicitConsumerId.fetch_add(1, std::memory_order_release); + lastKnownGlobalOffset = -1; +} + +template +inline void swap(ConcurrentQueue& a, ConcurrentQueue& b) MOODYCAMEL_NOEXCEPT +{ + a.swap(b); +} + +inline void swap(ProducerToken& a, ProducerToken& b) MOODYCAMEL_NOEXCEPT +{ + a.swap(b); +} + +inline void swap(ConsumerToken& a, ConsumerToken& b) MOODYCAMEL_NOEXCEPT +{ + a.swap(b); +} + +template +inline void swap(typename ConcurrentQueue::ImplicitProducerKVP& a, typename ConcurrentQueue::ImplicitProducerKVP& b) MOODYCAMEL_NOEXCEPT +{ + a.swap(b); +} + +} + diff --git a/src/duckdb/third_party/concurrentqueue/lightweightsemaphore.h b/src/duckdb/third_party/concurrentqueue/lightweightsemaphore.h new file mode 100644 index 00000000..904275df --- /dev/null +++ b/src/duckdb/third_party/concurrentqueue/lightweightsemaphore.h @@ -0,0 +1,432 @@ +// Provides an efficient implementation of a semaphore (LightweightSemaphore). +// This is an extension of Jeff Preshing's sempahore implementation (licensed +// under the terms of its separate zlib license) that has been adapted and +// extended by Cameron Desrochers. + +#pragma once + +#include // For std::size_t +#include +#include // For std::make_signed + +#if defined(_WIN32) +// Avoid including windows.h in a header; we only need a handful of +// items, so we'll redeclare them here (this is relatively safe since +// the API generally has to remain stable between Windows versions). +// I know this is an ugly hack but it still beats polluting the global +// namespace with thousands of generic names or adding a .cpp for nothing. +extern "C" { + struct _SECURITY_ATTRIBUTES; + __declspec(dllimport) void* __stdcall CreateSemaphoreW(_SECURITY_ATTRIBUTES* lpSemaphoreAttributes, long lInitialCount, long lMaximumCount, const wchar_t* lpName); + __declspec(dllimport) int __stdcall CloseHandle(void* hObject); + __declspec(dllimport) unsigned long __stdcall WaitForSingleObject(void* hHandle, unsigned long dwMilliseconds); + __declspec(dllimport) int __stdcall ReleaseSemaphore(void* hSemaphore, long lReleaseCount, long* lpPreviousCount); +} +#elif defined(__MACH__) +#include +#elif defined(__unix__) +#include +#include +#elif defined(__MVS__) +#include +#include +#endif + +namespace duckdb_moodycamel +{ +namespace details +{ + +// Code in the mpmc_sema namespace below is an adaptation of Jeff Preshing's +// portable + lightweight semaphore implementations, originally from +// https://github.com/preshing/cpp11-on-multicore/blob/master/common/sema.h +// LICENSE: +// Copyright (c) 2015 Jeff Preshing +// +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. +// +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: +// +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgement in the product documentation would be +// appreciated but is not required. +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. +// 3. This notice may not be removed or altered from any source distribution. +#if defined(_WIN32) +class Semaphore +{ +private: + void* m_hSema; + + Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + +public: + Semaphore(int initialCount = 0) + { + assert(initialCount >= 0); + const long maxLong = 0x7fffffff; + m_hSema = CreateSemaphoreW(nullptr, initialCount, maxLong, nullptr); + assert(m_hSema); + } + + ~Semaphore() + { + CloseHandle(m_hSema); + } + + bool wait() + { + const unsigned long infinite = 0xffffffff; + return WaitForSingleObject(m_hSema, infinite) == 0; + } + + bool try_wait() + { + return WaitForSingleObject(m_hSema, 0) == 0; + } + + bool timed_wait(std::uint64_t usecs) + { + return WaitForSingleObject(m_hSema, (unsigned long)(usecs / 1000)) == 0; + } + + void signal(int count = 1) + { + while (!ReleaseSemaphore(m_hSema, count, nullptr)); + } +}; +#elif defined(__MACH__) +//--------------------------------------------------------- +// Semaphore (Apple iOS and OSX) +// Can't use POSIX semaphores due to http://lists.apple.com/archives/darwin-kernel/2009/Apr/msg00010.html +//--------------------------------------------------------- +class Semaphore +{ +private: + semaphore_t m_sema; + + Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + +public: + Semaphore(int initialCount = 0) + { + assert(initialCount >= 0); + kern_return_t rc = semaphore_create(mach_task_self(), &m_sema, SYNC_POLICY_FIFO, initialCount); + assert(rc == KERN_SUCCESS); + (void)rc; + } + + ~Semaphore() + { + semaphore_destroy(mach_task_self(), m_sema); + } + + bool wait() + { + return semaphore_wait(m_sema) == KERN_SUCCESS; + } + + bool try_wait() + { + return timed_wait(0); + } + + bool timed_wait(std::uint64_t timeout_usecs) + { + mach_timespec_t ts; + ts.tv_sec = static_cast(timeout_usecs / 1000000); + ts.tv_nsec = (timeout_usecs % 1000000) * 1000; + + // added in OSX 10.10: https://developer.apple.com/library/prerelease/mac/documentation/General/Reference/APIDiffsMacOSX10_10SeedDiff/modules/Darwin.html + kern_return_t rc = semaphore_timedwait(m_sema, ts); + return rc == KERN_SUCCESS; + } + + void signal() + { + while (semaphore_signal(m_sema) != KERN_SUCCESS); + } + + void signal(int count) + { + while (count-- > 0) + { + while (semaphore_signal(m_sema) != KERN_SUCCESS); + } + } +}; +#elif defined(__unix__) || defined(__MVS__) +//--------------------------------------------------------- +// Semaphore (POSIX, Linux, zOS aka MVS) +//--------------------------------------------------------- +class Semaphore +{ +private: + sem_t m_sema; + + Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + +public: + Semaphore(int initialCount = 0) + { + assert(initialCount >= 0); + int rc = sem_init(&m_sema, 0, initialCount); + assert(rc == 0); + (void)rc; + } + + ~Semaphore() + { + sem_destroy(&m_sema); + } + + bool wait() + { + // http://stackoverflow.com/questions/2013181/gdb-causes-sem-wait-to-fail-with-eintr-error + int rc; + do { + rc = sem_wait(&m_sema); + } while (rc == -1 && errno == EINTR); + return rc == 0; + } + + bool try_wait() + { + int rc; + do { + rc = sem_trywait(&m_sema); + } while (rc == -1 && errno == EINTR); + return rc == 0; + } + + bool timed_wait(std::uint64_t usecs) + { + struct timespec ts; + const int usecs_in_1_sec = 1000000; + const int nsecs_in_1_sec = 1000000000; + + // sem_timedwait needs an absolute time + // hence we need to first obtain the current time + // and then add the maximum time we want to wait + // we want to avoid clock_gettime because of linking issues + // chrono -> timespec conversion from here: https://embeddedartistry.com/blog/2019/01/31/converting-between-timespec-stdchrono/ + auto current_time = std::chrono::system_clock::now(); + auto secs = std::chrono::time_point_cast(current_time); + auto ns = std::chrono::time_point_cast(current_time) - std::chrono::time_point_cast(secs); + + ts.tv_sec = secs.time_since_epoch().count(); + ts.tv_nsec = ns.count(); + + // now add the time we want to wait + ts.tv_sec += usecs / usecs_in_1_sec; + ts.tv_nsec += (usecs % usecs_in_1_sec) * 1000; + + // sem_timedwait bombs if you have more than 1e9 in tv_nsec + // so we have to clean things up before passing it in + if (ts.tv_nsec >= nsecs_in_1_sec) { + ts.tv_nsec -= nsecs_in_1_sec; + ++ts.tv_sec; + } + + int rc; + do { + rc = sem_timedwait(&m_sema, &ts); + } while (rc == -1 && errno == EINTR); + return rc == 0; + } + + void signal() + { + while (sem_post(&m_sema) == -1); + } + + void signal(int count) + { + while (count-- > 0) + { + while (sem_post(&m_sema) == -1); + } + } +}; +#else +#error Unsupported platform! (No semaphore wrapper available) +#endif + +} // end namespace details + + +//--------------------------------------------------------- +// LightweightSemaphore +//--------------------------------------------------------- +class LightweightSemaphore +{ +public: + typedef std::make_signed::type ssize_t; + +private: + std::atomic m_count; + details::Semaphore m_sema; + + bool waitWithPartialSpinning(std::int64_t timeout_usecs = -1) + { + ssize_t oldCount; + // Is there a better way to set the initial spin count? + // If we lower it to 1000, testBenaphore becomes 15x slower on my Core i7-5930K Windows PC, + // as threads start hitting the kernel semaphore. + int spin = 10000; + while (--spin >= 0) + { + oldCount = m_count.load(std::memory_order_relaxed); + if ((oldCount > 0) && m_count.compare_exchange_strong(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) + return true; + std::atomic_signal_fence(std::memory_order_acquire); // Prevent the compiler from collapsing the loop. + } + oldCount = m_count.fetch_sub(1, std::memory_order_acquire); + if (oldCount > 0) + return true; + if (timeout_usecs < 0) + return m_sema.wait(); + if (m_sema.timed_wait((std::uint64_t)timeout_usecs)) + return true; + // At this point, we've timed out waiting for the semaphore, but the + // count is still decremented indicating we may still be waiting on + // it. So we have to re-adjust the count, but only if the semaphore + // wasn't signaled enough times for us too since then. If it was, we + // need to release the semaphore too. + while (true) + { + oldCount = m_count.load(std::memory_order_acquire); + if (oldCount >= 0 && m_sema.try_wait()) + return true; + if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) + return false; + } + } + + ssize_t waitManyWithPartialSpinning(ssize_t max, std::int64_t timeout_usecs = -1) + { + assert(max > 0); + ssize_t oldCount; + int spin = 10000; + while (--spin >= 0) + { + oldCount = m_count.load(std::memory_order_relaxed); + if (oldCount > 0) + { + ssize_t newCount = oldCount > max ? oldCount - max : 0; + if (m_count.compare_exchange_strong(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) + return oldCount - newCount; + } + std::atomic_signal_fence(std::memory_order_acquire); + } + oldCount = m_count.fetch_sub(1, std::memory_order_acquire); + if (oldCount <= 0) + { + if (timeout_usecs < 0) + { + if (!m_sema.wait()) + return 0; + } + else if (!m_sema.timed_wait((std::uint64_t)timeout_usecs)) + { + while (true) + { + oldCount = m_count.load(std::memory_order_acquire); + if (oldCount >= 0 && m_sema.try_wait()) + break; + if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) + return 0; + } + } + } + if (max > 1) + return 1 + tryWaitMany(max - 1); + return 1; + } + +public: + LightweightSemaphore(ssize_t initialCount = 0) : m_count(initialCount) + { + assert(initialCount >= 0); + } + + bool tryWait() + { + ssize_t oldCount = m_count.load(std::memory_order_relaxed); + while (oldCount > 0) + { + if (m_count.compare_exchange_weak(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) + return true; + } + return false; + } + + bool wait() + { + return tryWait() || waitWithPartialSpinning(); + } + + bool wait(std::int64_t timeout_usecs) + { + return tryWait() || waitWithPartialSpinning(timeout_usecs); + } + + // Acquires between 0 and (greedily) max, inclusive + ssize_t tryWaitMany(ssize_t max) + { + assert(max >= 0); + ssize_t oldCount = m_count.load(std::memory_order_relaxed); + while (oldCount > 0) + { + ssize_t newCount = oldCount > max ? oldCount - max : 0; + if (m_count.compare_exchange_weak(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) + return oldCount - newCount; + } + return 0; + } + + // Acquires at least one, and (greedily) at most max + ssize_t waitMany(ssize_t max, std::int64_t timeout_usecs) + { + assert(max >= 0); + ssize_t result = tryWaitMany(max); + if (result == 0 && max > 0) + result = waitManyWithPartialSpinning(max, timeout_usecs); + return result; + } + + ssize_t waitMany(ssize_t max) + { + ssize_t result = waitMany(max, -1); + assert(result > 0); + return result; + } + + void signal(ssize_t count = 1) + { + assert(count >= 0); + ssize_t oldCount = m_count.fetch_add(count, std::memory_order_release); + ssize_t toRelease = -oldCount < count ? -oldCount : count; + if (toRelease > 0) + { + m_sema.signal((int)toRelease); + } + } + + ssize_t availableApprox() const + { + ssize_t count = m_count.load(std::memory_order_relaxed); + return count > 0 ? count : 0; + } +}; + +} // end namespace duckdb_moodycamel diff --git a/src/duckdb/third_party/fast_float/fast_float/fast_float.h b/src/duckdb/third_party/fast_float/fast_float/fast_float.h new file mode 100644 index 00000000..d072fc75 --- /dev/null +++ b/src/duckdb/third_party/fast_float/fast_float/fast_float.h @@ -0,0 +1,2418 @@ +// duckdb_fast_float by Daniel Lemire +// duckdb_fast_float by João Paulo Magalhaes + + +// with contributions from Eugene Golushkov +// with contributions from Maksim Kita +// with contributions from Marcin Wojdyr +// with contributions from Neal Richardson +// with contributions from Tim Paine +// with contributions from Fabio Pellacini + + +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + + +#ifndef FASTFLOAT_FAST_FLOAT_H +#define FASTFLOAT_FAST_FLOAT_H + +#include + +namespace duckdb_fast_float { +enum chars_format { + scientific = 1<<0, + fixed = 1<<2, + hex = 1<<3, + general = fixed | scientific +}; + + +struct from_chars_result { + const char *ptr; + std::errc ec; +}; + +/** + * This function parses the character sequence [first,last) for a number. It parses floating-point numbers expecting + * a locale-indepent format equivalent to what is used by std::strtod in the default ("C") locale. + * The resulting floating-point value is the closest floating-point values (using either float or double), + * using the "round to even" convention for values that would otherwise fall right in-between two values. + * That is, we provide exact parsing according to the IEEE standard. + * + * Given a successful parse, the pointer (`ptr`) in the returned value is set to point right after the + * parsed number, and the `value` referenced is set to the parsed value. In case of error, the returned + * `ec` contains a representative error, otherwise the default (`std::errc()`) value is stored. + * + * The implementation does not throw and does not allocate memory (e.g., with `new` or `malloc`). + * + * Like the C++17 standard, the `duckdb_fast_float::from_chars` functions take an optional last argument of + * the type `duckdb_fast_float::chars_format`. It is a bitset value: we check whether + * `fmt & duckdb_fast_float::chars_format::fixed` and `fmt & duckdb_fast_float::chars_format::scientific` are set + * to determine whether we allowe the fixed point and scientific notation respectively. + * The default is `duckdb_fast_float::chars_format::general` which allows both `fixed` and `scientific`. + */ +template +from_chars_result from_chars(const char *first, const char *last, + T &value, + const char decimal_separator = '.', + chars_format fmt = chars_format::general) noexcept; + +} +#endif // FASTFLOAT_FAST_FLOAT_H + +#ifndef FASTFLOAT_FLOAT_COMMON_H +#define FASTFLOAT_FLOAT_COMMON_H + +#include +#include +#include + +#if (defined(__x86_64) || defined(__x86_64__) || defined(_M_X64) \ + || defined(__amd64) || defined(__aarch64__) || defined(_M_ARM64) \ + || defined(__MINGW64__) \ + || defined(__s390x__) \ + || (defined(__ppc64__) || defined(__PPC64__) || defined(__ppc64le__) || defined(__PPC64LE__)) \ + || defined(__EMSCRIPTEN__)) +#define FASTFLOAT_64BIT +#elif (defined(__i386) || defined(__i386__) || defined(_M_IX86) \ + || defined(__arm__) || defined(_M_ARM) \ + || defined(__MINGW32__)) +#define FASTFLOAT_32BIT +#else + // Need to check incrementally, since SIZE_MAX is a size_t, avoid overflow. + // We can never tell the register width, but the SIZE_MAX is a good approximation. + // UINTPTR_MAX and INTPTR_MAX are optional, so avoid them for max portability. + #if SIZE_MAX == 0xffff + #error Unknown platform (16-bit, unsupported) + #elif SIZE_MAX == 0xffffffff + #define FASTFLOAT_32BIT + #elif SIZE_MAX == 0xffffffffffffffff + #define FASTFLOAT_64BIT + #else + #error Unknown platform (not 32-bit, not 64-bit?) + #endif +#endif + +#if ((defined(_WIN32) || defined(_WIN64)) && !defined(__clang__)) +#include +#endif + +#if defined(_MSC_VER) && !defined(__clang__) +#define FASTFLOAT_VISUAL_STUDIO 1 +#endif + +#ifdef _WIN32 +#define FASTFLOAT_IS_BIG_ENDIAN 0 +#else +#if defined(__APPLE__) || defined(__FreeBSD__) +#include +#elif defined(sun) || defined(__sun) +#include +#elif defined(__MVS__) +#include +#else +#include +#endif +# +#ifndef __BYTE_ORDER__ +// safe choice +#define FASTFLOAT_IS_BIG_ENDIAN 0 +#endif +# +#ifndef __ORDER_LITTLE_ENDIAN__ +// safe choice +#define FASTFLOAT_IS_BIG_ENDIAN 0 +#endif +# +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define FASTFLOAT_IS_BIG_ENDIAN 0 +#else +#define FASTFLOAT_IS_BIG_ENDIAN 1 +#endif +#endif + +#ifdef FASTFLOAT_VISUAL_STUDIO +#define fastfloat_really_inline __forceinline +#else +#define fastfloat_really_inline inline __attribute__((always_inline)) +#endif + +namespace duckdb_fast_float { + +// Compares two ASCII strings in a case insensitive manner. +inline bool fastfloat_strncasecmp(const char *input1, const char *input2, + size_t length) { + char running_diff{0}; + for (size_t i = 0; i < length; i++) { + running_diff |= (input1[i] ^ input2[i]); + } + return (running_diff == 0) || (running_diff == 32); +} + +#ifndef FLT_EVAL_METHOD +#error "FLT_EVAL_METHOD should be defined, please include cfloat." +#endif + +namespace { +constexpr uint32_t max_digits = 768; +constexpr uint32_t max_digit_without_overflow = 19; +constexpr int32_t decimal_point_range = 2047; +} // namespace + +struct value128 { + uint64_t low; + uint64_t high; + value128(uint64_t _low, uint64_t _high) : low(_low), high(_high) {} + value128() : low(0), high(0) {} +}; + +/* result might be undefined when input_num is zero */ +fastfloat_really_inline int leading_zeroes(uint64_t input_num) { + assert(input_num > 0); +#ifdef FASTFLOAT_VISUAL_STUDIO + #if defined(_M_X64) || defined(_M_ARM64) + unsigned long leading_zero = 0; + // Search the mask data from most significant bit (MSB) + // to least significant bit (LSB) for a set bit (1). + _BitScanReverse64(&leading_zero, input_num); + return (int)(63 - leading_zero); + #else + int last_bit = 0; + if(input_num & uint64_t(0xffffffff00000000)) input_num >>= 32, last_bit |= 32; + if(input_num & uint64_t( 0xffff0000)) input_num >>= 16, last_bit |= 16; + if(input_num & uint64_t( 0xff00)) input_num >>= 8, last_bit |= 8; + if(input_num & uint64_t( 0xf0)) input_num >>= 4, last_bit |= 4; + if(input_num & uint64_t( 0xc)) input_num >>= 2, last_bit |= 2; + if(input_num & uint64_t( 0x2)) input_num >>= 1, last_bit |= 1; + return 63 - last_bit; + #endif +#else + return __builtin_clzll(input_num); +#endif +} + +#ifdef FASTFLOAT_32BIT + +// slow emulation routine for 32-bit +fastfloat_really_inline uint64_t emulu(uint32_t x, uint32_t y) { + return x * (uint64_t)y; +} + +// slow emulation routine for 32-bit +#if !defined(__MINGW64__) +fastfloat_really_inline uint64_t _umul128(uint64_t ab, uint64_t cd, + uint64_t *hi) { + uint64_t ad = emulu((uint32_t)(ab >> 32), (uint32_t)cd); + uint64_t bd = emulu((uint32_t)ab, (uint32_t)cd); + uint64_t adbc = ad + emulu((uint32_t)ab, (uint32_t)(cd >> 32)); + uint64_t adbc_carry = !!(adbc < ad); + uint64_t lo = bd + (adbc << 32); + *hi = emulu((uint32_t)(ab >> 32), (uint32_t)(cd >> 32)) + (adbc >> 32) + + (adbc_carry << 32) + !!(lo < bd); + return lo; +} +#endif // !__MINGW64__ + +#endif // FASTFLOAT_32BIT + + +// compute 64-bit a*b +fastfloat_really_inline value128 full_multiplication(uint64_t a, + uint64_t b) { + value128 answer; +#ifdef _M_ARM64 + // ARM64 has native support for 64-bit multiplications, no need to emulate + answer.high = __umulh(a, b); + answer.low = a * b; +#elif defined(FASTFLOAT_32BIT) || (defined(_WIN64) && !defined(__clang__)) + answer.low = _umul128(a, b, &answer.high); // _umul128 not available on ARM64 +#elif defined(FASTFLOAT_64BIT) + __uint128_t r = ((__uint128_t)a) * b; + answer.low = uint64_t(r); + answer.high = uint64_t(r >> 64); +#else + #error Not implemented +#endif + return answer; +} + + +struct adjusted_mantissa { + uint64_t mantissa{0}; + int power2{0}; // a negative value indicates an invalid result + adjusted_mantissa() = default; + bool operator==(const adjusted_mantissa &o) const { + return mantissa == o.mantissa && power2 == o.power2; + } + bool operator!=(const adjusted_mantissa &o) const { + return mantissa != o.mantissa || power2 != o.power2; + } +}; + +struct decimal { + uint32_t num_digits{0}; + int32_t decimal_point{0}; + bool negative{false}; + bool truncated{false}; + uint8_t digits[max_digits]; + decimal() = default; + // Copies are not allowed since this is a fat object. + decimal(const decimal &) = delete; + // Copies are not allowed since this is a fat object. + decimal &operator=(const decimal &) = delete; + // Moves are allowed: + decimal(decimal &&) = default; + decimal &operator=(decimal &&other) = default; +}; + +constexpr static double powers_of_ten_double[] = { + 1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, + 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, 1e20, 1e21, 1e22}; +constexpr static float powers_of_ten_float[] = {1e0, 1e1, 1e2, 1e3, 1e4, 1e5, + 1e6, 1e7, 1e8, 1e9, 1e10}; + +template struct binary_format { + static inline constexpr int mantissa_explicit_bits(); + static inline constexpr int minimum_exponent(); + static inline constexpr int infinite_power(); + static inline constexpr int sign_index(); + static inline constexpr int min_exponent_fast_path(); + static inline constexpr int max_exponent_fast_path(); + static inline constexpr int max_exponent_round_to_even(); + static inline constexpr int min_exponent_round_to_even(); + static inline constexpr uint64_t max_mantissa_fast_path(); + static inline constexpr int largest_power_of_ten(); + static inline constexpr int smallest_power_of_ten(); + static inline constexpr T exact_power_of_ten(int64_t power); +}; + +template <> inline constexpr int binary_format::mantissa_explicit_bits() { + return 52; +} +template <> inline constexpr int binary_format::mantissa_explicit_bits() { + return 23; +} + +template <> inline constexpr int binary_format::max_exponent_round_to_even() { + return 23; +} + +template <> inline constexpr int binary_format::max_exponent_round_to_even() { + return 10; +} + +template <> inline constexpr int binary_format::min_exponent_round_to_even() { + return -4; +} + +template <> inline constexpr int binary_format::min_exponent_round_to_even() { + return -17; +} + +template <> inline constexpr int binary_format::minimum_exponent() { + return -1023; +} +template <> inline constexpr int binary_format::minimum_exponent() { + return -127; +} + +template <> inline constexpr int binary_format::infinite_power() { + return 0x7FF; +} +template <> inline constexpr int binary_format::infinite_power() { + return 0xFF; +} + +template <> inline constexpr int binary_format::sign_index() { return 63; } +template <> inline constexpr int binary_format::sign_index() { return 31; } + +template <> inline constexpr int binary_format::min_exponent_fast_path() { +#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0) + return 0; +#else + return -22; +#endif +} +template <> inline constexpr int binary_format::min_exponent_fast_path() { +#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0) + return 0; +#else + return -10; +#endif +} + +template <> inline constexpr int binary_format::max_exponent_fast_path() { + return 22; +} +template <> inline constexpr int binary_format::max_exponent_fast_path() { + return 10; +} + +template <> inline constexpr uint64_t binary_format::max_mantissa_fast_path() { + return uint64_t(2) << mantissa_explicit_bits(); +} +template <> inline constexpr uint64_t binary_format::max_mantissa_fast_path() { + return uint64_t(2) << mantissa_explicit_bits(); +} + +template <> +inline constexpr double binary_format::exact_power_of_ten(int64_t power) { + return powers_of_ten_double[power]; +} +template <> +inline constexpr float binary_format::exact_power_of_ten(int64_t power) { + + return powers_of_ten_float[power]; +} + + +template <> +inline constexpr int binary_format::largest_power_of_ten() { + return 308; +} +template <> +inline constexpr int binary_format::largest_power_of_ten() { + return 38; +} + +template <> +inline constexpr int binary_format::smallest_power_of_ten() { + return -342; +} +template <> +inline constexpr int binary_format::smallest_power_of_ten() { + return -65; +} + +} // namespace duckdb_fast_float + +// for convenience: +template +inline OStream& operator<<(OStream &out, const duckdb_fast_float::decimal &d) { + out << "0."; + for (size_t i = 0; i < d.num_digits; i++) { + out << int32_t(d.digits[i]); + } + out << " * 10 ** " << d.decimal_point; + return out; +} + +#endif + + +#ifndef FASTFLOAT_ASCII_NUMBER_H +#define FASTFLOAT_ASCII_NUMBER_H + +#include +#include +#include +#include + + +namespace duckdb_fast_float { + +// Next function can be micro-optimized, but compilers are entirely +// able to optimize it well. +fastfloat_really_inline bool is_integer(char c) noexcept { return c >= '0' && c <= '9'; } + +fastfloat_really_inline uint64_t byteswap(uint64_t val) { + return (val & 0xFF00000000000000) >> 56 + | (val & 0x00FF000000000000) >> 40 + | (val & 0x0000FF0000000000) >> 24 + | (val & 0x000000FF00000000) >> 8 + | (val & 0x00000000FF000000) << 8 + | (val & 0x0000000000FF0000) << 24 + | (val & 0x000000000000FF00) << 40 + | (val & 0x00000000000000FF) << 56; +} + +fastfloat_really_inline uint64_t read_u64(const char *chars) { + uint64_t val; + ::memcpy(&val, chars, sizeof(uint64_t)); +#if FASTFLOAT_IS_BIG_ENDIAN == 1 + // Need to read as-if the number was in little-endian order. + val = byteswap(val); +#endif + return val; +} + +fastfloat_really_inline void write_u64(uint8_t *chars, uint64_t val) { +#if FASTFLOAT_IS_BIG_ENDIAN == 1 + // Need to read as-if the number was in little-endian order. + val = byteswap(val); +#endif + ::memcpy(chars, &val, sizeof(uint64_t)); +} + +// credit @aqrit +fastfloat_really_inline uint32_t parse_eight_digits_unrolled(uint64_t val) { + const uint64_t mask = 0x000000FF000000FF; + const uint64_t mul1 = 0x000F424000000064; // 100 + (1000000ULL << 32) + const uint64_t mul2 = 0x0000271000000001; // 1 + (10000ULL << 32) + val -= 0x3030303030303030; + val = (val * 10) + (val >> 8); // val = (val * 2561) >> 8; + val = (((val & mask) * mul1) + (((val >> 16) & mask) * mul2)) >> 32; + return uint32_t(val); +} + +fastfloat_really_inline uint32_t parse_eight_digits_unrolled(const char *chars) noexcept { + return parse_eight_digits_unrolled(read_u64(chars)); +} + +// credit @aqrit +fastfloat_really_inline bool is_made_of_eight_digits_fast(uint64_t val) noexcept { + return !((((val + 0x4646464646464646) | (val - 0x3030303030303030)) & + 0x8080808080808080)); +} + +fastfloat_really_inline bool is_made_of_eight_digits_fast(const char *chars) noexcept { + return is_made_of_eight_digits_fast(read_u64(chars)); +} + +struct parsed_number_string { + int64_t exponent; + uint64_t mantissa; + const char *lastmatch; + bool negative; + bool valid; + bool too_many_digits; +}; + + +// Assuming that you use no more than 19 digits, this will +// parse an ASCII string. +fastfloat_really_inline +parsed_number_string parse_number_string(const char *p, const char *pend, const char decimal_separator, chars_format fmt) noexcept { + parsed_number_string answer; + answer.valid = false; + answer.too_many_digits = false; + answer.negative = (*p == '-'); + if (*p == '-') { // C++17 20.19.3.(7.1) explicitly forbids '+' sign here + ++p; + if (p == pend) { + return answer; + } + if (!is_integer(*p) && (*p != decimal_separator)) { // a sign must be followed by an integer or the dot + return answer; + } + } + const char *const start_digits = p; + + uint64_t i = 0; // an unsigned int avoids signed overflows (which are bad) + + while ((p != pend) && is_integer(*p)) { + // a multiplication by 10 is cheaper than an arbitrary integer + // multiplication + i = 10 * i + + uint64_t(*p - '0'); // might overflow, we will handle the overflow later + ++p; + } + const char *const end_of_integer_part = p; + int64_t digit_count = int64_t(end_of_integer_part - start_digits); + int64_t exponent = 0; + if ((p != pend) && (*p == decimal_separator)) { + ++p; + // Fast approach only tested under little endian systems + if ((p + 8 <= pend) && is_made_of_eight_digits_fast(p)) { + i = i * 100000000 + parse_eight_digits_unrolled(p); // in rare cases, this will overflow, but that's ok + p += 8; + if ((p + 8 <= pend) && is_made_of_eight_digits_fast(p)) { + i = i * 100000000 + parse_eight_digits_unrolled(p); // in rare cases, this will overflow, but that's ok + p += 8; + } + } + while ((p != pend) && is_integer(*p)) { + uint8_t digit = uint8_t(*p - '0'); + ++p; + i = i * 10 + digit; // in rare cases, this will overflow, but that's ok + } + exponent = end_of_integer_part + 1 - p; + digit_count -= exponent; + } + // we must have encountered at least one integer! + if (digit_count == 0) { + return answer; + } + int64_t exp_number = 0; // explicit exponential part + if ((fmt & chars_format::scientific) && (p != pend) && (('e' == *p) || ('E' == *p))) { + const char * location_of_e = p; + ++p; + bool neg_exp = false; + if ((p != pend) && ('-' == *p)) { + neg_exp = true; + ++p; + } else if ((p != pend) && ('+' == *p)) { // '+' on exponent is allowed by C++17 20.19.3.(7.1) + ++p; + } + if ((p == pend) || !is_integer(*p)) { + if(!(fmt & chars_format::fixed)) { + // We are in error. + return answer; + } + // Otherwise, we will be ignoring the 'e'. + p = location_of_e; + } else { + while ((p != pend) && is_integer(*p)) { + uint8_t digit = uint8_t(*p - '0'); + if (exp_number < 0x10000) { + exp_number = 10 * exp_number + digit; + } + ++p; + } + if(neg_exp) { exp_number = - exp_number; } + exponent += exp_number; + } + } else { + // If it scientific and not fixed, we have to bail out. + if((fmt & chars_format::scientific) && !(fmt & chars_format::fixed)) { return answer; } + } + answer.lastmatch = p; + answer.valid = true; + + // If we frequently had to deal with long strings of digits, + // we could extend our code by using a 128-bit integer instead + // of a 64-bit integer. However, this is uncommon. + // + // We can deal with up to 19 digits. + if (digit_count > 19) { // this is uncommon + // It is possible that the integer had an overflow. + // We have to handle the case where we have 0.0000somenumber. + // We need to be mindful of the case where we only have zeroes... + // E.g., 0.000000000...000. + const char *start = start_digits; + while ((start != pend) && (*start == '0' || *start == decimal_separator)) { + if(*start == '0') { digit_count --; } + start++; + } + if (digit_count > 19) { + answer.too_many_digits = true; + // Let us start again, this time, avoiding overflows. + i = 0; + p = start_digits; + const uint64_t minimal_nineteen_digit_integer{1000000000000000000}; + while((i < minimal_nineteen_digit_integer) && (p != pend) && is_integer(*p)) { + i = i * 10 + uint64_t(*p - '0'); + ++p; + } + if (i >= minimal_nineteen_digit_integer) { // We have a big integers + exponent = end_of_integer_part - p + exp_number; + } else { // We have a value with a fractional component. + p++; // skip the decimal_separator + const char *first_after_period = p; + while((i < minimal_nineteen_digit_integer) && (p != pend) && is_integer(*p)) { + i = i * 10 + uint64_t(*p - '0'); + ++p; + } + exponent = first_after_period - p + exp_number; + } + // We have now corrected both exponent and i, to a truncated value + } + } + answer.exponent = exponent; + answer.mantissa = i; + return answer; +} + + +// This should always succeed since it follows a call to parse_number_string +// This function could be optimized. In particular, we could stop after 19 digits +// and try to bail out. Furthermore, we should be able to recover the computed +// exponent from the pass in parse_number_string. +fastfloat_really_inline decimal parse_decimal(const char *p, const char *pend, const char decimal_separator = '.') noexcept { + decimal answer; + answer.num_digits = 0; + answer.decimal_point = 0; + answer.truncated = false; + answer.negative = (*p == '-'); + if (*p == '-') { // C++17 20.19.3.(7.1) explicitly forbids '+' sign here + ++p; + } + // skip leading zeroes + while ((p != pend) && (*p == '0')) { + ++p; + } + while ((p != pend) && is_integer(*p)) { + if (answer.num_digits < max_digits) { + answer.digits[answer.num_digits] = uint8_t(*p - '0'); + } + answer.num_digits++; + ++p; + } + if ((p != pend) && (*p == decimal_separator)) { + ++p; + const char *first_after_period = p; + // if we have not yet encountered a zero, we have to skip it as well + if(answer.num_digits == 0) { + // skip zeros + while ((p != pend) && (*p == '0')) { + ++p; + } + } + // We expect that this loop will often take the bulk of the running time + // because when a value has lots of digits, these digits often + while ((p + 8 <= pend) && (answer.num_digits + 8 < max_digits)) { + uint64_t val = read_u64(p); + if(! is_made_of_eight_digits_fast(val)) { break; } + // We have eight digits, process them in one go! + val -= 0x3030303030303030; + write_u64(answer.digits + answer.num_digits, val); + answer.num_digits += 8; + p += 8; + } + while ((p != pend) && is_integer(*p)) { + if (answer.num_digits < max_digits) { + answer.digits[answer.num_digits] = uint8_t(*p - '0'); + } + answer.num_digits++; + ++p; + } + answer.decimal_point = int32_t(first_after_period - p); + } + // We want num_digits to be the number of significant digits, excluding + // leading *and* trailing zeros! Otherwise the truncated flag later is + // going to be misleading. + if(answer.num_digits > 0) { + // We potentially need the answer.num_digits > 0 guard because we + // prune leading zeros. So with answer.num_digits > 0, we know that + // we have at least one non-zero digit. + const char *preverse = p - 1; + int32_t trailing_zeros = 0; + while ((*preverse == '0') || (*preverse == decimal_separator)) { + if(*preverse == '0') { trailing_zeros++; }; + --preverse; + } + answer.decimal_point += int32_t(answer.num_digits); + answer.num_digits -= uint32_t(trailing_zeros); + } + if(answer.num_digits > max_digits) { + answer.truncated = true; + answer.num_digits = max_digits; + } + if ((p != pend) && (('e' == *p) || ('E' == *p))) { + ++p; + bool neg_exp = false; + if ((p != pend) && ('-' == *p)) { + neg_exp = true; + ++p; + } else if ((p != pend) && ('+' == *p)) { // '+' on exponent is allowed by C++17 20.19.3.(7.1) + ++p; + } + int32_t exp_number = 0; // exponential part + while ((p != pend) && is_integer(*p)) { + uint8_t digit = uint8_t(*p - '0'); + if (exp_number < 0x10000) { + exp_number = 10 * exp_number + digit; + } + ++p; + } + answer.decimal_point += (neg_exp ? -exp_number : exp_number); + } + // In very rare cases, we may have fewer than 19 digits, we want to be able to reliably + // assume that all digits up to max_digit_without_overflow have been initialized. + for(uint32_t i = answer.num_digits; i < max_digit_without_overflow; i++) { answer.digits[i] = 0; } + + return answer; +} +} // namespace duckdb_fast_float + +#endif + + +#ifndef FASTFLOAT_FAST_TABLE_H +#define FASTFLOAT_FAST_TABLE_H +#include + +namespace duckdb_fast_float { + +/** + * When mapping numbers from decimal to binary, + * we go from w * 10^q to m * 2^p but we have + * 10^q = 5^q * 2^q, so effectively + * we are trying to match + * w * 2^q * 5^q to m * 2^p. Thus the powers of two + * are not a concern since they can be represented + * exactly using the binary notation, only the powers of five + * affect the binary significand. + */ + +/** + * The smallest non-zero float (binary64) is 2^−1074. + * We take as input numbers of the form w x 10^q where w < 2^64. + * We have that w * 10^-343 < 2^(64-344) 5^-343 < 2^-1076. + * However, we have that + * (2^64-1) * 10^-342 = (2^64-1) * 2^-342 * 5^-342 > 2^−1074. + * Thus it is possible for a number of the form w * 10^-342 where + * w is a 64-bit value to be a non-zero floating-point number. + ********* + * Any number of form w * 10^309 where w>= 1 is going to be + * infinite in binary64 so we never need to worry about powers + * of 5 greater than 308. + */ +template +struct powers_template { + +constexpr static int smallest_power_of_five = binary_format::smallest_power_of_ten(); +constexpr static int largest_power_of_five = binary_format::largest_power_of_ten(); +constexpr static int number_of_entries = 2 * (largest_power_of_five - smallest_power_of_five + 1); +// Powers of five from 5^-342 all the way to 5^308 rounded toward one. +static const uint64_t power_of_five_128[number_of_entries]; +}; + +template +const uint64_t powers_template::power_of_five_128[number_of_entries] = { + 0xeef453d6923bd65a,0x113faa2906a13b3f, + 0x9558b4661b6565f8,0x4ac7ca59a424c507, + 0xbaaee17fa23ebf76,0x5d79bcf00d2df649, + 0xe95a99df8ace6f53,0xf4d82c2c107973dc, + 0x91d8a02bb6c10594,0x79071b9b8a4be869, + 0xb64ec836a47146f9,0x9748e2826cdee284, + 0xe3e27a444d8d98b7,0xfd1b1b2308169b25, + 0x8e6d8c6ab0787f72,0xfe30f0f5e50e20f7, + 0xb208ef855c969f4f,0xbdbd2d335e51a935, + 0xde8b2b66b3bc4723,0xad2c788035e61382, + 0x8b16fb203055ac76,0x4c3bcb5021afcc31, + 0xaddcb9e83c6b1793,0xdf4abe242a1bbf3d, + 0xd953e8624b85dd78,0xd71d6dad34a2af0d, + 0x87d4713d6f33aa6b,0x8672648c40e5ad68, + 0xa9c98d8ccb009506,0x680efdaf511f18c2, + 0xd43bf0effdc0ba48,0x212bd1b2566def2, + 0x84a57695fe98746d,0x14bb630f7604b57, + 0xa5ced43b7e3e9188,0x419ea3bd35385e2d, + 0xcf42894a5dce35ea,0x52064cac828675b9, + 0x818995ce7aa0e1b2,0x7343efebd1940993, + 0xa1ebfb4219491a1f,0x1014ebe6c5f90bf8, + 0xca66fa129f9b60a6,0xd41a26e077774ef6, + 0xfd00b897478238d0,0x8920b098955522b4, + 0x9e20735e8cb16382,0x55b46e5f5d5535b0, + 0xc5a890362fddbc62,0xeb2189f734aa831d, + 0xf712b443bbd52b7b,0xa5e9ec7501d523e4, + 0x9a6bb0aa55653b2d,0x47b233c92125366e, + 0xc1069cd4eabe89f8,0x999ec0bb696e840a, + 0xf148440a256e2c76,0xc00670ea43ca250d, + 0x96cd2a865764dbca,0x380406926a5e5728, + 0xbc807527ed3e12bc,0xc605083704f5ecf2, + 0xeba09271e88d976b,0xf7864a44c633682e, + 0x93445b8731587ea3,0x7ab3ee6afbe0211d, + 0xb8157268fdae9e4c,0x5960ea05bad82964, + 0xe61acf033d1a45df,0x6fb92487298e33bd, + 0x8fd0c16206306bab,0xa5d3b6d479f8e056, + 0xb3c4f1ba87bc8696,0x8f48a4899877186c, + 0xe0b62e2929aba83c,0x331acdabfe94de87, + 0x8c71dcd9ba0b4925,0x9ff0c08b7f1d0b14, + 0xaf8e5410288e1b6f,0x7ecf0ae5ee44dd9, + 0xdb71e91432b1a24a,0xc9e82cd9f69d6150, + 0x892731ac9faf056e,0xbe311c083a225cd2, + 0xab70fe17c79ac6ca,0x6dbd630a48aaf406, + 0xd64d3d9db981787d,0x92cbbccdad5b108, + 0x85f0468293f0eb4e,0x25bbf56008c58ea5, + 0xa76c582338ed2621,0xaf2af2b80af6f24e, + 0xd1476e2c07286faa,0x1af5af660db4aee1, + 0x82cca4db847945ca,0x50d98d9fc890ed4d, + 0xa37fce126597973c,0xe50ff107bab528a0, + 0xcc5fc196fefd7d0c,0x1e53ed49a96272c8, + 0xff77b1fcbebcdc4f,0x25e8e89c13bb0f7a, + 0x9faacf3df73609b1,0x77b191618c54e9ac, + 0xc795830d75038c1d,0xd59df5b9ef6a2417, + 0xf97ae3d0d2446f25,0x4b0573286b44ad1d, + 0x9becce62836ac577,0x4ee367f9430aec32, + 0xc2e801fb244576d5,0x229c41f793cda73f, + 0xf3a20279ed56d48a,0x6b43527578c1110f, + 0x9845418c345644d6,0x830a13896b78aaa9, + 0xbe5691ef416bd60c,0x23cc986bc656d553, + 0xedec366b11c6cb8f,0x2cbfbe86b7ec8aa8, + 0x94b3a202eb1c3f39,0x7bf7d71432f3d6a9, + 0xb9e08a83a5e34f07,0xdaf5ccd93fb0cc53, + 0xe858ad248f5c22c9,0xd1b3400f8f9cff68, + 0x91376c36d99995be,0x23100809b9c21fa1, + 0xb58547448ffffb2d,0xabd40a0c2832a78a, + 0xe2e69915b3fff9f9,0x16c90c8f323f516c, + 0x8dd01fad907ffc3b,0xae3da7d97f6792e3, + 0xb1442798f49ffb4a,0x99cd11cfdf41779c, + 0xdd95317f31c7fa1d,0x40405643d711d583, + 0x8a7d3eef7f1cfc52,0x482835ea666b2572, + 0xad1c8eab5ee43b66,0xda3243650005eecf, + 0xd863b256369d4a40,0x90bed43e40076a82, + 0x873e4f75e2224e68,0x5a7744a6e804a291, + 0xa90de3535aaae202,0x711515d0a205cb36, + 0xd3515c2831559a83,0xd5a5b44ca873e03, + 0x8412d9991ed58091,0xe858790afe9486c2, + 0xa5178fff668ae0b6,0x626e974dbe39a872, + 0xce5d73ff402d98e3,0xfb0a3d212dc8128f, + 0x80fa687f881c7f8e,0x7ce66634bc9d0b99, + 0xa139029f6a239f72,0x1c1fffc1ebc44e80, + 0xc987434744ac874e,0xa327ffb266b56220, + 0xfbe9141915d7a922,0x4bf1ff9f0062baa8, + 0x9d71ac8fada6c9b5,0x6f773fc3603db4a9, + 0xc4ce17b399107c22,0xcb550fb4384d21d3, + 0xf6019da07f549b2b,0x7e2a53a146606a48, + 0x99c102844f94e0fb,0x2eda7444cbfc426d, + 0xc0314325637a1939,0xfa911155fefb5308, + 0xf03d93eebc589f88,0x793555ab7eba27ca, + 0x96267c7535b763b5,0x4bc1558b2f3458de, + 0xbbb01b9283253ca2,0x9eb1aaedfb016f16, + 0xea9c227723ee8bcb,0x465e15a979c1cadc, + 0x92a1958a7675175f,0xbfacd89ec191ec9, + 0xb749faed14125d36,0xcef980ec671f667b, + 0xe51c79a85916f484,0x82b7e12780e7401a, + 0x8f31cc0937ae58d2,0xd1b2ecb8b0908810, + 0xb2fe3f0b8599ef07,0x861fa7e6dcb4aa15, + 0xdfbdcece67006ac9,0x67a791e093e1d49a, + 0x8bd6a141006042bd,0xe0c8bb2c5c6d24e0, + 0xaecc49914078536d,0x58fae9f773886e18, + 0xda7f5bf590966848,0xaf39a475506a899e, + 0x888f99797a5e012d,0x6d8406c952429603, + 0xaab37fd7d8f58178,0xc8e5087ba6d33b83, + 0xd5605fcdcf32e1d6,0xfb1e4a9a90880a64, + 0x855c3be0a17fcd26,0x5cf2eea09a55067f, + 0xa6b34ad8c9dfc06f,0xf42faa48c0ea481e, + 0xd0601d8efc57b08b,0xf13b94daf124da26, + 0x823c12795db6ce57,0x76c53d08d6b70858, + 0xa2cb1717b52481ed,0x54768c4b0c64ca6e, + 0xcb7ddcdda26da268,0xa9942f5dcf7dfd09, + 0xfe5d54150b090b02,0xd3f93b35435d7c4c, + 0x9efa548d26e5a6e1,0xc47bc5014a1a6daf, + 0xc6b8e9b0709f109a,0x359ab6419ca1091b, + 0xf867241c8cc6d4c0,0xc30163d203c94b62, + 0x9b407691d7fc44f8,0x79e0de63425dcf1d, + 0xc21094364dfb5636,0x985915fc12f542e4, + 0xf294b943e17a2bc4,0x3e6f5b7b17b2939d, + 0x979cf3ca6cec5b5a,0xa705992ceecf9c42, + 0xbd8430bd08277231,0x50c6ff782a838353, + 0xece53cec4a314ebd,0xa4f8bf5635246428, + 0x940f4613ae5ed136,0x871b7795e136be99, + 0xb913179899f68584,0x28e2557b59846e3f, + 0xe757dd7ec07426e5,0x331aeada2fe589cf, + 0x9096ea6f3848984f,0x3ff0d2c85def7621, + 0xb4bca50b065abe63,0xfed077a756b53a9, + 0xe1ebce4dc7f16dfb,0xd3e8495912c62894, + 0x8d3360f09cf6e4bd,0x64712dd7abbbd95c, + 0xb080392cc4349dec,0xbd8d794d96aacfb3, + 0xdca04777f541c567,0xecf0d7a0fc5583a0, + 0x89e42caaf9491b60,0xf41686c49db57244, + 0xac5d37d5b79b6239,0x311c2875c522ced5, + 0xd77485cb25823ac7,0x7d633293366b828b, + 0x86a8d39ef77164bc,0xae5dff9c02033197, + 0xa8530886b54dbdeb,0xd9f57f830283fdfc, + 0xd267caa862a12d66,0xd072df63c324fd7b, + 0x8380dea93da4bc60,0x4247cb9e59f71e6d, + 0xa46116538d0deb78,0x52d9be85f074e608, + 0xcd795be870516656,0x67902e276c921f8b, + 0x806bd9714632dff6,0xba1cd8a3db53b6, + 0xa086cfcd97bf97f3,0x80e8a40eccd228a4, + 0xc8a883c0fdaf7df0,0x6122cd128006b2cd, + 0xfad2a4b13d1b5d6c,0x796b805720085f81, + 0x9cc3a6eec6311a63,0xcbe3303674053bb0, + 0xc3f490aa77bd60fc,0xbedbfc4411068a9c, + 0xf4f1b4d515acb93b,0xee92fb5515482d44, + 0x991711052d8bf3c5,0x751bdd152d4d1c4a, + 0xbf5cd54678eef0b6,0xd262d45a78a0635d, + 0xef340a98172aace4,0x86fb897116c87c34, + 0x9580869f0e7aac0e,0xd45d35e6ae3d4da0, + 0xbae0a846d2195712,0x8974836059cca109, + 0xe998d258869facd7,0x2bd1a438703fc94b, + 0x91ff83775423cc06,0x7b6306a34627ddcf, + 0xb67f6455292cbf08,0x1a3bc84c17b1d542, + 0xe41f3d6a7377eeca,0x20caba5f1d9e4a93, + 0x8e938662882af53e,0x547eb47b7282ee9c, + 0xb23867fb2a35b28d,0xe99e619a4f23aa43, + 0xdec681f9f4c31f31,0x6405fa00e2ec94d4, + 0x8b3c113c38f9f37e,0xde83bc408dd3dd04, + 0xae0b158b4738705e,0x9624ab50b148d445, + 0xd98ddaee19068c76,0x3badd624dd9b0957, + 0x87f8a8d4cfa417c9,0xe54ca5d70a80e5d6, + 0xa9f6d30a038d1dbc,0x5e9fcf4ccd211f4c, + 0xd47487cc8470652b,0x7647c3200069671f, + 0x84c8d4dfd2c63f3b,0x29ecd9f40041e073, + 0xa5fb0a17c777cf09,0xf468107100525890, + 0xcf79cc9db955c2cc,0x7182148d4066eeb4, + 0x81ac1fe293d599bf,0xc6f14cd848405530, + 0xa21727db38cb002f,0xb8ada00e5a506a7c, + 0xca9cf1d206fdc03b,0xa6d90811f0e4851c, + 0xfd442e4688bd304a,0x908f4a166d1da663, + 0x9e4a9cec15763e2e,0x9a598e4e043287fe, + 0xc5dd44271ad3cdba,0x40eff1e1853f29fd, + 0xf7549530e188c128,0xd12bee59e68ef47c, + 0x9a94dd3e8cf578b9,0x82bb74f8301958ce, + 0xc13a148e3032d6e7,0xe36a52363c1faf01, + 0xf18899b1bc3f8ca1,0xdc44e6c3cb279ac1, + 0x96f5600f15a7b7e5,0x29ab103a5ef8c0b9, + 0xbcb2b812db11a5de,0x7415d448f6b6f0e7, + 0xebdf661791d60f56,0x111b495b3464ad21, + 0x936b9fcebb25c995,0xcab10dd900beec34, + 0xb84687c269ef3bfb,0x3d5d514f40eea742, + 0xe65829b3046b0afa,0xcb4a5a3112a5112, + 0x8ff71a0fe2c2e6dc,0x47f0e785eaba72ab, + 0xb3f4e093db73a093,0x59ed216765690f56, + 0xe0f218b8d25088b8,0x306869c13ec3532c, + 0x8c974f7383725573,0x1e414218c73a13fb, + 0xafbd2350644eeacf,0xe5d1929ef90898fa, + 0xdbac6c247d62a583,0xdf45f746b74abf39, + 0x894bc396ce5da772,0x6b8bba8c328eb783, + 0xab9eb47c81f5114f,0x66ea92f3f326564, + 0xd686619ba27255a2,0xc80a537b0efefebd, + 0x8613fd0145877585,0xbd06742ce95f5f36, + 0xa798fc4196e952e7,0x2c48113823b73704, + 0xd17f3b51fca3a7a0,0xf75a15862ca504c5, + 0x82ef85133de648c4,0x9a984d73dbe722fb, + 0xa3ab66580d5fdaf5,0xc13e60d0d2e0ebba, + 0xcc963fee10b7d1b3,0x318df905079926a8, + 0xffbbcfe994e5c61f,0xfdf17746497f7052, + 0x9fd561f1fd0f9bd3,0xfeb6ea8bedefa633, + 0xc7caba6e7c5382c8,0xfe64a52ee96b8fc0, + 0xf9bd690a1b68637b,0x3dfdce7aa3c673b0, + 0x9c1661a651213e2d,0x6bea10ca65c084e, + 0xc31bfa0fe5698db8,0x486e494fcff30a62, + 0xf3e2f893dec3f126,0x5a89dba3c3efccfa, + 0x986ddb5c6b3a76b7,0xf89629465a75e01c, + 0xbe89523386091465,0xf6bbb397f1135823, + 0xee2ba6c0678b597f,0x746aa07ded582e2c, + 0x94db483840b717ef,0xa8c2a44eb4571cdc, + 0xba121a4650e4ddeb,0x92f34d62616ce413, + 0xe896a0d7e51e1566,0x77b020baf9c81d17, + 0x915e2486ef32cd60,0xace1474dc1d122e, + 0xb5b5ada8aaff80b8,0xd819992132456ba, + 0xe3231912d5bf60e6,0x10e1fff697ed6c69, + 0x8df5efabc5979c8f,0xca8d3ffa1ef463c1, + 0xb1736b96b6fd83b3,0xbd308ff8a6b17cb2, + 0xddd0467c64bce4a0,0xac7cb3f6d05ddbde, + 0x8aa22c0dbef60ee4,0x6bcdf07a423aa96b, + 0xad4ab7112eb3929d,0x86c16c98d2c953c6, + 0xd89d64d57a607744,0xe871c7bf077ba8b7, + 0x87625f056c7c4a8b,0x11471cd764ad4972, + 0xa93af6c6c79b5d2d,0xd598e40d3dd89bcf, + 0xd389b47879823479,0x4aff1d108d4ec2c3, + 0x843610cb4bf160cb,0xcedf722a585139ba, + 0xa54394fe1eedb8fe,0xc2974eb4ee658828, + 0xce947a3da6a9273e,0x733d226229feea32, + 0x811ccc668829b887,0x806357d5a3f525f, + 0xa163ff802a3426a8,0xca07c2dcb0cf26f7, + 0xc9bcff6034c13052,0xfc89b393dd02f0b5, + 0xfc2c3f3841f17c67,0xbbac2078d443ace2, + 0x9d9ba7832936edc0,0xd54b944b84aa4c0d, + 0xc5029163f384a931,0xa9e795e65d4df11, + 0xf64335bcf065d37d,0x4d4617b5ff4a16d5, + 0x99ea0196163fa42e,0x504bced1bf8e4e45, + 0xc06481fb9bcf8d39,0xe45ec2862f71e1d6, + 0xf07da27a82c37088,0x5d767327bb4e5a4c, + 0x964e858c91ba2655,0x3a6a07f8d510f86f, + 0xbbe226efb628afea,0x890489f70a55368b, + 0xeadab0aba3b2dbe5,0x2b45ac74ccea842e, + 0x92c8ae6b464fc96f,0x3b0b8bc90012929d, + 0xb77ada0617e3bbcb,0x9ce6ebb40173744, + 0xe55990879ddcaabd,0xcc420a6a101d0515, + 0x8f57fa54c2a9eab6,0x9fa946824a12232d, + 0xb32df8e9f3546564,0x47939822dc96abf9, + 0xdff9772470297ebd,0x59787e2b93bc56f7, + 0x8bfbea76c619ef36,0x57eb4edb3c55b65a, + 0xaefae51477a06b03,0xede622920b6b23f1, + 0xdab99e59958885c4,0xe95fab368e45eced, + 0x88b402f7fd75539b,0x11dbcb0218ebb414, + 0xaae103b5fcd2a881,0xd652bdc29f26a119, + 0xd59944a37c0752a2,0x4be76d3346f0495f, + 0x857fcae62d8493a5,0x6f70a4400c562ddb, + 0xa6dfbd9fb8e5b88e,0xcb4ccd500f6bb952, + 0xd097ad07a71f26b2,0x7e2000a41346a7a7, + 0x825ecc24c873782f,0x8ed400668c0c28c8, + 0xa2f67f2dfa90563b,0x728900802f0f32fa, + 0xcbb41ef979346bca,0x4f2b40a03ad2ffb9, + 0xfea126b7d78186bc,0xe2f610c84987bfa8, + 0x9f24b832e6b0f436,0xdd9ca7d2df4d7c9, + 0xc6ede63fa05d3143,0x91503d1c79720dbb, + 0xf8a95fcf88747d94,0x75a44c6397ce912a, + 0x9b69dbe1b548ce7c,0xc986afbe3ee11aba, + 0xc24452da229b021b,0xfbe85badce996168, + 0xf2d56790ab41c2a2,0xfae27299423fb9c3, + 0x97c560ba6b0919a5,0xdccd879fc967d41a, + 0xbdb6b8e905cb600f,0x5400e987bbc1c920, + 0xed246723473e3813,0x290123e9aab23b68, + 0x9436c0760c86e30b,0xf9a0b6720aaf6521, + 0xb94470938fa89bce,0xf808e40e8d5b3e69, + 0xe7958cb87392c2c2,0xb60b1d1230b20e04, + 0x90bd77f3483bb9b9,0xb1c6f22b5e6f48c2, + 0xb4ecd5f01a4aa828,0x1e38aeb6360b1af3, + 0xe2280b6c20dd5232,0x25c6da63c38de1b0, + 0x8d590723948a535f,0x579c487e5a38ad0e, + 0xb0af48ec79ace837,0x2d835a9df0c6d851, + 0xdcdb1b2798182244,0xf8e431456cf88e65, + 0x8a08f0f8bf0f156b,0x1b8e9ecb641b58ff, + 0xac8b2d36eed2dac5,0xe272467e3d222f3f, + 0xd7adf884aa879177,0x5b0ed81dcc6abb0f, + 0x86ccbb52ea94baea,0x98e947129fc2b4e9, + 0xa87fea27a539e9a5,0x3f2398d747b36224, + 0xd29fe4b18e88640e,0x8eec7f0d19a03aad, + 0x83a3eeeef9153e89,0x1953cf68300424ac, + 0xa48ceaaab75a8e2b,0x5fa8c3423c052dd7, + 0xcdb02555653131b6,0x3792f412cb06794d, + 0x808e17555f3ebf11,0xe2bbd88bbee40bd0, + 0xa0b19d2ab70e6ed6,0x5b6aceaeae9d0ec4, + 0xc8de047564d20a8b,0xf245825a5a445275, + 0xfb158592be068d2e,0xeed6e2f0f0d56712, + 0x9ced737bb6c4183d,0x55464dd69685606b, + 0xc428d05aa4751e4c,0xaa97e14c3c26b886, + 0xf53304714d9265df,0xd53dd99f4b3066a8, + 0x993fe2c6d07b7fab,0xe546a8038efe4029, + 0xbf8fdb78849a5f96,0xde98520472bdd033, + 0xef73d256a5c0f77c,0x963e66858f6d4440, + 0x95a8637627989aad,0xdde7001379a44aa8, + 0xbb127c53b17ec159,0x5560c018580d5d52, + 0xe9d71b689dde71af,0xaab8f01e6e10b4a6, + 0x9226712162ab070d,0xcab3961304ca70e8, + 0xb6b00d69bb55c8d1,0x3d607b97c5fd0d22, + 0xe45c10c42a2b3b05,0x8cb89a7db77c506a, + 0x8eb98a7a9a5b04e3,0x77f3608e92adb242, + 0xb267ed1940f1c61c,0x55f038b237591ed3, + 0xdf01e85f912e37a3,0x6b6c46dec52f6688, + 0x8b61313bbabce2c6,0x2323ac4b3b3da015, + 0xae397d8aa96c1b77,0xabec975e0a0d081a, + 0xd9c7dced53c72255,0x96e7bd358c904a21, + 0x881cea14545c7575,0x7e50d64177da2e54, + 0xaa242499697392d2,0xdde50bd1d5d0b9e9, + 0xd4ad2dbfc3d07787,0x955e4ec64b44e864, + 0x84ec3c97da624ab4,0xbd5af13bef0b113e, + 0xa6274bbdd0fadd61,0xecb1ad8aeacdd58e, + 0xcfb11ead453994ba,0x67de18eda5814af2, + 0x81ceb32c4b43fcf4,0x80eacf948770ced7, + 0xa2425ff75e14fc31,0xa1258379a94d028d, + 0xcad2f7f5359a3b3e,0x96ee45813a04330, + 0xfd87b5f28300ca0d,0x8bca9d6e188853fc, + 0x9e74d1b791e07e48,0x775ea264cf55347e, + 0xc612062576589dda,0x95364afe032a819e, + 0xf79687aed3eec551,0x3a83ddbd83f52205, + 0x9abe14cd44753b52,0xc4926a9672793543, + 0xc16d9a0095928a27,0x75b7053c0f178294, + 0xf1c90080baf72cb1,0x5324c68b12dd6339, + 0x971da05074da7bee,0xd3f6fc16ebca5e04, + 0xbce5086492111aea,0x88f4bb1ca6bcf585, + 0xec1e4a7db69561a5,0x2b31e9e3d06c32e6, + 0x9392ee8e921d5d07,0x3aff322e62439fd0, + 0xb877aa3236a4b449,0x9befeb9fad487c3, + 0xe69594bec44de15b,0x4c2ebe687989a9b4, + 0x901d7cf73ab0acd9,0xf9d37014bf60a11, + 0xb424dc35095cd80f,0x538484c19ef38c95, + 0xe12e13424bb40e13,0x2865a5f206b06fba, + 0x8cbccc096f5088cb,0xf93f87b7442e45d4, + 0xafebff0bcb24aafe,0xf78f69a51539d749, + 0xdbe6fecebdedd5be,0xb573440e5a884d1c, + 0x89705f4136b4a597,0x31680a88f8953031, + 0xabcc77118461cefc,0xfdc20d2b36ba7c3e, + 0xd6bf94d5e57a42bc,0x3d32907604691b4d, + 0x8637bd05af6c69b5,0xa63f9a49c2c1b110, + 0xa7c5ac471b478423,0xfcf80dc33721d54, + 0xd1b71758e219652b,0xd3c36113404ea4a9, + 0x83126e978d4fdf3b,0x645a1cac083126ea, + 0xa3d70a3d70a3d70a,0x3d70a3d70a3d70a4, + 0xcccccccccccccccc,0xcccccccccccccccd, + 0x8000000000000000,0x0, + 0xa000000000000000,0x0, + 0xc800000000000000,0x0, + 0xfa00000000000000,0x0, + 0x9c40000000000000,0x0, + 0xc350000000000000,0x0, + 0xf424000000000000,0x0, + 0x9896800000000000,0x0, + 0xbebc200000000000,0x0, + 0xee6b280000000000,0x0, + 0x9502f90000000000,0x0, + 0xba43b74000000000,0x0, + 0xe8d4a51000000000,0x0, + 0x9184e72a00000000,0x0, + 0xb5e620f480000000,0x0, + 0xe35fa931a0000000,0x0, + 0x8e1bc9bf04000000,0x0, + 0xb1a2bc2ec5000000,0x0, + 0xde0b6b3a76400000,0x0, + 0x8ac7230489e80000,0x0, + 0xad78ebc5ac620000,0x0, + 0xd8d726b7177a8000,0x0, + 0x878678326eac9000,0x0, + 0xa968163f0a57b400,0x0, + 0xd3c21bcecceda100,0x0, + 0x84595161401484a0,0x0, + 0xa56fa5b99019a5c8,0x0, + 0xcecb8f27f4200f3a,0x0, + 0x813f3978f8940984,0x4000000000000000, + 0xa18f07d736b90be5,0x5000000000000000, + 0xc9f2c9cd04674ede,0xa400000000000000, + 0xfc6f7c4045812296,0x4d00000000000000, + 0x9dc5ada82b70b59d,0xf020000000000000, + 0xc5371912364ce305,0x6c28000000000000, + 0xf684df56c3e01bc6,0xc732000000000000, + 0x9a130b963a6c115c,0x3c7f400000000000, + 0xc097ce7bc90715b3,0x4b9f100000000000, + 0xf0bdc21abb48db20,0x1e86d40000000000, + 0x96769950b50d88f4,0x1314448000000000, + 0xbc143fa4e250eb31,0x17d955a000000000, + 0xeb194f8e1ae525fd,0x5dcfab0800000000, + 0x92efd1b8d0cf37be,0x5aa1cae500000000, + 0xb7abc627050305ad,0xf14a3d9e40000000, + 0xe596b7b0c643c719,0x6d9ccd05d0000000, + 0x8f7e32ce7bea5c6f,0xe4820023a2000000, + 0xb35dbf821ae4f38b,0xdda2802c8a800000, + 0xe0352f62a19e306e,0xd50b2037ad200000, + 0x8c213d9da502de45,0x4526f422cc340000, + 0xaf298d050e4395d6,0x9670b12b7f410000, + 0xdaf3f04651d47b4c,0x3c0cdd765f114000, + 0x88d8762bf324cd0f,0xa5880a69fb6ac800, + 0xab0e93b6efee0053,0x8eea0d047a457a00, + 0xd5d238a4abe98068,0x72a4904598d6d880, + 0x85a36366eb71f041,0x47a6da2b7f864750, + 0xa70c3c40a64e6c51,0x999090b65f67d924, + 0xd0cf4b50cfe20765,0xfff4b4e3f741cf6d, + 0x82818f1281ed449f,0xbff8f10e7a8921a4, + 0xa321f2d7226895c7,0xaff72d52192b6a0d, + 0xcbea6f8ceb02bb39,0x9bf4f8a69f764490, + 0xfee50b7025c36a08,0x2f236d04753d5b4, + 0x9f4f2726179a2245,0x1d762422c946590, + 0xc722f0ef9d80aad6,0x424d3ad2b7b97ef5, + 0xf8ebad2b84e0d58b,0xd2e0898765a7deb2, + 0x9b934c3b330c8577,0x63cc55f49f88eb2f, + 0xc2781f49ffcfa6d5,0x3cbf6b71c76b25fb, + 0xf316271c7fc3908a,0x8bef464e3945ef7a, + 0x97edd871cfda3a56,0x97758bf0e3cbb5ac, + 0xbde94e8e43d0c8ec,0x3d52eeed1cbea317, + 0xed63a231d4c4fb27,0x4ca7aaa863ee4bdd, + 0x945e455f24fb1cf8,0x8fe8caa93e74ef6a, + 0xb975d6b6ee39e436,0xb3e2fd538e122b44, + 0xe7d34c64a9c85d44,0x60dbbca87196b616, + 0x90e40fbeea1d3a4a,0xbc8955e946fe31cd, + 0xb51d13aea4a488dd,0x6babab6398bdbe41, + 0xe264589a4dcdab14,0xc696963c7eed2dd1, + 0x8d7eb76070a08aec,0xfc1e1de5cf543ca2, + 0xb0de65388cc8ada8,0x3b25a55f43294bcb, + 0xdd15fe86affad912,0x49ef0eb713f39ebe, + 0x8a2dbf142dfcc7ab,0x6e3569326c784337, + 0xacb92ed9397bf996,0x49c2c37f07965404, + 0xd7e77a8f87daf7fb,0xdc33745ec97be906, + 0x86f0ac99b4e8dafd,0x69a028bb3ded71a3, + 0xa8acd7c0222311bc,0xc40832ea0d68ce0c, + 0xd2d80db02aabd62b,0xf50a3fa490c30190, + 0x83c7088e1aab65db,0x792667c6da79e0fa, + 0xa4b8cab1a1563f52,0x577001b891185938, + 0xcde6fd5e09abcf26,0xed4c0226b55e6f86, + 0x80b05e5ac60b6178,0x544f8158315b05b4, + 0xa0dc75f1778e39d6,0x696361ae3db1c721, + 0xc913936dd571c84c,0x3bc3a19cd1e38e9, + 0xfb5878494ace3a5f,0x4ab48a04065c723, + 0x9d174b2dcec0e47b,0x62eb0d64283f9c76, + 0xc45d1df942711d9a,0x3ba5d0bd324f8394, + 0xf5746577930d6500,0xca8f44ec7ee36479, + 0x9968bf6abbe85f20,0x7e998b13cf4e1ecb, + 0xbfc2ef456ae276e8,0x9e3fedd8c321a67e, + 0xefb3ab16c59b14a2,0xc5cfe94ef3ea101e, + 0x95d04aee3b80ece5,0xbba1f1d158724a12, + 0xbb445da9ca61281f,0x2a8a6e45ae8edc97, + 0xea1575143cf97226,0xf52d09d71a3293bd, + 0x924d692ca61be758,0x593c2626705f9c56, + 0xb6e0c377cfa2e12e,0x6f8b2fb00c77836c, + 0xe498f455c38b997a,0xb6dfb9c0f956447, + 0x8edf98b59a373fec,0x4724bd4189bd5eac, + 0xb2977ee300c50fe7,0x58edec91ec2cb657, + 0xdf3d5e9bc0f653e1,0x2f2967b66737e3ed, + 0x8b865b215899f46c,0xbd79e0d20082ee74, + 0xae67f1e9aec07187,0xecd8590680a3aa11, + 0xda01ee641a708de9,0xe80e6f4820cc9495, + 0x884134fe908658b2,0x3109058d147fdcdd, + 0xaa51823e34a7eede,0xbd4b46f0599fd415, + 0xd4e5e2cdc1d1ea96,0x6c9e18ac7007c91a, + 0x850fadc09923329e,0x3e2cf6bc604ddb0, + 0xa6539930bf6bff45,0x84db8346b786151c, + 0xcfe87f7cef46ff16,0xe612641865679a63, + 0x81f14fae158c5f6e,0x4fcb7e8f3f60c07e, + 0xa26da3999aef7749,0xe3be5e330f38f09d, + 0xcb090c8001ab551c,0x5cadf5bfd3072cc5, + 0xfdcb4fa002162a63,0x73d9732fc7c8f7f6, + 0x9e9f11c4014dda7e,0x2867e7fddcdd9afa, + 0xc646d63501a1511d,0xb281e1fd541501b8, + 0xf7d88bc24209a565,0x1f225a7ca91a4226, + 0x9ae757596946075f,0x3375788de9b06958, + 0xc1a12d2fc3978937,0x52d6b1641c83ae, + 0xf209787bb47d6b84,0xc0678c5dbd23a49a, + 0x9745eb4d50ce6332,0xf840b7ba963646e0, + 0xbd176620a501fbff,0xb650e5a93bc3d898, + 0xec5d3fa8ce427aff,0xa3e51f138ab4cebe, + 0x93ba47c980e98cdf,0xc66f336c36b10137, + 0xb8a8d9bbe123f017,0xb80b0047445d4184, + 0xe6d3102ad96cec1d,0xa60dc059157491e5, + 0x9043ea1ac7e41392,0x87c89837ad68db2f, + 0xb454e4a179dd1877,0x29babe4598c311fb, + 0xe16a1dc9d8545e94,0xf4296dd6fef3d67a, + 0x8ce2529e2734bb1d,0x1899e4a65f58660c, + 0xb01ae745b101e9e4,0x5ec05dcff72e7f8f, + 0xdc21a1171d42645d,0x76707543f4fa1f73, + 0x899504ae72497eba,0x6a06494a791c53a8, + 0xabfa45da0edbde69,0x487db9d17636892, + 0xd6f8d7509292d603,0x45a9d2845d3c42b6, + 0x865b86925b9bc5c2,0xb8a2392ba45a9b2, + 0xa7f26836f282b732,0x8e6cac7768d7141e, + 0xd1ef0244af2364ff,0x3207d795430cd926, + 0x8335616aed761f1f,0x7f44e6bd49e807b8, + 0xa402b9c5a8d3a6e7,0x5f16206c9c6209a6, + 0xcd036837130890a1,0x36dba887c37a8c0f, + 0x802221226be55a64,0xc2494954da2c9789, + 0xa02aa96b06deb0fd,0xf2db9baa10b7bd6c, + 0xc83553c5c8965d3d,0x6f92829494e5acc7, + 0xfa42a8b73abbf48c,0xcb772339ba1f17f9, + 0x9c69a97284b578d7,0xff2a760414536efb, + 0xc38413cf25e2d70d,0xfef5138519684aba, + 0xf46518c2ef5b8cd1,0x7eb258665fc25d69, + 0x98bf2f79d5993802,0xef2f773ffbd97a61, + 0xbeeefb584aff8603,0xaafb550ffacfd8fa, + 0xeeaaba2e5dbf6784,0x95ba2a53f983cf38, + 0x952ab45cfa97a0b2,0xdd945a747bf26183, + 0xba756174393d88df,0x94f971119aeef9e4, + 0xe912b9d1478ceb17,0x7a37cd5601aab85d, + 0x91abb422ccb812ee,0xac62e055c10ab33a, + 0xb616a12b7fe617aa,0x577b986b314d6009, + 0xe39c49765fdf9d94,0xed5a7e85fda0b80b, + 0x8e41ade9fbebc27d,0x14588f13be847307, + 0xb1d219647ae6b31c,0x596eb2d8ae258fc8, + 0xde469fbd99a05fe3,0x6fca5f8ed9aef3bb, + 0x8aec23d680043bee,0x25de7bb9480d5854, + 0xada72ccc20054ae9,0xaf561aa79a10ae6a, + 0xd910f7ff28069da4,0x1b2ba1518094da04, + 0x87aa9aff79042286,0x90fb44d2f05d0842, + 0xa99541bf57452b28,0x353a1607ac744a53, + 0xd3fa922f2d1675f2,0x42889b8997915ce8, + 0x847c9b5d7c2e09b7,0x69956135febada11, + 0xa59bc234db398c25,0x43fab9837e699095, + 0xcf02b2c21207ef2e,0x94f967e45e03f4bb, + 0x8161afb94b44f57d,0x1d1be0eebac278f5, + 0xa1ba1ba79e1632dc,0x6462d92a69731732, + 0xca28a291859bbf93,0x7d7b8f7503cfdcfe, + 0xfcb2cb35e702af78,0x5cda735244c3d43e, + 0x9defbf01b061adab,0x3a0888136afa64a7, + 0xc56baec21c7a1916,0x88aaa1845b8fdd0, + 0xf6c69a72a3989f5b,0x8aad549e57273d45, + 0x9a3c2087a63f6399,0x36ac54e2f678864b, + 0xc0cb28a98fcf3c7f,0x84576a1bb416a7dd, + 0xf0fdf2d3f3c30b9f,0x656d44a2a11c51d5, + 0x969eb7c47859e743,0x9f644ae5a4b1b325, + 0xbc4665b596706114,0x873d5d9f0dde1fee, + 0xeb57ff22fc0c7959,0xa90cb506d155a7ea, + 0x9316ff75dd87cbd8,0x9a7f12442d588f2, + 0xb7dcbf5354e9bece,0xc11ed6d538aeb2f, + 0xe5d3ef282a242e81,0x8f1668c8a86da5fa, + 0x8fa475791a569d10,0xf96e017d694487bc, + 0xb38d92d760ec4455,0x37c981dcc395a9ac, + 0xe070f78d3927556a,0x85bbe253f47b1417, + 0x8c469ab843b89562,0x93956d7478ccec8e, + 0xaf58416654a6babb,0x387ac8d1970027b2, + 0xdb2e51bfe9d0696a,0x6997b05fcc0319e, + 0x88fcf317f22241e2,0x441fece3bdf81f03, + 0xab3c2fddeeaad25a,0xd527e81cad7626c3, + 0xd60b3bd56a5586f1,0x8a71e223d8d3b074, + 0x85c7056562757456,0xf6872d5667844e49, + 0xa738c6bebb12d16c,0xb428f8ac016561db, + 0xd106f86e69d785c7,0xe13336d701beba52, + 0x82a45b450226b39c,0xecc0024661173473, + 0xa34d721642b06084,0x27f002d7f95d0190, + 0xcc20ce9bd35c78a5,0x31ec038df7b441f4, + 0xff290242c83396ce,0x7e67047175a15271, + 0x9f79a169bd203e41,0xf0062c6e984d386, + 0xc75809c42c684dd1,0x52c07b78a3e60868, + 0xf92e0c3537826145,0xa7709a56ccdf8a82, + 0x9bbcc7a142b17ccb,0x88a66076400bb691, + 0xc2abf989935ddbfe,0x6acff893d00ea435, + 0xf356f7ebf83552fe,0x583f6b8c4124d43, + 0x98165af37b2153de,0xc3727a337a8b704a, + 0xbe1bf1b059e9a8d6,0x744f18c0592e4c5c, + 0xeda2ee1c7064130c,0x1162def06f79df73, + 0x9485d4d1c63e8be7,0x8addcb5645ac2ba8, + 0xb9a74a0637ce2ee1,0x6d953e2bd7173692, + 0xe8111c87c5c1ba99,0xc8fa8db6ccdd0437, + 0x910ab1d4db9914a0,0x1d9c9892400a22a2, + 0xb54d5e4a127f59c8,0x2503beb6d00cab4b, + 0xe2a0b5dc971f303a,0x2e44ae64840fd61d, + 0x8da471a9de737e24,0x5ceaecfed289e5d2, + 0xb10d8e1456105dad,0x7425a83e872c5f47, + 0xdd50f1996b947518,0xd12f124e28f77719, + 0x8a5296ffe33cc92f,0x82bd6b70d99aaa6f, + 0xace73cbfdc0bfb7b,0x636cc64d1001550b, + 0xd8210befd30efa5a,0x3c47f7e05401aa4e, + 0x8714a775e3e95c78,0x65acfaec34810a71, + 0xa8d9d1535ce3b396,0x7f1839a741a14d0d, + 0xd31045a8341ca07c,0x1ede48111209a050, + 0x83ea2b892091e44d,0x934aed0aab460432, + 0xa4e4b66b68b65d60,0xf81da84d5617853f, + 0xce1de40642e3f4b9,0x36251260ab9d668e, + 0x80d2ae83e9ce78f3,0xc1d72b7c6b426019, + 0xa1075a24e4421730,0xb24cf65b8612f81f, + 0xc94930ae1d529cfc,0xdee033f26797b627, + 0xfb9b7cd9a4a7443c,0x169840ef017da3b1, + 0x9d412e0806e88aa5,0x8e1f289560ee864e, + 0xc491798a08a2ad4e,0xf1a6f2bab92a27e2, + 0xf5b5d7ec8acb58a2,0xae10af696774b1db, + 0x9991a6f3d6bf1765,0xacca6da1e0a8ef29, + 0xbff610b0cc6edd3f,0x17fd090a58d32af3, + 0xeff394dcff8a948e,0xddfc4b4cef07f5b0, + 0x95f83d0a1fb69cd9,0x4abdaf101564f98e, + 0xbb764c4ca7a4440f,0x9d6d1ad41abe37f1, + 0xea53df5fd18d5513,0x84c86189216dc5ed, + 0x92746b9be2f8552c,0x32fd3cf5b4e49bb4, + 0xb7118682dbb66a77,0x3fbc8c33221dc2a1, + 0xe4d5e82392a40515,0xfabaf3feaa5334a, + 0x8f05b1163ba6832d,0x29cb4d87f2a7400e, + 0xb2c71d5bca9023f8,0x743e20e9ef511012, + 0xdf78e4b2bd342cf6,0x914da9246b255416, + 0x8bab8eefb6409c1a,0x1ad089b6c2f7548e, + 0xae9672aba3d0c320,0xa184ac2473b529b1, + 0xda3c0f568cc4f3e8,0xc9e5d72d90a2741e, + 0x8865899617fb1871,0x7e2fa67c7a658892, + 0xaa7eebfb9df9de8d,0xddbb901b98feeab7, + 0xd51ea6fa85785631,0x552a74227f3ea565, + 0x8533285c936b35de,0xd53a88958f87275f, + 0xa67ff273b8460356,0x8a892abaf368f137, + 0xd01fef10a657842c,0x2d2b7569b0432d85, + 0x8213f56a67f6b29b,0x9c3b29620e29fc73, + 0xa298f2c501f45f42,0x8349f3ba91b47b8f, + 0xcb3f2f7642717713,0x241c70a936219a73, + 0xfe0efb53d30dd4d7,0xed238cd383aa0110, + 0x9ec95d1463e8a506,0xf4363804324a40aa, + 0xc67bb4597ce2ce48,0xb143c6053edcd0d5, + 0xf81aa16fdc1b81da,0xdd94b7868e94050a, + 0x9b10a4e5e9913128,0xca7cf2b4191c8326, + 0xc1d4ce1f63f57d72,0xfd1c2f611f63a3f0, + 0xf24a01a73cf2dccf,0xbc633b39673c8cec, + 0x976e41088617ca01,0xd5be0503e085d813, + 0xbd49d14aa79dbc82,0x4b2d8644d8a74e18, + 0xec9c459d51852ba2,0xddf8e7d60ed1219e, + 0x93e1ab8252f33b45,0xcabb90e5c942b503, + 0xb8da1662e7b00a17,0x3d6a751f3b936243, + 0xe7109bfba19c0c9d,0xcc512670a783ad4, + 0x906a617d450187e2,0x27fb2b80668b24c5, + 0xb484f9dc9641e9da,0xb1f9f660802dedf6, + 0xe1a63853bbd26451,0x5e7873f8a0396973, + 0x8d07e33455637eb2,0xdb0b487b6423e1e8, + 0xb049dc016abc5e5f,0x91ce1a9a3d2cda62, + 0xdc5c5301c56b75f7,0x7641a140cc7810fb, + 0x89b9b3e11b6329ba,0xa9e904c87fcb0a9d, + 0xac2820d9623bf429,0x546345fa9fbdcd44, + 0xd732290fbacaf133,0xa97c177947ad4095, + 0x867f59a9d4bed6c0,0x49ed8eabcccc485d, + 0xa81f301449ee8c70,0x5c68f256bfff5a74, + 0xd226fc195c6a2f8c,0x73832eec6fff3111, + 0x83585d8fd9c25db7,0xc831fd53c5ff7eab, + 0xa42e74f3d032f525,0xba3e7ca8b77f5e55, + 0xcd3a1230c43fb26f,0x28ce1bd2e55f35eb, + 0x80444b5e7aa7cf85,0x7980d163cf5b81b3, + 0xa0555e361951c366,0xd7e105bcc332621f, + 0xc86ab5c39fa63440,0x8dd9472bf3fefaa7, + 0xfa856334878fc150,0xb14f98f6f0feb951, + 0x9c935e00d4b9d8d2,0x6ed1bf9a569f33d3, + 0xc3b8358109e84f07,0xa862f80ec4700c8, + 0xf4a642e14c6262c8,0xcd27bb612758c0fa, + 0x98e7e9cccfbd7dbd,0x8038d51cb897789c, + 0xbf21e44003acdd2c,0xe0470a63e6bd56c3, + 0xeeea5d5004981478,0x1858ccfce06cac74, + 0x95527a5202df0ccb,0xf37801e0c43ebc8, + 0xbaa718e68396cffd,0xd30560258f54e6ba, + 0xe950df20247c83fd,0x47c6b82ef32a2069, + 0x91d28b7416cdd27e,0x4cdc331d57fa5441, + 0xb6472e511c81471d,0xe0133fe4adf8e952, + 0xe3d8f9e563a198e5,0x58180fddd97723a6, + 0x8e679c2f5e44ff8f,0x570f09eaa7ea7648,}; +using powers = powers_template<>; + +} + +#endif + +#ifndef FASTFLOAT_DECIMAL_TO_BINARY_H +#define FASTFLOAT_DECIMAL_TO_BINARY_H + +#include +#include +#include +#include +#include +#include +#include + +namespace duckdb_fast_float { + +// This will compute or rather approximate w * 5**q and return a pair of 64-bit words approximating +// the result, with the "high" part corresponding to the most significant bits and the +// low part corresponding to the least significant bits. +// +template +fastfloat_really_inline +value128 compute_product_approximation(int64_t q, uint64_t w) { + const int index = 2 * int(q - powers::smallest_power_of_five); + // For small values of q, e.g., q in [0,27], the answer is always exact because + // The line value128 firstproduct = full_multiplication(w, power_of_five_128[index]); + // gives the exact answer. + value128 firstproduct = full_multiplication(w, powers::power_of_five_128[index]); + static_assert((bit_precision >= 0) && (bit_precision <= 64), " precision should be in (0,64]"); + constexpr uint64_t precision_mask = (bit_precision < 64) ? + (uint64_t(0xFFFFFFFFFFFFFFFF) >> bit_precision) + : uint64_t(0xFFFFFFFFFFFFFFFF); + if((firstproduct.high & precision_mask) == precision_mask) { // could further guard with (lower + w < lower) + // regarding the second product, we only need secondproduct.high, but our expectation is that the compiler will optimize this extra work away if needed. + value128 secondproduct = full_multiplication(w, powers::power_of_five_128[index + 1]); + firstproduct.low += secondproduct.high; + if(secondproduct.high > firstproduct.low) { + firstproduct.high++; + } + } + return firstproduct; +} + +namespace detail { +/** + * For q in (0,350), we have that + * f = (((152170 + 65536) * q ) >> 16); + * is equal to + * floor(p) + q + * where + * p = log(5**q)/log(2) = q * log(5)/log(2) + * + * For negative values of q in (-400,0), we have that + * f = (((152170 + 65536) * q ) >> 16); + * is equal to + * -ceil(p) + q + * where + * p = log(5**-q)/log(2) = -q * log(5)/log(2) + */ + fastfloat_really_inline int power(int q) noexcept { + return (((152170 + 65536) * q) >> 16) + 63; + } +} // namespace detail + + +// w * 10 ** q +// The returned value should be a valid ieee64 number that simply need to be packed. +// However, in some very rare cases, the computation will fail. In such cases, we +// return an adjusted_mantissa with a negative power of 2: the caller should recompute +// in such cases. +template +fastfloat_really_inline +adjusted_mantissa compute_float(int64_t q, uint64_t w) noexcept { + adjusted_mantissa answer; + if ((w == 0) || (q < binary::smallest_power_of_ten())) { + answer.power2 = 0; + answer.mantissa = 0; + // result should be zero + return answer; + } + if (q > binary::largest_power_of_ten()) { + // we want to get infinity: + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + return answer; + } + // At this point in time q is in [powers::smallest_power_of_five, powers::largest_power_of_five]. + + // We want the most significant bit of i to be 1. Shift if needed. + int lz = leading_zeroes(w); + w <<= lz; + + // The required precision is binary::mantissa_explicit_bits() + 3 because + // 1. We need the implicit bit + // 2. We need an extra bit for rounding purposes + // 3. We might lose a bit due to the "upperbit" routine (result too small, requiring a shift) + + value128 product = compute_product_approximation(q, w); + if(product.low == 0xFFFFFFFFFFFFFFFF) { // could guard it further + // In some very rare cases, this could happen, in which case we might need a more accurate + // computation that what we can provide cheaply. This is very, very unlikely. + // + const bool inside_safe_exponent = (q >= -27) && (q <= 55); // always good because 5**q <2**128 when q>=0, + // and otherwise, for q<0, we have 5**-q<2**64 and the 128-bit reciprocal allows for exact computation. + if(!inside_safe_exponent) { + answer.power2 = -1; // This (a negative value) indicates an error condition. + return answer; + } + } + // The "compute_product_approximation" function can be slightly slower than a branchless approach: + // value128 product = compute_product(q, w); + // but in practice, we can win big with the compute_product_approximation if its additional branch + // is easily predicted. Which is best is data specific. + int upperbit = int(product.high >> 63); + + answer.mantissa = product.high >> (upperbit + 64 - binary::mantissa_explicit_bits() - 3); + + answer.power2 = int(detail::power(int(q)) + upperbit - lz - binary::minimum_exponent()); + if (answer.power2 <= 0) { // we have a subnormal? + // Here have that answer.power2 <= 0 so -answer.power2 >= 0 + if(-answer.power2 + 1 >= 64) { // if we have more than 64 bits below the minimum exponent, you have a zero for sure. + answer.power2 = 0; + answer.mantissa = 0; + // result should be zero + return answer; + } + // next line is safe because -answer.power2 + 1 < 64 + answer.mantissa >>= -answer.power2 + 1; + // Thankfully, we can't have both "round-to-even" and subnormals because + // "round-to-even" only occurs for powers close to 0. + answer.mantissa += (answer.mantissa & 1); // round up + answer.mantissa >>= 1; + // There is a weird scenario where we don't have a subnormal but just. + // Suppose we start with 2.2250738585072013e-308, we end up + // with 0x3fffffffffffff x 2^-1023-53 which is technically subnormal + // whereas 0x40000000000000 x 2^-1023-53 is normal. Now, we need to round + // up 0x3fffffffffffff x 2^-1023-53 and once we do, we are no longer + // subnormal, but we can only know this after rounding. + // So we only declare a subnormal if we are smaller than the threshold. + answer.power2 = (answer.mantissa < (uint64_t(1) << binary::mantissa_explicit_bits())) ? 0 : 1; + return answer; + } + + // usually, we round *up*, but if we fall right in between and and we have an + // even basis, we need to round down + // We are only concerned with the cases where 5**q fits in single 64-bit word. + if ((product.low <= 1) && (q >= binary::min_exponent_round_to_even()) && (q <= binary::max_exponent_round_to_even()) && + ((answer.mantissa & 3) == 1) ) { // we may fall between two floats! + // To be in-between two floats we need that in doing + // answer.mantissa = product.high >> (upperbit + 64 - binary::mantissa_explicit_bits() - 3); + // ... we dropped out only zeroes. But if this happened, then we can go back!!! + if((answer.mantissa << (upperbit + 64 - binary::mantissa_explicit_bits() - 3)) == product.high) { + answer.mantissa &= ~uint64_t(1); // flip it so that we do not round up + } + } + + answer.mantissa += (answer.mantissa & 1); // round up + answer.mantissa >>= 1; + if (answer.mantissa >= (uint64_t(2) << binary::mantissa_explicit_bits())) { + answer.mantissa = (uint64_t(1) << binary::mantissa_explicit_bits()); + answer.power2++; // undo previous addition + } + + answer.mantissa &= ~(uint64_t(1) << binary::mantissa_explicit_bits()); + if (answer.power2 >= binary::infinite_power()) { // infinity + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + } + return answer; +} + + +} // namespace duckdb_fast_float + +#endif + + +#ifndef FASTFLOAT_ASCII_NUMBER_H +#define FASTFLOAT_ASCII_NUMBER_H + +#include +#include +#include +#include + + +namespace duckdb_fast_float { + +// Next function can be micro-optimized, but compilers are entirely +// able to optimize it well. +fastfloat_really_inline bool is_integer(char c) noexcept { return c >= '0' && c <= '9'; } + +fastfloat_really_inline uint64_t byteswap(uint64_t val) { + return (val & 0xFF00000000000000) >> 56 + | (val & 0x00FF000000000000) >> 40 + | (val & 0x0000FF0000000000) >> 24 + | (val & 0x000000FF00000000) >> 8 + | (val & 0x00000000FF000000) << 8 + | (val & 0x0000000000FF0000) << 24 + | (val & 0x000000000000FF00) << 40 + | (val & 0x00000000000000FF) << 56; +} + +fastfloat_really_inline uint64_t read_u64(const char *chars) { + uint64_t val; + ::memcpy(&val, chars, sizeof(uint64_t)); +#if FASTFLOAT_IS_BIG_ENDIAN == 1 + // Need to read as-if the number was in little-endian order. + val = byteswap(val); +#endif + return val; +} + +fastfloat_really_inline void write_u64(uint8_t *chars, uint64_t val) { +#if FASTFLOAT_IS_BIG_ENDIAN == 1 + // Need to read as-if the number was in little-endian order. + val = byteswap(val); +#endif + ::memcpy(chars, &val, sizeof(uint64_t)); +} + +// credit @aqrit +fastfloat_really_inline uint32_t parse_eight_digits_unrolled(uint64_t val) { + const uint64_t mask = 0x000000FF000000FF; + const uint64_t mul1 = 0x000F424000000064; // 100 + (1000000ULL << 32) + const uint64_t mul2 = 0x0000271000000001; // 1 + (10000ULL << 32) + val -= 0x3030303030303030; + val = (val * 10) + (val >> 8); // val = (val * 2561) >> 8; + val = (((val & mask) * mul1) + (((val >> 16) & mask) * mul2)) >> 32; + return uint32_t(val); +} + +fastfloat_really_inline uint32_t parse_eight_digits_unrolled(const char *chars) noexcept { + return parse_eight_digits_unrolled(read_u64(chars)); +} + +// credit @aqrit +fastfloat_really_inline bool is_made_of_eight_digits_fast(uint64_t val) noexcept { + return !((((val + 0x4646464646464646) | (val - 0x3030303030303030)) & + 0x8080808080808080)); +} + +fastfloat_really_inline bool is_made_of_eight_digits_fast(const char *chars) noexcept { + return is_made_of_eight_digits_fast(read_u64(chars)); +} + +struct parsed_number_string { + int64_t exponent; + uint64_t mantissa; + const char *lastmatch; + bool negative; + bool valid; + bool too_many_digits; +}; + + +// Assuming that you use no more than 19 digits, this will +// parse an ASCII string. +fastfloat_really_inline +parsed_number_string parse_number_string(const char *p, const char *pend, const char decimal_separator, chars_format fmt) noexcept { + parsed_number_string answer; + answer.valid = false; + answer.too_many_digits = false; + answer.negative = (*p == '-'); + if (*p == '-') { // C++17 20.19.3.(7.1) explicitly forbids '+' sign here + ++p; + if (p == pend) { + return answer; + } + if (!is_integer(*p) && (*p != decimal_separator)) { // a sign must be followed by an integer or the dot + return answer; + } + } + const char *const start_digits = p; + + uint64_t i = 0; // an unsigned int avoids signed overflows (which are bad) + + while ((p != pend) && is_integer(*p)) { + // a multiplication by 10 is cheaper than an arbitrary integer + // multiplication + i = 10 * i + + uint64_t(*p - '0'); // might overflow, we will handle the overflow later + ++p; + } + const char *const end_of_integer_part = p; + int64_t digit_count = int64_t(end_of_integer_part - start_digits); + int64_t exponent = 0; + if ((p != pend) && (*p == decimal_separator)) { + ++p; + // Fast approach only tested under little endian systems + if ((p + 8 <= pend) && is_made_of_eight_digits_fast(p)) { + i = i * 100000000 + parse_eight_digits_unrolled(p); // in rare cases, this will overflow, but that's ok + p += 8; + if ((p + 8 <= pend) && is_made_of_eight_digits_fast(p)) { + i = i * 100000000 + parse_eight_digits_unrolled(p); // in rare cases, this will overflow, but that's ok + p += 8; + } + } + while ((p != pend) && is_integer(*p)) { + uint8_t digit = uint8_t(*p - '0'); + ++p; + i = i * 10 + digit; // in rare cases, this will overflow, but that's ok + } + exponent = end_of_integer_part + 1 - p; + digit_count -= exponent; + } + // we must have encountered at least one integer! + if (digit_count == 0) { + return answer; + } + int64_t exp_number = 0; // explicit exponential part + if ((fmt & chars_format::scientific) && (p != pend) && (('e' == *p) || ('E' == *p))) { + const char * location_of_e = p; + ++p; + bool neg_exp = false; + if ((p != pend) && ('-' == *p)) { + neg_exp = true; + ++p; + } else if ((p != pend) && ('+' == *p)) { // '+' on exponent is allowed by C++17 20.19.3.(7.1) + ++p; + } + if ((p == pend) || !is_integer(*p)) { + if(!(fmt & chars_format::fixed)) { + // We are in error. + return answer; + } + // Otherwise, we will be ignoring the 'e'. + p = location_of_e; + } else { + while ((p != pend) && is_integer(*p)) { + uint8_t digit = uint8_t(*p - '0'); + if (exp_number < 0x10000) { + exp_number = 10 * exp_number + digit; + } + ++p; + } + if(neg_exp) { exp_number = - exp_number; } + exponent += exp_number; + } + } else { + // If it scientific and not fixed, we have to bail out. + if((fmt & chars_format::scientific) && !(fmt & chars_format::fixed)) { return answer; } + } + answer.lastmatch = p; + answer.valid = true; + + // If we frequently had to deal with long strings of digits, + // we could extend our code by using a 128-bit integer instead + // of a 64-bit integer. However, this is uncommon. + // + // We can deal with up to 19 digits. + if (digit_count > 19) { // this is uncommon + // It is possible that the integer had an overflow. + // We have to handle the case where we have 0.0000somenumber. + // We need to be mindful of the case where we only have zeroes... + // E.g., 0.000000000...000. + const char *start = start_digits; + while ((start != pend) && (*start == '0' || *start == decimal_separator)) { + if(*start == '0') { digit_count --; } + start++; + } + if (digit_count > 19) { + answer.too_many_digits = true; + // Let us start again, this time, avoiding overflows. + i = 0; + p = start_digits; + const uint64_t minimal_nineteen_digit_integer{1000000000000000000}; + while((i < minimal_nineteen_digit_integer) && (p != pend) && is_integer(*p)) { + i = i * 10 + uint64_t(*p - '0'); + ++p; + } + if (i >= minimal_nineteen_digit_integer) { // We have a big integers + exponent = end_of_integer_part - p + exp_number; + } else { // We have a value with a fractional component. + p++; // skip the decimal_separator + const char *first_after_period = p; + while((i < minimal_nineteen_digit_integer) && (p != pend) && is_integer(*p)) { + i = i * 10 + uint64_t(*p - '0'); + ++p; + } + exponent = first_after_period - p + exp_number; + } + // We have now corrected both exponent and i, to a truncated value + } + } + answer.exponent = exponent; + answer.mantissa = i; + return answer; +} + + +// This should always succeed since it follows a call to parse_number_string +// This function could be optimized. In particular, we could stop after 19 digits +// and try to bail out. Furthermore, we should be able to recover the computed +// exponent from the pass in parse_number_string. +fastfloat_really_inline decimal parse_decimal(const char *p, const char *pend, const char decimal_separator) noexcept { + decimal answer; + answer.num_digits = 0; + answer.decimal_point = 0; + answer.truncated = false; + answer.negative = (*p == '-'); + if (*p == '-') { // C++17 20.19.3.(7.1) explicitly forbids '+' sign here + ++p; + } + // skip leading zeroes + while ((p != pend) && (*p == '0')) { + ++p; + } + while ((p != pend) && is_integer(*p)) { + if (answer.num_digits < max_digits) { + answer.digits[answer.num_digits] = uint8_t(*p - '0'); + } + answer.num_digits++; + ++p; + } + if ((p != pend) && (*p == decimal_separator)) { + ++p; + const char *first_after_period = p; + // if we have not yet encountered a zero, we have to skip it as well + if(answer.num_digits == 0) { + // skip zeros + while ((p != pend) && (*p == '0')) { + ++p; + } + } + // We expect that this loop will often take the bulk of the running time + // because when a value has lots of digits, these digits often + while ((p + 8 <= pend) && (answer.num_digits + 8 < max_digits)) { + uint64_t val = read_u64(p); + if(! is_made_of_eight_digits_fast(val)) { break; } + // We have eight digits, process them in one go! + val -= 0x3030303030303030; + write_u64(answer.digits + answer.num_digits, val); + answer.num_digits += 8; + p += 8; + } + while ((p != pend) && is_integer(*p)) { + if (answer.num_digits < max_digits) { + answer.digits[answer.num_digits] = uint8_t(*p - '0'); + } + answer.num_digits++; + ++p; + } + answer.decimal_point = int32_t(first_after_period - p); + } + // We want num_digits to be the number of significant digits, excluding + // leading *and* trailing zeros! Otherwise the truncated flag later is + // going to be misleading. + if(answer.num_digits > 0) { + // We potentially need the answer.num_digits > 0 guard because we + // prune leading zeros. So with answer.num_digits > 0, we know that + // we have at least one non-zero digit. + const char *preverse = p - 1; + int32_t trailing_zeros = 0; + while ((*preverse == '0') || (*preverse == decimal_separator)) { + if(*preverse == '0') { trailing_zeros++; }; + --preverse; + } + answer.decimal_point += int32_t(answer.num_digits); + answer.num_digits -= uint32_t(trailing_zeros); + } + if(answer.num_digits > max_digits) { + answer.truncated = true; + answer.num_digits = max_digits; + } + if ((p != pend) && (('e' == *p) || ('E' == *p))) { + ++p; + bool neg_exp = false; + if ((p != pend) && ('-' == *p)) { + neg_exp = true; + ++p; + } else if ((p != pend) && ('+' == *p)) { // '+' on exponent is allowed by C++17 20.19.3.(7.1) + ++p; + } + int32_t exp_number = 0; // exponential part + while ((p != pend) && is_integer(*p)) { + uint8_t digit = uint8_t(*p - '0'); + if (exp_number < 0x10000) { + exp_number = 10 * exp_number + digit; + } + ++p; + } + answer.decimal_point += (neg_exp ? -exp_number : exp_number); + } + // In very rare cases, we may have fewer than 19 digits, we want to be able to reliably + // assume that all digits up to max_digit_without_overflow have been initialized. + for(uint32_t i = answer.num_digits; i < max_digit_without_overflow; i++) { answer.digits[i] = 0; } + + return answer; +} +} // namespace duckdb_fast_float + +#endif + + +#ifndef FASTFLOAT_GENERIC_DECIMAL_TO_BINARY_H +#define FASTFLOAT_GENERIC_DECIMAL_TO_BINARY_H + +/** + * This code is meant to handle the case where we have more than 19 digits. + * + * It is based on work by Nigel Tao (at https://github.com/google/wuffs/) + * who credits Ken Thompson for the design (via a reference to the Go source + * code). + * + * Rob Pike suggested that this algorithm be called "Simple Decimal Conversion". + * + * It is probably not very fast but it is a fallback that should almost never + * be used in real life. Though it is not fast, it is "easily" understood and debugged. + **/ +#include + +namespace duckdb_fast_float { + +namespace detail { + +// remove all final zeroes +inline void trim(decimal &h) { + while ((h.num_digits > 0) && (h.digits[h.num_digits - 1] == 0)) { + h.num_digits--; + } +} + + + +inline uint32_t number_of_digits_decimal_left_shift(const decimal &h, uint32_t shift) { + shift &= 63; + const static uint16_t number_of_digits_decimal_left_shift_table[65] = { + 0x0000, 0x0800, 0x0801, 0x0803, 0x1006, 0x1009, 0x100D, 0x1812, 0x1817, + 0x181D, 0x2024, 0x202B, 0x2033, 0x203C, 0x2846, 0x2850, 0x285B, 0x3067, + 0x3073, 0x3080, 0x388E, 0x389C, 0x38AB, 0x38BB, 0x40CC, 0x40DD, 0x40EF, + 0x4902, 0x4915, 0x4929, 0x513E, 0x5153, 0x5169, 0x5180, 0x5998, 0x59B0, + 0x59C9, 0x61E3, 0x61FD, 0x6218, 0x6A34, 0x6A50, 0x6A6D, 0x6A8B, 0x72AA, + 0x72C9, 0x72E9, 0x7B0A, 0x7B2B, 0x7B4D, 0x8370, 0x8393, 0x83B7, 0x83DC, + 0x8C02, 0x8C28, 0x8C4F, 0x9477, 0x949F, 0x94C8, 0x9CF2, 0x051C, 0x051C, + 0x051C, 0x051C, + }; + uint32_t x_a = number_of_digits_decimal_left_shift_table[shift]; + uint32_t x_b = number_of_digits_decimal_left_shift_table[shift + 1]; + uint32_t num_new_digits = x_a >> 11; + uint32_t pow5_a = 0x7FF & x_a; + uint32_t pow5_b = 0x7FF & x_b; + const static uint8_t + number_of_digits_decimal_left_shift_table_powers_of_5[0x051C] = { + 5, 2, 5, 1, 2, 5, 6, 2, 5, 3, 1, 2, 5, 1, 5, 6, 2, 5, 7, 8, 1, 2, 5, 3, + 9, 0, 6, 2, 5, 1, 9, 5, 3, 1, 2, 5, 9, 7, 6, 5, 6, 2, 5, 4, 8, 8, 2, 8, + 1, 2, 5, 2, 4, 4, 1, 4, 0, 6, 2, 5, 1, 2, 2, 0, 7, 0, 3, 1, 2, 5, 6, 1, + 0, 3, 5, 1, 5, 6, 2, 5, 3, 0, 5, 1, 7, 5, 7, 8, 1, 2, 5, 1, 5, 2, 5, 8, + 7, 8, 9, 0, 6, 2, 5, 7, 6, 2, 9, 3, 9, 4, 5, 3, 1, 2, 5, 3, 8, 1, 4, 6, + 9, 7, 2, 6, 5, 6, 2, 5, 1, 9, 0, 7, 3, 4, 8, 6, 3, 2, 8, 1, 2, 5, 9, 5, + 3, 6, 7, 4, 3, 1, 6, 4, 0, 6, 2, 5, 4, 7, 6, 8, 3, 7, 1, 5, 8, 2, 0, 3, + 1, 2, 5, 2, 3, 8, 4, 1, 8, 5, 7, 9, 1, 0, 1, 5, 6, 2, 5, 1, 1, 9, 2, 0, + 9, 2, 8, 9, 5, 5, 0, 7, 8, 1, 2, 5, 5, 9, 6, 0, 4, 6, 4, 4, 7, 7, 5, 3, + 9, 0, 6, 2, 5, 2, 9, 8, 0, 2, 3, 2, 2, 3, 8, 7, 6, 9, 5, 3, 1, 2, 5, 1, + 4, 9, 0, 1, 1, 6, 1, 1, 9, 3, 8, 4, 7, 6, 5, 6, 2, 5, 7, 4, 5, 0, 5, 8, + 0, 5, 9, 6, 9, 2, 3, 8, 2, 8, 1, 2, 5, 3, 7, 2, 5, 2, 9, 0, 2, 9, 8, 4, + 6, 1, 9, 1, 4, 0, 6, 2, 5, 1, 8, 6, 2, 6, 4, 5, 1, 4, 9, 2, 3, 0, 9, 5, + 7, 0, 3, 1, 2, 5, 9, 3, 1, 3, 2, 2, 5, 7, 4, 6, 1, 5, 4, 7, 8, 5, 1, 5, + 6, 2, 5, 4, 6, 5, 6, 6, 1, 2, 8, 7, 3, 0, 7, 7, 3, 9, 2, 5, 7, 8, 1, 2, + 5, 2, 3, 2, 8, 3, 0, 6, 4, 3, 6, 5, 3, 8, 6, 9, 6, 2, 8, 9, 0, 6, 2, 5, + 1, 1, 6, 4, 1, 5, 3, 2, 1, 8, 2, 6, 9, 3, 4, 8, 1, 4, 4, 5, 3, 1, 2, 5, + 5, 8, 2, 0, 7, 6, 6, 0, 9, 1, 3, 4, 6, 7, 4, 0, 7, 2, 2, 6, 5, 6, 2, 5, + 2, 9, 1, 0, 3, 8, 3, 0, 4, 5, 6, 7, 3, 3, 7, 0, 3, 6, 1, 3, 2, 8, 1, 2, + 5, 1, 4, 5, 5, 1, 9, 1, 5, 2, 2, 8, 3, 6, 6, 8, 5, 1, 8, 0, 6, 6, 4, 0, + 6, 2, 5, 7, 2, 7, 5, 9, 5, 7, 6, 1, 4, 1, 8, 3, 4, 2, 5, 9, 0, 3, 3, 2, + 0, 3, 1, 2, 5, 3, 6, 3, 7, 9, 7, 8, 8, 0, 7, 0, 9, 1, 7, 1, 2, 9, 5, 1, + 6, 6, 0, 1, 5, 6, 2, 5, 1, 8, 1, 8, 9, 8, 9, 4, 0, 3, 5, 4, 5, 8, 5, 6, + 4, 7, 5, 8, 3, 0, 0, 7, 8, 1, 2, 5, 9, 0, 9, 4, 9, 4, 7, 0, 1, 7, 7, 2, + 9, 2, 8, 2, 3, 7, 9, 1, 5, 0, 3, 9, 0, 6, 2, 5, 4, 5, 4, 7, 4, 7, 3, 5, + 0, 8, 8, 6, 4, 6, 4, 1, 1, 8, 9, 5, 7, 5, 1, 9, 5, 3, 1, 2, 5, 2, 2, 7, + 3, 7, 3, 6, 7, 5, 4, 4, 3, 2, 3, 2, 0, 5, 9, 4, 7, 8, 7, 5, 9, 7, 6, 5, + 6, 2, 5, 1, 1, 3, 6, 8, 6, 8, 3, 7, 7, 2, 1, 6, 1, 6, 0, 2, 9, 7, 3, 9, + 3, 7, 9, 8, 8, 2, 8, 1, 2, 5, 5, 6, 8, 4, 3, 4, 1, 8, 8, 6, 0, 8, 0, 8, + 0, 1, 4, 8, 6, 9, 6, 8, 9, 9, 4, 1, 4, 0, 6, 2, 5, 2, 8, 4, 2, 1, 7, 0, + 9, 4, 3, 0, 4, 0, 4, 0, 0, 7, 4, 3, 4, 8, 4, 4, 9, 7, 0, 7, 0, 3, 1, 2, + 5, 1, 4, 2, 1, 0, 8, 5, 4, 7, 1, 5, 2, 0, 2, 0, 0, 3, 7, 1, 7, 4, 2, 2, + 4, 8, 5, 3, 5, 1, 5, 6, 2, 5, 7, 1, 0, 5, 4, 2, 7, 3, 5, 7, 6, 0, 1, 0, + 0, 1, 8, 5, 8, 7, 1, 1, 2, 4, 2, 6, 7, 5, 7, 8, 1, 2, 5, 3, 5, 5, 2, 7, + 1, 3, 6, 7, 8, 8, 0, 0, 5, 0, 0, 9, 2, 9, 3, 5, 5, 6, 2, 1, 3, 3, 7, 8, + 9, 0, 6, 2, 5, 1, 7, 7, 6, 3, 5, 6, 8, 3, 9, 4, 0, 0, 2, 5, 0, 4, 6, 4, + 6, 7, 7, 8, 1, 0, 6, 6, 8, 9, 4, 5, 3, 1, 2, 5, 8, 8, 8, 1, 7, 8, 4, 1, + 9, 7, 0, 0, 1, 2, 5, 2, 3, 2, 3, 3, 8, 9, 0, 5, 3, 3, 4, 4, 7, 2, 6, 5, + 6, 2, 5, 4, 4, 4, 0, 8, 9, 2, 0, 9, 8, 5, 0, 0, 6, 2, 6, 1, 6, 1, 6, 9, + 4, 5, 2, 6, 6, 7, 2, 3, 6, 3, 2, 8, 1, 2, 5, 2, 2, 2, 0, 4, 4, 6, 0, 4, + 9, 2, 5, 0, 3, 1, 3, 0, 8, 0, 8, 4, 7, 2, 6, 3, 3, 3, 6, 1, 8, 1, 6, 4, + 0, 6, 2, 5, 1, 1, 1, 0, 2, 2, 3, 0, 2, 4, 6, 2, 5, 1, 5, 6, 5, 4, 0, 4, + 2, 3, 6, 3, 1, 6, 6, 8, 0, 9, 0, 8, 2, 0, 3, 1, 2, 5, 5, 5, 5, 1, 1, 1, + 5, 1, 2, 3, 1, 2, 5, 7, 8, 2, 7, 0, 2, 1, 1, 8, 1, 5, 8, 3, 4, 0, 4, 5, + 4, 1, 0, 1, 5, 6, 2, 5, 2, 7, 7, 5, 5, 5, 7, 5, 6, 1, 5, 6, 2, 8, 9, 1, + 3, 5, 1, 0, 5, 9, 0, 7, 9, 1, 7, 0, 2, 2, 7, 0, 5, 0, 7, 8, 1, 2, 5, 1, + 3, 8, 7, 7, 7, 8, 7, 8, 0, 7, 8, 1, 4, 4, 5, 6, 7, 5, 5, 2, 9, 5, 3, 9, + 5, 8, 5, 1, 1, 3, 5, 2, 5, 3, 9, 0, 6, 2, 5, 6, 9, 3, 8, 8, 9, 3, 9, 0, + 3, 9, 0, 7, 2, 2, 8, 3, 7, 7, 6, 4, 7, 6, 9, 7, 9, 2, 5, 5, 6, 7, 6, 2, + 6, 9, 5, 3, 1, 2, 5, 3, 4, 6, 9, 4, 4, 6, 9, 5, 1, 9, 5, 3, 6, 1, 4, 1, + 8, 8, 8, 2, 3, 8, 4, 8, 9, 6, 2, 7, 8, 3, 8, 1, 3, 4, 7, 6, 5, 6, 2, 5, + 1, 7, 3, 4, 7, 2, 3, 4, 7, 5, 9, 7, 6, 8, 0, 7, 0, 9, 4, 4, 1, 1, 9, 2, + 4, 4, 8, 1, 3, 9, 1, 9, 0, 6, 7, 3, 8, 2, 8, 1, 2, 5, 8, 6, 7, 3, 6, 1, + 7, 3, 7, 9, 8, 8, 4, 0, 3, 5, 4, 7, 2, 0, 5, 9, 6, 2, 2, 4, 0, 6, 9, 5, + 9, 5, 3, 3, 6, 9, 1, 4, 0, 6, 2, 5, + }; + const uint8_t *pow5 = + &number_of_digits_decimal_left_shift_table_powers_of_5[pow5_a]; + uint32_t i = 0; + uint32_t n = pow5_b - pow5_a; + for (; i < n; i++) { + if (i >= h.num_digits) { + return num_new_digits - 1; + } else if (h.digits[i] == pow5[i]) { + continue; + } else if (h.digits[i] < pow5[i]) { + return num_new_digits - 1; + } else { + return num_new_digits; + } + } + return num_new_digits; +} + +inline uint64_t round(decimal &h) { + if ((h.num_digits == 0) || (h.decimal_point < 0)) { + return 0; + } else if (h.decimal_point > 18) { + return UINT64_MAX; + } + // at this point, we know that h.decimal_point >= 0 + uint32_t dp = uint32_t(h.decimal_point); + uint64_t n = 0; + for (uint32_t i = 0; i < dp; i++) { + n = (10 * n) + ((i < h.num_digits) ? h.digits[i] : 0); + } + bool round_up = false; + if (dp < h.num_digits) { + round_up = h.digits[dp] >= 5; // normally, we round up + // but we may need to round to even! + if ((h.digits[dp] == 5) && (dp + 1 == h.num_digits)) { + round_up = h.truncated || ((dp > 0) && (1 & h.digits[dp - 1])); + } + } + if (round_up) { + n++; + } + return n; +} + +// computes h * 2^-shift +inline void decimal_left_shift(decimal &h, uint32_t shift) { + if (h.num_digits == 0) { + return; + } + uint32_t num_new_digits = number_of_digits_decimal_left_shift(h, shift); + int32_t read_index = int32_t(h.num_digits - 1); + uint32_t write_index = h.num_digits - 1 + num_new_digits; + uint64_t n = 0; + + while (read_index >= 0) { + n += uint64_t(h.digits[read_index]) << shift; + uint64_t quotient = n / 10; + uint64_t remainder = n - (10 * quotient); + if (write_index < max_digits) { + h.digits[write_index] = uint8_t(remainder); + } else if (remainder > 0) { + h.truncated = true; + } + n = quotient; + write_index--; + read_index--; + } + while (n > 0) { + uint64_t quotient = n / 10; + uint64_t remainder = n - (10 * quotient); + if (write_index < max_digits) { + h.digits[write_index] = uint8_t(remainder); + } else if (remainder > 0) { + h.truncated = true; + } + n = quotient; + write_index--; + } + h.num_digits += num_new_digits; + if (h.num_digits > max_digits) { + h.num_digits = max_digits; + } + h.decimal_point += int32_t(num_new_digits); + trim(h); +} + +// computes h * 2^shift +inline void decimal_right_shift(decimal &h, uint32_t shift) { + uint32_t read_index = 0; + uint32_t write_index = 0; + + uint64_t n = 0; + + while ((n >> shift) == 0) { + if (read_index < h.num_digits) { + n = (10 * n) + h.digits[read_index++]; + } else if (n == 0) { + return; + } else { + while ((n >> shift) == 0) { + n = 10 * n; + read_index++; + } + break; + } + } + h.decimal_point -= int32_t(read_index - 1); + if (h.decimal_point < -decimal_point_range) { // it is zero + h.num_digits = 0; + h.decimal_point = 0; + h.negative = false; + h.truncated = false; + return; + } + uint64_t mask = (uint64_t(1) << shift) - 1; + while (read_index < h.num_digits) { + uint8_t new_digit = uint8_t(n >> shift); + n = (10 * (n & mask)) + h.digits[read_index++]; + h.digits[write_index++] = new_digit; + } + while (n > 0) { + uint8_t new_digit = uint8_t(n >> shift); + n = 10 * (n & mask); + if (write_index < max_digits) { + h.digits[write_index++] = new_digit; + } else if (new_digit > 0) { + h.truncated = true; + } + } + h.num_digits = write_index; + trim(h); +} + +} // namespace detail + +template +adjusted_mantissa compute_float(decimal &d) { + adjusted_mantissa answer; + if (d.num_digits == 0) { + // should be zero + answer.power2 = 0; + answer.mantissa = 0; + return answer; + } + // At this point, going further, we can assume that d.num_digits > 0. + // + // We want to guard against excessive decimal point values because + // they can result in long running times. Indeed, we do + // shifts by at most 60 bits. We have that log(10**400)/log(2**60) ~= 22 + // which is fine, but log(10**299995)/log(2**60) ~= 16609 which is not + // fine (runs for a long time). + // + if(d.decimal_point < -324) { + // We have something smaller than 1e-324 which is always zero + // in binary64 and binary32. + // It should be zero. + answer.power2 = 0; + answer.mantissa = 0; + return answer; + } else if(d.decimal_point >= 310) { + // We have something at least as large as 0.1e310 which is + // always infinite. + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + return answer; + } + static const uint32_t max_shift = 60; + static const uint32_t num_powers = 19; + static const uint8_t decimal_powers[19] = { + 0, 3, 6, 9, 13, 16, 19, 23, 26, 29, // + 33, 36, 39, 43, 46, 49, 53, 56, 59, // + }; + int32_t exp2 = 0; + while (d.decimal_point > 0) { + uint32_t n = uint32_t(d.decimal_point); + uint32_t shift = (n < num_powers) ? decimal_powers[n] : max_shift; + detail::decimal_right_shift(d, shift); + if (d.decimal_point < -decimal_point_range) { + // should be zero + answer.power2 = 0; + answer.mantissa = 0; + return answer; + } + exp2 += int32_t(shift); + } + // We shift left toward [1/2 ... 1]. + while (d.decimal_point <= 0) { + uint32_t shift; + if (d.decimal_point == 0) { + if (d.digits[0] >= 5) { + break; + } + shift = (d.digits[0] < 2) ? 2 : 1; + } else { + uint32_t n = uint32_t(-d.decimal_point); + shift = (n < num_powers) ? decimal_powers[n] : max_shift; + } + detail::decimal_left_shift(d, shift); + if (d.decimal_point > decimal_point_range) { + // we want to get infinity: + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + return answer; + } + exp2 -= int32_t(shift); + } + // We are now in the range [1/2 ... 1] but the binary format uses [1 ... 2]. + exp2--; + constexpr int32_t minimum_exponent = binary::minimum_exponent(); + while ((minimum_exponent + 1) > exp2) { + uint32_t n = uint32_t((minimum_exponent + 1) - exp2); + if (n > max_shift) { + n = max_shift; + } + detail::decimal_right_shift(d, n); + exp2 += int32_t(n); + } + if ((exp2 - minimum_exponent) >= binary::infinite_power()) { + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + return answer; + } + + const int mantissa_size_in_bits = binary::mantissa_explicit_bits() + 1; + detail::decimal_left_shift(d, mantissa_size_in_bits); + + uint64_t mantissa = detail::round(d); + // It is possible that we have an overflow, in which case we need + // to shift back. + if(mantissa >= (uint64_t(1) << mantissa_size_in_bits)) { + detail::decimal_right_shift(d, 1); + exp2 += 1; + mantissa = detail::round(d); + if ((exp2 - minimum_exponent) >= binary::infinite_power()) { + answer.power2 = binary::infinite_power(); + answer.mantissa = 0; + return answer; + } + } + answer.power2 = exp2 - binary::minimum_exponent(); + if(mantissa < (uint64_t(1) << binary::mantissa_explicit_bits())) { answer.power2--; } + answer.mantissa = mantissa & ((uint64_t(1) << binary::mantissa_explicit_bits()) - 1); + return answer; +} + +template +adjusted_mantissa parse_long_mantissa(const char *first, const char* last) { + decimal d = parse_decimal(first, last); + return compute_float(d); +} + +} // namespace duckdb_fast_float +#endif + + +#ifndef FASTFLOAT_PARSE_NUMBER_H +#define FASTFLOAT_PARSE_NUMBER_H + +#include +#include +#include +#include +#include + +namespace duckdb_fast_float { + + +namespace detail { +/** + * Special case +inf, -inf, nan, infinity, -infinity. + * The case comparisons could be made much faster given that we know that the + * strings a null-free and fixed. + **/ +template +from_chars_result parse_infnan(const char *first, const char *last, T &value) noexcept { + from_chars_result answer; + answer.ptr = first; + answer.ec = std::errc(); // be optimistic + bool minusSign = false; + if (*first == '-') { // assume first < last, so dereference without checks; C++17 20.19.3.(7.1) explicitly forbids '+' here + minusSign = true; + ++first; + } + if (last - first >= 3) { + if (fastfloat_strncasecmp(first, "nan", 3)) { + answer.ptr = (first += 3); + value = minusSign ? -std::numeric_limits::quiet_NaN() : std::numeric_limits::quiet_NaN(); + // Check for possible nan(n-char-seq-opt), C++17 20.19.3.7, C11 7.20.1.3.3. At least MSVC produces nan(ind) and nan(snan). + if(first != last && *first == '(') { + for(const char* ptr = first + 1; ptr != last; ++ptr) { + if (*ptr == ')') { + answer.ptr = ptr + 1; // valid nan(n-char-seq-opt) + break; + } + else if(!(('a' <= *ptr && *ptr <= 'z') || ('A' <= *ptr && *ptr <= 'Z') || ('0' <= *ptr && *ptr <= '9') || *ptr == '_')) + break; // forbidden char, not nan(n-char-seq-opt) + } + } + return answer; + } + if (fastfloat_strncasecmp(first, "inf", 3)) { + if ((last - first >= 8) && fastfloat_strncasecmp(first + 3, "inity", 5)) { + answer.ptr = first + 8; + } else { + answer.ptr = first + 3; + } + value = minusSign ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + return answer; + } + } + answer.ec = std::errc::invalid_argument; + return answer; +} + +template +fastfloat_really_inline void to_float(bool negative, adjusted_mantissa am, T &value) { + uint64_t word = am.mantissa; + word |= uint64_t(am.power2) << binary_format::mantissa_explicit_bits(); + word = negative + ? word | (uint64_t(1) << binary_format::sign_index()) : word; +#if FASTFLOAT_IS_BIG_ENDIAN == 1 + if (std::is_same::value) { + ::memcpy(&value, (char *)&word + 4, sizeof(T)); // extract value at offset 4-7 if float on big-endian + } else { + ::memcpy(&value, &word, sizeof(T)); + } +#else + // For little-endian systems: + ::memcpy(&value, &word, sizeof(T)); +#endif +} + +} // namespace detail + + + +template +from_chars_result from_chars(const char *first, const char *last, + T &value, const char decimal_separator, chars_format fmt + /*= chars_format::general*/) noexcept { + static_assert (std::is_same::value || std::is_same::value, "only float and double are supported"); + + + from_chars_result answer; + if (first == last) { + answer.ec = std::errc::invalid_argument; + answer.ptr = first; + return answer; + } + parsed_number_string pns = parse_number_string(first, last, decimal_separator, fmt); + if (!pns.valid) { + return detail::parse_infnan(first, last, value); + } + answer.ec = std::errc(); // be optimistic + answer.ptr = pns.lastmatch; + // Next is Clinger's fast path. + if (binary_format::min_exponent_fast_path() <= pns.exponent && pns.exponent <= binary_format::max_exponent_fast_path() && pns.mantissa <=binary_format::max_mantissa_fast_path() && !pns.too_many_digits) { + value = T(pns.mantissa); + if (pns.exponent < 0) { value = value / binary_format::exact_power_of_ten(-pns.exponent); } + else { value = value * binary_format::exact_power_of_ten(pns.exponent); } + if (pns.negative) { value = -value; } + return answer; + } + adjusted_mantissa am = compute_float>(pns.exponent, pns.mantissa); + if(pns.too_many_digits) { + if(am != compute_float>(pns.exponent, pns.mantissa + 1)) { + am.power2 = -1; // value is invalid. + } + } + // If we called compute_float>(pns.exponent, pns.mantissa) and we have an invalid power (am.power2 < 0), + // then we need to go the long way around again. This is very uncommon. + if(am.power2 < 0) { am = parse_long_mantissa>(first,last); } + detail::to_float(pns.negative, am, value); + return answer; +} + +} // namespace duckdb_fast_float + +#endif + diff --git a/src/duckdb/third_party/fastpforlib/bitpacking.cpp b/src/duckdb/third_party/fastpforlib/bitpacking.cpp new file mode 100644 index 00000000..91388868 --- /dev/null +++ b/src/duckdb/third_party/fastpforlib/bitpacking.cpp @@ -0,0 +1,1284 @@ +#include "bitpacking.h" + +#include +#include + +namespace duckdb_fastpforlib { +namespace internal { + +// Used for uint8_t, uint16_t and uint32_t +template +typename std::enable_if<(DELTA + SHR) < TYPE_SIZE>::type unpack_single_out(const TYPE *__restrict in, + TYPE *__restrict out) { + *out = ((*in) >> SHR) % (1 << DELTA); +} + +// Used for uint8_t, uint16_t and uint32_t +template +typename std::enable_if<(DELTA + SHR) >= TYPE_SIZE>::type unpack_single_out(const TYPE *__restrict &in, + TYPE *__restrict out) { + *out = (*in) >> SHR; + ++in; + + static const TYPE NEXT_SHR = SHR + DELTA - TYPE_SIZE; + *out |= ((*in) % (1U << NEXT_SHR)) << (TYPE_SIZE - SHR); +} + +template +typename std::enable_if<(DELTA + SHR) < 32>::type unpack_single_out(const uint32_t *__restrict in, + uint64_t *__restrict out) { + *out = ((static_cast(*in)) >> SHR) % (1ULL << DELTA); +} + +template +typename std::enable_if<(DELTA + SHR) >= 32 && (DELTA + SHR) < 64>::type +unpack_single_out(const uint32_t *__restrict &in, uint64_t *__restrict out) { + *out = static_cast(*in) >> SHR; + ++in; + if (DELTA + SHR > 32) { + static const uint8_t NEXT_SHR = SHR + DELTA - 32; + *out |= static_cast((*in) % (1U << NEXT_SHR)) << (32 - SHR); + } +} + +template +typename std::enable_if<(DELTA + SHR) >= 64>::type unpack_single_out(const uint32_t *__restrict &in, + uint64_t *__restrict out) { + *out = static_cast(*in) >> SHR; + ++in; + + *out |= static_cast(*in) << (32 - SHR); + ++in; + + if (DELTA + SHR > 64) { + static const uint8_t NEXT_SHR = DELTA + SHR - 64; + *out |= static_cast((*in) % (1U << NEXT_SHR)) << (64 - SHR); + } +} + +// Used for uint8_t, uint16_t and uint32_t +template + typename std::enable_if < DELTA + SHL::type pack_single_in(const TYPE in, TYPE *__restrict out) { + if (SHL == 0) { + *out = in & MASK; + } else { + *out |= (in & MASK) << SHL; + } +} + +// Used for uint8_t, uint16_t and uint32_t +template +typename std::enable_if= TYPE_SIZE>::type pack_single_in(const TYPE in, TYPE *__restrict &out) { + *out |= in << SHL; + ++out; + + if (DELTA + SHL > TYPE_SIZE) { + *out = (in & MASK) >> (TYPE_SIZE - SHL); + } +} + +template + typename std::enable_if < DELTA + SHL<32>::type pack_single_in64(const uint64_t in, uint32_t *__restrict out) { + if (SHL == 0) { + *out = static_cast(in & MASK); + } else { + *out |= (in & MASK) << SHL; + } +} +template + typename std::enable_if < DELTA + SHL >= 32 && + DELTA + SHL<64>::type pack_single_in64(const uint64_t in, uint32_t *__restrict &out) { + if (SHL == 0) { + *out = static_cast(in & MASK); + } else { + *out |= (in & MASK) << SHL; + } + + ++out; + + if (DELTA + SHL > 32) { + *out = static_cast((in & MASK) >> (32 - SHL)); + } +} +template +typename std::enable_if= 64>::type pack_single_in64(const uint64_t in, uint32_t *__restrict &out) { + *out |= in << SHL; + ++out; + + *out = static_cast((in & MASK) >> (32 - SHL)); + ++out; + + if (DELTA + SHL > 64) { + *out = (in & MASK) >> (64 - SHL); + } +} +template +struct Unroller8 { + static void Unpack(const uint8_t *__restrict &in, uint8_t *__restrict out) { + unpack_single_out(in, out + OINDEX); + + Unroller8::Unpack(in, out); + } + + static void Pack(const uint8_t *__restrict in, uint8_t *__restrict out) { + pack_single_in(in[OINDEX], out); + + Unroller8::Pack(in, out); + } + +};\ +template +struct Unroller8 { + enum { SHIFT = (DELTA * 7) % 8 }; + + static void Unpack(const uint8_t *__restrict in, uint8_t *__restrict out) { + out[7] = (*in) >> SHIFT; + } + + static void Pack(const uint8_t *__restrict in, uint8_t *__restrict out) { + *out |= (in[7] << SHIFT); + } +}; + +template +struct Unroller16 { + static void Unpack(const uint16_t *__restrict &in, uint16_t *__restrict out) { + unpack_single_out(in, out + OINDEX); + + Unroller16::Unpack(in, out); + } + + static void Pack(const uint16_t *__restrict in, uint16_t *__restrict out) { + pack_single_in(in[OINDEX], out); + + Unroller16::Pack(in, out); + } + +}; + +template +struct Unroller16 { + enum { SHIFT = (DELTA * 15) % 16 }; + + static void Unpack(const uint16_t *__restrict in, uint16_t *__restrict out) { + out[15] = (*in) >> SHIFT; + } + + static void Pack(const uint16_t *__restrict in, uint16_t *__restrict out) { + *out |= (in[15] << SHIFT); + } +}; + +template +struct Unroller { + static void Unpack(const uint32_t *__restrict &in, uint32_t *__restrict out) { + unpack_single_out(in, out + OINDEX); + + Unroller::Unpack(in, out); + } + + static void Unpack(const uint32_t *__restrict &in, uint64_t *__restrict out) { + unpack_single_out(in, out + OINDEX); + + Unroller::Unpack(in, out); + } + + static void Pack(const uint32_t *__restrict in, uint32_t *__restrict out) { + pack_single_in(in[OINDEX], out); + + Unroller::Pack(in, out); + } + + static void Pack(const uint64_t *__restrict in, uint32_t *__restrict out) { + pack_single_in64(in[OINDEX], out); + + Unroller::Pack(in, out); + } +}; + +template +struct Unroller { + enum { SHIFT = (DELTA * 31) % 32 }; + + static void Unpack(const uint32_t *__restrict in, uint32_t *__restrict out) { + out[31] = (*in) >> SHIFT; + } + + static void Unpack(const uint32_t *__restrict in, uint64_t *__restrict out) { + out[31] = (*in) >> SHIFT; + if (DELTA > 32) { + ++in; + out[31] |= static_cast(*in) << (32 - SHIFT); + } + } + + static void Pack(const uint32_t *__restrict in, uint32_t *__restrict out) { + *out |= (in[31] << SHIFT); + } + + static void Pack(const uint64_t *__restrict in, uint32_t *__restrict out) { + *out |= (in[31] << SHIFT); + if (DELTA > 32) { + ++out; + *out = static_cast(in[31] >> (32 - SHIFT)); + } + } +}; + +// Special cases +void __fastunpack0(const uint8_t *__restrict, uint8_t *__restrict out) { + for (uint8_t i = 0; i < 8; ++i) + *(out++) = 0; +} + +void __fastunpack0(const uint16_t *__restrict, uint16_t *__restrict out) { + for (uint16_t i = 0; i < 16; ++i) + *(out++) = 0; +} + +void __fastunpack0(const uint32_t *__restrict, uint32_t *__restrict out) { + for (uint32_t i = 0; i < 32; ++i) + *(out++) = 0; +} + +void __fastunpack0(const uint32_t *__restrict, uint64_t *__restrict out) { + for (uint32_t i = 0; i < 32; ++i) + *(out++) = 0; +} + +void __fastpack0(const uint8_t *__restrict, uint8_t *__restrict) { +} +void __fastpack0(const uint16_t *__restrict, uint16_t *__restrict) { +} +void __fastpack0(const uint32_t *__restrict, uint32_t *__restrict) { +} +void __fastpack0(const uint64_t *__restrict, uint32_t *__restrict) { +} + +// fastunpack for 8 bits +void __fastunpack1(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<1>::Unpack(in, out); +} + +void __fastunpack2(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<2>::Unpack(in, out); +} + +void __fastunpack3(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<3>::Unpack(in, out); +} + +void __fastunpack4(const uint8_t *__restrict in, uint8_t *__restrict out) { + for (uint8_t outer = 0; outer < 4; ++outer) { + for (uint8_t inwordpointer = 0; inwordpointer < 8; inwordpointer += 4) + *(out++) = ((*in) >> inwordpointer) % (1U << 4); + ++in; + } +} + +void __fastunpack5(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<5>::Unpack(in, out); +} + +void __fastunpack6(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<6>::Unpack(in, out); +} + +void __fastunpack7(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<7>::Unpack(in, out); +} + +void __fastunpack8(const uint8_t *__restrict in, uint8_t *__restrict out) { + for (int k = 0; k < 8; ++k) + out[k] = in[k]; +} + + +// fastunpack for 16 bits +void __fastunpack1(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<1>::Unpack(in, out); +} + +void __fastunpack2(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<2>::Unpack(in, out); +} + +void __fastunpack3(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<3>::Unpack(in, out); +} + +void __fastunpack4(const uint16_t *__restrict in, uint16_t *__restrict out) { + for (uint16_t outer = 0; outer < 4; ++outer) { + for (uint16_t inwordpointer = 0; inwordpointer < 16; inwordpointer += 4) + *(out++) = ((*in) >> inwordpointer) % (1U << 4); + ++in; + } +} + +void __fastunpack5(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<5>::Unpack(in, out); +} + +void __fastunpack6(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<6>::Unpack(in, out); +} + +void __fastunpack7(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<7>::Unpack(in, out); +} + +void __fastunpack8(const uint16_t *__restrict in, uint16_t *__restrict out) { + for (uint16_t outer = 0; outer < 8; ++outer) { + for (uint16_t inwordpointer = 0; inwordpointer < 16; inwordpointer += 8) + *(out++) = ((*in) >> inwordpointer) % (1U << 8); + ++in; + } +} + +void __fastunpack9(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<9>::Unpack(in, out); +} + +void __fastunpack10(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<10>::Unpack(in, out); +} + +void __fastunpack11(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<11>::Unpack(in, out); +} + +void __fastunpack12(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<12>::Unpack(in, out); +} + +void __fastunpack13(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<13>::Unpack(in, out); +} + +void __fastunpack14(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<14>::Unpack(in, out); +} + +void __fastunpack15(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<15>::Unpack(in, out); +} + +void __fastunpack16(const uint16_t *__restrict in, uint16_t *__restrict out) { + for (int k = 0; k < 16; ++k) + out[k] = in[k]; +} + +// fastunpack for 32 bits +void __fastunpack1(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<1>::Unpack(in, out); +} + +void __fastunpack2(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<2>::Unpack(in, out); +} + +void __fastunpack3(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<3>::Unpack(in, out); +} + +void __fastunpack4(const uint32_t *__restrict in, uint32_t *__restrict out) { + for (uint32_t outer = 0; outer < 4; ++outer) { + for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 4) + *(out++) = ((*in) >> inwordpointer) % (1U << 4); + ++in; + } +} + +void __fastunpack5(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<5>::Unpack(in, out); +} + +void __fastunpack6(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<6>::Unpack(in, out); +} + +void __fastunpack7(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<7>::Unpack(in, out); +} + +void __fastunpack8(const uint32_t *__restrict in, uint32_t *__restrict out) { + for (uint32_t outer = 0; outer < 8; ++outer) { + for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 8) + *(out++) = ((*in) >> inwordpointer) % (1U << 8); + ++in; + } +} + +void __fastunpack9(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<9>::Unpack(in, out); +} + +void __fastunpack10(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<10>::Unpack(in, out); +} + +void __fastunpack11(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<11>::Unpack(in, out); +} + +void __fastunpack12(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<12>::Unpack(in, out); +} + +void __fastunpack13(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<13>::Unpack(in, out); +} + +void __fastunpack14(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<14>::Unpack(in, out); +} + +void __fastunpack15(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<15>::Unpack(in, out); +} + +void __fastunpack16(const uint32_t *__restrict in, uint32_t *__restrict out) { + for (uint32_t outer = 0; outer < 16; ++outer) { + for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 16) + *(out++) = ((*in) >> inwordpointer) % (1U << 16); + ++in; + } +} + +void __fastunpack17(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<17>::Unpack(in, out); +} + +void __fastunpack18(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<18>::Unpack(in, out); +} + +void __fastunpack19(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<19>::Unpack(in, out); +} + +void __fastunpack20(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<20>::Unpack(in, out); +} + +void __fastunpack21(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<21>::Unpack(in, out); +} + +void __fastunpack22(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<22>::Unpack(in, out); +} + +void __fastunpack23(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<23>::Unpack(in, out); +} + +void __fastunpack24(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<24>::Unpack(in, out); +} + +void __fastunpack25(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<25>::Unpack(in, out); +} + +void __fastunpack26(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<26>::Unpack(in, out); +} + +void __fastunpack27(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<27>::Unpack(in, out); +} + +void __fastunpack28(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<28>::Unpack(in, out); +} + +void __fastunpack29(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<29>::Unpack(in, out); +} + +void __fastunpack30(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<30>::Unpack(in, out); +} + +void __fastunpack31(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<31>::Unpack(in, out); +} + +void __fastunpack32(const uint32_t *__restrict in, uint32_t *__restrict out) { + for (int k = 0; k < 32; ++k) + out[k] = in[k]; +} + +// fastupack for 64 bits +void __fastunpack1(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<1>::Unpack(in, out); +} + +void __fastunpack2(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<2>::Unpack(in, out); +} + +void __fastunpack3(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<3>::Unpack(in, out); +} + +void __fastunpack4(const uint32_t *__restrict in, uint64_t *__restrict out) { + for (uint32_t outer = 0; outer < 4; ++outer) { + for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 4) + *(out++) = ((*in) >> inwordpointer) % (1U << 4); + ++in; + } +} + +void __fastunpack5(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<5>::Unpack(in, out); +} + +void __fastunpack6(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<6>::Unpack(in, out); +} + +void __fastunpack7(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<7>::Unpack(in, out); +} + +void __fastunpack8(const uint32_t *__restrict in, uint64_t *__restrict out) { + for (uint32_t outer = 0; outer < 8; ++outer) { + for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 8) { + *(out++) = ((*in) >> inwordpointer) % (1U << 8); + } + ++in; + } +} + +void __fastunpack9(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<9>::Unpack(in, out); +} + +void __fastunpack10(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<10>::Unpack(in, out); +} + +void __fastunpack11(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<11>::Unpack(in, out); +} + +void __fastunpack12(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<12>::Unpack(in, out); +} + +void __fastunpack13(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<13>::Unpack(in, out); +} + +void __fastunpack14(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<14>::Unpack(in, out); +} + +void __fastunpack15(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<15>::Unpack(in, out); +} + +void __fastunpack16(const uint32_t *__restrict in, uint64_t *__restrict out) { + for (uint32_t outer = 0; outer < 16; ++outer) { + for (uint32_t inwordpointer = 0; inwordpointer < 32; inwordpointer += 16) + *(out++) = ((*in) >> inwordpointer) % (1U << 16); + ++in; + } +} + +void __fastunpack17(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<17>::Unpack(in, out); +} + +void __fastunpack18(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<18>::Unpack(in, out); +} + +void __fastunpack19(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<19>::Unpack(in, out); +} + +void __fastunpack20(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<20>::Unpack(in, out); +} + +void __fastunpack21(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<21>::Unpack(in, out); +} + +void __fastunpack22(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<22>::Unpack(in, out); +} + +void __fastunpack23(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<23>::Unpack(in, out); +} + +void __fastunpack24(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<24>::Unpack(in, out); +} + +void __fastunpack25(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<25>::Unpack(in, out); +} + +void __fastunpack26(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<26>::Unpack(in, out); +} + +void __fastunpack27(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<27>::Unpack(in, out); +} + +void __fastunpack28(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<28>::Unpack(in, out); +} + +void __fastunpack29(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<29>::Unpack(in, out); +} + +void __fastunpack30(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<30>::Unpack(in, out); +} + +void __fastunpack31(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<31>::Unpack(in, out); +} + +void __fastunpack32(const uint32_t *__restrict in, uint64_t *__restrict out) { + for (int k = 0; k < 32; ++k) + out[k] = in[k]; +} + +void __fastunpack33(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<33>::Unpack(in, out); +} + +void __fastunpack34(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<34>::Unpack(in, out); +} + +void __fastunpack35(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<35>::Unpack(in, out); +} + +void __fastunpack36(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<36>::Unpack(in, out); +} + +void __fastunpack37(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<37>::Unpack(in, out); +} + +void __fastunpack38(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<38>::Unpack(in, out); +} + +void __fastunpack39(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<39>::Unpack(in, out); +} + +void __fastunpack40(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<40>::Unpack(in, out); +} + +void __fastunpack41(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<41>::Unpack(in, out); +} + +void __fastunpack42(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<42>::Unpack(in, out); +} + +void __fastunpack43(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<43>::Unpack(in, out); +} + +void __fastunpack44(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<44>::Unpack(in, out); +} + +void __fastunpack45(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<45>::Unpack(in, out); +} + +void __fastunpack46(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<46>::Unpack(in, out); +} + +void __fastunpack47(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<47>::Unpack(in, out); +} + +void __fastunpack48(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<48>::Unpack(in, out); +} + +void __fastunpack49(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<49>::Unpack(in, out); +} + +void __fastunpack50(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<50>::Unpack(in, out); +} + +void __fastunpack51(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<51>::Unpack(in, out); +} + +void __fastunpack52(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<52>::Unpack(in, out); +} + +void __fastunpack53(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<53>::Unpack(in, out); +} + +void __fastunpack54(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<54>::Unpack(in, out); +} + +void __fastunpack55(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<55>::Unpack(in, out); +} + +void __fastunpack56(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<56>::Unpack(in, out); +} + +void __fastunpack57(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<57>::Unpack(in, out); +} + +void __fastunpack58(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<58>::Unpack(in, out); +} + +void __fastunpack59(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<59>::Unpack(in, out); +} + +void __fastunpack60(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<60>::Unpack(in, out); +} + +void __fastunpack61(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<61>::Unpack(in, out); +} + +void __fastunpack62(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<62>::Unpack(in, out); +} + +void __fastunpack63(const uint32_t *__restrict in, uint64_t *__restrict out) { + Unroller<63>::Unpack(in, out); +} + +void __fastunpack64(const uint32_t *__restrict in, uint64_t *__restrict out) { + for (int k = 0; k < 32; ++k) { + out[k] = in[k * 2]; + out[k] |= static_cast(in[k * 2 + 1]) << 32; + } +} + +// fastpack for 8 bits + +void __fastpack1(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<1>::Pack(in, out); +} + +void __fastpack2(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<2>::Pack(in, out); +} + +void __fastpack3(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<3>::Pack(in, out); +} + +void __fastpack4(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<4>::Pack(in, out); +} + +void __fastpack5(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<5>::Pack(in, out); +} + +void __fastpack6(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<6>::Pack(in, out); +} + +void __fastpack7(const uint8_t *__restrict in, uint8_t *__restrict out) { + Unroller8<7>::Pack(in, out); +} + +void __fastpack8(const uint8_t *__restrict in, uint8_t *__restrict out) { + for (int k = 0; k < 8; ++k) + out[k] = in[k]; +} + +// fastpack for 16 bits + +void __fastpack1(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<1>::Pack(in, out); +} + +void __fastpack2(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<2>::Pack(in, out); +} + +void __fastpack3(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<3>::Pack(in, out); +} + +void __fastpack4(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<4>::Pack(in, out); +} + +void __fastpack5(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<5>::Pack(in, out); +} + +void __fastpack6(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<6>::Pack(in, out); +} + +void __fastpack7(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<7>::Pack(in, out); +} + +void __fastpack8(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<8>::Pack(in, out); +} + +void __fastpack9(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<9>::Pack(in, out); +} + +void __fastpack10(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<10>::Pack(in, out); +} + +void __fastpack11(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<11>::Pack(in, out); +} + +void __fastpack12(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<12>::Pack(in, out); +} + +void __fastpack13(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<13>::Pack(in, out); +} + +void __fastpack14(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<14>::Pack(in, out); +} + +void __fastpack15(const uint16_t *__restrict in, uint16_t *__restrict out) { + Unroller16<15>::Pack(in, out); +} + +void __fastpack16(const uint16_t *__restrict in, uint16_t *__restrict out) { + for (int k = 0; k < 16; ++k) + out[k] = in[k]; +} + + +// fastpack for 32 bits + +void __fastpack1(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<1>::Pack(in, out); +} + +void __fastpack2(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<2>::Pack(in, out); +} + +void __fastpack3(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<3>::Pack(in, out); +} + +void __fastpack4(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<4>::Pack(in, out); +} + +void __fastpack5(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<5>::Pack(in, out); +} + +void __fastpack6(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<6>::Pack(in, out); +} + +void __fastpack7(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<7>::Pack(in, out); +} + +void __fastpack8(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<8>::Pack(in, out); +} + +void __fastpack9(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<9>::Pack(in, out); +} + +void __fastpack10(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<10>::Pack(in, out); +} + +void __fastpack11(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<11>::Pack(in, out); +} + +void __fastpack12(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<12>::Pack(in, out); +} + +void __fastpack13(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<13>::Pack(in, out); +} + +void __fastpack14(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<14>::Pack(in, out); +} + +void __fastpack15(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<15>::Pack(in, out); +} + +void __fastpack16(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<16>::Pack(in, out); +} + +void __fastpack17(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<17>::Pack(in, out); +} + +void __fastpack18(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<18>::Pack(in, out); +} + +void __fastpack19(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<19>::Pack(in, out); +} + +void __fastpack20(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<20>::Pack(in, out); +} + +void __fastpack21(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<21>::Pack(in, out); +} + +void __fastpack22(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<22>::Pack(in, out); +} + +void __fastpack23(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<23>::Pack(in, out); +} + +void __fastpack24(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<24>::Pack(in, out); +} + +void __fastpack25(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<25>::Pack(in, out); +} + +void __fastpack26(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<26>::Pack(in, out); +} + +void __fastpack27(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<27>::Pack(in, out); +} + +void __fastpack28(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<28>::Pack(in, out); +} + +void __fastpack29(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<29>::Pack(in, out); +} + +void __fastpack30(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<30>::Pack(in, out); +} + +void __fastpack31(const uint32_t *__restrict in, uint32_t *__restrict out) { + Unroller<31>::Pack(in, out); +} + +void __fastpack32(const uint32_t *__restrict in, uint32_t *__restrict out) { + for (int k = 0; k < 32; ++k) + out[k] = in[k]; +} + +// fastpack for 64 bits + +void __fastpack1(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<1>::Pack(in, out); +} + +void __fastpack2(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<2>::Pack(in, out); +} + +void __fastpack3(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<3>::Pack(in, out); +} + +void __fastpack4(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<4>::Pack(in, out); +} + +void __fastpack5(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<5>::Pack(in, out); +} + +void __fastpack6(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<6>::Pack(in, out); +} + +void __fastpack7(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<7>::Pack(in, out); +} + +void __fastpack8(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<8>::Pack(in, out); +} + +void __fastpack9(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<9>::Pack(in, out); +} + +void __fastpack10(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<10>::Pack(in, out); +} + +void __fastpack11(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<11>::Pack(in, out); +} + +void __fastpack12(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<12>::Pack(in, out); +} + +void __fastpack13(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<13>::Pack(in, out); +} + +void __fastpack14(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<14>::Pack(in, out); +} + +void __fastpack15(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<15>::Pack(in, out); +} + +void __fastpack16(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<16>::Pack(in, out); +} + +void __fastpack17(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<17>::Pack(in, out); +} + +void __fastpack18(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<18>::Pack(in, out); +} + +void __fastpack19(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<19>::Pack(in, out); +} + +void __fastpack20(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<20>::Pack(in, out); +} + +void __fastpack21(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<21>::Pack(in, out); +} + +void __fastpack22(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<22>::Pack(in, out); +} + +void __fastpack23(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<23>::Pack(in, out); +} + +void __fastpack24(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<24>::Pack(in, out); +} + +void __fastpack25(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<25>::Pack(in, out); +} + +void __fastpack26(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<26>::Pack(in, out); +} + +void __fastpack27(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<27>::Pack(in, out); +} + +void __fastpack28(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<28>::Pack(in, out); +} + +void __fastpack29(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<29>::Pack(in, out); +} + +void __fastpack30(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<30>::Pack(in, out); +} + +void __fastpack31(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<31>::Pack(in, out); +} + +void __fastpack32(const uint64_t *__restrict in, uint32_t *__restrict out) { + for (int k = 0; k < 32; ++k) { + out[k] = static_cast(in[k]); + } +} + +void __fastpack33(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<33>::Pack(in, out); +} + +void __fastpack34(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<34>::Pack(in, out); +} + +void __fastpack35(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<35>::Pack(in, out); +} + +void __fastpack36(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<36>::Pack(in, out); +} + +void __fastpack37(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<37>::Pack(in, out); +} + +void __fastpack38(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<38>::Pack(in, out); +} + +void __fastpack39(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<39>::Pack(in, out); +} + +void __fastpack40(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<40>::Pack(in, out); +} + +void __fastpack41(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<41>::Pack(in, out); +} + +void __fastpack42(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<42>::Pack(in, out); +} + +void __fastpack43(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<43>::Pack(in, out); +} + +void __fastpack44(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<44>::Pack(in, out); +} + +void __fastpack45(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<45>::Pack(in, out); +} + +void __fastpack46(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<46>::Pack(in, out); +} + +void __fastpack47(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<47>::Pack(in, out); +} + +void __fastpack48(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<48>::Pack(in, out); +} + +void __fastpack49(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<49>::Pack(in, out); +} + +void __fastpack50(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<50>::Pack(in, out); +} + +void __fastpack51(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<51>::Pack(in, out); +} + +void __fastpack52(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<52>::Pack(in, out); +} + +void __fastpack53(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<53>::Pack(in, out); +} + +void __fastpack54(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<54>::Pack(in, out); +} + +void __fastpack55(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<55>::Pack(in, out); +} + +void __fastpack56(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<56>::Pack(in, out); +} + +void __fastpack57(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<57>::Pack(in, out); +} + +void __fastpack58(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<58>::Pack(in, out); +} + +void __fastpack59(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<59>::Pack(in, out); +} + +void __fastpack60(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<60>::Pack(in, out); +} + +void __fastpack61(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<61>::Pack(in, out); +} + +void __fastpack62(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<62>::Pack(in, out); +} + +void __fastpack63(const uint64_t *__restrict in, uint32_t *__restrict out) { + Unroller<63>::Pack(in, out); +} + +void __fastpack64(const uint64_t *__restrict in, uint32_t *__restrict out) { + for (int i = 0; i < 32; ++i) { + out[2 * i] = static_cast(in[i]); + out[2 * i + 1] = in[i] >> 32; + } +} +} // namespace internal +} // namespace duckdb_fastpforlib diff --git a/src/duckdb/third_party/fastpforlib/bitpacking.h b/src/duckdb/third_party/fastpforlib/bitpacking.h new file mode 100644 index 00000000..ffe700e3 --- /dev/null +++ b/src/duckdb/third_party/fastpforlib/bitpacking.h @@ -0,0 +1,278 @@ +/** + * This code is released under the + * Apache License Version 2.0 http://www.apache.org/licenses/. + * + * (c) Daniel Lemire, http://fastpforlib.me/en/ + */ +#pragma once +#include +#include + +namespace duckdb_fastpforlib { +namespace internal { + +// Unpacks 8 uint8_t values +void __fastunpack0(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastunpack1(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastunpack2(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastunpack3(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastunpack4(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastunpack5(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastunpack6(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastunpack7(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastunpack8(const uint8_t *__restrict in, uint8_t *__restrict out); + +// Unpacks 16 uint16_t values +void __fastunpack0(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack1(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack2(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack3(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack4(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack5(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack6(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack7(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack8(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack9(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack10(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack11(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack12(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack13(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack14(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack15(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastunpack16(const uint16_t *__restrict in, uint16_t *__restrict out); + +// Unpacks 32 uint32_t values +void __fastunpack0(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack1(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack2(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack3(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack4(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack5(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack6(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack7(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack8(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack9(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack10(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack11(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack12(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack13(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack14(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack15(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack16(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack17(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack18(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack19(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack20(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack21(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack22(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack23(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack24(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack25(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack26(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack27(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack28(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack29(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack30(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack31(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastunpack32(const uint32_t *__restrict in, uint32_t *__restrict out); + +// Unpacks 32 uint64_t values +void __fastunpack0(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack1(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack2(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack3(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack4(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack5(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack6(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack7(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack8(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack9(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack10(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack11(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack12(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack13(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack14(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack15(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack16(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack17(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack18(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack19(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack20(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack21(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack22(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack23(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack24(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack25(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack26(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack27(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack28(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack29(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack30(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack31(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack32(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack33(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack34(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack35(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack36(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack37(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack38(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack39(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack40(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack41(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack42(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack43(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack44(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack45(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack46(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack47(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack48(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack49(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack50(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack51(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack52(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack53(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack54(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack55(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack56(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack57(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack58(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack59(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack60(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack61(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack62(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack63(const uint32_t *__restrict in, uint64_t *__restrict out); +void __fastunpack64(const uint32_t *__restrict in, uint64_t *__restrict out); + +// Packs 8 int8_t values +void __fastpack0(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastpack1(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastpack2(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastpack3(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastpack4(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastpack5(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastpack6(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastpack7(const uint8_t *__restrict in, uint8_t *__restrict out); +void __fastpack8(const uint8_t *__restrict in, uint8_t *__restrict out); + +// Packs 16 int16_t values +void __fastpack0(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack1(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack2(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack3(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack4(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack5(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack6(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack7(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack8(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack9(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack10(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack11(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack12(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack13(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack14(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack15(const uint16_t *__restrict in, uint16_t *__restrict out); +void __fastpack16(const uint16_t *__restrict in, uint16_t *__restrict out); + +// Packs 32 int32_t values +void __fastpack0(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack1(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack2(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack3(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack4(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack5(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack6(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack7(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack8(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack9(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack10(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack11(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack12(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack13(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack14(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack15(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack16(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack17(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack18(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack19(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack20(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack21(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack22(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack23(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack24(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack25(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack26(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack27(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack28(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack29(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack30(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack31(const uint32_t *__restrict in, uint32_t *__restrict out); +void __fastpack32(const uint32_t *__restrict in, uint32_t *__restrict out); + +// Packs 32 int64_t values +void __fastpack0(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack1(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack2(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack3(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack4(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack5(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack6(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack7(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack8(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack9(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack10(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack11(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack12(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack13(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack14(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack15(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack16(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack17(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack18(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack19(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack20(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack21(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack22(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack23(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack24(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack25(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack26(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack27(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack28(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack29(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack30(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack31(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack32(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack33(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack34(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack35(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack36(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack37(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack38(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack39(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack40(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack41(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack42(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack43(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack44(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack45(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack46(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack47(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack48(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack49(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack50(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack51(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack52(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack53(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack54(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack55(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack56(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack57(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack58(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack59(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack60(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack61(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack62(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack63(const uint64_t *__restrict in, uint32_t *__restrict out); +void __fastpack64(const uint64_t *__restrict in, uint32_t *__restrict out); +} // namespace internal +} // namespace duckdb_fastpforlib diff --git a/src/duckdb/third_party/fastpforlib/bitpackinghelpers.h b/src/duckdb/third_party/fastpforlib/bitpackinghelpers.h new file mode 100644 index 00000000..3d3a0925 --- /dev/null +++ b/src/duckdb/third_party/fastpforlib/bitpackinghelpers.h @@ -0,0 +1,869 @@ +/** +* This code is released under the +* Apache License Version 2.0 http://www.apache.org/licenses/. +* +* (c) Daniel Lemire, http://lemire.me/en/ +*/ +#pragma once +#include "bitpacking.h" + + +#include + +namespace duckdb_fastpforlib { + +namespace internal { + +// Note that this only packs 8 values +inline void fastunpack_quarter(const uint8_t *__restrict in, uint8_t *__restrict out, const uint32_t bit) { + // Could have used function pointers instead of switch. + // Switch calls do offer the compiler more opportunities for optimization in + // theory. In this case, it makes no difference with a good compiler. + switch (bit) { + case 0: + internal::__fastunpack0(in, out); + break; + case 1: + internal::__fastunpack1(in, out); + break; + case 2: + internal::__fastunpack2(in, out); + break; + case 3: + internal::__fastunpack3(in, out); + break; + case 4: + internal::__fastunpack4(in, out); + break; + case 5: + internal::__fastunpack5(in, out); + break; + case 6: + internal::__fastunpack6(in, out); + break; + case 7: + internal::__fastunpack7(in, out); + break; + case 8: + internal::__fastunpack8(in, out); + break; + default: + throw std::logic_error("Invalid bit width for bitpacking"); + } +} + +// Note that this only packs 8 values +inline void fastpack_quarter(const uint8_t *__restrict in, uint8_t *__restrict out, const uint32_t bit) { + // Could have used function pointers instead of switch. + // Switch calls do offer the compiler more opportunities for optimization in + // theory. In this case, it makes no difference with a good compiler. + switch (bit) { + case 0: + internal::__fastpack0(in, out); + break; + case 1: + internal::__fastpack1(in, out); + break; + case 2: + internal::__fastpack2(in, out); + break; + case 3: + internal::__fastpack3(in, out); + break; + case 4: + internal::__fastpack4(in, out); + break; + case 5: + internal::__fastpack5(in, out); + break; + case 6: + internal::__fastpack6(in, out); + break; + case 7: + internal::__fastpack7(in, out); + break; + case 8: + internal::__fastpack8(in, out); + break; + default: + throw std::logic_error("Invalid bit width for bitpacking"); + } +} + +// Note that this only packs 16 values +inline void fastunpack_half(const uint16_t *__restrict in, uint16_t *__restrict out, const uint32_t bit) { + // Could have used function pointers instead of switch. + // Switch calls do offer the compiler more opportunities for optimization in + // theory. In this case, it makes no difference with a good compiler. + switch (bit) { + case 0: + internal::__fastunpack0(in, out); + break; + case 1: + internal::__fastunpack1(in, out); + break; + case 2: + internal::__fastunpack2(in, out); + break; + case 3: + internal::__fastunpack3(in, out); + break; + case 4: + internal::__fastunpack4(in, out); + break; + case 5: + internal::__fastunpack5(in, out); + break; + case 6: + internal::__fastunpack6(in, out); + break; + case 7: + internal::__fastunpack7(in, out); + break; + case 8: + internal::__fastunpack8(in, out); + break; + case 9: + internal::__fastunpack9(in, out); + break; + case 10: + internal::__fastunpack10(in, out); + break; + case 11: + internal::__fastunpack11(in, out); + break; + case 12: + internal::__fastunpack12(in, out); + break; + case 13: + internal::__fastunpack13(in, out); + break; + case 14: + internal::__fastunpack14(in, out); + break; + case 15: + internal::__fastunpack15(in, out); + break; + case 16: + internal::__fastunpack16(in, out); + break; + default: + throw std::logic_error("Invalid bit width for bitpacking"); + } +} + +// Note that this only packs 16 values +inline void fastpack_half(const uint16_t *__restrict in, uint16_t *__restrict out, const uint32_t bit) { + // Could have used function pointers instead of switch. + // Switch calls do offer the compiler more opportunities for optimization in + // theory. In this case, it makes no difference with a good compiler. + switch (bit) { + case 0: + internal::__fastpack0(in, out); + break; + case 1: + internal::__fastpack1(in, out); + break; + case 2: + internal::__fastpack2(in, out); + break; + case 3: + internal::__fastpack3(in, out); + break; + case 4: + internal::__fastpack4(in, out); + break; + case 5: + internal::__fastpack5(in, out); + break; + case 6: + internal::__fastpack6(in, out); + break; + case 7: + internal::__fastpack7(in, out); + break; + case 8: + internal::__fastpack8(in, out); + break; + case 9: + internal::__fastpack9(in, out); + break; + case 10: + internal::__fastpack10(in, out); + break; + case 11: + internal::__fastpack11(in, out); + break; + case 12: + internal::__fastpack12(in, out); + break; + case 13: + internal::__fastpack13(in, out); + break; + case 14: + internal::__fastpack14(in, out); + break; + case 15: + internal::__fastpack15(in, out); + break; + case 16: + internal::__fastpack16(in, out); + break; + default: + throw std::logic_error("Invalid bit width for bitpacking"); + } +} +} + +inline void fastunpack(const uint8_t *__restrict in, uint8_t *__restrict out, const uint32_t bit) { + for (uint8_t i = 0; i < 4; i++) { + internal::fastunpack_quarter(in + (i*bit), out+(i*8), bit); + } +} + +inline void fastunpack(const uint16_t *__restrict in, uint16_t *__restrict out, const uint32_t bit) { + internal::fastunpack_half(in, out, bit); + internal::fastunpack_half(in + bit, out+16, bit); +} + +inline void fastunpack(const uint32_t *__restrict in, + uint32_t *__restrict out, const uint32_t bit) { + // Could have used function pointers instead of switch. + // Switch calls do offer the compiler more opportunities for optimization in + // theory. In this case, it makes no difference with a good compiler. + switch (bit) { + case 0: + internal::__fastunpack0(in, out); + break; + case 1: + internal::__fastunpack1(in, out); + break; + case 2: + internal::__fastunpack2(in, out); + break; + case 3: + internal::__fastunpack3(in, out); + break; + case 4: + internal::__fastunpack4(in, out); + break; + case 5: + internal::__fastunpack5(in, out); + break; + case 6: + internal::__fastunpack6(in, out); + break; + case 7: + internal::__fastunpack7(in, out); + break; + case 8: + internal::__fastunpack8(in, out); + break; + case 9: + internal::__fastunpack9(in, out); + break; + case 10: + internal::__fastunpack10(in, out); + break; + case 11: + internal::__fastunpack11(in, out); + break; + case 12: + internal::__fastunpack12(in, out); + break; + case 13: + internal::__fastunpack13(in, out); + break; + case 14: + internal::__fastunpack14(in, out); + break; + case 15: + internal::__fastunpack15(in, out); + break; + case 16: + internal::__fastunpack16(in, out); + break; + case 17: + internal::__fastunpack17(in, out); + break; + case 18: + internal::__fastunpack18(in, out); + break; + case 19: + internal::__fastunpack19(in, out); + break; + case 20: + internal::__fastunpack20(in, out); + break; + case 21: + internal::__fastunpack21(in, out); + break; + case 22: + internal::__fastunpack22(in, out); + break; + case 23: + internal::__fastunpack23(in, out); + break; + case 24: + internal::__fastunpack24(in, out); + break; + case 25: + internal::__fastunpack25(in, out); + break; + case 26: + internal::__fastunpack26(in, out); + break; + case 27: + internal::__fastunpack27(in, out); + break; + case 28: + internal::__fastunpack28(in, out); + break; + case 29: + internal::__fastunpack29(in, out); + break; + case 30: + internal::__fastunpack30(in, out); + break; + case 31: + internal::__fastunpack31(in, out); + break; + case 32: + internal::__fastunpack32(in, out); + break; + default: + throw std::logic_error("Invalid bit width for bitpacking"); + } +} + +inline void fastunpack(const uint32_t *__restrict in, + uint64_t *__restrict out, const uint32_t bit) { + // Could have used function pointers instead of switch. + // Switch calls do offer the compiler more opportunities for optimization in + // theory. In this case, it makes no difference with a good compiler. + switch (bit) { + case 0: + internal::__fastunpack0(in, out); + break; + case 1: + internal::__fastunpack1(in, out); + break; + case 2: + internal::__fastunpack2(in, out); + break; + case 3: + internal::__fastunpack3(in, out); + break; + case 4: + internal::__fastunpack4(in, out); + break; + case 5: + internal::__fastunpack5(in, out); + break; + case 6: + internal::__fastunpack6(in, out); + break; + case 7: + internal::__fastunpack7(in, out); + break; + case 8: + internal::__fastunpack8(in, out); + break; + case 9: + internal::__fastunpack9(in, out); + break; + case 10: + internal::__fastunpack10(in, out); + break; + case 11: + internal::__fastunpack11(in, out); + break; + case 12: + internal::__fastunpack12(in, out); + break; + case 13: + internal::__fastunpack13(in, out); + break; + case 14: + internal::__fastunpack14(in, out); + break; + case 15: + internal::__fastunpack15(in, out); + break; + case 16: + internal::__fastunpack16(in, out); + break; + case 17: + internal::__fastunpack17(in, out); + break; + case 18: + internal::__fastunpack18(in, out); + break; + case 19: + internal::__fastunpack19(in, out); + break; + case 20: + internal::__fastunpack20(in, out); + break; + case 21: + internal::__fastunpack21(in, out); + break; + case 22: + internal::__fastunpack22(in, out); + break; + case 23: + internal::__fastunpack23(in, out); + break; + case 24: + internal::__fastunpack24(in, out); + break; + case 25: + internal::__fastunpack25(in, out); + break; + case 26: + internal::__fastunpack26(in, out); + break; + case 27: + internal::__fastunpack27(in, out); + break; + case 28: + internal::__fastunpack28(in, out); + break; + case 29: + internal::__fastunpack29(in, out); + break; + case 30: + internal::__fastunpack30(in, out); + break; + case 31: + internal::__fastunpack31(in, out); + break; + case 32: + internal::__fastunpack32(in, out); + break; + case 33: + internal::__fastunpack33(in, out); + break; + case 34: + internal::__fastunpack34(in, out); + break; + case 35: + internal::__fastunpack35(in, out); + break; + case 36: + internal::__fastunpack36(in, out); + break; + case 37: + internal::__fastunpack37(in, out); + break; + case 38: + internal::__fastunpack38(in, out); + break; + case 39: + internal::__fastunpack39(in, out); + break; + case 40: + internal::__fastunpack40(in, out); + break; + case 41: + internal::__fastunpack41(in, out); + break; + case 42: + internal::__fastunpack42(in, out); + break; + case 43: + internal::__fastunpack43(in, out); + break; + case 44: + internal::__fastunpack44(in, out); + break; + case 45: + internal::__fastunpack45(in, out); + break; + case 46: + internal::__fastunpack46(in, out); + break; + case 47: + internal::__fastunpack47(in, out); + break; + case 48: + internal::__fastunpack48(in, out); + break; + case 49: + internal::__fastunpack49(in, out); + break; + case 50: + internal::__fastunpack50(in, out); + break; + case 51: + internal::__fastunpack51(in, out); + break; + case 52: + internal::__fastunpack52(in, out); + break; + case 53: + internal::__fastunpack53(in, out); + break; + case 54: + internal::__fastunpack54(in, out); + break; + case 55: + internal::__fastunpack55(in, out); + break; + case 56: + internal::__fastunpack56(in, out); + break; + case 57: + internal::__fastunpack57(in, out); + break; + case 58: + internal::__fastunpack58(in, out); + break; + case 59: + internal::__fastunpack59(in, out); + break; + case 60: + internal::__fastunpack60(in, out); + break; + case 61: + internal::__fastunpack61(in, out); + break; + case 62: + internal::__fastunpack62(in, out); + break; + case 63: + internal::__fastunpack63(in, out); + break; + case 64: + internal::__fastunpack64(in, out); + break; + default: + throw std::logic_error("Invalid bit width for bitpacking"); + } +} + +inline void fastpack(const uint8_t *__restrict in, uint8_t *__restrict out, const uint32_t bit) { + + for (uint8_t i = 0; i < 4; i++) { + internal::fastpack_quarter(in+(i*8), out + (i*bit), bit); + } +} + +inline void fastpack(const uint16_t *__restrict in, uint16_t *__restrict out, const uint32_t bit) { + internal::fastpack_half(in, out, bit); + internal::fastpack_half(in+16, out + bit, bit); +} + +inline void fastpack(const uint32_t *__restrict in, + uint32_t *__restrict out, const uint32_t bit) { + // Could have used function pointers instead of switch. + // Switch calls do offer the compiler more opportunities for optimization in + // theory. In this case, it makes no difference with a good compiler. + switch (bit) { + case 0: + internal::__fastpack0(in, out); + break; + case 1: + internal::__fastpack1(in, out); + break; + case 2: + internal::__fastpack2(in, out); + break; + case 3: + internal::__fastpack3(in, out); + break; + case 4: + internal::__fastpack4(in, out); + break; + case 5: + internal::__fastpack5(in, out); + break; + case 6: + internal::__fastpack6(in, out); + break; + case 7: + internal::__fastpack7(in, out); + break; + case 8: + internal::__fastpack8(in, out); + break; + case 9: + internal::__fastpack9(in, out); + break; + case 10: + internal::__fastpack10(in, out); + break; + case 11: + internal::__fastpack11(in, out); + break; + case 12: + internal::__fastpack12(in, out); + break; + case 13: + internal::__fastpack13(in, out); + break; + case 14: + internal::__fastpack14(in, out); + break; + case 15: + internal::__fastpack15(in, out); + break; + case 16: + internal::__fastpack16(in, out); + break; + case 17: + internal::__fastpack17(in, out); + break; + case 18: + internal::__fastpack18(in, out); + break; + case 19: + internal::__fastpack19(in, out); + break; + case 20: + internal::__fastpack20(in, out); + break; + case 21: + internal::__fastpack21(in, out); + break; + case 22: + internal::__fastpack22(in, out); + break; + case 23: + internal::__fastpack23(in, out); + break; + case 24: + internal::__fastpack24(in, out); + break; + case 25: + internal::__fastpack25(in, out); + break; + case 26: + internal::__fastpack26(in, out); + break; + case 27: + internal::__fastpack27(in, out); + break; + case 28: + internal::__fastpack28(in, out); + break; + case 29: + internal::__fastpack29(in, out); + break; + case 30: + internal::__fastpack30(in, out); + break; + case 31: + internal::__fastpack31(in, out); + break; + case 32: + internal::__fastpack32(in, out); + break; + default: + throw std::logic_error("Invalid bit width for bitpacking"); + } +} + +inline void fastpack(const uint64_t *__restrict in, + uint32_t *__restrict out, const uint32_t bit) { + switch (bit) { + case 0: + internal::__fastpack0(in, out); + break; + case 1: + internal::__fastpack1(in, out); + break; + case 2: + internal::__fastpack2(in, out); + break; + case 3: + internal::__fastpack3(in, out); + break; + case 4: + internal::__fastpack4(in, out); + break; + case 5: + internal::__fastpack5(in, out); + break; + case 6: + internal::__fastpack6(in, out); + break; + case 7: + internal::__fastpack7(in, out); + break; + case 8: + internal::__fastpack8(in, out); + break; + case 9: + internal::__fastpack9(in, out); + break; + case 10: + internal::__fastpack10(in, out); + break; + case 11: + internal::__fastpack11(in, out); + break; + case 12: + internal::__fastpack12(in, out); + break; + case 13: + internal::__fastpack13(in, out); + break; + case 14: + internal::__fastpack14(in, out); + break; + case 15: + internal::__fastpack15(in, out); + break; + case 16: + internal::__fastpack16(in, out); + break; + case 17: + internal::__fastpack17(in, out); + break; + case 18: + internal::__fastpack18(in, out); + break; + case 19: + internal::__fastpack19(in, out); + break; + case 20: + internal::__fastpack20(in, out); + break; + case 21: + internal::__fastpack21(in, out); + break; + case 22: + internal::__fastpack22(in, out); + break; + case 23: + internal::__fastpack23(in, out); + break; + case 24: + internal::__fastpack24(in, out); + break; + case 25: + internal::__fastpack25(in, out); + break; + case 26: + internal::__fastpack26(in, out); + break; + case 27: + internal::__fastpack27(in, out); + break; + case 28: + internal::__fastpack28(in, out); + break; + case 29: + internal::__fastpack29(in, out); + break; + case 30: + internal::__fastpack30(in, out); + break; + case 31: + internal::__fastpack31(in, out); + break; + case 32: + internal::__fastpack32(in, out); + break; + case 33: + internal::__fastpack33(in, out); + break; + case 34: + internal::__fastpack34(in, out); + break; + case 35: + internal::__fastpack35(in, out); + break; + case 36: + internal::__fastpack36(in, out); + break; + case 37: + internal::__fastpack37(in, out); + break; + case 38: + internal::__fastpack38(in, out); + break; + case 39: + internal::__fastpack39(in, out); + break; + case 40: + internal::__fastpack40(in, out); + break; + case 41: + internal::__fastpack41(in, out); + break; + case 42: + internal::__fastpack42(in, out); + break; + case 43: + internal::__fastpack43(in, out); + break; + case 44: + internal::__fastpack44(in, out); + break; + case 45: + internal::__fastpack45(in, out); + break; + case 46: + internal::__fastpack46(in, out); + break; + case 47: + internal::__fastpack47(in, out); + break; + case 48: + internal::__fastpack48(in, out); + break; + case 49: + internal::__fastpack49(in, out); + break; + case 50: + internal::__fastpack50(in, out); + break; + case 51: + internal::__fastpack51(in, out); + break; + case 52: + internal::__fastpack52(in, out); + break; + case 53: + internal::__fastpack53(in, out); + break; + case 54: + internal::__fastpack54(in, out); + break; + case 55: + internal::__fastpack55(in, out); + break; + case 56: + internal::__fastpack56(in, out); + break; + case 57: + internal::__fastpack57(in, out); + break; + case 58: + internal::__fastpack58(in, out); + break; + case 59: + internal::__fastpack59(in, out); + break; + case 60: + internal::__fastpack60(in, out); + break; + case 61: + internal::__fastpack61(in, out); + break; + case 62: + internal::__fastpack62(in, out); + break; + case 63: + internal::__fastpack63(in, out); + break; + case 64: + internal::__fastpack64(in, out); + break; + default: + throw std::logic_error("Invalid bit width for bitpacking"); + } +} +} // namespace fastpfor_lib diff --git a/src/duckdb/third_party/fmt/format.cc b/src/duckdb/third_party/fmt/format.cc new file mode 100644 index 00000000..a8d34d5a --- /dev/null +++ b/src/duckdb/third_party/fmt/format.cc @@ -0,0 +1,171 @@ +// Formatting library for C++ +// +// Copyright (c) 2012 - 2016, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#include "fmt/format-inl.h" + +FMT_BEGIN_NAMESPACE +namespace internal { + +template +int format_float(char* buf, std::size_t size, const char* format, int precision, + T value) { +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + if (precision > 100000) + throw std::runtime_error( + "fuzz mode - avoid large allocation inside snprintf"); +#endif + // Suppress the warning about nonliteral format string. + auto snprintf_ptr = FMT_SNPRINTF; + return precision < 0 ? snprintf_ptr(buf, size, format, value) + : snprintf_ptr(buf, size, format, precision, value); +} +struct sprintf_specs { + int precision; + char type; + bool alt : 1; + + template + constexpr sprintf_specs(basic_format_specs specs) + : precision(specs.precision), type(specs.type), alt(specs.alt) {} + + constexpr bool has_precision() const { return precision >= 0; } +}; + +// This is deprecated and is kept only to preserve ABI compatibility. +template +char* sprintf_format(Double value, internal::buffer& buf, + sprintf_specs specs) { + // Buffer capacity must be non-zero, otherwise MSVC's vsnprintf_s will fail. + FMT_ASSERT(buf.capacity() != 0, "empty buffer"); + + // Build format string. + enum { max_format_size = 10 }; // longest format: %#-*.*Lg + char format[max_format_size]; + char* format_ptr = format; + *format_ptr++ = '%'; + if (specs.alt || !specs.type) *format_ptr++ = '#'; + if (specs.precision >= 0) { + *format_ptr++ = '.'; + *format_ptr++ = '*'; + } + if (std::is_same::value) *format_ptr++ = 'L'; + + char type = specs.type; + + if (type == '%') + type = 'f'; + else if (type == 0 || type == 'n') + type = 'g'; +#if FMT_MSC_VER + if (type == 'F') { + // MSVC's printf doesn't support 'F'. + type = 'f'; + } +#endif + *format_ptr++ = type; + *format_ptr = '\0'; + + // Format using snprintf. + char* start = nullptr; + char* decimal_point_pos = nullptr; + for (;;) { + std::size_t buffer_size = buf.capacity(); + start = &buf[0]; + int result = + format_float(start, buffer_size, format, specs.precision, value); + if (result >= 0) { + unsigned n = internal::to_unsigned(result); + if (n < buf.capacity()) { + // Find the decimal point. + auto p = buf.data(), end = p + n; + if (*p == '+' || *p == '-') ++p; + if (specs.type != 'a' && specs.type != 'A') { + while (p < end && *p >= '0' && *p <= '9') ++p; + if (p < end && *p != 'e' && *p != 'E') { + decimal_point_pos = p; + if (!specs.type) { + // Keep only one trailing zero after the decimal point. + ++p; + if (*p == '0') ++p; + while (p != end && *p >= '1' && *p <= '9') ++p; + char* where = p; + while (p != end && *p == '0') ++p; + if (p == end || *p < '0' || *p > '9') { + if (p != end) std::memmove(where, p, to_unsigned(end - p)); + n -= static_cast(p - where); + } + } + } + } + buf.resize(n); + break; // The buffer is large enough - continue with formatting. + } + buf.reserve(n + 1); + } else { + // If result is negative we ask to increase the capacity by at least 1, + // but as std::vector, the buffer grows exponentially. + buf.reserve(buf.capacity() + 1); + } + } + return decimal_point_pos; +} +} // namespace internal + +template FMT_API char* internal::sprintf_format(double, internal::buffer&, + sprintf_specs); +template FMT_API char* internal::sprintf_format(long double, + internal::buffer&, + sprintf_specs); + +template struct FMT_API internal::basic_data; + +// Workaround a bug in MSVC2013 that prevents instantiation of format_float. +int (*instantiate_format_float)(double, int, internal::float_specs, + internal::buffer&) = + internal::format_float; + +// Explicit instantiations for char. + +template FMT_API std::string internal::grouping_impl(locale_ref); +template FMT_API char internal::thousands_sep_impl(locale_ref); +template FMT_API char internal::decimal_point_impl(locale_ref); + +template FMT_API void internal::buffer::append(const char*, const char*); + +template FMT_API void internal::arg_map::init( + const basic_format_args& args); + +template FMT_API std::string internal::vformat( + string_view, basic_format_args); + +template FMT_API format_context::iterator internal::vformat_to( + internal::buffer&, string_view, basic_format_args); + +template FMT_API int internal::snprintf_float(double, int, + internal::float_specs, + internal::buffer&); +template FMT_API int internal::snprintf_float(long double, int, + internal::float_specs, + internal::buffer&); +template FMT_API int internal::format_float(double, int, internal::float_specs, + internal::buffer&); +template FMT_API int internal::format_float(long double, int, + internal::float_specs, + internal::buffer&); + +// Explicit instantiations for wchar_t. + +template FMT_API std::string internal::grouping_impl(locale_ref); +template FMT_API wchar_t internal::thousands_sep_impl(locale_ref); +template FMT_API wchar_t internal::decimal_point_impl(locale_ref); + +template FMT_API void internal::buffer::append(const wchar_t*, + const wchar_t*); + +template FMT_API std::wstring internal::vformat( + wstring_view, basic_format_args); +FMT_END_NAMESPACE diff --git a/src/duckdb/third_party/fmt/include/fmt/core.h b/src/duckdb/third_party/fmt/include/fmt/core.h new file mode 100644 index 00000000..70d899a1 --- /dev/null +++ b/src/duckdb/third_party/fmt/include/fmt/core.h @@ -0,0 +1,1473 @@ +// Formatting library for C++ - the core API +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_CORE_H_ +#define FMT_CORE_H_ + +#include // std::FILE +#include +#include +#include +#include + +// The fmt library version in the form major * 10000 + minor * 100 + patch. +#define FMT_VERSION 60102 + +#ifdef __has_feature +# define FMT_HAS_FEATURE(x) __has_feature(x) +#else +# define FMT_HAS_FEATURE(x) 0 +#endif + +#if defined(__has_include) && !defined(__INTELLISENSE__) && \ + !(defined(__INTEL_COMPILER) && __INTEL_COMPILER < 1600) +# define FMT_HAS_INCLUDE(x) __has_include(x) +#else +# define FMT_HAS_INCLUDE(x) 0 +#endif + +#ifdef __has_cpp_attribute +# define FMT_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +# define FMT_HAS_CPP_ATTRIBUTE(x) 0 +#endif + +#if defined(__GNUC__) && !defined(__clang__) +# define FMT_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) +#else +# define FMT_GCC_VERSION 0 +#endif + +#if __cplusplus >= 201103L || defined(__GXX_EXPERIMENTAL_CXX0X__) +# define FMT_HAS_GXX_CXX11 FMT_GCC_VERSION +#else +# define FMT_HAS_GXX_CXX11 0 +#endif + +#ifdef __NVCC__ +# define FMT_NVCC __NVCC__ +#else +# define FMT_NVCC 0 +#endif + +#ifdef _MSC_VER +# define FMT_MSC_VER _MSC_VER +#else +# define FMT_MSC_VER 0 +#endif + +// Check if relaxed C++14 constexpr is supported. +// GCC doesn't allow throw in constexpr until version 6 (bug 67371). +#if FMT_USE_CONSTEXPR +# define FMT_CONSTEXPR inline +# define FMT_CONSTEXPR_DECL +#else +# define FMT_CONSTEXPR inline +# define FMT_CONSTEXPR_DECL +#endif + +#ifndef FMT_OVERRIDE +# if FMT_HAS_FEATURE(cxx_override) || \ + (FMT_GCC_VERSION >= 408 && FMT_HAS_GXX_CXX11) || FMT_MSC_VER >= 1900 +# define FMT_OVERRIDE override +# else +# define FMT_OVERRIDE +# endif +#endif + +// Check if exceptions are disabled. +#ifndef FMT_EXCEPTIONS +# if (defined(__GNUC__) && !defined(__EXCEPTIONS)) || \ + FMT_MSC_VER && !_HAS_EXCEPTIONS +# define FMT_EXCEPTIONS 0 +# else +# define FMT_EXCEPTIONS 1 +# endif +#endif + +// Define FMT_USE_NOEXCEPT to make fmt use noexcept (C++11 feature). +#ifndef FMT_USE_NOEXCEPT +# define FMT_USE_NOEXCEPT 0 +#endif + +#if FMT_USE_NOEXCEPT || FMT_HAS_FEATURE(cxx_noexcept) || \ + (FMT_GCC_VERSION >= 408 && FMT_HAS_GXX_CXX11) || FMT_MSC_VER >= 1900 +# define FMT_DETECTED_NOEXCEPT noexcept +# define FMT_HAS_CXX11_NOEXCEPT 1 +#else +# define FMT_DETECTED_NOEXCEPT throw() +# define FMT_HAS_CXX11_NOEXCEPT 0 +#endif + +#ifndef FMT_NOEXCEPT +# if FMT_EXCEPTIONS || FMT_HAS_CXX11_NOEXCEPT +# define FMT_NOEXCEPT FMT_DETECTED_NOEXCEPT +# else +# define FMT_NOEXCEPT +# endif +#endif + +// [[noreturn]] is disabled on MSVC because of bogus unreachable code warnings. +#if FMT_EXCEPTIONS && FMT_HAS_CPP_ATTRIBUTE(noreturn) && !FMT_MSC_VER +# define FMT_NORETURN [[noreturn]] +#else +# define FMT_NORETURN +#endif + +#ifndef FMT_DEPRECATED +# if (FMT_HAS_CPP_ATTRIBUTE(deprecated) && __cplusplus >= 201402L) || \ + FMT_MSC_VER >= 1900 +# define FMT_DEPRECATED [[deprecated]] +# else +# if defined(__GNUC__) || defined(__clang__) +# define FMT_DEPRECATED __attribute__((deprecated)) +# elif FMT_MSC_VER +# define FMT_DEPRECATED __declspec(deprecated) +# else +# define FMT_DEPRECATED /* deprecated */ +# endif +# endif +#endif + +// Workaround broken [[deprecated]] in the Intel compiler and NVCC. +#if defined(__INTEL_COMPILER) || FMT_NVCC +# define FMT_DEPRECATED_ALIAS +#else +# define FMT_DEPRECATED_ALIAS FMT_DEPRECATED +#endif + +#ifndef FMT_BEGIN_NAMESPACE +# if FMT_HAS_FEATURE(cxx_inline_namespaces) || FMT_GCC_VERSION >= 404 || \ + FMT_MSC_VER >= 1900 +# define FMT_INLINE_NAMESPACE inline namespace +# define FMT_END_NAMESPACE \ + } \ + } +# else +# define FMT_INLINE_NAMESPACE namespace +# define FMT_END_NAMESPACE \ + } \ + using namespace v6; \ + } +# endif +# define FMT_BEGIN_NAMESPACE \ + namespace duckdb_fmt { \ + FMT_INLINE_NAMESPACE v6 { +#endif + +#if !defined(FMT_HEADER_ONLY) && defined(_WIN32) +# ifdef FMT_EXPORT +# define FMT_API __declspec(dllexport) +# elif defined(FMT_SHARED) +# define FMT_API __declspec(dllimport) +# define FMT_EXTERN_TEMPLATE_API FMT_API +# endif +#endif +#ifndef FMT_API +# define FMT_API +#endif +#ifndef FMT_EXTERN_TEMPLATE_API +# define FMT_EXTERN_TEMPLATE_API +#endif + +#ifndef FMT_HEADER_ONLY +# define FMT_EXTERN extern +#else +# define FMT_EXTERN +#endif + +// libc++ supports string_view in pre-c++17. +#if (FMT_HAS_INCLUDE() && \ + (__cplusplus > 201402L || defined(_LIBCPP_VERSION))) || \ + (defined(_MSVC_LANG) && _MSVC_LANG > 201402L && _MSC_VER >= 1910) +# include +# define FMT_USE_STRING_VIEW +#elif FMT_HAS_INCLUDE("experimental/string_view") && __cplusplus >= 201402L +# include +# define FMT_USE_EXPERIMENTAL_STRING_VIEW +#endif + +FMT_BEGIN_NAMESPACE + +// Implementations of enable_if_t and other types for pre-C++14 systems. +template +using enable_if_t = typename std::enable_if::type; +template +using conditional_t = typename std::conditional::type; +template using bool_constant = std::integral_constant; +template +using remove_reference_t = typename std::remove_reference::type; +template +using remove_const_t = typename std::remove_const::type; +template +using remove_cvref_t = typename std::remove_cv>::type; + +struct monostate {}; + +// An enable_if helper to be used in template parameters which results in much +// shorter symbols: https://godbolt.org/z/sWw4vP. Extra parentheses are needed +// to workaround a bug in MSVC 2019 (see #1140 and #1186). +#define FMT_ENABLE_IF(...) enable_if_t<(__VA_ARGS__), int> = 0 + +namespace internal { + +// A workaround for gcc 4.8 to make void_t work in a SFINAE context. +template struct void_t_impl { using type = void; }; + +#ifndef FMT_ASSERT +#define FMT_ASSERT(condition, message) +#endif + +#if defined(FMT_USE_STRING_VIEW) +template using std_string_view = std::basic_string_view; +#elif defined(FMT_USE_EXPERIMENTAL_STRING_VIEW) +template +using std_string_view = std::experimental::basic_string_view; +#else +template struct std_string_view {}; +#endif + +#ifdef FMT_USE_INT128 +// Do nothing. +#elif defined(__SIZEOF_INT128__) +# define FMT_USE_INT128 1 +using int128_t = __int128_t; +using uint128_t = __uint128_t; +#else +# define FMT_USE_INT128 0 +#endif +#if !FMT_USE_INT128 +struct int128_t {}; +struct uint128_t {}; +#endif + +// Casts a nonnegative integer to unsigned. +template +FMT_CONSTEXPR typename std::make_unsigned::type to_unsigned(Int value) { + FMT_ASSERT(value >= 0, "negative value"); + return static_cast::type>(value); +} +} // namespace internal + +template +using void_t = typename internal::void_t_impl::type; + +/** + An implementation of ``std::basic_string_view`` for pre-C++17. It provides a + subset of the API. ``fmt::basic_string_view`` is used for format strings even + if ``std::string_view`` is available to prevent issues when a library is + compiled with a different ``-std`` option than the client code (which is not + recommended). + */ +template class basic_string_view { + private: + const Char* data_; + size_t size_; + + public: + using char_type = Char; + using iterator = const Char*; + + FMT_CONSTEXPR basic_string_view() FMT_NOEXCEPT : data_(nullptr), size_(0) {} + + /** Constructs a string reference object from a C string and a size. */ + FMT_CONSTEXPR basic_string_view(const Char* s, size_t count) FMT_NOEXCEPT + : data_(s), + size_(count) {} + + /** + \rst + Constructs a string reference object from a C string computing + the size with ``std::char_traits::length``. + \endrst + */ + basic_string_view(const Char* s) + : data_(s), size_(std::char_traits::length(s)) {} + + /** Constructs a string reference from a ``std::basic_string`` object. */ + template + FMT_CONSTEXPR basic_string_view( + const std::basic_string& s) FMT_NOEXCEPT + : data_(s.data()), + size_(s.size()) {} + + template < + typename S, + FMT_ENABLE_IF(std::is_same>::value)> + FMT_CONSTEXPR basic_string_view(S s) FMT_NOEXCEPT : data_(s.data()), + size_(s.size()) {} + + /** Returns a pointer to the string data. */ + FMT_CONSTEXPR const Char* data() const { return data_; } + + /** Returns the string size. */ + FMT_CONSTEXPR size_t size() const { return size_; } + + FMT_CONSTEXPR iterator begin() const { return data_; } + FMT_CONSTEXPR iterator end() const { return data_ + size_; } + + FMT_CONSTEXPR const Char& operator[](size_t pos) const { return data_[pos]; } + + FMT_CONSTEXPR void remove_prefix(size_t n) { + data_ += n; + size_ -= n; + } + + std::string to_string() { + return std::string((char *) data(), size()); + } + + // Lexicographically compare this string reference to other. + int compare(basic_string_view other) const { + size_t str_size = size_ < other.size_ ? size_ : other.size_; + int result = std::char_traits::compare(data_, other.data_, str_size); + if (result == 0) + result = size_ == other.size_ ? 0 : (size_ < other.size_ ? -1 : 1); + return result; + } + + friend bool operator==(basic_string_view lhs, basic_string_view rhs) { + return lhs.compare(rhs) == 0; + } + friend bool operator!=(basic_string_view lhs, basic_string_view rhs) { + return lhs.compare(rhs) != 0; + } + friend bool operator<(basic_string_view lhs, basic_string_view rhs) { + return lhs.compare(rhs) < 0; + } + friend bool operator<=(basic_string_view lhs, basic_string_view rhs) { + return lhs.compare(rhs) <= 0; + } + friend bool operator>(basic_string_view lhs, basic_string_view rhs) { + return lhs.compare(rhs) > 0; + } + friend bool operator>=(basic_string_view lhs, basic_string_view rhs) { + return lhs.compare(rhs) >= 0; + } +}; + +using string_view = basic_string_view; +using wstring_view = basic_string_view; + +// A UTF-8 code unit type. +#if FMT_HAS_FEATURE(__cpp_char8_t) +typedef char8_t fmt_char8_t; +#else +typedef char fmt_char8_t; +#endif + +/** Specifies if ``T`` is a character type. Can be specialized by users. */ +template struct is_char : std::false_type {}; +template <> struct is_char : std::true_type {}; +template <> struct is_char : std::true_type {}; +template <> struct is_char : std::true_type {}; +template <> struct is_char : std::true_type {}; + +/** + \rst + Returns a string view of `s`. In order to add custom string type support to + {fmt} provide an overload of `to_string_view` for it in the same namespace as + the type for the argument-dependent lookup to work. + + **Example**:: + + namespace my_ns { + inline string_view to_string_view(const my_string& s) { + return {s.data(), s.length()}; + } + } + std::string message = fmt::format(my_string("The answer is {}"), 42); + \endrst + */ +template ::value)> +inline basic_string_view to_string_view(const Char* s) { + return s; +} + +template +inline basic_string_view to_string_view( + const std::basic_string& s) { + return s; +} + +template +inline basic_string_view to_string_view(basic_string_view s) { + return s; +} + +template >::value)> +inline basic_string_view to_string_view( + internal::std_string_view s) { + return s; +} + +// A base class for compile-time strings. It is defined in the fmt namespace to +// make formatting functions visible via ADL, e.g. format(fmt("{}"), 42). +struct compile_string {}; + +template +struct is_compile_string : std::is_base_of {}; + +template ::value)> +FMT_CONSTEXPR basic_string_view to_string_view(const S& s) { + return s; +} + +namespace internal { +void to_string_view(...); +using duckdb_fmt::v6::to_string_view; + +// Specifies whether S is a string type convertible to fmt::basic_string_view. +// It should be a constexpr function but MSVC 2017 fails to compile it in +// enable_if and MSVC 2015 fails to compile it as an alias template. +template +struct is_string : std::is_class()))> { +}; + +template struct char_t_impl {}; +template struct char_t_impl::value>> { + using result = decltype(to_string_view(std::declval())); + using type = typename result::char_type; +}; + +struct error_handler { + FMT_CONSTEXPR error_handler() = default; + FMT_CONSTEXPR error_handler(const error_handler&) = default; + + // This function is intentionally not constexpr to give a compile-time error. + FMT_NORETURN FMT_API void on_error(std::string message); +}; +} // namespace internal + +/** String's character type. */ +template using char_t = typename internal::char_t_impl::type; + +/** + \rst + Parsing context consisting of a format string range being parsed and an + argument counter for automatic indexing. + + You can use one of the following type aliases for common character types: + + +-----------------------+-------------------------------------+ + | Type | Definition | + +=======================+=====================================+ + | format_parse_context | basic_format_parse_context | + +-----------------------+-------------------------------------+ + | wformat_parse_context | basic_format_parse_context | + +-----------------------+-------------------------------------+ + \endrst + */ +template +class basic_format_parse_context : private ErrorHandler { + private: + basic_string_view format_str_; + int next_arg_id_; + + public: + using char_type = Char; + using iterator = typename basic_string_view::iterator; + + explicit FMT_CONSTEXPR basic_format_parse_context( + basic_string_view format_str, ErrorHandler eh = ErrorHandler()) + : ErrorHandler(eh), format_str_(format_str), next_arg_id_(0) {} + + /** + Returns an iterator to the beginning of the format string range being + parsed. + */ + FMT_CONSTEXPR iterator begin() const FMT_NOEXCEPT { + return format_str_.begin(); + } + + /** + Returns an iterator past the end of the format string range being parsed. + */ + FMT_CONSTEXPR iterator end() const FMT_NOEXCEPT { return format_str_.end(); } + + /** Advances the begin iterator to ``it``. */ + FMT_CONSTEXPR void advance_to(iterator it) { + format_str_.remove_prefix(internal::to_unsigned(it - begin())); + } + + /** + Reports an error if using the manual argument indexing; otherwise returns + the next argument index and switches to the automatic indexing. + */ + FMT_CONSTEXPR int next_arg_id() { + if (next_arg_id_ >= 0) return next_arg_id_++; + on_error("cannot switch from manual to automatic argument indexing"); + return 0; + } + + /** + Reports an error if using the automatic argument indexing; otherwise + switches to the manual indexing. + */ + FMT_CONSTEXPR void check_arg_id(int) { + if (next_arg_id_ > 0) + on_error("cannot switch from automatic to manual argument indexing"); + else + next_arg_id_ = -1; + } + + FMT_CONSTEXPR void check_arg_id(basic_string_view) {} + + FMT_CONSTEXPR void on_error(std::string message) { + ErrorHandler::on_error(message); + } + + FMT_CONSTEXPR ErrorHandler error_handler() const { return *this; } +}; + +using format_parse_context = basic_format_parse_context; +using wformat_parse_context = basic_format_parse_context; + +template +using basic_parse_context FMT_DEPRECATED_ALIAS = + basic_format_parse_context; +using parse_context FMT_DEPRECATED_ALIAS = basic_format_parse_context; +using wparse_context FMT_DEPRECATED_ALIAS = basic_format_parse_context; + +template class basic_format_arg; +template class basic_format_args; + +// A formatter for objects of type T. +template +struct formatter { + // A deleted default constructor indicates a disabled formatter. + formatter() = delete; +}; + +template +struct FMT_DEPRECATED convert_to_int + : bool_constant::value && + std::is_convertible::value> {}; + +// Specifies if T has an enabled formatter specialization. A type can be +// formattable even if it doesn't have a formatter e.g. via a conversion. +template +using has_formatter = + std::is_constructible>; + +namespace internal { + +/** A contiguous memory buffer with an optional growing ability. */ +template class buffer { + private: + T* ptr_; + std::size_t size_; + std::size_t capacity_; + + protected: + // Don't initialize ptr_ since it is not accessed to save a few cycles. + buffer(std::size_t sz) FMT_NOEXCEPT : size_(sz), capacity_(sz) {} + + buffer(T* p = nullptr, std::size_t sz = 0, std::size_t cap = 0) FMT_NOEXCEPT + : ptr_(p), + size_(sz), + capacity_(cap) {} + + /** Sets the buffer data and capacity. */ + void set(T* buf_data, std::size_t buf_capacity) FMT_NOEXCEPT { + ptr_ = buf_data; + capacity_ = buf_capacity; + } + + /** Increases the buffer capacity to hold at least *capacity* elements. */ + virtual void grow(std::size_t capacity) = 0; + + public: + using value_type = T; + using const_reference = const T&; + + buffer(const buffer&) = delete; + void operator=(const buffer&) = delete; + virtual ~buffer() = default; + + T* begin() FMT_NOEXCEPT { return ptr_; } + T* end() FMT_NOEXCEPT { return ptr_ + size_; } + + /** Returns the size of this buffer. */ + std::size_t size() const FMT_NOEXCEPT { return size_; } + + /** Returns the capacity of this buffer. */ + std::size_t capacity() const FMT_NOEXCEPT { return capacity_; } + + /** Returns a pointer to the buffer data. */ + T* data() FMT_NOEXCEPT { return ptr_; } + + /** Returns a pointer to the buffer data. */ + const T* data() const FMT_NOEXCEPT { return ptr_; } + + /** + Resizes the buffer. If T is a POD type new elements may not be initialized. + */ + void resize(std::size_t new_size) { + reserve(new_size); + size_ = new_size; + } + + /** Clears this buffer. */ + void clear() { size_ = 0; } + + /** Reserves space to store at least *capacity* elements. */ + void reserve(std::size_t new_capacity) { + if (new_capacity > capacity_) grow(new_capacity); + } + + void push_back(const T& value) { + reserve(size_ + 1); + ptr_[size_++] = value; + } + + /** Appends data to the end of the buffer. */ + template void append(const U* begin, const U* end); + + T& operator[](std::size_t index) { return ptr_[index]; } + const T& operator[](std::size_t index) const { return ptr_[index]; } +}; + +// A container-backed buffer. +template +class container_buffer : public buffer { + private: + Container& container_; + + protected: + void grow(std::size_t capacity) FMT_OVERRIDE { + container_.resize(capacity); + this->set(&container_[0], capacity); + } + + public: + explicit container_buffer(Container& c) + : buffer(c.size()), container_(c) {} +}; + +// Extracts a reference to the container from back_insert_iterator. +template +inline Container& get_container(std::back_insert_iterator it) { + using bi_iterator = std::back_insert_iterator; + struct accessor : bi_iterator { + accessor(bi_iterator iter) : bi_iterator(iter) {} + using bi_iterator::container; + }; + return *accessor(it).container; +} + +template +struct fallback_formatter { + fallback_formatter() = delete; +}; + +// Specifies if T has an enabled fallback_formatter specialization. +template +using has_fallback_formatter = + std::is_constructible>; + +template struct named_arg_base; +template struct named_arg; + +enum type { + none_type, + named_arg_type, + // Integer types should go first, + int_type, + uint_type, + long_long_type, + ulong_long_type, + int128_type, + uint128_type, + bool_type, + char_type, + last_integer_type = char_type, + // followed by floating-point types. + float_type, + double_type, + long_double_type, + last_numeric_type = long_double_type, + cstring_type, + string_type, + pointer_type, + custom_type +}; + +// Maps core type T to the corresponding type enum constant. +template +struct type_constant : std::integral_constant {}; + +#define FMT_TYPE_CONSTANT(Type, constant) \ + template \ + struct type_constant : std::integral_constant {} + +FMT_TYPE_CONSTANT(const named_arg_base&, named_arg_type); +FMT_TYPE_CONSTANT(int, int_type); +FMT_TYPE_CONSTANT(unsigned, uint_type); +FMT_TYPE_CONSTANT(long long, long_long_type); +FMT_TYPE_CONSTANT(unsigned long long, ulong_long_type); +FMT_TYPE_CONSTANT(int128_t, int128_type); +FMT_TYPE_CONSTANT(uint128_t, uint128_type); +FMT_TYPE_CONSTANT(bool, bool_type); +FMT_TYPE_CONSTANT(Char, char_type); +FMT_TYPE_CONSTANT(float, float_type); +FMT_TYPE_CONSTANT(double, double_type); +FMT_TYPE_CONSTANT(long double, long_double_type); +FMT_TYPE_CONSTANT(const Char*, cstring_type); +FMT_TYPE_CONSTANT(basic_string_view, string_type); +FMT_TYPE_CONSTANT(const void*, pointer_type); + +FMT_CONSTEXPR bool is_integral_type(type t) { + FMT_ASSERT(t != named_arg_type, "invalid argument type"); + return t > none_type && t <= last_integer_type; +} + +FMT_CONSTEXPR bool is_arithmetic_type(type t) { + FMT_ASSERT(t != named_arg_type, "invalid argument type"); + return t > none_type && t <= last_numeric_type; +} + +template struct string_value { + const Char* data; + std::size_t size; +}; + +template struct custom_value { + using parse_context = basic_format_parse_context; + const void* value; + void (*format)(const void* arg, parse_context& parse_ctx, Context& ctx); +}; + +// A formatting argument value. +template class value { + public: + using char_type = typename Context::char_type; + + union { + int int_value; + unsigned uint_value; + long long long_long_value; + unsigned long long ulong_long_value; + int128_t int128_value; + uint128_t uint128_value; + bool bool_value; + char_type char_value; + float float_value; + double double_value; + long double long_double_value; + const void* pointer; + string_value string; + custom_value custom; + const named_arg_base* named_arg; + }; + + FMT_CONSTEXPR value(int val = 0) : int_value(val) {} + FMT_CONSTEXPR value(unsigned val) : uint_value(val) {} + value(long long val) : long_long_value(val) {} + value(unsigned long long val) : ulong_long_value(val) {} + value(int128_t val) : int128_value(val) {} + value(uint128_t val) : uint128_value(val) {} + value(float val) : float_value(val) {} + value(double val) : double_value(val) {} + value(long double val) : long_double_value(val) {} + value(bool val) : bool_value(val) {} + value(char_type val) : char_value(val) {} + value(const char_type* val) { string.data = val; } + value(basic_string_view val) { + string.data = val.data(); + string.size = val.size(); + } + value(const void* val) : pointer(val) {} + + template value(const T& val) { + custom.value = &val; + // Get the formatter type through the context to allow different contexts + // have different extension points, e.g. `formatter` for `format` and + // `printf_formatter` for `printf`. + custom.format = format_custom_arg< + T, conditional_t::value, + typename Context::template formatter_type, + fallback_formatter>>; + } + + value(const named_arg_base& val) { named_arg = &val; } + + private: + // Formats an argument of a custom type, such as a user-defined class. + template + static void format_custom_arg( + const void* arg, basic_format_parse_context& parse_ctx, + Context& ctx) { + Formatter f; + parse_ctx.advance_to(f.parse(parse_ctx)); + ctx.advance_to(f.format(*static_cast(arg), ctx)); + } +}; + +template +FMT_CONSTEXPR basic_format_arg make_arg(const T& value); + +// To minimize the number of types we need to deal with, long is translated +// either to int or to long long depending on its size. +enum { long_short = sizeof(long) == sizeof(int) }; +using long_type = conditional_t; +using ulong_type = conditional_t; + +// Maps formatting arguments to core types. +template struct arg_mapper { + using char_type = typename Context::char_type; + + FMT_CONSTEXPR int map(signed char val) { return val; } + FMT_CONSTEXPR unsigned map(unsigned char val) { return val; } + FMT_CONSTEXPR int map(short val) { return val; } + FMT_CONSTEXPR unsigned map(unsigned short val) { return val; } + FMT_CONSTEXPR int map(int val) { return val; } + FMT_CONSTEXPR unsigned map(unsigned val) { return val; } + FMT_CONSTEXPR long_type map(long val) { return val; } + FMT_CONSTEXPR ulong_type map(unsigned long val) { return val; } + FMT_CONSTEXPR long long map(long long val) { return val; } + FMT_CONSTEXPR unsigned long long map(unsigned long long val) { return val; } + FMT_CONSTEXPR int128_t map(int128_t val) { return val; } + FMT_CONSTEXPR uint128_t map(uint128_t val) { return val; } + FMT_CONSTEXPR bool map(bool val) { return val; } + + template ::value)> + FMT_CONSTEXPR char_type map(T val) { + static_assert( + std::is_same::value || std::is_same::value, + "mixing character types is disallowed"); + return val; + } + + FMT_CONSTEXPR float map(float val) { return val; } + FMT_CONSTEXPR double map(double val) { return val; } + FMT_CONSTEXPR long double map(long double val) { return val; } + + FMT_CONSTEXPR const char_type* map(char_type* val) { return val; } + FMT_CONSTEXPR const char_type* map(const char_type* val) { return val; } + template ::value)> + FMT_CONSTEXPR basic_string_view map(const T& val) { + static_assert(std::is_same>::value, + "mixing character types is disallowed"); + return to_string_view(val); + } + template , T>::value && + !is_string::value)> + FMT_CONSTEXPR basic_string_view map(const T& val) { + return basic_string_view(val); + } + template < + typename T, + FMT_ENABLE_IF( + std::is_constructible, T>::value && + !std::is_constructible, T>::value && + !is_string::value && !has_formatter::value)> + FMT_CONSTEXPR basic_string_view map(const T& val) { + return std_string_view(val); + } + FMT_CONSTEXPR const char* map(const signed char* val) { + static_assert(std::is_same::value, "invalid string type"); + return reinterpret_cast(val); + } + FMT_CONSTEXPR const char* map(const unsigned char* val) { + static_assert(std::is_same::value, "invalid string type"); + return reinterpret_cast(val); + } + + FMT_CONSTEXPR const void* map(void* val) { return val; } + FMT_CONSTEXPR const void* map(const void* val) { return val; } + FMT_CONSTEXPR const void* map(std::nullptr_t val) { return val; } + template FMT_CONSTEXPR int map(const T*) { + // Formatting of arbitrary pointers is disallowed. If you want to output + // a pointer cast it to "void *" or "const void *". In particular, this + // forbids formatting of "[const] volatile char *" which is printed as bool + // by iostreams. + static_assert(!sizeof(T), "formatting of non-void pointers is disallowed"); + return 0; + } + + template ::value && + !has_formatter::value && + !has_fallback_formatter::value)> + FMT_CONSTEXPR auto map(const T& val) -> decltype( + map(static_cast::type>(val))) { + return map(static_cast::type>(val)); + } + template < + typename T, + FMT_ENABLE_IF( + !is_string::value && !is_char::value && + !std::is_constructible, T>::value && + (has_formatter::value || + (has_fallback_formatter::value && + !std::is_constructible, T>::value)))> + FMT_CONSTEXPR const T& map(const T& val) { + return val; + } + + template + FMT_CONSTEXPR const named_arg_base& map( + const named_arg& val) { + auto arg = make_arg(val.value); + std::memcpy(val.data, &arg, sizeof(arg)); + return val; + } +}; + +// A type constant after applying arg_mapper. +template +using mapped_type_constant = + type_constant().map(std::declval())), + typename Context::char_type>; + +enum { packed_arg_bits = 5 }; +// Maximum number of arguments with packed types. +enum { max_packed_args = 63 / packed_arg_bits }; +enum : unsigned long long { is_unpacked_bit = 1ULL << 63 }; + +template class arg_map; +} // namespace internal + +// A formatting argument. It is a trivially copyable/constructible type to +// allow storage in basic_memory_buffer. +template class basic_format_arg { + private: + internal::value value_; + internal::type type_; + + template + friend FMT_CONSTEXPR basic_format_arg internal::make_arg( + const T& value); + + template + friend FMT_CONSTEXPR auto visit_format_arg(Visitor&& vis, + const basic_format_arg& arg) + -> decltype(vis(0)); + + friend class basic_format_args; + friend class internal::arg_map; + + using char_type = typename Context::char_type; + + public: + class handle { + public: + explicit handle(internal::custom_value custom) : custom_(custom) {} + + void format(basic_format_parse_context& parse_ctx, + Context& ctx) const { + custom_.format(custom_.value, parse_ctx, ctx); + } + + private: + internal::custom_value custom_; + }; + + FMT_CONSTEXPR basic_format_arg() : type_(internal::none_type) {} + + FMT_CONSTEXPR explicit operator bool() const FMT_NOEXCEPT { + return type_ != internal::none_type; + } + + internal::type type() const { return type_; } + + bool is_integral() const { return internal::is_integral_type(type_); } + bool is_arithmetic() const { return internal::is_arithmetic_type(type_); } +}; + +/** + \rst + Visits an argument dispatching to the appropriate visit method based on + the argument type. For example, if the argument type is ``double`` then + ``vis(value)`` will be called with the value of type ``double``. + \endrst + */ +template +FMT_CONSTEXPR auto visit_format_arg(Visitor&& vis, + const basic_format_arg& arg) + -> decltype(vis(0)) { + using char_type = typename Context::char_type; + switch (arg.type_) { + case internal::none_type: + break; + case internal::named_arg_type: + FMT_ASSERT(false, "invalid argument type"); + break; + case internal::int_type: + return vis(arg.value_.int_value); + case internal::uint_type: + return vis(arg.value_.uint_value); + case internal::long_long_type: + return vis(arg.value_.long_long_value); + case internal::ulong_long_type: + return vis(arg.value_.ulong_long_value); +#if FMT_USE_INT128 + case internal::int128_type: + return vis(arg.value_.int128_value); + case internal::uint128_type: + return vis(arg.value_.uint128_value); +#else + case internal::int128_type: + case internal::uint128_type: + break; +#endif + case internal::bool_type: + return vis(arg.value_.bool_value); + case internal::char_type: + return vis(arg.value_.char_value); + case internal::float_type: + return vis(arg.value_.float_value); + case internal::double_type: + return vis(arg.value_.double_value); + case internal::long_double_type: + return vis(arg.value_.long_double_value); + case internal::cstring_type: + return vis(arg.value_.string.data); + case internal::string_type: + return vis(basic_string_view(arg.value_.string.data, + arg.value_.string.size)); + case internal::pointer_type: + return vis(arg.value_.pointer); + case internal::custom_type: + return vis(typename basic_format_arg::handle(arg.value_.custom)); + } + return vis(monostate()); +} + +namespace internal { +// A map from argument names to their values for named arguments. +template class arg_map { + private: + using char_type = typename Context::char_type; + + struct entry { + basic_string_view name; + basic_format_arg arg; + }; + + entry* map_; + unsigned size_; + + void push_back(value val) { + const auto& named = *val.named_arg; + map_[size_] = {named.name, named.template deserialize()}; + ++size_; + } + + public: + arg_map(const arg_map&) = delete; + void operator=(const arg_map&) = delete; + arg_map() : map_(nullptr), size_(0) {} + void init(const basic_format_args& args); + ~arg_map() { delete[] map_; } + + basic_format_arg find(basic_string_view name) const { + // The list is unsorted, so just return the first matching name. + for (entry *it = map_, *end = map_ + size_; it != end; ++it) { + if (it->name == name) return it->arg; + } + return {}; + } +}; + +// A type-erased reference to an std::locale to avoid heavy include. +class locale_ref { + private: + const void* locale_; // A type-erased pointer to std::locale. + + public: + locale_ref() : locale_(nullptr) {} + template explicit locale_ref(const Locale& loc); + + explicit operator bool() const FMT_NOEXCEPT { return locale_ != nullptr; } + + template Locale get() const; +}; + +template constexpr unsigned long long encode_types() { return 0; } + +template +constexpr unsigned long long encode_types() { + return mapped_type_constant::value | + (encode_types() << packed_arg_bits); +} + +template +FMT_CONSTEXPR basic_format_arg make_arg(const T& value) { + basic_format_arg arg; + arg.type_ = mapped_type_constant::value; + arg.value_ = arg_mapper().map(value); + return arg; +} + +template +inline value make_arg(const T& val) { + return arg_mapper().map(val); +} + +template +inline basic_format_arg make_arg(const T& value) { + return make_arg(value); +} +} // namespace internal + +// Formatting context. +template class basic_format_context { + public: + /** The character type for the output. */ + using char_type = Char; + + private: + OutputIt out_; + basic_format_args args_; + internal::arg_map map_; + internal::locale_ref loc_; + + public: + using iterator = OutputIt; + using format_arg = basic_format_arg; + template using formatter_type = formatter; + + basic_format_context(const basic_format_context&) = delete; + void operator=(const basic_format_context&) = delete; + /** + Constructs a ``basic_format_context`` object. References to the arguments are + stored in the object so make sure they have appropriate lifetimes. + */ + basic_format_context(OutputIt out, + basic_format_args ctx_args, + internal::locale_ref loc = internal::locale_ref()) + : out_(out), args_(ctx_args), loc_(loc) {} + + format_arg arg(int id) const { return args_.get(id); } + + // Checks if manual indexing is used and returns the argument with the + // specified name. + format_arg arg(basic_string_view name); + + internal::error_handler error_handler() { return {}; } + void on_error(std::string message) { error_handler().on_error(message); } + + // Returns an iterator to the beginning of the output range. + iterator out() { return out_; } + + // Advances the begin iterator to ``it``. + void advance_to(iterator it) { out_ = it; } + + internal::locale_ref locale() { return loc_; } +}; + +template +using buffer_context = + basic_format_context>, + Char>; +using format_context = buffer_context; +using wformat_context = buffer_context; + +/** + \rst + An array of references to arguments. It can be implicitly converted into + `~fmt::basic_format_args` for passing into type-erased formatting functions + such as `~fmt::vformat`. + \endrst + */ +template class format_arg_store { + private: + static const size_t num_args = sizeof...(Args); + static const bool is_packed = num_args < internal::max_packed_args; + + using value_type = conditional_t, + basic_format_arg>; + + // If the arguments are not packed, add one more element to mark the end. + value_type data_[num_args + (num_args == 0 ? 1 : 0)]; + + friend class basic_format_args; + + public: + static constexpr unsigned long long types = + is_packed ? internal::encode_types() + : internal::is_unpacked_bit | num_args; + + format_arg_store(const Args&... args) + : data_{internal::make_arg(args)...} {} +}; + +/** + \rst + Constructs an `~fmt::format_arg_store` object that contains references to + arguments and can be implicitly converted to `~fmt::format_args`. `Context` + can be omitted in which case it defaults to `~fmt::context`. + See `~fmt::arg` for lifetime considerations. + \endrst + */ +template +inline format_arg_store make_format_args( + const Args&... args) { + return {args...}; +} + +/** Formatting arguments. */ +template class basic_format_args { + public: + using size_type = int; + using format_arg = basic_format_arg; + + private: + // To reduce compiled code size per formatting function call, types of first + // max_packed_args arguments are passed in the types_ field. + unsigned long long types_; + union { + // If the number of arguments is less than max_packed_args, the argument + // values are stored in values_, otherwise they are stored in args_. + // This is done to reduce compiled code size as storing larger objects + // may require more code (at least on x86-64) even if the same amount of + // data is actually copied to stack. It saves ~10% on the bloat test. + const internal::value* values_; + const format_arg* args_; + }; + + bool is_packed() const { return (types_ & internal::is_unpacked_bit) == 0; } + + internal::type type(int index) const { + int shift = index * internal::packed_arg_bits; + unsigned int mask = (1 << internal::packed_arg_bits) - 1; + return static_cast((types_ >> shift) & mask); + } + + friend class internal::arg_map; + + void set_data(const internal::value* values) { values_ = values; } + void set_data(const format_arg* args) { args_ = args; } + + format_arg do_get(int index) const { + format_arg arg; + if (!is_packed()) { + auto num_args = max_size(); + if (index < num_args) arg = args_[index]; + return arg; + } + if (index > internal::max_packed_args) return arg; + arg.type_ = type(index); + if (arg.type_ == internal::none_type) return arg; + internal::value& val = arg.value_; + val = values_[index]; + return arg; + } + + public: + basic_format_args() : types_(0) {} + + /** + \rst + Constructs a `basic_format_args` object from `~fmt::format_arg_store`. + \endrst + */ + template + basic_format_args(const format_arg_store& store) + : types_(store.types) { + set_data(store.data_); + } + + /** + \rst + Constructs a `basic_format_args` object from a dynamic set of arguments. + \endrst + */ + basic_format_args(const format_arg* args, int count) + : types_(internal::is_unpacked_bit | internal::to_unsigned(count)) { + set_data(args); + } + + /** Returns the argument at specified index. */ + format_arg get(int index) const { + format_arg arg = do_get(index); + if (arg.type_ == internal::named_arg_type) + arg = arg.value_.named_arg->template deserialize(); + return arg; + } + + int max_size() const { + unsigned long long max_packed = internal::max_packed_args; + return static_cast(is_packed() ? max_packed + : types_ & ~internal::is_unpacked_bit); + } +}; + +/** An alias to ``basic_format_args``. */ +// It is a separate type rather than an alias to make symbols readable. +struct format_args : basic_format_args { + template + format_args(Args&&... args) + : basic_format_args(std::forward(args)...) {} +}; +struct wformat_args : basic_format_args { + template + wformat_args(Args&&... args) + : basic_format_args(std::forward(args)...) {} +}; + +template struct is_contiguous : std::false_type {}; + +template +struct is_contiguous> : std::true_type {}; + +template +struct is_contiguous> : std::true_type {}; + +namespace internal { + +template +struct is_contiguous_back_insert_iterator : std::false_type {}; +template +struct is_contiguous_back_insert_iterator> + : is_contiguous {}; + +template struct named_arg_base { + basic_string_view name; + + // Serialized value. + mutable char data[sizeof(basic_format_arg>)]; + + named_arg_base(basic_string_view nm) : name(nm) {} + + template basic_format_arg deserialize() const { + basic_format_arg arg; + std::memcpy(&arg, data, sizeof(basic_format_arg)); + return arg; + } +}; + +template struct named_arg : named_arg_base { + const T& value; + + named_arg(basic_string_view name, const T& val) + : named_arg_base(name), value(val) {} +}; + +template ::value)> +inline void check_format_string(const S&) { +#if defined(FMT_ENFORCE_COMPILE_STRING) + static_assert(is_compile_string::value, + "FMT_ENFORCE_COMPILE_STRING requires all format strings to " + "utilize FMT_STRING() or fmt()."); +#endif +} +template ::value)> +void check_format_string(S); + +struct view {}; +template struct bool_pack; +template +using all_true = + std::is_same, bool_pack>; + +template > +inline format_arg_store, remove_reference_t...> +make_args_checked(const S& format_str, + const remove_reference_t&... args) { + static_assert(all_true<(!std::is_base_of>() || + !std::is_reference())...>::value, + "passing views as lvalues is disallowed"); + check_format_string>...>(format_str); + return {args...}; +} + +template +std::basic_string vformat(basic_string_view format_str, + basic_format_args> args); + +template +typename buffer_context::iterator vformat_to( + buffer& buf, basic_string_view format_str, + basic_format_args> args); +} // namespace internal + +/** + \rst + Returns a named argument to be used in a formatting function. + + The named argument holds a reference and does not extend the lifetime + of its arguments. + Consequently, a dangling reference can accidentally be created. + The user should take care to only pass this function temporaries when + the named argument is itself a temporary, as per the following example. + + **Example**:: + + fmt::print("Elapsed time: {s:.2f} seconds", fmt::arg("s", 1.23)); + \endrst + */ +template > +inline internal::named_arg arg(const S& name, const T& arg) { + static_assert(internal::is_string::value, ""); + return {name, arg}; +} + +// Disable nested named arguments, e.g. ``arg("a", arg("b", 42))``. +template +void arg(S, internal::named_arg) = delete; + +/** Formats a string and writes the output to ``out``. */ +// GCC 8 and earlier cannot handle std::back_insert_iterator with +// vformat_to(...) overload, so SFINAE on iterator type instead. +template , + FMT_ENABLE_IF( + internal::is_contiguous_back_insert_iterator::value)> +OutputIt vformat_to(OutputIt out, const S& format_str, + basic_format_args> args) { + using container = remove_reference_t; + internal::container_buffer buf((internal::get_container(out))); + internal::vformat_to(buf, to_string_view(format_str), args); + return out; +} + +template ::value&& internal::is_string::value)> +inline std::back_insert_iterator format_to( + std::back_insert_iterator out, const S& format_str, + Args&&... args) { + return vformat_to( + out, to_string_view(format_str), + {internal::make_args_checked(format_str, args...)}); +} + +template > +inline std::basic_string vformat( + const S& format_str, basic_format_args> args) { + return internal::vformat(to_string_view(format_str), args); +} + +/** + \rst + Formats arguments and returns the result as a string. + + **Example**:: + + #include + std::string message = fmt::format("The answer is {}", 42); + \endrst +*/ +// Pass char_t as a default template parameter instead of using +// std::basic_string> to reduce the symbol size. +template > +inline std::basic_string format(const S& format_str, Args&&... args) { + return internal::vformat( + to_string_view(format_str), + {internal::make_args_checked(format_str, args...)}); +} + +FMT_END_NAMESPACE + +#endif // FMT_CORE_H_ diff --git a/src/duckdb/third_party/fmt/include/fmt/format-inl.h b/src/duckdb/third_party/fmt/include/fmt/format-inl.h new file mode 100644 index 00000000..cef5359c --- /dev/null +++ b/src/duckdb/third_party/fmt/include/fmt/format-inl.h @@ -0,0 +1,1185 @@ +// Formatting library for C++ - implementation +// +// Copyright (c) 2012 - 2016, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_FORMAT_INL_H_ +#define FMT_FORMAT_INL_H_ + +#include "fmt/format.h" + +#include +#include +#include +#include +#include +#include // for std::memmove +#include +#if FMT_EXCEPTIONS +# define FMT_TRY try +# define FMT_CATCH(x) catch (x) +#else +# define FMT_TRY if (true) +# define FMT_CATCH(x) if (false) +#endif + +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable : 4702) // unreachable code +#endif + +// Dummy implementations of strerror_r and strerror_s called if corresponding +// system functions are not available. +inline duckdb_fmt::internal::null<> strerror_r(int, char*, ...) { return {}; } +inline duckdb_fmt::internal::null<> strerror_s(char*, std::size_t, ...) { return {}; } + +FMT_BEGIN_NAMESPACE +namespace internal { + +#ifndef _MSC_VER +# define FMT_SNPRINTF snprintf +#else // _MSC_VER +inline int fmt_snprintf(char* buffer, size_t size, const char* format, ...) { + va_list args; + va_start(args, format); + int result = vsnprintf_s(buffer, size, _TRUNCATE, format, args); + va_end(args); + return result; +} +# define FMT_SNPRINTF fmt_snprintf +#endif // _MSC_VER + +using format_func = void (*)(internal::buffer&, int, string_view); + +// A portable thread-safe version of strerror. +// Sets buffer to point to a string describing the error code. +// This can be either a pointer to a string stored in buffer, +// or a pointer to some static immutable string. +// Returns one of the following values: +// 0 - success +// ERANGE - buffer is not large enough to store the error message +// other - failure +// Buffer should be at least of size 1. +FMT_FUNC int safe_strerror(int error_code, char*& buffer, + std::size_t buffer_size) FMT_NOEXCEPT { + FMT_ASSERT(buffer != nullptr && buffer_size != 0, "invalid buffer"); + + class dispatcher { + private: + int error_code_; + char*& buffer_; + std::size_t buffer_size_; + + // A noop assignment operator to avoid bogus warnings. + void operator=(const dispatcher&) {} + + // Handle the result of XSI-compliant version of strerror_r. + int handle(int result) { + // glibc versions before 2.13 return result in errno. + return result == -1 ? errno : result; + } + + // Handle the result of GNU-specific version of strerror_r. + int handle(char* message) { + // If the buffer is full then the message is probably truncated. + if (message == buffer_ && strlen(buffer_) == buffer_size_ - 1) + return ERANGE; + buffer_ = message; + return 0; + } + + // Handle the case when strerror_r is not available. + int handle(internal::null<>) { + return fallback(strerror_s(buffer_, buffer_size_, error_code_)); + } + + // Fallback to strerror_s when strerror_r is not available. + int fallback(int result) { + // If the buffer is full then the message is probably truncated. + return result == 0 && strlen(buffer_) == buffer_size_ - 1 ? ERANGE + : result; + } + +#if !FMT_MSC_VER + // Fallback to strerror if strerror_r and strerror_s are not available. + int fallback(internal::null<>) { + errno = 0; + buffer_ = strerror(error_code_); + return errno; + } +#endif + + public: + dispatcher(int err_code, char*& buf, std::size_t buf_size) + : error_code_(err_code), buffer_(buf), buffer_size_(buf_size) {} + + int run() { return handle(strerror_r(error_code_, buffer_, buffer_size_)); } + }; + return dispatcher(error_code, buffer, buffer_size).run(); +} + +FMT_FUNC void format_error_code(internal::buffer& out, int error_code, + string_view message) FMT_NOEXCEPT { + // Report error code making sure that the output fits into + // inline_buffer_size to avoid dynamic memory allocation and potential + // bad_alloc. + out.resize(0); + static const char SEP[] = ": "; + static const char ERROR_STR[] = "error "; + // Subtract 2 to account for terminating null characters in SEP and ERROR_STR. + std::size_t error_code_size = sizeof(SEP) + sizeof(ERROR_STR) - 2; + auto abs_value = static_cast>(error_code); + if (internal::is_negative(error_code)) { + abs_value = 0 - abs_value; + ++error_code_size; + } + error_code_size += internal::to_unsigned(internal::count_digits(abs_value)); + internal::writer w(out); + if (message.size() <= inline_buffer_size - error_code_size) { + w.write(message); + w.write(SEP); + } + w.write(ERROR_STR); + w.write(error_code); + assert(out.size() <= inline_buffer_size); +} + +FMT_FUNC void report_error(format_func func, int error_code, + string_view message) FMT_NOEXCEPT { + memory_buffer full_message; + func(full_message, error_code, message); + /*// R does not allow us to have a reference to stderr even if we are not using it + // Don't use fwrite_fully because the latter may throw. + (void)std::fwrite(full_message.data(), full_message.size(), 1, stderr); + std::fputc('\n', stderr); + */ +} +} // namespace internal + +template +FMT_FUNC std::string internal::grouping_impl(locale_ref) { + return "\03"; +} +template +FMT_FUNC Char internal::thousands_sep_impl(locale_ref) { + return ','; +} +template +FMT_FUNC Char internal::decimal_point_impl(locale_ref) { + return '.'; +} + +namespace internal { + +template <> FMT_FUNC int count_digits<4>(internal::fallback_uintptr n) { + // fallback_uintptr is always stored in little endian. + int i = static_cast(sizeof(void*)) - 1; + while (i > 0 && n.value[i] == 0) --i; + auto char_digits = std::numeric_limits::digits / 4; + return i >= 0 ? i * char_digits + count_digits<4, unsigned>(n.value[i]) : 1; +} + +template +const char basic_data::digits[] = + "0001020304050607080910111213141516171819" + "2021222324252627282930313233343536373839" + "4041424344454647484950515253545556575859" + "6061626364656667686970717273747576777879" + "8081828384858687888990919293949596979899"; + +template +const char basic_data::hex_digits[] = "0123456789abcdef"; + +#define FMT_POWERS_OF_10(factor) \ + factor * 10, (factor)*100, (factor)*1000, (factor)*10000, (factor)*100000, \ + (factor)*1000000, (factor)*10000000, (factor)*100000000, \ + (factor)*1000000000 + +template +const uint64_t basic_data::powers_of_10_64[] = { + 1, FMT_POWERS_OF_10(1), FMT_POWERS_OF_10(1000000000ULL), + 10000000000000000000ULL}; + +template +const uint32_t basic_data::zero_or_powers_of_10_32[] = {0, + FMT_POWERS_OF_10(1)}; + +template +const uint64_t basic_data::zero_or_powers_of_10_64[] = { + 0, FMT_POWERS_OF_10(1), FMT_POWERS_OF_10(1000000000ULL), + 10000000000000000000ULL}; + +// Normalized 64-bit significands of pow(10, k), for k = -348, -340, ..., 340. +// These are generated by support/compute-powers.py. +template +const uint64_t basic_data::pow10_significands[] = { + 0xfa8fd5a0081c0288, 0xbaaee17fa23ebf76, 0x8b16fb203055ac76, + 0xcf42894a5dce35ea, 0x9a6bb0aa55653b2d, 0xe61acf033d1a45df, + 0xab70fe17c79ac6ca, 0xff77b1fcbebcdc4f, 0xbe5691ef416bd60c, + 0x8dd01fad907ffc3c, 0xd3515c2831559a83, 0x9d71ac8fada6c9b5, + 0xea9c227723ee8bcb, 0xaecc49914078536d, 0x823c12795db6ce57, + 0xc21094364dfb5637, 0x9096ea6f3848984f, 0xd77485cb25823ac7, + 0xa086cfcd97bf97f4, 0xef340a98172aace5, 0xb23867fb2a35b28e, + 0x84c8d4dfd2c63f3b, 0xc5dd44271ad3cdba, 0x936b9fcebb25c996, + 0xdbac6c247d62a584, 0xa3ab66580d5fdaf6, 0xf3e2f893dec3f126, + 0xb5b5ada8aaff80b8, 0x87625f056c7c4a8b, 0xc9bcff6034c13053, + 0x964e858c91ba2655, 0xdff9772470297ebd, 0xa6dfbd9fb8e5b88f, + 0xf8a95fcf88747d94, 0xb94470938fa89bcf, 0x8a08f0f8bf0f156b, + 0xcdb02555653131b6, 0x993fe2c6d07b7fac, 0xe45c10c42a2b3b06, + 0xaa242499697392d3, 0xfd87b5f28300ca0e, 0xbce5086492111aeb, + 0x8cbccc096f5088cc, 0xd1b71758e219652c, 0x9c40000000000000, + 0xe8d4a51000000000, 0xad78ebc5ac620000, 0x813f3978f8940984, + 0xc097ce7bc90715b3, 0x8f7e32ce7bea5c70, 0xd5d238a4abe98068, + 0x9f4f2726179a2245, 0xed63a231d4c4fb27, 0xb0de65388cc8ada8, + 0x83c7088e1aab65db, 0xc45d1df942711d9a, 0x924d692ca61be758, + 0xda01ee641a708dea, 0xa26da3999aef774a, 0xf209787bb47d6b85, + 0xb454e4a179dd1877, 0x865b86925b9bc5c2, 0xc83553c5c8965d3d, + 0x952ab45cfa97a0b3, 0xde469fbd99a05fe3, 0xa59bc234db398c25, + 0xf6c69a72a3989f5c, 0xb7dcbf5354e9bece, 0x88fcf317f22241e2, + 0xcc20ce9bd35c78a5, 0x98165af37b2153df, 0xe2a0b5dc971f303a, + 0xa8d9d1535ce3b396, 0xfb9b7cd9a4a7443c, 0xbb764c4ca7a44410, + 0x8bab8eefb6409c1a, 0xd01fef10a657842c, 0x9b10a4e5e9913129, + 0xe7109bfba19c0c9d, 0xac2820d9623bf429, 0x80444b5e7aa7cf85, + 0xbf21e44003acdd2d, 0x8e679c2f5e44ff8f, 0xd433179d9c8cb841, + 0x9e19db92b4e31ba9, 0xeb96bf6ebadf77d9, 0xaf87023b9bf0ee6b, +}; + +// Binary exponents of pow(10, k), for k = -348, -340, ..., 340, corresponding +// to significands above. +template +const int16_t basic_data::pow10_exponents[] = { + -1220, -1193, -1166, -1140, -1113, -1087, -1060, -1034, -1007, -980, -954, + -927, -901, -874, -847, -821, -794, -768, -741, -715, -688, -661, + -635, -608, -582, -555, -529, -502, -475, -449, -422, -396, -369, + -343, -316, -289, -263, -236, -210, -183, -157, -130, -103, -77, + -50, -24, 3, 30, 56, 83, 109, 136, 162, 189, 216, + 242, 269, 295, 322, 348, 375, 402, 428, 455, 481, 508, + 534, 561, 588, 614, 641, 667, 694, 720, 747, 774, 800, + 827, 853, 880, 907, 933, 960, 986, 1013, 1039, 1066}; + +template +const char basic_data::foreground_color[] = "\x1b[38;2;"; +template +const char basic_data::background_color[] = "\x1b[48;2;"; +template const char basic_data::reset_color[] = "\x1b[0m"; +template const wchar_t basic_data::wreset_color[] = L"\x1b[0m"; +template const char basic_data::signs[] = {0, '-', '+', ' '}; + +template struct bits { + static FMT_CONSTEXPR_DECL const int value = + static_cast(sizeof(T) * std::numeric_limits::digits); +}; + +class fp; +template fp normalize(fp value); + +// Lower (upper) boundary is a value half way between a floating-point value +// and its predecessor (successor). Boundaries have the same exponent as the +// value so only significands are stored. +struct boundaries { + uint64_t lower; + uint64_t upper; +}; + +// A handmade floating-point number f * pow(2, e). +class fp { + private: + using significand_type = uint64_t; + + // All sizes are in bits. + // Subtract 1 to account for an implicit most significant bit in the + // normalized form. + static FMT_CONSTEXPR_DECL const int double_significand_size = + std::numeric_limits::digits - 1; + static FMT_CONSTEXPR_DECL const uint64_t implicit_bit = + 1ULL << double_significand_size; + + public: + significand_type f; + int e; + + static FMT_CONSTEXPR_DECL const int significand_size = + bits::value; + + fp() : f(0), e(0) {} + fp(uint64_t f_val, int e_val) : f(f_val), e(e_val) {} + + // Constructs fp from an IEEE754 double. It is a template to prevent compile + // errors on platforms where double is not IEEE754. + template explicit fp(Double d) { assign(d); } + + // Normalizes the value converted from double and multiplied by (1 << SHIFT). + template friend fp normalize(fp value) { + // Handle subnormals. + const auto shifted_implicit_bit = fp::implicit_bit << SHIFT; + while ((value.f & shifted_implicit_bit) == 0) { + value.f <<= 1; + --value.e; + } + // Subtract 1 to account for hidden bit. + const auto offset = + fp::significand_size - fp::double_significand_size - SHIFT - 1; + value.f <<= offset; + value.e -= offset; + return value; + } + + // Assigns d to this and return true iff predecessor is closer than successor. + template + bool assign(Double d) { + // Assume double is in the format [sign][exponent][significand]. + using limits = std::numeric_limits; + const int exponent_size = + bits::value - double_significand_size - 1; // -1 for sign + const uint64_t significand_mask = implicit_bit - 1; + const uint64_t exponent_mask = (~0ULL >> 1) & ~significand_mask; + const int exponent_bias = (1 << exponent_size) - limits::max_exponent - 1; + auto u = bit_cast(d); + f = u & significand_mask; + auto biased_e = (u & exponent_mask) >> double_significand_size; + // Predecessor is closer if d is a normalized power of 2 (f == 0) other than + // the smallest normalized number (biased_e > 1). + bool is_predecessor_closer = f == 0 && biased_e > 1; + if (biased_e != 0) + f += implicit_bit; + else + biased_e = 1; // Subnormals use biased exponent 1 (min exponent). + e = static_cast(biased_e - exponent_bias - double_significand_size); + return is_predecessor_closer; + } + + template + bool assign(Double) { + *this = fp(); + return false; + } + + // Assigns d to this together with computing lower and upper boundaries, + // where a boundary is a value half way between the number and its predecessor + // (lower) or successor (upper). The upper boundary is normalized and lower + // has the same exponent but may be not normalized. + template boundaries assign_with_boundaries(Double d) { + bool is_lower_closer = assign(d); + fp lower = + is_lower_closer ? fp((f << 2) - 1, e - 2) : fp((f << 1) - 1, e - 1); + // 1 in normalize accounts for the exponent shift above. + fp upper = normalize<1>(fp((f << 1) + 1, e - 1)); + lower.f <<= lower.e - upper.e; + return boundaries{lower.f, upper.f}; + } + + template boundaries assign_float_with_boundaries(Double d) { + assign(d); + constexpr int min_normal_e = std::numeric_limits::min_exponent - + std::numeric_limits::digits; + significand_type half_ulp = 1 << (std::numeric_limits::digits - + std::numeric_limits::digits - 1); + if (min_normal_e > e) half_ulp <<= min_normal_e - e; + fp upper = normalize<0>(fp(f + half_ulp, e)); + fp lower = fp( + f - (half_ulp >> ((f == implicit_bit && e > min_normal_e) ? 1 : 0)), e); + lower.f <<= lower.e - upper.e; + return boundaries{lower.f, upper.f}; + } +}; + +inline bool operator==(fp x, fp y) { return x.f == y.f && x.e == y.e; } + +// Computes lhs * rhs / pow(2, 64) rounded to nearest with half-up tie breaking. +inline uint64_t multiply(uint64_t lhs, uint64_t rhs) { +#if FMT_USE_INT128 + auto product = static_cast<__uint128_t>(lhs) * rhs; + auto f = static_cast(product >> 64); + return (static_cast(product) & (1ULL << 63)) != 0 ? f + 1 : f; +#else + // Multiply 32-bit parts of significands. + uint64_t mask = (1ULL << 32) - 1; + uint64_t a = lhs >> 32, b = lhs & mask; + uint64_t c = rhs >> 32, d = rhs & mask; + uint64_t ac = a * c, bc = b * c, ad = a * d, bd = b * d; + // Compute mid 64-bit of result and round. + uint64_t mid = (bd >> 32) + (ad & mask) + (bc & mask) + (1U << 31); + return ac + (ad >> 32) + (bc >> 32) + (mid >> 32); +#endif +} + +inline fp operator*(fp x, fp y) { return {multiply(x.f, y.f), x.e + y.e + 64}; } + +// Returns a cached power of 10 `c_k = c_k.f * pow(2, c_k.e)` such that its +// (binary) exponent satisfies `min_exponent <= c_k.e <= min_exponent + 28`. +FMT_FUNC fp get_cached_power(int min_exponent, int& pow10_exponent) { + const uint64_t one_over_log2_10 = 0x4d104d42; // round(pow(2, 32) / log2(10)) + int index = static_cast( + static_cast( + (min_exponent + fp::significand_size - 1) * one_over_log2_10 + + ((uint64_t(1) << 32) - 1) // ceil + ) >> + 32 // arithmetic shift + ); + // Decimal exponent of the first (smallest) cached power of 10. + const int first_dec_exp = -348; + // Difference between 2 consecutive decimal exponents in cached powers of 10. + const int dec_exp_step = 8; + index = (index - first_dec_exp - 1) / dec_exp_step + 1; + pow10_exponent = first_dec_exp + index * dec_exp_step; + return {data::pow10_significands[index], data::pow10_exponents[index]}; +} + +// A simple accumulator to hold the sums of terms in bigint::square if uint128_t +// is not available. +struct accumulator { + uint64_t lower; + uint64_t upper; + + accumulator() : lower(0), upper(0) {} + explicit operator uint32_t() const { return static_cast(lower); } + + void operator+=(uint64_t n) { + lower += n; + if (lower < n) ++upper; + } + void operator>>=(int shift) { + assert(shift == 32); + (void)shift; + lower = (upper << 32) | (lower >> 32); + upper >>= 32; + } +}; + +class bigint { + private: + // A bigint is stored as an array of bigits (big digits), with bigit at index + // 0 being the least significant one. + using bigit = uint32_t; + using double_bigit = uint64_t; + enum { bigits_capacity = 32 }; + basic_memory_buffer bigits_; + int exp_; + + static FMT_CONSTEXPR_DECL const int bigit_bits = bits::value; + + friend struct formatter; + + void subtract_bigits(int index, bigit other, bigit& borrow) { + auto result = static_cast(bigits_[index]) - other - borrow; + bigits_[index] = static_cast(result); + borrow = static_cast(result >> (bigit_bits * 2 - 1)); + } + + void remove_leading_zeros() { + int num_bigits = static_cast(bigits_.size()) - 1; + while (num_bigits > 0 && bigits_[num_bigits] == 0) --num_bigits; + bigits_.resize(num_bigits + 1); + } + + // Computes *this -= other assuming aligned bigints and *this >= other. + void subtract_aligned(const bigint& other) { + FMT_ASSERT(other.exp_ >= exp_, "unaligned bigints"); + FMT_ASSERT(compare(*this, other) >= 0, ""); + bigit borrow = 0; + int i = other.exp_ - exp_; + for (int j = 0, n = static_cast(other.bigits_.size()); j != n; + ++i, ++j) { + subtract_bigits(i, other.bigits_[j], borrow); + } + while (borrow > 0) subtract_bigits(i, 0, borrow); + remove_leading_zeros(); + } + + void multiply(uint32_t value) { + const double_bigit wide_value = value; + bigit carry = 0; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + double_bigit result = bigits_[i] * wide_value + carry; + bigits_[i] = static_cast(result); + carry = static_cast(result >> bigit_bits); + } + if (carry != 0) bigits_.push_back(carry); + } + + void multiply(uint64_t value) { + const bigit mask = ~bigit(0); + const double_bigit lower = value & mask; + const double_bigit upper = value >> bigit_bits; + double_bigit carry = 0; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + double_bigit result = bigits_[i] * lower + (carry & mask); + carry = + bigits_[i] * upper + (result >> bigit_bits) + (carry >> bigit_bits); + bigits_[i] = static_cast(result); + } + while (carry != 0) { + bigits_.push_back(carry & mask); + carry >>= bigit_bits; + } + } + + public: + bigint() : exp_(0) {} + explicit bigint(uint64_t n) { assign(n); } + ~bigint() { assert(bigits_.capacity() <= bigits_capacity); } + + bigint(const bigint&) = delete; + void operator=(const bigint&) = delete; + + void assign(const bigint& other) { + bigits_.resize(other.bigits_.size()); + auto data = other.bigits_.data(); + std::copy(data, data + other.bigits_.size(), bigits_.data()); + exp_ = other.exp_; + } + + void assign(uint64_t n) { + int num_bigits = 0; + do { + bigits_[num_bigits++] = n & ~bigit(0); + n >>= bigit_bits; + } while (n != 0); + bigits_.resize(num_bigits); + exp_ = 0; + } + + int num_bigits() const { return static_cast(bigits_.size()) + exp_; } + + bigint& operator<<=(int shift) { + assert(shift >= 0); + exp_ += shift / bigit_bits; + shift %= bigit_bits; + if (shift == 0) return *this; + bigit carry = 0; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + bigit c = bigits_[i] >> (bigit_bits - shift); + bigits_[i] = (bigits_[i] << shift) + carry; + carry = c; + } + if (carry != 0) bigits_.push_back(carry); + return *this; + } + + template bigint& operator*=(Int value) { + FMT_ASSERT(value > 0, ""); + multiply(uint32_or_64_or_128_t(value)); + return *this; + } + + friend int compare(const bigint& lhs, const bigint& rhs) { + int num_lhs_bigits = lhs.num_bigits(), num_rhs_bigits = rhs.num_bigits(); + if (num_lhs_bigits != num_rhs_bigits) + return num_lhs_bigits > num_rhs_bigits ? 1 : -1; + int i = static_cast(lhs.bigits_.size()) - 1; + int j = static_cast(rhs.bigits_.size()) - 1; + int end = i - j; + if (end < 0) end = 0; + for (; i >= end; --i, --j) { + bigit lhs_bigit = lhs.bigits_[i], rhs_bigit = rhs.bigits_[j]; + if (lhs_bigit != rhs_bigit) return lhs_bigit > rhs_bigit ? 1 : -1; + } + if (i != j) return i > j ? 1 : -1; + return 0; + } + + // Returns compare(lhs1 + lhs2, rhs). + friend int add_compare(const bigint& lhs1, const bigint& lhs2, + const bigint& rhs) { + int max_lhs_bigits = (std::max)(lhs1.num_bigits(), lhs2.num_bigits()); + int num_rhs_bigits = rhs.num_bigits(); + if (max_lhs_bigits + 1 < num_rhs_bigits) return -1; + if (max_lhs_bigits > num_rhs_bigits) return 1; + auto get_bigit = [](const bigint& n, int i) -> bigit { + return i >= n.exp_ && i < n.num_bigits() ? n.bigits_[i - n.exp_] : 0; + }; + double_bigit borrow = 0; + int min_exp = (std::min)((std::min)(lhs1.exp_, lhs2.exp_), rhs.exp_); + for (int i = num_rhs_bigits - 1; i >= min_exp; --i) { + double_bigit sum = + static_cast(get_bigit(lhs1, i)) + get_bigit(lhs2, i); + bigit rhs_bigit = get_bigit(rhs, i); + if (sum > rhs_bigit + borrow) return 1; + borrow = rhs_bigit + borrow - sum; + if (borrow > 1) return -1; + borrow <<= bigit_bits; + } + return borrow != 0 ? -1 : 0; + } + + // Assigns pow(10, exp) to this bigint. + void assign_pow10(int exp) { + assert(exp >= 0); + if (exp == 0) return assign(1); + // Find the top bit. + int bitmask = 1; + while (exp >= bitmask) bitmask <<= 1; + bitmask >>= 1; + // pow(10, exp) = pow(5, exp) * pow(2, exp). First compute pow(5, exp) by + // repeated squaring and multiplication. + assign(5); + bitmask >>= 1; + while (bitmask != 0) { + square(); + if ((exp & bitmask) != 0) *this *= 5; + bitmask >>= 1; + } + *this <<= exp; // Multiply by pow(2, exp) by shifting. + } + + void square() { + basic_memory_buffer n(std::move(bigits_)); + int num_bigits = static_cast(bigits_.size()); + int num_result_bigits = 2 * num_bigits; + bigits_.resize(num_result_bigits); + using accumulator_t = conditional_t; + auto sum = accumulator_t(); + for (int bigit_index = 0; bigit_index < num_bigits; ++bigit_index) { + // Compute bigit at position bigit_index of the result by adding + // cross-product terms n[i] * n[j] such that i + j == bigit_index. + for (int i = 0, j = bigit_index; j >= 0; ++i, --j) { + // Most terms are multiplied twice which can be optimized in the future. + sum += static_cast(n[i]) * n[j]; + } + bigits_[bigit_index] = static_cast(sum); + sum >>= bits::value; // Compute the carry. + } + // Do the same for the top half. + for (int bigit_index = num_bigits; bigit_index < num_result_bigits; + ++bigit_index) { + for (int j = num_bigits - 1, i = bigit_index - j; i < num_bigits;) + sum += static_cast(n[i++]) * n[j--]; + bigits_[bigit_index] = static_cast(sum); + sum >>= bits::value; + } + --num_result_bigits; + remove_leading_zeros(); + exp_ *= 2; + } + + // Divides this bignum by divisor, assigning the remainder to this and + // returning the quotient. + int divmod_assign(const bigint& divisor) { + FMT_ASSERT(this != &divisor, ""); + if (compare(*this, divisor) < 0) return 0; + int num_bigits = static_cast(bigits_.size()); + FMT_ASSERT(divisor.bigits_[divisor.bigits_.size() - 1] != 0, ""); + int exp_difference = exp_ - divisor.exp_; + if (exp_difference > 0) { + // Align bigints by adding trailing zeros to simplify subtraction. + bigits_.resize(num_bigits + exp_difference); + for (int i = num_bigits - 1, j = i + exp_difference; i >= 0; --i, --j) + bigits_[j] = bigits_[i]; + std::uninitialized_fill_n(bigits_.data(), exp_difference, 0); + exp_ -= exp_difference; + } + int quotient = 0; + do { + subtract_aligned(divisor); + ++quotient; + } while (compare(*this, divisor) >= 0); + return quotient; + } +}; + +enum round_direction { unknown, up, down }; + +// Given the divisor (normally a power of 10), the remainder = v % divisor for +// some number v and the error, returns whether v should be rounded up, down, or +// whether the rounding direction can't be determined due to error. +// error should be less than divisor / 2. +inline round_direction get_round_direction(uint64_t divisor, uint64_t remainder, + uint64_t error) { + FMT_ASSERT(remainder < divisor, ""); // divisor - remainder won't overflow. + FMT_ASSERT(error < divisor, ""); // divisor - error won't overflow. + FMT_ASSERT(error < divisor - error, ""); // error * 2 won't overflow. + // Round down if (remainder + error) * 2 <= divisor. + if (remainder <= divisor - remainder && error * 2 <= divisor - remainder * 2) + return down; + // Round up if (remainder - error) * 2 >= divisor. + if (remainder >= error && + remainder - error >= divisor - (remainder - error)) { + return up; + } + return unknown; +} + +namespace digits { +enum result { + more, // Generate more digits. + done, // Done generating digits. + error // Digit generation cancelled due to an error. +}; +} + +// Generates output using the Grisu digit-gen algorithm. +// error: the size of the region (lower, upper) outside of which numbers +// definitely do not round to value (Delta in Grisu3). +template +FMT_ALWAYS_INLINE digits::result grisu_gen_digits(fp value, uint64_t error, + int& exp, Handler& handler) { + const fp one(1ULL << -value.e, value.e); + // The integral part of scaled value (p1 in Grisu) = value / one. It cannot be + // zero because it contains a product of two 64-bit numbers with MSB set (due + // to normalization) - 1, shifted right by at most 60 bits. + auto integral = static_cast(value.f >> -one.e); + FMT_ASSERT(integral != 0, ""); + FMT_ASSERT(integral == value.f >> -one.e, ""); + // The fractional part of scaled value (p2 in Grisu) c = value % one. + uint64_t fractional = value.f & (one.f - 1); + exp = count_digits(integral); // kappa in Grisu. + // Divide by 10 to prevent overflow. + auto result = handler.on_start(data::powers_of_10_64[exp - 1] << -one.e, + value.f / 10, error * 10, exp); + if (result != digits::more) return result; + // Generate digits for the integral part. This can produce up to 10 digits. + do { + uint32_t digit = 0; + auto divmod_integral = [&](uint32_t divisor) { + digit = integral / divisor; + integral %= divisor; + }; + // This optimization by Milo Yip reduces the number of integer divisions by + // one per iteration. + switch (exp) { + case 10: + divmod_integral(1000000000); + break; + case 9: + divmod_integral(100000000); + break; + case 8: + divmod_integral(10000000); + break; + case 7: + divmod_integral(1000000); + break; + case 6: + divmod_integral(100000); + break; + case 5: + divmod_integral(10000); + break; + case 4: + divmod_integral(1000); + break; + case 3: + divmod_integral(100); + break; + case 2: + divmod_integral(10); + break; + case 1: + digit = integral; + integral = 0; + break; + default: + FMT_ASSERT(false, "invalid number of digits"); + } + --exp; + uint64_t remainder = + (static_cast(integral) << -one.e) + fractional; + result = handler.on_digit(static_cast('0' + digit), + data::powers_of_10_64[exp] << -one.e, remainder, + error, exp, true); + if (result != digits::more) return result; + } while (exp > 0); + // Generate digits for the fractional part. + for (;;) { + fractional *= 10; + error *= 10; + char digit = + static_cast('0' + static_cast(fractional >> -one.e)); + fractional &= one.f - 1; + --exp; + result = handler.on_digit(digit, one.f, fractional, error, exp, false); + if (result != digits::more) return result; + } +} + +// The fixed precision digit handler. +struct fixed_handler { + char* buf; + int size; + int precision; + int exp10; + bool fixed; + + digits::result on_start(uint64_t divisor, uint64_t remainder, uint64_t error, + int& exp) { + // Non-fixed formats require at least one digit and no precision adjustment. + if (!fixed) return digits::more; + // Adjust fixed precision by exponent because it is relative to decimal + // point. + precision += exp + exp10; + // Check if precision is satisfied just by leading zeros, e.g. + // format("{:.2f}", 0.001) gives "0.00" without generating any digits. + if (precision > 0) return digits::more; + if (precision < 0) return digits::done; + auto dir = get_round_direction(divisor, remainder, error); + if (dir == unknown) return digits::error; + buf[size++] = dir == up ? '1' : '0'; + return digits::done; + } + + digits::result on_digit(char digit, uint64_t divisor, uint64_t remainder, + uint64_t error, int, bool integral) { + FMT_ASSERT(remainder < divisor, ""); + buf[size++] = digit; + if (size < precision) return digits::more; + if (!integral) { + // Check if error * 2 < divisor with overflow prevention. + // The check is not needed for the integral part because error = 1 + // and divisor > (1 << 32) there. + if (error >= divisor || error >= divisor - error) return digits::error; + } else { + FMT_ASSERT(error == 1 && divisor > 2, ""); + } + auto dir = get_round_direction(divisor, remainder, error); + if (dir != up) return dir == down ? digits::done : digits::error; + ++buf[size - 1]; + for (int i = size - 1; i > 0 && buf[i] > '9'; --i) { + buf[i] = '0'; + ++buf[i - 1]; + } + if (buf[0] > '9') { + buf[0] = '1'; + buf[size++] = '0'; + } + return digits::done; + } +}; + +// The shortest representation digit handler. +struct grisu_shortest_handler { + char* buf; + int size; + // Distance between scaled value and upper bound (wp_W in Grisu3). + uint64_t diff; + + digits::result on_start(uint64_t, uint64_t, uint64_t, int&) { + return digits::more; + } + + // Decrement the generated number approaching value from above. + void round(uint64_t d, uint64_t divisor, uint64_t& remainder, + uint64_t error) { + while ( + remainder < d && error - remainder >= divisor && + (remainder + divisor < d || d - remainder >= remainder + divisor - d)) { + --buf[size - 1]; + remainder += divisor; + } + } + + // Implements Grisu's round_weed. + digits::result on_digit(char digit, uint64_t divisor, uint64_t remainder, + uint64_t error, int exp, bool integral) { + buf[size++] = digit; + if (remainder >= error) return digits::more; + uint64_t unit = integral ? 1 : data::powers_of_10_64[-exp]; + uint64_t up = (diff - 1) * unit; // wp_Wup + round(up, divisor, remainder, error); + uint64_t down = (diff + 1) * unit; // wp_Wdown + if (remainder < down && error - remainder >= divisor && + (remainder + divisor < down || + down - remainder > remainder + divisor - down)) { + return digits::error; + } + return 2 * unit <= remainder && remainder <= error - 4 * unit + ? digits::done + : digits::error; + } +}; + +// Formats value using a variation of the Fixed-Precision Positive +// Floating-Point Printout ((FPP)^2) algorithm by Steele & White: +// https://fmt.dev/p372-steele.pdf. +template +void fallback_format(Double d, buffer& buf, int& exp10) { + bigint numerator; // 2 * R in (FPP)^2. + bigint denominator; // 2 * S in (FPP)^2. + // lower and upper are differences between value and corresponding boundaries. + bigint lower; // (M^- in (FPP)^2). + bigint upper_store; // upper's value if different from lower. + bigint* upper = nullptr; // (M^+ in (FPP)^2). + fp value; + // Shift numerator and denominator by an extra bit or two (if lower boundary + // is closer) to make lower and upper integers. This eliminates multiplication + // by 2 during later computations. + // TODO: handle float + int shift = value.assign(d) ? 2 : 1; + uint64_t significand = value.f << shift; + if (value.e >= 0) { + numerator.assign(significand); + numerator <<= value.e; + lower.assign(1); + lower <<= value.e; + if (shift != 1) { + upper_store.assign(1); + upper_store <<= value.e + 1; + upper = &upper_store; + } + denominator.assign_pow10(exp10); + denominator <<= 1; + } else if (exp10 < 0) { + numerator.assign_pow10(-exp10); + lower.assign(numerator); + if (shift != 1) { + upper_store.assign(numerator); + upper_store <<= 1; + upper = &upper_store; + } + numerator *= significand; + denominator.assign(1); + denominator <<= shift - value.e; + } else { + numerator.assign(significand); + denominator.assign_pow10(exp10); + denominator <<= shift - value.e; + lower.assign(1); + if (shift != 1) { + upper_store.assign(1ULL << 1); + upper = &upper_store; + } + } + if (!upper) upper = &lower; + // Invariant: value == (numerator / denominator) * pow(10, exp10). + bool even = (value.f & 1) == 0; + int num_digits = 0; + char* data = buf.data(); + for (;;) { + int digit = numerator.divmod_assign(denominator); + bool low = compare(numerator, lower) - even < 0; // numerator <[=] lower. + // numerator + upper >[=] pow10: + bool high = add_compare(numerator, *upper, denominator) + even > 0; + data[num_digits++] = static_cast('0' + digit); + if (low || high) { + if (!low) { + ++data[num_digits - 1]; + } else if (high) { + int result = add_compare(numerator, numerator, denominator); + // Round half to even. + if (result > 0 || (result == 0 && (digit % 2) != 0)) + ++data[num_digits - 1]; + } + buf.resize(num_digits); + exp10 -= num_digits - 1; + return; + } + numerator *= 10; + lower *= 10; + if (upper != &lower) *upper *= 10; + } +} + +// Formats value using the Grisu algorithm +// (https://www.cs.tufts.edu/~nr/cs257/archive/florian-loitsch/printf.pdf) +// if T is a IEEE754 binary32 or binary64 and snprintf otherwise. +template +int format_float(T value, int precision, float_specs specs, buffer& buf) { + static_assert(!std::is_same(), ""); + FMT_ASSERT(value >= 0, "value is negative"); + + const bool fixed = specs.format == float_format::fixed; + if (value <= 0) { // <= instead of == to silence a warning. + if (precision <= 0 || !fixed) { + buf.push_back('0'); + return 0; + } + buf.resize(to_unsigned(precision)); + std::uninitialized_fill_n(buf.data(), precision, '0'); + return -precision; + } + + if (!specs.use_grisu) return snprintf_float(value, precision, specs, buf); + + int exp = 0; + const int min_exp = -60; // alpha in Grisu. + int cached_exp10 = 0; // K in Grisu. + if (precision != -1) { + if (precision > 17) return snprintf_float(value, precision, specs, buf); + fp normalized = normalize(fp(value)); + const auto cached_pow = get_cached_power( + min_exp - (normalized.e + fp::significand_size), cached_exp10); + normalized = normalized * cached_pow; + fixed_handler handler{buf.data(), 0, precision, -cached_exp10, fixed}; + if (grisu_gen_digits(normalized, 1, exp, handler) == digits::error) + return snprintf_float(value, precision, specs, buf); + int num_digits = handler.size; + if (!fixed) { + // Remove trailing zeros. + while (num_digits > 0 && buf[num_digits - 1] == '0') { + --num_digits; + ++exp; + } + } + buf.resize(to_unsigned(num_digits)); + } else { + fp fp_value; + auto boundaries = specs.binary32 + ? fp_value.assign_float_with_boundaries(value) + : fp_value.assign_with_boundaries(value); + fp_value = normalize(fp_value); + // Find a cached power of 10 such that multiplying value by it will bring + // the exponent in the range [min_exp, -32]. + const fp cached_pow = get_cached_power( + min_exp - (fp_value.e + fp::significand_size), cached_exp10); + // Multiply value and boundaries by the cached power of 10. + fp_value = fp_value * cached_pow; + boundaries.lower = multiply(boundaries.lower, cached_pow.f); + boundaries.upper = multiply(boundaries.upper, cached_pow.f); + assert(min_exp <= fp_value.e && fp_value.e <= -32); + --boundaries.lower; // \tilde{M}^- - 1 ulp -> M^-_{\downarrow}. + ++boundaries.upper; // \tilde{M}^+ + 1 ulp -> M^+_{\uparrow}. + // Numbers outside of (lower, upper) definitely do not round to value. + grisu_shortest_handler handler{buf.data(), 0, + boundaries.upper - fp_value.f}; + auto result = + grisu_gen_digits(fp(boundaries.upper, fp_value.e), + boundaries.upper - boundaries.lower, exp, handler); + if (result == digits::error) { + exp += handler.size - cached_exp10 - 1; + fallback_format(value, buf, exp); + return exp; + } + buf.resize(to_unsigned(handler.size)); + } + return exp - cached_exp10; +} + +template +int snprintf_float(T value, int precision, float_specs specs, + buffer& buf) { + // Buffer capacity must be non-zero, otherwise MSVC's vsnprintf_s will fail. + FMT_ASSERT(buf.capacity() > buf.size(), "empty buffer"); + static_assert(!std::is_same(), ""); + + // Subtract 1 to account for the difference in precision since we use %e for + // both general and exponent format. + if (specs.format == float_format::general || + specs.format == float_format::exp) + precision = (precision >= 0 ? precision : 6) - 1; + + // Build the format string. + enum { max_format_size = 7 }; // Ths longest format is "%#.*Le". + char format[max_format_size]; + char* format_ptr = format; + *format_ptr++ = '%'; + if (specs.trailing_zeros) *format_ptr++ = '#'; + if (precision >= 0) { + *format_ptr++ = '.'; + *format_ptr++ = '*'; + } + if (std::is_same()) *format_ptr++ = 'L'; + *format_ptr++ = specs.format != float_format::hex + ? (specs.format == float_format::fixed ? 'f' : 'e') + : (specs.upper ? 'A' : 'a'); + *format_ptr = '\0'; + + // Format using snprintf. + auto offset = buf.size(); + for (;;) { + auto begin = buf.data() + offset; + auto capacity = buf.capacity() - offset; +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + if (precision > 100000) + throw std::runtime_error( + "fuzz mode - avoid large allocation inside snprintf"); +#endif + // Suppress the warning about a nonliteral format string. + auto snprintf_ptr = FMT_SNPRINTF; + int result = precision >= 0 + ? snprintf_ptr(begin, capacity, format, precision, value) + : snprintf_ptr(begin, capacity, format, value); + if (result < 0) { + buf.reserve(buf.capacity() + 1); // The buffer will grow exponentially. + continue; + } + unsigned size = to_unsigned(result); + // Size equal to capacity means that the last character was truncated. + if (size >= capacity) { + buf.reserve(size + offset + 1); // Add 1 for the terminating '\0'. + continue; + } + auto is_digit = [](char c) { return c >= '0' && c <= '9'; }; + if (specs.format == float_format::fixed) { + if (precision == 0) { + buf.resize(size); + return 0; + } + // Find and remove the decimal point. + auto end = begin + size, p = end; + do { + --p; + } while (is_digit(*p)); + int fraction_size = static_cast(end - p - 1); + std::memmove(p, p + 1, fraction_size); + buf.resize(size - 1); + return -fraction_size; + } + if (specs.format == float_format::hex) { + buf.resize(size + offset); + return 0; + } + // Find and parse the exponent. + auto end = begin + size, exp_pos = end; + do { + --exp_pos; + } while (*exp_pos != 'e'); + char sign = exp_pos[1]; + assert(sign == '+' || sign == '-'); + int exp = 0; + auto p = exp_pos + 2; // Skip 'e' and sign. + do { + assert(is_digit(*p)); + exp = exp * 10 + (*p++ - '0'); + } while (p != end); + if (sign == '-') exp = -exp; + int fraction_size = 0; + if (exp_pos != begin + 1) { + // Remove trailing zeros. + auto fraction_end = exp_pos - 1; + while (*fraction_end == '0') --fraction_end; + // Move the fractional part left to get rid of the decimal point. + fraction_size = static_cast(fraction_end - begin - 1); + std::memmove(begin + 1, begin + 2, fraction_size); + } + buf.resize(fraction_size + offset + 1); + return exp - fraction_size; + } +} +} // namespace internal + +template <> struct formatter { + format_parse_context::iterator parse(format_parse_context& ctx) { + return ctx.begin(); + } + + format_context::iterator format(const internal::bigint& n, + format_context& ctx) { + auto out = ctx.out(); + bool first = true; + for (auto i = n.bigits_.size(); i > 0; --i) { + auto value = n.bigits_[i - 1]; + if (first) { + out = format_to(out, "{:x}", value); + first = false; + continue; + } + out = format_to(out, "{:08x}", value); + } + if (n.exp_ > 0) + out = format_to(out, "p{}", n.exp_ * internal::bigint::bigit_bits); + return out; + } +}; + +FMT_FUNC void internal::error_handler::on_error(std::string message) { + FMT_THROW(duckdb::Exception(message)); +} + +FMT_END_NAMESPACE + +#ifdef _MSC_VER +# pragma warning(pop) +#endif + +#endif // FMT_FORMAT_INL_H_ diff --git a/src/duckdb/third_party/fmt/include/fmt/format.h b/src/duckdb/third_party/fmt/include/fmt/format.h new file mode 100644 index 00000000..dbc45255 --- /dev/null +++ b/src/duckdb/third_party/fmt/include/fmt/format.h @@ -0,0 +1,3370 @@ +/* + Formatting library for C++ + + Copyright (c) 2012 - present, Victor Zverovich + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + --- Optional exception to the license --- + + As an exception, if, as a result of your compiling your source code, portions + of this Software are embedded into a machine-executable object form of such + source code, you may redistribute such embedded portions in such object form + without including the above copyright and permission notices. + */ + +#ifndef FMT_FORMAT_H_ +#define FMT_FORMAT_H_ + +#include "duckdb/common/exception.hpp" +#include "fmt/core.h" + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __clang__ +# define FMT_CLANG_VERSION (__clang_major__ * 100 + __clang_minor__) +#else +# define FMT_CLANG_VERSION 0 +#endif + +#ifdef __INTEL_COMPILER +# define FMT_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICL) +# define FMT_ICC_VERSION __ICL +#else +# define FMT_ICC_VERSION 0 +#endif + +#ifdef __NVCC__ +# define FMT_CUDA_VERSION (__CUDACC_VER_MAJOR__ * 100 + __CUDACC_VER_MINOR__) +#else +# define FMT_CUDA_VERSION 0 +#endif + +#ifdef __has_builtin +# define FMT_HAS_BUILTIN(x) __has_builtin(x) +#else +# define FMT_HAS_BUILTIN(x) 0 +#endif + +#if FMT_HAS_CPP_ATTRIBUTE(fallthrough) && \ + (__cplusplus >= 201703 || FMT_GCC_VERSION != 0) +# define FMT_FALLTHROUGH [[fallthrough]] +#else +# define FMT_FALLTHROUGH +#endif + +#ifndef FMT_THROW +# if FMT_EXCEPTIONS +# if FMT_MSC_VER +FMT_BEGIN_NAMESPACE +namespace internal { +template inline void do_throw(const Exception& x) { + // Silence unreachable code warnings in MSVC because these are nearly + // impossible to fix in a generic code. + volatile bool b = true; + if (b) throw x; +} +} // namespace internal +FMT_END_NAMESPACE +# define FMT_THROW(x) internal::do_throw(x) +# else +# define FMT_THROW(x) throw x +# endif +# else +# define FMT_THROW(x) \ + do { \ + static_cast(sizeof(x)); \ + FMT_ASSERT(false, ""); \ + } while (false) +# endif +#endif + +#ifndef FMT_USE_USER_DEFINED_LITERALS +// For Intel and NVIDIA compilers both they and the system gcc/msc support UDLs. +# if (FMT_HAS_FEATURE(cxx_user_literals) || FMT_GCC_VERSION >= 407 || \ + FMT_MSC_VER >= 1900) && \ + (!(FMT_ICC_VERSION || FMT_CUDA_VERSION) || FMT_ICC_VERSION >= 1500 || \ + FMT_CUDA_VERSION >= 700) +# define FMT_USE_USER_DEFINED_LITERALS 1 +# else +# define FMT_USE_USER_DEFINED_LITERALS 0 +# endif +#endif + +#ifndef FMT_USE_UDL_TEMPLATE +#define FMT_USE_UDL_TEMPLATE 0 +#endif + +// __builtin_clz is broken in clang with Microsoft CodeGen: +// https://github.com/fmtlib/fmt/issues/519 +#if (FMT_GCC_VERSION || FMT_HAS_BUILTIN(__builtin_clz)) && !FMT_MSC_VER +# define FMT_BUILTIN_CLZ(n) __builtin_clz(n) +#endif +#if (FMT_GCC_VERSION || FMT_HAS_BUILTIN(__builtin_clzll)) && !FMT_MSC_VER +# define FMT_BUILTIN_CLZLL(n) __builtin_clzll(n) +#endif + +// Some compilers masquerade as both MSVC and GCC-likes or otherwise support +// __builtin_clz and __builtin_clzll, so only define FMT_BUILTIN_CLZ using the +// MSVC intrinsics if the clz and clzll builtins are not available. +#if FMT_MSC_VER && !defined(FMT_BUILTIN_CLZLL) && !defined(_MANAGED) +# include // _BitScanReverse, _BitScanReverse64 + +FMT_BEGIN_NAMESPACE +namespace internal { +// Avoid Clang with Microsoft CodeGen's -Wunknown-pragmas warning. +# ifndef __clang__ +# pragma intrinsic(_BitScanReverse) +# endif +inline uint32_t clz(uint32_t x) { + unsigned long r = 0; + _BitScanReverse(&r, x); + + FMT_ASSERT(x != 0, ""); + // Static analysis complains about using uninitialized data + // "r", but the only way that can happen is if "x" is 0, + // which the callers guarantee to not happen. +# pragma warning(suppress : 6102) + return 31 - r; +} +# define FMT_BUILTIN_CLZ(n) internal::clz(n) + +# if defined(_WIN64) && !defined(__clang__) +# pragma intrinsic(_BitScanReverse64) +# endif + +inline uint32_t clzll(uint64_t x) { + unsigned long r = 0; +# ifdef _WIN64 + _BitScanReverse64(&r, x); +# else + // Scan the high 32 bits. + if (_BitScanReverse(&r, static_cast(x >> 32))) return 63 - (r + 32); + + // Scan the low 32 bits. + _BitScanReverse(&r, static_cast(x)); +# endif + + FMT_ASSERT(x != 0, ""); + // Static analysis complains about using uninitialized data + // "r", but the only way that can happen is if "x" is 0, + // which the callers guarantee to not happen. +# pragma warning(suppress : 6102) + return 63 - r; +} +# define FMT_BUILTIN_CLZLL(n) internal::clzll(n) +} // namespace internal +FMT_END_NAMESPACE +#endif + +// Enable the deprecated numeric alignment. +#ifndef FMT_NUMERIC_ALIGN +# define FMT_NUMERIC_ALIGN 1 +#endif + +// Enable the deprecated percent specifier. +#ifndef FMT_DEPRECATED_PERCENT +# define FMT_DEPRECATED_PERCENT 0 +#endif + +FMT_BEGIN_NAMESPACE +namespace internal { + +// A helper function to suppress bogus "conditional expression is constant" +// warnings. +template inline T const_check(T value) { return value; } + +// An equivalent of `*reinterpret_cast(&source)` that doesn't have +// undefined behavior (e.g. due to type aliasing). +// Example: uint64_t d = bit_cast(2.718); +template +inline Dest bit_cast(const Source& source) { + static_assert(sizeof(Dest) == sizeof(Source), "size mismatch"); + Dest dest; + std::memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +inline bool is_big_endian() { + auto u = 1u; + struct bytes { + char data[sizeof(u)]; + }; + return bit_cast(u).data[0] == 0; +} + +// A fallback implementation of uintptr_t for systems that lack it. +struct fallback_uintptr { + unsigned char value[sizeof(void*)]; + + fallback_uintptr() = default; + explicit fallback_uintptr(const void* p) { + *this = bit_cast(p); + if (is_big_endian()) { + for (size_t i = 0, j = sizeof(void*) - 1; i < j; ++i, --j) + std::swap(value[i], value[j]); + } + } +}; +#ifdef UINTPTR_MAX +using uintptr_t = ::uintptr_t; +inline uintptr_t to_uintptr(const void* p) { return bit_cast(p); } +#else +using uintptr_t = fallback_uintptr; +inline fallback_uintptr to_uintptr(const void* p) { + return fallback_uintptr(p); +} +#endif + +// Returns the largest possible value for type T. Same as +// std::numeric_limits::max() but shorter and not affected by the max macro. +template constexpr T max_value() { + return (std::numeric_limits::max)(); +} +template constexpr int num_bits() { + return std::numeric_limits::digits; +} +template <> constexpr int num_bits() { + return static_cast(sizeof(void*) * + std::numeric_limits::digits); +} + +// An approximation of iterator_t for pre-C++20 systems. +template +using iterator_t = decltype(std::begin(std::declval())); + +// Detect the iterator category of *any* given type in a SFINAE-friendly way. +// Unfortunately, older implementations of std::iterator_traits are not safe +// for use in a SFINAE-context. +template +struct iterator_category : std::false_type {}; + +template struct iterator_category { + using type = std::random_access_iterator_tag; +}; + +template +struct iterator_category> { + using type = typename It::iterator_category; +}; + +// Detect if *any* given type models the OutputIterator concept. +template class is_output_iterator { + // Check for mutability because all iterator categories derived from + // std::input_iterator_tag *may* also meet the requirements of an + // OutputIterator, thereby falling into the category of 'mutable iterators' + // [iterator.requirements.general] clause 4. The compiler reveals this + // property only at the point of *actually dereferencing* the iterator! + template + static decltype(*(std::declval())) test(std::input_iterator_tag); + template static char& test(std::output_iterator_tag); + template static const char& test(...); + + using type = decltype(test(typename iterator_category::type{})); + + public: + static const bool value = !std::is_const>::value; +}; + +// A workaround for std::string not having mutable data() until C++17. +template inline Char* get_data(std::basic_string& s) { + return &s[0]; +} +template +inline typename Container::value_type* get_data(Container& c) { + return c.data(); +} + +#ifdef _SECURE_SCL +// Make a checked iterator to avoid MSVC warnings. +template using checked_ptr = stdext::checked_array_iterator; +template checked_ptr make_checked(T* p, std::size_t size) { + return {p, size}; +} +#else +template using checked_ptr = T*; +template inline T* make_checked(T* p, std::size_t) { return p; } +#endif + +template ::value)> +inline checked_ptr reserve( + std::back_insert_iterator& it, std::size_t n) { + Container& c = get_container(it); + std::size_t size = c.size(); + c.resize(size + n); + return make_checked(get_data(c) + size, n); +} + +template +inline Iterator& reserve(Iterator& it, std::size_t) { + return it; +} + +// An output iterator that counts the number of objects written to it and +// discards them. +class counting_iterator { + private: + std::size_t count_; + + public: + using iterator_category = std::output_iterator_tag; + using difference_type = std::ptrdiff_t; + using pointer = void; + using reference = void; + using _Unchecked_type = counting_iterator; // Mark iterator as checked. + + struct value_type { + template void operator=(const T&) {} + }; + + counting_iterator() : count_(0) {} + + std::size_t count() const { return count_; } + + counting_iterator& operator++() { + ++count_; + return *this; + } + + counting_iterator operator++(int) { + auto it = *this; + ++*this; + return it; + } + + value_type operator*() const { return {}; } +}; + +template class truncating_iterator_base { + protected: + OutputIt out_; + std::size_t limit_; + std::size_t count_; + + truncating_iterator_base(OutputIt out, std::size_t limit) + : out_(out), limit_(limit), count_(0) {} + + public: + using iterator_category = std::output_iterator_tag; + using difference_type = void; + using pointer = void; + using reference = void; + using _Unchecked_type = + truncating_iterator_base; // Mark iterator as checked. + + OutputIt base() const { return out_; } + std::size_t count() const { return count_; } +}; + +// An output iterator that truncates the output and counts the number of objects +// written to it. +template ::value_type>::type> +class truncating_iterator; + +template +class truncating_iterator + : public truncating_iterator_base { + using traits = std::iterator_traits; + + mutable typename traits::value_type blackhole_; + + public: + using value_type = typename traits::value_type; + + truncating_iterator(OutputIt out, std::size_t limit) + : truncating_iterator_base(out, limit) {} + + truncating_iterator& operator++() { + if (this->count_++ < this->limit_) ++this->out_; + return *this; + } + + truncating_iterator operator++(int) { + auto it = *this; + ++*this; + return it; + } + + value_type& operator*() const { + return this->count_ < this->limit_ ? *this->out_ : blackhole_; + } +}; + +template +class truncating_iterator + : public truncating_iterator_base { + public: + using value_type = typename OutputIt::container_type::value_type; + + truncating_iterator(OutputIt out, std::size_t limit) + : truncating_iterator_base(out, limit) {} + + truncating_iterator& operator=(value_type val) { + if (this->count_++ < this->limit_) this->out_ = val; + return *this; + } + + truncating_iterator& operator++() { return *this; } + truncating_iterator& operator++(int) { return *this; } + truncating_iterator& operator*() { return *this; } +}; + +// A range with the specified output iterator and value type. +template +class output_range { + private: + OutputIt it_; + + public: + using value_type = T; + using iterator = OutputIt; + struct sentinel {}; + + explicit output_range(OutputIt it) : it_(it) {} + OutputIt begin() const { return it_; } + sentinel end() const { return {}; } // Sentinel is not used yet. +}; + +template +inline size_t count_code_points(basic_string_view s) { + return s.size(); +} + +// Counts the number of code points in a UTF-8 string. +inline size_t count_code_points(basic_string_view s) { + const fmt_char8_t* data = s.data(); + size_t num_code_points = 0; + for (size_t i = 0, size = s.size(); i != size; ++i) { + if ((data[i] & 0xc0) != 0x80) ++num_code_points; + } + return num_code_points; +} + +template +inline size_t code_point_index(basic_string_view s, size_t n) { + size_t size = s.size(); + return n < size ? n : size; +} + +// Calculates the index of the nth code point in a UTF-8 string. +inline size_t code_point_index(basic_string_view s, size_t n) { + const fmt_char8_t* data = s.data(); + size_t num_code_points = 0; + for (size_t i = 0, size = s.size(); i != size; ++i) { + if ((data[i] & 0xc0) != 0x80 && ++num_code_points > n) { + return i; + } + } + return s.size(); +} + +inline fmt_char8_t to_fmt_char8_t(char c) { return static_cast(c); } + +template +using needs_conversion = bool_constant< + std::is_same::value_type, + char>::value && + std::is_same::value>; + +template ::value)> +OutputIt copy_str(InputIt begin, InputIt end, OutputIt it) { + return std::copy(begin, end, it); +} + +template ::value)> +OutputIt copy_str(InputIt begin, InputIt end, OutputIt it) { + return std::transform(begin, end, it, to_fmt_char8_t); +} + +#ifndef FMT_USE_GRISU +# define FMT_USE_GRISU 1 +#endif + +template constexpr bool use_grisu() { + return FMT_USE_GRISU && std::numeric_limits::is_iec559 && + sizeof(T) <= sizeof(double); +} + +template +template +void buffer::append(const U* begin, const U* end) { + std::size_t new_size = size_ + to_unsigned(end - begin); + reserve(new_size); + std::uninitialized_copy(begin, end, make_checked(ptr_, capacity_) + size_); + size_ = new_size; +} +} // namespace internal + +// A range with an iterator appending to a buffer. +template +class buffer_range : public internal::output_range< + std::back_insert_iterator>, T> { + public: + using iterator = std::back_insert_iterator>; + using internal::output_range::output_range; + buffer_range(internal::buffer& buf) + : internal::output_range(std::back_inserter(buf)) {} +}; + +// A UTF-8 string view. +class u8string_view : public basic_string_view { + public: + u8string_view(const char* s) + : basic_string_view(reinterpret_cast(s)) {} + u8string_view(const char* s, size_t count) FMT_NOEXCEPT + : basic_string_view(reinterpret_cast(s), count) { + } +}; + +#if FMT_USE_USER_DEFINED_LITERALS +inline namespace literals { +inline u8string_view operator"" _u(const char* s, std::size_t n) { + return {s, n}; +} +} // namespace literals +#endif + +// The number of characters to store in the basic_memory_buffer object itself +// to avoid dynamic memory allocation. +enum { inline_buffer_size = 500 }; + +/** + \rst + A dynamically growing memory buffer for trivially copyable/constructible types + with the first ``SIZE`` elements stored in the object itself. + + You can use one of the following type aliases for common character types: + + +----------------+------------------------------+ + | Type | Definition | + +================+==============================+ + | memory_buffer | basic_memory_buffer | + +----------------+------------------------------+ + | wmemory_buffer | basic_memory_buffer | + +----------------+------------------------------+ + + **Example**:: + + fmt::memory_buffer out; + format_to(out, "The answer is {}.", 42); + + This will append the following output to the ``out`` object: + + .. code-block:: none + + The answer is 42. + + The output can be converted to an ``std::string`` with ``to_string(out)``. + \endrst + */ +template > +class basic_memory_buffer : private Allocator, public internal::buffer { + private: + T store_[SIZE]; + + // Deallocate memory allocated by the buffer. + void deallocate() { + T* data = this->data(); + if (data != store_) Allocator::deallocate(data, this->capacity()); + } + + protected: + void grow(std::size_t size) FMT_OVERRIDE; + + public: + using value_type = T; + using const_reference = const T&; + + explicit basic_memory_buffer(const Allocator& alloc = Allocator()) + : Allocator(alloc) { + this->set(store_, SIZE); + } + ~basic_memory_buffer() FMT_OVERRIDE { deallocate(); } + + private: + // Move data from other to this buffer. + void move(basic_memory_buffer& other) { + Allocator &this_alloc = *this, &other_alloc = other; + this_alloc = std::move(other_alloc); + T* data = other.data(); + std::size_t size = other.size(), capacity = other.capacity(); + if (data == other.store_) { + this->set(store_, capacity); + std::uninitialized_copy(other.store_, other.store_ + size, + internal::make_checked(store_, capacity)); + } else { + this->set(data, capacity); + // Set pointer to the inline array so that delete is not called + // when deallocating. + other.set(other.store_, 0); + } + this->resize(size); + } + + public: + /** + \rst + Constructs a :class:`fmt::basic_memory_buffer` object moving the content + of the other object to it. + \endrst + */ + basic_memory_buffer(basic_memory_buffer&& other) FMT_NOEXCEPT { move(other); } + + /** + \rst + Moves the content of the other ``basic_memory_buffer`` object to this one. + \endrst + */ + basic_memory_buffer& operator=(basic_memory_buffer&& other) FMT_NOEXCEPT { + FMT_ASSERT(this != &other, ""); + deallocate(); + move(other); + return *this; + } + + // Returns a copy of the allocator associated with this buffer. + Allocator get_allocator() const { return *this; } +}; + +template +void basic_memory_buffer::grow(std::size_t size) { +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + if (size > 1000) throw std::runtime_error("fuzz mode - won't grow that much"); +#endif + std::size_t old_capacity = this->capacity(); + std::size_t new_capacity = old_capacity + old_capacity / 2; + if (size > new_capacity) new_capacity = size; + T* old_data = this->data(); + T* new_data = std::allocator_traits::allocate(*this, new_capacity); + // The following code doesn't throw, so the raw pointer above doesn't leak. + std::uninitialized_copy(old_data, old_data + this->size(), + internal::make_checked(new_data, new_capacity)); + this->set(new_data, new_capacity); + // deallocate must not throw according to the standard, but even if it does, + // the buffer already uses the new storage and will deallocate it in + // destructor. + if (old_data != store_) Allocator::deallocate(old_data, old_capacity); +} + +using memory_buffer = basic_memory_buffer; +using wmemory_buffer = basic_memory_buffer; + +namespace internal { + +// Returns true if value is negative, false otherwise. +// Same as `value < 0` but doesn't produce warnings if T is an unsigned type. +template ::is_signed)> +FMT_CONSTEXPR bool is_negative(T value) { + return value < 0; +} +template ::is_signed)> +FMT_CONSTEXPR bool is_negative(T) { + return false; +} + +// Smallest of uint32_t, uint64_t, uint128_t that is large enough to +// represent all values of T. +template +using uint32_or_64_or_128_t = conditional_t< + std::numeric_limits::digits <= 32, uint32_t, + conditional_t::digits <= 64, uint64_t, uint128_t>>; + +// Static data is placed in this class template for the header-only config. +template struct FMT_EXTERN_TEMPLATE_API basic_data { + static const uint64_t powers_of_10_64[]; + static const uint32_t zero_or_powers_of_10_32[]; + static const uint64_t zero_or_powers_of_10_64[]; + static const uint64_t pow10_significands[]; + static const int16_t pow10_exponents[]; + static const char digits[]; + static const char hex_digits[]; + static const char foreground_color[]; + static const char background_color[]; + static const char reset_color[5]; + static const wchar_t wreset_color[5]; + static const char signs[]; +}; + +FMT_EXTERN template struct basic_data; + +// This is a struct rather than an alias to avoid shadowing warnings in gcc. +struct data : basic_data<> {}; + +#ifdef FMT_BUILTIN_CLZLL +// Returns the number of decimal digits in n. Leading zeros are not counted +// except for n == 0 in which case count_digits returns 1. +inline int count_digits(uint64_t n) { + // Based on http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10 + // and the benchmark https://github.com/localvoid/cxx-benchmark-count-digits. + int t = (64 - FMT_BUILTIN_CLZLL(n | 1)) * 1233 >> 12; + return t - (n < data::zero_or_powers_of_10_64[t]) + 1; +} +#else +// Fallback version of count_digits used when __builtin_clz is not available. +inline int count_digits(uint64_t n) { + int count = 1; + for (;;) { + // Integer division is slow so do it for a group of four digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + if (n < 10) return count; + if (n < 100) return count + 1; + if (n < 1000) return count + 2; + if (n < 10000) return count + 3; + n /= 10000u; + count += 4; + } +} +#endif + +#if FMT_USE_INT128 +inline int count_digits(uint128_t n) { + int count = 1; + for (;;) { + // Integer division is slow so do it for a group of four digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + if (n < 10) return count; + if (n < 100) return count + 1; + if (n < 1000) return count + 2; + if (n < 10000) return count + 3; + n /= 10000U; + count += 4; + } +} +#endif + +// Counts the number of digits in n. BITS = log2(radix). +template inline int count_digits(UInt n) { + int num_digits = 0; + do { + ++num_digits; + } while ((n >>= BITS) != 0); + return num_digits; +} + +template <> int count_digits<4>(internal::fallback_uintptr n); + +#if FMT_GCC_VERSION || FMT_CLANG_VERSION +# define FMT_ALWAYS_INLINE inline __attribute__((always_inline)) +#else +# define FMT_ALWAYS_INLINE +#endif + +#ifdef FMT_BUILTIN_CLZ +// Optional version of count_digits for better performance on 32-bit platforms. +inline int count_digits(uint32_t n) { + int t = (32 - FMT_BUILTIN_CLZ(n | 1)) * 1233 >> 12; + return t - (n < data::zero_or_powers_of_10_32[t]) + 1; +} +#endif + +template FMT_API std::string grouping_impl(locale_ref loc); +template inline std::string grouping(locale_ref loc) { + return grouping_impl(loc); +} +template <> inline std::string grouping(locale_ref loc) { + return grouping_impl(loc); +} + +template FMT_API Char thousands_sep_impl(locale_ref loc); +template inline Char thousands_sep(locale_ref loc) { + return Char(thousands_sep_impl(loc)); +} +template <> inline wchar_t thousands_sep(locale_ref loc) { + return thousands_sep_impl(loc); +} + +template FMT_API Char decimal_point_impl(locale_ref loc); +template inline Char decimal_point(locale_ref loc) { + return Char(decimal_point_impl(loc)); +} +template <> inline wchar_t decimal_point(locale_ref loc) { + return decimal_point_impl(loc); +} + +// Formats a decimal unsigned integer value writing into buffer. +// add_thousands_sep is called after writing each char to add a thousands +// separator if necessary. +template +inline Char* format_decimal(Char* buffer, UInt value, int num_digits, + F add_thousands_sep) { + FMT_ASSERT(num_digits >= 0, "invalid digit count"); + buffer += num_digits; + Char* end = buffer; + while (value >= 100) { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + auto index = static_cast((value % 100) * 2); + value /= 100; + *--buffer = static_cast(data::digits[index + 1]); + add_thousands_sep(buffer); + *--buffer = static_cast(data::digits[index]); + add_thousands_sep(buffer); + } + if (value < 10) { + *--buffer = static_cast('0' + value); + return end; + } + auto index = static_cast(value * 2); + *--buffer = static_cast(data::digits[index + 1]); + add_thousands_sep(buffer); + *--buffer = static_cast(data::digits[index]); + return end; +} + +template constexpr int digits10() noexcept { + return std::numeric_limits::digits10; +} +template <> constexpr int digits10() noexcept { return 38; } +template <> constexpr int digits10() noexcept { return 38; } + +template +inline Iterator format_decimal(Iterator out, UInt value, int num_digits, + F add_thousands_sep) { + FMT_ASSERT(num_digits >= 0, "invalid digit count"); + // Buffer should be large enough to hold all digits (<= digits10 + 1). + enum { max_size = digits10() + 1 }; + Char buffer[2 * max_size]; + auto end = format_decimal(buffer, value, num_digits, add_thousands_sep); + return internal::copy_str(buffer, end, out); +} + +template +inline It format_decimal(It out, UInt value, int num_digits) { + return format_decimal(out, value, num_digits, [](Char*) {}); +} + +template +inline Char* format_uint(Char* buffer, UInt value, int num_digits, + bool upper = false) { + buffer += num_digits; + Char* end = buffer; + do { + const char* digits = upper ? "0123456789ABCDEF" : data::hex_digits; + unsigned digit = (value & ((1 << BASE_BITS) - 1)); + *--buffer = static_cast(BASE_BITS < 4 ? static_cast('0' + digit) + : digits[digit]); + } while ((value >>= BASE_BITS) != 0); + return end; +} + +template +Char* format_uint(Char* buffer, internal::fallback_uintptr n, int num_digits, + bool = false) { + auto char_digits = std::numeric_limits::digits / 4; + int start = (num_digits + char_digits - 1) / char_digits - 1; + if (int start_digits = num_digits % char_digits) { + unsigned value = n.value[start--]; + buffer = format_uint(buffer, value, start_digits); + } + for (; start >= 0; --start) { + unsigned value = n.value[start]; + buffer += char_digits; + auto p = buffer; + for (int i = 0; i < char_digits; ++i) { + unsigned digit = (value & ((1 << BASE_BITS) - 1)); + *--p = static_cast(data::hex_digits[digit]); + value >>= BASE_BITS; + } + } + return buffer; +} + +template +inline It format_uint(It out, UInt value, int num_digits, bool upper = false) { + // Buffer should be large enough to hold all digits (digits / BASE_BITS + 1). + char buffer[num_bits() / BASE_BITS + 1]; + format_uint(buffer, value, num_digits, upper); + return internal::copy_str(buffer, buffer + num_digits, out); +} + +template struct null {}; + +// Workaround an array initialization issue in gcc 4.8. +template struct fill_t { + private: + Char data_[6]; + + public: + FMT_CONSTEXPR Char& operator[](size_t index) { return data_[index]; } + FMT_CONSTEXPR const Char& operator[](size_t index) const { + return data_[index]; + } + + static FMT_CONSTEXPR fill_t make() { + auto fill = fill_t(); + fill[0] = Char(' '); + return fill; + } +}; +} // namespace internal + +// We cannot use enum classes as bit fields because of a gcc bug +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61414. +namespace align { +enum type { none, left, right, center, numeric }; +} +using align_t = align::type; + +namespace sign { +enum type { none, minus, plus, space }; +} +using sign_t = sign::type; + +// Format specifiers for built-in and string types. +template struct basic_format_specs { + int width; + int precision; + char type; + align_t align : 4; + sign_t sign : 3; + bool alt : 1; // Alternate form ('#'). + internal::fill_t fill; + char thousands; + + constexpr basic_format_specs() + : width(0), + precision(-1), + type(0), + align(align::none), + sign(sign::none), + alt(false), + fill(internal::fill_t::make()), + thousands('\0'){} +}; + +using format_specs = basic_format_specs; + +namespace internal { + +// A floating-point presentation format. +enum class float_format : unsigned char { + general, // General: exponent notation or fixed point based on magnitude. + exp, // Exponent notation with the default precision of 6, e.g. 1.2e-3. + fixed, // Fixed point with the default precision of 6, e.g. 0.0012. + hex +}; + +struct float_specs { + int precision; + float_format format : 8; + sign_t sign : 8; + bool upper : 1; + bool locale : 1; + bool percent : 1; + bool binary32 : 1; + bool use_grisu : 1; + bool trailing_zeros : 1; +}; + +// Writes the exponent exp in the form "[+-]d{2,3}" to buffer. +template It write_exponent(int exp, It it) { + FMT_ASSERT(-10000 < exp && exp < 10000, "exponent out of range"); + if (exp < 0) { + *it++ = static_cast('-'); + exp = -exp; + } else { + *it++ = static_cast('+'); + } + if (exp >= 100) { + const char* top = data::digits + (exp / 100) * 2; + if (exp >= 1000) *it++ = static_cast(top[0]); + *it++ = static_cast(top[1]); + exp %= 100; + } + const char* d = data::digits + exp * 2; + *it++ = static_cast(d[0]); + *it++ = static_cast(d[1]); + return it; +} + +template class float_writer { + private: + // The number is given as v = digits_ * pow(10, exp_). + const char* digits_; + int num_digits_; + int exp_; + size_t size_; + float_specs specs_; + Char decimal_point_; + + template It prettify(It it) const { + // pow(10, full_exp - 1) <= v <= pow(10, full_exp). + int full_exp = num_digits_ + exp_; + if (specs_.format == float_format::exp) { + // Insert a decimal point after the first digit and add an exponent. + *it++ = static_cast(*digits_); + int num_zeros = specs_.precision - num_digits_; + bool trailing_zeros = num_zeros > 0 && specs_.trailing_zeros; + if (num_digits_ > 1 || trailing_zeros) *it++ = decimal_point_; + it = copy_str(digits_ + 1, digits_ + num_digits_, it); + if (trailing_zeros) + it = std::fill_n(it, num_zeros, static_cast('0')); + *it++ = static_cast(specs_.upper ? 'E' : 'e'); + return write_exponent(full_exp - 1, it); + } + if (num_digits_ <= full_exp) { + // 1234e7 -> 12340000000[.0+] + it = copy_str(digits_, digits_ + num_digits_, it); + it = std::fill_n(it, full_exp - num_digits_, static_cast('0')); + if (specs_.trailing_zeros) { + *it++ = decimal_point_; + int num_zeros = specs_.precision - full_exp; + if (num_zeros <= 0) { + if (specs_.format != float_format::fixed) + *it++ = static_cast('0'); + return it; + } +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + if (num_zeros > 1000) + throw std::runtime_error("fuzz mode - avoiding excessive cpu use"); +#endif + it = std::fill_n(it, num_zeros, static_cast('0')); + } + } else if (full_exp > 0) { + // 1234e-2 -> 12.34[0+] + it = copy_str(digits_, digits_ + full_exp, it); + if (!specs_.trailing_zeros) { + // Remove trailing zeros. + int num_digits = num_digits_; + while (num_digits > full_exp && digits_[num_digits - 1] == '0') + --num_digits; + if (num_digits != full_exp) *it++ = decimal_point_; + return copy_str(digits_ + full_exp, digits_ + num_digits, it); + } + *it++ = decimal_point_; + it = copy_str(digits_ + full_exp, digits_ + num_digits_, it); + if (specs_.precision > num_digits_) { + // Add trailing zeros. + int num_zeros = specs_.precision - num_digits_; + it = std::fill_n(it, num_zeros, static_cast('0')); + } + } else { + // 1234e-6 -> 0.001234 + *it++ = static_cast('0'); + int num_zeros = -full_exp; + if (specs_.precision >= 0 && specs_.precision < num_zeros) + num_zeros = specs_.precision; + int num_digits = num_digits_; + if (!specs_.trailing_zeros) + while (num_digits > 0 && digits_[num_digits - 1] == '0') --num_digits; + if (num_zeros != 0 || num_digits != 0) { + *it++ = decimal_point_; + it = std::fill_n(it, num_zeros, static_cast('0')); + it = copy_str(digits_, digits_ + num_digits, it); + } + } + return it; + } + + public: + float_writer(const char* digits, int num_digits, int exp, float_specs specs, + Char decimal_point) + : digits_(digits), + num_digits_(num_digits), + exp_(exp), + specs_(specs), + decimal_point_(decimal_point) { + int full_exp = num_digits + exp - 1; + int precision = specs.precision > 0 ? specs.precision : 16; + if (specs_.format == float_format::general && + !(full_exp >= -4 && full_exp < precision)) { + specs_.format = float_format::exp; + } + size_ = prettify(counting_iterator()).count(); + size_ += specs.sign ? 1 : 0; + } + + size_t size() const { return size_; } + size_t width() const { return size(); } + + template void operator()(It&& it) { + if (specs_.sign) *it++ = static_cast(data::signs[specs_.sign]); + it = prettify(it); + } +}; + +template +int format_float(T value, int precision, float_specs specs, buffer& buf); + +// Formats a floating-point number with snprintf. +template +int snprintf_float(T value, int precision, float_specs specs, + buffer& buf); + +template T promote_float(T value) { return value; } +inline double promote_float(float value) { return value; } + +template +FMT_CONSTEXPR void handle_int_type_spec(const Spec& specs, Handler&& handler) { + if (specs.thousands != '\0') { + handler.on_num(); + return; + } + switch (specs.type) { + case 0: + case 'd': + handler.on_dec(); + break; + case 'x': + case 'X': + handler.on_hex(); + break; + case 'b': + case 'B': + handler.on_bin(); + break; + case 'o': + handler.on_oct(); + break; + case 'n': + case 'l': + case 'L': + handler.on_num(); + break; + default: + handler.on_error("Invalid type specifier \"" + std::string(1, specs.type) + "\" for formatting a value of type int"); + } +} + +template +FMT_CONSTEXPR float_specs parse_float_type_spec( + const basic_format_specs& specs, ErrorHandler&& eh = {}) { + + auto result = float_specs(); + if (specs.thousands != '\0') { + eh.on_error("Thousand separators are not supported for floating point numbers"); + return result; + } + result.trailing_zeros = specs.alt; + switch (specs.type) { + case 0: + result.format = float_format::general; + result.trailing_zeros |= specs.precision != 0; + break; + case 'G': + result.upper = true; + FMT_FALLTHROUGH; + case 'g': + result.format = float_format::general; + break; + case 'E': + result.upper = true; + FMT_FALLTHROUGH; + case 'e': + result.format = float_format::exp; + result.trailing_zeros |= specs.precision != 0; + break; + case 'F': + result.upper = true; + FMT_FALLTHROUGH; + case 'f': + result.format = float_format::fixed; + result.trailing_zeros |= specs.precision != 0; + break; +#if FMT_DEPRECATED_PERCENT + case '%': + result.format = float_format::fixed; + result.percent = true; + break; +#endif + case 'A': + result.upper = true; + FMT_FALLTHROUGH; + case 'a': + result.format = float_format::hex; + break; + case 'n': + case 'l': + case 'L': + result.locale = true; + break; + default: + eh.on_error("Invalid type specifier \"" + std::string(1, specs.type) + "\" for formatting a value of type float"); + break; + } + return result; +} + +template +FMT_CONSTEXPR void handle_char_specs(const basic_format_specs* specs, + Handler&& handler) { + if (!specs) return handler.on_char(); + if (specs->type && specs->type != 'c') return handler.on_int(); + if (specs->align == align::numeric || specs->sign != sign::none || specs->alt) + handler.on_error("invalid format specifier for char"); + handler.on_char(); +} + +template +FMT_CONSTEXPR void handle_cstring_type_spec(Char spec, Handler&& handler) { + if (spec == 0 || spec == 's') + handler.on_string(); + else if (spec == 'p') + handler.on_pointer(); + else + handler.on_error("Invalid type specifier \"" + std::string(1, spec) + "\" for formatting a value of type string"); +} + +template +FMT_CONSTEXPR void check_string_type_spec(Char spec, ErrorHandler&& eh) { + if (spec != 0 && spec != 's') eh.on_error("Invalid type specifier \"" + std::string(1, spec) + "\" for formatting a value of type string"); +} + +template +FMT_CONSTEXPR void check_pointer_type_spec(Char spec, ErrorHandler&& eh) { + if (spec != 0 && spec != 'p') eh.on_error("Invalid type specifier \"" + std::string(1, spec) + "\" for formatting a value of type pointer"); +} + +template class int_type_checker : private ErrorHandler { + public: + FMT_CONSTEXPR explicit int_type_checker(ErrorHandler eh) : ErrorHandler(eh) {} + + FMT_CONSTEXPR void on_dec() {} + FMT_CONSTEXPR void on_hex() {} + FMT_CONSTEXPR void on_bin() {} + FMT_CONSTEXPR void on_oct() {} + FMT_CONSTEXPR void on_num() {} + + FMT_CONSTEXPR void on_error(std::string error) { + ErrorHandler::on_error(error); + } +}; + +template +class char_specs_checker : public ErrorHandler { + private: + char type_; + + public: + FMT_CONSTEXPR char_specs_checker(char type, ErrorHandler eh) + : ErrorHandler(eh), type_(type) {} + + FMT_CONSTEXPR void on_int() { + handle_int_type_spec(type_, int_type_checker(*this)); + } + FMT_CONSTEXPR void on_char() {} +}; + +template +class cstring_type_checker : public ErrorHandler { + public: + FMT_CONSTEXPR explicit cstring_type_checker(ErrorHandler eh) + : ErrorHandler(eh) {} + + FMT_CONSTEXPR void on_string() {} + FMT_CONSTEXPR void on_pointer() {} +}; + +template +void arg_map::init(const basic_format_args& args) { + if (map_) return; + map_ = new entry[internal::to_unsigned(args.max_size())]; + if (args.is_packed()) { + for (int i = 0;; ++i) { + internal::type arg_type = args.type(i); + if (arg_type == internal::none_type) return; + if (arg_type == internal::named_arg_type) push_back(args.values_[i]); + } + } + for (int i = 0, n = args.max_size(); i < n; ++i) { + auto type = args.args_[i].type_; + if (type == internal::named_arg_type) push_back(args.args_[i].value_); + } +} + +template struct nonfinite_writer { + sign_t sign; + const char* str; + static constexpr size_t str_size = 3; + + size_t size() const { return str_size + (sign ? 1 : 0); } + size_t width() const { return size(); } + + template void operator()(It&& it) const { + if (sign) *it++ = static_cast(data::signs[sign]); + it = copy_str(str, str + str_size, it); + } +}; + +// This template provides operations for formatting and writing data into a +// character range. +template class basic_writer { + public: + using char_type = typename Range::value_type; + using iterator = typename Range::iterator; + using format_specs = basic_format_specs; + + private: + iterator out_; // Output iterator. + locale_ref locale_; + + // Attempts to reserve space for n extra characters in the output range. + // Returns a pointer to the reserved range or a reference to out_. + auto reserve(std::size_t n) -> decltype(internal::reserve(out_, n)) { + return internal::reserve(out_, n); + } + + template struct padded_int_writer { + size_t size_; + string_view prefix; + char_type fill; + std::size_t padding; + F f; + + size_t size() const { return size_; } + size_t width() const { return size_; } + + template void operator()(It&& it) const { + if (prefix.size() != 0) + it = copy_str(prefix.begin(), prefix.end(), it); + it = std::fill_n(it, padding, fill); + f(it); + } + }; + + // Writes an integer in the format + // + // where are written by f(it). + template + void write_int(int num_digits, string_view prefix, format_specs specs, F f) { + std::size_t size = prefix.size() + to_unsigned(num_digits); + char_type fill = specs.fill[0]; + std::size_t padding = 0; + if (specs.align == align::numeric) { + auto unsiged_width = to_unsigned(specs.width); + if (unsiged_width > size) { + padding = unsiged_width - size; + size = unsiged_width; + } + } else if (specs.precision > num_digits) { + size = prefix.size() + to_unsigned(specs.precision); + padding = to_unsigned(specs.precision - num_digits); + fill = static_cast('0'); + } + if (specs.align == align::none) specs.align = align::right; + write_padded(specs, padded_int_writer{size, prefix, fill, padding, f}); + } + + // Writes a decimal integer. + template void write_decimal(Int value) { + auto abs_value = static_cast>(value); + bool negative = is_negative(value); + // Don't do -abs_value since it trips unsigned-integer-overflow sanitizer. + if (negative) abs_value = ~abs_value + 1; + int num_digits = count_digits(abs_value); + auto&& it = reserve((negative ? 1 : 0) + static_cast(num_digits)); + if (negative) *it++ = static_cast('-'); + it = format_decimal(it, abs_value, num_digits); + } + + // The handle_int_type_spec handler that writes an integer. + template struct int_writer { + using unsigned_type = uint32_or_64_or_128_t; + + basic_writer& writer; + const Specs& specs; + unsigned_type abs_value; + char prefix[4]; + unsigned prefix_size; + + string_view get_prefix() const { return string_view(prefix, prefix_size); } + + int_writer(basic_writer& w, Int value, const Specs& s) + : writer(w), + specs(s), + abs_value(static_cast(value)), + prefix_size(0) { + if (is_negative(value)) { + prefix[0] = '-'; + ++prefix_size; + abs_value = 0 - abs_value; + } else if (specs.sign != sign::none && specs.sign != sign::minus) { + prefix[0] = specs.sign == sign::plus ? '+' : ' '; + ++prefix_size; + } + } + + struct dec_writer { + unsigned_type abs_value; + int num_digits; + + template void operator()(It&& it) const { + it = internal::format_decimal(it, abs_value, num_digits); + } + }; + + void on_dec() { + int num_digits = count_digits(abs_value); + writer.write_int(num_digits, get_prefix(), specs, + dec_writer{abs_value, num_digits}); + } + + struct hex_writer { + int_writer& self; + int num_digits; + + template void operator()(It&& it) const { + it = format_uint<4, char_type>(it, self.abs_value, num_digits, + self.specs.type != 'x'); + } + }; + + void on_hex() { + if (specs.alt) { + prefix[prefix_size++] = '0'; + prefix[prefix_size++] = specs.type; + } + int num_digits = count_digits<4>(abs_value); + writer.write_int(num_digits, get_prefix(), specs, + hex_writer{*this, num_digits}); + } + + template struct bin_writer { + unsigned_type abs_value; + int num_digits; + + template void operator()(It&& it) const { + it = format_uint(it, abs_value, num_digits); + } + }; + + void on_bin() { + if (specs.alt) { + prefix[prefix_size++] = '0'; + prefix[prefix_size++] = static_cast(specs.type); + } + int num_digits = count_digits<1>(abs_value); + writer.write_int(num_digits, get_prefix(), specs, + bin_writer<1>{abs_value, num_digits}); + } + + void on_oct() { + int num_digits = count_digits<3>(abs_value); + if (specs.alt && specs.precision <= num_digits && abs_value != 0) { + // Octal prefix '0' is counted as a digit, so only add it if precision + // is not greater than the number of digits. + prefix[prefix_size++] = '0'; + } + writer.write_int(num_digits, get_prefix(), specs, + bin_writer<3>{abs_value, num_digits}); + } + + enum { sep_size = 1 }; + + struct num_writer { + unsigned_type abs_value; + int size; + const std::string& groups; + char_type sep; + + template void operator()(It&& it) const { + basic_string_view s(&sep, sep_size); + // Index of a decimal digit with the least significant digit having + // index 0. + int digit_index = 0; + std::string::const_iterator group = groups.cbegin(); + it = format_decimal( + it, abs_value, size, + [this, s, &group, &digit_index](char_type*& buffer) { + if (*group <= 0 || ++digit_index % *group != 0 || + *group == max_value()) + return; + if (group + 1 != groups.cend()) { + digit_index = 0; + ++group; + } + buffer -= s.size(); + std::uninitialized_copy(s.data(), s.data() + s.size(), + make_checked(buffer, s.size())); + }); + } + }; + + void on_num() { + std::string groups = grouping(writer.locale_); + if (groups.empty()) return on_dec(); + auto sep = specs.thousands; + if (!sep) return on_dec(); + int num_digits = count_digits(abs_value); + int size = num_digits; + std::string::const_iterator group = groups.cbegin(); + while (group != groups.cend() && num_digits > *group && *group > 0 && + *group != max_value()) { + size += sep_size; + num_digits -= *group; + ++group; + } + if (group == groups.cend()) + size += sep_size * ((num_digits - 1) / groups.back()); + writer.write_int(size, get_prefix(), specs, + num_writer{abs_value, size, groups, static_cast(sep)}); + } + + FMT_NORETURN void on_error(std::string error) { + FMT_THROW(duckdb::Exception(error)); + } + }; + + template struct str_writer { + const Char* s; + size_t size_; + + size_t size() const { return size_; } + size_t width() const { + return count_code_points(basic_string_view(s, size_)); + } + + template void operator()(It&& it) const { + it = copy_str(s, s + size_, it); + } + }; + + template struct pointer_writer { + UIntPtr value; + int num_digits; + + size_t size() const { return to_unsigned(num_digits) + 2; } + size_t width() const { return size(); } + + template void operator()(It&& it) const { + *it++ = static_cast('0'); + *it++ = static_cast('x'); + it = format_uint<4, char_type>(it, value, num_digits); + } + }; + + public: + explicit basic_writer(Range out, locale_ref loc = locale_ref()) + : out_(out.begin()), locale_(loc) {} + + iterator out() const { return out_; } + + // Writes a value in the format + // + // where is written by f(it). + template void write_padded(const format_specs& specs, F&& f) { + // User-perceived width (in code points). + unsigned width = to_unsigned(specs.width); + size_t size = f.size(); // The number of code units. + size_t num_code_points = width != 0 ? f.width() : size; + if (width <= num_code_points) return f(reserve(size)); + auto&& it = reserve(width + (size - num_code_points)); + char_type fill = specs.fill[0]; + std::size_t padding = width - num_code_points; + if (specs.align == align::right) { + it = std::fill_n(it, padding, fill); + f(it); + } else if (specs.align == align::center) { + std::size_t left_padding = padding / 2; + it = std::fill_n(it, left_padding, fill); + f(it); + it = std::fill_n(it, padding - left_padding, fill); + } else { + f(it); + it = std::fill_n(it, padding, fill); + } + } + + void write(int value) { write_decimal(value); } + void write(long value) { write_decimal(value); } + void write(long long value) { write_decimal(value); } + + void write(unsigned value) { write_decimal(value); } + void write(unsigned long value) { write_decimal(value); } + void write(unsigned long long value) { write_decimal(value); } + +#if FMT_USE_INT128 + void write(int128_t value) { write_decimal(value); } + void write(uint128_t value) { write_decimal(value); } +#endif + + template + void write_int(T value, const Spec& spec) { + handle_int_type_spec(spec, int_writer(*this, value, spec)); + } + + template ::value)> + void write(T value, format_specs specs = {}) { + float_specs fspecs = parse_float_type_spec(specs); + fspecs.sign = specs.sign; + if (std::signbit(value)) { // value < 0 is false for NaN so use signbit. + fspecs.sign = sign::minus; + value = -value; + } else if (fspecs.sign == sign::minus) { + fspecs.sign = sign::none; + } + + if (!std::isfinite(value)) { + auto str = std::isinf(value) ? (fspecs.upper ? "INF" : "inf") + : (fspecs.upper ? "NAN" : "nan"); + return write_padded(specs, nonfinite_writer{fspecs.sign, str}); + } + + if (specs.align == align::none) { + specs.align = align::right; + } else if (specs.align == align::numeric) { + if (fspecs.sign) { + auto&& it = reserve(1); + *it++ = static_cast(data::signs[fspecs.sign]); + fspecs.sign = sign::none; + if (specs.width != 0) --specs.width; + } + specs.align = align::right; + } + + memory_buffer buffer; + if (fspecs.format == float_format::hex) { + if (fspecs.sign) buffer.push_back(data::signs[fspecs.sign]); + snprintf_float(promote_float(value), specs.precision, fspecs, buffer); + write_padded(specs, str_writer{buffer.data(), buffer.size()}); + return; + } + int precision = specs.precision >= 0 || !specs.type ? specs.precision : 6; + if (fspecs.format == float_format::exp) ++precision; + if (const_check(std::is_same())) fspecs.binary32 = true; + fspecs.use_grisu = use_grisu(); + if (const_check(FMT_DEPRECATED_PERCENT) && fspecs.percent) value *= 100; + int exp = format_float(promote_float(value), precision, fspecs, buffer); + if (const_check(FMT_DEPRECATED_PERCENT) && fspecs.percent) { + buffer.push_back('%'); + --exp; // Adjust decimal place position. + } + fspecs.precision = precision; + char_type point = fspecs.locale ? decimal_point(locale_) + : static_cast('.'); + write_padded(specs, float_writer(buffer.data(), + static_cast(buffer.size()), + exp, fspecs, point)); + } + + void write(char value) { + auto&& it = reserve(1); + *it++ = value; + } + + template ::value)> + void write(Char value) { + auto&& it = reserve(1); + *it++ = value; + } + + void write(string_view value) { + auto&& it = reserve(value.size()); + it = copy_str(value.begin(), value.end(), it); + } + void write(wstring_view value) { + static_assert(std::is_same::value, ""); + auto&& it = reserve(value.size()); + it = std::copy(value.begin(), value.end(), it); + } + + template + void write(const Char* s, std::size_t size, const format_specs& specs) { + write_padded(specs, str_writer{s, size}); + } + + template + void write(basic_string_view s, const format_specs& specs = {}) { + const Char* data = s.data(); + std::size_t size = s.size(); + if (specs.precision >= 0 && to_unsigned(specs.precision) < size) + size = code_point_index(s, to_unsigned(specs.precision)); + write(data, size, specs); + } + + template + void write_pointer(UIntPtr value, const format_specs* specs) { + int num_digits = count_digits<4>(value); + auto pw = pointer_writer{value, num_digits}; + if (!specs) return pw(reserve(to_unsigned(num_digits) + 2)); + format_specs specs_copy = *specs; + if (specs_copy.align == align::none) specs_copy.align = align::right; + write_padded(specs_copy, pw); + } +}; + +using writer = basic_writer>; + +template struct is_integral : std::is_integral {}; +template <> struct is_integral : std::true_type {}; +template <> struct is_integral : std::true_type {}; + +template +class arg_formatter_base { + public: + using char_type = typename Range::value_type; + using iterator = typename Range::iterator; + using format_specs = basic_format_specs; + + private: + using writer_type = basic_writer; + writer_type writer_; + format_specs* specs_; + + struct char_writer { + char_type value; + + size_t size() const { return 1; } + size_t width() const { return 1; } + + template void operator()(It&& it) const { *it++ = value; } + }; + + void write_char(char_type value) { + if (specs_) + writer_.write_padded(*specs_, char_writer{value}); + else + writer_.write(value); + } + + void write_pointer(const void* p) { + writer_.write_pointer(internal::to_uintptr(p), specs_); + } + + protected: + writer_type& writer() { return writer_; } + FMT_DEPRECATED format_specs* spec() { return specs_; } + format_specs* specs() { return specs_; } + iterator out() { return writer_.out(); } + + void write(bool value) { + string_view sv(value ? "true" : "false"); + specs_ ? writer_.write(sv, *specs_) : writer_.write(sv); + } + + void write(const char_type* value) { + if (!value) { + FMT_THROW(duckdb::Exception("string pointer is null")); + } else { + auto length = std::char_traits::length(value); + basic_string_view sv(value, length); + specs_ ? writer_.write(sv, *specs_) : writer_.write(sv); + } + } + + public: + arg_formatter_base(Range r, format_specs* s, locale_ref loc) + : writer_(r, loc), specs_(s) {} + + iterator operator()(monostate) { + FMT_ASSERT(false, "invalid argument type"); + return out(); + } + + template ::value)> + iterator operator()(T value) { + if (specs_) + writer_.write_int(value, *specs_); + else + writer_.write(value); + return out(); + } + + iterator operator()(char_type value) { + internal::handle_char_specs( + specs_, char_spec_handler(*this, static_cast(value))); + return out(); + } + + iterator operator()(bool value) { + if (specs_ && specs_->type) return (*this)(value ? 1 : 0); + write(value != 0); + return out(); + } + + template ::value)> + iterator operator()(T value) { + writer_.write(value, specs_ ? *specs_ : format_specs()); + return out(); + } + + struct char_spec_handler : ErrorHandler { + arg_formatter_base& formatter; + char_type value; + + char_spec_handler(arg_formatter_base& f, char_type val) + : formatter(f), value(val) {} + + void on_int() { + if (formatter.specs_) + formatter.writer_.write_int(value, *formatter.specs_); + else + formatter.writer_.write(value); + } + void on_char() { formatter.write_char(value); } + }; + + struct cstring_spec_handler : internal::error_handler { + arg_formatter_base& formatter; + const char_type* value; + + cstring_spec_handler(arg_formatter_base& f, const char_type* val) + : formatter(f), value(val) {} + + void on_string() { formatter.write(value); } + void on_pointer() { formatter.write_pointer(value); } + }; + + iterator operator()(const char_type* value) { + if (!specs_) return write(value), out(); + internal::handle_cstring_type_spec(specs_->type, + cstring_spec_handler(*this, value)); + return out(); + } + + iterator operator()(basic_string_view value) { + if (specs_) { + internal::check_string_type_spec(specs_->type, internal::error_handler()); + writer_.write(value, *specs_); + } else { + writer_.write(value); + } + return out(); + } + + iterator operator()(const void* value) { + if (specs_) + check_pointer_type_spec(specs_->type, internal::error_handler()); + write_pointer(value); + return out(); + } +}; + +template FMT_CONSTEXPR bool is_name_start(Char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || '_' == c; +} + +// Parses the range [begin, end) as an unsigned integer. This function assumes +// that the range is non-empty and the first character is a digit. +template +FMT_CONSTEXPR int parse_nonnegative_int(const Char*& begin, const Char* end, + ErrorHandler&& eh) { + FMT_ASSERT(begin != end && '0' <= *begin && *begin <= '9', ""); + if (*begin == '0') { + ++begin; + return 0; + } + unsigned value = 0; + // Convert to unsigned to prevent a warning. + constexpr unsigned max_int = max_value(); + unsigned big = max_int / 10; + do { + // Check for overflow. + if (value > big) { + value = max_int + 1; + break; + } + value = value * 10 + unsigned(*begin - '0'); + ++begin; + } while (begin != end && '0' <= *begin && *begin <= '9'); + if (value > max_int) eh.on_error("number is too big"); + return static_cast(value); +} + +template class custom_formatter { + private: + using char_type = typename Context::char_type; + + basic_format_parse_context& parse_ctx_; + Context& ctx_; + + public: + explicit custom_formatter(basic_format_parse_context& parse_ctx, + Context& ctx) + : parse_ctx_(parse_ctx), ctx_(ctx) {} + + bool operator()(typename basic_format_arg::handle h) const { + h.format(parse_ctx_, ctx_); + return true; + } + + template bool operator()(T) const { return false; } +}; + +template +using is_integer = + bool_constant::value && !std::is_same::value && + !std::is_same::value && + !std::is_same::value>; + +template class width_checker { + public: + explicit FMT_CONSTEXPR width_checker(ErrorHandler& eh) : handler_(eh) {} + + template ::value)> + FMT_CONSTEXPR unsigned long long operator()(T value) { + if (is_negative(value)) handler_.on_error("negative width"); + return static_cast(value); + } + + template ::value)> + FMT_CONSTEXPR unsigned long long operator()(T) { + handler_.on_error("width is not integer"); + return 0; + } + + private: + ErrorHandler& handler_; +}; + +template class precision_checker { + public: + explicit FMT_CONSTEXPR precision_checker(ErrorHandler& eh) : handler_(eh) {} + + template ::value)> + FMT_CONSTEXPR unsigned long long operator()(T value) { + if (is_negative(value)) handler_.on_error("negative precision"); + return static_cast(value); + } + + template ::value)> + FMT_CONSTEXPR unsigned long long operator()(T) { + handler_.on_error("precision is not integer"); + return 0; + } + + private: + ErrorHandler& handler_; +}; + +// A format specifier handler that sets fields in basic_format_specs. +template class specs_setter { + public: + explicit FMT_CONSTEXPR specs_setter(basic_format_specs& specs) + : specs_(specs) {} + + FMT_CONSTEXPR specs_setter(const specs_setter& other) + : specs_(other.specs_) {} + + FMT_CONSTEXPR void on_align(align_t align) { specs_.align = align; } + FMT_CONSTEXPR void on_fill(Char fill) { specs_.fill[0] = fill; } + FMT_CONSTEXPR void on_plus() { specs_.sign = sign::plus; } + FMT_CONSTEXPR void on_minus() { specs_.sign = sign::minus; } + FMT_CONSTEXPR void on_space() { specs_.sign = sign::space; } + FMT_CONSTEXPR void on_comma() { specs_.thousands = ','; } + FMT_CONSTEXPR void on_underscore() { specs_.thousands = '_'; } + FMT_CONSTEXPR void on_single_quote() { specs_.thousands = '\''; } + FMT_CONSTEXPR void on_thousands(char sep) { specs_.thousands = sep; } + FMT_CONSTEXPR void on_hash() { specs_.alt = true; } + + FMT_CONSTEXPR void on_zero() { + specs_.align = align::numeric; + specs_.fill[0] = Char('0'); + } + + FMT_CONSTEXPR void on_width(int width) { specs_.width = width; } + FMT_CONSTEXPR void on_precision(int precision) { + specs_.precision = precision; + } + FMT_CONSTEXPR void end_precision() {} + + FMT_CONSTEXPR void on_type(Char type) { + specs_.type = static_cast(type); + } + + protected: + basic_format_specs& specs_; +}; + +template class numeric_specs_checker { + public: + FMT_CONSTEXPR numeric_specs_checker(ErrorHandler& eh, internal::type arg_type) + : error_handler_(eh), arg_type_(arg_type) {} + + FMT_CONSTEXPR void require_numeric_argument() { + if (!is_arithmetic_type(arg_type_)) + error_handler_.on_error("format specifier requires numeric argument"); + } + + FMT_CONSTEXPR void check_sign() { + require_numeric_argument(); + if (is_integral_type(arg_type_) && arg_type_ != int_type && + arg_type_ != long_long_type && arg_type_ != internal::char_type) { + error_handler_.on_error("format specifier requires signed argument"); + } + } + + FMT_CONSTEXPR void check_precision() { + if (is_integral_type(arg_type_) || arg_type_ == internal::pointer_type) + error_handler_.on_error("precision not allowed for this argument type"); + } + + private: + ErrorHandler& error_handler_; + internal::type arg_type_; +}; + +// A format specifier handler that checks if specifiers are consistent with the +// argument type. +template class specs_checker : public Handler { + public: + FMT_CONSTEXPR specs_checker(const Handler& handler, internal::type arg_type) + : Handler(handler), checker_(*this, arg_type) {} + + FMT_CONSTEXPR specs_checker(const specs_checker& other) + : Handler(other), checker_(*this, other.arg_type_) {} + + FMT_CONSTEXPR void on_align(align_t align) { + if (align == align::numeric) checker_.require_numeric_argument(); + Handler::on_align(align); + } + + FMT_CONSTEXPR void on_plus() { + checker_.check_sign(); + Handler::on_plus(); + } + + FMT_CONSTEXPR void on_minus() { + checker_.check_sign(); + Handler::on_minus(); + } + + FMT_CONSTEXPR void on_space() { + checker_.check_sign(); + Handler::on_space(); + } + + FMT_CONSTEXPR void on_hash() { + checker_.require_numeric_argument(); + Handler::on_hash(); + } + + FMT_CONSTEXPR void on_zero() { + checker_.require_numeric_argument(); + Handler::on_zero(); + } + + FMT_CONSTEXPR void end_precision() { checker_.check_precision(); } + + private: + numeric_specs_checker checker_; +}; + +template